This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ff884b609a [Relax][Transform] Handle tuple return in
RemoveUnusedOutputs (#17253)
ff884b609a is described below
commit ff884b609a2eb94fef1f061bff0ec867b79d4ba0
Author: Eric Lunderberg <[email protected]>
AuthorDate: Fri Sep 6 11:28:28 2024 -0500
[Relax][Transform] Handle tuple return in RemoveUnusedOutputs (#17253)
* [Relax][Transform] Handle tuple return in RemoveUnusedOutputs
Prior to this commit, the `relax.transform.RemoveUnusedOutputs` pass
only marked a tuple element as used if it occurred in a `TupleGetItem`
node. This ignored use cases where a tuple is used as an aggregate
object, such as returning a tuple from a function. This would collect
incorrect results for a Relax function that calls a subroutine,
receives a tuple as the return value of the subroutine, then returns
that tuple.
This commit updates `RemoveUnusedOutputs` to look for usage of a tuple
object, not just for usage in `TupleGetItem`.
Closes https://github.com/apache/tvm/issues/17247
---
src/relax/transform/remove_unused_outputs.cc | 59 ++++++++++++++--------
.../relax/test_transform_remove_unused_outputs.py | 20 ++++++++
2 files changed, 59 insertions(+), 20 deletions(-)
diff --git a/src/relax/transform/remove_unused_outputs.cc
b/src/relax/transform/remove_unused_outputs.cc
index e3bf12382c..9a5c31e79b 100644
--- a/src/relax/transform/remove_unused_outputs.cc
+++ b/src/relax/transform/remove_unused_outputs.cc
@@ -92,29 +92,48 @@ class PartialTupleUsageCollector : ExprVisitor {
}
void VisitExpr_(const TupleGetItemNode* op) override {
- Expr tuple = UnwrapBindings(op->tuple);
-
- if (auto call = tuple.as<CallNode>()) {
- if (auto opt_callee = call->op.as<GlobalVar>()) {
- auto callee = opt_callee.value();
- if (auto it = output_usage_mask_.find(callee); it !=
output_usage_mask_.end()) {
- auto& used_indices = it->second;
-
- CHECK_GE(op->index, 0) << "IndexError: "
- << "Indices for TupleGetItem must be
non-negative, "
- << "but expression " << GetRef<Expr>(op)
- << " uses a tuple index of " << op->index;
- size_t index = op->index;
-
- CHECK_LT(index, used_indices.size())
- << "IndexError: "
- << "Indices for TupleGetItem must be less than the size of the
tuple, "
- << "but expression " << GetRef<Expr>(op) << " uses a tuple index
of " << op->index
- << " for a tuple of size " << used_indices.size();
- used_indices[index] = true;
+ if (auto* usage_mask_ptr = GetCalleeUsageMask(op->tuple)) {
+ auto& used_indices = *usage_mask_ptr;
+
+ CHECK_GE(op->index, 0) << "IndexError: "
+ << "Indices for TupleGetItem must be
non-negative, "
+ << "but expression " << GetRef<Expr>(op) << "
uses a tuple index of "
+ << op->index;
+ size_t index = op->index;
+
+ CHECK_LT(index, used_indices.size())
+ << "IndexError: "
+ << "Indices for TupleGetItem must be less than the size of the
tuple, "
+ << "but expression " << GetRef<Expr>(op) << " uses a tuple index of
" << op->index
+ << " for a tuple of size " << used_indices.size();
+ used_indices[index] = true;
+ }
+ }
+
+ void VisitExpr_(const VarNode* op) override {
+ if (auto* usage_mask_ptr = GetCalleeUsageMask(GetRef<Var>(op))) {
+ auto& usage_mask = *usage_mask_ptr;
+ for (size_t i = 0; i < usage_mask.size(); i++) {
+ usage_mask[i] = true;
+ }
+ }
+ }
+
+ std::vector<bool>* GetCalleeUsageMask(Expr expr) {
+ if (!expr->struct_info_.as<TupleStructInfoNode>()) {
+ return nullptr;
+ }
+
+ expr = UnwrapBindings(expr);
+ if (auto call = expr.as<CallNode>()) {
+ if (auto callee = call->op.as<GlobalVar>()) {
+ if (auto it = output_usage_mask_.find(callee.value()); it !=
output_usage_mask_.end()) {
+ return &it->second;
}
}
}
+
+ return nullptr;
}
Expr UnwrapBindings(Expr expr) const {
diff --git a/tests/python/relax/test_transform_remove_unused_outputs.py
b/tests/python/relax/test_transform_remove_unused_outputs.py
index c0405ca58d..365ce1695d 100644
--- a/tests/python/relax/test_transform_remove_unused_outputs.py
+++ b/tests/python/relax/test_transform_remove_unused_outputs.py
@@ -119,5 +119,25 @@ class TestMultipleCallSites(BaseCompare):
return (A, C)
+class TestReturnTuple(BaseCompare):
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(A: R.Tensor([16, 16], "int32")):
+ B = R.add(A, A)
+ out_tuple = Before.func(B)
+ return out_tuple
+
+ @R.function(private=True)
+ def func(
+ B: R.Tensor([16, 16], "int32")
+ ) -> R.Tuple(R.Tensor([16, 16], "int32"), R.Tensor([16, 16], "int32")):
+ C = R.multiply(B, B)
+ D = R.add(B, B)
+ return (C, D)
+
+ Expected = Before
+
+
if __name__ == "__main__":
tvm.testing.main()