This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 81d6778947 [Unity][Frontend] Add `no_bind_return_tuple` for PyTorch FX
Translator (#14639)
81d6778947 is described below
commit 81d6778947ff6968ae2a3306b7d149f40a9e3f03
Author: Chaofan Lin <[email protected]>
AuthorDate: Tue Apr 18 09:11:16 2023 +0800
[Unity][Frontend] Add `no_bind_return_tuple` for PyTorch FX Translator
(#14639)
Previously, if we have a Torch model with multiple return values like:
```
class Mod(nn.Module):
...
def forward(...):
...
return logits, loss
```
The translator will always bind the return tuple to a Relax var and return
it:
```
gv: R.Tuple(...) = (logits, loss)
return gv
```
This brings inconvenience if we want to use `Gradient` pass to set the loss
as the target (i.e. set `target_index=1`). Because if the return value is a
single Relax var, the `Gradient` pass will not take it as multiple return
values.
This PR brings a flag to let user choose between different return styles.
If the user set `no_bind_return_tuple = True`, they will get
```
return (logits, loss)
```
---
python/tvm/relax/frontend/torch/fx_translator.py | 24 ++++++++++++-----
tests/python/relax/test_frontend_from_fx.py | 33 ++++++++++++++++++++++++
2 files changed, 50 insertions(+), 7 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index 54890bd3c5..aa9f661803 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -1193,6 +1193,7 @@ class TorchFXImporter:
input_info: List[Tuple[Tuple[int], str]],
keep_params_as_input: bool,
unwrap_unit_return_tuple: bool,
+ no_bind_return_tuple: bool,
) -> tvm.IRModule:
"""Convert a PyTorch FX GraphModule to a Relax program."""
from torch import fx
@@ -1244,12 +1245,16 @@ class TorchFXImporter:
elif node.op == "output":
args = self.retrieve_args(node)
assert len(args) == 1
- if (
- unwrap_unit_return_tuple
- and isinstance(args[0], (tuple, list, relax.Tuple))
- and len(args[0]) == 1
- ):
- output = self.block_builder.emit_output(args[0][0])
+
+ # return tuple
+ if isinstance(args[0], (tuple, list, relax.Tuple)):
+ # unit tuple
+ if unwrap_unit_return_tuple and len(args[0]) == 1:
+ output =
self.block_builder.emit_output(args[0][0])
+ elif no_bind_return_tuple:
+ output = []
+ for ret in args[0]:
+
output.append(self.block_builder.emit_output(ret))
else:
output = self.block_builder.emit_output(args[0])
break
@@ -1289,6 +1294,7 @@ def from_fx(
*,
keep_params_as_input: bool = False,
unwrap_unit_return_tuple: bool = False,
+ no_bind_return_tuple: bool = False,
) -> tvm.IRModule:
"""Convert a PyTorch FX GraphModule to a Relax program
@@ -1307,6 +1313,10 @@ def from_fx(
A boolean flag indicating if to the return value when it is an unit
tuple.
When the return value is not a unit tuple, no unwrap will take place.
+ no_bind_return_tuple : bool
+ A boolean flag indicating whether to bind the return tuple as a relax
var.
+ If the flag is true and the return value is a tuple, it will not bind
it to a var.
+
Returns
-------
output : tvm.IRModule
@@ -1375,5 +1385,5 @@ def from_fx(
check the placeholder rows in the beginning of the tabular.
"""
return TorchFXImporter().from_fx(
- model, input_info, keep_params_as_input, unwrap_unit_return_tuple
+ model, input_info, keep_params_as_input, unwrap_unit_return_tuple,
no_bind_return_tuple
)
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 4eb7c2afa4..6d20abe16d 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -2707,6 +2707,39 @@ def test_unwrap_unit_return_tuple():
tvm.ir.assert_structural_equal(mod, Expected)
[email protected]_gpu
+def test_no_bind_return_tuple():
+ import torch.fx as fx
+ from torch.nn import Module
+ from tvm.relax.frontend.torch import from_fx
+
+ class Identity(Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, y):
+ return (x, y)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ inp_0: R.Tensor((256, 256), dtype="float32"),
+ inp_1: R.Tensor((256, 256), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((256, 256), dtype="float32"), R.Tensor((256,
256), dtype="float32")):
+ with R.dataflow():
+ gv: R.Tensor((256, 256), dtype="float32") = inp_0
+ gv1: R.Tensor((256, 256), dtype="float32") = inp_1
+ R.output(gv, gv1)
+ return (gv, gv1)
+
+ graph_model = fx.symbolic_trace(Identity())
+ mod = from_fx(
+ graph_model, [([256, 256], "float32"), ([256, 256], "float32")],
no_bind_return_tuple=True
+ )
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
@tvm.testing.requires_gpu
def test_argmax():
import torch