This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new a9cfe41 [RELAY] Remove primitive attribute from composite function
(#5014)
a9cfe41 is described below
commit a9cfe4158badf906fdb72a2450db9335bd9386ad
Author: lhutton1 <[email protected]>
AuthorDate: Tue Mar 10 08:10:07 2020 +0000
[RELAY] Remove primitive attribute from composite function (#5014)
* A composite function should not be primitive since we still may need to
perform passes on it.
Change-Id: If62d06d265234861a6ec0df7749dc1c339c1055c
---
src/relay/pass/merge_composite.cc | 1 -
tests/python/relay/test_pass_merge_composite.py | 16 ----------------
2 files changed, 17 deletions(-)
diff --git a/src/relay/pass/merge_composite.cc
b/src/relay/pass/merge_composite.cc
index 4e1094b..162bf3a 100644
--- a/src/relay/pass/merge_composite.cc
+++ b/src/relay/pass/merge_composite.cc
@@ -168,7 +168,6 @@ class MergeCompositeWrapper : public ExprMutator {
// make the composite function
auto f = FunctionNode::make(free_vars, extract, call->checked_type_, {},
Attrs());
f = FunctionSetAttr(f, attr::kComposite,
tir::StringImmNode::make(pattern_name_));
- f = FunctionSetAttr(f, attr::kPrimitive, tvm::Integer(1));
// find the expressions associated with the free vars using the args_map
// this tells us which expressions should be given as inputs to the
composite function
Array<Expr> args;
diff --git a/tests/python/relay/test_pass_merge_composite.py
b/tests/python/relay/test_pass_merge_composite.py
index b96a89b..bcf61a0 100644
--- a/tests/python/relay/test_pass_merge_composite.py
+++ b/tests/python/relay/test_pass_merge_composite.py
@@ -164,7 +164,6 @@ def test_simple_merge():
add_node = relay.add(in_1, in_2)
relu_node = relay.nn.relu(add_node)
add_relu = relay.Function([in_1, in_2], relu_node)
- add_relu = add_relu.set_attribute("Primitive", tir.IntImm("int32", 1))
add_relu = add_relu.set_attribute("Composite",
tir.StringImm("add_relu"))
# merged function
@@ -230,8 +229,6 @@ def test_branch_merge():
sub_node = relay.subtract(in_1, in_2)
mul_node = relay.multiply(add_node, sub_node)
add_sub_mul = relay.Function([in_1, in_2], mul_node)
- add_sub_mul = add_sub_mul.set_attribute("Primitive",
- tir.IntImm("int32", 1))
add_sub_mul = add_sub_mul.set_attribute("Composite",
tir.StringImm("add_sub_mul"))
@@ -242,8 +239,6 @@ def test_branch_merge():
sub_node_1 = relay.subtract(in_3, in_4)
mul_node_1 = relay.multiply(add_node_1, sub_node_1)
add_sub_mul_1 = relay.Function([in_3, in_4], mul_node_1)
- add_sub_mul_1 = add_sub_mul_1.set_attribute("Primitive",
- tir.IntImm("int32", 1))
add_sub_mul_1 = add_sub_mul_1.set_attribute("Composite",
tir.StringImm("add_sub_mul"))
@@ -304,8 +299,6 @@ def test_reuse_call_merge():
add_node_1 = relay.add(in_1, add_node)
add_node_2 = relay.add(add_node_1, add_node)
add_add_add = relay.Function([in_1, in_2], add_node_2)
- add_add_add = add_add_add.set_attribute("Primitive",
- tir.IntImm("int32", 1))
add_add_add = add_add_add.set_attribute("Composite",
tir.StringImm("add_add_add"))
@@ -390,7 +383,6 @@ def test_multiple_patterns():
bias_node = relay.nn.bias_add(conv_node, in_3)
r = relay.nn.relu(bias_node)
conv_bias_add_relu = relay.Function([in_1, in_2, in_3], r)
- conv_bias_add_relu = conv_bias_add_relu.set_attribute("Primitive",
tir.IntImm("int32", 1))
conv_bias_add_relu = conv_bias_add_relu.set_attribute("Composite",
tir.StringImm("conv2d_bias_relu"))
@@ -400,7 +392,6 @@ def test_multiple_patterns():
add_node = relay.add(in_4, in_5)
r = relay.nn.relu(add_node)
add_relu = relay.Function([in_4, in_5], r)
- add_relu = add_relu.set_attribute("Primitive", tir.IntImm("int32", 1))
add_relu = add_relu.set_attribute("Composite",
tir.StringImm("add_relu"))
# merged function
@@ -470,7 +461,6 @@ def test_merge_order():
out = relay.abs(out)
out = relay.nn.relu(out)
merged_func = relay.Function([x, y], out)
- merged_func = merged_func.set_attribute('Primitive',
tir.IntImm('int32', 1))
merged_func = merged_func.set_attribute('Composite',
tir.StringImm(composite_name))
ret = relay.Call(merged_func, [input_1, input_2])
@@ -537,14 +527,12 @@ def test_parallel_merge():
y = relay.var('y')
branch_1 = relay.multiply(relay.add(x, y), relay.subtract(x, y))
func_1 = relay.Function([x, y], branch_1)
- func_1 = func_1.set_attribute('Primitive', tir.IntImm('int32', 1))
func_1 = func_1.set_attribute('Composite',
tir.StringImm("add_sub_mul"))
call_1 = relay.Call(func_1, [input_1, input_2])
x1 = relay.var('x1')
y1 = relay.var('y1')
branch_2 = relay.multiply(relay.add(x1, y1), relay.subtract(x1, y1))
func_2 = relay.Function([x1, y1], branch_2)
- func_2 = func_2.set_attribute('Primitive', tir.IntImm('int32', 1))
func_2 = func_2.set_attribute('Composite',
tir.StringImm("add_sub_mul"))
call_2 = relay.Call(func_2, [input_1, input_2])
out = relay.multiply(call_1, call_2)
@@ -624,7 +612,6 @@ def test_multiple_input_subgraphs():
add_relu_1 = relay.add(x, y)
add_relu_1 = relay.nn.relu(add_relu_1)
add_relu_1 = relay.Function([x, y], add_relu_1)
- add_relu_1 = add_relu_1.set_attribute('Primitive', tir.IntImm('int32',
1))
add_relu_1 = add_relu_1.set_attribute('Composite',
tir.StringImm('add_relu'))
add_relu_call_1 = relay.Call(add_relu_1, [inputs[0], inputs[1]])
x1 = relay.var('x1')
@@ -632,7 +619,6 @@ def test_multiple_input_subgraphs():
add_relu_2 = relay.add(x1, y1)
add_relu_2 = relay.nn.relu(add_relu_2)
add_relu_2 = relay.Function([x1, y1], add_relu_2)
- add_relu_2 = add_relu_2.set_attribute('Primitive', tir.IntImm('int32',
1))
add_relu_2 = add_relu_2.set_attribute('Composite',
tir.StringImm('add_relu'))
add_relu_call_2 = relay.Call(add_relu_2, [inputs[2], inputs[3]])
x2 = relay.var('x2')
@@ -641,7 +627,6 @@ def test_multiple_input_subgraphs():
sub = relay.subtract(x2, y2)
add_sub_mul = relay.multiply(add, sub)
add_sub_mul = relay.Function([x2, y2], add_sub_mul)
- add_sub_mul = add_sub_mul.set_attribute('Primitive',
tir.IntImm('int32', 1))
add_sub_mul = add_sub_mul.set_attribute('Composite',
tir.StringImm('add_sub_mul'))
add_sub_mul_call = relay.Call(add_sub_mul, [add_relu_call_1,
add_relu_call_2])
return relay.Function(inputs, add_sub_mul_call)
@@ -655,7 +640,6 @@ def test_multiple_input_subgraphs():
add_relu = relay.add(x, y)
add_relu = relay.nn.relu(add_relu)
add_relu = relay.Function([x, y], add_relu)
- add_relu = add_relu.set_attribute('Primitive', tir.IntImm('int32',
1))
add_relu = add_relu.set_attribute('Composite',
tir.StringImm('add_relu'))
add_relu_call = relay.Call(add_relu, [inputs[i*2], inputs[i*2+1]])
add_relu_calls.append(add_relu_call)