You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jc...@apache.org on 2020/12/28 01:54:26 UTC

[tvm] branch main updated: [AutoScheduler] Update layout rewrite option setting for measuring (#7156)

This is an automated email from the ASF dual-hosted git repository.

jcf94 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2dec2dd  [AutoScheduler] Update layout rewrite option setting for measuring (#7156)
2dec2dd is described below

commit 2dec2dd9a7836e142effa9af2f4ff7a3bb3b0d44
Author: Chenfan <ch...@alibaba-inc.com>
AuthorDate: Mon Dec 28 09:54:05 2020 +0800

    [AutoScheduler] Update layout rewrite option setting for measuring (#7156)
    
    * Add layout rewrite options for measure
    
    * Update schedule for inserted transform stage
    
    * Set layout rewrite when tuning for network
    
    * Update the log version
---
 include/tvm/auto_scheduler/measure_record.h    |  2 +-
 include/tvm/auto_scheduler/search_task.h       |  6 ++++-
 python/tvm/auto_scheduler/compute_dag.py       | 36 +++++++++++++++++++++++++-
 python/tvm/auto_scheduler/measure.py           |  4 +--
 python/tvm/auto_scheduler/relay_integration.py | 16 ++++++------
 python/tvm/auto_scheduler/search_task.py       | 27 ++++++++++++++-----
 src/auto_scheduler/compute_dag.cc              | 20 +++++++++++---
 src/auto_scheduler/feature.cc                  |  6 +++--
 src/auto_scheduler/measure_record.cc           | 12 ++++++++-
 src/auto_scheduler/search_task.cc              | 10 ++++---
 10 files changed, 110 insertions(+), 29 deletions(-)

diff --git a/include/tvm/auto_scheduler/measure_record.h b/include/tvm/auto_scheduler/measure_record.h
index 4d7952f..ec40611 100755
--- a/include/tvm/auto_scheduler/measure_record.h
+++ b/include/tvm/auto_scheduler/measure_record.h
@@ -34,7 +34,7 @@
 namespace tvm {
 namespace auto_scheduler {
 
-const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.4";  // NOLINT(*)
+const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.5";  // NOLINT(*)
 
 /*! \brief Callback for logging the input and results of measurements to file */
 class RecordToFileNode : public MeasureCallbackNode {
diff --git a/include/tvm/auto_scheduler/search_task.h b/include/tvm/auto_scheduler/search_task.h
index 60e721b..9e7d3aa 100755
--- a/include/tvm/auto_scheduler/search_task.h
+++ b/include/tvm/auto_scheduler/search_task.h
@@ -118,6 +118,8 @@ class SearchTaskNode : public Object {
   Target target_host;
   /*! \brief Hardware parameters used in this search task. */
   HardwareParams hardware_params;
+  /*! \brief The layout rewrite option used for measuring programs. */
+  LayoutRewriteOption layout_rewrite_option;
 
   void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("compute_dag", &compute_dag);
@@ -125,6 +127,7 @@ class SearchTaskNode : public Object {
     v->Visit("target", &target);
     v->Visit("target_host", &target_host);
     v->Visit("hardware_params", &hardware_params);
+    v->Visit("layout_rewrite_option", &layout_rewrite_option);
   }
 
   static constexpr const char* _type_key = "auto_scheduler.SearchTask";
@@ -144,9 +147,10 @@ class SearchTask : public ObjectRef {
    * \param target The target device of this search task.
    * \param target_host The target host device of this search task.
    * \param hardware_params Hardware parameters used in this search task.
+   * \param layout_rewrite_option The layout rewrite option used for measuring programs.
    */
   SearchTask(ComputeDAG compute_dag, String workload_key, Target target, Target target_host,
-             Optional<HardwareParams> hardware_params);
+             Optional<HardwareParams> hardware_params, LayoutRewriteOption layout_rewrite_option);
 
   TVM_DEFINE_OBJECT_REF_METHODS(SearchTask, ObjectRef, SearchTaskNode);
 };
diff --git a/python/tvm/auto_scheduler/compute_dag.py b/python/tvm/auto_scheduler/compute_dag.py
index d8a2422..a7f200a 100755
--- a/python/tvm/auto_scheduler/compute_dag.py
+++ b/python/tvm/auto_scheduler/compute_dag.py
@@ -32,7 +32,12 @@ from .workload_registry import workload_key_to_tensors
 
 
 class LayoutRewriteOption:
-    """Options for applying layout rewrite."""
+    """
+    Options for applying layout rewrite.
+
+    The NO_REWRITE and INSERT_TRANSFORM_STAGE are expected to be used when tuning a standalone op,
+    and the REWRITE_FOR_PRE_TRANSFORMED is expected to be used when tuning ops inside a network.
+    """
 
     # Do not perform layout rewrite
     NO_REWRITE = 0
@@ -44,6 +49,35 @@ class LayoutRewriteOption:
     # so this option must be used along with `AutoSchedulerLayoutRewrite` pass in Relay.
     REWRITE_FOR_PRE_TRANSFORMED = 2
 
+    @staticmethod
+    def get_target_default(target, in_relay_integration=False):
+        """Get the default layout rewrite option for the specified target.
+        Currently we only enable layout rewrite for cpu / mali backend for now
+
+        Parameters
+        ----------
+        target: tvm.target.Target
+            The compilation target.
+        in_relay_integration: bool
+            If this check is ask for relay integration.
+
+        Returns
+        -------
+        layout_rewrite_option: LayoutRewriteOption
+            The default layout rewrite option for the specified target.
+        """
+        layout_rewrite_option = LayoutRewriteOption.NO_REWRITE
+        if target.kind.name == "llvm" or (
+            "device" in target.attrs and target.attrs["device"] == "mali"
+        ):
+            layout_rewrite_option = (
+                LayoutRewriteOption.REWRITE_FOR_PRE_TRANSFORMED
+                if in_relay_integration
+                else LayoutRewriteOption.INSERT_TRANSFORM_STAGE
+            )
+
+        return layout_rewrite_option
+
 
 @tvm._ffi.register_object("auto_scheduler.ComputeDAG")
 class ComputeDAG(Object):
diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py
index 24a7577..2f177a2 100644
--- a/python/tvm/auto_scheduler/measure.py
+++ b/python/tvm/auto_scheduler/measure.py
@@ -53,7 +53,6 @@ from .utils import (
     make_traceback_info,
     request_remote,
 )
-from .compute_dag import LayoutRewriteOption
 from .workload_registry import (
     serialize_workload_registry_entry,
     deserialize_workload_registry_entry,
@@ -211,6 +210,7 @@ def recover_measure_input(inp, rebuild_state=False):
         target=task.target,
         target_host=task.target_host,
         hardware_params=task.hardware_params,
+        layout_rewrite_option=task.layout_rewrite_option,
     )
 
     if rebuild_state:
@@ -576,7 +576,7 @@ def _timed_func(inp_serialized, build_func, verbose):
 
     try:
         sch, args = task.compute_dag.apply_steps_from_state(
-            inp.state, layout_rewrite=LayoutRewriteOption.REWRITE_FOR_PRE_TRANSFORMED
+            inp.state, layout_rewrite=task.layout_rewrite_option
         )
     # pylint: disable=broad-except
     except Exception:
diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py
index 2b26fc4..3287f3d 100644
--- a/python/tvm/auto_scheduler/relay_integration.py
+++ b/python/tvm/auto_scheduler/relay_integration.py
@@ -33,7 +33,7 @@ from tvm.runtime import convert_to_object
 from tvm.te.tensor import ComputeOp, PlaceholderOp, Tensor
 from tvm.tir import expr as _expr
 from . import _ffi_api
-from .compute_dag import ComputeDAG
+from .compute_dag import ComputeDAG, LayoutRewriteOption
 from .dispatcher import DispatchContext
 from .search_task import SearchTask
 from .workload_registry import register_workload_tensors
@@ -126,6 +126,9 @@ def extract_tasks(
                 target=target,
                 target_host=target_host,
                 hardware_params=hardware_params,
+                # When auto scheduler is used in end to end network, try to apply layout rewrite
+                # to improve the overall performance
+                layout_rewrite_option=LayoutRewriteOption.get_target_default(target, True),
             )
         )
         weights.append(use_count_dict[ccache_key] + 1)
@@ -259,13 +262,7 @@ def auto_schedule_topi(outs, has_complex_op):
 
     key = register_workload_tensors(dag.hash_key(), io_tensors)
 
-    # only enable layout rewrite for cpu / mali backend
     target = tvm.target.Target.current()
-    enable_layout_rewrite_targets = ["cpu", "mali"]
-    enable_layout_rewrite = any(
-        enable_layout_rewrite_target in target.keys
-        for enable_layout_rewrite_target in enable_layout_rewrite_targets
-    )
 
     env = TracingEnvironment.current
     if env is None:
@@ -284,7 +281,10 @@ def auto_schedule_topi(outs, has_complex_op):
         schedule = te.create_schedule([x.op for x in outs])
     elif env.tracing_mode == TracingMode.PREPARE_LAYOUT_REWRITE:
         # in prepare_layout_rewrite mode
-        if enable_layout_rewrite and has_layout_free:
+        if (
+            LayoutRewriteOption.get_target_default(target, True) != LayoutRewriteOption.NO_REWRITE
+            and has_layout_free
+        ):
             dispatch_ctx = DispatchContext.current
             state = dispatch_ctx.query(target, key, has_complex_op, dag)
             if state is None:
diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py
index be83e06..bfa596a 100644
--- a/python/tvm/auto_scheduler/search_task.py
+++ b/python/tvm/auto_scheduler/search_task.py
@@ -178,6 +178,13 @@ class SearchTask(Object):
         The target host device of this search task.
     hardware_params : Optional[HardwareParams]
         Hardware parameters used in this search task.
+    layout_rewrite_option : Optional[LayoutRewriteOption]
+        The layout rewrite option used for measuring programs. If None, the default value will be
+        set depending on the specified target.
+        Auto_scheduler will find a better schedule for the specified layout rewrite option.
+        The NO_REWRITE and INSERT_TRANSFORM_STAGE are expected to be used when tuning a standalone
+        op, and the REWRITE_FOR_PRE_TRANSFORMED is expected to be used when tuning ops inside a
+        network.
 
     Examples
     --------
@@ -204,6 +211,7 @@ class SearchTask(Object):
         target=None,
         target_host=None,
         hardware_params=None,
+        layout_rewrite_option=None,
     ):
         assert (
             func is not None or workload_key is not None
@@ -221,7 +229,13 @@ class SearchTask(Object):
             target_host = Target(target_host)
 
         self.__init_handle_by_constructor__(
-            _ffi_api.SearchTask, compute_dag, workload_key, target, target_host, hardware_params
+            _ffi_api.SearchTask,
+            compute_dag,
+            workload_key,
+            target,
+            target_host,
+            hardware_params,
+            layout_rewrite_option or LayoutRewriteOption.get_target_default(target),
         )
 
     def tune(self, tuning_options, search_policy=None):
@@ -250,6 +264,7 @@ class SearchTask(Object):
         layout_rewrite_option : Optional[LayoutRewriteOption]
            The layout rewrite option.
 
+
         Returns
         -------
             A `te.Schedule` and the a list of `te.Tensor` to be used in `tvm.lower` or `tvm.build`.
@@ -260,11 +275,9 @@ class SearchTask(Object):
                 "Cannot find any valid schedule for %s in file %s" % (self.workload_key, log_file)
             )
 
-        if layout_rewrite_option is None:
-            layout_rewrite_option = LayoutRewriteOption.NO_REWRITE
-            if self.target.kind.name == "llvm":
-                layout_rewrite_option = LayoutRewriteOption.INSERT_TRANSFORM_STAGE
-        sch, args = self.compute_dag.apply_steps_from_state(inp.state, layout_rewrite_option)
+        sch, args = self.compute_dag.apply_steps_from_state(
+            inp.state, layout_rewrite_option or self.layout_rewrite_option
+        )
         return sch, args
 
     def print_best(self, log_file, print_mode="schedule"):
@@ -305,6 +318,7 @@ class SearchTask(Object):
             "target": self.target,
             "target_host": self.target_host,
             "hardware_params": self.hardware_params,
+            "layout_rewrite_option": self.layout_rewrite_option,
         }
 
     def __setstate__(self, state):
@@ -327,6 +341,7 @@ class SearchTask(Object):
             state["target"],
             state["target_host"],
             state["hardware_params"],
+            state["layout_rewrite_option"],
         )
 
 
diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc
index 64114c8..b658782 100755
--- a/src/auto_scheduler/compute_dag.cc
+++ b/src/auto_scheduler/compute_dag.cc
@@ -998,11 +998,20 @@ ComputeDAG ComputeDAG::RewriteLayout(Array<Step>* transform_steps,
             transform_steps->Set(i, std::move(step));
           }
         }
+
+        // Add schedule for the new added transform stage
         Array<Integer> to_fuse;
-        for (size_t i = 0; i < new_shape.size() - 1; i++) {
-          to_fuse.push_back(i);
+
+        if (new_shape.size() >= 5) {
+          to_fuse.push_back(0);
+          to_fuse.push_back(1);
+          to_fuse.push_back(2);
+          transform_steps->push_back(FuseStep(stage_id, to_fuse));
+        } else if (new_shape.size() >= 3) {
+          to_fuse.push_back(0);
+          to_fuse.push_back(1);
+          transform_steps->push_back(FuseStep(stage_id, to_fuse));
         }
-        transform_steps->push_back(FuseStep(stage_id, to_fuse));
         transform_steps->push_back(AnnotationStep(stage_id, 0, IteratorAnnotation::kParallel));
       }
 
@@ -1024,7 +1033,10 @@ ComputeDAG ComputeDAG::RewriteLayout(Array<Step>* transform_steps,
             }
             original_compute_op = op;
             CHECK(!new_compute_op.defined());
-            new_compute_op = te::ComputeOp(pop->name, pop->tag, pop->attrs, pop->axis, new_body);
+            auto new_attrs = pop->attrs;
+            new_attrs.Set("ori_placeholder_layout", tvm::String(origin_layout));
+            new_attrs.Set("new_placeholder_layout", tvm::String(new_layout));
+            new_compute_op = te::ComputeOp(pop->name, pop->tag, new_attrs, pop->axis, new_body);
           }
         }
       }
diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc
index 53287a0..47b9fb6 100755
--- a/src/auto_scheduler/feature.cc
+++ b/src/auto_scheduler/feature.cc
@@ -1398,7 +1398,8 @@ void GetPerStoreFeaturesFromFile(const std::string& filename, int max_lines, int
       // rebuild task
       Array<te::Tensor> tensors = (*workload_key_to_tensors)(workload_key);
       task = SearchTask(ComputeDAG(tensors), workload_key, cur_inp->task->target,
-                        cur_inp->task->target_host, cur_inp->task->hardware_params);
+                        cur_inp->task->target_host, cur_inp->task->hardware_params,
+                        cur_inp->task->layout_rewrite_option);
       task_id = task_cache.size();
 
       // compute min cost for each task
@@ -1465,7 +1466,8 @@ void GetPerStoreFeaturesFromMeasurePairs(const Array<MeasureInput>& inputs,
         // rebuild task for incomplete measure pairs read from file
         Array<te::Tensor> tensors = (*workload_key_to_tensors)(workload_key);
         task = SearchTask(ComputeDAG(tensors), workload_key, inputs[i]->task->target,
-                          inputs[i]->task->target_host, inputs[i]->task->hardware_params);
+                          inputs[i]->task->target_host, inputs[i]->task->hardware_params,
+                          inputs[i]->task->layout_rewrite_option);
       }
       task_id = task_cache.size();
 
