This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new fcb8853603 [Relax] Refactor missing op check into shared utility for 
Torch frontends (#17840)
fcb8853603 is described below

commit fcb88536034f207333d199a587636a7d87576e58
Author: Deivanayaki S <[email protected]>
AuthorDate: Wed Apr 16 13:32:40 2025 +0530

    [Relax] Refactor missing op check into shared utility for Torch frontends 
(#17840)
    
    * combine missing op logic of export and fx graph into common utilities
    
    * move func call above builder and fix lint issue
    
    * add type hint for nodes in helper function
    
    ---------
    
    Co-authored-by: deivanayakisankaralingam <deiva@Deivanayaki>
---
 .../tvm/relax/frontend/torch/base_fx_graph_translator.py  | 12 +++++++++++-
 .../relax/frontend/torch/exported_program_translator.py   | 15 ++++-----------
 python/tvm/relax/frontend/torch/fx_translator.py          | 14 +++-----------
 3 files changed, 18 insertions(+), 23 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 3018b0db77..6d880ab90d 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -21,7 +21,7 @@
 import abc
 from functools import reduce
 import math
-from typing import Callable, Dict, Optional, Tuple, Union
+from typing import Callable, Dict, Optional, Tuple, Union, List
 
 from tvm import relax
 
@@ -103,6 +103,16 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         else:
             return node
 
+    def _check_unsupported_func_type(self, nodes: List[fx.Node]):
+        missing_func_types = list(
+            {
+                node.target.__name__
+                for node in nodes
+                if node.op == "call_function" and node.target.__name__ not in 
self.convert_map
+            }
+        )
+        assert not missing_func_types, f"Unsupported function types 
{missing_func_types}"
+
     ########## Unary Ops ##########
 
     def _unary_op(self, op: Callable) -> Callable:
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 7b9587b675..8f6418891b 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -518,23 +518,16 @@ class ExportedProgramImporter(BaseFXGraphImporter):
         func_attrs = {"num_input": len(user_input_vars)} if 
keep_params_as_input else None
 
         nodes: List[fx.Node] = exported_program.graph.nodes
+
+        # Find all the missing function types
+        self._check_unsupported_func_type(nodes)
+
         with self.block_builder.function(
             name=func_name, params=list(inputs_vars.values()).copy(), 
attrs=func_attrs
         ):
             output = None
             with self.block_builder.dataflow():
 
-                # Find all the missing function types
-                missing_func_types = list(
-                    {
-                        node.target.__name__
-                        for node in nodes
-                        if node.op == "call_function"
-                        and node.target.__name__ not in self.convert_map
-                    }
-                )
-                assert not missing_func_types, f"Unsupported function types 
{missing_func_types}"
-
                 # Translate the model.
                 for node in nodes:
                     if node.op == "placeholder":
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index d24d67105e..594344fef8 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -848,21 +848,13 @@ class TorchFXImporter(BaseFXGraphImporter):
         else:
             func_attrs = None
 
+        # Find all the missing function types
+        self._check_unsupported_func_type(graph.nodes)
+
         with self.block_builder.function(name=func_name, params=inputs.copy(), 
attrs=func_attrs):
             output = None
             with self.block_builder.dataflow():
 
-                # Find all the missing function types
-                missing_func_types = list(
-                    {
-                        node.target.__name__
-                        for node in graph.nodes
-                        if node.op == "call_function"
-                        and node.target.__name__ not in self.convert_map
-                    }
-                )
-                assert not missing_func_types, f"Unsupported function types 
{missing_func_types}"
-
                 # Translate model parameters.
                 for _, param in model.named_parameters():
                     shape = param.data.shape

Reply via email to