You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
test and call this function within SmoothMPRT metric
defexplain_smooth_batch_torch(
self,
model: ModelInterface,
x_batch: np.ndarray,
y_batch: np.ndarray,
std: float,
**kwargs,
) ->np.ndarray:
""" Compute explanations, normalise and take absolute (if was configured so during metric initialization.) This method should primarily be used if you need to generate additional explanation in metrics body. It encapsulates typical for Quantus pre- and postprocessing approach. It will do few things: - call model.shape_input (if ModelInterface instance was provided) - unwrap model (if ModelInterface instance was provided) - call explain_func - expand attribution channel Parameters ------- model: A model that is subject to explanation. x_batch: A np.ndarray which contains the input data that are explained. y_batch: A np.ndarray which contains the output labels that are explained. std : float Standard deviation of the Gaussian noise. kwargs: optional, dict List of hyperparameters. Returns ------- a_batch: Batch of explanations ready to be evaluated. """ifnotisinstance(x_batch, torch.Tensor):
x_batch=torch.Tensor(x_batch).to(self.device)
ifnotisinstance(y_batch, torch.Tensor):
y_batch=torch.as_tensor(y_batch).to(self.device)
a_batch_smooth=torch.zeros_like(x_batch)
forninrange(self.nr_samples):
# the last epsilon is defined as zero to compute the true output,# and have SmoothGrad w/ n_iter = 1 === gradientifn==self.nr_samples-1:
epsilon=torch.zeros_like(x_batch)
else:
epsilon=torch.randn_like(x_batch) *stda_batch=quantus.explain(model, x_batch+epsilon, y_batch, **kwargs)
ifa_batch_smoothisNone:
a_batch_smooth=a_batch/self.nr_sampleselse:
a_batch_smooth+=a_batch/self.nr_samplesreturna_batch_smooth
Minimum acceptance criteria
Specify what is necessary for the issue to be closed.
@mentions of the person that is apt to review these changes e.g., @annahedstroem
The text was updated successfully, but these errors were encountered:
annahedstroem
changed the title
Add torch implementation to SmoothMPRT
Make SmoothMPRT faster for torch (see implementation)
Nov 24, 2023
Description of the problem
Description of a solution
Minimum acceptance criteria
The text was updated successfully, but these errors were encountered: