Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finetuning bugs #100

Merged
merged 55 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
c31a750
fix: Convert processor outputs to numpy arrays
saattrupdan Oct 22, 2024
2bb52f1
chore: Revert commit
saattrupdan Oct 22, 2024
d677a96
debug
saattrupdan Oct 22, 2024
5673c60
debug
saattrupdan Oct 22, 2024
f533a1e
docs: Casing
saattrupdan Oct 22, 2024
777e270
debug: Temp logging
saattrupdan Oct 22, 2024
bd6aa4a
debug
saattrupdan Oct 22, 2024
dbeb536
docs: Remove temp comments
saattrupdan Oct 22, 2024
c284fab
docs: Remove temp comment
saattrupdan Oct 22, 2024
034e277
chore: Logging
saattrupdan Oct 22, 2024
c9d7e65
debug
saattrupdan Oct 22, 2024
71bad8c
chore: Revert
saattrupdan Oct 22, 2024
215f451
fix: Always process dataset, as it is necessary for training
saattrupdan Oct 22, 2024
a52e06d
debug
saattrupdan Oct 22, 2024
91cd695
debug
saattrupdan Oct 22, 2024
7d3af9b
debug
saattrupdan Oct 22, 2024
9085910
fix: Input features to numpy array in data collator
saattrupdan Oct 22, 2024
d5611fd
fix: Process dataset after interleave_datasets
saattrupdan Oct 22, 2024
99ef2ac
fix: Cast audio to sampling rate before interleave_datasets
saattrupdan Oct 22, 2024
fb38d03
chore: Update make recipe
saattrupdan Oct 22, 2024
b6432e4
chore: Make recipe
saattrupdan Oct 22, 2024
26d3cfb
chore: Logging
saattrupdan Oct 22, 2024
8e4df84
docs: Temp comment
saattrupdan Oct 22, 2024
b2658df
chore: Use AutoModelForSpeechSeq2Seq for initiating Whisper
saattrupdan Oct 22, 2024
60daf3f
chore: Use AutoProcessor
saattrupdan Oct 22, 2024
ee08ffb
debug: Try using default_data_collator
saattrupdan Oct 22, 2024
ff46961
chore: Revert
saattrupdan Oct 22, 2024
9be064a
chore: Temporarily only use one dataset
saattrupdan Oct 22, 2024
b8a2d99
chore: Remove temp block
saattrupdan Oct 22, 2024
3541bae
chore: Revert
saattrupdan Oct 22, 2024
4fdc302
chore: use pip3 to install pipx
saattrupdan Oct 22, 2024
8c7cbe2
chore: revert
saattrupdan Oct 22, 2024
75efbff
fix: Use apt to install pipx on Linux
saattrupdan Oct 22, 2024
c88926b
debug
saattrupdan Oct 23, 2024
be2ffba
fix: Use hf_config.max_length when setting tokenizer model_max_length…
saattrupdan Oct 23, 2024
dcd610f
chore: Update lock file
saattrupdan Oct 23, 2024
2c1ed86
chore: Update pre-commit
saattrupdan Oct 23, 2024
1bcac9e
fix: Use tokenizer model_max_length when padding to max_length
saattrupdan Oct 23, 2024
a4a558f
chore: Re-enable hook
saattrupdan Oct 23, 2024
e216bea
chore: Update lock file
saattrupdan Oct 23, 2024
26f241a
fix: Set split_batches=True
saattrupdan Oct 23, 2024
dd24672
Merge branch 'fix/whisper-finetuning' of github.com:alexandrainst/cor…
saattrupdan Oct 23, 2024
6338756
fix: Temporarily disable specaugment
saattrupdan Oct 23, 2024
d40f6ea
fix: Set feature_size and num_mel_bins to 160k
saattrupdan Oct 23, 2024
8b4a2b5
fix: Set ignore_mismatched_sizes=True
saattrupdan Oct 23, 2024
6dd8fb7
chore: Revert
saattrupdan Oct 23, 2024
6a026a6
fix: Disable split_batches again
saattrupdan Oct 23, 2024
1917e07
fix: Do not force padding to be max_length in multigpu setups
saattrupdan Oct 23, 2024
38cdcb2
fix: Set dispatch_batches=False
saattrupdan Oct 23, 2024
d4e9559
docs: Add TODO comment regarding the commenting of max length change …
saattrupdan Oct 23, 2024
c477ce9
Merge branch 'main' of github.com:alexandrainst/coral into fix/whispe…
saattrupdan Oct 23, 2024
3b5e084
chore: Update gradio
saattrupdan Oct 23, 2024
9b68b86
fix: Set dispatch_batches=False for wav2vec2 models as well
saattrupdan Oct 23, 2024
70899f6
Merge branch 'fix/whisper-finetuning' of github.com:alexandrainst/cor…
saattrupdan Oct 23, 2024
dff2c7a
fix: Remove cast_to_sampling_rate from process_dataset call in `valid…
saattrupdan Oct 23, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ repos:
hooks:
- id: python-use-type-annotations
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
rev: v5.0.0
hooks:
- id: end-of-file-fixer
- id: trailing-whitespace
Expand Down
1 change: 0 additions & 1 deletion config/datasets/common_voice_17.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,3 @@ common_voice_17:
text_column: sentence
audio_column: audio
filter_dataset: false
process_dataset: true
1 change: 0 additions & 1 deletion config/datasets/common_voice_9.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,3 @@ common_voice_9:
text_column: sentence
audio_column: audio
filter_dataset: true
process_dataset: true
1 change: 0 additions & 1 deletion config/datasets/coral.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,3 @@ coral:
text_column: text
audio_column: audio
filter_dataset: false
process_dataset: true
1 change: 0 additions & 1 deletion config/datasets/fleurs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,3 @@ fleurs:
text_column: raw_transcription
audio_column: audio
filter_dataset: true
process_dataset: true
1 change: 0 additions & 1 deletion config/datasets/ftspeech.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,3 @@ ftspeech:
text_column: sentence
audio_column: audio
filter_dataset: true
process_dataset: true
1 change: 0 additions & 1 deletion config/datasets/nota.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,3 @@ nota:
text_column: text
audio_column: audio
filter_dataset: true
process_dataset: true
1 change: 0 additions & 1 deletion config/datasets/nst.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,3 @@ nst:
text_column: text
audio_column: audio
filter_dataset: true
process_dataset: true
1 change: 0 additions & 1 deletion config/datasets/test_dataset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,3 @@ test_dataset:
text_column: sentence
audio_column: audio
filter_dataset: true
process_dataset: true
3 changes: 3 additions & 0 deletions makefile
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ install-pipx:
@if [ "$(shell which pipx)" = "" ]; then \
uname=$$(uname); \
case $${uname} in \
(*Linux*) installCmd='sudo apt install pipx'; ;; \
(*Darwin*) installCmd='brew install pipx'; ;; \
(*CYGWIN*) installCmd='py -3 -m pip install --upgrade --user pipx'; ;; \
(*) installCmd='python3 -m pip install --upgrade --user pipx'; ;; \
Expand Down Expand Up @@ -118,6 +119,8 @@ type-check: ## Run type checking
--show-error-codes \
--check-untyped-defs

check: lint format type-check ## Check the code

roest-315m: ## Train the Røst-315M model
@accelerate launch \
--use-deepspeed \
Expand Down
2,371 changes: 1,252 additions & 1,119 deletions poetry.lock

Large diffs are not rendered by default.

65 changes: 32 additions & 33 deletions src/coral/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,35 +171,25 @@ def load_data_for_finetuning(
num_proc=config.dataset_num_workers,
)

if dataset_config.process_dataset:
ds = ds.remove_columns(
column_names=[
column
for column in ds.column_names or list()
if column not in ["audio", "text"]
]
).shuffle(seed=config.seed)
ds = process_dataset(
dataset=ds,
clean_text=config.model.clean_text,
lower_case=config.model.lower_case,
characters_to_keep=config.characters_to_keep,
text_column="text",
audio_column="audio",
convert_numerals=False,
remove_input_dataset_columns=True,
cast_to_sampling_rate=config.model.sampling_rate,
processor=processor,
num_proc=config.dataset_num_workers,
)
ds = ds.remove_columns(
column_names=[
column
for column in ds.column_names or list()
if column not in ["audio", "text"]
]
).shuffle(seed=config.seed)

ds = ds.cast_column(
column="audio", feature=Audio(sampling_rate=config.model.sampling_rate)
)

all_datasets.append(ds)

assert len(all_datasets) > 0, "No datasets were loaded"

if len(all_datasets) > 1:
if is_main_process:
logger.info("Interleaving datasets")
logger.info("Interleaving datasets...")
if config.dataset_probabilities is None and len(all_datasets) > 1:
logger.warning(
"No dataset probabilities were specified for the training split. "
Expand Down Expand Up @@ -228,6 +218,19 @@ def load_data_for_finetuning(
else:
train = all_datasets[0]

train = process_dataset(
dataset=train,
clean_text=config.model.clean_text,
lower_case=config.model.lower_case,
characters_to_keep=config.characters_to_keep,
text_column="text",
audio_column="audio",
convert_numerals=False,
remove_input_dataset_columns=True,
processor=processor,
num_proc=config.dataset_num_workers,
)

data_dict = dict(train=train)
dataset = IterableDatasetDict(data_dict)

Expand Down Expand Up @@ -256,6 +259,10 @@ def load_data_for_finetuning(
if config.evaluation_dataset.audio_column != "audio":
val = val.rename_column(config.evaluation_dataset.audio_column, "audio")

val = val.cast_column(
column="audio", feature=Audio(sampling_rate=config.model.sampling_rate)
)

val = process_dataset(
dataset=val,
clean_text=config.model.clean_text,
Expand All @@ -265,7 +272,6 @@ def load_data_for_finetuning(
audio_column="audio",
convert_numerals=False,
remove_input_dataset_columns=True,
cast_to_sampling_rate=config.model.sampling_rate,
processor=processor,
num_proc=config.dataset_num_workers,
)
Expand Down Expand Up @@ -328,6 +334,9 @@ def load_dataset_for_evaluation(config: DictConfig) -> Dataset:
max_seconds_per_example=config.max_seconds_per_example,
is_main_process=is_main_process,
)
dataset = dataset.cast_column(
column=config.audio_column, feature=Audio(sampling_rate=config.sampling_rate)
)
dataset = process_dataset(
dataset=dataset,
clean_text=config.clean_text,
Expand All @@ -336,7 +345,6 @@ def load_dataset_for_evaluation(config: DictConfig) -> Dataset:
text_column=config.text_column,
audio_column=config.audio_column,
remove_input_dataset_columns=False,
cast_to_sampling_rate=config.sampling_rate,
convert_numerals=True,
)

Expand Down Expand Up @@ -450,7 +458,6 @@ def process_dataset(
audio_column: str | None,
convert_numerals: bool,
num_proc: int | None = None,
cast_to_sampling_rate: int | None = None,
processor: Callable | None = None,
) -> Data:
"""Process the dataset.
Expand Down Expand Up @@ -479,21 +486,13 @@ def process_dataset(
num_proc (optional):
The number of processes to use for processing the dataset. If `None`, then
no multiprocessing is used. Defaults to `None`.
cast_to_sampling_rate (optional):
The sampling rate to cast the audio to. If `None`, then the audio is not
cast. Defaults to `None`.
processor (optional):
The processor to use for processing the audio and transcriptions. If `None`,
then the processor is not used. Defaults to `None`.

Returns:
The cleaned dataset.
"""
if audio_column is not None:
dataset = dataset.cast_column(
column=audio_column, feature=Audio(sampling_rate=cast_to_sampling_rate)
)

if isinstance(dataset, Dataset) or isinstance(dataset, IterableDataset):
column_names = dataset.column_names
elif isinstance(dataset, DatasetDict) or isinstance(dataset, IterableDatasetDict):
Expand Down
10 changes: 5 additions & 5 deletions src/coral/data_collators.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def torch_call(self, features: list[dict]) -> BatchFeature:
labels=label_features,
padding=self.padding,
return_tensors=self.return_tensors,
max_length=512,
max_length=min(self.processor.tokenizer.model_max_length, 512),
)

# Replace padding with -100 to ignore loss correctly
Expand Down Expand Up @@ -147,16 +147,16 @@ def torch_call(self, features: list[dict]) -> BatchFeature:
label_features,
padding=self.padding,
return_tensors=self.return_tensors,
max_length=512,
max_length=min(self.processor.tokenizer.model_max_length, 512),
)

# replace padding with -100 to ignore loss correctly
# Replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(
labels_batch.attention_mask.ne(1), -100
)

# if bos token is appended in previous tokenization step,
# cut bos token here as it's append later anyways
# If bos token is appended in previous tokenization step, cut BOS token here as
# it's appended later anyway
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
labels = labels[:, 1:]

Expand Down
7 changes: 5 additions & 2 deletions src/coral/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import TypeVar

import torch
from datasets import Dataset, DatasetDict
from datasets import Audio, Dataset, DatasetDict
from transformers import AutomaticSpeechRecognitionPipeline, pipeline

from .compute_metrics import compute_metrics_of_dataset_using_pipeline
Expand Down Expand Up @@ -60,6 +60,10 @@ def add_validations(
if input_is_single_split:
dataset = DatasetDict(dict(train=dataset))

dataset = dataset.cast_column(
column=audio_column, feature=Audio(sampling_rate=sampling_rate)
)

processed_dataset = process_dataset(
dataset=dataset,
clean_text=clean_text,
Expand All @@ -69,7 +73,6 @@ def add_validations(
convert_numerals=False,
remove_input_dataset_columns=True,
lower_case=lower_case,
cast_to_sampling_rate=sampling_rate,
)

logger.info(f"Loading the {model_id!r} ASR model...")
Expand Down
1 change: 1 addition & 0 deletions src/coral/wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def load_training_arguments(self) -> TrainingArguments:
use_cpu=hasattr(sys, "_called_from_test"),
dataloader_num_workers=self.config.dataloader_num_workers,
ddp_find_unused_parameters=False,
dispatch_batches=False,
)
return args

Expand Down
13 changes: 12 additions & 1 deletion src/coral/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from omegaconf import DictConfig
from torch.backends.mps import is_available as mps_is_available
from transformers import (
AutoConfig,
AutoModelForSpeechSeq2Seq,
EvalPrediction,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
Expand Down Expand Up @@ -51,12 +53,20 @@ def load_processor(self) -> WhisperProcessor:
)
assert isinstance(processor_or_tup, WhisperProcessor)
self.processor = processor_or_tup

# Whisper tokenizers are misconfigured with a max_length that is too high, but
# the correct max_length is stored in the model config, so we'll update it here.
hf_config = AutoConfig.from_pretrained(self.config.model.pretrained_model_id)
self.processor.tokenizer.model_max_length = min(
self.processor.tokenizer.model_max_length, hf_config.max_length
)

return self.processor

def load_model(self) -> WhisperForConditionalGeneration:
"""Return the model for the setup."""
with transformers_output_ignored():
model = WhisperForConditionalGeneration.from_pretrained(
model = AutoModelForSpeechSeq2Seq.from_pretrained(
self.config.model.pretrained_model_id,
dropout=self.config.model.dropout,
activation_dropout=self.config.model.activation_dropout,
Expand Down Expand Up @@ -179,6 +189,7 @@ def load_training_arguments(self) -> TrainingArguments:
use_cpu=hasattr(sys, "_called_from_test"),
dataloader_num_workers=self.config.dataloader_num_workers,
ddp_find_unused_parameters=False,
dispatch_batches=False,
)
return args

Expand Down
17 changes: 10 additions & 7 deletions src/scripts/finetune_asr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,16 @@ def main(config: DictConfig) -> None:
"training"
)
config.model.layerdrop = 0.0
if config.padding != "max_length":
if is_main_process:
logger.info(
"Forcing `padding` to be 'max_length' as this is required in a "
"multi-GPU training"
)
config.padding = "max_length"

# TODO: This doesn't seem to be changed anymore, but keeping it here for some
# time in case we need to re-enable it.
# if config.padding != "max_length":
# if is_main_process:
# logger.info(
# "Forcing `padding` to be 'max_length' as this is required in a "
# "multi-GPU training"
# )
# config.padding = "max_length"

elif torch.cuda.device_count() > 1:
if is_main_process:
Expand Down
Loading