You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lm...@apache.org on 2020/10/11 09:51:36 UTC

[incubator-tvm] branch master updated: [AutoScheduler] Improve test cases (#6657)

This is an automated email from the ASF dual-hosted git repository.

lmzheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new dd60d24  [AutoScheduler] Improve test cases (#6657)
dd60d24 is described below

commit dd60d249e50f29c3cb34704693ba54803ae75a36
Author: Lianmin Zheng <li...@gmail.com>
AuthorDate: Sun Oct 11 02:51:22 2020 -0700

    [AutoScheduler] Improve test cases (#6657)
    
    * Improve test cases
    
    * update
    
    * fix lint
    
    * fix lint
    
    * trigger CI
    
    * address comments
    
    * trigger CI
---
 python/tvm/auto_scheduler/__init__.py              |  9 ++-
 python/tvm/auto_scheduler/auto_schedule.py         | 34 ++-------
 python/tvm/auto_scheduler/measure.py               |  2 +-
 python/tvm/auto_scheduler/measure_record.py        | 53 +++++++++++++-
 python/tvm/auto_scheduler/search_policy.py         |  2 +-
 python/tvm/auto_scheduler/search_task.py           | 47 ++++++++++++
 .../search_policy/sketch_policy_rules.cc           |  6 +-
 .../python/unittest/test_auto_scheduler_common.py  |  1 +
 .../unittest/test_auto_scheduler_cost_model.py     | 16 ++--
 .../test_auto_scheduler_evolutionary_search.py     | 19 ++---
 .../unittest/test_auto_scheduler_layout_rewrite.py | 17 ++---
 .../python/unittest/test_auto_scheduler_measure.py | 85 ++++++++++++++--------
 .../unittest/test_auto_scheduler_search_policy.py  | 10 +--
 .../test_auto_scheduler_sketch_generation.py       |  4 +-
 14 files changed, 202 insertions(+), 103 deletions(-)

diff --git a/python/tvm/auto_scheduler/__init__.py b/python/tvm/auto_scheduler/__init__.py
index 2b36287..6a395e7 100644
--- a/python/tvm/auto_scheduler/__init__.py
+++ b/python/tvm/auto_scheduler/__init__.py
@@ -18,15 +18,17 @@
 """ Namespace for TVM Auto-scheduler. """
 
 from . import compute_dag
+from . import feature
+from . import loop_state
 from . import measure
 from . import measure_record
-from . import loop_state
+from . import search_policy
+from . import search_task
 from . import utils
 from . import workload_registry
-from . import feature
 
 # Shortcut
-from .auto_schedule import SearchTask, TuningOptions, HardwareParams, create_task, auto_schedule
+from .auto_schedule import TuningOptions, HardwareParams, create_task, auto_schedule
 from .compute_dag import ComputeDAG
 from .cost_model import RandomModel, XGBModel
 from .measure import (
@@ -38,5 +40,6 @@ from .measure import (
     LocalRPCMeasureContext,
 )
 from .measure_record import RecordToFile, RecordReader, load_best, load_records, save_records
+from .search_task import SearchTask
 from .search_policy import EmptyPolicy, SketchPolicy, PreloadMeasuredStates
 from .workload_registry import register_workload, make_workload_key
diff --git a/python/tvm/auto_scheduler/auto_schedule.py b/python/tvm/auto_scheduler/auto_schedule.py
index d8763db..ca069bb 100644
--- a/python/tvm/auto_scheduler/auto_schedule.py
+++ b/python/tvm/auto_scheduler/auto_schedule.py
@@ -30,11 +30,13 @@ Candidate schedules are measured against the specific hardware target.
 
 import tvm._ffi
 from tvm.runtime import Object
+from tvm.target import Target
 from .measure import LocalBuilder, LocalRunner
 from .workload_registry import make_workload_key
 from .compute_dag import ComputeDAG
 from .cost_model import XGBModel
 from .search_policy import SketchPolicy
+from .search_task import SearchTask
 from . import _ffi_api
 
 
@@ -61,30 +63,6 @@ class HardwareParams(Object):
         )
 
 
-@tvm._ffi.register_object("auto_scheduler.SearchTask")
-class SearchTask(Object):
-    """The computation information and hardware parameters for a schedule search task.
-
-    Parameters
-    ----------
-    dag : ComputeDAG
-        The ComputeDAG for the corresponding compute declaration.
-    workload_key : str
-        The workload key for the corresponding compute declaration.
-    target : tvm.target.Target
-        The target device of this search task.
-    target_host : Optional[tvm.target.Target]
-        The target host device of this search task.
-    hardware_params : Optional[HardwareParams]
-        Hardware parameters used in this search task.
-    """
-
-    def __init__(self, dag, workload_key, target, target_host=None, hardware_params=None):
-        self.__init_handle_by_constructor__(
-            _ffi_api.SearchTask, dag, workload_key, target, target_host, hardware_params
-        )
-
-
 @tvm._ffi.register_object("auto_scheduler.TuningOptions")
 class TuningOptions(Object):
     """This controls the options of performance tuning.
@@ -169,9 +147,9 @@ def create_task(func, args, target, target_host=None, hardware_params=None):
         Can be the a function or the function name.
     args : Union[Tuple[Any, ...], List[Any]]
         The args of the function.
-    target : tvm.target.Target
+    target : Union[tvm.target.Target, str]
         The target device of this search task.
-    target_host : Optional[tvm.target.Target]
+    target_host : Optional[Union[tvm.target.Target, str]]
         The target host device of this search task.
     hardware_params : Optional[HardwareParams]
         Hardware parameters used in this search task.
@@ -182,6 +160,10 @@ def create_task(func, args, target, target_host=None, hardware_params=None):
     """
     workload_key = make_workload_key(func, args)
     dag = ComputeDAG(workload_key)
+    if isinstance(target, str):
+        target = Target(target)
+    if isinstance(target_host, str):
+        target_host = Target(target_host)
     return SearchTask(dag, workload_key, target, target_host, hardware_params)
 
 
diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py
index 7648ebe..81c314f 100644
--- a/python/tvm/auto_scheduler/measure.py
+++ b/python/tvm/auto_scheduler/measure.py
@@ -25,7 +25,7 @@ We separate the measurement into two steps: build and run.
 A builder builds the executable binary files and a runner runs the binary files to
 get the measurement results. The flow of data structures is
 
-  .                `ProgramBuilder`                 `ProgramRunner`
+  .               `ProgramBuilder`                 `ProgramRunner`
   `MeasureInput` -----------------> `BuildResult` ----------------> `MeasureResult`
 
 We implement these in python to utilize python's multiprocessing and error handling.
diff --git a/python/tvm/auto_scheduler/measure_record.py b/python/tvm/auto_scheduler/measure_record.py
index c7ae196..1d0d765 100644
--- a/python/tvm/auto_scheduler/measure_record.py
+++ b/python/tvm/auto_scheduler/measure_record.py
@@ -21,7 +21,9 @@ import numpy as np
 
 import tvm._ffi
 from tvm.runtime import Object
-from .measure import MeasureCallback, MeasureErrorNo
+from .compute_dag import ComputeDAG
+from .measure import MeasureErrorNo, MeasureInput, MeasureCallback
+from .search_task import SearchTask
 from . import _ffi_api
 
 
@@ -70,6 +72,13 @@ class RecordReader(Object):
             The MeasureInputs loaded from the log file.
         results : List[auto_scheduler.measure.MeasureResult]
             The MeasureResults loaded from the log file.
+
+        Notes
+        -----
+        Some unimportant and expensive fields in the returned MeasureInput are not deserialized
+        for faster read speed (e.g. input.task.compute_dag, input.state.stages).
+        If you want to use them, you can call the :code:`recover_measure_input` below
+        to rebuild these fields.
         """
         inputs, results = _ffi_api.RecordReaderReadLines(
             self, max_lines if max_lines else -1, skip_lines
@@ -96,6 +105,13 @@ def load_records(filename):
     Returns
     -------
     logs : List[auto_scheduler.measure.MeasureInput, auto_scheduler.measure.MeasureResult]
+
+    Notes
+    -----
+    Some unimportant and expensive fields in the returned MeasureInput are not deserialized
+    for faster read speed (e.g., input.task.compute_dag, input.state.stages).
+    If you want to use them, you can call the :code:`recover_measure_input` below
+    to rebuild these fields.
     """
     return zip(*RecordReader(filename).read_lines())
 
@@ -159,3 +175,38 @@ def load_best(filename, workload_key=None, target=None):
             best_res = res
 
     return best_inp, best_res
+
+
+def recover_measure_input(inp, rebuild_state=False):
+    """
+    Recover a deserialized MeasureInput by rebuilding the missing fields.
+    1. Rebuid the compute_dag in inp.task
+    2. (Optional) Rebuild the stages in inp.state
+
+    Parameters
+    ----------
+    inp: MeasureInput
+        The deserialized MeasureInput
+    rebuild_state: bool = False
+        Whether rebuild the stages in MeasureInput.State
+
+    Returns
+    -------
+    new_input: MeasureInput
+        The fully recovered MeasureInput with all fields rebuilt.
+    """
+    task = inp.task
+    new_task = SearchTask(
+        ComputeDAG(task.workload_key),
+        task.workload_key,
+        task.target,
+        task.target_host,
+        task.hardware_params,
+    )
+
+    if rebuild_state:
+        new_state = new_task.compute_dag.infer_bound_from_state(inp.state)
+    else:
+        new_state = inp.state
+
+    return MeasureInput(new_task, new_state)
diff --git a/python/tvm/auto_scheduler/search_policy.py b/python/tvm/auto_scheduler/search_policy.py
index bf7e2eb..5533aec 100644
--- a/python/tvm/auto_scheduler/search_policy.py
+++ b/python/tvm/auto_scheduler/search_policy.py
@@ -123,7 +123,7 @@ class SketchPolicy(SearchPolicy):
         "gpu_multi_level_tiling_structure": "SSSRRSRS",
         # Notice: the default thread bind policy of GPU assumes the tiling structure to have at
         # least 3 spatial tiling levels in outermost
-        "max_innermost_split_factor": 16,
+        "max_innermost_split_factor": 64,
         "max_vectorize_size": 16,
         "disable_change_compute_location": 0,
     }
diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py
new file mode 100644
index 0000000..92c4f48
--- /dev/null
+++ b/python/tvm/auto_scheduler/search_task.py
@@ -0,0 +1,47 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+""" The definiton of SearchTask """
+
+import tvm._ffi
+from tvm.runtime import Object
+
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("auto_scheduler.SearchTask")
+class SearchTask(Object):
+    """The computation information and hardware parameters for a schedule search task.
+
+    Parameters
+    ----------
+    dag : ComputeDAG
+        The ComputeDAG for the corresponding compute declaration.
+    workload_key : str
+        The workload key for the corresponding compute declaration.
+    target : tvm.target.Target
+        The target device of this search task.
+    target_host : Optional[tvm.target.Target]
+        The target host device of this search task.
+    hardware_params : Optional[HardwareParams]
+        Hardware parameters used in this search task.
+    """
+
+    def __init__(self, dag, workload_key, target, target_host=None, hardware_params=None):
+        self.__init_handle_by_constructor__(
+            _ffi_api.SearchTask, dag, workload_key, target, target_host, hardware_params
+        )
diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.cc b/src/auto_scheduler/search_policy/sketch_policy_rules.cc
index c8370d6..045ee86 100644
--- a/src/auto_scheduler/search_policy/sketch_policy_rules.cc
+++ b/src/auto_scheduler/search_policy/sketch_policy_rules.cc
@@ -441,6 +441,9 @@ std::vector<std::pair<State, int>> RuleSpecialComputeLocationGPU::Apply(
 
 PopulationGenerationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* policy, State* state,
                                                              std::mt19937* rand_gen) const {
+  int max_innermost_split_factor =
+      GetIntParam(policy->params, SketchParamKey::max_innermost_split_factor);
+
   StateNode* pstate = state->CopyOnWrite();
   // Scan the transformation history and randomly fill tiles size for all SplitStep
   for (size_t step_id = 0; step_id < (*state)->transform_steps.size(); ++step_id) {
@@ -459,8 +462,7 @@ PopulationGenerationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* p
       CHECK(ps->extent);
       int extent = GetIntImm(ps->extent.value());
       const auto& candidate_lens = policy->split_memo.GetFactorizationSchemes(
-          extent, ps->lengths.size(),
-          GetIntParam(policy->params, SketchParamKey::max_innermost_split_factor));
+          extent, ps->lengths.size(), max_innermost_split_factor);
       const auto& candidate_lengths = candidate_lens[(*rand_gen)() % candidate_lens.size()];
 
       pstate->transform_steps.Set(
diff --git a/tests/python/unittest/test_auto_scheduler_common.py b/tests/python/unittest/test_auto_scheduler_common.py
index eaf328c..880b112 100644
--- a/tests/python/unittest/test_auto_scheduler_common.py
+++ b/tests/python/unittest/test_auto_scheduler_common.py
@@ -216,6 +216,7 @@ def conv2d_winograd_nhwc_auto_scheduler_test(
 
 
 def get_tiled_matmul():
+    """Get a compute dag and a state for tiled matmul"""
     A, B, C = matmul_auto_scheduler_test(512, 512, 512)
     dag = auto_scheduler.ComputeDAG([A, B, C])
 
diff --git a/tests/python/unittest/test_auto_scheduler_cost_model.py b/tests/python/unittest/test_auto_scheduler_cost_model.py
index a28618c..62acb6b 100644
--- a/tests/python/unittest/test_auto_scheduler_cost_model.py
+++ b/tests/python/unittest/test_auto_scheduler_cost_model.py
@@ -28,12 +28,9 @@ from test_auto_scheduler_common import matmul_auto_scheduler_test
 
 
 def get_sample_records(number):
-    """Generate random a list of random MeasureInput and MeasureResult pairs"""
+    """Generate a list of random MeasureInput and MeasureResult pairs"""
     N = 128
-    workload_key = auto_scheduler.make_workload_key(matmul_auto_scheduler_test, (N, N, N))
-    dag = auto_scheduler.ComputeDAG(workload_key)
-    target = tvm.target.Target("llvm")
-    task = auto_scheduler.SearchTask(dag, workload_key, target)
+    task = auto_scheduler.create_task(matmul_auto_scheduler_test, (N, N, N), "llvm")
     policy = auto_scheduler.SketchPolicy(task, verbose=0)
     states = policy.sample_initial_population(number)
 
@@ -43,11 +40,11 @@ def get_sample_records(number):
         for _ in range(len(inputs))
     ]
 
-    return task, dag, inputs, results
+    return task, inputs, results
 
 
 def test_random_model():
-    task, dag, inputs, results = get_sample_records(50)
+    task, inputs, results = get_sample_records(50)
 
     model = auto_scheduler.RandomModel()
     model.update(inputs, results)
@@ -56,7 +53,7 @@ def test_random_model():
 
 
 def test_xgb_model():
-    task, dag, inputs, results = get_sample_records(50)
+    task, inputs, results = get_sample_records(50)
 
     model = auto_scheduler.XGBModel(num_warmup_sample=-1)
     model.update(inputs, results)
@@ -66,13 +63,16 @@ def test_xgb_model():
     costs = [np.mean([x.value for x in res.costs]) for res in results]
     throughputs = np.min(costs) / costs
 
+    # test regression quality
     rmse = np.sqrt(np.mean([np.square(pred - label) for pred, label in zip(preds, throughputs)]))
     assert rmse <= 0.3
 
+    # test loading a record file
     with tempfile.NamedTemporaryFile() as fp:
         auto_scheduler.save_records(fp.name, inputs, results)
         model.update_from_file(fp.name)
 
+    # test model serialization
     with tempfile.NamedTemporaryFile() as fp:
         model.save(fp.name)
         model.load(fp.name)
diff --git a/tests/python/unittest/test_auto_scheduler_evolutionary_search.py b/tests/python/unittest/test_auto_scheduler_evolutionary_search.py
index b51066b..9fec6f1 100644
--- a/tests/python/unittest/test_auto_scheduler_evolutionary_search.py
+++ b/tests/python/unittest/test_auto_scheduler_evolutionary_search.py
@@ -69,7 +69,6 @@ def test_mutate_tile_size():
     assert found
 
 
-@pytest.mark.skip(reason="flaky")
 def test_mutate_parallel():
     """
     The test case initializes evo search with a batch of "bad" states and check whether
@@ -95,20 +94,18 @@ def test_mutate_parallel():
                 scores.append(1 if self.is_good_state(state) else 0)
             return scores
 
-    workload_key = auto_scheduler.make_workload_key(matmul_auto_scheduler_test, (1024, 1024, 1024))
-    dag = auto_scheduler.ComputeDAG(workload_key)
-    task = auto_scheduler.SearchTask(dag, workload_key, tvm.target.Target("llvm"))
+    task = auto_scheduler.create_task(matmul_auto_scheduler_test, (1024, 1024, 1024), "llvm")
     policy = auto_scheduler.SketchPolicy(task, program_cost_model=MockCostModel(), verbose=0)
-    states = policy.sample_initial_population(100)
-
-    bad_states = []
-    for state in states:
-        if not MockCostModel.is_good_state(state):
-            bad_states.append(state)
 
     found = False
     retry_ct = 0
-    while retry_ct < 5 and not found:
+    while retry_ct < 10 and not found:
+        states = policy.sample_initial_population(100)
+        bad_states = []
+        for state in states:
+            if not MockCostModel.is_good_state(state):
+                bad_states.append(state)
+
         new_states = policy.evolutionary_search(bad_states, 50)
         for state in new_states:
             if MockCostModel.is_good_state(state):
diff --git a/tests/python/unittest/test_auto_scheduler_layout_rewrite.py b/tests/python/unittest/test_auto_scheduler_layout_rewrite.py
index caa1d6a..3ce7a43 100644
--- a/tests/python/unittest/test_auto_scheduler_layout_rewrite.py
+++ b/tests/python/unittest/test_auto_scheduler_layout_rewrite.py
@@ -41,12 +41,9 @@ def test_apply_steps_with_layout_rewrite():
 
 def test_layout_rewrite_correctness():
     N = 128
-    target = "llvm"
-    workload = matmul_auto_scheduler_test
-    workload_key = auto_scheduler.make_workload_key(workload, (N, N, N))
-    dag = auto_scheduler.ComputeDAG(workload_key)
-    target = tvm.target.Target(target)
-    task = auto_scheduler.SearchTask(dag, workload_key, target)
+    target = tvm.target.Target("llvm")
+    task = auto_scheduler.create_task(matmul_auto_scheduler_test, (N, N, N), target)
+    dag = task.compute_dag
 
     with tempfile.NamedTemporaryFile() as fp:
         log_file = fp.name
@@ -60,7 +57,7 @@ def test_layout_rewrite_correctness():
             measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
         )
         auto_scheduler.auto_schedule(task, search_policy, tuning_options)
-        inp, _ = auto_scheduler.load_best(log_file, workload_key, target)
+        inp, _ = auto_scheduler.load_best(log_file, task.workload_key, target)
         s, bufs = dag.apply_steps_from_state(inp.state, layout_rewrite=True)
         s_ref, bufs_ref = dag.apply_steps_from_state(inp.state, layout_rewrite=False)
         np_args = [np.random.randn(*topi.get_const_tuple(x.shape)).astype(x.dtype) for x in bufs]
@@ -89,10 +86,10 @@ def test_layout_rewrite_correctness():
             np_args_ref[1] = np_args_ref[1].transpose(new_order)
             np_args_ref[1] = np_args_ref[1].reshape((red_dim, out_dim))
 
-        func = tvm.build(s, bufs, target=inp.task.target, target_host=inp.task.target_host)
-        func_ref = tvm.build(s_ref, bufs_ref, target="llvm")
+        func = tvm.build(s, bufs, target=target)
+        func_ref = tvm.build(s_ref, bufs_ref, target=target)
 
-        ctx = tvm.context(str(inp.task.target))
+        ctx = tvm.context(str(target))
         ctx_ref = tvm.cpu()
 
         args = [tvm.nd.array(x, ctx=ctx) for x in np_args]
diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py
index 5dae2a5..4369d20 100644
--- a/tests/python/unittest/test_auto_scheduler_measure.py
+++ b/tests/python/unittest/test_auto_scheduler_measure.py
@@ -167,45 +167,69 @@ def test_record_pragma_storage_align_rfactor():
     record_common(dag, s)
 
 
-def test_measure_local_builder_runner(enable_cpu_cache_flush=False):
+def test_recover_measure_input():
+    task = auto_scheduler.create_task(matmul_auto_scheduler_test, [512, 512, 512], "llvm")
+
+    inp = auto_scheduler.measure.MeasureInput(task, task.compute_dag.init_state)
+    res = auto_scheduler.measure.MeasureResult([0.1], 0, "", 0.2, 1)
+
+    with tempfile.NamedTemporaryFile() as fp:
+        auto_scheduler.save_records(fp.name, [inp], [res])
+
+        log_reader = auto_scheduler.RecordReader(fp.name)
+        inputs, results = log_reader.read_lines()
+        assert len(inputs) == 1
+
+        raw_inp = inputs[0]
+
+        correct_inp = auto_scheduler.measure_record.recover_measure_input(raw_inp)
+        assert str(correct_inp.task.compute_dag) == str(inp.task.compute_dag)
+
+        correct_inp = auto_scheduler.measure_record.recover_measure_input(
+            raw_inp, rebuild_state=True
+        )
+        assert str(correct_inp.state) == str(inp.state)
+
+
+def test_measure_local_builder_runner():
     if not tvm.testing.device_enabled("llvm"):
         return
 
-    dag, s0 = get_tiled_matmul()
-    tgt = tvm.target.Target("llvm")
-    task = auto_scheduler.SearchTask(dag, "test", tgt)
+    task = auto_scheduler.create_task(matmul_auto_scheduler_test, [512, 512, 512], "llvm")
 
-    minp = auto_scheduler.MeasureInput(task, s0)
-    local_builder = auto_scheduler.LocalBuilder()
-    local_runner = auto_scheduler.LocalRunner(
-        timeout=60, enable_cpu_cache_flush=enable_cpu_cache_flush
-    )
+    for enable_cpu_cache_flush in [True, False]:
+        minp = auto_scheduler.MeasureInput(task, task.compute_dag.init_state)
+        local_builder = auto_scheduler.LocalBuilder()
+        local_runner = auto_scheduler.LocalRunner(
+            timeout=60, enable_cpu_cache_flush=enable_cpu_cache_flush
+        )
 
-    bress = local_builder.build([minp])
-    assert bress[0].error_no == 0
-    mress = local_runner.run([minp], bress)
-    assert mress[0].error_no == 0
+        bress = local_builder.build([minp])
+        assert bress[0].error_no == 0
+        mress = local_runner.run([minp], bress)
+        assert mress[0].error_no == 0
 
 
-def test_measure_local_builder_rpc_runner(enable_cpu_cache_flush=False):
+def test_measure_local_builder_rpc_runner():
     if not tvm.testing.device_enabled("llvm"):
         return
 
-    dag, s0 = get_tiled_matmul()
-    tgt = tvm.target.Target("llvm")
-    task = auto_scheduler.SearchTask(dag, "test", tgt)
+    task = auto_scheduler.create_task(matmul_auto_scheduler_test, [512, 512, 512], "llvm")
 
-    minp = auto_scheduler.MeasureInput(task, s0)
-    local_builder = auto_scheduler.LocalBuilder()
-    measure_ctx = auto_scheduler.LocalRPCMeasureContext(
-        timeout=60, enable_cpu_cache_flush=enable_cpu_cache_flush
-    )
-    rpc_runner = measure_ctx.runner
+    for enable_cpu_cache_flush in [True, False]:
+        minp = auto_scheduler.MeasureInput(task, task.compute_dag.init_state)
+        local_builder = auto_scheduler.LocalBuilder()
+        measure_ctx = auto_scheduler.LocalRPCMeasureContext(
+            timeout=60, enable_cpu_cache_flush=enable_cpu_cache_flush
+        )
+        rpc_runner = measure_ctx.runner
+
+        bress = local_builder.build([minp])
+        assert bress[0].error_no == 0
+        mress = rpc_runner.run([minp], bress)
+        assert mress[0].error_no == 0
 
-    bress = local_builder.build([minp])
-    assert bress[0].error_no == 0
-    mress = rpc_runner.run([minp], bress)
-    assert mress[0].error_no == 0
+        del measure_ctx
 
 
 if __name__ == "__main__":
@@ -213,7 +237,6 @@ if __name__ == "__main__":
     test_record_compute_at_root_inline_cache_read_write()
     test_record_follow_split_follow_fused_split()
     test_record_pragma_storage_align_rfactor()
-    test_measure_local_builder_runner(enable_cpu_cache_flush=True)
-    test_measure_local_builder_runner(enable_cpu_cache_flush=False)
-    test_measure_local_builder_rpc_runner(enable_cpu_cache_flush=True)
-    test_measure_local_builder_rpc_runner(enable_cpu_cache_flush=False)
+    test_recover_measure_input()
+    test_measure_local_builder_runner()
+    test_measure_local_builder_rpc_runner()
diff --git a/tests/python/unittest/test_auto_scheduler_search_policy.py b/tests/python/unittest/test_auto_scheduler_search_policy.py
index 04b54b2..07cf4c8 100644
--- a/tests/python/unittest/test_auto_scheduler_search_policy.py
+++ b/tests/python/unittest/test_auto_scheduler_search_policy.py
@@ -42,10 +42,8 @@ def search_common(
 
     random.seed(seed)
     N = 128
-    workload_key = auto_scheduler.make_workload_key(workload, (N, N, N))
-    dag = auto_scheduler.ComputeDAG(workload_key)
     target = tvm.target.Target(target)
-    task = auto_scheduler.SearchTask(dag, workload_key, target)
+    task = auto_scheduler.create_task(workload, (N, N, N), target)
 
     with tempfile.NamedTemporaryFile() as fp:
         log_file = fp.name
@@ -70,10 +68,10 @@ def search_common(
         print("*" * 80)
         print(target)
         print("*" * 80)
-        inp, res = auto_scheduler.load_best(log_file, workload_key, target)
+        inp, res = auto_scheduler.load_best(log_file, task.workload_key, target)
 
         print("==== Python Code ====")
-        print(dag.print_python_code_from_state(inp.state))
+        print(task.compute_dag.print_python_code_from_state(inp.state))
 
         try:
             print("==== Lowered Stmt ====")
@@ -81,7 +79,7 @@ def search_common(
             mod = tvm.build(sch, args, target)
 
             ctx = tvm.context(str(target), 0)
-            dtype = dag.tensors[0].dtype
+            dtype = task.compute_dag.tensors[0].dtype
             a = tvm.nd.array(np.random.uniform(size=(N, N)).astype(dtype), ctx)
             b = tvm.nd.array(np.random.uniform(size=(N, N)).astype(dtype), ctx)
             c = tvm.nd.array(np.zeros((N, N), dtype=dtype), ctx)
diff --git a/tests/python/unittest/test_auto_scheduler_sketch_generation.py b/tests/python/unittest/test_auto_scheduler_sketch_generation.py
index 5a687da..47aa78a 100644
--- a/tests/python/unittest/test_auto_scheduler_sketch_generation.py
+++ b/tests/python/unittest/test_auto_scheduler_sketch_generation.py
@@ -36,9 +36,7 @@ from test_auto_scheduler_common import (
 
 
 def generate_sketches(workload_func, args, target, print_for_debug=False):
-    workload_key = auto_scheduler.make_workload_key(workload_func, args)
-    dag = auto_scheduler.ComputeDAG(workload_key)
-    task = auto_scheduler.SearchTask(dag, workload_key, tvm.target.Target(target))
+    task = auto_scheduler.create_task(workload_func, args, tvm.target.Target(target))
     policy = auto_scheduler.SketchPolicy(task, verbose=0)
     return policy.generate_sketches(print_for_debug)