Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Frozen Function Support #1477

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/mitsuba/core/bitmap.h
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,8 @@ class MI_EXPORT_LIB Bitmap : public Object {
bool m_premultiplied_alpha;
bool m_owns_data;
Properties m_metadata;

DR_TRAVERSE_CB(Object, m_size);
};


Expand Down
4 changes: 3 additions & 1 deletion include/mitsuba/core/bsphere.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
NAMESPACE_BEGIN(mitsuba)

/// Generic n-dimensional bounding sphere data structure
template <typename Point_> struct BoundingSphere {
template <typename Point_> struct BoundingSphere: drjit::TraversableBase {
static constexpr size_t Size = Point_::Size;
using Point = Point_;
using Float = dr::value_t<Point>;
Expand Down Expand Up @@ -74,6 +74,8 @@ template <typename Point_> struct BoundingSphere {
dr::squared_norm(o) - dr::square(radius)
);
}

DR_TRAVERSE_CB(drjit::TraversableBase, center, radius);
};

/// Print a string representation of the bounding sphere
Expand Down
2 changes: 1 addition & 1 deletion include/mitsuba/core/class.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ NAMESPACE_END(detail)

#define MI_REGISTRY_PUT(name, ptr) \
if constexpr (dr::is_jit_v<Float>) { \
jit_registry_put(::mitsuba::detail::get_variant<Float, Spectrum>(), \
drjit::registry_put(::mitsuba::detail::get_variant<Float, Spectrum>(), \
"mitsuba::" name, ptr); \
}

Expand Down
17 changes: 14 additions & 3 deletions include/mitsuba/core/distr_1d.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <mitsuba/core/vector.h>
#include <mitsuba/core/math.h>
#include <drjit/dynamic.h>
#include <drjit/traversable_base.h>

NAMESPACE_BEGIN(mitsuba)

Expand All @@ -17,7 +18,7 @@ NAMESPACE_BEGIN(mitsuba)
* initialization. The associated scale factor can be retrieved using the
* function \ref normalization().
*/
template <typename Value> struct DiscreteDistribution {
template <typename Value> struct DiscreteDistribution: drjit::TraversableBase {
using Float = std::conditional_t<dr::is_static_array_v<Value>,
dr::value_t<Value>, Value>;
using FloatStorage = DynamicBuffer<Float>;
Expand Down Expand Up @@ -269,6 +270,9 @@ template <typename Value> struct DiscreteDistribution {
Float m_sum = 0.f;
Float m_normalization = 0.f;
Vector2u m_valid;

DR_TRAVERSE_CB(drjit::TraversableBase, m_pmf, m_cdf, m_sum, m_normalization,
m_valid);
};

/**
Expand All @@ -283,7 +287,7 @@ template <typename Value> struct DiscreteDistribution {
* initialization. The associated scale factor can be retrieved using the
* function \ref normalization().
*/
template <typename Value> struct ContinuousDistribution {
template <typename Value> struct ContinuousDistribution: drjit::TraversableBase {
using Float = std::conditional_t<dr::is_static_array_v<Value>,
dr::value_t<Value>, Value>;
using FloatStorage = DynamicBuffer<Float>;
Expand Down Expand Up @@ -601,6 +605,10 @@ template <typename Value> struct ContinuousDistribution {
ScalarVector2f m_range { 0.f, 0.f };
Vector2u m_valid;
ScalarFloat m_max = 0.f;

DR_TRAVERSE_CB(drjit::TraversableBase, m_pdf, m_cdf, m_integral,
m_normalization, m_interval_size, m_inv_interval_size,
m_valid);
};

/**
Expand All @@ -615,7 +623,7 @@ template <typename Value> struct ContinuousDistribution {
* initialization. The associated scale factor can be retrieved using the
* function \ref normalization().
*/
template <typename Value> struct IrregularContinuousDistribution {
template <typename Value> struct IrregularContinuousDistribution : public drjit::TraversableBase{
using Float = std::conditional_t<dr::is_static_array_v<Value>,
dr::value_t<Value>, Value>;
using FloatStorage = DynamicBuffer<Float>;
Expand Down Expand Up @@ -973,6 +981,9 @@ template <typename Value> struct IrregularContinuousDistribution {
Vector2u m_valid;
ScalarFloat m_interval_size = 0.f;
ScalarFloat m_max = 0.f;

DR_TRAVERSE_CB(drjit::TraversableBase, m_nodes, m_pdf, m_cdf, m_integral,
m_normalization, m_valid);
};

template <typename Value>
Expand Down
37 changes: 35 additions & 2 deletions include/mitsuba/core/distr_2d.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <mitsuba/core/util.h>
#include <drjit/dynamic.h>
#include <array>
#include <drjit/traversable_base.h>

NAMESPACE_BEGIN(mitsuba)

Expand Down Expand Up @@ -72,7 +73,7 @@ NAMESPACE_BEGIN(mitsuba)
*/

template <typename Float_, size_t Dimension_ = 0>
class DiscreteDistribution2D {
class DiscreteDistribution2D : drjit::TraversableBase{
public:
using Float = Float_;
using UInt32 = dr::uint32_array_t<Float>;
Expand Down Expand Up @@ -201,10 +202,14 @@ class DiscreteDistribution2D {

Float m_inv_normalization;
Float m_normalization;

DR_TRAVERSE_CB(drjit::TraversableBase, m_data, m_marg_cdf, m_cond_cdf,
m_inv_normalization, m_normalization)
};

/// Base class of Hierarchical2D and Marginal2D with common functionality
template <typename Float_, size_t Dimension_ = 0> class Distribution2D {
template <typename Float_, size_t Dimension_ = 0>
class Distribution2D : drjit::TraversableBase {
public:
static constexpr size_t Dimension = Dimension_;
using Float = Float_;
Expand Down Expand Up @@ -308,6 +313,28 @@ template <typename Float_, size_t Dimension_ = 0> class Distribution2D {

/// Total number of slices (in case Dimension > 1)
uint32_t m_slices;

public:
void
traverse_1_cb_ro(void *payload,
drjit::detail::traverse_callback_ro fn) const override {
if constexpr (!std ::is_same_v<drjit ::TraversableBase,
drjit ::TraversableBase>)
drjit ::TraversableBase ::traverse_1_cb_ro(payload, fn);
for (const auto &param_value : m_param_values) {
drjit ::traverse_1_fn_ro(param_value, payload, fn);
}
}
void traverse_1_cb_rw(void *payload,
drjit::detail::traverse_callback_rw fn) override {
if constexpr (!std ::is_same_v<drjit ::TraversableBase,
drjit ::TraversableBase>)
drjit ::TraversableBase ::traverse_1_cb_rw(payload, fn);

for (auto &param_value : m_param_values) {
drjit ::traverse_1_fn_rw(param_value, payload, fn);
}
}
};

/**
Expand Down Expand Up @@ -788,13 +815,17 @@ class Hierarchical2D : public Distribution2D<Float_, Dimension_> {
return dr::gather<Float>(data, i0, active);
}
}

DRJIT_STRUCT_NODEF(Level, data)
};

/// MIP hierarchy over linearly interpolated patches
std::vector<Level> m_levels;

/// Number of bilinear patches in the X/Y dimension - 1
ScalarVector2u m_max_patch_index;

DR_TRAVERSE_CB(Base, m_levels)
};

/**
Expand Down Expand Up @@ -1454,6 +1485,8 @@ class Marginal2D : public Distribution2D<Float_, Dimension_> {

/// Are the probability values normalized?
bool m_normalized;

DR_TRAVERSE_CB(Base, m_data, m_marg_cdf, m_cond_cdf)
};

//! @}
Expand Down
18 changes: 18 additions & 0 deletions include/mitsuba/core/field.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <type_traits>

#include <drjit/array.h>
#include <drjit/array_traverse.h>
namespace dr = drjit;

NAMESPACE_BEGIN(mitsuba)
Expand Down Expand Up @@ -62,6 +63,12 @@ struct field<DeviceType, HostType,
}
private:
DeviceType m_scalar;

public:
void traverse_1_cb_ro(void * /*payload*/,
drjit::detail::traverse_callback_ro) const {}
void traverse_1_cb_rw(void * /*payload*/,
drjit::detail::traverse_callback_rw) {}
};

template <typename DeviceType, typename HostType>
Expand Down Expand Up @@ -105,6 +112,17 @@ struct field<DeviceType, HostType,
private:
DeviceType m_value;
HostType m_scalar;

public:
void traverse_1_cb_ro(void *payload,
drjit::detail::traverse_callback_ro fn) const {

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

drjit ::traverse_1_fn_ro(m_value, payload, fn);
}
void traverse_1_cb_rw(void *payload,
drjit::detail::traverse_callback_rw fn) {
drjit ::traverse_1_fn_rw(m_value, payload, fn);
}
};

/// Prints the canonical string representation of a field
Expand Down
24 changes: 24 additions & 0 deletions include/mitsuba/core/fwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,30 @@ extern "C" {
})
#endif

#define MI_DECLARE_TRAVERSE_CB() \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alignment of \ (also below)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we already discussed this, but I really wish there was a way to have the fields listed in the DECLARE_ macro, even if the actual implementation is in IMPLEMENT_ macro.
If someone adds a field to one of the classes that has the DECLARE_ macro, it would be very easy to forget to add it to the IMPLEMENT_ call in the .cpp file.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to have this called automatically by MI_DECLARE_CLASS()?
Especially if we keep this form which doesn't have any arguments.

public: \
void traverse_1_cb_ro(void *payload, \
drjit::detail::traverse_callback_ro fn) \
const override; \
void traverse_1_cb_rw( \
void *payload, drjit::detail::traverse_callback_rw fn) override;

#define MI_IMPLEMENT_TRAVERSE_CB(Type, Base, ...) \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise, could this be called by MI_IMPLEMENT_CLASS_VARIANTS?

MI_VARIANT \
void Type<Float, Spectrum>::traverse_1_cb_ro( \
void *payload, drjit::detail::traverse_callback_ro fn) const { \
if constexpr (!std ::is_same_v<Base, drjit ::TraversableBase>) \
Base ::traverse_1_cb_ro(payload, fn); \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Base ::traverse_1_cb_ro(payload, fn); \
Base::traverse_1_cb_ro(payload, fn); \

and also below.

DRJIT_MAP(DR_TRAVERSE_MEMBER_RO, __VA_ARGS__) \
} \
MI_VARIANT \
void Type<Float, Spectrum>::traverse_1_cb_rw( \
void *payload, drjit::detail::traverse_callback_rw fn) { \
if constexpr (!std ::is_same_v<Base, drjit ::TraversableBase>) \
Base ::traverse_1_cb_rw(payload, fn); \
DRJIT_MAP(DR_TRAVERSE_MEMBER_RW, __VA_ARGS__) \
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should add an alias for DR_TRAVERSE_CB, e.g. MI_PLUGIN_TRAVERSE or something.
Right now it's confusing why some places use MI_IMPLEMENT_TRAVERSE_CB and others use DR_TRAVERSE_CB.

The two MI_*_TRAVERSE_CB macros above could also be aliased to MI_*_INTERFACE_TRAVERSE or something, so that it's more clear which macro needs to be called where.

//! @}
// =============================================================

Expand Down
5 changes: 4 additions & 1 deletion include/mitsuba/core/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <atomic>
#include <stdexcept>
#include <mitsuba/core/class.h>
#include <drjit/traversable_base.h>

NAMESPACE_BEGIN(mitsuba)

Expand All @@ -29,8 +30,10 @@ NAMESPACE_BEGIN(mitsuba)
* Python, this counter is shared with Python such that the ownerhsip and
* lifetime of any ``Object`` instance across C++ and Python is managed by it.
*/
class MI_EXPORT_LIB Object : public nanobind::intrusive_base {
class MI_EXPORT_LIB Object : public drjit::TraversableBase {
public:
DR_TRAVERSE_CB(drjit::TraversableBase)

/// Default constructor
Object() { }

Expand Down
2 changes: 2 additions & 0 deletions include/mitsuba/render/bsdf.h
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,8 @@ class MI_EXPORT_LIB BSDF : public Object {

/// Identifier (if available)
std::string m_id;

DR_TRAVERSE_CB(Object);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to use Base consistently for most of these?

};

// -----------------------------------------------------------------------
Expand Down
2 changes: 2 additions & 0 deletions include/mitsuba/render/emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ class MI_EXPORT_LIB Emitter : public Endpoint<Float, Spectrum> {

/// True if the emitters's parameters have changed
bool m_dirty = false;

DR_TRAVERSE_CB(Base);
};

MI_EXTERN_CLASS(Emitter)
Expand Down
2 changes: 2 additions & 0 deletions include/mitsuba/render/endpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,8 @@ class MI_EXPORT_LIB Endpoint : public Object {
bool m_needs_sample_2 = true;
bool m_needs_sample_3 = true;
std::string m_id;

MI_DECLARE_TRAVERSE_CB()
};

MI_EXTERN_CLASS(Endpoint)
Expand Down
2 changes: 2 additions & 0 deletions include/mitsuba/render/film.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,8 @@ class MI_EXPORT_LIB Film : public Object {
bool m_sample_border;
ref<ReconstructionFilter> m_filter;
ref<Texture> m_srf;

MI_DECLARE_TRAVERSE_CB()
};

MI_EXTERN_CLASS(Film)
Expand Down
2 changes: 2 additions & 0 deletions include/mitsuba/render/imageblock.h
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,8 @@ class MI_EXPORT_LIB ImageBlock : public Object {
bool m_compensate;
bool m_warn_negative;
bool m_warn_invalid;

DR_TRAVERSE_CB(Object, m_tensor, m_tensor_compensation)
};

MI_EXTERN_CLASS(ImageBlock)
Expand Down
Loading