This is an automated email from the ASF dual-hosted git repository.

zhic pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 27a0284  [Relay][Pass] Fix bug in re-processing call node in 
MergeComposite pass (#4879)
27a0284 is described below

commit 27a02844cb52e883a4a66da68a527590d76f7d01
Author: Jon Soifer <soif...@gmail.com>
AuthorDate: Mon Feb 17 12:18:15 2020 -0800

    [Relay][Pass] Fix bug in re-processing call node in MergeComposite pass 
(#4879)
    
    * Fix bug in re-processing call node
    
    * Add test
    
    * Add to main
    
    * temp changes to work from another machine
    
    * fix rest of tests
    
    * fix test_reuse_call_merge
    
    * fix merge
    
    Co-authored-by: Jon Soifer <jo...@microsoft.com>
---
 src/relay/pass/merge_composite.cc               | 25 +++++---
 tests/python/relay/test_pass_merge_composite.py | 82 +++++++++++++++++++++++++
 2 files changed, 98 insertions(+), 9 deletions(-)

diff --git a/src/relay/pass/merge_composite.cc 
b/src/relay/pass/merge_composite.cc
index 28bf8fa..4e1094b 100644
--- a/src/relay/pass/merge_composite.cc
+++ b/src/relay/pass/merge_composite.cc
@@ -87,7 +87,7 @@ class MergeCompositeWrapper : public ExprMutator {
    * a new Relay expression ready to be wrapped into a composite function.
    */
   Expr ExtractPattern(const Call& pattern, const Call& root,
-          Map<std::string, Array<Expr>>* var_map) {
+          Map<std::string, Array<Expr>>* var_map, Map<Expr, Expr>* call_map) {
     // check to make sure both calls are to operators (not functions)
     if (!pattern->op->IsInstance<OpNode>() || !root->op->IsInstance<OpNode>())
       return Expr();
@@ -99,14 +99,20 @@ class MergeCompositeWrapper : public ExprMutator {
     for (const auto& arg : pattern->args) {
       Expr new_arg;
       if (arg->IsInstance<CallNode>()) {
-        // fail if the root argument is not also a call node
-        if (!root->args[i]->IsInstance<CallNode>()) {
-          return Expr();
+        // if we've already processed this call node, return the previous 
result
+        if (call_map->find(arg) != call_map->end()) {
+          new_arg = (*call_map)[arg];
+        } else {
+          // fail if the root argument is not also a call node
+          if (!root->args[i]->IsInstance<CallNode>()) {
+            return Expr();
+          }
+          // if it's a call node, recursively call this function
+          new_arg = ExtractPattern(Downcast<Call>(arg),
+                                  Downcast<Call>(root->args[i]),
+                                  var_map, call_map);
+          call_map->Set(arg, new_arg);
         }
-        // if it's a call node, recursively call this function
-        new_arg = ExtractPattern(Downcast<Call>(arg),
-                                 Downcast<Call>(root->args[i]),
-                                 var_map);
       } else if (arg->IsInstance<VarNode>()) {
         // if there's a var in the pattern, it must be a free var
         // so call the function to update the var_map
@@ -155,7 +161,8 @@ class MergeCompositeWrapper : public ExprMutator {
     Call pattern = Downcast<Call>(pattern_);
     CHECK(pattern.defined());
     Map<std::string, Array<Expr>> args_map;
-    auto extract = ExtractPattern(pattern, call, &args_map);
+    Map<Expr, Expr> call_map;
+    auto extract = ExtractPattern(pattern, call, &args_map, &call_map);
     if (extract.defined()) {
       auto free_vars = FreeVars(extract);
       // make the composite function
diff --git a/tests/python/relay/test_pass_merge_composite.py 
b/tests/python/relay/test_pass_merge_composite.py
index 4f5acc7..b96a89b 100644
--- a/tests/python/relay/test_pass_merge_composite.py
+++ b/tests/python/relay/test_pass_merge_composite.py
@@ -110,6 +110,26 @@ def make_conv_bias_relu_pattern():
     return r
 
 
+def make_add_add_add_pattern():
+    """Create a pattern to match the following graph.
+       Useful for testing re-using a call node.
+
+        x    y
+      /  \  /
+      |  add
+       \  |  \
+         add |
+          | /
+         add
+    """
+    x = relay.var('x')
+    y = relay.var('y')
+    add_node = relay.add(x, y)
+    add_node_1 = relay.add(x, add_node)
+    r = relay.add(add_node_1, add_node)
+    return r
+
+
 def test_simple_merge():
     """Test composite function is correctly produced from simple graph.
 
@@ -239,6 +259,67 @@ def test_branch_merge():
     assert relay.analysis.alpha_equal(result, expected)
 
 
+def test_reuse_call_merge():
+    """Test composite function is correctly produced from simple graph
+       which re-uses call nodes.
+
+    We could expect the pattern `make_add_add_add` to be merged
+    into a single op `add_add_add`.
+
+        x     y
+         \   / \
+          sub  |           x     y
+        /  |  /             \   / |
+        | add      ====>     sub  |
+         \ |  \               |  /
+          add |           add_add_add
+           | /
+          add
+
+    """
+    pattern_table = [
+        ("add_add_add", make_add_add_add_pattern())
+    ]
+
+    def before():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+        sub_node = relay.subtract(a, b)
+
+        # pattern
+        add_node = relay.add(sub_node, b)
+        add_node_1 = relay.add(sub_node, add_node)
+        r = relay.add(add_node_1, add_node)
+
+        return relay.Function([a, b], r)
+
+    def expected():
+        a = relay.var('a', shape=(10, 10))
+        b = relay.var('b', shape=(10, 10))
+
+        # add_relu_add function
+        in_1 = relay.var('in_1', shape=(10, 10))
+        in_2 = relay.var('in_2', shape=(10, 10))
+        add_node = relay.add(in_1, in_2)
+        add_node_1 = relay.add(in_1, add_node)
+        add_node_2 = relay.add(add_node_1, add_node)
+        add_add_add = relay.Function([in_1, in_2], add_node_2)
+        add_add_add = add_add_add.set_attribute("Primitive",
+                                                tir.IntImm("int32", 1))
+        add_add_add = add_add_add.set_attribute("Composite",
+                                                tir.StringImm("add_add_add"))
+
+        # merged function
+        sub_node = relay.subtract(a, b)
+        call = relay.Call(add_add_add, [sub_node, b])
+        return relay.Function([a, b], call)
+
+    result = run_opt_pass(before(), 
relay.transform.MergeComposite(pattern_table))
+    assert not relay.analysis.free_vars(result)
+    expected = run_opt_pass(expected(), relay.transform.InferType())
+    assert relay.analysis.alpha_equal(result, expected)
+
+
 def test_multiple_patterns():
     """Test different patterns are merged correctly in the graph.
 
@@ -608,3 +689,4 @@ if __name__ == "__main__":
     test_merge_order()
     test_parallel_merge()
     test_multiple_input_subgraphs()
+    test_reuse_call_merge()
\ No newline at end of file

Reply via email to