This is an automated email from the ASF dual-hosted git repository.
wrongtest pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 2565aa38ef [BugFix][Relax] change FuseOpsByPattern strategy to
pattern-match maximal subgraph (#16922)
2565aa38ef is described below
commit 2565aa38ef4d1d5a5ce5561ebf36910532993d90
Author: lazypanda <[email protected]>
AuthorDate: Fri May 10 21:07:48 2024 +0800
[BugFix][Relax] change FuseOpsByPattern strategy to pattern-match maximal
subgraph (#16922)
* [BugFix][Relax] change FuseOpsByPattern strategy to pattern-match maximal
subgraph
* add testcase
---------
Co-authored-by: Huibin Wang <[email protected]>
---
src/relax/transform/fuse_ops.cc | 31 ++++++++++++++++++++--
.../relax/test_transform_fuse_ops_by_pattern.py | 26 ++++++++++++++++++
2 files changed, 55 insertions(+), 2 deletions(-)
diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index 04c07c439c..e89c5e4445 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -1073,7 +1073,11 @@ class PatternBasedPartitioner : ExprVisitor {
current_block_use_def_ = {};
}
- void VisitVarDef(const Var& var) final { group_map_[var.get()] =
arena_->make<Group>(); }
+ void VisitVarDef(const Var& var) final {
+ Group* g = arena_->make<Group>();
+ group_map_[var.get()] = g;
+ vars_in_group_[g].push_back(var);
+ }
void VisitBinding_(const VarBindingNode* binding) final {
bindings_.Set(binding->var, binding->value);
@@ -1097,7 +1101,13 @@ class PatternBasedPartitioner : ExprVisitor {
auto g = GetGroup(match);
if (g && g->FindRoot()->num_nodes > 1) {
// This expression has already been matched to a previous pattern.
- return;
+ // If the prior matched subgraph is subsumed by the new matched
one,
+ // we can safely merge them, obtaining a maximized matched
subgraph enventually.
+ // Otherwise, merging them will result in an incorrect subgraph,
+ // so we keep the prior subgraph and discard the current one by
directly return.
+ auto vars_in_prior_matched_graph = vars_in_group_[g];
+ if (!GraphSubsumedInMatchedValues(vars_in_prior_matched_graph,
matches_opt.value()))
+ return;
}
}
}
@@ -1145,6 +1155,7 @@ class PatternBasedPartitioner : ExprVisitor {
if (group_map_[e.get()] != to) {
--group_map_[e.get()]->num_nodes;
group_map_[e.get()]->parent = to;
+ vars_in_group_[to].push_back(e);
++to->num_nodes;
}
}
@@ -1181,6 +1192,21 @@ class PatternBasedPartitioner : ExprVisitor {
current_block_use_def_, value_to_bound_var_);
}
+ // check if a previous matched subgraph is subsumed by the current matched
result
+ bool GraphSubsumedInMatchedValues(const Array<Expr>& vars_in_graph,
+ const Map<DFPattern, Expr>&
matched_result) {
+ std::set<Expr> matched_vars;
+ for (const auto& [pat, match] : matched_result) {
+ if ((pat->IsInstance<CallPatternNode>() ||
pat->IsInstance<TupleGetItemPatternNode>()))
+ matched_vars.insert(value_to_bound_var_[match]);
+ }
+
+ for (const auto var : vars_in_graph) {
+ if (matched_vars.find(var) == matched_vars.end()) return false;
+ }
+ return true;
+ }
+
String pat_name_;
DFPattern pat_;
Map<String, DFPattern> annotation_pat_;
@@ -1191,6 +1217,7 @@ class PatternBasedPartitioner : ExprVisitor {
Map<Expr, Var> value_to_bound_var_;
Map<Var, Array<Var>> current_block_use_def_;
GroupMap group_map_;
+ std::map<Group*, Array<Expr>> vars_in_group_;
};
/*!
diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py
b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
index 5e700b277f..f5905f7643 100644
--- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py
+++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
@@ -1217,5 +1217,31 @@ def test_matmul_symbolic_var():
tvm.ir.assert_structural_equal(Expected, After)
+def test_match_maximal_subgraph():
+ @R.function
+ def func(
+ x: R.Tensor((32, 8), dtype="int32"),
+ y: R.Tensor((8, 8), dtype="int32"),
+ bias: R.Tensor((8,), dtype="int32"),
+ ) -> R.Tensor((32, 8), dtype="int32"):
+ R.func_attr({"global_symbol": "main"})
+ with R.dataflow():
+ lv0 = R.matmul(x, y, out_dtype="int32")
+ lv1 = R.add(lv0, bias)
+ lv2 = R.clip(lv1, -128, 127)
+ R.output(lv2)
+ return lv2
+
+ mod = tvm.IRModule({"main": func})
+
+ matmul = is_op("relax.matmul")(wildcard(), wildcard())
+ matmul_add = is_op("relax.add")(matmul, wildcard())
+ pattern = matmul_add | is_op("relax.clip")(matmul_add, wildcard(),
wildcard())
+
+ partitioned = relax.transform.FuseOpsByPattern([("orclip", pattern)])(mod)
+ func_names = [name.name_hint for (name, _) in
partitioned.functions.items()]
+ assert "fused_relax_matmul_relax_add_relax_clip" in func_names
+
+
if __name__ == "__main__":
pytest.main([__file__])