This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c12debb462 [Relax][PyTorch] Fix segfault in from_exported_program when 
model uses index_put_ with tuple output (#19488)
c12debb462 is described below

commit c12debb462c2e43270987db15ebe5074d33c190d
Author: Neo Chien <[email protected]>
AuthorDate: Fri May 8 17:49:18 2026 +0800

    [Relax][PyTorch] Fix segfault in from_exported_program when model uses 
index_put_ with tuple output (#19488)
    
    Hi Committers,
    
    This PR is trying to fix issues
    https://github.com/apache/tvm/issues/18363. Any suggestions would be
    appreciated if you are available.
    
    ### Root Cause
    - When an ExportedProgram's FX graph output node returns a **nested
    Python tuple** (e.g., buffer mutation outputs + user-defined tuple
    returns), `_translate_fx_graph()` passes the raw nested structure
    directly to the Relax FFI Tuple constructor.
    - The C++ Array<Expr> initializer cannot handle heterogeneous/nested
    Python containers, causing a segmentation fault at `expr.cc`.
    - Additionally, index_put_ (in-place write op) did not update self.env
    to alias the source tensor to the mutated output, causing subsequent FX
    nodes that read the same tensor to observe **stale pre-mutation
    values**.
    
    ### Solution
    - exported_program_translator.py
    - Added static method `_flatten_output_args()` that recursively walks
    any Python `tuple/list`, collects only `relax.Expr` leaves, and preserve
    explicit None outputs as Relax null objects.
    - Replaced the fragile `assert isinstance(output_args, tuple |
    relax.Tuple)` guard with a call to `_flatten_output_args()`, producing a
    clean flat tuple of `relax.Expr` before FFI construction.
    - base_fx_graph_translator.py
    - In `_index_put()`, after emitting the `relax.op.index_put(...)` call,
    added an env alias update: `self.env[source_node] = output` when the
    target op name starts with `index_put_`, preserving correct in-place
    mutation semantics for downstream FX nodes.
    
    ---------
    
    Co-authored-by: cchung100m <[email protected]>
---
 .../frontend/torch/base_fx_graph_translator.py     |  20 +++-
 .../frontend/torch/exported_program_translator.py  |  37 ++++++-
 .../relax/test_frontend_from_exported_program.py   | 108 +++++++++++++++++++++
 3 files changed, 162 insertions(+), 3 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 138176155a..89c91e3773 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1921,7 +1921,25 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
                 indices = relax.Tuple(processed_indices)
             else:
                 indices = relax.Tuple(indices)
-        return self.block_builder.emit(relax.op.index_put(tensor, indices, 
values, accumulate))
+
+        output = self.block_builder.emit(relax.op.index_put(tensor, indices, 
values, accumulate))
+
+        target_name = (
+            node.target if isinstance(node.target, str) else 
getattr(node.target, "__name__", "")
+        )
+        if target_name.startswith("index_put_") and len(node.args) > 0:
+            from torch import fx
+
+            if isinstance(node.args[0], fx.Node):
+                # `index_put_` is in-place. If the mutated input is an alias 
of another
+                # FX node, later reads via either the alias node or the 
original node
+                # must oberve the updated tensor.
+                aliased_expr = tensor
+                for env_node, env_expr in list(self.env.items()):
+                    if env_expr is aliased_expr:
+                        self.env[env_node] = output
+
+        return output
 
     def _index_tensor(self, node: fx.Node) -> relax.Var:
         args = self.retrieve_args(node)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index cc37554bf3..5bd2c785f2 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -1338,7 +1338,40 @@ class ExportedProgramImporter(BaseFXGraphImporter):
                 raise ValueError(f"Unsupported op {node.op}")
 
         assert output_args is not None
-        return output_args
+        return self._flatten_output_args(output_args)
+
+    @staticmethod
+    def _flatten_output_args(output_args) -> tuple[relax.Expr, ...]:
+        """Flatten output args into a tuple of Relax expressions.
+
+        ExportedProgram output trees contain nested Python tuple/list 
containers
+        (e.g. mutation outputs + user tuple outputs). Emitting nested Python 
tuples
+        directly through FFI may construct invalid Relax tuples.
+        """
+
+        flattened: list[relax.Expr] = []
+
+        def _visit(value):
+            if isinstance(value, relax.Expr):
+                flattened.append(value)
+            elif isinstance(value, list | tuple):
+                for item in value:
+                    _visit(item)
+            elif value is None:
+                # Preserve explicit None outputs as Relax null objects.
+                flattened.append(relax.op.null_value())
+            else:
+                raise ValueError(
+                    "Unsupported output type in exported graph output: "
+                    f"{type(value)}"
+                )
+
+        _visit(output_args)
+
+        if not flattened:
+            raise ValueError("Exported graph produced no Relax outputs")
+
+        return tuple(flattened)
 
     def _import_branch_subgraph(
         self,
@@ -1995,7 +2028,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
                 output_args = self._translate_fx_graph(
                     exported_program.graph_module, nodes, inputs_vars, 
custom_ops
                 )
-                assert isinstance(output_args, tuple | relax.Tuple)
+                output_args = self._flatten_output_args(output_args)
 
                 if unwrap_unit_return_tuple and len(output_args) == 1:
                     ret = output_args[0]
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index e2f9751c15..f3e2e581e1 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -7402,6 +7402,114 @@ def test_index_put():
     verify_model(IndexPutBatchedWithNone(), example_args_batched_none, {}, 
ExpectedBatchedWithNone)
 
 
+def test_index_put_with_tuple_output():
+    class IndexPutTupleOutput(Module):
+        def forward(self, x, l, idx):
+            values = x
+            l[..., idx, idx] = values
+            return x[..., 1], l
+
+    example_args = (
+        torch.ones(2, 3, 5, dtype=torch.float32),
+        torch.zeros(2, 3, 5, 5, dtype=torch.float32),
+        torch.tensor([0, 1, 2, 3, 4], dtype=torch.int64),
+    )
+
+    exported_program = export(IndexPutTupleOutput(), args=example_args)
+    mod = from_exported_program(exported_program)
+
+    ret_sinfo = mod["main"].ret_struct_info
+    assert isinstance(ret_sinfo, relax.TupleStructInfo)
+
+    tensor_fields = [f for f in ret_sinfo.fields if isinstance(f, 
relax.TensorStructInfo)]
+    assert len(tensor_fields) >= 2
+
+    assert any(
+        len(f.shape) == 4 and int(f.shape[-2]) == 5 and int(f.shape[-1]) == 5 
+        for f in tensor_fields
+    )
+
+
+def test_m4d_diag_index_put_tuple_output_regression():
+    class M4D(Module):
+        def forward(self, x):
+            b, k, n = 2, 3, 5
+            l = x.new_zeros(b, k, n, n)
+            idx = torch.arange(n, device=x.device)
+
+            diag = l[..., idx, idx]
+            diag = torch.nn.functional.elu(diag) + 1.0 + 1e-8
+            l[..., idx, idx] = diag
+
+            return x[..., :1], l
+
+    ex_in = torch.zeros(2, 3, 5, dtype=torch.float32)
+    exported_program = export(M4D().eval(), args=(ex_in,))
+
+    exported_targets = [str(getattr(n, "target", "")) for n in 
exported_program.graph.nodes]
+    assert any("index_put" in target for target in exported_targets)
+
+    # Regression focus: importing this graph should not segfault at Tuple 
construction.
+    mod = from_exported_program(exported_program)
+    ret_sinfo = mod["main"].ret_struct_info
+    assert isinstance(ret_sinfo, relax.TupleStructInfo)
+
+    tensor_fields = [f for f in ret_sinfo.fields if isinstance(f, 
relax.TensorStructInfo)]
+    assert len(tensor_fields) >= 2
+    # x: (2, 3, 5) → x[..., :1]: (2, 3, 1)
+    assert any(len(f.shape) == 3 and int(f.shape[-1]) == 1 for f in 
tensor_fields)
+    # l: (2, 3, 5, 5) → 4-D with spatial dims 5×5
+    assert any(
+        len(f.shape) == 4 and int(f.shape[-2]) == 5 and int(f.shape[-1]) == 5
+        for f in tensor_fields
+    )
+
+
+def test_index_put_mutation_through_alias_regression():
+    class IndexPutAlias(Module):
+        def forward(self, x, idx, values):
+            y = torch.ops.aten.alias.default(x)
+            y[idx] = values
+            return x, y
+
+    example_args = (
+        torch.zeros(5, dtype=torch.float32),
+        torch.tensor([1, 3], dtype=torch.int64),
+        torch.tensor([2.0, 4.0], dtype=torch.float32),
+    )
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((5,), dtype="float32"),
+            idx: R.Tensor((2,), dtype="int64"),
+            values: R.Tensor((2,), dtype="float32"),
+        ) -> R.Tuple(
+            R.Tensor((5,), dtype="float32"),
+            R.Tensor((5,), dtype="float32"),
+            R.Tensor((5,), dtype="float32"),
+        ):
+            with R.dataflow():
+                lv: R.Tensor((5,), dtype="float32") = R.index_put(
+                    x, (idx,), values, accumulate=False
+                )
+                # ExportedProgram may include an additional mutation output.
+                gv: R.Tuple(
+                    R.Tensor((5,), dtype="float32"),
+                    R.Tensor((5,), dtype="float32"),
+                    R.Tensor((5,), dtype="float32"),
+                ) = (
+                    lv,
+                    lv,
+                    lv,
+                )
+                R.output(gv)
+            return gv
+
+    verify_model(IndexPutAlias(), example_args, {}, Expected)
+
+
 def test_flip():
     class Flip0(Module):
         def forward(self, data):

Reply via email to