This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 36f2502f00 Add op support for roll op (#17839)
36f2502f00 is described below
commit 36f2502f00d58128451d64c8ff12fa871c994c92
Author: Pratheesh-04-MCW <[email protected]>
AuthorDate: Fri Apr 18 17:25:36 2025 +0530
Add op support for roll op (#17839)
* add op support for roll op
* lint fix
* fixed unity check
* add unit test in fx_graph
* lint issues
* lint check
* confilct resolved
---
.../frontend/torch/base_fx_graph_translator.py | 81 ++++++++++++-
.../frontend/torch/exported_program_translator.py | 1 +
python/tvm/relax/frontend/torch/fx_translator.py | 1 +
.../relax/test_frontend_from_exported_program.py | 125 +++++++++++++++++++++
tests/python/relax/test_frontend_from_fx.py | 120 ++++++++++++++++++++
5 files changed, 327 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 2652b167e5..ae4c918900 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -23,7 +23,7 @@ from functools import reduce
import math
from typing import Callable, Dict, Optional, Tuple, Union, List
-from tvm import relax
+from tvm import relax, tir
class BaseFXGraphImporter(metaclass=abc.ABCMeta):
@@ -1164,6 +1164,85 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else
args[1:]
return self.block_builder.emit(relax.op.tile(x, dims))
+ def _roll(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ input_tensor = args[0]
+ shifts = args[1] if len(node.args) > 1 else node.kwargs.get("shifts",
None)
+ dims = args[2] if len(node.args) > 2 else node.kwargs.get("dims", None)
+
+ # Get original shape
+ original_shape = self.shape_of(input_tensor)
+
+ def to_int(val):
+ if isinstance(val, tir.IntImm):
+ return int(val.value)
+ elif isinstance(val, int):
+ return val
+ elif hasattr(val, "__int__"):
+ return int(val)
+ raise TypeError(f"Unsupported type for shift/dim: {type(val)}")
+
+ def roll_single_dim(tensor: relax.Var, shift: int, dim: int) ->
relax.Var:
+ shape = self.shape_of(tensor)
+
+ dim_size = shape.values[dim]
+ shift_val = to_int(shift)
+ dim_size_val = to_int(dim_size)
+ shift_mod = shift_val % dim_size_val
+ if shift_mod == 0:
+ return tensor
+
+ split_pos = dim_size_val - shift_mod
+ part1 = self.block_builder.emit(
+ relax.op.strided_slice(
+ tensor,
+ axes=[dim],
+ begin=[0],
+ end=[split_pos],
+ strides=[1],
+ )
+ )
+ part2 = self.block_builder.emit(
+ relax.op.strided_slice(
+ tensor,
+ axes=[dim],
+ begin=[split_pos],
+ end=[dim_size_val],
+ strides=[1],
+ )
+ )
+ return self.block_builder.emit(relax.op.concat([part2, part1],
axis=dim))
+
+ # Handle dims=None (flatten -> roll -> reshape)
+ if dims is None:
+ flattened = self.block_builder.emit(relax.op.reshape(input_tensor,
(-1,)))
+ shift_scalar = to_int(shifts[0] if isinstance(shifts, (list,
tuple)) else shifts)
+ rolled = roll_single_dim(flattened, shift_scalar, 0)
+ return self.block_builder.emit(relax.op.reshape(rolled,
original_shape))
+
+ # Normalize shifts and dims
+ if isinstance(shifts, (list, tuple)):
+ shifts = [to_int(s) for s in shifts]
+ else:
+ shifts = [to_int(shifts)]
+
+ if isinstance(dims, (list, tuple)):
+ dims = [to_int(d) for d in dims]
+ else:
+ dims = [to_int(dims)]
+
+ if len(shifts) != len(dims):
+ raise ValueError("shifts and dims must have the same length")
+
+ result = input_tensor
+ rank = len(original_shape.values)
+ for shift, dim in zip(shifts, dims):
+ if dim < 0:
+ dim += rank
+ result = roll_single_dim(result, shift, dim)
+
+ return result
+
def _reshape(self, node: fx.Node) -> relax.Var:
import torch # type: ignore
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 5d4f3437b2..9326072875 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -423,6 +423,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"narrow.default": self._narrow,
"permute.default": self._permute,
"repeat.default": self._repeat,
+ "roll.default": self._roll,
"select.int": self._select,
"slice.Tensor": self._slice,
"split.Tensor": self._split,
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index e6b1fdd223..5a34befb92 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -750,6 +750,7 @@ class TorchFXImporter(BaseFXGraphImporter):
"numel": self._numel,
"permute": self._permute,
"repeat": self._repeat,
+ "roll": self._roll,
"reshape": self._reshape,
"scatter": self._scatter,
"select": self._select,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 9259936dc2..80c0bd5fb4 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -2968,6 +2968,131 @@ def test_reshape_as():
verify_model(ReshapeAs(), example_args, {}, expected1)
+def test_roll():
+ class Roll1(Module):
+ def forward(self, x):
+ return torch.roll(x, 1)
+
+ class Roll2(Module):
+ def forward(self, x):
+ return torch.roll(x, -1, 0)
+
+ class Roll3(Module):
+ def forward(self, x):
+ return torch.roll(x, shifts=(2, 1), dims=(0, 1))
+
+ # Test case 1: torch.roll(x, 1)
+ @I.ir_module
+ class Expected1:
+ @R.function
+ def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4,
2), dtype="int64")):
+ with R.dataflow():
+ lv: R.Tensor((8,), dtype="int64") = R.reshape(x, R.shape([8]))
+ lv1: R.Tensor((7,), dtype="int64") = R.strided_slice(
+ lv,
+ axes=[0],
+ begin=[R.prim_value(0)],
+ end=[R.prim_value(7)],
+ strides=[R.prim_value(1)],
+ assume_inbound=False,
+ )
+ lv2: R.Tensor((1,), dtype="int64") = R.strided_slice(
+ lv,
+ axes=[0],
+ begin=[R.prim_value(7)],
+ end=[R.prim_value(8)],
+ strides=[R.prim_value(1)],
+ assume_inbound=False,
+ )
+ lv3: R.Tensor((8,), dtype="int64") = R.concat((lv2, lv1),
axis=0)
+ lv4: R.Tensor((4, 2), dtype="int64") = R.reshape(lv3,
R.shape([4, 2]))
+ gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv4,)
+ R.output(gv)
+ return gv
+
+ # Test case 2: torch.roll(x, -1, 0)
+ @I.ir_module
+ class Expected2:
+ @R.function
+ def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4,
2), dtype="int64")):
+ with R.dataflow():
+ lv: R.Tensor((1, 2), dtype="int64") = R.strided_slice(
+ x,
+ axes=[0],
+ begin=[R.prim_value(0)],
+ end=[R.prim_value(1)],
+ strides=[R.prim_value(1)],
+ assume_inbound=False,
+ )
+ lv1: R.Tensor((3, 2), dtype="int64") = R.strided_slice(
+ x,
+ axes=[0],
+ begin=[R.prim_value(1)],
+ end=[R.prim_value(4)],
+ strides=[R.prim_value(1)],
+ assume_inbound=False,
+ )
+ lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv),
axis=0)
+ gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv2,)
+ R.output(gv)
+ return gv
+
+ # Test case 3: torch.roll(x, shifts=(2,1), dims=(0,1))
+ @I.ir_module
+ class Expected3:
+ @R.function
+ def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4,
2), dtype="int64")):
+ with R.dataflow():
+ # First roll along dim=0 with shift=2
+ lv: R.Tensor((2, 2), dtype="int64") = R.strided_slice(
+ x,
+ axes=[0],
+ begin=[R.prim_value(0)],
+ end=[R.prim_value(2)],
+ strides=[R.prim_value(1)],
+ assume_inbound=False,
+ )
+ lv1: R.Tensor((2, 2), dtype="int64") = R.strided_slice(
+ x,
+ axes=[0],
+ begin=[R.prim_value(2)],
+ end=[R.prim_value(4)],
+ strides=[R.prim_value(1)],
+ assume_inbound=False,
+ )
+ lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv),
axis=0)
+
+ # Second roll along dim=1 with shift=1
+ lv3: R.Tensor((4, 1), dtype="int64") = R.strided_slice(
+ lv2,
+ axes=[1],
+ begin=[R.prim_value(0)],
+ end=[R.prim_value(1)],
+ strides=[R.prim_value(1)],
+ assume_inbound=False,
+ )
+ lv4: R.Tensor((4, 1), dtype="int64") = R.strided_slice(
+ lv2,
+ axes=[1],
+ begin=[R.prim_value(1)],
+ end=[R.prim_value(2)],
+ strides=[R.prim_value(1)],
+ assume_inbound=False,
+ )
+ lv5: R.Tensor((4, 2), dtype="int64") = R.concat((lv4, lv3),
axis=1)
+ gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv5,)
+ R.output(gv)
+ return gv
+
+ # Test inputs
+ example_input = torch.randint(0, 10, (4, 2), dtype=torch.int64)
+
+ # Run verification for each case
+ verify_model(Roll1(), (example_input,), {}, Expected1)
+ verify_model(Roll2(), (example_input,), {}, Expected2)
+ verify_model(Roll3(), (example_input,), {}, Expected3)
+
+
def test_select_slice():
class Slice1(Module):
def forward(self, x):
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 53c925e14e..c522556380 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -3560,6 +3560,126 @@ def test_repeat():
verify_model(Tile2(), [(torch.Size([1, 3]), "float32")], {}, expected2)
+def test_roll():
+ class Roll1(Module):
+ def forward(self, x):
+ return torch.roll(x, 1)
+
+ class Roll2(Module):
+ def forward(self, x):
+ return torch.roll(x, -1, 0)
+
+ class Roll3(Module):
+ def forward(self, x):
+ return torch.roll(x, shifts=(2, 1), dims=(0, 1))
+
+ # Test case 1: torch.roll(x, 1)
+ @I.ir_module
+ class Expected1:
+ @R.function
+ def main(inp_0: R.Tensor((4, 2), dtype="int64")) -> R.Tensor((4, 2),
dtype="int64"):
+ with R.dataflow():
+ lv: R.Tensor((8,), dtype="int64") = R.reshape(inp_0,
R.shape([8]))
+ lv1: R.Tensor((7,), dtype="int64") = R.strided_slice(
+ lv,
+ axes=[0],
+ begin=[R.prim_value(0)],
+ end=[R.prim_value(7)],
+ strides=[R.prim_value(1)],
+ assume_inbound=False,
+ )
+ lv2: R.Tensor((1,), dtype="int64") = R.strided_slice(
+ lv,
+ axes=[0],
+ begin=[R.prim_value(7)],
+ end=[R.prim_value(8)],
+ strides=[R.prim_value(1)],
+ assume_inbound=False,
+ )
+ lv3: R.Tensor((8,), dtype="int64") = R.concat((lv2, lv1),
axis=0)
+ lv4: R.Tensor((4, 2), dtype="int64") = R.reshape(lv3,
R.shape([4, 2]))
+ gv: R.Tensor((4, 2), dtype="int64") = lv4
+ R.output(gv)
+ return gv
+
+ # Test case 2: torch.roll(x, -1, 0)
+ @I.ir_module
+ class Expected2:
+ @R.function
+ def main(inp_0: R.Tensor((4, 2), dtype="int64")) -> R.Tensor((4, 2),
dtype="int64"):
+ with R.dataflow():
+ lv: R.Tensor((1, 2), dtype="int64") = R.strided_slice(
+ inp_0,
+ axes=[0],
+ begin=[R.prim_value(0)],
+ end=[R.prim_value(1)],
+ strides=[R.prim_value(1)],
+ assume_inbound=False,
+ )
+ lv1: R.Tensor((3, 2), dtype="int64") = R.strided_slice(
+ inp_0,
+ axes=[0],
+ begin=[R.prim_value(1)],
+ end=[R.prim_value(4)],
+ strides=[R.prim_value(1)],
+ assume_inbound=False,
+ )
+ lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv),
axis=0)
+ gv: R.Tensor((4, 2), dtype="int64") = lv2
+ R.output(gv)
+ return gv
+
+ # Test case 3: torch.roll(x, shifts=(2, 1), dims=(0, 1))
+ @I.ir_module
+ class Expected3:
+ @R.function
+ def main(inp_0: R.Tensor((4, 2), dtype="int64")) -> R.Tensor((4, 2),
dtype="int64"):
+ with R.dataflow():
+ lv: R.Tensor((2, 2), dtype="int64") = R.strided_slice(
+ inp_0,
+ axes=[0],
+ begin=[R.prim_value(0)],
+ end=[R.prim_value(2)],
+ strides=[R.prim_value(1)],
+ assume_inbound=False,
+ )
+ lv1: R.Tensor((2, 2), dtype="int64") = R.strided_slice(
+ inp_0,
+ axes=[0],
+ begin=[R.prim_value(2)],
+ end=[R.prim_value(4)],
+ strides=[R.prim_value(1)],
+ assume_inbound=False,
+ )
+ lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv),
axis=0)
+ lv3: R.Tensor((4, 1), dtype="int64") = R.strided_slice(
+ lv2,
+ axes=[1],
+ begin=[R.prim_value(0)],
+ end=[R.prim_value(1)],
+ strides=[R.prim_value(1)],
+ assume_inbound=False,
+ )
+ lv4: R.Tensor((4, 1), dtype="int64") = R.strided_slice(
+ lv2,
+ axes=[1],
+ begin=[R.prim_value(1)],
+ end=[R.prim_value(2)],
+ strides=[R.prim_value(1)],
+ assume_inbound=False,
+ )
+ lv5: R.Tensor((4, 2), dtype="int64") = R.concat((lv4, lv3),
axis=1)
+ gv: R.Tensor((4, 2), dtype="int64") = lv5
+ R.output(gv)
+ return gv
+
+ input_info = [([4, 2], "int64")]
+
+ verify_model(Roll1(), input_info, {}, Expected1)
+ verify_model(Roll2(), input_info, {}, Expected2)
+ verify_model(Roll3(), input_info, {}, Expected3)
+
+
def test_view():
input_info = [([1, 2, 3, 4], "float32")]