This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new f30b29c2c5 [Relax][PyTorch] Fix the segfault in from_exported_program
when model returns (Tensor, None) tuple (#18359)
f30b29c2c5 is described below
commit f30b29c2c5e35eb975ae8926fb7ebfae4d817a50
Author: Shushi Hong <[email protected]>
AuthorDate: Sun Oct 5 00:58:40 2025 -0400
[Relax][PyTorch] Fix the segfault in from_exported_program when model
returns (Tensor, None) tuple (#18359)
* finish1
* finish2
* add unittest
---
.../frontend/torch/base_fx_graph_translator.py | 2 ++
.../relax/test_frontend_from_exported_program.py | 22 ++++++++++++++++++++++
2 files changed, 24 insertions(+)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 12b460e859..c1cbd3416c 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -102,6 +102,8 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
return [self._retrieve_args(x) for x in node]
elif isinstance(node, dict):
return {self._retrieve_args(k): self._retrieve_args(v) for k, v in
node.items()}
+ elif node is None:
+ return relax.op.null_value()
else:
return node
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 4b0672ccc1..b35af088b5 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -6028,5 +6028,27 @@ def test_lstm():
np.testing.assert_allclose(pytorch_output2.numpy(), tvm_output2_np,
rtol=1e-4, atol=1e-5)
+def test_tensor_none_tuple():
+ example_args = (torch.tensor([1.0, 2.0, 3.0]),)
+
+ class TensorNoneModel(Module):
+ def forward(self, x):
+ return x + 1, None
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((3,), dtype="float32")
+ ) -> R.Tuple(R.Tensor((3,), dtype="float32"), R.Object):
+ with R.dataflow():
+ lv: R.Tensor((3,), dtype="float32") = R.add(x, R.const(1.0,
"float32"))
+ gv: R.Tuple(R.Tensor((3,), dtype="float32"), R.Object) = (lv,
R.null_value())
+ R.output(gv)
+ return gv
+
+ verify_model(TensorNoneModel(), example_args, {}, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()