You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/04/13 09:54:15 UTC

[tvm] branch main updated: [FIX] resolve int64/32 for AttrStmtNode (#10983)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 61a9269d85 [FIX] resolve int64/32 for AttrStmtNode (#10983)
61a9269d85 is described below

commit 61a9269d85c2966d122e1216c5c91e2d9764dc84
Author: Jiawei Liu <ja...@gmail.com>
AuthorDate: Wed Apr 13 04:54:09 2022 -0500

    [FIX] resolve int64/32 for AttrStmtNode (#10983)
    
    * resolve int64/32 for AttrStmtNode
    
    * rm debug header
    
    * refine
    
    * add test case
    
    * lint
---
 src/tir/transforms/narrow_datatype.cc | 12 +++++++++++-
 tests/python/relay/test_op_level10.py | 17 +++++++++++++++++
 2 files changed, 28 insertions(+), 1 deletion(-)

diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc
index c2bf273931..8df7b57eaf 100644
--- a/src/tir/transforms/narrow_datatype.cc
+++ b/src/tir/transforms/narrow_datatype.cc
@@ -276,7 +276,17 @@ class DataTypeRewriter : public StmtExprMutator {
       PrimExpr e = VisitExpr(iv->var);
       Var var = Downcast<Var>(e);
       if (ivmap_.find(iv) == ivmap_.end()) {
-        ivmap_[iv] = IterVar(iv->dom, var, iv->iter_type, iv->thread_tag);
+        Range dom = iv->dom;
+        if (dom.defined()) {
+          PrimExpr extend = dom->extent;
+          if (extend.dtype().is_int() && var.dtype().is_int() &&
+              var.dtype().bits() != extend.dtype().bits()) {
+            int bits = std::max(extend.dtype().bits(), var.dtype().bits());
+            DataType dtype = var.dtype().with_bits(bits);
+            dom = Range(cast(dtype, dom->min), cast(dtype, extend), dom->span);
+          }
+        }
+        ivmap_[iv] = IterVar(dom, var, iv->iter_type, iv->thread_tag);
       }
       return AttrStmt(ivmap_[iv], op->attr_key, cast(var.dtype(), op->value), op->body);
     }
diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py
index 0486ef4001..85a3dd5636 100644
--- a/tests/python/relay/test_op_level10.py
+++ b/tests/python/relay/test_op_level10.py
@@ -229,6 +229,23 @@ def test_broadcast_to():
             tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5)
 
 
+@tvm.testing.uses_gpu
+def test_broadcast_to_const_shape_int64():
+    shape_like = relay.const(np.array([1, 5]), dtype="int64")
+    x = relay.var("x", shape=(1,), dtype="int64")
+    z = relay.broadcast_to(x, shape=shape_like)
+    z = relay.sum(z, axis=0)
+
+    f = relay.Function([x], z)
+
+    x = np.random.randint(10, size=(1,), dtype="int64")
+    ref_res = np.broadcast_to(x, (5,))
+    for target, dev in tvm.testing.enabled_targets():
+        for kind in ["graph", "debug"]:
+            op_res = relay.create_executor(kind, device=dev, target=target).evaluate(f)(x)
+            tvm.testing.assert_allclose(op_res.numpy(), ref_res)
+
+
 @tvm.testing.uses_gpu
 def test_broadcast_to_like():
     shape = (4, 1, 6)