This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 00395ae43d [Relax][Bugfix] Provide the full Expr to pattern-match 
rewriter (#16828)
00395ae43d is described below

commit 00395ae43d3d6024c900a32f512f136cf818a3af
Author: Eric Lunderberg <[email protected]>
AuthorDate: Mon Apr 1 17:23:35 2024 -0500

    [Relax][Bugfix] Provide the full Expr to pattern-match rewriter (#16828)
    
    * [Relax][Bugfix] Provide the full Expr to pattern-match rewriter
    
    This resolves a bug that was introduced in
    https://github.com/apache/tvm/pull/16732.  If a rewriter function
    returned a no-op, and the pattern-match continued, then the `matches`
    provided to the rewriter function in subsequent calls would contain
    a variable to which the matched expression was bound, not the matched
    expression itself.  (e.g. For a match of `C = R.add(A,B)`, passing `C`
    to the rewriter instead of `R.add(A,B)`.)
    
    This bug was caused by incorrect re-wrapping of `OrPattern` in
    `ExprPatternRewriter`.  Prior to
    https://github.com/apache/tvm/pull/16732, all pattern-match results
    were populated by `ExtractMatchExpr`, and contained the result after
    applying `TryGetValOfVar`.  When re-wrapping the result of an
    `OrPattern`, https://github.com/apache/tvm/pull/16732 populated the
    additional matches with the result before applying `TryGetValOfVar`.
    This commit fixes the bug by applying `TryGetValOfVar`.
    
    * Update with PR link of bugfix
---
 src/relax/ir/dataflow_matcher.cc            | 13 ++++++++++--
 tests/python/relax/test_dataflow_pattern.py | 33 +++++++++++++++++++++++++++++
 2 files changed, 44 insertions(+), 2 deletions(-)

diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc
index db70ef6a9c..cf8934c372 100644
--- a/src/relax/ir/dataflow_matcher.cc
+++ b/src/relax/ir/dataflow_matcher.cc
@@ -1190,8 +1190,17 @@ class ExprPatternRewriter : ExprMutator {
 
     if (auto opt_matches = ExtractMatchedExpr(pattern, expr, bindings_)) {
       auto matches = opt_matches.value();
-      for (const auto& pat : *matches_top_level) {
-        matches.Set(pat, expr);
+
+      // Append any additional matches that from the unwrapped
+      // `OrPattern`.  When matching against `pat = pat_lhs |
+      // pat_rhs`, we call `ExtractMatchedExpr` on `pat_lhs` and
+      // `pat_rhs` separately.  The top-level `pat` is never seen by
+      // `ExtractMatchedExpr`, and must be re-added afterward.
+      if (matches_top_level->size()) {
+        auto matched_expr = TryGetValOfVar(expr, bindings_);
+        for (const auto& pat : *matches_top_level) {
+          matches.Set(pat, matched_expr);
+        }
       }
 
       Expr rewritten_expr = rewriter_func_(expr, matches);
diff --git a/tests/python/relax/test_dataflow_pattern.py 
b/tests/python/relax/test_dataflow_pattern.py
index 81cd8da7fe..24c36d20dc 100644
--- a/tests/python/relax/test_dataflow_pattern.py
+++ b/tests/python/relax/test_dataflow_pattern.py
@@ -1952,5 +1952,38 @@ def test_backtrack_if_rewriter_returns_no_op():
     tvm.ir.assert_structural_equal(expected, after)
 
 
+def test_backtrack_for_no_op_rewriter_does_not_match_on_var():
+    """The matches should always contain the bound value
+
+    This is a regression test.  In versions from
+    https://github.com/apache/tvm/pull/16732 to
+    https://github.com/apache/tvm/pull/16828, the `rewrite_call`
+    function could erroneously call the rewriter with `expr` and
+    `matches[pat]` set to a variable (`C`) instead of the value to
+    which it is bound (`R.add(A,B)`).
+    """
+    pat_a = is_op("relax.add")(wildcard(), wildcard())
+    pat_b = is_op("relax.add")(wildcard(), wildcard())
+    pat = pat_a | pat_b
+
+    def rewriter(expr, matches):
+        assert isinstance(matches[pat], rx.Call)
+        return expr
+
+    @R.function(private=True)
+    def before():
+        with R.dataflow():
+            A = R.ones([64, 128], "int32")
+            B = R.zeros([64, 128], "int32")
+            C = R.add(A, B)
+
+            R.output(C)
+        return C
+
+    expected = before
+    after = rewrite_call(pat, rewriter, before)
+    tvm.ir.assert_structural_equal(expected, after)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to