You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lm...@apache.org on 2020/07/21 19:58:59 UTC

[incubator-tvm] branch master updated: [Ansor][AutoTVM v2.0] Phase 1: Add annotation/compute_at/compute_root/compute_inline steps (#6073)

This is an automated email from the ASF dual-hosted git repository.

lmzheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 71533a5  [Ansor][AutoTVM v2.0] Phase 1: Add annotation/compute_at/compute_root/compute_inline steps (#6073)
71533a5 is described below

commit 71533a5c252d9c501a20b906c9cc5e6471d3f686
Author: Chenfan <ch...@alibaba-inc.com>
AuthorDate: Wed Jul 22 03:58:48 2020 +0800

    [Ansor][AutoTVM v2.0] Phase 1: Add annotation/compute_at/compute_root/compute_inline steps (#6073)
    
    * Add annotation step
    
    * Add compute_at/compute_root/compute_inline
    
    * Doc update
    
    * Update
    
    * Update
    
    * Update measure record UT
    
    * Update
    
    * Update
    
    * Update
    
    * Move state implementation to step
    
    * Move measure_record implementation to step
    
    * Order update & API update
    
    * Update the order of state api
    
    * Update
---
 python/tvm/auto_scheduler/loop_state.py            | 228 ++++++-
 src/auto_scheduler/compute_dag.cc                  |  26 +-
 src/auto_scheduler/loop_state.cc                   | 359 +++++-----
 src/auto_scheduler/loop_state.h                    | 274 ++++----
 src/auto_scheduler/measure_record.cc               | 122 +---
 src/auto_scheduler/transform_step.cc               | 754 ++++++++++++++++++++-
 src/auto_scheduler/transform_step.h                | 503 +++++++++++++-
 src/auto_scheduler/utils.h                         |  32 +-
 .../python/unittest/test_auto_scheduler_common.py  |   2 +-
 .../unittest/test_auto_scheduler_loop_state.py     |  89 ++-
 .../python/unittest/test_auto_scheduler_measure.py |  42 +-
 11 files changed, 1914 insertions(+), 517 deletions(-)

diff --git a/python/tvm/auto_scheduler/loop_state.py b/python/tvm/auto_scheduler/loop_state.py
index 693a668..ab041cf 100644
--- a/python/tvm/auto_scheduler/loop_state.py
+++ b/python/tvm/auto_scheduler/loop_state.py
@@ -83,6 +83,24 @@ class State:
     -----
     This is a wrapper class of StateObject to deal with copy-on-write property
     """
+
+    # Static trans table for thread bind
+    # This is used to transform the annotation name to C++ enum
+    ANNOTATION_TRANS_TABLE = {
+        "none": 0,
+        "unroll": 1,
+        "vectorize": 2,
+        "parallel": 3,
+        "vthread": 4,
+        "blockIdx.x": 5,
+        "threadIdx.x": 6,
+        "blockIdx.y": 7,
+        "threadIdx.y": 8,
+        "blockIdx.z": 9,
+        "threadIdx.z": 10,
+        "tensorize": 11
+    }
+
     def __init__(self, state_object, dag):
         self.state_object = state_object
         self.compute_dag = dag
@@ -108,20 +126,140 @@ class State:
         """
         return [stage.op for stage in self.stages]
 
+    def bind(self, stage, iterator, thread_name):
+        """ Schedule primitive corresponds to te.bind.
+
+        Parameters
+        ----------
+        stage : Union[int, Operation, Tensor]
+            The Stage to be binded, which can be specified by the integer index, Operation,
+            or output tensor of the stage.
+        iterator : Iterator
+            The iterator to be binded.
+        thread_name : str
+            The thread type to be binded. Candidates:
+            - vthread
+            - blockIdx.x
+            - threadIdx.x
+            - blockIdx.y
+            - threadIdx.y
+            - blockIdx.z
+            - threadIdx.z
+
+        Returns
+        -------
+        res_it : Iterator
+            The binded Iterator.
+        """
+        if not thread_name in State.ANNOTATION_TRANS_TABLE.keys():
+            raise ValueError("Invalid thread_name: ", thread_name)
+
+        self.state_object, res = _ffi_api.StateBind(self.state_object,
+                                                    self._resolve_stage_id(stage), iterator,
+                                                    State.ANNOTATION_TRANS_TABLE[thread_name])
+        return res
+
+    def parallel(self, stage, iterator):
+        """ Schedule primitive corresponds to te.parallel.
+
+        Parameters
+        ----------
+        stage : Union[int, Operation, Tensor]
+            The Stage to be paralleled, which can be specified by the integer index, Operation,
+            or output tensor of the stage.
+        iterator : Iterator
+            The iterator to be paralleled.
+
+        Returns
+        -------
+        res_it : Iterator
+            The paralleled Iterator.
+        """
+        self.state_object, res = _ffi_api.StateParallel(self.state_object,
+                                                        self._resolve_stage_id(stage), iterator)
+        return res
+
+    def unroll(self, stage, iterator, max_unroll=None):
+        """ Schedule primitive corresponds to te.unroll.
+
+        Parameters
+        ----------
+        stage : Union[int, Operation, Tensor]
+            The Stage to be unrolled, which can be specified by the integer index, Operation,
+            or output tensor of the stage.
+        iterator : Iterator
+            The iterator to be unrolled.
+        max_unroll : Optional[int]
+            The max unroll limit. Iterator with extent larger than this limit will be skipped.
+
+        Returns
+        -------
+        res_it : Iterator
+            The unrolled Iterator.
+        """
+        self.state_object, res = _ffi_api.StateUnroll(self.state_object,
+                                                      self._resolve_stage_id(stage), iterator,
+                                                      max_unroll if max_unroll else -1)
+        return res
+
+    def vectorize(self, stage, iterator):
+        """ Schedule primitive corresponds to te.vectorize.
+
+        Parameters
+        ----------
+        stage : Union[int, Operation, Tensor]
+            The Stage to be vectorized, which can be specified by the integer index, Operation,
+            or output tensor of the stage.
+        iterator : Iterator
+            The iterator to be vectorized.
+
+        Returns
+        -------
+        res_it : Iterator
+            The vectorized Iterator.
+        """
+        self.state_object, res = _ffi_api.StateVectorize(self.state_object,
+                                                         self._resolve_stage_id(stage), iterator)
+        return res
+
+    def fuse(self, stage, iters):
+        """ Schedule primitive corresponds to te.fuse.
+
+        Parameters
+        ----------
+        stage : Union[int, Operation, Tensor]
+            The Stage to be fused, which can be specified by the integer index, Operation,
+            or output tensor of the stage.
+        iters : List[Iterator]
+            The iterators to be fused.
+
+        Returns
+        -------
+        res_it : Iterator
+            The fused Iterator.
+
+        Notes
+        -----
+        If the iterators to be fused have stages attached at them(by compute_at), the fused
+        result will become the new attach point.
+        """
+        self.state_object, res = _ffi_api.StateFuse(self.state_object,
+                                                    self._resolve_stage_id(stage), iters)
+        return res
+
     def reorder(self, stage, order):
         """ Schedule primitive corresponds to te.reorder.
 
         Parameters
         ----------
         stage : Union[int, Operation, Tensor]
-            The Stage to be reordered, can be a Stage order index, Stage operation or stage
-            output tensor.
+            The Stage to be reordered, which can be specified by the integer index, Operation,
+            or output tensor of the stage.
         order : List[Iterator]
             Iterators in the expected order.
         """
-        stage_id = self._resolve_stage_id(stage)
-
-        self.state_object = _ffi_api.StateReorder(self.state_object, stage_id, order)
+        self.state_object = _ffi_api.StateReorder(self.state_object, self._resolve_stage_id(stage),
+                                                  order)
 
     def split(self, stage, iterator, lengths, inner_to_outer=True):
         """ Schedule primitive corresponds to te.split.
@@ -132,8 +270,8 @@ class State:
         Parameters
         ----------
         stage : Union[int, Operation, Tensor]
-            The Stage to be split, can be a Stage order index, Stage operation or stage
-            output tensor.
+            The Stage to be split, which can be specified by the integer index, Operation,
+            or output tensor of the stage.
         iterator : Iterator
             The iterator to be split.
         lengths: List[int]
@@ -144,34 +282,74 @@ class State:
         Returns
         -------
         res_its : List[Iterator]
-            The splitted new Iterators
-        """
-        stage_id = self._resolve_stage_id(stage)
+            The splitted new Iterators.
 
-        self.state_object, res = _ffi_api.StateSplit(self.state_object, stage_id, iterator, lengths,
-                                                     inner_to_outer)
+        Notes
+        -----
+        If we do split on an iterator which has stages attached at it(by compute_at), the inner
+        most iterator of split results will become the new attach point.
+        """
+        self.state_object, res = _ffi_api.StateSplit(self.state_object,
+                                                     self._resolve_stage_id(stage),
+                                                     iterator, lengths, inner_to_outer)
         return res
 
-    def fuse(self, stage, iters):
-        """ Schedule primitive corresponds to te.fuse.
+    def compute_at(self, stage, target_stage, target_iter):
+        """ Schedule primitive corresponds to te.compute_at.
 
         Parameters
         ----------
         stage : Union[int, Operation, Tensor]
-            The Stage to be fused, can be a Stage order index, Stage operation or stage
-            output tensor.
-        iters : List[Iterator]
-            The iterators to be fused
+            The Stage to be compute at, which can be specified by the integer index, Operation,
+            or output tensor of the stage.
+        target_stage : Union[int, Operation, Tensor]
+            The target stage of compute_at, which can be specified by the integer index, Operation,
+            or output tensor of the stage.
+        target_iter : Iterator
+            The target Iterator of compute_at.
+
+        Notes
+        -----
+        After compute_at, we need careful dependency analysis to compute the accurate bound
+        information. However, it is relatively expensive and complicated, so we just fill "None"
+        as bound for the newly created iterators.
+        Call ComputeDAG::InferBound on the returned state to get the complete bound information.
+        """
+        self.state_object = _ffi_api.StateComputeAt(self.state_object,
+                                                    self._resolve_stage_id(stage),
+                                                    self._resolve_stage_id(target_stage),
+                                                    target_iter)
 
-        Returns
-        -------
-        res_it : Iterator
-            The fused Iterator
+    def compute_inline(self, stage):
+        """ Schedule primitive corresponds to te.compute_inline.
+
+        Parameters
+        ----------
+        stage : Union[int, Operation, Tensor]
+            The Stage to be compute inlined, which can be specified by the integer index, Operation,
+            or output tensor of the stage.
         """
-        stage_id = self._resolve_stage_id(stage)
+        self.state_object = _ffi_api.StateComputeInline(self.state_object,
+                                                        self._resolve_stage_id(stage))
 
-        self.state_object, res = _ffi_api.StateFuse(self.state_object, stage_id, iters)
-        return res
+    def compute_root(self, stage):
+        """ Schedule primitive corresponds to te.compute_root.
+
+        Parameters
+        ----------
+        stage : Union[int, Operation, Tensor]
+            The Stage to be compute root, which can be specified by the integer index, Operation,
+            or output tensor of the stage.
+
+        Notes
+        -----
+        After compute_root, we need careful dependency analysis to compute the accurate bound
+        information. However, it is relatively expensive and complicated, so we just fill "None"
+        as bound for the newly created iterators.
+        Call ComputeDAG::InferBound on the returned state to get the complete bound information.
+        """
+        self.state_object = _ffi_api.StateComputeRoot(self.state_object,
+                                                      self._resolve_stage_id(stage))
 
     def copy(self):
         """ Do deep copy of this State. """
diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc
index a7abcb8..d81dff6 100644
--- a/src/auto_scheduler/compute_dag.cc
+++ b/src/auto_scheduler/compute_dag.cc
@@ -270,19 +270,9 @@ std::pair<te::Schedule, Array<te::Tensor>> ComputeDAG::ApplySteps(
   }
 
   // Apply the history steps to TVM schedule
+  // Call each step's ApplyToSchedule method
   for (const auto& step : transform_steps) {
-    // Call each step's ApplyToSchedule method
-    // Note: some steps have extra parameters that must be passed and they may need different
-    // return value, so the ApplyToSchedule is not able to be merged to single interface
-    if (auto ps = step.as<ReorderStepNode>()) {
-      ps->ApplyToSchedule(stages, stage_to_axes);
-    } else if (auto ps = step.as<SplitStepNode>()) {
-      ps->ApplyToSchedule(stages, stage_to_axes);
-    } else if (auto ps = step.as<FuseStepNode>()) {
-      ps->ApplyToSchedule(stages, stage_to_axes);
-    } else {
-      LOG(FATAL) << "Invalid Step";
-    }
+    StepApplyToSchedule(step, stages, stage_to_axes);
   }
 
   return std::make_pair(schedule, operator->()->tensors);
@@ -326,15 +316,7 @@ String ComputeDAG::PrintStepsAsPython(const Array<Step>& transform_steps) const
   }
   // Call each step's PrintAsPythonAPI method
   for (const auto& step : transform_steps) {
-    if (auto ps = step.as<ReorderStepNode>()) {
-      ss << ps->PrintAsPythonAPI(&stages, &stage_to_axes);
-    } else if (auto ps = step.as<SplitStepNode>()) {
-      ss << ps->PrintAsPythonAPI(&stages, &stage_to_axes);
-    } else if (auto ps = step.as<FuseStepNode>()) {
-      ss << ps->PrintAsPythonAPI(&stages, &stage_to_axes);
-    } else {
-      LOG(FATAL) << "Invalid Step";
-    }
+    ss << StepPrintAsPythonAPI(step, &stages, &stage_to_axes);
   }
 
   return ss.str();
@@ -352,7 +334,7 @@ State ComputeDAG::InferBound(const State& state) const {
     ret_state = operator->()->init_state;
     pstate = ret_state.CopyOnWrite();
     pstate->transform_steps = state->transform_steps;
-    ret_state.DoSteps(*this);
+    ret_state.ApplySteps(*this);
   } else {
     ret_state = state;
     pstate = ret_state.CopyOnWrite();
diff --git a/src/auto_scheduler/loop_state.cc b/src/auto_scheduler/loop_state.cc
index 1bfcb9e..bfe5478 100644
--- a/src/auto_scheduler/loop_state.cc
+++ b/src/auto_scheduler/loop_state.cc
@@ -90,36 +90,122 @@ Stage::Stage(te::Operation op, StageKind op_type, const Array<Iterator>& iters,
   data_ = std::move(node);
 }
 
+/********** AttachMap **********/
+void AttachMap::SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id) {
+  AttachMapNode* pnode = CopyOnWrite();
+
+  // Delete the current entry of this stage
+  DeleteStageEntry(pnode, stage_id);
+
+  // Store the new stage/iterator relations to map
+  IterKey iter_key(target_stage_id, target_iter_id);
+  pnode->stage_to_attach_iter[stage_id] = iter_key;
+  pnode->iter_to_attached_stages[iter_key].push_back(stage_id);
+}
+
+void AttachMap::DeleteStage(int stage_id) {
+  AttachMapNode* pnode = CopyOnWrite();
+  // Delete the original stage entry
+  DeleteStageEntry(pnode, stage_id);
+}
+
+void AttachMap::UpdateIters(const std::vector<IterKey>& original_iters,
+                            const std::vector<IterKey>& new_iters) {
+  CHECK_EQ(original_iters.size(), new_iters.size());
+  AttachMapNode* pnode = CopyOnWrite();
+  for (size_t i = 0; i < original_iters.size(); ++i) {
+    auto entry = pnode->iter_to_attached_stages.find(original_iters[i]);
+    // We get <IterKey, std::vector<StageKey>> from this map
+    if (entry == pnode->iter_to_attached_stages.end()) {
+      // Skip if this iterator does not have any attach relations
+      continue;
+    }
+
+    // Update the attaching target of an stage to the new iter in `stage_to_attach_iter`
+    for (const auto& s : entry->second) {
+      pnode->stage_to_attach_iter[s] = new_iters[i];
+    }
+
+    // Remove the original iterator relation from `iter_to_attached_stages` and add the new
+    // iterator to it
+    std::vector<int> attached_stages = std::move(entry->second);
+    pnode->iter_to_attached_stages.erase(entry);
+    pnode->iter_to_attached_stages[new_iters[i]] = std::move(attached_stages);
+  }
+}
+
+void AttachMap::DeleteStageEntry(AttachMapNode* pnode, int stage_id) {
+  auto old_entry = pnode->stage_to_attach_iter.find(stage_id);
+  // We get <StageKey, IterKey> from this map
+  if (old_entry != pnode->stage_to_attach_iter.end()) {
+    // Delete the stage in `iter_to_attached_stages`, if the corresponding iterator does not have
+    // any attatched stage, delete this iterm too
+    auto entry2 = pnode->iter_to_attached_stages.find(old_entry->second);
+    // We get <IterKey, std::vector<StageKey>> from this map
+    FindAndDeleteItem(&entry2->second, stage_id);
+    if (entry2->second.size() == 0) {
+      pnode->iter_to_attached_stages.erase(entry2);
+    }
+    // Delete the stage in `stage_to_attach_iter`
+    pnode->stage_to_attach_iter.erase(old_entry);
+  }
+}
+
 /********** State **********/
 State::State(const Array<te::Operation>& ops) {
   auto node = make_object<StateNode>();
   for (const auto& op : ops) {
     node->stages.push_back(Stage(op));
   }
+  node->attach_map = AttachMap(make_object<AttachMapNode>());
   node->concrete = true;
   data_ = std::move(node);
 }
 
 /********** Schedule primitives apis for state **********/
-void State::reorder(int stage_id, const Array<Iterator>& order) {
+Iterator State::bind(int stage_id, const Iterator& it, IteratorAnnotation thread_type) {
   const Stage& stage = operator->()->stages[stage_id];
-  CHECK_EQ(order.size(), stage->iters.size()) << "The order of all iterators "
-                                              << "should be specified";
-  Array<Integer> after_ids;
-  GetIndices(stage->iters, order, &after_ids);
-  ReorderStep step = ReorderStep(stage_id, after_ids);
+  if (thread_type < IteratorAnnotation::kVThread || thread_type > IteratorAnnotation::kThreadZ) {
+    LOG(FATAL) << "thread_type error, valid: kVThread, kBlockX, kBlockY, "
+               << "kThreadX, kThreadY, kBlockZ, kThreadZ";
+  }
+  AnnotationStep step = AnnotationStep(stage_id, GetIndex(stage->iters, it), thread_type);
   CopyOnWrite()->transform_steps.push_back(step);
-  DoReorderStep(step);
+  return step->ApplyToState(this);
 }
 
-Array<Iterator> State::split(int stage_id, const Iterator& it,
-                             const Array<Optional<Integer>>& lengths, bool inner_to_outer) {
+Iterator State::parallel(int stage_id, const Iterator& it) {
   const Stage& stage = operator->()->stages[stage_id];
-  SplitStep step =
-      SplitStep(stage_id, GetIndex(stage->iters, it),
-                it->range.defined() ? it->range->extent : PrimExpr(), lengths, inner_to_outer);
+  AnnotationStep step =
+      AnnotationStep(stage_id, GetIndex(stage->iters, it), IteratorAnnotation::kParallel);
+  CopyOnWrite()->transform_steps.push_back(step);
+  return step->ApplyToState(this);
+}
+
+Iterator State::unroll(int stage_id, const Iterator& it, int max_unroll) {
+  const Stage& stage = operator->()->stages[stage_id];
+
+  // Don't unroll if the extent is larger than max_unroll
+  if (max_unroll != -1 && it->range.defined()) {
+    if (auto imm = it->range->extent.as<IntImmNode>()) {
+      if (imm->value > max_unroll) {
+        return it;
+      }
+    }
+  }
+
+  AnnotationStep step =
+      AnnotationStep(stage_id, GetIndex(stage->iters, it), IteratorAnnotation::kUnroll);
+  CopyOnWrite()->transform_steps.push_back(step);
+  return step->ApplyToState(this);
+}
+
+Iterator State::vectorize(int stage_id, const Iterator& it) {
+  const Stage& stage = operator->()->stages[stage_id];
+  AnnotationStep step =
+      AnnotationStep(stage_id, GetIndex(stage->iters, it), IteratorAnnotation::kVectorize);
   CopyOnWrite()->transform_steps.push_back(step);
-  return DoSplitStep(step);
+  return step->ApplyToState(this);
 }
 
 Iterator State::fuse(int stage_id, const Array<Iterator>& iters) {
@@ -128,174 +214,59 @@ Iterator State::fuse(int stage_id, const Array<Iterator>& iters) {
   GetIndices(stage->iters, iters, &indices);
   FuseStep step = FuseStep(stage_id, indices);
   CopyOnWrite()->transform_steps.push_back(step);
-  return DoFuseStep(step);
+  return step->ApplyToState(this);
 }
 
-/********** Step implementations for state **********/
-void State::DoReorderStep(const ReorderStep& step) {
-  const Stage& stage = operator->()->stages[step->stage_id];
-  Array<Iterator> iters;
-  for (auto x : step->after_ids) {
-    iters.push_back(stage->iters[x]);
-  }
-  StateNode* pstate = CopyOnWrite();
-  pstate->stages.Set(step->stage_id,
-                     Stage(stage->op, stage->op_type, iters, stage->compute_at, stage->attrs));
+void State::reorder(int stage_id, const Array<Iterator>& order) {
+  const Stage& stage = operator->()->stages[stage_id];
+  CHECK_EQ(order.size(), stage->iters.size()) << "The order of all iterators "
+                                              << "should be specified";
+  Array<Integer> after_ids;
+  GetIndices(stage->iters, order, &after_ids);
+  ReorderStep step = ReorderStep(stage_id, after_ids);
+  CopyOnWrite()->transform_steps.push_back(step);
+  step->ApplyToState(this);
 }
 
-// common part for DoSplitStep, DoFollowSplitStep, and DoFollowFusedSplitStep
-Array<Iterator> State::DoSplitStepCommon(int stage_id, int iter_id,
-                                         const Array<Optional<Integer>>& lengths,
-                                         bool inner_to_outer) {
+Array<Iterator> State::split(int stage_id, const Iterator& it,
+                             const Array<Optional<Integer>>& lengths, bool inner_to_outer) {
   const Stage& stage = operator->()->stages[stage_id];
-  const Iterator& it = stage->iters[iter_id];
-  bool concrete = true;
-
-  Optional<PrimExpr> tosplit_min, tosplit_extent;
-  if (it->range.defined()) {
-    tosplit_min = it->range->min;
-    tosplit_extent = it->range->extent;
-  } else {
-    tosplit_min = NullOpt;
-    tosplit_extent = NullOpt;
-  }
-
-  Array<Iterator> outs;
-  for (size_t i = 0; i < lengths.size(); ++i) {
-    Optional<Integer> l;
-    String name;
-    if (inner_to_outer) {
-      l = lengths[lengths.size() - i - 1];
-      name = it->name + "." + std::to_string(lengths.size() - i);
-    } else {
-      l = lengths[i];
-      name = it->name + "." + std::to_string(i);
-    }
-    Iterator res;
-    if (l && tosplit_min && tosplit_extent) {
-      res = Iterator(name, Range::FromMinExtent(tosplit_min.value(), l.value()), it->iter_kind,
-                     IteratorAnnotation::kNone);
-      tosplit_min = Integer(0);
-      tosplit_extent = indexdiv(tosplit_extent.value() + l.value() - 1, l.value());
-    } else {
-      res = Iterator(name, Range(), it->iter_kind, IteratorAnnotation::kNone);
-      tosplit_min = NullOpt;
-      tosplit_extent = NullOpt;
-      concrete = false;
-    }
-    outs.push_back(std::move(res));
-  }
-
-  Range range;
-  if (tosplit_min && tosplit_extent) {
-    range = Range::FromMinExtent(tosplit_min.value(), tosplit_extent.value());
-  }
-  if (inner_to_outer) {
-    outs.push_back(Iterator(it->name + ".0", range, it->iter_kind, IteratorAnnotation::kNone));
-    // Reverse the Iterator array
-    Array<Iterator> temp(outs.rbegin(), outs.rend());
-    outs = std::move(temp);
-  } else {
-    outs.push_back(Iterator(it->name + "." + std::to_string(lengths.size()), range, it->iter_kind,
-                            IteratorAnnotation::kNone));
-  }
-
-  Array<Iterator> new_iters;
-  new_iters.insert(new_iters.end(), stage->iters.begin(), stage->iters.begin() + iter_id);
-  new_iters.insert(new_iters.end(), outs.begin(), outs.end());
-  new_iters.insert(new_iters.end(), stage->iters.begin() + iter_id + 1, stage->iters.end());
-
-  StateNode* pstate = CopyOnWrite();
-  pstate->stages.Set(stage_id,
-                     Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs));
-  pstate->concrete &= concrete;
-
-  return outs;
+  SplitStep step =
+      SplitStep(stage_id, GetIndex(stage->iters, it),
+                it->range.defined() ? it->range->extent : PrimExpr(), lengths, inner_to_outer);
+  CopyOnWrite()->transform_steps.push_back(step);
+  return step->ApplyToState(this);
 }
 
-Array<Iterator> State::DoSplitStep(const SplitStep& step) {
-  return DoSplitStepCommon(step->stage_id, step->iter_id, step->lengths, step->inner_to_outer);
+void State::compute_at(int stage_id, int target_stage_id, const Iterator& target_iter) {
+  const Stage& target_stage = operator->()->stages[target_stage_id];
+  ComputeAtStep step =
+      ComputeAtStep(stage_id, target_stage_id, GetIndex(target_stage->iters, target_iter));
+  CopyOnWrite()->transform_steps.push_back(step);
+  step->ApplyToState(this);
 }
 
-Iterator State::DoFuseStep(const FuseStep& step) {
-  int stage_id = step->stage_id;
-  const Stage& stage = operator->()->stages[stage_id];
-
-  String new_name;
-  PrimExpr new_extent = 1;
-  IteratorKind new_iter_kind = IteratorKind::kSpecial;
-
-  for (size_t i = 0; i < step->fused_ids.size(); ++i) {
-    if (i > 0) {
-      CHECK_EQ(step->fused_ids[i]->value, step->fused_ids[i - 1]->value + 1);
-    }
-
-    const Iterator& it = stage->iters[step->fused_ids[i]];
-    new_name = new_name + it->name + "@";
-
-    if (it->range.defined() && new_extent.defined()) {
-      new_extent = new_extent * it->range->extent;
-    } else {
-      new_extent = PrimExpr();
-    }
-
-    if (i == 0) {
-      new_iter_kind = it->iter_kind;
-    } else {
-      if (new_iter_kind != it->iter_kind) {
-        new_iter_kind = IteratorKind::kMixed;
-      }
-    }
-  }
+void State::compute_inline(int stage_id) {
+  ComputeInlineStep step = ComputeInlineStep(stage_id);
+  CopyOnWrite()->transform_steps.push_back(step);
+  step->ApplyToState(this);
+}
 
-  Range range;
-  if (new_extent.defined()) {
-    range = Range::FromMinExtent(0, new_extent);
-  }
-  Iterator new_it = Iterator(new_name, range, new_iter_kind, IteratorAnnotation::kNone);
-  Array<Iterator> new_iters;
-  new_iters.insert(new_iters.end(), stage->iters.begin(),
-                   stage->iters.begin() + step->fused_ids.front());
-  new_iters.push_back(new_it);
-  new_iters.insert(new_iters.end(), stage->iters.begin() + step->fused_ids.back() + 1,
-                   stage->iters.end());
-
-  StateNode* pstate = CopyOnWrite();
-  pstate->stages.Set(stage_id,
-                     Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs));
-
-  return new_it;
+void State::compute_root(int stage_id) {
+  ComputeRootStep step = ComputeRootStep(stage_id);
+  CopyOnWrite()->transform_steps.push_back(step);
+  step->ApplyToState(this);
 }
 
-void State::DoSteps(const ComputeDAG& dag) {
+void State::ApplySteps(const ComputeDAG& dag) {
   CHECK(operator->()->stages.size()) << "Invalid State with empty operation stages.";
 
+  // Call each step's ApplyToState method
   for (const auto& step : operator->()->transform_steps) {
-    if (auto ps = step.as<ReorderStepNode>()) {
-      DoReorderStep(GetRef<ReorderStep>(ps));
-    } else if (auto ps = step.as<SplitStepNode>()) {
-      DoSplitStep(GetRef<SplitStep>(ps));
-    } else if (auto ps = step.as<FuseStepNode>()) {
-      DoFuseStep(GetRef<FuseStep>(ps));
-    } else {
-      LOG(FATAL) << "Invalid step: " << step;
-    }
+    StepApplyToState(step, this, dag);
   }
 }
 
-static const char* IteratorAnnotationString[] = {
-    "for",              // kNone = 0
-    "unroll",           // kUnroll = 1
-    "vectorize",        // kVectorize = 2
-    "parallel",         // kParallel = 3
-    "vthread",          // kVThread = 4
-    "gpu.blockIdx.x",   // kBlockX = 5
-    "gpu.threadIdx.x",  // kThreadX = 6
-    "gpu.blockIdx.y",   // kBlockY = 7
-    "gpu.threadIdx.y",  // kThreadY = 8
-    "tensorize"         // kTensorized = 9
-};
-
 // Print stage to ostream
 void PrintStage(std::ostream* os, int stage_id, const State& state, size_t base_indent,
                 bool delete_trivial_loop) {
@@ -332,6 +303,17 @@ void PrintStage(std::ostream* os, int stage_id, const State& state, size_t base_
 
       indent += 2;
     }
+
+    if (state.defined()) {
+      IterKey iter_key(stage_id, i);
+      auto pair = state->attach_map->iter_to_attached_stages.find(iter_key);
+      if (pair != state->attach_map->iter_to_attached_stages.end()) {
+        // Print the attached stage
+        for (const auto& attach_stage_id : pair->second) {
+          PrintStage(os, attach_stage_id, state, base_indent + indent, delete_trivial_loop);
+        }
+      }
+    }
   }
 
   for (size_t j = 0; j < base_indent + indent; ++j) {
@@ -386,6 +368,36 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     });
 
 /********** State interface API for ffi **********/
+TVM_REGISTER_GLOBAL("auto_scheduler.StateBind")
+    .set_body_typed([](State state, int stage_id, const Iterator& it, int thread_type) {
+      const auto& res = state.bind(stage_id, it, IteratorAnnotation(thread_type));
+      return Array<ObjectRef>{state, res};
+    });
+
+TVM_REGISTER_GLOBAL("auto_scheduler.StateParallel")
+    .set_body_typed([](State state, int stage_id, const Iterator& it) {
+      const auto& res = state.parallel(stage_id, it);
+      return Array<ObjectRef>{state, res};
+    });
+
+TVM_REGISTER_GLOBAL("auto_scheduler.StateUnroll")
+    .set_body_typed([](State state, int stage_id, const Iterator& it, int max_unroll) {
+      const auto& res = state.unroll(stage_id, it, max_unroll);
+      return Array<ObjectRef>{state, res};
+    });
+
+TVM_REGISTER_GLOBAL("auto_scheduler.StateVectorize")
+    .set_body_typed([](State state, int stage_id, const Iterator& it) {
+      const auto& res = state.vectorize(stage_id, it);
+      return Array<ObjectRef>{state, res};
+    });
+
+TVM_REGISTER_GLOBAL("auto_scheduler.StateFuse")
+    .set_body_typed([](State state, int stage_id, const Array<Iterator>& iters) {
+      const auto& res = state.fuse(stage_id, iters);
+      return Array<ObjectRef>{state, res};
+    });
+
 TVM_REGISTER_GLOBAL("auto_scheduler.StateReorder")
     .set_body_typed([](State state, int stage_id, const Array<Iterator>& order) {
       state.reorder(stage_id, order);
@@ -399,10 +411,23 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateSplit")
       return Array<ObjectRef>{state, res};
     });
 
-TVM_REGISTER_GLOBAL("auto_scheduler.StateFuse")
-    .set_body_typed([](State state, int stage_id, const Array<Iterator>& iters) {
-      const auto& res = state.fuse(stage_id, iters);
-      return Array<ObjectRef>{state, res};
+TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeAt")
+    .set_body_typed([](State state, int stage_id, int target_stage_id,
+                       const Iterator& target_iter) {
+      state.compute_at(stage_id, target_stage_id, target_iter);
+      return state;
+    });
+
+TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeInline")
+    .set_body_typed([](State state, int stage_id) {
+      state.compute_inline(stage_id);
+      return state;
+    });
+
+TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeRoot")
+    .set_body_typed([](State state, int stage_id) {
+      state.compute_root(stage_id);
+      return state;
     });
 
 TVM_REGISTER_GLOBAL("auto_scheduler.StateEqual").set_body_typed([](State state1, State state2) {
diff --git a/src/auto_scheduler/loop_state.h b/src/auto_scheduler/loop_state.h
index 04e5304..4d6477b 100644
--- a/src/auto_scheduler/loop_state.h
+++ b/src/auto_scheduler/loop_state.h
@@ -51,6 +51,9 @@
 #include <tvm/runtime/container.h>
 
 #include <functional>
+#include <unordered_map>
+#include <utility>
+#include <vector>
 
 #include "transform_step.h"
 
@@ -79,84 +82,6 @@ enum class ComputeAtKind : int {
   kIter = 2,
 };
 
-/*! \brief The type of an iterator. */
-enum class IteratorKind : int {
-  /*! \brief Spatial iterator. */
-  kSpatial = 0,
-  /*! \brief Reduction iterator. */
-  kReduction = 1,
-  /*! \brief Fused spatial and reduction iterator. */
-  kMixed = 2,
-  /*! \brief Special iterator. (e.g. virtual root iterator) */
-  kSpecial = 3
-};
-
-/*! \brief The type of an iterator's annotation. */
-enum class IteratorAnnotation : int {
-  /*! \brief This iterator has no annotation. */
-  kNone = 0,
-  /*! \brief This iterator has been unrolled. */
-  kUnroll = 1,
-  /*! \brief This iterator has been vectorized. */
-  kVectorize = 2,
-  /*! \brief This iterator has been paralleld. */
-  kParallel = 3,
-  /*! \brief This iterator has been bind to vthread. */
-  kVThread = 4,
-  /*! \brief This iterator has been bind to blockIdx.x. */
-  kBlockX = 5,
-  /*! \brief This iterator has been bind to threadIdx.x. */
-  kThreadX = 6,
-  /*! \brief This iterator has been bind to blockIdx.y. */
-  kBlockY = 7,
-  /*! \brief This iterator has been bind to threadIdx.y. */
-  kThreadY = 8,
-  /*! \brief This iterator has been mapped with a tensorize intrinsic. */
-  kTensorized = 9
-};
-
-/*!
- * \brief A for loop iterator
- * Similar to tvm::IterVar in `include/tvm/tir/expr.h`
- */
-class IteratorNode : public Object {
- public:
-  /*! \brief The name of this iterator. */
-  String name;
-  /*! \brief The range of this iterator. */
-  Range range;
-  /*! \brief The iterator type of this iterator. */
-  IteratorKind iter_kind;
-  /*! \brief The annotation type of this iterator. */
-  IteratorAnnotation annotation;
-
-  void VisitAttrs(tvm::AttrVisitor* v) {
-    v->Visit("name", &name);
-    v->Visit("range", &range);
-  }
-
-  static constexpr const char* _type_key = "auto_scheduler.Iterator";
-  TVM_DECLARE_FINAL_OBJECT_INFO(IteratorNode, Object);
-};
-
-/*!
- * \brief Managed reference to IteratorNode.
- * \sa IteratorNode
- */
-class Iterator : public ObjectRef {
- public:
-  /*!
-   * \brief The constructor.
-   * \param name The name of this iterator.
-   * \param range The range of this iterator.
-   * \param iter_kind The iterator type of this iterator.
-   * \param annotation The annotation type of this iterator.
-   */
-  Iterator(String name, Range range, IteratorKind iter_kind, IteratorAnnotation annotation);
-
-  TVM_DEFINE_OBJECT_REF_METHODS(Iterator, ObjectRef, IteratorNode);
-};
-
 /*! \brief Stage-level attributes. */
 struct StageAttributes {
   /*! \brief The maximum steps for the pragma `auto_unroll_max_step`. */
@@ -167,16 +92,16 @@ struct StageAttributes {
 
 /*!
  * \brief A op stage in the compute declaration.
- * Similar to te::Stage in `include/schedule.h`.
+ * Similar to te::Stage in `include/tvm/te/schedule.h`.
  */
 class StageNode : public Object {
  public:
   /*! \brief The operator of this stage */
   te::Operation op;
-  /*! \brief The type of this stage. */
-  StageKind op_type;
   /*! \brief The iterators in this stage. */
   Array<Iterator> iters;
+  /*! \brief The type of this stage. */
+  StageKind op_type;
   /*! \brief The compute location of this stage. */
   ComputeAtKind compute_at;
   /*! \brief Other stage-level attributes. */
@@ -185,6 +110,8 @@ class StageNode : public Object {
   void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("op", &op);
     v->Visit("iters", &iters);
+    v->Visit("op_type", &op_type);
+    v->Visit("compute_at", &compute_at);
   }
 
   static constexpr const char* _type_key = "auto_scheduler.Stage";
@@ -217,6 +144,70 @@ class Stage : public ObjectRef {
   TVM_DEFINE_OBJECT_REF_COW_METHOD(StageNode);
 };
 
+/*! \brief Use stage_id to represent a stage. */
+using StageKey = int;
+/*! \brief Use stage_id and iter_id to represent a iterator. */
+using IterKey = std::pair<int, int>;
+
+/*!
+ * \brief stores the compute_at relation between stages
+ * This stores a bi-directional mapping from stages and iter:
+ * 1. Stage to its attached iterator
+ * 2. Iterator to the stage attached to it
+ * You can use AttachMapNode::stage_to_attach_iter and AttachMapNode::iter_to_attached_stages
+ * to query the relations
+ */
+class AttachMapNode : public Object {
+ public:
+  /*! \brief A Map to store the mapping of stage to its attached iterator. */
+  std::unordered_map<StageKey, IterKey> stage_to_attach_iter;
+  /*! \brief A Map to store the mapping of iterator to the stage attached to it. */
+  std::unordered_map<IterKey, std::vector<StageKey>> iter_to_attached_stages;
+
+  static constexpr const char* _type_key = "auto_scheduler.AttachMap";
+  TVM_DECLARE_FINAL_OBJECT_INFO(AttachMapNode, Object);
+};
+
+/*!
+ * \brief Managed reference to AttachMapNode.
+ * \sa AttachMapNode
+ */
+class AttachMap : public ObjectRef {
+ public:
+  /*!
+   * \brief Process the stage/iterator mapping after compute at.
+   * \param stage_id The index of the stage to be compute at.
+   * \param target_stage_id The index of stage that this step will compute at to.
+   * \param target_iter_id The index of iterator in target stage that this step will compute at to.
+   */
+  void SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id);
+  /*!
+   * \brief This is a public wrapper of `DeleteStageEntry`. To delete the entry of a specific stage.
+   * \param stage_id The index of the stage to be compute at.
+   */
+  void DeleteStage(int stage_id);
+  /*!
+   * \brief Find the relations of original iterators in AttachMap, and update them with the new
+   * iterators. Both `stage_to_attach_iter` and `iter_to_attached_stages` will be updated.
+   * \param original_iters The original IterKey.
+   * \param new_iters The new IterKey to update.
+   */
+  void UpdateIters(const std::vector<IterKey>& original_iters,
+                   const std::vector<IterKey>& new_iters);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(AttachMap, ObjectRef, AttachMapNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(AttachMapNode);
+
+ private:
+  /*!
+   * \brief To delete the entry of a specific stage. This will remove the items related to this
+   * stage in both `stage_to_attach_iter` and `iter_to_attached_stages` map.
+   * \param pnode A mutable pointer to AttachMapNode.
+   * \param stage_id The index of stage that will be removed from the map.
+   */
+  static void DeleteStageEntry(AttachMapNode* pnode, int stage_id);
+};
+
 /*!
  * \brief A state in the search process.
  * It consists of the current loop structure and a list of transformation steps used to construct
@@ -230,6 +221,11 @@ class StateNode : public Object {
   /*! \brief History transformation steps. */
   Array<Step> transform_steps;
   /*!
+   * \brief The attach relations of stages and iterators. This is used to track the compute at
+   * operation.
+   */
+  AttachMap attach_map;
+  /*!
    * \brief Indicate whether this state has unfilled tile sizes. A concrete state means that all
    * tile sizes of the state is filled. Only concrete state can be apply to TVM schedule.
    */
@@ -275,17 +271,60 @@ class State : public ObjectRef {
   String ToStr(bool delete_trivial_loop = true) const;
 
   /*!
-   * \brief General do step functions with a runtime dynamic dispatcher. This will re-apply all the
-   * transform steps with the initial state.
+   * \brief General call step functions with a runtime dynamic dispatcher. This will re-apply all
+   * the transform steps from the initial state.
    * \param dag The original ComputeDAG of this state.
-   * \note This is different from the class member `current_compute_dag`, for some transform step
-   * may change the op stage structure of the ComputeDAG.
+   * \note The input `dag` is different from the class member `current_compute_dag`.
+   * This function takes the initial ComputeDAG as input to replay all the history. While the
+   * `current_compute_dag` is used to track the current stage status, for some transform step may
+   * change the op stage structure.
    */
-  void DoSteps(const ComputeDAG& dag);
+  void ApplySteps(const ComputeDAG& dag);
 
-  /* Step APIs for State. */
+  /********** Step APIs working on single stage **********/
 
   /*!
+   * \brief Schedule primitive corresponds to te.bind.
+   * \param stage_id The index of the stage to be binded.
+   * \param it The iterator to be binded.
+   * \param thread_type The thread type to be binded. We dirctly use the IteratorAnnotation as
+   * this input.
+   * \return The iterator result after binded.
+   */
+  Iterator bind(int stage_id, const Iterator& it, IteratorAnnotation thread_type);
+  /*!
+   * \brief Schedule primitive corresponds to te.parallel.
+   * \param stage_id The index of the stage to be paralleled.
+   * \param it The iterator to be paralleled.
+   * \return The iterator result after parallel.
+   */
+  Iterator parallel(int stage_id, const Iterator& it);
+  /*!
+   * \brief Schedule primitive corresponds to te.unroll.
+   * \param stage_id The index of the stage to be unrolled.
+   * \param it The iterator to be unrolled.
+   * \param max_unroll The max unroll limit. Iterator with extent larger than this limit will be
+   * skipped.
+   * \return The iterator result after unrolled.
+   */
+  Iterator unroll(int stage_id, const Iterator& it, int max_unroll = -1);
+  /*!
+   * \brief Schedule primitive corresponds to te.vectorize.
+   * \param stage_id The index of the stage to be vectorized.
+   * \param it The iterator to be vectorized.
+   * \return The iterator result after vectorize.
+   */
+  Iterator vectorize(int stage_id, const Iterator& it);
+  /*!
+   * \brief Schedule primitive corresponds to te.fuse.
+   * \param stage_id The index of the stage to be fused.
+   * \param iters The iterators to be fused.
+   * \return The iterator result after fuse.
+   * \note If the iterators to be fused have stages attached at them(by compute_at), the fused
+   * result will become the new attach point.
+   */
+  Iterator fuse(int stage_id, const Array<Iterator>& iters);
+  /*!
    * \brief Schedule primitive corresponds to te.reorder.
    * \param stage_id The index of the stage to be reordered.
    * \param order The expected iterator order.
@@ -294,57 +333,46 @@ class State : public ObjectRef {
   /*!
    * \brief Schedule primitive corresponds to te.split.
    * \param stage_id The index of the stage to be split.
-   * \param it The iterator the be split.
+   * \param it The iterator to be split.
    * \param lengths The multiple split factors. Can be None to be filled by search policy.
    * \param inner_to_outer Whether the factor go from inner to outer, or from outer to inner.
    * \return The iterator results after split.
+   * \note If we do split on an iterator which has stages attached at it(by compute_at), the inner
+   * most iterator of split results will become the new attach point.
    */
   Array<Iterator> split(int stage_id, const Iterator& it, const Array<Optional<Integer>>& lengths,
                         bool inner_to_outer = true);
-  /*!
-   * \brief Schedule primitive corresponds to te.fuse.
-   * \param stage_id The index of the stage to be fused.
-   * \param iters The iterators to be fused.
-   * \return The iterator result after fuse.
-   */
-  Iterator fuse(int stage_id, const Array<Iterator>& iters);
 
-  TVM_DEFINE_OBJECT_REF_METHODS(State, ObjectRef, StateNode);
-  TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode);
-
- private:
-  /* Do transform steps
-   * Note: The following functions only change loop state but do not change transform_history.
-   * We separate these functions out, so you can call them for replay easily given history steps */
+  /********** Step APIs working on multiple stages **********/
 
   /*!
-   * \brief Apply reorder step to current state.
-   * \param step A ReorderStep.
+   * \brief Schedule primitive corresponds to te.compute_at.
+   * \param stage_id The index of the stage to be reordered.
+   * \param target_stage_id The index of stage that this step will compute at to.
+   * \param target_iter The iterator in target stage that this step will compute at to.
+   * \note After compute_at, we need careful dependency analysis to compute the accurate bound
+   * information. However, it is relatively expensive and complicated, so we just fill "None" as
+   * bound for the newly created iterators.
+   * Call ComputeDAG::InferBound on the updated state to get the complete bound information.
    */
-  void DoReorderStep(const ReorderStep& step);
+  void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter);
   /*!
-   * \brief Apply split step to current state.
-   * \param step A SplitStep.
-   * \return The iterator results after split.
+   * \brief Schedule primitive corresponds to te.compute_inline.
+   * \param stage_id The index of the stage to be reordered.
    */
-  Array<Iterator> DoSplitStep(const SplitStep& step);
+  void compute_inline(int stage_id);
   /*!
-   * \brief Apply fuse step to current state.
-   * \param step A FuseStep.
-   * \return The iterator result after fuse.
+   * \brief Schedule primitive corresponds to te.compute_root.
+   * \param stage_id The index of the stage to be reordered.
+   * \note After compute_root, we need careful dependency analysis to compute the accurate bound
+   * information. However, it is relatively expensive and complicated, so we just fill "None" as
+   * bound for the newly created iterators.
+   * Call ComputeDAG::InferBound on the updated state to get the complete bound information.
    */
-  Iterator DoFuseStep(const FuseStep& step);
+  void compute_root(int stage_id);
 
-  /*!
-   * \brief Common function for DoSplitStep and DoFollowSplitStep(Will be added later).
-   * \param stage_id The index of the stage to be split.
-   * \param iter_id The index of the iterator to be split.
-   * \param lengths The multiple split factors.
-   * \param inner_to_outer The split direction.
-   * \return The iterator results after split.
-   */
-  Array<Iterator> DoSplitStepCommon(int stage_id, int iter_id,
-                                    const Array<Optional<Integer>>& lengths, bool inner_to_outer);
+  TVM_DEFINE_OBJECT_REF_METHODS(State, ObjectRef, StateNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode);
 };
 
 }  // namespace auto_scheduler
diff --git a/src/auto_scheduler/measure_record.cc b/src/auto_scheduler/measure_record.cc
index f6f882e..39f9ad8 100644
--- a/src/auto_scheduler/measure_record.cc
+++ b/src/auto_scheduler/measure_record.cc
@@ -42,25 +42,6 @@
 namespace dmlc {
 namespace json {
 
-inline std::vector<int> IntArrayToVector(const ::tvm::Array<::tvm::Integer>& data) {
-  std::vector<int> out;
-  for (const auto& x : data) {
-    CHECK(x.defined());
-    out.push_back(x);
-  }
-  return out;
-}
-
-inline std::vector<int> IntArrayToVector(
-    const ::tvm::Array<::tvm::Optional<::tvm::Integer>>& data) {
-  std::vector<int> out;
-  for (const auto& x : data) {
-    CHECK(x);
-    out.push_back(x.value());
-  }
-  return out;
-}
-
 template <>
 struct Handler<::tvm::Array<::tvm::auto_scheduler::Stage>> {
   inline static void Write(dmlc::JSONWriter* writer,
@@ -82,28 +63,10 @@ struct Handler<::tvm::Array<::tvm::auto_scheduler::Step>> {
   inline static void Write(dmlc::JSONWriter* writer,
                            const ::tvm::Array<::tvm::auto_scheduler::Step>& data) {
     writer->BeginArray(false);
-    for (size_t i = 0; i < data.size(); ++i) {
+    for (const auto& step : data) {
       writer->WriteArraySeperator();
       writer->BeginArray(false);
-      if (auto ps = data[i].as<::tvm::auto_scheduler::ReorderStepNode>()) {
-        writer->WriteArrayItem(std::string("RE"));
-        writer->WriteArrayItem(ps->stage_id);
-        writer->WriteArrayItem(IntArrayToVector(ps->after_ids));
-      } else if (auto ps = data[i].as<::tvm::auto_scheduler::SplitStepNode>()) {
-        writer->WriteArrayItem(std::string("SP"));
-        writer->WriteArrayItem(ps->stage_id);
-        writer->WriteArrayItem(ps->iter_id);
-        writer->WriteArrayItem(ps->extent ? ::tvm::auto_scheduler::GetIntImm(ps->extent.value())
-                                          : 0);
-        writer->WriteArrayItem(IntArrayToVector(ps->lengths));
-        writer->WriteArrayItem(static_cast<int>(ps->inner_to_outer));
-      } else if (auto ps = data[i].as<::tvm::auto_scheduler::FuseStepNode>()) {
-        writer->WriteArrayItem(std::string("FU"));
-        writer->WriteArrayItem(ps->stage_id);
-        writer->WriteArrayItem(IntArrayToVector(ps->fused_ids));
-      } else {
-        LOG(FATAL) << "Invalid step: " << data[i];
-      }
+      step->WriteToRecord(writer);
       writer->EndArray();
     }
     writer->EndArray();
@@ -111,67 +74,12 @@ struct Handler<::tvm::Array<::tvm::auto_scheduler::Step>> {
 
   inline static void Read(dmlc::JSONReader* reader,
                           ::tvm::Array<::tvm::auto_scheduler::Step>* data) {
-    std::vector<int> int_list;
-    bool s, inner_to_outer;
-    std::string name, scope_name, pragma_type, ti_func_name;
-    int stage_id, iter_id, extent;
-
+    bool s;
     reader->BeginArray();
     data->clear();
     while (reader->NextArrayItem()) {
       reader->BeginArray();
-      s = reader->NextArrayItem();
-      CHECK(s);
-      reader->Read(&name);
-      if (name == "RE") {
-        s = reader->NextArrayItem();
-        CHECK(s);
-        reader->Read(&stage_id);
-        s = reader->NextArrayItem();
-        CHECK(s);
-        reader->Read(&int_list);
-        ::tvm::Array<::tvm::Integer> after_ids;
-        for (const auto& i : int_list) {
-          after_ids.push_back(i);
-        }
-        data->push_back(::tvm::auto_scheduler::ReorderStep(stage_id, after_ids));
-      } else if (name == "SP") {
-        s = reader->NextArrayItem();
-        CHECK(s);
-        reader->Read(&stage_id);
-        s = reader->NextArrayItem();
-        CHECK(s);
-        reader->Read(&iter_id);
-        s = reader->NextArrayItem();
-        CHECK(s);
-        reader->Read(&extent);
-        s = reader->NextArrayItem();
-        CHECK(s);
-        reader->Read(&int_list);
-        s = reader->NextArrayItem();
-        CHECK(s);
-        reader->Read(&inner_to_outer);
-        ::tvm::Array<::tvm::Optional<::tvm::Integer>> lengths;
-        for (const auto& i : int_list) {
-          lengths.push_back(::tvm::Integer(i));
-        }
-        data->push_back(::tvm::auto_scheduler::SplitStep(
-            stage_id, iter_id, extent == 0 ? ::tvm::PrimExpr() : extent, lengths, inner_to_outer));
-      } else if (name == "FU") {
-        s = reader->NextArrayItem();
-        CHECK(s);
-        reader->Read(&stage_id);
-        s = reader->NextArrayItem();
-        CHECK(s);
-        reader->Read(&int_list);
-        ::tvm::Array<::tvm::Integer> fused_ids;
-        for (const auto& i : int_list) {
-          fused_ids.push_back(i);
-        }
-        data->push_back(::tvm::auto_scheduler::FuseStep(stage_id, fused_ids));
-      } else {
-        LOG(FATAL) << "Invalid step format";
-      }
+      data->push_back(::tvm::auto_scheduler::StepReadFromRecord(reader));
       s = reader->NextArrayItem();
       CHECK(!s);
     }
@@ -187,8 +95,8 @@ struct Handler<::tvm::auto_scheduler::StateNode> {
     writer->EndArray();
   }
   inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_scheduler::StateNode* data) {
-    reader->BeginArray();
     bool s;
+    reader->BeginArray();
     s = reader->NextArrayItem();
     CHECK(s);
     reader->Read(&data->stages);
@@ -210,18 +118,17 @@ struct Handler<::tvm::auto_scheduler::SearchTaskNode> {
     writer->EndArray();
   }
   inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_scheduler::SearchTaskNode* data) {
-    std::string target_str;
     bool s;
-
+    std::string str_value;
     reader->BeginArray();
     s = reader->NextArrayItem();
     CHECK(s);
-    reader->Read(&target_str);
-    data->workload_key = std::move(target_str);
+    reader->Read(&str_value);
+    data->workload_key = std::move(str_value);
     s = reader->NextArrayItem();
     CHECK(s);
-    reader->Read(&target_str);
-    data->target = ::tvm::Target::Create(target_str);
+    reader->Read(&str_value);
+    data->target = ::tvm::Target::Create(str_value);
     s = reader->NextArrayItem();
     CHECK(!s);
   }
@@ -237,11 +144,11 @@ struct Handler<::tvm::auto_scheduler::MeasureInputNode> {
     writer->EndArray();
   }
   inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_scheduler::MeasureInputNode* data) {
-    bool s;
     auto task_node = ::tvm::make_object<::tvm::auto_scheduler::SearchTaskNode>();
     auto state_node = ::tvm::make_object<::tvm::auto_scheduler::StateNode>();
     state_node->concrete = true;
 
+    bool s;
     reader->BeginArray();
     s = reader->NextArrayItem();
     CHECK(s);
@@ -277,15 +184,14 @@ struct Handler<::tvm::auto_scheduler::MeasureResultNode> {
   }
   inline static void Read(dmlc::JSONReader* reader,
                           ::tvm::auto_scheduler::MeasureResultNode* data) {
+    std::vector<double> double_list;
     bool s;
-    std::vector<double> tmp;
-
     reader->BeginArray();
     s = reader->NextArrayItem();
     CHECK(s);
-    reader->Read(&tmp);
+    reader->Read(&double_list);
     data->costs.clear();
-    for (const auto& i : tmp) {
+    for (const auto& i : double_list) {
       data->costs.push_back(::tvm::FloatImm(::tvm::DataType::Float(64), i));
     }
     s = reader->NextArrayItem();
diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc
index 90b4db8..6c672a5 100644
--- a/src/auto_scheduler/transform_step.cc
+++ b/src/auto_scheduler/transform_step.cc
@@ -28,7 +28,9 @@
 #include <tvm/runtime/registry.h>
 #include <tvm/te/operation.h>
 
+#include <string>
 #include <utility>
+#include <vector>
 
 #include "loop_state.h"
 #include "utils.h"
@@ -36,6 +38,404 @@
 namespace tvm {
 namespace auto_scheduler {
 
+const char* IteratorAnnotationString[] = {
+    "for",          // kNone = 0
+    "unroll",       // kUnroll = 1
+    "vectorize",    // kVectorize = 2
+    "parallel",     // kParallel = 3
+    "vthread",      // kVThread = 4
+    "blockIdx.x",   // kBlockX = 5
+    "threadIdx.x",  // kThreadX = 6
+    "blockIdx.y",   // kBlockY = 7
+    "threadIdx.y",  // kThreadY = 8
+    "blockIdx.z",   // kBlockZ = 9
+    "threadIdx.z",  // kThreadZ = 10
+    "tensorize"     // kTensorized = 11
+};
+
+Step StepReadFromRecord(dmlc::JSONReader* reader) {
+  std::string name;
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&name);
+  if (name == AnnotationStepNode::record_prefix_str) {
+    return AnnotationStep(reader);
+  } else if (name == FuseStepNode::record_prefix_str) {
+    return FuseStep(reader);
+  } else if (name == ReorderStepNode::record_prefix_str) {
+    return ReorderStep(reader);
+  } else if (name == SplitStepNode::record_prefix_str) {
+    return SplitStep(reader);
+  } else if (name == ComputeAtStepNode::record_prefix_str) {
+    return ComputeAtStep(reader);
+  } else if (name == ComputeInlineStepNode::record_prefix_str) {
+    return ComputeInlineStep(reader);
+  } else if (name == ComputeRootStepNode::record_prefix_str) {
+    return ComputeRootStep(reader);
+  } else {
+    LOG(FATAL) << "Invalid step format: " << name;
+  }
+  return Step();
+}
+
+void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) {
+  if (auto ps = step.as<AnnotationStepNode>()) {
+    ps->ApplyToState(state);
+  } else if (auto ps = step.as<FuseStepNode>()) {
+    ps->ApplyToState(state);
+  } else if (auto ps = step.as<ReorderStepNode>()) {
+    ps->ApplyToState(state);
+  } else if (auto ps = step.as<SplitStepNode>()) {
+    ps->ApplyToState(state);
+  } else if (auto ps = step.as<ComputeAtStepNode>()) {
+    ps->ApplyToState(state);
+  } else if (auto ps = step.as<ComputeInlineStepNode>()) {
+    ps->ApplyToState(state);
+  } else if (auto ps = step.as<ComputeRootStepNode>()) {
+    ps->ApplyToState(state);
+  } else {
+    LOG(FATAL) << "Invalid step: " << step;
+  }
+}
+
+void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages,
+                         StageToAxesMap* stage_to_axes) {
+  if (auto ps = step.as<AnnotationStepNode>()) {
+    ps->ApplyToSchedule(stages, stage_to_axes);
+  } else if (auto ps = step.as<FuseStepNode>()) {
+    ps->ApplyToSchedule(stages, stage_to_axes);
+  } else if (auto ps = step.as<ReorderStepNode>()) {
+    ps->ApplyToSchedule(stages, stage_to_axes);
+  } else if (auto ps = step.as<SplitStepNode>()) {
+    ps->ApplyToSchedule(stages, stage_to_axes);
+  } else if (auto ps = step.as<ComputeAtStepNode>()) {
+    ps->ApplyToSchedule(stages, stage_to_axes);
+  } else if (auto ps = step.as<ComputeInlineStepNode>()) {
+    ps->ApplyToSchedule(stages, stage_to_axes);
+  } else if (auto ps = step.as<ComputeRootStepNode>()) {
+    ps->ApplyToSchedule(stages, stage_to_axes);
+  } else {
+    LOG(FATAL) << "Invalid Step: " << step;
+  }
+}
+
+String StepPrintAsPythonAPI(const Step& step, Array<te::Stage>* stages,
+                            StageToAxesMap* stage_to_axes) {
+  if (auto ps = step.as<AnnotationStepNode>()) {
+    return ps->PrintAsPythonAPI(stages, stage_to_axes);
+  } else if (auto ps = step.as<FuseStepNode>()) {
+    return ps->PrintAsPythonAPI(stages, stage_to_axes);
+  } else if (auto ps = step.as<ReorderStepNode>()) {
+    return ps->PrintAsPythonAPI(stages, stage_to_axes);
+  } else if (auto ps = step.as<SplitStepNode>()) {
+    return ps->PrintAsPythonAPI(stages, stage_to_axes);
+  } else if (auto ps = step.as<ComputeAtStepNode>()) {
+    return ps->PrintAsPythonAPI(stages, stage_to_axes);
+  } else if (auto ps = step.as<ComputeInlineStepNode>()) {
+    return ps->PrintAsPythonAPI(stages, stage_to_axes);
+  } else if (auto ps = step.as<ComputeRootStepNode>()) {
+    return ps->PrintAsPythonAPI(stages, stage_to_axes);
+  } else {
+    LOG(FATAL) << "Invalid Step: " << step;
+  }
+  return "";
+}
+
+/********** Primitives working on single stage **********/
+
+/********** Annotation **********/
+AnnotationStep::AnnotationStep(int stage_id, int iter_id, IteratorAnnotation ann) {
+  auto node = make_object<AnnotationStepNode>();
+  node->stage_id = stage_id;
+  node->iter_id = iter_id;
+  node->annotation = ann;
+  data_ = std::move(node);
+}
+
+AnnotationStep::AnnotationStep(dmlc::JSONReader* reader) {
+  auto node = make_object<AnnotationStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->iter_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  int int_val;
+  reader->Read(&int_val);
+  node->annotation = IteratorAnnotation(int_val);
+  data_ = std::move(node);
+}
+
+void AnnotationStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArrayItem(iter_id);
+  writer->WriteArrayItem(static_cast<int>(annotation));
+}
+
+Iterator AnnotationStepNode::ApplyToState(State* state) const {
+  const Stage& stage = (*state)->stages[stage_id];
+  Iterator it = stage->iters[iter_id];
+
+  CHECK(it->annotation == IteratorAnnotation::kNone);
+  Iterator new_it = Iterator(it->name, it->range, it->iter_kind, annotation);
+  Stage new_stage = stage;
+  new_stage.CopyOnWrite()->iters.Set(iter_id, new_it);
+  state->CopyOnWrite()->stages.Set(stage_id, std::move(new_stage));
+  return new_it;
+}
+
+void AnnotationStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+                                         StageToAxesMap* stage_to_axes) const {
+  te::Stage stage = (*stages)[stage_id];
+  const Array<IterVar>& axes = (*stage_to_axes)[stage];
+
+  switch (annotation) {
+    case IteratorAnnotation::kUnroll:
+      stage.unroll(axes[iter_id]);
+      break;
+    case IteratorAnnotation::kVectorize:
+      stage.vectorize(axes[iter_id]);
+      break;
+    case IteratorAnnotation::kParallel:
+      stage.parallel(axes[iter_id]);
+      break;
+    case IteratorAnnotation::kVThread:
+    case IteratorAnnotation::kBlockX:
+    case IteratorAnnotation::kBlockY:
+    case IteratorAnnotation::kBlockZ:
+    case IteratorAnnotation::kThreadX:
+    case IteratorAnnotation::kThreadY:
+    case IteratorAnnotation::kThreadZ:
+      stage.bind(axes[iter_id],
+                 te::thread_axis(Range(), IteratorAnnotationString[static_cast<int>(annotation)]));
+      break;
+    case IteratorAnnotation::kNone:
+      break;
+    default:
+      LOG(FATAL) << "Invalid Annotation " << static_cast<int>(annotation);
+      break;
+  }
+
+  stages->Set(stage_id, std::move(stage));
+}
+
+String AnnotationStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
+                                            StageToAxesMap* stage_to_axes) const {
+  std::stringstream ss;
+  const auto& stage = (*stages)[stage_id];
+  const auto& iter = (*stage_to_axes)[stage][iter_id];
+
+  ss << "s[" << CleanName(stage->op->name) << "].";
+  switch (annotation) {
+    case IteratorAnnotation::kUnroll:
+      ss << "unroll(";
+      break;
+    case IteratorAnnotation::kVectorize:
+      ss << "vectorize(";
+      break;
+    case IteratorAnnotation::kParallel:
+      ss << "parallel(";
+      break;
+    case IteratorAnnotation::kVThread:
+    case IteratorAnnotation::kBlockX:
+    case IteratorAnnotation::kBlockY:
+    case IteratorAnnotation::kBlockZ:
+    case IteratorAnnotation::kThreadX:
+    case IteratorAnnotation::kThreadY:
+    case IteratorAnnotation::kThreadZ:
+      ss << "bind(";
+      break;
+    case IteratorAnnotation::kNone:
+      break;
+    default:
+      LOG(FATAL) << "Invalid annotation " << static_cast<int>(annotation);
+      break;
+  }
+  ss << CleanName(iter->var->name_hint);
+  switch (annotation) {
+    case IteratorAnnotation::kVThread:
+    case IteratorAnnotation::kBlockX:
+    case IteratorAnnotation::kBlockY:
+    case IteratorAnnotation::kBlockZ:
+    case IteratorAnnotation::kThreadX:
+    case IteratorAnnotation::kThreadY:
+    case IteratorAnnotation::kThreadZ:
+      ss << ", tvm.thread_axis(\"" << IteratorAnnotationString[static_cast<int>(annotation)]
+         << "\")";
+      break;
+    default:
+      break;
+  }
+  ss << ")\n";
+
+  ApplyToSchedule(stages, stage_to_axes);
+  return ss.str();
+}
+
+/********** Fuse **********/
+FuseStep::FuseStep(int stage_id, const Array<Integer>& fused_ids) {
+  auto node = make_object<FuseStepNode>();
+  node->stage_id = stage_id;
+  for (const auto& x : fused_ids) {
+    CHECK(x->IsInstance<IntImmNode>());
+  }
+  node->fused_ids = fused_ids;
+  data_ = std::move(node);
+}
+
+FuseStep::FuseStep(dmlc::JSONReader* reader) {
+  auto node = make_object<FuseStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);
+  std::vector<int> int_list;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&int_list);
+  ::tvm::Array<::tvm::Integer> fused_ids;
+  for (const auto& i : int_list) {
+    fused_ids.push_back(i);
+  }
+  node->fused_ids = fused_ids;
+  data_ = std::move(node);
+}
+
+void FuseStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArrayItem(IntArrayToVector(fused_ids));
+}
+
+Iterator FuseStepNode::ApplyToState(State* state) const {
+  const Stage& stage = (*state)->stages[stage_id];
+  size_t old_iter_size = static_cast<int>(stage->iters.size());
+
+  String new_name;
+  PrimExpr new_extent = 1;
+  IteratorKind new_iter_kind = IteratorKind::kSpecial;
+
+  for (size_t i = 0; i < fused_ids.size(); ++i) {
+    if (i > 0) {
+      CHECK_EQ(fused_ids[i]->value, fused_ids[i - 1]->value + 1);
+    }
+
+    if (i != fused_ids.size() - 1) {
+      const auto& iter_to_attached_stage = (*state)->attach_map->iter_to_attached_stages;
+      if (iter_to_attached_stage.find(std::make_pair(stage_id, fused_ids[i])) !=
+          iter_to_attached_stage.end()) {
+        LOG(FATAL) << "Invalid Fuse. Trying to fuse iterators that have been attached by some "
+                   << "stages. State before fusion:\n"
+                   << (*state);
+      }
+    }
+
+    const Iterator& it = stage->iters[fused_ids[i]];
+    new_name = new_name + it->name + "@";
+
+    if (it->range.defined() && new_extent.defined()) {
+      new_extent = new_extent * it->range->extent;
+    } else {
+      new_extent = PrimExpr();
+    }
+
+    if (i == 0) {
+      new_iter_kind = it->iter_kind;
+    } else {
+      if (new_iter_kind != it->iter_kind) {
+        new_iter_kind = IteratorKind::kMixed;
+      }
+    }
+  }
+
+  Range range;
+  if (new_extent.defined()) {
+    range = Range::FromMinExtent(0, new_extent);
+  }
+  Iterator new_it = Iterator(new_name, range, new_iter_kind, IteratorAnnotation::kNone);
+  Array<Iterator> new_iters;
+  new_iters.insert(new_iters.end(), stage->iters.begin(), stage->iters.begin() + fused_ids.front());
+  new_iters.push_back(new_it);
+  new_iters.insert(new_iters.end(), stage->iters.begin() + fused_ids.back() + 1,
+                   stage->iters.end());
+
+  StateNode* pstate = state->CopyOnWrite();
+  pstate->stages.Set(stage_id,
+                     Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs));
+
+  // Two vectors are used to represent the iterator relation before and after fuse
+  // The original iterators in AttachMap will be updated with the new iterators
+  std::vector<IterKey> from_iters;
+  std::vector<IterKey> to_iters;
+  const size_t begin_id = fused_ids.front(), end_id = fused_ids.back();
+  for (size_t i = 0; i < old_iter_size; ++i) {
+    if (i <= begin_id) {
+      continue;
+    } else if (i > end_id) {
+      // move forward
+      from_iters.emplace_back(stage_id, i);
+      to_iters.emplace_back(stage_id, i - end_id + begin_id);
+    } else {
+      // move to the fused id
+      from_iters.emplace_back(stage_id, i);
+      to_iters.emplace_back(stage_id, begin_id);
+    }
+  }
+  pstate->attach_map.UpdateIters(from_iters, to_iters);
+
+  return new_it;
+}
+
+IterVar FuseStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+                                      StageToAxesMap* stage_to_axes) const {
+  auto stage = (*stages)[stage_id];
+  const Array<IterVar>& axes = stage_to_axes->at(stage);
+
+  Array<IterVar> to_fuse;
+  for (const auto& i : fused_ids) {
+    to_fuse.push_back(axes[i]);
+  }
+  IterVar fused_axis;
+  stage.fuse(to_fuse, &fused_axis);
+
+  Array<IterVar> new_axes;
+  new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + fused_ids.front());
+  new_axes.push_back(fused_axis);
+  new_axes.insert(new_axes.end(), axes.begin() + fused_ids.back() + 1, axes.end());
+
+  stage_to_axes->Set(stage, std::move(new_axes));
+  stages->Set(stage_id, std::move(stage));
+  return fused_axis;
+}
+
+String FuseStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
+                                      StageToAxesMap* stage_to_axes) const {
+  const auto& stage = (*stages)[stage_id];
+  std::stringstream to_fuse;
+
+  for (size_t i = 0; i < fused_ids.size(); ++i) {
+    to_fuse << CleanName(stage_to_axes->at(stage)[fused_ids[i]]->var->name_hint);
+    if (i != fused_ids.size() - 1) {
+      to_fuse << ", ";
+    }
+  }
+
+  std::stringstream ss;
+  const auto& fused = ApplyToSchedule(stages, stage_to_axes);
+
+  ss << CleanName(fused->var->name_hint) << " = s[" << CleanName(stage->op->name) << "].fuse("
+     << to_fuse.str() << ")\n";
+
+  return ss.str();
+}
+
 /********** Reorder **********/
 ReorderStep::ReorderStep(int stage_id, const Array<Integer>& after_ids) {
   auto node = make_object<ReorderStepNode>();
@@ -47,6 +447,41 @@ ReorderStep::ReorderStep(int stage_id, const Array<Integer>& after_ids) {
   data_ = std::move(node);
 }
 
+ReorderStep::ReorderStep(dmlc::JSONReader* reader) {
+  auto node = make_object<ReorderStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  std::vector<int> int_list;
+  reader->Read(&int_list);
+  ::tvm::Array<::tvm::Integer> after_ids;
+  for (const auto& i : int_list) {
+    after_ids.push_back(i);
+  }
+  node->after_ids = after_ids;
+  data_ = std::move(node);
+}
+
+void ReorderStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArrayItem(IntArrayToVector(after_ids));
+}
+
+void ReorderStepNode::ApplyToState(State* state) const {
+  const Stage& stage = (*state)->stages[stage_id];
+  Array<Iterator> iters;
+  for (auto x : after_ids) {
+    iters.push_back(stage->iters[x]);
+  }
+  state->CopyOnWrite()->stages.Set(
+      stage_id, Stage(stage->op, stage->op_type, iters, stage->compute_at, stage->attrs));
+}
+
 void ReorderStepNode::ApplyToSchedule(Array<te::Stage>* stages,
                                       StageToAxesMap* stage_to_axes) const {
   auto stage = (*stages)[stage_id];
@@ -83,6 +518,86 @@ String ReorderStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
 }
 
 /********** Split **********/
+// common part for SplitStep, FollowSplitStep, and FollowFusedSplitStep
+Array<Iterator> ApplySplitToState(State* state, int stage_id, int iter_id,
+                                  const Array<Optional<Integer>>& lengths, bool inner_to_outer) {
+  const Stage& stage = (*state)->stages[stage_id];
+  const Iterator& it = stage->iters[iter_id];
+  size_t old_iter_size = stage->iters.size();
+  bool concrete = true;
+
+  Optional<PrimExpr> tosplit_min, tosplit_extent;
+  if (it->range.defined()) {
+    tosplit_min = it->range->min;
+    tosplit_extent = it->range->extent;
+  } else {
+    tosplit_min = NullOpt;
+    tosplit_extent = NullOpt;
+  }
+
+  Array<Iterator> outs;
+  for (size_t i = 0; i < lengths.size(); ++i) {
+    Optional<Integer> l;
+    String name;
+    if (inner_to_outer) {
+      l = lengths[lengths.size() - i - 1];
+      name = it->name + "." + std::to_string(lengths.size() - i);
+    } else {
+      l = lengths[i];
+      name = it->name + "." + std::to_string(i);
+    }
+    Iterator res;
+    if (l && tosplit_min && tosplit_extent) {
+      res = Iterator(name, Range::FromMinExtent(tosplit_min.value(), l.value()), it->iter_kind,
+                     IteratorAnnotation::kNone);
+      tosplit_min = Integer(0);
+      tosplit_extent = indexdiv(tosplit_extent.value() + l.value() - 1, l.value());
+    } else {
+      res = Iterator(name, Range(), it->iter_kind, IteratorAnnotation::kNone);
+      tosplit_min = NullOpt;
+      tosplit_extent = NullOpt;
+      concrete = false;
+    }
+    outs.push_back(std::move(res));
+  }
+
+  Range range;
+  if (tosplit_min && tosplit_extent) {
+    range = Range::FromMinExtent(tosplit_min.value(), tosplit_extent.value());
+  }
+  if (inner_to_outer) {
+    outs.push_back(Iterator(it->name + ".0", range, it->iter_kind, IteratorAnnotation::kNone));
+    // Reverse the Iterator array
+    Array<Iterator> temp(outs.rbegin(), outs.rend());
+    outs = std::move(temp);
+  } else {
+    outs.push_back(Iterator(it->name + "." + std::to_string(lengths.size()), range, it->iter_kind,
+                            IteratorAnnotation::kNone));
+  }
+
+  Array<Iterator> new_iters;
+  new_iters.insert(new_iters.end(), stage->iters.begin(), stage->iters.begin() + iter_id);
+  new_iters.insert(new_iters.end(), outs.begin(), outs.end());
+  new_iters.insert(new_iters.end(), stage->iters.begin() + iter_id + 1, stage->iters.end());
+
+  StateNode* pstate = state->CopyOnWrite();
+  pstate->stages.Set(stage_id,
+                     Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs));
+  pstate->concrete &= concrete;
+
+  // Two vectors are used to represent the iterator relation before and after split
+  // The original iterators in AttachMap will be updated with the new iterators
+  std::vector<IterKey> from_iters;
+  std::vector<IterKey> to_iters;
+  for (size_t i = iter_id; i < old_iter_size; ++i) {
+    from_iters.emplace_back(stage_id, i);
+    to_iters.emplace_back(stage_id, i + lengths.size());
+  }
+  pstate->attach_map.UpdateIters(from_iters, to_iters);
+
+  return outs;
+}
+
 Array<IterVar> ApplySplitToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
                                     int stage_id, int iter_id,
                                     const Array<Optional<Integer>>& lengths, bool inner_to_outer) {
@@ -171,6 +686,51 @@ SplitStep::SplitStep(int stage_id, int iter_id, Optional<PrimExpr> extent,
   data_ = std::move(node);
 }
 
+SplitStep::SplitStep(dmlc::JSONReader* reader) {
+  auto node = make_object<SplitStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->iter_id);
+  int int_val;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&int_val);
+  if (int_val) {
+    node->extent = Integer(int_val);
+  }
+  s = reader->NextArrayItem();
+  CHECK(s);
+  std::vector<int> int_list;
+  reader->Read(&int_list);
+  ::tvm::Array<::tvm::Optional<::tvm::Integer>> lengths;
+  for (const auto& i : int_list) {
+    lengths.push_back(::tvm::Integer(i));
+  }
+  node->lengths = lengths;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->inner_to_outer);
+  data_ = std::move(node);
+}
+
+void SplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArrayItem(iter_id);
+  writer->WriteArrayItem(extent ? GetIntImm(extent.value()) : 0);
+  writer->WriteArrayItem(IntArrayToVector(lengths));
+  writer->WriteArrayItem(static_cast<int>(inner_to_outer));
+}
+
+Array<Iterator> SplitStepNode::ApplyToState(State* state) const {
+  return ApplySplitToState(state, stage_id, iter_id, lengths, inner_to_outer);
+}
+
 Array<IterVar> SplitStepNode::ApplyToSchedule(Array<te::Stage>* stages,
                                               StageToAxesMap* stage_to_axes) const {
   return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer);
@@ -181,57 +741,185 @@ String SplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
   return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer);
 }
 
-/********** Fuse **********/
-FuseStep::FuseStep(int stage_id, const Array<Integer>& fused_ids) {
-  auto node = make_object<FuseStepNode>();
+/********** Primitives working on multiple stages **********/
+
+/********** Compute At **********/
+ComputeAtStep::ComputeAtStep(int stage_id, int target_stage_id, int target_iter_id) {
+  auto node = make_object<ComputeAtStepNode>();
   node->stage_id = stage_id;
-  for (const auto& x : fused_ids) {
-    CHECK(x->IsInstance<IntImmNode>());
-  }
-  node->fused_ids = fused_ids;
+  node->target_stage_id = target_stage_id;
+  node->target_iter_id = target_iter_id;
   data_ = std::move(node);
 }
 
-IterVar FuseStepNode::ApplyToSchedule(Array<te::Stage>* stages,
-                                      StageToAxesMap* stage_to_axes) const {
-  auto stage = (*stages)[stage_id];
-  const Array<IterVar>& axes = stage_to_axes->at(stage);
+ComputeAtStep::ComputeAtStep(dmlc::JSONReader* reader) {
+  auto node = make_object<ComputeAtStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->target_stage_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->target_iter_id);
+  data_ = std::move(node);
+}
 
-  Array<IterVar> to_fuse;
-  for (const auto& i : fused_ids) {
-    to_fuse.push_back(axes[i]);
+void ComputeAtStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArrayItem(target_stage_id);
+  writer->WriteArrayItem(target_iter_id);
+}
+void ComputeAtStepNode::ApplyToState(State* state) const {
+  const Stage& stage = (*state)->stages[stage_id];
+
+  // Remove the bound information of each iterator since they may not be accurate after
+  // compute at
+  Array<Iterator> new_iters;
+  for (const Iterator& it : stage->iters) {
+    new_iters.push_back(Iterator(it->name, Range(), it->iter_kind, it->annotation));
   }
-  IterVar fused_axis;
-  stage.fuse(to_fuse, &fused_axis);
 
-  Array<IterVar> new_axes;
-  new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + fused_ids.front());
-  new_axes.push_back(fused_axis);
-  new_axes.insert(new_axes.end(), axes.begin() + fused_ids.back() + 1, axes.end());
+  StateNode* pstate = state->CopyOnWrite();
+  pstate->stages.Set(stage_id, Stage(stage->op, stage->op_type, std::move(new_iters),
+                                     ComputeAtKind::kIter, stage->attrs));
+  // Update attach map
+  pstate->attach_map.SetComputeAtIter(stage_id, target_stage_id, target_iter_id);
+}
+
+void ComputeAtStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+                                        StageToAxesMap* stage_to_axes) const {
+  te::Stage stage = (*stages)[stage_id];
+  const auto& target_stage = (*stages)[target_stage_id];
+  const auto& target_axis = (*stage_to_axes)[target_stage][target_iter_id];
+  stage.compute_at(target_stage, target_axis);
 
-  stage_to_axes->Set(stage, std::move(new_axes));
   stages->Set(stage_id, std::move(stage));
-  return fused_axis;
 }
 
-String FuseStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
-                                      StageToAxesMap* stage_to_axes) const {
+String ComputeAtStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
+                                           StageToAxesMap* stage_to_axes) const {
+  std::stringstream ss;
   const auto& stage = (*stages)[stage_id];
-  std::stringstream to_fuse;
+  const auto& target_stage = (*stages)[target_stage_id];
+  ss << "s[" << CleanName(stage->op->name) << "].compute_at(s[" << CleanName(target_stage->op->name)
+     << "], " << CleanName((*stage_to_axes)[target_stage][target_iter_id]->var->name_hint) << ")\n";
+  ApplyToSchedule(stages, stage_to_axes);
+  return ss.str();
+}
 
-  for (size_t i = 0; i < fused_ids.size(); ++i) {
-    to_fuse << CleanName(stage_to_axes->at(stage)[fused_ids[i]]->var->name_hint);
-    if (i != fused_ids.size() - 1) {
-      to_fuse << ", ";
-    }
+/********** Compute Inline **********/
+ComputeInlineStep::ComputeInlineStep(int stage_id) {
+  auto node = make_object<ComputeInlineStepNode>();
+  node->stage_id = stage_id;
+  data_ = std::move(node);
+}
+
+ComputeInlineStep::ComputeInlineStep(dmlc::JSONReader* reader) {
+  auto node = make_object<ComputeInlineStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);
+  data_ = std::move(node);
+}
+
+void ComputeInlineStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+}
+
+void ComputeInlineStepNode::ApplyToState(State* state) const {
+  const Stage& stage = (*state)->stages[stage_id];
+
+  // Check the validity of compute_inline
+  for (size_t i = 0; i < stage->iters.size(); ++i) {
+    CHECK_EQ((*state)->attach_map->iter_to_attached_stages.count(std::make_pair(stage_id, i)), 0)
+        << "Invalid compute_inline: There are some other stages that are attached to the "
+        << "target stage";
   }
 
+  StateNode* pstate = state->CopyOnWrite();
+  auto new_stage = pstate->stages[stage_id];
+  new_stage.CopyOnWrite()->compute_at = ComputeAtKind::kInlined;
+  pstate->stages.Set(stage_id, std::move(new_stage));
+  // Update attach map
+  pstate->attach_map.DeleteStage(stage_id);
+}
+
+void ComputeInlineStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+                                            StageToAxesMap* stage_to_axes) const {
+  auto stage = (*stages)[stage_id];
+  stage.compute_inline();
+  stages->Set(stage_id, std::move(stage));
+}
+
+String ComputeInlineStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
+                                               StageToAxesMap* stage_to_axes) const {
   std::stringstream ss;
-  const auto& fused = ApplyToSchedule(stages, stage_to_axes);
+  const auto& stage = (*stages)[stage_id];
+  ss << "s[" << CleanName(stage->op->name) << "].compute_inline()\n";
+  ApplyToSchedule(stages, stage_to_axes);
+  return ss.str();
+}
 
