This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new f63fcba404 [Unity] Fix FuseOpsByPattern when a subgraph can be matched
by multiple residual patterns (#15308)
f63fcba404 is described below
commit f63fcba40432964133a3d7609bd298734ae0fb1d
Author: masahi <[email protected]>
AuthorDate: Fri Jul 14 09:29:41 2023 +0900
[Unity] Fix FuseOpsByPattern when a subgraph can be matched by multiple
residual patterns (#15308)
Fix FuseOpsByPattern when a subgraph can be matched by multiple
residual patterns
---
src/relax/transform/fuse_ops.cc | 19 +++++++++++++++
.../relax/test_transform_fuse_ops_by_pattern.py | 27 ++++++++++++++++++++++
2 files changed, 46 insertions(+)
diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index 8cf0d63a1e..463772f1f2 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -1062,6 +1062,18 @@ class PatternBasedPartitioner : ExprVisitor {
if (check_ != nullptr && !check_(CreatePatternCheckContext(call,
matches_opt.value()))) {
return;
}
+
+ for (const auto& [pat, match] : matches_opt.value()) {
+ if ((pat->IsInstance<CallPatternNode>() && match !=
GetRef<Call>(call)) ||
+ pat->IsInstance<TupleGetItemPatternNode>()) {
+ auto g = GetGroup(match);
+ if (g && g->FindRoot()->num_nodes > 1) {
+ // This expression has already been matched to a previous pattern.
+ return;
+ }
+ }
+ }
+
// If a match is found, put all matching expressions into the same group.
// OperatorFusor also requires that the bound variable be in the same
group as the RHS value.
// Since is_op(...) based pattern only matches against call nodes on the
right hand side,
@@ -1108,6 +1120,13 @@ class PatternBasedPartitioner : ExprVisitor {
return group_map_[bound_var.get()]->FindRoot();
}
+ Group* GetGroup(const Expr& exp) {
+ if (value_to_bound_var_.count(exp) &&
group_map_.count(value_to_bound_var_[exp].get())) {
+ return group_map_[value_to_bound_var_[exp].get()];
+ }
+ return nullptr;
+ }
+
PatternCheckContext CreatePatternCheckContext(const CallNode* call,
const Map<DFPattern, Expr>&
matched_result) {
Map<String, Expr> annotated_expr;
diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py
b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
index a4df89ad54..1352a52674 100644
--- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py
+++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
@@ -26,6 +26,7 @@ from tvm.relax.dpl.pattern import (
wildcard,
)
from tvm.relax.transform import PatternCheckContext
+from tvm.relax.backend.contrib.cutlass import partition_for_cutlass
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
@@ -894,5 +895,31 @@ def test_clip():
check(mod, [("x.clip", pat_clip)], Expected2)
+def test_matmul_add3():
+ @I.ir_module
+ class Module:
+ @R.function
+ def main(
+ x: R.Tensor((32, 8), dtype="float16"),
+ y: R.Tensor((8, 8), dtype="float16"),
+ x2: R.Tensor((32, 8), dtype="float16"),
+ y2: R.Tensor((8, 8), dtype="float16"),
+ bias: R.Tensor((8,), dtype="float16"),
+ residual: R.Tensor((32, 8), dtype="float16"),
+ ) -> R.Tensor((32, 8), dtype="float16"):
+ with R.dataflow():
+ lv_: R.Tensor((32, 8), dtype="float16") = R.matmul(x2, y2,
out_dtype="float16")
+ lv: R.Tensor((32, 8), dtype="float16") = R.matmul(x, y,
out_dtype="float16")
+ lv1: R.Tensor((32, 8), dtype="float16") = R.add(lv, bias)
+ lv2: R.Tensor((32, 8), dtype="float16") = R.add(lv1, lv_)
+ out: R.Tensor((32, 8), dtype="float16") = R.add(lv2, residual)
+ R.output(out)
+ return out
+
+ mod = partition_for_cutlass(Module)
+ func_names = [name.name_hint for (name, _) in mod.functions.items()]
+ assert "fused_relax_matmul_relax_add_relax_add_cutlass" in func_names
+
+
if __name__ == "__main__":
pytest.main([__file__])