You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/07/27 07:15:37 UTC

[GitHub] [incubator-tvm] jcf94 commented on a change in pull request #6142: [Ansor][AutoTVM v2.0] Phase 1: Add follow_split and follow_fused_split steps

jcf94 commented on a change in pull request #6142:
URL: https://github.com/apache/incubator-tvm/pull/6142#discussion_r460685681



##########
File path: include/tvm/auto_scheduler/loop_state.h
##########
@@ -356,9 +356,30 @@ class State : public ObjectRef {
    * \note If we do split on an iterator which has stages attached at it(by compute_at), the inner
    * most iterator of split results will become the new attach point.
    */
-  TVM_DLL Array<Iterator> split(int stage_id, const Iterator& it,
-                                const Array<Optional<Integer>>& lengths,
-                                bool inner_to_outer = true);
+  Array<Iterator> split(int stage_id, const Iterator& it, const Array<Optional<Integer>>& lengths,
+                        bool inner_to_outer = true);

Review comment:
       Recover the split function to the original one.

##########
File path: include/tvm/auto_scheduler/loop_state.h
##########
@@ -356,9 +356,30 @@ class State : public ObjectRef {
    * \note If we do split on an iterator which has stages attached at it(by compute_at), the inner
    * most iterator of split results will become the new attach point.
    */
-  TVM_DLL Array<Iterator> split(int stage_id, const Iterator& it,
-                                const Array<Optional<Integer>>& lengths,
-                                bool inner_to_outer = true);
+  Array<Iterator> split(int stage_id, const Iterator& it, const Array<Optional<Integer>>& lengths,
+                        bool inner_to_outer = true);
+  /*!
+   * \brief Schedule primitive extends to split step.
+   * \param stage_id The index of the stage to be split.
+   * \param it The iterator to be split.
+   * \param src_step_id The index of the split step to follow in the history.
+   * \param n_split The number of split level.
+   * \return The splitted new Iterators.
+   */
+  Array<Iterator> follow_split(int stage_id, const Iterator& it, int src_step_id, int n_split);

Review comment:
       Add `TVM_DLL` hint before this function, see other function as a reference.

##########
File path: src/auto_scheduler/loop_state.cc
##########
@@ -454,6 +475,25 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateSplit")
       return Array<ObjectRef>{state, res};
     });
 
+TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowSplit")
+    .set_body_typed([](State state, int stage_id, const Iterator& it, int src_step_id,
+                       int n_split) {
+      const auto& res = state.follow_split(stage_id, it, src_step_id, n_split);
+      return Array<ObjectRef>{state, Array<Iterator>(res)};
+    });
+
+TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowFusedSplit")
+    .set_body_typed([](State state, int stage_id, const Iterator& it,
+                       const Array<IntImm>& src_step_ids, int level, bool factor_or_nparts) {

Review comment:
       This seems a historical issue from our old code base, we can directly use Array<Integer> as the function parameter, and remove the `for ... copy` below.

##########
File path: src/auto_scheduler/transform_step.cc
##########
@@ -778,6 +795,184 @@ String SplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
 
 /********** Steps working on multiple stages **********/
 
+/********** Follow Split **********/
+FollowSplitStep::FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split) {
+  auto node = make_object<FollowSplitStepNode>();
+  node->stage_id = stage_id;
+  node->iter_id = iter_id;
+  node->src_step_id = src_step_id;
+  node->n_split = n_split;
+  data_ = std::move(node);
+}
+
+void FollowSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArrayItem(iter_id);
+  writer->WriteArrayItem(src_step_id);
+  writer->WriteArrayItem(n_split);
+}
+
+void FollowSplitStepNode::ExtractSplitLengths(const Array<Step>& transform_steps,
+                                              Array<Optional<Integer>>* lengths) const {
+  CHECK_LT(src_step_id, transform_steps.size());
+  auto ps = transform_steps[src_step_id].as<SplitStepNode>();
+  CHECK(ps != nullptr);
+
+  // get lengths from src step
+  lengths->reserve(n_split);
+  int j = 0;
+  for (; j < n_split - 1; ++j) {
+    lengths->push_back(ps->lengths[j]);
+  }
+  PrimExpr last_factor = 1;
+  for (; j < static_cast<int>(ps->lengths.size()); ++j) {
+    if (ps->lengths[j]) {
+      last_factor *= ps->lengths[j].value();
+    } else {
+      last_factor = PrimExpr();
+      break;
+    }
+  }
+  if (last_factor.defined()) {
+    lengths->push_back(Downcast<Integer>(last_factor));
+  } else {
+    lengths->push_back(NullOpt);
+  }
+}
+
+FollowSplitStep::FollowSplitStep(dmlc::JSONReader* reader) {
+  auto node = make_object<FollowSplitStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->iter_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->src_step_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->n_split);
+
+  data_ = std::move(node);
+}
+
+Array<Iterator> FollowSplitStepNode::ApplyToState(State* state) const {
+  Array<Optional<Integer>> lengths;
+  ExtractSplitLengths((*state)->transform_steps, &lengths);
+  return ApplySplitToState(state, stage_id, iter_id, lengths, true);
+}
+
+Array<IterVar> FollowSplitStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+                                                    StageToAxesMap* stage_to_axes,
+                                                    const Array<Step>& transform_steps) const {
+  Array<Optional<Integer>> lengths;
+  ExtractSplitLengths(transform_steps, &lengths);
+  return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, lengths, true);
+}
+
+String FollowSplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
+                                             StageToAxesMap* stage_to_axes,
+                                             const Array<Step>& transform_steps) const {
+  Array<Optional<Integer>> lengths;
+  ExtractSplitLengths(transform_steps, &lengths);
+  return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, true);
+}
+
+/********** Follow Fused Split **********/
+FollowFusedSplitStep::FollowFusedSplitStep(int stage_id, int iter_id,
+                                           const Array<Integer>& src_step_ids, int level,
+                                           bool factor_or_nparts) {
+  auto node = make_object<FollowFusedSplitStepNode>();
+  node->stage_id = stage_id;
+  node->iter_id = iter_id;
+  node->src_step_ids = src_step_ids;
+  node->level = level;
+  node->factor_or_nparts = factor_or_nparts;
+  data_ = std::move(node);
+}
+
+FollowFusedSplitStep::FollowFusedSplitStep(dmlc::JSONReader* reader) {
+  auto node = make_object<FollowFusedSplitStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->iter_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  std::vector<int> int_list;
+  reader->Read(&int_list);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->level);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->factor_or_nparts);
+
+  ::tvm::Array<::tvm::Integer> src_step_ids;
+  for (const auto& i : int_list) {
+    src_step_ids.push_back(i);
+  }
+  node->src_step_ids = src_step_ids;
+  data_ = std::move(node);
+}
+
+void FollowFusedSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArrayItem(iter_id);
+  writer->WriteArrayItem(IntArrayToVector(src_step_ids));
+  writer->WriteArrayItem(level);
+  writer->WriteArrayItem(static_cast<int>(factor_or_nparts));
+}
+
+Optional<Integer> FollowFusedSplitStepNode::ExtractSplitLength(
+    const Array<Step>& transform_steps) const {
+  PrimExpr ret(1);
+
+  for (int src_step_id : src_step_ids) {
+    CHECK_LT(src_step_id, transform_steps.size());
+    auto ps = transform_steps[src_step_id].as<SplitStepNode>();
+    CHECK(ps != nullptr);
+    if (ps->lengths[level] && ret.defined()) {
+      ret *= ps->lengths[level].value();
+    } else {
+      return NullOpt;
+    }
+  }
+  return Downcast<Integer>(ret);
+}
+
+Array<Iterator> FollowFusedSplitStepNode::ApplyToState(State* state) const {
+  const Optional<Integer>& length = ExtractSplitLength((*state)->transform_steps);
+  return ApplySplitToState(state, stage_id, iter_id, {length}, factor_or_nparts);
+}
+
+Array<IterVar> FollowFusedSplitStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+                                                         StageToAxesMap* stage_to_axes,
+                                                         const Array<Step>& transform_steps) const {
+  const Optional<Integer>& length = ExtractSplitLength(transform_steps);
+  return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, {length}, factor_or_nparts);
+}
+
+String FollowFusedSplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
+                                                  StageToAxesMap* stage_to_axes,
+                                                  const Array<Step>& transform_steps) const {
+  const Optional<Integer>& length = ExtractSplitLength(transform_steps);
+  return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, {length},
+                               factor_or_nparts);
+}
+
+/********** Primitives working on multiple stages **********/

