This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e715814985 [Relax][Transform] Preserve param names in
LiftTransformParams (#16594)
e715814985 is described below
commit e715814985fc88d813eabfa9ada364bfcadc7bec
Author: Eric Lunderberg <[email protected]>
AuthorDate: Fri Feb 23 08:36:33 2024 -0600
[Relax][Transform] Preserve param names in LiftTransformParams (#16594)
* [Relax][Transform] Preserve param names in LiftTransformParams
The `relax.transform.LiftTransformParams` pass splits apart a relax
function, extracting the steps that could be performed at
compile-time. Prior to this commit, the transformed parameters were
named `param0`, `param1`, and so on.
This commit updates the `LiftTransformParams` pass to preserve any
human-readable parameter names. The parameter names for the updated
function are taken from the original parameter names, if no
transformation is performed, or from the internal variable binding, if
a transformation is applied. This implementation uses `LambdaLift`
internally, relying on the changes made in
https://github.com/apache/tvm/pull/16306.
* Update based on review comments
---
src/relax/transform/lift_transform_params.cc | 637 +++++++++++----------
.../relax/test_transform_lift_transform_params.py | 86 ++-
2 files changed, 380 insertions(+), 343 deletions(-)
diff --git a/src/relax/transform/lift_transform_params.cc
b/src/relax/transform/lift_transform_params.cc
index b500a3c3a3..15b60f5492 100644
--- a/src/relax/transform/lift_transform_params.cc
+++ b/src/relax/transform/lift_transform_params.cc
@@ -18,7 +18,7 @@
*/
/*!
- * \file tvm/relax/transform/lambda_lift.cc
+ * \file tvm/relax/transform/lift_transform_params.cc
* \brief Lift local functions into global functions.
*/
@@ -29,6 +29,7 @@
#include <tvm/runtime/logging.h>
#include <iostream>
+#include <tuple>
#include <vector>
#include "../../support/ordered_set.h"
@@ -37,405 +38,443 @@
namespace tvm {
namespace relax {
-/*! \brief Plan of lifting transform params */
-struct LiftTransformParamsInfoPlan {
- Function f_transform_params; // the lifted function that transforms the
parameters
- std::unordered_map<Var, int, ObjectPtrHash, ObjectPtrEqual>
- output_to_index; // the index of the original bindings in the output
tuple
- std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>
- lifted_bindings; // the bindings of the original function that are
lifted
-};
+namespace {
-/*! \brief Builder of the function that transforms the parameters. */
-class TransformParamsFuncBuilder : public ExprMutator {
- public:
- TransformParamsFuncBuilder() { builder_->BeginDataflowBlock(); }
+struct CollectInfo {
+ /* \brief The analyzed function */
+ Function orig_func;
+
+ /* \brief The number of parameters unknown until runtime */
+ size_t num_runtime_params;
+
+ /*! \brief Bindings that can be lifted out into a pre-processing
+ *
+ * - All bindings in `computable_at_compile_time` are suitable for
+ * use in a DataflowBlock.
+ *
+ * - Do not depend on any parameter prior to attr::kNumInput.
+ *
+ * - Does not include "relax.builtin.stop_lift_params"
+ */
+ std::vector<Binding> computable_at_compile_time;
- /*! \brief Add a input parameter. */
- void AddInput(const Var& var) {
- inputs_.push_back(var);
- lifted_binding_lookup_.insert(var);
+ /*! \brief Variables that are required at runtime */
+ std::unordered_set<Variant<relax::Var, tir::Var>, ObjectPtrHash,
ObjectPtrEqual>
+ required_at_runtime;
+
+ Array<Var> GetCompileTimeInputs() const {
+ return Array<Var>(orig_func->params.begin() + num_runtime_params,
orig_func->params.end());
}
- void UpdateBasedOnRuntimeInput(const Var& var) {
- for (const auto& var : DefinableTIRVarsInStructInfo(GetStructInfo(var))) {
- known_symbolic_var_during_inference_.insert(var);
- }
- for (const auto& var : TIRVarsInStructInfo(GetStructInfo(var))) {
- required_symbolic_var_during_inference_.insert(var);
- }
+ Array<Var> GetRuntimeInputs() const {
+ return Array<Var>(orig_func->params.begin(), orig_func->params.begin() +
num_runtime_params);
}
- /*! \brief Add a binding to lift. */
- void AddInternalBinding(const VarBinding& binding) {
- bindings_.push_back(binding);
- lifted_binding_lookup_.insert(binding->var);
+ Array<tir::Var> GetPropagatedSymbolicVariables() const {
+ auto vars_from_any_param =
+
DefinableTIRVarsInStructInfo(TupleStructInfo(orig_func->params.Map(GetStructInfo)));
+
+ auto vars_from_runtime_params =
+ [&]() -> std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> {
+ auto tir_var_vec =
+
DefinableTIRVarsInStructInfo(TupleStructInfo(GetRuntimeInputs().Map(GetStructInfo)));
+ return {tir_var_vec.begin(), tir_var_vec.end()};
+ }();
+
+ auto vars_from_transformed_params =
+ [&]() -> std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> {
+ auto tir_var_vec =
+
DefinableTIRVarsInStructInfo(TupleStructInfo(GetCompileTimeOutputs().Map(GetStructInfo)));
+ return {tir_var_vec.begin(), tir_var_vec.end()};
+ }();
+
+ Array<tir::Var> output;
+ for (const auto& tir_var : vars_from_any_param) {
+ if (required_at_runtime.count(tir_var) &&
!vars_from_runtime_params.count(tir_var) &&
+ !vars_from_transformed_params.count(tir_var)) {
+ output.push_back(tir_var);
+ }
+ }
+ return output;
}
- /*! \brief Update based on bindings not being lifted. */
- void UpdateBasedOnRuntimeBinding(const VarBinding& binding) {
- for (const auto& producer : FreeVars(binding->value)) {
- // An external value that uses a lifted binding requires the
- // lifted binding to be returned as output.
- if (lifted_binding_lookup_.count(producer)) {
- outputs_.insert(producer);
+ Array<Var> GetCompileTimeOutputs() const {
+ Array<Var> params;
- for (const auto& var :
DefinableTIRVarsInStructInfo(GetStructInfo(producer))) {
- known_symbolic_var_during_inference_.insert(var);
- }
+ // Any value that is available at compile-time, but is also
+ // required at runtime, must be passed through the compile-time
+ // function.
+ for (size_t i = num_runtime_params; i < orig_func->params.size(); i++) {
+ Var var = orig_func->params[i];
+ if (required_at_runtime.count(var)) {
+ params.push_back(var);
}
}
- // All TIR variables used in the binding must be available at runtime.
- for (const auto& var : FreeSymbolicVars(binding->value)) {
- required_symbolic_var_during_inference_.insert(var);
+ // Any variable that is computed at compile-time, but is required
+ // at runtime, must be provided as a parameter.
+ for (const auto& binding : computable_at_compile_time) {
+ if (required_at_runtime.count(binding->var)) {
+ params.push_back(binding->var);
+ }
}
- }
- bool UsesOnlyLiftableProducers(const Expr& expr) {
- auto producers = FreeVars(expr);
- bool uses_only_liftable_producers = [&]() {
- return std::all_of(producers.begin(), producers.end(),
- [&](const auto& var) { return
lifted_binding_lookup_.count(var); });
- }();
- return uses_only_liftable_producers;
+ return params;
}
- /*!
- * \brief Build the function that transforms the parameters
- * \return The created function, and a map from the variable in the original
function to the index
- * of the element of the output tuple
- */
- std::pair<Function, std::unordered_map<Var, int, ObjectPtrHash,
ObjectPtrEqual>> Build() {
- Array<PrimExpr> extra_symbolic_vars;
- for (const auto& var : required_symbolic_var_during_inference_) {
- if (!known_symbolic_var_during_inference_.count(var)) {
- extra_symbolic_vars.push_back(var);
- }
- }
+ Function MakeCompileTimeFunction() const {
+ auto compile_time_params = GetCompileTimeInputs();
- Array<StructInfo> input_sinfo;
- Array<Expr> output_vars;
- std::unordered_map<Var, int, ObjectPtrHash, ObjectPtrEqual>
output_to_index;
+ Array<Binding> output_var_binding;
+ Array<Expr> output_exprs;
- for (const auto& input : inputs_) {
- input_sinfo.push_back(Downcast<StructInfo>(input->struct_info_.value()));
+ // Any symbolic variables that are inferrable from compile-time
+ // parameters, but are not inferrable from run-time parameters,
+ // must be propagated to the output.
+ if (auto propagated_tir_vars = GetPropagatedSymbolicVariables();
propagated_tir_vars.size()) {
+ output_exprs.push_back(
+ ShapeExpr(propagated_tir_vars.Map([](tir::Var var) -> PrimExpr {
return var; })));
}
- Var params("params", TupleStructInfo(input_sinfo));
- if (extra_symbolic_vars.size()) {
- output_vars.push_back(builder_->Emit(ShapeExpr(extra_symbolic_vars),
"extra_symbolic_vars"));
+ for (const auto& var : GetCompileTimeOutputs()) {
+ Var out_var(var->name_hint() + "_output", GetStructInfo(var));
+ output_var_binding.push_back(VarBinding(out_var, var));
+ output_exprs.push_back(out_var);
}
- // Helper to add a variable to the output tuple
- // original_var: the binding variable in the original function
- // output_var: the variable, which is a binding in the transform_params
function, that is added
- // to the output tuple
- auto f_add_output = [&](const Var& original_var, const Var& output_var) ->
void {
- output_to_index[original_var] = output_vars.size();
- output_vars.push_back(output_var);
- };
+ Var tuple_var("output_tuple",
TupleStructInfo(output_exprs.Map(GetStructInfo)));
+ output_var_binding.push_back(VarBinding(tuple_var, Tuple(output_exprs)));
+
+ SeqExpr body(
+ {
+ DataflowBlock(computable_at_compile_time),
+ DataflowBlock(output_var_binding),
+ },
+ tuple_var);
+
+ Function func(compile_time_params, body, GetStructInfo(tuple_var));
+ func = WithAttr(func, attr::kNumInput, Integer(0));
+ func = CopyWithNewVars(func);
+ func = Downcast<Function>(CanonicalizeBindings(func));
+ return func;
+ }
- // Create mapping from the original input variables to the TupleGetItem
from the packed
- // parameter tuple Add the parameters that are marked as the output of the
function to the
- // output tuple
- for (const auto& input : inputs_) {
- input_remap_.emplace(input.get(), TupleGetItem(params,
input_remap_.size()));
- if (outputs_.count(input)) {
- auto output_var = builder_->Emit(input_remap_.at(input.get()));
- f_add_output(input, output_var);
- }
+ Function MakeRuntimeFunction() const {
+ Array<Binding> bindings;
+
+ // Any parameter that isn't available until runtime must be an
+ // input, along with any output from the compile-time function.
+ // Compile-time outputs must have a fresh non-dataflow var to
+ // serve as the parameter. This trivial binding will later be
+ // removed with CanonicalizeBindings.
+ Array<Var> params = GetRuntimeInputs();
+ if (auto propagated_tir_vars = GetPropagatedSymbolicVariables();
propagated_tir_vars.size()) {
+ ShapeStructInfo shape_sinfo(
+ propagated_tir_vars.Map([](tir::Var var) -> PrimExpr { return var;
}));
+ Var shape_expr("vars_from_compile_time_params", shape_sinfo);
+ params.push_back(shape_expr);
+ }
+ for (const auto& var : GetCompileTimeOutputs()) {
+ Var param_var(var->name_hint(), GetStructInfo(var));
+ bindings.push_back(VarBinding(var, param_var));
+ params.push_back(param_var);
}
- // Re-emit the bindings that are lifted. Update the output tuple if the
binding is marked as the
- // output.
- for (const auto& binding : bindings_) {
- if (outputs_.count(binding->var)) {
- auto output_var = builder_->Emit(VisitExpr(binding->value));
- var_remap_[binding->var->vid] = output_var;
- f_add_output(binding->var, output_var);
- } else {
- VisitBinding(binding);
+ // Any binding that is computable at compile-time should be
+ // suppressed at run-time.
+ struct SuppressCompileTime : ExprMutator {
+ std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> to_suppress;
+ explicit SuppressCompileTime(const std::vector<Binding>& bindings) {
+ for (const auto& binding : bindings) {
+ to_suppress.insert(binding->var);
+ }
}
- }
- // Create the function.
- Expr transformed_params = builder_->EmitOutput(Tuple(output_vars));
- BindingBlock block = builder_->EndBlock();
- Expr body = VisitWithNewScope(SeqExpr({block}, transformed_params),
Array<Var>{params});
- Function f_transform_params =
- Function(/*params=*/{params}, /*body=*/body,
/*ret_struct_info=*/NullOpt);
- return {f_transform_params, output_to_index};
- }
+ void VisitBinding(const Binding& binding) override {
+ if (!to_suppress.count(binding->var)) {
+ ExprMutator::VisitBinding(binding);
+ }
+ }
- Expr VisitExpr_(const VarNode* var) final {
- if (auto it = input_remap_.find(var); it != input_remap_.end()) {
- return builder_->Emit((*it).second);
- } else {
- return ExprMutator::VisitExpr_(var);
- }
+ using ExprMutator::VisitExpr_;
+ Expr VisitExpr_(const CallNode* call) override {
+ static const Op& stop_lift_params_op =
Op::Get("relax.builtin.stop_lift_params");
+ if (call->op.same_as(stop_lift_params_op)) {
+ return VisitExpr(call->args[0]);
+ } else {
+ return ExprMutator::VisitExpr_(call);
+ }
+ }
+ };
+ Expr body =
SuppressCompileTime(computable_at_compile_time)(orig_func->body);
+ body = SeqExpr({DataflowBlock(bindings)}, body);
+
+ Function func(params, body, orig_func->ret_struct_info,
orig_func->is_pure, orig_func->attrs);
+ func = WithoutAttr(func, tvm::attr::kGlobalSymbol);
+ func = CopyWithNewVars(func);
+ return func;
}
- // The input parameters of the function.
- Array<Var> inputs_;
- // Remap from the original input variable to TupleGetItem from the packed
parameter tuple, which
- // is the input of the lifted function.
- std::unordered_map<const VarNode*, Expr> input_remap_;
- // The bindings that are lifted.
- Array<VarBinding> bindings_;
- // The variables that are marked as the output of the function.
- std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> outputs_;
+ Function MakePartitionedFunction() const {
+ Array<Binding> inner_func_bindings;
+ Var compile_time_func = [&]() {
+ auto func = MakeCompileTimeFunction();
+ Var var("transform_params", GetStructInfo(func));
+ inner_func_bindings.push_back(VarBinding(var, std::move(func)));
+ return var;
+ }();
+ Var runtime_func = [&]() {
+ auto func = MakeRuntimeFunction();
+ Var var("runtime", GetStructInfo(func));
+ inner_func_bindings.push_back(VarBinding(var, std::move(func)));
+ return var;
+ }();
- // The bindings that are lifted
- std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>
lifted_binding_lookup_;
+ Array<Binding> calling_scope;
- /* Symbolic variables that are known during the transform_params execution.
- *
- * This set is populated based on the variables declared with
- * AddInput, and contains variables that may appear inside the
- * transformation function. A binding that depends on a symbolic
- * variable not contained in this set may not be lifted.
- */
- support::OrderedSet<tir::Var> known_symbolic_var_during_transform_;
+ Call compile_time_preprocess(
+ compile_time_func, GetCompileTimeInputs().Map([](const Var& var) ->
Expr { return var; }));
- /* Symbolic variables that are known during the runtime
- *
- * This set is populated based on the variables declared with
- * UpdateBasedOnRuntimeInput, and contains variables that are
- * defined at runtime. A variable that present in
- * required_symbolic_var_during_inference_, but not present in this
- * set, causes the Build() function to output an additional
- * R.ShapeExpr in order to propagate the symbolic variables.
- */
- support::OrderedSet<tir::Var> known_symbolic_var_during_inference_;
+ // Use a fresh variable in case it is passed through unmodified in
+ // the compile-time function.
+ Array<Var> compile_time_outputs;
+ if (auto propagated_tir_vars = GetPropagatedSymbolicVariables();
propagated_tir_vars.size()) {
+ ShapeStructInfo shape_sinfo(
+ propagated_tir_vars.Map([](tir::Var var) -> PrimExpr { return var;
}));
+ Var shape_expr("vars_from_compile_time_params", shape_sinfo);
+ compile_time_outputs.push_back(shape_expr);
+ }
+ for (const auto& relax_var : GetCompileTimeOutputs()) {
+ compile_time_outputs.push_back(
+ Var(relax_var->name_hint(), GetStructInfo(relax_var),
relax_var->span));
+ }
+ {
+ Var tuple_output("compile_time_output",
+
TupleStructInfo(compile_time_outputs.Map(GetStructInfo)));
+ calling_scope.push_back(VarBinding(tuple_output,
compile_time_preprocess));
+ for (size_t i = 0; i < compile_time_outputs.size(); i++) {
+ calling_scope.push_back(VarBinding(compile_time_outputs[i],
TupleGetItem(tuple_output, i)));
+ }
+ }
- /* Symbolic variables that must be known at runtime
- *
- * This set is populated based on the variables used in external
- * bindings. A variable that is present here, but not present in
- * known_symbolic_var_during_inference_, must be provided as an
- * additional R.ShapeExpr parameter from the transform_params
- * function.
- */
- support::OrderedSet<tir::Var> required_symbolic_var_during_inference_;
+ Array<Expr> runtime_args = GetRuntimeInputs().Map([](const Var& var) ->
Expr { return var; });
+ for (const auto& var : compile_time_outputs) {
+ runtime_args.push_back(var);
+ }
+
+ Call runtime_execution(runtime_func, runtime_args);
+ Var output_var("output", orig_func->ret_struct_info);
+ calling_scope.push_back(VarBinding(output_var, runtime_execution));
+
+ SeqExpr body(
+ {
+ BindingBlock(inner_func_bindings),
+ DataflowBlock(calling_scope),
+ },
+ output_var);
+
+ Function func = orig_func;
+ func.CopyOnWrite()->body = body;
+ func = Downcast<Function>(CanonicalizeBindings(func));
+ return func;
+ }
};
-/*!
- * \brief Visitor that creates the plan of lifting transform params.
- *
- * Starting from the parameters of the function (they are the initial set of
lifted bindings), we
- * will visit the body of the function to find the bindings that can be
lifted. A binding can be
- * lifted if all the variables that it depends on are also lifted.
- *
- * When a binding cannot be lifted, all the variables that 1) it depends on,
and 2) have been
- * lifted, will be marked as the boundary variable and will be in the output
of the lifted function.
- */
-class LiftTransformParamsPlanner : public ExprVisitor {
+class LiftableBindingCollector : ExprVisitor {
public:
- LiftTransformParamsInfoPlan Plan(const Function& function, int num_inputs) {
- for (int i = 0; i < static_cast<int>(function->params.size()); ++i) {
- if (i < num_inputs) {
- builder_.UpdateBasedOnRuntimeInput(function->params[i]);
- } else {
- builder_.AddInput(function->params[i]);
- if (function->params[i]->struct_info_.defined()) {
- Array<tir::Var> symbolic_vars = DefinableTIRVarsInStructInfo(
- Downcast<StructInfo>(function->params[i]->struct_info_.value()));
- for (const auto& var : symbolic_vars) {
- param_symbolic_vars_.insert(var);
- }
- }
- }
+ static CollectInfo Collect(const Function& func) {
+ LiftableBindingCollector visitor;
+ visitor(func);
+ visitor.info_.orig_func = func;
+ return visitor.info_;
+ }
+
+ private:
+ void VisitExpr_(const FunctionNode* func) override {
+ size_t num_runtime_params = func->params.size();
+ if (auto opt = func->attrs.GetAttr<Integer>(attr::kNumInput)) {
+ num_runtime_params = opt.value()->value;
}
- VisitExpr(function->body);
- const auto& [f_transform_params, output_to_index] = builder_.Build();
- return {f_transform_params, output_to_index,
std::move(builder_.lifted_binding_lookup_)};
+ info_.num_runtime_params = num_runtime_params;
+
+ for (size_t i = num_runtime_params; i < func->params.size(); i++) {
+ liftable_vars_.insert(func->params[i]);
+ for (const auto& tir_var :
DefinableTIRVarsInStructInfo(GetStructInfo(func->params[i]))) {
+ liftable_vars_.insert(tir_var);
+ }
+ }
+ ExprVisitor::VisitExpr_(func);
}
- private:
void VisitBindingBlock_(const DataflowBlockNode* block) final {
+ bool cache = is_in_dataflow_block_;
is_in_dataflow_block_ = true;
ExprVisitor::VisitBindingBlock_(block);
- is_in_dataflow_block_ = false;
+ is_in_dataflow_block_ = cache;
+ }
+
+ void VisitBinding(const Binding& binding) override {
+ if (CanLiftBinding(binding)) {
+ info_.computable_at_compile_time.push_back(binding);
+ liftable_vars_.insert(binding->var);
+ } else {
+ info_.required_at_runtime.insert(binding->var);
+ auto bound_value = GetBoundValue(binding);
+ for (const auto& upstream_var : FreeVars(bound_value)) {
+ info_.required_at_runtime.insert(upstream_var);
+ }
+ for (const auto& tir_var : FreeSymbolicVars(bound_value)) {
+ info_.required_at_runtime.insert(tir_var);
+ }
+ }
}
- void VisitBinding_(const VarBindingNode* binding) final {
- bool can_lift = true;
+ bool CanLiftBinding(const Binding& binding) const {
+ auto value = GetBoundValue(binding);
// Cond 1. Do not lift bindings outside dataflow blocks.
if (!is_in_dataflow_block_) {
- can_lift = false;
+ return false;
}
// Cond 2. Do not lift regarding the "builtin.stop_lift_params" op.
- if (const auto* call = binding->value.as<CallNode>()) {
+ if (const auto* call = value.as<CallNode>()) {
static const Op& stop_lift_params_op =
Op::Get("relax.builtin.stop_lift_params");
if (call->op.same_as(stop_lift_params_op)) {
- can_lift = false;
+ return false;
}
}
// Cond 3. Do not lift when involving Vars that are not liftable.
- auto producers = FreeVars(binding->value);
- bool uses_only_liftable_producers =
builder_.UsesOnlyLiftableProducers(binding->value);
- if (!uses_only_liftable_producers) {
- can_lift = false;
+ for (const auto& var : FreeVars(value)) {
+ if (!liftable_vars_.count(var)) {
+ return false;
+ }
}
// Cond 4. Do not lift when its struct info contains symbolic variables
that do not appear in
// params.
for (const auto& var : TIRVarsInStructInfo(GetStructInfo(binding->var))) {
- if (!param_symbolic_vars_.count(var)) {
- can_lift = false;
+ if (!liftable_vars_.count(var)) {
+ return false;
}
}
// Cond 5. Do not lift declarations of external functions
- if (binding->value.as<relax::ExternFuncNode>()) {
- can_lift = false;
+ if (value.as<relax::ExternFuncNode>()) {
+ return false;
}
- if (can_lift) {
- builder_.AddInternalBinding(GetRef<VarBinding>(binding));
- } else {
- builder_.UpdateBasedOnRuntimeBinding(GetRef<VarBinding>(binding));
- }
+ return true;
}
- // The builder of the function that transforms the parameters
- TransformParamsFuncBuilder builder_;
- // Whether we are in a dataflow block
+ CollectInfo info_;
+ std::unordered_set<Variant<Var, tir::Var>, ObjectPtrHash, ObjectPtrEqual>
liftable_vars_;
bool is_in_dataflow_block_{false};
- // The symbolic variables in the parameters
- std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual>
param_symbolic_vars_;
};
-/*!
- *\brief The rewriter that lifts the transform params of a function and
updates the original
- * function.
- */
-class TransformParamsLifter : ExprMutator {
+class PreprocessPartitioner : public ExprMutator {
public:
- explicit TransformParamsLifter(const IRModule& module) : ExprMutator(module)
{}
-
- Function VisitFunction(GlobalVar gvar, Function func) {
- current_gvar_ = gvar;
- auto out = Downcast<Function>(VisitExpr(std::move(func)));
- current_gvar_ = NullOpt;
- return out;
- }
-
- Map<GlobalVar, Function> GetTransformParamFunctions() const { return
transform_param_funcs_; }
-
- private:
+ using ExprMutator::VisitExpr_;
Expr VisitExpr_(const FunctionNode* op) override {
auto func = GetRef<Function>(op);
- Optional<Integer> opt_num_input =
func->attrs.GetAttr<Integer>(attr::kNumInput);
- if (!opt_num_input) {
+ if (func->attrs.GetAttr<Integer>(attr::kNumInput)) {
+ auto info = LiftableBindingCollector::Collect(func);
+ return info.MakePartitionedFunction();
+ } else {
return func;
}
- auto signed_num_input = opt_num_input.value()->value;
- ICHECK_GE(signed_num_input, 0);
- ICHECK_LE(signed_num_input, func->params.size());
- size_t num_input = signed_num_input;
-
- LiftTransformParamsPlanner planner;
-
- // Step 1: Create the plan of lifting transform params
- lift_plan_ = planner.Plan(func, num_input);
-
- // Step 2: Stash the lifted function to add to the module
- transform_param_funcs_.Set(current_gvar_.value(),
lift_plan_.f_transform_params);
-
- // Step 3: Update the current function.
-
- // Step 3.1: Update the function signature
- Array<StructInfo> param_fields =
-
Downcast<TupleStructInfo>(lift_plan_.f_transform_params->ret_struct_info)->fields;
-
- Array<Var> new_params(func->params.begin(), func->params.begin() +
num_input);
- for (size_t i = 0; i < param_fields.size(); i++) {
- std::stringstream name;
- name << "transformed_param_" << i;
- Var param(name.str(), param_fields[i]);
- new_params.push_back(param);
- }
-
- // Step 3.2: Update the function body
- for (const auto& [var, index] : lift_plan_.output_to_index) {
- ICHECK_LT(num_input + index, new_params.size());
- param_remap_[var] = new_params[num_input + index];
- }
- auto new_body = VisitWithNewScope(func->body, new_params);
-
- return Function(new_params, new_body, func->ret_struct_info,
func->is_pure, func->attrs);
- }
-
- void VisitBinding_(const VarBindingNode* binding) final {
- if (lift_plan_.lifted_bindings.count(binding->var)) {
- return;
- }
- if (const auto* call = binding->value.as<CallNode>()) {
- static const Op& stop_lift_params_op =
Op::Get("relax.builtin.stop_lift_params");
- if (call->op.same_as(stop_lift_params_op)) {
- var_remap_[binding->var->vid] =
Downcast<Var>(VisitExpr(call->args[0]));
- return;
- }
- }
- ExprMutator::VisitBinding_(binding);
- }
-
- Expr VisitExpr_(const VarNode* var) final {
- auto it = param_remap_.find(GetRef<Var>(var));
- if (it != param_remap_.end()) {
- return builder_->Emit(it->second);
- }
- return ExprMutator::VisitExpr_(var);
}
+};
- // Remap the original parameters to TupleGetItem from the packed tuple of
transformed parameters.
- std::unordered_map<Var, Expr, ObjectPtrHash, ObjectPtrEqual> param_remap_;
- // The plan of lifting the transform params
- LiftTransformParamsInfoPlan lift_plan_;
+// Adapted from https://stackoverflow.com/a/2072890
+inline bool ends_with(const std::string& value, const std::string& ending) {
+ return ending.size() <= value.size() &&
+ std::equal(ending.rbegin(), ending.rend(), value.rbegin());
+}
- Map<GlobalVar, Function> transform_param_funcs_;
- Optional<GlobalVar> current_gvar_;
-};
+} // namespace
namespace transform {
-Pass LiftTransformParams() {
- runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule mod,
-
PassContext pc) {
- TransformParamsLifter mutator(mod);
+
+Pass PartitionTransformParams() {
+ auto pass_func = [=](IRModule mod, PassContext pc) {
+ PreprocessPartitioner mutator;
IRModule updates;
for (const auto& [gvar, func] : mod->functions) {
if (auto opt = func.as<relax::Function>()) {
- auto new_func = mutator.VisitFunction(gvar, opt.value());
+ auto new_func = Downcast<Function>(mutator(opt.value()));
if (!new_func.same_as(func)) {
updates->Add(gvar, new_func);
}
}
}
- for (auto [gvar, transform_func] : mutator.GetTransformParamFunctions()) {
- String name = gvar->name_hint + "_transform_params";
- GlobalVar new_gvar(name);
- new_gvar->struct_info_ = transform_func->struct_info_;
- transform_func = CopyWithNewVars(transform_func);
- transform_func = WithAttr(transform_func, tvm::attr::kGlobalSymbol,
name);
+ if (updates->functions.size()) {
+ mod.CopyOnWrite()->Update(updates);
+ }
+
+ return mod;
+ };
+ return tvm::transform::CreateModulePass(pass_func, 1,
"PartitionTransformParams", {});
+}
- updates->Add(new_gvar, transform_func);
+Pass LiftTransformParams() {
+ // A post-proc utility as as the third step in LiftTransformParams
+ //
+ // 1. PartitionTransformParams: Partition each function into a
+ // compile-time and run-time lambda functions.
+ //
+ // 2. LambdaLift: Lift the compile-time and run-time lambda
+ // functions out of the end-to-end function.
+ //
+ // 3. Post-proc: Expose the compile-time and run-time functions for
+ // external use, replacing the end-to-end functions.
+ auto post_proc_func = [=](IRModule mod, PassContext pc) {
+ std::unordered_set<GlobalVar, ObjectPtrHash, ObjectPtrEqual> to_remove;
+ std::unordered_map<GlobalVar, Function, ObjectPtrHash, ObjectPtrEqual>
to_add;
+ for (const auto& [gvar, base_func] : mod->functions) {
+ if (auto opt = base_func.as<Function>()) {
+ auto func = opt.value();
+
+ std::string func_name = gvar->name_hint;
+ if (ends_with(func_name, "transform_params")) {
+ func = WithAttr(func, tvm::attr::kGlobalSymbol, gvar->name_hint);
+ func = BundleModelParams(func);
+ to_add[gvar] = func;
+ } else if (ends_with(func_name, "_runtime")) {
+ std::string name(func_name.begin(), func_name.end() -
sizeof("_runtime") + 1);
+ to_remove.insert(mod->GetGlobalVar(name));
+ to_remove.insert(gvar);
+ to_add[GlobalVar(name)] = WithAttr(func, tvm::attr::kGlobalSymbol,
String(name));
+ }
+ }
}
- if (updates->functions.size()) {
- mod.CopyOnWrite()->Update(updates);
+ if (to_remove.size() || to_add.size()) {
+ auto write_ptr = mod.CopyOnWrite();
+ for (const auto& gvar : to_remove) {
+ write_ptr->Remove(gvar);
+ }
+ for (const auto& [gvar, func] : to_add) {
+ write_ptr->Add(gvar, func);
+ }
}
return mod;
};
- return CreateModulePass(pass_func, 1, "LiftTransformParams", {});
+ auto post_proc =
+ tvm::transform::CreateModulePass(post_proc_func, 1,
"LiftTransformParamsPostProc", {});
+
+ return tvm::transform::Sequential(
+ {
+ PartitionTransformParams(),
+ LambdaLift(),
+ post_proc,
+ },
+ "LiftTransformParams");
}
TVM_REGISTER_GLOBAL("relax.transform.LiftTransformParams").set_body_typed(LiftTransformParams);
diff --git a/tests/python/relax/test_transform_lift_transform_params.py
b/tests/python/relax/test_transform_lift_transform_params.py
index 5b24614469..8042765d40 100644
--- a/tests/python/relax/test_transform_lift_transform_params.py
+++ b/tests/python/relax/test_transform_lift_transform_params.py
@@ -64,15 +64,14 @@ def test_basic():
@R.function
def main(
x: R.Tensor((1, 3, 224, 224), dtype="float32"),
- param0: R.Tensor((16, 16, 3, 3), dtype="float32"),
- param1: R.Tensor((16, 3, 3, 3), dtype="float32"),
+ w2: R.Tensor((16, 16, 3, 3), dtype="float32"),
+ w1_transformed: R.Tensor((16, 3, 3, 3), dtype="float32"),
) -> R.Tensor((1, 16, 224, 224), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
- param1 = param1
conv1: R.Tensor((1, 16, 224, 224), dtype="float32") =
R.nn.conv2d(
x,
- param1,
+ w1_transformed,
strides=[1, 1],
padding=[1, 1, 1, 1],
dilation=[1, 1],
@@ -82,10 +81,9 @@ def test_basic():
out_layout="NCHW",
out_dtype="void",
)
- param0 = param0
conv2: R.Tensor((1, 16, 224, 224), dtype="float32") =
R.nn.conv2d(
conv1,
- param0,
+ w2,
strides=[1, 1],
padding=[1, 1, 1, 1],
dilation=[1, 1],
@@ -117,15 +115,16 @@ def test_basic():
) -> R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3),
dtype="float32")
):
+ R.func_attr({"num_input": 0})
cls = Expected
with R.dataflow():
- lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1]
lv1: R.Tensor((3, 16, 3, 3), dtype="float32") = params[0]
lv2 = R.call_tir(
cls.transform_layout_IOHW_to_OIHW,
(lv1,),
out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"),
)
+ lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1]
gv: R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"),
R.Tensor((16, 3, 3, 3), dtype="float32"),
@@ -137,6 +136,10 @@ def test_basic():
after = relax.transform.LiftTransformParams()(mod)
tvm.ir.assert_structural_equal(after, Expected)
+ names_after = [param.name_hint for param in after["main"].params]
+ names_expected = [param.name_hint for param in Expected["main"].params]
+ assert names_after == names_expected
+
def test_tuple():
@tvm.script.ir_module
@@ -168,10 +171,9 @@ def test_tuple():
) -> R.Tensor((1, 16, 224, 224), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
- lv: R.Tensor((16, 16, 3, 3), dtype="float32") = param1
conv1: R.Tensor((1, 16, 224, 224), dtype="float32") =
R.nn.conv2d(
x,
- lv,
+ param1,
strides=[1, 1],
padding=[1, 1, 1, 1],
dilation=[1, 1],
@@ -181,10 +183,9 @@ def test_tuple():
out_layout="NCHW",
out_dtype="void",
)
- lv1: R.Tensor((16, 16, 3, 3), dtype="float32") = param0
conv2: R.Tensor((1, 16, 224, 224), dtype="float32") =
R.nn.conv2d(
conv1,
- lv1,
+ param0,
strides=[1, 1],
padding=[1, 1, 1, 1],
dilation=[1, 1],
@@ -203,17 +204,14 @@ def test_tuple():
) -> R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3,
3), dtype="float32")
):
- with R.dataflow():
- lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[0]
- lv1: R.Tensor((16, 16, 3, 3), dtype="float32") = params[0]
- l0: R.Tuple(R.Tensor((16, 16, 3, 3), dtype="float32")) = (lv1,)
- l1: R.Tuple(R.Tuple(R.Tensor((16, 16, 3, 3),
dtype="float32"))) = (l0,)
- l2: R.Tuple(R.Tensor((16, 16, 3, 3), dtype="float32")) = l1[0]
- lv2: R.Tensor((16, 16, 3, 3), dtype="float32") = l2[0]
- gv: R.Tuple(
- R.Tensor((16, 16, 3, 3), dtype="float32"),
- R.Tensor((16, 16, 3, 3), dtype="float32"),
- ) = (lv, lv2)
+ R.func_attr({"num_input": 0})
+ with R.dataflow():
+ lv = params[0]
+ lv0 = (lv,)
+ lv1 = (lv0,)
+ lv2 = params[0]
+ lv3 = params[0]
+ gv = (lv2, lv3)
R.output(gv)
return gv
@@ -258,6 +256,7 @@ def test_condition():
R.Tensor((16, 16, 3, 3), dtype="float32"),
R.Tensor((), dtype="bool"),
):
+ R.func_attr({"num_input": 0})
with R.dataflow():
lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[0]
lv1: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1]
@@ -278,13 +277,10 @@ def test_condition():
param2: R.Tensor((), dtype="bool"),
) -> R.Tensor((1, 16, 224, 224), "float32"):
R.func_attr({"num_input": 1})
- gv: R.Tensor((), dtype="bool") = param2
- if gv:
- gv1: R.Tensor((16, 16, 3, 3), dtype="float32") = param0
- w: R.Tensor((16, 16, 3, 3), dtype="float32") = gv1
+ if param2:
+ w: R.Tensor((16, 16, 3, 3), dtype="float32") = param0
else:
- gv2: R.Tensor((16, 16, 3, 3), dtype="float32") = param1
- w: R.Tensor((16, 16, 3, 3), dtype="float32") = gv2
+ w: R.Tensor((16, 16, 3, 3), dtype="float32") = param1
with R.dataflow():
conv1 = R.nn.conv2d(x, w, padding=(1, 1), data_layout="NCHW",
kernel_layout="OIHW")
R.output(conv1)
@@ -342,8 +338,7 @@ def test_multiple_functions():
) -> R.Tensor((256, 256), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
- lv: R.Tensor((256, 256), dtype="float32") = param0
- y: R.Tensor((256, 256), dtype="float32") = R.matmul(x, lv,
out_dtype="void")
+ y: R.Tensor((256, 256), dtype="float32") = R.matmul(x, param0,
out_dtype="void")
R.output(y)
return y
@@ -351,6 +346,7 @@ def test_multiple_functions():
def func1_transform_params(
params: R.Tuple(R.Tensor((256, 256), dtype="float32"))
) -> R.Tuple(R.Tensor((256, 256), dtype="float32")):
+ R.func_attr({"num_input": 0})
with R.dataflow():
lv: R.Tensor((256, 256), dtype="float32") = params[0]
lv1: R.Tensor((256, 256), dtype="float32") =
R.permute_dims(lv, axes=[1, 0])
@@ -365,8 +361,7 @@ def test_multiple_functions():
) -> R.Tensor((256, 128), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
- lv1: R.Tensor((256, 128), dtype="float32") = param0
- y: R.Tensor((256, 128), dtype="float32") = R.matmul(x, lv1,
out_dtype="void")
+ y: R.Tensor((256, 128), dtype="float32") = R.matmul(x, param0,
out_dtype="void")
R.output(y)
return y
@@ -374,6 +369,7 @@ def test_multiple_functions():
def func2_transform_params(
params: R.Tuple(R.Tensor((128, 256), dtype="float32"))
) -> R.Tuple(R.Tensor((256, 128), dtype="float32")):
+ R.func_attr({"num_input": 0})
with R.dataflow():
lv: R.Tensor((128, 256), dtype="float32") = params[0]
lv1: R.Tensor((256, 128), dtype="float32") =
R.permute_dims(lv, axes=[1, 0])
@@ -422,8 +418,7 @@ def test_stop_lifting():
) -> R.Tensor((256, 256), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
- lv: R.Tensor((256, 256), dtype="float32") = param0
- w1_add: R.Tensor((256, 256), dtype="float32") = R.add(lv,
R.const(1, "float32"))
+ w1_add: R.Tensor((256, 256), dtype="float32") = R.add(param0,
R.const(1, "float32"))
y: R.Tensor((256, 256), dtype="float32") = R.matmul(x, w1_add,
out_dtype="void")
R.output(y)
return y
@@ -432,6 +427,7 @@ def test_stop_lifting():
def func1_transform_params(
params: R.Tuple(R.Tensor((256, 256), dtype="float32"))
) -> R.Tuple(R.Tensor((256, 256), dtype="float32")):
+ R.func_attr({"num_input": 0})
with R.dataflow():
lv: R.Tensor((256, 256), dtype="float32") = params[0]
lv1: R.Tensor((256, 256), dtype="float32") =
R.permute_dims(lv, axes=[1, 0])
@@ -459,6 +455,7 @@ def test_symbolic_var_1():
class Expected:
@R.function
def main_transform_params(params: R.Tuple) -> R.Tuple:
+ R.func_attr({"num_input": 0})
with R.dataflow():
gv: R.Tuple = R.tuple()
R.output(gv)
@@ -522,6 +519,7 @@ def test_symbolic_var_2():
@R.function
def main_transform_params(params: R.Tuple) -> R.Tuple:
+ R.func_attr({"num_input": 0})
with R.dataflow():
gv: R.Tuple = R.tuple()
R.output(gv)
@@ -603,7 +601,6 @@ def test_symbolic_var_from_shape():
tir_vars=R.ShapeExpr([slice_index]),
out_sinfo=R.Tensor([16], dtype="int32"),
)
- B_slice = B_slice
A_scale = R.multiply(A_slice, B_slice)
R.output(A_scale)
return A_scale
@@ -612,18 +609,19 @@ def test_symbolic_var_from_shape():
def main_transform_params(
params: R.Tuple(R.Tensor([16, 16], "int32"),
R.Shape(["slice_index"]))
):
+ R.func_attr({"num_input": 0})
slice_index = T.int64()
cls = Expected
with R.dataflow():
- extra_symbolic_vars = R.ShapeExpr([slice_index])
B = params[0]
+ # extra_symbolic_vars = params[1]
B_slice = R.call_tir(
cls.slice,
[B],
tir_vars=R.ShapeExpr([slice_index]),
out_sinfo=R.Tensor([16], dtype="int32"),
)
- output = (extra_symbolic_vars, B_slice)
+ output = (R.ShapeExpr([slice_index]), B_slice)
R.output(output)
return output
@@ -652,7 +650,7 @@ def test_symbolic_var_in_param_shape():
x: R.Tensor((1, 16, 224, "n"), "float32"),
w1: R.Tensor((16, "m", 3, 3), "float32"),
w2: R.Tensor((16, "m", 3, 3), "float32"),
- ) -> R.Tensor((1, 16, 224, 224), "float32"):
+ ) -> R.Tensor((1, 16, 224, "n"), "float32"):
m = T.int64()
n = T.int64()
R.func_attr({"num_input": 1})
@@ -677,11 +675,12 @@ def test_symbolic_var_in_param_shape():
) -> R.Tuple(
R.Tensor((16, "m", 3, 3), dtype="float32"), R.Tensor((16, "m", 3,
3), dtype="float32")
):
+ R.func_attr({"num_input": 0})
m = T.int64()
with R.dataflow():
- lv: R.Tensor((16, m, 3, 3), dtype="float32") = params[1]
lv1: R.Tensor((16, m, 3, 3), dtype="float32") = params[0]
lv2: R.Tensor((16, m, 3, 3), dtype="float32") = R.add(lv1,
R.const(1, "float32"))
+ lv: R.Tensor((16, m, 3, 3), dtype="float32") = params[1]
gv: R.Tuple(
R.Tensor((16, m, 3, 3), dtype="float32"),
R.Tensor((16, m, 3, 3), dtype="float32"),
@@ -694,16 +693,15 @@ def test_symbolic_var_in_param_shape():
x: R.Tensor((1, 16, 224, "n"), dtype="float32"),
transformed_param_0: R.Tensor((16, "m", 3, 3), dtype="float32"),
transformed_param_1: R.Tensor((16, "m", 3, 3), dtype="float32"),
- ) -> R.Tensor((1, 16, 224, 224), dtype="float32"):
+ ) -> R.Tensor((1, 16, 224, "n"), dtype="float32"):
n = T.int64()
m = T.int64()
R.func_attr({"num_input": 1})
with R.dataflow():
zeros: R.Tensor((n, n), dtype="float32") = R.zeros(R.shape([n,
n]), dtype="float32")
- lv: R.Tensor((16, m, 3, 3), dtype="float32") =
transformed_param_1
conv1: R.Tensor((1, 16, 224, n), dtype="float32") =
R.nn.conv2d(
x,
- lv,
+ transformed_param_1,
strides=[1, 1],
padding=[1, 1, 1, 1],
dilation=[1, 1],
@@ -713,10 +711,9 @@ def test_symbolic_var_in_param_shape():
out_layout="NCHW",
out_dtype="void",
)
- lv1: R.Tensor((16, m, 3, 3), dtype="float32") =
transformed_param_0
conv2: R.Tensor((1, 16, 224, n), dtype="float32") =
R.nn.conv2d(
conv1,
- lv1,
+ transformed_param_0,
strides=[1, 1],
padding=[1, 1, 1, 1],
dilation=[1, 1],
@@ -770,6 +767,7 @@ def
test_symbolic_var_defined_in_params_but_used_in_weights():
def main_transform_params(
params: R.Tuple(R.Tensor(("k",), dtype="float32"))
) -> R.Tuple(R.Tensor(dtype="float32", ndim=1)):
+ R.func_attr({"num_input": 0})
k = T.int64()
with R.dataflow():
lv: R.Tensor((k,), dtype="float32") = params[0]