This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new f41c3b54ba [Relax][PyTorch] Support `leaky_relu_.default` and
`reshape_as.default` in ExportedProgram frontend (#17851)
f41c3b54ba is described below
commit f41c3b54bad0135441dd60fdd40bdeb27896b8d8
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Fri Apr 18 14:59:02 2025 +0900
[Relax][PyTorch] Support `leaky_relu_.default` and `reshape_as.default` in
ExportedProgram frontend (#17851)
* support `leaky_relu_.default`
* support `reshape_as.default`
---
.../frontend/torch/base_fx_graph_translator.py | 7 +++++
.../frontend/torch/exported_program_translator.py | 2 ++
.../relax/test_frontend_from_exported_program.py | 31 ++++++++++++++++++++++
3 files changed, 40 insertions(+)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 3ea70df9a1..2652b167e5 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1172,6 +1172,13 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else
args[1:]
return self.block_builder.emit(relax.op.reshape(x, dims))
+ def _reshape_as(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ x = args[0]
+ other = args[1]
+ dims = self.shape_of(other)
+ return self.block_builder.emit(relax.op.reshape(x, dims))
+
def _scatter(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
if len(node.args) == 1:
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index cecffa753f..5d4f3437b2 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -293,6 +293,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"isinf.default": self._unary_op(relax.op.isinf),
"isnan.default": self._unary_op(relax.op.isnan),
"leaky_relu.default": self._leakyrelu,
+ "leaky_relu_.default": self._leakyrelu,
"log.default": self._unary_op(relax.op.log),
"log2.default": self._log2,
"log10.default": self._log10,
@@ -439,6 +440,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
),
"view.default": self._reshape,
"reshape.default": self._reshape,
+ "reshape_as.default": self._reshape_as,
# tensor creation
"_to_copy.default": self._to_copy,
"arange.default": self._arange,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 6cdefbb12e..9259936dc2 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -605,6 +605,10 @@ def test_leakyrelu():
def forward(self, input):
return torch.nn.functional.leaky_relu(input, 0.02)
+ class LeakyReLU2(Module):
+ def forward(self, input):
+ return torch.ops.aten.leaky_relu_(input, 0.02)
+
@tvm.script.ir_module
class expected:
@R.function
@@ -621,6 +625,7 @@ def test_leakyrelu():
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
verify_model(LeakyReLU0(), example_args, {}, expected)
verify_model(LeakyReLU1(), example_args, {}, expected)
+ verify_model(LeakyReLU2(), example_args, {}, expected)
def test_logaddexp():
@@ -2937,6 +2942,32 @@ def test_reshape():
verify_model(Reshape(), example_args, {}, expected1)
+def test_reshape_as():
+ class ReshapeAs(Module):
+ def forward(self, x: torch.Tensor, y: torch.Tensor):
+ return x.reshape_as(y)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ x: R.Tensor((1, 2, 3, 4), dtype="float32"),
+ y: R.Tensor((2, 12), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((2, 12), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((2, 12), dtype="float32") = R.reshape(x, (2, 12))
+ gv: R.Tuple(R.Tensor((2, 12), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (
+ torch.randn(1, 2, 3, 4, dtype=torch.float32),
+ torch.randn(2, 12, dtype=torch.float32),
+ )
+ verify_model(ReshapeAs(), example_args, {}, expected1)
+
+
def test_select_slice():
class Slice1(Module):
def forward(self, x):