From 33ed8069c4d2bf4763de0c7a9c1c50ca97657b88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20D=C3=B6ring?= Date: Wed, 25 Sep 2024 15:34:43 +0200 Subject: [PATCH] Formatting of TraversableBase related changes --- include/drjit/array_traverse.h | 12 +++++----- include/drjit/python.h | 28 +++++++++++----------- include/drjit/texture.h | 40 ++++++++++++++++---------------- include/drjit/traversable_base.h | 24 +++++++++---------- src/python/detail.cpp | 14 +++++------ tests/call_ext.cpp | 3 +-- 6 files changed, 57 insertions(+), 64 deletions(-) diff --git a/include/drjit/array_traverse.h b/include/drjit/array_traverse.h index 70b465dd8..98eda0775 100644 --- a/include/drjit/array_traverse.h +++ b/include/drjit/array_traverse.h @@ -171,7 +171,7 @@ template using enable_if_traversable_t = enable_if_t static constexpr bool is_dynamic_traversable_v = is_jit_v && is_dynamic_array_v && is_vector_v && !is_tensor_v; -template struct is_ref_t: std::false_type{}; +template struct is_ref_t : std::false_type {}; template struct is_iterable_t : std::false_type {}; @@ -207,15 +207,15 @@ void traverse_1_fn_ro(const Value &value, void *payload, void (*fn)(void *, uint is_detected_v) { if (value) value->traverse_1_cb_ro(payload, fn); - + } else if constexpr (is_iterable_t::value) { - for (auto elem: value){ + for (auto elem : value) { traverse_1_fn_ro(elem, payload, fn); } } else if constexpr (is_ref_t::value) { const auto *tmp = value.get(); traverse_1_fn_ro(tmp, payload, fn); - } else if constexpr (is_detected_v) { + } else if constexpr (is_detected_v) { value.traverse_1_cb_ro(payload, fn); } else { // static_assert(false, "Failed to traverse field!"); @@ -242,13 +242,13 @@ void traverse_1_fn_rw(Value &value, void *payload, uint64_t (*fn)(void *, uint64 if (value) value->traverse_1_cb_rw(payload, fn); } else if constexpr (is_iterable_t::value) { - for (auto elem: value){ + for (auto elem : value) { traverse_1_fn_rw(elem, payload, fn); } } else if constexpr (is_ref_t::value) { auto *tmp = value.get(); traverse_1_fn_rw(tmp, payload, fn); - } else if constexpr (is_detected_v) { + } else if constexpr (is_detected_v) { value.traverse_1_cb_rw(payload, fn); } else { // static_assert(false, "Failed to traverse field!"); diff --git a/include/drjit/python.h b/include/drjit/python.h index 2ee82868f..521b396b4 100644 --- a/include/drjit/python.h +++ b/include/drjit/python.h @@ -1056,7 +1056,6 @@ template auto& bind_traverse(nanobind::class_traverse_1_cb_rw((void *) &payload, [](void *p, uint64_t index) { return nb::cast(((Payload *) p)->c(index)); }); @@ -1066,35 +1065,34 @@ template auto& bind_traverse(nanobind::class_self_py(); if (!self) return; - + auto detail = nb::module_::import_("drjit.detail"); nb::callable traverse_py_cb_ro = nb::borrow(nb::getattr(detail, "traverse_py_cb_ro")); - traverse_py_cb_ro(self, nb::cpp_function([&](uint64_t index){ - fn(payload, index); - })); + traverse_py_cb_ro( + self, nb::cpp_function([&](uint64_t index) { fn(payload, index); })); } inline void traverse_py_cb_rw(TraversableBase *base, void *payload, - uint64_t (*fn)(void *, uint64_t)) { - - namespace nb = nanobind; + uint64_t (*fn)(void *, uint64_t)) { + + namespace nb = nanobind; nb::handle self = base->self_py(); if (!self) return; - + auto detail = nb::module_::import_("drjit.detail"); nb::callable traverse_py_cb_rw = nb::borrow(nb::getattr(detail, "traverse_py_cb_rw")); - - traverse_py_cb_rw(self, nb::cpp_function([&](uint64_t index){ - return fn(payload, index); - })); + + traverse_py_cb_rw(self, nb::cpp_function([&](uint64_t index) { + return fn(payload, index); + })); } NAMESPACE_END(drjit) diff --git a/include/drjit/texture.h b/include/drjit/texture.h index 1ae3bc893..808ba5121 100644 --- a/include/drjit/texture.h +++ b/include/drjit/texture.h @@ -1384,26 +1384,26 @@ template class Texture : TraversableBase { mutable bool m_migrated = false; public: -void traverse_1_cb_ro(void *payload, - void (*fn)(void *, uint64_t)) const override { - if constexpr (!std ::is_same_v) - drjit ::TraversableBase ::traverse_1_cb_ro(payload, fn); - - DR_TRAVERSE_MEMBER_RO(m_value) - DR_TRAVERSE_MEMBER_RO(m_shape_opaque) - DR_TRAVERSE_MEMBER_RO(m_inv_resolution) -} -void traverse_1_cb_rw(void *payload, - uint64_t (*fn)(void *, uint64_t)) override { - if constexpr (!std ::is_same_v) - drjit ::TraversableBase ::traverse_1_cb_rw(payload, fn); - - DR_TRAVERSE_MEMBER_RW(m_value) - DR_TRAVERSE_MEMBER_RW(m_shape_opaque) - DR_TRAVERSE_MEMBER_RW(m_inv_resolution) -} + void traverse_1_cb_ro(void *payload, + void (*fn)(void *, uint64_t)) const override { + if constexpr (!std ::is_same_v) + drjit ::TraversableBase ::traverse_1_cb_ro(payload, fn); + + DR_TRAVERSE_MEMBER_RO(m_value) + DR_TRAVERSE_MEMBER_RO(m_shape_opaque) + DR_TRAVERSE_MEMBER_RO(m_inv_resolution) + } + void traverse_1_cb_rw(void *payload, + uint64_t (*fn)(void *, uint64_t)) override { + if constexpr (!std ::is_same_v) + drjit ::TraversableBase ::traverse_1_cb_rw(payload, fn); + + DR_TRAVERSE_MEMBER_RW(m_value) + DR_TRAVERSE_MEMBER_RW(m_shape_opaque) + DR_TRAVERSE_MEMBER_RW(m_inv_resolution) + } }; NAMESPACE_END(drjit) diff --git a/include/drjit/traversable_base.h b/include/drjit/traversable_base.h index 6f2f082ee..a530269f5 100644 --- a/include/drjit/traversable_base.h +++ b/include/drjit/traversable_base.h @@ -1,14 +1,13 @@ - #pragma once -#include "drjit-core/macros.h" #include "array_traverse.h" +#include "drjit-core/macros.h" #include "nanobind/intrusive/counter.h" #include "nanobind/intrusive/ref.h" -#include -#include #include #include +#include +#include NAMESPACE_BEGIN(drjit) @@ -19,7 +18,8 @@ struct TraversableBase : nanobind::intrusive_base { template struct is_ref_t> : std::true_type {}; template struct is_ref_t> : std::true_type {}; -// template struct is_iterable_t> : std::true_type {}; +// template struct is_iterable_t> : std::true_type +// {}; #define DR_TRAVERSE_MEMBER_RO(member) \ drjit::log_member_open(false, #member); \ @@ -30,13 +30,11 @@ template struct is_ref_t> : std::true_type {}; drjit::traverse_1_fn_rw(member, payload, fn); \ drjit::log_member_close(); -inline void log_member_open(bool rw, const char *member){ +inline void log_member_open(bool rw, const char *member) { jit_log(LogLevel::Debug, "%s%s{", rw ? "rw " : "ro ", member); } -inline void log_member_close(){ - jit_log(LogLevel::Debug, "}"); -} +inline void log_member_close() { jit_log(LogLevel::Debug, "}"); } #define DR_TRAVERSE_CB_RO(Base, ...) \ void traverse_1_cb_ro(void *payload, void (*fn)(void *, uint64_t)) \ @@ -63,21 +61,21 @@ public: \ public: \ void traverse_1_cb_ro(void *payload, void (*fn)(void *, uint64_t)) \ const override { \ - if constexpr (!std ::is_same_v) \ + if constexpr (!std ::is_same_v) \ Base ::traverse_1_cb_ro(payload, fn); \ drjit::traverse_py_cb_ro(this, payload, fn); \ } \ void traverse_1_cb_rw(void *payload, uint64_t (*fn)(void *, uint64_t)) \ override { \ - if constexpr (!std ::is_same_v) \ + if constexpr (!std ::is_same_v) \ Base ::traverse_1_cb_rw(payload, fn); \ drjit::traverse_py_cb_rw(this, payload, fn); \ } #if defined(_MSC_VER) -# define DRJIT_EXPORT __declspec(dllexport) +#define DRJIT_EXPORT __declspec(dllexport) #else -# define DRJIT_EXPORT __attribute__ ((visibility("default"))) +#define DRJIT_EXPORT __attribute__((visibility("default"))) #endif NAMESPACE_END(drjit) diff --git a/src/python/detail.cpp b/src/python/detail.cpp index 94968092d..88fdaeb48 100644 --- a/src/python/detail.cpp +++ b/src/python/detail.cpp @@ -271,26 +271,24 @@ void disable_py_tracing() { nb::module_::import_("sys").attr("settrace")(nb::none()); } -void traverse_py_cb_ro_impl(nb::handle self, nb::callable c){ - struct PyTraverseCallback: TraverseCallback{ - void operator()(nb::handle h) override{ +void traverse_py_cb_ro_impl(nb::handle self, nb::callable c) { + struct PyTraverseCallback : TraverseCallback { + void operator()(nb::handle h) override { auto index_fn = supp(h.type()).index; if (index_fn) operator()(index_fn(inst_ptr(h))); - } - void operator()(uint64_t index) override{ - m_callback(index); } + void operator()(uint64_t index) override { m_callback(index); } nb::callable m_callback; - PyTraverseCallback(nb::callable c): m_callback(c){} + PyTraverseCallback(nb::callable c) : m_callback(c) {} }; PyTraverseCallback traverse_cb(std::move(c)); auto dict = nb::borrow(nb::getattr(self, "__dict__")); - for (auto value: dict.values()){ + for (auto value : dict.values()) { traverse("traverse_py_cb_ro", traverse_cb, value); } } diff --git a/tests/call_ext.cpp b/tests/call_ext.cpp index 857c1cd50..4c7a0e9d1 100644 --- a/tests/call_ext.cpp +++ b/tests/call_ext.cpp @@ -15,7 +15,7 @@ using namespace nb::literals; template struct Sampler : dr::TraversableBase { - Sampler() : rng(1) { } + Sampler() : rng(1) {} Sampler(size_t size) : rng(size) { } T next() { return rng.next_float32(); } @@ -218,7 +218,6 @@ void bind(nb::module_ &m) { .def_rw("opaque", &BT::opaque) .def_rw("value", &BT::value); bind_traverse(b_cls); - using BaseArray = dr::DiffArray; m.def("dispatch_f", [](BaseArray &self, Float a, Float b) {