Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Variants: allow registering callbacks for variant changes #1367

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/core/python/drjit_v.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,4 +189,3 @@ MI_PY_EXPORT(DrJit) {
// Loop type alias
m.attr("while_loop") = drjit.attr("while_loop");
}

54 changes: 54 additions & 0 deletions src/core/tests/test_variants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import pytest
import mitsuba as mi


def test01_variants_callbacks(variants_all_backends_once):
available = mi.variants()
if len(available) <= 1:
pytest.mark.skip("Test requires more than 1 enabled variant")

history = []
change_count = 0
def track_changes(old, new):
history.append((old, new))
def count_changes(old, new):
nonlocal change_count
change_count += 1

mi.detail.add_variant_callback(track_changes)
mi.detail.add_variant_callback(count_changes)
# Adding the same callback multiple times does nothing.
# It won't be called multiple times.
mi.detail.add_variant_callback(track_changes)
mi.detail.add_variant_callback(track_changes)

try:
previous = mi.variant()
base_i = available.index(previous)

expected = []
for i in range(1, len(available) + 1):
next_i = (base_i + i) % len(available)
next_variant = available[next_i]

assert next_variant != mi.variant()
mi.set_variant(next_variant)
expected.append((previous, next_variant))
previous = next_variant

assert len(expected) > 1
assert len(history) == len(expected)
assert change_count == len(expected)
for e, h in zip(expected, history):
assert h == e

finally:
# The callback shouldn't stay on even if the test fails.
mi.detail.remove_variant_callback(track_changes)

# Callback shouldn't be called anymore
len_e = len(expected)
next_variant = available[(available.index(mi.variant()) + 1) % len(available)]
with mi.util.scoped_set_variant(next_variant):
pass
assert len(expected) == len_e
62 changes: 52 additions & 10 deletions src/python/alias.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
namespace nb = nanobind;

/**
* On Windows the GIL is not held when we load DLLs due to potential deadlocks
* On Windows the GIL is not held when we load DLLs due to potential deadlocks
* with the Windows loader-lock.
* (see https://github.com/python/cpython/issues/78076 that describes a similar
* issue). Here, initialization of static variables is performed during DLL
Expand All @@ -44,6 +44,10 @@ PyObject *mi_dict = nullptr;
/// Current variant (string)
nb::object curr_variant;

/// Set of user-provided callback functions to call when switching variants
PyObject *variant_change_callbacks;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this not an std::set or at least an nb::object?



nb::object import_with_deepbind_if_necessary(const char* name) {
#if defined(__clang__) && !defined(__APPLE__)
nb::int_ backupflags;
Expand Down Expand Up @@ -112,23 +116,50 @@ static void set_variant(nb::args args) {
}

if (!curr_variant.equal(new_variant)) {
curr_variant = new_variant;
nb::object curr_variant_module = variant_module(curr_variant);
nb::object new_variant_module = variant_module(new_variant);

nb::dict variant_dict = curr_variant_module.attr("__dict__");
nb::dict variant_dict = new_variant_module.attr("__dict__");
for (const auto &k : variant_dict.keys())
if (!nb::bool_(k.attr("startswith")("__")) &&
!nb::bool_(k.attr("endswith")("__")))
Safe_PyDict_SetItem(mi_dict, k.ptr(),
PyDict_GetItem(variant_dict.ptr(), k.ptr()));

if (new_variant.attr("startswith")(nb::make_tuple("llvm_", "cuda_"))) {
nb::module_ mi_python = nb::module_::import_("mitsuba.python.ad.integrators");
nb::steal(PyImport_ReloadModule(mi_python.ptr()));
}
const auto &callbacks = nb::borrow<nb::set>(variant_change_callbacks);
for (const auto &cb : callbacks)
cb(curr_variant, new_variant);

curr_variant = new_variant;
}
}

/**
* The given callable will be called each time the Mitsuba variable is changed.
* Note that callbacks are kept in a set: a given callback function will only be
* called once, even if it is added multiple times.
*
* `callback` will be called with the arguments `old_variant: str, new_variant: str`.
*/
static void add_variant_callback(const nb::callable &callback) {
nb::borrow<nb::set>(variant_change_callbacks).add(callback);
}

