Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torchao float8 training #3348

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions benchmarks/fp8/torchao/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
FROM nvcr.io/nvidia/pytorch:24.07-py3

RUN pip install transformers evaluate datasets
RUN git clone https://github.com/huggingface/accelerate.git

RUN cd accelerate && \
pip install -e . && \
cd benchmarks/fp8

RUN /bin/bash


32 changes: 32 additions & 0 deletions benchmarks/fp8/torchao/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# FP8 Benchmarks

Comparing and running [torchao](https://github.com/pytorch/ao/tree/main/torchao/float8) FP8 with accelerate

## Overview

This repo provides scripts which compare native `torchao` model training against `accelerate`'s own integration. Each modeling type is segmented out via a script, supporting the following:

* Single GPU training (`non_distributed.py`)
* Multi-GPU training via DistributedDataParallelism (`ddp.py`)
* Fully Sharded Data Parallelism (`fsdp.py`)
* DeepSpeed ZeRO 1-3 (`deepspeed.py`)

To run them, it's recommended to use a docker image (see the attached `Dockerfile`) and not install `torchao` manually.

## Running:

There are official Docker images located at `huggingface/accelerate:gpu-fp8-torchao-nightly` which can be used.

You can run all scripts using the core `accelerate launch` command without any `accelerate config` being needed.

For single GPU, run it via `python`:

```bash
python non_distributed.py
```

For the rest, run it via `accelerate launch`:

```bash
accelerate launch ddp.py # or distrib_deepspeed.py, ddp.py
```
158 changes: 158 additions & 0 deletions benchmarks/fp8/torchao/ddp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This script tests to ensure that `accelerate` performs at the same level as raw `torchao`.

This particular script verifies this for DDP training.
"""

from functools import partial

import evaluate
import torch
from fp8_utils import get_training_utilities
from torch.nn.parallel import DistributedDataParallel as DDP
from torchao.float8 import convert_to_float8_training

from accelerate import Accelerator
from accelerate.state import AcceleratorState
from accelerate.utils import AORecipeKwargs, set_seed


MODEL_NAME = "bert-base-cased"
METRIC = evaluate.load("glue", "mrpc")


def evaluate_model(model, dataloader, metric, accelerator=None):
"Turns model to .eval(), runs dataloader, calculates metric, then turns eval back on"
model.eval()
for step, batch in enumerate(dataloader):
with torch.no_grad():
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1)
references = batch["labels"]
if accelerator is not None and accelerator.num_processes > 1:
predictions, references = accelerator.gather_for_metrics((predictions, references))
metric.add_batch(predictions=predictions, references=references)
return metric.compute()


def filter_linear_layers(module, fqn, first_layer_name=None, last_layer_name=None):
if isinstance(module, torch.nn.Linear):
if module.in_features % 16 != 0 or module.out_features % 16 != 0:
return False
Comment on lines +54 to +55
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a quick explanation for that ?

# For stability reasons, we skip the first and last linear layers
# Otherwise can lead to the model not training or converging properly
if fqn in (first_layer_name, last_layer_name):
return False
return True


def train_baseline():
set_seed(42)
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME)
first_linear = None
last_linear = None
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
if first_linear is None:
first_linear = name
last_linear = name
func = partial(filter_linear_layers, first_layer_name=first_linear, last_layer_name=last_linear)
accelerator = Accelerator()
device = accelerator.device
model.to(device)

convert_to_float8_training(model, module_filter_fn=func)

# Convert the model to DDP
device_ids, output_device = [accelerator.local_process_index], accelerator.local_process_index
model = DDP(model, device_ids=device_ids, output_device=output_device)

base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
model.train()

for batch in train_dataloader:
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
batch = batch.to(device)
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()

trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)

assert (
trained_model_results["accuracy"] > base_model_results["accuracy"]
), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}'
assert (
trained_model_results["f1"] > base_model_results["f1"]
), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}'

return base_model_results, trained_model_results


def train_integration():
AcceleratorState()._reset_state(True)
accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=[AORecipeKwargs()])
set_seed(42)
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(
MODEL_NAME, accelerator=accelerator
)

model, optimizer = accelerator.prepare(model, optimizer)
base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
model.train()

for batch in train_dataloader:
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()

trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)

assert (
trained_model_results["accuracy"] > base_model_results["accuracy"]
), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}'
assert (
trained_model_results["f1"] > base_model_results["f1"]
), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}'

return base_model_results, trained_model_results


if __name__ == "__main__":
baseline_not_trained, baseline_trained = train_baseline()
accelerator_not_trained, accelerator_trained = train_integration()

assert (
baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"]
), f'Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}'
assert (
baseline_not_trained["f1"] == accelerator_not_trained["f1"]
), f'F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}'
assert (
baseline_trained["accuracy"] == accelerator_trained["accuracy"]
), f'Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}'
assert (
baseline_trained["f1"] == accelerator_trained["f1"]
), f'F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}'

torch.distributed.destroy_process_group()
Loading