Skip to content

Commit

Permalink
Formatting of TraversableBase related changes
Browse files Browse the repository at this point in the history
  • Loading branch information
DoeringChristian committed Sep 25, 2024
1 parent de16c73 commit 33ed806
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 64 deletions.
12 changes: 6 additions & 6 deletions include/drjit/array_traverse.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ template <typename T> using enable_if_traversable_t = enable_if_t<is_traversable
template <typename T> static constexpr bool is_dynamic_traversable_v =
is_jit_v<T> && is_dynamic_array_v<T> && is_vector_v<T> && !is_tensor_v<T>;

template <typename T> struct is_ref_t: std::false_type{};
template <typename T> struct is_ref_t : std::false_type {};

template <typename T, typename = void>
struct is_iterable_t : std::false_type {};
Expand Down Expand Up @@ -207,15 +207,15 @@ void traverse_1_fn_ro(const Value &value, void *payload, void (*fn)(void *, uint
is_detected_v<detail::det_traverse_1_cb_ro, Value>) {
if (value)
value->traverse_1_cb_ro(payload, fn);

} else if constexpr (is_iterable_t<Value>::value) {
for (auto elem: value){
for (auto elem : value) {
traverse_1_fn_ro(elem, payload, fn);
}
} else if constexpr (is_ref_t<Value>::value) {
const auto *tmp = value.get();
traverse_1_fn_ro(tmp, payload, fn);
} else if constexpr (is_detected_v<detail::det_traverse_1_cb_ro, Value*>) {
} else if constexpr (is_detected_v<detail::det_traverse_1_cb_ro, Value *>) {
value.traverse_1_cb_ro(payload, fn);
} else {
// static_assert(false, "Failed to traverse field!");
Expand All @@ -242,13 +242,13 @@ void traverse_1_fn_rw(Value &value, void *payload, uint64_t (*fn)(void *, uint64
if (value)
value->traverse_1_cb_rw(payload, fn);
} else if constexpr (is_iterable_t<Value>::value) {
for (auto elem: value){
for (auto elem : value) {
traverse_1_fn_rw(elem, payload, fn);
}
} else if constexpr (is_ref_t<Value>::value) {
auto *tmp = value.get();
traverse_1_fn_rw(tmp, payload, fn);
} else if constexpr (is_detected_v<detail::det_traverse_1_cb_rw, Value*>) {
} else if constexpr (is_detected_v<detail::det_traverse_1_cb_rw, Value *>) {
value.traverse_1_cb_rw(payload, fn);
} else {
// static_assert(false, "Failed to traverse field!");
Expand Down
28 changes: 13 additions & 15 deletions include/drjit/python.h
Original file line number Diff line number Diff line change
Expand Up @@ -1056,7 +1056,6 @@ template <typename T, typename... Args> auto& bind_traverse(nanobind::class_<T,

cls.def("_traverse_1_cb_rw", [](T *self, nb::callable c) {
Payload payload{ std::move(c) };
jit_log(LogLevel::Debug, "pointer %p", &T::traverse_1_cb_rw);
self->traverse_1_cb_rw((void *) &payload, [](void *p, uint64_t index) {
return nb::cast<uint64_t>(((Payload *) p)->c(index));
});
Expand All @@ -1066,35 +1065,34 @@ template <typename T, typename... Args> auto& bind_traverse(nanobind::class_<T,
}

inline void traverse_py_cb_ro(const TraversableBase *base, void *payload,
void (*fn)(void *, uint64_t)) {
namespace nb = nanobind;
void (*fn)(void *, uint64_t)) {
namespace nb = nanobind;
nb::handle self = base->self_py();
if (!self)
return;

auto detail = nb::module_::import_("drjit.detail");
nb::callable traverse_py_cb_ro =
nb::borrow<nb::callable>(nb::getattr(detail, "traverse_py_cb_ro"));

traverse_py_cb_ro(self, nb::cpp_function([&](uint64_t index){
fn(payload, index);
}));
traverse_py_cb_ro(
self, nb::cpp_function([&](uint64_t index) { fn(payload, index); }));
}
inline void traverse_py_cb_rw(TraversableBase *base, void *payload,
uint64_t (*fn)(void *, uint64_t)) {
namespace nb = nanobind;
uint64_t (*fn)(void *, uint64_t)) {

namespace nb = nanobind;
nb::handle self = base->self_py();
if (!self)
return;

auto detail = nb::module_::import_("drjit.detail");
nb::callable traverse_py_cb_rw =
nb::borrow<nb::callable>(nb::getattr(detail, "traverse_py_cb_rw"));
traverse_py_cb_rw(self, nb::cpp_function([&](uint64_t index){
return fn(payload, index);
}));

traverse_py_cb_rw(self, nb::cpp_function([&](uint64_t index) {
return fn(payload, index);
}));
}

NAMESPACE_END(drjit)
40 changes: 20 additions & 20 deletions include/drjit/texture.h
Original file line number Diff line number Diff line change
Expand Up @@ -1384,26 +1384,26 @@ template <typename _Storage, size_t Dimension> class Texture : TraversableBase {
mutable bool m_migrated = false;

public:
void traverse_1_cb_ro(void *payload,
void (*fn)(void *, uint64_t)) const override {
if constexpr (!std ::is_same_v<drjit ::TraversableBase,
drjit ::TraversableBase>)
drjit ::TraversableBase ::traverse_1_cb_ro(payload, fn);

DR_TRAVERSE_MEMBER_RO(m_value)
DR_TRAVERSE_MEMBER_RO(m_shape_opaque)
DR_TRAVERSE_MEMBER_RO(m_inv_resolution)
}
void traverse_1_cb_rw(void *payload,
uint64_t (*fn)(void *, uint64_t)) override {
if constexpr (!std ::is_same_v<drjit ::TraversableBase,
drjit ::TraversableBase>)
drjit ::TraversableBase ::traverse_1_cb_rw(payload, fn);

DR_TRAVERSE_MEMBER_RW(m_value)
DR_TRAVERSE_MEMBER_RW(m_shape_opaque)
DR_TRAVERSE_MEMBER_RW(m_inv_resolution)
}
void traverse_1_cb_ro(void *payload,
void (*fn)(void *, uint64_t)) const override {
if constexpr (!std ::is_same_v<drjit ::TraversableBase,
drjit ::TraversableBase>)
drjit ::TraversableBase ::traverse_1_cb_ro(payload, fn);

DR_TRAVERSE_MEMBER_RO(m_value)
DR_TRAVERSE_MEMBER_RO(m_shape_opaque)
DR_TRAVERSE_MEMBER_RO(m_inv_resolution)
}
void traverse_1_cb_rw(void *payload,
uint64_t (*fn)(void *, uint64_t)) override {
if constexpr (!std ::is_same_v<drjit ::TraversableBase,
drjit ::TraversableBase>)
drjit ::TraversableBase ::traverse_1_cb_rw(payload, fn);

DR_TRAVERSE_MEMBER_RW(m_value)
DR_TRAVERSE_MEMBER_RW(m_shape_opaque)
DR_TRAVERSE_MEMBER_RW(m_inv_resolution)
}
};

NAMESPACE_END(drjit)
24 changes: 11 additions & 13 deletions include/drjit/traversable_base.h
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@

#pragma once

#include "drjit-core/macros.h"
#include "array_traverse.h"
#include "drjit-core/macros.h"
#include "nanobind/intrusive/counter.h"
#include "nanobind/intrusive/ref.h"
#include <type_traits>
#include <vector>
#include <drjit-core/jit.h>
#include <drjit/map.h>
#include <type_traits>
#include <vector>

NAMESPACE_BEGIN(drjit)

Expand All @@ -19,7 +18,8 @@ struct TraversableBase : nanobind::intrusive_base {

template <typename T> struct is_ref_t<nanobind::ref<T>> : std::true_type {};
template <typename T> struct is_ref_t<std::unique_ptr<T>> : std::true_type {};
// template <typename T> struct is_iterable_t<std::vector<T>> : std::true_type {};
// template <typename T> struct is_iterable_t<std::vector<T>> : std::true_type
// {};

#define DR_TRAVERSE_MEMBER_RO(member) \
drjit::log_member_open(false, #member); \
Expand All @@ -30,13 +30,11 @@ template <typename T> struct is_ref_t<std::unique_ptr<T>> : std::true_type {};
drjit::traverse_1_fn_rw(member, payload, fn); \
drjit::log_member_close();

inline void log_member_open(bool rw, const char *member){
inline void log_member_open(bool rw, const char *member) {
jit_log(LogLevel::Debug, "%s%s{", rw ? "rw " : "ro ", member);
}

inline void log_member_close(){
jit_log(LogLevel::Debug, "}");
}
inline void log_member_close() { jit_log(LogLevel::Debug, "}"); }

#define DR_TRAVERSE_CB_RO(Base, ...) \
void traverse_1_cb_ro(void *payload, void (*fn)(void *, uint64_t)) \
Expand All @@ -63,21 +61,21 @@ public: \
public: \
void traverse_1_cb_ro(void *payload, void (*fn)(void *, uint64_t)) \
const override { \
if constexpr (!std ::is_same_v<Base, drjit ::TraversableBase>) \
if constexpr (!std ::is_same_v<Base, drjit ::TraversableBase>) \
Base ::traverse_1_cb_ro(payload, fn); \
drjit::traverse_py_cb_ro(this, payload, fn); \
} \
void traverse_1_cb_rw(void *payload, uint64_t (*fn)(void *, uint64_t)) \
override { \
if constexpr (!std ::is_same_v<Base, drjit ::TraversableBase>) \
if constexpr (!std ::is_same_v<Base, drjit ::TraversableBase>) \
Base ::traverse_1_cb_rw(payload, fn); \
drjit::traverse_py_cb_rw(this, payload, fn); \
}

#if defined(_MSC_VER)
# define DRJIT_EXPORT __declspec(dllexport)
#define DRJIT_EXPORT __declspec(dllexport)
#else
# define DRJIT_EXPORT __attribute__ ((visibility("default")))
#define DRJIT_EXPORT __attribute__((visibility("default")))
#endif

NAMESPACE_END(drjit)
14 changes: 6 additions & 8 deletions src/python/detail.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,26 +271,24 @@ void disable_py_tracing() {
nb::module_::import_("sys").attr("settrace")(nb::none());
}

void traverse_py_cb_ro_impl(nb::handle self, nb::callable c){
struct PyTraverseCallback: TraverseCallback{
void operator()(nb::handle h) override{
void traverse_py_cb_ro_impl(nb::handle self, nb::callable c) {
struct PyTraverseCallback : TraverseCallback {
void operator()(nb::handle h) override {
auto index_fn = supp(h.type()).index;
if (index_fn)
operator()(index_fn(inst_ptr(h)));
}
void operator()(uint64_t index) override{
m_callback(index);
}
void operator()(uint64_t index) override { m_callback(index); }
nb::callable m_callback;

PyTraverseCallback(nb::callable c): m_callback(c){}
PyTraverseCallback(nb::callable c) : m_callback(c) {}
};

PyTraverseCallback traverse_cb(std::move(c));

auto dict = nb::borrow<nb::dict>(nb::getattr(self, "__dict__"));

for (auto value: dict.values()){
for (auto value : dict.values()) {
traverse("traverse_py_cb_ro", traverse_cb, value);
}
}
Expand Down
3 changes: 1 addition & 2 deletions tests/call_ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ using namespace nb::literals;

template <typename T>
struct Sampler : dr::TraversableBase {
Sampler() : rng(1) { }
Sampler() : rng(1) {}
Sampler(size_t size) : rng(size) { }

T next() { return rng.next_float32(); }
Expand Down Expand Up @@ -218,7 +218,6 @@ void bind(nb::module_ &m) {
.def_rw("opaque", &BT::opaque)
.def_rw("value", &BT::value);
bind_traverse(b_cls);


using BaseArray = dr::DiffArray<Backend, BaseT *>;
m.def("dispatch_f", [](BaseArray &self, Float a, Float b) {
Expand Down

0 comments on commit 33ed806

Please sign in to comment.