You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2021/12/27 00:16:49 UTC

[tvm] branch main updated: [BugFix][TensorIR] Non-positive constant input factors for `split` (#9805)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2c654b57 [BugFix][TensorIR] Non-positive constant input factors for `split` (#9805)
2c654b57 is described below

commit 2c654b577cc84cca6bb3322000c2d596e8a1034c
Author: Ruihang Lai <la...@qq.com>
AuthorDate: Mon Dec 27 08:16:11 2021 +0800

    [BugFix][TensorIR] Non-positive constant input factors for `split` (#9805)
    
    * Update docs of GetProducers/GetConsumers
    
    * Fix split for non-positive factors
---
 include/tvm/tir/schedule/schedule.h                | 14 +++++----
 python/tvm/tir/schedule/schedule.py                |  2 +-
 src/tir/schedule/concrete_schedule.cc              | 33 ++++++++++++++++++++--
 .../unittest/test_tir_schedule_split_fuse.py       | 12 ++++++++
 4 files changed, 51 insertions(+), 10 deletions(-)

diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h
index 1b64a71..210ed53 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -239,15 +239,17 @@ class ScheduleNode : public runtime::Object {
    */
   virtual Array<BlockRV> GetChildBlocks(const LoopRV& loop_rv) = 0;
   /*!
-   * \brief Get the producer of a specific block
+   * \brief Get the producer of a specific block, under the same block scope
    * \param block_rv The block in the query
-   * \return A list of blocks, the producers of the given block
+   * \return A list of blocks, the producers of the given block under the same scope of the given
+   * block
    */
   virtual Array<BlockRV> GetProducers(const BlockRV& block_rv) = 0;
   /*!
-   * \brief Get the consumers of a specific block
+   * \brief Get the consumers of a specific block, under the same block scope
    * \param block_rv The block to be queried
-   * \return A list of blocks, the consumers of the given block
+   * \return A list of blocks, the consumers of the given block under the same scope of the given
+   * block
    */
   virtual Array<BlockRV> GetConsumers(const BlockRV& block_rv) = 0;
   /******** Schedule: Transform loops ********/
@@ -266,8 +268,8 @@ class ScheduleNode : public runtime::Object {
    * 1) The loop can't have annotation or thread binding.
    * 2) The loop must start with 0.
    * \param loop_rv The loop to be split
-   * \param factors The tiling factors, and at most one of which is -1, which means that
-   * factor is inferred.
+   * \param factors The positive tiling factors, and at most one of which is `NullOpt`, which means
+   * that factor is inferred.
    * \return The new loops after split
    */
   virtual Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors) = 0;
diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py
index e82d75f..50905ee 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -546,7 +546,7 @@ class Schedule(Object):
             Potential inputs are:
             - None
             - ExprRV
-            - Non-negative constant integers
+            - Positive constant integers
 
         Returns
         -------
diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc
index 81d91b7..65886da 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -376,6 +376,31 @@ Array<LoopRV> ConcreteScheduleNode::Split(const LoopRV& loop_rv,
     IRModule mod_;
     For loop_;
   };
+
+  class NonPositiveFactorError : public ScheduleError {
+   public:
+    explicit NonPositiveFactorError(IRModule mod, int64_t factor, size_t idx)
+        : mod_(std::move(mod)), factor_(factor), idx_(idx) {}
+
+    String FastErrorString() const final {
+      return "ScheduleError: All the constant factors are required to be positive. However, some "
+             "constant input factor is zero or negative.";
+    }
+    String DetailRenderTemplate() const final {
+      std::ostringstream os;
+      os << "All the constant factors are required to be positive. However, the factor at position "
+         << idx_ << " is " << factor_;
+      return os.str();
+    }
+    IRModule mod() const final { return mod_; }
+    Array<ObjectRef> LocationsOfInterest() const final { return {}; }
+
+   private:
+    IRModule mod_;
+    int64_t factor_;
+    size_t idx_;
+  };
+
   // Prepare for the splitting
   StmtSRef loop_sref = this->GetSRef(loop_rv);
   const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
@@ -389,13 +414,15 @@ Array<LoopRV> ConcreteScheduleNode::Split(const LoopRV& loop_rv,
   for (size_t i = 0; i < factor_rvs.size(); i++) {
     if (!factor_rvs[i].defined()) {
       factors.push_back(Integer(-1));
-      if (infer_index == -1) {
-        infer_index = i;
-      } else {
+      if (infer_index != -1) {
         throw NotSingleInferFactorError(state_->mod);
       }
+      infer_index = i;
     } else {
       PrimExpr factor = this->Get(factor_rvs[i].value());
+      if (is_const_int(factor) && !is_positive_const(factor)) {
+        throw NonPositiveFactorError(state_->mod, factor.as<IntImmNode>()->value, i);
+      }
       factors.push_back(factor);
       tot_length *= factor;
     }
diff --git a/tests/python/unittest/test_tir_schedule_split_fuse.py b/tests/python/unittest/test_tir_schedule_split_fuse.py
index d2365c3..84ecece 100644
--- a/tests/python/unittest/test_tir_schedule_split_fuse.py
+++ b/tests/python/unittest/test_tir_schedule_split_fuse.py
@@ -466,6 +466,18 @@ def test_split_with_opaque_access():
     verify_trace_roundtrip(sch=sch, mod=opaque_access)
 
 
+def test_split_with_non_positive_factors():
+    sch = tir.Schedule(elementwise, debug_mask="all")
+    block_b = sch.get_block("B")
+    i, j, k = sch.get_loops(block_b)
+    with pytest.raises(tvm.tir.ScheduleError):
+        sch.split(i, factors=[-2, -64])
+    with pytest.raises(tvm.tir.ScheduleError):
+        sch.split(j, factors=[0, None])
+    with pytest.raises(tvm.tir.ScheduleError):
+        sch.split(k, factors=[None, -16])
+
+
 def test_fuse_split_fail_with_thread_binding():
     sch = tir.Schedule(elementwise_with_thread_binding, debug_mask="all")
     block_b = sch.get_block("B")