This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c581abeee4 [Relax][PyTorch] Fix _slice and _expand for dynamic shapes 
in PyTorch ExportedProgram frontend (#18918)
c581abeee4 is described below

commit c581abeee410cc5426a667f64b3ac27b59ace2af
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Fri Mar 20 15:18:52 2026 +0900

    [Relax][PyTorch] Fix _slice and _expand for dynamic shapes in PyTorch 
ExportedProgram frontend (#18918)
    
    Fixes two issues when translating PyTorch models with dynamic shapes:
    
    1. **_slice**: Resolve `fx.Node` references in start/end/step arguments
    and detect identity slices where the symbolic end equals the tensor
    dimension (avoids redundant `strided_slice` ops).
    
    2. **_expand**: Fall back to FX node metadata when `shape_of()` returns
    `None` for tensors with unknown shapes.
---
 .../frontend/torch/base_fx_graph_translator.py     | 17 +++++-
 .../frontend/torch/exported_program_translator.py  | 21 ++++++-
 .../relax/test_frontend_from_exported_program.py   | 67 ++++++++++++++++++++++
 3 files changed, 101 insertions(+), 4 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 803b4b7e11..c146cf6c00 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1754,13 +1754,24 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
     def _expand(self, node: fx.Node) -> relax.Var:
         args = self.retrieve_args(node)
         sizes = args[1:] if len(args) > 2 else args[1]
-        broadcast_shape, in_shape = [], self.shape_of(args[0])
+        x = args[0]
+        broadcast_shape = []
+        in_shape = self.shape_of(x)
         for idx, i in enumerate(sizes):
             if isinstance(i, int) and i == -1:
-                broadcast_shape.append(in_shape[idx])
+                if in_shape is not None:
+                    broadcast_shape.append(in_shape[idx])
+                elif hasattr(node.args[0], "meta") and "val" in 
node.args[0].meta:
+                    # Fallback: get shape from FX node metadata (FakeTensor)
+                    fake_shape = node.args[0].meta["val"].shape
+                    broadcast_shape.append(fake_shape[idx])
+                else:
+                    raise ValueError(
+                        f"Cannot use -1 in expand for dim {idx} when input 
shape is unknown"
+                    )
             else:
                 broadcast_shape.append(i)
-        return self.block_builder.emit(relax.op.broadcast_to(args[0], 
broadcast_shape))
+        return self.block_builder.emit(relax.op.broadcast_to(x, 
broadcast_shape))
 
     def _expand_as(self, node: fx.Node) -> relax.Var:
         args = self.retrieve_args(node)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 2487b904c6..fd03f67332 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -29,7 +29,7 @@ import torch
 from torch import fx
 
 import tvm
-from tvm import relax
+from tvm import relax, tir
 
 from .base_fx_graph_translator import BaseFXGraphImporter
 
@@ -937,6 +937,14 @@ class ExportedProgramImporter(BaseFXGraphImporter):
         end_val = node.args[3] if len(node.args) > 3 else None
         step = node.args[4] if len(node.args) > 4 else 1
 
+        # Resolve fx.Node references (e.g. symbolic sizes from dynamic shapes)
+        if isinstance(start, fx.Node):
+            start = self.env[start]
+        if isinstance(end_val, fx.Node):
+            end_val = self.env[end_val]
+        if isinstance(step, fx.Node):
+            step = self.env[step]
+
         if start is None:
             start = 0
         if end_val is None:
@@ -956,6 +964,17 @@ class ExportedProgramImporter(BaseFXGraphImporter):
         ):
             return x
 
+        # Skip identity slice where end_val is a symbolic expression equal to 
the
+        # tensor's own dimension size (common with dynamic shapes).
+        if isinstance(start, int) and start == 0 and isinstance(step, int) and 
step == 1:
+            in_shape = self.shape_of(x)
+            if in_shape is not None and isinstance(end_val, tir.PrimExpr):
+                actual_dim = dim if dim >= 0 else len(in_shape) + dim
+                dim_expr = in_shape[actual_dim]
+                if isinstance(dim_expr, tir.PrimExpr):
+                    if tir.analysis.expr_deep_equal(end_val, dim_expr):
+                        return x
+
         axes = [dim]
         begin = [start]
         end = [end_val]
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index a9cea19fdc..e1cadb9d02 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -5634,6 +5634,73 @@ def test_slice_scatter():
     verify_model(SliceScatterNegative(), example_args, {}, 
expected_slice_scatter)
 
 
+def test_slice_with_symbolic_end():
+    """_slice correctly handles symbolic end values from dynamic shapes."""
+
+    class SliceIdentityModel(torch.nn.Module):
+        def forward(self, x):
+            # x[:, :x.size(1)] is an identity slice that torch.export emits
+            # as slice(x, 1, 0, sym_size_int(x, 1), 1) with dynamic shapes.
+            seq_len = x.size(1)
+            return x[:, :seq_len] + 0.0  # +0.0 to ensure output is a new 
tensor
+
+    # The identity slice is elided; only x + 0.0 remains.
+    @I.ir_module
+    class ExpectedIdentity:
+        @R.function
+        def main(x: R.Tensor(("s0", "s1", 4), dtype="float32")) -> R.Tuple(
+            R.Tensor(("s0", "s1", 4), dtype="float32")
+        ):
+            s0 = T.int64(is_size_var=True)
+            s1 = T.int64(is_size_var=True)
+            R.func_attr({"tir_var_lower_bound": {"s27": 2, "s77": 2}})
+            with R.dataflow():
+                lv: R.Tensor((s0, s1, 4), dtype="float32") = R.add(x, 
R.const(0.0, "float32"))
+                gv: R.Tuple(R.Tensor((s0, s1, 4), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(2, 8, 4, dtype=torch.float32),)
+    batch = torch.export.Dim("batch", min=2)
+    seq = torch.export.Dim("seq", min=2)
+    dynamic_shapes = {"x": {0: batch, 1: seq}}
+
+    verify_model(
+        SliceIdentityModel(),
+        example_args,
+        {},
+        ExpectedIdentity,
+        dynamic_shapes=dynamic_shapes,
+        map_free_vars=True,
+    )
+
+    class SliceStaticModel(torch.nn.Module):
+        def forward(self, x):
+            # A non-identity static slice
+            return x[:, :3]
+
+    @tvm.script.ir_module
+    class ExpectedStatic:
+        @R.function
+        def main(x: R.Tensor((2, 8, 4), dtype="float32")) -> R.Tuple(
+            R.Tensor((2, 3, 4), dtype="float32")
+        ):
+            with R.dataflow():
+                lv: R.Tensor((2, 3, 4), dtype="float32") = R.strided_slice(
+                    x,
+                    axes=[1],
+                    begin=[0],
+                    end=[3],
+                    strides=[1],
+                )
+                gv: R.Tuple(R.Tensor((2, 3, 4), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args_static = (torch.randn(2, 8, 4, dtype=torch.float32),)
+    verify_model(SliceStaticModel(), example_args_static, {}, ExpectedStatic)
+
+
 def test_split():
     class Chunk(Module):
         def forward(self, input):

Reply via email to