yongwww commented on code in PR #16306:
URL: https://github.com/apache/tvm/pull/16306#discussion_r1446380538


##########
src/relax/transform/lambda_lift.cc:
##########
@@ -34,6 +34,195 @@
 namespace tvm {
 namespace relax {
 
+namespace {
+
+/* \brief Collect names of functions to be lifted out */
+class LambdaNameCollector : ExprVisitor {
+ public:
+  static std::unordered_map<const FunctionNode*, String> Collect(const 
IRModule& mod) {
+    LambdaNameCollector visitor;
+
+    for (const auto& [gvar, base_func] : mod->functions) {
+      visitor.previous_global_vars_.insert(gvar->name_hint);
+    }
+
+    for (const auto& [gvar, base_func] : mod->functions) {
+      if (auto func = base_func.as<Function>()) {
+        visitor.name_stack_.push_back(gvar->name_hint);
+        visitor(func.value());
+        visitor.name_stack_.pop_back();
+      }
+    }
+
+    return visitor.Finalize();
+  }
+
+ private:
+  void VisitBinding_(const VarBindingNode* binding, const FunctionNode* func) 
override {
+    if (auto opt = func->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
+      String public_name = opt.value();
+
+      // If a kGlobalSymbol exists, we must use the name exactly as it
+      // appears, with no modifications.  Because these errors would
+      // be raised from deep within an optimization pipeline, but
+      // depends on small annotation changes from a user's initial
+      // model definition, they are intentionally verbose to
+      // (hopefully) provide sufficient context to a user encountering
+      // the error.
+      CHECK(!previous_global_vars_.count(public_name))
+          << "Function " << name_stack_.front() << " contains a lambda with 
kGlobalSymbol (\""
+          << tvm::attr::kGlobalSymbol << "\" attribute of \"" << public_name 
<< "\".  "
+          << "However, the module already contains a GlobalVar with this name. 
 "
+          << "If present, the kGlobalSymbol attribute must match the name of 
the GlobalVar, "
+          << "and GlobalVar names must be unique across an IRModule.  "
+          << "Lifting the " << public_name << " function out of " << 
name_stack_.front()
+          << " would require violating one of these two conditions.";
+
+      auto it = new_public_names_.find(public_name);
+      CHECK(it == new_public_names_.end())
+          << "Function " << name_stack_.front() << " contains a lambda with 
kGlobalSymbol (\""
+          << tvm::attr::kGlobalSymbol << "\" attribute of \"" << public_name 
<< "\".  "
+          << "However, the function " << it->second.front()
+          << " also contains a lambda with the same value for kGlobalSymbol.  "
+          << "If present, the kGlobalSymbol attribute must match the name of 
the GlobalVar, "
+          << "and GlobalVar names must be unique across an IRModule.  "
+          << "Lifting the " << public_name << " function out of both " << 
name_stack_.front()
+          << " and " << it->second.front()
+          << " would require violating one of these two conditions.";
+
+      new_public_names_.insert({public_name, name_stack_});
+      lifted_with_global_symbol_.insert({func, public_name});
+    }
+
+    name_stack_.push_back(binding->var->name_hint());
+    lambda_location_.insert({func, name_stack_});
+    ExprVisitor::VisitBinding_(binding, func);
+    name_stack_.pop_back();
+  }
+
+  // De-duplication of collected names
+  std::unordered_map<const FunctionNode*, String> Finalize() const {
+    // The functions which still must be assigned a name
+    std::unordered_map<const FunctionNode*, Array<String>> remaining_to_name = 
lambda_location_;
+
+    // Collecting the functions that now have a name.
+    std::unordered_map<const FunctionNode*, String> lifted_names;
+
+    // A lookup for names that are unavailable for use.
+    std::unordered_set<String> unavailable_names = previous_global_vars_;
+
+    // A helper function to generate de-duplicated names.  The
+    // `proposed_name_generation_func` should be a function with
+    // signature:
+    //
+    //     Optional<String> func(const FunctionNode*, const Array<String>&)
+    //
+    // The first argument will be the lambda function being lifted.
+    // The second argument will be the nested location where that
+    // lambda function was found.  The function should return the
+    // proposed name for the lifted lambda function.  The proposed
+    // name will be accepted if it does not conflict with any previous
+    // names, and is unique for all lambda functions being lifted.
+    //
+    // This helper function is used to apply several different schemes
+    // to generate the name of the lifted lambda function.  The
+    // overall goal is to provide names that are unique (required by
+    // IRModule), deterministic (required for unit testing), and
+    // human-readable.
+    auto attempt_name_generation = [&](const auto& 
proposed_name_generation_func) {
+      if (remaining_to_name.empty()) {
+        return;
+      }
+
+      std::unordered_map<String, const FunctionNode*> new_names;
+      for (const auto& [func, location] : remaining_to_name) {
+        if (Optional<String> opt_proposed_name = 
proposed_name_generation_func(func, location)) {
+          auto proposed_name = opt_proposed_name.value();
+
+          if (unavailable_names.count(proposed_name)) {

Review Comment:
   the body of the `if` statement is empty, probably we can merge it with `else 
if`/ `else`



##########
tests/python/relax/test_transform_lambda_lift.py:
##########
@@ -205,34 +211,40 @@ def while_loop(
 
 
 def test_multi_func():
+    """Lifting may be required for multiple top-level functions
+
+    De-duplication of GlobalVar names at the IRModule is done by
+    appending the name of the function from which they were lifted.
+    """
+
     # expected IRModule
     @tvm.script.ir_module
     class Expected:
         @R.function
         def glob_func_1(
             x1: R.Tensor((10, 5), "float32"), y1: R.Tensor((10, 5), "float32")
         ) -> R.Tensor(None, "float32", ndim=2):
-            inner = Expected.lifted_func_0
+            inner = Expected.glob_func_1_inner
             gv1: R.Tensor((10, 5), "float32") = inner(x1, y1)
             return gv1
 
         @R.function
         def glob_func_2(
             x11: R.Tensor((10, 5), "float32"), y11: R.Tensor((10, 5), 
"float32")
         ) -> R.Tensor(None, "float32", ndim=2):
-            inner = Expected.lifted_func_1
+            inner = Expected.glob_func_2_inner
             gv11: R.Tensor((10, 5), "float32") = inner(x11, y11)
             return gv11
 
         @R.function(private=True)
-        def lifted_func_0(
+        def glob_func_1_inner(
             x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32")
         ) -> R.Tensor((10, 5), "float32"):
             s: R.Tensor((10, 5), "float32") = R.add(x2, y2)
             return s
 
         @R.function(private=True)
-        def lifted_func_1(
+        def glob_func_2_inner(

Review Comment:
   how about adding a test case to cover the dedup of possible duplicate lifted 
names like:
   
   ```
   @ir_module
   class Before:
       @R.function
       def foo_foo(x1: R.Tensor):
           @R.function
           def foo(x2: R.Tensor):
               ...
   
       @R.function
       def foo(x1: R.Tensor):
           @R.function
           def foo_foo(x2: R.Tensor):
               ...
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to