This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 86b5a1301c [Relax] Allow composition of DFPattern replacements
(#16732)
86b5a1301c is described below
commit 86b5a1301c18a411ea920ee26bcbe8f0af70bd75
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Mar 27 12:50:35 2024 -0500
[Relax] Allow composition of DFPattern replacements (#16732)
[Relax] Allow composition of DFPattern replacements
The `rewrite_call` function accepts a `DFPattern`, and a function to
rewrite expressions matching that pattern. Often, the rewriting
function will perform additional validation that cannot be expressed
within the `DFPattern` itself. If this additional validation fails,
the rewriter function will return the matched expression unmodified.
Prior to this commit, an `OrPattern` that matches on the first branch,
but whose rewriter function does not apply a modification, would
prevent the second branch from being checked. This commit updates the
`ExprPatternRewriter` to check both branches of a `OrPattern`, if the
rewriter function of the first branch does not modify the result.
---
src/relax/ir/dataflow_matcher.cc | 44 +++++++++++++++++---
tests/python/relax/test_dataflow_pattern.py | 63 +++++++++++++++++++++++++++++
2 files changed, 102 insertions(+), 5 deletions(-)
diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc
index 531971d3db..db70ef6a9c 100644
--- a/src/relax/ir/dataflow_matcher.cc
+++ b/src/relax/ir/dataflow_matcher.cc
@@ -1158,17 +1158,51 @@ class ExprPatternRewriter : ExprMutator {
Expr VisitExpr(const Expr& expr) override {
auto node = ExprMutator::VisitExpr(expr);
- if (auto matches_opt = ExtractMatchedExpr(pattern_, node, bindings_)) {
- Expr rewritten_expr = rewriter_func_(node, matches_opt.value());
- if (!rewritten_expr.same_as(node)) {
- return builder_->Normalize(rewritten_expr);
- }
+ std::vector<DFPattern> matches_top_level;
+ if (auto rewritten = TryRewrite(node, pattern_, &matches_top_level)) {
+ return builder_->Normalize(rewritten.value());
}
return node;
}
private:
+ Optional<Expr> TryRewrite(const Expr& expr, const DFPattern& pattern,
+ std::vector<DFPattern>* matches_top_level) {
+ ICHECK(matches_top_level);
+
+ // Special handling if the user-supplied pattern is a `OrPattern`.
+ // While the `ExtractMatchedExpr` can handle matching the
+ // `OrPattern`, it will return on the first match, even if the
+ // `rewriter_func_` doesn't apply a replacement. Unpacking the
+ // `OrPattern` here allows the match to be resumed if
+ // `rewriter_func_` returns the original function unmodified.
+ // This is only valid for a top-level match.
+ if (auto or_pattern = pattern.as<OrPatternNode>()) {
+ matches_top_level->push_back(pattern);
+ Optional<Expr> output = TryRewrite(expr, or_pattern->left,
matches_top_level);
+ if (!output.defined()) {
+ output = TryRewrite(expr, or_pattern->right, matches_top_level);
+ }
+ matches_top_level->pop_back();
+ return output;
+ }
+
+ if (auto opt_matches = ExtractMatchedExpr(pattern, expr, bindings_)) {
+ auto matches = opt_matches.value();
+ for (const auto& pat : *matches_top_level) {
+ matches.Set(pat, expr);
+ }
+
+ Expr rewritten_expr = rewriter_func_(expr, matches);
+ if (!rewritten_expr.same_as(expr)) {
+ return builder_->Normalize(rewritten_expr);
+ }
+ }
+
+ return NullOpt;
+ }
+
/*! \brief The pattern for rewriting call nodes */
DFPattern pattern_;
/*!
diff --git a/tests/python/relax/test_dataflow_pattern.py
b/tests/python/relax/test_dataflow_pattern.py
index 583e2a8d08..81cd8da7fe 100644
--- a/tests/python/relax/test_dataflow_pattern.py
+++ b/tests/python/relax/test_dataflow_pattern.py
@@ -1889,5 +1889,68 @@ def test_wildcard_struct_info_with_symbolic_vars():
tvm.ir.assert_structural_equal(expected, after)
+def test_backtrack_if_rewriter_returns_no_op():
+ """Rewriter participates in the pattern matching
+
+ Sometimes, the pattern-matching syntax is insufficient to check if
+ a replacement may be performed. In this case, the `rewriter`
+ function may perform additional validation. If this validation
+ fails, the `rewriter` function can return the original expression,
+ and no replacement is performed.
+
+ In addition, when the `rewriter` returns the original expression,
+ the pattern match should backtrack to determine if another branch
+ of the match may have produced a replacement.
+
+ This functionality allows pattern replacements to be composed.
+ """
+
+ pat_match_no_rewrite = is_op("relax.add")(wildcard(), wildcard())
+
+ pat_arg = wildcard()
+ pat_zeros = is_op("relax.zeros")(wildcard())
+ pat_add = is_op("relax.add")(pat_arg, pat_zeros)
+
+ # OR conditions are checked in the order that they occur. Because
+ # `pat_match_no_rewrite` is a superset of `pat_add`, it will
+ # always match first.
+ pat = pat_match_no_rewrite | pat_add
+
+ def rewriter(expr, matches):
+ if pat_match_no_rewrite in matches:
+ # This branch simulates a rewrite whose precondition has
+ # failed. If the pattern-matching treats this as a
+ # successful match with no replacemen required, then no
+ # rewrite would be performed. On the other hand, if the
+ # pattern-matching treats this as an unsuccessful match,
+ # then it can backtrack and attempt `pat_add` instead.
+ return expr
+ elif pat_add in matches:
+ return matches[pat_arg]
+ else:
+ raise RuntimeError("Pattern matched, but neither branch matched")
+
+ @R.function(private=True)
+ def before():
+ with R.dataflow():
+ A = R.ones([64, 128], "int32")
+ B = R.zeros([64, 128], "int32")
+ C = R.add(A, B)
+
+ R.output(C)
+ return C
+
+ @R.function(private=True)
+ def expected():
+ with R.dataflow():
+ C = R.ones([64, 128], "int32")
+
+ R.output(C)
+ return C
+
+ after = rewrite_call(pat, rewriter, before)
+ tvm.ir.assert_structural_equal(expected, after)
+
+
if __name__ == "__main__":
tvm.testing.main()