You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/07/27 07:01:00 UTC

[GitHub] [incubator-tvm] jiuqi-yang opened a new pull request #6142: [Ansor][AutoTVM v2.0] Phase 1: Add follow_split and follow_fused_split steps

jiuqi-yang opened a new pull request #6142:
URL: https://github.com/apache/incubator-tvm/pull/6142


   For the full upstream plan, see Ansor RFC.
   
   In this PR, we bring follow split and follow fused split steps for Ansor auto_scheduler.
   
   cc @merrymercy @comaniac @junrushao1994 @FrozenGene @jroesch
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] merrymercy merged pull request #6142: [Ansor][AutoTVM v2.0] Phase 1: Add follow_split and follow_fused_split steps

Posted by GitBox <gi...@apache.org>.
merrymercy merged pull request #6142:
URL: https://github.com/apache/incubator-tvm/pull/6142


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] jcf94 commented on a change in pull request #6142: [Ansor][AutoTVM v2.0] Phase 1: Add follow_split and follow_fused_split steps

Posted by GitBox <gi...@apache.org>.
jcf94 commented on a change in pull request #6142:
URL: https://github.com/apache/incubator-tvm/pull/6142#discussion_r461261906



##########
File path: src/auto_scheduler/transform_step.cc
##########
@@ -111,6 +115,10 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) {
     ps->ApplyToState(state);
   } else if (auto ps = step.as<SplitStepNode>()) {
     ps->ApplyToState(state);
+  } else if (auto ps = step.as<FollowSplitStepNode>()) {

Review comment:
       Yes, the problem is these functions may have different parameters/return values, so they're not able to be merged to a single virtural function interface.

##########
File path: include/tvm/auto_scheduler/transform_step.h
##########
@@ -487,6 +490,164 @@ class SplitStep : public Step {
   TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode);
 };
 
