You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lm...@apache.org on 2020/12/04 13:03:57 UTC
[tvm] branch main updated: [AutoScheduler] Refactor task interface
for tuning single operators (#7028)
This is an automated email from the ASF dual-hosted git repository.
lmzheng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 75afcd7 [AutoScheduler] Refactor task interface for tuning single operators (#7028)
75afcd7 is described below
commit 75afcd766e0e45fdb8bd3007b3114257f11a7ec4
Author: Lianmin Zheng <li...@gmail.com>
AuthorDate: Fri Dec 4 05:03:34 2020 -0800
[AutoScheduler] Refactor task interface for tuning single operators (#7028)
* [AutoScheduler] Refactor task interface
* updae tutorials and tests
* update
* fix lint
* fix lint
* update
* fix test
---
include/tvm/auto_scheduler/compute_dag.h | 4 +-
python/tvm/auto_scheduler/__init__.py | 7 +-
python/tvm/auto_scheduler/compute_dag.py | 23 +-
python/tvm/auto_scheduler/measure.py | 18 +-
python/tvm/auto_scheduler/measure_record.py | 2 +-
python/tvm/auto_scheduler/relay_integration.py | 9 +-
python/tvm/auto_scheduler/search_task.py | 304 ++++++++++++++++++++-
src/auto_scheduler/utils.h | 2 +-
src/tir/ir/expr.cc | 25 +-
.../unittest/test_auto_scheduler_compute_dag.py | 6 +-
.../unittest/test_auto_scheduler_cost_model.py | 2 +-
.../test_auto_scheduler_evolutionary_search.py | 10 +-
.../python/unittest/test_auto_scheduler_feature.py | 15 +-
.../unittest/test_auto_scheduler_layout_rewrite.py | 22 +-
.../python/unittest/test_auto_scheduler_measure.py | 15 +-
.../unittest/test_auto_scheduler_search_policy.py | 8 +-
.../test_auto_scheduler_sketch_generation.py | 2 +-
.../unittest/test_auto_scheduler_task_scheduler.py | 12 +-
tutorials/auto_scheduler/ci_logs/matmul.json | 2 +-
tutorials/auto_scheduler/tune_conv2d_layer_cuda.py | 36 +--
tutorials/auto_scheduler/tune_matmul_x86.py | 49 ++--
tutorials/auto_scheduler/tune_network_cuda.py | 5 +-
tutorials/auto_scheduler/tune_network_x86.py | 3 +-
23 files changed, 456 insertions(+), 125 deletions(-)
diff --git a/include/tvm/auto_scheduler/compute_dag.h b/include/tvm/auto_scheduler/compute_dag.h
index da0d196..b9306c6 100755
--- a/include/tvm/auto_scheduler/compute_dag.h
+++ b/include/tvm/auto_scheduler/compute_dag.h
@@ -199,7 +199,7 @@ class ComputeDAGNode : public Object {
* This is an optimization to rewrite the layout of input tensors according to the schedule we get.
*/
enum class LayoutRewriteOption : int {
- /*! \brief Do not process layout rewrite. */
+ /*! \brief Do not perform layout rewrite. */
NoRewrite = 0,
/*! \brief Insert layout transformation stages for input placeholders in the compute DAG */
InsertTransformStage = 1,
@@ -207,7 +207,7 @@ enum class LayoutRewriteOption : int {
* \brief Do not insert layout transformation stages and assume the input placeholders
* are pre-transformed.
* \note The lowered function with this option does not accept the origial input shapes,
- * so this option must be used along with a layout conversion pass in Relay.
+ * so this option must be used along with `AutoSchedulerLayoutRewrite` pass in Relay.
*/
RewriteForPreTransformed = 2,
};
diff --git a/python/tvm/auto_scheduler/__init__.py b/python/tvm/auto_scheduler/__init__.py
index bee2e7f..4926b88 100644
--- a/python/tvm/auto_scheduler/__init__.py
+++ b/python/tvm/auto_scheduler/__init__.py
@@ -31,8 +31,7 @@ from . import utils
from . import workload_registry
# Shortcut
-from .auto_schedule import TuningOptions, HardwareParams, create_task, auto_schedule
-from .compute_dag import ComputeDAG
+from .compute_dag import ComputeDAG, LayoutRewriteOption
from .cost_model import RandomModel, XGBModel
from .dispatcher import DispatchContext, ApplyHistoryBest
from .measure import (
@@ -43,14 +42,14 @@ from .measure import (
RPCRunner,
LocalRPCMeasureContext,
)
-from .measure_record import RecordToFile, RecordReader, load_best, load_records, save_records
+from .measure_record import RecordToFile, RecordReader, load_best_record, load_records, save_records
from .relay_integration import (
extract_tasks,
remove_index_check,
rewrite_compute_body,
is_auto_scheduler_enabled,
)
-from .search_task import SearchTask
+from .search_task import SearchTask, TuningOptions, HardwareParams, create_task, auto_schedule
from .search_policy import EmptyPolicy, SketchPolicy, PreloadMeasuredStates
from .task_scheduler import TaskScheduler
from .workload_registry import register_workload, make_workload_key
diff --git a/python/tvm/auto_scheduler/compute_dag.py b/python/tvm/auto_scheduler/compute_dag.py
index cba3600..a6f9954 100755
--- a/python/tvm/auto_scheduler/compute_dag.py
+++ b/python/tvm/auto_scheduler/compute_dag.py
@@ -31,6 +31,20 @@ from .utils import get_const_tuple
from .workload_registry import workload_key_to_tensors
+class LayoutRewriteOption:
+ """Options for applying layout rewrite."""
+
+ # Do not perform layout rewrite
+ NO_REWRITE = 0
+ # Insert layout transformation stages for input placeholders in the compute DAG
+ INSERT_TRANSFORM_STAGE = 1
+ # Do not insert layout transformation stages and assume the input placeholders
+ # are pre-transformed.
+ # Note: The lowered function with this option does not accept the origial input shapes,
+ # so this option must be used along with `AutoSchedulerLayoutRewrite` pass in Relay.
+ REWRITE_FOR_PRE_TRANSFORMED = 2
+
+
@tvm._ffi.register_object("auto_scheduler.ComputeDAG")
class ComputeDAG(Object):
"""
@@ -52,11 +66,6 @@ class ComputeDAG(Object):
Input/output tensors or workload key for a compute declaration.
"""
- # Layout Rewrite Options
- NoRewrite = 0
- InsertTransformStage = 1
- RewriteForPreTransformed = 2
-
def __init__(self, compute_or_sche):
if isinstance(compute_or_sche, str):
compute = workload_key_to_tensors(compute_or_sche)
@@ -92,7 +101,7 @@ class ComputeDAG(Object):
"""
return State(self.init_state, self)
- def apply_steps_from_state(self, state, layout_rewrite=NoRewrite):
+ def apply_steps_from_state(self, state, layout_rewrite=LayoutRewriteOption.NO_REWRITE):
"""
Apply the history transform steps from a State to get a TVM schedule.
@@ -101,7 +110,7 @@ class ComputeDAG(Object):
state : Union[State, StateObject]
The state from which we get transform steps.
- layout_rewrite: Bool
+ layout_rewrite: LayoutRewriteOption = NoRewrite
Rewrite the layout of placeholders specified by "layout_free_placeholders" attr
to make it most friendly for the generated schedule to read from.
diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py
index b282651..7e4f149 100644
--- a/python/tvm/auto_scheduler/measure.py
+++ b/python/tvm/auto_scheduler/measure.py
@@ -53,8 +53,7 @@ from .utils import (
make_traceback_info,
request_remote,
)
-from .compute_dag import ComputeDAG
-from .search_task import SearchTask
+from .compute_dag import LayoutRewriteOption
from .workload_registry import (
serialize_workload_registry_entry,
deserialize_workload_registry_entry,
@@ -178,13 +177,15 @@ def recover_measure_input(inp, rebuild_state=False):
new_input: MeasureInput
The fully recovered MeasureInput with all fields rebuilt.
"""
+ # pylint: disable=import-outside-toplevel
+ from .search_task import SearchTask # lazily import to avoid recursive dependency
+
task = inp.task
new_task = SearchTask(
- ComputeDAG(task.workload_key),
- task.workload_key,
- task.target,
- task.target_host,
- task.hardware_params,
+ workload_key=task.workload_key,
+ target=task.target,
+ target_host=task.target_host,
+ hardware_params=task.hardware_params,
)
if rebuild_state:
@@ -521,6 +522,7 @@ class LocalRPCMeasureContext:
# Close the tracker and server before exit
self.tracker.terminate()
self.server.terminate()
+ time.sleep(0.5)
class MeasureErrorNo(object):
@@ -549,7 +551,7 @@ def _timed_func(inp_serialized, build_func, verbose):
try:
sch, args = task.compute_dag.apply_steps_from_state(
- inp.state, layout_rewrite=ComputeDAG.RewriteForPreTransformed
+ inp.state, layout_rewrite=LayoutRewriteOption.REWRITE_FOR_PRE_TRANSFORMED
)
# pylint: disable=broad-except
except Exception:
diff --git a/python/tvm/auto_scheduler/measure_record.py b/python/tvm/auto_scheduler/measure_record.py
index 2569f39..d6fea5c 100644
--- a/python/tvm/auto_scheduler/measure_record.py
+++ b/python/tvm/auto_scheduler/measure_record.py
@@ -137,7 +137,7 @@ def save_records(filename, inputs, results):
_ffi_api.SaveRecords(filename, inputs, results)
-def load_best(filename, workload_key=None, target=None):
+def load_best_record(filename, workload_key=None, target=None):
"""Return the best measurement pair form a log file. This may return none results if
there is no legal measure pair with the specified workload_key/target found from the log file.
diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py
index 5a19791..4c493d1 100644
--- a/python/tvm/auto_scheduler/relay_integration.py
+++ b/python/tvm/auto_scheduler/relay_integration.py
@@ -120,7 +120,14 @@ def extract_tasks(
weights = []
for wkl_key, ccache_key in env.wkl_key_to_ccache_key.items():
dag = ComputeDAG(wkl_key)
- tasks.append(SearchTask(dag, wkl_key, target, target_host, hardware_params))
+ tasks.append(
+ SearchTask(
+ workload_key=wkl_key,
+ target=target,
+ target_host=target_host,
+ hardware_params=hardware_params,
+ )
+ )
weights.append(use_count_dict[ccache_key] + 1)
# clean the cached lowering results
diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py
index f2dadcc..31698d0 100644
--- a/python/tvm/auto_scheduler/search_task.py
+++ b/python/tvm/auto_scheduler/search_task.py
@@ -22,8 +22,139 @@ import json
import tvm._ffi
from tvm.runtime import Object
-from . import _ffi_api
+from tvm.driver.build_module import build
+from tvm.target import Target
+from .measure import LocalBuilder, LocalRunner
+from .measure_record import load_best_record
+from .workload_registry import make_workload_key
+from .compute_dag import ComputeDAG, LayoutRewriteOption
+from .cost_model import XGBModel
+from .search_policy import SketchPolicy
from .workload_registry import register_workload_tensors
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("auto_scheduler.HardwareParams")
+class HardwareParams(Object):
+ """The parameters of target hardware used to guide the search policy
+ TODO(jcf94): This is considered to be merged with the new Target specification:
+ https://discuss.tvm.apache.org/t/rfc-tvm-target-specification/6844
+ Parameters
+ ----------
+ num_cores : int
+ The number of device cores.
+ vector_unit_bytes : int
+ The width of vector units in bytes.
+ cache_line_bytes : int
+ The size of cache line in bytes.
+ max_shared_memory_per_block : int
+ The max shared memory per block in bytes.
+ max_registers_per_block : int
+ The max number of register per block.
+ max_threads_per_block : int
+ The max number of threads per block.
+ max_vthread_extent : int
+ The max vthread extent.
+ warp_size : int
+ The thread numbers of a warp.
+ """
+
+ def __init__(
+ self,
+ num_cores,
+ vector_unit_bytes,
+ cache_line_bytes,
+ max_shared_memory_per_block,
+ max_registers_per_block,
+ max_threads_per_block,
+ max_vthread_extent,
+ warp_size,
+ ):
+ self.__init_handle_by_constructor__(
+ _ffi_api.HardwareParams,
+ num_cores,
+ vector_unit_bytes,
+ cache_line_bytes,
+ max_shared_memory_per_block,
+ max_registers_per_block,
+ max_threads_per_block,
+ max_vthread_extent,
+ warp_size,
+ )
+
+
+@tvm._ffi.register_object("auto_scheduler.TuningOptions")
+class TuningOptions(Object):
+ """This controls the options of performance tuning.
+
+ Parameters
+ ----------
+ num_measure_trials: int = 0
+ The number of measurement trials.
+ The search policy measures `num_measure_trials` schedules in total and returns the best one
+ among them.
+ With `num_measure_trials` == 0, the policy will do the schedule search but won't involve
+ measurement. This can be used to get a runnable schedule quickly without auto-tuning.
+ early_stopping: Optional[int]
+ Stop the tuning early if getting no improvement after n measurements.
+ num_measures_per_round: int = 64
+ The number of schedules to be measured at each search round.
+ The whole schedule search process will try a total number of `num_measure_trials` in several
+ rounds.
+ verbose: int = 1
+ Verbosity level. 0 for silent, 1 to output information during schedule search.
+ builder: Union[ProgramBuilder, str] = 'local'
+ ProgramBuilder which builds the program.
+ runner: Union[ProgramRunner, str] = 'local'
+ ProgramRunner which runs the program and measures time costs.
+ measure_callbacks: Optional[List[MeasureCallback]]
+ Callback functions called after each measurement.
+ Candidates:
+ - auto_scheduler.RecordToFile
+ """
+
+ def __init__(
+ self,
+ num_measure_trials=0,
+ early_stopping=None,
+ num_measures_per_round=64,
+ verbose=1,
+ builder="local",
+ runner="local",
+ measure_callbacks=None,
+ ):
+ if isinstance(builder, str):
+ if builder == "local":
+ builder = LocalBuilder()
+ else:
+ raise ValueError("Invalid builder: " + builder)
+ elif not isinstance(builder, tvm.auto_scheduler.measure.ProgramBuilder):
+ raise ValueError(
+ "Invalid builder: "
+ + builder
+ + " . TuningOptions expects a ProgramBuilder or string."
+ )
+
+ if isinstance(runner, str):
+ if runner == "local":
+ runner = LocalRunner()
+ else:
+ raise ValueError("Invalid runner: " + runner)
+ elif not isinstance(runner, tvm.auto_scheduler.measure.ProgramRunner):
+ raise ValueError(
+ "Invalid runner: " + runner + " . TuningOptions expects a ProgramRunner or string."
+ )
+
+ self.__init_handle_by_constructor__(
+ _ffi_api.TuningOptions,
+ num_measure_trials,
+ early_stopping or -1,
+ num_measures_per_round,
+ verbose,
+ builder,
+ runner,
+ measure_callbacks,
+ )
@tvm._ffi.register_object("auto_scheduler.SearchTask")
@@ -32,7 +163,12 @@ class SearchTask(Object):
Parameters
----------
- dag : ComputeDAG
+ func : Union[Function, str]
+ The function that returns the compute declaration Tensors.
+ Can be the a function or the function name.
+ args : Union[Tuple[Any, ...], List[Any]]
+ The args of the function.
+ compute_dag : ComputeDAG
The ComputeDAG for the corresponding compute declaration.
workload_key : str
The workload key for the corresponding compute declaration.
@@ -42,18 +178,123 @@ class SearchTask(Object):
The target host device of this search task.
hardware_params : Optional[HardwareParams]
Hardware parameters used in this search task.
+
+ Examples
+ --------
+ .. code-block:: python
+
+ # We support two ways to create a search task
+
+ # Way 1: create a task by a workload generation function.
+ # The `workload_func` is a function decorated by @auto_scheduler.register_workload
+ task = SearchTask(func=workload_func, args=args, target=target)
+
+ # Way 2: create a task by a workload_key.
+ # The `workload_key` is a string, which can be either a hash key or a json-serialized
+ # tuple(func, args).
+ task = SearchTask(workload_key=workload_key, target=target)
"""
- def __init__(self, dag, workload_key, target, target_host=None, hardware_params=None):
- self.dag = dag
+ def __init__(
+ self,
+ func=None,
+ args=None,
+ compute_dag=None,
+ workload_key=None,
+ target=None,
+ target_host=None,
+ hardware_params=None,
+ ):
+ assert (
+ func is not None or workload_key is not None
+ ), "Either a workload generation function or a workload key should be provided"
+
+ if func is not None:
+ workload_key = make_workload_key(func, args)
+ if compute_dag is None:
+ compute_dag = ComputeDAG(workload_key)
+
+ assert target is not None, "Must specify a target."
+ if isinstance(target, str):
+ target = Target(target)
+ if isinstance(target_host, str):
+ target_host = Target(target_host)
+
+ self.dag = compute_dag
self.workload_key = workload_key
self.target = target
self.target_host = target_host
self.hardware_params = hardware_params
self.__init_handle_by_constructor__(
- _ffi_api.SearchTask, dag, workload_key, target, target_host, hardware_params
+ _ffi_api.SearchTask, compute_dag, workload_key, target, target_host, hardware_params
)
+ def tune(self, tuning_options, search_policy=None):
+ """Run auto scheduling search for a task
+
+ Parameters
+ ----------
+ tuning_options : TuningOptions
+ Tuning and measurement options.
+ search_policy : Optional[SearchPolicy]
+ The search policy to be used for schedule search.
+ """
+ if search_policy is None:
+ cost_model = XGBModel()
+ search_policy = SketchPolicy(self, cost_model)
+
+ _ffi_api.AutoSchedule(search_policy, tuning_options)
+
+ def apply_best(self, log_file, layout_rewrite_option=None):
+ """Apply the history best from a log file and return the schedule.
+
+ Parameters
+ ----------
+ log_file : str
+ The name of the log file.
+ layout_rewrite_option : Optional[LayoutRewriteOption]
+ The layout rewrite option.
+
+ Returns
+ -------
+ A `te.Schedule` and the a list of `te.Tensor` to be used in `tvm.lower` or `tvm.build`.
+ """
+ inp, _ = load_best_record(log_file, self.workload_key)
+
+ if layout_rewrite_option is None:
+ layout_rewrite_option = LayoutRewriteOption.NO_REWRITE
+ if self.target.kind.name == "llvm":
+ layout_rewrite_option = LayoutRewriteOption.INSERT_TRANSFORM_STAGE
+ sch, args = self.compute_dag.apply_steps_from_state(inp.state, layout_rewrite_option)
+ return sch, args
+
+ def print_best(self, log_file, print_mode="schedule"):
+ """Print the best schedule as python schedule API code or CUDA source code.
+
+ Parameters
+ ----------
+ log_file : str
+ The name of the log file
+ print_mode: str
+ if "schedule", print the best schedule as python schedule API code.
+ if "cuda", print the best schedule as CUDA source code.
+
+ Returns
+ -------
+ code: str
+ The best schedule code in python API or CUDA source code
+ """
+ inp, _ = load_best_record(log_file, self.workload_key)
+
+ if print_mode == "schedule":
+ return self.compute_dag.print_python_code_from_state(inp.state)
+ if print_mode == "cuda":
+ assert self.target.kind.name == "cuda"
+ sch, args = self.compute_dag.apply_steps_from_state(inp.state)
+ func = build(sch, args, "cuda")
+ return func.imported_modules[0].get_source()
+ raise ValueError("Invalid print_mode: %s" % print_mode)
+
def __getstate__(self):
return {
"dag": self.dag,
@@ -90,3 +331,56 @@ class SearchTask(Object):
self.target_host,
self.hardware_params,
)
+
+
+def create_task(func, args, target, target_host=None, hardware_params=None):
+ """THIS API IS DEPRECATED.
+
+ Create a search task.
+
+ Parameters
+ ----------
+ func : Union[Function, str]
+ The function that returns the compute declaration Tensors.
+ Can be the a function or the function name.
+ args : Union[Tuple[Any, ...], List[Any]]
+ The args of the function.
+ target : Union[tvm.target.Target, str]
+ The target device of this search task.
+ target_host : Optional[Union[tvm.target.Target, str]]
+ The target host device of this search task.
+ hardware_params : Optional[HardwareParams]
+ Hardware parameters used in this search task.
+
+ Returns
+ -------
+ SearchTask: the created task
+ """
+ raise ValueError(
+ 'The API "auto_scheduler.create_task" is deprecated.'
+ "See https://github.com/apache/tvm/pull/7028 for the upgrade guide"
+ )
+
+
+def auto_schedule(task, search_policy=None, tuning_options=TuningOptions()):
+ """THIS API IS DEPRECATED.
+
+ Run auto scheduling search for a task.
+
+ Parameters
+ ----------
+ task : SearchTask
+ The SearchTask for the computation declaration.
+ search_policy : Optional[SearchPolicy]
+ The search policy to be used for schedule search.
+ tuning_options : Optional[TuningOptions]
+ Tuning and measurement options.
+
+ Returns
+ -------
+ A `te.Schedule` and the a list of `te.Tensor` to be used in `tvm.lower` or `tvm.build`.
+ """
+ raise ValueError(
+ 'The API "auto_scheduler.create_task" is deprecated.'
+ "See https://github.com/apache/tvm/pull/7028 for the upgrade guide."
+ )
diff --git a/src/auto_scheduler/utils.h b/src/auto_scheduler/utils.h
index bc29a37..9fc5a1d 100755
--- a/src/auto_scheduler/utils.h
+++ b/src/auto_scheduler/utils.h
@@ -192,7 +192,7 @@ inline bool StrEndsWith(const String& a, const String& b) {
/*! \brief Get an int value from an Expr */
inline int64_t GetIntImm(const PrimExpr& expr) {
auto pint = expr.as<IntImmNode>();
- ICHECK(pint != nullptr);
+ ICHECK(pint != nullptr) << "Expect an IntImm but get " << expr;
return pint->value;
}
diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc
index 2d2a299..aa40099 100644
--- a/src/tir/ir/expr.cc
+++ b/src/tir/ir/expr.cc
@@ -33,18 +33,19 @@
namespace tvm {
namespace tir {
-#define TVM_DEFINE_BINOP_CONSTRUCTOR(Name) \
- Name::Name(PrimExpr a, PrimExpr b, Span span) { \
- using T = Name::ContainerType; \
- ICHECK(a.defined()) << "ValueError: a is undefined\n"; \
- ICHECK(b.defined()) << "ValueError: b is undefined\n"; \
- ICHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types\n"; \
- ObjectPtr<T> node = make_object<T>(); \
- node->dtype = a.dtype(); \
- node->a = std::move(a); \
- node->b = std::move(b); \
- node->span = std::move(span); \
- data_ = std::move(node); \
+#define TVM_DEFINE_BINOP_CONSTRUCTOR(Name) \
+ Name::Name(PrimExpr a, PrimExpr b, Span span) { \
+ using T = Name::ContainerType; \
+ ICHECK(a.defined()) << "ValueError: a is undefined\n"; \
+ ICHECK(b.defined()) << "ValueError: b is undefined\n"; \
+ ICHECK(a.dtype() == b.dtype()) \
+ << "TypeError: mismatched types. " << a.dtype() << " vs. " << b.dtype() << "\n"; \
+ ObjectPtr<T> node = make_object<T>(); \
+ node->dtype = a.dtype(); \
+ node->a = std::move(a); \
+ node->b = std::move(b); \
+ node->span = std::move(span); \
+ data_ = std::move(node); \
}
#define TVM_DEFINE_CMPOP_CONSTRUCTOR(Name) \
diff --git a/tests/python/unittest/test_auto_scheduler_compute_dag.py b/tests/python/unittest/test_auto_scheduler_compute_dag.py
index 1356154..859964f 100644
--- a/tests/python/unittest/test_auto_scheduler_compute_dag.py
+++ b/tests/python/unittest/test_auto_scheduler_compute_dag.py
@@ -121,9 +121,9 @@ def test_stage_order():
# Serialize and deserialize the search task.
task = auto_scheduler.SearchTask(
- dag,
- json.dumps(("test-key",)),
- tvm.target.Target("llvm"),
+ compute_dag=dag,
+ workload_key=json.dumps(("test-key",)),
+ target=tvm.target.Target("llvm"),
hardware_params=auto_scheduler.HardwareParams(100000, 16, 64, 0, 0, 0, 0, 0),
)
diff --git a/tests/python/unittest/test_auto_scheduler_cost_model.py b/tests/python/unittest/test_auto_scheduler_cost_model.py
index 5ed736a..36360da 100644
--- a/tests/python/unittest/test_auto_scheduler_cost_model.py
+++ b/tests/python/unittest/test_auto_scheduler_cost_model.py
@@ -30,7 +30,7 @@ from test_auto_scheduler_common import matmul_auto_scheduler_test
def get_sample_records(number):
"""Generate a list of random MeasureInput and MeasureResult pairs"""
N = 128
- task = auto_scheduler.create_task(matmul_auto_scheduler_test, (N, N, N), "llvm")
+ task = auto_scheduler.SearchTask(func=matmul_auto_scheduler_test, args=(N, N, N), target="llvm")
policy = auto_scheduler.SketchPolicy(task, verbose=0)
states = policy.sample_initial_population()[:number]
diff --git a/tests/python/unittest/test_auto_scheduler_evolutionary_search.py b/tests/python/unittest/test_auto_scheduler_evolutionary_search.py
index 70bea3a..e28219d 100644
--- a/tests/python/unittest/test_auto_scheduler_evolutionary_search.py
+++ b/tests/python/unittest/test_auto_scheduler_evolutionary_search.py
@@ -48,9 +48,9 @@ def test_mutate_tile_size():
scores.append(1 if self.is_good_state(state) else 0)
return scores
- workload_key = auto_scheduler.make_workload_key(matmul_auto_scheduler_test, (10, 10, 4))
- dag = auto_scheduler.ComputeDAG(workload_key)
- task = auto_scheduler.SearchTask(dag, workload_key, tvm.target.Target("llvm"))
+ task = auto_scheduler.SearchTask(
+ func=matmul_auto_scheduler_test, args=(10, 10, 4), target=tvm.target.Target("llvm")
+ )
policy = auto_scheduler.SketchPolicy(task, program_cost_model=MockCostModel(), verbose=0)
states = policy.sample_initial_population()[:50]
@@ -92,7 +92,9 @@ def test_mutate_parallel():
scores.append(1 if self.is_good_state(state) else 0)
return scores
- task = auto_scheduler.create_task(matmul_auto_scheduler_test, (1024, 1024, 1024), "llvm")
+ task = auto_scheduler.SearchTask(
+ func=matmul_auto_scheduler_test, args=(1024, 1024, 1024), target="llvm"
+ )
policy = auto_scheduler.SketchPolicy(task, program_cost_model=MockCostModel(), verbose=0)
found = False
diff --git a/tests/python/unittest/test_auto_scheduler_feature.py b/tests/python/unittest/test_auto_scheduler_feature.py
index 7412dbc..b52b538 100644
--- a/tests/python/unittest/test_auto_scheduler_feature.py
+++ b/tests/python/unittest/test_auto_scheduler_feature.py
@@ -45,7 +45,7 @@ def test_cpu_matmul():
s.unroll(C, k)
target = tvm.target.Target("llvm")
- task = auto_scheduler.SearchTask(dag, "test", target)
+ task = auto_scheduler.SearchTask(compute_dag=dag, workload_key="test", target=target)
names = auto_scheduler.feature.get_per_store_feature_names()
fea = auto_scheduler.feature.get_per_store_features_from_states([s], task)[0]
@@ -103,7 +103,7 @@ def test_cpu_fusion():
s.compute_at(1, 2, s.stages[2].iters[1])
target = tvm.target.Target("llvm")
- task = auto_scheduler.SearchTask(dag, "test", target)
+ task = auto_scheduler.SearchTask(compute_dag=dag, workload_key="test", target=target)
names = auto_scheduler.feature.get_per_store_feature_names()
fea = auto_scheduler.feature.get_per_store_features_from_states([s], task)[0]
@@ -147,18 +147,15 @@ def test_gpu_feature():
inputs, results = auto_scheduler.RecordReader(f.name).read_lines()
inp = inputs[0]
- dag = auto_scheduler.ComputeDAG(inp.task.workload_key)
task = auto_scheduler.SearchTask(
- dag,
- inp.task.workload_key,
- inp.task.target,
- None,
- auto_scheduler.HardwareParams(
+ workload_key=inp.task.workload_key,
+ target=inp.task.target,
+ hardware_params=auto_scheduler.HardwareParams(
100000, 16, 64, 1 << 30, 1 << 30, 1 << 30, 1 << 30, 1 << 30
),
)
- state = dag.infer_bound_from_state(inputs[0].state)
+ state = task.dag.infer_bound_from_state(inputs[0].state)
fea = auto_scheduler.feature.get_per_store_features_from_states([state], task)[0]
names = auto_scheduler.feature.get_per_store_feature_names()
diff --git a/tests/python/unittest/test_auto_scheduler_layout_rewrite.py b/tests/python/unittest/test_auto_scheduler_layout_rewrite.py
index 9d9704d..6ca56bd 100644
--- a/tests/python/unittest/test_auto_scheduler_layout_rewrite.py
+++ b/tests/python/unittest/test_auto_scheduler_layout_rewrite.py
@@ -35,7 +35,7 @@ def test_apply_steps_with_layout_rewrite():
assert bufs[1].shape[0] == 512
assert bufs[1].shape[1] == 512
_, bufs = dag.apply_steps_from_state(
- s, layout_rewrite=auto_scheduler.compute_dag.ComputeDAG.RewriteForPreTransformed
+ s, layout_rewrite=auto_scheduler.LayoutRewriteOption.REWRITE_FOR_PRE_TRANSFORMED
)
assert bufs[1].shape[0] == 4
assert bufs[1].shape[1] == 8
@@ -43,7 +43,7 @@ def test_apply_steps_with_layout_rewrite():
assert bufs[1].shape[3] == 4
assert bufs[1].shape[4] == 512
_, bufs = dag.apply_steps_from_state(
- s, layout_rewrite=auto_scheduler.compute_dag.ComputeDAG.InsertTransformStage
+ s, layout_rewrite=auto_scheduler.LayoutRewriteOption.INSERT_TRANSFORM_STAGE
)
assert bufs[1].shape[0] == 512
assert bufs[1].shape[1] == 512
@@ -53,7 +53,7 @@ def test_apply_steps_with_layout_rewrite():
def test_correctness_layout_rewrite_rewrite_for_preTransformed():
N = 128
target = tvm.target.Target("llvm")
- task = auto_scheduler.create_task(matmul_auto_scheduler_test, (N, N, N), target)
+ task = auto_scheduler.SearchTask(func=matmul_auto_scheduler_test, args=(N, N, N), target=target)
dag = task.compute_dag
with tempfile.NamedTemporaryFile() as fp:
@@ -65,13 +65,13 @@ def test_correctness_layout_rewrite_rewrite_for_preTransformed():
tuning_options = auto_scheduler.TuningOptions(
num_measure_trials=2,
runner=measure_ctx.runner,
- verbose=1,
+ verbose=2,
measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
)
- auto_scheduler.auto_schedule(task, search_policy, tuning_options)
- inp, _ = auto_scheduler.load_best(log_file, task.workload_key, target)
+ task.tune(tuning_options, search_policy=search_policy)
+ inp, _ = auto_scheduler.load_best_record(log_file, task.workload_key, target)
s, bufs = dag.apply_steps_from_state(
- inp.state, layout_rewrite=auto_scheduler.compute_dag.ComputeDAG.RewriteForPreTransformed
+ inp.state, layout_rewrite=auto_scheduler.LayoutRewriteOption.REWRITE_FOR_PRE_TRANSFORMED
)
s_ref, bufs_ref = dag.apply_steps_from_state(inp.state)
np_args = [np.random.randn(*topi.get_const_tuple(x.shape)).astype(x.dtype) for x in bufs]
@@ -123,7 +123,7 @@ def test_correctness_layout_rewrite_rewrite_for_preTransformed():
def test_correctness_layout_rewrite_insert_transform_stage():
N = 128
target = tvm.target.Target("llvm")
- task = auto_scheduler.create_task(matmul_auto_scheduler_test, (N, N, N), target)
+ task = auto_scheduler.SearchTask(func=matmul_auto_scheduler_test, args=(N, N, N), target=target)
dag = task.compute_dag
with tempfile.NamedTemporaryFile() as fp:
@@ -138,10 +138,10 @@ def test_correctness_layout_rewrite_insert_transform_stage():
verbose=1,
measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
)
- auto_scheduler.auto_schedule(task, search_policy, tuning_options)
- inp, _ = auto_scheduler.load_best(log_file, task.workload_key, target)
+ task.tune(tuning_options, search_policy=search_policy)
+ inp, _ = auto_scheduler.load_best_record(log_file, task.workload_key, target)
s, bufs = dag.apply_steps_from_state(
- inp.state, layout_rewrite=auto_scheduler.compute_dag.ComputeDAG.InsertTransformStage
+ inp.state, layout_rewrite=auto_scheduler.LayoutRewriteOption.INSERT_TRANSFORM_STAGE
)
s_ref, bufs_ref = dag.apply_steps_from_state(inp.state)
diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py
index 80ce98d..b214d9c 100644
--- a/tests/python/unittest/test_auto_scheduler_measure.py
+++ b/tests/python/unittest/test_auto_scheduler_measure.py
@@ -29,7 +29,7 @@ from test_auto_scheduler_common import matmul_auto_scheduler_test, get_tiled_mat
def record_common(dag, s):
target = tvm.target.Target("llvm")
- task = auto_scheduler.SearchTask(dag, "test", target)
+ task = auto_scheduler.SearchTask(compute_dag=dag, workload_key="test", target=target)
inp = auto_scheduler.measure.MeasureInput(task, s)
res = auto_scheduler.measure.MeasureResult([0.1], 0, "", 0.2, 1)
@@ -169,7 +169,9 @@ def test_record_pragma_storage_align_rfactor():
def test_recover_measure_input():
- task = auto_scheduler.create_task(matmul_auto_scheduler_test, [512, 512, 512], "llvm")
+ task = auto_scheduler.SearchTask(
+ func=matmul_auto_scheduler_test, args=(512, 512, 512), target="llvm"
+ )
inp = auto_scheduler.measure.MeasureInput(task, task.compute_dag.init_state)
res = auto_scheduler.measure.MeasureResult([0.1], 0, "", 0.2, 1)
@@ -194,7 +196,9 @@ def test_measure_local_builder_runner():
if not tvm.testing.device_enabled("llvm"):
return
- task = auto_scheduler.create_task(matmul_auto_scheduler_test, [512, 512, 512], "llvm")
+ task = auto_scheduler.SearchTask(
+ func=matmul_auto_scheduler_test, args=(512, 512, 512), target="llvm"
+ )
for enable_cpu_cache_flush in [True, False]:
minp = auto_scheduler.MeasureInput(task, task.compute_dag.init_state)
@@ -213,7 +217,9 @@ def test_measure_local_builder_rpc_runner():
if not tvm.testing.device_enabled("llvm"):
return
- task = auto_scheduler.create_task(matmul_auto_scheduler_test, [512, 512, 512], "llvm")
+ task = auto_scheduler.SearchTask(
+ func=matmul_auto_scheduler_test, args=(512, 512, 512), target="llvm"
+ )
for enable_cpu_cache_flush in [True, False]:
minp = auto_scheduler.MeasureInput(task, task.compute_dag.init_state)
@@ -251,5 +257,4 @@ if __name__ == "__main__":
test_record_pragma_storage_align_rfactor()
test_recover_measure_input()
test_measure_local_builder_runner()
- test_measure_local_builder_runner_spawn()
test_measure_local_builder_rpc_runner()
diff --git a/tests/python/unittest/test_auto_scheduler_search_policy.py b/tests/python/unittest/test_auto_scheduler_search_policy.py
index a4f3c4e..1bb7449 100644
--- a/tests/python/unittest/test_auto_scheduler_search_policy.py
+++ b/tests/python/unittest/test_auto_scheduler_search_policy.py
@@ -45,7 +45,7 @@ def search_common(
random.seed(seed)
N = 128
target = tvm.target.Target(target)
- task = auto_scheduler.create_task(workload, (N, N, N), target)
+ task = auto_scheduler.SearchTask(func=workload, args=(N, N, N), target=target)
with tempfile.NamedTemporaryFile() as fp:
log_file = fp.name
@@ -70,11 +70,11 @@ def search_common(
verbose=2,
measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
)
- sch, args = auto_scheduler.auto_schedule(task, search_policy, tuning_options)
- inp, res = auto_scheduler.load_best(log_file, task.workload_key, target)
+ task.tune(tuning_options=tuning_options, search_policy=search_policy)
+ sch, args = task.apply_best(log_file)
print("==== Python Code ====")
- print(task.compute_dag.print_python_code_from_state(inp.state))
+ print(task.print_best(log_file))
try:
print("==== Lowered Stmt ====")
diff --git a/tests/python/unittest/test_auto_scheduler_sketch_generation.py b/tests/python/unittest/test_auto_scheduler_sketch_generation.py
index 1c8b993..74d5729 100644
--- a/tests/python/unittest/test_auto_scheduler_sketch_generation.py
+++ b/tests/python/unittest/test_auto_scheduler_sketch_generation.py
@@ -36,7 +36,7 @@ from test_auto_scheduler_common import (
def generate_sketches(workload_func, args, target, print_for_debug=False):
- task = auto_scheduler.create_task(workload_func, args, tvm.target.Target(target))
+ task = auto_scheduler.SearchTask(func=workload_func, args=args, target=target)
policy = auto_scheduler.SketchPolicy(task, verbose=0)
return policy.generate_sketches(print_for_debug)
diff --git a/tests/python/unittest/test_auto_scheduler_task_scheduler.py b/tests/python/unittest/test_auto_scheduler_task_scheduler.py
index 680a783..032933f 100644
--- a/tests/python/unittest/test_auto_scheduler_task_scheduler.py
+++ b/tests/python/unittest/test_auto_scheduler_task_scheduler.py
@@ -32,7 +32,11 @@ from test_auto_scheduler_common import matmul_auto_scheduler_test
def test_task_scheduler_round_robin():
tasks = []
for n in [2, 4, 8]:
- tasks.append(auto_scheduler.create_task(matmul_auto_scheduler_test, (n, n, n), "llvm"))
+ tasks.append(
+ auto_scheduler.SearchTask(
+ func=matmul_auto_scheduler_test, args=(n, n, n), target="llvm"
+ )
+ )
with tempfile.NamedTemporaryFile() as fp:
log_file = fp.name
@@ -90,7 +94,11 @@ def test_task_scheduler_round_robin_spawn():
def test_task_scheduler_gradient():
tasks = []
for n in [2, 4]:
- tasks.append(auto_scheduler.create_task(matmul_auto_scheduler_test, (n, n, n), "llvm"))
+ tasks.append(
+ auto_scheduler.SearchTask(
+ func=matmul_auto_scheduler_test, args=(n, n, n), target="llvm"
+ )
+ )
def objective_func(costs):
return costs[0]
diff --git a/tutorials/auto_scheduler/ci_logs/matmul.json b/tutorials/auto_scheduler/ci_logs/matmul.json
index 827cfc9..bc5d6f0 100644
--- a/tutorials/auto_scheduler/ci_logs/matmul.json
+++ b/tutorials/auto_scheduler/ci_logs/matmul.json
@@ -1,2 +1,2 @@
# Keep a valid schedule for demonstraction. This is used to prevent flasky errors in CI.
-{"i": [["[\"matmul_add\", 128, 128, 128, \"float32\"]", "llvm -keys=cpu"], [[], [["SP", 2, 0, 128, [4, 2, 4], 1], ["SP", 2, 4, 128, [1, 32, 2], 1], ["SP", 2, 8, 128, [2], 1], ["RE", 2, [0, 4, 1, 5, 8, 2, 6, 9, 3, 7]], ["FSP", 4, 0, 0, 1], ["FSP", 4, 2, 1, 1], ["RE", 4, [0, 2, 1, 3]], ["CA", 2, 4, 1], ["FU", 4, [0, 1]], ["AN", 4, 0, 3], ["PR", 2, 0, "auto_unroll_max_step$0"], ["AN", 2, 9, 2]]]], "r": [[5.80388e-05], 0, 0.299169, 1603402396], "v": "v0.2"}
+{"i": [["[\"matmul_add\", 1024, 1024, 1024, \"float32\"]", "llvm -keys=cpu -link-params=0", [24, 64, 64, 0, 0, 0, 0, 0]], [[], [["SP", 2, 0, 1024, [2, 4, 16], 1], ["SP", 2, 4, 1024, [16, 4, 16], 1], ["SP", 2, 8, 1024, [8], 1], ["RE", 2, [0, 4, 1, 5, 8, 2, 6, 9, 3, 7]], ["FU", 2, [0, 1, 2, 3]], ["AN", 2, 0, 3], ["AN", 4, 0, 3], ["PR", 2, 0, "auto_unroll_max_step$0"], ["AN", 2, 6, 2]]]], "r": [[0.028777], 0, 0.613435, 1607038574], "v": "v0.3"}
diff --git a/tutorials/auto_scheduler/tune_conv2d_layer_cuda.py b/tutorials/auto_scheduler/tune_conv2d_layer_cuda.py
index 9aeea84..103ceb4 100644
--- a/tutorials/auto_scheduler/tune_conv2d_layer_cuda.py
+++ b/tutorials/auto_scheduler/tune_conv2d_layer_cuda.py
@@ -71,9 +71,12 @@ target = tvm.target.Target("cuda")
# Use the last layer in ResNet-50
N, H, W, CO, CI, KH, KW, strides, padding = 1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)
-task = auto_scheduler.create_task(conv2d_layer, (N, H, W, CO, CI, KH, KW, strides, padding), target)
+task = auto_scheduler.SearchTask(
+ func=conv2d_layer, args=(N, H, W, CO, CI, KH, KW, strides, padding), target=target
+)
# Inspect the computational graph
+print("Computational DAG:")
print(task.compute_dag)
######################################################################
@@ -109,11 +112,15 @@ tune_option = auto_scheduler.TuningOptions(
# ^^^^^^^^^^^^^^
# Now we get all inputs ready. Pretty simple, isn't it?
# We can kick off the search and let the auto-scheduler do its magic.
-# After some measurement trials, it will return the best schedule it found.
+# After some measurement trials, we can load the best schedule from the log
+# file and apply it.
-sch, args = auto_scheduler.auto_schedule(task, tuning_options=tune_option)
+# Run auto-tuning (search)
+task.tune(tune_option)
+# Apply the best schedule
+sch, args = task.apply_best(log_file)
-# Kill the process for measurement
+# Kill the measurement process
del measure_ctx
######################################################################
@@ -121,6 +128,7 @@ del measure_ctx
# The auto-scheduler correctly performs optimizations including multi-level tiling,
# cooperative fetching, unrolling and operator fusion.
+print("Lowered TIR:")
print(tvm.lower(sch, args, simple_mode=True))
######################################################################
@@ -157,26 +165,20 @@ print(
######################################################################
# Using the record file
# ^^^^^^^^^^^^^^^^^^^^^
-# During the search, all measuremnt records are dumpped into the record
+# During the search, all measurement records are dumped into the record
# file "conv2d.json". The measurement records can be used to re-apply search results,
# resume the search, and perform other analyses.
######################################################################
# Here is an example where we load the best schedule from a file,
-# print the equivalent python schedule API, and build the binary again.
+# print the equivalent python schedule API and CUDA source code.
+# They can be used for debugging and learning the behavior of the auto-scheduler.
-# Load the measuremnt record for the best schedule
-inp, res = auto_scheduler.load_best(log_file, task.workload_key)
-
-# Print equivalent python schedule API. This can be used for debugging and
-# learning the behavior of the auto-scheduler.
print("Equivalent python schedule:")
-print(task.compute_dag.print_python_code_from_state(inp.state))
+print(task.print_best(log_file, print_mode="schedule"))
-# Rebuild the binary. This shows how you can apply the best schedule from a
-# log file without reruning the search again.
-sch, args = task.compute_dag.apply_steps_from_state(inp.state)
-func = tvm.build(sch, args, target)
+print("CUDA source code:")
+print(task.print_best(log_file, print_mode="cuda"))
######################################################################
# A more complicated example is to resume the search.
@@ -195,7 +197,7 @@ tune_option = auto_scheduler.TuningOptions(
runner=measure_ctx.runner,
measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
)
-sch, args = auto_scheduler.auto_schedule(task, search_policy, tuning_options=tune_option)
+task.tune(tune_option, search_policy=search_policy)
# Kill the measurement process
del measure_ctx
diff --git a/tutorials/auto_scheduler/tune_matmul_x86.py b/tutorials/auto_scheduler/tune_matmul_x86.py
index 6d75629..bdd14be 100644
--- a/tutorials/auto_scheduler/tune_matmul_x86.py
+++ b/tutorials/auto_scheduler/tune_matmul_x86.py
@@ -56,7 +56,12 @@ def matmul_add(N, L, M, dtype):
C = te.placeholder((N, M), name="C", dtype=dtype)
k = te.reduce_axis((0, L), name="k")
- matmul = te.compute((N, M), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name="matmul")
+ matmul = te.compute(
+ (N, M),
+ lambda i, j: te.sum(A[i, k] * B[k, j], axis=k),
+ name="matmul",
+ attrs={"layout_free_placeholders": [B]}, # enable automatic layout transform for tensor B
+ )
out = te.compute((N, M), lambda i, j: matmul[i, j] + C[i, j], name="out")
return [A, B, C, out]
@@ -65,16 +70,18 @@ def matmul_add(N, L, M, dtype):
######################################################################
# Create the search task
# ^^^^^^^^^^^^^^^^^^^^^^
-# We then create a search task with N=L=M=128 and dtype="float32"
+# We then create a search task with N=L=M=1024 and dtype="float32"
# If your machine supports avx instructions, you can
#
# - replace "llvm" below with "llvm -mcpu=core-avx2" to enable AVX2
# - replace "llvm" below with "llvm -mcpu=skylake-avx512" to enable AVX-512
target = tvm.target.Target("llvm")
-task = tvm.auto_scheduler.create_task(matmul_add, (128, 128, 128, "float32"), target)
+N = L = M = 1024
+task = tvm.auto_scheduler.SearchTask(func=matmul_add, args=(N, L, M, "float32"), target=target)
# Inspect the computational graph
+print("Computational DAG:")
print(task.compute_dag)
######################################################################
@@ -100,15 +107,20 @@ tune_option = auto_scheduler.TuningOptions(
# ^^^^^^^^^^^^^^
# Now we get all inputs ready. Pretty simple, isn't it?
# We can kick off the search and let the auto-scheduler do its magic.
-# After some measurement trials, it will return the best schedule it found.
+# After some measurement trials, we can load the best schedule from the log
+# file and apply it.
-sch, args = auto_scheduler.auto_schedule(task, tuning_options=tune_option)
+# Run auto-tuning (search)
+task.tune(tune_option)
+# Apply the best schedule
+sch, args = task.apply_best(log_file)
######################################################################
# We can lower the schedule to see the IR after auto-scheduling.
# The auto-scheduler correctly performs optimizations including multi-level tiling,
# parallelization, vectorization, unrolling and operator fusion.
+print("Lowered TIR:")
print(tvm.lower(sch, args, simple_mode=True))
######################################################################
@@ -116,10 +128,10 @@ print(tvm.lower(sch, args, simple_mode=True))
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# We build the binary and check its correctness and performance.
-func = tvm.build(sch, args)
-a_np = np.random.uniform(size=(128, 128)).astype(np.float32)
-b_np = np.random.uniform(size=(128, 128)).astype(np.float32)
-c_np = np.random.uniform(size=(128, 128)).astype(np.float32)
+func = tvm.build(sch, args, target)
+a_np = np.random.uniform(size=(N, L)).astype(np.float32)
+b_np = np.random.uniform(size=(L, M)).astype(np.float32)
+c_np = np.random.uniform(size=(N, M)).astype(np.float32)
out_np = a_np.dot(b_np) + c_np
ctx = tvm.cpu()
@@ -143,26 +155,17 @@ print(
######################################################################
# Using the record file
# ^^^^^^^^^^^^^^^^^^^^^
-# During the search, all measuremnt records are dumpped into the record
+# During the search, all measurement records are dumped into the record
# file "matmul.json". The measurement records can be used to re-apply search results,
# resume the search, and perform other analyses.
######################################################################
# Here is an example where we load the best schedule from a file,
-# print the equivalent python schedule API, and build the binary again.
-
-# Load the measuremnt record for the best schedule
-inp, res = auto_scheduler.load_best(log_file, task.workload_key)
+# and print the equivalent python schedule API. This can be used for
+# debugging and learning the behavior of the auto-scheduler.
-# Print equivalent python schedule API. This can be used for debugging and
-# learning the behavior of the auto-scheduler.
print("Equivalent python schedule:")
-print(task.compute_dag.print_python_code_from_state(inp.state))
-
-# Rebuild the binary. This shows how you can apply the best schedule from a
-# log file without reruning the search again.
-sch, args = task.compute_dag.apply_steps_from_state(inp.state)
-func = tvm.build(sch, args)
+print(task.print_best(log_file))
######################################################################
# A more complicated example is to resume the search.
@@ -182,7 +185,7 @@ def resume_search(task, log_file_name):
tune_option = auto_scheduler.TuningOptions(
num_measure_trials=5, measure_callbacks=[auto_scheduler.RecordToFile(log_file_name)]
)
- sch, args = auto_scheduler.auto_schedule(task, search_policy, tuning_options=tune_option)
+ task.tune(tune_option, search_policy=search_policy)
# resume_search(task, log_file)
diff --git a/tutorials/auto_scheduler/tune_network_cuda.py b/tutorials/auto_scheduler/tune_network_cuda.py
index 90f531f..03be05a 100644
--- a/tutorials/auto_scheduler/tune_network_cuda.py
+++ b/tutorials/auto_scheduler/tune_network_cuda.py
@@ -299,9 +299,10 @@ print("Mean inference time (std dev): %.2f ms (%.2f ms)" % (np.mean(prof_res), n
# 1. During the tuning, the auto-scheduler needs to compile many programs and
# extract feature from them. This part is CPU-intensive,
# so a high-performance CPU with many cores is recommended for faster search.
-# 2. If you have multiple GPUs, you can use all of them for measurements to
+# 2. If you have multiple target GPUs, you can use all of them for measurements to
# parallelize the measurements. Check this :ref:`section <tutorials-autotvm-rpc-tracker>`
# to learn how to use the RPC Tracker and RPC Server.
# To use the RPC Tracker in auto-scheduler, replace the runner in :code:`TuningOptions`
# with :any:`auto_scheduler.RPCRunner`.
-#
+# 3. You can use :code:`python3 -m tvm.auto_scheduler.measure_record --mode distill --i log.json`
+# to distill the large log file and only save the best useful records.
diff --git a/tutorials/auto_scheduler/tune_network_x86.py b/tutorials/auto_scheduler/tune_network_x86.py
index 8dd9230..aba75b2 100644
--- a/tutorials/auto_scheduler/tune_network_x86.py
+++ b/tutorials/auto_scheduler/tune_network_x86.py
@@ -303,4 +303,5 @@ print("Mean inference time (std dev): %.2f ms (%.2f ms)" % (np.mean(prof_res), n
# to learn how to use the RPC Tracker and RPC Server.
# To use the RPC Tracker in auto-scheduler, replace the runner in :code:`TuningOptions`
# with :any:`auto_scheduler.RPCRunner`.
-#
+# 3. You can use :code:`python3 -m tvm.auto_scheduler.measure_record --mode distill --i log.json`
+# to distill the large log file and only save the best useful records.