This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 820642b6c4 [Relax] Fix Torch frontends to report all the missing ops
(#17826)
820642b6c4 is described below
commit 820642b6c4709aca510bedbd4fae028da9d126c3
Author: Deivanayaki S <[email protected]>
AuthorDate: Mon Apr 14 07:31:01 2025 +0530
[Relax] Fix Torch frontends to report all the missing ops (#17826)
* enhance missing func types finding in exported program and fx graph
frontend
* fix trailing space issue
* fix lint issues by formatting the code
* fix name error in fx frontend
---------
Co-authored-by: deivanayakisankaralingam <deiva@Deivanayaki>
---
.../relax/frontend/torch/exported_program_translator.py | 15 ++++++++++++---
python/tvm/relax/frontend/torch/fx_translator.py | 15 ++++++++++++---
2 files changed, 24 insertions(+), 6 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 875ec3b83e..be17001fd0 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -511,6 +511,18 @@ class ExportedProgramImporter(BaseFXGraphImporter):
):
output = None
with self.block_builder.dataflow():
+
+ # Find all the missing function types
+ missing_func_types = list(
+ {
+ node.target.__name__
+ for node in nodes
+ if node.op == "call_function"
+ and node.target.__name__ not in self.convert_map
+ }
+ )
+ assert not missing_func_types, f"Unsupported function types
{missing_func_types}"
+
# Translate the model.
for node in nodes:
if node.op == "placeholder":
@@ -537,9 +549,6 @@ class ExportedProgramImporter(BaseFXGraphImporter):
self.env[node] =
getattr(exported_program.graph_module, node.target)
elif node.op == "call_function":
func_name = node.target.__name__
- assert (
- func_name in self.convert_map
- ), f"Unsupported function type {func_name}"
self.env[node] = self.convert_map[func_name](node)
else:
raise ValueError(f"Unsupported op {node.op}")
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index a5b50a7d1d..f6dd235d5a 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -884,6 +884,18 @@ class TorchFXImporter(BaseFXGraphImporter):
with self.block_builder.function(name=func_name, params=inputs.copy(),
attrs=func_attrs):
output = None
with self.block_builder.dataflow():
+
+ # Find all the missing function types
+ missing_func_types = list(
+ {
+ node.target.__name__
+ for node in graph.nodes
+ if node.op == "call_function"
+ and node.target.__name__ not in self.convert_map
+ }
+ )
+ assert not missing_func_types, f"Unsupported function types
{missing_func_types}"
+
# Translate model parameters.
for _, param in model.named_parameters():
shape = param.data.shape
@@ -929,9 +941,6 @@ class TorchFXImporter(BaseFXGraphImporter):
self.env[node] = self.convert_map[type(module)](node)
elif node.op == "call_function":
func_name = node.target.__name__
- assert (
- func_name in self.convert_map
- ), f"Unsupported function type {func_name}"
if func_name in custom_ops:
self.env[node] = self.convert_map[func_name](node,
self)
else: