Lunderberg commented on code in PR #16717:
URL: https://github.com/apache/tvm/pull/16717#discussion_r1527169353


##########
python/tvm/relax/transform/transform.py:
##########
@@ -844,12 +844,28 @@ def LiftTransformParams() -> tvm.ir.transform.Pass:
     Users are expected to invoke the `transform_params` function in runtime 
and pass the transformed
     parameters to the original function as input.
 
+    Parameters
+    ----------
+    shared_transform : Union[bool, List[str]]
+        Boolean to indicate whether to share the transformation of the 
parameters among functions or

Review Comment:
   When there are multiple options for how a parameter will be interpreted, it 
can be helpful for readers to list the behavior explicitly.  For example, I'd 
recommend changing the description of this parameter to explicitly state the 
behavior for true/false/list of names, like the example below.
   
   ```
   shared_transform: Union[bool, List[str]]
   
       Indicates how the parameter transformation function will be produced
   
       - `False` (default): A separate parameter transformation function will be
          produced for each function with the `"num_input"` attribute.
   
       - `True`: A single parameter transformation function will be produced,
          containing thepreprocessing steps common across all functions with
          the `"num_input"` attribute.
   
       - List[str]: A single parameter transformation function will be produced,
       containing the preprocessing steps common across each function whose
       name is in the list.  Passing a list of all functions with the 
`"num_input"`
       attribute is equivalent to passing `True`.
   ```



##########
src/relax/transform/lift_transform_params.cc:
##########
@@ -494,22 +700,71 @@ class ConsumeBundledParams : public ExprMutator {
   std::unordered_map<int, Expr> param_remap_;
 };
 
