This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 7271feba41 [Relax][PyTorch] Add support for Custom Ops for
ExportedProgram frontend (#18544)
7271feba41 is described below
commit 7271feba4161d9751dc1d069d7a9223c9f736a84
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Tue Dec 9 14:10:43 2025 +0900
[Relax][PyTorch] Add support for Custom Ops for ExportedProgram frontend
(#18544)
As per title.
cc @tlopex @guan404ming
We keep the interface same as
[`from_fx()`](https://github.com/apache/tvm/blob/ed97234b25a155bc66198ab5cd9e372a4772acec/python/tvm/relax/frontend/torch/fx_translator.py#L1152)
so you can define and pass custom converter something like this.
```python
from tvm.relax.frontend.torch.exported_program_translator import
ExportedProgramImporter
def _rms_norm_converter(node: torch.fx.Node, self: ExportedProgramImporter)
-> relax.Var:
x = self.env[node.args[0]]
torch_dtype = node.args[0].meta["tensor_meta"].dtype
normalized_shape = node.args[1]
weight = self.env.get(node.args[2], None) if len(node.args) > 2 else
None
eps = node.args[3] if len(node.args) > 3 else None
N = len(self.shape_of(x))
D = len(normalized_shape) if isinstance(normalized_shape, (tuple,
list)) else 1
axes = list(range(N - D, N))
if weight is None:
weight = self._convert_torch_tensor_to_relax(
torch.ones(list(normalized_shape), dtype=torch_dtype)
)
eps = torch.finfo(torch_dtype).eps if eps is None else 0.00001
return self.block_builder.emit(relax.op.nn.rms_norm(x, weight, axes,
eps))
mod = from_exported_program(
exported_program,
custom_convert_map={"rms_norm.default": _rms_norm_converter},
run_ep_decomposition=False,
)
---
.../frontend/torch/base_fx_graph_translator.py | 11 +++++++
.../frontend/torch/exported_program_translator.py | 26 +++++++++++++---
python/tvm/relax/frontend/torch/fx_translator.py | 11 -------
.../relax/test_frontend_from_exported_program.py | 36 ++++++++++++++++++++++
4 files changed, 69 insertions(+), 15 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 471d4209d7..47eb666210 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -46,6 +46,17 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
########## Utilities ##########
+ def update_convert_map(self, custom_convert_map: Dict[str, Callable]):
+ """Update self.convert_map with custom convert map
+
+ Parameters
+ ----------
+ custom_convert_map : Dict[str, Callable]
+ A custom op conversion map in the same format as self.convert_map
+ """
+
+ self.convert_map.update(custom_convert_map)
+
@staticmethod
def _convert_data_type(input_type: Union[str, torch.dtype], env:
Optional[Dict] = None):
"""converts the PyTorch scalar type input_type to a TVM dtype."""
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 3e2274e551..3d6a632fb2 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -23,6 +23,7 @@ from functools import partial
from typing import Callable, Dict, List, Optional, Tuple
import torch
+from torch import fx
import tvm
from tvm import relax
@@ -32,8 +33,6 @@ from .base_fx_graph_translator import BaseFXGraphImporter
class ExportedProgramImporter(BaseFXGraphImporter):
"""An importer from ExportedProgram to Relax."""
- from torch import fx
-
@staticmethod
def _convert_pytorch_tensor_to_tvm(tensor_value: torch.Tensor) ->
tvm.runtime.Tensor:
"""Convert a PyTorch tensor to TVM tensor, handling sparse tensors.
@@ -1615,9 +1614,18 @@ class ExportedProgramImporter(BaseFXGraphImporter):
keep_params_as_input: bool,
unwrap_unit_return_tuple: bool,
no_bind_return_tuple: bool,
+ custom_convert_map: Optional[
+ Dict[str, Callable[[fx.Node, BaseFXGraphImporter], relax.Var]]
+ ],
) -> tvm.IRModule:
"""Convert a PyTorch ExportedProgram to a Relax program."""
- from torch import fx # type: ignore
+
+ # Update the conversion map with custom ops if provided.
+ if custom_convert_map:
+ custom_ops = set(custom_convert_map.keys())
+ self.update_convert_map(custom_convert_map)
+ else:
+ custom_ops = set()
# Create input variables.
(
@@ -1682,7 +1690,10 @@ class ExportedProgramImporter(BaseFXGraphImporter):
self.env[node] =
getattr(exported_program.graph_module, node.target)
elif node.op == "call_function":
func_name = node.target.__name__
- self.env[node] = self.convert_map[func_name](node)
+ if func_name in custom_ops:
+ self.env[node] = self.convert_map[func_name](node,
self)
+ else:
+ self.env[node] = self.convert_map[func_name](node)
else:
raise ValueError(f"Unsupported op {node.op}")
assert output is not None
@@ -1722,6 +1733,9 @@ def from_exported_program(
keep_params_as_input: bool = False,
unwrap_unit_return_tuple: bool = False,
no_bind_return_tuple: bool = False,
+ custom_convert_map: Optional[
+ Dict[str, Callable[[fx.Node, BaseFXGraphImporter], relax.Var]]
+ ] = None,
run_ep_decomposition: bool = True,
) -> tvm.IRModule:
"""Convert a PyTorch ExportedProgram to a Relax program
@@ -1742,6 +1756,9 @@ def from_exported_program(
A boolean flag indicating whether to bind the return tuple as a relax
var.
If the flag is true and the return value is a tuple, it will not bind
it to a var.
+ custom_convert_map : Dict[str, Callable[[fx.Node, BaseFXGraphImporter],
relax.Var]]
+ A custom op conversion map in the same format as
ExportedProgramImporter.convert_map above
+
run_ep_decomposition : bool
A boolean flag indicating whether to run PyTorch's decomposition on the
exported program before translation. When True, high-level operators
will
@@ -1795,4 +1812,5 @@ def from_exported_program(
keep_params_as_input,
unwrap_unit_return_tuple,
no_bind_return_tuple,
+ custom_convert_map,
)
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index 8b1f5de36b..f2a6c9e654 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -1037,17 +1037,6 @@ class TorchFXImporter(BaseFXGraphImporter):
"item": self._item,
}
- def update_convert_map(self, custom_convert_map: dict):
- """Update self.convert_map with custom convert map
-
- Parameters
- ----------
- custom_convert_map : Dictionary of str to Relax op
- A custom op conversion map in the same format as self.convert_map
- """
-
- self.convert_map.update(custom_convert_map)
-
def from_fx(
self,
model,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 74ad2329fe..01e16e7564 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -42,6 +42,7 @@ def verify_model(
unwrap_unit_return_tuple=False,
no_bind_return_tuple=False,
map_free_vars=False,
+ custom_convert_map=None,
):
exported_program = export(torch_model, args=example_args,
dynamic_shapes=dynamic_shapes)
mod = from_exported_program(
@@ -50,6 +51,7 @@ def verify_model(
keep_params_as_input=keep_params_as_input,
unwrap_unit_return_tuple=unwrap_unit_return_tuple,
no_bind_return_tuple=no_bind_return_tuple,
+ custom_convert_map=custom_convert_map,
)
binding = {k: tvm.runtime.tensor(v) for k, v in binding.items()}
@@ -6562,6 +6564,40 @@ def test_register_buffer():
from_exported_program(ep)
+def test_custom_op():
+ class AddOp(Module):
+ def forward(self, x, y):
+ return torch.ops.aten.add.Tensor(x, y)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((5,), dtype="float32"),
+ y: R.Tensor((5,), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((5,), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((5,), dtype="float32") = R.subtract(x, y)
+ gv: R.Tuple(R.Tensor((5,), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ from tvm.relax.frontend.torch.exported_program_translator import (
+ ExportedProgramImporter,
+ )
+
+ def custom_add_converter(node: torch.fx.Node, self:
ExportedProgramImporter) -> relax.Var:
+ x = self.env[node.args[0]]
+ y = self.env[node.args[1]]
+
+ return self.block_builder.emit(R.subtract(x, y))
+
+ example_args = (torch.randn(5, dtype=torch.float32), torch.randn(5,
dtype=torch.float32))
+ verify_model(
+ AddOp(), example_args, {}, Expected, custom_convert_map={"add.Tensor":
custom_add_converter}
+ )
+
+
def test_empty_like():
class EmptyLike(Module):
def forward(self, data):