This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 9c0b41bc74 [Unity] Allow filtering out unwanted branches in matmul
combining pass (#14971)
9c0b41bc74 is described below
commit 9c0b41bc741a6814ab03e685486491f3afd441c2
Author: masahi <[email protected]>
AuthorDate: Mon May 29 10:32:45 2023 +0900
[Unity] Allow filtering out unwanted branches in matmul combining pass
(#14971)
* pass current bindings to rewriter
* add check func to CombineParallelMatmul
* clean
* add doc for df binding rewrite update
* add test
* black
---
python/tvm/relax/dpl/rewrite.py | 9 ++--
python/tvm/relax/transform/transform.py | 12 ++++-
src/relax/ir/dataflow_matcher.cc | 29 +++++++-----
src/relax/transform/combine_parallel_matmul.cc | 41 +++++++++++------
tests/python/relax/test_dataflow_pattern.py | 2 +-
.../test_transform_combine_parallel_matmul.py | 52 ++++++++++++++++++++++
6 files changed, 116 insertions(+), 29 deletions(-)
diff --git a/python/tvm/relax/dpl/rewrite.py b/python/tvm/relax/dpl/rewrite.py
index 1b62a42903..291061090f 100644
--- a/python/tvm/relax/dpl/rewrite.py
+++ b/python/tvm/relax/dpl/rewrite.py
@@ -60,7 +60,9 @@ def rewrite_call(
def rewrite_bindings(
- ctx: PatternContext, rewriter: Callable[[Dict[DFPattern, Var]], Dict[Var,
Expr]], func: Function
+ ctx: PatternContext,
+ rewriter: Callable[[Dict[DFPattern, Var], Dict[Var, Expr]], Dict[Var,
Expr]],
+ func: Function,
) -> Function:
"""
Rewrite a function with the given pattern and the rewriter function.
@@ -70,10 +72,11 @@ def rewrite_bindings(
ctx: PatternContext
The pattern constraint context under which rewriting takes place.
- rewriter: Callable[[Dict[DFPattern, Var]], Dict[Var, Expr]]
+ rewriter: Callable[[Dict[DFPattern, Var], Dict[Var, Expr]], Dict[Var,
Expr]]
The function to be called on a successful matching for rewriting.
Given the map of patterns
and corresponding variables (bound variables or parameters), it should
return a map that
- specifies new values for matched bound variables.
+ specifies new values for matched bound variables. It can refer to the
passed bindings to
+ create the replacement expressions.
For example, to rewrite three matmuls for QKV projection in
transformer models into one
matmul followed by slicing, one can use the follwoing rewriter:
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index 6013073a37..7d390ed1f9 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -1012,7 +1012,7 @@ def SplitCallTIRByPattern(patterns, fcodegen) ->
tvm.ir.transform.Pass:
return _ffi_api.SplitCallTIRByPattern(patterns, fcodegen) # type: ignore
-def CombineParallelMatmul():
+def CombineParallelMatmul(check=None):
"""Combine multiple matmul operators sharing the same LHS matrix into one,
followed by slicing. When all matmul branches in a tree have the same set
of fused ops,
the fused ops are applied to the combined matmul output before slicing.
@@ -1020,12 +1020,20 @@ def CombineParallelMatmul():
Currently, only a limited set of fused ops is supported. It includes bias
add,
relu, gelu, gelu_tanh and silu activation.
+ Parameters
+ ----------
+ check : Callable[[Var, List[Var], List[Var], Dict[Var, Expr]], bool]
+ A function to filter out unwanted branches, with the signature
+ (input, [rhs], [bias], binding) -> bool.
+
Returns
-------
ret : tvm.transform.Pass
The corresponding pass.
"""
- return _ffi_api.CombineParallelMatmul() # type: ignore
+ if check is None:
+ check = lambda *_: True
+ return _ffi_api.CombineParallelMatmul(check) # type: ignore
def RewriteCUDAGraph() -> tvm.ir.transform.Pass:
diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc
index b06da62c26..81cccec86d 100644
--- a/src/relax/ir/dataflow_matcher.cc
+++ b/src/relax/ir/dataflow_matcher.cc
@@ -707,12 +707,11 @@ static std::optional<MatchState> MatchTree(
return std::nullopt;
}
-Optional<Map<DFPattern, Var>> MatchGraph(const PatternContext& ctx, const
DataflowBlock& dfb) {
+Optional<Map<DFPattern, Var>> MatchGraph(const PatternContext& ctx, const
DataflowBlock& dfb,
+ const Map<Var, Expr>& bindings) {
// TODO(@ganler): Handle non-may external use.
ICHECK(ctx->allow_extern_use == PatternContextNode::kMay) << "Only kMay is
supported yet.";
-
- const auto var2val = AnalyzeVar2Value(dfb);
- DFPatternMatcher matcher(var2val);
+ DFPatternMatcher matcher(bindings);
MatcherUseDefAnalysis ud_analysis;
ud_analysis.VisitBindingBlock_(dfb.get());
@@ -772,7 +771,14 @@ Optional<Map<DFPattern, Var>> MatchGraph(const
PatternContext& ctx, const Datafl
return NullOpt;
}
-TVM_REGISTER_GLOBAL("relax.dpl.match_dfb").set_body_typed(MatchGraph);
+Optional<Map<DFPattern, Var>> MatchGraph(const PatternContext& ctx, const
DataflowBlock& dfb) {
+ return MatchGraph(ctx, dfb, AnalyzeVar2Value(dfb));
+}
+
+TVM_REGISTER_GLOBAL("relax.dpl.match_dfb")
+ .set_body_typed([](const PatternContext& ctx, const DataflowBlock& dfb) {
+ return MatchGraph(ctx, dfb);
+ });
/*!
* \brief Apply pattern matching to each call node and dataflow block, and
replace matching ones
@@ -863,9 +869,11 @@ class PatternRewriter : ExprMutator {
// Repeat until all matchable subsets of bindings are rewritten.
BindingBlock RewriteDataflowBlockFixedPoint(BindingBlock block) {
- if (auto matches = MatchGraph(ctx_.value(),
Downcast<DataflowBlock>(block))) {
+ auto df_block = Downcast<DataflowBlock>(block);
+ Map<Var, Expr> bindings = AnalyzeVar2Value(df_block);
+ if (auto matches = MatchGraph(ctx_.value(), df_block, bindings)) {
builder_->BeginDataflowBlock();
- Map<Var, Expr> replacements = rewriter_func_(matches.value());
+ Map<Var, Expr> replacements = rewriter_func_(matches.value(), bindings);
std::unordered_set<const VarNode*> emitted_vars;
@@ -906,9 +914,10 @@ class PatternRewriter : ExprMutator {
* - (Call, Map<DFPattern, Expr>) -> Call for call node rewriting. Given the
matched
* call node and the map of patterns and matched expressions, it should
return a new call node
* to replace the original one or the original matched call node as is.
- * - Map<DFPattern, Var> -> Map<Var, Expr> for dataflow block rewriting.
Given the map of patterns
- * and corresponding variables (bound variables or parameters), it should
return a map that
- * specifies new values for matched bound variables.
+ * - (Map<DFPattern, Var>, Map<Var, Expr>) -> Map<Var, Expr> for dataflow
block rewriting.
+ * Given the map of patterns and corresponding variables (bound variables
or parameters),
+ * it should return a map that specifies new values for matched bound
variables. It can refer
+ * to the passed bindings to create the replacement expressions.
*/
PackedFunc rewriter_func_;
std::unordered_set<const VarNode*> params_;
diff --git a/src/relax/transform/combine_parallel_matmul.cc
b/src/relax/transform/combine_parallel_matmul.cc
index 6efa4552ac..3ea17fdd70 100644
--- a/src/relax/transform/combine_parallel_matmul.cc
+++ b/src/relax/transform/combine_parallel_matmul.cc
@@ -40,6 +40,8 @@ namespace relax {
using runtime::Map;
+using FCheck = runtime::TypedPackedFunc<bool(Var, Array<Var>, Array<Var>,
Map<Var, Expr>)>;
+
/*! \brief Group shapes of the RHS matrices by rank. Matrices in a group whose
batch sizes
are compatible are combined.
*/
@@ -106,8 +108,8 @@ Patterns CreatePatterns(const BranchInfo& branch_info) {
}
/*! \brief Create a rewriter for the given parallel matmul branches. */
-runtime::TypedPackedFunc<Map<Var, Expr>(Map<DFPattern, Var>)> GetRewriter(
- const Patterns& patterns, const BranchInfo& branch_info) {
+runtime::TypedPackedFunc<Map<Var, Expr>(Map<DFPattern, Var>, Map<Var, Expr>)>
GetRewriter(
+ const Patterns& patterns, const BranchInfo& branch_info, FCheck check) {
auto batch_dims_compatible = [](size_t rhs_dim, const std::vector<size_t>&
indices,
const std::vector<Array<PrimExpr>>&
rhs_shapes) {
arith::Analyzer ana;
@@ -123,7 +125,7 @@ runtime::TypedPackedFunc<Map<Var, Expr>(Map<DFPattern,
Var>)> GetRewriter(
return true;
};
- return [=](Map<DFPattern, Var> matchings) {
+ return [=](Map<DFPattern, Var> matchings, Map<Var, Expr> bindings) {
std::vector<Array<PrimExpr>> rhs_shapes;
for (const auto& rhs_pat : patterns.rhs) {
auto rhs_shape_opt = GetTensorSInfo(matchings[rhs_pat])->GetShape();
@@ -138,7 +140,9 @@ runtime::TypedPackedFunc<Map<Var, Expr>(Map<DFPattern,
Var>)> GetRewriter(
for (const auto& [rhs_dim, indices] : GroupShapes(rhs_shapes)) {
if (indices.size() == 1 || !batch_dims_compatible(rhs_dim, indices,
rhs_shapes)) continue;
- Array<Expr> rhs, bias;
+ auto inp = matchings[patterns.input];
+
+ Array<Var> rhs, bias;
for (auto ind : indices) {
rhs.push_back(matchings[patterns.rhs[ind]]);
if (branch_info.bias_dim) {
@@ -147,8 +151,17 @@ runtime::TypedPackedFunc<Map<Var, Expr>(Map<DFPattern,
Var>)> GetRewriter(
}
}
- auto inp = matchings[patterns.input];
- auto concat_rhs = concat(Tuple(rhs), Integer(rhs_dim - 1));
+ if (!check(inp, rhs, bias, bindings)) {
+ continue;
+ }
+
+ auto make_tuple = [](const Array<Var>& var_array) {
+ Array<Expr> exp_array;
+ for (auto v : var_array) exp_array.push_back(v);
+ return Tuple(exp_array);
+ };
+
+ auto concat_rhs = concat(make_tuple(rhs), Integer(rhs_dim - 1));
auto out_dtype =
GetTensorSInfo(matchings[patterns.matmul[indices[0]]])->dtype;
auto matmul_combined = matmul(inp, concat_rhs, out_dtype);
@@ -160,7 +173,7 @@ runtime::TypedPackedFunc<Map<Var, Expr>(Map<DFPattern,
Var>)> GetRewriter(
if (branch_info.bias_dim) {
auto bias_dim = GetTensorSInfo(bias[0])->ndim;
- auto concat_bias = concat(Tuple(bias), Integer(bias_dim - 1));
+ auto concat_bias = concat(make_tuple(bias), Integer(bias_dim - 1));
matmul_combined = add(matmul_combined, concat_bias);
}
@@ -200,9 +213,9 @@ runtime::TypedPackedFunc<Map<Var, Expr>(Map<DFPattern,
Var>)> GetRewriter(
};
}
-Function Rewrite(Function f, const BranchInfo& branch_info) {
+Function Rewrite(Function f, const BranchInfo& branch_info, FCheck check) {
auto patterns = CreatePatterns(branch_info);
- auto rewriter = GetRewriter(patterns, branch_info);
+ auto rewriter = GetRewriter(patterns, branch_info, check);
return RewriteBindings(patterns.ctx, rewriter, f);
}
@@ -313,22 +326,24 @@ std::vector<BranchInfo> GetBranchInfo(Function f) {
return info;
}
-Function CombineParallelMatmul(Function f) {
+Function CombineParallelMatmul(Function f, FCheck check) {
auto branches = GetBranchInfo(f);
std::sort(branches.begin(), branches.end(),
[](const auto& b1, const auto& b2) { return b1.num_branches >
b2.num_branches; });
for (const auto& branch : branches) {
- f = Rewrite(f, branch);
+ f = Rewrite(f, branch, check);
}
return f;
}
namespace transform {
-Pass CombineParallelMatmul() {
+Pass CombineParallelMatmul(FCheck check) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>
pass_func =
- [=](Function f, IRModule m, PassContext pc) { return
relax::CombineParallelMatmul(f); };
+ [=](Function f, IRModule m, PassContext pc) {
+ return relax::CombineParallelMatmul(f, check);
+ };
return CreateFunctionPass(/*pass_function=*/pass_func, //
/*opt_level=*/0, //
/*pass_name=*/"CombineParallelMatmul", //
diff --git a/tests/python/relax/test_dataflow_pattern.py
b/tests/python/relax/test_dataflow_pattern.py
index ed221f54be..e32314cf39 100644
--- a/tests/python/relax/test_dataflow_pattern.py
+++ b/tests/python/relax/test_dataflow_pattern.py
@@ -1033,7 +1033,7 @@ def test_attention_fake_qkv():
def get_qkv_proj_rewriter(
inp_pat, Q_weight_pat, K_weight_pat, V_weight_pat, matmul1, matmul2,
matmul3
):
- def qkv_proj_rewriter(matchings):
+ def qkv_proj_rewriter(matchings, _):
inp = matchings[inp_pat]
Q_weight = matchings[Q_weight_pat]
K_weight = matchings[K_weight_pat]
diff --git a/tests/python/relax/test_transform_combine_parallel_matmul.py
b/tests/python/relax/test_transform_combine_parallel_matmul.py
index 41cba1a58b..719daaf449 100644
--- a/tests/python/relax/test_transform_combine_parallel_matmul.py
+++ b/tests/python/relax/test_transform_combine_parallel_matmul.py
@@ -479,5 +479,57 @@ def test_multiple_combine():
tvm.ir.assert_structural_equal(mod["main"], expected1)
+def test_check():
+ @tvm.script.ir_module
+ class multiple_combine:
+ @R.function
+ def main(
+ x1: R.Tensor((2, 1024, 640), "float32"),
+ x2: R.Tensor((2, 1024, 640), "float32"),
+ w0: R.Tensor((640, 640), "float32"),
+ w1: R.Tensor((640, 640), "float32"),
+ w2: R.Tensor((640, 640), "float32"),
+ w3: R.Tensor((640, 640), "float32"),
+ w4: R.Tensor((640, 640), "float32"),
+ ) -> R.Tensor:
+ with R.dataflow():
+ lv0 = R.matmul(x1, w0)
+ lv1 = R.matmul(x1, w1)
+ lv2 = R.matmul(x1, w2)
+ lv3 = R.matmul(x2, w3)
+ lv4 = R.matmul(x2, w4)
+ out = (lv0, lv1, lv2, lv3, lv4)
+ R.output(out)
+ return out
+
+ check = lambda *inp: len(inp[1]) > 2 # Ignore branches with two matmuls
+ mod = CombineParallelMatmul(check)(multiple_combine)
+
+ @R.function
+ def expected(
+ x1: R.Tensor((2, 1024, 640), dtype="float32"),
+ x2: R.Tensor((2, 1024, 640), dtype="float32"),
+ w0: R.Tensor((640, 640), dtype="float32"),
+ w1: R.Tensor((640, 640), dtype="float32"),
+ w2: R.Tensor((640, 640), dtype="float32"),
+ w3: R.Tensor((640, 640), dtype="float32"),
+ w4: R.Tensor((640, 640), dtype="float32"),
+ ) -> R.Tensor:
+ with R.dataflow():
+ lv = R.concat((w0, w1, w2), axis=1)
+ lv1 = R.matmul(x1, lv, out_dtype="float32")
+ lv2 = R.split(lv1, indices_or_sections=[640, 1280], axis=2)
+ lv0 = lv2[0]
+ lv1_1 = lv2[1]
+ lv2_1 = lv2[2]
+ lv3 = R.matmul(x2, w3, out_dtype="void")
+ lv4 = R.matmul(x2, w4, out_dtype="void")
+ out = (lv0, lv1_1, lv2_1, lv3, lv4)
+ R.output(out)
+ return out
+
+ tvm.ir.assert_structural_equal(mod["main"], expected)
+
+
if __name__ == "__main__":
tvm.testing.main()