-  ss << CleanName(fused->var->name_hint) << " = s[" << CleanName(stage->op->name) << "].fuse("
-     << to_fuse.str() << ")\n";
+/********** Compute Root **********/
+ComputeRootStep::ComputeRootStep(int stage_id) {
+  auto node = make_object<ComputeRootStepNode>();
+  node->stage_id = stage_id;
+  data_ = std::move(node);
+}
 
+ComputeRootStep::ComputeRootStep(dmlc::JSONReader* reader) {
+  auto node = make_object<ComputeRootStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);
+  data_ = std::move(node);
+}
+
+void ComputeRootStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+}
+
+void ComputeRootStepNode::ApplyToState(State* state) const {
+  const Stage& stage = (*state)->stages[stage_id];
+
+  // Remove the bound information of each iterator since they may not be accurate after
+  // compute root
+  Array<Iterator> new_iters;
+  for (const Iterator& it : stage->iters) {
+    new_iters.push_back(Iterator(it->name, Range(), it->iter_kind, it->annotation));
+  }
+
+  StateNode* pstate = state->CopyOnWrite();
+  pstate->stages.Set(stage_id, Stage(stage->op, stage->op_type, std::move(new_iters),
+                                     ComputeAtKind::kRoot, stage->attrs));
+  // Update attach map
+  pstate->attach_map.DeleteStage(stage_id);
+}
+
+void ComputeRootStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+                                          StageToAxesMap* stage_to_axes) const {
+  auto stage = (*stages)[stage_id];
+  stage.compute_root();
+  stages->Set(stage_id, std::move(stage));
+}
+
+String ComputeRootStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
+                                             StageToAxesMap* stage_to_axes) const {
+  std::stringstream ss;
+  const auto& stage = (*stages)[stage_id];
+  ss << "s[" << CleanName(stage->op->name) << "].compute_root()\n";
+  ApplyToSchedule(stages, stage_to_axes);
   return ss.str();
 }
 
