You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2022/09/15 20:15:20 UTC

[tvm] branch main updated: [TIR, Schedule] Add schedule primitive PadEinsum (#12750)

This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 1f8b5dec29 [TIR, Schedule] Add schedule primitive PadEinsum (#12750)
1f8b5dec29 is described below

commit 1f8b5dec29e6e34b4cf5f092acf5b1d197a59d42
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Thu Sep 15 13:15:10 2022 -0700

    [TIR, Schedule] Add schedule primitive PadEinsum (#12750)
    
    * [TIR, Schedule] Add schedule primitive PadEinsum
    
    Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
    
    * lint
    
    * [TIR] Fix producer indices check in PadEinsum
    
    * address comments
    
    * simplify lambda expr
    
    * fix
    
    Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
---
 include/tvm/tir/schedule/schedule.h                |  20 +
 python/tvm/tir/schedule/schedule.py                | 122 ++++++
 src/tir/schedule/analysis.h                        |  27 ++
 src/tir/schedule/analysis/analysis.cc              |  29 ++
 src/tir/schedule/concrete_schedule.cc              |   6 +
 src/tir/schedule/concrete_schedule.h               |   1 +
 src/tir/schedule/primitive.h                       |  11 +-
 .../schedule/primitive/layout_transformation.cc    |  36 +-
 src/tir/schedule/primitive/pad_einsum.cc           | 474 +++++++++++++++++++++
 src/tir/schedule/schedule.cc                       |   3 +-
 src/tir/schedule/traced_schedule.cc                |  12 +-
 src/tir/schedule/traced_schedule.h                 |   3 +-
 src/tir/schedule/transform.cc                      |   8 +
 src/tir/schedule/transform.h                       |   7 +-
 .../unittest/test_tir_schedule_pad_einsum.py       | 122 ++++++
 15 files changed, 841 insertions(+), 40 deletions(-)

diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h
index da399ab976..8e5cd34d2e 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -627,6 +627,7 @@ class ScheduleNode : public runtime::Object {
                                 BufferIndexType buffer_index_type,
                                 const Array<IntImm>& axis_separators) = 0;
 
+  /******** Schedule: Padding ********/
   /*!
    * \brief Decompose a padding block into a block filling const pad values and a block
    * writing in-bound values.
@@ -636,6 +637,25 @@ class ScheduleNode : public runtime::Object {
    */
   virtual BlockRV DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) = 0;
 
+  /*!
+   * \brief Pad the computation of Einsum.
+   * \param block_rv The block that matches the Einsum pattern.
+   * \param padding The padding for each block iter.
+   * \details This schedule primitives identifies the Einsum pattern in the block body, and find its
+   * producer blocks. It then pads the computation of the Einsum pattern and its producer blocks.
+   * The output buffer and the producer buffer is resized according to the padding size. It requires
+   * the output buffer and the producer buffer to be allocated inside the PrimFunc.
+   *
+   * The padding is a list of non-negative integers, each element corresponds to the padding for
+   * each block iter in the order of block iters. The block and its producer blocks should have
+   * trivial bindings, i.e. each block iter is bound to a single loop variable. After padding, the
+   * block iter extent and the corresponding outer loop is extended by the padding size.
+   *
+   * The size of the producer buffers are infered from the padding size of the Einsum computation.
+   * The producer buffers are padded by the initial value of the corresponding reduction.
+   */
+  virtual void PadEinsum(const BlockRV& block_rv, const Array<Integer>& padding) = 0;
+
   /******** Schedule: Misc ********/
   /*! \brief A no-op that marks the start of postprocessing phase of scheduling */
   virtual void EnterPostproc() = 0;
diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py
index d1293371a0..fdc8717032 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -2783,6 +2783,128 @@ class Schedule(Object):
         """Check whether the block match padding pattern and can be decomposed."""
         return _ffi_api.CanDecomposePadding(self, block, loop)  # type: ignore # pylint: disable=no-member
 
+    @type_checked
+    def pad_einsum(self, block: Union[BlockRV, str], padding: List[int]) -> None:
+        """Pad the computation of Einsum.
+
+        This schedule primitives identifies the Einsum pattern in the block body, and find its
+        producer blocks. It then pads the computation of the Einsum pattern and its producer blocks.
+        The output buffer and the producer buffer is resized according to the padding size. It
+        requires the output buffer and the producer buffer to be allocated inside the PrimFunc.
+
+        The padding is a list of non-negative integers, each element corresponds to the padding for
+        each block iter in the order of block iters. The block and it's producer blocks should have
+        trivial bindings, i.e. each block iter is bound to a single loop variable. After padding,
+        thblock iter extent and the corresponding outer loop is extended by the padding size.
+
+        The size of the producer buffers are infered from the padding size of the Einsum
+        computation. The producer buffers are padded by the initial value of the corresponding
+        reduction.
+
+        Parameters
+        ----------
+        block : Union[BlockRV, str]
+            The block that matches the Einsum pattern.
+
+        padding : List[int]
+            The padding for each block iter.
+
+        Examples
+        --------
+
+        Before applying pad-einsum, in TensorIR, the IR is:
+
+        .. code-block:: python
+
+            @T.prim_func
+            def before_pad_einsum(
+                A: T.Buffer[(128, 127), "float32"],
+                B: T.Buffer[(127, 127), "float32"],
+                C: T.Buffer[(128, 127), "float32"],
+            ) -> None:
+                A_shared = T.alloc_buffer((128, 127), "float32", scope="shared")
+                B_shared = T.alloc_buffer((127, 127), "float32", scope="shared")
+                C_shared = T.alloc_buffer((128, 127), "float32", scope="shared")
+                for i0, i1 in T.grid(128, 127):
+                    with T.block("A"):
+                        i, j = T.axis.remap("SS", [i0, i1])
+                        A_shared[i, j] = A[i, j]
+                for i0, i1 in T.grid(127, 127):
+                    with T.block("B"):
+                        i, j = T.axis.remap("SS", [i0, i1])
+                        B_shared[i, j] = B[i, j]
+                for i0, i1, i2 in T.grid(128, 127, 127):
+                    with T.block("C_shared"):
+                        i, j, k = T.axis.remap("SSR", [i0, i1, i2])
+                        with T.init():
+                            C_shared[i, j] = T.float32(0)
+                        C_shared[i, j] = C_shared[i, j] + A_shared[i, k] * B_shared[k, j]
+                for i0, i1 in T.grid(128, 127):
+                    with T.block("C"):
+                        i, j = T.axis.remap("SS", [i0, i1])
+                        C[i, j] = C_shared[i, j]
+
+        Create the schedule and do pad-einsum with specified block:
+
+        .. code-block:: python
+
+            sch = tir.Schedule(before_pad_einsum, debug_mask="all")
+            block = sch.get_block("C_shared")
+            sch.pad_einsum(block, [0, 1, 1])
+            print(sch.mod["main"].script())
+
+        After applying decompose-padding, the IR becomes:
+
+        .. code-block:: python
+
+            @T.prim_func
+            def after_pad_einsum(
+                A: T.Buffer[(128, 127), "float32"],
+                B: T.Buffer[(127, 127), "float32"],
+                C: T.Buffer[(128, 127), "float32"],
+            ) -> None:
+                A_shared_padded = T.alloc_buffer([128, 128], dtype="float32", scope="shared")
+                B_shared_padded = T.alloc_buffer([128, 128], dtype="float32", scope="shared")
+                C_shared_padded = T.alloc_buffer([128, 128], dtype="float32", scope="shared")
+                for i0, i1 in T.grid(128, 128):
+                    with T.block("A"):
+                        i, j = T.axis.remap("SS", [i0, i1])
+                        T.reads(A[i, j])
+                        T.writes(A_shared_padded[i, j])
+                        A_shared_padded[i, j] = T.if_then_else(
+                            j < 127, A[i, j], T.float32(0), dtype="float32"
+                        )
+                for i0, i1 in T.grid(128, 128):
+                    with T.block("B"):
+                        i, j = T.axis.remap("SS", [i0, i1])
+                        T.reads(B[i, j])
+                        T.writes(B_shared_padded[i, j])
+                        B_shared_padded[i, j] = T.if_then_else(
+                            i < 127 and j < 127, B[i, j], T.float32(0), dtype="float32"
+                        )
+                for i0, i1, i2 in T.grid(128, 128, 128):
+                    with T.block("C_shared"):
+                        i, j, k = T.axis.remap("SSR", [i0, i1, i2])
+                        T.reads(A_shared_padded[i, k], B_shared_padded[k, j])
+                        T.writes(C_shared_padded[i, j])
+                        with T.init():
+                            C_shared_padded[i, j] = T.float32(0)
+                        C_shared_padded[i, j] = (
+                            C_shared_padded[i, j] + A_shared_padded[i, k] * B_shared_padded[k, j]
+                        )
+                for i0, i1 in T.grid(128, 127):
+                    with T.block("C"):
+                        i, j = T.axis.remap("SS", [i0, i1])
+                        T.reads(C_shared_padded[i, j])
+                        T.writes(C[i, j])
+                        C[i, j] = C_shared_padded[i, j]
+
+        """
+        block = self._normalize_block_arg(block)
+        return _ffi_api.SchedulePadEinsum(  # type: ignore # pylint: disable=no-member
+            self, block, padding
+        )
+
     ########## Schedule: Misc ##########
 
     @type_checked
diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h
index 489df8959d..ca45bcac6b 100644
--- a/src/tir/schedule/analysis.h
+++ b/src/tir/schedule/analysis.h
@@ -298,6 +298,15 @@ bool GetVarsTouchedByBlockIters(const BlockRealize& block_realize,
 void CheckLoopStartsWithZero(const ScheduleState& self, const StmtSRef& loop_sref,
                              arith::Analyzer* analyzer);
 
+/*!
+ * \brief Check whether a block has a trivial binding, i.e. each block var is bound to a outer loop,
+ * from outer to inner.
+ * \param self The schedule state
+ * \param block_sref The block to be checked
+ * \throw ScheduleError If the block does not have trivial bindings
+ */
+void CheckBlockHasTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref);
+
 /******** Block-loop relation ********/
 
 /*!
@@ -697,6 +706,24 @@ Array<arith::IntSet> AnalyzeRegionLowerBound(const BufferRegion& region, const P
                                              const StmtSRef& dom_high_exclusive,
                                              arith::Analyzer* analyzer);
 
+/*!
+ * \brief Check if buffer indices are all Vars and extr
+ * \param buffer_access The BufferLoad or BufferStore
+ * \return The indices if the indices are all Vars, otherwise NullOpt
+ */
+template <typename T>
+Optional<Array<Var>> CheckTrivialBufferIndices(const T& buffer_access) {
+  Array<Var> indices;
+  for (const PrimExpr& index : buffer_access->indices) {
+    const VarNode* var = index.as<VarNode>();
+    if (var == nullptr) {
+      return NullOpt;
+    }
+    indices.push_back(GetRef<Var>(var));
+  }
+  return indices;
+}
+
 /*! \brief Necessary information used for tensorization */
 class TensorizeInfoNode : public Object {
  public:
diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc
index 7ed60876ab..4f78b0c9cd 100644
--- a/src/tir/schedule/analysis/analysis.cc
+++ b/src/tir/schedule/analysis/analysis.cc
@@ -652,6 +652,35 @@ void CheckAffineBinding(const ScheduleState& self, Block block) {
   CheckPartialAffineBinding(self, std::move(block), NullOpt);
 }
 
+void CheckBlockHasTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref) {
+  class NotTrivialBindingError : public ScheduleError {
+   public:
+    explicit NotTrivialBindingError(IRModule mod, Block block)
+        : mod_(std::move(mod)), block_(std::move(block)) {}
+
+    String FastErrorString() const final {
+      return "ScheduleError: The binding values of the block are not variables of outer loops.";
+    }
+
+    String DetailRenderTemplate() const final {
+      std::ostringstream os;
+      os << "The binding values of the {0} are not variables of outer loops.";
+      return os.str();
+    }
+
+    IRModule mod() const final { return mod_; }
+    Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
+
+   private:
+    IRModule mod_;
+    Block block_;
+  };
+
+  if (!IsTrivialBinding(self, block_sref)) {
+    throw NotTrivialBindingError(self->mod, GetRef<Block>(block_sref->StmtAs<BlockNode>()));
+  }
+}
+
 Map<Var, Range> LoopDomainOfSRefTreePath(const StmtSRef& low_inclusive,
                                          const Optional<StmtSRef>& high_exclusive,
                                          const runtime::StorageScope& extra_relax_scope) {
diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc
index afc6757997..9d7dc6b95f 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -795,6 +795,12 @@ BlockRV ConcreteScheduleNode::DecomposePadding(const BlockRV& block_rv, const Lo
   return CreateRV<BlockRV>(result);
 }
 
+void ConcreteScheduleNode::PadEinsum(const BlockRV& block_rv, const Array<Integer>& padding) {
+  TVM_TIR_SCHEDULE_BEGIN();
+  tir::PadEinsum(state_, this->GetSRef(block_rv), padding);
+  TVM_TIR_SCHEDULE_END("pad-einsum", this->error_render_level_);
+  this->state_->DebugVerify();
+}
 /******** Schedule: Misc ********/
 
 }  // namespace tir
diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h
index e79d1d5288..1aa9dafcc9 100644
--- a/src/tir/schedule/concrete_schedule.h
+++ b/src/tir/schedule/concrete_schedule.h
@@ -128,6 +128,7 @@ class ConcreteScheduleNode : public ScheduleNode {
   /******** Schedule: Reduction ********/
   BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) override;
   BlockRV DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) override;
+  void PadEinsum(const BlockRV& block_rv, const Array<Integer>& padding) override;
   /******** Schedule: Block annotation ********/
   void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor,
                     int offset) override;
diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h
index 05d9e4cf94..97233fe4bc 100644
--- a/src/tir/schedule/primitive.h
+++ b/src/tir/schedule/primitive.h
@@ -490,7 +490,7 @@ TVM_DLL void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int
 TVM_DLL void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref,
                                   const IndexMap& index_map);
 
-/******** Schedule: Padding decomposition ********/
+/******** Schedule: Padding ********/
 /*!
  * \brief Decompose a padding block into a block filling const pad values and a block
  * writing in-bound values.
@@ -501,6 +501,15 @@ TVM_DLL void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref
 TVM_DLL StmtSRef DecomposePadding(ScheduleState self, const StmtSRef& block_sref,
                                   const StmtSRef& loop_sref);
 
+/*!
+ * \brief Pad the computation of Einsum.
+ * \param self The state of the schedule
+ * \param block_sref The block sref that matches the Einsum pattern.
+ * \param padding The padding for each block iter.
+ */
+TVM_DLL void PadEinsum(ScheduleState self, const StmtSRef& block_sref,
+                       const Array<Integer>& padding);
+
 /******** Schedule: Misc ********/
 
 }  // namespace tir
diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc
index 8e2643db01..32ed279f02 100644
--- a/src/tir/schedule/primitive/layout_transformation.cc
+++ b/src/tir/schedule/primitive/layout_transformation.cc
@@ -278,40 +278,6 @@ class IndexMapNotApplicableToBlockIterError : public ScheduleError {
   IndexMap index_map_;
 };
 
