You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/04/20 06:34:38 UTC
[tvm] branch main updated: Fix While Node StructuralEqual and StructuralHash issue (#11073)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 8cf0c3e992 Fix While Node StructuralEqual and StructuralHash issue (#11073)
8cf0c3e992 is described below
commit 8cf0c3e9927cdbf4e9bcf538ffe6c798e0a7bc25
Author: Siyuan Feng <Hz...@sjtu.edu.cn>
AuthorDate: Wed Apr 20 14:34:32 2022 +0800
Fix While Node StructuralEqual and StructuralHash issue (#11073)
---
include/tvm/tir/stmt.h | 6 +++---
tests/python/unittest/test_tir_structural_equal_hash.py | 10 ++++++++++
2 files changed, 13 insertions(+), 3 deletions(-)
diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h
index 9ccab50ece..6cdd6499c8 100644
--- a/include/tvm/tir/stmt.h
+++ b/include/tvm/tir/stmt.h
@@ -996,12 +996,12 @@ class WhileNode : public StmtNode {
}
bool SEqualReduce(const WhileNode* other, SEqualReducer equal) const {
- return equal.DefEqual(condition, other->condition) && equal.DefEqual(body, other->body);
+ return equal(condition, other->condition) && equal(body, other->body);
}
void SHashReduce(SHashReducer hash_reduce) const {
- hash_reduce.DefHash(condition);
- hash_reduce.DefHash(body);
+ hash_reduce(condition);
+ hash_reduce(body);
}
static constexpr const char* _type_key = "tir.While";
diff --git a/tests/python/unittest/test_tir_structural_equal_hash.py b/tests/python/unittest/test_tir_structural_equal_hash.py
index d25780a01f..ff02f1e369 100644
--- a/tests/python/unittest/test_tir_structural_equal_hash.py
+++ b/tests/python/unittest/test_tir_structural_equal_hash.py
@@ -199,6 +199,15 @@ def test_buffer_load_store():
assert not consistent_equal(sy, sz)
+def test_while():
+ x = tvm.tir.Var("x", "int32")
+ y = tvm.tir.Var("y", "int32")
+ wx = tvm.tir.While(x > 0, tvm.tir.Evaluate(x))
+ wy = tvm.tir.While(y > 0, tvm.tir.Evaluate(y))
+ assert not consistent_equal(wx, wy)
+ assert consistent_equal(wx, wy, map_free_vars=True)
+
+
if __name__ == "__main__":
test_exprs()
test_prim_func()
@@ -208,3 +217,4 @@ if __name__ == "__main__":
test_stmt()
test_buffer_storage_scope()
test_buffer_load_store()
+ test_while()