You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/07/08 17:26:01 UTC

[tvm] branch main updated: [TIR] Avoid unnecessary dtype escalation in loop splitting (#12035)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 684a838160 [TIR] Avoid unnecessary dtype escalation in loop splitting (#12035)
684a838160 is described below

commit 684a8381608f0978ea91539af7c9d3c2f6e85eaa
Author: Xiyou Zhou <xi...@octoml.ai>
AuthorDate: Fri Jul 8 10:25:56 2022 -0700

    [TIR] Avoid unnecessary dtype escalation in loop splitting (#12035)
    
    This PR introduces a type check to cast loop split decisions (sometimes given as `int64`) back to a smaller datatype when the loop variable's data type is smaller. This issue usually happens during reloading a trace from disk using JSON database and causes the failure of `CompactBufferAllocation` pass.
---
 src/tir/schedule/concrete_schedule.cc                 | 3 +++
 tests/python/unittest/test_tir_schedule_split_fuse.py | 9 +++++++++
 2 files changed, 12 insertions(+)

diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc
index c19735025d..35f31ac916 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -452,6 +452,9 @@ Array<LoopRV> ConcreteScheduleNode::Split(const LoopRV& loop_rv,
       if (is_const_int(factor) && !is_positive_const(factor)) {
         throw NonPositiveFactorError(state_->mod, factor.as<IntImmNode>()->value, i);
       }
+      if (factor.dtype().bits() > loop->extent.dtype().bits()) {
+        factor = cast(loop->extent.dtype(), factor);
+      }
       factors.push_back(factor);
       tot_length *= factor;
     }
diff --git a/tests/python/unittest/test_tir_schedule_split_fuse.py b/tests/python/unittest/test_tir_schedule_split_fuse.py
index 0bfac4e425..9fd678174d 100644
--- a/tests/python/unittest/test_tir_schedule_split_fuse.py
+++ b/tests/python/unittest/test_tir_schedule_split_fuse.py
@@ -20,6 +20,7 @@ import tvm
 import tvm.testing
 from tvm import te, tir
 from tvm.script import tir as T
+from tvm.tir.expr import IntImm
 from tvm.tir.schedule.testing import verify_trace_roundtrip
 
 # pylint: disable=no-member,invalid-name,unused-variable
@@ -637,5 +638,13 @@ def test_split_int64_extent_with_int32_factors():
     )
 
 
+def test_split_int64_factors():
+    sch = tir.Schedule(elementwise_symbolic, debug_mask="all")
+    block_b = sch.get_block("B")
+    _, _, k = sch.get_loops(block_b)
+    sch.split(k, factors=[IntImm(dtype="int64", value=10), None])
+    tvm.ir.assert_structural_equal(elementwise_symbolic_split, sch.mod["main"])
+
+
 if __name__ == "__main__":
     tvm.testing.main()