Skip to content

Commit

Permalink
Tkurth/mplamb fixed (NVIDIA#1684)
Browse files Browse the repository at this point in the history
  • Loading branch information
azrael417 authored Jun 22, 2023
1 parent 2d8302a commit 30a7ad3
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions apex/optimizers/fused_mixed_precision_lamb.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,25 @@ def __init__(self, params, lr=1e-3, step=0, bias_correction=True,
amsgrad=False, adam_w_mode=True,
grad_averaging=True, max_grad_norm=1.0, use_nvlamb=False,
reduced_precision_dtype=None):

if amsgrad:
raise RuntimeError('FusedLAMB does not support the AMSGrad variant.')

# The learning rate (lr) and optimizer step (step) should be located on device
# in order to faciliated device sync free execution

# init defaults
defaults = dict(lr=torch.tensor(lr, dtype=torch.float32),
step=torch.tensor([step], dtype=torch.int),
bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay,
grad_averaging=grad_averaging,
max_grad_norm=max_grad_norm)
tensor_state = ['lr', 'step']

# init base module
super(FusedMixedPrecisionLamb, self).__init__(params, defaults)

# The learning rate (lr) and optimizer step (step) should be located on device
# in order to faciliated device sync free execution
device = self.param_groups[0]['params'][0].device

tensor_state = ['lr', 'step']
for idx,group in enumerate(self.param_groups):
for item in tensor_state:
self.param_groups[idx][item] = group[item].to(device=device)
Expand Down

0 comments on commit 30a7ad3

Please sign in to comment.