diff --git a/estimation.py b/estimation.py index ddf24d8a..e652c581 100644 --- a/estimation.py +++ b/estimation.py @@ -124,8 +124,8 @@ def loss_fn(pred, labels): whole_model = model_cls.from_model_args(model_config) # apply fp8 linear module swap - if job_config.training.fp8_linear: - build_fp8_linear(whole_model, job_config) + if job_config.training.enable_fp8_linear: + build_fp8_linear(whole_model, job_config, parallel_dims.dp_enabled) # apply PT-D DP/TP parallelisms and activation checkpointing model_parts = [whole_model] diff --git a/test_runner.py b/test_runner.py index 319f99d7..f2f80504 100755 --- a/test_runner.py +++ b/test_runner.py @@ -273,6 +273,39 @@ def build_test_list(): "fsdp2_mem_tracker", ngpu=4, ), + OverrideDefinitions( + [ + [ + "--training.enable_fp8_linear", + ] + ], + "FSDP2 with original dtype", + "fp8_fsdp2_orig_all_gather", + ngpu=4, + ), + OverrideDefinitions( + [ + [ + "--training.enable_fp8_linear", + "--training.enable_fsdp_fp8_all_gather", + ] + ], + "FSDP2 with fp8 all-gather", + "fsdp2_fp8_all_gather", + ngpu=4, + ), + OverrideDefinitions( + [ + [ + "--training.enable_fp8_linear", + "--training.enable_fsdp_fp8_all_gather", + "--training.precompute_float8_dynamic_scale_for_fsdp", + ] + ], + "FSDP2 with fp8 all-gather and precomputed dynamic scales", + "fsdp2_fp8_all_gather_precompute_dynamic_scales", + ngpu=4, + ), ] return integration_tests_flavors diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 3ade1b9d..0dfe1bb0 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -338,7 +338,7 @@ def __init__(self): help="Whether to compile the model", ) self.parser.add_argument( - "--training.fp8_linear", + "--training.enable_fp8_linear", action="store_true", help=""" If true, swaps `torch.nn.Linear` with `Float8Linear` with @@ -347,6 +347,18 @@ def __init__(self): here: https://github.com/pytorch-labs/float8_experimental """, ) + self.parser.add_argument( + "--training.enable_fsdp_fp8_all_gather", + action="store_true", + default=False, + help="Whether enable fp8 all-gather in FSDP", + ) + self.parser.add_argument( + "--training.precompute_float8_dynamic_scale_for_fsdp", + action="store_true", + default=False, + help="Whether precompute fp8 scales dynamically for FSDP", + ) self.parser.add_argument( "--training.gc_freq", type=int, diff --git a/torchtitan/float8_linear.py b/torchtitan/float8_linear.py index 0bd0900c..f41a812d 100644 --- a/torchtitan/float8_linear.py +++ b/torchtitan/float8_linear.py @@ -12,31 +12,58 @@ # Note: Performance # Float8 experimental is intended to be ran under `torch.compile`` for competitive performance +import contextlib +from typing import Optional +import float8_experimental.config as config + +import torch import torch.nn as nn +from float8_experimental.float8_linear import TensorScalingType from torchtitan.config_manager import JobConfig from torchtitan.logging_utils import logger -def build_fp8_linear(model: nn.Module, job_config: JobConfig): +@contextlib.contextmanager +def set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather: bool): + prev = config.enable_fsdp_fp8_all_gather + torch.distributed.barrier() + config.enable_fsdp_fp8_all_gather = enable_fsdp_fp8_all_gather + try: + yield + finally: + torch.distributed.barrier() + config.enable_fsdp_fp8_all_gather = prev + + +def build_fp8_linear( + model: nn.Module, job_config: JobConfig, dp_enabled: Optional[bool] = False +): """ This function converts the linear layers to `Float8Linear`. Note that today, only dynamic tensor scaling (the default) is supported. This will mutate the model inplace. """ - use_fp8_linear = job_config.training.fp8_linear + enable_fp8_linear = job_config.training.enable_fp8_linear + enable_fsdp_fp8_all_gather = ( + job_config.training.enable_fsdp_fp8_all_gather and dp_enabled + ) try: - from float8_experimental.float8_linear import Float8Linear from float8_experimental.float8_linear_utils import ( swap_linear_with_float8_linear, ) + + # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear + with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): + swap_linear_with_float8_linear( + model, scaling_type_w=TensorScalingType.DYNAMIC + ) + logger.info( + f"Swapped to Float8Linear layers with {enable_fsdp_fp8_all_gather=}" + ) except ImportError as exc: raise ImportError( "float8_experimental is not installed. Please install it to use fp8 linear layers." ) from exc - if use_fp8_linear: - # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear - swap_linear_with_float8_linear(model, Float8Linear) - logger.info("Swapped to Float8Linear layers") diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 1b414159..b33e8870 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -117,12 +117,24 @@ def selective_checkpointing_context_fn(): def get_tp_parallel_strategy( job_config: JobConfig, + model: nn.Module, ) -> Tuple[RowwiseParallel, ColwiseParallel, PrepareModuleInput]: """Get the parallel strategy for the transformer model. This function handles the special case of using float8 with tensor parallelism. """ - if job_config.training.fp8_linear == "dynamic": + if job_config.training.enable_fp8_linear: + from float8_experimental.float8_linear import Float8Linear, TensorScalingType + + if any( + isinstance(m, Float8Linear) + and m.scaling_type_w is TensorScalingType.DELAYED + for m in model.modules() + ): + raise NotImplementedError( + "1D TP fp8 all-gather only supports dynamic scaling" + ) + from float8_experimental.float8_tensor_parallel import ( Float8ColwiseParallel, Float8RowwiseParallel, @@ -346,7 +358,7 @@ def apply_tp( rowwise_parallel_weight, colwise_parallel_weight, prepare_module_input, - ) = get_tp_parallel_strategy(job_config) + ) = get_tp_parallel_strategy(job_config, model) loss_parallel = parallel_dims.loss_parallel_enabled # 1. Parallelize the embedding and shard its outputs (which are the first diff --git a/train.py b/train.py index 8e55c210..14008525 100644 --- a/train.py +++ b/train.py @@ -19,6 +19,7 @@ import torch import torch.nn.functional as F +from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp from torch.distributed import destroy_process_group from torch.distributed.checkpoint.stateful import Stateful from torch.distributed.elastic.multiprocessing.errors import record @@ -216,8 +217,8 @@ def loss_fn(pred, labels): whole_model = model_cls.from_model_args(model_config) # apply fp8 linear module swap - if job_config.training.fp8_linear: - build_fp8_linear(whole_model, job_config) + if job_config.training.enable_fp8_linear: + build_fp8_linear(whole_model, job_config, parallel_dims.dp_enabled) # log model size model_param_count = get_num_params(whole_model) @@ -398,6 +399,15 @@ def loss_fn(pred, labels): optimizers.step() lr_schedulers.step() + if ( + job_config.training.enable_fp8_linear + and job_config.training.enable_fsdp_fp8_all_gather + and job_config.training.precompute_float8_dynamic_scale_for_fsdp + ): + # calculate float8 dynamic amax/scale for all-parameter for FSDP2 + # it issues a single all-reduce for all parameters at once for better performance + precompute_float8_dynamic_scale_for_fsdp(model) + losses_since_last_log.append(loss) # log metrics diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index cb2fb215..6064ced1 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -37,7 +37,7 @@ max_norm = 1.0 # grad norm clipping steps = 10 data_parallel_degree = -1 tensor_parallel_degree = 1 -fp8_linear = false +enable_fp8_linear = false compile = false dataset = "c4_mini" # supported datasets: c4_mini (45K), c4 (177M) diff --git a/train_configs/llama2_13b.toml b/train_configs/llama2_13b.toml index 05e3c27b..f4061ad0 100644 --- a/train_configs/llama2_13b.toml +++ b/train_configs/llama2_13b.toml @@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 1 -fp8_linear = false +enable_fp8_linear = false compile = false dataset = "c4" diff --git a/train_configs/llama2_70b.toml b/train_configs/llama2_70b.toml index 5b2dd493..19e033b8 100644 --- a/train_configs/llama2_70b.toml +++ b/train_configs/llama2_70b.toml @@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 8 # 8-way TP -fp8_linear = false +enable_fp8_linear = false compile = false dataset = "c4" diff --git a/train_configs/llama2_7b.toml b/train_configs/llama2_7b.toml index 9b72246a..95d67667 100644 --- a/train_configs/llama2_7b.toml +++ b/train_configs/llama2_7b.toml @@ -32,7 +32,7 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 1 # dp-only would be sufficient for 7B -fp8_linear = false +enable_fp8_linear = false compile = false dataset = "c4" diff --git a/train_configs/llama3_70b.toml b/train_configs/llama3_70b.toml index 93b529f6..ac6b31c1 100644 --- a/train_configs/llama3_70b.toml +++ b/train_configs/llama3_70b.toml @@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 8 # 8-way TP -fp8_linear = false +enable_fp8_linear = false compile = false dataset = "c4" diff --git a/train_configs/llama3_8b.toml b/train_configs/llama3_8b.toml index 95a53d56..2c3c6e63 100644 --- a/train_configs/llama3_8b.toml +++ b/train_configs/llama3_8b.toml @@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 1 -fp8_linear = false +enable_fp8_linear = false compile = false dataset = "c4"