You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/09/09 23:24:19 UTC

[GitHub] [tvm] vinx13 opened a new pull request, #12750: [TIR, Schedule] Add schedule primitive PadEinsum

vinx13 opened a new pull request, #12750:
URL: https://github.com/apache/tvm/pull/12750

   Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
   
   This PR adds a schedule primitive `PadEinsum`. It is used for computation in Einsum pattern specifically, which cover most cases for tensorization. Unlike general cases for padding in https://github.com/apache/tvm-rfcs/blob/main/rfcs/0077-layout-transform-padding.md, this primitive pads the output blocks and the input blocks at once, which eliminates the need to extra arithmetic analysis to provide the guarantee of program correctness.
   
   cc @Hzfengsy @wrongtest-intellif @spectrometerHBH @Lunderberg 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Lunderberg commented on a diff in pull request #12750: [TIR, Schedule] Add schedule primitive PadEinsum

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on code in PR #12750:
URL: https://github.com/apache/tvm/pull/12750#discussion_r970943198


##########
include/tvm/tir/schedule/schedule.h:
##########
@@ -636,6 +637,25 @@ class ScheduleNode : public runtime::Object {
    */
   virtual BlockRV DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) = 0;
 
+  /*!
+   * \brief Pad the computation of Einsum.
+   * \param block_rv The block that matches the Einsum pattern.
+   * \param padding The padding for each block iter.
+   * \details This schedule primitives identifies the Einsum pattern in the block body, and find its
+   * producer blocks. It then pads the computation of the Einsum pattern and its producer blocks.
+   * The output buffer and the producer buffer is resized according to the padding size. It requires
+   * the output buffer and the producer buffer to be allocated inside the PrimFunc.
+   *
+   * The padding is a list of non-negative integers, each element corresponds to the padding for
+   * each block iter in the order of block iters. The block and it's producer blocks should have

Review Comment:
   Nitpick: "it's" should be "its", without an apostrphe



##########
src/tir/schedule/primitive/pad_einsum.cc:
##########
@@ -0,0 +1,493 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <optional>
+
+#include "../utils.h"
+
+namespace tvm {
+namespace tir {
+
+/*! \brief The schedule error class when the padding size is invalid. */
+class InvalidPaddingError : public ScheduleError {
+ public:
+  InvalidPaddingError(IRModule mod, Block block, Array<Integer> padding)
+      : mod_(std::move(mod)), block_(std::move(block)), padding_(std::move(padding)) {}
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
+  String FastErrorString() const final {
+    return "ScheduleError: The padding size for the block is invalid.";
+  }
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "The padding for the block {0} are invalid. It should be a list of "
+       << block_->iter_vars.size() << " non-negative integers. Got " << padding_;
+    return os.str();
+  }
+
+  static void Check(const ScheduleState& self, const Block& block, Array<Integer> padding) {
+    if (padding.size() != block->iter_vars.size()) {
+      throw InvalidPaddingError(self->mod, block, padding);
+    }
+    for (const auto& pad : padding) {
+      if (pad->value < 0) {
+        throw InvalidPaddingError(self->mod, block, padding);
+      }
+    }
+  }
+
+ private:
+  IRModule mod_;
+  Block block_;
+  Array<Integer> padding_;
+};
+
+/*! \brief The schedule error class when the block body is not an Einsum pattern. */
+class NonEinsumError : public ScheduleError {
+ public:
+  explicit NonEinsumError(IRModule mod, Block block)
+      : mod_(std::move(mod)), block_(std::move(block)) {}
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
+  String FastErrorString() const final {
+    return "ScheduleError: The block is not a computation of Einsum pattern.";
+  }
+  String DetailRenderTemplate() const final {
+    return "The block {0} not a computation of Einsum pattern.";
+  }
+
+ private:
+  IRModule mod_;
+  Block block_;
+};
+
+/*! \brief Data structure that represents a Einsum computation. */
+struct Einsum {
+  // The output buffer
+  Buffer output_buffer;
+  // The indices of the output buffer
+  Array<Var> output_indices;
+  // The indices of the input buffers
+  Map<Buffer, Array<Var>> input_indices;
+};
+
+/*!

Review Comment:
   Could this function be exposed in `analysis.h` alongside `CheckBlockHasTrivialBinding`?  There are similar checks in https://github.com/apache/tvm/pull/12720 that could take advantage of a shared utility.



##########
include/tvm/tir/schedule/schedule.h:
##########
@@ -636,6 +637,25 @@ class ScheduleNode : public runtime::Object {
    */
   virtual BlockRV DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) = 0;
 