+/*! \brief Similar to SplitStepNode, but uses split factors from another step
+ * (i.e. Follow another split step) */
+class FollowSplitStepNode : public StepNode {
+ public:
+  /*! \brief The id of the iter to be split. */
+  int iter_id;
+  /*! \brief The index of the split step to follow in the history. */
+  int src_step_id;
+  /*! \brief The number of split level. */
+  int n_split;
+
+  void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+  /*!
+   * \brief Extract split lengths.
+   * \param transform_steps An array record all transform steps.
+   * \param lengths The multiple split factors. Can be None to be filled by search policy.
+   */
+  void ExtractSplitLengths(const Array<Step>& transform_steps,
+                           Array<Optional<Integer>>* lengths) const;
+
+  /*!
+   * \brief Apply the current step to State.
+   * \param state A mutable pointer to State.

Review comment:
       ```suggestion
      * \param state A mutable pointer to state, which will be updated.
      * \return The iterator results after split.
   ```

##########
File path: include/tvm/auto_scheduler/transform_step.h
##########
@@ -487,6 +490,164 @@ class SplitStep : public Step {
   TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode);
 };
 
+/*! \brief Similar to SplitStepNode, but uses split factors from another step
+ * (i.e. Follow another split step) */
+class FollowSplitStepNode : public StepNode {
+ public:
+  /*! \brief The id of the iter to be split. */
+  int iter_id;
+  /*! \brief The index of the split step to follow in the history. */
+  int src_step_id;
+  /*! \brief The number of split level. */
+  int n_split;
+
+  void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+  /*!
+   * \brief Extract split lengths.
+   * \param transform_steps An array record all transform steps.
+   * \param lengths The multiple split factors. Can be None to be filled by search policy.
+   */
+  void ExtractSplitLengths(const Array<Step>& transform_steps,
+                           Array<Optional<Integer>>* lengths) const;
+
+  /*!
+   * \brief Apply the current step to State.
+   * \param state A mutable pointer to State.

Review comment:
       ```suggestion
      * \param state A mutable pointer to state, which will be updated.
      * \return The iterator results after split.
   ```

##########
File path: include/tvm/auto_scheduler/transform_step.h
##########
@@ -487,6 +490,164 @@ class SplitStep : public Step {
   TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode);
 };
 
+/*! \brief Similar to SplitStepNode, but uses split factors from another step
+ * (i.e. Follow another split step) */
+class FollowSplitStepNode : public StepNode {
+ public:
+  /*! \brief The id of the iter to be split. */
+  int iter_id;
+  /*! \brief The index of the split step to follow in the history. */
+  int src_step_id;
+  /*! \brief The number of split level. */
+  int n_split;
+
+  void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+  /*!
+   * \brief Extract split lengths.
+   * \param transform_steps An array record all transform steps.
+   * \param lengths The multiple split factors. Can be None to be filled by search policy.
+   */
+  void ExtractSplitLengths(const Array<Step>& transform_steps,
+                           Array<Optional<Integer>>* lengths) const;
+
+  /*!
+   * \brief Apply the current step to State.
+   * \param state A mutable pointer to State.
+   */
+  Array<Iterator> ApplyToState(State* state) const;
+
+  /*!
+   * \brief Apply the current step to tvm.schedule.
+   * \param stages A pointer to a `te::Stage` Array.
+   * \param stage_to_axes A pointer to a StageToAxesMap.
+   * \param transform_steps An array record all transform steps.

Review comment:
       ```suggestion
      * \param stages The `te::Stage`s used in TVM scheduler applying.
      * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
      * \param transform_steps An array record all transform steps.
      * \return The iterator results after split.
   ```

##########
File path: include/tvm/auto_scheduler/transform_step.h
##########
@@ -487,6 +490,164 @@ class SplitStep : public Step {
   TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode);
 };
 
+/*! \brief Similar to SplitStepNode, but uses split factors from another step
+ * (i.e. Follow another split step) */
+class FollowSplitStepNode : public StepNode {
+ public:
+  /*! \brief The id of the iter to be split. */
+  int iter_id;
+  /*! \brief The index of the split step to follow in the history. */
+  int src_step_id;
+  /*! \brief The number of split level. */
+  int n_split;
+
+  void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+  /*!
+   * \brief Extract split lengths.
+   * \param transform_steps An array record all transform steps.
+   * \param lengths The multiple split factors. Can be None to be filled by search policy.
+   */
+  void ExtractSplitLengths(const Array<Step>& transform_steps,
+                           Array<Optional<Integer>>* lengths) const;
+
+  /*!
+   * \brief Apply the current step to State.
+   * \param state A mutable pointer to State.
+   */
+  Array<Iterator> ApplyToState(State* state) const;
+
+  /*!
+   * \brief Apply the current step to tvm.schedule.
+   * \param stages A pointer to a `te::Stage` Array.
+   * \param stage_to_axes A pointer to a StageToAxesMap.
+   * \param transform_steps An array record all transform steps.
+   */
+  Array<tir::IterVar> ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+                                      const Array<Step>& transform_steps) const;
+
+  /*!
+   * \brief Print the current step as equivalent python schedule API.
+   * \param stages A pointer to a `te::Stage` Array.
+   * \param stage_to_axes A pointer to a StageToAxesMap.
+   * \param transform_steps An array record all transform steps.
+   * \return Python schedule code.
+   */
+  String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+                          const Array<Step>& transform_steps) const;
+
+  static constexpr const char* record_prefix_str = "FSP";
+
+  static constexpr const char* _type_key = "auto_scheduler.FollowSplitStep";
+  TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to FollowSplitStepNode.
+ * \sa FollowSplitStepNode
+ */
+class FollowSplitStep : public Step {
+ public:
+  /*!
+   * \brief The constructor.
+   * \param stage_id The index of the stage to be split.
+   * \param iter_id The index of the iterator to be split.
+   * \param src_step_id The index of the split step to follow in the history.
+   * \param n_split The number of split level.
+   */
+  FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split);
+
+  /*!
+   * \brief The constructor used to read a step record from JSONReader and create the
+   * corresponding step.
+   * \param reader The input JSONReader.
+   */
+  explicit FollowSplitStep(dmlc::JSONReader* reader);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(FollowSplitStep, Step, FollowSplitStepNode);
+};
+
+/*! \brief Similar to FollowSplitStep, but uses split factors from multiple steps.
+ *  \note This can be used for the split in cooperative fetching.
+ */
+class FollowFusedSplitStepNode : public StepNode {
+ public:
+  /*! \brief The id of the iter to split. */
+  int iter_id;
+  /*! \brief The indices of the split steps to follow in the history. */
+  Array<Integer> src_step_ids;
+  /*! \brief  Use the length in this split level. */
+  int level;
+  /*! \brief If this is true, use factor. Otherwise, use nparts. */
+  bool factor_or_nparts;
+
+  void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+  /*!
+   * \brief Extract split length.
+   * \param transform_steps An array record all transform steps.
+   * \return Split factor.
+   */
+  Optional<Integer> ExtractSplitLength(const Array<Step>& transform_steps) const;
+
+  /*!
+   * \brief Apply the current step to State.
+   * \param state A mutable pointer to State.
+   */
+  Array<Iterator> ApplyToState(State* state) const;
+
+  /*!
+   * \brief Apply the current step to tvm.schedule.
+   * \param stages A pointer to a `te::Stage` Array.
+   * \param stage_to_axes A pointer to a StageToAxesMap.
+   * \param transform_steps An array record all transform steps.
+   */
+  Array<tir::IterVar> ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+                                      const Array<Step>& transform_steps) const;
+
+  /*!
+   * \brief Print the current step as equivalent python schedule API.
+   * \param stages A pointer to a `te::Stage` Array.
+   * \param stage_to_axes A pointer to a StageToAxesMap.

Review comment:
       ```suggestion
      * \param stages The `te::Stage`s used in TVM scheduler applying.
      * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
   ```

##########
File path: include/tvm/auto_scheduler/transform_step.h
##########
@@ -487,6 +490,164 @@ class SplitStep : public Step {
   TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode);
 };
 
+/*! \brief Similar to SplitStepNode, but uses split factors from another step
+ * (i.e. Follow another split step) */
+class FollowSplitStepNode : public StepNode {
+ public:
+  /*! \brief The id of the iter to be split. */
+  int iter_id;
+  /*! \brief The index of the split step to follow in the history. */
+  int src_step_id;
+  /*! \brief The number of split level. */
+  int n_split;
+
+  void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+  /*!
+   * \brief Extract split lengths.
+   * \param transform_steps An array record all transform steps.
+   * \param lengths The multiple split factors. Can be None to be filled by search policy.
+   */
+  void ExtractSplitLengths(const Array<Step>& transform_steps,
+                           Array<Optional<Integer>>* lengths) const;
+
+  /*!
+   * \brief Apply the current step to State.
+   * \param state A mutable pointer to State.
+   */
+  Array<Iterator> ApplyToState(State* state) const;
+
+  /*!
+   * \brief Apply the current step to tvm.schedule.
+   * \param stages A pointer to a `te::Stage` Array.
+   * \param stage_to_axes A pointer to a StageToAxesMap.
+   * \param transform_steps An array record all transform steps.
+   */
+  Array<tir::IterVar> ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+                                      const Array<Step>& transform_steps) const;
+
+  /*!
+   * \brief Print the current step as equivalent python schedule API.
+   * \param stages A pointer to a `te::Stage` Array.
+   * \param stage_to_axes A pointer to a StageToAxesMap.
+   * \param transform_steps An array record all transform steps.
+   * \return Python schedule code.
+   */
+  String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+                          const Array<Step>& transform_steps) const;
+
+  static constexpr const char* record_prefix_str = "FSP";
+
+  static constexpr const char* _type_key = "auto_scheduler.FollowSplitStep";
+  TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to FollowSplitStepNode.
+ * \sa FollowSplitStepNode
+ */
+class FollowSplitStep : public Step {
+ public:
+  /*!
+   * \brief The constructor.
+   * \param stage_id The index of the stage to be split.
+   * \param iter_id The index of the iterator to be split.
+   * \param src_step_id The index of the split step to follow in the history.
+   * \param n_split The number of split level.
+   */
+  FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split);
+
+  /*!
+   * \brief The constructor used to read a step record from JSONReader and create the
+   * corresponding step.
+   * \param reader The input JSONReader.
+   */
+  explicit FollowSplitStep(dmlc::JSONReader* reader);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(FollowSplitStep, Step, FollowSplitStepNode);
+};
+
+/*! \brief Similar to FollowSplitStep, but uses split factors from multiple steps.
+ *  \note This can be used for the split in cooperative fetching.
+ */
+class FollowFusedSplitStepNode : public StepNode {
+ public:
+  /*! \brief The id of the iter to split. */
+  int iter_id;
+  /*! \brief The indices of the split steps to follow in the history. */
+  Array<Integer> src_step_ids;
+  /*! \brief  Use the length in this split level. */
+  int level;
+  /*! \brief If this is true, use factor. Otherwise, use nparts. */
+  bool factor_or_nparts;
+
+  void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+  /*!
+   * \brief Extract split length.
+   * \param transform_steps An array record all transform steps.
+   * \return Split factor.
+   */
+  Optional<Integer> ExtractSplitLength(const Array<Step>& transform_steps) const;
+
+  /*!
+   * \brief Apply the current step to State.
+   * \param state A mutable pointer to State.
+   */
+  Array<Iterator> ApplyToState(State* state) const;
+
+  /*!
+   * \brief Apply the current step to tvm.schedule.
+   * \param stages A pointer to a `te::Stage` Array.
+   * \param stage_to_axes A pointer to a StageToAxesMap.
+   * \param transform_steps An array record all transform steps.

Review comment:
       ```suggestion
      * \param stages The `te::Stage`s used in TVM scheduler applying.
      * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
      * \param transform_steps An array record all transform steps.
      * \return The iterator results after split.
   ```

##########
File path: include/tvm/auto_scheduler/transform_step.h
##########
@@ -487,6 +490,164 @@ class SplitStep : public Step {
   TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode);
 };
 
+/*! \brief Similar to SplitStepNode, but uses split factors from another step
+ * (i.e. Follow another split step) */
+class FollowSplitStepNode : public StepNode {
+ public:
+  /*! \brief The id of the iter to be split. */
+  int iter_id;
+  /*! \brief The index of the split step to follow in the history. */
+  int src_step_id;
+  /*! \brief The number of split level. */
+  int n_split;
+
+  void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+  /*!
+   * \brief Extract split lengths.
+   * \param transform_steps An array record all transform steps.
+   * \param lengths The multiple split factors. Can be None to be filled by search policy.
+   */
+  void ExtractSplitLengths(const Array<Step>& transform_steps,
+                           Array<Optional<Integer>>* lengths) const;
+
+  /*!
+   * \brief Apply the current step to State.
+   * \param state A mutable pointer to State.
+   */
+  Array<Iterator> ApplyToState(State* state) const;
+
+  /*!
+   * \brief Apply the current step to tvm.schedule.
+   * \param stages A pointer to a `te::Stage` Array.
+   * \param stage_to_axes A pointer to a StageToAxesMap.
+   * \param transform_steps An array record all transform steps.
+   */
+  Array<tir::IterVar> ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+                                      const Array<Step>& transform_steps) const;
+
+  /*!
+   * \brief Print the current step as equivalent python schedule API.
+   * \param stages A pointer to a `te::Stage` Array.
+   * \param stage_to_axes A pointer to a StageToAxesMap.

Review comment:
       ```suggestion
      * \param stages The `te::Stage`s used in TVM scheduler applying.
      * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
   ```

##########
File path: include/tvm/auto_scheduler/transform_step.h
##########
@@ -487,6 +490,164 @@ class SplitStep : public Step {
   TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode);
 };
 
+/*! \brief Similar to SplitStepNode, but uses split factors from another step
+ * (i.e. Follow another split step) */
+class FollowSplitStepNode : public StepNode {
+ public:
+  /*! \brief The id of the iter to be split. */
+  int iter_id;
+  /*! \brief The index of the split step to follow in the history. */
+  int src_step_id;
+  /*! \brief The number of split level. */
+  int n_split;
+
+  void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+  /*!
+   * \brief Extract split lengths.
+   * \param transform_steps An array record all transform steps.
+   * \param lengths The multiple split factors. Can be None to be filled by search policy.
+   */
+  void ExtractSplitLengths(const Array<Step>& transform_steps,
+                           Array<Optional<Integer>>* lengths) const;
+
+  /*!
+   * \brief Apply the current step to State.
+   * \param state A mutable pointer to State.
+   */
+  Array<Iterator> ApplyToState(State* state) const;
+
+  /*!
+   * \brief Apply the current step to tvm.schedule.
+   * \param stages A pointer to a `te::Stage` Array.
+   * \param stage_to_axes A pointer to a StageToAxesMap.
+   * \param transform_steps An array record all transform steps.
+   */
+  Array<tir::IterVar> ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+                                      const Array<Step>& transform_steps) const;
+
+  /*!
+   * \brief Print the current step as equivalent python schedule API.
+   * \param stages A pointer to a `te::Stage` Array.
+   * \param stage_to_axes A pointer to a StageToAxesMap.
+   * \param transform_steps An array record all transform steps.
+   * \return Python schedule code.
+   */
+  String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+                          const Array<Step>& transform_steps) const;
+
+  static constexpr const char* record_prefix_str = "FSP";
+
+  static constexpr const char* _type_key = "auto_scheduler.FollowSplitStep";
+  TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to FollowSplitStepNode.
+ * \sa FollowSplitStepNode
+ */
+class FollowSplitStep : public Step {
+ public:
+  /*!
+   * \brief The constructor.
+   * \param stage_id The index of the stage to be split.
+   * \param iter_id The index of the iterator to be split.
+   * \param src_step_id The index of the split step to follow in the history.
+   * \param n_split The number of split level.
+   */
+  FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split);
+
+  /*!
+   * \brief The constructor used to read a step record from JSONReader and create the
+   * corresponding step.
+   * \param reader The input JSONReader.
+   */
+  explicit FollowSplitStep(dmlc::JSONReader* reader);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(FollowSplitStep, Step, FollowSplitStepNode);
+};
+
+/*! \brief Similar to FollowSplitStep, but uses split factors from multiple steps.
+ *  \note This can be used for the split in cooperative fetching.
+ */
+class FollowFusedSplitStepNode : public StepNode {
+ public:
+  /*! \brief The id of the iter to split. */
+  int iter_id;
+  /*! \brief The indices of the split steps to follow in the history. */
+  Array<Integer> src_step_ids;
+  /*! \brief  Use the length in this split level. */
+  int level;
+  /*! \brief If this is true, use factor. Otherwise, use nparts. */
+  bool factor_or_nparts;
+
+  void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+  /*!
+   * \brief Extract split length.
+   * \param transform_steps An array record all transform steps.
+   * \return Split factor.
+   */
+  Optional<Integer> ExtractSplitLength(const Array<Step>& transform_steps) const;
+
+  /*!
+   * \brief Apply the current step to State.
+   * \param state A mutable pointer to State.

Review comment:
       ```suggestion
      * \param state A mutable pointer to state, which will be updated.
      * \return The iterator results after split.
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] comaniac commented on a change in pull request #6142: [Ansor][AutoTVM v2.0] Phase 1: Add follow_split and follow_fused_split steps

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #6142:
URL: https://github.com/apache/incubator-tvm/pull/6142#discussion_r461082016



##########
File path: python/tvm/auto_scheduler/loop_state.py
##########
@@ -301,6 +310,89 @@ def split(self, stage, iterator, lengths, inner_to_outer=True):
                                                      iterator, lengths, inner_to_outer)
         return res
 
+    def follow_split(self, stage, iterator, src_step_id, n_split):
+        """ Schedule primitive extends to split step.
+
+        This step is used to follow a former SplitStep, keeps their iterator structures to be same.

Review comment:
       ```suggestion
           This step splits the iterator by the same factors as the given SplitStep.
   ```

##########
File path: python/tvm/auto_scheduler/loop_state.py
##########
@@ -301,6 +310,89 @@ def split(self, stage, iterator, lengths, inner_to_outer=True):
                                                      iterator, lengths, inner_to_outer)
         return res
 
+    def follow_split(self, stage, iterator, src_step_id, n_split):
+        """ Schedule primitive extends to split step.
+
+        This step is used to follow a former SplitStep, keeps their iterator structures to be same.
+
+        Example cases:
+            With subgraph: Dense -> Relu
+            Some tiling structures are used in Relu stage and we intend to compute the Dense
+            stage at Relu.
+            The follow_split is used here to keep their outer most few iterators the same for
+            applying compute at.
+
+        Parameters
+        ----------
+        stage : Union[int, Operation, Tensor]
+            The Stage to be split, which can be specified by the integer index, Operation,
+            or output tensor of the stage.
+        iterator : Iterator
+            The iterator to split.
+        src_step_id : int
+            The index of the split step to follow in the history.
+        n_split : int
+            The number of split level.
+
+        Returns
+        -------
+        res_its : List[Iterator]
+            The splitted new Iterators.
+        """
+
+        self.state_object, res = _ffi_api.StateFollowSplit(self.state_object,
+                                                           self._resolve_stage_id(stage),
+                                                           iterator,
+                                                           src_step_id, n_split)
+        return res
+
+    def follow_fused_split(self, stage, iterator, src_step_ids, level,
+                           factor_or_nparts):
+        """ Schedule primitive extends to split step.
+
+        This step is used to follow several former SplitSteps and FuseSteps.
+
+        Example cases:
+            With subgraph in GPU schedule: Input -> Dense
+            for i.0@j.0 = ... : Bind to blockIdx.x
+                for i.1@j.1 = ... : Bind to threadIdx.x
+                    for i.2@j.2 = ...
+                        Input_shared = Input ...
+                        for k = ...
+                            Dense = ...
+            We intend to apply cooperative fetching with the Input stage, while the threadIdx.x
+            axis is binded to a iterator generated by split & fuse step.
+            The follow_fused_step is used here to figure out the final extent of the threadIdx.x
+            binded iterator.

Review comment:
       ```suggestion
           This step is used to split an iterator by the same factors as the given list of SplitSteps and FuseSteps.
           
           Notes
           ------
               This step is useful in a scenario that we have a subgraph in GPU schedule: Input -> Dense
               for i.0@j.0 = ... : Bind to blockIdx.x
                   for i.1@j.1 = ... : Bind to threadIdx.x
                       for i.2@j.2 = ...
                           Input_shared = Input ...
                           for k = ...
                               Dense = ...
               We intend to apply cooperative fetching with the input stage, while the threadIdx.x
               axis is bound to an iterator generated by split & fuse step.
               The follow_fused_step is used split the iterator to 2 parts, while the split factor matches
               the final extent of the threadIdx.x bound iterator.
   ```

##########
File path: python/tvm/auto_scheduler/loop_state.py
##########
@@ -301,6 +310,89 @@ def split(self, stage, iterator, lengths, inner_to_outer=True):
                                                      iterator, lengths, inner_to_outer)
         return res
 
+    def follow_split(self, stage, iterator, src_step_id, n_split):
+        """ Schedule primitive extends to split step.
+
+        This step is used to follow a former SplitStep, keeps their iterator structures to be same.
+
+        Example cases:
+            With subgraph: Dense -> Relu
+            Some tiling structures are used in Relu stage and we intend to compute the Dense
+            stage at Relu.
+            The follow_split is used here to keep their outer most few iterators the same for
+            applying compute at.

Review comment:
       ```suggestion
           Notes
           ------
               This step is useful in a scenario that we have subgraph Dense -> Relu,
               and we want to compute the Dense stage at ReLU. In this case, we need them to have
               the same tiling structure of common outer loops.
               The follow_split step could be used here to split the Dense stage and makes sure its
               splitting factors are the same as the given split step for the ReLU stage.
   ```

##########
File path: src/auto_scheduler/transform_step.cc
##########
@@ -111,6 +115,10 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) {
     ps->ApplyToState(state);
   } else if (auto ps = step.as<SplitStepNode>()) {
     ps->ApplyToState(state);
+  } else if (auto ps = step.as<FollowSplitStepNode>()) {

Review comment:
       In this way we will have a super long if-statement because all steps can be put together. Per discussion in the cache_read/write PR, we will refactor this part in the future to eliminate redundant step dispatching.

##########
File path: src/auto_scheduler/transform_step.cc
##########
@@ -778,6 +795,184 @@ String SplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
 
 /********** Steps working on multiple stages **********/
 
+/********** Follow Split **********/
+FollowSplitStep::FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split) {
+  auto node = make_object<FollowSplitStepNode>();
+  node->stage_id = stage_id;
+  node->iter_id = iter_id;
+  node->src_step_id = src_step_id;
+  node->n_split = n_split;
+  data_ = std::move(node);
+}
+
+void FollowSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);

