From 988a39e4902a8bd10359da99333ab32adcc2a2a7 Mon Sep 17 00:00:00 2001 From: Jason Senthil Date: Mon, 3 Feb 2025 11:11:34 -0800 Subject: [PATCH] move out zero grad logic into separate function (#969) Summary: # Context Currently it isn't possible to log gradients from AutoUnit as they are zeroed out before `on_train_step_end()` is reached. # This Diff Moves out the zeroed grad from the `_update_weights` and into it's own function. Can be overridden, ie ``` class MyAutoUnit(AutoUnit): ... def zero_grad(self) -> self.logger.log(self.module.grad) super().zero_grad() ``` to log the gradients prior to zeroing them out Reviewed By: galrotem, diego-urgell Differential Revision: D68983117 --- torchtnt/framework/auto_unit.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/torchtnt/framework/auto_unit.py b/torchtnt/framework/auto_unit.py index d4c7d636bb..6e361564e0 100644 --- a/torchtnt/framework/auto_unit.py +++ b/torchtnt/framework/auto_unit.py @@ -829,6 +829,22 @@ def step_lr_scheduler(self) -> None: """ none_throws(self.lr_scheduler).step() + def zero_grad(self) -> None: + """ + Zeroes the gradients of the module's parameters. Override this if you need to log the gradients before zeroing them. + + Example of overriding: + class CustomAutoUnit(MyAutoUnit): + ... + + def zero_grad(self): + # log before zeroing gradients + super().zero_grad() + """ + + optimizer = none_throws(self.optimizer) + optimizer.zero_grad(set_to_none=True) + def _update_weights(self, state: State) -> Optional[torch.Tensor]: """ Updates weights of the module, handles clip gradient norm, etc. @@ -892,7 +908,7 @@ def _update_weights(self, state: State) -> Optional[torch.Tensor]: with get_timing_context( state, f"{self.__class__.__name__}.optimizer_zero_grad" ): - optimizer.zero_grad(set_to_none=True) + self.zero_grad() if self.step_lr_interval == "step": self._update_lr_and_swa(state, self.train_progress.num_steps_completed)