Skip to content

Commit

Permalink
Explain example_policy; add richer example of callback tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
oremanj committed Nov 5, 2024
1 parent e04a637 commit 83d6606
Show file tree
Hide file tree
Showing 5 changed files with 200 additions and 0 deletions.
4 changes: 4 additions & 0 deletions docs/api_core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2002,6 +2002,10 @@ parameter of :cpp:func:`module_::def`, :cpp:func:`class_::def`,
}
};
For a more complex example (binding an object that uses trivially-copyable
callbacks), see ``tests/test_callbacks.cpp`` in the nanobind source
distribution.

.. _class_binding_annotations:

Class binding annotations
Expand Down
2 changes: 2 additions & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ endif()

set(TEST_NAMES
functions
callbacks
classes
holders
stl
Expand Down Expand Up @@ -137,6 +138,7 @@ target_link_libraries(test_inter_module_2_ext PRIVATE inter_module)

set(TEST_FILES
common.py
test_callbacks.py
test_classes.py
test_eigen.py
test_enum.py
Expand Down
125 changes: 125 additions & 0 deletions tests/test_callbacks.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
// This is an example of using nb::call_policy to support binding an
// object that takes non-owning callbacks. Since the callbacks can't
// directly keep a Python object alive (they're trivially copyable), we
// maintain a sideband structure to manage the lifetimes.

#include <unordered_set>
#include <vector>

#include <nanobind/nanobind.h>
#include <nanobind/stl/unordered_set.h>

namespace nb = nanobind;

// The callback type accepted by the object, which we assume we can't change.
// It's trivially copyable, so it can't directly keep a Python object alive.
struct callback {
void *context;
void (*func)(void *context, int arg);

void operator()(int arg) const { (*func)(context, arg); }
bool operator==(const callback& other) const {
return context == other.context && func == other.func;
}
};

// An object that uses these callbacks, which we want to write bindings for
class publisher {
public:
void subscribe(callback cb) { cbs.push_back(cb); }
void unsubscribe(callback cb) {
cbs.erase(std::remove(cbs.begin(), cbs.end(), cb), cbs.end());
}
void emit(int arg) const { for (auto cb : cbs) cb(arg); }
private:
std::vector<callback> cbs;
};

template <> struct nanobind::detail::type_caster<callback> {
static void wrap_call(void *context, int arg) {
borrow<callable>((PyObject *) context)(arg);
}
bool from_python(handle src, uint8_t, cleanup_list*) noexcept {
if (!isinstance<callable>(src)) return false;
value = {(void *) src.ptr(), &wrap_call};
return true;
}
static handle from_cpp(callback cb, rv_policy policy, cleanup_list*) noexcept {
if (cb.func == &wrap_call)
return handle((PyObject *) cb.context).inc_ref();
if (policy == rv_policy::none)
return handle();
return cpp_function(cb, policy).release();
}
NB_TYPE_CASTER(callback, const_name("Callable[[int], None]"))
};

nb::dict cb_registry() {
return nb::cast<nb::dict>(
nb::module_::import_("test_callbacks_ext").attr("registry"));
}

struct callback_data {
struct py_hash {
size_t operator()(const nb::object& obj) const { return nb::hash(obj); }
};
struct py_eq {
bool operator()(const nb::object& a, const nb::object& b) const {
return a.equal(b);
}
};
std::unordered_set<nb::object, py_hash, py_eq> subscribers;
};

callback_data& callbacks_for(nb::handle publisher) {
auto registry = cb_registry();
nb::weakref key(publisher, registry.attr("__delitem__"));
if (nb::handle value = PyDict_GetItem(registry.ptr(), key.ptr())) {
return nb::cast<callback_data&>(value);
}
nb::object new_data = nb::cast(callback_data{});
registry[key] = new_data;
return nb::cast<callback_data&>(new_data);
}

// to check at compile time that the subscribe/unsubscribe functions take
// two arguments: self, callback
using TwoArgs = std::integral_constant<size_t, 2>;

struct subscribe_policy {
static void precall(PyObject **, TwoArgs, nb::detail::cleanup_list *) {}
static void postcall(PyObject **args, TwoArgs, nb::handle) {
nb::handle self = args[0], cb = args[1];
callbacks_for(self).subscribers.insert(nb::borrow(cb));
}
};

struct unsubscribe_policy {
static void precall(PyObject **args, TwoArgs, nb::detail::cleanup_list *) {
nb::handle self = args[0], cb = args[1];
auto& cbs = callbacks_for(self);
auto it = cbs.subscribers.find(nb::borrow(cb));
if (it != cbs.subscribers.end() && !it->is(cb)) {
// No callback identical to this one is subscribed. Substitute
// one that is Python-equal.
args[1] = it->ptr();
}
}
static void postcall(PyObject **args, TwoArgs, nb::handle) {
nb::handle self = args[0], cb = args[1];
callbacks_for(self).subscribers.erase(nb::borrow(cb));
}
};

NB_MODULE(test_callbacks_ext, m) {
m.attr("registry") = nb::dict();
nb::class_<callback_data>(m, "callback_data")
.def_ro("subscribers", &callback_data::subscribers);
nb::class_<publisher>(m, "publisher", nb::is_weak_referenceable())
.def(nb::init<>())
.def("subscribe", &publisher::subscribe,
nb::call_policy<subscribe_policy>())
.def("unsubscribe", &publisher::unsubscribe,
nb::call_policy<unsubscribe_policy>())
.def("emit", &publisher::emit);
}
58 changes: 58 additions & 0 deletions tests/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import test_callbacks_ext as t
import gc


def test_callbacks():
pub1 = t.publisher()
pub2 = t.publisher()
record = []

def sub1(x):
record.append(x + 10)

def sub2(x):
record.append(x + 20)

pub1.subscribe(sub1)
pub2.subscribe(sub2)
for pub in (pub1, pub2):
pub.subscribe(record.append)

pub1.emit(1)
assert record == [11, 1]
del record[:]

pub2.emit(2)
assert record == [22, 2]
del record[:]

pub1_w, pub2_w = t.registry.keys() # weakrefs to pub1, pub2
assert pub1_w() is pub1
assert pub2_w() is pub2
assert t.registry[pub1_w].subscribers == {sub1, record.append}
assert t.registry[pub2_w].subscribers == {sub2, record.append}

# NB: this `record.append` is a different object than the one we subscribed
# above, so we're testing the normalization logic in unsubscribe_policy
pub1.unsubscribe(record.append)
assert t.registry[pub1_w].subscribers == {sub1}
pub1.emit(3)
assert record == [13]
del record[:]

del pub, pub1
gc.collect()
gc.collect()
assert pub1_w() is None
assert pub2_w() is pub2
assert t.registry.keys() == {pub2_w}

pub2.emit(4)
assert record == [24, 4]
del record[:]

del pub2
gc.collect()
gc.collect()
assert pub2_w() is None
assert not t.registry
11 changes: 11 additions & 0 deletions tests/test_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,17 @@ struct my_call_guard {
~my_call_guard() { call_guard_value = 2; }
};

// Example call policy for use with nb::call_policy<>. Each call will add
// an entry to `calls` containing the arguments tuple and return value.
// The return value will be recorded as "<unfinished>" if the function
// did not return (still executing or threw an exception) and as
// "<return conversion failed>" if the function returned something that we
// couldn't convert to a Python object.
// Additional features to test particular interactions:
// - the precall hook will throw if any arguments are not strings
// - any argument equal to "swapfrom" will be replaced by a temporary
// string object equal to "swapto", which will be destroyed at end of call
// - the postcall hook will throw if any argument equals "postthrow"
struct example_policy {
static inline std::vector<std::pair<nb::tuple, nb::object>> calls;
static void precall(PyObject **args, size_t nargs,
Expand Down

0 comments on commit 83d6606

Please sign in to comment.