diff --git a/benchmarks/fp8/torchao/Dockerfile b/benchmarks/fp8/torchao/Dockerfile
new file mode 100644
index 00000000000..88c21934d4e
--- /dev/null
+++ b/benchmarks/fp8/torchao/Dockerfile
@@ -0,0 +1,12 @@
+FROM nvcr.io/nvidia/pytorch:24.07-py3
+
+RUN pip install transformers evaluate datasets
+RUN git clone https://github.com/huggingface/accelerate.git
+
+RUN cd accelerate && \
+ pip install -e . && \
+ cd benchmarks/fp8
+
+RUN /bin/bash
+
+
diff --git a/benchmarks/fp8/torchao/README.md b/benchmarks/fp8/torchao/README.md
new file mode 100644
index 00000000000..d5abadaf64e
--- /dev/null
+++ b/benchmarks/fp8/torchao/README.md
@@ -0,0 +1,32 @@
+# FP8 Benchmarks
+
+Comparing and running [torchao](https://github.com/pytorch/ao/tree/main/torchao/float8) FP8 with accelerate
+
+## Overview
+
+This repo provides scripts which compare native `torchao` model training against `accelerate`'s own integration. Each modeling type is segmented out via a script, supporting the following:
+
+* Single GPU training (`non_distributed.py`)
+* Multi-GPU training via DistributedDataParallelism (`ddp.py`)
+* Fully Sharded Data Parallelism (`fsdp.py`)
+* DeepSpeed ZeRO 1-3 (`deepspeed.py`)
+
+To run them, it's recommended to use a docker image (see the attached `Dockerfile`) and not install `torchao` manually.
+
+## Running:
+
+There are official Docker images located at `huggingface/accelerate:gpu-fp8-torchao-nightly` which can be used.
+
+You can run all scripts using the core `accelerate launch` command without any `accelerate config` being needed.
+
+For single GPU, run it via `python`:
+
+```bash
+python non_distributed.py
+```
+
+For the rest, run it via `accelerate launch`:
+
+```bash
+accelerate launch ddp.py # or distrib_deepspeed.py, ddp.py
+```
\ No newline at end of file
diff --git a/benchmarks/fp8/torchao/ddp.py b/benchmarks/fp8/torchao/ddp.py
new file mode 100644
index 00000000000..5cb125b56b2
--- /dev/null
+++ b/benchmarks/fp8/torchao/ddp.py
@@ -0,0 +1,158 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This script tests to ensure that `accelerate` performs at the same level as raw `torchao`.
+
+This particular script verifies this for DDP training.
+"""
+
+from functools import partial
+
+import evaluate
+import torch
+from fp8_utils import get_training_utilities
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torchao.float8 import convert_to_float8_training
+
+from accelerate import Accelerator
+from accelerate.state import AcceleratorState
+from accelerate.utils import AORecipeKwargs, set_seed
+
+
+MODEL_NAME = "bert-base-cased"
+METRIC = evaluate.load("glue", "mrpc")
+
+
+def evaluate_model(model, dataloader, metric, accelerator=None):
+ "Turns model to .eval(), runs dataloader, calculates metric, then turns eval back on"
+ model.eval()
+ for step, batch in enumerate(dataloader):
+ with torch.no_grad():
+ outputs = model(**batch)
+ predictions = outputs.logits.argmax(dim=-1)
+ references = batch["labels"]
+ if accelerator is not None and accelerator.num_processes > 1:
+ predictions, references = accelerator.gather_for_metrics((predictions, references))
+ metric.add_batch(predictions=predictions, references=references)
+ return metric.compute()
+
+
+def filter_linear_layers(module, fqn, first_layer_name=None, last_layer_name=None):
+ if isinstance(module, torch.nn.Linear):
+ if module.in_features % 16 != 0 or module.out_features % 16 != 0:
+ return False
+ # For stability reasons, we skip the first and last linear layers
+ # Otherwise can lead to the model not training or converging properly
+ if fqn in (first_layer_name, last_layer_name):
+ return False
+ return True
+
+
+def train_baseline():
+ set_seed(42)
+ model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME)
+ first_linear = None
+ last_linear = None
+ for name, module in model.named_modules():
+ if isinstance(module, torch.nn.Linear):
+ if first_linear is None:
+ first_linear = name
+ last_linear = name
+ func = partial(filter_linear_layers, first_layer_name=first_linear, last_layer_name=last_linear)
+ accelerator = Accelerator()
+ device = accelerator.device
+ model.to(device)
+
+ convert_to_float8_training(model, module_filter_fn=func)
+
+ # Convert the model to DDP
+ device_ids, output_device = [accelerator.local_process_index], accelerator.local_process_index
+ model = DDP(model, device_ids=device_ids, output_device=output_device)
+
+ base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
+ model.train()
+
+ for batch in train_dataloader:
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
+ batch = batch.to(device)
+ outputs = model(**batch)
+ loss = outputs.loss
+ loss.backward()
+ optimizer.step()
+ optimizer.zero_grad()
+ lr_scheduler.step()
+
+ trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
+
+ assert (
+ trained_model_results["accuracy"] > base_model_results["accuracy"]
+ ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}'
+ assert (
+ trained_model_results["f1"] > base_model_results["f1"]
+ ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}'
+
+ return base_model_results, trained_model_results
+
+
+def train_integration():
+ AcceleratorState()._reset_state(True)
+ accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=[AORecipeKwargs()])
+ set_seed(42)
+ model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(
+ MODEL_NAME, accelerator=accelerator
+ )
+
+ model, optimizer = accelerator.prepare(model, optimizer)
+ base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
+ model.train()
+
+ for batch in train_dataloader:
+ outputs = model(**batch)
+ loss = outputs.loss
+ accelerator.backward(loss)
+ optimizer.step()
+ optimizer.zero_grad()
+ lr_scheduler.step()
+
+ trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
+
+ assert (
+ trained_model_results["accuracy"] > base_model_results["accuracy"]
+ ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}'
+ assert (
+ trained_model_results["f1"] > base_model_results["f1"]
+ ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}'
+
+ return base_model_results, trained_model_results
+
+
+if __name__ == "__main__":
+ baseline_not_trained, baseline_trained = train_baseline()
+ accelerator_not_trained, accelerator_trained = train_integration()
+
+ assert (
+ baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"]
+ ), f'Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}'
+ assert (
+ baseline_not_trained["f1"] == accelerator_not_trained["f1"]
+ ), f'F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}'
+ assert (
+ baseline_trained["accuracy"] == accelerator_trained["accuracy"]
+ ), f'Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}'
+ assert (
+ baseline_trained["f1"] == accelerator_trained["f1"]
+ ), f'F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}'
+
+ torch.distributed.destroy_process_group()
diff --git a/benchmarks/fp8/torchao/distrib_deepspeed.py b/benchmarks/fp8/torchao/distrib_deepspeed.py
new file mode 100644
index 00000000000..6fc2080b304
--- /dev/null
+++ b/benchmarks/fp8/torchao/distrib_deepspeed.py
@@ -0,0 +1,213 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This script tests to ensure that `accelerate` performs at the same level as raw `torchao`.
+
+This particular script verifies this for deepspeed training.
+"""
+
+from functools import partial
+from unittest.mock import patch
+
+import deepspeed
+import evaluate
+import torch
+from fp8_utils import evaluate_model, get_training_utilities
+from torchao.float8 import convert_to_float8_training
+from transformers.integrations import HfDeepSpeedConfig
+
+from accelerate import Accelerator, DeepSpeedPlugin
+from accelerate.state import AcceleratorState
+from accelerate.utils import AORecipeKwargs, set_seed
+
+
+MODEL_NAME = "bert-base-cased"
+METRIC = evaluate.load("glue", "mrpc")
+
+
+def filter_linear_layers(module, fqn, first_layer_name=None, last_layer_name=None):
+ if isinstance(module, torch.nn.Linear):
+ if module.in_features % 16 != 0 or module.out_features % 16 != 0:
+ return False
+ # For stability reasons, we skip the first and last linear layers
+ # Otherwise can lead to the model not training or converging properly
+ if fqn in (first_layer_name, last_layer_name):
+ return False
+ return True
+
+
+def train_baseline(zero_stage: int = 1):
+ set_seed(42)
+ # This forces transformers to think Zero-3 Init should be used
+ with patch("transformers.integrations.deepspeed.is_deepspeed_zero3_enabled") as mock:
+ mock.return_value = zero_stage == 3
+
+ config = HfDeepSpeedConfig(
+ {
+ "train_micro_batch_size_per_gpu": 16,
+ "gradient_accumulation_steps": 1,
+ "zero_optimization": {"stage": zero_stage},
+ }
+ )
+ plugin = DeepSpeedPlugin(hf_ds_config=config)
+ accelerator = Accelerator(deepspeed_plugin=plugin)
+ model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(
+ MODEL_NAME, accelerator=accelerator
+ )
+ first_linear = None
+ last_linear = None
+ for name, module in model.named_modules():
+ if isinstance(module, torch.nn.Linear):
+ if first_linear is None:
+ first_linear = name
+ last_linear = name
+ func = partial(filter_linear_layers, first_layer_name=first_linear, last_layer_name=last_linear)
+
+ convert_to_float8_training(model, module_filter_fn=func)
+
+ import numpy as np
+
+ config = {
+ "train_batch_size": 32,
+ "train_micro_batch_size_per_gpu": 16,
+ "gradient_accumulation_steps": 1,
+ "zero_optimization": {
+ "stage": zero_stage,
+ "offload_optimizer": {"device": "none", "nvme_path": None},
+ "offload_param": {"device": "none", "nvme_path": None},
+ "stage3_gather_16bit_weights_on_model_save": False,
+ },
+ "gradient_clipping": 1.0,
+ "steps_per_print": np.inf,
+ "bf16": {"enabled": True},
+ "fp16": {"enabled": False},
+ "zero_allow_untested_optimizer": True,
+ }
+
+ (
+ model,
+ optimizer,
+ _,
+ lr_scheduler,
+ ) = deepspeed.initialize(
+ model=model,
+ optimizer=optimizer,
+ lr_scheduler=lr_scheduler,
+ config_params=config,
+ )
+
+ base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
+ model.train()
+
+ model_outputs = []
+ data = []
+
+ for batch in train_dataloader:
+ outputs = model(**batch)
+ data.append(batch.to("cpu"))
+ model_outputs.append(outputs.logits.to("cpu"))
+ loss = outputs.loss
+ model.backward(loss)
+ model.step()
+ for _ in range(accelerator.num_processes):
+ lr_scheduler.step()
+
+ trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
+ model.destroy()
+ assert (
+ trained_model_results["accuracy"] > base_model_results["accuracy"]
+ ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}'
+ assert (
+ trained_model_results["f1"] > base_model_results["f1"]
+ ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}'
+
+ del config
+ return base_model_results, trained_model_results, model_outputs, data
+
+
+def train_integration(zero_stage: int = 1):
+ set_seed(42)
+ AcceleratorState()._reset_state(True)
+ config = HfDeepSpeedConfig(
+ {
+ "train_micro_batch_size_per_gpu": 16,
+ "gradient_accumulation_steps": 1,
+ "zero_optimization": {"stage": zero_stage},
+ }
+ )
+ deepspeed_plugin = DeepSpeedPlugin(
+ hf_ds_config=config,
+ )
+ # This forces transformers to think Zero-3 Init should be used
+ with patch("transformers.integrations.deepspeed.is_deepspeed_zero3_enabled") as mock:
+ mock.return_value = zero_stage == 3
+ accelerator = Accelerator(
+ mixed_precision="fp8", kwargs_handlers=[AORecipeKwargs()], deepspeed_plugin=deepspeed_plugin
+ )
+
+ model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(
+ MODEL_NAME, accelerator=accelerator
+ )
+
+ model, optimizer, lr_scheduler, train_dataloader, eval_dataloader = accelerator.prepare(
+ model, optimizer, lr_scheduler, train_dataloader, eval_dataloader
+ )
+ base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
+ model.train()
+ model_outputs = []
+ data = []
+ for batch in train_dataloader:
+ outputs = model(**batch)
+ data.append(batch.to("cpu"))
+ model_outputs.append(outputs.logits.to("cpu"))
+ loss = outputs.loss
+ accelerator.backward(loss)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
+ model.destroy()
+ assert (
+ trained_model_results["accuracy"] > base_model_results["accuracy"]
+ ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}'
+ assert (
+ trained_model_results["f1"] > base_model_results["f1"]
+ ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}'
+
+ del config
+ return base_model_results, trained_model_results, model_outputs, data
+
+
+if __name__ == "__main__":
+ for zero_stage in [1, 2, 3]:
+ baseline_not_trained, baseline_trained, baseline_outputs, baseline_data = train_baseline(zero_stage)
+ accelerator_not_trained, accelerator_trained, accelerator_outputs, accelerator_data = train_integration(
+ zero_stage
+ )
+ assert (
+ baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"]
+ ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}'
+ assert (
+ baseline_not_trained["f1"] == accelerator_not_trained["f1"]
+ ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}'
+ assert (
+ baseline_trained["accuracy"] == accelerator_trained["accuracy"]
+ ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}'
+ assert (
+ baseline_trained["f1"] == accelerator_trained["f1"]
+ ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}'
+ AcceleratorState()._reset_state(True)
+ torch.distributed.destroy_process_group()
diff --git a/benchmarks/fp8/torchao/fp8_utils.py b/benchmarks/fp8/torchao/fp8_utils.py
new file mode 100644
index 00000000000..1aaa7db5df9
--- /dev/null
+++ b/benchmarks/fp8/torchao/fp8_utils.py
@@ -0,0 +1,116 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+
+
+def get_dataloaders(model_name: str, batch_size: int = 16):
+ from datasets import load_dataset
+ from torch.utils.data import DataLoader
+ from transformers import AutoTokenizer
+
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
+ datasets = load_dataset("glue", "mrpc")
+
+ def tokenize_function(examples):
+ # max_length=None => use the model max length (it's actually the default)
+ outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None)
+ return outputs
+
+ # Apply the method we just defined to all the examples in all the splits of the dataset
+ # starting with the main process first:
+ tokenized_datasets = datasets.map(
+ tokenize_function,
+ batched=True,
+ remove_columns=["idx", "sentence1", "sentence2"],
+ )
+
+ # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the
+ # transformers library
+ tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
+
+ def collate_fn(examples):
+ return tokenizer.pad(
+ examples,
+ padding="longest",
+ pad_to_multiple_of=16, # Specific for FP8
+ return_tensors="pt",
+ )
+
+ # Instantiate dataloaders.
+ train_dataloader = DataLoader(
+ tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size, drop_last=True
+ )
+ eval_dataloader = DataLoader(
+ tokenized_datasets["validation"],
+ shuffle=False,
+ collate_fn=collate_fn,
+ batch_size=16,
+ drop_last=True,
+ )
+
+ return train_dataloader, eval_dataloader
+
+
+def get_training_utilities(model_name: str, batch_size: int = 16, accelerator=None, prepare=True):
+ """
+ Returns a tuple of:
+ - Model
+ - Optimizer
+ - Train dataloader (prepared)
+ - Eval dataloader (prepared)
+ - LR Scheduler
+ Suitable for training on the MRPC dataset
+ """
+ from torch.optim import AdamW
+ from transformers import AutoModelForSequenceClassification, get_linear_schedule_with_warmup
+
+ from accelerate import Accelerator
+
+ if accelerator is None:
+ accelerator = Accelerator()
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
+ train_dataloader, eval_dataloader = get_dataloaders(model_name, batch_size)
+ optimizer = AdamW(model.parameters(), lr=0.0001)
+ lr_scheduler = get_linear_schedule_with_warmup(
+ optimizer=optimizer,
+ num_warmup_steps=100,
+ num_training_steps=len(train_dataloader) * 2,
+ )
+ train_dataloader, eval_dataloader = accelerator.prepare(train_dataloader, eval_dataloader)
+ return model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
+
+
+def get_named_parameters(model):
+ """
+ Same thing as `Accelerator.get_named_parameters` Returns a list of the named parameters of the model (extracted
+ from parallel)
+ """
+ from accelerate.utils import extract_model_from_parallel
+
+ model = extract_model_from_parallel(model)
+ return {n: p for n, p in model.named_parameters()}
+
+
+def evaluate_model(model, dataloader, metric, accelerator=None):
+ "Turns model to .eval(), runs dataloader, calculates metric, then turns eval back on"
+ model.eval()
+ for step, batch in enumerate(dataloader):
+ with torch.no_grad():
+ outputs = model(**batch)
+ predictions = outputs.logits.argmax(dim=-1)
+ references = batch["labels"]
+ if accelerator is not None and accelerator.num_processes > 1:
+ predictions, references = accelerator.gather_for_metrics((predictions, references))
+ metric.add_batch(predictions=predictions, references=references)
+ return metric.compute()
diff --git a/benchmarks/fp8/torchao/fsdp.py b/benchmarks/fp8/torchao/fsdp.py
new file mode 100644
index 00000000000..42eedb48bd5
--- /dev/null
+++ b/benchmarks/fp8/torchao/fsdp.py
@@ -0,0 +1,173 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This script tests to ensure that `accelerate` performs at the same level as raw `torchao`.
+
+This particular script verifies this for FSDP training.
+"""
+
+from functools import partial
+
+import evaluate
+import torch
+from fp8_utils import get_training_utilities
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from torch.distributed.fsdp import MixedPrecision
+from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
+from torchao.float8 import convert_to_float8_training
+from transformers.models.bert import BertLayer
+
+from accelerate import Accelerator
+from accelerate import FullyShardedDataParallelPlugin as FSDPPlugin
+from accelerate.state import AcceleratorState
+from accelerate.utils import AORecipeKwargs, set_seed
+
+
+MODEL_NAME = "bert-base-cased"
+METRIC = evaluate.load("glue", "mrpc")
+
+FSDP_WRAP_POLICY = partial(transformer_auto_wrap_policy, transformer_layer_cls={BertLayer})
+
+
+def filter_linear_layers(module, fqn, first_layer_name=None, last_layer_name=None):
+ if isinstance(module, torch.nn.Linear):
+ if module.in_features % 16 != 0 or module.out_features % 16 != 0:
+ return False
+ # For stability reasons, we skip the first and last linear layers
+ # Otherwise can lead to the model not training or converging properly
+ if fqn in (first_layer_name, last_layer_name):
+ return False
+ return True
+
+
+def evaluate_model(model, dataloader, metric, accelerator=None):
+ "Turns model to .eval(), runs dataloader, calculates metric, then turns eval back on"
+ model.eval()
+ for step, batch in enumerate(dataloader):
+ with torch.no_grad():
+ outputs = model(**batch)
+ predictions = outputs.logits.argmax(dim=-1)
+ references = batch["labels"]
+ if accelerator is not None and accelerator.num_processes > 1:
+ predictions, references = accelerator.gather_for_metrics((predictions, references))
+ metric.add_batch(predictions=predictions, references=references)
+ return metric.compute()
+
+
+def train_baseline():
+ set_seed(42)
+ model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME)
+ first_linear = None
+ last_linear = None
+ for name, module in model.named_modules():
+ if isinstance(module, torch.nn.Linear):
+ if first_linear is None:
+ first_linear = name
+ last_linear = name
+ func = partial(filter_linear_layers, first_layer_name=first_linear, last_layer_name=last_linear)
+ accelerator = Accelerator()
+ device = accelerator.device
+ model.to(device)
+
+ convert_to_float8_training(model, module_filter_fn=func)
+
+ # Convert the model to FSDP
+ model = FSDP(
+ model,
+ use_orig_params=True,
+ mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32),
+ auto_wrap_policy=FSDP_WRAP_POLICY,
+ )
+
+ base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
+ model.train()
+
+ for batch in train_dataloader:
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
+ batch = batch.to(device)
+ outputs = model(**batch)
+ loss = outputs.loss
+ loss.backward()
+ optimizer.step()
+ optimizer.zero_grad()
+ lr_scheduler.step()
+
+ trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
+
+ assert (
+ trained_model_results["accuracy"] > base_model_results["accuracy"]
+ ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}'
+ assert (
+ trained_model_results["f1"] > base_model_results["f1"]
+ ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}'
+
+ return base_model_results, trained_model_results
+
+
+def train_integration():
+ AcceleratorState()._reset_state(True)
+ fsdp_plugin = FSDPPlugin(
+ auto_wrap_policy=FSDP_WRAP_POLICY,
+ use_orig_params=True,
+ mixed_precision_policy=MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32),
+ )
+ accelerator = Accelerator(mixed_precision="fp8", fsdp_plugin=fsdp_plugin, kwargs_handlers=[AORecipeKwargs()])
+ set_seed(42)
+ model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(
+ MODEL_NAME, accelerator=accelerator
+ )
+
+ model, optimizer = accelerator.prepare(model, optimizer)
+ base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
+ model.train()
+
+ for batch in train_dataloader:
+ outputs = model(**batch)
+ loss = outputs.loss
+ accelerator.backward(loss)
+ optimizer.step()
+ optimizer.zero_grad()
+ lr_scheduler.step()
+
+ trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
+
+ assert (
+ trained_model_results["accuracy"] > base_model_results["accuracy"]
+ ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}'
+ assert (
+ trained_model_results["f1"] > base_model_results["f1"]
+ ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}'
+
+ return base_model_results, trained_model_results
+
+
+if __name__ == "__main__":
+ baseline_not_trained, baseline_trained = train_baseline()
+ accelerator_not_trained, accelerator_trained = train_integration()
+
+ assert (
+ baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"]
+ ), f'Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}'
+ assert (
+ baseline_not_trained["f1"] == accelerator_not_trained["f1"]
+ ), f'F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}'
+ assert (
+ baseline_trained["accuracy"] == accelerator_trained["accuracy"]
+ ), f'Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}'
+ assert (
+ baseline_trained["f1"] == accelerator_trained["f1"]
+ ), f'F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}'
+
+ torch.distributed.destroy_process_group()
diff --git a/benchmarks/fp8/torchao/non_distributed.py b/benchmarks/fp8/torchao/non_distributed.py
new file mode 100644
index 00000000000..7b8e5993e42
--- /dev/null
+++ b/benchmarks/fp8/torchao/non_distributed.py
@@ -0,0 +1,145 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This script tests to ensure that `accelerate` performs at the same level as raw `torchao`.
+
+This particular script verifies this for single GPU training.
+"""
+
+from functools import partial
+
+import evaluate
+import torch
+from fp8_utils import get_training_utilities
+from torchao.float8 import convert_to_float8_training
+
+from accelerate import Accelerator
+from accelerate.state import AcceleratorState
+from accelerate.utils import AORecipeKwargs, set_seed
+
+
+MODEL_NAME = "bert-base-cased"
+METRIC = evaluate.load("glue", "mrpc")
+
+
+def evaluate_model(model, dataloader, metric, accelerator=None):
+ "Turns model to .eval(), runs dataloader, calculates metric, then turns eval back on"
+ model.eval()
+ for step, batch in enumerate(dataloader):
+ with torch.no_grad():
+ outputs = model(**batch)
+ predictions = outputs.logits.argmax(dim=-1)
+ references = batch["labels"]
+ if accelerator is not None and accelerator.num_processes > 1:
+ predictions, references = accelerator.gather_for_metrics((predictions, references))
+ metric.add_batch(predictions=predictions, references=references)
+ return metric.compute()
+
+
+def filter_linear_layers(module, fqn, first_layer_name=None, last_layer_name=None):
+ if isinstance(module, torch.nn.Linear):
+ if module.in_features % 16 != 0 or module.out_features % 16 != 0:
+ return False
+ # For stability reasons, we skip the first and last linear layers
+ # Otherwise can lead to the model not training or converging properly
+ if fqn in (first_layer_name, last_layer_name):
+ return False
+ return True
+
+
+def train_baseline():
+ set_seed(42)
+ model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME)
+ first_linear = None
+ last_linear = None
+ for name, module in model.named_modules():
+ if isinstance(module, torch.nn.Linear):
+ if first_linear is None:
+ first_linear = name
+ last_linear = name
+
+ func = partial(filter_linear_layers, first_layer_name=first_linear, last_layer_name=last_linear)
+ model.to("cuda")
+ convert_to_float8_training(model, module_filter_fn=func)
+ base_model_results = evaluate_model(model, eval_dataloader, METRIC)
+ model.train()
+
+ for batch in train_dataloader:
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
+ outputs = model(**batch)
+ loss = outputs.loss
+ loss.backward()
+ optimizer.step()
+ optimizer.zero_grad()
+ lr_scheduler.step()
+
+ trained_model_results = evaluate_model(model, eval_dataloader, METRIC)
+
+ assert (
+ trained_model_results["accuracy"] > base_model_results["accuracy"]
+ ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}'
+ assert (
+ trained_model_results["f1"] > base_model_results["f1"]
+ ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}'
+
+ return base_model_results, trained_model_results
+
+
+def train_integration():
+ set_seed(42)
+ accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=[AORecipeKwargs()])
+ model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(
+ MODEL_NAME, accelerator=accelerator
+ )
+ model = accelerator.prepare(model)
+ base_model_results = evaluate_model(model, eval_dataloader, METRIC)
+ model.train()
+
+ for batch in train_dataloader:
+ outputs = model(**batch)
+ loss = outputs.loss
+ loss.backward()
+ optimizer.step()
+ optimizer.zero_grad()
+ lr_scheduler.step()
+
+ trained_model_results = evaluate_model(model, eval_dataloader, METRIC)
+
+ assert (
+ trained_model_results["accuracy"] > base_model_results["accuracy"]
+ ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}'
+ assert (
+ trained_model_results["f1"] > base_model_results["f1"]
+ ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}'
+
+ return base_model_results, trained_model_results
+
+
+if __name__ == "__main__":
+ baseline_not_trained, baseline_trained = train_baseline()
+ AcceleratorState._reset_state(True)
+ accelerator_not_trained, accelerator_trained = train_integration()
+ assert (
+ baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"]
+ ), f'Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}'
+ assert (
+ baseline_not_trained["f1"] == accelerator_not_trained["f1"]
+ ), f'F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}'
+ assert (
+ baseline_trained["accuracy"] == accelerator_trained["accuracy"]
+ ), f'Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}'
+ assert (
+ baseline_trained["f1"] == accelerator_trained["f1"]
+ ), f'F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}'
diff --git a/benchmarks/fp8/transformer_engine/distrib_deepspeed.py b/benchmarks/fp8/transformer_engine/distrib_deepspeed.py
index e678deb3659..73953b6793f 100644
--- a/benchmarks/fp8/transformer_engine/distrib_deepspeed.py
+++ b/benchmarks/fp8/transformer_engine/distrib_deepspeed.py
@@ -66,7 +66,7 @@ def train_baseline(zero_stage: int = 1):
import numpy as np
config = {
- "train_batch_size": 32,
+ "train_batch_size": 16,
"train_micro_batch_size_per_gpu": 16,
"gradient_accumulation_steps": 1,
"zero_optimization": {
@@ -170,21 +170,22 @@ def train_integration(zero_stage: int = 1):
if __name__ == "__main__":
- # for zero_stage in [1, 2, 3]:
- zero_stage = 1
- baseline_not_trained, baseline_trained, baseline_outputs, baseline_data = train_baseline(zero_stage)
- accelerator_not_trained, accelerator_trained, accelerator_outputs, accelerator_data = train_integration(zero_stage)
- assert (
- baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"]
- ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}'
- assert (
- baseline_not_trained["f1"] == accelerator_not_trained["f1"]
- ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}'
- assert (
- baseline_trained["accuracy"] == accelerator_trained["accuracy"]
- ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}'
- assert (
- baseline_trained["f1"] == accelerator_trained["f1"]
- ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}'
-
- torch.distributed.destroy_process_group()
+ for zero_stage in [1, 2, 3]:
+ baseline_not_trained, baseline_trained, baseline_outputs, baseline_data = train_baseline(zero_stage)
+ accelerator_not_trained, accelerator_trained, accelerator_outputs, accelerator_data = train_integration(
+ zero_stage
+ )
+ assert (
+ baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"]
+ ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}'
+ assert (
+ baseline_not_trained["f1"] == accelerator_not_trained["f1"]
+ ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}'
+ assert (
+ baseline_trained["accuracy"] == accelerator_trained["accuracy"]
+ ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}'
+ assert (
+ baseline_trained["f1"] == accelerator_trained["f1"]
+ ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}'
+
+ torch.distributed.destroy_process_group()
diff --git a/docs/source/usage_guides/low_precision_training.md b/docs/source/usage_guides/low_precision_training.md
index c730136e1ce..08e533e60cf 100644
--- a/docs/source/usage_guides/low_precision_training.md
+++ b/docs/source/usage_guides/low_precision_training.md
@@ -15,7 +15,7 @@ rendered properly in your Markdown viewer.
# Low Precision Training Methods
-Accelerate provides integrations to train on lower precision methods using specified supported hardware through the `TransformersEngine` and `MS-AMP` packages. This documentation will help guide you through what hardware is supported, how to configure your [`Accelerator`] to leverage the low precision methods, and what you can expect when training.
+Accelerate provides integrations to train on lower precision methods using specified supported hardware through the `TransformersEngine`, `MS-AMP`, and `torchao` packages. This documentation will help guide you through what hardware is supported, how to configure your [`Accelerator`] to leverage the low precision methods, and what you can expect when training.
## What training on FP8 means
@@ -30,7 +30,7 @@ What this will result in is some gain in the memory used (as we've cut the neede
## Configuring the Accelerator
-Currently two different backends for FP8 are supported (`TransformersEngine` and `MS-AMP`), each with different capabilities and configurations.
+Currently three different backends for FP8 are supported (`TransformersEngine`, `torchao`, and `MS-AMP`), each with different capabilities and configurations.
To use either, the same core API is used. Just pass `mixed_precision="fp8"` to either the [`Accelerator`], during `accelerate config` when prompted about mixed precision, or as part of your `config.yaml` file in the `mixed_precision` key:
@@ -39,14 +39,16 @@ from accelerate import Accelerator
accelerator = Accelerator(mixed_precision="fp8")
```
-By default, if `MS-AMP` is available in your environment, Accelerate will automatically utilize it as a backend. To specify it yourself (and customize other parts of the FP8 mixed precision setup), you can utilize the [`utils.FP8RecipeKwargs`] or clarify it in your config `yaml`/during `accelerate launch`:
+By default, if `MS-AMP` is available in your environment, Accelerate will automatically utilize it as a backend. To specify it yourself (and customize other parts of the FP8 mixed precision setup), you can utilize one of the `RecipeKwargs` dataclasses such as [`utils.AORecipeKwargs`], [`utils.TERecipeKwargs`], or [`utils.MSAMPRecipeKwargs`]; you can also nclarify it in your config `yaml`/during `accelerate launch`:
```{python}
from accelerate import Accelerator
-from accelerate.utils import FP8RecipeKwargs
-kwargs = [FP8RecipeKwargs(backend="msamp")]
+from accelerate.utils import MSAMPRecipeKwargs
+kwargs = [MSAMPRecipeKwargs()]
# Or to specify the backend as `TransformersEngine` even if MS-AMP is installed
-# kwargs = [FP8RecipeKwargs(backend="te")]
+# kwargs = [TERecipeKwargs()]
+# Or to use torchao
+# kwargs = [AORecipeKwargs()]
accelerator = Accelerator(mixed_precision="fp8", kwarg_handlers=kwargs)
```
@@ -124,6 +126,22 @@ fp8_config:
use_autocast_during_eval: false
```
+## Configuring `torchao`
+
+`torchao` is a [PyTorch-driven](https://github.com/pytorch/ao/tree/main/torchao/float8) hackable FP8 backend, aiming to be more approchable than the prior two engines. One of the core differences with `ao` compared to the prior two is that for numerical stability, it's found to be generally better off keeping the first *and* last layers in the model at the regular precision (be it FP32 or BF16), and then the other layers quantized down to FP8. As a result, a config for `ao` looks a bit differently:
+
+> Note: this API is experimental and is subject to change
+
+```{python}
+from accelerate import Accelerator
+from accelerate.utils import AORecipeKwargs
+kwargs = [AORecipeKwargs()]
+accelerator = Accelerator(mixed_precision="fp8", kwarg_handlers=kwargs)
+```
+
+To learn more about the specific parameters to be used, please see the official `torchao` repo.
+
+
## Example Zoo
We have examples showcasing training with FP8 both with accelerate and its underlying implementation available in the accelerate repo.
@@ -143,3 +161,4 @@ To learn more about training in FP8 please check out the following resources:
* [Our concept guide](../concept_guides/low_precision_training) detailing into more about both TransformersEngine and MS-AMP
* [The `transformers-engine` documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html)
* [The `MS-AMP` documentation](https://azure.github.io/MS-AMP/docs/)
+* [The `torchao` documentation](https://github.com/pytorch/ao/tree/main/torchao/float8)
diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py
index a483f0d1a39..7d05dafb600 100755
--- a/src/accelerate/accelerator.py
+++ b/src/accelerate/accelerator.py
@@ -33,6 +33,8 @@
import torch.utils.hooks as hooks
from huggingface_hub import split_torch_state_dict_into_shards
+from accelerate.utils.imports import is_torchao_available
+
from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state
from .data_loader import DataLoaderDispatcher, prepare_data_loader, skip_first_batches
from .logging import get_logger
@@ -48,6 +50,7 @@
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
WEIGHTS_PATTERN_NAME,
+ AORecipeKwargs,
AutocastKwargs,
DataLoaderConfiguration,
DeepSpeedPlugin,
@@ -62,10 +65,12 @@
KwargsHandler,
LoggerType,
MegatronLMPlugin,
+ MSAMPRecipeKwargs,
PrecisionType,
ProfileKwargs,
ProjectConfiguration,
RNGType,
+ TERecipeKwargs,
TorchDynamoPlugin,
TorchTensorParallelPlugin,
apply_fp8_autowrap,
@@ -73,6 +78,7 @@
clean_state_dict_for_safetensors,
compare_versions,
convert_model,
+ convert_model_to_fp8_ao,
convert_outputs_to_fp32,
ensure_weights_retied,
extract_model_from_parallel,
@@ -409,45 +415,39 @@ def __init__(
self.scaler_handler = None
self.init_handler = None
self.fp8_recipe_handler = None
+ self.ao_recipe_handler = None
+ self.te_recipe_handler = None
+ self.msamp_recipe_handler = None
self.autocast_handler = None
self.profile_handler = None
self.has_lomo_optimizer = False
+ found_handlers = set()
+ handler_class_to_attr = {
+ DistributedDataParallelKwargs: "ddp_handler",
+ GradScalerKwargs: "scaler_handler",
+ InitProcessGroupKwargs: "init_handler",
+ FP8RecipeKwargs: "fp8_recipe_handler",
+ AutocastKwargs: "autocast_handler",
+ ProfileKwargs: "profile_handler",
+ AORecipeKwargs: "ao_recipe_handler",
+ TERecipeKwargs: "te_recipe_handler",
+ MSAMPRecipeKwargs: "msamp_recipe_handler",
+ }
+ self.has_fp8_handler = False
if kwargs_handlers is not None:
for handler in kwargs_handlers:
assert isinstance(
handler, KwargsHandler
), f"Unsupported kwargs handler passed: {handler}, must be one that inherits `accelerate.utils.KwargsHandler`."
- if isinstance(handler, DistributedDataParallelKwargs):
- if self.ddp_handler is not None:
- raise ValueError("You can only pass one `DistributedDataParallelKwargs` in `kwargs_handler`.")
- else:
- self.ddp_handler = handler
- elif isinstance(handler, GradScalerKwargs):
- if self.scaler_handler is not None:
- raise ValueError("You can only pass one `GradScalerKwargs` in `kwargs_handler`.")
- else:
- self.scaler_handler = handler
- elif isinstance(handler, InitProcessGroupKwargs):
- if self.init_handler is not None:
- raise ValueError("You can only pass one `InitProcessGroupKwargs` in `kwargs_handler`.")
- else:
- self.init_handler = handler
- elif isinstance(handler, FP8RecipeKwargs):
- if self.fp8_recipe_handler is not None:
- raise ValueError("You can only pass one `FP8RecipeKwargs` in `kwargs_handler`.")
- else:
- self.fp8_recipe_handler = handler
- elif isinstance(handler, AutocastKwargs):
- if self.autocast_handler is not None:
- raise ValueError("You can only pass one `AutocastKwargs` in `kwargs_handler`.")
- else:
- self.autocast_handler = handler
- elif isinstance(handler, ProfileKwargs):
- if self.profile_handler is not None:
- raise ValueError("You can only pass one `ProfileKwargs` in `kwargs_handler`.")
- else:
- self.profile_handler = handler
+ # Add the handler class to the set of found handlers
+ if handler.__class__ in found_handlers:
+ raise ValueError(f"You can only pass one {handler.__class__} in `kwargs_handlers`.")
+ found_handlers.add(handler.__class__)
+ handler_attr = handler_class_to_attr[handler.__class__]
+ setattr(self, handler_attr, handler)
+ if "recipe_handler" in handler_attr and not self.has_fp8_handler:
+ self.has_fp8_handler = True
kwargs = self.init_handler.to_kwargs() if self.init_handler is not None else {}
self.state = AcceleratorState(
@@ -463,17 +463,32 @@ def __init__(
)
self._mixed_precision = mixed_precision
- if mixed_precision == "fp8" and self.fp8_recipe_handler is None:
- self.fp8_recipe_handler = FP8RecipeKwargs()
+ # Check for automatic FP8 recipe creation
+ if self._mixed_precision == "fp8" and not self.has_fp8_handler:
+ # Prioritize TE -> AO -> MSAMP
+ if is_torchao_available():
+ logger.info("Found `torchao` installed, using it for FP8 training.")
+ self.ao_recipe_handler = AORecipeKwargs()
+ elif is_transformer_engine_available():
+ logger.info("Found `transformer-engine` installed, using it for FP8 training.")
+ self.te_recipe_handler = TERecipeKwargs()
+ elif is_msamp_available():
+ logger.info("Found `msamp` installed, using it for FP8 training.")
+ self.msamp_recipe_handler = MSAMPRecipeKwargs()
+ else:
+ raise ImportError(
+ "Tried to train with `fp8` and auto-detect backend, but no FP8-compatible backend was installed. "
+ "Valid backends are: `torchao`, `transformer-engine`, and `msamp`."
+ )
self.delayed_fp8_autocast = False
- if self.fp8_recipe_handler is not None:
+ if self.has_fp8_handler:
# We already check if FP8 is available during `self.state`
if mixed_precision != "fp8" and (
self.distributed_type not in (DistributedType.FSDP, DistributedType.DEEPSPEED)
):
- raise ValueError("Passing in a `FP8RecipeKwargs` object requires setting `mixed_precision='fp8'`.")
- self.delayed_fp8_autocast = self.fp8_recipe_handler.backend == "TE" and self.distributed_type in (
+ raise ValueError("Passing in an FP8 configuration requires setting `mixed_precision='fp8'`.")
+ self.delayed_fp8_autocast = self.fp8_backend == "TE" and self.distributed_type in (
DistributedType.MULTI_GPU,
DistributedType.FSDP,
)
@@ -1362,6 +1377,8 @@ def prepare(self, *args, device_placement=None):
args = self._prepare_ipex_or_xpu(*args)
if self.fp8_backend == "TE":
args = self._prepare_te(*args)
+ elif self.fp8_backend == "AO":
+ args = self._prepare_ao(*args)
if self.distributed_type == DistributedType.DEEPSPEED:
result = self._prepare_deepspeed(*args)
elif self.distributed_type == DistributedType.MEGATRON_LM:
@@ -1447,7 +1464,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
# We prepare TE after, allowing for bf16 autocast to happen first
if self.fp8_backend == "TE" and not self.delayed_fp8_autocast:
- model = apply_fp8_autowrap(model, self.fp8_recipe_handler)
+ model = apply_fp8_autowrap(model, self.te_recipe_handler or self.fp8_recipe_handler)
if (getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)) and getattr(
model, "hf_device_map", False
@@ -1651,12 +1668,26 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
model = xmp.MpModelWrapper(model).to(self.device)
# Now we can apply the FP8 autocast
if self.delayed_fp8_autocast:
- model = apply_fp8_autowrap(model, self.fp8_recipe_handler)
+ model = apply_fp8_autowrap(model, self.te_recipe_handler or self.fp8_recipe_handler)
# torch.compile should be called last and only if the model isn't already compiled.
if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model):
model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs())
return model
+ def _prepare_ao(self, *args):
+ if not is_torchao_available():
+ raise ImportError(
+ "`torchao` was not found on your system or is too old of a version. Please ensure that `torchao >= 0.6.1` is installed"
+ )
+ for arg in args:
+ if isinstance(arg, torch.nn.Module):
+ convert_model_to_fp8_ao(
+ arg,
+ config=self.ao_recipe_handler.config,
+ module_filter_func=self.ao_recipe_handler.module_filter_func,
+ )
+ return args
+
def _prepare_te(self, *args):
if not is_transformer_engine_available():
raise ImportError(
@@ -1811,7 +1842,7 @@ def _prepare_deepspeed(self, *args):
if model is not None:
# If we are using FP8, we need to apply the autowrap now
- if getattr(self.fp8_recipe_handler, "backend", None) == "TE":
+ if self.fp8_backend == "TE":
model = apply_fp8_autowrap(model, self.fp8_recipe_handler)
# if the model is an MOE, set the appropriate MOE layers as leaf Z3 modules
deepspeed_plugin.set_moe_leaf_modules(model)
@@ -2106,7 +2137,12 @@ def _prepare_msamp(self, *args, device_placement):
f"You can't use multiple models ({num_models}) or optimizers {num_optimizers} with MS-AMP."
)
else:
- model, optimizer = msamp.initialize(model, optimizer, opt_level=self.fp8_recipe_handler.opt_level)
+ # DEPRECATE @ 2.0
+ if self.fp8_recipe_handler is not None:
+ opt_level = self.fp8_recipe_handler.opt_level
+ else:
+ opt_level = self.msamp_recipe_handler.opt_level
+ model, optimizer = msamp.initialize(model, optimizer, opt_level=opt_level)
for i in range(len(result)):
if isinstance(result[i], torch.nn.Module):
result[i] = model
@@ -3647,8 +3683,15 @@ def lomo_backward(self, loss: torch.Tensor, learning_rate: float) -> None:
@property
def fp8_backend(self):
"Returns the configured backend for training in FP8"
- if self._mixed_precision == "fp8" and self.fp8_recipe_handler is not None:
- return self.fp8_recipe_handler.backend
+ if self.has_fp8_handler:
+ if self.fp8_recipe_handler is not None:
+ return self.fp8_recipe_handler.backend
+ elif self.ao_recipe_handler is not None:
+ return "AO"
+ elif self.te_recipe_handler is not None:
+ return "TE"
+ elif self.msamp_recipe_handler is not None:
+ return "MSAMP"
elif self.state.deepspeed_plugin is not None and self.state.deepspeed_plugin.enable_msamp:
return "MSAMP"
return None
diff --git a/src/accelerate/test_utils/__init__.py b/src/accelerate/test_utils/__init__.py
index f41473f6363..6b59fb0246d 100644
--- a/src/accelerate/test_utils/__init__.py
+++ b/src/accelerate/test_utils/__init__.py
@@ -41,6 +41,7 @@
require_single_gpu,
require_single_xpu,
require_torch_min_version,
+ require_torchao,
require_torchvision,
require_tpu,
require_transformer_engine,
diff --git a/src/accelerate/test_utils/testing.py b/src/accelerate/test_utils/testing.py
index bbcfc616ada..48350b0a4a7 100644
--- a/src/accelerate/test_utils/testing.py
+++ b/src/accelerate/test_utils/testing.py
@@ -53,6 +53,7 @@
is_timm_available,
is_torch_version,
is_torch_xla_available,
+ is_torchao_available,
is_torchdata_stateful_dataloader_available,
is_torchvision_available,
is_transformer_engine_available,
@@ -425,6 +426,13 @@ def require_transformer_engine(test_case):
return unittest.skipUnless(is_transformer_engine_available(), "test requires transformers engine")(test_case)
+def require_torchao(test_case):
+ """
+ Decorator marking a test that requires torchao installed. These tests are skipped when torchao isn't installed
+ """
+ return unittest.skipUnless(is_torchao_available(), "test requires torchao")(test_case)
+
+
_atleast_one_tracker_available = (
any([is_wandb_available(), is_tensorboard_available()]) and not is_comet_ml_available()
)
diff --git a/src/accelerate/utils/__init__.py b/src/accelerate/utils/__init__.py
index e0ea5841372..6b5bed62d73 100644
--- a/src/accelerate/utils/__init__.py
+++ b/src/accelerate/utils/__init__.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from .ao import convert_model_to_fp8_ao, filter_first_and_last_linear_layers, has_ao_layers
from .constants import (
MITA_PROFILING_AVAILABLE_PYTORCH_VERSION,
MODEL_NAME,
@@ -32,6 +33,7 @@
XPU_PROFILING_AVAILABLE_PYTORCH_VERSION,
)
from .dataclasses import (
+ AORecipeKwargs,
AutocastKwargs,
BnbQuantizationConfig,
ComputeEnvironment,
@@ -50,12 +52,14 @@
KwargsHandler,
LoggerType,
MegatronLMPlugin,
+ MSAMPRecipeKwargs,
PrecisionType,
ProfileKwargs,
ProjectConfiguration,
RNGType,
SageMakerDistributedType,
TensorInformation,
+ TERecipeKwargs,
TorchDynamoPlugin,
TorchTensorParallelPlugin,
add_model_config_to_megatron_parser,
@@ -115,6 +119,7 @@
is_tensorboard_available,
is_timm_available,
is_torch_xla_available,
+ is_torchao_available,
is_torchdata_available,
is_torchdata_stateful_dataloader_available,
is_torchvision_available,
@@ -124,6 +129,7 @@
is_wandb_available,
is_weights_only_available,
is_xpu_available,
+ torchao_required,
)
from .modeling import (
align_module_device,
diff --git a/src/accelerate/utils/ao.py b/src/accelerate/utils/ao.py
new file mode 100644
index 00000000000..2023371ca66
--- /dev/null
+++ b/src/accelerate/utils/ao.py
@@ -0,0 +1,139 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Needed utilities for torchao FP8 training.
+"""
+
+from functools import partial
+from typing import Callable, List, Optional
+
+import torch
+
+from .imports import is_torchao_available, torchao_required
+
+
+if is_torchao_available():
+ from torchao.float8.float8_linear import Float8LinearConfig
+
+
+def find_first_last_linear_layers(model: torch.nn.Module):
+ """
+ Finds the first and last linear layer names in a model.
+
+ This is needed during FP8 to avoid issues with instability by keeping the first and last layers unquantized.
+
+ Ref: https://x.com/xariusrke/status/1826669142604141052
+ """
+ first_linear, last_linear = None, None
+ for name, module in model.named_modules():
+ if isinstance(module, torch.nn.Linear):
+ if first_linear is None:
+ first_linear = name
+ last_linear = name
+ return first_linear, last_linear
+
+
+def filter_linear_layers(module, fqn: str, layers_to_filter: List[str]) -> bool:
+ """
+ A function which will check if `module` is:
+ - a `torch.nn.Linear` layer
+ - has in_features and out_features divisible by 16
+ - is not part of `layers_to_filter`
+
+ Args:
+ module (`torch.nn.Module`):
+ The module to check.
+ fqn (`str`):
+ The fully qualified name of the layer.
+ layers_to_filter (`List[str]`):
+ The list of layers to filter.
+ """
+ if isinstance(module, torch.nn.Linear):
+ if module.in_features % 16 != 0 or module.out_features % 16 != 0:
+ return False
+ if fqn in layers_to_filter:
+ return False
+ return True
+
+
+def filter_first_and_last_linear_layers(module, fqn: str) -> bool:
+ """
+ A filter function which will filter out all linear layers except the first and last.
+
+
+
+ For stability reasons, we skip the first and last linear layers Otherwise can lead to the model not training or
+ converging properly
+
+
+
+ Args:
+ module (`torch.nn.Module`):
+ The module to check.
+ fqn (`str`):
+ The fully qualified name of the layer.
+ """
+ first_linear, last_linear = find_first_last_linear_layers(module)
+ return filter_linear_layers(module, fqn, layers_to_filter=[first_linear, last_linear])
+
+
+@torchao_required
+def has_ao_layers(model: torch.nn.Module):
+ from torchao.float8.float8_linear import Float8Linear
+
+ for name, module in model.named_modules():
+ if isinstance(module, Float8Linear):
+ return True
+ return False
+
+
+@torchao_required
+def convert_model_to_fp8_ao(
+ model: torch.nn.Module,
+ config: Optional["Float8LinearConfig"] = None,
+ module_filter_func: Optional[Callable] = filter_first_and_last_linear_layers,
+):
+ """
+ Converts all `nn.Linear` layers in the model (except the first and last) to torchao's `Float8Linear` layer inplace.
+
+ Args:
+ model (`torch.nn.Module`):
+ The model to convert.
+ config (`torchao.float8.Float8LinearConfig`, *optional*):
+ The configuration for the FP8 training. Recommended to utilize
+ `torchao.float8.recipe_name_to_linear_config` to generate this. In general, the default config should be
+ sufficient (what is passed when set to `None`).
+ module_filter_func (`Callable`, *optional*, defaults to `filter_linear_layers`):
+ Optional function that must take in a module and layer name, and returns a boolean indicating whether the
+ module should be converted to FP8. Defaults to `filter_linear_layers`. See it for an example.
+
+ Example:
+
+ ```python
+ from accelerate.utils.ao import convert_model_to_fp8_ao
+
+ model = MyModel()
+ model.to("cuda")
+ convert_to_float8_training(model)
+
+ model.train()
+ ```
+ """
+ from torchao.float8 import convert_to_float8_training
+
+ first_linear, last_linear = find_first_last_linear_layers(model)
+ if module_filter_func is None:
+ module_filter_func = partial(filter_linear_layers, first_layer_name=first_linear, last_layer_name=last_linear)
+ convert_to_float8_training(model, module_filter_fn=module_filter_func, config=config)
diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py
index 3baa525d294..9936ee8c00c 100644
--- a/src/accelerate/utils/dataclasses.py
+++ b/src/accelerate/utils/dataclasses.py
@@ -20,12 +20,13 @@
import copy
import enum
import functools
+import logging
import os
import warnings
from contextlib import contextmanager
from dataclasses import dataclass, field
from datetime import timedelta
-from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union, get_args
+from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union, get_args
import torch
@@ -50,6 +51,13 @@
from .versions import compare_versions, is_torch_version
+if TYPE_CHECKING:
+ # Mock imports for type checking
+ from torchao.float8 import Float8LinearConfig
+
+logger = logging.getLogger(__name__)
+
+
class KwargsHandler:
"""
Internal mixin that implements a `to_kwargs()` method for a dataclass.
@@ -281,40 +289,48 @@ def __post_init__(self):
AmaxComputeAlgorithm = Literal["max", "most_recent"]
+# FP8 training recipe kwargs
+@dataclass
+class AORecipeKwargs(KwargsHandler):
+ """
+ Use this object in your [`Accelerator`] to customize the initialization of the recipe for FP8 mixed precision
+ training with `torchao` FP8.
+
+ Args:
+ config (`torchao.float8.Float8LinearConfig`, *optional*, default to `None`):
+ The configuration for the FP8 training. In general, the default config should be sufficient.
+ module_filter_func (`Callable`, *optional*, default to `None`):
+ Optional function that must take in a module and layer name, and returns a boolean indicating whether the
+ module should be converted to FP8. Defaults to `accelerate.utils.ao.filter_linear_layers`. See it for an
+ example.
+ """
+
+ config: Optional["Float8LinearConfig"] = None
+ module_filter_func: Optional[Callable] = None
+
+
@dataclass
-class FP8RecipeKwargs(KwargsHandler):
+class TERecipeKwargs(KwargsHandler):
"""
Use this object in your [`Accelerator`] to customize the initialization of the recipe for FP8 mixed precision
- training with `transformer-engine` or `ms-amp`.
+ training with `transformer-engine`.
- For more information on `transformer-engine` args, please refer to the API
+ For more information on the args, please refer to the API
[documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html).
- For more information on the `ms-amp` args, please refer to the Optimization Level
- [documentation](https://azure.github.io/MS-AMP/docs/user-tutorial/optimization-level).
-
```python
from accelerate import Accelerator
- from accelerate.utils import FP8RecipeKwargs
+ from accelerate.utils import TERecipeKwargs
- kwargs = FP8RecipeKwargs(backend="te", fp8_format="HYBRID")
+ kwargs = TERecipeKwargs(fp8_format="HYBRID")
accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=[kwargs])
```
- To use MS-AMP as an engine, pass `backend="msamp"` and the `optimization_level`:
-
- ```python
- kwargs = FP8RecipeKwargs(backend="msamp", optimization_level="02")
- ```
-
Args:
- backend (`str`, *optional*):
- Which FP8 engine to use. Must be one of `"msamp"` (MS-AMP) or `"te"` (TransformerEngine). If not passed,
- will use whichever is available in the environment, prioritizing MS-AMP.
use_autocast_during_eval (`bool`, *optional*, default to `False`):
Whether to use FP8 autocast during eval mode. Generally better metrics are found when this is `False`.
margin (`int`, *optional*, default to 0):
@@ -330,21 +346,9 @@ class FP8RecipeKwargs(KwargsHandler):
The algorithm to use for the scaling factor computation. Must be one of `max` or `most_recent`.
override_linear_precision (`tuple` of three `bool`, *optional*, default to `(False, False, False)`):
Whether or not to execute `fprop`, `dgrad`, and `wgrad` GEMMS in higher precision.
- optimization_level (`str`), one of `O1`, `O2`. (default is `O2`):
- What level of 8-bit collective communication should be used with MS-AMP. In general:
- * O1: Weight gradients and `all_reduce` communications are done in fp8, reducing GPU
- memory usage and communication bandwidth
- * O2: First-order optimizer states are in 8-bit, and second order states are in FP16.
- Only available when using Adam or AdamW. This maintains accuracy and can potentially save the
- highest memory.
- * 03: Specifically for DeepSpeed, implements capabilities so weights and master weights of models
- are stored in FP8. If `fp8` is selected and deepspeed is enabled, will be used by default. (Not
- available currently).
"""
- backend: Backend = None
use_autocast_during_eval: bool = None
- opt_level: OptLevel = None
margin: int = None
interval: int = None
fp8_format: FP8Format = None
@@ -354,50 +358,73 @@ class FP8RecipeKwargs(KwargsHandler):
def __post_init__(self):
env_prefix = "ACCELERATE_FP8_"
+ if not is_transformer_engine_available():
+ raise ImportError("TransformerEngine is not available. Please install it or use a different backend.")
+ if self.use_autocast_during_eval is None:
+ self.use_autocast_during_eval = parse_flag_from_env(env_prefix + "USE_AUTOCAST_DURING_EVAL")
+ if self.margin is None:
+ self.margin = int(os.environ.get(env_prefix + "MARGIN", 0))
+ if self.interval is None:
+ self.interval = int(os.environ.get(env_prefix + "INTERVAL", 1))
+ if self.fp8_format is None:
+ self.fp8_format = os.environ.get(env_prefix + "FORMAT", "HYBRID")
+ self.fp8_format = self.fp8_format.upper()
+ if self.fp8_format not in get_args(FP8Format):
+ raise ValueError(f"`fp8_format` must be one of {' or '.join(get_args(FP8Format))}.")
+ if self.amax_compute_algo is None:
+ self.amax_compute_algo = os.environ.get(env_prefix + "AMAX_COMPUTE_ALGO", "most_recent")
+ self.amax_compute_algo = self.amax_compute_algo.lower()
+ if self.amax_compute_algo not in get_args(AmaxComputeAlgorithm):
+ raise ValueError(f"`amax_compute_algo` must be one of {' or '.join(get_args(AmaxComputeAlgorithm))}")
+ if self.amax_history_len is None:
+ self.amax_history_len = int(os.environ.get(env_prefix + "AMAX_HISTORY_LEN", 1024))
+ if self.override_linear_precision is None:
+ fprop = parse_flag_from_env(env_prefix + "OVERRIDE_FPROP")
+ dgrad = parse_flag_from_env(env_prefix + "OVERRIDE_DGRAD")
+ wgrad = parse_flag_from_env(env_prefix + "OVERRIDE_WGRAD")
+ self.override_linear_precision = (fprop, dgrad, wgrad)
+
+
+@dataclass
+class MSAMPRecipeKwargs(KwargsHandler):
+ """
+ Use this object in your [`Accelerator`] to customize the initialization of the recipe for FP8 mixed precision
+ training with `ms-amp`.
+ """
+
+ opt_level: OptLevel = None
+
+ def __post_init__(self):
+ env_prefix = "ACCELERATE_FP8_"
+ if self.opt_level is None:
+ self.opt_level = os.environ.get(env_prefix + "OPT_LEVEL", "O2")
+ if self.opt_level not in get_args(OptLevel):
+ raise ValueError(f"`opt_level` must be one of {' or '.join(get_args(OptLevel))}")
+
+
+@dataclass
+class FP8RecipeKwargs(TERecipeKwargs, MSAMPRecipeKwargs):
+ """
+ Deprecated. Please use one of the proper FP8 recipe kwargs classes such as `TERecipeKwargs` or `MSAMPRecipeKwargs`
+ instead.
+ """
+
+ backend: Backend = None
+
+ def __post_init__(self):
+ env_prefix = "ACCELERATE_FP8_"
+ warnings.warn(
+ "FP8RecipeKwargs is deprecated and will be removed in Accelerate v2.0.0. "
+ "Please use one of the proper FP8 recipe kwargs classes such as TERecipeKwargs or MSAMPRecipeKwargs instead.",
+ FutureWarning,
+ )
default_backend = "msamp" if is_msamp_available() else "te"
if self.backend is None:
self.backend = os.environ.get(env_prefix + "BACKEND", default_backend)
self.backend = self.backend.upper()
if self.backend not in get_args(Backend):
- raise ValueError("`backend` must be 'MSAMP' or 'TE' (TransformerEngine).")
- # Check TE args
- if self.backend == "TE":
- if not is_transformer_engine_available():
- raise ValueError(
- "TransformerEngine is not available. Please either install it, or use the 'MSAMP' backend (if installed)."
- )
- if self.use_autocast_during_eval is None:
- self.use_autocast_during_eval = parse_flag_from_env(env_prefix + "USE_AUTOCAST_DURING_EVAL")
- if self.margin is None:
- self.margin = int(os.environ.get(env_prefix + "MARGIN", 0))
- if self.interval is None:
- self.interval = int(os.environ.get(env_prefix + "INTERVAL", 1))
- if self.fp8_format is None:
- self.fp8_format = os.environ.get(env_prefix + "FORMAT", "HYBRID")
- self.fp8_format = self.fp8_format.upper()
- if self.fp8_format not in get_args(FP8Format):
- raise ValueError(f"`fp8_format` must be one of {' or '.join(get_args(FP8Format))}.")
- if self.amax_compute_algo is None:
- self.amax_compute_algo = os.environ.get(env_prefix + "AMAX_COMPUTE_ALGO", "most_recent")
- self.amax_compute_algo = self.amax_compute_algo.lower()
- if self.amax_compute_algo not in get_args(AmaxComputeAlgorithm):
- raise ValueError(f"`amax_compute_algo` must be one of {' or '.join(get_args(AmaxComputeAlgorithm))}")
- if self.amax_history_len is None:
- self.amax_history_len = int(os.environ.get(env_prefix + "AMAX_HISTORY_LEN", 1024))
- if self.override_linear_precision is None:
- fprop = parse_flag_from_env(env_prefix + "OVERRIDE_FPROP")
- dgrad = parse_flag_from_env(env_prefix + "OVERRIDE_DGRAD")
- wgrad = parse_flag_from_env(env_prefix + "OVERRIDE_WGRAD")
- self.override_linear_precision = (fprop, dgrad, wgrad)
- elif self.backend == "MSAMP":
- if not is_msamp_available():
- raise ValueError(
- "MS-AMP is not available. Please either install it, or use the 'TE' backend (if installed)."
- )
- if self.opt_level is None:
- self.opt_level = os.environ.get(env_prefix + "OPT_LEVEL", "O2")
- if self.opt_level not in get_args(OptLevel):
- raise ValueError(f"`optimization_level` must be one of {' or '.join(get_args(OptLevel))}")
+ raise ValueError("`backend` must be 'MSAMP' or 'TE' (TransformerEngine) to use `FP8RecipeKwargs`.")
+ super().__post_init__()
# Literal
diff --git a/src/accelerate/utils/imports.py b/src/accelerate/utils/imports.py
index b271dab9a9a..c103b41f737 100644
--- a/src/accelerate/utils/imports.py
+++ b/src/accelerate/utils/imports.py
@@ -110,7 +110,7 @@ def is_lomo_available():
def is_fp8_available():
- return is_msamp_available() or is_transformer_engine_available()
+ return is_msamp_available() or is_transformer_engine_available() or is_torchao_available()
def is_cuda_available():
@@ -142,6 +142,14 @@ def is_torch_xla_available(check_is_tpu=False, check_is_gpu=False):
return True
+def is_torchao_available():
+ package_exists = _is_package_available("torchao")
+ if package_exists:
+ torchao_version = version.parse(importlib.metadata.version("torchao"))
+ return compare_versions(torchao_version, ">=", "0.6.1")
+ return False
+
+
def is_deepspeed_available():
if is_mlu_available():
return _is_package_available("deepspeed", metadata_name="deepspeed-mlu")
@@ -422,6 +430,22 @@ def is_torchdata_stateful_dataloader_available():
return False
+def torchao_required(func):
+ """
+ A decorator that ensures the decorated function is only called when torchao is available.
+ """
+
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ if not is_torchao_available():
+ raise ImportError(
+ "`torchao` is not available, please install it before calling this function via `pip install torchao`."
+ )
+ return func(*args, **kwargs)
+
+ return wrapper
+
+
# TODO: Rework this into `utils.deepspeed` and migrate the "core" chunks into `accelerate.deepspeed`
def deepspeed_required(func):
"""
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 00000000000..8568c82be1c
--- /dev/null
+++ b/tests/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/test_fp8.py b/tests/test_fp8.py
index eb35f183b6a..7e3814c35f2 100644
--- a/tests/test_fp8.py
+++ b/tests/test_fp8.py
@@ -20,13 +20,26 @@
from accelerate import Accelerator
from accelerate.state import AcceleratorState
-from accelerate.test_utils import get_launch_command, require_cuda, require_multi_gpu, require_transformer_engine
+from accelerate.test_utils import (
+ get_launch_command,
+ require_cuda,
+ require_huggingface_suite,
+ require_multi_gpu,
+ require_torchao,
+ require_transformer_engine,
+)
from accelerate.test_utils.testing import require_deepspeed, run_command
-from accelerate.utils import FP8RecipeKwargs, has_transformer_engine_layers
+from accelerate.utils import (
+ AORecipeKwargs,
+ FP8RecipeKwargs,
+ has_ao_layers,
+ has_transformer_engine_layers,
+ is_torchao_available,
+ is_transformer_engine_available,
+)
-def can_convert_model():
- print("Starting basic_fp8_test")
+def can_convert_te_model():
accelerator_kwargs = {"mixed_precision": "fp8", "kwargs_handlers": [FP8RecipeKwargs(backend="TE")]}
accelerator = Accelerator(**accelerator_kwargs)
dataloader = torch.utils.data.DataLoader(torch.randn(10, 32), batch_size=2)
@@ -44,6 +57,20 @@ def maintain_proper_deepspeed_config(expected_version):
), f"Expected zero stage {expected_version} but got {AcceleratorState().deepspeed_plugin.zero_stage}"
+def can_convert_ao_model():
+ from transformers import AutoModelForSequenceClassification
+
+ accelerator_kwargs = {"mixed_precision": "fp8", "kwargs_handlers": [AORecipeKwargs()]}
+ accelerator = Accelerator(**accelerator_kwargs)
+ dataloader = torch.utils.data.DataLoader(torch.randn(10, 32), batch_size=2)
+ model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased")
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
+
+ model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
+ assert has_ao_layers(model)
+
+
@require_transformer_engine
class TestTransformerEngine(unittest.TestCase):
@require_cuda
@@ -91,7 +118,60 @@ def test_can_prepare_model_multigpu_deepspeed(self):
run_command(command)
+@require_torchao
+@require_huggingface_suite
+class TestTorchAO(unittest.TestCase):
+ @require_cuda
+ def test_can_prepare_model_single_gpu(self):
+ command = get_launch_command(num_processes=1, monitor_interval=0.1)
+ command += ["-m", "tests.test_fp8"]
+ run_command(command)
+
+ @require_multi_gpu
+ def test_can_prepare_model_multi_gpu(self):
+ command = get_launch_command(num_processes=2, monitor_interval=0.1)
+ command += ["-m", "tests.test_fp8"]
+ run_command(command)
+
+ @require_deepspeed
+ @require_multi_gpu
+ def test_can_prepare_model_multigpu_deepspeed(self):
+ for zero_stage in [1, 2, 3]:
+ os.environ["ZERO_STAGE"] = str(zero_stage)
+ ds_config = {
+ "bf16": {"enabled": True},
+ "zero_optimization": {
+ "stage": zero_stage,
+ "allgather_partitions": True,
+ "allgather_bucket_size": 2e8,
+ "overlap_comm": True,
+ "reduce_scatter": True,
+ "reduce_bucket_size": 2e8,
+ "contiguous_gradients": True,
+ },
+ "gradient_accumulation_steps": 1,
+ "gradient_clipping": "auto",
+ "steps_per_print": 2000,
+ "train_batch_size": "auto",
+ "train_micro_batch_size_per_gpu": "auto",
+ "wall_clock_breakdown": False,
+ }
+
+ ds_config = json.dumps(ds_config)
+
+ command = get_launch_command(
+ num_processes=2, monitor_interval=0.1, use_deepspeed=True, deepspeed_config_file=ds_config
+ )
+ command += ["-m", "tests.test_fp8"]
+ run_command(command)
+
+
if __name__ == "__main__":
- can_convert_model()
- if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true":
- maintain_proper_deepspeed_config(int(os.environ.get("ZERO_STAGE")))
+ # TE suite
+ if is_transformer_engine_available():
+ can_convert_te_model()
+ if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true":
+ maintain_proper_deepspeed_config(int(os.environ.get("ZERO_STAGE")))
+ # AO suite
+ if is_torchao_available():
+ can_convert_ao_model()