vinx13 commented on code in PR #14166:
URL: https://github.com/apache/tvm/pull/14166#discussion_r1122422145


##########
python/tvm/relax/backend/pattern_registry.py:
##########
@@ -46,25 +47,59 @@ class PatternRegistryEntry(Object):
     arg_patterns: Mapping[str, DFPattern]
         The mapping from arg name to its pattern. It can be used to extract 
arg expression
         from match result. All DFPattern in this map should be part of the 
`pattern`.
+
+    check: Callable[[Mapping[DFPattern, Expr], Expr], bool]
+        The function to check whether the match result is accepted.
     """
 
     name: str
     pattern: DFPattern
     arg_patterns: Mapping[str, DFPattern]
-
-    def __init__(self, name: str, pattern: DFPattern, arg_patterns: 
Mapping[str, DFPattern]):
+    check: Callable[[Mapping[DFPattern, Expr], Expr], bool]
+
+    def __init__(
+        self,
+        name: str,
+        pattern: DFPattern,
+        arg_patterns: Mapping[str, DFPattern],
+        check: Callable[[Mapping[DFPattern, Expr], Expr], bool],
+    ):
         self.__init_handle_by_constructor__(
-            _ffi_api.PatternRegistryEntry, name, pattern, arg_patterns  # 
type: ignore
+            _ffi_api.PatternRegistryEntry, name, pattern, arg_patterns, check  
# type: ignore
         )
 
 
+_REGISTERED_PATTERN_NAMES: Set[str] = set()
+
+
+def _cleanup_registered_patterns():
+    _ffi_api.RemovePatterns(list(_REGISTERED_PATTERN_NAMES))  # type: ignore # 
pylint: disable=no-member
+
+
+_CLEANUP_REGISTERED = False
+
+
+def _ensure_cleanup_function_registered():
+    """
+    Add a cleanup function to be called on interpreter termination, to remove 
all
+    patterns registered on the Python side. Without cleaning up those patterns,
+    program will segfault on termination. It's because the check functiosn of 
pattern

Review Comment:
   this seems a general issue of ffi
   cc @tqchen @junrushao 



##########
python/tvm/relax/backend/contrib/cutlass.py:
##########
@@ -116,7 +173,6 @@ def partition_for_cutlass(mod):
         compiled by the CUTLASS backend.
     """
 
-    cutlass_patterns = get_patterns_with_prefix("cutlass")
-    return transform.FuseOpsByPattern(cutlass_patterns, bind_constants=True, 
annotate_codegen=True)(
-        mod
-    )
+    cutlass_pattern_entries = get_patterns_with_prefix("cutlass")
+    patterns = [(e.name, e.pattern, e.check) for e in cutlass_pattern_entries]
+    return transform.FuseOpsByPattern(patterns, bind_constants=True, 
annotate_codegen=True)(mod)

Review Comment:
   ```suggestion
       return transform.FuseOpsByPattern(patterns, bind_constants=False, 
annotate_codegen=True)(mod)
   ```
   I made a typo in the previous PR, it should be False



##########
python/tvm/relax/backend/pattern_registry.py:
##########
@@ -46,25 +47,59 @@ class PatternRegistryEntry(Object):
     arg_patterns: Mapping[str, DFPattern]
         The mapping from arg name to its pattern. It can be used to extract 
arg expression
         from match result. All DFPattern in this map should be part of the 
`pattern`.
+
+    check: Callable[[Mapping[DFPattern, Expr], Expr], bool]
+        The function to check whether the match result is accepted.
     """
 
     name: str
     pattern: DFPattern
     arg_patterns: Mapping[str, DFPattern]
-
-    def __init__(self, name: str, pattern: DFPattern, arg_patterns: 
Mapping[str, DFPattern]):
+    check: Callable[[Mapping[DFPattern, Expr], Expr], bool]
+
+    def __init__(
+        self,
+        name: str,
+        pattern: DFPattern,
+        arg_patterns: Mapping[str, DFPattern],
+        check: Callable[[Mapping[DFPattern, Expr], Expr], bool],
+    ):
         self.__init_handle_by_constructor__(
-            _ffi_api.PatternRegistryEntry, name, pattern, arg_patterns  # 
type: ignore
+            _ffi_api.PatternRegistryEntry, name, pattern, arg_patterns, check  
# type: ignore
         )
 
 
+_REGISTERED_PATTERN_NAMES: Set[str] = set()
+
+
+def _cleanup_registered_patterns():
+    _ffi_api.RemovePatterns(list(_REGISTERED_PATTERN_NAMES))  # type: ignore # 
pylint: disable=no-member
+
+
+_CLEANUP_REGISTERED = False
+
+
+def _ensure_cleanup_function_registered():
+    """
+    Add a cleanup function to be called on interpreter termination, to remove 
all
+    patterns registered on the Python side. Without cleaning up those patterns,
+    program will segfault on termination. It's because the check functiosn of 
pattern

Review Comment:
   ```suggestion
       program will segfault on termination. It's because the check functions 
of pattern
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to