This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new a59c09de77 [Unity] Add pass for combining parallel matmul (#14583)
a59c09de77 is described below

commit a59c09de77bbab5514dc4ff3b1713f31113def51
Author: masahi <[email protected]>
AuthorDate: Thu Apr 13 05:40:53 2023 +0900

    [Unity] Add pass for combining parallel matmul (#14583)
    
    * stub
    
    * make EnterWithScope and ExitWithScope public
    
    * qkv combining works
    
    * automatic branch extraction
    
    * fix hardcoded concat axis
    
    * handle non-uniform rhs ranks
    
    * wip
    
    * improve termination check in binding rewriter
    
    * wip
    
    * properly handle rhs with same rank but different batch size
    
    * support bias and activation
    
    * refactor
    
    * fixed activation handling
    
    * wip
    
    * clean
    
    * fix bias and activation combine logic
    
    * add tests
    
    * add comment
    
    * add doc
    
    * fix compile warning
---
 include/tvm/relax/dataflow_matcher.h               |  10 +
 include/tvm/relax/dataflow_pattern.h               |  10 +-
 python/tvm/relax/transform/transform.py            |  16 +
 src/relax/ir/dataflow_matcher.cc                   |  24 +-
 src/relax/ir/dataflow_pattern.cc                   |  18 +-
 src/relax/transform/combine_parallel_matmul.cc     | 337 +++++++++++++++
 tests/python/relax/test_dataflow_pattern.py        |   6 +-
 .../test_transform_combine_parallel_matmul.py      | 469 +++++++++++++++++++++
 8 files changed, 863 insertions(+), 27 deletions(-)

diff --git a/include/tvm/relax/dataflow_matcher.h 
b/include/tvm/relax/dataflow_matcher.h
index cf7c58f093..16249377a2 100644
--- a/include/tvm/relax/dataflow_matcher.h
+++ b/include/tvm/relax/dataflow_matcher.h
@@ -58,6 +58,16 @@ Optional<Map<DFPattern, Expr>> ExtractMatchedExpr(
 TVM_DLL Optional<Map<DFPattern, Var>> MatchGraph(const PatternContext& ctx,
                                                  const DataflowBlock& dfb);
 
+/**
+ * \brief Rewrite a function with the given pattern and the rewriter function.
+ * \param ctx The pattern constraint context under which rewriting takes place.
+ * \param rewriter The function to be called on a successful matching for 
rewriting.
+    Given the map of patterns and corresponding variables (bound variables or 
parameters),
+    it should return a map that specifies new values for matched bound 
variables.
+ * \param f The function to rewrite
+ * \return The rewritten or the input function, depending on the pattern 
matching result.
+ */
+TVM_DLL Function RewriteBindings(const PatternContext& ctx, PackedFunc 
rewriter, Function f);
 }  // namespace relax
 }  // namespace tvm
 
diff --git a/include/tvm/relax/dataflow_pattern.h 
b/include/tvm/relax/dataflow_pattern.h
index e4c27f3558..68cfdd83ad 100644
--- a/include/tvm/relax/dataflow_pattern.h
+++ b/include/tvm/relax/dataflow_pattern.h
@@ -248,14 +248,12 @@ class PatternContext : public ObjectRef {
   /*! \brief Get the constraint context object on the top of the stack */
   TVM_DLL static Optional<PatternContext> Current();
 
-  class Internal;
-
- private:
   /*! \brief The RAII-like entry of a constraint context scope */
-  TVM_DLL void EnterWithScope();
+  TVM_DLL void EnterWithScope() const;
   /*! \brief The RAII-like exit of a constraint context scope */
-  TVM_DLL void ExitWithScope();
-  friend class Internal;
+  TVM_DLL void ExitWithScope() const;
+
+ private:
   friend class With<PatternContext>;
 };
 
diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index a53c45b655..f0277151bb 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -934,6 +934,22 @@ def SplitCallTIRByPattern(patterns, fcodegen) -> 
tvm.ir.transform.Pass:
     return _ffi_api.SplitCallTIRByPattern(patterns, fcodegen)  # type: ignore
 
 
+def CombineParallelMatmul():
+    """Combine multiple matmul operators sharing the same LHS matrix into one,
+    followed by slicing. When all matmul branches in a tree have the same set 
of fused ops,
+    the fused ops are applied to the combined matmul output before slicing.
+
+    Currently, only a limited set of fused ops is supported. It includes bias 
add,
+    relu, gelu, and silu activation.
+
+    Returns
+    -------
+    ret : tvm.transform.Pass
+        The corresponding pass.
+    """
+    return _ffi_api.CombineParallelMatmul()  # type: ignore
+
+
 def _wrap_class_function_pass(pass_cls, pass_info):
     """Wrap a python class as function pass."""
 
diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc
index 6e8211cfd3..b06da62c26 100644
--- a/src/relax/ir/dataflow_matcher.cc
+++ b/src/relax/ir/dataflow_matcher.cc
@@ -22,6 +22,7 @@
  * \brief The dataflow pattern matcher for Relax.
  */
 
+#include <tvm/node/structural_equal.h>
 #include <tvm/relax/analysis.h>
 #include <tvm/relax/dataflow_matcher.h>
 #include <tvm/relax/dataflow_pattern.h>
@@ -791,7 +792,7 @@ class PatternRewriter : ExprMutator {
       : ctx_(ctx), rewriter_func_(rewriter_func), params_(params) {}
 
   template <typename PatternType>
-  static Expr Run(PatternType pat, PackedFunc rewriter_func, Function f) {
+  static Function Run(PatternType pat, PackedFunc rewriter_func, Function f) {
     std::unordered_set<const VarNode*> params;
     for (const auto& p : f->params) {
       params.insert(p.get());
@@ -868,15 +869,17 @@ class PatternRewriter : ExprMutator {
 
       std::unordered_set<const VarNode*> emitted_vars;
 
+      bool changed = false;
       for (size_t i = 0; i < block->bindings.size(); ++i) {
         const auto& binding = block->bindings[i];
         if (auto var_bind = binding.as<VarBindingNode>()) {
-          if (replacements.count(var_bind->var)) {
-            auto new_val = replacements[var_bind->var];
+          if (auto new_val = 
replacements.Get(var_bind->var).value_or(var_bind->value);
+              !StructuralEqual()(var_bind->value, new_val)) {
             Array<Binding> pending_bindings(block->bindings.begin() + i + 1, 
block->bindings.end());
             // Make sure there is no unbound variable used in the new value 
before it is emitted
             EmitUsedVars(new_val, pending_bindings, &emitted_vars);
             this->ReEmitBinding(var_bind, builder_->Normalize(new_val));
+            changed = true;
           } else if (!emitted_vars.count(var_bind->var.get())) {
             this->VisitBinding(binding);
             emitted_vars.insert(var_bind->var.get());
@@ -885,7 +888,11 @@ class PatternRewriter : ExprMutator {
           this->VisitBinding(binding);
         }
       }
-      return RewriteDataflowBlockFixedPoint(builder_->EndBlock());
+
+      auto new_block = builder_->EndBlock();
+
+      if (!changed) return new_block;
+      return RewriteDataflowBlockFixedPoint(new_block);
     }
     return block;
   }
@@ -909,15 +916,16 @@ class PatternRewriter : ExprMutator {
   std::unordered_map<const Object*, Expr> memo_;
 };
 
+Function RewriteBindings(const PatternContext& ctx, PackedFunc rewriter, 
Function f) {
+  return PatternRewriter::Run(ctx, rewriter, f);
+}
+
 TVM_REGISTER_GLOBAL("relax.dpl.rewrite_call")
     .set_body_typed([](DFPattern pat, PackedFunc rewriter, Function f) {
       return PatternRewriter::Run(pat, rewriter, f);
     });
 
-TVM_REGISTER_GLOBAL("relax.dpl.rewrite_bindings")
-    .set_body_typed([](const PatternContext& ctx, PackedFunc rewriter, 
Function f) {
-      return PatternRewriter::Run(ctx, rewriter, f);
-    });
+TVM_REGISTER_GLOBAL("relax.dpl.rewrite_bindings").set_body_typed(RewriteBindings);
 
 }  // namespace relax
 }  // namespace tvm
diff --git a/src/relax/ir/dataflow_pattern.cc b/src/relax/ir/dataflow_pattern.cc
index 4d225ceecf..cd1376303c 100644
--- a/src/relax/ir/dataflow_pattern.cc
+++ b/src/relax/ir/dataflow_pattern.cc
@@ -412,9 +412,9 @@ PatternContext::PatternContext(bool incremental) {
   data_ = std::move(n);
 }
 
-void PatternContext::EnterWithScope() { pattern_ctx_stack().push(*this); }
+void PatternContext::EnterWithScope() const { pattern_ctx_stack().push(*this); 
}
 
-void PatternContext::ExitWithScope() {
+void PatternContext::ExitWithScope() const {
   ICHECK(pattern_ctx_stack().top().same_as(*this));
   pattern_ctx_stack().pop();
 }
@@ -610,15 +610,13 @@ 
TVM_REGISTER_GLOBAL("relax.dpl.current_context").set_body_typed([] {
   return PatternContext::Current();
 });
 
-class PatternContext::Internal {
- public:
-  static void EnterScope(PatternContext pass_ctx) { pass_ctx.EnterWithScope(); 
}
-  static void ExitScope(PatternContext pass_ctx) { pass_ctx.ExitWithScope(); }
-};
-
-TVM_REGISTER_GLOBAL("relax.dpl.enter_context").set_body_typed(PatternContext::Internal::EnterScope);
+TVM_REGISTER_GLOBAL("relax.dpl.enter_context").set_body_typed([](const 
PatternContext& ctx) {
+  ctx.EnterWithScope();
+});
 
-TVM_REGISTER_GLOBAL("relax.dpl.exit_context").set_body_typed(PatternContext::Internal::ExitScope);
+TVM_REGISTER_GLOBAL("relax.dpl.exit_context").set_body_typed([](const 
PatternContext& ctx) {
+  ctx.ExitWithScope();
+});
 
 }  // namespace relax
 }  // namespace tvm
diff --git a/src/relax/transform/combine_parallel_matmul.cc 
b/src/relax/transform/combine_parallel_matmul.cc
new file mode 100644
index 0000000000..d6435ec829
--- /dev/null
+++ b/src/relax/transform/combine_parallel_matmul.cc
@@ -0,0 +1,337 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/dataflow_matcher.h>
+#include <tvm/relax/dataflow_pattern.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/struct_info.h>
+#include <tvm/relax/transform.h>
+
+#include <optional>
+#include <unordered_map>
+#include <vector>
+
+#include "../op/nn/nn.h"
+#include "../op/tensor/binary.h"
+#include "../op/tensor/index.h"
+#include "../op/tensor/linear_algebra.h"
+#include "../op/tensor/manipulate.h"
+
+namespace tvm {
+namespace relax {
+
+using runtime::Map;
+
+/*! \brief Group shapes of the RHS matrices by rank. Matrices in a group whose 
batch sizes
+  are compatible are combined.
+*/
+std::unordered_map<size_t, std::vector<size_t>> GroupShapes(
+    const std::vector<Array<PrimExpr>>& shapes) {
+  std::unordered_map<size_t, std::vector<size_t>> indices_map;
+  for (size_t i = 0; i < shapes.size(); ++i) {
+    indices_map[shapes[i].size()].push_back(i);
+  }
+  return indices_map;
+}
+
+inline TensorStructInfo GetTensorSInfo(Expr e) {
+  return Downcast<TensorStructInfo>(GetStructInfo(e));
+}
+
+struct BranchInfo {
+  int num_branches;
+  std::optional<int> bias_dim;
+  std::optional<std::string> activation;
+};
+
+struct Patterns {
+  Patterns() : input(Wildcard()) { ctx.EnterWithScope(); }
+
+  PatternContext ctx;
+  WildcardPattern input;
+  std::vector<WildcardPattern> rhs;
+  std::vector<WildcardPattern> bias;
+  std::vector<CallPattern> matmul, bias_add, activation;
+};
+
+Patterns CreatePatterns(const BranchInfo& branch_info) {
+  Patterns patterns;
+
+  for (int i = 0; i < branch_info.num_branches; ++i) {
+    auto w_pat = Wildcard();
+    auto matmul_pat = IsOp("relax.matmul")(patterns.input, w_pat);
+    patterns.rhs.push_back(w_pat);
+    patterns.matmul.push_back(matmul_pat);
+    patterns.ctx.add_constraint(patterns.input, matmul_pat, 
PairCons(PairCons::kUsedBy, 0));
+    patterns.ctx.add_constraint(w_pat, matmul_pat, PairCons(PairCons::kUsedBy, 
1));
+
+    CallPattern matmul_out = matmul_pat;
+
+    if (branch_info.bias_dim) {
+      auto bias_pat = Wildcard();
+      auto bias_add_pat = IsOp("relax.add")(matmul_pat, bias_pat);
+      patterns.bias.push_back(bias_pat);
+      patterns.bias_add.push_back(bias_add_pat);
+      patterns.ctx.add_constraint(matmul_pat, bias_add_pat, 
PairCons(PairCons::kUsedBy, 0));
+      patterns.ctx.add_constraint(bias_pat, bias_add_pat, 
PairCons(PairCons::kUsedBy, 1));
+      matmul_out = bias_add_pat;
+    }
+
+    if (branch_info.activation) {
+      auto act_pat = IsOp(*branch_info.activation)(matmul_out);
+      patterns.activation.push_back(act_pat);
+      patterns.ctx.add_constraint(matmul_out, act_pat, 
PairCons(PairCons::kUsedBy, 0));
+    }
+  }
+
+  return patterns;
+}
+
+/*! \brief Create a rewriter for the given parallel matmul branches. */
+runtime::TypedPackedFunc<Map<Var, Expr>(Map<DFPattern, Var>)> GetRewriter(
+    const Patterns& patterns, const BranchInfo& branch_info) {
+  auto batch_dims_compatible = [](size_t rhs_dim, const std::vector<size_t>& 
indices,
+                                  const std::vector<Array<PrimExpr>>& 
rhs_shapes) {
+    arith::Analyzer ana;
+    for (auto ind : indices) {
+      ICHECK_EQ(static_cast<int>(rhs_shapes[ind].size()), rhs_dim);
+      // -2 for reduction and concat axes
+      for (size_t i = 0; i < rhs_dim - 2; ++i) {
+        if (!ana.CanProve(rhs_shapes[indices[0]][i] == rhs_shapes[ind][i])) {
+          return false;
+        }
+      }
+    }
+    return true;
+  };
+
+  return [=](Map<DFPattern, Var> matchings) {
+    std::vector<Array<PrimExpr>> rhs_shapes;
+    for (const auto& rhs_pat : patterns.rhs) {
+      auto rhs_shape_opt = GetTensorSInfo(matchings[rhs_pat])->GetShape();
+      if (!rhs_shape_opt) {
+        return Map<Var, Expr>{};
+      }
+      rhs_shapes.push_back(rhs_shape_opt.value());
+    }
+
+    Map<Var, Expr> replacements;
+
+    for (const auto& [rhs_dim, indices] : GroupShapes(rhs_shapes)) {
+      if (indices.size() == 1 || !batch_dims_compatible(rhs_dim, indices, 
rhs_shapes)) continue;
+
+      Array<Expr> rhs, bias;
+      for (auto ind : indices) {
+        rhs.push_back(matchings[patterns.rhs[ind]]);
+        if (branch_info.bias_dim) {
+          ICHECK(matchings.count(patterns.bias[ind]));
+          bias.push_back(matchings[patterns.bias[ind]]);
+        }
+      }
+
+      auto inp = matchings[patterns.input];
+      auto concat_rhs = concat(Tuple(rhs), Integer(rhs_dim - 1));
+      auto out_dtype = 
GetTensorSInfo(matchings[patterns.matmul[indices[0]]])->dtype;
+      auto matmul_combined = matmul(inp, concat_rhs, out_dtype);
+
+      const auto& pattern_to_replace = [&patterns, &branch_info]() {
+        if (branch_info.activation) return patterns.activation;
+        if (branch_info.bias_dim) return patterns.bias_add;
+        return patterns.matmul;
+      }();
+
+      if (branch_info.bias_dim) {
+        auto bias_dim = GetTensorSInfo(bias[0])->ndim;
+        auto concat_bias = concat(Tuple(bias), Integer(bias_dim - 1));
+        matmul_combined = add(matmul_combined, concat_bias);
+      }
+
+      if (branch_info.activation) {
+        if (*branch_info.activation == "relax.nn.relu") {
+          matmul_combined = relu(matmul_combined);
+        } else if (*branch_info.activation == "relax.nn.gelu") {
+          matmul_combined = gelu(matmul_combined);
+        } else if (*branch_info.activation == "relax.nn.silu") {
+          matmul_combined = silu(matmul_combined);
+        } else {
+          LOG(FATAL) << "Unsupported activation: " << *branch_info.activation;
+        }
+      }
+
+      PrimExpr begin{0};
+      Array<PrimExpr> strides{1};
+      int lhs_dim = GetTensorSInfo(inp)->ndim;
+      int slice_axis = std::max<int>(lhs_dim, rhs_dim) - 1;
+
+      for (size_t i = 0; i < indices.size(); ++i) {
+        auto width = GetTensorSInfo(rhs[i])->GetShape().value()[rhs_dim - 1];
+        auto bound_var = matchings[pattern_to_replace[indices[i]]];
+        auto slice =
+            strided_slice(matmul_combined, {slice_axis}, {begin}, {begin + 
width}, strides);
+        replacements.Set(bound_var, slice);
+        begin += width;
+      }
+    }
+
+    return replacements;
+  };
+}
+
+Function Rewrite(Function f, const BranchInfo& branch_info) {
+  auto patterns = CreatePatterns(branch_info);
+  auto rewriter = GetRewriter(patterns, branch_info);
+  return RewriteBindings(patterns.ctx, rewriter, f);
+}
+
+/*! \brief Look for subtrees with parallel matmul and return information about
+  them (the number of branches and the kind of fused ops)
+*/
+std::vector<BranchInfo> GetBranchInfo(Function f) {
+  auto bias_pat = Wildcard();
+  auto matmul_pat = IsOp("relax.matmul")(Wildcard(), Wildcard());
+  auto bias_add_pat = IsOp("relax.add")(matmul_pat, bias_pat);
+
+  std::vector<std::string> activations{"relax.nn.relu", "relax.nn.gelu", 
"relax.nn.silu"};
+
+  std::vector<DFPattern> activation_pat, bias_activation_pat;
+  for (const auto& act : activations) {
+    activation_pat.push_back(IsOp(act)(matmul_pat));
+    bias_activation_pat.push_back(IsOp(act)(bias_add_pat));
+  }
+
+  auto bindings = AnalyzeVar2Value(f);
+
+  auto create_group = [&](DFPattern pat) {
+    // Maps a LHS matrix to consumer parallel matmuls
+    std::unordered_map<const VarNode*, BranchInfo> groups;
+
+    PostOrderVisit(f, [&](const Expr& e) {
+      if (!e->IsInstance<CallNode>()) return;
+      if (auto match = ExtractMatchedExpr(pat, e, bindings)) {
+        auto matmul_call = Downcast<Call>(match.value()[matmul_pat]);
+        auto matmul_lhs = Downcast<Var>(matmul_call->args[0]);
+
+        auto it = groups.find(matmul_lhs.get());
+        BranchInfo* branch = it != groups.end() ? &it->second : nullptr;
+        std::optional<int> bias_dim = std::nullopt;
+        std::optional<std::string> activation = std::nullopt;
+
+        if (match.value().count(bias_pat)) {
+          bias_dim = GetTensorSInfo(match.value()[bias_pat])->ndim;
+        }
+
+        for (size_t i = 0; i < activations.size(); ++i) {
+          if (match.value().count(activation_pat[i]) ||
+              match.value().count(bias_activation_pat[i])) {
+            activation = activations[i];
+          }
+        }
+
+        if (!branch) {
+          // Create a new subgraph with one matmul
+          groups[matmul_lhs.get()] = {1, bias_dim, activation};
+        } else {
+          // Create a new branch in the existing parallel matmul subtree, and
+          // invalidate bias and activation information when needed.
+          branch->num_branches += 1;
+
+          if (!bias_dim || (branch->bias_dim && *branch->bias_dim != 
*bias_dim)) {
+            branch->bias_dim = std::nullopt;
+          }
+
+          if (!activation || (branch->activation && *branch->activation != 
*activation)) {
+            branch->activation = std::nullopt;
+          }
+        }
+        return;
+      }
+    });
+
+    return groups;
+  };
+
+  std::unordered_map<const VarNode*, BranchInfo> groups_activation;
+  for (size_t i = 0; i < activations.size(); ++i) {
+    auto groups = create_group(bias_activation_pat[i]);
+    groups_activation.merge(std::move(groups));
+  }
+
+  for (size_t i = 0; i < activations.size(); ++i) {
+    auto groups = create_group(activation_pat[i]);
+    groups_activation.merge(std::move(groups));
+  }
+
+  auto groups_bias = create_group(bias_add_pat);
+  auto groups_matmul = create_group(matmul_pat);
+
+  for (const auto& groups : {groups_bias, groups_activation}) {
+    for (const auto& [lhs, branch] : groups) {
+      // Prefer combining more matmuls than combining fewer ones and leaving 
additional uncombined
+      // matmuls followed by bias or activation. So we combine matmuls + fused 
ops patterns only
+      // when all branches have the same fused ops.
+      if (auto it = groups_matmul.find(lhs);
+          it != groups_matmul.end() && it->second.num_branches == 
branch.num_branches) {
+        it->second = branch;
+      }
+    }
+  }
+
+  std::vector<BranchInfo> info;
+
+  for (const auto& groups : {groups_matmul, groups_activation, groups_bias}) {
+    for (const auto& group : groups) {
+      if (group.second.num_branches > 1) {
+        info.push_back(group.second);
+      }
+    }
+  }
+
+  return info;
+}
+
+Function CombineParallelMatmul(Function f) {
+  auto branches = GetBranchInfo(f);
+  std::sort(branches.begin(), branches.end(),
+            [](const auto& b1, const auto& b2) { return b1.num_branches > 
b2.num_branches; });
+
+  for (const auto& branch : branches) {
+    f = Rewrite(f, branch);
+  }
+  return f;
+}
+
+namespace transform {
+
+Pass CombineParallelMatmul() {
+  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> 
pass_func =
+      [=](Function f, IRModule m, PassContext pc) { return 
relax::CombineParallelMatmul(f); };
+  return CreateFunctionPass(/*pass_function=*/pass_func,            //
+                            /*opt_level=*/0,                        //
+                            /*pass_name=*/"CombineParallelMatmul",  //
+                            /*required=*/{});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.CombineParallelMatmul").set_body_typed(CombineParallelMatmul);
+
+}  // namespace transform
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/tests/python/relax/test_dataflow_pattern.py 
b/tests/python/relax/test_dataflow_pattern.py
index a73a62eeef..ed221f54be 100644
--- a/tests/python/relax/test_dataflow_pattern.py
+++ b/tests/python/relax/test_dataflow_pattern.py
@@ -18,7 +18,7 @@
 import pytest
 import tvm.testing
 
-from tvm import relay, relax
+from tvm import relay
 from tvm.relax.dpl import *
 from tvm.relax.analysis import get_var2val
 from tvm import relax as rx, tir
@@ -1177,9 +1177,9 @@ def test_combine_matmul_emit_order():
         # make sure it builds
         mod = tvm.IRModule()
         mod["main"] = rewritten
-        mod = relax.transform.LegalizeOps()(mod)
+        mod = rx.transform.LegalizeOps()(mod)
 
-        relax.build(mod, target="llvm")
+        rx.build(mod, target="llvm")
 
 
 if __name__ == "__main__":
diff --git a/tests/python/relax/test_transform_combine_parallel_matmul.py 
b/tests/python/relax/test_transform_combine_parallel_matmul.py
new file mode 100644
index 0000000000..f5cc269620
--- /dev/null
+++ b/tests/python/relax/test_transform_combine_parallel_matmul.py
@@ -0,0 +1,469 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm.testing
+
+from tvm import relax, tir
+from tvm.script import relax as R, tir as T
+from tvm.relax.transform import CombineParallelMatmul
+from tvm.script.ir_builder import IRBuilder
+from tvm.script.ir_builder import relax as relax_builder
+
+
+def get_parallel_matmul(
+    num_branches,
+    lhs_shape=(640, 640),
+    rhs_shape=(640, 640),
+    with_bias=None,
+    activation=None,
+):
+    dtype = "float32"
+
+    activation_map = {"relu": R.nn.relu, "gelu": R.nn.gelu}
+
+    with IRBuilder() as builder:
+        with relax_builder.function():
+            R.func_name("main")
+            x = R.arg("x", R.Tensor(lhs_shape, dtype))
+
+            rhs = []
+            bias = []
+
+            for i in range(num_branches):
+                rhs.append(R.arg("y", R.Tensor(rhs_shape, dtype)))
+
+                if with_bias and with_bias[i]:
+                    bias.append(R.arg("bias", R.Tensor((rhs_shape[1],), 
dtype)))
+                else:
+                    bias.append(None)
+
+            with R.dataflow() as frame:
+                branches = []
+
+                for i, r in enumerate(rhs):
+                    result = R.emit(R.matmul(x, r, out_dtype=dtype))
+                    if bias[i]:
+                        result = R.emit(result + bias[i])
+                    if activation and activation[i]:
+                        result = R.emit(activation_map[activation[i]](result))
+
+                    branches.append(result)
+
+                R.output(R.emit(R.concat(branches, axis=1)))
+
+            R.func_ret_value(frame.output_vars[0])
+
+    func = builder.get()
+    return tvm.IRModule({"main": func})
+
+
+def test_simple():
+    mod_orig = get_parallel_matmul(1)
+    mod = CombineParallelMatmul()(mod_orig)
+
+    tvm.ir.assert_structural_equal(mod, mod_orig)
+
+    mod = get_parallel_matmul(3)
+    mod = CombineParallelMatmul()(mod)
+
+    @R.function
+    def expected1(
+        x: R.Tensor((640, 640), dtype="float32"),
+        y: R.Tensor((640, 640), dtype="float32"),
+        y_1: R.Tensor((640, 640), dtype="float32"),
+        y_2: R.Tensor((640, 640), dtype="float32"),
+    ) -> R.Tensor((640, 1920), dtype="float32"):
+        with R.dataflow():
+            lv = R.concat((y, y_1, y_2), axis=1)
+            lv1 = R.matmul(x, lv, out_dtype="float32")
+            lv_1 = R.strided_slice(lv1, axes=[1], begin=[0], end=[640], 
strides=[1])
+            lv1_1 = R.strided_slice(lv1, axes=[1], begin=[640], end=[1280], 
strides=[1])
+            lv2 = R.strided_slice(lv1, axes=[1], begin=[1280], end=[1920], 
strides=[1])
+            lv3 = R.concat((lv_1, lv1_1, lv2), axis=1)
+            R.output(lv3)
+        return lv3
+
+    tvm.ir.assert_structural_equal(mod["main"], expected1)
+
+    # Test a batched LHS case, slicing is done on the axis 2
+    mod = get_parallel_matmul(3, lhs_shape=(2, 1024, 640))
+    mod = CombineParallelMatmul()(mod)
+
+    @R.function
+    def expected2(
+        x: R.Tensor((2, 1024, 640), dtype="float32"),
+        y: R.Tensor((640, 640), dtype="float32"),
+        y_1: R.Tensor((640, 640), dtype="float32"),
+        y_2: R.Tensor((640, 640), dtype="float32"),
+    ) -> R.Tensor((2, 3072, 640), dtype="float32"):
+        with R.dataflow():
+            lv = R.concat((y, y_1, y_2), axis=1)
+            lv1 = R.matmul(x, lv, out_dtype="float32")
+            lv_1 = R.strided_slice(lv1, axes=[2], begin=[0], end=[640], 
strides=[1])
+            lv1_1 = R.strided_slice(lv1, axes=[2], begin=[640], end=[1280], 
strides=[1])
+            lv2 = R.strided_slice(lv1, axes=[2], begin=[1280], end=[1920], 
strides=[1])
+            lv3 = R.concat((lv_1, lv1_1, lv2), axis=1)
+            R.output(lv3)
+        return lv3
+
+    tvm.ir.assert_structural_equal(mod["main"], expected2)
+
+
+def test_bias():
+    mod = get_parallel_matmul(3, with_bias=[True, True, True])
+    mod = CombineParallelMatmul()(mod)
+
+    @R.function
+    def expected1(
+        x: R.Tensor((640, 640), dtype="float32"),
+        y: R.Tensor((640, 640), dtype="float32"),
+        bias: R.Tensor((640,), dtype="float32"),
+        y_1: R.Tensor((640, 640), dtype="float32"),
+        bias_1: R.Tensor((640,), dtype="float32"),
+        y_2: R.Tensor((640, 640), dtype="float32"),
+        bias_2: R.Tensor((640,), dtype="float32"),
+    ) -> R.Tensor((640, 1920), dtype="float32"):
+        with R.dataflow():
+            lv = R.concat((y, y_1, y_2), axis=1)
+            lv1 = R.matmul(x, lv, out_dtype="float32")
+            lv2 = R.concat((bias, bias_1, bias_2), axis=0)
+            lv3 = R.add(lv1, lv2)
+            lv1_1 = R.strided_slice(lv3, axes=[1], begin=[0], end=[640], 
strides=[1])
+            lv3_1 = R.strided_slice(lv3, axes=[1], begin=[640], end=[1280], 
strides=[1])
+            lv5 = R.strided_slice(lv3, axes=[1], begin=[1280], end=[1920], 
strides=[1])
+            lv6 = R.concat((lv1_1, lv3_1, lv5), axis=1)
+            R.output(lv6)
+        return lv6
+
+    tvm.ir.assert_structural_equal(mod["main"], expected1)
+
+    mod = get_parallel_matmul(3, with_bias=[True, False, True])
+    mod = CombineParallelMatmul()(mod)
+
+    @R.function
+    def expected2(
+        x: R.Tensor((640, 640), dtype="float32"),
+        y: R.Tensor((640, 640), dtype="float32"),
+        bias: R.Tensor((640,), dtype="float32"),
+        y_1: R.Tensor((640, 640), dtype="float32"),
+        y_2: R.Tensor((640, 640), dtype="float32"),
+        bias_1: R.Tensor((640,), dtype="float32"),
+    ) -> R.Tensor((640, 1920), dtype="float32"):
+        with R.dataflow():
+            lv = R.concat((y, y_1, y_2), axis=1)
+            lv1 = R.matmul(x, lv, out_dtype="float32")
+            lv_1 = R.strided_slice(lv1, axes=[1], begin=[0], end=[640], 
strides=[1])
+            lv1_1 = R.add(lv_1, bias)
+            lv2 = R.strided_slice(lv1, axes=[1], begin=[640], end=[1280], 
strides=[1])
+            lv3 = R.strided_slice(lv1, axes=[1], begin=[1280], end=[1920], 
strides=[1])
+            lv4 = R.add(lv3, bias_1)
+            lv5 = R.concat((lv1_1, lv2, lv4), axis=1)
+            R.output(lv5)
+        return lv5
+
+    tvm.ir.assert_structural_equal(mod["main"], expected2)
+
+
+def test_activation():
+    mod = get_parallel_matmul(3, activation=["relu", "relu", "relu"])
+    mod = CombineParallelMatmul()(mod)
+
+    @R.function
+    def expected1(
+        x: R.Tensor((640, 640), dtype="float32"),
+        y: R.Tensor((640, 640), dtype="float32"),
+        y_1: R.Tensor((640, 640), dtype="float32"),
+        y_2: R.Tensor((640, 640), dtype="float32"),
+    ) -> R.Tensor((640, 1920), dtype="float32"):
+        with R.dataflow():
+            lv = R.concat((y, y_1, y_2), axis=1)
+            lv1 = R.matmul(x, lv, out_dtype="float32")
+            lv2 = R.nn.relu(lv1)
+            lv1_1 = R.strided_slice(lv2, axes=[1], begin=[0], end=[640], 
strides=[1])
+            lv3 = R.strided_slice(lv2, axes=[1], begin=[640], end=[1280], 
strides=[1])
+            lv5 = R.strided_slice(lv2, axes=[1], begin=[1280], end=[1920], 
strides=[1])
+            lv6 = R.concat((lv1_1, lv3, lv5), axis=1)
+            R.output(lv6)
+        return lv6
+
+    tvm.ir.assert_structural_equal(mod["main"], expected1)
+
+    mod = get_parallel_matmul(3, activation=["gelu", "relu", "relu"])
+    mod = CombineParallelMatmul()(mod)
+
+    @R.function
+    def expected2(
+        x: R.Tensor((640, 640), dtype="float32"),
+        y: R.Tensor((640, 640), dtype="float32"),
+        y_1: R.Tensor((640, 640), dtype="float32"),
+        y_2: R.Tensor((640, 640), dtype="float32"),
+    ) -> R.Tensor((640, 1920), dtype="float32"):
+        with R.dataflow():
+            lv = R.concat((y, y_1, y_2), axis=1)
+            lv1 = R.matmul(x, lv, out_dtype="float32")
+            lv_1 = R.strided_slice(lv1, axes=[1], begin=[0], end=[640], 
strides=[1])
+            lv1_1 = R.nn.gelu(lv_1)
+            lv2 = R.strided_slice(lv1, axes=[1], begin=[640], end=[1280], 
strides=[1])
+            lv3 = R.nn.relu(lv2)
+            lv4 = R.strided_slice(lv1, axes=[1], begin=[1280], end=[1920], 
strides=[1])
+            lv5 = R.nn.relu(lv4)
+            lv6 = R.concat((lv1_1, lv3, lv5), axis=1)
+            R.output(lv6)
+        return lv6
+
+    tvm.ir.assert_structural_equal(mod["main"], expected2)
+
+    mod = get_parallel_matmul(3, activation=["relu", None, None])
+    mod = CombineParallelMatmul()(mod)
+
+    @R.function
+    def expected3(
+        x: R.Tensor((640, 640), dtype="float32"),
+        y: R.Tensor((640, 640), dtype="float32"),
+        y_1: R.Tensor((640, 640), dtype="float32"),
+        y_2: R.Tensor((640, 640), dtype="float32"),
+    ) -> R.Tensor((640, 1920), dtype="float32"):
+        with R.dataflow():
+            lv = R.concat((y, y_1, y_2), axis=1)
+            lv1 = R.matmul(x, lv, out_dtype="float32")
+            lv_1 = R.strided_slice(lv1, axes=[1], begin=[0], end=[640], 
strides=[1])
+            lv1_1 = R.nn.relu(lv_1)
+            lv2 = R.strided_slice(lv1, axes=[1], begin=[640], end=[1280], 
strides=[1])
+            lv3 = R.strided_slice(lv1, axes=[1], begin=[1280], end=[1920], 
strides=[1])
+            lv4 = R.concat((lv1_1, lv2, lv3), axis=1)
+            R.output(lv4)
+        return lv4
+
+    tvm.ir.assert_structural_equal(mod["main"], expected3)
+
+
+def test_bias_activation():
+    mod = get_parallel_matmul(3, with_bias=[True, True, True], 
activation=["relu", "relu", "relu"])
+    mod = CombineParallelMatmul()(mod)
+
+    @R.function
+    def expected1(
+        x: R.Tensor((640, 640), dtype="float32"),
+        y: R.Tensor((640, 640), dtype="float32"),
+        bias: R.Tensor((640,), dtype="float32"),
+        y_1: R.Tensor((640, 640), dtype="float32"),
+        bias_1: R.Tensor((640,), dtype="float32"),
+        y_2: R.Tensor((640, 640), dtype="float32"),
+        bias_2: R.Tensor((640,), dtype="float32"),
+    ) -> R.Tensor((640, 1920), dtype="float32"):
+        with R.dataflow():
+            lv = R.concat((y, y_1, y_2), axis=1)
+            lv1 = R.matmul(x, lv, out_dtype="float32")
+            lv2 = R.concat((bias, bias_1, bias_2), axis=0)
+            lv3 = R.add(lv1, lv2)
+            lv4 = R.nn.relu(lv3)
+            lv2_1 = R.strided_slice(lv4, axes=[1], begin=[0], end=[640], 
strides=[1])
+            lv5 = R.strided_slice(lv4, axes=[1], begin=[640], end=[1280], 
strides=[1])
+            lv8 = R.strided_slice(lv4, axes=[1], begin=[1280], end=[1920], 
strides=[1])
+            lv9 = R.concat((lv2_1, lv5, lv8), axis=1)
+            R.output(lv9)
+        return lv9
+
+    tvm.ir.assert_structural_equal(mod["main"], expected1)
+
+    mod = get_parallel_matmul(3, with_bias=[True, True, True], 
activation=["relu", None, "relu"])
+    mod = CombineParallelMatmul()(mod)
+
+    @R.function
+    def expected2(
+        x: R.Tensor((640, 640), dtype="float32"),
+        y: R.Tensor((640, 640), dtype="float32"),
+        bias: R.Tensor((640,), dtype="float32"),
+        y_1: R.Tensor((640, 640), dtype="float32"),
+        bias_1: R.Tensor((640,), dtype="float32"),
+        y_2: R.Tensor((640, 640), dtype="float32"),
+        bias_2: R.Tensor((640,), dtype="float32"),
+    ) -> R.Tensor((640, 1920), dtype="float32"):
+        with R.dataflow():
+            lv = R.concat((y, y_1, y_2), axis=1)
+            lv1 = R.matmul(x, lv, out_dtype="float32")
+            lv2 = R.concat((bias, bias_1, bias_2), axis=0)
+            lv3 = R.add(lv1, lv2)
+            lv1_1 = R.strided_slice(lv3, axes=[1], begin=[0], end=[640], 
strides=[1])
+            lv2_1 = R.nn.relu(lv1_1)
+            lv4 = R.strided_slice(lv3, axes=[1], begin=[640], end=[1280], 
strides=[1])
+            lv6 = R.strided_slice(lv3, axes=[1], begin=[1280], end=[1920], 
strides=[1])
+            lv7 = R.nn.relu(lv6)
+            lv8 = R.concat((lv2_1, lv4, lv7), axis=1)
+            R.output(lv8)
+        return lv8
+
+    tvm.ir.assert_structural_equal(mod["main"], expected2)
+
+    mod = get_parallel_matmul(3, with_bias=[True, False, True], 
activation=["relu", None, "relu"])
+    mod = CombineParallelMatmul()(mod)
+
+    @R.function
+    def expected3(
+        x: R.Tensor((640, 640), dtype="float32"),
+        y: R.Tensor((640, 640), dtype="float32"),
+        bias: R.Tensor((640,), dtype="float32"),
+        y_1: R.Tensor((640, 640), dtype="float32"),
+        y_2: R.Tensor((640, 640), dtype="float32"),
+        bias_1: R.Tensor((640,), dtype="float32"),
+    ) -> R.Tensor((640, 1920), dtype="float32"):
+        with R.dataflow():
+            lv = R.concat((y, y_1, y_2), axis=1)
+            lv1 = R.matmul(x, lv, out_dtype="float32")
+            lv_1 = R.strided_slice(lv1, axes=[1], begin=[0], end=[640], 
strides=[1])
+            lv1_1 = R.add(lv_1, bias)
+            lv2 = R.nn.relu(lv1_1)
+            lv3 = R.strided_slice(lv1, axes=[1], begin=[640], end=[1280], 
strides=[1])
+            lv4 = R.strided_slice(lv1, axes=[1], begin=[1280], end=[1920], 
strides=[1])
+            lv5 = R.add(lv4, bias_1)
+            lv6 = R.nn.relu(lv5)
+            lv7 = R.concat((lv2, lv3, lv6), axis=1)
+            R.output(lv7)
+        return lv7
+
+    tvm.ir.assert_structural_equal(mod["main"], expected3)
+
+
+def test_rhs_batched():
+    @tvm.script.ir_module
+    class four_matmul:
+        @R.function
+        def main(
+            x: R.Tensor((1024, 640), "float32"),
+            w0: R.Tensor((2, 640, 640), "float32"),
+            w1: R.Tensor((640, 640), "float32"),
+            w2: R.Tensor((2, 640, 640), "float32"),
+            w3: R.Tensor((3, 4, 640, 640), "float32"),
+        ) -> R.Tensor:
+            with R.dataflow():
+                lv0 = R.matmul(x, w0)
+                lv1 = R.matmul(x, w1)
+                lv2 = R.matmul(x, w2)
+                lv3 = R.matmul(x, w3)
+                out = (lv0, lv1, lv2, lv3)
+                R.output(out)
+            return out
+
+    mod = CombineParallelMatmul()(four_matmul)
+
+    @R.function
+    def expected1(
+        x: R.Tensor((1024, 640), dtype="float32"),
+        w0: R.Tensor((2, 640, 640), dtype="float32"),
+        w1: R.Tensor((640, 640), dtype="float32"),
+        w2: R.Tensor((2, 640, 640), dtype="float32"),
+        w3: R.Tensor((3, 4, 640, 640), dtype="float32"),
+    ) -> R.Tensor:
+        with R.dataflow():
+            lv = R.concat((w0, w2), axis=2)
+            lv1 = R.matmul(x, lv, out_dtype="float32")
+            lv0 = R.strided_slice(lv1, axes=[2], begin=[0], end=[640], 
strides=[1])
+            lv1_1 = R.matmul(x, w1, out_dtype="void")
+            lv2 = R.strided_slice(lv1, axes=[2], begin=[640], end=[1280], 
strides=[1])
+            lv3 = R.matmul(x, w3, out_dtype="void")
+            out = lv0, lv1_1, lv2, lv3
+            R.output(out)
+        return out
+
+    tvm.ir.assert_structural_equal(mod["main"], expected1)
+
+    @tvm.script.ir_module
+    class four_matmul_incompatible_batches:
+        @R.function
+        def main(
+            x: R.Tensor((1024, 640), "float32"),
+            w0: R.Tensor((2, 640, 640), "float32"),
+            w1: R.Tensor((3, 640, 640), "float32"),
+            w2: R.Tensor((2, 640, 640), "float32"),
+            w3: R.Tensor((2, 640, 640), "float32"),
+        ) -> R.Tensor:
+            with R.dataflow():
+                lv0 = R.matmul(x, w0)
+                lv1 = R.matmul(x, w1)
+                lv2 = R.matmul(x, w2)
+                lv3 = R.matmul(x, w3)
+                out = (lv0, lv1, lv2, lv3)
+                R.output(out)
+            return out
+
+    mod = CombineParallelMatmul()(four_matmul_incompatible_batches)
+    # For now, when rhs matrices have the same rank but different batch sizes, 
we don't
+    # combine any of them.
+    tvm.ir.assert_structural_equal(mod, four_matmul_incompatible_batches)
+
+
+def test_multiple_combine():
+    @tvm.script.ir_module
+    class multiple_combine:
+        @R.function
+        def main(
+            x1: R.Tensor((2, 1024, 640), "float32"),
+            x2: R.Tensor((2, 1024, 640), "float32"),
+            w0: R.Tensor((640, 640), "float32"),
+            w1: R.Tensor((640, 640), "float32"),
+            w2: R.Tensor((640, 640), "float32"),
+            w3: R.Tensor((640, 640), "float32"),
+            w4: R.Tensor((640, 640), "float32"),
+            b0: R.Tensor((640,), "float32"),
+            b1: R.Tensor((640,), "float32"),
+        ) -> R.Tensor:
+            with R.dataflow():
+                lv0 = R.matmul(x1, w0)
+                lv3 = R.matmul(x2, w3)
+                lv1 = R.matmul(x1, w1)
+                lv5 = R.add(lv3, b0)
+                lv2 = R.matmul(x1, w2)
+                lv4 = R.matmul(x2, w4)
+                lv6 = R.add(lv4, b1)
+                out = (lv0, lv1, lv2, lv5, lv6)
+                R.output(out)
+            return out
+
+    mod = CombineParallelMatmul()(multiple_combine)
+
+    @R.function
+    def expected1(
+        x1: R.Tensor((2, 1024, 640), dtype="float32"),
+        x2: R.Tensor((2, 1024, 640), dtype="float32"),
+        w0: R.Tensor((640, 640), dtype="float32"),
+        w1: R.Tensor((640, 640), dtype="float32"),
+        w2: R.Tensor((640, 640), dtype="float32"),
+        w3: R.Tensor((640, 640), dtype="float32"),
+        w4: R.Tensor((640, 640), dtype="float32"),
+        b0: R.Tensor((640,), dtype="float32"),
+        b1: R.Tensor((640,), dtype="float32"),
+    ) -> R.Tensor:
+        with R.dataflow():
+            lv = R.concat((w0, w1, w2), axis=1)
+            lv1 = R.matmul(x1, lv, out_dtype="float32")
+            lv0 = R.strided_slice(lv1, axes=[2], begin=[0], end=[640], 
strides=[1])
+            lv1_1 = R.strided_slice(lv1, axes=[2], begin=[640], end=[1280], 
strides=[1])
+            lv_1 = R.concat((w3, w4), axis=1)
+            lv1_2 = R.matmul(x2, lv_1, out_dtype="float32")
+            lv2 = R.concat((b0, b1), axis=0)
+            lv3 = R.add(lv1_2, lv2)
+            lv5 = R.strided_slice(lv3, axes=[2], begin=[0], end=[640], 
strides=[1])
+            lv2_1 = R.strided_slice(lv1, axes=[2], begin=[1280], end=[1920], 
strides=[1])
+            lv6 = R.strided_slice(lv3, axes=[2], begin=[640], end=[1280], 
strides=[1])
+            out = lv0, lv1_1, lv2_1, lv5, lv6
+            R.output(out)
+        return out
+
+    tvm.ir.assert_structural_equal(mod["main"], expected1)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()


Reply via email to