This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e715814985 [Relax][Transform] Preserve param names in 
LiftTransformParams (#16594)
e715814985 is described below

commit e715814985fc88d813eabfa9ada364bfcadc7bec
Author: Eric Lunderberg <[email protected]>
AuthorDate: Fri Feb 23 08:36:33 2024 -0600

    [Relax][Transform] Preserve param names in LiftTransformParams (#16594)
    
    * [Relax][Transform] Preserve param names in LiftTransformParams
    
    The `relax.transform.LiftTransformParams` pass splits apart a relax
    function, extracting the steps that could be performed at
    compile-time.  Prior to this commit, the transformed parameters were
    named `param0`, `param1`, and so on.
    
    This commit updates the `LiftTransformParams` pass to preserve any
    human-readable parameter names.  The parameter names for the updated
    function are taken from the original parameter names, if no
    transformation is performed, or from the internal variable binding, if
    a transformation is applied.  This implementation uses `LambdaLift`
    internally, relying on the changes made in
    https://github.com/apache/tvm/pull/16306.
    
    * Update based on review comments
---
 src/relax/transform/lift_transform_params.cc       | 637 +++++++++++----------
 .../relax/test_transform_lift_transform_params.py  |  86 ++-
 2 files changed, 380 insertions(+), 343 deletions(-)

diff --git a/src/relax/transform/lift_transform_params.cc 
b/src/relax/transform/lift_transform_params.cc
index b500a3c3a3..15b60f5492 100644
--- a/src/relax/transform/lift_transform_params.cc
+++ b/src/relax/transform/lift_transform_params.cc
@@ -18,7 +18,7 @@
  */
 
 /*!
- * \file tvm/relax/transform/lambda_lift.cc
+ * \file tvm/relax/transform/lift_transform_params.cc
  * \brief Lift local functions into global functions.
  */
 
@@ -29,6 +29,7 @@
 #include <tvm/runtime/logging.h>
 
 #include <iostream>
+#include <tuple>
 #include <vector>
 
 #include "../../support/ordered_set.h"
@@ -37,405 +38,443 @@
 namespace tvm {
 namespace relax {
 
-/*! \brief Plan of lifting transform params */
-struct LiftTransformParamsInfoPlan {
-  Function f_transform_params;  // the lifted function that transforms the 
parameters
-  std::unordered_map<Var, int, ObjectPtrHash, ObjectPtrEqual>
-      output_to_index;  // the index of the original bindings in the output 
tuple
-  std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>
-      lifted_bindings;  // the bindings of the original function that are 
lifted
-};
+namespace {
 
-/*! \brief Builder of the function that transforms the parameters. */
-class TransformParamsFuncBuilder : public ExprMutator {
- public:
-  TransformParamsFuncBuilder() { builder_->BeginDataflowBlock(); }
+struct CollectInfo {
+  /* \brief The analyzed function */
+  Function orig_func;
+
+  /* \brief The number of parameters unknown until runtime */
+  size_t num_runtime_params;
+
+  /*! \brief Bindings that can be lifted out into a pre-processing
+   *
+   * - All bindings in `computable_at_compile_time` are suitable for
+   *   use in a DataflowBlock.
+   *
+   * - Do not depend on any parameter prior to attr::kNumInput.
+   *
+   * - Does not include "relax.builtin.stop_lift_params"
+   */
+  std::vector<Binding> computable_at_compile_time;
 
-  /*! \brief Add a input parameter. */
-  void AddInput(const Var& var) {
-    inputs_.push_back(var);
-    lifted_binding_lookup_.insert(var);
+  /*! \brief Variables that are required at runtime */
+  std::unordered_set<Variant<relax::Var, tir::Var>, ObjectPtrHash, 
ObjectPtrEqual>
+      required_at_runtime;
+
+  Array<Var> GetCompileTimeInputs() const {
+    return Array<Var>(orig_func->params.begin() + num_runtime_params, 
orig_func->params.end());
   }
 
-  void UpdateBasedOnRuntimeInput(const Var& var) {
-    for (const auto& var : DefinableTIRVarsInStructInfo(GetStructInfo(var))) {
-      known_symbolic_var_during_inference_.insert(var);
-    }
-    for (const auto& var : TIRVarsInStructInfo(GetStructInfo(var))) {
-      required_symbolic_var_during_inference_.insert(var);
-    }
+  Array<Var> GetRuntimeInputs() const {
+    return Array<Var>(orig_func->params.begin(), orig_func->params.begin() + 
num_runtime_params);
   }
 
-  /*! \brief Add a binding to lift. */
-  void AddInternalBinding(const VarBinding& binding) {
-    bindings_.push_back(binding);
-    lifted_binding_lookup_.insert(binding->var);
+  Array<tir::Var> GetPropagatedSymbolicVariables() const {
+    auto vars_from_any_param =
+        
DefinableTIRVarsInStructInfo(TupleStructInfo(orig_func->params.Map(GetStructInfo)));
+
+    auto vars_from_runtime_params =
+        [&]() -> std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> {
+      auto tir_var_vec =
+          
DefinableTIRVarsInStructInfo(TupleStructInfo(GetRuntimeInputs().Map(GetStructInfo)));
+      return {tir_var_vec.begin(), tir_var_vec.end()};
+    }();
+
+    auto vars_from_transformed_params =
+        [&]() -> std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> {
+      auto tir_var_vec =
+          
DefinableTIRVarsInStructInfo(TupleStructInfo(GetCompileTimeOutputs().Map(GetStructInfo)));
+      return {tir_var_vec.begin(), tir_var_vec.end()};
+    }();
+
+    Array<tir::Var> output;
+    for (const auto& tir_var : vars_from_any_param) {
+      if (required_at_runtime.count(tir_var) && 
!vars_from_runtime_params.count(tir_var) &&
+          !vars_from_transformed_params.count(tir_var)) {
+        output.push_back(tir_var);
+      }
+    }
+    return output;
   }
 
-  /*! \brief Update based on bindings not being lifted. */
-  void UpdateBasedOnRuntimeBinding(const VarBinding& binding) {
-    for (const auto& producer : FreeVars(binding->value)) {
-      // An external value that uses a lifted binding requires the
-      // lifted binding to be returned as output.
-      if (lifted_binding_lookup_.count(producer)) {
-        outputs_.insert(producer);
+  Array<Var> GetCompileTimeOutputs() const {
+    Array<Var> params;
 
-        for (const auto& var : 
DefinableTIRVarsInStructInfo(GetStructInfo(producer))) {
-          known_symbolic_var_during_inference_.insert(var);
-        }
+    // Any value that is available at compile-time, but is also
+    // required at runtime, must be passed through the compile-time
+    // function.
+    for (size_t i = num_runtime_params; i < orig_func->params.size(); i++) {
+      Var var = orig_func->params[i];
+      if (required_at_runtime.count(var)) {
+        params.push_back(var);
       }
     }
 
-    // All TIR variables used in the binding must be available at runtime.
-    for (const auto& var : FreeSymbolicVars(binding->value)) {
-      required_symbolic_var_during_inference_.insert(var);
+    // Any variable that is computed at compile-time, but is required
+    // at runtime, must be provided as a parameter.
+    for (const auto& binding : computable_at_compile_time) {
+      if (required_at_runtime.count(binding->var)) {
+        params.push_back(binding->var);
+      }
     }
-  }
 
-  bool UsesOnlyLiftableProducers(const Expr& expr) {
-    auto producers = FreeVars(expr);
-    bool uses_only_liftable_producers = [&]() {
-      return std::all_of(producers.begin(), producers.end(),
-                         [&](const auto& var) { return 
lifted_binding_lookup_.count(var); });
-    }();
-    return uses_only_liftable_producers;
+    return params;
   }
 
-  /*!
-   * \brief Build the function that transforms the parameters
-   * \return The created function, and a map from the variable in the original 
function to the index
-   * of the element of the output tuple
-   */
-  std::pair<Function, std::unordered_map<Var, int, ObjectPtrHash, 
ObjectPtrEqual>> Build() {
-    Array<PrimExpr> extra_symbolic_vars;
-    for (const auto& var : required_symbolic_var_during_inference_) {
-      if (!known_symbolic_var_during_inference_.count(var)) {
-        extra_symbolic_vars.push_back(var);
-      }
-    }
+  Function MakeCompileTimeFunction() const {
+    auto compile_time_params = GetCompileTimeInputs();
 
-    Array<StructInfo> input_sinfo;
-    Array<Expr> output_vars;
-    std::unordered_map<Var, int, ObjectPtrHash, ObjectPtrEqual> 
output_to_index;
+    Array<Binding> output_var_binding;
+    Array<Expr> output_exprs;
 
-    for (const auto& input : inputs_) {
-      input_sinfo.push_back(Downcast<StructInfo>(input->struct_info_.value()));
+    // Any symbolic variables that are inferrable from compile-time
+    // parameters, but are not inferrable from run-time parameters,
+    // must be propagated to the output.
+    if (auto propagated_tir_vars = GetPropagatedSymbolicVariables(); 
propagated_tir_vars.size()) {
+      output_exprs.push_back(
+          ShapeExpr(propagated_tir_vars.Map([](tir::Var var) -> PrimExpr { 
return var; })));
     }
-    Var params("params", TupleStructInfo(input_sinfo));
 
-    if (extra_symbolic_vars.size()) {
-      output_vars.push_back(builder_->Emit(ShapeExpr(extra_symbolic_vars), 
"extra_symbolic_vars"));
+    for (const auto& var : GetCompileTimeOutputs()) {
+      Var out_var(var->name_hint() + "_output", GetStructInfo(var));
+      output_var_binding.push_back(VarBinding(out_var, var));
+      output_exprs.push_back(out_var);
     }
 
-    // Helper to add a variable to the output tuple
-    // original_var: the binding variable in the original function
-    // output_var: the variable, which is a binding in the transform_params 
function, that is added
-    // to the output tuple
-    auto f_add_output = [&](const Var& original_var, const Var& output_var) -> 
void {
-      output_to_index[original_var] = output_vars.size();
-      output_vars.push_back(output_var);
-    };
+    Var tuple_var("output_tuple", 
TupleStructInfo(output_exprs.Map(GetStructInfo)));
+    output_var_binding.push_back(VarBinding(tuple_var, Tuple(output_exprs)));
+
+    SeqExpr body(
+        {
+            DataflowBlock(computable_at_compile_time),
+            DataflowBlock(output_var_binding),
+        },
+        tuple_var);
+
+    Function func(compile_time_params, body, GetStructInfo(tuple_var));
+    func = WithAttr(func, attr::kNumInput, Integer(0));
+    func = CopyWithNewVars(func);
+    func = Downcast<Function>(CanonicalizeBindings(func));
+    return func;
+  }
 
-    // Create mapping from the original input variables to the TupleGetItem 
from the packed
-    // parameter tuple Add the parameters that are marked as the output of the 
function to the
-    // output tuple
-    for (const auto& input : inputs_) {
-      input_remap_.emplace(input.get(), TupleGetItem(params, 
input_remap_.size()));
-      if (outputs_.count(input)) {
-        auto output_var = builder_->Emit(input_remap_.at(input.get()));
-        f_add_output(input, output_var);
-      }
+  Function MakeRuntimeFunction() const {
+    Array<Binding> bindings;
+
+    // Any parameter that isn't available until runtime must be an
+    // input, along with any output from the compile-time function.
+    // Compile-time outputs must have a fresh non-dataflow var to
+    // serve as the parameter.  This trivial binding will later be
+    // removed with CanonicalizeBindings.
+    Array<Var> params = GetRuntimeInputs();
+    if (auto propagated_tir_vars = GetPropagatedSymbolicVariables(); 
propagated_tir_vars.size()) {
+      ShapeStructInfo shape_sinfo(
+          propagated_tir_vars.Map([](tir::Var var) -> PrimExpr { return var; 
}));
+      Var shape_expr("vars_from_compile_time_params", shape_sinfo);
+      params.push_back(shape_expr);
+    }
+    for (const auto& var : GetCompileTimeOutputs()) {
+      Var param_var(var->name_hint(), GetStructInfo(var));
+      bindings.push_back(VarBinding(var, param_var));
+      params.push_back(param_var);
     }
 
-    // Re-emit the bindings that are lifted. Update the output tuple if the 
binding is marked as the
-    // output.
-    for (const auto& binding : bindings_) {
-      if (outputs_.count(binding->var)) {
-        auto output_var = builder_->Emit(VisitExpr(binding->value));
-        var_remap_[binding->var->vid] = output_var;
-        f_add_output(binding->var, output_var);
-      } else {
-        VisitBinding(binding);
+    // Any binding that is computable at compile-time should be
+    // suppressed at run-time.
+    struct SuppressCompileTime : ExprMutator {
+      std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> to_suppress;
+      explicit SuppressCompileTime(const std::vector<Binding>& bindings) {
+        for (const auto& binding : bindings) {
+          to_suppress.insert(binding->var);
+        }
       }
-    }
 
-    // Create the function.
-    Expr transformed_params = builder_->EmitOutput(Tuple(output_vars));
-    BindingBlock block = builder_->EndBlock();
-    Expr body = VisitWithNewScope(SeqExpr({block}, transformed_params), 
Array<Var>{params});
-    Function f_transform_params =
-        Function(/*params=*/{params}, /*body=*/body, 
/*ret_struct_info=*/NullOpt);
-    return {f_transform_params, output_to_index};
-  }
+      void VisitBinding(const Binding& binding) override {
+        if (!to_suppress.count(binding->var)) {
+          ExprMutator::VisitBinding(binding);
+        }
+      }
 
-  Expr VisitExpr_(const VarNode* var) final {
-    if (auto it = input_remap_.find(var); it != input_remap_.end()) {
-      return builder_->Emit((*it).second);
-    } else {
-      return ExprMutator::VisitExpr_(var);
-    }
+      using ExprMutator::VisitExpr_;
+      Expr VisitExpr_(const CallNode* call) override {
+        static const Op& stop_lift_params_op = 
Op::Get("relax.builtin.stop_lift_params");
+        if (call->op.same_as(stop_lift_params_op)) {
+          return VisitExpr(call->args[0]);
+        } else {
+          return ExprMutator::VisitExpr_(call);
+        }
+      }
+    };
+    Expr body = 
SuppressCompileTime(computable_at_compile_time)(orig_func->body);
+    body = SeqExpr({DataflowBlock(bindings)}, body);
+
+    Function func(params, body, orig_func->ret_struct_info, 
orig_func->is_pure, orig_func->attrs);
+    func = WithoutAttr(func, tvm::attr::kGlobalSymbol);
+    func = CopyWithNewVars(func);
+    return func;
   }
 
-  // The input parameters of the function.
-  Array<Var> inputs_;
-  // Remap from the original input variable to TupleGetItem from the packed 
parameter tuple, which
-  // is the input of the lifted function.
-  std::unordered_map<const VarNode*, Expr> input_remap_;
-  // The bindings that are lifted.
-  Array<VarBinding> bindings_;
-  // The variables that are marked as the output of the function.
-  std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> outputs_;
+  Function MakePartitionedFunction() const {
+    Array<Binding> inner_func_bindings;
+    Var compile_time_func = [&]() {
+      auto func = MakeCompileTimeFunction();
+      Var var("transform_params", GetStructInfo(func));
+      inner_func_bindings.push_back(VarBinding(var, std::move(func)));
+      return var;
+    }();
+    Var runtime_func = [&]() {
+      auto func = MakeRuntimeFunction();
+      Var var("runtime", GetStructInfo(func));
+      inner_func_bindings.push_back(VarBinding(var, std::move(func)));
+      return var;
+    }();
 
-  // The bindings that are lifted
-  std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> 
lifted_binding_lookup_;
+    Array<Binding> calling_scope;
 
-  /* Symbolic variables that are known during the transform_params execution.
-   *
-   * This set is populated based on the variables declared with
-   * AddInput, and contains variables that may appear inside the
-   * transformation function.  A binding that depends on a symbolic
-   * variable not contained in this set may not be lifted.
-   */
-  support::OrderedSet<tir::Var> known_symbolic_var_during_transform_;
+    Call compile_time_preprocess(
+        compile_time_func, GetCompileTimeInputs().Map([](const Var& var) -> 
Expr { return var; }));
 
-  /* Symbolic variables that are known during the runtime
-   *
-   * This set is populated based on the variables declared with
-   * UpdateBasedOnRuntimeInput, and contains variables that are
-   * defined at runtime.  A variable that present in
-   * required_symbolic_var_during_inference_, but not present in this
-   * set, causes the Build() function to output an additional
-   * R.ShapeExpr in order to propagate the symbolic variables.
-   */
-  support::OrderedSet<tir::Var> known_symbolic_var_during_inference_;
+    // Use a fresh variable in case it is passed through unmodified in
+    // the compile-time function.
+    Array<Var> compile_time_outputs;
+    if (auto propagated_tir_vars = GetPropagatedSymbolicVariables(); 
propagated_tir_vars.size()) {
+      ShapeStructInfo shape_sinfo(
+          propagated_tir_vars.Map([](tir::Var var) -> PrimExpr { return var; 
}));
+      Var shape_expr("vars_from_compile_time_params", shape_sinfo);
+      compile_time_outputs.push_back(shape_expr);
+    }
+    for (const auto& relax_var : GetCompileTimeOutputs()) {
+      compile_time_outputs.push_back(
+          Var(relax_var->name_hint(), GetStructInfo(relax_var), 
relax_var->span));
+    }
+    {
+      Var tuple_output("compile_time_output",
+                       
TupleStructInfo(compile_time_outputs.Map(GetStructInfo)));
+      calling_scope.push_back(VarBinding(tuple_output, 
compile_time_preprocess));
+      for (size_t i = 0; i < compile_time_outputs.size(); i++) {
+        calling_scope.push_back(VarBinding(compile_time_outputs[i], 
TupleGetItem(tuple_output, i)));
+      }
+    }
 
-  /* Symbolic variables that must be known at runtime
-   *
-   * This set is populated based on the variables used in external
-   * bindings.  A variable that is present here, but not present in
-   * known_symbolic_var_during_inference_, must be provided as an
-   * additional R.ShapeExpr parameter from the transform_params
-   * function.
-   */
-  support::OrderedSet<tir::Var> required_symbolic_var_during_inference_;
+    Array<Expr> runtime_args = GetRuntimeInputs().Map([](const Var& var) -> 
Expr { return var; });
+    for (const auto& var : compile_time_outputs) {
+      runtime_args.push_back(var);
+    }
+
+    Call runtime_execution(runtime_func, runtime_args);
+    Var output_var("output", orig_func->ret_struct_info);
+    calling_scope.push_back(VarBinding(output_var, runtime_execution));
+
+    SeqExpr body(
+        {
+            BindingBlock(inner_func_bindings),
+            DataflowBlock(calling_scope),
+        },
+        output_var);
+
+    Function func = orig_func;
+    func.CopyOnWrite()->body = body;
+    func = Downcast<Function>(CanonicalizeBindings(func));
+    return func;
+  }
 };
 
-/*!
- * \brief Visitor that creates the plan of lifting transform params.
- *
- * Starting from the parameters of the function (they are the initial set of 
lifted bindings), we
- * will visit the body of the function to find the bindings that can be 
lifted. A binding can be
- * lifted if all the variables that it depends on are also lifted.
- *
- * When a binding cannot be lifted, all the variables that 1) it depends on, 
and 2) have been
- * lifted, will be marked as the boundary variable and will be in the output 
of the lifted function.
- */
-class LiftTransformParamsPlanner : public ExprVisitor {
+class LiftableBindingCollector : ExprVisitor {
  public:
-  LiftTransformParamsInfoPlan Plan(const Function& function, int num_inputs) {
-    for (int i = 0; i < static_cast<int>(function->params.size()); ++i) {
-      if (i < num_inputs) {
-        builder_.UpdateBasedOnRuntimeInput(function->params[i]);
-      } else {
-        builder_.AddInput(function->params[i]);
-        if (function->params[i]->struct_info_.defined()) {
-          Array<tir::Var> symbolic_vars = DefinableTIRVarsInStructInfo(
-              Downcast<StructInfo>(function->params[i]->struct_info_.value()));
-          for (const auto& var : symbolic_vars) {
-            param_symbolic_vars_.insert(var);
-          }
-        }
-      }
+  static CollectInfo Collect(const Function& func) {
+    LiftableBindingCollector visitor;
+    visitor(func);
+    visitor.info_.orig_func = func;
+    return visitor.info_;
+  }
+
+ private:
+  void VisitExpr_(const FunctionNode* func) override {
+    size_t num_runtime_params = func->params.size();
+    if (auto opt = func->attrs.GetAttr<Integer>(attr::kNumInput)) {
+      num_runtime_params = opt.value()->value;
     }
-    VisitExpr(function->body);
 
-    const auto& [f_transform_params, output_to_index] = builder_.Build();
-    return {f_transform_params, output_to_index, 
std::move(builder_.lifted_binding_lookup_)};
+    info_.num_runtime_params = num_runtime_params;
+
+    for (size_t i = num_runtime_params; i < func->params.size(); i++) {
+      liftable_vars_.insert(func->params[i]);
+      for (const auto& tir_var : 
DefinableTIRVarsInStructInfo(GetStructInfo(func->params[i]))) {
+        liftable_vars_.insert(tir_var);
+      }
+    }
+    ExprVisitor::VisitExpr_(func);
   }
 
- private:
   void VisitBindingBlock_(const DataflowBlockNode* block) final {
+    bool cache = is_in_dataflow_block_;
     is_in_dataflow_block_ = true;
     ExprVisitor::VisitBindingBlock_(block);
-    is_in_dataflow_block_ = false;
+    is_in_dataflow_block_ = cache;
+  }
+
+  void VisitBinding(const Binding& binding) override {
+    if (CanLiftBinding(binding)) {
+      info_.computable_at_compile_time.push_back(binding);
+      liftable_vars_.insert(binding->var);
+    } else {
+      info_.required_at_runtime.insert(binding->var);
+      auto bound_value = GetBoundValue(binding);
+      for (const auto& upstream_var : FreeVars(bound_value)) {
+        info_.required_at_runtime.insert(upstream_var);
+      }
+      for (const auto& tir_var : FreeSymbolicVars(bound_value)) {
+        info_.required_at_runtime.insert(tir_var);
+      }
+    }
   }
 
-  void VisitBinding_(const VarBindingNode* binding) final {
-    bool can_lift = true;
+  bool CanLiftBinding(const Binding& binding) const {
+    auto value = GetBoundValue(binding);
 
     // Cond 1. Do not lift bindings outside dataflow blocks.
     if (!is_in_dataflow_block_) {
-      can_lift = false;
+      return false;
     }
 
     // Cond 2. Do not lift regarding the "builtin.stop_lift_params" op.
-    if (const auto* call = binding->value.as<CallNode>()) {
+    if (const auto* call = value.as<CallNode>()) {
       static const Op& stop_lift_params_op = 
Op::Get("relax.builtin.stop_lift_params");
       if (call->op.same_as(stop_lift_params_op)) {
-        can_lift = false;
+        return false;
       }
     }
 
     // Cond 3. Do not lift when involving Vars that are not liftable.
-    auto producers = FreeVars(binding->value);
-    bool uses_only_liftable_producers = 
builder_.UsesOnlyLiftableProducers(binding->value);
-    if (!uses_only_liftable_producers) {
-      can_lift = false;
+    for (const auto& var : FreeVars(value)) {
+      if (!liftable_vars_.count(var)) {
+        return false;
+      }
     }
 
     // Cond 4. Do not lift when its struct info contains symbolic variables 
that do not appear in
     // params.
     for (const auto& var : TIRVarsInStructInfo(GetStructInfo(binding->var))) {
-      if (!param_symbolic_vars_.count(var)) {
-        can_lift = false;
+      if (!liftable_vars_.count(var)) {
+        return false;
       }
     }
 
     // Cond 5. Do not lift declarations of external functions
-    if (binding->value.as<relax::ExternFuncNode>()) {
-      can_lift = false;
+    if (value.as<relax::ExternFuncNode>()) {
+      return false;
     }
 
-    if (can_lift) {
-      builder_.AddInternalBinding(GetRef<VarBinding>(binding));
-    } else {
-      builder_.UpdateBasedOnRuntimeBinding(GetRef<VarBinding>(binding));
-    }
+    return true;
   }
 
-  // The builder of the function that transforms the parameters
-  TransformParamsFuncBuilder builder_;
-  // Whether we are in a dataflow block
+  CollectInfo info_;
+  std::unordered_set<Variant<Var, tir::Var>, ObjectPtrHash, ObjectPtrEqual> 
liftable_vars_;
   bool is_in_dataflow_block_{false};
-  // The symbolic variables in the parameters
-  std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> 
param_symbolic_vars_;
 };
 
-/*!
- *\brief The rewriter that lifts the transform params of a function and 
updates the original
- * function.
- */
-class TransformParamsLifter : ExprMutator {
+class PreprocessPartitioner : public ExprMutator {
  public:
-  explicit TransformParamsLifter(const IRModule& module) : ExprMutator(module) 
{}
-
-  Function VisitFunction(GlobalVar gvar, Function func) {
-    current_gvar_ = gvar;
-    auto out = Downcast<Function>(VisitExpr(std::move(func)));
-    current_gvar_ = NullOpt;
-    return out;
-  }
-
-  Map<GlobalVar, Function> GetTransformParamFunctions() const { return 
transform_param_funcs_; }
-
- private:
+  using ExprMutator::VisitExpr_;
   Expr VisitExpr_(const FunctionNode* op) override {
     auto func = GetRef<Function>(op);
-    Optional<Integer> opt_num_input = 
func->attrs.GetAttr<Integer>(attr::kNumInput);
-    if (!opt_num_input) {
+    if (func->attrs.GetAttr<Integer>(attr::kNumInput)) {
+      auto info = LiftableBindingCollector::Collect(func);
+      return info.MakePartitionedFunction();
+    } else {
       return func;
     }
-    auto signed_num_input = opt_num_input.value()->value;
-    ICHECK_GE(signed_num_input, 0);
-    ICHECK_LE(signed_num_input, func->params.size());
-    size_t num_input = signed_num_input;
-
-    LiftTransformParamsPlanner planner;
-
-    // Step 1: Create the plan of lifting transform params
-    lift_plan_ = planner.Plan(func, num_input);
-
-    // Step 2: Stash the lifted function to add to the module
-    transform_param_funcs_.Set(current_gvar_.value(), 
lift_plan_.f_transform_params);
-
-    // Step 3: Update the current function.
-
-    // Step 3.1: Update the function signature
-    Array<StructInfo> param_fields =
-        
Downcast<TupleStructInfo>(lift_plan_.f_transform_params->ret_struct_info)->fields;
-
-    Array<Var> new_params(func->params.begin(), func->params.begin() + 
num_input);
-    for (size_t i = 0; i < param_fields.size(); i++) {
-      std::stringstream name;
-      name << "transformed_param_" << i;
-      Var param(name.str(), param_fields[i]);
-      new_params.push_back(param);
-    }
-
-    // Step 3.2: Update the function body
-    for (const auto& [var, index] : lift_plan_.output_to_index) {
-      ICHECK_LT(num_input + index, new_params.size());
-      param_remap_[var] = new_params[num_input + index];
-    }
-    auto new_body = VisitWithNewScope(func->body, new_params);
-
-    return Function(new_params, new_body, func->ret_struct_info, 
func->is_pure, func->attrs);
-  }
-
-  void VisitBinding_(const VarBindingNode* binding) final {
-    if (lift_plan_.lifted_bindings.count(binding->var)) {
-      return;
-    }
-    if (const auto* call = binding->value.as<CallNode>()) {
-      static const Op& stop_lift_params_op = 
Op::Get("relax.builtin.stop_lift_params");
-      if (call->op.same_as(stop_lift_params_op)) {
-        var_remap_[binding->var->vid] = 
Downcast<Var>(VisitExpr(call->args[0]));
-        return;
-      }
-    }
-    ExprMutator::VisitBinding_(binding);
-  }
-
-  Expr VisitExpr_(const VarNode* var) final {
-    auto it = param_remap_.find(GetRef<Var>(var));
-    if (it != param_remap_.end()) {
-      return builder_->Emit(it->second);
-    }
-    return ExprMutator::VisitExpr_(var);
   }
+};
 
-  // Remap the original parameters to TupleGetItem from the packed tuple of 
transformed parameters.
-  std::unordered_map<Var, Expr, ObjectPtrHash, ObjectPtrEqual> param_remap_;
-  // The plan of lifting the transform params
-  LiftTransformParamsInfoPlan lift_plan_;
+// Adapted from https://stackoverflow.com/a/2072890
+inline bool ends_with(const std::string& value, const std::string& ending) {
+  return ending.size() <= value.size() &&
+         std::equal(ending.rbegin(), ending.rend(), value.rbegin());
+}
 
-  Map<GlobalVar, Function> transform_param_funcs_;
-  Optional<GlobalVar> current_gvar_;
-};
+}  // namespace
 
 namespace transform {
-Pass LiftTransformParams() {
-  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = 
[=](IRModule mod,
-                                                                            
PassContext pc) {
-    TransformParamsLifter mutator(mod);
+
+Pass PartitionTransformParams() {
+  auto pass_func = [=](IRModule mod, PassContext pc) {
+    PreprocessPartitioner mutator;
 
     IRModule updates;
     for (const auto& [gvar, func] : mod->functions) {
       if (auto opt = func.as<relax::Function>()) {
-        auto new_func = mutator.VisitFunction(gvar, opt.value());
+        auto new_func = Downcast<Function>(mutator(opt.value()));
         if (!new_func.same_as(func)) {
           updates->Add(gvar, new_func);
         }
       }
     }
-    for (auto [gvar, transform_func] : mutator.GetTransformParamFunctions()) {
-      String name = gvar->name_hint + "_transform_params";
-      GlobalVar new_gvar(name);
-      new_gvar->struct_info_ = transform_func->struct_info_;
 
-      transform_func = CopyWithNewVars(transform_func);
-      transform_func = WithAttr(transform_func, tvm::attr::kGlobalSymbol, 
name);
+    if (updates->functions.size()) {
+      mod.CopyOnWrite()->Update(updates);
+    }
+
+    return mod;
+  };
+  return tvm::transform::CreateModulePass(pass_func, 1, 
"PartitionTransformParams", {});
+}
 
-      updates->Add(new_gvar, transform_func);
+Pass LiftTransformParams() {
+  // A post-proc utility as as the third step in LiftTransformParams
+  //
+  // 1. PartitionTransformParams: Partition each function into a
+  // compile-time and run-time lambda functions.
+  //
+  // 2. LambdaLift: Lift the compile-time and run-time lambda
+  // functions out of the end-to-end function.
+  //
+  // 3. Post-proc: Expose the compile-time and run-time functions for
+  // external use, replacing the end-to-end functions.
+  auto post_proc_func = [=](IRModule mod, PassContext pc) {
+    std::unordered_set<GlobalVar, ObjectPtrHash, ObjectPtrEqual> to_remove;
+    std::unordered_map<GlobalVar, Function, ObjectPtrHash, ObjectPtrEqual> 
to_add;
+    for (const auto& [gvar, base_func] : mod->functions) {
+      if (auto opt = base_func.as<Function>()) {
+        auto func = opt.value();
+
+        std::string func_name = gvar->name_hint;
+        if (ends_with(func_name, "transform_params")) {
+          func = WithAttr(func, tvm::attr::kGlobalSymbol, gvar->name_hint);
+          func = BundleModelParams(func);
+          to_add[gvar] = func;
+        } else if (ends_with(func_name, "_runtime")) {
+          std::string name(func_name.begin(), func_name.end() - 
sizeof("_runtime") + 1);
+          to_remove.insert(mod->GetGlobalVar(name));
+          to_remove.insert(gvar);
+          to_add[GlobalVar(name)] = WithAttr(func, tvm::attr::kGlobalSymbol, 
String(name));
+        }
+      }
     }
 
-    if (updates->functions.size()) {
-      mod.CopyOnWrite()->Update(updates);
+    if (to_remove.size() || to_add.size()) {
+      auto write_ptr = mod.CopyOnWrite();
+      for (const auto& gvar : to_remove) {
+        write_ptr->Remove(gvar);
+      }
+      for (const auto& [gvar, func] : to_add) {
+        write_ptr->Add(gvar, func);
+      }
     }
 
     return mod;
   };
-  return CreateModulePass(pass_func, 1, "LiftTransformParams", {});
+  auto post_proc =
+      tvm::transform::CreateModulePass(post_proc_func, 1, 
"LiftTransformParamsPostProc", {});
+
+  return tvm::transform::Sequential(
+      {
+          PartitionTransformParams(),
+          LambdaLift(),
+          post_proc,
+      },
+      "LiftTransformParams");
 }
 
 
TVM_REGISTER_GLOBAL("relax.transform.LiftTransformParams").set_body_typed(LiftTransformParams);
diff --git a/tests/python/relax/test_transform_lift_transform_params.py 
b/tests/python/relax/test_transform_lift_transform_params.py
index 5b24614469..8042765d40 100644
--- a/tests/python/relax/test_transform_lift_transform_params.py
+++ b/tests/python/relax/test_transform_lift_transform_params.py
@@ -64,15 +64,14 @@ def test_basic():
         @R.function
         def main(
             x: R.Tensor((1, 3, 224, 224), dtype="float32"),
-            param0: R.Tensor((16, 16, 3, 3), dtype="float32"),
-            param1: R.Tensor((16, 3, 3, 3), dtype="float32"),
+            w2: R.Tensor((16, 16, 3, 3), dtype="float32"),
+            w1_transformed: R.Tensor((16, 3, 3, 3), dtype="float32"),
         ) -> R.Tensor((1, 16, 224, 224), dtype="float32"):
             R.func_attr({"num_input": 1})
             with R.dataflow():
-                param1 = param1
                 conv1: R.Tensor((1, 16, 224, 224), dtype="float32") = 
R.nn.conv2d(
                     x,
-                    param1,
+                    w1_transformed,
                     strides=[1, 1],
                     padding=[1, 1, 1, 1],
                     dilation=[1, 1],
@@ -82,10 +81,9 @@ def test_basic():
                     out_layout="NCHW",
                     out_dtype="void",
                 )
-                param0 = param0
                 conv2: R.Tensor((1, 16, 224, 224), dtype="float32") = 
R.nn.conv2d(
                     conv1,
-                    param0,
+                    w2,
                     strides=[1, 1],
                     padding=[1, 1, 1, 1],
                     dilation=[1, 1],
@@ -117,15 +115,16 @@ def test_basic():
         ) -> R.Tuple(
             R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), 
dtype="float32")
         ):
+            R.func_attr({"num_input": 0})
             cls = Expected
             with R.dataflow():
-                lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1]
                 lv1: R.Tensor((3, 16, 3, 3), dtype="float32") = params[0]
                 lv2 = R.call_tir(
                     cls.transform_layout_IOHW_to_OIHW,
                     (lv1,),
                     out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"),
                 )
+                lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1]
                 gv: R.Tuple(
                     R.Tensor((16, 16, 3, 3), dtype="float32"),
                     R.Tensor((16, 3, 3, 3), dtype="float32"),
@@ -137,6 +136,10 @@ def test_basic():
     after = relax.transform.LiftTransformParams()(mod)
     tvm.ir.assert_structural_equal(after, Expected)
 
+    names_after = [param.name_hint for param in after["main"].params]
+    names_expected = [param.name_hint for param in Expected["main"].params]
+    assert names_after == names_expected
+
 
 def test_tuple():
     @tvm.script.ir_module
@@ -168,10 +171,9 @@ def test_tuple():
         ) -> R.Tensor((1, 16, 224, 224), dtype="float32"):
             R.func_attr({"num_input": 1})
             with R.dataflow():
-                lv: R.Tensor((16, 16, 3, 3), dtype="float32") = param1
                 conv1: R.Tensor((1, 16, 224, 224), dtype="float32") = 
R.nn.conv2d(
                     x,
-                    lv,
+                    param1,
                     strides=[1, 1],
                     padding=[1, 1, 1, 1],
                     dilation=[1, 1],
@@ -181,10 +183,9 @@ def test_tuple():
                     out_layout="NCHW",
                     out_dtype="void",
                 )
-                lv1: R.Tensor((16, 16, 3, 3), dtype="float32") = param0
                 conv2: R.Tensor((1, 16, 224, 224), dtype="float32") = 
R.nn.conv2d(
                     conv1,
-                    lv1,
+                    param0,
                     strides=[1, 1],
                     padding=[1, 1, 1, 1],
                     dilation=[1, 1],
@@ -203,17 +204,14 @@ def test_tuple():
         ) -> R.Tuple(
             R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 
3), dtype="float32")
         ):
-            with R.dataflow():
-                lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[0]
-                lv1: R.Tensor((16, 16, 3, 3), dtype="float32") = params[0]
-                l0: R.Tuple(R.Tensor((16, 16, 3, 3), dtype="float32")) = (lv1,)
-                l1: R.Tuple(R.Tuple(R.Tensor((16, 16, 3, 3), 
dtype="float32"))) = (l0,)
-                l2: R.Tuple(R.Tensor((16, 16, 3, 3), dtype="float32")) = l1[0]
-                lv2: R.Tensor((16, 16, 3, 3), dtype="float32") = l2[0]
-                gv: R.Tuple(
-                    R.Tensor((16, 16, 3, 3), dtype="float32"),
-                    R.Tensor((16, 16, 3, 3), dtype="float32"),
-                ) = (lv, lv2)
+            R.func_attr({"num_input": 0})
+            with R.dataflow():
+                lv = params[0]
+                lv0 = (lv,)
+                lv1 = (lv0,)
+                lv2 = params[0]
+                lv3 = params[0]
+                gv = (lv2, lv3)
                 R.output(gv)
             return gv
 
@@ -258,6 +256,7 @@ def test_condition():
             R.Tensor((16, 16, 3, 3), dtype="float32"),
             R.Tensor((), dtype="bool"),
         ):
+            R.func_attr({"num_input": 0})
             with R.dataflow():
                 lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[0]
                 lv1: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1]
@@ -278,13 +277,10 @@ def test_condition():
             param2: R.Tensor((), dtype="bool"),
         ) -> R.Tensor((1, 16, 224, 224), "float32"):
             R.func_attr({"num_input": 1})
-            gv: R.Tensor((), dtype="bool") = param2
-            if gv:
-                gv1: R.Tensor((16, 16, 3, 3), dtype="float32") = param0
-                w: R.Tensor((16, 16, 3, 3), dtype="float32") = gv1
+            if param2:
+                w: R.Tensor((16, 16, 3, 3), dtype="float32") = param0
             else:
-                gv2: R.Tensor((16, 16, 3, 3), dtype="float32") = param1
-                w: R.Tensor((16, 16, 3, 3), dtype="float32") = gv2
+                w: R.Tensor((16, 16, 3, 3), dtype="float32") = param1
             with R.dataflow():
                 conv1 = R.nn.conv2d(x, w, padding=(1, 1), data_layout="NCHW", 
kernel_layout="OIHW")
                 R.output(conv1)
@@ -342,8 +338,7 @@ def test_multiple_functions():
         ) -> R.Tensor((256, 256), dtype="float32"):
             R.func_attr({"num_input": 1})
             with R.dataflow():
-                lv: R.Tensor((256, 256), dtype="float32") = param0
-                y: R.Tensor((256, 256), dtype="float32") = R.matmul(x, lv, 
out_dtype="void")
+                y: R.Tensor((256, 256), dtype="float32") = R.matmul(x, param0, 
out_dtype="void")
                 R.output(y)
             return y
 
@@ -351,6 +346,7 @@ def test_multiple_functions():
         def func1_transform_params(
             params: R.Tuple(R.Tensor((256, 256), dtype="float32"))
         ) -> R.Tuple(R.Tensor((256, 256), dtype="float32")):
+            R.func_attr({"num_input": 0})
             with R.dataflow():
                 lv: R.Tensor((256, 256), dtype="float32") = params[0]
                 lv1: R.Tensor((256, 256), dtype="float32") = 
R.permute_dims(lv, axes=[1, 0])
@@ -365,8 +361,7 @@ def test_multiple_functions():
         ) -> R.Tensor((256, 128), dtype="float32"):
             R.func_attr({"num_input": 1})
             with R.dataflow():
-                lv1: R.Tensor((256, 128), dtype="float32") = param0
-                y: R.Tensor((256, 128), dtype="float32") = R.matmul(x, lv1, 
out_dtype="void")
+                y: R.Tensor((256, 128), dtype="float32") = R.matmul(x, param0, 
out_dtype="void")
                 R.output(y)
             return y
 
@@ -374,6 +369,7 @@ def test_multiple_functions():
         def func2_transform_params(
             params: R.Tuple(R.Tensor((128, 256), dtype="float32"))
         ) -> R.Tuple(R.Tensor((256, 128), dtype="float32")):
+            R.func_attr({"num_input": 0})
             with R.dataflow():
                 lv: R.Tensor((128, 256), dtype="float32") = params[0]
                 lv1: R.Tensor((256, 128), dtype="float32") = 
R.permute_dims(lv, axes=[1, 0])
@@ -422,8 +418,7 @@ def test_stop_lifting():
         ) -> R.Tensor((256, 256), dtype="float32"):
             R.func_attr({"num_input": 1})
             with R.dataflow():
-                lv: R.Tensor((256, 256), dtype="float32") = param0
-                w1_add: R.Tensor((256, 256), dtype="float32") = R.add(lv, 
R.const(1, "float32"))
+                w1_add: R.Tensor((256, 256), dtype="float32") = R.add(param0, 
R.const(1, "float32"))
                 y: R.Tensor((256, 256), dtype="float32") = R.matmul(x, w1_add, 
out_dtype="void")
                 R.output(y)
             return y
@@ -432,6 +427,7 @@ def test_stop_lifting():
         def func1_transform_params(
             params: R.Tuple(R.Tensor((256, 256), dtype="float32"))
         ) -> R.Tuple(R.Tensor((256, 256), dtype="float32")):
+            R.func_attr({"num_input": 0})
             with R.dataflow():
                 lv: R.Tensor((256, 256), dtype="float32") = params[0]
                 lv1: R.Tensor((256, 256), dtype="float32") = 
R.permute_dims(lv, axes=[1, 0])
@@ -459,6 +455,7 @@ def test_symbolic_var_1():
     class Expected:
         @R.function
         def main_transform_params(params: R.Tuple) -> R.Tuple:
+            R.func_attr({"num_input": 0})
             with R.dataflow():
                 gv: R.Tuple = R.tuple()
                 R.output(gv)
@@ -522,6 +519,7 @@ def test_symbolic_var_2():
 
         @R.function
         def main_transform_params(params: R.Tuple) -> R.Tuple:
+            R.func_attr({"num_input": 0})
             with R.dataflow():
                 gv: R.Tuple = R.tuple()
                 R.output(gv)
@@ -603,7 +601,6 @@ def test_symbolic_var_from_shape():
                     tir_vars=R.ShapeExpr([slice_index]),
                     out_sinfo=R.Tensor([16], dtype="int32"),
                 )
-                B_slice = B_slice
                 A_scale = R.multiply(A_slice, B_slice)
                 R.output(A_scale)
             return A_scale
@@ -612,18 +609,19 @@ def test_symbolic_var_from_shape():
         def main_transform_params(
             params: R.Tuple(R.Tensor([16, 16], "int32"), 
R.Shape(["slice_index"]))
         ):
+            R.func_attr({"num_input": 0})
             slice_index = T.int64()
             cls = Expected
             with R.dataflow():
-                extra_symbolic_vars = R.ShapeExpr([slice_index])
                 B = params[0]
+                # extra_symbolic_vars = params[1]
                 B_slice = R.call_tir(
                     cls.slice,
                     [B],
                     tir_vars=R.ShapeExpr([slice_index]),
                     out_sinfo=R.Tensor([16], dtype="int32"),
                 )
-                output = (extra_symbolic_vars, B_slice)
+                output = (R.ShapeExpr([slice_index]), B_slice)
                 R.output(output)
             return output
 
@@ -652,7 +650,7 @@ def test_symbolic_var_in_param_shape():
             x: R.Tensor((1, 16, 224, "n"), "float32"),
             w1: R.Tensor((16, "m", 3, 3), "float32"),
             w2: R.Tensor((16, "m", 3, 3), "float32"),
-        ) -> R.Tensor((1, 16, 224, 224), "float32"):
+        ) -> R.Tensor((1, 16, 224, "n"), "float32"):
             m = T.int64()
             n = T.int64()
             R.func_attr({"num_input": 1})
@@ -677,11 +675,12 @@ def test_symbolic_var_in_param_shape():
         ) -> R.Tuple(
             R.Tensor((16, "m", 3, 3), dtype="float32"), R.Tensor((16, "m", 3, 
3), dtype="float32")
         ):
+            R.func_attr({"num_input": 0})
             m = T.int64()
             with R.dataflow():
-                lv: R.Tensor((16, m, 3, 3), dtype="float32") = params[1]
                 lv1: R.Tensor((16, m, 3, 3), dtype="float32") = params[0]
                 lv2: R.Tensor((16, m, 3, 3), dtype="float32") = R.add(lv1, 
R.const(1, "float32"))
+                lv: R.Tensor((16, m, 3, 3), dtype="float32") = params[1]
                 gv: R.Tuple(
                     R.Tensor((16, m, 3, 3), dtype="float32"),
                     R.Tensor((16, m, 3, 3), dtype="float32"),
@@ -694,16 +693,15 @@ def test_symbolic_var_in_param_shape():
             x: R.Tensor((1, 16, 224, "n"), dtype="float32"),
             transformed_param_0: R.Tensor((16, "m", 3, 3), dtype="float32"),
             transformed_param_1: R.Tensor((16, "m", 3, 3), dtype="float32"),
-        ) -> R.Tensor((1, 16, 224, 224), dtype="float32"):
+        ) -> R.Tensor((1, 16, 224, "n"), dtype="float32"):
             n = T.int64()
             m = T.int64()
             R.func_attr({"num_input": 1})
             with R.dataflow():
                 zeros: R.Tensor((n, n), dtype="float32") = R.zeros(R.shape([n, 
n]), dtype="float32")
-                lv: R.Tensor((16, m, 3, 3), dtype="float32") = 
transformed_param_1
                 conv1: R.Tensor((1, 16, 224, n), dtype="float32") = 
R.nn.conv2d(
                     x,
-                    lv,
+                    transformed_param_1,
                     strides=[1, 1],
                     padding=[1, 1, 1, 1],
                     dilation=[1, 1],
@@ -713,10 +711,9 @@ def test_symbolic_var_in_param_shape():
                     out_layout="NCHW",
                     out_dtype="void",
                 )
-                lv1: R.Tensor((16, m, 3, 3), dtype="float32") = 
transformed_param_0
                 conv2: R.Tensor((1, 16, 224, n), dtype="float32") = 
R.nn.conv2d(
                     conv1,
-                    lv1,
+                    transformed_param_0,
                     strides=[1, 1],
                     padding=[1, 1, 1, 1],
                     dilation=[1, 1],
@@ -770,6 +767,7 @@ def 
test_symbolic_var_defined_in_params_but_used_in_weights():
         def main_transform_params(
             params: R.Tuple(R.Tensor(("k",), dtype="float32"))
         ) -> R.Tuple(R.Tensor(dtype="float32", ndim=1)):
+            R.func_attr({"num_input": 0})
             k = T.int64()
             with R.dataflow():
                 lv: R.Tensor((k,), dtype="float32") = params[0]


Reply via email to