Lunderberg commented on code in PR #16411:
URL: https://github.com/apache/tvm/pull/16411#discussion_r1462330502
##########
src/relax/transform/lambda_lift.cc:
##########
@@ -336,176 +266,235 @@ class LambdaLifter : public ExprMutator {
return it->second;
}();
- auto global = GlobalVar(lift_func_name);
- Array<Var> free_vars = FreeVars(func);
Array<Var> captured_vars;
-
- Array<Var> typed_captured_vars;
- bool recursive = false;
- for (const auto& var : free_vars) {
- if (!recur_vars_.empty() && var == recur_vars_.back()) {
- recursive = true;
+ bool is_recursive = false;
+ bool is_closure = false;
+ for (const auto& var : FreeVars(func)) {
+ if (var.same_as(current_lambda_var_)) {
+ is_recursive = true;
} else {
+ is_closure = true;
captured_vars.push_back(var);
}
}
+ Array<Var> typed_captured_vars;
Map<Var, Expr> rebinding_map;
for (auto free_var : captured_vars) {
Var var = Var(free_var->name_hint(), GetStructInfo(free_var),
free_var->span);
typed_captured_vars.push_back(var);
rebinding_map.Set(free_var, var);
}
- // recursive call
- if (recursive) {
- if (!captured_vars.empty()) {
- Array<Expr> fvs;
- for (auto fv : captured_vars) {
- fvs.push_back(fv);
- }
- // it is required by block_blocker, will be updated later
- UpdateStructInfo(global, GetStructInfo(recur_vars_.back()));
- lambda_map_.emplace(recur_vars_.back(), Call(global, fvs));
- } else {
- if (recur_vars_.size() > 0) {
- lambda_map_.emplace(recur_vars_.back(), global);
- }
- }
+ tvm::Array<Var> lifted_func_params =
+ func_node->params.Map([this](Var var) { return VisitVarDef(var); });
+ for (const auto& var : typed_captured_vars) {
+ lifted_func_params.push_back(var);
}
- tvm::Array<Var> params;
- bool all_params_unchanged = true;
- for (Var param : func_node->params) {
- Var new_param = this->VisitVarDef(param);
- params.push_back(new_param);
- all_params_unchanged &= param.same_as(new_param);
+ auto gvar_lifted_func = GlobalVar(lift_func_name);
+ {
+ auto func_sinfo = Downcast<FuncStructInfo>(func_node->struct_info_);
+ if (is_closure) {
+ func_sinfo = FuncStructInfo(lifted_func_params.Map(GetStructInfo),
func_sinfo->ret,
+ func_sinfo->purity);
+ }
+ UpdateStructInfo(gvar_lifted_func, func_sinfo);
}
- Expr body = this->VisitWithNewScope(func_node->body);
- Expr visited_func;
+ Expr body = func_node->body;
- if (all_params_unchanged && body.same_as(func_node->body)) {
- visited_func = GetRef<Expr>(func_node);
- } else if (const auto& body_sinfo =
MatchStructInfo<ObjectStructInfo>(body)) {
- visited_func =
- Function(params, body, body_sinfo.value(), func_node->is_pure,
func_node->attrs);
- } else {
- visited_func =
- Function(params, body, func_node->ret_struct_info,
func_node->is_pure, func_node->attrs);
+ // recursive call
+ if (is_recursive && is_closure) {
+ // it is required by block_blocker, will be updated later
+ nested_closure_map_.emplace(
+ current_lambda_var_.value(),
+ Call(gvar_lifted_func, captured_vars.Map([](Var var) -> Expr {
return var; })));
}
- auto new_func = Downcast<Function>(visited_func);
- Function lifted_func;
- bool is_closure = IsClosure(captured_vars);
if (!is_closure) {
- lifted_func = Function(
- /*params=*/new_func->params,
- /*body=*/new_func->body,
- /*ret_struct_info=*/new_func->ret_struct_info,
- /*is_pure=*/new_func->is_pure,
- /*attrs=*/new_func->attrs,
- /*span=*/new_func->span);
- } else {
- // Flatten the Closure
- std::vector<Var> closure_params;
- closure_params.reserve(func->params.size() + typed_captured_vars.size());
- for (size_t i = 0; i < func->params.size(); ++i) {
- closure_params.emplace_back(func->params[i]);
- }
- for (size_t i = 0; i < typed_captured_vars.size(); ++i) {
- closure_params.emplace_back(typed_captured_vars[i]);
- }
+ rebind_map_.emplace(current_lambda_var_.value(), gvar_lifted_func);
+ }
- lifted_func = Function(/*params=*/closure_params,
- /*body=*/Bind(new_func->body, rebinding_map),
- /*ret_struct_info=*/new_func->ret_struct_info,
- /*is_pure=*/new_func->is_pure,
- /*attrs=*/new_func->attrs,
- /*span=*/func->span);
+ body = this->VisitWithNewScope(body, lifted_func_params);
+ StructInfo ret_struct_info = GetStructInfo(body);
+ body = Bind(body, rebinding_map);
- for (Var param : closure_params) {
- CHECK(param->checked_type_.defined())
- << "relax.Function requires params to contain checked_type_";
- }
+ Function lifted_func;
+ if (lifted_func_params.same_as(func_node->params) &&
body.same_as(func_node->body) &&
+ ret_struct_info.same_as(func_node->ret_struct_info)) {
+ lifted_func = GetRef<Function>(func_node);
+ } else {
+ lifted_func =
+ Function(lifted_func_params, body, ret_struct_info,
func_node->is_pure, func_node->attrs);
+ }
+
+ for (Var param : lifted_func->params) {
+ CHECK(param->checked_type_.defined())
+ << "relax.Function requires all parameters to contain checked_type_.
"
+ << "However, parameter " << param << " with struct info " <<
param->struct_info_
+ << " has no checked type";
}
ICHECK(lifted_func.defined());
+ if (is_closure || IsClosure(lifted_func)) {
+ closures_.insert(gvar_lifted_func);
+ }
+
// Add the lifted function to the module.
- global->struct_info_ = GetStructInfo(lifted_func);
- global->checked_type_ = lifted_func->checked_type_;
- builder_->UpdateFunction(global, lifted_func);
+ lifted_func = CopyWithNewVars(lifted_func);
+ gvar_lifted_func->struct_info_ = GetStructInfo(lifted_func);
+ gvar_lifted_func->checked_type_ = lifted_func->checked_type_;
- if (!is_closure) {
- return std::move(global);
- } else {
+ builder_->UpdateFunction(gvar_lifted_func, lifted_func);
+
+ Expr callable_value = gvar_lifted_func;
+ if (is_closure) {
// If we need to allocate a closure,
// we pass the variables in its environment here.
- Array<Expr> fvs;
- for (auto fv : captured_vars) {
- fvs.push_back(fv);
- }
+ Tuple arg_tuple(captured_vars.Map([](Var var) -> Expr { return var; }));
// Call make_closure intrinsic
- return Call(make_closure_op_, {global, Tuple(fvs)}, {}, {});
+ callable_value = Call(make_closure_op_, {gvar_lifted_func, arg_tuple},
{}, {});
}
+
+ return callable_value;
}
- bool HasClosure(const Var& var) {
- auto val = builder_->LookupBinding(var);
- if (const auto* value = val.as<GlobalVarNode>()) {
- IRModule ctx_mod = builder_->GetContextIRModule();
- ICHECK(ctx_mod->functions.size() > 0);
- BaseFunc func = ctx_mod->Lookup(GetRef<GlobalVar>(value));
- if (const auto* func_node = func.as<FunctionNode>()) {
- if (const auto* call_node = func_node->body.as<CallNode>()) {
- if (call_node->op == make_closure_op_) {
- return true;
- }
- } else if (const auto* seq_expr_node =
func_node->body.as<SeqExprNode>()) {
- // the return var points to a make_closure intrinsic
- if (const auto* var = seq_expr_node->body.as<VarNode>()) {
- return HasClosure(GetRef<Var>(var));
+ Expr VisitExpr_(const CallNode* call_node) final {
+ auto call = GetRef<Call>(call_node);
+
+ auto orig_sinfo = Downcast<StructInfo>(call->struct_info_);
+
+ if (auto opt_var = call->op.as<Var>()) {
+ auto var = opt_var.value();
+
+ // Call "relax.invoke_closure" to invoke closure
+
+ if (bool is_closure = IsClosure(var);
+ is_closure && builder_->LookupBinding(var).as<CallNode>()) {
Review Comment:
Nope. A previous implementation of `IsClosure` returned an `Optional<_>`,
and this stuck around. Updated.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]