You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/02/22 00:51:44 UTC

[tvm] branch main updated: [TIR] Fix Ramp int32~64 mismatch in VectorizeLoop and NarrowDataType passes (#10172)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d8e39fd  [TIR] Fix Ramp int32~64 mismatch in VectorizeLoop and NarrowDataType passes (#10172)
d8e39fd is described below

commit d8e39fde1f975934ec9b0d0d206425b34e87f371
Author: Jinkun Lin <la...@gmail.com>
AuthorDate: Mon Feb 21 19:51:03 2022 -0500

    [TIR] Fix Ramp int32~64 mismatch in VectorizeLoop and NarrowDataType passes (#10172)
    
    [TIR] Fix Ramp int32~64 mismatch in VectorizeLoop and NarrowDataType passes
---
 src/tir/transforms/narrow_datatype.cc              | 17 +++++++++++++++
 src/tir/transforms/vectorize_loop.cc               |  2 +-
 .../unittest/test_tir_transform_narrow_datatype.py | 25 +++++++++++++++++++++-
 .../unittest/test_tir_transform_vectorize.py       |  9 ++++++++
 4 files changed, 51 insertions(+), 2 deletions(-)

diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc
index dc34626..dd5f54e 100644
--- a/src/tir/transforms/narrow_datatype.cc
+++ b/src/tir/transforms/narrow_datatype.cc
@@ -253,6 +253,23 @@ class DataTypeRewriter : public StmtExprMutator {
     return StmtExprMutator::VisitExpr_(op);
   }
 
+  PrimExpr VisitExpr_(const RampNode* op) final {
+    PrimExpr base = VisitExpr(op->base);
+    PrimExpr stride = VisitExpr(op->stride);
+    if (base.same_as(op->base) && stride.same_as(op->stride)) {
+      return GetRef<PrimExpr>(op);
+    } else {
+      if (base.dtype().is_int()) {
+        ICHECK(stride.dtype().is_int()) << "Ramp base is int but stride is " << stride.dtype();
+        int bits = std::max(base.dtype().bits(), stride.dtype().bits());
+        DataType dtype = base.dtype().with_bits(bits);
+        if (base.dtype() != dtype) base = cast(dtype, base);
+        if (stride.dtype() != dtype) stride = cast(dtype, stride);
+      }
+      return Ramp(base, stride, op->lanes);
+    }
+  }
+
   PrimExpr VisitExpr_(const SizeVarNode* op) final {
     if (visitor_.vmap.find(op) != visitor_.vmap.end()) {
       if (vmap_.find(op) == vmap_.end()) {
diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc
index cd2d230..0c9c97a 100644
--- a/src/tir/transforms/vectorize_loop.cc
+++ b/src/tir/transforms/vectorize_loop.cc
@@ -101,7 +101,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
   using StmtMutator::operator();
 
   Vectorizer(Var var, int var_lanes) : var_(var), var_lanes_(var_lanes) {
-    ramp_ = Ramp(0, 1, var_lanes);
+    ramp_ = Ramp(IntImm(var->dtype, 0), IntImm(var->dtype, 1), var_lanes);
   }
 
   Stmt VisitStmt(const Stmt& stmt) final {
diff --git a/tests/python/unittest/test_tir_transform_narrow_datatype.py b/tests/python/unittest/test_tir_transform_narrow_datatype.py
index 9b95266..667fad0 100644
--- a/tests/python/unittest/test_tir_transform_narrow_datatype.py
+++ b/tests/python/unittest/test_tir_transform_narrow_datatype.py
@@ -27,7 +27,7 @@ def lower_stmt(params, stmt, target_bits):
     return stmt
 
 
-def lower_sch(sch, args, target_bits):
+def lower_sch(sch, args, target_bits, extra_passes=None):
     binds = {}
     arg_list = []
     for x in args:
@@ -42,6 +42,9 @@ def lower_sch(sch, args, target_bits):
 
     mod = schedule_to_module(sch, args)
     mod = tvm.tir.transform.StorageFlatten(64)(mod)
+    if extra_passes:
+        for p in extra_passes:
+            mod = p(mod)
     return tvm.tir.transform.NarrowDataType(target_bits)(mod)["main"].body
 
 
@@ -255,6 +258,25 @@ def test_relay_take():
     )
 
 
+def test_ramp_dtype_consistency():
+    """
+    for (i :int64, (int64)0, (int64)4) {
+        A[ramp(i*(int64)2, (int64)1, 2)] = cast(int64, 2 ** 31 - 1) * i;
+    }
+    The infer result:
+        base:   int64 -> int64 (since i is involved in another int64 expr)
+        stride: int64 -> int32
+
+    Thus ramp should still use int64 for both stride and base after rewrite.
+    """
+    n = tvm.tir.IntImm("int64", 4)
+    m = tvm.tir.IntImm("int64", 2)
+    A = te.compute((n, m), lambda i, j: tvm.tir.Cast("int64", 2 ** 31 - 1) * i, name="A")
+    s = te.create_schedule(A.op)
+    s[A].vectorize(A.op.axis[1])
+    lower_sch(s, [A], 32, extra_passes=[tvm.tir.transform.VectorizeLoop()])
+
+
 if __name__ == "__main__":
     test_basic()
     test_thread_axis()
@@ -263,3 +285,4 @@ if __name__ == "__main__":
     test_slice()
     test_relay_basic()
     test_relay_take()
+    test_ramp_dtype_consistency()
diff --git a/tests/python/unittest/test_tir_transform_vectorize.py b/tests/python/unittest/test_tir_transform_vectorize.py
index b1e5809..1a0d84a 100644
--- a/tests/python/unittest/test_tir_transform_vectorize.py
+++ b/tests/python/unittest/test_tir_transform_vectorize.py
@@ -205,6 +205,14 @@ def test_vectorize_while_fail():
         assert expected in error_msg
 
 
+def test_vectorize_dtype_mismatch():
+    n = tvm.tir.IntImm("int64", 4)
+    A = te.compute((n,), lambda i: tvm.tir.IntImm("int64", 2 ** 31 - 1) + i, name="A")
+    s = te.create_schedule(A.op)
+    s[A].vectorize(A.op.axis[0])
+    tvm.lower(s, [A], "llvm", simple_mode=True)
+
+
 if __name__ == "__main__":
     test_vectorize_vector()
     test_vectorize_with_if()
@@ -214,3 +222,4 @@ if __name__ == "__main__":
     test_vectorize_with_ge_cond()
     test_vectorize_let()
     test_vectorize_while_fail()
+    test_vectorize_dtype_mismatch()