diff --git a/kronfluence/module/tracked_module.py b/kronfluence/module/tracked_module.py index 85cd6f2..923c7b5 100644 --- a/kronfluence/module/tracked_module.py +++ b/kronfluence/module/tracked_module.py @@ -163,8 +163,8 @@ def forward(self, inputs: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tens The output of the forward pass. """ outputs = self.original_module(inputs, *args, **kwargs) - # if outputs.requires_grad: - # return outputs + if outputs.requires_grad and self.gradient_scale == 1.0: + return outputs return outputs + self._constant def prepare_storage(self, device: torch.device) -> None: