-
Notifications
You must be signed in to change notification settings - Fork 207
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Explain example_policy; add richer example of callback tracking
- Loading branch information
Showing
5 changed files
with
205 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
// This is an example of using nb::call_policy to support binding an | ||
// object that takes non-owning callbacks. Since the callbacks can't | ||
// directly keep a Python object alive (they're trivially copyable), we | ||
// maintain a sideband structure to manage the lifetimes. | ||
|
||
#include <unordered_set> | ||
|
||
#include <nanobind/nanobind.h> | ||
#include <nanobind/stl/unordered_set.h> | ||
|
||
namespace nb = nanobind; | ||
|
||
// The callback type accepted by the object, which we assume we can't change. | ||
// It's trivially copyable, so it can't directly keep a Python object alive. | ||
struct callback { | ||
void *context; | ||
void (*func)(void *context, int arg); | ||
|
||
void operator()(int arg) const { (*func)(context, arg); } | ||
bool operator==(const callback& other) const { | ||
return context == other.context && func == other.func; | ||
} | ||
}; | ||
template <> struct std::hash<callback> { | ||
size_t operator()(const callback& cb) const { | ||
return std::hash<void*>()(cb.context) ^ | ||
std::hash<void(*)(void*, int)>()(cb.func); | ||
} | ||
}; | ||
|
||
// An object that uses these callbacks, which we want to write bindings for | ||
class publisher { | ||
public: | ||
void subscribe(callback cb) { cbs.push_back(cb); } | ||
void unsubscribe(callback cb) { | ||
cbs.erase(std::remove(cbs.begin(), cbs.end(), cb), cbs.end()); | ||
} | ||
void emit(int arg) const { for (auto cb : cbs) cb(arg); } | ||
private: | ||
std::vector<callback> cbs; | ||
}; | ||
|
||
template <> struct nanobind::detail::type_caster<callback> { | ||
static void wrap_call(void *context, int arg) { | ||
borrow<callable>((PyObject *) context)(arg); | ||
} | ||
bool from_python(handle src, uint8_t, cleanup_list*) noexcept { | ||
if (!isinstance<callable>(src)) return false; | ||
value = {(void *) src.ptr(), &wrap_call}; | ||
return true; | ||
} | ||
static handle from_cpp(callback cb, rv_policy policy, cleanup_list*) noexcept { | ||
if (cb.func == &wrap_call) | ||
return handle((PyObject *) cb.context).inc_ref(); | ||
if (policy == rv_policy::none) | ||
return handle(); | ||
return cpp_function(cb, policy).release(); | ||
} | ||
NB_TYPE_CASTER(callback, const_name("Callable[[int], None]")) | ||
}; | ||
|
||
nb::dict cb_registry() { | ||
return nb::cast<nb::dict>( | ||
nb::module_::import_("test_callbacks_ext").attr("registry")); | ||
} | ||
|
||
struct callback_data { | ||
struct py_hash { | ||
size_t operator()(const nb::object& obj) const { return nb::hash(obj); } | ||
}; | ||
struct py_eq { | ||
bool operator()(const nb::object& a, const nb::object& b) const { | ||
return a.equal(b); | ||
} | ||
}; | ||
std::unordered_set<nb::object, py_hash, py_eq> subscribers; | ||
}; | ||
|
||
callback_data& callbacks_for(nb::handle publisher) { | ||
auto registry = cb_registry(); | ||
nb::weakref key(publisher, registry.attr("__delitem__")); | ||
if (nb::handle value = PyDict_GetItem(registry.ptr(), key.ptr())) { | ||
return nb::cast<callback_data&>(value); | ||
} | ||
nb::object new_data = nb::cast(callback_data{}); | ||
registry[key] = new_data; | ||
return nb::cast<callback_data&>(new_data); | ||
} | ||
|
||
// to check at compile time that the subscribe/unsubscribe functions take | ||
// two arguments: self, callback | ||
using TwoArgs = std::integral_constant<size_t, 2>; | ||
|
||
struct subscribe_policy { | ||
static void precall(PyObject **, TwoArgs, nb::detail::cleanup_list *) {} | ||
static void postcall(PyObject **args, TwoArgs, nb::handle) { | ||
nb::handle self = args[0], cb = args[1]; | ||
callbacks_for(self).subscribers.insert(nb::borrow(cb)); | ||
} | ||
}; | ||
|
||
struct unsubscribe_policy { | ||
static void precall(PyObject **args, TwoArgs, nb::detail::cleanup_list *) { | ||
nb::handle self = args[0], cb = args[1]; | ||
auto& cbs = callbacks_for(self); | ||
auto it = cbs.subscribers.find(nb::borrow(cb)); | ||
if (it != cbs.subscribers.end() && !it->is(cb)) { | ||
// No callback identical to this one is subscribed. Substitute | ||
// one that is Python-equal. | ||
args[1] = it->ptr(); | ||
} | ||
} | ||
static void postcall(PyObject **args, TwoArgs, nb::handle) { | ||
nb::handle self = args[0], cb = args[1]; | ||
callbacks_for(self).subscribers.erase(nb::borrow(cb)); | ||
} | ||
}; | ||
|
||
NB_MODULE(test_callbacks_ext, m) { | ||
m.attr("registry") = nb::dict(); | ||
nb::class_<callback_data>(m, "callback_data") | ||
.def_ro("subscribers", &callback_data::subscribers); | ||
nb::class_<publisher>(m, "publisher", nb::is_weak_referenceable()) | ||
.def(nb::init<>()) | ||
.def("subscribe", &publisher::subscribe, | ||
nb::call_policy<subscribe_policy>()) | ||
.def("unsubscribe", &publisher::unsubscribe, | ||
nb::call_policy<unsubscribe_policy>()) | ||
.def("emit", &publisher::emit); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import test_callbacks_ext as t | ||
import gc | ||
|
||
|
||
def test_callbacks(): | ||
pub1 = t.publisher() | ||
pub2 = t.publisher() | ||
record = [] | ||
|
||
def sub1(x): | ||
record.append(x + 10) | ||
|
||
def sub2(x): | ||
record.append(x + 20) | ||
|
||
pub1.subscribe(sub1) | ||
pub2.subscribe(sub2) | ||
for pub in (pub1, pub2): | ||
pub.subscribe(record.append) | ||
|
||
pub1.emit(1) | ||
assert record == [11, 1] | ||
del record[:] | ||
|
||
pub2.emit(2) | ||
assert record == [22, 2] | ||
del record[:] | ||
|
||
pub1_w, pub2_w = t.registry.keys() # weakrefs to pub1, pub2 | ||
assert pub1_w() is pub1 | ||
assert pub2_w() is pub2 | ||
assert t.registry[pub1_w].subscribers == {sub1, record.append} | ||
assert t.registry[pub2_w].subscribers == {sub2, record.append} | ||
|
||
# NB: this `record.append` is a different object than the one we subscribed | ||
# above, so we're testing the normalization logic in unsubscribe_policy | ||
pub1.unsubscribe(record.append) | ||
assert t.registry[pub1_w].subscribers == {sub1} | ||
pub1.emit(3) | ||
assert record == [13] | ||
del record[:] | ||
|
||
del pub, pub1 | ||
gc.collect() | ||
gc.collect() | ||
assert pub1_w() is None | ||
assert pub2_w() is pub2 | ||
assert t.registry.keys() == {pub2_w} | ||
|
||
pub2.emit(4) | ||
assert record == [24, 4] | ||
del record[:] | ||
|
||
del pub2 | ||
gc.collect() | ||
gc.collect() | ||
assert pub2_w() is None | ||
assert not t.registry |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters