This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5e85443e43 [FFI][BUGFIX] Grab GIL when check env signals (#17419)
5e85443e43 is described below
commit 5e85443e43f9befcf8319cdc4045597aa49bf724
Author: Tianqi Chen <[email protected]>
AuthorDate: Thu Sep 26 09:22:13 2024 -0400
[FFI][BUGFIX] Grab GIL when check env signals (#17419)
This PR updates the CheckSignals function to grab GIL.
This is needed because we now explicitly release gil when calling
any C functions. GIL will need to be obtained otherwise we will
run into segfault when checking the signal.
The update now enables us to run ctrl + C in long running C functions.
---
python/tvm/_ffi/_cython/base.pxi | 16 +++++++++++-----
python/tvm/_ffi/_cython/packed_func.pxi | 16 ----------------
src/runtime/registry.cc | 12 ++++++++----
src/support/ffi_testing.cc | 8 ++++++++
4 files changed, 27 insertions(+), 25 deletions(-)
diff --git a/python/tvm/_ffi/_cython/base.pxi b/python/tvm/_ffi/_cython/base.pxi
index 0f7e5fcae6..887ac123ce 100644
--- a/python/tvm/_ffi/_cython/base.pxi
+++ b/python/tvm/_ffi/_cython/base.pxi
@@ -201,6 +201,10 @@ cdef inline void* c_handle(object handle):
# python env API
cdef extern from "Python.h":
int PyErr_CheckSignals()
+ void* PyGILState_Ensure()
+ void PyGILState_Release(void*)
+ void Py_IncRef(void*)
+ void Py_DecRef(void*)
cdef extern from "tvm/runtime/c_backend_api.h":
int TVMBackendRegisterEnvCAPI(const char* name, void* ptr)
@@ -210,11 +214,13 @@ cdef _init_env_api():
# so backend can call tvm::runtime::EnvCheckSignals to check
# signal when executing a long running function.
#
- # This feature is only enabled in cython for now due to problems of calling
- # these functions in ctypes.
- #
- # When the functions are not registered, the signals will be handled
- # only when the FFI function returns.
+ # Also registers the gil state release and ensure as PyErr_CheckSignals
+ # function is called with gil released and we need to regrab the gil
CHECK_CALL(TVMBackendRegisterEnvCAPI(c_str("PyErr_CheckSignals"),
<void*>PyErr_CheckSignals))
+ CHECK_CALL(TVMBackendRegisterEnvCAPI(c_str("PyGILState_Ensure"),
<void*>PyGILState_Ensure))
+ CHECK_CALL(TVMBackendRegisterEnvCAPI(c_str("PyGILState_Release"),
<void*>PyGILState_Release))
+ CHECK_CALL(TVMBackendRegisterEnvCAPI(c_str("PyGILState_Release"),
<void*>PyGILState_Release))
+ CHECK_CALL(TVMBackendRegisterEnvCAPI(c_str("Py_IncRef"), <void*>Py_IncRef))
+ CHECK_CALL(TVMBackendRegisterEnvCAPI(c_str("Py_DecRef"), <void*>Py_DecRef))
_init_env_api()
diff --git a/python/tvm/_ffi/_cython/packed_func.pxi
b/python/tvm/_ffi/_cython/packed_func.pxi
index 6e062ab5f1..b9516e79e3 100644
--- a/python/tvm/_ffi/_cython/packed_func.pxi
+++ b/python/tvm/_ffi/_cython/packed_func.pxi
@@ -376,19 +376,3 @@ def _set_class_object_generic(object_generic_class,
func_convert_to_object):
global _FUNC_CONVERT_TO_OBJECT
_CLASS_OBJECT_GENERIC = object_generic_class
_FUNC_CONVERT_TO_OBJECT = func_convert_to_object
-
-# Py_INCREF and Py_DECREF are C macros, not function objects.
-# Therefore, providing a wrapper function.
-cdef void _py_incref_wrapper(void* py_object):
- Py_INCREF(<object>py_object)
-cdef void _py_decref_wrapper(void* py_object):
- Py_DECREF(<object>py_object)
-
-def _init_pythonapi_inc_def_ref():
- register_func = TVMBackendRegisterEnvCAPI
- register_func(c_str("Py_IncRef"), <void*>_py_incref_wrapper)
- register_func(c_str("Py_DecRef"), <void*>_py_decref_wrapper)
- register_func(c_str("PyGILState_Ensure"), <void*>PyGILState_Ensure)
- register_func(c_str("PyGILState_Release"), <void*>PyGILState_Release)
-
-_init_pythonapi_inc_def_ref()
diff --git a/src/runtime/registry.cc b/src/runtime/registry.cc
index 0a034a7b58..09674edf35 100644
--- a/src/runtime/registry.cc
+++ b/src/runtime/registry.cc
@@ -183,10 +183,14 @@ class EnvCAPIRegistry {
// implementation of tvm::runtime::EnvCheckSignals
void CheckSignals() {
// check python signal to see if there are exception raised
- if (pyerr_check_signals != nullptr && (*pyerr_check_signals)() != 0) {
- // The error will let FFI know that the frontend environment
- // already set an error.
- throw EnvErrorAlreadySet("");
+ if (pyerr_check_signals != nullptr) {
+ // The C++ env comes without gil, so we need to grab gil here
+ WithGIL context(this);
+ if ((*pyerr_check_signals)() != 0) {
+ // The error will let FFI know that the frontend environment
+ // already set an error.
+ throw EnvErrorAlreadySet("");
+ }
}
}
diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc
index 928cdfcab8..52ffedda80 100644
--- a/src/support/ffi_testing.cc
+++ b/src/support/ffi_testing.cc
@@ -178,6 +178,14 @@
TVM_REGISTER_GLOBAL("testing.sleep_in_ffi").set_body_typed([](double timeout) {
std::this_thread::sleep_for(duration);
});
+TVM_REGISTER_GLOBAL("testing.check_signals").set_body_typed([](double
sleep_period) {
+ while (true) {
+ std::chrono::duration<int64_t, std::nano>
duration(static_cast<int64_t>(sleep_period * 1e9));
+ std::this_thread::sleep_for(duration);
+ runtime::EnvCheckSignals();
+ }
+});
+
TVM_REGISTER_GLOBAL("testing.ReturnsVariant").set_body_typed([](int x) ->
Variant<String, IntImm> {
if (x % 2 == 0) {
return IntImm(DataType::Int(64), x / 2);