You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/04/22 22:00:17 UTC
[tvm] branch main updated: [Analysis] Exposed Analyzer::CanProveEqual to Python API (#11102)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 83672c65c7 [Analysis] Exposed Analyzer::CanProveEqual to Python API (#11102)
83672c65c7 is described below
commit 83672c65c77554a4c0b26691ea3364bed2cf08af
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Fri Apr 22 17:00:10 2022 -0500
[Analysis] Exposed Analyzer::CanProveEqual to Python API (#11102)
* [Analysis] Exposed Analyzer::CanProveEqual to Python API
Checking for `analyizer.simplify(lhs-rhs) == 0` was a frequent pattern
in Python unit tests, and already had a utility function in the C++
public API. Exposing this utility function to Python allowed this
pattern to be cleaned up.
* Replaced more cases of .simplify with .can_prove_equal
---
python/tvm/arith/analyzer.py | 19 +++++++++++++++++++
python/tvm/testing/utils.py | 4 +---
src/arith/analyzer.cc | 3 +++
tests/python/unittest/test_arith_intset.py | 8 ++------
tests/python/unittest/test_arith_iter_affine_map.py | 12 ++++++------
.../test_tir_analysis_get_block_access_region.py | 10 +++-------
vta/python/vta/transform.py | 7 +++----
7 files changed, 37 insertions(+), 26 deletions(-)
diff --git a/python/tvm/arith/analyzer.py b/python/tvm/arith/analyzer.py
index 5c532c692b..28adbe9d81 100644
--- a/python/tvm/arith/analyzer.py
+++ b/python/tvm/arith/analyzer.py
@@ -90,6 +90,7 @@ class Analyzer:
self._canonical_simplify = _mod("canonical_simplify")
self._int_set = _mod("int_set")
self._enter_constraint_context = _mod("enter_constraint_context")
+ self._can_prove_equal = _mod("can_prove_equal")
def const_int_bound(self, expr):
"""Find constant integer bound for expr.
@@ -251,3 +252,21 @@ class Analyzer:
self._const_int_bound_update(var, info, override)
else:
raise TypeError("Do not know how to handle type {}".format(type(info)))
+
+ def can_prove_equal(self, lhs: "PrimExpr", rhs: "PrimExpr"):
+ """Whether we can prove that lhs == rhs
+
+ Parameters
+ ----------
+ lhs: PrimExpr
+ The left-hand side of the comparison
+
+ rhs: PrimExpr
+ The right-hand side of the comparison
+
+ Returns
+ -------
+ result: bool
+ Whether we can prove that lhs == rhs
+ """
+ return self._can_prove_equal(lhs, rhs)
diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py
index eeb9c35b4a..b86596feed 100644
--- a/python/tvm/testing/utils.py
+++ b/python/tvm/testing/utils.py
@@ -274,9 +274,7 @@ def assert_prim_expr_equal(lhs, rhs):
The left operand.
"""
ana = tvm.arith.Analyzer()
- res = ana.simplify(lhs - rhs)
- equal = isinstance(res, tvm.tir.IntImm) and res.value == 0
- if not equal:
+ if not ana.can_prove_equal(lhs, rhs):
raise ValueError("{} and {} are not equal".format(lhs, rhs))
diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc
index 08e32f5762..5309aa3270 100644
--- a/src/arith/analyzer.cc
+++ b/src/arith/analyzer.cc
@@ -185,6 +185,9 @@ TVM_REGISTER_GLOBAL("arith.CreateAnalyzer").set_body([](TVMArgs args, TVMRetValu
auto fexit = [ctx](TVMArgs, TVMRetValue*) mutable { ctx.reset(); };
*ret = PackedFunc(fexit);
});
+ } else if (name == "can_prove_equal") {
+ return PackedFunc(
+ [self](TVMArgs args, TVMRetValue* ret) { *ret = self->CanProveEqual(args[0], args[1]); });
}
return PackedFunc();
};
diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py
index e741ee88a6..9ca6cb8e02 100644
--- a/tests/python/unittest/test_arith_intset.py
+++ b/tests/python/unittest/test_arith_intset.py
@@ -30,12 +30,8 @@ class IntSetChecker:
def err_msg():
return "\ndata={}\ndmap={}\nres={}\nexpected={}".format(data, dmap, res, expected)
- def equal(x, y):
- res = self.analyzer.simplify(x - y)
- return tvm.tir.analysis.expr_deep_equal(res, 0)
-
- assert equal(res.min_value, expected[0]), err_msg()
- assert equal(res.max_value, expected[1]), err_msg()
+ assert self.analyzer.can_prove_equal(res.min_value, expected[0]), err_msg()
+ assert self.analyzer.can_prove_equal(res.max_value, expected[1]), err_msg()
def test_basic():
diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py
index 3dd6ee1c2b..5beec1c08c 100644
--- a/tests/python/unittest/test_arith_iter_affine_map.py
+++ b/tests/python/unittest/test_arith_iter_affine_map.py
@@ -850,8 +850,8 @@ def test_inverse_affine_iter_map():
assert len(res) == 2
l0_inverse = floormod(floordiv(outputs[0], 4), 16) + outputs[1] * 16
l1_inverse = floormod(outputs[0], 4) + outputs[2] * 4
- assert analyzer.simplify(res[l0[0]] - l0_inverse) == 0
- assert analyzer.simplify(res[l1[0]] - l1_inverse) == 0
+ assert analyzer.can_prove_equal(res[l0[0]], l0_inverse)
+ assert analyzer.can_prove_equal(res[l1[0]], l1_inverse)
# compound case
l0_0, l0_1 = isplit(l0, 16)
@@ -873,9 +873,9 @@ def test_inverse_affine_iter_map():
floormod(outputs[0], 4) * 16 + floormod(floordiv(outputs[0], 16), 4) * 4 + outputs[2]
)
- assert analyzer.simplify(res[l0[0]] - l0_inverse) == 0
- assert analyzer.simplify(res[l1[0]] - l1_inverse) == 0
- assert analyzer.simplify(res[l2[0]] - l2_inverse) == 0
+ assert analyzer.can_prove_equal(res[l0[0]], l0_inverse)
+ assert analyzer.can_prove_equal(res[l1[0]], l1_inverse)
+ assert analyzer.can_prove_equal(res[l2[0]], l2_inverse)
# diamond-shape DAG
l0_0, l0_1 = isplit(l0, 16)
@@ -890,7 +890,7 @@ def test_inverse_affine_iter_map():
l1_inverse = floormod(outputs[0], 8) * 8 + floormod(floordiv(outputs[0], 8), 8)
l0_inverse = floormod(l1_inverse, 4) * 16 + floormod(floordiv(l1_inverse, 4), 16)
- assert analyzer.simplify(res[l0[0]] - l0_inverse) == 0
+ assert analyzer.can_prove_equal(res[l0[0]], l0_inverse)
def test_free_variables():
diff --git a/tests/python/unittest/test_tir_analysis_get_block_access_region.py b/tests/python/unittest/test_tir_analysis_get_block_access_region.py
index f5d701ea71..463f2a7f0e 100644
--- a/tests/python/unittest/test_tir_analysis_get_block_access_region.py
+++ b/tests/python/unittest/test_tir_analysis_get_block_access_region.py
@@ -291,13 +291,9 @@ def test_access_of_padding_pattern():
def do_compare_buffer_region(region, expect):
assert region.buffer == expect.buffer
analyzer = tvm.arith.Analyzer()
- for k, rng in enumerate(region.region):
- tvm.ir.assert_structural_equal(
- analyzer.simplify(rng.min), analyzer.simplify(expect.region[k].min)
- )
- tvm.ir.assert_structural_equal(
- analyzer.simplify(rng.extent), analyzer.simplify(expect.region[k].extent)
- )
+ for observed_range, expected_range in zip(region.region, expect.region):
+ analyzer.can_prove_equal(observed_range.min, expected_range.min)
+ analyzer.can_prove_equal(observed_range.extent, expected_range.extent)
def do_check_block(block_name):
block = s.get_sref(s.get_block(block_name)).stmt
diff --git a/vta/python/vta/transform.py b/vta/python/vta/transform.py
index 1e8247c6e1..38d58179c4 100644
--- a/vta/python/vta/transform.py
+++ b/vta/python/vta/transform.py
@@ -902,9 +902,6 @@ def InjectALUIntrin():
analyzer = tvm.arith.Analyzer()
def _do_fold(stmt):
- def _equal(x, y):
- return tvm.ir.structural_equal(analyzer.simplify(x - y), 0)
-
def _flatten_loop(src_coeff, dst_coeff, extents):
src_coeff = list(src_coeff)
dst_coeff = list(dst_coeff)
@@ -921,7 +918,9 @@ def InjectALUIntrin():
next_dst = dst_coeff.pop()
next_ext = extents.pop()
- if _equal(next_src, vsrc * vext) and _equal(next_dst, vdst * vext):
+ if analyzer.can_prove_equal(next_src, vsrc * vext) and analyzer.can_prove_equal(
+ next_dst, vdst * vext
+ ):
vext = analyzer.simplify(vext * next_ext)
else:
rev_src_coeff.append(vsrc)