Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Automicrobatching for Non-Powers-of-2 + Fixes to FSDP deadlocks using Adaptive Sync Hooks #3503

Draft
wants to merge 23 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion composer/algorithms/seq_length_warmup/seq_length_warmup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import logging
import textwrap
import warnings
from collections import defaultdict
from typing import Mapping, Optional

import torch
Expand Down Expand Up @@ -284,7 +286,32 @@ def _activate_model(self, state: State, logger: Logger) -> None:
batch_clone[k] = v[:, :self.max_seq_length].contiguous()

# In-line to avoid circular dependency
from composer.trainer.trainer import _adjust_device_train_microbatch_size, _is_cuda_oom
from composer.trainer.trainer import _clear_incomplete_train_states, _is_cuda_oom

def _adjust_device_train_microbatch_size(state: State):
"""Adjust device_train_microbatch_size if we encounter OOM.

Args:
state (State): State of trainer.
"""
# If any rank hit CUDA OOM, update device_train_microbatch_size and retry. Raise runtime error
# if training 1 sample at a time still resulted in CUDA out of memory.
assert state.device_train_microbatch_size is not None
if state.device_train_microbatch_size == 1:
raise RuntimeError((
'CUDA out of memory. The train loop failed with an internal microbatch of size 1.'
'The GPU does not have enough memory to process even 1 sample during train.'
))
else:
original_microbatch_size = state.device_train_microbatch_size
state.device_train_microbatch_size = max(int(original_microbatch_size / 2), 1)
warnings.warn(
RuntimeWarning(
'CUDA out of memory detected. Train microbatch size will be decreased from '
f'{original_microbatch_size} -> {state.device_train_microbatch_size}.',
),
)
_clear_incomplete_train_states

# This loop tries to do a forward/backward pass using the current microbatch size.
# If it hits an OOM error, it halves `state.device_train_microbatch_size` and tries again
Expand Down
24 changes: 16 additions & 8 deletions composer/distributed/dist_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import logging
import warnings
from contextlib import contextmanager, nullcontext
from typing import Any, Callable, ContextManager, Iterator, Optional, Sequence, Union, cast
from typing import Any, Callable, ContextManager, Iterator, Optional, Sequence, Union, Tuple, cast

import torch
from packaging import version
Expand Down Expand Up @@ -203,7 +203,7 @@ def prepare_fsdp_module(
device: Device,
auto_microbatching: bool,
te_rng_seed: int = 1234,
) -> None:
) -> Tuple[list, dict]:
"""Prepare a module (assumed ComposerModel) and optimizer for use with :class:`torch.distributed.fsdp.FullyShardedDataParallel`.

Args:
Expand All @@ -230,6 +230,9 @@ def prepare_fsdp_module(
'some weights may be randomly initialized when loading a checkpoint.',
)

# Handles of FSDP sync hooks if automicrobatching is on
hook_handles = []

# Check if other ranks OOMed after forward/backward pass when using auto microbatching. This
# may happen when close to memory limit or with uneven memory usage across ranks. Since we
# need to do this before the model weights are gathered for the next FSDP block, we wrap every
Expand Down Expand Up @@ -512,9 +515,6 @@ def lambda_fn(module: torch.nn.Module) -> Union[bool, dict]:
ret = obj.fsdp_wrap_fn(module)
if isinstance(ret, dict):
ret = set_custom_fsdp_module_kwargs(ret, process_group_cache)
if ret and auto_microbatching:
module.register_forward_hook(sync_hook)
module.register_full_backward_hook(sync_hook)
return ret

_auto_wrap_policy = CustomPolicy(lambda_fn)
Expand All @@ -531,9 +531,6 @@ def __auto_wrap_policy(module: torch.nn.Module, recurse: bool, nonwrapped_numel:
elif hasattr(obj, 'fsdp_wrap_fn') and isinstance(obj.fsdp_wrap_fn, Callable):
should_be_wrapped = obj.fsdp_wrap_fn(module)

if should_be_wrapped and auto_microbatching:
module.register_forward_hook(sync_hook)
module.register_full_backward_hook(sync_hook)
return should_be_wrapped

def _auto_wrap_policy_new(module: torch.nn.Module, recurse: bool, nonwrapped_numel: int) -> bool:
Expand Down Expand Up @@ -567,6 +564,15 @@ def _auto_wrap_policy_new(module: torch.nn.Module, recurse: bool, nonwrapped_num
log.info(f'Calling prepare_te_modules_for_fsdp to enable TE weights sharding')
prepare_te_modules_for_fsdp(fsdp_obj)


if auto_microbatching:
Comment on lines +567 to +568
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a comment on what this is doing?

for _, module in fsdp_obj.named_modules():
if isinstance(module, FullyShardedDataParallel):
hook_handles.append(module.register_forward_pre_hook(sync_hook, prepend=True))
hook_handles.append(module.register_full_backward_pre_hook(sync_hook, prepend=True))
else:
hook_handles.append(module.register_full_backward_hook(sync_hook))

if hasattr(fsdp_obj, '_exec_order_data'):
if hasattr(fsdp_obj._exec_order_data, '_forward_prefetch_limit'):
fsdp_obj._exec_order_data._forward_prefetch_limit = fsdp_config.forward_prefetch_limit
Expand Down Expand Up @@ -727,3 +733,5 @@ def _check_fn(module: torch.nn.Module) -> bool:
assert optimizer_specific_info is not None
optimizer_specific_info.update({'params': list(model.parameters())})
optim.add_param_group(optimizer_specific_info)

return hook_handles, dict(fsdp_obj.named_modules())
76 changes: 76 additions & 0 deletions composer/trainer/_patch_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,17 @@
from torch.distributed.fsdp._fsdp_extensions import _ext_pre_load_state_dict_transform
from torch.distributed.utils import _replace_by_prefix

from composer.utils import dist

log = logging.getLogger(__name__)

def patch_unshard_for_automicrobatching(auto_microbatch_size_found=False):
if version.parse(torch.__version__) >= version.parse('2.3.1'):
from torch.distributed.fsdp._flat_param import FlatParamHandle
if auto_microbatch_size_found:
FlatParamHandle.unshard = (unshard)
else:
FlatParamHandle.unshard = (unshard_with_sync)

def patch_pytorch():
"""Monkey patches pytorch functions based on pytorch version."""
Expand Down Expand Up @@ -122,6 +131,73 @@ def patch_pytorch():
_MeshEnv.create_child_mesh = create_child_mesh
DeviceMesh.__getitem__ = device_mesh__getitem__

@no_type_check
def unshard(self):
"""
Run the unshard logic.
This is an unpatched method from pytorch, meant to be reverted to
whenever automicrobatching turns off its hooks for increased throughput.
This includes all-gathering the flat parameter
and switching to using the unsharded flat parameter. If the handle does
not need unsharding, then this only switches to using the unsharded
flat parameter. For ``NO_SHARD``, this is a no-op.
If FSDP is in :meth:`summon_full_params` and the handle uses parameter
Comment on lines +134 to +144
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably be in the if torch 2.3.1 section

