slyubomirsky commented on code in PR #16411:
URL: https://github.com/apache/tvm/pull/16411#discussion_r1460702833
##########
src/relax/transform/lambda_lift.cc:
##########
@@ -336,176 +266,235 @@ class LambdaLifter : public ExprMutator {
return it->second;
}();
- auto global = GlobalVar(lift_func_name);
- Array<Var> free_vars = FreeVars(func);
Array<Var> captured_vars;
-
- Array<Var> typed_captured_vars;
- bool recursive = false;
- for (const auto& var : free_vars) {
- if (!recur_vars_.empty() && var == recur_vars_.back()) {
- recursive = true;
+ bool is_recursive = false;
+ bool is_closure = false;
+ for (const auto& var : FreeVars(func)) {
+ if (var.same_as(current_lambda_var_)) {
+ is_recursive = true;
} else {
+ is_closure = true;
captured_vars.push_back(var);
}
}
+ Array<Var> typed_captured_vars;
Map<Var, Expr> rebinding_map;
for (auto free_var : captured_vars) {
Var var = Var(free_var->name_hint(), GetStructInfo(free_var),
free_var->span);
typed_captured_vars.push_back(var);
rebinding_map.Set(free_var, var);
}
- // recursive call
- if (recursive) {
- if (!captured_vars.empty()) {
- Array<Expr> fvs;
- for (auto fv : captured_vars) {
- fvs.push_back(fv);
- }
- // it is required by block_blocker, will be updated later
- UpdateStructInfo(global, GetStructInfo(recur_vars_.back()));
- lambda_map_.emplace(recur_vars_.back(), Call(global, fvs));
- } else {
- if (recur_vars_.size() > 0) {
- lambda_map_.emplace(recur_vars_.back(), global);
- }
- }
+ tvm::Array<Var> lifted_func_params =
+ func_node->params.Map([this](Var var) { return VisitVarDef(var); });
+ for (const auto& var : typed_captured_vars) {
+ lifted_func_params.push_back(var);
}
- tvm::Array<Var> params;
- bool all_params_unchanged = true;
- for (Var param : func_node->params) {
- Var new_param = this->VisitVarDef(param);
- params.push_back(new_param);
- all_params_unchanged &= param.same_as(new_param);
+ auto gvar_lifted_func = GlobalVar(lift_func_name);
+ {
+ auto func_sinfo = Downcast<FuncStructInfo>(func_node->struct_info_);
+ if (is_closure) {
+ func_sinfo = FuncStructInfo(lifted_func_params.Map(GetStructInfo),
func_sinfo->ret,
+ func_sinfo->purity);
+ }
+ UpdateStructInfo(gvar_lifted_func, func_sinfo);
}
- Expr body = this->VisitWithNewScope(func_node->body);
- Expr visited_func;
+ Expr body = func_node->body;
- if (all_params_unchanged && body.same_as(func_node->body)) {
- visited_func = GetRef<Expr>(func_node);
- } else if (const auto& body_sinfo =
MatchStructInfo<ObjectStructInfo>(body)) {
- visited_func =
- Function(params, body, body_sinfo.value(), func_node->is_pure,
func_node->attrs);
- } else {
- visited_func =
- Function(params, body, func_node->ret_struct_info,
func_node->is_pure, func_node->attrs);
+ // recursive call
+ if (is_recursive && is_closure) {
+ // it is required by block_blocker, will be updated later
Review Comment:
It looks like this typo was present in the original, but I presume this
comment should refer to the BlockBuilder
##########
src/relax/transform/lambda_lift.cc:
##########
@@ -236,95 +236,25 @@ class LambdaLifter : public ExprMutator {
using ExprMutator::VisitExpr_;
- void VisitBinding_(const VarBindingNode* binding) final {
- bool is_lambda = binding->value->IsInstance<FunctionNode>();
- if (is_lambda) {
- recur_vars_.push_back(binding->var);
+ void VisitBinding_(const VarBindingNode* binding, const FunctionNode*
func_node) final {
+ auto cache = current_lambda_var_;
+ current_lambda_var_ = binding->var;
+
+ // ExprMutator::VisitBinding_(binding, func_node);
Review Comment:
Probably subjective, but this comment is a little more cryptic compared to
just pointing out that we are visiting a function literal
##########
src/relax/transform/lambda_lift.cc:
##########
@@ -336,176 +266,235 @@ class LambdaLifter : public ExprMutator {
return it->second;
}();
- auto global = GlobalVar(lift_func_name);
- Array<Var> free_vars = FreeVars(func);
Array<Var> captured_vars;
-
- Array<Var> typed_captured_vars;
- bool recursive = false;
- for (const auto& var : free_vars) {
- if (!recur_vars_.empty() && var == recur_vars_.back()) {
- recursive = true;
+ bool is_recursive = false;
+ bool is_closure = false;
+ for (const auto& var : FreeVars(func)) {
+ if (var.same_as(current_lambda_var_)) {
+ is_recursive = true;
} else {
+ is_closure = true;
captured_vars.push_back(var);
}
}
+ Array<Var> typed_captured_vars;
Map<Var, Expr> rebinding_map;
for (auto free_var : captured_vars) {
Var var = Var(free_var->name_hint(), GetStructInfo(free_var),
free_var->span);
typed_captured_vars.push_back(var);
rebinding_map.Set(free_var, var);
}
- // recursive call
- if (recursive) {
- if (!captured_vars.empty()) {
- Array<Expr> fvs;
- for (auto fv : captured_vars) {
- fvs.push_back(fv);
- }
- // it is required by block_blocker, will be updated later
- UpdateStructInfo(global, GetStructInfo(recur_vars_.back()));
- lambda_map_.emplace(recur_vars_.back(), Call(global, fvs));
- } else {
- if (recur_vars_.size() > 0) {
- lambda_map_.emplace(recur_vars_.back(), global);
- }
- }
+ tvm::Array<Var> lifted_func_params =
+ func_node->params.Map([this](Var var) { return VisitVarDef(var); });
+ for (const auto& var : typed_captured_vars) {
+ lifted_func_params.push_back(var);
}
- tvm::Array<Var> params;
- bool all_params_unchanged = true;
- for (Var param : func_node->params) {
- Var new_param = this->VisitVarDef(param);
- params.push_back(new_param);
- all_params_unchanged &= param.same_as(new_param);
+ auto gvar_lifted_func = GlobalVar(lift_func_name);
+ {
+ auto func_sinfo = Downcast<FuncStructInfo>(func_node->struct_info_);
+ if (is_closure) {
+ func_sinfo = FuncStructInfo(lifted_func_params.Map(GetStructInfo),
func_sinfo->ret,
+ func_sinfo->purity);
+ }
+ UpdateStructInfo(gvar_lifted_func, func_sinfo);
}
- Expr body = this->VisitWithNewScope(func_node->body);
- Expr visited_func;
+ Expr body = func_node->body;
- if (all_params_unchanged && body.same_as(func_node->body)) {
- visited_func = GetRef<Expr>(func_node);
- } else if (const auto& body_sinfo =
MatchStructInfo<ObjectStructInfo>(body)) {
- visited_func =
- Function(params, body, body_sinfo.value(), func_node->is_pure,
func_node->attrs);
- } else {
- visited_func =
- Function(params, body, func_node->ret_struct_info,
func_node->is_pure, func_node->attrs);
+ // recursive call
+ if (is_recursive && is_closure) {
+ // it is required by block_blocker, will be updated later
+ nested_closure_map_.emplace(
+ current_lambda_var_.value(),
+ Call(gvar_lifted_func, captured_vars.Map([](Var var) -> Expr {
return var; })));
}
- auto new_func = Downcast<Function>(visited_func);
- Function lifted_func;
- bool is_closure = IsClosure(captured_vars);
if (!is_closure) {
- lifted_func = Function(
- /*params=*/new_func->params,
- /*body=*/new_func->body,
- /*ret_struct_info=*/new_func->ret_struct_info,
- /*is_pure=*/new_func->is_pure,
- /*attrs=*/new_func->attrs,
- /*span=*/new_func->span);
- } else {
- // Flatten the Closure
- std::vector<Var> closure_params;
- closure_params.reserve(func->params.size() + typed_captured_vars.size());
- for (size_t i = 0; i < func->params.size(); ++i) {
- closure_params.emplace_back(func->params[i]);
- }
- for (size_t i = 0; i < typed_captured_vars.size(); ++i) {
- closure_params.emplace_back(typed_captured_vars[i]);
- }
+ rebind_map_.emplace(current_lambda_var_.value(), gvar_lifted_func);
+ }
- lifted_func = Function(/*params=*/closure_params,
- /*body=*/Bind(new_func->body, rebinding_map),
- /*ret_struct_info=*/new_func->ret_struct_info,
- /*is_pure=*/new_func->is_pure,
- /*attrs=*/new_func->attrs,
- /*span=*/func->span);
+ body = this->VisitWithNewScope(body, lifted_func_params);
+ StructInfo ret_struct_info = GetStructInfo(body);
+ body = Bind(body, rebinding_map);
- for (Var param : closure_params) {
- CHECK(param->checked_type_.defined())
- << "relax.Function requires params to contain checked_type_";
- }
+ Function lifted_func;
+ if (lifted_func_params.same_as(func_node->params) &&
body.same_as(func_node->body) &&
+ ret_struct_info.same_as(func_node->ret_struct_info)) {
+ lifted_func = GetRef<Function>(func_node);
+ } else {
+ lifted_func =
+ Function(lifted_func_params, body, ret_struct_info,
func_node->is_pure, func_node->attrs);
+ }
+
+ for (Var param : lifted_func->params) {
+ CHECK(param->checked_type_.defined())
+ << "relax.Function requires all parameters to contain checked_type_.
"
+ << "However, parameter " << param << " with struct info " <<
param->struct_info_
+ << " has no checked type";
}
ICHECK(lifted_func.defined());
+ if (is_closure || IsClosure(lifted_func)) {
+ closures_.insert(gvar_lifted_func);
+ }
+
// Add the lifted function to the module.
- global->struct_info_ = GetStructInfo(lifted_func);
- global->checked_type_ = lifted_func->checked_type_;
- builder_->UpdateFunction(global, lifted_func);
+ lifted_func = CopyWithNewVars(lifted_func);
+ gvar_lifted_func->struct_info_ = GetStructInfo(lifted_func);
+ gvar_lifted_func->checked_type_ = lifted_func->checked_type_;
- if (!is_closure) {
- return std::move(global);
- } else {
+ builder_->UpdateFunction(gvar_lifted_func, lifted_func);
+
+ Expr callable_value = gvar_lifted_func;
+ if (is_closure) {
// If we need to allocate a closure,
// we pass the variables in its environment here.
- Array<Expr> fvs;
- for (auto fv : captured_vars) {
- fvs.push_back(fv);
- }
+ Tuple arg_tuple(captured_vars.Map([](Var var) -> Expr { return var; }));
// Call make_closure intrinsic
- return Call(make_closure_op_, {global, Tuple(fvs)}, {}, {});
+ callable_value = Call(make_closure_op_, {gvar_lifted_func, arg_tuple},
{}, {});
}
+
+ return callable_value;
}
- bool HasClosure(const Var& var) {
- auto val = builder_->LookupBinding(var);
- if (const auto* value = val.as<GlobalVarNode>()) {
- IRModule ctx_mod = builder_->GetContextIRModule();
- ICHECK(ctx_mod->functions.size() > 0);
- BaseFunc func = ctx_mod->Lookup(GetRef<GlobalVar>(value));
- if (const auto* func_node = func.as<FunctionNode>()) {
- if (const auto* call_node = func_node->body.as<CallNode>()) {
- if (call_node->op == make_closure_op_) {
- return true;
- }
- } else if (const auto* seq_expr_node =
func_node->body.as<SeqExprNode>()) {
- // the return var points to a make_closure intrinsic
- if (const auto* var = seq_expr_node->body.as<VarNode>()) {
- return HasClosure(GetRef<Var>(var));
+ Expr VisitExpr_(const CallNode* call_node) final {
+ auto call = GetRef<Call>(call_node);
+
+ auto orig_sinfo = Downcast<StructInfo>(call->struct_info_);
+
+ if (auto opt_var = call->op.as<Var>()) {
+ auto var = opt_var.value();
+
+ // Call "relax.invoke_closure" to invoke closure
+
+ if (bool is_closure = IsClosure(var);
+ is_closure && builder_->LookupBinding(var).as<CallNode>()) {
+ // if the original op was pure, we should use invoke_pure_closure
+ Call orig_call = Downcast<Call>(builder_->LookupBinding(var));
+ bool is_pure = [&]() -> bool {
+ if (auto op = orig_call->op.as<Op>()) {
+ static const auto& purity_map = Op::GetAttrMap<Bool>("FPurity");
+ return purity_map.get(op.value(), Bool(false))->value;
+ } else if (const auto* func_sinfo =
+ orig_call->op->struct_info_.as<FuncStructInfoNode>())
{
+ return func_sinfo->purity;
+ } else {
+ LOG(FATAL) << "Could not determine purity of call to " <<
orig_call->op
+ << ", as it is neither a tvm::Op (type = \"" <<
orig_call->op->GetTypeKey()
+ << "\"), "
+ << "nor is is annotated with FuncStructInfo (sinfo = "
+ << orig_call->op->struct_info_ << ")";
}
- }
+ }();
+
+ auto prev = call;
+ call = Call(is_pure ? invoke_pure_closure_op_ : invoke_closure_op_,
+ {var, Tuple(call->args)}, {}, {orig_sinfo});
}
- } else if (const auto* func_node = val.as<FunctionNode>()) {
- if (const auto* call_node = func_node->body.as<CallNode>()) {
- if (call_node->op == make_closure_op_) {
- return true;
+ }
+
+ if (auto opt_var = call->op.as<Var>()) {
+ auto var = opt_var.value();
+ if (auto it = nested_closure_map_.find(var); it !=
nested_closure_map_.end()) {
+ Call nested_call = it->second;
+
+ Array<relay::Expr> new_args = call->args;
+ for (const auto arg : nested_call->args) {
+ new_args.push_back(arg);
}
+
+ auto prev = call;
+ call = Call(nested_call->op, new_args, call->attrs, call->sinfo_args);
}
- } else if (const auto* call_node = val.as<relax::CallNode>()) {
+ }
+
+ return ExprMutator::VisitExpr_(call.get());
+ }
+
+ Expr VisitExpr_(const VarNode* op) override {
+ auto var = GetRef<Var>(op);
+ if (auto it = rebind_map_.find(var); it != rebind_map_.end()) {
+ return it->second;
+ }
+ return ExprMutator::VisitExpr_(op);
+ }
+
+ bool IsClosure(Expr val) {
+ static int depth = -1;
+ struct Context {
+ explicit Context(int* ptr) : ptr(ptr) { (*ptr)++; }
+ ~Context() { (*ptr)--; }
+ int* ptr;
+ } context(&depth);
Review Comment:
This doesn't appear to be used anywhere.
##########
src/relax/transform/lambda_lift.cc:
##########
@@ -336,176 +266,235 @@ class LambdaLifter : public ExprMutator {
return it->second;
}();
- auto global = GlobalVar(lift_func_name);
- Array<Var> free_vars = FreeVars(func);
Array<Var> captured_vars;
-
- Array<Var> typed_captured_vars;
- bool recursive = false;
- for (const auto& var : free_vars) {
- if (!recur_vars_.empty() && var == recur_vars_.back()) {
- recursive = true;
+ bool is_recursive = false;
+ bool is_closure = false;
+ for (const auto& var : FreeVars(func)) {
+ if (var.same_as(current_lambda_var_)) {
+ is_recursive = true;
} else {
+ is_closure = true;
captured_vars.push_back(var);
}
}
+ Array<Var> typed_captured_vars;
Map<Var, Expr> rebinding_map;
for (auto free_var : captured_vars) {
Var var = Var(free_var->name_hint(), GetStructInfo(free_var),
free_var->span);
typed_captured_vars.push_back(var);
rebinding_map.Set(free_var, var);
}
- // recursive call
- if (recursive) {
- if (!captured_vars.empty()) {
- Array<Expr> fvs;
- for (auto fv : captured_vars) {
- fvs.push_back(fv);
- }
- // it is required by block_blocker, will be updated later
- UpdateStructInfo(global, GetStructInfo(recur_vars_.back()));
- lambda_map_.emplace(recur_vars_.back(), Call(global, fvs));
- } else {
- if (recur_vars_.size() > 0) {
- lambda_map_.emplace(recur_vars_.back(), global);
- }
- }
+ tvm::Array<Var> lifted_func_params =
+ func_node->params.Map([this](Var var) { return VisitVarDef(var); });
+ for (const auto& var : typed_captured_vars) {
+ lifted_func_params.push_back(var);
}
- tvm::Array<Var> params;
- bool all_params_unchanged = true;
- for (Var param : func_node->params) {
- Var new_param = this->VisitVarDef(param);
- params.push_back(new_param);
- all_params_unchanged &= param.same_as(new_param);
+ auto gvar_lifted_func = GlobalVar(lift_func_name);
+ {
+ auto func_sinfo = Downcast<FuncStructInfo>(func_node->struct_info_);
+ if (is_closure) {
+ func_sinfo = FuncStructInfo(lifted_func_params.Map(GetStructInfo),
func_sinfo->ret,
+ func_sinfo->purity);
+ }
+ UpdateStructInfo(gvar_lifted_func, func_sinfo);
}
- Expr body = this->VisitWithNewScope(func_node->body);
- Expr visited_func;
+ Expr body = func_node->body;
- if (all_params_unchanged && body.same_as(func_node->body)) {
- visited_func = GetRef<Expr>(func_node);
- } else if (const auto& body_sinfo =
MatchStructInfo<ObjectStructInfo>(body)) {
- visited_func =
- Function(params, body, body_sinfo.value(), func_node->is_pure,
func_node->attrs);
- } else {
- visited_func =
- Function(params, body, func_node->ret_struct_info,
func_node->is_pure, func_node->attrs);
+ // recursive call
+ if (is_recursive && is_closure) {
+ // it is required by block_blocker, will be updated later
+ nested_closure_map_.emplace(
+ current_lambda_var_.value(),
+ Call(gvar_lifted_func, captured_vars.Map([](Var var) -> Expr {
return var; })));
}
- auto new_func = Downcast<Function>(visited_func);
- Function lifted_func;
- bool is_closure = IsClosure(captured_vars);
if (!is_closure) {
- lifted_func = Function(
- /*params=*/new_func->params,
- /*body=*/new_func->body,
- /*ret_struct_info=*/new_func->ret_struct_info,
- /*is_pure=*/new_func->is_pure,
- /*attrs=*/new_func->attrs,
- /*span=*/new_func->span);
- } else {
- // Flatten the Closure
- std::vector<Var> closure_params;
- closure_params.reserve(func->params.size() + typed_captured_vars.size());
- for (size_t i = 0; i < func->params.size(); ++i) {
- closure_params.emplace_back(func->params[i]);
- }
- for (size_t i = 0; i < typed_captured_vars.size(); ++i) {
- closure_params.emplace_back(typed_captured_vars[i]);
- }
+ rebind_map_.emplace(current_lambda_var_.value(), gvar_lifted_func);
+ }
- lifted_func = Function(/*params=*/closure_params,
- /*body=*/Bind(new_func->body, rebinding_map),
- /*ret_struct_info=*/new_func->ret_struct_info,
- /*is_pure=*/new_func->is_pure,
- /*attrs=*/new_func->attrs,
- /*span=*/func->span);
+ body = this->VisitWithNewScope(body, lifted_func_params);
+ StructInfo ret_struct_info = GetStructInfo(body);
+ body = Bind(body, rebinding_map);
- for (Var param : closure_params) {
- CHECK(param->checked_type_.defined())
- << "relax.Function requires params to contain checked_type_";
- }
+ Function lifted_func;
+ if (lifted_func_params.same_as(func_node->params) &&
body.same_as(func_node->body) &&
+ ret_struct_info.same_as(func_node->ret_struct_info)) {
+ lifted_func = GetRef<Function>(func_node);
+ } else {
+ lifted_func =
+ Function(lifted_func_params, body, ret_struct_info,
func_node->is_pure, func_node->attrs);
+ }
+
+ for (Var param : lifted_func->params) {
+ CHECK(param->checked_type_.defined())
+ << "relax.Function requires all parameters to contain checked_type_.
"
+ << "However, parameter " << param << " with struct info " <<
param->struct_info_
+ << " has no checked type";
}
ICHECK(lifted_func.defined());
+ if (is_closure || IsClosure(lifted_func)) {
+ closures_.insert(gvar_lifted_func);
+ }
+
// Add the lifted function to the module.
- global->struct_info_ = GetStructInfo(lifted_func);
- global->checked_type_ = lifted_func->checked_type_;
- builder_->UpdateFunction(global, lifted_func);
+ lifted_func = CopyWithNewVars(lifted_func);
+ gvar_lifted_func->struct_info_ = GetStructInfo(lifted_func);
+ gvar_lifted_func->checked_type_ = lifted_func->checked_type_;
- if (!is_closure) {
- return std::move(global);
- } else {
+ builder_->UpdateFunction(gvar_lifted_func, lifted_func);
+
+ Expr callable_value = gvar_lifted_func;
+ if (is_closure) {
// If we need to allocate a closure,
// we pass the variables in its environment here.
- Array<Expr> fvs;
- for (auto fv : captured_vars) {
- fvs.push_back(fv);
- }
+ Tuple arg_tuple(captured_vars.Map([](Var var) -> Expr { return var; }));
// Call make_closure intrinsic
- return Call(make_closure_op_, {global, Tuple(fvs)}, {}, {});
+ callable_value = Call(make_closure_op_, {gvar_lifted_func, arg_tuple},
{}, {});
}
+
+ return callable_value;
}
- bool HasClosure(const Var& var) {
- auto val = builder_->LookupBinding(var);
- if (const auto* value = val.as<GlobalVarNode>()) {
- IRModule ctx_mod = builder_->GetContextIRModule();
- ICHECK(ctx_mod->functions.size() > 0);
- BaseFunc func = ctx_mod->Lookup(GetRef<GlobalVar>(value));
- if (const auto* func_node = func.as<FunctionNode>()) {
- if (const auto* call_node = func_node->body.as<CallNode>()) {
- if (call_node->op == make_closure_op_) {
- return true;
- }
- } else if (const auto* seq_expr_node =
func_node->body.as<SeqExprNode>()) {
- // the return var points to a make_closure intrinsic
- if (const auto* var = seq_expr_node->body.as<VarNode>()) {
- return HasClosure(GetRef<Var>(var));
+ Expr VisitExpr_(const CallNode* call_node) final {
+ auto call = GetRef<Call>(call_node);
+
+ auto orig_sinfo = Downcast<StructInfo>(call->struct_info_);
+
+ if (auto opt_var = call->op.as<Var>()) {
+ auto var = opt_var.value();
+
+ // Call "relax.invoke_closure" to invoke closure
+
+ if (bool is_closure = IsClosure(var);
+ is_closure && builder_->LookupBinding(var).as<CallNode>()) {
Review Comment:
Is there a reason to write it this way instead of `IsClosure(var) && ...`?
##########
src/relax/transform/lambda_lift.cc:
##########
@@ -336,176 +266,235 @@ class LambdaLifter : public ExprMutator {
return it->second;
}();
- auto global = GlobalVar(lift_func_name);
- Array<Var> free_vars = FreeVars(func);
Array<Var> captured_vars;
-
- Array<Var> typed_captured_vars;
- bool recursive = false;
- for (const auto& var : free_vars) {
- if (!recur_vars_.empty() && var == recur_vars_.back()) {
- recursive = true;
+ bool is_recursive = false;
+ bool is_closure = false;
+ for (const auto& var : FreeVars(func)) {
+ if (var.same_as(current_lambda_var_)) {
+ is_recursive = true;
} else {
+ is_closure = true;
captured_vars.push_back(var);
}
}
+ Array<Var> typed_captured_vars;
Map<Var, Expr> rebinding_map;
for (auto free_var : captured_vars) {
Var var = Var(free_var->name_hint(), GetStructInfo(free_var),
free_var->span);
typed_captured_vars.push_back(var);
rebinding_map.Set(free_var, var);
}
- // recursive call
- if (recursive) {
- if (!captured_vars.empty()) {
- Array<Expr> fvs;
- for (auto fv : captured_vars) {
- fvs.push_back(fv);
- }
- // it is required by block_blocker, will be updated later
- UpdateStructInfo(global, GetStructInfo(recur_vars_.back()));
- lambda_map_.emplace(recur_vars_.back(), Call(global, fvs));
- } else {
- if (recur_vars_.size() > 0) {
- lambda_map_.emplace(recur_vars_.back(), global);
- }
- }
+ tvm::Array<Var> lifted_func_params =
+ func_node->params.Map([this](Var var) { return VisitVarDef(var); });
+ for (const auto& var : typed_captured_vars) {
+ lifted_func_params.push_back(var);
}
- tvm::Array<Var> params;
- bool all_params_unchanged = true;
- for (Var param : func_node->params) {
- Var new_param = this->VisitVarDef(param);
- params.push_back(new_param);
- all_params_unchanged &= param.same_as(new_param);
+ auto gvar_lifted_func = GlobalVar(lift_func_name);
+ {
+ auto func_sinfo = Downcast<FuncStructInfo>(func_node->struct_info_);
+ if (is_closure) {
+ func_sinfo = FuncStructInfo(lifted_func_params.Map(GetStructInfo),
func_sinfo->ret,
+ func_sinfo->purity);
+ }
+ UpdateStructInfo(gvar_lifted_func, func_sinfo);
}
- Expr body = this->VisitWithNewScope(func_node->body);
- Expr visited_func;
+ Expr body = func_node->body;
- if (all_params_unchanged && body.same_as(func_node->body)) {
- visited_func = GetRef<Expr>(func_node);
- } else if (const auto& body_sinfo =
MatchStructInfo<ObjectStructInfo>(body)) {
- visited_func =
- Function(params, body, body_sinfo.value(), func_node->is_pure,
func_node->attrs);
- } else {
- visited_func =
- Function(params, body, func_node->ret_struct_info,
func_node->is_pure, func_node->attrs);
+ // recursive call
+ if (is_recursive && is_closure) {
+ // it is required by block_blocker, will be updated later
+ nested_closure_map_.emplace(
+ current_lambda_var_.value(),
+ Call(gvar_lifted_func, captured_vars.Map([](Var var) -> Expr {
return var; })));
Review Comment:
I assume the map is so that an `Array<Var>` is treated as `Array<Expr>`,
correct?
##########
src/relax/transform/lambda_lift.cc:
##########
@@ -336,176 +266,235 @@ class LambdaLifter : public ExprMutator {
return it->second;
}();
- auto global = GlobalVar(lift_func_name);
- Array<Var> free_vars = FreeVars(func);
Array<Var> captured_vars;
-
- Array<Var> typed_captured_vars;
- bool recursive = false;
- for (const auto& var : free_vars) {
- if (!recur_vars_.empty() && var == recur_vars_.back()) {
- recursive = true;
+ bool is_recursive = false;
+ bool is_closure = false;
+ for (const auto& var : FreeVars(func)) {
+ if (var.same_as(current_lambda_var_)) {
+ is_recursive = true;
} else {
+ is_closure = true;
captured_vars.push_back(var);
}
}
+ Array<Var> typed_captured_vars;
Map<Var, Expr> rebinding_map;
for (auto free_var : captured_vars) {
Var var = Var(free_var->name_hint(), GetStructInfo(free_var),
free_var->span);
typed_captured_vars.push_back(var);
rebinding_map.Set(free_var, var);
}
- // recursive call
- if (recursive) {
- if (!captured_vars.empty()) {
- Array<Expr> fvs;
- for (auto fv : captured_vars) {
- fvs.push_back(fv);
- }
- // it is required by block_blocker, will be updated later
- UpdateStructInfo(global, GetStructInfo(recur_vars_.back()));
- lambda_map_.emplace(recur_vars_.back(), Call(global, fvs));
- } else {
- if (recur_vars_.size() > 0) {
- lambda_map_.emplace(recur_vars_.back(), global);
- }
- }
+ tvm::Array<Var> lifted_func_params =
+ func_node->params.Map([this](Var var) { return VisitVarDef(var); });
+ for (const auto& var : typed_captured_vars) {
+ lifted_func_params.push_back(var);
}
- tvm::Array<Var> params;
- bool all_params_unchanged = true;
- for (Var param : func_node->params) {
- Var new_param = this->VisitVarDef(param);
- params.push_back(new_param);
- all_params_unchanged &= param.same_as(new_param);
+ auto gvar_lifted_func = GlobalVar(lift_func_name);
+ {
+ auto func_sinfo = Downcast<FuncStructInfo>(func_node->struct_info_);
+ if (is_closure) {
+ func_sinfo = FuncStructInfo(lifted_func_params.Map(GetStructInfo),
func_sinfo->ret,
+ func_sinfo->purity);
+ }
+ UpdateStructInfo(gvar_lifted_func, func_sinfo);
}
- Expr body = this->VisitWithNewScope(func_node->body);
- Expr visited_func;
+ Expr body = func_node->body;
- if (all_params_unchanged && body.same_as(func_node->body)) {
- visited_func = GetRef<Expr>(func_node);
- } else if (const auto& body_sinfo =
MatchStructInfo<ObjectStructInfo>(body)) {
- visited_func =
- Function(params, body, body_sinfo.value(), func_node->is_pure,
func_node->attrs);
- } else {
- visited_func =
- Function(params, body, func_node->ret_struct_info,
func_node->is_pure, func_node->attrs);
+ // recursive call
+ if (is_recursive && is_closure) {
+ // it is required by block_blocker, will be updated later
+ nested_closure_map_.emplace(
+ current_lambda_var_.value(),
+ Call(gvar_lifted_func, captured_vars.Map([](Var var) -> Expr {
return var; })));
}
- auto new_func = Downcast<Function>(visited_func);
- Function lifted_func;
- bool is_closure = IsClosure(captured_vars);
if (!is_closure) {
- lifted_func = Function(
- /*params=*/new_func->params,
- /*body=*/new_func->body,
- /*ret_struct_info=*/new_func->ret_struct_info,
- /*is_pure=*/new_func->is_pure,
- /*attrs=*/new_func->attrs,
- /*span=*/new_func->span);
- } else {
- // Flatten the Closure
- std::vector<Var> closure_params;
- closure_params.reserve(func->params.size() + typed_captured_vars.size());
- for (size_t i = 0; i < func->params.size(); ++i) {
- closure_params.emplace_back(func->params[i]);
- }
- for (size_t i = 0; i < typed_captured_vars.size(); ++i) {
- closure_params.emplace_back(typed_captured_vars[i]);
- }
+ rebind_map_.emplace(current_lambda_var_.value(), gvar_lifted_func);
+ }
- lifted_func = Function(/*params=*/closure_params,
- /*body=*/Bind(new_func->body, rebinding_map),
- /*ret_struct_info=*/new_func->ret_struct_info,
- /*is_pure=*/new_func->is_pure,
- /*attrs=*/new_func->attrs,
- /*span=*/func->span);
+ body = this->VisitWithNewScope(body, lifted_func_params);
+ StructInfo ret_struct_info = GetStructInfo(body);
+ body = Bind(body, rebinding_map);
- for (Var param : closure_params) {
- CHECK(param->checked_type_.defined())
- << "relax.Function requires params to contain checked_type_";
- }
+ Function lifted_func;
+ if (lifted_func_params.same_as(func_node->params) &&
body.same_as(func_node->body) &&
+ ret_struct_info.same_as(func_node->ret_struct_info)) {
+ lifted_func = GetRef<Function>(func_node);
+ } else {
+ lifted_func =
+ Function(lifted_func_params, body, ret_struct_info,
func_node->is_pure, func_node->attrs);
+ }
+
+ for (Var param : lifted_func->params) {
+ CHECK(param->checked_type_.defined())
+ << "relax.Function requires all parameters to contain checked_type_.
"
+ << "However, parameter " << param << " with struct info " <<
param->struct_info_
+ << " has no checked type";
}
ICHECK(lifted_func.defined());
+ if (is_closure || IsClosure(lifted_func)) {
+ closures_.insert(gvar_lifted_func);
+ }
+
// Add the lifted function to the module.
- global->struct_info_ = GetStructInfo(lifted_func);
- global->checked_type_ = lifted_func->checked_type_;
- builder_->UpdateFunction(global, lifted_func);
+ lifted_func = CopyWithNewVars(lifted_func);
+ gvar_lifted_func->struct_info_ = GetStructInfo(lifted_func);
+ gvar_lifted_func->checked_type_ = lifted_func->checked_type_;
- if (!is_closure) {
- return std::move(global);
- } else {
+ builder_->UpdateFunction(gvar_lifted_func, lifted_func);
+
+ Expr callable_value = gvar_lifted_func;
+ if (is_closure) {
// If we need to allocate a closure,
// we pass the variables in its environment here.
- Array<Expr> fvs;
- for (auto fv : captured_vars) {
- fvs.push_back(fv);
- }
+ Tuple arg_tuple(captured_vars.Map([](Var var) -> Expr { return var; }));
// Call make_closure intrinsic
- return Call(make_closure_op_, {global, Tuple(fvs)}, {}, {});
+ callable_value = Call(make_closure_op_, {gvar_lifted_func, arg_tuple},
{}, {});
}
+
+ return callable_value;
}
- bool HasClosure(const Var& var) {
- auto val = builder_->LookupBinding(var);
- if (const auto* value = val.as<GlobalVarNode>()) {
- IRModule ctx_mod = builder_->GetContextIRModule();
- ICHECK(ctx_mod->functions.size() > 0);
- BaseFunc func = ctx_mod->Lookup(GetRef<GlobalVar>(value));
- if (const auto* func_node = func.as<FunctionNode>()) {
- if (const auto* call_node = func_node->body.as<CallNode>()) {
- if (call_node->op == make_closure_op_) {
- return true;
- }
- } else if (const auto* seq_expr_node =
func_node->body.as<SeqExprNode>()) {
- // the return var points to a make_closure intrinsic
- if (const auto* var = seq_expr_node->body.as<VarNode>()) {
- return HasClosure(GetRef<Var>(var));
+ Expr VisitExpr_(const CallNode* call_node) final {
+ auto call = GetRef<Call>(call_node);
+
+ auto orig_sinfo = Downcast<StructInfo>(call->struct_info_);
+
+ if (auto opt_var = call->op.as<Var>()) {
+ auto var = opt_var.value();
+
+ // Call "relax.invoke_closure" to invoke closure
+
+ if (bool is_closure = IsClosure(var);
+ is_closure && builder_->LookupBinding(var).as<CallNode>()) {
+ // if the original op was pure, we should use invoke_pure_closure
+ Call orig_call = Downcast<Call>(builder_->LookupBinding(var));
+ bool is_pure = [&]() -> bool {
+ if (auto op = orig_call->op.as<Op>()) {
+ static const auto& purity_map = Op::GetAttrMap<Bool>("FPurity");
+ return purity_map.get(op.value(), Bool(false))->value;
+ } else if (const auto* func_sinfo =
+ orig_call->op->struct_info_.as<FuncStructInfoNode>())
{
+ return func_sinfo->purity;
+ } else {
+ LOG(FATAL) << "Could not determine purity of call to " <<
orig_call->op
+ << ", as it is neither a tvm::Op (type = \"" <<
orig_call->op->GetTypeKey()
+ << "\"), "
+ << "nor is is annotated with FuncStructInfo (sinfo = "
+ << orig_call->op->struct_info_ << ")";
}
- }
+ }();
+
+ auto prev = call;
+ call = Call(is_pure ? invoke_pure_closure_op_ : invoke_closure_op_,
+ {var, Tuple(call->args)}, {}, {orig_sinfo});
}
- } else if (const auto* func_node = val.as<FunctionNode>()) {
- if (const auto* call_node = func_node->body.as<CallNode>()) {
- if (call_node->op == make_closure_op_) {
- return true;
+ }
+
+ if (auto opt_var = call->op.as<Var>()) {
+ auto var = opt_var.value();
+ if (auto it = nested_closure_map_.find(var); it !=
nested_closure_map_.end()) {
+ Call nested_call = it->second;
+
+ Array<relay::Expr> new_args = call->args;
+ for (const auto arg : nested_call->args) {
+ new_args.push_back(arg);
}
+
+ auto prev = call;
+ call = Call(nested_call->op, new_args, call->attrs, call->sinfo_args);
}
- } else if (const auto* call_node = val.as<relax::CallNode>()) {
+ }
+
+ return ExprMutator::VisitExpr_(call.get());
+ }
+
+ Expr VisitExpr_(const VarNode* op) override {
+ auto var = GetRef<Var>(op);
+ if (auto it = rebind_map_.find(var); it != rebind_map_.end()) {
+ return it->second;
+ }
+ return ExprMutator::VisitExpr_(op);
+ }
+
+ bool IsClosure(Expr val) {
+ static int depth = -1;
+ struct Context {
+ explicit Context(int* ptr) : ptr(ptr) { (*ptr)++; }
+ ~Context() { (*ptr)--; }
+ int* ptr;
+ } context(&depth);
+
+ if (auto opt_var = val.as<Var>()) {
+ if (closures_.count(opt_var.value())) {
+ return true;
+ }
+ if (auto bound_value = builder_->LookupBinding(opt_var.value())) {
+ val = bound_value.value();
+ }
+ }
+
+ if (const auto* call_node = val.as<relax::CallNode>()) {
// recursive call
auto op = call_node->op;
- if (make_closure_op_ == op) {
+ if (auto local_var = op.as<Var>()) {
+ return IsClosure(local_var.value());
+ } else if (auto global_var = op.as<GlobalVar>()) {
+ return IsClosure(global_var.value());
+ } else {
+ return make_closure_op_ == op;
+ }
+
+ } else if (const auto* global_var = val.as<GlobalVarNode>()) {
+ if (closures_.count(GetRef<GlobalVar>(global_var))) {
return true;
}
- if (const auto* lv = op.as<VarNode>()) {
- return HasClosure(GetRef<Var>(lv));
+ IRModule ctx_mod = builder_->GetContextIRModule();
+ ICHECK(ctx_mod->functions.size() > 0);
+ BaseFunc func = ctx_mod->Lookup(GetRef<GlobalVar>(global_var));
+ const auto* func_node = func.as<FunctionNode>();
+ if (func_node) {
+ return IsClosure(func_node->body);
+ } else {
+ return false;
}
+
+ } else if (const auto* func_node = val.as<FunctionNode>()) {
+ return IsClosure(func_node->body);
+
+ } else if (const auto* seq_node = val.as<SeqExprNode>()) {
+ return IsClosure(seq_node->body);
+
+ } else {
+ return false;
}
- return false;
}
- bool IsClosure(const Array<Var>& captured_vars) { return
captured_vars.size() > 0; }
-
IRModule Lift() {
auto glob_funcs = mod_->functions;
- for (auto pair : glob_funcs) {
- if (auto* n = pair.second.as<FunctionNode>()) {
- auto func = GetRef<Function>(n);
- func = Function(func->params, VisitExpr(func->body),
func->ret_struct_info, func->is_pure,
- func->attrs);
- builder_->UpdateFunction(pair.first, func);
+ for (auto [gvar, base_func] : glob_funcs) {
+ if (auto opt = base_func.as<Function>()) {
+ // Must visit the function itself, and not just the function
+ // body, to ensure that EraseToWellDefined recognized symbolic
+ // variables that are exposed by the function signature.
+ auto func = Downcast<Function>(VisitExpr(opt.value()));
+ builder_->UpdateFunction(gvar, func);
Review Comment:
Is this the only change needed to handle symbolic vars, fundamentally? I ask
because I didn't see any logic specific to symbolic vars above.
##########
src/relax/transform/lambda_lift.cc:
##########
@@ -336,176 +266,235 @@ class LambdaLifter : public ExprMutator {
return it->second;
}();
- auto global = GlobalVar(lift_func_name);
- Array<Var> free_vars = FreeVars(func);
Array<Var> captured_vars;
-
- Array<Var> typed_captured_vars;
- bool recursive = false;
- for (const auto& var : free_vars) {
- if (!recur_vars_.empty() && var == recur_vars_.back()) {
- recursive = true;
+ bool is_recursive = false;
+ bool is_closure = false;
+ for (const auto& var : FreeVars(func)) {
+ if (var.same_as(current_lambda_var_)) {
+ is_recursive = true;
} else {
+ is_closure = true;
captured_vars.push_back(var);
}
}
+ Array<Var> typed_captured_vars;
Map<Var, Expr> rebinding_map;
for (auto free_var : captured_vars) {
Var var = Var(free_var->name_hint(), GetStructInfo(free_var),
free_var->span);
typed_captured_vars.push_back(var);
rebinding_map.Set(free_var, var);
}
- // recursive call
- if (recursive) {
- if (!captured_vars.empty()) {
- Array<Expr> fvs;
- for (auto fv : captured_vars) {
- fvs.push_back(fv);
- }
- // it is required by block_blocker, will be updated later
- UpdateStructInfo(global, GetStructInfo(recur_vars_.back()));
- lambda_map_.emplace(recur_vars_.back(), Call(global, fvs));
- } else {
- if (recur_vars_.size() > 0) {
- lambda_map_.emplace(recur_vars_.back(), global);
- }
- }
+ tvm::Array<Var> lifted_func_params =
+ func_node->params.Map([this](Var var) { return VisitVarDef(var); });
+ for (const auto& var : typed_captured_vars) {
+ lifted_func_params.push_back(var);
}
- tvm::Array<Var> params;
- bool all_params_unchanged = true;
- for (Var param : func_node->params) {
- Var new_param = this->VisitVarDef(param);
- params.push_back(new_param);
- all_params_unchanged &= param.same_as(new_param);
+ auto gvar_lifted_func = GlobalVar(lift_func_name);
+ {
+ auto func_sinfo = Downcast<FuncStructInfo>(func_node->struct_info_);
+ if (is_closure) {
+ func_sinfo = FuncStructInfo(lifted_func_params.Map(GetStructInfo),
func_sinfo->ret,
+ func_sinfo->purity);
+ }
+ UpdateStructInfo(gvar_lifted_func, func_sinfo);
}
- Expr body = this->VisitWithNewScope(func_node->body);
- Expr visited_func;
+ Expr body = func_node->body;
- if (all_params_unchanged && body.same_as(func_node->body)) {
- visited_func = GetRef<Expr>(func_node);
- } else if (const auto& body_sinfo =
MatchStructInfo<ObjectStructInfo>(body)) {
- visited_func =
- Function(params, body, body_sinfo.value(), func_node->is_pure,
func_node->attrs);
- } else {
- visited_func =
- Function(params, body, func_node->ret_struct_info,
func_node->is_pure, func_node->attrs);
+ // recursive call
+ if (is_recursive && is_closure) {
+ // it is required by block_blocker, will be updated later
+ nested_closure_map_.emplace(
+ current_lambda_var_.value(),
+ Call(gvar_lifted_func, captured_vars.Map([](Var var) -> Expr {
return var; })));
}
- auto new_func = Downcast<Function>(visited_func);
- Function lifted_func;
- bool is_closure = IsClosure(captured_vars);
if (!is_closure) {
- lifted_func = Function(
- /*params=*/new_func->params,
- /*body=*/new_func->body,
- /*ret_struct_info=*/new_func->ret_struct_info,
- /*is_pure=*/new_func->is_pure,
- /*attrs=*/new_func->attrs,
- /*span=*/new_func->span);
- } else {
- // Flatten the Closure
- std::vector<Var> closure_params;
- closure_params.reserve(func->params.size() + typed_captured_vars.size());
- for (size_t i = 0; i < func->params.size(); ++i) {
- closure_params.emplace_back(func->params[i]);
- }
- for (size_t i = 0; i < typed_captured_vars.size(); ++i) {
- closure_params.emplace_back(typed_captured_vars[i]);
- }
+ rebind_map_.emplace(current_lambda_var_.value(), gvar_lifted_func);
+ }
- lifted_func = Function(/*params=*/closure_params,
- /*body=*/Bind(new_func->body, rebinding_map),
- /*ret_struct_info=*/new_func->ret_struct_info,
- /*is_pure=*/new_func->is_pure,
- /*attrs=*/new_func->attrs,
- /*span=*/func->span);
+ body = this->VisitWithNewScope(body, lifted_func_params);
+ StructInfo ret_struct_info = GetStructInfo(body);
+ body = Bind(body, rebinding_map);
- for (Var param : closure_params) {
- CHECK(param->checked_type_.defined())
- << "relax.Function requires params to contain checked_type_";
- }
+ Function lifted_func;
+ if (lifted_func_params.same_as(func_node->params) &&
body.same_as(func_node->body) &&
+ ret_struct_info.same_as(func_node->ret_struct_info)) {
+ lifted_func = GetRef<Function>(func_node);
+ } else {
+ lifted_func =
+ Function(lifted_func_params, body, ret_struct_info,
func_node->is_pure, func_node->attrs);
+ }
+
+ for (Var param : lifted_func->params) {
+ CHECK(param->checked_type_.defined())
+ << "relax.Function requires all parameters to contain checked_type_.
"
+ << "However, parameter " << param << " with struct info " <<
param->struct_info_
+ << " has no checked type";
}
ICHECK(lifted_func.defined());
+ if (is_closure || IsClosure(lifted_func)) {
+ closures_.insert(gvar_lifted_func);
+ }
+
// Add the lifted function to the module.
- global->struct_info_ = GetStructInfo(lifted_func);
- global->checked_type_ = lifted_func->checked_type_;
- builder_->UpdateFunction(global, lifted_func);
+ lifted_func = CopyWithNewVars(lifted_func);
+ gvar_lifted_func->struct_info_ = GetStructInfo(lifted_func);
+ gvar_lifted_func->checked_type_ = lifted_func->checked_type_;
- if (!is_closure) {
- return std::move(global);
- } else {
+ builder_->UpdateFunction(gvar_lifted_func, lifted_func);
+
+ Expr callable_value = gvar_lifted_func;
+ if (is_closure) {
// If we need to allocate a closure,
// we pass the variables in its environment here.
- Array<Expr> fvs;
- for (auto fv : captured_vars) {
- fvs.push_back(fv);
- }
+ Tuple arg_tuple(captured_vars.Map([](Var var) -> Expr { return var; }));
// Call make_closure intrinsic
- return Call(make_closure_op_, {global, Tuple(fvs)}, {}, {});
+ callable_value = Call(make_closure_op_, {gvar_lifted_func, arg_tuple},
{}, {});
}
+
+ return callable_value;
}
- bool HasClosure(const Var& var) {
- auto val = builder_->LookupBinding(var);
- if (const auto* value = val.as<GlobalVarNode>()) {
- IRModule ctx_mod = builder_->GetContextIRModule();
- ICHECK(ctx_mod->functions.size() > 0);
- BaseFunc func = ctx_mod->Lookup(GetRef<GlobalVar>(value));
- if (const auto* func_node = func.as<FunctionNode>()) {
- if (const auto* call_node = func_node->body.as<CallNode>()) {
- if (call_node->op == make_closure_op_) {
- return true;
- }
- } else if (const auto* seq_expr_node =
func_node->body.as<SeqExprNode>()) {
- // the return var points to a make_closure intrinsic
- if (const auto* var = seq_expr_node->body.as<VarNode>()) {
- return HasClosure(GetRef<Var>(var));
+ Expr VisitExpr_(const CallNode* call_node) final {
+ auto call = GetRef<Call>(call_node);
+
+ auto orig_sinfo = Downcast<StructInfo>(call->struct_info_);
+
+ if (auto opt_var = call->op.as<Var>()) {
+ auto var = opt_var.value();
+
+ // Call "relax.invoke_closure" to invoke closure
+
+ if (bool is_closure = IsClosure(var);
+ is_closure && builder_->LookupBinding(var).as<CallNode>()) {
+ // if the original op was pure, we should use invoke_pure_closure
+ Call orig_call = Downcast<Call>(builder_->LookupBinding(var));
+ bool is_pure = [&]() -> bool {
+ if (auto op = orig_call->op.as<Op>()) {
+ static const auto& purity_map = Op::GetAttrMap<Bool>("FPurity");
+ return purity_map.get(op.value(), Bool(false))->value;
+ } else if (const auto* func_sinfo =
+ orig_call->op->struct_info_.as<FuncStructInfoNode>())
{
+ return func_sinfo->purity;
+ } else {
+ LOG(FATAL) << "Could not determine purity of call to " <<
orig_call->op
+ << ", as it is neither a tvm::Op (type = \"" <<
orig_call->op->GetTypeKey()
+ << "\"), "
+ << "nor is is annotated with FuncStructInfo (sinfo = "
+ << orig_call->op->struct_info_ << ")";
}
- }
+ }();
+
+ auto prev = call;
+ call = Call(is_pure ? invoke_pure_closure_op_ : invoke_closure_op_,
+ {var, Tuple(call->args)}, {}, {orig_sinfo});
}
- } else if (const auto* func_node = val.as<FunctionNode>()) {
- if (const auto* call_node = func_node->body.as<CallNode>()) {
- if (call_node->op == make_closure_op_) {
- return true;
+ }
+
+ if (auto opt_var = call->op.as<Var>()) {
+ auto var = opt_var.value();
+ if (auto it = nested_closure_map_.find(var); it !=
nested_closure_map_.end()) {
+ Call nested_call = it->second;
+
+ Array<relay::Expr> new_args = call->args;
+ for (const auto arg : nested_call->args) {
+ new_args.push_back(arg);
}
+
+ auto prev = call;
+ call = Call(nested_call->op, new_args, call->attrs, call->sinfo_args);
}
- } else if (const auto* call_node = val.as<relax::CallNode>()) {
+ }
+
+ return ExprMutator::VisitExpr_(call.get());
+ }
+
+ Expr VisitExpr_(const VarNode* op) override {
+ auto var = GetRef<Var>(op);
+ if (auto it = rebind_map_.find(var); it != rebind_map_.end()) {
+ return it->second;
+ }
+ return ExprMutator::VisitExpr_(op);
+ }
+
+ bool IsClosure(Expr val) {
+ static int depth = -1;
+ struct Context {
+ explicit Context(int* ptr) : ptr(ptr) { (*ptr)++; }
+ ~Context() { (*ptr)--; }
+ int* ptr;
+ } context(&depth);
+
+ if (auto opt_var = val.as<Var>()) {
+ if (closures_.count(opt_var.value())) {
+ return true;
+ }
+ if (auto bound_value = builder_->LookupBinding(opt_var.value())) {
+ val = bound_value.value();
+ }
+ }
+
+ if (const auto* call_node = val.as<relax::CallNode>()) {
// recursive call
auto op = call_node->op;
- if (make_closure_op_ == op) {
+ if (auto local_var = op.as<Var>()) {
+ return IsClosure(local_var.value());
+ } else if (auto global_var = op.as<GlobalVar>()) {
+ return IsClosure(global_var.value());
+ } else {
+ return make_closure_op_ == op;
+ }
+
+ } else if (const auto* global_var = val.as<GlobalVarNode>()) {
+ if (closures_.count(GetRef<GlobalVar>(global_var))) {
return true;
}
- if (const auto* lv = op.as<VarNode>()) {
- return HasClosure(GetRef<Var>(lv));
+ IRModule ctx_mod = builder_->GetContextIRModule();
+ ICHECK(ctx_mod->functions.size() > 0);
+ BaseFunc func = ctx_mod->Lookup(GetRef<GlobalVar>(global_var));
+ const auto* func_node = func.as<FunctionNode>();
+ if (func_node) {
+ return IsClosure(func_node->body);
+ } else {
+ return false;
}
+
+ } else if (const auto* func_node = val.as<FunctionNode>()) {
+ return IsClosure(func_node->body);
+
+ } else if (const auto* seq_node = val.as<SeqExprNode>()) {
+ return IsClosure(seq_node->body);
+
+ } else {
+ return false;
}
- return false;
}
- bool IsClosure(const Array<Var>& captured_vars) { return
captured_vars.size() > 0; }
-
IRModule Lift() {
auto glob_funcs = mod_->functions;
- for (auto pair : glob_funcs) {
- if (auto* n = pair.second.as<FunctionNode>()) {
- auto func = GetRef<Function>(n);
- func = Function(func->params, VisitExpr(func->body),
func->ret_struct_info, func->is_pure,
- func->attrs);
- builder_->UpdateFunction(pair.first, func);
+ for (auto [gvar, base_func] : glob_funcs) {
+ if (auto opt = base_func.as<Function>()) {
Review Comment:
Cool, I didn't know about using `as` for `ObjectRef`s directly :+1:
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]