-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Scale factors across plate dims in partial_sum_product
#606
Conversation
@eb8680 I opened a PR in NumPyro (pyro-ppl/numpyro#1572) where I tried to expand the math and check whether expectations match for a full- and a mini-batch log-likelihoods. Can you verify that the math is correct? |
funsor/sum_product.py
Outdated
elif sum_op is ops.add and prod_op is ops.mul: | ||
pow_op = ops.pow | ||
else: | ||
raise ValueError("should not be here!") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be NotImplementedError
funsor/sum_product.py
Outdated
if plate_to_scale: | ||
if sum_op is ops.logaddexp and prod_op is ops.add: | ||
pow_op = ops.mul | ||
elif sum_op is ops.add and prod_op is ops.mul: | ||
pow_op = ops.pow | ||
else: | ||
raise ValueError("should not be here!") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move this out to a PROD_TO_POWER dict or similar, see ops
funsor/sum_product.py
Outdated
f = f.reduce(prod_op, leaf & eliminate) | ||
f_scales = [ | ||
plate_to_scale[plate] | ||
for plate in leaf & eliminate | ||
if plate in plate_to_scale | ||
] | ||
if f_scales: | ||
scale = reduce(ops.mul, f_scales) | ||
f = pow_op(f, scale) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we wrap this in an if plate_to_scale:
guard to improve readability?
funsor/sum_product.py
Outdated
@@ -306,6 +330,14 @@ def partial_sum_product( | |||
reduced_plates = leaf - new_plates | |||
assert reduced_plates.issubset(eliminate) | |||
f = f.reduce(prod_op, reduced_plates) | |||
f_scales = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto: if plate_to_scale: ...
funsor/sum_product.py
Outdated
eliminate=frozenset(), | ||
plates=frozenset(), | ||
pedantic=False, | ||
plate_to_scale={}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
default to None, maybe comment on datatype
@fritzo can you review the changes please? Also I believe the rule for jax tests needs to be changed in the Settings to use python 3.9 instead of 3.8. Jax doesn't support 3.8 anymore. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for addressing my comments.
One of the features not supported by
TraceEnum_ELBO
is that you cannot subsample a local variable when it depends on a global variable that is enumerated in the model because it requires a common scale:This has been asked on the forum as well: https://forum.pyro.ai/t/enumeration-and-subsampling-expected-all-enumerated-sample-sites-to-share-common-poutine-scale/4938
A solution I'm proposing here is to perform plate-wise scaling inside the
partial_sum_product
by passing theplate_to_scale
dictionary. Then whenever a plate is reduced we can scale the factor: