Skip to content

Commit

Permalink
Merge pull request speechbrain#2405 from Adel-Moumen/gigaspeech
Browse files Browse the repository at this point in the history
add GigaSpeech dataset in SpeechBrain
  • Loading branch information
asumagic authored Oct 25, 2024
2 parents fd0cd20 + db5b629 commit 3d2eeee
Show file tree
Hide file tree
Showing 19 changed files with 3,068 additions and 2 deletions.
91 changes: 91 additions & 0 deletions recipes/GigaSpeech/ASR/CTC/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Speech Recognition on GigaSpeech with pre-trained self-supervised models and CTC

This folder contains the scripts to finetune any HuggingFace transformer model based
on transformers (wavlm, wav2vec 2, HuBERT...) with CTC for speech recognition on
GigaSpeech. Training can be done on any of the GigaSpeech subset (XL, L, S etc).

## Data access and download

**The XL set is fairly large, 2.2TB are necessary to store the compressed and uncompressed version of the data**

SpeechBrain supports two ways of dealing with the GigaSpeech dataset:
1. [HuggingFace dataset](https://huggingface.co/datasets/speechcolab/gigaspeech/). For HuggingFacem note that **you must use** the HuggingFace client to log in first before running the recipe.
2. [Original Github](https://github.com/SpeechColab/GigaSpeech).

You simply need to follow the instructions on either of the above links. **We strongly
recomment using HuggingFace as the download speed for people outside of China is
much quicker**.

## Data preparation

**This step can be very long depending on your internet connection and filesystem for the XL split of GigaSpeech. For DDP (multi GPU) the recipe must be run once without DDP otherwise it will timeout. You do not want to let X GPUs hang out without doing nothing for hours anyway. Use the *data_prep_only* flag from the yaml to exit after data preparation**

SpeechBrain will automatically download the dataset if you use HuggingFace. Note that if you use HuggingFace, the *data_folder* argument is used to store the **extracted** dataset. However, HuggingFace first needs to download the compressed data, and this is not stored in *data_folder* by default. Indeed, HuggingFace is a bit strict in the way it operates with dataset, and the data will be put into the folder specified by the environment variable *HF_HUB_CACHE* or, if not set, *HF_HOME* or, if not set, *XDG_CACHE_HOME*. Hence, we recommend setting the *HF_HUB_CACHE* to the place where you want to store the data first. For example, you can set it like this:

```export HF_HUB_CACHE=/path/to/your/data/folder```

## Installing Extra Dependencies

Before proceeding, ensure you have installed the necessary additional dependencies. To do this, simply run the following command in your terminal:

```
pip install -r extra_requirements.txt
```

# How to run

With a single GPU:
```
python train_with_wavlm.py hparams/file.yaml
```
With multiple GPUs:
```
torchrun --nproc_per_node=8 train_with_wavlm.py hparams/file.yaml
```

# KenLM n-gram CTC rescoring
To enable n-gram rescoring during the decoding, you must download (or train yourself) the n-gram language model:

```
wget https://huggingface.co/wgb14/gigaspeech_lm/resolve/main/3gram_pruned_1e7.arpa.gz
wget https://huggingface.co/wgb14/gigaspeech_lm/resolve/main/4gram.arpa.gz
gunzip -c 3gram_pruned_1e7.arpa.gz > 3gram_pruned_1e7.arpa
gunzip -c 4gram.arpa.gz > 4gram.arpa
```

Then simply modify the *test_beam_search* in the yaml by adding *kenlm_model_path:* and your path as a parameter.

# Rescoring with a Neural Language Model
This can be done by modifying the current recipe. We invite you to have a look at our LibriSpeech CTC recipe for many different examples.

# Results

| Release | Hyperparams file | Decoding method | Finetuning Split | Test WER | Dev WER | HuggingFace link | Full model link | Training GPUs |
|:-------------:|:---------------------------:| :----------:| :-----:| :-----:| :-----:| :-----:| :-----:| :-----:|
| 25-10-2024 | train_hf_wavlm.yaml | GreedySearch | XL | 11.88% | 11.86% | Unavailable\* | Unavailable\* | 8xRTX 3090 |

\*: Unfortunately, we are unable to upload the checkpoints for the WavLM model at this time. We currently don't have plans to remedy this.

# **Citing SpeechBrain**
Please, cite SpeechBrain if you use it for your research or business.

```bibtex
@misc{speechbrainV1,
title={Open-Source Conversational AI with SpeechBrain 1.0},
author={Mirco Ravanelli and Titouan Parcollet and Adel Moumen and Sylvain de Langen and Cem Subakan and Peter Plantinga and Yingzhi Wang and Pooneh Mousavi and Luca Della Libera and Artem Ploujnikov and Francesco Paissan and Davide Borra and Salah Zaiem and Zeyu Zhao and Shucong Zhang and Georgios Karakasidis and Sung-Lin Yeh and Pierre Champion and Aku Rouhe and Rudolf Braun and Florian Mai and Juan Zuluaga-Gomez and Seyed Mahed Mousavi and Andreas Nautsch and Xuechen Liu and Sangeet Sagar and Jarod Duret and Salima Mdhaffar and Gaelle Laperriere and Mickael Rouvier and Renato De Mori and Yannick Esteve},
year={2024},
eprint={2407.00463},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2407.00463},
}
@misc{speechbrain,
title={{SpeechBrain}: A General-Purpose Speech Toolkit},
author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio},
year={2021},
eprint={2106.04624},
archivePrefix={arXiv},
primaryClass={eess.AS},
note={arXiv:2106.04624}
}
```
1 change: 1 addition & 0 deletions recipes/GigaSpeech/ASR/CTC/dataset.py
5 changes: 5 additions & 0 deletions recipes/GigaSpeech/ASR/CTC/extra_requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
datasets
kenlm
soundfile
speechcolab
transformers
1 change: 1 addition & 0 deletions recipes/GigaSpeech/ASR/CTC/gigaspeech_prepare.py
240 changes: 240 additions & 0 deletions recipes/GigaSpeech/ASR/CTC/hparams/train_hf_wavlm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
# ################################
# Model: wavlm + DNN + CTC
# Decoding AM: Greedy for validation, and Beam search for testing
# Augmentation: SpecAugment
# Authors: Adel Moumen 2024, Titouan Parcollet 2024
# ################################

# Seed needs to be set at top of yaml, before objects with parameters are made
seed: 1986
__set_seed: !apply:speechbrain.utils.seed_everything [!ref <seed>]
experiment_name: train_wavlm_char
output_folder: !ref results/<experiment_name>/<seed>
output_wer_folder: !ref <output_folder>/
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt

wav2vec2_hub: microsoft/wavlm-large
wav2vec2_folder: !ref <save_folder>/wav2vec2_checkpoint

# Data files
data_folder: !PLACEHOLDER # e,g./path/to/GigaSpeech

# see https://github.com/SpeechColab/GigaSpeech for more details on the dataset
# must be one of ["XS", "S", "M", "L", "XL"]
# and ["DEV", "TEST"] for the eval splits.
splits: ["XL", "DEV", "TEST"]
skip_prep: False
data_prep_only: False
download_with_HF: True
convert_opus_to_wav: True
keep_filler_words: False
keep_punctuation: False
ckpt_interval_minutes: 25 # save checkpoint every N min
train_csv: !ref <output_folder>/train.csv
valid_csv: !ref <output_folder>/dev.csv
test_csv: !ref <output_folder>/test.csv
json_file: !ref <data_folder>/GigaSpeech.json

# Training parameters

# The training will either stops at number_of_epochs or optimizer_step_limit
# I.e. the first that is reached.
number_of_epochs: 10
optimizer_step_limit: 300000
warmup: 1000 # Not much is needed as models are pretrained
lr: 0.001
lr_wav2vec: 0.0001
sorting: ascending
num_workers: 4
precision: fp16 # bf16, fp16 or fp32
sample_rate: 16000

# With data_parallel batch_size is split into N jobs
# With DDP batch_size is multiplied by N jobs
# Must be 3 per GPU to fit 32GB of VRAM
batch_size: 8
test_batch_size: 1

# Dataloader options
train_dataloader_opts:
batch_size: !ref <batch_size>
num_workers: !ref <num_workers>

valid_dataloader_opts:
batch_size: !ref <test_batch_size>

test_dataloader_opts:
batch_size: !ref <test_batch_size>

# Using dynamic batching by default. This works with 4x24GB GPUs
# Or turn it off (but training speed will decrease)
dynamic_batching: True
max_batch_length_train: 50
max_batch_length_val: 30 # we reduce it as the beam is much wider (VRAM)
num_bucket: 200
shuffle: True # if true re-creates batches at each epoch shuffling examples.
batch_ordering: random
max_batch_ex: 256

dynamic_batch_sampler_train:
max_batch_length: !ref <max_batch_length_train>
num_buckets: !ref <num_bucket>
shuffle: !ref <shuffle>
batch_ordering: !ref <batch_ordering>
max_batch_ex: !ref <max_batch_ex>

dynamic_batch_sampler_valid:
max_batch_length: !ref <max_batch_length_val>
num_buckets: !ref <num_bucket>
shuffle: !ref <shuffle>
batch_ordering: !ref <batch_ordering>
max_batch_ex: !ref <max_batch_ex>

# BPE parameters
token_type: char # ["unigram", "bpe", "char"]
character_coverage: 1.0

# Model parameters
dnn_neurons: 1024
dropout: 0.1
freeze_wav2vec: False
freeze_wav2vec_extractor: False
wav2vec_output_dim: 1024

# Outputs
output_neurons: 29 # without punctuation
blank_index: 0
bos_index: -1 # No bos/eos with CTC
eos_index: -1

# Decoding parameters
test_beam_search:
beam_size: 143
topk: 1
blank_index: !ref <blank_index>
space_token: ' ' # make sure this is the same as the one used in the tokenizer
beam_prune_logp: -12.0
token_prune_min_logp: -1.2
prune_history: True

#
# Functions and classes
#
epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
limit: !ref <number_of_epochs>

# Speed perturbation
speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
orig_freq: !ref <sample_rate>
speeds: [95, 100, 105]

drop_freq: !new:speechbrain.augment.time_domain.DropFreq
drop_freq_low: 0
drop_freq_high: 1
drop_freq_count_low: 1
drop_freq_count_high: 3
drop_freq_width: 0.05

drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
drop_length_low: 1
drop_length_high: 5
drop_count_low: 1000
drop_count_high: 2000

# Augmenter: Combines previously defined augmentations to perform data augmentation
wav_augment: !new:speechbrain.augment.augmenter.Augmenter
parallel_augment: False
concat_original: True
repeat_augment: 1
shuffle_augmentations: False
min_augmentations: 2
max_augmentations: 2
augment_prob: 1.0
augmentations: [
!ref <speed_perturb>,
!ref <drop_freq>,
!ref <drop_chunk>]


enc: !new:speechbrain.nnet.containers.Sequential
input_shape: [null, null, !ref <wav2vec_output_dim>]
linear1: !name:speechbrain.nnet.linear.Linear
n_neurons: !ref <dnn_neurons>
bias: True
bn1: !name:speechbrain.nnet.normalization.BatchNorm1d
activation: !new:torch.nn.LeakyReLU
drop: !new:torch.nn.Dropout
p: !ref <dropout>
linear2: !name:speechbrain.nnet.linear.Linear
n_neurons: !ref <dnn_neurons>
bias: True
bn2: !name:speechbrain.nnet.normalization.BatchNorm1d
activation2: !new:torch.nn.LeakyReLU
drop2: !new:torch.nn.Dropout
p: !ref <dropout>
linear3: !name:speechbrain.nnet.linear.Linear
n_neurons: !ref <dnn_neurons>
bias: True
bn3: !name:speechbrain.nnet.normalization.BatchNorm1d
activation3: !new:torch.nn.LeakyReLU

wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
source: !ref <wav2vec2_hub>
output_norm: False
freeze: !ref <freeze_wav2vec>
freeze_feature_extractor: !ref <freeze_wav2vec_extractor>
save_path: !ref <wav2vec2_folder>

ctc_lin: !new:speechbrain.nnet.linear.Linear
input_size: !ref <dnn_neurons>
n_neurons: !ref <output_neurons>

log_softmax: !new:speechbrain.nnet.activations.Softmax
apply_log: True

ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
blank_index: !ref <blank_index>

modules:
wav2vec2: !ref <wav2vec2>
enc: !ref <enc>
ctc_lin: !ref <ctc_lin>

model: !new:torch.nn.ModuleList
- [!ref <enc>, !ref <ctc_lin>]

model_opt_class: !name:torch.optim.AdamW
lr: !ref <lr>

wav2vec_opt_class: !name:torch.optim.AdamW
lr: !ref <lr_wav2vec>

lr_annealing_model: !new:speechbrain.nnet.schedulers.WarmAndExpDecayLRSchedule
lr: !ref <lr>
n_warmup_steps: !ref <warmup>
total_steps: !ref <optimizer_step_limit>
decay_factor: 0.05 # Divided by twenty at the end.

lr_annealing_wav2vec: !new:speechbrain.nnet.schedulers.WarmAndExpDecayLRSchedule
lr: !ref <lr_wav2vec>
n_warmup_steps: !ref <warmup>
total_steps: !ref <optimizer_step_limit>
decay_factor: 0.1 # Divided by ten at the end.

checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
checkpoints_dir: !ref <save_folder>
recoverables:
wav2vec2: !ref <wav2vec2>
model: !ref <model>
scheduler_model: !ref <lr_annealing_model>
scheduler_wav2vec: !ref <lr_annealing_wav2vec>
counter: !ref <epoch_counter>

train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
save_file: !ref <train_log>

error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats

cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
split_tokens: True
Loading

0 comments on commit 3d2eeee

Please sign in to comment.