You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2021/01/15 14:06:03 UTC
[tvm] branch main updated: [BYOC][bugfix] Handle empty tuples in
annotation pass (#7288)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4c5c086 [BYOC][bugfix] Handle empty tuples in annotation pass (#7288)
4c5c086 is described below
commit 4c5c086e2a259adeb486878c76c53896f3377fe8
Author: Steven S. Lyubomirsky <ss...@cs.washington.edu>
AuthorDate: Fri Jan 15 09:05:42 2021 -0500
[BYOC][bugfix] Handle empty tuples in annotation pass (#7288)
---
src/relay/transforms/annotate_target.cc | 5 +++--
tests/python/relay/test_pass_annotate_target.py | 26 +++++++++++++++++++++++--
2 files changed, 27 insertions(+), 4 deletions(-)
diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc
index 76585cf..e365dca 100644
--- a/src/relay/transforms/annotate_target.cc
+++ b/src/relay/transforms/annotate_target.cc
@@ -144,11 +144,12 @@ class AnnotateTargetRewriter : public ExprRewriter {
*/
Expr new_expr = expr;
const CallNode* call = expr.as<CallNode>();
+ const TupleNode* tup = expr.as<TupleNode>();
if (op_expr_to_target_.find(expr) != op_expr_to_target_.end()) {
// Check whether expr has args, if not - do not insert compiler_end.
if (expr->IsInstance<RefWriteNode>() || expr->IsInstance<RefCreateNode>() ||
- expr->IsInstance<RefReadNode>() || expr->IsInstance<TupleNode>() ||
- expr->IsInstance<TupleGetItemNode>() || (call && !call->args.empty())) {
+ expr->IsInstance<RefReadNode>() || expr->IsInstance<TupleGetItemNode>() ||
+ (call && !call->args.empty()) || (tup && !tup->fields.empty())) {
std::string target = op_expr_to_target_[new_expr];
new_expr = InsertAnnotation(new_expr, target, make_end_op);
op_expr_to_target_[new_expr] = target;
diff --git a/tests/python/relay/test_pass_annotate_target.py b/tests/python/relay/test_pass_annotate_target.py
index 4f35066..ce86cc6 100644
--- a/tests/python/relay/test_pass_annotate_target.py
+++ b/tests/python/relay/test_pass_annotate_target.py
@@ -738,8 +738,8 @@ def test_if_free_vars():
mod = tvm.IRModule.from_expr(func)
return mod
- for annotate_non_call_ops in [True, False, True]:
- result = transform.AnnotateTarget(target)(before())
+ for annotate_non_call_ops in [True, False]:
+ result = transform.AnnotateTarget(target, annotate_non_call_ops)(before())
expected = transform.InferType()(after())
assert tvm.ir.structural_equal(expected, result)
@@ -764,6 +764,27 @@ def test_free_vars_zeros():
assert tvm.ir.structural_equal(expected, result)
+def test_empty_tuple():
+ target = "test_empty_tuple"
+
+ """An empty tuple should behave just like a call with no args (see above test)."""
+
+ def before():
+ func = relay.Function([], relay.Tuple([]))
+ mod = tvm.IRModule.from_expr(func)
+ return mod
+
+ def after():
+ func = relay.Function([], relay.Tuple([]))
+ mod = tvm.IRModule.from_expr(func)
+ return mod
+
+ for annotate_non_call_ops in [True, False]:
+ result = transform.AnnotateTarget(target, annotate_non_call_ops)(before())
+ expected = transform.InferType()(after())
+ assert tvm.ir.structural_equal(expected, result)
+
+
if __name__ == "__main__":
test_extern_dnnl()
test_composite_function()
@@ -780,3 +801,4 @@ if __name__ == "__main__":
test_double_target()
test_ends_with_tuple()
test_ref_create_read_write()
+ test_empty_tuple()