This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ff884b609a [Relax][Transform] Handle tuple return in 
RemoveUnusedOutputs (#17253)
ff884b609a is described below

commit ff884b609a2eb94fef1f061bff0ec867b79d4ba0
Author: Eric Lunderberg <[email protected]>
AuthorDate: Fri Sep 6 11:28:28 2024 -0500

    [Relax][Transform] Handle tuple return in RemoveUnusedOutputs (#17253)
    
    * [Relax][Transform] Handle tuple return in RemoveUnusedOutputs
    
    Prior to this commit, the `relax.transform.RemoveUnusedOutputs` pass
    only marked a tuple element as used if it occurred in a `TupleGetItem`
    node.  This ignored use cases where a tuple is used as an aggregate
    object, such as returning a tuple from a function.  This would collect
    incorrect results for a Relax function that calls a subroutine,
    receives a tuple as the return value of the subroutine, then returns
    that tuple.
    
    This commit updates `RemoveUnusedOutputs` to look for usage of a tuple
    object, not just for usage in `TupleGetItem`.
    
    Closes https://github.com/apache/tvm/issues/17247
---
 src/relax/transform/remove_unused_outputs.cc       | 59 ++++++++++++++--------
 .../relax/test_transform_remove_unused_outputs.py  | 20 ++++++++
 2 files changed, 59 insertions(+), 20 deletions(-)

diff --git a/src/relax/transform/remove_unused_outputs.cc 
b/src/relax/transform/remove_unused_outputs.cc
index e3bf12382c..9a5c31e79b 100644
--- a/src/relax/transform/remove_unused_outputs.cc
+++ b/src/relax/transform/remove_unused_outputs.cc
@@ -92,29 +92,48 @@ class PartialTupleUsageCollector : ExprVisitor {
   }
 
   void VisitExpr_(const TupleGetItemNode* op) override {
-    Expr tuple = UnwrapBindings(op->tuple);
-
-    if (auto call = tuple.as<CallNode>()) {
-      if (auto opt_callee = call->op.as<GlobalVar>()) {
-        auto callee = opt_callee.value();
-        if (auto it = output_usage_mask_.find(callee); it != 
output_usage_mask_.end()) {
-          auto& used_indices = it->second;
-
-          CHECK_GE(op->index, 0) << "IndexError: "
-                                 << "Indices for TupleGetItem must be 
non-negative, "
-                                 << "but expression " << GetRef<Expr>(op)
-                                 << " uses a tuple index of " << op->index;
-          size_t index = op->index;
-
-          CHECK_LT(index, used_indices.size())
-              << "IndexError: "
-              << "Indices for TupleGetItem must be less than the size of the 
tuple, "
-              << "but expression " << GetRef<Expr>(op) << " uses a tuple index 
of " << op->index
-              << " for a tuple of size " << used_indices.size();
-          used_indices[index] = true;
+    if (auto* usage_mask_ptr = GetCalleeUsageMask(op->tuple)) {
+      auto& used_indices = *usage_mask_ptr;
+
+      CHECK_GE(op->index, 0) << "IndexError: "
+                             << "Indices for TupleGetItem must be 
non-negative, "
+                             << "but expression " << GetRef<Expr>(op) << " 
uses a tuple index of "
+                             << op->index;
+      size_t index = op->index;
+
+      CHECK_LT(index, used_indices.size())
+          << "IndexError: "
+          << "Indices for TupleGetItem must be less than the size of the 
tuple, "
+          << "but expression " << GetRef<Expr>(op) << " uses a tuple index of 
" << op->index
+          << " for a tuple of size " << used_indices.size();
+      used_indices[index] = true;
+    }
+  }
+
+  void VisitExpr_(const VarNode* op) override {
+    if (auto* usage_mask_ptr = GetCalleeUsageMask(GetRef<Var>(op))) {
+      auto& usage_mask = *usage_mask_ptr;
+      for (size_t i = 0; i < usage_mask.size(); i++) {
+        usage_mask[i] = true;
+      }
+    }
+  }
+
+  std::vector<bool>* GetCalleeUsageMask(Expr expr) {
+    if (!expr->struct_info_.as<TupleStructInfoNode>()) {
+      return nullptr;
+    }
+
+    expr = UnwrapBindings(expr);
+    if (auto call = expr.as<CallNode>()) {
+      if (auto callee = call->op.as<GlobalVar>()) {
+        if (auto it = output_usage_mask_.find(callee.value()); it != 
output_usage_mask_.end()) {
+          return &it->second;
         }
       }
     }
+
+    return nullptr;
   }
 
   Expr UnwrapBindings(Expr expr) const {
diff --git a/tests/python/relax/test_transform_remove_unused_outputs.py 
b/tests/python/relax/test_transform_remove_unused_outputs.py
index c0405ca58d..365ce1695d 100644
--- a/tests/python/relax/test_transform_remove_unused_outputs.py
+++ b/tests/python/relax/test_transform_remove_unused_outputs.py
@@ -119,5 +119,25 @@ class TestMultipleCallSites(BaseCompare):
             return (A, C)
 
 
+class TestReturnTuple(BaseCompare):
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(A: R.Tensor([16, 16], "int32")):
+            B = R.add(A, A)
+            out_tuple = Before.func(B)
+            return out_tuple
+
+        @R.function(private=True)
+        def func(
+            B: R.Tensor([16, 16], "int32")
+        ) -> R.Tuple(R.Tensor([16, 16], "int32"), R.Tensor([16, 16], "int32")):
+            C = R.multiply(B, B)
+            D = R.add(B, B)
+            return (C, D)
+
+    Expected = Before
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to