From 4847322f4f5bc86e513b861d2bcdd70bcd889f94 Mon Sep 17 00:00:00 2001 From: Minhua Chen Date: Thu, 10 Oct 2024 01:53:13 -0700 Subject: [PATCH] Nesterov (#3232) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/330 using step_mode to cover a few special cases: step_mode=0: embedding scaling step_mode=1: nesterov accelerated gradient step_mode=2: pure ema (compatible with previous diff) Reviewed By: q10 Differential Revision: D63875074 --- ...t_table_batched_embeddings_ops_training.py | 39 ++++++++++++------- 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 1730dbedc..f3dcc0cae 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -1937,31 +1937,40 @@ def forward( # noqa: C901 raise ValueError(f"Invalid OptimType: {self.optimizer}") def ensemble_and_swap(self, ensemble_mode: Dict[str, float]) -> None: + """ + Perform ensemble and swap operations on the full sparse embedding tables. + + Returns: + Sparse embedding weights and optimizer states will be updated in-place. + """ should_ema = self.iter.item() % int(ensemble_mode["step_ema"]) == 0 should_swap = self.iter.item() % int(ensemble_mode["step_swap"]) == 0 if should_ema or should_swap: weights = self.split_embedding_weights() states = self.split_optimizer_states() + coef_ema = ( + 0.0 + if self.iter.item() <= int(ensemble_mode["step_start"]) + else ensemble_mode["step_ema_coef"] + ) for i in range(len(self.embedding_specs)): + # 0) copying weights from gpu to cpu + weights_cpu = weights[i].to( + dtype=states[i][1].dtype, device=states[i][1].device + ) + # 1) ema step if should_ema: - step_start = int(ensemble_mode["step_start"]) - if int(ensemble_mode["step_mode"]) == 1: - should_ema_reset = self.iter.item() % step_start == 0 - elif int(ensemble_mode["step_mode"]) == 2: - should_ema_reset = self.iter.item() <= step_start - else: - should_ema_reset = (self.iter.item() <= step_start) or ( - self.iter.item() % step_start == 0 - ) - coef_ema = ( - 0.0 if should_ema_reset else ensemble_mode["step_ema_coef"] - ) - weights_cpu = weights[i].to( - dtype=states[i][1].dtype, device=states[i][1].device - ) states[i][1].lerp_(weights_cpu, 1.0 - coef_ema) + # 2) swap step if should_swap: weights[i].copy_(states[i][1], non_blocking=True) + # 3) post-processing step + if should_ema: + if int(ensemble_mode["step_mode"]) == 0: # embedding scaling + states[i][1].mul_(0.0) + elif int(ensemble_mode["step_mode"]) == 1: # nesterov + states[i][1].copy_(weights_cpu, non_blocking=True) + # elif int(ensemble_mode["step_mode"]) == 2: pure ema def reset_uvm_cache_stats(self) -> None: assert (