From 8d00b73c6bc49774253ac29e324d0cc9633276e8 Mon Sep 17 00:00:00 2001 From: willfengg Date: Wed, 12 Jun 2024 00:09:54 -0700 Subject: [PATCH 01/20] float8 tmp save Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchtitan/float8_linear.py | 20 ++++++++++++++++++-- train_configs/llama3_70b.toml | 4 ++-- train_configs/llama3_8b.toml | 7 ++++--- 3 files changed, 24 insertions(+), 7 deletions(-) diff --git a/torchtitan/float8_linear.py b/torchtitan/float8_linear.py index 9bd88cae..e7ca307b 100644 --- a/torchtitan/float8_linear.py +++ b/torchtitan/float8_linear.py @@ -12,11 +12,25 @@ # Note: Performance # Float8 experimental is intended to be ran under `torch.compile`` for competitive performance - +import torch import torch.nn as nn +import contextlib +import torch.distributed as dist from torchtitan.config_manager import JobConfig from torchtitan.logging_utils import logger +import float8_experimental.config as config + +@contextlib.contextmanager +def set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather: bool): + prev = config.enable_fsdp_fp8_all_gather + dist.barrier() + config.enable_fsdp_fp8_all_gather = enable_fsdp_fp8_all_gather + try: + yield + finally: + dist.barrier() + config.enable_fsdp_fp8_all_gather = prev def build_fp8_linear(model: nn.Module, job_config: JobConfig): @@ -28,6 +42,7 @@ def build_fp8_linear(model: nn.Module, job_config: JobConfig): This will mutate the model inplace. """ linear_type = job_config.training.fp8_linear.lower() + enable_fsdp_fp8_all_gather = job_config.training.enable_fsdp_fp8_all_gather if hasattr(job_config.training, 'enable_fsdp_fp8_all_gather') else False try: from float8_experimental.float8_dynamic_linear import Float8DynamicLinear @@ -50,5 +65,6 @@ def build_fp8_linear(model: nn.Module, job_config: JobConfig): float8_linear_type = linear_type_map[linear_type.lower()] # Mutates the model inplace replacing instances of torch.nn.Linear with float8_linear_type - swap_linear_with_float8_linear(model, float8_linear_type) + with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): + swap_linear_with_float8_linear(model, float8_linear_type) logger.info(f"Swapped to {linear_type} float8 linear layers") diff --git a/train_configs/llama3_70b.toml b/train_configs/llama3_70b.toml index f45632ad..6c5f0b56 100644 --- a/train_configs/llama3_70b.toml +++ b/train_configs/llama3_70b.toml @@ -26,7 +26,7 @@ name = "AdamW" lr = 1.5e-4 [training] -batch_size = 8 +batch_size = 1 seq_len = 8192 warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps max_norm = 1.0 # grad norm clipping @@ -34,7 +34,7 @@ steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 8 # 8-way TP pipeline_parallel_degree = 1 -fp8_linear = "" +fp8_linear = "dynamic" compile = false dataset = "c4" diff --git a/train_configs/llama3_8b.toml b/train_configs/llama3_8b.toml index aaba99a2..51d7ffb9 100644 --- a/train_configs/llama3_8b.toml +++ b/train_configs/llama3_8b.toml @@ -8,7 +8,7 @@ description = "Llama 3 8B training" [profiling] enable_profiling = true save_traces_folder = "profile_trace" -profile_freq = 100 +profile_freq = 1 [metrics] log_freq = 10 @@ -30,11 +30,12 @@ batch_size = 1 seq_len = 8192 warmup_steps = 200 # lr scheduler warm up max_norm = 1.0 # grad norm clipping -steps = 1000 +steps = 10 data_parallel_degree = -1 tensor_parallel_degree = 1 pipeline_parallel_degree = 1 -fp8_linear = "" +fp8_linear = "dynamic" +enable_fsdp_fp8_all_gather = true compile = false dataset = "c4" From 4cd5f7443d3db2ef998ca71759f79bfd57537028 Mon Sep 17 00:00:00 2001 From: willfengg Date: Wed, 19 Jun 2024 16:05:33 -0700 Subject: [PATCH 02/20] run 8b eager successfully Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchtitan/config_manager.py | 6 ++++++ torchtitan/float8_linear.py | 4 ++-- train.py | 2 +- train_configs/llama3_8b.toml | 11 +++++------ 4 files changed, 14 insertions(+), 9 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 0eeac026..16ad2c06 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -344,6 +344,12 @@ def __init__(self): here: https://github.com/pytorch-labs/float8_experimental """, ) + self.parser.add_argument( + "--training.enable_fsdp_fp8_all_gather", + action="store_true", + default=False, + help="Whether enable fp8 all-gather in FSDP", + ) self.parser.add_argument( "--training.gc_freq", type=int, diff --git a/torchtitan/float8_linear.py b/torchtitan/float8_linear.py index e7ca307b..0f0b1451 100644 --- a/torchtitan/float8_linear.py +++ b/torchtitan/float8_linear.py @@ -42,7 +42,7 @@ def build_fp8_linear(model: nn.Module, job_config: JobConfig): This will mutate the model inplace. """ linear_type = job_config.training.fp8_linear.lower() - enable_fsdp_fp8_all_gather = job_config.training.enable_fsdp_fp8_all_gather if hasattr(job_config.training, 'enable_fsdp_fp8_all_gather') else False + enable_fsdp_fp8_all_gather = job_config.training.enable_fsdp_fp8_all_gather try: from float8_experimental.float8_dynamic_linear import Float8DynamicLinear @@ -67,4 +67,4 @@ def build_fp8_linear(model: nn.Module, job_config: JobConfig): # Mutates the model inplace replacing instances of torch.nn.Linear with float8_linear_type with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): swap_linear_with_float8_linear(model, float8_linear_type) - logger.info(f"Swapped to {linear_type} float8 linear layers") + logger.info(f"Swapped to {linear_type} float8 linear layers with {enable_fsdp_fp8_all_gather=}") diff --git a/train.py b/train.py index 64a50990..8e55c210 100644 --- a/train.py +++ b/train.py @@ -217,7 +217,7 @@ def loss_fn(pred, labels): # apply fp8 linear module swap if job_config.training.fp8_linear: - build_fp8_linear(model, job_config) + build_fp8_linear(whole_model, job_config) # log model size model_param_count = get_num_params(whole_model) diff --git a/train_configs/llama3_8b.toml b/train_configs/llama3_8b.toml index 79bc894b..781351d6 100644 --- a/train_configs/llama3_8b.toml +++ b/train_configs/llama3_8b.toml @@ -8,10 +8,10 @@ description = "Llama 3 8B training" [profiling] enable_profiling = true save_traces_folder = "profile_trace" -profile_freq = 1 +profile_freq = 5 [metrics] -log_freq = 10 +log_freq = 1 enable_tensorboard = true save_tb_folder = "tb" @@ -27,10 +27,10 @@ lr = 3e-4 [training] batch_size = 1 -seq_len = 8192 +seq_len = 256 warmup_steps = 200 # lr scheduler warm up max_norm = 1.0 # grad norm clipping -steps = 10 +steps = 20 data_parallel_degree = -1 tensor_parallel_degree = 1 fp8_linear = "dynamic" @@ -51,5 +51,4 @@ export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] -mode = 'selective' # ['none', 'selective', 'full'] -selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy +mode = 'full' # ['none', 'selective', 'full'] From 05a4a06f0b3c870df06bc563818bfce7f9577bc4 Mon Sep 17 00:00:00 2001 From: willfengg Date: Wed, 19 Jun 2024 17:02:07 -0700 Subject: [PATCH 03/20] enable compile Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- train_configs/llama3_8b.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_configs/llama3_8b.toml b/train_configs/llama3_8b.toml index 781351d6..eecf1375 100644 --- a/train_configs/llama3_8b.toml +++ b/train_configs/llama3_8b.toml @@ -35,7 +35,7 @@ data_parallel_degree = -1 tensor_parallel_degree = 1 fp8_linear = "dynamic" enable_fsdp_fp8_all_gather = true -compile = false +compile = true dataset = "c4" [experimental] From f48a82e5543073c4db0a86f86035a3827cbb4003 Mon Sep 17 00:00:00 2001 From: willfengg Date: Wed, 19 Jun 2024 17:15:54 -0700 Subject: [PATCH 04/20] benchmark Summaiy: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchtitan/parallelisms/parallelize_llama.py | 4 +++- train_configs/llama3_70b.toml | 5 +++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 5b781201..9e115527 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -31,6 +31,8 @@ from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.logging_utils import logger +import torch._functorch.config as functorch_config +functorch_config.activation_memory_budget = 0.0 # for selective AC no_recompute_list = { @@ -376,7 +378,7 @@ def apply_ac(model, job_config: JobConfig): transformer_block = checkpoint_wrapper(transformer_block, ac_config) model.layers.register_module(layer_id, transformer_block) - logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") + logger.info(f"Applied {ac_config.mode} activation checkpointing to the model with {functorch_config.activation_memory_budget=}") return model diff --git a/train_configs/llama3_70b.toml b/train_configs/llama3_70b.toml index 9bb2e416..660c7d97 100644 --- a/train_configs/llama3_70b.toml +++ b/train_configs/llama3_70b.toml @@ -26,14 +26,15 @@ name = "AdamW" lr = 1.5e-4 [training] -batch_size = 1 +batch_size = 2 seq_len = 8192 warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 -tensor_parallel_degree = 8 # 8-way TP +tensor_parallel_degree = 1 # 8-way TP fp8_linear = "dynamic" +enable_fsdp_fp8_all_gather = true compile = false dataset = "c4" From 14aabfb4c60347059a67f92423a8f00f9c520242 Mon Sep 17 00:00:00 2001 From: willfengg Date: Thu, 20 Jun 2024 17:49:06 -0700 Subject: [PATCH 05/20] 1d setup Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchtitan/config_manager.py | 6 ++++++ train.py | 4 ++++ train_configs/llama3_70b.toml | 1 + 3 files changed, 11 insertions(+) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 16ad2c06..aa0cb13d 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -350,6 +350,12 @@ def __init__(self): default=False, help="Whether enable fp8 all-gather in FSDP", ) + self.parser.add_argument( + "--training.precompute_float8_amax", + action="store_true", + default=False, + help="Whether precompute fp8 amax for FSDP", + ) self.parser.add_argument( "--training.gc_freq", type=int, diff --git a/train.py b/train.py index 8e55c210..1b756826 100644 --- a/train.py +++ b/train.py @@ -23,6 +23,7 @@ from torch.distributed.checkpoint.stateful import Stateful from torch.distributed.elastic.multiprocessing.errors import record from torch.distributed.tensor.parallel import loss_parallel +from float8_experimental.float8_linear_utils import precompute_float8_amax from torchtitan.checkpoint import CheckpointManager from torchtitan.config_manager import JobConfig @@ -398,6 +399,9 @@ def loss_fn(pred, labels): optimizers.step() lr_schedulers.step() + if job_config.training.precompute_float8_amax: + precompute_float8_amax(model) + losses_since_last_log.append(loss) # log metrics diff --git a/train_configs/llama3_70b.toml b/train_configs/llama3_70b.toml index 660c7d97..891b912c 100644 --- a/train_configs/llama3_70b.toml +++ b/train_configs/llama3_70b.toml @@ -35,6 +35,7 @@ data_parallel_degree = -1 tensor_parallel_degree = 1 # 8-way TP fp8_linear = "dynamic" enable_fsdp_fp8_all_gather = true +precompute_float8_amax = true compile = false dataset = "c4" From b88aee92637a6a218181a7f5b0206c3e0b6abb33 Mon Sep 17 00:00:00 2001 From: willfengg Date: Thu, 20 Jun 2024 17:51:31 -0700 Subject: [PATCH 06/20] 2d setup Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- train_configs/llama3_70b.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train_configs/llama3_70b.toml b/train_configs/llama3_70b.toml index 891b912c..c17b90a2 100644 --- a/train_configs/llama3_70b.toml +++ b/train_configs/llama3_70b.toml @@ -26,17 +26,17 @@ name = "AdamW" lr = 1.5e-4 [training] -batch_size = 2 +batch_size = 16 seq_len = 8192 warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 -tensor_parallel_degree = 1 # 8-way TP +tensor_parallel_degree = 8 # 8-way TP fp8_linear = "dynamic" enable_fsdp_fp8_all_gather = true precompute_float8_amax = true -compile = false +compile = true dataset = "c4" [experimental] From 2b4e0c2a768ccba62cdff2ccd9231185d302ef12 Mon Sep 17 00:00:00 2001 From: willfengg Date: Fri, 21 Jun 2024 14:36:03 -0700 Subject: [PATCH 07/20] 2d setup Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- train_configs/llama3_8b.toml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/train_configs/llama3_8b.toml b/train_configs/llama3_8b.toml index eecf1375..7846eafc 100644 --- a/train_configs/llama3_8b.toml +++ b/train_configs/llama3_8b.toml @@ -32,10 +32,11 @@ warmup_steps = 200 # lr scheduler warm up max_norm = 1.0 # grad norm clipping steps = 20 data_parallel_degree = -1 -tensor_parallel_degree = 1 +tensor_parallel_degree = 4 fp8_linear = "dynamic" enable_fsdp_fp8_all_gather = true -compile = true +precompute_float8_amax = true +compile = false dataset = "c4" [experimental] From 23536e98f0741d1de85d457e7b0164fc5314a447 Mon Sep 17 00:00:00 2001 From: willfengg Date: Fri, 12 Jul 2024 15:38:41 -0700 Subject: [PATCH 08/20] fp8 all-gather FSDP Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchtitan/config_manager.py | 4 ++-- torchtitan/float8_linear.py | 4 +++- torchtitan/parallelisms/parallelize_llama.py | 5 +---- train.py | 6 +++--- train_configs/llama3_8b.toml | 4 ++-- 5 files changed, 11 insertions(+), 12 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 2a8b86b4..88d04b6a 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -354,10 +354,10 @@ def __init__(self): help="Whether enable fp8 all-gather in FSDP", ) self.parser.add_argument( - "--training.precompute_float8_amax", + "--training.precompute_float8_dynamic_scale_for_fsdp", action="store_true", default=False, - help="Whether precompute fp8 amax for FSDP", + help="Whether precompute fp8 scales dynamically for FSDP", ) self.parser.add_argument( "--training.gc_freq", diff --git a/torchtitan/float8_linear.py b/torchtitan/float8_linear.py index 142a1bb5..cf6bcacc 100644 --- a/torchtitan/float8_linear.py +++ b/torchtitan/float8_linear.py @@ -20,6 +20,8 @@ from torchtitan.config_manager import JobConfig from torchtitan.logging_utils import logger import float8_experimental.config as config +from float8_experimental.float8_linear import TensorScalingType + @contextlib.contextmanager def set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather: bool): @@ -54,5 +56,5 @@ def build_fp8_linear(model: nn.Module, job_config: JobConfig): if use_fp8_linear: # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): - swap_linear_with_float8_linear(model, Float8Linear) + swap_linear_with_float8_linear(model, scaling_type_w=TensorScalingType.DYNAMIC) logger.info(f"Swapped to Float8Linear layers with {enable_fsdp_fp8_all_gather=}") diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 9cdd0b2c..1b414159 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -40,9 +40,6 @@ DeviceType = Union[int, str, torch.device] -import torch._functorch.config as functorch_config -functorch_config.activation_memory_budget = 0.0 - # for selective AC no_recompute_list = { torch.ops.aten.mm.default, @@ -428,7 +425,7 @@ def apply_ac(model: nn.Module, job_config: JobConfig): transformer_block = checkpoint_wrapper(transformer_block, ac_config) model.layers.register_module(layer_id, transformer_block) - logger.info(f"Applied {ac_config.mode} activation checkpointing to the model with {functorch_config.activation_memory_budget=}") + logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") return model diff --git a/train.py b/train.py index 1b756826..7f47490c 100644 --- a/train.py +++ b/train.py @@ -23,7 +23,7 @@ from torch.distributed.checkpoint.stateful import Stateful from torch.distributed.elastic.multiprocessing.errors import record from torch.distributed.tensor.parallel import loss_parallel -from float8_experimental.float8_linear_utils import precompute_float8_amax +from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp from torchtitan.checkpoint import CheckpointManager from torchtitan.config_manager import JobConfig @@ -399,8 +399,8 @@ def loss_fn(pred, labels): optimizers.step() lr_schedulers.step() - if job_config.training.precompute_float8_amax: - precompute_float8_amax(model) + if job_config.training.precompute_float8_dynamic_scale_for_fsdp: + precompute_float8_dynamic_scale_for_fsdp(model) losses_since_last_log.append(loss) diff --git a/train_configs/llama3_8b.toml b/train_configs/llama3_8b.toml index f95158ca..83064215 100644 --- a/train_configs/llama3_8b.toml +++ b/train_configs/llama3_8b.toml @@ -32,10 +32,10 @@ warmup_steps = 200 # lr scheduler warm up max_norm = 1.0 # grad norm clipping steps = 20 data_parallel_degree = -1 -tensor_parallel_degree = 4 +tensor_parallel_degree = 1 fp8_linear = true enable_fsdp_fp8_all_gather = true -precompute_float8_amax = true +precompute_float8_dynamic_scale_for_fsdp = true compile = false dataset = "c4" From bdb0fd039ba29969673183b634137ac49babfbeb Mon Sep 17 00:00:00 2001 From: willfengg Date: Fri, 12 Jul 2024 15:47:03 -0700 Subject: [PATCH 09/20] linter Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchtitan/float8_linear.py | 22 ++++++++++++---------- train.py | 2 +- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/torchtitan/float8_linear.py b/torchtitan/float8_linear.py index cf6bcacc..0d2aca3d 100644 --- a/torchtitan/float8_linear.py +++ b/torchtitan/float8_linear.py @@ -12,26 +12,25 @@ # Note: Performance # Float8 experimental is intended to be ran under `torch.compile`` for competitive performance -import torch -import torch.nn as nn import contextlib -import torch.distributed as dist -from torchtitan.config_manager import JobConfig -from torchtitan.logging_utils import logger import float8_experimental.config as config +import torch.nn as nn from float8_experimental.float8_linear import TensorScalingType +from torchtitan.config_manager import JobConfig +from torchtitan.logging_utils import logger + @contextlib.contextmanager def set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather: bool): prev = config.enable_fsdp_fp8_all_gather - dist.barrier() + torch.distributed.barrier() config.enable_fsdp_fp8_all_gather = enable_fsdp_fp8_all_gather try: yield finally: - dist.barrier() + torch.distributed.barrier() config.enable_fsdp_fp8_all_gather = prev @@ -45,7 +44,6 @@ def build_fp8_linear(model: nn.Module, job_config: JobConfig): use_fp8_linear = job_config.training.fp8_linear enable_fsdp_fp8_all_gather = job_config.training.enable_fsdp_fp8_all_gather try: - from float8_experimental.float8_linear import Float8Linear from float8_experimental.float8_linear_utils import ( swap_linear_with_float8_linear, ) @@ -56,5 +54,9 @@ def build_fp8_linear(model: nn.Module, job_config: JobConfig): if use_fp8_linear: # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): - swap_linear_with_float8_linear(model, scaling_type_w=TensorScalingType.DYNAMIC) - logger.info(f"Swapped to Float8Linear layers with {enable_fsdp_fp8_all_gather=}") + swap_linear_with_float8_linear( + model, scaling_type_w=TensorScalingType.DYNAMIC + ) + logger.info( + f"Swapped to Float8Linear layers with {enable_fsdp_fp8_all_gather=}" + ) diff --git a/train.py b/train.py index 7f47490c..cc953588 100644 --- a/train.py +++ b/train.py @@ -19,11 +19,11 @@ import torch import torch.nn.functional as F +from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp from torch.distributed import destroy_process_group from torch.distributed.checkpoint.stateful import Stateful from torch.distributed.elastic.multiprocessing.errors import record from torch.distributed.tensor.parallel import loss_parallel -from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp from torchtitan.checkpoint import CheckpointManager from torchtitan.config_manager import JobConfig From ef0e843ef70471e3aebb99b930c2515e22805ac4 Mon Sep 17 00:00:00 2001 From: willfengg Date: Fri, 12 Jul 2024 16:07:52 -0700 Subject: [PATCH 10/20] add unit test and restore original toml Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test_runner.py | 46 +++++++++++++++++++++++++++++++++++ train_configs/llama3_70b.toml | 8 +++--- train_configs/llama3_8b.toml | 15 ++++++------ 3 files changed, 56 insertions(+), 13 deletions(-) diff --git a/test_runner.py b/test_runner.py index 319f99d7..feecd12e 100755 --- a/test_runner.py +++ b/test_runner.py @@ -273,6 +273,52 @@ def build_test_list(): "fsdp2_mem_tracker", ngpu=4, ), + OverrideDefinitions( + [ + [ + "--training.fp8_linear", + ] + ], + "FSDP2 with bf16 all-gather", + "fp8_fsdp2_bf16_all_gather", + ngpu=4, + ), + OverrideDefinitions( + [ + [ + "--training.fp8_linear", + "--training.enable_fsdp_fp8_all_gather", + ] + ], + "FSDP2 with fp8 all-gather", + "fp8_fsdp2_fp8_all_gather", + ngpu=4, + ), + OverrideDefinitions( + [ + [ + "--training.fp8_linear", + "--training.enable_fsdp_fp8_all_gather", + "--precompute_float8_dynamic_scale_for_fsdp", + ] + ], + "FSDP2 with fp8 all-gather and precomputed dynamic scales", + "fp8_fsdp2_fp8_all_gather_precompute_dynamic_scales", + ngpu=4, + ), + OverrideDefinitions( + [ + [ + "--training.fp8_linear", + "--training.enable_fsdp_fp8_all_gather", + "--precompute_float8_dynamic_scale_for_fsdp", + "--training.compile", + ] + ], + "FSDP2 with fp8 all-gather and precomputed dynamic scales, graph-break compile", + "fp8_fsdp2_fp8_all_gather_precompute_dynamic_scales_compile", + ngpu=4, + ), ] return integration_tests_flavors diff --git a/train_configs/llama3_70b.toml b/train_configs/llama3_70b.toml index 92d419f4..93b529f6 100644 --- a/train_configs/llama3_70b.toml +++ b/train_configs/llama3_70b.toml @@ -26,17 +26,15 @@ name = "AdamW" lr = 1.5e-4 [training] -batch_size = 16 +batch_size = 8 seq_len = 8192 warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 8 # 8-way TP -fp8_linear = true -enable_fsdp_fp8_all_gather = true -precompute_float8_amax = true -compile = true +fp8_linear = false +compile = false dataset = "c4" [experimental] diff --git a/train_configs/llama3_8b.toml b/train_configs/llama3_8b.toml index 83064215..95a53d56 100644 --- a/train_configs/llama3_8b.toml +++ b/train_configs/llama3_8b.toml @@ -8,10 +8,10 @@ description = "Llama 3 8B training" [profiling] enable_profiling = true save_traces_folder = "profile_trace" -profile_freq = 5 +profile_freq = 100 [metrics] -log_freq = 1 +log_freq = 10 enable_tensorboard = true save_tb_folder = "tb" @@ -27,15 +27,13 @@ lr = 3e-4 [training] batch_size = 1 -seq_len = 256 +seq_len = 8192 warmup_steps = 200 # lr scheduler warm up max_norm = 1.0 # grad norm clipping -steps = 20 +steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 1 -fp8_linear = true -enable_fsdp_fp8_all_gather = true -precompute_float8_dynamic_scale_for_fsdp = true +fp8_linear = false compile = false dataset = "c4" @@ -52,4 +50,5 @@ export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] -mode = 'full' # ['none', 'selective', 'full'] +mode = 'selective' # ['none', 'selective', 'full'] +selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy From c294f6a3caadff153066383c54a5bef65f55a1a5 Mon Sep 17 00:00:00 2001 From: willfengg Date: Fri, 12 Jul 2024 16:43:31 -0700 Subject: [PATCH 11/20] add unit test for float8 Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test_runner.py | 15 +-------------- torchtitan/float8_linear.py | 2 ++ 2 files changed, 3 insertions(+), 14 deletions(-) diff --git a/test_runner.py b/test_runner.py index feecd12e..b914f88c 100755 --- a/test_runner.py +++ b/test_runner.py @@ -299,26 +299,13 @@ def build_test_list(): [ "--training.fp8_linear", "--training.enable_fsdp_fp8_all_gather", - "--precompute_float8_dynamic_scale_for_fsdp", + "--training.precompute_float8_dynamic_scale_for_fsdp", ] ], "FSDP2 with fp8 all-gather and precomputed dynamic scales", "fp8_fsdp2_fp8_all_gather_precompute_dynamic_scales", ngpu=4, ), - OverrideDefinitions( - [ - [ - "--training.fp8_linear", - "--training.enable_fsdp_fp8_all_gather", - "--precompute_float8_dynamic_scale_for_fsdp", - "--training.compile", - ] - ], - "FSDP2 with fp8 all-gather and precomputed dynamic scales, graph-break compile", - "fp8_fsdp2_fp8_all_gather_precompute_dynamic_scales_compile", - ngpu=4, - ), ] return integration_tests_flavors diff --git a/torchtitan/float8_linear.py b/torchtitan/float8_linear.py index 0d2aca3d..bc7cbab2 100644 --- a/torchtitan/float8_linear.py +++ b/torchtitan/float8_linear.py @@ -15,6 +15,8 @@ import contextlib import float8_experimental.config as config + +import torch import torch.nn as nn from float8_experimental.float8_linear import TensorScalingType From b58b07b19fea1e42ade9b7cd70cf69cdacd5a1b5 Mon Sep 17 00:00:00 2001 From: willfengg Date: Mon, 15 Jul 2024 14:00:00 -0700 Subject: [PATCH 12/20] better doc with original dtype all-gather and value error on fp8 Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test_runner.py | 4 ++-- train.py | 13 +++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/test_runner.py b/test_runner.py index b914f88c..96a89df5 100755 --- a/test_runner.py +++ b/test_runner.py @@ -279,8 +279,8 @@ def build_test_list(): "--training.fp8_linear", ] ], - "FSDP2 with bf16 all-gather", - "fp8_fsdp2_bf16_all_gather", + "FSDP2 with original dtype", + "fp8_fsdp2_orig_all_gather", ngpu=4, ), OverrideDefinitions( diff --git a/train.py b/train.py index cc953588..5984186e 100644 --- a/train.py +++ b/train.py @@ -219,6 +219,11 @@ def loss_fn(pred, labels): # apply fp8 linear module swap if job_config.training.fp8_linear: build_fp8_linear(whole_model, job_config) + else: + if job_config.training.enable_fsdp_fp8_all_gather: + raise ValueError( + "enable_fsdp_fp8_all_gather can only be used with fp8_linear" + ) # log model size model_param_count = get_num_params(whole_model) @@ -400,6 +405,14 @@ def loss_fn(pred, labels): lr_schedulers.step() if job_config.training.precompute_float8_dynamic_scale_for_fsdp: + if (not job_config.training.use_fp8_linear) or ( + not job_config.training.enable_fsdp_fp8_all_gather + ): + raise ValueError( + "precompute_float8_dynamic_scale_for_fsdp is only ", + "supported when use_fp8_linear and ", + "enable_fsdp_fp8_all_gather are both enabled.", + ) precompute_float8_dynamic_scale_for_fsdp(model) losses_since_last_log.append(loss) From 7df10aedecf2f724cf3762100dd1bf190e425726 Mon Sep 17 00:00:00 2001 From: willfengg Date: Mon, 15 Jul 2024 14:10:58 -0700 Subject: [PATCH 13/20] improve config msg Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- train.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/train.py b/train.py index 5984186e..11a11823 100644 --- a/train.py +++ b/train.py @@ -405,13 +405,13 @@ def loss_fn(pred, labels): lr_schedulers.step() if job_config.training.precompute_float8_dynamic_scale_for_fsdp: - if (not job_config.training.use_fp8_linear) or ( + if (not job_config.training.fp8_linear) or ( not job_config.training.enable_fsdp_fp8_all_gather ): raise ValueError( - "precompute_float8_dynamic_scale_for_fsdp is only ", - "supported when use_fp8_linear and ", - "enable_fsdp_fp8_all_gather are both enabled.", + "precompute_float8_dynamic_scale_for_fsdp is only " + "supported when fp8_linear and " + "enable_fsdp_fp8_all_gather are both enabled" ) precompute_float8_dynamic_scale_for_fsdp(model) From 7dd788c45a254500046b6e1d8ed1259a624989da Mon Sep 17 00:00:00 2001 From: willfengg Date: Mon, 15 Jul 2024 18:55:50 -0700 Subject: [PATCH 14/20] rename config to enable_fp8_linear and improve comments Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- estimation.py | 4 ++-- test_runner.py | 12 ++++++++++ torchtitan/config_manager.py | 2 +- torchtitan/float8_linear.py | 22 ++++++++++------- torchtitan/parallelisms/parallelize_llama.py | 14 +++++++++-- train.py | 25 +++++++------------- 6 files changed, 50 insertions(+), 29 deletions(-) diff --git a/estimation.py b/estimation.py index ddf24d8a..e652c581 100644 --- a/estimation.py +++ b/estimation.py @@ -124,8 +124,8 @@ def loss_fn(pred, labels): whole_model = model_cls.from_model_args(model_config) # apply fp8 linear module swap - if job_config.training.fp8_linear: - build_fp8_linear(whole_model, job_config) + if job_config.training.enable_fp8_linear: + build_fp8_linear(whole_model, job_config, parallel_dims.dp_enabled) # apply PT-D DP/TP parallelisms and activation checkpointing model_parts = [whole_model] diff --git a/test_runner.py b/test_runner.py index 96a89df5..ed938b1a 100755 --- a/test_runner.py +++ b/test_runner.py @@ -283,6 +283,18 @@ def build_test_list(): "fp8_fsdp2_orig_all_gather", ngpu=4, ), + OverrideDefinitions( + [ + [ + "--training.fp8_linear", + "--training.data_parallel_degree 1" + "--training.tensor_parallel_degree 4", + ] + ], + "1D TP with fp8 all-gather", + "tp_fp8_all_gather", + ngpu=4, + ), OverrideDefinitions( [ [ diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 88d04b6a..0dfe1bb0 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -338,7 +338,7 @@ def __init__(self): help="Whether to compile the model", ) self.parser.add_argument( - "--training.fp8_linear", + "--training.enable_fp8_linear", action="store_true", help=""" If true, swaps `torch.nn.Linear` with `Float8Linear` with diff --git a/torchtitan/float8_linear.py b/torchtitan/float8_linear.py index bc7cbab2..0aced204 100644 --- a/torchtitan/float8_linear.py +++ b/torchtitan/float8_linear.py @@ -36,24 +36,26 @@ def set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather: bool): config.enable_fsdp_fp8_all_gather = prev -def build_fp8_linear(model: nn.Module, job_config: JobConfig): +def build_fp8_linear( + model: nn.Module, job_config: JobConfig, dp_enabled: Optional[bool] = False +): """ This function converts the linear layers to `Float8Linear`. Note that today, only dynamic tensor scaling (the default) is supported. This will mutate the model inplace. """ - use_fp8_linear = job_config.training.fp8_linear - enable_fsdp_fp8_all_gather = job_config.training.enable_fsdp_fp8_all_gather + enable_fp8_linear = job_config.training.enable_fp8_linear + if not enable_fp8_linear: + return + enable_fsdp_fp8_all_gather = ( + job_config.training.enable_fsdp_fp8_all_gather and dp_enabled + ) try: from float8_experimental.float8_linear_utils import ( swap_linear_with_float8_linear, ) - except ImportError as exc: - raise ImportError( - "float8_experimental is not installed. Please install it to use fp8 linear layers." - ) from exc - if use_fp8_linear: + # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): swap_linear_with_float8_linear( @@ -62,3 +64,7 @@ def build_fp8_linear(model: nn.Module, job_config: JobConfig): logger.info( f"Swapped to Float8Linear layers with {enable_fsdp_fp8_all_gather=}" ) + except ImportError as exc: + raise ImportError( + "float8_experimental is not installed. Please install it to use fp8 linear layers." + ) from exc diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 1b414159..bb5606c8 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -117,12 +117,22 @@ def selective_checkpointing_context_fn(): def get_tp_parallel_strategy( job_config: JobConfig, + model: nn.Module, ) -> Tuple[RowwiseParallel, ColwiseParallel, PrepareModuleInput]: """Get the parallel strategy for the transformer model. This function handles the special case of using float8 with tensor parallelism. """ - if job_config.training.fp8_linear == "dynamic": + if job_config.training.enable_fp8_linear: + from float8_experimental.float8_linear import Float8Linear, TensorScalingType + + if any( + isinstance(m, Float8Linear) + and m.scaling_type_w is TensorScalingType.DELAYED + for m in module.modules() + ): + raise NotImplementedError("Only supports delayed scaling") + from float8_experimental.float8_tensor_parallel import ( Float8ColwiseParallel, Float8RowwiseParallel, @@ -346,7 +356,7 @@ def apply_tp( rowwise_parallel_weight, colwise_parallel_weight, prepare_module_input, - ) = get_tp_parallel_strategy(job_config) + ) = get_tp_parallel_strategy(job_config, model) loss_parallel = parallel_dims.loss_parallel_enabled # 1. Parallelize the embedding and shard its outputs (which are the first diff --git a/train.py b/train.py index 11a11823..14008525 100644 --- a/train.py +++ b/train.py @@ -217,13 +217,8 @@ def loss_fn(pred, labels): whole_model = model_cls.from_model_args(model_config) # apply fp8 linear module swap - if job_config.training.fp8_linear: - build_fp8_linear(whole_model, job_config) - else: - if job_config.training.enable_fsdp_fp8_all_gather: - raise ValueError( - "enable_fsdp_fp8_all_gather can only be used with fp8_linear" - ) + if job_config.training.enable_fp8_linear: + build_fp8_linear(whole_model, job_config, parallel_dims.dp_enabled) # log model size model_param_count = get_num_params(whole_model) @@ -404,15 +399,13 @@ def loss_fn(pred, labels): optimizers.step() lr_schedulers.step() - if job_config.training.precompute_float8_dynamic_scale_for_fsdp: - if (not job_config.training.fp8_linear) or ( - not job_config.training.enable_fsdp_fp8_all_gather - ): - raise ValueError( - "precompute_float8_dynamic_scale_for_fsdp is only " - "supported when fp8_linear and " - "enable_fsdp_fp8_all_gather are both enabled" - ) + if ( + job_config.training.enable_fp8_linear + and job_config.training.enable_fsdp_fp8_all_gather + and job_config.training.precompute_float8_dynamic_scale_for_fsdp + ): + # calculate float8 dynamic amax/scale for all-parameter for FSDP2 + # it issues a single all-reduce for all parameters at once for better performance precompute_float8_dynamic_scale_for_fsdp(model) losses_since_last_log.append(loss) From faefe2796e8c91791d5a46c207ffc32da748b8eb Mon Sep 17 00:00:00 2001 From: willfengg Date: Mon, 15 Jul 2024 18:57:49 -0700 Subject: [PATCH 15/20] rename to enable_fp8_linear Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- train_configs/debug_model.toml | 2 +- train_configs/llama2_13b.toml | 2 +- train_configs/llama2_70b.toml | 2 +- train_configs/llama2_7b.toml | 2 +- train_configs/llama3_70b.toml | 2 +- train_configs/llama3_8b.toml | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index cb2fb215..6064ced1 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -37,7 +37,7 @@ max_norm = 1.0 # grad norm clipping steps = 10 data_parallel_degree = -1 tensor_parallel_degree = 1 -fp8_linear = false +enable_fp8_linear = false compile = false dataset = "c4_mini" # supported datasets: c4_mini (45K), c4 (177M) diff --git a/train_configs/llama2_13b.toml b/train_configs/llama2_13b.toml index 05e3c27b..f4061ad0 100644 --- a/train_configs/llama2_13b.toml +++ b/train_configs/llama2_13b.toml @@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 1 -fp8_linear = false +enable_fp8_linear = false compile = false dataset = "c4" diff --git a/train_configs/llama2_70b.toml b/train_configs/llama2_70b.toml index 5b2dd493..19e033b8 100644 --- a/train_configs/llama2_70b.toml +++ b/train_configs/llama2_70b.toml @@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 8 # 8-way TP -fp8_linear = false +enable_fp8_linear = false compile = false dataset = "c4" diff --git a/train_configs/llama2_7b.toml b/train_configs/llama2_7b.toml index 9b72246a..95d67667 100644 --- a/train_configs/llama2_7b.toml +++ b/train_configs/llama2_7b.toml @@ -32,7 +32,7 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 1 # dp-only would be sufficient for 7B -fp8_linear = false +enable_fp8_linear = false compile = false dataset = "c4" diff --git a/train_configs/llama3_70b.toml b/train_configs/llama3_70b.toml index 93b529f6..ac6b31c1 100644 --- a/train_configs/llama3_70b.toml +++ b/train_configs/llama3_70b.toml @@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 8 # 8-way TP -fp8_linear = false +enable_fp8_linear = false compile = false dataset = "c4" diff --git a/train_configs/llama3_8b.toml b/train_configs/llama3_8b.toml index 95a53d56..2c3c6e63 100644 --- a/train_configs/llama3_8b.toml +++ b/train_configs/llama3_8b.toml @@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 1 -fp8_linear = false +enable_fp8_linear = false compile = false dataset = "c4" From 7aad0668b904ce3445dafb4ddda93162a1a54a1f Mon Sep 17 00:00:00 2001 From: willfengg Date: Mon, 15 Jul 2024 19:14:07 -0700 Subject: [PATCH 16/20] add 2D test Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test_runner.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/test_runner.py b/test_runner.py index ed938b1a..c28f1a23 100755 --- a/test_runner.py +++ b/test_runner.py @@ -276,7 +276,7 @@ def build_test_list(): OverrideDefinitions( [ [ - "--training.fp8_linear", + "--training.enable_fp8_linear", ] ], "FSDP2 with original dtype", @@ -286,7 +286,7 @@ def build_test_list(): OverrideDefinitions( [ [ - "--training.fp8_linear", + "--training.enable_fp8_linear", "--training.data_parallel_degree 1" "--training.tensor_parallel_degree 4", ] @@ -298,24 +298,37 @@ def build_test_list(): OverrideDefinitions( [ [ - "--training.fp8_linear", + "--training.enable_fp8_linear", "--training.enable_fsdp_fp8_all_gather", ] ], "FSDP2 with fp8 all-gather", - "fp8_fsdp2_fp8_all_gather", + "fsdp2_fp8_all_gather", ngpu=4, ), OverrideDefinitions( [ [ - "--training.fp8_linear", + "--training.enable_fp8_linear", "--training.enable_fsdp_fp8_all_gather", "--training.precompute_float8_dynamic_scale_for_fsdp", ] ], "FSDP2 with fp8 all-gather and precomputed dynamic scales", - "fp8_fsdp2_fp8_all_gather_precompute_dynamic_scales", + "fsdp2_fp8_all_gather_precompute_dynamic_scales", + ngpu=4, + ), + OverrideDefinitions( + [ + [ + "--training.enable_fp8_linear", + "--training.enable_fsdp_fp8_all_gather", + "--training.precompute_float8_dynamic_scale_for_fsdp", + "--training.tensor_parallel_degree 2", + ] + ], + "FSDP2 with fp8 all-gather and precomputed dynamic scales", + "fsdp2_tp_fp8_all_gather_precompute_dynamic_scales", ngpu=4, ), ] From 5040c31925b5b63a212b336f115e0d1b2ddb6d7c Mon Sep 17 00:00:00 2001 From: willfengg Date: Mon, 15 Jul 2024 20:38:34 -0700 Subject: [PATCH 17/20] import Optional and NotImplement for delayed scaling Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchtitan/float8_linear.py | 1 + torchtitan/parallelisms/parallelize_llama.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/torchtitan/float8_linear.py b/torchtitan/float8_linear.py index 0aced204..f8599a84 100644 --- a/torchtitan/float8_linear.py +++ b/torchtitan/float8_linear.py @@ -13,6 +13,7 @@ # Note: Performance # Float8 experimental is intended to be ran under `torch.compile`` for competitive performance import contextlib +from typing import Optional import float8_experimental.config as config diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index bb5606c8..1f6fb3c3 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -129,9 +129,9 @@ def get_tp_parallel_strategy( if any( isinstance(m, Float8Linear) and m.scaling_type_w is TensorScalingType.DELAYED - for m in module.modules() + for m in model.modules() ): - raise NotImplementedError("Only supports delayed scaling") + raise NotImplementedError("1D TP fp8 all-gather only supports dynamic scaling") from float8_experimental.float8_tensor_parallel import ( Float8ColwiseParallel, From cee653e801306a9b3c474d20110d70bd5da4e3b5 Mon Sep 17 00:00:00 2001 From: willfengg Date: Mon, 15 Jul 2024 20:48:51 -0700 Subject: [PATCH 18/20] remove TP fp8 all-gather from CI Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test_runner.py | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/test_runner.py b/test_runner.py index c28f1a23..f2f80504 100755 --- a/test_runner.py +++ b/test_runner.py @@ -283,18 +283,6 @@ def build_test_list(): "fp8_fsdp2_orig_all_gather", ngpu=4, ), - OverrideDefinitions( - [ - [ - "--training.enable_fp8_linear", - "--training.data_parallel_degree 1" - "--training.tensor_parallel_degree 4", - ] - ], - "1D TP with fp8 all-gather", - "tp_fp8_all_gather", - ngpu=4, - ), OverrideDefinitions( [ [ @@ -318,19 +306,6 @@ def build_test_list(): "fsdp2_fp8_all_gather_precompute_dynamic_scales", ngpu=4, ), - OverrideDefinitions( - [ - [ - "--training.enable_fp8_linear", - "--training.enable_fsdp_fp8_all_gather", - "--training.precompute_float8_dynamic_scale_for_fsdp", - "--training.tensor_parallel_degree 2", - ] - ], - "FSDP2 with fp8 all-gather and precomputed dynamic scales", - "fsdp2_tp_fp8_all_gather_precompute_dynamic_scales", - ngpu=4, - ), ] return integration_tests_flavors From e16428533f642d54855bf7f5d0419078a2f94ebe Mon Sep 17 00:00:00 2001 From: willfengg Date: Tue, 16 Jul 2024 10:24:31 -0700 Subject: [PATCH 19/20] fix linter Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchtitan/parallelisms/parallelize_llama.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 1f6fb3c3..b33e8870 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -131,7 +131,9 @@ def get_tp_parallel_strategy( and m.scaling_type_w is TensorScalingType.DELAYED for m in model.modules() ): - raise NotImplementedError("1D TP fp8 all-gather only supports dynamic scaling") + raise NotImplementedError( + "1D TP fp8 all-gather only supports dynamic scaling" + ) from float8_experimental.float8_tensor_parallel import ( Float8ColwiseParallel, From 22c71ea819d9df905bb802ebd991e4945fd74474 Mon Sep 17 00:00:00 2001 From: willfengg Date: Tue, 16 Jul 2024 15:45:42 -0700 Subject: [PATCH 20/20] remove redudant check Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchtitan/float8_linear.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchtitan/float8_linear.py b/torchtitan/float8_linear.py index f8599a84..f41a812d 100644 --- a/torchtitan/float8_linear.py +++ b/torchtitan/float8_linear.py @@ -47,8 +47,6 @@ def build_fp8_linear( This will mutate the model inplace. """ enable_fp8_linear = job_config.training.enable_fp8_linear - if not enable_fp8_linear: - return enable_fsdp_fp8_all_gather = ( job_config.training.enable_fsdp_fp8_all_gather and dp_enabled )