From d3ede552aa15233760cdd39989f4bf009a5638b3 Mon Sep 17 00:00:00 2001 From: Minhua Chen Date: Thu, 10 Oct 2024 01:01:52 -0700 Subject: [PATCH] Nesterov (#3232) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/330 using step_mode to cover a few special cases: step_mode=0: embedding scaling step_mode=1: nesterov accelerated gradient step_mode=2: pure ema (compatible with previous diff) Reviewed By: q10 Differential Revision: D63875074 --- ...t_table_batched_embeddings_ops_training.py | 33 ++++++++++--------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 1730dbedc..ecb0b434a 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -1942,26 +1942,29 @@ def ensemble_and_swap(self, ensemble_mode: Dict[str, float]) -> None: if should_ema or should_swap: weights = self.split_embedding_weights() states = self.split_optimizer_states() + coef_ema = ( + 0.0 + if self.iter.item() <= int(ensemble_mode["step_start"]) + else ensemble_mode["step_ema_coef"] + ) for i in range(len(self.embedding_specs)): + # 0) copying weights from gpu to cpu + weights_cpu = weights[i].to( + dtype=states[i][1].dtype, device=states[i][1].device + ) + # 1) ema step if should_ema: - step_start = int(ensemble_mode["step_start"]) - if int(ensemble_mode["step_mode"]) == 1: - should_ema_reset = self.iter.item() % step_start == 0 - elif int(ensemble_mode["step_mode"]) == 2: - should_ema_reset = self.iter.item() <= step_start - else: - should_ema_reset = (self.iter.item() <= step_start) or ( - self.iter.item() % step_start == 0 - ) - coef_ema = ( - 0.0 if should_ema_reset else ensemble_mode["step_ema_coef"] - ) - weights_cpu = weights[i].to( - dtype=states[i][1].dtype, device=states[i][1].device - ) states[i][1].lerp_(weights_cpu, 1.0 - coef_ema) + # 2) swap step if should_swap: weights[i].copy_(states[i][1], non_blocking=True) + # 3) post-processing step + if should_ema: + if int(ensemble_mode["step_mode"]) == 0: # embedding scaling + states[i][1].mul_(0.0) + elif int(ensemble_mode["step_mode"]) == 1: # nesterov + states[i][1].copy_(weights_cpu, non_blocking=True) + # elif int(ensemble_mode["step_mode"]) == 2: pure ema def reset_uvm_cache_stats(self) -> None: assert (