This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 7271feba41 [Relax][PyTorch] Add support for Custom Ops for 
ExportedProgram frontend (#18544)
7271feba41 is described below

commit 7271feba4161d9751dc1d069d7a9223c9f736a84
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Tue Dec 9 14:10:43 2025 +0900

    [Relax][PyTorch] Add support for Custom Ops for ExportedProgram frontend 
(#18544)
    
    As per title.
    
    cc @tlopex @guan404ming
    
    We keep the interface same as
    
[`from_fx()`](https://github.com/apache/tvm/blob/ed97234b25a155bc66198ab5cd9e372a4772acec/python/tvm/relax/frontend/torch/fx_translator.py#L1152)
    so you can define and pass custom converter something like this.
    
    ```python
    from tvm.relax.frontend.torch.exported_program_translator import 
ExportedProgramImporter
    def _rms_norm_converter(node: torch.fx.Node, self: ExportedProgramImporter) 
-> relax.Var:
        x = self.env[node.args[0]]
        torch_dtype = node.args[0].meta["tensor_meta"].dtype
        normalized_shape = node.args[1]
        weight = self.env.get(node.args[2], None) if len(node.args) > 2 else 
None
        eps = node.args[3] if len(node.args) > 3 else None
    
        N = len(self.shape_of(x))
        D = len(normalized_shape) if isinstance(normalized_shape, (tuple, 
list)) else 1
        axes = list(range(N - D, N))
    
        if weight is None:
            weight = self._convert_torch_tensor_to_relax(
                torch.ones(list(normalized_shape), dtype=torch_dtype)
            )
        eps = torch.finfo(torch_dtype).eps if eps is None else 0.00001
    
        return self.block_builder.emit(relax.op.nn.rms_norm(x, weight, axes, 
eps))
    
    mod = from_exported_program(
        exported_program,
        custom_convert_map={"rms_norm.default": _rms_norm_converter},
        run_ep_decomposition=False,
    )
---
 .../frontend/torch/base_fx_graph_translator.py     | 11 +++++++
 .../frontend/torch/exported_program_translator.py  | 26 +++++++++++++---
 python/tvm/relax/frontend/torch/fx_translator.py   | 11 -------
 .../relax/test_frontend_from_exported_program.py   | 36 ++++++++++++++++++++++
 4 files changed, 69 insertions(+), 15 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 471d4209d7..47eb666210 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -46,6 +46,17 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
 
     ########## Utilities ##########
 
+    def update_convert_map(self, custom_convert_map: Dict[str, Callable]):
+        """Update self.convert_map with custom convert map
+
+        Parameters
+        ----------
+        custom_convert_map : Dict[str, Callable]
+            A custom op conversion map in the same format as self.convert_map
+        """
+
+        self.convert_map.update(custom_convert_map)
+
     @staticmethod
     def _convert_data_type(input_type: Union[str, torch.dtype], env: 
Optional[Dict] = None):
         """converts the PyTorch scalar type input_type to a TVM dtype."""
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 3e2274e551..3d6a632fb2 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -23,6 +23,7 @@ from functools import partial
 from typing import Callable, Dict, List, Optional, Tuple
 
 import torch
+from torch import fx
 import tvm
 from tvm import relax
 
@@ -32,8 +33,6 @@ from .base_fx_graph_translator import BaseFXGraphImporter
 class ExportedProgramImporter(BaseFXGraphImporter):
     """An importer from ExportedProgram to Relax."""
 
-    from torch import fx
-
     @staticmethod
     def _convert_pytorch_tensor_to_tvm(tensor_value: torch.Tensor) -> 
tvm.runtime.Tensor:
         """Convert a PyTorch tensor to TVM tensor, handling sparse tensors.
@@ -1615,9 +1614,18 @@ class ExportedProgramImporter(BaseFXGraphImporter):
         keep_params_as_input: bool,
         unwrap_unit_return_tuple: bool,
         no_bind_return_tuple: bool,
+        custom_convert_map: Optional[
+            Dict[str, Callable[[fx.Node, BaseFXGraphImporter], relax.Var]]
+        ],
     ) -> tvm.IRModule:
         """Convert a PyTorch ExportedProgram to a Relax program."""
-        from torch import fx  # type: ignore
+
+        # Update the conversion map with custom ops if provided.
+        if custom_convert_map:
+            custom_ops = set(custom_convert_map.keys())
+            self.update_convert_map(custom_convert_map)
+        else:
+            custom_ops = set()
 
         # Create input variables.
         (
@@ -1682,7 +1690,10 @@ class ExportedProgramImporter(BaseFXGraphImporter):
                         self.env[node] = 
getattr(exported_program.graph_module, node.target)
                     elif node.op == "call_function":
                         func_name = node.target.__name__
-                        self.env[node] = self.convert_map[func_name](node)
+                        if func_name in custom_ops:
+                            self.env[node] = self.convert_map[func_name](node, 
self)
+                        else:
+                            self.env[node] = self.convert_map[func_name](node)
                     else:
                         raise ValueError(f"Unsupported op {node.op}")
             assert output is not None
@@ -1722,6 +1733,9 @@ def from_exported_program(
     keep_params_as_input: bool = False,
     unwrap_unit_return_tuple: bool = False,
     no_bind_return_tuple: bool = False,
+    custom_convert_map: Optional[
+        Dict[str, Callable[[fx.Node, BaseFXGraphImporter], relax.Var]]
+    ] = None,
     run_ep_decomposition: bool = True,
 ) -> tvm.IRModule:
     """Convert a PyTorch ExportedProgram to a Relax program
@@ -1742,6 +1756,9 @@ def from_exported_program(
         A boolean flag indicating whether to bind the return tuple as a relax 
var.
         If the flag is true and the return value is a tuple, it will not bind 
it to a var.
 
+    custom_convert_map : Dict[str, Callable[[fx.Node, BaseFXGraphImporter], 
relax.Var]]
+        A custom op conversion map in the same format as 
ExportedProgramImporter.convert_map above
+
     run_ep_decomposition : bool
         A boolean flag indicating whether to run PyTorch's decomposition on the
         exported program before translation. When True, high-level operators 
will
@@ -1795,4 +1812,5 @@ def from_exported_program(
         keep_params_as_input,
         unwrap_unit_return_tuple,
         no_bind_return_tuple,
+        custom_convert_map,
     )
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 8b1f5de36b..f2a6c9e654 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -1037,17 +1037,6 @@ class TorchFXImporter(BaseFXGraphImporter):
             "item": self._item,
         }
 
-    def update_convert_map(self, custom_convert_map: dict):
-        """Update self.convert_map with custom convert map
-
-        Parameters
-        ----------
-        custom_convert_map : Dictionary of str to Relax op
-            A custom op conversion map in the same format as self.convert_map
-        """
-
-        self.convert_map.update(custom_convert_map)
-
     def from_fx(
         self,
         model,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 74ad2329fe..01e16e7564 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -42,6 +42,7 @@ def verify_model(
     unwrap_unit_return_tuple=False,
     no_bind_return_tuple=False,
     map_free_vars=False,
+    custom_convert_map=None,
 ):
     exported_program = export(torch_model, args=example_args, 
dynamic_shapes=dynamic_shapes)
     mod = from_exported_program(
@@ -50,6 +51,7 @@ def verify_model(
         keep_params_as_input=keep_params_as_input,
         unwrap_unit_return_tuple=unwrap_unit_return_tuple,
         no_bind_return_tuple=no_bind_return_tuple,
+        custom_convert_map=custom_convert_map,
     )
 
     binding = {k: tvm.runtime.tensor(v) for k, v in binding.items()}
@@ -6562,6 +6564,40 @@ def test_register_buffer():
     from_exported_program(ep)
 
 
+def test_custom_op():
+    class AddOp(Module):
+        def forward(self, x, y):
+            return torch.ops.aten.add.Tensor(x, y)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((5,), dtype="float32"),
+            y: R.Tensor((5,), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((5,), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((5,), dtype="float32") = R.subtract(x, y)
+                gv: R.Tuple(R.Tensor((5,), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    from tvm.relax.frontend.torch.exported_program_translator import (
+        ExportedProgramImporter,
+    )
+
+    def custom_add_converter(node: torch.fx.Node, self: 
ExportedProgramImporter) -> relax.Var:
+        x = self.env[node.args[0]]
+        y = self.env[node.args[1]]
+
+        return self.block_builder.emit(R.subtract(x, y))
+
+    example_args = (torch.randn(5, dtype=torch.float32), torch.randn(5, 
dtype=torch.float32))
+    verify_model(
+        AddOp(), example_args, {}, Expected, custom_convert_map={"add.Tensor": 
custom_add_converter}
+    )
+
+
 def test_empty_like():
     class EmptyLike(Module):
         def forward(self, data):

Reply via email to