gemini-code-assist[bot] commented on code in PR #19612:
URL: https://github.com/apache/tvm/pull/19612#discussion_r3304784882
##########
src/relax/transform/decompose_ops.cc:
##########
@@ -212,6 +215,90 @@ class OpDecomposer : public ExprMutator {
namespace transform {
+namespace {
+
+/*! \brief Helper: add or remove an attribute on a BaseFunc */
+BaseFunc BaseFuncWithAttr(BaseFunc func, const std::string& attr_key, Any
attr_value) {
+ if (auto tirx = func.as<tirx::PrimFunc>()) {
+ return WithAttr(tirx.value(), attr_key, attr_value);
+ } else if (auto relax_fn = func.as<relax::Function>()) {
+ return WithAttr(relax_fn.value(), attr_key, attr_value);
+ } else {
+ return func;
+ }
+}
+
+BaseFunc BaseFuncWithoutAttr(BaseFunc func, const std::string& attr_key) {
+ if (auto tirx = func.as<tirx::PrimFunc>()) {
+ return WithoutAttr(tirx.value(), attr_key);
+ } else if (auto relax_fn = func.as<relax::Function>()) {
+ return WithoutAttr(relax_fn.value(), attr_key);
+ } else {
+ return func;
+ }
+}
+
+/*!
+ * \brief Apply a pass to a single named function within an IRModule.
+ *
+ * Replaces all other functions with dummy ExternFunc stubs so that the
+ * pass does not see them, then restores the original module. Uses
+ * exact name match (not a regex) because all in-tree callers supply a
+ * literal function name.
+ */
+Pass ApplyDecomposeToFunction(Pass pass, ffi::String func_name) {
+ auto pass_func = [pass, func_name](IRModule mod, PassContext) -> IRModule {
+ std::unordered_set<ffi::String> keep_original_version;
+ std::unordered_set<ffi::String> internal_functions;
+ IRModule subset;
+
+ for (auto [gvar, func] : mod->functions) {
+ std::string name = gvar->name_hint;
+ if (name == std::string(func_name)) {
+ if (!func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol).has_value())
{
+ // Mark internal functions as externally-exposed so that
+ // call-tracing transforms inside the pass do not remove them.
+ internal_functions.insert(gvar->name_hint);
+ func = BaseFuncWithAttr(func, tvm::attr::kGlobalSymbol,
gvar->name_hint);
+ }
+ } else {
+ // Replace non-target functions with stubs to keep references intact.
+ keep_original_version.insert(gvar->name_hint);
+ func = relax::ExternFunc("dummy_" + name);
+ func->struct_info_ = gvar->struct_info_;
+ }
+ subset->Add(gvar, func);
+ }
Review Comment:

Comparing `gvar->name_hint` directly with `func_name` as `ffi::String`
avoids converting both to `std::string` on every iteration of the loop. We can
also avoid the local variable `name` and directly construct the dummy function
name inside the `else` block. This significantly reduces unnecessary heap
allocations, especially for modules with many functions.
```c
for (auto [gvar, func] : mod->functions) {
if (gvar->name_hint == func_name) {
if
(!func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol).has_value()) {
// Mark internal functions as externally-exposed so that
// call-tracing transforms inside the pass do not remove them.
internal_functions.insert(gvar->name_hint);
func = BaseFuncWithAttr(func, tvm::attr::kGlobalSymbol,
gvar->name_hint);
}
} else {
// Replace non-target functions with stubs to keep references intact.
keep_original_version.insert(gvar->name_hint);
func = relax::ExternFunc("dummy_" + std::string(gvar->name_hint));
func->struct_info_ = gvar->struct_info_;
}
subset->Add(gvar, func);
}
```
##########
src/relax/transform/decompose_ops.cc:
##########
@@ -212,6 +215,90 @@ class OpDecomposer : public ExprMutator {
namespace transform {
+namespace {
+
+/*! \brief Helper: add or remove an attribute on a BaseFunc */
+BaseFunc BaseFuncWithAttr(BaseFunc func, const std::string& attr_key, Any
attr_value) {
+ if (auto tirx = func.as<tirx::PrimFunc>()) {
+ return WithAttr(tirx.value(), attr_key, attr_value);
+ } else if (auto relax_fn = func.as<relax::Function>()) {
+ return WithAttr(relax_fn.value(), attr_key, attr_value);
+ } else {
+ return func;
+ }
+}
+
+BaseFunc BaseFuncWithoutAttr(BaseFunc func, const std::string& attr_key) {
+ if (auto tirx = func.as<tirx::PrimFunc>()) {
+ return WithoutAttr(tirx.value(), attr_key);
+ } else if (auto relax_fn = func.as<relax::Function>()) {
+ return WithoutAttr(relax_fn.value(), attr_key);
+ } else {
+ return func;
+ }
+}
+
+/*!
+ * \brief Apply a pass to a single named function within an IRModule.
+ *
+ * Replaces all other functions with dummy ExternFunc stubs so that the
+ * pass does not see them, then restores the original module. Uses
+ * exact name match (not a regex) because all in-tree callers supply a
+ * literal function name.
+ */
+Pass ApplyDecomposeToFunction(Pass pass, ffi::String func_name) {
+ auto pass_func = [pass, func_name](IRModule mod, PassContext) -> IRModule {
+ std::unordered_set<ffi::String> keep_original_version;
+ std::unordered_set<ffi::String> internal_functions;
+ IRModule subset;
+
+ for (auto [gvar, func] : mod->functions) {
+ std::string name = gvar->name_hint;
+ if (name == std::string(func_name)) {
+ if (!func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol).has_value())
{
+ // Mark internal functions as externally-exposed so that
+ // call-tracing transforms inside the pass do not remove them.
+ internal_functions.insert(gvar->name_hint);
+ func = BaseFuncWithAttr(func, tvm::attr::kGlobalSymbol,
gvar->name_hint);
+ }
+ } else {
+ // Replace non-target functions with stubs to keep references intact.
+ keep_original_version.insert(gvar->name_hint);
+ func = relax::ExternFunc("dummy_" + name);
+ func->struct_info_ = gvar->struct_info_;
+ }
+ subset->Add(gvar, func);
+ }
+
+ IRModule new_subset = pass(subset);
+ if (new_subset.same_as(subset)) {
+ return mod;
+ }
+
+ auto write_ptr = mod.CopyOnWrite();
+ for (auto [gvar, func] : new_subset->functions) {
+ if (!keep_original_version.count(gvar->name_hint)) {
+ if (auto it = write_ptr->global_var_map_.find(gvar->name_hint);
+ it != write_ptr->global_var_map_.end()) {
+ write_ptr->Remove((*it).second);
+ }
+ if (internal_functions.count(gvar->name_hint)) {
+ func = BaseFuncWithoutAttr(func, tvm::attr::kGlobalSymbol);
+ }
+ write_ptr->Add(gvar, func);
+ }
+ }
+ return mod;
+ };
+
+ auto pass_name =
+ static_cast<const std::stringstream&>(std::stringstream() <<
"ApplyDecomposeTo" << func_name)
+ .str();
Review Comment:

Using `std::stringstream` with a `static_cast` to construct the pass name is
overly verbose and complex. Since `func_name` can be converted to
`std::string`, we can use simple string concatenation instead.
```suggestion
std::string pass_name = "ApplyDecomposeTo" + std::string(func_name);
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]