This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b01dadb49d [Relax][PyTorch] Add `as_strided` operator in 
ExportedProgram frontend (#18490)
b01dadb49d is described below

commit b01dadb49d002a542700840d6b5714877451e712
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Sun Nov 23 17:04:56 2025 +0900

    [Relax][PyTorch] Add `as_strided` operator in ExportedProgram frontend 
(#18490)
    
    As per title.
---
 .../frontend/torch/exported_program_translator.py  | 32 ++++++++++++++++++
 .../relax/test_frontend_from_exported_program.py   | 38 ++++++++++++++++++++++
 2 files changed, 70 insertions(+)

diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 1961898f76..d7975a8dde 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -954,6 +954,37 @@ class ExportedProgramImporter(BaseFXGraphImporter):
 
         return self.block_builder.emit(relax.op.scatter_elements(x, index, 
src, axis=dim))
 
+    def _as_strided(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        x = args[0]
+        size = args[1]
+        stride = args[2]
+        storage_offset = args[3] if len(args) > 3 else 
node.kwargs.get("storage_offset", 0)
+
+        assert storage_offset == 0, "as_strided with non-zero storage_offset 
is not supported yet"
+
+        # Only handle view-like cases where the provided strides align with a 
contiguous layout.
+        can_check = all(isinstance(dim, (int, tvm.tir.IntImm)) for dim in 
size) and all(
+            isinstance(st, (int, tvm.tir.IntImm)) for st in stride
+        )
+        if can_check:
+            expected_stride = []
+            running = 1
+            for dim in reversed(size):
+                dim_int = int(dim)
+                expected_stride.insert(0, running)
+                running *= dim_int
+
+            for dim, st, exp in zip(size, stride, expected_stride):
+                dim_int = int(dim)
+                if dim_int != 1 and int(st) != exp:
+                    raise AssertionError(
+                        f"as_strided with non-contiguous stride {stride} for"
+                        f"size {size} is not supported"
+                    )
+
+        return self.block_builder.emit(relax.op.reshape(x, size))
+
     ########## Others ##########
 
     def create_convert_map(
@@ -1219,6 +1250,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "view.default": self._reshape,
             "reshape.default": self._reshape,
             "reshape_as.default": self._reshape_as,
+            "as_strided.default": self._as_strided,
             # tensor creation
             "_to_copy.default": self._to_copy,
             "arange.default": self._arange,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index a61da359d3..341bafc267 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -5633,6 +5633,44 @@ def test_view():
     verify_model(View(), example_args, {}, expected1)
 
 
+def test_as_strided():
+    class AsStrided(Module):
+        def forward(self, x):
+            return torch.ops.aten.as_strided.default(x, (3, 2, 2), (4, 2, 1))
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 2, 3), dtype="float32")
+        ) -> R.Tuple(R.Tensor((3, 2, 2), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((3, 2, 2), dtype="float32") = R.reshape(x, (3, 2, 
2))
+                gv: R.Tuple(R.Tensor((3, 2, 2), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    class AsStridedNonContiguous(Module):
+        def forward(self, x):
+            return torch.ops.aten.as_strided.default(x, (2, 2, 2), (6, 3, 1))
+
+    class AsStridedWithStorageOffset(Module):
+        def forward(self, x):
+            return torch.ops.aten.as_strided.default(x, (2, 2), (2, 1), 1)
+
+    example_args = (torch.randn(2, 2, 3, dtype=torch.float32),)
+    verify_model(AsStrided(), example_args, {}, Expected)
+
+    exported = export(AsStridedNonContiguous(), args=example_args)
+    with pytest.raises(AssertionError, match="non-contiguous stride"):
+        from_exported_program(exported)
+
+    example_args = (torch.randn(2, 2, dtype=torch.float32),)
+    exported = export(AsStridedWithStorageOffset(), args=example_args)
+    with pytest.raises(AssertionError, match="storage_offset"):
+        from_exported_program(exported)
+
+
 def test_arange():
     class Arange(Module):
         def forward(self, input):

Reply via email to