diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 1730dbedc..f3dcc0cae 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -1937,31 +1937,40 @@ def forward( # noqa: C901 raise ValueError(f"Invalid OptimType: {self.optimizer}") def ensemble_and_swap(self, ensemble_mode: Dict[str, float]) -> None: + """ + Perform ensemble and swap operations on the full sparse embedding tables. + + Returns: + Sparse embedding weights and optimizer states will be updated in-place. + """ should_ema = self.iter.item() % int(ensemble_mode["step_ema"]) == 0 should_swap = self.iter.item() % int(ensemble_mode["step_swap"]) == 0 if should_ema or should_swap: weights = self.split_embedding_weights() states = self.split_optimizer_states() + coef_ema = ( + 0.0 + if self.iter.item() <= int(ensemble_mode["step_start"]) + else ensemble_mode["step_ema_coef"] + ) for i in range(len(self.embedding_specs)): + # 0) copying weights from gpu to cpu + weights_cpu = weights[i].to( + dtype=states[i][1].dtype, device=states[i][1].device + ) + # 1) ema step if should_ema: - step_start = int(ensemble_mode["step_start"]) - if int(ensemble_mode["step_mode"]) == 1: - should_ema_reset = self.iter.item() % step_start == 0 - elif int(ensemble_mode["step_mode"]) == 2: - should_ema_reset = self.iter.item() <= step_start - else: - should_ema_reset = (self.iter.item() <= step_start) or ( - self.iter.item() % step_start == 0 - ) - coef_ema = ( - 0.0 if should_ema_reset else ensemble_mode["step_ema_coef"] - ) - weights_cpu = weights[i].to( - dtype=states[i][1].dtype, device=states[i][1].device - ) states[i][1].lerp_(weights_cpu, 1.0 - coef_ema) + # 2) swap step if should_swap: weights[i].copy_(states[i][1], non_blocking=True) + # 3) post-processing step + if should_ema: + if int(ensemble_mode["step_mode"]) == 0: # embedding scaling + states[i][1].mul_(0.0) + elif int(ensemble_mode["step_mode"]) == 1: # nesterov + states[i][1].copy_(weights_cpu, non_blocking=True) + # elif int(ensemble_mode["step_mode"]) == 2: pure ema def reset_uvm_cache_stats(self) -> None: assert (