-
Notifications
You must be signed in to change notification settings - Fork 264
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Frozen Function Support #1477
base: master
Are you sure you want to change the base?
Frozen Function Support #1477
Conversation
ced44ea
to
2d8b9a8
Compare
b1a70c2
to
25c82f2
Compare
25c82f2
to
3781eef
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice PR! I like that the tests systematically cover a lot of the plugins.
Given the number of new tests, we may have to check that they don't slow down the CI too much.
Mostly small comments & consistency fixes, the bigger points are about:
- How to name the traverse macros
- Ways to improve the tests a bit
public: | ||
void traverse_1_cb_ro(void *payload, | ||
drjit::detail::traverse_callback_ro fn) const { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -368,6 +368,30 @@ extern "C" { | |||
}) | |||
#endif | |||
|
|||
#define MI_DECLARE_TRAVERSE_CB() \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alignment of \
(also below)
void Type<Float, Spectrum>::traverse_1_cb_ro( \ | ||
void *payload, drjit::detail::traverse_callback_ro fn) const { \ | ||
if constexpr (!std ::is_same_v<Base, drjit ::TraversableBase>) \ | ||
Base ::traverse_1_cb_ro(payload, fn); \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Base ::traverse_1_cb_ro(payload, fn); \ | |
Base::traverse_1_cb_ro(payload, fn); \ |
and also below.
@@ -609,6 +609,8 @@ class MI_EXPORT_LIB BSDF : public Object { | |||
|
|||
/// Identifier (if available) | |||
std::string m_id; | |||
|
|||
DR_TRAVERSE_CB(Object); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be possible to use Base
consistently for most of these?
@@ -636,6 +636,13 @@ class MI_EXPORT_LIB Scene : public Object { | |||
std::unique_ptr<DiscreteDistribution<Float>> m_silhouette_distr = nullptr; | |||
|
|||
bool m_shapes_grad_enabled; | |||
|
|||
void traverse_1_cb_ro_cpu(void *payload, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a comment to explain why these two are needed on top of the normal MI_DECLARE_TRAVERSE_CB()
macro.
mi.util.write_bitmap(f"out/{shape}/ref{i}.jpg", ref) | ||
mi.util.write_bitmap(f"out/{shape}/frozen{i}.jpg", frozen) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Save as EXR
if optimizer == "adam": | ||
opt = mi.ad.Adam(lr=0.05) | ||
elif optimizer == "sgd": | ||
opt = mi.ad.SGD(lr=0.005) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could try enabling momentum in SGD, since that's one more field that gets used.
os.makedirs(f"out/{medium}", exist_ok=True) | ||
mi.util.write_bitmap(f"out/{medium}/ref{i}.jpg", ref) | ||
mi.util.write_bitmap(f"out/{medium}/frozen{i}.jpg", frozen) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comments as above
mi.util.write_bitmap(f"out/{sampler}/frozen{i}.jpg", frozen) | ||
|
||
for ref, frozen in zip(images_ref, images_frozen): | ||
assert dr.allclose(ref, frozen) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing newline at end of file (you can configure your editor so that it always adds one).
os.makedirs(f"out/{sampler}", exist_ok=True) | ||
mi.util.write_bitmap(f"out/{sampler}/ref{i}.jpg", ref) | ||
mi.util.write_bitmap(f"out/{sampler}/frozen{i}.jpg", frozen) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as above.
Description
This PR adds
TRAVERSE_CB
macros to most classes in Mitsuba. This allows the frozen function feature to traverse a scene and enables frozen function support (see here).Additionally, it expands on #1326 to pass UInt32 type seeds, allowing rendering scenes in frozen functions with different seeds without re-tracing.
The following is a more concise list of changes:
Object
is now inheriting fromdrjit::TraversableBase
and implements thetraverse_1_cb_ro
andtraverse_1_cb_rw
functions.drjit::registry_put
instead ofjit_registry_put
, enforcing that they inherit fromdrjit::TraversableBase
at compile time.traverse_1_cb_ro
andtraverse_cb_rw
from thedrjit::TraversableBase
class using either a combination of the macrosDR_TRAVERSE_CB
,MI_DECLARE_TRAVERSE_CB
, andMI_IMPLEMENT_TRAVERSE_CB
.DRJIT_STRUCT
propertyScene
s and it's subtypes (GPU and CPU)Testing
Added tests for shapes, bsdfs, emitters, integrators, optimizers, meidia and samplers as well as tests verifying gradient descent in frozen functions (pose and material parameter optimization).
Checklist
cuda_*
andllvm_*
variants. If you can't test this, please leave below