You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2019/12/08 19:57:33 UTC

[incubator-tvm] branch master updated: Check function attr for alpha equal (#4479)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 5fe5cee  Check function attr for alpha equal (#4479)
5fe5cee is described below

commit 5fe5ceee9ebfe40ade22033a1b20691d020ebb1f
Author: Zhi <51...@users.noreply.github.com>
AuthorDate: Sun Dec 8 11:57:25 2019 -0800

    Check function attr for alpha equal (#4479)
---
 src/relay/ir/hash.cc                        |  2 ++
 tests/python/relay/test_pass_alpha_equal.py | 24 ++++++++++++++++++++++++
 2 files changed, 26 insertions(+)

diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc
index bce3610..f37b1a4 100644
--- a/src/relay/ir/hash.cc
+++ b/src/relay/ir/hash.cc
@@ -267,6 +267,8 @@ class RelayHashHandler:
     hash = Combine(hash, TypeHash(func->ret_type));
     hash = Combine(hash, ExprHash(func->body));
 
+    hash = Combine(hash, AttrHash(func->attrs));
+
     return hash;
   }
 
diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_pass_alpha_equal.py
index b240daf..6ef435a 100644
--- a/tests/python/relay/test_pass_alpha_equal.py
+++ b/tests/python/relay/test_pass_alpha_equal.py
@@ -313,6 +313,29 @@ def test_tuple_get_item_alpha_equal():
     assert alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 1))
 
 
+def test_multi_node_subgraph():
+    x0 = relay.var('x0', shape=(10, 10))
+    w00 = relay.var('w00', shape=(10, 10))
+    w01 = relay.var('w01', shape=(10, 10))
+    w02 = relay.var('w02', shape=(10, 10))
+    z00 = relay.add(x0, w00)
+    p00 = relay.subtract(z00, w01)
+    q00 = relay.multiply(p00, w02)
+    func0 = relay.Function([x0, w00, w01, w02], q00)
+    func0 = func0.set_attribute("FuncName", tvm.expr.StringImm("a"))
+
+    x1 = relay.var('x1', shape=(10, 10))
+    w10 = relay.var('w10', shape=(10, 10))
+    w11 = relay.var('w11', shape=(10, 10))
+    w12 = relay.var('w12', shape=(10, 10))
+    z10 = relay.add(x1, w10)
+    p10 = relay.subtract(z10, w11)
+    q10 = relay.multiply(p10, w12)
+    func1 = relay.Function([x1, w10, w11, w12], q10)
+    func1 = func1.set_attribute("FuncName", tvm.expr.StringImm("b"))
+    assert not alpha_equal(func0, func1)
+
+
 def test_function_alpha_equal():
     tt1 = relay.TensorType((1, 2, 3), "float32")
     tt2 = relay.TensorType((4, 5, 6), "int8")
@@ -639,6 +662,7 @@ if __name__ == "__main__":
     test_tuple_alpha_equal()
     test_tuple_get_item_alpha_equal()
     test_function_alpha_equal()
+    test_function_attr()
     test_call_alpha_equal()
     test_let_alpha_equal()
     test_if_alpha_equal()