This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new f41c3b54ba [Relax][PyTorch] Support `leaky_relu_.default` and 
`reshape_as.default` in ExportedProgram frontend (#17851)
f41c3b54ba is described below

commit f41c3b54bad0135441dd60fdd40bdeb27896b8d8
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Fri Apr 18 14:59:02 2025 +0900

    [Relax][PyTorch] Support `leaky_relu_.default` and `reshape_as.default` in 
ExportedProgram frontend (#17851)
    
    * support `leaky_relu_.default`
    
    * support `reshape_as.default`
---
 .../frontend/torch/base_fx_graph_translator.py     |  7 +++++
 .../frontend/torch/exported_program_translator.py  |  2 ++
 .../relax/test_frontend_from_exported_program.py   | 31 ++++++++++++++++++++++
 3 files changed, 40 insertions(+)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 3ea70df9a1..2652b167e5 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1172,6 +1172,13 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else 
args[1:]
         return self.block_builder.emit(relax.op.reshape(x, dims))
 
+    def _reshape_as(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        x = args[0]
+        other = args[1]
+        dims = self.shape_of(other)
+        return self.block_builder.emit(relax.op.reshape(x, dims))
+
     def _scatter(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         if len(node.args) == 1:
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index cecffa753f..5d4f3437b2 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -293,6 +293,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "isinf.default": self._unary_op(relax.op.isinf),
             "isnan.default": self._unary_op(relax.op.isnan),
             "leaky_relu.default": self._leakyrelu,
+            "leaky_relu_.default": self._leakyrelu,
             "log.default": self._unary_op(relax.op.log),
             "log2.default": self._log2,
             "log10.default": self._log10,
@@ -439,6 +440,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             ),
             "view.default": self._reshape,
             "reshape.default": self._reshape,
+            "reshape_as.default": self._reshape_as,
             # tensor creation
             "_to_copy.default": self._to_copy,
             "arange.default": self._arange,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 6cdefbb12e..9259936dc2 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -605,6 +605,10 @@ def test_leakyrelu():
         def forward(self, input):
             return torch.nn.functional.leaky_relu(input, 0.02)
 
+    class LeakyReLU2(Module):
+        def forward(self, input):
+            return torch.ops.aten.leaky_relu_(input, 0.02)
+
     @tvm.script.ir_module
     class expected:
         @R.function
@@ -621,6 +625,7 @@ def test_leakyrelu():
     example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
     verify_model(LeakyReLU0(), example_args, {}, expected)
     verify_model(LeakyReLU1(), example_args, {}, expected)
+    verify_model(LeakyReLU2(), example_args, {}, expected)
 
 
 def test_logaddexp():
@@ -2937,6 +2942,32 @@ def test_reshape():
     verify_model(Reshape(), example_args, {}, expected1)
 
 
+def test_reshape_as():
+    class ReshapeAs(Module):
+        def forward(self, x: torch.Tensor, y: torch.Tensor):
+            return x.reshape_as(y)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            x: R.Tensor((1, 2, 3, 4), dtype="float32"),
+            y: R.Tensor((2, 12), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((2, 12), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((2, 12), dtype="float32") = R.reshape(x, (2, 12))
+                gv: R.Tuple(R.Tensor((2, 12), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (
+        torch.randn(1, 2, 3, 4, dtype=torch.float32),
+        torch.randn(2, 12, dtype=torch.float32),
+    )
+    verify_model(ReshapeAs(), example_args, {}, expected1)
+
+
 def test_select_slice():
     class Slice1(Module):
         def forward(self, x):

Reply via email to