You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lm...@apache.org on 2020/10/11 09:51:36 UTC
[incubator-tvm] branch master updated: [AutoScheduler] Improve test
cases (#6657)
This is an automated email from the ASF dual-hosted git repository.
lmzheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new dd60d24 [AutoScheduler] Improve test cases (#6657)
dd60d24 is described below
commit dd60d249e50f29c3cb34704693ba54803ae75a36
Author: Lianmin Zheng <li...@gmail.com>
AuthorDate: Sun Oct 11 02:51:22 2020 -0700
[AutoScheduler] Improve test cases (#6657)
* Improve test cases
* update
* fix lint
* fix lint
* trigger CI
* address comments
* trigger CI
---
python/tvm/auto_scheduler/__init__.py | 9 ++-
python/tvm/auto_scheduler/auto_schedule.py | 34 ++-------
python/tvm/auto_scheduler/measure.py | 2 +-
python/tvm/auto_scheduler/measure_record.py | 53 +++++++++++++-
python/tvm/auto_scheduler/search_policy.py | 2 +-
python/tvm/auto_scheduler/search_task.py | 47 ++++++++++++
.../search_policy/sketch_policy_rules.cc | 6 +-
.../python/unittest/test_auto_scheduler_common.py | 1 +
.../unittest/test_auto_scheduler_cost_model.py | 16 ++--
.../test_auto_scheduler_evolutionary_search.py | 19 ++---
.../unittest/test_auto_scheduler_layout_rewrite.py | 17 ++---
.../python/unittest/test_auto_scheduler_measure.py | 85 ++++++++++++++--------
.../unittest/test_auto_scheduler_search_policy.py | 10 +--
.../test_auto_scheduler_sketch_generation.py | 4 +-
14 files changed, 202 insertions(+), 103 deletions(-)
diff --git a/python/tvm/auto_scheduler/__init__.py b/python/tvm/auto_scheduler/__init__.py
index 2b36287..6a395e7 100644
--- a/python/tvm/auto_scheduler/__init__.py
+++ b/python/tvm/auto_scheduler/__init__.py
@@ -18,15 +18,17 @@
""" Namespace for TVM Auto-scheduler. """
from . import compute_dag
+from . import feature
+from . import loop_state
from . import measure
from . import measure_record
-from . import loop_state
+from . import search_policy
+from . import search_task
from . import utils
from . import workload_registry
-from . import feature
# Shortcut
-from .auto_schedule import SearchTask, TuningOptions, HardwareParams, create_task, auto_schedule
+from .auto_schedule import TuningOptions, HardwareParams, create_task, auto_schedule
from .compute_dag import ComputeDAG
from .cost_model import RandomModel, XGBModel
from .measure import (
@@ -38,5 +40,6 @@ from .measure import (
LocalRPCMeasureContext,
)
from .measure_record import RecordToFile, RecordReader, load_best, load_records, save_records
+from .search_task import SearchTask
from .search_policy import EmptyPolicy, SketchPolicy, PreloadMeasuredStates
from .workload_registry import register_workload, make_workload_key
diff --git a/python/tvm/auto_scheduler/auto_schedule.py b/python/tvm/auto_scheduler/auto_schedule.py
index d8763db..ca069bb 100644
--- a/python/tvm/auto_scheduler/auto_schedule.py
+++ b/python/tvm/auto_scheduler/auto_schedule.py
@@ -30,11 +30,13 @@ Candidate schedules are measured against the specific hardware target.
import tvm._ffi
from tvm.runtime import Object
+from tvm.target import Target
from .measure import LocalBuilder, LocalRunner
from .workload_registry import make_workload_key
from .compute_dag import ComputeDAG
from .cost_model import XGBModel
from .search_policy import SketchPolicy
+from .search_task import SearchTask
from . import _ffi_api
@@ -61,30 +63,6 @@ class HardwareParams(Object):
)
-@tvm._ffi.register_object("auto_scheduler.SearchTask")
-class SearchTask(Object):
- """The computation information and hardware parameters for a schedule search task.
-
- Parameters
- ----------
- dag : ComputeDAG
- The ComputeDAG for the corresponding compute declaration.
- workload_key : str
- The workload key for the corresponding compute declaration.
- target : tvm.target.Target
- The target device of this search task.
- target_host : Optional[tvm.target.Target]
- The target host device of this search task.
- hardware_params : Optional[HardwareParams]
- Hardware parameters used in this search task.
- """
-
- def __init__(self, dag, workload_key, target, target_host=None, hardware_params=None):
- self.__init_handle_by_constructor__(
- _ffi_api.SearchTask, dag, workload_key, target, target_host, hardware_params
- )
-
-
@tvm._ffi.register_object("auto_scheduler.TuningOptions")
class TuningOptions(Object):
"""This controls the options of performance tuning.
@@ -169,9 +147,9 @@ def create_task(func, args, target, target_host=None, hardware_params=None):
Can be the a function or the function name.
args : Union[Tuple[Any, ...], List[Any]]
The args of the function.
- target : tvm.target.Target
+ target : Union[tvm.target.Target, str]
The target device of this search task.
- target_host : Optional[tvm.target.Target]
+ target_host : Optional[Union[tvm.target.Target, str]]
The target host device of this search task.
hardware_params : Optional[HardwareParams]
Hardware parameters used in this search task.
@@ -182,6 +160,10 @@ def create_task(func, args, target, target_host=None, hardware_params=None):
"""
workload_key = make_workload_key(func, args)
dag = ComputeDAG(workload_key)
+ if isinstance(target, str):
+ target = Target(target)
+ if isinstance(target_host, str):
+ target_host = Target(target_host)
return SearchTask(dag, workload_key, target, target_host, hardware_params)
diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py
index 7648ebe..81c314f 100644
--- a/python/tvm/auto_scheduler/measure.py
+++ b/python/tvm/auto_scheduler/measure.py
@@ -25,7 +25,7 @@ We separate the measurement into two steps: build and run.
A builder builds the executable binary files and a runner runs the binary files to
get the measurement results. The flow of data structures is
- . `ProgramBuilder` `ProgramRunner`
+ . `ProgramBuilder` `ProgramRunner`
`MeasureInput` -----------------> `BuildResult` ----------------> `MeasureResult`
We implement these in python to utilize python's multiprocessing and error handling.
diff --git a/python/tvm/auto_scheduler/measure_record.py b/python/tvm/auto_scheduler/measure_record.py
index c7ae196..1d0d765 100644
--- a/python/tvm/auto_scheduler/measure_record.py
+++ b/python/tvm/auto_scheduler/measure_record.py
@@ -21,7 +21,9 @@ import numpy as np
import tvm._ffi
from tvm.runtime import Object
-from .measure import MeasureCallback, MeasureErrorNo
+from .compute_dag import ComputeDAG
+from .measure import MeasureErrorNo, MeasureInput, MeasureCallback
+from .search_task import SearchTask
from . import _ffi_api
@@ -70,6 +72,13 @@ class RecordReader(Object):
The MeasureInputs loaded from the log file.
results : List[auto_scheduler.measure.MeasureResult]
The MeasureResults loaded from the log file.
+
+ Notes
+ -----
+ Some unimportant and expensive fields in the returned MeasureInput are not deserialized
+ for faster read speed (e.g. input.task.compute_dag, input.state.stages).
+ If you want to use them, you can call the :code:`recover_measure_input` below
+ to rebuild these fields.
"""
inputs, results = _ffi_api.RecordReaderReadLines(
self, max_lines if max_lines else -1, skip_lines
@@ -96,6 +105,13 @@ def load_records(filename):
Returns
-------
logs : List[auto_scheduler.measure.MeasureInput, auto_scheduler.measure.MeasureResult]
+
+ Notes
+ -----
+ Some unimportant and expensive fields in the returned MeasureInput are not deserialized
+ for faster read speed (e.g., input.task.compute_dag, input.state.stages).
+ If you want to use them, you can call the :code:`recover_measure_input` below
+ to rebuild these fields.
"""
return zip(*RecordReader(filename).read_lines())
@@ -159,3 +175,38 @@ def load_best(filename, workload_key=None, target=None):
best_res = res
return best_inp, best_res
+
+
+def recover_measure_input(inp, rebuild_state=False):
+ """
+ Recover a deserialized MeasureInput by rebuilding the missing fields.
+ 1. Rebuid the compute_dag in inp.task
+ 2. (Optional) Rebuild the stages in inp.state
+
+ Parameters
+ ----------
+ inp: MeasureInput
+ The deserialized MeasureInput
+ rebuild_state: bool = False
+ Whether rebuild the stages in MeasureInput.State
+
+ Returns
+ -------
+ new_input: MeasureInput
+ The fully recovered MeasureInput with all fields rebuilt.
+ """
+ task = inp.task
+ new_task = SearchTask(
+ ComputeDAG(task.workload_key),
+ task.workload_key,
+ task.target,
+ task.target_host,
+ task.hardware_params,
+ )
+
+ if rebuild_state:
+ new_state = new_task.compute_dag.infer_bound_from_state(inp.state)
+ else:
+ new_state = inp.state
+
+ return MeasureInput(new_task, new_state)
diff --git a/python/tvm/auto_scheduler/search_policy.py b/python/tvm/auto_scheduler/search_policy.py
index bf7e2eb..5533aec 100644
--- a/python/tvm/auto_scheduler/search_policy.py
+++ b/python/tvm/auto_scheduler/search_policy.py
@@ -123,7 +123,7 @@ class SketchPolicy(SearchPolicy):
"gpu_multi_level_tiling_structure": "SSSRRSRS",
# Notice: the default thread bind policy of GPU assumes the tiling structure to have at
# least 3 spatial tiling levels in outermost
- "max_innermost_split_factor": 16,
+ "max_innermost_split_factor": 64,
"max_vectorize_size": 16,
"disable_change_compute_location": 0,
}
diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py
new file mode 100644
index 0000000..92c4f48
--- /dev/null
+++ b/python/tvm/auto_scheduler/search_task.py
@@ -0,0 +1,47 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+""" The definiton of SearchTask """
+
+import tvm._ffi
+from tvm.runtime import Object
+
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("auto_scheduler.SearchTask")
+class SearchTask(Object):
+ """The computation information and hardware parameters for a schedule search task.
+
+ Parameters
+ ----------
+ dag : ComputeDAG
+ The ComputeDAG for the corresponding compute declaration.
+ workload_key : str
+ The workload key for the corresponding compute declaration.
+ target : tvm.target.Target
+ The target device of this search task.
+ target_host : Optional[tvm.target.Target]
+ The target host device of this search task.
+ hardware_params : Optional[HardwareParams]
+ Hardware parameters used in this search task.
+ """
+
+ def __init__(self, dag, workload_key, target, target_host=None, hardware_params=None):
+ self.__init_handle_by_constructor__(
+ _ffi_api.SearchTask, dag, workload_key, target, target_host, hardware_params
+ )
diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.cc b/src/auto_scheduler/search_policy/sketch_policy_rules.cc
index c8370d6..045ee86 100644
--- a/src/auto_scheduler/search_policy/sketch_policy_rules.cc
+++ b/src/auto_scheduler/search_policy/sketch_policy_rules.cc
@@ -441,6 +441,9 @@ std::vector<std::pair<State, int>> RuleSpecialComputeLocationGPU::Apply(
PopulationGenerationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* policy, State* state,
std::mt19937* rand_gen) const {
+ int max_innermost_split_factor =
+ GetIntParam(policy->params, SketchParamKey::max_innermost_split_factor);
+
StateNode* pstate = state->CopyOnWrite();
// Scan the transformation history and randomly fill tiles size for all SplitStep
for (size_t step_id = 0; step_id < (*state)->transform_steps.size(); ++step_id) {
@@ -459,8 +462,7 @@ PopulationGenerationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* p
CHECK(ps->extent);
int extent = GetIntImm(ps->extent.value());
const auto& candidate_lens = policy->split_memo.GetFactorizationSchemes(
- extent, ps->lengths.size(),
- GetIntParam(policy->params, SketchParamKey::max_innermost_split_factor));
+ extent, ps->lengths.size(), max_innermost_split_factor);
const auto& candidate_lengths = candidate_lens[(*rand_gen)() % candidate_lens.size()];
pstate->transform_steps.Set(
diff --git a/tests/python/unittest/test_auto_scheduler_common.py b/tests/python/unittest/test_auto_scheduler_common.py
index eaf328c..880b112 100644
--- a/tests/python/unittest/test_auto_scheduler_common.py
+++ b/tests/python/unittest/test_auto_scheduler_common.py
@@ -216,6 +216,7 @@ def conv2d_winograd_nhwc_auto_scheduler_test(
def get_tiled_matmul():
+ """Get a compute dag and a state for tiled matmul"""
A, B, C = matmul_auto_scheduler_test(512, 512, 512)
dag = auto_scheduler.ComputeDAG([A, B, C])
diff --git a/tests/python/unittest/test_auto_scheduler_cost_model.py b/tests/python/unittest/test_auto_scheduler_cost_model.py
index a28618c..62acb6b 100644
--- a/tests/python/unittest/test_auto_scheduler_cost_model.py
+++ b/tests/python/unittest/test_auto_scheduler_cost_model.py
@@ -28,12 +28,9 @@ from test_auto_scheduler_common import matmul_auto_scheduler_test
def get_sample_records(number):
- """Generate random a list of random MeasureInput and MeasureResult pairs"""
+ """Generate a list of random MeasureInput and MeasureResult pairs"""
N = 128
- workload_key = auto_scheduler.make_workload_key(matmul_auto_scheduler_test, (N, N, N))
- dag = auto_scheduler.ComputeDAG(workload_key)
- target = tvm.target.Target("llvm")
- task = auto_scheduler.SearchTask(dag, workload_key, target)
+ task = auto_scheduler.create_task(matmul_auto_scheduler_test, (N, N, N), "llvm")
policy = auto_scheduler.SketchPolicy(task, verbose=0)
states = policy.sample_initial_population(number)
@@ -43,11 +40,11 @@ def get_sample_records(number):
for _ in range(len(inputs))
]
- return task, dag, inputs, results
+ return task, inputs, results
def test_random_model():
- task, dag, inputs, results = get_sample_records(50)
+ task, inputs, results = get_sample_records(50)
model = auto_scheduler.RandomModel()
model.update(inputs, results)
@@ -56,7 +53,7 @@ def test_random_model():
def test_xgb_model():
- task, dag, inputs, results = get_sample_records(50)
+ task, inputs, results = get_sample_records(50)
model = auto_scheduler.XGBModel(num_warmup_sample=-1)
model.update(inputs, results)
@@ -66,13 +63,16 @@ def test_xgb_model():
costs = [np.mean([x.value for x in res.costs]) for res in results]
throughputs = np.min(costs) / costs
+ # test regression quality
rmse = np.sqrt(np.mean([np.square(pred - label) for pred, label in zip(preds, throughputs)]))
assert rmse <= 0.3
+ # test loading a record file
with tempfile.NamedTemporaryFile() as fp:
auto_scheduler.save_records(fp.name, inputs, results)
model.update_from_file(fp.name)
+ # test model serialization
with tempfile.NamedTemporaryFile() as fp:
model.save(fp.name)
model.load(fp.name)
diff --git a/tests/python/unittest/test_auto_scheduler_evolutionary_search.py b/tests/python/unittest/test_auto_scheduler_evolutionary_search.py
index b51066b..9fec6f1 100644
--- a/tests/python/unittest/test_auto_scheduler_evolutionary_search.py
+++ b/tests/python/unittest/test_auto_scheduler_evolutionary_search.py
@@ -69,7 +69,6 @@ def test_mutate_tile_size():
assert found
-@pytest.mark.skip(reason="flaky")
def test_mutate_parallel():
"""
The test case initializes evo search with a batch of "bad" states and check whether
@@ -95,20 +94,18 @@ def test_mutate_parallel():
scores.append(1 if self.is_good_state(state) else 0)
return scores
- workload_key = auto_scheduler.make_workload_key(matmul_auto_scheduler_test, (1024, 1024, 1024))
- dag = auto_scheduler.ComputeDAG(workload_key)
- task = auto_scheduler.SearchTask(dag, workload_key, tvm.target.Target("llvm"))
+ task = auto_scheduler.create_task(matmul_auto_scheduler_test, (1024, 1024, 1024), "llvm")
policy = auto_scheduler.SketchPolicy(task, program_cost_model=MockCostModel(), verbose=0)
- states = policy.sample_initial_population(100)
-
- bad_states = []
- for state in states:
- if not MockCostModel.is_good_state(state):
- bad_states.append(state)
found = False
retry_ct = 0
- while retry_ct < 5 and not found:
+ while retry_ct < 10 and not found:
+ states = policy.sample_initial_population(100)
+ bad_states = []
+ for state in states:
+ if not MockCostModel.is_good_state(state):
+ bad_states.append(state)
+
new_states = policy.evolutionary_search(bad_states, 50)
for state in new_states:
if MockCostModel.is_good_state(state):
diff --git a/tests/python/unittest/test_auto_scheduler_layout_rewrite.py b/tests/python/unittest/test_auto_scheduler_layout_rewrite.py
index caa1d6a..3ce7a43 100644
--- a/tests/python/unittest/test_auto_scheduler_layout_rewrite.py
+++ b/tests/python/unittest/test_auto_scheduler_layout_rewrite.py
@@ -41,12 +41,9 @@ def test_apply_steps_with_layout_rewrite():
def test_layout_rewrite_correctness():
N = 128
- target = "llvm"
- workload = matmul_auto_scheduler_test
- workload_key = auto_scheduler.make_workload_key(workload, (N, N, N))
- dag = auto_scheduler.ComputeDAG(workload_key)
- target = tvm.target.Target(target)
- task = auto_scheduler.SearchTask(dag, workload_key, target)
+ target = tvm.target.Target("llvm")
+ task = auto_scheduler.create_task(matmul_auto_scheduler_test, (N, N, N), target)
+ dag = task.compute_dag
with tempfile.NamedTemporaryFile() as fp:
log_file = fp.name
@@ -60,7 +57,7 @@ def test_layout_rewrite_correctness():
measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
)
auto_scheduler.auto_schedule(task, search_policy, tuning_options)
- inp, _ = auto_scheduler.load_best(log_file, workload_key, target)
+ inp, _ = auto_scheduler.load_best(log_file, task.workload_key, target)
s, bufs = dag.apply_steps_from_state(inp.state, layout_rewrite=True)
s_ref, bufs_ref = dag.apply_steps_from_state(inp.state, layout_rewrite=False)
np_args = [np.random.randn(*topi.get_const_tuple(x.shape)).astype(x.dtype) for x in bufs]
@@ -89,10 +86,10 @@ def test_layout_rewrite_correctness():
np_args_ref[1] = np_args_ref[1].transpose(new_order)
np_args_ref[1] = np_args_ref[1].reshape((red_dim, out_dim))
- func = tvm.build(s, bufs, target=inp.task.target, target_host=inp.task.target_host)
- func_ref = tvm.build(s_ref, bufs_ref, target="llvm")
+ func = tvm.build(s, bufs, target=target)
+ func_ref = tvm.build(s_ref, bufs_ref, target=target)
- ctx = tvm.context(str(inp.task.target))
+ ctx = tvm.context(str(target))
ctx_ref = tvm.cpu()
args = [tvm.nd.array(x, ctx=ctx) for x in np_args]
diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py
index 5dae2a5..4369d20 100644
--- a/tests/python/unittest/test_auto_scheduler_measure.py
+++ b/tests/python/unittest/test_auto_scheduler_measure.py
@@ -167,45 +167,69 @@ def test_record_pragma_storage_align_rfactor():
record_common(dag, s)
-def test_measure_local_builder_runner(enable_cpu_cache_flush=False):
+def test_recover_measure_input():
+ task = auto_scheduler.create_task(matmul_auto_scheduler_test, [512, 512, 512], "llvm")
+
+ inp = auto_scheduler.measure.MeasureInput(task, task.compute_dag.init_state)
+ res = auto_scheduler.measure.MeasureResult([0.1], 0, "", 0.2, 1)
+
+ with tempfile.NamedTemporaryFile() as fp:
+ auto_scheduler.save_records(fp.name, [inp], [res])
+
+ log_reader = auto_scheduler.RecordReader(fp.name)
+ inputs, results = log_reader.read_lines()
+ assert len(inputs) == 1
+
+ raw_inp = inputs[0]
+
+ correct_inp = auto_scheduler.measure_record.recover_measure_input(raw_inp)
+ assert str(correct_inp.task.compute_dag) == str(inp.task.compute_dag)
+
+ correct_inp = auto_scheduler.measure_record.recover_measure_input(
+ raw_inp, rebuild_state=True
+ )
+ assert str(correct_inp.state) == str(inp.state)
+
+
+def test_measure_local_builder_runner():
if not tvm.testing.device_enabled("llvm"):
return
- dag, s0 = get_tiled_matmul()
- tgt = tvm.target.Target("llvm")
- task = auto_scheduler.SearchTask(dag, "test", tgt)
+ task = auto_scheduler.create_task(matmul_auto_scheduler_test, [512, 512, 512], "llvm")
- minp = auto_scheduler.MeasureInput(task, s0)
- local_builder = auto_scheduler.LocalBuilder()
- local_runner = auto_scheduler.LocalRunner(
- timeout=60, enable_cpu_cache_flush=enable_cpu_cache_flush
- )
+ for enable_cpu_cache_flush in [True, False]:
+ minp = auto_scheduler.MeasureInput(task, task.compute_dag.init_state)
+ local_builder = auto_scheduler.LocalBuilder()
+ local_runner = auto_scheduler.LocalRunner(
+ timeout=60, enable_cpu_cache_flush=enable_cpu_cache_flush
+ )
- bress = local_builder.build([minp])
- assert bress[0].error_no == 0
- mress = local_runner.run([minp], bress)
- assert mress[0].error_no == 0
+ bress = local_builder.build([minp])
+ assert bress[0].error_no == 0
+ mress = local_runner.run([minp], bress)
+ assert mress[0].error_no == 0
-def test_measure_local_builder_rpc_runner(enable_cpu_cache_flush=False):
+def test_measure_local_builder_rpc_runner():
if not tvm.testing.device_enabled("llvm"):
return
- dag, s0 = get_tiled_matmul()
- tgt = tvm.target.Target("llvm")
- task = auto_scheduler.SearchTask(dag, "test", tgt)
+ task = auto_scheduler.create_task(matmul_auto_scheduler_test, [512, 512, 512], "llvm")
- minp = auto_scheduler.MeasureInput(task, s0)
- local_builder = auto_scheduler.LocalBuilder()
- measure_ctx = auto_scheduler.LocalRPCMeasureContext(
- timeout=60, enable_cpu_cache_flush=enable_cpu_cache_flush
- )
- rpc_runner = measure_ctx.runner
+ for enable_cpu_cache_flush in [True, False]:
+ minp = auto_scheduler.MeasureInput(task, task.compute_dag.init_state)
+ local_builder = auto_scheduler.LocalBuilder()
+ measure_ctx = auto_scheduler.LocalRPCMeasureContext(
+ timeout=60, enable_cpu_cache_flush=enable_cpu_cache_flush
+ )
+ rpc_runner = measure_ctx.runner
+
+ bress = local_builder.build([minp])
+ assert bress[0].error_no == 0
+ mress = rpc_runner.run([minp], bress)
+ assert mress[0].error_no == 0
- bress = local_builder.build([minp])
- assert bress[0].error_no == 0
- mress = rpc_runner.run([minp], bress)
- assert mress[0].error_no == 0
+ del measure_ctx
if __name__ == "__main__":
@@ -213,7 +237,6 @@ if __name__ == "__main__":
test_record_compute_at_root_inline_cache_read_write()
test_record_follow_split_follow_fused_split()
test_record_pragma_storage_align_rfactor()
- test_measure_local_builder_runner(enable_cpu_cache_flush=True)
- test_measure_local_builder_runner(enable_cpu_cache_flush=False)
- test_measure_local_builder_rpc_runner(enable_cpu_cache_flush=True)
- test_measure_local_builder_rpc_runner(enable_cpu_cache_flush=False)
+ test_recover_measure_input()
+ test_measure_local_builder_runner()
+ test_measure_local_builder_rpc_runner()
diff --git a/tests/python/unittest/test_auto_scheduler_search_policy.py b/tests/python/unittest/test_auto_scheduler_search_policy.py
index 04b54b2..07cf4c8 100644
--- a/tests/python/unittest/test_auto_scheduler_search_policy.py
+++ b/tests/python/unittest/test_auto_scheduler_search_policy.py
@@ -42,10 +42,8 @@ def search_common(
random.seed(seed)
N = 128
- workload_key = auto_scheduler.make_workload_key(workload, (N, N, N))
- dag = auto_scheduler.ComputeDAG(workload_key)
target = tvm.target.Target(target)
- task = auto_scheduler.SearchTask(dag, workload_key, target)
+ task = auto_scheduler.create_task(workload, (N, N, N), target)
with tempfile.NamedTemporaryFile() as fp:
log_file = fp.name
@@ -70,10 +68,10 @@ def search_common(
print("*" * 80)
print(target)
print("*" * 80)
- inp, res = auto_scheduler.load_best(log_file, workload_key, target)
+ inp, res = auto_scheduler.load_best(log_file, task.workload_key, target)
print("==== Python Code ====")
- print(dag.print_python_code_from_state(inp.state))
+ print(task.compute_dag.print_python_code_from_state(inp.state))
try:
print("==== Lowered Stmt ====")
@@ -81,7 +79,7 @@ def search_common(
mod = tvm.build(sch, args, target)
ctx = tvm.context(str(target), 0)
- dtype = dag.tensors[0].dtype
+ dtype = task.compute_dag.tensors[0].dtype
a = tvm.nd.array(np.random.uniform(size=(N, N)).astype(dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(N, N)).astype(dtype), ctx)
c = tvm.nd.array(np.zeros((N, N), dtype=dtype), ctx)
diff --git a/tests/python/unittest/test_auto_scheduler_sketch_generation.py b/tests/python/unittest/test_auto_scheduler_sketch_generation.py
index 5a687da..47aa78a 100644
--- a/tests/python/unittest/test_auto_scheduler_sketch_generation.py
+++ b/tests/python/unittest/test_auto_scheduler_sketch_generation.py
@@ -36,9 +36,7 @@ from test_auto_scheduler_common import (
def generate_sketches(workload_func, args, target, print_for_debug=False):
- workload_key = auto_scheduler.make_workload_key(workload_func, args)
- dag = auto_scheduler.ComputeDAG(workload_key)
- task = auto_scheduler.SearchTask(dag, workload_key, tvm.target.Target(target))
+ task = auto_scheduler.create_task(workload_func, args, tvm.target.Target(target))
policy = auto_scheduler.SketchPolicy(task, verbose=0)
return policy.generate_sketches(print_for_debug)