codeislife99 commented on a change in pull request #6826:
URL: https://github.com/apache/incubator-tvm/pull/6826#discussion_r516384510



##########
File path: src/relay/transforms/annotate_target.cc
##########
@@ -77,7 +77,11 @@ class AnnotateTargetRewriter : public ExprRewriter {
         compiler_ends.push_back(call->args[0]);
       } else if (op_expr_to_target_.find(arg) != op_expr_to_target_.end()) {
         arg_target = op_expr_to_target_[arg];
-        compiler_ends.push_back(InsertAnnotation(arg, arg_target, 
make_end_op));
+        if (call && call->args.size() == 0) {

Review comment:
       Done. 

##########
File path: src/relay/transforms/annotate_target.cc
##########
@@ -120,7 +124,7 @@ class AnnotateTargetRewriter : public ExprRewriter {
      * \return An annotated and target-propagated relay expression.
      */
     Expr new_expr = expr;
-    if (op_expr_to_target_.find(expr) != op_expr_to_target_.end()) {
+    if (op_expr_to_target_.find(expr) != op_expr_to_target_.end() && 
FreeVars(expr).size() != 0) {

Review comment:
       Done and added test cases. 

##########
File path: src/relay/transforms/annotate_target.cc
##########
@@ -77,7 +77,11 @@ class AnnotateTargetRewriter : public ExprRewriter {
         compiler_ends.push_back(call->args[0]);
       } else if (op_expr_to_target_.find(arg) != op_expr_to_target_.end()) {
         arg_target = op_expr_to_target_[arg];
-        compiler_ends.push_back(InsertAnnotation(arg, arg_target, 
make_end_op));
+        if (call && call->args.size() == 0) {

Review comment:
       Done. 

##########
File path: tests/python/relay/test_pass_annotate_target.py
##########
@@ -510,6 +510,188 @@ def after():
     assert tvm.ir.structural_equal(expected, result)
 
 
+def test_if_free_vars_1():
+    target = "test_if_free_vars_1"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free 
variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_if_free_vars_2():
+    target = "test_if_free_vars_2"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free 
variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(0), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(0), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_free_vars_zeros_1():
+    target = "test_free_vars_zeros_1"
+
+    """Test that free variables compile correctly on their own"""

Review comment:
       I removed the two unnecessary tests but I think we need one test for a 
fn node and one test for a regular one. 

##########
File path: tests/python/relay/test_pass_annotate_target.py
##########
@@ -510,6 +510,188 @@ def after():
     assert tvm.ir.structural_equal(expected, result)
 
 
+def test_if_free_vars_1():
+    target = "test_if_free_vars_1"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free 
variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_if_free_vars_2():
+    target = "test_if_free_vars_2"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free 
variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(0), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(0), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_free_vars_zeros_1():
+    target = "test_free_vars_zeros_1"
+
+    """Test that free variables compile correctly on their own"""
+
+    def before():
+
+        func = relay.Function([], relay.zeros(shape=(1, 32), dtype="float32"))
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        func = relay.Function([], relay.zeros(shape=(1, 32), dtype="float32"))
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_free_vars_zeros_2():

Review comment:
       Removed. 

##########
File path: tests/python/relay/test_pass_annotate_target.py
##########
@@ -510,6 +510,188 @@ def after():
     assert tvm.ir.structural_equal(expected, result)
 
 
+def test_if_free_vars_1():
+    target = "test_if_free_vars_1"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free 
variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_if_free_vars_2():
+    target = "test_if_free_vars_2"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free 
variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(0), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(0), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_free_vars_zeros_1():
+    target = "test_free_vars_zeros_1"
+
+    """Test that free variables compile correctly on their own"""
+
+    def before():
+
+        func = relay.Function([], relay.zeros(shape=(1, 32), dtype="float32"))
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+

Review comment:
       Done. 

##########
File path: tests/python/relay/test_pass_annotate_target.py
##########
@@ -510,6 +510,188 @@ def after():
     assert tvm.ir.structural_equal(expected, result)
 
 
+def test_if_free_vars_1():
+    target = "test_if_free_vars_1"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free 
variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_if_free_vars_2():
+    target = "test_if_free_vars_2"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free 
variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(0), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(0), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_free_vars_zeros_1():
+    target = "test_free_vars_zeros_1"
+
+    """Test that free variables compile correctly on their own"""
+
+    def before():
+

Review comment:
       Done. 

##########
File path: tests/python/relay/test_pass_annotate_target.py
##########
@@ -510,6 +510,188 @@ def after():
     assert tvm.ir.structural_equal(expected, result)
 
 
+def test_if_free_vars_1():
+    target = "test_if_free_vars_1"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free 
variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_if_free_vars_2():
+    target = "test_if_free_vars_2"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free 
variables"""

Review comment:
       Removed this test, not necessary.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to