This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4e41b42fa3 [Relax][PyTorch] Support narrow and broadcast_to ops for 
ExportedProgram importer (#17830)
4e41b42fa3 is described below

commit 4e41b42fa3d9e2fdcb11f79ff4e2d4f757479894
Author: Shushi Hong <[email protected]>
AuthorDate: Mon Apr 14 10:17:56 2025 +0800

    [Relax][PyTorch] Support narrow and broadcast_to ops for ExportedProgram 
importer (#17830)
    
    * Update exported_program_translator.py
    
    * Update test_frontend_from_exported_program.py
    
    * Update test_frontend_from_exported_program.py
    
    * Update test_frontend_from_exported_program.py
---
 .../frontend/torch/exported_program_translator.py  |  9 ++++
 .../relax/test_frontend_from_exported_program.py   | 50 ++++++++++++++++++++++
 2 files changed, 59 insertions(+)

diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index be17001fd0..c398cc4558 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -202,6 +202,13 @@ class ExportedProgramImporter(BaseFXGraphImporter):
 
     ########## Manipulation ##########
 
+    def _narrow(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        dim = node.args[1]
+        start = node.args[2]
+        length = node.args[3]
+        return self.block_builder.emit(relax.op.strided_slice(x, [dim], 
[start], [length]))
+
     def _select(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         dim = node.args[1]
@@ -390,6 +397,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "where.self": self._where,
             # tensor manipulation
             "argsort.default": self._argsort,
+            "broadcast_to.default": self._broadcast_to,
             "cat.default": self._cat,
             "chunk.default": self._chunk,
             "clamp.Tensor": self._clamp,
@@ -402,6 +410,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "flatten.using_ints": self._flatten,
             "flip.default": self._flip,
             "gather.default": self._gather,
+            "narrow.default": self._narrow,
             "permute.default": self._permute,
             "repeat.default": self._repeat,
             "select.int": self._select,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 42288cf562..284544be50 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -3856,5 +3856,55 @@ def test_dynamic_shape():
     verify_model(DynamicModel(), example_args, {}, Expected, 
dynamic_shapes=dynamic_shapes)
 
 
+def test_broadcast_to():
+    class BroadcastTo(Module):
+        def forward(self, x):
+            return torch.broadcast_to(x, (5, 3))
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((5, 1), dtype="float32")
+        ) -> R.Tuple(R.Tensor((5, 3), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((5, 3), dtype="float32") = R.broadcast_to(x, 
R.shape([5, 3]))
+                gv: R.Tuple(R.Tensor((5, 3), dtype="float32")) = (lv,)
+                R.output(gv)
+
+            return gv
+
+    example_args = (torch.randn(5, 1, dtype=torch.float32),)
+    verify_model(BroadcastTo(), example_args, {}, Expected)
+
+
+def test_narrow():
+    class Narrow(Module):
+        def forward(self, x):
+            return torch.narrow(x, 1, 0, 2)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((5, 3), dtype="float32")
+        ) -> R.Tuple(R.Tensor((5, 2), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((5, 2), dtype="float32") = R.strided_slice(
+                    x,
+                    (R.prim_value(1),),
+                    (R.prim_value(0),),
+                    (R.prim_value(2),),
+                    assume_inbound=False,
+                )
+                gv: R.Tuple(R.Tensor((5, 2), dtype="float32")) = (lv,)
+                R.output(gv)
+
+            return gv
+
+    example_args = (torch.randn(5, 3, dtype=torch.float32),)
+    verify_model(Narrow(), example_args, {}, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to