diff --git a/c_src/pythonx/python.cpp b/c_src/pythonx/python.cpp index 9a65c3c..70be734 100644 --- a/c_src/pythonx/python.cpp +++ b/c_src/pythonx/python.cpp @@ -31,13 +31,13 @@ DEF_SYMBOL(PyErr_Fetch) DEF_SYMBOL(PyErr_Occurred) DEF_SYMBOL(PyEval_GetBuiltins) DEF_SYMBOL(PyEval_EvalCode) +DEF_SYMBOL(PyEval_RestoreThread) DEF_SYMBOL(PyEval_SaveThread) DEF_SYMBOL(PyFloat_AsDouble) DEF_SYMBOL(PyFloat_FromDouble) -DEF_SYMBOL(PyGILState_Ensure) -DEF_SYMBOL(PyGILState_Release) DEF_SYMBOL(PyImport_AddModule) DEF_SYMBOL(PyImport_ImportModule) +DEF_SYMBOL(PyInterpreterState_Get) DEF_SYMBOL(PyIter_Next) DEF_SYMBOL(PyList_Append) DEF_SYMBOL(PyList_GetItem) @@ -60,6 +60,7 @@ DEF_SYMBOL(PyObject_Str) DEF_SYMBOL(PySet_Add) DEF_SYMBOL(PySet_New) DEF_SYMBOL(PySet_Size) +DEF_SYMBOL(PyThreadState_New) DEF_SYMBOL(PyTuple_GetItem) DEF_SYMBOL(PyTuple_New) DEF_SYMBOL(PyTuple_Pack) @@ -70,7 +71,6 @@ DEF_SYMBOL(PyUnicode_FromStringAndSize) DEF_SYMBOL(Py_BuildValue) DEF_SYMBOL(Py_CompileString) DEF_SYMBOL(Py_DecRef) -DEF_SYMBOL(Py_FinalizeEx) DEF_SYMBOL(Py_IncRef) DEF_SYMBOL(Py_InitializeEx) DEF_SYMBOL(Py_IsFalse) @@ -105,13 +105,13 @@ void load_python_library(std::string path) { LOAD_SYMBOL(python_library, PyErr_Occurred) LOAD_SYMBOL(python_library, PyEval_GetBuiltins) LOAD_SYMBOL(python_library, PyEval_EvalCode) + LOAD_SYMBOL(python_library, PyEval_RestoreThread) LOAD_SYMBOL(python_library, PyEval_SaveThread) LOAD_SYMBOL(python_library, PyFloat_AsDouble) LOAD_SYMBOL(python_library, PyFloat_FromDouble) - LOAD_SYMBOL(python_library, PyGILState_Ensure) - LOAD_SYMBOL(python_library, PyGILState_Release) LOAD_SYMBOL(python_library, PyImport_AddModule) LOAD_SYMBOL(python_library, PyImport_ImportModule) + LOAD_SYMBOL(python_library, PyInterpreterState_Get) LOAD_SYMBOL(python_library, PyIter_Next) LOAD_SYMBOL(python_library, PyList_Append) LOAD_SYMBOL(python_library, PyList_GetItem) @@ -134,6 +134,7 @@ void load_python_library(std::string path) { LOAD_SYMBOL(python_library, PySet_Add) LOAD_SYMBOL(python_library, PySet_New) LOAD_SYMBOL(python_library, PySet_Size) + LOAD_SYMBOL(python_library, PyThreadState_New) LOAD_SYMBOL(python_library, PyTuple_GetItem) LOAD_SYMBOL(python_library, PyTuple_New) LOAD_SYMBOL(python_library, PyTuple_Pack) @@ -144,7 +145,6 @@ void load_python_library(std::string path) { LOAD_SYMBOL(python_library, Py_BuildValue) LOAD_SYMBOL(python_library, Py_CompileString) LOAD_SYMBOL(python_library, Py_DecRef) - LOAD_SYMBOL(python_library, Py_FinalizeEx) LOAD_SYMBOL(python_library, Py_IncRef) LOAD_SYMBOL(python_library, Py_InitializeEx) LOAD_SYMBOL(python_library, Py_IsFalse) diff --git a/c_src/pythonx/python.hpp b/c_src/pythonx/python.hpp index 20d1414..3bbbdca 100644 --- a/c_src/pythonx/python.hpp +++ b/c_src/pythonx/python.hpp @@ -62,9 +62,9 @@ namespace pythonx::python { // Opaque types +using PyInterpreterStatePtr = void *; using PyObjectPtr = void *; using PyThreadStatePtr = void *; -using PyGILState_STATE = unsigned char; using Py_ssize_t = ssize_t; // Functions @@ -85,13 +85,13 @@ extern void (*PyErr_Fetch)(PyObjectPtr *, PyObjectPtr *, PyObjectPtr *); extern PyObjectPtr (*PyErr_Occurred)(); extern PyObjectPtr (*PyEval_GetBuiltins)(); extern PyObjectPtr (*PyEval_EvalCode)(PyObjectPtr, PyObjectPtr, PyObjectPtr); +extern void (*PyEval_RestoreThread)(PyThreadStatePtr); extern PyThreadStatePtr (*PyEval_SaveThread)(); extern double (*PyFloat_AsDouble)(PyObjectPtr); extern PyObjectPtr (*PyFloat_FromDouble)(double); -extern PyGILState_STATE (*PyGILState_Ensure)(); -extern void (*PyGILState_Release)(PyGILState_STATE); extern PyObjectPtr (*PyImport_AddModule)(const char *); extern PyObjectPtr (*PyImport_ImportModule)(const char *); +extern PyInterpreterStatePtr (*PyInterpreterState_Get)(); extern PyObjectPtr (*PyIter_Next)(PyObjectPtr); extern int (*PyList_Append)(PyObjectPtr, PyObjectPtr); extern PyObjectPtr (*PyList_GetItem)(PyObjectPtr, Py_ssize_t); @@ -114,6 +114,7 @@ extern PyObjectPtr (*PyObject_Str)(PyObjectPtr); extern int (*PySet_Add)(PyObjectPtr, PyObjectPtr); extern PyObjectPtr (*PySet_New)(PyObjectPtr); extern Py_ssize_t (*PySet_Size)(PyObjectPtr); +extern PyThreadStatePtr (*PyThreadState_New)(PyInterpreterStatePtr); extern PyObjectPtr (*PyTuple_GetItem)(PyObjectPtr, Py_ssize_t); extern PyObjectPtr (*PyTuple_New)(Py_ssize_t); extern PyObjectPtr (*PyTuple_Pack)(Py_ssize_t, ...); @@ -124,7 +125,6 @@ extern PyObjectPtr (*PyUnicode_FromStringAndSize)(const char *, Py_ssize_t); extern PyObjectPtr (*Py_BuildValue)(const char *, ...); extern PyObjectPtr (*Py_CompileString)(const char *, const char *, int); extern void (*Py_DecRef)(PyObjectPtr); -extern int (*Py_FinalizeEx)(); extern void (*Py_IncRef)(PyObjectPtr); extern void (*Py_InitializeEx)(int); extern int (*Py_IsFalse)(PyObjectPtr); diff --git a/c_src/pythonx/pythonx.cpp b/c_src/pythonx/pythonx.cpp index 8f79b65..2049247 100644 --- a/c_src/pythonx/pythonx.cpp +++ b/c_src/pythonx/pythonx.cpp @@ -26,6 +26,9 @@ std::wstring python_home_path_w; std::wstring python_executable_path_w; std::map> compilation_cache; std::mutex compilation_cache_mutex; +PyInterpreterStatePtr interpreter_state; +std::map thread_states; +std::mutex thread_states_mutex; // Wrapper around the Python Global Interpreter Lock (GIL). // @@ -36,12 +39,63 @@ std::mutex compilation_cache_mutex; // // [1]: https://en.wikipedia.org/wiki/Resource_acquisition_is_initialization class PyGILGuard { - PyGILState_STATE state; + // The simplest way to implement this guard is to use `PyGILState_Ensure` + // and `PyGILState_Release`, however this can lead to segfaults when + // using libraries depending on pybind11. + // + // pybind11 is a popular library for writing C extensions in Python + // packages. It provies convenient C++ API on top of the Python C + // API. In particular, it provides conveniences for dealing with + // GIL, one of them being `gil_scoped_acquire`. The implementation + // has a bug that results in a dangling pointer being used. This + // bug only appears when the code runs in a non-main thread that + // manages the `gil_scoped_acquire` checks if the calling thread + // alread holds GIL with `PyGILState_Ensure` and `PyGILState_Release`. + // Specifically, the GIL, in which case it stores the pointer to + // the corresponding `PyThreadState`. After `PyGILState_Release`, + // the thread state is freed, but subsequent usage of `gil_scoped_acquire` + // still re-uses the pointer. This issues has been reported in [1]. + // + // In our case, we evaluate Python code dirty scheduler threads. + // This means that the threads are reused and we acquire the GIL + // every time. In order to avoid the pybind11 bug, we want to avoid + // using `PyGILState_Release`, and instead have a permanent `PyThreadState` + // for each of the dirty scheduler threads. We do this by creating + // new state when the given scheduler thread obtains the GIL for + // the first time. Then, we use `PyEval_RestoreThread` and `PyEval_SaveThread` + // to acquire and release the GIL respectively. + // + // NOTE: the dirty scheduler thread pool is fixed, so the map does + // not grow beyond that. If we ever need to acquire the GIL from + // other threads, we should extend this implementation to either + // allow removing the state on destruction, or have a variant with + // `PyGILState_Ensure` and `PyGILState_Release`, as long as it does + // not fall into the bug described above. + // + // [1]: https://github.com/pybind/pybind11/issues/2888 public: - PyGILGuard() { this->state = PyGILState_Ensure(); } + PyGILGuard() { + auto thread_id = std::this_thread::get_id(); + + PyThreadStatePtr state; + + { + auto guard = std::lock_guard(thread_states_mutex); + + if (thread_states.find(thread_id) == thread_states.end()) { + // Note that PyThreadState_New does not require GIL to be held. + state = PyThreadState_New(interpreter_state); + thread_states[thread_id] = state; + } else { + state = thread_states[thread_id]; + } + } + + PyEval_RestoreThread(state); + } - ~PyGILGuard() { PyGILState_Release(this->state); } + ~PyGILGuard() { PyEval_SaveThread(); } }; // Ensures the given object refcount is decremented when the guard @@ -275,6 +329,8 @@ fine::Ok<> init(ErlNifEnv *env, std::string python_dl_path, Py_InitializeEx(0); + interpreter_state = PyInterpreterState_Get(); + // In order to use any of the Python C API functions, the calling // thread must hold the GIL. Since every NIF call may run on a // different dirty scheduler thread, we need to acquire the GIL at @@ -285,7 +341,7 @@ fine::Ok<> init(ErlNifEnv *env, std::string python_dl_path, // See pyo3 [1] for an extra reference. // // [1]: https://github.com/PyO3/pyo3/blob/v0.23.3/src/gil.rs#L63-L74 - PyEval_SaveThread(); + thread_states[std::this_thread::get_id()] = PyEval_SaveThread(); is_initialized = true; @@ -405,40 +461,6 @@ sys.stdin = Stdin() FINE_NIF(init, ERL_NIF_DIRTY_JOB_CPU_BOUND); -// Note that this NIF is here for the reference, but currently we do -// not support deinitialization. While in principle it should be -// possible to reinitialize Python, it can lead to issues in practice. -// For example, doing so while using numpy simply does not work, see -// [1] for discussion points. -// -// [1]: https://bugs.python.org/issue34309 -fine::Ok<> terminate(ErlNifEnv *env) { - ensure_initialized(); - - auto init_guard = std::lock_guard(init_mutex); - - // Here we only acquire the GIL, since releasing after finalization - // makes no sense - PyGILState_Ensure(); - - if (Py_FinalizeEx() == -1) { - throw std::runtime_error("failed to finalize Python interpreter"); - } - - is_initialized = false; - - auto compilation_cache_guard = - std::lock_guard(compilation_cache_mutex); - compilation_cache.clear(); - - // Raises runtime error on failure, which is propagated automatically - unload_python_library(); - - return fine::Ok<>(); -} - -FINE_NIF(terminate, ERL_NIF_DIRTY_JOB_CPU_BOUND); - fine::Ok<> janitor_decref(ErlNifEnv *env, uint64_t ptr) { auto init_guard = std::lock_guard(init_mutex); diff --git a/lib/pythonx/nif.ex b/lib/pythonx/nif.ex index 774f314..736d686 100644 --- a/lib/pythonx/nif.ex +++ b/lib/pythonx/nif.ex @@ -13,7 +13,6 @@ defmodule Pythonx.NIF do end def init(_python_dl_path, _python_home_path, _python_executable_path, _sys_paths), do: err!() - def terminate(), do: err!() def janitor_decref(_ptr), do: err!() def none_new(), do: err!() def false_new(), do: err!()