-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Obtanining graident of LRP otput w.r.t. network parameters #183
Comments
Hey @MikiFER Otherwise, you can try to use this proof of concept I quickly put together: Codefrom itertools import islice
import torch
from torchvision.models import AlexNet
from zennit.core import BasicHook, ParamMod
from zennit.rules import Epsilon, Gamma, ZBox
from zennit.composites import EpsilonGammaBox
from zennit.attribution import Gradient
from zennit.types import Convolution
class ParamBasicHook(BasicHook):
'''Hook to also get the relevance wrt. Parameters'''
def backward(self, module, grad_input, grad_output):
'''Backward hook to compute LRP based on the class attributes.'''
original_input = self.stored_tensors['input'][0].clone()
inputs = []
outputs = []
params = {key: [] for key, _ in module.named_parameters()}
for in_mod, param_mod, out_mod in zip(self.input_modifiers, self.param_modifiers, self.output_modifiers):
input = in_mod(original_input).requires_grad_()
with ParamMod.ensure(param_mod)(module) as modified, torch.autograd.enable_grad():
# remember the gradient state
grad_states = [param.requires_grad for param in modified.parameters()]
# require the gradients to compute the relevance
for param in modified.parameters():
param.requires_grad_()
output = modified.forward(input)
output = out_mod(output)
# keep track of the params for later gradient computation
for key, param in modified.named_parameters():
params[key].append(param)
# reset the gradient state
for param, grad_state in zip(modified.parameters(), grad_states):
param.requires_grad = grad_state
inputs.append(input)
outputs.append(output)
grad_outputs = self.gradient_mapper(grad_output[0], outputs)
if isinstance(grad_outputs, torch.Tensor):
grad_outputs = [grad_outputs]
gradients = torch.autograd.grad(
outputs * (1 + len(params)),
inputs + sum(params.values(), []),
grad_outputs=grad_outputs * (1 + len(params)),
create_graph=grad_output[0].requires_grad
)
grad_groups = [list(islice(elem, len(inputs))) for elem in [iter(gradients)] * (1 + len(params))]
relevance = self.reducer(inputs, grad_groups[0])
# set the .grad of the original parameter
for (key, param), gradient in zip(params.items(), grad_groups[1:]):
getattr(module, key).grad = self.reducer(param, gradient)
return tuple(relevance if original.shape == relevance.shape else None for original in grad_input)
@classmethod
def inject(cls, hook_type):
'''Create a subclass of hook_type and this class, injecting this class
before BasicHook in order to give this class' backward a higher
priority. May also be done manually with e.g.
``class EpsilonParam(Epsilon, ParamBasicHook, BasicHook): pass``.'''
return type(f'{hook_type.__name__}Param', (hook_type, cls, BasicHook), {})
ZBoxParam = ParamBasicHook.inject(ZBox)
GammaParam = ParamBasicHook.inject(Gamma)
EpsilonParam = ParamBasicHook.inject(Epsilon)
def main():
torch.manual_seed(0xdeadbeef)
net = AlexNet().eval()
layer_map = [
(torch.nn.Linear, EpsilonParam(epsilon=1e-6)),
(Convolution, GammaParam(gamma=0.25)),
]
first_map = [
(Convolution, ZBoxParam(low=-3., high=3.)),
]
composite = EpsilonGammaBox(low=-3., high=3., layer_map=layer_map, first_map=first_map)
data = torch.randn((1, 3, 224, 224))
# not needed for LRP, only using this to compute the gradients
for param in net.parameters():
param.requires_grad = True
net(data).sum().backward()
weight_grad = net.features[0].weight.grad[:]
for param in net.parameters():
del param.grad
param.requires_grad = False
# compute LRP
with Gradient(net, composite=composite) as attributor:
out2, relevance = attributor(data)
weight_relevance = net.features[0].weight.grad[:]
# demonstrate that the gradient was modified
print((weight_grad - weight_relevance).abs().sum())
if __name__ == '__main__':
main() Here, the trick is to create a new @MaxH1996 may also be interested in this code. |
Hi @chr5tphr thank you for the response. |
Hey @MikiFER, sorry for the confusion! Let me know in case you have issues with this, so we can try to figure it out together. |
Thank you so much. Will try it and will get back to you if there are issues :) |
Hi @chr5tphr by diving little deeper into the code I arrived to a question I cannot answer.
I am afraid that calculating combined_loss.backward() will result in gradients of the canonized network but I want to optimize parameters of the "normal" network that is batch-norm and appropriate linear layer parameters will never be optimized. Is there something that I am not understanding correctly? |
Hey @MikiFER, theoretically, this should not be a problem, as the canonized parameters should be computed from the original parameters in such a way that the gradient is the same. You can check it out by directly installing with pip: pip install git+https://github.com/chr5tphr/zennit.git@canonizer-merge-batchnorm-gradfix Let me know whether it works for you. Here's a proof of concept checkimport torch
from zennit.core import Composite
from zennit.canonizers import SequentialMergeBatchNorm
def main():
torch.manual_seed(0xdeadbeef)
net = torch.nn.Sequential(
torch.nn.Linear(32, 32),
torch.nn.BatchNorm1d(32),
)
weight = net[0].weight
net.eval()
net[1].running_mean += 1.
net[1].running_var *= 3.
canonizers = [
SequentialMergeBatchNorm()
]
composite = Composite(canonizers=canonizers)
data = torch.randn((1, 32))
weight.requires_grad = True
out_base = net(data).sum()
grad_base, = torch.autograd.grad(out_base, weight)
with composite.context(net) as modified:
out_canon = modified(data).sum()
grad_canon, = torch.autograd.grad(out_canon, weight)
print((out_base - out_canon).abs().sum())
print((grad_base - grad_canon).abs().sum())
if __name__ == '__main__':
main() |
Hi @chr5tphr thank you so much, I will try it out and get back to you if there are any more issues. |
Assuming there were no more issues, closing this for now after merging #185 . Feel free to reopen once something pops up. |
Hi @chr5tphr I have a question regarding the obtained explanation using the ResNetCanonizer in combination with EpsilonPlusFlat composite. I noticed that sum of attributions for the input image is not 1 even though when using LRP with starting relevance for the output layer equal to 1 sum of relevance in all layers should be 1. Here is piece of code I used to replicate this behavior.
Am I not understanding something correctly or is this an error? |
Hey MikiFER, usually, the attributions will not sum to one, unless you are certain that no attribution is lost to the bias, which you can do by passing composite = EpsilonPlusFlat(canonizers=[canonizer], zero_params='bias') While investigating your issue, I noticed that, although #185 increased the overall attribution stability within ResNet, it lead to a negative attribution sum in the input (which can happen if some attribution is lost to biases in combination with skip-connections), for which I have opened #194. While at least for |
Hi @chr5tphr thanks for the response. Also one unrelated question. Have you tried paring up your library with pytorch-lightning? I get some weird results when trying to use half precision (fp16) training where model inference results in NaN result when inside composite context. |
Hey @MikiFER it's not the stability parameters, but the bias term, which silently receives attribution. where the denominator includes not only This lost relevance can be omitted by removing the bias term from the denominator, which There is, however, as you also pointed out, currently something wrong with the changes introduced by #185, and my investigation so far points to the ResNet canonizer. To have a better overview, feel free to create new issue when the topics are not directly related. |
Hi, first of all thank you for all the hard work that was put into developing this framework and then making it available to everyone.
I was wondering if there is a way to obtain gradient of the explanation obtained using LRP with respect to the network parameters in order to optimize it.
I stumbled across your overview paper and would like to use the framework in my own EGL research.
The text was updated successfully, but these errors were encountered: