comaniac commented on a change in pull request #6826:
URL: https://github.com/apache/incubator-tvm/pull/6826#discussion_r516302304



##########
File path: src/relay/transforms/annotate_target.cc
##########
@@ -120,7 +124,7 @@ class AnnotateTargetRewriter : public ExprRewriter {
      * \return An annotated and target-propagated relay expression.
      */
     Expr new_expr = expr;
-    if (op_expr_to_target_.find(expr) != op_expr_to_target_.end()) {
+    if (op_expr_to_target_.find(expr) != op_expr_to_target_.end() && 
FreeVars(expr).size() != 0) {

Review comment:
       Ditto. Please comment why we skip `FreeVars`.

##########
File path: src/relay/transforms/annotate_target.cc
##########
@@ -77,7 +77,11 @@ class AnnotateTargetRewriter : public ExprRewriter {
         compiler_ends.push_back(call->args[0]);
       } else if (op_expr_to_target_.find(arg) != op_expr_to_target_.end()) {
         arg_target = op_expr_to_target_[arg];
-        compiler_ends.push_back(InsertAnnotation(arg, arg_target, 
make_end_op));
+        if (call && call->args.size() == 0) {

Review comment:
       Please add comments to indicate in which case we don't want 
`compiler_end`.

##########
File path: tests/python/relay/test_pass_annotate_target.py
##########
@@ -510,6 +510,188 @@ def after():
     assert tvm.ir.structural_equal(expected, result)
 
 
+def test_if_free_vars_1():
+    target = "test_if_free_vars_1"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free 
variables"""

Review comment:
       Move this line under `def test_if_free_vars_1():`

##########
File path: tests/python/relay/test_pass_annotate_target.py
##########
@@ -510,6 +510,188 @@ def after():
     assert tvm.ir.structural_equal(expected, result)
 
 
+def test_if_free_vars_1():
+    target = "test_if_free_vars_1"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free 
variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_if_free_vars_2():
+    target = "test_if_free_vars_2"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free 
variables"""

Review comment:
       Ditto. Move to the function declaration.
   
   Also from what is being tested in this function, it seems to me that the 
purpose of having this test is zero shape instead of If-else node IIUC. If so 
please update the docstring and the function name to reflect it.

##########
File path: tests/python/relay/test_pass_annotate_target.py
##########
@@ -510,6 +510,188 @@ def after():
     assert tvm.ir.structural_equal(expected, result)
 
 
+def test_if_free_vars_1():
+    target = "test_if_free_vars_1"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free 
variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_if_free_vars_2():
+    target = "test_if_free_vars_2"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free 
variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(0), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(0), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_free_vars_zeros_1():
+    target = "test_free_vars_zeros_1"
+
+    """Test that free variables compile correctly on their own"""
+
+    def before():
+
+        func = relay.Function([], relay.zeros(shape=(1, 32), dtype="float32"))
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        func = relay.Function([], relay.zeros(shape=(1, 32), dtype="float32"))
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_free_vars_zeros_2():

Review comment:
       Same question. Do we need this test?

##########
File path: tests/python/relay/test_pass_annotate_target.py
##########
@@ -510,6 +510,188 @@ def after():
     assert tvm.ir.structural_equal(expected, result)
 
 
+def test_if_free_vars_1():
+    target = "test_if_free_vars_1"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free 
variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_if_free_vars_2():
+    target = "test_if_free_vars_2"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free 
variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(0), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(0), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_free_vars_zeros_1():
+    target = "test_free_vars_zeros_1"
+
+    """Test that free variables compile correctly on their own"""

Review comment:
       Ditto. Also as we already have the above two tests, do we really need 
this one?

##########
File path: tests/python/relay/test_pass_annotate_target.py
##########
@@ -510,6 +510,188 @@ def after():
     assert tvm.ir.structural_equal(expected, result)
 
 
+def test_if_free_vars_1():
+    target = "test_if_free_vars_1"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free 
variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_if_free_vars_2():
+    target = "test_if_free_vars_2"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free 
variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(0), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(0), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_free_vars_zeros_1():
+    target = "test_free_vars_zeros_1"
+
+    """Test that free variables compile correctly on their own"""
+
+    def before():
+
+        func = relay.Function([], relay.zeros(shape=(1, 32), dtype="float32"))
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+

Review comment:
       remove this line.

##########
File path: tests/python/relay/test_pass_annotate_target.py
##########
@@ -510,6 +510,188 @@ def after():
     assert tvm.ir.structural_equal(expected, result)
 
 
+def test_if_free_vars_1():
+    target = "test_if_free_vars_1"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free 
variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_if_free_vars_2():
+    target = "test_if_free_vars_2"
+
+    @tvm.ir.register_op_attr("equal", "target." + target)
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("sigmoid", "target." + target)
+    def sigmoid(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @tvm.ir.register_op_attr("erf", "target." + target)
+    def erf(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that If-else nodes compiles correctly when surrounded by free 
variables"""
+
+    def before():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+        eq = relay.equal(eq1, eq2)
+
+        true_branch = relay.zeros(shape=(0), dtype="float32")
+        false_branch = relay.sigmoid(data)
+        ife = relay.If(eq, true_branch, false_branch)
+        out = relay.erf(ife)
+
+        func = relay.Function([data, eq1, eq2], out)
+        mod = tvm.IRModule.from_expr(func)
+
+        return mod
+
+    def after():
+
+        data = relay.var("data", shape=(1, 32))
+        eq1 = relay.var("e1", shape=[], dtype="float32")
+        eq2 = relay.var("e2", shape=[], dtype="float32")
+
+        cb_1 = relay.annotation.compiler_begin(eq1, target)
+        cb_2 = relay.annotation.compiler_begin(eq2, target)
+
+        equality_condition = relay.equal(cb_1, cb_2)
+        ce_1 = relay.annotation.compiler_end(equality_condition, target)
+
+        # if condition
+        true_branch = relay.zeros(shape=(0), dtype="float32")
+
+        # else condition
+        cb_3 = relay.annotation.compiler_begin(data, target)
+        false_branch = relay.sigmoid(cb_3)
+        ce_2 = relay.annotation.compiler_end(false_branch, target)
+
+        if_condition = relay.If(ce_1, true_branch, ce_2)
+        cb_4 = relay.annotation.compiler_begin(if_condition, target)
+        erf_out = relay.erf(cb_4)
+        ce_3 = relay.annotation.compiler_end(erf_out, target)
+        func = relay.Function([data, eq1, eq2], ce_3)
+        mod = tvm.IRModule.from_expr(func)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
+def test_free_vars_zeros_1():
+    target = "test_free_vars_zeros_1"
+
+    """Test that free variables compile correctly on their own"""
+
+    def before():
+

Review comment:
       remove this line.

##########
File path: src/relay/transforms/annotate_target.cc
##########
@@ -77,7 +77,11 @@ class AnnotateTargetRewriter : public ExprRewriter {
         compiler_ends.push_back(call->args[0]);
       } else if (op_expr_to_target_.find(arg) != op_expr_to_target_.end()) {
         arg_target = op_expr_to_target_[arg];
-        compiler_ends.push_back(InsertAnnotation(arg, arg_target, 
make_end_op));
+        if (call && call->args.size() == 0) {

Review comment:
       Thanks. Would you be more specific like "If an argument is a call node 
and has no argument, then it should be tensor ops such as zeros, so we treat it 
as input vars."




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to