From 26036b430de56cc59de55791bb2946d3894b9eb9 Mon Sep 17 00:00:00 2001 From: Merlin Nimier-David Date: Mon, 28 Oct 2024 18:03:12 +0100 Subject: [PATCH 1/4] Variants: allow registering callbacks for variant changes + unit test --- src/core/python/drjit_v.cpp | 1 - src/core/tests/test_variants.py | 54 +++++++++++++++++++++++++++++++++ src/python/alias.cpp | 54 ++++++++++++++++++++++++++++++--- 3 files changed, 103 insertions(+), 6 deletions(-) create mode 100644 src/core/tests/test_variants.py diff --git a/src/core/python/drjit_v.cpp b/src/core/python/drjit_v.cpp index 12a82b5bf..5afb5545b 100644 --- a/src/core/python/drjit_v.cpp +++ b/src/core/python/drjit_v.cpp @@ -189,4 +189,3 @@ MI_PY_EXPORT(DrJit) { // Loop type alias m.attr("while_loop") = drjit.attr("while_loop"); } - diff --git a/src/core/tests/test_variants.py b/src/core/tests/test_variants.py new file mode 100644 index 000000000..b96df8bdf --- /dev/null +++ b/src/core/tests/test_variants.py @@ -0,0 +1,54 @@ +import pytest +import mitsuba as mi + + +def test01_variants_callbacks(variants_all_backends_once): + available = mi.variants() + if len(available) <= 1: + pytest.mark.skip("Test requires more than 1 enabled variant") + + history = [] + change_count = 0 + def track_changes(old, new): + history.append((old, new)) + def count_changes(old, new): + nonlocal change_count + change_count += 1 + + mi.add_set_variant_callback(track_changes) + mi.add_set_variant_callback(count_changes) + # Adding the same callback multiple times does nothing. + # It won't be called multiple times. + mi.add_set_variant_callback(track_changes) + mi.add_set_variant_callback(track_changes) + + try: + previous = mi.variant() + base_i = available.index(previous) + + expected = [] + for i in range(1, len(available) + 1): + next_i = (base_i + i) % len(available) + next_variant = available[next_i] + + assert next_variant != mi.variant() + mi.set_variant(next_variant) + expected.append((previous, next_variant)) + previous = next_variant + + assert len(expected) > 1 + assert len(history) == len(expected) + assert change_count == len(expected) + for e, h in zip(expected, history): + assert h == e + + finally: + # The callback shouldn't survive + mi.remove_set_variant_callback(track_changes) + + # Callback shouldn't be called anymore + len_e = len(expected) + next_variant = available[(available.index(mi.variant()) + 1) % len(available)] + with mi.util.scoped_set_variant(next_variant): + pass + assert len(expected) == len_e diff --git a/src/python/alias.cpp b/src/python/alias.cpp index b2ba1c6fe..81ad1d10d 100644 --- a/src/python/alias.cpp +++ b/src/python/alias.cpp @@ -23,7 +23,7 @@ namespace nb = nanobind; /** - * On Windows the GIL is not held when we load DLLs due to potential deadlocks + * On Windows the GIL is not held when we load DLLs due to potential deadlocks * with the Windows loader-lock. * (see https://github.com/python/cpython/issues/78076 that describes a similar * issue). Here, initialization of static variables is performed during DLL @@ -44,6 +44,10 @@ PyObject *mi_dict = nullptr; /// Current variant (string) nb::object curr_variant; +/// Set of user-provided callback functions to call when switching variants +PyObject *variant_change_callbacks; + + nb::object import_with_deepbind_if_necessary(const char* name) { #if defined(__clang__) && !defined(__APPLE__) nb::int_ backupflags; @@ -112,23 +116,56 @@ static void set_variant(nb::args args) { } if (!curr_variant.equal(new_variant)) { - curr_variant = new_variant; - nb::object curr_variant_module = variant_module(curr_variant); + nb::object new_variant_module = variant_module(new_variant); - nb::dict variant_dict = curr_variant_module.attr("__dict__"); + nb::dict variant_dict = new_variant_module.attr("__dict__"); for (const auto &k : variant_dict.keys()) if (!nb::bool_(k.attr("startswith")("__")) && !nb::bool_(k.attr("endswith")("__"))) Safe_PyDict_SetItem(mi_dict, k.ptr(), PyDict_GetItem(variant_dict.ptr(), k.ptr())); + const auto &callbacks = nb::borrow(variant_change_callbacks); + for (const auto &cb : callbacks) + cb(curr_variant, new_variant); + + // TODO: replace this with a callback? if (new_variant.attr("startswith")(nb::make_tuple("llvm_", "cuda_"))) { nb::module_ mi_python = nb::module_::import_("mitsuba.python.ad.integrators"); nb::steal(PyImport_ReloadModule(mi_python.ptr())); } + + curr_variant = new_variant; } } +/** + * The given callable will be called each time the Mitsuba variable is changed. + * Note that callbacks are kept in a set: a given callback function will only be + * called once, even if it is added multiple times. + * + * `callback` will be called with the arguments `old_variant: str, new_variant: str`. + */ +static void add_set_variant_callback(const nb::callable &callback) { + nb::borrow(variant_change_callbacks).add(callback); +} + +/** + * Removes the given `callback` callable from the list of callbacks to be called + * when the Mitsuba variant changes. + */ +static void remove_set_variant_callback(const nb::callable &callback) { + nb::borrow(variant_change_callbacks).discard(callback); +} + +/** + * Removes all callbacks to be called when the Mitsuba variant changes. + */ +static void clear_set_variant_callback() { + nb::borrow(variant_change_callbacks).clear(); +} + + /// Fallback for when we're attempting to fetch variant-specific attribute static nb::object get_attr(nb::handle key) { if (PyDict_Contains(variant_modules, key.ptr()) == 1) @@ -142,6 +179,7 @@ NB_MODULE(mitsuba_alias, m) { m.attr("__name__") = "mitsuba"; curr_variant = nb::none(); + variant_change_callbacks = PySet_New(nullptr); if (!variant_modules) { variant_modules = PyDict_New(); @@ -175,6 +213,9 @@ NB_MODULE(mitsuba_alias, m) { m.def("variant", []() { return curr_variant ? curr_variant : nb::none(); }); m.def("variants", []() { return nb::steal(PyDict_Keys(variant_modules)); }); m.def("set_variant", set_variant); + m.def("add_set_variant_callback", add_set_variant_callback); + m.def("remove_set_variant_callback", remove_set_variant_callback); + m.def("clear_set_variant_callback", clear_set_variant_callback); /// Only used for variant-specific attributes e.g. mi.scalar_rgb.T m.def("__getattr__", [](nb::handle key) { return get_attr(key); }); @@ -204,9 +245,12 @@ NB_MODULE(mitsuba_alias, m) { PyDict_Clear(mi_dict); mi_dict = nullptr; + PySet_Clear(variant_change_callbacks); + variant_change_callbacks = nullptr; + if (variant_modules) { Py_DECREF(variant_modules); - variant_modules = nullptr; + variant_modules = nullptr; } })); } From 6ca2661d27a8b1aaeb2a625700fe62c2a333a5a8 Mon Sep 17 00:00:00 2001 From: Merlin Nimier-David Date: Mon, 28 Oct 2024 18:03:40 +0100 Subject: [PATCH 2/4] variant_context: add alias `scoped_set_variant` Consistent with e.g. `dr.scoped_set_flag()` --- src/python/python/util.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/python/python/util.py b/src/python/python/util.py index dc32c5e0b..48d513744 100644 --- a/src/python/python/util.py +++ b/src/python/python/util.py @@ -704,3 +704,5 @@ def variant_context(*args) -> None: raise finally: mi.set_variant(old_variant) + +scoped_set_variant = variant_context From 5140706ae6af2735a046bbc84e3de0464c523ecd Mon Sep 17 00:00:00 2001 From: Merlin Nimier-David Date: Mon, 28 Oct 2024 19:55:28 +0100 Subject: [PATCH 3/4] Python integrators: register using the callback mechanism --- src/python/alias.cpp | 6 ------ src/python/python/ad/integrators/__init__.py | 10 +++++++--- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/src/python/alias.cpp b/src/python/alias.cpp index 81ad1d10d..b4ef8bc80 100644 --- a/src/python/alias.cpp +++ b/src/python/alias.cpp @@ -129,12 +129,6 @@ static void set_variant(nb::args args) { for (const auto &cb : callbacks) cb(curr_variant, new_variant); - // TODO: replace this with a callback? - if (new_variant.attr("startswith")(nb::make_tuple("llvm_", "cuda_"))) { - nb::module_ mi_python = nb::module_::import_("mitsuba.python.ad.integrators"); - nb::steal(PyImport_ReloadModule(mi_python.ptr())); - } - curr_variant = new_variant; } } diff --git a/src/python/python/ad/integrators/__init__.py b/src/python/python/ad/integrators/__init__.py index c83f8184a..3525c0ddd 100644 --- a/src/python/python/ad/integrators/__init__.py +++ b/src/python/python/ad/integrators/__init__.py @@ -1,8 +1,11 @@ # Import/re-import all files in this folder to register AD integrators -import importlib import mitsuba as mi -if mi.variant() is not None and not mi.variant().startswith('scalar'): +def integrators_variants_cb(old, new): + if new is None or new.startswith("scalar"): + return + + import importlib from . import common importlib.reload(common) @@ -21,4 +24,5 @@ from . import prb_projective importlib.reload(prb_projective) -del importlib, mi + +mi.add_set_variant_callback(integrators_variants_cb) From 99a223eb094db67282900ac4f7933e09ac45e3a0 Mon Sep 17 00:00:00 2001 From: Merlin Nimier-David Date: Tue, 29 Oct 2024 14:27:18 +0100 Subject: [PATCH 4/4] Rename new methods and move to `mi.detail` --- src/core/tests/test_variants.py | 12 ++++++------ src/python/alias.cpp | 18 +++++++++++------- src/python/python/__init__.py | 2 +- src/python/python/ad/integrators/__init__.py | 2 +- 4 files changed, 19 insertions(+), 15 deletions(-) diff --git a/src/core/tests/test_variants.py b/src/core/tests/test_variants.py index b96df8bdf..cd3c34563 100644 --- a/src/core/tests/test_variants.py +++ b/src/core/tests/test_variants.py @@ -15,12 +15,12 @@ def count_changes(old, new): nonlocal change_count change_count += 1 - mi.add_set_variant_callback(track_changes) - mi.add_set_variant_callback(count_changes) + mi.detail.add_variant_callback(track_changes) + mi.detail.add_variant_callback(count_changes) # Adding the same callback multiple times does nothing. # It won't be called multiple times. - mi.add_set_variant_callback(track_changes) - mi.add_set_variant_callback(track_changes) + mi.detail.add_variant_callback(track_changes) + mi.detail.add_variant_callback(track_changes) try: previous = mi.variant() @@ -43,8 +43,8 @@ def count_changes(old, new): assert h == e finally: - # The callback shouldn't survive - mi.remove_set_variant_callback(track_changes) + # The callback shouldn't stay on even if the test fails. + mi.detail.remove_variant_callback(track_changes) # Callback shouldn't be called anymore len_e = len(expected) diff --git a/src/python/alias.cpp b/src/python/alias.cpp index b4ef8bc80..f58090291 100644 --- a/src/python/alias.cpp +++ b/src/python/alias.cpp @@ -140,7 +140,7 @@ static void set_variant(nb::args args) { * * `callback` will be called with the arguments `old_variant: str, new_variant: str`. */ -static void add_set_variant_callback(const nb::callable &callback) { +static void add_variant_callback(const nb::callable &callback) { nb::borrow(variant_change_callbacks).add(callback); } @@ -148,14 +148,14 @@ static void add_set_variant_callback(const nb::callable &callback) { * Removes the given `callback` callable from the list of callbacks to be called * when the Mitsuba variant changes. */ -static void remove_set_variant_callback(const nb::callable &callback) { +static void remove_variant_callback(const nb::callable &callback) { nb::borrow(variant_change_callbacks).discard(callback); } /** * Removes all callbacks to be called when the Mitsuba variant changes. */ -static void clear_set_variant_callback() { +static void clear_variant_callbacks() { nb::borrow(variant_change_callbacks).clear(); } @@ -207,16 +207,18 @@ NB_MODULE(mitsuba_alias, m) { m.def("variant", []() { return curr_variant ? curr_variant : nb::none(); }); m.def("variants", []() { return nb::steal(PyDict_Keys(variant_modules)); }); m.def("set_variant", set_variant); - m.def("add_set_variant_callback", add_set_variant_callback); - m.def("remove_set_variant_callback", remove_set_variant_callback); - m.def("clear_set_variant_callback", clear_set_variant_callback); /// Only used for variant-specific attributes e.g. mi.scalar_rgb.T m.def("__getattr__", [](nb::handle key) { return get_attr(key); }); + // `mitsuba.detail` submodule + nb::module_ mi_detail = m.def_submodule("detail"); + mi_detail.def("add_variant_callback", add_variant_callback); + mi_detail.def("remove_variant_callback", remove_variant_callback); + mi_detail.def("clear_variant_callbacks", clear_variant_callbacks); + /// Fill `__dict__` with all objects in `mitsuba_ext` and `mitsuba.python` mi_dict = m.attr("__dict__").ptr(); nb::object mi_ext = import_with_deepbind_if_necessary("mitsuba.mitsuba_ext"); - nb::object mi_python = nb::module_::import_("mitsuba.python"); nb::dict mitsuba_ext_dict = mi_ext.attr("__dict__"); for (const auto &k : mitsuba_ext_dict.keys()) if (!nb::bool_(k.attr("startswith")("__")) && @@ -224,6 +226,8 @@ NB_MODULE(mitsuba_alias, m) { Safe_PyDict_SetItem(mi_dict, k.ptr(), mitsuba_ext_dict[k].ptr()); } + // Import contents of `mitsuba.python` into top-level `mitsuba` module + nb::object mi_python = nb::module_::import_("mitsuba.python"); nb::dict mitsuba_python_dict = mi_python.attr("__dict__"); for (const auto &k : mitsuba_python_dict.keys()) if (!nb::bool_(k.attr("startswith")("__")) && diff --git a/src/python/python/__init__.py b/src/python/python/__init__.py index 636cdbd39..c937a75b7 100644 --- a/src/python/python/__init__.py +++ b/src/python/python/__init__.py @@ -1,4 +1,4 @@ -from .util import traverse, SceneParameters, render, cornell_box, variant_context +from .util import traverse, SceneParameters, render, cornell_box, variant_context, scoped_set_variant from . import chi2 from . import xml from . import ad diff --git a/src/python/python/ad/integrators/__init__.py b/src/python/python/ad/integrators/__init__.py index 3525c0ddd..01348a594 100644 --- a/src/python/python/ad/integrators/__init__.py +++ b/src/python/python/ad/integrators/__init__.py @@ -25,4 +25,4 @@ def integrators_variants_cb(old, new): importlib.reload(prb_projective) -mi.add_set_variant_callback(integrators_variants_cb) +mi.detail.add_variant_callback(integrators_variants_cb)