-class NotTrivialBindingError : public ScheduleError {
- public:
-  explicit NotTrivialBindingError(IRModule mod, Block block)
-      : mod_(std::move(mod)), block_(std::move(block)) {}
-
-  static void CheckBlockHasTrivialBinding(const IRModule& mod, const BlockRealize& block_realize,
-                                          std::unordered_set<const VarNode*> outer_loop_vars) {
-    // Step 2: Check all the binding values are loops vars
-    for (const PrimExpr& iter_value : block_realize->iter_values) {
-      const VarNode* loop_var = iter_value.as<VarNode>();
-      if (!loop_var || !outer_loop_vars.count(loop_var)) {
-        throw NotTrivialBindingError(mod, block_realize->block);
-      }
-    }
-  }
-
-  String FastErrorString() const final {
-    return "ScheduleError: The binding values of the block are not variables of outer loops.";
-  }
-
-  String DetailRenderTemplate() const final {
-    std::ostringstream os;
-    os << "The binding values of the {0} are not variables of outer loops.";
-    return os.str();
-  }
-
-  IRModule mod() const final { return mod_; }
-  Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
-
- private:
-  IRModule mod_;
-  Block block_;
-};
-
 class OpaqueNewIterTypeError : public ScheduleError {
  public:
   explicit OpaqueNewIterTypeError(IRModule mod, Block block, PrimExpr iter_value)
@@ -363,7 +329,7 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref,
   }
 
   BlockRealize block_realize = GetBlockRealize(self, block_sref);
-  NotTrivialBindingError::CheckBlockHasTrivialBinding(self->mod, block_realize, loop_vars);
+  CheckBlockHasTrivialBinding(self, block_sref);
 
   // Step 3: Collect information of block iter vars
   Array<PrimExpr> block_vars;      // iter_var->var of each block iter
