This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 59b5f6d9ae [Unity][BYOC] Make `MergeCompositeFunctions` append codegen
name at the end of function name (#15534)
59b5f6d9ae is described below
commit 59b5f6d9ae0d805ec3afac7aae49645def51056b
Author: Sunghyun Park <[email protected]>
AuthorDate: Sat Aug 12 16:26:46 2023 -0700
[Unity][BYOC] Make `MergeCompositeFunctions` append codegen name at the end
of function name (#15534)
* append codegen name at the end of the funcname
* lint
---
src/relax/transform/merge_composite_functions.cc | 81 ++++++++++---
.../relax/test_transform_fuse_ops_by_pattern.py | 4 +-
.../test_transform_merge_composite_functions.py | 134 ++++++++-------------
3 files changed, 118 insertions(+), 101 deletions(-)
diff --git a/src/relax/transform/merge_composite_functions.cc
b/src/relax/transform/merge_composite_functions.cc
index 365b34c601..9d9d9aa644 100644
--- a/src/relax/transform/merge_composite_functions.cc
+++ b/src/relax/transform/merge_composite_functions.cc
@@ -319,6 +319,66 @@ class CompositeInliner : public ExprMutator {
Map<Function, Function> inlined_functions_;
};
+/*!
+ * \brief Wrap each created composite function with another function, whose
body consists
+ * only of a call to the composite function, and annotate the outer function
with kCodegen
+ * and kGlobalSymbol attributes.
+ */
+class CompositeFunctionAnnotator : public ExprMutator {
+ public:
+ explicit CompositeFunctionAnnotator(IRModule mod, IRModule new_mod)
+ : ExprMutator(new_mod), mod_(new_mod), inliner(mod) {
+ mod_.CopyOnWrite();
+ }
+ using ExprMutator::VisitExpr_;
+
+ IRModule update() {
+ auto gvar = mod_->GetGlobalVar("main");
+ auto func = Downcast<Function>(mod_->Lookup(gvar));
+ builder_->UpdateFunction(gvar, Downcast<Function>(VisitExpr(func)));
+ return builder_->GetContextIRModule();
+ }
+
+ Expr VisitExpr_(const CallNode* call) {
+ if (call->op->IsInstance<GlobalVarNode>()) {
+ GlobalVar cur_var = Downcast<GlobalVar>(call->op);
+ auto func = Downcast<Function>(mod_->Lookup(cur_var));
+ if (auto codegen_name = func->GetAttr<String>(attr::kCodegen)) {
+ GlobalVar new_var;
+ if (var_map_.count(cur_var) > 0) {
+ // if we visited before, we don't need to create the new function,
+ // use the one we stored.
+ new_var = var_map_[cur_var];
+ } else {
+ // if it is first time, create the new function with a new name.
+ // remove old function from the irmoulde under construction.
+ auto old_var =
builder_->GetContextIRModule()->GetGlobalVar(cur_var->name_hint);
+ builder_->GetContextIRModule()->Remove(old_var);
+
+ // rename the function.
+ String new_func_name = cur_var->name_hint + "_" +
codegen_name.value();
+ Function new_func = inliner.Run(Downcast<Function>(func));
+ new_func = WithAttr(new_func, tvm::attr::kGlobalSymbol,
new_func_name);
+ new_func = WithoutAttr(std::move(new_func),
tvm::relax::attr::kPrimitive);
+ // add a function with a new name.
+ new_var = builder_->AddFunction(new_func, new_func_name);
+ var_map_[cur_var] = new_var;
+ }
+ // we call new var instead of the old one.
+ // we don't have to update args since we are just updating the
function to call,
+ // without any change in the arguments.
+ return Call(new_var, call->args);
+ }
+ }
+ return GetRef<Call>(call);
+ }
+
+ private:
+ IRModule mod_;
+ CompositeInliner inliner;
+ std::unordered_map<GlobalVar, GlobalVar, ObjectPtrHash, ObjectPtrEqual>
var_map_;
+};
+
} // namespace
IRModule MergeCompositeFunctions(IRModule mod) {
@@ -327,21 +387,8 @@ IRModule MergeCompositeFunctions(IRModule mod) {
support::Arena arena;
auto group_map = CompositeGroupsBuilder(mod, &arena).Run(func);
auto new_mod = MakeGroupedFunctions(mod, group_map);
+ new_mod = CompositeFunctionAnnotator(mod, new_mod).update();
- CompositeInliner inliner(mod);
- std::vector<std::pair<GlobalVar, BaseFunc>> to_update;
- for (const auto& [gvar, func] : new_mod->functions) {
- if (func->GetAttr<String>(attr::kCodegen)) {
- auto new_func = inliner.Run(Downcast<Function>(func));
- new_func = WithAttr(new_func, tvm::attr::kGlobalSymbol, gvar->name_hint);
- new_func = WithoutAttr(std::move(new_func),
tvm::relax::attr::kPrimitive);
- to_update.emplace_back(gvar, new_func);
- }
- }
-
- for (const auto& [gvar, func] : to_update) {
- new_mod->Update(gvar, Downcast<Function>(func));
- }
// TODO(@tvm-team): Implicit pass dependency. Revisit when we have a better
way to handle this.
return DeadCodeElimination(new_mod, {"main"});
}
@@ -351,9 +398,9 @@ namespace transform {
Pass MergeCompositeFunctions() {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = //
[=](IRModule mod, PassContext pc) { return
relax::MergeCompositeFunctions(mod); };
- return CreateModulePass(/*pass_function=*/pass_func, //
- /*opt_level=*/0, //
- /*pass_name=*/"FuseOpsByPattern", //
+ return CreateModulePass(/*pass_function=*/pass_func, //
+ /*opt_level=*/0, //
+ /*pass_name=*/"MergeCompositeFunctions", //
/*required=*/{});
}
diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py
b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
index d95ed68c61..bbf2b0eea4 100644
--- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py
+++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
@@ -597,7 +597,7 @@ def test_compare_with_merge_composite_path():
@I.ir_module
class Expected2:
@R.function
- def fused_relax_multiply1(
+ def fused_relax_multiply1_cutlass(
x: R.Tensor((10, 10), dtype="float32"), y: R.Tensor((10, 10),
dtype="float32")
) -> R.Tensor((10, 10), dtype="float32"):
R.func_attr({"Codegen": "cutlass"})
@@ -624,7 +624,7 @@ def test_compare_with_merge_composite_path():
) -> R.Tensor((10, 10), dtype="float32"):
cls = Expected2
with R.dataflow():
- gv: R.Tensor((10, 10), dtype="float32") =
cls.fused_relax_multiply1(x, y)
+ gv: R.Tensor((10, 10), dtype="float32") =
cls.fused_relax_multiply1_cutlass(x, y)
R.output(gv)
return gv
diff --git a/tests/python/relax/test_transform_merge_composite_functions.py
b/tests/python/relax/test_transform_merge_composite_functions.py
index d56e1db564..6a36314a74 100644
--- a/tests/python/relax/test_transform_merge_composite_functions.py
+++ b/tests/python/relax/test_transform_merge_composite_functions.py
@@ -86,24 +86,19 @@ class Conv2dReLUx2_merged:
with R.dataflow():
gv: R.Tensor(
(1, 64, 54, 54), dtype="float32"
- ) =
cls.fused_relax_nn_conv2d_relax_nn_relu_relax_nn_conv2d_relax_nn_relu1(
+ ) =
cls.fused_relax_nn_conv2d_relax_nn_relu_relax_nn_conv2d_relax_nn_relu1_dnnl(
data, weight1, weight2
)
R.output(gv)
return gv
@R.function
- def fused_relax_nn_conv2d_relax_nn_relu_relax_nn_conv2d_relax_nn_relu1(
+ def
fused_relax_nn_conv2d_relax_nn_relu_relax_nn_conv2d_relax_nn_relu1_dnnl(
data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
weight11: R.Tensor((64, 64, 3, 3), dtype="float32"),
weight21: R.Tensor((64, 64, 3, 3), dtype="float32"),
) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
- R.func_attr(
- {
- "Codegen": "dnnl",
- "global_symbol":
"fused_relax_nn_conv2d_relax_nn_relu_relax_nn_conv2d_relax_nn_relu1",
- }
- )
+ R.func_attr({"Codegen": "dnnl"})
@R.function
def lv(
@@ -211,17 +206,12 @@ class Diamond:
@tvm.script.ir_module
class Diamond_merged:
@R.function
- def fused_relax_nn_conv2d_relax_nn_relu_relax_nn_gelu_relax_add(
+ def fused_relax_nn_conv2d_relax_nn_relu_relax_nn_gelu_relax_add_compiler_A(
data: R.Tensor((1, 64, 56, 56), dtype="float32"),
weight: R.Tensor((64, 64, 3, 3), dtype="float32"),
) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
# function attr dict
- R.func_attr(
- {
- "Codegen": "compiler_A",
- "global_symbol":
"fused_relax_nn_conv2d_relax_nn_relu_relax_nn_gelu_relax_add",
- }
- )
+ R.func_attr({"Codegen": "compiler_A"})
# block 0
@R.function
@@ -303,7 +293,9 @@ class Diamond_merged:
with R.dataflow():
gv5: R.Tensor(
(1, 64, 54, 54), dtype="float32"
- ) =
cls.fused_relax_nn_conv2d_relax_nn_relu_relax_nn_gelu_relax_add(data2, weight2)
+ ) =
cls.fused_relax_nn_conv2d_relax_nn_relu_relax_nn_gelu_relax_add_compiler_A(
+ data2, weight2
+ )
R.output(gv5)
return gv5
@@ -385,27 +377,26 @@ class Diamond_cyclic_dep_merged:
lv4: R.Tuple(
R.Tensor((1, 64, 54, 54), dtype="float32"),
R.Tensor((1, 64, 54, 54), dtype="float32"),
- ) = cls.fused_relax_nn_conv2d_relax_nn_relu(data2, weight2)
+ ) = cls.fused_relax_nn_conv2d_relax_nn_relu_compiler_A(data2,
weight2)
lv12: R.Tensor((1, 64, 54, 54), dtype="float32") = lv4[0]
lv22: R.Tensor((1, 64, 54, 54), dtype="float32") = lv4[1]
- lv31: R.Tensor((1, 64, 54, 54), dtype="float32") =
cls.fused_relax_nn_gelu1(lv12)
- gv5: R.Tensor((1, 64, 54, 54), dtype="float32") =
cls.fused_relax_add1(lv22, lv31)
+ lv31: R.Tensor((1, 64, 54, 54), dtype="float32") =
cls.fused_relax_nn_gelu1_compiler_B(
+ lv12
+ )
+ gv5: R.Tensor((1, 64, 54, 54), dtype="float32") =
cls.fused_relax_add1_compiler_A(
+ lv22, lv31
+ )
R.output(gv5)
return gv5
@R.function
- def fused_relax_nn_conv2d_relax_nn_relu(
+ def fused_relax_nn_conv2d_relax_nn_relu_compiler_A(
data: R.Tensor((1, 64, 56, 56), dtype="float32"),
weight: R.Tensor((64, 64, 3, 3), dtype="float32"),
) -> R.Tuple(
R.Tensor((1, 64, 54, 54), dtype="float32"), R.Tensor((1, 64, 54, 54),
dtype="float32")
):
- R.func_attr(
- {
- "Codegen": "compiler_A",
- "global_symbol": "fused_relax_nn_conv2d_relax_nn_relu",
- }
- )
+ R.func_attr({"Codegen": "compiler_A"})
@R.function
def lv(
@@ -439,10 +430,10 @@ class Diamond_cyclic_dep_merged:
return (gv, gv11)
@R.function
- def fused_relax_nn_gelu1(
+ def fused_relax_nn_gelu1_compiler_B(
lv2: R.Tensor((1, 64, 54, 54), dtype="float32")
) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
- R.func_attr({"Codegen": "compiler_B", "global_symbol":
"fused_relax_nn_gelu1"})
+ R.func_attr({"Codegen": "compiler_B"})
@R.function
def lv21(
@@ -458,11 +449,11 @@ class Diamond_cyclic_dep_merged:
return gv3
@R.function
- def fused_relax_add1(
+ def fused_relax_add1_compiler_A(
lv32: R.Tensor((1, 64, 54, 54), dtype="float32"),
lv41: R.Tensor((1, 64, 54, 54), dtype="float32"),
) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
- R.func_attr({"Codegen": "compiler_A", "global_symbol":
"fused_relax_add1"})
+ R.func_attr({"Codegen": "compiler_A"})
@R.function
def lv33(
@@ -529,16 +520,11 @@ class MultipleProducers:
@tvm.script.ir_module
class MultipleProducers_merged:
@R.function
- def
fused_relax_nn_relu_relax_nn_gelu_relax_nn_relu_relax_nn_gelu_relax_add(
+ def
fused_relax_nn_relu_relax_nn_gelu_relax_nn_relu_relax_nn_gelu_relax_add_compiler_A(
x1: R.Tensor((10,), dtype="float32"), x2: R.Tensor((10,),
dtype="float32")
) -> R.Tensor((10,), dtype="float32"):
# function attr dict
- R.func_attr(
- {
- "Codegen": "compiler_A",
- "global_symbol":
"fused_relax_nn_relu_relax_nn_gelu_relax_nn_relu_relax_nn_gelu_relax_add",
- }
- )
+ R.func_attr({"Codegen": "compiler_A"})
# block 0
@R.function
@@ -590,7 +576,7 @@ class MultipleProducers_merged:
with R.dataflow():
gv4: R.Tensor(
(10,), dtype="float32"
- ) =
cls.fused_relax_nn_relu_relax_nn_gelu_relax_nn_relu_relax_nn_gelu_relax_add(
+ ) =
cls.fused_relax_nn_relu_relax_nn_gelu_relax_nn_relu_relax_nn_gelu_relax_add_compiler_A(
x12, x22
)
R.output(gv4)
@@ -647,18 +633,20 @@ class MultipleProducersCyclic_merged:
def main(x1: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,),
dtype="float32"):
cls = MultipleProducersCyclic_merged
with R.dataflow():
- lv: R.Tensor((10,), dtype="float32") = cls.fused_relax_nn_relu1(x1)
+ lv: R.Tensor((10,), dtype="float32") =
cls.fused_relax_nn_relu1_compiler_A(x1)
lv2: R.Tensor((10,), dtype="float32") = R.nn.relu(lv)
- gv: R.Tensor((10,), dtype="float32") =
cls.fused_relax_nn_gelu_relax_add(lv2, lv)
+ gv: R.Tensor((10,), dtype="float32") =
cls.fused_relax_nn_gelu_relax_add_compiler_A(
+ lv2, lv
+ )
R.output(gv)
return gv
@R.function
- def fused_relax_nn_relu1(
+ def fused_relax_nn_relu1_compiler_A(
x11: R.Tensor((10,), dtype="float32")
) -> R.Tensor((10,), dtype="float32"):
# function attr dict
- R.func_attr({"Codegen": "compiler_A", "global_symbol":
"fused_relax_nn_relu1"})
+ R.func_attr({"Codegen": "compiler_A"})
# block 0
@R.function
@@ -675,16 +663,11 @@ class MultipleProducersCyclic_merged:
return gv1
@R.function
- def fused_relax_nn_gelu_relax_add(
+ def fused_relax_nn_gelu_relax_add_compiler_A(
lv21: R.Tensor((10,), dtype="float32"), lv11: R.Tensor((10,),
dtype="float32")
) -> R.Tensor((10,), dtype="float32"):
# function attr dict
- R.func_attr(
- {
- "Codegen": "compiler_A",
- "global_symbol": "fused_relax_nn_gelu_relax_add",
- }
- )
+ R.func_attr({"Codegen": "compiler_A"})
# block 0
@R.function
@@ -770,17 +753,12 @@ class MergeCompilerRegionsExample:
@tvm.script.ir_module
class MergeCompilerRegionsExampleRef:
@R.function
- def fused_relax_add_relax_add_relax_nn_relu(
+ def fused_relax_add_relax_add_relax_nn_relu_compiler_A(
x1: R.Tensor((10,), dtype="float32"),
x2: R.Tensor((10,), dtype="float32"),
lv: R.Tensor((10,), dtype="float32"),
) -> R.Tuple(R.Tensor((10,), dtype="float32"), R.Tensor((10,),
dtype="float32")):
- R.func_attr(
- {
- "Codegen": "compiler_A",
- "global_symbol": "fused_relax_add_relax_add_relax_nn_relu",
- }
- )
+ R.func_attr({"Codegen": "compiler_A"})
@R.function
def lv1(
@@ -807,15 +785,10 @@ class MergeCompilerRegionsExampleRef:
return (gv1, gv11)
@R.function
- def fused_relax_add_relax_nn_relu(
+ def fused_relax_add_relax_nn_relu_compiler_A(
lv12: R.Tensor((10,), dtype="float32"), lv3: R.Tensor((10,),
dtype="float32")
) -> R.Tensor((10,), dtype="float32"):
- R.func_attr(
- {
- "Codegen": "compiler_A",
- "global_symbol": "fused_relax_add_relax_nn_relu",
- }
- )
+ R.func_attr({"Codegen": "compiler_A"})
@R.function
def lv21(
@@ -842,10 +815,10 @@ class MergeCompilerRegionsExampleRef:
return gv3
@R.function
- def fused_relax_nn_gelu1(
+ def fused_relax_nn_gelu1_compiler_B(
x3: R.Tensor((10,), dtype="float32")
) -> R.Tensor((10,), dtype="float32"):
- R.func_attr({"Codegen": "compiler_B", "global_symbol":
"fused_relax_nn_gelu1"})
+ R.func_attr({"Codegen": "compiler_B"})
@R.function
def lv4(x31: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,),
dtype="float32"):
@@ -866,14 +839,16 @@ class MergeCompilerRegionsExampleRef:
) -> R.Tensor((10,), dtype="float32"):
cls = MergeCompilerRegionsExampleRef
with R.dataflow():
- lv5: R.Tensor((10,), dtype="float32") =
cls.fused_relax_nn_gelu1(x32)
+ lv5: R.Tensor((10,), dtype="float32") =
cls.fused_relax_nn_gelu1_compiler_B(x32)
lv13: R.Tuple(
R.Tensor((10,), dtype="float32"), R.Tensor((10,),
dtype="float32")
- ) = cls.fused_relax_add_relax_add_relax_nn_relu(x12, x22, lv5)
+ ) = cls.fused_relax_add_relax_add_relax_nn_relu_compiler_A(x12,
x22, lv5)
lv23: R.Tensor((10,), dtype="float32") = lv13[0]
lv32: R.Tensor((10,), dtype="float32") = lv13[1]
- lv41: R.Tensor((10,), dtype="float32") =
cls.fused_relax_nn_gelu1(lv23)
- gv6: R.Tensor((10,), dtype="float32") =
cls.fused_relax_add_relax_nn_relu(lv41, lv32)
+ lv41: R.Tensor((10,), dtype="float32") =
cls.fused_relax_nn_gelu1_compiler_B(lv23)
+ gv6: R.Tensor((10,), dtype="float32") =
cls.fused_relax_add_relax_nn_relu_compiler_A(
+ lv41, lv32
+ )
R.output(gv6)
return gv6
@@ -917,7 +892,7 @@ class ModuleWithNonComposite_ref:
) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
cls = ModuleWithNonComposite_ref
with R.dataflow():
- lv: R.Tensor((1, 64, 56, 56), dtype="float32") =
cls.fused_relax_nn_conv2d1(
+ lv: R.Tensor((1, 64, 56, 56), dtype="float32") =
cls.fused_relax_nn_conv2d1_tensorrt(
data, weight
)
conv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv)
@@ -925,11 +900,11 @@ class ModuleWithNonComposite_ref:
return conv
@R.function
- def fused_relax_nn_conv2d1(
+ def fused_relax_nn_conv2d1_tensorrt(
data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
- R.func_attr({"Codegen": "tensorrt", "global_symbol":
"fused_relax_nn_conv2d1"})
+ R.func_attr({"Codegen": "tensorrt"})
@R.function
def lv1(
@@ -1076,17 +1051,12 @@ def test_reshape():
@I.ir_module
class Expected:
@R.function
- def fused_relax_reshape_relax_matmul(
+ def fused_relax_reshape_relax_matmul_tensorrt(
inp_0: R.Tensor((1, 1, 28, 28), dtype="float32"),
param_0: R.Shape([1, 784]),
lv1: R.Tensor((784, 512), dtype="float32"),
) -> R.Tensor((1, 512), dtype="float32"):
- R.func_attr(
- {
- "Codegen": "tensorrt",
- "global_symbol": "fused_relax_reshape_relax_matmul",
- }
- )
+ R.func_attr({"Codegen": "tensorrt"})
# from tvm.script import relax as R
@R.function
@@ -1128,9 +1098,9 @@ def test_reshape():
lv1: R.Tensor((784, 512), dtype="float32") = R.permute_dims(
linear_relu_stack_0_weight, axes=None
)
- gv: R.Tensor((1, 512), dtype="float32") =
cls.fused_relax_reshape_relax_matmul(
- inp_0, R.shape([1, 784]), lv1
- )
+ gv: R.Tensor(
+ (1, 512), dtype="float32"
+ ) = cls.fused_relax_reshape_relax_matmul_tensorrt(inp_0,
R.shape([1, 784]), lv1)
R.output(gv)
return gv