Skip to content

Commit

Permalink
Add new optimizer state row_counter for Adam [Frontend]
Browse files Browse the repository at this point in the history
Summary:
A new optional optimizer state `row_counter` is added to Adam to perform bias correction per embedding row. `row_counter` serves as  the iteration counter when a row (an index) occurs and used to do bias correction.


Without rowwise bias correction (existing Adam),
```
m_hat_t = m_t / (1.0 - powf(beta1, iter));
v_hat_t = v_t / (1.0 - powf(beta2, iter));
```

With rowwise bias correction enabled.
```
// when index `idx` occurs
_row_counter = row_counter[idx] + 1;
m_hat_t = m_t / (1.0 - powf(beta1, _row_counter));
v_hat_t = v_t / (1.0 - powf(beta2, _row_counter));
```

This request is from IG to allow all the models to be scaled on sparse features with expected 1.5% NE on Stories.

-------

**__The functionality is not set by default.__** Frontend: D64848802

To enable the bias correction, `use_rowwise_bias_correction` needs to be set to True through extra_optimizer_config. 
```
extra_optimizer_config = UserEnabledConfigDefinition(use_rowwise_bias_correction=True)
emb_op = SplitTableBatchedEmbeddingBagsCodegen
(
            embedding_specs=[
                (E, D, M, compute_device) for (E, D, M) in zip(Es, Ds, managed)
            ],
            optimizer=OptimType.Adam
            extra_optimizer_config=extra_optimizer_config,
            ...
)
```
------
**__Performance__**
```
                   Baseline* |  default** | enabled*** 
forward  | cpu  |   2.293 s  |   2.188 s  |   2.043 s
         | cuda |  12.512 ms |  12.539 ms |  12.547 ms
backward | cpu  |  69.861 ms |  66.546 ms |  65.880 ms
         | cuda | 103.429 ms | 103.395 ms | 103.130 ms
```
\* Baseline: before changes
\** default: default setting; use_bias_correction = False
\*** enabled: use_bias_correction = True

Reviewed By: sryap

Differential Revision: D64848802
  • Loading branch information
spcyppt authored and facebook-github-bot committed Jan 9, 2025
1 parent c5a7077 commit ac8606e
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 27 deletions.
1 change: 1 addition & 0 deletions fbgemm_gpu/codegen/training/python/lookup_args.template
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class OptimizerArgs(NamedTuple):
weight_norm_coefficient: float
lower_bound: float
regularization_mode: int
use_rowwise_bias_correction: bool # Used for OptimType.ADAM


