This is an automated email from the ASF dual-hosted git repository.

kparzysz pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 9fb191159f [Unity] Commutative pattern match based on relax.Expr op 
(#15494)
9fb191159f is described below

commit 9fb191159ffb8a628e5ca0a02201156d2030edee
Author: Eric Lunderberg <[email protected]>
AuthorDate: Mon Aug 7 08:06:55 2023 -0500

    [Unity] Commutative pattern match based on relax.Expr op (#15494)
    
    Prior to this commit, the commutative pattern matching was enabled
    based on the operation in the pattern.  As a result,
    commutative matches would only be checked if the match checked for a
    single operator, but not if the operator was itself a pattern that
    resolved to a commutative operator.
    
    ```python
    pattern_add = ExprPattern(Op.get("relax.add"))
    pattern_mul = ExprPattern(Op.get("relax.multiply"))
    
    uses_commutative_matching = pattern_add(lhs, rhs)
    no_commutative_matching = OrPattern(pattern_add, pattern_mul)(lhs, rhs)
    ```
    
    This commit updates the pattern matcher to check against the matched
    operator, rather than the pattern, to determine whether to check for
    commutative matches.
---
 src/relax/ir/dataflow_matcher.cc            |  4 ++--
 tests/python/relax/test_dataflow_pattern.py | 36 +++++++++++++++++++++++++++++
 2 files changed, 38 insertions(+), 2 deletions(-)

diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc
index 2d06ce1fb9..290ee42eff 100644
--- a/src/relax/ir/dataflow_matcher.cc
+++ b/src/relax/ir/dataflow_matcher.cc
@@ -275,8 +275,8 @@ bool DFPatternMatcher::VisitDFPattern_(const 
CallPatternNode* op, const Expr& ex
       // Standard case
       if (match_args(op->args, call_node->args.begin(), 
call_node->args.end())) return true;
 
-      // Commutative Matching
-      if (const OpNode* op_node = get_op_node(op)) {
+      // Commutative Matching.
+      if (const OpNode* op_node = call_node->op.as<OpNode>()) {
         if ((op_node->name == "relax.add") || (op_node->name == 
"relax.multiply")) {
           if (match_args(op->args, call_node->args.rbegin(), 
call_node->args.rend())) {
             return true;
diff --git a/tests/python/relax/test_dataflow_pattern.py 
b/tests/python/relax/test_dataflow_pattern.py
index ea83807bf8..202db9b5b3 100644
--- a/tests/python/relax/test_dataflow_pattern.py
+++ b/tests/python/relax/test_dataflow_pattern.py
@@ -1282,5 +1282,41 @@ def test_combine_transposed_matmul_twice():
         rx.build(mod, target="llvm")
 
 
+def test_commutative_pattern_match():
+    @R.function(private=True)
+    def before(
+        x: R.Tensor((1024,)),
+    ):
+        with R.dataflow():
+            out = R.add(R.const(1.0), x)
+            R.output(out)
+        return out
+
+    @R.function(private=True)
+    def expected(
+        x: R.Tensor((1024,)),
+    ):
+        with R.dataflow():
+            out = R.add(x, R.const(2.0))
+            R.output(out)
+        return out
+
+    pattern_add = is_op("relax.add")
+    pattern_mul = is_op("relax.multiply")
+    pattern_op = pattern_add | pattern_mul
+    pattern_arg = wildcard()
+    pattern_const = is_const()
+
+    pattern = pattern_op(pattern_arg, pattern_const)
+
+    def rewriter(_expr, matches):
+        op = matches[pattern_op]
+        arg = matches[pattern_arg]
+        return rx.Call(op, [arg, rx.const(2.0)])
+
+    after = rewrite_call(pattern, rewriter, before)
+    tvm.ir.assert_structural_equal(after, expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to