You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/07/08 17:26:01 UTC
[tvm] branch main updated: [TIR] Avoid unnecessary dtype escalation in loop splitting (#12035)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 684a838160 [TIR] Avoid unnecessary dtype escalation in loop splitting (#12035)
684a838160 is described below
commit 684a8381608f0978ea91539af7c9d3c2f6e85eaa
Author: Xiyou Zhou <xi...@octoml.ai>
AuthorDate: Fri Jul 8 10:25:56 2022 -0700
[TIR] Avoid unnecessary dtype escalation in loop splitting (#12035)
This PR introduces a type check to cast loop split decisions (sometimes given as `int64`) back to a smaller datatype when the loop variable's data type is smaller. This issue usually happens during reloading a trace from disk using JSON database and causes the failure of `CompactBufferAllocation` pass.
---
src/tir/schedule/concrete_schedule.cc | 3 +++
tests/python/unittest/test_tir_schedule_split_fuse.py | 9 +++++++++
2 files changed, 12 insertions(+)
diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc
index c19735025d..35f31ac916 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -452,6 +452,9 @@ Array<LoopRV> ConcreteScheduleNode::Split(const LoopRV& loop_rv,
if (is_const_int(factor) && !is_positive_const(factor)) {
throw NonPositiveFactorError(state_->mod, factor.as<IntImmNode>()->value, i);
}
+ if (factor.dtype().bits() > loop->extent.dtype().bits()) {
+ factor = cast(loop->extent.dtype(), factor);
+ }
factors.push_back(factor);
tot_length *= factor;
}
diff --git a/tests/python/unittest/test_tir_schedule_split_fuse.py b/tests/python/unittest/test_tir_schedule_split_fuse.py
index 0bfac4e425..9fd678174d 100644
--- a/tests/python/unittest/test_tir_schedule_split_fuse.py
+++ b/tests/python/unittest/test_tir_schedule_split_fuse.py
@@ -20,6 +20,7 @@ import tvm
import tvm.testing
from tvm import te, tir
from tvm.script import tir as T
+from tvm.tir.expr import IntImm
from tvm.tir.schedule.testing import verify_trace_roundtrip
# pylint: disable=no-member,invalid-name,unused-variable
@@ -637,5 +638,13 @@ def test_split_int64_extent_with_int32_factors():
)
+def test_split_int64_factors():
+ sch = tir.Schedule(elementwise_symbolic, debug_mask="all")
+ block_b = sch.get_block("B")
+ _, _, k = sch.get_loops(block_b)
+ sch.split(k, factors=[IntImm(dtype="int64", value=10), None])
+ tvm.ir.assert_structural_equal(elementwise_symbolic_split, sch.mod["main"])
+
+
if __name__ == "__main__":
tvm.testing.main()