This is an automated email from the ASF dual-hosted git repository.
kparzysz pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 9fb191159f [Unity] Commutative pattern match based on relax.Expr op
(#15494)
9fb191159f is described below
commit 9fb191159ffb8a628e5ca0a02201156d2030edee
Author: Eric Lunderberg <[email protected]>
AuthorDate: Mon Aug 7 08:06:55 2023 -0500
[Unity] Commutative pattern match based on relax.Expr op (#15494)
Prior to this commit, the commutative pattern matching was enabled
based on the operation in the pattern. As a result,
commutative matches would only be checked if the match checked for a
single operator, but not if the operator was itself a pattern that
resolved to a commutative operator.
```python
pattern_add = ExprPattern(Op.get("relax.add"))
pattern_mul = ExprPattern(Op.get("relax.multiply"))
uses_commutative_matching = pattern_add(lhs, rhs)
no_commutative_matching = OrPattern(pattern_add, pattern_mul)(lhs, rhs)
```
This commit updates the pattern matcher to check against the matched
operator, rather than the pattern, to determine whether to check for
commutative matches.
---
src/relax/ir/dataflow_matcher.cc | 4 ++--
tests/python/relax/test_dataflow_pattern.py | 36 +++++++++++++++++++++++++++++
2 files changed, 38 insertions(+), 2 deletions(-)
diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc
index 2d06ce1fb9..290ee42eff 100644
--- a/src/relax/ir/dataflow_matcher.cc
+++ b/src/relax/ir/dataflow_matcher.cc
@@ -275,8 +275,8 @@ bool DFPatternMatcher::VisitDFPattern_(const
CallPatternNode* op, const Expr& ex
// Standard case
if (match_args(op->args, call_node->args.begin(),
call_node->args.end())) return true;
- // Commutative Matching
- if (const OpNode* op_node = get_op_node(op)) {
+ // Commutative Matching.
+ if (const OpNode* op_node = call_node->op.as<OpNode>()) {
if ((op_node->name == "relax.add") || (op_node->name ==
"relax.multiply")) {
if (match_args(op->args, call_node->args.rbegin(),
call_node->args.rend())) {
return true;
diff --git a/tests/python/relax/test_dataflow_pattern.py
b/tests/python/relax/test_dataflow_pattern.py
index ea83807bf8..202db9b5b3 100644
--- a/tests/python/relax/test_dataflow_pattern.py
+++ b/tests/python/relax/test_dataflow_pattern.py
@@ -1282,5 +1282,41 @@ def test_combine_transposed_matmul_twice():
rx.build(mod, target="llvm")
+def test_commutative_pattern_match():
+ @R.function(private=True)
+ def before(
+ x: R.Tensor((1024,)),
+ ):
+ with R.dataflow():
+ out = R.add(R.const(1.0), x)
+ R.output(out)
+ return out
+
+ @R.function(private=True)
+ def expected(
+ x: R.Tensor((1024,)),
+ ):
+ with R.dataflow():
+ out = R.add(x, R.const(2.0))
+ R.output(out)
+ return out
+
+ pattern_add = is_op("relax.add")
+ pattern_mul = is_op("relax.multiply")
+ pattern_op = pattern_add | pattern_mul
+ pattern_arg = wildcard()
+ pattern_const = is_const()
+
+ pattern = pattern_op(pattern_arg, pattern_const)
+
+ def rewriter(_expr, matches):
+ op = matches[pattern_op]
+ arg = matches[pattern_arg]
+ return rx.Call(op, [arg, rx.const(2.0)])
+
+ after = rewrite_call(pattern, rewriter, before)
+ tvm.ir.assert_structural_equal(after, expected)
+
+
if __name__ == "__main__":
tvm.testing.main()