You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by cs...@apache.org on 2022/05/06 22:21:23 UTC

[tvm] branch main updated: [TIR] Add schedule primitive SetAxisSeparator (#11225)

This is an automated email from the ASF dual-hosted git repository.

csullivan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 31be30062b [TIR] Add schedule primitive SetAxisSeparator (#11225)
31be30062b is described below

commit 31be30062badf658bc71cb3a906411291a7db12a
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Fri May 6 15:21:18 2022 -0700

    [TIR] Add schedule primitive SetAxisSeparator (#11225)
    
    * [TIR] Add schedule primitive SetAxisSeparator
    
    * remove unused include
    
    * Move ReplaceBufferMutator impl to cc file
---
 include/tvm/tir/schedule/schedule.h                |  12 ++
 python/tvm/script/tir/__init__.pyi                 |   3 +
 python/tvm/script/tir/special_stmt.py              |  12 +-
 python/tvm/script/tir/ty.py                        |   3 +
 python/tvm/tir/schedule/schedule.py                |  81 +++++++++++-
 src/printer/tvmscript_printer.cc                   |   6 +
 src/tir/schedule/concrete_schedule.cc              |  10 ++
 src/tir/schedule/concrete_schedule.h               |   3 +
 src/tir/schedule/primitive.h                       |  11 ++
 src/tir/schedule/primitive/block_annotate.cc       | 107 ++--------------
 .../schedule/primitive/layout_transformation.cc    | 112 +++++++++++++++++
 src/tir/schedule/schedule.cc                       |   7 +-
 src/tir/schedule/traced_schedule.cc                |  13 ++
 src/tir/schedule/traced_schedule.h                 |   3 +
 src/tir/schedule/transform.cc                      |  81 ++++++++++++
 src/tir/schedule/transform.h                       |  55 ++++++++
 .../test_tir_schedule_set_axis_separator.py        | 139 +++++++++++++++++++++
 tests/python/unittest/test_tvmscript_roundtrip.py  |  20 +++
 18 files changed, 575 insertions(+), 103 deletions(-)

diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h
index e78cef2cac..18e15d1670 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -545,6 +545,18 @@ class ScheduleNode : public runtime::Object {
   virtual void TransformLayout(const BlockRV& block_rv, int buffer_index,
                                BufferIndexType buffer_index_type, const IndexMap& index_map) = 0;
 
+  /*!
+   * \brief Set the axis separator of a buffer, where the buffer is specified by a block and a read
+   * or write index
+   * \param block_rv The block that accesses the target buffer.
+   * \param buffer_index The index of the buffer in block's read or write region.
+   * \param buffer_index_type The type of the buffer index, kRead or kWrite.
+   * \param axis_separators The axis separator of the buffer
+   */
+  virtual void SetAxisSeparator(const BlockRV& block_rv, int buffer_index,
+                                BufferIndexType buffer_index_type,
+                                const Array<IntImm>& axis_separators) = 0;
+
   /******** Schedule: Misc ********/
   /*! \brief A no-op that marks the start of postprocessing phase of scheduling */
   virtual void EnterPostproc() = 0;
diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi
index 9727a8db63..e4513feb43 100644
--- a/python/tvm/script/tir/__init__.pyi
+++ b/python/tvm/script/tir/__init__.pyi
@@ -199,6 +199,7 @@ def match_buffer(
     align: int = -1,
     offset_factor: int = 0,
     buffer_type: str = "default",
+    axis_separators: Optional[List[int]] = None,
 ) -> Buffer: ...
 def buffer_decl(
     shape: Sequence[Union[PrimExpr, int]],
@@ -210,6 +211,7 @@ def buffer_decl(
     align: int = -1,
     offset_factor: int = 0,
     buffer_type: str = "default",
+    axis_separators: Optional[List[int]] = None,
 ) -> Buffer: ...
 def alloc_buffer(
     shape: Sequence[Union[PrimExpr, int]],
@@ -221,6 +223,7 @@ def alloc_buffer(
     align: int = -1,
     offset_factor: int = 0,
     buffer_type: str = "default",
+    axis_separators: Optional[List[int]] = None,
 ) -> Buffer: ...
 
 """
diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py
index 45eaa8b8be..39a345de7f 100644
--- a/python/tvm/script/tir/special_stmt.py
+++ b/python/tvm/script/tir/special_stmt.py
@@ -100,7 +100,7 @@ class SpecialStmt:
 @register
 class MatchBuffer(SpecialStmt):
     """Special Stmt match_buffer(param, shape, dtype, data, strides, elem_offset, scope, align,
-                                 offset_factor, buffer_type)
+                                 offset_factor, buffer_type, axis_separators)
 
     Note
     ----
@@ -131,6 +131,7 @@ class MatchBuffer(SpecialStmt):
             align=-1,
             offset_factor=0,
             buffer_type="default",
+            axis_separators=None,
             span=None,
         ):
             if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
@@ -157,6 +158,7 @@ class MatchBuffer(SpecialStmt):
                 align,
                 offset_factor,
                 buffer_type,
+                axis_separators,
                 span=span,
             )
             if isinstance(param, tvm.tir.Var):
@@ -184,7 +186,7 @@ class MatchBuffer(SpecialStmt):
 @register
 class BufferDeclare(SpecialStmt):
     """Special Stmt buffer_decl(shape, dtype, data, strides, elem_offset, scope, align,
-                                offset_factor, buffer_type)
+                                offset_factor, buffer_type, axis_separators)
     Example
     -------
     .. code-block:: python
@@ -202,6 +204,7 @@ class BufferDeclare(SpecialStmt):
             align=-1,
             offset_factor=0,
             buffer_type="default",
+            axis_separators=None,
             span=None,
         ):
             if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
@@ -228,6 +231,7 @@ class BufferDeclare(SpecialStmt):
                 align,
                 offset_factor,
                 buffer_type,
+                axis_separators,
                 span=span,
             )
             self.context.update_symbol(buffer_name, buffer, self.node)
@@ -239,7 +243,7 @@ class BufferDeclare(SpecialStmt):
 @register
 class AllocBuffer(SpecialStmt):
     """Special function alloc_buffer(shape, dtype, data, strides, elem_offset, scope, align,
-                                     offset_factor, buffer_type)
+                                     offset_factor, buffer_type, axis_separators)
 
     Example
     -------
@@ -259,6 +263,7 @@ class AllocBuffer(SpecialStmt):
             align=-1,
             offset_factor=0,
             buffer_type="default",
+            axis_separators=None,
             span=None,
         ):
             if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
@@ -286,6 +291,7 @@ class AllocBuffer(SpecialStmt):
                 align,
                 offset_factor,
                 buffer_type,
+                axis_separators,
                 span=span,
             )
             if self.context.current_block_scope():
