You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2022/09/15 20:15:20 UTC
[tvm] branch main updated: [TIR, Schedule] Add schedule primitive PadEinsum (#12750)
This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 1f8b5dec29 [TIR, Schedule] Add schedule primitive PadEinsum (#12750)
1f8b5dec29 is described below
commit 1f8b5dec29e6e34b4cf5f092acf5b1d197a59d42
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Thu Sep 15 13:15:10 2022 -0700
[TIR, Schedule] Add schedule primitive PadEinsum (#12750)
* [TIR, Schedule] Add schedule primitive PadEinsum
Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
* lint
* [TIR] Fix producer indices check in PadEinsum
* address comments
* simplify lambda expr
* fix
Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
---
include/tvm/tir/schedule/schedule.h | 20 +
python/tvm/tir/schedule/schedule.py | 122 ++++++
src/tir/schedule/analysis.h | 27 ++
src/tir/schedule/analysis/analysis.cc | 29 ++
src/tir/schedule/concrete_schedule.cc | 6 +
src/tir/schedule/concrete_schedule.h | 1 +
src/tir/schedule/primitive.h | 11 +-
.../schedule/primitive/layout_transformation.cc | 36 +-
src/tir/schedule/primitive/pad_einsum.cc | 474 +++++++++++++++++++++
src/tir/schedule/schedule.cc | 3 +-
src/tir/schedule/traced_schedule.cc | 12 +-
src/tir/schedule/traced_schedule.h | 3 +-
src/tir/schedule/transform.cc | 8 +
src/tir/schedule/transform.h | 7 +-
.../unittest/test_tir_schedule_pad_einsum.py | 122 ++++++
15 files changed, 841 insertions(+), 40 deletions(-)
diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h
index da399ab976..8e5cd34d2e 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -627,6 +627,7 @@ class ScheduleNode : public runtime::Object {
BufferIndexType buffer_index_type,
const Array<IntImm>& axis_separators) = 0;
+ /******** Schedule: Padding ********/
/*!
* \brief Decompose a padding block into a block filling const pad values and a block
* writing in-bound values.
@@ -636,6 +637,25 @@ class ScheduleNode : public runtime::Object {
*/
virtual BlockRV DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) = 0;
+ /*!
+ * \brief Pad the computation of Einsum.
+ * \param block_rv The block that matches the Einsum pattern.
+ * \param padding The padding for each block iter.
+ * \details This schedule primitives identifies the Einsum pattern in the block body, and find its
+ * producer blocks. It then pads the computation of the Einsum pattern and its producer blocks.
+ * The output buffer and the producer buffer is resized according to the padding size. It requires
+ * the output buffer and the producer buffer to be allocated inside the PrimFunc.
+ *
+ * The padding is a list of non-negative integers, each element corresponds to the padding for
+ * each block iter in the order of block iters. The block and its producer blocks should have
+ * trivial bindings, i.e. each block iter is bound to a single loop variable. After padding, the
+ * block iter extent and the corresponding outer loop is extended by the padding size.
+ *
+ * The size of the producer buffers are infered from the padding size of the Einsum computation.
+ * The producer buffers are padded by the initial value of the corresponding reduction.
+ */
+ virtual void PadEinsum(const BlockRV& block_rv, const Array<Integer>& padding) = 0;
+
/******** Schedule: Misc ********/
/*! \brief A no-op that marks the start of postprocessing phase of scheduling */
virtual void EnterPostproc() = 0;
diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py
index d1293371a0..fdc8717032 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -2783,6 +2783,128 @@ class Schedule(Object):
"""Check whether the block match padding pattern and can be decomposed."""
return _ffi_api.CanDecomposePadding(self, block, loop) # type: ignore # pylint: disable=no-member
+ @type_checked
+ def pad_einsum(self, block: Union[BlockRV, str], padding: List[int]) -> None:
+ """Pad the computation of Einsum.
+
+ This schedule primitives identifies the Einsum pattern in the block body, and find its
+ producer blocks. It then pads the computation of the Einsum pattern and its producer blocks.
+ The output buffer and the producer buffer is resized according to the padding size. It
+ requires the output buffer and the producer buffer to be allocated inside the PrimFunc.
+
+ The padding is a list of non-negative integers, each element corresponds to the padding for
+ each block iter in the order of block iters. The block and it's producer blocks should have
+ trivial bindings, i.e. each block iter is bound to a single loop variable. After padding,
+ thblock iter extent and the corresponding outer loop is extended by the padding size.
+
+ The size of the producer buffers are infered from the padding size of the Einsum
+ computation. The producer buffers are padded by the initial value of the corresponding
+ reduction.
+
+ Parameters
+ ----------
+ block : Union[BlockRV, str]
+ The block that matches the Einsum pattern.
+
+ padding : List[int]
+ The padding for each block iter.
+
+ Examples
+ --------
+
+ Before applying pad-einsum, in TensorIR, the IR is:
+
+ .. code-block:: python
+
+ @T.prim_func
+ def before_pad_einsum(
+ A: T.Buffer[(128, 127), "float32"],
+ B: T.Buffer[(127, 127), "float32"],
+ C: T.Buffer[(128, 127), "float32"],
+ ) -> None:
+ A_shared = T.alloc_buffer((128, 127), "float32", scope="shared")
+ B_shared = T.alloc_buffer((127, 127), "float32", scope="shared")
+ C_shared = T.alloc_buffer((128, 127), "float32", scope="shared")
+ for i0, i1 in T.grid(128, 127):
+ with T.block("A"):
+ i, j = T.axis.remap("SS", [i0, i1])
+ A_shared[i, j] = A[i, j]
+ for i0, i1 in T.grid(127, 127):
+ with T.block("B"):
+ i, j = T.axis.remap("SS", [i0, i1])
+ B_shared[i, j] = B[i, j]
+ for i0, i1, i2 in T.grid(128, 127, 127):
+ with T.block("C_shared"):
+ i, j, k = T.axis.remap("SSR", [i0, i1, i2])
+ with T.init():
+ C_shared[i, j] = T.float32(0)
+ C_shared[i, j] = C_shared[i, j] + A_shared[i, k] * B_shared[k, j]
+ for i0, i1 in T.grid(128, 127):
+ with T.block("C"):
+ i, j = T.axis.remap("SS", [i0, i1])
+ C[i, j] = C_shared[i, j]
+
+ Create the schedule and do pad-einsum with specified block:
+
+ .. code-block:: python
+
+ sch = tir.Schedule(before_pad_einsum, debug_mask="all")
+ block = sch.get_block("C_shared")
+ sch.pad_einsum(block, [0, 1, 1])
+ print(sch.mod["main"].script())
+
+ After applying decompose-padding, the IR becomes:
+
+ .. code-block:: python
+
+ @T.prim_func
+ def after_pad_einsum(
+ A: T.Buffer[(128, 127), "float32"],
+ B: T.Buffer[(127, 127), "float32"],
+ C: T.Buffer[(128, 127), "float32"],
+ ) -> None:
+ A_shared_padded = T.alloc_buffer([128, 128], dtype="float32", scope="shared")
+ B_shared_padded = T.alloc_buffer([128, 128], dtype="float32", scope="shared")
+ C_shared_padded = T.alloc_buffer([128, 128], dtype="float32", scope="shared")
+ for i0, i1 in T.grid(128, 128):
+ with T.block("A"):
+ i, j = T.axis.remap("SS", [i0, i1])
+ T.reads(A[i, j])
+ T.writes(A_shared_padded[i, j])
+ A_shared_padded[i, j] = T.if_then_else(
+ j < 127, A[i, j], T.float32(0), dtype="float32"
+ )
+ for i0, i1 in T.grid(128, 128):
+ with T.block("B"):
+ i, j = T.axis.remap("SS", [i0, i1])
+ T.reads(B[i, j])
+ T.writes(B_shared_padded[i, j])
+ B_shared_padded[i, j] = T.if_then_else(
+ i < 127 and j < 127, B[i, j], T.float32(0), dtype="float32"
+ )
+ for i0, i1, i2 in T.grid(128, 128, 128):
+ with T.block("C_shared"):
+ i, j, k = T.axis.remap("SSR", [i0, i1, i2])
+ T.reads(A_shared_padded[i, k], B_shared_padded[k, j])
+ T.writes(C_shared_padded[i, j])
+ with T.init():
+ C_shared_padded[i, j] = T.float32(0)
+ C_shared_padded[i, j] = (
+ C_shared_padded[i, j] + A_shared_padded[i, k] * B_shared_padded[k, j]
+ )
+ for i0, i1 in T.grid(128, 127):
+ with T.block("C"):
+ i, j = T.axis.remap("SS", [i0, i1])
+ T.reads(C_shared_padded[i, j])
+ T.writes(C[i, j])
+ C[i, j] = C_shared_padded[i, j]
+
+ """
+ block = self._normalize_block_arg(block)
+ return _ffi_api.SchedulePadEinsum( # type: ignore # pylint: disable=no-member
+ self, block, padding
+ )
+
########## Schedule: Misc ##########
@type_checked
diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h
index 489df8959d..ca45bcac6b 100644
--- a/src/tir/schedule/analysis.h
+++ b/src/tir/schedule/analysis.h
@@ -298,6 +298,15 @@ bool GetVarsTouchedByBlockIters(const BlockRealize& block_realize,
void CheckLoopStartsWithZero(const ScheduleState& self, const StmtSRef& loop_sref,
arith::Analyzer* analyzer);
+/*!
+ * \brief Check whether a block has a trivial binding, i.e. each block var is bound to a outer loop,
+ * from outer to inner.
+ * \param self The schedule state
+ * \param block_sref The block to be checked
+ * \throw ScheduleError If the block does not have trivial bindings
+ */
+void CheckBlockHasTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref);
+
/******** Block-loop relation ********/
/*!
@@ -697,6 +706,24 @@ Array<arith::IntSet> AnalyzeRegionLowerBound(const BufferRegion& region, const P
const StmtSRef& dom_high_exclusive,
arith::Analyzer* analyzer);
+/*!
+ * \brief Check if buffer indices are all Vars and extr
+ * \param buffer_access The BufferLoad or BufferStore
+ * \return The indices if the indices are all Vars, otherwise NullOpt
+ */
+template <typename T>
+Optional<Array<Var>> CheckTrivialBufferIndices(const T& buffer_access) {
+ Array<Var> indices;
+ for (const PrimExpr& index : buffer_access->indices) {
+ const VarNode* var = index.as<VarNode>();
+ if (var == nullptr) {
+ return NullOpt;
+ }
+ indices.push_back(GetRef<Var>(var));
+ }
+ return indices;
+}
+
/*! \brief Necessary information used for tensorization */
class TensorizeInfoNode : public Object {
public:
diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc
index 7ed60876ab..4f78b0c9cd 100644
--- a/src/tir/schedule/analysis/analysis.cc
+++ b/src/tir/schedule/analysis/analysis.cc
@@ -652,6 +652,35 @@ void CheckAffineBinding(const ScheduleState& self, Block block) {
CheckPartialAffineBinding(self, std::move(block), NullOpt);
}
+void CheckBlockHasTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref) {
+ class NotTrivialBindingError : public ScheduleError {
+ public:
+ explicit NotTrivialBindingError(IRModule mod, Block block)
+ : mod_(std::move(mod)), block_(std::move(block)) {}
+
+ String FastErrorString() const final {
+ return "ScheduleError: The binding values of the block are not variables of outer loops.";
+ }
+
+ String DetailRenderTemplate() const final {
+ std::ostringstream os;
+ os << "The binding values of the {0} are not variables of outer loops.";
+ return os.str();
+ }
+
+ IRModule mod() const final { return mod_; }
+ Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
+
+ private:
+ IRModule mod_;
+ Block block_;
+ };
+
+ if (!IsTrivialBinding(self, block_sref)) {
+ throw NotTrivialBindingError(self->mod, GetRef<Block>(block_sref->StmtAs<BlockNode>()));
+ }
+}
+
Map<Var, Range> LoopDomainOfSRefTreePath(const StmtSRef& low_inclusive,
const Optional<StmtSRef>& high_exclusive,
const runtime::StorageScope& extra_relax_scope) {
diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc
index afc6757997..9d7dc6b95f 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -795,6 +795,12 @@ BlockRV ConcreteScheduleNode::DecomposePadding(const BlockRV& block_rv, const Lo
return CreateRV<BlockRV>(result);
}
+void ConcreteScheduleNode::PadEinsum(const BlockRV& block_rv, const Array<Integer>& padding) {
+ TVM_TIR_SCHEDULE_BEGIN();
+ tir::PadEinsum(state_, this->GetSRef(block_rv), padding);
+ TVM_TIR_SCHEDULE_END("pad-einsum", this->error_render_level_);
+ this->state_->DebugVerify();
+}
/******** Schedule: Misc ********/
} // namespace tir
diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h
index e79d1d5288..1aa9dafcc9 100644
--- a/src/tir/schedule/concrete_schedule.h
+++ b/src/tir/schedule/concrete_schedule.h
@@ -128,6 +128,7 @@ class ConcreteScheduleNode : public ScheduleNode {
/******** Schedule: Reduction ********/
BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) override;
BlockRV DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) override;
+ void PadEinsum(const BlockRV& block_rv, const Array<Integer>& padding) override;
/******** Schedule: Block annotation ********/
void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor,
int offset) override;
diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h
index 05d9e4cf94..97233fe4bc 100644
--- a/src/tir/schedule/primitive.h
+++ b/src/tir/schedule/primitive.h
@@ -490,7 +490,7 @@ TVM_DLL void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int
TVM_DLL void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref,
const IndexMap& index_map);
-/******** Schedule: Padding decomposition ********/
+/******** Schedule: Padding ********/
/*!
* \brief Decompose a padding block into a block filling const pad values and a block
* writing in-bound values.
@@ -501,6 +501,15 @@ TVM_DLL void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref
TVM_DLL StmtSRef DecomposePadding(ScheduleState self, const StmtSRef& block_sref,
const StmtSRef& loop_sref);
+/*!
+ * \brief Pad the computation of Einsum.
+ * \param self The state of the schedule
+ * \param block_sref The block sref that matches the Einsum pattern.
+ * \param padding The padding for each block iter.
+ */
+TVM_DLL void PadEinsum(ScheduleState self, const StmtSRef& block_sref,
+ const Array<Integer>& padding);
+
/******** Schedule: Misc ********/
} // namespace tir
diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc
index 8e2643db01..32ed279f02 100644
--- a/src/tir/schedule/primitive/layout_transformation.cc
+++ b/src/tir/schedule/primitive/layout_transformation.cc
@@ -278,40 +278,6 @@ class IndexMapNotApplicableToBlockIterError : public ScheduleError {
IndexMap index_map_;
};
-class NotTrivialBindingError : public ScheduleError {
- public:
- explicit NotTrivialBindingError(IRModule mod, Block block)
- : mod_(std::move(mod)), block_(std::move(block)) {}
-
- static void CheckBlockHasTrivialBinding(const IRModule& mod, const BlockRealize& block_realize,
- std::unordered_set<const VarNode*> outer_loop_vars) {
- // Step 2: Check all the binding values are loops vars
- for (const PrimExpr& iter_value : block_realize->iter_values) {
- const VarNode* loop_var = iter_value.as<VarNode>();
- if (!loop_var || !outer_loop_vars.count(loop_var)) {
- throw NotTrivialBindingError(mod, block_realize->block);
- }
- }
- }
-
- String FastErrorString() const final {
- return "ScheduleError: The binding values of the block are not variables of outer loops.";
- }
-
- String DetailRenderTemplate() const final {
- std::ostringstream os;
- os << "The binding values of the {0} are not variables of outer loops.";
- return os.str();
- }
-
- IRModule mod() const final { return mod_; }
- Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
-
- private:
- IRModule mod_;
- Block block_;
-};
-
class OpaqueNewIterTypeError : public ScheduleError {
public:
explicit OpaqueNewIterTypeError(IRModule mod, Block block, PrimExpr iter_value)
@@ -363,7 +329,7 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref,
}
BlockRealize block_realize = GetBlockRealize(self, block_sref);
- NotTrivialBindingError::CheckBlockHasTrivialBinding(self->mod, block_realize, loop_vars);
+ CheckBlockHasTrivialBinding(self, block_sref);
// Step 3: Collect information of block iter vars
Array<PrimExpr> block_vars; // iter_var->var of each block iter
diff --git a/src/tir/schedule/primitive/pad_einsum.cc b/src/tir/schedule/primitive/pad_einsum.cc
new file mode 100644
index 0000000000..7a7b88d686
--- /dev/null
+++ b/src/tir/schedule/primitive/pad_einsum.cc
@@ -0,0 +1,474 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <optional>
+
+#include "../utils.h"
+
+namespace tvm {
+namespace tir {
+
+/*! \brief The schedule error class when the padding size is invalid. */
+class InvalidPaddingError : public ScheduleError {
+ public:
+ InvalidPaddingError(IRModule mod, Block block, Array<Integer> padding)
+ : mod_(std::move(mod)), block_(std::move(block)), padding_(std::move(padding)) {}
+ IRModule mod() const final { return mod_; }
+ Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
+ String FastErrorString() const final {
+ return "ScheduleError: The padding size for the block is invalid.";
+ }
+ String DetailRenderTemplate() const final {
+ std::ostringstream os;
+ os << "The padding for the block {0} are invalid. It should be a list of "
+ << block_->iter_vars.size() << " non-negative integers. Got " << padding_;
+ return os.str();
+ }
+
+ static void Check(const ScheduleState& self, const Block& block, Array<Integer> padding) {
+ if (padding.size() != block->iter_vars.size()) {
+ throw InvalidPaddingError(self->mod, block, padding);
+ }
+ for (const auto& pad : padding) {
+ if (pad->value < 0) {
+ throw InvalidPaddingError(self->mod, block, padding);
+ }
+ }
+ }
+
+ private:
+ IRModule mod_;
+ Block block_;
+ Array<Integer> padding_;
+};
+
+/*! \brief The schedule error class when the block body is not an Einsum pattern. */
+class NonEinsumError : public ScheduleError {
+ public:
+ explicit NonEinsumError(IRModule mod, Block block)
+ : mod_(std::move(mod)), block_(std::move(block)) {}
+
+ IRModule mod() const final { return mod_; }
+ Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
+ String FastErrorString() const final {
+ return "ScheduleError: The block is not a computation of Einsum pattern.";
+ }
+ String DetailRenderTemplate() const final {
+ return "The block {0} not a computation of Einsum pattern.";
+ }
+
+ private:
+ IRModule mod_;
+ Block block_;
+};
+
+/*! \brief Data structure that represents a Einsum computation. */
+struct Einsum {
+ // The output buffer
+ Buffer output_buffer;
+ // The indices of the output buffer
+ Array<Var> output_indices;
+ // The indices of the input buffers
+ Map<Buffer, Array<Var>> input_indices;
+};
+
+class EinsumExtractor : public ExprVisitor {
+ public:
+ EinsumExtractor() = default;
+
+ std::optional<Einsum> Extract(const Block& block) {
+ const BufferStoreNode* update = block->body.as<BufferStoreNode>();
+ // Step 1: Check the body is a BufferStore and the block has the init statement, and the
+ // BufferStore and the init statement store have the same output buffer indices.
+ if (update == nullptr || !block->init.defined()) {
+ return std::nullopt;
+ }
+
+ if (Optional<Array<Var>> opt_indices = CheckTrivialBufferIndices(update);
+ opt_indices.defined()) {
+ ein_sum_.output_indices = std::move(opt_indices.value());
+ } else {
+ return std::nullopt;
+ }
+ ein_sum_.output_buffer = update->buffer;
+
+ const BufferStoreNode* init = block->init.value().as<BufferStoreNode>();
+ ICHECK(init != nullptr);
+ if (!CompareBufferIndices(init->indices, ein_sum_.output_indices)) {
+ return std::nullopt;
+ }
+ // Step 2: Check the BufferStore updates the output buffer and the input buffers indices are
+ // block iter variables.
+ CheckStoreValue(update->value);
+ if (fail_) {
+ return std::nullopt;
+ }
+ return std::move(ein_sum_);
+ }
+
+ private:
+ void CheckStoreValue(const PrimExpr& update) {
+ // Check the update part has the form:
+ // Output[output_indices] += Input_0[input_indices_0] op_0 Input_1[input_indices_1] op_1 ...
+ // where output_indices and input_indices_i are the indices are arrays whose elements are the
+ // block iter variables instead of composite PrimExpr, and op_i are the binary operations.
+
+ // Check the value is Add and eithe LHS or RHS is the BufferLoad from the output buffer.
+ const AddNode* add = update.as<AddNode>();
+ if (add == nullptr) {
+ fail_ = true;
+ return;
+ }
+ const BufferLoadNode* lhs = add->a.as<BufferLoadNode>();
+ const BufferLoadNode* rhs = add->b.as<BufferLoadNode>();
+ if (lhs == nullptr && rhs != nullptr) {
+ std::swap(lhs, rhs);
+ }
+ if (lhs == nullptr || !lhs->buffer.same_as(ein_sum_.output_buffer) ||
+ !CompareBufferIndices(lhs->indices, ein_sum_.output_indices)) {
+ fail_ = true;
+ return;
+ }
+ VisitExpr(add->b);
+ }
+
+ void VisitExpr(const PrimExpr& n) final {
+ if (n->IsInstance<BufferLoadNode>() || n->IsInstance<MulNode>() || n->IsInstance<CastNode>()) {
+ ExprVisitor::VisitExpr(n);
+ } else {
+ fail_ = true;
+ return;
+ }
+ }
+
+ void VisitExpr_(const BufferLoadNode* op) final {
+ if (auto it = ein_sum_.input_indices.find(op->buffer);
+ it != ein_sum_.input_indices.end() && !CompareBufferIndices(op->indices, (*it).second)) {
+ fail_ = true;
+ return;
+ }
+ if (Optional<Array<Var>> opt_indices = CheckTrivialBufferIndices(op); opt_indices.defined()) {
+ ein_sum_.input_indices.Set(op->buffer, std::move(opt_indices.value()));
+ } else {
+ fail_ = true;
+ return;
+ }
+ }
+
+ void VisitExpr_(const CastNode* op) { VisitExpr(op->value); }
+
+ bool Fail() { return fail_; }
+
+ bool CompareBufferIndices(const Array<PrimExpr>& indices, const Array<Var>& other) {
+ return std::equal(indices.begin(), indices.end(), other.begin(), other.end(),
+ [](const PrimExpr& a, const Var& b) { return a.same_as(b); });
+ }
+
+ Einsum ein_sum_;
+ bool fail_{false};
+};
+
+Einsum ExtractEinsum(const ScheduleState& self, const Block& block) {
+ EinsumExtractor extractor;
+ std::optional<Einsum> einsum = extractor.Extract(block);
+ if (!einsum.has_value()) {
+ throw NonEinsumError(self->mod, block);
+ }
+ return einsum.value();
+}
+
+class BufferNotAllocatedInScopeError : public ScheduleError {
+ public:
+ explicit BufferNotAllocatedInScopeError(IRModule mod, Buffer buffer)
+ : mod_(std::move(mod)), buffer_(std::move(buffer)) {}
+
+ String FastErrorString() const final {
+ return "ScheduleError: The buffer is not allocated as an intermediate buffer in current "
+ "PrimFunc.";
+ }
+
+ String DetailRenderTemplate() const final {
+ std::ostringstream os;
+ os << "The buffer " << buffer_->name
+ << " is not allocated as an intermediate buffer in current PrimFunc.";
+ return os.str();
+ }
+
+ IRModule mod() const final { return mod_; }
+ Array<ObjectRef> LocationsOfInterest() const final { return {}; }
+
+ private:
+ IRModule mod_;
+ Buffer buffer_;
+};
+
+class PadEinsumRewriter : public ReplaceBufferMutator {
+ public:
+ PadEinsumRewriter(const std::unordered_map<const BlockNode*, PrimExpr> producer_predicate,
+ Map<Var, PrimExpr> padded_iter_extents, const Map<Buffer, Buffer>& buffer_remap,
+ Map<Block, Block>* block_sref_reuse, arith::Analyzer* analyzer)
+ : ReplaceBufferMutator(buffer_remap, block_sref_reuse),
+ producer_predicate_(producer_predicate),
+ padded_iter_extents_(padded_iter_extents),
+ analyzer_(analyzer) {}
+
+ Stmt VisitStmt_(const ForNode* op) final {
+ For new_for = Downcast<For>(ReplaceBufferMutator::VisitStmt_(op));
+ if (padded_iter_extents_.count(new_for->loop_var)) {
+ new_for.CopyOnWrite()->extent = padded_iter_extents_.at(new_for->loop_var);
+ }
+ return std::move(new_for);
+ }
+
+ Block PadProducerBlock(Block block, const PrimExpr& predicate) {
+ BufferStore store = Downcast<BufferStore>(block->body);
+ store.CopyOnWrite()->value =
+ analyzer_->Simplify(if_then_else(predicate, store->value, make_zero(store->value.dtype())));
+ block.CopyOnWrite()->body = std::move(store);
+ return block;
+ }
+
+ Stmt VisitStmt_(const BlockNode* op) final {
+ Block old_block = GetRef<Block>(op);
+ Block new_block = Downcast<Block>(ReplaceBufferMutator::VisitStmt_(op));
+ if (auto it = producer_predicate_.find(op); it != producer_predicate_.end()) {
+ new_block = PadProducerBlock(std::move(new_block), (*it).second);
+ }
+
+ // Mutate block iters
+ Array<IterVar> new_iters;
+ bool changed = false;
+ for (const IterVar& iter : new_block->iter_vars) {
+ if (auto it = padded_iter_extents_.find(iter->var); it != padded_iter_extents_.end()) {
+ changed = true;
+ new_iters.push_back(
+ IterVar(Range::FromMinExtent(0, (*it).second), iter->var, iter->iter_type));
+ } else {
+ new_iters.push_back(iter);
+ }
+ }
+ if (changed) {
+ new_block.CopyOnWrite()->iter_vars = std::move(new_iters);
+ }
+ if (!old_block.same_as(new_block)) {
+ block_sref_reuse_->Set(old_block, new_block);
+ }
+ return std::move(new_block);
+ }
+
+ private:
+ const std::unordered_set<const BlockNode*> producer_blocks_;
+ const std::unordered_map<const BlockNode*, PrimExpr> producer_predicate_;
+ const Map<Var, PrimExpr> padded_iter_extents_;
+ arith::Analyzer* analyzer_;
+};
+
+/*! \brief The schedule error class when the producer block cannot be padded. */
+class InvalidProducerError : public ScheduleError {
+ public:
+ explicit InvalidProducerError(IRModule mod, Block producer)
+ : mod_(std::move(mod)), producer_(std::move(producer)) {}
+
+ String FastErrorString() const final {
+ return "ScheduleError: The producer block cannot be padded.";
+ }
+
+ String DetailRenderTemplate() const final {
+ std::ostringstream os;
+ os << "The producer block {0} cannot be padded. It should write to a single buffer and the "
+ "body should be a BufferStore.";
+ return os.str();
+ }
+
+ IRModule mod() const final { return mod_; }
+ Array<ObjectRef> LocationsOfInterest() const final { return {producer_}; }
+
+ private:
+ IRModule mod_;
+ Buffer buffer_;
+ Block producer_;
+};
+
+void PadEinsum(ScheduleState self, const StmtSRef& block_sref, const Array<Integer>& padding) {
+ arith::Analyzer analyzer;
+ // Step 1: Input checking and error handling
+ const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
+ BlockRealize realize = GetBlockRealize(self, block_sref);
+
+ const StmtSRef& scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true);
+ InvalidPaddingError::Check(self, GetRef<Block>(block), padding);
+
+ const Array<StmtSRef> producers = GetProducers(self, block_sref);
+ {
+ auto f_check_block_properties = [&](const StmtSRef& block_sref, bool is_producer) {
+ CheckBlockHasTrivialBinding(self, block_sref);
+ if (is_producer) {
+ CheckCompleteBlock(self, block_sref, scope_sref);
+ } else {
+ CheckReductionBlock(self, block_sref, scope_sref);
+ }
+ Array loops = GetLoops(block_sref);
+ ICHECK(!loops.empty());
+ CheckGetSingleChildBlockRealizeOnSRefTree(self, loops.front());
+ };
+
+ // Check block properties of the computation block
+ f_check_block_properties(block_sref, false);
+
+ // Check block properties of the producer block
+ for (const StmtSRef& producer_sref : producers) {
+ f_check_block_properties(producer_sref, true);
+ }
+ }
+
+ Einsum einsum = ExtractEinsum(self, GetRef<Block>(block));
+
+ // Check input and output buffers are all allocated in the current scope.
+ {
+ auto f_check_buffer_allocated = [&](const Buffer& buffer) {
+ auto [defining_site_sref, is_allocate] = GetBufferDefiningSite(block_sref, buffer);
+ if (!defining_site_sref.defined() || !is_allocate) {
+ throw BufferNotAllocatedInScopeError(self->mod, buffer);
+ }
+ };
+ f_check_buffer_allocated(einsum.output_buffer);
+ for (const auto& buffer_indices_pair : einsum.input_indices) {
+ f_check_buffer_allocated(buffer_indices_pair.first);
+ }
+ }
+
+ // Step 2: Prepare buffer and variable remapping. Infer the new shape of the input and the output
+ // buffers. Infer the new extent of the block iters of the computation block and the producer
+ // block.
+
+ Map<Var, PrimExpr> padded_iter_extents; // The new extents of both the block iters and loop vars
+
+ // Convert the input padding array to a map from variables to the padded extents
+ for (int i = 0, n = padding.size(); i < n; ++i) {
+ const IterVar& iter = block->iter_vars[i];
+ PrimExpr new_extent =
+ IntImm(iter->var->dtype, Downcast<Integer>(iter->dom->extent)->value + padding[i]->value);
+ padded_iter_extents.Set(iter->var, new_extent);
+ padded_iter_extents.Set(Downcast<Var>(realize->iter_values[i]), new_extent);
+ }
+
+ Map<Buffer, Buffer> buffer_remap; // mapping from buffers to new buffers with padded shapes
+
+ // Utility function to pad a buffer with the new shape
+ auto f_pad_buffer = [&padded_iter_extents, &buffer_remap](Buffer buffer,
+ const Array<Var>& indices) -> Buffer {
+ Array<PrimExpr> new_shape;
+ for (const Var& index : indices) {
+ new_shape.push_back(padded_iter_extents.at(index));
+ }
+ ICHECK_EQ(buffer->shape.size(), new_shape.size());
+ buffer.CopyOnWrite()->shape = std::move(new_shape);
+ return buffer;
+ };
+
+ buffer_remap.Set(einsum.output_buffer, f_pad_buffer(einsum.output_buffer, einsum.output_indices));
+
+ std::unordered_map<const BlockNode*, PrimExpr> producer_predicate;
+
+ // Different from the output block, the padding for the producer block is not directly specified
+ // as the input argument. Instead, it is inferred from indices of the producer buffer accessed in
+ // the output block.
+ // We will find the indices (which are block iters) in BufferStore to the producer buffer
+ // and infer the new extents of the block iters and the corresponding loop vars.
+ for (const StmtSRef& producer_sref : producers) {
+ const BlockNode* producer_block = TVM_SREF_TO_BLOCK(producer_sref);
+ const BufferStoreNode* buffer_store = producer_block->body.as<BufferStoreNode>();
+ Optional<Array<Var>> producer_store_indices;
+ if (!buffer_store || producer_block->writes.size() != 1 ||
+ !(producer_store_indices = CheckTrivialBufferIndices(buffer_store)).defined()) {
+ throw InvalidProducerError(self->mod, GetRef<Block>(producer_block));
+ }
+ BlockRealize producer_realize = GetBlockRealize(self, producer_sref);
+
+ const Buffer& old_buffer = producer_block->writes[0]->buffer;
+ Buffer new_buffer = f_pad_buffer(old_buffer, einsum.input_indices.at(old_buffer));
+ buffer_remap.Set(old_buffer, new_buffer);
+
+ // The predicate to ensure the producer block is in the original bound before padding
+ PrimExpr predicate = Bool(true);
+ Map<Var, PrimExpr> indices_to_padded_extents; // buffer indices to padded extents
+ for (int i = 0, n = producer_store_indices.value().size(); i < n; ++i) {
+ const Var& index = producer_store_indices.value()[i];
+ PrimExpr padded_extent = new_buffer->shape[i];
+ if (!analyzer.CanProveEqual(padded_extent, old_buffer->shape[i])) {
+ predicate = predicate && (index < old_buffer->shape[i]);
+ }
+ indices_to_padded_extents.Set(index, padded_extent);
+ }
+
+ for (int i = 0, n = producer_block->iter_vars.size(); i < n; ++i) {
+ const IterVar& iter = producer_block->iter_vars[i];
+ if (auto it = indices_to_padded_extents.find(iter->var);
+ it != indices_to_padded_extents.end()) {
+ const PrimExpr& padded_extent = (*it).second;
+ padded_iter_extents.Set(iter->var, padded_extent);
+ padded_iter_extents.Set(Downcast<Var>(producer_realize->iter_values[i]), padded_extent);
+ } else if (!is_one(iter->dom->extent)) {
+ throw InvalidProducerError(self->mod, GetRef<Block>(producer_block));
+ }
+ }
+ producer_predicate[producer_block] = predicate;
+ }
+
+ // Step 3: Mutate the AST subtree with the new buffers and the new block iter extents.
+ Map<Block, Block> block_sref_reuse;
+ PadEinsumRewriter rewriter(producer_predicate, padded_iter_extents, buffer_remap,
+ &block_sref_reuse, &analyzer);
+ const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_sref);
+ Stmt new_scope_block = rewriter(GetRef<Block>(scope_block));
+
+ // Step 4: Do the actual replacement.
+ self->Replace(scope_sref, new_scope_block, block_sref_reuse);
+}
+
+/******** Instruction Registration ********/
+
+struct PadEinsumTraits : public UnpackedInstTraits<PadEinsumTraits> {
+ static constexpr const char* kName = "PadEinsum";
+ static constexpr bool kIsPure = false;
+
+ private:
+ static constexpr size_t kNumInputs = 1;
+ static constexpr size_t kNumAttrs = 1;
+ static constexpr size_t kNumDecisions = 0;
+
+ static void UnpackedApplyToSchedule(Schedule sch, BlockRV block, Array<Integer> padding) {
+ sch->PadEinsum(block, padding);
+ }
+
+ static String UnpackedAsPython(Array<String> outputs, String block, Array<Integer> padding) {
+ PythonAPICall py("pad_einsum");
+ py.Input("block", block);
+ py.Input("padding", padding);
+ return py.Str();
+ }
+
+ template <typename>
+ friend struct ::tvm::tir::UnpackedInstTraits;
+};
+
+TVM_REGISTER_INST_KIND_TRAITS(PadEinsumTraits);
+
+} // namespace tir
+} // namespace tvm
diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc
index 091db344aa..d72f67fb7c 100644
--- a/src/tir/schedule/schedule.cc
+++ b/src/tir/schedule/schedule.cc
@@ -264,7 +264,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSetAxisSeparator")
/******** (FFI) Padding decomposition ********/
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleDecomposePadding")
.set_body_method<Schedule>(&ScheduleNode::DecomposePadding);
-
+TVM_REGISTER_GLOBAL("tir.schedule.SchedulePadEinsum")
+ .set_body_method<Schedule>(&ScheduleNode::PadEinsum);
/******** (FFI) Misc ********/
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleEnterPostproc")
.set_body_method<Schedule>(&ScheduleNode::EnterPostproc);
diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc
index 04ddc0507d..a31950d331 100644
--- a/src/tir/schedule/traced_schedule.cc
+++ b/src/tir/schedule/traced_schedule.cc
@@ -520,7 +520,7 @@ void TracedScheduleNode::SetAxisSeparator(const BlockRV& block_rv, int buffer_in
/*outputs=*/{}));
}
-/******** Schedule: Padding decomposition ********/
+/******** Schedule: Padding ********/
BlockRV TracedScheduleNode::DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) {
BlockRV new_block = ConcreteScheduleNode::DecomposePadding(block_rv, loop_rv);
static const InstructionKind& kind = InstructionKind::Get("DecomposePadding");
@@ -532,6 +532,16 @@ BlockRV TracedScheduleNode::DecomposePadding(const BlockRV& block_rv, const Loop
return new_block;
}
+void TracedScheduleNode::PadEinsum(const BlockRV& block_rv, const Array<Integer>& padding) {
+ ConcreteScheduleNode::PadEinsum(block_rv, padding);
+ static const InstructionKind& kind = InstructionKind::Get("PadEinsum");
+ trace_->Append(/*inst=*/Instruction(
+ /*kind=*/kind,
+ /*inputs=*/{block_rv},
+ /*attrs=*/{padding},
+ /*outputs=*/{}));
+}
+
/******** Schedule: Misc ********/
void TracedScheduleNode::EnterPostproc() {
diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h
index d98e4ba4bb..ad44cc6ae5 100644
--- a/src/tir/schedule/traced_schedule.h
+++ b/src/tir/schedule/traced_schedule.h
@@ -108,8 +108,9 @@ class TracedScheduleNode : public ConcreteScheduleNode {
void SetAxisSeparator(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type,
const Array<IntImm>& axis_separators) final;
- /******** Schedule: Padding decomposition ********/
+ /******** Schedule: Padding ********/
BlockRV DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) final;
+ void PadEinsum(const BlockRV& block_rv, const Array<Integer>& padding) final;
/******** Schedule: Misc ********/
void EnterPostproc() final;
};
diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc
index c11fa656d6..dfbd3dbcbc 100644
--- a/src/tir/schedule/transform.cc
+++ b/src/tir/schedule/transform.cc
@@ -103,6 +103,14 @@ ReplaceBufferMutator::ReplaceBufferMutator(const Buffer& old_buffer, Buffer new_
buffer_var_map_[old_buffer->data.get()] = std::move(new_buffer);
}
+ReplaceBufferMutator::ReplaceBufferMutator(const Map<Buffer, Buffer>& buffer_map,
+ Map<Block, Block>* block_sref_reuse)
+ : block_sref_reuse_(block_sref_reuse) {
+ for (const auto& [old_buffer, new_buffer] : buffer_map) {
+ buffer_var_map_[old_buffer->data.get()] = new_buffer;
+ }
+}
+
PrimExpr ReplaceBufferMutator::VisitExpr_(const VarNode* var) {
auto it = buffer_var_map_.find(var);
return it != buffer_var_map_.end() ? it->second->data : GetRef<Var>(var);
diff --git a/src/tir/schedule/transform.h b/src/tir/schedule/transform.h
index 908a823c2d..4de3685e24 100644
--- a/src/tir/schedule/transform.h
+++ b/src/tir/schedule/transform.h
@@ -114,7 +114,12 @@ class ReplaceBufferMutator : public StmtExprMutator {
ReplaceBufferMutator(const Buffer& old_buffer, Buffer new_buffer,
Map<Block, Block>* block_sref_reuse);
+ ReplaceBufferMutator(const Map<Buffer, Buffer>& buffer_map, Map<Block, Block>* block_sref_reuse);
+
protected:
+ using StmtExprMutator::VisitExpr_;
+ using StmtExprMutator::VisitStmt_;
+
PrimExpr VisitExpr_(const VarNode* var) final;
template <typename Node>
@@ -132,7 +137,7 @@ class ReplaceBufferMutator : public StmtExprMutator {
virtual MatchBufferRegion VisitMatchBufferRegion(const MatchBufferRegion& match_buffer);
- Stmt VisitStmt_(const BlockNode* block) final;
+ Stmt VisitStmt_(const BlockNode* block) override;
/*!
* \brief A mapping which maps old buffer vars to new buffers, including the buffers defined in
diff --git a/tests/python/unittest/test_tir_schedule_pad_einsum.py b/tests/python/unittest/test_tir_schedule_pad_einsum.py
new file mode 100644
index 0000000000..89628db4ff
--- /dev/null
+++ b/tests/python/unittest/test_tir_schedule_pad_einsum.py
@@ -0,0 +1,122 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+import sys
+
+import pytest
+import tvm
+import tvm.testing
+from tvm import tir, te
+from tvm.script import tir as T
+from tvm.tir.schedule.schedule import ScheduleError
+from tvm.tir.schedule.testing import verify_trace_roundtrip
+from tvm.meta_schedule.testing import te_workload
+
+# pylint: disable=no-member,invalid-name,unused-variable,unexpected-keyword-arg
+
+
+@T.prim_func
+def matmul_before(
+ A: T.Buffer[(128, 127), "float32"],
+ B: T.Buffer[(127, 127), "float32"],
+ C: T.Buffer[(128, 127), "float32"],
+) -> None:
+ A_shared = T.alloc_buffer((128, 127), "float32", scope="shared")
+ B_shared = T.alloc_buffer((127, 127), "float32", scope="shared")
+ C_shared = T.alloc_buffer((128, 127), "float32", scope="shared")
+ for i0, i1 in T.grid(128, 127):
+ with T.block("A"):
+ i, j = T.axis.remap("SS", [i0, i1])
+ A_shared[i, j] = A[i, j]
+ for i0, i1 in T.grid(127, 127):
+ with T.block("B"):
+ i, j = T.axis.remap("SS", [i0, i1])
+ B_shared[i, j] = B[i, j]
+ for i0, i1, i2 in T.grid(128, 127, 127):
+ with T.block("C_shared"):
+ i, j, k = T.axis.remap("SSR", [i0, i1, i2])
+ with T.init():
+ C_shared[i, j] = T.float32(0)
+ C_shared[i, j] = C_shared[i, j] + A_shared[i, k] * B_shared[k, j]
+ for i0, i1 in T.grid(128, 127):
+ with T.block("C"):
+ i, j = T.axis.remap("SS", [i0, i1])
+ C[i, j] = C_shared[i, j]
+
+
+@T.prim_func
+def matmul_expected(
+ A: T.Buffer[(128, 127), "float32"],
+ B: T.Buffer[(127, 127), "float32"],
+ C: T.Buffer[(128, 127), "float32"],
+) -> None:
+ A_shared_padded = T.alloc_buffer([128, 128], dtype="float32", scope="shared")
+ B_shared_padded = T.alloc_buffer([128, 128], dtype="float32", scope="shared")
+ C_shared_padded = T.alloc_buffer([128, 128], dtype="float32", scope="shared")
+ for i0, i1 in T.grid(128, 128):
+ with T.block("A"):
+ i, j = T.axis.remap("SS", [i0, i1])
+ T.reads(A[i, j])
+ T.writes(A_shared_padded[i, j])
+ A_shared_padded[i, j] = T.if_then_else(j < 127, A[i, j], T.float32(0), dtype="float32")
+ for i0, i1 in T.grid(128, 128):
+ with T.block("B"):
+ i, j = T.axis.remap("SS", [i0, i1])
+ T.reads(B[i, j])
+ T.writes(B_shared_padded[i, j])
+ B_shared_padded[i, j] = T.if_then_else(
+ i < 127 and j < 127, B[i, j], T.float32(0), dtype="float32"
+ )
+ for i0, i1, i2 in T.grid(128, 128, 128):
+ with T.block("C_shared"):
+ i, j, k = T.axis.remap("SSR", [i0, i1, i2])
+ T.reads(A_shared_padded[i, k], B_shared_padded[k, j])
+ T.writes(C_shared_padded[i, j])
+ with T.init():
+ C_shared_padded[i, j] = T.float32(0)
+ C_shared_padded[i, j] = (
+ C_shared_padded[i, j] + A_shared_padded[i, k] * B_shared_padded[k, j]
+ )
+ for i0, i1 in T.grid(128, 127):
+ with T.block("C"):
+ i, j = T.axis.remap("SS", [i0, i1])
+ T.reads(C_shared_padded[i, j])
+ T.writes(C[i, j])
+ C[i, j] = C_shared_padded[i, j]
+
+
+# pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg
+
+
+def test_pad_matmul():
+ sch = tir.Schedule(matmul_before, debug_mask="all")
+ C = sch.get_block("C_shared")
+ sch.pad_einsum(C, [0, 1, 1])
+ tvm.ir.assert_structural_equal(matmul_expected, sch.mod["main"])
+ verify_trace_roundtrip(sch, mod=matmul_before)
+
+
+def test_pad_matmul_error_non_intermediate_buffer():
+ func = te.create_prim_func(te_workload.matmul(128, 127, 127))
+ sch = tir.Schedule(func, debug_mask="all")
+ C = sch.get_block("C")
+ with pytest.raises(ScheduleError):
+ sch.pad_einsum(C, [0, 1, 1])
+
+
+if __name__ == "__main__":
+ tvm.testing.main()