You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2022/06/24 21:46:52 UTC
[tvm] branch main updated: [TIR][Arith] Avoid assigning range of possible values to integers (#11859)
This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ed638ef6db [TIR][Arith] Avoid assigning range of possible values to integers (#11859)
ed638ef6db is described below
commit ed638ef6db007772cbf84b13c26836a8a53706b3
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Fri Jun 24 16:46:46 2022 -0500
[TIR][Arith] Avoid assigning range of possible values to integers (#11859)
Previously, in `ConstIntBoundAnalyzer`, entering a conditional such as
`if 2==0` could result in the expression `2` being treated as having a
known value of zero within the body of the conditional. Evaluating
the range of expressions using `2` in the body of the conditional
could result in exceptions being thrown, such as evaluating `expr / 2`
while setting `2` to its maximum value of zero.
This issue was present for conditions with inequalities for some time,
but was introduced for conditions with equalities in
https://github.com/apache/tvm/pull/11524. Both types are resolved in
this PR.
---
src/arith/const_int_bound.cc | 41 ++++++++++-------
...test_tir_transform_renormalize_split_pattern.py | 53 +++++++++++++++++++++-
2 files changed, 76 insertions(+), 18 deletions(-)
diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc
index cabf299a88..fa74f83313 100644
--- a/src/arith/const_int_bound.cc
+++ b/src/arith/const_int_bound.cc
@@ -637,29 +637,36 @@ class ConstIntBoundAnalyzer::Impl
static std::vector<BoundInfo> DetectBoundInfo(const PrimExpr& cond) {
PVar<PrimExpr> x, y;
PVar<IntImm> c;
- // NOTE: canonical form always use <= or <
- if ((c <= x).Match(cond)) {
- return {BoundInfo(x.Eval(), MakeBound(c.Eval()->value, kPosInf))};
- }
- if ((c < x).Match(cond)) {
- return {BoundInfo(x.Eval(), MakeBound(c.Eval()->value + 1, kPosInf))};
- }
- if ((x <= c).Match(cond)) {
- return {BoundInfo(x.Eval(), MakeBound(kNegInf, c.Eval()->value))};
- }
- if ((x < c).Match(cond)) {
- return {BoundInfo(x.Eval(), MakeBound(kNegInf, c.Eval()->value - 1))};
- }
- if ((x == c).Match(cond) || (c == x).Match(cond)) {
- return {BoundInfo(x.Eval(), MakeBound(c.Eval()->value, c.Eval()->value))};
- }
if ((x && y).Match(cond)) {
auto ret1 = DetectBoundInfo(x.Eval());
auto ret2 = DetectBoundInfo(y.Eval());
ret1.insert(ret1.end(), ret2.begin(), ret2.end());
return ret1;
}
- return {};
+
+ // NOTE: canonical form always use <= or <
+ Entry bound;
+ if ((c <= x).Match(cond)) {
+ bound = MakeBound(c.Eval()->value, kPosInf);
+ } else if ((c < x).Match(cond)) {
+ bound = MakeBound(c.Eval()->value + 1, kPosInf);
+ } else if ((x <= c).Match(cond)) {
+ bound = MakeBound(kNegInf, c.Eval()->value);
+ } else if ((x < c).Match(cond)) {
+ bound = MakeBound(kNegInf, c.Eval()->value - 1);
+ } else if ((x == c).Match(cond) || (c == x).Match(cond)) {
+ bound = MakeBound(c.Eval()->value, c.Eval()->value);
+ } else {
+ return {};
+ }
+
+ // If the conditional is comparing two integers, do not assign a
+ // value to them.
+ if (x.Eval().as<IntImmNode>()) {
+ return {};
+ }
+
+ return {BoundInfo(x.Eval(), bound)};
}
/*!
diff --git a/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py b/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py
index fb1fb72eb8..872afeeba5 100644
--- a/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py
+++ b/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py
@@ -16,6 +16,7 @@
# under the License.
import tvm
+import tvm.testing
from tvm.script import tir as T
# fmt: off
@@ -124,5 +125,55 @@ def test_renormalize_split_pattern():
tvm.ir.assert_structural_equal(after, After_simplified)
+@T.prim_func
+def impossible_equality(n: T.int32):
+ # Prior to bugfix, this conditional defined the expression "2" as
+ # equal to zero within the then_case. [min_value=2, max_value=0]
+ if 2 == 0:
+ # Then this expression evaluates n/2, using the min/max values
+ # of "2", which is caught as a divide by zero error.
+ if n / 2 >= 16:
+ T.evaluate(0)
+
+
+@T.prim_func
+def impossible_inequality(n: T.int32):
+ # Prior to bugfix, this conditional set up a range of possible
+ # values for the expression "-2" as [0, kPosInf].
+ if -1 < -2:
+ if n / (-2) >= 16:
+ T.evaluate(0)
+
+
+integer_condition = tvm.testing.parameter(
+ impossible_equality,
+ impossible_inequality,
+)
+
+
+def test_analyze_inside_integer_conditional(integer_condition):
+ """Avoid crash occurring in ConstIntBoundAnalyzer.
+
+ Crash occurred when simplifying some expressions with provably
+ false integer expressions. If the expressions were renormalized
+ before calling Simplify, conditional statements could assign a
+ range of possible values to integers, as if they were variables.
+ This would result in divide by zero throwing an exception,
+ followed by a second exception during stack unwinding causing the
+ program to crash.
+ """
+
+ # Similar issue would occur in most transformations that subclass
+ # IRMutatorWithAnalyzer. tir.transform.Simplify() is an
+ # exception, as it rewrites the integer conditionals first. These
+ # tests are written using RenormalizeSplitPattern as it is the
+ # first case identified.
+ transform = tvm.tir.transform.RenormalizeSplitPattern()
+
+ # Issue would result in an error through while applying the transformation.
+ mod = tvm.IRModule.from_expr(integer_condition)
+ transform(mod)
+
+
if __name__ == "__main__":
- tesd_renormalize_split_pattern()
+ tvm.testing.main()