You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/11/03 13:42:43 UTC

[GitHub] [incubator-tvm] codeislife99 opened a new pull request #6826: Fix Annotate Target to support freevars(zeros, ones etc) of size 0

codeislife99 opened a new pull request #6826:
URL: https://github.com/apache/incubator-tvm/pull/6826


   In the Torchvision 0.7 MaskRCNN model there are freevars(zeros, ones) of shape 0. Annotate target pass currently generates a compiler end (without a compiler begin) for all of such cases. This PR fixes that. 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] codeislife99 commented on pull request #6826: Fix Annotate Target to support freevars(zeros, ones etc) of size 0

Posted by GitBox <gi...@apache.org>.
codeislife99 commented on pull request #6826:
URL: https://github.com/apache/incubator-tvm/pull/6826#issuecomment-720778342






----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] comaniac commented on a change in pull request #6826: Fix Annotate Target to support freevars(zeros, ones etc) of size 0

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #6826:
URL: https://github.com/apache/incubator-tvm/pull/6826#discussion_r516302304



##########
File path: src/relay/transforms/annotate_target.cc
##########
@@ -120,7 +124,7 @@ class AnnotateTargetRewriter : public ExprRewriter {
      * \return An annotated and target-propagated relay expression.
      */
     Expr new_expr = expr;
-    if (op_expr_to_target_.find(expr) != op_expr_to_target_.end()) {
+    if (op_expr_to_target_.find(expr) != op_expr_to_target_.end() && FreeVars(expr).size() != 0) {

Review comment:
       Ditto. Please comment why we skip `FreeVars`.

##########
File path: src/relay/transforms/annotate_target.cc
##########
@@ -77,7 +77,11 @@ class AnnotateTargetRewriter : public ExprRewriter {
         compiler_ends.push_back(call->args[0]);
       } else if (op_expr_to_target_.find(arg) != op_expr_to_target_.end()) {
         arg_target = op_expr_to_target_[arg];
-        compiler_ends.push_back(InsertAnnotation(arg, arg_target, make_end_op));
+        if (call && call->args.size() == 0) {

Review comment:
       Please add comments to indicate in which case we don't want `compiler_end`.

##########
File path: tests/python/relay/test_pass_annotate_target.py
##########
@@ -510,6 +510,188 @@ def after():
     assert tvm.ir.structural_equal(expected, result)
 
 
+def test_if_free_vars_1():
+    target = "test_if_free_vars_1"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free variables"""

Review comment:
       Move this line under `def test_if_free_vars_1():`

##########
File path: tests/python/relay/test_pass_annotate_target.py
##########
@@ -510,6 +510,188 @@ def after():
     assert tvm.ir.structural_equal(expected, result)
 
 
+def test_if_free_vars_1():
+    target = "test_if_free_vars_1"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_if_free_vars_2():
+    target = "test_if_free_vars_2"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free variables"""

Review comment:
       Ditto. Move to the function declaration.
   
   Also from what is being tested in this function, it seems to me that the purpose of having this test is zero shape instead of If-else node IIUC. If so please update the docstring and the function name to reflect it.

##########
File path: tests/python/relay/test_pass_annotate_target.py
##########
@@ -510,6 +510,188 @@ def after():
     assert tvm.ir.structural_equal(expected, result)
 
 
+def test_if_free_vars_1():
+    target = "test_if_free_vars_1"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_if_free_vars_2():
+    target = "test_if_free_vars_2"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(0), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(0), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_free_vars_zeros_1():
+    target = "test_free_vars_zeros_1"
+
+    """Test that free variables compile correctly on their own"""
+
+    def before():
+
+        func = relay.Function([], relay.zeros(shape=(1, 32), dtype="float32"))
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        func = relay.Function([], relay.zeros(shape=(1, 32), dtype="float32"))
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_free_vars_zeros_2():

Review comment:
       Same question. Do we need this test?

##########
File path: tests/python/relay/test_pass_annotate_target.py
##########
@@ -510,6 +510,188 @@ def after():
     assert tvm.ir.structural_equal(expected, result)
 
 
+def test_if_free_vars_1():
+    target = "test_if_free_vars_1"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_if_free_vars_2():
+    target = "test_if_free_vars_2"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(0), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(0), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_free_vars_zeros_1():
+    target = "test_free_vars_zeros_1"
+
+    """Test that free variables compile correctly on their own"""

Review comment:
       Ditto. Also as we already have the above two tests, do we really need this one?

##########
File path: tests/python/relay/test_pass_annotate_target.py
##########
@@ -510,6 +510,188 @@ def after():
     assert tvm.ir.structural_equal(expected, result)
 
 
+def test_if_free_vars_1():
+    target = "test_if_free_vars_1"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_if_free_vars_2():
+    target = "test_if_free_vars_2"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(0), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(0), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_free_vars_zeros_1():
+    target = "test_free_vars_zeros_1"
+
+    """Test that free variables compile correctly on their own"""
+
+    def before():
+
+        func = relay.Function([], relay.zeros(shape=(1, 32), dtype="float32"))
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+

Review comment:
       remove this line.

##########
File path: tests/python/relay/test_pass_annotate_target.py
##########
@@ -510,6 +510,188 @@ def after():
     assert tvm.ir.structural_equal(expected, result)
 
 
+def test_if_free_vars_1():
+    target = "test_if_free_vars_1"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_if_free_vars_2():
+    target = "test_if_free_vars_2"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(0), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(0), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_free_vars_zeros_1():
+    target = "test_free_vars_zeros_1"
+
+    """Test that free variables compile correctly on their own"""
+
+    def before():
+

Review comment:
       remove this line.

##########
File path: src/relay/transforms/annotate_target.cc
##########
@@ -77,7 +77,11 @@ class AnnotateTargetRewriter : public ExprRewriter {
         compiler_ends.push_back(call->args[0]);
       } else if (op_expr_to_target_.find(arg) != op_expr_to_target_.end()) {
         arg_target = op_expr_to_target_[arg];
-        compiler_ends.push_back(InsertAnnotation(arg, arg_target, make_end_op));
+        if (call && call->args.size() == 0) {

Review comment:
       Thanks. Would you be more specific like "If an argument is a call node and has no argument, then it should be tensor ops such as zeros, so we treat it as input vars."




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] comaniac commented on pull request #6826: Fix Annotate Target to support freevars(relay.zeros, relay.ones etc) of any size (including zero)

Posted by GitBox <gi...@apache.org>.
comaniac commented on pull request #6826:
URL: https://github.com/apache/incubator-tvm/pull/6826#issuecomment-721297658


   Thanks @codeislife99 @anijain2305 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] comaniac merged pull request #6826: Fix Annotate Target to support freevars(relay.zeros, relay.ones etc) of any size (including zero)

Posted by GitBox <gi...@apache.org>.
comaniac merged pull request #6826:
URL: https://github.com/apache/incubator-tvm/pull/6826


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] anijain2305 commented on a change in pull request #6826: Fix Annotate Target to support freevars(relay.zeros, relay.ones etc) of any size (including zero)

Posted by GitBox <gi...@apache.org>.
anijain2305 commented on a change in pull request #6826:
URL: https://github.com/apache/incubator-tvm/pull/6826#discussion_r516402796



##########
File path: tests/python/relay/test_pass_annotate_target.py
##########
@@ -510,6 +510,188 @@ def after():
     assert tvm.ir.structural_equal(expected, result)
 
 
+def test_if_free_vars_1():
+    target = "test_if_free_vars_1"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_if_free_vars_2():
+    target = "test_if_free_vars_2"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(0), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(0), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_free_vars_zeros_1():
+    target = "test_free_vars_zeros_1"
+
+    """Test that free variables compile correctly on their own"""

Review comment:
       So, FunctionNode/IfNode go through a different codepath for inserting the final compiler_end. If the FunctionNode/IfNode has just one `zero` op, then it might lead to another compiler_end. Thats why, we need both tests.

##########
File path: tests/python/relay/test_pass_annotate_target.py
##########
@@ -510,6 +510,188 @@ def after():
     assert tvm.ir.structural_equal(expected, result)
 
 
+def test_if_free_vars_1():
+    target = "test_if_free_vars_1"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_if_free_vars_2():
+    target = "test_if_free_vars_2"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(0), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(0), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_free_vars_zeros_1():
+    target = "test_free_vars_zeros_1"
+
+    """Test that free variables compile correctly on their own"""

Review comment:
       This is related to the FreeVars change in the source file.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] codeislife99 commented on a change in pull request #6826: Fix Annotate Target to support freevars(relay.zeros, relay.ones etc) of any size (including zero)

Posted by GitBox <gi...@apache.org>.
codeislife99 commented on a change in pull request #6826:
URL: https://github.com/apache/incubator-tvm/pull/6826#discussion_r516384510



##########
File path: src/relay/transforms/annotate_target.cc
##########
@@ -77,7 +77,11 @@ class AnnotateTargetRewriter : public ExprRewriter {
         compiler_ends.push_back(call->args[0]);
       } else if (op_expr_to_target_.find(arg) != op_expr_to_target_.end()) {
         arg_target = op_expr_to_target_[arg];
-        compiler_ends.push_back(InsertAnnotation(arg, arg_target, make_end_op));
+        if (call && call->args.size() == 0) {

Review comment:
       Done. 

##########
File path: src/relay/transforms/annotate_target.cc
##########
@@ -120,7 +124,7 @@ class AnnotateTargetRewriter : public ExprRewriter {
      * \return An annotated and target-propagated relay expression.
      */
     Expr new_expr = expr;
-    if (op_expr_to_target_.find(expr) != op_expr_to_target_.end()) {
+    if (op_expr_to_target_.find(expr) != op_expr_to_target_.end() && FreeVars(expr).size() != 0) {

Review comment:
       Done and added test cases. 

##########
File path: src/relay/transforms/annotate_target.cc
##########
@@ -77,7 +77,11 @@ class AnnotateTargetRewriter : public ExprRewriter {
         compiler_ends.push_back(call->args[0]);
       } else if (op_expr_to_target_.find(arg) != op_expr_to_target_.end()) {
         arg_target = op_expr_to_target_[arg];
-        compiler_ends.push_back(InsertAnnotation(arg, arg_target, make_end_op));
+        if (call && call->args.size() == 0) {

Review comment:
       Done. 

##########
File path: tests/python/relay/test_pass_annotate_target.py
##########
@@ -510,6 +510,188 @@ def after():
     assert tvm.ir.structural_equal(expected, result)
 
 
+def test_if_free_vars_1():
+    target = "test_if_free_vars_1"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_if_free_vars_2():
+    target = "test_if_free_vars_2"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(0), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(0), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_free_vars_zeros_1():
+    target = "test_free_vars_zeros_1"
+
+    """Test that free variables compile correctly on their own"""

Review comment:
       I removed the two unnecessary tests but I think we need one test for a fn node and one test for a regular one. 

##########
File path: tests/python/relay/test_pass_annotate_target.py
##########
@@ -510,6 +510,188 @@ def after():
     assert tvm.ir.structural_equal(expected, result)
 
 
+def test_if_free_vars_1():
+    target = "test_if_free_vars_1"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_if_free_vars_2():
+    target = "test_if_free_vars_2"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(0), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(0), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_free_vars_zeros_1():
+    target = "test_free_vars_zeros_1"
+
+    """Test that free variables compile correctly on their own"""
+
+    def before():
+
+        func = relay.Function([], relay.zeros(shape=(1, 32), dtype="float32"))
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        func = relay.Function([], relay.zeros(shape=(1, 32), dtype="float32"))
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_free_vars_zeros_2():

Review comment:
       Removed. 

##########
File path: tests/python/relay/test_pass_annotate_target.py
##########
@@ -510,6 +510,188 @@ def after():
     assert tvm.ir.structural_equal(expected, result)
 
 
+def test_if_free_vars_1():
+    target = "test_if_free_vars_1"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_if_free_vars_2():
+    target = "test_if_free_vars_2"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(0), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(0), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_free_vars_zeros_1():
+    target = "test_free_vars_zeros_1"
+
+    """Test that free variables compile correctly on their own"""
+
+    def before():
+
+        func = relay.Function([], relay.zeros(shape=(1, 32), dtype="float32"))
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+

Review comment:
       Done. 

##########
File path: tests/python/relay/test_pass_annotate_target.py
##########
@@ -510,6 +510,188 @@ def after():
     assert tvm.ir.structural_equal(expected, result)
 
 
+def test_if_free_vars_1():
+    target = "test_if_free_vars_1"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_if_free_vars_2():
+    target = "test_if_free_vars_2"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(0), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(0), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_free_vars_zeros_1():
+    target = "test_free_vars_zeros_1"
+
+    """Test that free variables compile correctly on their own"""
+
+    def before():
+

Review comment:
       Done. 

##########
File path: tests/python/relay/test_pass_annotate_target.py
##########
@@ -510,6 +510,188 @@ def after():
     assert tvm.ir.structural_equal(expected, result)
 
 
+def test_if_free_vars_1():
+    target = "test_if_free_vars_1"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_if_free_vars_2():
+    target = "test_if_free_vars_2"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free variables"""

Review comment:
       Removed this test, not necessary.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org