You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2021/04/07 12:07:04 UTC

[tvm] branch main updated: [M1b] Scaffolding ScheduleState data structure (#7765)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new bf0f87d  [M1b] Scaffolding ScheduleState data structure (#7765)
bf0f87d is described below

commit bf0f87dcc6400ab15626538cdf694cb9fa07a7df
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Wed Apr 7 05:04:48 2021 -0700

    [M1b] Scaffolding ScheduleState data structure (#7765)
    
    Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
    Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
    Co-authored-by: Ruihang Lai <la...@qq.com>
    Co-authored-by: Hongyi Jin <32...@qq.com>
    Co-authored-by: Wuwei Lin <wu...@apache.org>
    Co-authored-by: Cody Yu <co...@gmail.com>
    Co-authored-by: Jared Roesch <ro...@gmail.com>
---
 include/tvm/runtime/object.h                     |  15 +
 include/tvm/tir/schedule/block_scope.h           | 271 +++++++
 include/tvm/tir/schedule/state.h                 | 216 ++++++
 python/tvm/tir/__init__.py                       |   3 +
 python/tvm/tir/schedule/__init__.py              |  21 +
 python/tvm/tir/schedule/_ffi_api_schedule.py     |  20 +
 python/tvm/tir/schedule/block_scope.py           | 152 ++++
 python/tvm/tir/schedule/state.py                 | 185 +++++
 python/tvm/tir/stmt.py                           |  19 +-
 src/printer/tvmscript_printer.cc                 |   1 +
 src/tir/analysis/var_touch.cc                    |   2 +-
 src/tir/ir/stmt.cc                               |   6 +-
 src/tir/schedule/analysis.h                      |  47 ++
 src/tir/schedule/analysis/analysis.cc            |  60 ++
 src/tir/schedule/analysis/verify.cc              | 146 ++++
 src/tir/schedule/block_scope.cc                  | 162 +++++
 src/tir/schedule/state.cc                        | 870 +++++++++++++++++++++++
 src/tir/schedule/utils.h                         |  93 +++
 tests/python/unittest/test_tir_block_scope.py    | 145 ++++
 tests/python/unittest/test_tir_schedule_state.py | 352 +++++++++
 20 files changed, 2776 insertions(+), 10 deletions(-)

diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h
index 048fc1d..f13bdee 100644
--- a/include/tvm/runtime/object.h
+++ b/include/tvm/runtime/object.h
@@ -739,6 +739,21 @@ struct ObjectPtrEqual {
   ObjectName* operator->() const { return static_cast<ObjectName*>(data_.get()); }          \
   using ContainerType = ObjectName;
 
+/*
+ * \brief Define object reference methods that is both not nullable and mutable.
+ *
+ * \param TypeName The object type name
+ * \param ParentType The parent type of the objectref
+ * \param ObjectName The type name of the object.
+ */
+#define TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
+  explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \
+  TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName);                                        \
+  ObjectName* operator->() const { return static_cast<ObjectName*>(data_.get()); }          \
+  ObjectName* get() const { return operator->(); }                                          \
+  static constexpr bool _type_is_nullable = false;                                          \
+  using ContainerType = ObjectName;
+
 /*!
  * \brief Define CopyOnWrite function in an ObjectRef.
  * \param ObjectName The Type of the Node.
diff --git a/include/tvm/tir/schedule/block_scope.h b/include/tvm/tir/schedule/block_scope.h
new file mode 100644
index 0000000..49d5e7f
--- /dev/null
+++ b/include/tvm/tir/schedule/block_scope.h
@@ -0,0 +1,271 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/tir/schedule/block_scope.h
+ * \brief Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope.
+ * \sa StmtSRefNode
+ * \sa BlockScopeNode
+ */
+#ifndef TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+#define TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+
+#include <tvm/tir/stmt.h>
+
+#include <unordered_map>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref".
+ *
+ * Glossary
+ * - Block sref: A StmtSRef that points to a TensorIR block.
+ * - Loop sref: A StmtSRef that points to a TensorIR for loop.
+ * - Parent sref: The parent reference of an sref is the block or loop reference to the closest
+ schedulable statement. We define closest to be the nearest schedulable statement of an ancestor in
+ the AST.
+ * schedulable statement of its ancestors on the TensorIR AST.
+ * - Root sref: Sref to the root block. Every sref has exactly one parent sref except for root sref.
+ * - Sref tree: The parent-children-relationship of srefs that forms a tree, uniquely determined by
+ * the TensorIR AST.
+ */
+class StmtSRefNode : public Object {
+ public:
+  /*!
+   * \brief The block or `for` stmt the object refers to
+   * \note Non-owned reference (raw pointer) is used here, so that we can perform copy-on-write
+   * optimization on statements when possible. The strong reference is held in the ScheduleState.
+   */
+  const StmtNode* stmt;
+  /*! \brief The parent sref. */
+  StmtSRefNode* parent;
+  /*!
+   * \brief If the statement the sref points to is an element of a SeqStmt in the AST,
+   * then `seq_index` is set to its index; otherwise `seq_index` is -1
+   */
+  int64_t seq_index;
+
+  void VisitAttrs(AttrVisitor* v) {
+    // `stmt` is not visited
+    // `parent` is not visited
+    v->Visit("seq_index", &seq_index);
+  }
+
+  static constexpr const char* _type_key = "tir.StmtSRef";
+  TVM_DECLARE_FINAL_OBJECT_INFO(StmtSRefNode, Object);
+
+  /*! \brief Reset the object inplace to the invalid state */
+  void Reset() {
+    this->stmt = nullptr;
+    this->parent = nullptr;
+    this->seq_index = -1;
+  }
+
+  /*!
+   * \brief Get the referenced statement with proper type checking.
+   * It serves the same purpose as `ObjectRef::as`, but does not acquire strong reference to `stmt`
+   * \tparam StmtType The type that `this->stmt` to be downcasted to. Presumably
+   * tvm::tir::BlockNode or tvm::tir::ForNode
+   * \return nullptr if type check fails, otherwise the casted result for `this->stmt`
+   */
+  template <typename StmtType>
+  const StmtType* StmtAs() const {
+    if (stmt != nullptr && stmt->IsInstance<StmtType>()) {
+      return static_cast<const StmtType*>(stmt);
+    } else {
+      return nullptr;
+    }
+  }
+};
+
+/*!
+ * \brief Managed reference to StmtSRefNode
+ * \sa StmtSRefNode
+ */
+class StmtSRef : public ObjectRef {
+ public:
+  /*!
+   * \brief The constructor
+   * \param stmt The corresponding stmt node, can be either block or for loop.
+   * \param parent The parent sref.
+   * \param seq_index The location in an array if the parent of the stmt contains multiple children.
+   * -1 if the parent does not contain multiple children.
+   */
+  TVM_DLL explicit StmtSRef(const StmtNode* stmt, StmtSRefNode* parent, int64_t seq_index);
+
+  /*! \return The mutable pointer to the StmtSRefNode */
+  StmtSRefNode* get() const { return static_cast<StmtSRefNode*>(data_.get()); }
+
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(StmtSRef, ObjectRef, StmtSRefNode);
+
+ public:
+  /*!
+   * \return A special StmtSRef, which doesn't point to any stmt in the AST,
+   * only serving as a "mark" to hint compute-at to do the work of compute-inline
+   * \note This is only as a faked loop sref for compute-at and reverse-compute-at,
+   * i.e.
+   *
+   * compute-at(block, loop_sref):
+   *   compute-inline(block)                if loop_sref.same_as(InlineMark())
+   *   no-op                                if loop_sref.same_as(RootMark())
+   *   compute-at-impl(block, loop_sref)    otherwise
+   */
+  TVM_DLL static StmtSRef InlineMark();
+  /*!
+   * \return A special StmtSRef, which doesn't point to any stmt in the AST,
+   * only serving as a "mark" to hint compute-at to do nothing
+   * \note This is only as a faked loop sref for compute-at and reverse-compute-at,
+   * i.e.
+   *
+   * compute-at(block, loop_sref):
+   *   compute-inline(block)                if loop_sref.same_as(InlineMark())
+   *   no-op                                if loop_sref.same_as(RootMark())
+   *   compute-at-impl(block, loop_sref)    otherwise
+   */
+  TVM_DLL static StmtSRef RootMark();
+};
+
+/*!
+ * \brief Type of dependency. Right now we have 4 types of dependencies
+ * 1) Read-after-write (kRAW)
+ * 2) Write-after-write (kWAW)
+ * 3) Write-after-read (kWAR)
+ * 4) Opaque dependency (kOpaque)
+ */
+enum class DepKind : int32_t {
+  kRAW = 0,
+  kWAW = 1,
+  kWAR = 2,
+  kOpaque = 3,
+};
+
+/*!
+ * \brief A tuple (src, dst, kind) representing certain types of dependency.
+ * For example, (A, B, kRAW) means block B depends on block A, and the dependency kind is
+ * read-after-write, which means block B reads the result written by block A.
+ */
+class DependencyNode : public Object {
+ public:
+  /*! \brief The source of the dependency relation */
+  StmtSRef src;
+  /*! \brief The destination of the dependency relation */
+  StmtSRef dst;
+  /*! \brief The dependency kind */
+  DepKind kind;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("src", &src);
+    v->Visit("dst", &dst);
+    v->Visit("kind", &kind);
+  }
+
+  static constexpr const char* _type_key = "tir.Dependency";
+  TVM_DECLARE_FINAL_OBJECT_INFO(DependencyNode, Object);
+};
+
+/*!
+ * \brief Managed reference to DependencyNode
+ * \sa DependencyNode
+ */
+class Dependency : public ObjectRef {
+ public:
+  /*! \brief Constructor */
+  TVM_DLL explicit Dependency(StmtSRef src, StmtSRef dst, DepKind kind);
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Dependency, ObjectRef, DependencyNode);
+};
+
+/*!
+ * \brief An object with 1-to-1 correspondence with each block reference in the sref tree.
+ * This data structure is used to track the producer-consumer dependencies between blocks.
+ * For example even leaf nodes have a scope node, even though they have no dependencies.
+ *
+ * Glossary:
+ * - Block scope: A contiguous subtree of the sref tree, rooted at each block sref,
+ * whose components are:
+ *   - scope root: a block sref
+ *   - internal srefs: loop srefs
+ *   - scope leaves: block srefs
+ * - Child block: The scope leaf blocks under the scope root or a specific internal sref
+ */
+class BlockScopeNode : public Object {
+ public:
+  /*!
+   * \brief Lookup table for the `src` of dependencies
+   * \note We intentionally didn't use tvm::Map as the data structure, because we need the values
+   * inside to be mutable so that they could be further maintained properly during transformations.
+   */
+  std::unordered_map<StmtSRef, Array<Dependency>, ObjectPtrHash, ObjectPtrEqual> src2deps;
+  /*! \brief Lookup table for the `dst` of dependencies */
+  std::unordered_map<StmtSRef, Array<Dependency>, ObjectPtrHash, ObjectPtrEqual> dst2deps;
+  /*! \brief The mapping from the buffer to the blocks who write it */
+  std::unordered_map<Buffer, Array<StmtSRef>, ObjectPtrHash, ObjectPtrEqual> buffer_writers;
+  /*!
+   * \brief This property indicates that the block scope (rooted at its corresponding block) is
+   * equivalent to of a stage pipeline. Under the following conditions:
+   *
+   * 1) The region cover property holds for every of its child blocks
+   * 2) No write-after-read dependency
+   */
+  bool stage_pipeline{false};
+
+  void VisitAttrs(AttrVisitor* v) {}
+
+  static constexpr const char* _type_key = "tir.BlockScope";
+  TVM_DECLARE_FINAL_OBJECT_INFO(BlockScopeNode, Object);
+
+ public:
+  /******** Dependency ********/
+  /*!
+   * \brief Get all dependencies whose `src` equals `src`
+   * \param src The queried block
+   * \return The dependencies
+   */
+  TVM_DLL Array<Dependency> GetDepsBySrc(const StmtSRef& src) const;
+  /*!
+   * \brief Get all dependencies whose `dst` equals `dst`
+   * \param dst The queried block
+   * \return The dependencies
+   */
+  TVM_DLL Array<Dependency> GetDepsByDst(const StmtSRef& dst) const;
+};
+
+/*!
+ * \brief Managed reference to BlockScopeNode
+ * \sa BlockScopeNode
+ */
+class BlockScope : public ObjectRef {
+ public:
+  /*! \brief The constructor creating an empty block scope with on dependency information */
+  TVM_DLL BlockScope();
+  /*!
+   * \brief Create the object with the specific leaf blocks, and compute the dependency information
+   * between the leaf blocks.
+   * \param child_block_srefs The srefs to the leaf blocks
+   * \note We assume the leaf blocks are given in pre-DFS order
+   */
+  TVM_DLL BlockScope(const Array<StmtSRef>& child_block_srefs);
+
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockScope, ObjectRef, BlockScopeNode);
+};
+
+}  // namespace tir
+}  // namespace tvm
+
+#endif  // TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
diff --git a/include/tvm/tir/schedule/state.h b/include/tvm/tir/schedule/state.h
new file mode 100644
index 0000000..12b6fc1
--- /dev/null
+++ b/include/tvm/tir/schedule/state.h
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/tir/schedule/state.h
+ * \brief This file defines ScheduleState, the core data structure of TensorIR scheduling.
+ */
+#ifndef TVM_TIR_SCHEDULE_STATE_H_
+#define TVM_TIR_SCHEDULE_STATE_H_
+
+#include <tvm/ir/module.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/schedule/block_scope.h>
+
+#include <unordered_map>
+#include <utility>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief The information about a TensorIR block, it contains two categories of information
+ * 1) Info on the block scope rooted at a specific block, including dependency tracking,
+ * flags indicating if the scope is a stage pipeline, etc.
+ * 2) Info on the block itself, including if the block has a quasi-affine binding, if the regions it
+ * reads are completely covered by their producers, etc.
+ */
+struct BlockInfo {
+  /*! \brief Property of a block scope rooted at the block, storing dependencies in the scope */
+  BlockScope scope{nullptr};
+  // The properties below are information about the current block realization under its parent scope
+  /*! \brief Property of a block, indicating the block realization binding is quasi-affine */
+  bool affine_binding{false};
+  /*!
+   * \brief Property of a block, indicating each of the block's read regions is fully
+   * produced by its producers
+   */
+  bool region_cover{false};
+
+  BlockInfo() = default;
+
+  explicit BlockInfo(BlockScope scope, bool affine_binding = false, bool region_cover = false)
+      : scope(std::move(scope)),         //
+        affine_binding(affine_binding),  //
+        region_cover(region_cover) {}
+};
+
+/*!
+ * \brief The bitmask of the debug flag in the ScheduleStateNode.
+ * \sa ScheduleStateNode
+ */
+enum class ScheduleDebugMask : int32_t {
+  /*! \brief Verify the correctness of the sref tree */
+  kVerifySRefTree = 1,
+  /*! \brief Verify the correctness of affine_binding */
+  kVerifyAffineBinding = 2,
+  /*! \brief Verify the correctness of region_cover */
+  kVerifyRegionCover = 4,
+  /*! \brief Verify the correctness of stage_pipeline */
+  kVerifyStagePipeline = 8,
+};
+
+/*!
+ * \brief The state of scheduling, which exposes a `Replace` method as
+ * the primary interface for all the scheduling primitives to manipulate the TensorIR.
+ *
+ * The data structure contains the following information
+ * 1) The AST being scheduled (mod)
+ * 2) The sref tree of schedulable statements (indicated by the srefs)
+ * 3) The dependency information of each block scope (block_info)
+ * 4) A reverse mapping from the AST nodes to that in the sref tree (stmt2ref)
+ * 5) A debug flag, if set, extra checking is enabled (debug_mode)
+ */
+class ScheduleStateNode : public Object {
+ public:
+  /*! \brief The AST of the module being scheduled */
+  IRModule mod;
+  /*!
+   * \brief Mapping from a block sref to its correpsonding BlockInfo,
+   * tracking the dependency inside the block scope,
+   * and storing necessary information flags for scheduling
+   */
+  std::unordered_map<StmtSRef, BlockInfo, ObjectPtrHash, ObjectPtrEqual> block_info;
+  /*! \brief The reverse mapping from block/for-loop to their corresponding srefs */
+  std::unordered_map<const StmtNode*, StmtSRef> stmt2ref;
+  /*!
+   * \brief Do extra correctness checking after the class creation
+   * and each time after calling the Replace method.
+   * \sa ScheduleDebugMask
+   */
+  int debug_mode;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("mod", &mod);
+    // `block_info` is not visited
+    // `stmt2ref` is not visited
+    v->Visit("debug_mode", &debug_mode);
+  }
+  /*!
+   * \brief Replace the part of the AST, as being pointed to by `src_sref`,
+   * with a specific statement `tgt_stmt`, and maintain the sref tree accordingly.
+   * Replace will try to perform copy on write as much as possible when the ScheduleState holds
+   * the only copy to the IRModule and IR nodes.
+   *
+   * Only 3 types of replacements are allowed: from `src_sref->stmt` to `tgt_stmt`.
+   * 1) Block -> Block
+   * 2) Loop -> Loop
+   * 3) Loop -> BlockRealize
+   *
+   * \param src_sref The sref to the statement to be replaced
+   * \param tgt_stmt The statement to be replaced in
+   * \param block_sref_reuse Maps an old block (to be replaced in the subtree under
+   * `src_sref->stmt`) to a new block (replaced to, in the subtree under `tgt_stmt`), and enforces
+   * reuse of srefs between them (rather than create new srefs) i.e. after being replaced, the sref
+   * that points to the old block will point to the new one
+   * \note The reuse of loop srefs are detected automatically according to the reuse of loop vars.
+   */
+  TVM_DLL void Replace(const tir::StmtSRef& src_sref, const Stmt& tgt_stmt,
+                       const Map<Block, Block>& block_sref_reuse);
+  /*!
+   * \brief Trigger the verification according to the `debug_mode` bitmask.
+   * 1) If the bitmask `kVerifySRefTree` is on, verify the correctness of the sref tree.
+   * 2) If the bitmask `kVerifyAffineBinding` is on, verify the correctness of `affine_binding`
+   * 3) If the bitmask `kVerifyRegionCover` is on, verify the correctness of `region_cover`
+   * 4) If the bitmask `kVerifyStagePipeline` is on, verify the correctness of `stage_pipeline`
+   */
+  TVM_DLL void DebugVerify() const;
+
+  static constexpr const char* _type_key = "tir.ScheduleState";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleStateNode, Object);
+
+  /******** Property of blocks ********/
+  /*! \brief Returns the BlockInfo correpsonding to the block sref */
+  TVM_DLL BlockInfo GetBlockInfo(const StmtSRef& block_sref) const;
+  /*!
+   * \brief Get the BlockScope correpsonding to the sref of scope root block
+   * \param scope_root The block sref to be retrieved
+   * \return The corresponding BlockScope
+   */
+  BlockScope GetBlockScope(const StmtSRef& scope_root) const {
+    return GetBlockInfo(scope_root).scope;
+  }
+  /*!
+   * \brief Check a cached flag indicating if the specific block has quasi-affine bindings
+   * \param block_sref The block sref to be checked
+   * \return A boolean flag indicating if the block has quasi-affine bindings
+   */
+  bool IsAffineBlockBinding(const StmtSRef& block_sref) const {
+    return GetBlockInfo(block_sref).affine_binding;
+  }
+  /*!
+   * \brief Check a cached flag indicating if each of the specific consumer block's read region
+   * is fully produced by its producers
+   * \param consumer_block_sref The specific consumer block
+   * \return A boolean flag indicating if the block has quasi-affine bindings
+   */
+  bool IsRegionCoveredConsumer(const StmtSRef& consumer_block_sref) const {
+    return GetBlockInfo(consumer_block_sref).region_cover;
+  }
+  /*!
+   * \brief Check a cached flag indicating if a block scope is an equivalence of a stage pipeline
+   * \param scope_root The block sref to be retrieved
+   * \return The corresponding BlockScope
+   */
+  bool IsStagePipeline(const StmtSRef& scope_root) const {
+    return GetBlockScope(scope_root)->stage_pipeline;
+  }
+};
+
+/*!
+ * \brief Managed reference to ScheduleStateNode
+ * \sa ScheduleStateNode
+ */
+class ScheduleState : public ObjectRef {
+ public:
+  /*!
+   * \brief Construct a schedule state from an IRModule
+   * \param mod The IRModule to be scheduled
+   * \param debug_mode Do extra correctness checking after the class creation
+   * and each time after calling the Replace method.
+   */
+  TVM_DLL explicit ScheduleState(IRModule mod, int debug_mode = 0);
+  /*!
+   * \brief Construct a schedule state from a PrimFunc
+   * \param func The PrimFunc to be scheduled. A new IRModule will be created with
+   * this specific PrimFunc as "main" function in the module to be scheduled
+   * \param debug_mode Do extra correctness checking after the class creation
+   * and each time after calling the Replace method.
+   */
+  TVM_DLL explicit ScheduleState(PrimFunc func, int debug_mode = 0);
+
+  /*! \return The mutable pointer to the ScheduleStateNode */
+  ScheduleStateNode* get() const { return static_cast<ScheduleStateNode*>(data_.get()); }
+
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ScheduleState, ObjectRef, ScheduleStateNode);
+};
+
+}  // namespace tir
+}  // namespace tvm
+
+#endif  // TVM_TIR_SCHEDULE_STATE_H_
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index ad91eab..681fc31 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -48,6 +48,9 @@ from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
 from .op import comm_reducer, min, max, sum
 from .op import q_multiply_shift
 
+from .schedule import StmtSRef, BlockScope, ScheduleState
+
+from . import schedule
 from . import ir_builder
 from . import transform
 from . import analysis
diff --git a/python/tvm/tir/schedule/__init__.py b/python/tvm/tir/schedule/__init__.py
new file mode 100644
index 0000000..21721f7
--- /dev/null
+++ b/python/tvm/tir/schedule/__init__.py
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-import
+"""Namespace for the TensorIR schedule API."""
+
+from .block_scope import BlockScope, Dependency, DepKind, StmtSRef
+from .state import ScheduleDebugMask, ScheduleState
diff --git a/python/tvm/tir/schedule/_ffi_api_schedule.py b/python/tvm/tir/schedule/_ffi_api_schedule.py
new file mode 100644
index 0000000..ae8bdfd
--- /dev/null
+++ b/python/tvm/tir/schedule/_ffi_api_schedule.py
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""FFI APIs for tvm.tir.schedule"""
+import tvm._ffi
+
+tvm._ffi._init_api("tir.schedule", __name__)  # pylint: disable=protected-access
diff --git a/python/tvm/tir/schedule/block_scope.py b/python/tvm/tir/schedule/block_scope.py
new file mode 100644
index 0000000..8281452
--- /dev/null
+++ b/python/tvm/tir/schedule/block_scope.py
@@ -0,0 +1,152 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope."""
+from enum import IntEnum
+from typing import List, Optional, Union
+
+from tvm._ffi import register_object
+from tvm.runtime import Object
+from tvm.tir import Block, For
+
+from . import _ffi_api_schedule
+
+
+@register_object("tir.StmtSRef")
+class StmtSRef(Object):
+    """An object that refers to schedulable elements in the TensorIR, aka "sref".
+
+    Glossary
+    - Block sref: An StmtSref that points to a TensorIR block.
+    - Loop sref: An StmtSRef that points to a TensorIR for loop.
+    - Parent sref: The parent sref of an sref is the block/loop sref that points to its closest
+    schedulable statement of its ancestors on the TensorIR AST.
+    - Root sref: Sref to the root block. Every sref has exactly one parent sref
+    except for root sref.
+    - Sref tree: The parent-children-relationship of srefs that forms a tree,
+    uniquely determined by the TensorIR AST.
+    """
+
+    seq_index: int
+
+    @property
+    def stmt(self) -> Optional[Union[Block, For]]:
+        """The block/for stmt the object refers to"""
+        return _ffi_api_schedule.StmtSRefStmt(self)  # pylint: disable=no-member
+
+    @property
+    def parent(self) -> Optional["StmtSRef"]:
+        """The parent sref"""
+        return _ffi_api_schedule.StmtSRefParent(self)  # pylint: disable=no-member
+
+    @staticmethod
+    def inline_mark() -> "StmtSRef":
+        """A special StmtSRef, which doesn't point to any stmt in the AST,
+        only serving as a "mark" to hint compute-at to do the work of compute-inline"""
+        return _ffi_api_schedule.StmtSRefInlineMark()  # pylint: disable=no-member
+
+    @staticmethod
+    def root_mark() -> "StmtSRef":
+        """A special StmtSRef, which doesn't point to any stmt in the AST,
+        only serving as a "mark" to hint compute-at to do nothing"""
+        return _ffi_api_schedule.StmtSRefRootMark()  # pylint: disable=no-member
+
+
+class DepKind(IntEnum):
+    """Type of dependency.
+
+    Attributes
+    ----------
+    RAW : int = 0
+        Read-after-write dependency
+    WAW : int = 1
+        Write-after-write dependency
+    WAR : int = 2
+        Write-after-read dependency. Not supported in TensorIR for now.
+    OPAQUE: int = 3
+        Opaque dependency
+    """
+
+    RAW = 0
+    WAW = 1
+    WAR = 2
+    OPAQUE = 3
+
+
+@register_object("tir.Dependency")
+class Dependency(Object):
+    """A tuple (src, dst, kind) representing certain types of dependency.
+    For example, (A, B, kRAW) means block B depends on block A, and the dependency kind is
+    read-after-write, which means block B reads the result written by block A.
+
+    Parameters
+    ----------
+    src : StmtSRef
+        The source of the dependency relation
+    dst : StmtSRef
+        The destination of the dependency relation
+    kind : DepKind
+        The dependency kind
+    """
+
+    src: StmtSRef
+    dst: StmtSRef
+    kind: DepKind
+
+
+@register_object("tir.BlockScope")
+class BlockScope(Object):
+    """An object corresponds to each block sref in the sref tree,
+       which tracks the producer-consumer dependency between blocks.
+
+    Glossary:
+    - Block scope: A contiguous subtree of the sref tree, rooted at each block sref,
+    whose components are:
+        - scope root: a block sref
+        - internal srefs: loop srefs
+        - scope leaves: block srefs
+    - Child block: The scope leaf blocks under the scope root or a specific internal sref
+    """
+
+    def get_deps_by_src(self, block: StmtSRef) -> List[Dependency]:
+        """Get all dependencies whose `src` is the target`block`.
+
+        Parameters
+        ----------
+        block: StmtSRef
+            The queried block
+
+        Returns
+        -------
+        blocks: List[Dependency]
+            The dependencies
+        """
+        return _ffi_api_schedule.BlockScopeGetDepsBySrc(self, block)  # pylint: disable=no-member
+
+    def get_deps_by_dst(self, block: StmtSRef) -> List[Dependency]:
+        """Get all dependencies whose `dst` is the target `block`.
+
+        Parameters
+        ----------
+        block: StmtSRef
+            The queried block
+
+        Returns
+        -------
+        blocks: List[Dependency]
+            The dependencies
+        """
+        return _ffi_api_schedule.BlockScopeGetDepsByDst(self, block)  # pylint: disable=no-member
diff --git a/python/tvm/tir/schedule/state.py b/python/tvm/tir/schedule/state.py
new file mode 100644
index 0000000..180fede
--- /dev/null
+++ b/python/tvm/tir/schedule/state.py
@@ -0,0 +1,185 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""This file defines ScheduleState, the core data structure of TensorIR scheduling."""
+from enum import IntEnum
+from typing import Dict, Optional, Union
+
+from tvm._ffi import register_object
+from tvm.ir import IRModule
+from tvm.runtime import Object
+from tvm.tir import Block, BlockRealize, For, PrimFunc
+
+from . import _ffi_api_schedule
+from .block_scope import BlockScope, StmtSRef
+
+
+class ScheduleDebugMask(IntEnum):
+    """The bitmask of the `debug_mode` flag in the ScheduleState class.
+
+    If the `debug_mode` flag has a certain bit on, then the correpsonding
+    verification pass will be conducted. For example, if `(debug_mode & VERIFY_SREF_TREE) != 0`,
+    then the correctness of the sref tree will be verified after each schedule instruction.
+
+    Attributes
+    ----------
+    VERIFY_SREF_TREE : int = 1
+        Verify the correctness of the sref tree
+    VERIFY_AFFINE_BINDING : int = 2
+        Verify the correctness of affine_binding
+    VERIFY_REGION_COVER : int = 4
+        Verify the correctness of region_cover
+    VERIFY_STAGE_PIPELINE: int = 8
+        Verify the correctness of stage_pipeline
+    """
+
+    VERIFY_SREF_TREE = 1
+    VERIFY_AFFINE_BINDING = 2
+    VERIFY_REGION_COVER = 4
+    VERIFY_STAGE_PIPELINE = 8
+
+
+@register_object("tir.ScheduleState")
+class ScheduleState(Object):
+    """The state of scheduling, which exposes a `Replace` method as
+    the primary resort for all the scheduling primitives to manipulate the TensorIR.
+
+    The data structure contains the following information
+    1) The AST being scheduled (mod)
+    2) The sref tree of schedulable statements (indicated by the srefs)
+    3) The dependency information of each block scope (block_info)
+    4) A reverse mapping from the AST nodes to that in the sref tree (get_sref)
+    5) A debug flag, if set, extra checking is enabled (debug_mode)
+
+    Parameters
+    ----------
+    mod : IRModule
+        The AST of the module being scheduled
+    debug_mode : int
+        Do extra correctness checking after the object construction
+        and each time after calling the Replace method.
+    """
+
+    mod: IRModule
+    debug_mode: int
+
+    def __init__(
+        self,
+        func_or_mod: Union[PrimFunc, IRModule],
+        debug_mode: Union[bool, int] = False,
+    ):
+        """Construct a schedule state from an IRModule or a PrimFunc
+
+        Parameters
+        ----------
+        func_or_mod : Union[PrimFunc, IRModule]
+            The IRModule or PrimFunc to be scheduled
+        debug_mode : Union[bool, int]
+            Do extra correctness checking after the class creation and each time
+            after calling the Replace method.
+            Possible choices of `debug_mode`:
+            1) True - Turn on all the checks
+            2) False - Turn off all the checks
+            3) An integer - Turn on checks according to the bitmasks provided in ScheduleDebugMask
+        """
+        if isinstance(debug_mode, bool):
+            if debug_mode:
+                debug_mode = -1
+            else:
+                debug_mode = 0
+        if not isinstance(debug_mode, int):
+            raise TypeError(f"`debug_mode` should be integer or boolean, but gets: {debug_mode}")
+        self.__init_handle_by_constructor__(
+            _ffi_api_schedule.ScheduleState,  # pylint: disable=no-member
+            func_or_mod,
+            debug_mode,
+        )
+
+    def get_sref(self, stmt: Union[Block, For]) -> Optional[StmtSRef]:
+        """Return the corresponding sref that points to the stmt
+
+        Parameters
+        ----------
+        stmt : Union[Block, For]
+            The schedulable statement in the TensorIR to be retrieved for its sref
+
+        Returns
+        -------
+        sref : StmtSRef
+            The corresponding sref
+        """
+        return _ffi_api_schedule.ScheduleStateGetSRef(self, stmt)  # pylint: disable=no-member
+
+    def get_block_scope(self, block_sref: StmtSRef) -> BlockScope:
+        """Get the BlockScope correpsonding to the block sref
+
+        Parameters
+        ----------
+        block_sref : StmtSRef
+            The block sref to be retrieved
+
+        Returns
+        -------
+        sref : StmtSRef
+            The corresponding sref
+        """
+        return _ffi_api_schedule.ScheduleStateGetBlockScope(  # pylint: disable=no-member
+            self, block_sref
+        )
+
+    def replace(
+        self,
+        src_sref: StmtSRef,
+        tgt_stmt: Union[Block, For, BlockRealize],
+        block_sref_reuse: Optional[Dict[Block, Block]] = None,
+    ) -> None:
+        """
+        Replace the part of the AST, as being pointed to by `src_sref`,
+        with a specific statement `tgt_stmt`, and maintain the sref tree accordingly.
+        Replace will try to perform copy on write as much as possible when the ScheduleState holds
+        the only copy to the IRModule and IR nodes.
+
+        Only 3 types of replacements are allowed: from `src_sref->stmt` to `tgt_stmt`.
+        1) Block -> Block
+        2) Loop -> Loop
+        3) Loop -> BlockRealize
+
+        Parameters
+        ----------
+        src_sref : StmtSRef
+            The sref to the statement to be replaced in the TensorIR AST
+
+        tgt_stmt : Union[Block, For, BlockRealize]
+            The statement to be replaced to
+
+        block_sref_reuse : Optional[Dict[Block, Block]] = None
+            Maps an old block (to be replaced in the subtree under `src_sref->stmt`)
+            to a new block (replaced to, in the subtree under `tgt_stmt`), and enforces
+            reuse of srefs between them (rather than create new srefs) i.e. after being replaced,
+            the sref that points to the old block will point to the new one
+
+        Note
+        ----------
+        The reuse of loop srefs are detected automatically according to the reuse of loop vars.
+        """
+        if block_sref_reuse is None:
+            block_sref_reuse = {}
+        _ffi_api_schedule.ScheduleStateReplace(  # pylint: disable=no-member
+            self,
+            src_sref,
+            tgt_stmt,
+            block_sref_reuse,
+        )
diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py
index 4746206..46f456c 100644
--- a/python/tvm/tir/stmt.py
+++ b/python/tvm/tir/stmt.py
@@ -26,12 +26,13 @@ Each statement node have subfields that can be visited from python side.
     assert isinstance(st, tvm.tir.stmt.Store)
     assert(st.buffer_var == a)
 """
-from typing import List, Optional, Mapping
 from enum import IntEnum
+from typing import List, Mapping, Optional, Union
+
 import tvm._ffi
+from tvm.ir import PrimExpr, Range, Span
+from tvm.runtime import Object, const
 
-from tvm.runtime import Object
-from tvm.ir import Span, PrimExpr, Range
 from . import _ffi_api
 from .buffer import Buffer
 from .expr import IterVar
@@ -589,7 +590,7 @@ class BlockRealize(Stmt):
     iter_values : List[PrimExpr]
         The binding values of the block var.
 
-    predicate : PrimExpr
+    predicate : Union[PrimExpr, bool]
         The predicate of the block.
 
     block : Block
@@ -607,12 +608,18 @@ class BlockRealize(Stmt):
     def __init__(
         self,
         iter_values: List[PrimExpr],
-        predicate: PrimExpr,
+        predicate: Union[PrimExpr, bool],
         block: Block,
         span: Optional[Span] = None,
     ):
+        if isinstance(predicate, bool):
+            predicate = const(predicate, "bool")
         self.__init_handle_by_constructor__(
-            _ffi_api.BlockRealize, iter_values, predicate, block, span
+            _ffi_api.BlockRealize,
+            iter_values,
+            predicate,
+            block,
+            span,
         )
 
 
diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc
index 4380795..7afdcab 100644
--- a/src/printer/tvmscript_printer.cc
+++ b/src/printer/tvmscript_printer.cc
@@ -683,6 +683,7 @@ Doc TVMScriptPrinter::VisitStmt_(const IfThenElseNode* op) {
   doc << "if " << Print(op->condition) << ":";
   doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->then_case));
   if (!is_one(op->condition) && op->else_case.defined()) {
+    doc << Doc::NewLine();
     doc << "else:" << Doc::Indent(4, Doc::NewLine() << PrintBody(op->else_case));
   }
   return doc;
