This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 1499bdae59 [Relax][PyTorch] Fix crash on dynamic shapes with identity
slice in ExportedProgram importer (#18903)
1499bdae59 is described below
commit 1499bdae5950b7bcaa585df4e736b894d11768f1
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Wed Mar 11 01:29:38 2026 +0900
[Relax][PyTorch] Fix crash on dynamic shapes with identity slice in
ExportedProgram importer (#18903)
Fixes `TypeError: 'NoneType' object is not iterable` when importing
models with dynamic batch dimensions that contain identity slices (e.g.,
`x[:, :H, :W, :]` on a dynamic batch dim).
**Root cause:** `aten.slice.Tensor(x, 0, 0, INT_MAX)` (an identity slice
on a dynamic dim `s`) produces a result with shape `[T.min(INT_MAX, s),
...]` instead of `[s, ...]`. When this is combined with the original
tensor via `add`, TVM cannot unify the shapes, resulting in
`struct_info.shape = None`. Any subsequent `view`/`reshape` then crashes
calling `list(None)`.
This pattern appears in models like `swin_t`, where shifted window
attention crops padded features with `x[:, :H, :W, :].contiguous()`.
**Changes:**
- `exported_program_translator.py`: Skip `strided_slice` for identity
slices (`start=0, end>=INT_MAX, step=1`) and return the input tensor
directly.
- `base_fx_graph_translator.py`: Guard the identity-reshape check in
`_reshape` against `None` shape.
---
.../tvm/relax/frontend/torch/base_fx_graph_translator.py | 2 +-
.../relax/frontend/torch/exported_program_translator.py | 14 ++++++++++++++
2 files changed, 15 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 2615a68ee3..3a7a62ba39 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -2076,7 +2076,7 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
# Skip identity reshape
current_shape = self.shape_of(x)
- if list(current_shape) == list(dims):
+ if current_shape is not None and list(current_shape) == list(dims):
return x
return self.block_builder.emit(relax.op.reshape(x, dims))
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 4c841b00b5..c9d0696277 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -941,6 +941,20 @@ class ExportedProgramImporter(BaseFXGraphImporter):
if end_val is None:
end_val = sys.maxsize
+ # Skip identity slice (start=0, end>=maxsize, step=1) which is commonly
+ # emitted by torch.export for dynamic shapes. Without this,
strided_slice
+ # produces shapes like T.min(9223372036854775807, s) that don't
simplify,
+ # causing downstream shape inference failures.
+ if (
+ isinstance(start, int)
+ and isinstance(end_val, int)
+ and isinstance(step, int)
+ and start == 0
+ and end_val >= sys.maxsize
+ and step == 1
+ ):
+ return x
+
axes = [dim]
begin = [start]
end = [end_val]