You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2022/05/26 16:26:35 UTC

[tvm] branch main updated: [TIR] Additional Stmt/Expr simplication rules (#11373)

This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 52df2e8414 [TIR] Additional Stmt/Expr simplication rules (#11373)
52df2e8414 is described below

commit 52df2e84141b34cda2b1e723c22d38b22796d6a7
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Thu May 26 11:26:29 2022 -0500

    [TIR] Additional Stmt/Expr simplication rules (#11373)
    
    * [TIR] Additional Stmt/Expr simplication rules
    
    - Enabled simplification of `A[i] = A[i] + 0` into no-op.  This was a
      bug introduced in https://github.com/apache/tvm/pull/9727, which
      applied this rewrite only to `A[i] = A[i]`, and not to statements
      which simplify to `A[i] = A[i]`.  Regression test added to prevent
      reoccurrence of this bug.
    
    - Enabled simplification of `x - x` to zero for floating point types.
      Previously, this simplification was applied only for data types that
      could be used as buffer indices.
    
    * Updated to maintain separate int/float simplification paths
    
    * Updated to use tvm.testing.main
    
    * Remove duplicate rewrite rules
---
 src/arith/rewrite_simplify.cc                      |  9 +++++
 src/tir/transforms/simplify.cc                     | 12 +++---
 .../python/unittest/test_arith_rewrite_simplify.py |  8 ++++
 .../python/unittest/test_tir_transform_simplify.py | 45 +++++++++++++++++++---
 4 files changed, 63 insertions(+), 11 deletions(-)

diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc
index 4d8b6ff769..dab78c77a0 100644
--- a/src/arith/rewrite_simplify.cc
+++ b/src/arith/rewrite_simplify.cc
@@ -411,6 +411,15 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) {
     TVM_TRY_RECURSIVE_REWRITE((x + c1) - y, (x - y) + c1);
     TVM_TRY_RECURSIVE_REWRITE(x - (y - z), (x + z) - y);
     TVM_TRY_RECURSIVE_REWRITE(x - y * c1, x + y * (0 - c1));
+  } else if (op->dtype.is_float()) {
+    // Cancellation rules.  Deliberately off of the integer path, to
+    // avoid introducing checks on the side effects for the fast path.
+    TVM_TRY_REWRITE_IF(x - x, ZeroWithTypeLike(x),
+                       SideEffect(x.Eval()) <= CallEffectKind::kReadState);
+    TVM_TRY_REWRITE_IF((x + y) - y, x, SideEffect(y.Eval()) <= CallEffectKind::kReadState);
+    TVM_TRY_REWRITE_IF((x + y) - x, y, SideEffect(x.Eval()) <= CallEffectKind::kReadState);
+    TVM_TRY_REWRITE_IF(x - (y + x), 0 - y, SideEffect(x.Eval()) <= CallEffectKind::kReadState);
+    TVM_TRY_REWRITE_IF(x - (x + y), 0 - y, SideEffect(x.Eval()) <= CallEffectKind::kReadState);
   }
 
   // condition rules.
diff --git a/src/tir/transforms/simplify.cc b/src/tir/transforms/simplify.cc
index 7d4fac8d7b..85f405be44 100644
--- a/src/tir/transforms/simplify.cc
+++ b/src/tir/transforms/simplify.cc
@@ -90,12 +90,12 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
   // eliminate useless stores
   Stmt VisitStmt_(const BufferStoreNode* op) final {
     BufferStore store = Downcast<BufferStore>(Parent::VisitStmt_(op));
-    if (const BufferLoadNode* load = op->value.as<BufferLoadNode>()) {
-      if (load->buffer->data.same_as(op->buffer->data) &&
-          ArrayDeepEqual(load->indices, op->indices) &&
-          tir::ExprDeepEqual()(load->buffer->elem_offset, op->buffer->elem_offset) &&
-          ArrayDeepEqual(load->buffer->shape, op->buffer->shape) &&
-          ArrayDeepEqual(load->buffer->strides, op->buffer->strides)) {
+    if (const BufferLoadNode* load = store->value.as<BufferLoadNode>()) {
+      if (load->buffer->data.same_as(store->buffer->data) &&
+          ArrayDeepEqual(load->indices, store->indices) &&
+          tir::ExprDeepEqual()(load->buffer->elem_offset, store->buffer->elem_offset) &&
+          ArrayDeepEqual(load->buffer->shape, store->buffer->shape) &&
+          ArrayDeepEqual(load->buffer->strides, store->buffer->strides)) {
         return Evaluate(0);
       }
     }
diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py
index 855635b3f9..8d26710f40 100644
--- a/tests/python/unittest/test_arith_rewrite_simplify.py
+++ b/tests/python/unittest/test_arith_rewrite_simplify.py
@@ -972,5 +972,13 @@ def test_div_zero_simplify():
         assert "division by zero" in str(cm.execption)
 
 
+def test_sub_bufferload():
+    ck = RewriteChecker()
+    buf = tvm.tir.decl_buffer([1], dtype="float32")
+    load = tvm.tir.BufferLoad(buf, [0])
+    expr = load - load
+    ck.verify(expr, 0.0)
+
+
 if __name__ == "__main__":
     pytest.main([__file__])
diff --git a/tests/python/unittest/test_tir_transform_simplify.py b/tests/python/unittest/test_tir_transform_simplify.py
index 824bef4f32..01cc41c7ce 100644
--- a/tests/python/unittest/test_tir_transform_simplify.py
+++ b/tests/python/unittest/test_tir_transform_simplify.py
@@ -15,7 +15,10 @@
 # specific language governing permissions and limitations
 # under the License.
 import tvm
+import tvm.testing
+
 from tvm import te
+from tvm.script import tir as T
 
 
 def test_stmt_simplify():
@@ -133,9 +136,41 @@ def test_complex_likely_elimination():
     assert "if" not in str(stmt)
 
 
+def test_load_store_noop():
+    """Store of a value that was just read from the same location is a no-op."""
+
+    @T.prim_func
+    def before(A: T.Buffer[(1,), "float32"]):
+        A[0] = A[0]
+
+    @T.prim_func
+    def expected(A: T.Buffer[(1,), "float32"]):
+        T.evaluate(0)
+
+    after = tvm.tir.transform.Simplify()(tvm.IRModule.from_expr(before))["main"]
+    tvm.ir.assert_structural_equal(after, expected)
+
+
+def test_load_store_noop_after_simplify():
+    """As test_load_store_noop, but requiring simplification to identify.
+
+    Previously, a bug caused the self-assignment of a buffer to
+    checked based on the pre-simplification assignment, not the
+    post-simplification.  This test is to identify any similar
+    regression.
+    """
+
+    @T.prim_func
+    def before(A: T.Buffer[(1,), "float32"]):
+        A[0] = A[0] + (5.0 - 5.0)
+
+    @T.prim_func
+    def expected(A: T.Buffer[(1,), "float32"]):
+        T.evaluate(0)
+
+    after = tvm.tir.transform.Simplify()(tvm.IRModule.from_expr(before))["main"]
+    tvm.ir.assert_structural_equal(after, expected)
+
+
 if __name__ == "__main__":
-    test_stmt_simplify()
-    test_thread_extent_simplify()
-    test_if_likely()
-    test_basic_likely_elimination()
-    test_complex_likely_elimination()
+    tvm.testing.main()