You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2021/12/27 00:16:49 UTC
[tvm] branch main updated: [BugFix][TensorIR] Non-positive constant input factors for `split` (#9805)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 2c654b57 [BugFix][TensorIR] Non-positive constant input factors for `split` (#9805)
2c654b57 is described below
commit 2c654b577cc84cca6bb3322000c2d596e8a1034c
Author: Ruihang Lai <la...@qq.com>
AuthorDate: Mon Dec 27 08:16:11 2021 +0800
[BugFix][TensorIR] Non-positive constant input factors for `split` (#9805)
* Update docs of GetProducers/GetConsumers
* Fix split for non-positive factors
---
include/tvm/tir/schedule/schedule.h | 14 +++++----
python/tvm/tir/schedule/schedule.py | 2 +-
src/tir/schedule/concrete_schedule.cc | 33 ++++++++++++++++++++--
.../unittest/test_tir_schedule_split_fuse.py | 12 ++++++++
4 files changed, 51 insertions(+), 10 deletions(-)
diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h
index 1b64a71..210ed53 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -239,15 +239,17 @@ class ScheduleNode : public runtime::Object {
*/
virtual Array<BlockRV> GetChildBlocks(const LoopRV& loop_rv) = 0;
/*!
- * \brief Get the producer of a specific block
+ * \brief Get the producer of a specific block, under the same block scope
* \param block_rv The block in the query
- * \return A list of blocks, the producers of the given block
+ * \return A list of blocks, the producers of the given block under the same scope of the given
+ * block
*/
virtual Array<BlockRV> GetProducers(const BlockRV& block_rv) = 0;
/*!
- * \brief Get the consumers of a specific block
+ * \brief Get the consumers of a specific block, under the same block scope
* \param block_rv The block to be queried
- * \return A list of blocks, the consumers of the given block
+ * \return A list of blocks, the consumers of the given block under the same scope of the given
+ * block
*/
virtual Array<BlockRV> GetConsumers(const BlockRV& block_rv) = 0;
/******** Schedule: Transform loops ********/
@@ -266,8 +268,8 @@ class ScheduleNode : public runtime::Object {
* 1) The loop can't have annotation or thread binding.
* 2) The loop must start with 0.
* \param loop_rv The loop to be split
- * \param factors The tiling factors, and at most one of which is -1, which means that
- * factor is inferred.
+ * \param factors The positive tiling factors, and at most one of which is `NullOpt`, which means
+ * that factor is inferred.
* \return The new loops after split
*/
virtual Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors) = 0;
diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py
index e82d75f..50905ee 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -546,7 +546,7 @@ class Schedule(Object):
Potential inputs are:
- None
- ExprRV
- - Non-negative constant integers
+ - Positive constant integers
Returns
-------
diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc
index 81d91b7..65886da 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -376,6 +376,31 @@ Array<LoopRV> ConcreteScheduleNode::Split(const LoopRV& loop_rv,
IRModule mod_;
For loop_;
};
+
+ class NonPositiveFactorError : public ScheduleError {
+ public:
+ explicit NonPositiveFactorError(IRModule mod, int64_t factor, size_t idx)
+ : mod_(std::move(mod)), factor_(factor), idx_(idx) {}
+
+ String FastErrorString() const final {
+ return "ScheduleError: All the constant factors are required to be positive. However, some "
+ "constant input factor is zero or negative.";
+ }
+ String DetailRenderTemplate() const final {
+ std::ostringstream os;
+ os << "All the constant factors are required to be positive. However, the factor at position "
+ << idx_ << " is " << factor_;
+ return os.str();
+ }
+ IRModule mod() const final { return mod_; }
+ Array<ObjectRef> LocationsOfInterest() const final { return {}; }
+
+ private:
+ IRModule mod_;
+ int64_t factor_;
+ size_t idx_;
+ };
+
// Prepare for the splitting
StmtSRef loop_sref = this->GetSRef(loop_rv);
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
@@ -389,13 +414,15 @@ Array<LoopRV> ConcreteScheduleNode::Split(const LoopRV& loop_rv,
for (size_t i = 0; i < factor_rvs.size(); i++) {
if (!factor_rvs[i].defined()) {
factors.push_back(Integer(-1));
- if (infer_index == -1) {
- infer_index = i;
- } else {
+ if (infer_index != -1) {
throw NotSingleInferFactorError(state_->mod);
}
+ infer_index = i;
} else {
PrimExpr factor = this->Get(factor_rvs[i].value());
+ if (is_const_int(factor) && !is_positive_const(factor)) {
+ throw NonPositiveFactorError(state_->mod, factor.as<IntImmNode>()->value, i);
+ }
factors.push_back(factor);
tot_length *= factor;
}
diff --git a/tests/python/unittest/test_tir_schedule_split_fuse.py b/tests/python/unittest/test_tir_schedule_split_fuse.py
index d2365c3..84ecece 100644
--- a/tests/python/unittest/test_tir_schedule_split_fuse.py
+++ b/tests/python/unittest/test_tir_schedule_split_fuse.py
@@ -466,6 +466,18 @@ def test_split_with_opaque_access():
verify_trace_roundtrip(sch=sch, mod=opaque_access)
+def test_split_with_non_positive_factors():
+ sch = tir.Schedule(elementwise, debug_mask="all")
+ block_b = sch.get_block("B")
+ i, j, k = sch.get_loops(block_b)
+ with pytest.raises(tvm.tir.ScheduleError):
+ sch.split(i, factors=[-2, -64])
+ with pytest.raises(tvm.tir.ScheduleError):
+ sch.split(j, factors=[0, None])
+ with pytest.raises(tvm.tir.ScheduleError):
+ sch.split(k, factors=[None, -16])
+
+
def test_fuse_split_fail_with_thread_binding():
sch = tir.Schedule(elementwise_with_thread_binding, debug_mask="all")
block_b = sch.get_block("B")