diff --git a/src/tir/analysis/var_touch.cc b/src/tir/analysis/var_touch.cc
index 2a23329..40a2cce 100644
--- a/src/tir/analysis/var_touch.cc
+++ b/src/tir/analysis/var_touch.cc
@@ -18,7 +18,7 @@
  */
 
 /*!
- * \file simple_analysis.cc
+ * \file var_touch.cc
  * \brief Implementation of simple passes
  */
 #include <tvm/tir/analysis.h>
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index 2aeaae3..87ead3e 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -221,7 +221,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     .set_dispatch<WhileNode>([](const ObjectRef& node, ReprPrinter* p) {
       auto* op = static_cast<const WhileNode*>(node.get());
       p->PrintIndent();
-      p->stream << "while(" << op->condition << "){\n";
+      p->stream << "while(" << op->condition << ") {\n";
       p->indent += 2;
       p->Print(op->body);
       p->indent -= 2;
@@ -781,7 +781,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
       auto* op = static_cast<const BlockNode*>(node.get());
       p->PrintIndent();
       PrintBlockTitle(op, p);
-      p->stream << "{\n";
+      p->stream << " {\n";
       p->indent += 2;
 
       // Print block elements (e.g. reads/writes, etc)
@@ -820,7 +820,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
       auto* block_op = op->block.get();
       p->PrintIndent();
       PrintBlockTitle(block_op, p);
-      p->stream << "{\n";
+      p->stream << " {\n";
       p->indent += 2;
 
       // Print binding iter_values
diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h
new file mode 100644
index 0000000..32d9f6d
--- /dev/null
+++ b/src/tir/schedule/analysis.h
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_TIR_SCHEDULE_ANALYSIS_H_
+#define TVM_TIR_SCHEDULE_ANALYSIS_H_
+
+#include <tvm/tir/schedule/state.h>
+
+namespace tvm {
+namespace tir {
+
+/******** Verification ********/
+/*!
+ * \brief Verify the sref tree state is consistent with the IR
+ * \param self The schedule state containing the sref to be verified
+ * \throw An exception will be thrown if the sref tree is not valid
+ */
+void VerifySRefTree(const ScheduleState& self);
+
+/******** Block-loop relation ********/
+/*!
+ * \brief Get the leaf blocks of a scope where a specific block/loop is in
+ * \param self The schedule state
+ * \param parent_sref The StmtSRef that points to the parent block/loop
+ * \return A list of leaf blocks
+ */
+Array<StmtSRef> GetChildBlocks(const ScheduleState& self, const StmtSRef& parent_sref);
+
+}  // namespace tir
+}  // namespace tvm
+
+#endif  // TVM_TIR_SCHEDULE_ANALYSIS_H_
diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc
new file mode 100644
index 0000000..005ff37
--- /dev/null
+++ b/src/tir/schedule/analysis/analysis.cc
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace tir {
+
+/******** Block-loop relation ********/
+
+Array<StmtSRef> GetChildBlocks(const ScheduleState& self, const StmtSRef& parent_sref) {
+  struct Collector : public StmtVisitor {
+   public:
+    static Array<StmtSRef> Collect(const ScheduleState& self, const Stmt& stmt) {
+      Collector collector(self);
+      collector(stmt);
+      return std::move(collector.result_);
+    }
+
+   private:
+    explicit Collector(const ScheduleState& self) : self_(self) {}
+
+    void VisitStmt_(const BlockNode* block) final {
+      auto it = self_->stmt2ref.find(block);
+      ICHECK(it != self_->stmt2ref.end());
+      result_.push_back(it->second);
+    }
+
+    const ScheduleState& self_;
+    Array<StmtSRef> result_;
+  };
+
+  if (parent_sref->stmt->IsInstance<ForNode>()) {
+    const auto* loop = static_cast<const ForNode*>(parent_sref->stmt);
+    return Collector::Collect(self, loop->body);
+  } else if (parent_sref->stmt->IsInstance<BlockNode>()) {
+    const auto* block = static_cast<const BlockNode*>(parent_sref->stmt);
+    return Collector::Collect(self, block->body);
+  }
+  ICHECK(false) << "Unreachable";
+  throw;
+}
+
+}  // namespace tir
+}  // namespace tvm
diff --git a/src/tir/schedule/analysis/verify.cc b/src/tir/schedule/analysis/verify.cc
new file mode 100644
index 0000000..edb62b5
--- /dev/null
+++ b/src/tir/schedule/analysis/verify.cc
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace tir {
+
+class SRefTreeVerifier : public StmtVisitor {
+ public:
+  static void Verify(const ScheduleStateNode* self) { SRefTreeVerifier(self).Verify(); }
+
+ private:
+  /*! \brief Constructor */
+  explicit SRefTreeVerifier(const ScheduleStateNode* self) : self_(self) {}
+
+  void Verify() {
+    VisitPrimFuncs(self_->mod, [this](const PrimFuncNode* func) { this->VisitStmt(func->body); });
+    ICHECK_EQ(n_sref_visited_, static_cast<int>(self_->stmt2ref.size()));
+    for (const auto& kv : self_->block_info) {
+      const StmtSRef& sref = kv.first;
+      ICHECK(sref->stmt != nullptr)
+          << "InternalError: An expired sref is found in the block_scope mapping";
+      auto it = self_->stmt2ref.find(sref->stmt);
+      ICHECK(it != self_->stmt2ref.end())
+          << "InternalError: The sref points to a statement that does not exist in stmt2ref";
+      const StmtSRef& sref2 = it->second;
+      ICHECK(sref.same_as(sref2))
+          << "InternalError: The sref points to a statement whose corresponding sref in stmt2ref "
+             "is not the same object as itself";
+    }
+    ICHECK_EQ(n_block_sref_visited_, static_cast<int>(self_->block_info.size()));
+  }
+
+  void VisitStmt_(const BlockNode* block) final {
+    if (init_block_depth_) {
+      ICHECK(!self_->stmt2ref.count(block)) << "InternalError: A block inside init block has its "
+                                               "corresponding sref, which is not allowed";
+      StmtVisitor::VisitStmt_(block);
+      return;
+    }
+    ICHECK(self_->stmt2ref.count(block))
+        << "InternalError: A BlockNode should appear in sref map, but it didn't\n"
+        << GetRef<Stmt>(block);
+    ++n_sref_visited_;
+    ++n_block_sref_visited_;
+    const StmtSRef& sref = self_->stmt2ref.at(block);
+    ICHECK(self_->block_info.count(sref))
+        << "InternalError: Cannot find scope information of the BlockNode:\n"
+        << GetRef<Stmt>(block);
+    ICHECK(sref->parent == ancestors_.back())
+        << "InternalError: Parent information mismatch for BlockNode:\n"
+        << GetRef<Stmt>(block) << "\nIts parent is supposed to be:\n"
+        << GetRef<Stmt>(ancestors_.back()->stmt) << "\nHowever, its parent is incorrect and is:\n"
+        << (sref->parent ? Optional<Stmt>(GetRef<Stmt>(sref->parent->stmt))
+                         : Optional<Stmt>(NullOpt));
+    ancestors_.push_back(sref.operator->());
+    if (block->init.defined()) {
+      ++init_block_depth_;
+      VisitStmt(block->init.value());
+      --init_block_depth_;
+    }
+    VisitStmt(block->body);
+    ancestors_.pop_back();
+  }
+
+  void VisitStmt_(const ForNode* loop) final {
+    if (init_block_depth_) {
+      ICHECK(!self_->stmt2ref.count(loop)) << "InternalError: A loop inside init block has its "
+                                              "corresponding sref, which is not allowed";
+      StmtVisitor::VisitStmt_(loop);
+      return;
+    }
+    ICHECK(self_->stmt2ref.count(loop))
+        << "InternalError: A ForNode should appear in sref map, but it didn't\n"
+        << GetRef<Stmt>(loop);
+    ++n_sref_visited_;
+    const StmtSRef& sref = self_->stmt2ref.at(loop);
+    Optional<Stmt> stmt = NullOpt;
+    ICHECK(sref->parent == ancestors_.back())
+        << "InternalError: Parent information mismatch for ForNode:\n"
+        << GetRef<Stmt>(loop) << "\nIts parent is supposed to be:\n"
+        << GetRef<Stmt>(ancestors_.back()->stmt) << "\nHowever, its parent is incorrect and is:\n"
+        << (sref->parent ? Optional<Stmt>(GetRef<Stmt>(sref->parent->stmt))
+                         : Optional<Stmt>(NullOpt));
+    ancestors_.push_back(sref.operator->());
+    StmtVisitor::VisitStmt_(loop);
+    ancestors_.pop_back();
+  }
+
+  void VisitStmt_(const SeqStmtNode* seq_stmt) final {
+    // Verify seq_index
+    if (init_block_depth_) {
+      StmtVisitor::VisitStmt_(seq_stmt);
+      return;
+    }
+    int n = static_cast<int>(seq_stmt->seq.size());
+    for (int i = 0; i < n; ++i) {
+      const Stmt& child = seq_stmt->seq[i];
+      StmtSRef sref{nullptr};
+      if (const auto* realize = child.as<BlockRealizeNode>()) {
+        const auto* block = realize->block.get();
+        ICHECK(self_->stmt2ref.count(block));
+        sref = self_->stmt2ref.at(block);
+      } else if (child->IsInstance<ForNode>()) {
+        ICHECK(self_->stmt2ref.count(child.get()));
+        sref = self_->stmt2ref.at(child.get());
+      } else {
+        continue;
+      }
+      ICHECK_EQ(sref->seq_index, i) << "InternalError: A StmtSRef has incorrect seq_index";
+    }
+    StmtVisitor::VisitStmt_(seq_stmt);
+  }
+
+  /*! \brief The schedule it belongs to */
+  const ScheduleStateNode* self_;
+  /*! \brief Parent information during the visit */
+  std::vector<const StmtSRefNode*> ancestors_ = {nullptr};
+  /*! \brief If the visitor is currently in the init block */
+  int init_block_depth_ = 0;
+  /*! \brief Number of srefs that are visited */
+  int n_sref_visited_ = 0;
+  /*! \brief Number of block srefs that are visited */
+  int n_block_sref_visited_ = 0;
+};
+
+void VerifySRefTree(const ScheduleState& self) { SRefTreeVerifier::Verify(self.get()); }
+
+}  // namespace tir
+}  // namespace tvm
diff --git a/src/tir/schedule/block_scope.cc b/src/tir/schedule/block_scope.cc
new file mode 100644
index 0000000..f1ce65e
--- /dev/null
+++ b/src/tir/schedule/block_scope.cc
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "./utils.h"
+
+namespace tvm {
+namespace tir {
+
+/******** Utility functions ********/
+
+template <class K, class V>
+using SMap = std::unordered_map<K, V, ObjectPtrHash, ObjectPtrEqual>;
+
+/*!
+ * \brief Add a dependency relation.
+ * \param src The source of the dependency
+ * \param dst The destination of the dependecy
+ * \param kind Type of the dependency
+ * \note This method is effectively NOP on self-loops
+ */
+void AddDependency(BlockScopeNode* self, const StmtSRef& src, const StmtSRef& dst, DepKind kind) {
+  if (!src.same_as(dst)) {
+    Dependency dep(src, dst, kind);
+    self->src2deps[src].push_back(dep);
+    self->dst2deps[dst].push_back(dep);
+  }
+}
+
+/******** Constructors ********/
+
+StmtSRef::StmtSRef(const StmtNode* stmt, StmtSRefNode* parent, int64_t seq_index) {
+  ObjectPtr<StmtSRefNode> n = make_object<StmtSRefNode>();
+  n->stmt = stmt;
+  n->parent = parent;
+  n->seq_index = seq_index;
+  data_ = std::move(n);
+}
+
+StmtSRef StmtSRef::InlineMark() {
+  static StmtSRef result(nullptr, nullptr, -1);
+  return result;
+}
+
+StmtSRef StmtSRef::RootMark() {
+  static StmtSRef result(nullptr, nullptr, -1);
+  return result;
+}
+
+Dependency::Dependency(StmtSRef src, StmtSRef dst, DepKind kind) {
+  ObjectPtr<DependencyNode> node = make_object<DependencyNode>();
+  node->src = std::move(src);
+  node->dst = std::move(dst);
+  node->kind = kind;
+  data_ = std::move(node);
+}
+
+BlockScope::BlockScope() { data_ = make_object<BlockScopeNode>(); }
+
+BlockScope::BlockScope(const Array<StmtSRef>& child_block_srefs) {
+  ObjectPtr<BlockScopeNode> n = make_object<BlockScopeNode>();
+  SMap<Buffer, Array<StmtSRef>> buffer_readers;
+  SMap<Buffer, Array<StmtSRef>>& buffer_writers = n->buffer_writers;
+  for (const StmtSRef& child_block_sref : child_block_srefs) {
+    const BlockNode* child_block = TVM_SREF_TO_BLOCK(child_block, child_block_sref);
+    // Step 1. Update `buffer_readers` and `buffer_writers` for each buffer
+    for (const BufferRegion& region : child_block->reads) {
+      buffer_readers[region->buffer].push_back(child_block_sref);
+    }
+    for (const BufferRegion& region : child_block->writes) {
+      buffer_writers[region->buffer].push_back(child_block_sref);
+    }
+    // Step 2. Update RAW dependency
+    for (const BufferRegion& region : child_block->reads) {
+      auto it = buffer_writers.find(region->buffer);
+      if (it != buffer_writers.end()) {
+        for (const StmtSRef& from : it->second) {
+          AddDependency(n.get(), from, child_block_sref, DepKind::kRAW);
+        }
+      }
+    }
+    // Step 3. Update WAW dependency
+    for (const BufferRegion& region : child_block->writes) {
+      auto it = buffer_writers.find(region->buffer);
+      if (it != buffer_writers.end()) {
+        for (const StmtSRef& from : it->second) {
+          AddDependency(n.get(), from, child_block_sref, DepKind::kWAW);
+        }
+      }
+    }
+    // Step 4. Update WAR dependency
+    for (const BufferRegion& region : child_block->writes) {
+      auto it = buffer_readers.find(region->buffer);
+      if (it != buffer_readers.end()) {
+        for (const StmtSRef& from : it->second) {
+          AddDependency(n.get(), from, child_block_sref, DepKind::kWAR);
+        }
+      }
+    }
+  }
+  data_ = std::move(n);
+}
+
+/******** Dependency ********/
+
+Array<Dependency> BlockScopeNode::GetDepsBySrc(const StmtSRef& block_sref) const {
+  auto iter = this->src2deps.find(block_sref);
+  if (iter != this->src2deps.end()) {
+    return iter->second;
+  } else {
+    return {};
+  }
+}
+
+Array<Dependency> BlockScopeNode::GetDepsByDst(const StmtSRef& block_sref) const {
+  auto iter = this->dst2deps.find(block_sref);
+  if (iter != this->dst2deps.end()) {
+    return iter->second;
+  } else {
+    return {};
+  }
+}
+
+/******** FFI ********/
+
+TVM_REGISTER_NODE_TYPE(StmtSRefNode);
+TVM_REGISTER_NODE_TYPE(DependencyNode);
+TVM_REGISTER_NODE_TYPE(BlockScopeNode);
+
+TVM_REGISTER_GLOBAL("tir.schedule.StmtSRefStmt")
+    .set_body_typed([](StmtSRef sref) -> Optional<Stmt> {
+      return GetRef<Optional<Stmt>>(sref->stmt);
+    });
+TVM_REGISTER_GLOBAL("tir.schedule.StmtSRefParent")
+    .set_body_typed([](StmtSRef sref) -> Optional<StmtSRef> {
+      return GetRef<Optional<StmtSRef>>(sref->parent);
+    });
+TVM_REGISTER_GLOBAL("tir.schedule.StmtSRefRootMark")  //
+    .set_body_typed(StmtSRef::RootMark);
+TVM_REGISTER_GLOBAL("tir.schedule.StmtSRefInlineMark")  //
+    .set_body_typed(StmtSRef::InlineMark);
+TVM_REGISTER_GLOBAL("tir.schedule.BlockScopeGetDepsBySrc")
+    .set_body_method<BlockScope>(&BlockScopeNode::GetDepsBySrc);
+TVM_REGISTER_GLOBAL("tir.schedule.BlockScopeGetDepsByDst")
+    .set_body_method<BlockScope>(&BlockScopeNode::GetDepsByDst);
+
+}  // namespace tir
+}  // namespace tvm
diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc
new file mode 100644
index 0000000..d1b899b
--- /dev/null
+++ b/src/tir/schedule/state.cc
@@ -0,0 +1,870 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "./utils.h"
+
+namespace tvm {
+namespace tir {
+
+template <class K, class V>
+using SMap = std::unordered_map<K, V, ObjectPtrHash, ObjectPtrEqual>;
+
+/**************** Utility functions ****************/
+
+/*!
+ * \brief Set the `StmtSRefNode::seq_index` field for stmt
+ * \param self The schedule class
+ * \param stmt The statement, or the realize node of the statement whose sref to be set
+ * \param seq_index The seq_index to be set
+ * \note The method is NOP for statements that are not scheduleable, i.e. not For or Block
+ */
+void SetSeqIndex(ScheduleStateNode* self, const Stmt& stmt, int seq_index) {
+  if (const auto* realize = stmt.as<BlockRealizeNode>()) {
+    const BlockNode* block = realize->block.get();
+    ICHECK(self->stmt2ref.count(block));
+    self->stmt2ref.at(block)->seq_index = seq_index;
+  } else if (const auto* block = stmt.as<BlockNode>()) {
+    ICHECK(self->stmt2ref.count(block));
+    self->stmt2ref.at(block)->seq_index = seq_index;
+  } else if (const auto* loop = stmt.as<ForNode>()) {
+    ICHECK(self->stmt2ref.count(loop));
+    self->stmt2ref.at(loop)->seq_index = seq_index;
+  } else {
+    // do nothing
+  }
+}
+
+/*!
+ * \brief Update seq_index of the children of a SeqStmt
+ * \param self The schedule class
+ * \param seq_stmt The SeqStmt whose children need updating
+ */
+void SetSeqIndexInChildren(ScheduleStateNode* self, const SeqStmtNode* seq_stmt) {
+  int i = 0;
+  for (const Stmt& stmt : seq_stmt->seq) {
+    SetSeqIndex(self, stmt, i);
+    ++i;
+  }
+}
+
+/*!
+ * \brief Update the sref information on the schedule class, as well as the statement of sref itself
+ * More specifically, update
+ *  `sref->stmt` to `new_stmt`
+ *  `self->stmt2ref`, remove the old statement that sref points to, and add the new statement
+ * \param self The schedule class to be updated
+ * \param sref The sref to be updated
+ * \param new_stmt The statement that replaces the statement inside the sref
+ */
+void UpdateSRef(ScheduleStateNode* self, StmtSRefNode* sref, const StmtNode* new_stmt) {
+  ICHECK(new_stmt->IsInstance<BlockNode>() || new_stmt->IsInstance<ForNode>());
+  const StmtNode* old_stmt = sref->stmt;
+  ICHECK_NE(new_stmt, old_stmt);
+  self->stmt2ref[new_stmt] = GetRef<StmtSRef>(sref);
+  self->stmt2ref.erase(sref->stmt);
+  sref->stmt = new_stmt;
+}
+
+/*!
+ * \brief Get PrimFunc and GlobalVar that the root block belongs to
+ * \param mod The IRModule
+ * \param root_block The root block of the PrimFunc
+ * \param result_g_var The result GlobalVar
+ * \return The result PrimFunc where the root block belongs to
+ * \note This function returns the pointer instead of ObjectRef to avoid later copy-on-write
+ */
+const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_block,
+                                    GlobalVar* result_g_var) {
+  for (const auto& kv : mod->functions) {
+    const GlobalVar& g_var = kv.first;
+    const BaseFunc& base_func = kv.second;
+    if (const auto* func = base_func.as<PrimFuncNode>()) {
+      if (const auto* realize = func->body.as<BlockRealizeNode>()) {
+        if (realize->block.get() == root_block) {
+          *result_g_var = g_var;
+          return func;
+        }
+      }
+    }
+  }
+  LOG(FATAL) << "IndexError: Could not get the correpsonding function in the schedule state of the "
+                "statement:\n"
+             << GetRef<Stmt>(root_block);
+  throw;
+}
+
+/**************** Creation ****************/
+
+/*! \brief A helper class to create a new ScheduleStateNode from an IRModule */
+class StateCreator : private StmtVisitor {
+ public:
+  /*!
+   * \brief The entry function
+   * \param self The schedule state to be completed
+   */
+  static ObjectPtr<ScheduleStateNode> Create(IRModule mod, int debug_mode) {
+    ObjectPtr<ScheduleStateNode> n = make_object<ScheduleStateNode>();
+    ScheduleStateNode* self = n.get();
+    // Set `n->mod`
+    n->mod = std::move(mod);
+    // Set `n->debug_mode`
+    n->debug_mode = debug_mode;
+    // Set `n->stmt2ref` and `n->block_info`
+    StateCreator creator(self);
+    for (const auto& kv : n->mod->functions) {
+      const BaseFunc& base_func = kv.second;
+      if (const auto* func = base_func.as<PrimFuncNode>()) {
+        creator.VisitStmt(func->body);
+      }
+    }
+    return n;
+  }
+
+ private:
+  explicit StateCreator(ScheduleStateNode* self)
+      : self_(self), srefs_{}, realizes_{}, block_frames_{} {
+    block_frames_.emplace({});
+  }
+
+  /*!
+   * \brief Add a new statement to the stack, which becomes the current scope
+   * \param stmt A for-loop statement or a block statement
+   * \return A sref to the stmt
+   */
+  StmtSRef PushSRef(const StmtNode* stmt) {
+    if (srefs_.empty()) {
+      srefs_.push_back(
+          StmtSRef(stmt,
+                   /*parent=*/nullptr,
+                   /*seq_index=*/-1));  // `seq_index` will be set properly in SetSeqIndex
+    } else {
+      StmtSRefNode* parent = srefs_.back().get();
+      srefs_.push_back(
+          StmtSRef(stmt, parent,
+                   /*seq_index=*/-1));  // `seq_index` will be set properly in SetSeqIndex
+    }
+    return srefs_.back();
+  }
+
+  /*! \brief Pop the top of the scope and record it in stmt2ref map */
+  StmtSRef PopAndRecordSRef() {
+    StmtSRef sref = std::move(srefs_.back());
+    self_->stmt2ref[sref->stmt] = sref;
+    srefs_.pop_back();
+    return sref;
+  }
+
+  void MakeBlockInfo(StmtSRef scope_root) {
+    // Calculate `BlockInfo::scope`
+    Array<StmtSRef> child_block_srefs = std::move(block_frames_.back());
+    BlockInfo& info =
+        self_->block_info.emplace(std::move(scope_root), BlockInfo(BlockScope(child_block_srefs)))
+            .first->second;
+    // TODO(@junrushao1994): calculate the flags
+    // Set `affine_binding`
+    info.affine_binding = false;
+    // Set `region_cover`
+    info.region_cover = false;
+    // Set `stage_pipeline`
+    info.scope->stage_pipeline = false;
+  }
+
+  void VisitStmt_(const ForNode* loop) final {
+    PushSRef(loop);
+    VisitStmt(loop->body);
+    PopAndRecordSRef();
+  }
+
+  void VisitStmt_(const BlockRealizeNode* realize) final {
+    realizes_.push_back(realize);
+    block_frames_.emplace_back();
+    const BlockNode* block = realize->block.get();
+    // Recursive visit
+    PushSRef(block);
+    VisitStmt(block->body);  // `block->init` is not visited
+    StmtSRef sref = PopAndRecordSRef();
+    // Create BlockInfo for the block
+    MakeBlockInfo(sref);
+    // Update parent scope
+    block_frames_.pop_back();
+    block_frames_.back().push_back(sref);
+    realizes_.pop_back();
+  }
+
+  void VisitStmt_(const SeqStmtNode* seq_stmt) final {
+    // Set `seq_index` information for SeqStmtNode
+    StmtVisitor::VisitStmt_(seq_stmt);
+    SetSeqIndexInChildren(self_, seq_stmt);
+  }
+
+  /*! \brief The result ScheduleStateNode */
+  ScheduleStateNode* self_;
+  /*! \brief The stack frame used to indicate the current scope */
+  std::vector<StmtSRef> srefs_;
+  /*! \brief The BlockRealize in the ancestors */
+  std::vector<const BlockRealizeNode*> realizes_;
+  /*! \brief The stack frames of blocks in the DFS visit. */
+  std::vector<Array<StmtSRef>> block_frames_;
+};
+
+/**************** Constructor ****************/
+
+ScheduleState::ScheduleState(IRModule mod, int debug_mode) {
+  CHECK_GE(debug_mode, -1) << "ValueError: negative `debug_mode` other than -1 is not supported";
+  data_ = StateCreator::Create(mod, debug_mode);
+  (*this)->DebugVerify();
+}
+
+ScheduleState::ScheduleState(PrimFunc func, int debug_mode)
+    : ScheduleState(IRModule({{GlobalVar("main"), func}}), debug_mode) {}
+
+/**************** Replace ****************/
+
+/*
+ * The goal of the replacement algorithm is to substitute a subtree `src_stmt` of the AST to a new
+ * subtree `tgt_stmt`, and maintain the corresponding sref tree accordingly, with some srefs reused,
+ * so that the srefs users hold doesn't expire. For example, if we split a loop into 2, and the
+ * original loop has a child block, then the sref to the child block should be reused, so that users
+ * won't have to acquire that sref again.
+ *
+ * The workflow of the replacement algorithm is:
+ * 1) Detect all possible reuses in class ReuseInfo
+ * 2) Remove the expired srefs in class SRefTreePruner
+ * 3) Update the reused the sref, and create the srefs for new statements, in class SRefUpdater
+ * 4) Renew the ancestors of `src_stmt` to reflect the replacement
+ */
+
+/*!
+ * \brief Record the different sref reuse types in the replacement
+ *
+ * 1) Intact: the subtree appears as the same object on both `src_stmt` and `tgt_stmt`,
+ * which, given the immutability of the IR, means the entire subtree is unchanged,
+ * and we do not need to recurse into the subtree.
+ *
+ * 2) Loop/Block sref reuse: for two different objects (`src`, `tgt`),
+ * which are both loops or both blocks,
+ * there is correspondence between them,
+ * which makes us to reuse the sref pointing to `src`, and change it to point to `tgt`.
+ *
+ * \note The intact reuse and loop sref reuse are collected in the ReuseCollector,
+ * while the block reuse is specified by the caller.
+ *
+ * \sa ReuseCollector
+ */
+struct ReuseInfo {
+  /*!
+   * \brief Kind 1. Intact reuse. If a stmt is in `intact`, it means its corresponding
+   * sref is reused and it is intact reuse.
+   */
+  std::unordered_set<const StmtNode*> intact;
+  /*!
+   * \brief Kind 2.1. Loop sref reuse
+   * If the loop var of a loop is in `loop_sref_possible_reuse`,
+   * it means that when `src_stmt` has a loop that uses this loop var,
+   * the reuse kind is loop sref reuse.
+   * \note For each loop var in `loop_sref_possible_reuse`, it is possible that `src_stmt` doesn't
+   * contain a loop that uses this loop var, and that is the reason why it is named "possible".
+   */
+  std::unordered_set<const VarNode*> loop_sref_possible_reuse;
+  /*!
+   * \brief Kind 2.2. Block sref reuse.
+   * Maps an old Block in `src_stmt` to a new block in `tgt_stmt`,
+   * indicating the sref to the old block should be reused in the sref to the new block.
+   */
+  std::unordered_map<const BlockNode*, const BlockNode*> block_sref_reuse;
+};
+
+/*!
+ * \brief A helper visitor which collects two cases of sref reuses in the `tgt_stmt`:
+ *
+ * 1) Intact: the subtree represented by `intact` appears on both old and new IR.
+ * Given the immutability of the IR, we can quickly decide that the entire subtree is unchanged,
+ * which means we do not need to visit into the subtree of the old statement.
+ *
+ * 2) Reused block/loop: for two different objects (`src`, `tgt`),
+ * which are both loops or both blocks,
+ * and there is correspondence between them,
+ * which makes us to reuse the sref pointing to `src`, and changes it to point to `tgt`,
+ */
+class ReuseCollector : public StmtVisitor {
+ public:
+  static ReuseInfo Collect(const ScheduleStateNode* self, const Stmt& tgt_stmt) {
+    ReuseCollector collector(self);
+    collector.VisitStmt(tgt_stmt);
+    ReuseInfo result;
+    result.intact = {collector.intact_.begin(), collector.intact_.end()};
+    result.loop_sref_possible_reuse = {collector.loop_vars_.begin(), collector.loop_vars_.end()};
+    // `result.block_reuse ` is not set here because ReuseCollector doesn't collect it,
+    // and it is supposed to be properly set by the caller.
+    return result;
+  }
+
+ private:
+  explicit ReuseCollector(const ScheduleStateNode* self) : self_(self) {}
+
+  void VisitStmt_(const ForNode* op) final {
+    if (self_->stmt2ref.count(op)) {
+      intact_.push_back(op);
+    } else {
+      // Collect loop vars for detecting reuse of loop sref
+      loop_vars_.push_back(op->loop_var.get());
+      StmtVisitor::VisitStmt_(op);
+    }
+  }
+
+  void VisitStmt_(const BlockNode* op) final {
+    if (self_->stmt2ref.count(op)) {
+      intact_.push_back(op);
+    } else {
+      StmtVisitor::VisitStmt_(op);
+    }
+  }
+
+  /*! \brief The schedule state to be worked on */
+  const ScheduleStateNode* self_;
+  /*! \brief The intact statements we have collected along the way of visiting */
+  std::vector<const StmtNode*> intact_;
+  /*! \brief The loop variable we collected in the tgt_stmt */
+  std::vector<const VarNode*> loop_vars_;
+};
+
+/*!
+ * \brief A helper visitor which removes the stale srefs in the `src_stmt`
+ * that are useless after the replacement.
+ *
+ * It uses the reuse information previously collected to
+ * 1) delete those srefs that are not reused.
+ * 2) return the sref objects that are loop/block sref reuses, but not intact reuses
+ */
+class SRefTreePruner : public StmtVisitor {
+ public:
+  /*!
+   * \brief The entry function
+   * \param self The schedule class
+   * \param info The reuse info about intact reuse and loop/block reuse
+   * \param src_stmt The `src_stmt` where stale srefs to be removed
+   * \return Mapping from the reuse elements to reused srefs, more specifically:
+   * 1) Loop reuse: maps a loop var to the reused sref
+   * 2) Block reuse: maps a block stmt to the reused sref,
+   * where the block comes from the subtree of `tgt_stmt`
+   * 3) Intact reuse: not returned
+   */
+  static std::unordered_map<const Object*, StmtSRef> Prune(ScheduleStateNode* self,
+                                                           const ReuseInfo& reuse_info,
+                                                           const Stmt& src_stmt) {
+    SRefTreePruner pruner(self, reuse_info);
+    pruner.VisitStmt(src_stmt);
+    return std::move(pruner.reused_srefs_);
+  }
+
+ private:
+  explicit SRefTreePruner(ScheduleStateNode* self, const ReuseInfo& reuse_info)
+      : self_(self), reuse_info_(reuse_info) {}
+
+  void VisitStmt_(const ForNode* op) final {
+    if (reuse_info_.intact.count(op)) {
+      return;
+    }
+    auto it = self_->stmt2ref.find(op);
+    ICHECK(it != self_->stmt2ref.end())
+        << "IndexError: Cannot find correpsonding StmtSRef for the loop:\n"
+        << GetRef<For>(op);
+    StmtSRef& sref = it->second;
+    // Detect reuse
+    const VarNode* loop_var = op->loop_var.get();
+    if (reuse_info_.loop_sref_possible_reuse.count(loop_var)) {
+      // sref can be reused
+      reused_srefs_.emplace(loop_var, std::move(sref));
+    } else {
+      sref->Reset();
+    }
+    // erase the statement
+    self_->stmt2ref.erase(it);
+    // detect recursively
+    VisitStmt(op->body);
+  }
+
+  void VisitStmt_(const BlockNode* op) final {
+    if (reuse_info_.intact.count(op)) {
+      return;
+    }
+    auto it = self_->stmt2ref.find(op);
+    ICHECK(it != self_->stmt2ref.end())
+        << "IndexError: Cannot find correpsonding StmtSRef for the block:\n"
+        << GetRef<Block>(op);
+    StmtSRef& sref = it->second;
+    // Detect reuse
+    auto reuse_it = reuse_info_.block_sref_reuse.find(op);
+    if (reuse_it != reuse_info_.block_sref_reuse.end()) {
+      // sref can be reused
+      reused_srefs_.emplace(reuse_it->second, std::move(sref));
+    } else {
+      sref->Reset();
+      self_->block_info.erase(sref);
+    }
+    // erase the statement
+    self_->stmt2ref.erase(it);
+    // detect recursively
+    // op->init is omitted
+    VisitStmt(op->body);
+  }
+
+  /*! \brief The schedule state we are working on */
+  ScheduleStateNode* self_;
+  /*! \brief The reuse information we collected previously */
+  const ReuseInfo& reuse_info_;
+  /*!
+   * \brief Reused srefs:
+   * 1) loop var -> StmtSRef
+   * 2) block stmt -> StmtSRef, where the block comes from the subtree of `tgt_stmt`
+   */
+  std::unordered_map<const Object*, StmtSRef> reused_srefs_;
+};
+
+/*!
+ * \brief Update the sref in the `tgt_stmt` given the reuse information
+ *
+ * After being updated, in the `tgt_stmt` subtree,
+ * 1) all `StmtSRefNode::parent`s are correct
+ * 2) all `StmtSRefNode::seq_index`s are correct, except for the root
+ * 3) all `StmtSRefNode::stmt`s are correct, except for the root
+ */
+class SRefUpdater : public StmtVisitor {
+ public:
+  static void Update(ScheduleStateNode* self, StmtSRefNode* src_stmt_parent,
+                     const std::unordered_map<const Object*, StmtSRef>& reused_srefs,
+                     const Stmt& tgt_stmt) {
+    SRefUpdater(self, src_stmt_parent, reused_srefs).VisitStmt(tgt_stmt);
+  }
+
+ private:
+  explicit SRefUpdater(ScheduleStateNode* self, StmtSRefNode* src_stmt_parent,
+                       const std::unordered_map<const Object*, StmtSRef>& reused_srefs)
+      : self_(GetRef<ScheduleState>(self)),
+        ancestors_{src_stmt_parent},
+        reused_srefs_(reused_srefs) {}
+
+  void VisitStmt_(const ForNode* op) final {
+    StmtSRef& sref = self_->stmt2ref[op];
+    // Detect intact reuse
+    if (sref.defined()) {
+      sref->parent = ancestors_.back();
+      sref->seq_index = -1;  // `seq_index` will be set properly in SetSeqIndex
+      return;
+    }
+    // Detect loop reuse
+    auto it = reused_srefs_.find(op->loop_var.get());
+    if (it != reused_srefs_.end()) {
+      // Update `stmt2ref[op]` to `reused_srefs_[op->loop_var]`
+      sref = it->second;
+      sref->stmt = op;
+      sref->parent = ancestors_.back();
+      sref->seq_index = -1;  // `seq_index` will be set properly in SetSeqIndex
+    } else {
+      // A new loop sref without reuse
+      sref = StmtSRef(/*stmt=*/op, /*parent=*/ancestors_.back(),
+                      /*seq_index=*/-1);  // `seq_index` will be set properly in SetSeqIndex
+    }
+    // Recursive visit
+    ancestors_.push_back(sref.get());
+    VisitStmt(op->body);
+    ancestors_.pop_back();
+  }
+
+  void VisitStmt_(const BlockNode* op) final {
+    StmtSRef& sref = self_->stmt2ref[op];
+    // Detect intact
+    if (sref.defined()) {
+      sref->parent = ancestors_.back();
+      sref->seq_index = -1;  // `seq_index` will be set properly in SetSeqIndex
+      return;
+    }
+    // Detect block reuse
+    auto it = reused_srefs_.find(op);
+    if (it != reused_srefs_.end()) {
+      // Update `stmt2ref[op]` to `reused_srefs_[op]`
+      sref = it->second;
+      sref->stmt = op;
+      sref->parent = ancestors_.back();
+      sref->seq_index = -1;  // `seq_index` will be set properly in SetSeqIndex
+    } else {
+      // A new block sref without reuse
+      sref = StmtSRef(/*stmt=*/op, /*parent=*/ancestors_.back(),
+                      /*seq_index=*/-1);  // `seq_index` will be set properly in SetSeqIndex
+    }
+    // Recursive visit
+    ancestors_.push_back(sref.get());
+    VisitStmt(op->body);
+    ancestors_.pop_back();
+    // Additionally, need to update the scope because the block is changed
+    UpdateBlockInfo(sref);
+  }
+
+  void VisitStmt_(const SeqStmtNode* seq_stmt) final {
+    StmtVisitor::VisitStmt_(seq_stmt);
+    SetSeqIndexInChildren(self_.get(), seq_stmt);
+  }
+
+  void UpdateBlockInfo(const StmtSRef& block_sref) {
+    using TIter = std::unordered_map<StmtSRef, BlockInfo, ObjectPtrHash, ObjectPtrEqual>::iterator;
+    // The caller is responsible for correcting the flags
+    BlockInfo new_info(BlockScope(GetChildBlocks(self_, block_sref)));
+    std::pair<TIter, bool> insert_result = self_->block_info.emplace(block_sref, new_info);
+    bool inserted = insert_result.second;
+    BlockInfo& info = insert_result.first->second;
+    if (inserted) {
+      // Insertion has happened, update the flags accordingly
+      BlockInfo& info = insert_result.first->second;
+      info.affine_binding = false;
+      info.region_cover = false;
+      info.scope->stage_pipeline = false;
+    } else {
+      // Insertion didn't take place, because the entry has been there before.
+      // In this case, we assume that flags are still valid so intentionally keep them unchanged
+      info.scope = std::move(new_info.scope);
+    }
+  }
+
+  /*! \brief The schedule state class to be worked on */
+  ScheduleState self_;
+  /*! \brief A stack containing all the ancestor For/Block nodes during the visit */
+  std::vector<StmtSRefNode*> ancestors_;
+  /*! \brief Maps the loop var / block to the reused sref */
+  const std::unordered_map<const Object*, StmtSRef>& reused_srefs_;
+};
+
+/*!
+ * \brief A helper that returns a new copy of `parent_stmt`,
+ * where the subtree `child_src_stmt` is replaced with the subtree `child_tgt_stmt`.
+ * \note The visitor assumes `child_src_stmt` is the child of `parent_stmt` in the sref tree.
+ */
+class ChildReplacer : private StmtMutator {
+ public:
+  static Stmt Replace(const StmtNode* parent_stmt, const StmtNode* child_src_stmt,
+                      const Stmt& child_tgt_stmt, int seq_index, bool allow_copy_on_write) {
+    // Check the invariant
+    ICHECK(child_src_stmt->IsInstance<BlockNode>() ||  //
+           child_src_stmt->IsInstance<ForNode>());
+    ICHECK(child_tgt_stmt->IsInstance<BlockNode>() ||  //
+           child_tgt_stmt->IsInstance<ForNode>() ||    //
+           child_tgt_stmt->IsInstance<BlockRealizeNode>());
+    ChildReplacer replacer(child_src_stmt, child_tgt_stmt, seq_index);
+    replacer.allow_copy_on_write_ = allow_copy_on_write;
+    return replacer.CopyOnWriteAndVisit(parent_stmt);
+  }
+
+ private:
+  explicit ChildReplacer(const StmtNode* src_stmt, const Stmt& tgt_stmt, int seq_index)
+      : src_stmt_(src_stmt), tgt_stmt_(tgt_stmt), seq_index_(seq_index) {}
+
+  Stmt VisitStmt(const Stmt& stmt) final {
+    if (stmt.get() == src_stmt_) {
+      // If the statement matches the `src_stmt` to be replaced, just return the `tgt_stmt`
+      return tgt_stmt_;
+    } else {
+      return StmtMutator::VisitStmt(stmt);
+    }
+  }
+
+  // Skipping sibling blocks and loops other than `src_stmt_`
+  Stmt VisitStmt_(const BlockNode* op) final { return GetRef<Stmt>(op); }
+  Stmt VisitStmt_(const ForNode* op) final { return GetRef<Stmt>(op); }
+
+  Stmt VisitStmt_(const SeqStmtNode* op) final {
+    int i = this->seq_index_;
+    int n = static_cast<int>(op->seq.size());
+    if (0 <= i && i < n) {
+      const Stmt& stmt = op->seq[i];
+      Optional<Stmt> new_stmt = NullOpt;
+      const StmtNode* src_stmt = this->src_stmt_;
+      // `stmt` can be For or BlockRealize
+      // `src_stmt` can be For or Block
+      // so the match from `stmt` to `src_stmt` can be
+      // 1) For -> For
+      // 2) BlockRealize -> Block
+      if (stmt.get() == src_stmt) {
+        // Case 1. src_stmt is For, stmt is For
+        new_stmt = tgt_stmt_;
+      } else if (const auto* realize = stmt.as<BlockRealizeNode>()) {
+        // Case 2. stmt is BlockRealize, src_stmt is Block
+        if (realize->block.get() == src_stmt) {
+          const auto* tgt_block = TVM_TYPE_AS(tgt_block, tgt_stmt_, BlockNode);
+          ObjectPtr<BlockRealizeNode> new_realize = make_object<BlockRealizeNode>(*realize);
+          new_realize->block = GetRef<Block>(tgt_block);
+          new_stmt = BlockRealize(std::move(new_realize));
+        }
+      }
+      // Move new_stmt to position i
+      if (new_stmt.defined()) {
+        ObjectPtr<SeqStmtNode> new_seq_stmt = CopyOnWrite(op);
+        new_seq_stmt->seq.Set(i, new_stmt.value());
+        return SeqStmt(std::move(new_seq_stmt));
+      }
+    }
+    return StmtMutator::VisitStmt_(op);
+  }
+
+  Stmt CopyOnWriteAndVisit(const StmtNode* parent_stmt) {
+    // Step 1. Copy-on-write the `parent_stmt` and extract its `body`,
+    // where `body` means the body of either a block or a loop
+    // Step 2. Mutate the `block/loop->body`, searching for `child_old_stmt`
+    // and replace it with `child_tgt_stmt`
+    if (parent_stmt->IsInstance<BlockNode>()) {
+      auto* block = const_cast<BlockNode*>(static_cast<const BlockNode*>(parent_stmt));
+      ObjectPtr<BlockNode> new_block = CopyOnWrite(block);
+      new_block->body = this->VisitStmt(new_block->body);
+      return Block(std::move(new_block));
+    } else if (parent_stmt->IsInstance<ForNode>()) {
+      auto* loop = const_cast<ForNode*>(static_cast<const ForNode*>(parent_stmt));
+      ObjectPtr<ForNode> new_loop = CopyOnWrite(loop);
+      new_loop->body = this->VisitStmt(new_loop->body);
+      return For(std::move(new_loop));
+    }
+    LOG(FATAL) << "TypeError: Unexpected type: " << parent_stmt->GetTypeKey();
+    throw;
+  }
+
+  /*! \brief The `src_stmt` to be replaced */
+  const StmtNode* src_stmt_;
+  /*! \brief The `tgt_stmt` to be replaced in */
+  const Stmt& tgt_stmt_;
+  /*!
+   * \brief The `seq_index` of the `src_stmt`
+   * \sa StmtSRefNode
+   */
+  int seq_index_;
+};
+
+void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_stmt,
+                                const Map<Block, Block>& _block_sref_reuse) {
+  if (this->debug_mode != 0) {
+    const StmtNode* src_stmt = _src_sref->stmt;
+    bool input_correct =
+        (src_stmt->IsInstance<ForNode>() && tgt_stmt->IsInstance<ForNode>()) ||
+        (src_stmt->IsInstance<ForNode>() && tgt_stmt->IsInstance<BlockRealizeNode>()) ||
+        (src_stmt->IsInstance<BlockNode>() && tgt_stmt->IsInstance<BlockNode>());
+    if (!input_correct) {
+      LOG(FATAL) << "TypeError: src_stmt has type: " << src_stmt->GetTypeKey()
+                 << ". tgt_stmt has type: " << tgt_stmt->GetTypeKey() << ".\nsrc_stmt:\n"
+                 << GetRef<Stmt>(src_stmt) << "\ntgt_stmt:\n"
+                 << tgt_stmt;
+    }
+  }
+  // Rule out the case that no replacement happens
+  if (_src_sref->stmt == tgt_stmt.get()) {
+    return;
+  }
+  // Reset sref as a new sref so that its content won't be affected by subsequent changes
+  StmtSRef src_sref(_src_sref->stmt, _src_sref->parent, _src_sref->seq_index);
+  Stmt src_stmt = GetRef<Stmt>(src_sref->stmt);
+  // Step 1. Create all the nodes needed for the new sref tree.
+  // After this step
+  // 1) all `parent`s are correct
+  // 2) all `seq_index`s are correct, except for the root
+  // 3) all `stmt`s are correct, except for the root
+  {
+    // Step 0. Setup block_sref_reuse
+    std::unordered_map<const BlockNode*, const BlockNode*> block_sref_reuse;
+    block_sref_reuse.reserve(_block_sref_reuse.size() + 1);
+    for (const auto& kv : _block_sref_reuse) {
+      block_sref_reuse.emplace(kv.first.get(), kv.second.get());
+    }
+    // Step 1.1. Collect info for different kinds of reuses
+    // 1) intact
+    // 2) loop/block reuse
+    ReuseInfo reuse_info = ReuseCollector::Collect(this, tgt_stmt);
+    reuse_info.block_sref_reuse = std::move(block_sref_reuse);
+    // Step 1.2. Collect loop/block reuse to their corresponding srefs
+    // and remove those srefs in the `src_stmt` that are no longer used after replacement
+    std::unordered_map<const Object*, StmtSRef> reused_srefs =
+        SRefTreePruner::Prune(this, reuse_info, src_stmt);
+    // Step 1.3. Update the sref tree, inserting newly created srefs and properly handle reused
+    // srefs in `tgt_stmt`
+    SRefUpdater::Update(this, src_sref->parent, reused_srefs, tgt_stmt);
+  }
+  // Step 2. Set the ancestors' children properly
+  //   Iteratively visit the ancestors, creating new ones whose `body`s are properly fixed.
+  //   The visit stops when all the ancestors are uniquely referenced, i.e. can mutate inplace.
+  //   Along the way, because we create a new ancestor path,
+  //   we need to update those sref points from old ancestors to newly created ones
+  // Variables:
+  // 1) `num_copy_steps`. The maximum number of hops until we need to copy. To reach a node that
+  //   can be mutated inplace, it needs `num_copy_steps + 1` hops.
+  // 2) `need_module_copy`. If true, need to mutate the PrimFunc and IRModule the sref belongs to.
+  // 3) `g_var` and `g_func`. Indicate which GlobalVar and PrimFunc the sref corresponds to
+  int num_copy_steps = -1;
+  bool need_module_copy = false;
+  const PrimFuncNode* g_func = nullptr;
+  GlobalVar g_var;
+  {
+    int i = 0;
+    const StmtSRefNode* p = src_sref.get();
+    while (true) {
+      if (!p->stmt->unique()) {
+        num_copy_steps = i;
+      }
+      if (p->parent == nullptr) {
+        break;
+      }
+      ++i;
+      p = p->parent;
+    }
+    // Find `g_func` and `g_var` where the `src_sref` is in
+    g_func = GetRootPrimFunc(this->mod, p->stmt, &g_var);
+    need_module_copy = num_copy_steps == i ||             //
+                       !this->mod.unique() ||             //
+                       !this->mod->functions.unique() ||  //
+                       !g_func->unique();
+  }
+  // Loop invariant:
+  //
+  // Before step `i`:
+  // 1) `child_sref` is `src_sref` going up by `i` steps
+  // 2) `child_tgt_stmt` is the subtree that `child_sref` should correspond to after replacement
+  // 3) except for the subtree root, srefs that point to the subtree of `child_tgt_stmt` are
+  // correct 4) for the subtree root of `child_tgt_stmt`, `child_sref` has not pointed to it yet
+  // 5) `tgt_stmt` is of type Loop, Block or BlockRealize
+  //
+  // During step `i`:
+  // 1) Create `parent_stmt` that corresponds to `child_sref->parent`
+  // 2) Point `child_sref` to `child_tgt_stmt`
+  // 3) `tgt_stmt` is of type Loop or Block
+  StmtSRefNode* child_sref = src_sref.get();
+  Stmt child_tgt_stmt = std::move(tgt_stmt);
+  for (int i = 0; (need_module_copy || i <= num_copy_steps) && child_sref->parent != nullptr; ++i) {
+    bool can_directly_mutate_parent = !need_module_copy && i == num_copy_steps;
+    // Replace `child_sref->stmt` to `child_tgt_stmt`.
+    const StmtNode* parent_stmt = child_sref->parent->stmt;
+    const StmtNode* child_src_stmt = child_sref->stmt;
+    // Step 2.1. Link `child_sref` to `child_tgt_stmt`
+    if (i == 0) {
+      // As the invariance of SRefUpdater,
+      // the `seq_index` of the root of `tgt_stmt` is set as -1,
+      // which might be incorrect
+      SetSeqIndex(this, child_tgt_stmt, child_sref->seq_index);
+    } else {
+      // Point `child_sref` to `child_tgt_stmt`
+      UpdateSRef(this, child_sref, child_tgt_stmt.get());
+    }
+    // Step 2.2. Create `new_parent_stmt`, by mutating the body of `parent_stmt`
+    Stmt new_parent_stmt =
+        ChildReplacer::Replace(parent_stmt, child_src_stmt, child_tgt_stmt,
+                               /*seq_index=*/child_sref->seq_index,
+                               /*allow_copy_on_write=*/can_directly_mutate_parent);
+    // Step 2.3. Go to next parent
+    if (can_directly_mutate_parent) {
+      // If the node can be directly mutated inplace,
+      // then there is no need to update its parent and the function
+      break;
+    }
+    child_tgt_stmt = std::move(new_parent_stmt);
+    child_sref = child_sref->parent;
+  }
+  // Step 3. Handle the case that we mutate the root
+  if (need_module_copy) {
+    // From the loop invariant, upon exit, while its subtree is properly set,
+    // `child_sref` is not properly to `child_tgt_stmt` yet.
+    if (src_sref->parent != nullptr) {
+      // Not replacing a root
+      UpdateSRef(this, child_sref, child_tgt_stmt.get());
+    }
+    // Ensure the uniqueness of `this->mod` and `this->mod->functions`
+    IRModuleNode* new_mod = this->mod.CopyOnWrite();
+    MapNode* new_map = new_mod->functions.CopyOnWrite();
+    // Move out the PrimFunc where the sref belong while ensuring uniqueness
+    PrimFunc ref_new_func = Downcast<PrimFunc>(std::move(new_map->at(g_var)));
+    ICHECK(ref_new_func.get() == g_func);
+    PrimFuncNode* new_func = ref_new_func.CopyOnWrite();
+    // If `g_func` was not unique, after the 3 lines above:
+    //   `ref_new_func` points to a unique PrimFunc
+    //   `g_func` points to the previous PrimFunc if it is not unique
+    // If `g_func` was unique, after the 3 lines above:
+    //   `ref_new_func` points to the same unique function that `g_func` points to
+    // Update the body of the function the sref belongs to Assign
+    const auto* realize = TVM_TYPE_AS(realize, g_func->body, BlockRealizeNode);
+    // Make `child_tgt_stmt` the root block
+    const auto* child_block = TVM_TYPE_AS(child_block, child_tgt_stmt, BlockNode);
+    ObjectPtr<BlockRealizeNode> new_realize = make_object<BlockRealizeNode>(*realize);
+    new_realize->block = GetRef<Block>(child_block);
+    new_func->body = BlockRealize(std::move(new_realize));
+    // Finally, move the `ref_new_func` back and update `this->mod`
+    new_map->at(g_var) = std::move(ref_new_func);
+    this->mod = GetRef<IRModule>(new_mod);
+  }
+  constexpr int kVerifySRefTree = static_cast<int>(ScheduleDebugMask::kVerifySRefTree);
+  if (debug_mode == -1 || (debug_mode & kVerifySRefTree)) {
+    VerifySRefTree(GetRef<ScheduleState>(this));
+  }
+}
+
+void ScheduleStateNode::DebugVerify() const {
+  constexpr int kVerifySRefTree = static_cast<int>(ScheduleDebugMask::kVerifySRefTree);
+  constexpr int kVerifyAffineBinding = static_cast<int>(ScheduleDebugMask::kVerifyAffineBinding);
+  constexpr int kVerifyRegionCover = static_cast<int>(ScheduleDebugMask::kVerifyRegionCover);
+  constexpr int kVerifyStagePipeline = static_cast<int>(ScheduleDebugMask::kVerifyStagePipeline);
+  ICHECK_GE(debug_mode, -1);
+  if (debug_mode == -1 || (debug_mode & kVerifySRefTree)) {
+    VerifySRefTree(GetRef<ScheduleState>(this));
+  }
+  if (debug_mode == -1 || (debug_mode & kVerifyAffineBinding)) {
+    // TODO(@junrushao1994): Verify affine block binding
+  }
+  if (debug_mode == -1 || (debug_mode & kVerifyRegionCover)) {
+    // TODO(@junrushao1994): Verify region cover
+  }
+  if (debug_mode == -1 || (debug_mode & kVerifyStagePipeline)) {
+    // TODO(@junrushao1994): Verify stage pipeline
+  }
+}
+
+/**************** BlockInfo-related ****************/
+
+BlockInfo ScheduleStateNode::GetBlockInfo(const StmtSRef& block_sref) const {
+  const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  auto it = this->block_info.find(block_sref);
+  CHECK(it != this->block_info.end())
+      << "IndexError: Cannot find the corresponding BlockScope to the block sref:\n"
+      << GetRef<Stmt>(block_sref->stmt);
+  return it->second;
+}
+
+/**************** FFI ****************/
+
+TVM_REGISTER_NODE_TYPE(ScheduleStateNode);
+TVM_REGISTER_GLOBAL("tir.schedule.ScheduleState").set_body_typed([](ObjectRef obj, int debug_mode) {
+  if (const auto* func = obj.as<PrimFuncNode>()) {
+    return ScheduleState(GetRef<PrimFunc>(func), debug_mode);
+  }
+  if (const auto* mod = obj.as<IRModuleNode>()) {
+    return ScheduleState(GetRef<IRModule>(mod), debug_mode);
+  }
+  LOG(FATAL) << "TypeError: Expects `IRModule` or `PrimFunc`, but gets: " << obj->GetTypeKey();
+  throw;
+});
+TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStateGetBlockScope")
+    .set_body_method<ScheduleState>(&ScheduleStateNode::GetBlockScope);
+TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStateReplace")
+    .set_body_method<ScheduleState>(&ScheduleStateNode::Replace);
+TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStateGetSRef")
+    .set_body_typed([](ScheduleState self, Stmt stmt) -> Optional<StmtSRef> {
+      auto it = self->stmt2ref.find(stmt.get());
+      return it != self->stmt2ref.end() ? it->second : Optional<StmtSRef>(NullOpt);
+    });
+
+}  // namespace tir
+}  // namespace tvm
diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h
new file mode 100644
index 0000000..63ec77d
--- /dev/null
+++ b/src/tir/schedule/utils.h
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_TIR_SCHEDULE_UTILS_H_
+#define TVM_TIR_SCHEDULE_UTILS_H_
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/schedule/state.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include "./analysis.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief A helper macro to convert an sref to the statement it points to,
+ * then check if the downcasting succeeded.
+ * \param Result The result variable, used for checking
+ * \param SRef The SRef to be casted
+ * \param Type The type to be casted to, can be Block or For
+ */
+#define TVM_SREF_AS_OR_ERR(Result, SRef, Type) \
+  SRef->StmtAs<Type>();                        \
+  ICHECK(Result)
+
+/*!
+ * \brief A helper macro to convert an sref to the block it points to,
+ * throwing an internal error if downcasting fails
+ * \param Result The result variable, used for checking
+ * \param SRef The SRef to be casted
+ */
+#define TVM_SREF_TO_BLOCK(Result, SRef)                   \
+  TVM_SREF_AS_OR_ERR(Result, SRef, ::tvm::tir::BlockNode) \
+      << "TypeError: Expects StmtSRef `" << #SRef         \
+      << "` points to `Block`, but gets: " << (SRef->stmt ? SRef->stmt->GetTypeKey() : "None")
+
+/*!
+ * \brief A helper macro to convert an sref to the for-loop it points to,
+ * throwing an internal error if downcasting fails
+ * \param Result The name of the result variable, used for checking
+ * \param SRef The SRef to be casted
+ */
+#define TVM_SREF_TO_FOR(Result, SRef)                   \
+  TVM_SREF_AS_OR_ERR(Result, SRef, ::tvm::tir::ForNode) \
+      << "TypeError: Expects StmtSRef `" << #SRef       \
+      << "` points to `Loop`, but gets: " << (SRef->stmt ? SRef->stmt->GetTypeKey() : "None")
+
+/*!
+ * \brief Downcast a TVM ObjectRef to its corresponding container using `ObjectRef::as<Type>`,
+ * then check if the downcasting succeeded.
+ * \param Result The result variable, used for checking
+ * \param From The ObjectRef to be downcasted
+ * \param Type The type to be downcasted to
+ */
+#define TVM_TYPE_AS_OR_ERR(Result, From, Type) \
+  From.as<Type>();                             \
+  ICHECK(Result)
+
+/*!
+ * \brief Downcast a TVM ObjectRef to its corresponding container using `ObjectRef::as<Type>`,
+ * throwing an internal error if downcast fails.
+ * \param Result The result variable, used for checking
+ * \param From The ObjectRef to be downcasted
+ * \param Type The type to be downcasted to
+ */
+#define TVM_TYPE_AS(Result, From, Type)                                           \
+  TVM_TYPE_AS_OR_ERR(Result, From, Type)                                          \
+      << "TypeError: Expects `" << #From << "` to have type `" << Type::_type_key \
+      << "`, but gets: " << (From.defined() ? From->GetTypeKey() : "None")
+
+}  // namespace tir
+}  // namespace tvm
+
+#endif  // TVM_TIR_SCHEDULE_UTILS_H_
diff --git a/tests/python/unittest/test_tir_block_scope.py b/tests/python/unittest/test_tir_block_scope.py
new file mode 100644
index 0000000..4a914f5
--- /dev/null
+++ b/tests/python/unittest/test_tir_block_scope.py
@@ -0,0 +1,145 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+import tvm
+from tvm import tir
+from tvm.script import ty
+from tvm.tir.schedule import DepKind
+from tvm.tir.stmt_functor import post_order_visit
+
+# pylint: disable=no-member,invalid-name,unused-variable
+
+
+@tvm.script.tir
+def elementwise(a: ty.handle, c: ty.handle) -> None:
+    A = tir.match_buffer(a, (128, 128), "float32")
+    C = tir.match_buffer(c, (128, 128), "float32")
+    B = tir.alloc_buffer((128, 128), "float32")
+    with tir.block([128, 128], "B") as [vi, vj]:
+        B[vi, vj] = A[vi, vj] * 2.0
+    with tir.block([128, 128], "C") as [vi, vj]:
+        C[vi, vj] = B[vi, vj] + 1.0
+
+
+@tvm.script.tir
+def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
+    A = tir.match_buffer(a, [128, 128])
+    B = tir.match_buffer(b, [128, 128])
+    C = tir.match_buffer(c, [128, 128])
+    for i, j in tir.grid(128, 128):
+        with tir.block([128, 128], "init") as [vi, vj]:
+            C[vi, vj] = tir.float32(0)
+        for k in range(0, 128):
+            with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]:
+                C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+
+@tvm.script.tir
+def war_dependency(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
+    A = tir.match_buffer(a, (128, 128))
+    B = tir.match_buffer(b, (128, 128))
+    C = tir.match_buffer(c, (128, 128))
+
+    for i, j in tir.grid(128, 128):
+        with tir.block([128, 128], "C") as [vi, vj]:
+            C[vi, vj] = B[vi, vj] + 1.0
+        with tir.block([128, 128], "B") as [vi, vj]:
+            B[vi, vj] = A[vi, vj] * 2.0
+
+
+# pylint: enable=no-member,invalid-name,unused-variable
+
+# pylint: disable=invalid-name
+
+
+def _get_block(s: tir.ScheduleState, name_hint: str) -> tir.StmtSRef:
+    result = None
+
+    def f_visit(node):
+        nonlocal result
+        if isinstance(node, tvm.tir.Block) and node.name_hint == name_hint:
+            result = node
+
+    func = s.mod["main"]
+    post_order_visit(func.body, f_visit)
+    assert result is not None and isinstance(result, tvm.tir.Block)
+    return s.get_sref(result)
+
+
+def test_elementwise_dependency():
+    s = tir.ScheduleState(elementwise, debug_mode=True)
+    root = _get_block(s, "root")
+    block_b = _get_block(s, "B")
+    block_c = _get_block(s, "C")
+    # Check get_deps_by_src
+    (dep,) = s.get_block_scope(root).get_deps_by_src(block_b)
+    assert dep.src.same_as(block_b)
+    assert dep.dst.same_as(block_c)
+    assert dep.kind == DepKind.RAW
+    # Check get_deps_by_dst
+    (dep,) = s.get_block_scope(root).get_deps_by_dst(block_c)
+    assert dep.src.same_as(block_b)
+    assert dep.dst.same_as(block_c)
+    assert dep.kind == DepKind.RAW
+
+
+def test_matmul_dependency():
+    s = tir.ScheduleState(matmul, debug_mode=True)
+    root = _get_block(s, "root")
+    init = _get_block(s, "init")
+    update = _get_block(s, "update")
+    # Check get_deps_by_src
+    p0, p1 = s.get_block_scope(root).get_deps_by_src(init)
+    assert p0.src.same_as(init)
+    assert p0.dst.same_as(update)
+    assert p1.src.same_as(init)
+    assert p1.dst.same_as(update)
+    assert (p0.kind == DepKind.RAW and p1.kind == DepKind.WAW) or (
+        p0.kind == DepKind.WAW and p1.kind == DepKind.RAW
+    )
+    # Check get_deps_by_dst
+    p0, p1 = s.get_block_scope(root).get_deps_by_dst(update)
+    assert p0.src.same_as(init)
+    assert p0.dst.same_as(update)
+    assert p1.src.same_as(init)
+    assert p1.dst.same_as(update)
+    assert (p0.kind == DepKind.RAW and p1.kind == DepKind.WAW) or (
+        p0.kind == DepKind.WAW and p1.kind == DepKind.RAW
+    )
+
+
+def test_war_dependency():
+    s = tir.ScheduleState(war_dependency, debug_mode=True)
+    root = _get_block(s, "root")
+    block_c = _get_block(s, "C")
+    block_b = _get_block(s, "B")
+    # Check get_deps_by_src
+    (dep,) = s.get_block_scope(root).get_deps_by_src(block_c)
+    assert dep.src.same_as(block_c)
+    assert dep.dst.same_as(block_b)
+    assert dep.kind == DepKind.WAR
+    # Check get_deps_by_dst
+    (dep,) = s.get_block_scope(root).get_deps_by_dst(block_b)
+    assert dep.src.same_as(block_c)
+    assert dep.dst.same_as(block_b)
+    assert dep.kind == DepKind.WAR
+
+
+if __name__ == "__main__":
+    test_elementwise_dependency()
+    test_matmul_dependency()
+    test_war_dependency()
diff --git a/tests/python/unittest/test_tir_schedule_state.py b/tests/python/unittest/test_tir_schedule_state.py
new file mode 100644
index 0000000..ac98725
--- /dev/null
+++ b/tests/python/unittest/test_tir_schedule_state.py
@@ -0,0 +1,352 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+
+import gc
+
+import tvm
+from tvm import tir
+from tvm.ir import IRModule
+from tvm.script import ty
+
+# pylint: disable=no-member,invalid-name,unused-variable
+
+
+@tvm.script.tir
+def elementwise(a: ty.handle, c: ty.handle) -> None:
+    A = tir.match_buffer(a, (128, 128), "float32")
+    C = tir.match_buffer(c, (128, 128), "float32")
+    B = tir.alloc_buffer((128, 128), "float32")
+    with tir.block([128, 128], "B") as [vi, vj]:
+        B[vi, vj] = A[vi, vj] * 2.0
+    with tir.block([128, 128], "C") as [vi, vj]:
+        C[vi, vj] = B[vi, vj] + 1.0
+
+
+@tvm.script.tir
+def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
+    A = tir.match_buffer(a, [128, 128])
+    B = tir.match_buffer(b, [128, 128])
+    C = tir.match_buffer(c, [128, 128])
+    for i, j in tir.grid(128, 128):
+        with tir.block([128, 128], "init") as [vi, vj]:
+            C[vi, vj] = tir.float32(0)
+        for k in range(0, 128):
+            with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]:
+                C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+
+@tvm.script.tir
+def block_in_opaque_block(a: ty.handle, b: ty.handle) -> None:
+    A = tir.match_buffer(a, (128, 128), "float32")
+    B = tir.match_buffer(b, (128, 128), "float32")
+    with tir.block([128], "B") as vi:
+        tir.reads([A[0:128, 0:128]])
+        tir.writes([B[0:128, 0:128]])
+        B[vi, 0] = A[vi, 0]
+        if A[vi, 0] == 0.0:
+            with tir.block([], "C"):
+                tir.reads([A[0:128, 0:128]])
+                tir.writes([B[0:128, 0:128]])
+                with tir.block([128], "D") as vj:
+                    B[vi, vj] = A[vi, vj] * 3.0
+        else:
+            with tir.block([], "E"):
+                tir.reads([A[0:128, 0:128]])
+                tir.writes([B[0:128, 0:128]])
+                with tir.block([128], "F") as vj:
+                    B[vi, vj] = A[vi, vj] * 2.0
+
+
+# pylint: enable=no-member,invalid-name,unused-variable
+
+
+def replace_ir_builder(deep_copy=False, realize=False):
+    new_func = tvm.script.from_source(tvm.script.asscript(elementwise))
+    s = tir.ScheduleState(new_func, debug_mode=True)
+    target = tvm.tir.Block(
+        iter_vars=[],
+        reads=[],
+        writes=[],
+        name_hint="target",
+        body=s.mod["main"].body.block.body[1],
+        init=None,
+        alloc_buffers=None,
+        match_buffers=None,
+        annotations=None,
+    )
+    if realize:
+        target = tvm.tir.BlockRealize(
+            iter_values=[],
+            predicate=True,
+            block=target,
+        )
+    if deep_copy:
+        target.__setstate__(target.__getstate__())
+    gc.collect()
+    return s, target
+
+
+def replace_ir_builder_module(deep_copy=False, realize=False):
+    new_func = tvm.script.from_source(tvm.script.asscript(elementwise))
+    other_func = tvm.script.from_source(tvm.script.asscript(elementwise))
+    mod = IRModule(functions={"main": new_func, "other": other_func})
+    s = tir.ScheduleState(mod, debug_mode=True)
+    target = tvm.tir.Block(
+        iter_vars=[],
+        reads=[],
+        writes=[],
+        name_hint="target",
+        body=s.mod["main"].body.block.body[1],
+        init=None,
+        alloc_buffers=None,
+        match_buffers=None,
+        annotations=None,
+    )
+    if realize:
+        target = tvm.tir.BlockRealize(
+            iter_values=[],
+            predicate=True,
+            block=target,
+        )
+    if deep_copy:
+        target.__setstate__(target.__getstate__())
+    gc.collect()
+    return s, target
+
+
+def replace_ir_builder_with_opaque():
+    func = tvm.script.from_source(tvm.script.asscript(block_in_opaque_block))
+    s = tir.ScheduleState(func, debug_mode=True)
+    gc.collect()
+    return s
+
+
+def test_replace_direct_write0():
+    s, target = replace_ir_builder(realize=True)
+    old_hash = s.mod["main"].__hash__()
+    sref = s.get_sref(s.mod["main"].body.block.body[1])
+    s.replace(sref, target)
+    # There is no other reference so the AST node can be written directly
+    assert old_hash == s.mod["main"].__hash__()
+    # Check the replaced part is equal to the target
+    tvm.ir.assert_structural_equal(s.mod["main"].body.block.body[1], target)
+    # The target reuse the stmt of the sref, so the sref won't be None
+    assert sref.stmt is not None
+
+
+def test_replace_direct_write1():
+    s, target = replace_ir_builder(realize=True)
+    old_hash = s.mod["main"].body.block.body.__hash__()
+    hold_ref = s.mod["main"].body.block.body[1]
+    sref = s.get_sref(s.mod["main"].body.block.body[1])
+    s.replace(sref, target)
+    # There is no other reference so the AST node can be written directly
+    assert old_hash == s.mod["main"].body.block.body.__hash__()
+    assert not tvm.ir.structural_equal(hold_ref.body, target)
+    # Check the replaced part is equal to the target
+    tvm.ir.assert_structural_equal(s.mod["main"].body.block.body[1], target)
+    # The target reuse `sref.stmt`, so the sref won't be None
+    assert sref.stmt is not None
+
+
+def test_replace_copy():
+    s, target = replace_ir_builder(deep_copy=True, realize=True)
+    old_hash = s.mod["main"].__hash__()
+    # We hold another reference of func
+    old_func = s.mod["main"]
+    sref = s.get_sref(s.mod["main"].body.block.body[0])
+    s.replace(sref, target)
+    # We need to copy the whole func to remain the old_func unchanged
+    assert old_hash != s.mod["main"].__hash__()
+    assert not tvm.ir.structural_equal(old_func.body, s.mod["main"].body)
+    assert old_hash == old_func.__hash__()
+    # Check the replaced part is equal to the target
+    tvm.ir.assert_structural_equal(s.mod["main"].body.block.body[0], target)
+    # The replaced AST node will be deleted, so the ref will be None
+    assert sref.stmt is None
+
+
+def test_replace_partial_copy0():
+    s, target = replace_ir_builder(deep_copy=True, realize=True)
+    func_old_hash = s.mod["main"].__hash__()
+    hold_ref = s.mod["main"].body.block.body[0]
+    ref_old_hash = hold_ref.__hash__()
+    sref = s.get_sref(s.mod["main"].body.block.body[0].body)
+    other_part_hash = s.mod["main"].body.block.body[1].__hash__()
+    s.replace(sref, target)
+    # The stmt is held by `hold_sref`, so it will be coped in copy-on-write because the ref count is not unique
+    assert ref_old_hash != s.mod["main"].body.block.body[0].__hash__()
+    assert not tvm.ir.structural_equal(hold_ref.body, target)
+    # The function and the other part stmt can be directly written
+    assert func_old_hash == s.mod["main"].__hash__()
+    assert other_part_hash == s.mod["main"].body.block.body[1].__hash__()
+    # Check the replaced part is equal to the target
+    tvm.ir.assert_structural_equal(s.mod["main"].body.block.body[0].body, target)
+    # The replaced AST node will be deleted, so the ref will be None
+    assert sref.stmt is None
+
+
+def test_replace_partial_copy1():
+    s, target = replace_ir_builder(deep_copy=True)
+    func_old_hash = s.mod["main"].__hash__()
+    hold_ref = s.mod["main"].body.block.body[0].body
+    stmt_old_hash = s.mod["main"].body.block.body[0].__hash__()
+    sref = s.get_sref(s.mod["main"].body.block.body[0].body.body.block)
+    other_part_hash = s.mod["main"].body.block.body[1].__hash__()
+    s.replace(sref, target)
+    # The parent stmt will change since there is only one reference
+    assert stmt_old_hash == s.mod["main"].body.block.body[0].__hash__()
+    assert not tvm.ir.structural_equal(hold_ref.body, target)
+    # The function and the other part stmt can be directly written
+    assert func_old_hash == s.mod["main"].__hash__()
+    assert other_part_hash == s.mod["main"].body.block.body[1].__hash__()
+    # Check the replaced part is equal to the target
+    tvm.ir.assert_structural_equal(s.mod["main"].body.block.body[0].body.body.block, target)
+    # The replaced AST node will be deleted, so the ref will be None
+    assert sref.stmt is None
+
+
+def test_replace_root_write():
+    s, target = replace_ir_builder()
+    old_hash = s.mod["main"].__hash__()
+    sref = s.get_sref(s.mod["main"].body.block)
+    s.replace(sref, target)
+    # Check no copy and the new body equals to target
+    assert old_hash == s.mod["main"].__hash__()
+    tvm.ir.assert_structural_equal(s.mod["main"].body.block, target)
+
+
+def test_replace_root_copy0():
+    s, target = replace_ir_builder(deep_copy=True)
+    old_hash = s.mod["main"].__hash__()
+    func_ref = s.mod["main"]
+    sref = s.get_sref(s.mod["main"].body.block)
+    s.replace(sref, target)
+    # Check the new body equals to target
+    assert old_hash != s.mod["main"].__hash__()
+    tvm.ir.assert_structural_equal(s.mod["main"].body.block, target)
+    # Check the original func remains unchanged
+    assert old_hash == func_ref.__hash__()
+    assert not tvm.ir.structural_equal(func_ref.body, target)
+
+
+def test_replace_root_copy1():
+    s, target = replace_ir_builder(deep_copy=True, realize=True)
+    old_hash = s.mod["main"].body.block.__hash__()
+    func_ref = s.mod["main"].body.block
+    sref = s.get_sref(s.mod["main"].body.block.body[0])
+    s.replace(sref, target)
+    # Check the new body equals to target
+    assert old_hash != s.mod["main"].body.block.__hash__()
+    tvm.ir.assert_structural_equal(s.mod["main"].body.block.body[0], target)
+    # Check the original func remains unchanged
+    assert old_hash == func_ref.__hash__()
+    assert not tvm.ir.structural_equal(func_ref.body, target)
+
+
+def test_replace_root_copy2():
+    s, target = replace_ir_builder(deep_copy=True)
+    old_hash = s.mod.functions.__hash__()
+    func_ref = s.mod.functions
+    sref = s.get_sref(s.mod["main"].body.block)
+    s.replace(sref, target)
+    # Check the new body equals to target
+    assert old_hash != s.mod.functions.__hash__()
+    tvm.ir.assert_structural_equal(s.mod["main"].body.block, target)
+    # Check the original func remains unchanged
+    assert old_hash == func_ref.__hash__()
+    for _, v in func_ref.items():
+        assert not tvm.ir.structural_equal(v.body.block, target)
+
+
+def test_replace_root_copy3():
+    s, target = replace_ir_builder(deep_copy=True)
+    old_hash = s.mod.__hash__()
+    func_ref = s.mod
+    sref = s.get_sref(s.mod["main"].body.block)
+    s.replace(sref, target)
+    # Check the new body equals to target
+    assert old_hash != s.mod.__hash__()
+    tvm.ir.assert_structural_equal(s.mod["main"].body.block, target)
+    # Check the original func remains unchanged
+    assert old_hash == func_ref.__hash__()
+    assert not tvm.ir.structural_equal(func_ref["main"].body.block, target)
+
+
+def test_replace_block_remap():
+    func = elementwise
+    s = tir.ScheduleState(func, debug_mode=True)
+    # The target stmt
+    target = matmul.body.block.body.body.body[0].block
+    sref = s.get_sref(s.mod["main"].body.block.body[0].body.body.block)
+    s.replace(sref, target, {sref.stmt: target})
+    sref_new = s.get_sref(s.mod["main"].body.block.body[0].body.body.block)
+    # Check the original sref has been remapped
+    assert sref.__hash__() == sref_new.__hash__()
+    tvm.ir.assert_structural_equal(sref.stmt, target)
+
+
+def test_replace_block_in_opaque_block():
+    s = replace_ir_builder_with_opaque()
+    root_hash = s.mod["main"].__hash__()
+    for_loop = s.mod["main"].body.block.body.body.block.body[1].then_case.block.body
+    sref = s.get_sref(for_loop)
+    new_for_loop = tir.For(
+        loop_var=for_loop.loop_var,
+        min_val=0,
+        extent=128,
+        kind=tir.ForKind.SERIAL,
+        body=tir.Evaluate(0),
+        thread_binding=None,
+        annotations=None,
+    )
+    s.replace(sref, new_for_loop)
+    assert root_hash == s.mod["main"].__hash__()
+    tvm.ir.assert_structural_equal(sref.stmt, new_for_loop)
+
+
+def test_replace_ir_module():
+    s, target = replace_ir_builder_module(deep_copy=True)
+    old_hash = s.mod["main"].__hash__()
+    other_func_hash = s.mod["other"].__hash__()
+    func_ref = s.mod["main"]
+    sref = s.get_sref(s.mod["main"].body.block)
+    s.replace(sref, target)
+    # Check the new body equals to target
+    assert old_hash != s.mod["main"].__hash__()
+    tvm.ir.assert_structural_equal(s.mod["main"].body.block, target)
+    # Check the original func remains unchanged
+    assert old_hash == func_ref.__hash__()
+    assert not tvm.ir.structural_equal(func_ref.body, target)
+    assert other_func_hash == s.mod["other"].__hash__()
+
+
+if __name__ == "__main__":
+    test_replace_direct_write0()
+    test_replace_direct_write1()
+    test_replace_copy()
+    test_replace_partial_copy0()
+    test_replace_partial_copy1()
+    test_replace_root_write()
+    test_replace_root_copy0()
+    test_replace_root_copy1()
+    test_replace_root_copy2()
+    test_replace_root_copy3()
+    test_replace_block_remap()
+    test_replace_block_in_opaque_block()
+    test_replace_ir_module()