This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 86b5a1301c [Relax] Allow composition of DFPattern replacements  
(#16732)
86b5a1301c is described below

commit 86b5a1301c18a411ea920ee26bcbe8f0af70bd75
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Mar 27 12:50:35 2024 -0500

    [Relax] Allow composition of DFPattern replacements  (#16732)
    
    [Relax] Allow composition of DFPattern replacements
    
    The `rewrite_call` function accepts a `DFPattern`, and a function to
    rewrite expressions matching that pattern.  Often, the rewriting
    function will perform additional validation that cannot be expressed
    within the `DFPattern` itself.  If this additional validation fails,
    the rewriter function will return the matched expression unmodified.
    
    Prior to this commit, an `OrPattern` that matches on the first branch,
    but whose rewriter function does not apply a modification, would
    prevent the second branch from being checked.  This commit updates the
    `ExprPatternRewriter` to check both branches of a `OrPattern`, if the
    rewriter function of the first branch does not modify the result.
---
 src/relax/ir/dataflow_matcher.cc            | 44 +++++++++++++++++---
 tests/python/relax/test_dataflow_pattern.py | 63 +++++++++++++++++++++++++++++
 2 files changed, 102 insertions(+), 5 deletions(-)

diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc
index 531971d3db..db70ef6a9c 100644
--- a/src/relax/ir/dataflow_matcher.cc
+++ b/src/relax/ir/dataflow_matcher.cc
@@ -1158,17 +1158,51 @@ class ExprPatternRewriter : ExprMutator {
   Expr VisitExpr(const Expr& expr) override {
     auto node = ExprMutator::VisitExpr(expr);
 
-    if (auto matches_opt = ExtractMatchedExpr(pattern_, node, bindings_)) {
-      Expr rewritten_expr = rewriter_func_(node, matches_opt.value());
-      if (!rewritten_expr.same_as(node)) {
-        return builder_->Normalize(rewritten_expr);
-      }
+    std::vector<DFPattern> matches_top_level;
+    if (auto rewritten = TryRewrite(node, pattern_, &matches_top_level)) {
+      return builder_->Normalize(rewritten.value());
     }
 
     return node;
   }
 
  private:
+  Optional<Expr> TryRewrite(const Expr& expr, const DFPattern& pattern,
+                            std::vector<DFPattern>* matches_top_level) {
+    ICHECK(matches_top_level);
+
+    // Special handling if the user-supplied pattern is a `OrPattern`.
+    // While the `ExtractMatchedExpr` can handle matching the
+    // `OrPattern`, it will return on the first match, even if the
+    // `rewriter_func_` doesn't apply a replacement.  Unpacking the
+    // `OrPattern` here allows the match to be resumed if
+    // `rewriter_func_` returns the original function unmodified.
+    // This is only valid for a top-level match.
+    if (auto or_pattern = pattern.as<OrPatternNode>()) {
+      matches_top_level->push_back(pattern);
+      Optional<Expr> output = TryRewrite(expr, or_pattern->left, 
matches_top_level);
+      if (!output.defined()) {
+        output = TryRewrite(expr, or_pattern->right, matches_top_level);
+      }
+      matches_top_level->pop_back();
+      return output;
+    }
+
+    if (auto opt_matches = ExtractMatchedExpr(pattern, expr, bindings_)) {
+      auto matches = opt_matches.value();
+      for (const auto& pat : *matches_top_level) {
+        matches.Set(pat, expr);
+      }
+
+      Expr rewritten_expr = rewriter_func_(expr, matches);
+      if (!rewritten_expr.same_as(expr)) {
+        return builder_->Normalize(rewritten_expr);
+      }
+    }
+
+    return NullOpt;
+  }
+
   /*! \brief The pattern for rewriting call nodes */
   DFPattern pattern_;
   /*!
diff --git a/tests/python/relax/test_dataflow_pattern.py 
b/tests/python/relax/test_dataflow_pattern.py
index 583e2a8d08..81cd8da7fe 100644
--- a/tests/python/relax/test_dataflow_pattern.py
+++ b/tests/python/relax/test_dataflow_pattern.py
@@ -1889,5 +1889,68 @@ def test_wildcard_struct_info_with_symbolic_vars():
     tvm.ir.assert_structural_equal(expected, after)
 
 
+def test_backtrack_if_rewriter_returns_no_op():
+    """Rewriter participates in the pattern matching
+
+    Sometimes, the pattern-matching syntax is insufficient to check if
+    a replacement may be performed.  In this case, the `rewriter`
+    function may perform additional validation.  If this validation
+    fails, the `rewriter` function can return the original expression,
+    and no replacement is performed.
+
+    In addition, when the `rewriter` returns the original expression,
+    the pattern match should backtrack to determine if another branch
+    of the match may have produced a replacement.
+
+    This functionality allows pattern replacements to be composed.
+    """
+
+    pat_match_no_rewrite = is_op("relax.add")(wildcard(), wildcard())
+
+    pat_arg = wildcard()
+    pat_zeros = is_op("relax.zeros")(wildcard())
+    pat_add = is_op("relax.add")(pat_arg, pat_zeros)
+
+    # OR conditions are checked in the order that they occur.  Because
+    # `pat_match_no_rewrite` is a superset of `pat_add`, it will
+    # always match first.
+    pat = pat_match_no_rewrite | pat_add
+
+    def rewriter(expr, matches):
+        if pat_match_no_rewrite in matches:
+            # This branch simulates a rewrite whose precondition has
+            # failed.  If the pattern-matching treats this as a
+            # successful match with no replacemen required, then no
+            # rewrite would be performed.  On the other hand, if the
+            # pattern-matching treats this as an unsuccessful match,
+            # then it can backtrack and attempt `pat_add` instead.
+            return expr
+        elif pat_add in matches:
+            return matches[pat_arg]
+        else:
+            raise RuntimeError("Pattern matched, but neither branch matched")
+
+    @R.function(private=True)
+    def before():
+        with R.dataflow():
+            A = R.ones([64, 128], "int32")
+            B = R.zeros([64, 128], "int32")
+            C = R.add(A, B)
+
+            R.output(C)
+        return C
+
+    @R.function(private=True)
+    def expected():
+        with R.dataflow():
+            C = R.ones([64, 128], "int32")
+
+            R.output(C)
+        return C
+
+    after = rewrite_call(pat, rewriter, before)
+    tvm.ir.assert_structural_equal(expected, after)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to