+  /*!
+   * \brief Pad the computation of Einsum.
+   * \param block_rv The block that matches the Einsum pattern.
+   * \param padding The padding for each block iter.
+   * \details This schedule primitives identifies the Einsum pattern in the block body, and find its
+   * producer blocks. It then pads the computation of the Einsum pattern and its producer blocks.
+   * The output buffer and the producer buffer is resized according to the padding size. It requires
+   * the output buffer and the producer buffer to be allocated inside the PrimFunc.
+   *
+   * The padding is a list of non-negative integers, each element corresponds to the padding for

Review Comment:
   It looks like the padding can only be applied to the end of an axis/iterator, and cannot be applied to the beginning.  Could we specify two arrays of padding, one for the lower end each block iter and one for the upper end?



##########
src/tir/schedule/primitive/pad_einsum.cc:
##########
@@ -0,0 +1,493 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <optional>
+
+#include "../utils.h"
+
+namespace tvm {
+namespace tir {
+
+/*! \brief The schedule error class when the padding size is invalid. */
+class InvalidPaddingError : public ScheduleError {
+ public:
+  InvalidPaddingError(IRModule mod, Block block, Array<Integer> padding)
+      : mod_(std::move(mod)), block_(std::move(block)), padding_(std::move(padding)) {}
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
+  String FastErrorString() const final {
+    return "ScheduleError: The padding size for the block is invalid.";
+  }
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "The padding for the block {0} are invalid. It should be a list of "
+       << block_->iter_vars.size() << " non-negative integers. Got " << padding_;
+    return os.str();
+  }
+
+  static void Check(const ScheduleState& self, const Block& block, Array<Integer> padding) {
+    if (padding.size() != block->iter_vars.size()) {
+      throw InvalidPaddingError(self->mod, block, padding);
+    }
+    for (const auto& pad : padding) {
+      if (pad->value < 0) {
+        throw InvalidPaddingError(self->mod, block, padding);
+      }
+    }
+  }
+
+ private:
+  IRModule mod_;
+  Block block_;
+  Array<Integer> padding_;
+};
+
+/*! \brief The schedule error class when the block body is not an Einsum pattern. */
+class NonEinsumError : public ScheduleError {
+ public:
+  explicit NonEinsumError(IRModule mod, Block block)
+      : mod_(std::move(mod)), block_(std::move(block)) {}
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
+  String FastErrorString() const final {
+    return "ScheduleError: The block is not a computation of Einsum pattern.";
+  }
+  String DetailRenderTemplate() const final {
+    return "The block {0} not a computation of Einsum pattern.";
+  }
+
+ private:
+  IRModule mod_;
+  Block block_;
+};
+
+/*! \brief Data structure that represents a Einsum computation. */
+struct Einsum {
+  // The output buffer
+  Buffer output_buffer;
+  // The indices of the output buffer
+  Array<Var> output_indices;
+  // The indices of the input buffers
+  Map<Buffer, Array<Var>> input_indices;
+};
+
+/*!
+ * \brief Check if buffer indices are all Vars and extr
+ * \param buffer_access The BufferLoad or BufferStore
+ * \param[out] indices The optional array to store the indices
+ * \return Whether The indices are all Vars
+ */
+template <typename T>
+bool CheckTrivialBufferIndices(const T& buffer_access, Array<Var>* indices = nullptr) {
+  for (const PrimExpr& index : buffer_access->indices) {
+    const VarNode* var = index.as<VarNode>();
+    if (var == nullptr) {
+      return false;
+    }
+    if (indices != nullptr) {
+      indices->push_back(GetRef<Var>(var));
+    }
+  }
+  return true;
+}
+
+class EinsumExtractor : public ExprVisitor {
+ public:
+  EinsumExtractor() = default;
+
+  std::optional<Einsum> Extract(const Block& block) {
+    const BufferStoreNode* update = block->body.as<BufferStoreNode>();
+    // Step 1: Check the body is a BufferStore and the block has the init statement, and the
+    // BufferStore and the init statement store have the same output buffer indices.
+    if (update == nullptr || !block->init.defined()) {
+      return std::nullopt;
+    }
+
+    if (!CheckTrivialBufferIndices(update, &(ein_sum_.output_indices))) {
+      return std::nullopt;
+    }
+    ein_sum_.output_buffer = update->buffer;
+
+    const BufferStoreNode* init = block->init.value().as<BufferStoreNode>();
+    ICHECK(init != nullptr);
+    if (!CompareBufferIndices(init->indices, ein_sum_.output_indices)) {
+      return std::nullopt;
+    }
+    // Step 2: Check the BufferStore updates the output buffer and the input buffers indices are
+    // block iter variables.
+    CheckStoreValue(update->value);
+    if (fail_) {
+      return std::nullopt;
+    }
+    return std::move(ein_sum_);
+  }
+
+ private:
+  void CheckStoreValue(const PrimExpr& update) {
+    // Check the update part has the form:
+    //   Output[output_indices] += Input_0[input_indices_0] op_0 Input_1[input_indices_1] op_1 ...
+    // where output_indices and input_indices_i are the indices are arrays whose elements are the
+    // block iter variables instead of composite PrimExpr, and op_i are the binary operations.
+
+    // Check the value is Add and eithe LHS or RHS is the BufferLoad from the output buffer.
+    const AddNode* add = update.as<AddNode>();
+    if (add == nullptr) {
+      fail_ = true;
+      return;
+    }
+    const BufferLoadNode* lhs = add->a.as<BufferLoadNode>();
+    const BufferLoadNode* rhs = add->b.as<BufferLoadNode>();
+    if (lhs == nullptr && rhs != nullptr) {
+      std::swap(lhs, rhs);
+    }
+    if (lhs == nullptr || !lhs->buffer.same_as(ein_sum_.output_buffer) ||
+        !CompareBufferIndices(lhs->indices, ein_sum_.output_indices)) {
+      fail_ = true;
+      return;
+    }
+    VisitExpr(add->b);
+  }
+
+  void VisitExpr(const PrimExpr& n) final {
+    if (n->IsInstance<BufferLoadNode>() || n->IsInstance<MulNode>() || n->IsInstance<CastNode>()) {
+      ExprVisitor::VisitExpr(n);
+    } else {
+      fail_ = true;
+      return;
+    }
+  }
+
+  void VisitExpr_(const BufferLoadNode* op) final {
+    if (auto it = ein_sum_.input_indices.find(op->buffer);
+        it != ein_sum_.input_indices.end() && !CompareBufferIndices(op->indices, (*it).second)) {
+      fail_ = true;
+      return;
+    }
+    Array<Var> indices;
+    if (!CheckTrivialBufferIndices(GetRef<BufferLoad>(op), &indices)) {
+      fail_ = true;
+      return;
+    }
+    ein_sum_.input_indices.Set(op->buffer, std::move(indices));
+  }
+
+  void VisitExpr_(const CastNode* op) { VisitExpr(op->value); }
+
+  bool Fail() { return fail_; }
+
+  bool CompareBufferIndices(const Array<PrimExpr>& indices, const Array<Var>& other) {
+    return std::equal(indices.begin(), indices.end(), other.begin(), other.end(),
+                      [](const PrimExpr& a, const Var& b) { return a.same_as(b); });
+  }
+
+  Einsum ein_sum_;
+  bool fail_{false};
+};
+
+Einsum ExtractEinsum(const ScheduleState& self, const Block& block) {
+  EinsumExtractor extractor;
+  std::optional<Einsum> einsum = extractor.Extract(block);
+  if (!einsum.has_value()) {
+    throw NonEinsumError(self->mod, block);
+  }
+  return einsum.value();
+}
+
+class BufferNotAllocatedInScopeError : public ScheduleError {
+ public:
+  explicit BufferNotAllocatedInScopeError(IRModule mod, Buffer buffer)
+      : mod_(std::move(mod)), buffer_(std::move(buffer)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: The buffer is not allocated as an intermediate buffer in current "
+           "PrimFunc.";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "The buffer " << buffer_->name
+       << " is not allocated as an intermediate buffer in current PrimFunc.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {}; }
+
+ private:
+  IRModule mod_;
+  Buffer buffer_;
+};
+
+class PadEinsumRewriter : public ReplaceBufferMutator {
+ public:
+  PadEinsumRewriter(const std::unordered_map<const BlockNode*, PrimExpr> producer_predicate,
+                    Map<Var, PrimExpr> padded_iter_extents, const Map<Buffer, Buffer>& buffer_remap,
+                    Map<Block, Block>* block_sref_reuse, arith::Analyzer* analyzer)
+      : ReplaceBufferMutator(buffer_remap, block_sref_reuse),
+        producer_predicate_(producer_predicate),
+        padded_iter_extents_(padded_iter_extents),
+        analyzer_(analyzer) {}
+
+  Stmt VisitStmt_(const ForNode* op) final {
+    For new_for = Downcast<For>(ReplaceBufferMutator::VisitStmt_(op));
+    if (padded_iter_extents_.count(new_for->loop_var)) {
+      new_for.CopyOnWrite()->extent = padded_iter_extents_.at(new_for->loop_var);
+    }
+    return std::move(new_for);
+  }
+
+  Block PadProducerBlock(Block block, const PrimExpr& predicate) {
+    BufferStore store = Downcast<BufferStore>(block->body);
+    store.CopyOnWrite()->value =
+        analyzer_->Simplify(if_then_else(predicate, store->value, make_zero(store->value.dtype())));
+    block.CopyOnWrite()->body = std::move(store);
+    return block;
+  }
+
+  Stmt VisitStmt_(const BlockNode* op) final {
+    Block old_block = GetRef<Block>(op);
+    Block new_block = Downcast<Block>(ReplaceBufferMutator::VisitStmt_(op));
+    if (auto it = producer_predicate_.find(op); it != producer_predicate_.end()) {
+      new_block = PadProducerBlock(std::move(new_block), (*it).second);
+    }
+
+    // Mutate block iters
+    Array<IterVar> new_iters;
+    bool changed = false;
+    for (const IterVar& iter : new_block->iter_vars) {
+      if (auto it = padded_iter_extents_.find(iter->var); it != padded_iter_extents_.end()) {
+        changed = true;
+        new_iters.push_back(
+            IterVar(Range::FromMinExtent(0, (*it).second), iter->var, iter->iter_type));
+      } else {
+        new_iters.push_back(iter);
+      }
+    }
+    if (changed) {
+      new_block.CopyOnWrite()->iter_vars = std::move(new_iters);
+    }
+    if (!old_block.same_as(new_block)) {
+      block_sref_reuse_->Set(old_block, new_block);
+    }
+    return std::move(new_block);
+  }
+
+ private:
+  const std::unordered_set<const BlockNode*> producer_blocks_;
+  const std::unordered_map<const BlockNode*, PrimExpr> producer_predicate_;
+  const Map<Var, PrimExpr> padded_iter_extents_;
+  arith::Analyzer* analyzer_;
+};
+
+/*! \brief The schedule error class when the producer block cannot be padded. */
+class InvalidProducerError : public ScheduleError {
+ public:
+  explicit InvalidProducerError(IRModule mod, Block producer)
+      : mod_(std::move(mod)), producer_(std::move(producer)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: The producer block cannot be padded.";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "The producer block {0} cannot be padded. It should write to a single buffer and the "
+          "body should be a BufferStore.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {producer_}; }
+
+ private:
+  IRModule mod_;
+  Buffer buffer_;
+  Block producer_;
+};
+
+void PadEinsum(ScheduleState self, const StmtSRef& block_sref, const Array<Integer>& padding) {
+  arith::Analyzer analyzer;
+  // Step 1: Input checking and error handling
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
+  BlockRealize realize = GetBlockRealize(self, block_sref);
+
+  const StmtSRef& scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true);
+  InvalidPaddingError::Check(self, GetRef<Block>(block), padding);
+
+  // Check block properties of the computation block
+  CheckBlockHasTrivialBinding(self, block_sref);
+  CheckReductionBlock(self, block_sref, scope_sref);
+  Array loops = GetLoops(block_sref);
+  ICHECK(!loops.empty());
+  CheckGetSingleChildBlockRealizeOnSRefTree(self, loops.front());
+
+  // Check block properties of the producer block
+  const Array<StmtSRef> producers = GetProducers(self, block_sref);
+  {
+    auto f_check_producer = [&](const StmtSRef& producer_sref) {

Review Comment:
   Since this lambda function is only being called once, unlike the lower usage of `f_check_buffer_allocated`, should it be moved into the body of the loop over `producers`?
   



##########
src/tir/schedule/primitive/pad_einsum.cc:
##########
@@ -0,0 +1,493 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <optional>
+
+#include "../utils.h"
+
+namespace tvm {
+namespace tir {
+
+/*! \brief The schedule error class when the padding size is invalid. */
+class InvalidPaddingError : public ScheduleError {
+ public:
+  InvalidPaddingError(IRModule mod, Block block, Array<Integer> padding)
+      : mod_(std::move(mod)), block_(std::move(block)), padding_(std::move(padding)) {}
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
+  String FastErrorString() const final {
+    return "ScheduleError: The padding size for the block is invalid.";
+  }
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "The padding for the block {0} are invalid. It should be a list of "
+       << block_->iter_vars.size() << " non-negative integers. Got " << padding_;
+    return os.str();
+  }
+
+  static void Check(const ScheduleState& self, const Block& block, Array<Integer> padding) {
+    if (padding.size() != block->iter_vars.size()) {
+      throw InvalidPaddingError(self->mod, block, padding);
+    }
+    for (const auto& pad : padding) {
+      if (pad->value < 0) {
+        throw InvalidPaddingError(self->mod, block, padding);
+      }
+    }
+  }
+
+ private:
+  IRModule mod_;
+  Block block_;
+  Array<Integer> padding_;
+};
+
+/*! \brief The schedule error class when the block body is not an Einsum pattern. */
+class NonEinsumError : public ScheduleError {
+ public:
+  explicit NonEinsumError(IRModule mod, Block block)
+      : mod_(std::move(mod)), block_(std::move(block)) {}
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
+  String FastErrorString() const final {
+    return "ScheduleError: The block is not a computation of Einsum pattern.";
+  }
+  String DetailRenderTemplate() const final {
+    return "The block {0} not a computation of Einsum pattern.";
+  }
+
+ private:
+  IRModule mod_;
+  Block block_;
+};
+
+/*! \brief Data structure that represents a Einsum computation. */
+struct Einsum {
+  // The output buffer
+  Buffer output_buffer;
+  // The indices of the output buffer
+  Array<Var> output_indices;
+  // The indices of the input buffers
+  Map<Buffer, Array<Var>> input_indices;
+};
+
+/*!
+ * \brief Check if buffer indices are all Vars and extr
+ * \param buffer_access The BufferLoad or BufferStore
+ * \param[out] indices The optional array to store the indices
+ * \return Whether The indices are all Vars
+ */
+template <typename T>
+bool CheckTrivialBufferIndices(const T& buffer_access, Array<Var>* indices = nullptr) {

Review Comment:
   Instead of an output parameter, could we return either `Optional<Array<Var>>` or `std::optional<Array<Var>>`?  It looks like all the current usages use the extracted indices, so it wouldn't introduce an unnecessary allocation.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] wrongtest-intellif commented on a diff in pull request #12750: [TIR, Schedule] Add schedule primitive PadEinsum

Posted by GitBox <gi...@apache.org>.
wrongtest-intellif commented on code in PR #12750:
URL: https://github.com/apache/tvm/pull/12750#discussion_r967571891


##########
src/tir/schedule/primitive/pad_einsum.cc:
##########
@@ -0,0 +1,496 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <optional>
+
+#include "../utils.h"
+
+namespace tvm {
+namespace tir {
+
+/*! \brief The schedule error class when the padding size is invalid. */
+class InvalidPaddingError : public ScheduleError {
+ public:
+  InvalidPaddingError(IRModule mod, Block block, Array<Integer> padding)
+      : mod_(std::move(mod)), block_(std::move(block)), padding_(std::move(padding)) {}
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
+  String FastErrorString() const final {
+    return "ScheduleError: The padding size for the block is invalid.";
+  }
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "The padding for the block {0} are invalid. It should be a list of "
+       << block_->iter_vars.size() << " non-negative integers. Got " << padding_;
+    return os.str();
+  }
+
+  static void Check(const ScheduleState& self, const Block& block, Array<Integer> padding) {
+    if (padding.size() != block->iter_vars.size()) {
+      throw InvalidPaddingError(self->mod, block, padding);
+    }
+    for (const auto& pad : padding) {
+      if (pad->value < 0) {
+        throw InvalidPaddingError(self->mod, block, padding);
+      }
+    }
+  }
+
+ private:
+  IRModule mod_;
+  Block block_;
+  Array<Integer> padding_;
+};
+
+/*! \brief The schedule error class when the block body is not an Einsum pattern. */
+class NonEinsumError : public ScheduleError {
+ public:
+  explicit NonEinsumError(IRModule mod, Block block)
+      : mod_(std::move(mod)), block_(std::move(block)) {}
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
+  String FastErrorString() const final {
+    return "ScheduleError: The block is not a computation of Einsum pattern.";
+  }
+  String DetailRenderTemplate() const final {
+    return "The block {0} not a computation of Einsum pattern.";
+  }
+
+ private:
+  IRModule mod_;
+  Block block_;
+};
+
+/*! \brief Data structure that represents a Einsum computation. */
+struct Einsum {
+  // The output buffer
+  Buffer output_buffer;
+  // The indices of the output buffer
+  Array<Var> output_indices;
+  // The indices of the input buffers
+  Map<Buffer, Array<Var>> input_indices;
+};
+
+/*!
+ * \brief Check if buffer indices are all Vars and extr
+ * \param buffer_access The BufferLoad or BufferStore
+ * \param[out] indices The optional array to store the indices
+ * \return Whether The indices are all Vars
+ */
+template <typename T>
+bool CheckTrivialBufferIndices(const T& buffer_access, Array<Var>* indices = nullptr) {
+  for (const PrimExpr& index : buffer_access->indices) {
+    const VarNode* var = index.as<VarNode>();
+    if (var == nullptr) {
+      return false;
+    }
+    if (indices != nullptr) {
+      indices->push_back(GetRef<Var>(var));
+    }
+  }
+  return true;
+}
+
+class EinsumExtractor : public ExprVisitor {
+ public:
+  explicit EinsumExtractor() = default;
+
+  std::optional<Einsum> Extract(const Block& block) {
+    const BufferStoreNode* update = block->body.as<BufferStoreNode>();
+    // Step 1: Check the body is a BufferStore and the block has the init statement, and the
+    // BufferStore and the init statement store have the same output buffer indices.
+    if (update == nullptr || !block->init.defined()) {
+      return std::nullopt;
+    }
+
+    if (!CheckTrivialBufferIndices(update, &(ein_sum_.output_indices))) {
+      return std::nullopt;
+    }
+    ein_sum_.output_buffer = update->buffer;
+
+    const BufferStoreNode* init = block->init.value().as<BufferStoreNode>();
+    ICHECK(init != nullptr);
+    if (!CompareBufferIndices(init->indices, ein_sum_.output_indices)) {
+      return std::nullopt;
+    }
+    // Step 2: Check the BufferStore updates the output buffer and the input buffers indices are
+    // block iter variables.
+    CheckStoreValue(update->value);
+    if (fail_) {
+      return std::nullopt;
+    }
+    return std::move(ein_sum_);
+  }
+
+ private:
+  void CheckStoreValue(const PrimExpr& update) {
+    // Check the update part has the form:
+    //   Output[output_indices] += Input_0[input_indices_0] op_0 Input_1[input_indices_1] op_1 ...
+    // where output_indices and input_indices_i are the indices are arrays whose elements are the
+    // block iter variables instead of composite PrimExpr, and op_i are the binary operations.
+
+    // Check the value is Add and eithe LHS or RHS is the BufferLoad from the output buffer.
+    const AddNode* add = update.as<AddNode>();
+    if (add == nullptr) {
+      fail_ = true;
+      return;
+    }
+    const BufferLoadNode* lhs = add->a.as<BufferLoadNode>();
+    const BufferLoadNode* rhs = add->b.as<BufferLoadNode>();
+    if (lhs == nullptr && rhs != nullptr) {
+      std::swap(lhs, rhs);
+    }
+    if (lhs == nullptr || !lhs->buffer.same_as(ein_sum_.output_buffer) ||
+        !CompareBufferIndices(lhs->indices, ein_sum_.output_indices)) {
+      fail_ = true;
+      return;
+    }
+    VisitExpr(add->b);
+  }
+
+  void VisitExpr(const PrimExpr& n) final {
+    if (n->IsInstance<BufferLoadNode>() || n->IsInstance<MulNode>() || n->IsInstance<CastNode>()) {
+      ExprVisitor::VisitExpr(n);
+    } else {
+      fail_ = true;
+      return;
+    }
+  }
+
+  void VisitExpr_(const BufferLoadNode* op) final {
+    if (auto it = ein_sum_.input_indices.find(op->buffer);
+        it != ein_sum_.input_indices.end() && !CompareBufferIndices(op->indices, (*it).second)) {
+      fail_ = true;
+      return;
+    }
+    Array<Var> indices;
+    if (!CheckTrivialBufferIndices(GetRef<BufferLoad>(op), &indices)) {
+      fail_ = true;
+      return;
+    }
+    ein_sum_.input_indices.Set(op->buffer, std::move(indices));
+  }
+
+  void VisitExpr_(const CastNode* op) { VisitExpr(op->value); }
+
+  bool Fail() { return fail_; }
+
+  bool CompareBufferIndices(const Array<PrimExpr>& indices, const Array<Var>& other) {
+    return std::equal(indices.begin(), indices.end(), other.begin(), other.end(),
+                      [](const PrimExpr& a, const Var& b) { return a.same_as(b); });
+  }
+
+  Einsum ein_sum_;
+  bool fail_{false};
+};
+
+Einsum ExtractEinsum(const ScheduleState& self, const Block& block) {
+  EinsumExtractor extractor;
+  std::optional<Einsum> einsum = extractor.Extract(block);
+  if (!einsum.has_value()) {
+    throw NonEinsumError(self->mod, block);
+  }
+  return einsum.value();
+}
+
+class BufferNotAllocatedInScopeError : public ScheduleError {
+ public:
+  explicit BufferNotAllocatedInScopeError(IRModule mod, Buffer buffer)
+      : mod_(std::move(mod)), buffer_(std::move(buffer)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: The buffer is not allocated as an intermediate buffer in current "
+           "PrimFunc.";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "The buffer " << buffer_->name
+       << " is not allocated as an intermediate buffer in current PrimFunc.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {}; }
+
+ private:
+  IRModule mod_;
+  Buffer buffer_;
+};
+
+class PadEinsumRewriter : public ReplaceBufferMutator {
+ public:
+  PadEinsumRewriter(const std::unordered_map<const BlockNode*, PrimExpr> producer_predicate,
+                    Map<Var, PrimExpr> padded_iter_extents, const Map<Buffer, Buffer>& buffer_remap,
+                    Map<Block, Block>* block_sref_reuse, arith::Analyzer* analyzer)
+      : ReplaceBufferMutator(buffer_remap, block_sref_reuse),
+        producer_predicate_(producer_predicate),
+        padded_iter_extents_(padded_iter_extents),
+        analyzer_(analyzer) {}
+
+  Stmt VisitStmt_(const ForNode* op) final {
+    For new_for = Downcast<For>(ReplaceBufferMutator::VisitStmt_(op));
+    if (padded_iter_extents_.count(new_for->loop_var)) {
+      new_for.CopyOnWrite()->extent = padded_iter_extents_.at(new_for->loop_var);
+    }
+    return std::move(new_for);
+  }
+
+  Block PadProducerBlock(Block block, const PrimExpr& predicate) {
+    BufferStore store = Downcast<BufferStore>(block->body);
+    store.CopyOnWrite()->value =
+        analyzer_->Simplify(if_then_else(predicate, store->value, make_zero(store->value.dtype())));

Review Comment:
   Should be zero elem of reduce op rather than zero?



##########
tests/python/unittest/test_tir_schedule_pad_einsum.py:
##########
@@ -0,0 +1,123 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+import sys
+
+import pytest
+import tvm
+import tvm.testing
+from tvm import tir, te
+from tvm.script import tir as T
+from tvm.tir.schedule.schedule import ScheduleError
+from tvm.tir.schedule.testing import verify_trace_roundtrip
+from tvm.meta_schedule.testing import te_workload
+
+# pylint: disable=no-member,invalid-name,unused-variable,unexpected-keyword-arg
+
+
+@T.prim_func
+def matmul_before(
+    A: T.Buffer[(128, 127), "float32"],
+    B: T.Buffer[(127, 127), "float32"],
+    C: T.Buffer[(128, 127), "float32"],
+) -> None:
+    A_shared = T.alloc_buffer((128, 127), "float32", scope="shared")
+    B_shared = T.alloc_buffer((127, 127), "float32", scope="shared")
+    C_shared = T.alloc_buffer((128, 127), "float32", scope="shared")
+    for i0, i1 in T.grid(128, 127):
+        with T.block("A"):
+            i, j = T.axis.remap("SS", [i0, i1])
+            A_shared[i, j] = A[i, j]
+    for i0, i1 in T.grid(127, 127):
+        with T.block("B"):
+            i, j = T.axis.remap("SS", [i0, i1])
+            B_shared[i, j] = B[i, j]
+    for i0, i1, i2 in T.grid(128, 127, 127):
+        with T.block("C_shared"):
+            i, j, k = T.axis.remap("SSR", [i0, i1, i2])
+            with T.init():
+                C_shared[i, j] = T.float32(0)
+            C_shared[i, j] = C_shared[i, j] + A_shared[i, k] * B_shared[k, j]
+    for i0, i1 in T.grid(128, 127):
+        with T.block("C"):
+            i, j = T.axis.remap("SS", [i0, i1])
+            C[i, j] = C_shared[i, j]
+
+
+@T.prim_func
+def matmul_expected(
+    A: T.Buffer[(128, 127), "float32"],
+    B: T.Buffer[(127, 127), "float32"],
+    C: T.Buffer[(128, 127), "float32"],
+) -> None:
+    A_shared_padded = T.alloc_buffer([128, 128], dtype="float32", scope="shared")
+    B_shared_padded = T.alloc_buffer([128, 128], dtype="float32", scope="shared")
+    C_shared_padded = T.alloc_buffer([128, 128], dtype="float32", scope="shared")
+    for i0, i1 in T.grid(128, 128):
+        with T.block("A"):
+            i, j = T.axis.remap("SS", [i0, i1])
+            T.reads(A[i, j])
+            T.writes(A_shared_padded[i, j])
+            A_shared_padded[i, j] = T.if_then_else(j < 127, A[i, j], T.float32(0), dtype="float32")
+    for i0, i1 in T.grid(128, 128):
+        with T.block("B"):
+            i, j = T.axis.remap("SS", [i0, i1])
+            T.reads(B[i, j])
+            T.writes(B_shared_padded[i, j])
+            B_shared_padded[i, j] = T.if_then_else(
+                i < 127 and j < 127, B[i, j], T.float32(0), dtype="float32"
+            )
+    for i0, i1, i2 in T.grid(128, 128, 128):
+        with T.block("C_shared"):
+            i, j, k = T.axis.remap("SSR", [i0, i1, i2])
+            T.reads(A_shared_padded[i, k], B_shared_padded[k, j])
+            T.writes(C_shared_padded[i, j])
+            with T.init():
+                C_shared_padded[i, j] = T.float32(0)
+            C_shared_padded[i, j] = (
+                C_shared_padded[i, j] + A_shared_padded[i, k] * B_shared_padded[k, j]
+            )
+    for i0, i1 in T.grid(128, 127):
+        with T.block("C"):
+            i, j = T.axis.remap("SS", [i0, i1])
+            T.reads(C_shared_padded[i, j])
+            T.writes(C[i, j])
+            C[i, j] = C_shared_padded[i, j]
+
+
+# pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg
+
+
+def test_pad_matmul():
+    sch = tir.Schedule(matmul_before, debug_mask="all")
+    C = sch.get_block("C_shared")
+    sch.pad_einsum(C, [0, 1, 1])
+    print(sch.mod["main"].script())

Review Comment:
   print



##########
src/tir/schedule/primitive/pad_einsum.cc:
##########
@@ -0,0 +1,496 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <optional>
+
+#include "../utils.h"
+
+namespace tvm {
+namespace tir {
+
+/*! \brief The schedule error class when the padding size is invalid. */
+class InvalidPaddingError : public ScheduleError {
+ public:
+  InvalidPaddingError(IRModule mod, Block block, Array<Integer> padding)
+      : mod_(std::move(mod)), block_(std::move(block)), padding_(std::move(padding)) {}
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
+  String FastErrorString() const final {
+    return "ScheduleError: The padding size for the block is invalid.";
+  }
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "The padding for the block {0} are invalid. It should be a list of "
+       << block_->iter_vars.size() << " non-negative integers. Got " << padding_;
+    return os.str();
+  }
+
+  static void Check(const ScheduleState& self, const Block& block, Array<Integer> padding) {
+    if (padding.size() != block->iter_vars.size()) {
+      throw InvalidPaddingError(self->mod, block, padding);
+    }
+    for (const auto& pad : padding) {
+      if (pad->value < 0) {
+        throw InvalidPaddingError(self->mod, block, padding);
+      }
+    }
+  }
+
+ private:
+  IRModule mod_;
+  Block block_;
+  Array<Integer> padding_;
+};
+
+/*! \brief The schedule error class when the block body is not an Einsum pattern. */
+class NonEinsumError : public ScheduleError {
+ public:
+  explicit NonEinsumError(IRModule mod, Block block)
+      : mod_(std::move(mod)), block_(std::move(block)) {}
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
+  String FastErrorString() const final {
+    return "ScheduleError: The block is not a computation of Einsum pattern.";
+  }
+  String DetailRenderTemplate() const final {
+    return "The block {0} not a computation of Einsum pattern.";
+  }
+
+ private:
+  IRModule mod_;
+  Block block_;
+};
+
+/*! \brief Data structure that represents a Einsum computation. */
+struct Einsum {
+  // The output buffer
+  Buffer output_buffer;
+  // The indices of the output buffer
+  Array<Var> output_indices;
+  // The indices of the input buffers
+  Map<Buffer, Array<Var>> input_indices;
+};
+
+/*!
+ * \brief Check if buffer indices are all Vars and extr
+ * \param buffer_access The BufferLoad or BufferStore
+ * \param[out] indices The optional array to store the indices
+ * \return Whether The indices are all Vars
+ */
+template <typename T>
+bool CheckTrivialBufferIndices(const T& buffer_access, Array<Var>* indices = nullptr) {
+  for (const PrimExpr& index : buffer_access->indices) {
+    const VarNode* var = index.as<VarNode>();
+    if (var == nullptr) {
+      return false;
+    }
+    if (indices != nullptr) {
+      indices->push_back(GetRef<Var>(var));
+    }
+  }
+  return true;
+}
+
+class EinsumExtractor : public ExprVisitor {
+ public:
+  explicit EinsumExtractor() = default;
+
+  std::optional<Einsum> Extract(const Block& block) {
+    const BufferStoreNode* update = block->body.as<BufferStoreNode>();
+    // Step 1: Check the body is a BufferStore and the block has the init statement, and the
+    // BufferStore and the init statement store have the same output buffer indices.
+    if (update == nullptr || !block->init.defined()) {
+      return std::nullopt;
+    }
+
+    if (!CheckTrivialBufferIndices(update, &(ein_sum_.output_indices))) {
+      return std::nullopt;
+    }
+    ein_sum_.output_buffer = update->buffer;
+
+    const BufferStoreNode* init = block->init.value().as<BufferStoreNode>();
+    ICHECK(init != nullptr);
+    if (!CompareBufferIndices(init->indices, ein_sum_.output_indices)) {
+      return std::nullopt;
+    }
+    // Step 2: Check the BufferStore updates the output buffer and the input buffers indices are
+    // block iter variables.
+    CheckStoreValue(update->value);
+    if (fail_) {
+      return std::nullopt;
+    }
+    return std::move(ein_sum_);
+  }
+
+ private:
+  void CheckStoreValue(const PrimExpr& update) {
+    // Check the update part has the form:
+    //   Output[output_indices] += Input_0[input_indices_0] op_0 Input_1[input_indices_1] op_1 ...
+    // where output_indices and input_indices_i are the indices are arrays whose elements are the
+    // block iter variables instead of composite PrimExpr, and op_i are the binary operations.
+
+    // Check the value is Add and eithe LHS or RHS is the BufferLoad from the output buffer.
+    const AddNode* add = update.as<AddNode>();
+    if (add == nullptr) {
+      fail_ = true;
+      return;
+    }
+    const BufferLoadNode* lhs = add->a.as<BufferLoadNode>();
+    const BufferLoadNode* rhs = add->b.as<BufferLoadNode>();
+    if (lhs == nullptr && rhs != nullptr) {
+      std::swap(lhs, rhs);
+    }
+    if (lhs == nullptr || !lhs->buffer.same_as(ein_sum_.output_buffer) ||
+        !CompareBufferIndices(lhs->indices, ein_sum_.output_indices)) {
+      fail_ = true;
+      return;
+    }
+    VisitExpr(add->b);
+  }
+
+  void VisitExpr(const PrimExpr& n) final {
+    if (n->IsInstance<BufferLoadNode>() || n->IsInstance<MulNode>() || n->IsInstance<CastNode>()) {
+      ExprVisitor::VisitExpr(n);
+    } else {
+      fail_ = true;
+      return;
+    }
+  }
+
+  void VisitExpr_(const BufferLoadNode* op) final {
+    if (auto it = ein_sum_.input_indices.find(op->buffer);
+        it != ein_sum_.input_indices.end() && !CompareBufferIndices(op->indices, (*it).second)) {
+      fail_ = true;
+      return;
+    }
+    Array<Var> indices;
+    if (!CheckTrivialBufferIndices(GetRef<BufferLoad>(op), &indices)) {
+      fail_ = true;
+      return;
+    }
+    ein_sum_.input_indices.Set(op->buffer, std::move(indices));
+  }
+
+  void VisitExpr_(const CastNode* op) { VisitExpr(op->value); }
+
+  bool Fail() { return fail_; }
+
+  bool CompareBufferIndices(const Array<PrimExpr>& indices, const Array<Var>& other) {
+    return std::equal(indices.begin(), indices.end(), other.begin(), other.end(),
+                      [](const PrimExpr& a, const Var& b) { return a.same_as(b); });
+  }
+
+  Einsum ein_sum_;
+  bool fail_{false};
+};
+
+Einsum ExtractEinsum(const ScheduleState& self, const Block& block) {
+  EinsumExtractor extractor;
+  std::optional<Einsum> einsum = extractor.Extract(block);
+  if (!einsum.has_value()) {
+    throw NonEinsumError(self->mod, block);
+  }
+  return einsum.value();
+}
+
+class BufferNotAllocatedInScopeError : public ScheduleError {
+ public:
+  explicit BufferNotAllocatedInScopeError(IRModule mod, Buffer buffer)
+      : mod_(std::move(mod)), buffer_(std::move(buffer)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: The buffer is not allocated as an intermediate buffer in current "
+           "PrimFunc.";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "The buffer " << buffer_->name
+       << " is not allocated as an intermediate buffer in current PrimFunc.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {}; }
+
+ private:
+  IRModule mod_;
+  Buffer buffer_;
+};
+
+class PadEinsumRewriter : public ReplaceBufferMutator {
+ public:
+  PadEinsumRewriter(const std::unordered_map<const BlockNode*, PrimExpr> producer_predicate,
+                    Map<Var, PrimExpr> padded_iter_extents, const Map<Buffer, Buffer>& buffer_remap,
+                    Map<Block, Block>* block_sref_reuse, arith::Analyzer* analyzer)
+      : ReplaceBufferMutator(buffer_remap, block_sref_reuse),
+        producer_predicate_(producer_predicate),
+        padded_iter_extents_(padded_iter_extents),
+        analyzer_(analyzer) {}
+
+  Stmt VisitStmt_(const ForNode* op) final {
+    For new_for = Downcast<For>(ReplaceBufferMutator::VisitStmt_(op));
+    if (padded_iter_extents_.count(new_for->loop_var)) {
+      new_for.CopyOnWrite()->extent = padded_iter_extents_.at(new_for->loop_var);
+    }
+    return std::move(new_for);
+  }
+
+  Block PadProducerBlock(Block block, const PrimExpr& predicate) {
+    BufferStore store = Downcast<BufferStore>(block->body);
+    store.CopyOnWrite()->value =
+        analyzer_->Simplify(if_then_else(predicate, store->value, make_zero(store->value.dtype())));

Review Comment:
   Oh I see, it would checks the Add pattern.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] vinx13 commented on pull request #12750: [TIR, Schedule] Add schedule primitive PadEinsum

Posted by GitBox <gi...@apache.org>.
vinx13 commented on PR #12750:
URL: https://github.com/apache/tvm/pull/12750#issuecomment-1247139119

   @Lunderberg The current assumption is to over compute the reduction block, and infer the padding of the producer. Since the padding is inferred from buffer access pattern, I think we can't specify the padding as tuple 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] vinx13 commented on a diff in pull request #12750: [TIR, Schedule] Add schedule primitive PadEinsum

Posted by GitBox <gi...@apache.org>.
vinx13 commented on code in PR #12750:
URL: https://github.com/apache/tvm/pull/12750#discussion_r971197674


##########
src/tir/schedule/primitive/pad_einsum.cc:
##########
@@ -0,0 +1,493 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <optional>
+
+#include "../utils.h"
+
+namespace tvm {
+namespace tir {
+
+/*! \brief The schedule error class when the padding size is invalid. */
+class InvalidPaddingError : public ScheduleError {
+ public:
+  InvalidPaddingError(IRModule mod, Block block, Array<Integer> padding)
+      : mod_(std::move(mod)), block_(std::move(block)), padding_(std::move(padding)) {}
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
+  String FastErrorString() const final {
+    return "ScheduleError: The padding size for the block is invalid.";
+  }
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "The padding for the block {0} are invalid. It should be a list of "
+       << block_->iter_vars.size() << " non-negative integers. Got " << padding_;
+    return os.str();
+  }
+
+  static void Check(const ScheduleState& self, const Block& block, Array<Integer> padding) {
+    if (padding.size() != block->iter_vars.size()) {
+      throw InvalidPaddingError(self->mod, block, padding);
+    }
+    for (const auto& pad : padding) {
+      if (pad->value < 0) {
+        throw InvalidPaddingError(self->mod, block, padding);
+      }
+    }
+  }
+
+ private:
+  IRModule mod_;
+  Block block_;
+  Array<Integer> padding_;
+};
+
+/*! \brief The schedule error class when the block body is not an Einsum pattern. */
+class NonEinsumError : public ScheduleError {
+ public:
+  explicit NonEinsumError(IRModule mod, Block block)
+      : mod_(std::move(mod)), block_(std::move(block)) {}
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
+  String FastErrorString() const final {
+    return "ScheduleError: The block is not a computation of Einsum pattern.";
+  }
+  String DetailRenderTemplate() const final {
+    return "The block {0} not a computation of Einsum pattern.";
+  }
+
+ private:
+  IRModule mod_;
+  Block block_;
+};
+
+/*! \brief Data structure that represents a Einsum computation. */
+struct Einsum {
+  // The output buffer
+  Buffer output_buffer;
+  // The indices of the output buffer
+  Array<Var> output_indices;
+  // The indices of the input buffers
+  Map<Buffer, Array<Var>> input_indices;
+};
+
+/*!
+ * \brief Check if buffer indices are all Vars and extr
+ * \param buffer_access The BufferLoad or BufferStore
+ * \param[out] indices The optional array to store the indices
+ * \return Whether The indices are all Vars
+ */
+template <typename T>
+bool CheckTrivialBufferIndices(const T& buffer_access, Array<Var>* indices = nullptr) {
+  for (const PrimExpr& index : buffer_access->indices) {
+    const VarNode* var = index.as<VarNode>();
+    if (var == nullptr) {
+      return false;
+    }
+    if (indices != nullptr) {
+      indices->push_back(GetRef<Var>(var));
+    }
+  }
+  return true;
+}
+
+class EinsumExtractor : public ExprVisitor {
+ public:
+  EinsumExtractor() = default;
+
+  std::optional<Einsum> Extract(const Block& block) {
+    const BufferStoreNode* update = block->body.as<BufferStoreNode>();
+    // Step 1: Check the body is a BufferStore and the block has the init statement, and the
+    // BufferStore and the init statement store have the same output buffer indices.
+    if (update == nullptr || !block->init.defined()) {
+      return std::nullopt;
+    }
+
+    if (!CheckTrivialBufferIndices(update, &(ein_sum_.output_indices))) {
+      return std::nullopt;
+    }
+    ein_sum_.output_buffer = update->buffer;
+
+    const BufferStoreNode* init = block->init.value().as<BufferStoreNode>();
+    ICHECK(init != nullptr);
+    if (!CompareBufferIndices(init->indices, ein_sum_.output_indices)) {
+      return std::nullopt;
+    }
+    // Step 2: Check the BufferStore updates the output buffer and the input buffers indices are
+    // block iter variables.
+    CheckStoreValue(update->value);
+    if (fail_) {
+      return std::nullopt;
+    }
+    return std::move(ein_sum_);
+  }
+
+ private:
+  void CheckStoreValue(const PrimExpr& update) {
+    // Check the update part has the form:
+    //   Output[output_indices] += Input_0[input_indices_0] op_0 Input_1[input_indices_1] op_1 ...
+    // where output_indices and input_indices_i are the indices are arrays whose elements are the
+    // block iter variables instead of composite PrimExpr, and op_i are the binary operations.
+
+    // Check the value is Add and eithe LHS or RHS is the BufferLoad from the output buffer.
+    const AddNode* add = update.as<AddNode>();
+    if (add == nullptr) {
+      fail_ = true;
+      return;
+    }
+    const BufferLoadNode* lhs = add->a.as<BufferLoadNode>();
+    const BufferLoadNode* rhs = add->b.as<BufferLoadNode>();
+    if (lhs == nullptr && rhs != nullptr) {
+      std::swap(lhs, rhs);
+    }
+    if (lhs == nullptr || !lhs->buffer.same_as(ein_sum_.output_buffer) ||
+        !CompareBufferIndices(lhs->indices, ein_sum_.output_indices)) {
+      fail_ = true;
+      return;
+    }
+    VisitExpr(add->b);
+  }
+
+  void VisitExpr(const PrimExpr& n) final {
+    if (n->IsInstance<BufferLoadNode>() || n->IsInstance<MulNode>() || n->IsInstance<CastNode>()) {
+      ExprVisitor::VisitExpr(n);
+    } else {
+      fail_ = true;
+      return;
+    }
+  }
+
+  void VisitExpr_(const BufferLoadNode* op) final {
+    if (auto it = ein_sum_.input_indices.find(op->buffer);
+        it != ein_sum_.input_indices.end() && !CompareBufferIndices(op->indices, (*it).second)) {
+      fail_ = true;
+      return;
+    }
+    Array<Var> indices;
+    if (!CheckTrivialBufferIndices(GetRef<BufferLoad>(op), &indices)) {
+      fail_ = true;
+      return;
+    }
+    ein_sum_.input_indices.Set(op->buffer, std::move(indices));
+  }
+
+  void VisitExpr_(const CastNode* op) { VisitExpr(op->value); }
+
+  bool Fail() { return fail_; }
+
+  bool CompareBufferIndices(const Array<PrimExpr>& indices, const Array<Var>& other) {
+    return std::equal(indices.begin(), indices.end(), other.begin(), other.end(),
+                      [](const PrimExpr& a, const Var& b) { return a.same_as(b); });
+  }
+
+  Einsum ein_sum_;
+  bool fail_{false};
+};
+
+Einsum ExtractEinsum(const ScheduleState& self, const Block& block) {
+  EinsumExtractor extractor;
+  std::optional<Einsum> einsum = extractor.Extract(block);
+  if (!einsum.has_value()) {
+    throw NonEinsumError(self->mod, block);
+  }
+  return einsum.value();
+}
+
+class BufferNotAllocatedInScopeError : public ScheduleError {
+ public:
+  explicit BufferNotAllocatedInScopeError(IRModule mod, Buffer buffer)
+      : mod_(std::move(mod)), buffer_(std::move(buffer)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: The buffer is not allocated as an intermediate buffer in current "
+           "PrimFunc.";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "The buffer " << buffer_->name
+       << " is not allocated as an intermediate buffer in current PrimFunc.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {}; }
+
+ private:
+  IRModule mod_;
+  Buffer buffer_;
+};
+
+class PadEinsumRewriter : public ReplaceBufferMutator {
+ public:
+  PadEinsumRewriter(const std::unordered_map<const BlockNode*, PrimExpr> producer_predicate,
+                    Map<Var, PrimExpr> padded_iter_extents, const Map<Buffer, Buffer>& buffer_remap,
+                    Map<Block, Block>* block_sref_reuse, arith::Analyzer* analyzer)
+      : ReplaceBufferMutator(buffer_remap, block_sref_reuse),
+        producer_predicate_(producer_predicate),
+        padded_iter_extents_(padded_iter_extents),
+        analyzer_(analyzer) {}
+
+  Stmt VisitStmt_(const ForNode* op) final {
+    For new_for = Downcast<For>(ReplaceBufferMutator::VisitStmt_(op));
+    if (padded_iter_extents_.count(new_for->loop_var)) {
+      new_for.CopyOnWrite()->extent = padded_iter_extents_.at(new_for->loop_var);
+    }
+    return std::move(new_for);
+  }
+
+  Block PadProducerBlock(Block block, const PrimExpr& predicate) {
+    BufferStore store = Downcast<BufferStore>(block->body);
+    store.CopyOnWrite()->value =
+        analyzer_->Simplify(if_then_else(predicate, store->value, make_zero(store->value.dtype())));
+    block.CopyOnWrite()->body = std::move(store);
+    return block;
+  }
+
+  Stmt VisitStmt_(const BlockNode* op) final {
+    Block old_block = GetRef<Block>(op);
+    Block new_block = Downcast<Block>(ReplaceBufferMutator::VisitStmt_(op));
+    if (auto it = producer_predicate_.find(op); it != producer_predicate_.end()) {
+      new_block = PadProducerBlock(std::move(new_block), (*it).second);
+    }
+
+    // Mutate block iters
+    Array<IterVar> new_iters;
+    bool changed = false;
+    for (const IterVar& iter : new_block->iter_vars) {
+      if (auto it = padded_iter_extents_.find(iter->var); it != padded_iter_extents_.end()) {
+        changed = true;
+        new_iters.push_back(
+            IterVar(Range::FromMinExtent(0, (*it).second), iter->var, iter->iter_type));
+      } else {
+        new_iters.push_back(iter);
+      }
+    }
+    if (changed) {
+      new_block.CopyOnWrite()->iter_vars = std::move(new_iters);
+    }
+    if (!old_block.same_as(new_block)) {
+      block_sref_reuse_->Set(old_block, new_block);
+    }
+    return std::move(new_block);
+  }
+
+ private:
+  const std::unordered_set<const BlockNode*> producer_blocks_;
+  const std::unordered_map<const BlockNode*, PrimExpr> producer_predicate_;
+  const Map<Var, PrimExpr> padded_iter_extents_;
+  arith::Analyzer* analyzer_;
+};
+
+/*! \brief The schedule error class when the producer block cannot be padded. */
+class InvalidProducerError : public ScheduleError {
+ public:
+  explicit InvalidProducerError(IRModule mod, Block producer)
+      : mod_(std::move(mod)), producer_(std::move(producer)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: The producer block cannot be padded.";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "The producer block {0} cannot be padded. It should write to a single buffer and the "
+          "body should be a BufferStore.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {producer_}; }
+
+ private:
+  IRModule mod_;
+  Buffer buffer_;
+  Block producer_;
+};
+
+void PadEinsum(ScheduleState self, const StmtSRef& block_sref, const Array<Integer>& padding) {
+  arith::Analyzer analyzer;
+  // Step 1: Input checking and error handling
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
+  BlockRealize realize = GetBlockRealize(self, block_sref);
+
+  const StmtSRef& scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true);
+  InvalidPaddingError::Check(self, GetRef<Block>(block), padding);
+
+  // Check block properties of the computation block
+  CheckBlockHasTrivialBinding(self, block_sref);
+  CheckReductionBlock(self, block_sref, scope_sref);
+  Array loops = GetLoops(block_sref);
+  ICHECK(!loops.empty());
+  CheckGetSingleChildBlockRealizeOnSRefTree(self, loops.front());
+
+  // Check block properties of the producer block
+  const Array<StmtSRef> producers = GetProducers(self, block_sref);
+  {
+    auto f_check_producer = [&](const StmtSRef& producer_sref) {

Review Comment:
   Thanks, I found this is similar to the checking before, so I consolidated them into a `f_check_block_properties` function



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] junrushao commented on pull request #12750: [TIR, Schedule] Add schedule primitive PadEinsum

Posted by GitBox <gi...@apache.org>.
junrushao commented on PR #12750:
URL: https://github.com/apache/tvm/pull/12750#issuecomment-1253141898

   @vinx13 let's fix the following warnings:
   
   ```
   /root/Projects/tvm-dev/src/tir/schedule/primitive/pad_einsum.cc:231:8: warning: 'tvm::tir::PadEinsumRewriter::VisitStmt_' hides overloaded virtual function [-Woverloaded-virtual]
     Stmt VisitStmt_(const ForNode* op) final {
          ^
   /root/Projects/tvm-dev/src/tir/schedule/primitive/.././transform.h:134:8: note: hidden overloaded virtual function 'tvm::tir::ReplaceBufferMutator::VisitStmt_' declared here: type mismatch at 1st parameter ('const tvm::tir::BufferStoreNode *' vs 'const tvm::tir::ForNode *')
     Stmt VisitStmt_(const BufferStoreNode* op) final;
          ^
   /root/Projects/tvm-dev/src/tir/schedule/primitive/pad_einsum.cc:374:47: warning: lambda capture 'buffer_remap' is not used [-Wunused-lambda-capture]
     auto f_pad_buffer = [&padded_iter_extents, &buffer_remap](Buffer buffer,
                                              ~~~^~~~~~~~~~~~
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] wrongtest-intellif commented on a diff in pull request #12750: [TIR, Schedule] Add schedule primitive PadEinsum

Posted by GitBox <gi...@apache.org>.
wrongtest-intellif commented on code in PR #12750:
URL: https://github.com/apache/tvm/pull/12750#discussion_r967571170


##########
tests/python/unittest/test_tir_schedule_pad_einsum.py:
##########
@@ -0,0 +1,123 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+import sys
+
+import pytest
+import tvm
+import tvm.testing
+from tvm import tir, te
+from tvm.script import tir as T
+from tvm.tir.schedule.schedule import ScheduleError
+from tvm.tir.schedule.testing import verify_trace_roundtrip
+from tvm.meta_schedule.testing import te_workload
+
+# pylint: disable=no-member,invalid-name,unused-variable,unexpected-keyword-arg
+
+
+@T.prim_func
+def matmul_before(
+    A: T.Buffer[(128, 127), "float32"],
+    B: T.Buffer[(127, 127), "float32"],
+    C: T.Buffer[(128, 127), "float32"],
+) -> None:
+    A_shared = T.alloc_buffer((128, 127), "float32", scope="shared")
+    B_shared = T.alloc_buffer((127, 127), "float32", scope="shared")
+    C_shared = T.alloc_buffer((128, 127), "float32", scope="shared")
+    for i0, i1 in T.grid(128, 127):
+        with T.block("A"):
+            i, j = T.axis.remap("SS", [i0, i1])
+            A_shared[i, j] = A[i, j]
+    for i0, i1 in T.grid(127, 127):
+        with T.block("B"):
+            i, j = T.axis.remap("SS", [i0, i1])
+            B_shared[i, j] = B[i, j]
+    for i0, i1, i2 in T.grid(128, 127, 127):
+        with T.block("C_shared"):
+            i, j, k = T.axis.remap("SSR", [i0, i1, i2])
+            with T.init():
+                C_shared[i, j] = T.float32(0)
+            C_shared[i, j] = C_shared[i, j] + A_shared[i, k] * B_shared[k, j]
+    for i0, i1 in T.grid(128, 127):
+        with T.block("C"):
+            i, j = T.axis.remap("SS", [i0, i1])
+            C[i, j] = C_shared[i, j]
+
+
+@T.prim_func
+def matmul_expected(

Review Comment:
   Compare to https://github.com/apache/tvm/pull/12720 cc @Lunderberg 
   Could I understand that it equals with a bundle of operations in certain workload pattern? Like
   ```python
   for buffer in [A_shared, B_shared, C_shared]:
        s.transpose_layout(buffer, (127, 127) -> (128, 128), pad_value=0)
   for block in [A, B, C_shared]:
        for axis in s.get_loops(block)
            s.fuse(*s.split(axis, [1, 128]))
   s.annotate(C_shared, "en_some_predicate_versus_overcomputation_selection", 1)
           
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] vinx13 commented on a diff in pull request #12750: [TIR, Schedule] Add schedule primitive PadEinsum

Posted by GitBox <gi...@apache.org>.
vinx13 commented on code in PR #12750:
URL: https://github.com/apache/tvm/pull/12750#discussion_r968855917


##########
tests/python/unittest/test_tir_schedule_pad_einsum.py:
##########
@@ -0,0 +1,123 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+import sys
+
+import pytest
+import tvm
+import tvm.testing
+from tvm import tir, te
+from tvm.script import tir as T
+from tvm.tir.schedule.schedule import ScheduleError
+from tvm.tir.schedule.testing import verify_trace_roundtrip
+from tvm.meta_schedule.testing import te_workload
+
+# pylint: disable=no-member,invalid-name,unused-variable,unexpected-keyword-arg
+
+
+@T.prim_func
+def matmul_before(
+    A: T.Buffer[(128, 127), "float32"],
+    B: T.Buffer[(127, 127), "float32"],
+    C: T.Buffer[(128, 127), "float32"],
+) -> None:
+    A_shared = T.alloc_buffer((128, 127), "float32", scope="shared")
+    B_shared = T.alloc_buffer((127, 127), "float32", scope="shared")
+    C_shared = T.alloc_buffer((128, 127), "float32", scope="shared")
+    for i0, i1 in T.grid(128, 127):
+        with T.block("A"):
+            i, j = T.axis.remap("SS", [i0, i1])
+            A_shared[i, j] = A[i, j]
+    for i0, i1 in T.grid(127, 127):
+        with T.block("B"):
+            i, j = T.axis.remap("SS", [i0, i1])
+            B_shared[i, j] = B[i, j]
+    for i0, i1, i2 in T.grid(128, 127, 127):
+        with T.block("C_shared"):
+            i, j, k = T.axis.remap("SSR", [i0, i1, i2])
+            with T.init():
+                C_shared[i, j] = T.float32(0)
+            C_shared[i, j] = C_shared[i, j] + A_shared[i, k] * B_shared[k, j]
+    for i0, i1 in T.grid(128, 127):
+        with T.block("C"):
+            i, j = T.axis.remap("SS", [i0, i1])
+            C[i, j] = C_shared[i, j]
+
+
+@T.prim_func
+def matmul_expected(

Review Comment:
   Yes. It pads the producers with init value (zero) and over-computes the reduction block



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] vinx13 merged pull request #12750: [TIR, Schedule] Add schedule primitive PadEinsum

Posted by GitBox <gi...@apache.org>.
vinx13 merged PR #12750:
URL: https://github.com/apache/tvm/pull/12750


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Lunderberg commented on pull request #12750: [TIR, Schedule] Add schedule primitive PadEinsum

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on PR #12750:
URL: https://github.com/apache/tvm/pull/12750#issuecomment-1247156286

   @vinx13 Thank you, and that makes sense.  So, one of the simplifying assumptions that is all padding will only be on one side, and if the padding is allowed on both sides, that wouldn't just add a free parameter for the final output, but also for each producer.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org