This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 36f2502f00 Add op support for roll op (#17839)
36f2502f00 is described below

commit 36f2502f00d58128451d64c8ff12fa871c994c92
Author: Pratheesh-04-MCW <[email protected]>
AuthorDate: Fri Apr 18 17:25:36 2025 +0530

    Add op support for roll op (#17839)
    
    * add op support for roll op
    
    * lint fix
    
    * fixed unity check
    
    * add unit test in fx_graph
    
    * lint issues
    
    * lint check
    
    * confilct resolved
---
 .../frontend/torch/base_fx_graph_translator.py     |  81 ++++++++++++-
 .../frontend/torch/exported_program_translator.py  |   1 +
 python/tvm/relax/frontend/torch/fx_translator.py   |   1 +
 .../relax/test_frontend_from_exported_program.py   | 125 +++++++++++++++++++++
 tests/python/relax/test_frontend_from_fx.py        | 120 ++++++++++++++++++++
 5 files changed, 327 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 2652b167e5..ae4c918900 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -23,7 +23,7 @@ from functools import reduce
 import math
 from typing import Callable, Dict, Optional, Tuple, Union, List
 
-from tvm import relax
+from tvm import relax, tir
 
 
 class BaseFXGraphImporter(metaclass=abc.ABCMeta):
@@ -1164,6 +1164,85 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else 
args[1:]
         return self.block_builder.emit(relax.op.tile(x, dims))
 
+    def _roll(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        input_tensor = args[0]
+        shifts = args[1] if len(node.args) > 1 else node.kwargs.get("shifts", 
None)
+        dims = args[2] if len(node.args) > 2 else node.kwargs.get("dims", None)
+
+        # Get original shape
+        original_shape = self.shape_of(input_tensor)
+
+        def to_int(val):
+            if isinstance(val, tir.IntImm):
+                return int(val.value)
+            elif isinstance(val, int):
+                return val
+            elif hasattr(val, "__int__"):
+                return int(val)
+            raise TypeError(f"Unsupported type for shift/dim: {type(val)}")
+
+        def roll_single_dim(tensor: relax.Var, shift: int, dim: int) -> 
relax.Var:
+            shape = self.shape_of(tensor)
+
+            dim_size = shape.values[dim]
+            shift_val = to_int(shift)
+            dim_size_val = to_int(dim_size)
+            shift_mod = shift_val % dim_size_val
+            if shift_mod == 0:
+                return tensor
+
+            split_pos = dim_size_val - shift_mod
+            part1 = self.block_builder.emit(
+                relax.op.strided_slice(
+                    tensor,
+                    axes=[dim],
+                    begin=[0],
+                    end=[split_pos],
+                    strides=[1],
+                )
+            )
+            part2 = self.block_builder.emit(
+                relax.op.strided_slice(
+                    tensor,
+                    axes=[dim],
+                    begin=[split_pos],
+                    end=[dim_size_val],
+                    strides=[1],
+                )
+            )
+            return self.block_builder.emit(relax.op.concat([part2, part1], 
axis=dim))
+
+        # Handle dims=None (flatten -> roll -> reshape)
+        if dims is None:
+            flattened = self.block_builder.emit(relax.op.reshape(input_tensor, 
(-1,)))
+            shift_scalar = to_int(shifts[0] if isinstance(shifts, (list, 
tuple)) else shifts)
+            rolled = roll_single_dim(flattened, shift_scalar, 0)
+            return self.block_builder.emit(relax.op.reshape(rolled, 
original_shape))
+
+        # Normalize shifts and dims
+        if isinstance(shifts, (list, tuple)):
+            shifts = [to_int(s) for s in shifts]
+        else:
+            shifts = [to_int(shifts)]
+
+        if isinstance(dims, (list, tuple)):
+            dims = [to_int(d) for d in dims]
+        else:
+            dims = [to_int(dims)]
+
+        if len(shifts) != len(dims):
+            raise ValueError("shifts and dims must have the same length")
+
+        result = input_tensor
+        rank = len(original_shape.values)
+        for shift, dim in zip(shifts, dims):
+            if dim < 0:
+                dim += rank
+            result = roll_single_dim(result, shift, dim)
+
+        return result
+
     def _reshape(self, node: fx.Node) -> relax.Var:
         import torch  # type: ignore
 
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 5d4f3437b2..9326072875 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -423,6 +423,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "narrow.default": self._narrow,
             "permute.default": self._permute,
             "repeat.default": self._repeat,
+            "roll.default": self._roll,
             "select.int": self._select,
             "slice.Tensor": self._slice,
             "split.Tensor": self._split,
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index e6b1fdd223..5a34befb92 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -750,6 +750,7 @@ class TorchFXImporter(BaseFXGraphImporter):
             "numel": self._numel,
             "permute": self._permute,
             "repeat": self._repeat,
+            "roll": self._roll,
             "reshape": self._reshape,
             "scatter": self._scatter,
             "select": self._select,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 9259936dc2..80c0bd5fb4 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -2968,6 +2968,131 @@ def test_reshape_as():
     verify_model(ReshapeAs(), example_args, {}, expected1)
 
 
+def test_roll():
+    class Roll1(Module):
+        def forward(self, x):
+            return torch.roll(x, 1)
+
+    class Roll2(Module):
+        def forward(self, x):
+            return torch.roll(x, -1, 0)
+
+    class Roll3(Module):
+        def forward(self, x):
+            return torch.roll(x, shifts=(2, 1), dims=(0, 1))
+
+    # Test case 1: torch.roll(x, 1)
+    @I.ir_module
+    class Expected1:
+        @R.function
+        def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 
2), dtype="int64")):
+            with R.dataflow():
+                lv: R.Tensor((8,), dtype="int64") = R.reshape(x, R.shape([8]))
+                lv1: R.Tensor((7,), dtype="int64") = R.strided_slice(
+                    lv,
+                    axes=[0],
+                    begin=[R.prim_value(0)],
+                    end=[R.prim_value(7)],
+                    strides=[R.prim_value(1)],
+                    assume_inbound=False,
+                )
+                lv2: R.Tensor((1,), dtype="int64") = R.strided_slice(
+                    lv,
+                    axes=[0],
+                    begin=[R.prim_value(7)],
+                    end=[R.prim_value(8)],
+                    strides=[R.prim_value(1)],
+                    assume_inbound=False,
+                )
+                lv3: R.Tensor((8,), dtype="int64") = R.concat((lv2, lv1), 
axis=0)
+                lv4: R.Tensor((4, 2), dtype="int64") = R.reshape(lv3, 
R.shape([4, 2]))
+                gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv4,)
+                R.output(gv)
+            return gv
+
+    # Test case 2: torch.roll(x, -1, 0)
+    @I.ir_module
+    class Expected2:
+        @R.function
+        def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 
2), dtype="int64")):
+            with R.dataflow():
+                lv: R.Tensor((1, 2), dtype="int64") = R.strided_slice(
+                    x,
+                    axes=[0],
+                    begin=[R.prim_value(0)],
+                    end=[R.prim_value(1)],
+                    strides=[R.prim_value(1)],
+                    assume_inbound=False,
+                )
+                lv1: R.Tensor((3, 2), dtype="int64") = R.strided_slice(
+                    x,
+                    axes=[0],
+                    begin=[R.prim_value(1)],
+                    end=[R.prim_value(4)],
+                    strides=[R.prim_value(1)],
+                    assume_inbound=False,
+                )
+                lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv), 
axis=0)
+                gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv2,)
+                R.output(gv)
+            return gv
+
+    # Test case 3: torch.roll(x, shifts=(2,1), dims=(0,1))
+    @I.ir_module
+    class Expected3:
+        @R.function
+        def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 
2), dtype="int64")):
+            with R.dataflow():
+                # First roll along dim=0 with shift=2
+                lv: R.Tensor((2, 2), dtype="int64") = R.strided_slice(
+                    x,
+                    axes=[0],
+                    begin=[R.prim_value(0)],
+                    end=[R.prim_value(2)],
+                    strides=[R.prim_value(1)],
+                    assume_inbound=False,
+                )
+                lv1: R.Tensor((2, 2), dtype="int64") = R.strided_slice(
+                    x,
+                    axes=[0],
+                    begin=[R.prim_value(2)],
+                    end=[R.prim_value(4)],
+                    strides=[R.prim_value(1)],
+                    assume_inbound=False,
+                )
+                lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv), 
axis=0)
+
+                # Second roll along dim=1 with shift=1
+                lv3: R.Tensor((4, 1), dtype="int64") = R.strided_slice(
+                    lv2,
+                    axes=[1],
+                    begin=[R.prim_value(0)],
+                    end=[R.prim_value(1)],
+                    strides=[R.prim_value(1)],
+                    assume_inbound=False,
+                )
+                lv4: R.Tensor((4, 1), dtype="int64") = R.strided_slice(
+                    lv2,
+                    axes=[1],
+                    begin=[R.prim_value(1)],
+                    end=[R.prim_value(2)],
+                    strides=[R.prim_value(1)],
+                    assume_inbound=False,
+                )
+                lv5: R.Tensor((4, 2), dtype="int64") = R.concat((lv4, lv3), 
axis=1)
+                gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv5,)
+                R.output(gv)
+            return gv
+
+    # Test inputs
+    example_input = torch.randint(0, 10, (4, 2), dtype=torch.int64)
+
+    # Run verification for each case
+    verify_model(Roll1(), (example_input,), {}, Expected1)
+    verify_model(Roll2(), (example_input,), {}, Expected2)
+    verify_model(Roll3(), (example_input,), {}, Expected3)
+
+
 def test_select_slice():
     class Slice1(Module):
         def forward(self, x):
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 53c925e14e..c522556380 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -3560,6 +3560,126 @@ def test_repeat():
     verify_model(Tile2(), [(torch.Size([1, 3]), "float32")], {}, expected2)
 
 
