You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/07/23 21:01:34 UTC

[GitHub] [incubator-tvm] jroesch commented on a change in pull request #6107: [Ansor][AutoTVM v2.0] Phase 1: Add cache_read/cache_write steps

jroesch commented on a change in pull request #6107:
URL: https://github.com/apache/incubator-tvm/pull/6107#discussion_r459719321



##########
File path: src/auto_scheduler/loop_state.h
##########
@@ -358,19 +369,43 @@ class State : public ObjectRef {
   void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter);
   /*!
    * \brief Schedule primitive corresponds to te.compute_inline.
-   * \param stage_id The index of the stage to be reordered.
+   * \param stage_id The index of the stage to be compute inlined.
    */
   void compute_inline(int stage_id);
   /*!
    * \brief Schedule primitive corresponds to te.compute_root.
-   * \param stage_id The index of the stage to be reordered.
+   * \param stage_id The index of the stage to be compute root.

Review comment:
       ```suggestion
      * \param stage_id The index of the stage to be the compute root.
   ```

##########
File path: python/tvm/auto_scheduler/loop_state.py
##########
@@ -351,6 +351,72 @@ def compute_root(self, stage):
         self.state_object = _ffi_api.StateComputeRoot(self.state_object,
                                                       self._resolve_stage_id(stage))
 
+    def cache_read(self, stage, scope_name, reader_stages):
+        """ Schedule primitive corresponds to te.schedule.cache_read.
+

Review comment:
       Can you explain what this step does?

##########
File path: python/tvm/auto_scheduler/loop_state.py
##########
@@ -351,6 +351,72 @@ def compute_root(self, stage):
         self.state_object = _ffi_api.StateComputeRoot(self.state_object,
                                                       self._resolve_stage_id(stage))
 
+    def cache_read(self, stage, scope_name, reader_stages):
+        """ Schedule primitive corresponds to te.schedule.cache_read.
+
+        Parameters
+        ----------
+        stage : Union[int, Operation, Tensor]
+            The Stage to be cache read, which can be specified by the integer index, Operation,
+            or output tensor of the stage.
+        scope_name : str
+            The scope name to be set for the new added read stage.
+        reader_stages : List[Union[int, Operation, Tensor]]
+            The reader stages. Each of the list can be specified by the integer index, Operation,
+            or output tensor of the stage.
+
+        Returns
+        -------
+        new_stage_op : Operator
+            The Operator of the new added stage.
+
+        Notes
+        -----
+        Cache read step will insert an extra stage to the original ComputeDAG (at the back of the
+        target stage).
+        """
+        reader_stage_ids = [self._resolve_stage_id(i) for i in reader_stages]
+        self.state_object, new_stage_id = _ffi_api.StateCacheRead(self.state_object,
+                                                                  self._resolve_stage_id(stage),
+                                                                  scope_name, reader_stage_ids,
+                                                                  self.compute_dag)
+        # Add a new stage will change all ops behind the added stage. But we still want to keep the
+        # original ops map, apply stage id offset to stage_id_map to make them work.
+        self._apply_stage_id_offset(int(new_stage_id))
+        self._update_stage_id_map()
+        return self.stages[int(new_stage_id)].op
+
+    def cache_write(self, stage, scope_name):
+        """ Schedule primitive corresponds to te.schedule.cache_write.
+

Review comment:
       Same as above, can we provide more documentation on what this does given these are user visible APIs.

##########
File path: src/auto_scheduler/loop_state.h
##########
@@ -358,19 +369,43 @@ class State : public ObjectRef {
   void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter);
   /*!
    * \brief Schedule primitive corresponds to te.compute_inline.
-   * \param stage_id The index of the stage to be reordered.
+   * \param stage_id The index of the stage to be compute inlined.
    */
   void compute_inline(int stage_id);
   /*!
    * \brief Schedule primitive corresponds to te.compute_root.
-   * \param stage_id The index of the stage to be reordered.
+   * \param stage_id The index of the stage to be compute root.
    * \note After compute_root, we need careful dependency analysis to compute the accurate bound
    * information. However, it is relatively expensive and complicated, so we just fill "None" as
    * bound for the newly created iterators.
    * Call ComputeDAG::InferBound on the updated state to get the complete bound information.
    */
   void compute_root(int stage_id);
 