+std::vector<std::pair<GlobalVar, Function>> GetTargetFunctions(
+    const IRModule& mod, const Array<String>& target_function_names) {
+  std::vector<std::pair<GlobalVar, Function>> target_functions;
+  if (target_function_names.size()) {
+    for (const auto& name : target_function_names) {
+      auto gvar = mod->GetGlobalVar(name);
+      target_functions.push_back({gvar, 
Downcast<Function>(mod->Lookup(gvar))});
+    }
+  } else {
+    // Get all the functions that have the `num_input` attribute.
+    for (const auto& [gvar, func] : mod->functions) {
+      if (func->IsInstance<FunctionNode>()) {
+        auto opt_num_input = func->GetAttr<Integer>(attr::kNumInput);
+        if (opt_num_input) {
+          target_functions.emplace_back(gvar, Downcast<Function>(func));
+        }
+      }
+    }
+    std::sort(target_functions.begin(), target_functions.end(),
+              [](const auto& lhs, const auto& rhs) {
+                return lhs.first->name_hint < rhs.first->name_hint;
+              });
+  }
+  return target_functions;
+}
+
 }  // namespace
 
 namespace transform {
 
-Pass PartitionTransformParams() {
+Pass PartitionTransformParams(bool shared_transform, const Array<String>& 
target_function_names) {
   auto pass_func = [=](IRModule mod, PassContext pc) {
-    PreprocessPartitioner mutator;
-
     IRModule updates;
-    for (const auto& [gvar, func] : mod->functions) {
-      if (auto opt = func.as<relax::Function>()) {
-        auto new_func = Downcast<Function>(mutator(opt.value()));
-        if (!new_func.same_as(func)) {
-          updates->Add(gvar, new_func);
-        }
+

Review Comment:
   With accepting a `Variant<Bool, Array<String>>` for the `shared_transform` 
argument, we would then check whether we need to make the shared transform as 
follows:
   
   ```c++
   bool generate_shared_transform = 
shared_transform.as<Bool>().value_or(Bool(true))->value;
   ```



##########
python/tvm/relax/transform/transform.py:
##########
@@ -844,12 +844,28 @@ def LiftTransformParams() -> tvm.ir.transform.Pass:
     Users are expected to invoke the `transform_params` function in runtime 
and pass the transformed
     parameters to the original function as input.
 
+    Parameters
+    ----------
+    shared_transform : Union[bool, List[str]]
+        Boolean to indicate whether to share the transformation of the 
parameters among functions or
+        a list of function names to apply the shared transformation.
+
+        When the shared transformation is enabled, all the target functions 
should have the same
+        parameters and the common part of the transformations will be lifted 
to a global
+        `transform_params` function that is shared among all functions.
+        Otherwise, each function will have its own `transform_params` function.
+
     Returns
     -------
     ret : tvm.transform.Pass
         The registered pass for lifting transformation of parameters.
     """
-    return _ffi_api.LiftTransformParams()  # type: ignore
+    if isinstance(shared_transform, bool):

Review Comment:
   With the `LiftTransformParams` argument in C++ changed to `Variant<Bool, 
Array<String>>`, this conditional check can be removed.  The conversion from 
python `bool` or python `List[str]` to a `Variant<Bool, Array<String>>` is 
handled internally by the FFI.



##########
include/tvm/relax/transform.h:
##########
@@ -265,9 +265,15 @@ TVM_DLL Pass RealizeVDevice();
  * Users are expected to invoke the `transform_params` function in runtime and 
pass the transformed
  * parameters to the original function as input.
  *
+ * \param shared_transform Whether to share the transformation of the 
parameters among all functions
+ * If true, all functions should have the same parameters and the common part 
of the transformations
+ * will be lifted to a global `transform_params` function that is shared among 
all functions. If
+ * false, each function will have its own `transform_params` function.
+ * \param target_functions The list of functions to apply the shared 
transformation. If empty, the

Review Comment:
   Instead of two independent parameters, can this be a single parameter of 
type `Variant<Bool, Array<String>>`?  As it is, there's coupling between the 
two parameters, as the second parameter is ignored whenever the first parameter 
is `false`.  Using `Variant<Bool, Array<String>>` would avoid having coupled 
parameters, and would provide the same interface in both C++ and Python.



##########
src/relax/transform/lift_transform_params.cc:
##########
@@ -494,22 +700,71 @@ class ConsumeBundledParams : public ExprMutator {
   std::unordered_map<int, Expr> param_remap_;
 };
 
+std::vector<std::pair<GlobalVar, Function>> GetTargetFunctions(
+    const IRModule& mod, const Array<String>& target_function_names) {
+  std::vector<std::pair<GlobalVar, Function>> target_functions;
+  if (target_function_names.size()) {

Review Comment:
   With accepting a `Variant<Bool, Array<String>>` for the shared transforms, 
that argument would then be forwarded directly to the `GetTargetFunctions` 
utility.
   
   ```c++
   if(auto function_names = shared_transforms.as<Array<String>>()) {
       for(const auto& name: function_names.value()) {
           ...
        }
   } else {
       ...
   }
   ```



##########
src/relax/transform/lift_transform_params.cc:
##########
@@ -494,22 +700,71 @@ class ConsumeBundledParams : public ExprMutator {
   std::unordered_map<int, Expr> param_remap_;
 };
 
+std::vector<std::pair<GlobalVar, Function>> GetTargetFunctions(
+    const IRModule& mod, const Array<String>& target_function_names) {
+  std::vector<std::pair<GlobalVar, Function>> target_functions;
+  if (target_function_names.size()) {
+    for (const auto& name : target_function_names) {
+      auto gvar = mod->GetGlobalVar(name);
+      target_functions.push_back({gvar, 
Downcast<Function>(mod->Lookup(gvar))});
+    }
+  } else {
+    // Get all the functions that have the `num_input` attribute.
+    for (const auto& [gvar, func] : mod->functions) {
+      if (func->IsInstance<FunctionNode>()) {
+        auto opt_num_input = func->GetAttr<Integer>(attr::kNumInput);
+        if (opt_num_input) {
+          target_functions.emplace_back(gvar, Downcast<Function>(func));
+        }
+      }
+    }
+    std::sort(target_functions.begin(), target_functions.end(),
+              [](const auto& lhs, const auto& rhs) {
+                return lhs.first->name_hint < rhs.first->name_hint;
+              });
+  }
+  return target_functions;
+}
+
 }  // namespace
 
 namespace transform {
 
-Pass PartitionTransformParams() {
+Pass PartitionTransformParams(bool shared_transform, const Array<String>& 
target_function_names) {
   auto pass_func = [=](IRModule mod, PassContext pc) {
-    PreprocessPartitioner mutator;
-
     IRModule updates;
-    for (const auto& [gvar, func] : mod->functions) {
-      if (auto opt = func.as<relax::Function>()) {
-        auto new_func = Downcast<Function>(mutator(opt.value()));
-        if (!new_func.same_as(func)) {
-          updates->Add(gvar, new_func);
-        }
+
+    std::optional<GlobalCollectInfo> global_collect_info;
+    auto target_functions = GetTargetFunctions(mod, target_function_names);
+
+    if (shared_transform) {
+      std::vector<Function> functions;
+      for (const auto& [_, func] : target_functions) {
+        functions.push_back(func);
       }
+      global_collect_info = MakeGlobalLiftPlan(mod, functions);
+    }
+
+    std::unordered_map<GlobalVar, LocalCollectInfo, ObjectPtrHash, 
ObjectPtrEqual>
+        local_collect_info;
+    for (const auto& [gvar, func] : target_functions) {
+      auto info = LocalLiftableBindingCollector::Collect(
+          func, global_collect_info.has_value() ? &global_collect_info.value() 
: nullptr);
+      local_collect_info[gvar] = info;
+    }
+
+    for (const auto& [gvar, info] : local_collect_info) {
+      auto new_runtime_func = info.MakeRuntimeFunction();
+      updates->Add(gvar, new_runtime_func);
+      if (!global_collect_info.has_value()) {

Review Comment:
   Can this conditional be moved to an else block of `if 
(global_collect_info.has_value())`?  That would make it clearer to a reader 
that the runtime function is always generated, and that there's a choice 
between a single global transformation and several individual transformations.
   
   ```c++
   if (global_collect_info.has_value()) {
     auto global_transform = global_collect_info.value().MakeCompileTimeFunc();
     updates->Add(GlobalVar("transform_params"), global_transform);
   } else {
     for (const auto& [gvar, info] : local_collect_info) {
         updates->Add(GlobalVar(gvar->name_hint + "_transform_params"),
                        info.MakeCompileTimeFunction());
     }
   }
   ```



##########
include/tvm/relax/transform.h:
##########
@@ -265,9 +265,15 @@ TVM_DLL Pass RealizeVDevice();
  * Users are expected to invoke the `transform_params` function in runtime and 
pass the transformed
  * parameters to the original function as input.
  *
+ * \param shared_transform Whether to share the transformation of the 
parameters among all functions
+ * If true, all functions should have the same parameters and the common part 
of the transformations
+ * will be lifted to a global `transform_params` function that is shared among 
all functions. If
+ * false, each function will have its own `transform_params` function.
+ * \param target_functions The list of functions to apply the shared 
transformation. If empty, the
+ * shared transformation will be applied to all functions.
  * \return The Pass.
  */
-TVM_DLL Pass LiftTransformParams();
+TVM_DLL Pass LiftTransformParams(bool shared_transform, Array<String> 
target_functions);

Review Comment:
   To avoid breaking backwards compatibility, new parameters should have 
default values, with the default value chosen to reproduce the previous 
behavior.
   
   ```c++
   TVM_DLL Pass LiftTransformParams(Variant<Bool, Array<String>> 
shared_transform = Bool(false));
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to