From febc6344b046f08e98cbb359ca0ef6db6a87274b Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Thu, 9 Jan 2025 11:26:25 -0500 Subject: [PATCH 01/25] Bookmark --- benchmarks/fp8/torchao/Dockerfile | 12 ++ benchmarks/fp8/torchao/README.md | 32 ++++ benchmarks/fp8/torchao/ddp.py | 144 +++++++++++++++ benchmarks/fp8/torchao/distrib_deepspeed.py | 190 ++++++++++++++++++++ benchmarks/fp8/torchao/fp8_utils.py | 116 ++++++++++++ benchmarks/fp8/torchao/fsdp.py | 161 +++++++++++++++++ benchmarks/fp8/torchao/non_distributed.py | 125 +++++++++++++ 7 files changed, 780 insertions(+) create mode 100644 benchmarks/fp8/torchao/Dockerfile create mode 100644 benchmarks/fp8/torchao/README.md create mode 100644 benchmarks/fp8/torchao/ddp.py create mode 100644 benchmarks/fp8/torchao/distrib_deepspeed.py create mode 100644 benchmarks/fp8/torchao/fp8_utils.py create mode 100644 benchmarks/fp8/torchao/fsdp.py create mode 100644 benchmarks/fp8/torchao/non_distributed.py diff --git a/benchmarks/fp8/torchao/Dockerfile b/benchmarks/fp8/torchao/Dockerfile new file mode 100644 index 00000000000..88c21934d4e --- /dev/null +++ b/benchmarks/fp8/torchao/Dockerfile @@ -0,0 +1,12 @@ +FROM nvcr.io/nvidia/pytorch:24.07-py3 + +RUN pip install transformers evaluate datasets +RUN git clone https://github.com/huggingface/accelerate.git + +RUN cd accelerate && \ + pip install -e . && \ + cd benchmarks/fp8 + +RUN /bin/bash + + diff --git a/benchmarks/fp8/torchao/README.md b/benchmarks/fp8/torchao/README.md new file mode 100644 index 00000000000..d5abadaf64e --- /dev/null +++ b/benchmarks/fp8/torchao/README.md @@ -0,0 +1,32 @@ +# FP8 Benchmarks + +Comparing and running [torchao](https://github.com/pytorch/ao/tree/main/torchao/float8) FP8 with accelerate + +## Overview + +This repo provides scripts which compare native `torchao` model training against `accelerate`'s own integration. Each modeling type is segmented out via a script, supporting the following: + +* Single GPU training (`non_distributed.py`) +* Multi-GPU training via DistributedDataParallelism (`ddp.py`) +* Fully Sharded Data Parallelism (`fsdp.py`) +* DeepSpeed ZeRO 1-3 (`deepspeed.py`) + +To run them, it's recommended to use a docker image (see the attached `Dockerfile`) and not install `torchao` manually. + +## Running: + +There are official Docker images located at `huggingface/accelerate:gpu-fp8-torchao-nightly` which can be used. + +You can run all scripts using the core `accelerate launch` command without any `accelerate config` being needed. + +For single GPU, run it via `python`: + +```bash +python non_distributed.py +``` + +For the rest, run it via `accelerate launch`: + +```bash +accelerate launch ddp.py # or distrib_deepspeed.py, ddp.py +``` \ No newline at end of file diff --git a/benchmarks/fp8/torchao/ddp.py b/benchmarks/fp8/torchao/ddp.py new file mode 100644 index 00000000000..ba708a27be4 --- /dev/null +++ b/benchmarks/fp8/torchao/ddp.py @@ -0,0 +1,144 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script tests to ensure that `accelerate` performs at the same level as raw `TransformersEngine`. + +This particular script verifies this for DDP training. +""" + +import evaluate +import torch +import transformer_engine.common.recipe as te_recipe +import transformer_engine.pytorch as te +from fp8_utils import evaluate_model, get_named_parameters, get_training_utilities +from torch.nn.parallel import DistributedDataParallel as DDP +from transformer_engine.common.recipe import DelayedScaling + +from accelerate import Accelerator +from accelerate.state import AcceleratorState +from accelerate.utils import FP8RecipeKwargs, set_seed +from accelerate.utils.transformer_engine import convert_model + + +MODEL_NAME = "bert-base-cased" +METRIC = evaluate.load("glue", "mrpc") + + +def train_baseline(): + set_seed(42) + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME) + accelerator = Accelerator() + device = accelerator.device + model.to(device) + + # Convert the model to TE + old_named_params = get_named_parameters(model) + + with torch.no_grad(): + convert_model(model) + + FP8_RECIPE_KWARGS = {"fp8_format": te_recipe.Format.HYBRID, "amax_history_len": 32, "amax_compute_algo": "max"} + fp8_recipe = DelayedScaling(**FP8_RECIPE_KWARGS) + + new_named_params = get_named_parameters(model) + + # Convert the model to DDP + device_ids, output_device = [accelerator.local_process_index], accelerator.local_process_index + model = DDP(model, device_ids=device_ids, output_device=output_device) + + mapping = {p: new_named_params[n] for n, p in old_named_params.items()} + for param_group in optimizer.param_groups: + param_group["params"] = [mapping[p] for p in param_group["params"]] + + base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + model.train() + + for _ in range(2): + for batch in train_dataloader: + with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + batch = batch.to(device) + outputs = model(**batch) + loss = outputs.loss + loss.backward() + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + + trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + + assert ( + trained_model_results["accuracy"] > base_model_results["accuracy"] + ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' + assert ( + trained_model_results["f1"] > base_model_results["f1"] + ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' + + return base_model_results, trained_model_results + + +def train_integration(): + FP8_RECIPE_KWARGS = {"fp8_format": "HYBRID", "amax_history_len": 32, "amax_compute_algo": "max"} + kwargs_handlers = [FP8RecipeKwargs(backend="TE", **FP8_RECIPE_KWARGS)] + AcceleratorState()._reset_state(True) + accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=kwargs_handlers) + set_seed(42) + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities( + MODEL_NAME, accelerator=accelerator + ) + + model, optimizer = accelerator.prepare(model, optimizer) + base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + model.train() + + for _ in range(2): + for batch in train_dataloader: + outputs = model(**batch) + loss = outputs.loss + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + + trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + + assert ( + trained_model_results["accuracy"] > base_model_results["accuracy"] + ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' + assert ( + trained_model_results["f1"] > base_model_results["f1"] + ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' + + return base_model_results, trained_model_results + + +if __name__ == "__main__": + baseline_not_trained, baseline_trained = train_baseline() + accelerator_not_trained, accelerator_trained = train_integration() + + assert ( + baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] + ), f'Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' + assert ( + baseline_not_trained["f1"] == accelerator_not_trained["f1"] + ), f'F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' + assert ( + baseline_trained["accuracy"] == accelerator_trained["accuracy"] + ), f'Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' + assert ( + baseline_trained["f1"] == accelerator_trained["f1"] + ), f'F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' + + torch.distributed.destroy_process_group() diff --git a/benchmarks/fp8/torchao/distrib_deepspeed.py b/benchmarks/fp8/torchao/distrib_deepspeed.py new file mode 100644 index 00000000000..e678deb3659 --- /dev/null +++ b/benchmarks/fp8/torchao/distrib_deepspeed.py @@ -0,0 +1,190 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script tests to ensure that `accelerate` performs at the same level as raw `TransformersEngine`. + +This particular script verifies this for DDP training. +""" + +from unittest.mock import patch + +import deepspeed +import evaluate +import torch +import transformer_engine.common.recipe as te_recipe +import transformer_engine.pytorch as te +from fp8_utils import evaluate_model, get_named_parameters, get_training_utilities +from transformer_engine.common.recipe import DelayedScaling + +from accelerate import Accelerator, DeepSpeedPlugin +from accelerate.state import AcceleratorState +from accelerate.utils import FP8RecipeKwargs, set_seed +from accelerate.utils.transformer_engine import convert_model + + +MODEL_NAME = "bert-base-cased" +METRIC = evaluate.load("glue", "mrpc") + + +def train_baseline(zero_stage: int = 1): + # This forces transformers to think Zero-3 Init should be used + with patch("transformers.integrations.deepspeed.is_deepspeed_zero3_enabled") as mock: + mock.return_value = zero_stage == 3 + set_seed(42) + + accelerator = Accelerator() + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities( + MODEL_NAME, accelerator=accelerator + ) + + # Convert the model to TE + old_named_params = get_named_parameters(model) + + with torch.no_grad(): + convert_model(model) + new_named_params = get_named_parameters(model) + + mapping = {p: new_named_params[n] for n, p in old_named_params.items()} + for param_group in optimizer.param_groups: + param_group["params"] = [mapping[p] for p in param_group["params"]] + + FP8_RECIPE_KWARGS = {"fp8_format": te_recipe.Format.HYBRID, "amax_history_len": 32, "amax_compute_algo": "max"} + fp8_recipe = DelayedScaling(**FP8_RECIPE_KWARGS) + + import numpy as np + + config = { + "train_batch_size": 32, + "train_micro_batch_size_per_gpu": 16, + "gradient_accumulation_steps": 1, + "zero_optimization": { + "stage": zero_stage, + "offload_optimizer": {"device": "none", "nvme_path": None}, + "offload_param": {"device": "none", "nvme_path": None}, + "stage3_gather_16bit_weights_on_model_save": False, + }, + "gradient_clipping": 1.0, + "steps_per_print": np.inf, + "bf16": {"enabled": True}, + "fp16": {"enabled": False}, + "zero_allow_untested_optimizer": True, + } + + ( + model, + optimizer, + _, + _, + ) = deepspeed.initialize( + model=model, + optimizer=optimizer, + config_params=config, + ) + + base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + model.train() + + model_outputs = [] + data = [] + + for _ in range(2): + for batch in train_dataloader: + with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + outputs = model(**batch) + data.append(batch.to("cpu")) + model_outputs.append(outputs.logits.to("cpu")) + loss = outputs.loss + model.backward(loss) + model.step() + for _ in range(accelerator.num_processes): + lr_scheduler.step() + + trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + model.destroy() + assert ( + trained_model_results["accuracy"] > base_model_results["accuracy"] + ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' + assert ( + trained_model_results["f1"] > base_model_results["f1"] + ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' + + return base_model_results, trained_model_results, model_outputs, data + + +def train_integration(zero_stage: int = 1): + set_seed(42) + FP8_RECIPE_KWARGS = {"fp8_format": "HYBRID", "amax_history_len": 32, "amax_compute_algo": "max"} + kwargs_handlers = [FP8RecipeKwargs(backend="TE", **FP8_RECIPE_KWARGS)] + AcceleratorState()._reset_state(True) + deepspeed_plugin = DeepSpeedPlugin( + zero_stage=zero_stage, + zero3_init_flag=zero_stage == 3, + ) + accelerator = Accelerator( + mixed_precision="fp8", kwargs_handlers=kwargs_handlers, deepspeed_plugin=deepspeed_plugin + ) + accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = 16 + + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities( + MODEL_NAME, accelerator=accelerator + ) + + model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) + base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + model.train() + model_outputs = [] + data = [] + for _ in range(2): + for batch in train_dataloader: + outputs = model(**batch) + data.append(batch.to("cpu")) + model_outputs.append(outputs.logits.to("cpu")) + loss = outputs.loss + accelerator.backward(loss) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + model.destroy() + assert ( + trained_model_results["accuracy"] > base_model_results["accuracy"] + ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' + assert ( + trained_model_results["f1"] > base_model_results["f1"] + ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' + + return base_model_results, trained_model_results, model_outputs, data + + +if __name__ == "__main__": + # for zero_stage in [1, 2, 3]: + zero_stage = 1 + baseline_not_trained, baseline_trained, baseline_outputs, baseline_data = train_baseline(zero_stage) + accelerator_not_trained, accelerator_trained, accelerator_outputs, accelerator_data = train_integration(zero_stage) + assert ( + baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] + ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' + assert ( + baseline_not_trained["f1"] == accelerator_not_trained["f1"] + ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' + assert ( + baseline_trained["accuracy"] == accelerator_trained["accuracy"] + ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' + assert ( + baseline_trained["f1"] == accelerator_trained["f1"] + ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' + + torch.distributed.destroy_process_group() diff --git a/benchmarks/fp8/torchao/fp8_utils.py b/benchmarks/fp8/torchao/fp8_utils.py new file mode 100644 index 00000000000..d28702e05ff --- /dev/null +++ b/benchmarks/fp8/torchao/fp8_utils.py @@ -0,0 +1,116 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + + +def get_dataloaders(model_name: str, batch_size: int = 16): + from datasets import load_dataset + from torch.utils.data import DataLoader + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(model_name) + datasets = load_dataset("glue", "mrpc") + + def tokenize_function(examples): + # max_length=None => use the model max length (it's actually the default) + outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None) + return outputs + + # Apply the method we just defined to all the examples in all the splits of the dataset + # starting with the main process first: + tokenized_datasets = datasets.map( + tokenize_function, + batched=True, + remove_columns=["idx", "sentence1", "sentence2"], + ) + + # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the + # transformers library + tokenized_datasets = tokenized_datasets.rename_column("label", "labels") + + def collate_fn(examples): + return tokenizer.pad( + examples, + padding="longest", + pad_to_multiple_of=16, # Specific for FP8 + return_tensors="pt", + ) + + # Instantiate dataloaders. + train_dataloader = DataLoader( + tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size, drop_last=True + ) + eval_dataloader = DataLoader( + tokenized_datasets["validation"], + shuffle=False, + collate_fn=collate_fn, + batch_size=16, + drop_last=True, + ) + + return train_dataloader, eval_dataloader + + +def get_training_utilities(model_name: str, batch_size: int = 16, accelerator=None): + """ + Returns a tuple of: + - Model + - Optimizer + - Train dataloader (prepared) + - Eval dataloader (prepared) + - LR Scheduler + Suitable for training on the MRPC dataset + """ + from torch.optim import AdamW + from transformers import AutoModelForSequenceClassification, get_linear_schedule_with_warmup + + from accelerate import Accelerator + + if accelerator is None: + accelerator = Accelerator() + model = AutoModelForSequenceClassification.from_pretrained(model_name) + train_dataloader, eval_dataloader = get_dataloaders(model_name, batch_size) + optimizer = AdamW(model.parameters(), lr=0.0001) + lr_scheduler = get_linear_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=100, + num_training_steps=len(train_dataloader) * 2, + ) + train_dataloader, eval_dataloader = accelerator.prepare(train_dataloader, eval_dataloader) + return model, optimizer, train_dataloader, eval_dataloader, lr_scheduler + + +def get_named_parameters(model): + """ + Same thing as `Accelerator.get_named_parameters` Returns a list of the named parameters of the model (extracted + from parallel) + """ + from accelerate.utils import extract_model_from_parallel + + model = extract_model_from_parallel(model) + return {n: p for n, p in model.named_parameters()} + + +def evaluate_model(model, dataloader, metric, accelerator=None): + "Turns model to .eval(), runs dataloader, calculates metric, then turns eval back on" + model.eval() + for step, batch in enumerate(dataloader): + with torch.no_grad(): + outputs = model(**batch) + predictions = outputs.logits.argmax(dim=-1) + references = batch["labels"] + if accelerator is not None and accelerator.num_processes > 1: + predictions, references = accelerator.gather_for_metrics((predictions, references)) + metric.add_batch(predictions=predictions, references=references) + return metric.compute() diff --git a/benchmarks/fp8/torchao/fsdp.py b/benchmarks/fp8/torchao/fsdp.py new file mode 100644 index 00000000000..418122185e1 --- /dev/null +++ b/benchmarks/fp8/torchao/fsdp.py @@ -0,0 +1,161 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script tests to ensure that `accelerate` performs at the same level as raw `TransformersEngine`. + +This particular script verifies this for FSDP training. +""" + +from functools import partial + +import evaluate +import torch +import transformer_engine.common.recipe as te_recipe +import transformer_engine.pytorch as te +from fp8_utils import evaluate_model, get_named_parameters, get_training_utilities +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import MixedPrecision +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from transformer_engine.common.recipe import DelayedScaling +from transformers.models.bert import BertLayer + +from accelerate import Accelerator +from accelerate import FullyShardedDataParallelPlugin as FSDPPlugin +from accelerate.state import AcceleratorState +from accelerate.utils import FP8RecipeKwargs, set_seed +from accelerate.utils.transformer_engine import convert_model + + +MODEL_NAME = "bert-base-cased" +METRIC = evaluate.load("glue", "mrpc") + +FSDP_WRAP_POLICY = partial(transformer_auto_wrap_policy, transformer_layer_cls={BertLayer}) + + +def train_baseline(): + set_seed(42) + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME) + accelerator = Accelerator() + device = accelerator.device + model.to(device) + + # Convert the model to TE + old_named_params = get_named_parameters(model) + + with torch.no_grad(): + convert_model(model) + + FP8_RECIPE_KWARGS = {"fp8_format": te_recipe.Format.HYBRID, "amax_history_len": 32, "amax_compute_algo": "max"} + fp8_recipe = DelayedScaling(**FP8_RECIPE_KWARGS) + + new_named_params = get_named_parameters(model) + + # Convert the model to FSDP + model = FSDP( + model, + use_orig_params=True, + mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32), + auto_wrap_policy=FSDP_WRAP_POLICY, + ) + + mapping = {p: new_named_params[n] for n, p in old_named_params.items()} + for param_group in optimizer.param_groups: + param_group["params"] = [mapping[p] for p in param_group["params"]] + + base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + model.train() + + for _ in range(2): + for batch in train_dataloader: + with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + batch = batch.to(device) + outputs = model(**batch) + loss = outputs.loss + loss.backward() + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + + trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + + assert ( + trained_model_results["accuracy"] > base_model_results["accuracy"] + ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' + assert ( + trained_model_results["f1"] > base_model_results["f1"] + ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' + + return base_model_results, trained_model_results + + +def train_integration(): + FP8_RECIPE_KWARGS = {"fp8_format": "HYBRID", "amax_history_len": 32, "amax_compute_algo": "max"} + kwargs_handlers = [FP8RecipeKwargs(backend="TE", **FP8_RECIPE_KWARGS)] + AcceleratorState()._reset_state(True) + fsdp_plugin = FSDPPlugin( + auto_wrap_policy=FSDP_WRAP_POLICY, + use_orig_params=True, + mixed_precision_policy=MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32), + ) + accelerator = Accelerator(mixed_precision="fp8", fsdp_plugin=fsdp_plugin, kwargs_handlers=kwargs_handlers) + set_seed(42) + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities( + MODEL_NAME, accelerator=accelerator + ) + + model, optimizer = accelerator.prepare(model, optimizer) + base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + model.train() + + for _ in range(2): + for batch in train_dataloader: + outputs = model(**batch) + loss = outputs.loss + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + + trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + + assert ( + trained_model_results["accuracy"] > base_model_results["accuracy"] + ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' + assert ( + trained_model_results["f1"] > base_model_results["f1"] + ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' + + return base_model_results, trained_model_results + + +if __name__ == "__main__": + baseline_not_trained, baseline_trained = train_baseline() + accelerator_not_trained, accelerator_trained = train_integration() + + assert ( + baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] + ), f'Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' + assert ( + baseline_not_trained["f1"] == accelerator_not_trained["f1"] + ), f'F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' + assert ( + baseline_trained["accuracy"] == accelerator_trained["accuracy"] + ), f'Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' + assert ( + baseline_trained["f1"] == accelerator_trained["f1"] + ), f'F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' + + torch.distributed.destroy_process_group() diff --git a/benchmarks/fp8/torchao/non_distributed.py b/benchmarks/fp8/torchao/non_distributed.py new file mode 100644 index 00000000000..81ebec12fe1 --- /dev/null +++ b/benchmarks/fp8/torchao/non_distributed.py @@ -0,0 +1,125 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script tests to ensure that `accelerate` performs at the same level as raw `torchao`. + +This particular script verifies this for single GPU training. +""" + +import evaluate +import torch +from fp8_utils import evaluate_model, get_named_parameters, get_training_utilities + +from accelerate import Accelerator +from accelerate.state import AcceleratorState +from accelerate.utils import FP8RecipeKwargs, set_seed + +from torchao.float8 import convert_to_float8_training + +MODEL_NAME = "bert-base-cased" +METRIC = evaluate.load("glue", "mrpc") + + +def module_filter_func(module, *args): + if isinstance(module, torch.nn.Linear): + if module.in_features % 16 != 0 or module.out_features % 16 != 0: + return False + + return True + + +def train_baseline(): + set_seed(42) + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME) + model.to("cuda") + convert_to_float8_training(model, module_filter_fn=module_filter_func) + base_model_results = evaluate_model(model, eval_dataloader, METRIC) + model.train() + + from accelerate.utils.modeling import get_mixed_precision_context_manager + from accelerate.utils.operations import convert_outputs_to_fp32 + autocast_context = get_mixed_precision_context_manager(True, {"dtype": torch.bfloat16}) + model_forward_func = model.forward + model.forward = convert_outputs_to_fp32(autocast_context(model_forward_func)) + + for batch in train_dataloader: + outputs = model(**batch) + loss = outputs.loss + loss.backward() + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + + trained_model_results = evaluate_model(model, eval_dataloader, METRIC) + + assert ( + trained_model_results["accuracy"] > base_model_results["accuracy"] + ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' + assert ( + trained_model_results["f1"] > base_model_results["f1"] + ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' + + return base_model_results, trained_model_results + + +def train_integration(): + FP8_RECIPE_KWARGS = {"fp8_format": "HYBRID", "amax_history_len": 32, "amax_compute_algo": "max"} + kwargs_handlers = [FP8RecipeKwargs(backend="TE", **FP8_RECIPE_KWARGS)] + AcceleratorState()._reset_state(True) + accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=kwargs_handlers) + set_seed(42) + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities( + MODEL_NAME, accelerator=accelerator + ) + + model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) + base_model_results = evaluate_model(model, eval_dataloader, METRIC) + model.train() + + for batch in train_dataloader: + outputs = model(**batch) + loss = outputs.loss + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + + trained_model_results = evaluate_model(model, eval_dataloader, METRIC) + + assert ( + trained_model_results["accuracy"] > base_model_results["accuracy"] + ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' + assert ( + trained_model_results["f1"] > base_model_results["f1"] + ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' + + return base_model_results, trained_model_results + + +if __name__ == "__main__": + baseline_not_trained, baseline_trained = train_baseline() + # accelerator_not_trained, accelerator_trained = train_integration() + # assert ( + # baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] + # ), f'Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' + # assert ( + # baseline_not_trained["f1"] == accelerator_not_trained["f1"] + # ), f'F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' + # assert ( + # baseline_trained["accuracy"] == accelerator_trained["accuracy"] + # ), f'Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' + # assert ( + # baseline_trained["f1"] == accelerator_trained["f1"] + # ), f'F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' From a7663c51a05d718c9081a612f8c52c3c2f94312a Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Fri, 10 Jan 2025 10:09:32 -0500 Subject: [PATCH 02/25] bookmark --- benchmarks/fp8/torchao/non_distributed.py | 139 +++++++++++++++------- 1 file changed, 94 insertions(+), 45 deletions(-) diff --git a/benchmarks/fp8/torchao/non_distributed.py b/benchmarks/fp8/torchao/non_distributed.py index 81ebec12fe1..08d1da99b74 100644 --- a/benchmarks/fp8/torchao/non_distributed.py +++ b/benchmarks/fp8/torchao/non_distributed.py @@ -20,18 +20,106 @@ import evaluate import torch -from fp8_utils import evaluate_model, get_named_parameters, get_training_utilities +from datasets import load_dataset +from torch.optim import AdamW +from torch.utils.data import DataLoader +from torchao.float8 import convert_to_float8_training +from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup from accelerate import Accelerator from accelerate.state import AcceleratorState from accelerate.utils import FP8RecipeKwargs, set_seed -from torchao.float8 import convert_to_float8_training MODEL_NAME = "bert-base-cased" METRIC = evaluate.load("glue", "mrpc") +def get_dataloaders(model_name: str, batch_size: int = 16): + tokenizer = AutoTokenizer.from_pretrained(model_name) + datasets = load_dataset("glue", "mrpc") + + def tokenize_function(examples): + # max_length=None => use the model max length (it's actually the default) + outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None) + return outputs + + # Apply the method we just defined to all the examples in all the splits of the dataset + # starting with the main process first: + tokenized_datasets = datasets.map( + tokenize_function, + batched=True, + remove_columns=["idx", "sentence1", "sentence2"], + ) + + # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the + # transformers library + tokenized_datasets = tokenized_datasets.rename_column("label", "labels") + + def collate_fn(examples): + return tokenizer.pad( + examples, + padding="longest", + pad_to_multiple_of=16, # Specific for FP8 + return_tensors="pt", + ) + + # Instantiate dataloaders. + train_dataloader = DataLoader( + tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size, drop_last=True + ) + eval_dataloader = DataLoader( + tokenized_datasets["validation"], + shuffle=False, + collate_fn=collate_fn, + batch_size=16, + drop_last=True, + ) + + return train_dataloader, eval_dataloader + + +def get_training_utilities(model_name: str, batch_size: int = 16, accelerator=None): + """ + Returns a tuple of: + - Model + - Optimizer + - Train dataloader (prepared) + - Eval dataloader (prepared) + - LR Scheduler + Suitable for training on the MRPC dataset + """ + + if accelerator is None: + accelerator = Accelerator() + model = AutoModelForSequenceClassification.from_pretrained(model_name) + train_dataloader, eval_dataloader = get_dataloaders(model_name, batch_size) + optimizer = AdamW(model.parameters(), lr=0.0001) + lr_scheduler = get_linear_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=100, + num_training_steps=len(train_dataloader) * 2, + ) + train_dataloader, eval_dataloader = accelerator.prepare(train_dataloader, eval_dataloader) + return model, optimizer, train_dataloader, eval_dataloader, lr_scheduler + + + + +def evaluate_model(model, dataloader, metric, accelerator=None): + "Turns model to .eval(), runs dataloader, calculates metric, then turns eval back on" + model.eval() + for step, batch in enumerate(dataloader): + with torch.no_grad(): + outputs = model(**batch) + predictions = outputs.logits.argmax(dim=-1) + references = batch["labels"] + if accelerator is not None and accelerator.num_processes > 1: + predictions, references = accelerator.gather_for_metrics((predictions, references)) + metric.add_batch(predictions=predictions, references=references) + return metric.compute() + + def module_filter_func(module, *args): if isinstance(module, torch.nn.Linear): if module.in_features % 16 != 0 or module.out_features % 16 != 0: @@ -48,50 +136,11 @@ def train_baseline(): base_model_results = evaluate_model(model, eval_dataloader, METRIC) model.train() - from accelerate.utils.modeling import get_mixed_precision_context_manager - from accelerate.utils.operations import convert_outputs_to_fp32 - autocast_context = get_mixed_precision_context_manager(True, {"dtype": torch.bfloat16}) - model_forward_func = model.forward - model.forward = convert_outputs_to_fp32(autocast_context(model_forward_func)) - - for batch in train_dataloader: - outputs = model(**batch) - loss = outputs.loss - loss.backward() - optimizer.step() - optimizer.zero_grad() - lr_scheduler.step() - - trained_model_results = evaluate_model(model, eval_dataloader, METRIC) - - assert ( - trained_model_results["accuracy"] > base_model_results["accuracy"] - ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' - assert ( - trained_model_results["f1"] > base_model_results["f1"] - ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' - - return base_model_results, trained_model_results - - -def train_integration(): - FP8_RECIPE_KWARGS = {"fp8_format": "HYBRID", "amax_history_len": 32, "amax_compute_algo": "max"} - kwargs_handlers = [FP8RecipeKwargs(backend="TE", **FP8_RECIPE_KWARGS)] - AcceleratorState()._reset_state(True) - accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=kwargs_handlers) - set_seed(42) - model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities( - MODEL_NAME, accelerator=accelerator - ) - - model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) - base_model_results = evaluate_model(model, eval_dataloader, METRIC) - model.train() - for batch in train_dataloader: - outputs = model(**batch) - loss = outputs.loss - accelerator.backward(loss) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + outputs = model(**batch) + loss = outputs.loss + loss.backward() optimizer.step() optimizer.zero_grad() lr_scheduler.step() From ed1adb1d41735859c714f9a64c4afaf775621120 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Thu, 16 Jan 2025 08:56:35 -0500 Subject: [PATCH 03/25] Add torchao base example --- benchmarks/fp8/torchao/non_distributed.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/benchmarks/fp8/torchao/non_distributed.py b/benchmarks/fp8/torchao/non_distributed.py index 08d1da99b74..732a23a2570 100644 --- a/benchmarks/fp8/torchao/non_distributed.py +++ b/benchmarks/fp8/torchao/non_distributed.py @@ -20,6 +20,7 @@ import evaluate import torch +from functools import partial from datasets import load_dataset from torch.optim import AdamW from torch.utils.data import DataLoader @@ -104,8 +105,6 @@ def get_training_utilities(model_name: str, batch_size: int = 16, accelerator=No return model, optimizer, train_dataloader, eval_dataloader, lr_scheduler - - def evaluate_model(model, dataloader, metric, accelerator=None): "Turns model to .eval(), runs dataloader, calculates metric, then turns eval back on" model.eval() @@ -120,19 +119,31 @@ def evaluate_model(model, dataloader, metric, accelerator=None): return metric.compute() -def module_filter_func(module, *args): +def module_filter_func(module, fqn, first_layer_name=None, last_layer_name=None): if isinstance(module, torch.nn.Linear): if module.in_features % 16 != 0 or module.out_features % 16 != 0: return False - + # For stability reasons, we skip the first and last linear layers + # Otherwise can lead to the model not training or converging properly + if fqn in (first_layer_name, last_layer_name): + return False return True def train_baseline(): set_seed(42) model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME) + first_linear = None + last_linear = None + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + if first_linear is None: + first_linear = name + last_linear = name + + func = partial(module_filter_func, first_layer_name=first_linear, last_layer_name=last_linear) model.to("cuda") - convert_to_float8_training(model, module_filter_fn=module_filter_func) + convert_to_float8_training(model, module_filter_fn=func) base_model_results = evaluate_model(model, eval_dataloader, METRIC) model.train() From 3d34f8ec92936a91f6f475d94d06a8dfe21b0de6 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Thu, 16 Jan 2025 10:46:05 -0500 Subject: [PATCH 04/25] Currently broken --- benchmarks/fp8/torchao/non_distributed.py | 39 ++++- src/accelerate/accelerator.py | 113 +++++++++----- src/accelerate/utils/__init__.py | 6 + src/accelerate/utils/ao.py | 112 ++++++++++++++ src/accelerate/utils/dataclasses.py | 176 +++++++++++++--------- src/accelerate/utils/imports.py | 20 +++ 6 files changed, 351 insertions(+), 115 deletions(-) create mode 100644 src/accelerate/utils/ao.py diff --git a/benchmarks/fp8/torchao/non_distributed.py b/benchmarks/fp8/torchao/non_distributed.py index 732a23a2570..81eb0d2bc73 100644 --- a/benchmarks/fp8/torchao/non_distributed.py +++ b/benchmarks/fp8/torchao/non_distributed.py @@ -28,8 +28,7 @@ from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup from accelerate import Accelerator -from accelerate.state import AcceleratorState -from accelerate.utils import FP8RecipeKwargs, set_seed +from accelerate.utils import AORecipeKwargs, set_seed MODEL_NAME = "bert-base-cased" @@ -119,7 +118,7 @@ def evaluate_model(model, dataloader, metric, accelerator=None): return metric.compute() -def module_filter_func(module, fqn, first_layer_name=None, last_layer_name=None): +def filter_linear_layers(module, fqn, first_layer_name=None, last_layer_name=None): if isinstance(module, torch.nn.Linear): if module.in_features % 16 != 0 or module.out_features % 16 != 0: return False @@ -141,7 +140,7 @@ def train_baseline(): first_linear = name last_linear = name - func = partial(module_filter_func, first_layer_name=first_linear, last_layer_name=last_linear) + func = partial(filter_linear_layers, first_layer_name=first_linear, last_layer_name=last_linear) model.to("cuda") convert_to_float8_training(model, module_filter_fn=func) base_model_results = evaluate_model(model, eval_dataloader, METRIC) @@ -168,9 +167,37 @@ def train_baseline(): return base_model_results, trained_model_results +def train_integration(): + set_seed(42) + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME) + accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=[AORecipeKwargs()]) + model = accelerator.prepare(model) + base_model_results = evaluate_model(model, eval_dataloader, METRIC) + model.train() + + for batch in train_dataloader: + outputs = model(**batch) + loss = outputs.loss + loss.backward() + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + + trained_model_results = evaluate_model(model, eval_dataloader, METRIC) + + assert ( + trained_model_results["accuracy"] > base_model_results["accuracy"] + ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' + assert ( + trained_model_results["f1"] > base_model_results["f1"] + ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' + + return base_model_results, trained_model_results + + if __name__ == "__main__": - baseline_not_trained, baseline_trained = train_baseline() - # accelerator_not_trained, accelerator_trained = train_integration() + # baseline_not_trained, baseline_trained = train_baseline() + accelerator_not_trained, accelerator_trained = train_integration() # assert ( # baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] # ), f'Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index a483f0d1a39..476e0814b2d 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -29,6 +29,7 @@ from types import MethodType from typing import Any, Callable, Union +from accelerate.utils.imports import is_torchao_available import torch import torch.utils.hooks as hooks from huggingface_hub import split_torch_state_dict_into_shards @@ -49,6 +50,9 @@ WEIGHTS_NAME, WEIGHTS_PATTERN_NAME, AutocastKwargs, + AORecipeKwargs, + TERecipeKwargs, + MSAMPRecipeKwargs, DataLoaderConfiguration, DeepSpeedPlugin, DistributedDataParallelKwargs, @@ -73,6 +77,7 @@ clean_state_dict_for_safetensors, compare_versions, convert_model, + convert_to_float8_training, convert_outputs_to_fp32, ensure_weights_retied, extract_model_from_parallel, @@ -409,45 +414,39 @@ def __init__( self.scaler_handler = None self.init_handler = None self.fp8_recipe_handler = None + self.ao_recipe_handler = None + self.te_recipe_handler = None + self.msamp_recipe_handler = None self.autocast_handler = None self.profile_handler = None self.has_lomo_optimizer = False + found_handlers = set() + handler_class_to_attr = { + DistributedDataParallelKwargs: "ddp_handler", + GradScalerKwargs: "scaler_handler", + InitProcessGroupKwargs: "init_handler", + FP8RecipeKwargs: "fp8_recipe_handler", + AutocastKwargs: "autocast_handler", + ProfileKwargs: "profile_handler", + AORecipeKwargs: "ao_recipe_handler", + TERecipeKwargs: "te_recipe_handler", + MSAMPRecipeKwargs: "msamp_recipe_handler", + } + self.has_fp8_handler = False if kwargs_handlers is not None: for handler in kwargs_handlers: assert isinstance( handler, KwargsHandler ), f"Unsupported kwargs handler passed: {handler}, must be one that inherits `accelerate.utils.KwargsHandler`." - if isinstance(handler, DistributedDataParallelKwargs): - if self.ddp_handler is not None: - raise ValueError("You can only pass one `DistributedDataParallelKwargs` in `kwargs_handler`.") - else: - self.ddp_handler = handler - elif isinstance(handler, GradScalerKwargs): - if self.scaler_handler is not None: - raise ValueError("You can only pass one `GradScalerKwargs` in `kwargs_handler`.") - else: - self.scaler_handler = handler - elif isinstance(handler, InitProcessGroupKwargs): - if self.init_handler is not None: - raise ValueError("You can only pass one `InitProcessGroupKwargs` in `kwargs_handler`.") - else: - self.init_handler = handler - elif isinstance(handler, FP8RecipeKwargs): - if self.fp8_recipe_handler is not None: - raise ValueError("You can only pass one `FP8RecipeKwargs` in `kwargs_handler`.") - else: - self.fp8_recipe_handler = handler - elif isinstance(handler, AutocastKwargs): - if self.autocast_handler is not None: - raise ValueError("You can only pass one `AutocastKwargs` in `kwargs_handler`.") - else: - self.autocast_handler = handler - elif isinstance(handler, ProfileKwargs): - if self.profile_handler is not None: - raise ValueError("You can only pass one `ProfileKwargs` in `kwargs_handler`.") - else: - self.profile_handler = handler + # Add the handler class to the set of found handlers + if handler.__class__ in found_handlers: + raise ValueError(f"You can only pass one {handler.__class__} in `kwargs_handlers`.") + found_handlers.add(handler.__class__) + handler_attr = handler_class_to_attr[handler.__class__] + setattr(self, handler_attr, handler) + if "recipe_handler" in handler_attr and not self.has_fp8_handler: + self.has_fp8_handler = True kwargs = self.init_handler.to_kwargs() if self.init_handler is not None else {} self.state = AcceleratorState( @@ -463,17 +462,27 @@ def __init__( ) self._mixed_precision = mixed_precision - if mixed_precision == "fp8" and self.fp8_recipe_handler is None: - self.fp8_recipe_handler = FP8RecipeKwargs() + # Check for automatic FP8 recipe creation + if self._mixed_precision == "fp8" and not self.has_fp8_handler: + # Prioritize TE -> AO -> MSAMP + if is_torchao_available(): + self.ao_recipe_handler = AORecipeKwargs() + elif is_transformer_engine_available(): + self.te_recipe_handler = TERecipeKwargs() + elif is_msamp_available(): + self.msamp_recipe_handler = MSAMPRecipeKwargs() + else: + raise ImportError("Tried to train with `fp8` and auto-detect backend, but no FP8-compatible backend was installed.") self.delayed_fp8_autocast = False - if self.fp8_recipe_handler is not None: + if self.has_fp8_handler: # We already check if FP8 is available during `self.state` if mixed_precision != "fp8" and ( self.distributed_type not in (DistributedType.FSDP, DistributedType.DEEPSPEED) ): - raise ValueError("Passing in a `FP8RecipeKwargs` object requires setting `mixed_precision='fp8'`.") - self.delayed_fp8_autocast = self.fp8_recipe_handler.backend == "TE" and self.distributed_type in ( + raise ValueError("Passing in an FP8 configuration requires setting `mixed_precision='fp8'`.") + # DEPRECATE once 2.0 is released + self.delayed_fp8_autocast = self.fp8_backend == "TE" and self.distributed_type in ( DistributedType.MULTI_GPU, DistributedType.FSDP, ) @@ -1362,6 +1371,8 @@ def prepare(self, *args, device_placement=None): args = self._prepare_ipex_or_xpu(*args) if self.fp8_backend == "TE": args = self._prepare_te(*args) + elif self.fp8_backend == "AO": + args = self._prepare_ao(*args) if self.distributed_type == DistributedType.DEEPSPEED: result = self._prepare_deepspeed(*args) elif self.distributed_type == DistributedType.MEGATRON_LM: @@ -1447,7 +1458,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e # We prepare TE after, allowing for bf16 autocast to happen first if self.fp8_backend == "TE" and not self.delayed_fp8_autocast: - model = apply_fp8_autowrap(model, self.fp8_recipe_handler) + model = apply_fp8_autowrap(model, self.te_recipe_handler or self.fp8_recipe_handler) if (getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)) and getattr( model, "hf_device_map", False @@ -1651,12 +1662,19 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e model = xmp.MpModelWrapper(model).to(self.device) # Now we can apply the FP8 autocast if self.delayed_fp8_autocast: - model = apply_fp8_autowrap(model, self.fp8_recipe_handler) + model = apply_fp8_autowrap(model, self.te_recipe_handler or self.fp8_recipe_handler) # torch.compile should be called last and only if the model isn't already compiled. if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model): model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs()) return model + def _prepare_ao(self, *args): + if not is_torchao_available(): + raise ImportError("`torchao` was not found on your system. Please ensure that `torchao` is installed") + for model in self._models: + convert_to_float8_training(model, config=self.ao_recipe_handler.config, module_filter_func=self.ao_recipe_handler.module_filter_func) + return args + def _prepare_te(self, *args): if not is_transformer_engine_available(): raise ImportError( @@ -1811,7 +1829,7 @@ def _prepare_deepspeed(self, *args): if model is not None: # If we are using FP8, we need to apply the autowrap now - if getattr(self.fp8_recipe_handler, "backend", None) == "TE": + if self.fp8_backend == "TE": model = apply_fp8_autowrap(model, self.fp8_recipe_handler) # if the model is an MOE, set the appropriate MOE layers as leaf Z3 modules deepspeed_plugin.set_moe_leaf_modules(model) @@ -2106,7 +2124,12 @@ def _prepare_msamp(self, *args, device_placement): f"You can't use multiple models ({num_models}) or optimizers {num_optimizers} with MS-AMP." ) else: - model, optimizer = msamp.initialize(model, optimizer, opt_level=self.fp8_recipe_handler.opt_level) + # DEPRECATE @ 2.0 + if self.fp8_recipe_handler is not None: + opt_level = self.fp8_recipe_handler.opt_level + else: + opt_level = self.msamp_recipe_handler.opt_level + model, optimizer = msamp.initialize(model, optimizer, opt_level=opt_level) for i in range(len(result)): if isinstance(result[i], torch.nn.Module): result[i] = model @@ -3647,8 +3670,20 @@ def lomo_backward(self, loss: torch.Tensor, learning_rate: float) -> None: @property def fp8_backend(self): "Returns the configured backend for training in FP8" +<<<<<<< HEAD if self._mixed_precision == "fp8" and self.fp8_recipe_handler is not None: return self.fp8_recipe_handler.backend +======= + if self.has_fp8_handler: + if self.fp8_recipe_handler is not None: + return self.fp8_recipe_handler.backend + elif self.ao_recipe_handler is not None: + return "AO" + elif self.te_recipe_handler is not None: + return "TE" + elif self.msamp_recipe_handler is not None: + return "MSAMP" +>>>>>>> be210db (Currently broken) elif self.state.deepspeed_plugin is not None and self.state.deepspeed_plugin.enable_msamp: return "MSAMP" return None diff --git a/src/accelerate/utils/__init__.py b/src/accelerate/utils/__init__.py index e0ea5841372..cad3a06f018 100644 --- a/src/accelerate/utils/__init__.py +++ b/src/accelerate/utils/__init__.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from .ao import convert_to_float8_training from .constants import ( MITA_PROFILING_AVAILABLE_PYTORCH_VERSION, MODEL_NAME, @@ -33,6 +34,7 @@ ) from .dataclasses import ( AutocastKwargs, + AORecipeKwargs, BnbQuantizationConfig, ComputeEnvironment, CustomDtype, @@ -58,6 +60,8 @@ TensorInformation, TorchDynamoPlugin, TorchTensorParallelPlugin, + TERecipeKwargs, + MSAMPRecipeKwargs, add_model_config_to_megatron_parser, ) from .environment import ( @@ -78,6 +82,7 @@ ) from .imports import ( deepspeed_required, + torchao_required, get_ccl_version, is_4bit_bnb_available, is_8bit_bnb_available, @@ -115,6 +120,7 @@ is_tensorboard_available, is_timm_available, is_torch_xla_available, + is_torchao_available, is_torchdata_available, is_torchdata_stateful_dataloader_available, is_torchvision_available, diff --git a/src/accelerate/utils/ao.py b/src/accelerate/utils/ao.py new file mode 100644 index 00000000000..1d21738c495 --- /dev/null +++ b/src/accelerate/utils/ao.py @@ -0,0 +1,112 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Needed utilities for torchao FP8 training. +""" + +from functools import partial + +import torch + +from .imports import torchao_required + + +def find_first_last_linear_layers(model: torch.nn.Module): + """ + Finds the first and last linear layer names in a model. + + This is needed during FP8 to avoid issues with + instability by keeping the first and last layers + unquantized. + + Ref: https://x.com/xariusrke/status/1826669142604141052 + """ + first_linear, last_linear = None, None + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + if first_linear is None: + first_linear = name + last_linear = name + return first_linear, last_linear + + +def filter_linear_layers(module, layer_name, first_layer_name, last_layer_name) -> bool: + """ + A function which will check if `module` is: + - a `torch.nn.Linear` layer + - has in_features and out_features divisible by 16 + - is not the first or last layer of the model. + + Args: + module (`torch.nn.Module`): + The module to check. + layer_name (`str`): + The fully qualified name of the layer. + first_layer_name (`str`): + The name of the first layer of the model. + last_layer_name (`str`): + The name of the last layer of the model. + """ + if isinstance(module, torch.nn.Linear): + if module.in_features % 16 != 0 or module.out_features % 16 != 0: + return False + # For stability reasons, we skip the first and last linear layers + # Otherwise can lead to the model not training or converging properly + # TODO: apply this to all FP8 backends + if layer_name in (first_layer_name, last_layer_name): + return False + return True + + +@torchao_required +def convert_to_float8_training( + model: torch.nn.Module, + config=None, + module_filter_func=None, + ): + """ + Converts all `nn.Linear` layers in the model (except the first and last) + to torchao's `Float8Linear` layer inplace. + + Args: + model (`torch.nn.Module`): + The model to convert. + config (`torchao.float8.Float8LinearConfig`, *optional*): + The configuration for the FP8 training. Recommended to utilize + `torchao.float8.recipe_name_to_linear_config` to generate this. + In general, the default config should be sufficient. + module_filter_func (`Callable`, *optional*): + Optional function that must take in a module and layer name, + and returns a boolean indicating whether the module should be + converted to FP8. Defaults to `filter_linear_layers`. See + it for an example. + + Example: + + ```python + from accelerate.utils.ao import convert_to_float8_training + model = MyModel() + model.to("cuda") + convert_to_float8_training(model) + + model.train() + ``` + """ + from torchao.float8 import convert_to_float8_training + + first_linear, last_linear = find_first_last_linear_layers(model) + if module_filter_func is None: + module_filter_func = partial(filter_linear_layers, first_layer_name=first_linear, last_layer_name=last_linear) + convert_to_float8_training(model, config, module_filter_func) diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index 3baa525d294..aa347100d3c 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -22,13 +22,15 @@ import functools import os import warnings +import logging from contextlib import contextmanager from dataclasses import dataclass, field from datetime import timedelta -from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union, get_args +from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union, get_args, TYPE_CHECKING import torch +from .ao import filter_linear_layers from .constants import ( BETA_TP_AVAILABLE_PYTORCH_VERSION, FSDP_AUTO_WRAP_POLICY, @@ -49,6 +51,12 @@ ) from .versions import compare_versions, is_torch_version +if TYPE_CHECKING: + # Mock imports for type checking + from torchao.float8 import Float8LinearConfig + +logger = logging.getLogger(__name__) + class KwargsHandler: """ @@ -281,40 +289,57 @@ def __post_init__(self): AmaxComputeAlgorithm = Literal["max", "most_recent"] +# FP8 training recipe kwargs +@dataclass +class AORecipeKwargs(KwargsHandler): + """ + Use this object in your [`Accelerator`] to customize the initialization of the recipe for FP8 mixed precision + training with `torchao` FP8. + + Args: + recipe_name (`str`, *optional*, default to `None`): + The name of the recipe to use for FP8 training. Should + be compatible with `torchao.float8.recipe_name_to_linear_config`. + config (`torchao.float8.Float8LinearConfig`, *optional*, default to `None`): + The configuration for the FP8 training. In general, the default config + should be sufficient. + module_filter_func (`Callable`, *optional*, default to `None`): + Optional function that must take in a module and layer name, + and returns a boolean indicating whether the module should be + converted to FP8. Defaults to `accelerate.utils.ao.filter_linear_layers`. See + it for an example. + """ + recipe_name: str = None + config: "Float8LinearConfig" = None + module_filter_func: Callable = None + + def __post_init__(self): + if self.module_filter_func is None: + self.module_filter_func = filter_linear_layers + + @dataclass -class FP8RecipeKwargs(KwargsHandler): +class TERecipeKwargs(KwargsHandler): """ Use this object in your [`Accelerator`] to customize the initialization of the recipe for FP8 mixed precision - training with `transformer-engine` or `ms-amp`. + training with `transformer-engine`. - For more information on `transformer-engine` args, please refer to the API + For more information on the args, please refer to the API [documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html). - For more information on the `ms-amp` args, please refer to the Optimization Level - [documentation](https://azure.github.io/MS-AMP/docs/user-tutorial/optimization-level). - ```python from accelerate import Accelerator - from accelerate.utils import FP8RecipeKwargs + from accelerate.utils import TERecipeKwargs - kwargs = FP8RecipeKwargs(backend="te", fp8_format="HYBRID") + kwargs = TERecipeKwargs(fp8_format="HYBRID") accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=[kwargs]) ``` - To use MS-AMP as an engine, pass `backend="msamp"` and the `optimization_level`: - - ```python - kwargs = FP8RecipeKwargs(backend="msamp", optimization_level="02") - ``` - Args: - backend (`str`, *optional*): - Which FP8 engine to use. Must be one of `"msamp"` (MS-AMP) or `"te"` (TransformerEngine). If not passed, - will use whichever is available in the environment, prioritizing MS-AMP. use_autocast_during_eval (`bool`, *optional*, default to `False`): Whether to use FP8 autocast during eval mode. Generally better metrics are found when this is `False`. margin (`int`, *optional*, default to 0): @@ -330,21 +355,8 @@ class FP8RecipeKwargs(KwargsHandler): The algorithm to use for the scaling factor computation. Must be one of `max` or `most_recent`. override_linear_precision (`tuple` of three `bool`, *optional*, default to `(False, False, False)`): Whether or not to execute `fprop`, `dgrad`, and `wgrad` GEMMS in higher precision. - optimization_level (`str`), one of `O1`, `O2`. (default is `O2`): - What level of 8-bit collective communication should be used with MS-AMP. In general: - * O1: Weight gradients and `all_reduce` communications are done in fp8, reducing GPU - memory usage and communication bandwidth - * O2: First-order optimizer states are in 8-bit, and second order states are in FP16. - Only available when using Adam or AdamW. This maintains accuracy and can potentially save the - highest memory. - * 03: Specifically for DeepSpeed, implements capabilities so weights and master weights of models - are stored in FP8. If `fp8` is selected and deepspeed is enabled, will be used by default. (Not - available currently). """ - - backend: Backend = None use_autocast_during_eval: bool = None - opt_level: OptLevel = None margin: int = None interval: int = None fp8_format: FP8Format = None @@ -354,50 +366,74 @@ class FP8RecipeKwargs(KwargsHandler): def __post_init__(self): env_prefix = "ACCELERATE_FP8_" + if not is_transformer_engine_available(): + raise ImportError( + "TransformerEngine is not available. Please install it or use a different backend." + ) + if self.use_autocast_during_eval is None: + self.use_autocast_during_eval = parse_flag_from_env(env_prefix + "USE_AUTOCAST_DURING_EVAL") + if self.margin is None: + self.margin = int(os.environ.get(env_prefix + "MARGIN", 0)) + if self.interval is None: + self.interval = int(os.environ.get(env_prefix + "INTERVAL", 1)) + if self.fp8_format is None: + self.fp8_format = os.environ.get(env_prefix + "FORMAT", "HYBRID") + self.fp8_format = self.fp8_format.upper() + if self.fp8_format not in get_args(FP8Format): + raise ValueError(f"`fp8_format` must be one of {' or '.join(get_args(FP8Format))}.") + if self.amax_compute_algo is None: + self.amax_compute_algo = os.environ.get(env_prefix + "AMAX_COMPUTE_ALGO", "most_recent") + self.amax_compute_algo = self.amax_compute_algo.lower() + if self.amax_compute_algo not in get_args(AmaxComputeAlgorithm): + raise ValueError(f"`amax_compute_algo` must be one of {' or '.join(get_args(AmaxComputeAlgorithm))}") + if self.amax_history_len is None: + self.amax_history_len = int(os.environ.get(env_prefix + "AMAX_HISTORY_LEN", 1024)) + if self.override_linear_precision is None: + fprop = parse_flag_from_env(env_prefix + "OVERRIDE_FPROP") + dgrad = parse_flag_from_env(env_prefix + "OVERRIDE_DGRAD") + wgrad = parse_flag_from_env(env_prefix + "OVERRIDE_WGRAD") + self.override_linear_precision = (fprop, dgrad, wgrad) + + +@dataclass +class MSAMPRecipeKwargs(KwargsHandler): + """ + Use this object in your [`Accelerator`] to customize the initialization of the recipe for FP8 mixed precision + training with `ms-amp`. + """ + opt_level: OptLevel = None + + def __post_init__(self): + env_prefix = "ACCELERATE_FP8_" + if self.opt_level is None: + self.opt_level = os.environ.get(env_prefix + "OPT_LEVEL", "O2") + if self.opt_level not in get_args(OptLevel): + raise ValueError(f"`opt_level` must be one of {' or '.join(get_args(OptLevel))}") + + +@dataclass +class FP8RecipeKwargs(TERecipeKwargs, MSAMPRecipeKwargs): + """ + Deprecated. Please use one of the proper FP8 recipe + kwargs classes such as `TERecipeKwargs` or `MSAMPRecipeKwargs` + instead. + """ + + backend: Backend = None + + def __post_init__(self): + env_prefix = "ACCELERATE_FP8_" + warnings.warn( + "FP8RecipeKwargs is deprecated and will be removed in Accelerate v2.0.0. " + "Please use one of the proper FP8 recipe kwargs classes such as TERecipeKwargs or MSAMPRecipeKwargs instead.", + FutureWarning, + ) default_backend = "msamp" if is_msamp_available() else "te" if self.backend is None: self.backend = os.environ.get(env_prefix + "BACKEND", default_backend) self.backend = self.backend.upper() if self.backend not in get_args(Backend): - raise ValueError("`backend` must be 'MSAMP' or 'TE' (TransformerEngine).") - # Check TE args - if self.backend == "TE": - if not is_transformer_engine_available(): - raise ValueError( - "TransformerEngine is not available. Please either install it, or use the 'MSAMP' backend (if installed)." - ) - if self.use_autocast_during_eval is None: - self.use_autocast_during_eval = parse_flag_from_env(env_prefix + "USE_AUTOCAST_DURING_EVAL") - if self.margin is None: - self.margin = int(os.environ.get(env_prefix + "MARGIN", 0)) - if self.interval is None: - self.interval = int(os.environ.get(env_prefix + "INTERVAL", 1)) - if self.fp8_format is None: - self.fp8_format = os.environ.get(env_prefix + "FORMAT", "HYBRID") - self.fp8_format = self.fp8_format.upper() - if self.fp8_format not in get_args(FP8Format): - raise ValueError(f"`fp8_format` must be one of {' or '.join(get_args(FP8Format))}.") - if self.amax_compute_algo is None: - self.amax_compute_algo = os.environ.get(env_prefix + "AMAX_COMPUTE_ALGO", "most_recent") - self.amax_compute_algo = self.amax_compute_algo.lower() - if self.amax_compute_algo not in get_args(AmaxComputeAlgorithm): - raise ValueError(f"`amax_compute_algo` must be one of {' or '.join(get_args(AmaxComputeAlgorithm))}") - if self.amax_history_len is None: - self.amax_history_len = int(os.environ.get(env_prefix + "AMAX_HISTORY_LEN", 1024)) - if self.override_linear_precision is None: - fprop = parse_flag_from_env(env_prefix + "OVERRIDE_FPROP") - dgrad = parse_flag_from_env(env_prefix + "OVERRIDE_DGRAD") - wgrad = parse_flag_from_env(env_prefix + "OVERRIDE_WGRAD") - self.override_linear_precision = (fprop, dgrad, wgrad) - elif self.backend == "MSAMP": - if not is_msamp_available(): - raise ValueError( - "MS-AMP is not available. Please either install it, or use the 'TE' backend (if installed)." - ) - if self.opt_level is None: - self.opt_level = os.environ.get(env_prefix + "OPT_LEVEL", "O2") - if self.opt_level not in get_args(OptLevel): - raise ValueError(f"`optimization_level` must be one of {' or '.join(get_args(OptLevel))}") + raise ValueError("`backend` must be 'MSAMP' or 'TE' (TransformerEngine) to use `FP8RecipeKwargs`.") # Literal diff --git a/src/accelerate/utils/imports.py b/src/accelerate/utils/imports.py index b271dab9a9a..38daacf5831 100644 --- a/src/accelerate/utils/imports.py +++ b/src/accelerate/utils/imports.py @@ -142,6 +142,10 @@ def is_torch_xla_available(check_is_tpu=False, check_is_gpu=False): return True +def is_torchao_available(): + return _is_package_available("torchao") + + def is_deepspeed_available(): if is_mlu_available(): return _is_package_available("deepspeed", metadata_name="deepspeed-mlu") @@ -422,6 +426,22 @@ def is_torchdata_stateful_dataloader_available(): return False +def torchao_required(func): + """ + A decorator that ensures the decorated function is only called when torchao is available. + """ + + @wraps(func) + def wrapper(*args, **kwargs): + if not is_torchao_available(): + raise ImportError( + "`torchao` is not available, please install it before calling this function via `pip install torchao`." + ) + return func(*args, **kwargs) + + return wrapper + + # TODO: Rework this into `utils.deepspeed` and migrate the "core" chunks into `accelerate.deepspeed` def deepspeed_required(func): """ From a032f1a5b1b11efaf0785467b255ccfa3d9c47d1 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Thu, 16 Jan 2025 10:55:29 -0500 Subject: [PATCH 05/25] Clean --- benchmarks/fp8/torchao/non_distributed.py | 35 +++++++++++++---------- src/accelerate/accelerator.py | 24 ++++++++++------ src/accelerate/utils/__init__.py | 7 +++-- src/accelerate/utils/ao.py | 28 ++++++++---------- src/accelerate/utils/dataclasses.py | 34 +++++++++------------- src/accelerate/utils/imports.py | 2 +- 6 files changed, 67 insertions(+), 63 deletions(-) diff --git a/benchmarks/fp8/torchao/non_distributed.py b/benchmarks/fp8/torchao/non_distributed.py index 81eb0d2bc73..e2426f162d5 100644 --- a/benchmarks/fp8/torchao/non_distributed.py +++ b/benchmarks/fp8/torchao/non_distributed.py @@ -18,9 +18,10 @@ This particular script verifies this for single GPU training. """ +from functools import partial + import evaluate import torch -from functools import partial from datasets import load_dataset from torch.optim import AdamW from torch.utils.data import DataLoader @@ -28,6 +29,7 @@ from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup from accelerate import Accelerator +from accelerate.state import AcceleratorState from accelerate.utils import AORecipeKwargs, set_seed @@ -169,8 +171,10 @@ def train_baseline(): def train_integration(): set_seed(42) - model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME) accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=[AORecipeKwargs()]) + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities( + MODEL_NAME, accelerator=accelerator + ) model = accelerator.prepare(model) base_model_results = evaluate_model(model, eval_dataloader, METRIC) model.train() @@ -196,17 +200,18 @@ def train_integration(): if __name__ == "__main__": - # baseline_not_trained, baseline_trained = train_baseline() + baseline_not_trained, baseline_trained = train_baseline() + AcceleratorState._reset_state(True) accelerator_not_trained, accelerator_trained = train_integration() - # assert ( - # baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] - # ), f'Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' - # assert ( - # baseline_not_trained["f1"] == accelerator_not_trained["f1"] - # ), f'F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' - # assert ( - # baseline_trained["accuracy"] == accelerator_trained["accuracy"] - # ), f'Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' - # assert ( - # baseline_trained["f1"] == accelerator_trained["f1"] - # ), f'F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' + assert ( + baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] + ), f'Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' + assert ( + baseline_not_trained["f1"] == accelerator_not_trained["f1"] + ), f'F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' + assert ( + baseline_trained["accuracy"] == accelerator_trained["accuracy"] + ), f'Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' + assert ( + baseline_trained["f1"] == accelerator_trained["f1"] + ), f'F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 476e0814b2d..3544fc67fea 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -29,11 +29,12 @@ from types import MethodType from typing import Any, Callable, Union -from accelerate.utils.imports import is_torchao_available import torch import torch.utils.hooks as hooks from huggingface_hub import split_torch_state_dict_into_shards +from accelerate.utils.imports import is_torchao_available + from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state from .data_loader import DataLoaderDispatcher, prepare_data_loader, skip_first_batches from .logging import get_logger @@ -49,10 +50,8 @@ WEIGHTS_INDEX_NAME, WEIGHTS_NAME, WEIGHTS_PATTERN_NAME, - AutocastKwargs, AORecipeKwargs, - TERecipeKwargs, - MSAMPRecipeKwargs, + AutocastKwargs, DataLoaderConfiguration, DeepSpeedPlugin, DistributedDataParallelKwargs, @@ -66,10 +65,12 @@ KwargsHandler, LoggerType, MegatronLMPlugin, + MSAMPRecipeKwargs, PrecisionType, ProfileKwargs, ProjectConfiguration, RNGType, + TERecipeKwargs, TorchDynamoPlugin, TorchTensorParallelPlugin, apply_fp8_autowrap, @@ -77,8 +78,8 @@ clean_state_dict_for_safetensors, compare_versions, convert_model, - convert_to_float8_training, convert_outputs_to_fp32, + convert_to_float8_training, ensure_weights_retied, extract_model_from_parallel, gather, @@ -472,7 +473,9 @@ def __init__( elif is_msamp_available(): self.msamp_recipe_handler = MSAMPRecipeKwargs() else: - raise ImportError("Tried to train with `fp8` and auto-detect backend, but no FP8-compatible backend was installed.") + raise ImportError( + "Tried to train with `fp8` and auto-detect backend, but no FP8-compatible backend was installed." + ) self.delayed_fp8_autocast = False if self.has_fp8_handler: @@ -1671,8 +1674,13 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e def _prepare_ao(self, *args): if not is_torchao_available(): raise ImportError("`torchao` was not found on your system. Please ensure that `torchao` is installed") - for model in self._models: - convert_to_float8_training(model, config=self.ao_recipe_handler.config, module_filter_func=self.ao_recipe_handler.module_filter_func) + for arg in args: + if isinstance(arg, torch.nn.Module): + convert_to_float8_training( + arg, + config=self.ao_recipe_handler.config, + module_filter_func=self.ao_recipe_handler.module_filter_func, + ) return args def _prepare_te(self, *args): diff --git a/src/accelerate/utils/__init__.py b/src/accelerate/utils/__init__.py index cad3a06f018..8e760ccfb54 100644 --- a/src/accelerate/utils/__init__.py +++ b/src/accelerate/utils/__init__.py @@ -33,8 +33,8 @@ XPU_PROFILING_AVAILABLE_PYTORCH_VERSION, ) from .dataclasses import ( - AutocastKwargs, AORecipeKwargs, + AutocastKwargs, BnbQuantizationConfig, ComputeEnvironment, CustomDtype, @@ -52,6 +52,7 @@ KwargsHandler, LoggerType, MegatronLMPlugin, + MSAMPRecipeKwargs, PrecisionType, ProfileKwargs, ProjectConfiguration, @@ -61,7 +62,7 @@ TorchDynamoPlugin, TorchTensorParallelPlugin, TERecipeKwargs, - MSAMPRecipeKwargs, + TorchDynamoPlugin, add_model_config_to_megatron_parser, ) from .environment import ( @@ -82,7 +83,6 @@ ) from .imports import ( deepspeed_required, - torchao_required, get_ccl_version, is_4bit_bnb_available, is_8bit_bnb_available, @@ -130,6 +130,7 @@ is_wandb_available, is_weights_only_available, is_xpu_available, + torchao_required, ) from .modeling import ( align_module_device, diff --git a/src/accelerate/utils/ao.py b/src/accelerate/utils/ao.py index 1d21738c495..e0a2cf93d73 100644 --- a/src/accelerate/utils/ao.py +++ b/src/accelerate/utils/ao.py @@ -27,9 +27,7 @@ def find_first_last_linear_layers(model: torch.nn.Module): """ Finds the first and last linear layer names in a model. - This is needed during FP8 to avoid issues with - instability by keeping the first and last layers - unquantized. + This is needed during FP8 to avoid issues with instability by keeping the first and last layers unquantized. Ref: https://x.com/xariusrke/status/1826669142604141052 """ @@ -72,31 +70,29 @@ def filter_linear_layers(module, layer_name, first_layer_name, last_layer_name) @torchao_required def convert_to_float8_training( - model: torch.nn.Module, - config=None, - module_filter_func=None, - ): + model: torch.nn.Module, + config=None, + module_filter_func=None, +): """ - Converts all `nn.Linear` layers in the model (except the first and last) - to torchao's `Float8Linear` layer inplace. + Converts all `nn.Linear` layers in the model (except the first and last) to torchao's `Float8Linear` layer inplace. Args: model (`torch.nn.Module`): The model to convert. config (`torchao.float8.Float8LinearConfig`, *optional*): The configuration for the FP8 training. Recommended to utilize - `torchao.float8.recipe_name_to_linear_config` to generate this. - In general, the default config should be sufficient. + `torchao.float8.recipe_name_to_linear_config` to generate this. In general, the default config should be + sufficient. module_filter_func (`Callable`, *optional*): - Optional function that must take in a module and layer name, - and returns a boolean indicating whether the module should be - converted to FP8. Defaults to `filter_linear_layers`. See - it for an example. + Optional function that must take in a module and layer name, and returns a boolean indicating whether the + module should be converted to FP8. Defaults to `filter_linear_layers`. See it for an example. Example: ```python from accelerate.utils.ao import convert_to_float8_training + model = MyModel() model.to("cuda") convert_to_float8_training(model) @@ -109,4 +105,4 @@ def convert_to_float8_training( first_linear, last_linear = find_first_last_linear_layers(model) if module_filter_func is None: module_filter_func = partial(filter_linear_layers, first_layer_name=first_linear, last_layer_name=last_linear) - convert_to_float8_training(model, config, module_filter_func) + convert_to_float8_training(model, module_filter_fn=module_filter_func, config=config) diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index aa347100d3c..0490ca72085 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -20,17 +20,16 @@ import copy import enum import functools +import logging import os import warnings -import logging from contextlib import contextmanager from dataclasses import dataclass, field from datetime import timedelta -from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union, get_args, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union, get_args import torch -from .ao import filter_linear_layers from .constants import ( BETA_TP_AVAILABLE_PYTORCH_VERSION, FSDP_AUTO_WRAP_POLICY, @@ -51,6 +50,7 @@ ) from .versions import compare_versions, is_torch_version + if TYPE_CHECKING: # Mock imports for type checking from torchao.float8 import Float8LinearConfig @@ -298,25 +298,20 @@ class AORecipeKwargs(KwargsHandler): Args: recipe_name (`str`, *optional*, default to `None`): - The name of the recipe to use for FP8 training. Should - be compatible with `torchao.float8.recipe_name_to_linear_config`. + The name of the recipe to use for FP8 training. Should be compatible with + `torchao.float8.recipe_name_to_linear_config`. config (`torchao.float8.Float8LinearConfig`, *optional*, default to `None`): - The configuration for the FP8 training. In general, the default config - should be sufficient. + The configuration for the FP8 training. In general, the default config should be sufficient. module_filter_func (`Callable`, *optional*, default to `None`): - Optional function that must take in a module and layer name, - and returns a boolean indicating whether the module should be - converted to FP8. Defaults to `accelerate.utils.ao.filter_linear_layers`. See - it for an example. + Optional function that must take in a module and layer name, and returns a boolean indicating whether the + module should be converted to FP8. Defaults to `accelerate.utils.ao.filter_linear_layers`. See it for an + example. """ + recipe_name: str = None config: "Float8LinearConfig" = None module_filter_func: Callable = None - def __post_init__(self): - if self.module_filter_func is None: - self.module_filter_func = filter_linear_layers - @dataclass class TERecipeKwargs(KwargsHandler): @@ -356,6 +351,7 @@ class TERecipeKwargs(KwargsHandler): override_linear_precision (`tuple` of three `bool`, *optional*, default to `(False, False, False)`): Whether or not to execute `fprop`, `dgrad`, and `wgrad` GEMMS in higher precision. """ + use_autocast_during_eval: bool = None margin: int = None interval: int = None @@ -367,9 +363,7 @@ class TERecipeKwargs(KwargsHandler): def __post_init__(self): env_prefix = "ACCELERATE_FP8_" if not is_transformer_engine_available(): - raise ImportError( - "TransformerEngine is not available. Please install it or use a different backend." - ) + raise ImportError("TransformerEngine is not available. Please install it or use a different backend.") if self.use_autocast_during_eval is None: self.use_autocast_during_eval = parse_flag_from_env(env_prefix + "USE_AUTOCAST_DURING_EVAL") if self.margin is None: @@ -401,6 +395,7 @@ class MSAMPRecipeKwargs(KwargsHandler): Use this object in your [`Accelerator`] to customize the initialization of the recipe for FP8 mixed precision training with `ms-amp`. """ + opt_level: OptLevel = None def __post_init__(self): @@ -414,8 +409,7 @@ def __post_init__(self): @dataclass class FP8RecipeKwargs(TERecipeKwargs, MSAMPRecipeKwargs): """ - Deprecated. Please use one of the proper FP8 recipe - kwargs classes such as `TERecipeKwargs` or `MSAMPRecipeKwargs` + Deprecated. Please use one of the proper FP8 recipe kwargs classes such as `TERecipeKwargs` or `MSAMPRecipeKwargs` instead. """ diff --git a/src/accelerate/utils/imports.py b/src/accelerate/utils/imports.py index 38daacf5831..07f53e4fb91 100644 --- a/src/accelerate/utils/imports.py +++ b/src/accelerate/utils/imports.py @@ -110,7 +110,7 @@ def is_lomo_available(): def is_fp8_available(): - return is_msamp_available() or is_transformer_engine_available() + return is_msamp_available() or is_transformer_engine_available() or is_torchao_available() def is_cuda_available(): From 7197943672b1acee57b45a456f785972b033b37f Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Fri, 17 Jan 2025 08:57:08 -0500 Subject: [PATCH 06/25] DDP varient working --- benchmarks/fp8/torchao/ddp.py | 126 +++++++++++++++++++++++----------- 1 file changed, 85 insertions(+), 41 deletions(-) diff --git a/benchmarks/fp8/torchao/ddp.py b/benchmarks/fp8/torchao/ddp.py index ba708a27be4..873f13918c5 100644 --- a/benchmarks/fp8/torchao/ddp.py +++ b/benchmarks/fp8/torchao/ddp.py @@ -13,69 +13,116 @@ # limitations under the License. """ -This script tests to ensure that `accelerate` performs at the same level as raw `TransformersEngine`. +This script tests to ensure that `accelerate` performs at the same level as raw `torchao`. This particular script verifies this for DDP training. """ +from functools import partial + import evaluate import torch -import transformer_engine.common.recipe as te_recipe -import transformer_engine.pytorch as te -from fp8_utils import evaluate_model, get_named_parameters, get_training_utilities +from datasets import load_dataset from torch.nn.parallel import DistributedDataParallel as DDP -from transformer_engine.common.recipe import DelayedScaling +from torch.optim import AdamW +from torch.utils.data import DataLoader +from torchao.float8 import convert_to_float8_training +from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup from accelerate import Accelerator from accelerate.state import AcceleratorState -from accelerate.utils import FP8RecipeKwargs, set_seed -from accelerate.utils.transformer_engine import convert_model +from accelerate.utils import AORecipeKwargs, set_seed + +from fp8_utils import get_dataloaders MODEL_NAME = "bert-base-cased" METRIC = evaluate.load("glue", "mrpc") +def get_training_utilities(model_name: str, batch_size: int = 16, accelerator=None): + """ + Returns a tuple of: + - Model + - Optimizer + - Train dataloader (prepared) + - Eval dataloader (prepared) + - LR Scheduler + Suitable for training on the MRPC dataset + """ + + if accelerator is None: + accelerator = Accelerator() + model = AutoModelForSequenceClassification.from_pretrained(model_name) + train_dataloader, eval_dataloader = get_dataloaders(model_name, batch_size) + optimizer = AdamW(model.parameters(), lr=0.0001) + lr_scheduler = get_linear_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=100, + num_training_steps=len(train_dataloader) * 2, + ) + train_dataloader, eval_dataloader = accelerator.prepare(train_dataloader, eval_dataloader) + return model, optimizer, train_dataloader, eval_dataloader, lr_scheduler + + +def evaluate_model(model, dataloader, metric, accelerator=None): + "Turns model to .eval(), runs dataloader, calculates metric, then turns eval back on" + model.eval() + for step, batch in enumerate(dataloader): + with torch.no_grad(): + outputs = model(**batch) + predictions = outputs.logits.argmax(dim=-1) + references = batch["labels"] + if accelerator is not None and accelerator.num_processes > 1: + predictions, references = accelerator.gather_for_metrics((predictions, references)) + metric.add_batch(predictions=predictions, references=references) + return metric.compute() + + +def filter_linear_layers(module, fqn, first_layer_name=None, last_layer_name=None): + if isinstance(module, torch.nn.Linear): + if module.in_features % 16 != 0 or module.out_features % 16 != 0: + return False + # For stability reasons, we skip the first and last linear layers + # Otherwise can lead to the model not training or converging properly + if fqn in (first_layer_name, last_layer_name): + return False + return True + + def train_baseline(): set_seed(42) model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME) + first_linear = None + last_linear = None + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + if first_linear is None: + first_linear = name + last_linear = name + func = partial(filter_linear_layers, first_layer_name=first_linear, last_layer_name=last_linear) accelerator = Accelerator() device = accelerator.device model.to(device) - # Convert the model to TE - old_named_params = get_named_parameters(model) - - with torch.no_grad(): - convert_model(model) - - FP8_RECIPE_KWARGS = {"fp8_format": te_recipe.Format.HYBRID, "amax_history_len": 32, "amax_compute_algo": "max"} - fp8_recipe = DelayedScaling(**FP8_RECIPE_KWARGS) - - new_named_params = get_named_parameters(model) + convert_to_float8_training(model, module_filter_fn=func) # Convert the model to DDP device_ids, output_device = [accelerator.local_process_index], accelerator.local_process_index model = DDP(model, device_ids=device_ids, output_device=output_device) - mapping = {p: new_named_params[n] for n, p in old_named_params.items()} - for param_group in optimizer.param_groups: - param_group["params"] = [mapping[p] for p in param_group["params"]] - base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) model.train() - for _ in range(2): - for batch in train_dataloader: - with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - batch = batch.to(device) - outputs = model(**batch) + for batch in train_dataloader: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + batch = batch.to(device) + outputs = model(**batch) loss = outputs.loss loss.backward() - optimizer.step() - optimizer.zero_grad() - lr_scheduler.step() + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) @@ -90,10 +137,8 @@ def train_baseline(): def train_integration(): - FP8_RECIPE_KWARGS = {"fp8_format": "HYBRID", "amax_history_len": 32, "amax_compute_algo": "max"} - kwargs_handlers = [FP8RecipeKwargs(backend="TE", **FP8_RECIPE_KWARGS)] AcceleratorState()._reset_state(True) - accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=kwargs_handlers) + accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=[AORecipeKwargs()]) set_seed(42) model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities( MODEL_NAME, accelerator=accelerator @@ -103,14 +148,13 @@ def train_integration(): base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) model.train() - for _ in range(2): - for batch in train_dataloader: - outputs = model(**batch) - loss = outputs.loss - accelerator.backward(loss) - optimizer.step() - optimizer.zero_grad() - lr_scheduler.step() + for batch in train_dataloader: + outputs = model(**batch) + loss = outputs.loss + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) From dc797fdeb325504da42ac7dd5905d4fb70c1fb08 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Fri, 17 Jan 2025 09:03:49 -0500 Subject: [PATCH 07/25] FSDP as well --- benchmarks/fp8/torchao/ddp.py | 9 +-- benchmarks/fp8/torchao/fsdp.py | 127 +++++++++++++++++++++------------ 2 files changed, 86 insertions(+), 50 deletions(-) diff --git a/benchmarks/fp8/torchao/ddp.py b/benchmarks/fp8/torchao/ddp.py index 873f13918c5..0b7e6071ac2 100644 --- a/benchmarks/fp8/torchao/ddp.py +++ b/benchmarks/fp8/torchao/ddp.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,19 +22,16 @@ import evaluate import torch -from datasets import load_dataset +from fp8_utils import get_dataloaders from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import AdamW -from torch.utils.data import DataLoader from torchao.float8 import convert_to_float8_training -from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup +from transformers import AutoModelForSequenceClassification, get_linear_schedule_with_warmup from accelerate import Accelerator from accelerate.state import AcceleratorState from accelerate.utils import AORecipeKwargs, set_seed -from fp8_utils import get_dataloaders - MODEL_NAME = "bert-base-cased" METRIC = evaluate.load("glue", "mrpc") diff --git a/benchmarks/fp8/torchao/fsdp.py b/benchmarks/fp8/torchao/fsdp.py index 418122185e1..a047f27bd86 100644 --- a/benchmarks/fp8/torchao/fsdp.py +++ b/benchmarks/fp8/torchao/fsdp.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +13,7 @@ # limitations under the License. """ -This script tests to ensure that `accelerate` performs at the same level as raw `TransformersEngine`. +This script tests to ensure that `accelerate` performs at the same level as raw `torchao`. This particular script verifies this for FSDP training. """ @@ -22,20 +22,19 @@ import evaluate import torch -import transformer_engine.common.recipe as te_recipe -import transformer_engine.pytorch as te -from fp8_utils import evaluate_model, get_named_parameters, get_training_utilities +from fp8_utils import get_dataloaders from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import MixedPrecision from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy -from transformer_engine.common.recipe import DelayedScaling +from torch.optim import AdamW +from torchao.float8 import convert_to_float8_training +from transformers import AutoModelForSequenceClassification, get_linear_schedule_with_warmup from transformers.models.bert import BertLayer from accelerate import Accelerator from accelerate import FullyShardedDataParallelPlugin as FSDPPlugin from accelerate.state import AcceleratorState -from accelerate.utils import FP8RecipeKwargs, set_seed -from accelerate.utils.transformer_engine import convert_model +from accelerate.utils import AORecipeKwargs, set_seed MODEL_NAME = "bert-base-cased" @@ -44,23 +43,72 @@ FSDP_WRAP_POLICY = partial(transformer_auto_wrap_policy, transformer_layer_cls={BertLayer}) +def get_training_utilities(model_name: str, batch_size: int = 16, accelerator=None): + """ + Returns a tuple of: + - Model + - Optimizer + - Train dataloader (prepared) + - Eval dataloader (prepared) + - LR Scheduler + Suitable for training on the MRPC dataset + """ + + if accelerator is None: + accelerator = Accelerator() + model = AutoModelForSequenceClassification.from_pretrained(model_name) + train_dataloader, eval_dataloader = get_dataloaders(model_name, batch_size) + optimizer = AdamW(model.parameters(), lr=0.0001) + lr_scheduler = get_linear_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=100, + num_training_steps=len(train_dataloader) * 2, + ) + train_dataloader, eval_dataloader = accelerator.prepare(train_dataloader, eval_dataloader) + return model, optimizer, train_dataloader, eval_dataloader, lr_scheduler + + +def filter_linear_layers(module, fqn, first_layer_name=None, last_layer_name=None): + if isinstance(module, torch.nn.Linear): + if module.in_features % 16 != 0 or module.out_features % 16 != 0: + return False + # For stability reasons, we skip the first and last linear layers + # Otherwise can lead to the model not training or converging properly + if fqn in (first_layer_name, last_layer_name): + return False + return True + + +def evaluate_model(model, dataloader, metric, accelerator=None): + "Turns model to .eval(), runs dataloader, calculates metric, then turns eval back on" + model.eval() + for step, batch in enumerate(dataloader): + with torch.no_grad(): + outputs = model(**batch) + predictions = outputs.logits.argmax(dim=-1) + references = batch["labels"] + if accelerator is not None and accelerator.num_processes > 1: + predictions, references = accelerator.gather_for_metrics((predictions, references)) + metric.add_batch(predictions=predictions, references=references) + return metric.compute() + + def train_baseline(): set_seed(42) model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME) + first_linear = None + last_linear = None + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + if first_linear is None: + first_linear = name + last_linear = name + func = partial(filter_linear_layers, first_layer_name=first_linear, last_layer_name=last_linear) accelerator = Accelerator() device = accelerator.device model.to(device) - # Convert the model to TE - old_named_params = get_named_parameters(model) - - with torch.no_grad(): - convert_model(model) - - FP8_RECIPE_KWARGS = {"fp8_format": te_recipe.Format.HYBRID, "amax_history_len": 32, "amax_compute_algo": "max"} - fp8_recipe = DelayedScaling(**FP8_RECIPE_KWARGS) - - new_named_params = get_named_parameters(model) + convert_to_float8_training(model, module_filter_fn=func) # Convert the model to FSDP model = FSDP( @@ -70,24 +118,18 @@ def train_baseline(): auto_wrap_policy=FSDP_WRAP_POLICY, ) - mapping = {p: new_named_params[n] for n, p in old_named_params.items()} - for param_group in optimizer.param_groups: - param_group["params"] = [mapping[p] for p in param_group["params"]] - base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) model.train() - for _ in range(2): - for batch in train_dataloader: - with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - batch = batch.to(device) - outputs = model(**batch) - loss = outputs.loss - loss.backward() - optimizer.step() - optimizer.zero_grad() - lr_scheduler.step() + for batch in train_dataloader: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + batch = batch.to(device) + outputs = model(**batch) + loss = outputs.loss + loss.backward() + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) @@ -102,15 +144,13 @@ def train_baseline(): def train_integration(): - FP8_RECIPE_KWARGS = {"fp8_format": "HYBRID", "amax_history_len": 32, "amax_compute_algo": "max"} - kwargs_handlers = [FP8RecipeKwargs(backend="TE", **FP8_RECIPE_KWARGS)] AcceleratorState()._reset_state(True) fsdp_plugin = FSDPPlugin( auto_wrap_policy=FSDP_WRAP_POLICY, use_orig_params=True, mixed_precision_policy=MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32), ) - accelerator = Accelerator(mixed_precision="fp8", fsdp_plugin=fsdp_plugin, kwargs_handlers=kwargs_handlers) + accelerator = Accelerator(mixed_precision="fp8", fsdp_plugin=fsdp_plugin, kwargs_handlers=[AORecipeKwargs()]) set_seed(42) model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities( MODEL_NAME, accelerator=accelerator @@ -120,14 +160,13 @@ def train_integration(): base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) model.train() - for _ in range(2): - for batch in train_dataloader: - outputs = model(**batch) - loss = outputs.loss - accelerator.backward(loss) - optimizer.step() - optimizer.zero_grad() - lr_scheduler.step() + for batch in train_dataloader: + outputs = model(**batch) + loss = outputs.loss + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) From 145dec2cc8544fe8299a733f0a520a8654f545c7 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Fri, 17 Jan 2025 09:34:33 -0500 Subject: [PATCH 08/25] Works for all but zero3 --- benchmarks/fp8/torchao/distrib_deepspeed.py | 121 ++++++++++---------- 1 file changed, 63 insertions(+), 58 deletions(-) diff --git a/benchmarks/fp8/torchao/distrib_deepspeed.py b/benchmarks/fp8/torchao/distrib_deepspeed.py index e678deb3659..d8019524a10 100644 --- a/benchmarks/fp8/torchao/distrib_deepspeed.py +++ b/benchmarks/fp8/torchao/distrib_deepspeed.py @@ -13,31 +13,40 @@ # limitations under the License. """ -This script tests to ensure that `accelerate` performs at the same level as raw `TransformersEngine`. +This script tests to ensure that `accelerate` performs at the same level as raw `torchao`. -This particular script verifies this for DDP training. +This particular script verifies this for deepspeed training. """ from unittest.mock import patch +from functools import partial import deepspeed import evaluate import torch -import transformer_engine.common.recipe as te_recipe -import transformer_engine.pytorch as te from fp8_utils import evaluate_model, get_named_parameters, get_training_utilities -from transformer_engine.common.recipe import DelayedScaling from accelerate import Accelerator, DeepSpeedPlugin from accelerate.state import AcceleratorState -from accelerate.utils import FP8RecipeKwargs, set_seed -from accelerate.utils.transformer_engine import convert_model +from accelerate.utils import AORecipeKwargs, set_seed +from torchao.float8 import convert_to_float8_training MODEL_NAME = "bert-base-cased" METRIC = evaluate.load("glue", "mrpc") +def filter_linear_layers(module, fqn, first_layer_name=None, last_layer_name=None): + if isinstance(module, torch.nn.Linear): + if module.in_features % 16 != 0 or module.out_features % 16 != 0: + return False + # For stability reasons, we skip the first and last linear layers + # Otherwise can lead to the model not training or converging properly + if fqn in (first_layer_name, last_layer_name): + return False + return True + + def train_baseline(zero_stage: int = 1): # This forces transformers to think Zero-3 Init should be used with patch("transformers.integrations.deepspeed.is_deepspeed_zero3_enabled") as mock: @@ -49,19 +58,17 @@ def train_baseline(zero_stage: int = 1): MODEL_NAME, accelerator=accelerator ) - # Convert the model to TE - old_named_params = get_named_parameters(model) - - with torch.no_grad(): - convert_model(model) - new_named_params = get_named_parameters(model) - - mapping = {p: new_named_params[n] for n, p in old_named_params.items()} - for param_group in optimizer.param_groups: - param_group["params"] = [mapping[p] for p in param_group["params"]] + first_linear = None + last_linear = None + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + if first_linear is None: + first_linear = name + last_linear = name + func = partial(filter_linear_layers, first_layer_name=first_linear, last_layer_name=last_linear) + convert_to_float8_training(model, module_filter_fn=func) - FP8_RECIPE_KWARGS = {"fp8_format": te_recipe.Format.HYBRID, "amax_history_len": 32, "amax_compute_algo": "max"} - fp8_recipe = DelayedScaling(**FP8_RECIPE_KWARGS) + accelerator = Accelerator() import numpy as np @@ -99,17 +106,15 @@ def train_baseline(zero_stage: int = 1): model_outputs = [] data = [] - for _ in range(2): - for batch in train_dataloader: - with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): - outputs = model(**batch) - data.append(batch.to("cpu")) - model_outputs.append(outputs.logits.to("cpu")) - loss = outputs.loss - model.backward(loss) - model.step() - for _ in range(accelerator.num_processes): - lr_scheduler.step() + for batch in train_dataloader: + outputs = model(**batch) + data.append(batch.to("cpu")) + model_outputs.append(outputs.logits.to("cpu")) + loss = outputs.loss + model.backward(loss) + model.step() + for _ in range(accelerator.num_processes): + lr_scheduler.step() trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) model.destroy() @@ -125,15 +130,16 @@ def train_baseline(zero_stage: int = 1): def train_integration(zero_stage: int = 1): set_seed(42) - FP8_RECIPE_KWARGS = {"fp8_format": "HYBRID", "amax_history_len": 32, "amax_compute_algo": "max"} - kwargs_handlers = [FP8RecipeKwargs(backend="TE", **FP8_RECIPE_KWARGS)] AcceleratorState()._reset_state(True) + # This forces transformers to think Zero-3 Init should be used + with patch("transformers.integrations.deepspeed.is_deepspeed_zero3_enabled") as mock: + mock.return_value = zero_stage == 3 deepspeed_plugin = DeepSpeedPlugin( zero_stage=zero_stage, zero3_init_flag=zero_stage == 3, ) accelerator = Accelerator( - mixed_precision="fp8", kwargs_handlers=kwargs_handlers, deepspeed_plugin=deepspeed_plugin + mixed_precision="fp8", kwargs_handlers=[AORecipeKwargs()], deepspeed_plugin=deepspeed_plugin ) accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = 16 @@ -146,16 +152,15 @@ def train_integration(zero_stage: int = 1): model.train() model_outputs = [] data = [] - for _ in range(2): - for batch in train_dataloader: - outputs = model(**batch) - data.append(batch.to("cpu")) - model_outputs.append(outputs.logits.to("cpu")) - loss = outputs.loss - accelerator.backward(loss) - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() + for batch in train_dataloader: + outputs = model(**batch) + data.append(batch.to("cpu")) + model_outputs.append(outputs.logits.to("cpu")) + loss = outputs.loss + accelerator.backward(loss) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) model.destroy() @@ -171,20 +176,20 @@ def train_integration(zero_stage: int = 1): if __name__ == "__main__": # for zero_stage in [1, 2, 3]: - zero_stage = 1 - baseline_not_trained, baseline_trained, baseline_outputs, baseline_data = train_baseline(zero_stage) - accelerator_not_trained, accelerator_trained, accelerator_outputs, accelerator_data = train_integration(zero_stage) - assert ( - baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] - ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' - assert ( - baseline_not_trained["f1"] == accelerator_not_trained["f1"] - ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' - assert ( - baseline_trained["accuracy"] == accelerator_trained["accuracy"] - ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' - assert ( - baseline_trained["f1"] == accelerator_trained["f1"] - ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' + for zero_stage in [3]: + baseline_not_trained, baseline_trained, baseline_outputs, baseline_data = train_baseline(zero_stage) + accelerator_not_trained, accelerator_trained, accelerator_outputs, accelerator_data = train_integration(zero_stage) + assert ( + baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] + ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' + assert ( + baseline_not_trained["f1"] == accelerator_not_trained["f1"] + ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' + assert ( + baseline_trained["accuracy"] == accelerator_trained["accuracy"] + ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' + assert ( + baseline_trained["f1"] == accelerator_trained["f1"] + ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' torch.distributed.destroy_process_group() From 676d5ac1b538d0d63bb64dc178f373c7bb90abcf Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Fri, 17 Jan 2025 09:51:44 -0500 Subject: [PATCH 09/25] Bookmark: currently zero3 is underperforming --- benchmarks/fp8/torchao/distrib_deepspeed.py | 29 ++++++++++----------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/benchmarks/fp8/torchao/distrib_deepspeed.py b/benchmarks/fp8/torchao/distrib_deepspeed.py index d8019524a10..836238149a0 100644 --- a/benchmarks/fp8/torchao/distrib_deepspeed.py +++ b/benchmarks/fp8/torchao/distrib_deepspeed.py @@ -131,12 +131,10 @@ def train_baseline(zero_stage: int = 1): def train_integration(zero_stage: int = 1): set_seed(42) AcceleratorState()._reset_state(True) - # This forces transformers to think Zero-3 Init should be used - with patch("transformers.integrations.deepspeed.is_deepspeed_zero3_enabled") as mock: - mock.return_value = zero_stage == 3 deepspeed_plugin = DeepSpeedPlugin( zero_stage=zero_stage, zero3_init_flag=zero_stage == 3, + gradient_clipping=1.0, ) accelerator = Accelerator( mixed_precision="fp8", kwargs_handlers=[AORecipeKwargs()], deepspeed_plugin=deepspeed_plugin @@ -179,17 +177,18 @@ def train_integration(zero_stage: int = 1): for zero_stage in [3]: baseline_not_trained, baseline_trained, baseline_outputs, baseline_data = train_baseline(zero_stage) accelerator_not_trained, accelerator_trained, accelerator_outputs, accelerator_data = train_integration(zero_stage) - assert ( - baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] - ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' - assert ( - baseline_not_trained["f1"] == accelerator_not_trained["f1"] - ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' - assert ( - baseline_trained["accuracy"] == accelerator_trained["accuracy"] - ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' - assert ( - baseline_trained["f1"] == accelerator_trained["f1"] - ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' + print(baseline_trained, accelerator_trained) + # assert ( + # baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] + # ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' + # assert ( + # baseline_not_trained["f1"] == accelerator_not_trained["f1"] + # ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' + # assert ( + # baseline_trained["accuracy"] == accelerator_trained["accuracy"] + # ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' + # assert ( + # baseline_trained["f1"] == accelerator_trained["f1"] + # ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' torch.distributed.destroy_process_group() From 92b3d9bbf335128472a998d99c42a553b4ae8ae8 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Fri, 7 Feb 2025 10:17:06 -0500 Subject: [PATCH 10/25] Bookmark --- benchmarks/fp8/torchao/distrib_deepspeed.py | 66 ++++++++++++--------- benchmarks/fp8/torchao/fp8_utils.py | 2 +- 2 files changed, 40 insertions(+), 28 deletions(-) diff --git a/benchmarks/fp8/torchao/distrib_deepspeed.py b/benchmarks/fp8/torchao/distrib_deepspeed.py index 836238149a0..04bbfd0d55f 100644 --- a/benchmarks/fp8/torchao/distrib_deepspeed.py +++ b/benchmarks/fp8/torchao/distrib_deepspeed.py @@ -31,6 +31,8 @@ from accelerate.utils import AORecipeKwargs, set_seed from torchao.float8 import convert_to_float8_training +from transformers.integrations import HfDeepSpeedConfig + MODEL_NAME = "bert-base-cased" METRIC = evaluate.load("glue", "mrpc") @@ -48,16 +50,20 @@ def filter_linear_layers(module, fqn, first_layer_name=None, last_layer_name=Non def train_baseline(zero_stage: int = 1): - # This forces transformers to think Zero-3 Init should be used - with patch("transformers.integrations.deepspeed.is_deepspeed_zero3_enabled") as mock: - mock.return_value = zero_stage == 3 set_seed(42) - accelerator = Accelerator() + config = HfDeepSpeedConfig( + { + "train_micro_batch_size_per_gpu": 16, + "gradient_accumulation_steps": 1, + "zero_optimization": {"stage": zero_stage}, + } + ) + plugin = DeepSpeedPlugin(hf_ds_config=config) + accelerator = Accelerator(deepspeed_plugin=plugin) model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities( MODEL_NAME, accelerator=accelerator ) - first_linear = None last_linear = None for name, module in model.named_modules(): @@ -66,9 +72,8 @@ def train_baseline(zero_stage: int = 1): first_linear = name last_linear = name func = partial(filter_linear_layers, first_layer_name=first_linear, last_layer_name=last_linear) - convert_to_float8_training(model, module_filter_fn=func) - accelerator = Accelerator() + convert_to_float8_training(model, module_filter_fn=func) import numpy as np @@ -125,27 +130,34 @@ def train_baseline(zero_stage: int = 1): trained_model_results["f1"] > base_model_results["f1"] ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' + del config return base_model_results, trained_model_results, model_outputs, data def train_integration(zero_stage: int = 1): set_seed(42) AcceleratorState()._reset_state(True) + config = HfDeepSpeedConfig( + { + "train_micro_batch_size_per_gpu": 16, + "gradient_accumulation_steps": 1, + "zero_optimization": {"stage": zero_stage}, + } + ) deepspeed_plugin = DeepSpeedPlugin( - zero_stage=zero_stage, - zero3_init_flag=zero_stage == 3, - gradient_clipping=1.0, + hf_ds_config=config, ) accelerator = Accelerator( mixed_precision="fp8", kwargs_handlers=[AORecipeKwargs()], deepspeed_plugin=deepspeed_plugin ) - accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = 16 model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities( MODEL_NAME, accelerator=accelerator ) - model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) + model, optimizer, lr_scheduler, train_dataloader, eval_dataloader = accelerator.prepare( + model, optimizer, lr_scheduler, train_dataloader, eval_dataloader + ) base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) model.train() model_outputs = [] @@ -169,26 +181,26 @@ def train_integration(zero_stage: int = 1): trained_model_results["f1"] > base_model_results["f1"] ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' + del config return base_model_results, trained_model_results, model_outputs, data if __name__ == "__main__": - # for zero_stage in [1, 2, 3]: for zero_stage in [3]: + # Expected baseline: ValueError: {'accuracy': 0.7916666666666666, 'f1': 0.8513011152416357} baseline_not_trained, baseline_trained, baseline_outputs, baseline_data = train_baseline(zero_stage) accelerator_not_trained, accelerator_trained, accelerator_outputs, accelerator_data = train_integration(zero_stage) - print(baseline_trained, accelerator_trained) - # assert ( - # baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] - # ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' - # assert ( - # baseline_not_trained["f1"] == accelerator_not_trained["f1"] - # ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' - # assert ( - # baseline_trained["accuracy"] == accelerator_trained["accuracy"] - # ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' - # assert ( - # baseline_trained["f1"] == accelerator_trained["f1"] - # ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' - + assert ( + baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] + ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' + assert ( + baseline_not_trained["f1"] == accelerator_not_trained["f1"] + ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' + assert ( + baseline_trained["accuracy"] == accelerator_trained["accuracy"] + ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' + assert ( + baseline_trained["f1"] == accelerator_trained["f1"] + ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' + AcceleratorState()._reset_state(True) torch.distributed.destroy_process_group() diff --git a/benchmarks/fp8/torchao/fp8_utils.py b/benchmarks/fp8/torchao/fp8_utils.py index d28702e05ff..e010d9b4fba 100644 --- a/benchmarks/fp8/torchao/fp8_utils.py +++ b/benchmarks/fp8/torchao/fp8_utils.py @@ -62,7 +62,7 @@ def collate_fn(examples): return train_dataloader, eval_dataloader -def get_training_utilities(model_name: str, batch_size: int = 16, accelerator=None): +def get_training_utilities(model_name: str, batch_size: int = 16, accelerator=None, prepare=True): """ Returns a tuple of: - Model From 660b6b59e93648e8808181f4ac7d3d83c2324e4b Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Mon, 10 Feb 2025 09:36:30 -0500 Subject: [PATCH 11/25] Another diff --- src/accelerate/accelerator.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 3544fc67fea..0a12ecf388c 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -3678,10 +3678,6 @@ def lomo_backward(self, loss: torch.Tensor, learning_rate: float) -> None: @property def fp8_backend(self): "Returns the configured backend for training in FP8" -<<<<<<< HEAD - if self._mixed_precision == "fp8" and self.fp8_recipe_handler is not None: - return self.fp8_recipe_handler.backend -======= if self.has_fp8_handler: if self.fp8_recipe_handler is not None: return self.fp8_recipe_handler.backend @@ -3691,7 +3687,6 @@ def fp8_backend(self): return "TE" elif self.msamp_recipe_handler is not None: return "MSAMP" ->>>>>>> be210db (Currently broken) elif self.state.deepspeed_plugin is not None and self.state.deepspeed_plugin.enable_msamp: return "MSAMP" return None From ca45f463eaabd4b5c55f6aef560031b049971069 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Mon, 10 Feb 2025 10:22:24 -0500 Subject: [PATCH 12/25] Fin --- benchmarks/fp8/torchao/distrib_deepspeed.py | 53 ++++++----- .../transformer_engine/distrib_deepspeed.py | 4 +- src/accelerate/test_utils/__init__.py | 1 + src/accelerate/test_utils/testing.py | 8 ++ src/accelerate/utils/__init__.py | 5 +- src/accelerate/utils/ao.py | 10 ++ tests/__init__.py | 13 +++ tests/test_fp8.py | 91 +++++++++++++++++-- 8 files changed, 150 insertions(+), 35 deletions(-) create mode 100644 tests/__init__.py diff --git a/benchmarks/fp8/torchao/distrib_deepspeed.py b/benchmarks/fp8/torchao/distrib_deepspeed.py index 04bbfd0d55f..2609af30730 100644 --- a/benchmarks/fp8/torchao/distrib_deepspeed.py +++ b/benchmarks/fp8/torchao/distrib_deepspeed.py @@ -18,20 +18,18 @@ This particular script verifies this for deepspeed training. """ +from functools import partial from unittest.mock import patch -from functools import partial import deepspeed import evaluate import torch -from fp8_utils import evaluate_model, get_named_parameters, get_training_utilities +from fp8_utils import evaluate_model, get_training_utilities +from torchao.float8 import convert_to_float8_training +from transformers.integrations import HfDeepSpeedConfig from accelerate import Accelerator, DeepSpeedPlugin -from accelerate.state import AcceleratorState from accelerate.utils import AORecipeKwargs, set_seed -from torchao.float8 import convert_to_float8_training - -from transformers.integrations import HfDeepSpeedConfig MODEL_NAME = "bert-base-cased" @@ -51,6 +49,9 @@ def filter_linear_layers(module, fqn, first_layer_name=None, last_layer_name=Non def train_baseline(zero_stage: int = 1): set_seed(42) + # This forces transformers to think Zero-3 Init should be used + with patch("transformers.integrations.deepspeed.is_deepspeed_zero3_enabled") as mock: + mock.return_value = zero_stage == 3 config = HfDeepSpeedConfig( { @@ -98,10 +99,11 @@ def train_baseline(zero_stage: int = 1): model, optimizer, _, - _, + lr_scheduler, ) = deepspeed.initialize( model=model, optimizer=optimizer, + lr_scheduler=lr_scheduler, config_params=config, ) @@ -136,7 +138,7 @@ def train_baseline(zero_stage: int = 1): def train_integration(zero_stage: int = 1): set_seed(42) - AcceleratorState()._reset_state(True) + # AcceleratorState()._reset_state(True) config = HfDeepSpeedConfig( { "train_micro_batch_size_per_gpu": 16, @@ -147,6 +149,9 @@ def train_integration(zero_stage: int = 1): deepspeed_plugin = DeepSpeedPlugin( hf_ds_config=config, ) + # This forces transformers to think Zero-3 Init should be used + with patch("transformers.integrations.deepspeed.is_deepspeed_zero3_enabled") as mock: + mock.return_value = zero_stage == 3 accelerator = Accelerator( mixed_precision="fp8", kwargs_handlers=[AORecipeKwargs()], deepspeed_plugin=deepspeed_plugin ) @@ -188,19 +193,21 @@ def train_integration(zero_stage: int = 1): if __name__ == "__main__": for zero_stage in [3]: # Expected baseline: ValueError: {'accuracy': 0.7916666666666666, 'f1': 0.8513011152416357} - baseline_not_trained, baseline_trained, baseline_outputs, baseline_data = train_baseline(zero_stage) - accelerator_not_trained, accelerator_trained, accelerator_outputs, accelerator_data = train_integration(zero_stage) - assert ( - baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] - ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' - assert ( - baseline_not_trained["f1"] == accelerator_not_trained["f1"] - ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' - assert ( - baseline_trained["accuracy"] == accelerator_trained["accuracy"] - ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' - assert ( - baseline_trained["f1"] == accelerator_trained["f1"] - ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' - AcceleratorState()._reset_state(True) + # baseline_not_trained, baseline_trained, baseline_outputs, baseline_data = train_baseline(zero_stage) + accelerator_not_trained, accelerator_trained, accelerator_outputs, accelerator_data = train_integration( + zero_stage + ) + # assert ( + # baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] + # ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' + # assert ( + # baseline_not_trained["f1"] == accelerator_not_trained["f1"] + # ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' + # assert ( + # baseline_trained["accuracy"] == accelerator_trained["accuracy"] + # ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' + # assert ( + # baseline_trained["f1"] == accelerator_trained["f1"] + # ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' + # AcceleratorState()._reset_state(True) torch.distributed.destroy_process_group() diff --git a/benchmarks/fp8/transformer_engine/distrib_deepspeed.py b/benchmarks/fp8/transformer_engine/distrib_deepspeed.py index e678deb3659..7ea77266915 100644 --- a/benchmarks/fp8/transformer_engine/distrib_deepspeed.py +++ b/benchmarks/fp8/transformer_engine/distrib_deepspeed.py @@ -66,7 +66,7 @@ def train_baseline(zero_stage: int = 1): import numpy as np config = { - "train_batch_size": 32, + "train_batch_size": 16, "train_micro_batch_size_per_gpu": 16, "gradient_accumulation_steps": 1, "zero_optimization": { @@ -171,7 +171,7 @@ def train_integration(zero_stage: int = 1): if __name__ == "__main__": # for zero_stage in [1, 2, 3]: - zero_stage = 1 + zero_stage = 3 baseline_not_trained, baseline_trained, baseline_outputs, baseline_data = train_baseline(zero_stage) accelerator_not_trained, accelerator_trained, accelerator_outputs, accelerator_data = train_integration(zero_stage) assert ( diff --git a/src/accelerate/test_utils/__init__.py b/src/accelerate/test_utils/__init__.py index f41473f6363..6b59fb0246d 100644 --- a/src/accelerate/test_utils/__init__.py +++ b/src/accelerate/test_utils/__init__.py @@ -41,6 +41,7 @@ require_single_gpu, require_single_xpu, require_torch_min_version, + require_torchao, require_torchvision, require_tpu, require_transformer_engine, diff --git a/src/accelerate/test_utils/testing.py b/src/accelerate/test_utils/testing.py index bbcfc616ada..48350b0a4a7 100644 --- a/src/accelerate/test_utils/testing.py +++ b/src/accelerate/test_utils/testing.py @@ -53,6 +53,7 @@ is_timm_available, is_torch_version, is_torch_xla_available, + is_torchao_available, is_torchdata_stateful_dataloader_available, is_torchvision_available, is_transformer_engine_available, @@ -425,6 +426,13 @@ def require_transformer_engine(test_case): return unittest.skipUnless(is_transformer_engine_available(), "test requires transformers engine")(test_case) +def require_torchao(test_case): + """ + Decorator marking a test that requires torchao installed. These tests are skipped when torchao isn't installed + """ + return unittest.skipUnless(is_torchao_available(), "test requires torchao")(test_case) + + _atleast_one_tracker_available = ( any([is_wandb_available(), is_tensorboard_available()]) and not is_comet_ml_available() ) diff --git a/src/accelerate/utils/__init__.py b/src/accelerate/utils/__init__.py index 8e760ccfb54..1c9dbff909a 100644 --- a/src/accelerate/utils/__init__.py +++ b/src/accelerate/utils/__init__.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .ao import convert_to_float8_training +from .ao import convert_to_float8_training, has_ao_layers from .constants import ( MITA_PROFILING_AVAILABLE_PYTORCH_VERSION, MODEL_NAME, @@ -59,10 +59,9 @@ RNGType, SageMakerDistributedType, TensorInformation, - TorchDynamoPlugin, - TorchTensorParallelPlugin, TERecipeKwargs, TorchDynamoPlugin, + TorchTensorParallelPlugin, add_model_config_to_megatron_parser, ) from .environment import ( diff --git a/src/accelerate/utils/ao.py b/src/accelerate/utils/ao.py index e0a2cf93d73..8d74774feae 100644 --- a/src/accelerate/utils/ao.py +++ b/src/accelerate/utils/ao.py @@ -68,6 +68,16 @@ def filter_linear_layers(module, layer_name, first_layer_name, last_layer_name) return True +@torchao_required +def has_ao_layers(model: torch.nn.Module): + from torchao.float8.float8_linear import Float8Linear + + for name, module in model.named_modules(): + if isinstance(module, Float8Linear): + return True + return False + + @torchao_required def convert_to_float8_training( model: torch.nn.Module, diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000000..8568c82be1c --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/test_fp8.py b/tests/test_fp8.py index eb35f183b6a..932594c76f4 100644 --- a/tests/test_fp8.py +++ b/tests/test_fp8.py @@ -17,16 +17,29 @@ import unittest import torch +from transformers import AutoModelForSequenceClassification from accelerate import Accelerator from accelerate.state import AcceleratorState -from accelerate.test_utils import get_launch_command, require_cuda, require_multi_gpu, require_transformer_engine +from accelerate.test_utils import ( + get_launch_command, + require_cuda, + require_multi_gpu, + require_torchao, + require_transformer_engine, +) from accelerate.test_utils.testing import require_deepspeed, run_command -from accelerate.utils import FP8RecipeKwargs, has_transformer_engine_layers +from accelerate.utils import ( + AORecipeKwargs, + FP8RecipeKwargs, + has_ao_layers, + has_transformer_engine_layers, + is_torchao_available, + is_transformer_engine_available, +) -def can_convert_model(): - print("Starting basic_fp8_test") +def can_convert_te_model(): accelerator_kwargs = {"mixed_precision": "fp8", "kwargs_handlers": [FP8RecipeKwargs(backend="TE")]} accelerator = Accelerator(**accelerator_kwargs) dataloader = torch.utils.data.DataLoader(torch.randn(10, 32), batch_size=2) @@ -44,6 +57,18 @@ def maintain_proper_deepspeed_config(expected_version): ), f"Expected zero stage {expected_version} but got {AcceleratorState().deepspeed_plugin.zero_stage}" +def can_convert_ao_model(): + accelerator_kwargs = {"mixed_precision": "fp8", "kwargs_handlers": [AORecipeKwargs()]} + accelerator = Accelerator(**accelerator_kwargs) + dataloader = torch.utils.data.DataLoader(torch.randn(10, 32), batch_size=2) + model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased") + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1) + + model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler) + assert has_ao_layers(model) + + @require_transformer_engine class TestTransformerEngine(unittest.TestCase): @require_cuda @@ -91,7 +116,59 @@ def test_can_prepare_model_multigpu_deepspeed(self): run_command(command) +@require_torchao +class TestTorchAO(unittest.TestCase): + @require_cuda + def test_can_prepare_model_single_gpu(self): + command = get_launch_command(num_processes=1, monitor_interval=0.1) + command += ["-m", "tests.test_fp8"] + run_command(command) + + @require_multi_gpu + def test_can_prepare_model_multi_gpu(self): + command = get_launch_command(num_processes=2, monitor_interval=0.1) + command += ["-m", "tests.test_fp8"] + run_command(command) + + @require_deepspeed + @require_multi_gpu + def test_can_prepare_model_multigpu_deepspeed(self): + for zero_stage in [1, 2, 3]: + os.environ["ZERO_STAGE"] = str(zero_stage) + ds_config = { + "bf16": {"enabled": True}, + "zero_optimization": { + "stage": zero_stage, + "allgather_partitions": True, + "allgather_bucket_size": 2e8, + "overlap_comm": True, + "reduce_scatter": True, + "reduce_bucket_size": 2e8, + "contiguous_gradients": True, + }, + "gradient_accumulation_steps": 1, + "gradient_clipping": "auto", + "steps_per_print": 2000, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": False, + } + + ds_config = json.dumps(ds_config) + + command = get_launch_command( + num_processes=2, monitor_interval=0.1, use_deepspeed=True, deepspeed_config_file=ds_config + ) + command += ["-m", "tests.test_fp8"] + run_command(command) + + if __name__ == "__main__": - can_convert_model() - if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true": - maintain_proper_deepspeed_config(int(os.environ.get("ZERO_STAGE"))) + # TE suite + if is_transformer_engine_available(): + can_convert_te_model() + if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true": + maintain_proper_deepspeed_config(int(os.environ.get("ZERO_STAGE"))) + # AO suite + if is_torchao_available(): + can_convert_ao_model() From a0193ce27452cecf1645d10b2818d406e27a6740 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Mon, 10 Feb 2025 10:27:05 -0500 Subject: [PATCH 13/25] Fin --- src/accelerate/utils/dataclasses.py | 4 ---- src/accelerate/utils/imports.py | 6 +++++- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index 0490ca72085..cf6511a7cb0 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -297,9 +297,6 @@ class AORecipeKwargs(KwargsHandler): training with `torchao` FP8. Args: - recipe_name (`str`, *optional*, default to `None`): - The name of the recipe to use for FP8 training. Should be compatible with - `torchao.float8.recipe_name_to_linear_config`. config (`torchao.float8.Float8LinearConfig`, *optional*, default to `None`): The configuration for the FP8 training. In general, the default config should be sufficient. module_filter_func (`Callable`, *optional*, default to `None`): @@ -308,7 +305,6 @@ class AORecipeKwargs(KwargsHandler): example. """ - recipe_name: str = None config: "Float8LinearConfig" = None module_filter_func: Callable = None diff --git a/src/accelerate/utils/imports.py b/src/accelerate/utils/imports.py index 07f53e4fb91..c103b41f737 100644 --- a/src/accelerate/utils/imports.py +++ b/src/accelerate/utils/imports.py @@ -143,7 +143,11 @@ def is_torch_xla_available(check_is_tpu=False, check_is_gpu=False): def is_torchao_available(): - return _is_package_available("torchao") + package_exists = _is_package_available("torchao") + if package_exists: + torchao_version = version.parse(importlib.metadata.version("torchao")) + return compare_versions(torchao_version, ">=", "0.6.1") + return False def is_deepspeed_available(): From b271b13967945dc4c25fced4b2185c6cae9ea937 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Mon, 10 Feb 2025 10:39:38 -0500 Subject: [PATCH 14/25] Add req huggingface suite --- tests/test_fp8.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_fp8.py b/tests/test_fp8.py index 932594c76f4..7e3814c35f2 100644 --- a/tests/test_fp8.py +++ b/tests/test_fp8.py @@ -17,13 +17,13 @@ import unittest import torch -from transformers import AutoModelForSequenceClassification from accelerate import Accelerator from accelerate.state import AcceleratorState from accelerate.test_utils import ( get_launch_command, require_cuda, + require_huggingface_suite, require_multi_gpu, require_torchao, require_transformer_engine, @@ -58,6 +58,8 @@ def maintain_proper_deepspeed_config(expected_version): def can_convert_ao_model(): + from transformers import AutoModelForSequenceClassification + accelerator_kwargs = {"mixed_precision": "fp8", "kwargs_handlers": [AORecipeKwargs()]} accelerator = Accelerator(**accelerator_kwargs) dataloader = torch.utils.data.DataLoader(torch.randn(10, 32), batch_size=2) @@ -117,6 +119,7 @@ def test_can_prepare_model_multigpu_deepspeed(self): @require_torchao +@require_huggingface_suite class TestTorchAO(unittest.TestCase): @require_cuda def test_can_prepare_model_single_gpu(self): From a5d2c29fa193929fddc0a80d474e64d94ded88e8 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Wed, 12 Feb 2025 08:48:44 -0500 Subject: [PATCH 15/25] update tests for fp8/torchao/ddp --- benchmarks/fp8/torchao/ddp.py | 29 +------- benchmarks/fp8/torchao/distrib_deepspeed.py | 34 +++++----- benchmarks/fp8/torchao/fp8_utils.py | 2 +- benchmarks/fp8/torchao/fsdp.py | 29 +------- benchmarks/fp8/torchao/non_distributed.py | 74 +-------------------- 5 files changed, 21 insertions(+), 147 deletions(-) diff --git a/benchmarks/fp8/torchao/ddp.py b/benchmarks/fp8/torchao/ddp.py index 0b7e6071ac2..5cb125b56b2 100644 --- a/benchmarks/fp8/torchao/ddp.py +++ b/benchmarks/fp8/torchao/ddp.py @@ -22,11 +22,9 @@ import evaluate import torch -from fp8_utils import get_dataloaders +from fp8_utils import get_training_utilities from torch.nn.parallel import DistributedDataParallel as DDP -from torch.optim import AdamW from torchao.float8 import convert_to_float8_training -from transformers import AutoModelForSequenceClassification, get_linear_schedule_with_warmup from accelerate import Accelerator from accelerate.state import AcceleratorState @@ -37,31 +35,6 @@ METRIC = evaluate.load("glue", "mrpc") -def get_training_utilities(model_name: str, batch_size: int = 16, accelerator=None): - """ - Returns a tuple of: - - Model - - Optimizer - - Train dataloader (prepared) - - Eval dataloader (prepared) - - LR Scheduler - Suitable for training on the MRPC dataset - """ - - if accelerator is None: - accelerator = Accelerator() - model = AutoModelForSequenceClassification.from_pretrained(model_name) - train_dataloader, eval_dataloader = get_dataloaders(model_name, batch_size) - optimizer = AdamW(model.parameters(), lr=0.0001) - lr_scheduler = get_linear_schedule_with_warmup( - optimizer=optimizer, - num_warmup_steps=100, - num_training_steps=len(train_dataloader) * 2, - ) - train_dataloader, eval_dataloader = accelerator.prepare(train_dataloader, eval_dataloader) - return model, optimizer, train_dataloader, eval_dataloader, lr_scheduler - - def evaluate_model(model, dataloader, metric, accelerator=None): "Turns model to .eval(), runs dataloader, calculates metric, then turns eval back on" model.eval() diff --git a/benchmarks/fp8/torchao/distrib_deepspeed.py b/benchmarks/fp8/torchao/distrib_deepspeed.py index 2609af30730..6fc2080b304 100644 --- a/benchmarks/fp8/torchao/distrib_deepspeed.py +++ b/benchmarks/fp8/torchao/distrib_deepspeed.py @@ -29,6 +29,7 @@ from transformers.integrations import HfDeepSpeedConfig from accelerate import Accelerator, DeepSpeedPlugin +from accelerate.state import AcceleratorState from accelerate.utils import AORecipeKwargs, set_seed @@ -138,7 +139,7 @@ def train_baseline(zero_stage: int = 1): def train_integration(zero_stage: int = 1): set_seed(42) - # AcceleratorState()._reset_state(True) + AcceleratorState()._reset_state(True) config = HfDeepSpeedConfig( { "train_micro_batch_size_per_gpu": 16, @@ -191,23 +192,22 @@ def train_integration(zero_stage: int = 1): if __name__ == "__main__": - for zero_stage in [3]: - # Expected baseline: ValueError: {'accuracy': 0.7916666666666666, 'f1': 0.8513011152416357} - # baseline_not_trained, baseline_trained, baseline_outputs, baseline_data = train_baseline(zero_stage) + for zero_stage in [1, 2, 3]: + baseline_not_trained, baseline_trained, baseline_outputs, baseline_data = train_baseline(zero_stage) accelerator_not_trained, accelerator_trained, accelerator_outputs, accelerator_data = train_integration( zero_stage ) - # assert ( - # baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] - # ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' - # assert ( - # baseline_not_trained["f1"] == accelerator_not_trained["f1"] - # ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' - # assert ( - # baseline_trained["accuracy"] == accelerator_trained["accuracy"] - # ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' - # assert ( - # baseline_trained["f1"] == accelerator_trained["f1"] - # ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' - # AcceleratorState()._reset_state(True) + assert ( + baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] + ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' + assert ( + baseline_not_trained["f1"] == accelerator_not_trained["f1"] + ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' + assert ( + baseline_trained["accuracy"] == accelerator_trained["accuracy"] + ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' + assert ( + baseline_trained["f1"] == accelerator_trained["f1"] + ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' + AcceleratorState()._reset_state(True) torch.distributed.destroy_process_group() diff --git a/benchmarks/fp8/torchao/fp8_utils.py b/benchmarks/fp8/torchao/fp8_utils.py index e010d9b4fba..1aaa7db5df9 100644 --- a/benchmarks/fp8/torchao/fp8_utils.py +++ b/benchmarks/fp8/torchao/fp8_utils.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/benchmarks/fp8/torchao/fsdp.py b/benchmarks/fp8/torchao/fsdp.py index a047f27bd86..42eedb48bd5 100644 --- a/benchmarks/fp8/torchao/fsdp.py +++ b/benchmarks/fp8/torchao/fsdp.py @@ -22,13 +22,11 @@ import evaluate import torch -from fp8_utils import get_dataloaders +from fp8_utils import get_training_utilities from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import MixedPrecision from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy -from torch.optim import AdamW from torchao.float8 import convert_to_float8_training -from transformers import AutoModelForSequenceClassification, get_linear_schedule_with_warmup from transformers.models.bert import BertLayer from accelerate import Accelerator @@ -43,31 +41,6 @@ FSDP_WRAP_POLICY = partial(transformer_auto_wrap_policy, transformer_layer_cls={BertLayer}) -def get_training_utilities(model_name: str, batch_size: int = 16, accelerator=None): - """ - Returns a tuple of: - - Model - - Optimizer - - Train dataloader (prepared) - - Eval dataloader (prepared) - - LR Scheduler - Suitable for training on the MRPC dataset - """ - - if accelerator is None: - accelerator = Accelerator() - model = AutoModelForSequenceClassification.from_pretrained(model_name) - train_dataloader, eval_dataloader = get_dataloaders(model_name, batch_size) - optimizer = AdamW(model.parameters(), lr=0.0001) - lr_scheduler = get_linear_schedule_with_warmup( - optimizer=optimizer, - num_warmup_steps=100, - num_training_steps=len(train_dataloader) * 2, - ) - train_dataloader, eval_dataloader = accelerator.prepare(train_dataloader, eval_dataloader) - return model, optimizer, train_dataloader, eval_dataloader, lr_scheduler - - def filter_linear_layers(module, fqn, first_layer_name=None, last_layer_name=None): if isinstance(module, torch.nn.Linear): if module.in_features % 16 != 0 or module.out_features % 16 != 0: diff --git a/benchmarks/fp8/torchao/non_distributed.py b/benchmarks/fp8/torchao/non_distributed.py index e2426f162d5..7b8e5993e42 100644 --- a/benchmarks/fp8/torchao/non_distributed.py +++ b/benchmarks/fp8/torchao/non_distributed.py @@ -22,11 +22,8 @@ import evaluate import torch -from datasets import load_dataset -from torch.optim import AdamW -from torch.utils.data import DataLoader +from fp8_utils import get_training_utilities from torchao.float8 import convert_to_float8_training -from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup from accelerate import Accelerator from accelerate.state import AcceleratorState @@ -37,75 +34,6 @@ METRIC = evaluate.load("glue", "mrpc") -def get_dataloaders(model_name: str, batch_size: int = 16): - tokenizer = AutoTokenizer.from_pretrained(model_name) - datasets = load_dataset("glue", "mrpc") - - def tokenize_function(examples): - # max_length=None => use the model max length (it's actually the default) - outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None) - return outputs - - # Apply the method we just defined to all the examples in all the splits of the dataset - # starting with the main process first: - tokenized_datasets = datasets.map( - tokenize_function, - batched=True, - remove_columns=["idx", "sentence1", "sentence2"], - ) - - # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the - # transformers library - tokenized_datasets = tokenized_datasets.rename_column("label", "labels") - - def collate_fn(examples): - return tokenizer.pad( - examples, - padding="longest", - pad_to_multiple_of=16, # Specific for FP8 - return_tensors="pt", - ) - - # Instantiate dataloaders. - train_dataloader = DataLoader( - tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size, drop_last=True - ) - eval_dataloader = DataLoader( - tokenized_datasets["validation"], - shuffle=False, - collate_fn=collate_fn, - batch_size=16, - drop_last=True, - ) - - return train_dataloader, eval_dataloader - - -def get_training_utilities(model_name: str, batch_size: int = 16, accelerator=None): - """ - Returns a tuple of: - - Model - - Optimizer - - Train dataloader (prepared) - - Eval dataloader (prepared) - - LR Scheduler - Suitable for training on the MRPC dataset - """ - - if accelerator is None: - accelerator = Accelerator() - model = AutoModelForSequenceClassification.from_pretrained(model_name) - train_dataloader, eval_dataloader = get_dataloaders(model_name, batch_size) - optimizer = AdamW(model.parameters(), lr=0.0001) - lr_scheduler = get_linear_schedule_with_warmup( - optimizer=optimizer, - num_warmup_steps=100, - num_training_steps=len(train_dataloader) * 2, - ) - train_dataloader, eval_dataloader = accelerator.prepare(train_dataloader, eval_dataloader) - return model, optimizer, train_dataloader, eval_dataloader, lr_scheduler - - def evaluate_model(model, dataloader, metric, accelerator=None): "Turns model to .eval(), runs dataloader, calculates metric, then turns eval back on" model.eval() From 002c4be2b5d26d052785504bd4e6916e1065e4cb Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Wed, 12 Feb 2025 09:01:44 -0500 Subject: [PATCH 16/25] Log FP8 backend used and adjust typing --- src/accelerate/accelerator.py | 11 ++++++++--- src/accelerate/utils/ao.py | 11 ++++++++--- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 0a12ecf388c..370b655d3cc 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -467,14 +467,18 @@ def __init__( if self._mixed_precision == "fp8" and not self.has_fp8_handler: # Prioritize TE -> AO -> MSAMP if is_torchao_available(): + logger.info("Found `torchao` installed, using it for FP8 training.") self.ao_recipe_handler = AORecipeKwargs() elif is_transformer_engine_available(): + logger.info("Found `transformer-engine` installed, using it for FP8 training.") self.te_recipe_handler = TERecipeKwargs() elif is_msamp_available(): + logger.info("Found `msamp` installed, using it for FP8 training.") self.msamp_recipe_handler = MSAMPRecipeKwargs() else: raise ImportError( - "Tried to train with `fp8` and auto-detect backend, but no FP8-compatible backend was installed." + "Tried to train with `fp8` and auto-detect backend, but no FP8-compatible backend was installed. " + "Valid backends are: `torchao`, `transformer-engine`, and `msamp`." ) self.delayed_fp8_autocast = False @@ -484,7 +488,6 @@ def __init__( self.distributed_type not in (DistributedType.FSDP, DistributedType.DEEPSPEED) ): raise ValueError("Passing in an FP8 configuration requires setting `mixed_precision='fp8'`.") - # DEPRECATE once 2.0 is released self.delayed_fp8_autocast = self.fp8_backend == "TE" and self.distributed_type in ( DistributedType.MULTI_GPU, DistributedType.FSDP, @@ -1673,7 +1676,9 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e def _prepare_ao(self, *args): if not is_torchao_available(): - raise ImportError("`torchao` was not found on your system. Please ensure that `torchao` is installed") + raise ImportError( + "`torchao` was not found on your system or is too old of a version. Please ensure that `torchao >= 0.6.1` is installed" + ) for arg in args: if isinstance(arg, torch.nn.Module): convert_to_float8_training( diff --git a/src/accelerate/utils/ao.py b/src/accelerate/utils/ao.py index 8d74774feae..ec7576c19de 100644 --- a/src/accelerate/utils/ao.py +++ b/src/accelerate/utils/ao.py @@ -17,10 +17,15 @@ """ from functools import partial +from typing import Callable, Optional import torch -from .imports import torchao_required +from .imports import is_torchao_available, torchao_required + + +if is_torchao_available(): + from torchao.float8.float8_linear import Float8LinearConfig def find_first_last_linear_layers(model: torch.nn.Module): @@ -81,8 +86,8 @@ def has_ao_layers(model: torch.nn.Module): @torchao_required def convert_to_float8_training( model: torch.nn.Module, - config=None, - module_filter_func=None, + config: Optional["Float8LinearConfig"] = None, + module_filter_func: Optional[Callable] = None, ): """ Converts all `nn.Linear` layers in the model (except the first and last) to torchao's `Float8Linear` layer inplace. From 95fbb8dc425c5cc97685b28639b0f037a2f1a248 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Wed, 12 Feb 2025 09:04:21 -0500 Subject: [PATCH 17/25] add documentation for convert_to_float8_training --- src/accelerate/utils/ao.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/accelerate/utils/ao.py b/src/accelerate/utils/ao.py index ec7576c19de..d8d81a14812 100644 --- a/src/accelerate/utils/ao.py +++ b/src/accelerate/utils/ao.py @@ -87,7 +87,7 @@ def has_ao_layers(model: torch.nn.Module): def convert_to_float8_training( model: torch.nn.Module, config: Optional["Float8LinearConfig"] = None, - module_filter_func: Optional[Callable] = None, + module_filter_func: Optional[Callable] = filter_linear_layers, ): """ Converts all `nn.Linear` layers in the model (except the first and last) to torchao's `Float8Linear` layer inplace. @@ -98,8 +98,8 @@ def convert_to_float8_training( config (`torchao.float8.Float8LinearConfig`, *optional*): The configuration for the FP8 training. Recommended to utilize `torchao.float8.recipe_name_to_linear_config` to generate this. In general, the default config should be - sufficient. - module_filter_func (`Callable`, *optional*): + sufficient (what is passed when set to `None`). + module_filter_func (`Callable`, *optional*, defaults to `filter_linear_layers`): Optional function that must take in a module and layer name, and returns a boolean indicating whether the module should be converted to FP8. Defaults to `filter_linear_layers`. See it for an example. From ac22296611773308e3b5203ed5ab4335ceb0b243 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Wed, 12 Feb 2025 09:07:22 -0500 Subject: [PATCH 18/25] Rename to convert_model_to_fp8_ao --- src/accelerate/accelerator.py | 4 ++-- src/accelerate/utils/__init__.py | 2 +- src/accelerate/utils/ao.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 370b655d3cc..7d05dafb600 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -78,8 +78,8 @@ clean_state_dict_for_safetensors, compare_versions, convert_model, + convert_model_to_fp8_ao, convert_outputs_to_fp32, - convert_to_float8_training, ensure_weights_retied, extract_model_from_parallel, gather, @@ -1681,7 +1681,7 @@ def _prepare_ao(self, *args): ) for arg in args: if isinstance(arg, torch.nn.Module): - convert_to_float8_training( + convert_model_to_fp8_ao( arg, config=self.ao_recipe_handler.config, module_filter_func=self.ao_recipe_handler.module_filter_func, diff --git a/src/accelerate/utils/__init__.py b/src/accelerate/utils/__init__.py index 1c9dbff909a..4623abc52e1 100644 --- a/src/accelerate/utils/__init__.py +++ b/src/accelerate/utils/__init__.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .ao import convert_to_float8_training, has_ao_layers +from .ao import convert_model_to_fp8_ao, has_ao_layers from .constants import ( MITA_PROFILING_AVAILABLE_PYTORCH_VERSION, MODEL_NAME, diff --git a/src/accelerate/utils/ao.py b/src/accelerate/utils/ao.py index d8d81a14812..5f94b4c0a4f 100644 --- a/src/accelerate/utils/ao.py +++ b/src/accelerate/utils/ao.py @@ -84,7 +84,7 @@ def has_ao_layers(model: torch.nn.Module): @torchao_required -def convert_to_float8_training( +def convert_model_to_fp8_ao( model: torch.nn.Module, config: Optional["Float8LinearConfig"] = None, module_filter_func: Optional[Callable] = filter_linear_layers, @@ -106,7 +106,7 @@ def convert_to_float8_training( Example: ```python - from accelerate.utils.ao import convert_to_float8_training + from accelerate.utils.ao import convert_model_to_fp8_ao model = MyModel() model.to("cuda") From 06642ca11e6d96584ca65b98fcf0d45b6b124e9b Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Wed, 12 Feb 2025 09:08:45 -0500 Subject: [PATCH 19/25] Call superinit" --- src/accelerate/utils/dataclasses.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index cf6511a7cb0..97fa50ac5f9 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -424,6 +424,7 @@ def __post_init__(self): self.backend = self.backend.upper() if self.backend not in get_args(Backend): raise ValueError("`backend` must be 'MSAMP' or 'TE' (TransformerEngine) to use `FP8RecipeKwargs`.") + super().__post_init__() # Literal From d46b0a1df6c99f3848b2ecb002ecac7278a1deff Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Wed, 12 Feb 2025 09:09:38 -0500 Subject: [PATCH 20/25] Add types --- src/accelerate/utils/dataclasses.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index 97fa50ac5f9..fec81048b97 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -305,9 +305,8 @@ class AORecipeKwargs(KwargsHandler): example. """ - config: "Float8LinearConfig" = None - module_filter_func: Callable = None - + config: Optional["Float8LinearConfig"] = None + module_filter_func: Optional[Callable] = None @dataclass class TERecipeKwargs(KwargsHandler): From 62881acb0c01fcb7431b69c8b88c665ba3096183 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Wed, 12 Feb 2025 09:10:10 -0500 Subject: [PATCH 21/25] Clean --- src/accelerate/utils/dataclasses.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index fec81048b97..9936ee8c00c 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -308,6 +308,7 @@ class AORecipeKwargs(KwargsHandler): config: Optional["Float8LinearConfig"] = None module_filter_func: Optional[Callable] = None + @dataclass class TERecipeKwargs(KwargsHandler): """ From f3ceb37bd172f69de6da0e6047f16e2d0e9b026f Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Fri, 14 Feb 2025 14:13:48 -0500 Subject: [PATCH 22/25] Use filter_first_and_last_linear_layers --- src/accelerate/utils/__init__.py | 2 +- src/accelerate/utils/ao.py | 39 +++++++++++++++++++++----------- 2 files changed, 27 insertions(+), 14 deletions(-) diff --git a/src/accelerate/utils/__init__.py b/src/accelerate/utils/__init__.py index 4623abc52e1..5f00d96c395 100644 --- a/src/accelerate/utils/__init__.py +++ b/src/accelerate/utils/__init__.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .ao import convert_model_to_fp8_ao, has_ao_layers +from .ao import convert_model_to_fp8_ao, has_ao_layers, filter_first_and_last_linear_layers from .constants import ( MITA_PROFILING_AVAILABLE_PYTORCH_VERSION, MODEL_NAME, diff --git a/src/accelerate/utils/ao.py b/src/accelerate/utils/ao.py index 5f94b4c0a4f..6c59f33c27a 100644 --- a/src/accelerate/utils/ao.py +++ b/src/accelerate/utils/ao.py @@ -17,7 +17,7 @@ """ from functools import partial -from typing import Callable, Optional +from typing import Callable, List, Optional import torch @@ -45,34 +45,47 @@ def find_first_last_linear_layers(model: torch.nn.Module): return first_linear, last_linear -def filter_linear_layers(module, layer_name, first_layer_name, last_layer_name) -> bool: +def filter_linear_layers(module, fqn: str, layers_to_filter: List[str]) -> bool: """ A function which will check if `module` is: - a `torch.nn.Linear` layer - has in_features and out_features divisible by 16 - - is not the first or last layer of the model. + - is not part of `layers_to_filter` Args: module (`torch.nn.Module`): The module to check. - layer_name (`str`): + fqn (`str`): The fully qualified name of the layer. - first_layer_name (`str`): - The name of the first layer of the model. - last_layer_name (`str`): - The name of the last layer of the model. + layers_to_filter (`List[str]`): + The list of layers to filter. """ if isinstance(module, torch.nn.Linear): if module.in_features % 16 != 0 or module.out_features % 16 != 0: return False - # For stability reasons, we skip the first and last linear layers - # Otherwise can lead to the model not training or converging properly - # TODO: apply this to all FP8 backends - if layer_name in (first_layer_name, last_layer_name): + if fqn in layers_to_filter: return False return True +def filter_first_and_last_linear_layers(module, fqn: str) -> bool: + """ + A filter function which will filter out all linear layers except the first and last. + + + For stability reasons, we skip the first and last linear layers + Otherwise can lead to the model not training or converging properly + + + Args: + module (`torch.nn.Module`): + The module to check. + fqn (`str`): + The fully qualified name of the layer. + """ + first_linear, last_linear = find_first_last_linear_layers(module) + return filter_linear_layers(module, fqn, layers_to_filter=[first_linear, last_linear]) + @torchao_required def has_ao_layers(model: torch.nn.Module): from torchao.float8.float8_linear import Float8Linear @@ -87,7 +100,7 @@ def has_ao_layers(model: torch.nn.Module): def convert_model_to_fp8_ao( model: torch.nn.Module, config: Optional["Float8LinearConfig"] = None, - module_filter_func: Optional[Callable] = filter_linear_layers, + module_filter_func: Optional[Callable] = filter_first_and_last_linear_layers, ): """ Converts all `nn.Linear` layers in the model (except the first and last) to torchao's `Float8Linear` layer inplace. From 14f6d04ebccabe8be86d4c9be7806ec0a2b2250c Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Mon, 17 Feb 2025 11:29:33 -0500 Subject: [PATCH 23/25] Update usage guide docs --- .../usage_guides/low_precision_training.md | 31 +++++++++++++++---- src/accelerate/utils/__init__.py | 2 +- src/accelerate/utils/ao.py | 7 +++-- 3 files changed, 31 insertions(+), 9 deletions(-) diff --git a/docs/source/usage_guides/low_precision_training.md b/docs/source/usage_guides/low_precision_training.md index c730136e1ce..08e533e60cf 100644 --- a/docs/source/usage_guides/low_precision_training.md +++ b/docs/source/usage_guides/low_precision_training.md @@ -15,7 +15,7 @@ rendered properly in your Markdown viewer. # Low Precision Training Methods -Accelerate provides integrations to train on lower precision methods using specified supported hardware through the `TransformersEngine` and `MS-AMP` packages. This documentation will help guide you through what hardware is supported, how to configure your [`Accelerator`] to leverage the low precision methods, and what you can expect when training. +Accelerate provides integrations to train on lower precision methods using specified supported hardware through the `TransformersEngine`, `MS-AMP`, and `torchao` packages. This documentation will help guide you through what hardware is supported, how to configure your [`Accelerator`] to leverage the low precision methods, and what you can expect when training. ## What training on FP8 means @@ -30,7 +30,7 @@ What this will result in is some gain in the memory used (as we've cut the neede ## Configuring the Accelerator -Currently two different backends for FP8 are supported (`TransformersEngine` and `MS-AMP`), each with different capabilities and configurations. +Currently three different backends for FP8 are supported (`TransformersEngine`, `torchao`, and `MS-AMP`), each with different capabilities and configurations. To use either, the same core API is used. Just pass `mixed_precision="fp8"` to either the [`Accelerator`], during `accelerate config` when prompted about mixed precision, or as part of your `config.yaml` file in the `mixed_precision` key: @@ -39,14 +39,16 @@ from accelerate import Accelerator accelerator = Accelerator(mixed_precision="fp8") ``` -By default, if `MS-AMP` is available in your environment, Accelerate will automatically utilize it as a backend. To specify it yourself (and customize other parts of the FP8 mixed precision setup), you can utilize the [`utils.FP8RecipeKwargs`] or clarify it in your config `yaml`/during `accelerate launch`: +By default, if `MS-AMP` is available in your environment, Accelerate will automatically utilize it as a backend. To specify it yourself (and customize other parts of the FP8 mixed precision setup), you can utilize one of the `RecipeKwargs` dataclasses such as [`utils.AORecipeKwargs`], [`utils.TERecipeKwargs`], or [`utils.MSAMPRecipeKwargs`]; you can also nclarify it in your config `yaml`/during `accelerate launch`: ```{python} from accelerate import Accelerator -from accelerate.utils import FP8RecipeKwargs -kwargs = [FP8RecipeKwargs(backend="msamp")] +from accelerate.utils import MSAMPRecipeKwargs +kwargs = [MSAMPRecipeKwargs()] # Or to specify the backend as `TransformersEngine` even if MS-AMP is installed -# kwargs = [FP8RecipeKwargs(backend="te")] +# kwargs = [TERecipeKwargs()] +# Or to use torchao +# kwargs = [AORecipeKwargs()] accelerator = Accelerator(mixed_precision="fp8", kwarg_handlers=kwargs) ``` @@ -124,6 +126,22 @@ fp8_config: use_autocast_during_eval: false ``` +## Configuring `torchao` + +`torchao` is a [PyTorch-driven](https://github.com/pytorch/ao/tree/main/torchao/float8) hackable FP8 backend, aiming to be more approchable than the prior two engines. One of the core differences with `ao` compared to the prior two is that for numerical stability, it's found to be generally better off keeping the first *and* last layers in the model at the regular precision (be it FP32 or BF16), and then the other layers quantized down to FP8. As a result, a config for `ao` looks a bit differently: + +> Note: this API is experimental and is subject to change + +```{python} +from accelerate import Accelerator +from accelerate.utils import AORecipeKwargs +kwargs = [AORecipeKwargs()] +accelerator = Accelerator(mixed_precision="fp8", kwarg_handlers=kwargs) +``` + +To learn more about the specific parameters to be used, please see the official `torchao` repo. + + ## Example Zoo We have examples showcasing training with FP8 both with accelerate and its underlying implementation available in the accelerate repo. @@ -143,3 +161,4 @@ To learn more about training in FP8 please check out the following resources: * [Our concept guide](../concept_guides/low_precision_training) detailing into more about both TransformersEngine and MS-AMP * [The `transformers-engine` documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html) * [The `MS-AMP` documentation](https://azure.github.io/MS-AMP/docs/) +* [The `torchao` documentation](https://github.com/pytorch/ao/tree/main/torchao/float8) diff --git a/src/accelerate/utils/__init__.py b/src/accelerate/utils/__init__.py index 5f00d96c395..6b5bed62d73 100644 --- a/src/accelerate/utils/__init__.py +++ b/src/accelerate/utils/__init__.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .ao import convert_model_to_fp8_ao, has_ao_layers, filter_first_and_last_linear_layers +from .ao import convert_model_to_fp8_ao, filter_first_and_last_linear_layers, has_ao_layers from .constants import ( MITA_PROFILING_AVAILABLE_PYTORCH_VERSION, MODEL_NAME, diff --git a/src/accelerate/utils/ao.py b/src/accelerate/utils/ao.py index 6c59f33c27a..2023371ca66 100644 --- a/src/accelerate/utils/ao.py +++ b/src/accelerate/utils/ao.py @@ -73,8 +73,10 @@ def filter_first_and_last_linear_layers(module, fqn: str) -> bool: A filter function which will filter out all linear layers except the first and last. - For stability reasons, we skip the first and last linear layers - Otherwise can lead to the model not training or converging properly + + For stability reasons, we skip the first and last linear layers Otherwise can lead to the model not training or + converging properly + Args: @@ -86,6 +88,7 @@ def filter_first_and_last_linear_layers(module, fqn: str) -> bool: first_linear, last_linear = find_first_last_linear_layers(module) return filter_linear_layers(module, fqn, layers_to_filter=[first_linear, last_linear]) + @torchao_required def has_ao_layers(model: torch.nn.Module): from torchao.float8.float8_linear import Float8Linear From a8b6b8c6d2830955fe1a0139783a2a7e0f02fead Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Mon, 17 Feb 2025 11:36:39 -0500 Subject: [PATCH 24/25] Actually loop through the zero stages --- .../transformer_engine/distrib_deepspeed.py | 35 +++++++++---------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/benchmarks/fp8/transformer_engine/distrib_deepspeed.py b/benchmarks/fp8/transformer_engine/distrib_deepspeed.py index 7ea77266915..a574a864e47 100644 --- a/benchmarks/fp8/transformer_engine/distrib_deepspeed.py +++ b/benchmarks/fp8/transformer_engine/distrib_deepspeed.py @@ -170,21 +170,20 @@ def train_integration(zero_stage: int = 1): if __name__ == "__main__": - # for zero_stage in [1, 2, 3]: - zero_stage = 3 - baseline_not_trained, baseline_trained, baseline_outputs, baseline_data = train_baseline(zero_stage) - accelerator_not_trained, accelerator_trained, accelerator_outputs, accelerator_data = train_integration(zero_stage) - assert ( - baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] - ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' - assert ( - baseline_not_trained["f1"] == accelerator_not_trained["f1"] - ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' - assert ( - baseline_trained["accuracy"] == accelerator_trained["accuracy"] - ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' - assert ( - baseline_trained["f1"] == accelerator_trained["f1"] - ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' - - torch.distributed.destroy_process_group() + for zero_stage in [1, 2, 3]: + baseline_not_trained, baseline_trained, baseline_outputs, baseline_data = train_baseline(zero_stage) + accelerator_not_trained, accelerator_trained, accelerator_outputs, accelerator_data = train_integration(zero_stage) + assert ( + baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] + ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' + assert ( + baseline_not_trained["f1"] == accelerator_not_trained["f1"] + ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' + assert ( + baseline_trained["accuracy"] == accelerator_trained["accuracy"] + ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' + assert ( + baseline_trained["f1"] == accelerator_trained["f1"] + ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' + + torch.distributed.destroy_process_group() From f222d7215a1287c9119f6266bc20c98a9c382809 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Mon, 17 Feb 2025 11:36:47 -0500 Subject: [PATCH 25/25] Clean --- benchmarks/fp8/transformer_engine/distrib_deepspeed.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/benchmarks/fp8/transformer_engine/distrib_deepspeed.py b/benchmarks/fp8/transformer_engine/distrib_deepspeed.py index a574a864e47..73953b6793f 100644 --- a/benchmarks/fp8/transformer_engine/distrib_deepspeed.py +++ b/benchmarks/fp8/transformer_engine/distrib_deepspeed.py @@ -172,7 +172,9 @@ def train_integration(zero_stage: int = 1): if __name__ == "__main__": for zero_stage in [1, 2, 3]: baseline_not_trained, baseline_trained, baseline_outputs, baseline_data = train_baseline(zero_stage) - accelerator_not_trained, accelerator_trained, accelerator_outputs, accelerator_data = train_integration(zero_stage) + accelerator_not_trained, accelerator_trained, accelerator_outputs, accelerator_data = train_integration( + zero_stage + ) assert ( baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}'