+  /********** Step APIs adding new stages **********/
+
+  /*!
+   * \brief Schedule primitive corresponds to te.schedule.cache_read.
+   * \param stage_id The index of the stage to be cache read.
+   * \param scope_name The scope name to be set for the new added read stage.
+   * \param reader_stage_ids The indexes of reader stages.

Review comment:
       ```suggestion
      * \param reader_stage_ids The indices of read stages.
   ```

##########
File path: tests/python/unittest/test_auto_scheduler_loop_state.py
##########
@@ -143,6 +143,282 @@ def test_compute_at_root_inline():
     assert s0[conv].iters[6].range.extent == 7
 
 
+def test_cache_read_write():
+    N, H, W, CO, CI, KH, KW, strides, padding = 4, 7, 7, 512, 512, 3, 3, (
+        1, 1), (1, 1)
+
+    data = te.placeholder((N, CI, H, W), name='Data')
+    kernel_data = te.placeholder((CO, CI, KH, KW), name='Kernel_data')
+    k0, k1 = te.compute(kernel_data.shape,
+                        lambda *i: (kernel_data(*i)+1, kernel_data(*i)/2),
+                        name='Kernel_split')
+    kernel = te.compute(kernel_data.shape,
+                        lambda *i: k0(*i) + k1(*i),
+                        name='Kernel')
+    conv = topi.nn.conv2d_nchw(data, kernel, strides, padding, dilation=1)
+    relu = topi.nn.relu(conv)
+    add = topi.add(data, relu)
+
+    dag = auto_scheduler.ComputeDAG([data, kernel_data, add])
+    s0 = dag.get_init_state()
+
+    pad_temp = s0.stage_ops[1]
+    kernel_split = s0.stage_ops[3]
+
+    # 0: init state
+    ori_its = s0[add].iters
+    its = s0.split(add, s0[add].iters[0], [2])
+    s0.reorder(add, [its[0], ori_its[1], its[1], ori_its[2], ori_its[3]])
+    s0.compute_inline(relu)
+
+    # 1: simple cache_write with compute_at
+    conv_global = s0.cache_write(conv, "global")
+    s0.compute_at(conv_global, conv, s0[conv].iters[3])
+
+    # 2: simple cache_read with compute_at
+    kernel_global = s0.cache_read(kernel, "global", [conv_global])
+    s0.compute_at(kernel_global, conv_global, s0[conv_global].iters[4])
+    """
+        Placeholder: Data, Kernel_data
+        for i0 (0,4)
+          for i1 (0,512)
+            for i2 (0,9)
+              for i3 (0,9)
+                pad_temp = ...
+        for i0 (0,512)
+          for i1 (0,512)
+            for i2 (0,3)
+              for i3 (0,3)
+                Kernel_split = ...
+        for i0 (0,512)
+          for i1 (0,512)
+            for i2 (0,3)
+              for i3 (0,3)
+                Kernel = ...
+        for nn (0,4)
+          for ff (0,512)
+            for yy (0,7)
+              for xx (0,7)
+                for nn_c (None)
+                  for ff_c (None)
+                    for yy_c (None)
+                      for xx_c (None)
+                        for rc (None)
+                          for ax0 (None)
+                            for ax1 (None)
+                              for ax2 (None)
+                                for ax3 (None)
+                                  Kernel.global = ...
+                          for ry (None)
+                            for rx (None)
+                              compute.global = ...
+                compute = ...
+        for ax0.0 (0,2)
+          for ax1 (0,512)
+            for ax0.1 (0,2)
+              for ax2 (0,7)
+                for ax3 (0,7)
+                  T_add = ...
+    """
+    s1 = dag.infer_bound_from_state(s0)
+    assert s1[conv].iters[0].range.extent == 4
+    assert s1[conv].iters[1].range.extent == 512
+    assert s1[conv].iters[2].range.extent == 7
+    assert s1[conv].iters[3].range.extent == 7
+    assert s1[kernel_global].iters[0].range.extent == 1
+    assert s1[kernel_global].iters[1].range.extent == 1
+    assert s1[kernel_global].iters[2].range.extent == 3
+    assert s1[kernel_global].iters[3].range.extent == 3
+    assert s1[conv_global].iters[0].range.extent == 1
+    assert s1[conv_global].iters[1].range.extent == 1
+    assert s1[conv_global].iters[2].range.extent == 1
+    assert s1[conv_global].iters[3].range.extent == 1
+    assert s1[conv_global].iters[4].range.extent == 512
+    assert s1[conv_global].iters[5].range.extent == 3
+    assert s1[conv_global].iters[6].range.extent == 3
+
+    # 3: two level cache_read with compute_at
+    #    preparing for GPU's shared memory & local memory
+    pad_temp_global = s0.cache_read(pad_temp, "global", [conv_global])
+    pad_temp_shared = s0.cache_read(pad_temp_global, "shared", [conv_global])
+    s0.compute_at(pad_temp_global, conv_global, s0[conv_global].iters[2])
+    s0.compute_at(pad_temp_shared, conv_global, s0[conv_global].iters[4])
+
+    # 4: cache_read with multi readers
+    #    This stage cannot be compute at to its consumer
+    s0.cache_read(data, "global", [pad_temp, add])
+    """
+        Placeholder: Data, Kernel_data
+        for ax0 (0,4)
+          for ax1 (0,512)
+            for ax2 (0,7)
+              for ax3 (0,7)
+                Data.global = ...
+        for i0 (0,4)
+          for i1 (0,512)
+            for i2 (0,9)
+              for i3 (0,9)
+                pad_temp = ...
+        for i0 (0,512)
+          for i1 (0,512)
+            for i2 (0,3)
+              for i3 (0,3)
+                Kernel_split = ...
+        for i0 (0,512)
+          for i1 (0,512)
+            for i2 (0,3)
+              for i3 (0,3)
+                Kernel = ...
+        for nn (0,4)
+          for ff (0,512)
+            for yy (0,7)
+              for xx (0,7)
+                for nn_c (None)
+                  for ff_c (None)
+                    for yy_c (None)
+                      for ax0 (None)
+                        for ax1 (None)
+                          for ax2 (None)
+                            for ax3 (None)
+                              pad_temp.global = ...
+                      for xx_c (None)
+                        for rc (None)
+                          for ax0 (None)
+                            for ax1 (None)
+                              for ax2 (None)
+                                for ax3 (None)
+                                  Kernel.global = ...
+                          for ax0 (None)
+                            for ax1 (None)
+                              for ax2 (None)
+                                for ax3 (None)
+                                  pad_temp.global.shared = ...
+                          for ry (None)
+                            for rx (None)
+                              compute.global = ...
+                compute = ...
+        for ax0.0 (0,2)
+          for ax1 (0,512)
+            for ax0.1 (0,2)
+              for ax2 (0,7)
+                for ax3 (0,7)
+                  T_add = ...
+    """
+    s1 = dag.infer_bound_from_state(s0)
+    assert s1[conv].iters[0].range.extent == 4
+    assert s1[conv].iters[1].range.extent == 512
+    assert s1[conv].iters[2].range.extent == 7
+    assert s1[conv].iters[3].range.extent == 7
+    assert s1[kernel_global].iters[0].range.extent == 1
+    assert s1[kernel_global].iters[1].range.extent == 1
+    assert s1[kernel_global].iters[2].range.extent == 3
+    assert s1[kernel_global].iters[3].range.extent == 3
+    assert s1[conv_global].iters[0].range.extent == 1
+    assert s1[conv_global].iters[1].range.extent == 1
+    assert s1[conv_global].iters[2].range.extent == 1
+    assert s1[conv_global].iters[3].range.extent == 1
+    assert s1[conv_global].iters[4].range.extent == 512
+    assert s1[conv_global].iters[5].range.extent == 3
+    assert s1[conv_global].iters[6].range.extent == 3
+    assert s1[pad_temp_global].iters[0].range.extent == 1
+    assert s1[pad_temp_global].iters[1].range.extent == 512
+    assert s1[pad_temp_global].iters[2].range.extent == 3
+    assert s1[pad_temp_global].iters[3].range.extent == 3
+    assert s1[pad_temp_shared].iters[0].range.extent == 1
+    assert s1[pad_temp_shared].iters[1].range.extent == 1
+    assert s1[pad_temp_shared].iters[2].range.extent == 3
+    assert s1[pad_temp_shared].iters[3].range.extent == 3
+
+    # 5: cache_write with multi outputs
+    # TVM's cache_write actually has a bug with this case:
+    #
+    # After schedule.cache_write, TVM generate one new stage:
+    #   From: kernel_data -> kernel_split -> kernel
+    #   To:   kernel_data -> kernel_split_global -> kernel_split -> kernel
+    #
+    # But with topo sort analyse, we get:
+    #  //   kernel_data -> kernel_split_global -> kernel_split -> kernel
+    #         \                                                /
+    #          ----------------> kernel_split ---------------->
+    #
+    # Seems there's bug with the input/output tensor. Such multi outputs case
+    # should be unusual, so we make some hack on DoCacheWrite
+    # To be fixed in the future

Review comment:
       Same comment as above, we shouldn't use TODOs but we should track these in an Ansor stabilization PR

##########
File path: src/auto_scheduler/transform_step.h
##########
@@ -659,6 +671,153 @@ class ComputeRootStep : public Step {
   TVM_DEFINE_OBJECT_REF_METHODS(ComputeRootStep, Step, ComputeRootStepNode);
 };
 
+/********** Primitives adding new stages **********/
+
+/*!
+ * \brief Cache read step that corresponds to te::Schedule::cache_read.
+ * \note Cache read step will add an extra stage to the original ComputeDAG, a up-to-date ComputeDAG

Review comment:
       Can we clarify what `up-to-date` means? given different methods it is not clear what "up-to-date" means do we always replay the steps to freshen the graph before storing it back in the current_compute_dag? 

##########
File path: src/auto_scheduler/loop_state.h
##########
@@ -225,6 +238,13 @@ class StateNode : public Object {
    * operation.
    */
   AttachMap attach_map;