diff --git a/src/auto_scheduler/transform_step.h b/src/auto_scheduler/transform_step.h
index d840cc0..ce3ca50 100644
--- a/src/auto_scheduler/transform_step.h
+++ b/src/auto_scheduler/transform_step.h
@@ -20,29 +20,34 @@
 /*!
  * \file auto_scheduler/transform_step.h
  * \brief Transformation steps. For each schedule primitive, there is a corresponding transform
- * step. The implementation of each step consists of 2 parts:
- * - transform_step.cc: How each step interacts with TE and TE's schedule primitives
- * - loop_state.cc:     How each step updates LoopState
+ * step.
  *
  * \note To add a new transform step:
  * Take fuse step for example:
- * 1. Define class `FuseStepNode`, `FuseStep` in `transform_steps.h`, and implement its construction
- *    function `FuseStep::FuseStep(...)` in `transform_steps.cc`
- * 2. Implement `FuseStepNode::ApplyToSchedule` and `FuseStepNode::PrintAsPythonAPI`.
+ * 1. Define class `FuseStepNode`, `FuseStep` in `transform_steps.h`, and implement its first
+ *    construction function `FuseStep::FuseStep()` in `transform_steps.cc`.
+ * 2. Implement `FuseStepNode::ApplyToSchedule()` and `FuseStepNode::PrintAsPythonAPI()`.
  *    - In these two functions you need to lower this step with tvm's te schedule API
- * 3. Implement `State::fuse` and `State::DoFuseStep`.
+ * 3. Implement `FuseStepNode::ApplyToState` and the state API `State::fuse`.
  *    - In these two functions you need to incrementally update all data structures in State with
- *      CopyOnWrite style
- * 4. Add you step to `ComputeDAG::ApplySteps` and make sure it works.
- * 5. Add log record serialization support in `struct Handler<Array<::tvm::auto_scheduler::Step>>`
- *    in `record.cc`.
- * 6. Add its corresponding Python API to `loop_state.py` and necessary unit test.
+ *      CopyOnWrite style.
+ * 4. Add your step implementation to `StepApplyToState`, `StepApplyToSchedule` and
+ *    `StepPrintAsPythonAPI`, make sure it works.
+ * 5. Log record serialization support:
+ *    - Add `FuseStepNode::WriteToRecord` which takes a mutable JSONWriter pointer as input and
+ *      output the record to it.
+ *    - Add another construction function that takes a mutable JSONReader as input, this will get a
+ *      step record from the reader and create the step.
+ *    - Add the step implementation to `StepReadFromRecord`.
+ * 6. Add its corresponding Python API to `loop_state.py` and necessary unit test, the test should
+ *    at lease consists of two parts: the functional test and the record serialization test.
  */
 
 #ifndef TVM_AUTO_SCHEDULER_TRANSFORM_STEP_H_
 #define TVM_AUTO_SCHEDULER_TRANSFORM_STEP_H_
 
 #include <dmlc/common.h>
