You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2022/01/06 16:13:07 UTC
[tvm] branch main updated: [BugFix] resolve integer 32. ~ 64. mismatch by casting (#9582)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 07a46a1 [BugFix] resolve integer 32. ~ 64. mismatch by casting (#9582)
07a46a1 is described below
commit 07a46a1c10ec7c9d9b6d5bf8aa7c2377cc511010
Author: Jiawei Liu <ja...@gmail.com>
AuthorDate: Thu Jan 6 10:12:41 2022 -0600
[BugFix] resolve integer 32. ~ 64. mismatch by casting (#9582)
---
src/tir/ir/data_layout.cc | 8 +++++---
tests/python/relay/test_type_solver.py | 24 ++++++++++++++++++++++++
2 files changed, 29 insertions(+), 3 deletions(-)
diff --git a/src/tir/ir/data_layout.cc b/src/tir/ir/data_layout.cc
index da3496d..8dea343 100644
--- a/src/tir/ir/data_layout.cc
+++ b/src/tir/ir/data_layout.cc
@@ -218,7 +218,7 @@ inline bool GetStoreRule(Array<PrimExpr>* rule, const Layout& src_layout,
PrimExpr orig_var = orig_axis_impl->var;
const int32_t factor = src_layout.FactorOf(orig_axis);
if (factor > 0) {
- orig_var = orig_var * PrimExpr(factor);
+ orig_var = orig_var * factor;
}
store = store + orig_var;
} else {
@@ -304,9 +304,11 @@ inline Array<PrimExpr> TransformShape(const Array<PrimExpr>& src_shape,
<< ", get " << orig_shape;
}
}
- bind_map[orig_axis->var.get()] = PrimExpr(0);
+ bind_map[orig_axis->var.get()] = IntImm(orig_axis->var->dtype, 0);
} else {
- bind_map[orig_axis->var.get()] = orig_shape;
+ bind_map[orig_axis->var.get()] = orig_axis->var->dtype == orig_shape->dtype
+ ? orig_shape
+ : cast(orig_axis->var->dtype, orig_shape);
}
}
// infer the target shape,
diff --git a/tests/python/relay/test_type_solver.py b/tests/python/relay/test_type_solver.py
index 88bdd16..c1dc5c0 100644
--- a/tests/python/relay/test_type_solver.py
+++ b/tests/python/relay/test_type_solver.py
@@ -16,7 +16,10 @@
# under the License.
import tvm
from tvm import relay
+from tvm.relay import testing
+
import pytest
+import numpy as np
def make_rel(name, args, num_inputs=None, attrs=None):
@@ -338,6 +341,26 @@ def test_incompatible_quantified_func_unification():
solver.Unify(ft1, ft2)
+def test_integer_compatibility_in_layout_transform():
+ x = relay.var("data", shape=(2, 3, 48, 48), dtype="float32")
+ conv_out = relay.nn.conv2d(
+ x,
+ relay.var("weight", shape=(1, 3, 1, 1), dtype="float32"),
+ strides=[47, 47],
+ channels=1,
+ kernel_size=[1, 1],
+ )
+ bias_out = relay.nn.bias_add(conv_out, relay.var("bias"))
+ broadcast_out = relay.op.broadcast_to(bias_out, relay.const([2, 1, 2, 2], dtype="int64"))
+ y = relay.add(bias_out, broadcast_out)
+
+ mod, _ = testing.create_workload(y)
+ with tvm.transform.PassContext(opt_level=3):
+ with tvm.target.Target("llvm"):
+ mod = relay.transform.CanonicalizeOps()(mod)
+ mod = relay.transform.AlterOpLayout()(mod)
+
+
if __name__ == "__main__":
test_bcast()
test_backward_solving()
@@ -357,3 +380,4 @@ if __name__ == "__main__":
test_incompatible_typecall_var_unification()
test_incompatible_typecall_args_unification()
test_incompatible_quantified_func_unification()
+ test_integer_compatibility_in_layout_transform()