This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new fcb8853603 [Relax] Refactor missing op check into shared utility for
Torch frontends (#17840)
fcb8853603 is described below
commit fcb88536034f207333d199a587636a7d87576e58
Author: Deivanayaki S <[email protected]>
AuthorDate: Wed Apr 16 13:32:40 2025 +0530
[Relax] Refactor missing op check into shared utility for Torch frontends
(#17840)
* combine missing op logic of export and fx graph into common utilities
* move func call above builder and fix lint issue
* add type hint for nodes in helper function
---------
Co-authored-by: deivanayakisankaralingam <deiva@Deivanayaki>
---
.../tvm/relax/frontend/torch/base_fx_graph_translator.py | 12 +++++++++++-
.../relax/frontend/torch/exported_program_translator.py | 15 ++++-----------
python/tvm/relax/frontend/torch/fx_translator.py | 14 +++-----------
3 files changed, 18 insertions(+), 23 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 3018b0db77..6d880ab90d 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -21,7 +21,7 @@
import abc
from functools import reduce
import math
-from typing import Callable, Dict, Optional, Tuple, Union
+from typing import Callable, Dict, Optional, Tuple, Union, List
from tvm import relax
@@ -103,6 +103,16 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
else:
return node
+ def _check_unsupported_func_type(self, nodes: List[fx.Node]):
+ missing_func_types = list(
+ {
+ node.target.__name__
+ for node in nodes
+ if node.op == "call_function" and node.target.__name__ not in
self.convert_map
+ }
+ )
+ assert not missing_func_types, f"Unsupported function types
{missing_func_types}"
+
########## Unary Ops ##########
def _unary_op(self, op: Callable) -> Callable:
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 7b9587b675..8f6418891b 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -518,23 +518,16 @@ class ExportedProgramImporter(BaseFXGraphImporter):
func_attrs = {"num_input": len(user_input_vars)} if
keep_params_as_input else None
nodes: List[fx.Node] = exported_program.graph.nodes
+
+ # Find all the missing function types
+ self._check_unsupported_func_type(nodes)
+
with self.block_builder.function(
name=func_name, params=list(inputs_vars.values()).copy(),
attrs=func_attrs
):
output = None
with self.block_builder.dataflow():
- # Find all the missing function types
- missing_func_types = list(
- {
- node.target.__name__
- for node in nodes
- if node.op == "call_function"
- and node.target.__name__ not in self.convert_map
- }
- )
- assert not missing_func_types, f"Unsupported function types
{missing_func_types}"
-
# Translate the model.
for node in nodes:
if node.op == "placeholder":
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index d24d67105e..594344fef8 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -848,21 +848,13 @@ class TorchFXImporter(BaseFXGraphImporter):
else:
func_attrs = None
+ # Find all the missing function types
+ self._check_unsupported_func_type(graph.nodes)
+
with self.block_builder.function(name=func_name, params=inputs.copy(),
attrs=func_attrs):
output = None
with self.block_builder.dataflow():
- # Find all the missing function types
- missing_func_types = list(
- {
- node.target.__name__
- for node in graph.nodes
- if node.op == "call_function"
- and node.target.__name__ not in self.convert_map
- }
- )
- assert not missing_func_types, f"Unsupported function types
{missing_func_types}"
-
# Translate model parameters.
for _, param in model.named_parameters():
shape = param.data.shape