Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Are customized loss functions/layers supported? #119

Open
zzpustc opened this issue Dec 22, 2022 · 2 comments
Open

Are customized loss functions/layers supported? #119

zzpustc opened this issue Dec 22, 2022 · 2 comments
Milestone

Comments

@zzpustc
Copy link

zzpustc commented Dec 22, 2022

Hi!

Thanks for the great work. I'm trying to use the Laplace approximation in my work, but your package only supports MSE and CE, and the corresponding layers only support the nn.Module class. Is there any method to use your package with customized loss functions/layers?

Best

@aleximmer
Copy link
Owner

Hi,

Thanks for your interest. Unfortunately, only basic nn.Modules can be supported since this allows to compute Hessian approximations (for example see the necessary extensions in ASDL). The same applies to losses.

However, in some cases it is not that complicated to extend the corresponding backend so if you have a specific use-case, we can try to give suggestions how to get it done with the help of the library if possible.

@wiseodd
Copy link
Collaborator

wiseodd commented Mar 12, 2024

The CurvlinopsGGN/EF with the diagonal structure should be able to handle non-module layers (as long as its params are included in model.parameters()) since it's just a pure torch.func.

For custom loss functions, the requirement is that they correspond to a log-likelihood. Then we need to know how to sample or compute the 2nd derivative w.r.t. network's output. E.g. here

def _get_mc_functional_fisher(self, f):
""" Approximate the Fisher's middle matrix (expected outer product of the functional gradient)
using MC integral with `self.num_samples` many samples.
"""
F = 0
for _ in range(self.num_samples):
if self.likelihood == 'regression':
y_sample = f + torch.randn(f.shape, device=f.device) # N(y | f, 1)
grad_sample = f - y_sample # functional MSE grad
else: # classification with softmax
y_sample = torch.distributions.Multinomial(logits=f).sample()
# First functional derivative of the loglik is p - y
p = torch.softmax(f, dim=-1)
grad_sample = p - y_sample
F += 1/self.num_samples * torch.einsum('bc,bk->bck', grad_sample, grad_sample)
return F

We plan to support more likelihood, e.g. BCE (#130), after milestone 0.2.

@wiseodd wiseodd added this to the 0.3 milestone Jul 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants