You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lm...@apache.org on 2020/07/28 14:49:15 UTC
[incubator-tvm] branch master updated: [Ansor][AutoTVM v2.0] Phase
1: Add follow_split and follow_fused_split steps (#6142)
This is an automated email from the ASF dual-hosted git repository.
lmzheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new bbc2dbf [Ansor][AutoTVM v2.0] Phase 1: Add follow_split and follow_fused_split steps (#6142)
bbc2dbf is described below
commit bbc2dbf9f81669c505ac8c73f4a6511bfc941d4f
Author: jiuqi-yang <68...@users.noreply.github.com>
AuthorDate: Tue Jul 28 22:49:05 2020 +0800
[Ansor][AutoTVM v2.0] Phase 1: Add follow_split and follow_fused_split steps (#6142)
* Add cache_read/cache_write step
* Update
* Add follow split and follow fused split
Signed-off-by: jingbang.yjb <ji...@alibaba-inc.com>
Conflicts:
src/auto_scheduler/compute_dag.cc
src/auto_scheduler/transform_step.cc
src/auto_scheduler/transform_step.h
tests/python/unittest/test_auto_scheduler_loop_state.py
* add loop_state.py
Signed-off-by: jingbang.yjb <ji...@alibaba-inc.com>
* Update
* Update
* Update state->current_compute_dag to Optional
* Add some doc strings for Follow_Split and Follow_fused_split
Signed-off-by: jingbang.yjb <ji...@alibaba-inc.com>
* Check code using c-lint
Signed-off-by: jingbang.yjb <ji...@alibaba-inc.com>
* Add more doc strings and change the order for follow split.
Signed-off-by: jingbang.yjb <ji...@alibaba-inc.com>
* Add record test for follow_split and follow_fused_split
Signed-off-by: jingbang.yjb <ji...@alibaba-inc.com>
* Add record test for follow_split
Signed-off-by: jingbang.yjb <ji...@alibaba-inc.com>
* Add record test for follow_fused_split.
Signed-off-by: jingbang.yjb <ji...@alibaba-inc.com>
* Add test record for follow_fused_split
1. delete a comment
2. add "fuse" between follow_split and follow_fused_split
Signed-off-by: jingbang.yjb <ji...@alibaba-inc.com>
* Add doc strings for some functions and variables
Signed-off-by: jingbang.yjb <ji...@alibaba-inc.com>
* Fix the code format in src/auto_scheduler/transform_step.h
Signed-off-by: jingbang.yjb <ji...@alibaba-inc.com>
* Update
* Update doc
* Update
* Update
* Fix follow_split and follow_fused_split record test.
Signed-off-by: jingbang.yjb <ji...@alibaba-inc.com>
* Doc update
* Update some doc strings
Signed-off-by: jingbang.yjb <ji...@alibaba-inc.com>
* Fix code style and some function definitions.
Signed-off-by: jingbang.yjb <ji...@alibaba-inc.com>
* Update
Signed-off-by: jingbang.yjb <ji...@alibaba-inc.com>
* Add comments on parameters.
Signed-off-by: jingbang.yjb <ji...@alibaba-inc.com>
* Add more doc strings and fix some.
Signed-off-by: jingbang.yjb <ji...@alibaba-inc.com>
* Update
Signed-off-by: jingbang.yjb <ji...@alibaba-inc.com>
* Update
Signed-off-by: jingbang.yjb <ji...@alibaba-inc.com>
* Update
Signed-off-by: jingbang.yjb <ji...@alibaba-inc.com>
* Update.
Signed-off-by: jingbang.yjb <ji...@alibaba-inc.com>
Co-authored-by: chengfan.jcf <ch...@alibaba-inc.com>
Co-authored-by: jingbang.yjb <ji...@alibaba-inc.com>
---
include/tvm/auto_scheduler/loop_state.h | 23 +++
include/tvm/auto_scheduler/transform_step.h | 168 ++++++++++++++++-
python/tvm/auto_scheduler/loop_state.py | 96 ++++++++++
src/auto_scheduler/compute_dag.cc | 4 +-
src/auto_scheduler/loop_state.cc | 34 ++++
src/auto_scheduler/transform_step.cc | 208 ++++++++++++++++++++-
.../unittest/test_auto_scheduler_loop_state.py | 42 ++++-
.../python/unittest/test_auto_scheduler_measure.py | 25 ++-
8 files changed, 589 insertions(+), 11 deletions(-)
diff --git a/include/tvm/auto_scheduler/loop_state.h b/include/tvm/auto_scheduler/loop_state.h
index 1c8ea77..9850620 100644
--- a/include/tvm/auto_scheduler/loop_state.h
+++ b/include/tvm/auto_scheduler/loop_state.h
@@ -359,6 +359,29 @@ class State : public ObjectRef {
TVM_DLL Array<Iterator> split(int stage_id, const Iterator& it,
const Array<Optional<Integer>>& lengths,
bool inner_to_outer = true);
+ /*!
+ * \brief Schedule primitive extends to split step.
+ * \param stage_id The index of the stage to be split.
+ * \param it The iterator to be split.
+ * \param src_step_id The index of the split step to be followed in the history.
+ * \param n_split The number of split level.
+ * \return The splitted new Iterators.
+ */
+ TVM_DLL Array<Iterator> follow_split(int stage_id, const Iterator& it, int src_step_id,
+ int n_split);
+ /*!
+ * \brief Schedule primitive extends to split step.
+ * \param stage_id The index of the stage to be split.
+ * \param it The iterator to be split.
+ * \param src_step_ids The indices of the split steps to be followed in the history.
+ * \param level Use the length in this split level.
+ * \param factor_or_nparts True to use `factor` for split from inner to outer,
+ False to use `nparts` for split from outer to inner.
+ * \return The splitted new Iterators.
+ */
+ TVM_DLL Array<Iterator> follow_fused_split(int stage_id, const Iterator& it,
+ const Array<Integer>& src_step_ids, int level,
+ bool factor_or_nparts);
/********** Step APIs working on multiple stages **********/
diff --git a/include/tvm/auto_scheduler/transform_step.h b/include/tvm/auto_scheduler/transform_step.h
index 83d6e29..f91505c 100644
--- a/include/tvm/auto_scheduler/transform_step.h
+++ b/include/tvm/auto_scheduler/transform_step.h
@@ -202,9 +202,10 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag);
* \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
* \param schedule A mutable pointer to a `te::Schedule`. This is required by some steps which need
* `te::Schedule` API. (e.g. CacheRead/CacheWrite step)
+ * \param transform_steps An array record all transform steps.
*/
void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
- te::Schedule* schedule);
+ te::Schedule* schedule, const Array<Step>& transform_steps);
/*!
* \brief Print the step as equivalent python schedule API.
@@ -213,10 +214,12 @@ void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages, StageToAxes
* \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
* \param schedule A mutable pointer to a te::Schedule. This is required by some steps. (e.g.
* CacheRead/CacheWrite step)
+ * \param transform_steps An array record all transform steps.
* \return Python schedule code.
*/
String StepPrintAsPythonAPI(const Step& step, Array<te::Stage>* stages,
- StageToAxesMap* stage_to_axes, te::Schedule* schedule);
+ StageToAxesMap* stage_to_axes, te::Schedule* schedule,
+ const Array<Step>& transform_steps);
/********** Steps working on single stage **********/
@@ -487,6 +490,167 @@ class SplitStep : public Step {
TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode);
};
+/*! \brief Similar to SplitStepNode, but uses split factors from another step
+ * (i.e. Follow another split step) */
+class FollowSplitStepNode : public StepNode {
+ public:
+ /*! \brief The id of the iter to be split. */
+ int iter_id;
+ /*! \brief The index of the split step to follow in the history. */
+ int src_step_id;
+ /*! \brief The number of split level. */
+ int n_split;
+
+ void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+ /*!
+ * \brief Extract split lengths.
+ * \param transform_steps An array record all transform steps.
+ * \param lengths The multiple split factors. Can be None to be filled by search policy.
+ */
+ void ExtractSplitLengths(const Array<Step>& transform_steps,
+ Array<Optional<Integer>>* lengths) const;
+
+ /*!
+ * \brief Apply the current step to State.
+ * \param state A mutable pointer to state, which will be updated.
+ */
+ Array<Iterator> ApplyToState(State* state) const;
+
+ /*!
+ * \brief Apply the current step to tvm.schedule.
+ * \param stages The `te::Stage`s used in TVM scheduler applying.
+ * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ * \param transform_steps An array record all transform steps.
+ * \return The iterator results after split.
+ */
+ Array<tir::IterVar> ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+ const Array<Step>& transform_steps) const;
+
+ /*!
+ * \brief Print the current step as equivalent python schedule API.
+ * \param stages The `te::Stage`s used in TVM scheduler applying.
+ * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ * \param transform_steps An array record all transform steps.
+ * \return Python schedule code.
+ */
+ String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+ const Array<Step>& transform_steps) const;
+
+ static constexpr const char* record_prefix_str = "FSP";
+
+ static constexpr const char* _type_key = "auto_scheduler.FollowSplitStep";
+ TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to FollowSplitStepNode.
+ * \sa FollowSplitStepNode
+ */
+class FollowSplitStep : public Step {
+ public:
+ /*!
+ * \brief The constructor.
+ * \param stage_id The index of the stage to be split.
+ * \param iter_id The index of the iterator to be split.
+ * \param src_step_id The index of the split step to follow in the history.
+ * \param n_split The number of split level.
+ */
+ FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split);
+
+ /*!
+ * \brief The constructor used to read a step record from JSONReader and create the
+ * corresponding step.
+ * \param reader The input JSONReader.
+ */
+ explicit FollowSplitStep(dmlc::JSONReader* reader);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(FollowSplitStep, Step, FollowSplitStepNode);
+};
+
+/*! \brief Similar to FollowSplitStep, but uses split factors from multiple steps.
+ * \note This can be used for the split in cooperative fetching.
+ */
+class FollowFusedSplitStepNode : public StepNode {
+ public:
+ /*! \brief The id of the iter to split. */
+ int iter_id;
+ /*! \brief The indices of the split steps to follow in the history. */
+ Array<Integer> src_step_ids;
+ /*! \brief Use the length in this split level. */
+ int level;
+ /*! \brief If this is true, use factor. Otherwise, use nparts. */
+ bool factor_or_nparts;
+
+ void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+ /*!
+ * \brief Extract split length.
+ * \param transform_steps An array record all transform steps.
+ * \return Split factor.
+ */
+ Optional<Integer> ExtractSplitLength(const Array<Step>& transform_steps) const;
+
+ /*!
+ * \brief Apply the current step to State.
+ * \param state A mutable pointer to state, which will be updated.
+ * \return The iterator results after split.
+ */
+ Array<Iterator> ApplyToState(State* state) const;
+
+ /*!
+ * \brief Apply the current step to tvm.schedule.
+ * \param stages The `te::Stage`s used in TVM scheduler applying.
+ * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ * \param transform_steps An array record all transform steps.
+ * \return The iterator results after split.
+ */
+ Array<tir::IterVar> ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+ const Array<Step>& transform_steps) const;
+
+ /*!
+ * \brief Print the current step as equivalent python schedule API.
+ * \param stages The `te::Stage`s used in TVM scheduler applying.
+ * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ * \param transform_steps An array record all transform steps.
+ * \return Python schedule code.
+ */
+ String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+ const Array<Step>& transform_steps) const;
+
+ static constexpr const char* record_prefix_str = "FFSP";
+
+ static constexpr const char* _type_key = "auto_scheduler.FollowFusedSplitStep";
+ TVM_DECLARE_FINAL_OBJECT_INFO(FollowFusedSplitStepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to FollowFusedSplitStepNode.
+ * \sa FollowFusedSplitStepNode
+ */
+class FollowFusedSplitStep : public Step {
+ public:
+ /*!
+ * \brief The constructor.
+ * \param stage_id The index of the stage to be split.
+ * \param iter_id The index of the iterator to be split.
+ * \param src_step_ids An array of index for split step to follow in the history.
+ * \param level Use the length in this split level.
+ * \param factor_or_nparts If this is true, use factor. Otherwise, use nparts.
+ */
+ FollowFusedSplitStep(int stage_id, int iter_id, const Array<Integer>& src_step_ids, int level,
+ bool factor_or_nparts);
+
+ /*!
+ * \brief The constructor used to read a step record from JSONReader and create the
+ * corresponding step.
+ * \param reader The input JSONReader.
+ */
+ explicit FollowFusedSplitStep(dmlc::JSONReader* reader);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(FollowFusedSplitStep, Step, FollowFusedSplitStepNode);
+};
+
/********** Steps working on multiple stages **********/
/*! \brief Compute at step that corresponds to te::Stage::compute_at */
diff --git a/python/tvm/auto_scheduler/loop_state.py b/python/tvm/auto_scheduler/loop_state.py
index 8c3a936..9ec26c3 100644
--- a/python/tvm/auto_scheduler/loop_state.py
+++ b/python/tvm/auto_scheduler/loop_state.py
@@ -118,6 +118,15 @@ class State:
return self.state_object.stages
@property
+ def transform_steps(self):
+ """
+ Returns
+ -------
+ transform_steps : List[transform_steps]
+ """
+ return self.state_object.transform_steps
+
+ @property
def stage_ops(self):
"""
Returns
@@ -301,6 +310,93 @@ class State:
iterator, lengths, inner_to_outer)
return res
+ def follow_split(self, stage, iterator, src_step_id, n_split):
+ """ Schedule primitive extends to split step.
+
+ This step splits the iterator by the same factors as the given SplitStep.
+
+ Notes
+ ------
+ This step is useful in a scenario that we have subgraph Dense -> Relu,
+ and we want to compute the Dense stage at ReLU. In this case, we need them to have
+ the same tiling structure of common outer loops.
+ The follow_split step could be used here to split the Dense stage and makes sure its
+ splitting factors are the same as the given split step for the ReLU stage.
+
+ Parameters
+ ----------
+ stage : Union[int, Operation, Tensor]
+ The Stage to be split, which can be specified by the integer index, Operation,
+ or output tensor of the stage.
+ iterator : Iterator
+ The iterator to split.
+ src_step_id : int
+ The index of the split step to follow in the history.
+ n_split : int
+ The number of split level.
+
+ Returns
+ -------
+ res_its : List[Iterator]
+ The splitted new Iterators.
+ """
+
+ self.state_object, res = _ffi_api.StateFollowSplit(self.state_object,
+ self._resolve_stage_id(stage),
+ iterator,
+ src_step_id, n_split)
+ return res
+
+ def follow_fused_split(self, stage, iterator, src_step_ids, level,
+ factor_or_nparts):
+ """ Schedule primitive extends to split step.
+
+ This step is used to split an iterator by the same factors
+ as the given list of SplitSteps and FuseSteps.
+
+ Notes
+ ------
+ This step is useful in a scenario that we have a subgraph
+ in GPU schedule: Input -> Dense
+ for i.0@j.0 = ... : Bind to blockIdx.x
+ for i.1@j.1 = ... : Bind to threadIdx.x
+ for i.2@j.2 = ...
+ Input_shared = Input ...
+ for k = ...
+ Dense = ...
+ We intend to apply cooperative fetching with the input stage, while the threadIdx.x
+ axis is bound to an iterator generated by split & fuse step.
+ The follow_fused_step is used split the iterator to 2 parts, while the split factor
+ matches the final extent of the threadIdx.x bound iterator.
+
+ Parameters
+ ----------
+ stage : Union[int, Operation, Tensor]
+ The Stage to be split, which can be specified by the integer index, Operation,
+ or output tensor of the stage.
+ iterator : Iterator
+ The iterator to split.
+ src_step_ids : List[int]
+ The indices of the split steps to follow in the history.
+ level : int
+ Use the length in this split level.
+ factor_or_nparts : bool
+ True to use `factor` for split from inner to outer,
+ False to use `nparts` for split from outer to inner.
+
+ Returns
+ -------
+ res_its : List[Iterator]
+ The splitted new Iterators.
+ """
+
+ self.state_object, res = _ffi_api.StateFollowFusedSplit(self.state_object,
+ self._resolve_stage_id(stage),
+ iterator,
+ src_step_ids, level,
+ factor_or_nparts)
+ return res
+
def compute_at(self, stage, target_stage, target_iter):
""" Schedule primitive corresponds to `te.Stage.compute_at`, see also the `te.Stage` for
more details.
diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc
index 2f6e948..f2815fb 100644
--- a/src/auto_scheduler/compute_dag.cc
+++ b/src/auto_scheduler/compute_dag.cc
@@ -678,7 +678,7 @@ std::pair<te::Schedule, Array<te::Tensor>> ComputeDAG::ApplySteps(
// Apply the history steps to TVM schedule
// Call each step's ApplyToSchedule method
for (const auto& step : transform_steps) {
- StepApplyToSchedule(step, stages, stage_to_axes, &schedule);
+ StepApplyToSchedule(step, stages, stage_to_axes, &schedule, transform_steps);
}
return std::make_pair(schedule, operator->()->tensors);
@@ -722,7 +722,7 @@ String ComputeDAG::PrintStepsAsPython(const Array<Step>& transform_steps) const
}
// Call each step's PrintAsPythonAPI method
for (const auto& step : transform_steps) {
- ss << StepPrintAsPythonAPI(step, &stages, &stage_to_axes, &schedule);
+ ss << StepPrintAsPythonAPI(step, &stages, &stage_to_axes, &schedule, transform_steps);
}
return ss.str();
diff --git a/src/auto_scheduler/loop_state.cc b/src/auto_scheduler/loop_state.cc
index 67c6b38..636066a 100644
--- a/src/auto_scheduler/loop_state.cc
+++ b/src/auto_scheduler/loop_state.cc
@@ -268,6 +268,25 @@ Array<Iterator> State::split(int stage_id, const Iterator& it,
return step->ApplyToState(this);
}
+Array<Iterator> State::follow_split(int stage_id, const Iterator& it, int src_step_id,
+ int n_split) {
+ const Stage& stage = operator->()->stages[stage_id];
+ FollowSplitStep step =
+ FollowSplitStep(stage_id, GetIndex(stage->iters, it), src_step_id, n_split);
+ CopyOnWrite()->transform_steps.push_back(step);
+ return step->ApplyToState(this);
+}
+
+Array<Iterator> State::follow_fused_split(int stage_id, const Iterator& it,
+ const Array<Integer>& src_step_ids, int level,
+ bool factor_or_nparts) {
+ const Stage& stage = operator->()->stages[stage_id];
+ FollowFusedSplitStep step = FollowFusedSplitStep(stage_id, GetIndex(stage->iters, it),
+ src_step_ids, level, factor_or_nparts);
+ CopyOnWrite()->transform_steps.push_back(step);
+ return step->ApplyToState(this);
+}
+
void State::compute_at(int stage_id, int target_stage_id, const Iterator& target_iter) {
const Stage& target_stage = operator->()->stages[target_stage_id];
ComputeAtStep step =
@@ -454,6 +473,21 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateSplit")
return Array<ObjectRef>{state, res};
});
+TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowSplit")
+ .set_body_typed([](State state, int stage_id, const Iterator& it, int src_step_id,
+ int n_split) {
+ const auto& res = state.follow_split(stage_id, it, src_step_id, n_split);
+ return Array<ObjectRef>{state, Array<Iterator>(res)};
+ });
+
+TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowFusedSplit")
+ .set_body_typed([](State state, int stage_id, const Iterator& it,
+ const Array<Integer>& src_step_ids, int level, bool factor_or_nparts) {
+ const auto& res =
+ state.follow_fused_split(stage_id, it, src_step_ids, level, factor_or_nparts);
+ return Array<ObjectRef>{state, Array<Iterator>(res)};
+ });
+
TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeAt")
.set_body_typed([](State state, int stage_id, int target_stage_id,
const Iterator& target_iter) {
diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc
index 5c5cc4b..d43d0af 100644
--- a/src/auto_scheduler/transform_step.cc
+++ b/src/auto_scheduler/transform_step.cc
@@ -85,6 +85,10 @@ Step StepReadFromRecord(dmlc::JSONReader* reader) {
return ReorderStep(reader);
} else if (name == SplitStepNode::record_prefix_str) {
return SplitStep(reader);
+ } else if (name == FollowSplitStepNode::record_prefix_str) {
+ return FollowSplitStep(reader);
+ } else if (name == FollowFusedSplitStepNode::record_prefix_str) {
+ return FollowFusedSplitStep(reader);
} else if (name == ComputeAtStepNode::record_prefix_str) {
return ComputeAtStep(reader);
} else if (name == ComputeInlineStepNode::record_prefix_str) {
@@ -111,6 +115,10 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) {
ps->ApplyToState(state);
} else if (auto ps = step.as<SplitStepNode>()) {
ps->ApplyToState(state);
+ } else if (auto ps = step.as<FollowSplitStepNode>()) {
+ ps->ApplyToState(state);
+ } else if (auto ps = step.as<FollowFusedSplitStepNode>()) {
+ ps->ApplyToState(state);
} else if (auto ps = step.as<ComputeAtStepNode>()) {
ps->ApplyToState(state);
} else if (auto ps = step.as<ComputeInlineStepNode>()) {
@@ -127,7 +135,7 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) {
}
void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
- te::Schedule* schedule) {
+ te::Schedule* schedule, const Array<Step>& transform_steps) {
if (auto ps = step.as<AnnotationStepNode>()) {
ps->ApplyToSchedule(stages, stage_to_axes);
} else if (auto ps = step.as<FuseStepNode>()) {
@@ -136,6 +144,10 @@ void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages, StageToAxes
ps->ApplyToSchedule(stages, stage_to_axes);
} else if (auto ps = step.as<SplitStepNode>()) {
ps->ApplyToSchedule(stages, stage_to_axes);
+ } else if (auto ps = step.as<FollowSplitStepNode>()) {
+ ps->ApplyToSchedule(stages, stage_to_axes, transform_steps);
+ } else if (auto ps = step.as<FollowFusedSplitStepNode>()) {
+ ps->ApplyToSchedule(stages, stage_to_axes, transform_steps);
} else if (auto ps = step.as<ComputeAtStepNode>()) {
ps->ApplyToSchedule(stages, stage_to_axes);
} else if (auto ps = step.as<ComputeInlineStepNode>()) {
@@ -152,7 +164,8 @@ void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages, StageToAxes
}
String StepPrintAsPythonAPI(const Step& step, Array<te::Stage>* stages,
- StageToAxesMap* stage_to_axes, te::Schedule* schedule) {
+ StageToAxesMap* stage_to_axes, te::Schedule* schedule,
+ const Array<Step>& transform_steps) {
if (auto ps = step.as<AnnotationStepNode>()) {
return ps->PrintAsPythonAPI(stages, stage_to_axes);
} else if (auto ps = step.as<FuseStepNode>()) {
@@ -161,6 +174,10 @@ String StepPrintAsPythonAPI(const Step& step, Array<te::Stage>* stages,
return ps->PrintAsPythonAPI(stages, stage_to_axes);
} else if (auto ps = step.as<SplitStepNode>()) {
return ps->PrintAsPythonAPI(stages, stage_to_axes);
+ } else if (auto ps = step.as<FollowSplitStepNode>()) {
+ return ps->PrintAsPythonAPI(stages, stage_to_axes, transform_steps);
+ } else if (auto ps = step.as<FollowFusedSplitStepNode>()) {
+ return ps->PrintAsPythonAPI(stages, stage_to_axes, transform_steps);
} else if (auto ps = step.as<ComputeAtStepNode>()) {
return ps->PrintAsPythonAPI(stages, stage_to_axes);
} else if (auto ps = step.as<ComputeInlineStepNode>()) {
@@ -776,6 +793,193 @@ String SplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer);
}
+/********** Follow Split **********/
+FollowSplitStep::FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split) {
+ auto node = make_object<FollowSplitStepNode>();
+ node->stage_id = stage_id;
+ node->iter_id = iter_id;
+ node->src_step_id = src_step_id;
+ node->n_split = n_split;
+ data_ = std::move(node);
+}
+
+void FollowSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+ writer->WriteArraySeperator();
+ writer->WriteString(record_prefix_str);
+ writer->WriteArrayItem(stage_id);
+ writer->WriteArrayItem(iter_id);
+ writer->WriteArrayItem(src_step_id);
+ writer->WriteArrayItem(n_split);
+}
+
+void FollowSplitStepNode::ExtractSplitLengths(const Array<Step>& transform_steps,
+ Array<Optional<Integer>>* lengths) const {
+ // Make sure src_step_id is within the range of transform_steps.
+ CHECK_LT(src_step_id, transform_steps.size());
+ auto ps = transform_steps[src_step_id].as<SplitStepNode>();
+ CHECK(ps != nullptr);
+
+ // Make sure the size of ps->lengths is not smaller than n_split-1.
+ // Note that the number of actual splitting factors of src_step is ps->lengths.size()+1.
+ CHECK_LE(n_split, ps->lengths.size() + 1);
+ CHECK(ps != nullptr);
+
+ lengths->reserve(n_split);
+ int j = 0;
+ // Get the first (n_split-1) split factors of followed src_step.
+ for (; j < n_split - 1; ++j) {
+ lengths->push_back(ps->lengths[j]);
+ }
+
+ // Get the last split factor of src_step for splitting level if n_split is smaller than
+ // ps->lengths.size()+1.
+ PrimExpr last_factor = 1;
+ for (; j < static_cast<int>(ps->lengths.size()); ++j) {
+ if (ps->lengths[j]) {
+ last_factor *= ps->lengths[j].value();
+ } else {
+ last_factor = PrimExpr();
+ break;
+ }
+ }
+ if (last_factor.defined()) {
+ lengths->push_back(Downcast<Integer>(last_factor));
+ } else {
+ lengths->push_back(NullOpt);
+ }
+}
+
+FollowSplitStep::FollowSplitStep(dmlc::JSONReader* reader) {
+ auto node = make_object<FollowSplitStepNode>();
+ bool s;
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&node->stage_id);
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&node->iter_id);
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&node->src_step_id);
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&node->n_split);
+
+ data_ = std::move(node);
+}
+
+Array<Iterator> FollowSplitStepNode::ApplyToState(State* state) const {
+ Array<Optional<Integer>> lengths;
+ ExtractSplitLengths((*state)->transform_steps, &lengths);
+ return ApplySplitToState(state, stage_id, iter_id, lengths, true);
+}
+
+Array<IterVar> FollowSplitStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+ StageToAxesMap* stage_to_axes,
+ const Array<Step>& transform_steps) const {
+ Array<Optional<Integer>> lengths;
+ ExtractSplitLengths(transform_steps, &lengths);
+ return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, lengths, true);
+}
+
+String FollowSplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
+ StageToAxesMap* stage_to_axes,
+ const Array<Step>& transform_steps) const {
+ Array<Optional<Integer>> lengths;
+ ExtractSplitLengths(transform_steps, &lengths);
+ return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, true);
+}
+
+/********** Follow Fused Split **********/
+FollowFusedSplitStep::FollowFusedSplitStep(int stage_id, int iter_id,
+ const Array<Integer>& src_step_ids, int level,
+ bool factor_or_nparts) {
+ auto node = make_object<FollowFusedSplitStepNode>();
+ node->stage_id = stage_id;
+ node->iter_id = iter_id;
+ node->src_step_ids = src_step_ids;
+ node->level = level;
+ node->factor_or_nparts = factor_or_nparts;
+ data_ = std::move(node);
+}
+
+FollowFusedSplitStep::FollowFusedSplitStep(dmlc::JSONReader* reader) {
+ auto node = make_object<FollowFusedSplitStepNode>();
+ bool s;
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&node->stage_id);
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&node->iter_id);
+ s = reader->NextArrayItem();
+ CHECK(s);
+ std::vector<int> int_list;
+ reader->Read(&int_list);
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&node->level);
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&node->factor_or_nparts);
+
+ ::tvm::Array<::tvm::Integer> src_step_ids;
+ for (const auto& i : int_list) {
+ src_step_ids.push_back(i);
+ }
+ node->src_step_ids = src_step_ids;
+ data_ = std::move(node);
+}
+
+void FollowFusedSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+ writer->WriteArraySeperator();
+ writer->WriteString(record_prefix_str);
+ writer->WriteArrayItem(stage_id);
+ writer->WriteArrayItem(iter_id);
+ writer->WriteArrayItem(IntArrayToVector(src_step_ids));
+ writer->WriteArrayItem(level);
+ writer->WriteArrayItem(static_cast<int>(factor_or_nparts));
+}
+
+Optional<Integer> FollowFusedSplitStepNode::ExtractSplitLength(
+ const Array<Step>& transform_steps) const {
+ PrimExpr ret(1);
+
+ for (int src_step_id : src_step_ids) {
+ // Make sure the src_step_id is within the range of transform_steps.
+ CHECK_LT(src_step_id, transform_steps.size());
+ auto ps = transform_steps[src_step_id].as<SplitStepNode>();
+ CHECK(ps != nullptr);
+ // Multiple the splitting factor on corresponding splitting level of src_steps.
+ if (ps->lengths[level] && ret.defined()) {
+ ret *= ps->lengths[level].value();
+ } else {
+ return NullOpt;
+ }
+ }
+ return Downcast<Integer>(ret);
+}
+
+Array<Iterator> FollowFusedSplitStepNode::ApplyToState(State* state) const {
+ const Optional<Integer>& length = ExtractSplitLength((*state)->transform_steps);
+ return ApplySplitToState(state, stage_id, iter_id, {length}, factor_or_nparts);
+}
+
+Array<IterVar> FollowFusedSplitStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+ StageToAxesMap* stage_to_axes,
+ const Array<Step>& transform_steps) const {
+ const Optional<Integer>& length = ExtractSplitLength(transform_steps);
+ return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, {length}, factor_or_nparts);
+}
+
+String FollowFusedSplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
+ StageToAxesMap* stage_to_axes,
+ const Array<Step>& transform_steps) const {
+ const Optional<Integer>& length = ExtractSplitLength(transform_steps);
+ return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, {length},
+ factor_or_nparts);
+}
+
/********** Steps working on multiple stages **********/
/********** Compute At **********/
diff --git a/tests/python/unittest/test_auto_scheduler_loop_state.py b/tests/python/unittest/test_auto_scheduler_loop_state.py
index 8282d4a..e35dfe3 100644
--- a/tests/python/unittest/test_auto_scheduler_loop_state.py
+++ b/tests/python/unittest/test_auto_scheduler_loop_state.py
@@ -85,7 +85,6 @@ def test_split_fuse_reorder_annotation():
assert res == s1[C].iters[5]
assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE["vectorize"]
-
def test_compute_at_root_inline():
dag = auto_scheduler.ComputeDAG(conv2d_nchw_bn_relu(N=1, H=224, W=224, CI=3, CO=64,
kernel_size=7, strides=2, padding=3))
@@ -142,7 +141,6 @@ def test_compute_at_root_inline():
assert s0[conv].iters[5].range.extent == 7
assert s0[conv].iters[6].range.extent == 7
-
def test_cache_read_write():
N, H, W, CO, CI, KH, KW, strides, padding = 4, 7, 7, 512, 512, 3, 3, (
1, 1), (1, 1)
@@ -417,7 +415,47 @@ def test_cache_read_write():
for it0, it1 in zip(s0[kernel_split].iters, s0[kernel_split_global].iters):
assert it0.range == it1.range
+def test_follow_split_follow_fused_split():
+ A, B, C = matmul_auto_scheduler_test(512, 512, 512)
+ dag = auto_scheduler.ComputeDAG([A, B, C])
+ s0 = dag.get_init_state()
+
+ C_global = s0.cache_write(C, "global")
+ its0 = s0.split(C, s0[C].iters[0], [4, 2, 8, 4], True)
+ split_step0 = len(s0.transform_steps) - 1
+ for level in range(1, 6):
+ tmp = s0.copy()
+ tmp.follow_split(C_global, tmp[C_global].iters[0], split_step0, level)
+ for i in range(0, level):
+ assert tmp[C].iters[i].range.extent == \
+ tmp[C_global].iters[i].range.extent
+
+ its1 = s0.split(C, s0[C].iters[5], [2, 2, 4, 8])
+ split_step1 = len(s0.transform_steps) - 1
+ its = []
+ for i0, i1 in zip(its0, its1):
+ its.append(i0)
+ its.append(i1)
+ s0.reorder(C, its)
+ for i in range(0, 5):
+ s0.fuse(C, [s0[C].iters[i], s0[C].iters[i + 1]])
+
+ for level in range(0, 4):
+ tmp = s0.copy()
+ tmp.follow_fused_split(C_global, tmp[C_global].iters[0],
+ [split_step0, split_step1], level, False)
+ assert tmp[C].iters[level + 1].range.extent == \
+ tmp[C_global].iters[0].range.extent
+
+ for level in range(0, 4):
+ tmp = s0.copy()
+ tmp.follow_fused_split(C_global, tmp[C_global].iters[0],
+ [split_step0, split_step1], level, True)
+ assert tmp[C].iters[level + 1].range.extent == \
+ tmp[C_global].iters[1].range.extent
+
if __name__ == "__main__":
test_split_fuse_reorder_annotation()
test_compute_at_root_inline()
test_cache_read_write()
+ test_follow_split_follow_fused_split()
diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py
index 5f2f87a..39d01e0 100644
--- a/tests/python/unittest/test_auto_scheduler_measure.py
+++ b/tests/python/unittest/test_auto_scheduler_measure.py
@@ -22,8 +22,7 @@ import topi
from tvm import te, auto_scheduler
import tempfile
-from test_auto_scheduler_common import get_tiled_matmul
-
+from test_auto_scheduler_common import matmul_auto_scheduler_test, get_tiled_matmul
def test_record():
if not tvm.runtime.enabled("llvm"):
@@ -37,8 +36,12 @@ def test_record():
k = te.reduce_axis((0, 512), name='k')
E = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * D[k][j], axis=[k]), name='E')
F = topi.nn.relu(E)
+ k = te.reduce_axis((0, 512), name='k')
+ G = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * F[k][j], axis=[k]), name='G')
+ H = topi.nn.relu(G)
+ I = topi.nn.relu(H)
- dag = auto_scheduler.ComputeDAG([A, B, F])
+ dag = auto_scheduler.ComputeDAG([A, B, I])
s = dag.get_init_state()
# Split
@@ -71,6 +74,22 @@ def test_record():
s.compute_at(D_global, E, s[E].iters[2])
# Cache Write
s.cache_write(D, "shared")
+ #follow_split
+ its2 = s.split(G, s[G].iters[0], [4, 2, 8, 4], True)
+ split_step0 = len(s.transform_steps) - 1
+ s.follow_split(G, s[G].iters[5], split_step0, 4)
+ #follow_fused_split
+ its2 = s.split(H, s[H].iters[0], [4, 2, 8, 4], True)
+ split_step1 = len(s.transform_steps) - 1
+ its3 = s.split(H, s[H].iters[5], [2, 4, 2, 4], True)
+ split_step2 = len(s.transform_steps) - 1
+ its = []
+ for i0, i1 in zip(its2, its3):
+ its.append(i0)
+ its.append(i1)
+ for i in range(0, 5):
+ s.fuse(H, [s[H].iters[i], s[H].iters[i + 1]])
+ s.follow_fused_split(I, s[I].iters[0], [split_step1, split_step2], 0, False)
target = tvm.target.create("llvm")
task = auto_scheduler.SearchTask(dag, "test", target)