diff --git a/python/tvm/script/tir/ty.py b/python/tvm/script/tir/ty.py
index dfe2fbbe42..7d90dec646 100644
--- a/python/tvm/script/tir/ty.py
+++ b/python/tvm/script/tir/ty.py
@@ -121,6 +121,7 @@ class GenericBufferType(SpecialStmt):  # pylint: disable=too-few-public-methods,
             align=-1,
             offset_factor=0,
             buffer_type="default",
+            axis_separators=None,
             span=None,
         ):
             if strides is None:
@@ -140,6 +141,7 @@ class GenericBufferType(SpecialStmt):  # pylint: disable=too-few-public-methods,
                 align,
                 offset_factor,
                 buffer_type,
+                axis_separators,
                 span=span,
             )
             return buffer
@@ -160,6 +162,7 @@ class GenericBufferType(SpecialStmt):  # pylint: disable=too-few-public-methods,
         align=-1,
         offset_factor=0,
         buffer_type="default",
+        axis_separators=None,
         span=None,
     ):
         """
diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py
index d537db2800..8bfd906315 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -2125,7 +2125,7 @@ class Schedule(Object):
         """Apply a transformation represented by IndexMap to buffer
         Parameters
         ----------
-        block_rv : BlockRV
+        block : BlockRV
             The block that accesses the target buffer
         buffer_index: int
             The index of the buffer in block's read or write region
@@ -2190,6 +2190,85 @@ class Schedule(Object):
             self, block, buffer_index, buffer_index_type_enum, index_map
         )
 
+    @type_checked
+    def set_axis_separator(
+        self,
+        block: BlockRV,
+        buffer_index: int,
+        buffer_index_type: str,
+        axis_separators: Optional[List[int]],
+    ) -> None:
+        """Set the axis separator of a buffer, where the buffer is specified by a block and a read
+        or write index.
+
+        Parameters
+        ----------
+        block : BlockRV
+            The block that accesses the target buffer
+        buffer_index: int
+            The index of the buffer in block's read or write region
+        buffer_index_type : str
+            Type of the buffer index, "read" or "write"
+        axis_separators : Optional[List[int]]
+            The axis separators.
+
+        Examples
+        --------
+
+        Before set_axis_separator, in TensorIR, the IR is:
+
+        .. code-block:: python
+
+            @T.prim_func
+            def before_set_axis_separator(
+                A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]
+            ) -> None:
+                B = T.alloc_buffer((128, 128), dtype="float32")
+
+                for i, j in T.grid(128, 128):
+                    with T.block("B"):
+                        vi, vj = T.axis.remap("SS", [i, j])
+                        B[vi, vj] = A[vi, vj] * 2.0
+                for i, j in T.grid(128, 128):
+                    with T.block("C"):
+                        vi, vj = T.axis.remap("SS", [i, j])
+                        C[vi, vj] = B[vi, vj] + 1.0
+
+        Create the schedule and do set_axis_separator:
+
+        .. code-block:: python
+
+            sch = tir.Schedule(before_set_axis_separator)
+            sch.set_axis_separators(sch.get_block("B"), buffer_index=0, buffer_index_type="write",
+                                    axis_separators=[1])
+            print(sch.mod["main"].script())
+
+        After applying set_axis_separator, the IR becomes:
+
+        .. code-block:: python
+
+            @T.prim_func
+            def after_set_axis_separators(
+                A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]
+            ) -> None:
+                B = T.alloc_buffer([128, 128], dtype="float32", axis_separators=[1])
+
+                for i, j in T.grid(128, 128):
+                    with T.block("B"):
+                        vi, vj = T.axis.remap("SS", [i, j])
+                        B[vi, vj] = A[vi, vj] * T.float32(2)
+                for i, j in T.grid(128, 128):
+                    with T.block("C"):
+                        vi, vj = T.axis.remap("SS", [i, j])
+                        C[vi, vj] = B[vi, vj] + T.float32(1)
+        """
+        axis_separators = axis_separators or []
+        assert buffer_index_type in ["read", "write"], "Invalid buffer_index_type"
+        buffer_index_type_enum = 0 if buffer_index_type == "read" else 1
+        _ffi_api.ScheduleSetAxisSeparator(  # type: ignore # pylint: disable=no-member
+            self, block, buffer_index, buffer_index_type_enum, axis_separators
+        )
+
     ########## Schedule: Misc ##########
 
     @type_checked
diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc
index aeb118a49c..6f8d10b320 100644
--- a/src/printer/tvmscript_printer.cc
+++ b/src/printer/tvmscript_printer.cc
@@ -502,6 +502,9 @@ Doc TVMScriptPrinter::AllocBufferDeclaration(const Buffer& buf) {
   if (buf->buffer_type != BufferType::kDefault) {
     doc << ", type=" << Doc::StrLiteral("auto");
   }
+  if (buf->axis_separators.size()) {
+    doc << ", axis_separators=" << Print(buf->axis_separators);
+  }
   return doc;
 }
 
@@ -606,6 +609,9 @@ bool TVMScriptPrinter::IsSimpleBuffer(const Buffer& buf) {
   if (buf->buffer_type != BufferType::kDefault) {
     return false;
   }
+  if (buf->axis_separators.size()) {
+    return false;
+  }
   return true;
 }
 
diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc
index 59a19631fc..7b953220f2 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -693,6 +693,16 @@ void ConcreteScheduleNode::TransformLayout(const BlockRV& block_rv, int buffer_i
   TVM_TIR_SCHEDULE_END("transform_layout", this->error_render_level_);
 }
 
+void ConcreteScheduleNode::SetAxisSeparator(const BlockRV& block_rv, int buffer_index,
+                                            BufferIndexType buffer_index_type,
+                                            const Array<IntImm>& axis_separators) {
+  TVM_TIR_SCHEDULE_BEGIN();
+  tir::SetAxisSeparator(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type,
+                        axis_separators);
+  TVM_TIR_SCHEDULE_END("set-axis-separator", this->error_render_level_);
+  this->state_->DebugVerify();
+}
+
 /******** Schedule: Misc ********/
 
 }  // namespace tir
diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h
index 4534406d79..9293aa3493 100644
--- a/src/tir/schedule/concrete_schedule.h
+++ b/src/tir/schedule/concrete_schedule.h
@@ -134,6 +134,9 @@ class ConcreteScheduleNode : public ScheduleNode {
   /******** Schedule: Layout transformation ********/
   void TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type,
                        const IndexMap& index_map) override;
+  void SetAxisSeparator(const BlockRV& block_rv, int buffer_index,
+                        BufferIndexType buffer_index_type,
+                        const Array<IntImm>& axis_separators) override;
   /******** Schedule: Misc ********/
   void EnterPostproc() override {}
 
diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h
index 5e21075d58..d55b896934 100644
--- a/src/tir/schedule/primitive.h
+++ b/src/tir/schedule/primitive.h
@@ -377,6 +377,17 @@ TVM_DLL void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int bu
  */
 TVM_DLL void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
                       const String& storage_scope);
+/*!
+ * \brief Set the axis separator of a buffer, where the buffer is specified by a block and a read
+ * or write index
+ * \param block_rv The block that accesses the target buffer.
+ * \param buffer_index The index of the buffer in block's read or write region.
+ * \param buffer_index_type The type of the buffer index, kRead or kWrite.
+ * \param axis_separators The axis separator of the buffer
+ */
+TVM_DLL void SetAxisSeparator(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
+                              BufferIndexType buffer_index_type,
+                              const Array<IntImm>& axis_separators);
 
 /******** Schedule: Blockize & Tensorize ********/
 
