Skip to content

Commit

Permalink
Merge branch 'ClashLuke-patch-1'
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Jan 31, 2025
2 parents 8b3c07a + 490d222 commit a49b020
Showing 1 changed file with 12 additions and 22 deletions.
34 changes: 12 additions & 22 deletions timm/optim/kron.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,7 @@ def __init__(
self._param_exprs = {} # cache for einsum expr
self._tiny = torch.finfo(torch.bfloat16).tiny
self.rng = random.Random(1337)
if deterministic:
# Use a Generator to try to be more deterministic across resume (save/load)
self.torch_rng = torch.Generator().manual_seed(1337)
else:
self.torch_rng = None
self.deterministic = deterministic

# make compile optional (for bwd compat)
if has_dynamo:
Expand All @@ -178,7 +174,6 @@ def __init__(
def __getstate__(self):
_dict = super().__getstate__()
_dict["rng"] = self.rng
_dict["torch_rng"] = self.torch_rng
return _dict

def state_dict(self) -> Dict[str, Any]:
Expand All @@ -187,28 +182,21 @@ def state_dict(self) -> Dict[str, Any]:

# Add the generator state
optimizer_state['rng_state'] = self.rng.getstate()
if self.torch_rng is not None:
optimizer_state['torch_rng_state'] = self.torch_rng.get_state()

return optimizer_state

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
# Extract and remove the RNG state from the state dict
rng_states = {}
if 'rng_state' in state_dict:
rng_states['rng_state'] = state_dict.pop('rng_state')
if 'torch_rng_state' in state_dict:
rng_states['torch_rng_state'] = state_dict.pop('torch_rng_state')


# Load the optimizer state
super().load_state_dict(state_dict)
state_dict.update(rng_states) # add back

# Restore the RNG state if it exists
if 'rng_state' in rng_states:
self.rng.setstate(rng_states['rng_state'])
if 'torch_rng_state' in rng_states:
self.torch_rng.set_state(rng_states['torch_rng_state'])

def __setstate__(self, state):
super().__setstate__(state)
Expand Down Expand Up @@ -317,15 +305,17 @@ def step(self, closure=None):
if do_update:
exprA, exprGs, _ = exprs
Q = state["Q"]
if self.torch_rng is None:
V = torch.randn_like(debiased_momentum, dtype=precond_dtype)
if self.deterministic:
torch_rng = torch.Generator(device=debiased_momentum.device)
torch_rng.manual_seed(self.rng.randint(0, 2 ** 31))
else:
# Restoring generator state to device is messy. For now,
# we keep RNG on CPU, but this slows the optimizer down quite a bit.
# FIXME Need a better approach
V = torch.randn(
debiased_momentum.shape, generator=self.torch_rng, dtype=precond_dtype, device='cpu')
V = V.to(debiased_momentum.device)
torch_rng = None
V = torch.randn(
debiased_momentum.shape,
generator=torch_rng,
dtype=precond_dtype,
device=debiased_momentum.device,
)
G = debiased_momentum if momentum_into_precond_update else grad

A, conjB = self._calc_A_and_conjB(exprA, G, Q, V)
Expand Down

0 comments on commit a49b020

Please sign in to comment.