You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/11/03 13:49:08 UTC

[GitHub] [incubator-tvm] codeislife99 commented on a change in pull request #6826: Fix Annotate Target to support freevars(relay.zeros, relay.ones etc) of any size (including zero)

codeislife99 commented on a change in pull request #6826:
URL: https://github.com/apache/incubator-tvm/pull/6826#discussion_r516384510



##########
File path: src/relay/transforms/annotate_target.cc
##########
@@ -77,7 +77,11 @@ class AnnotateTargetRewriter : public ExprRewriter {
         compiler_ends.push_back(call->args[0]);
       } else if (op_expr_to_target_.find(arg) != op_expr_to_target_.end()) {
         arg_target = op_expr_to_target_[arg];
-        compiler_ends.push_back(InsertAnnotation(arg, arg_target, make_end_op));
+        if (call && call->args.size() == 0) {

Review comment:
       Done. 

##########
File path: src/relay/transforms/annotate_target.cc
##########
@@ -120,7 +124,7 @@ class AnnotateTargetRewriter : public ExprRewriter {
      * \return An annotated and target-propagated relay expression.
      */
     Expr new_expr = expr;
-    if (op_expr_to_target_.find(expr) != op_expr_to_target_.end()) {
+    if (op_expr_to_target_.find(expr) != op_expr_to_target_.end() && FreeVars(expr).size() != 0) {

Review comment:
       Done and added test cases. 

##########
File path: src/relay/transforms/annotate_target.cc
##########
@@ -77,7 +77,11 @@ class AnnotateTargetRewriter : public ExprRewriter {
         compiler_ends.push_back(call->args[0]);
       } else if (op_expr_to_target_.find(arg) != op_expr_to_target_.end()) {
         arg_target = op_expr_to_target_[arg];
-        compiler_ends.push_back(InsertAnnotation(arg, arg_target, make_end_op));
+        if (call && call->args.size() == 0) {

Review comment:
       Done. 

##########
File path: tests/python/relay/test_pass_annotate_target.py
##########
@@ -510,6 +510,188 @@ def after():
     assert tvm.ir.structural_equal(expected, result)
 
 
+def test_if_free_vars_1():
+    target = "test_if_free_vars_1"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_if_free_vars_2():
+    target = "test_if_free_vars_2"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(0), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(0), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_free_vars_zeros_1():
+    target = "test_free_vars_zeros_1"
+
+    """Test that free variables compile correctly on their own"""

Review comment:
       I removed the two unnecessary tests but I think we need one test for a fn node and one test for a regular one. 

##########
File path: tests/python/relay/test_pass_annotate_target.py
##########
@@ -510,6 +510,188 @@ def after():
     assert tvm.ir.structural_equal(expected, result)
 
 
+def test_if_free_vars_1():
+    target = "test_if_free_vars_1"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_if_free_vars_2():
+    target = "test_if_free_vars_2"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(0), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(0), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_free_vars_zeros_1():
+    target = "test_free_vars_zeros_1"
+
+    """Test that free variables compile correctly on their own"""
+
+    def before():
+
+        func = relay.Function([], relay.zeros(shape=(1, 32), dtype="float32"))
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        func = relay.Function([], relay.zeros(shape=(1, 32), dtype="float32"))
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_free_vars_zeros_2():

Review comment:
       Removed. 

##########
File path: tests/python/relay/test_pass_annotate_target.py
##########
@@ -510,6 +510,188 @@ def after():
     assert tvm.ir.structural_equal(expected, result)
 
 
+def test_if_free_vars_1():
+    target = "test_if_free_vars_1"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_if_free_vars_2():
+    target = "test_if_free_vars_2"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(0), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(0), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_free_vars_zeros_1():
+    target = "test_free_vars_zeros_1"
+
+    """Test that free variables compile correctly on their own"""
+
+    def before():
+
+        func = relay.Function([], relay.zeros(shape=(1, 32), dtype="float32"))
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+

Review comment:
       Done. 

##########
File path: tests/python/relay/test_pass_annotate_target.py
##########
@@ -510,6 +510,188 @@ def after():
     assert tvm.ir.structural_equal(expected, result)
 
 
+def test_if_free_vars_1():
+    target = "test_if_free_vars_1"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_if_free_vars_2():
+    target = "test_if_free_vars_2"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(0), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(0), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_free_vars_zeros_1():
+    target = "test_free_vars_zeros_1"
+
+    """Test that free variables compile correctly on their own"""
+
+    def before():
+

Review comment:
       Done. 

##########
File path: tests/python/relay/test_pass_annotate_target.py
##########
@@ -510,6 +510,188 @@ def after():
     assert tvm.ir.structural_equal(expected, result)
 
 
+def test_if_free_vars_1():
+    target = "test_if_free_vars_1"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_if_free_vars_2():
+    target = "test_if_free_vars_2"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free variables"""

Review comment:
       Removed this test, not necessary.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org