This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d4ca123afc [BugFix] Support rewrite_once when the number of callbacks
> 1 (#14344)
d4ca123afc is described below
commit d4ca123afc54ebabe3c9b0666a5456aaf25eeaa2
Author: sisleyli <[email protected]>
AuthorDate: Wed Mar 22 02:34:01 2023 +0800
[BugFix] Support rewrite_once when the number of callbacks > 1 (#14344)
* [BugFix] Support rewrite_once when the number of callbacks > 1
* callbacks_map -> done, swapping false and true
---------
Co-authored-by: Bin Li <[email protected]>
---
src/relay/ir/dataflow_matcher.cc | 37 +++++++++-----
tests/python/relay/test_dataflow_pattern.py | 79 +++++++++++++++++++++++++----
2 files changed, 94 insertions(+), 22 deletions(-)
diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc
index cf186c474e..67c6bae6c5 100644
--- a/src/relay/ir/dataflow_matcher.cc
+++ b/src/relay/ir/dataflow_matcher.cc
@@ -796,24 +796,35 @@ Expr PatternRewriter::Rewrite(const
Array<DFPatternCallback>& callbacks, const E
bool equal = true;
static auto* structural_equal =
runtime::Registry::Get("node.StructuralEqual");
ICHECK(structural_equal) << "node.StructuralEqual is not registered.";
+ // Keep track of callbacks that have finished rewriting
+ std::unordered_map<DFPatternCallback, bool, ObjectPtrHash, ObjectPtrEqual>
done;
do {
last = post;
for (auto callback : callbacks) {
- callback_ = callback;
- if (callback_->require_type) {
- post = InferTypeWithModule(post, mod_);
- }
- auto grouper = PatternGrouper();
- groups_ = grouper.GroupMatches(callback_->pattern, post);
- gid_assignments_ = grouper.GetGIDAssignments();
- memo_.clear();
- VLOG(1) << "pre rewritten:" << std::endl << PrettyPrint(pre);
- post = this->VisitExpr(post);
- VLOG(1) << "post rewritten:" << std::endl << PrettyPrint(post);
- count++;
+ if (!done[callback]) {
+ auto before = post;
+ callback_ = callback;
+ if (callback_->require_type) {
+ post = InferTypeWithModule(post, mod_);
+ }
+ auto grouper = PatternGrouper();
+ groups_ = grouper.GroupMatches(callback_->pattern, post);
+ gid_assignments_ = grouper.GetGIDAssignments();
+ memo_.clear();
+ VLOG(1) << "pre rewritten:" << std::endl << PrettyPrint(pre);
+ post = this->VisitExpr(post);
+ VLOG(1) << "post rewritten:" << std::endl << PrettyPrint(post);
+ count++;
+ if (callback_->rewrite_once) {
+ bool current_equal = (*structural_equal)(before, post, false, true);
+ if (!current_equal) {
+ done[callback] = true;
+ }
+ }
+ }
}
equal = (*structural_equal)(last, post, false, true);
- } while (!equal && count < 100 && !callback_->rewrite_once);
+ } while (!equal && count < 100);
if (count >= 100) {
LOG(FATAL) << "Observed 100 rewrite passes, possible conflicting passes?";
}
diff --git a/tests/python/relay/test_dataflow_pattern.py
b/tests/python/relay/test_dataflow_pattern.py
index 1bd05f5258..bcb665121b 100644
--- a/tests/python/relay/test_dataflow_pattern.py
+++ b/tests/python/relay/test_dataflow_pattern.py
@@ -1804,22 +1804,83 @@ def test_rewrite_once():
if new_args:
return relay.op.concatenate(relay.expr.Tuple(new_args), axis=0)
else:
- return concat_args
+ return concat_args[0]
x = relay.var("x")
y = relay.var("y")
z = relay.var("z")
concat = relay.op.concatenate(relay.expr.Tuple([x, y, z]), axis=0)
- # Let the rewriter run recursively
- out = rewrite(ConcatRewriter(False), concat)
- expected = relay.expr.Tuple([x])
- assert tvm.ir.structural_equal(out, expected)
+ def test_one_callback():
+ # Let the rewriter run recursively
+ out = rewrite(ConcatRewriter(False), concat)
+ expected = x
+ assert tvm.ir.structural_equal(out, expected)
+
+ # Run the rewriter once
+ out = rewrite(ConcatRewriter(True), concat)
+ expected = relay.op.concatenate(relay.expr.Tuple([x, y]), axis=0)
+ assert tvm.ir.structural_equal(out, expected)
+
+ def test_multi_callbacks():
+ # This class recursively add a nn.relu operator after nn.softmax
+ class OneMoreReluRewriter(DFPatternCallback):
+ def __init__(self, rewrite_once):
+ super().__init__(rewrite_once=rewrite_once)
+ self.pattern = is_op("nn.softmax")(None)
+
+ def callback(self, pre, post, node_map):
+ return relay.nn.relu(post)
+
+ def before():
+ # Before:
+ # x y z
+ # | | |
+ # concat
+ # |
+ # softmax
+ return relay.nn.softmax(concat)
+
+ def once_concat():
+ # ConcatRewrite once, OneMoreReluRewrite once
+ # Expected:
+ # x y
+ # | |
+ # concat
+ # |
+ # softmax
+ # |
+ # relu
+ return relay.nn.relu(
+ relay.nn.softmax(relay.op.concatenate(relay.expr.Tuple([x,
y]), axis=0))
+ )
+
+ def recursive_concat():
+ # ConcatRewrite recursively, OneMoreReluRewrite once
+ # Expected:
+ # x
+ # |
+ # softmax
+ # |
+ # relu
+ return relay.nn.relu(relay.nn.softmax(x))
+
+ # Run ConcatRewriter once, OneMoreReluRewriter once
+ out = rewrite(
+ [OneMoreReluRewriter(True), ConcatRewriter(True)],
+ before(),
+ )
+ assert tvm.ir.structural_equal(out, once_concat())
+
+ # Run ConcatRewriter recursively, OneMoreReluRewriter once
+ out = rewrite(
+ [OneMoreReluRewriter(True), ConcatRewriter(False)],
+ before(),
+ )
+ assert tvm.ir.structural_equal(out, recursive_concat())
- # Run the rewriter once
- out = rewrite(ConcatRewriter(True), concat)
- expected = relay.op.concatenate(relay.expr.Tuple([x, y]), axis=0)
- assert tvm.ir.structural_equal(out, expected)
+ test_one_callback()
+ test_multi_callbacks()
def test_matched_outside_but_dominated():