Skip to content

Commit

Permalink
Update whisper fine-tune example (#1049)
Browse files Browse the repository at this point in the history
  • Loading branch information
aksh-at authored Jan 16, 2025
1 parent af3b503 commit 0f5afc7
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 175 deletions.
32 changes: 2 additions & 30 deletions 06_gpu_and_ml/openai_whisper/finetuning/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,38 +7,10 @@ epochs should improve performance, decreasing WER.
You can benchmark this example's performance using Huggingface's [**autoevaluate leaderboard**]https://huggingface.co/spaces/autoevaluate/leaderboards?dataset=mozilla-foundation%2Fcommon_voice_11_0&only_verified=0&task=automatic-speech-recognition&config=hi&split=test&metric=wer).

```bash
python3 -m train \
--model_name_or_path="openai/whisper-small" \
--dataset_config_name="clean" \
--train_split_name="train.100" \
--eval_split_name="validation" \
--preprocessing_num_workers="16" \
--length_column_name="input_length" \
--num_train_epochs="5" \
--freeze_feature_encoder=False \
--per_device_train_batch_size="8" \
--per_device_eval_batch_size="8" \
--gradient_accumulation_steps="8" \
--learning_rate="3e-4" \
--warmup_steps="400" \
--evaluation_strategy="steps" \
--text_column_name="sentence" \
--save_steps="400" \
--eval_steps="400" \
--logging_steps="10" \
--save_total_limit="3" \
--freeze_feature_encoder=False \
--gradient_checkpointing \
--fp16 \
--group_by_length \
--predict_with_generate \
--generation_max_length="40" \
--generation_num_beams="1" \
--do_train --do_eval \
--do_lower_case
modal run train.train --num_train_epochs=10
```

### Testing

Use `python3 -m train.end_to_end_check` to do a full train → serialize → save → load → predict
Use `modal run train.end_to_end_check` to do a full train → serialize → save → load → predict
run in less than 5 minutes, checking that the finetuning program is functional.
15 changes: 8 additions & 7 deletions 06_gpu_and_ml/openai_whisper/finetuning/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
datasets>=1.18.0
evaluate
jiwer
librosa
torch>=1.5
torchaudio
transformers~=4.28.1
datasets~=3.2.0
evaluate~=0.4.3
jiwer~=3.0.5
librosa~=0.10.0
torch~=2.5.1
torchaudio~=2.5.1
transformers~=4.48.0
accelerate~=1.2.1
2 changes: 0 additions & 2 deletions 06_gpu_and_ml/openai_whisper/finetuning/train/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

@dataclass
class ModalAppConfig:
app_name = "example-whisper-fine-tune"
persistent_vol_name = "example-whisper-fine-tune-vol"
dataset = "mozilla-foundation/common_voice_11_0"
cache_dir = "/cache"
model_dir = "/models"
Expand Down
65 changes: 13 additions & 52 deletions 06_gpu_and_ml/openai_whisper/finetuning/train/end_to_end_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,53 +9,13 @@

import pathlib

import modal
from transformers import Seq2SeqTrainingArguments

from .__main__ import app, train
from .config import DataTrainingArguments, ModelArguments, app_config
from .config import app_config
from .logs import get_logger
from .train import app, persistent_volume, train
from .transcribe import whisper_transcribe_audio

test_volume = modal.NetworkFileSystem.from_name(
"example-whisper-fine-tune-test-vol", create_if_missing=True
)

logger = get_logger(__name__)

# Test the `main.train` function by passing in test-specific configuration
# that does only a minimal amount of training steps and saves the model
# to the temporary (ie. ephemeral) network file system disk.
#
# This remote function should take only ~1 min to run.


@app.function(network_file_systems={app_config.model_dir: test_volume})
def test_finetune_one_step_and_save_to_vol(run_id: str):
output_dir = pathlib.Path(app_config.model_dir, run_id)
test_model_args = ModelArguments(
model_name_or_path="openai/whisper-small",
freeze_feature_encoder=False,
)
test_data_args = DataTrainingArguments(
preprocessing_num_workers=16,
max_train_samples=5,
max_eval_samples=5,
)

train(
model_args=test_model_args,
data_args=test_data_args,
training_args=Seq2SeqTrainingArguments(
do_train=True,
output_dir=output_dir,
num_train_epochs=1.0,
learning_rate=3e-4,
warmup_steps=0,
max_steps=1,
),
)


# Test model serialization and persistence by starting a new remote
# function that reads back the model files from the temporary network file system disk
Expand All @@ -66,7 +26,7 @@ def test_finetune_one_step_and_save_to_vol(run_id: str):
# ephemeral app that ran the training has stopped.


@app.function(network_file_systems={app_config.model_dir: test_volume})
@app.function(volumes={app_config.model_dir: persistent_volume})
def test_download_and_tryout_model(run_id: str):
from datasets import Audio, load_dataset
from evaluate import load
Expand All @@ -83,6 +43,7 @@ def test_download_and_tryout_model(run_id: str):
lang_short,
split="test",
streaming=True,
trust_remote_code=True,
)
ds = ds.cast_column("audio", Audio(sampling_rate=16_000))
test_row = next(iter(ds))
Expand Down Expand Up @@ -115,12 +76,12 @@ def test_download_and_tryout_model(run_id: str):
# Any runtime errors or assertion errors will fail the app and exit non-zero.