mixed precision, then the parameter is forced to full precision.
"""
if not self.needs_unshard():
# Even when not needing an unshard, we should switch to using
# the unsharded flat parameter
unsharded_flat_param = (
self._get_padded_unsharded_flat_param()
if self.uses_sharded_strategy
else self.flat_param
)
self._use_unsharded_flat_param(unsharded_flat_param)
return
unsharded_flat_param = self._alloc_padded_unsharded_flat_param()
padded_unsharded_flat_param = self._all_gather_flat_param(unsharded_flat_param)
self._use_unsharded_flat_param(padded_unsharded_flat_param)

@no_type_check
def unshard_with_sync(self):
"""
Run the unshard logic, but with a sync after a :meth:`_alloc_padded_unsharded_flat_param`
to prevent deadlocks when some ranks OOM after the alloc call and others do not.
This is a patched method from pytorch, meant to be called when automicrobatching
turns on hooks in its search process for the optimal non-OOMing microbatch size.
This includes all-gathering the flat parameter
and switching to using the unsharded flat parameter. If the handle does
not need unsharding, then this only switches to using the unsharded
flat parameter. For ``NO_SHARD``, this is a no-op.
If FSDP is in :meth:`summon_full_params` and the handle uses parameter
mixed precision, then the parameter is forced to full precision.
"""
if not self.needs_unshard():
# Even when not needing an unshard, we should switch to using
# the unsharded flat parameter
unsharded_flat_param = (
self._get_padded_unsharded_flat_param()
if self.uses_sharded_strategy
else self.flat_param
)
self._use_unsharded_flat_param(unsharded_flat_param)
return
unsharded_flat_param = self._alloc_padded_unsharded_flat_param()

# Check if any other rank hit an OOM
found_cuda_oom_tensor = torch.tensor([0], dtype=torch.uint8).to(self.device, non_blocking=True)

dist.all_reduce(found_cuda_oom_tensor, reduce_operation='MAX')
found_cuda_oom = found_cuda_oom_tensor.item()
# Signal current rank is still in batch
all_ranks_finished_tensor = torch.tensor([0], dtype=torch.uint8).to(self.device, non_blocking=True)

dist.all_reduce(all_ranks_finished_tensor, reduce_operation='MIN')

if found_cuda_oom == 1:
raise RuntimeError('CUDA out of memory encountered on a different rank')
padded_unsharded_flat_param = self._all_gather_flat_param(unsharded_flat_param)
self._use_unsharded_flat_param(padded_unsharded_flat_param)

def build_metadata(
self,
Expand Down
Loading
Loading