+#include <dmlc/json.h>
 #include <tvm/node/node.h>
 #include <tvm/te/schedule.h>
 
@@ -53,6 +58,92 @@ namespace auto_scheduler {
 
 typedef Map<tvm::te::Stage, Array<tir::IterVar>, ObjectHash, ObjectEqual> StageToAxesMap;
 
+/*! \brief The type of an iterator. */
+enum class IteratorKind : int {
+  /*! \brief Spatial iterator. */
+  kSpatial = 0,
+  /*! \brief Reduction iterator. */
+  kReduction = 1,
+  /*! \brief Fused spatial and reduction iterator. */
+  kMixed = 2,
+  /*! \brief Special iterator. (e.g. virtual root iterator) */
+  kSpecial = 3
+};
+
+/*! \brief The type of an iterator's annotation. */
+enum class IteratorAnnotation : int {
+  /*! \brief This iterator has no annotation. */
+  kNone = 0,
+  /*! \brief This iterator has been unrolled. */
+  kUnroll = 1,
+  /*! \brief This iterator has been vectorized. */
+  kVectorize = 2,
+  /*! \brief This iterator has been paralleld. */
+  kParallel = 3,
+  /*! \brief This iterator has been bind to vthread. */
+  kVThread = 4,
+  /*! \brief This iterator has been bind to blockIdx.x. */
+  kBlockX = 5,
+  /*! \brief This iterator has been bind to threadIdx.x. */
+  kThreadX = 6,
+  /*! \brief This iterator has been bind to blockIdx.y. */
+  kBlockY = 7,
+  /*! \brief This iterator has been bind to threadIdx.y. */
+  kThreadY = 8,
+  /*! \brief This iterator has been bind to blockIdx.y. */
+  kBlockZ = 9,
+  /*! \brief This iterator has been bind to threadIdx.y. */
+  kThreadZ = 10,
+  /*! \brief This iterator has been mapped with a tensorize intrinsic. */
+  kTensorize = 11
+};
+
+extern const char* IteratorAnnotationString[];
+
+/*!
+ * \brief A for loop iterator
+ * Similar to tvm::IterVar in `include/tvm/tir/expr.h`
+ */
+class IteratorNode : public Object {
+ public:
+  /*! \brief The name of this iterator. */
+  String name;
+  /*! \brief The range of this iterator. */
+  Range range;
+  /*! \brief The iterator type of this iterator. */
+  IteratorKind iter_kind;
+  /*! \brief The annotation type of this iterator. */
+  IteratorAnnotation annotation;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("name", &name);
+    v->Visit("range", &range);
+    v->Visit("iter_kind", &iter_kind);
+    v->Visit("annotation", &annotation);
+  }
+
+  static constexpr const char* _type_key = "auto_scheduler.Iterator";
+  TVM_DECLARE_FINAL_OBJECT_INFO(IteratorNode, Object);
+};
+
+/*!
+ * \brief Managed reference to IteratorNode.
+ * \sa IteratorNode
+ */
+class Iterator : public ObjectRef {
+ public:
+  /*!
+   * \brief The constructor.
+   * \param name The name of this iterator.
+   * \param range The range of this iterator.
+   * \param iter_kind The iterator type of this iterator.
+   * \param annotation The annotation type of this iterator.
+   */
+  Iterator(String name, Range range, IteratorKind iter_kind, IteratorAnnotation annotation);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(Iterator, ObjectRef, IteratorNode);
+};
+
 /*!
  * \brief The base class of transformation steps. Each step has its corresponding tvm.te
  * schedule primitives.
@@ -62,6 +153,12 @@ class StepNode : public Object {
   /*! \brief The index of the stage. */
   int stage_id;
 