Review comment:
       The order here corresponds to the read order defined in the constructor of this step.

##########
File path: src/auto_scheduler/transform_step.cc
##########
@@ -778,6 +795,184 @@ String SplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
 
 /********** Steps working on multiple stages **********/
 
+/********** Follow Split **********/
+FollowSplitStep::FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split) {
+  auto node = make_object<FollowSplitStepNode>();
+  node->stage_id = stage_id;
+  node->iter_id = iter_id;
+  node->src_step_id = src_step_id;
+  node->n_split = n_split;
+  data_ = std::move(node);
+}
+
+void FollowSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArrayItem(iter_id);
+  writer->WriteArrayItem(src_step_id);
+  writer->WriteArrayItem(n_split);
+}
+
+void FollowSplitStepNode::ExtractSplitLengths(const Array<Step>& transform_steps,
+                                              Array<Optional<Integer>>* lengths) const {
+  CHECK_LT(src_step_id, transform_steps.size());
+  auto ps = transform_steps[src_step_id].as<SplitStepNode>();
+  CHECK(ps != nullptr);
+
+  // get lengths from src step
+  lengths->reserve(n_split);
+  int j = 0;
+  for (; j < n_split - 1; ++j) {
+    lengths->push_back(ps->lengths[j]);

Review comment:
       1. Need to make sure the size of `ps->lengths` is smaller than `n_split - 1`.
   2. Add comment for this loop and the following loop.

##########
File path: src/auto_scheduler/transform_step.cc
##########
@@ -778,6 +795,184 @@ String SplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
 
 /********** Steps working on multiple stages **********/
 
+/********** Follow Split **********/
+FollowSplitStep::FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split) {
+  auto node = make_object<FollowSplitStepNode>();
+  node->stage_id = stage_id;
+  node->iter_id = iter_id;
+  node->src_step_id = src_step_id;
+  node->n_split = n_split;
+  data_ = std::move(node);
+}
+
+void FollowSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArrayItem(iter_id);
+  writer->WriteArrayItem(src_step_id);
+  writer->WriteArrayItem(n_split);
+}
+
+void FollowSplitStepNode::ExtractSplitLengths(const Array<Step>& transform_steps,
+                                              Array<Optional<Integer>>* lengths) const {
+  CHECK_LT(src_step_id, transform_steps.size());
+  auto ps = transform_steps[src_step_id].as<SplitStepNode>();
+  CHECK(ps != nullptr);
+
+  // get lengths from src step
+  lengths->reserve(n_split);
+  int j = 0;
+  for (; j < n_split - 1; ++j) {
+    lengths->push_back(ps->lengths[j]);
+  }
+  PrimExpr last_factor = 1;

Review comment:
       Same here. IMO, I don't think the following loop is required.

##########
File path: include/tvm/auto_scheduler/transform_step.h
##########
@@ -489,6 +492,166 @@ class SplitStep : public Step {
 
 /********** Steps working on multiple stages **********/
 
+/*! \brief Similar to SplitStepNode, but uses split factors from another step
+ * (i.e. Follow another split step) */
+class FollowSplitStepNode : public StepNode {
+ public:
+  /*! \brief The id of the iter to be split. */
+  int iter_id;
+  /*! \brief The index of the split step to follow in the history. */
+  int src_step_id;
+  /*! \brief The number of split level. */
+  int n_split;
+
+  void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+  /*!
+   * \brief Extract split lengths.
+   * \param transform_steps An array record all transform steps.
+   * \param lengths The multiple split factors. Can be None to be filled by search policy.
+   */
+  void ExtractSplitLengths(const Array<Step>& transform_steps,
+                           Array<Optional<Integer>>* lengths) const;
+
+  /*!
+   * \brief Apply the current step to State.
+   * \param state A mutable pointer to State.
+   */
+  Array<Iterator> ApplyToState(State* state) const;
+
+  /*!
+   * \brief Apply the current step to tvm.schedule.
+   * \param stages A pointer to a `te::Stage` Array.
+   * \param stage_to_axes A pointer to a StageToAxesMap.
+   * \param transform_steps An array record all transform steps.
+   */
+  Array<tir::IterVar> ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+                                      const Array<Step>& transform_steps) const;
+
+  /*!
+   * \brief Print the current step as equivalent python schedule API.
+   * \param stages A pointer to a `te::Stage` Array.
+   * \param stage_to_axes A pointer to a StageToAxesMap.
+   * \param transform_steps An array record all transform steps.
+   * \return Python schedule code.
+   */
+  String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+                          const Array<Step>& transform_steps) const;
+
+  static constexpr const char* record_prefix_str = "FSP";
+
+  static constexpr const char* _type_key = "auto_scheduler.FollowSplitStep";
+  TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to FollowSplitStepNode.
+ * \sa FollowSplitStepNode
+ */
+class FollowSplitStep : public Step {
+ public:
+  /*!
+   * \brief The constructor.
+   * \param stage_id The index of the stage to be split.
+   * \param iter_id The index of the iterator to be split.
+   * \param src_step_id The index of the split step to follow in the history.
+   * \param n_split The number of split level.
+   */
+  FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split);
+
+  /*!
+   * \brief The constructor used to read a step record from JSONReader and create the
+   * corresponding step.
+   * \param reader The input JSONReader.
+   */
+  explicit FollowSplitStep(dmlc::JSONReader* reader);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(FollowSplitStep, Step, FollowSplitStepNode);
+};
+
+/*! \brief Similar to FollowSplitStep, but use split factors from multiple steps.
+ *  \note This can be used for the split in cooperative fetching

Review comment:
       ```suggestion
   /*! \brief Similar to FollowSplitStep, but uses split factors from multiple steps.
    *  \note This can be used for the split in cooperative fetching.
   ```

##########
File path: src/auto_scheduler/transform_step.cc
##########
@@ -778,6 +795,184 @@ String SplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
 
 /********** Steps working on multiple stages **********/
 
