This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 95de08b  Fix alpha_equal bug (#4897)
95de08b is described below

commit 95de08ba4f0d90dde308f4b2b401da8aaa333d2b
Author: Zhi <[email protected]>
AuthorDate: Sun Feb 16 17:44:22 2020 -0800

    Fix alpha_equal bug (#4897)
---
 src/relay/ir/alpha_equal.cc                     |  2 +-
 tests/python/relay/test_ir_nodes.py             |  2 +
 tests/python/relay/test_pass_alpha_equal.py     | 25 ++++++-
 tests/python/relay/test_pass_fuse_ops.py        | 36 +++++++++-
 tests/python/relay/test_pass_merge_composite.py | 93 +++++++++++++------------
 5 files changed, 109 insertions(+), 49 deletions(-)

diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc
index 48634ba..78688d7 100644
--- a/src/relay/ir/alpha_equal.cc
+++ b/src/relay/ir/alpha_equal.cc
@@ -92,7 +92,7 @@ class AlphaEqualHandler:
     auto compute = [&]() {
       if (&lhs == &rhs) return true;
       if (auto lhsd = lhs.as<DictAttrsNode>()) {
-        auto rhsd = lhs.as<DictAttrsNode>();
+        auto rhsd = rhs.as<DictAttrsNode>();
         if (!rhsd) return false;
         if (lhsd->dict.size() != rhsd->dict.size()) return false;
         for (const auto& k : lhsd->dict) {
diff --git a/tests/python/relay/test_ir_nodes.py 
b/tests/python/relay/test_ir_nodes.py
index ad15255..bdda72c 100644
--- a/tests/python/relay/test_ir_nodes.py
+++ b/tests/python/relay/test_ir_nodes.py
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 """ test ir"""
+import pytest
 import tvm
 from tvm import relay
 from tvm.tir.expr import *
@@ -174,6 +175,7 @@ def test_function():
     str(fn)
     check_json_roundtrip(fn)
 
[email protected](reason="AttrsEqualHandler doesn't handle Map so far.")
 def test_function_attrs():
     param_names = ['a', 'b', 'c', 'd']
     params = tvm.convert([relay.var(n, shape=(5, 2)) for n in param_names])
diff --git a/tests/python/relay/test_pass_alpha_equal.py 
b/tests/python/relay/test_pass_alpha_equal.py
index 5985273..0319d0b 100644
--- a/tests/python/relay/test_pass_alpha_equal.py
+++ b/tests/python/relay/test_pass_alpha_equal.py
@@ -18,6 +18,7 @@ import numpy as np
 import tvm
 from tvm import relay
 from tvm.relay import analysis
+from tvm.relay.testing import run_opt_pass
 
 def alpha_equal(x, y):
     """
@@ -313,7 +314,7 @@ def test_tuple_get_item_alpha_equal():
     assert alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 1))
 
 
-def test_multi_node_subgraph():
+def test_function_attr():
     x0 = relay.var('x0', shape=(10, 10))
     w00 = relay.var('w00', shape=(10, 10))
     w01 = relay.var('w01', shape=(10, 10))
@@ -608,6 +609,7 @@ def test_graph_equal():
     z3 = relay.add(relay.add(x, x), relay.add(x, x))
 
     assert alpha_equal(z0, z1)
+    assert alpha_equal(z0, z1)
 
     # z3's dataflow format is different from z0
     # z0 is computed from a common y0 node
@@ -649,6 +651,26 @@ def test_tuple_match():
     assert analysis.structural_hash(x) == analysis.structural_hash(y)
 
 
+def test_fn_attribute():
+    # create function that performs add
+    a = relay.var('a', shape=(10, 10))
+    b = relay.var('b', shape=(10, 10))
+    add = relay.add(a, b)
+    add_fn = relay.Function([a, b], add)
+    add_fn = run_opt_pass(add_fn, relay.transform.InferType())
+
+    # create function that performs add with test attribute
+    c = relay.var('c', shape=(10, 10))
+    d = relay.var('d', shape=(10, 10))
+    add_1 = relay.add(c, d)
+    add_1_fn = relay.Function([c, d], add_1)
+    add_1_fn = add_1_fn.set_attribute("TestAttribute", 
tvm.tir.StringImm("test"))
+    add_1_fn = run_opt_pass(add_1_fn, relay.transform.InferType())
+
+    assert not relay.analysis.alpha_equal(add_1_fn, add_fn)
+    assert not relay.analysis.alpha_equal(add_fn, add_1_fn)
+
+
 if __name__ == "__main__":
     test_tensor_type_alpha_equal()
     test_incomplete_type_alpha_equal()
@@ -672,3 +694,4 @@ if __name__ == "__main__":
     test_var_alpha_equal()
     test_graph_equal()
     test_hash_unequal()
+    test_fn_attribute()
diff --git a/tests/python/relay/test_pass_fuse_ops.py 
b/tests/python/relay/test_pass_fuse_ops.py
index 18916f7..e11b6ae 100644
--- a/tests/python/relay/test_pass_fuse_ops.py
+++ b/tests/python/relay/test_pass_fuse_ops.py
@@ -35,6 +35,7 @@ def test_fuse_simple():
         z = relay.exp(y)
         w = relay.squeeze(z)
         f1 = relay.Function([x], w)
+        f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
         x = relay.var("x", shape=(10, 20))
         y = relay.Call(f1, [x])
         return relay.Function([x], y)
@@ -76,6 +77,8 @@ def test_conv2d_fuse():
         x = relay.var("p0", shape=dshape)
         y = relay.add(x, relay.const(1, "float32"))
         f0 = relay.Function([x], y)
+        f0 = f0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
+
         # segment 1
         x = relay.var("p0", shape=dshape)
         w = relay.var("p1")
@@ -86,6 +89,8 @@ def test_conv2d_fuse():
         y1 = relay.add(relay.const(1, "float32"), y)
         y = relay.add(y, y1)
         f1 = relay.Function([x, w], y)
+        f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
+
         # segment 2
         x = relay.var("p0", shape=dshape)
         w = relay.var("p1")
@@ -94,6 +99,8 @@ def test_conv2d_fuse():
                              padding=(1,1),
                              channels=16)
         f2 = relay.Function([x, w], z2)
+        f2 = f2.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
+
         # segment 3
         x = relay.var("p0", shape=dshape)
         w = relay.var("p1")
@@ -104,6 +111,8 @@ def test_conv2d_fuse():
                              channels=16)
         z3 = relay.add(z3, offset)
         f3 = relay.Function([x, w, offset], z3)
+        f3 = f3.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
+
         # compose
         x = relay.var("x", shape=dshape)
         y = relay.Call(f0, [x])
@@ -135,6 +144,7 @@ def test_concatenate():
         x = relay.var("x", shape=dshape)
         pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), 
padding=(0, 0))
         f0 = relay.Function([x], pooled)
+        f0 = f0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
 
         p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, 
dshape[3]//2))
         p1 = relay.var("p1", shape=dshape)
@@ -142,6 +152,7 @@ def test_concatenate():
         concat = relay.concatenate((upsampled, p1), axis=1)
         out = relay.add(concat, relay.const(1, "float32"))
         f1 = relay.Function([p0, p1], out)
+        f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
 
         x = relay.var("x", shape=dshape)
         y = relay.Call(f0, [x])
@@ -172,10 +183,12 @@ def test_tuple_root():
         x = relay.var("x", shape=dshape)
         pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), 
padding=(0, 0))
         f0 = relay.Function([x], pooled)
+        f0 = f0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
 
         p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, 
dshape[3]//2))
         upsampled = relay.nn.upsampling(p0, scale_h=2, scale_w=2, 
layout="NCHW")
         f1 = relay.Function([p0], upsampled)
+        f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
 
         x = relay.var("x", shape=dshape)
         y = relay.Call(f0, [x])
@@ -205,10 +218,12 @@ def test_stop_fusion():
         x = relay.var("p0", shape=dshape)
         y = relay.add(x, relay.const(1, "float32"))
         f1 = relay.Function([x], y)
+        f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
 
         x = relay.var("p01", shape=dshape)
         y = relay.exp(x)
         f2 = relay.Function([x], y)
+        f2 = f2.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
 
         x = relay.var("x", shape=dshape)
         y = relay.Call(f1, [x])
@@ -242,6 +257,7 @@ def test_fuse_myia_regression():
         p2 = relay.var('p2', shape=dshape, dtype=dtype)
         fused_gt = relay.Function([p1, p2],
             relay.op.greater(p1, p2))
+        fused_gt = fused_gt.set_attribute("Primitive", tvm.tir.IntImm("int32", 
1))
         with sb.if_scope(fused_gt(x, y)):
             sb.ret(relay.Function([], x))
         with sb.else_scope():
@@ -271,11 +287,13 @@ def test_fuse_tuple_get_elemwise():
         p1 = relay.var("p1", shape=(3 * dim, dim))
         matmul = relay.nn.dense(p0, p1)
         f0 = relay.Function([p0, p1], matmul)
+        f0 = f0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
 
         p01 = relay.var("p01", shape=(1, 3 * dim))
         splitted = relay.split(p01, indices_or_sections=3, axis=1)
         out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * 
relay.exp(splitted[2])
         f1 = relay.Function([p01], out)
+        f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
 
         X = relay.var("X", shape=(1, dim))
         W = relay.var("W", shape=(3 * dim, dim))
@@ -306,11 +324,13 @@ def test_tuple_get_root():
         splitted = relay.split(p0, indices_or_sections=3, axis=1)
         out = splitted[0]
         f0 = relay.Function([p0], out)
+        f0 = f0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
 
         p01 = relay.var("p01", shape=(1, dim))
         p1 = relay.var("p1", shape=(dim, dim))
         out = relay.nn.dense(p01, p1)
         f1 = relay.Function([p01, p1], out)
+        f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
 
         X = relay.var("X", shape=(1, 3 * dim))
         W = relay.var("W", shape=(dim, dim))
@@ -346,8 +366,9 @@ def test_tuple_intermediate():
 
     def expected(p0):
         f0 = before(p0)
+        f1 = f0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
         x = relay.var("x", shape=dshape)
-        y = relay.Call(f0, [x])
+        y = relay.Call(f1, [x])
         return relay.Function([x], y)
 
     dshape = (1, 16, 64, 64)
@@ -388,15 +409,18 @@ def test_tuple_consecutive():
         p0 = relay.var("p0", shape=dshape)
         concat = gen_consecutive_tuple(p0)
         f0 = relay.Function([p0], concat)
+        f0 = f0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
 
         p01 = relay.var("p01", shape=(1, dshape[1]*9, dshape[2], dshape[3]))
         pooled = relay.nn.max_pool2d(p01, pool_size=(2, 2), strides=(2, 2), 
padding=(0, 0))
         out = relay.add(pooled, relay.const(1, "float32"))
         f1 = relay.Function([p01], out)
+        f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
 
         p02 = relay.var("p02", shape=(1, dshape[1]*9, dshape[2]//2, 
dshape[3]//2))
         out = relay.add(p02, relay.const(1, "float32"))
         f2 = relay.Function([p02], out)
+        f2 = f2.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
 
         x = relay.var("x", shape=dshape)
         y = relay.Call(f0, [x])
@@ -438,30 +462,36 @@ def test_inception_like():
         p0 = relay.var("p0", shape=dshape)
         c = conv(p0)
         f0 = relay.Function(relay.analysis.free_vars(c), c)
+        f0 = f0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
 
         p01 = relay.var("p01", shape=dshape)
         c = conv(p01)
         f1 = relay.Function(relay.analysis.free_vars(c), c)
+        f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
 
         p02 = relay.var("p02", shape=dshape)
         p12 = relay.var("p12", shape=dshape)
         concat1 = relay.concatenate((p02, p12), axis=1)
         f_concat1 = relay.Function([p02, p12], concat1)
+        f_concat1 = f_concat1.set_attribute("Primitive", 
tvm.tir.IntImm("int32", 1))
 
         dshape2 = (dshape[0], dshape[1]*2, dshape[2], dshape[3])
 
         p03 = relay.var("p03", shape=dshape2)
         c = conv(p03)
         f2 = relay.Function(relay.analysis.free_vars(c), c)
+        f2 = f2.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
 
         p04 = relay.var("p04", shape=dshape2)
         c = conv(p04)
         f3 = relay.Function(relay.analysis.free_vars(c), c)
+        f3 = f3.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
 
         p05 = relay.var("p05", shape=dshape)
         p15 = relay.var("p15", shape=dshape)
         concat2 = relay.concatenate((p05, p15), axis=1)
         f_concat2 = relay.Function([p05, p15], concat2)
+        f_concat2 = f_concat2.set_attribute("Primitive", 
tvm.tir.IntImm("int32", 1))
 
         x = relay.var("x", shape=dshape)
         c1 = relay.Call(f0, [x, relay.var("w1")])
@@ -499,6 +529,7 @@ def test_fuse_parallel_injective():
         u = relay.transpose(y, axes=[0, 1])
         w = relay.left_shift(z, u)
         f1 = relay.Function([x], w)
+        f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
         x = relay.var("x", shape=(10, 20))
         y = relay.Call(f1, [x])
         return relay.Function([x], y)
@@ -529,6 +560,7 @@ def test_immutable():
         z = relay.exp(y)
         w = relay.squeeze(z)
         f1 = relay.Function([x], w)
+        f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
         x = relay.var("x", shape=(10, 20))
         y = relay.Call(f1, [x])
         mod = tvm.IRModule()
@@ -570,6 +602,7 @@ def test_fuse_max():
         for i in range(max_fused_ops):
             y = relay.exp(y)
         f1 = relay.Function([x], y)
+        f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
         x = relay.var("x", shape=(10, 20))
         z = relay.Call(f1, [x])
         xx = relay.var("pp", shape=(10, 20))
@@ -577,6 +610,7 @@ def test_fuse_max():
         for i in range(n-max_fused_ops):
             yy = relay.exp(yy)
         f2 = relay.Function([xx], yy)
+        f2 = f2.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
         zz = relay.Call(f2, [z])
         return relay.Function([x], zz)
 
diff --git a/tests/python/relay/test_pass_merge_composite.py 
b/tests/python/relay/test_pass_merge_composite.py
index 4f785d7..4f5acc7 100644
--- a/tests/python/relay/test_pass_merge_composite.py
+++ b/tests/python/relay/test_pass_merge_composite.py
@@ -15,8 +15,8 @@
 # specific language governing permissions and limitations
 # under the License.
 """Unit tests for merge composite."""
-from tvm import expr
 from tvm import relay
+from tvm import tir
 from tvm.relay.testing import run_opt_pass
 
 """
@@ -144,6 +144,8 @@ def test_simple_merge():
         add_node = relay.add(in_1, in_2)
         relu_node = relay.nn.relu(add_node)
         add_relu = relay.Function([in_1, in_2], relu_node)
+        add_relu = add_relu.set_attribute("Primitive", tir.IntImm("int32", 1))
+        add_relu = add_relu.set_attribute("Composite", 
tir.StringImm("add_relu"))
 
         # merged function
         r = relay.Call(add_relu, [a, b])
@@ -208,11 +210,27 @@ def test_branch_merge():
         sub_node = relay.subtract(in_1, in_2)
         mul_node = relay.multiply(add_node, sub_node)
         add_sub_mul = relay.Function([in_1, in_2], mul_node)
+        add_sub_mul = add_sub_mul.set_attribute("Primitive",
+                                                tir.IntImm("int32", 1))
+        add_sub_mul = add_sub_mul.set_attribute("Composite",
+                                                tir.StringImm("add_sub_mul"))
+
+        # add_sub_mul1 function
+        in_3 = relay.var('in_3', shape=(10, 10))
+        in_4 = relay.var('in_4', shape=(10, 10))
+        add_node_1 = relay.add(in_3, in_4)
+        sub_node_1 = relay.subtract(in_3, in_4)
+        mul_node_1 = relay.multiply(add_node_1, sub_node_1)
+        add_sub_mul_1 = relay.Function([in_3, in_4], mul_node_1)
+        add_sub_mul_1 = add_sub_mul_1.set_attribute("Primitive",
+                                                    tir.IntImm("int32", 1))
+        add_sub_mul_1 = add_sub_mul_1.set_attribute("Composite",
+                                                    
tir.StringImm("add_sub_mul"))
 
         # merged function
-        add_sub_mul_1 = relay.Call(add_sub_mul, [a, b])
-        add_sub_mul_2 = relay.Call(add_sub_mul, [c, add_sub_mul_1])
-        r = relay.nn.relu(add_sub_mul_2)
+        m_add_sub_mul_1 = relay.Call(add_sub_mul, [a, b])
+        m_add_sub_mul_2 = relay.Call(add_sub_mul_1, [c, m_add_sub_mul_1])
+        r = relay.nn.relu(m_add_sub_mul_2)
         return relay.Function([a, b, c], r)
 
     result = run_opt_pass(before(), 
relay.transform.MergeComposite(pattern_table))
@@ -291,6 +309,9 @@ def test_multiple_patterns():
         bias_node = relay.nn.bias_add(conv_node, in_3)
         r = relay.nn.relu(bias_node)
         conv_bias_add_relu = relay.Function([in_1, in_2, in_3], r)
+        conv_bias_add_relu = conv_bias_add_relu.set_attribute("Primitive", 
tir.IntImm("int32", 1))
+        conv_bias_add_relu = conv_bias_add_relu.set_attribute("Composite",
+                                                              
tir.StringImm("conv2d_bias_relu"))
 
         # add_relu function
         in_4 = relay.var('in_4', shape=(1, 256, 28, 28))
@@ -298,6 +319,8 @@ def test_multiple_patterns():
         add_node = relay.add(in_4, in_5)
         r = relay.nn.relu(add_node)
         add_relu = relay.Function([in_4, in_5], r)
+        add_relu = add_relu.set_attribute("Primitive", tir.IntImm("int32", 1))
+        add_relu = add_relu.set_attribute("Composite", 
tir.StringImm("add_relu"))
 
         # merged function
         conv_bias_add_relu_1 = relay.Call(conv_bias_add_relu, [data, kernel, 
bias])
@@ -357,7 +380,7 @@ def test_merge_order():
         out = relay.nn.relu(out)
         return relay.Function([input_1, input_2], out)
 
-    def after_A_priority():
+    def after_A_priority(composite_name):
         input_1 = relay.var('input_1', shape=(10, 10))
         input_2 = relay.var('input_2', shape=(10, 10))
         x = relay.var('x')
@@ -366,38 +389,12 @@ def test_merge_order():
         out = relay.abs(out)
         out = relay.nn.relu(out)
         merged_func = relay.Function([x, y], out)
-        merged_func = merged_func.set_attribute('Primitive', 
expr.IntImm('int32', 1))
-        merged_func = merged_func.set_attribute('Composite', 
expr.StringImm('A'))
+        merged_func = merged_func.set_attribute('Primitive', 
tir.IntImm('int32', 1))
+        merged_func = merged_func.set_attribute('Composite',
+                                                tir.StringImm(composite_name))
         ret = relay.Call(merged_func, [input_1, input_2])
         return relay.Function([input_1, input_2], ret)
 
-    def after_B_priority():
-        input_1 = relay.var('input_1', shape=(10, 10))
-        input_2 = relay.var('input_2', shape=(10, 10))
-        x = relay.var('x')
-        y = relay.var('y')
-        out = relay.add(x, y)
-        out = relay.abs(out)
-        merged_func = relay.Function([x, y], out)
-        merged_func = merged_func.set_attribute('Primitive', 
expr.IntImm('int32', 1))
-        merged_func = merged_func.set_attribute('Composite', 
expr.StringImm('B'))
-        merged_call = relay.Call(merged_func, [input_1, input_2])
-        ret = relay.nn.relu(merged_call)
-        return relay.Function([input_1, input_2], ret)
-
-    def after_C_priority():
-        input_1 = relay.var('input_1', shape=(10, 10))
-        input_2 = relay.var('input_2', shape=(10, 10))
-        add = relay.add(input_1, input_2)
-        x = relay.var('x')
-        out = relay.abs(x)
-        out = relay.nn.relu(out)
-        merged_func = relay.Function([x], out)
-        merged_func = merged_func.set_attribute('Primitive', 
expr.IntImm('int32', 1))
-        merged_func = merged_func.set_attribute('Composite', 
expr.StringImm('C'))
-        ret = relay.Call(merged_func, [add])
-        return relay.Function([input_1, input_2], ret)
-
     # check A highest priority
     pattern_table = [
         ("A", pattern_A()),
@@ -406,7 +403,7 @@ def test_merge_order():
     ]
     result = run_opt_pass(before(), 
relay.transform.MergeComposite(pattern_table))
     assert not relay.analysis.free_vars(result)
-    expected = run_opt_pass(after_A_priority(), relay.transform.InferType())
+    expected = run_opt_pass(after_A_priority("A"), relay.transform.InferType())
     assert relay.analysis.alpha_equal(result, expected)
 
     # check B highest priority
@@ -417,7 +414,7 @@ def test_merge_order():
     ]
     result = run_opt_pass(before(), 
relay.transform.MergeComposite(pattern_table))
     assert not relay.analysis.free_vars(result)
-    expected = run_opt_pass(after_A_priority(), relay.transform.InferType())
+    expected = run_opt_pass(after_A_priority("B"), relay.transform.InferType())
     assert relay.analysis.alpha_equal(result, expected)
 
     # check C highest priority
@@ -428,7 +425,7 @@ def test_merge_order():
     ]
     result = run_opt_pass(before(), 
relay.transform.MergeComposite(pattern_table))
     assert not relay.analysis.free_vars(result)
-    expected = run_opt_pass(after_A_priority(), relay.transform.InferType())
+    expected = run_opt_pass(after_A_priority("C"), relay.transform.InferType())
     assert relay.analysis.alpha_equal(result, expected)
 
 
@@ -459,11 +456,15 @@ def test_parallel_merge():
         y = relay.var('y')
         branch_1 = relay.multiply(relay.add(x, y), relay.subtract(x, y))
         func_1 = relay.Function([x, y], branch_1)
+        func_1 = func_1.set_attribute('Primitive', tir.IntImm('int32', 1))
+        func_1 = func_1.set_attribute('Composite', 
tir.StringImm("add_sub_mul"))
         call_1 = relay.Call(func_1, [input_1, input_2])
         x1 = relay.var('x1')
         y1 = relay.var('y1')
         branch_2 = relay.multiply(relay.add(x1, y1), relay.subtract(x1, y1))
         func_2 = relay.Function([x1, y1], branch_2)
+        func_2 = func_2.set_attribute('Primitive', tir.IntImm('int32', 1))
+        func_2 = func_2.set_attribute('Composite', 
tir.StringImm("add_sub_mul"))
         call_2 = relay.Call(func_2, [input_1, input_2])
         out = relay.multiply(call_1, call_2)
         return relay.Function([input_1, input_2], out)
@@ -542,16 +543,16 @@ def test_multiple_input_subgraphs():
         add_relu_1 = relay.add(x, y)
         add_relu_1 = relay.nn.relu(add_relu_1)
         add_relu_1 = relay.Function([x, y], add_relu_1)
-        add_relu_1 = add_relu_1.set_attribute('Primitive', 
expr.IntImm('int32', 1))
-        add_relu_1 = add_relu_1.set_attribute('Composite', 
expr.StringImm('add_relu'))
+        add_relu_1 = add_relu_1.set_attribute('Primitive', tir.IntImm('int32', 
1))
+        add_relu_1 = add_relu_1.set_attribute('Composite', 
tir.StringImm('add_relu'))
         add_relu_call_1 = relay.Call(add_relu_1, [inputs[0], inputs[1]])
         x1 = relay.var('x1')
         y1 = relay.var('y1')
         add_relu_2 = relay.add(x1, y1)
         add_relu_2 = relay.nn.relu(add_relu_2)
         add_relu_2 = relay.Function([x1, y1], add_relu_2)
-        add_relu_2 = add_relu_2.set_attribute('Primitive', 
expr.IntImm('int32', 1))
-        add_relu_2 = add_relu_2.set_attribute('Composite', 
expr.StringImm('add_relu'))
+        add_relu_2 = add_relu_2.set_attribute('Primitive', tir.IntImm('int32', 
1))
+        add_relu_2 = add_relu_2.set_attribute('Composite', 
tir.StringImm('add_relu'))
         add_relu_call_2 = relay.Call(add_relu_2, [inputs[2], inputs[3]])
         x2 = relay.var('x2')
         y2 = relay.var('y2')
@@ -559,8 +560,8 @@ def test_multiple_input_subgraphs():
         sub = relay.subtract(x2, y2)
         add_sub_mul = relay.multiply(add, sub)
         add_sub_mul = relay.Function([x2, y2], add_sub_mul)
-        add_sub_mul = add_sub_mul.set_attribute('Primitive', 
expr.IntImm('int32', 1))
-        add_sub_mul = add_sub_mul.set_attribute('Composite', 
expr.StringImm('add_sub_mul'))
+        add_sub_mul = add_sub_mul.set_attribute('Primitive', 
tir.IntImm('int32', 1))
+        add_sub_mul = add_sub_mul.set_attribute('Composite', 
tir.StringImm('add_sub_mul'))
         add_sub_mul_call = relay.Call(add_sub_mul, [add_relu_call_1, 
add_relu_call_2])
         return relay.Function(inputs, add_sub_mul_call)
 
@@ -573,8 +574,8 @@ def test_multiple_input_subgraphs():
             add_relu = relay.add(x, y)
             add_relu = relay.nn.relu(add_relu)
             add_relu = relay.Function([x, y], add_relu)
-            add_relu = add_relu.set_attribute('Primitive', 
expr.IntImm('int32', 1))
-            add_relu = add_relu.set_attribute('Composite', 
expr.StringImm('add_relu'))
+            add_relu = add_relu.set_attribute('Primitive', tir.IntImm('int32', 
1))
+            add_relu = add_relu.set_attribute('Composite', 
tir.StringImm('add_relu'))
             add_relu_call = relay.Call(add_relu, [inputs[i*2], inputs[i*2+1]])
             add_relu_calls.append(add_relu_call)
 
@@ -606,4 +607,4 @@ if __name__ == "__main__":
     test_multiple_patterns()
     test_merge_order()
     test_parallel_merge()
-    test_multiple_input_subgraphs()
\ No newline at end of file
+    test_multiple_input_subgraphs()

Reply via email to