+def test_roll():
+    class Roll1(Module):
+        def forward(self, x):
+            return torch.roll(x, 1)
+
+    class Roll2(Module):
+        def forward(self, x):
+            return torch.roll(x, -1, 0)
+
+    class Roll3(Module):
+        def forward(self, x):
+            return torch.roll(x, shifts=(2, 1), dims=(0, 1))
+
+    # Test case 1: torch.roll(x, 1)
+    @I.ir_module
+    class Expected1:
+        @R.function
+        def main(inp_0: R.Tensor((4, 2), dtype="int64")) -> R.Tensor((4, 2), 
dtype="int64"):
+            with R.dataflow():
+                lv: R.Tensor((8,), dtype="int64") = R.reshape(inp_0, 
R.shape([8]))
+                lv1: R.Tensor((7,), dtype="int64") = R.strided_slice(
+                    lv,
+                    axes=[0],
+                    begin=[R.prim_value(0)],
+                    end=[R.prim_value(7)],
+                    strides=[R.prim_value(1)],
+                    assume_inbound=False,
+                )
+                lv2: R.Tensor((1,), dtype="int64") = R.strided_slice(
+                    lv,
+                    axes=[0],
+                    begin=[R.prim_value(7)],
+                    end=[R.prim_value(8)],
+                    strides=[R.prim_value(1)],
+                    assume_inbound=False,
+                )
+                lv3: R.Tensor((8,), dtype="int64") = R.concat((lv2, lv1), 
axis=0)
+                lv4: R.Tensor((4, 2), dtype="int64") = R.reshape(lv3, 
R.shape([4, 2]))
+                gv: R.Tensor((4, 2), dtype="int64") = lv4
+                R.output(gv)
+            return gv
+
+    # Test case 2: torch.roll(x, -1, 0)
+    @I.ir_module
+    class Expected2:
+        @R.function
+        def main(inp_0: R.Tensor((4, 2), dtype="int64")) -> R.Tensor((4, 2), 
dtype="int64"):
+            with R.dataflow():
+                lv: R.Tensor((1, 2), dtype="int64") = R.strided_slice(
+                    inp_0,
+                    axes=[0],
+                    begin=[R.prim_value(0)],
+                    end=[R.prim_value(1)],
+                    strides=[R.prim_value(1)],
+                    assume_inbound=False,
+                )
+                lv1: R.Tensor((3, 2), dtype="int64") = R.strided_slice(
+                    inp_0,
+                    axes=[0],
+                    begin=[R.prim_value(1)],
+                    end=[R.prim_value(4)],
+                    strides=[R.prim_value(1)],
+                    assume_inbound=False,
+                )
+                lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv), 
axis=0)
+                gv: R.Tensor((4, 2), dtype="int64") = lv2
+                R.output(gv)
+            return gv
+
+    # Test case 3: torch.roll(x, shifts=(2, 1), dims=(0, 1))
+    @I.ir_module
+    class Expected3:
+        @R.function
+        def main(inp_0: R.Tensor((4, 2), dtype="int64")) -> R.Tensor((4, 2), 
dtype="int64"):
+            with R.dataflow():
+                lv: R.Tensor((2, 2), dtype="int64") = R.strided_slice(
+                    inp_0,
+                    axes=[0],
+                    begin=[R.prim_value(0)],
+                    end=[R.prim_value(2)],
+                    strides=[R.prim_value(1)],
+                    assume_inbound=False,
+                )
+                lv1: R.Tensor((2, 2), dtype="int64") = R.strided_slice(
+                    inp_0,
+                    axes=[0],
+                    begin=[R.prim_value(2)],
+                    end=[R.prim_value(4)],
+                    strides=[R.prim_value(1)],
+                    assume_inbound=False,
+                )
+                lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv), 
axis=0)
+                lv3: R.Tensor((4, 1), dtype="int64") = R.strided_slice(
+                    lv2,
+                    axes=[1],
+                    begin=[R.prim_value(0)],
+                    end=[R.prim_value(1)],
+                    strides=[R.prim_value(1)],
+                    assume_inbound=False,
+                )
+                lv4: R.Tensor((4, 1), dtype="int64") = R.strided_slice(
+                    lv2,
+                    axes=[1],
+                    begin=[R.prim_value(1)],
+                    end=[R.prim_value(2)],
+                    strides=[R.prim_value(1)],
+                    assume_inbound=False,
+                )
+                lv5: R.Tensor((4, 2), dtype="int64") = R.concat((lv4, lv3), 
axis=1)
+                gv: R.Tensor((4, 2), dtype="int64") = lv5
+                R.output(gv)
+            return gv
+
+    input_info = [([4, 2], "int64")]
+
+    verify_model(Roll1(), input_info, {}, Expected1)
+    verify_model(Roll2(), input_info, {}, Expected2)
+    verify_model(Roll3(), input_info, {}, Expected3)
+
+
 def test_view():
     input_info = [([1, 2, 3, 4], "float32")]
 

Reply via email to