Skip to content

Commit

Permalink
Fix comment, add 'stochastic weight decay' idea because why not
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Jan 30, 2025
1 parent 5940cc1 commit 5f85f8e
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions timm/optim/kron.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ class Kron(torch.optim.Optimizer):
precond_dtype: Dtype of the preconditioner.
decoupled_decay: AdamW style decoupled weight decay
flatten: Flatten dimensions instead of fully relying on expressions for higher rank params
flatten_start_end: Range of dimensions to flatten, defaults to (2, -1).
flatten_start_dim: Start of flatten range, defaults to 2. Seems good tradeoff for ConvNets.
flatten_end_dim: End of flatten range, defaults to -1.
stochastic_weight_decay: Enable random modulation of weight decay
deterministic: Deterministic behaviour across save / load (resume). FIXME slow, needs work
"""

Expand All @@ -118,6 +120,7 @@ def __init__(
flatten: bool = False,
flatten_start_dim: int = 2,
flatten_end_dim: int = -1,
stochastic_weight_decay: bool = False,
deterministic: bool = False,
):
if not has_opt_einsum:
Expand Down Expand Up @@ -147,6 +150,7 @@ def __init__(
flatten=flatten,
flatten_start_dim=flatten_start_dim,
flatten_end_dim=flatten_end_dim,
stochastic_weight_decay=stochastic_weight_decay,
)
super(Kron, self).__init__(params, defaults)

Expand Down Expand Up @@ -353,11 +357,15 @@ def step(self, closure=None):
pre_grad = pre_grad.view(p.shape)

# Apply weight decay
if group["weight_decay"] != 0:
weight_decay = group["weight_decay"]
if weight_decay != 0:
if group["stochastic_weight_decay"]:
weight_decay = 2 * self.rng.random() * weight_decay

if group["decoupled_decay"]:
p.mul_(1. - group["lr"] * group["weight_decay"])
p.mul_(1. - group["lr"] * weight_decay)
else:
pre_grad.add_(p, alpha=group["weight_decay"])
pre_grad.add_(p, alpha=weight_decay)

# Update parameters
p.add_(pre_grad, alpha=-group["lr"])
Expand Down

0 comments on commit 5f85f8e

Please sign in to comment.