vinx13 commented on code in PR #14166:
URL: https://github.com/apache/tvm/pull/14166#discussion_r1122422145
##########
python/tvm/relax/backend/pattern_registry.py:
##########
@@ -46,25 +47,59 @@ class PatternRegistryEntry(Object):
arg_patterns: Mapping[str, DFPattern]
The mapping from arg name to its pattern. It can be used to extract
arg expression
from match result. All DFPattern in this map should be part of the
`pattern`.
+
+ check: Callable[[Mapping[DFPattern, Expr], Expr], bool]
+ The function to check whether the match result is accepted.
"""
name: str
pattern: DFPattern
arg_patterns: Mapping[str, DFPattern]
-
- def __init__(self, name: str, pattern: DFPattern, arg_patterns:
Mapping[str, DFPattern]):
+ check: Callable[[Mapping[DFPattern, Expr], Expr], bool]
+
+ def __init__(
+ self,
+ name: str,
+ pattern: DFPattern,
+ arg_patterns: Mapping[str, DFPattern],
+ check: Callable[[Mapping[DFPattern, Expr], Expr], bool],
+ ):
self.__init_handle_by_constructor__(
- _ffi_api.PatternRegistryEntry, name, pattern, arg_patterns #
type: ignore
+ _ffi_api.PatternRegistryEntry, name, pattern, arg_patterns, check
# type: ignore
)
+_REGISTERED_PATTERN_NAMES: Set[str] = set()
+
+
+def _cleanup_registered_patterns():
+ _ffi_api.RemovePatterns(list(_REGISTERED_PATTERN_NAMES)) # type: ignore #
pylint: disable=no-member
+
+
+_CLEANUP_REGISTERED = False
+
+
+def _ensure_cleanup_function_registered():
+ """
+ Add a cleanup function to be called on interpreter termination, to remove
all
+ patterns registered on the Python side. Without cleaning up those patterns,
+ program will segfault on termination. It's because the check functiosn of
pattern
Review Comment:
this seems a general issue of ffi
cc @tqchen @junrushao
##########
python/tvm/relax/backend/contrib/cutlass.py:
##########
@@ -116,7 +173,6 @@ def partition_for_cutlass(mod):
compiled by the CUTLASS backend.
"""
- cutlass_patterns = get_patterns_with_prefix("cutlass")
- return transform.FuseOpsByPattern(cutlass_patterns, bind_constants=True,
annotate_codegen=True)(
- mod
- )
+ cutlass_pattern_entries = get_patterns_with_prefix("cutlass")
+ patterns = [(e.name, e.pattern, e.check) for e in cutlass_pattern_entries]
+ return transform.FuseOpsByPattern(patterns, bind_constants=True,
annotate_codegen=True)(mod)
Review Comment:
```suggestion
return transform.FuseOpsByPattern(patterns, bind_constants=False,
annotate_codegen=True)(mod)
```
I made a typo in the previous PR, it should be False
##########
python/tvm/relax/backend/pattern_registry.py:
##########
@@ -46,25 +47,59 @@ class PatternRegistryEntry(Object):
arg_patterns: Mapping[str, DFPattern]
The mapping from arg name to its pattern. It can be used to extract
arg expression
from match result. All DFPattern in this map should be part of the
`pattern`.
+
+ check: Callable[[Mapping[DFPattern, Expr], Expr], bool]
+ The function to check whether the match result is accepted.
"""
name: str
pattern: DFPattern
arg_patterns: Mapping[str, DFPattern]
-
- def __init__(self, name: str, pattern: DFPattern, arg_patterns:
Mapping[str, DFPattern]):
+ check: Callable[[Mapping[DFPattern, Expr], Expr], bool]
+
+ def __init__(
+ self,
+ name: str,
+ pattern: DFPattern,
+ arg_patterns: Mapping[str, DFPattern],
+ check: Callable[[Mapping[DFPattern, Expr], Expr], bool],
+ ):
self.__init_handle_by_constructor__(
- _ffi_api.PatternRegistryEntry, name, pattern, arg_patterns #
type: ignore
+ _ffi_api.PatternRegistryEntry, name, pattern, arg_patterns, check
# type: ignore
)
+_REGISTERED_PATTERN_NAMES: Set[str] = set()
+
+
+def _cleanup_registered_patterns():
+ _ffi_api.RemovePatterns(list(_REGISTERED_PATTERN_NAMES)) # type: ignore #
pylint: disable=no-member
+
+
+_CLEANUP_REGISTERED = False
+
+
+def _ensure_cleanup_function_registered():
+ """
+ Add a cleanup function to be called on interpreter termination, to remove
all
+ patterns registered on the Python side. Without cleaning up those patterns,
+ program will segfault on termination. It's because the check functiosn of
pattern
Review Comment:
```suggestion
program will segfault on termination. It's because the check functions
of pattern
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]