diff --git a/src/tir/schedule/primitive/block_annotate.cc b/src/tir/schedule/primitive/block_annotate.cc
index f9cec421cd..ede239878a 100644
--- a/src/tir/schedule/primitive/block_annotate.cc
+++ b/src/tir/schedule/primitive/block_annotate.cc
@@ -16,7 +16,6 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-#include "../../ir/functor_common.h"
 #include "../utils.h"
 
 namespace tvm {
@@ -202,7 +201,7 @@ class StorageAlignInvalidAnnotationError : public ScheduleError {
  * \brief A helper mutator which recursively mutates the old buffer's storage scope and collects
  * the block sref reuse information for the following replacement.
  */
-class StorageScopeMutator : StmtExprMutator {
+class StorageScopeMutator : private ReplaceBufferMutator {
  public:
   /*!
    * \param allocate_site The block where `old_buffer` was allocated.
@@ -222,107 +221,19 @@ class StorageScopeMutator : StmtExprMutator {
  private:
   StorageScopeMutator(const Buffer& old_buffer, Buffer new_buffer, String storage_scope,
                       Map<Block, Block>* block_sref_reuse)
-      : storage_scope_(std::move(storage_scope)), block_sref_reuse_(block_sref_reuse) {
-    buffer_var_map_[old_buffer->data.get()] = std::move(new_buffer);
-  }
-
-  PrimExpr VisitExpr_(const VarNode* var) final {
-    auto it = buffer_var_map_.find(var);
-    return it != buffer_var_map_.end() ? it->second->data : GetRef<Var>(var);
-  }
-
-  PrimExpr VisitExpr_(const BufferLoadNode* load) final {
-    BufferLoad res = Downcast<BufferLoad>(ExprMutator::VisitExpr_(load));
-
-    auto it = buffer_var_map_.find(res->buffer->data.get());
-    if (it != buffer_var_map_.end()) {
-      ObjectPtr<BufferLoadNode> ptr = make_object<BufferLoadNode>(*res.get());
-      ptr->buffer = it->second;
-      return PrimExpr(ptr);
-    } else {
-      return std::move(res);
-    }
-  }
+      : ReplaceBufferMutator(old_buffer, std::move(new_buffer), block_sref_reuse) {}
 
-  Stmt VisitStmt_(const BufferStoreNode* store) final {
-    BufferStore res = Downcast<BufferStore>(StmtMutator::VisitStmt_(store));
-
-    auto it = buffer_var_map_.find(res->buffer->data.get());
+  MatchBufferRegion VisitMatchBufferRegion(const MatchBufferRegion& match_buffer) final {
+    auto it = buffer_var_map_.find(match_buffer->source->buffer->data.get());
     if (it != buffer_var_map_.end()) {
-      ObjectPtr<BufferStoreNode> ptr = make_object<BufferStoreNode>(*res.get());
-      ptr->buffer = it->second;
-      return Stmt(ptr);
+      Buffer new_target_buffer = WithScope(match_buffer->buffer, it->second.scope());
+      buffer_var_map_[match_buffer->buffer->data.get()] = new_target_buffer;
+      return MatchBufferRegion(new_target_buffer,
+                               BufferRegion(it->second, match_buffer->source->region));
     } else {
-      return std::move(res);
+      return match_buffer;
     }
   }
-
-  Stmt VisitStmt_(const BlockNode* block) final {
-    // To reduce the number of blocks in block sref reuse map, we check whether the block is really
-    // mutated (i.e., the old buffer appears in the block). If so, we return the block after
-    // mutation. Otherwise we just return the original block.
-
-    // Define the mutation functions.
-    auto f_mutate_match_buffers = [this](const MatchBufferRegion& match_buffer) {
-      auto it = buffer_var_map_.find(match_buffer->source->buffer->data.get());
-      if (it != buffer_var_map_.end()) {
-        Buffer new_target_buffer = WithScope(match_buffer->buffer, storage_scope_);
-        buffer_var_map_[match_buffer->buffer->data.get()] = new_target_buffer;
-        return MatchBufferRegion(new_target_buffer,
-                                 BufferRegion(it->second, match_buffer->source->region));
-      } else {
-        return match_buffer;
-      }
-    };
-    auto f_mutate_read_write_region = [this](const BufferRegion& buffer_region) {
-      auto it = buffer_var_map_.find(buffer_region->buffer->data.get());
-      return it == buffer_var_map_.end() ? buffer_region
-                                         : BufferRegion(it->second, buffer_region->region);
-    };
-    auto f_mutate_alloc_buffers = [this](const Buffer& buffer) {
-      auto it = buffer_var_map_.find(buffer->data.get());
-      return it == buffer_var_map_.end() ? buffer : it->second;
-    };
-
-    // Step 1. Mutate `match_buffers`. If an old buffer appears as a source of MatchBufferRegion,
-    // the storage scope of the target buffer also needs to be set.
-    Array<MatchBufferRegion> match_buffers =
-        MutateArray(block->match_buffers, f_mutate_match_buffers);
-    // Step 2. Mutate the read/write region.
-    Array<BufferRegion> reads = MutateArray(block->reads, f_mutate_read_write_region);
-    Array<BufferRegion> writes = MutateArray(block->writes, f_mutate_read_write_region);
-    // Step 3. Mutate `alloc_buffers` for the old buffer allocated in this block.
-    Array<Buffer> alloc_buffers = MutateArray(block->alloc_buffers, f_mutate_alloc_buffers);
-    // Step 4. Recursively mutate the block.
-    Block mutated_block = Downcast<Block>(StmtMutator::VisitStmt_(block));
-
-    if (mutated_block.get() == block && reads.same_as(mutated_block->reads) &&
-        writes.same_as(mutated_block->writes) &&
-        alloc_buffers.same_as(mutated_block->alloc_buffers) &&
-        match_buffers.same_as(mutated_block->match_buffers)) {
-      return GetRef<Block>(block);
-    } else {
-      ObjectPtr<BlockNode> n = CopyOnWrite(mutated_block.get());
-      n->reads = std::move(reads);
-      n->writes = std::move(writes);
-      n->alloc_buffers = std::move(alloc_buffers);
-      n->match_buffers = std::move(match_buffers);
-
-      Block new_block(n);
-      block_sref_reuse_->Set(GetRef<Block>(block), new_block);
-      return std::move(new_block);
-    }
-  }
-
-  /*! \brief The storage scope to be set. */
-  String storage_scope_;
-  /*!
-   * \brief A mapping which maps old buffer vars to new buffers, including the buffers defined in
-   * MatchBufferRegion.
-   */
-  std::unordered_map<const VarNode*, Buffer> buffer_var_map_;
-  /*! \brief The block sref reuse map for the following replacement */
-  Map<Block, Block>* block_sref_reuse_;
 };
 
 void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_index, int axis,
diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc
index fcfce5d217..b133f537b5 100644
--- a/src/tir/schedule/primitive/layout_transformation.cc
+++ b/src/tir/schedule/primitive/layout_transformation.cc
@@ -185,6 +185,84 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_
   self->Replace(scope_sref, new_scope_block, block_sref_reuse);
 }
 
+class BufferAxisSeparatorMutator : private ReplaceBufferMutator {
+ public:
+  static Block Mutate(const Block& scope_block, const Buffer& old_buffer, Buffer new_buffer,
+                      Map<Block, Block>* block_sref_reuse) {
+    BufferAxisSeparatorMutator mutator(old_buffer, std::move(new_buffer), block_sref_reuse);
+    return Downcast<Block>(mutator.VisitStmt(scope_block));
+  }
+
+ private:
+  BufferAxisSeparatorMutator(const Buffer& old_buffer, Buffer new_buffer,
+                             Map<Block, Block>* block_sref_reuse)
+      : ReplaceBufferMutator(old_buffer, new_buffer, block_sref_reuse) {}
+
+  MatchBufferRegion VisitMatchBufferRegion(const MatchBufferRegion& match_buffer) final {
+    auto it = buffer_var_map_.find(match_buffer->source->buffer->data.get());
+    if (it != buffer_var_map_.end()) {
+      const Buffer& new_source_buffer = it->second;
+      Buffer new_target_buffer = match_buffer->buffer;
+      new_target_buffer.CopyOnWrite()->axis_separators = new_source_buffer->axis_separators;
+      if (new_target_buffer->shape.size() != new_source_buffer->shape.size()) {
+        LOG(WARNING)
+            << "Target buffer in match_buffer doesn't have the same dimensionality as its source "
+               "buffer. `axis_separators` for the target buffer might be incorrect.";
+      }
+      buffer_var_map_[new_target_buffer->data.get()] = new_target_buffer;
+      return MatchBufferRegion(new_target_buffer,
+                               BufferRegion(new_source_buffer, match_buffer->source->region));
+    }
+    return match_buffer;
+  }
+};
+
+void SetAxisSeparator(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
+                      BufferIndexType buffer_index_type, const Array<IntImm>& axis_separators) {
+  const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref);
+  Buffer old_buffer = GetNthAccessBuffer(self, GetRef<Block>(block_ptr), buffer_index,
+                                         buffer_index_type == BufferIndexType::kWrite);
+  Optional<StmtSRef> defining_site_sref;
+  bool is_alloc;
+  std::tie(defining_site_sref, is_alloc) = GetBufferDefiningSite(block_sref, old_buffer);
+  if (defining_site_sref.defined() && !is_alloc) {
+    throw BufferIsSubregionError(self->mod, old_buffer);
+  }
+
+  StmtSRef scope_sref = defining_site_sref.defined()
+                            ? defining_site_sref.value()
+                            : GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false);
+  const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref);
+
+  // Step 1: Check and update axis_separators of the buffer.
+  Buffer new_buffer = old_buffer;
+  new_buffer.CopyOnWrite()->axis_separators = axis_separators;
+
+  Map<Block, Block> block_sref_reuse;
+
+  // Step 2: Rewrite alloc_buffer of the block or buffer_map of the PrimFunc.
+  Block new_scope_block = BufferAxisSeparatorMutator::Mutate(GetRef<Block>(scope_block), old_buffer,
+                                                             new_buffer, &block_sref_reuse);
+  if (!defining_site_sref.defined()) {
+    // mutate buffer_map of the PrimFunc
+    GlobalVar g_var;
+    GetRootPrimFunc(self->mod, scope_block, &g_var);
+    IRModuleNode* new_mod = self->mod.CopyOnWrite();
+    MapNode* new_map = new_mod->functions.CopyOnWrite();
+    PrimFunc ref_new_func = Downcast<PrimFunc>(std::move(new_map->at(g_var)));
+    PrimFuncNode* new_func = ref_new_func.CopyOnWrite();
+    MapNode* new_buffer_map = new_func->buffer_map.CopyOnWrite();
+    for (auto it = new_buffer_map->begin(); it != new_buffer_map->end(); ++it) {
+      if ((*it).second.same_as(old_buffer)) {
+        (*it).second = new_buffer;
+      }
+    }
+    new_map->at(g_var) = std::move(ref_new_func);
+  }
+
+  // Step 4: Replace the scope block with the new block
+  self->Replace(scope_sref, new_scope_block, block_sref_reuse);
+}
 /******** InstructionKind Registration ********/
 
 struct TransformLayoutTraits : public UnpackedInstTraits<TransformLayoutTraits> {
@@ -238,7 +316,41 @@ struct TransformLayoutTraits : public UnpackedInstTraits<TransformLayoutTraits>
   friend struct ::tvm::tir::UnpackedInstTraits;
 };
 
