You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/06/29 03:14:16 UTC

[tvm] branch main updated: [MetaSchedule] Refactor MultiLevelTiling state to allow subclassing (#11931)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ae33f408c4 [MetaSchedule] Refactor MultiLevelTiling state to allow subclassing (#11931)
ae33f408c4 is described below

commit ae33f408c4658ad38a959489f3fd5d3b7fc37dfc
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Tue Jun 28 20:14:11 2022 -0700

    [MetaSchedule] Refactor MultiLevelTiling state to allow subclassing (#11931)
    
    This PR made `State` in `MultiLevelTiling` inherit `Object`, to allow future subclassing of `State`. Making `State` an `Object` allows instances of `State` and its subclasses to be stored in `std::vector<State>`.
---
 .../schedule_rule/multi_level_tiling.cc            | 70 ++++++++++++----------
 .../schedule_rule/multi_level_tiling.h             | 25 ++++++--
 .../multi_level_tiling_with_intrin.cc              |  2 +-
 3 files changed, 60 insertions(+), 37 deletions(-)

diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc
index 07c5ddd7ae..28c1a0fdb6 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc
@@ -61,6 +61,20 @@ using tir::IterVarType;
 using tir::LoopRV;
 using tir::Schedule;
 
+State::State(tir::Schedule sch, tir::BlockRV block_rv, Array<Array<tir::LoopRV>> tiles) {
+  ObjectPtr<StateNode> node = make_object<StateNode>();
+  node->sch = std::move(sch);
+  node->block_rv = std::move(block_rv);
+  node->tiles = std::move(tiles);
+  data_ = std::move(node);
+}
+
+State StateNode::Copy() const {
+  ObjectPtr<StateNode> node = make_object<StateNode>(*this);
+  node->sch = sch->Copy();
+  return State(node);
+}
+
 // Do nothing; Inherited from ScheduleRuleNode
 void MultiLevelTilingNode::InitializeWithTuneContext(const TuneContext& context) {
   if (Optional<Integer> v = context->target.value()->GetAttr<Integer>("max_threads_per_block")) {
@@ -82,15 +96,15 @@ Array<Schedule> MultiLevelTilingNode::Apply(const Schedule& sch, const BlockRV&
 
   Array<Schedule> results;
   for (auto&& state : ApplySubRules({State(sch, block_rv)})) {
-    results.push_back(std::move(state.sch));
+    results.push_back(std::move(state->sch));
   }
   return results;
 }
 
 std::vector<State> MultiLevelTilingNode::ApplySubRules(std::vector<State> states) {
-  states = SubRule(std::move(states), [&](State state) { return TileLoopNest(state); });
-  states = SubRule(std::move(states), [&](State state) { return AddWriteReuse(state); });
-  states = SubRule(std::move(states), [&](State state) { return AddReadReuse(state); });
+  states = SubRule(std::move(states), [&](State state) { return TileLoopNest(std::move(state)); });
+  states = SubRule(std::move(states), [&](State state) { return AddWriteReuse(std::move(state)); });
+  states = SubRule(std::move(states), [&](State state) { return AddReadReuse(std::move(state)); });
   return states;
 }
 
@@ -102,53 +116,49 @@ std::vector<State> MultiLevelTilingNode::AddWriteReuse(State state) const {
   std::vector<int> levels = config.levels;
   ReuseType req = config.req;
   if (Optional<Array<Integer>> ann = tir::GetAnn<Array<Integer>>(
-          state.sch->GetSRef(state.block_rv), "meta_schedule.write_cache_level")) {
+          state->sch->GetSRef(state->block_rv), "meta_schedule.write_cache_level")) {
     req = ReuseType::kMustReuse;
     levels = std::vector<int>(ann.value().begin(), ann.value().end());
   }
   std::vector<State> results;
   if (req == ReuseType::kMayReuse) {
     // Case 1. If the write cache is already there, we don't need to add another.
-    Array<BlockRV> consumer_rvs = state.sch->GetConsumers(state.block_rv);
-    if (consumer_rvs.size() == 1 && IsWriteCache(state.sch->GetSRef(consumer_rvs[0]))) {
+    Array<BlockRV> consumer_rvs = state->sch->GetConsumers(state->block_rv);
+    if (consumer_rvs.size() == 1 && IsWriteCache(state->sch->GetSRef(consumer_rvs[0]))) {
       for (int level : levels) {
-        State new_state = state;
-        new_state.sch = state.sch->Copy();
-        new_state.sch->Seed(state.sch->ForkSeed());
-        const LoopRV& loop_rv = new_state.tiles[level - 1].back();
-        new_state.sch->ReverseComputeAt(consumer_rvs[0], loop_rv, true);
+        State new_state = state->Copy();
+        const LoopRV& loop_rv = new_state->tiles[level - 1].back();
+        new_state->sch->ReverseComputeAt(consumer_rvs[0], loop_rv, true);
         results.push_back(std::move(new_state));
       }
       results.push_back(state);
       return results;
     } else {
       // Case 2. No write cache is added
-      State new_state(/*sch=*/state.sch->Copy(), /*block_rv=*/state.block_rv);
-      new_state.sch->Seed(state.sch->ForkSeed());
+      State new_state = state->Copy();
       results.emplace_back(std::move(new_state));
     }
   }
 
   // Case 3. Add one write cache
-  BlockRV write_cache = state.sch->CacheWrite(/*block_rv=*/state.block_rv, /*read_buffer_index=*/0,
-                                              /*storage_scope=*/config.scope);
+  BlockRV write_cache =
+      state->sch->CacheWrite(/*block_rv=*/state->block_rv, /*read_buffer_index=*/0,
+                             /*storage_scope=*/config.scope);
   for (int level : levels) {
-    State new_state = state;
-    new_state.sch = state.sch->Copy();
-    new_state.sch->Seed(state.sch->ForkSeed());
-    const LoopRV& loop_rv = new_state.tiles[level - 1].back();
-    new_state.sch->ReverseComputeAt(write_cache, loop_rv, true);
+    State new_state = state->Copy();
+    const LoopRV& loop_rv = new_state->tiles[level - 1].back();
+    new_state->sch->ReverseComputeAt(write_cache, loop_rv, true);
     results.push_back(std::move(new_state));
   }
   return results;
 }
 
 std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const {
-  Schedule& sch = state.sch;
-  const BlockRV& block_rv = state.block_rv;
+  Schedule& sch = state->sch;
+  const BlockRV& block_rv = state->block_rv;
   // Step 1. Assuming trivial binding, pair the loops and their iter-var-types
   Array<LoopRV> loops = sch->GetLoops(block_rv);
-  std::vector<IterVarType> iter_types = GetBlockVarTypes(sch->GetSRef(state.block_rv));
+  std::vector<IterVarType> iter_types = GetBlockVarTypes(sch->GetSRef(state->block_rv));
   ICHECK_EQ(loops.size(), iter_types.size());
   // Step 2. For each loop axis, tile it
   int64_t spatial_loop_product = 1;
@@ -192,7 +202,7 @@ std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const {
     sch->Bind(fused, tile_binds[i]);
     tiles[i] = {fused};
   }
-  state.tiles = Array<Array<LoopRV>>{tiles.begin(), tiles.end()};
+  state->tiles = Array<Array<LoopRV>>{tiles.begin(), tiles.end()};
   if (this->thread_warp_size_ != -1) {
     int64_t low_inclusive = 1;
     int64_t high_inclusive = this->max_threads_per_block_;
@@ -213,13 +223,13 @@ std::vector<State> MultiLevelTilingNode::AddReadReuse(State state) const {
     return {std::move(state)};
   }
   ICHECK(config.req != ReuseType::kMayReuse);
-  const BlockRV& block_rv = state.block_rv;
+  const BlockRV& block_rv = state->block_rv;
   std::vector<State> results;
   results.reserve(config.levels.size());
   for (int level : config.levels) {
-    Schedule sch = state.sch->Copy();
-    sch->Seed(state.sch->ForkSeed());
-    const LoopRV& loop_rv = state.tiles[level - 1].back();
+    State new_state = state->Copy();
+    Schedule& sch = new_state->sch;
+    const LoopRV& loop_rv = state->tiles[level - 1].back();
     // Enumerate all buffers that are read but not written
     std::vector<int> read_buffer_ndims = tir::GetReadBufferNDims(sch->GetSRef(block_rv));
     for (int i = 0, n_reads = read_buffer_ndims.size(); i < n_reads; ++i) {
@@ -246,8 +256,6 @@ std::vector<State> MultiLevelTilingNode::AddReadReuse(State state) const {
                       vector_load_len);
       }
     }
-    State new_state = state;
-    new_state.sch = sch;
     results.push_back(std::move(new_state));
   }
   return results;
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.h b/src/meta_schedule/schedule_rule/multi_level_tiling.h
index f260c4856e..05179318d0 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling.h
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling.h
@@ -81,8 +81,12 @@ struct ReuseConfig {
   }
 };
 
+// Forware declaration
+class State;
+
 /*! \brief The state of auto scheduling for the multi-level tiling rule */
-struct State {
+class StateNode : public Object {
+ public:
   /*! \brief The schedule to date */
   tir::Schedule sch;
   /*! \brief The block to be tiled */
@@ -90,11 +94,22 @@ struct State {
   /*! \brief The loop tiles */
   Array<Array<tir::LoopRV>> tiles;
 
+  /*!
+   * \brief Create a copy of the state. The underlying schedule is copied. Schedule rules that
+   * produce multiple states should use this method to create new states.
+   */
+  virtual State Copy() const;
+
+  static constexpr const char* _type_key = "meta_schedule.State";
+  TVM_DECLARE_BASE_OBJECT_INFO(StateNode, Object);
+};
+
+/*! \brief Managed reference to StateNode */
+class State : public ObjectRef {
+ public:
   /*! \brief Default constructor */
-  explicit State(tir::Schedule sch, tir::BlockRV block_rv,
-                 Optional<tir::BlockRV> write_cache = NullOpt, bool write_cache_is_added = false,
-                 Array<Array<tir::LoopRV>> tiles = {})
-      : sch(sch), block_rv(block_rv), tiles(tiles) {}
+  explicit State(tir::Schedule sch, tir::BlockRV block_rv, Array<Array<tir::LoopRV>> tiles = {});
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(State, ObjectRef, StateNode);
 };
 
 /*!
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc
index da3ea2484e..9dd720db4a 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc
@@ -45,7 +45,7 @@ class MultiLevelTilingWithIntrinNode : public MultiLevelTilingNode {
   // tile the outerloops.
   virtual std::vector<State> ApplySubRules(std::vector<State> states) {
     states = SubRule(std::move(states), [&](State state) {
-      state.block_rv = TileForIntrin(state.sch, state.block_rv, intrin_name);
+      state->block_rv = TileForIntrin(state->sch, state->block_rv, intrin_name);
       return std::vector<State>(1, state);
     });
     return MultiLevelTilingNode::ApplySubRules(states);