You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2022/01/06 16:13:07 UTC

[tvm] branch main updated: [BugFix] resolve integer 32. ~ 64. mismatch by casting (#9582)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 07a46a1  [BugFix] resolve integer 32. ~ 64. mismatch by casting (#9582)
07a46a1 is described below

commit 07a46a1c10ec7c9d9b6d5bf8aa7c2377cc511010
Author: Jiawei Liu <ja...@gmail.com>
AuthorDate: Thu Jan 6 10:12:41 2022 -0600

    [BugFix] resolve integer 32. ~ 64. mismatch by casting (#9582)
---
 src/tir/ir/data_layout.cc              |  8 +++++---
 tests/python/relay/test_type_solver.py | 24 ++++++++++++++++++++++++
 2 files changed, 29 insertions(+), 3 deletions(-)

diff --git a/src/tir/ir/data_layout.cc b/src/tir/ir/data_layout.cc
index da3496d..8dea343 100644
--- a/src/tir/ir/data_layout.cc
+++ b/src/tir/ir/data_layout.cc
@@ -218,7 +218,7 @@ inline bool GetStoreRule(Array<PrimExpr>* rule, const Layout& src_layout,
           PrimExpr orig_var = orig_axis_impl->var;
           const int32_t factor = src_layout.FactorOf(orig_axis);
           if (factor > 0) {
-            orig_var = orig_var * PrimExpr(factor);
+            orig_var = orig_var * factor;
           }
           store = store + orig_var;
         } else {
@@ -304,9 +304,11 @@ inline Array<PrimExpr> TransformShape(const Array<PrimExpr>& src_shape,
               << ", get " << orig_shape;
         }
       }
-      bind_map[orig_axis->var.get()] = PrimExpr(0);
+      bind_map[orig_axis->var.get()] = IntImm(orig_axis->var->dtype, 0);
     } else {
-      bind_map[orig_axis->var.get()] = orig_shape;
+      bind_map[orig_axis->var.get()] = orig_axis->var->dtype == orig_shape->dtype
+                                           ? orig_shape
+                                           : cast(orig_axis->var->dtype, orig_shape);
     }
   }
   // infer the target shape,
diff --git a/tests/python/relay/test_type_solver.py b/tests/python/relay/test_type_solver.py
index 88bdd16..c1dc5c0 100644
--- a/tests/python/relay/test_type_solver.py
+++ b/tests/python/relay/test_type_solver.py
@@ -16,7 +16,10 @@
 # under the License.
 import tvm
 from tvm import relay
+from tvm.relay import testing
+
 import pytest
+import numpy as np
 
 
 def make_rel(name, args, num_inputs=None, attrs=None):
@@ -338,6 +341,26 @@ def test_incompatible_quantified_func_unification():
     solver.Unify(ft1, ft2)
 
 
+def test_integer_compatibility_in_layout_transform():
+    x = relay.var("data", shape=(2, 3, 48, 48), dtype="float32")
+    conv_out = relay.nn.conv2d(
+        x,
+        relay.var("weight", shape=(1, 3, 1, 1), dtype="float32"),
+        strides=[47, 47],
+        channels=1,
+        kernel_size=[1, 1],
+    )
+    bias_out = relay.nn.bias_add(conv_out, relay.var("bias"))
+    broadcast_out = relay.op.broadcast_to(bias_out, relay.const([2, 1, 2, 2], dtype="int64"))
+    y = relay.add(bias_out, broadcast_out)
+
+    mod, _ = testing.create_workload(y)
+    with tvm.transform.PassContext(opt_level=3):
+        with tvm.target.Target("llvm"):
+            mod = relay.transform.CanonicalizeOps()(mod)
+            mod = relay.transform.AlterOpLayout()(mod)
+
+
 if __name__ == "__main__":
     test_bcast()
     test_backward_solving()
@@ -357,3 +380,4 @@ if __name__ == "__main__":
     test_incompatible_typecall_var_unification()
     test_incompatible_typecall_args_unification()
     test_incompatible_quantified_func_unification()
+    test_integer_compatibility_in_layout_transform()