This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 9c0b41bc74 [Unity] Allow filtering out unwanted branches in matmul 
combining pass (#14971)
9c0b41bc74 is described below

commit 9c0b41bc741a6814ab03e685486491f3afd441c2
Author: masahi <[email protected]>
AuthorDate: Mon May 29 10:32:45 2023 +0900

    [Unity] Allow filtering out unwanted branches in matmul combining pass 
(#14971)
    
    * pass current bindings to rewriter
    
    * add check func to CombineParallelMatmul
    
    * clean
    
    * add doc for df binding rewrite update
    
    * add test
    
    * black
---
 python/tvm/relax/dpl/rewrite.py                    |  9 ++--
 python/tvm/relax/transform/transform.py            | 12 ++++-
 src/relax/ir/dataflow_matcher.cc                   | 29 +++++++-----
 src/relax/transform/combine_parallel_matmul.cc     | 41 +++++++++++------
 tests/python/relax/test_dataflow_pattern.py        |  2 +-
 .../test_transform_combine_parallel_matmul.py      | 52 ++++++++++++++++++++++
 6 files changed, 116 insertions(+), 29 deletions(-)

diff --git a/python/tvm/relax/dpl/rewrite.py b/python/tvm/relax/dpl/rewrite.py
index 1b62a42903..291061090f 100644
--- a/python/tvm/relax/dpl/rewrite.py
+++ b/python/tvm/relax/dpl/rewrite.py
@@ -60,7 +60,9 @@ def rewrite_call(
 
 
 def rewrite_bindings(
-    ctx: PatternContext, rewriter: Callable[[Dict[DFPattern, Var]], Dict[Var, 
Expr]], func: Function
+    ctx: PatternContext,
+    rewriter: Callable[[Dict[DFPattern, Var], Dict[Var, Expr]], Dict[Var, 
Expr]],
+    func: Function,
 ) -> Function:
     """
     Rewrite a function with the given pattern and the rewriter function.
@@ -70,10 +72,11 @@ def rewrite_bindings(
     ctx: PatternContext
         The pattern constraint context under which rewriting takes place.
 
-    rewriter: Callable[[Dict[DFPattern, Var]], Dict[Var, Expr]]
+    rewriter: Callable[[Dict[DFPattern, Var], Dict[Var, Expr]], Dict[Var, 
Expr]]
         The function to be called on a successful matching for rewriting. 
Given the map of patterns
         and corresponding variables (bound variables or parameters), it should 
return a map that
-        specifies new values for matched bound variables.
+        specifies new values for matched bound variables. It can refer to the 
passed bindings to
+        create the replacement expressions.
 
         For example, to rewrite three matmuls for QKV projection in 
transformer models into one
         matmul followed by slicing, one can use the follwoing rewriter:
diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index 6013073a37..7d390ed1f9 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -1012,7 +1012,7 @@ def SplitCallTIRByPattern(patterns, fcodegen) -> 
tvm.ir.transform.Pass:
     return _ffi_api.SplitCallTIRByPattern(patterns, fcodegen)  # type: ignore
 
 
-def CombineParallelMatmul():
+def CombineParallelMatmul(check=None):
     """Combine multiple matmul operators sharing the same LHS matrix into one,
     followed by slicing. When all matmul branches in a tree have the same set 
of fused ops,
     the fused ops are applied to the combined matmul output before slicing.
@@ -1020,12 +1020,20 @@ def CombineParallelMatmul():
     Currently, only a limited set of fused ops is supported. It includes bias 
add,
     relu, gelu, gelu_tanh and silu activation.
 
+    Parameters
+    ----------
+    check : Callable[[Var, List[Var], List[Var], Dict[Var, Expr]], bool]
+        A function to filter out unwanted branches, with the signature
+        (input, [rhs], [bias], binding) -> bool.
+
     Returns
     -------
     ret : tvm.transform.Pass
         The corresponding pass.
     """
-    return _ffi_api.CombineParallelMatmul()  # type: ignore
+    if check is None:
+        check = lambda *_: True
+    return _ffi_api.CombineParallelMatmul(check)  # type: ignore
 
 
 def RewriteCUDAGraph() -> tvm.ir.transform.Pass:
diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc
index b06da62c26..81cccec86d 100644
--- a/src/relax/ir/dataflow_matcher.cc
+++ b/src/relax/ir/dataflow_matcher.cc
@@ -707,12 +707,11 @@ static std::optional<MatchState> MatchTree(
   return std::nullopt;
 }
 
-Optional<Map<DFPattern, Var>> MatchGraph(const PatternContext& ctx, const 
DataflowBlock& dfb) {
+Optional<Map<DFPattern, Var>> MatchGraph(const PatternContext& ctx, const 
DataflowBlock& dfb,
+                                         const Map<Var, Expr>& bindings) {
   // TODO(@ganler): Handle non-may external use.
   ICHECK(ctx->allow_extern_use == PatternContextNode::kMay) << "Only kMay is 
supported yet.";
-
-  const auto var2val = AnalyzeVar2Value(dfb);
-  DFPatternMatcher matcher(var2val);
+  DFPatternMatcher matcher(bindings);
 
   MatcherUseDefAnalysis ud_analysis;
   ud_analysis.VisitBindingBlock_(dfb.get());
@@ -772,7 +771,14 @@ Optional<Map<DFPattern, Var>> MatchGraph(const 
PatternContext& ctx, const Datafl
   return NullOpt;
 }
 
-TVM_REGISTER_GLOBAL("relax.dpl.match_dfb").set_body_typed(MatchGraph);
+Optional<Map<DFPattern, Var>> MatchGraph(const PatternContext& ctx, const 
DataflowBlock& dfb) {
+  return MatchGraph(ctx, dfb, AnalyzeVar2Value(dfb));
+}
+
+TVM_REGISTER_GLOBAL("relax.dpl.match_dfb")
+    .set_body_typed([](const PatternContext& ctx, const DataflowBlock& dfb) {
+      return MatchGraph(ctx, dfb);
+    });
 
 /*!
  * \brief Apply pattern matching to each call node and dataflow block, and 
replace matching ones
@@ -863,9 +869,11 @@ class PatternRewriter : ExprMutator {
 
   // Repeat until all matchable subsets of bindings are rewritten.
   BindingBlock RewriteDataflowBlockFixedPoint(BindingBlock block) {
-    if (auto matches = MatchGraph(ctx_.value(), 
Downcast<DataflowBlock>(block))) {
+    auto df_block = Downcast<DataflowBlock>(block);
+    Map<Var, Expr> bindings = AnalyzeVar2Value(df_block);
+    if (auto matches = MatchGraph(ctx_.value(), df_block, bindings)) {
       builder_->BeginDataflowBlock();
-      Map<Var, Expr> replacements = rewriter_func_(matches.value());
+      Map<Var, Expr> replacements = rewriter_func_(matches.value(), bindings);
 
       std::unordered_set<const VarNode*> emitted_vars;
 
@@ -906,9 +914,10 @@ class PatternRewriter : ExprMutator {
    * - (Call, Map<DFPattern, Expr>) -> Call for call node rewriting. Given the 
matched
    *    call node and the map of patterns and matched expressions, it should 
return a new call node
    *    to replace the original one or the original matched call node as is.
-   * - Map<DFPattern, Var> -> Map<Var, Expr> for dataflow block rewriting. 
Given the map of patterns
-   *   and corresponding variables (bound variables or parameters), it should 
return a map that
-   *   specifies new values for matched bound variables.
+   * - (Map<DFPattern, Var>, Map<Var, Expr>) -> Map<Var, Expr> for dataflow 
block rewriting.
+   *    Given the map of patterns and corresponding variables (bound variables 
or parameters),
+   *    it should return a map that specifies new values for matched bound 
variables. It can refer
+   *    to the passed bindings to create the replacement expressions.
    */
   PackedFunc rewriter_func_;
   std::unordered_set<const VarNode*> params_;
diff --git a/src/relax/transform/combine_parallel_matmul.cc 
b/src/relax/transform/combine_parallel_matmul.cc
index 6efa4552ac..3ea17fdd70 100644
--- a/src/relax/transform/combine_parallel_matmul.cc
+++ b/src/relax/transform/combine_parallel_matmul.cc
@@ -40,6 +40,8 @@ namespace relax {
 
 using runtime::Map;
 
+using FCheck = runtime::TypedPackedFunc<bool(Var, Array<Var>, Array<Var>, 
Map<Var, Expr>)>;
+
 /*! \brief Group shapes of the RHS matrices by rank. Matrices in a group whose 
batch sizes
   are compatible are combined.
 */
@@ -106,8 +108,8 @@ Patterns CreatePatterns(const BranchInfo& branch_info) {
 }
 
 /*! \brief Create a rewriter for the given parallel matmul branches. */
-runtime::TypedPackedFunc<Map<Var, Expr>(Map<DFPattern, Var>)> GetRewriter(
-    const Patterns& patterns, const BranchInfo& branch_info) {
+runtime::TypedPackedFunc<Map<Var, Expr>(Map<DFPattern, Var>, Map<Var, Expr>)> 
GetRewriter(
+    const Patterns& patterns, const BranchInfo& branch_info, FCheck check) {
   auto batch_dims_compatible = [](size_t rhs_dim, const std::vector<size_t>& 
indices,
                                   const std::vector<Array<PrimExpr>>& 
rhs_shapes) {
     arith::Analyzer ana;
@@ -123,7 +125,7 @@ runtime::TypedPackedFunc<Map<Var, Expr>(Map<DFPattern, 
Var>)> GetRewriter(
     return true;
   };
 
-  return [=](Map<DFPattern, Var> matchings) {
+  return [=](Map<DFPattern, Var> matchings, Map<Var, Expr> bindings) {
     std::vector<Array<PrimExpr>> rhs_shapes;
     for (const auto& rhs_pat : patterns.rhs) {
       auto rhs_shape_opt = GetTensorSInfo(matchings[rhs_pat])->GetShape();
@@ -138,7 +140,9 @@ runtime::TypedPackedFunc<Map<Var, Expr>(Map<DFPattern, 
Var>)> GetRewriter(
     for (const auto& [rhs_dim, indices] : GroupShapes(rhs_shapes)) {
       if (indices.size() == 1 || !batch_dims_compatible(rhs_dim, indices, 
rhs_shapes)) continue;
 
-      Array<Expr> rhs, bias;
+      auto inp = matchings[patterns.input];
+
+      Array<Var> rhs, bias;
       for (auto ind : indices) {
         rhs.push_back(matchings[patterns.rhs[ind]]);
         if (branch_info.bias_dim) {
@@ -147,8 +151,17 @@ runtime::TypedPackedFunc<Map<Var, Expr>(Map<DFPattern, 
Var>)> GetRewriter(
         }
       }
 
-      auto inp = matchings[patterns.input];
-      auto concat_rhs = concat(Tuple(rhs), Integer(rhs_dim - 1));
+      if (!check(inp, rhs, bias, bindings)) {
+        continue;
+      }
+
+      auto make_tuple = [](const Array<Var>& var_array) {
+        Array<Expr> exp_array;
+        for (auto v : var_array) exp_array.push_back(v);
+        return Tuple(exp_array);
+      };
+
+      auto concat_rhs = concat(make_tuple(rhs), Integer(rhs_dim - 1));
       auto out_dtype = 
GetTensorSInfo(matchings[patterns.matmul[indices[0]]])->dtype;
       auto matmul_combined = matmul(inp, concat_rhs, out_dtype);
 
@@ -160,7 +173,7 @@ runtime::TypedPackedFunc<Map<Var, Expr>(Map<DFPattern, 
Var>)> GetRewriter(
 
       if (branch_info.bias_dim) {
         auto bias_dim = GetTensorSInfo(bias[0])->ndim;
-        auto concat_bias = concat(Tuple(bias), Integer(bias_dim - 1));
+        auto concat_bias = concat(make_tuple(bias), Integer(bias_dim - 1));
         matmul_combined = add(matmul_combined, concat_bias);
       }
 
@@ -200,9 +213,9 @@ runtime::TypedPackedFunc<Map<Var, Expr>(Map<DFPattern, 
Var>)> GetRewriter(
   };
 }
 
-Function Rewrite(Function f, const BranchInfo& branch_info) {
+Function Rewrite(Function f, const BranchInfo& branch_info, FCheck check) {
   auto patterns = CreatePatterns(branch_info);
-  auto rewriter = GetRewriter(patterns, branch_info);
+  auto rewriter = GetRewriter(patterns, branch_info, check);
   return RewriteBindings(patterns.ctx, rewriter, f);
 }
 
@@ -313,22 +326,24 @@ std::vector<BranchInfo> GetBranchInfo(Function f) {
   return info;
 }
 
-Function CombineParallelMatmul(Function f) {
+Function CombineParallelMatmul(Function f, FCheck check) {
   auto branches = GetBranchInfo(f);
   std::sort(branches.begin(), branches.end(),
             [](const auto& b1, const auto& b2) { return b1.num_branches > 
b2.num_branches; });
 
   for (const auto& branch : branches) {
-    f = Rewrite(f, branch);
+    f = Rewrite(f, branch, check);
   }
   return f;
 }
 
 namespace transform {
 
-Pass CombineParallelMatmul() {
+Pass CombineParallelMatmul(FCheck check) {
   runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> 
pass_func =
-      [=](Function f, IRModule m, PassContext pc) { return 
relax::CombineParallelMatmul(f); };
+      [=](Function f, IRModule m, PassContext pc) {
+        return relax::CombineParallelMatmul(f, check);
+      };
   return CreateFunctionPass(/*pass_function=*/pass_func,            //
                             /*opt_level=*/0,                        //
                             /*pass_name=*/"CombineParallelMatmul",  //
diff --git a/tests/python/relax/test_dataflow_pattern.py 
b/tests/python/relax/test_dataflow_pattern.py
index ed221f54be..e32314cf39 100644
--- a/tests/python/relax/test_dataflow_pattern.py
+++ b/tests/python/relax/test_dataflow_pattern.py
@@ -1033,7 +1033,7 @@ def test_attention_fake_qkv():
 def get_qkv_proj_rewriter(
     inp_pat, Q_weight_pat, K_weight_pat, V_weight_pat, matmul1, matmul2, 
matmul3
 ):
-    def qkv_proj_rewriter(matchings):
+    def qkv_proj_rewriter(matchings, _):
         inp = matchings[inp_pat]
         Q_weight = matchings[Q_weight_pat]
         K_weight = matchings[K_weight_pat]
diff --git a/tests/python/relax/test_transform_combine_parallel_matmul.py 
b/tests/python/relax/test_transform_combine_parallel_matmul.py
index 41cba1a58b..719daaf449 100644
--- a/tests/python/relax/test_transform_combine_parallel_matmul.py
+++ b/tests/python/relax/test_transform_combine_parallel_matmul.py
@@ -479,5 +479,57 @@ def test_multiple_combine():
     tvm.ir.assert_structural_equal(mod["main"], expected1)
 
 
+def test_check():
+    @tvm.script.ir_module
+    class multiple_combine:
+        @R.function
+        def main(
+            x1: R.Tensor((2, 1024, 640), "float32"),
+            x2: R.Tensor((2, 1024, 640), "float32"),
+            w0: R.Tensor((640, 640), "float32"),
+            w1: R.Tensor((640, 640), "float32"),
+            w2: R.Tensor((640, 640), "float32"),
+            w3: R.Tensor((640, 640), "float32"),
+            w4: R.Tensor((640, 640), "float32"),
+        ) -> R.Tensor:
+            with R.dataflow():
+                lv0 = R.matmul(x1, w0)
+                lv1 = R.matmul(x1, w1)
+                lv2 = R.matmul(x1, w2)
+                lv3 = R.matmul(x2, w3)
+                lv4 = R.matmul(x2, w4)
+                out = (lv0, lv1, lv2, lv3, lv4)
+                R.output(out)
+            return out
+
+    check = lambda *inp: len(inp[1]) > 2  # Ignore branches with two matmuls
+    mod = CombineParallelMatmul(check)(multiple_combine)
+
+    @R.function
+    def expected(
+        x1: R.Tensor((2, 1024, 640), dtype="float32"),
+        x2: R.Tensor((2, 1024, 640), dtype="float32"),
+        w0: R.Tensor((640, 640), dtype="float32"),
+        w1: R.Tensor((640, 640), dtype="float32"),
+        w2: R.Tensor((640, 640), dtype="float32"),
+        w3: R.Tensor((640, 640), dtype="float32"),
+        w4: R.Tensor((640, 640), dtype="float32"),
+    ) -> R.Tensor:
+        with R.dataflow():
+            lv = R.concat((w0, w1, w2), axis=1)
+            lv1 = R.matmul(x1, lv, out_dtype="float32")
+            lv2 = R.split(lv1, indices_or_sections=[640, 1280], axis=2)
+            lv0 = lv2[0]
+            lv1_1 = lv2[1]
+            lv2_1 = lv2[2]
+            lv3 = R.matmul(x2, w3, out_dtype="void")
+            lv4 = R.matmul(x2, w4, out_dtype="void")
+            out = (lv0, lv1_1, lv2_1, lv3, lv4)
+            R.output(out)
+        return out
+
+    tvm.ir.assert_structural_equal(mod["main"], expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to