Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add minimal FSDP example #23

Merged
merged 4 commits into from
Jul 2, 2024
Merged

Add minimal FSDP example #23

merged 4 commits into from
Jul 2, 2024

Conversation

coreystatendet
Copy link
Contributor

Add an example (originally by Garrett with small updates from me) that shows how to use torch's FSDP implementation for LLM training alongside Core API.

Comment on lines 41 to 42
if next_idx == 0:
generator.manual_seed(42 + rank + 100000 * is_validation)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious: why this if block? Why not seed immediately after creating the generator?

Also, this isn't very important, but I guess the restarting logic starts the dataloader over, rather than continuing from the last batch. I wouldn't fix this, just maybe note it in a comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Goal: I'm trying to have the generator yield a recurring sequence of length simulated_size_in_batches. I do this by reseeding to the same initial seed after simulated_size_in_batches steps, which comes from setting next_idx = (next_idx + 1) % simulated_size_in_batches.

I save repeating the code by having the initial seed and the reseeds use the same if condition here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, I missed the modulo, got it now


# Wrap the embedding layer, the lm head, and each transformer block into its own FSDP unit:
auto_wrap_policy = ModuleWrapPolicy([TransformerBlock, EmbedAndEncode, LMHead])

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should have caught this before, but there are some small use_amp = True related issues we probably want to adress.

First, there is some special handling for the FSDP weight/comms/etc types when using amp, as in

from torch.distributed.fsdp import MixedPrecision
fsdp_model = FSDP(model,
    mixed_precision=MixedPrecision(param_dtype=torch.bfloat16,
    reduce_dtype=torch.bfloat16), ...)

Not needed for correctness, just efficiency.

Second, for the autocast in get_loss it's better practice to specify the dtype=torch.bfloat16 arg, as in

with torch.cuda.amp.autocast(enabled=use_amp, dtype=torch.bfloat16):
        outputs = fsdp_model(inputs)

The default is dtype=torch.float16 (not bfloat16).

Third, with bfloat16 there should be no need for the ShardedGradScaler; only needed for float16.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies, I should have caught these. Think I missed them when just looking at the diffs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First, there is some special handling for the FSDP weight/comms/etc types when using amp ... Not needed for correctness, just efficiency.

Good catch, wasn't aware of this; I assume autocast just fails to recognize / handle some of the buffers and operations used inside of FSDP without it?

Second, for the autocast in get_loss it's better practice to specify the dtype=torch.bfloat16 arg, as in

bfloat16 isn't supported for pre-A100 GPUs, right? I imagine we want our example to run on any GPU, hence defaulting to float16. Doing a quick Google search, looks like this can be made conditional via e.g.

compute_capability = torch.cuda.get_device_capability()
if compute_capability[0] < 8:
     ....

So we could switch to using bfloat16 iff it's supported.

Third, with bfloat16 there should be no need for the ShardedGradScaler; only needed for float16.

See above, but yes -- if we switch to bfloat16 entirely or conditionally, we could omit.

Copy link
Member

@garrett361 garrett361 Jul 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call on hardware compatibility. Apparently if the hardware doesn't support bfloat16, it silently falls back to float32? Seem weird.

Anyway, your call on what to support there. Could also make it configurable as another hparam, amp_dtype?

I assume autocast just fails to recognize / handle some of the buffers and operations used inside of FSDP without it?

Nah, it's a little different, IIUC. The actual weights used for the forwards are held in whatever precision you specify in MixedPrecision, or float32 if omitted. Under autocasting, various tensors have their dtypes changed so that specific operations occur in either high or low precision. E.g. matmuls in low-precision and softmax in high-precision. MixedPrecision affects which direction most of the casts occur in (among other things).

Like in non-FSDP weights would always be kept in high-precision and down-cast as needed. With FSDP and MixedPrecision(param_dtype=torch.bfloat16), say, you also avoid those down-casts because the weights are already in the desired precision. But you then might need some extra up-casts elsewhere.

There are no explicit failures if you don't specify MixedPrecision w/ FSDP and run under autocast; just perf differences, and likely some (hopefully) small numerical differences also.

I likely got some details incorrect here, also. The FSDP + amp API isn't super well documented. Talked to one of the OLMo researchers about exactly this topic and how it's confusing a while ago. (They are heavily FSDP based.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, if you specify MixedPrecision in FSDP but don't wrap in autocast none of the casts ever happen and you can end up running every part of the forwards in low-precision, rather than mixed.

Copy link
Member

@garrett361 garrett361 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! I missed some things about mixed precision that might be worth fixing, but up to you.


# Wrap the embedding layer, the lm head, and each transformer block into its own FSDP unit:
auto_wrap_policy = ModuleWrapPolicy([TransformerBlock, EmbedAndEncode, LMHead])

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies, I should have caught these. Think I missed them when just looking at the diffs.

@coreystatendet coreystatendet merged commit d5f4b27 into main Jul 2, 2024
1 check failed
@coreystatendet coreystatendet deleted the minimal-fsdp branch July 2, 2024 19:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants