You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jc...@apache.org on 2020/12/28 01:54:26 UTC
[tvm] branch main updated: [AutoScheduler] Update layout rewrite
option setting for measuring (#7156)
This is an automated email from the ASF dual-hosted git repository.
jcf94 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 2dec2dd [AutoScheduler] Update layout rewrite option setting for measuring (#7156)
2dec2dd is described below
commit 2dec2dd9a7836e142effa9af2f4ff7a3bb3b0d44
Author: Chenfan <ch...@alibaba-inc.com>
AuthorDate: Mon Dec 28 09:54:05 2020 +0800
[AutoScheduler] Update layout rewrite option setting for measuring (#7156)
* Add layout rewrite options for measure
* Update schedule for inserted transform stage
* Set layout rewrite when tuning for network
* Update the log version
---
include/tvm/auto_scheduler/measure_record.h | 2 +-
include/tvm/auto_scheduler/search_task.h | 6 ++++-
python/tvm/auto_scheduler/compute_dag.py | 36 +++++++++++++++++++++++++-
python/tvm/auto_scheduler/measure.py | 4 +--
python/tvm/auto_scheduler/relay_integration.py | 16 ++++++------
python/tvm/auto_scheduler/search_task.py | 27 ++++++++++++++-----
src/auto_scheduler/compute_dag.cc | 20 +++++++++++---
src/auto_scheduler/feature.cc | 6 +++--
src/auto_scheduler/measure_record.cc | 12 ++++++++-
src/auto_scheduler/search_task.cc | 10 ++++---
10 files changed, 110 insertions(+), 29 deletions(-)
diff --git a/include/tvm/auto_scheduler/measure_record.h b/include/tvm/auto_scheduler/measure_record.h
index 4d7952f..ec40611 100755
--- a/include/tvm/auto_scheduler/measure_record.h
+++ b/include/tvm/auto_scheduler/measure_record.h
@@ -34,7 +34,7 @@
namespace tvm {
namespace auto_scheduler {
-const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.4"; // NOLINT(*)
+const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.5"; // NOLINT(*)
/*! \brief Callback for logging the input and results of measurements to file */
class RecordToFileNode : public MeasureCallbackNode {
diff --git a/include/tvm/auto_scheduler/search_task.h b/include/tvm/auto_scheduler/search_task.h
index 60e721b..9e7d3aa 100755
--- a/include/tvm/auto_scheduler/search_task.h
+++ b/include/tvm/auto_scheduler/search_task.h
@@ -118,6 +118,8 @@ class SearchTaskNode : public Object {
Target target_host;
/*! \brief Hardware parameters used in this search task. */
HardwareParams hardware_params;
+ /*! \brief The layout rewrite option used for measuring programs. */
+ LayoutRewriteOption layout_rewrite_option;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("compute_dag", &compute_dag);
@@ -125,6 +127,7 @@ class SearchTaskNode : public Object {
v->Visit("target", &target);
v->Visit("target_host", &target_host);
v->Visit("hardware_params", &hardware_params);
+ v->Visit("layout_rewrite_option", &layout_rewrite_option);
}
static constexpr const char* _type_key = "auto_scheduler.SearchTask";
@@ -144,9 +147,10 @@ class SearchTask : public ObjectRef {
* \param target The target device of this search task.
* \param target_host The target host device of this search task.
* \param hardware_params Hardware parameters used in this search task.
+ * \param layout_rewrite_option The layout rewrite option used for measuring programs.
*/
SearchTask(ComputeDAG compute_dag, String workload_key, Target target, Target target_host,
- Optional<HardwareParams> hardware_params);
+ Optional<HardwareParams> hardware_params, LayoutRewriteOption layout_rewrite_option);
TVM_DEFINE_OBJECT_REF_METHODS(SearchTask, ObjectRef, SearchTaskNode);
};
diff --git a/python/tvm/auto_scheduler/compute_dag.py b/python/tvm/auto_scheduler/compute_dag.py
index d8a2422..a7f200a 100755
--- a/python/tvm/auto_scheduler/compute_dag.py
+++ b/python/tvm/auto_scheduler/compute_dag.py
@@ -32,7 +32,12 @@ from .workload_registry import workload_key_to_tensors
class LayoutRewriteOption:
- """Options for applying layout rewrite."""
+ """
+ Options for applying layout rewrite.
+
+ The NO_REWRITE and INSERT_TRANSFORM_STAGE are expected to be used when tuning a standalone op,
+ and the REWRITE_FOR_PRE_TRANSFORMED is expected to be used when tuning ops inside a network.
+ """
# Do not perform layout rewrite
NO_REWRITE = 0
@@ -44,6 +49,35 @@ class LayoutRewriteOption:
# so this option must be used along with `AutoSchedulerLayoutRewrite` pass in Relay.
REWRITE_FOR_PRE_TRANSFORMED = 2
+ @staticmethod
+ def get_target_default(target, in_relay_integration=False):
+ """Get the default layout rewrite option for the specified target.
+ Currently we only enable layout rewrite for cpu / mali backend for now
+
+ Parameters
+ ----------
+ target: tvm.target.Target
+ The compilation target.
+ in_relay_integration: bool
+ If this check is ask for relay integration.
+
+ Returns
+ -------
+ layout_rewrite_option: LayoutRewriteOption
+ The default layout rewrite option for the specified target.
+ """
+ layout_rewrite_option = LayoutRewriteOption.NO_REWRITE
+ if target.kind.name == "llvm" or (
+ "device" in target.attrs and target.attrs["device"] == "mali"
+ ):
+ layout_rewrite_option = (
+ LayoutRewriteOption.REWRITE_FOR_PRE_TRANSFORMED
+ if in_relay_integration
+ else LayoutRewriteOption.INSERT_TRANSFORM_STAGE
+ )
+
+ return layout_rewrite_option
+
@tvm._ffi.register_object("auto_scheduler.ComputeDAG")
class ComputeDAG(Object):
diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py
index 24a7577..2f177a2 100644
--- a/python/tvm/auto_scheduler/measure.py
+++ b/python/tvm/auto_scheduler/measure.py
@@ -53,7 +53,6 @@ from .utils import (
make_traceback_info,
request_remote,
)
-from .compute_dag import LayoutRewriteOption
from .workload_registry import (
serialize_workload_registry_entry,
deserialize_workload_registry_entry,
@@ -211,6 +210,7 @@ def recover_measure_input(inp, rebuild_state=False):
target=task.target,
target_host=task.target_host,
hardware_params=task.hardware_params,
+ layout_rewrite_option=task.layout_rewrite_option,
)
if rebuild_state:
@@ -576,7 +576,7 @@ def _timed_func(inp_serialized, build_func, verbose):
try:
sch, args = task.compute_dag.apply_steps_from_state(
- inp.state, layout_rewrite=LayoutRewriteOption.REWRITE_FOR_PRE_TRANSFORMED
+ inp.state, layout_rewrite=task.layout_rewrite_option
)
# pylint: disable=broad-except
except Exception:
diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py
index 2b26fc4..3287f3d 100644
--- a/python/tvm/auto_scheduler/relay_integration.py
+++ b/python/tvm/auto_scheduler/relay_integration.py
@@ -33,7 +33,7 @@ from tvm.runtime import convert_to_object
from tvm.te.tensor import ComputeOp, PlaceholderOp, Tensor
from tvm.tir import expr as _expr
from . import _ffi_api
-from .compute_dag import ComputeDAG
+from .compute_dag import ComputeDAG, LayoutRewriteOption
from .dispatcher import DispatchContext
from .search_task import SearchTask
from .workload_registry import register_workload_tensors
@@ -126,6 +126,9 @@ def extract_tasks(
target=target,
target_host=target_host,
hardware_params=hardware_params,
+ # When auto scheduler is used in end to end network, try to apply layout rewrite
+ # to improve the overall performance
+ layout_rewrite_option=LayoutRewriteOption.get_target_default(target, True),
)
)
weights.append(use_count_dict[ccache_key] + 1)
@@ -259,13 +262,7 @@ def auto_schedule_topi(outs, has_complex_op):
key = register_workload_tensors(dag.hash_key(), io_tensors)
- # only enable layout rewrite for cpu / mali backend
target = tvm.target.Target.current()
- enable_layout_rewrite_targets = ["cpu", "mali"]
- enable_layout_rewrite = any(
- enable_layout_rewrite_target in target.keys
- for enable_layout_rewrite_target in enable_layout_rewrite_targets
- )
env = TracingEnvironment.current
if env is None:
@@ -284,7 +281,10 @@ def auto_schedule_topi(outs, has_complex_op):
schedule = te.create_schedule([x.op for x in outs])
elif env.tracing_mode == TracingMode.PREPARE_LAYOUT_REWRITE:
# in prepare_layout_rewrite mode
- if enable_layout_rewrite and has_layout_free:
+ if (
+ LayoutRewriteOption.get_target_default(target, True) != LayoutRewriteOption.NO_REWRITE
+ and has_layout_free
+ ):
dispatch_ctx = DispatchContext.current
state = dispatch_ctx.query(target, key, has_complex_op, dag)
if state is None:
diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py
index be83e06..bfa596a 100644
--- a/python/tvm/auto_scheduler/search_task.py
+++ b/python/tvm/auto_scheduler/search_task.py
@@ -178,6 +178,13 @@ class SearchTask(Object):
The target host device of this search task.
hardware_params : Optional[HardwareParams]
Hardware parameters used in this search task.
+ layout_rewrite_option : Optional[LayoutRewriteOption]
+ The layout rewrite option used for measuring programs. If None, the default value will be
+ set depending on the specified target.
+ Auto_scheduler will find a better schedule for the specified layout rewrite option.
+ The NO_REWRITE and INSERT_TRANSFORM_STAGE are expected to be used when tuning a standalone
+ op, and the REWRITE_FOR_PRE_TRANSFORMED is expected to be used when tuning ops inside a
+ network.
Examples
--------
@@ -204,6 +211,7 @@ class SearchTask(Object):
target=None,
target_host=None,
hardware_params=None,
+ layout_rewrite_option=None,
):
assert (
func is not None or workload_key is not None
@@ -221,7 +229,13 @@ class SearchTask(Object):
target_host = Target(target_host)
self.__init_handle_by_constructor__(
- _ffi_api.SearchTask, compute_dag, workload_key, target, target_host, hardware_params
+ _ffi_api.SearchTask,
+ compute_dag,
+ workload_key,
+ target,
+ target_host,
+ hardware_params,
+ layout_rewrite_option or LayoutRewriteOption.get_target_default(target),
)
def tune(self, tuning_options, search_policy=None):
@@ -250,6 +264,7 @@ class SearchTask(Object):
layout_rewrite_option : Optional[LayoutRewriteOption]
The layout rewrite option.
+
Returns
-------
A `te.Schedule` and the a list of `te.Tensor` to be used in `tvm.lower` or `tvm.build`.
@@ -260,11 +275,9 @@ class SearchTask(Object):
"Cannot find any valid schedule for %s in file %s" % (self.workload_key, log_file)
)
- if layout_rewrite_option is None:
- layout_rewrite_option = LayoutRewriteOption.NO_REWRITE
- if self.target.kind.name == "llvm":
- layout_rewrite_option = LayoutRewriteOption.INSERT_TRANSFORM_STAGE
- sch, args = self.compute_dag.apply_steps_from_state(inp.state, layout_rewrite_option)
+ sch, args = self.compute_dag.apply_steps_from_state(
+ inp.state, layout_rewrite_option or self.layout_rewrite_option
+ )
return sch, args
def print_best(self, log_file, print_mode="schedule"):
@@ -305,6 +318,7 @@ class SearchTask(Object):
"target": self.target,
"target_host": self.target_host,
"hardware_params": self.hardware_params,
+ "layout_rewrite_option": self.layout_rewrite_option,
}
def __setstate__(self, state):
@@ -327,6 +341,7 @@ class SearchTask(Object):
state["target"],
state["target_host"],
state["hardware_params"],
+ state["layout_rewrite_option"],
)
diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc
index 64114c8..b658782 100755
--- a/src/auto_scheduler/compute_dag.cc
+++ b/src/auto_scheduler/compute_dag.cc
@@ -998,11 +998,20 @@ ComputeDAG ComputeDAG::RewriteLayout(Array<Step>* transform_steps,
transform_steps->Set(i, std::move(step));
}
}
+
+ // Add schedule for the new added transform stage
Array<Integer> to_fuse;
- for (size_t i = 0; i < new_shape.size() - 1; i++) {
- to_fuse.push_back(i);
+
+ if (new_shape.size() >= 5) {
+ to_fuse.push_back(0);
+ to_fuse.push_back(1);
+ to_fuse.push_back(2);
+ transform_steps->push_back(FuseStep(stage_id, to_fuse));
+ } else if (new_shape.size() >= 3) {
+ to_fuse.push_back(0);
+ to_fuse.push_back(1);
+ transform_steps->push_back(FuseStep(stage_id, to_fuse));
}
- transform_steps->push_back(FuseStep(stage_id, to_fuse));
transform_steps->push_back(AnnotationStep(stage_id, 0, IteratorAnnotation::kParallel));
}
@@ -1024,7 +1033,10 @@ ComputeDAG ComputeDAG::RewriteLayout(Array<Step>* transform_steps,
}
original_compute_op = op;
CHECK(!new_compute_op.defined());
- new_compute_op = te::ComputeOp(pop->name, pop->tag, pop->attrs, pop->axis, new_body);
+ auto new_attrs = pop->attrs;
+ new_attrs.Set("ori_placeholder_layout", tvm::String(origin_layout));
+ new_attrs.Set("new_placeholder_layout", tvm::String(new_layout));
+ new_compute_op = te::ComputeOp(pop->name, pop->tag, new_attrs, pop->axis, new_body);
}
}
}
diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc
index 53287a0..47b9fb6 100755
--- a/src/auto_scheduler/feature.cc
+++ b/src/auto_scheduler/feature.cc
@@ -1398,7 +1398,8 @@ void GetPerStoreFeaturesFromFile(const std::string& filename, int max_lines, int
// rebuild task
Array<te::Tensor> tensors = (*workload_key_to_tensors)(workload_key);
task = SearchTask(ComputeDAG(tensors), workload_key, cur_inp->task->target,
- cur_inp->task->target_host, cur_inp->task->hardware_params);
+ cur_inp->task->target_host, cur_inp->task->hardware_params,
+ cur_inp->task->layout_rewrite_option);
task_id = task_cache.size();
// compute min cost for each task
@@ -1465,7 +1466,8 @@ void GetPerStoreFeaturesFromMeasurePairs(const Array<MeasureInput>& inputs,
// rebuild task for incomplete measure pairs read from file
Array<te::Tensor> tensors = (*workload_key_to_tensors)(workload_key);
task = SearchTask(ComputeDAG(tensors), workload_key, inputs[i]->task->target,
- inputs[i]->task->target_host, inputs[i]->task->hardware_params);
+ inputs[i]->task->target_host, inputs[i]->task->hardware_params,
+ inputs[i]->task->layout_rewrite_option);
}
task_id = task_cache.size();
diff --git a/src/auto_scheduler/measure_record.cc b/src/auto_scheduler/measure_record.cc
index faf3fca..1120f43 100644
--- a/src/auto_scheduler/measure_record.cc
+++ b/src/auto_scheduler/measure_record.cc
@@ -165,12 +165,16 @@ struct Handler<::tvm::auto_scheduler::SearchTaskNode> {
writer->WriteArrayItem(*data.hardware_params.get());
if (data.target_host.defined()) {
writer->WriteArrayItem(data.target_host->str());
+ } else {
+ writer->WriteArrayItem(std::string(""));
}
+ writer->WriteArrayItem(static_cast<int>(data.layout_rewrite_option));
writer->EndArray();
}
inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_scheduler::SearchTaskNode* data) {
bool s;
std::string str_value;
+ int int_value;
auto hardware_params_node = ::tvm::make_object<::tvm::auto_scheduler::HardwareParamsNode>();
reader->BeginArray();
s = reader->NextArrayItem();
@@ -188,7 +192,13 @@ struct Handler<::tvm::auto_scheduler::SearchTaskNode> {
data->hardware_params = ::tvm::auto_scheduler::HardwareParams(hardware_params_node);
if (s) {
reader->Read(&str_value);
- data->target_host = ::tvm::Target(str_value);
+ if (!str_value.empty()) {
+ data->target_host = ::tvm::Target(str_value);
+ }
+ s = reader->NextArrayItem();
+ ICHECK(s);
+ reader->Read(&int_value);
+ data->layout_rewrite_option = ::tvm::auto_scheduler::LayoutRewriteOption(int_value);
s = reader->NextArrayItem();
ICHECK(!s);
}
diff --git a/src/auto_scheduler/search_task.cc b/src/auto_scheduler/search_task.cc
index 93f3460..0abee16 100755
--- a/src/auto_scheduler/search_task.cc
+++ b/src/auto_scheduler/search_task.cc
@@ -113,7 +113,8 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target
}
SearchTask::SearchTask(ComputeDAG compute_dag, String workload_key, Target target,
- Target target_host, Optional<HardwareParams> hardware_params) {
+ Target target_host, Optional<HardwareParams> hardware_params,
+ LayoutRewriteOption layout_rewrite_option) {
auto node = make_object<SearchTaskNode>();
node->compute_dag = std::move(compute_dag);
node->workload_key = std::move(workload_key);
@@ -125,6 +126,7 @@ SearchTask::SearchTask(ComputeDAG compute_dag, String workload_key, Target targe
node->hardware_params =
HardwareParamsNode::GetDefaultHardwareParams(node->target, node->target_host);
}
+ node->layout_rewrite_option = layout_rewrite_option;
data_ = std::move(node);
}
@@ -139,8 +141,10 @@ TVM_REGISTER_GLOBAL("auto_scheduler.HardwareParams")
TVM_REGISTER_GLOBAL("auto_scheduler.SearchTask")
.set_body_typed([](ComputeDAG compute_dag, String workload_key, Target target,
- Target target_host, Optional<HardwareParams> hardware_params) {
- return SearchTask(compute_dag, workload_key, target, target_host, hardware_params);
+ Target target_host, Optional<HardwareParams> hardware_params,
+ int layout_rewrite_option) {
+ return SearchTask(compute_dag, workload_key, target, target_host, hardware_params,
+ LayoutRewriteOption(layout_rewrite_option));
});
} // namespace auto_scheduler