diff --git a/src/tir/schedule/primitive/pad_einsum.cc b/src/tir/schedule/primitive/pad_einsum.cc
new file mode 100644
index 0000000000..7a7b88d686
--- /dev/null
+++ b/src/tir/schedule/primitive/pad_einsum.cc
@@ -0,0 +1,474 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <optional>
+
+#include "../utils.h"
+
+namespace tvm {
+namespace tir {
+
+/*! \brief The schedule error class when the padding size is invalid. */
+class InvalidPaddingError : public ScheduleError {
+ public:
+  InvalidPaddingError(IRModule mod, Block block, Array<Integer> padding)
+      : mod_(std::move(mod)), block_(std::move(block)), padding_(std::move(padding)) {}
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
+  String FastErrorString() const final {
+    return "ScheduleError: The padding size for the block is invalid.";
+  }
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "The padding for the block {0} are invalid. It should be a list of "
+       << block_->iter_vars.size() << " non-negative integers. Got " << padding_;
+    return os.str();
+  }
+
+  static void Check(const ScheduleState& self, const Block& block, Array<Integer> padding) {
+    if (padding.size() != block->iter_vars.size()) {
+      throw InvalidPaddingError(self->mod, block, padding);
+    }
+    for (const auto& pad : padding) {
+      if (pad->value < 0) {
+        throw InvalidPaddingError(self->mod, block, padding);
+      }
+    }
+  }
+
+ private:
+  IRModule mod_;
+  Block block_;
+  Array<Integer> padding_;
+};
+
+/*! \brief The schedule error class when the block body is not an Einsum pattern. */
+class NonEinsumError : public ScheduleError {
+ public:
+  explicit NonEinsumError(IRModule mod, Block block)
+      : mod_(std::move(mod)), block_(std::move(block)) {}
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
+  String FastErrorString() const final {
+    return "ScheduleError: The block is not a computation of Einsum pattern.";
+  }
+  String DetailRenderTemplate() const final {
+    return "The block {0} not a computation of Einsum pattern.";
+  }
+
+ private:
+  IRModule mod_;
+  Block block_;
+};
+
+/*! \brief Data structure that represents a Einsum computation. */
+struct Einsum {
+  // The output buffer
+  Buffer output_buffer;
+  // The indices of the output buffer
+  Array<Var> output_indices;
+  // The indices of the input buffers
+  Map<Buffer, Array<Var>> input_indices;
+};
+
+class EinsumExtractor : public ExprVisitor {
+ public:
+  EinsumExtractor() = default;
+
+  std::optional<Einsum> Extract(const Block& block) {
+    const BufferStoreNode* update = block->body.as<BufferStoreNode>();
+    // Step 1: Check the body is a BufferStore and the block has the init statement, and the
+    // BufferStore and the init statement store have the same output buffer indices.
+    if (update == nullptr || !block->init.defined()) {
+      return std::nullopt;
+    }
+
+    if (Optional<Array<Var>> opt_indices = CheckTrivialBufferIndices(update);
+        opt_indices.defined()) {
+      ein_sum_.output_indices = std::move(opt_indices.value());
+    } else {
+      return std::nullopt;
+    }
+    ein_sum_.output_buffer = update->buffer;
+
+    const BufferStoreNode* init = block->init.value().as<BufferStoreNode>();
+    ICHECK(init != nullptr);
+    if (!CompareBufferIndices(init->indices, ein_sum_.output_indices)) {
+      return std::nullopt;
+    }
+    // Step 2: Check the BufferStore updates the output buffer and the input buffers indices are
+    // block iter variables.
+    CheckStoreValue(update->value);
+    if (fail_) {
+      return std::nullopt;
+    }
+    return std::move(ein_sum_);
+  }
+
+ private:
+  void CheckStoreValue(const PrimExpr& update) {
+    // Check the update part has the form:
+    //   Output[output_indices] += Input_0[input_indices_0] op_0 Input_1[input_indices_1] op_1 ...
+    // where output_indices and input_indices_i are the indices are arrays whose elements are the
+    // block iter variables instead of composite PrimExpr, and op_i are the binary operations.
+
+    // Check the value is Add and eithe LHS or RHS is the BufferLoad from the output buffer.
+    const AddNode* add = update.as<AddNode>();
+    if (add == nullptr) {
+      fail_ = true;
+      return;
+    }
+    const BufferLoadNode* lhs = add->a.as<BufferLoadNode>();
+    const BufferLoadNode* rhs = add->b.as<BufferLoadNode>();
+    if (lhs == nullptr && rhs != nullptr) {
+      std::swap(lhs, rhs);
+    }
+    if (lhs == nullptr || !lhs->buffer.same_as(ein_sum_.output_buffer) ||
+        !CompareBufferIndices(lhs->indices, ein_sum_.output_indices)) {
+      fail_ = true;
+      return;
+    }
+    VisitExpr(add->b);
+  }
+
+  void VisitExpr(const PrimExpr& n) final {
+    if (n->IsInstance<BufferLoadNode>() || n->IsInstance<MulNode>() || n->IsInstance<CastNode>()) {
+      ExprVisitor::VisitExpr(n);
+    } else {
+      fail_ = true;
+      return;
+    }
+  }
+
+  void VisitExpr_(const BufferLoadNode* op) final {
+    if (auto it = ein_sum_.input_indices.find(op->buffer);
+        it != ein_sum_.input_indices.end() && !CompareBufferIndices(op->indices, (*it).second)) {
+      fail_ = true;
+      return;
+    }
+    if (Optional<Array<Var>> opt_indices = CheckTrivialBufferIndices(op); opt_indices.defined()) {
+      ein_sum_.input_indices.Set(op->buffer, std::move(opt_indices.value()));
+    } else {
+      fail_ = true;
+      return;
+    }
+  }
+
+  void VisitExpr_(const CastNode* op) { VisitExpr(op->value); }
+
+  bool Fail() { return fail_; }
+
+  bool CompareBufferIndices(const Array<PrimExpr>& indices, const Array<Var>& other) {
+    return std::equal(indices.begin(), indices.end(), other.begin(), other.end(),
+                      [](const PrimExpr& a, const Var& b) { return a.same_as(b); });
+  }
+
+  Einsum ein_sum_;
+  bool fail_{false};
+};
+
+Einsum ExtractEinsum(const ScheduleState& self, const Block& block) {
+  EinsumExtractor extractor;
+  std::optional<Einsum> einsum = extractor.Extract(block);
+  if (!einsum.has_value()) {
+    throw NonEinsumError(self->mod, block);
+  }
+  return einsum.value();
+}
+
+class BufferNotAllocatedInScopeError : public ScheduleError {
+ public:
+  explicit BufferNotAllocatedInScopeError(IRModule mod, Buffer buffer)
+      : mod_(std::move(mod)), buffer_(std::move(buffer)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: The buffer is not allocated as an intermediate buffer in current "
+           "PrimFunc.";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "The buffer " << buffer_->name
+       << " is not allocated as an intermediate buffer in current PrimFunc.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {}; }
+
+ private:
+  IRModule mod_;
+  Buffer buffer_;
+};
+
+class PadEinsumRewriter : public ReplaceBufferMutator {
+ public:
+  PadEinsumRewriter(const std::unordered_map<const BlockNode*, PrimExpr> producer_predicate,
+                    Map<Var, PrimExpr> padded_iter_extents, const Map<Buffer, Buffer>& buffer_remap,
+                    Map<Block, Block>* block_sref_reuse, arith::Analyzer* analyzer)
+      : ReplaceBufferMutator(buffer_remap, block_sref_reuse),
+        producer_predicate_(producer_predicate),
+        padded_iter_extents_(padded_iter_extents),
+        analyzer_(analyzer) {}
+
+  Stmt VisitStmt_(const ForNode* op) final {
+    For new_for = Downcast<For>(ReplaceBufferMutator::VisitStmt_(op));
+    if (padded_iter_extents_.count(new_for->loop_var)) {
+      new_for.CopyOnWrite()->extent = padded_iter_extents_.at(new_for->loop_var);
+    }
+    return std::move(new_for);
+  }
+
+  Block PadProducerBlock(Block block, const PrimExpr& predicate) {
+    BufferStore store = Downcast<BufferStore>(block->body);
+    store.CopyOnWrite()->value =
+        analyzer_->Simplify(if_then_else(predicate, store->value, make_zero(store->value.dtype())));
+    block.CopyOnWrite()->body = std::move(store);
+    return block;
+  }
+
+  Stmt VisitStmt_(const BlockNode* op) final {
+    Block old_block = GetRef<Block>(op);
+    Block new_block = Downcast<Block>(ReplaceBufferMutator::VisitStmt_(op));
+    if (auto it = producer_predicate_.find(op); it != producer_predicate_.end()) {
+      new_block = PadProducerBlock(std::move(new_block), (*it).second);
+    }
+
+    // Mutate block iters
+    Array<IterVar> new_iters;
+    bool changed = false;
+    for (const IterVar& iter : new_block->iter_vars) {
+      if (auto it = padded_iter_extents_.find(iter->var); it != padded_iter_extents_.end()) {
+        changed = true;
+        new_iters.push_back(
+            IterVar(Range::FromMinExtent(0, (*it).second), iter->var, iter->iter_type));
+      } else {
+        new_iters.push_back(iter);
+      }
+    }
+    if (changed) {
+      new_block.CopyOnWrite()->iter_vars = std::move(new_iters);
+    }
+    if (!old_block.same_as(new_block)) {
+      block_sref_reuse_->Set(old_block, new_block);
+    }
+    return std::move(new_block);
+  }
+
+ private:
+  const std::unordered_set<const BlockNode*> producer_blocks_;
+  const std::unordered_map<const BlockNode*, PrimExpr> producer_predicate_;
+  const Map<Var, PrimExpr> padded_iter_extents_;
+  arith::Analyzer* analyzer_;
+};
+
+/*! \brief The schedule error class when the producer block cannot be padded. */
+class InvalidProducerError : public ScheduleError {
+ public:
+  explicit InvalidProducerError(IRModule mod, Block producer)
+      : mod_(std::move(mod)), producer_(std::move(producer)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: The producer block cannot be padded.";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "The producer block {0} cannot be padded. It should write to a single buffer and the "
+          "body should be a BufferStore.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {producer_}; }
+
+ private:
+  IRModule mod_;
+  Buffer buffer_;
+  Block producer_;
+};
+
+void PadEinsum(ScheduleState self, const StmtSRef& block_sref, const Array<Integer>& padding) {
+  arith::Analyzer analyzer;
+  // Step 1: Input checking and error handling
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
+  BlockRealize realize = GetBlockRealize(self, block_sref);
+
+  const StmtSRef& scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true);
+  InvalidPaddingError::Check(self, GetRef<Block>(block), padding);
+
+  const Array<StmtSRef> producers = GetProducers(self, block_sref);
+  {
+    auto f_check_block_properties = [&](const StmtSRef& block_sref, bool is_producer) {
+      CheckBlockHasTrivialBinding(self, block_sref);
+      if (is_producer) {
+        CheckCompleteBlock(self, block_sref, scope_sref);
+      } else {
+        CheckReductionBlock(self, block_sref, scope_sref);
+      }
+      Array loops = GetLoops(block_sref);
+      ICHECK(!loops.empty());
+      CheckGetSingleChildBlockRealizeOnSRefTree(self, loops.front());
+    };
+
+    // Check block properties of the computation block
+    f_check_block_properties(block_sref, false);
+
+    // Check block properties of the producer block
+    for (const StmtSRef& producer_sref : producers) {
+      f_check_block_properties(producer_sref, true);
+    }
+  }
+
+  Einsum einsum = ExtractEinsum(self, GetRef<Block>(block));
+
+  // Check input and output buffers are all allocated in the current scope.
+  {
+    auto f_check_buffer_allocated = [&](const Buffer& buffer) {
+      auto [defining_site_sref, is_allocate] = GetBufferDefiningSite(block_sref, buffer);
+      if (!defining_site_sref.defined() || !is_allocate) {
+        throw BufferNotAllocatedInScopeError(self->mod, buffer);
+      }
+    };
+    f_check_buffer_allocated(einsum.output_buffer);
+    for (const auto& buffer_indices_pair : einsum.input_indices) {
+      f_check_buffer_allocated(buffer_indices_pair.first);
+    }
+  }
+
+  // Step 2: Prepare buffer and variable remapping. Infer the new shape of the input and the output
+  // buffers. Infer the new extent of the block iters of the computation block and the producer
+  // block.
+
+  Map<Var, PrimExpr> padded_iter_extents;  // The new extents of both the block iters and loop vars
+
+  // Convert the input padding array to a map from variables to the padded extents
+  for (int i = 0, n = padding.size(); i < n; ++i) {
+    const IterVar& iter = block->iter_vars[i];
+    PrimExpr new_extent =
+        IntImm(iter->var->dtype, Downcast<Integer>(iter->dom->extent)->value + padding[i]->value);
+    padded_iter_extents.Set(iter->var, new_extent);
+    padded_iter_extents.Set(Downcast<Var>(realize->iter_values[i]), new_extent);
+  }
+
+  Map<Buffer, Buffer> buffer_remap;  // mapping from buffers to new buffers with padded shapes
+
+  // Utility function to pad a buffer with the new shape
+  auto f_pad_buffer = [&padded_iter_extents, &buffer_remap](Buffer buffer,
+                                                            const Array<Var>& indices) -> Buffer {
+    Array<PrimExpr> new_shape;
+    for (const Var& index : indices) {
+      new_shape.push_back(padded_iter_extents.at(index));
+    }
+    ICHECK_EQ(buffer->shape.size(), new_shape.size());
+    buffer.CopyOnWrite()->shape = std::move(new_shape);
+    return buffer;
+  };
+
+  buffer_remap.Set(einsum.output_buffer, f_pad_buffer(einsum.output_buffer, einsum.output_indices));
+
+  std::unordered_map<const BlockNode*, PrimExpr> producer_predicate;
+
+  // Different from the output block, the padding for the producer block is not directly specified
+  // as the input argument. Instead, it is inferred from indices of the producer buffer accessed in
+  // the output block.
+  // We will find the indices (which are block iters) in BufferStore to the producer buffer
+  // and infer the new extents of the block iters and the corresponding loop vars.
+  for (const StmtSRef& producer_sref : producers) {
+    const BlockNode* producer_block = TVM_SREF_TO_BLOCK(producer_sref);
+    const BufferStoreNode* buffer_store = producer_block->body.as<BufferStoreNode>();
+    Optional<Array<Var>> producer_store_indices;
+    if (!buffer_store || producer_block->writes.size() != 1 ||
+        !(producer_store_indices = CheckTrivialBufferIndices(buffer_store)).defined()) {
+      throw InvalidProducerError(self->mod, GetRef<Block>(producer_block));
+    }
+    BlockRealize producer_realize = GetBlockRealize(self, producer_sref);
+
+    const Buffer& old_buffer = producer_block->writes[0]->buffer;
+    Buffer new_buffer = f_pad_buffer(old_buffer, einsum.input_indices.at(old_buffer));
+    buffer_remap.Set(old_buffer, new_buffer);
+
+    // The predicate to ensure the producer block is in the original bound before padding
+    PrimExpr predicate = Bool(true);
+    Map<Var, PrimExpr> indices_to_padded_extents;  // buffer indices to padded extents
+    for (int i = 0, n = producer_store_indices.value().size(); i < n; ++i) {
+      const Var& index = producer_store_indices.value()[i];
+      PrimExpr padded_extent = new_buffer->shape[i];
+      if (!analyzer.CanProveEqual(padded_extent, old_buffer->shape[i])) {
+        predicate = predicate && (index < old_buffer->shape[i]);
+      }
+      indices_to_padded_extents.Set(index, padded_extent);
+    }
+
+    for (int i = 0, n = producer_block->iter_vars.size(); i < n; ++i) {
+      const IterVar& iter = producer_block->iter_vars[i];
+      if (auto it = indices_to_padded_extents.find(iter->var);
+          it != indices_to_padded_extents.end()) {
+        const PrimExpr& padded_extent = (*it).second;
+        padded_iter_extents.Set(iter->var, padded_extent);
+        padded_iter_extents.Set(Downcast<Var>(producer_realize->iter_values[i]), padded_extent);
+      } else if (!is_one(iter->dom->extent)) {
+        throw InvalidProducerError(self->mod, GetRef<Block>(producer_block));
+      }
+    }
+    producer_predicate[producer_block] = predicate;
+  }
+
+  // Step 3: Mutate the AST subtree with the new buffers and the new block iter extents.
+  Map<Block, Block> block_sref_reuse;
+  PadEinsumRewriter rewriter(producer_predicate, padded_iter_extents, buffer_remap,
+                             &block_sref_reuse, &analyzer);
+  const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_sref);
+  Stmt new_scope_block = rewriter(GetRef<Block>(scope_block));
+
+  // Step 4: Do the actual replacement.
+  self->Replace(scope_sref, new_scope_block, block_sref_reuse);
+}
+
+/******** Instruction Registration ********/
+
+struct PadEinsumTraits : public UnpackedInstTraits<PadEinsumTraits> {
+  static constexpr const char* kName = "PadEinsum";
+  static constexpr bool kIsPure = false;
+
+ private:
+  static constexpr size_t kNumInputs = 1;
+  static constexpr size_t kNumAttrs = 1;
+  static constexpr size_t kNumDecisions = 0;
+
+  static void UnpackedApplyToSchedule(Schedule sch, BlockRV block, Array<Integer> padding) {
+    sch->PadEinsum(block, padding);
+  }
+
+  static String UnpackedAsPython(Array<String> outputs, String block, Array<Integer> padding) {
+    PythonAPICall py("pad_einsum");
+    py.Input("block", block);
+    py.Input("padding", padding);
+    return py.Str();
+  }
+
+  template <typename>
+  friend struct ::tvm::tir::UnpackedInstTraits;
+};
+
+TVM_REGISTER_INST_KIND_TRAITS(PadEinsumTraits);
+
+}  // namespace tir
+}  // namespace tvm
diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc
index 091db344aa..d72f67fb7c 100644
--- a/src/tir/schedule/schedule.cc
+++ b/src/tir/schedule/schedule.cc
@@ -264,7 +264,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSetAxisSeparator")
 /******** (FFI) Padding decomposition ********/
 TVM_REGISTER_GLOBAL("tir.schedule.ScheduleDecomposePadding")
     .set_body_method<Schedule>(&ScheduleNode::DecomposePadding);
