This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 32ff55b8ce [Unity][BYOC] Add pass to merge composite functions to
offload large subgraphs (#14062)
32ff55b8ce is described below
commit 32ff55b8cecd213d9cfe92be54ae7aeb4a0d29b5
Author: masahi <[email protected]>
AuthorDate: Tue Feb 21 16:13:58 2023 +0900
[Unity][BYOC] Add pass to merge composite functions to offload large
subgraphs (#14062)
This PR adds a pass that merges neighboring calls to composite functions
offloaded to the same external backend into one function. This is important for
backends that want to receive as large subgraph as possible, for example
TensorRT. It plays the same role as the MergeCompilerRegion pass in Relay BYOC
does, and the algorithm follows the same idea described in
https://discuss.tvm.apache.org/t/relay-improved-graph-partitioning-algorithm/5830.
Original PR
https://github.com/tlc-pack/relax/pull/372
Substantial improvement by @yelite
https://github.com/tlc-pack/relax/pull/411
Related fix PR by @yelite
https://github.com/tlc-pack/relax/pull/406
Co-authored-by: Lite Ye <[email protected]>
---
include/tvm/relax/utils.h | 11 +-
python/tvm/relax/transform/transform.py | 14 +
python/tvm/relax/utils.py | 12 +-
src/relax/transform/merge_composite_functions.cc | 355 +++++++
src/relax/utils.cc | 29 +
.../test_transform_merge_composite_functions.py | 1051 ++++++++++++++++++++
tests/python/relax/test_utils.py | 107 ++
7 files changed, 1570 insertions(+), 9 deletions(-)
diff --git a/include/tvm/relax/utils.h b/include/tvm/relax/utils.h
index c1d984a21a..b3cc76768d 100644
--- a/include/tvm/relax/utils.h
+++ b/include/tvm/relax/utils.h
@@ -142,13 +142,16 @@ TVM_DLL bool IsBoolScalarType(const Type& ty, bool
permit_unknown_rank = true,
TVM_DLL bool IsLeafOrTuple(const Expr& expr);
/*!
- * \brief Copy the given function. The parameters of the original function
would be copied to
- * satisfy the restriction in the well-formed check: any two functions cannot
share the same
- * parameter variable.
+ * \brief Copy the given function. All variables that are bound inside the
original function
+ * would be copied to satisfy the restriction in the well-formed check:
Variables in
+ * Relax must be bound exactly once. This also ensures that both the function
and its copy
+ * can be inserted into the same IRModule, and be asserted on the structural
equality
+ * agaisnt IRModule created by TVMScript.
+ *
* \param func The relax function to copy.
* \return The copied function.
*/
-TVM_DLL Function CopyWithNewParams(Function func);
+TVM_DLL Function CopyWithNewVars(Function func);
} // namespace relax
} // namespace tvm
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index bf90ef0b09..12ed27f73a 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -276,6 +276,20 @@ def FuseOpsByPattern(
return _ffi_api.FuseOpsByPattern(pattern_names, df_patterns,
annotate_codegen) # type: ignore
+def MergeCompositeFunctions() -> tvm.ir.transform.Pass:
+ """Group one or multiple composite functions created by FuseOpsByPattern
into a new function.
+ The new function will be annotated with "Codegen" and "global_symbol"
attributes, and it
+ is intented to be offloaded to an external backend.
+
+ Returns
+ -------
+ ret : tvm.transform.Pass
+ The registered pass for merging composite functions.
+
+ """
+ return _ffi_api.MergeCompositeFunctions() # type: ignore
+
+
def LegalizeOps(customize_legalize_map: Optional[Dict[str, LegalizeFunc]] =
None):
"""Legalize high-level operator calls in Relax functions to call_tir
with corresponding low-level TIR PrimFuncs.
diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py
index 0bb82c79f4..d6b405f183 100644
--- a/python/tvm/relax/utils.py
+++ b/python/tvm/relax/utils.py
@@ -250,10 +250,12 @@ class _ArgsConverter:
args_converter = _ArgsConverter() # pylint: disable=invalid-name
-def copy_with_new_params(func: Function) -> Function:
- """Copy the given function. The parameters of the original function would
be copied to
- satisfy the restriction in the well-formed check: any two functions cannot
share the same
- parameter variable.
+def copy_with_new_vars(func: Function) -> Function:
+ """Copy the given function. All variables that are bound inside the
original function
+ would be copied to satisfy the restriction in the well-formed check:
Variables in
+ Relax must be bound exactly once. This also ensures that both the function
and its copy
+ can be inserted into the same IRModule, and be asserted on the structural
equality
+ agaisnt IRModule created by TVMScript.
Parameters
----------
@@ -265,4 +267,4 @@ def copy_with_new_params(func: Function) -> Function:
ret : Function
The copied function.
"""
- return _ffi_api.CopyWithNewParams(func) # type: ignore
+ return _ffi_api.CopyWithNewVars(func) # type: ignore
diff --git a/src/relax/transform/merge_composite_functions.cc
b/src/relax/transform/merge_composite_functions.cc
new file mode 100644
index 0000000000..db73392b02
--- /dev/null
+++ b/src/relax/transform/merge_composite_functions.cc
@@ -0,0 +1,355 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relax/transform/merge_composite_functions.cc
+ * \brief Group one or multiple composite functions created by
FuseOpsByPattern into a new
+ * function.
+ *
+ * The new function will be annotated with kCodegen and kGlobalSymbol
attributes, and it is
+ * intented to be offloaded to an external backend.
+ *
+ * A group for one composite function can be merged into another group for one
of its arguments,
+ * which we call the parent group for brevity, if the following conditions are
met:
+ * - The argument is the result of calling a composite function offloaded to
the same backend
+ * - Merging into the parent group would not create a cyclic dependency with
other parent groups
+ *
+ * For example, in the subgraph below the bottom group cannot be merged into
the left parent group,
+ * since the right parent group for X depends on an output from the left
parent group.
+ *
+ * O = Offloaded to A
+ * X = Offloaded to B
+ *
+ * Correct partitioning:
+ *
+ * O O
+ * / \ / \
+ * O X --> O + + X
+ * \ / \ /
+ * O O
+ *
+ * The algorithm proceeds by assigning a group to each subexpression in the
function according to
+ * its dataflow. On encountering a call node whose callee is a composite
function, we check the
+ * two conditions above to see if we can merge this call node into one of its
parent groups, and
+ * if we can merge some of its parent groups.
+ *
+ * To detect cyclic dependencies between groups, we propagate dependency
relations, both direct
+ * and indirect ones, as we flow through the function. The propagation of
indirect dependencies
+ * is important since the dependency relation is transitive.
+ */
+
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/struct_info.h>
+#include <tvm/relax/transform.h>
+#include <tvm/tir/function.h>
+
+#include "../../support/arena.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relax {
+
+using relay::GraphPartitioner;
+
+namespace {
+
+using Group = GraphPartitioner::Group;
+
+/*! \brief Assign group to each subexpression in a function according to its
+ * dataflow, and returns a mapping from a subexpression to its group. */
+class CompositeGroupsBuilder : public MemoizedExprTranslator<Group*> {
+ public:
+ using GroupMap = std::unordered_map<const Object*, Group*>;
+ using MemoizedExprTranslator<Group*>::VisitExpr_;
+
+ CompositeGroupsBuilder(IRModule mod, support::Arena* arena) : mod_(mod),
arena_(arena) {}
+
+ GroupMap Run(Function func) {
+ for (const auto& param : func->params) {
+ memo_[param] = arena_->make<Group>();
+ }
+ VisitExpr(func->body);
+
+ GroupMap group_map;
+ for (const auto& [expr, group] : memo_) {
+ group_map[expr.get()] = group->FindRoot();
+ }
+
+ return group_map;
+ }
+
+ Group* VisitBinding(const Binding& binding) {
+ if (const auto* node = binding.as<VarBindingNode>()) {
+ return VisitBinding_(node);
+ } else {
+ LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey();
+ }
+ }
+
+ void VisitBindingBlock_(const BindingBlockNode* block) {
+ for (Binding binding : block->bindings) {
+ VisitBinding(binding);
+ }
+ }
+
+ void VisitBindingBlock_(const DataflowBlockNode* block) {
+ for (Binding binding : block->bindings) {
+ VisitBinding(binding);
+ }
+ }
+
+ void VisitBindingBlock(const BindingBlock& block) {
+ if (const auto* node = block.as<DataflowBlockNode>()) {
+ VisitBindingBlock_(node);
+ } else if (const auto* node = block.as<BindingBlockNode>()) {
+ VisitBindingBlock_(node);
+ } else {
+ LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey();
+ }
+ }
+
+ Group* VisitExpr_(const SeqExprNode* op) {
+ for (BindingBlock block : op->blocks) {
+ VisitBindingBlock(block);
+ }
+ return VisitExpr(op->body);
+ }
+
+ Group* VisitExpr_(const CallNode* call) {
+ std::vector<Group*> groups_to_merge = GetGroupsToMerge(call);
+ Group* group;
+
+ if (groups_to_merge.size() == 0) {
+ // Create new group if there is nothing to merge with
+ group = CreateNewGroup(call);
+ } else {
+ auto it = groups_to_merge.cbegin();
+ // Assign the first mergable group to current node
+ // to reduce the number of groups created
+ group = *it++;
+ group->num_nodes += 1;
+
+ // Merge all groups
+ for (; it != groups_to_merge.cend(); ++it) {
+ MergeGroup(*it, group);
+ }
+ }
+
+ UpdateGroupDependencies(group, call->args);
+ return group;
+ }
+
+ private:
+ String GetCodegenName(const std::string& composite_name) {
+ auto delim_pos = composite_name.find(".");
+ ICHECK(delim_pos != std::string::npos) << "The pattern name for a
composite function should "
+ "start with a compiler name
followed by period.";
+ return composite_name.substr(0, delim_pos);
+ }
+
+ Optional<String> GetCodegenName(const Expr& callee) {
+ auto const* gvar = callee.as<GlobalVarNode>();
+ if (!gvar) {
+ return NullOpt;
+ }
+
+ auto composite_name_opt =
+
mod_->Lookup(GetRef<GlobalVar>(gvar))->GetAttr<String>(attr::kComposite);
+ if (!composite_name_opt) {
+ return NullOpt;
+ }
+
+ return GetCodegenName(composite_name_opt.value());
+ }
+
+ Optional<String> GetCodegenName(Group* group) {
+ return Downcast<Optional<String>>(group->attrs.Get(attr::kCodegen));
+ }
+
+ Group* CreateNewGroup(const CallNode* call) {
+ Group* group = arena_->make<Group>();
+ if (Optional<String> codegen_name = GetCodegenName(call->op)) {
+ group->attrs.Set(attr::kCodegen, codegen_name.value());
+ }
+ return group;
+ }
+
+ void MergeGroup(Group* from, Group* to) {
+ ICHECK_EQ(GetCodegenName(from), GetCodegenName(to));
+
+ Group* from_root = from->FindRoot();
+ Group* to_root = to->FindRoot();
+ if (from_root == to_root) {
+ return;
+ }
+
+ from_root->parent = to_root;
+ to_root->num_nodes += from_root->num_nodes;
+
+ // Update the group_deps_, maintaining the invariant that
+ // all groups in the map are root groups.
+ group_deps_[to_root].merge(group_deps_[from_root]);
+ group_deps_.erase(from_root);
+ for (auto& it : group_deps_) {
+ if (it.second.count(from_root)) {
+ it.second.erase(from_root);
+ it.second.insert(to_root);
+ }
+ }
+ }
+
+ std::unordered_set<Group*> GetParentGroupDependencies(const Array<Expr>&
args) {
+ // Collect groups that parent groups depend on
+ std::unordered_set<Group*> dependencies;
+
+ for (const auto& arg : args) {
+ for (auto dep : group_deps_[memo_[arg]->FindRoot()]) {
+ dependencies.insert(dep);
+ }
+ }
+
+ return dependencies;
+ }
+
+ void UpdateGroupDependencies(Group* group, const Array<Expr>& args) {
+ Group* group_root = group->FindRoot();
+
+ for (const auto& arg : args) {
+ auto arg_group_root = memo_[arg]->FindRoot();
+ if (arg_group_root == group_root) {
+ // If arg and the current node are in the same group,
+ // there is nothing to update.
+ continue;
+ }
+ // Add the group of arg as dependency
+ group_deps_[group_root].insert(arg_group_root);
+ // Propagate dependencies of arg
+ for (auto dep : group_deps_[arg_group_root]) {
+ group_deps_[group_root].insert(dep);
+ }
+ }
+ }
+
+ std::vector<Group*> GetGroupsToMerge(const CallNode* call) {
+ Optional<String> codegen_name = GetCodegenName(call->op);
+ if (!codegen_name.defined()) {
+ return {};
+ }
+
+ std::vector<Group*> groups_to_merge;
+ std::unordered_set<Group*> parent_dependencies =
GetParentGroupDependencies(call->args);
+
+ for (const auto& arg : call->args) {
+ auto arg_group = memo_[arg];
+ Optional<String> arg_codegen_name = GetCodegenName(arg_group);
+ if (arg_codegen_name == codegen_name &&
!parent_dependencies.count(arg_group->FindRoot())) {
+ // If there is a parent group with the same target, which none of the
parent dependency
+ // groups depends on, merging "this" call node into the parent group
will not form a cyclic
+ // dependency.
+ groups_to_merge.push_back(arg_group);
+ }
+ }
+
+ return groups_to_merge;
+ }
+
+ IRModule mod_;
+ support::Arena* arena_;
+ // Map from group to its dependencies. All groups in this map, whether it's
+ // the key or in value, should be root node (that is, group->parent ==
nullptr).
+ std::unordered_map<Group*, std::unordered_set<Group*>> group_deps_;
+};
+
+/*! \brief Inline definitions of composite functions at the global level into
their call sites.
+ This is necessary to make functions created by MergeCompositeFunctions
self-contained - each
+ external backend compiler does not need to refer to the original containing
module.
+ */
+class CompositeInliner : public ExprMutator {
+ public:
+ explicit CompositeInliner(IRModule mod) : ExprMutator(mod), mod_(mod) {}
+ using ExprMutator::VisitExpr_;
+
+ Function Run(Function func) {
+ inlined_functions_ = Map<Function, Function>();
+ auto new_body = VisitExpr(func->body);
+ auto new_func =
+ Function(func->params, new_body, func->ret_struct_info, func->attrs,
func->span);
+ return new_func;
+ }
+
+ Expr VisitExpr_(const CallNode* call) {
+ if (call->op->IsInstance<GlobalVarNode>()) {
+ auto gvar = Downcast<GlobalVar>(call->op);
+ auto func = Downcast<Function>(mod_->Lookup(gvar));
+
+ if (func->GetAttr<String>(attr::kComposite)) {
+ if (!inlined_functions_.count(func)) {
+ inlined_functions_.Set(func, CopyWithNewVars(func));
+ }
+ return Call(inlined_functions_[func], call->args);
+ }
+ }
+
+ return ExprMutator::VisitExpr_(call);
+ }
+
+ private:
+ IRModule mod_;
+ Map<Function, Function> inlined_functions_;
+};
+
+} // namespace
+
+IRModule MergeCompositeFunctions(IRModule mod) {
+ auto gvar = mod->GetGlobalVar("main");
+ auto func = Downcast<Function>(mod->Lookup(gvar));
+ support::Arena arena;
+ auto group_map = CompositeGroupsBuilder(mod, &arena).Run(func);
+ auto new_mod = MakeGroupedFunctions(mod, group_map);
+
+ CompositeInliner inliner(mod);
+ for (const auto& [gvar, func] : new_mod->functions) {
+ if (func->GetAttr<String>(attr::kCodegen)) {
+ auto new_func = inliner.Run(Downcast<Function>(func));
+ new_func = WithAttr(new_func, tvm::attr::kGlobalSymbol, gvar->name_hint);
+ new_mod->Update(gvar, new_func);
+ }
+ }
+ // TODO(@tvm-team): Implicit pass dependency. Revisit when we have a better
way to handle this.
+ return RemoveUnusedFunctions(new_mod, {"main"});
+}
+
+namespace transform {
+
+Pass MergeCompositeFunctions() {
+ runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = //
+ [=](IRModule mod, PassContext pc) { return
relax::MergeCompositeFunctions(mod); };
+ return CreateModulePass(/*pass_function=*/pass_func, //
+ /*opt_level=*/0, //
+ /*pass_name=*/"FuseOpsByPattern", //
+ /*required=*/{});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.MergeCompositeFunctions")
+ .set_body_typed(MergeCompositeFunctions);
+
+} // namespace transform
+
+} // namespace relax
+} // namespace tvm
diff --git a/src/relax/utils.cc b/src/relax/utils.cc
index 24414f250c..110bdb5c8c 100644
--- a/src/relax/utils.cc
+++ b/src/relax/utils.cc
@@ -82,5 +82,34 @@ bool IsLeafOrTuple(const Expr& expr) {
expr.as<OpNode>() || expr.as<TupleNode>();
}
+class FunctionCopier : public ExprMutator {
+ public:
+ static Function Transform(Function func) {
+ FunctionCopier copier;
+ // All variables that are bound inside the original function would be
copied
+ // to satisfy the restriction in the well-formed check: Variables in Relax
+ // must be bound exactly once.
+ return Downcast<Function>(copier.VisitExpr(func));
+ }
+
+ Var VisitVarDef_(const DataflowVarNode* var) override {
+ Var new_var = ExprMutator::VisitVarDef_(var);
+ Var copied_var = DataflowVar(new_var->name_hint(), GetStructInfo(new_var),
new_var->span);
+ var_remap_[var->vid] = copied_var;
+ return copied_var;
+ }
+
+ Var VisitVarDef_(const VarNode* var) override {
+ Var new_var = ExprMutator::VisitVarDef_(var);
+ Var copied_var = Var(new_var->name_hint(), GetStructInfo(new_var),
new_var->span);
+ var_remap_[var->vid] = copied_var;
+ return copied_var;
+ }
+};
+
+Function CopyWithNewVars(Function func) { return
FunctionCopier::Transform(func); }
+
+TVM_REGISTER_GLOBAL("relax.CopyWithNewVars").set_body_typed(CopyWithNewVars);
+
} // namespace relax
} // namespace tvm
diff --git a/tests/python/relax/test_transform_merge_composite_functions.py
b/tests/python/relax/test_transform_merge_composite_functions.py
new file mode 100644
index 0000000000..8577a4d93c
--- /dev/null
+++ b/tests/python/relax/test_transform_merge_composite_functions.py
@@ -0,0 +1,1051 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+
+import tvm
+from tvm import relax
+from tvm.script import relax as R
+
+
[email protected]_module
+class Conv2dReLUx2:
+ @R.function
+ def main(
+ data: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ weight2: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((1, 64, 56, 56), dtype="float32") =
fused_relax_nn_conv2d_relax_nn_relu(
+ data, weight1
+ )
+ gv: R.Tensor((1, 64, 54, 54), dtype="float32") =
fused_relax_nn_conv2d_relax_nn_relu1(
+ lv, weight2
+ )
+ R.output(gv)
+ return gv
+
+ @R.function
+ def fused_relax_nn_conv2d_relax_nn_relu(
+ data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight11: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
+ R.func_attr({"Primitive": 1, "Composite": "dnnl.conv2d_relu"})
+ with R.dataflow():
+ lv1: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(
+ data1,
+ weight11,
+ padding=[1, 1, 1, 1],
+ )
+ gv1: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv1)
+ R.output(gv1)
+ return gv1
+
+ @R.function
+ def fused_relax_nn_conv2d_relax_nn_relu1(
+ conv1: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight21: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+ R.func_attr({"Primitive": 1, "Composite": "dnnl.conv2d_relu"})
+ with R.dataflow():
+ lv2: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d(
+ conv1,
+ weight21,
+ padding=[0, 0, 0, 0],
+ )
+ gv2: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(lv2)
+ R.output(gv2)
+ return gv2
+
+
[email protected]_module
+class Conv2dReLUx2_merged:
+ @R.function
+ def main(
+ data: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ weight2: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+ with R.dataflow():
+ gv: R.Tensor(
+ (1, 64, 54, 54), dtype="float32"
+ ) =
fused_relax_nn_conv2d_relax_nn_relu_relax_nn_conv2d_relax_nn_relu1(
+ data, weight1, weight2
+ )
+ R.output(gv)
+ return gv
+
+ @R.function
+ def fused_relax_nn_conv2d_relax_nn_relu_relax_nn_conv2d_relax_nn_relu1(
+ data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight11: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ weight21: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+ R.func_attr(
+ {
+ "Primitive": 1,
+ "Codegen": "dnnl",
+ "global_symbol":
"fused_relax_nn_conv2d_relax_nn_relu_relax_nn_conv2d_relax_nn_relu1",
+ }
+ )
+ with R.dataflow():
+
+ @R.function
+ def lv(
+ data11: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight111: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
+ R.func_attr({"Composite": "dnnl.conv2d_relu", "Primitive": 1})
+ with R.dataflow():
+ lv1: R.Tensor((1, 64, 56, 56), dtype="float32") =
R.nn.conv2d(
+ data11,
+ weight111,
+ padding=[1, 1, 1, 1],
+ )
+ gv1: R.Tensor((1, 64, 56, 56), dtype="float32") =
R.nn.relu(lv1)
+ R.output(gv1)
+ return gv1
+
+ lv2: R.Tensor((1, 64, 56, 56), dtype="float32") = lv(data1,
weight11)
+
+ @R.function
+ def lv11(
+ conv1: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight211: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+ R.func_attr({"Composite": "dnnl.conv2d_relu", "Primitive": 1})
+ with R.dataflow():
+ lv21: R.Tensor((1, 64, 54, 54), dtype="float32") =
R.nn.conv2d(
+ conv1,
+ weight211,
+ padding=[0, 0, 0, 0],
+ )
+ gv2: R.Tensor((1, 64, 54, 54), dtype="float32") =
R.nn.relu(lv21)
+ R.output(gv2)
+ return gv2
+
+ gv3: R.Tensor((1, 64, 54, 54), dtype="float32") = lv11(lv2,
weight21)
+ R.output(gv3)
+ return gv3
+
+
[email protected]_module
+class Diamond:
+ @R.function
+ def main(
+ data: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+ with R.dataflow():
+ lv2: R.Tensor((1, 64, 54, 54), dtype="float32") =
fused_relax_nn_conv2d(data, weight)
+ lv3: R.Tensor((1, 64, 54, 54), dtype="float32") =
fused_relax_nn_relu(lv2)
+ lv4: R.Tensor((1, 64, 54, 54), dtype="float32") =
fused_relax_nn_gelu(lv2)
+ gv2: R.Tensor((1, 64, 54, 54), dtype="float32") =
fused_relax_add(lv3, lv4)
+ R.output(gv2)
+ return gv2
+
+ @R.function
+ def fused_relax_nn_gelu(
+ lv: R.Tensor((1, 64, 54, 54), dtype="float32")
+ ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+ R.func_attr({"Primitive": 1, "Composite": "compiler_A.gelu"})
+ with R.dataflow():
+ gv: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.gelu(lv)
+ R.output(gv)
+ return gv
+
+ @R.function
+ def fused_relax_nn_relu(
+ lv1: R.Tensor((1, 64, 54, 54), dtype="float32")
+ ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+ R.func_attr({"Primitive": 1, "Composite": "compiler_A.relu"})
+ with R.dataflow():
+ gv1: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(lv1)
+ R.output(gv1)
+ return gv1
+
+ @R.function
+ def fused_relax_add(
+ lv5: R.Tensor((1, 64, 54, 54), dtype="float32"),
+ gelu1: R.Tensor((1, 64, 54, 54), dtype="float32"),
+ ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+ R.func_attr({"Primitive": 1, "Composite": "compiler_A.add"})
+ with R.dataflow():
+ gv3: R.Tensor((1, 64, 54, 54), dtype="float32") = R.add(lv5, gelu1)
+ R.output(gv3)
+ return gv3
+
+ @R.function
+ def fused_relax_nn_conv2d(
+ data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+ R.func_attr({"Primitive": 1, "Composite": "compiler_A.conv2d"})
+ with R.dataflow():
+ gv4: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d(
+ data1,
+ weight1,
+ padding=[0, 0, 0, 0],
+ )
+ R.output(gv4)
+ return gv4
+
+
[email protected]_module
+class Diamond_merged:
+ @R.function
+ def fused_relax_nn_conv2d_relax_nn_relu_relax_nn_gelu_relax_add(
+ data: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+ # function attr dict
+ R.func_attr(
+ {
+ "Codegen": "compiler_A",
+ "Primitive": 1,
+ "global_symbol":
"fused_relax_nn_conv2d_relax_nn_relu_relax_nn_gelu_relax_add",
+ }
+ )
+ # block 0
+ with R.dataflow():
+
+ @R.function
+ def lv(
+ data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+ # function attr dict
+ R.func_attr({"Composite": "compiler_A.conv2d", "Primitive": 1})
+ # block 0
+ with R.dataflow():
+ gv4: R.Tensor((1, 64, 54, 54), dtype="float32") =
R.nn.conv2d(
+ data1,
+ weight1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="NCHW",
+ out_dtype="",
+ )
+ R.output(gv4)
+ return gv4
+
+ lv2: R.Tensor((1, 64, 54, 54), dtype="float32") = lv(data, weight)
+
+ @R.function
+ def lv1(
+ lv11: R.Tensor((1, 64, 54, 54), dtype="float32")
+ ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+ # function attr dict
+ R.func_attr({"Composite": "compiler_A.relu", "Primitive": 1})
+ # block 0
+ with R.dataflow():
+ gv1: R.Tensor((1, 64, 54, 54), dtype="float32") =
R.nn.relu(lv11)
+ R.output(gv1)
+ return gv1
+
+ lv3: R.Tensor((1, 64, 54, 54), dtype="float32") = lv1(lv2)
+
+ @R.function
+ def lv21(
+ lv4: R.Tensor((1, 64, 54, 54), dtype="float32")
+ ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+ # function attr dict
+ R.func_attr({"Composite": "compiler_A.gelu", "Primitive": 1})
+ # block 0
+ with R.dataflow():
+ gv: R.Tensor((1, 64, 54, 54), dtype="float32") =
R.nn.gelu(lv4)
+ R.output(gv)
+ return gv
+
+ lv41: R.Tensor((1, 64, 54, 54), dtype="float32") = lv21(lv2)
+
+ @R.function
+ def lv31(
+ lv5: R.Tensor((1, 64, 54, 54), dtype="float32"),
+ gelu1: R.Tensor((1, 64, 54, 54), dtype="float32"),
+ ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+ # function attr dict
+ R.func_attr({"Composite": "compiler_A.add", "Primitive": 1})
+ # block 0
+ with R.dataflow():
+ gv3: R.Tensor((1, 64, 54, 54), dtype="float32") =
R.add(lv5, gelu1)
+ R.output(gv3)
+ return gv3
+
+ gv2: R.Tensor((1, 64, 54, 54), dtype="float32") = lv31(lv3, lv41)
+ R.output(gv2)
+ return gv2
+
+ @R.function
+ def main(
+ data2: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight2: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ gv5: R.Tensor(
+ (1, 64, 54, 54), dtype="float32"
+ ) =
fused_relax_nn_conv2d_relax_nn_relu_relax_nn_gelu_relax_add(data2, weight2)
+ R.output(gv5)
+ return gv5
+
+
[email protected]_module
+class Diamond_cyclic_dep:
+ @R.function
+ def main(
+ data: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+ with R.dataflow():
+ lv2: R.Tensor((1, 64, 54, 54), dtype="float32") =
fused_relax_nn_conv2d(data, weight)
+ lv3: R.Tensor((1, 64, 54, 54), dtype="float32") =
fused_relax_nn_relu(lv2)
+ lv4: R.Tensor((1, 64, 54, 54), dtype="float32") =
fused_relax_nn_gelu(lv2)
+ gv2: R.Tensor((1, 64, 54, 54), dtype="float32") =
fused_relax_add(lv3, lv4)
+ R.output(gv2)
+ return gv2
+
+ @R.function
+ def fused_relax_nn_gelu(
+ lv: R.Tensor((1, 64, 54, 54), dtype="float32")
+ ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+ R.func_attr({"Primitive": 1, "Composite": "compiler_B.gelu"})
+ with R.dataflow():
+ gv: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.gelu(lv)
+ R.output(gv)
+ return gv
+
+ @R.function
+ def fused_relax_nn_relu(
+ lv1: R.Tensor((1, 64, 54, 54), dtype="float32")
+ ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+ R.func_attr({"Primitive": 1, "Composite": "compiler_A.relu"})
+ with R.dataflow():
+ gv1: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(lv1)
+ R.output(gv1)
+ return gv1
+
+ @R.function
+ def fused_relax_add(
+ lv5: R.Tensor((1, 64, 54, 54), dtype="float32"),
+ gelu1: R.Tensor((1, 64, 54, 54), dtype="float32"),
+ ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+ R.func_attr({"Primitive": 1, "Composite": "compiler_A.add"})
+ with R.dataflow():
+ gv3: R.Tensor((1, 64, 54, 54), dtype="float32") = R.add(lv5, gelu1)
+ R.output(gv3)
+ return gv3
+
+ @R.function
+ def fused_relax_nn_conv2d(
+ data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+ R.func_attr({"Primitive": 1, "Composite": "compiler_A.conv2d"})
+ with R.dataflow():
+ gv4: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d(
+ data1,
+ weight1,
+ padding=[0, 0, 0, 0],
+ )
+ R.output(gv4)
+ return gv4
+
+
[email protected]_module
+class Diamond_cyclic_dep_merged:
+ @R.function
+ def main(
+ data2: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight2: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+ with R.dataflow():
+ lv4: R.Tuple(
+ R.Tensor((1, 64, 54, 54), dtype="float32"),
+ R.Tensor((1, 64, 54, 54), dtype="float32"),
+ ) = fused_relax_nn_conv2d_relax_nn_relu(data2, weight2)
+ lv12: R.Tensor((1, 64, 54, 54), dtype="float32") = lv4[0]
+ lv22: R.Tensor((1, 64, 54, 54), dtype="float32") = lv4[1]
+ lv31: R.Tensor((1, 64, 54, 54), dtype="float32") =
fused_relax_nn_gelu1(lv12)
+ gv5: R.Tensor((1, 64, 54, 54), dtype="float32") =
fused_relax_add1(lv22, lv31)
+ R.output(gv5)
+ return gv5
+
+ @R.function
+ def fused_relax_nn_conv2d_relax_nn_relu(
+ data: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tuple(
+ R.Tensor((1, 64, 54, 54), dtype="float32"), R.Tensor((1, 64, 54, 54),
dtype="float32")
+ ):
+ R.func_attr(
+ {
+ "Primitive": 1,
+ "Codegen": "compiler_A",
+ "global_symbol": "fused_relax_nn_conv2d_relax_nn_relu",
+ }
+ )
+ with R.dataflow():
+
+ @R.function
+ def lv(
+ data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+ R.func_attr({"Composite": "compiler_A.conv2d", "Primitive": 1})
+ with R.dataflow():
+ gv4: R.Tensor((1, 64, 54, 54), dtype="float32") =
R.nn.conv2d(
+ data1,
+ weight1,
+ padding=[0, 0, 0, 0],
+ )
+ R.output(gv4)
+ return gv4
+
+ gv: R.Tensor((1, 64, 54, 54), dtype="float32") = lv(data, weight)
+
+ @R.function
+ def lv1(
+ lv11: R.Tensor((1, 64, 54, 54), dtype="float32")
+ ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+ R.func_attr({"Composite": "compiler_A.relu", "Primitive": 1})
+ with R.dataflow():
+ gv1: R.Tensor((1, 64, 54, 54), dtype="float32") =
R.nn.relu(lv11)
+ R.output(gv1)
+ return gv1
+
+ gv11: R.Tensor((1, 64, 54, 54), dtype="float32") = lv1(gv)
+ R.output(gv, gv11)
+ return (gv, gv11)
+
+ @R.function
+ def fused_relax_nn_gelu1(
+ lv2: R.Tensor((1, 64, 54, 54), dtype="float32")
+ ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+ R.func_attr(
+ {"Primitive": 1, "Codegen": "compiler_B", "global_symbol":
"fused_relax_nn_gelu1"}
+ )
+ with R.dataflow():
+
+ @R.function
+ def lv21(
+ lv3: R.Tensor((1, 64, 54, 54), dtype="float32")
+ ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+ R.func_attr({"Composite": "compiler_B.gelu", "Primitive": 1})
+ with R.dataflow():
+ gv2: R.Tensor((1, 64, 54, 54), dtype="float32") =
R.nn.gelu(lv3)
+ R.output(gv2)
+ return gv2
+
+ gv3: R.Tensor((1, 64, 54, 54), dtype="float32") = lv21(lv2)
+ R.output(gv3)
+ return gv3
+
+ @R.function
+ def fused_relax_add1(
+ lv32: R.Tensor((1, 64, 54, 54), dtype="float32"),
+ lv41: R.Tensor((1, 64, 54, 54), dtype="float32"),
+ ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+ R.func_attr({"Primitive": 1, "Codegen": "compiler_A", "global_symbol":
"fused_relax_add1"})
+ with R.dataflow():
+
+ @R.function
+ def lv33(
+ lv5: R.Tensor((1, 64, 54, 54), dtype="float32"),
+ gelu1: R.Tensor((1, 64, 54, 54), dtype="float32"),
+ ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+ R.func_attr({"Composite": "compiler_A.add", "Primitive": 1})
+ with R.dataflow():
+ gv31: R.Tensor((1, 64, 54, 54), dtype="float32") =
R.add(lv5, gelu1)
+ R.output(gv31)
+ return gv31
+
+ gv6: R.Tensor((1, 64, 54, 54), dtype="float32") = lv33(lv32, lv41)
+ R.output(gv6)
+ return gv6
+
+
[email protected]_module
+class MultipleProducers:
+ @R.function
+ def main(
+ x1: R.Tensor((10,), dtype="float32"), x2: R.Tensor((10,),
dtype="float32")
+ ) -> R.Tensor((10,), dtype="float32"):
+ with R.dataflow():
+ lv1: R.Tensor((10,), dtype="float32") = fused_relax_nn_relu(x1)
+ lv2: R.Tensor((10,), dtype="float32") = fused_relax_nn_gelu(x2)
+ lv3: R.Tensor((10,), dtype="float32") = fused_relax_nn_relu(lv1)
+ lv4: R.Tensor((10,), dtype="float32") = fused_relax_nn_gelu(lv2)
+ gv1: R.Tensor((10,), dtype="float32") = fused_relax_add(lv3, lv4)
+ R.output(gv1)
+ return gv1
+
+ @R.function
+ def fused_relax_nn_relu(
+ x11: R.Tensor((10,), dtype="float32")
+ ) -> R.Tensor((10,), dtype="float32"):
+ R.func_attr({"Primitive": 1, "Composite": "compiler_A.relu"})
+ with R.dataflow():
+ gv2: R.Tensor((10,), dtype="float32") = R.nn.relu(x11)
+ R.output(gv2)
+ return gv2
+
+ @R.function
+ def fused_relax_nn_gelu(
+ x21: R.Tensor((10,), dtype="float32")
+ ) -> R.Tensor((10,), dtype="float32"):
+ R.func_attr({"Primitive": 1, "Composite": "compiler_A.gelu"})
+ with R.dataflow():
+ gv3: R.Tensor((10,), dtype="float32") = R.nn.gelu(x21)
+ R.output(gv3)
+ return gv3
+
+ @R.function
+ def fused_relax_add(
+ lv: R.Tensor((10,), dtype="float32"), gelu1: R.Tensor((10,),
dtype="float32")
+ ) -> R.Tensor((10,), dtype="float32"):
+ R.func_attr({"Primitive": 1, "Composite": "compiler_A.add"})
+ with R.dataflow():
+ gv: R.Tensor((10,), dtype="float32") = R.add(lv, gelu1)
+ R.output(gv)
+ return gv
+
+
[email protected]_module
+class MultipleProducers_merged:
+ @R.function
+ def
fused_relax_nn_relu_relax_nn_gelu_relax_nn_relu_relax_nn_gelu_relax_add(
+ x1: R.Tensor((10,), dtype="float32"), x2: R.Tensor((10,),
dtype="float32")
+ ) -> R.Tensor((10,), dtype="float32"):
+ # function attr dict
+ R.func_attr(
+ {
+ "Codegen": "compiler_A",
+ "Primitive": 1,
+ "global_symbol":
"fused_relax_nn_relu_relax_nn_gelu_relax_nn_relu_relax_nn_gelu_relax_add",
+ }
+ )
+ # block 0
+ with R.dataflow():
+
+ @R.function
+ def lv(x11: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,),
dtype="float32"):
+ # function attr dict
+ R.func_attr({"Composite": "compiler_A.relu", "Primitive": 1})
+ # block 0
+ with R.dataflow():
+ gv2: R.Tensor((10,), dtype="float32") = R.nn.relu(x11)
+ R.output(gv2)
+ return gv2
+
+ lv1: R.Tensor((10,), dtype="float32") = lv(x1)
+
+ @R.function
+ def lv11(x21: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,),
dtype="float32"):
+ # function attr dict
+ R.func_attr({"Composite": "compiler_A.gelu", "Primitive": 1})
+ # block 0
+ with R.dataflow():
+ gv3: R.Tensor((10,), dtype="float32") = R.nn.gelu(x21)
+ R.output(gv3)
+ return gv3
+
+ lv2: R.Tensor((10,), dtype="float32") = lv11(x2)
+ lv3: R.Tensor((10,), dtype="float32") = lv(lv1)
+ lv4: R.Tensor((10,), dtype="float32") = lv11(lv2)
+
+ @R.function
+ def lv21(
+ lv5: R.Tensor((10,), dtype="float32"), gelu1: R.Tensor((10,),
dtype="float32")
+ ) -> R.Tensor((10,), dtype="float32"):
+ # function attr dict
+ R.func_attr({"Composite": "compiler_A.add", "Primitive": 1})
+ # block 0
+ with R.dataflow():
+ gv: R.Tensor((10,), dtype="float32") = R.add(lv5, gelu1)
+ R.output(gv)
+ return gv
+
+ gv1: R.Tensor((10,), dtype="float32") = lv21(lv3, lv4)
+ R.output(gv1)
+ return gv1
+
+ @R.function
+ def main(
+ x12: R.Tensor((10,), dtype="float32"), x22: R.Tensor((10,),
dtype="float32")
+ ) -> R.Tensor((10,), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ gv4: R.Tensor(
+ (10,), dtype="float32"
+ ) =
fused_relax_nn_relu_relax_nn_gelu_relax_nn_relu_relax_nn_gelu_relax_add(x12,
x22)
+ R.output(gv4)
+ return gv4
+
+
[email protected]_module
+class MultipleProducersCyclic:
+ @R.function
+ def main(x1: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,),
dtype="float32"):
+ with R.dataflow():
+ lv1: R.Tensor((10,), dtype="float32") = fused_relax_nn_relu(x1)
+ lv2: R.Tensor((10,), dtype="float32") = R.nn.relu(lv1)
+ lv3: R.Tensor((10,), dtype="float32") = fused_relax_nn_gelu(lv2)
+ gv1: R.Tensor((10,), dtype="float32") = fused_relax_add(lv1, lv3)
+ R.output(gv1)
+ return gv1
+
+ @R.function
+ def fused_relax_nn_relu(
+ x11: R.Tensor((10,), dtype="float32")
+ ) -> R.Tensor((10,), dtype="float32"):
+ R.func_attr({"Primitive": 1, "Composite": "compiler_A.relu"})
+ with R.dataflow():
+ gv2: R.Tensor((10,), dtype="float32") = R.nn.relu(x11)
+ R.output(gv2)
+ return gv2
+
+ @R.function
+ def fused_relax_nn_gelu(
+ x21: R.Tensor((10,), dtype="float32")
+ ) -> R.Tensor((10,), dtype="float32"):
+ R.func_attr({"Primitive": 1, "Composite": "compiler_A.gelu"})
+ with R.dataflow():
+ gv3: R.Tensor((10,), dtype="float32") = R.nn.gelu(x21)
+ R.output(gv3)
+ return gv3
+
+ @R.function
+ def fused_relax_add(
+ lv: R.Tensor((10,), dtype="float32"), gelu1: R.Tensor((10,),
dtype="float32")
+ ) -> R.Tensor((10,), dtype="float32"):
+ R.func_attr({"Primitive": 1, "Composite": "compiler_A.add"})
+ with R.dataflow():
+ gv: R.Tensor((10,), dtype="float32") = R.add(lv, gelu1)
+ R.output(gv)
+ return gv
+
+
[email protected]_module
+class MultipleProducersCyclic_merged:
+ @R.function
+ def main(x1: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,),
dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((10,), dtype="float32") = fused_relax_nn_relu1(x1)
+ lv2: R.Tensor((10,), dtype="float32") = R.nn.relu(lv)
+ gv: R.Tensor((10,), dtype="float32") =
fused_relax_nn_gelu_relax_add(lv2, lv)
+ R.output(gv)
+ return gv
+
+ @R.function
+ def fused_relax_nn_relu1(
+ x11: R.Tensor((10,), dtype="float32")
+ ) -> R.Tensor((10,), dtype="float32"):
+ # function attr dict
+ R.func_attr(
+ {"Codegen": "compiler_A", "Primitive": 1, "global_symbol":
"fused_relax_nn_relu1"}
+ )
+ # block 0
+ with R.dataflow():
+
+ @R.function
+ def lv1(x111: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,),
dtype="float32"):
+ # function attr dict
+ R.func_attr({"Composite": "compiler_A.relu", "Primitive": 1})
+ # block 0
+ with R.dataflow():
+ gv2: R.Tensor((10,), dtype="float32") = R.nn.relu(x111)
+ R.output(gv2)
+ return gv2
+
+ gv1: R.Tensor((10,), dtype="float32") = lv1(x11)
+ R.output(gv1)
+ return gv1
+
+ @R.function
+ def fused_relax_nn_gelu_relax_add(
+ lv21: R.Tensor((10,), dtype="float32"), lv11: R.Tensor((10,),
dtype="float32")
+ ) -> R.Tensor((10,), dtype="float32"):
+ # function attr dict
+ R.func_attr(
+ {
+ "Codegen": "compiler_A",
+ "Primitive": 1,
+ "global_symbol": "fused_relax_nn_gelu_relax_add",
+ }
+ )
+ # block 0
+ with R.dataflow():
+
+ @R.function
+ def lv12(x21: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,),
dtype="float32"):
+ # function attr dict
+ R.func_attr({"Composite": "compiler_A.gelu", "Primitive": 1})
+ # block 0
+ with R.dataflow():
+ gv3: R.Tensor((10,), dtype="float32") = R.nn.gelu(x21)
+ R.output(gv3)
+ return gv3
+
+ lv3: R.Tensor((10,), dtype="float32") = lv12(lv21)
+
+ @R.function
+ def lv22(
+ lv4: R.Tensor((10,), dtype="float32"), gelu1: R.Tensor((10,),
dtype="float32")
+ ) -> R.Tensor((10,), dtype="float32"):
+ # function attr dict
+ R.func_attr({"Composite": "compiler_A.add", "Primitive": 1})
+ # block 0
+ with R.dataflow():
+ gv4: R.Tensor((10,), dtype="float32") = R.add(lv4, gelu1)
+ R.output(gv4)
+ return gv4
+
+ gv5: R.Tensor((10,), dtype="float32") = lv22(lv11, lv3)
+ R.output(gv5)
+ return gv5
+
+
[email protected]_module
+class MergeCompilerRegionsExample:
+ @R.function
+ def main(
+ x1: R.Tensor((10,), dtype="float32"),
+ x2: R.Tensor((10,), dtype="float32"),
+ x3: R.Tensor((10,), dtype="float32"),
+ ) -> R.Tensor((10,), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((10,), dtype="float32") = fused_relax_add(x1, x2)
+ lv1: R.Tensor((10,), dtype="float32") = fused_relax_nn_gelu(x3)
+ lv11: R.Tensor((10,), dtype="float32") = fused_relax_add(lv, lv1)
+ lv12: R.Tensor((10,), dtype="float32") = fused_relax_nn_gelu(lv11)
+ lv2: R.Tensor((10,), dtype="float32") = fused_relax_nn_relu(lv11)
+ lv21: R.Tensor((10,), dtype="float32") = fused_relax_add(lv12, lv2)
+ gv1: R.Tensor((10,), dtype="float32") = fused_relax_nn_relu(lv21)
+ R.output(gv1)
+ return gv1
+
+ @R.function
+ def fused_relax_nn_relu(
+ add2: R.Tensor((10,), dtype="float32")
+ ) -> R.Tensor((10,), dtype="float32"):
+ R.func_attr({"Primitive": 1, "Composite": "compiler_A.relu"})
+ with R.dataflow():
+ gv: R.Tensor((10,), dtype="float32") = R.nn.relu(add2)
+ R.output(gv)
+ return gv
+
+ @R.function
+ def fused_relax_add(
+ x11: R.Tensor((10,), dtype="float32"), x21: R.Tensor((10,),
dtype="float32")
+ ) -> R.Tensor((10,), dtype="float32"):
+ R.func_attr({"Primitive": 1, "Composite": "compiler_A.add"})
+ with R.dataflow():
+ gv2: R.Tensor((10,), dtype="float32") = R.add(x11, x21)
+ R.output(gv2)
+ return gv2
+
+ @R.function
+ def fused_relax_nn_gelu(
+ x31: R.Tensor((10,), dtype="float32")
+ ) -> R.Tensor((10,), dtype="float32"):
+ R.func_attr({"Primitive": 1, "Composite": "compiler_B.gelu"})
+ with R.dataflow():
+ gv3: R.Tensor((10,), dtype="float32") = R.nn.gelu(x31)
+ R.output(gv3)
+ return gv3
+
+
[email protected]_module
+class MergeCompilerRegionsExampleRef:
+ @R.function
+ def fused_relax_add_relax_add_relax_nn_relu(
+ x1: R.Tensor((10,), dtype="float32"),
+ x2: R.Tensor((10,), dtype="float32"),
+ lv: R.Tensor((10,), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((10,), dtype="float32"), R.Tensor((10,),
dtype="float32")):
+ R.func_attr(
+ {
+ "Primitive": 1,
+ "Codegen": "compiler_A",
+ "global_symbol": "fused_relax_add_relax_add_relax_nn_relu",
+ }
+ )
+ with R.dataflow():
+
+ @R.function
+ def lv1(
+ x11: R.Tensor((10,), dtype="float32"), x21: R.Tensor((10,),
dtype="float32")
+ ) -> R.Tensor((10,), dtype="float32"):
+ R.func_attr({"Primitive": 1, "Composite": "compiler_A.add"})
+ with R.dataflow():
+ gv: R.Tensor((10,), dtype="float32") = R.add(x11, x21)
+ R.output(gv)
+ return gv
+
+ lv2: R.Tensor((10,), dtype="float32") = lv1(x1, x2)
+ gv1: R.Tensor((10,), dtype="float32") = lv1(lv2, lv)
+
+ @R.function
+ def lv11(add2: R.Tensor((10,), dtype="float32")) ->
R.Tensor((10,), dtype="float32"):
+ R.func_attr({"Primitive": 1, "Composite": "compiler_A.relu"})
+ with R.dataflow():
+ gv2: R.Tensor((10,), dtype="float32") = R.nn.relu(add2)
+ R.output(gv2)
+ return gv2
+
+ gv11: R.Tensor((10,), dtype="float32") = lv11(gv1)
+ R.output(gv1, gv11)
+ return (gv1, gv11)
+
+ @R.function
+ def fused_relax_add_relax_nn_relu(
+ lv12: R.Tensor((10,), dtype="float32"), lv3: R.Tensor((10,),
dtype="float32")
+ ) -> R.Tensor((10,), dtype="float32"):
+ R.func_attr(
+ {
+ "Primitive": 1,
+ "Codegen": "compiler_A",
+ "global_symbol": "fused_relax_add_relax_nn_relu",
+ }
+ )
+ with R.dataflow():
+
+ @R.function
+ def lv21(
+ x11: R.Tensor((10,), dtype="float32"), x21: R.Tensor((10,),
dtype="float32")
+ ) -> R.Tensor((10,), dtype="float32"):
+ R.func_attr({"Primitive": 1, "Composite": "compiler_A.add"})
+ with R.dataflow():
+ gv: R.Tensor((10,), dtype="float32") = R.add(x11, x21)
+ R.output(gv)
+ return gv
+
+ lv22: R.Tensor((10,), dtype="float32") = lv21(lv12, lv3)
+
+ @R.function
+ def lv31(add2: R.Tensor((10,), dtype="float32")) ->
R.Tensor((10,), dtype="float32"):
+ R.func_attr({"Primitive": 1, "Composite": "compiler_A.relu"})
+ with R.dataflow():
+ gv2: R.Tensor((10,), dtype="float32") = R.nn.relu(add2)
+ R.output(gv2)
+ return gv2
+
+ gv3: R.Tensor((10,), dtype="float32") = lv31(lv22)
+ R.output(gv3)
+ return gv3
+
+ @R.function
+ def fused_relax_nn_gelu1(
+ x3: R.Tensor((10,), dtype="float32")
+ ) -> R.Tensor((10,), dtype="float32"):
+ R.func_attr(
+ {"Primitive": 1, "Codegen": "compiler_B", "global_symbol":
"fused_relax_nn_gelu1"}
+ )
+ with R.dataflow():
+
+ @R.function
+ def lv4(x31: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,),
dtype="float32"):
+ R.func_attr({"Primitive": 1, "Composite": "compiler_B.gelu"})
+ with R.dataflow():
+ gv4: R.Tensor((10,), dtype="float32") = R.nn.gelu(x31)
+ R.output(gv4)
+ return gv4
+
+ gv5: R.Tensor((10,), dtype="float32") = lv4(x3)
+ R.output(gv5)
+ return gv5
+
+ @R.function
+ def main(
+ x12: R.Tensor((10,), dtype="float32"),
+ x22: R.Tensor((10,), dtype="float32"),
+ x32: R.Tensor((10,), dtype="float32"),
+ ) -> R.Tensor((10,), dtype="float32"):
+ with R.dataflow():
+ lv5: R.Tensor((10,), dtype="float32") = fused_relax_nn_gelu1(x32)
+ lv13: R.Tuple(
+ R.Tensor((10,), dtype="float32"), R.Tensor((10,),
dtype="float32")
+ ) = fused_relax_add_relax_add_relax_nn_relu(x12, x22, lv5)
+ lv23: R.Tensor((10,), dtype="float32") = lv13[0]
+ lv32: R.Tensor((10,), dtype="float32") = lv13[1]
+ lv41: R.Tensor((10,), dtype="float32") = fused_relax_nn_gelu1(lv23)
+ gv6: R.Tensor((10,), dtype="float32") =
fused_relax_add_relax_nn_relu(lv41, lv32)
+ R.output(gv6)
+ return gv6
+
+
[email protected]_module
+class ModuleWithNonComposite:
+ @R.function
+ def main(
+ data: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((1, 64, 56, 56), dtype="float32") =
fused_relax_nn_conv2d(data, weight)
+ conv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv)
+ R.output(conv)
+ return conv
+
+ @R.function
+ def fused_relax_nn_conv2d(
+ data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
+ R.func_attr({"Composite": "tensorrt.conv2d", "Primitive": 1})
+ with R.dataflow():
+ gv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(
+ data1,
+ weight1,
+ padding=[1, 1, 1, 1],
+ )
+ R.output(gv)
+ return gv
+
+
[email protected]_module
+class ModuleWithNonComposite_ref:
+ @R.function
+ def main(
+ data: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((1, 64, 56, 56), dtype="float32") =
fused_relax_nn_conv2d1(data, weight)
+ conv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv)
+ R.output(conv)
+ return conv
+
+ @R.function
+ def fused_relax_nn_conv2d1(
+ data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
+ R.func_attr(
+ {"Codegen": "tensorrt", "Primitive": 1, "global_symbol":
"fused_relax_nn_conv2d1"}
+ )
+ with R.dataflow():
+
+ @R.function
+ def lv1(
+ data2: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight2: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
+ R.func_attr({"Composite": "tensorrt.conv2d", "Primitive": 1})
+ with R.dataflow():
+ gv: R.Tensor((1, 64, 56, 56), dtype="float32") =
R.nn.conv2d(
+ data2,
+ weight2,
+ padding=[1, 1, 1, 1],
+ )
+ R.output(gv)
+ return gv
+
+ gv1: R.Tensor((1, 64, 56, 56), dtype="float32") = lv1(data1,
weight1)
+ R.output(gv1)
+ return gv1
+
+
+def check(mod, expected):
+ partitioned = relax.transform.MergeCompositeFunctions()(mod)
+ tvm.ir.assert_structural_equal(partitioned, expected)
+
+
+def test_conv2d_relu_x2():
+ check(Conv2dReLUx2, Conv2dReLUx2_merged)
+
+
+def test_diamond_cyclic_dep():
+ """
+ O = Offloaded to A
+ X = Offloaded to B
+
+ O O
+ / \\ / \\
+ O X --> O + + X
+ \\ / \\ /
+ O O
+
+ We cannot merge all 'O' since it would create a cyclic dependency between
the group of `X`.
+ """
+ check(Diamond_cyclic_dep, Diamond_cyclic_dep_merged)
+
+
+def test_diamond():
+ """
+ O = Offloaded to A
+
+ O O
+ / \\ / \\
+ O O --> O O
+ \\ / \\ /
+ O O
+
+ """
+ check(Diamond, Diamond_merged)
+
+
+def test_merge_producers():
+ """
+ Test merging multiple producer groups into a single representative group.
+ O O
+ | |
+ O O
+ \\ /
+ O
+ """
+ check(MultipleProducers, MultipleProducers_merged)
+
+
+def test_merge_producers_cyclic_dep():
+ """
+ Test when multiple producer groups being blocked to merge due to circular
dependency
+ in the result.
+ O
+ |\\
+ | X
+ | |
+ | O
+ |/
+ O
+ """
+ check(MultipleProducersCyclic, MultipleProducersCyclic_merged)
+
+
+def test_merge_compiler_regions_example():
+ """
+ A tricky example from
https://discuss.tvm.apache.org/t/relay-improved-graph-partitioning-algorithm/5830
+ See also the corresponding test case for Relay MergeCompilerRegions in
relay/test_pass_merge_compiler_regions.py.
+ """
+ check(
+ MergeCompilerRegionsExample,
+ MergeCompilerRegionsExampleRef,
+ )
+
+
+def test_mixed_non_composite():
+ check(ModuleWithNonComposite, ModuleWithNonComposite_ref)
+
+
+if __name__ == "__main__":
+ pytest.main([__file__])
diff --git a/tests/python/relax/test_utils.py b/tests/python/relax/test_utils.py
new file mode 100644
index 0000000000..fbeb57564f
--- /dev/null
+++ b/tests/python/relax/test_utils.py
@@ -0,0 +1,107 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+
+import tvm
+from tvm import relax
+from tvm.ir.base import assert_structural_equal
+from tvm.script.parser import relax as R
+
+
+def test_copy_with_new_vars():
+ @R.function
+ def before(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")):
+ gv = R.add(x, y)
+ return gv
+
+ after = relax.utils.copy_with_new_vars(before)
+ assert_structural_equal(after, before)
+
+ assert len(after.params) == len(before.params)
+ for before_var, after_var in zip(before.params, after.params):
+ assert before_var != after_var
+
+
+def test_copy_with_new_vars_on_ir_module():
+ @tvm.script.ir_module
+ class Actual:
+ @R.function
+ def func(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")):
+ gv = R.add(x, y)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def func(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")):
+ gv = R.add(x, y)
+ return gv
+
+ @R.function
+ def func_copied(x: R.Tensor((3,), "float32"), y: R.Tensor((3,),
"float32")):
+ gv = R.add(x, y)
+ return gv
+
+ Actual["func_copied"] = relax.utils.copy_with_new_vars(Actual["func"])
+
+ # Assertion will fail if the f_copied contains the same VarNode that's
used in
+ # the original function, due to var mapping during structural equal.
+ assert_structural_equal(Actual, Expected)
+
+
+def test_copy_with_new_vars_on_ir_module_nested_function():
+ @tvm.script.ir_module
+ class Actual:
+ @R.function
+ def func(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")):
+ @R.function
+ def inner(x: R.Tensor((3,), "float32")):
+ gv = R.add(x, x)
+ return gv
+
+ gv = R.add(x, y)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def func(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")):
+ @R.function
+ def inner(x: R.Tensor((3,), "float32")):
+ gv = R.add(x, x)
+ return gv
+
+ gv = R.add(x, y)
+ return gv
+
+ @R.function
+ def func_copied(x: R.Tensor((3,), "float32"), y: R.Tensor((3,),
"float32")):
+ @R.function
+ def inner(x: R.Tensor((3,), "float32")):
+ gv = R.add(x, x)
+ return gv
+
+ gv = R.add(x, y)
+ return gv
+
+ Actual["func_copied"] = relax.utils.copy_with_new_vars(Actual["func"])
+
+ assert_structural_equal(Actual, Expected)
+
+
+if __name__ == "__main__":
+ pytest.main([__file__])