slyubomirsky commented on code in PR #16411:
URL: https://github.com/apache/tvm/pull/16411#discussion_r1460702809
##########
src/relax/transform/lambda_lift.cc:
##########
@@ -336,176 +266,235 @@ class LambdaLifter : public ExprMutator {
return it->second;
}();
- auto global = GlobalVar(lift_func_name);
- Array<Var> free_vars = FreeVars(func);
Array<Var> captured_vars;
-
- Array<Var> typed_captured_vars;
- bool recursive = false;
- for (const auto& var : free_vars) {
- if (!recur_vars_.empty() && var == recur_vars_.back()) {
- recursive = true;
+ bool is_recursive = false;
+ bool is_closure = false;
+ for (const auto& var : FreeVars(func)) {
+ if (var.same_as(current_lambda_var_)) {
+ is_recursive = true;
} else {
+ is_closure = true;
captured_vars.push_back(var);
}
}
+ Array<Var> typed_captured_vars;
Map<Var, Expr> rebinding_map;
for (auto free_var : captured_vars) {
Var var = Var(free_var->name_hint(), GetStructInfo(free_var),
free_var->span);
typed_captured_vars.push_back(var);
rebinding_map.Set(free_var, var);
}
- // recursive call
- if (recursive) {
- if (!captured_vars.empty()) {
- Array<Expr> fvs;
- for (auto fv : captured_vars) {
- fvs.push_back(fv);
- }
- // it is required by block_blocker, will be updated later
- UpdateStructInfo(global, GetStructInfo(recur_vars_.back()));
- lambda_map_.emplace(recur_vars_.back(), Call(global, fvs));
- } else {
- if (recur_vars_.size() > 0) {
- lambda_map_.emplace(recur_vars_.back(), global);
- }
- }
+ tvm::Array<Var> lifted_func_params =
+ func_node->params.Map([this](Var var) { return VisitVarDef(var); });
+ for (const auto& var : typed_captured_vars) {
+ lifted_func_params.push_back(var);
}
- tvm::Array<Var> params;
- bool all_params_unchanged = true;
- for (Var param : func_node->params) {
- Var new_param = this->VisitVarDef(param);
- params.push_back(new_param);
- all_params_unchanged &= param.same_as(new_param);
+ auto gvar_lifted_func = GlobalVar(lift_func_name);
+ {
+ auto func_sinfo = Downcast<FuncStructInfo>(func_node->struct_info_);
+ if (is_closure) {
+ func_sinfo = FuncStructInfo(lifted_func_params.Map(GetStructInfo),
func_sinfo->ret,
+ func_sinfo->purity);
+ }
+ UpdateStructInfo(gvar_lifted_func, func_sinfo);
}
- Expr body = this->VisitWithNewScope(func_node->body);
- Expr visited_func;
+ Expr body = func_node->body;
- if (all_params_unchanged && body.same_as(func_node->body)) {
- visited_func = GetRef<Expr>(func_node);
- } else if (const auto& body_sinfo =
MatchStructInfo<ObjectStructInfo>(body)) {
- visited_func =
- Function(params, body, body_sinfo.value(), func_node->is_pure,
func_node->attrs);
- } else {
- visited_func =
- Function(params, body, func_node->ret_struct_info,
func_node->is_pure, func_node->attrs);
+ // recursive call
+ if (is_recursive && is_closure) {
+ // it is required by block_blocker, will be updated later
Review Comment:
I assume you mean the BlockBuilder?
##########
src/relax/transform/lambda_lift.cc:
##########
@@ -336,176 +266,235 @@ class LambdaLifter : public ExprMutator {
return it->second;
}();
- auto global = GlobalVar(lift_func_name);
- Array<Var> free_vars = FreeVars(func);
Array<Var> captured_vars;
-
- Array<Var> typed_captured_vars;
- bool recursive = false;
- for (const auto& var : free_vars) {
- if (!recur_vars_.empty() && var == recur_vars_.back()) {
- recursive = true;
+ bool is_recursive = false;
+ bool is_closure = false;
+ for (const auto& var : FreeVars(func)) {
+ if (var.same_as(current_lambda_var_)) {
+ is_recursive = true;
} else {
+ is_closure = true;
captured_vars.push_back(var);
}
}
+ Array<Var> typed_captured_vars;
Map<Var, Expr> rebinding_map;
for (auto free_var : captured_vars) {
Var var = Var(free_var->name_hint(), GetStructInfo(free_var),
free_var->span);
typed_captured_vars.push_back(var);
rebinding_map.Set(free_var, var);
}
- // recursive call
- if (recursive) {
- if (!captured_vars.empty()) {
- Array<Expr> fvs;
- for (auto fv : captured_vars) {
- fvs.push_back(fv);
- }
- // it is required by block_blocker, will be updated later
- UpdateStructInfo(global, GetStructInfo(recur_vars_.back()));
- lambda_map_.emplace(recur_vars_.back(), Call(global, fvs));
- } else {
- if (recur_vars_.size() > 0) {
- lambda_map_.emplace(recur_vars_.back(), global);
- }
- }
+ tvm::Array<Var> lifted_func_params =
+ func_node->params.Map([this](Var var) { return VisitVarDef(var); });
+ for (const auto& var : typed_captured_vars) {
+ lifted_func_params.push_back(var);
}
- tvm::Array<Var> params;
- bool all_params_unchanged = true;
- for (Var param : func_node->params) {
- Var new_param = this->VisitVarDef(param);
- params.push_back(new_param);
- all_params_unchanged &= param.same_as(new_param);
+ auto gvar_lifted_func = GlobalVar(lift_func_name);
+ {
+ auto func_sinfo = Downcast<FuncStructInfo>(func_node->struct_info_);
+ if (is_closure) {
+ func_sinfo = FuncStructInfo(lifted_func_params.Map(GetStructInfo),
func_sinfo->ret,
+ func_sinfo->purity);
+ }
+ UpdateStructInfo(gvar_lifted_func, func_sinfo);
}
- Expr body = this->VisitWithNewScope(func_node->body);
- Expr visited_func;
+ Expr body = func_node->body;
- if (all_params_unchanged && body.same_as(func_node->body)) {
- visited_func = GetRef<Expr>(func_node);
- } else if (const auto& body_sinfo =
MatchStructInfo<ObjectStructInfo>(body)) {
- visited_func =
- Function(params, body, body_sinfo.value(), func_node->is_pure,
func_node->attrs);
- } else {
- visited_func =
- Function(params, body, func_node->ret_struct_info,
func_node->is_pure, func_node->attrs);
+ // recursive call
+ if (is_recursive && is_closure) {
+ // it is required by block_blocker, will be updated later
Review Comment:
I assume you mean the BlockBuilder?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]