yongwww commented on code in PR #16306:
URL: https://github.com/apache/tvm/pull/16306#discussion_r1446380538
##########
src/relax/transform/lambda_lift.cc:
##########
@@ -34,6 +34,195 @@
namespace tvm {
namespace relax {
+namespace {
+
+/* \brief Collect names of functions to be lifted out */
+class LambdaNameCollector : ExprVisitor {
+ public:
+ static std::unordered_map<const FunctionNode*, String> Collect(const
IRModule& mod) {
+ LambdaNameCollector visitor;
+
+ for (const auto& [gvar, base_func] : mod->functions) {
+ visitor.previous_global_vars_.insert(gvar->name_hint);
+ }
+
+ for (const auto& [gvar, base_func] : mod->functions) {
+ if (auto func = base_func.as<Function>()) {
+ visitor.name_stack_.push_back(gvar->name_hint);
+ visitor(func.value());
+ visitor.name_stack_.pop_back();
+ }
+ }
+
+ return visitor.Finalize();
+ }
+
+ private:
+ void VisitBinding_(const VarBindingNode* binding, const FunctionNode* func)
override {
+ if (auto opt = func->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
+ String public_name = opt.value();
+
+ // If a kGlobalSymbol exists, we must use the name exactly as it
+ // appears, with no modifications. Because these errors would
+ // be raised from deep within an optimization pipeline, but
+ // depends on small annotation changes from a user's initial
+ // model definition, they are intentionally verbose to
+ // (hopefully) provide sufficient context to a user encountering
+ // the error.
+ CHECK(!previous_global_vars_.count(public_name))
+ << "Function " << name_stack_.front() << " contains a lambda with
kGlobalSymbol (\""
+ << tvm::attr::kGlobalSymbol << "\" attribute of \"" << public_name
<< "\". "
+ << "However, the module already contains a GlobalVar with this name.
"
+ << "If present, the kGlobalSymbol attribute must match the name of
the GlobalVar, "
+ << "and GlobalVar names must be unique across an IRModule. "
+ << "Lifting the " << public_name << " function out of " <<
name_stack_.front()
+ << " would require violating one of these two conditions.";
+
+ auto it = new_public_names_.find(public_name);
+ CHECK(it == new_public_names_.end())
+ << "Function " << name_stack_.front() << " contains a lambda with
kGlobalSymbol (\""
+ << tvm::attr::kGlobalSymbol << "\" attribute of \"" << public_name
<< "\". "
+ << "However, the function " << it->second.front()
+ << " also contains a lambda with the same value for kGlobalSymbol. "
+ << "If present, the kGlobalSymbol attribute must match the name of
the GlobalVar, "
+ << "and GlobalVar names must be unique across an IRModule. "
+ << "Lifting the " << public_name << " function out of both " <<
name_stack_.front()
+ << " and " << it->second.front()
+ << " would require violating one of these two conditions.";
+
+ new_public_names_.insert({public_name, name_stack_});
+ lifted_with_global_symbol_.insert({func, public_name});
+ }
+
+ name_stack_.push_back(binding->var->name_hint());
+ lambda_location_.insert({func, name_stack_});
+ ExprVisitor::VisitBinding_(binding, func);
+ name_stack_.pop_back();
+ }
+
+ // De-duplication of collected names
+ std::unordered_map<const FunctionNode*, String> Finalize() const {
+ // The functions which still must be assigned a name
+ std::unordered_map<const FunctionNode*, Array<String>> remaining_to_name =
lambda_location_;
+
+ // Collecting the functions that now have a name.
+ std::unordered_map<const FunctionNode*, String> lifted_names;
+
+ // A lookup for names that are unavailable for use.
+ std::unordered_set<String> unavailable_names = previous_global_vars_;
+
+ // A helper function to generate de-duplicated names. The
+ // `proposed_name_generation_func` should be a function with
+ // signature:
+ //
+ // Optional<String> func(const FunctionNode*, const Array<String>&)
+ //
+ // The first argument will be the lambda function being lifted.
+ // The second argument will be the nested location where that
+ // lambda function was found. The function should return the
+ // proposed name for the lifted lambda function. The proposed
+ // name will be accepted if it does not conflict with any previous
+ // names, and is unique for all lambda functions being lifted.
+ //
+ // This helper function is used to apply several different schemes
+ // to generate the name of the lifted lambda function. The
+ // overall goal is to provide names that are unique (required by
+ // IRModule), deterministic (required for unit testing), and
+ // human-readable.
+ auto attempt_name_generation = [&](const auto&
proposed_name_generation_func) {
+ if (remaining_to_name.empty()) {
+ return;
+ }
+
+ std::unordered_map<String, const FunctionNode*> new_names;
+ for (const auto& [func, location] : remaining_to_name) {
+ if (Optional<String> opt_proposed_name =
proposed_name_generation_func(func, location)) {
+ auto proposed_name = opt_proposed_name.value();
+
+ if (unavailable_names.count(proposed_name)) {
Review Comment:
the body of the `if` statement is empty, probably we can merge it with `else
if`/ `else`
##########
tests/python/relax/test_transform_lambda_lift.py:
##########
@@ -205,34 +211,40 @@ def while_loop(
def test_multi_func():
+ """Lifting may be required for multiple top-level functions
+
+ De-duplication of GlobalVar names at the IRModule is done by
+ appending the name of the function from which they were lifted.
+ """
+
# expected IRModule
@tvm.script.ir_module
class Expected:
@R.function
def glob_func_1(
x1: R.Tensor((10, 5), "float32"), y1: R.Tensor((10, 5), "float32")
) -> R.Tensor(None, "float32", ndim=2):
- inner = Expected.lifted_func_0
+ inner = Expected.glob_func_1_inner
gv1: R.Tensor((10, 5), "float32") = inner(x1, y1)
return gv1
@R.function
def glob_func_2(
x11: R.Tensor((10, 5), "float32"), y11: R.Tensor((10, 5),
"float32")
) -> R.Tensor(None, "float32", ndim=2):
- inner = Expected.lifted_func_1
+ inner = Expected.glob_func_2_inner
gv11: R.Tensor((10, 5), "float32") = inner(x11, y11)
return gv11
@R.function(private=True)
- def lifted_func_0(
+ def glob_func_1_inner(
x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32")
) -> R.Tensor((10, 5), "float32"):
s: R.Tensor((10, 5), "float32") = R.add(x2, y2)
return s
@R.function(private=True)
- def lifted_func_1(
+ def glob_func_2_inner(
Review Comment:
how about adding a test case to cover the dedup of possible duplicate lifted
names like:
```
@ir_module
class Before:
@R.function
def foo_foo(x1: R.Tensor):
@R.function
def foo(x2: R.Tensor):
...
@R.function
def foo(x1: R.Tensor):
@R.function
def foo_foo(x2: R.Tensor):
...
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]