This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 0e87752  [BYOC, MergeComposite] Add additional check before re-using 
the cached match (#5552)
0e87752 is described below

commit 0e877521f454e239f5c44bb88e557801444d81a5
Author: masahi <[email protected]>
AuthorDate: Mon May 11 21:00:43 2020 +0900

    [BYOC, MergeComposite] Add additional check before re-using the cached 
match (#5552)
    
    * Add additional check before re-using the cached match in merge composite
    
    * clean up ExtractPattern calls
---
 src/relay/transforms/merge_composite.cc         | 11 ++-----
 tests/python/relay/test_pass_merge_composite.py | 39 +++++++++++++++++++++++++
 2 files changed, 42 insertions(+), 8 deletions(-)

diff --git a/src/relay/transforms/merge_composite.cc 
b/src/relay/transforms/merge_composite.cc
index 46fdae0..ae549fa 100644
--- a/src/relay/transforms/merge_composite.cc
+++ b/src/relay/transforms/merge_composite.cc
@@ -121,17 +121,12 @@ class MergeCompositeWrapper : public ExprMutator {
     for (const auto& arg : pattern->args) {
       Expr new_arg;
       if (arg->IsInstance<CallNode>()) {
+        new_arg =
+            ExtractPattern(Downcast<Call>(arg), Downcast<Call>(root->args[i]), 
var_map, call_map);
         // if we've already processed this call node, return the previous 
result
-        if (call_map->find(arg) != call_map->end()) {
+        if (call_map->find(arg) != call_map->end() && new_arg.defined()) {
           new_arg = (*call_map)[arg];
         } else {
-          // fail if the root argument is not also a call node
-          if (!root->args[i]->IsInstance<CallNode>()) {
-            return Expr();
-          }
-          // if it's a call node, recursively call this function
-          new_arg =
-              ExtractPattern(Downcast<Call>(arg), 
Downcast<Call>(root->args[i]), var_map, call_map);
           call_map->Set(arg, new_arg);
         }
       } else if (arg->IsInstance<VarNode>()) {
diff --git a/tests/python/relay/test_pass_merge_composite.py 
b/tests/python/relay/test_pass_merge_composite.py
index e3c8991..317bb42 100644
--- a/tests/python/relay/test_pass_merge_composite.py
+++ b/tests/python/relay/test_pass_merge_composite.py
@@ -765,6 +765,44 @@ def test_pattern_with_check():
     assert result.body.op.attrs["Composite"] == "conv_bias_relu"
 
 
+def test_diamond_not_merge():
+    """
+    The pattern on the left shouldn't match the structure on the right
+
+    relu             relu
+     | \              | \
+     | clip           | add
+     |  /             |  |
+     mul              | clip
+                      |  /
+                      mul
+    """
+    def get_pattern():
+        conv = make_conv_bias_relu_pattern()
+        clip = relay.op.clip(conv, 0, 255)
+        return relay.op.multiply(conv, clip)
+
+    def get_net():
+        data = relay.var('data', shape=(1, 512, 28, 28))
+        kernel = relay.var('kernel', shape=(256, 512, 1, 1))
+        conv = relay.nn.conv2d(data, kernel,
+                               kernel_size=(1, 1),
+                               padding=(0, 0),
+                               strides=(1, 1))
+        bias = relay.nn.bias_add(conv, relay.var('bias', shape=(256,)))
+        relu = relay.nn.relu(bias)
+        add = relay.op.add(relu, relay.const(1.0))
+        clip2 = relay.op.clip(add, 0, 255)
+        mul = relay.op.multiply(relu, clip2)
+        return relay.Function(relay.analysis.free_vars(mul), mul)
+
+    pat_table = [("pat", get_pattern())]
+    net = get_net()
+    result = run_opt_pass(net, relay.transform.MergeComposite(pat_table))
+    expected = run_opt_pass(net, relay.transform.InferType())
+    assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
+
+
 if __name__ == "__main__":
     test_simple_merge()
     test_branch_merge()
@@ -775,3 +813,4 @@ if __name__ == "__main__":
     test_reuse_call_merge()
     test_tuple_get_item_merge()
     test_pattern_with_check()
+    test_diamond_not_merge()

Reply via email to