This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 81d6778947 [Unity][Frontend] Add `no_bind_return_tuple` for PyTorch FX 
Translator (#14639)
81d6778947 is described below

commit 81d6778947ff6968ae2a3306b7d149f40a9e3f03
Author: Chaofan Lin <[email protected]>
AuthorDate: Tue Apr 18 09:11:16 2023 +0800

    [Unity][Frontend] Add `no_bind_return_tuple` for PyTorch FX Translator 
(#14639)
    
    Previously, if we have a Torch model with multiple return values like:
    ```
    class Mod(nn.Module):
        ...
        def forward(...):
            ...
            return logits, loss
    ```
    The translator will always bind the return tuple to a Relax var and return 
it:
    ```
        gv: R.Tuple(...) = (logits, loss)
        return gv
    ```
    This brings inconvenience if we want to use `Gradient` pass to set the loss 
as the target (i.e. set `target_index=1`). Because if the return value is a 
single Relax var, the `Gradient` pass will not take it as multiple return 
values.
    
    This PR brings a flag to let user choose between different return styles. 
If the user set `no_bind_return_tuple = True`, they will get
    ```
        return (logits, loss)
    ```
---
 python/tvm/relax/frontend/torch/fx_translator.py | 24 ++++++++++++-----
 tests/python/relax/test_frontend_from_fx.py      | 33 ++++++++++++++++++++++++
 2 files changed, 50 insertions(+), 7 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 54890bd3c5..aa9f661803 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -1193,6 +1193,7 @@ class TorchFXImporter:
         input_info: List[Tuple[Tuple[int], str]],
         keep_params_as_input: bool,
         unwrap_unit_return_tuple: bool,
+        no_bind_return_tuple: bool,
     ) -> tvm.IRModule:
         """Convert a PyTorch FX GraphModule to a Relax program."""
         from torch import fx
@@ -1244,12 +1245,16 @@ class TorchFXImporter:
                     elif node.op == "output":
                         args = self.retrieve_args(node)
                         assert len(args) == 1
-                        if (
-                            unwrap_unit_return_tuple
-                            and isinstance(args[0], (tuple, list, relax.Tuple))
-                            and len(args[0]) == 1
-                        ):
-                            output = self.block_builder.emit_output(args[0][0])
+
+                        # return tuple
+                        if isinstance(args[0], (tuple, list, relax.Tuple)):
+                            # unit tuple
+                            if unwrap_unit_return_tuple and len(args[0]) == 1:
+                                output = 
self.block_builder.emit_output(args[0][0])
+                            elif no_bind_return_tuple:
+                                output = []
+                                for ret in args[0]:
+                                    
output.append(self.block_builder.emit_output(ret))
                         else:
                             output = self.block_builder.emit_output(args[0])
                         break
@@ -1289,6 +1294,7 @@ def from_fx(
     *,
     keep_params_as_input: bool = False,
     unwrap_unit_return_tuple: bool = False,
+    no_bind_return_tuple: bool = False,
 ) -> tvm.IRModule:
     """Convert a PyTorch FX GraphModule to a Relax program
 
@@ -1307,6 +1313,10 @@ def from_fx(
         A boolean flag indicating if to the return value when it is an unit 
tuple.
         When the return value is not a unit tuple, no unwrap will take place.
 
+    no_bind_return_tuple : bool
+        A boolean flag indicating whether to bind the return tuple as a relax 
var.
+        If the flag is true and the return value is a tuple, it will not bind 
it to a var.
+
     Returns
     -------
     output : tvm.IRModule
@@ -1375,5 +1385,5 @@ def from_fx(
     check the placeholder rows in the beginning of the tabular.
     """
     return TorchFXImporter().from_fx(
-        model, input_info, keep_params_as_input, unwrap_unit_return_tuple
+        model, input_info, keep_params_as_input, unwrap_unit_return_tuple, 
no_bind_return_tuple
     )
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 4eb7c2afa4..6d20abe16d 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -2707,6 +2707,39 @@ def test_unwrap_unit_return_tuple():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
[email protected]_gpu
+def test_no_bind_return_tuple():
+    import torch.fx as fx
+    from torch.nn import Module
+    from tvm.relax.frontend.torch import from_fx
+
+    class Identity(Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, x, y):
+            return (x, y)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            inp_0: R.Tensor((256, 256), dtype="float32"),
+            inp_1: R.Tensor((256, 256), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((256, 256), dtype="float32"), R.Tensor((256, 
256), dtype="float32")):
+            with R.dataflow():
+                gv: R.Tensor((256, 256), dtype="float32") = inp_0
+                gv1: R.Tensor((256, 256), dtype="float32") = inp_1
+                R.output(gv, gv1)
+            return (gv, gv1)
+
+    graph_model = fx.symbolic_trace(Identity())
+    mod = from_fx(
+        graph_model, [([256, 256], "float32"), ([256, 256], "float32")], 
no_bind_return_tuple=True
+    )
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 @tvm.testing.requires_gpu
 def test_argmax():
     import torch

Reply via email to