diff --git a/src/auto_scheduler/measure_record.cc b/src/auto_scheduler/measure_record.cc
index faf3fca..1120f43 100644
--- a/src/auto_scheduler/measure_record.cc
+++ b/src/auto_scheduler/measure_record.cc
@@ -165,12 +165,16 @@ struct Handler<::tvm::auto_scheduler::SearchTaskNode> {
     writer->WriteArrayItem(*data.hardware_params.get());
     if (data.target_host.defined()) {
       writer->WriteArrayItem(data.target_host->str());
+    } else {
+      writer->WriteArrayItem(std::string(""));
     }
+    writer->WriteArrayItem(static_cast<int>(data.layout_rewrite_option));
     writer->EndArray();
   }
   inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_scheduler::SearchTaskNode* data) {
     bool s;
     std::string str_value;
+    int int_value;
     auto hardware_params_node = ::tvm::make_object<::tvm::auto_scheduler::HardwareParamsNode>();
     reader->BeginArray();
     s = reader->NextArrayItem();
@@ -188,7 +192,13 @@ struct Handler<::tvm::auto_scheduler::SearchTaskNode> {
       data->hardware_params = ::tvm::auto_scheduler::HardwareParams(hardware_params_node);
       if (s) {
         reader->Read(&str_value);
-        data->target_host = ::tvm::Target(str_value);
+        if (!str_value.empty()) {
+          data->target_host = ::tvm::Target(str_value);
+        }
+        s = reader->NextArrayItem();
+        ICHECK(s);
+        reader->Read(&int_value);
+        data->layout_rewrite_option = ::tvm::auto_scheduler::LayoutRewriteOption(int_value);
         s = reader->NextArrayItem();
         ICHECK(!s);
       }
diff --git a/src/auto_scheduler/search_task.cc b/src/auto_scheduler/search_task.cc
index 93f3460..0abee16 100755
--- a/src/auto_scheduler/search_task.cc
+++ b/src/auto_scheduler/search_task.cc
@@ -113,7 +113,8 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target
 }
 
 SearchTask::SearchTask(ComputeDAG compute_dag, String workload_key, Target target,
-                       Target target_host, Optional<HardwareParams> hardware_params) {
+                       Target target_host, Optional<HardwareParams> hardware_params,
+                       LayoutRewriteOption layout_rewrite_option) {
   auto node = make_object<SearchTaskNode>();
   node->compute_dag = std::move(compute_dag);
   node->workload_key = std::move(workload_key);
@@ -125,6 +126,7 @@ SearchTask::SearchTask(ComputeDAG compute_dag, String workload_key, Target targe
     node->hardware_params =
         HardwareParamsNode::GetDefaultHardwareParams(node->target, node->target_host);
   }
+  node->layout_rewrite_option = layout_rewrite_option;
   data_ = std::move(node);
 }
 
@@ -139,8 +141,10 @@ TVM_REGISTER_GLOBAL("auto_scheduler.HardwareParams")
 
 TVM_REGISTER_GLOBAL("auto_scheduler.SearchTask")
     .set_body_typed([](ComputeDAG compute_dag, String workload_key, Target target,
-                       Target target_host, Optional<HardwareParams> hardware_params) {
-      return SearchTask(compute_dag, workload_key, target, target_host, hardware_params);
+                       Target target_host, Optional<HardwareParams> hardware_params,
+                       int layout_rewrite_option) {
+      return SearchTask(compute_dag, workload_key, target, target_host, hardware_params,
+                        LayoutRewriteOption(layout_rewrite_option));
     });
 
 }  // namespace auto_scheduler