class Momentum(NamedTuple):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ def invoke(
prev_iter_dev: Optional[torch.Tensor] = None,
{%- endif %}
gwd_lower_bound: float = 0.0,
{%- if "use_rowwise_bias_correction" in args_pt2.unified_pt2.split_function_arg_names %}
use_rowwise_bias_correction: bool = False,
{%- endif %}
{%- if "row_counter" in args_pt2.unified_pt2.split_saved_tensorlist_optional %}
row_counter: Optional[Momentum] = None,
{%- endif %}
) -> torch.Tensor:
{%- if is_experimental_optimizer %}
# By design, the warning only shows up once
Expand All @@ -94,7 +100,20 @@ def invoke(
{%- endif %}

vbe_metadata = common_args.vbe_metadata

{%- if "row_counter" in args_pt2.unified_pt2.split_saved_tensorlist_optional %}
if not use_rowwise_bias_correction or row_counter is None:
row_counter_dev = None
row_counter_uvm = None
row_counter_offsets = None
row_counter_placements = None
elif use_rowwise_bias_correction and row_counter is None:
assert False, "use_rowwise_bias_correction is set but row_counter cannot be None"
else:
row_counter_dev = row_counter.host
row_counter_uvm = row_counter.uvm
row_counter_offsets = row_counter.offsets
row_counter_placements = row_counter.placements
{%- endif %}
{%- if has_cpu_support and not ssd %}
if (common_args.host_weights.numel() > 0):
T = common_args.D_offsets.numel() - 1
Expand Down Expand Up @@ -393,6 +412,12 @@ def invoke(
row_counter_offsets=row_counter.offsets,
row_counter_placements=row_counter.placements,
{%- endif %}
{%- if "row_counter" in args_pt2.unified_pt2.split_saved_tensorlist_optional %}
row_counter_dev=row_counter_dev,
row_counter_uvm=row_counter_uvm,
row_counter_offsets=row_counter_offsets,
row_counter_placements=row_counter_placements,
{%- endif %}
# iter
iter=iter,
# max counter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,15 @@ class GlobalWeightDecayDefinition:
lower_bound: float = 0.0


@dataclass(frozen=True)
class UserEnabledConfigDefinition:
"""
This class is used to configure whether certain modes are to be enabled
"""

use_rowwise_bias_correction: bool = False # Adam


@dataclass(frozen=True)
class EnsembleModeDefinition:
step_ema: float = 10000
Expand Down Expand Up @@ -630,6 +639,7 @@ def __init__( # noqa C901
multipass_prefetch_config: Optional[MultiPassPrefetchConfig] = None,
global_weight_decay: Optional[GlobalWeightDecayDefinition] = None,
uvm_host_mapped: bool = False,
extra_optimizer_config: Optional[UserEnabledConfigDefinition] = None,
) -> None:
super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__()

Expand Down Expand Up @@ -1006,6 +1016,20 @@ def __init__( # noqa C901
# and CowClipDefinition are not used
counter_halflife = -1

# TO DO: Enable this on the new interface
# learning_rate_tensor = torch.tensor(
# learning_rate, device=torch.device("cpu"), dtype=torch.float
# )
if extra_optimizer_config is None:
extra_optimizer_config = UserEnabledConfigDefinition()
self.use_rowwise_bias_correction: bool = (
extra_optimizer_config.use_rowwise_bias_correction
)
if self.use_rowwise_bias_correction and not self.optimizer == OptimType.ADAM:
raise AssertionError(
"`use_rowwise_bias_correction` is only supported for OptimType.ADAM",
)

self.optimizer_args = invokers.lookup_args.OptimizerArgs(
stochastic_rounding=stochastic_rounding,
gradient_clipping=gradient_clipping,
Expand All @@ -1032,6 +1056,7 @@ def __init__( # noqa C901
weight_norm_coefficient=cowclip_regularization.weight_norm_coefficient,
lower_bound=cowclip_regularization.lower_bound,
regularization_mode=weight_decay_mode.value,
use_rowwise_bias_correction=self.use_rowwise_bias_correction,
)

if optimizer != OptimType.NONE:
Expand Down Expand Up @@ -1168,6 +1193,19 @@ def __init__( # noqa C901
torch.ones(1, dtype=torch.float32, device=self.current_device),
persistent=False,
)
elif optimizer == OptimType.ADAM and self.use_rowwise_bias_correction:
self._apply_split(
construct_split_state(
embedding_specs,
rowwise=True,
cacheable=False,
),
prefix="row_counter",
# pyre-fixme[6]: Expected `Type[Type[torch._dtype]]` for 3rd param
# but got `Type[torch.float32]`.
dtype=torch.float32,
uvm_host_mapped=self.uvm_host_mapped,
)
else:
self._register_nonpersistent_buffers("prev_iter")
self._register_nonpersistent_buffers("row_counter")
Expand All @@ -1192,7 +1230,6 @@ def __init__( # noqa C901
"iter",
torch.zeros(1, dtype=torch.int64, device=self.current_device),
)

else:
self.register_buffer(
"iter",
Expand Down Expand Up @@ -1895,6 +1932,24 @@ def forward( # noqa: C901
iter_int = int(self.iter_cpu.add_(1).item()) # used for local computation
self.iter.add_(1) # used for checkpointing

row_counter = invokers.lookup_args.Momentum(
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `Union[Module, Tensor]`.
dev=self.row_counter_dev,
# pyre-fixme[6]: For 2nd argument expected `Tensor` but got
# `Union[Module, Tensor]`.
host=self.row_counter_host,
# pyre-fixme[6]: For 3rd argument expected `Tensor` but got
# `Union[Module, Tensor]`.
uvm=self.row_counter_uvm,
# pyre-fixme[6]: For 4th argument expected `Tensor` but got
# `Union[Module, Tensor]`.
offsets=self.row_counter_offsets,
# pyre-fixme[6]: For 5th argument expected `Tensor` but got
# `Union[Module, Tensor]`.
placements=self.row_counter_placements,
)

if self.optimizer == OptimType.ADAM:
return self._report_io_size_count(
"fwd_output",
Expand All @@ -1904,6 +1959,10 @@ def forward( # noqa: C901
momentum1,
momentum2,
iter_int,
self.use_rowwise_bias_correction,
row_counter=(
row_counter if self.use_rowwise_bias_correction else None
),
),
)
if self.optimizer == OptimType.PARTIAL_ROWWISE_ADAM:
Expand Down Expand Up @@ -1957,23 +2016,6 @@ def forward( # noqa: C901
# `Union[Module, Tensor]`.
placements=self.prev_iter_placements,
)
row_counter = invokers.lookup_args.Momentum(
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `Union[Module, Tensor]`.
dev=self.row_counter_dev,
# pyre-fixme[6]: For 2nd argument expected `Tensor` but got
# `Union[Module, Tensor]`.
host=self.row_counter_host,
# pyre-fixme[6]: For 3rd argument expected `Tensor` but got
# `Union[Module, Tensor]`.
uvm=self.row_counter_uvm,
# pyre-fixme[6]: For 4th argument expected `Tensor` but got
# `Union[Module, Tensor]`.
offsets=self.row_counter_offsets,
# pyre-fixme[6]: For 5th argument expected `Tensor` but got
# `Union[Module, Tensor]`.
placements=self.row_counter_placements,
)

if self.optimizer == OptimType.EMAINPLACE_ROWWISE_ADAGRAD:
with torch.no_grad():
Expand Down Expand Up @@ -2543,6 +2585,15 @@ def get_optimizer_state(self) -> List[Dict[str, torch.Tensor]]:
list_of_state_dict = [
{"momentum_buffer": states[0]} for states in split_optimizer_states
]
elif self.optimizer == OptimType.ADAM and self.use_rowwise_bias_correction:
list_of_state_dict = [
{
"exp_avg": states[0],
"exp_avg_sq": states[1],
"row_counter": states[2],
}
for states in split_optimizer_states
]
elif (
self.optimizer == OptimType.ADAM
or self.optimizer == OptimType.PARTIAL_ROWWISE_ADAM
Expand Down Expand Up @@ -2717,7 +2768,9 @@ def get_optimizer_states(
rowwise=True,
)
)
if self._used_rowwise_adagrad_with_counter:
if self._used_rowwise_adagrad_with_counter or (
self.optimizer == OptimType.ADAM and self.use_rowwise_bias_correction
):
states.append(
get_optimizer_states(
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
Expand Down
1 change: 1 addition & 0 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,7 @@ def __init__(
weight_norm_coefficient=cowclip_regularization.weight_norm_coefficient,
lower_bound=cowclip_regularization.lower_bound,
regularization_mode=weight_decay_mode.value,
use_rowwise_bias_correction=False, # Unused, this is used in TBE's Adam
)

table_embedding_dtype = weights_precision.as_dtype()
Expand Down
Loading

0 comments on commit ac8606e

Please sign in to comment.