Skip to content

Commit

Permalink
Fix race condition involving wrapper lookup (#865)
Browse files Browse the repository at this point in the history
There's a race condition between wrapper lookup and wrapper deallocation where a Python wrapper may be returned that's in the process of being deallocated. This commit fixes the issue (see #864 for further details).
  • Loading branch information
colesbury authored Jan 20, 2025
1 parent e134487 commit d758a00
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 5 deletions.
66 changes: 62 additions & 4 deletions src/nb_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,60 @@ static PyObject **nb_weaklist_ptr(PyObject *self) {
return weaklistoffset ? (PyObject **) ((uint8_t *) self + weaklistoffset) : nullptr;
}

static void nb_enable_try_inc_ref(PyObject *obj) noexcept {
#if 0 && defined(Py_GIL_DISABLED) && PY_VERSION_HEX >= 0x030E00A5
PyUnstable_EnableTryIncRef(obj);
#elif defined(Py_GIL_DISABLED)
// Since this is called during object construction, we know that we have
// the only reference to the object and can use a non-atomic write.
assert(obj->ob_ref_shared == 0);
obj->ob_ref_shared = _Py_REF_MAYBE_WEAKREF;
#endif
}

static bool nb_try_inc_ref(PyObject *obj) noexcept {
#if 0 && defined(Py_GIL_DISABLED) && PY_VERSION_HEX >= 0x030E00A5
return PyUnstable_TryIncRef(obj);
#elif defined(Py_GIL_DISABLED)
// See https://github.com/python/cpython/blob/d05140f9f77d7dfc753dd1e5ac3a5962aaa03eff/Include/internal/pycore_object.h#L761
uint32_t local = _Py_atomic_load_uint32_relaxed(&obj->ob_ref_local);
local += 1;
if (local == 0) {
// immortal
return true;
}
if (_Py_IsOwnedByCurrentThread(obj)) {
_Py_atomic_store_uint32_relaxed(&obj->ob_ref_local, local);
#ifdef Py_REF_DEBUG
_Py_INCREF_IncRefTotal();
#endif
return true;
}
Py_ssize_t shared = _Py_atomic_load_ssize_relaxed(&obj->ob_ref_shared);
for (;;) {
// If the shared refcount is zero and the object is either merged
// or may not have weak references, then we cannot incref it.
if (shared == 0 || shared == _Py_REF_MERGED) {
return false;
}

if (_Py_atomic_compare_exchange_ssize(
&obj->ob_ref_shared, &shared, shared + (1 << _Py_REF_SHARED_SHIFT))) {
#ifdef Py_REF_DEBUG
_Py_INCREF_IncRefTotal();
#endif
return true;
}
}
#else
if (Py_REFCNT(obj) > 0) {
Py_INCREF(obj);
return true;
}
return false;
#endif
}

static PyGetSetDef inst_getset[] = {
{ "__dict__", PyObject_GenericGetDict, PyObject_GenericSetDict, nullptr, nullptr },
{ nullptr, nullptr, nullptr, nullptr, nullptr }
Expand Down Expand Up @@ -98,6 +152,7 @@ PyObject *inst_new_int(PyTypeObject *tp, PyObject * /* args */,
self->clear_keep_alive = 0;
self->intrusive = intrusive;
self->unused = 0;
nb_enable_try_inc_ref((PyObject *)self);

// Update hash table that maps from C++ to Python instance
nb_shard &shard = internals->shard((void *) payload);
Expand Down Expand Up @@ -163,6 +218,7 @@ PyObject *inst_new_ext(PyTypeObject *tp, void *value) {
self->clear_keep_alive = 0;
self->intrusive = intrusive;
self->unused = 0;
nb_enable_try_inc_ref((PyObject *)self);

nb_shard &shard = internals->shard(value);
lock_shard guard(shard);
Expand Down Expand Up @@ -1766,16 +1822,18 @@ PyObject *nb_type_put(const std::type_info *cpp_type,
PyTypeObject *tp = Py_TYPE(seq.inst);

if (nb_type_data(tp)->type == cpp_type) {
Py_INCREF(seq.inst);
return seq.inst;
if (nb_try_inc_ref(seq.inst)) {
return seq.inst;
}
}

if (!lookup_type())
return nullptr;

if (PyType_IsSubtype(tp, td->type_py)) {
Py_INCREF(seq.inst);
return seq.inst;
if (nb_try_inc_ref(seq.inst)) {
return seq.inst;
}
}

if (seq.next == nullptr)
Expand Down
5 changes: 5 additions & 0 deletions tests/test_thread.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ struct Counter {
}
};

struct GlobalData {} global_data;

nb::ft_mutex mutex;

NB_MODULE(test_thread_ext, m) {
Expand All @@ -34,4 +36,7 @@ NB_MODULE(test_thread_ext, m) {
nb::ft_lock_guard guard(mutex);
c.inc();
}, "counter");

nb::class_<GlobalData>(m, "GlobalData")
.def_static("get", [] { return &global_data; }, nb::rv_policy::reference);
}
15 changes: 14 additions & 1 deletion tests/test_thread.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import test_thread_ext as t
from test_thread_ext import Counter
from test_thread_ext import Counter, GlobalData
from common import parallelize

def test01_object_creation(n_threads=8):
Expand Down Expand Up @@ -75,3 +75,16 @@ def f():

parallelize(f, n_threads=n_threads)
assert c.value == n * n_threads


def test_06_global_wrapper(n_threads=8):
# Check wrapper lookup racing with wrapper deallocation
n = 10000
def f():
for i in range(n):
GlobalData.get()
GlobalData.get()
GlobalData.get()
GlobalData.get()

parallelize(f, n_threads=n_threads)

0 comments on commit d758a00

Please sign in to comment.