This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new a59c09de77 [Unity] Add pass for combining parallel matmul (#14583)
a59c09de77 is described below
commit a59c09de77bbab5514dc4ff3b1713f31113def51
Author: masahi <[email protected]>
AuthorDate: Thu Apr 13 05:40:53 2023 +0900
[Unity] Add pass for combining parallel matmul (#14583)
* stub
* make EnterWithScope and ExitWithScope public
* qkv combining works
* automatic branch extraction
* fix hardcoded concat axis
* handle non-uniform rhs ranks
* wip
* improve termination check in binding rewriter
* wip
* properly handle rhs with same rank but different batch size
* support bias and activation
* refactor
* fixed activation handling
* wip
* clean
* fix bias and activation combine logic
* add tests
* add comment
* add doc
* fix compile warning
---
include/tvm/relax/dataflow_matcher.h | 10 +
include/tvm/relax/dataflow_pattern.h | 10 +-
python/tvm/relax/transform/transform.py | 16 +
src/relax/ir/dataflow_matcher.cc | 24 +-
src/relax/ir/dataflow_pattern.cc | 18 +-
src/relax/transform/combine_parallel_matmul.cc | 337 +++++++++++++++
tests/python/relax/test_dataflow_pattern.py | 6 +-
.../test_transform_combine_parallel_matmul.py | 469 +++++++++++++++++++++
8 files changed, 863 insertions(+), 27 deletions(-)
diff --git a/include/tvm/relax/dataflow_matcher.h
b/include/tvm/relax/dataflow_matcher.h
index cf7c58f093..16249377a2 100644
--- a/include/tvm/relax/dataflow_matcher.h
+++ b/include/tvm/relax/dataflow_matcher.h
@@ -58,6 +58,16 @@ Optional<Map<DFPattern, Expr>> ExtractMatchedExpr(
TVM_DLL Optional<Map<DFPattern, Var>> MatchGraph(const PatternContext& ctx,
const DataflowBlock& dfb);
+/**
+ * \brief Rewrite a function with the given pattern and the rewriter function.
+ * \param ctx The pattern constraint context under which rewriting takes place.
+ * \param rewriter The function to be called on a successful matching for
rewriting.
+ Given the map of patterns and corresponding variables (bound variables or
parameters),
+ it should return a map that specifies new values for matched bound
variables.
+ * \param f The function to rewrite
+ * \return The rewritten or the input function, depending on the pattern
matching result.
+ */
+TVM_DLL Function RewriteBindings(const PatternContext& ctx, PackedFunc
rewriter, Function f);
} // namespace relax
} // namespace tvm
diff --git a/include/tvm/relax/dataflow_pattern.h
b/include/tvm/relax/dataflow_pattern.h
index e4c27f3558..68cfdd83ad 100644
--- a/include/tvm/relax/dataflow_pattern.h
+++ b/include/tvm/relax/dataflow_pattern.h
@@ -248,14 +248,12 @@ class PatternContext : public ObjectRef {
/*! \brief Get the constraint context object on the top of the stack */
TVM_DLL static Optional<PatternContext> Current();
- class Internal;
-
- private:
/*! \brief The RAII-like entry of a constraint context scope */
- TVM_DLL void EnterWithScope();
+ TVM_DLL void EnterWithScope() const;
/*! \brief The RAII-like exit of a constraint context scope */
- TVM_DLL void ExitWithScope();
- friend class Internal;
+ TVM_DLL void ExitWithScope() const;
+
+ private:
friend class With<PatternContext>;
};
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index a53c45b655..f0277151bb 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -934,6 +934,22 @@ def SplitCallTIRByPattern(patterns, fcodegen) ->
tvm.ir.transform.Pass:
return _ffi_api.SplitCallTIRByPattern(patterns, fcodegen) # type: ignore
+def CombineParallelMatmul():
+ """Combine multiple matmul operators sharing the same LHS matrix into one,
+ followed by slicing. When all matmul branches in a tree have the same set
of fused ops,
+ the fused ops are applied to the combined matmul output before slicing.
+
+ Currently, only a limited set of fused ops is supported. It includes bias
add,
+ relu, gelu, and silu activation.
+
+ Returns
+ -------
+ ret : tvm.transform.Pass
+ The corresponding pass.
+ """
+ return _ffi_api.CombineParallelMatmul() # type: ignore
+
+
def _wrap_class_function_pass(pass_cls, pass_info):
"""Wrap a python class as function pass."""
diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc
index 6e8211cfd3..b06da62c26 100644
--- a/src/relax/ir/dataflow_matcher.cc
+++ b/src/relax/ir/dataflow_matcher.cc
@@ -22,6 +22,7 @@
* \brief The dataflow pattern matcher for Relax.
*/
+#include <tvm/node/structural_equal.h>
#include <tvm/relax/analysis.h>
#include <tvm/relax/dataflow_matcher.h>
#include <tvm/relax/dataflow_pattern.h>
@@ -791,7 +792,7 @@ class PatternRewriter : ExprMutator {
: ctx_(ctx), rewriter_func_(rewriter_func), params_(params) {}
template <typename PatternType>
- static Expr Run(PatternType pat, PackedFunc rewriter_func, Function f) {
+ static Function Run(PatternType pat, PackedFunc rewriter_func, Function f) {
std::unordered_set<const VarNode*> params;
for (const auto& p : f->params) {
params.insert(p.get());
@@ -868,15 +869,17 @@ class PatternRewriter : ExprMutator {
std::unordered_set<const VarNode*> emitted_vars;
+ bool changed = false;
for (size_t i = 0; i < block->bindings.size(); ++i) {
const auto& binding = block->bindings[i];
if (auto var_bind = binding.as<VarBindingNode>()) {
- if (replacements.count(var_bind->var)) {
- auto new_val = replacements[var_bind->var];
+ if (auto new_val =
replacements.Get(var_bind->var).value_or(var_bind->value);
+ !StructuralEqual()(var_bind->value, new_val)) {
Array<Binding> pending_bindings(block->bindings.begin() + i + 1,
block->bindings.end());
// Make sure there is no unbound variable used in the new value
before it is emitted
EmitUsedVars(new_val, pending_bindings, &emitted_vars);
this->ReEmitBinding(var_bind, builder_->Normalize(new_val));
+ changed = true;
} else if (!emitted_vars.count(var_bind->var.get())) {
this->VisitBinding(binding);
emitted_vars.insert(var_bind->var.get());
@@ -885,7 +888,11 @@ class PatternRewriter : ExprMutator {
this->VisitBinding(binding);
}
}
- return RewriteDataflowBlockFixedPoint(builder_->EndBlock());
+
+ auto new_block = builder_->EndBlock();
+
+ if (!changed) return new_block;
+ return RewriteDataflowBlockFixedPoint(new_block);
}
return block;
}
@@ -909,15 +916,16 @@ class PatternRewriter : ExprMutator {
std::unordered_map<const Object*, Expr> memo_;
};
+Function RewriteBindings(const PatternContext& ctx, PackedFunc rewriter,
Function f) {
+ return PatternRewriter::Run(ctx, rewriter, f);
+}
+
TVM_REGISTER_GLOBAL("relax.dpl.rewrite_call")
.set_body_typed([](DFPattern pat, PackedFunc rewriter, Function f) {
return PatternRewriter::Run(pat, rewriter, f);
});
-TVM_REGISTER_GLOBAL("relax.dpl.rewrite_bindings")
- .set_body_typed([](const PatternContext& ctx, PackedFunc rewriter,
Function f) {
- return PatternRewriter::Run(ctx, rewriter, f);
- });
+TVM_REGISTER_GLOBAL("relax.dpl.rewrite_bindings").set_body_typed(RewriteBindings);
} // namespace relax
} // namespace tvm
diff --git a/src/relax/ir/dataflow_pattern.cc b/src/relax/ir/dataflow_pattern.cc
index 4d225ceecf..cd1376303c 100644
--- a/src/relax/ir/dataflow_pattern.cc
+++ b/src/relax/ir/dataflow_pattern.cc
@@ -412,9 +412,9 @@ PatternContext::PatternContext(bool incremental) {
data_ = std::move(n);
}
-void PatternContext::EnterWithScope() { pattern_ctx_stack().push(*this); }
+void PatternContext::EnterWithScope() const { pattern_ctx_stack().push(*this);
}
-void PatternContext::ExitWithScope() {
+void PatternContext::ExitWithScope() const {
ICHECK(pattern_ctx_stack().top().same_as(*this));
pattern_ctx_stack().pop();
}
@@ -610,15 +610,13 @@
TVM_REGISTER_GLOBAL("relax.dpl.current_context").set_body_typed([] {
return PatternContext::Current();
});
-class PatternContext::Internal {
- public:
- static void EnterScope(PatternContext pass_ctx) { pass_ctx.EnterWithScope();
}
- static void ExitScope(PatternContext pass_ctx) { pass_ctx.ExitWithScope(); }
-};
-
-TVM_REGISTER_GLOBAL("relax.dpl.enter_context").set_body_typed(PatternContext::Internal::EnterScope);
+TVM_REGISTER_GLOBAL("relax.dpl.enter_context").set_body_typed([](const
PatternContext& ctx) {
+ ctx.EnterWithScope();
+});
-TVM_REGISTER_GLOBAL("relax.dpl.exit_context").set_body_typed(PatternContext::Internal::ExitScope);
+TVM_REGISTER_GLOBAL("relax.dpl.exit_context").set_body_typed([](const
PatternContext& ctx) {
+ ctx.ExitWithScope();
+});
} // namespace relax
} // namespace tvm
diff --git a/src/relax/transform/combine_parallel_matmul.cc
b/src/relax/transform/combine_parallel_matmul.cc
new file mode 100644
index 0000000000..d6435ec829
--- /dev/null
+++ b/src/relax/transform/combine_parallel_matmul.cc
@@ -0,0 +1,337 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/dataflow_matcher.h>
+#include <tvm/relax/dataflow_pattern.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/struct_info.h>
+#include <tvm/relax/transform.h>
+
+#include <optional>
+#include <unordered_map>
+#include <vector>
+
+#include "../op/nn/nn.h"
+#include "../op/tensor/binary.h"
+#include "../op/tensor/index.h"
+#include "../op/tensor/linear_algebra.h"
+#include "../op/tensor/manipulate.h"
+
+namespace tvm {
+namespace relax {
+
+using runtime::Map;
+
+/*! \brief Group shapes of the RHS matrices by rank. Matrices in a group whose
batch sizes
+ are compatible are combined.
+*/
+std::unordered_map<size_t, std::vector<size_t>> GroupShapes(
+ const std::vector<Array<PrimExpr>>& shapes) {
+ std::unordered_map<size_t, std::vector<size_t>> indices_map;
+ for (size_t i = 0; i < shapes.size(); ++i) {
+ indices_map[shapes[i].size()].push_back(i);
+ }
+ return indices_map;
+}
+
+inline TensorStructInfo GetTensorSInfo(Expr e) {
+ return Downcast<TensorStructInfo>(GetStructInfo(e));
+}
+
+struct BranchInfo {
+ int num_branches;
+ std::optional<int> bias_dim;
+ std::optional<std::string> activation;
+};
+
+struct Patterns {
+ Patterns() : input(Wildcard()) { ctx.EnterWithScope(); }
+
+ PatternContext ctx;
+ WildcardPattern input;
+ std::vector<WildcardPattern> rhs;
+ std::vector<WildcardPattern> bias;
+ std::vector<CallPattern> matmul, bias_add, activation;
+};
+
+Patterns CreatePatterns(const BranchInfo& branch_info) {
+ Patterns patterns;
+
+ for (int i = 0; i < branch_info.num_branches; ++i) {
+ auto w_pat = Wildcard();
+ auto matmul_pat = IsOp("relax.matmul")(patterns.input, w_pat);
+ patterns.rhs.push_back(w_pat);
+ patterns.matmul.push_back(matmul_pat);
+ patterns.ctx.add_constraint(patterns.input, matmul_pat,
PairCons(PairCons::kUsedBy, 0));
+ patterns.ctx.add_constraint(w_pat, matmul_pat, PairCons(PairCons::kUsedBy,
1));
+
+ CallPattern matmul_out = matmul_pat;
+
+ if (branch_info.bias_dim) {
+ auto bias_pat = Wildcard();
+ auto bias_add_pat = IsOp("relax.add")(matmul_pat, bias_pat);
+ patterns.bias.push_back(bias_pat);
+ patterns.bias_add.push_back(bias_add_pat);
+ patterns.ctx.add_constraint(matmul_pat, bias_add_pat,
PairCons(PairCons::kUsedBy, 0));
+ patterns.ctx.add_constraint(bias_pat, bias_add_pat,
PairCons(PairCons::kUsedBy, 1));
+ matmul_out = bias_add_pat;
+ }
+
+ if (branch_info.activation) {
+ auto act_pat = IsOp(*branch_info.activation)(matmul_out);
+ patterns.activation.push_back(act_pat);
+ patterns.ctx.add_constraint(matmul_out, act_pat,
PairCons(PairCons::kUsedBy, 0));
+ }
+ }
+
+ return patterns;
+}
+
+/*! \brief Create a rewriter for the given parallel matmul branches. */
+runtime::TypedPackedFunc<Map<Var, Expr>(Map<DFPattern, Var>)> GetRewriter(
+ const Patterns& patterns, const BranchInfo& branch_info) {
+ auto batch_dims_compatible = [](size_t rhs_dim, const std::vector<size_t>&
indices,
+ const std::vector<Array<PrimExpr>>&
rhs_shapes) {
+ arith::Analyzer ana;
+ for (auto ind : indices) {
+ ICHECK_EQ(static_cast<int>(rhs_shapes[ind].size()), rhs_dim);
+ // -2 for reduction and concat axes
+ for (size_t i = 0; i < rhs_dim - 2; ++i) {
+ if (!ana.CanProve(rhs_shapes[indices[0]][i] == rhs_shapes[ind][i])) {
+ return false;
+ }
+ }
+ }
+ return true;
+ };
+
+ return [=](Map<DFPattern, Var> matchings) {
+ std::vector<Array<PrimExpr>> rhs_shapes;
+ for (const auto& rhs_pat : patterns.rhs) {
+ auto rhs_shape_opt = GetTensorSInfo(matchings[rhs_pat])->GetShape();
+ if (!rhs_shape_opt) {
+ return Map<Var, Expr>{};
+ }
+ rhs_shapes.push_back(rhs_shape_opt.value());
+ }
+
+ Map<Var, Expr> replacements;
+
+ for (const auto& [rhs_dim, indices] : GroupShapes(rhs_shapes)) {
+ if (indices.size() == 1 || !batch_dims_compatible(rhs_dim, indices,
rhs_shapes)) continue;
+
+ Array<Expr> rhs, bias;
+ for (auto ind : indices) {
+ rhs.push_back(matchings[patterns.rhs[ind]]);
+ if (branch_info.bias_dim) {
+ ICHECK(matchings.count(patterns.bias[ind]));
+ bias.push_back(matchings[patterns.bias[ind]]);
+ }
+ }
+
+ auto inp = matchings[patterns.input];
+ auto concat_rhs = concat(Tuple(rhs), Integer(rhs_dim - 1));
+ auto out_dtype =
GetTensorSInfo(matchings[patterns.matmul[indices[0]]])->dtype;
+ auto matmul_combined = matmul(inp, concat_rhs, out_dtype);
+
+ const auto& pattern_to_replace = [&patterns, &branch_info]() {
+ if (branch_info.activation) return patterns.activation;
+ if (branch_info.bias_dim) return patterns.bias_add;
+ return patterns.matmul;
+ }();
+
+ if (branch_info.bias_dim) {
+ auto bias_dim = GetTensorSInfo(bias[0])->ndim;
+ auto concat_bias = concat(Tuple(bias), Integer(bias_dim - 1));
+ matmul_combined = add(matmul_combined, concat_bias);
+ }
+
+ if (branch_info.activation) {
+ if (*branch_info.activation == "relax.nn.relu") {
+ matmul_combined = relu(matmul_combined);
+ } else if (*branch_info.activation == "relax.nn.gelu") {
+ matmul_combined = gelu(matmul_combined);
+ } else if (*branch_info.activation == "relax.nn.silu") {
+ matmul_combined = silu(matmul_combined);
+ } else {
+ LOG(FATAL) << "Unsupported activation: " << *branch_info.activation;
+ }
+ }
+
+ PrimExpr begin{0};
+ Array<PrimExpr> strides{1};
+ int lhs_dim = GetTensorSInfo(inp)->ndim;
+ int slice_axis = std::max<int>(lhs_dim, rhs_dim) - 1;
+
+ for (size_t i = 0; i < indices.size(); ++i) {
+ auto width = GetTensorSInfo(rhs[i])->GetShape().value()[rhs_dim - 1];
+ auto bound_var = matchings[pattern_to_replace[indices[i]]];
+ auto slice =
+ strided_slice(matmul_combined, {slice_axis}, {begin}, {begin +
width}, strides);
+ replacements.Set(bound_var, slice);
+ begin += width;
+ }
+ }
+
+ return replacements;
+ };
+}
+
+Function Rewrite(Function f, const BranchInfo& branch_info) {
+ auto patterns = CreatePatterns(branch_info);
+ auto rewriter = GetRewriter(patterns, branch_info);
+ return RewriteBindings(patterns.ctx, rewriter, f);
+}
+
+/*! \brief Look for subtrees with parallel matmul and return information about
+ them (the number of branches and the kind of fused ops)
+*/
+std::vector<BranchInfo> GetBranchInfo(Function f) {
+ auto bias_pat = Wildcard();
+ auto matmul_pat = IsOp("relax.matmul")(Wildcard(), Wildcard());
+ auto bias_add_pat = IsOp("relax.add")(matmul_pat, bias_pat);
+
+ std::vector<std::string> activations{"relax.nn.relu", "relax.nn.gelu",
"relax.nn.silu"};
+
+ std::vector<DFPattern> activation_pat, bias_activation_pat;
+ for (const auto& act : activations) {
+ activation_pat.push_back(IsOp(act)(matmul_pat));
+ bias_activation_pat.push_back(IsOp(act)(bias_add_pat));
+ }
+
+ auto bindings = AnalyzeVar2Value(f);
+
+ auto create_group = [&](DFPattern pat) {
+ // Maps a LHS matrix to consumer parallel matmuls
+ std::unordered_map<const VarNode*, BranchInfo> groups;
+
+ PostOrderVisit(f, [&](const Expr& e) {
+ if (!e->IsInstance<CallNode>()) return;
+ if (auto match = ExtractMatchedExpr(pat, e, bindings)) {
+ auto matmul_call = Downcast<Call>(match.value()[matmul_pat]);
+ auto matmul_lhs = Downcast<Var>(matmul_call->args[0]);
+
+ auto it = groups.find(matmul_lhs.get());
+ BranchInfo* branch = it != groups.end() ? &it->second : nullptr;
+ std::optional<int> bias_dim = std::nullopt;
+ std::optional<std::string> activation = std::nullopt;
+
+ if (match.value().count(bias_pat)) {
+ bias_dim = GetTensorSInfo(match.value()[bias_pat])->ndim;
+ }
+
+ for (size_t i = 0; i < activations.size(); ++i) {
+ if (match.value().count(activation_pat[i]) ||
+ match.value().count(bias_activation_pat[i])) {
+ activation = activations[i];
+ }
+ }
+
+ if (!branch) {
+ // Create a new subgraph with one matmul
+ groups[matmul_lhs.get()] = {1, bias_dim, activation};
+ } else {
+ // Create a new branch in the existing parallel matmul subtree, and
+ // invalidate bias and activation information when needed.
+ branch->num_branches += 1;
+
+ if (!bias_dim || (branch->bias_dim && *branch->bias_dim !=
*bias_dim)) {
+ branch->bias_dim = std::nullopt;
+ }
+
+ if (!activation || (branch->activation && *branch->activation !=
*activation)) {
+ branch->activation = std::nullopt;
+ }
+ }
+ return;
+ }
+ });
+
+ return groups;
+ };
+
+ std::unordered_map<const VarNode*, BranchInfo> groups_activation;
+ for (size_t i = 0; i < activations.size(); ++i) {
+ auto groups = create_group(bias_activation_pat[i]);
+ groups_activation.merge(std::move(groups));
+ }
+
+ for (size_t i = 0; i < activations.size(); ++i) {
+ auto groups = create_group(activation_pat[i]);
+ groups_activation.merge(std::move(groups));
+ }
+
+ auto groups_bias = create_group(bias_add_pat);
+ auto groups_matmul = create_group(matmul_pat);
+
+ for (const auto& groups : {groups_bias, groups_activation}) {
+ for (const auto& [lhs, branch] : groups) {
+ // Prefer combining more matmuls than combining fewer ones and leaving
additional uncombined
+ // matmuls followed by bias or activation. So we combine matmuls + fused
ops patterns only
+ // when all branches have the same fused ops.
+ if (auto it = groups_matmul.find(lhs);
+ it != groups_matmul.end() && it->second.num_branches ==
branch.num_branches) {
+ it->second = branch;
+ }
+ }
+ }
+
+ std::vector<BranchInfo> info;
+
+ for (const auto& groups : {groups_matmul, groups_activation, groups_bias}) {
+ for (const auto& group : groups) {
+ if (group.second.num_branches > 1) {
+ info.push_back(group.second);
+ }
+ }
+ }
+
+ return info;
+}
+
+Function CombineParallelMatmul(Function f) {
+ auto branches = GetBranchInfo(f);
+ std::sort(branches.begin(), branches.end(),
+ [](const auto& b1, const auto& b2) { return b1.num_branches >
b2.num_branches; });
+
+ for (const auto& branch : branches) {
+ f = Rewrite(f, branch);
+ }
+ return f;
+}
+
+namespace transform {
+
+Pass CombineParallelMatmul() {
+ runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>
pass_func =
+ [=](Function f, IRModule m, PassContext pc) { return
relax::CombineParallelMatmul(f); };
+ return CreateFunctionPass(/*pass_function=*/pass_func, //
+ /*opt_level=*/0, //
+ /*pass_name=*/"CombineParallelMatmul", //
+ /*required=*/{});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.CombineParallelMatmul").set_body_typed(CombineParallelMatmul);
+
+} // namespace transform
+
+} // namespace relax
+} // namespace tvm
diff --git a/tests/python/relax/test_dataflow_pattern.py
b/tests/python/relax/test_dataflow_pattern.py
index a73a62eeef..ed221f54be 100644
--- a/tests/python/relax/test_dataflow_pattern.py
+++ b/tests/python/relax/test_dataflow_pattern.py
@@ -18,7 +18,7 @@
import pytest
import tvm.testing
-from tvm import relay, relax
+from tvm import relay
from tvm.relax.dpl import *
from tvm.relax.analysis import get_var2val
from tvm import relax as rx, tir
@@ -1177,9 +1177,9 @@ def test_combine_matmul_emit_order():
# make sure it builds
mod = tvm.IRModule()
mod["main"] = rewritten
- mod = relax.transform.LegalizeOps()(mod)
+ mod = rx.transform.LegalizeOps()(mod)
- relax.build(mod, target="llvm")
+ rx.build(mod, target="llvm")
if __name__ == "__main__":
diff --git a/tests/python/relax/test_transform_combine_parallel_matmul.py
b/tests/python/relax/test_transform_combine_parallel_matmul.py
new file mode 100644
index 0000000000..f5cc269620
--- /dev/null
+++ b/tests/python/relax/test_transform_combine_parallel_matmul.py
@@ -0,0 +1,469 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm.testing
+
+from tvm import relax, tir
+from tvm.script import relax as R, tir as T
+from tvm.relax.transform import CombineParallelMatmul
+from tvm.script.ir_builder import IRBuilder
+from tvm.script.ir_builder import relax as relax_builder
+
+
+def get_parallel_matmul(
+ num_branches,
+ lhs_shape=(640, 640),
+ rhs_shape=(640, 640),
+ with_bias=None,
+ activation=None,
+):
+ dtype = "float32"
+
+ activation_map = {"relu": R.nn.relu, "gelu": R.nn.gelu}
+
+ with IRBuilder() as builder:
+ with relax_builder.function():
+ R.func_name("main")
+ x = R.arg("x", R.Tensor(lhs_shape, dtype))
+
+ rhs = []
+ bias = []
+
+ for i in range(num_branches):
+ rhs.append(R.arg("y", R.Tensor(rhs_shape, dtype)))
+
+ if with_bias and with_bias[i]:
+ bias.append(R.arg("bias", R.Tensor((rhs_shape[1],),
dtype)))
+ else:
+ bias.append(None)
+
+ with R.dataflow() as frame:
+ branches = []
+
+ for i, r in enumerate(rhs):
+ result = R.emit(R.matmul(x, r, out_dtype=dtype))
+ if bias[i]:
+ result = R.emit(result + bias[i])
+ if activation and activation[i]:
+ result = R.emit(activation_map[activation[i]](result))
+
+ branches.append(result)
+
+ R.output(R.emit(R.concat(branches, axis=1)))
+
+ R.func_ret_value(frame.output_vars[0])
+
+ func = builder.get()
+ return tvm.IRModule({"main": func})
+
+
+def test_simple():
+ mod_orig = get_parallel_matmul(1)
+ mod = CombineParallelMatmul()(mod_orig)
+
+ tvm.ir.assert_structural_equal(mod, mod_orig)
+
+ mod = get_parallel_matmul(3)
+ mod = CombineParallelMatmul()(mod)
+
+ @R.function
+ def expected1(
+ x: R.Tensor((640, 640), dtype="float32"),
+ y: R.Tensor((640, 640), dtype="float32"),
+ y_1: R.Tensor((640, 640), dtype="float32"),
+ y_2: R.Tensor((640, 640), dtype="float32"),
+ ) -> R.Tensor((640, 1920), dtype="float32"):
+ with R.dataflow():
+ lv = R.concat((y, y_1, y_2), axis=1)
+ lv1 = R.matmul(x, lv, out_dtype="float32")
+ lv_1 = R.strided_slice(lv1, axes=[1], begin=[0], end=[640],
strides=[1])
+ lv1_1 = R.strided_slice(lv1, axes=[1], begin=[640], end=[1280],
strides=[1])
+ lv2 = R.strided_slice(lv1, axes=[1], begin=[1280], end=[1920],
strides=[1])
+ lv3 = R.concat((lv_1, lv1_1, lv2), axis=1)
+ R.output(lv3)
+ return lv3
+
+ tvm.ir.assert_structural_equal(mod["main"], expected1)
+
+ # Test a batched LHS case, slicing is done on the axis 2
+ mod = get_parallel_matmul(3, lhs_shape=(2, 1024, 640))
+ mod = CombineParallelMatmul()(mod)
+
+ @R.function
+ def expected2(
+ x: R.Tensor((2, 1024, 640), dtype="float32"),
+ y: R.Tensor((640, 640), dtype="float32"),
+ y_1: R.Tensor((640, 640), dtype="float32"),
+ y_2: R.Tensor((640, 640), dtype="float32"),
+ ) -> R.Tensor((2, 3072, 640), dtype="float32"):
+ with R.dataflow():
+ lv = R.concat((y, y_1, y_2), axis=1)
+ lv1 = R.matmul(x, lv, out_dtype="float32")
+ lv_1 = R.strided_slice(lv1, axes=[2], begin=[0], end=[640],
strides=[1])
+ lv1_1 = R.strided_slice(lv1, axes=[2], begin=[640], end=[1280],
strides=[1])
+ lv2 = R.strided_slice(lv1, axes=[2], begin=[1280], end=[1920],
strides=[1])
+ lv3 = R.concat((lv_1, lv1_1, lv2), axis=1)
+ R.output(lv3)
+ return lv3
+
+ tvm.ir.assert_structural_equal(mod["main"], expected2)
+
+
+def test_bias():
+ mod = get_parallel_matmul(3, with_bias=[True, True, True])
+ mod = CombineParallelMatmul()(mod)
+
+ @R.function
+ def expected1(
+ x: R.Tensor((640, 640), dtype="float32"),
+ y: R.Tensor((640, 640), dtype="float32"),
+ bias: R.Tensor((640,), dtype="float32"),
+ y_1: R.Tensor((640, 640), dtype="float32"),
+ bias_1: R.Tensor((640,), dtype="float32"),
+ y_2: R.Tensor((640, 640), dtype="float32"),
+ bias_2: R.Tensor((640,), dtype="float32"),
+ ) -> R.Tensor((640, 1920), dtype="float32"):
+ with R.dataflow():
+ lv = R.concat((y, y_1, y_2), axis=1)
+ lv1 = R.matmul(x, lv, out_dtype="float32")
+ lv2 = R.concat((bias, bias_1, bias_2), axis=0)
+ lv3 = R.add(lv1, lv2)
+ lv1_1 = R.strided_slice(lv3, axes=[1], begin=[0], end=[640],
strides=[1])
+ lv3_1 = R.strided_slice(lv3, axes=[1], begin=[640], end=[1280],
strides=[1])
+ lv5 = R.strided_slice(lv3, axes=[1], begin=[1280], end=[1920],
strides=[1])
+ lv6 = R.concat((lv1_1, lv3_1, lv5), axis=1)
+ R.output(lv6)
+ return lv6
+
+ tvm.ir.assert_structural_equal(mod["main"], expected1)
+
+ mod = get_parallel_matmul(3, with_bias=[True, False, True])
+ mod = CombineParallelMatmul()(mod)
+
+ @R.function
+ def expected2(
+ x: R.Tensor((640, 640), dtype="float32"),
+ y: R.Tensor((640, 640), dtype="float32"),
+ bias: R.Tensor((640,), dtype="float32"),
+ y_1: R.Tensor((640, 640), dtype="float32"),
+ y_2: R.Tensor((640, 640), dtype="float32"),
+ bias_1: R.Tensor((640,), dtype="float32"),
+ ) -> R.Tensor((640, 1920), dtype="float32"):
+ with R.dataflow():
+ lv = R.concat((y, y_1, y_2), axis=1)
+ lv1 = R.matmul(x, lv, out_dtype="float32")
+ lv_1 = R.strided_slice(lv1, axes=[1], begin=[0], end=[640],
strides=[1])
+ lv1_1 = R.add(lv_1, bias)
+ lv2 = R.strided_slice(lv1, axes=[1], begin=[640], end=[1280],
strides=[1])
+ lv3 = R.strided_slice(lv1, axes=[1], begin=[1280], end=[1920],
strides=[1])
+ lv4 = R.add(lv3, bias_1)
+ lv5 = R.concat((lv1_1, lv2, lv4), axis=1)
+ R.output(lv5)
+ return lv5
+
+ tvm.ir.assert_structural_equal(mod["main"], expected2)
+
+
+def test_activation():
+ mod = get_parallel_matmul(3, activation=["relu", "relu", "relu"])
+ mod = CombineParallelMatmul()(mod)
+
+ @R.function
+ def expected1(
+ x: R.Tensor((640, 640), dtype="float32"),
+ y: R.Tensor((640, 640), dtype="float32"),
+ y_1: R.Tensor((640, 640), dtype="float32"),
+ y_2: R.Tensor((640, 640), dtype="float32"),
+ ) -> R.Tensor((640, 1920), dtype="float32"):
+ with R.dataflow():
+ lv = R.concat((y, y_1, y_2), axis=1)
+ lv1 = R.matmul(x, lv, out_dtype="float32")
+ lv2 = R.nn.relu(lv1)
+ lv1_1 = R.strided_slice(lv2, axes=[1], begin=[0], end=[640],
strides=[1])
+ lv3 = R.strided_slice(lv2, axes=[1], begin=[640], end=[1280],
strides=[1])
+ lv5 = R.strided_slice(lv2, axes=[1], begin=[1280], end=[1920],
strides=[1])
+ lv6 = R.concat((lv1_1, lv3, lv5), axis=1)
+ R.output(lv6)
+ return lv6
+
+ tvm.ir.assert_structural_equal(mod["main"], expected1)
+
+ mod = get_parallel_matmul(3, activation=["gelu", "relu", "relu"])
+ mod = CombineParallelMatmul()(mod)
+
+ @R.function
+ def expected2(
+ x: R.Tensor((640, 640), dtype="float32"),
+ y: R.Tensor((640, 640), dtype="float32"),
+ y_1: R.Tensor((640, 640), dtype="float32"),
+ y_2: R.Tensor((640, 640), dtype="float32"),
+ ) -> R.Tensor((640, 1920), dtype="float32"):
+ with R.dataflow():
+ lv = R.concat((y, y_1, y_2), axis=1)
+ lv1 = R.matmul(x, lv, out_dtype="float32")
+ lv_1 = R.strided_slice(lv1, axes=[1], begin=[0], end=[640],
strides=[1])
+ lv1_1 = R.nn.gelu(lv_1)
+ lv2 = R.strided_slice(lv1, axes=[1], begin=[640], end=[1280],
strides=[1])
+ lv3 = R.nn.relu(lv2)
+ lv4 = R.strided_slice(lv1, axes=[1], begin=[1280], end=[1920],
strides=[1])
+ lv5 = R.nn.relu(lv4)
+ lv6 = R.concat((lv1_1, lv3, lv5), axis=1)
+ R.output(lv6)
+ return lv6
+
+ tvm.ir.assert_structural_equal(mod["main"], expected2)
+
+ mod = get_parallel_matmul(3, activation=["relu", None, None])
+ mod = CombineParallelMatmul()(mod)
+
+ @R.function
+ def expected3(
+ x: R.Tensor((640, 640), dtype="float32"),
+ y: R.Tensor((640, 640), dtype="float32"),
+ y_1: R.Tensor((640, 640), dtype="float32"),
+ y_2: R.Tensor((640, 640), dtype="float32"),
+ ) -> R.Tensor((640, 1920), dtype="float32"):
+ with R.dataflow():
+ lv = R.concat((y, y_1, y_2), axis=1)
+ lv1 = R.matmul(x, lv, out_dtype="float32")
+ lv_1 = R.strided_slice(lv1, axes=[1], begin=[0], end=[640],
strides=[1])
+ lv1_1 = R.nn.relu(lv_1)
+ lv2 = R.strided_slice(lv1, axes=[1], begin=[640], end=[1280],
strides=[1])
+ lv3 = R.strided_slice(lv1, axes=[1], begin=[1280], end=[1920],
strides=[1])
+ lv4 = R.concat((lv1_1, lv2, lv3), axis=1)
+ R.output(lv4)
+ return lv4
+
+ tvm.ir.assert_structural_equal(mod["main"], expected3)
+
+
+def test_bias_activation():
+ mod = get_parallel_matmul(3, with_bias=[True, True, True],
activation=["relu", "relu", "relu"])
+ mod = CombineParallelMatmul()(mod)
+
+ @R.function
+ def expected1(
+ x: R.Tensor((640, 640), dtype="float32"),
+ y: R.Tensor((640, 640), dtype="float32"),
+ bias: R.Tensor((640,), dtype="float32"),
+ y_1: R.Tensor((640, 640), dtype="float32"),
+ bias_1: R.Tensor((640,), dtype="float32"),
+ y_2: R.Tensor((640, 640), dtype="float32"),
+ bias_2: R.Tensor((640,), dtype="float32"),
+ ) -> R.Tensor((640, 1920), dtype="float32"):
+ with R.dataflow():
+ lv = R.concat((y, y_1, y_2), axis=1)
+ lv1 = R.matmul(x, lv, out_dtype="float32")
+ lv2 = R.concat((bias, bias_1, bias_2), axis=0)
+ lv3 = R.add(lv1, lv2)
+ lv4 = R.nn.relu(lv3)
+ lv2_1 = R.strided_slice(lv4, axes=[1], begin=[0], end=[640],
strides=[1])
+ lv5 = R.strided_slice(lv4, axes=[1], begin=[640], end=[1280],
strides=[1])
+ lv8 = R.strided_slice(lv4, axes=[1], begin=[1280], end=[1920],
strides=[1])
+ lv9 = R.concat((lv2_1, lv5, lv8), axis=1)
+ R.output(lv9)
+ return lv9
+
+ tvm.ir.assert_structural_equal(mod["main"], expected1)
+
+ mod = get_parallel_matmul(3, with_bias=[True, True, True],
activation=["relu", None, "relu"])
+ mod = CombineParallelMatmul()(mod)
+
+ @R.function
+ def expected2(
+ x: R.Tensor((640, 640), dtype="float32"),
+ y: R.Tensor((640, 640), dtype="float32"),
+ bias: R.Tensor((640,), dtype="float32"),
+ y_1: R.Tensor((640, 640), dtype="float32"),
+ bias_1: R.Tensor((640,), dtype="float32"),
+ y_2: R.Tensor((640, 640), dtype="float32"),
+ bias_2: R.Tensor((640,), dtype="float32"),
+ ) -> R.Tensor((640, 1920), dtype="float32"):
+ with R.dataflow():
+ lv = R.concat((y, y_1, y_2), axis=1)
+ lv1 = R.matmul(x, lv, out_dtype="float32")
+ lv2 = R.concat((bias, bias_1, bias_2), axis=0)
+ lv3 = R.add(lv1, lv2)
+ lv1_1 = R.strided_slice(lv3, axes=[1], begin=[0], end=[640],
strides=[1])
+ lv2_1 = R.nn.relu(lv1_1)
+ lv4 = R.strided_slice(lv3, axes=[1], begin=[640], end=[1280],
strides=[1])
+ lv6 = R.strided_slice(lv3, axes=[1], begin=[1280], end=[1920],
strides=[1])
+ lv7 = R.nn.relu(lv6)
+ lv8 = R.concat((lv2_1, lv4, lv7), axis=1)
+ R.output(lv8)
+ return lv8
+
+ tvm.ir.assert_structural_equal(mod["main"], expected2)
+
+ mod = get_parallel_matmul(3, with_bias=[True, False, True],
activation=["relu", None, "relu"])
+ mod = CombineParallelMatmul()(mod)
+
+ @R.function
+ def expected3(
+ x: R.Tensor((640, 640), dtype="float32"),
+ y: R.Tensor((640, 640), dtype="float32"),
+ bias: R.Tensor((640,), dtype="float32"),
+ y_1: R.Tensor((640, 640), dtype="float32"),
+ y_2: R.Tensor((640, 640), dtype="float32"),
+ bias_1: R.Tensor((640,), dtype="float32"),
+ ) -> R.Tensor((640, 1920), dtype="float32"):
+ with R.dataflow():
+ lv = R.concat((y, y_1, y_2), axis=1)
+ lv1 = R.matmul(x, lv, out_dtype="float32")
+ lv_1 = R.strided_slice(lv1, axes=[1], begin=[0], end=[640],
strides=[1])
+ lv1_1 = R.add(lv_1, bias)
+ lv2 = R.nn.relu(lv1_1)
+ lv3 = R.strided_slice(lv1, axes=[1], begin=[640], end=[1280],
strides=[1])
+ lv4 = R.strided_slice(lv1, axes=[1], begin=[1280], end=[1920],
strides=[1])
+ lv5 = R.add(lv4, bias_1)
+ lv6 = R.nn.relu(lv5)
+ lv7 = R.concat((lv2, lv3, lv6), axis=1)
+ R.output(lv7)
+ return lv7
+
+ tvm.ir.assert_structural_equal(mod["main"], expected3)
+
+
+def test_rhs_batched():
+ @tvm.script.ir_module
+ class four_matmul:
+ @R.function
+ def main(
+ x: R.Tensor((1024, 640), "float32"),
+ w0: R.Tensor((2, 640, 640), "float32"),
+ w1: R.Tensor((640, 640), "float32"),
+ w2: R.Tensor((2, 640, 640), "float32"),
+ w3: R.Tensor((3, 4, 640, 640), "float32"),
+ ) -> R.Tensor:
+ with R.dataflow():
+ lv0 = R.matmul(x, w0)
+ lv1 = R.matmul(x, w1)
+ lv2 = R.matmul(x, w2)
+ lv3 = R.matmul(x, w3)
+ out = (lv0, lv1, lv2, lv3)
+ R.output(out)
+ return out
+
+ mod = CombineParallelMatmul()(four_matmul)
+
+ @R.function
+ def expected1(
+ x: R.Tensor((1024, 640), dtype="float32"),
+ w0: R.Tensor((2, 640, 640), dtype="float32"),
+ w1: R.Tensor((640, 640), dtype="float32"),
+ w2: R.Tensor((2, 640, 640), dtype="float32"),
+ w3: R.Tensor((3, 4, 640, 640), dtype="float32"),
+ ) -> R.Tensor:
+ with R.dataflow():
+ lv = R.concat((w0, w2), axis=2)
+ lv1 = R.matmul(x, lv, out_dtype="float32")
+ lv0 = R.strided_slice(lv1, axes=[2], begin=[0], end=[640],
strides=[1])
+ lv1_1 = R.matmul(x, w1, out_dtype="void")
+ lv2 = R.strided_slice(lv1, axes=[2], begin=[640], end=[1280],
strides=[1])
+ lv3 = R.matmul(x, w3, out_dtype="void")
+ out = lv0, lv1_1, lv2, lv3
+ R.output(out)
+ return out
+
+ tvm.ir.assert_structural_equal(mod["main"], expected1)
+
+ @tvm.script.ir_module
+ class four_matmul_incompatible_batches:
+ @R.function
+ def main(
+ x: R.Tensor((1024, 640), "float32"),
+ w0: R.Tensor((2, 640, 640), "float32"),
+ w1: R.Tensor((3, 640, 640), "float32"),
+ w2: R.Tensor((2, 640, 640), "float32"),
+ w3: R.Tensor((2, 640, 640), "float32"),
+ ) -> R.Tensor:
+ with R.dataflow():
+ lv0 = R.matmul(x, w0)
+ lv1 = R.matmul(x, w1)
+ lv2 = R.matmul(x, w2)
+ lv3 = R.matmul(x, w3)
+ out = (lv0, lv1, lv2, lv3)
+ R.output(out)
+ return out
+
+ mod = CombineParallelMatmul()(four_matmul_incompatible_batches)
+ # For now, when rhs matrices have the same rank but different batch sizes,
we don't
+ # combine any of them.
+ tvm.ir.assert_structural_equal(mod, four_matmul_incompatible_batches)
+
+
+def test_multiple_combine():
+ @tvm.script.ir_module
+ class multiple_combine:
+ @R.function
+ def main(
+ x1: R.Tensor((2, 1024, 640), "float32"),
+ x2: R.Tensor((2, 1024, 640), "float32"),
+ w0: R.Tensor((640, 640), "float32"),
+ w1: R.Tensor((640, 640), "float32"),
+ w2: R.Tensor((640, 640), "float32"),
+ w3: R.Tensor((640, 640), "float32"),
+ w4: R.Tensor((640, 640), "float32"),
+ b0: R.Tensor((640,), "float32"),
+ b1: R.Tensor((640,), "float32"),
+ ) -> R.Tensor:
+ with R.dataflow():
+ lv0 = R.matmul(x1, w0)
+ lv3 = R.matmul(x2, w3)
+ lv1 = R.matmul(x1, w1)
+ lv5 = R.add(lv3, b0)
+ lv2 = R.matmul(x1, w2)
+ lv4 = R.matmul(x2, w4)
+ lv6 = R.add(lv4, b1)
+ out = (lv0, lv1, lv2, lv5, lv6)
+ R.output(out)
+ return out
+
+ mod = CombineParallelMatmul()(multiple_combine)
+
+ @R.function
+ def expected1(
+ x1: R.Tensor((2, 1024, 640), dtype="float32"),
+ x2: R.Tensor((2, 1024, 640), dtype="float32"),
+ w0: R.Tensor((640, 640), dtype="float32"),
+ w1: R.Tensor((640, 640), dtype="float32"),
+ w2: R.Tensor((640, 640), dtype="float32"),
+ w3: R.Tensor((640, 640), dtype="float32"),
+ w4: R.Tensor((640, 640), dtype="float32"),
+ b0: R.Tensor((640,), dtype="float32"),
+ b1: R.Tensor((640,), dtype="float32"),
+ ) -> R.Tensor:
+ with R.dataflow():
+ lv = R.concat((w0, w1, w2), axis=1)
+ lv1 = R.matmul(x1, lv, out_dtype="float32")
+ lv0 = R.strided_slice(lv1, axes=[2], begin=[0], end=[640],
strides=[1])
+ lv1_1 = R.strided_slice(lv1, axes=[2], begin=[640], end=[1280],
strides=[1])
+ lv_1 = R.concat((w3, w4), axis=1)
+ lv1_2 = R.matmul(x2, lv_1, out_dtype="float32")
+ lv2 = R.concat((b0, b1), axis=0)
+ lv3 = R.add(lv1_2, lv2)
+ lv5 = R.strided_slice(lv3, axes=[2], begin=[0], end=[640],
strides=[1])
+ lv2_1 = R.strided_slice(lv1, axes=[2], begin=[1280], end=[1920],
strides=[1])
+ lv6 = R.strided_slice(lv3, axes=[2], begin=[640], end=[1280],
strides=[1])
+ out = lv0, lv1_1, lv2_1, lv5, lv6
+ R.output(out)
+ return out
+
+ tvm.ir.assert_structural_equal(mod["main"], expected1)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()