diff --git a/trojanvision/defenses/backdoor/model_inspection/deep_inspect.py b/trojanvision/defenses/backdoor/model_inspection/deep_inspect.py index d52ca074..a30114fc 100644 --- a/trojanvision/defenses/backdoor/model_inspection/deep_inspect.py +++ b/trojanvision/defenses/backdoor/model_inspection/deep_inspect.py @@ -45,10 +45,11 @@ def __init__(self, defense_remask_epoch: int = 20, defense_remask_lr=0.01, subset, _ = self.dataset.split_dataset(dataset, percent=sample_ratio) self.loader = self.dataset.get_dataloader(mode='train', dataset=subset) - def optimize_mark(self, label: int) -> tuple[torch.Tensor, float]: + def optimize_mark(self, label: int, **kwargs) -> tuple[torch.Tensor, float]: r""" Args: label (int): The class label to optimize. + **kwargs: Any keyword argument (unused). Returns: (torch.Tensor, torch.Tensor):