This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new f63fcba404 [Unity] Fix FuseOpsByPattern when a subgraph can be matched 
by multiple residual patterns (#15308)
f63fcba404 is described below

commit f63fcba40432964133a3d7609bd298734ae0fb1d
Author: masahi <[email protected]>
AuthorDate: Fri Jul 14 09:29:41 2023 +0900

    [Unity] Fix FuseOpsByPattern when a subgraph can be matched by multiple 
residual patterns (#15308)
    
    Fix FuseOpsByPattern when a subgraph can be matched by multiple
    residual patterns
---
 src/relax/transform/fuse_ops.cc                    | 19 +++++++++++++++
 .../relax/test_transform_fuse_ops_by_pattern.py    | 27 ++++++++++++++++++++++
 2 files changed, 46 insertions(+)

diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index 8cf0d63a1e..463772f1f2 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -1062,6 +1062,18 @@ class PatternBasedPartitioner : ExprVisitor {
       if (check_ != nullptr && !check_(CreatePatternCheckContext(call, 
matches_opt.value()))) {
         return;
       }
+
+      for (const auto& [pat, match] : matches_opt.value()) {
+        if ((pat->IsInstance<CallPatternNode>() && match != 
GetRef<Call>(call)) ||
+            pat->IsInstance<TupleGetItemPatternNode>()) {
+          auto g = GetGroup(match);
+          if (g && g->FindRoot()->num_nodes > 1) {
+            // This expression has already been matched to a previous pattern.
+            return;
+          }
+        }
+      }
+
       // If a match is found, put all matching expressions into the same group.
       // OperatorFusor also requires that the bound variable be in the same 
group as the RHS value.
       // Since is_op(...) based pattern only matches against call nodes on the 
right hand side,
@@ -1108,6 +1120,13 @@ class PatternBasedPartitioner : ExprVisitor {
     return group_map_[bound_var.get()]->FindRoot();
   }
 
+  Group* GetGroup(const Expr& exp) {
+    if (value_to_bound_var_.count(exp) && 
group_map_.count(value_to_bound_var_[exp].get())) {
+      return group_map_[value_to_bound_var_[exp].get()];
+    }
+    return nullptr;
+  }
+
   PatternCheckContext CreatePatternCheckContext(const CallNode* call,
                                                 const Map<DFPattern, Expr>& 
matched_result) {
     Map<String, Expr> annotated_expr;
diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py 
b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
index a4df89ad54..1352a52674 100644
--- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py
+++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
@@ -26,6 +26,7 @@ from tvm.relax.dpl.pattern import (
     wildcard,
 )
 from tvm.relax.transform import PatternCheckContext
+from tvm.relax.backend.contrib.cutlass import partition_for_cutlass
 from tvm.script import ir as I
 from tvm.script import relax as R
 from tvm.script import tir as T
@@ -894,5 +895,31 @@ def test_clip():
     check(mod, [("x.clip", pat_clip)], Expected2)
 
 
+def test_matmul_add3():
+    @I.ir_module
+    class Module:
+        @R.function
+        def main(
+            x: R.Tensor((32, 8), dtype="float16"),
+            y: R.Tensor((8, 8), dtype="float16"),
+            x2: R.Tensor((32, 8), dtype="float16"),
+            y2: R.Tensor((8, 8), dtype="float16"),
+            bias: R.Tensor((8,), dtype="float16"),
+            residual: R.Tensor((32, 8), dtype="float16"),
+        ) -> R.Tensor((32, 8), dtype="float16"):
+            with R.dataflow():
+                lv_: R.Tensor((32, 8), dtype="float16") = R.matmul(x2, y2, 
out_dtype="float16")
+                lv: R.Tensor((32, 8), dtype="float16") = R.matmul(x, y, 
out_dtype="float16")
+                lv1: R.Tensor((32, 8), dtype="float16") = R.add(lv, bias)
+                lv2: R.Tensor((32, 8), dtype="float16") = R.add(lv1, lv_)
+                out: R.Tensor((32, 8), dtype="float16") = R.add(lv2, residual)
+                R.output(out)
+            return out
+
+    mod = partition_for_cutlass(Module)
+    func_names = [name.name_hint for (name, _) in mod.functions.items()]
+    assert "fused_relax_matmul_relax_add_relax_add_cutlass" in func_names
+
+
 if __name__ == "__main__":
     pytest.main([__file__])

Reply via email to