slyubomirsky commented on code in PR #16411:
URL: https://github.com/apache/tvm/pull/16411#discussion_r1462672163
##########
src/relax/transform/lambda_lift.cc:
##########
@@ -336,176 +266,235 @@ class LambdaLifter : public ExprMutator {
return it->second;
}();
- auto global = GlobalVar(lift_func_name);
- Array<Var> free_vars = FreeVars(func);
Array<Var> captured_vars;
-
- Array<Var> typed_captured_vars;
- bool recursive = false;
- for (const auto& var : free_vars) {
- if (!recur_vars_.empty() && var == recur_vars_.back()) {
- recursive = true;
+ bool is_recursive = false;
+ bool is_closure = false;
+ for (const auto& var : FreeVars(func)) {
+ if (var.same_as(current_lambda_var_)) {
+ is_recursive = true;
} else {
+ is_closure = true;
captured_vars.push_back(var);
}
}
+ Array<Var> typed_captured_vars;
Map<Var, Expr> rebinding_map;
for (auto free_var : captured_vars) {
Var var = Var(free_var->name_hint(), GetStructInfo(free_var),
free_var->span);
typed_captured_vars.push_back(var);
rebinding_map.Set(free_var, var);
}
- // recursive call
- if (recursive) {
- if (!captured_vars.empty()) {
- Array<Expr> fvs;
- for (auto fv : captured_vars) {
- fvs.push_back(fv);
- }
- // it is required by block_blocker, will be updated later
- UpdateStructInfo(global, GetStructInfo(recur_vars_.back()));
- lambda_map_.emplace(recur_vars_.back(), Call(global, fvs));
- } else {
- if (recur_vars_.size() > 0) {
- lambda_map_.emplace(recur_vars_.back(), global);
- }
- }
+ tvm::Array<Var> lifted_func_params =
+ func_node->params.Map([this](Var var) { return VisitVarDef(var); });
+ for (const auto& var : typed_captured_vars) {
+ lifted_func_params.push_back(var);
}
- tvm::Array<Var> params;
- bool all_params_unchanged = true;
- for (Var param : func_node->params) {
- Var new_param = this->VisitVarDef(param);
- params.push_back(new_param);
- all_params_unchanged &= param.same_as(new_param);
+ auto gvar_lifted_func = GlobalVar(lift_func_name);
+ {
+ auto func_sinfo = Downcast<FuncStructInfo>(func_node->struct_info_);
+ if (is_closure) {
+ func_sinfo = FuncStructInfo(lifted_func_params.Map(GetStructInfo),
func_sinfo->ret,
+ func_sinfo->purity);
+ }
+ UpdateStructInfo(gvar_lifted_func, func_sinfo);
}
- Expr body = this->VisitWithNewScope(func_node->body);
- Expr visited_func;
+ Expr body = func_node->body;
- if (all_params_unchanged && body.same_as(func_node->body)) {
- visited_func = GetRef<Expr>(func_node);
- } else if (const auto& body_sinfo =
MatchStructInfo<ObjectStructInfo>(body)) {
- visited_func =
- Function(params, body, body_sinfo.value(), func_node->is_pure,
func_node->attrs);
- } else {
- visited_func =
- Function(params, body, func_node->ret_struct_info,
func_node->is_pure, func_node->attrs);
+ // recursive call
+ if (is_recursive && is_closure) {
+ // it is required by block_blocker, will be updated later
+ nested_closure_map_.emplace(
+ current_lambda_var_.value(),
+ Call(gvar_lifted_func, captured_vars.Map([](Var var) -> Expr {
return var; })));
Review Comment:
Might have to be careful to ensure type safety can't get broken that way,
that has tended to lead to soundness issues in languages. (I believe that type
of thing is a source of unsoundness in TypeScript, IIRC.)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]