This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b01dadb49d [Relax][PyTorch] Add `as_strided` operator in
ExportedProgram frontend (#18490)
b01dadb49d is described below
commit b01dadb49d002a542700840d6b5714877451e712
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Sun Nov 23 17:04:56 2025 +0900
[Relax][PyTorch] Add `as_strided` operator in ExportedProgram frontend
(#18490)
As per title.
---
.../frontend/torch/exported_program_translator.py | 32 ++++++++++++++++++
.../relax/test_frontend_from_exported_program.py | 38 ++++++++++++++++++++++
2 files changed, 70 insertions(+)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 1961898f76..d7975a8dde 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -954,6 +954,37 @@ class ExportedProgramImporter(BaseFXGraphImporter):
return self.block_builder.emit(relax.op.scatter_elements(x, index,
src, axis=dim))
+ def _as_strided(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ x = args[0]
+ size = args[1]
+ stride = args[2]
+ storage_offset = args[3] if len(args) > 3 else
node.kwargs.get("storage_offset", 0)
+
+ assert storage_offset == 0, "as_strided with non-zero storage_offset
is not supported yet"
+
+ # Only handle view-like cases where the provided strides align with a
contiguous layout.
+ can_check = all(isinstance(dim, (int, tvm.tir.IntImm)) for dim in
size) and all(
+ isinstance(st, (int, tvm.tir.IntImm)) for st in stride
+ )
+ if can_check:
+ expected_stride = []
+ running = 1
+ for dim in reversed(size):
+ dim_int = int(dim)
+ expected_stride.insert(0, running)
+ running *= dim_int
+
+ for dim, st, exp in zip(size, stride, expected_stride):
+ dim_int = int(dim)
+ if dim_int != 1 and int(st) != exp:
+ raise AssertionError(
+ f"as_strided with non-contiguous stride {stride} for"
+ f"size {size} is not supported"
+ )
+
+ return self.block_builder.emit(relax.op.reshape(x, size))
+
########## Others ##########
def create_convert_map(
@@ -1219,6 +1250,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"view.default": self._reshape,
"reshape.default": self._reshape,
"reshape_as.default": self._reshape_as,
+ "as_strided.default": self._as_strided,
# tensor creation
"_to_copy.default": self._to_copy,
"arange.default": self._arange,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index a61da359d3..341bafc267 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -5633,6 +5633,44 @@ def test_view():
verify_model(View(), example_args, {}, expected1)
+def test_as_strided():
+ class AsStrided(Module):
+ def forward(self, x):
+ return torch.ops.aten.as_strided.default(x, (3, 2, 2), (4, 2, 1))
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 2, 3), dtype="float32")
+ ) -> R.Tuple(R.Tensor((3, 2, 2), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((3, 2, 2), dtype="float32") = R.reshape(x, (3, 2,
2))
+ gv: R.Tuple(R.Tensor((3, 2, 2), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ class AsStridedNonContiguous(Module):
+ def forward(self, x):
+ return torch.ops.aten.as_strided.default(x, (2, 2, 2), (6, 3, 1))
+
+ class AsStridedWithStorageOffset(Module):
+ def forward(self, x):
+ return torch.ops.aten.as_strided.default(x, (2, 2), (2, 1), 1)
+
+ example_args = (torch.randn(2, 2, 3, dtype=torch.float32),)
+ verify_model(AsStrided(), example_args, {}, Expected)
+
+ exported = export(AsStridedNonContiguous(), args=example_args)
+ with pytest.raises(AssertionError, match="non-contiguous stride"):
+ from_exported_program(exported)
+
+ example_args = (torch.randn(2, 2, dtype=torch.float32),)
+ exported = export(AsStridedWithStorageOffset(), args=example_args)
+ with pytest.raises(AssertionError, match="storage_offset"):
+ from_exported_program(exported)
+
+
def test_arange():
class Arange(Module):
def forward(self, input):