slyubomirsky commented on code in PR #16306:
URL: https://github.com/apache/tvm/pull/16306#discussion_r1445596559
##########
src/relax/transform/lambda_lift.cc:
##########
@@ -34,6 +34,178 @@
namespace tvm {
namespace relax {
+namespace {
+
+/* \brief Collect names of functions to be lifted out */
+class LambdaNameCollector : ExprVisitor {
+ public:
+ static std::unordered_map<const FunctionNode*, String> Collect(const
IRModule& mod) {
+ LambdaNameCollector visitor;
+
+ for (const auto& [gvar, base_func] : mod->functions) {
+ visitor.previous_global_vars_.insert(gvar->name_hint);
+ }
+
+ for (const auto& [gvar, base_func] : mod->functions) {
+ if (auto func = base_func.as<Function>()) {
+ visitor.name_stack_.push_back(gvar->name_hint);
+ visitor(func.value());
+ visitor.name_stack_.pop_back();
+ }
+ }
+
+ return visitor.Finalize();
+ }
+
+ private:
+ void VisitBinding_(const VarBindingNode* binding, const FunctionNode* func)
override {
+ if (auto opt = func->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
+ String public_name = opt.value();
+
+ // If a kGlobalSymbol exists, we must use the name exactly as it
+ // appears, with no modifications. Because these errors would
+ // be raised from deep within an optimization pipeline, but
+ // depends on small annotation changes from a user's initial
+ // model definition, they are intentionally verbose to
+ // (hopefully) provide sufficient context to a user encountering
+ // the error.
+ CHECK(!previous_global_vars_.count(public_name))
+ << "Function " << name_stack_.front() << " contains a lambda with
kGlobalSymbol (\""
+ << tvm::attr::kGlobalSymbol << "\" attribute of \"" << public_name
<< "\". "
+ << "However, the module already contains a GlobalVar with this name.
"
+ << "If present, the kGlobalSymbol attribute must match the name of
the GlobalVar, "
+ << "and GlobalVar names must be unique across an IRModule. "
+ << "Lifting the " << public_name << " function out of " <<
name_stack_.front()
+ << " would require violating one of these two conditions.";
+
+ auto it = new_public_names_.find(public_name);
+ CHECK(it == new_public_names_.end())
+ << "Function " << name_stack_.front() << " contains a lambda with
kGlobalSymbol (\""
+ << tvm::attr::kGlobalSymbol << "\" attribute of \"" << public_name
<< "\". "
+ << "However, the function " << it->second.front()
+ << " also contains a lambda with the same value for kGlobalSymbol. "
+ << "If present, the kGlobalSymbol attribute must match the name of
the GlobalVar, "
+ << "and GlobalVar names must be unique across an IRModule. "
+ << "Lifting the " << public_name << " function out of both " <<
name_stack_.front()
+ << " and " << it->second.front()
+ << " would require violating one of these two conditions.";
+
+ new_public_names_.insert({public_name, name_stack_});
+ lifted_with_global_symbol_.insert({func, public_name});
+ }
+
+ name_stack_.push_back(binding->var->name_hint());
+ lambda_location_.insert({func, name_stack_});
+ ExprVisitor::VisitBinding_(binding, func);
+ name_stack_.pop_back();
+ }
+
+ // De-duplication of collected names
+ std::unordered_map<const FunctionNode*, String> Finalize() const {
+ // The functions which still must be assigned a name
+ std::unordered_map<const FunctionNode*, Array<String>> remaining_to_name =
lambda_location_;
+
+ // Collecting the functions that now have a name.
+ std::unordered_map<const FunctionNode*, String> lifted_names;
+
+ // A lookup for names that are unavailable for use.
+ std::unordered_set<String> unavailable_names = previous_global_vars_;
+
+ // A helper function to check de-duplicate names
+ auto use_if_unique = [&](const auto& generate_proposed_name) {
Review Comment:
This might be a bit of a nitpick but I don't think `use_if_unique` is
necessarily a fully descriptive name. What is important to convey is that it
takes a scheme for generating names and applies it to all remaining lambdas. It
had me scratching my head until I looked through the implementation. Not
strictly necessary to change but it might be easier to understand at a glance.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]