You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/10/02 13:38:33 UTC

[incubator-tvm] branch master updated: [Ansor] Support multiple output ops and fix Python API printing (#6584)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 72969b2  [Ansor] Support multiple output ops and fix Python API printing (#6584)
72969b2 is described below

commit 72969b27c6476d2269a3b0caf0c100cdae000c7a
Author: Cody Yu <co...@gmail.com>
AuthorDate: Fri Oct 2 06:38:22 2020 -0700

    [Ansor] Support multiple output ops and fix Python API printing (#6584)
---
 src/auto_scheduler/compute_dag.cc                  | 32 ++++----
 .../search_policy/sketch_policy_rules.h            |  8 ++
 src/auto_scheduler/transform_step.cc               | 87 +++++++++++++---------
 src/auto_scheduler/utils.h                         |  8 +-
 4 files changed, 81 insertions(+), 54 deletions(-)

diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc
index 7c9ce4c..23b3817 100755
--- a/src/auto_scheduler/compute_dag.cc
+++ b/src/auto_scheduler/compute_dag.cc
@@ -969,6 +969,9 @@ void ComputeDAG::RewriteLayout(const Array<Step>& transform_steps) {
       }
     }  // end for placeholder
   }    // end for stage
+  p_dag->access_analyzer = AccessAnalyzer(p_dag->tensors);
+  p_dag->ops = p_dag->access_analyzer->ops_topo_order;
+  p_dag->flop_ct = FlopEstimator().EstimateFlop(p_dag->ops);
 }
 
 std::pair<te::Schedule, Array<te::Tensor>> ComputeDAG::ApplySteps(
@@ -989,16 +992,15 @@ std::pair<te::Schedule, Array<te::Tensor>> ComputeDAG::ApplySteps(
   if (stage_to_axes == nullptr) {
     stage_to_axes = &temp_stage_to_axes;
   }
-  Array<te::Operation> ops;
+  Array<te::Operation> out_ops;
   for (const auto& op : operator->()->ops) {
-    if (!op->IsInstance<te::PlaceholderOpNode>()) {
-      ops.push_back(op);
+    if (operator->()->access_analyzer.IsOutput(op)) {
+      out_ops.push_back(op);
     }
   }
+
   // Create the initial schedule
-  // TODO(jcf94): Currently we only checked single output dag for TVM Auto-scheduler,
-  // update this after testing with multiple outputs.
-  te::Schedule schedule = te::create_schedule({ops.back()});
+  te::Schedule schedule = te::create_schedule(out_ops);
 
   // init axes
   for (const auto& x : operator->()->ops) {
@@ -1019,16 +1021,14 @@ std::pair<te::Schedule, Array<te::Tensor>> ComputeDAG::ApplySteps(
 String ComputeDAG::PrintStepsAsPython(const Array<Step>& transform_steps) const {
   Array<te::Stage> stages;
   StageToAxesMap stage_to_axes;
-  Array<te::Operation> ops;
+  Array<te::Operation> out_ops;
   for (const auto& op : operator->()->ops) {
-    if (!op->IsInstance<te::PlaceholderOpNode>()) {
-      ops.push_back(op);
+    if (operator->()->access_analyzer.IsOutput(op)) {
+      out_ops.push_back(op);
     }
   }
   // Create the initial schedule
-  // TODO(jcf94): Currently we only checked single output dag for TVM Auto-scheduler,
-  // update this after testing with multiple outputs.
-  te::Schedule schedule = te::create_schedule({ops.back()});
+  te::Schedule schedule = te::create_schedule(out_ops);
 
   // init axes
   for (const auto& x : operator->()->ops) {
@@ -1040,16 +1040,18 @@ String ComputeDAG::PrintStepsAsPython(const Array<Step>& transform_steps) const
   std::stringstream ss;
   for (const auto& stage : stages) {
     if (stage->op->IsInstance<te::ComputeOpNode>()) {
+      auto op_name = CleanName(stage->op->name);
+
       for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
-        ss << stage->leaf_iter_vars[i]->var->name_hint;
+        ss << CleanName(stage->leaf_iter_vars[i]->var->name_hint, op_name);
         if (i != stage->leaf_iter_vars.size() - 1) {
           ss << ", ";
         }
       }
       ss << " = "
-         << "tuple(" << stage->op->name << ".op.axis)"
+         << "tuple(" << op_name << ".op.axis)"
          << " + "
-         << "tuple(" << stage->op->name << ".op.reduce_axis)\n";
+         << "tuple(" << op_name << ".op.reduce_axis)\n";
     }
   }
   // Call each step's PrintAsPythonAPI method
diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.h b/src/auto_scheduler/search_policy/sketch_policy_rules.h
index 928efc5..035dc89 100644
--- a/src/auto_scheduler/search_policy/sketch_policy_rules.h
+++ b/src/auto_scheduler/search_policy/sketch_policy_rules.h
@@ -28,6 +28,7 @@
 #include <tvm/auto_scheduler/loop_state.h>
 #include <tvm/auto_scheduler/search_task.h>
 
+#include <string>
 #include <utility>
 #include <vector>
 
@@ -74,6 +75,12 @@ class SketchGenerationRule {
    */
   virtual std::vector<std::pair<State, int>> Apply(const SketchPolicyNode& policy,
                                                    const State& state, int stage_id) const = 0;
+
+  /*!
+   * \brief Get the name of this rule.
+   * \return A string of the rule name.
+   */
+  virtual std::string GetRuleName() const = 0;
 };
 
 #define DEFINE_SKETCH_GENERATION_RULE(rule_name)                                                 \
@@ -83,6 +90,7 @@ class SketchGenerationRule {
                                 int stage_id) const final;                                       \
     std::vector<std::pair<State, int>> Apply(const SketchPolicyNode& policy, const State& state, \
                                              int stage_id) const final;                          \
+    std::string GetRuleName() const final { return #rule_name; }                                 \
   };
 
 /*! \brief The rule that simply skips the current stage. It returns an unchanged state and move to
diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc
index 2a93497..73f6734 100755
--- a/src/auto_scheduler/transform_step.cc
+++ b/src/auto_scheduler/transform_step.cc
@@ -356,8 +356,9 @@ String AnnotationStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
   std::stringstream ss;
   const auto& stage = (*stages)[stage_id];
   const auto& iter = (*stage_to_axes)[stage][iter_id];
+  const auto& op_name = CleanName(stage->op->name);
 
-  ss << "s[" << CleanName(stage->op->name) << "].";
+  ss << "s[" << op_name << "].";
   switch (annotation) {
     case IteratorAnnotation::kUnroll:
       ss << "unroll(";
@@ -383,7 +384,7 @@ String AnnotationStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
       LOG(FATAL) << "Invalid annotation " << static_cast<int>(annotation);
       break;
   }
-  ss << CleanName(iter->var->name_hint);
+  ss << CleanName(iter->var->name_hint, op_name);
   switch (annotation) {
     case IteratorAnnotation::kVThread:
     case IteratorAnnotation::kBlockX:
@@ -392,7 +393,7 @@ String AnnotationStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
     case IteratorAnnotation::kThreadX:
     case IteratorAnnotation::kThreadY:
     case IteratorAnnotation::kThreadZ:
-      ss << ", tvm.thread_axis(\"" << IteratorAnnotationString[static_cast<int>(annotation)]
+      ss << ", te.thread_axis(\"" << IteratorAnnotationString[static_cast<int>(annotation)]
          << "\")";
       break;
     default:
@@ -541,10 +542,11 @@ IterVar FuseStepNode::ApplyToSchedule(Array<te::Stage>* stages,
 String FuseStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
                                       StageToAxesMap* stage_to_axes) const {
   const auto& stage = (*stages)[stage_id];
+  const auto& op_name = CleanName(stage->op->name);
   std::stringstream to_fuse;
 
   for (size_t i = 0; i < fused_ids.size(); ++i) {
-    to_fuse << CleanName(stage_to_axes->at(stage)[fused_ids[i]]->var->name_hint);
+    to_fuse << CleanName(stage_to_axes->at(stage)[fused_ids[i]]->var->name_hint, op_name);
     if (i != fused_ids.size() - 1) {
       to_fuse << ", ";
     }
@@ -553,7 +555,7 @@ String FuseStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
   std::stringstream ss;
   const auto& fused = ApplyToSchedule(stages, stage_to_axes);
 
-  ss << CleanName(fused->var->name_hint) << " = s[" << CleanName(stage->op->name) << "].fuse("
+  ss << CleanName(fused->var->name_hint, op_name) << " = s[" << op_name << "].fuse("
      << to_fuse.str() << ")\n";
 
   return ss.str();
@@ -640,6 +642,7 @@ String PragmaStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
                                         StageToAxesMap* stage_to_axes) const {
   std::stringstream ss;
   const auto& stage = (*stages)[stage_id];
+  const auto& op_name = CleanName(stage->op->name);
 
   if (StrStartsWith(pragma_type, "auto_unroll_max_step")) {
     size_t pos = 0;
@@ -650,16 +653,16 @@ String PragmaStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
     }
     CHECK_LT(pos, pragma_type.size()) << "max step value not found.";
     int value = atoi(pragma_type.c_str() + pos + 1);
-    ss << "s[" << CleanName(stage->op->name) << "].pragma("
-       << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint)
+    ss << "s[" << op_name << "].pragma("
+       << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint, op_name)
        << ", \"auto_unroll_max_step\", " << value << ")\n";
-    ss << "s[" << CleanName(stage->op->name) << "].pragma("
-       << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint)
+    ss << "s[" << op_name << "].pragma("
+       << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint, op_name)
        << ", \"unroll_explicit\", True)\n";
   } else {
-    ss << "s[" << CleanName(stage->op->name) << "].pragma("
-       << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", \"" << pragma_type
-       << "\")\n";
+    ss << "s[" << op_name << "].pragma("
+       << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint, op_name) << ", \""
+       << pragma_type << "\")\n";
   }
 
   ApplyToSchedule(stages, stage_to_axes);
@@ -726,11 +729,12 @@ void ReorderStepNode::ApplyToSchedule(Array<te::Stage>* stages,
 String ReorderStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
                                          StageToAxesMap* stage_to_axes) const {
   const auto& stage = (*stages)[stage_id];
+  const auto& op_name = CleanName(stage->op->name);
   std::stringstream ss;
 
-  ss << "s[" << CleanName(stage->op->name) << "].reorder(";
+  ss << "s[" << op_name << "].reorder(";
   for (size_t i = 0; i < after_ids.size(); ++i) {
-    ss << CleanName((*stage_to_axes)[stage][after_ids[i]]->var->name_hint);
+    ss << CleanName((*stage_to_axes)[stage][after_ids[i]]->var->name_hint, op_name);
     if (i != after_ids.size() - 1) {
       ss << ", ";
     }
@@ -881,16 +885,17 @@ String PrintSplitAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_
   int size = static_cast<int>(lengths.size());
   if (inner_to_outer) {
     for (int i = size - 1; i >= 0; i--) {
-      ss << CleanName(outs[size - i]->var->name_hint) << ", "
-         << CleanName(outs[size - i - 1]->var->name_hint) << " = s[" << func_name << "].split("
-         << CleanName(to_split->var->name_hint) << ", factor=" << lengths[i] << ")\n";
+      ss << CleanName(outs[size - i]->var->name_hint, func_name) << ", "
+         << CleanName(outs[size - i - 1]->var->name_hint, func_name) << " = s[" << func_name
+         << "].split(" << CleanName(to_split->var->name_hint, func_name)
+         << ", factor=" << lengths[i] << ")\n";
       to_split = outs[size - i];
     }
   } else {
     for (int i = 0; i < size; i++) {
-      ss << CleanName(outs[i]->var->name_hint) << ", " << CleanName(outs[i + 1]->var->name_hint)
-         << " = s[" << func_name << "].split(" << CleanName(to_split->var->name_hint)
-         << ", nparts=" << lengths[i] << ")\n";
+      ss << CleanName(outs[i]->var->name_hint, func_name) << ", "
+         << CleanName(outs[i + 1]->var->name_hint, func_name) << " = s[" << func_name << "].split("
+         << CleanName(to_split->var->name_hint, func_name) << ", nparts=" << lengths[i] << ")\n";
       to_split = outs[i + 1];
     }
   }
@@ -1195,9 +1200,10 @@ String StorageAlignStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
                                               StageToAxesMap* stage_to_axes) const {
   std::stringstream ss;
   const auto& stage = (*stages)[stage_id];
-  ss << "s[" << CleanName(stage->op->name) << "].storage_align("
-     << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", " << factor << ", "
-     << offset << ")\n";
+  const auto& op_name = CleanName(stage->op->name);
+  ss << "s[" << op_name << "].storage_align("
+     << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint, op_name) << ", " << factor
+     << ", " << offset << ")\n";
 
   ApplyToSchedule(stages, stage_to_axes);
   return ss.str();
@@ -1269,8 +1275,11 @@ String ComputeAtStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
   std::stringstream ss;
   const auto& stage = (*stages)[stage_id];
   const auto& target_stage = (*stages)[target_stage_id];
-  ss << "s[" << CleanName(stage->op->name) << "].compute_at(s[" << CleanName(target_stage->op->name)
-     << "], " << CleanName((*stage_to_axes)[target_stage][target_iter_id]->var->name_hint) << ")\n";
+  const auto& op_name = CleanName(stage->op->name);
+  const auto& target_op_name = CleanName(target_stage->op->name);
+  ss << "s[" << op_name << "].compute_at(s[" << target_op_name << "], "
+     << CleanName((*stage_to_axes)[target_stage][target_iter_id]->var->name_hint, target_op_name)
+     << ")\n";
   ApplyToSchedule(stages, stage_to_axes);
   return ss.str();
 }
@@ -1516,7 +1525,8 @@ String CacheReadStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxes
   }
   auto out = ApplyToSchedule(stages, stage_to_axes, schedule);
 
-  ss << CleanName(out->op->name) << " = "
+  const auto& op_name = CleanName(out->op->name);
+  ss << op_name << " = "
      << "s.cache_read(" << CleanName(stage->op->name) << ", \"" << scope_name << "\", ["
      << CleanName(reader_stages[0]->op->name);
   for (size_t i = 1; i < reader_stage_ids.size(); ++i) {
@@ -1527,13 +1537,13 @@ String CacheReadStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxes
   // Print the iterators of the new added stage
   const auto& iters = out->op->root_iter_vars();
   for (size_t i = 0; i < iters.size(); ++i) {
-    ss << CleanName(iters[i]->var->name_hint);
+    ss << CleanName(iters[i]->var->name_hint, op_name);
     if (i != iters.size() - 1) {
       ss << ", ";
     }
   }
   ss << " = "
-     << "tuple(" << CleanName(out->op->name) << ".op.axis)\n";
+     << "tuple(" << op_name << ".op.axis)\n";
 
   return ss.str();
 }
@@ -1652,16 +1662,17 @@ String CacheWriteStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxe
   // Print the iterators of the new added stage
   for (const auto& out : outs) {
     const auto& iters = out->op->root_iter_vars();
+    const auto& op_name = CleanName(out->op->name);
     for (size_t i = 0; i < iters.size(); ++i) {
-      ss << CleanName(iters[i]->var->name_hint);
+      ss << CleanName(iters[i]->var->name_hint, op_name);
       if (i != iters.size() - 1) {
         ss << ", ";
       }
     }
     ss << " = "
-       << "tuple(" << CleanName(out->op->name) << ".op.axis)"
+       << "tuple(" << op_name << ".op.axis)"
        << " + "
-       << "tuple(" << CleanName(out->op->name) << ".op.reduce_axis)\n";
+       << "tuple(" << op_name << ".op.reduce_axis)\n";
   }
 
   return ss.str();
@@ -1764,30 +1775,32 @@ String RfactorStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMa
 
   for (const auto& out : outs) {
     const auto& iters = out->op->root_iter_vars();
+    const auto& op_name = CleanName(out->op->name);
     for (size_t i = 0; i < iters.size(); ++i) {
-      ss << CleanName(iters[i]->var->name_hint);
+      ss << CleanName(iters[i]->var->name_hint, op_name);
       if (i != iters.size() - 1) {
         ss << ", ";
       }
     }
     ss << " = "
-       << "tuple(" << CleanName(out->op->name) << ".op.axis)"
+       << "tuple(" << op_name << ".op.axis)"
        << " + "
-       << "tuple(" << CleanName(out->op->name) << ".op.reduce_axis)\n";
+       << "tuple(" << op_name << ".op.reduce_axis)\n";
   }
 
   const auto& output = (*stages)[stage_id + 1]->op.output(0);
   const auto& iters = output->op->root_iter_vars();
+  const auto& op_name = CleanName(output->op->name);
   for (size_t i = 0; i < iters.size(); ++i) {
-    ss << CleanName(iters[i]->var->name_hint);
+    ss << CleanName(iters[i]->var->name_hint, op_name);
     if (i != iters.size() - 1) {
       ss << ", ";
     }
   }
   ss << " = "
-     << "tuple(s[" << CleanName(output->op->name) << "].op.axis)"
+     << "tuple(s[" << op_name << "].op.axis)"
      << " + "
-     << "tuple(s[" << CleanName(output->op->name) << "].op.reduce_axis)\n";
+     << "tuple(s[" << op_name << "].op.reduce_axis)\n";
 
   return ss.str();
 }
diff --git a/src/auto_scheduler/utils.h b/src/auto_scheduler/utils.h
index d036743..610fec9 100755
--- a/src/auto_scheduler/utils.h
+++ b/src/auto_scheduler/utils.h
@@ -209,16 +209,20 @@ inline int64_t AxisLengthProd(const Array<tir::IterVar>& axes) {
 }
 
 /*!
- * \brief Clean the name of an iterator to make it valid in python code.
+ * \brief Clean the name of an iterator or an op to make it valid in python code.
  * \param str The original name.
+ * \param prefix The name prefix to differentiate the same name (e.g., the same iterator names).
  * \return The cleaned name.
  */
-inline std::string CleanName(const std::string& str) {
+inline std::string CleanName(const std::string& str, const std::string& prefix = "") {
   std::string ret = str;
   StrReplace(&ret, ".", "_");
   StrReplace(&ret, "@", "_");
   StrReplace(&ret, "outer", "o");
   StrReplace(&ret, "inner", "i");
+  if (prefix != "") {
+    return prefix + "_" + ret;
+  }
   return ret;
 }