You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/04/22 22:00:17 UTC

[tvm] branch main updated: [Analysis] Exposed Analyzer::CanProveEqual to Python API (#11102)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 83672c65c7 [Analysis] Exposed Analyzer::CanProveEqual to Python API (#11102)
83672c65c7 is described below

commit 83672c65c77554a4c0b26691ea3364bed2cf08af
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Fri Apr 22 17:00:10 2022 -0500

    [Analysis] Exposed Analyzer::CanProveEqual to Python API (#11102)
    
    * [Analysis] Exposed Analyzer::CanProveEqual to Python API
    
    Checking for `analyizer.simplify(lhs-rhs) == 0` was a frequent pattern
    in Python unit tests, and already had a utility function in the C++
    public API.  Exposing this utility function to Python allowed this
    pattern to be cleaned up.
    
    * Replaced more cases of .simplify with .can_prove_equal
---
 python/tvm/arith/analyzer.py                          | 19 +++++++++++++++++++
 python/tvm/testing/utils.py                           |  4 +---
 src/arith/analyzer.cc                                 |  3 +++
 tests/python/unittest/test_arith_intset.py            |  8 ++------
 tests/python/unittest/test_arith_iter_affine_map.py   | 12 ++++++------
 .../test_tir_analysis_get_block_access_region.py      | 10 +++-------
 vta/python/vta/transform.py                           |  7 +++----
 7 files changed, 37 insertions(+), 26 deletions(-)

diff --git a/python/tvm/arith/analyzer.py b/python/tvm/arith/analyzer.py
index 5c532c692b..28adbe9d81 100644
--- a/python/tvm/arith/analyzer.py
+++ b/python/tvm/arith/analyzer.py
@@ -90,6 +90,7 @@ class Analyzer:
         self._canonical_simplify = _mod("canonical_simplify")
         self._int_set = _mod("int_set")
         self._enter_constraint_context = _mod("enter_constraint_context")
+        self._can_prove_equal = _mod("can_prove_equal")
 
     def const_int_bound(self, expr):
         """Find constant integer bound for expr.
@@ -251,3 +252,21 @@ class Analyzer:
             self._const_int_bound_update(var, info, override)
         else:
             raise TypeError("Do not know how to handle type {}".format(type(info)))
+
+    def can_prove_equal(self, lhs: "PrimExpr", rhs: "PrimExpr"):
+        """Whether we can prove that lhs == rhs
+
+        Parameters
+        ----------
+        lhs: PrimExpr
+            The left-hand side of the comparison
+
+        rhs: PrimExpr
+            The right-hand side of the comparison
+
+        Returns
+        -------
+        result: bool
+            Whether we can prove that lhs == rhs
+        """
+        return self._can_prove_equal(lhs, rhs)
diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py
index eeb9c35b4a..b86596feed 100644
--- a/python/tvm/testing/utils.py
+++ b/python/tvm/testing/utils.py
@@ -274,9 +274,7 @@ def assert_prim_expr_equal(lhs, rhs):
         The left operand.
     """
     ana = tvm.arith.Analyzer()
-    res = ana.simplify(lhs - rhs)
-    equal = isinstance(res, tvm.tir.IntImm) and res.value == 0
-    if not equal:
+    if not ana.can_prove_equal(lhs, rhs):
         raise ValueError("{} and {} are not equal".format(lhs, rhs))
 
 
diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc
index 08e32f5762..5309aa3270 100644
--- a/src/arith/analyzer.cc
+++ b/src/arith/analyzer.cc
@@ -185,6 +185,9 @@ TVM_REGISTER_GLOBAL("arith.CreateAnalyzer").set_body([](TVMArgs args, TVMRetValu
         auto fexit = [ctx](TVMArgs, TVMRetValue*) mutable { ctx.reset(); };
         *ret = PackedFunc(fexit);
       });
+    } else if (name == "can_prove_equal") {
+      return PackedFunc(
+          [self](TVMArgs args, TVMRetValue* ret) { *ret = self->CanProveEqual(args[0], args[1]); });
     }
     return PackedFunc();
   };
diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py
index e741ee88a6..9ca6cb8e02 100644
--- a/tests/python/unittest/test_arith_intset.py
+++ b/tests/python/unittest/test_arith_intset.py
@@ -30,12 +30,8 @@ class IntSetChecker:
         def err_msg():
             return "\ndata={}\ndmap={}\nres={}\nexpected={}".format(data, dmap, res, expected)
 
-        def equal(x, y):
-            res = self.analyzer.simplify(x - y)
-            return tvm.tir.analysis.expr_deep_equal(res, 0)
-
-        assert equal(res.min_value, expected[0]), err_msg()
-        assert equal(res.max_value, expected[1]), err_msg()
+        assert self.analyzer.can_prove_equal(res.min_value, expected[0]), err_msg()
+        assert self.analyzer.can_prove_equal(res.max_value, expected[1]), err_msg()
 
 
 def test_basic():
diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py
index 3dd6ee1c2b..5beec1c08c 100644
--- a/tests/python/unittest/test_arith_iter_affine_map.py
+++ b/tests/python/unittest/test_arith_iter_affine_map.py
@@ -850,8 +850,8 @@ def test_inverse_affine_iter_map():
     assert len(res) == 2
     l0_inverse = floormod(floordiv(outputs[0], 4), 16) + outputs[1] * 16
     l1_inverse = floormod(outputs[0], 4) + outputs[2] * 4
-    assert analyzer.simplify(res[l0[0]] - l0_inverse) == 0
-    assert analyzer.simplify(res[l1[0]] - l1_inverse) == 0
+    assert analyzer.can_prove_equal(res[l0[0]], l0_inverse)
+    assert analyzer.can_prove_equal(res[l1[0]], l1_inverse)
 
     # compound case
     l0_0, l0_1 = isplit(l0, 16)
@@ -873,9 +873,9 @@ def test_inverse_affine_iter_map():
         floormod(outputs[0], 4) * 16 + floormod(floordiv(outputs[0], 16), 4) * 4 + outputs[2]
     )
 
-    assert analyzer.simplify(res[l0[0]] - l0_inverse) == 0
-    assert analyzer.simplify(res[l1[0]] - l1_inverse) == 0
-    assert analyzer.simplify(res[l2[0]] - l2_inverse) == 0
+    assert analyzer.can_prove_equal(res[l0[0]], l0_inverse)
+    assert analyzer.can_prove_equal(res[l1[0]], l1_inverse)
+    assert analyzer.can_prove_equal(res[l2[0]], l2_inverse)
 
     # diamond-shape DAG
     l0_0, l0_1 = isplit(l0, 16)
@@ -890,7 +890,7 @@ def test_inverse_affine_iter_map():
     l1_inverse = floormod(outputs[0], 8) * 8 + floormod(floordiv(outputs[0], 8), 8)
     l0_inverse = floormod(l1_inverse, 4) * 16 + floormod(floordiv(l1_inverse, 4), 16)
 
-    assert analyzer.simplify(res[l0[0]] - l0_inverse) == 0
+    assert analyzer.can_prove_equal(res[l0[0]], l0_inverse)
 
 
 def test_free_variables():
diff --git a/tests/python/unittest/test_tir_analysis_get_block_access_region.py b/tests/python/unittest/test_tir_analysis_get_block_access_region.py
index f5d701ea71..463f2a7f0e 100644
--- a/tests/python/unittest/test_tir_analysis_get_block_access_region.py
+++ b/tests/python/unittest/test_tir_analysis_get_block_access_region.py
@@ -291,13 +291,9 @@ def test_access_of_padding_pattern():
     def do_compare_buffer_region(region, expect):
         assert region.buffer == expect.buffer
         analyzer = tvm.arith.Analyzer()
-        for k, rng in enumerate(region.region):
-            tvm.ir.assert_structural_equal(
-                analyzer.simplify(rng.min), analyzer.simplify(expect.region[k].min)
-            )
-            tvm.ir.assert_structural_equal(
-                analyzer.simplify(rng.extent), analyzer.simplify(expect.region[k].extent)
-            )
+        for observed_range, expected_range in zip(region.region, expect.region):
+            analyzer.can_prove_equal(observed_range.min, expected_range.min)
+            analyzer.can_prove_equal(observed_range.extent, expected_range.extent)
 
     def do_check_block(block_name):
         block = s.get_sref(s.get_block(block_name)).stmt
diff --git a/vta/python/vta/transform.py b/vta/python/vta/transform.py
index 1e8247c6e1..38d58179c4 100644
--- a/vta/python/vta/transform.py
+++ b/vta/python/vta/transform.py
@@ -902,9 +902,6 @@ def InjectALUIntrin():
         analyzer = tvm.arith.Analyzer()
 
         def _do_fold(stmt):
-            def _equal(x, y):
-                return tvm.ir.structural_equal(analyzer.simplify(x - y), 0)
-
             def _flatten_loop(src_coeff, dst_coeff, extents):
                 src_coeff = list(src_coeff)
                 dst_coeff = list(dst_coeff)
@@ -921,7 +918,9 @@ def InjectALUIntrin():
                     next_dst = dst_coeff.pop()
                     next_ext = extents.pop()
 
-                    if _equal(next_src, vsrc * vext) and _equal(next_dst, vdst * vext):
+                    if analyzer.can_prove_equal(next_src, vsrc * vext) and analyzer.can_prove_equal(
+                        next_dst, vdst * vext
+                    ):
                         vext = analyzer.simplify(vext * next_ext)
                     else:
                         rev_src_coeff.append(vsrc)