This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 0e87752 [BYOC, MergeComposite] Add additional check before re-using
the cached match (#5552)
0e87752 is described below
commit 0e877521f454e239f5c44bb88e557801444d81a5
Author: masahi <[email protected]>
AuthorDate: Mon May 11 21:00:43 2020 +0900
[BYOC, MergeComposite] Add additional check before re-using the cached
match (#5552)
* Add additional check before re-using the cached match in merge composite
* clean up ExtractPattern calls
---
src/relay/transforms/merge_composite.cc | 11 ++-----
tests/python/relay/test_pass_merge_composite.py | 39 +++++++++++++++++++++++++
2 files changed, 42 insertions(+), 8 deletions(-)
diff --git a/src/relay/transforms/merge_composite.cc
b/src/relay/transforms/merge_composite.cc
index 46fdae0..ae549fa 100644
--- a/src/relay/transforms/merge_composite.cc
+++ b/src/relay/transforms/merge_composite.cc
@@ -121,17 +121,12 @@ class MergeCompositeWrapper : public ExprMutator {
for (const auto& arg : pattern->args) {
Expr new_arg;
if (arg->IsInstance<CallNode>()) {
+ new_arg =
+ ExtractPattern(Downcast<Call>(arg), Downcast<Call>(root->args[i]),
var_map, call_map);
// if we've already processed this call node, return the previous
result
- if (call_map->find(arg) != call_map->end()) {
+ if (call_map->find(arg) != call_map->end() && new_arg.defined()) {
new_arg = (*call_map)[arg];
} else {
- // fail if the root argument is not also a call node
- if (!root->args[i]->IsInstance<CallNode>()) {
- return Expr();
- }
- // if it's a call node, recursively call this function
- new_arg =
- ExtractPattern(Downcast<Call>(arg),
Downcast<Call>(root->args[i]), var_map, call_map);
call_map->Set(arg, new_arg);
}
} else if (arg->IsInstance<VarNode>()) {
diff --git a/tests/python/relay/test_pass_merge_composite.py
b/tests/python/relay/test_pass_merge_composite.py
index e3c8991..317bb42 100644
--- a/tests/python/relay/test_pass_merge_composite.py
+++ b/tests/python/relay/test_pass_merge_composite.py
@@ -765,6 +765,44 @@ def test_pattern_with_check():
assert result.body.op.attrs["Composite"] == "conv_bias_relu"
+def test_diamond_not_merge():
+ """
+ The pattern on the left shouldn't match the structure on the right
+
+ relu relu
+ | \ | \
+ | clip | add
+ | / | |
+ mul | clip
+ | /
+ mul
+ """
+ def get_pattern():
+ conv = make_conv_bias_relu_pattern()
+ clip = relay.op.clip(conv, 0, 255)
+ return relay.op.multiply(conv, clip)
+
+ def get_net():
+ data = relay.var('data', shape=(1, 512, 28, 28))
+ kernel = relay.var('kernel', shape=(256, 512, 1, 1))
+ conv = relay.nn.conv2d(data, kernel,
+ kernel_size=(1, 1),
+ padding=(0, 0),
+ strides=(1, 1))
+ bias = relay.nn.bias_add(conv, relay.var('bias', shape=(256,)))
+ relu = relay.nn.relu(bias)
+ add = relay.op.add(relu, relay.const(1.0))
+ clip2 = relay.op.clip(add, 0, 255)
+ mul = relay.op.multiply(relu, clip2)
+ return relay.Function(relay.analysis.free_vars(mul), mul)
+
+ pat_table = [("pat", get_pattern())]
+ net = get_net()
+ result = run_opt_pass(net, relay.transform.MergeComposite(pat_table))
+ expected = run_opt_pass(net, relay.transform.InferType())
+ assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
+
+
if __name__ == "__main__":
test_simple_merge()
test_branch_merge()
@@ -775,3 +813,4 @@ if __name__ == "__main__":
test_reuse_call_merge()
test_tuple_get_item_merge()
test_pattern_with_check()
+ test_diamond_not_merge()