This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4e41b42fa3 [Relax][PyTorch] Support narrow and broadcast_to ops for
ExportedProgram importer (#17830)
4e41b42fa3 is described below
commit 4e41b42fa3d9e2fdcb11f79ff4e2d4f757479894
Author: Shushi Hong <[email protected]>
AuthorDate: Mon Apr 14 10:17:56 2025 +0800
[Relax][PyTorch] Support narrow and broadcast_to ops for ExportedProgram
importer (#17830)
* Update exported_program_translator.py
* Update test_frontend_from_exported_program.py
* Update test_frontend_from_exported_program.py
* Update test_frontend_from_exported_program.py
---
.../frontend/torch/exported_program_translator.py | 9 ++++
.../relax/test_frontend_from_exported_program.py | 50 ++++++++++++++++++++++
2 files changed, 59 insertions(+)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index be17001fd0..c398cc4558 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -202,6 +202,13 @@ class ExportedProgramImporter(BaseFXGraphImporter):
########## Manipulation ##########
+ def _narrow(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ dim = node.args[1]
+ start = node.args[2]
+ length = node.args[3]
+ return self.block_builder.emit(relax.op.strided_slice(x, [dim],
[start], [length]))
+
def _select(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
dim = node.args[1]
@@ -390,6 +397,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"where.self": self._where,
# tensor manipulation
"argsort.default": self._argsort,
+ "broadcast_to.default": self._broadcast_to,
"cat.default": self._cat,
"chunk.default": self._chunk,
"clamp.Tensor": self._clamp,
@@ -402,6 +410,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"flatten.using_ints": self._flatten,
"flip.default": self._flip,
"gather.default": self._gather,
+ "narrow.default": self._narrow,
"permute.default": self._permute,
"repeat.default": self._repeat,
"select.int": self._select,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 42288cf562..284544be50 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -3856,5 +3856,55 @@ def test_dynamic_shape():
verify_model(DynamicModel(), example_args, {}, Expected,
dynamic_shapes=dynamic_shapes)
+def test_broadcast_to():
+ class BroadcastTo(Module):
+ def forward(self, x):
+ return torch.broadcast_to(x, (5, 3))
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((5, 1), dtype="float32")
+ ) -> R.Tuple(R.Tensor((5, 3), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((5, 3), dtype="float32") = R.broadcast_to(x,
R.shape([5, 3]))
+ gv: R.Tuple(R.Tensor((5, 3), dtype="float32")) = (lv,)
+ R.output(gv)
+
+ return gv
+
+ example_args = (torch.randn(5, 1, dtype=torch.float32),)
+ verify_model(BroadcastTo(), example_args, {}, Expected)
+
+
+def test_narrow():
+ class Narrow(Module):
+ def forward(self, x):
+ return torch.narrow(x, 1, 0, 2)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((5, 3), dtype="float32")
+ ) -> R.Tuple(R.Tensor((5, 2), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((5, 2), dtype="float32") = R.strided_slice(
+ x,
+ (R.prim_value(1),),
+ (R.prim_value(0),),
+ (R.prim_value(2),),
+ assume_inbound=False,
+ )
+ gv: R.Tuple(R.Tensor((5, 2), dtype="float32")) = (lv,)
+ R.output(gv)
+
+ return gv
+
+ example_args = (torch.randn(5, 3, dtype=torch.float32),)
+ verify_model(Narrow(), example_args, {}, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()