diff --git a/src/core/python/drjit_v.cpp b/src/core/python/drjit_v.cpp index 12a82b5bf..5afb5545b 100644 --- a/src/core/python/drjit_v.cpp +++ b/src/core/python/drjit_v.cpp @@ -189,4 +189,3 @@ MI_PY_EXPORT(DrJit) { // Loop type alias m.attr("while_loop") = drjit.attr("while_loop"); } - diff --git a/src/core/tests/test_variants.py b/src/core/tests/test_variants.py new file mode 100644 index 000000000..cd3c34563 --- /dev/null +++ b/src/core/tests/test_variants.py @@ -0,0 +1,54 @@ +import pytest +import mitsuba as mi + + +def test01_variants_callbacks(variants_all_backends_once): + available = mi.variants() + if len(available) <= 1: + pytest.mark.skip("Test requires more than 1 enabled variant") + + history = [] + change_count = 0 + def track_changes(old, new): + history.append((old, new)) + def count_changes(old, new): + nonlocal change_count + change_count += 1 + + mi.detail.add_variant_callback(track_changes) + mi.detail.add_variant_callback(count_changes) + # Adding the same callback multiple times does nothing. + # It won't be called multiple times. + mi.detail.add_variant_callback(track_changes) + mi.detail.add_variant_callback(track_changes) + + try: + previous = mi.variant() + base_i = available.index(previous) + + expected = [] + for i in range(1, len(available) + 1): + next_i = (base_i + i) % len(available) + next_variant = available[next_i] + + assert next_variant != mi.variant() + mi.set_variant(next_variant) + expected.append((previous, next_variant)) + previous = next_variant + + assert len(expected) > 1 + assert len(history) == len(expected) + assert change_count == len(expected) + for e, h in zip(expected, history): + assert h == e + + finally: + # The callback shouldn't stay on even if the test fails. + mi.detail.remove_variant_callback(track_changes) + + # Callback shouldn't be called anymore + len_e = len(expected) + next_variant = available[(available.index(mi.variant()) + 1) % len(available)] + with mi.util.scoped_set_variant(next_variant): + pass + assert len(expected) == len_e diff --git a/src/python/alias.cpp b/src/python/alias.cpp index b2ba1c6fe..f58090291 100644 --- a/src/python/alias.cpp +++ b/src/python/alias.cpp @@ -23,7 +23,7 @@ namespace nb = nanobind; /** - * On Windows the GIL is not held when we load DLLs due to potential deadlocks + * On Windows the GIL is not held when we load DLLs due to potential deadlocks * with the Windows loader-lock. * (see https://github.com/python/cpython/issues/78076 that describes a similar * issue). Here, initialization of static variables is performed during DLL @@ -44,6 +44,10 @@ PyObject *mi_dict = nullptr; /// Current variant (string) nb::object curr_variant; +/// Set of user-provided callback functions to call when switching variants +PyObject *variant_change_callbacks; + + nb::object import_with_deepbind_if_necessary(const char* name) { #if defined(__clang__) && !defined(__APPLE__) nb::int_ backupflags; @@ -112,23 +116,50 @@ static void set_variant(nb::args args) { } if (!curr_variant.equal(new_variant)) { - curr_variant = new_variant; - nb::object curr_variant_module = variant_module(curr_variant); + nb::object new_variant_module = variant_module(new_variant); - nb::dict variant_dict = curr_variant_module.attr("__dict__"); + nb::dict variant_dict = new_variant_module.attr("__dict__"); for (const auto &k : variant_dict.keys()) if (!nb::bool_(k.attr("startswith")("__")) && !nb::bool_(k.attr("endswith")("__"))) Safe_PyDict_SetItem(mi_dict, k.ptr(), PyDict_GetItem(variant_dict.ptr(), k.ptr())); - if (new_variant.attr("startswith")(nb::make_tuple("llvm_", "cuda_"))) { - nb::module_ mi_python = nb::module_::import_("mitsuba.python.ad.integrators"); - nb::steal(PyImport_ReloadModule(mi_python.ptr())); - } + const auto &callbacks = nb::borrow(variant_change_callbacks); + for (const auto &cb : callbacks) + cb(curr_variant, new_variant); + + curr_variant = new_variant; } } +/** + * The given callable will be called each time the Mitsuba variable is changed. + * Note that callbacks are kept in a set: a given callback function will only be + * called once, even if it is added multiple times. + * + * `callback` will be called with the arguments `old_variant: str, new_variant: str`. + */ +static void add_variant_callback(const nb::callable &callback) { + nb::borrow(variant_change_callbacks).add(callback); +} + +/** + * Removes the given `callback` callable from the list of callbacks to be called + * when the Mitsuba variant changes. + */ +static void remove_variant_callback(const nb::callable &callback) { + nb::borrow(variant_change_callbacks).discard(callback); +} + +/** + * Removes all callbacks to be called when the Mitsuba variant changes. + */ +static void clear_variant_callbacks() { + nb::borrow(variant_change_callbacks).clear(); +} + + /// Fallback for when we're attempting to fetch variant-specific attribute static nb::object get_attr(nb::handle key) { if (PyDict_Contains(variant_modules, key.ptr()) == 1) @@ -142,6 +173,7 @@ NB_MODULE(mitsuba_alias, m) { m.attr("__name__") = "mitsuba"; curr_variant = nb::none(); + variant_change_callbacks = PySet_New(nullptr); if (!variant_modules) { variant_modules = PyDict_New(); @@ -178,10 +210,15 @@ NB_MODULE(mitsuba_alias, m) { /// Only used for variant-specific attributes e.g. mi.scalar_rgb.T m.def("__getattr__", [](nb::handle key) { return get_attr(key); }); + // `mitsuba.detail` submodule + nb::module_ mi_detail = m.def_submodule("detail"); + mi_detail.def("add_variant_callback", add_variant_callback); + mi_detail.def("remove_variant_callback", remove_variant_callback); + mi_detail.def("clear_variant_callbacks", clear_variant_callbacks); + /// Fill `__dict__` with all objects in `mitsuba_ext` and `mitsuba.python` mi_dict = m.attr("__dict__").ptr(); nb::object mi_ext = import_with_deepbind_if_necessary("mitsuba.mitsuba_ext"); - nb::object mi_python = nb::module_::import_("mitsuba.python"); nb::dict mitsuba_ext_dict = mi_ext.attr("__dict__"); for (const auto &k : mitsuba_ext_dict.keys()) if (!nb::bool_(k.attr("startswith")("__")) && @@ -189,6 +226,8 @@ NB_MODULE(mitsuba_alias, m) { Safe_PyDict_SetItem(mi_dict, k.ptr(), mitsuba_ext_dict[k].ptr()); } + // Import contents of `mitsuba.python` into top-level `mitsuba` module + nb::object mi_python = nb::module_::import_("mitsuba.python"); nb::dict mitsuba_python_dict = mi_python.attr("__dict__"); for (const auto &k : mitsuba_python_dict.keys()) if (!nb::bool_(k.attr("startswith")("__")) && @@ -204,9 +243,12 @@ NB_MODULE(mitsuba_alias, m) { PyDict_Clear(mi_dict); mi_dict = nullptr; + PySet_Clear(variant_change_callbacks); + variant_change_callbacks = nullptr; + if (variant_modules) { Py_DECREF(variant_modules); - variant_modules = nullptr; + variant_modules = nullptr; } })); } diff --git a/src/python/python/__init__.py b/src/python/python/__init__.py index 636cdbd39..c937a75b7 100644 --- a/src/python/python/__init__.py +++ b/src/python/python/__init__.py @@ -1,4 +1,4 @@ -from .util import traverse, SceneParameters, render, cornell_box, variant_context +from .util import traverse, SceneParameters, render, cornell_box, variant_context, scoped_set_variant from . import chi2 from . import xml from . import ad diff --git a/src/python/python/ad/integrators/__init__.py b/src/python/python/ad/integrators/__init__.py index c83f8184a..01348a594 100644 --- a/src/python/python/ad/integrators/__init__.py +++ b/src/python/python/ad/integrators/__init__.py @@ -1,8 +1,11 @@ # Import/re-import all files in this folder to register AD integrators -import importlib import mitsuba as mi -if mi.variant() is not None and not mi.variant().startswith('scalar'): +def integrators_variants_cb(old, new): + if new is None or new.startswith("scalar"): + return + + import importlib from . import common importlib.reload(common) @@ -21,4 +24,5 @@ from . import prb_projective importlib.reload(prb_projective) -del importlib, mi + +mi.detail.add_variant_callback(integrators_variants_cb) diff --git a/src/python/python/util.py b/src/python/python/util.py index dc32c5e0b..48d513744 100644 --- a/src/python/python/util.py +++ b/src/python/python/util.py @@ -704,3 +704,5 @@ def variant_context(*args) -> None: raise finally: mi.set_variant(old_variant) + +scoped_set_variant = variant_context