+/********** Follow Split **********/
+FollowSplitStep::FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split) {
+  auto node = make_object<FollowSplitStepNode>();
+  node->stage_id = stage_id;
+  node->iter_id = iter_id;
+  node->src_step_id = src_step_id;
+  node->n_split = n_split;
+  data_ = std::move(node);
+}
+
+void FollowSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArrayItem(iter_id);
+  writer->WriteArrayItem(src_step_id);
+  writer->WriteArrayItem(n_split);
+}
+
+void FollowSplitStepNode::ExtractSplitLengths(const Array<Step>& transform_steps,
+                                              Array<Optional<Integer>>* lengths) const {
+  CHECK_LT(src_step_id, transform_steps.size());
+  auto ps = transform_steps[src_step_id].as<SplitStepNode>();
+  CHECK(ps != nullptr);

Review comment:
       Add comment.

##########
File path: src/auto_scheduler/transform_step.cc
##########
@@ -778,6 +795,184 @@ String SplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
 
 /********** Steps working on multiple stages **********/
 
+/********** Follow Split **********/
+FollowSplitStep::FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split) {
+  auto node = make_object<FollowSplitStepNode>();
+  node->stage_id = stage_id;
+  node->iter_id = iter_id;
+  node->src_step_id = src_step_id;
+  node->n_split = n_split;
+  data_ = std::move(node);
+}
+
+void FollowSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArrayItem(iter_id);
+  writer->WriteArrayItem(src_step_id);
+  writer->WriteArrayItem(n_split);
+}
+
+void FollowSplitStepNode::ExtractSplitLengths(const Array<Step>& transform_steps,
+                                              Array<Optional<Integer>>* lengths) const {
+  CHECK_LT(src_step_id, transform_steps.size());
+  auto ps = transform_steps[src_step_id].as<SplitStepNode>();
+  CHECK(ps != nullptr);
+
+  // get lengths from src step
+  lengths->reserve(n_split);
+  int j = 0;
+  for (; j < n_split - 1; ++j) {
+    lengths->push_back(ps->lengths[j]);
+  }
+  PrimExpr last_factor = 1;
+  for (; j < static_cast<int>(ps->lengths.size()); ++j) {
+    if (ps->lengths[j]) {
+      last_factor *= ps->lengths[j].value();
+    } else {
+      last_factor = PrimExpr();
+      break;
+    }
+  }
+  if (last_factor.defined()) {
+    lengths->push_back(Downcast<Integer>(last_factor));
+  } else {
+    lengths->push_back(NullOpt);
+  }
+}
+
+FollowSplitStep::FollowSplitStep(dmlc::JSONReader* reader) {
+  auto node = make_object<FollowSplitStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->iter_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->src_step_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->n_split);
+
+  data_ = std::move(node);
+}
+
+Array<Iterator> FollowSplitStepNode::ApplyToState(State* state) const {
+  Array<Optional<Integer>> lengths;
+  ExtractSplitLengths((*state)->transform_steps, &lengths);
+  return ApplySplitToState(state, stage_id, iter_id, lengths, true);
+}
+
+Array<IterVar> FollowSplitStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+                                                    StageToAxesMap* stage_to_axes,
+                                                    const Array<Step>& transform_steps) const {
+  Array<Optional<Integer>> lengths;
+  ExtractSplitLengths(transform_steps, &lengths);
+  return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, lengths, true);
+}
+
+String FollowSplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
+                                             StageToAxesMap* stage_to_axes,
+                                             const Array<Step>& transform_steps) const {
+  Array<Optional<Integer>> lengths;
+  ExtractSplitLengths(transform_steps, &lengths);
+  return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, true);
+}
+
+/********** Follow Fused Split **********/
+FollowFusedSplitStep::FollowFusedSplitStep(int stage_id, int iter_id,
+                                           const Array<Integer>& src_step_ids, int level,
+                                           bool factor_or_nparts) {
+  auto node = make_object<FollowFusedSplitStepNode>();
+  node->stage_id = stage_id;
+  node->iter_id = iter_id;
+  node->src_step_ids = src_step_ids;
+  node->level = level;
+  node->factor_or_nparts = factor_or_nparts;
+  data_ = std::move(node);
+}
+
+FollowFusedSplitStep::FollowFusedSplitStep(dmlc::JSONReader* reader) {
+  auto node = make_object<FollowFusedSplitStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->iter_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  std::vector<int> int_list;
+  reader->Read(&int_list);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->level);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->factor_or_nparts);
+
+  ::tvm::Array<::tvm::Integer> src_step_ids;
+  for (const auto& i : int_list) {
+    src_step_ids.push_back(i);
+  }
+  node->src_step_ids = src_step_ids;
+  data_ = std::move(node);
+}
+
+void FollowFusedSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArrayItem(iter_id);
+  writer->WriteArrayItem(IntArrayToVector(src_step_ids));
+  writer->WriteArrayItem(level);
+  writer->WriteArrayItem(static_cast<int>(factor_or_nparts));
+}
+
+Optional<Integer> FollowFusedSplitStepNode::ExtractSplitLength(
+    const Array<Step>& transform_steps) const {
+  PrimExpr ret(1);
+
+  for (int src_step_id : src_step_ids) {
+    CHECK_LT(src_step_id, transform_steps.size());
+    auto ps = transform_steps[src_step_id].as<SplitStepNode>();
+    CHECK(ps != nullptr);

Review comment:
       Add comments.

##########
File path: src/auto_scheduler/transform_step.cc
##########
@@ -778,6 +795,184 @@ String SplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
 
 /********** Steps working on multiple stages **********/
 
+/********** Follow Split **********/
+FollowSplitStep::FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split) {
+  auto node = make_object<FollowSplitStepNode>();
+  node->stage_id = stage_id;
+  node->iter_id = iter_id;
+  node->src_step_id = src_step_id;
+  node->n_split = n_split;
+  data_ = std::move(node);
+}
+
+void FollowSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArrayItem(iter_id);
+  writer->WriteArrayItem(src_step_id);
+  writer->WriteArrayItem(n_split);
+}
+
+void FollowSplitStepNode::ExtractSplitLengths(const Array<Step>& transform_steps,
+                                              Array<Optional<Integer>>* lengths) const {
+  CHECK_LT(src_step_id, transform_steps.size());
+  auto ps = transform_steps[src_step_id].as<SplitStepNode>();
+  CHECK(ps != nullptr);
+
+  // get lengths from src step
+  lengths->reserve(n_split);
+  int j = 0;
+  for (; j < n_split - 1; ++j) {
+    lengths->push_back(ps->lengths[j]);
+  }
+  PrimExpr last_factor = 1;
+  for (; j < static_cast<int>(ps->lengths.size()); ++j) {
+    if (ps->lengths[j]) {
+      last_factor *= ps->lengths[j].value();
+    } else {
+      last_factor = PrimExpr();
+      break;
+    }
+  }
+  if (last_factor.defined()) {
+    lengths->push_back(Downcast<Integer>(last_factor));
+  } else {
+    lengths->push_back(NullOpt);
+  }
+}
+
+FollowSplitStep::FollowSplitStep(dmlc::JSONReader* reader) {
+  auto node = make_object<FollowSplitStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->iter_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->src_step_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->n_split);
+
+  data_ = std::move(node);
+}
+
+Array<Iterator> FollowSplitStepNode::ApplyToState(State* state) const {
+  Array<Optional<Integer>> lengths;
+  ExtractSplitLengths((*state)->transform_steps, &lengths);
+  return ApplySplitToState(state, stage_id, iter_id, lengths, true);
+}
+
+Array<IterVar> FollowSplitStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+                                                    StageToAxesMap* stage_to_axes,
+                                                    const Array<Step>& transform_steps) const {
+  Array<Optional<Integer>> lengths;
+  ExtractSplitLengths(transform_steps, &lengths);
+  return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, lengths, true);
+}
+
+String FollowSplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
+                                             StageToAxesMap* stage_to_axes,
+                                             const Array<Step>& transform_steps) const {
+  Array<Optional<Integer>> lengths;
+  ExtractSplitLengths(transform_steps, &lengths);
+  return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, true);
+}
+
+/********** Follow Fused Split **********/
+FollowFusedSplitStep::FollowFusedSplitStep(int stage_id, int iter_id,
+                                           const Array<Integer>& src_step_ids, int level,
+                                           bool factor_or_nparts) {
+  auto node = make_object<FollowFusedSplitStepNode>();
+  node->stage_id = stage_id;
+  node->iter_id = iter_id;
+  node->src_step_ids = src_step_ids;
+  node->level = level;
+  node->factor_or_nparts = factor_or_nparts;
+  data_ = std::move(node);
+}
+
+FollowFusedSplitStep::FollowFusedSplitStep(dmlc::JSONReader* reader) {
+  auto node = make_object<FollowFusedSplitStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);