Review comment:
       Remove this line, and put the line 796: `Steps working on multiple stages` here.

##########
File path: src/auto_scheduler/loop_state.cc
##########
@@ -268,6 +268,27 @@ Array<Iterator> State::split(int stage_id, const Iterator& it,
   return step->ApplyToState(this);
 }
 
+Array<Iterator> State::follow_split(int stage_id, const Iterator& it, int src_step_id,
+                                    int n_split) {
+  const Stage& stage = operator->()->stages[stage_id];
+
+  FollowSplitStep step =
+      FollowSplitStep(stage_id, GetIndex(stage->iters, it), src_step_id, n_split);
+  CopyOnWrite()->transform_steps.push_back(step);
+  return step->ApplyToState(this);
+}
+
+Array<Iterator> State::follow_fused_split(int stage_id, const Iterator& it,
+                                          const Array<Integer>& src_step_ids, int level,
+                                          bool factor_or_nparts) {
+  const Stage& stage = operator->()->stages[stage_id];
+

Review comment:
       diito

##########
File path: src/auto_scheduler/loop_state.cc
##########
@@ -268,6 +268,27 @@ Array<Iterator> State::split(int stage_id, const Iterator& it,
   return step->ApplyToState(this);
 }
 
+Array<Iterator> State::follow_split(int stage_id, const Iterator& it, int src_step_id,
+                                    int n_split) {
+  const Stage& stage = operator->()->stages[stage_id];
+

Review comment:
       Remove this blank line.

##########
File path: include/tvm/auto_scheduler/transform_step.h
##########
@@ -489,6 +490,166 @@ class SplitStep : public Step {
 
 /********** Steps working on multiple stages **********/
 
+/*! \brief Similar to SplitStepNode, but uses split factors from another step
+ * (i.e. Follow another split step) */
+class FollowSplitStepNode : public StepNode {
+ public:
+  /*! \brief The id of the iter to be split. */
+  int iter_id;
+  /*! \brief The index of the split step to follow in the history. */
+  int src_step_id;
+  /*! \brief The number of split level. */
+  int n_split;
+
+  void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+  /*!
+   * \brief Extract split lengths.
+   * \param transform_steps An array record all transform steps.
+   * \param lengths The multiple split factors. Can be None to be filled by search policy.
+   */
+  void ExtractSplitLengths(const Array<Step>& transform_steps,
+                           Array<Optional<Integer>>* lengths) const;
+
+  /*!
+   * \brief Apply the current step to State.
+   * \param state A mutable pointer to State.
+   */
+  Array<Iterator> ApplyToState(State* state) const;
+
+  /*!
+   * \brief Apply the current step to tvm.schedule.
+   * \param stages A pointer to a `te::Stage` Array.
+   * \param stage_to_axes A pointer to a StageToAxesMap.
+   * \param transform_steps An array record all transform steps.
+   */
+  Array<tir::IterVar> ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+                                      const Array<Step>& transform_steps) const;
+
+  /*!
+   * \brief Print the current step as equivalent python schedule API.
+   * \param stages A pointer to a `te::Stage` Array.
+   * \param stage_to_axes A pointer to a StageToAxesMap.
+   * \param transform_steps An array record all transform steps.
+   * \return Python schedule code.
+   */
+  String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+                          const Array<Step>& transform_steps) const;
+
+  static constexpr const char* record_prefix_str = "FSP";
+
+  static constexpr const char* _type_key = "auto_scheduler.FollowSplitStep";
+  TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to FollowSplitStepNode.
+ * \sa FollowSplitStepNode
+ */
+class FollowSplitStep : public Step {
+ public:
+  /*!
+   * \brief The constructor.
+   * \param stage_id The index of the stage to be split.
+   * \param iter_id The index of the iterator to be split.
+   * \param src_step_id The index of the split step to follow in the history.
+   * \param n_split The number of split level.
+   */
+  FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split);
+
+  /*!
+   * \brief The constructor used to read a step record from JSONReader and create the
+   * corresponding step.
+   * \param reader The input JSONReader.
+   */
+  explicit FollowSplitStep(dmlc::JSONReader* reader);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(FollowSplitStep, Step, FollowSplitStepNode);
+};
+
+/*! \brief Similar to FollowSplitStep, but use split factors from multiple steps.
+ *  \Note This can be used for the split in cooperative fetching
+ */
+class FollowFusedSplitStepNode : public StepNode {
+ public:
+  /*! \brief The id of the iter to split. */
+  int iter_id;
+  /*! \brief The indices of the split steps to follow in the history. */
+  Array<Integer> src_step_ids;
+  /*! \brief  Use the length in this split level. */
+  int level;
+  /*! \brief If this is true, use factor. Otherwise, use nparts. */
+  bool factor_or_nparts;
+
+  void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+  /*!
+   * \brief Extract split length.
+   * \param transform_steps An array record all transform steps.
+   * \return Split factor.
+   */
+  Optional<Integer> ExtractSplitLength(const Array<Step>& transform_steps) const;
+
+  /*!
+   * \brief Apply the current step to State.
+   * \param state A mutable pointer to State.
+   */
+  Array<Iterator> ApplyToState(State* state) const;
+
+  /*!
+   * \brief Apply the current step to tvm.schedule.
+   * \param stages A pointer to a `te::Stage` Array.
+   * \param stage_to_axes A pointer to a StageToAxesMap.
+   * \param transform_steps An array record all transform steps.
+   */
+  Array<tir::IterVar> ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+                                      const Array<Step>& transform_steps) const;
+
+  /*!
+   * \brief Print the current step as equivalent python schedule API.
+   * \param stages A pointer to a `te::Stage` Array.
+   * \param stage_to_axes A pointer to a StageToAxesMap.
+   * \param transform_steps An array record all transform steps.
+   * \return Python schedule code.
+   */
+  String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+                          const Array<Step>& transform_steps) const;
+
+  static constexpr const char* record_prefix_str = "FFSP";
+
+  static constexpr const char* _type_key = "auto_scheduler.FollowFusedSplitStep";
+  TVM_DECLARE_FINAL_OBJECT_INFO(FollowFusedSplitStepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to FollowFusedSplitStepNode.
+ * \sa FollowFusedSplitStepNode
+ */
+class FollowFusedSplitStep : public Step {
+ public:
+  /*!
+   * \brief The constructor.
+   * \param stage_id The index of the stage to be split.
+   * \param iter_id The index of the iterator to be split.
+   * \param src_step_ids An array of index for split step to follow in the history.
+   * \param level Use the length in this split level.
+   * \param factor_or_nparts If this is true, use factor. Otherwise, use nparts.
+   */
+  FollowFusedSplitStep(int stage_id, int iter_id, const Array<Integer>& src_step_ids, int level,
+                       bool factor_or_nparts);
+
+  /*!
+   * \brief The constructor used to read a step record from JSONReader and create the
+   * corresponding step.
+   * \param reader The input JSONReader.
+   */
+  explicit FollowFusedSplitStep(dmlc::JSONReader* reader);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(FollowFusedSplitStep, Step, FollowFusedSplitStepNode);
+};
+
+/********** Primitives working on multiple stages **********/

Review comment:
       Remove this line, and put the line 491: `Steps working on multiple stages` here.

##########
File path: tests/python/unittest/test_auto_scheduler_measure.py
##########
@@ -22,13 +22,12 @@
 from tvm import te, auto_scheduler
 import tempfile
 
-from test_auto_scheduler_common import get_tiled_matmul
-
+from test_auto_scheduler_common import matmul_auto_scheduler_test, get_tiled_matmul
 
 def test_record():
     if not tvm.runtime.enabled("llvm"):
         return
-
+    #pdb.set_trace()

Review comment:
       Remove the debug code.

##########
File path: tests/python/unittest/test_auto_scheduler_loop_state.py
##########
@@ -142,6 +142,318 @@ def test_compute_at_root_inline():
     assert s0[conv].iters[5].range.extent == 7
     assert s0[conv].iters[6].range.extent == 7
 
+def test_follow_split_follow_fused_split():

Review comment:
       Put this function in the back of `test_cache_read_write`.
   Have a careful check that there seems to be two `test_cache_read_write` function in this file? This should not be marked in the git diff.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org