You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/07/29 07:18:49 UTC

[GitHub] [incubator-tvm] merrymercy commented on a change in pull request #6142: [Ansor][AutoTVM v2.0] Phase 1: Add follow_split and follow_fused_split steps

merrymercy commented on a change in pull request #6142:
URL: https://github.com/apache/incubator-tvm/pull/6142#discussion_r461125355



##########
File path: include/tvm/auto_scheduler/loop_state.h
##########
@@ -359,6 +359,29 @@ class State : public ObjectRef {
   TVM_DLL Array<Iterator> split(int stage_id, const Iterator& it,
                                 const Array<Optional<Integer>>& lengths,
                                 bool inner_to_outer = true);
+  /*!
+   * \brief Schedule primitive extends to split step.
+   * \param stage_id The index of the stage to be split.
+   * \param it The iterator to be split.
+   * \param src_step_id The index of the split step to follow in the history.

Review comment:
       ```suggestion
      * \param src_step_id The index of the split step to be followed in the history.
   ```

##########
File path: python/tvm/auto_scheduler/loop_state.py
##########
@@ -301,6 +310,89 @@ def split(self, stage, iterator, lengths, inner_to_outer=True):
                                                      iterator, lengths, inner_to_outer)
         return res
 
+    def follow_split(self, stage, iterator, src_step_id, n_split):
+        """ Schedule primitive extends to split step.
+
+        This step is used to follow a former SplitStep, keeps their iterator structures to be same.

Review comment:
       ```suggestion
           This step is used to follow a former SplitStep and keep their iterator structures to be the same.
   ```

##########
File path: src/auto_scheduler/transform_step.cc
##########
@@ -778,6 +795,184 @@ String SplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
 
 /********** Steps working on multiple stages **********/
 
+/********** Follow Split **********/
+FollowSplitStep::FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split) {
+  auto node = make_object<FollowSplitStepNode>();
+  node->stage_id = stage_id;
+  node->iter_id = iter_id;
+  node->src_step_id = src_step_id;
+  node->n_split = n_split;
+  data_ = std::move(node);
+}
+
+void FollowSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArrayItem(iter_id);
+  writer->WriteArrayItem(src_step_id);
+  writer->WriteArrayItem(n_split);
+}
+
+void FollowSplitStepNode::ExtractSplitLengths(const Array<Step>& transform_steps,
+                                              Array<Optional<Integer>>* lengths) const {
+  CHECK_LT(src_step_id, transform_steps.size());
+  auto ps = transform_steps[src_step_id].as<SplitStepNode>();
+  CHECK(ps != nullptr);
+
+  // get lengths from src step
+  lengths->reserve(n_split);
+  int j = 0;
+  for (; j < n_split - 1; ++j) {
+    lengths->push_back(ps->lengths[j]);
+  }
+  PrimExpr last_factor = 1;
+  for (; j < static_cast<int>(ps->lengths.size()); ++j) {
+    if (ps->lengths[j]) {
+      last_factor *= ps->lengths[j].value();
+    } else {
+      last_factor = PrimExpr();
+      break;
+    }
+  }
+  if (last_factor.defined()) {
+    lengths->push_back(Downcast<Integer>(last_factor));
+  } else {
+    lengths->push_back(NullOpt);
+  }
+}
+
+FollowSplitStep::FollowSplitStep(dmlc::JSONReader* reader) {
+  auto node = make_object<FollowSplitStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->iter_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->src_step_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->n_split);
+
+  data_ = std::move(node);
+}
+
+Array<Iterator> FollowSplitStepNode::ApplyToState(State* state) const {
+  Array<Optional<Integer>> lengths;
+  ExtractSplitLengths((*state)->transform_steps, &lengths);
+  return ApplySplitToState(state, stage_id, iter_id, lengths, true);
+}
+
+Array<IterVar> FollowSplitStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+                                                    StageToAxesMap* stage_to_axes,
+                                                    const Array<Step>& transform_steps) const {
+  Array<Optional<Integer>> lengths;
+  ExtractSplitLengths(transform_steps, &lengths);
+  return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, lengths, true);
+}
+
+String FollowSplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
+                                             StageToAxesMap* stage_to_axes,
+                                             const Array<Step>& transform_steps) const {
+  Array<Optional<Integer>> lengths;
+  ExtractSplitLengths(transform_steps, &lengths);
+  return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, true);
+}
+
+/********** Follow Fused Split **********/
+FollowFusedSplitStep::FollowFusedSplitStep(int stage_id, int iter_id,
+                                           const Array<Integer>& src_step_ids, int level,
+                                           bool factor_or_nparts) {
+  auto node = make_object<FollowFusedSplitStepNode>();
+  node->stage_id = stage_id;
+  node->iter_id = iter_id;
+  node->src_step_ids = src_step_ids;
+  node->level = level;
+  node->factor_or_nparts = factor_or_nparts;
+  data_ = std::move(node);
+}
+
+FollowFusedSplitStep::FollowFusedSplitStep(dmlc::JSONReader* reader) {
+  auto node = make_object<FollowFusedSplitStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);