-
+TVM_REGISTER_GLOBAL("tir.schedule.SchedulePadEinsum")
+    .set_body_method<Schedule>(&ScheduleNode::PadEinsum);
 /******** (FFI) Misc ********/
 TVM_REGISTER_GLOBAL("tir.schedule.ScheduleEnterPostproc")
     .set_body_method<Schedule>(&ScheduleNode::EnterPostproc);
diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc
index 04ddc0507d..a31950d331 100644
--- a/src/tir/schedule/traced_schedule.cc
+++ b/src/tir/schedule/traced_schedule.cc
@@ -520,7 +520,7 @@ void TracedScheduleNode::SetAxisSeparator(const BlockRV& block_rv, int buffer_in
       /*outputs=*/{}));
 }
 
-/******** Schedule: Padding decomposition ********/
+/******** Schedule: Padding ********/
 BlockRV TracedScheduleNode::DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) {
   BlockRV new_block = ConcreteScheduleNode::DecomposePadding(block_rv, loop_rv);
   static const InstructionKind& kind = InstructionKind::Get("DecomposePadding");
@@ -532,6 +532,16 @@ BlockRV TracedScheduleNode::DecomposePadding(const BlockRV& block_rv, const Loop
   return new_block;
 }
 
+void TracedScheduleNode::PadEinsum(const BlockRV& block_rv, const Array<Integer>& padding) {
+  ConcreteScheduleNode::PadEinsum(block_rv, padding);
+  static const InstructionKind& kind = InstructionKind::Get("PadEinsum");
+  trace_->Append(/*inst=*/Instruction(
+      /*kind=*/kind,
+      /*inputs=*/{block_rv},
+      /*attrs=*/{padding},
+      /*outputs=*/{}));
+}
+
 /******** Schedule: Misc ********/
 
 void TracedScheduleNode::EnterPostproc() {
diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h
index d98e4ba4bb..ad44cc6ae5 100644
--- a/src/tir/schedule/traced_schedule.h
+++ b/src/tir/schedule/traced_schedule.h
@@ -108,8 +108,9 @@ class TracedScheduleNode : public ConcreteScheduleNode {
   void SetAxisSeparator(const BlockRV& block_rv, int buffer_index,
                         BufferIndexType buffer_index_type,
                         const Array<IntImm>& axis_separators) final;
-  /******** Schedule: Padding decomposition ********/
+  /******** Schedule: Padding ********/
   BlockRV DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) final;
+  void PadEinsum(const BlockRV& block_rv, const Array<Integer>& padding) final;
   /******** Schedule: Misc ********/
   void EnterPostproc() final;
 };
diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc
index c11fa656d6..dfbd3dbcbc 100644
--- a/src/tir/schedule/transform.cc
+++ b/src/tir/schedule/transform.cc
@@ -103,6 +103,14 @@ ReplaceBufferMutator::ReplaceBufferMutator(const Buffer& old_buffer, Buffer new_
   buffer_var_map_[old_buffer->data.get()] = std::move(new_buffer);
 }
 
+ReplaceBufferMutator::ReplaceBufferMutator(const Map<Buffer, Buffer>& buffer_map,
+                                           Map<Block, Block>* block_sref_reuse)
+    : block_sref_reuse_(block_sref_reuse) {
+  for (const auto& [old_buffer, new_buffer] : buffer_map) {
+    buffer_var_map_[old_buffer->data.get()] = new_buffer;
+  }
+}
+
 PrimExpr ReplaceBufferMutator::VisitExpr_(const VarNode* var) {
   auto it = buffer_var_map_.find(var);
   return it != buffer_var_map_.end() ? it->second->data : GetRef<Var>(var);
diff --git a/src/tir/schedule/transform.h b/src/tir/schedule/transform.h
index 908a823c2d..4de3685e24 100644
--- a/src/tir/schedule/transform.h
+++ b/src/tir/schedule/transform.h
@@ -114,7 +114,12 @@ class ReplaceBufferMutator : public StmtExprMutator {
   ReplaceBufferMutator(const Buffer& old_buffer, Buffer new_buffer,
                        Map<Block, Block>* block_sref_reuse);
 
+  ReplaceBufferMutator(const Map<Buffer, Buffer>& buffer_map, Map<Block, Block>* block_sref_reuse);
+
  protected:
+  using StmtExprMutator::VisitExpr_;
+  using StmtExprMutator::VisitStmt_;
+
   PrimExpr VisitExpr_(const VarNode* var) final;
 
   template <typename Node>
@@ -132,7 +137,7 @@ class ReplaceBufferMutator : public StmtExprMutator {
 
   virtual MatchBufferRegion VisitMatchBufferRegion(const MatchBufferRegion& match_buffer);
 
-  Stmt VisitStmt_(const BlockNode* block) final;
+  Stmt VisitStmt_(const BlockNode* block) override;
 
   /*!
    * \brief A mapping which maps old buffer vars to new buffers, including the buffers defined in
diff --git a/tests/python/unittest/test_tir_schedule_pad_einsum.py b/tests/python/unittest/test_tir_schedule_pad_einsum.py
new file mode 100644
index 0000000000..89628db4ff
--- /dev/null
+++ b/tests/python/unittest/test_tir_schedule_pad_einsum.py
@@ -0,0 +1,122 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+import sys
+
+import pytest
+import tvm
+import tvm.testing
+from tvm import tir, te
+from tvm.script import tir as T
+from tvm.tir.schedule.schedule import ScheduleError
+from tvm.tir.schedule.testing import verify_trace_roundtrip
+from tvm.meta_schedule.testing import te_workload
+
+# pylint: disable=no-member,invalid-name,unused-variable,unexpected-keyword-arg
+
+
+@T.prim_func
+def matmul_before(
+    A: T.Buffer[(128, 127), "float32"],
+    B: T.Buffer[(127, 127), "float32"],
+    C: T.Buffer[(128, 127), "float32"],
+) -> None:
+    A_shared = T.alloc_buffer((128, 127), "float32", scope="shared")
+    B_shared = T.alloc_buffer((127, 127), "float32", scope="shared")
+    C_shared = T.alloc_buffer((128, 127), "float32", scope="shared")
+    for i0, i1 in T.grid(128, 127):
+        with T.block("A"):
+            i, j = T.axis.remap("SS", [i0, i1])
+            A_shared[i, j] = A[i, j]
+    for i0, i1 in T.grid(127, 127):
+        with T.block("B"):
+            i, j = T.axis.remap("SS", [i0, i1])
+            B_shared[i, j] = B[i, j]
+    for i0, i1, i2 in T.grid(128, 127, 127):
+        with T.block("C_shared"):
+            i, j, k = T.axis.remap("SSR", [i0, i1, i2])
+            with T.init():
+                C_shared[i, j] = T.float32(0)
+            C_shared[i, j] = C_shared[i, j] + A_shared[i, k] * B_shared[k, j]
+    for i0, i1 in T.grid(128, 127):
+        with T.block("C"):
+            i, j = T.axis.remap("SS", [i0, i1])
+            C[i, j] = C_shared[i, j]
+
+
+@T.prim_func
+def matmul_expected(
+    A: T.Buffer[(128, 127), "float32"],
+    B: T.Buffer[(127, 127), "float32"],
+    C: T.Buffer[(128, 127), "float32"],
+) -> None:
+    A_shared_padded = T.alloc_buffer([128, 128], dtype="float32", scope="shared")
+    B_shared_padded = T.alloc_buffer([128, 128], dtype="float32", scope="shared")
+    C_shared_padded = T.alloc_buffer([128, 128], dtype="float32", scope="shared")
+    for i0, i1 in T.grid(128, 128):
+        with T.block("A"):
+            i, j = T.axis.remap("SS", [i0, i1])
+            T.reads(A[i, j])
+            T.writes(A_shared_padded[i, j])
+            A_shared_padded[i, j] = T.if_then_else(j < 127, A[i, j], T.float32(0), dtype="float32")
+    for i0, i1 in T.grid(128, 128):
+        with T.block("B"):
+            i, j = T.axis.remap("SS", [i0, i1])
+            T.reads(B[i, j])
+            T.writes(B_shared_padded[i, j])
+            B_shared_padded[i, j] = T.if_then_else(
+                i < 127 and j < 127, B[i, j], T.float32(0), dtype="float32"
+            )
+    for i0, i1, i2 in T.grid(128, 128, 128):
+        with T.block("C_shared"):
+            i, j, k = T.axis.remap("SSR", [i0, i1, i2])
+            T.reads(A_shared_padded[i, k], B_shared_padded[k, j])
+            T.writes(C_shared_padded[i, j])
+            with T.init():
+                C_shared_padded[i, j] = T.float32(0)
+            C_shared_padded[i, j] = (
+                C_shared_padded[i, j] + A_shared_padded[i, k] * B_shared_padded[k, j]
+            )
+    for i0, i1 in T.grid(128, 127):
+        with T.block("C"):
+            i, j = T.axis.remap("SS", [i0, i1])
+            T.reads(C_shared_padded[i, j])
+            T.writes(C[i, j])
+            C[i, j] = C_shared_padded[i, j]
+
+
+# pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg
+
+
+def test_pad_matmul():
+    sch = tir.Schedule(matmul_before, debug_mask="all")
+    C = sch.get_block("C_shared")
+    sch.pad_einsum(C, [0, 1, 1])
+    tvm.ir.assert_structural_equal(matmul_expected, sch.mod["main"])
+    verify_trace_roundtrip(sch, mod=matmul_before)
+
+
+def test_pad_matmul_error_non_intermediate_buffer():
+    func = te.create_prim_func(te_workload.matmul(128, 127, 127))
+    sch = tir.Schedule(func, debug_mask="all")
+    C = sch.get_block("C")
+    with pytest.raises(ScheduleError):
+        sch.pad_einsum(C, [0, 1, 1])
+
+
+if __name__ == "__main__":
+    tvm.testing.main()