You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2022/06/24 21:46:52 UTC

[tvm] branch main updated: [TIR][Arith] Avoid assigning range of possible values to integers (#11859)

This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ed638ef6db [TIR][Arith] Avoid assigning range of possible values to integers (#11859)
ed638ef6db is described below

commit ed638ef6db007772cbf84b13c26836a8a53706b3
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Fri Jun 24 16:46:46 2022 -0500

    [TIR][Arith] Avoid assigning range of possible values to integers (#11859)
    
    Previously, in `ConstIntBoundAnalyzer`, entering a conditional such as
    `if 2==0` could result in the expression `2` being treated as having a
    known value of zero within the body of the conditional.  Evaluating
    the range of expressions using `2` in the body of the conditional
    could result in exceptions being thrown, such as evaluating `expr / 2`
    while setting `2` to its maximum value of zero.
    
    This issue was present for conditions with inequalities for some time,
    but was introduced for conditions with equalities in
    https://github.com/apache/tvm/pull/11524.  Both types are resolved in
    this PR.
---
 src/arith/const_int_bound.cc                       | 41 ++++++++++-------
 ...test_tir_transform_renormalize_split_pattern.py | 53 +++++++++++++++++++++-
 2 files changed, 76 insertions(+), 18 deletions(-)

diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc
index cabf299a88..fa74f83313 100644
--- a/src/arith/const_int_bound.cc
+++ b/src/arith/const_int_bound.cc
@@ -637,29 +637,36 @@ class ConstIntBoundAnalyzer::Impl
   static std::vector<BoundInfo> DetectBoundInfo(const PrimExpr& cond) {
     PVar<PrimExpr> x, y;
     PVar<IntImm> c;
-    // NOTE: canonical form always use <= or <
-    if ((c <= x).Match(cond)) {
-      return {BoundInfo(x.Eval(), MakeBound(c.Eval()->value, kPosInf))};
-    }
-    if ((c < x).Match(cond)) {
-      return {BoundInfo(x.Eval(), MakeBound(c.Eval()->value + 1, kPosInf))};
-    }
-    if ((x <= c).Match(cond)) {
-      return {BoundInfo(x.Eval(), MakeBound(kNegInf, c.Eval()->value))};
-    }
-    if ((x < c).Match(cond)) {
-      return {BoundInfo(x.Eval(), MakeBound(kNegInf, c.Eval()->value - 1))};
-    }
-    if ((x == c).Match(cond) || (c == x).Match(cond)) {
-      return {BoundInfo(x.Eval(), MakeBound(c.Eval()->value, c.Eval()->value))};
-    }
     if ((x && y).Match(cond)) {
       auto ret1 = DetectBoundInfo(x.Eval());
       auto ret2 = DetectBoundInfo(y.Eval());
       ret1.insert(ret1.end(), ret2.begin(), ret2.end());
       return ret1;
     }
-    return {};
+
+    // NOTE: canonical form always use <= or <
+    Entry bound;
+    if ((c <= x).Match(cond)) {
+      bound = MakeBound(c.Eval()->value, kPosInf);
+    } else if ((c < x).Match(cond)) {
+      bound = MakeBound(c.Eval()->value + 1, kPosInf);
+    } else if ((x <= c).Match(cond)) {
+      bound = MakeBound(kNegInf, c.Eval()->value);
+    } else if ((x < c).Match(cond)) {
+      bound = MakeBound(kNegInf, c.Eval()->value - 1);
+    } else if ((x == c).Match(cond) || (c == x).Match(cond)) {
+      bound = MakeBound(c.Eval()->value, c.Eval()->value);
+    } else {
+      return {};
+    }
+
+    // If the conditional is comparing two integers, do not assign a
+    // value to them.
+    if (x.Eval().as<IntImmNode>()) {
+      return {};
+    }
+
+    return {BoundInfo(x.Eval(), bound)};
   }
 
   /*!
diff --git a/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py b/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py
index fb1fb72eb8..872afeeba5 100644
--- a/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py
+++ b/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py
@@ -16,6 +16,7 @@
 # under the License.
 
 import tvm
+import tvm.testing
 from tvm.script import tir as T
 
 # fmt: off
@@ -124,5 +125,55 @@ def test_renormalize_split_pattern():
     tvm.ir.assert_structural_equal(after, After_simplified)
 
 
+@T.prim_func
+def impossible_equality(n: T.int32):
+    # Prior to bugfix, this conditional defined the expression "2" as
+    # equal to zero within the then_case. [min_value=2, max_value=0]
+    if 2 == 0:
+        # Then this expression evaluates n/2, using the min/max values
+        # of "2", which is caught as a divide by zero error.
+        if n / 2 >= 16:
+            T.evaluate(0)
+
+
+@T.prim_func
+def impossible_inequality(n: T.int32):
+    # Prior to bugfix, this conditional set up a range of possible
+    # values for the expression "-2" as [0, kPosInf].
+    if -1 < -2:
+        if n / (-2) >= 16:
+            T.evaluate(0)
+
+
+integer_condition = tvm.testing.parameter(
+    impossible_equality,
+    impossible_inequality,
+)
+
+
+def test_analyze_inside_integer_conditional(integer_condition):
+    """Avoid crash occurring in ConstIntBoundAnalyzer.
+
+    Crash occurred when simplifying some expressions with provably
+    false integer expressions.  If the expressions were renormalized
+    before calling Simplify, conditional statements could assign a
+    range of possible values to integers, as if they were variables.
+    This would result in divide by zero throwing an exception,
+    followed by a second exception during stack unwinding causing the
+    program to crash.
+    """
+
+    # Similar issue would occur in most transformations that subclass
+    # IRMutatorWithAnalyzer.  tir.transform.Simplify() is an
+    # exception, as it rewrites the integer conditionals first.  These
+    # tests are written using RenormalizeSplitPattern as it is the
+    # first case identified.
+    transform = tvm.tir.transform.RenormalizeSplitPattern()
+
+    # Issue would result in an error through while applying the transformation.
+    mod = tvm.IRModule.from_expr(integer_condition)
+    transform(mod)
+
+
 if __name__ == "__main__":
-    tesd_renormalize_split_pattern()
+    tvm.testing.main()