diff --git a/python/tvm/_ffi/_ctypes/packed_func.py b/python/tvm/_ffi/_ctypes/packed_func.py index e8680afcdf98..6465e0335db0 100644 --- a/python/tvm/_ffi/_ctypes/packed_func.py +++ b/python/tvm/_ffi/_ctypes/packed_func.py @@ -340,6 +340,8 @@ def _init_pythonapi_inc_def_ref(): register_func = _LIB.TVMBackendRegisterEnvCAPI register_func(c_str("Py_IncRef"), ctypes.pythonapi.Py_IncRef) register_func(c_str("Py_DecRef"), ctypes.pythonapi.Py_DecRef) + register_func(c_str("PyGILState_Ensure"), ctypes.pythonapi.PyGILState_Ensure) + register_func(c_str("PyGILState_Release"), ctypes.pythonapi.PyGILState_Release) _init_pythonapi_inc_def_ref() diff --git a/python/tvm/_ffi/_cython/packed_func.pxi b/python/tvm/_ffi/_cython/packed_func.pxi index ae528bcb7828..3d1e87bf563d 100644 --- a/python/tvm/_ffi/_cython/packed_func.pxi +++ b/python/tvm/_ffi/_cython/packed_func.pxi @@ -17,7 +17,7 @@ import ctypes import traceback -from cpython cimport Py_INCREF, Py_DECREF +from cpython cimport Py_INCREF, Py_DECREF, PyGILState_Ensure, PyGILState_Release from numbers import Number, Integral from ..base import string_types, py2cerror from ..runtime_ctypes import DataType, Device, TVMByteArray, ObjectRValueRef @@ -381,5 +381,7 @@ def _init_pythonapi_inc_def_ref(): register_func = TVMBackendRegisterEnvCAPI register_func(c_str("Py_IncRef"), _py_incref_wrapper) register_func(c_str("Py_DecRef"), _py_decref_wrapper) + register_func(c_str("PyGILState_Ensure"), PyGILState_Ensure) + register_func(c_str("PyGILState_Release"), PyGILState_Release) _init_pythonapi_inc_def_ref() diff --git a/src/runtime/registry.cc b/src/runtime/registry.cc index 0db8786145d3..b9699c85d77b 100644 --- a/src/runtime/registry.cc +++ b/src/runtime/registry.cc @@ -148,6 +148,16 @@ class EnvCAPIRegistry { */ F_Py_IncDefRef py_dec_ref = nullptr; + /*! + \brief PyGILState_Ensure function + */ + void* (*py_gil_state_ensure)() = nullptr; + + /*! + \brief PyGILState_Release function + */ + void (*py_gil_state_release)(void*) = nullptr; + static EnvCAPIRegistry* Global() { static EnvCAPIRegistry* inst = new EnvCAPIRegistry(); return inst; @@ -161,6 +171,10 @@ class EnvCAPIRegistry { Update(symbol_name, &py_inc_ref, fptr); } else if (symbol_name == "Py_DecRef") { Update(symbol_name, &py_dec_ref, fptr); + } else if (symbol_name == "PyGILState_Ensure") { + Update(symbol_name, &py_gil_state_ensure, fptr); + } else if (symbol_name == "PyGILState_Release") { + Update(symbol_name, &py_gil_state_release, fptr); } else { LOG(FATAL) << "Unknown env API " << symbol_name; } @@ -177,15 +191,17 @@ class EnvCAPIRegistry { } void IncRef(void* python_obj) { + WithGIL context(this); ICHECK(py_inc_ref) << "Attempted to call Py_IncRef through EnvCAPIRegistry, " << "but Py_IncRef wasn't registered"; (*py_inc_ref)(python_obj); } void DecRef(void* python_obj) { - ICHECK(py_inc_ref) << "Attempted to call Py_IncRef through EnvCAPIRegistry, " - << "but Py_IncRef wasn't registered"; - (*py_inc_ref)(python_obj); + WithGIL context(this); + ICHECK(py_dec_ref) << "Attempted to call Py_DefRef through EnvCAPIRegistry, " + << "but Py_DefRef wasn't registered"; + (*py_dec_ref)(python_obj); } private: @@ -198,6 +214,28 @@ class EnvCAPIRegistry { } target[0] = ptr_casted; } + + struct WithGIL { + WithGIL(EnvCAPIRegistry* self) : self(self) { + ICHECK(self->py_gil_state_ensure) << "Attempted to acquire GIL through EnvCAPIRegistry, " + << "but PyGILState_Ensure wasn't registered"; + ICHECK(self->py_gil_state_release) << "Attempted to acquire GIL through EnvCAPIRegistry, " + << "but PyGILState_Release wasn't registered"; + gil_state = self->py_gil_state_ensure(); + } + ~WithGIL() { + if (self && gil_state) { + self->py_gil_state_release(gil_state); + } + } + WithGIL(const WithGIL&) = delete; + WithGIL(WithGIL&&) = delete; + WithGIL& operator=(const WithGIL&) = delete; + WithGIL& operator=(WithGIL&&) = delete; + + EnvCAPIRegistry* self = nullptr; + void* gil_state = nullptr; + }; }; void EnvCheckSignals() { EnvCAPIRegistry::Global()->CheckSignals(); }