+struct SetAxisSeparatorTraits : public UnpackedInstTraits<SetAxisSeparatorTraits> {
+  static constexpr const char* kName = "SetAxisSeparator";
+  static constexpr bool kIsPure = false;
+
+ private:
+  static constexpr size_t kNumInputs = 1;
+  static constexpr size_t kNumAttrs = 3;
+  static constexpr size_t kNumDecisions = 0;
+
+  static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, Integer buffer_index,
+                                      Integer buffer_index_type, Array<IntImm> axis_separators) {
+    return sch->SetAxisSeparator(block_rv, buffer_index,
+                                 static_cast<BufferIndexType>(buffer_index_type->value),
+                                 axis_separators);
+  }
+
+  static String UnpackedAsPython(Array<String> outputs, String block_rv, Integer buffer_index,
+                                 Integer buffer_index_type, Array<IntImm> axis_separators) {
+    PythonAPICall py("set_axis_separator");
+    py.Input("block", block_rv);
+    py.Input("buffer_index", buffer_index);
+    py.Input("buffer_index_type", '"' +
+                                      std::string(BufferIndexType2Str(
+                                          static_cast<BufferIndexType>(buffer_index_type->value))) +
+                                      '"');
+    py.Input("axis_separators", axis_separators);
+    return py.Str();
+  }
+
+  template <typename>
+  friend struct ::tvm::tir::UnpackedInstTraits;
+};
+
 TVM_REGISTER_INST_KIND_TRAITS(TransformLayoutTraits);
