You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2022/05/26 16:26:35 UTC
[tvm] branch main updated: [TIR] Additional Stmt/Expr simplication rules (#11373)
This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 52df2e8414 [TIR] Additional Stmt/Expr simplication rules (#11373)
52df2e8414 is described below
commit 52df2e84141b34cda2b1e723c22d38b22796d6a7
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Thu May 26 11:26:29 2022 -0500
[TIR] Additional Stmt/Expr simplication rules (#11373)
* [TIR] Additional Stmt/Expr simplication rules
- Enabled simplification of `A[i] = A[i] + 0` into no-op. This was a
bug introduced in https://github.com/apache/tvm/pull/9727, which
applied this rewrite only to `A[i] = A[i]`, and not to statements
which simplify to `A[i] = A[i]`. Regression test added to prevent
reoccurrence of this bug.
- Enabled simplification of `x - x` to zero for floating point types.
Previously, this simplification was applied only for data types that
could be used as buffer indices.
* Updated to maintain separate int/float simplification paths
* Updated to use tvm.testing.main
* Remove duplicate rewrite rules
---
src/arith/rewrite_simplify.cc | 9 +++++
src/tir/transforms/simplify.cc | 12 +++---
.../python/unittest/test_arith_rewrite_simplify.py | 8 ++++
.../python/unittest/test_tir_transform_simplify.py | 45 +++++++++++++++++++---
4 files changed, 63 insertions(+), 11 deletions(-)
diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc
index 4d8b6ff769..dab78c77a0 100644
--- a/src/arith/rewrite_simplify.cc
+++ b/src/arith/rewrite_simplify.cc
@@ -411,6 +411,15 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) {
TVM_TRY_RECURSIVE_REWRITE((x + c1) - y, (x - y) + c1);
TVM_TRY_RECURSIVE_REWRITE(x - (y - z), (x + z) - y);
TVM_TRY_RECURSIVE_REWRITE(x - y * c1, x + y * (0 - c1));
+ } else if (op->dtype.is_float()) {
+ // Cancellation rules. Deliberately off of the integer path, to
+ // avoid introducing checks on the side effects for the fast path.
+ TVM_TRY_REWRITE_IF(x - x, ZeroWithTypeLike(x),
+ SideEffect(x.Eval()) <= CallEffectKind::kReadState);
+ TVM_TRY_REWRITE_IF((x + y) - y, x, SideEffect(y.Eval()) <= CallEffectKind::kReadState);
+ TVM_TRY_REWRITE_IF((x + y) - x, y, SideEffect(x.Eval()) <= CallEffectKind::kReadState);
+ TVM_TRY_REWRITE_IF(x - (y + x), 0 - y, SideEffect(x.Eval()) <= CallEffectKind::kReadState);
+ TVM_TRY_REWRITE_IF(x - (x + y), 0 - y, SideEffect(x.Eval()) <= CallEffectKind::kReadState);
}
// condition rules.
diff --git a/src/tir/transforms/simplify.cc b/src/tir/transforms/simplify.cc
index 7d4fac8d7b..85f405be44 100644
--- a/src/tir/transforms/simplify.cc
+++ b/src/tir/transforms/simplify.cc
@@ -90,12 +90,12 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
// eliminate useless stores
Stmt VisitStmt_(const BufferStoreNode* op) final {
BufferStore store = Downcast<BufferStore>(Parent::VisitStmt_(op));
- if (const BufferLoadNode* load = op->value.as<BufferLoadNode>()) {
- if (load->buffer->data.same_as(op->buffer->data) &&
- ArrayDeepEqual(load->indices, op->indices) &&
- tir::ExprDeepEqual()(load->buffer->elem_offset, op->buffer->elem_offset) &&
- ArrayDeepEqual(load->buffer->shape, op->buffer->shape) &&
- ArrayDeepEqual(load->buffer->strides, op->buffer->strides)) {
+ if (const BufferLoadNode* load = store->value.as<BufferLoadNode>()) {
+ if (load->buffer->data.same_as(store->buffer->data) &&
+ ArrayDeepEqual(load->indices, store->indices) &&
+ tir::ExprDeepEqual()(load->buffer->elem_offset, store->buffer->elem_offset) &&
+ ArrayDeepEqual(load->buffer->shape, store->buffer->shape) &&
+ ArrayDeepEqual(load->buffer->strides, store->buffer->strides)) {
return Evaluate(0);
}
}
diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py
index 855635b3f9..8d26710f40 100644
--- a/tests/python/unittest/test_arith_rewrite_simplify.py
+++ b/tests/python/unittest/test_arith_rewrite_simplify.py
@@ -972,5 +972,13 @@ def test_div_zero_simplify():
assert "division by zero" in str(cm.execption)
+def test_sub_bufferload():
+ ck = RewriteChecker()
+ buf = tvm.tir.decl_buffer([1], dtype="float32")
+ load = tvm.tir.BufferLoad(buf, [0])
+ expr = load - load
+ ck.verify(expr, 0.0)
+
+
if __name__ == "__main__":
pytest.main([__file__])
diff --git a/tests/python/unittest/test_tir_transform_simplify.py b/tests/python/unittest/test_tir_transform_simplify.py
index 824bef4f32..01cc41c7ce 100644
--- a/tests/python/unittest/test_tir_transform_simplify.py
+++ b/tests/python/unittest/test_tir_transform_simplify.py
@@ -15,7 +15,10 @@
# specific language governing permissions and limitations
# under the License.
import tvm
+import tvm.testing
+
from tvm import te
+from tvm.script import tir as T
def test_stmt_simplify():
@@ -133,9 +136,41 @@ def test_complex_likely_elimination():
assert "if" not in str(stmt)
+def test_load_store_noop():
+ """Store of a value that was just read from the same location is a no-op."""
+
+ @T.prim_func
+ def before(A: T.Buffer[(1,), "float32"]):
+ A[0] = A[0]
+
+ @T.prim_func
+ def expected(A: T.Buffer[(1,), "float32"]):
+ T.evaluate(0)
+
+ after = tvm.tir.transform.Simplify()(tvm.IRModule.from_expr(before))["main"]
+ tvm.ir.assert_structural_equal(after, expected)
+
+
+def test_load_store_noop_after_simplify():
+ """As test_load_store_noop, but requiring simplification to identify.
+
+ Previously, a bug caused the self-assignment of a buffer to
+ checked based on the pre-simplification assignment, not the
+ post-simplification. This test is to identify any similar
+ regression.
+ """
+
+ @T.prim_func
+ def before(A: T.Buffer[(1,), "float32"]):
+ A[0] = A[0] + (5.0 - 5.0)
+
+ @T.prim_func
+ def expected(A: T.Buffer[(1,), "float32"]):
+ T.evaluate(0)
+
+ after = tvm.tir.transform.Simplify()(tvm.IRModule.from_expr(before))["main"]
+ tvm.ir.assert_structural_equal(after, expected)
+
+
if __name__ == "__main__":
- test_stmt_simplify()
- test_thread_extent_simplify()
- test_if_likely()
- test_basic_likely_elimination()
- test_complex_likely_elimination()
+ tvm.testing.main()