Skip to content

Commit

Permalink
Recognizer: Incorporate EMA (#922)
Browse files Browse the repository at this point in the history
Summary:

Add EMA to the recognizer:
- Separate out learning rate scheduler updates and EMA model updates: in d2go, the EMA weights were updated every step, while the scheduler was updated every epoch. We separate them to implement the same functionality in Vizard and override `on_train_step_end` to update the EMA weights every step (irrespective of other parameters).
- Update torchtnt auto_unit to use self.device for the EMA / SWA model, which may be set from environment in the superclass init. This enables model evaluation in GPU.

Differential Revision: D64206735
  • Loading branch information
Victor Bourgin authored and facebook-github-bot committed Oct 11, 2024
1 parent 1beb1f0 commit 6d0c078
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torchtnt/framework/auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ def __init__(

self.swa_model = AveragedModel(
module_for_swa,
device=device,
device=self.device,
use_buffers=swa_params.use_buffers,
averaging_method=swa_params.averaging_method,
ema_decay=swa_params.ema_decay,
Expand Down

0 comments on commit 6d0c078

Please sign in to comment.