You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I noticed that the only allowed loss is CrossEntropyLoss. As I'd need your module to compute Influence Functions on other kinds of losses, how about adding the possibility to use a different loss, perhaps passed as param in the config dict?
I was thinking about the following changes (to avoid disruptive changes to the API):
Hi,
Thank you very much for this implementation!
I noticed that the only allowed loss is
CrossEntropyLoss
. As I'd need your module to compute Influence Functions on other kinds of losses, how about adding the possibility to use a different loss, perhaps passed as param in theconfig
dict?I was thinking about the following changes (to avoid disruptive changes to the API):
pytorch_influence_functions/pytorch_influence_functions/influence_function.py
Lines 8 to 9 in 4df5d2e
to
And then change
pytorch_influence_functions/pytorch_influence_functions/influence_function.py
Line 45 in 4df5d2e
to
With something analogous for:
pytorch_influence_functions/pytorch_influence_functions/influence_function.py
Line 75 in 4df5d2e
It would be obviously necessary to propagate the
loss_fn
params to the functions that callgrad_z
ands_test
.I'd be glad to create a new branch and open a pull request with this change!
Thanks :)
The text was updated successfully, but these errors were encountered: