You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/06/29 03:14:16 UTC
[tvm] branch main updated: [MetaSchedule] Refactor MultiLevelTiling state to allow subclassing (#11931)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ae33f408c4 [MetaSchedule] Refactor MultiLevelTiling state to allow subclassing (#11931)
ae33f408c4 is described below
commit ae33f408c4658ad38a959489f3fd5d3b7fc37dfc
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Tue Jun 28 20:14:11 2022 -0700
[MetaSchedule] Refactor MultiLevelTiling state to allow subclassing (#11931)
This PR made `State` in `MultiLevelTiling` inherit `Object`, to allow future subclassing of `State`. Making `State` an `Object` allows instances of `State` and its subclasses to be stored in `std::vector<State>`.
---
.../schedule_rule/multi_level_tiling.cc | 70 ++++++++++++----------
.../schedule_rule/multi_level_tiling.h | 25 ++++++--
.../multi_level_tiling_with_intrin.cc | 2 +-
3 files changed, 60 insertions(+), 37 deletions(-)
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc
index 07c5ddd7ae..28c1a0fdb6 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc
@@ -61,6 +61,20 @@ using tir::IterVarType;
using tir::LoopRV;
using tir::Schedule;
+State::State(tir::Schedule sch, tir::BlockRV block_rv, Array<Array<tir::LoopRV>> tiles) {
+ ObjectPtr<StateNode> node = make_object<StateNode>();
+ node->sch = std::move(sch);
+ node->block_rv = std::move(block_rv);
+ node->tiles = std::move(tiles);
+ data_ = std::move(node);
+}
+
+State StateNode::Copy() const {
+ ObjectPtr<StateNode> node = make_object<StateNode>(*this);
+ node->sch = sch->Copy();
+ return State(node);
+}
+
// Do nothing; Inherited from ScheduleRuleNode
void MultiLevelTilingNode::InitializeWithTuneContext(const TuneContext& context) {
if (Optional<Integer> v = context->target.value()->GetAttr<Integer>("max_threads_per_block")) {
@@ -82,15 +96,15 @@ Array<Schedule> MultiLevelTilingNode::Apply(const Schedule& sch, const BlockRV&
Array<Schedule> results;
for (auto&& state : ApplySubRules({State(sch, block_rv)})) {
- results.push_back(std::move(state.sch));
+ results.push_back(std::move(state->sch));
}
return results;
}
std::vector<State> MultiLevelTilingNode::ApplySubRules(std::vector<State> states) {
- states = SubRule(std::move(states), [&](State state) { return TileLoopNest(state); });
- states = SubRule(std::move(states), [&](State state) { return AddWriteReuse(state); });
- states = SubRule(std::move(states), [&](State state) { return AddReadReuse(state); });
+ states = SubRule(std::move(states), [&](State state) { return TileLoopNest(std::move(state)); });
+ states = SubRule(std::move(states), [&](State state) { return AddWriteReuse(std::move(state)); });
+ states = SubRule(std::move(states), [&](State state) { return AddReadReuse(std::move(state)); });
return states;
}
@@ -102,53 +116,49 @@ std::vector<State> MultiLevelTilingNode::AddWriteReuse(State state) const {
std::vector<int> levels = config.levels;
ReuseType req = config.req;
if (Optional<Array<Integer>> ann = tir::GetAnn<Array<Integer>>(
- state.sch->GetSRef(state.block_rv), "meta_schedule.write_cache_level")) {
+ state->sch->GetSRef(state->block_rv), "meta_schedule.write_cache_level")) {
req = ReuseType::kMustReuse;
levels = std::vector<int>(ann.value().begin(), ann.value().end());
}
std::vector<State> results;
if (req == ReuseType::kMayReuse) {
// Case 1. If the write cache is already there, we don't need to add another.
- Array<BlockRV> consumer_rvs = state.sch->GetConsumers(state.block_rv);
- if (consumer_rvs.size() == 1 && IsWriteCache(state.sch->GetSRef(consumer_rvs[0]))) {
+ Array<BlockRV> consumer_rvs = state->sch->GetConsumers(state->block_rv);
+ if (consumer_rvs.size() == 1 && IsWriteCache(state->sch->GetSRef(consumer_rvs[0]))) {
for (int level : levels) {
- State new_state = state;
- new_state.sch = state.sch->Copy();
- new_state.sch->Seed(state.sch->ForkSeed());
- const LoopRV& loop_rv = new_state.tiles[level - 1].back();
- new_state.sch->ReverseComputeAt(consumer_rvs[0], loop_rv, true);
+ State new_state = state->Copy();
+ const LoopRV& loop_rv = new_state->tiles[level - 1].back();
+ new_state->sch->ReverseComputeAt(consumer_rvs[0], loop_rv, true);
results.push_back(std::move(new_state));
}
results.push_back(state);
return results;
} else {
// Case 2. No write cache is added
- State new_state(/*sch=*/state.sch->Copy(), /*block_rv=*/state.block_rv);
- new_state.sch->Seed(state.sch->ForkSeed());
+ State new_state = state->Copy();
results.emplace_back(std::move(new_state));
}
}
// Case 3. Add one write cache
- BlockRV write_cache = state.sch->CacheWrite(/*block_rv=*/state.block_rv, /*read_buffer_index=*/0,
- /*storage_scope=*/config.scope);
+ BlockRV write_cache =
+ state->sch->CacheWrite(/*block_rv=*/state->block_rv, /*read_buffer_index=*/0,
+ /*storage_scope=*/config.scope);
for (int level : levels) {
- State new_state = state;
- new_state.sch = state.sch->Copy();
- new_state.sch->Seed(state.sch->ForkSeed());
- const LoopRV& loop_rv = new_state.tiles[level - 1].back();
- new_state.sch->ReverseComputeAt(write_cache, loop_rv, true);
+ State new_state = state->Copy();
+ const LoopRV& loop_rv = new_state->tiles[level - 1].back();
+ new_state->sch->ReverseComputeAt(write_cache, loop_rv, true);
results.push_back(std::move(new_state));
}
return results;
}
std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const {
- Schedule& sch = state.sch;
- const BlockRV& block_rv = state.block_rv;
+ Schedule& sch = state->sch;
+ const BlockRV& block_rv = state->block_rv;
// Step 1. Assuming trivial binding, pair the loops and their iter-var-types
Array<LoopRV> loops = sch->GetLoops(block_rv);
- std::vector<IterVarType> iter_types = GetBlockVarTypes(sch->GetSRef(state.block_rv));
+ std::vector<IterVarType> iter_types = GetBlockVarTypes(sch->GetSRef(state->block_rv));
ICHECK_EQ(loops.size(), iter_types.size());
// Step 2. For each loop axis, tile it
int64_t spatial_loop_product = 1;
@@ -192,7 +202,7 @@ std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const {
sch->Bind(fused, tile_binds[i]);
tiles[i] = {fused};
}
- state.tiles = Array<Array<LoopRV>>{tiles.begin(), tiles.end()};
+ state->tiles = Array<Array<LoopRV>>{tiles.begin(), tiles.end()};
if (this->thread_warp_size_ != -1) {
int64_t low_inclusive = 1;
int64_t high_inclusive = this->max_threads_per_block_;
@@ -213,13 +223,13 @@ std::vector<State> MultiLevelTilingNode::AddReadReuse(State state) const {
return {std::move(state)};
}
ICHECK(config.req != ReuseType::kMayReuse);
- const BlockRV& block_rv = state.block_rv;
+ const BlockRV& block_rv = state->block_rv;
std::vector<State> results;
results.reserve(config.levels.size());
for (int level : config.levels) {
- Schedule sch = state.sch->Copy();
- sch->Seed(state.sch->ForkSeed());
- const LoopRV& loop_rv = state.tiles[level - 1].back();
+ State new_state = state->Copy();
+ Schedule& sch = new_state->sch;
+ const LoopRV& loop_rv = state->tiles[level - 1].back();
// Enumerate all buffers that are read but not written
std::vector<int> read_buffer_ndims = tir::GetReadBufferNDims(sch->GetSRef(block_rv));
for (int i = 0, n_reads = read_buffer_ndims.size(); i < n_reads; ++i) {
@@ -246,8 +256,6 @@ std::vector<State> MultiLevelTilingNode::AddReadReuse(State state) const {
vector_load_len);
}
}
- State new_state = state;
- new_state.sch = sch;
results.push_back(std::move(new_state));
}
return results;
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.h b/src/meta_schedule/schedule_rule/multi_level_tiling.h
index f260c4856e..05179318d0 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling.h
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling.h
@@ -81,8 +81,12 @@ struct ReuseConfig {
}
};
+// Forware declaration
+class State;
+
/*! \brief The state of auto scheduling for the multi-level tiling rule */
-struct State {
+class StateNode : public Object {
+ public:
/*! \brief The schedule to date */
tir::Schedule sch;
/*! \brief The block to be tiled */
@@ -90,11 +94,22 @@ struct State {
/*! \brief The loop tiles */
Array<Array<tir::LoopRV>> tiles;
+ /*!
+ * \brief Create a copy of the state. The underlying schedule is copied. Schedule rules that
+ * produce multiple states should use this method to create new states.
+ */
+ virtual State Copy() const;
+
+ static constexpr const char* _type_key = "meta_schedule.State";
+ TVM_DECLARE_BASE_OBJECT_INFO(StateNode, Object);
+};
+
+/*! \brief Managed reference to StateNode */
+class State : public ObjectRef {
+ public:
/*! \brief Default constructor */
- explicit State(tir::Schedule sch, tir::BlockRV block_rv,
- Optional<tir::BlockRV> write_cache = NullOpt, bool write_cache_is_added = false,
- Array<Array<tir::LoopRV>> tiles = {})
- : sch(sch), block_rv(block_rv), tiles(tiles) {}
+ explicit State(tir::Schedule sch, tir::BlockRV block_rv, Array<Array<tir::LoopRV>> tiles = {});
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(State, ObjectRef, StateNode);
};
/*!
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc
index da3ea2484e..9dd720db4a 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc
@@ -45,7 +45,7 @@ class MultiLevelTilingWithIntrinNode : public MultiLevelTilingNode {
// tile the outerloops.
virtual std::vector<State> ApplySubRules(std::vector<State> states) {
states = SubRule(std::move(states), [&](State state) {
- state.block_rv = TileForIntrin(state.sch, state.block_rv, intrin_name);
+ state->block_rv = TileForIntrin(state->sch, state->block_rv, intrin_name);
return std::vector<State>(1, state);
});
return MultiLevelTilingNode::ApplySubRules(states);