You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lm...@apache.org on 2020/07/28 14:49:15 UTC

[incubator-tvm] branch master updated: [Ansor][AutoTVM v2.0] Phase 1: Add follow_split and follow_fused_split steps (#6142)

This is an automated email from the ASF dual-hosted git repository.

lmzheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new bbc2dbf  [Ansor][AutoTVM v2.0] Phase 1: Add follow_split and follow_fused_split steps (#6142)
bbc2dbf is described below

commit bbc2dbf9f81669c505ac8c73f4a6511bfc941d4f
Author: jiuqi-yang <68...@users.noreply.github.com>
AuthorDate: Tue Jul 28 22:49:05 2020 +0800

    [Ansor][AutoTVM v2.0] Phase 1: Add follow_split and follow_fused_split steps (#6142)
    
    * Add cache_read/cache_write step
    
    * Update
    
    * Add follow split and follow fused split
    
    Signed-off-by: jingbang.yjb <ji...@alibaba-inc.com>
    
    Conflicts:
    	src/auto_scheduler/compute_dag.cc
    	src/auto_scheduler/transform_step.cc
    	src/auto_scheduler/transform_step.h
    	tests/python/unittest/test_auto_scheduler_loop_state.py
    
    * add loop_state.py
    
    Signed-off-by: jingbang.yjb <ji...@alibaba-inc.com>
    
    * Update
    
    * Update
    
    * Update state->current_compute_dag to Optional
    
    * Add some doc strings for Follow_Split and Follow_fused_split
    
    Signed-off-by: jingbang.yjb <ji...@alibaba-inc.com>
    
    * Check code using c-lint
    
    Signed-off-by: jingbang.yjb <ji...@alibaba-inc.com>
    
    * Add more doc strings and change the order for follow split.
    
    Signed-off-by: jingbang.yjb <ji...@alibaba-inc.com>
    
    * Add record test for follow_split and follow_fused_split
    
    Signed-off-by: jingbang.yjb <ji...@alibaba-inc.com>
    
    * Add record test for follow_split
    
    Signed-off-by: jingbang.yjb <ji...@alibaba-inc.com>
    
    * Add record test for follow_fused_split.
    
    Signed-off-by: jingbang.yjb <ji...@alibaba-inc.com>
    
    * Add test record for follow_fused_split
    1. delete a comment
    2. add "fuse" between follow_split and follow_fused_split
    
    Signed-off-by: jingbang.yjb <ji...@alibaba-inc.com>
    
    * Add doc strings for some functions and variables
    
    Signed-off-by: jingbang.yjb <ji...@alibaba-inc.com>
    
    * Fix the code format in src/auto_scheduler/transform_step.h
    
    Signed-off-by: jingbang.yjb <ji...@alibaba-inc.com>
    
    * Update
    
    * Update doc
    
    * Update
    
    * Update
    
    * Fix follow_split and follow_fused_split record test.
    
    Signed-off-by: jingbang.yjb <ji...@alibaba-inc.com>
    
    * Doc update
    
    * Update some doc strings
    
    Signed-off-by: jingbang.yjb <ji...@alibaba-inc.com>
    
    * Fix code style and some function definitions.
    
    Signed-off-by: jingbang.yjb <ji...@alibaba-inc.com>
    
    * Update
    
    Signed-off-by: jingbang.yjb <ji...@alibaba-inc.com>
    
    * Add comments on parameters.
    
    Signed-off-by: jingbang.yjb <ji...@alibaba-inc.com>
    
    * Add more doc strings and fix some.
    
    Signed-off-by: jingbang.yjb <ji...@alibaba-inc.com>
    
    * Update
    
    Signed-off-by: jingbang.yjb <ji...@alibaba-inc.com>
    
    * Update
    
    Signed-off-by: jingbang.yjb <ji...@alibaba-inc.com>
    
    * Update
    
    Signed-off-by: jingbang.yjb <ji...@alibaba-inc.com>
    
    * Update.
    
    Signed-off-by: jingbang.yjb <ji...@alibaba-inc.com>
    
    Co-authored-by: chengfan.jcf <ch...@alibaba-inc.com>
    Co-authored-by: jingbang.yjb <ji...@alibaba-inc.com>
---
 include/tvm/auto_scheduler/loop_state.h            |  23 +++
 include/tvm/auto_scheduler/transform_step.h        | 168 ++++++++++++++++-
 python/tvm/auto_scheduler/loop_state.py            |  96 ++++++++++
 src/auto_scheduler/compute_dag.cc                  |   4 +-
 src/auto_scheduler/loop_state.cc                   |  34 ++++
 src/auto_scheduler/transform_step.cc               | 208 ++++++++++++++++++++-
 .../unittest/test_auto_scheduler_loop_state.py     |  42 ++++-
 .../python/unittest/test_auto_scheduler_measure.py |  25 ++-
 8 files changed, 589 insertions(+), 11 deletions(-)

diff --git a/include/tvm/auto_scheduler/loop_state.h b/include/tvm/auto_scheduler/loop_state.h
index 1c8ea77..9850620 100644
--- a/include/tvm/auto_scheduler/loop_state.h
+++ b/include/tvm/auto_scheduler/loop_state.h
@@ -359,6 +359,29 @@ class State : public ObjectRef {
   TVM_DLL Array<Iterator> split(int stage_id, const Iterator& it,
                                 const Array<Optional<Integer>>& lengths,
                                 bool inner_to_outer = true);
+  /*!
+   * \brief Schedule primitive extends to split step.
+   * \param stage_id The index of the stage to be split.
+   * \param it The iterator to be split.
+   * \param src_step_id The index of the split step to be followed in the history.
+   * \param n_split The number of split level.
+   * \return The splitted new Iterators.
+   */
+  TVM_DLL Array<Iterator> follow_split(int stage_id, const Iterator& it, int src_step_id,
+                                       int n_split);
+  /*!
+   * \brief Schedule primitive extends to split step.
+   * \param stage_id The index of the stage to be split.
+   * \param it The iterator to be split.
+   * \param src_step_ids The indices of the split steps to be followed in the history.
+   * \param level Use the length in this split level.
+   * \param factor_or_nparts True to use `factor` for split from inner to outer,
+      False to use `nparts` for split from outer to inner.
+   * \return The splitted new Iterators.
+   */
+  TVM_DLL Array<Iterator> follow_fused_split(int stage_id, const Iterator& it,
+                                             const Array<Integer>& src_step_ids, int level,
+                                             bool factor_or_nparts);
 
   /********** Step APIs working on multiple stages **********/
 
diff --git a/include/tvm/auto_scheduler/transform_step.h b/include/tvm/auto_scheduler/transform_step.h
index 83d6e29..f91505c 100644
--- a/include/tvm/auto_scheduler/transform_step.h
+++ b/include/tvm/auto_scheduler/transform_step.h
@@ -202,9 +202,10 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag);
  * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
  * \param schedule A mutable pointer to a `te::Schedule`. This is required by some steps which need
  * `te::Schedule` API. (e.g. CacheRead/CacheWrite step)
+ * \param transform_steps An array record all transform steps.
  */
 void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
-                         te::Schedule* schedule);
+                         te::Schedule* schedule, const Array<Step>& transform_steps);
 
 /*!
  * \brief Print the step as equivalent python schedule API.
@@ -213,10 +214,12 @@ void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages, StageToAxes
  * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
  * \param schedule A mutable pointer to a te::Schedule. This is required by some steps. (e.g.
  * CacheRead/CacheWrite step)
+ * \param transform_steps An array record all transform steps.
  * \return Python schedule code.
  */
 String StepPrintAsPythonAPI(const Step& step, Array<te::Stage>* stages,
-                            StageToAxesMap* stage_to_axes, te::Schedule* schedule);
+                            StageToAxesMap* stage_to_axes, te::Schedule* schedule,
+                            const Array<Step>& transform_steps);
 
 /********** Steps working on single stage **********/
 
@@ -487,6 +490,167 @@ class SplitStep : public Step {
   TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode);
 };
 
+/*! \brief Similar to SplitStepNode, but uses split factors from another step
+ * (i.e. Follow another split step) */
+class FollowSplitStepNode : public StepNode {
+ public:
+  /*! \brief The id of the iter to be split. */
+  int iter_id;
+  /*! \brief The index of the split step to follow in the history. */
+  int src_step_id;
+  /*! \brief The number of split level. */
+  int n_split;
+
+  void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+  /*!
+   * \brief Extract split lengths.
+   * \param transform_steps An array record all transform steps.
+   * \param lengths The multiple split factors. Can be None to be filled by search policy.
+   */
+  void ExtractSplitLengths(const Array<Step>& transform_steps,
+                           Array<Optional<Integer>>* lengths) const;
+
+  /*!
+   * \brief Apply the current step to State.
+   * \param state A mutable pointer to state, which will be updated.
+   */
+  Array<Iterator> ApplyToState(State* state) const;
+
+  /*!
+   * \brief Apply the current step to tvm.schedule.
+   * \param stages The `te::Stage`s used in TVM scheduler applying.
+   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+   * \param transform_steps An array record all transform steps.
+   * \return The iterator results after split.
+   */
+  Array<tir::IterVar> ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+                                      const Array<Step>& transform_steps) const;
+
+  /*!
+   * \brief Print the current step as equivalent python schedule API.
+   * \param stages The `te::Stage`s used in TVM scheduler applying.
+   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+   * \param transform_steps An array record all transform steps.
+   * \return Python schedule code.
+   */
+  String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+                          const Array<Step>& transform_steps) const;
+
+  static constexpr const char* record_prefix_str = "FSP";
+
+  static constexpr const char* _type_key = "auto_scheduler.FollowSplitStep";
+  TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to FollowSplitStepNode.
+ * \sa FollowSplitStepNode
+ */
+class FollowSplitStep : public Step {
+ public:
+  /*!
+   * \brief The constructor.
+   * \param stage_id The index of the stage to be split.
+   * \param iter_id The index of the iterator to be split.
+   * \param src_step_id The index of the split step to follow in the history.
+   * \param n_split The number of split level.
+   */
+  FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split);
+
+  /*!
+   * \brief The constructor used to read a step record from JSONReader and create the
+   * corresponding step.
+   * \param reader The input JSONReader.
+   */
+  explicit FollowSplitStep(dmlc::JSONReader* reader);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(FollowSplitStep, Step, FollowSplitStepNode);
+};
+
+/*! \brief Similar to FollowSplitStep, but uses split factors from multiple steps.
+ *  \note This can be used for the split in cooperative fetching.
+ */
+class FollowFusedSplitStepNode : public StepNode {
+ public:
+  /*! \brief The id of the iter to split. */
+  int iter_id;
+  /*! \brief The indices of the split steps to follow in the history. */
+  Array<Integer> src_step_ids;
+  /*! \brief  Use the length in this split level. */
+  int level;
+  /*! \brief If this is true, use factor. Otherwise, use nparts. */
+  bool factor_or_nparts;
+
+  void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+  /*!
+   * \brief Extract split length.
+   * \param transform_steps An array record all transform steps.
+   * \return Split factor.
+   */
+  Optional<Integer> ExtractSplitLength(const Array<Step>& transform_steps) const;
+
+  /*!
+   * \brief Apply the current step to State.
+   * \param state A mutable pointer to state, which will be updated.
+   * \return The iterator results after split.
+   */
+  Array<Iterator> ApplyToState(State* state) const;
+
+  /*!
+   * \brief Apply the current step to tvm.schedule.
+   * \param stages The `te::Stage`s used in TVM scheduler applying.
+   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+   * \param transform_steps An array record all transform steps.
+   * \return The iterator results after split.
+   */
+  Array<tir::IterVar> ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+                                      const Array<Step>& transform_steps) const;
+
+  /*!
+   * \brief Print the current step as equivalent python schedule API.
+   * \param stages The `te::Stage`s used in TVM scheduler applying.
+   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+   * \param transform_steps An array record all transform steps.
+   * \return Python schedule code.
+   */
+  String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+                          const Array<Step>& transform_steps) const;
+
+  static constexpr const char* record_prefix_str = "FFSP";
+
+  static constexpr const char* _type_key = "auto_scheduler.FollowFusedSplitStep";
+  TVM_DECLARE_FINAL_OBJECT_INFO(FollowFusedSplitStepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to FollowFusedSplitStepNode.
+ * \sa FollowFusedSplitStepNode
+ */
+class FollowFusedSplitStep : public Step {
+ public:
+  /*!
+   * \brief The constructor.
+   * \param stage_id The index of the stage to be split.
+   * \param iter_id The index of the iterator to be split.
+   * \param src_step_ids An array of index for split step to follow in the history.
+   * \param level Use the length in this split level.
+   * \param factor_or_nparts If this is true, use factor. Otherwise, use nparts.
+   */
+  FollowFusedSplitStep(int stage_id, int iter_id, const Array<Integer>& src_step_ids, int level,
+                       bool factor_or_nparts);
+
+  /*!
+   * \brief The constructor used to read a step record from JSONReader and create the
+   * corresponding step.
+   * \param reader The input JSONReader.
+   */
+  explicit FollowFusedSplitStep(dmlc::JSONReader* reader);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(FollowFusedSplitStep, Step, FollowFusedSplitStepNode);
+};
+
 /********** Steps working on multiple stages **********/
 
 /*! \brief Compute at step that corresponds to te::Stage::compute_at */
diff --git a/python/tvm/auto_scheduler/loop_state.py b/python/tvm/auto_scheduler/loop_state.py
index 8c3a936..9ec26c3 100644
--- a/python/tvm/auto_scheduler/loop_state.py
+++ b/python/tvm/auto_scheduler/loop_state.py
@@ -118,6 +118,15 @@ class State:
         return self.state_object.stages
 
     @property
+    def transform_steps(self):
+        """
+        Returns
+        -------
+        transform_steps : List[transform_steps]
+        """
+        return self.state_object.transform_steps
+
+    @property
     def stage_ops(self):
         """
         Returns
@@ -301,6 +310,93 @@ class State:
                                                      iterator, lengths, inner_to_outer)
         return res
 
+    def follow_split(self, stage, iterator, src_step_id, n_split):
+        """ Schedule primitive extends to split step.
+
+        This step splits the iterator by the same factors as the given SplitStep.
+
+        Notes
+        ------
+            This step is useful in a scenario that we have subgraph Dense -> Relu,
+            and we want to compute the Dense stage at ReLU. In this case, we need them to have
+            the same tiling structure of common outer loops.
+            The follow_split step could be used here to split the Dense stage and makes sure its
+            splitting factors are the same as the given split step for the ReLU stage.
+
+        Parameters
+        ----------
+        stage : Union[int, Operation, Tensor]
+            The Stage to be split, which can be specified by the integer index, Operation,
+            or output tensor of the stage.
+        iterator : Iterator
+            The iterator to split.
+        src_step_id : int
+            The index of the split step to follow in the history.
+        n_split : int
+            The number of split level.
+
+        Returns
+        -------
+        res_its : List[Iterator]
+            The splitted new Iterators.
+        """
+
+        self.state_object, res = _ffi_api.StateFollowSplit(self.state_object,
+                                                           self._resolve_stage_id(stage),
+                                                           iterator,
+                                                           src_step_id, n_split)
+        return res
+
+    def follow_fused_split(self, stage, iterator, src_step_ids, level,
+                           factor_or_nparts):
+        """ Schedule primitive extends to split step.
+
+        This step is used to split an iterator by the same factors
+        as the given list of SplitSteps and FuseSteps.
+
+        Notes
+        ------
+            This step is useful in a scenario that we have a subgraph
+            in GPU schedule: Input -> Dense
+            for i.0@j.0 = ... : Bind to blockIdx.x
+                for i.1@j.1 = ... : Bind to threadIdx.x
+                    for i.2@j.2 = ...
+                        Input_shared = Input ...
+                        for k = ...
+                            Dense = ...
+            We intend to apply cooperative fetching with the input stage, while the threadIdx.x
+            axis is bound to an iterator generated by split & fuse step.
+            The follow_fused_step is used split the iterator to 2 parts, while the split factor
+            matches the final extent of the threadIdx.x bound iterator.
+
+        Parameters
+        ----------
+        stage : Union[int, Operation, Tensor]
+            The Stage to be split, which can be specified by the integer index, Operation,
+            or output tensor of the stage.
+        iterator : Iterator
+            The iterator to split.
+        src_step_ids : List[int]
+            The indices of the split steps to follow in the history.
+        level : int
+            Use the length in this split level.
+        factor_or_nparts : bool
+            True to use `factor` for split from inner to outer,
+            False to use `nparts` for split from outer to inner.
+
+        Returns
+        -------
+        res_its : List[Iterator]
+            The splitted new Iterators.
+        """
+
+        self.state_object, res = _ffi_api.StateFollowFusedSplit(self.state_object,
+                                                                self._resolve_stage_id(stage),
+                                                                iterator,
+                                                                src_step_ids, level,
+                                                                factor_or_nparts)
+        return res
+
     def compute_at(self, stage, target_stage, target_iter):
         """ Schedule primitive corresponds to `te.Stage.compute_at`, see also the `te.Stage` for
         more details.
diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc
index 2f6e948..f2815fb 100644
--- a/src/auto_scheduler/compute_dag.cc
+++ b/src/auto_scheduler/compute_dag.cc
@@ -678,7 +678,7 @@ std::pair<te::Schedule, Array<te::Tensor>> ComputeDAG::ApplySteps(
   // Apply the history steps to TVM schedule
   // Call each step's ApplyToSchedule method
   for (const auto& step : transform_steps) {
-    StepApplyToSchedule(step, stages, stage_to_axes, &schedule);
+    StepApplyToSchedule(step, stages, stage_to_axes, &schedule, transform_steps);
   }
 
   return std::make_pair(schedule, operator->()->tensors);
@@ -722,7 +722,7 @@ String ComputeDAG::PrintStepsAsPython(const Array<Step>& transform_steps) const
   }
   // Call each step's PrintAsPythonAPI method
   for (const auto& step : transform_steps) {
-    ss << StepPrintAsPythonAPI(step, &stages, &stage_to_axes, &schedule);
+    ss << StepPrintAsPythonAPI(step, &stages, &stage_to_axes, &schedule, transform_steps);
   }
 
   return ss.str();
diff --git a/src/auto_scheduler/loop_state.cc b/src/auto_scheduler/loop_state.cc
index 67c6b38..636066a 100644
--- a/src/auto_scheduler/loop_state.cc
+++ b/src/auto_scheduler/loop_state.cc
@@ -268,6 +268,25 @@ Array<Iterator> State::split(int stage_id, const Iterator& it,
   return step->ApplyToState(this);
 }
 
+Array<Iterator> State::follow_split(int stage_id, const Iterator& it, int src_step_id,
+                                    int n_split) {
+  const Stage& stage = operator->()->stages[stage_id];
+  FollowSplitStep step =
+      FollowSplitStep(stage_id, GetIndex(stage->iters, it), src_step_id, n_split);
+  CopyOnWrite()->transform_steps.push_back(step);
+  return step->ApplyToState(this);
+}
+
+Array<Iterator> State::follow_fused_split(int stage_id, const Iterator& it,
+                                          const Array<Integer>& src_step_ids, int level,
+                                          bool factor_or_nparts) {
+  const Stage& stage = operator->()->stages[stage_id];
+  FollowFusedSplitStep step = FollowFusedSplitStep(stage_id, GetIndex(stage->iters, it),
+                                                   src_step_ids, level, factor_or_nparts);
+  CopyOnWrite()->transform_steps.push_back(step);
+  return step->ApplyToState(this);
+}
+
 void State::compute_at(int stage_id, int target_stage_id, const Iterator& target_iter) {
   const Stage& target_stage = operator->()->stages[target_stage_id];
   ComputeAtStep step =
@@ -454,6 +473,21 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateSplit")
       return Array<ObjectRef>{state, res};
     });
 
+TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowSplit")
+    .set_body_typed([](State state, int stage_id, const Iterator& it, int src_step_id,
+                       int n_split) {
+      const auto& res = state.follow_split(stage_id, it, src_step_id, n_split);
+      return Array<ObjectRef>{state, Array<Iterator>(res)};
+    });
+
+TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowFusedSplit")
+    .set_body_typed([](State state, int stage_id, const Iterator& it,
+                       const Array<Integer>& src_step_ids, int level, bool factor_or_nparts) {
+      const auto& res =
+          state.follow_fused_split(stage_id, it, src_step_ids, level, factor_or_nparts);
+      return Array<ObjectRef>{state, Array<Iterator>(res)};
+    });
+
 TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeAt")
     .set_body_typed([](State state, int stage_id, int target_stage_id,
                        const Iterator& target_iter) {
diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc
index 5c5cc4b..d43d0af 100644
--- a/src/auto_scheduler/transform_step.cc
+++ b/src/auto_scheduler/transform_step.cc
@@ -85,6 +85,10 @@ Step StepReadFromRecord(dmlc::JSONReader* reader) {
     return ReorderStep(reader);
   } else if (name == SplitStepNode::record_prefix_str) {
     return SplitStep(reader);
+  } else if (name == FollowSplitStepNode::record_prefix_str) {
+    return FollowSplitStep(reader);
+  } else if (name == FollowFusedSplitStepNode::record_prefix_str) {
+    return FollowFusedSplitStep(reader);
   } else if (name == ComputeAtStepNode::record_prefix_str) {
     return ComputeAtStep(reader);
   } else if (name == ComputeInlineStepNode::record_prefix_str) {
@@ -111,6 +115,10 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) {
     ps->ApplyToState(state);
   } else if (auto ps = step.as<SplitStepNode>()) {
     ps->ApplyToState(state);
+  } else if (auto ps = step.as<FollowSplitStepNode>()) {
+    ps->ApplyToState(state);
+  } else if (auto ps = step.as<FollowFusedSplitStepNode>()) {
+    ps->ApplyToState(state);
   } else if (auto ps = step.as<ComputeAtStepNode>()) {
     ps->ApplyToState(state);
   } else if (auto ps = step.as<ComputeInlineStepNode>()) {
@@ -127,7 +135,7 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) {
 }
 
 void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
-                         te::Schedule* schedule) {
+                         te::Schedule* schedule, const Array<Step>& transform_steps) {
   if (auto ps = step.as<AnnotationStepNode>()) {
     ps->ApplyToSchedule(stages, stage_to_axes);
   } else if (auto ps = step.as<FuseStepNode>()) {
@@ -136,6 +144,10 @@ void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages, StageToAxes
     ps->ApplyToSchedule(stages, stage_to_axes);
   } else if (auto ps = step.as<SplitStepNode>()) {
     ps->ApplyToSchedule(stages, stage_to_axes);
+  } else if (auto ps = step.as<FollowSplitStepNode>()) {
+    ps->ApplyToSchedule(stages, stage_to_axes, transform_steps);
+  } else if (auto ps = step.as<FollowFusedSplitStepNode>()) {
+    ps->ApplyToSchedule(stages, stage_to_axes, transform_steps);
   } else if (auto ps = step.as<ComputeAtStepNode>()) {
     ps->ApplyToSchedule(stages, stage_to_axes);
   } else if (auto ps = step.as<ComputeInlineStepNode>()) {
@@ -152,7 +164,8 @@ void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages, StageToAxes
 }
 
 String StepPrintAsPythonAPI(const Step& step, Array<te::Stage>* stages,
-                            StageToAxesMap* stage_to_axes, te::Schedule* schedule) {
+                            StageToAxesMap* stage_to_axes, te::Schedule* schedule,
+                            const Array<Step>& transform_steps) {
   if (auto ps = step.as<AnnotationStepNode>()) {
     return ps->PrintAsPythonAPI(stages, stage_to_axes);
   } else if (auto ps = step.as<FuseStepNode>()) {
@@ -161,6 +174,10 @@ String StepPrintAsPythonAPI(const Step& step, Array<te::Stage>* stages,
     return ps->PrintAsPythonAPI(stages, stage_to_axes);
   } else if (auto ps = step.as<SplitStepNode>()) {
     return ps->PrintAsPythonAPI(stages, stage_to_axes);
+  } else if (auto ps = step.as<FollowSplitStepNode>()) {
+    return ps->PrintAsPythonAPI(stages, stage_to_axes, transform_steps);
+  } else if (auto ps = step.as<FollowFusedSplitStepNode>()) {
+    return ps->PrintAsPythonAPI(stages, stage_to_axes, transform_steps);
   } else if (auto ps = step.as<ComputeAtStepNode>()) {
     return ps->PrintAsPythonAPI(stages, stage_to_axes);
   } else if (auto ps = step.as<ComputeInlineStepNode>()) {
@@ -776,6 +793,193 @@ String SplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
   return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer);
 }
 
+/********** Follow Split **********/
+FollowSplitStep::FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split) {
+  auto node = make_object<FollowSplitStepNode>();
+  node->stage_id = stage_id;
+  node->iter_id = iter_id;
+  node->src_step_id = src_step_id;
+  node->n_split = n_split;
+  data_ = std::move(node);
+}
+
+void FollowSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArrayItem(iter_id);
+  writer->WriteArrayItem(src_step_id);
+  writer->WriteArrayItem(n_split);
+}
+
+void FollowSplitStepNode::ExtractSplitLengths(const Array<Step>& transform_steps,
+                                              Array<Optional<Integer>>* lengths) const {
+  // Make sure src_step_id is within the range of transform_steps.
+  CHECK_LT(src_step_id, transform_steps.size());
+  auto ps = transform_steps[src_step_id].as<SplitStepNode>();
+  CHECK(ps != nullptr);
+
+  // Make sure the size of ps->lengths is not smaller than n_split-1.
+  // Note that the number of actual splitting factors of src_step is ps->lengths.size()+1.
+  CHECK_LE(n_split, ps->lengths.size() + 1);
+  CHECK(ps != nullptr);
+
+  lengths->reserve(n_split);
+  int j = 0;
+  // Get the first (n_split-1) split factors of followed src_step.
+  for (; j < n_split - 1; ++j) {
+    lengths->push_back(ps->lengths[j]);
+  }
+
+  // Get the last split factor of src_step for splitting level if n_split is smaller than
+  // ps->lengths.size()+1.
+  PrimExpr last_factor = 1;
+  for (; j < static_cast<int>(ps->lengths.size()); ++j) {
+    if (ps->lengths[j]) {
+      last_factor *= ps->lengths[j].value();
+    } else {
+      last_factor = PrimExpr();
+      break;
+    }
+  }
+  if (last_factor.defined()) {
+    lengths->push_back(Downcast<Integer>(last_factor));
+  } else {
+    lengths->push_back(NullOpt);
+  }
+}
+
+FollowSplitStep::FollowSplitStep(dmlc::JSONReader* reader) {
+  auto node = make_object<FollowSplitStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->iter_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->src_step_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->n_split);
+
+  data_ = std::move(node);
+}
+
+Array<Iterator> FollowSplitStepNode::ApplyToState(State* state) const {
+  Array<Optional<Integer>> lengths;
+  ExtractSplitLengths((*state)->transform_steps, &lengths);
+  return ApplySplitToState(state, stage_id, iter_id, lengths, true);
+}
+
+Array<IterVar> FollowSplitStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+                                                    StageToAxesMap* stage_to_axes,
+                                                    const Array<Step>& transform_steps) const {
+  Array<Optional<Integer>> lengths;
+  ExtractSplitLengths(transform_steps, &lengths);
+  return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, lengths, true);
+}
+
+String FollowSplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
+                                             StageToAxesMap* stage_to_axes,
+                                             const Array<Step>& transform_steps) const {
+  Array<Optional<Integer>> lengths;
+  ExtractSplitLengths(transform_steps, &lengths);
+  return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, true);
+}
+
+/********** Follow Fused Split **********/
+FollowFusedSplitStep::FollowFusedSplitStep(int stage_id, int iter_id,
+                                           const Array<Integer>& src_step_ids, int level,
+                                           bool factor_or_nparts) {
+  auto node = make_object<FollowFusedSplitStepNode>();
+  node->stage_id = stage_id;
+  node->iter_id = iter_id;
+  node->src_step_ids = src_step_ids;
+  node->level = level;
+  node->factor_or_nparts = factor_or_nparts;
+  data_ = std::move(node);
+}
+
+FollowFusedSplitStep::FollowFusedSplitStep(dmlc::JSONReader* reader) {
+  auto node = make_object<FollowFusedSplitStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->iter_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  std::vector<int> int_list;
+  reader->Read(&int_list);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->level);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->factor_or_nparts);
+
+  ::tvm::Array<::tvm::Integer> src_step_ids;
+  for (const auto& i : int_list) {
+    src_step_ids.push_back(i);
+  }
+  node->src_step_ids = src_step_ids;
+  data_ = std::move(node);
+}
+
+void FollowFusedSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArrayItem(iter_id);
+  writer->WriteArrayItem(IntArrayToVector(src_step_ids));
+  writer->WriteArrayItem(level);
+  writer->WriteArrayItem(static_cast<int>(factor_or_nparts));
+}
+
+Optional<Integer> FollowFusedSplitStepNode::ExtractSplitLength(
+    const Array<Step>& transform_steps) const {
+  PrimExpr ret(1);
+
+  for (int src_step_id : src_step_ids) {
+    // Make sure the src_step_id is within the range of transform_steps.
+    CHECK_LT(src_step_id, transform_steps.size());
+    auto ps = transform_steps[src_step_id].as<SplitStepNode>();
+    CHECK(ps != nullptr);
+    // Multiple the splitting factor on corresponding splitting level of src_steps.
+    if (ps->lengths[level] && ret.defined()) {
+      ret *= ps->lengths[level].value();
+    } else {
+      return NullOpt;
+    }
+  }
+  return Downcast<Integer>(ret);
+}
+
+Array<Iterator> FollowFusedSplitStepNode::ApplyToState(State* state) const {
+  const Optional<Integer>& length = ExtractSplitLength((*state)->transform_steps);
+  return ApplySplitToState(state, stage_id, iter_id, {length}, factor_or_nparts);
+}
+
+Array<IterVar> FollowFusedSplitStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+                                                         StageToAxesMap* stage_to_axes,
+                                                         const Array<Step>& transform_steps) const {
+  const Optional<Integer>& length = ExtractSplitLength(transform_steps);
+  return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, {length}, factor_or_nparts);
+}
+
+String FollowFusedSplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
+                                                  StageToAxesMap* stage_to_axes,
+                                                  const Array<Step>& transform_steps) const {
+  const Optional<Integer>& length = ExtractSplitLength(transform_steps);
+  return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, {length},
+                               factor_or_nparts);
+}
+
 /********** Steps working on multiple stages **********/
 
 /********** Compute At **********/
diff --git a/tests/python/unittest/test_auto_scheduler_loop_state.py b/tests/python/unittest/test_auto_scheduler_loop_state.py
index 8282d4a..e35dfe3 100644
--- a/tests/python/unittest/test_auto_scheduler_loop_state.py
+++ b/tests/python/unittest/test_auto_scheduler_loop_state.py
@@ -85,7 +85,6 @@ def test_split_fuse_reorder_annotation():
     assert res == s1[C].iters[5]
     assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE["vectorize"]
 
-
 def test_compute_at_root_inline():
     dag = auto_scheduler.ComputeDAG(conv2d_nchw_bn_relu(N=1, H=224, W=224, CI=3, CO=64,
                                                         kernel_size=7, strides=2, padding=3))
@@ -142,7 +141,6 @@ def test_compute_at_root_inline():
     assert s0[conv].iters[5].range.extent == 7
     assert s0[conv].iters[6].range.extent == 7
 
-
 def test_cache_read_write():
     N, H, W, CO, CI, KH, KW, strides, padding = 4, 7, 7, 512, 512, 3, 3, (
         1, 1), (1, 1)
@@ -417,7 +415,47 @@ def test_cache_read_write():
     for it0, it1 in zip(s0[kernel_split].iters, s0[kernel_split_global].iters):
         assert it0.range == it1.range
 
+def test_follow_split_follow_fused_split():
+    A, B, C = matmul_auto_scheduler_test(512, 512, 512)
+    dag = auto_scheduler.ComputeDAG([A, B, C])
+    s0 = dag.get_init_state()
+
+    C_global = s0.cache_write(C, "global")
+    its0 = s0.split(C, s0[C].iters[0], [4, 2, 8, 4], True)
+    split_step0 = len(s0.transform_steps) - 1
+    for level in range(1, 6):
+        tmp = s0.copy()
+        tmp.follow_split(C_global, tmp[C_global].iters[0], split_step0, level)
+        for i in range(0, level):
+            assert tmp[C].iters[i].range.extent == \
+                   tmp[C_global].iters[i].range.extent
+
+    its1 = s0.split(C, s0[C].iters[5], [2, 2, 4, 8])
+    split_step1 = len(s0.transform_steps) - 1
+    its = []
+    for i0, i1 in zip(its0, its1):
+        its.append(i0)
+        its.append(i1)
+    s0.reorder(C, its)
+    for i in range(0, 5):
+        s0.fuse(C, [s0[C].iters[i], s0[C].iters[i + 1]])
+
+    for level in range(0, 4):
+        tmp = s0.copy()
+        tmp.follow_fused_split(C_global, tmp[C_global].iters[0],
+                               [split_step0, split_step1], level, False)
+        assert tmp[C].iters[level + 1].range.extent == \
+               tmp[C_global].iters[0].range.extent
+
+    for level in range(0, 4):
+        tmp = s0.copy()
+        tmp.follow_fused_split(C_global, tmp[C_global].iters[0],
+                               [split_step0, split_step1], level, True)
+        assert tmp[C].iters[level + 1].range.extent == \
+               tmp[C_global].iters[1].range.extent
+
 if __name__ == "__main__":
     test_split_fuse_reorder_annotation()
     test_compute_at_root_inline()
     test_cache_read_write()
+    test_follow_split_follow_fused_split()
diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py
index 5f2f87a..39d01e0 100644
--- a/tests/python/unittest/test_auto_scheduler_measure.py
+++ b/tests/python/unittest/test_auto_scheduler_measure.py
@@ -22,8 +22,7 @@ import topi
 from tvm import te, auto_scheduler
 import tempfile
 
-from test_auto_scheduler_common import get_tiled_matmul
-
+from test_auto_scheduler_common import matmul_auto_scheduler_test, get_tiled_matmul
 
 def test_record():
     if not tvm.runtime.enabled("llvm"):
@@ -37,8 +36,12 @@ def test_record():
     k = te.reduce_axis((0, 512), name='k')
     E = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * D[k][j], axis=[k]), name='E')
     F = topi.nn.relu(E)
+    k = te.reduce_axis((0, 512), name='k')
+    G = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * F[k][j], axis=[k]), name='G')
+    H = topi.nn.relu(G)
+    I = topi.nn.relu(H)
 
-    dag = auto_scheduler.ComputeDAG([A, B, F])
+    dag = auto_scheduler.ComputeDAG([A, B, I])
     s = dag.get_init_state()
 
     # Split
@@ -71,6 +74,22 @@ def test_record():
     s.compute_at(D_global, E, s[E].iters[2])
     # Cache Write
     s.cache_write(D, "shared")
+    #follow_split
+    its2 = s.split(G, s[G].iters[0], [4, 2, 8, 4], True)
+    split_step0 = len(s.transform_steps) - 1
+    s.follow_split(G, s[G].iters[5], split_step0, 4)
+    #follow_fused_split
+    its2 = s.split(H, s[H].iters[0], [4, 2, 8, 4], True)
+    split_step1 = len(s.transform_steps) - 1
+    its3 = s.split(H, s[H].iters[5], [2, 4, 2, 4], True)
+    split_step2 = len(s.transform_steps) - 1
+    its = []
+    for i0, i1 in zip(its2, its3):
+        its.append(i0)
+        its.append(i1)
+    for i in range(0, 5):
+        s.fuse(H, [s[H].iters[i], s[H].iters[i + 1]])
+    s.follow_fused_split(I, s[I].iters[0], [split_step1, split_step2], 0, False)
 
     target = tvm.target.create("llvm")
     task = auto_scheduler.SearchTask(dag, "test", target)