Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable FSDP2 + fp8 all-gather and fix TP fp8 all-gather #413

Merged
merged 24 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,39 @@ def build_test_list():
"fsdp2_mem_tracker",
ngpu=4,
),
OverrideDefinitions(
Copy link
Contributor Author

@weifengpy weifengpy Jul 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added followings to CI

  • 1D fsdp original dtype all-gather
  • 1D fsdp fp8 all-gather
  • 1D fsdp fp8 all-gather with precomputed dynamic scales

need follow ups to enable TP fp8 all-gather in CI: current CI tokenizer has 2556, not divisible by 16) #461

  • 1D TP fp8 all-gather
  • 2D FSDP + TP fp8 all-gather

[
[
"--training.fp8_linear",
]
],
"FSDP2 with original dtype",
"fp8_fsdp2_orig_all_gather",
ngpu=4,
),
OverrideDefinitions(
[
[
"--training.fp8_linear",
"--training.enable_fsdp_fp8_all_gather",
weifengpy marked this conversation as resolved.
Show resolved Hide resolved
]
],
"FSDP2 with fp8 all-gather",
"fp8_fsdp2_fp8_all_gather",
ngpu=4,
),
OverrideDefinitions(
[
[
"--training.fp8_linear",
"--training.enable_fsdp_fp8_all_gather",
"--training.precompute_float8_dynamic_scale_for_fsdp",
weifengpy marked this conversation as resolved.
Show resolved Hide resolved
]
],
"FSDP2 with fp8 all-gather and precomputed dynamic scales",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: comment for 2D

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the end I have to remove 2D from this PR. current CI tokenizer has vacab size = 2556. However, fp8 gemm need the vacab size to be divisible by 16 #461

I can follow up with you on how to have a tokenizer with vacab size = 2560 to unblock 1D TP + fp8, and 2D + fp8 in CI

"fp8_fsdp2_fp8_all_gather_precompute_dynamic_scales",
ngpu=4,
),
]
return integration_tests_flavors

