From ac8606ea63e1640ca097d65c2f004edc9bd9d8a4 Mon Sep 17 00:00:00 2001 From: Supadchaya Puangpontip Date: Thu, 9 Jan 2025 13:37:20 -0800 Subject: [PATCH] Add new optimizer state `row_counter` for Adam [Frontend] Summary: A new optional optimizer state `row_counter` is added to Adam to perform bias correction per embedding row. `row_counter` serves as the iteration counter when a row (an index) occurs and used to do bias correction. Without rowwise bias correction (existing Adam), ``` m_hat_t = m_t / (1.0 - powf(beta1, iter)); v_hat_t = v_t / (1.0 - powf(beta2, iter)); ``` With rowwise bias correction enabled. ``` // when index `idx` occurs _row_counter = row_counter[idx] + 1; m_hat_t = m_t / (1.0 - powf(beta1, _row_counter)); v_hat_t = v_t / (1.0 - powf(beta2, _row_counter)); ``` This request is from IG to allow all the models to be scaled on sparse features with expected 1.5% NE on Stories. ------- **__The functionality is not set by default.__** Frontend: D64848802 To enable the bias correction, `use_rowwise_bias_correction` needs to be set to True through extra_optimizer_config. ``` extra_optimizer_config = UserEnabledConfigDefinition(use_rowwise_bias_correction=True) emb_op = SplitTableBatchedEmbeddingBagsCodegen ( embedding_specs=[ (E, D, M, compute_device) for (E, D, M) in zip(Es, Ds, managed) ], optimizer=OptimType.Adam extra_optimizer_config=extra_optimizer_config, ... ) ``` ------ **__Performance__** ``` Baseline* | default** | enabled*** forward | cpu | 2.293 s | 2.188 s | 2.043 s | cuda | 12.512 ms | 12.539 ms | 12.547 ms backward | cpu | 69.861 ms | 66.546 ms | 65.880 ms | cuda | 103.429 ms | 103.395 ms | 103.130 ms ``` \* Baseline: before changes \** default: default setting; use_bias_correction = False \*** enabled: use_bias_correction = True Reviewed By: sryap Differential Revision: D64848802 --- .../training/python/lookup_args.template | 1 + ..._embedding_codegen_lookup_invoker.template | 27 ++++- ...t_table_batched_embeddings_ops_training.py | 91 ++++++++++++--- fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py | 1 + .../tbe/training/backward_optimizers_test.py | 108 ++++++++++++++++-- 5 files changed, 201 insertions(+), 27 deletions(-) diff --git a/fbgemm_gpu/codegen/training/python/lookup_args.template b/fbgemm_gpu/codegen/training/python/lookup_args.template index f3fd7aa87a..ca79a16f8f 100644 --- a/fbgemm_gpu/codegen/training/python/lookup_args.template +++ b/fbgemm_gpu/codegen/training/python/lookup_args.template @@ -76,6 +76,7 @@ class OptimizerArgs(NamedTuple): weight_norm_coefficient: float lower_bound: float regularization_mode: int + use_rowwise_bias_correction: bool # Used for OptimType.ADAM class Momentum(NamedTuple): diff --git a/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template b/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template index b55b850c5d..1b100e4ce6 100644 --- a/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template +++ b/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template @@ -81,6 +81,12 @@ def invoke( prev_iter_dev: Optional[torch.Tensor] = None, {%- endif %} gwd_lower_bound: float = 0.0, + {%- if "use_rowwise_bias_correction" in args_pt2.unified_pt2.split_function_arg_names %} + use_rowwise_bias_correction: bool = False, + {%- endif %} + {%- if "row_counter" in args_pt2.unified_pt2.split_saved_tensorlist_optional %} + row_counter: Optional[Momentum] = None, + {%- endif %} ) -> torch.Tensor: {%- if is_experimental_optimizer %} # By design, the warning only shows up once @@ -94,7 +100,20 @@ def invoke( {%- endif %} vbe_metadata = common_args.vbe_metadata - + {%- if "row_counter" in args_pt2.unified_pt2.split_saved_tensorlist_optional %} + if not use_rowwise_bias_correction or row_counter is None: + row_counter_dev = None + row_counter_uvm = None + row_counter_offsets = None + row_counter_placements = None + elif use_rowwise_bias_correction and row_counter is None: + assert False, "use_rowwise_bias_correction is set but row_counter cannot be None" + else: + row_counter_dev = row_counter.host + row_counter_uvm = row_counter.uvm + row_counter_offsets = row_counter.offsets + row_counter_placements = row_counter.placements + {%- endif %} {%- if has_cpu_support and not ssd %} if (common_args.host_weights.numel() > 0): T = common_args.D_offsets.numel() - 1 @@ -393,6 +412,12 @@ def invoke( row_counter_offsets=row_counter.offsets, row_counter_placements=row_counter.placements, {%- endif %} + {%- if "row_counter" in args_pt2.unified_pt2.split_saved_tensorlist_optional %} + row_counter_dev=row_counter_dev, + row_counter_uvm=row_counter_uvm, + row_counter_offsets=row_counter_offsets, + row_counter_placements=row_counter_placements, + {%- endif %} # iter iter=iter, # max counter diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 8f8d5779ea..93080fa484 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -145,6 +145,15 @@ class GlobalWeightDecayDefinition: lower_bound: float = 0.0 +@dataclass(frozen=True) +class UserEnabledConfigDefinition: + """ + This class is used to configure whether certain modes are to be enabled + """ + + use_rowwise_bias_correction: bool = False # Adam + + @dataclass(frozen=True) class EnsembleModeDefinition: step_ema: float = 10000 @@ -630,6 +639,7 @@ def __init__( # noqa C901 multipass_prefetch_config: Optional[MultiPassPrefetchConfig] = None, global_weight_decay: Optional[GlobalWeightDecayDefinition] = None, uvm_host_mapped: bool = False, + extra_optimizer_config: Optional[UserEnabledConfigDefinition] = None, ) -> None: super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__() @@ -1006,6 +1016,20 @@ def __init__( # noqa C901 # and CowClipDefinition are not used counter_halflife = -1 + # TO DO: Enable this on the new interface + # learning_rate_tensor = torch.tensor( + # learning_rate, device=torch.device("cpu"), dtype=torch.float + # ) + if extra_optimizer_config is None: + extra_optimizer_config = UserEnabledConfigDefinition() + self.use_rowwise_bias_correction: bool = ( + extra_optimizer_config.use_rowwise_bias_correction + ) + if self.use_rowwise_bias_correction and not self.optimizer == OptimType.ADAM: + raise AssertionError( + "`use_rowwise_bias_correction` is only supported for OptimType.ADAM", + ) + self.optimizer_args = invokers.lookup_args.OptimizerArgs( stochastic_rounding=stochastic_rounding, gradient_clipping=gradient_clipping, @@ -1032,6 +1056,7 @@ def __init__( # noqa C901 weight_norm_coefficient=cowclip_regularization.weight_norm_coefficient, lower_bound=cowclip_regularization.lower_bound, regularization_mode=weight_decay_mode.value, + use_rowwise_bias_correction=self.use_rowwise_bias_correction, ) if optimizer != OptimType.NONE: @@ -1168,6 +1193,19 @@ def __init__( # noqa C901 torch.ones(1, dtype=torch.float32, device=self.current_device), persistent=False, ) + elif optimizer == OptimType.ADAM and self.use_rowwise_bias_correction: + self._apply_split( + construct_split_state( + embedding_specs, + rowwise=True, + cacheable=False, + ), + prefix="row_counter", + # pyre-fixme[6]: Expected `Type[Type[torch._dtype]]` for 3rd param + # but got `Type[torch.float32]`. + dtype=torch.float32, + uvm_host_mapped=self.uvm_host_mapped, + ) else: self._register_nonpersistent_buffers("prev_iter") self._register_nonpersistent_buffers("row_counter") @@ -1192,7 +1230,6 @@ def __init__( # noqa C901 "iter", torch.zeros(1, dtype=torch.int64, device=self.current_device), ) - else: self.register_buffer( "iter", @@ -1895,6 +1932,24 @@ def forward( # noqa: C901 iter_int = int(self.iter_cpu.add_(1).item()) # used for local computation self.iter.add_(1) # used for checkpointing + row_counter = invokers.lookup_args.Momentum( + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `Union[Module, Tensor]`. + dev=self.row_counter_dev, + # pyre-fixme[6]: For 2nd argument expected `Tensor` but got + # `Union[Module, Tensor]`. + host=self.row_counter_host, + # pyre-fixme[6]: For 3rd argument expected `Tensor` but got + # `Union[Module, Tensor]`. + uvm=self.row_counter_uvm, + # pyre-fixme[6]: For 4th argument expected `Tensor` but got + # `Union[Module, Tensor]`. + offsets=self.row_counter_offsets, + # pyre-fixme[6]: For 5th argument expected `Tensor` but got + # `Union[Module, Tensor]`. + placements=self.row_counter_placements, + ) + if self.optimizer == OptimType.ADAM: return self._report_io_size_count( "fwd_output", @@ -1904,6 +1959,10 @@ def forward( # noqa: C901 momentum1, momentum2, iter_int, + self.use_rowwise_bias_correction, + row_counter=( + row_counter if self.use_rowwise_bias_correction else None + ), ), ) if self.optimizer == OptimType.PARTIAL_ROWWISE_ADAM: @@ -1957,23 +2016,6 @@ def forward( # noqa: C901 # `Union[Module, Tensor]`. placements=self.prev_iter_placements, ) - row_counter = invokers.lookup_args.Momentum( - # pyre-fixme[6]: For 1st argument expected `Tensor` but got - # `Union[Module, Tensor]`. - dev=self.row_counter_dev, - # pyre-fixme[6]: For 2nd argument expected `Tensor` but got - # `Union[Module, Tensor]`. - host=self.row_counter_host, - # pyre-fixme[6]: For 3rd argument expected `Tensor` but got - # `Union[Module, Tensor]`. - uvm=self.row_counter_uvm, - # pyre-fixme[6]: For 4th argument expected `Tensor` but got - # `Union[Module, Tensor]`. - offsets=self.row_counter_offsets, - # pyre-fixme[6]: For 5th argument expected `Tensor` but got - # `Union[Module, Tensor]`. - placements=self.row_counter_placements, - ) if self.optimizer == OptimType.EMAINPLACE_ROWWISE_ADAGRAD: with torch.no_grad(): @@ -2543,6 +2585,15 @@ def get_optimizer_state(self) -> List[Dict[str, torch.Tensor]]: list_of_state_dict = [ {"momentum_buffer": states[0]} for states in split_optimizer_states ] + elif self.optimizer == OptimType.ADAM and self.use_rowwise_bias_correction: + list_of_state_dict = [ + { + "exp_avg": states[0], + "exp_avg_sq": states[1], + "row_counter": states[2], + } + for states in split_optimizer_states + ] elif ( self.optimizer == OptimType.ADAM or self.optimizer == OptimType.PARTIAL_ROWWISE_ADAM @@ -2717,7 +2768,9 @@ def get_optimizer_states( rowwise=True, ) ) - if self._used_rowwise_adagrad_with_counter: + if self._used_rowwise_adagrad_with_counter or ( + self.optimizer == OptimType.ADAM and self.use_rowwise_bias_correction + ): states.append( get_optimizer_states( # pyre-fixme[6]: For 1st argument expected `Tensor` but got diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index 36dae3e11c..a4ef78d888 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -570,6 +570,7 @@ def __init__( weight_norm_coefficient=cowclip_regularization.weight_norm_coefficient, lower_bound=cowclip_regularization.lower_bound, regularization_mode=weight_decay_mode.value, + use_rowwise_bias_correction=False, # Unused, this is used in TBE's Adam ) table_embedding_dtype = weights_precision.as_dtype() diff --git a/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py b/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py index 91b0d95fde..b9f8d7f308 100644 --- a/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py +++ b/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py @@ -33,6 +33,7 @@ SplitTableBatchedEmbeddingBagsCodegen, StepMode, TailIdThreshold, + UserEnabledConfigDefinition, WeightDecayMode, ) @@ -105,6 +106,7 @@ def execute_backward_optimizers_( # noqa C901 weight_decay_mode: WeightDecayMode = WeightDecayMode.NONE, uvm_non_rowwise_momentum: bool = False, optimizer_state_dtypes: Optional[Dict[str, SparseType]] = None, + use_rowwise_bias_correction: bool = False, ) -> None: # NOTE: limit (T * B * L * D) to avoid timeout for CPU version! @@ -307,6 +309,9 @@ def execute_backward_optimizers_( # noqa C901 optimizer_kwargs["weight_decay"] = weight_decay optimizer_kwargs["optimizer_state_dtypes"] = optimizer_state_dtypes + extra_optimizer_config = UserEnabledConfigDefinition( + use_rowwise_bias_correction=use_rowwise_bias_correction + ) if optimizer in (OptimType.PARTIAL_ROWWISE_LAMB, OptimType.LAMB): optimizer_kwargs["eps"] = eps optimizer_kwargs["beta1"] = beta1 @@ -335,6 +340,10 @@ def execute_backward_optimizers_( # noqa C901 step_ema_coef=momentum, step_mode=step_mode, ) + row_counter_ref = [torch.zeros(E, dtype=torch.float32) for E in Es] + if optimizer == OptimType.ADAM and use_rowwise_bias_correction: + for i, indices in enumerate(xs): + row_counter_ref[i][indices.cpu()] += 1 if optimizer == OptimType.EMAINPLACE_ROWWISE_ADAGRAD: (eps, step_ema, step_start) = ( @@ -356,6 +365,7 @@ def execute_backward_optimizers_( # noqa C901 optimizer=optimizer, pooling_mode=pooling_mode, uvm_non_rowwise_momentum=uvm_non_rowwise_momentum, + extra_optimizer_config=extra_optimizer_config, **optimizer_kwargs, ) @@ -539,8 +549,21 @@ def execute_backward_optimizers_( # noqa C901 if optimizer in (OptimType.PARTIAL_ROWWISE_ADAM, OptimType.ADAM): rowwise = optimizer == OptimType.PARTIAL_ROWWISE_ADAM + row_counter: Optional[torch.Tensor] = None for t in range(T): - (m1, m2) = split_optimizer_states[t] + if rowwise or not use_rowwise_bias_correction: + (m1, m2) = split_optimizer_states[t] + else: # Full adam with rowwise bias correction + (m1, m2, row_counter) = split_optimizer_states[t] + # check row counter + row_counter = row_counter.cpu() + torch.testing.assert_close( + row_counter, + row_counter_ref[t], + atol=0, + rtol=0, + ) + row_counter = row_counter.reshape(row_counter.size(0), 1) # Some optimizers have non-float momentums dense_cpu_grad = bs[t].weight.grad.cpu().to_dense() m2_ref = ( @@ -552,9 +575,10 @@ def execute_backward_optimizers_( # noqa C901 m1_ref = dense_cpu_grad * (1.0 - beta1) self.assert_close_optim_state(m1, m1_ref) iter_ = cc.iter.item() - v_hat_t = m2_ref / (1 - beta2**iter_) + power = row_counter if use_rowwise_bias_correction else iter_ + v_hat_t = m2_ref / (1 - beta2**power) v_hat_t = v_hat_t if not rowwise else v_hat_t.view(v_hat_t.numel(), 1) - m_hat_t = m1_ref / (1 - beta1**iter_) + m_hat_t = m1_ref / (1 - beta1**power) weights_new = split_weights[t] weights_ref = ( torch.addcdiv( @@ -574,10 +598,16 @@ def execute_backward_optimizers_( # noqa C901 if get_optimizer_states is not None: optimizer_states_dict = get_optimizer_states[t] - assert set(optimizer_states_dict.keys()) == { - "exp_avg", - "exp_avg_sq", - } + state_keys = ( + { + "exp_avg", + "exp_avg_sq", + "row_counter", + } + if use_rowwise_bias_correction + else {"exp_avg", "exp_avg_sq"} + ) + assert set(optimizer_states_dict.keys()) == state_keys if optimizer == OptimType.ENSEMBLE_ROWWISE_ADAGRAD: for t in range(T): @@ -938,6 +968,70 @@ def test_backward_optimizers_adam( # noqa C901 uvm_non_rowwise_momentum=uvm_non_rowwise_momentum, ) + @given( + T=st.integers(min_value=1, max_value=5), + D=st.integers(min_value=2, max_value=256), + B=st.integers(min_value=1, max_value=128), + log_E=st.integers(min_value=3, max_value=5), + L=st.integers(min_value=0, max_value=20), + weighted=st.booleans(), + mixed=st.booleans(), + optimizer=st.sampled_from( + [ + OptimType.ADAM, + OptimType.PARTIAL_ROWWISE_ADAM, + ] + ), + long_segments=st.booleans(), + pooling_mode=st.sampled_from( + [ + PoolingMode.SUM, + PoolingMode.MEAN, + PoolingMode.NONE, + ] + ), + use_cpu=use_cpu_strategy(), + uvm_non_rowwise_momentum=st.booleans(), + ) + @settings( + verbosity=VERBOSITY, + max_examples=MAX_EXAMPLES_LONG_RUNNING, + deadline=None, + suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], + ) + @unittest.skipIf(*gpu_unavailable) + def test_backward_optimizers_adam_rowwise_bias_correction( # noqa C901 + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + weighted: bool, + mixed: bool, + optimizer: OptimType, + long_segments: bool, + pooling_mode: PoolingMode, + use_cpu: bool, + uvm_non_rowwise_momentum: bool, + ) -> None: + self.execute_backward_optimizers_( + T, + D, + B, + log_E, + L, + weighted, + mixed, + False, # mixed_B + optimizer, + long_segments, + pooling_mode, + use_cpu, + uvm_non_rowwise_momentum=uvm_non_rowwise_momentum, + use_rowwise_bias_correction=True, + ) + @given( T=st.integers(min_value=1, max_value=5), D=st.integers(min_value=2, max_value=256),