+TVM_REGISTER_INST_KIND_TRAITS(SetAxisSeparatorTraits);
 
 }  // namespace tir
 }  // namespace tvm
diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc
index 82cd0a4a35..8dc0c52111 100644
--- a/src/tir/schedule/schedule.cc
+++ b/src/tir/schedule/schedule.cc
@@ -233,7 +233,12 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTransformLayout")
       return self->TransformLayout(block_rv, buffer_index,
                                    static_cast<BufferIndexType>(buffer_index_type), index_map);
     });
-
+TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSetAxisSeparator")
+    .set_body_typed([](Schedule self, const BlockRV& block_rv, int buffer_index,
+                       int buffer_index_type, const Array<IntImm>& axis_separators) {
+      return self->SetAxisSeparator(
+          block_rv, buffer_index, static_cast<BufferIndexType>(buffer_index_type), axis_separators);
+    });
 /******** (FFI) Misc ********/
 TVM_REGISTER_GLOBAL("tir.schedule.ScheduleEnterPostproc")
     .set_body_method<Schedule>(&ScheduleNode::EnterPostproc);
diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc
index 417f80dd93..865b6f3784 100644
--- a/src/tir/schedule/traced_schedule.cc
+++ b/src/tir/schedule/traced_schedule.cc
@@ -442,6 +442,19 @@ void TracedScheduleNode::TransformLayout(const BlockRV& block_rv, int buffer_ind
                            /*outputs=*/{}));
 }
 
+void TracedScheduleNode::SetAxisSeparator(const BlockRV& block_rv, int buffer_index,
+                                          BufferIndexType buffer_index_type,
+                                          const Array<IntImm>& axis_separators) {
+  ConcreteScheduleNode::SetAxisSeparator(block_rv, buffer_index, buffer_index_type,
+                                         axis_separators);
+  static const InstructionKind& kind = InstructionKind::Get("SetAxisSeparator");
+  trace_->Append(/*inst=*/Instruction(
+      /*kind=*/kind,
+      /*inputs=*/{block_rv},
+      /*attrs=*/{Integer(buffer_index), Integer(buffer_index_type), axis_separators},
+      /*outputs=*/{}));
+}
+
 /******** Schedule: Misc ********/
 
 void TracedScheduleNode::EnterPostproc() {
diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h
index 442b50ad0c..12c076d886 100644
--- a/src/tir/schedule/traced_schedule.h
+++ b/src/tir/schedule/traced_schedule.h
@@ -98,6 +98,9 @@ class TracedScheduleNode : public ConcreteScheduleNode {
   /******** Schedule: Layout transformation ********/
   void TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type,
                        const IndexMap& index_map) override;
+  void SetAxisSeparator(const BlockRV& block_rv, int buffer_index,
+                        BufferIndexType buffer_index_type,
+                        const Array<IntImm>& axis_separators) final;
   /******** Schedule: Misc ********/
   void EnterPostproc() final;
 };
diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc
index b2e71a9a0d..6c4f3e1b7a 100644
--- a/src/tir/schedule/transform.cc
+++ b/src/tir/schedule/transform.cc
@@ -70,6 +70,87 @@ Array<MatchBufferRegion> ReplaceBuffer(Array<MatchBufferRegion> match_buffers, c
   return match_buffers;
 }
 
+/******** ReplaceBufferMutator ********/
+ReplaceBufferMutator::ReplaceBufferMutator(const Buffer& old_buffer, Buffer new_buffer,
+                                           Map<Block, Block>* block_sref_reuse)
+    : block_sref_reuse_(block_sref_reuse) {
+  buffer_var_map_[old_buffer->data.get()] = std::move(new_buffer);
+}
+
+PrimExpr ReplaceBufferMutator::VisitExpr_(const VarNode* var) {
+  auto it = buffer_var_map_.find(var);
+  return it != buffer_var_map_.end() ? it->second->data : GetRef<Var>(var);
+}
+
+Stmt ReplaceBufferMutator::VisitStmt_(const BufferStoreNode* op) {
+  auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
+  return VisitBufferAccess(std::move(node));
+}
+
+PrimExpr ReplaceBufferMutator::VisitExpr_(const BufferLoadNode* op) {
+  auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
+  return VisitBufferAccess(std::move(node));
+}
+
+MatchBufferRegion ReplaceBufferMutator::VisitMatchBufferRegion(
+    const MatchBufferRegion& match_buffer) {
+  auto it = buffer_var_map_.find(match_buffer->source->buffer->data.get());
+  if (it != buffer_var_map_.end()) {
+    return MatchBufferRegion(match_buffer->buffer,
+                             BufferRegion(it->second, match_buffer->source->region));
+  } else {
+    return match_buffer;
+  }
+}
+
+Stmt ReplaceBufferMutator::VisitStmt_(const BlockNode* block) {
+  // To reduce the number of blocks in block sref reuse map, we check whether the block is really
+  // mutated (i.e., the old buffer appears in the block). If so, we return the block after
+  // mutation. Otherwise we just return the original block.
+
+  auto f_mutate_match_buffer = [this](const MatchBufferRegion& match_buffer) {
+    return this->VisitMatchBufferRegion(match_buffer);
+  };
+  auto f_mutate_read_write_region = [this](const BufferRegion& buffer_region) {
+    auto it = buffer_var_map_.find(buffer_region->buffer->data.get());
+    return it == buffer_var_map_.end() ? buffer_region
+                                       : BufferRegion(it->second, buffer_region->region);
+  };
+  auto f_mutate_alloc_buffers = [this](const Buffer& buffer) {
+    auto it = buffer_var_map_.find(buffer->data.get());
+    return it == buffer_var_map_.end() ? buffer : it->second;
+  };
+
+  // Step 1. Mutate `match_buffers`. If an old buffer appears as a source of MatchBufferRegion,
+  Array<MatchBufferRegion> match_buffers = MutateArray(block->match_buffers, f_mutate_match_buffer);
+  // Step 2. Mutate the read/write region.
+  Array<BufferRegion> reads = MutateArray(block->reads, f_mutate_read_write_region);
+  Array<BufferRegion> writes = MutateArray(block->writes, f_mutate_read_write_region);
+  // Step 3. Mutate `alloc_buffers` for the old buffer allocated in this block.
+  Array<Buffer> alloc_buffers = MutateArray(block->alloc_buffers, f_mutate_alloc_buffers);
+  // Step 4. Recursively mutate the block.
+  Block mutated_block = Downcast<Block>(StmtMutator::VisitStmt_(block));
+
+  if (mutated_block.get() == block && reads.same_as(mutated_block->reads) &&
+      writes.same_as(mutated_block->writes) &&
+      alloc_buffers.same_as(mutated_block->alloc_buffers) &&
+      match_buffers.same_as(mutated_block->match_buffers)) {
+    return GetRef<Block>(block);
+  } else {
+    ObjectPtr<BlockNode> n = CopyOnWrite(mutated_block.get());
+    n->reads = std::move(reads);
+    n->writes = std::move(writes);
+    n->alloc_buffers = std::move(alloc_buffers);
+    n->match_buffers = std::move(match_buffers);
+
+    Block new_block(n);
+    if (block_sref_reuse_ != nullptr) {
+      block_sref_reuse_->Set(GetRef<Block>(block), new_block);
+    }
+    return std::move(new_block);
+  }
+}
+
 /******** Block Removal ********/
 
 void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_sref,
