This is an automated email from the ASF dual-hosted git repository.
zhic pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 27a0284 [Relay][Pass] Fix bug in re-processing call node in
MergeComposite pass (#4879)
27a0284 is described below
commit 27a02844cb52e883a4a66da68a527590d76f7d01
Author: Jon Soifer <[email protected]>
AuthorDate: Mon Feb 17 12:18:15 2020 -0800
[Relay][Pass] Fix bug in re-processing call node in MergeComposite pass
(#4879)
* Fix bug in re-processing call node
* Add test
* Add to main
* temp changes to work from another machine
* fix rest of tests
* fix test_reuse_call_merge
* fix merge
Co-authored-by: Jon Soifer <[email protected]>
---
src/relay/pass/merge_composite.cc | 25 +++++---
tests/python/relay/test_pass_merge_composite.py | 82 +++++++++++++++++++++++++
2 files changed, 98 insertions(+), 9 deletions(-)
diff --git a/src/relay/pass/merge_composite.cc
b/src/relay/pass/merge_composite.cc
index 28bf8fa..4e1094b 100644
--- a/src/relay/pass/merge_composite.cc
+++ b/src/relay/pass/merge_composite.cc
@@ -87,7 +87,7 @@ class MergeCompositeWrapper : public ExprMutator {
* a new Relay expression ready to be wrapped into a composite function.
*/
Expr ExtractPattern(const Call& pattern, const Call& root,
- Map<std::string, Array<Expr>>* var_map) {
+ Map<std::string, Array<Expr>>* var_map, Map<Expr, Expr>* call_map) {
// check to make sure both calls are to operators (not functions)
if (!pattern->op->IsInstance<OpNode>() || !root->op->IsInstance<OpNode>())
return Expr();
@@ -99,14 +99,20 @@ class MergeCompositeWrapper : public ExprMutator {
for (const auto& arg : pattern->args) {
Expr new_arg;
if (arg->IsInstance<CallNode>()) {
- // fail if the root argument is not also a call node
- if (!root->args[i]->IsInstance<CallNode>()) {
- return Expr();
+ // if we've already processed this call node, return the previous
result
+ if (call_map->find(arg) != call_map->end()) {
+ new_arg = (*call_map)[arg];
+ } else {
+ // fail if the root argument is not also a call node
+ if (!root->args[i]->IsInstance<CallNode>()) {
+ return Expr();
+ }
+ // if it's a call node, recursively call this function
+ new_arg = ExtractPattern(Downcast<Call>(arg),
+ Downcast<Call>(root->args[i]),
+ var_map, call_map);
+ call_map->Set(arg, new_arg);
}
- // if it's a call node, recursively call this function
- new_arg = ExtractPattern(Downcast<Call>(arg),
- Downcast<Call>(root->args[i]),
- var_map);
} else if (arg->IsInstance<VarNode>()) {
// if there's a var in the pattern, it must be a free var
// so call the function to update the var_map
@@ -155,7 +161,8 @@ class MergeCompositeWrapper : public ExprMutator {
Call pattern = Downcast<Call>(pattern_);
CHECK(pattern.defined());
Map<std::string, Array<Expr>> args_map;
- auto extract = ExtractPattern(pattern, call, &args_map);
+ Map<Expr, Expr> call_map;
+ auto extract = ExtractPattern(pattern, call, &args_map, &call_map);
if (extract.defined()) {
auto free_vars = FreeVars(extract);
// make the composite function
diff --git a/tests/python/relay/test_pass_merge_composite.py
b/tests/python/relay/test_pass_merge_composite.py
index 4f5acc7..b96a89b 100644
--- a/tests/python/relay/test_pass_merge_composite.py
+++ b/tests/python/relay/test_pass_merge_composite.py
@@ -110,6 +110,26 @@ def make_conv_bias_relu_pattern():
return r
+def make_add_add_add_pattern():
+ """Create a pattern to match the following graph.
+ Useful for testing re-using a call node.
+
+ x y
+ / \ /
+ | add
+ \ | \
+ add |
+ | /
+ add
+ """
+ x = relay.var('x')
+ y = relay.var('y')
+ add_node = relay.add(x, y)
+ add_node_1 = relay.add(x, add_node)
+ r = relay.add(add_node_1, add_node)
+ return r
+
+
def test_simple_merge():
"""Test composite function is correctly produced from simple graph.
@@ -239,6 +259,67 @@ def test_branch_merge():
assert relay.analysis.alpha_equal(result, expected)
+def test_reuse_call_merge():
+ """Test composite function is correctly produced from simple graph
+ which re-uses call nodes.
+
+ We could expect the pattern `make_add_add_add` to be merged
+ into a single op `add_add_add`.
+
+ x y
+ \ / \
+ sub | x y
+ / | / \ / |
+ | add ====> sub |
+ \ | \ | /
+ add | add_add_add
+ | /
+ add
+
+ """
+ pattern_table = [
+ ("add_add_add", make_add_add_add_pattern())
+ ]
+
+ def before():
+ a = relay.var('a', shape=(10, 10))
+ b = relay.var('b', shape=(10, 10))
+ sub_node = relay.subtract(a, b)
+
+ # pattern
+ add_node = relay.add(sub_node, b)
+ add_node_1 = relay.add(sub_node, add_node)
+ r = relay.add(add_node_1, add_node)
+
+ return relay.Function([a, b], r)
+
+ def expected():
+ a = relay.var('a', shape=(10, 10))
+ b = relay.var('b', shape=(10, 10))
+
+ # add_relu_add function
+ in_1 = relay.var('in_1', shape=(10, 10))
+ in_2 = relay.var('in_2', shape=(10, 10))
+ add_node = relay.add(in_1, in_2)
+ add_node_1 = relay.add(in_1, add_node)
+ add_node_2 = relay.add(add_node_1, add_node)
+ add_add_add = relay.Function([in_1, in_2], add_node_2)
+ add_add_add = add_add_add.set_attribute("Primitive",
+ tir.IntImm("int32", 1))
+ add_add_add = add_add_add.set_attribute("Composite",
+ tir.StringImm("add_add_add"))
+
+ # merged function
+ sub_node = relay.subtract(a, b)
+ call = relay.Call(add_add_add, [sub_node, b])
+ return relay.Function([a, b], call)
+
+ result = run_opt_pass(before(),
relay.transform.MergeComposite(pattern_table))
+ assert not relay.analysis.free_vars(result)
+ expected = run_opt_pass(expected(), relay.transform.InferType())
+ assert relay.analysis.alpha_equal(result, expected)
+
+
def test_multiple_patterns():
"""Test different patterns are merged correctly in the graph.
@@ -608,3 +689,4 @@ if __name__ == "__main__":
test_merge_order()
test_parallel_merge()
test_multiple_input_subgraphs()
+ test_reuse_call_merge()
\ No newline at end of file