Review comment:
       Yes, it depends on the order. This is a limitation of tvm's JSON serialization support.
   However, as you suggested, the only two functions using this order (`FollowFusedSplitStep::FollowFusedSplitSte` and `FollowFusedSplitStepNode::WriteToRecord`) have been alerady located adjacently in this file.

##########
File path: include/tvm/auto_scheduler/loop_state.h
##########
@@ -359,6 +359,29 @@ class State : public ObjectRef {
   TVM_DLL Array<Iterator> split(int stage_id, const Iterator& it,
                                 const Array<Optional<Integer>>& lengths,
                                 bool inner_to_outer = true);
+  /*!
+   * \brief Schedule primitive extends to split step.
+   * \param stage_id The index of the stage to be split.
+   * \param it The iterator to be split.
+   * \param src_step_id The index of the split step to follow in the history.
+   * \param n_split The number of split level.
+   * \return The splitted new Iterators.
+   */
+  TVM_DLL Array<Iterator> follow_split(int stage_id, const Iterator& it, int src_step_id,
+                                       int n_split);
+  /*!
+   * \brief Schedule primitive extends to split step.
+   * \param stage_id The index of the stage to be split.
+   * \param it The iterator to be split.
+   * \param src_step_ids The indices of the split steps to follow in the history.

Review comment:
       Propagate this change to other places.

##########
File path: src/auto_scheduler/transform_step.cc
##########
@@ -111,6 +115,10 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) {
     ps->ApplyToState(state);
   } else if (auto ps = step.as<SplitStepNode>()) {
     ps->ApplyToState(state);
+  } else if (auto ps = step.as<FollowSplitStepNode>()) {

Review comment:
       The `ps`s in different branches have different types, so they cannot be merged.

##########
File path: include/tvm/auto_scheduler/loop_state.h
##########
@@ -359,6 +359,29 @@ class State : public ObjectRef {
   TVM_DLL Array<Iterator> split(int stage_id, const Iterator& it,
                                 const Array<Optional<Integer>>& lengths,
                                 bool inner_to_outer = true);
+  /*!
+   * \brief Schedule primitive extends to split step.
+   * \param stage_id The index of the stage to be split.
+   * \param it The iterator to be split.
+   * \param src_step_id The index of the split step to follow in the history.
+   * \param n_split The number of split level.
+   * \return The splitted new Iterators.
+   */
+  TVM_DLL Array<Iterator> follow_split(int stage_id, const Iterator& it, int src_step_id,
+                                       int n_split);
+  /*!
+   * \brief Schedule primitive extends to split step.
+   * \param stage_id The index of the stage to be split.
+   * \param it The iterator to be split.
+   * \param src_step_ids The indices of the split steps to follow in the history.

Review comment:
       ```suggestion
      * \param src_step_ids The indices of the split steps to be followed in the history.
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org