Lunderberg commented on code in PR #16717:
URL: https://github.com/apache/tvm/pull/16717#discussion_r1527169353
##########
python/tvm/relax/transform/transform.py:
##########
@@ -844,12 +844,28 @@ def LiftTransformParams() -> tvm.ir.transform.Pass:
Users are expected to invoke the `transform_params` function in runtime
and pass the transformed
parameters to the original function as input.
+ Parameters
+ ----------
+ shared_transform : Union[bool, List[str]]
+ Boolean to indicate whether to share the transformation of the
parameters among functions or
Review Comment:
When there are multiple options for how a parameter will be interpreted, it
can be helpful for readers to list the behavior explicitly. For example, I'd
recommend changing the description of this parameter to explicitly state the
behavior for true/false/list of names, like the example below.
```
shared_transform: Union[bool, List[str]]
Indicates how the parameter transformation function will be produced
- `False` (default): A separate parameter transformation function will be
produced for each function with the `"num_input"` attribute.
- `True`: A single parameter transformation function will be produced,
containing thepreprocessing steps common across all functions with
the `"num_input"` attribute.
- List[str]: A single parameter transformation function will be produced,
containing the preprocessing steps common across each function whose
name is in the list. Passing a list of all functions with the
`"num_input"`
attribute is equivalent to passing `True`.
```
##########
src/relax/transform/lift_transform_params.cc:
##########
@@ -494,22 +700,71 @@ class ConsumeBundledParams : public ExprMutator {
std::unordered_map<int, Expr> param_remap_;
};
+std::vector<std::pair<GlobalVar, Function>> GetTargetFunctions(
+ const IRModule& mod, const Array<String>& target_function_names) {
+ std::vector<std::pair<GlobalVar, Function>> target_functions;
+ if (target_function_names.size()) {
+ for (const auto& name : target_function_names) {
+ auto gvar = mod->GetGlobalVar(name);
+ target_functions.push_back({gvar,
Downcast<Function>(mod->Lookup(gvar))});
+ }
+ } else {
+ // Get all the functions that have the `num_input` attribute.
+ for (const auto& [gvar, func] : mod->functions) {
+ if (func->IsInstance<FunctionNode>()) {
+ auto opt_num_input = func->GetAttr<Integer>(attr::kNumInput);
+ if (opt_num_input) {
+ target_functions.emplace_back(gvar, Downcast<Function>(func));
+ }
+ }
+ }
+ std::sort(target_functions.begin(), target_functions.end(),
+ [](const auto& lhs, const auto& rhs) {
+ return lhs.first->name_hint < rhs.first->name_hint;
+ });
+ }
+ return target_functions;
+}
+
} // namespace
namespace transform {
-Pass PartitionTransformParams() {
+Pass PartitionTransformParams(bool shared_transform, const Array<String>&
target_function_names) {
auto pass_func = [=](IRModule mod, PassContext pc) {
- PreprocessPartitioner mutator;
-
IRModule updates;
- for (const auto& [gvar, func] : mod->functions) {
- if (auto opt = func.as<relax::Function>()) {
- auto new_func = Downcast<Function>(mutator(opt.value()));
- if (!new_func.same_as(func)) {
- updates->Add(gvar, new_func);
- }
+
Review Comment:
With accepting a `Variant<Bool, Array<String>>` for the `shared_transform`
argument, we would then check whether we need to make the shared transform as
follows:
```c++
bool generate_shared_transform =
shared_transform.as<Bool>().value_or(Bool(true))->value;
```
##########
python/tvm/relax/transform/transform.py:
##########
@@ -844,12 +844,28 @@ def LiftTransformParams() -> tvm.ir.transform.Pass:
Users are expected to invoke the `transform_params` function in runtime
and pass the transformed
parameters to the original function as input.
+ Parameters
+ ----------
+ shared_transform : Union[bool, List[str]]
+ Boolean to indicate whether to share the transformation of the
parameters among functions or
+ a list of function names to apply the shared transformation.
+
+ When the shared transformation is enabled, all the target functions
should have the same
+ parameters and the common part of the transformations will be lifted
to a global
+ `transform_params` function that is shared among all functions.
+ Otherwise, each function will have its own `transform_params` function.
+
Returns
-------
ret : tvm.transform.Pass
The registered pass for lifting transformation of parameters.
"""
- return _ffi_api.LiftTransformParams() # type: ignore
+ if isinstance(shared_transform, bool):
Review Comment:
With the `LiftTransformParams` argument in C++ changed to `Variant<Bool,
Array<String>>`, this conditional check can be removed. The conversion from
python `bool` or python `List[str]` to a `Variant<Bool, Array<String>>` is
handled internally by the FFI.
##########
include/tvm/relax/transform.h:
##########
@@ -265,9 +265,15 @@ TVM_DLL Pass RealizeVDevice();
* Users are expected to invoke the `transform_params` function in runtime and
pass the transformed
* parameters to the original function as input.
*
+ * \param shared_transform Whether to share the transformation of the
parameters among all functions
+ * If true, all functions should have the same parameters and the common part
of the transformations
+ * will be lifted to a global `transform_params` function that is shared among
all functions. If
+ * false, each function will have its own `transform_params` function.
+ * \param target_functions The list of functions to apply the shared
transformation. If empty, the
Review Comment:
Instead of two independent parameters, can this be a single parameter of
type `Variant<Bool, Array<String>>`? As it is, there's coupling between the
two parameters, as the second parameter is ignored whenever the first parameter
is `false`. Using `Variant<Bool, Array<String>>` would avoid having coupled
parameters, and would provide the same interface in both C++ and Python.
##########
src/relax/transform/lift_transform_params.cc:
##########
@@ -494,22 +700,71 @@ class ConsumeBundledParams : public ExprMutator {
std::unordered_map<int, Expr> param_remap_;
};
+std::vector<std::pair<GlobalVar, Function>> GetTargetFunctions(
+ const IRModule& mod, const Array<String>& target_function_names) {
+ std::vector<std::pair<GlobalVar, Function>> target_functions;
+ if (target_function_names.size()) {
Review Comment:
With accepting a `Variant<Bool, Array<String>>` for the shared transforms,
that argument would then be forwarded directly to the `GetTargetFunctions`
utility.
```c++
if(auto function_names = shared_transforms.as<Array<String>>()) {
for(const auto& name: function_names.value()) {
...
}
} else {
...
}
```
##########
src/relax/transform/lift_transform_params.cc:
##########
@@ -494,22 +700,71 @@ class ConsumeBundledParams : public ExprMutator {
std::unordered_map<int, Expr> param_remap_;
};
+std::vector<std::pair<GlobalVar, Function>> GetTargetFunctions(
+ const IRModule& mod, const Array<String>& target_function_names) {
+ std::vector<std::pair<GlobalVar, Function>> target_functions;
+ if (target_function_names.size()) {
+ for (const auto& name : target_function_names) {
+ auto gvar = mod->GetGlobalVar(name);
+ target_functions.push_back({gvar,
Downcast<Function>(mod->Lookup(gvar))});
+ }
+ } else {
+ // Get all the functions that have the `num_input` attribute.
+ for (const auto& [gvar, func] : mod->functions) {
+ if (func->IsInstance<FunctionNode>()) {
+ auto opt_num_input = func->GetAttr<Integer>(attr::kNumInput);
+ if (opt_num_input) {
+ target_functions.emplace_back(gvar, Downcast<Function>(func));
+ }
+ }
+ }
+ std::sort(target_functions.begin(), target_functions.end(),
+ [](const auto& lhs, const auto& rhs) {
+ return lhs.first->name_hint < rhs.first->name_hint;
+ });
+ }
+ return target_functions;
+}
+
} // namespace
namespace transform {
-Pass PartitionTransformParams() {
+Pass PartitionTransformParams(bool shared_transform, const Array<String>&
target_function_names) {
auto pass_func = [=](IRModule mod, PassContext pc) {
- PreprocessPartitioner mutator;
-
IRModule updates;
- for (const auto& [gvar, func] : mod->functions) {
- if (auto opt = func.as<relax::Function>()) {
- auto new_func = Downcast<Function>(mutator(opt.value()));
- if (!new_func.same_as(func)) {
- updates->Add(gvar, new_func);
- }
+
+ std::optional<GlobalCollectInfo> global_collect_info;
+ auto target_functions = GetTargetFunctions(mod, target_function_names);
+
+ if (shared_transform) {
+ std::vector<Function> functions;
+ for (const auto& [_, func] : target_functions) {
+ functions.push_back(func);
}
+ global_collect_info = MakeGlobalLiftPlan(mod, functions);
+ }
+
+ std::unordered_map<GlobalVar, LocalCollectInfo, ObjectPtrHash,
ObjectPtrEqual>
+ local_collect_info;
+ for (const auto& [gvar, func] : target_functions) {
+ auto info = LocalLiftableBindingCollector::Collect(
+ func, global_collect_info.has_value() ? &global_collect_info.value()
: nullptr);
+ local_collect_info[gvar] = info;
+ }
+
+ for (const auto& [gvar, info] : local_collect_info) {
+ auto new_runtime_func = info.MakeRuntimeFunction();
+ updates->Add(gvar, new_runtime_func);
+ if (!global_collect_info.has_value()) {
Review Comment:
Can this conditional be moved to an else block of `if
(global_collect_info.has_value())`? That would make it clearer to a reader
that the runtime function is always generated, and that there's a choice
between a single global transformation and several individual transformations.
```c++
if (global_collect_info.has_value()) {
auto global_transform = global_collect_info.value().MakeCompileTimeFunc();
updates->Add(GlobalVar("transform_params"), global_transform);
} else {
for (const auto& [gvar, info] : local_collect_info) {
updates->Add(GlobalVar(gvar->name_hint + "_transform_params"),
info.MakeCompileTimeFunction());
}
}
```
##########
include/tvm/relax/transform.h:
##########
@@ -265,9 +265,15 @@ TVM_DLL Pass RealizeVDevice();
* Users are expected to invoke the `transform_params` function in runtime and
pass the transformed
* parameters to the original function as input.
*
+ * \param shared_transform Whether to share the transformation of the
parameters among all functions
+ * If true, all functions should have the same parameters and the common part
of the transformations
+ * will be lifted to a global `transform_params` function that is shared among
all functions. If
+ * false, each function will have its own `transform_params` function.
+ * \param target_functions The list of functions to apply the shared
transformation. If empty, the
+ * shared transformation will be applied to all functions.
* \return The Pass.
*/
-TVM_DLL Pass LiftTransformParams();
+TVM_DLL Pass LiftTransformParams(bool shared_transform, Array<String>
target_functions);
Review Comment:
To avoid breaking backwards compatibility, new parameters should have
default values, with the default value chosen to reproduce the previous
behavior.
```c++
TVM_DLL Pass LiftTransformParams(Variant<Bool, Array<String>>
shared_transform = Bool(false));
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]