Lunderberg commented on code in PR #16306:
URL: https://github.com/apache/tvm/pull/16306#discussion_r1446428657
##########
src/relax/transform/lambda_lift.cc:
##########
@@ -34,6 +34,195 @@
namespace tvm {
namespace relax {
+namespace {
+
+/* \brief Collect names of functions to be lifted out */
+class LambdaNameCollector : ExprVisitor {
+ public:
+ static std::unordered_map<const FunctionNode*, String> Collect(const
IRModule& mod) {
+ LambdaNameCollector visitor;
+
+ for (const auto& [gvar, base_func] : mod->functions) {
+ visitor.previous_global_vars_.insert(gvar->name_hint);
+ }
+
+ for (const auto& [gvar, base_func] : mod->functions) {
+ if (auto func = base_func.as<Function>()) {
+ visitor.name_stack_.push_back(gvar->name_hint);
+ visitor(func.value());
+ visitor.name_stack_.pop_back();
+ }
+ }
+
+ return visitor.Finalize();
+ }
+
+ private:
+ void VisitBinding_(const VarBindingNode* binding, const FunctionNode* func)
override {
+ if (auto opt = func->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
+ String public_name = opt.value();
+
+ // If a kGlobalSymbol exists, we must use the name exactly as it
+ // appears, with no modifications. Because these errors would
+ // be raised from deep within an optimization pipeline, but
+ // depends on small annotation changes from a user's initial
+ // model definition, they are intentionally verbose to
+ // (hopefully) provide sufficient context to a user encountering
+ // the error.
+ CHECK(!previous_global_vars_.count(public_name))
+ << "Function " << name_stack_.front() << " contains a lambda with
kGlobalSymbol (\""
+ << tvm::attr::kGlobalSymbol << "\" attribute of \"" << public_name
<< "\". "
+ << "However, the module already contains a GlobalVar with this name.
"
+ << "If present, the kGlobalSymbol attribute must match the name of
the GlobalVar, "
+ << "and GlobalVar names must be unique across an IRModule. "
+ << "Lifting the " << public_name << " function out of " <<
name_stack_.front()
+ << " would require violating one of these two conditions.";
+
+ auto it = new_public_names_.find(public_name);
+ CHECK(it == new_public_names_.end())
+ << "Function " << name_stack_.front() << " contains a lambda with
kGlobalSymbol (\""
+ << tvm::attr::kGlobalSymbol << "\" attribute of \"" << public_name
<< "\". "
+ << "However, the function " << it->second.front()
+ << " also contains a lambda with the same value for kGlobalSymbol. "
+ << "If present, the kGlobalSymbol attribute must match the name of
the GlobalVar, "
+ << "and GlobalVar names must be unique across an IRModule. "
+ << "Lifting the " << public_name << " function out of both " <<
name_stack_.front()
+ << " and " << it->second.front()
+ << " would require violating one of these two conditions.";
+
+ new_public_names_.insert({public_name, name_stack_});
+ lifted_with_global_symbol_.insert({func, public_name});
+ }
+
+ name_stack_.push_back(binding->var->name_hint());
+ lambda_location_.insert({func, name_stack_});
+ ExprVisitor::VisitBinding_(binding, func);
+ name_stack_.pop_back();
+ }
+
+ // De-duplication of collected names
+ std::unordered_map<const FunctionNode*, String> Finalize() const {
+ // The functions which still must be assigned a name
+ std::unordered_map<const FunctionNode*, Array<String>> remaining_to_name =
lambda_location_;
+
+ // Collecting the functions that now have a name.
+ std::unordered_map<const FunctionNode*, String> lifted_names;
+
+ // A lookup for names that are unavailable for use.
+ std::unordered_set<String> unavailable_names = previous_global_vars_;
+
+ // A helper function to generate de-duplicated names. The
+ // `proposed_name_generation_func` should be a function with
+ // signature:
+ //
+ // Optional<String> func(const FunctionNode*, const Array<String>&)
+ //
+ // The first argument will be the lambda function being lifted.
+ // The second argument will be the nested location where that
+ // lambda function was found. The function should return the
+ // proposed name for the lifted lambda function. The proposed
+ // name will be accepted if it does not conflict with any previous
+ // names, and is unique for all lambda functions being lifted.
+ //
+ // This helper function is used to apply several different schemes
+ // to generate the name of the lifted lambda function. The
+ // overall goal is to provide names that are unique (required by
+ // IRModule), deterministic (required for unit testing), and
+ // human-readable.
+ auto attempt_name_generation = [&](const auto&
proposed_name_generation_func) {
+ if (remaining_to_name.empty()) {
+ return;
+ }
+
+ std::unordered_map<String, const FunctionNode*> new_names;
+ for (const auto& [func, location] : remaining_to_name) {
+ if (Optional<String> opt_proposed_name =
proposed_name_generation_func(func, location)) {
+ auto proposed_name = opt_proposed_name.value();
+
+ if (unavailable_names.count(proposed_name)) {
Review Comment:
Yeah, I went back and forth on it. It could be merged with the `else
if`/`else`, but that would require repeating (and re-evaluating) the condition
in the other two branches. It could be pulled out as
`if(!unavailable_names.count(proposed_name))` to wrap around the other two
cases, but this utility is already getting a bit deeply nested for readability.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]