+  /*!
+   * \brief The up-to-date ComputeDAG of this state, used for some steps that may change the

Review comment:
       Can you explain this better? given the above methods it seems that `current_compute_dag` might in fact not be up-to-date, given that some scheduling steps modify the compute dag. 

##########
File path: src/auto_scheduler/transform_step.cc
##########
@@ -923,5 +958,272 @@ String ComputeRootStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
   return ss.str();
 }
 
+/********** Primitives adding new stages **********/
+
+/*!
+ * \brief Common part for steps that add new stages(e.g. CacheReadStep, CacheWriteStep,
+ * RfactorStep). This will filter out all steps that can change the stages of ComputeDAG.
+ */
+Array<Step> GetStageModifiableSteps(Step current_step, const Array<Step>& transform_steps) {
+  Array<Step> ret_steps;
+  for (const Step& step : transform_steps) {
+    if (step->IsInstance<CacheWriteStepNode>() || step->IsInstance<CacheReadStepNode>()) {
+      ret_steps.push_back(step);
+    }
+    // TODO(jcf94): add rfactor support
+    if (step.same_as(current_step)) {
+      break;
+    }
+  }
+  return ret_steps;
+}
+
+/********** Cache Read **********/
+CacheReadStep::CacheReadStep(int stage_id, String scope_name,
+                             const Array<Integer>& reader_stage_ids) {
+  auto node = make_object<CacheReadStepNode>();
+  node->stage_id = stage_id;
+  node->scope_name = std::move(scope_name);
+  node->reader_stage_ids = reader_stage_ids;
+  data_ = std::move(node);
+}
+
+CacheReadStep::CacheReadStep(dmlc::JSONReader* reader) {
+  auto node = make_object<CacheReadStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  std::string string_value;
+  reader->Read(&string_value);
+  node->scope_name = std::move(string_value);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  std::vector<int> int_list;
+  reader->Read(&int_list);
+  Array<Integer> reader_stage_ids;
+  for (int i : int_list) {
+    reader_stage_ids.push_back(i);
+  }
+  node->reader_stage_ids = std::move(reader_stage_ids);
+  data_ = std::move(node);
+}
+
+void CacheReadStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArraySeperator();
+  writer->WriteString(scope_name);
+  writer->WriteArrayItem(IntArrayToVector(reader_stage_ids));
+}
+
+int CacheReadStepNode::ApplyToState(State* state, const ComputeDAG& dag) const {
+  StateNode* pstate = state->CopyOnWrite();
+  const ComputeDAG& current_compute_dag =
+      dag.ReplayAndGetDAG(GetStageModifiableSteps(GetRef<Step>(this), (*state)->transform_steps));
+
+  // target_stage -> target_stage + target_store
+  // Update the op of the target stage, insert a new cache read stage behind, update the op of
+  // later stages, then update the stage_id mapping in AttachMap
+  int added_stage_id = stage_id + 1;
+  Stage tmp_stage = pstate->stages[stage_id];
+  tmp_stage.CopyOnWrite()->op = current_compute_dag->ops[stage_id];
+  pstate->stages.Set(stage_id, std::move(tmp_stage));
+  pstate->stages.insert(pstate->stages.begin() + added_stage_id,
+                        Stage(current_compute_dag->ops[added_stage_id]));
+  for (size_t i = added_stage_id + 1; i < pstate->stages.size(); ++i) {
+    tmp_stage = pstate->stages[i];
+    tmp_stage.CopyOnWrite()->op = current_compute_dag->ops[i];
+    pstate->stages.Set(i, std::move(tmp_stage));
+  }
+  pstate->attach_map = pstate->attach_map.ApplyStageIdOffset(added_stage_id);
+  pstate->current_compute_dag = std::move(current_compute_dag);
+
+  return added_stage_id;
+}
+
+te::Tensor CacheReadStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+                                              StageToAxesMap* stage_to_axes,
+                                              te::Schedule* schedule) const {
+  const te::Stage& stage = (*stages)[stage_id];
+  Array<te::Operation> readers;
+  for (const auto& i : reader_stage_ids) {
+    readers.push_back((*stages)[i]->origin_op);
+  }
+  auto out = schedule->cache_read(stage->origin_op.output(0), scope_name, readers);
+
+  const auto& new_stage = (*schedule)[out->op];
+  UpdateStageToAxesMap(new_stage, stage_to_axes);
+  stages->insert(stages->begin() + stage_id + 1, new_stage);
+
+  return out;
+}
+
+String CacheReadStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+                                           te::Schedule* schedule) const {
+  std::stringstream ss;
+  // Since the original stage will be changed after schedule apply, keep a copy here
+  // These information will be used to print Python API string later
+  auto stage = (*stages)[stage_id];
+  Array<te::Stage> reader_stages;
+  for (size_t i = 0; i < reader_stage_ids.size(); ++i) {
+    reader_stages.push_back((*stages)[reader_stage_ids[i]]);
+  }
+  auto out = ApplyToSchedule(stages, stage_to_axes, schedule);
+
+  ss << CleanName(out->op->name) << " = "
+     << "s.cache_read(" << CleanName(stage->op->name) << ", \"" << scope_name << "\", ["
+     << CleanName(reader_stages[0]->op->name);
+  for (size_t i = 1; i < reader_stage_ids.size(); ++i) {
+    ss << ", " << CleanName(reader_stages[i]->op->name);
+  }
+  ss << "])\n";
+
+  // Print the iterators of the new added stage
+  const auto& iters = out->op->root_iter_vars();
+  for (size_t i = 0; i < iters.size(); ++i) {
+    ss << CleanName(iters[i]->var->name_hint);
+    if (i != iters.size() - 1) {
+      ss << ", ";
+    }
+  }
+  ss << " = "
+     << "tuple(" << CleanName(out->op->name) << ".op.axis)\n";
+
+  return ss.str();
+}
+
+/********** Cache Write **********/
+CacheWriteStep::CacheWriteStep(int stage_id, String scope_name) {
+  auto node = make_object<CacheWriteStepNode>();
+  node->stage_id = stage_id;
+  node->scope_name = std::move(scope_name);
+  data_ = std::move(node);
+}
+
+CacheWriteStep::CacheWriteStep(dmlc::JSONReader* reader) {
+  auto node = make_object<CacheWriteStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  std::string string_value;
+  reader->Read(&string_value);
+  node->scope_name = std::move(string_value);
+  data_ = std::move(node);
+}
+
+void CacheWriteStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArraySeperator();
+  writer->WriteString(scope_name);
+}
+
+int CacheWriteStepNode::ApplyToState(State* state, const ComputeDAG& dag) const {
+  StateNode* pstate = state->CopyOnWrite();
+  int last_dag_op_size = pstate->current_compute_dag
+                             ? pstate->current_compute_dag.value().as<ComputeDAGNode>()->ops.size()
+                             : dag->ops.size();
+  const ComputeDAG& current_compute_dag =
+      dag.ReplayAndGetDAG(GetStageModifiableSteps(GetRef<Step>(this), (*state)->transform_steps));
+  int added_ops = current_compute_dag->ops.size() - last_dag_op_size;
+  // TODO(jcf94): Update this check to equal after fixing the cache write bug in TVM

Review comment:
       Can we track all of these in an Ansor tracking issue instead of putting TODOs in the code. My worry is it is very easy to forget about all the bugs that must be resolved before stabilizing a new subsystem. 

##########
File path: src/auto_scheduler/loop_state.h
##########
@@ -358,19 +369,43 @@ class State : public ObjectRef {
   void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter);
   /*!
    * \brief Schedule primitive corresponds to te.compute_inline.
-   * \param stage_id The index of the stage to be reordered.
+   * \param stage_id The index of the stage to be compute inlined.
    */
   void compute_inline(int stage_id);
   /*!
    * \brief Schedule primitive corresponds to te.compute_root.
-   * \param stage_id The index of the stage to be reordered.
+   * \param stage_id The index of the stage to be compute root.
    * \note After compute_root, we need careful dependency analysis to compute the accurate bound
    * information. However, it is relatively expensive and complicated, so we just fill "None" as
    * bound for the newly created iterators.
    * Call ComputeDAG::InferBound on the updated state to get the complete bound information.
    */
   void compute_root(int stage_id);
 
+  /********** Step APIs adding new stages **********/
+
+  /*!
+   * \brief Schedule primitive corresponds to te.schedule.cache_read.
+   * \param stage_id The index of the stage to be cache read.
+   * \param scope_name The scope name to be set for the new added read stage.

Review comment:
       ```suggestion
      * \param scope_name The scope name of the new read stage which will be inserted.
   ```

##########
File path: src/auto_scheduler/loop_state.h
##########
@@ -358,19 +369,43 @@ class State : public ObjectRef {
   void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter);
   /*!
    * \brief Schedule primitive corresponds to te.compute_inline.
-   * \param stage_id The index of the stage to be reordered.
+   * \param stage_id The index of the stage to be compute inlined.

Review comment:
       ```suggestion
      * \param stage_id The index of the stage to be marked compute inlined.
   ```
   
   This doesn't parse, suggestion above.

##########
File path: src/auto_scheduler/loop_state.h
##########
@@ -347,7 +358,7 @@ class State : public ObjectRef {
 
   /*!
    * \brief Schedule primitive corresponds to te.compute_at.
-   * \param stage_id The index of the stage to be reordered.
+   * \param stage_id The index of the stage to be compute at.

Review comment:
       ```suggestion
      * \param stage_id The index of the stage to be computed at.
   ```

##########
File path: src/auto_scheduler/transform_step.cc
##########
@@ -923,5 +958,272 @@ String ComputeRootStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
   return ss.str();
 }
 
+/********** Primitives adding new stages **********/
+
+/*!
+ * \brief Common part for steps that add new stages(e.g. CacheReadStep, CacheWriteStep,

Review comment:
       Can you explain this more?

##########
File path: src/auto_scheduler/loop_state.h
##########
@@ -195,6 +197,17 @@ class AttachMap : public ObjectRef {
   void UpdateIters(const std::vector<IterKey>& original_iters,
                    const std::vector<IterKey>& new_iters);
 
+  /*!
+   * \brief Traverse through `stage_to_attach_iter` and `iter_to_attached_stages` map, add offset
+   * to stage indexes that are larger than the start_id. Used for steps that inserts net stages to

Review comment:
       ```suggestion
      * to stage indexes that are larger than the start_id. Used for steps that inserts net stages to
   ```
   What does net stages mean?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org