Skip to content

Commit

Permalink
Added traverse_callback functions for frozen functions support
Browse files Browse the repository at this point in the history
  • Loading branch information
DoeringChristian committed Jan 27, 2025
1 parent f2967f1 commit 2d8b9a8
Show file tree
Hide file tree
Showing 119 changed files with 1,807 additions and 208 deletions.
2 changes: 2 additions & 0 deletions include/mitsuba/core/bitmap.h
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,8 @@ class MI_EXPORT_LIB Bitmap : public Object {
bool m_premultiplied_alpha;
bool m_owns_data;
Properties m_metadata;

DR_TRAVERSE_CB(Object, m_size);
};


Expand Down
2 changes: 2 additions & 0 deletions include/mitsuba/core/bsphere.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ template <typename Point_> struct BoundingSphere {
dr::squared_norm(o) - dr::square(radius)
);
}

DRJIT_STRUCT_NODEF(BoundingSphere, center, radius)
};

/// Print a string representation of the bounding sphere
Expand Down
2 changes: 1 addition & 1 deletion include/mitsuba/core/class.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ NAMESPACE_END(detail)

#define MI_REGISTRY_PUT(name, ptr) \
if constexpr (dr::is_jit_v<Float>) { \
jit_registry_put(::mitsuba::detail::get_variant<Float, Spectrum>(), \
drjit::registry_put(::mitsuba::detail::get_variant<Float, Spectrum>(), \
"mitsuba::" name, ptr); \
}

Expand Down
17 changes: 14 additions & 3 deletions include/mitsuba/core/distr_1d.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <mitsuba/core/vector.h>
#include <mitsuba/core/math.h>
#include <drjit/dynamic.h>
#include <drjit/traversable_base.h>

NAMESPACE_BEGIN(mitsuba)

Expand All @@ -17,7 +18,7 @@ NAMESPACE_BEGIN(mitsuba)
* initialization. The associated scale factor can be retrieved using the
* function \ref normalization().
*/
template <typename Value> struct DiscreteDistribution {
template <typename Value> struct DiscreteDistribution: drjit::TraversableBase {
using Float = std::conditional_t<dr::is_static_array_v<Value>,
dr::value_t<Value>, Value>;
using FloatStorage = DynamicBuffer<Float>;
Expand Down Expand Up @@ -269,6 +270,9 @@ template <typename Value> struct DiscreteDistribution {
Float m_sum = 0.f;
Float m_normalization = 0.f;
Vector2u m_valid;

DR_TRAVERSE_CB(drjit::TraversableBase, m_pmf, m_cdf, m_sum, m_normalization,
m_valid);
};

/**
Expand All @@ -283,7 +287,7 @@ template <typename Value> struct DiscreteDistribution {
* initialization. The associated scale factor can be retrieved using the
* function \ref normalization().
*/
template <typename Value> struct ContinuousDistribution {
template <typename Value> struct ContinuousDistribution: drjit::TraversableBase {
using Float = std::conditional_t<dr::is_static_array_v<Value>,
dr::value_t<Value>, Value>;
using FloatStorage = DynamicBuffer<Float>;
Expand Down Expand Up @@ -601,6 +605,10 @@ template <typename Value> struct ContinuousDistribution {
ScalarVector2f m_range { 0.f, 0.f };
Vector2u m_valid;
ScalarFloat m_max = 0.f;

DR_TRAVERSE_CB(drjit::TraversableBase, m_pdf, m_cdf, m_integral,
m_normalization, m_interval_size, m_inv_interval_size,
m_valid);
};

/**
Expand All @@ -615,7 +623,7 @@ template <typename Value> struct ContinuousDistribution {
* initialization. The associated scale factor can be retrieved using the
* function \ref normalization().
*/
template <typename Value> struct IrregularContinuousDistribution {
template <typename Value> struct IrregularContinuousDistribution : public drjit::TraversableBase{
using Float = std::conditional_t<dr::is_static_array_v<Value>,
dr::value_t<Value>, Value>;
using FloatStorage = DynamicBuffer<Float>;
Expand Down Expand Up @@ -973,6 +981,9 @@ template <typename Value> struct IrregularContinuousDistribution {
Vector2u m_valid;
ScalarFloat m_interval_size = 0.f;
ScalarFloat m_max = 0.f;

DR_TRAVERSE_CB(drjit::TraversableBase, m_nodes, m_pdf, m_cdf, m_integral,
m_normalization, m_valid);
};

template <typename Value>
Expand Down
37 changes: 35 additions & 2 deletions include/mitsuba/core/distr_2d.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <mitsuba/core/util.h>
#include <drjit/dynamic.h>
#include <array>
#include <drjit/traversable_base.h>

NAMESPACE_BEGIN(mitsuba)

Expand Down Expand Up @@ -72,7 +73,7 @@ NAMESPACE_BEGIN(mitsuba)
*/

template <typename Float_, size_t Dimension_ = 0>
class DiscreteDistribution2D {
class DiscreteDistribution2D : drjit::TraversableBase{
public:
using Float = Float_;
using UInt32 = dr::uint32_array_t<Float>;
Expand Down Expand Up @@ -201,10 +202,14 @@ class DiscreteDistribution2D {

Float m_inv_normalization;
Float m_normalization;

DR_TRAVERSE_CB(drjit::TraversableBase, m_data, m_marg_cdf, m_cond_cdf,
m_inv_normalization, m_normalization)
};

/// Base class of Hierarchical2D and Marginal2D with common functionality
template <typename Float_, size_t Dimension_ = 0> class Distribution2D {
template <typename Float_, size_t Dimension_ = 0>
class Distribution2D : drjit::TraversableBase {
public:
static constexpr size_t Dimension = Dimension_;
using Float = Float_;
Expand Down Expand Up @@ -308,6 +313,28 @@ template <typename Float_, size_t Dimension_ = 0> class Distribution2D {

/// Total number of slices (in case Dimension > 1)
uint32_t m_slices;

public:
void
traverse_1_cb_ro(void *payload,
drjit::detail::traverse_callback_ro fn) const override {
if constexpr (!std ::is_same_v<drjit ::TraversableBase,
drjit ::TraversableBase>)
drjit ::TraversableBase ::traverse_1_cb_ro(payload, fn);
for (const auto &param_value : m_param_values) {
drjit ::traverse_1_fn_ro(param_value, payload, fn);
}
}
void traverse_1_cb_rw(void *payload,
drjit::detail::traverse_callback_rw fn) override {
if constexpr (!std ::is_same_v<drjit ::TraversableBase,
drjit ::TraversableBase>)
drjit ::TraversableBase ::traverse_1_cb_rw(payload, fn);

for (auto &param_value : m_param_values) {
drjit ::traverse_1_fn_rw(param_value, payload, fn);
}
}
};

/**
Expand Down Expand Up @@ -788,13 +815,17 @@ class Hierarchical2D : public Distribution2D<Float_, Dimension_> {
return dr::gather<Float>(data, i0, active);
}
}

DRJIT_STRUCT_NODEF(Level, data)
};

/// MIP hierarchy over linearly interpolated patches
std::vector<Level> m_levels;

/// Number of bilinear patches in the X/Y dimension - 1
ScalarVector2u m_max_patch_index;

DR_TRAVERSE_CB(Base, m_levels)
};

/**
Expand Down Expand Up @@ -1454,6 +1485,8 @@ class Marginal2D : public Distribution2D<Float_, Dimension_> {

/// Are the probability values normalized?
bool m_normalized;

DR_TRAVERSE_CB(Base, m_data, m_marg_cdf, m_cond_cdf)
};

//! @}
Expand Down
18 changes: 18 additions & 0 deletions include/mitsuba/core/field.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <type_traits>

#include <drjit/array.h>
#include <drjit/array_traverse.h>
namespace dr = drjit;

NAMESPACE_BEGIN(mitsuba)
Expand Down Expand Up @@ -62,6 +63,12 @@ struct field<DeviceType, HostType,
}
private:
DeviceType m_scalar;

public:
void traverse_1_cb_ro(void * /*payload*/,
drjit::detail::traverse_callback_ro) const {}
void traverse_1_cb_rw(void * /*payload*/,
drjit::detail::traverse_callback_rw) {}
};

template <typename DeviceType, typename HostType>
Expand Down Expand Up @@ -105,6 +112,17 @@ struct field<DeviceType, HostType,
private:
DeviceType m_value;
HostType m_scalar;

public:
void traverse_1_cb_ro(void *payload,
drjit::detail::traverse_callback_ro fn) const {

drjit ::traverse_1_fn_ro(m_value, payload, fn);
}
void traverse_1_cb_rw(void *payload,
drjit::detail::traverse_callback_rw fn) {
drjit ::traverse_1_fn_rw(m_value, payload, fn);
}
};

/// Prints the canonical string representation of a field
Expand Down
24 changes: 24 additions & 0 deletions include/mitsuba/core/fwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,30 @@ extern "C" {
})
#endif

#define MI_DECLARE_TRAVERSE_CB() \
public: \
void traverse_1_cb_ro(void *payload, \
drjit::detail::traverse_callback_ro fn) \
const override; \
void traverse_1_cb_rw( \
void *payload, drjit::detail::traverse_callback_rw fn) override;

#define MI_IMPLEMENT_TRAVERSE_CB(Type, Base, ...) \
MI_VARIANT \
void Type<Float, Spectrum>::traverse_1_cb_ro( \
void *payload, drjit::detail::traverse_callback_ro fn) const { \
if constexpr (!std ::is_same_v<Base, drjit ::TraversableBase>) \
Base ::traverse_1_cb_ro(payload, fn); \
DRJIT_MAP(DR_TRAVERSE_MEMBER_RO, __VA_ARGS__) \
} \
MI_VARIANT \
void Type<Float, Spectrum>::traverse_1_cb_rw( \
void *payload, drjit::detail::traverse_callback_rw fn) { \
if constexpr (!std ::is_same_v<Base, drjit ::TraversableBase>) \
Base ::traverse_1_cb_rw(payload, fn); \
DRJIT_MAP(DR_TRAVERSE_MEMBER_RW, __VA_ARGS__) \
}

//! @}
// =============================================================

Expand Down
5 changes: 4 additions & 1 deletion include/mitsuba/core/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <atomic>
#include <stdexcept>
#include <mitsuba/core/class.h>
#include <drjit/traversable_base.h>

NAMESPACE_BEGIN(mitsuba)

Expand All @@ -29,8 +30,10 @@ NAMESPACE_BEGIN(mitsuba)
* Python, this counter is shared with Python such that the ownerhsip and
* lifetime of any ``Object`` instance across C++ and Python is managed by it.
*/
class MI_EXPORT_LIB Object : public nanobind::intrusive_base {
class MI_EXPORT_LIB Object : public drjit::TraversableBase {
public:
DR_TRAVERSE_CB(drjit::TraversableBase)

/// Default constructor
Object() { }

Expand Down
2 changes: 2 additions & 0 deletions include/mitsuba/render/bsdf.h
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,8 @@ class MI_EXPORT_LIB BSDF : public Object {

/// Identifier (if available)
std::string m_id;

DR_TRAVERSE_CB(Object);
};

// -----------------------------------------------------------------------
Expand Down
2 changes: 2 additions & 0 deletions include/mitsuba/render/emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ class MI_EXPORT_LIB Emitter : public Endpoint<Float, Spectrum> {

/// True if the emitters's parameters have changed
bool m_dirty = false;

DR_TRAVERSE_CB(Base);
};

MI_EXTERN_CLASS(Emitter)
Expand Down
2 changes: 2 additions & 0 deletions include/mitsuba/render/endpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,8 @@ class MI_EXPORT_LIB Endpoint : public Object {
bool m_needs_sample_2 = true;
bool m_needs_sample_3 = true;
std::string m_id;

MI_DECLARE_TRAVERSE_CB()
};

MI_EXTERN_CLASS(Endpoint)
Expand Down
2 changes: 2 additions & 0 deletions include/mitsuba/render/film.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,8 @@ class MI_EXPORT_LIB Film : public Object {
bool m_sample_border;
ref<ReconstructionFilter> m_filter;
ref<Texture> m_srf;

MI_DECLARE_TRAVERSE_CB()
};

MI_EXTERN_CLASS(Film)
Expand Down
2 changes: 2 additions & 0 deletions include/mitsuba/render/imageblock.h
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,8 @@ class MI_EXPORT_LIB ImageBlock : public Object {
bool m_compensate;
bool m_warn_negative;
bool m_warn_invalid;

DR_TRAVERSE_CB(Object, m_tensor, m_tensor_compensation)
};

MI_EXTERN_CLASS(ImageBlock)
Expand Down
Loading

0 comments on commit 2d8b9a8

Please sign in to comment.