gemini-code-assist[bot] commented on code in PR #18720:
URL: https://github.com/apache/tvm/pull/18720#discussion_r2776711509
##########
tests/python/relax/test_transform_dead_code_elimination.py:
##########
@@ -799,5 +799,117 @@ def while_loop(
verify(Before, Expected)
+def test_mutual_recursion_unused_params():
+ """Test that unused parameters are removed even in mutual recursion"""
+
+ @tvm.script.ir_module
+ class Input:
+ @R.function
+ def func_a(
+ x: R.Tensor([32, 32], "float32"),
+ y: R.Tensor([32, 32], "float32"), # Unused in both
+ z: R.Tensor([32, 32], "float32"),
+ ):
+ cls = Input
+ out = R.add(x, z)
+ result = cls.func_b(out, y, z) # y passed but func_b doesn't use
it
+ return result
+
+ @R.function
+ def func_b(
+ a: R.Tensor([32, 32], "float32"),
+ b: R.Tensor([32, 32], "float32"), # Unused
+ c: R.Tensor([32, 32], "float32"),
+ ):
+ cls = Input
+ out = R.add(a, c)
+ result = cls.func_a(out, b, c)
+ return result
+
+ @R.function
+ def main():
+ x = R.zeros([32, 32], "float32")
+ y = R.ones([32, 32], "float32")
+ z = R.full([32, 32], R.const(2.0, dtype="float32"), "float32")
+ return Input.func_a(x, y, z)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def func_a(x: R.Tensor([32, 32], "float32"), z: R.Tensor([32, 32],
"float32")):
+ cls = Expected
+ out = R.add(x, z)
+ result = cls.func_b(out, z)
+ return result
+
+ @R.function
+ def func_b(a: R.Tensor([32, 32], "float32"), c: R.Tensor([32, 32],
"float32")):
+ cls = Expected
+ out = R.add(a, c)
+ result = cls.func_a(out, c)
+ return result
+
+ @R.function
+ def main():
+ x = R.zeros([32, 32], "float32")
+ z = R.full([32, 32], R.const(2.0, dtype="float32"), "float32")
+ return Expected.func_a(x, z)
+
+ verify(Input, Input)
Review Comment:

The test `test_mutual_recursion_unused_params` defines an `Expected`
IRModule but then verifies that the transformed `Input` is equal to the
original `Input`. This is incorrect and doesn't test the pass's effect. It
should verify against the `Expected` module.
```suggestion
verify(Input, Expected)
```
##########
src/relax/transform/dead_code_elimination.cc:
##########
@@ -91,6 +95,153 @@ IRModule RemoveUnusedFunctions(IRModule mod, const
std::unordered_set<GlobalVar>
return mod;
}
+// two-stage dead parameter elimination
+// 1. collect all unused parameters with propagation
+// 2. update all call points with new functions and inputs
+std::unordered_map<GlobalVar, std::vector<int>> CollectUsedParamIndices(const
IRModule& mod) {
+ std::unordered_map<GlobalVar, std::vector<int>> result;
+
+ for (const auto& [gvar, base_func] : mod->functions) {
+ if (auto opt_func = base_func.as<Function>()) {
+ auto func = opt_func.value();
+ std::vector<bool> used(func->params.size(), false);
+
+ PostOrderVisit(func->body, [&](const ObjectRef& obj) {
+ if (auto v = obj.as<VarNode>()) {
+ Var var = ffi::GetRef<Var>(v);
+ for (size_t i = 0; i < func->params.size(); ++i) {
+ if (var.same_as(func->params[i])) {
+ used[i] = true;
+ }
+ }
+ }
+ });
Review Comment:

This approach to finding used parameters involves a nested loop within
`PostOrderVisit`, which can be inefficient for functions with large bodies. For
each variable encountered, it iterates through all function parameters. This
can be optimized by building a map from parameter `Var` to its index before the
visit, and then performing a direct lookup for each variable inside the
visitor. This would change the complexity from O(num_vars * num_params) to
O(num_vars + num_params).
```c
std::unordered_map<Var, int, ObjectPtrHash, ObjectPtrEqual>
param_to_idx;
for (size_t i = 0; i < func->params.size(); ++i) {
param_to_idx[func->params[i]] = i;
}
PostOrderVisit(func->body, [&](const ObjectRef& obj) {
if (auto v = obj.as<VarNode>()) {
Var var = ffi::GetRef<Var>(v);
auto it = param_to_idx.find(var);
if (it != param_to_idx.end()) {
used[it->second] = true;
}
}
});
```
##########
src/relax/transform/dead_code_elimination.cc:
##########
@@ -91,6 +95,153 @@ IRModule RemoveUnusedFunctions(IRModule mod, const
std::unordered_set<GlobalVar>
return mod;
}
+// two-stage dead parameter elimination
+// 1. collect all unused parameters with propagation
+// 2. update all call points with new functions and inputs
+std::unordered_map<GlobalVar, std::vector<int>> CollectUsedParamIndices(const
IRModule& mod) {
+ std::unordered_map<GlobalVar, std::vector<int>> result;
+
+ for (const auto& [gvar, base_func] : mod->functions) {
+ if (auto opt_func = base_func.as<Function>()) {
+ auto func = opt_func.value();
+ std::vector<bool> used(func->params.size(), false);
+
+ PostOrderVisit(func->body, [&](const ObjectRef& obj) {
+ if (auto v = obj.as<VarNode>()) {
+ Var var = ffi::GetRef<Var>(v);
+ for (size_t i = 0; i < func->params.size(); ++i) {
+ if (var.same_as(func->params[i])) {
+ used[i] = true;
+ }
+ }
+ }
+ });
+
+ std::vector<int> indices;
+ for (size_t i = 0; i < used.size(); ++i) {
+ if (used[i]) indices.push_back(i);
+ }
+
+ result[gvar] = std::move(indices);
+ }
+ }
+
+ return result;
+}
+
+struct CallSiteUpdater : public ExprMutator {
+ const std::unordered_map<GlobalVar, std::vector<int>>& used_param_indices;
+
+ explicit CallSiteUpdater(const std::unordered_map<GlobalVar,
std::vector<int>>& used)
+ : ExprMutator(std::nullopt), used_param_indices(used) {}
+
+ using ExprMutator::VisitExpr_;
+
+ Expr VisitExpr_(const CallNode* call) final {
+ if (auto gvar = call->op.as<GlobalVar>()) {
+ auto it = used_param_indices.find(gvar.value());
+ if (it != used_param_indices.end()) {
+ const auto& used = it->second;
+
+ if (used.size() == call->args.size()) {
+ return ExprMutator::VisitExpr_(call);
+ }
+
+ ffi::Array<Expr> new_args;
+ for (int idx : used) {
+ new_args.push_back(call->args[idx]);
+ }
+
+ auto new_call = Call(call->op, new_args, call->attrs);
+ if (call->struct_info_.defined()) {
+ new_call->struct_info_ = call->struct_info_;
+ }
Review Comment:

Manually copying the `struct_info_` from the old call to the new call is
incorrect. When arguments are removed from a call, the `struct_info` of the
call (which describes the output) may change, especially if the return type
depends on the removed arguments. The `ExprMutator` framework will handle
re-inferring the `struct_info` for the new call node during normalization.
Removing this manual assignment will allow the correct `struct_info` to be
inferred.
##########
src/relax/transform/dead_code_elimination.cc:
##########
@@ -91,6 +95,153 @@ IRModule RemoveUnusedFunctions(IRModule mod, const
std::unordered_set<GlobalVar>
return mod;
}
+// two-stage dead parameter elimination
+// 1. collect all unused parameters with propagation
+// 2. update all call points with new functions and inputs
+std::unordered_map<GlobalVar, std::vector<int>> CollectUsedParamIndices(const
IRModule& mod) {
+ std::unordered_map<GlobalVar, std::vector<int>> result;
+
+ for (const auto& [gvar, base_func] : mod->functions) {
+ if (auto opt_func = base_func.as<Function>()) {
+ auto func = opt_func.value();
+ std::vector<bool> used(func->params.size(), false);
+
+ PostOrderVisit(func->body, [&](const ObjectRef& obj) {
+ if (auto v = obj.as<VarNode>()) {
+ Var var = ffi::GetRef<Var>(v);
+ for (size_t i = 0; i < func->params.size(); ++i) {
+ if (var.same_as(func->params[i])) {
+ used[i] = true;
+ }
+ }
+ }
+ });
+
+ std::vector<int> indices;
+ for (size_t i = 0; i < used.size(); ++i) {
+ if (used[i]) indices.push_back(i);
+ }
+
+ result[gvar] = std::move(indices);
+ }
+ }
+
+ return result;
+}
+
+struct CallSiteUpdater : public ExprMutator {
+ const std::unordered_map<GlobalVar, std::vector<int>>& used_param_indices;
+
+ explicit CallSiteUpdater(const std::unordered_map<GlobalVar,
std::vector<int>>& used)
+ : ExprMutator(std::nullopt), used_param_indices(used) {}
+
+ using ExprMutator::VisitExpr_;
+
+ Expr VisitExpr_(const CallNode* call) final {
+ if (auto gvar = call->op.as<GlobalVar>()) {
+ auto it = used_param_indices.find(gvar.value());
+ if (it != used_param_indices.end()) {
+ const auto& used = it->second;
+
+ if (used.size() == call->args.size()) {
+ return ExprMutator::VisitExpr_(call);
+ }
+
+ ffi::Array<Expr> new_args;
+ for (int idx : used) {
+ new_args.push_back(call->args[idx]);
+ }
+
+ auto new_call = Call(call->op, new_args, call->attrs);
+ if (call->struct_info_.defined()) {
+ new_call->struct_info_ = call->struct_info_;
+ }
+ return new_call;
+ }
+ }
+ return ExprMutator::VisitExpr_(call);
+ }
+};
+
+IRModule RemoveUnusedParameters(IRModule mod) {
+ auto write_ptr = mod.CopyOnWrite();
+ bool changed = true;
+
+ do {
+ changed = false;
+
+ auto used_param_indices = CollectUsedParamIndices(mod);
+
+ for (const auto& [gvar, used] : used_param_indices) {
+ if (auto opt_func = mod->Lookup(gvar).as<Function>()) {
+ auto func = opt_func.value();
+ if (used.size() < func->params.size()) {
+ changed = true;
+ break;
+ }
+ }
+ }
+
+ if (!changed) break;
+
+ std::vector<GlobalVar> worklist;
+ std::unordered_set<GlobalVar> visited;
+ std::function<void(GlobalVar)> dfs = [&](GlobalVar gvar) {
+ if (visited.count(gvar)) return;
+ visited.insert(gvar);
+
+ if (auto opt_func = mod->Lookup(gvar).as<Function>()) {
+ auto func = opt_func.value();
+ PostOrderVisit(func->body, [&](const ObjectRef& obj) {
+ if (auto call = obj.as<CallNode>()) {
+ if (auto callee_gvar = call->op.as<GlobalVar>()) {
+ dfs(callee_gvar.value());
+ }
+ }
+ });
+ }
+ worklist.push_back(gvar);
+ };
+
+ for (const auto& [gvar, _] : mod->functions) {
+ dfs(gvar);
+ }
Review Comment:

The topological sort of the call graph is performed inside the `do-while`
loop. However, the call graph structure does not change during dead parameter
elimination. You can move the topological sort outside the loop to avoid
re-computing it on every iteration, which would improve performance.
##########
src/relax/transform/dead_code_elimination.cc:
##########
@@ -127,9 +278,28 @@ IRModule DeadCodeElimination(const IRModule& arg_mod,
}
}
- // S3: remove unused functions again as some callers may be removed in S2.
+ // S3: remove unused parameters in each function
+ mod = RemoveUnusedParameters(mod);
+
+ // S4: remove unused functions again as some callers may be removed in S2
and S3.
mod = RemoveUnusedFunctions(mod, entry_functions);
+ // S5: remove unused variables again as some arguments may be removed in S3
+ {
+ IRModule updates;
+ for (const auto& [gvar, base_func] : mod->functions) {
+ if (auto opt = base_func.as<Function>()) {
+ auto new_func = Downcast<Function>(RemoveAllUnused(opt.value()));
+ if (!new_func.same_as(base_func)) {
+ updates->Add(gvar, new_func);
+ }
+ }
+ }
+ if (updates->functions.size()) {
+ mod.CopyOnWrite()->Update(updates);
+ }
+ }
Review Comment:

This block of code for removing unused variables (`S5`) is identical to the
block for `S2` (lines 266-279). To improve maintainability and avoid code
duplication, you could extract this logic into a helper function, for example
`RemoveUnusedVariables(IRModule mod)`, and call it in both places.
##########
src/ir/module.cc:
##########
@@ -152,6 +152,9 @@ void IRModuleNode::Add(const GlobalVar& var, const
BaseFunc& f, bool update) {
void IRModuleNode::AddUnchecked(const GlobalVar& var, const BaseFunc& func) {
this->functions.Set(var, func);
+ auto* var_node = const_cast<GlobalVarNode*>(var.get());
Review Comment:

The `const_cast` here is unnecessary. The `struct_info_` member of
`RelaxExprNode` (which `GlobalVarNode` inherits from) is marked as `mutable`,
which means it can be modified on a `const` object. You can directly get a
`const GlobalVarNode*` and assign to its `struct_info_`.
```suggestion
const auto* var_node = var.get();
```
##########
tests/python/relax/test_transform_dead_code_elimination.py:
##########
@@ -799,5 +799,117 @@ def while_loop(
verify(Before, Expected)
+def test_mutual_recursion_unused_params():
+ """Test that unused parameters are removed even in mutual recursion"""
+
+ @tvm.script.ir_module
+ class Input:
+ @R.function
+ def func_a(
+ x: R.Tensor([32, 32], "float32"),
+ y: R.Tensor([32, 32], "float32"), # Unused in both
+ z: R.Tensor([32, 32], "float32"),
+ ):
+ cls = Input
+ out = R.add(x, z)
+ result = cls.func_b(out, y, z) # y passed but func_b doesn't use
it
+ return result
+
+ @R.function
+ def func_b(
+ a: R.Tensor([32, 32], "float32"),
+ b: R.Tensor([32, 32], "float32"), # Unused
+ c: R.Tensor([32, 32], "float32"),
+ ):
+ cls = Input
+ out = R.add(a, c)
+ result = cls.func_a(out, b, c)
+ return result
+
+ @R.function
+ def main():
+ x = R.zeros([32, 32], "float32")
+ y = R.ones([32, 32], "float32")
+ z = R.full([32, 32], R.const(2.0, dtype="float32"), "float32")
+ return Input.func_a(x, y, z)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def func_a(x: R.Tensor([32, 32], "float32"), z: R.Tensor([32, 32],
"float32")):
+ cls = Expected
+ out = R.add(x, z)
+ result = cls.func_b(out, z)
+ return result
+
+ @R.function
+ def func_b(a: R.Tensor([32, 32], "float32"), c: R.Tensor([32, 32],
"float32")):
+ cls = Expected
+ out = R.add(a, c)
+ result = cls.func_a(out, c)
+ return result
+
+ @R.function
+ def main():
+ x = R.zeros([32, 32], "float32")
+ z = R.full([32, 32], R.const(2.0, dtype="float32"), "float32")
+ return Expected.func_a(x, z)
+
+ verify(Input, Input)
+
+
+def test_deep_recursion_chain():
+ """Test parameter removal through a chain of recursive calls"""
+
+ @tvm.script.ir_module
+ class Input:
+ @R.function
+ def depth_1(x: R.Tensor([32], "float32"), dead1: R.Tensor([32],
"float32")): # Unused
+ out = R.add(x, x)
+ result = Input.depth_2(out, dead1)
+ return result
+
+ @R.function
+ def depth_2(a: R.Tensor([32], "float32"), dead2: R.Tensor([32],
"float32")): # Unused
+ out = R.multiply(a, a)
+ result = Input.depth_3(out, dead2)
+ return result
+
+ @R.function
+ def depth_3(b: R.Tensor([32], "float32"), dead3: R.Tensor([32],
"float32")): # Unused
+ return R.subtract(b, b)
+
+ @R.function
+ def main():
+ x = R.zeros([32], "float32")
+ dead = R.ones([32], "float32")
+ return Input.depth_1(x, dead)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def depth_1(x: R.Tensor([32], "float32")):
+ out = R.add(x, x)
+ result = Expected.depth_2(out)
+ return result
+
+ @R.function
+ def depth_2(a: R.Tensor([32], "float32")):
+ out = R.multiply(a, a)
+ result = Expected.depth_3(out)
+ return result
+
+ @R.function
+ def depth_3(b: R.Tensor([32], "float32")): # Unused
Review Comment:

The comment `# Unused` is incorrect here. The parameter `b` is used in the
function body on the next line: `return R.subtract(b, b)`. This comment seems
to be a leftover from the `Input` module definition and should be removed to
avoid confusion.
```suggestion
def depth_3(b: R.Tensor([32], "float32")):
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]