Review comment:
       This is the decision we made in the last PR. Since the (de)serialization logic of each step differ a lot, self-maintained is better to manage. In this way, we can easily trace the logic simply by looking at the step definition class instead of searching around the code base.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] jiuqi-yang commented on a change in pull request #6142: [Ansor][AutoTVM v2.0] Phase 1: Add follow_split and follow_fused_split steps

Posted by GitBox <gi...@apache.org>.
jiuqi-yang commented on a change in pull request #6142:
URL: https://github.com/apache/incubator-tvm/pull/6142#discussion_r461279732



##########
File path: src/auto_scheduler/transform_step.cc
##########
@@ -778,6 +795,184 @@ String SplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
 
 /********** Steps working on multiple stages **********/
 
+/********** Follow Split **********/
+FollowSplitStep::FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split) {
+  auto node = make_object<FollowSplitStepNode>();
+  node->stage_id = stage_id;
+  node->iter_id = iter_id;
+  node->src_step_id = src_step_id;
+  node->n_split = n_split;
+  data_ = std::move(node);
+}
+
+void FollowSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArrayItem(iter_id);
+  writer->WriteArrayItem(src_step_id);
+  writer->WriteArrayItem(n_split);
+}
+
+void FollowSplitStepNode::ExtractSplitLengths(const Array<Step>& transform_steps,
+                                              Array<Optional<Integer>>* lengths) const {
+  CHECK_LT(src_step_id, transform_steps.size());
+  auto ps = transform_steps[src_step_id].as<SplitStepNode>();
+  CHECK(ps != nullptr);
+
+  // get lengths from src step
+  lengths->reserve(n_split);
+  int j = 0;
+  for (; j < n_split - 1; ++j) {
+    lengths->push_back(ps->lengths[j]);
+  }
+  PrimExpr last_factor = 1;

Review comment:
       I think the following loop to get the last splitting factor for follow_split step is requied. Otherwise we may miss a factor.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] jcf94 commented on a change in pull request #6142: [Ansor][AutoTVM v2.0] Phase 1: Add follow_split and follow_fused_split steps

Posted by GitBox <gi...@apache.org>.
jcf94 commented on a change in pull request #6142:
URL: https://github.com/apache/incubator-tvm/pull/6142#discussion_r460685681



##########
File path: include/tvm/auto_scheduler/loop_state.h
##########
@@ -356,9 +356,30 @@ class State : public ObjectRef {
    * \note If we do split on an iterator which has stages attached at it(by compute_at), the inner
    * most iterator of split results will become the new attach point.
    */
-  TVM_DLL Array<Iterator> split(int stage_id, const Iterator& it,
-                                const Array<Optional<Integer>>& lengths,
-                                bool inner_to_outer = true);
+  Array<Iterator> split(int stage_id, const Iterator& it, const Array<Optional<Integer>>& lengths,
+                        bool inner_to_outer = true);

Review comment:
       Recover the split function to the original one.

##########
File path: include/tvm/auto_scheduler/loop_state.h
##########
@@ -356,9 +356,30 @@ class State : public ObjectRef {
    * \note If we do split on an iterator which has stages attached at it(by compute_at), the inner
    * most iterator of split results will become the new attach point.
    */
-  TVM_DLL Array<Iterator> split(int stage_id, const Iterator& it,
-                                const Array<Optional<Integer>>& lengths,
-                                bool inner_to_outer = true);
+  Array<Iterator> split(int stage_id, const Iterator& it, const Array<Optional<Integer>>& lengths,
+                        bool inner_to_outer = true);
+  /*!
+   * \brief Schedule primitive extends to split step.
+   * \param stage_id The index of the stage to be split.
+   * \param it The iterator to be split.
+   * \param src_step_id The index of the split step to follow in the history.
+   * \param n_split The number of split level.
+   * \return The splitted new Iterators.
+   */
+  Array<Iterator> follow_split(int stage_id, const Iterator& it, int src_step_id, int n_split);

Review comment:
       Add `TVM_DLL` hint before this function, see other function as a reference.

##########
File path: src/auto_scheduler/loop_state.cc
##########
@@ -454,6 +475,25 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateSplit")
       return Array<ObjectRef>{state, res};
     });
 
+TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowSplit")
+    .set_body_typed([](State state, int stage_id, const Iterator& it, int src_step_id,
+                       int n_split) {
+      const auto& res = state.follow_split(stage_id, it, src_step_id, n_split);
+      return Array<ObjectRef>{state, Array<Iterator>(res)};
+    });
+
+TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowFusedSplit")
+    .set_body_typed([](State state, int stage_id, const Iterator& it,
+                       const Array<IntImm>& src_step_ids, int level, bool factor_or_nparts) {

Review comment:
       This seems a historical issue from our old code base, we can directly use Array<Integer> as the function parameter, and remove the `for ... copy` below.

##########
File path: src/auto_scheduler/transform_step.cc
##########
@@ -778,6 +795,184 @@ String SplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
 
 /********** Steps working on multiple stages **********/
 
+/********** Follow Split **********/
+FollowSplitStep::FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split) {
+  auto node = make_object<FollowSplitStepNode>();
+  node->stage_id = stage_id;
+  node->iter_id = iter_id;
+  node->src_step_id = src_step_id;
+  node->n_split = n_split;
+  data_ = std::move(node);
+}
+
+void FollowSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArrayItem(iter_id);
+  writer->WriteArrayItem(src_step_id);
+  writer->WriteArrayItem(n_split);
+}
+
+void FollowSplitStepNode::ExtractSplitLengths(const Array<Step>& transform_steps,
+                                              Array<Optional<Integer>>* lengths) const {
+  CHECK_LT(src_step_id, transform_steps.size());
+  auto ps = transform_steps[src_step_id].as<SplitStepNode>();
+  CHECK(ps != nullptr);
+
+  // get lengths from src step
+  lengths->reserve(n_split);
+  int j = 0;
+  for (; j < n_split - 1; ++j) {
+    lengths->push_back(ps->lengths[j]);
+  }
+  PrimExpr last_factor = 1;
+  for (; j < static_cast<int>(ps->lengths.size()); ++j) {
+    if (ps->lengths[j]) {
+      last_factor *= ps->lengths[j].value();
+    } else {
+      last_factor = PrimExpr();
+      break;
+    }
+  }
+  if (last_factor.defined()) {
+    lengths->push_back(Downcast<Integer>(last_factor));
+  } else {
+    lengths->push_back(NullOpt);
+  }
+}
+
+FollowSplitStep::FollowSplitStep(dmlc::JSONReader* reader) {
+  auto node = make_object<FollowSplitStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->iter_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->src_step_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->n_split);
+
+  data_ = std::move(node);
+}
+
+Array<Iterator> FollowSplitStepNode::ApplyToState(State* state) const {
+  Array<Optional<Integer>> lengths;
+  ExtractSplitLengths((*state)->transform_steps, &lengths);
+  return ApplySplitToState(state, stage_id, iter_id, lengths, true);
+}
+
+Array<IterVar> FollowSplitStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+                                                    StageToAxesMap* stage_to_axes,
+                                                    const Array<Step>& transform_steps) const {
+  Array<Optional<Integer>> lengths;
+  ExtractSplitLengths(transform_steps, &lengths);
+  return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, lengths, true);
+}
+
+String FollowSplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
+                                             StageToAxesMap* stage_to_axes,
+                                             const Array<Step>& transform_steps) const {
+  Array<Optional<Integer>> lengths;
+  ExtractSplitLengths(transform_steps, &lengths);
+  return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, true);
+}
+
+/********** Follow Fused Split **********/
+FollowFusedSplitStep::FollowFusedSplitStep(int stage_id, int iter_id,
+                                           const Array<Integer>& src_step_ids, int level,
+                                           bool factor_or_nparts) {
+  auto node = make_object<FollowFusedSplitStepNode>();
+  node->stage_id = stage_id;
+  node->iter_id = iter_id;
+  node->src_step_ids = src_step_ids;
+  node->level = level;
+  node->factor_or_nparts = factor_or_nparts;
+  data_ = std::move(node);
+}
+
+FollowFusedSplitStep::FollowFusedSplitStep(dmlc::JSONReader* reader) {
+  auto node = make_object<FollowFusedSplitStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->iter_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  std::vector<int> int_list;
+  reader->Read(&int_list);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->level);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->factor_or_nparts);
+
+  ::tvm::Array<::tvm::Integer> src_step_ids;
+  for (const auto& i : int_list) {
+    src_step_ids.push_back(i);
+  }
+  node->src_step_ids = src_step_ids;
+  data_ = std::move(node);
+}
+
+void FollowFusedSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArrayItem(iter_id);
+  writer->WriteArrayItem(IntArrayToVector(src_step_ids));
+  writer->WriteArrayItem(level);
+  writer->WriteArrayItem(static_cast<int>(factor_or_nparts));
+}
+
+Optional<Integer> FollowFusedSplitStepNode::ExtractSplitLength(
+    const Array<Step>& transform_steps) const {
+  PrimExpr ret(1);
+
+  for (int src_step_id : src_step_ids) {
+    CHECK_LT(src_step_id, transform_steps.size());
+    auto ps = transform_steps[src_step_id].as<SplitStepNode>();
+    CHECK(ps != nullptr);
+    if (ps->lengths[level] && ret.defined()) {
+      ret *= ps->lengths[level].value();
+    } else {
+      return NullOpt;
+    }
+  }
+  return Downcast<Integer>(ret);
+}
+
+Array<Iterator> FollowFusedSplitStepNode::ApplyToState(State* state) const {
+  const Optional<Integer>& length = ExtractSplitLength((*state)->transform_steps);
+  return ApplySplitToState(state, stage_id, iter_id, {length}, factor_or_nparts);
+}
+
+Array<IterVar> FollowFusedSplitStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+                                                         StageToAxesMap* stage_to_axes,
+                                                         const Array<Step>& transform_steps) const {
+  const Optional<Integer>& length = ExtractSplitLength(transform_steps);
+  return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, {length}, factor_or_nparts);
+}
+
+String FollowFusedSplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
+                                                  StageToAxesMap* stage_to_axes,
+                                                  const Array<Step>& transform_steps) const {
+  const Optional<Integer>& length = ExtractSplitLength(transform_steps);
+  return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, {length},
+                               factor_or_nparts);
+}
+
+/********** Primitives working on multiple stages **********/

Review comment:
       Remove this line, and put the line 796: `Steps working on multiple stages` here.

##########
File path: src/auto_scheduler/loop_state.cc
##########
@@ -268,6 +268,27 @@ Array<Iterator> State::split(int stage_id, const Iterator& it,
   return step->ApplyToState(this);
 }
 
+Array<Iterator> State::follow_split(int stage_id, const Iterator& it, int src_step_id,
+                                    int n_split) {
+  const Stage& stage = operator->()->stages[stage_id];
+
+  FollowSplitStep step =
+      FollowSplitStep(stage_id, GetIndex(stage->iters, it), src_step_id, n_split);
+  CopyOnWrite()->transform_steps.push_back(step);
+  return step->ApplyToState(this);
+}
+
+Array<Iterator> State::follow_fused_split(int stage_id, const Iterator& it,
+                                          const Array<Integer>& src_step_ids, int level,
+                                          bool factor_or_nparts) {
+  const Stage& stage = operator->()->stages[stage_id];
+

Review comment:
       diito

##########
File path: src/auto_scheduler/loop_state.cc
##########
@@ -268,6 +268,27 @@ Array<Iterator> State::split(int stage_id, const Iterator& it,
   return step->ApplyToState(this);
 }
 
+Array<Iterator> State::follow_split(int stage_id, const Iterator& it, int src_step_id,
+                                    int n_split) {
+  const Stage& stage = operator->()->stages[stage_id];
+

Review comment:
       Remove this blank line.

##########
File path: include/tvm/auto_scheduler/transform_step.h
##########
@@ -489,6 +490,166 @@ class SplitStep : public Step {
 
 /********** Steps working on multiple stages **********/
 
+/*! \brief Similar to SplitStepNode, but uses split factors from another step
+ * (i.e. Follow another split step) */
+class FollowSplitStepNode : public StepNode {
+ public:
+  /*! \brief The id of the iter to be split. */
+  int iter_id;
+  /*! \brief The index of the split step to follow in the history. */
+  int src_step_id;
+  /*! \brief The number of split level. */
+  int n_split;
+
+  void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+  /*!
+   * \brief Extract split lengths.
+   * \param transform_steps An array record all transform steps.
+   * \param lengths The multiple split factors. Can be None to be filled by search policy.
+   */
+  void ExtractSplitLengths(const Array<Step>& transform_steps,
+                           Array<Optional<Integer>>* lengths) const;
+
+  /*!
+   * \brief Apply the current step to State.
+   * \param state A mutable pointer to State.
+   */
+  Array<Iterator> ApplyToState(State* state) const;
+
+  /*!
+   * \brief Apply the current step to tvm.schedule.
+   * \param stages A pointer to a `te::Stage` Array.
+   * \param stage_to_axes A pointer to a StageToAxesMap.
+   * \param transform_steps An array record all transform steps.
+   */
+  Array<tir::IterVar> ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+                                      const Array<Step>& transform_steps) const;
+
+  /*!
+   * \brief Print the current step as equivalent python schedule API.
+   * \param stages A pointer to a `te::Stage` Array.
+   * \param stage_to_axes A pointer to a StageToAxesMap.
+   * \param transform_steps An array record all transform steps.
+   * \return Python schedule code.
+   */
+  String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+                          const Array<Step>& transform_steps) const;
+
+  static constexpr const char* record_prefix_str = "FSP";
+
+  static constexpr const char* _type_key = "auto_scheduler.FollowSplitStep";
+  TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to FollowSplitStepNode.
+ * \sa FollowSplitStepNode
+ */
+class FollowSplitStep : public Step {
+ public:
+  /*!
+   * \brief The constructor.
+   * \param stage_id The index of the stage to be split.
+   * \param iter_id The index of the iterator to be split.
+   * \param src_step_id The index of the split step to follow in the history.
+   * \param n_split The number of split level.
+   */
+  FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split);
+
+  /*!
+   * \brief The constructor used to read a step record from JSONReader and create the
+   * corresponding step.
+   * \param reader The input JSONReader.
+   */
+  explicit FollowSplitStep(dmlc::JSONReader* reader);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(FollowSplitStep, Step, FollowSplitStepNode);
+};
+
+/*! \brief Similar to FollowSplitStep, but use split factors from multiple steps.
+ *  \Note This can be used for the split in cooperative fetching
+ */
+class FollowFusedSplitStepNode : public StepNode {
+ public:
+  /*! \brief The id of the iter to split. */
+  int iter_id;
+  /*! \brief The indices of the split steps to follow in the history. */
+  Array<Integer> src_step_ids;
+  /*! \brief  Use the length in this split level. */
+  int level;
+  /*! \brief If this is true, use factor. Otherwise, use nparts. */
+  bool factor_or_nparts;
+
+  void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+  /*!
+   * \brief Extract split length.
+   * \param transform_steps An array record all transform steps.
+   * \return Split factor.
+   */
+  Optional<Integer> ExtractSplitLength(const Array<Step>& transform_steps) const;
+
+  /*!
+   * \brief Apply the current step to State.
+   * \param state A mutable pointer to State.
+   */
+  Array<Iterator> ApplyToState(State* state) const;
+
+  /*!
+   * \brief Apply the current step to tvm.schedule.
+   * \param stages A pointer to a `te::Stage` Array.
+   * \param stage_to_axes A pointer to a StageToAxesMap.
+   * \param transform_steps An array record all transform steps.
+   */
+  Array<tir::IterVar> ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+                                      const Array<Step>& transform_steps) const;
+
+  /*!
+   * \brief Print the current step as equivalent python schedule API.
+   * \param stages A pointer to a `te::Stage` Array.
+   * \param stage_to_axes A pointer to a StageToAxesMap.
+   * \param transform_steps An array record all transform steps.
+   * \return Python schedule code.
+   */
+  String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+                          const Array<Step>& transform_steps) const;
+
+  static constexpr const char* record_prefix_str = "FFSP";
+
+  static constexpr const char* _type_key = "auto_scheduler.FollowFusedSplitStep";
+  TVM_DECLARE_FINAL_OBJECT_INFO(FollowFusedSplitStepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to FollowFusedSplitStepNode.
+ * \sa FollowFusedSplitStepNode
+ */
+class FollowFusedSplitStep : public Step {
+ public:
+  /*!
+   * \brief The constructor.
+   * \param stage_id The index of the stage to be split.
+   * \param iter_id The index of the iterator to be split.
+   * \param src_step_ids An array of index for split step to follow in the history.
+   * \param level Use the length in this split level.
+   * \param factor_or_nparts If this is true, use factor. Otherwise, use nparts.
+   */
+  FollowFusedSplitStep(int stage_id, int iter_id, const Array<Integer>& src_step_ids, int level,
+                       bool factor_or_nparts);
+
+  /*!
+   * \brief The constructor used to read a step record from JSONReader and create the
+   * corresponding step.
+   * \param reader The input JSONReader.
+   */
+  explicit FollowFusedSplitStep(dmlc::JSONReader* reader);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(FollowFusedSplitStep, Step, FollowFusedSplitStepNode);
+};
+
+/********** Primitives working on multiple stages **********/

Review comment:
       Remove this line, and put the line 491: `Steps working on multiple stages` here.

##########
File path: tests/python/unittest/test_auto_scheduler_measure.py
##########
@@ -22,13 +22,12 @@
 from tvm import te, auto_scheduler
 import tempfile
 
-from test_auto_scheduler_common import get_tiled_matmul
-
+from test_auto_scheduler_common import matmul_auto_scheduler_test, get_tiled_matmul
 
 def test_record():
     if not tvm.runtime.enabled("llvm"):
         return
-
+    #pdb.set_trace()

Review comment:
       Remove the debug code.

##########
File path: tests/python/unittest/test_auto_scheduler_loop_state.py
##########
@@ -142,6 +142,318 @@ def test_compute_at_root_inline():
     assert s0[conv].iters[5].range.extent == 7
     assert s0[conv].iters[6].range.extent == 7
 
+def test_follow_split_follow_fused_split():

Review comment:
       Put this function in the back of `test_cache_read_write`.
   Have a careful check that there seems to be two `test_cache_read_write` function in this file? This should not be marked in the git diff.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] merrymercy commented on a change in pull request #6142: [Ansor][AutoTVM v2.0] Phase 1: Add follow_split and follow_fused_split steps

