Skip to content

Commit

Permalink
Nesterov (pytorch#3232)
Browse files Browse the repository at this point in the history
Summary:

X-link: facebookresearch/FBGEMM#330

using step_mode to cover a few special cases:
      step_mode=0: embedding scaling
      step_mode=1: nesterov accelerated gradient
      step_mode=2: pure ema (compatible with previous diff)

Reviewed By: q10

Differential Revision: D63875074
  • Loading branch information
minhua-chen authored and facebook-github-bot committed Oct 10, 2024
1 parent 8e7beba commit 4847322
Showing 1 changed file with 24 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1937,31 +1937,40 @@ def forward( # noqa: C901
raise ValueError(f"Invalid OptimType: {self.optimizer}")

def ensemble_and_swap(self, ensemble_mode: Dict[str, float]) -> None:
"""
Perform ensemble and swap operations on the full sparse embedding tables.
Returns:
Sparse embedding weights and optimizer states will be updated in-place.
"""
should_ema = self.iter.item() % int(ensemble_mode["step_ema"]) == 0
should_swap = self.iter.item() % int(ensemble_mode["step_swap"]) == 0
if should_ema or should_swap:
weights = self.split_embedding_weights()
states = self.split_optimizer_states()
coef_ema = (
0.0
if self.iter.item() <= int(ensemble_mode["step_start"])
else ensemble_mode["step_ema_coef"]
)
for i in range(len(self.embedding_specs)):
# 0) copying weights from gpu to cpu
weights_cpu = weights[i].to(
dtype=states[i][1].dtype, device=states[i][1].device
)
# 1) ema step
if should_ema:
step_start = int(ensemble_mode["step_start"])
if int(ensemble_mode["step_mode"]) == 1:
should_ema_reset = self.iter.item() % step_start == 0
elif int(ensemble_mode["step_mode"]) == 2:
should_ema_reset = self.iter.item() <= step_start
else:
should_ema_reset = (self.iter.item() <= step_start) or (
self.iter.item() % step_start == 0
)
coef_ema = (
0.0 if should_ema_reset else ensemble_mode["step_ema_coef"]
)
weights_cpu = weights[i].to(
dtype=states[i][1].dtype, device=states[i][1].device
)
states[i][1].lerp_(weights_cpu, 1.0 - coef_ema)
# 2) swap step
if should_swap:
weights[i].copy_(states[i][1], non_blocking=True)
# 3) post-processing step
if should_ema:
if int(ensemble_mode["step_mode"]) == 0: # embedding scaling
states[i][1].mul_(0.0)
elif int(ensemble_mode["step_mode"]) == 1: # nesterov
states[i][1].copy_(weights_cpu, non_blocking=True)
# elif int(ensemble_mode["step_mode"]) == 2: pure ema

def reset_uvm_cache_stats(self) -> None:
assert (
Expand Down

0 comments on commit 4847322

Please sign in to comment.