You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2021/01/15 14:06:03 UTC

[tvm] branch main updated: [BYOC][bugfix] Handle empty tuples in annotation pass (#7288)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4c5c086  [BYOC][bugfix] Handle empty tuples in annotation pass (#7288)
4c5c086 is described below

commit 4c5c086e2a259adeb486878c76c53896f3377fe8
Author: Steven S. Lyubomirsky <ss...@cs.washington.edu>
AuthorDate: Fri Jan 15 09:05:42 2021 -0500

    [BYOC][bugfix] Handle empty tuples in annotation pass (#7288)
---
 src/relay/transforms/annotate_target.cc         |  5 +++--
 tests/python/relay/test_pass_annotate_target.py | 26 +++++++++++++++++++++++--
 2 files changed, 27 insertions(+), 4 deletions(-)

diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc
index 76585cf..e365dca 100644
--- a/src/relay/transforms/annotate_target.cc
+++ b/src/relay/transforms/annotate_target.cc
@@ -144,11 +144,12 @@ class AnnotateTargetRewriter : public ExprRewriter {
      */
     Expr new_expr = expr;
     const CallNode* call = expr.as<CallNode>();
+    const TupleNode* tup = expr.as<TupleNode>();
     if (op_expr_to_target_.find(expr) != op_expr_to_target_.end()) {
       // Check whether expr has args, if not - do not insert compiler_end.
       if (expr->IsInstance<RefWriteNode>() || expr->IsInstance<RefCreateNode>() ||
-          expr->IsInstance<RefReadNode>() || expr->IsInstance<TupleNode>() ||
-          expr->IsInstance<TupleGetItemNode>() || (call && !call->args.empty())) {
+          expr->IsInstance<RefReadNode>() || expr->IsInstance<TupleGetItemNode>() ||
+          (call && !call->args.empty()) || (tup && !tup->fields.empty())) {
         std::string target = op_expr_to_target_[new_expr];
         new_expr = InsertAnnotation(new_expr, target, make_end_op);
         op_expr_to_target_[new_expr] = target;
diff --git a/tests/python/relay/test_pass_annotate_target.py b/tests/python/relay/test_pass_annotate_target.py
index 4f35066..ce86cc6 100644
--- a/tests/python/relay/test_pass_annotate_target.py
+++ b/tests/python/relay/test_pass_annotate_target.py
@@ -738,8 +738,8 @@ def test_if_free_vars():
         mod = tvm.IRModule.from_expr(func)
         return mod
 
-    for annotate_non_call_ops in [True, False, True]:
-        result = transform.AnnotateTarget(target)(before())
+    for annotate_non_call_ops in [True, False]:
+        result = transform.AnnotateTarget(target, annotate_non_call_ops)(before())
         expected = transform.InferType()(after())
         assert tvm.ir.structural_equal(expected, result)
 
@@ -764,6 +764,27 @@ def test_free_vars_zeros():
     assert tvm.ir.structural_equal(expected, result)
 
 
+def test_empty_tuple():
+    target = "test_empty_tuple"
+
+    """An empty tuple should behave just like a call with no args (see above test)."""
+
+    def before():
+        func = relay.Function([], relay.Tuple([]))
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    def after():
+        func = relay.Function([], relay.Tuple([]))
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    for annotate_non_call_ops in [True, False]:
+        result = transform.AnnotateTarget(target, annotate_non_call_ops)(before())
+        expected = transform.InferType()(after())
+        assert tvm.ir.structural_equal(expected, result)
+
+
 if __name__ == "__main__":
     test_extern_dnnl()
     test_composite_function()
@@ -780,3 +801,4 @@ if __name__ == "__main__":
     test_double_target()
     test_ends_with_tuple()
     test_ref_create_read_write()
+    test_empty_tuple()