def run_test() -> int:
with app.run():
test_finetune_one_step_and_save_to_vol.remote(run_id=app.app_id)
test_download_and_tryout_model.remote(run_id=app.app_id)
return 0


if __name__ == "__main__":
raise SystemExit(run_test())
@app.local_entrypoint()
def run_test():
# Test the `main.train` function by passing in test-specific configuration
# that does only a minimal amount of training steps and saves the model
# to the temporary (ie. ephemeral) network file system disk.
#
# This should take only ~1 min to run.
train.remote(num_train_epochs=1.0, warmup_steps=0, max_steps=1)
test_download_and_tryout_model.remote(run_id=app.app_id)
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,25 @@
# Based on the work done in https://huggingface.co/blog/fine-tune-whisper.

import os
import pathlib
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Union

import modal

from .config import DataTrainingArguments, ModelArguments, app_config
from .logs import get_logger, setup_logging

try:
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
except ModuleNotFoundError:
exit(
"The 'transformers' library is required to run both locally and in Modal."
)


persistent_volume = modal.Volume.from_name(
app_config.persistent_vol_name,
"example-whisper-fine-tune-vol",
create_if_missing=True,
)
image = modal.Image.debian_slim().pip_install_from_requirements(
"requirements.txt"
)

image = modal.Image.debian_slim(
python_version="3.12"
).pip_install_from_requirements("requirements.txt")
app = modal.App(
name=app_config.app_name,
name="example-whisper-fine-tune",
image=image,
secrets=[
modal.Secret.from_name("huggingface-secret", required_keys=["HF_TOKEN"])
Expand All @@ -49,9 +41,10 @@
retries=1,
)
def train(
model_args: ModelArguments,
data_args: DataTrainingArguments,
training_args: Seq2SeqTrainingArguments,
num_train_epochs: int = 5,
warmup_steps: int = 400,
max_steps: int = -1,
overwrite_output_dir: bool = False,
):
import datasets
import evaluate
Expand All @@ -64,9 +57,51 @@ def train(
AutoProcessor,
AutoTokenizer,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
)
from transformers.trainer_utils import get_last_checkpoint, is_main_process

model_args = ModelArguments(
model_name_or_path="openai/whisper-small",
freeze_feature_encoder=False,
)

run_id = app.app_id
output_dir = Path(app_config.model_dir, run_id).as_posix()

data_args = DataTrainingArguments(
dataset_config_name="clean",
train_split_name="train.100",
eval_split_name="validation",
text_column_name="sentence",
preprocessing_num_workers=16,
max_train_samples=5,
max_eval_samples=5,
do_lower_case=True,
)

training_args = Seq2SeqTrainingArguments(
length_column_name="input_length",
output_dir=output_dir,
num_train_epochs=num_train_epochs,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
gradient_accumulation_steps=8,
learning_rate=3e-4,
warmup_steps=warmup_steps,
max_steps=max_steps,
evaluation_strategy="steps",
save_total_limit=3,
gradient_checkpointing=True,
fp16=True,
group_by_length=True,
predict_with_generate=True,
generation_max_length=40,
generation_num_beams=1,
do_train=True,
do_eval=True,
)

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
"""
Expand Down Expand Up @@ -146,15 +181,16 @@ def __call__(
)
last_checkpoint = None
if (
os.path.isdir(training_args.output_dir)
Path(training_args.output_dir).exists()
and training_args.do_train
and not training_args.overwrite_output_dir
and not overwrite_output_dir
):
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if (
last_checkpoint is None
and len(os.listdir(training_args.output_dir)) > 0
):
print(os.listdir(training_args.output_dir))
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
Expand All @@ -174,13 +210,12 @@ def __call__(
"mozilla-foundation/common_voice_11_0",
"hi",
split="train+validation",
use_auth_token=os.environ["HF_TOKEN"],
trust_remote_code=True,
)
raw_datasets["eval"] = load_dataset(
"mozilla-foundation/common_voice_11_0",
"hi",
split="test",
use_auth_token=os.environ["HF_TOKEN"],
)

# Most ASR datasets only provide input audio samples (audio) and
Expand Down Expand Up @@ -495,29 +530,3 @@ def compute_metrics(pred):

logger.info("Training run complete!")
return results


def main() -> int:
with app.run(detach=True):
run_id = app.app_id
output_dir = str(pathlib.Path(app_config.model_dir, run_id))
args = sys.argv[1:] + [f"--output_dir={str(output_dir)}"]
# Modal's @app.local_entrypoint() uses tiangolo/typer, which doesn't support
# building CLI interfaces from dataclasses. https://github.com/tiangolo/typer/issues/154
parser = HfArgumentParser(
(ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)
)
(
model_args,
data_args,
training_args,
) = parser.parse_args_into_dataclasses(args)

logger.info("Starting training")
result = train.remote(model_args, data_args, training_args)
logger.info(result)
return 0


if __name__ == "__main__":
raise SystemExit(main())
36 changes: 0 additions & 36 deletions 06_gpu_and_ml/openai_whisper/finetuning/train/transcribe.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
import os
import pathlib
import sys
from typing import TYPE_CHECKING

import modal
from modal.cli.volume import FileType

from .config import app_config
from .logs import get_logger

if TYPE_CHECKING:
Expand All @@ -15,32 +9,6 @@
logger = get_logger(__name__)


def download_model_locally(run_id: str) -> pathlib.Path:
"""
Download a finetuned model locally.
NOTE: These models were trained on GPU and require torch.distributed installed locally.
"""
logger.info(f"Saving finetuning run {run_id} model locally")
vol = modal.NetworkFileSystem.lookup(app_config.persistent_vol_name)
for entry in vol.listdir(f"{run_id}/**"):
p = pathlib.Path(f".{app_config.model_dir}", entry.path)

if entry.type == FileType.DIRECTORY:
p.mkdir(parents=True, exist_ok=True)
elif entry.type == FileType.FILE:
logger.info(f"Downloading {entry.path} to {p}")
p.parent.mkdir(parents=True, exist_ok=True)
with open(p, "wb") as f:
for chunk in vol.read_file(entry.path):
f.write(chunk)
else:
logger.warning(
f"Skipping unknown entry '{p}' with unknown filetype"
)
return pathlib.Path(f".{app_config.model_dir}", run_id)


def whisper_transcribe_local_file(
model_dir: os.PathLike,
language: str,
Expand Down Expand Up @@ -95,7 +63,3 @@ def whisper_transcribe_audio(
predicted_ids, skip_special_tokens=True
)[0]
return predicted_transcription


if __name__ == "__main__":
download_model_locally(run_id=sys.argv[1])

0 comments on commit 0f5afc7

Please sign in to comment.