You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/03/23 15:47:37 UTC
[incubator-tvm] branch master updated: [Bugfix] Fixed bug where
shifting by out-of-bounds value results in no compute code being emitted.
(#5115)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new a422589 [Bugfix] Fixed bug where shifting by out-of-bounds value results in no compute code being emitted. (#5115)
a422589 is described below
commit a422589c36ad11dea7a3b4c94534e36833a12c50
Author: pankratz <35...@users.noreply.github.com>
AuthorDate: Mon Mar 23 09:47:29 2020 -0600
[Bugfix] Fixed bug where shifting by out-of-bounds value results in no compute code being emitted. (#5115)
* Fixed bug where shifting by out-of-bounds RHS values results in LLVM to codegen nothing. Added regression testcase
* Updated testcase to be more precise.
* Fixed testcase
---
src/tir/ir/op.cc | 6 ++++++
tests/python/unittest/test_tir_nodes.py | 18 ++++++++++++++++++
2 files changed, 24 insertions(+)
diff --git a/src/tir/ir/op.cc b/src/tir/ir/op.cc
index cf1c24c..4ad244f 100644
--- a/src/tir/ir/op.cc
+++ b/src/tir/ir/op.cc
@@ -469,6 +469,9 @@ PrimExpr operator>>(PrimExpr a, PrimExpr b) {
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
+ if (pb) CHECK(pb->value >= 0 && pb->value < rtype.bits()) <<
+ "Shift amount must be non-negative and less than " << rtype.bits()
+ << " for type " << rtype;
if (pa && pb) return IntImm(rtype, (pa->value >> pb->value));
if (pb) {
if (pb->value == 0) return a;
@@ -484,6 +487,9 @@ PrimExpr operator<<(PrimExpr a, PrimExpr b) {
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
+ if (pb) CHECK(pb->value >= 0 && pb->value < rtype.bits()) <<
+ "Shift amount must be non-negative and less than " << rtype.bits()
+ << " for type " << rtype;
if (pa && pb) return IntImm(rtype, (pa->value << pb->value));
if (pb) {
if (pb->value == 0) return a;
diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py
index 7e2c8b5..2904953 100644
--- a/tests/python/unittest/test_tir_nodes.py
+++ b/tests/python/unittest/test_tir_nodes.py
@@ -207,6 +207,23 @@ def test_float_bitwise():
pass
+def test_shift_bounds():
+ x = te.var('x')
+ for test in [lambda lhs, rhs : lhs << rhs,
+ lambda lhs, rhs : lhs >> rhs]:
+ #negative case
+ for testcase in [(x,-1), (x,32)]:
+ try:
+ test(*testcase)
+ assert False
+ except tvm.TVMError:
+ pass
+
+ #positive case
+ for testcase in [(x,0), (x,16), (x,31)]:
+ test(*testcase)
+
+
def test_divide_by_zero():
for test in [lambda lhs, rhs : tvm.tir.floormod(lhs,rhs),
lambda lhs, rhs : tvm.tir.floordiv(lhs,rhs),
@@ -293,6 +310,7 @@ if __name__ == "__main__":
test_all()
test_bitwise()
test_float_bitwise()
+ test_shift_bounds()
test_divide_by_zero()
test_isnan()
test_equality()