This is an automated email from the ASF dual-hosted git repository.

wrongtest pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2565aa38ef [BugFix][Relax] change FuseOpsByPattern strategy to 
pattern-match maximal subgraph (#16922)
2565aa38ef is described below

commit 2565aa38ef4d1d5a5ce5561ebf36910532993d90
Author: lazypanda <[email protected]>
AuthorDate: Fri May 10 21:07:48 2024 +0800

    [BugFix][Relax] change FuseOpsByPattern strategy to pattern-match maximal 
subgraph (#16922)
    
    * [BugFix][Relax] change FuseOpsByPattern strategy to pattern-match maximal 
subgraph
    
    * add testcase
    
    ---------
    
    Co-authored-by: Huibin Wang <[email protected]>
---
 src/relax/transform/fuse_ops.cc                    | 31 ++++++++++++++++++++--
 .../relax/test_transform_fuse_ops_by_pattern.py    | 26 ++++++++++++++++++
 2 files changed, 55 insertions(+), 2 deletions(-)

diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index 04c07c439c..e89c5e4445 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -1073,7 +1073,11 @@ class PatternBasedPartitioner : ExprVisitor {
     current_block_use_def_ = {};
   }
 
-  void VisitVarDef(const Var& var) final { group_map_[var.get()] = 
arena_->make<Group>(); }
+  void VisitVarDef(const Var& var) final {
+    Group* g = arena_->make<Group>();
+    group_map_[var.get()] = g;
+    vars_in_group_[g].push_back(var);
+  }
 
   void VisitBinding_(const VarBindingNode* binding) final {
     bindings_.Set(binding->var, binding->value);
@@ -1097,7 +1101,13 @@ class PatternBasedPartitioner : ExprVisitor {
           auto g = GetGroup(match);
           if (g && g->FindRoot()->num_nodes > 1) {
             // This expression has already been matched to a previous pattern.
-            return;
+            // If the prior matched subgraph is subsumed by the new matched 
one,
+            // we can safely merge them, obtaining a maximized matched 
subgraph enventually.
+            // Otherwise, merging them will result in an incorrect subgraph,
+            // so we keep the prior subgraph and discard the current one by 
directly return.
+            auto vars_in_prior_matched_graph = vars_in_group_[g];
+            if (!GraphSubsumedInMatchedValues(vars_in_prior_matched_graph, 
matches_opt.value()))
+              return;
           }
         }
       }
@@ -1145,6 +1155,7 @@ class PatternBasedPartitioner : ExprVisitor {
     if (group_map_[e.get()] != to) {
       --group_map_[e.get()]->num_nodes;
       group_map_[e.get()]->parent = to;
+      vars_in_group_[to].push_back(e);
       ++to->num_nodes;
     }
   }
@@ -1181,6 +1192,21 @@ class PatternBasedPartitioner : ExprVisitor {
                                current_block_use_def_, value_to_bound_var_);
   }
 
+  // check if a previous matched subgraph is subsumed by the current matched 
result
+  bool GraphSubsumedInMatchedValues(const Array<Expr>& vars_in_graph,
+                                    const Map<DFPattern, Expr>& 
matched_result) {
+    std::set<Expr> matched_vars;
+    for (const auto& [pat, match] : matched_result) {
+      if ((pat->IsInstance<CallPatternNode>() || 
pat->IsInstance<TupleGetItemPatternNode>()))
+        matched_vars.insert(value_to_bound_var_[match]);
+    }
+
+    for (const auto var : vars_in_graph) {
+      if (matched_vars.find(var) == matched_vars.end()) return false;
+    }
+    return true;
+  }
+
   String pat_name_;
   DFPattern pat_;
   Map<String, DFPattern> annotation_pat_;
@@ -1191,6 +1217,7 @@ class PatternBasedPartitioner : ExprVisitor {
   Map<Expr, Var> value_to_bound_var_;
   Map<Var, Array<Var>> current_block_use_def_;
   GroupMap group_map_;
+  std::map<Group*, Array<Expr>> vars_in_group_;
 };
 
 /*!
diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py 
b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
index 5e700b277f..f5905f7643 100644
--- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py
+++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
@@ -1217,5 +1217,31 @@ def test_matmul_symbolic_var():
     tvm.ir.assert_structural_equal(Expected, After)
 
 
+def test_match_maximal_subgraph():
+    @R.function
+    def func(
+        x: R.Tensor((32, 8), dtype="int32"),
+        y: R.Tensor((8, 8), dtype="int32"),
+        bias: R.Tensor((8,), dtype="int32"),
+    ) -> R.Tensor((32, 8), dtype="int32"):
+        R.func_attr({"global_symbol": "main"})
+        with R.dataflow():
+            lv0 = R.matmul(x, y, out_dtype="int32")
+            lv1 = R.add(lv0, bias)
+            lv2 = R.clip(lv1, -128, 127)
+            R.output(lv2)
+        return lv2
+
+    mod = tvm.IRModule({"main": func})
+
+    matmul = is_op("relax.matmul")(wildcard(), wildcard())
+    matmul_add = is_op("relax.add")(matmul, wildcard())
+    pattern = matmul_add | is_op("relax.clip")(matmul_add, wildcard(), 
wildcard())
+
+    partitioned = relax.transform.FuseOpsByPattern([("orclip", pattern)])(mod)
+    func_names = [name.name_hint for (name, _) in 
partitioned.functions.items()]
+    assert "fused_relax_matmul_relax_add_relax_clip" in func_names
+
+
 if __name__ == "__main__":
     pytest.main([__file__])

Reply via email to