You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/04/20 06:34:38 UTC

[tvm] branch main updated: Fix While Node StructuralEqual and StructuralHash issue (#11073)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 8cf0c3e992 Fix While Node StructuralEqual and StructuralHash issue (#11073)
8cf0c3e992 is described below

commit 8cf0c3e9927cdbf4e9bcf538ffe6c798e0a7bc25
Author: Siyuan Feng <Hz...@sjtu.edu.cn>
AuthorDate: Wed Apr 20 14:34:32 2022 +0800

    Fix While Node StructuralEqual and StructuralHash issue (#11073)
---
 include/tvm/tir/stmt.h                                  |  6 +++---
 tests/python/unittest/test_tir_structural_equal_hash.py | 10 ++++++++++
 2 files changed, 13 insertions(+), 3 deletions(-)

diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h
index 9ccab50ece..6cdd6499c8 100644
--- a/include/tvm/tir/stmt.h
+++ b/include/tvm/tir/stmt.h
@@ -996,12 +996,12 @@ class WhileNode : public StmtNode {
   }
 
   bool SEqualReduce(const WhileNode* other, SEqualReducer equal) const {
-    return equal.DefEqual(condition, other->condition) && equal.DefEqual(body, other->body);
+    return equal(condition, other->condition) && equal(body, other->body);
   }
 
   void SHashReduce(SHashReducer hash_reduce) const {
-    hash_reduce.DefHash(condition);
-    hash_reduce.DefHash(body);
+    hash_reduce(condition);
+    hash_reduce(body);
   }
 
   static constexpr const char* _type_key = "tir.While";
diff --git a/tests/python/unittest/test_tir_structural_equal_hash.py b/tests/python/unittest/test_tir_structural_equal_hash.py
index d25780a01f..ff02f1e369 100644
--- a/tests/python/unittest/test_tir_structural_equal_hash.py
+++ b/tests/python/unittest/test_tir_structural_equal_hash.py
@@ -199,6 +199,15 @@ def test_buffer_load_store():
     assert not consistent_equal(sy, sz)
 
 
+def test_while():
+    x = tvm.tir.Var("x", "int32")
+    y = tvm.tir.Var("y", "int32")
+    wx = tvm.tir.While(x > 0, tvm.tir.Evaluate(x))
+    wy = tvm.tir.While(y > 0, tvm.tir.Evaluate(y))
+    assert not consistent_equal(wx, wy)
+    assert consistent_equal(wx, wy, map_free_vars=True)
+
+
 if __name__ == "__main__":
     test_exprs()
     test_prim_func()
@@ -208,3 +217,4 @@ if __name__ == "__main__":
     test_stmt()
     test_buffer_storage_scope()
     test_buffer_load_store()
+    test_while()