diff --git a/src/tir/schedule/transform.h b/src/tir/schedule/transform.h
index 12326b3418..52e27350d4 100644
--- a/src/tir/schedule/transform.h
+++ b/src/tir/schedule/transform.h
@@ -21,6 +21,12 @@
 
 #include <tvm/tir/schedule/schedule.h>
 #include <tvm/tir/schedule/state.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <unordered_map>
+#include <utility>
+
+#include "../ir/functor_common.h"
 
 namespace tvm {
 namespace tir {
@@ -66,6 +72,55 @@ Array<BufferRegion> ReplaceBuffer(Array<BufferRegion> regions, const Buffer& sou
 Array<MatchBufferRegion> ReplaceBuffer(Array<MatchBufferRegion> match_buffers, const Buffer& source,
                                        const Buffer& target);
 
+/*!
+ * \brief A helper mutator which recursively replaces the old buffer with the new buffer and
+ * collects the block sref reuse information for the following replacement.
+ *
+ * If the buffer to be replaced in used as the source in `match_buffers`, depending the specific
+ * use cases, the target buffers in `match_buffers` may also need to be mutated. In this
+ * case, this class should be subclassed to explicitly handle `match_buffers`.
+ */
+class ReplaceBufferMutator : public StmtExprMutator {
+ public:
+  /*!
+   * \brief The constructor
+   * \param old_buffer The old buffer
+   * \param new_buffer The new buffer
+   * \param block_sref_reuse Optional map to record mapping between old and new blocks that reuse
+   *        sref.
+   */
+  ReplaceBufferMutator(const Buffer& old_buffer, Buffer new_buffer,
+                       Map<Block, Block>* block_sref_reuse);
+
+ protected:
+  PrimExpr VisitExpr_(const VarNode* var) final;
+
+  template <typename Node>
+  Node VisitBufferAccess(Node node) {
+    auto it = buffer_var_map_.find(node->buffer->data.get());
+    if (it != buffer_var_map_.end()) {
+      node.CopyOnWrite()->buffer = it->second;
+    }
+    return node;
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final;
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final;
+
+  virtual MatchBufferRegion VisitMatchBufferRegion(const MatchBufferRegion& match_buffer);
+
+  Stmt VisitStmt_(const BlockNode* block) final;
+
+  /*!
+   * \brief A mapping which maps old buffer vars to new buffers, including the buffers defined in
+   * MatchBufferRegion.
+   */
+  std::unordered_map<const VarNode*, Buffer> buffer_var_map_;
+  /*! \brief The block sref reuse map for the following replacement */
+  Map<Block, Block>* block_sref_reuse_;
+};
+
 /******** Block Removal ********/
 
 /*!
diff --git a/tests/python/unittest/test_tir_schedule_set_axis_separator.py b/tests/python/unittest/test_tir_schedule_set_axis_separator.py
new file mode 100644
index 0000000000..d829a3f1b7
--- /dev/null
+++ b/tests/python/unittest/test_tir_schedule_set_axis_separator.py
@@ -0,0 +1,139 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+import sys
+import pytest
+import tvm
+from tvm import tir
+from tvm.script import tir as T
+from tvm.tir.schedule.testing import verify_trace_roundtrip
+
+# fmt: off
+# pylint: disable=no-member,invalid-name,unused-variable,unexpected-keyword-arg
+
+@T.prim_func
+def element_wise(A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]) -> None:
+    B = T.alloc_buffer((128, 128), dtype="float32")
+
+    for i, j in T.grid(128, 128):
+        with T.block("B"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            B[vi, vj] = A[vi, vj] * 2.0
+    for i, j in T.grid(128, 128):
+        with T.block("C"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            C[vi, vj] = B[vi, vj] + 1.0
+
+
+@T.prim_func
+def element_wise_set_axis_separator(A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]) -> None:
+    B = T.alloc_buffer([128, 128], dtype="float32", axis_separators=[1])
+
+    for i, j in T.grid(128, 128):
+        with T.block("B"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            B[vi, vj] = A[vi, vj] * T.float32(2)
+    for i, j in T.grid(128, 128):
+        with T.block("C"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            C[vi, vj] = B[vi, vj] + T.float32(1)
+
+
+@T.prim_func
+def element_wise_set_axis_separator_input_buffer(A: T.Buffer(shape=(128, 128), dtype="float32", axis_separators=(1,)), C: T.Buffer[(128, 128), "float32"]) -> None:
+    B = T.alloc_buffer([128, 128], dtype="float32")
+
+    for i, j in T.grid(128, 128):
+        with T.block("B"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            B[vi, vj] = A[vi, vj] * T.float32(2)
+    for i, j in T.grid(128, 128):
+        with T.block("C"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            C[vi, vj] = B[vi, vj] + T.float32(1)
+
+
+@T.prim_func
+def element_wise_subregion_match(A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]) -> None:
+    B = T.alloc_buffer((128, 128), dtype="float32")
+
+    for i, j in T.grid(128, 128):
+        with T.block("B"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            B_subregion0 = T.match_buffer(B[i, j], [], offset_factor=1)
+            B_subregion0[()] = A[vi, vj] * 2.0
+    for i, j in T.grid(128, 128):
+        with T.block("C"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            B_subregion1 = T.match_buffer(B[i, j], [], offset_factor=1)
+            C[vi, vj] = B_subregion1[()] + 1.0
+
+
+@T.prim_func
+def element_wise_subregion_match_set_axis_separator(A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]) -> None:
+    B = T.alloc_buffer([128, 128], dtype="float32", axis_separators=[1])
+
+    for i, j in T.grid(128, 128):
+        with T.block("B"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            B_subregion0 = T.match_buffer(B[i, j], [], dtype="float32", offset_factor=1, axis_separators=[1])
+            B_subregion0[()] = A[vi, vj] * T.float32(2)
+    for i, j in T.grid(128, 128):
+        with T.block("C"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            B_subregion1 = T.match_buffer(B[i, j], [], dtype="float32", offset_factor=1, axis_separators=[1])
+            C[vi, vj] = B_subregion1[()] + T.float32(1)
+
+
+# pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg
+
+
+def test_set_axis_separator():
+    func = element_wise
+    s = tir.Schedule(func, debug_mask='all')
+    s.set_axis_separator(s.get_block("B"), 0, "write", [1])
+    tvm.ir.assert_structural_equal(element_wise_set_axis_separator, s.mod["main"])
+    verify_trace_roundtrip(sch=s, mod=func)
+
+
+def test_set_scope_fail_on_index_out_of_bound():
+    func = element_wise
+    s = tir.Schedule(func, debug_mask='all')
+    with pytest.raises(tvm.tir.ScheduleError):
+        s.set_axis_separator(s.get_block("B"), 1, "write",[1])
+    with pytest.raises(tvm.tir.ScheduleError):
+        s.set_axis_separator(s.get_block("B"), -1, "read",[1])
+
+
+def test_set_axis_separator_input_buffer():
+    func = element_wise
+    s = tir.Schedule(func, debug_mask='all')
+    s.set_axis_separator(s.get_block("B"), 0, "read", [1])
+    tvm.ir.assert_structural_equal(element_wise_set_axis_separator_input_buffer, s.mod["main"])
+    verify_trace_roundtrip(sch=s, mod=func)
+
+
+def test_set_axis_separator_subregion():
+    func = element_wise_subregion_match
+    s = tir.Schedule(func, debug_mask='all')
+    s.set_axis_separator(s.get_block("B"), 0, "write", [1])
+    tvm.ir.assert_structural_equal(element_wise_subregion_match_set_axis_separator, s.mod["main"])
+    verify_trace_roundtrip(sch=s, mod=func)
+
+
+if __name__ == "__main__":
+    sys.exit(pytest.main([__file__] + sys.argv[1:]))
diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py
index 0437576462..c704baebc7 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -3253,6 +3253,25 @@ def pointer_type():
     return func_with_ptr_type_annotations
 
 
+def buffer_axis_separator():
+    @T.prim_func
+    def element_wise(a: T.handle, c: T.handle) -> None:
+        A = T.match_buffer(a, (128, 128), "float32", axis_separators=[1])
+        C = T.match_buffer(c, (128, 128), "float32")
+        B = T.alloc_buffer((128, 128), "float32", axis_separators=[1])
+
+        for i, j in T.grid(128, 128):
+            with T.block("B"):
+                vi, vj = T.axis.remap("SS", [i, j])
+                B[vi, vj] = A[vi, vj] * T.float32(2)
+        for i, j in T.grid(128, 128):
+            with T.block("C"):
+                vi, vj = T.axis.remap("SS", [i, j])
+                C[vi, vj] = B[vi, vj] + T.float32(1)
+
+    return element_wise
+
+
 ir_generator = tvm.testing.parameter(
     opt_gemm_normalize,
     opt_gemm_lower,
@@ -3288,6 +3307,7 @@ ir_generator = tvm.testing.parameter(
     int64_support,
     string_annotation_escaping,
     pointer_type,
+    buffer_axis_separator,
 )