This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 0877cf65ba [Unity][BYOC] Add pattern-based partitioning pass (#14054)
0877cf65ba is described below
commit 0877cf65ba9688f0f45bb429caece6253da53e3c
Author: masahi <[email protected]>
AuthorDate: Mon Feb 20 17:08:44 2023 +0900
[Unity][BYOC] Add pattern-based partitioning pass (#14054)
This adds a new pass, FuseOpsByPattern, which applies pattern matching to
each function in the given module, and groups matched expressions into a new
function. The end result is similar to FuseOps, but fusion is driven completely
by
the provided patterns. The implementation also reuses OperatorFusor used by
FuseOps to create grouped functions from partitioned groups, further
illustrating the similarity between the two passes.
The new pass will serve the same role the MergeComposite pass plays in
Relay BYOC - grouped functions are annotated with the "composite" attribute to
denote what operations a given function consists of, and offloaded to external
backends. But it can be also useful in non-BYOC settings, for example to
support advanced fusion that the op-kind based one doesn't handle (fused MHA,
conv2d / gemm + reduction fusion, etc).
The original PR: https://github.com/tlc-pack/relax/pull/366
---
python/tvm/relax/transform/transform.py | 37 +-
src/relax/transform/fuse_ops.cc | 199 +++++++++
.../relax/test_transform_fuse_ops_by_pattern.py | 464 +++++++++++++++++++++
3 files changed, 699 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index 1f14823b5a..bf90ef0b09 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -19,7 +19,7 @@
import functools
import inspect
import types
-from typing import Callable, Dict, Union, Optional, List
+from typing import Callable, Dict, Union, Optional, List, Tuple
import numpy as np # type: ignore
import tvm.ir
from tvm.runtime import NDArray
@@ -241,6 +241,41 @@ def FuseTIR() -> tvm.ir.transform.Pass:
return _ffi_api.FuseTIR() # type: ignore
+def FuseOpsByPattern(
+ patterns: List[Tuple], annotate_codegen: bool = False
+) -> tvm.ir.transform.Pass:
+ """Apply pattern matching to each function in the given module, and group
matched expressions
+ into a new function.
+
+ The end result is similar to FuseOps, but fusion is driven completely by
the provided patterns.
+
+ Parameters
+ ----------
+ patterns : List[Tuple[str, DFPattern]]
+ The patterns to detect. The order of the patterns determines the order
of priority in which
+ they are matched. Higher-priority patterns should come earlier in the
list.
+ The string is the name of the corresponding pattern. It becomes the
value of the kComposite
+ attribute of a fused function after a successful matching.
+
+ annotate_codegen : bool
+ If True, wrap each created composite function with another function,
whose body consists
+ only of a call to the composite function, and annotate the outer
function with "Codegen"
+ and "global_symbol" attributes. The "Codegen" attribute is set as the
prefix of the
+ corresponding pattern name. For example, "dnnl" if the pattern name is
"dnnl.conv2d_relu".
+
+ This must be True if the created composite functions are intended to
be offloaded to
+ an external backend without using the MergeCompositeFunctions pass.
+
+ Returns
+ -------
+ ret : tvm.transform.Pass
+ The registered pass for pattern-based fusion.
+
+ """
+ pattern_names, df_patterns = zip(*patterns)
+ return _ffi_api.FuseOpsByPattern(pattern_names, df_patterns,
annotate_codegen) # type: ignore
+
+
def LegalizeOps(customize_legalize_map: Optional[Dict[str, LegalizeFunc]] =
None):
"""Legalize high-level operator calls in Relax functions to call_tir
with corresponding low-level TIR PrimFuncs.
diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index 0a0209bb87..3b78274cec 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -28,12 +28,15 @@
*/
#include <tvm/relax/analysis.h>
+#include <tvm/relax/dataflow_matcher.h>
+#include <tvm/relax/dataflow_pattern.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/struct_info.h>
#include <tvm/relax/transform.h>
#include <tvm/tir/function.h>
#include <optional>
+#include <unordered_map>
#include "../../relay/analysis/graph_partitioner.h"
#include "../../support/arena.h"
@@ -880,6 +883,188 @@ IRModule FuseOps(IRModule mod, int opt_level, size_t
max_fuse_depth) {
return OperatorFusor(mod, graph, groups, /*lift_constants*/
true).Transform();
}
+IRModule MakeGroupedFunctions(
+ IRModule mod, const std::unordered_map<const Object*,
GraphPartitioner::Group*>& partition,
+ bool lift_constants) {
+ return OperatorFusor(mod, partition, lift_constants).Transform();
+}
+
+static Map<Expr, Var> GetBindingInverse(const Map<Var, Expr>& binding) {
+ Map<Expr, Var> value_to_bound_var;
+ for (const auto& [var, val] : binding) {
+ value_to_bound_var.Set(val, var);
+ }
+ return value_to_bound_var;
+}
+
+/*! \brief Create a "partitioning", a map from interior / leaf expr to its
representative group,
+ * based on the provided pattern. The result can be passed to OperatorFusor
above to fuse operations
+ * in a group and create a grouped function.
+ */
+class PatternBasedPartitioner : ExprVisitor {
+ public:
+ using Group = GraphPartitioner::Group;
+ using GroupMap = OperatorFusor::GroupMap;
+ using ExprVisitor::VisitExpr_;
+
+ static GroupMap Run(String pattern_name, DFPattern pattern, Expr expr,
support::Arena* arena) {
+ PatternBasedPartitioner part(pattern_name, pattern,
AnalyzeVar2Value(expr));
+ // Initialize each expr to have its own group
+ PostOrderVisit(
+ expr, [arena, &part](const Expr& e) { part.group_map_[e.get()] =
arena->make<Group>(); });
+ part.VisitExpr(expr);
+ return part.group_map_;
+ }
+
+ PatternBasedPartitioner(String pattern_name, DFPattern pattern, const
Map<Var, Expr>& bindings)
+ : pat_name_(pattern_name),
+ pat_(pattern),
+ bindings_(bindings),
+ value_to_bound_var_(GetBindingInverse(bindings)) {}
+
+ void VisitExpr_(const CallNode* call) override {
+ if (auto matches_opt = ExtractMatchedExpr(pat_, GetRef<Call>(call),
bindings_)) {
+ // If a match is found, put all matching expressions into the same group.
+ // OperatorFusor also requires that the bound variable be in the same
group as the RHS value.
+ // Since is_op(...) based pattern only matches against call nodes on the
right hand side,
+ // we need to take care of groups corresponding to the LHS bound
variables carefully.
+
+ // In the example below, conv2d + relu pattern would match if the "call"
variable in this
+ // function points to the relu op. We identify the group corresponding
to "conv1", and make
+ // it the representative group for relu and conv2d on the RHS and also
"lv" on the LHS.
+
+ // with R.dataflow():
+ // lv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(...)
+ // conv1: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv)
+
+ // parent_group corresponds to the group of "conv1" above.
+ auto parent_group = GetGroupForBoundVar(GetRef<Call>(call));
+ ICHECK(parent_group);
+ parent_group->attrs.Set(attr::kComposite, pat_name_);
+
+ for (const auto& [pat, match] : matches_opt.value()) {
+ ICHECK(group_map_.count(match.get()));
+ // Put all matching call nodes into the parent group.
+ if (pat->IsInstance<CallPatternNode>() && match != GetRef<Call>(call))
{
+ AddToGroup(match, parent_group);
+ // Put the bound variable on the LHS into the same parent group.
+ AddToGroup(value_to_bound_var_[match], parent_group);
+ }
+ }
+ }
+ }
+
+ private:
+ void AddToGroup(Expr e, Group* to) {
+ if (group_map_[e.get()] != to) {
+ --group_map_[e.get()]->num_nodes;
+ group_map_[e.get()]->parent = to;
+ ++to->num_nodes;
+ }
+ }
+
+ Group* GetGroupForBoundVar(Expr e) {
+ ICHECK(value_to_bound_var_.count(e));
+ auto bound_var = value_to_bound_var_[e];
+ ICHECK(group_map_.count(bound_var.get()));
+ return group_map_[bound_var.get()]->FindRoot();
+ }
+
+ String pat_name_;
+ DFPattern pat_;
+ Map<Var, Expr> bindings_;
+ Map<Expr, Var> value_to_bound_var_;
+ GroupMap group_map_;
+};
+
+/*!
+ * \brief Wrap each created composite function with another function, whose
body consists
+ * only of a call to the composite function, and annotate the outer function
with kCodegen
+ * and kGlobalSymbol attributes.
+ */
+class CompositeFunctionAnnotator : public ExprMutator {
+ public:
+ explicit CompositeFunctionAnnotator(IRModule mod) : ExprMutator(mod) {}
+ using ExprMutator::VisitExpr_;
+
+ IRModule Run() {
+ auto mod = builder_->GetContextIRModule();
+ auto gvar = mod->GetGlobalVar("main");
+ auto func = Downcast<Function>(mod->Lookup(gvar));
+ auto new_func =
+ Function(func->params, VisitExpr(func->body), func->ret_struct_info,
func->attrs);
+ builder_->UpdateFunction(gvar, new_func);
+ return builder_->GetContextIRModule();
+ }
+
+ Expr VisitExpr_(const CallNode* call_node) final {
+ if (auto const* gvar = call_node->op.as<GlobalVarNode>()) {
+ if (auto it = gvar_map_.find(gvar); it != gvar_map_.end()) {
+ return Call(it->second, call_node->args);
+ }
+ auto func =
builder_->GetContextIRModule()->Lookup(GetRef<GlobalVar>(gvar));
+ if (auto composite_name = func->GetAttr<String>(attr::kComposite)) {
+ auto new_func = Downcast<Function>(VisitExpr(func));
+ auto codegen_name = GetCodegenName(composite_name.value());
+ auto gsymbol = gvar->name_hint + "_" + codegen_name;
+ new_func = WithAttrs(new_func,
+ {{attr::kCodegen, codegen_name},
{tvm::attr::kGlobalSymbol, gsymbol}});
+ builder_->GetContextIRModule()->Remove(GetRef<GlobalVar>(gvar));
+ auto new_gvar = builder_->AddFunction(new_func, gsymbol);
+ gvar_map_[gvar] = new_gvar;
+ return Call(new_gvar, call_node->args);
+ }
+ }
+ return ExprMutator::VisitExpr_(call_node);
+ }
+
+ Expr VisitExpr_(const FunctionNode* func_node) final {
+ auto f_inner = ExprMutator::VisitExpr_(func_node);
+ auto composite_name = func_node->GetAttr<String>(attr::kComposite);
+ ICHECK(composite_name);
+
+ Array<Var> param_vars;
+ Array<Expr> params;
+
+ for (auto v : func_node->params) {
+ Var new_v(v->name_hint(), GetStructInfo(v));
+ param_vars.push_back(new_v);
+ params.push_back(new_v);
+ }
+
+ return Function(param_vars, Call(f_inner, params),
func_node->ret_struct_info);
+ }
+
+ private:
+ String GetCodegenName(const std::string& composite_name) {
+ auto delim_pos = composite_name.find(".");
+ ICHECK(delim_pos != std::string::npos) << "The pattern name for a
composite function should "
+ "start with a compiler name
followed by period.";
+ return composite_name.substr(0, delim_pos);
+ }
+
+ /*! \brief A map from old global vars to their replacements. */
+ std::unordered_map<const GlobalVarNode*, GlobalVar> gvar_map_;
+};
+
+IRModule FuseOpsByPattern(const tvm::Array<String>& pattern_names,
+ const tvm::Array<DFPattern>& patterns, IRModule mod,
+ bool annotate_codegen) {
+ support::Arena arena;
+ for (size_t i = 0; i < pattern_names.size(); ++i) {
+ OperatorFusor::GroupMap group_map;
+ for (const auto& entry : mod->functions) {
+ auto map = PatternBasedPartitioner::Run(pattern_names[i], patterns[i],
entry.second, &arena);
+ group_map.insert(map.begin(), map.end());
+ }
+ mod = MakeGroupedFunctions(mod, group_map, /*lift_constants*/ false);
+ }
+ if (annotate_codegen) {
+ return CompositeFunctionAnnotator(mod).Run();
+ }
+ return mod;
+}
+
namespace transform {
Pass FuseOps(int fuse_opt_level) {
@@ -897,6 +1082,20 @@ Pass FuseOps(int fuse_opt_level) {
TVM_REGISTER_GLOBAL("relax.transform.FuseOps").set_body_typed(FuseOps);
+Pass FuseOpsByPattern(const tvm::Array<String>& pattern_names,
+ const tvm::Array<DFPattern>& patterns, bool
annotate_codegen) {
+ runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = //
+ [=](IRModule m, PassContext pc) {
+ return relax::FuseOpsByPattern(pattern_names, patterns, m,
annotate_codegen);
+ };
+ return CreateModulePass(/*pass_function=*/pass_func, //
+ /*opt_level=*/0, //
+ /*pass_name=*/"FuseOpsByPattern", //
+ /*required=*/{});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.FuseOpsByPattern").set_body_typed(FuseOpsByPattern);
+
} // namespace transform
} // namespace relax
diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py
b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
new file mode 100644
index 0000000000..da5b92fb64
--- /dev/null
+++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
@@ -0,0 +1,464 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+import numpy as np
+
+import tvm
+
+from tvm import relax
+from tvm.script import relax as R
+from tvm.relax.dpl.pattern import make_fused_bias_activation_pattern, is_op,
wildcard
+
+
[email protected]_module
+class Conv2dReLU:
+ @R.function
+ def main(
+ data: R.Tensor((1, 64, 56, 56), "float32"),
+ weight1: R.Tensor((64, 64, 3, 3), "float32"),
+ ):
+ with R.dataflow():
+ conv1 = R.nn.relu(R.nn.conv2d(data, weight1, padding=(1, 1)))
+ R.output(conv1)
+
+ return conv1
+
+
[email protected]_module
+class Conv2dReLU_composite_annotated:
+ @R.function
+ def main(
+ data: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
+ with R.dataflow():
+ gv: R.Tensor(
+ (1, 64, 56, 56), dtype="float32"
+ ) = fused_relax_nn_conv2d_relax_nn_relu_dnnl(data, weight1)
+ R.output(gv)
+ return gv
+
+ @R.function
+ def fused_relax_nn_conv2d_relax_nn_relu_dnnl(
+ data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight11: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
+ R.func_attr(
+ {"Codegen": "dnnl", "global_symbol":
"fused_relax_nn_conv2d_relax_nn_relu_dnnl"}
+ )
+
+ @R.function
+ def gv1(
+ data2: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight12: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
+ R.func_attr({"Primitive": 1, "Composite": "dnnl.conv2d_relu"})
+ with R.dataflow():
+ lv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(
+ data2,
+ weight12,
+ padding=[1, 1, 1, 1],
+ )
+ gv2: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv)
+ R.output(gv2)
+ return gv2
+
+ gv11: R.Tensor((1, 64, 56, 56), dtype="float32") = gv1(data1, weight11)
+ return gv11
+
+
[email protected]_module
+class Conv2dReLUx2:
+ @R.function
+ def main(
+ data: R.Tensor((1, 64, 56, 56), "float32"),
+ weight1: R.Tensor((64, 64, 3, 3), "float32"),
+ weight2: R.Tensor((64, 64, 3, 3), "float32"),
+ ):
+ with R.dataflow():
+ conv1 = R.nn.relu(R.nn.conv2d(data, weight1, padding=(1, 1)))
+ conv2 = R.nn.relu(R.nn.conv2d(conv1, weight2, padding=(0, 0)))
+ R.output(conv2)
+
+ return conv2
+
+
[email protected]_module
+class Conv2dReLUx2Partitioned:
+ @R.function
+ def main(
+ data: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ weight2: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((1, 64, 56, 56), dtype="float32") =
fused_relax_nn_conv2d_relax_nn_relu(
+ data, weight1
+ )
+ gv: R.Tensor((1, 64, 54, 54), dtype="float32") =
fused_relax_nn_conv2d_relax_nn_relu1(
+ lv, weight2
+ )
+ R.output(gv)
+ return gv
+
+ @R.function
+ def fused_relax_nn_conv2d_relax_nn_relu(
+ data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight11: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
+ R.func_attr({"Primitive": 1, "Composite": "dnnl.conv2d_relu"})
+ with R.dataflow():
+ lv1: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(
+ data1, weight11, padding=[1, 1, 1, 1]
+ )
+ gv1: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv1)
+ R.output(gv1)
+ return gv1
+
+ @R.function
+ def fused_relax_nn_conv2d_relax_nn_relu1(
+ conv1: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight21: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+ R.func_attr({"Primitive": 1, "Composite": "dnnl.conv2d_relu"})
+ with R.dataflow():
+ lv2: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d(
+ conv1, weight21, padding=[0, 0, 0, 0]
+ )
+ gv2: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(lv2)
+ R.output(gv2)
+ return gv2
+
+
[email protected]_module
+class Conv2dReLUx2Partitioned_only_conv2d:
+ @R.function
+ def main(
+ data: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ weight2: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((1, 64, 56, 56), dtype="float32") =
fused_relax_nn_conv2d(data, weight1)
+ conv1: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv)
+ lv1: R.Tensor((1, 64, 54, 54), dtype="float32") =
fused_relax_nn_conv2d1(conv1, weight2)
+ conv2d: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(lv1)
+ R.output(conv2d)
+ return conv2d
+
+ @R.function
+ def fused_relax_nn_conv2d(
+ data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight11: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
+ R.func_attr({"Primitive": 1, "Composite": "dnnl.conv2d"})
+ with R.dataflow():
+ gv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(
+ data1, weight11, padding=[1, 1, 1, 1]
+ )
+ R.output(gv)
+ return gv
+
+ @R.function
+ def fused_relax_nn_conv2d1(
+ conv11: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight21: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+ R.func_attr({"Primitive": 1, "Composite": "dnnl.conv2d"})
+ with R.dataflow():
+ gv1: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d(
+ conv11, weight21, padding=[0, 0, 0, 0]
+ )
+ R.output(gv1)
+ return gv1
+
+
[email protected]_module
+class Conv2dConv2dReLU:
+ @R.function
+ def main(
+ data: R.Tensor((1, 64, 56, 56), "float32"),
+ weight1: R.Tensor((64, 64, 3, 3), "float32"),
+ weight2: R.Tensor((64, 64, 3, 3), "float32"),
+ ):
+ with R.dataflow():
+ conv1 = R.nn.conv2d(data, weight1, padding=(1, 1))
+ conv2d = R.nn.relu(R.nn.conv2d(conv1, weight2, padding=(0, 0)))
+ R.output(conv2d)
+
+ return conv2d
+
+
[email protected]_module
+class Conv2dConv2dReLUPartitioned:
+ @R.function
+ def main(
+ data: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ weight2: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((1, 64, 56, 56), dtype="float32") =
fused_relax_nn_conv2d(data, weight1)
+ gv: R.Tensor((1, 64, 54, 54), dtype="float32") =
fused_relax_nn_conv2d_relax_nn_relu(
+ lv, weight2
+ )
+ R.output(gv)
+ return gv
+
+ @R.function
+ def fused_relax_nn_conv2d_relax_nn_relu(
+ conv1: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight21: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+ R.func_attr({"Primitive": 1, "Composite": "dnnl.conv2d_relu"})
+ with R.dataflow():
+ lv1: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d(
+ conv1, weight21, padding=[0, 0, 0, 0]
+ )
+ gv1: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(lv1)
+ R.output(gv1)
+ return gv1
+
+ @R.function
+ def fused_relax_nn_conv2d(
+ data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight11: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
+ R.func_attr({"Primitive": 1, "Composite": "dnnl.conv2d"})
+ with R.dataflow():
+ gv2: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(
+ data1, weight11, padding=[1, 1, 1, 1]
+ )
+ R.output(gv2)
+ return gv2
+
+
[email protected]_module
+class BranchTupleOutput:
+ @R.function
+ def main(
+ data: R.Tensor((1, 64, 56, 56), "float32"),
+ weight: R.Tensor((64, 64, 3, 3), "float32"),
+ ):
+ with R.dataflow():
+ conv1 = R.nn.conv2d(data, weight)
+ relu1 = R.nn.relu(conv1)
+ gelu1 = R.nn.gelu(relu1)
+ gelu2 = R.nn.gelu(conv1)
+ out = relax.op.add(gelu1, gelu2)
+ R.output(out)
+
+ return out
+
+
[email protected]_module
+class BranchTupleOutputPartitioned:
+ @R.function
+ def main(
+ data: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tuple(
+ R.Tensor((1, 64, 54, 54), dtype="float32"),
+ R.Tensor((1, 64, 54, 54), dtype="float32"),
+ ) = fused_relax_nn_conv2d_relax_nn_relu(data, weight)
+ lv1: R.Tensor((1, 64, 54, 54), dtype="float32") = lv[1] # conv1
+ lv2: R.Tensor((1, 64, 54, 54), dtype="float32") = lv[0] #
relu(conv1)
+ gelu1: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.gelu(lv2)
+ gelu2: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.gelu(lv1)
+ out: R.Tensor((1, 64, 54, 54), dtype="float32") = R.add(gelu1,
gelu2)
+ R.output(out)
+ return out
+
+ @R.function
+ def fused_relax_nn_conv2d_relax_nn_relu(
+ data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tuple(
+ R.Tensor((1, 64, 54, 54), dtype="float32"), R.Tensor((1, 64, 54, 54),
dtype="float32")
+ ):
+ R.func_attr({"Primitive": 1, "Composite": "dnnl.conv2d_relu"})
+ with R.dataflow():
+ gv: R.Tensor((1, 64, 54, 54), dtype="float32") =
R.nn.conv2d(data1, weight1)
+ gv1: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(gv)
+ R.output(gv, gv1)
+ return (gv1, gv)
+
+
[email protected]_module
+class Branch:
+ @R.function
+ def main(
+ data: R.Tensor((1, 64, 56, 56), "float32"),
+ weight: R.Tensor((64, 64, 3, 3), "float32"),
+ ):
+ with R.dataflow():
+ conv1 = R.nn.conv2d(data, weight)
+ relu1 = R.nn.relu(conv1)
+ gelu1 = R.nn.gelu(conv1)
+
+ out = relax.op.add(relu1, gelu1)
+ R.output(out)
+
+ return out
+
+
[email protected]_module
+class Conv2dx2:
+ @R.function
+ def main(
+ data: R.Tensor((16, 32, 32, 16), "float16"),
+ weight1: R.Tensor((16, 3, 3, 16), "float16"),
+ weight2: R.Tensor((16, 3, 3, 16), "float16"),
+ ):
+ with R.dataflow():
+ conv1 = relax.op.nn.conv2d(
+ data, weight1, padding=(1, 1), data_layout="NHWC",
kernel_layout="OHWI"
+ )
+ conv2 = relax.op.nn.conv2d(
+ conv1, weight2, padding=(1, 1), data_layout="NHWC",
kernel_layout="OHWI"
+ )
+ R.output(conv2)
+
+ return conv2
+
+
[email protected]_module
+class Conv2dx2_partitioned:
+ @R.function
+ def main(
+ data: R.Tensor((16, 32, 32, 16), dtype="float16"),
+ weight1: R.Tensor((16, 3, 3, 16), dtype="float16"),
+ weight2: R.Tensor((16, 3, 3, 16), dtype="float16"),
+ ) -> R.Tensor((16, 32, 32, 16), dtype="float16"):
+ with R.dataflow():
+ lv: R.Tensor((16, 32, 32, 16), dtype="float16") =
fused_relax_nn_conv2d_cutlass(
+ data, weight1
+ )
+ gv: R.Tensor((16, 32, 32, 16), dtype="float16") =
fused_relax_nn_conv2d_cutlass(
+ lv, weight2
+ )
+ R.output(gv)
+ return gv
+
+ @R.function
+ def fused_relax_nn_conv2d_cutlass(
+ data: R.Tensor((16, 32, 32, 16), dtype="float16"),
+ weight1: R.Tensor((16, 3, 3, 16), dtype="float16"),
+ ) -> R.Tensor((16, 32, 32, 16), dtype="float16"):
+ R.func_attr({"Codegen": "cutlass", "global_symbol":
"fused_relax_nn_conv2d_cutlass"})
+
+ @R.function
+ def gv(
+ data_1: R.Tensor((16, 32, 32, 16), dtype="float16"),
+ weight1_1: R.Tensor((16, 3, 3, 16), dtype="float16"),
+ ) -> R.Tensor((16, 32, 32, 16), dtype="float16"):
+ R.func_attr({"Composite": "cutlass.conv2d", "Primitive": 1})
+ with R.dataflow():
+ gv_1: R.Tensor((16, 32, 32, 16), dtype="float16") =
R.nn.conv2d(
+ data_1,
+ weight1_1,
+ padding=[1, 1, 1, 1],
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ out_layout="NHWC",
+ )
+ R.output(gv_1)
+ return gv_1
+
+ gv1: R.Tensor((16, 32, 32, 16), dtype="float16") = gv(data, weight1)
+ return gv1
+
+
+conv2d_pat = make_fused_bias_activation_pattern("relax.nn.conv2d",
activation=None)
+conv2d_relu_pat = make_fused_bias_activation_pattern("relax.nn.conv2d",
activation="relax.nn.relu")
+
+
+def check(mod, patterns, expected, annoatate_codegen=False):
+ partitioned = relax.transform.FuseOpsByPattern(patterns,
annoatate_codegen)(mod)
+ tvm.ir.assert_structural_equal(partitioned, expected)
+
+
+def test_partition_conv2d_relu():
+ check(Conv2dReLUx2, [("dnnl.conv2d_relu", conv2d_relu_pat)],
Conv2dReLUx2Partitioned)
+
+
+def test_partition_multiple_patterns():
+ check(
+ Conv2dConv2dReLU,
+ [("dnnl.conv2d_relu", conv2d_relu_pat), ("dnnl.conv2d", conv2d_pat)],
+ Conv2dConv2dReLUPartitioned,
+ )
+
+
+def test_partition_order():
+ check(
+ Conv2dReLUx2,
+ [("dnnl.conv2d", conv2d_pat), ("dnnl.conv2d_relu", conv2d_relu_pat)],
+ Conv2dReLUx2Partitioned_only_conv2d,
+ )
+
+
+def test_branch_tuple_output():
+ check(BranchTupleOutput, [("dnnl.conv2d_relu", conv2d_relu_pat)],
BranchTupleOutputPartitioned)
+
+
+def test_cyclic_dependency():
+ conv_pat = make_fused_bias_activation_pattern("relax.nn.conv2d")
+ relu_pat = is_op("relax.nn.relu")(conv_pat)
+ add_pat = is_op("relax.add")(relu_pat, wildcard())
+
+ with pytest.raises(tvm.error.TVMError) as err:
+ relax.transform.FuseOpsByPattern([("compiler_A.conv2d_relu_add",
add_pat)])(Branch)
+
+ assert "A cyclic dependency detected" in str(err.value)
+
+
+def test_bind_params():
+ weight_np = np.random.randn(64, 64, 3, 3).astype("float32")
+ mod = tvm.transform.Sequential(
+ [
+ relax.transform.BindParams("main", {"weight1": weight_np}),
+ relax.transform.FuseOpsByPattern([("dnnl.conv2d_relu",
conv2d_relu_pat)]),
+ ]
+ )(Conv2dReLU)
+
+ assert "fused_relax_nn_conv2d_relax_nn_relu" in [var.name_hint for var in
mod.functions.keys()]
+
+ for gvar, f in mod.functions.items():
+ if gvar.name_hint == "fused_relax_nn_conv2d_relax_nn_relu":
+ conv2d = f.body.blocks[0].bindings[0].value
+ assert isinstance(conv2d.args[1], relax.Constant)
+
+
+def test_annotate_codegen():
+ check(
+ Conv2dReLU,
+ [("dnnl.conv2d_relu", conv2d_relu_pat)],
+ Conv2dReLU_composite_annotated,
+ annoatate_codegen=True,
+ )
+
+
+def test_multiple_calls_same_extern():
+ pat = make_fused_bias_activation_pattern("relax.nn.conv2d",
with_bias=False, activation=None)
+ check(Conv2dx2, [("cutlass.conv2d", pat)], Conv2dx2_partitioned,
annoatate_codegen=True)
+
+
+if __name__ == "__main__":
+ pytest.main([__file__])