This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b4e0d3eafa [Relax][PyTorch] Add negative slicing support in
`slice_scatter` operation (#18494)
b4e0d3eafa is described below
commit b4e0d3eafab4b000a7ae7987f167c355e128c6a7
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Mon Nov 24 03:18:39 2025 +0900
[Relax][PyTorch] Add negative slicing support in `slice_scatter` operation
(#18494)
As per title.
---
.../frontend/torch/base_fx_graph_translator.py | 35 ++++++++++++++++++++++
.../relax/test_frontend_from_exported_program.py | 21 +++++++++++++
tests/python/relax/test_frontend_from_fx.py | 25 +++++++++++++++-
3 files changed, 80 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 9c2e45c8fd..3a3e0360af 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1795,6 +1795,41 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
end = args[4] if len(args) > 4 else node.kwargs.get("end",
self.shape_of(input_tensor)[dim])
step = args[5] if len(args) > 5 else node.kwargs.get("step", 1)
+ # Normalize bounds to match PyTorch behavior (negative and open-ended
slices).
+ input_shape = self.shape_of(input_tensor)
+ axis = dim if dim >= 0 else dim + len(input_shape)
+
+ def _normalize_bound(bound):
+ # PyTorch uses a large positive value (2^63-1) to mean "len".
+ max_index_val = 9223372036854775807
+
+ def _adjust(val):
+ if isinstance(val, (int, tir.IntImm)):
+ int_val = int(val)
+ if int_val >= max_index_val:
+ return input_shape[axis]
+ if int_val < 0:
+ return input_shape[axis] + int_val
+ if isinstance(input_shape[axis], (int, tir.IntImm)) and
int_val > int(
+ input_shape[axis]
+ ):
+ return input_shape[axis]
+ return val
+
+ if isinstance(bound, relax.PrimValue):
+ value = _adjust(bound.value)
+ return relax.PrimValue(value)
+
+ bound = _adjust(bound)
+ if not isinstance(bound, relax.PrimValue):
+ bound = relax.PrimValue(bound)
+ return bound
+
+ start = _normalize_bound(start)
+ end = _normalize_bound(end)
+ if not isinstance(step, relax.PrimValue):
+ step = relax.PrimValue(step)
+
return self.block_builder.emit(
relax.op.slice_scatter(input_tensor, src, start, end, step,
axis=dim)
)
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 4c5d71216c..3435ac5670 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -5283,12 +5283,33 @@ def test_slice_scatter():
R.output(gv)
return gv
+ class SliceScatterNegative(Module):
+ def forward(self, input, src):
+ return torch.slice_scatter(input, src, dim=1, start=0, end=-2,
step=1)
+
+ @tvm.script.ir_module
+ class expected_slice_scatter:
+ @R.function
+ def main(
+ a: R.Tensor((2, 5), dtype="float32"), b: R.Tensor((2, 3),
dtype="float32")
+ ) -> R.Tuple(R.Tensor((2, 5), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((2, 5), dtype="float32") = R.slice_scatter(
+ a, b, R.prim_value(0), R.prim_value(3), R.prim_value(1),
axis=1
+ )
+ gv: R.Tuple(R.Tensor((2, 5), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
example_args = (torch.randn(8, 8, 10, 10, dtype=torch.float32),
torch.randn(8, 3, 10, 10))
verify_model(SliceScatter1(), example_args, {}, expected1)
example_args = (torch.randn(8, 16, dtype=torch.float32), torch.randn(6,
16))
verify_model(SliceScatter2(), example_args, {}, expected2)
+ example_args = (torch.randn(2, 5, dtype=torch.float32), torch.randn(2, 3,
dtype=torch.float32))
+ verify_model(SliceScatterNegative(), example_args, {},
expected_slice_scatter)
+
def test_split():
class Chunk(Module):
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 7f0905088c..9840665251 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -5142,11 +5142,34 @@ def test_slice_scatter():
R.output(gv)
return gv
+ class SliceScatterNegative(Module):
+ def forward(self, input, src):
+ return torch.slice_scatter(input, src, dim=1, start=0, end=-2,
step=1)
+
+ @tvm.script.ir_module
+ class expected_slice_scatter:
+ @R.function
+ def main(
+ a: R.Tensor((2, 5), dtype="float32"), b: R.Tensor((2, 3),
dtype="float32")
+ ) -> R.Tensor((2, 5), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((2, 5), dtype="float32") = R.slice_scatter(
+ a, b, R.prim_value(0), R.prim_value(3), R.prim_value(1),
axis=1
+ )
+ gv: R.Tensor((2, 5), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
verify_model(
SliceScatter1(), [((8, 8, 10, 10), "float32"), ((8, 3, 10, 10),
"float32")], {}, expected1
)
-
verify_model(SliceScatter2(), [((8, 16), "float32"), ((6, 16),
"float32")], {}, expected2)
+ verify_model(
+ SliceScatterNegative(),
+ [((2, 5), "float32"), ((2, 3), "float32")],
+ {},
+ expected_slice_scatter,
+ )
def test_masked_scatter():