From 562d7e29f1fd121f06bfc0853bea7a6b70b23ab5 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Wed, 10 Jul 2024 14:27:13 -0700 Subject: [PATCH 01/29] Renamed parallel styles for transformer block weights ghstack-source-id: 5fb0bf3d08cacf27242ec0f85d5dd3cdc03b739e Pull Request resolved: https://github.com/pytorch/torchtrain/pull/448 --- torchtitan/parallelisms/parallelize_llama.py | 22 +++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 8d796c36..d092ef2a 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -316,9 +316,11 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig): """ tp_mesh = world_mesh["tp"] + # Parallel styles used for transformer block linear weights and their + # inputs may be different for float8 linears ( - row_parallel_strategy, - col_parallel_strategy, + rowwise_parallel_weight, + colwise_parallel_weight, prepare_module_input, ) = get_tp_parallel_strategy(job_config) loss_parallel = parallel_dims.loss_parallel_enabled @@ -336,7 +338,7 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig): output_layouts=Shard(1), ), "norm": SequenceParallel(), - "output": col_parallel_strategy( + "output": colwise_parallel_weight( input_layouts=Shard(1), output_layouts=Shard(-1) if loss_parallel else Replicate(), use_local_output=not loss_parallel, @@ -355,18 +357,18 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig): input_layouts=(Shard(1), None), desired_input_layouts=(Replicate(), None), ), - "attention.wq": col_parallel_strategy(), - "attention.wk": col_parallel_strategy(), - "attention.wv": col_parallel_strategy(), - "attention.wo": row_parallel_strategy(output_layouts=Shard(1)), + "attention.wq": colwise_parallel_weight(), + "attention.wk": colwise_parallel_weight(), + "attention.wv": colwise_parallel_weight(), + "attention.wo": rowwise_parallel_weight(output_layouts=Shard(1)), "ffn_norm": SequenceParallel(), "feed_forward": prepare_module_input( input_layouts=(Shard(1),), desired_input_layouts=(Replicate(),), ), - "feed_forward.w1": col_parallel_strategy(), - "feed_forward.w2": row_parallel_strategy(output_layouts=Shard(1)), - "feed_forward.w3": col_parallel_strategy(), + "feed_forward.w1": colwise_parallel_weight(), + "feed_forward.w2": rowwise_parallel_weight(output_layouts=Shard(1)), + "feed_forward.w3": colwise_parallel_weight(), } # Adjust attention module to use the local number of heads From 0ddf49ba75b4abbd14f451daf17311bb0458f82b Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Wed, 10 Jul 2024 14:27:13 -0700 Subject: [PATCH 02/29] Added type annotations and more stylistic changes ghstack-source-id: 1bd5b9d5abc8644785132f8eb2baaf8b1cfc5fb5 Pull Request resolved: https://github.com/pytorch/torchtrain/pull/449 --- torchtitan/parallelisms/parallelize_llama.py | 111 ++++++++++++------- 1 file changed, 69 insertions(+), 42 deletions(-) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index d092ef2a..1b414159 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -9,9 +9,11 @@ import copy from collections import defaultdict -from typing import Dict, Tuple +from typing import Tuple, TYPE_CHECKING, Union import torch +import torch.nn as nn +from torch.distributed import DeviceMesh from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy from torch.distributed._tensor import Replicate, Shard @@ -29,8 +31,15 @@ from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.logging_utils import logger +from torchtitan.models.llama.model import ModelArgs from torchtitan.parallelisms.pipelining_utils import stage_ids_this_rank +if TYPE_CHECKING: + from torchtitan.parallelisms import ParallelDims + + +DeviceType = Union[int, str, torch.device] + # for selective AC no_recompute_list = { torch.ops.aten.mm.default, @@ -125,23 +134,30 @@ def get_tp_parallel_strategy( def pipeline_llama( - model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict + model: nn.Module, + world_mesh: DeviceMesh, + parallel_dims: "ParallelDims", + job_config: JobConfig, + device: DeviceType, + model_config: ModelArgs, ): - if job_config.experimental.pipeline_parallel_split_mode == "manual": + split_mode = job_config.experimental.pipeline_parallel_split_mode + valid_split_modes = ("manual", "tracer") + if split_mode not in valid_split_modes: + raise ValueError( + f"Invalid split mode: {split_mode}. Valid split modes: {valid_split_modes}" + ) + if split_mode == "manual": return pipeline_llama_manual( model, world_mesh, parallel_dims, job_config, device, model_config ) - elif job_config.experimental.pipeline_parallel_split_mode == "tracer": + elif split_mode == "tracer": return pipeline_llama_tracer( model, world_mesh, parallel_dims, job_config, device, model_config ) - else: - raise NotImplementedError( - f"{job_config.experimental.pipeline_parallel_split_mode} is not a valid split mode" - ) -def _llama_trace_input(job_config, model_config, device="meta"): +def _llama_trace_input(job_config: JobConfig, model_config: ModelArgs, device="meta"): """Get meta tensors with the right input shapes used for tracing""" tokens_shape = (job_config.training.batch_size, job_config.training.seq_len) tokens = torch.randint( @@ -153,18 +169,18 @@ def _llama_trace_input(job_config, model_config, device="meta"): def _mixed_precision_dtype( job_config: JobConfig, parallel_dims, default: torch.dtype = torch.float32 ) -> torch.dtype: - """Get the mixed precision dtype if fsdp is enabled, otherwise return the default""" + """Get the mixed precision dtype if FSDP is enabled, otherwise return the default""" mp_arg = job_config.training.mixed_precision_param return TORCH_DTYPE_MAP[mp_arg] if parallel_dims.dp_enabled else default def pipeline_llama_manual( - whole_model, - world_mesh, - parallel_dims, + whole_model: nn.Module, + world_mesh: DeviceMesh, + parallel_dims: "ParallelDims", job_config: JobConfig, - device, - model_config: Dict, + device: DeviceType, + model_config: ModelArgs, ): """ This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage. @@ -262,19 +278,24 @@ def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=Fal def pipeline_llama_tracer( - model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict + model: nn.Module, + world_mesh: DeviceMesh, + parallel_dims: "ParallelDims", + job_config: JobConfig, + device: DeviceType, + model_config: ModelArgs, ): if job_config.model.norm_type == "fused_rmsnorm": - # TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode - # coming from ` if dy.stride(-1) != 1:` in fused_rmsnorm + # TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr + # invocation stride in strict mode from `if dy.stride(-1) != 1:` in + # fused_rmsnorm raise NotImplementedError( - "fused_rmsnorm not yet compatible with Pipeline Tracer (strides error). Please use layernorm or rmsnorm." + "fused_rmsnorm is not compatible with Pipeline Tracer yet. Please use rmsnorm or layernorm." ) - - if _mixed_precision_dtype(job_config, parallel_dims) == torch.bfloat16: + if _mixed_precision_dtype(job_config, parallel_dims) != torch.float32: raise NotImplementedError( - "pipeline tracer doesn't work with fsdp mixed precision currently. " - "To work around, edit fsdp mixed precision config to use fp32." + "Pipeline tracer does not work with FSDP mixed precision yet. " + "To work around, set mixed_precision_param to float32." ) pp_mesh = world_mesh["pp"] @@ -310,10 +331,13 @@ def pipeline_llama_tracer( return (stages, models) -def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig): - """ - Apply tensor parallelism. - """ +def apply_tp( + model: nn.Module, + world_mesh: DeviceMesh, + parallel_dims: "ParallelDims", + job_config: JobConfig, +): + """Apply tensor parallelism.""" tp_mesh = world_mesh["tp"] # Parallel styles used for transformer block linear weights and their @@ -392,10 +416,8 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig): return model -def apply_ac(model, job_config: JobConfig): - """ - Apply activation checkpointing to the model. - """ +def apply_ac(model: nn.Module, job_config: JobConfig): + """Apply activation checkpointing to the model.""" ac_config = job_config.activation_checkpoint @@ -407,18 +429,15 @@ def apply_ac(model, job_config: JobConfig): return model -def apply_compile(model, job_config: JobConfig): - """ - Apply torch.compile to the model. - """ +def apply_compile(model: nn.Module, job_config: JobConfig): + """Apply torch.compile to each transformer block.""" if job_config.model.norm_type == "fused_rmsnorm": raise NotImplementedError( - "fused_rmsnorm not yet compatible with torch.compile. Please use layernorm or rmsnorm." + "fused_rmsnorm is not compatible with torch.compile yet. Please use rmsnorm or layernorm." ) for layer_id, transformer_block in model.layers.named_children(): - # turn on per-transformer block compile after AC wrapping and before FSDP # TODO: dynamic shape have some issues so we turn it off for now. # TODO: inline inbuilt nn modules does not work yet, enable it to accelarate # compile time. @@ -430,10 +449,13 @@ def apply_compile(model, job_config: JobConfig): return model -def apply_dp(model, world_mesh, parallel_dims, job_config: JobConfig): - """ - Apply data parallelism (FSDP2) to the model. - """ +def apply_dp( + model: nn.Module, + world_mesh: DeviceMesh, + parallel_dims: "ParallelDims", + job_config: JobConfig, +): + """Apply data parallelism (FSDP2) to the model.""" dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names @@ -466,7 +488,12 @@ def apply_dp(model, world_mesh, parallel_dims, job_config: JobConfig): return model -def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): +def parallelize_llama( + model: nn.Module, + world_mesh: DeviceMesh, + parallel_dims: "ParallelDims", + job_config: JobConfig, +): """ Apply tensor parallelism, activation checkpointing, torch.compile, and data parallelism to the model. From 535acf60c6ea6f1e3e61ca91a3842a616d9f612b Mon Sep 17 00:00:00 2001 From: Will Constable Date: Sun, 14 Jul 2024 21:17:02 -0700 Subject: [PATCH 03/29] [Cleanup] Remove libuv from run_llama_train.sh libuv is now enabled by default. we can proably do without the educational blurb there, and don't need the env either since the default has landed. ghstack-source-id: 68c8d2abe7eb0777e2add8df7634367c31b7ec06 Pull Request resolved: https://github.com/pytorch/torchtitan/pull/453 --- create_seed_checkpoint.sh | 1 - multinode_trainer.slurm | 1 - run_llama_train.sh | 3 --- 3 files changed, 5 deletions(-) diff --git a/create_seed_checkpoint.sh b/create_seed_checkpoint.sh index 1abc77ec..3dfbde71 100755 --- a/create_seed_checkpoint.sh +++ b/create_seed_checkpoint.sh @@ -18,7 +18,6 @@ set -ex -export USE_LIBUV=1 TRAINER_DIR=${1:-/home/$USER/local/torchtitan} NGPU=1 LOG_RANK=0 diff --git a/multinode_trainer.slurm b/multinode_trainer.slurm index 09b94ef1..4bc495d3 100644 --- a/multinode_trainer.slurm +++ b/multinode_trainer.slurm @@ -53,7 +53,6 @@ export NCCL_SOCKET_IFNAME="eth0,en,eth,em,bond" export NCCL_BUFFSIZE=2097152 #export TORCH_DIST_INIT_BARRIER=1 export FI_EFA_SET_CUDA_SYNC_MEMOPS=0 -#export USE_LIBUV=1 CONFIG_FILE=${CONFIG_FILE:-"./train_configs/llama2_13b.toml"} dcgmi profile --pause diff --git a/run_llama_train.sh b/run_llama_train.sh index cf4943a6..5a661284 100755 --- a/run_llama_train.sh +++ b/run_llama_train.sh @@ -7,9 +7,6 @@ set -ex -# libUV is a scalable backend for TCPStore which is used in processGroup -# rendezvous. This is the recommended backend for distributed training. -export USE_LIBUV=1 TRAINER_DIR=${TRAINER_DIR:-/home/$USER/local/torchtitan} # use envs as local overrides for convenience From ac72078bb590d71a74872e514be01c579544ef84 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Sun, 14 Jul 2024 21:17:03 -0700 Subject: [PATCH 04/29] [Cleanup] Organize run_llama_train.sh options Just a little code motion but it looks cleaner to me this way ghstack-source-id: 055fbd557cd9cf189e6b9bd6a7048f1204e1dc5c Pull Request resolved: https://github.com/pytorch/torchtitan/pull/454 --- run_llama_train.sh | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/run_llama_train.sh b/run_llama_train.sh index 5a661284..91133a25 100755 --- a/run_llama_train.sh +++ b/run_llama_train.sh @@ -7,19 +7,13 @@ set -ex -TRAINER_DIR=${TRAINER_DIR:-/home/$USER/local/torchtitan} - # use envs as local overrides for convenience # e.g. # LOG_RANK=0,1 NGPU=4 ./run_llama_train.sh - NGPU=${NGPU:-"8"} NNODES=${NNODES:-"1"} - -# by default log just rank 0 output, +TRAINER_DIR=${TRAINER_DIR:-/home/$USER/local/torchtitan} LOG_RANK=${LOG_RANK:-0} - - CONFIG_FILE=${CONFIG_FILE:-"./train_configs/debug_model.toml"} overrides="" From 4b6cdc17627deae58113032889d8d04245a94111 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Sun, 14 Jul 2024 21:17:03 -0700 Subject: [PATCH 05/29] [Cleanup] Split run_llama_train.sh and run_memory_estimation.sh Make each script simpler to read ghstack-source-id: ba3aa65feb6e304736c73daf5bc8ab5fb254f196 Pull Request resolved: https://github.com/pytorch/torchtitan/pull/455 --- run_llama_train.sh | 17 +++-------------- run_memory_estimation.sh | 27 +++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 14 deletions(-) create mode 100755 run_memory_estimation.sh diff --git a/run_llama_train.sh b/run_llama_train.sh index 91133a25..8da1ebda 100755 --- a/run_llama_train.sh +++ b/run_llama_train.sh @@ -11,7 +11,6 @@ set -ex # e.g. # LOG_RANK=0,1 NGPU=4 ./run_llama_train.sh NGPU=${NGPU:-"8"} -NNODES=${NNODES:-"1"} TRAINER_DIR=${TRAINER_DIR:-/home/$USER/local/torchtitan} LOG_RANK=${LOG_RANK:-0} CONFIG_FILE=${CONFIG_FILE:-"./train_configs/debug_model.toml"} @@ -21,16 +20,6 @@ if [ $# -ne 0 ]; then overrides="$*" fi -# Check if --estimate.memory=True is in the arguments -if echo "$overrides" | grep -q -- "--memory_estimation.enabled"; then - # Calculate WORLD_SIZE as the product of NGPU and NNODES - # Export WORLD_SIZE and LOCAL_RANK - export WORLD_SIZE=$((NGPU * NNODES)) - export LOCAL_RANK=0 - python estimation.py --job.config_file ${CONFIG_FILE} $overrides -else - # Call train.py if not in estimation mode - torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ - --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ - train.py --job.config_file ${CONFIG_FILE} $overrides -fi +torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ +--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ +train.py --job.config_file ${CONFIG_FILE} $overrides diff --git a/run_memory_estimation.sh b/run_memory_estimation.sh new file mode 100755 index 00000000..f58b089a --- /dev/null +++ b/run_memory_estimation.sh @@ -0,0 +1,27 @@ +#!/usr/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -ex + +# use envs as local overrides for convenience +# e.g. +# NGPU=4 ./run_memory_estimation.sh +NGPU=${NGPU:-"8"} +NNODES=${NNODES:-"1"} +TRAINER_DIR=${TRAINER_DIR:-/home/$USER/local/torchtitan} +CONFIG_FILE=${CONFIG_FILE:-"./train_configs/debug_model.toml"} + +overrides="" +if [ $# -ne 0 ]; then + overrides="$*" +fi + +# Calculate WORLD_SIZE as the product of NGPU and NNODES +# Export WORLD_SIZE and LOCAL_RANK +export WORLD_SIZE=$((NGPU * NNODES)) +export LOCAL_RANK=0 +python estimation.py --job.config_file ${CONFIG_FILE} --memory_estimation.enabled $overrides From 8fa11f0f55a6f2d2e5bf460098546e600cc3058e Mon Sep 17 00:00:00 2001 From: Will Constable Date: Sun, 14 Jul 2024 21:17:04 -0700 Subject: [PATCH 06/29] [Cleanup] Remove unused TRAINER_DIR This argument seems to be left over from older times- it is not used anywhere in the codebase. ghstack-source-id: abbcf82ed4d1b8fbb71c6a6b48acbc1296dbec64 Pull Request resolved: https://github.com/pytorch/torchtitan/pull/456 --- create_seed_checkpoint.sh | 1 - run_llama_train.sh | 1 - run_memory_estimation.sh | 1 - 3 files changed, 3 deletions(-) diff --git a/create_seed_checkpoint.sh b/create_seed_checkpoint.sh index 3dfbde71..77185bfc 100755 --- a/create_seed_checkpoint.sh +++ b/create_seed_checkpoint.sh @@ -18,7 +18,6 @@ set -ex -TRAINER_DIR=${1:-/home/$USER/local/torchtitan} NGPU=1 LOG_RANK=0 CONFIG_FILE=${CONFIG_FILE:-"./train_configs/debug_model.toml"} diff --git a/run_llama_train.sh b/run_llama_train.sh index 8da1ebda..a4107806 100755 --- a/run_llama_train.sh +++ b/run_llama_train.sh @@ -11,7 +11,6 @@ set -ex # e.g. # LOG_RANK=0,1 NGPU=4 ./run_llama_train.sh NGPU=${NGPU:-"8"} -TRAINER_DIR=${TRAINER_DIR:-/home/$USER/local/torchtitan} LOG_RANK=${LOG_RANK:-0} CONFIG_FILE=${CONFIG_FILE:-"./train_configs/debug_model.toml"} diff --git a/run_memory_estimation.sh b/run_memory_estimation.sh index f58b089a..02148b84 100755 --- a/run_memory_estimation.sh +++ b/run_memory_estimation.sh @@ -12,7 +12,6 @@ set -ex # NGPU=4 ./run_memory_estimation.sh NGPU=${NGPU:-"8"} NNODES=${NNODES:-"1"} -TRAINER_DIR=${TRAINER_DIR:-/home/$USER/local/torchtitan} CONFIG_FILE=${CONFIG_FILE:-"./train_configs/debug_model.toml"} overrides="" From 174c44a3967790f9fedc3513d67cb662a2407382 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Sun, 14 Jul 2024 21:23:56 -0700 Subject: [PATCH 07/29] Add educational code pointers to top level README ghstack-source-id: 522aa2fa0bf1679f55d9f3a8a38fdcd319d5e3df Pull Request resolved: https://github.com/pytorch/torchtitan/pull/457 --- README.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/README.md b/README.md index 18364d8f..dde75e20 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,14 @@ Our guiding principles when building `torchtitan`: [![Welcome to torchtitan!](assets/images/titan_play_video.png)](https://youtu.be/ee5DOEqD35I?si=_B94PbVv0V5ZnNKE "Welcome to torchtitan!") +### Dive into the code + +You may want to see how the model is defined or how parallelism techniques are applied. For a guided tour, see these files first: +* [train.py](https://github.com/pytorch/torchtitan/blob/main/train.py) - the main training loop and high-level setup code +* [torchtitan/parallelisms/parallelize_llama.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py) - helpers for applying Data / Tensor / Pipeline Parallelisms to the model +* [torchtitan/checkpoint.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/checkpoint.py) - utils for saving/loading distributed checkpoints +* [torchtitan/models/llama/model.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py) - the Llama model definition (shared for Llama2 and Llama3 variants) + ## Pre-Release Updates: #### (4/25/2024): `torchtitan` is now public but in a pre-release state and under development. Currently we showcase pre-training **Llama 3 and Llama 2** LLMs of various sizes from scratch. `torchtitan` is tested and verified with the PyTorch nightly version `torch-2.4.0.dev20240412`. (We recommend latest PyTorch nightly). From a4b2ee3c0dc03648f8d38032e529f199a18f841e Mon Sep 17 00:00:00 2001 From: "Wei (Will) Feng" <134637289+weifengpy@users.noreply.github.com> Date: Tue, 16 Jul 2024 15:46:21 -0700 Subject: [PATCH 08/29] enable FSDP2 + fp8 all-gather and fix TP fp8 all-gather (#413) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit we have landed fp8 all-gather optimizations in float8_experimental https://github.com/pytorch-labs/float8_experimental/pull/266/ this PR proposes torchtitan changes. also include fp8 in CI ``` from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp # inside the training loop model(input).sum().backward() optim.step() precompute_float8_dynamic_scale_for_fsdp(model) ``` FSDP2 fp8 all-gather are added to CI ``` CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear --training.enable_fsdp_fp8_all_gather CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear --training.enable_fsdp_fp8_all_gather --training.precompute_float8_dynamic_scale_for_fsdp ``` TP fp8 all-gather are locally tested. will add them to CI after uploading a new tokenizer with vacab size 2560 (divisible by 16) ``` CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=4 ./run_llama_train.sh --training.enable_fp8_linear --training.data_parallel_degree 1 --training.tensor_parallel_degree 4 CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=4 ./run_llama_train.sh --training.enable_fp8_linear --training.data_parallel_degree 2 --training.tensor_parallel_degree 2 ``` precompute scales after optimizer.step Screenshot 2024-07-12 at 5 11 14 PM FSDP2 pre-all-gather do not have any small all-reduces Screenshot 2024-07-12 at 5 13 04 PM TODO * upload tokenizer with vacab size 2560 to enable CI on TP fp8 all-gather * torch.compile complains about fp8 * add delayed scaling and brainstorm about best config option to express fp8 * compare perf between delayed scaling and dynamic scaling https://github.com/pytorch-labs/float8_experimental/pull/312/files --- estimation.py | 4 +- test_runner.py | 33 ++++++++++++++++ torchtitan/config_manager.py | 14 ++++++- torchtitan/float8_linear.py | 41 ++++++++++++++++---- torchtitan/parallelisms/parallelize_llama.py | 16 +++++++- train.py | 14 ++++++- train_configs/debug_model.toml | 2 +- train_configs/llama2_13b.toml | 2 +- train_configs/llama2_70b.toml | 2 +- train_configs/llama2_7b.toml | 2 +- train_configs/llama3_70b.toml | 2 +- train_configs/llama3_8b.toml | 2 +- 12 files changed, 114 insertions(+), 20 deletions(-) diff --git a/estimation.py b/estimation.py index ddf24d8a..e652c581 100644 --- a/estimation.py +++ b/estimation.py @@ -124,8 +124,8 @@ def loss_fn(pred, labels): whole_model = model_cls.from_model_args(model_config) # apply fp8 linear module swap - if job_config.training.fp8_linear: - build_fp8_linear(whole_model, job_config) + if job_config.training.enable_fp8_linear: + build_fp8_linear(whole_model, job_config, parallel_dims.dp_enabled) # apply PT-D DP/TP parallelisms and activation checkpointing model_parts = [whole_model] diff --git a/test_runner.py b/test_runner.py index 319f99d7..f2f80504 100755 --- a/test_runner.py +++ b/test_runner.py @@ -273,6 +273,39 @@ def build_test_list(): "fsdp2_mem_tracker", ngpu=4, ), + OverrideDefinitions( + [ + [ + "--training.enable_fp8_linear", + ] + ], + "FSDP2 with original dtype", + "fp8_fsdp2_orig_all_gather", + ngpu=4, + ), + OverrideDefinitions( + [ + [ + "--training.enable_fp8_linear", + "--training.enable_fsdp_fp8_all_gather", + ] + ], + "FSDP2 with fp8 all-gather", + "fsdp2_fp8_all_gather", + ngpu=4, + ), + OverrideDefinitions( + [ + [ + "--training.enable_fp8_linear", + "--training.enable_fsdp_fp8_all_gather", + "--training.precompute_float8_dynamic_scale_for_fsdp", + ] + ], + "FSDP2 with fp8 all-gather and precomputed dynamic scales", + "fsdp2_fp8_all_gather_precompute_dynamic_scales", + ngpu=4, + ), ] return integration_tests_flavors diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 3ade1b9d..0dfe1bb0 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -338,7 +338,7 @@ def __init__(self): help="Whether to compile the model", ) self.parser.add_argument( - "--training.fp8_linear", + "--training.enable_fp8_linear", action="store_true", help=""" If true, swaps `torch.nn.Linear` with `Float8Linear` with @@ -347,6 +347,18 @@ def __init__(self): here: https://github.com/pytorch-labs/float8_experimental """, ) + self.parser.add_argument( + "--training.enable_fsdp_fp8_all_gather", + action="store_true", + default=False, + help="Whether enable fp8 all-gather in FSDP", + ) + self.parser.add_argument( + "--training.precompute_float8_dynamic_scale_for_fsdp", + action="store_true", + default=False, + help="Whether precompute fp8 scales dynamically for FSDP", + ) self.parser.add_argument( "--training.gc_freq", type=int, diff --git a/torchtitan/float8_linear.py b/torchtitan/float8_linear.py index 0bd0900c..f41a812d 100644 --- a/torchtitan/float8_linear.py +++ b/torchtitan/float8_linear.py @@ -12,31 +12,58 @@ # Note: Performance # Float8 experimental is intended to be ran under `torch.compile`` for competitive performance +import contextlib +from typing import Optional +import float8_experimental.config as config + +import torch import torch.nn as nn +from float8_experimental.float8_linear import TensorScalingType from torchtitan.config_manager import JobConfig from torchtitan.logging_utils import logger -def build_fp8_linear(model: nn.Module, job_config: JobConfig): +@contextlib.contextmanager +def set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather: bool): + prev = config.enable_fsdp_fp8_all_gather + torch.distributed.barrier() + config.enable_fsdp_fp8_all_gather = enable_fsdp_fp8_all_gather + try: + yield + finally: + torch.distributed.barrier() + config.enable_fsdp_fp8_all_gather = prev + + +def build_fp8_linear( + model: nn.Module, job_config: JobConfig, dp_enabled: Optional[bool] = False +): """ This function converts the linear layers to `Float8Linear`. Note that today, only dynamic tensor scaling (the default) is supported. This will mutate the model inplace. """ - use_fp8_linear = job_config.training.fp8_linear + enable_fp8_linear = job_config.training.enable_fp8_linear + enable_fsdp_fp8_all_gather = ( + job_config.training.enable_fsdp_fp8_all_gather and dp_enabled + ) try: - from float8_experimental.float8_linear import Float8Linear from float8_experimental.float8_linear_utils import ( swap_linear_with_float8_linear, ) + + # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear + with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): + swap_linear_with_float8_linear( + model, scaling_type_w=TensorScalingType.DYNAMIC + ) + logger.info( + f"Swapped to Float8Linear layers with {enable_fsdp_fp8_all_gather=}" + ) except ImportError as exc: raise ImportError( "float8_experimental is not installed. Please install it to use fp8 linear layers." ) from exc - if use_fp8_linear: - # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear - swap_linear_with_float8_linear(model, Float8Linear) - logger.info("Swapped to Float8Linear layers") diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 1b414159..b33e8870 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -117,12 +117,24 @@ def selective_checkpointing_context_fn(): def get_tp_parallel_strategy( job_config: JobConfig, + model: nn.Module, ) -> Tuple[RowwiseParallel, ColwiseParallel, PrepareModuleInput]: """Get the parallel strategy for the transformer model. This function handles the special case of using float8 with tensor parallelism. """ - if job_config.training.fp8_linear == "dynamic": + if job_config.training.enable_fp8_linear: + from float8_experimental.float8_linear import Float8Linear, TensorScalingType + + if any( + isinstance(m, Float8Linear) + and m.scaling_type_w is TensorScalingType.DELAYED + for m in model.modules() + ): + raise NotImplementedError( + "1D TP fp8 all-gather only supports dynamic scaling" + ) + from float8_experimental.float8_tensor_parallel import ( Float8ColwiseParallel, Float8RowwiseParallel, @@ -346,7 +358,7 @@ def apply_tp( rowwise_parallel_weight, colwise_parallel_weight, prepare_module_input, - ) = get_tp_parallel_strategy(job_config) + ) = get_tp_parallel_strategy(job_config, model) loss_parallel = parallel_dims.loss_parallel_enabled # 1. Parallelize the embedding and shard its outputs (which are the first diff --git a/train.py b/train.py index 8e55c210..14008525 100644 --- a/train.py +++ b/train.py @@ -19,6 +19,7 @@ import torch import torch.nn.functional as F +from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp from torch.distributed import destroy_process_group from torch.distributed.checkpoint.stateful import Stateful from torch.distributed.elastic.multiprocessing.errors import record @@ -216,8 +217,8 @@ def loss_fn(pred, labels): whole_model = model_cls.from_model_args(model_config) # apply fp8 linear module swap - if job_config.training.fp8_linear: - build_fp8_linear(whole_model, job_config) + if job_config.training.enable_fp8_linear: + build_fp8_linear(whole_model, job_config, parallel_dims.dp_enabled) # log model size model_param_count = get_num_params(whole_model) @@ -398,6 +399,15 @@ def loss_fn(pred, labels): optimizers.step() lr_schedulers.step() + if ( + job_config.training.enable_fp8_linear + and job_config.training.enable_fsdp_fp8_all_gather + and job_config.training.precompute_float8_dynamic_scale_for_fsdp + ): + # calculate float8 dynamic amax/scale for all-parameter for FSDP2 + # it issues a single all-reduce for all parameters at once for better performance + precompute_float8_dynamic_scale_for_fsdp(model) + losses_since_last_log.append(loss) # log metrics diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index cb2fb215..6064ced1 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -37,7 +37,7 @@ max_norm = 1.0 # grad norm clipping steps = 10 data_parallel_degree = -1 tensor_parallel_degree = 1 -fp8_linear = false +enable_fp8_linear = false compile = false dataset = "c4_mini" # supported datasets: c4_mini (45K), c4 (177M) diff --git a/train_configs/llama2_13b.toml b/train_configs/llama2_13b.toml index 05e3c27b..f4061ad0 100644 --- a/train_configs/llama2_13b.toml +++ b/train_configs/llama2_13b.toml @@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 1 -fp8_linear = false +enable_fp8_linear = false compile = false dataset = "c4" diff --git a/train_configs/llama2_70b.toml b/train_configs/llama2_70b.toml index 5b2dd493..19e033b8 100644 --- a/train_configs/llama2_70b.toml +++ b/train_configs/llama2_70b.toml @@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 8 # 8-way TP -fp8_linear = false +enable_fp8_linear = false compile = false dataset = "c4" diff --git a/train_configs/llama2_7b.toml b/train_configs/llama2_7b.toml index 9b72246a..95d67667 100644 --- a/train_configs/llama2_7b.toml +++ b/train_configs/llama2_7b.toml @@ -32,7 +32,7 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 1 # dp-only would be sufficient for 7B -fp8_linear = false +enable_fp8_linear = false compile = false dataset = "c4" diff --git a/train_configs/llama3_70b.toml b/train_configs/llama3_70b.toml index 93b529f6..ac6b31c1 100644 --- a/train_configs/llama3_70b.toml +++ b/train_configs/llama3_70b.toml @@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 8 # 8-way TP -fp8_linear = false +enable_fp8_linear = false compile = false dataset = "c4" diff --git a/train_configs/llama3_8b.toml b/train_configs/llama3_8b.toml index 95a53d56..2c3c6e63 100644 --- a/train_configs/llama3_8b.toml +++ b/train_configs/llama3_8b.toml @@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 1 -fp8_linear = false +enable_fp8_linear = false compile = false dataset = "c4" From ae8181b9189d32f3da13700ba73e575df644b44a Mon Sep 17 00:00:00 2001 From: "Wei (Will) Feng" <134637289+weifengpy@users.noreply.github.com> Date: Tue, 16 Jul 2024 17:56:17 -0700 Subject: [PATCH 09/29] import float8_experimental only when fp8 is enabled and install it in CI (#464) make sure to only import float8_experimental when fp8 is enabled for 4 gpu CI, make sure we can import float8_experimental correctly in CI `python -m pip install git+https://github.com/pytorch-labs/float8_experimental.git` --- .github/workflows/integration_test_4gpu.yaml | 1 + torchtitan/float8_linear.py | 6 +++--- train.py | 5 ++++- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/.github/workflows/integration_test_4gpu.yaml b/.github/workflows/integration_test_4gpu.yaml index 3816f404..7c913b07 100644 --- a/.github/workflows/integration_test_4gpu.yaml +++ b/.github/workflows/integration_test_4gpu.yaml @@ -39,5 +39,6 @@ jobs: python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 python -m pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/ + python -m pip install git+https://github.com/pytorch-labs/float8_experimental.git mkdir artifacts-to-be-uploaded python ./test_runner.py artifacts-to-be-uploaded --ngpu 4 diff --git a/torchtitan/float8_linear.py b/torchtitan/float8_linear.py index f41a812d..496b590a 100644 --- a/torchtitan/float8_linear.py +++ b/torchtitan/float8_linear.py @@ -15,11 +15,8 @@ import contextlib from typing import Optional -import float8_experimental.config as config - import torch import torch.nn as nn -from float8_experimental.float8_linear import TensorScalingType from torchtitan.config_manager import JobConfig from torchtitan.logging_utils import logger @@ -27,6 +24,8 @@ @contextlib.contextmanager def set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather: bool): + import float8_experimental.config as config + prev = config.enable_fsdp_fp8_all_gather torch.distributed.barrier() config.enable_fsdp_fp8_all_gather = enable_fsdp_fp8_all_gather @@ -51,6 +50,7 @@ def build_fp8_linear( job_config.training.enable_fsdp_fp8_all_gather and dp_enabled ) try: + from float8_experimental.float8_linear import TensorScalingType from float8_experimental.float8_linear_utils import ( swap_linear_with_float8_linear, ) diff --git a/train.py b/train.py index 14008525..2c63e299 100644 --- a/train.py +++ b/train.py @@ -19,7 +19,6 @@ import torch import torch.nn.functional as F -from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp from torch.distributed import destroy_process_group from torch.distributed.checkpoint.stateful import Stateful from torch.distributed.elastic.multiprocessing.errors import record @@ -404,6 +403,10 @@ def loss_fn(pred, labels): and job_config.training.enable_fsdp_fp8_all_gather and job_config.training.precompute_float8_dynamic_scale_for_fsdp ): + from float8_experimental.fsdp_utils import ( + precompute_float8_dynamic_scale_for_fsdp, + ) + # calculate float8 dynamic amax/scale for all-parameter for FSDP2 # it issues a single all-reduce for all parameters at once for better performance precompute_float8_dynamic_scale_for_fsdp(model) From 3760bcf3b7b7f6c9bce74bdd4bc3947f259d9c15 Mon Sep 17 00:00:00 2001 From: "Wei (Will) Feng" <134637289+weifengpy@users.noreply.github.com> Date: Tue, 16 Jul 2024 22:09:58 -0700 Subject: [PATCH 10/29] skip fp8 CI on non-H100 GPUs (#465) skip fp8 tests on non-H100 GPUs by checking `torch.cuda.get_device_capability() >= (9, 0)` this makes 4 GPU CI healthy again --- torchtitan/float8_linear.py | 44 +++++++++++++++++++++++++++++++++---- train.py | 26 +++++++++------------- 2 files changed, 50 insertions(+), 20 deletions(-) diff --git a/torchtitan/float8_linear.py b/torchtitan/float8_linear.py index 496b590a..9b92400c 100644 --- a/torchtitan/float8_linear.py +++ b/torchtitan/float8_linear.py @@ -13,10 +13,12 @@ # Note: Performance # Float8 experimental is intended to be ran under `torch.compile`` for competitive performance import contextlib +import functools from typing import Optional import torch import torch.nn as nn +from torch._logging import warning_once from torchtitan.config_manager import JobConfig from torchtitan.logging_utils import logger @@ -36,7 +38,13 @@ def set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather: bool): config.enable_fsdp_fp8_all_gather = prev -def build_fp8_linear( +@functools.lru_cache(None) +def is_sm90_or_later(): + # Float8 is only supported on H100+ GPUs + return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) + + +def maybe_build_fp8_linear( model: nn.Module, job_config: JobConfig, dp_enabled: Optional[bool] = False ): """ @@ -46,9 +54,14 @@ def build_fp8_linear( This will mutate the model inplace. """ enable_fp8_linear = job_config.training.enable_fp8_linear - enable_fsdp_fp8_all_gather = ( - job_config.training.enable_fsdp_fp8_all_gather and dp_enabled - ) + if not enable_fp8_linear: + return + if not is_sm90_or_later(): + warning_once( + logger, + "Failed to swap to Float8Linear because SM90 or later is not available", + ) + return try: from float8_experimental.float8_linear import TensorScalingType from float8_experimental.float8_linear_utils import ( @@ -56,6 +69,9 @@ def build_fp8_linear( ) # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear + enable_fsdp_fp8_all_gather = ( + job_config.training.enable_fsdp_fp8_all_gather and dp_enabled + ) with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): swap_linear_with_float8_linear( model, scaling_type_w=TensorScalingType.DYNAMIC @@ -67,3 +83,23 @@ def build_fp8_linear( raise ImportError( "float8_experimental is not installed. Please install it to use fp8 linear layers." ) from exc + + +def maybe_precompute_fp8_dynamic_scale_for_fsdp( + model: nn.Module, job_config: JobConfig +): + if not ( + job_config.training.enable_fp8_linear + and job_config.training.enable_fsdp_fp8_all_gather + and job_config.training.precompute_float8_dynamic_scale_for_fsdp + ): + return + if not is_sm90_or_later(): + warning_once( + logger, + "Skipped precomputing fp8 scales because SM90 or later is not available", + ) + return + from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp + + precompute_float8_dynamic_scale_for_fsdp(model) diff --git a/train.py b/train.py index 2c63e299..afd1d888 100644 --- a/train.py +++ b/train.py @@ -27,7 +27,10 @@ from torchtitan.checkpoint import CheckpointManager from torchtitan.config_manager import JobConfig from torchtitan.datasets import build_hf_data_loader, create_tokenizer -from torchtitan.float8_linear import build_fp8_linear +from torchtitan.float8_linear import ( + maybe_build_fp8_linear, + maybe_precompute_fp8_dynamic_scale_for_fsdp, +) from torchtitan.logging_utils import init_logger, logger from torchtitan.lr_scheduling import get_lr_schedulers from torchtitan.metrics import build_gpu_memory_monitor, build_metric_logger @@ -215,9 +218,8 @@ def loss_fn(pred, labels): with torch.device("meta"): whole_model = model_cls.from_model_args(model_config) - # apply fp8 linear module swap - if job_config.training.enable_fp8_linear: - build_fp8_linear(whole_model, job_config, parallel_dims.dp_enabled) + # swap to Float8Linear base on fp8 config + maybe_build_fp8_linear(whole_model, job_config, parallel_dims.dp_enabled) # log model size model_param_count = get_num_params(whole_model) @@ -398,18 +400,10 @@ def loss_fn(pred, labels): optimizers.step() lr_schedulers.step() - if ( - job_config.training.enable_fp8_linear - and job_config.training.enable_fsdp_fp8_all_gather - and job_config.training.precompute_float8_dynamic_scale_for_fsdp - ): - from float8_experimental.fsdp_utils import ( - precompute_float8_dynamic_scale_for_fsdp, - ) - - # calculate float8 dynamic amax/scale for all-parameter for FSDP2 - # it issues a single all-reduce for all parameters at once for better performance - precompute_float8_dynamic_scale_for_fsdp(model) + # when fp8 config is on, + # calculate float8 dynamic amax/scale for all-parameter for FSDP2 + # it issues a single all-reduce for all parameters at once for better performance + maybe_precompute_fp8_dynamic_scale_for_fsdp(model, job_config) losses_since_last_log.append(loss) From 69fe8defc248479acc50dc795fb652e2a11e07f1 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Wed, 17 Jul 2024 15:09:35 -0700 Subject: [PATCH 11/29] clean up float8 configs in torchtitan (#466) Summary: 1. standardizes on `float8` instead of `fp8` for config names 2. removes usage of non-public objects such as `Float8Linear` Test Plan: ``` with-proxy NGPU=1 CUDA_VISIBLE_DEVICES=7 CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.compile --training.enable_float8_linear ``` Reviewers: Subscribers: Tasks: Tags: --- test_runner.py | 20 ++++++++++---------- torchtitan/config_manager.py | 8 ++++---- torchtitan/float8_linear.py | 18 +++++++++--------- torchtitan/parallelisms/parallelize_llama.py | 16 ++++------------ train_configs/debug_model.toml | 2 +- train_configs/llama2_13b.toml | 2 +- train_configs/llama2_70b.toml | 2 +- train_configs/llama2_7b.toml | 2 +- train_configs/llama3_70b.toml | 2 +- train_configs/llama3_8b.toml | 2 +- 10 files changed, 33 insertions(+), 41 deletions(-) diff --git a/test_runner.py b/test_runner.py index f2f80504..c84ca6af 100755 --- a/test_runner.py +++ b/test_runner.py @@ -276,34 +276,34 @@ def build_test_list(): OverrideDefinitions( [ [ - "--training.enable_fp8_linear", + "--training.enable_float8_linear", ] ], "FSDP2 with original dtype", - "fp8_fsdp2_orig_all_gather", + "float8_fsdp2_orig_all_gather", ngpu=4, ), OverrideDefinitions( [ [ - "--training.enable_fp8_linear", - "--training.enable_fsdp_fp8_all_gather", + "--training.enable_float8_linear", + "--training.enable_fsdp_float8_all_gather", ] ], - "FSDP2 with fp8 all-gather", - "fsdp2_fp8_all_gather", + "FSDP2 with float8 all-gather", + "fsdp2_float8_all_gather", ngpu=4, ), OverrideDefinitions( [ [ - "--training.enable_fp8_linear", - "--training.enable_fsdp_fp8_all_gather", + "--training.enable_float8_linear", + "--training.enable_fsdp_float8_all_gather", "--training.precompute_float8_dynamic_scale_for_fsdp", ] ], - "FSDP2 with fp8 all-gather and precomputed dynamic scales", - "fsdp2_fp8_all_gather_precompute_dynamic_scales", + "FSDP2 with float8 all-gather and precomputed dynamic scales", + "fsdp2_float8_all_gather_precompute_dynamic_scales", ngpu=4, ), ] diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 0dfe1bb0..2bd6e370 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -338,7 +338,7 @@ def __init__(self): help="Whether to compile the model", ) self.parser.add_argument( - "--training.enable_fp8_linear", + "--training.enable_float8_linear", action="store_true", help=""" If true, swaps `torch.nn.Linear` with `Float8Linear` with @@ -348,16 +348,16 @@ def __init__(self): """, ) self.parser.add_argument( - "--training.enable_fsdp_fp8_all_gather", + "--training.enable_fsdp_float8_all_gather", action="store_true", default=False, - help="Whether enable fp8 all-gather in FSDP", + help="Whether enable float8 all-gather in FSDP", ) self.parser.add_argument( "--training.precompute_float8_dynamic_scale_for_fsdp", action="store_true", default=False, - help="Whether precompute fp8 scales dynamically for FSDP", + help="Whether precompute float8 scales dynamically for FSDP", ) self.parser.add_argument( "--training.gc_freq", diff --git a/torchtitan/float8_linear.py b/torchtitan/float8_linear.py index 9b92400c..50c971ae 100644 --- a/torchtitan/float8_linear.py +++ b/torchtitan/float8_linear.py @@ -25,7 +25,7 @@ @contextlib.contextmanager -def set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather: bool): +def set_enable_fsdp_float8_all_gather(enable_fsdp_fp8_all_gather: bool): import float8_experimental.config as config prev = config.enable_fsdp_fp8_all_gather @@ -53,8 +53,8 @@ def maybe_build_fp8_linear( This will mutate the model inplace. """ - enable_fp8_linear = job_config.training.enable_fp8_linear - if not enable_fp8_linear: + enable_float8_linear = job_config.training.enable_float8_linear + if not enable_float8_linear: return if not is_sm90_or_later(): warning_once( @@ -69,15 +69,15 @@ def maybe_build_fp8_linear( ) # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear - enable_fsdp_fp8_all_gather = ( - job_config.training.enable_fsdp_fp8_all_gather and dp_enabled + enable_fsdp_float8_all_gather = ( + job_config.training.enable_fsdp_float8_all_gather and dp_enabled ) - with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): + with set_enable_fsdp_float8_all_gather(enable_fsdp_float8_all_gather): swap_linear_with_float8_linear( model, scaling_type_w=TensorScalingType.DYNAMIC ) logger.info( - f"Swapped to Float8Linear layers with {enable_fsdp_fp8_all_gather=}" + f"Swapped to Float8Linear layers with {enable_fsdp_float8_all_gather=}" ) except ImportError as exc: raise ImportError( @@ -89,8 +89,8 @@ def maybe_precompute_fp8_dynamic_scale_for_fsdp( model: nn.Module, job_config: JobConfig ): if not ( - job_config.training.enable_fp8_linear - and job_config.training.enable_fsdp_fp8_all_gather + job_config.training.enable_float8_linear + and job_config.training.enable_fsdp_float8_all_gather and job_config.training.precompute_float8_dynamic_scale_for_fsdp ): return diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index b33e8870..ec0f6763 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -123,18 +123,10 @@ def get_tp_parallel_strategy( This function handles the special case of using float8 with tensor parallelism. """ - if job_config.training.enable_fp8_linear: - from float8_experimental.float8_linear import Float8Linear, TensorScalingType - - if any( - isinstance(m, Float8Linear) - and m.scaling_type_w is TensorScalingType.DELAYED - for m in model.modules() - ): - raise NotImplementedError( - "1D TP fp8 all-gather only supports dynamic scaling" - ) - + if job_config.training.enable_float8_linear: + # TODO(future PR): once float8 configuration supports delayed + # scaling, add a check here to enforce supported float8 all-gather + # configurations from float8_experimental.float8_tensor_parallel import ( Float8ColwiseParallel, Float8RowwiseParallel, diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 6064ced1..7c849976 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -37,7 +37,7 @@ max_norm = 1.0 # grad norm clipping steps = 10 data_parallel_degree = -1 tensor_parallel_degree = 1 -enable_fp8_linear = false +enable_float8_linear = false compile = false dataset = "c4_mini" # supported datasets: c4_mini (45K), c4 (177M) diff --git a/train_configs/llama2_13b.toml b/train_configs/llama2_13b.toml index f4061ad0..2dc29f2e 100644 --- a/train_configs/llama2_13b.toml +++ b/train_configs/llama2_13b.toml @@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 1 -enable_fp8_linear = false +enable_float8_linear = false compile = false dataset = "c4" diff --git a/train_configs/llama2_70b.toml b/train_configs/llama2_70b.toml index 19e033b8..f17496c5 100644 --- a/train_configs/llama2_70b.toml +++ b/train_configs/llama2_70b.toml @@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 8 # 8-way TP -enable_fp8_linear = false +enable_float8_linear = false compile = false dataset = "c4" diff --git a/train_configs/llama2_7b.toml b/train_configs/llama2_7b.toml index 95d67667..69ae7285 100644 --- a/train_configs/llama2_7b.toml +++ b/train_configs/llama2_7b.toml @@ -32,7 +32,7 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 1 # dp-only would be sufficient for 7B -enable_fp8_linear = false +enable_float8_linear = false compile = false dataset = "c4" diff --git a/train_configs/llama3_70b.toml b/train_configs/llama3_70b.toml index ac6b31c1..660f2c0b 100644 --- a/train_configs/llama3_70b.toml +++ b/train_configs/llama3_70b.toml @@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 8 # 8-way TP -enable_fp8_linear = false +enable_float8_linear = false compile = false dataset = "c4" diff --git a/train_configs/llama3_8b.toml b/train_configs/llama3_8b.toml index 2c3c6e63..7e5ac63c 100644 --- a/train_configs/llama3_8b.toml +++ b/train_configs/llama3_8b.toml @@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 1 -enable_fp8_linear = false +enable_float8_linear = false compile = false dataset = "c4" From 2f989b9f36d782f3ae3a6059621d68585a0e68ee Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 18 Jul 2024 09:55:38 -0700 Subject: [PATCH 12/29] Add support of DDP and experimental CompiledAutograd Summary: Address the comments in https://github.com/pytorch/torchtitan/pull/319 and resubmit the PR to fit the current code base. Test Plan: ``` CONFIG_FILE=./train_configs/debug_model.toml ./run_llama_train.sh --comm.train_timeout_seconds=3600 --training.tensor_parallel_degree=1 --training.data_parallel_degree=8 --experimental.data_parallel_type=ddp --training.steps=1000 --metrics.log_freq=10 --profiling.profile_freq=1000 ``` ghstack-source-id: 81dc85d42df13df4ed727bebd825681879af936b Pull Request resolved: https://github.com/pytorch/torchtitan/pull/432 --- estimation.py | 1 + test_runner.py | 9 +++++ torchtitan/config_manager.py | 11 ++++++ torchtitan/parallelisms/__init__.py | 3 ++ torchtitan/parallelisms/parallelize_llama.py | 36 ++++++++++++++++++-- train.py | 27 ++++++++++++--- 6 files changed, 79 insertions(+), 8 deletions(-) diff --git a/estimation.py b/estimation.py index e652c581..3e393399 100644 --- a/estimation.py +++ b/estimation.py @@ -71,6 +71,7 @@ def estimate_memory(job_config: JobConfig): pp=job_config.experimental.pipeline_parallel_degree, world_size=world_size, enable_loss_parallel=job_config.training.enable_loss_parallel, + dp_type=job_config.training.data_parallel_type, ) device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") diff --git a/test_runner.py b/test_runner.py index c84ca6af..6a7b6b1a 100755 --- a/test_runner.py +++ b/test_runner.py @@ -304,6 +304,15 @@ def build_test_list(): ], "FSDP2 with float8 all-gather and precomputed dynamic scales", "fsdp2_float8_all_gather_precompute_dynamic_scales", + ), + OverrideDefinitions( + [ + [ + "--training.data_parallel_type ddp", + ] + ], + "DDP", + "ddp", ngpu=4, ), ] diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 2bd6e370..9a086830 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -312,6 +312,17 @@ def __init__(self): The default value will be the number of pipeline stages, if unspecified. """, ) + self.parser.add_argument( + "--training.data_parallel_type", + type=str, + default="fsdp", + help="Data parallelism type. TorchTitan currently supports FSDP and DDP.", + ) + self.parser.add_argument( + "--experimental.enable_compiled_autograd", + action="store_true", + help="Enable CompiledAutograd to compile the backward.", + ) self.parser.add_argument( "--training.mixed_precision_param", type=str, diff --git a/torchtitan/parallelisms/__init__.py b/torchtitan/parallelisms/__init__.py index 7e1b21c7..2fdba316 100644 --- a/torchtitan/parallelisms/__init__.py +++ b/torchtitan/parallelisms/__init__.py @@ -28,8 +28,10 @@ class ParallelDims: pp: int world_size: int enable_loss_parallel: bool + dp_type: str def __post_init__(self): + self.dp_type = self.dp_type.lower() self._validate() def _validate(self): @@ -42,6 +44,7 @@ def _validate(self): assert ( dp * tp * pp == self.world_size ), f"Invalid parallel dims: dp({dp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})" + assert self.dp_type in ("fsdp", "ddp") def build_mesh(self, device_type): dims = [] diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index ec0f6763..33b9d6d3 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -16,6 +16,8 @@ from torch.distributed import DeviceMesh from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy + +from torch.distributed._composable.replicate import replicate from torch.distributed._tensor import Replicate, Shard from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper as ptd_checkpoint_wrapper, @@ -453,13 +455,15 @@ def apply_compile(model: nn.Module, job_config: JobConfig): return model -def apply_dp( +def apply_fsdp( model: nn.Module, world_mesh: DeviceMesh, parallel_dims: "ParallelDims", job_config: JobConfig, ): - """Apply data parallelism (FSDP2) to the model.""" + """ + Apply data parallelism to the model. FSDP2 is used here. + """ dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names @@ -492,6 +496,29 @@ def apply_dp( return model +def apply_ddp( + model: nn.Module, + world_mesh: DeviceMesh, + parallel_dims: "ParallelDims", + job_config: JobConfig, +): + if world_mesh.ndim > 1: + raise RuntimeError("DDP has not supported > 1D parallelism.") + + if job_config.training.compile: + if job_config.experimental.enable_compiled_autograd: + torch._dynamo.config.optimize_ddp = ( + "python_reducer_without_compiled_forward" + ) + else: + torch._dynamo.config.optimize_ddp = "ddp_optimizer" + + model = replicate(model, device_mesh=world_mesh, bucket_cap_mb=100) + + logger.info("Applied DDP to the model") + return model + + def parallelize_llama( model: nn.Module, world_mesh: DeviceMesh, @@ -516,6 +543,9 @@ def parallelize_llama( model = apply_compile(model, job_config) if parallel_dims.dp_enabled: - model = apply_dp(model, world_mesh, parallel_dims, job_config) + if parallel_dims.dp_type == "fsdp": + model = apply_fsdp(model, world_mesh, parallel_dims, job_config) + else: + model = apply_ddp(model, world_mesh, parallel_dims, job_config) return model diff --git a/train.py b/train.py index afd1d888..b7eee302 100644 --- a/train.py +++ b/train.py @@ -138,6 +138,22 @@ def zero_grad(self): return OptimizersContainer([_build_optimizer(model) for model in model_parts]) +def get_train_context(enable_loss_parallel: bool, enable_compiled_autograd: bool): + @contextlib.contextmanager + def context(): + with contextlib.ExitStack() as stack: + if enable_loss_parallel: + stack.enter_context(loss_parallel()) + if enable_compiled_autograd: + stack.enter_context( + torch._dynamo.utils.maybe_enable_compiled_autograd(True) + ) + + yield + + return context + + # Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html @record def main(job_config: JobConfig): @@ -160,6 +176,7 @@ def main(job_config: JobConfig): pp=job_config.experimental.pipeline_parallel_degree, world_size=world_size, enable_loss_parallel=job_config.training.enable_loss_parallel, + dp_type=job_config.training.data_parallel_type, ) device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") torch.cuda.set_device(device) @@ -194,9 +211,9 @@ def main(job_config: JobConfig): dp_rank, ) - # loss_parallel enables dispatching to efficient loss operators - loss_parallel_ctx = ( - loss_parallel if parallel_dims.loss_parallel_enabled else contextlib.nullcontext + train_context = get_train_context( + parallel_dims.loss_parallel_enabled, + job_config.experimental.enable_compiled_autograd, ) # loss fn can be shared by pipeline-parallel or non-pp execution @@ -364,7 +381,7 @@ def loss_fn(pred, labels): # pipeline parallel forward / backward inside step() call is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1 - with loss_parallel_ctx(): + with train_context(): if pp_mesh.get_local_rank() == 0: pp_schedule.step(input_ids) elif is_last_stage: @@ -381,7 +398,7 @@ def loss_fn(pred, labels): ) else: # Non-PP forward / backward - with loss_parallel_ctx(): + with train_context(): pred = model(input_ids) loss = loss_fn(pred, labels) # pred.shape=(bs, seq_len, vocab_size) From 71b8eaecf7a9f66376fc74b693d09d4d839361c9 Mon Sep 17 00:00:00 2001 From: "Wei (Will) Feng" <134637289+weifengpy@users.noreply.github.com> Date: Thu, 18 Jul 2024 19:17:30 -0700 Subject: [PATCH 13/29] add torch.compile + FSDP2 float8 all-gather in CI (#468) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fixed my bug in float8_experimental. now we can torch.compile transfromer blocks with FSDP float8 all-gather https://github.com/pytorch-labs/float8_experimental/pull/321 local test: `CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_float8_linear --training.enable_fsdp_float8_all_gather --training.precompute_float8_dynamic_scale_for_fsdp --training.compile` profiler traces: I can see compiled region in cpu thread and float8 malmul `sm90_xmma_gemm_e4m3bf16...` in cuda stream Screenshot 2024-07-18 at 4 22 17 PM --- test_runner.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test_runner.py b/test_runner.py index 6a7b6b1a..b0eb04c9 100755 --- a/test_runner.py +++ b/test_runner.py @@ -305,6 +305,18 @@ def build_test_list(): "FSDP2 with float8 all-gather and precomputed dynamic scales", "fsdp2_float8_all_gather_precompute_dynamic_scales", ), + OverrideDefinitions( + [ + [ + "--training.enable_float8_linear", + "--training.enable_fsdp_float8_all_gather", + "--training.precompute_float8_dynamic_scale_for_fsdp", + "--training.compile", + ] + ], + "FSDP2 with float8 all-gather and precomputed dynamic scales", + "fsdp2_float8_all_gather_precompute_dynamic_scales_compile", + ), OverrideDefinitions( [ [ From 0c6f9a24de0bf1dd9548e3691de7b228179bbdd8 Mon Sep 17 00:00:00 2001 From: "Wei (Will) Feng" <134637289+weifengpy@users.noreply.github.com> Date: Fri, 19 Jul 2024 16:07:12 -0700 Subject: [PATCH 14/29] [float8] keep model.output as `nn.Linear` (high precision, not fp8) (#469) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit **keep model.output as nn.Linear**: it's a common practice to NOT apply fp8 on final output layer * specify `skip_fqn_list` in swapping * when applying TP to model.output, use plain `ColwiseParallel` instead of `Float8ColwiseParallel` credit to @awgu, we do not need tokentizer vacab size to be divisible by 16 https://github.com/pytorch/torchtitan/issues/461 1D TP + float8 all-gather, eager mode: `CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4 ./run_llama_train.sh --training.enable_float8_linear --training.data_parallel_degree 1 --training.tensor_parallel_degree 4` 1D TP + float8 all-gather, compile mode: `CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4 ./run_llama_train.sh --training.enable_float8_linear --training.data_parallel_degree 1 --training.tensor_parallel_degree 4 --training.compile` 2D FSDP2 + TP + float8 all-gather, eager mode: `CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4 ./run_llama_train.sh --training.enable_float8_linear --training.enable_fsdp_float8_all_gather --training.precompute_float8_dynamic_scale_for_fsdp --training.tensor_parallel_degree 2` 2D FSDP2 + TP + float8 all-gather, eager mode: `CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4 ./run_llama_train.sh --training.enable_float8_linear --training.enable_fsdp_float8_all_gather --training.precompute_float8_dynamic_scale_for_fsdp --training.tensor_parallel_degree 2 --training.compile` 1D TP + float8 all-gather trace: see float8 and all-gather in the trace Screenshot 2024-07-19 at 1 16 59 PM 2D + float8 all-gather trace: see float8 and FSDP collectives and TP collectives Screenshot 2024-07-19 at 1 29 59 PM --- torchtitan/float8_linear.py | 4 +++- torchtitan/parallelisms/parallelize_llama.py | 19 ++++++++++--------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/torchtitan/float8_linear.py b/torchtitan/float8_linear.py index 50c971ae..770531d5 100644 --- a/torchtitan/float8_linear.py +++ b/torchtitan/float8_linear.py @@ -74,7 +74,9 @@ def maybe_build_fp8_linear( ) with set_enable_fsdp_float8_all_gather(enable_fsdp_float8_all_gather): swap_linear_with_float8_linear( - model, scaling_type_w=TensorScalingType.DYNAMIC + model, + scaling_type_w=TensorScalingType.DYNAMIC, + skip_fqn_list=["output"], ) logger.info( f"Swapped to Float8Linear layers with {enable_fsdp_float8_all_gather=}" diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 33b9d6d3..634c70a0 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -117,7 +117,7 @@ def selective_checkpointing_context_fn(): return module -def get_tp_parallel_strategy( +def get_tp_parallel_strategy_for_transformer_block( job_config: JobConfig, model: nn.Module, ) -> Tuple[RowwiseParallel, ColwiseParallel, PrepareModuleInput]: @@ -346,13 +346,6 @@ def apply_tp( """Apply tensor parallelism.""" tp_mesh = world_mesh["tp"] - # Parallel styles used for transformer block linear weights and their - # inputs may be different for float8 linears - ( - rowwise_parallel_weight, - colwise_parallel_weight, - prepare_module_input, - ) = get_tp_parallel_strategy(job_config, model) loss_parallel = parallel_dims.loss_parallel_enabled # 1. Parallelize the embedding and shard its outputs (which are the first @@ -368,7 +361,7 @@ def apply_tp( output_layouts=Shard(1), ), "norm": SequenceParallel(), - "output": colwise_parallel_weight( + "output": ColwiseParallel( input_layouts=Shard(1), output_layouts=Shard(-1) if loss_parallel else Replicate(), use_local_output=not loss_parallel, @@ -376,6 +369,14 @@ def apply_tp( }, ) + # Parallel styles used for transformer block linear weights and their + # inputs may be different for float8 linears + ( + rowwise_parallel_weight, + colwise_parallel_weight, + prepare_module_input, + ) = get_tp_parallel_strategy_for_transformer_block(job_config, model) + # Apply tensor + sequence parallelism to every transformer block # NOTE: At the cost of model code change, we can accelerate Sequence Parallel # by folding (and unfolding) the batch dimension and the sequence dimension. From 0a17c26d31af264930d25ed32c546f1d87162be9 Mon Sep 17 00:00:00 2001 From: "Wei (Will) Feng" <134637289+weifengpy@users.noreply.github.com> Date: Fri, 19 Jul 2024 17:28:10 -0700 Subject: [PATCH 15/29] remove CI for FSDP2 + fp8 all-gather (#470) per discussion from https://github.com/pytorch/torchtitan/pull/469#issuecomment-2240258083 we are planning BC breaking changes in float8_experimental. remove CI for FSDP2 + fp8 all-gather for now. When public APIs are finalized, we can discuss bringing it back --- test_runner.py | 44 -------------------------------------------- 1 file changed, 44 deletions(-) diff --git a/test_runner.py b/test_runner.py index b0eb04c9..67117bfe 100755 --- a/test_runner.py +++ b/test_runner.py @@ -273,50 +273,6 @@ def build_test_list(): "fsdp2_mem_tracker", ngpu=4, ), - OverrideDefinitions( - [ - [ - "--training.enable_float8_linear", - ] - ], - "FSDP2 with original dtype", - "float8_fsdp2_orig_all_gather", - ngpu=4, - ), - OverrideDefinitions( - [ - [ - "--training.enable_float8_linear", - "--training.enable_fsdp_float8_all_gather", - ] - ], - "FSDP2 with float8 all-gather", - "fsdp2_float8_all_gather", - ngpu=4, - ), - OverrideDefinitions( - [ - [ - "--training.enable_float8_linear", - "--training.enable_fsdp_float8_all_gather", - "--training.precompute_float8_dynamic_scale_for_fsdp", - ] - ], - "FSDP2 with float8 all-gather and precomputed dynamic scales", - "fsdp2_float8_all_gather_precompute_dynamic_scales", - ), - OverrideDefinitions( - [ - [ - "--training.enable_float8_linear", - "--training.enable_fsdp_float8_all_gather", - "--training.precompute_float8_dynamic_scale_for_fsdp", - "--training.compile", - ] - ], - "FSDP2 with float8 all-gather and precomputed dynamic scales", - "fsdp2_float8_all_gather_precompute_dynamic_scales_compile", - ), OverrideDefinitions( [ [ From 0ee573cdd1fc97ebdf0912d6f92d0ed08c86ada3 Mon Sep 17 00:00:00 2001 From: Less Wright Date: Sat, 20 Jul 2024 17:14:52 -0700 Subject: [PATCH 16/29] dynamically update torch.compile cache config to ensure async tp support, enhance async tp UX (#471) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds some enhancements for supporting async tp: 1 - if async tp is active, auto updates the torch.dynamo cache limit to 10K. If this is not updated, async tp will not be activated on larger models as it will quietly stop compilation due to 'cache limit reached' with no info for the user. This config update is logged. 2 - if async tp is enabled, verifies that torch.compile is set to true for this job config. If not, it warns and then activates torch.compile to ensure user gets working async tp. (see WARNING in below screenshot) Screenshot 2024-07-20 at 4 33 04 PM 3 - Updates the 'Applied Tensor Parallel' to the model to be 'Applied Async Tensor Parallel' when async tp is active to make it clear in the logs which TP is active. (see above screenshot) --- torchtitan/parallelisms/parallelize_llama.py | 16 +++++++++++++++- train_configs/debug_model.toml | 1 + 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 634c70a0..31eabc6c 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -413,13 +413,27 @@ def apply_tp( parallelize_plan=layer_plan, ) + # updates expressly for async tensor parallel if job_config.experimental.enable_async_tensor_parallel: from torch.distributed._symmetric_memory import enable_symm_mem_for_group + torch._dynamo.config.cache_size_limit = 10000 + logger.info( + "Updating torch._dynamo.config.cache_size_limit to 10000 to support Async TP" + ) + torch._inductor.config._micro_pipeline_tp = True enable_symm_mem_for_group(tp_mesh.get_group().group_name) - logger.info("Applied Tensor Parallelism to the model") + if not job_config.training.compile: + logger.warning( + "Async TP requires compilation...auto enabling compile = True for this job to resolve." + ) + job_config.training.compile = True + + logger.info( + f"Applied{' Async ' if job_config.experimental.enable_async_tensor_parallel else ' '}Tensor Parallelism to the model" + ) return model diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 7c849976..b36e9d0c 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -43,6 +43,7 @@ dataset = "c4_mini" # supported datasets: c4_mini (45K), c4 (177M) [experimental] pipeline_parallel_degree = 1 +enable_async_tensor_parallel = false [checkpoint] enable_checkpoint = false From 69c9bb27b0779739227f278fd4738697707c2660 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Mon, 15 Jul 2024 12:11:50 -0700 Subject: [PATCH 17/29] Fix 8gpu PP failure due to 2D DCP disablement DCP recently added safeties to avoid using it for 2D/3D since strided sharding (a feature needed for safe 2D/3D resharding) is not ready yet. PP uses DCP to load a seed checkpoint. Disabling the safety mechanism is enough to make 3D/PP still work (for the case where we train from the beginning or do not re-shard. (Resharding refers to saving a checkpoint from one world size/parallelism config and loading/resuming under a different one). ghstack-source-id: c069d2186c79517c72f5b3c99485cebdc15df08f Pull Request resolved: https://github.com/pytorch/torchtitan/pull/460 --- torchtitan/parallelisms/parallelize_llama.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 31eabc6c..3d123953 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -507,6 +507,17 @@ def apply_fsdp( model, **fsdp_config, reshard_after_forward=not parallel_dims.pp_enabled ) + if parallel_dims.pp_enabled: + # TODO + # This PR https://github.com/pytorch/pytorch/pull/129519 added a safety check to avoid using 2D/3D DCP since + # without strided sharding, DCP can not safely support resharding for 2D/3D. However, for PP to work, even + # without resharding, we load a seed-checkpoint and need to disable the safety mechanism. This hack should be + # removed after strided sharding is landed in DCP. + for module in model.modules(): + assert len(module._load_state_dict_pre_hooks) <= 1 + module._load_state_dict_pre_hooks.clear() + assert len(module._state_dict_pre_hooks) <= 1 + module._state_dict_pre_hooks.clear() logger.info("Applied FSDP to the model") return model From 90e2070349fd602e43c1f81a7fc03be6fd230c8c Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Fri, 26 Jul 2024 13:48:29 -0700 Subject: [PATCH 18/29] update float8 integration after UX changes (#484) Summary: float8_experimental landed various BC-breaking UX changes last week. This PR updates torchtitan to work with the version of float8_experimental after https://github.com/pytorch-labs/float8_experimental/pull/332 and https://github.com/pytorch-labs/float8_experimental/pull/337 Test Plan: ``` with-proxy CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 NGPU=8 CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.enable_float8_linear --training.compile ``` Reviewers: Subscribers: Tasks: Tags: --- torchtitan/float8_linear.py | 40 ++++++++++++++----------------------- 1 file changed, 15 insertions(+), 25 deletions(-) diff --git a/torchtitan/float8_linear.py b/torchtitan/float8_linear.py index 770531d5..557fca64 100644 --- a/torchtitan/float8_linear.py +++ b/torchtitan/float8_linear.py @@ -12,7 +12,6 @@ # Note: Performance # Float8 experimental is intended to be ran under `torch.compile`` for competitive performance -import contextlib import functools from typing import Optional @@ -24,20 +23,6 @@ from torchtitan.logging_utils import logger -@contextlib.contextmanager -def set_enable_fsdp_float8_all_gather(enable_fsdp_fp8_all_gather: bool): - import float8_experimental.config as config - - prev = config.enable_fsdp_fp8_all_gather - torch.distributed.barrier() - config.enable_fsdp_fp8_all_gather = enable_fsdp_fp8_all_gather - try: - yield - finally: - torch.distributed.barrier() - config.enable_fsdp_fp8_all_gather = prev - - @functools.lru_cache(None) def is_sm90_or_later(): # Float8 is only supported on H100+ GPUs @@ -63,21 +48,26 @@ def maybe_build_fp8_linear( ) return try: - from float8_experimental.float8_linear import TensorScalingType - from float8_experimental.float8_linear_utils import ( - swap_linear_with_float8_linear, + from float8_experimental import ( + CastConfig, + convert_to_float8_training, + Float8LinearConfig, + ScalingType, ) # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear enable_fsdp_float8_all_gather = ( job_config.training.enable_fsdp_float8_all_gather and dp_enabled ) - with set_enable_fsdp_float8_all_gather(enable_fsdp_float8_all_gather): - swap_linear_with_float8_linear( - model, - scaling_type_w=TensorScalingType.DYNAMIC, - skip_fqn_list=["output"], - ) + float8_config = Float8LinearConfig( + enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, + cast_config_weight=CastConfig(scaling_type=ScalingType.DYNAMIC), + ) + convert_to_float8_training( + model, + config=float8_config, + module_filter_fn=lambda mod, fqn: fqn != "output", + ) logger.info( f"Swapped to Float8Linear layers with {enable_fsdp_float8_all_gather=}" ) @@ -102,6 +92,6 @@ def maybe_precompute_fp8_dynamic_scale_for_fsdp( "Skipped precomputing fp8 scales because SM90 or later is not available", ) return - from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp + from float8_experimental import precompute_float8_dynamic_scale_for_fsdp precompute_float8_dynamic_scale_for_fsdp(model) From 42f4ff5a54f2c6f06984e57fd7b34527e8cd943f Mon Sep 17 00:00:00 2001 From: Sanket Jayant Purandare Date: Fri, 26 Jul 2024 14:38:19 -0700 Subject: [PATCH 19/29] Re-enable FSDP2 Mem Tracker integration tests ghstack-source-id: 8344603f7a5596cb2909c9bf04dd1b9e4730c9b8 Pull Request resolved: https://github.com/pytorch/torchtitan/pull/485 --- estimation.py | 28 +++++++++++++++++----------- test_runner.py | 2 ++ 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/estimation.py b/estimation.py index 3e393399..f5527f74 100644 --- a/estimation.py +++ b/estimation.py @@ -14,17 +14,19 @@ from torch._subclasses.fake_tensor import FakeTensorMode from torch.distributed import destroy_process_group from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker -from torch.distributed.tensor.parallel import loss_parallel from torch.testing._internal.distributed.fake_pg import FakeStore from torchtitan.config_manager import JobConfig from torchtitan.datasets import create_tokenizer -from torchtitan.float8_linear import build_fp8_linear +from torchtitan.float8_linear import ( + maybe_build_fp8_linear, + maybe_precompute_fp8_dynamic_scale_for_fsdp, +) from torchtitan.logging_utils import init_logger, logger from torchtitan.lr_scheduling import get_lr_schedulers from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config from torchtitan.parallelisms import models_parallelize_fns, ParallelDims -from train import build_optimizers +from train import build_optimizers, get_train_context def estimate_memory(job_config: JobConfig): @@ -61,9 +63,10 @@ def estimate_memory(job_config: JobConfig): logger.info("Compiled RMSNorm is not supported yet. Switching to RMSNorm.") job_config.model.norm_type = "rmsnorm" - if job_config.training.compile: + if job_config.training.compile or job_config.experimental.enable_compiled_autograd: logger.info("Compile mode is not supported yet. Switching to eager mode.") job_config.training.compile = False + job_config.experimental.enable_compiled_autograd = False parallel_dims = ParallelDims( dp=job_config.training.data_parallel_degree, @@ -96,9 +99,9 @@ def estimate_memory(job_config: JobConfig): tokenizer_type = model_name_to_tokenizer[model_name] tokenizer = create_tokenizer(tokenizer_type, job_config.model.tokenizer_path) - # loss_parallel enables dispatching to efficient loss operators - loss_parallel_ctx = ( - loss_parallel if parallel_dims.loss_parallel_enabled else contextlib.nullcontext + train_context = get_train_context( + parallel_dims.loss_parallel_enabled, + job_config.experimental.enable_compiled_autograd, ) # loss fn can be shared by pipeline-parallel or non-pp execution @@ -124,9 +127,8 @@ def loss_fn(pred, labels): with torch.device("meta"): whole_model = model_cls.from_model_args(model_config) - # apply fp8 linear module swap - if job_config.training.enable_fp8_linear: - build_fp8_linear(whole_model, job_config, parallel_dims.dp_enabled) + # swap to Float8Linear base on fp8 config + maybe_build_fp8_linear(whole_model, job_config, parallel_dims.dp_enabled) # apply PT-D DP/TP parallelisms and activation checkpointing model_parts = [whole_model] @@ -171,7 +173,7 @@ def loss_fn(pred, labels): for iter_idx in range(2): input_ids, labels = batch # train step - with loss_parallel_ctx(): + with train_context(): pred = whole_model(input_ids) loss = loss_fn(pred, labels) del pred @@ -185,6 +187,10 @@ def loss_fn(pred, labels): # optimizer step optimizers.step() lr_schedulers.step() + # when fp8 config is on, + # calculate float8 dynamic amax/scale for all-parameter for FSDP2 + # it issues a single all-reduce for all parameters at once for better performance + maybe_precompute_fp8_dynamic_scale_for_fsdp(whole_model, job_config) optimizers.zero_grad() print(f"Peak Memory at iter: {iter_idx}") fsdp_memtracker.display_snapshot("peak", units="MiB", tabulate=True) diff --git a/test_runner.py b/test_runner.py index 67117bfe..a8df397c 100755 --- a/test_runner.py +++ b/test_runner.py @@ -314,6 +314,8 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str): for override_arg in test_flavor.override_args: cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} ./run_llama_train.sh" + if test_name == "fsdp2_mem_tracker": + cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} ./run_memory_estimation.sh" cmd += " " + dump_folder_arg cmd += " " + model_flavor_arg if override_arg: From a48de09a29326e958a8c21dd466c73a8b15a65b6 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Mon, 29 Jul 2024 07:35:53 -0700 Subject: [PATCH 20/29] Used `partial` instead of global vars for LR scheduling ghstack-source-id: 12c4418b0574d93e1441f4ca3d1de79c8aad7a40 Pull Request resolved: https://github.com/pytorch/torchtitan/pull/487 --- torchtitan/lr_scheduling.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/torchtitan/lr_scheduling.py b/torchtitan/lr_scheduling.py index 35f39e13..9f766268 100644 --- a/torchtitan/lr_scheduling.py +++ b/torchtitan/lr_scheduling.py @@ -4,31 +4,30 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import functools + from torch.optim.lr_scheduler import LambdaLR from torchtitan.config_manager import JobConfig -# global states for scheduling -# these are needed as LambdaLR does not support argument passing -_warmup_steps = 200 -_decay_steps = 0 - -def linear_warmup_linear_decay(current_step: int) -> float: +def linear_warmup_linear_decay( + warmup_steps: int, decay_steps: int, current_step: int +) -> float: """Computes linear warmup followed by linear decay. Per LambdaLR requirement, this is accomplished by returning a multiplicative factor to adjust the learning rate to create the desired schedule. """ - if current_step < _warmup_steps: + if current_step < warmup_steps: # linear warmup # 0-indexed step, hence + 1 adjustments current_step += 1 - curr_adjustment = float(current_step / (_warmup_steps + 1)) + curr_adjustment = float(current_step / (warmup_steps + 1)) else: # linear decay - normalized_step = _decay_steps - (current_step - _warmup_steps) - curr_adjustment = 1 - (_decay_steps - normalized_step) / _decay_steps + normalized_step = decay_steps - (current_step - warmup_steps) + curr_adjustment = 1 - (decay_steps - normalized_step) / decay_steps return curr_adjustment @@ -36,11 +35,12 @@ def linear_warmup_linear_decay(current_step: int) -> float: def get_lr_schedulers(optimizers, job_config: JobConfig): def _get_lr_scheduler(optimizer): """Build a linear warmup and linear decay scheduler""" - global _warmup_steps, _decay_steps - _warmup_steps = int(job_config.training.warmup_steps) - _decay_steps = float(max(1, job_config.training.steps - _warmup_steps)) - - warmup_scheduler = LambdaLR(optimizer, lr_lambda=linear_warmup_linear_decay) + warmup_steps = int(job_config.training.warmup_steps) + decay_steps = float(max(1, job_config.training.steps - warmup_steps)) + lr_lambda = functools.partial( + linear_warmup_linear_decay, warmup_steps, decay_steps + ) + warmup_scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) return warmup_scheduler class SchedulersContainer: From b63e20988a75a8d16c376b9a04de897e3edecbf7 Mon Sep 17 00:00:00 2001 From: Hugo <6937752+fduwjj@users.noreply.github.com> Date: Mon, 29 Jul 2024 19:38:01 -0700 Subject: [PATCH 21/29] =?UTF-8?q?[EZ]=20Add=20logs=20for=20some=20basic=20?= =?UTF-8?q?training=20params=20so=20that=20we=20can=20verify=20in=E2=80=A6?= =?UTF-8?q?=20(#491)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit As title, while testing on 405B model, I found that we need to somehow need the logs for some training params. So added some here. Tested locally and the logging is shown as in the screenshot: image --- train.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/train.py b/train.py index b7eee302..38983972 100644 --- a/train.py +++ b/train.py @@ -355,7 +355,12 @@ def loss_fn(pred, labels): gpu_memory_monitor.reset_peak_stats() # train loop - logger.info(f"Training starts at step {train_state.step + 1}") + logger.info( + f"Training starts at step {train_state.step + 1}, " + f"with local batch size: {job_config.training.batch_size}, " + f"sequence length: {job_config.training.seq_len}, " + f"total steps: {job_config.training.steps}({job_config.training.warmup_steps}), " + ) with maybe_enable_profiling( job_config, global_step=train_state.step ) as torch_profiler, maybe_enable_memory_snapshot( From 91f075ae2a38862fb4691ee56e0c89692d145f66 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Tue, 30 Jul 2024 09:20:14 -0700 Subject: [PATCH 22/29] make float8 scaling type configurable (#489) Summary: Adds config options to configure float8 scaling type for input, weight, grad_output. Performance is not ideal yet, but that's because we have not optimized it. Test Plan: ``` // repeat for input, weight, grad_out with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.enable_float8_linear --training.float8_scaling_type_weight delayed --training.compile ``` Reviewers: Subscribers: Tasks: Tags: --- torchtitan/config_manager.py | 22 ++++++++++++++++-- torchtitan/float8_linear.py | 43 +++++++++++++++++++++++++++++++++++- train.py | 6 ++++- 3 files changed, 67 insertions(+), 4 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 9a086830..f0a69b4d 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -352,8 +352,7 @@ def __init__(self): "--training.enable_float8_linear", action="store_true", help=""" - If true, swaps `torch.nn.Linear` with `Float8Linear` with - default settings (dynamic scaling). + If true, swaps `torch.nn.Linear` with `Float8Linear`. This feature requires you to install 'float8_experimental' which can be found here: https://github.com/pytorch-labs/float8_experimental """, @@ -370,6 +369,25 @@ def __init__(self): default=False, help="Whether precompute float8 scales dynamically for FSDP", ) + self.parser.add_argument( + "--training.float8_scaling_type_input", + type=str, + default="dynamic", + help="float8 scaling for input, dynamic (default) or delayed", + choices=["dynamic", "delayed"], + ) + self.parser.add_argument( + "--training.float8_scaling_type_weight", + type=str, + default="dynamic", + help="float8 scaling for input, dynamic (default) or delayed", + ) + self.parser.add_argument( + "--training.float8_scaling_type_grad_output", + type=str, + default="dynamic", + help="float8 scaling for input, dynamic (default) or delayed", + ) self.parser.add_argument( "--training.gc_freq", type=int, diff --git a/torchtitan/float8_linear.py b/torchtitan/float8_linear.py index 557fca64..1651585e 100644 --- a/torchtitan/float8_linear.py +++ b/torchtitan/float8_linear.py @@ -59,9 +59,19 @@ def maybe_build_fp8_linear( enable_fsdp_float8_all_gather = ( job_config.training.enable_fsdp_float8_all_gather and dp_enabled ) + scaling_type_input = ScalingType(job_config.training.float8_scaling_type_input) + scaling_type_weight = ScalingType( + job_config.training.float8_scaling_type_weight + ) + scaling_type_grad_output = ScalingType( + job_config.training.float8_scaling_type_grad_output + ) float8_config = Float8LinearConfig( enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, - cast_config_weight=CastConfig(scaling_type=ScalingType.DYNAMIC), + cast_config_input=CastConfig(scaling_type=scaling_type_input), + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), + cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output), + enable_pre_and_post_forward=False, ) convert_to_float8_training( model, @@ -95,3 +105,34 @@ def maybe_precompute_fp8_dynamic_scale_for_fsdp( from float8_experimental import precompute_float8_dynamic_scale_for_fsdp precompute_float8_dynamic_scale_for_fsdp(model) + + +_sync_float8_amax_and_scale_history = None + + +def maybe_sync_float8_amax_and_scale_history(model: nn.Module, job_config: JobConfig): + if not ( + job_config.training.enable_float8_linear + and ( + job_config.training.float8_scaling_type_input == "delayed" + or job_config.training.float8_scaling_type_weight == "delayed" + or job_config.training.float8_scaling_type_grad_output == "delayed" + ) + ): + return + + from float8_experimental import sync_float8_amax_and_scale_history + + # TODO(future): see if precalculating the modules to sync over is going to + # meaningfully help performance + + global _sync_float8_amax_and_scale_history + if _sync_float8_amax_and_scale_history is None: + if job_config.training.compile: + _sync_float8_amax_and_scale_history = torch.compile( + sync_float8_amax_and_scale_history + ) + else: + _sync_float8_amax_and_scale_history = sync_float8_amax_and_scale_history + + sync_float8_amax_and_scale_history(model) diff --git a/train.py b/train.py index 38983972..5a637f46 100644 --- a/train.py +++ b/train.py @@ -30,6 +30,7 @@ from torchtitan.float8_linear import ( maybe_build_fp8_linear, maybe_precompute_fp8_dynamic_scale_for_fsdp, + maybe_sync_float8_amax_and_scale_history, ) from torchtitan.logging_utils import init_logger, logger from torchtitan.lr_scheduling import get_lr_schedulers @@ -417,12 +418,15 @@ def loss_fn(pred, labels): model.parameters(), job_config.training.max_norm, foreach=True ) + # if float8 is enabled, sync float8 amaxes and scales + maybe_sync_float8_amax_and_scale_history(model, job_config) + # optimizer step checkpoint.wait_for_staging() optimizers.step() lr_schedulers.step() - # when fp8 config is on, + # when float8 config is on, # calculate float8 dynamic amax/scale for all-parameter for FSDP2 # it issues a single all-reduce for all parameters at once for better performance maybe_precompute_fp8_dynamic_scale_for_fsdp(model, job_config) From 9358d7086ed7e67b1b6ef3e7cbea794577820988 Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Tue, 30 Jul 2024 14:40:01 -0400 Subject: [PATCH 23/29] [PP] add flexible interleaved 1f1b schedule #490 (#493) This was approved in https://github.com/pytorch/torchtitan/pull/490, but merged into the wrong branch, merging this into main --- test_runner.py | 15 +++++++++++++++ torchtitan/config_manager.py | 2 +- torchtitan/parallelisms/pipelining_utils.py | 7 +++++++ 3 files changed, 23 insertions(+), 1 deletion(-) diff --git a/test_runner.py b/test_runner.py index a8df397c..a1d3bf22 100755 --- a/test_runner.py +++ b/test_runner.py @@ -46,6 +46,21 @@ def build_test_list(): """ integration_tests_flavors = defaultdict(list) integration_tests_flavors["debug_model.toml"] = [ + OverrideDefinitions( + [ + [ + "--checkpoint.enable_checkpoint", + "--experimental.pipeline_parallel_degree 4", + "--experimental.pipeline_parallel_split_points layers.1,layers.2,layers.3,layers.4,layers.5,layers.6,layers.7", + "--experimental.pipeline_parallel_schedule flexible_interleaved_1f1b", + "--model.norm_type rmsnorm", # fused_rmsnorm throws cuda context error with pp + ], + ], + "PP looped flexible 1f1b test", + "pp_looped_flexible_1f1b", + requires_seed_checkpoint=True, + ngpu=4, + ), OverrideDefinitions( [ [ diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index f0a69b4d..26570ec7 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -275,7 +275,7 @@ def __init__(self): self.parser.add_argument( "--experimental.pipeline_parallel_schedule", type=str, - choices=["1f1b", "gpipe", "interleaved_1f1b"], + choices=["1f1b", "gpipe", "interleaved_1f1b", "flexible_interleaved_1f1b"], default="1f1b", help=""" Specify the Pipeline Parallel schedule to use. diff --git a/torchtitan/parallelisms/pipelining_utils.py b/torchtitan/parallelisms/pipelining_utils.py index e60b7f51..adf9eb09 100644 --- a/torchtitan/parallelisms/pipelining_utils.py +++ b/torchtitan/parallelisms/pipelining_utils.py @@ -7,6 +7,7 @@ from torch.distributed.pipelining import ( Schedule1F1B, + ScheduleFlexibleInterleaved1F1B, ScheduleGPipe, ScheduleInterleaved1F1B, ) @@ -23,6 +24,12 @@ def build_pipeline_schedule(job_config, parallel_dims, stages, loss_fn): elif job_config.experimental.pipeline_parallel_schedule == "interleaved_1f1b": schedule_class = ScheduleInterleaved1F1B looped_schedule = True + elif ( + job_config.experimental.pipeline_parallel_schedule + == "flexible_interleaved_1f1b" + ): + schedule_class = ScheduleFlexibleInterleaved1F1B + looped_schedule = True else: raise NotImplementedError( f"{job_config.experimental.pipeline_parallel_schedule} is not implemented" From 239d56fc1468106aa411f542ec809e4198fdc5c6 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Tue, 30 Jul 2024 11:57:37 -0700 Subject: [PATCH 24/29] move float8 callsites to torchao.float8 (#492) Summary: The `float8_experimental` repository moved to `torchao.float8` in https://github.com/pytorch/ao/pull/551 This PR updates `torchtitan` to use float8 from the new location. Test Plan: ``` with-proxy CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_float8_linear --training.compile ``` Reviewers: Subscribers: Tasks: Tags: --- .github/workflows/integration_test_4gpu.yaml | 2 +- torchtitan/config_manager.py | 4 ++-- torchtitan/float8_linear.py | 16 ++++++++-------- torchtitan/parallelisms/parallelize_llama.py | 4 +++- 4 files changed, 14 insertions(+), 12 deletions(-) diff --git a/.github/workflows/integration_test_4gpu.yaml b/.github/workflows/integration_test_4gpu.yaml index 7c913b07..813e11af 100644 --- a/.github/workflows/integration_test_4gpu.yaml +++ b/.github/workflows/integration_test_4gpu.yaml @@ -39,6 +39,6 @@ jobs: python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 python -m pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/ - python -m pip install git+https://github.com/pytorch-labs/float8_experimental.git + USE_CPP=0 python -m pip install git+https://github.com/pytorch/ao.git mkdir artifacts-to-be-uploaded python ./test_runner.py artifacts-to-be-uploaded --ngpu 4 diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 26570ec7..33070120 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -353,8 +353,8 @@ def __init__(self): action="store_true", help=""" If true, swaps `torch.nn.Linear` with `Float8Linear`. - This feature requires you to install 'float8_experimental' which can be found - here: https://github.com/pytorch-labs/float8_experimental + This feature requires you to install 'torchao' which can be found + here: https://github.com/pytorch/ao """, ) self.parser.add_argument( diff --git a/torchtitan/float8_linear.py b/torchtitan/float8_linear.py index 1651585e..658a41cc 100644 --- a/torchtitan/float8_linear.py +++ b/torchtitan/float8_linear.py @@ -4,11 +4,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# [Note] Getting the 'float8_experimental' package: -# This script requires the 'float8_experimental' package to function correctly. +# [Note] Getting the 'torchao' package: +# This script requires the 'torchao' package to function correctly. # Please ensure you have this package installed from the appropriate repository. -# You can obtain it from https://github.com/pytorch-labs/float8_experimental. -# Either clone and run `pip install .` or run `pip install git+https://github.com/pytorch-labs/float8_experimental.git` +# You can obtain it from https://github.com/pytorch/ao by following the +# installation instructions. # Note: Performance # Float8 experimental is intended to be ran under `torch.compile`` for competitive performance @@ -48,7 +48,7 @@ def maybe_build_fp8_linear( ) return try: - from float8_experimental import ( + from torchao.float8 import ( CastConfig, convert_to_float8_training, Float8LinearConfig, @@ -83,7 +83,7 @@ def maybe_build_fp8_linear( ) except ImportError as exc: raise ImportError( - "float8_experimental is not installed. Please install it to use fp8 linear layers." + "torchao is not installed. Please install it to use fp8 linear layers." ) from exc @@ -102,7 +102,7 @@ def maybe_precompute_fp8_dynamic_scale_for_fsdp( "Skipped precomputing fp8 scales because SM90 or later is not available", ) return - from float8_experimental import precompute_float8_dynamic_scale_for_fsdp + from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp precompute_float8_dynamic_scale_for_fsdp(model) @@ -121,7 +121,7 @@ def maybe_sync_float8_amax_and_scale_history(model: nn.Module, job_config: JobCo ): return - from float8_experimental import sync_float8_amax_and_scale_history + from torchao.float8 import sync_float8_amax_and_scale_history # TODO(future): see if precalculating the modules to sync over is going to # meaningfully help performance diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 3d123953..e3c6fc80 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -129,7 +129,9 @@ def get_tp_parallel_strategy_for_transformer_block( # TODO(future PR): once float8 configuration supports delayed # scaling, add a check here to enforce supported float8 all-gather # configurations - from float8_experimental.float8_tensor_parallel import ( + # TODO(future PR): add the items below to __init__.py of torchao.float8, + # and import from there + from torchao.float8.float8_tensor_parallel import ( Float8ColwiseParallel, Float8RowwiseParallel, PrepareFloat8ModuleInput, From 3c77e9fa2fb384cabcf0bf71571caaaea6cf90e7 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Tue, 30 Jul 2024 23:52:12 -0700 Subject: [PATCH 25/29] [BE][1/n] simplify train.py ghstack-source-id: 3879e764e7b33afde5d778810c71d1d2a8f82f6d Pull Request resolved: https://github.com/pytorch/torchtitan/pull/494 --- estimation.py | 20 +- test/datasets/test_checkpoint.py | 4 +- torchtitan/checkpoint.py | 45 +++- torchtitan/config_manager.py | 2 +- torchtitan/datasets/__init__.py | 4 +- torchtitan/datasets/hf_datasets.py | 2 +- torchtitan/datasets/tokenizer/__init__.py | 4 +- .../datasets/tokenizer/sentencepiece.py | 2 +- torchtitan/datasets/tokenizer/tiktoken.py | 2 +- torchtitan/float8_linear.py | 2 +- torchtitan/{logging_utils.py => logging.py} | 0 torchtitan/metrics.py | 32 ++- torchtitan/models/llama/model.py | 8 +- torchtitan/models/norms.py | 8 +- torchtitan/{lr_scheduling.py => optimizer.py} | 53 ++++- torchtitan/parallelisms/__init__.py | 11 +- torchtitan/parallelisms/parallelize_llama.py | 2 +- torchtitan/parallelisms/pipelining_utils.py | 4 +- torchtitan/profiling.py | 2 +- torchtitan/utils.py | 39 ++- train.py | 225 +++++------------- 21 files changed, 231 insertions(+), 240 deletions(-) rename torchtitan/{logging_utils.py => logging.py} (100%) rename torchtitan/{lr_scheduling.py => optimizer.py} (51%) diff --git a/estimation.py b/estimation.py index f5527f74..3adcf663 100644 --- a/estimation.py +++ b/estimation.py @@ -9,24 +9,22 @@ import os import torch -import torch.nn.functional as F from torch._guards import active_fake_mode from torch._subclasses.fake_tensor import FakeTensorMode -from torch.distributed import destroy_process_group from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker from torch.testing._internal.distributed.fake_pg import FakeStore from torchtitan.config_manager import JobConfig -from torchtitan.datasets import create_tokenizer +from torchtitan.datasets import build_tokenizer from torchtitan.float8_linear import ( maybe_build_fp8_linear, maybe_precompute_fp8_dynamic_scale_for_fsdp, ) -from torchtitan.logging_utils import init_logger, logger -from torchtitan.lr_scheduling import get_lr_schedulers +from torchtitan.logging import init_logger, logger from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config +from torchtitan.optimizer import build_lr_schedulers, build_optimizers from torchtitan.parallelisms import models_parallelize_fns, ParallelDims -from train import build_optimizers, get_train_context +from train import get_train_context def estimate_memory(job_config: JobConfig): @@ -97,7 +95,7 @@ def estimate_memory(job_config: JobConfig): # build tokenizer tokenizer_type = model_name_to_tokenizer[model_name] - tokenizer = create_tokenizer(tokenizer_type, job_config.model.tokenizer_path) + tokenizer = build_tokenizer(tokenizer_type, job_config.model.tokenizer_path) train_context = get_train_context( parallel_dims.loss_parallel_enabled, @@ -106,7 +104,9 @@ def estimate_memory(job_config: JobConfig): # loss fn can be shared by pipeline-parallel or non-pp execution def loss_fn(pred, labels): - return F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1)) + return torch.nn.functional.cross_entropy( + pred.flatten(0, 1), labels.flatten(0, 1) + ) # build model (using meta init) model_cls = model_name_to_cls[model_name] @@ -146,7 +146,7 @@ def loss_fn(pred, labels): # build optimizer after applying parallelisms to the model optimizers = build_optimizers(model_parts, job_config) - lr_schedulers = get_lr_schedulers(optimizers.optimizers, job_config) + lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config) for model in model_parts: model.train() @@ -224,4 +224,4 @@ def loss_fn(pred, labels): try: estimate_memory(config) finally: - destroy_process_group() + torch.distributed.destroy_process_group() diff --git a/test/datasets/test_checkpoint.py b/test/datasets/test_checkpoint.py index 6f04dd23..741c997f 100644 --- a/test/datasets/test_checkpoint.py +++ b/test/datasets/test_checkpoint.py @@ -6,7 +6,7 @@ import torch from torchtitan.datasets.hf_datasets import build_hf_data_loader -from torchtitan.datasets.tokenizer import create_tokenizer +from torchtitan.datasets.tokenizer import build_tokenizer class TestCheckpoint: @@ -42,7 +42,7 @@ def _build_dataloader( self, dataset_name, dataset_path, batch_size, seq_len, world_size, rank ): tokenizer_type = "tiktoken" - tokenizer = create_tokenizer("tiktoken", "./test/assets/test_tiktoken.model") + tokenizer = build_tokenizer("tiktoken", "./test/assets/test_tiktoken.model") return build_hf_data_loader( dataset_name=dataset_name, dataset_path=dataset_path, diff --git a/torchtitan/checkpoint.py b/torchtitan/checkpoint.py index 30317e3c..b71419c6 100644 --- a/torchtitan/checkpoint.py +++ b/torchtitan/checkpoint.py @@ -10,6 +10,8 @@ import re import shutil import time +from dataclasses import dataclass, field +from io import BytesIO from multiprocessing import get_context from typing import Any, Dict, List, Union @@ -27,7 +29,7 @@ from torch.distributed.checkpoint.stateful import Stateful from torch.utils.data import DataLoader from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP -from torchtitan.logging_utils import init_logger, logger +from torchtitan.logging import init_logger, logger class IntervalType(enum.Enum): @@ -41,6 +43,43 @@ class AsyncMode(str, enum.Enum): ASYNC_WITH_PINNED_MEM = "async_with_pinned_mem" +@dataclass +class TrainState(Stateful): + step: int = 0 + global_avg_losses: List[float] = field(default_factory=list) + global_max_losses: List[float] = field(default_factory=list) + log_steps: List[int] = field(default_factory=list) + + def state_dict(self) -> Dict[str, Any]: + # Only checkpoint global_avg_losses and global_max_losses per log frequency + # to avoid sync overhead in every iteration. + global_avg_losses_bytes = BytesIO() + torch.save(self.global_avg_losses, global_avg_losses_bytes) + global_max_losses_bytes = BytesIO() + torch.save(self.global_max_losses, global_max_losses_bytes) + log_steps_bytes = BytesIO() + torch.save(self.log_steps, log_steps_bytes) + return { + "step": torch.tensor(self.step, dtype=torch.int32), + "global_avg_losses": global_avg_losses_bytes, + "global_max_losses": global_max_losses_bytes, + "log_steps": log_steps_bytes, + } + + def load_state_dict(self, state_dict) -> None: + self.step = state_dict["step"].item() + state_dict["global_avg_losses"].seek(0) + self.global_avg_losses = torch.load( + state_dict["global_avg_losses"], weights_only=False + ) + state_dict["global_max_losses"].seek(0) + self.global_max_losses = torch.load( + state_dict["global_max_losses"], weights_only=False + ) + state_dict["log_steps"].seek(0) + self.log_steps = torch.load(state_dict["log_steps"], weights_only=False) + + class ModelWrapper(Stateful): def __init__(self, model: Union[nn.Module, List[nn.Module]]) -> None: self.model = [model] if isinstance(model, nn.Module) else model @@ -124,10 +163,10 @@ def checkpoint_mp(recv, send): class CheckpointManager: def __init__( self, + dataloader: DataLoader, model_parts: List[nn.Module], optimizers: List[torch.optim.Optimizer], lr_schedulers: List[torch.optim.lr_scheduler.LRScheduler], - dataloader: DataLoader, states: Dict[str, Any], job_config: JobConfig, ) -> None: @@ -390,7 +429,7 @@ def save(self, curr_step: int, force: bool = False) -> None: f"in {time.monotonic() - begin:.2f} seconds." ) - def wait_for_staging(self) -> None: + def maybe_wait_for_staging(self) -> None: if ( self.enable_checkpoint and self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 33070120..dd5ba7cd 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -16,7 +16,7 @@ except ModuleNotFoundError: import tomli as tomllib -from torchtitan.logging_utils import logger +from torchtitan.logging import logger TORCH_DTYPE_MAP = { "float16": torch.float16, diff --git a/torchtitan/datasets/__init__.py b/torchtitan/datasets/__init__.py index e9a149c6..75ea6b66 100644 --- a/torchtitan/datasets/__init__.py +++ b/torchtitan/datasets/__init__.py @@ -5,9 +5,9 @@ # LICENSE file in the root directory of this source tree. from torchtitan.datasets.hf_datasets import build_hf_data_loader -from torchtitan.datasets.tokenizer import create_tokenizer +from torchtitan.datasets.tokenizer import build_tokenizer __all__ = [ "build_hf_data_loader", - "create_tokenizer", + "build_tokenizer", ] diff --git a/torchtitan/datasets/hf_datasets.py b/torchtitan/datasets/hf_datasets.py index d8cd5d83..0b894e24 100644 --- a/torchtitan/datasets/hf_datasets.py +++ b/torchtitan/datasets/hf_datasets.py @@ -20,7 +20,7 @@ ) from e from torchtitan.datasets.tokenizer import Tokenizer -from torchtitan.logging_utils import logger +from torchtitan.logging import logger from datasets import load_dataset from datasets.distributed import split_dataset_by_node diff --git a/torchtitan/datasets/tokenizer/__init__.py b/torchtitan/datasets/tokenizer/__init__.py index 346caf83..7ff74722 100644 --- a/torchtitan/datasets/tokenizer/__init__.py +++ b/torchtitan/datasets/tokenizer/__init__.py @@ -8,10 +8,10 @@ from torchtitan.datasets.tokenizer.tiktoken import TikTokenizer from torchtitan.datasets.tokenizer.tokenizer import Tokenizer -from torchtitan.logging_utils import logger +from torchtitan.logging import logger -def create_tokenizer(tokenizer_type: str, tokenizer_path: str) -> Tokenizer: +def build_tokenizer(tokenizer_type: str, tokenizer_path: str) -> Tokenizer: logger.info(f"Building {tokenizer_type} tokenizer locally from {tokenizer_path}") if tokenizer_type == "sentencepiece": return SentencePieceTokenizer(tokenizer_path) diff --git a/torchtitan/datasets/tokenizer/sentencepiece.py b/torchtitan/datasets/tokenizer/sentencepiece.py index 7229daa3..c71afddd 100644 --- a/torchtitan/datasets/tokenizer/sentencepiece.py +++ b/torchtitan/datasets/tokenizer/sentencepiece.py @@ -11,7 +11,7 @@ from sentencepiece import SentencePieceProcessor from torchtitan.datasets.tokenizer.tokenizer import Tokenizer -from torchtitan.logging_utils import logger +from torchtitan.logging import logger class SentencePieceTokenizer(Tokenizer): diff --git a/torchtitan/datasets/tokenizer/tiktoken.py b/torchtitan/datasets/tokenizer/tiktoken.py index 1ec5de20..c879e7f3 100644 --- a/torchtitan/datasets/tokenizer/tiktoken.py +++ b/torchtitan/datasets/tokenizer/tiktoken.py @@ -26,7 +26,7 @@ from tiktoken.load import load_tiktoken_bpe from torchtitan.datasets.tokenizer.tokenizer import Tokenizer -from torchtitan.logging_utils import logger +from torchtitan.logging import logger class TikTokenizer(Tokenizer): diff --git a/torchtitan/float8_linear.py b/torchtitan/float8_linear.py index 658a41cc..fa311061 100644 --- a/torchtitan/float8_linear.py +++ b/torchtitan/float8_linear.py @@ -20,7 +20,7 @@ from torch._logging import warning_once from torchtitan.config_manager import JobConfig -from torchtitan.logging_utils import logger +from torchtitan.logging import logger @functools.lru_cache(None) diff --git a/torchtitan/logging_utils.py b/torchtitan/logging.py similarity index 100% rename from torchtitan/logging_utils.py rename to torchtitan/logging.py diff --git a/torchtitan/metrics.py b/torchtitan/metrics.py index 1717439b..f86ccc98 100644 --- a/torchtitan/metrics.py +++ b/torchtitan/metrics.py @@ -12,7 +12,8 @@ import torch from torch.utils.tensorboard import SummaryWriter from torchtitan.config_manager import JobConfig -from torchtitan.logging_utils import logger +from torchtitan.logging import logger +from torchtitan.parallelisms import ParallelDims # named tuple for passing GPU memory stats for logging GPUMemStats = namedtuple( @@ -110,16 +111,29 @@ def close(self): self.writer.close() +def _get_metrics_rank(parallel_dims: ParallelDims) -> int: + """ + Returns global rank 0 in non-pipeline-parallel configs, and returns the global + rank of the 0th rank in the last pipeline stage when pipeline parallelism is enabled. + """ + if parallel_dims.pp_enabled: + world_size = parallel_dims.world_size + pp_size = parallel_dims.pp + metrics_log_rank = (world_size // pp_size) * (pp_size - 1) + else: + metrics_log_rank = 0 + + return metrics_log_rank + + def build_metric_logger( - config: JobConfig, metrics_log_rank: int = 0, tag: Optional[str] = None + config: JobConfig, parallel_dims: ParallelDims, tag: Optional[str] = None ): """ - metrics_log_rank controls which rank acts as 'rank 0' for logging metrics. - - If 'tb_config.rank_0_only' is set, then `metrics_log_rank` will be used as the rank to log metrics. - This is intended to allow logging from the 0th rank within the last pipeline stage group, in case pipeline - parallelism is enabled, without forcing logging from all ranks to capture loss information when using pipeline - parallelism. + parallel_dims is used to determine the rank to log metrics from if 'tb_config.rank_0_only=True'. + In that case, `_get_metrics_rank` will be used to calculate which rank acts as 'rank 0'. This is + intended to allow logging from the 0th rank within the last pipeline stage group, in case pipeline + parallelism is enabled, without forcing logging from all ranks to capture loss information. """ dump_dir = config.job.dump_folder tb_config = config.metrics @@ -134,7 +148,7 @@ def build_metric_logger( f"Metrics logging active. Tensorboard logs will be saved at {log_dir}" ) if tb_config.rank_0_only: - enable_tb = torch.distributed.get_rank() == metrics_log_rank + enable_tb = torch.distributed.get_rank() == _get_metrics_rank(parallel_dims) else: rank_str = f"rank_{torch.distributed.get_rank()}" log_dir = os.path.join(log_dir, rank_str) diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index 49cda624..e47d0fb9 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -14,7 +14,7 @@ import torch import torch.nn.functional as F from torch import nn -from torchtitan.models.norms import create_norm +from torchtitan.models.norms import build_norm @dataclass @@ -291,10 +291,10 @@ def __init__(self, layer_id: int, model_args: ModelArgs): self.layer_id = layer_id self.num_layers = model_args.n_layers - self.attention_norm = create_norm( + self.attention_norm = build_norm( model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps ) - self.ffn_norm = create_norm( + self.ffn_norm = build_norm( model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps ) @@ -370,7 +370,7 @@ def __init__(self, model_args: ModelArgs): for layer_id in range(model_args.n_layers): self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args) - self.norm = create_norm( + self.norm = build_norm( model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps ) diff --git a/torchtitan/models/norms.py b/torchtitan/models/norms.py index 10a6b853..c0ef6a80 100644 --- a/torchtitan/models/norms.py +++ b/torchtitan/models/norms.py @@ -18,18 +18,18 @@ from torch.distributed._tensor.experimental import local_map -def create_norm(norm_type: str, dim: int, eps: float = 1e-6): +def build_norm(norm_type: str, dim: int, eps: float = 1e-6): """ - Creates the specified normalization layer based on the norm_type. + Builds the specified normalization layer based on the norm_type. Args: - norm_type (str): The type of normalization layer to create. + norm_type (str): The type of normalization layer to build. Supported types: 1. rmsnorm 2. fused_rmsnorm 3. layernorm 4. np_layernorm dim (int): The dimension of the normalization layer. eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6. Returns: - The created normalization layer. + The built normalization layer. Raises: NotImplementedError: If an unknown norm_type is provided. diff --git a/torchtitan/lr_scheduling.py b/torchtitan/optimizer.py similarity index 51% rename from torchtitan/lr_scheduling.py rename to torchtitan/optimizer.py index 9f766268..3f9eb3a8 100644 --- a/torchtitan/lr_scheduling.py +++ b/torchtitan/optimizer.py @@ -6,10 +6,57 @@ import functools +import torch from torch.optim.lr_scheduler import LambdaLR from torchtitan.config_manager import JobConfig +# consider split between PP and non-PP +def build_optimizers(model_parts, job_config: JobConfig): + """Wrap one optimizer per model part in an OptimizersContainer which provides a single + step() and zero_grad() method for all the child optimizers. + """ + + def _build_optimizer(model): + name = job_config.optimizer.name + lr = job_config.optimizer.lr + fused = job_config.optimizer.fused + + # Common parameters for both optimizers + optimizer_kwargs = { + "lr": lr, + "betas": (0.9, 0.95), + "weight_decay": 0.1, + "fused": fused, + "foreach": not fused, + } + if name == "Adam": + # TODO: make the optimizer options configurable by toml/cmd args + optimizer = torch.optim.Adam(model.parameters(), **optimizer_kwargs) + elif name == "AdamW": + optimizer = torch.optim.AdamW(model.parameters(), **optimizer_kwargs) + else: + raise NotImplementedError(f"Optimizer {name} not added.") + + return optimizer + + class OptimizersContainer: + """Util for calling step/zero_grad on multiple optimizers needed for virtual pipeline stages""" + + def __init__(self, optimizers): + self.optimizers = optimizers + + def step(self): + for optimizer in self.optimizers: + optimizer.step() + + def zero_grad(self): + for optimizer in self.optimizers: + optimizer.zero_grad() + + return OptimizersContainer([_build_optimizer(model) for model in model_parts]) + + def linear_warmup_linear_decay( warmup_steps: int, decay_steps: int, current_step: int ) -> float: @@ -32,8 +79,8 @@ def linear_warmup_linear_decay( return curr_adjustment -def get_lr_schedulers(optimizers, job_config: JobConfig): - def _get_lr_scheduler(optimizer): +def build_lr_schedulers(optimizers, job_config: JobConfig): + def _build_lr_scheduler(optimizer): """Build a linear warmup and linear decay scheduler""" warmup_steps = int(job_config.training.warmup_steps) decay_steps = float(max(1, job_config.training.steps - warmup_steps)) @@ -54,5 +101,5 @@ def step(self): schedulers.step() return SchedulersContainer( - [_get_lr_scheduler(optimizer) for optimizer in optimizers] + [_build_lr_scheduler(optimizer) for optimizer in optimizers] ) diff --git a/torchtitan/parallelisms/__init__.py b/torchtitan/parallelisms/__init__.py index 2fdba316..7188474d 100644 --- a/torchtitan/parallelisms/__init__.py +++ b/torchtitan/parallelisms/__init__.py @@ -8,8 +8,17 @@ from functools import cached_property from torch.distributed.device_mesh import init_device_mesh -from torchtitan.logging_utils import logger +from torchtitan.logging import logger from torchtitan.parallelisms.parallelize_llama import parallelize_llama, pipeline_llama +from torchtitan.parallelisms.pipelining_utils import build_pipeline_schedule + + +__all__ = [ + "build_pipeline_schedule", + "models_parallelize_fns", + "models_pipelining_fns", + "ParallelDims", +] models_parallelize_fns = { "llama2": parallelize_llama, diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index e3c6fc80..11a8188f 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -32,7 +32,7 @@ ) from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP -from torchtitan.logging_utils import logger +from torchtitan.logging import logger from torchtitan.models.llama.model import ModelArgs from torchtitan.parallelisms.pipelining_utils import stage_ids_this_rank diff --git a/torchtitan/parallelisms/pipelining_utils.py b/torchtitan/parallelisms/pipelining_utils.py index adf9eb09..aafe70fa 100644 --- a/torchtitan/parallelisms/pipelining_utils.py +++ b/torchtitan/parallelisms/pipelining_utils.py @@ -11,12 +11,12 @@ ScheduleGPipe, ScheduleInterleaved1F1B, ) -from torchtitan.logging_utils import logger +from torchtitan.logging import logger def build_pipeline_schedule(job_config, parallel_dims, stages, loss_fn): - looped_schedule = False + if job_config.experimental.pipeline_parallel_schedule == "1f1b": schedule_class = Schedule1F1B elif job_config.experimental.pipeline_parallel_schedule == "gpipe": diff --git a/torchtitan/profiling.py b/torchtitan/profiling.py index 662b64f8..9da5c8fb 100644 --- a/torchtitan/profiling.py +++ b/torchtitan/profiling.py @@ -11,7 +11,7 @@ import torch from torchtitan.config_manager import JobConfig -from torchtitan.logging_utils import logger +from torchtitan.logging import logger # the number of warmup steps before the active step in each profiling cycle WARMUP = 3 diff --git a/torchtitan/utils.py b/torchtitan/utils.py index c2983660..3ed74d13 100644 --- a/torchtitan/utils.py +++ b/torchtitan/utils.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import gc import os from dataclasses import dataclass from datetime import timedelta @@ -13,18 +14,17 @@ import torch.distributed._functional_collectives as funcol import torch.distributed.distributed_c10d as c10d from torch.distributed.device_mesh import DeviceMesh -from torchtitan.logging_utils import logger -from torchtitan.parallelisms import ParallelDims +from torchtitan.logging import logger def dist_max(x: Union[int, float], mesh: DeviceMesh) -> float: tensor = torch.tensor(x).cuda() - return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.MAX.name, group=mesh) + return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.MAX.name, group=mesh).item() def dist_mean(x: Union[int, float], mesh: DeviceMesh) -> float: tensor = torch.tensor(x).cuda() - return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.AVG.name, group=mesh) + return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.AVG.name, group=mesh).item() def _warn_overwrite_env(env, val): @@ -35,24 +35,6 @@ def _warn_overwrite_env(env, val): os.environ[env] = val -def get_metrics_rank(world_mesh: DeviceMesh, parallel_dims: ParallelDims) -> int: - """ - Returns global rank 0 in non-pipeline-parallel configs, and returns the global - rank of the 0th rank in the last pipeline stage when pipeline parallelism is enabled. - """ - if parallel_dims.pp_enabled: - assert ( - world_mesh.mesh_dim_names[0] == "pp" - ), "get_metrics_rank assumes pp is the outer mesh dim" - pp_mesh = world_mesh["pp"] - pp_size = pp_mesh.size() - metrics_log_rank = int((world_mesh.size() // pp_size) * (pp_size - 1)) - else: - metrics_log_rank = 0 - - return metrics_log_rank - - def set_pg_timeouts(timeout, world_mesh): """ Sets the timeout for all PGs in the provided mesh, and the default (world) group. @@ -80,6 +62,19 @@ def set_pg_timeouts(timeout, world_mesh): torch.distributed.distributed_c10d._set_pg_timeout(timeout, group) +# used to avoid stragglers in garbage collection +class GarbageCollection: + def __init__(self, gc_freq=1000): + assert gc_freq > 0, "gc_freq must be a positive integer" + self.gc_freq = gc_freq + gc.disable() + gc.collect(1) + + def run(self, step_count): + if step_count > 1 and step_count % self.gc_freq == 0: + gc.collect(1) + + TRACE_BUFFER_SIZE = "TORCH_NCCL_TRACE_BUFFER_SIZE" TRACE_FILE = "TORCH_NCCL_DEBUG_INFO_TEMP_FILE" DUMP_ON_TIMEOUT = "TORCH_NCCL_DUMP_ON_TIMEOUT" diff --git a/train.py b/train.py index 5a637f46..eef7401d 100644 --- a/train.py +++ b/train.py @@ -5,138 +5,32 @@ # LICENSE file in the root directory of this source tree. import contextlib -import gc import os import time - -from dataclasses import dataclass, field from datetime import timedelta -from io import BytesIO -from timeit import default_timer as timer -from typing import Any, Dict, List - -import numpy as np import torch -import torch.nn.functional as F -from torch.distributed import destroy_process_group -from torch.distributed.checkpoint.stateful import Stateful +import torchtitan.utils as utils from torch.distributed.elastic.multiprocessing.errors import record -from torch.distributed.tensor.parallel import loss_parallel - -from torchtitan.checkpoint import CheckpointManager +from torchtitan.checkpoint import CheckpointManager, TrainState from torchtitan.config_manager import JobConfig -from torchtitan.datasets import build_hf_data_loader, create_tokenizer +from torchtitan.datasets import build_hf_data_loader, build_tokenizer from torchtitan.float8_linear import ( maybe_build_fp8_linear, maybe_precompute_fp8_dynamic_scale_for_fsdp, maybe_sync_float8_amax_and_scale_history, ) -from torchtitan.logging_utils import init_logger, logger -from torchtitan.lr_scheduling import get_lr_schedulers +from torchtitan.logging import init_logger, logger from torchtitan.metrics import build_gpu_memory_monitor, build_metric_logger from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config +from torchtitan.optimizer import build_lr_schedulers, build_optimizers from torchtitan.parallelisms import ( + build_pipeline_schedule, models_parallelize_fns, models_pipelining_fns, ParallelDims, ) -from torchtitan.parallelisms.pipelining_utils import build_pipeline_schedule from torchtitan.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling -from torchtitan.utils import ( - Color, - dist_max, - dist_mean, - get_metrics_rank, - get_num_flop_per_token, - get_num_params, - get_peak_flops, - init_distributed, - NoColor, - set_pg_timeouts, -) - - -@dataclass -class TrainState(Stateful): - step: int = 0 - global_avg_losses: List[float] = field(default_factory=list) - global_max_losses: List[float] = field(default_factory=list) - log_steps: List[int] = field(default_factory=list) - - def state_dict(self) -> Dict[str, Any]: - # Only checkpoint global_avg_losses and global_max_losses per log frequency - # to avoid sync overhead in every iteration. - global_avg_losses_bytes = BytesIO() - torch.save(self.global_avg_losses, global_avg_losses_bytes) - global_max_losses_bytes = BytesIO() - torch.save(self.global_max_losses, global_max_losses_bytes) - log_steps_bytes = BytesIO() - torch.save(self.log_steps, log_steps_bytes) - return { - "step": torch.tensor(self.step, dtype=torch.int32), - "global_avg_losses": global_avg_losses_bytes, - "global_max_losses": global_max_losses_bytes, - "log_steps": log_steps_bytes, - } - - def load_state_dict(self, state_dict) -> None: - self.step = state_dict["step"].item() - state_dict["global_avg_losses"].seek(0) - self.global_avg_losses = torch.load( - state_dict["global_avg_losses"], weights_only=False - ) - state_dict["global_max_losses"].seek(0) - self.global_max_losses = torch.load( - state_dict["global_max_losses"], weights_only=False - ) - state_dict["log_steps"].seek(0) - self.log_steps = torch.load(state_dict["log_steps"], weights_only=False) - - -def build_optimizers(model_parts, job_config: JobConfig): - """Wrap one optimizer per model part in an OptimizersContainer which provides a single - step() and zero_grad() method for all the child optimizers. - """ - - def _build_optimizer(model): - name = job_config.optimizer.name - lr = job_config.optimizer.lr - fused = job_config.optimizer.fused - - # Common parameters for both optimizers - optimizer_kwargs = { - "lr": lr, - "betas": (0.9, 0.95), - "weight_decay": 0.1, - "fused": fused, - "foreach": not fused, - } - if name == "Adam": - # TODO: make the optimizer options configurable by toml/cmd args - optimizer = torch.optim.Adam(model.parameters(), **optimizer_kwargs) - elif name == "AdamW": - optimizer = torch.optim.AdamW(model.parameters(), **optimizer_kwargs) - else: - raise NotImplementedError(f"Optimizer {name} not added.") - - return optimizer - - class OptimizersContainer: - """Util for calling step/zero_grad on multiple optimizers needed for virtual pipeline stages""" - - def __init__(self, optimizers): - self.optimizers = optimizers - - def step(self): - for optimizer in self.optimizers: - optimizer.step() - - def zero_grad(self): - for optimizer in self.optimizers: - optimizer.zero_grad() - - return OptimizersContainer([_build_optimizer(model) for model in model_parts]) def get_train_context(enable_loss_parallel: bool, enable_compiled_autograd: bool): @@ -144,12 +38,11 @@ def get_train_context(enable_loss_parallel: bool, enable_compiled_autograd: bool def context(): with contextlib.ExitStack() as stack: if enable_loss_parallel: - stack.enter_context(loss_parallel()) + stack.enter_context(torch.distributed.tensor.parallel.loss_parallel()) if enable_compiled_autograd: stack.enter_context( torch._dynamo.utils.maybe_enable_compiled_autograd(True) ) - yield return context @@ -162,12 +55,10 @@ def main(job_config: JobConfig): logger.info(f"Starting job: {job_config.job.description}") # used for colorful printing - color = Color if job_config.metrics.enable_color_printing else NoColor + color = utils.Color if job_config.metrics.enable_color_printing else utils.NoColor # take control of garbage collection to avoid stragglers - _gc_freq = job_config.training.gc_freq - gc.disable() - gc.collect(1) + gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq) # init distributed world_size = int(os.environ["WORLD_SIZE"]) @@ -181,14 +72,16 @@ def main(job_config: JobConfig): ) device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") torch.cuda.set_device(device) - init_distributed(job_config) + utils.init_distributed(job_config) + # initialize GPU memory monitor and get peak flops for MFU calculation + gpu_memory_monitor = build_gpu_memory_monitor() + gpu_peak_flops = utils.get_peak_flops(gpu_memory_monitor.device_name) # build meshes world_mesh = parallel_dims.build_mesh(device_type="cuda") if parallel_dims.dp_enabled: dp_mesh = world_mesh["dp"] - dp_degree = dp_mesh.size() - dp_rank = dp_mesh.get_local_rank() + dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank() else: dp_degree, dp_rank = 1, 0 @@ -199,7 +92,7 @@ def main(job_config: JobConfig): # build tokenizer tokenizer_type = model_name_to_tokenizer[model_name] - tokenizer = create_tokenizer(tokenizer_type, job_config.model.tokenizer_path) + tokenizer = build_tokenizer(tokenizer_type, job_config.model.tokenizer_path) # build dataloader data_loader = build_hf_data_loader( @@ -212,15 +105,6 @@ def main(job_config: JobConfig): dp_rank, ) - train_context = get_train_context( - parallel_dims.loss_parallel_enabled, - job_config.experimental.enable_compiled_autograd, - ) - - # loss fn can be shared by pipeline-parallel or non-pp execution - def loss_fn(pred, labels): - return F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1)) - # build model (using meta init) model_cls = model_name_to_cls[model_name] model_config = models_config[model_name][job_config.model.flavor] @@ -240,9 +124,9 @@ def loss_fn(pred, labels): maybe_build_fp8_linear(whole_model, job_config, parallel_dims.dp_enabled) # log model size - model_param_count = get_num_params(whole_model) - num_flop_per_token = get_num_flop_per_token( - get_num_params(whole_model, exclude_embedding=True), + model_param_count = utils.get_num_params(whole_model) + num_flop_per_token = utils.get_num_flop_per_token( + utils.get_num_params(whole_model, exclude_embedding=True), model_config, job_config.training.seq_len, ) @@ -251,11 +135,6 @@ def loss_fn(pred, labels): f"{color.red}size: {model_param_count:,} total parameters{color.reset}" ) - # initialize GPU memory monitor before applying parallelisms to the model - gpu_memory_monitor = build_gpu_memory_monitor() - # obtain the peak flops of bf16 type for MFU calculation - gpu_peak_flops = get_peak_flops(gpu_memory_monitor.device_name) - if parallel_dims.pp_enabled: stages, model_parts = models_pipelining_fns[model_name]( whole_model, world_mesh, parallel_dims, job_config, device, model_config @@ -276,6 +155,12 @@ def loss_fn(pred, labels): for model in model_parts: model.to_empty(device=init_device) + # loss fn can be shared by pipeline-parallel or non-pp execution + def loss_fn(pred, labels): + return torch.nn.functional.cross_entropy( + pred.flatten(0, 1), labels.flatten(0, 1) + ) + if parallel_dims.pp_enabled: pp_schedule = build_pipeline_schedule( job_config, parallel_dims, stages, loss_fn @@ -295,11 +180,7 @@ def loss_fn(pred, labels): # build optimizer after applying parallelisms to the model optimizers = build_optimizers(model_parts, job_config) - lr_schedulers = get_lr_schedulers(optimizers.optimizers, job_config) - - metric_logger = build_metric_logger( - job_config, metrics_log_rank=get_metrics_rank(world_mesh, parallel_dims) - ) + lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config) train_state = TrainState() @@ -309,10 +190,10 @@ def loss_fn(pred, labels): # load initial checkpoint checkpoint = CheckpointManager( + dataloader=data_loader, model_parts=model_parts, optimizers=optimizers.optimizers, lr_schedulers=lr_schedulers.schedulers, - dataloader=data_loader, states={"train_state": train_state}, job_config=job_config, ) @@ -333,6 +214,8 @@ def loss_fn(pred, labels): "Please run `./create_seed_checkpoint.sh` and rerun training with `--checkpoint.enable_checkpoint`" ) + metric_logger = build_metric_logger(job_config, parallel_dims) + # plot losses loaded from checkpoint (if any) to TensorBoard # NOTE: Loss info after the last log step before checkpoint saving will not be ploted. # This can be avoided by setting checkpoint.interval to be a multiple of metrics.log_freq @@ -346,21 +229,28 @@ def loss_fn(pred, labels): data_iterator = iter(data_loader) - checkpoint.reset() + train_context = get_train_context( + parallel_dims.loss_parallel_enabled, + job_config.experimental.enable_compiled_autograd, + ) # variables used to keep info for metrics logging - losses_since_last_log: List[float] = [] + losses_since_last_log = [] ntokens_since_last_log = 0 - data_loading_times: List[float] = [] - time_last_log = timer() + data_loading_times = [] + time_last_log = time.perf_counter() gpu_memory_monitor.reset_peak_stats() + checkpoint.reset() + # train loop logger.info( f"Training starts at step {train_state.step + 1}, " - f"with local batch size: {job_config.training.batch_size}, " - f"sequence length: {job_config.training.seq_len}, " - f"total steps: {job_config.training.steps}({job_config.training.warmup_steps}), " + f"with local batch size {job_config.training.batch_size}, " + f"global batch size {job_config.training.batch_size * dp_degree}, " + f"sequence length {job_config.training.seq_len}, " + f"total steps {job_config.training.steps} " + f"(warmup {job_config.training.warmup_steps})" ) with maybe_enable_profiling( job_config, global_step=train_state.step @@ -369,15 +259,14 @@ def loss_fn(pred, labels): ) as memory_profiler: while train_state.step < job_config.training.steps: train_state.step += 1 - if train_state.step > 1 and train_state.step % _gc_freq == 0: - gc.collect(1) + gc_handler.run(train_state.step) # get batch - data_load_start = timer() + data_load_start = time.perf_counter() batch = next(data_iterator) input_ids, labels = batch ntokens_since_last_log += labels.numel() - data_loading_times.append(timer() - data_load_start) + data_loading_times.append(time.perf_counter() - data_load_start) input_ids = input_ids.cuda() labels = labels.cuda() @@ -422,7 +311,7 @@ def loss_fn(pred, labels): maybe_sync_float8_amax_and_scale_history(model, job_config) # optimizer step - checkpoint.wait_for_staging() + checkpoint.maybe_wait_for_staging() optimizers.step() lr_schedulers.step() @@ -439,23 +328,21 @@ def loss_fn(pred, labels): or train_state.step % job_config.metrics.log_freq == 0 ): losses = [loss.item() for loss in losses_since_last_log] - avg_loss, max_loss = ( - np.mean(losses), - np.max(losses), - ) + avg_loss, max_loss = sum(losses) / len(losses), max(losses) if parallel_dims.dp_enabled: global_avg_loss, global_max_loss = ( - dist_mean(avg_loss, dp_mesh).item(), - dist_max(max_loss, dp_mesh).item(), + utils.dist_mean(avg_loss, dp_mesh), + utils.dist_max(max_loss, dp_mesh), ) else: global_avg_loss, global_max_loss = avg_loss, max_loss + # update train state train_state.log_steps.append(train_state.step) train_state.global_avg_losses.append(global_avg_loss) train_state.global_max_losses.append(global_max_loss) - time_delta = timer() - time_last_log + time_delta = time.perf_counter() - time_last_log # tokens per second, abbr. as wps by convention wps = ntokens_since_last_log / ( @@ -467,8 +354,8 @@ def loss_fn(pred, labels): mfu = 100 * num_flop_per_token * wps / gpu_peak_flops time_end_to_end = time_delta / job_config.metrics.log_freq - time_data_loading = np.mean(data_loading_times) - time_data_loading_pct = 100 * np.sum(data_loading_times) / time_delta + time_data_loading = sum(data_loading_times) / len(data_loading_times) + time_data_loading_pct = 100 * sum(data_loading_times) / time_delta gpu_mem_stats = gpu_memory_monitor.get_peak_stats() @@ -501,7 +388,7 @@ def loss_fn(pred, labels): losses_since_last_log.clear() ntokens_since_last_log = 0 data_loading_times.clear() - time_last_log = timer() + time_last_log = time.perf_counter() gpu_memory_monitor.reset_peak_stats() checkpoint.save( @@ -517,7 +404,7 @@ def loss_fn(pred, labels): # Reduce timeout after first train step for faster signal (assumes lazy init, compile are finished) if train_state.step == 1: - set_pg_timeouts( + utils.set_pg_timeouts( timeout=timedelta(seconds=job_config.comm.train_timeout_seconds), world_mesh=world_mesh, ) @@ -534,4 +421,4 @@ def loss_fn(pred, labels): config = JobConfig() config.parse_args() main(config) - destroy_process_group() + torch.distributed.destroy_process_group() From bf907104facee024ac176ed3da0459f9b304ffe5 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Tue, 30 Jul 2024 23:52:15 -0700 Subject: [PATCH 26/29] [BE][2/n] use proper method signatures in parallelize_llama ghstack-source-id: 17a1ee9f03f13423a30183c5c8d7ad30f8c8dbfc Pull Request resolved: https://github.com/pytorch/torchtitan/pull/495 --- torchtitan/parallelisms/parallelize_llama.py | 129 ++++++++++--------- train.py | 2 +- 2 files changed, 67 insertions(+), 64 deletions(-) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 11a8188f..e86f93b9 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -118,18 +118,17 @@ def selective_checkpointing_context_fn(): def get_tp_parallel_strategy_for_transformer_block( - job_config: JobConfig, - model: nn.Module, + enable_float8: bool, ) -> Tuple[RowwiseParallel, ColwiseParallel, PrepareModuleInput]: """Get the parallel strategy for the transformer model. This function handles the special case of using float8 with tensor parallelism. """ - if job_config.training.enable_float8_linear: - # TODO(future PR): once float8 configuration supports delayed + if enable_float8: + # TODO(vkuzo): once float8 configuration supports delayed # scaling, add a check here to enforce supported float8 all-gather # configurations - # TODO(future PR): add the items below to __init__.py of torchao.float8, + # TODO(vkuzo): add the items below to __init__.py of torchao.float8, # and import from there from torchao.float8.float8_tensor_parallel import ( Float8ColwiseParallel, @@ -143,7 +142,7 @@ def get_tp_parallel_strategy_for_transformer_block( def pipeline_llama( model: nn.Module, - world_mesh: DeviceMesh, + pp_mesh: DeviceMesh, parallel_dims: "ParallelDims", job_config: JobConfig, device: DeviceType, @@ -157,11 +156,11 @@ def pipeline_llama( ) if split_mode == "manual": return pipeline_llama_manual( - model, world_mesh, parallel_dims, job_config, device, model_config + model, pp_mesh, parallel_dims, job_config, device, model_config ) elif split_mode == "tracer": return pipeline_llama_tracer( - model, world_mesh, parallel_dims, job_config, device, model_config + model, pp_mesh, parallel_dims, job_config, device, model_config ) @@ -184,7 +183,7 @@ def _mixed_precision_dtype( def pipeline_llama_manual( whole_model: nn.Module, - world_mesh: DeviceMesh, + pp_mesh: DeviceMesh, parallel_dims: "ParallelDims", job_config: JobConfig, device: DeviceType, @@ -198,7 +197,6 @@ def pipeline_llama_manual( The stage object is used to create a pipeline schedule, and the model object can be used for applying SPMD parallelism. """ - pp_mesh = world_mesh["pp"] pp_rank = pp_mesh.get_local_rank() pp_size = pp_mesh.size() microbatches = ( @@ -287,7 +285,7 @@ def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=Fal def pipeline_llama_tracer( model: nn.Module, - world_mesh: DeviceMesh, + pp_mesh: DeviceMesh, parallel_dims: "ParallelDims", job_config: JobConfig, device: DeviceType, @@ -306,7 +304,6 @@ def pipeline_llama_tracer( "To work around, set mixed_precision_param to float32." ) - pp_mesh = world_mesh["pp"] pp_rank = pp_mesh.get_local_rank() pp_size = pp_mesh.size() microbatches = ( @@ -341,15 +338,12 @@ def pipeline_llama_tracer( def apply_tp( model: nn.Module, - world_mesh: DeviceMesh, - parallel_dims: "ParallelDims", - job_config: JobConfig, + tp_mesh: DeviceMesh, + loss_parallel: bool, + enable_float8: bool, + enable_async_tp: bool, ): """Apply tensor parallelism.""" - - tp_mesh = world_mesh["tp"] - loss_parallel = parallel_dims.loss_parallel_enabled - # 1. Parallelize the embedding and shard its outputs (which are the first # transformer block's inputs) # 2. Parallelize the root norm layer over the sequence dim @@ -377,7 +371,7 @@ def apply_tp( rowwise_parallel_weight, colwise_parallel_weight, prepare_module_input, - ) = get_tp_parallel_strategy_for_transformer_block(job_config, model) + ) = get_tp_parallel_strategy_for_transformer_block(enable_float8) # Apply tensor + sequence parallelism to every transformer block # NOTE: At the cost of model code change, we can accelerate Sequence Parallel @@ -416,7 +410,7 @@ def apply_tp( ) # updates expressly for async tensor parallel - if job_config.experimental.enable_async_tensor_parallel: + if enable_async_tp: from torch.distributed._symmetric_memory import enable_symm_mem_for_group torch._dynamo.config.cache_size_limit = 10000 @@ -434,16 +428,14 @@ def apply_tp( job_config.training.compile = True logger.info( - f"Applied{' Async ' if job_config.experimental.enable_async_tensor_parallel else ' '}Tensor Parallelism to the model" + f"Applied {'Async ' if enable_async_tp else ''}" + "Tensor Parallelism to the model" ) return model -def apply_ac(model: nn.Module, job_config: JobConfig): +def apply_ac(model: nn.Module, ac_config: JobConfig): """Apply activation checkpointing to the model.""" - - ac_config = job_config.activation_checkpoint - for layer_id, transformer_block in model.layers.named_children(): transformer_block = checkpoint_wrapper(transformer_block, ac_config) model.layers.register_module(layer_id, transformer_block) @@ -452,14 +444,8 @@ def apply_ac(model: nn.Module, job_config: JobConfig): return model -def apply_compile(model: nn.Module, job_config: JobConfig): +def apply_compile(model: nn.Module): """Apply torch.compile to each transformer block.""" - - if job_config.model.norm_type == "fused_rmsnorm": - raise NotImplementedError( - "fused_rmsnorm is not compatible with torch.compile yet. Please use rmsnorm or layernorm." - ) - for layer_id, transformer_block in model.layers.named_children(): # TODO: dynamic shape have some issues so we turn it off for now. # TODO: inline inbuilt nn modules does not work yet, enable it to accelarate @@ -474,25 +460,19 @@ def apply_compile(model: nn.Module, job_config: JobConfig): def apply_fsdp( model: nn.Module, - world_mesh: DeviceMesh, - parallel_dims: "ParallelDims", - job_config: JobConfig, + dp_mesh: DeviceMesh, + param_dtype: torch.dtype, + reduce_dtype: torch.dtype, + pp_enabled: bool, ): """ Apply data parallelism to the model. FSDP2 is used here. """ - - dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh - assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names - - mp_policy = MixedPrecisionPolicy( - param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], - reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], - ) + mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} for layer_id, transformer_block in model.layers.items(): - if parallel_dims.pp_enabled: + if pp_enabled: # For PP, do not reshard after forward to avoid per-microbatch # all-gathers, which can be expensive and non-overlapped reshard_after_forward = False @@ -505,11 +485,9 @@ def apply_fsdp( **fsdp_config, reshard_after_forward=reshard_after_forward, ) - fully_shard( - model, **fsdp_config, reshard_after_forward=not parallel_dims.pp_enabled - ) + fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled) - if parallel_dims.pp_enabled: + if pp_enabled: # TODO # This PR https://github.com/pytorch/pytorch/pull/129519 added a safety check to avoid using 2D/3D DCP since # without strided sharding, DCP can not safely support resharding for 2D/3D. However, for PP to work, even @@ -526,22 +504,19 @@ def apply_fsdp( def apply_ddp( model: nn.Module, - world_mesh: DeviceMesh, - parallel_dims: "ParallelDims", - job_config: JobConfig, + dp_mesh: DeviceMesh, + enable_compile: bool, + enable_compiled_autograd: bool, ): - if world_mesh.ndim > 1: - raise RuntimeError("DDP has not supported > 1D parallelism.") - - if job_config.training.compile: - if job_config.experimental.enable_compiled_autograd: + if enable_compile: + if enable_compiled_autograd: torch._dynamo.config.optimize_ddp = ( "python_reducer_without_compiled_forward" ) else: torch._dynamo.config.optimize_ddp = "ddp_optimizer" - model = replicate(model, device_mesh=world_mesh, bucket_cap_mb=100) + model = replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100) logger.info("Applied DDP to the model") return model @@ -562,18 +537,46 @@ def parallelize_llama( """ if parallel_dims.tp_enabled: - model = apply_tp(model, world_mesh, parallel_dims, job_config) + model = apply_tp( + model, + world_mesh["tp"], + loss_parallel=parallel_dims.loss_parallel_enabled, + enable_float8=job_config.training.enable_float8_linear, + enable_async_tp=job_config.experimental.enable_async_tensor_parallel, + ) if job_config.activation_checkpoint.mode != "none": - model = apply_ac(model, job_config) + model = apply_ac(model, job_config.activation_checkpoint) if job_config.training.compile: - model = apply_compile(model, job_config) + if job_config.model.norm_type == "fused_rmsnorm": + raise NotImplementedError( + "fused_rmsnorm is not compatible with torch.compile yet. Please use rmsnorm or layernorm." + ) + model = apply_compile(model) if parallel_dims.dp_enabled: if parallel_dims.dp_type == "fsdp": - model = apply_fsdp(model, world_mesh, parallel_dims, job_config) + dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh + assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names + + model = apply_fsdp( + model, + dp_mesh, + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[ + job_config.training.mixed_precision_reduce + ], + pp_enabled=parallel_dims.pp_enabled, + ) else: - model = apply_ddp(model, world_mesh, parallel_dims, job_config) + if world_mesh.ndim > 1: + raise RuntimeError("DDP has not supported > 1D parallelism.") + model = apply_ddp( + model, + world_mesh, + enable_compile=job_config.training.compile, + enable_compiled_autograd=job_config.experimental.enable_compiled_autograd, + ) return model diff --git a/train.py b/train.py index eef7401d..92e29058 100644 --- a/train.py +++ b/train.py @@ -137,7 +137,7 @@ def main(job_config: JobConfig): if parallel_dims.pp_enabled: stages, model_parts = models_pipelining_fns[model_name]( - whole_model, world_mesh, parallel_dims, job_config, device, model_config + whole_model, pp_mesh, parallel_dims, job_config, device, model_config ) else: # In 1D/2D cases or PP with simple schedules, model_parts is just one item From 40f79d79f067d3d1887e9476474d17d791848be9 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Tue, 30 Jul 2024 23:52:17 -0700 Subject: [PATCH 27/29] [BE][3/n] wrap fp8 logic using Float8Handler ghstack-source-id: e94c7f6f4fad87c5432262c54beabd02de5541b8 Pull Request resolved: https://github.com/pytorch/torchtitan/pull/496 --- estimation.py | 14 +- torchtitan/config_manager.py | 83 ++++----- torchtitan/float8_linear.py | 173 ++++++++++--------- torchtitan/parallelisms/parallelize_llama.py | 2 +- train.py | 17 +- train_configs/debug_model.toml | 4 +- train_configs/llama2_13b.toml | 4 +- train_configs/llama2_70b.toml | 4 +- train_configs/llama2_7b.toml | 4 +- train_configs/llama3_70b.toml | 4 +- train_configs/llama3_8b.toml | 4 +- 11 files changed, 163 insertions(+), 150 deletions(-) diff --git a/estimation.py b/estimation.py index 3adcf663..acf867d5 100644 --- a/estimation.py +++ b/estimation.py @@ -16,10 +16,7 @@ from torchtitan.config_manager import JobConfig from torchtitan.datasets import build_tokenizer -from torchtitan.float8_linear import ( - maybe_build_fp8_linear, - maybe_precompute_fp8_dynamic_scale_for_fsdp, -) +from torchtitan.float8_linear import Float8Handler from torchtitan.logging import init_logger, logger from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config from torchtitan.optimizer import build_lr_schedulers, build_optimizers @@ -127,8 +124,10 @@ def loss_fn(pred, labels): with torch.device("meta"): whole_model = model_cls.from_model_args(model_config) + # a no-op hander if fp8 is not enabled + float8_handler = Float8Handler(job_config, parallel_dims) # swap to Float8Linear base on fp8 config - maybe_build_fp8_linear(whole_model, job_config, parallel_dims.dp_enabled) + float8_handler.convert_to_float8_training(whole_model) # apply PT-D DP/TP parallelisms and activation checkpointing model_parts = [whole_model] @@ -184,13 +183,14 @@ def loss_fn(pred, labels): torch.nn.utils.clip_grad_norm_( model.parameters(), job_config.training.max_norm, foreach=True ) + # sync float8 amaxes and scales + float8_handler.sync_float8_amax_and_scale_history(model) # optimizer step optimizers.step() lr_schedulers.step() - # when fp8 config is on, # calculate float8 dynamic amax/scale for all-parameter for FSDP2 # it issues a single all-reduce for all parameters at once for better performance - maybe_precompute_fp8_dynamic_scale_for_fsdp(whole_model, job_config) + float8_handler.precompute_fp8_dynamic_scale_for_fsdp(model) optimizers.zero_grad() print(f"Peak Memory at iter: {iter_idx}") fsdp_memtracker.display_snapshot("peak", units="MiB", tabulate=True) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index dd5ba7cd..2bc37bfb 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -348,46 +348,6 @@ def __init__(self): action="store_true", help="Whether to compile the model", ) - self.parser.add_argument( - "--training.enable_float8_linear", - action="store_true", - help=""" - If true, swaps `torch.nn.Linear` with `Float8Linear`. - This feature requires you to install 'torchao' which can be found - here: https://github.com/pytorch/ao - """, - ) - self.parser.add_argument( - "--training.enable_fsdp_float8_all_gather", - action="store_true", - default=False, - help="Whether enable float8 all-gather in FSDP", - ) - self.parser.add_argument( - "--training.precompute_float8_dynamic_scale_for_fsdp", - action="store_true", - default=False, - help="Whether precompute float8 scales dynamically for FSDP", - ) - self.parser.add_argument( - "--training.float8_scaling_type_input", - type=str, - default="dynamic", - help="float8 scaling for input, dynamic (default) or delayed", - choices=["dynamic", "delayed"], - ) - self.parser.add_argument( - "--training.float8_scaling_type_weight", - type=str, - default="dynamic", - help="float8 scaling for input, dynamic (default) or delayed", - ) - self.parser.add_argument( - "--training.float8_scaling_type_grad_output", - type=str, - default="dynamic", - help="float8 scaling for input, dynamic (default) or delayed", - ) self.parser.add_argument( "--training.gc_freq", type=int, @@ -483,6 +443,7 @@ def __init__(self): 0 is the default value. """, ) + # activation checkpointing configs self.parser.add_argument( "--activation_checkpoint.mode", @@ -500,6 +461,48 @@ def __init__(self): """, ) + # float8 configs + self.parser.add_argument( + "--float8.enable_float8_linear", + action="store_true", + help=""" + If true, swaps `torch.nn.Linear` with `Float8Linear`. + This feature requires you to install 'torchao' which can be found + here: https://github.com/pytorch/ao + """, + ) + self.parser.add_argument( + "--float8.enable_fsdp_float8_all_gather", + action="store_true", + default=False, + help="Whether enable float8 all-gather in FSDP", + ) + self.parser.add_argument( + "--float8.precompute_float8_dynamic_scale_for_fsdp", + action="store_true", + default=False, + help="Whether precompute float8 scales dynamically for FSDP", + ) + self.parser.add_argument( + "--float8.scaling_type_input", + type=str, + default="dynamic", + help="float8 scaling for input, dynamic (default) or delayed", + choices=["dynamic", "delayed"], + ) + self.parser.add_argument( + "--float8.scaling_type_weight", + type=str, + default="dynamic", + help="float8 scaling for input, dynamic (default) or delayed", + ) + self.parser.add_argument( + "--float8.scaling_type_grad_output", + type=str, + default="dynamic", + help="float8 scaling for input, dynamic (default) or delayed", + ) + # communications library settings self.parser.add_argument( "--comm.init_timeout_seconds", diff --git a/torchtitan/float8_linear.py b/torchtitan/float8_linear.py index fa311061..494b6046 100644 --- a/torchtitan/float8_linear.py +++ b/torchtitan/float8_linear.py @@ -12,127 +12,128 @@ # Note: Performance # Float8 experimental is intended to be ran under `torch.compile`` for competitive performance -import functools -from typing import Optional import torch import torch.nn as nn -from torch._logging import warning_once from torchtitan.config_manager import JobConfig from torchtitan.logging import logger +from torchtitan.parallelisms import ParallelDims -@functools.lru_cache(None) def is_sm90_or_later(): # Float8 is only supported on H100+ GPUs return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) -def maybe_build_fp8_linear( - model: nn.Module, job_config: JobConfig, dp_enabled: Optional[bool] = False -): - """ - This function converts the linear layers to `Float8Linear`. Note that today, - only dynamic tensor scaling (the default) is supported. - - This will mutate the model inplace. - """ - enable_float8_linear = job_config.training.enable_float8_linear - if not enable_float8_linear: - return - if not is_sm90_or_later(): - warning_once( - logger, - "Failed to swap to Float8Linear because SM90 or later is not available", - ) - return - try: - from torchao.float8 import ( - CastConfig, - convert_to_float8_training, - Float8LinearConfig, - ScalingType, - ) +class Float8Handler: + def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): + self.enabled = False + + float8_config = job_config.float8 + if not float8_config.enable_float8_linear: + return + if not is_sm90_or_later(): + logger.warning( + "Failed to swap to Float8Linear because SM90 or later is not available", + ) + return + try: + from torchao.float8 import CastConfig, Float8LinearConfig, ScalingType + except ImportError as e: + raise ImportError( + "torchao is not installed. Please install it to use fp8 linear layers." + ) from e # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear enable_fsdp_float8_all_gather = ( - job_config.training.enable_fsdp_float8_all_gather and dp_enabled - ) - scaling_type_input = ScalingType(job_config.training.float8_scaling_type_input) - scaling_type_weight = ScalingType( - job_config.training.float8_scaling_type_weight + parallel_dims.dp_enabled + and parallel_dims.dp_type == "fsdp" + and float8_config.enable_fsdp_float8_all_gather ) - scaling_type_grad_output = ScalingType( - job_config.training.float8_scaling_type_grad_output - ) - float8_config = Float8LinearConfig( + scaling_type_input = ScalingType(float8_config.scaling_type_input) + scaling_type_weight = ScalingType(float8_config.scaling_type_weight) + scaling_type_grad_output = ScalingType(float8_config.scaling_type_grad_output) + self.config = Float8LinearConfig( enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, cast_config_input=CastConfig(scaling_type=scaling_type_input), cast_config_weight=CastConfig(scaling_type=scaling_type_weight), cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output), enable_pre_and_post_forward=False, ) + + self.enabled = True + + # for precompute_fp8_dynamic_scale_for_fsdp + self.precompute_scale = ( + enable_fsdp_float8_all_gather + and float8_config.precompute_float8_dynamic_scale_for_fsdp + ) + + # for sync_float8_amax_and_scale_history + self.delayed_scaling = ( + scaling_type_input == "delayed" + or scaling_type_weight == "delayed" + or scaling_type_grad_output == "delayed" + ) + self._sync_float8_amax_and_scale_history = None + self.compile = job_config.training.compile + + logger.info("Float8 training active") + + def convert_to_float8_training(self, model: nn.Module): + """ + This function converts the linear layers of `model` to `Float8Linear`. + Note that today, only dynamic tensor scaling (the default) is supported. + This will mutate the model inplace. + """ + if not self.enabled: + return + + from torchao.float8 import convert_to_float8_training + + # Mutates the model inplace replacing instances of nn.Linear with Float8Linear convert_to_float8_training( model, - config=float8_config, + config=self.config, module_filter_fn=lambda mod, fqn: fqn != "output", ) logger.info( - f"Swapped to Float8Linear layers with {enable_fsdp_float8_all_gather=}" + "Swapped to Float8Linear layers with enable_fsdp_float8_all_gather=" + f"{self.config.enable_fsdp_float8_all_gather}" ) - except ImportError as exc: - raise ImportError( - "torchao is not installed. Please install it to use fp8 linear layers." - ) from exc - - -def maybe_precompute_fp8_dynamic_scale_for_fsdp( - model: nn.Module, job_config: JobConfig -): - if not ( - job_config.training.enable_float8_linear - and job_config.training.enable_fsdp_float8_all_gather - and job_config.training.precompute_float8_dynamic_scale_for_fsdp - ): - return - if not is_sm90_or_later(): - warning_once( - logger, - "Skipped precomputing fp8 scales because SM90 or later is not available", - ) - return - from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp - precompute_float8_dynamic_scale_for_fsdp(model) + def precompute_fp8_dynamic_scale_for_fsdp(self, model: nn.Module): + if not self.enabled: + return + if not self.precompute_scale: + return -_sync_float8_amax_and_scale_history = None + from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp + precompute_float8_dynamic_scale_for_fsdp(model) -def maybe_sync_float8_amax_and_scale_history(model: nn.Module, job_config: JobConfig): - if not ( - job_config.training.enable_float8_linear - and ( - job_config.training.float8_scaling_type_input == "delayed" - or job_config.training.float8_scaling_type_weight == "delayed" - or job_config.training.float8_scaling_type_grad_output == "delayed" - ) - ): - return + def sync_float8_amax_and_scale_history(self, model: nn.Module): + if not self.enabled: + return - from torchao.float8 import sync_float8_amax_and_scale_history + if not self.delayed_scaling: + return - # TODO(future): see if precalculating the modules to sync over is going to - # meaningfully help performance + from torchao.float8 import sync_float8_amax_and_scale_history - global _sync_float8_amax_and_scale_history - if _sync_float8_amax_and_scale_history is None: - if job_config.training.compile: - _sync_float8_amax_and_scale_history = torch.compile( - sync_float8_amax_and_scale_history - ) - else: - _sync_float8_amax_and_scale_history = sync_float8_amax_and_scale_history + # TODO(vkuzo): see if precalculating the modules to sync over is going to + # meaningfully help performance + + if self._sync_float8_amax_and_scale_history is None: + if self.compile: + self._sync_float8_amax_and_scale_history = torch.compile( + sync_float8_amax_and_scale_history + ) + else: + self._sync_float8_amax_and_scale_history = ( + sync_float8_amax_and_scale_history + ) - sync_float8_amax_and_scale_history(model) + self._sync_float8_amax_and_scale_history(model) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index e86f93b9..bdafc8e2 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -541,7 +541,7 @@ def parallelize_llama( model, world_mesh["tp"], loss_parallel=parallel_dims.loss_parallel_enabled, - enable_float8=job_config.training.enable_float8_linear, + enable_float8=job_config.float8.enable_float8_linear, enable_async_tp=job_config.experimental.enable_async_tensor_parallel, ) diff --git a/train.py b/train.py index 92e29058..615ed4e3 100644 --- a/train.py +++ b/train.py @@ -15,11 +15,7 @@ from torchtitan.checkpoint import CheckpointManager, TrainState from torchtitan.config_manager import JobConfig from torchtitan.datasets import build_hf_data_loader, build_tokenizer -from torchtitan.float8_linear import ( - maybe_build_fp8_linear, - maybe_precompute_fp8_dynamic_scale_for_fsdp, - maybe_sync_float8_amax_and_scale_history, -) +from torchtitan.float8_linear import Float8Handler from torchtitan.logging import init_logger, logger from torchtitan.metrics import build_gpu_memory_monitor, build_metric_logger from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config @@ -120,8 +116,10 @@ def main(job_config: JobConfig): with torch.device("meta"): whole_model = model_cls.from_model_args(model_config) + # a no-op hander if fp8 is not enabled + float8_handler = Float8Handler(job_config, parallel_dims) # swap to Float8Linear base on fp8 config - maybe_build_fp8_linear(whole_model, job_config, parallel_dims.dp_enabled) + float8_handler.convert_to_float8_training(whole_model) # log model size model_param_count = utils.get_num_params(whole_model) @@ -307,18 +305,17 @@ def loss_fn(pred, labels): model.parameters(), job_config.training.max_norm, foreach=True ) - # if float8 is enabled, sync float8 amaxes and scales - maybe_sync_float8_amax_and_scale_history(model, job_config) + # sync float8 amaxes and scales + float8_handler.sync_float8_amax_and_scale_history(model) # optimizer step checkpoint.maybe_wait_for_staging() optimizers.step() lr_schedulers.step() - # when float8 config is on, # calculate float8 dynamic amax/scale for all-parameter for FSDP2 # it issues a single all-reduce for all parameters at once for better performance - maybe_precompute_fp8_dynamic_scale_for_fsdp(model, job_config) + float8_handler.precompute_fp8_dynamic_scale_for_fsdp(model) losses_since_last_log.append(loss) diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index b36e9d0c..7d4187dc 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -37,7 +37,6 @@ max_norm = 1.0 # grad norm clipping steps = 10 data_parallel_degree = -1 tensor_parallel_degree = 1 -enable_float8_linear = false compile = false dataset = "c4_mini" # supported datasets: c4_mini (45K), c4 (177M) @@ -57,3 +56,6 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = 'selective' # ['none', 'selective', 'full'] selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[float8] +enable_float8_linear = false diff --git a/train_configs/llama2_13b.toml b/train_configs/llama2_13b.toml index 2dc29f2e..4727f965 100644 --- a/train_configs/llama2_13b.toml +++ b/train_configs/llama2_13b.toml @@ -33,7 +33,6 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 1 -enable_float8_linear = false compile = false dataset = "c4" @@ -52,3 +51,6 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = 'selective' # ['none', 'selective', 'full'] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[float8] +enable_float8_linear = false diff --git a/train_configs/llama2_70b.toml b/train_configs/llama2_70b.toml index f17496c5..83114876 100644 --- a/train_configs/llama2_70b.toml +++ b/train_configs/llama2_70b.toml @@ -33,7 +33,6 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 8 # 8-way TP -enable_float8_linear = false compile = false dataset = "c4" @@ -51,3 +50,6 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = 'full' # ['none', 'selective', 'full'] + +[float8] +enable_float8_linear = false diff --git a/train_configs/llama2_7b.toml b/train_configs/llama2_7b.toml index 69ae7285..22ab6c76 100644 --- a/train_configs/llama2_7b.toml +++ b/train_configs/llama2_7b.toml @@ -32,7 +32,6 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 1 # dp-only would be sufficient for 7B -enable_float8_linear = false compile = false dataset = "c4" @@ -51,3 +50,6 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = 'selective' # ['none', 'selective', 'full'] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[float8] +enable_float8_linear = false diff --git a/train_configs/llama3_70b.toml b/train_configs/llama3_70b.toml index 660f2c0b..62d75dfb 100644 --- a/train_configs/llama3_70b.toml +++ b/train_configs/llama3_70b.toml @@ -33,7 +33,6 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 8 # 8-way TP -enable_float8_linear = false compile = false dataset = "c4" @@ -51,3 +50,6 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = 'full' + +[float8] +enable_float8_linear = false diff --git a/train_configs/llama3_8b.toml b/train_configs/llama3_8b.toml index 7e5ac63c..517dd81e 100644 --- a/train_configs/llama3_8b.toml +++ b/train_configs/llama3_8b.toml @@ -33,7 +33,6 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 1 -enable_float8_linear = false compile = false dataset = "c4" @@ -52,3 +51,6 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = 'selective' # ['none', 'selective', 'full'] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[float8] +enable_float8_linear = false From 48713581479319a15f4d10c05b450ad0b9f796bd Mon Sep 17 00:00:00 2001 From: Hugo <6937752+fduwjj@users.noreply.github.com> Date: Thu, 1 Aug 2024 16:16:17 -0700 Subject: [PATCH 28/29] Bring LLaMa 3.1 405B to TorchTitan family (#481) With the official launch of LLaMa 3.1 model, we want to add the config to TorchTitan. Of course, there are more work to be done, but we want to go an incremental way. So more PRs will be needed. For now, we try on 128 GPUs with current config (TP=8, FSDP=16). The perf number is wps: 109 mfu: 29%. Loss curve for 3000 steps with 600 warmup (lr = 0.8e-4). image Loss curve for 3000 steps with 600 warmup (lr = 1.1e-4). ![image](https://github.com/user-attachments/assets/429b9738-94cb-4b37-90ef-049a5587ddd0) --- README.md | 2 +- torchtitan/datasets/download_tokenizer.py | 4 +- torchtitan/models/llama/__init__.py | 9 ++++ train_configs/llama3_405b.toml | 53 +++++++++++++++++++++++ 4 files changed, 65 insertions(+), 3 deletions(-) create mode 100644 train_configs/llama3_405b.toml diff --git a/README.md b/README.md index dde75e20..56785112 100644 --- a/README.md +++ b/README.md @@ -74,7 +74,7 @@ Once you have confirmed access, you can run the following command to download th ```bash # Get your HF token from https://huggingface.co/settings/tokens -# llama3 tokenizer.model +# llama3 or 3.1 tokenizer.model python torchtitan/datasets/download_tokenizer.py --repo_id meta-llama/Meta-Llama-3-8B --tokenizer_path "original" --hf_token=... # llama2 tokenizer.model diff --git a/torchtitan/datasets/download_tokenizer.py b/torchtitan/datasets/download_tokenizer.py index 44ef5f59..a419d709 100644 --- a/torchtitan/datasets/download_tokenizer.py +++ b/torchtitan/datasets/download_tokenizer.py @@ -20,8 +20,8 @@ def hf_download( try: hf_hub_download( - repo_id, - tokenizer_path, + repo_id=repo_id, + filename=tokenizer_path, local_dir=local_dir, local_dir_use_symlinks=False, token=hf_token, diff --git a/torchtitan/models/llama/__init__.py b/torchtitan/models/llama/__init__.py index 3cdfe0f9..887a96cd 100644 --- a/torchtitan/models/llama/__init__.py +++ b/torchtitan/models/llama/__init__.py @@ -48,4 +48,13 @@ multiple_of=4096, rope_theta=500000, ), + "405B": ModelArgs( + dim=16384, + n_layers=126, + n_heads=128, + n_kv_heads=8, + ffn_dim_multiplier=1.2, + multiple_of=4096, + rope_theta=500000, + ), } diff --git a/train_configs/llama3_405b.toml b/train_configs/llama3_405b.toml new file mode 100644 index 00000000..fb250642 --- /dev/null +++ b/train_configs/llama3_405b.toml @@ -0,0 +1,53 @@ +# torchtitan Config.toml +# NOTE: this toml config is a preset for 128 H100 GPUs. + +[job] +dump_folder = "./outputs" +description = "Llama 3 405B training" + +[profiling] +enable_profiling = true +save_traces_folder = "profile_trace" +profile_freq = 100 + +[metrics] +log_freq = 10 +enable_tensorboard = true +save_tb_folder = "tb" + +[model] +name = "llama3" +flavor = "405B" +norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm +tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model" + +[optimizer] +name = "AdamW" +lr = 0.8e-4 + +[training] +batch_size = 2 +seq_len = 8192 +warmup_steps = 600 # lr scheduler warm up, normally 20% of the train steps +max_norm = 1.0 # grad norm clipping +steps = 3000 +data_parallel_degree = -1 +tensor_parallel_degree = 8 # 8-way TP +enable_float8_linear = false +compile = false +dataset = "c4" + +[experimental] +pipeline_parallel_degree = 1 + +[checkpoint] +enable_checkpoint = false +folder = "checkpoint" +interval_type = "steps" +interval = 500 +model_weights_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = 'full' # ['none', 'selective', 'full'] From d41d6045d847e1315be431be1a7bbe216984d180 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Thu, 1 Aug 2024 20:00:58 -0700 Subject: [PATCH 29/29] [TP] Infer local n_heads instead of ad-hoc model changes ghstack-source-id: 587e3d6e5270714ca734b8031ce41a962e6394ea Pull Request resolved: https://github.com/pytorch/torchtitan/pull/498 --- torchtitan/models/llama/model.py | 9 ++++++--- torchtitan/parallelisms/parallelize_llama.py | 5 ----- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index e47d0fb9..e357f432 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -190,9 +190,12 @@ def forward( bs, seqlen, _ = x.shape xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) - xq = xq.view(bs, seqlen, self.n_heads, self.head_dim) - xk = xk.view(bs, seqlen, self.n_kv_heads, self.head_dim) - xv = xv.view(bs, seqlen, self.n_kv_heads, self.head_dim) + # Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual + # local heads from sizes of xq, xk, and xv as TP may have sharded them + # after the above linear ops. + xq = xq.view(bs, seqlen, -1, self.head_dim) + xk = xk.view(bs, seqlen, -1, self.head_dim) + xv = xv.view(bs, seqlen, -1, self.head_dim) xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index bdafc8e2..c21479a2 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -398,11 +398,6 @@ def apply_tp( "feed_forward.w3": colwise_parallel_weight(), } - # Adjust attention module to use the local number of heads - attn_layer = transformer_block.attention - attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size() - attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size() - parallelize_module( module=transformer_block, device_mesh=tp_mesh,