+  /*!
+   * \brief Serialize the current step record to JSONWriter.
+   * \param writer The output JSONWriter.
+   */
+  virtual void WriteToRecord(dmlc::JSONWriter* writer) const = 0;
+
   static constexpr const char* _type_key = "auto_scheduler.Step";
   TVM_DECLARE_BASE_OBJECT_INFO(StepNode, Object);
 };
@@ -75,6 +172,172 @@ class Step : public ObjectRef {
   TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Step, ObjectRef, StepNode);
 };
 
+// Forward declaration
+class State;
+class ComputeDAG;
+
+/*!
+ * \brief Read a step record from JSONReader and create the corresponding step.
+ * \param reader The input JSONReader.
+ */
+Step StepReadFromRecord(dmlc::JSONReader* reader);
+
+/*!
+ * \brief Apply the step to State.
+ * \param step The step to be applied to State.
+ * \param state A mutable pointer to State.
+ * \param dag The original ComputeDAG of this state.
+ * \return The iterator result after annotate.
+ */
+void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag);
+
+/*!
+ * \brief Apply the step to tvm.schedule.
+ * \param step The step to be applied to tvm.schedule.
+ * \param stages A pointer to a `te::Stage` Array.
+ * \param stage_to_axes A pointer to a StageToAxesMap.
+ */
+void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages, StageToAxesMap* stage_to_axes);
+
+/*!
+ * \brief Print the step as equivalent python schedule API.
+ * \param step The step to be applied to python API.
+ * \param stages A pointer to a `te::Stage` Array.
+ * \param stage_to_axes A pointer to a StageToAxesMap.
+ * \return Python schedule code.
+ */
+String StepPrintAsPythonAPI(const Step& step, Array<te::Stage>* stages,
+                            StageToAxesMap* stage_to_axes);
+
+/********** Primitives working on single stage **********/
+
+/*!
+ * \brief Annotation step that corresponds to vectorize, parallel, unroll and thread binding.
+ * (i.e. te::Stage::vectorize, te::Stage::parallel, te::Stage::vectorize, te::Stage::bind)
+ */
+class AnnotationStepNode : public StepNode {
+ public:
+  /*! \brief The index of the iterator to add annotation. */
+  int iter_id;
+  /*! \brief The annotation type of this step. */
+  IteratorAnnotation annotation;
+
+  void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+  /*!
+   * \brief Apply the current step to State.
+   * \param state A mutable pointer to State.
+   * \return The iterator result after annotate.
+   */
+  Iterator ApplyToState(State* state) const;
+
+  /*!
+   * \brief Apply the current step to tvm.schedule.
+   * \param stages A pointer to a `te::Stage` Array.
+   * \param stage_to_axes A pointer to a StageToAxesMap.
+   */
+  void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
+
+  /*!
+   * \brief Print the current step as equivalent python schedule API.
+   * \param stages A pointer to a `te::Stage` Array.
+   * \param stage_to_axes A pointer to a StageToAxesMap.
+   * \return Python schedule code.
+   */
+  String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
+
+  static constexpr const char* record_prefix_str = "AN";
+
+  static constexpr const char* _type_key = "auto_scheduler.AnnotationStep";
+  TVM_DECLARE_FINAL_OBJECT_INFO(AnnotationStepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to AnnotationStepNode.
+ * \sa AnnotationStepNode
+ */
+class AnnotationStep : public Step {
+ public:
+  /*!
+   * \brief The constructor.
+   * \param stage_id The index of the stage to add annotation.
+   * \param iter_id The index of the iterator to add annotation.
+   * \param ann The annotation type of this step.
+   */
+  AnnotationStep(int stage_id, int iter_id, IteratorAnnotation ann);
+
+  /*!
+   * \brief The constructor used to read a step record from JSONReader and create the
+   * corresponding step.
+   * \param reader The input JSONReader.
+   */
+  explicit AnnotationStep(dmlc::JSONReader* reader);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(AnnotationStep, Step, AnnotationStepNode);
+};
+
+/*! \brief Fuse step that corresponds to te::Stage::fuse */
+class FuseStepNode : public StepNode {
+ public:
+  /*! \brief The ids of iterators to fuse. */
+  Array<Integer> fused_ids;
+
+  void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+  /*!
+   * \brief Apply the current step to State.
+   * \param state A mutable pointer to State.
+   * \return The iterator result after fuse.
+   * \note If the iterators to be fused have stages attached at them(by compute_at), the fused
+   * result will become the new attach point.
+   */
+  Iterator ApplyToState(State* state) const;
+
+  /*!
+   * \brief Apply the current step to tvm.schedule.
+   * \param stages A pointer to a `te::Stage` Array.
+   * \param stage_to_axes A pointer to a StageToAxesMap.
+   * \return The iterator result after fuse.
+   */
+  tir::IterVar ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
+
+  /*!
+   * \brief Print the current step as equivalent python schedule API.
+   * \param stages A pointer to a `te::Stage` Array.
+   * \param stage_to_axes A pointer to a StageToAxesMap.
+   * \return Python schedule code.
+   */
+  String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
+
+  static constexpr const char* record_prefix_str = "FU";
+
+  static constexpr const char* _type_key = "auto_scheduler.FuseStep";
+  TVM_DECLARE_FINAL_OBJECT_INFO(FuseStepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to FuseStepNode.
+ * \sa FuseStepNode
+ */
+class FuseStep : public Step {
+ public:
+  /*!
+   * \brief The constructor.
+   * \param stage_id The index of the stage to be fused.
+   * \param fused_ids The index of the iterators to be fused.
+   */
+  FuseStep(int stage_id, const Array<Integer>& fused_ids);
+
+  /*!
+   * \brief The constructor used to read a step record from JSONReader and create the
+   * corresponding step.
+   * \param reader The input JSONReader.
+   */
+  explicit FuseStep(dmlc::JSONReader* reader);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(FuseStep, Step, FuseStepNode);
+};
+
 /*! \brief Reorder step that corresponds to te::Stage::reorder */
 class ReorderStepNode : public StepNode {
  public:
@@ -84,21 +347,31 @@ class ReorderStepNode : public StepNode {
    */
   Array<Integer> after_ids;
 
+  void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+  /*!
+   * \brief Apply the current step to State.
+   * \param state A mutable pointer to State.
+   */
+  void ApplyToState(State* state) const;
+
   /*!
-   * \brief Apply the current state to tvm.schedule
+   * \brief Apply the current step to tvm.schedule.
    * \param stages A pointer to a `te::Stage` Array.
    * \param stage_to_axes A pointer to a StageToAxesMap.
    */
   void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
 
   /*!
-   * \brief Print step as equivalent python schedule API.
+   * \brief Print the current step as equivalent python schedule API.
    * \param stages A pointer to a `te::Stage` Array.
    * \param stage_to_axes A pointer to a StageToAxesMap.
    * \return Python schedule code.
    */
   String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
 
+  static constexpr const char* record_prefix_str = "RE";
+
   static constexpr const char* _type_key = "auto_scheduler.ReorderStep";
   TVM_DECLARE_FINAL_OBJECT_INFO(ReorderStepNode, Object);
 };
@@ -116,6 +389,13 @@ class ReorderStep : public Step {
    */
   ReorderStep(int stage_id, const Array<Integer>& after_ids);
 
+  /*!
+   * \brief The constructor used to read a step record from JSONReader and create the
+   * corresponding step.
+   * \param reader The input JSONReader.
+   */
+  explicit ReorderStep(dmlc::JSONReader* reader);
+
   TVM_DEFINE_OBJECT_REF_METHODS(ReorderStep, Step, ReorderStepNode);
 };
 
@@ -137,8 +417,19 @@ class SplitStepNode : public StepNode {
    */
   bool inner_to_outer;
 
+  void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
   /*!
-   * \brief Apply the current state to tvm.schedule
+   * \brief Apply the current step to State.
+   * \param state A mutable pointer to State.
+   * \return The iterator results after split.
+   * \note If we do split on an iterator which has stages attached at it(by compute_at), the inner
+   * most iterator of split results will become the new attach point.
+   */
+  Array<Iterator> ApplyToState(State* state) const;
+
+  /*!
+   * \brief Apply the current step to tvm.schedule.
    * \param stages A pointer to a `te::Stage` Array.
    * \param stage_to_axes A pointer to a StageToAxesMap.
    * \return The iterator results after split.
@@ -147,13 +438,15 @@ class SplitStepNode : public StepNode {
                                       StageToAxesMap* stage_to_axes) const;
 
   /*!
-   * \brief Print step as equivalent python schedule API.
+   * \brief Print the current step as equivalent python schedule API.
    * \param stages A pointer to a `te::Stage` Array.
    * \param stage_to_axes A pointer to a StageToAxesMap.
    * \return Python schedule code.
    */
   String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
 
+  static constexpr const char* record_prefix_str = "SP";
+
   static constexpr const char* _type_key = "auto_scheduler.SplitStep";
   TVM_DECLARE_FINAL_OBJECT_INFO(SplitStepNode, Object);
 };
@@ -175,49 +468,195 @@ class SplitStep : public Step {
   SplitStep(int stage_id, int iter_id, Optional<PrimExpr> extent,
             const Array<Optional<Integer>>& lengths, bool inner_to_outer);
 
+  /*!
+   * \brief The constructor used to read a step record from JSONReader and create the
+   * corresponding step.
+   * \param reader The input JSONReader.
+   */
+  explicit SplitStep(dmlc::JSONReader* reader);
+
   TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode);
 };
 
-/*! \brief Fuse step that corresponds to te::Stage::fuse */
-class FuseStepNode : public StepNode {
+/********** Primitives working on multiple stages **********/
+
+/*! \brief Compute at step that corresponds to te::Stage::compute_at */
+class ComputeAtStepNode : public StepNode {
  public:
-  /*! \brief The ids of iterators to fuse. */
-  Array<Integer> fused_ids;
+  /*! \brief The index of stage that this step will compute at to. */
+  int target_stage_id;
+  /*! \brief The index of iterator in target stage that this step will compute at to. */
+  int target_iter_id;
+
+  void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+  /*!
+   * \brief Apply the current step to State.
+   * \param state A mutable pointer to State.
+   * \note After compute_at, we need careful dependency analysis to compute the accurate bound
+   * information. However, it is relatively expensive and complicated, so we just fill "None" as
+   * bound for the newly created iterators.
+   * Call ComputeDAG::InferBound on the updated state to get the complete bound information.
+   */
+  void ApplyToState(State* state) const;
+
+  /*!
+   * \brief Apply the current step to tvm.schedule.
+   * \param stages A pointer to a `te::Stage` Array.
+   * \param stage_to_axes A pointer to a StageToAxesMap.
+   */
+  void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
+
+  /*!
+   * \brief Print the current step as equivalent python schedule API.
+   * \param stages A pointer to a `te::Stage` Array.
+   * \param stage_to_axes A pointer to a StageToAxesMap.
+   * \return Python schedule code.
+   */
+  String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
+
+  static constexpr const char* record_prefix_str = "CA";
+
+  static constexpr const char* _type_key = "auto_scheduler.ComputeAtStep";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ComputeAtStepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to ComputeAtStepNode.
+ * \sa ComputeAtStepNode
+ */
+class ComputeAtStep : public Step {
+ public:
+  /*!
+   * \brief The constructor.
+   * \param stage_id The index of the stage to be compute at.
+   * \param target_stage_id The index of stage that this step will compute at to.
+   * \param target_iter_id The index of iterator in target stage that this step will compute at to.
+   */
+  ComputeAtStep(int stage_id, int target_stage_id, int target_iter_id);
+
+  /*!
+   * \brief The constructor used to read a step record from JSONReader and create the
+   * corresponding step.
+   * \param reader The input JSONReader.
+   */
+  explicit ComputeAtStep(dmlc::JSONReader* reader);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(ComputeAtStep, Step, ComputeAtStepNode);
+};
+
+/*! \brief Compute inline step that corresponds to te::Stage::compute_inline */
+class ComputeInlineStepNode : public StepNode {
+ public:
+  void WriteToRecord(dmlc::JSONWriter* writer) const final;
 
   /*!
-   * \brief Apply the current state to tvm.schedule
+   * \brief Apply the current step to State.
+   * \param state A mutable pointer to State.
+   */
+  void ApplyToState(State* state) const;
+
+  /*!
+   * \brief Apply the current step to tvm.schedule.
    * \param stages A pointer to a `te::Stage` Array.
    * \param stage_to_axes A pointer to a StageToAxesMap.
    * \return The iterator result after fuse.
    */
-  tir::IterVar ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
+  void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
 
   /*!
-   * \brief Print step as equivalent python schedule API.
+   * \brief Print the current step as equivalent python schedule API.
    * \param stages A pointer to a `te::Stage` Array.
    * \param stage_to_axes A pointer to a StageToAxesMap.
    * \return Python schedule code.
    */
   String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
 
-  static constexpr const char* _type_key = "auto_scheduler.FuseStep";
-  TVM_DECLARE_FINAL_OBJECT_INFO(FuseStepNode, Object);
+  static constexpr const char* record_prefix_str = "CI";
+
+  static constexpr const char* _type_key = "auto_scheduler.ComputeInlineStep";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ComputeInlineStepNode, Object);
 };
 
 /*!
- * \brief Managed reference to FuseStepNode.
- * \sa FuseStepNode
+ * \brief Managed reference to ComputeInlineStepNode.
+ * \sa ComputeInlineStepNode
  */
-class FuseStep : public Step {
+class ComputeInlineStep : public Step {
  public:
   /*!
    * \brief The constructor.
-   * \param stage_id The index of the stage to be fused.
-   * \param fused_ids The index of the iterators to be fused.
+   * \param stage_id The index of the stage to be compute inline.
    */
-  FuseStep(int stage_id, const Array<Integer>& fused_ids);
+  explicit ComputeInlineStep(int stage_id);
 
-  TVM_DEFINE_OBJECT_REF_METHODS(FuseStep, Step, FuseStepNode);
+  /*!
+   * \brief The constructor used to read a step record from JSONReader and create the
+   * corresponding step.
+   * \param reader The input JSONReader.
+   */
+  explicit ComputeInlineStep(dmlc::JSONReader* reader);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(ComputeInlineStep, Step, ComputeInlineStepNode);
+};
+
+/*! \brief Compute root step that corresponds to te::Stage::compute_root */
+class ComputeRootStepNode : public StepNode {
+ public:
+  void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+  /*!
+   * \brief Apply the current step to State.
+   * \param state A mutable pointer to State.
+   * \note After compute_at, we need careful dependency analysis to compute the accurate bound
+   * information. However, it is relatively expensive and complicated, so we just fill "None" as
+   * bound for the newly created iterators.
+   * Call ComputeDAG::InferBound on the updated state to get the complete bound information.
+   */
+  void ApplyToState(State* state) const;
+
+  /*!
+   * \brief Apply the current step to tvm.schedule.
+   * \param stages A pointer to a `te::Stage` Array.
+   * \param stage_to_axes A pointer to a StageToAxesMap.
+   * \return The iterator result after fuse.
+   */
+  void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
+
+  /*!
+   * \brief Print the current step as equivalent python schedule API.
+   * \param stages A pointer to a `te::Stage` Array.
+   * \param stage_to_axes A pointer to a StageToAxesMap.
+   * \return Python schedule code.
+   */
+  String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
+
+  static constexpr const char* record_prefix_str = "CR";
+
+  static constexpr const char* _type_key = "auto_scheduler.ComputeRootStep";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ComputeRootStepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to ComputeRootStepNode.
+ * \sa ComputeRootStepNode
+ */
+class ComputeRootStep : public Step {
+ public:
+  /*!
+   * \brief The constructor.
+   * \param stage_id The index of the stage to be compute root
+   */
+  explicit ComputeRootStep(int stage_id);
+
+  /*!
+   * \brief The constructor used to read a step record from JSONReader and create the
+   * corresponding step.
+   * \param reader The input JSONReader.
+   */
+  explicit ComputeRootStep(dmlc::JSONReader* reader);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(ComputeRootStep, Step, ComputeRootStepNode);
 };
 
 }  // namespace auto_scheduler
diff --git a/src/auto_scheduler/utils.h b/src/auto_scheduler/utils.h
index 5637780..de800da 100644
--- a/src/auto_scheduler/utils.h
+++ b/src/auto_scheduler/utils.h
@@ -63,7 +63,7 @@ struct hash<std::tuple<T1, T2, T3>> {
 namespace tvm {
 namespace auto_scheduler {
 
-/********** Utilities for Array, std::string **********/
+/********** Utilities for Array, std::vector, std::string **********/
 /*! \brief Get the first appearance index of elements in an Array */
 template <typename T>
 inline void GetIndices(const Array<T>& array, const Array<T>& to_locate, Array<Integer>* indices) {
@@ -89,6 +89,15 @@ inline int GetIndex(const Array<T>& array, const T& to_locate) {
   return -1;
 }
 
+/*! \brief Delete the item in a std::vector if it exists. */
+template <typename T>
+inline void FindAndDeleteItem(std::vector<T>* array, const T& to_delete) {
+  auto iter = std::find(array->begin(), array->end(), to_delete);
+  if (iter != array->end()) {
+    array->erase(iter);
+  }
+}
+
 /*! \brief Replace a sub-string to another sub-string in a string */
 inline void StrReplace(std::string* base, const std::string& from, const std::string& to) {
   auto pos = base->find(from);
@@ -98,6 +107,27 @@ inline void StrReplace(std::string* base, const std::string& from, const std::st
   }
 }
 
+/*! \brief Convert a Array<Integer> to std::vector<int>. */
+inline std::vector<int> IntArrayToVector(const ::tvm::Array<::tvm::Integer>& data) {
+  std::vector<int> out;
+  for (const auto& x : data) {
+    CHECK(x.defined());
+    out.push_back(x);
+  }
+  return out;
+}
+
+/*! \brief Convert a Array<Optional<Integer>> to std::vector<int>. */
+inline std::vector<int> IntArrayToVector(
+    const ::tvm::Array<::tvm::Optional<::tvm::Integer>>& data) {
+  std::vector<int> out;
+  for (const auto& x : data) {
+    CHECK(x);
+    out.push_back(x.value());
+  }
+  return out;
+}
+
 /********** Utilities for TVM Containers / ByteArray **********/
 /*! \brief Compute mean of a FloatImm array */
 inline double FloatArrayMean(const Array<PrimExpr>& float_array) {
diff --git a/tests/python/unittest/test_auto_scheduler_common.py b/tests/python/unittest/test_auto_scheduler_common.py
index 078e1ae..fa22fdc 100644
--- a/tests/python/unittest/test_auto_scheduler_common.py
+++ b/tests/python/unittest/test_auto_scheduler_common.py
@@ -40,7 +40,7 @@ def matmul_auto_scheduler_test_rename_0(N, M, K):
     C = te.compute((N, M), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name='C')
     return [A, B, C]
 
-
+@auto_scheduler.register_workload
 def conv2d_nchw_bn_relu(N, H, W, CI, CO, kernel_size, strides, padding, dilation=1):
     data = te.placeholder((N, CI, H, W), name='Data')
     kernel = te.placeholder((CO, CI, kernel_size, kernel_size), name='Kernel')
diff --git a/tests/python/unittest/test_auto_scheduler_loop_state.py b/tests/python/unittest/test_auto_scheduler_loop_state.py
index 0801d92..32ea8fa 100644
--- a/tests/python/unittest/test_auto_scheduler_loop_state.py
+++ b/tests/python/unittest/test_auto_scheduler_loop_state.py
@@ -26,8 +26,8 @@ import topi
 from test_auto_scheduler_common import matmul_auto_scheduler_test, conv2d_nchw_bn_relu
 
 
-def test_split_fuse_reorder():
-    A, B, C = matmul_auto_scheduler_test(512, 512, 512)
+def test_split_fuse_reorder_annotation():
+    A, B, C = matmul_auto_scheduler_test(N=512, M=512, K=512)
     dag = auto_scheduler.ComputeDAG([A, B, C])
     s0 = dag.get_init_state()
     i, j, k = s0[C].iters
@@ -61,5 +61,88 @@ def test_split_fuse_reorder():
     assert s1[C].iters[4].range.extent == 8
     assert s1[C].iters[5].range.extent == 2
 
+    res = s1.bind(C, i1, "blockIdx.x")
+    assert res == s1[C].iters[0]
+    assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE["blockIdx.x"]
+
+    res = s1.bind(C, i2, "vthread")
+    assert res == s1[C].iters[1]
+    assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE["vthread"]
+
+    res = s1.bind(C, i3, "threadIdx.y")
+    assert res == s1[C].iters[2]
+    assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE["threadIdx.y"]
+
+    res = s1.parallel(C, j1)
+    assert res == s1[C].iters[3]
+    assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE["parallel"]
+
+    res = s1.unroll(C, j2)
+    assert res == s1[C].iters[4]
+    assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE["unroll"]
+
+    res = s1.vectorize(C, j3)
+    assert res == s1[C].iters[5]
+    assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE["vectorize"]
+
+
+def test_compute_at_root_inline():
+    dag = auto_scheduler.ComputeDAG(conv2d_nchw_bn_relu(N=1, H=224, W=224, CI=3, CO=64,
+                                                        kernel_size=7, strides=2, padding=3))
+    s0 = dag.get_init_state()
+
+    # data, padding, kernel = 0, 1, 2
+    conv = s0.stage_ops[3]
+    # bias = 4
+    bias_add = s0.stage_ops[5]
+    # bn_scale = 6
+    bn_mul = s0.stage_ops[7]
+    # bn_offset = 8
+    bn_add = s0.stage_ops[9]
+    relu = s0.stage_ops[10]
+
+    s0.compute_inline(bn_add)
+    assert s0[bn_add].compute_at == 1
+
+    s0.compute_inline(bn_mul)
+    assert s0[bn_mul].compute_at == 1
+
+    s0.compute_inline(bias_add)
+    assert s0[bias_add].compute_at == 1
+
+    assert s0[conv].iters[0].range.extent == 1
+    assert s0[conv].iters[1].range.extent == 64
+    assert s0[conv].iters[2].range.extent == 112
+    assert s0[conv].iters[3].range.extent == 112
+    assert s0[conv].iters[4].range.extent == 3
+    assert s0[conv].iters[5].range.extent == 7
+    assert s0[conv].iters[6].range.extent == 7
+    s0.compute_at(conv, relu, s0[relu].iters[2])
+    assert s0[conv].compute_at == 2
+    s0 = dag.infer_bound_from_state(s0)
+    assert s0[conv].iters[0].range.extent == 1
+    assert s0[conv].iters[1].range.extent == 1
+    assert s0[conv].iters[2].range.extent == 1
+    assert s0[conv].iters[3].range.extent == 112
+    assert s0[conv].iters[4].range.extent == 3
+    assert s0[conv].iters[5].range.extent == 7
+    assert s0[conv].iters[6].range.extent == 7
+
+    s0.compute_root(bn_mul)
+    assert s0[bn_mul].compute_at == 0
+
+    s0.compute_root(conv)
+    assert s0[conv].compute_at == 0
+    s0 = dag.infer_bound_from_state(s0)
+    assert s0[conv].iters[0].range.extent == 1
+    assert s0[conv].iters[1].range.extent == 64
+    assert s0[conv].iters[2].range.extent == 112
+    assert s0[conv].iters[3].range.extent == 112
+    assert s0[conv].iters[4].range.extent == 3
+    assert s0[conv].iters[5].range.extent == 7
+    assert s0[conv].iters[6].range.extent == 7
+
+
 if __name__ == "__main__":
-    test_split_fuse_reorder()
+    test_split_fuse_reorder_annotation()
+    test_compute_at_root_inline()
diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py
index d6e6c51..333d20e 100644
--- a/tests/python/unittest/test_auto_scheduler_measure.py
+++ b/tests/python/unittest/test_auto_scheduler_measure.py
@@ -18,7 +18,8 @@
 """ Test measurement and log serialization. """
 
 import tvm
-from tvm import auto_scheduler
+import topi
+from tvm import te, auto_scheduler
 import tempfile
 
 from test_auto_scheduler_common import get_tiled_matmul
@@ -28,7 +29,44 @@ def test_record():
     if not tvm.runtime.enabled("llvm"):
         return
 
-    dag, s = get_tiled_matmul()
+    A = te.placeholder((512, 512), name='A')
+    B = te.placeholder((512, 512), name='B')
+    k = te.reduce_axis((0, 512), name='k')
+    C = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name='C')
+    D = topi.nn.relu(C)
+    k = te.reduce_axis((0, 512), name='k')
+    E = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * D[k][j], axis=[k]), name='C')
+    F = topi.nn.relu(E)
+
+    dag = auto_scheduler.ComputeDAG([A, B, F])
+    s = dag.get_init_state()
+
+    # Split
+    its0 = s.split(C, s[C].iters[0], [4, 8, 8])
+    its1 = s.split(C, s[C].iters[4], [8, 4, 4])
+    # Reorder
+    s.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its0[3], s[C].iters[8],
+                  its1[3]])
+    # Fuse
+    s.fuse(C, [s[C].iters[0], s[C].iters[1], s[C].iters[2]])
+    # Compute at
+    s.split(F, s[F].iters[0], [2])
+    s.compute_at(E, F, s[F].iters[0])
+    # Compute inline
+    s.compute_inline(D)
+    # Compute root
+    s.compute_root(D)
+    # Parallel
+    s.parallel(C, s[C].iters[0])
+    # Thread bind(The blockIdx & threadIdx are used in GPU, just for record testing here)
+    s.bind(C, s[C].iters[1], "blockIdx.x")
+    s.bind(C, s[C].iters[2], "threadIdx.z")
+    s.bind(C, s[C].iters[3], "vthread")
+    # Unroll
+    s.unroll(C, s[C].iters[4])
+    # Vectorize
+    s.vectorize(C, s[C].iters[6])
+
     target = tvm.target.create("llvm")
     task = auto_scheduler.SearchTask(dag, "test", target)