You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2019/12/08 19:57:33 UTC
[incubator-tvm] branch master updated: Check function attr for
alpha equal (#4479)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 5fe5cee Check function attr for alpha equal (#4479)
5fe5cee is described below
commit 5fe5ceee9ebfe40ade22033a1b20691d020ebb1f
Author: Zhi <51...@users.noreply.github.com>
AuthorDate: Sun Dec 8 11:57:25 2019 -0800
Check function attr for alpha equal (#4479)
---
src/relay/ir/hash.cc | 2 ++
tests/python/relay/test_pass_alpha_equal.py | 24 ++++++++++++++++++++++++
2 files changed, 26 insertions(+)
diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc
index bce3610..f37b1a4 100644
--- a/src/relay/ir/hash.cc
+++ b/src/relay/ir/hash.cc
@@ -267,6 +267,8 @@ class RelayHashHandler:
hash = Combine(hash, TypeHash(func->ret_type));
hash = Combine(hash, ExprHash(func->body));
+ hash = Combine(hash, AttrHash(func->attrs));
+
return hash;
}
diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_pass_alpha_equal.py
index b240daf..6ef435a 100644
--- a/tests/python/relay/test_pass_alpha_equal.py
+++ b/tests/python/relay/test_pass_alpha_equal.py
@@ -313,6 +313,29 @@ def test_tuple_get_item_alpha_equal():
assert alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 1))
+def test_multi_node_subgraph():
+ x0 = relay.var('x0', shape=(10, 10))
+ w00 = relay.var('w00', shape=(10, 10))
+ w01 = relay.var('w01', shape=(10, 10))
+ w02 = relay.var('w02', shape=(10, 10))
+ z00 = relay.add(x0, w00)
+ p00 = relay.subtract(z00, w01)
+ q00 = relay.multiply(p00, w02)
+ func0 = relay.Function([x0, w00, w01, w02], q00)
+ func0 = func0.set_attribute("FuncName", tvm.expr.StringImm("a"))
+
+ x1 = relay.var('x1', shape=(10, 10))
+ w10 = relay.var('w10', shape=(10, 10))
+ w11 = relay.var('w11', shape=(10, 10))
+ w12 = relay.var('w12', shape=(10, 10))
+ z10 = relay.add(x1, w10)
+ p10 = relay.subtract(z10, w11)
+ q10 = relay.multiply(p10, w12)
+ func1 = relay.Function([x1, w10, w11, w12], q10)
+ func1 = func1.set_attribute("FuncName", tvm.expr.StringImm("b"))
+ assert not alpha_equal(func0, func1)
+
+
def test_function_alpha_equal():
tt1 = relay.TensorType((1, 2, 3), "float32")
tt2 = relay.TensorType((4, 5, 6), "int8")
@@ -639,6 +662,7 @@ if __name__ == "__main__":
test_tuple_alpha_equal()
test_tuple_get_item_alpha_equal()
test_function_alpha_equal()
+ test_function_attr()
test_call_alpha_equal()
test_let_alpha_equal()
test_if_alpha_equal()