This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 00395ae43d [Relax][Bugfix] Provide the full Expr to pattern-match
rewriter (#16828)
00395ae43d is described below
commit 00395ae43d3d6024c900a32f512f136cf818a3af
Author: Eric Lunderberg <[email protected]>
AuthorDate: Mon Apr 1 17:23:35 2024 -0500
[Relax][Bugfix] Provide the full Expr to pattern-match rewriter (#16828)
* [Relax][Bugfix] Provide the full Expr to pattern-match rewriter
This resolves a bug that was introduced in
https://github.com/apache/tvm/pull/16732. If a rewriter function
returned a no-op, and the pattern-match continued, then the `matches`
provided to the rewriter function in subsequent calls would contain
a variable to which the matched expression was bound, not the matched
expression itself. (e.g. For a match of `C = R.add(A,B)`, passing `C`
to the rewriter instead of `R.add(A,B)`.)
This bug was caused by incorrect re-wrapping of `OrPattern` in
`ExprPatternRewriter`. Prior to
https://github.com/apache/tvm/pull/16732, all pattern-match results
were populated by `ExtractMatchExpr`, and contained the result after
applying `TryGetValOfVar`. When re-wrapping the result of an
`OrPattern`, https://github.com/apache/tvm/pull/16732 populated the
additional matches with the result before applying `TryGetValOfVar`.
This commit fixes the bug by applying `TryGetValOfVar`.
* Update with PR link of bugfix
---
src/relax/ir/dataflow_matcher.cc | 13 ++++++++++--
tests/python/relax/test_dataflow_pattern.py | 33 +++++++++++++++++++++++++++++
2 files changed, 44 insertions(+), 2 deletions(-)
diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc
index db70ef6a9c..cf8934c372 100644
--- a/src/relax/ir/dataflow_matcher.cc
+++ b/src/relax/ir/dataflow_matcher.cc
@@ -1190,8 +1190,17 @@ class ExprPatternRewriter : ExprMutator {
if (auto opt_matches = ExtractMatchedExpr(pattern, expr, bindings_)) {
auto matches = opt_matches.value();
- for (const auto& pat : *matches_top_level) {
- matches.Set(pat, expr);
+
+ // Append any additional matches that from the unwrapped
+ // `OrPattern`. When matching against `pat = pat_lhs |
+ // pat_rhs`, we call `ExtractMatchedExpr` on `pat_lhs` and
+ // `pat_rhs` separately. The top-level `pat` is never seen by
+ // `ExtractMatchedExpr`, and must be re-added afterward.
+ if (matches_top_level->size()) {
+ auto matched_expr = TryGetValOfVar(expr, bindings_);
+ for (const auto& pat : *matches_top_level) {
+ matches.Set(pat, matched_expr);
+ }
}
Expr rewritten_expr = rewriter_func_(expr, matches);
diff --git a/tests/python/relax/test_dataflow_pattern.py
b/tests/python/relax/test_dataflow_pattern.py
index 81cd8da7fe..24c36d20dc 100644
--- a/tests/python/relax/test_dataflow_pattern.py
+++ b/tests/python/relax/test_dataflow_pattern.py
@@ -1952,5 +1952,38 @@ def test_backtrack_if_rewriter_returns_no_op():
tvm.ir.assert_structural_equal(expected, after)
+def test_backtrack_for_no_op_rewriter_does_not_match_on_var():
+ """The matches should always contain the bound value
+
+ This is a regression test. In versions from
+ https://github.com/apache/tvm/pull/16732 to
+ https://github.com/apache/tvm/pull/16828, the `rewrite_call`
+ function could erroneously call the rewriter with `expr` and
+ `matches[pat]` set to a variable (`C`) instead of the value to
+ which it is bound (`R.add(A,B)`).
+ """
+ pat_a = is_op("relax.add")(wildcard(), wildcard())
+ pat_b = is_op("relax.add")(wildcard(), wildcard())
+ pat = pat_a | pat_b
+
+ def rewriter(expr, matches):
+ assert isinstance(matches[pat], rx.Call)
+ return expr
+
+ @R.function(private=True)
+ def before():
+ with R.dataflow():
+ A = R.ones([64, 128], "int32")
+ B = R.zeros([64, 128], "int32")
+ C = R.add(A, B)
+
+ R.output(C)
+ return C
+
+ expected = before
+ after = rewrite_call(pat, rewriter, before)
+ tvm.ir.assert_structural_equal(expected, after)
+
+
if __name__ == "__main__":
tvm.testing.main()