Expand Down
12 changes: 12 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,18 @@ def __init__(self):
here: https://github.com/pytorch-labs/float8_experimental
""",
)
self.parser.add_argument(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline, let's refactor fp8 configs, e.g. have a dedicated field for enabling fp8 or not.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

renamed fp8_linear to enable_fp8_linear

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think one thing to note is that right now this is a boolean which will swap to the default float8 recipe
Dynamic scaling x Tensor wise ScalingGranularity x all tensors involved in the matmul [ input, weight, grad]

I think we should brainstorm on an elegant solutions for users to express their desired config here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good question. evetually we might have to expose args/kwargs from swap_linear_with_float8_linear for flexibility

"--training.enable_fsdp_fp8_all_gather",
action="store_true",
default=False,
help="Whether enable fp8 all-gather in FSDP",
)
self.parser.add_argument(
"--training.precompute_float8_dynamic_scale_for_fsdp",
action="store_true",
default=False,
help="Whether precompute fp8 scales dynamically for FSDP",
)
self.parser.add_argument(
"--training.gc_freq",
type=int,
Expand Down
28 changes: 25 additions & 3 deletions torchtitan/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,30 @@

# Note: Performance
# Float8 experimental is intended to be ran under `torch.compile`` for competitive performance
import contextlib

import float8_experimental.config as config

import torch
import torch.nn as nn
from float8_experimental.float8_linear import TensorScalingType

from torchtitan.config_manager import JobConfig
from torchtitan.logging_utils import logger


@contextlib.contextmanager
def set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather: bool):
prev = config.enable_fsdp_fp8_all_gather
torch.distributed.barrier()
config.enable_fsdp_fp8_all_gather = enable_fsdp_fp8_all_gather
try:
yield
finally:
torch.distributed.barrier()
config.enable_fsdp_fp8_all_gather = prev


def build_fp8_linear(model: nn.Module, job_config: JobConfig):
"""
This function converts the linear layers to `Float8Linear`. Note that today,
Expand All @@ -27,8 +44,8 @@ def build_fp8_linear(model: nn.Module, job_config: JobConfig):
This will mutate the model inplace.
"""
use_fp8_linear = job_config.training.fp8_linear
enable_fsdp_fp8_all_gather = job_config.training.enable_fsdp_fp8_all_gather
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

discussed offline: please check if it makes sense to enable it only when dp_degree > 1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added check on parallel_dims.dp_enabled

try:
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear_utils import (
swap_linear_with_float8_linear,
)
Expand All @@ -38,5 +55,10 @@ def build_fp8_linear(model: nn.Module, job_config: JobConfig):
) from exc
if use_fp8_linear:
# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
swap_linear_with_float8_linear(model, Float8Linear)
logger.info("Swapped to Float8Linear layers")
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noop Q: do we need this in a context manager to make testing + resetting easier?

Copy link
Contributor Author

@weifengpy weifengpy Jul 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm. set_enable_fsdp_fp8_all_gather is a context manager right now. do you mean "why" it should be a context manager ?

EDIT: I also see you mentioned "make testing + resetting easier", which answered why. so I am not sure if it's a question for me

swap_linear_with_float8_linear(
model, scaling_type_w=TensorScalingType.DYNAMIC
)
logger.info(
f"Swapped to Float8Linear layers with {enable_fsdp_fp8_all_gather=}"
)
17 changes: 17 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import torch
import torch.nn.functional as F
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
Copy link
Contributor

@wanchaol wanchaol Jul 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@weifengpy I think we should hide this import to the path where enable_fp8_allgather path happened?

The problem here is that for every feature that requires an additional install from other dependency, we should try to hide the import to the path that uses it instead of import it globally, otherwise for users who didn't install the float8_experimental, if they rebase, and it would just fail to train for them.

Please submit a follow up PR to fix this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got you. I am moving it from top-level to if-else now #464

thanks for the timely reminder

from torch.distributed import destroy_process_group
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.elastic.multiprocessing.errors import record
Expand Down Expand Up @@ -218,6 +219,11 @@ def loss_fn(pred, labels):
# apply fp8 linear module swap
if job_config.training.fp8_linear:
build_fp8_linear(whole_model, job_config)
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can remove this to favor simplicity if it is a no-op flag when fp8_linear=False

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed ValueError on enable_fp8_linear=False

if job_config.training.enable_fsdp_fp8_all_gather:
raise ValueError(
"enable_fsdp_fp8_all_gather can only be used with fp8_linear"
)

# log model size
model_param_count = get_num_params(whole_model)
Expand Down Expand Up @@ -398,6 +404,17 @@ def loss_fn(pred, labels):
optimizers.step()
lr_schedulers.step()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add comment to explain precompute_float8_dynamic_scale_for_fsdp

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if job_config.training.precompute_float8_dynamic_scale_for_fsdp:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

discussed offline: can refactor to make it simpler

Copy link
Contributor Author

@weifengpy weifengpy Jul 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed ValueError when enable_fp8_linear/enable_fsdp_fp8_all_gather=False

if (not job_config.training.fp8_linear) or (
not job_config.training.enable_fsdp_fp8_all_gather
):
raise ValueError(
"precompute_float8_dynamic_scale_for_fsdp is only "
"supported when fp8_linear and "
"enable_fsdp_fp8_all_gather are both enabled"
)
precompute_float8_dynamic_scale_for_fsdp(model)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe a noob question: could you briefly explain what this is doing?
I wonder since we are already using context functions for FP8, can we have a context and run it in a .step() function here, just like optimizer, lr scheduler, and profiler. This would make the code consistent.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you briefly explain what this is doing

precompute_float8_dynamic_scale_for_fsdp is a for-loop over model.parameters(). it issues a single all-reduce for all parameters, ie abs(max(param)) for param in model.parameters() and save amax/scale as param._precomputed_scale. this speed up the training loop since we do not need to compute amax/scale for each parameters in the training loop

we are already using context functions for FP8

do you refer to set_enable_fsdp_fp8_all_gather ? That's for model intiaitialization where we swap nn.Linear with user-defined float8 linear. precompute_float8_dynamic_scale_for_fsdp is for training loop

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

per suggestion, raise error if use_fp8_linear=False or enable_fsdp_fp8_all_gather =False

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noob q: do we eventually want to just put this in fsdp2?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has to be done after optimizer step (since parameter values change). Are you suggesting to run this in the root module's pre-forward?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah anywhere between the n-1th optimizer step and the first all-gather in the nth step where fsdp2 has control (if there's any).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. I think one concern is that FSDP is agnostic to the fp8 all-gather. FSDP does not know that the fsdp_pre_all_gather and fsdp_post_all_gather of the Float8Linear.weights are implemented to do fp8 all-gather, so at best, the user still would need to register a module forward pre-hook or something to run this method.

Copy link
Contributor

@yifuwang yifuwang Jul 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see. Somehow I thought fsdp2 was fp8-aware


losses_since_last_log.append(loss)

# log metrics
Expand Down
Loading