You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by xi...@apache.org on 2022/06/15 06:56:27 UTC

[tvm] branch main updated: [Bugfix][TIR] Narrow-Datatype for thread axis (#11725)

This is an automated email from the ASF dual-hosted git repository.

xiyou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 954a927be3 [Bugfix][TIR] Narrow-Datatype for thread axis (#11725)
954a927be3 is described below

commit 954a927be3bb00076ae66b3997483f7ce9b4c355
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Tue Jun 14 23:56:22 2022 -0700

    [Bugfix][TIR] Narrow-Datatype for thread axis (#11725)
    
    This PR fixes a bug in the pass Narrow-Datatype in TIR, where dtype of
    certain IterVar and loop variables are adjusted to narrower ones.
    
    The bug occurs when the dtype of thread axis is int32, while its extent
    is int64, where the original behavior will not narrow the extent to
    int32, which causes an assertion thrown in IterVar's constructor. An
    alternative approach is to re-dtype IterVar to int64, however, the
    subsequent passes do not actually respect int64 thread axes, which leads
    to even more issues in lowering.
    
    This bug prevents AutoTIR in tuning Huggingface DistilBERT.
---
 src/tir/transforms/narrow_datatype.cc              |  3 +--
 .../unittest/test_tir_transform_narrow_datatype.py | 31 +++++++++++++++++++++-
 2 files changed, 31 insertions(+), 3 deletions(-)

diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc
index 8df7b57eaf..16ec86d018 100644
--- a/src/tir/transforms/narrow_datatype.cc
+++ b/src/tir/transforms/narrow_datatype.cc
@@ -281,8 +281,7 @@ class DataTypeRewriter : public StmtExprMutator {
           PrimExpr extend = dom->extent;
           if (extend.dtype().is_int() && var.dtype().is_int() &&
               var.dtype().bits() != extend.dtype().bits()) {
-            int bits = std::max(extend.dtype().bits(), var.dtype().bits());
-            DataType dtype = var.dtype().with_bits(bits);
+            DataType dtype = var.dtype();
             dom = Range(cast(dtype, dom->min), cast(dtype, extend), dom->span);
           }
         }
diff --git a/tests/python/unittest/test_tir_transform_narrow_datatype.py b/tests/python/unittest/test_tir_transform_narrow_datatype.py
index 9909262a44..5c69ddc412 100644
--- a/tests/python/unittest/test_tir_transform_narrow_datatype.py
+++ b/tests/python/unittest/test_tir_transform_narrow_datatype.py
@@ -15,8 +15,9 @@
 # specific language governing permissions and limitations
 # under the License.
 import tvm
-from tvm import te, relay
+from tvm import relay, te
 from tvm.driver.build_module import schedule_to_module
+from tvm.script import tir as T
 from tvm.tir import const
 
 
@@ -118,6 +119,33 @@ def test_thread_axis():
     check(2**14, 32, target_bits=16, target_dtype="int32")
 
 
+def test_thread_axis_2():
+    # fmt: off
+    @tvm.script.ir_module
+    class Before:
+        @T.prim_func
+        def main(T_reshape: T.Buffer[(1, 12, 384, 384), "float32"], placeholder_1: T.Buffer[(T.int64(1), T.int64(12), T.int64(384), 384), "bool"], T_where: T.Buffer[(T.int64(1), T.int64(12), T.int64(384), 384), "float32"]) -> None:
+            # function attr dict
+            T.func_attr({"global_symbol": "main", "tir.noalias": True})
+            # body
+            # with T.block("root")
+            for i0_i1_i2_i3_fused_1 in T.thread_binding(T.int64(256), thread="blockIdx.x"):
+                for i0_i1_i2_i3_fused_2 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
+                    for i0_i1_i2_i3_fused_0 in T.serial(T.int64(7)):
+                        with T.block("T_where"):
+                            ax0 = T.axis.spatial(T.int64(1), T.int64(0))
+                            ax1 = T.axis.spatial(T.int64(12), ((i0_i1_i2_i3_fused_0 * T.int64(256) + i0_i1_i2_i3_fused_1) * T.int64(1024) + i0_i1_i2_i3_fused_2) % T.int64(1769472) // T.int64(147456))
+                            ax2 = T.axis.spatial(T.int64(384), ((i0_i1_i2_i3_fused_0 * T.int64(256) + i0_i1_i2_i3_fused_1) * T.int64(1024) + i0_i1_i2_i3_fused_2) % T.int64(147456) // T.int64(384))
+                            ax3 = T.axis.spatial(384, T.cast(((i0_i1_i2_i3_fused_0 * T.int64(256) + i0_i1_i2_i3_fused_1) * T.int64(1024) + i0_i1_i2_i3_fused_2) % T.int64(384), "int32"))
+                            T.where((i0_i1_i2_i3_fused_0 * T.int64(256) + i0_i1_i2_i3_fused_1) * T.int64(1024) + i0_i1_i2_i3_fused_2 < T.int64(1769472))
+                            T.reads(placeholder_1[ax0, ax1, ax2, ax3], T_reshape[ax0, ax1, ax2, ax3])
+                            T.writes(T_where[ax0, ax1, ax2, ax3])
+                            T_where[ax0, ax1, ax2, ax3] = T.Select(T.cast(placeholder_1[ax0, ax1, ax2, ax3], "int32") != 0, T.float32(-1000000000), T_reshape[ax0, ax1, ax2, ax3])
+    # fmt: on
+    # TODO(@junrushao1994): make this test more "unit" after the new TVMScript printer/parser lands
+    tvm.lower(Before)
+
+
 def test_multilanes():
     def check(m, lanes, target_bits, target_dtype):
         ib = tvm.tir.ir_builder.create()
@@ -280,6 +308,7 @@ def test_ramp_dtype_consistency():
 if __name__ == "__main__":
     test_basic()
     test_thread_axis()
+    test_thread_axis_2()
     test_multilanes()
     test_reduce()
     test_slice()