/**
* Removes the given `callback` callable from the list of callbacks to be called
* when the Mitsuba variant changes.
*/
static void remove_variant_callback(const nb::callable &callback) {
nb::borrow<nb::set>(variant_change_callbacks).discard(callback);
}

/**
* Removes all callbacks to be called when the Mitsuba variant changes.
*/
static void clear_variant_callbacks() {
nb::borrow<nb::set>(variant_change_callbacks).clear();
}


/// Fallback for when we're attempting to fetch variant-specific attribute
static nb::object get_attr(nb::handle key) {
if (PyDict_Contains(variant_modules, key.ptr()) == 1)
Expand All @@ -142,6 +173,7 @@ NB_MODULE(mitsuba_alias, m) {
m.attr("__name__") = "mitsuba";

curr_variant = nb::none();
variant_change_callbacks = PySet_New(nullptr);

if (!variant_modules) {
variant_modules = PyDict_New();
Expand Down Expand Up @@ -178,17 +210,24 @@ NB_MODULE(mitsuba_alias, m) {
/// Only used for variant-specific attributes e.g. mi.scalar_rgb.T
m.def("__getattr__", [](nb::handle key) { return get_attr(key); });

// `mitsuba.detail` submodule
nb::module_ mi_detail = m.def_submodule("detail");
wjakob marked this conversation as resolved.
Show resolved Hide resolved
mi_detail.def("add_variant_callback", add_variant_callback);
mi_detail.def("remove_variant_callback", remove_variant_callback);
mi_detail.def("clear_variant_callbacks", clear_variant_callbacks);

/// Fill `__dict__` with all objects in `mitsuba_ext` and `mitsuba.python`
mi_dict = m.attr("__dict__").ptr();
nb::object mi_ext = import_with_deepbind_if_necessary("mitsuba.mitsuba_ext");
nb::object mi_python = nb::module_::import_("mitsuba.python");
nb::dict mitsuba_ext_dict = mi_ext.attr("__dict__");
for (const auto &k : mitsuba_ext_dict.keys())
if (!nb::bool_(k.attr("startswith")("__")) &&
!nb::bool_(k.attr("endswith")("__"))) {
Safe_PyDict_SetItem(mi_dict, k.ptr(), mitsuba_ext_dict[k].ptr());
}

// Import contents of `mitsuba.python` into top-level `mitsuba` module
nb::object mi_python = nb::module_::import_("mitsuba.python");
nb::dict mitsuba_python_dict = mi_python.attr("__dict__");
for (const auto &k : mitsuba_python_dict.keys())
if (!nb::bool_(k.attr("startswith")("__")) &&
Expand All @@ -204,9 +243,12 @@ NB_MODULE(mitsuba_alias, m) {
PyDict_Clear(mi_dict);
mi_dict = nullptr;

PySet_Clear(variant_change_callbacks);
variant_change_callbacks = nullptr;

if (variant_modules) {
Py_DECREF(variant_modules);
variant_modules = nullptr;
variant_modules = nullptr;
}
}));
}
2 changes: 1 addition & 1 deletion src/python/python/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .util import traverse, SceneParameters, render, cornell_box, variant_context
from .util import traverse, SceneParameters, render, cornell_box, variant_context, scoped_set_variant
from . import chi2
from . import xml
from . import ad
Expand Down
10 changes: 7 additions & 3 deletions src/python/python/ad/integrators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# Import/re-import all files in this folder to register AD integrators
import importlib
import mitsuba as mi

if mi.variant() is not None and not mi.variant().startswith('scalar'):
def integrators_variants_cb(old, new):
if new is None or new.startswith("scalar"):
return

import importlib
from . import common
importlib.reload(common)

Expand All @@ -21,4 +24,5 @@
from . import prb_projective
importlib.reload(prb_projective)

del importlib, mi

mi.detail.add_variant_callback(integrators_variants_cb)
2 changes: 2 additions & 0 deletions src/python/python/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,3 +704,5 @@ def variant_context(*args) -> None:
raise
finally:
mi.set_variant(old_variant)

scoped_set_variant = variant_context