Posted by GitBox <gi...@apache.org>.
merrymercy commented on a change in pull request #6142:
URL: https://github.com/apache/incubator-tvm/pull/6142#discussion_r461125355



##########
File path: include/tvm/auto_scheduler/loop_state.h
##########
@@ -359,6 +359,29 @@ class State : public ObjectRef {
   TVM_DLL Array<Iterator> split(int stage_id, const Iterator& it,
                                 const Array<Optional<Integer>>& lengths,
                                 bool inner_to_outer = true);
+  /*!
+   * \brief Schedule primitive extends to split step.
+   * \param stage_id The index of the stage to be split.
+   * \param it The iterator to be split.
+   * \param src_step_id The index of the split step to follow in the history.

Review comment:
       ```suggestion
      * \param src_step_id The index of the split step to be followed in the history.
   ```

##########
File path: python/tvm/auto_scheduler/loop_state.py
##########
@@ -301,6 +310,89 @@ def split(self, stage, iterator, lengths, inner_to_outer=True):
                                                      iterator, lengths, inner_to_outer)
         return res
 
+    def follow_split(self, stage, iterator, src_step_id, n_split):
+        """ Schedule primitive extends to split step.
+
+        This step is used to follow a former SplitStep, keeps their iterator structures to be same.

Review comment:
       ```suggestion
           This step is used to follow a former SplitStep and keep their iterator structures to be the same.
   ```

##########
File path: src/auto_scheduler/transform_step.cc
##########
@@ -778,6 +795,184 @@ String SplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
 
 /********** Steps working on multiple stages **********/
 
+/********** Follow Split **********/
+FollowSplitStep::FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split) {
+  auto node = make_object<FollowSplitStepNode>();
+  node->stage_id = stage_id;
+  node->iter_id = iter_id;
+  node->src_step_id = src_step_id;
+  node->n_split = n_split;
+  data_ = std::move(node);
+}
+
+void FollowSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArrayItem(iter_id);
+  writer->WriteArrayItem(src_step_id);
+  writer->WriteArrayItem(n_split);
+}
+
+void FollowSplitStepNode::ExtractSplitLengths(const Array<Step>& transform_steps,
+                                              Array<Optional<Integer>>* lengths) const {
+  CHECK_LT(src_step_id, transform_steps.size());
+  auto ps = transform_steps[src_step_id].as<SplitStepNode>();
+  CHECK(ps != nullptr);
+
+  // get lengths from src step
+  lengths->reserve(n_split);
+  int j = 0;
+  for (; j < n_split - 1; ++j) {
+    lengths->push_back(ps->lengths[j]);
+  }
+  PrimExpr last_factor = 1;
+  for (; j < static_cast<int>(ps->lengths.size()); ++j) {
+    if (ps->lengths[j]) {
+      last_factor *= ps->lengths[j].value();
+    } else {
+      last_factor = PrimExpr();
+      break;
+    }
+  }
+  if (last_factor.defined()) {
+    lengths->push_back(Downcast<Integer>(last_factor));
+  } else {
+    lengths->push_back(NullOpt);
+  }
+}
+
+FollowSplitStep::FollowSplitStep(dmlc::JSONReader* reader) {
+  auto node = make_object<FollowSplitStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->iter_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->src_step_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->n_split);
+
+  data_ = std::move(node);
+}
+
+Array<Iterator> FollowSplitStepNode::ApplyToState(State* state) const {
+  Array<Optional<Integer>> lengths;
+  ExtractSplitLengths((*state)->transform_steps, &lengths);
+  return ApplySplitToState(state, stage_id, iter_id, lengths, true);
+}
+
+Array<IterVar> FollowSplitStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+                                                    StageToAxesMap* stage_to_axes,
+                                                    const Array<Step>& transform_steps) const {
+  Array<Optional<Integer>> lengths;
+  ExtractSplitLengths(transform_steps, &lengths);
+  return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, lengths, true);
+}
+
+String FollowSplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
+                                             StageToAxesMap* stage_to_axes,
+                                             const Array<Step>& transform_steps) const {
+  Array<Optional<Integer>> lengths;
+  ExtractSplitLengths(transform_steps, &lengths);
+  return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, true);
+}
+
+/********** Follow Fused Split **********/
+FollowFusedSplitStep::FollowFusedSplitStep(int stage_id, int iter_id,
+                                           const Array<Integer>& src_step_ids, int level,
+                                           bool factor_or_nparts) {
+  auto node = make_object<FollowFusedSplitStepNode>();
+  node->stage_id = stage_id;
+  node->iter_id = iter_id;
+  node->src_step_ids = src_step_ids;
+  node->level = level;
+  node->factor_or_nparts = factor_or_nparts;
+  data_ = std::move(node);
+}
+
+FollowFusedSplitStep::FollowFusedSplitStep(dmlc::JSONReader* reader) {
+  auto node = make_object<FollowFusedSplitStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);

Review comment:
       Yes, it depends on the order. This is a limitation of tvm's JSON serialization support.
   However, as you suggested, the only two functions using this order (`FollowFusedSplitStep::FollowFusedSplitSte` and `FollowFusedSplitStepNode::WriteToRecord`) have been alerady located adjacently in this file.

##########
File path: include/tvm/auto_scheduler/loop_state.h
##########
@@ -359,6 +359,29 @@ class State : public ObjectRef {
   TVM_DLL Array<Iterator> split(int stage_id, const Iterator& it,
                                 const Array<Optional<Integer>>& lengths,
                                 bool inner_to_outer = true);
+  /*!
+   * \brief Schedule primitive extends to split step.
+   * \param stage_id The index of the stage to be split.
+   * \param it The iterator to be split.
+   * \param src_step_id The index of the split step to follow in the history.
+   * \param n_split The number of split level.
+   * \return The splitted new Iterators.
+   */
+  TVM_DLL Array<Iterator> follow_split(int stage_id, const Iterator& it, int src_step_id,
+                                       int n_split);
+  /*!
+   * \brief Schedule primitive extends to split step.
+   * \param stage_id The index of the stage to be split.
+   * \param it The iterator to be split.
+   * \param src_step_ids The indices of the split steps to follow in the history.

Review comment:
       Propagate this change to other places.

##########
File path: src/auto_scheduler/transform_step.cc
##########
@@ -111,6 +115,10 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) {
     ps->ApplyToState(state);
   } else if (auto ps = step.as<SplitStepNode>()) {
     ps->ApplyToState(state);
+  } else if (auto ps = step.as<FollowSplitStepNode>()) {

Review comment:
       The `ps`s in different branches have different types, so they cannot be merged.

##########
File path: include/tvm/auto_scheduler/loop_state.h
##########
@@ -359,6 +359,29 @@ class State : public ObjectRef {
   TVM_DLL Array<Iterator> split(int stage_id, const Iterator& it,
                                 const Array<Optional<Integer>>& lengths,
                                 bool inner_to_outer = true);
+  /*!
+   * \brief Schedule primitive extends to split step.
+   * \param stage_id The index of the stage to be split.
+   * \param it The iterator to be split.
+   * \param src_step_id The index of the split step to follow in the history.
+   * \param n_split The number of split level.
+   * \return The splitted new Iterators.
+   */
+  TVM_DLL Array<Iterator> follow_split(int stage_id, const Iterator& it, int src_step_id,
+                                       int n_split);
+  /*!
+   * \brief Schedule primitive extends to split step.
+   * \param stage_id The index of the stage to be split.
+   * \param it The iterator to be split.
+   * \param src_step_ids The indices of the split steps to follow in the history.

Review comment:
       ```suggestion
      * \param src_step_ids The indices of the split steps to be followed in the history.
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] jcf94 edited a comment on pull request #6142: [Ansor][AutoTVM v2.0] Phase 1: Add follow_split and follow_fused_split steps

Posted by GitBox <gi...@apache.org>.
jcf94 edited a comment on pull request #6142:
URL: https://github.com/apache/incubator-tvm/pull/6142#issuecomment-664167975


   Hi, all. This is an student intern of us, who is now helping us with the Ansor upstreaming. 😄 
   
   The follow_split & follow_fused_split are two steps extent to `te.Stage.Split`. Each of these will collect information from the former history and process the split.
   
   ## FollowSplit
   
   This is mainly used in stage fusion using compute at.
   For example we have stages: `Dense -> Relu`:
   We've already done some tiling on Relu, and we would like to compute the Dense at the Relu stage. FollowSplit step is used to keep the outer most few iterators of Dense the same as the Relu stage.
   Since in Ansor, the split factor of Relu stage may be left as a None placeholder to be filled by search policy, by this way we can easily write a schedule with some kind of dynamic dependence.
   
   ## FollowFusedSplit
   
   This is mainly used in GPU cooperative fetching.
   For example we have stages: `Input -> Dense`:
   ```
               for i.0@j.0 = ... : Bind to blockIdx.x
                   for i.1@j.1 = ... : Bind to threadIdx.x
                       for i.2@j.2 = ...
                           Input_shared = Input ...
                           for k = ...
                               Dense = ...
   ```
   In Ansor's search policy, the outer stage has been tiled. The the threadIdx.x axis is binded to a iterator generated by split & fuse step. We use this step to compute out the final extent of the threadIdx.x binded iterator, to make sure that Input_shared stage can split out a iterator with same extent.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] jcf94 edited a comment on pull request #6142: [Ansor][AutoTVM v2.0] Phase 1: Add follow_split and follow_fused_split steps

Posted by GitBox <gi...@apache.org>.
jcf94 edited a comment on pull request #6142:
URL: https://github.com/apache/incubator-tvm/pull/6142#issuecomment-664167975


   Hi, all. This is an student intern of us, who is now helping us with the Ansor upstreaming. 😄 
   
   The follow_split & follow_fused_split are two steps extent to `te.Stage.Split`. Each of these will collect information from the former history and process the split.
   
   ## FollowSplit
   
   This is mainly used in stage fusion using compute at.
   For example we have stages: `Dense -> Relu`:
   We've already done some tiling on Relu, and we would like to compute the Dense at the Relu stage. FollowSplit step is used to keep the outer most few iterators the same as the Relu stage.
   Since in Ansor, the split factor of Relu stage may be left as a None placeholder to be filled by search policy, by this way we can easily write a schedule with dynamic dependent.
   
   ## FollowFusedSplit
   
   This is mainly used in GPU cooperative fetching.
   For example we have stages: `Input -> Dense`:
   ```
               for i.0@j.0 = ... : Bind to blockIdx.x
                   for i.1@j.1 = ... : Bind to threadIdx.x
                       for i.2@j.2 = ...
                           Input_shared = Input ...
                           for k = ...
                               Dense = ...
   ```
   In Ansor's search policy, the outer stage has been tiled. The the threadIdx.x axis is binded to a iterator generated by split & fuse step. We use this step to compute out the final extent of the threadIdx.x binded iterator, to make sure that Input_shared stage can split out a iterator with same extent.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] jcf94 commented on a change in pull request #6142: [Ansor][AutoTVM v2.0] Phase 1: Add follow_split and follow_fused_split steps

Posted by GitBox <gi...@apache.org>.
jcf94 commented on a change in pull request #6142:
URL: https://github.com/apache/incubator-tvm/pull/6142#discussion_r460686060



##########
File path: include/tvm/auto_scheduler/loop_state.h
##########
@@ -356,9 +356,30 @@ class State : public ObjectRef {
    * \note If we do split on an iterator which has stages attached at it(by compute_at), the inner
    * most iterator of split results will become the new attach point.
    */
-  TVM_DLL Array<Iterator> split(int stage_id, const Iterator& it,
-                                const Array<Optional<Integer>>& lengths,
-                                bool inner_to_outer = true);
+  Array<Iterator> split(int stage_id, const Iterator& it, const Array<Optional<Integer>>& lengths,
+                        bool inner_to_outer = true);
+  /*!
+   * \brief Schedule primitive extends to split step.
+   * \param stage_id The index of the stage to be split.
+   * \param it The iterator to be split.
+   * \param src_step_id The index of the split step to follow in the history.
+   * \param n_split The number of split level.
+   * \return The splitted new Iterators.
+   */
+  Array<Iterator> follow_split(int stage_id, const Iterator& it, int src_step_id, int n_split);

Review comment:
       Add `TVM_DLL` hint before this function, see other functions as the reference.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] jcf94 edited a comment on pull request #6142: [Ansor][AutoTVM v2.0] Phase 1: Add follow_split and follow_fused_split steps

Posted by GitBox <gi...@apache.org>.
jcf94 edited a comment on pull request #6142:
URL: https://github.com/apache/incubator-tvm/pull/6142#issuecomment-664167975


   Hi, all. This is an student intern of us, who is now helping us with the Ansor upstreaming. 😄 
   
   The follow_split & follow_fused_split are two steps extent to `te.Stage.Split`. Each of these will collect information from the former history and process the split.
   
   ## FollowSplit
   
   This is mainly used in stage fusion using compute at.
   For example we have stages: `Dense -> Relu`:
   We've already done some tiling on Relu, and we would like to compute the Dense at the Relu stage. FollowSplit step is used to keep the outer most few iterators of Dense the same as the Relu stage.
   Since in Ansor, the split factor of Relu stage may be left as a None placeholder to be filled by search policy, by this way we can easily write a schedule with some kind of dynamic dependence.
   
   ## FollowFusedSplit
   
   This is mainly used in GPU cooperative fetching.
   For example we have stages: `Input -> Dense`:
   ```
               for i.0@j.0 = ... : Bind to blockIdx.x
                   for i.1@j.1 = ... : Bind to threadIdx.x
                       for i.2@j.2 = ...
                           Input_shared = Input ...
                           for k = ...
                               Dense = ...
   ```
   In Ansor's search policy, the outer stage has been tiled. The the threadIdx.x axis is binded to a iterator generated by split & fuse step.
   We use FollowFusedSplit step to compute out the final extent of the threadIdx.x binded iterator, to make sure that Input_shared stage can split out a iterator with same extent.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] jcf94 commented on pull request #6142: [Ansor][AutoTVM v2.0] Phase 1: Add follow_split and follow_fused_split steps

Posted by GitBox <gi...@apache.org>.
jcf94 commented on pull request #6142:
URL: https://github.com/apache/incubator-tvm/pull/6142#issuecomment-664167975


   Hi, all. This is an student intern of us, who is now helping us with the Ansor upstreaming. 😄 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] yangjunpro commented on pull request #6142: [Ansor][AutoTVM v2.0] Phase 1: Add follow_split and follow_fused_split steps

Posted by GitBox <gi...@apache.org>.
yangjunpro commented on pull request #6142:
URL: https://github.com/apache/incubator-tvm/pull/6142#issuecomment-664170160


   Thanks @jiuqi-yang for the nice work, @merrymercy @tqchen @FrozenGene @comaniac , would you please take a look at this PR? We are trying to accelerate the auto-schedule upstreaming process. 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org