This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b4e0d3eafa [Relax][PyTorch] Add negative slicing support in 
`slice_scatter` operation (#18494)
b4e0d3eafa is described below

commit b4e0d3eafab4b000a7ae7987f167c355e128c6a7
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Mon Nov 24 03:18:39 2025 +0900

    [Relax][PyTorch] Add negative slicing support in `slice_scatter` operation 
(#18494)
    
    As per title.
---
 .../frontend/torch/base_fx_graph_translator.py     | 35 ++++++++++++++++++++++
 .../relax/test_frontend_from_exported_program.py   | 21 +++++++++++++
 tests/python/relax/test_frontend_from_fx.py        | 25 +++++++++++++++-
 3 files changed, 80 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 9c2e45c8fd..3a3e0360af 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1795,6 +1795,41 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         end = args[4] if len(args) > 4 else node.kwargs.get("end", 
self.shape_of(input_tensor)[dim])
         step = args[5] if len(args) > 5 else node.kwargs.get("step", 1)
 
+        # Normalize bounds to match PyTorch behavior (negative and open-ended 
slices).
+        input_shape = self.shape_of(input_tensor)
+        axis = dim if dim >= 0 else dim + len(input_shape)
+
+        def _normalize_bound(bound):
+            # PyTorch uses a large positive value (2^63-1) to mean "len".
+            max_index_val = 9223372036854775807
+
+            def _adjust(val):
+                if isinstance(val, (int, tir.IntImm)):
+                    int_val = int(val)
+                    if int_val >= max_index_val:
+                        return input_shape[axis]
+                    if int_val < 0:
+                        return input_shape[axis] + int_val
+                    if isinstance(input_shape[axis], (int, tir.IntImm)) and 
int_val > int(
+                        input_shape[axis]
+                    ):
+                        return input_shape[axis]
+                return val
+
+            if isinstance(bound, relax.PrimValue):
+                value = _adjust(bound.value)
+                return relax.PrimValue(value)
+
+            bound = _adjust(bound)
+            if not isinstance(bound, relax.PrimValue):
+                bound = relax.PrimValue(bound)
+            return bound
+
+        start = _normalize_bound(start)
+        end = _normalize_bound(end)
+        if not isinstance(step, relax.PrimValue):
+            step = relax.PrimValue(step)
+
         return self.block_builder.emit(
             relax.op.slice_scatter(input_tensor, src, start, end, step, 
axis=dim)
         )
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 4c5d71216c..3435ac5670 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -5283,12 +5283,33 @@ def test_slice_scatter():
                 R.output(gv)
             return gv
 
+    class SliceScatterNegative(Module):
+        def forward(self, input, src):
+            return torch.slice_scatter(input, src, dim=1, start=0, end=-2, 
step=1)
+
+    @tvm.script.ir_module
+    class expected_slice_scatter:
+        @R.function
+        def main(
+            a: R.Tensor((2, 5), dtype="float32"), b: R.Tensor((2, 3), 
dtype="float32")
+        ) -> R.Tuple(R.Tensor((2, 5), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((2, 5), dtype="float32") = R.slice_scatter(
+                    a, b, R.prim_value(0), R.prim_value(3), R.prim_value(1), 
axis=1
+                )
+                gv: R.Tuple(R.Tensor((2, 5), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
     example_args = (torch.randn(8, 8, 10, 10, dtype=torch.float32), 
torch.randn(8, 3, 10, 10))
     verify_model(SliceScatter1(), example_args, {}, expected1)
 
     example_args = (torch.randn(8, 16, dtype=torch.float32), torch.randn(6, 
16))
     verify_model(SliceScatter2(), example_args, {}, expected2)
 
+    example_args = (torch.randn(2, 5, dtype=torch.float32), torch.randn(2, 3, 
dtype=torch.float32))
+    verify_model(SliceScatterNegative(), example_args, {}, 
expected_slice_scatter)
+
 
 def test_split():
     class Chunk(Module):
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 7f0905088c..9840665251 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -5142,11 +5142,34 @@ def test_slice_scatter():
                 R.output(gv)
             return gv
 
+    class SliceScatterNegative(Module):
+        def forward(self, input, src):
+            return torch.slice_scatter(input, src, dim=1, start=0, end=-2, 
step=1)
+
+    @tvm.script.ir_module
+    class expected_slice_scatter:
+        @R.function
+        def main(
+            a: R.Tensor((2, 5), dtype="float32"), b: R.Tensor((2, 3), 
dtype="float32")
+        ) -> R.Tensor((2, 5), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((2, 5), dtype="float32") = R.slice_scatter(
+                    a, b, R.prim_value(0), R.prim_value(3), R.prim_value(1), 
axis=1
+                )
+                gv: R.Tensor((2, 5), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
     verify_model(
         SliceScatter1(), [((8, 8, 10, 10), "float32"), ((8, 3, 10, 10), 
"float32")], {}, expected1
     )
-
     verify_model(SliceScatter2(), [((8, 16), "float32"), ((6, 16), 
"float32")], {}, expected2)
+    verify_model(
+        SliceScatterNegative(),
+        [((2, 5), "float32"), ((2, 3), "float32")],
+        {},
+        expected_slice_scatter,
+    )
 
 
 def test_masked_scatter():

Reply via email to