This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 093d785e22 [Relax][PyTorch] Add atan2 converter (#19850)
093d785e22 is described below

commit 093d785e2297269cf71ea115aec1b62095ba11eb
Author: Javier De Jesus <[email protected]>
AuthorDate: Sun Jun 21 05:26:28 2026 +0200

    [Relax][PyTorch] Add atan2 converter (#19850)
    
    ### Motivation
    
    `torch.atan2` was not registered in either the ExportedProgram or FX
    frontend,
    so importing a model that uses it failed with an "Unsupported function
    types"
    error. The `relax.op.atan2` operator already exists and legalizes to
    `topi.atan2`, so the frontends only needed to route the op to it.
    
    ### Changes
    
    - Register `atan2` in the FX frontend and `atan2.default` in the
    ExportedProgram
    frontend, reusing the shared `_binary_op` helper (the same pattern as
    the
      existing `maximum`/`minimum`/`logaddexp` converters).
    - Add a structural test in `test_frontend_from_fx.py` and
      `test_frontend_from_exported_program.py`.
---
 .../frontend/torch/exported_program_translator.py  |  1 +
 python/tvm/relax/frontend/torch/fx_translator.py   |  2 ++
 .../relax/test_frontend_from_exported_program.py   | 26 ++++++++++++++++++++++
 tests/python/relax/test_frontend_from_fx.py        | 21 +++++++++++++++++
 4 files changed, 50 insertions(+)

diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 6c9e3e3f5e..b96316adee 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -1603,6 +1603,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "add.Tensor": self._binary_op(relax.op.add, operator.add),
             "add.Scalar": self._binary_op(relax.op.add, operator.add),
             "add_.Tensor": self._binary_op(relax.op.add, operator.add),
+            "atan2.default": self._binary_op(relax.op.atan2, torch.atan2),
             "bitwise_and.Tensor": self._binary_op(relax.op.bitwise_and, 
operator.and_),
             "bitwise_and.Scalar": self._binary_op(relax.op.bitwise_and, 
operator.and_),
             "bitwise_or_.Scalar": self._binary_op(relax.op.bitwise_or, 
operator.or_),
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 66d17a5828..4932871bad 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -791,6 +791,7 @@ class TorchFXImporter(BaseFXGraphImporter):
     ) -> dict[torch.nn.Module | str, Callable[[fx.Node], relax.Var]]:
         import operator
 
+        import torch  # type: ignore
         from torch import nn
 
         return {
@@ -909,6 +910,7 @@ class TorchFXImporter(BaseFXGraphImporter):
             # binary
             "add": self._binary_op(relax.op.add, operator.add),
             "and_": self._binary_op(relax.op.bitwise_and, operator.and_),
+            "atan2": self._binary_op(relax.op.atan2, torch.atan2),
             "bitwise_or_": self._binary_op_inplace(relax.op.bitwise_or, 
operator.or_),
             "bitwise_or": self._binary_op(relax.op.bitwise_or, operator.or_),
             "div": self._div,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index ee2f4a8f8d..dac0bd1e2a 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -1016,6 +1016,32 @@ def test_logaddexp():
     verify_model(LogAddExp(), example_args, {}, expected)
 
 
+def test_atan2():
+    class Atan2(Module):
+        def forward(self, lhs, rhs):
+            return torch.atan2(lhs, rhs)
+
+    @tvm.script.ir_module
+    class expected:
+        @R.function
+        def main(
+            lhs: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            rhs: R.Tensor((1, 3, 10, 10), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.atan2(lhs, 
rhs)
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (
+        torch.randn(1, 3, 10, 10, dtype=torch.float32),
+        torch.randn(1, 3, 10, 10, dtype=torch.float32),
+    )
+    verify_model(Atan2(), example_args, {}, expected)
+
+
 def test_logical_and():
     class LogicalAnd(Module):
         def forward(self, lhs, rhs):
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 34da69d5f0..bcb9252b89 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -5335,6 +5335,27 @@ def test_min():
     verify_model(Min(), [([256, 256], "float32"), ([256, 256], "float32")], 
{}, Expected1)
 
 
+def test_atan2():
+    class Atan2(Module):
+        def forward(self, x, y):
+            return torch.atan2(x, y)
+
+    @I.ir_module
+    class Expected1:
+        @R.function
+        def main(
+            inp_0: R.Tensor((256, 256), dtype="float32"),
+            inp_1: R.Tensor((256, 256), dtype="float32"),
+        ) -> R.Tensor((256, 256), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((256, 256), dtype="float32") = R.atan2(inp_0, 
inp_1)
+                gv: R.Tensor((256, 256), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Atan2(), [([256, 256], "float32"), ([256, 256], "float32")], 
{}, Expected1)
+
+
 def test_attention():
     @I.ir_module
     class Expected1:

Reply via email to