-
Notifications
You must be signed in to change notification settings - Fork 429
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Automicrobatching for Non-Powers-of-2 + Fixes to FSDP deadlocks using Adaptive Sync Hooks #3503
Draft
JackZ-db
wants to merge
23
commits into
mosaicml:main
Choose a base branch
from
JackZ-db:jz/auto_non_powers_of_2
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
bbb9a66
add automicrobatching for non-powers-of-2 + adaptive sync hooks
JackZ-db 4228889
include auto helpers in _all_
JackZ-db a537c4c
fix circular imports
JackZ-db ff806d1
remove circular import
JackZ-db 4025274
remove import state
JackZ-db 9ad1719
dist
JackZ-db 896b999
fix imports
JackZ-db 7146f23
import defaultdict
JackZ-db cd2fe9f
log for hook on off
JackZ-db b1b16cd
fixed hook readd bug
JackZ-db 476e028
rename hooks to fsdp hooks, will only trigger if fsdp
JackZ-db 0693364
only invoke hook logic if fsdp enabled
JackZ-db df404ac
typo
JackZ-db 0b6d6ce
fix seq length warmup
JackZ-db 153c413
only patch flat param handle unshard if > 2.3
JackZ-db d98926d
fix version comparison
JackZ-db 3e82ef6
mark unit test
JackZ-db a09b844
remove device mark
JackZ-db 33840c5
filter user warnigns out
JackZ-db e87c9f6
fix
JackZ-db 86c32a0
dist sampler
JackZ-db a193b76
ignore runtime warning
JackZ-db 0b6a30d
only drop hooks after 3 consecutive successes with this microbatch size
JackZ-db File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,8 +31,17 @@ | |
from torch.distributed.fsdp._fsdp_extensions import _ext_pre_load_state_dict_transform | ||
from torch.distributed.utils import _replace_by_prefix | ||
|
||
from composer.utils import dist | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
def patch_unshard_for_automicrobatching(auto_microbatch_size_found=False): | ||
if version.parse(torch.__version__) >= version.parse('2.3.1'): | ||
from torch.distributed.fsdp._flat_param import FlatParamHandle | ||
if auto_microbatch_size_found: | ||
FlatParamHandle.unshard = (unshard) | ||
else: | ||
FlatParamHandle.unshard = (unshard_with_sync) | ||
|
||
def patch_pytorch(): | ||
"""Monkey patches pytorch functions based on pytorch version.""" | ||
|
@@ -122,6 +131,73 @@ def patch_pytorch(): | |
_MeshEnv.create_child_mesh = create_child_mesh | ||
DeviceMesh.__getitem__ = device_mesh__getitem__ | ||
|
||
@no_type_check | ||
def unshard(self): | ||
""" | ||
Run the unshard logic. | ||
This is an unpatched method from pytorch, meant to be reverted to | ||
whenever automicrobatching turns off its hooks for increased throughput. | ||
This includes all-gathering the flat parameter | ||
and switching to using the unsharded flat parameter. If the handle does | ||
not need unsharding, then this only switches to using the unsharded | ||
flat parameter. For ``NO_SHARD``, this is a no-op. | ||
If FSDP is in :meth:`summon_full_params` and the handle uses parameter | ||
Comment on lines
+134
to
+144
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should probably be in the if torch 2.3.1 section |
||
mixed precision, then the parameter is forced to full precision. | ||
""" | ||
if not self.needs_unshard(): | ||
# Even when not needing an unshard, we should switch to using | ||
# the unsharded flat parameter | ||
unsharded_flat_param = ( | ||
self._get_padded_unsharded_flat_param() | ||
if self.uses_sharded_strategy | ||
else self.flat_param | ||
) | ||
self._use_unsharded_flat_param(unsharded_flat_param) | ||
return | ||
unsharded_flat_param = self._alloc_padded_unsharded_flat_param() | ||
padded_unsharded_flat_param = self._all_gather_flat_param(unsharded_flat_param) | ||
self._use_unsharded_flat_param(padded_unsharded_flat_param) | ||
|
||
@no_type_check | ||
def unshard_with_sync(self): | ||
""" | ||
Run the unshard logic, but with a sync after a :meth:`_alloc_padded_unsharded_flat_param` | ||
to prevent deadlocks when some ranks OOM after the alloc call and others do not. | ||
This is a patched method from pytorch, meant to be called when automicrobatching | ||
turns on hooks in its search process for the optimal non-OOMing microbatch size. | ||
This includes all-gathering the flat parameter | ||
and switching to using the unsharded flat parameter. If the handle does | ||
not need unsharding, then this only switches to using the unsharded | ||
flat parameter. For ``NO_SHARD``, this is a no-op. | ||
If FSDP is in :meth:`summon_full_params` and the handle uses parameter | ||
mixed precision, then the parameter is forced to full precision. | ||
""" | ||
if not self.needs_unshard(): | ||
# Even when not needing an unshard, we should switch to using | ||
# the unsharded flat parameter | ||
unsharded_flat_param = ( | ||
self._get_padded_unsharded_flat_param() | ||
if self.uses_sharded_strategy | ||
else self.flat_param | ||
) | ||
self._use_unsharded_flat_param(unsharded_flat_param) | ||
return | ||
unsharded_flat_param = self._alloc_padded_unsharded_flat_param() | ||
|
||
# Check if any other rank hit an OOM | ||
found_cuda_oom_tensor = torch.tensor([0], dtype=torch.uint8).to(self.device, non_blocking=True) | ||
|
||
dist.all_reduce(found_cuda_oom_tensor, reduce_operation='MAX') | ||
found_cuda_oom = found_cuda_oom_tensor.item() | ||
# Signal current rank is still in batch | ||
all_ranks_finished_tensor = torch.tensor([0], dtype=torch.uint8).to(self.device, non_blocking=True) | ||
|
||
dist.all_reduce(all_ranks_finished_tensor, reduce_operation='MIN') | ||
|
||
if found_cuda_oom == 1: | ||
raise RuntimeError('CUDA out of memory encountered on a different rank') | ||
padded_unsharded_flat_param = self._all_gather_flat_param(unsharded_flat_param) | ||
self._use_unsharded_flat_param(padded_unsharded_flat_param) | ||
|
||
def build_metadata( | ||
self, | ||
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add a comment on what this is doing?