This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 820642b6c4 [Relax] Fix Torch frontends to report all the missing ops 
(#17826)
820642b6c4 is described below

commit 820642b6c4709aca510bedbd4fae028da9d126c3
Author: Deivanayaki S <[email protected]>
AuthorDate: Mon Apr 14 07:31:01 2025 +0530

    [Relax] Fix Torch frontends to report all the missing ops (#17826)
    
    * enhance missing func types finding in exported program and fx graph 
frontend
    
    * fix trailing space issue
    
    * fix lint issues by formatting the code
    
    * fix name error in fx frontend
    
    ---------
    
    Co-authored-by: deivanayakisankaralingam <deiva@Deivanayaki>
---
 .../relax/frontend/torch/exported_program_translator.py   | 15 ++++++++++++---
 python/tvm/relax/frontend/torch/fx_translator.py          | 15 ++++++++++++---
 2 files changed, 24 insertions(+), 6 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 875ec3b83e..be17001fd0 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -511,6 +511,18 @@ class ExportedProgramImporter(BaseFXGraphImporter):
         ):
             output = None
             with self.block_builder.dataflow():
+
+                # Find all the missing function types
+                missing_func_types = list(
+                    {
+                        node.target.__name__
+                        for node in nodes
+                        if node.op == "call_function"
+                        and node.target.__name__ not in self.convert_map
+                    }
+                )
+                assert not missing_func_types, f"Unsupported function types 
{missing_func_types}"
+
                 # Translate the model.
                 for node in nodes:
                     if node.op == "placeholder":
@@ -537,9 +549,6 @@ class ExportedProgramImporter(BaseFXGraphImporter):
                         self.env[node] = 
getattr(exported_program.graph_module, node.target)
                     elif node.op == "call_function":
                         func_name = node.target.__name__
-                        assert (
-                            func_name in self.convert_map
-                        ), f"Unsupported function type {func_name}"
                         self.env[node] = self.convert_map[func_name](node)
                     else:
                         raise ValueError(f"Unsupported op {node.op}")
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index a5b50a7d1d..f6dd235d5a 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -884,6 +884,18 @@ class TorchFXImporter(BaseFXGraphImporter):
         with self.block_builder.function(name=func_name, params=inputs.copy(), 
attrs=func_attrs):
             output = None
             with self.block_builder.dataflow():
+
+                # Find all the missing function types
+                missing_func_types = list(
+                    {
+                        node.target.__name__
+                        for node in graph.nodes
+                        if node.op == "call_function"
+                        and node.target.__name__ not in self.convert_map
+                    }
+                )
+                assert not missing_func_types, f"Unsupported function types 
{missing_func_types}"
+
                 # Translate model parameters.
                 for _, param in model.named_parameters():
                     shape = param.data.shape
@@ -929,9 +941,6 @@ class TorchFXImporter(BaseFXGraphImporter):
                         self.env[node] = self.convert_map[type(module)](node)
                     elif node.op == "call_function":
                         func_name = node.target.__name__
-                        assert (
-                            func_name in self.convert_map
-                        ), f"Unsupported function type {func_name}"
                         if func_name in custom_ops:
                             self.env[node] = self.convert_map[func_name](node, 
self)
                         else:

Reply via email to