You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/02/22 00:51:44 UTC
[tvm] branch main updated: [TIR] Fix Ramp int32~64 mismatch in VectorizeLoop and NarrowDataType passes (#10172)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d8e39fd [TIR] Fix Ramp int32~64 mismatch in VectorizeLoop and NarrowDataType passes (#10172)
d8e39fd is described below
commit d8e39fde1f975934ec9b0d0d206425b34e87f371
Author: Jinkun Lin <la...@gmail.com>
AuthorDate: Mon Feb 21 19:51:03 2022 -0500
[TIR] Fix Ramp int32~64 mismatch in VectorizeLoop and NarrowDataType passes (#10172)
[TIR] Fix Ramp int32~64 mismatch in VectorizeLoop and NarrowDataType passes
---
src/tir/transforms/narrow_datatype.cc | 17 +++++++++++++++
src/tir/transforms/vectorize_loop.cc | 2 +-
.../unittest/test_tir_transform_narrow_datatype.py | 25 +++++++++++++++++++++-
.../unittest/test_tir_transform_vectorize.py | 9 ++++++++
4 files changed, 51 insertions(+), 2 deletions(-)
diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc
index dc34626..dd5f54e 100644
--- a/src/tir/transforms/narrow_datatype.cc
+++ b/src/tir/transforms/narrow_datatype.cc
@@ -253,6 +253,23 @@ class DataTypeRewriter : public StmtExprMutator {
return StmtExprMutator::VisitExpr_(op);
}
+ PrimExpr VisitExpr_(const RampNode* op) final {
+ PrimExpr base = VisitExpr(op->base);
+ PrimExpr stride = VisitExpr(op->stride);
+ if (base.same_as(op->base) && stride.same_as(op->stride)) {
+ return GetRef<PrimExpr>(op);
+ } else {
+ if (base.dtype().is_int()) {
+ ICHECK(stride.dtype().is_int()) << "Ramp base is int but stride is " << stride.dtype();
+ int bits = std::max(base.dtype().bits(), stride.dtype().bits());
+ DataType dtype = base.dtype().with_bits(bits);
+ if (base.dtype() != dtype) base = cast(dtype, base);
+ if (stride.dtype() != dtype) stride = cast(dtype, stride);
+ }
+ return Ramp(base, stride, op->lanes);
+ }
+ }
+
PrimExpr VisitExpr_(const SizeVarNode* op) final {
if (visitor_.vmap.find(op) != visitor_.vmap.end()) {
if (vmap_.find(op) == vmap_.end()) {
diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc
index cd2d230..0c9c97a 100644
--- a/src/tir/transforms/vectorize_loop.cc
+++ b/src/tir/transforms/vectorize_loop.cc
@@ -101,7 +101,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
using StmtMutator::operator();
Vectorizer(Var var, int var_lanes) : var_(var), var_lanes_(var_lanes) {
- ramp_ = Ramp(0, 1, var_lanes);
+ ramp_ = Ramp(IntImm(var->dtype, 0), IntImm(var->dtype, 1), var_lanes);
}
Stmt VisitStmt(const Stmt& stmt) final {
diff --git a/tests/python/unittest/test_tir_transform_narrow_datatype.py b/tests/python/unittest/test_tir_transform_narrow_datatype.py
index 9b95266..667fad0 100644
--- a/tests/python/unittest/test_tir_transform_narrow_datatype.py
+++ b/tests/python/unittest/test_tir_transform_narrow_datatype.py
@@ -27,7 +27,7 @@ def lower_stmt(params, stmt, target_bits):
return stmt
-def lower_sch(sch, args, target_bits):
+def lower_sch(sch, args, target_bits, extra_passes=None):
binds = {}
arg_list = []
for x in args:
@@ -42,6 +42,9 @@ def lower_sch(sch, args, target_bits):
mod = schedule_to_module(sch, args)
mod = tvm.tir.transform.StorageFlatten(64)(mod)
+ if extra_passes:
+ for p in extra_passes:
+ mod = p(mod)
return tvm.tir.transform.NarrowDataType(target_bits)(mod)["main"].body
@@ -255,6 +258,25 @@ def test_relay_take():
)
+def test_ramp_dtype_consistency():
+ """
+ for (i :int64, (int64)0, (int64)4) {
+ A[ramp(i*(int64)2, (int64)1, 2)] = cast(int64, 2 ** 31 - 1) * i;
+ }
+ The infer result:
+ base: int64 -> int64 (since i is involved in another int64 expr)
+ stride: int64 -> int32
+
+ Thus ramp should still use int64 for both stride and base after rewrite.
+ """
+ n = tvm.tir.IntImm("int64", 4)
+ m = tvm.tir.IntImm("int64", 2)
+ A = te.compute((n, m), lambda i, j: tvm.tir.Cast("int64", 2 ** 31 - 1) * i, name="A")
+ s = te.create_schedule(A.op)
+ s[A].vectorize(A.op.axis[1])
+ lower_sch(s, [A], 32, extra_passes=[tvm.tir.transform.VectorizeLoop()])
+
+
if __name__ == "__main__":
test_basic()
test_thread_axis()
@@ -263,3 +285,4 @@ if __name__ == "__main__":
test_slice()
test_relay_basic()
test_relay_take()
+ test_ramp_dtype_consistency()
diff --git a/tests/python/unittest/test_tir_transform_vectorize.py b/tests/python/unittest/test_tir_transform_vectorize.py
index b1e5809..1a0d84a 100644
--- a/tests/python/unittest/test_tir_transform_vectorize.py
+++ b/tests/python/unittest/test_tir_transform_vectorize.py
@@ -205,6 +205,14 @@ def test_vectorize_while_fail():
assert expected in error_msg
+def test_vectorize_dtype_mismatch():
+ n = tvm.tir.IntImm("int64", 4)
+ A = te.compute((n,), lambda i: tvm.tir.IntImm("int64", 2 ** 31 - 1) + i, name="A")
+ s = te.create_schedule(A.op)
+ s[A].vectorize(A.op.axis[0])
+ tvm.lower(s, [A], "llvm", simple_mode=True)
+
+
if __name__ == "__main__":
test_vectorize_vector()
test_vectorize_with_if()
@@ -214,3 +222,4 @@ if __name__ == "__main__":
test_vectorize_with_ge_cond()
test_vectorize_let()
test_vectorize_while_fail()
+ test_vectorize_dtype_mismatch()