diff --git a/recipes/GigaSpeech/ASR/CTC/README.md b/recipes/GigaSpeech/ASR/CTC/README.md new file mode 100644 index 0000000000..488906ecb6 --- /dev/null +++ b/recipes/GigaSpeech/ASR/CTC/README.md @@ -0,0 +1,91 @@ +# Speech Recognition on GigaSpeech with pre-trained self-supervised models and CTC + +This folder contains the scripts to finetune any HuggingFace transformer model based +on transformers (wavlm, wav2vec 2, HuBERT...) with CTC for speech recognition on +GigaSpeech. Training can be done on any of the GigaSpeech subset (XL, L, S etc). + +## Data access and download + +**The XL set is fairly large, 2.2TB are necessary to store the compressed and uncompressed version of the data** + +SpeechBrain supports two ways of dealing with the GigaSpeech dataset: +1. [HuggingFace dataset](https://huggingface.co/datasets/speechcolab/gigaspeech/). For HuggingFacem note that **you must use** the HuggingFace client to log in first before running the recipe. +2. [Original Github](https://github.com/SpeechColab/GigaSpeech). + +You simply need to follow the instructions on either of the above links. **We strongly +recomment using HuggingFace as the download speed for people outside of China is +much quicker**. + +## Data preparation + +**This step can be very long depending on your internet connection and filesystem for the XL split of GigaSpeech. For DDP (multi GPU) the recipe must be run once without DDP otherwise it will timeout. You do not want to let X GPUs hang out without doing nothing for hours anyway. Use the *data_prep_only* flag from the yaml to exit after data preparation** + +SpeechBrain will automatically download the dataset if you use HuggingFace. Note that if you use HuggingFace, the *data_folder* argument is used to store the **extracted** dataset. However, HuggingFace first needs to download the compressed data, and this is not stored in *data_folder* by default. Indeed, HuggingFace is a bit strict in the way it operates with dataset, and the data will be put into the folder specified by the environment variable *HF_HUB_CACHE* or, if not set, *HF_HOME* or, if not set, *XDG_CACHE_HOME*. Hence, we recommend setting the *HF_HUB_CACHE* to the place where you want to store the data first. For example, you can set it like this: + +```export HF_HUB_CACHE=/path/to/your/data/folder``` + +## Installing Extra Dependencies + +Before proceeding, ensure you have installed the necessary additional dependencies. To do this, simply run the following command in your terminal: + +``` +pip install -r extra_requirements.txt +``` + +# How to run + +With a single GPU: +``` +python train_with_wavlm.py hparams/file.yaml +``` +With multiple GPUs: +``` +torchrun --nproc_per_node=8 train_with_wavlm.py hparams/file.yaml +``` + +# KenLM n-gram CTC rescoring +To enable n-gram rescoring during the decoding, you must download (or train yourself) the n-gram language model: + +``` +wget https://huggingface.co/wgb14/gigaspeech_lm/resolve/main/3gram_pruned_1e7.arpa.gz +wget https://huggingface.co/wgb14/gigaspeech_lm/resolve/main/4gram.arpa.gz +gunzip -c 3gram_pruned_1e7.arpa.gz > 3gram_pruned_1e7.arpa +gunzip -c 4gram.arpa.gz > 4gram.arpa +``` + +Then simply modify the *test_beam_search* in the yaml by adding *kenlm_model_path:* and your path as a parameter. + +# Rescoring with a Neural Language Model +This can be done by modifying the current recipe. We invite you to have a look at our LibriSpeech CTC recipe for many different examples. + +# Results + +| Release | Hyperparams file | Decoding method | Finetuning Split | Test WER | Dev WER | HuggingFace link | Full model link | Training GPUs | +|:-------------:|:---------------------------:| :----------:| :-----:| :-----:| :-----:| :-----:| :-----:| :-----:| +| 25-10-2024 | train_hf_wavlm.yaml | GreedySearch | XL | 11.88% | 11.86% | Unavailable\* | Unavailable\* | 8xRTX 3090 | + +\*: Unfortunately, we are unable to upload the checkpoints for the WavLM model at this time. We currently don't have plans to remedy this. + +# **Citing SpeechBrain** +Please, cite SpeechBrain if you use it for your research or business. + +```bibtex +@misc{speechbrainV1, + title={Open-Source Conversational AI with SpeechBrain 1.0}, + author={Mirco Ravanelli and Titouan Parcollet and Adel Moumen and Sylvain de Langen and Cem Subakan and Peter Plantinga and Yingzhi Wang and Pooneh Mousavi and Luca Della Libera and Artem Ploujnikov and Francesco Paissan and Davide Borra and Salah Zaiem and Zeyu Zhao and Shucong Zhang and Georgios Karakasidis and Sung-Lin Yeh and Pierre Champion and Aku Rouhe and Rudolf Braun and Florian Mai and Juan Zuluaga-Gomez and Seyed Mahed Mousavi and Andreas Nautsch and Xuechen Liu and Sangeet Sagar and Jarod Duret and Salima Mdhaffar and Gaelle Laperriere and Mickael Rouvier and Renato De Mori and Yannick Esteve}, + year={2024}, + eprint={2407.00463}, + archivePrefix={arXiv}, + primaryClass={cs.LG}, + url={https://arxiv.org/abs/2407.00463}, +} +@misc{speechbrain, + title={{SpeechBrain}: A General-Purpose Speech Toolkit}, + author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio}, + year={2021}, + eprint={2106.04624}, + archivePrefix={arXiv}, + primaryClass={eess.AS}, + note={arXiv:2106.04624} +} +``` diff --git a/recipes/GigaSpeech/ASR/CTC/dataset.py b/recipes/GigaSpeech/ASR/CTC/dataset.py new file mode 120000 index 0000000000..f3bfeaf826 --- /dev/null +++ b/recipes/GigaSpeech/ASR/CTC/dataset.py @@ -0,0 +1 @@ +../../dataset.py \ No newline at end of file diff --git a/recipes/GigaSpeech/ASR/CTC/extra_requirements.txt b/recipes/GigaSpeech/ASR/CTC/extra_requirements.txt new file mode 100644 index 0000000000..a619ba044a --- /dev/null +++ b/recipes/GigaSpeech/ASR/CTC/extra_requirements.txt @@ -0,0 +1,5 @@ +datasets +kenlm +soundfile +speechcolab +transformers diff --git a/recipes/GigaSpeech/ASR/CTC/gigaspeech_prepare.py b/recipes/GigaSpeech/ASR/CTC/gigaspeech_prepare.py new file mode 120000 index 0000000000..5190685a8e --- /dev/null +++ b/recipes/GigaSpeech/ASR/CTC/gigaspeech_prepare.py @@ -0,0 +1 @@ +../../gigaspeech_prepare.py \ No newline at end of file diff --git a/recipes/GigaSpeech/ASR/CTC/hparams/train_hf_wavlm.yaml b/recipes/GigaSpeech/ASR/CTC/hparams/train_hf_wavlm.yaml new file mode 100644 index 0000000000..71d2c8c7c3 --- /dev/null +++ b/recipes/GigaSpeech/ASR/CTC/hparams/train_hf_wavlm.yaml @@ -0,0 +1,240 @@ +# ################################ +# Model: wavlm + DNN + CTC +# Decoding AM: Greedy for validation, and Beam search for testing +# Augmentation: SpecAugment +# Authors: Adel Moumen 2024, Titouan Parcollet 2024 +# ################################ + +# Seed needs to be set at top of yaml, before objects with parameters are made +seed: 1986 +__set_seed: !apply:speechbrain.utils.seed_everything [!ref ] +experiment_name: train_wavlm_char +output_folder: !ref results// +output_wer_folder: !ref / +save_folder: !ref /save +train_log: !ref /train_log.txt + +wav2vec2_hub: microsoft/wavlm-large +wav2vec2_folder: !ref /wav2vec2_checkpoint + +# Data files +data_folder: !PLACEHOLDER # e,g./path/to/GigaSpeech + +# see https://github.com/SpeechColab/GigaSpeech for more details on the dataset +# must be one of ["XS", "S", "M", "L", "XL"] +# and ["DEV", "TEST"] for the eval splits. +splits: ["XL", "DEV", "TEST"] +skip_prep: False +data_prep_only: False +download_with_HF: True +convert_opus_to_wav: True +keep_filler_words: False +keep_punctuation: False +ckpt_interval_minutes: 25 # save checkpoint every N min +train_csv: !ref /train.csv +valid_csv: !ref /dev.csv +test_csv: !ref /test.csv +json_file: !ref /GigaSpeech.json + +# Training parameters + +# The training will either stops at number_of_epochs or optimizer_step_limit +# I.e. the first that is reached. +number_of_epochs: 10 +optimizer_step_limit: 300000 +warmup: 1000 # Not much is needed as models are pretrained +lr: 0.001 +lr_wav2vec: 0.0001 +sorting: ascending +num_workers: 4 +precision: fp16 # bf16, fp16 or fp32 +sample_rate: 16000 + +# With data_parallel batch_size is split into N jobs +# With DDP batch_size is multiplied by N jobs +# Must be 3 per GPU to fit 32GB of VRAM +batch_size: 8 +test_batch_size: 1 + +# Dataloader options +train_dataloader_opts: + batch_size: !ref + num_workers: !ref + +valid_dataloader_opts: + batch_size: !ref + +test_dataloader_opts: + batch_size: !ref + +# Using dynamic batching by default. This works with 4x24GB GPUs +# Or turn it off (but training speed will decrease) +dynamic_batching: True +max_batch_length_train: 50 +max_batch_length_val: 30 # we reduce it as the beam is much wider (VRAM) +num_bucket: 200 +shuffle: True # if true re-creates batches at each epoch shuffling examples. +batch_ordering: random +max_batch_ex: 256 + +dynamic_batch_sampler_train: + max_batch_length: !ref + num_buckets: !ref + shuffle: !ref + batch_ordering: !ref + max_batch_ex: !ref + +dynamic_batch_sampler_valid: + max_batch_length: !ref + num_buckets: !ref + shuffle: !ref + batch_ordering: !ref + max_batch_ex: !ref + +# BPE parameters +token_type: char # ["unigram", "bpe", "char"] +character_coverage: 1.0 + +# Model parameters +dnn_neurons: 1024 +dropout: 0.1 +freeze_wav2vec: False +freeze_wav2vec_extractor: False +wav2vec_output_dim: 1024 + +# Outputs +output_neurons: 29 # without punctuation +blank_index: 0 +bos_index: -1 # No bos/eos with CTC +eos_index: -1 + +# Decoding parameters +test_beam_search: + beam_size: 143 + topk: 1 + blank_index: !ref + space_token: ' ' # make sure this is the same as the one used in the tokenizer + beam_prune_logp: -12.0 + token_prune_min_logp: -1.2 + prune_history: True + +# +# Functions and classes +# +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +# Speed perturbation +speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb + orig_freq: !ref + speeds: [95, 100, 105] + +drop_freq: !new:speechbrain.augment.time_domain.DropFreq + drop_freq_low: 0 + drop_freq_high: 1 + drop_freq_count_low: 1 + drop_freq_count_high: 3 + drop_freq_width: 0.05 + +drop_chunk: !new:speechbrain.augment.time_domain.DropChunk + drop_length_low: 1 + drop_length_high: 5 + drop_count_low: 1000 + drop_count_high: 2000 + +# Augmenter: Combines previously defined augmentations to perform data augmentation +wav_augment: !new:speechbrain.augment.augmenter.Augmenter + parallel_augment: False + concat_original: True + repeat_augment: 1 + shuffle_augmentations: False + min_augmentations: 2 + max_augmentations: 2 + augment_prob: 1.0 + augmentations: [ + !ref , + !ref , + !ref ] + + +enc: !new:speechbrain.nnet.containers.Sequential + input_shape: [null, null, !ref ] + linear1: !name:speechbrain.nnet.linear.Linear + n_neurons: !ref + bias: True + bn1: !name:speechbrain.nnet.normalization.BatchNorm1d + activation: !new:torch.nn.LeakyReLU + drop: !new:torch.nn.Dropout + p: !ref + linear2: !name:speechbrain.nnet.linear.Linear + n_neurons: !ref + bias: True + bn2: !name:speechbrain.nnet.normalization.BatchNorm1d + activation2: !new:torch.nn.LeakyReLU + drop2: !new:torch.nn.Dropout + p: !ref + linear3: !name:speechbrain.nnet.linear.Linear + n_neurons: !ref + bias: True + bn3: !name:speechbrain.nnet.normalization.BatchNorm1d + activation3: !new:torch.nn.LeakyReLU + +wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2 + source: !ref + output_norm: False + freeze: !ref + freeze_feature_extractor: !ref + save_path: !ref + +ctc_lin: !new:speechbrain.nnet.linear.Linear + input_size: !ref + n_neurons: !ref + +log_softmax: !new:speechbrain.nnet.activations.Softmax + apply_log: True + +ctc_cost: !name:speechbrain.nnet.losses.ctc_loss + blank_index: !ref + +modules: + wav2vec2: !ref + enc: !ref + ctc_lin: !ref + +model: !new:torch.nn.ModuleList + - [!ref , !ref ] + +model_opt_class: !name:torch.optim.AdamW + lr: !ref + +wav2vec_opt_class: !name:torch.optim.AdamW + lr: !ref + +lr_annealing_model: !new:speechbrain.nnet.schedulers.WarmAndExpDecayLRSchedule + lr: !ref + n_warmup_steps: !ref + total_steps: !ref + decay_factor: 0.05 # Divided by twenty at the end. + +lr_annealing_wav2vec: !new:speechbrain.nnet.schedulers.WarmAndExpDecayLRSchedule + lr: !ref + n_warmup_steps: !ref + total_steps: !ref + decay_factor: 0.1 # Divided by ten at the end. + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + wav2vec2: !ref + model: !ref + scheduler_model: !ref + scheduler_wav2vec: !ref + counter: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref + +error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + +cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + split_tokens: True diff --git a/recipes/GigaSpeech/ASR/CTC/train_with_wavlm.py b/recipes/GigaSpeech/ASR/CTC/train_with_wavlm.py new file mode 100644 index 0000000000..60e1f6bb4a --- /dev/null +++ b/recipes/GigaSpeech/ASR/CTC/train_with_wavlm.py @@ -0,0 +1,482 @@ +""" This recipe finetunes a pretrained wavlm model large +on GigaSpeech for speech recognition with CTC and at the character level. +The WavLM model can be swapped with any HuggingFace model if wanted. + +To run this recipe, do the follo +wing: +> python train_with_wavlm.py hparams/train_hf_wavlm.yaml + +Authors + * Adel Moumen 2024 + * Titouan Parcollet 2024 +""" + +import logging +import os +import sys + +import torch +from hyperpyyaml import load_hyperpyyaml + +import speechbrain as sb +from speechbrain.tokenizers.SentencePiece import SentencePiece +from speechbrain.utils.data_utils import undo_padding +from speechbrain.utils.distributed import if_main_process, run_on_main + +logger = logging.getLogger(__name__) + + +# Define training procedure +class ASR(sb.Brain): + def compute_forward(self, batch, stage): + """Forward computations from the waveform batches to the output probabilities.""" + batch = batch.to(self.device) + wavs, wav_lens = batch.sig + wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device) + + # Downsample the inputs if specified + if hasattr(self.modules, "downsampler"): + wavs = self.modules.downsampler(wavs) + + # Add waveform augmentation if specified. + if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"): + wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens) + + # Forward pass + + # Handling SpeechBrain vs HuggingFace pretrained models + if hasattr(self.modules, "extractor"): # SpeechBrain pretrained model + latents = self.modules.extractor(wavs) + feats = self.modules.encoder_wrapper(latents, wav_lens=wav_lens)[ + "embeddings" + ] + else: # HuggingFace pretrained model + feats = self.modules.wav2vec2(wavs, wav_lens) + + x = self.modules.enc(feats) + + # Compute outputs + logits = self.modules.ctc_lin(x) + + # Upsample the inputs if they have been highly downsampled + if hasattr(self.hparams, "upsampling") and self.hparams.upsampling: + logits = logits.view( + logits.shape[0], -1, self.hparams.output_neurons + ) + + p_ctc = self.hparams.log_softmax(logits) + + if stage == sb.Stage.VALID: + p_tokens = sb.decoders.ctc_greedy_decode( + p_ctc, wav_lens, blank_id=self.hparams.blank_index + ) + elif stage == sb.Stage.TEST: + p_tokens = test_searcher(p_ctc, wav_lens) + else: + p_tokens = None + + return p_ctc, wav_lens, p_tokens + + def compute_objectives(self, predictions, batch, stage): + """Computes the loss (CTC+NLL) given predictions and targets.""" + + p_ctc, wav_lens, predicted_tokens = predictions + + ids = batch.id + tokens, tokens_lens = batch.tokens + + # Label Augmentation + if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"): + tokens = self.hparams.wav_augment.replicate_labels(tokens) + tokens_lens = self.hparams.wav_augment.replicate_labels(tokens_lens) + + loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens) + + if stage == sb.Stage.VALID: + # Decode token terms to words + predicted_words = self.tokenizer( + predicted_tokens, task="decode_from_list" + ) + elif stage == sb.Stage.TEST: + predicted_words = [ + hyp[0].text.split(" ") for hyp in predicted_tokens + ] + + if stage != sb.Stage.TRAIN: + # Convert indices to words + target_words = undo_padding(tokens, tokens_lens) + target_words = self.tokenizer(target_words, task="decode_from_list") + self.wer_metric.append(ids, predicted_words, target_words) + self.cer_metric.append(ids, predicted_words, target_words) + + return loss + + def on_stage_start(self, stage, epoch): + """Gets called at the beginning of each epoch""" + if stage != sb.Stage.TRAIN: + self.cer_metric = self.hparams.cer_computer() + self.wer_metric = self.hparams.error_rate_computer() + + if stage == sb.Stage.TEST: + if hasattr(self.hparams, "rescorer"): + self.hparams.rescorer.move_rescorers_to_device() + + def on_stage_end(self, stage, stage_loss, epoch): + """Gets called at the end of an epoch.""" + # Compute/store important stats + stage_stats = {"loss": stage_loss} + if stage == sb.Stage.TRAIN: + self.train_stats = stage_stats + else: + stage_stats["CER"] = self.cer_metric.summarize("error_rate") + stage_stats["WER"] = self.wer_metric.summarize("error_rate") + + # Perform end-of-iteration things, like annealing, logging, etc. + if stage == sb.Stage.VALID: + new_lr_model = self.model_optimizer.param_groups[0]["lr"] + new_lr_wav2vec = self.wav2vec_optimizer.param_groups[0]["lr"] + + self.hparams.train_logger.log_stats( + stats_meta={ + "epoch": epoch, + "lr_model": new_lr_model, + "lr_wav2vec": new_lr_wav2vec, + }, + train_stats=self.train_stats, + valid_stats=stage_stats, + ) + self.checkpointer.save_and_keep_only( + meta={"WER": stage_stats["WER"]}, + min_keys=["WER"], + ) + elif stage == sb.Stage.TEST: + self.hparams.train_logger.log_stats( + stats_meta={"Epoch loaded": self.hparams.epoch_counter.current}, + test_stats=stage_stats, + ) + if if_main_process(): + with open(self.hparams.test_wer_file, "w") as w: + self.wer_metric.write_stats(w) + + def on_fit_batch_end(self, batch, outputs, loss, should_step): + """Called after ``fit_batch()``. + + Arguments + --------- + batch : list of torch.Tensors + Batch of data to use for training. Default implementation assumes + this batch has two elements: inputs and targets. + outputs : list or dictionary of torch.Tensors + Returned value of compute_forward(). + loss : torch.Tensor + Returned value of compute_objectives(). + should_step : boolean + Whether optimizer.step() was called or not. + """ + + self.hparams.lr_annealing_model(self.model_optimizer) + self.hparams.lr_annealing_wav2vec(self.wav2vec_optimizer) + + def init_optimizers(self): + "Initializes the wav2vec2 optimizer and model optimizer" + # Handling SpeechBrain vs HuggingFace pretrained models + if hasattr(self.modules, "extractor"): # SpeechBrain pretrained model + self.wav2vec_optimizer = self.hparams.wav2vec_opt_class( + self.modules.encoder_wrapper.parameters() + ) + + else: # HuggingFace pretrained model + self.wav2vec_optimizer = self.hparams.wav2vec_opt_class( + self.modules.wav2vec2.parameters() + ) + + self.model_optimizer = self.hparams.model_opt_class( + self.hparams.model.parameters() + ) + + # save the optimizers in a dictionary + # the key will be used in `freeze_optimizers()` + self.optimizers_dict = { + "model_optimizer": self.model_optimizer, + } + if not self.hparams.freeze_wav2vec: + self.optimizers_dict["wav2vec_optimizer"] = self.wav2vec_optimizer + + if self.checkpointer is not None: + self.checkpointer.add_recoverable( + "wav2vec_opt", self.wav2vec_optimizer + ) + self.checkpointer.add_recoverable("modelopt", self.model_optimizer) + + +def dataio_prepare(hparams, tokenizer): + """This function prepares the datasets to be used in the brain class. + It also defines the data processing pipeline through user-defined functions. + """ + data_folder = hparams["data_folder"] + + train_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["train_csv"], + replacements={"data_root": data_folder}, + ) + + if hparams["sorting"] == "ascending": + # we sort training data to speed up training and get better results. + train_data = train_data.filtered_sorted(sort_key="duration") + # when sorting do not shuffle in dataloader ! otherwise is pointless + hparams["train_dataloader_opts"]["shuffle"] = False + + elif hparams["sorting"] == "descending": + train_data = train_data.filtered_sorted( + sort_key="duration", reverse=True + ) + # when sorting do not shuffle in dataloader ! otherwise is pointless + hparams["train_dataloader_opts"]["shuffle"] = False + + elif hparams["sorting"] == "random": + pass + + else: + raise NotImplementedError( + "sorting must be random, ascending or descending" + ) + + valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["valid_csv"], + replacements={"data_root": data_folder}, + ) + valid_data = valid_data.filtered_sorted(sort_key="duration") + + test_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["test_csv"], + replacements={"data_root": data_folder}, + ) + + # We also sort the validation data so it is faster to validate + test_data = test_data.filtered_sorted(sort_key="duration") + + datasets = [train_data, valid_data, test_data] + + # 2. Define audio pipeline: + @sb.utils.data_pipeline.takes("audio_path", "begin_time", "end_time") + @sb.utils.data_pipeline.provides("sig") + def audio_pipeline(audio_path, begin_time, end_time): + if hparams["download_with_HF"]: + sig = sb.dataio.dataio.read_audio(audio_path) + else: + start_sample = int(float(begin_time) * hparams["sample_rate"]) + stop_sample = int(float(end_time) * hparams["sample_rate"]) + sig = sb.dataio.dataio.read_audio( + {"file": audio_path, "start": start_sample, "stop": stop_sample} + ) + return sig + + sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline) + + # 3. Define text pipeline: + @sb.utils.data_pipeline.takes("text") + @sb.utils.data_pipeline.provides( + "wrd", "char_list", "tokens_list", "tokens" + ) + def text_pipeline(wrd): + yield wrd + char_list = list(wrd) + yield char_list + tokens_list = tokenizer.sp.encode_as_ids(wrd) + yield tokens_list + tokens = torch.LongTensor(tokens_list) + yield tokens + + sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline) + + # 4. Set output: + sb.dataio.dataset.set_output_keys( + datasets, + ["id", "sig", "text", "char_list", "tokens"], + ) + + # 5. If Dynamic Batching is used, we instantiate the needed samplers. + train_batch_sampler = None + valid_batch_sampler = None + if hparams["dynamic_batching"]: + from speechbrain.dataio.sampler import DynamicBatchSampler # noqa + + dynamic_hparams_train = hparams["dynamic_batch_sampler_train"] + dynamic_hparams_valid = hparams["dynamic_batch_sampler_valid"] + + train_batch_sampler = DynamicBatchSampler( + train_data, + length_func=lambda x: x["duration"], + **dynamic_hparams_train, + ) + valid_batch_sampler = DynamicBatchSampler( + valid_data, + length_func=lambda x: x["duration"], + **dynamic_hparams_valid, + ) + + return ( + train_data, + valid_data, + test_data, + train_batch_sampler, + valid_batch_sampler, + ) + + +if __name__ == "__main__": + + # CLI: + hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) + + with open(hparams_file) as fin: + hparams = load_hyperpyyaml(fin, overrides) + + # create ddp_group with the right communication protocol + sb.utils.distributed.ddp_init_group(run_opts) + + # Create experiment directory + sb.create_experiment_directory( + experiment_directory=hparams["output_folder"], + hyperparams_to_save=hparams_file, + overrides=overrides, + ) + + # Dataset prep (parsing Librispeech) + from gigaspeech_prepare import prepare_gigaspeech # noqa + + # We run on main for no reason as it is advised to not run this dataprep with + # DDP initialised. Indeed, it takes a lot of time and will most likely + # result in a timeout (internal DDP timeout). + run_on_main( + prepare_gigaspeech, + kwargs={ + "data_folder": hparams["data_folder"], + "save_folder": hparams["save_folder"], + "splits": hparams["splits"], + "output_train": hparams["train_csv"], + "output_dev": hparams["valid_csv"], + "output_test": hparams["test_csv"], + "json_file": hparams["json_file"], + "convert_opus_to_wav": hparams["convert_opus_to_wav"], + "download_with_HF": hparams["download_with_HF"], + "punctuation": hparams["keep_punctuation"], + "skip_prep": hparams["skip_prep"], + "filler": hparams["keep_filler_words"], + }, + ) + + if hparams["data_prep_only"]: + logger.info( + "Data preparation finished. Restart the script with data_prep_only to False. " + ) + import sys + + sys.exit() + + # Defining tokenizer and loading it + tokenizer = SentencePiece( + model_dir=hparams["save_folder"], + vocab_size=hparams["output_neurons"], + annotation_train=hparams["train_csv"], + annotation_read="text", + model_type=hparams["token_type"], + character_coverage=hparams["character_coverage"], + bos_id=hparams["bos_index"], + eos_id=hparams["eos_index"], + ) + + # here we create the datasets objects as well as tokenization and encoding + ( + train_data, + valid_data, + test_data, + train_bsampler, + valid_bsampler, + ) = dataio_prepare(hparams, tokenizer) + + # Trainer initialization + asr_brain = ASR( + modules=hparams["modules"], + hparams=hparams, + run_opts=run_opts, + checkpointer=hparams["checkpointer"], + ) + + # We load the pretrained wav2vec2 model + if "pretrainer" in hparams.keys(): + run_on_main(hparams["pretrainer"].collect_files) + hparams["pretrainer"].load_collected() + + # We dynamically add the tokenizer to our brain class. + # NB: This tokenizer corresponds to the one used for the LM!! + asr_brain.tokenizer = tokenizer + + # Manage dynamic batching + train_dataloader_opts = hparams["train_dataloader_opts"] + valid_dataloader_opts = hparams["valid_dataloader_opts"] + if train_bsampler is not None: + collate_fn = None + if "collate_fn" in train_dataloader_opts: + collate_fn = train_dataloader_opts["collate_fn"] + + train_dataloader_opts = { + "batch_sampler": train_bsampler, + "num_workers": hparams["num_workers"], + } + + if collate_fn is not None: + train_dataloader_opts["collate_fn"] = collate_fn + + if valid_bsampler is not None: + collate_fn = None + if "collate_fn" in valid_dataloader_opts: + collate_fn = valid_dataloader_opts["collate_fn"] + + valid_dataloader_opts = {"batch_sampler": valid_bsampler} + + if collate_fn is not None: + valid_dataloader_opts["collate_fn"] = collate_fn + + vocab_list = [ + tokenizer.sp.id_to_piece(i) for i in range(tokenizer.sp.vocab_size()) + ] + + from speechbrain.decoders.ctc import CTCBeamSearcher + + test_searcher = CTCBeamSearcher( + **hparams["test_beam_search"], + vocab_list=vocab_list, + ) + + # Training + asr_brain.fit( + asr_brain.hparams.epoch_counter, + train_data, + valid_data, + train_loader_kwargs=train_dataloader_opts, + valid_loader_kwargs=valid_dataloader_opts, + ) + + # Testing + os.makedirs(hparams["output_wer_folder"], exist_ok=True) + + # report WER on valid data + asr_brain.hparams.test_wer_file = os.path.join( + hparams["output_wer_folder"], "valid_wer.txt" + ) + asr_brain.evaluate( + valid_data, + min_key="WER", + test_loader_kwargs=hparams["test_dataloader_opts"], + ) + + # report WER on test data + asr_brain.hparams.test_wer_file = os.path.join( + hparams["output_wer_folder"], "test_wer.txt" + ) + asr_brain.evaluate( + test_data, + min_key="WER", + test_loader_kwargs=hparams["test_dataloader_opts"], + ) diff --git a/recipes/GigaSpeech/ASR/transducer/README.md b/recipes/GigaSpeech/ASR/transducer/README.md new file mode 100644 index 0000000000..46e8953160 --- /dev/null +++ b/recipes/GigaSpeech/ASR/transducer/README.md @@ -0,0 +1,127 @@ +# GigaSpeech streaming and non streaming speech recognition with Transducer models. +This folder contains scripts necessary to run an ASR experiment with the GigaSpeech dataset. +Before running this recipe, make sure numba is installed (pip install numba) + +## Data access and download + +**The XL set is fairly large, 2.2TB are necessary to store the compressed and uncompressed version of the data** + +SpeechBrain supports two ways of dealing with the GigaSpeech dataset: +1. [HuggingFace dataset](https://huggingface.co/datasets/speechcolab/gigaspeech/). For HuggingFacem note that **you must use** the HuggingFace client to log in first before running the recipe. +2. [Original Github](https://github.com/SpeechColab/GigaSpeech). + +You simply need to follow the instructions on either of the above links. **We strongly +recomment using HuggingFace as the download speed for people outside of China is +much quicker**. + +## Data preparation + +**This step can be very long depending on your internet connection and filesystem for the XL split of GigaSpeech. For DDP (multi GPU) the recipe must be run once without DDP otherwise it will timeout. You do not want to let X GPUs hang out without doing nothing for hours anyway. Use the *data_prep_only* flag from the yaml to exit after data preparation** + +SpeechBrain will automatically download the dataset if you use HuggingFace. Note that if you use HuggingFace, the *data_folder* argument is used to store the **extracted** dataset. However, HuggingFace first needs to download the compressed data, and this is not stored in *data_folder* by default. Indeed, HuggingFace is a bit strict in the way it operates with dataset, and the data will be put into the folder specified by the environment variable *HF_HUB_CACHE* or, if not set, *HF_HOME* or, if not set, *XDG_CACHE_HOME*. Hence, we recommend setting the *HF_HUB_CACHE* to the place where you want to store the data first. For example, you can set it like this: + +```export HF_HUB_CACHE=/path/to/your/data/folder``` + +# Extra-Dependencies +This recipe supports two implementations of the transducer loss, see `use_torchaudio` arg in the yaml file: +1. Transducer loss from torchaudio (this requires torchaudio version >= 0.10.0). +2. Speechbrain implementation using Numba. To use it, please set `use_torchaudio=False` in the yaml file. This version is implemented within SpeechBrain and allows you to directly access the python code of the transducer loss (and directly modify it if needed). + +The Numba implementation is currently enabled by default as the `use_torchaudio` option is incompatible with `bfloat16` training. + +Note: Before running this recipe, make sure numba is installed. Otherwise, run: +``` +pip install numba +``` + +# How to run it +```shell +python train.py hparams/conformer_transducer.yaml +``` + +## Precision Notes +If your GPU effectively supports fp16 (half-precision) computations, it is recommended to execute the training script with the `--precision=fp16` (or `--precision=bf16`) option. +Enabling half precision can significantly reduce the peak VRAM requirements. For example, in the case of the Conformer Transducer recipe trained with GigaSpeech, the peak VRAM decreases from 39GB to 12GB when using fp16. +According to our tests, the performance is not affected. + +# Results (non-streaming) + +Results are obtained with beam search and no LM (no-streaming i.e. full context). + +**TBD: The final models are currently in training.** This model has already been successfully trained, though. This will be updated when the checkpoints are ready for download. + + + + + +## Streaming model + +### WER vs chunk size & left context + +The following matrix presents the Word Error Rate (WER%) achieved on GigaSpeech +`test` with various chunk sizes (in ms). + +The relative difference is not trivial to interpret, because we are not testing +against a continuous stream of speech, but rather against utterances of various +lengths. This tends to bias results in favor of larger chunk sizes. + +The chunk size might not accurately represent expected latency due to slight +padding differences in streaming contexts. + +The left chunk size is not representative of the receptive field of the model. +Because the model caches the streaming context at different layers, the model +may end up forming indirect dependencies to audio many seconds ago. + +| | full | cs=32 (1280ms) | 16 (640ms) | 8 (320ms) | +|:-----:|:----:|:-----:|:-----:|:-----:| + +**TBD: The final models are currently in training.** This model has already been successfully trained, though. This will be updated when the checkpoints are ready for download. + +### Inference + +Once your model is trained, you need a few manual steps in order to use it with the high-level streaming interfaces (`speechbrain.inference.ASR.StreamingASR`): + +1. Create a new directory where you want to store the model. +2. Copy `results/conformer_transducer//lm.ckpt` (optional; currently, for streaming rescoring LMs might be unsupported) and `tokenizer.ckpt` to that directory. +3. Copy `results/conformer_transducer//save/CKPT+????/model.ckpt` and `normalizer.ckpt` to that directory. +4. Copy your hyperparameters file to that directory. Uncomment the streaming specific keys and remove any training-specific keys. Alternatively, grab the inference hyperparameters YAML for this model from HuggingFace and adapt it to any changes you may have done. +5. You can now instantiate a `StreamingASR` with your model using `StreamingASR.from_hparams("/path/to/model/")`. + +The contents of that directory may be uploaded as a HuggingFace model, in which case the model source path can just be specified as `youruser/yourmodel`. + +# **About SpeechBrain** +- Website: https://speechbrain.github.io/ +- Code: https://github.com/speechbrain/speechbrain/ +- HuggingFace: https://huggingface.co/speechbrain/ + + +# **Citing SpeechBrain** +Please, cite SpeechBrain if you use it for your research or business. + +```bibtex +@misc{speechbrainV1, + title={Open-Source Conversational AI with SpeechBrain 1.0}, + author={Mirco Ravanelli and Titouan Parcollet and Adel Moumen and Sylvain de Langen and Cem Subakan and Peter Plantinga and Yingzhi Wang and Pooneh Mousavi and Luca Della Libera and Artem Ploujnikov and Francesco Paissan and Davide Borra and Salah Zaiem and Zeyu Zhao and Shucong Zhang and Georgios Karakasidis and Sung-Lin Yeh and Pierre Champion and Aku Rouhe and Rudolf Braun and Florian Mai and Juan Zuluaga-Gomez and Seyed Mahed Mousavi and Andreas Nautsch and Xuechen Liu and Sangeet Sagar and Jarod Duret and Salima Mdhaffar and Gaelle Laperriere and Mickael Rouvier and Renato De Mori and Yannick Esteve}, + year={2024}, + eprint={2407.00463}, + archivePrefix={arXiv}, + primaryClass={cs.LG}, + url={https://arxiv.org/abs/2407.00463}, +} +@misc{speechbrain, + title={{SpeechBrain}: A General-Purpose Speech Toolkit}, + author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio}, + year={2021}, + eprint={2106.04624}, + archivePrefix={arXiv}, + primaryClass={eess.AS}, + note={arXiv:2106.04624} +} +``` diff --git a/recipes/GigaSpeech/ASR/transducer/dataset.py b/recipes/GigaSpeech/ASR/transducer/dataset.py new file mode 120000 index 0000000000..f3bfeaf826 --- /dev/null +++ b/recipes/GigaSpeech/ASR/transducer/dataset.py @@ -0,0 +1 @@ +../../dataset.py \ No newline at end of file diff --git a/recipes/GigaSpeech/ASR/transducer/extra_requirements.txt b/recipes/GigaSpeech/ASR/transducer/extra_requirements.txt new file mode 100644 index 0000000000..f582033930 --- /dev/null +++ b/recipes/GigaSpeech/ASR/transducer/extra_requirements.txt @@ -0,0 +1,8 @@ +datasets +# Numba is used if use_torchaudio=False +# Numba might be faster, but it is harder to install +# You might need to install numba with conda +# You might also need to install other packages such as cudatoolkit +numba +soundfile +speechcolab diff --git a/recipes/GigaSpeech/ASR/transducer/gigaspeech_prepare.py b/recipes/GigaSpeech/ASR/transducer/gigaspeech_prepare.py new file mode 120000 index 0000000000..5190685a8e --- /dev/null +++ b/recipes/GigaSpeech/ASR/transducer/gigaspeech_prepare.py @@ -0,0 +1 @@ +../../gigaspeech_prepare.py \ No newline at end of file diff --git a/recipes/GigaSpeech/ASR/transducer/hparams/conformer_transducer.yaml b/recipes/GigaSpeech/ASR/transducer/hparams/conformer_transducer.yaml new file mode 100644 index 0000000000..3024e78522 --- /dev/null +++ b/recipes/GigaSpeech/ASR/transducer/hparams/conformer_transducer.yaml @@ -0,0 +1,402 @@ +# ############################################################################ +# Model: E2E ASR with transformer and transducer +# Encoder: Conformer +# Decoder: LSTM + beamsearch + RNNLM +# Tokens: BPE with unigram +# losses: Transducer + CTC (optional) + CE (optional) +# Training: GigaSpeech +# Authors: Titouan Parcollet 2024 +# ############################################################################ + +# Seed needs to be set at top of yaml, before objects with parameters are made +seed: 1986 +__set_seed: !apply:speechbrain.utils.seed_everything [!ref ] +experiment_name: conformer_transducer +output_folder: !ref results// +output_wer_folder: !ref / +save_folder: !ref /save +train_log: !ref /train_log.txt + +# Data files +data_folder: !PLACEHOLDER # e,g./path/to/GigaSpeech + +# see https://github.com/SpeechColab/GigaSpeech for more details on the dataset +# must be one of ["XS", "S", "M", "L", "XL"] +# and ["DEV", "TEST"] for the eval splits. +splits: ["XL", "DEV", "TEST"] +skip_prep: False +data_prep_only: False +download_with_HF: True +convert_opus_to_wav: True +keep_filler_words: False +keep_punctuation: False +ckpt_interval_minutes: 10 # save checkpoint every N min +train_csv: !ref /train.csv +valid_csv: !ref /dev.csv +test_csv: !ref /test.csv +json_file: !ref /GigaSpeech.json + +####################### Training Parameters #################################### + +# To make Transformers converge, the global bath size should be large enough. +# The global batch size is computed as batch_size * n_gpus * grad_accumulation_factor. +# Empirically, we found that this value should be >= 128. +# Please, set your parameters accordingly. +number_of_epochs: 10 +optimizer_step_limit: 400000 +warmup_steps: 30000 +num_workers: 4 +batch_size_valid: 4 +lr: 0.0008 +weight_decay: 0.01 +number_of_ctc_epochs: 2 +ctc_weight: 0.3 # Multitask with CTC for the encoder (0.0 = disabled) +ce_weight: 0.0 # Multitask with CE for the decoder (0.0 = disabled) +max_grad_norm: 5.0 +loss_reduction: 'batchmean' +precision: fp16 # bf16, fp16 or fp32 +grad_accumulation_factor: 1 + +# The batch size is used if and only if dynamic batching is set to False +# Validation and testing are done with fixed batches and not dynamic batching. +batch_size: 8 + +sorting: random +avg_checkpoints: 1 # Number of checkpoints to average for evaluation + +# Feature parameters +sample_rate: 16000 +n_fft: 512 +n_mels: 80 +win_length: 32 + +# Streaming & dynamic chunk training options +# At least for the current architecture on LibriSpeech, we found out that +# non-streaming accuracy is very similar between `streaming: True` and +# `streaming: False`. +streaming: True # controls all Dynamic Chunk Training & chunk size & left context mechanisms + +# Configuration for Dynamic Chunk Training. +# In this model, a chunk is roughly equivalent to 40ms of audio. +dynchunktrain_config_sampler: !new:speechbrain.utils.dynamic_chunk_training.DynChunkTrainConfigRandomSampler # yamllint disable-line rule:line-length + chunkwise_prob: 0.6 # Probability during a batch to limit attention and sample a random chunk size in the following range + chunk_size_min: 8 # Minimum chunk size (if in a DynChunkTrain batch) + chunk_size_max: 32 # Maximum chunk size (if in a DynChunkTrain batch) + limited_left_context_prob: 0.75 # If in a DynChunkTrain batch, the probability during a batch to restrict left context to a random number of chunks + left_context_chunks_min: 2 # Minimum left context size (in # of chunks) + left_context_chunks_max: 32 # Maximum left context size (in # of chunks) + # If you specify a valid/test config, you can optionally have evaluation be + # done with a specific DynChunkTrain configuration. + # valid_config: !new:speechbrain.utils.dynamic_chunk_training.DynChunkTrainConfig + # chunk_size: 24 + # left_context_size: 16 + # test_config: ... + +# Dataloader options +train_dataloader_opts: + batch_size: !ref + num_workers: !ref + +valid_dataloader_opts: + batch_size: !ref + +test_dataloader_opts: + batch_size: !ref + +# Using dynamic batching by default. This works with 48GB GPUs +# Or turn it off (but training speed will decrease) +# Play with grad_accum_factor such that the total batch is around 600 to 1500 s. +dynamic_batching: True +max_batch_length_train: 250 +max_batch_length_val: 50 # we reduce it as the beam is much wider (VRAM) +num_bucket: 200 +shuffle: True # if true re-creates batches at each epoch shuffling examples. +batch_ordering: random +max_batch_ex: 256 + +dynamic_batch_sampler_train: + max_batch_length: !ref + num_buckets: !ref + shuffle: !ref + batch_ordering: !ref + max_batch_ex: !ref + +dynamic_batch_sampler_valid: + max_batch_length: !ref + num_buckets: !ref + shuffle: !ref + batch_ordering: !ref + max_batch_ex: !ref + +# BPE parameters +token_type: unigram # ["unigram", "bpe", "char"] +character_coverage: 1.0 + +####################### Model Parameters ####################################### + +# Transformer +d_model: 768 +joint_dim: 512 +nhead: 8 +num_encoder_layers: 12 +num_decoder_layers: 0 +d_ffn: 2048 +transformer_dropout: 0.1 +activation: !name:torch.nn.GELU +output_neurons: 1024 +dec_dim: 512 +dec_emb_dropout: 0.2 +dec_dropout: 0.1 + +# Decoding parameters +blank_index: 0 +bos_index: 1 +eos_index: 2 +pad_index: 0 +beam_size: 10 +nbest: 1 +# by default {state,expand}_beam = 2.3 as mention in paper +# https://arxiv.org/abs/1904.02619 +state_beam: 2.3 +expand_beam: 2.3 + +# If True uses torchaudio loss. Otherwise, the numba one +use_torchaudio: False + +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +normalize: !new:speechbrain.processing.features.InputNormalization + norm_type: global + update_until_epoch: 4 + +compute_features: !new:speechbrain.lobes.features.Fbank + sample_rate: !ref + n_fft: !ref + n_mels: !ref + win_length: !ref + +############################## Augmentations ################################### + +# Speed perturbation +speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb + orig_freq: !ref + speeds: [95, 100, 105] + +# Augmenter: Combines previously defined augmentations to perform data augmentation +wav_augment: !new:speechbrain.augment.augmenter.Augmenter + min_augmentations: 1 + max_augmentations: 1 + augment_prob: 1.0 + augmentations: [!ref ] + + +# Time Drop +time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop + drop_length_low: 12 + drop_length_high: 20 + drop_count_low: 1 + drop_count_high: 1 + replace: "zeros" + +# Frequency Drop +freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop + drop_length_low: 20 + drop_length_high: 25 + drop_count_low: 2 + drop_count_high: 2 + replace: "zeros" + dim: 2 + +# Time warp +time_warp: !new:speechbrain.augment.freq_domain.Warping + +fea_augment: !new:speechbrain.augment.augmenter.Augmenter + parallel_augment: False + concat_original: False + repeat_augment: 1 + shuffle_augmentations: False + min_augmentations: 2 + max_augmentations: 2 + augment_prob: 1.0 + augmentations: [ + !ref , + !ref , + !ref ] + +############################## Models ########################################## + +CNN: !new:speechbrain.lobes.models.convolution.ConvolutionFrontEnd + input_shape: (8, 10, 80) + num_blocks: 2 + num_layers_per_block: 1 + out_channels: (64, 32) + kernel_sizes: (3, 3) + strides: (2, 2) + residuals: (False, False) + +Transformer: !new:speechbrain.lobes.models.transformer.TransformerASR.TransformerASR # yamllint disable-line rule:line-length + input_size: 640 + tgt_vocab: !ref + d_model: !ref + nhead: !ref + num_encoder_layers: !ref + num_decoder_layers: !ref + d_ffn: !ref + dropout: !ref + activation: !ref + encoder_module: conformer + attention_type: RelPosMHAXL + normalize_before: True + causal: False + +# We must call an encoder wrapper so the decoder isn't run (we don't have any) +enc: !new:speechbrain.lobes.models.transformer.TransformerASR.EncoderWrapper + transformer: !ref + +# For MTL CTC over the encoder +proj_ctc: !new:speechbrain.nnet.linear.Linear + input_size: !ref + n_neurons: !ref + +# Define some projection layers to make sure that enc and dec +# output dim are the same before joining +proj_enc: !new:speechbrain.nnet.linear.Linear + input_size: !ref + n_neurons: !ref + bias: False + +proj_dec: !new:speechbrain.nnet.linear.Linear + input_size: !ref + n_neurons: !ref + bias: False + +# Uncomment for MTL with CTC +ctc_cost: !name:speechbrain.nnet.losses.ctc_loss + blank_index: !ref + reduction: !ref + +emb: !new:speechbrain.nnet.embedding.Embedding + num_embeddings: !ref + consider_as_one_hot: True + blank_id: !ref + +dec: !new:speechbrain.nnet.RNN.LSTM + input_shape: [null, null, !ref - 1] + hidden_size: !ref + num_layers: 1 + re_init: True + +# For MTL with LM over the decoder (need to uncomment to activate) +# dec_lin: !new:speechbrain.nnet.linear.Linear +# input_size: !ref +# n_neurons: !ref +# bias: False + +# For MTL +ce_cost: !name:speechbrain.nnet.losses.nll_loss + label_smoothing: 0.1 + +Tjoint: !new:speechbrain.nnet.transducer.transducer_joint.Transducer_joint + joint: sum # joint [sum | concat] + nonlinearity: !ref + +transducer_lin: !new:speechbrain.nnet.linear.Linear + input_size: !ref + n_neurons: !ref + bias: False + +log_softmax: !new:speechbrain.nnet.activations.Softmax + apply_log: True + +transducer_cost: !name:speechbrain.nnet.losses.transducer_loss + blank_index: !ref + use_torchaudio: !ref + +# for MTL +# update model if any HEAD module is added +modules: + CNN: !ref + enc: !ref + emb: !ref + dec: !ref + Tjoint: !ref + transducer_lin: !ref + normalize: !ref + proj_ctc: !ref + proj_dec: !ref + proj_enc: !ref + + +# update model if any HEAD module is added +model: !new:torch.nn.ModuleList + - [!ref , !ref , !ref , !ref , !ref , !ref , !ref , !ref ] + +############################## Decoding & optimiser ############################ + +Greedysearcher: !new:speechbrain.decoders.transducer.TransducerBeamSearcher + decode_network_lst: [!ref , !ref , !ref ] + tjoint: !ref + classifier_network: [!ref ] + blank_id: !ref + beam_size: 1 + nbest: 1 + +Beamsearcher: !new:speechbrain.decoders.transducer.TransducerBeamSearcher + decode_network_lst: [!ref , !ref , !ref ] + tjoint: !ref + classifier_network: [!ref ] + blank_id: !ref + beam_size: !ref + nbest: !ref + state_beam: !ref + expand_beam: !ref + +opt_class: !name:torch.optim.AdamW + lr: !ref + betas: (0.9, 0.98) + eps: 1.e-8 + weight_decay: !ref + +noam_annealing: !new:speechbrain.nnet.schedulers.NoamScheduler + lr_initial: !ref + n_warmup_steps: !ref + +############################## Logging and Pretrainer ########################## + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + scheduler: !ref + normalizer: !ref + counter: !ref + + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref + +error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + +cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + split_tokens: True + +# for the inference hparams, you will need to include and uncomment something like this: + +# make_tokenizer_streaming_context: !name:speechbrain.tokenizers.SentencePiece.SentencePieceDecoderStreamingContext +# tokenizer_decode_streaming: !name:speechbrain.tokenizers.SentencePiece.spm_decode_preserve_leading_space + +# make_decoder_streaming_context: !name:speechbrain.decoders.transducer.TransducerGreedySearcherStreamingContext # default constructor +# decoding_function: !name:speechbrain.decoders.transducer.TransducerBeamSearcher.transducer_greedy_decode_streaming +# - !ref # self + +# fea_streaming_extractor: !new:speechbrain.lobes.features.StreamingFeatureWrapper +# module: !new:speechbrain.nnet.containers.LengthsCapableSequential +# - !ref +# - !ref +# - !ref +# # don't consider normalization as part of the input filter chain. +# # normalization will operate at chunk level, which mismatches training +# # somewhat, but does not appear to result in noticeable degradation. +# properties: !apply:speechbrain.utils.filter_analysis.stack_filter_properties +# - [!ref , !ref ] diff --git a/recipes/GigaSpeech/ASR/transducer/train.py b/recipes/GigaSpeech/ASR/transducer/train.py new file mode 100644 index 0000000000..280946a065 --- /dev/null +++ b/recipes/GigaSpeech/ASR/transducer/train.py @@ -0,0 +1,540 @@ +#!/usr/bin/env/python3 +"""Recipe for training a Transducer ASR system with GigaSpeech. +The system employs an encoder, a decoder, and an joint network +between them. Decoding is performed with beamsearch coupled with a neural +language model. + +To run this recipe, do the following: +> python train.py hparams/conformer_transducer.yaml + +With the default hyperparameters, the system employs a conformer encoder. +The decoder is based on a standard LSTM. Beamsearch coupled with a RNN +language model is used on the top of decoder probabilities. + +The neural network is trained on both CTC and negative-log likelihood +targets and sub-word units estimated with Byte Pairwise Encoding (BPE) +are used as basic recognition tokens. + +The experiment file is flexible enough to support a large variety of +different systems. By properly changing the parameter files, you can try +different encoders, decoders, tokens (e.g, characters instead of BPE), +training split, and many +other possible variations. + + +Authors + * Sylvain de Langen 2024 + * Titouan Parcollet 2024 + * Abdel Heba 2020 + * Mirco Ravanelli 2020 + * Ju-Chieh Chou 2020 + * Peter Plantinga 2020 +""" + +import os +import sys + +import torch +from hyperpyyaml import load_hyperpyyaml + +import speechbrain as sb +from speechbrain.tokenizers.SentencePiece import SentencePiece +from speechbrain.utils.data_utils import undo_padding +from speechbrain.utils.distributed import if_main_process, run_on_main +from speechbrain.utils.logger import get_logger + +logger = get_logger(__name__) + + +class ASR(sb.Brain): + def compute_forward(self, batch, stage): + """Forward computations from the waveform batches to the output probabilities.""" + batch = batch.to(self.device) + wavs, wav_lens = batch.sig + tokens_with_bos, token_with_bos_lens = batch.tokens_bos + + # Add waveform augmentation if specified. + if stage == sb.Stage.TRAIN: + if hasattr(self.hparams, "wav_augment"): + wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens) + tokens_with_bos = self.hparams.wav_augment.replicate_labels( + tokens_with_bos + ) + + feats = self.hparams.compute_features(wavs) + + # Add feature augmentation if specified. + if stage == sb.Stage.TRAIN and hasattr(self.hparams, "fea_augment"): + feats, fea_lens = self.hparams.fea_augment(feats, wav_lens) + tokens_with_bos = self.hparams.fea_augment.replicate_labels( + tokens_with_bos + ) + + current_epoch = self.hparams.epoch_counter.current + + # Old models may not have the streaming hparam, we don't break them in + # any other way so just check for its presence + if hasattr(self.hparams, "streaming") and self.hparams.streaming: + dynchunktrain_config = self.hparams.dynchunktrain_config_sampler( + stage + ) + else: + dynchunktrain_config = None + + feats = self.modules.normalize(feats, wav_lens, epoch=current_epoch) + + src = self.modules.CNN(feats) + x = self.modules.enc( + src, + wav_lens, + pad_idx=self.hparams.pad_index, + dynchunktrain_config=dynchunktrain_config, + ) + x = self.modules.proj_enc(x) + + e_in = self.modules.emb(tokens_with_bos) + e_in = torch.nn.functional.dropout( + e_in, + self.hparams.dec_emb_dropout, + training=(stage == sb.Stage.TRAIN), + ) + h, _ = self.modules.dec(e_in) + h = torch.nn.functional.dropout( + h, self.hparams.dec_dropout, training=(stage == sb.Stage.TRAIN) + ) + h = self.modules.proj_dec(h) + + # Joint network + # add labelseq_dim to the encoder tensor: [B,T,H_enc] => [B,T,1,H_enc] + # add timeseq_dim to the decoder tensor: [B,U,H_dec] => [B,1,U,H_dec] + joint = self.modules.Tjoint(x.unsqueeze(2), h.unsqueeze(1)) + + # Output layer for transducer log-probabilities + logits_transducer = self.modules.transducer_lin(joint) + + # Compute outputs + if stage == sb.Stage.TRAIN: + p_ctc = None + p_ce = None + + if ( + self.hparams.ctc_weight > 0.0 + and current_epoch <= self.hparams.number_of_ctc_epochs + ): + # Output layer for ctc log-probabilities + out_ctc = self.modules.proj_ctc(x) + p_ctc = self.hparams.log_softmax(out_ctc) + + if self.hparams.ce_weight > 0.0: + # Output layer for ctc log-probabilities + p_ce = self.modules.dec_lin(h) + p_ce = self.hparams.log_softmax(p_ce) + + return p_ctc, p_ce, logits_transducer, wav_lens + + elif stage == sb.Stage.VALID: + best_hyps, scores, _, _ = self.hparams.Greedysearcher(x) + return logits_transducer, wav_lens, best_hyps + else: + ( + best_hyps, + best_scores, + nbest_hyps, + nbest_scores, + ) = self.hparams.Beamsearcher(x) + return logits_transducer, wav_lens, best_hyps + + def compute_objectives(self, predictions, batch, stage): + """Computes the loss (Transducer+(CTC+NLL)) given predictions and targets.""" + + ids = batch.id + tokens, token_lens = batch.tokens + tokens_eos, token_eos_lens = batch.tokens_eos + + # Train returns 4 elements vs 3 for val and test + if len(predictions) == 4: + p_ctc, p_ce, logits_transducer, wav_lens = predictions + else: + logits_transducer, wav_lens, predicted_tokens = predictions + + if stage == sb.Stage.TRAIN: + # Labels must be extended if parallel augmentation or concatenated + # augmentation was performed on the input (increasing the time dimension) + if hasattr(self.hparams, "fea_augment"): + ( + tokens, + token_lens, + tokens_eos, + token_eos_lens, + ) = self.hparams.fea_augment.replicate_multiple_labels( + tokens, token_lens, tokens_eos, token_eos_lens + ) + + if stage == sb.Stage.TRAIN: + CTC_loss = 0.0 + CE_loss = 0.0 + if p_ctc is not None: + CTC_loss = self.hparams.ctc_cost( + p_ctc, tokens, wav_lens, token_lens + ) + if p_ce is not None: + CE_loss = self.hparams.ce_cost( + p_ce, tokens_eos, length=token_eos_lens + ) + loss_transducer = self.hparams.transducer_cost( + logits_transducer, tokens, wav_lens, token_lens + ) + loss = ( + self.hparams.ctc_weight * CTC_loss + + self.hparams.ce_weight * CE_loss + + (1 - (self.hparams.ctc_weight + self.hparams.ce_weight)) + * loss_transducer + ) + else: + loss = self.hparams.transducer_cost( + logits_transducer, tokens, wav_lens, token_lens + ) + + if stage != sb.Stage.TRAIN: + # Decode token terms to words + predicted_words = self.tokenizer( + predicted_tokens, task="decode_from_list" + ) + + # Convert indices to words + target_words = undo_padding(tokens, token_lens) + target_words = self.tokenizer(target_words, task="decode_from_list") + self.wer_metric.append(ids, predicted_words, target_words) + self.cer_metric.append(ids, predicted_words, target_words) + + return loss + + def on_fit_batch_end(self, batch, outputs, loss, should_step): + """At the end of the optimizer step, apply noam annealing.""" + if should_step: + self.hparams.noam_annealing(self.optimizer) + + def on_stage_start(self, stage, epoch): + """Gets called at the beginning of each epoch""" + if stage != sb.Stage.TRAIN: + self.cer_metric = self.hparams.cer_computer() + self.wer_metric = self.hparams.error_rate_computer() + + def on_stage_end(self, stage, stage_loss, epoch): + """Gets called at the end of a epoch.""" + # Compute/store important stats + stage_stats = {"loss": stage_loss} + if stage == sb.Stage.TRAIN: + self.train_stats = stage_stats + else: + stage_stats["CER"] = self.cer_metric.summarize("error_rate") + stage_stats["WER"] = self.wer_metric.summarize("error_rate") + + # Perform end-of-iteration things, like annealing, logging, etc. + if stage == sb.Stage.VALID: + lr = self.hparams.noam_annealing.current_lr + steps = self.optimizer_step + optimizer = self.optimizer.__class__.__name__ + + epoch_stats = { + "epoch": epoch, + "lr": lr, + "steps": steps, + "optimizer": optimizer, + } + + self.hparams.train_logger.log_stats( + stats_meta=epoch_stats, + train_stats=self.train_stats, + valid_stats=stage_stats, + ) + self.checkpointer.save_and_keep_only( + meta={"WER": stage_stats["WER"], "epoch": epoch}, + min_keys=["WER"], + num_to_keep=self.hparams.avg_checkpoints, + ) + + elif stage == sb.Stage.TEST: + self.hparams.train_logger.log_stats( + stats_meta={"Epoch loaded": self.hparams.epoch_counter.current}, + test_stats=stage_stats, + ) + if if_main_process(): + with open(self.hparams.test_wer_file, "w") as w: + self.wer_metric.write_stats(w) + + # save the averaged checkpoint at the end of the evaluation stage + # delete the rest of the intermediate checkpoints + # WER is set to -0.1 so checkpointer only keeps the averaged checkpoint + self.checkpointer.save_and_keep_only( + meta={"WER": -0.1, "epoch": epoch}, + min_keys=["WER"], + num_to_keep=1, + ) + + def on_evaluate_start(self, max_key=None, min_key=None): + """perform checkpoint average if needed""" + super().on_evaluate_start() + + ckpts = self.checkpointer.find_checkpoints( + max_key=max_key, + min_key=min_key, + ) + ckpt = sb.utils.checkpoints.average_checkpoints( + ckpts, recoverable_name="model" + ) + + self.hparams.model.load_state_dict(ckpt, strict=True) + self.hparams.model.eval() + + +def dataio_prepare(hparams): + """This function prepares the datasets to be used in the brain class. + It also defines the data processing pipeline through user-defined functions. + """ + data_folder = hparams["data_folder"] + + train_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["train_csv"], + replacements={"data_root": data_folder}, + ) + + if hparams["sorting"] == "ascending": + # we sort training data to speed up training and get better results. + train_data = train_data.filtered_sorted(sort_key="duration") + # when sorting do not shuffle in dataloader ! otherwise is pointless + hparams["train_dataloader_opts"]["shuffle"] = False + + elif hparams["sorting"] == "descending": + train_data = train_data.filtered_sorted( + sort_key="duration", reverse=True + ) + # when sorting do not shuffle in dataloader ! otherwise is pointless + hparams["train_dataloader_opts"]["shuffle"] = False + + elif hparams["sorting"] == "random": + pass + + else: + raise NotImplementedError( + "sorting must be random, ascending or descending" + ) + + valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["valid_csv"], + replacements={"data_root": data_folder}, + ) + valid_data = valid_data.filtered_sorted(sort_key="duration") + + test_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["test_csv"], + replacements={"data_root": data_folder}, + ) + + # We also sort the validation data so it is faster to validate + test_data = test_data.filtered_sorted(sort_key="duration") + + datasets = [train_data, valid_data, test_data] + + # 2. Define audio pipeline: + @sb.utils.data_pipeline.takes("audio_path", "begin_time", "end_time") + @sb.utils.data_pipeline.provides("sig") + def audio_pipeline(audio_path, begin_time, end_time): + if hparams["download_with_HF"]: + sig = sb.dataio.dataio.read_audio(audio_path) + else: + start_sample = int(float(begin_time) * hparams["sample_rate"]) + stop_sample = int(float(end_time) * hparams["sample_rate"]) + sig = sb.dataio.dataio.read_audio( + {"file": audio_path, "start": start_sample, "stop": stop_sample} + ) + return sig + + sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline) + + # 3. Define text pipeline: + @sb.utils.data_pipeline.takes("text") + @sb.utils.data_pipeline.provides( + "wrd", "tokens_list", "tokens_bos", "tokens_eos", "tokens" + ) + def text_pipeline(wrd): + yield wrd + tokens_list = tokenizer.sp.encode_as_ids(wrd) + yield tokens_list + tokens_bos = torch.LongTensor([hparams["bos_index"]] + (tokens_list)) + yield tokens_bos + tokens_eos = torch.LongTensor(tokens_list + [hparams["eos_index"]]) + yield tokens_eos + tokens = torch.LongTensor(tokens_list) + yield tokens + + sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline) + + # 4. Set output: + sb.dataio.dataset.set_output_keys( + datasets, + ["id", "sig", "wrd", "tokens_bos", "tokens_eos", "tokens"], + ) + + # 5. If Dynamic Batching is used, we instantiate the needed samplers. + train_batch_sampler = None + valid_batch_sampler = None + if hparams["dynamic_batching"]: + from speechbrain.dataio.sampler import DynamicBatchSampler # noqa + + dynamic_hparams_train = hparams["dynamic_batch_sampler_train"] + dynamic_hparams_valid = hparams["dynamic_batch_sampler_valid"] + + train_batch_sampler = DynamicBatchSampler( + train_data, + length_func=lambda x: x["duration"], + **dynamic_hparams_train, + ) + valid_batch_sampler = DynamicBatchSampler( + valid_data, + length_func=lambda x: x["duration"], + **dynamic_hparams_valid, + ) + + return ( + train_data, + valid_data, + test_data, + train_batch_sampler, + valid_batch_sampler, + ) + + +if __name__ == "__main__": + # CLI: + hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) + + # Use torchaudio if the device is CPU + if run_opts.get("device") == "cpu": + if "use_torchaudio: True" in overrides: + overrides.replace("use_torchaudio: True", "use_torchaudio: False") + else: + overrides += "\nuse_torchaudio: True" + + # create ddp_group with the right communication protocol + sb.utils.distributed.ddp_init_group(run_opts) + + with open(hparams_file) as fin: + hparams = load_hyperpyyaml(fin, overrides) + + # Create experiment directory + sb.create_experiment_directory( + experiment_directory=hparams["output_folder"], + hyperparams_to_save=hparams_file, + overrides=overrides, + ) + + # Dataset prep (parsing Librispeech) + from gigaspeech_prepare import prepare_gigaspeech # noqa + + # multi-gpu (ddp) save data preparation + run_on_main( + prepare_gigaspeech, + kwargs={ + "data_folder": hparams["data_folder"], + "save_folder": hparams["save_folder"], + "splits": hparams["splits"], + "output_train": hparams["train_csv"], + "output_dev": hparams["valid_csv"], + "output_test": hparams["test_csv"], + "json_file": hparams["json_file"], + "skip_prep": hparams["skip_prep"], + "convert_opus_to_wav": hparams["convert_opus_to_wav"], + "download_with_HF": hparams["download_with_HF"], + "punctuation": hparams["keep_punctuation"], + "filler": hparams["keep_filler_words"], + }, + ) + + if hparams["data_prep_only"]: + logger.info( + "Data preparation finished. Restart the script with data_prep_only to False. " + ) + import sys + + sys.exit() + + # Defining tokenizer and loading it + tokenizer = SentencePiece( + model_dir=hparams["save_folder"], + vocab_size=hparams["output_neurons"], + annotation_train=hparams["train_csv"], + annotation_read="text", + model_type=hparams["token_type"], + character_coverage=hparams["character_coverage"], + bos_id=hparams["bos_index"], + eos_id=hparams["eos_index"], + ) + + # here we create the datasets objects as well as tokenization and encoding + ( + train_data, + valid_data, + test_data, + train_bsampler, + valid_bsampler, + ) = dataio_prepare(hparams) + + # Trainer initialization + asr_brain = ASR( + modules=hparams["modules"], + opt_class=hparams["opt_class"], + hparams=hparams, + run_opts=run_opts, + checkpointer=hparams["checkpointer"], + ) + + # We dynamically add the tokenizer to our brain class. + # NB: This tokenizer corresponds to the one used for the LM!! + asr_brain.tokenizer = tokenizer + + # We dynamically add the tokenizer to our brain class. + # NB: This tokenizer corresponds to the one used for the LM!! + train_dataloader_opts = hparams["train_dataloader_opts"] + valid_dataloader_opts = hparams["valid_dataloader_opts"] + + if train_bsampler is not None: + train_dataloader_opts = { + "batch_sampler": train_bsampler, + "num_workers": hparams["num_workers"], + } + + if valid_bsampler is not None: + valid_dataloader_opts = {"batch_sampler": valid_bsampler} + + # Training + asr_brain.fit( + asr_brain.hparams.epoch_counter, + train_data, + valid_data, + train_loader_kwargs=train_dataloader_opts, + valid_loader_kwargs=valid_dataloader_opts, + ) + + # Testing + os.makedirs(hparams["output_wer_folder"], exist_ok=True) + + # report WER on valid data + asr_brain.hparams.test_wer_file = os.path.join( + hparams["output_wer_folder"], "valid_wer.txt" + ) + asr_brain.evaluate( + valid_data, + min_key="WER", + test_loader_kwargs=hparams["test_dataloader_opts"], + ) + + # report WER on test data + asr_brain.hparams.test_wer_file = os.path.join( + hparams["output_wer_folder"], "test_wer.txt" + ) + asr_brain.evaluate( + test_data, + min_key="WER", + test_loader_kwargs=hparams["test_dataloader_opts"], + ) diff --git a/recipes/GigaSpeech/README.md b/recipes/GigaSpeech/README.md new file mode 100644 index 0000000000..a71fd9593d --- /dev/null +++ b/recipes/GigaSpeech/README.md @@ -0,0 +1,13 @@ +# Experimenting with the GigaSpeech dataset + +GigaSpeech is an evolving, multi-domain English speech recognition corpus with 10,000 hours of high quality labeled audio suitable for supervised training, and 40,000 hours of total audio suitable for semi-supervised and unsupervised training (this implementation contains only labelled data for now). However, the data access is gated, meaning, you need to request access to it. + +# Data access and download + +SpeechBrain supports two ways of dealing with the GigaSpeech dataset: +1. [HuggingFace dataset](https://huggingface.co/datasets/speechcolab/gigaspeech/). For HuggingFace note that **you must use** the HuggingFace client to log in first before running the recipe. +2. [Original Github](https://github.com/SpeechColab/GigaSpeech). + +You simply need to follow the instructions on either of the above links. **We strongly +recomment using HuggingFace as the download speed for people outside of China is +much quicker**. \ No newline at end of file diff --git a/recipes/GigaSpeech/dataset.py b/recipes/GigaSpeech/dataset.py new file mode 100644 index 0000000000..9841bc22c9 --- /dev/null +++ b/recipes/GigaSpeech/dataset.py @@ -0,0 +1,446 @@ +# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# MODIFIED BY: Adel Moumen 2024 +""" +GigaSpeech is an evolving, multi-domain English speech recognition corpus with 10,000 hours of high quality +labeled audio suitable for supervised training, and 40,000 hours of total audio suitable for semi-supervised +and unsupervised training. Around 40,000 hours of transcribed audio is first collected from audiobooks, podcasts +and YouTube, covering both read and spontaneous speaking styles, and a variety of topics, such as arts, science, +sports, etc. A new forced alignment and segmentation pipeline is proposed to create sentence segments suitable +for speech recognition training, and to filter out segments with low-quality transcription. For system training, +GigaSpeech provides five subsets of different sizes, 10h, 250h, 1000h, 2500h, and 10000h. +For our 10,000-hour XL training subset, we cap the word error rate at 4% during the filtering/validation stage, +and for all our other smaller training subsets, we cap it at 0%. The DEV and TEST evaluation sets, on the other hand, +are re-processed by professional human transcribers to ensure high transcription quality. +""" + +import csv +import os + +import datasets + +from speechbrain.utils.logger import get_logger + +logger = get_logger(__name__) + +_CITATION = """\ +@article{DBLP:journals/corr/abs-2106-06909, + author = {Guoguo Chen and + Shuzhou Chai and + Guanbo Wang and + Jiayu Du and + Wei{-}Qiang Zhang and + Chao Weng and + Dan Su and + Daniel Povey and + Jan Trmal and + Junbo Zhang and + Mingjie Jin and + Sanjeev Khudanpur and + Shinji Watanabe and + Shuaijiang Zhao and + Wei Zou and + Xiangang Li and + Xuchen Yao and + Yongqing Wang and + Yujun Wang and + Zhao You and + Zhiyong Yan}, + title = {GigaSpeech: An Evolving, Multi-domain {ASR} Corpus with 10, 000 Hours + of Transcribed Audio}, + journal = {CoRR}, + volume = {abs/2106.06909}, + year = {2021}, + url = {https://arxiv.org/abs/2106.06909}, + eprinttype = {arXiv}, + eprint = {2106.06909}, + timestamp = {Wed, 29 Dec 2021 14:29:26 +0100}, + biburl = {https://dblp.org/rec/journals/corr/abs-2106-06909.bib}, + bibsource = {dblp computer science bibliography, https://dblp.org} +} +""" + +_DESCRIPTION = """\ +GigaSpeech is an evolving, multi-domain English speech recognition corpus with 10,000 hours of high quality +labeled audio suitable for supervised training, and 40,000 hours of total audio suitable for semi-supervised +and unsupervised training. Around 40,000 hours of transcribed audio is first collected from audiobooks, podcasts +and YouTube, covering both read and spontaneous speaking styles, and a variety of topics, such as arts, science, +sports, etc. A new forced alignment and segmentation pipeline is proposed to create sentence segments suitable +for speech recognition training, and to filter out segments with low-quality transcription. For system training, +GigaSpeech provides five subsets of different sizes, 10h, 250h, 1000h, 2500h, and 10000h. +For our 10,000-hour XL training subset, we cap the word error rate at 4% during the filtering/validation stage, +and for all our other smaller training subsets, we cap it at 0%. The DEV and TEST evaluation sets, on the other hand, +are re-processed by professional human transcribers to ensure high transcription quality. +""" + +_HOMEPAGE = "https://github.com/SpeechColab/GigaSpeech" + +_LICENSE = "Apache License 2.0" + +_CATEGORIES = ( + "People and Blogs", + "Business", + "Nonprofits and Activism", + "Crime", + "History", + "Pets and Animals", + "News and Politics", + "Travel and Events", + "Kids and Family", + "Leisure", + "N/A", + "Comedy", + "News and Politics", + "Sports", + "Arts", + "Science and Technology", + "Autos and Vehicles", + "Science and Technology", + "People and Blogs", + "Music", + "Society and Culture", + "Education", + "Howto and Style", + "Film and Animation", + "Gaming", + "Entertainment", + "Travel and Events", + "Health and Fitness", + "audiobook", +) + +_SOURCES = ("audiobook", "podcast", "youtube") + +_SUBSETS = ("xs", "s", "m", "l", "xl") + +_BASE_DATA_URL = ( + "https://huggingface.co/datasets/speechcolab/gigaspeech/resolve/main/data/" +) + +_AUDIO_ARCHIVE_URL = ( + _BASE_DATA_URL + + "audio/{subset}_files{is_additional}/{subset}_chunks_{archive_id:04}.tar.gz" +) + +_META_URL = ( + _BASE_DATA_URL + + "metadata/{subset}_metadata{is_additional}/{subset}_chunks_{archive_id:04}_metadata.csv" +) + +_N_ARCHIVES_URL = _BASE_DATA_URL + "{subset}_n_archives{is_additional}.txt" + +logger = datasets.utils.logging.get_logger(__name__) + + +class GigaspeechConfig(datasets.BuilderConfig): + """BuilderConfig for Gigaspeech.""" + + def __init__(self, name, *args, **kwargs): + super().__init__(name=name, *args, **kwargs) + # larger subsets are supersets of smaller subsets, + # if we want to download "m", we need to download "xs" and "s" data too. + # so if name == "m", self.subsets_to_download will be ("xs", "s", "m") + if name not in {"dev", "test"}: + self.subsets_to_download = _SUBSETS[: _SUBSETS.index(name) + 1] + else: + self.subsets_to_download = (name,) + + +class Gigaspeech(datasets.GeneratorBasedBuilder): + """ + GigaSpeech is an evolving, multi-domain English speech recognition corpus with 10,000 hours of high quality + labeled audio suitable for supervised training, and 40,000 hours of total audio suitable for semi-supervised + and unsupervised training (this implementation contains only labelled data for now). + Around 40,000 hours of transcribed audio is first collected from audiobooks, podcasts + and YouTube, covering both read and spontaneous speaking styles, and a variety of topics, such as arts, science, + sports, etc. A new forced alignment and segmentation pipeline is proposed to create sentence segments suitable + for speech recognition training, and to filter out segments with low-quality transcription. For system training, + GigaSpeech provides five subsets of different sizes, 10h, 250h, 1000h, 2500h, and 10000h. + For our 10,000-hour XL training subset, we cap the word error rate at 4% during the filtering/validation stage, + and for all our other smaller training subsets, we cap it at 0%. The DEV and TEST evaluation sets, on the other hand, + are re-processed by professional human transcribers to ensure high transcription quality. + """ + + VERSION = datasets.Version("1.0.0") + + BUILDER_CONFIGS = [ + GigaspeechConfig(name=subset) for subset in _SUBSETS + ("dev", "test") + ] + + DEFAULT_WRITER_BATCH_SIZE = 128 + + def _info(self): + features = datasets.Features( + { + "segment_id": datasets.Value("string"), + "speaker": datasets.Value("string"), + "text": datasets.Value("string"), + "audio": datasets.Audio(sampling_rate=16_000, decode=False), + "begin_time": datasets.Value("float32"), + "end_time": datasets.Value("float32"), + "audio_id": datasets.Value("string"), + "title": datasets.Value("string"), + "url": datasets.Value("string"), + "source": datasets.ClassLabel(names=_SOURCES), + "category": datasets.ClassLabel(names=_CATEGORIES), + "original_full_path": datasets.Value( + "string" + ), # relative path to full audio in original data dirs + } + ) + return datasets.DatasetInfo( + description=_DESCRIPTION, + features=features, + homepage=_HOMEPAGE, + license=_LICENSE, + citation=_CITATION, + ) + + def _is_additional_data(self, name): + if name in {"s", "m", "l", "xl"}: + return "_additional" + return "" + + @property + def _splits_to_subsets(self): + return { + "train": self.config.subsets_to_download, + "dev": ["dev"], + "test": ["test"], + } + + def _read_n_archives(self, n_archives_path): + with open(n_archives_path, encoding="utf-8") as f: + return int(f.read().strip()) + + def _split_generators(self, dl_manager): + splits_to_subsets = self._splits_to_subsets + if self.config.name in {"dev", "test"}: + splits = (self.config.name,) + else: + splits = ("train", "dev", "test") + + # 1. get number of archives (shards) in each subset + n_archives_links = { + split: { + subset: _N_ARCHIVES_URL.format( + subset=subset, + is_additional=self._is_additional_data(subset), + ) + for subset in splits_to_subsets[split] + } + for split in splits + } + logger.info("Downloading the data. It may take a while.") + paths = dl_manager.download(n_archives_links) + logger.info("Extracting the data. It may take a while.") + n_archives_paths = dl_manager.extract(paths) + n_archives = { + # mapping from a subset to a single number - number of audio archives (shards) in a subset + split: { + subset: self._read_n_archives(n_archives_paths[split][subset]) + for subset in splits_to_subsets[split] + } + for split in splits + } + + # 2. prepare sharded archives with audio files + audio_archives_urls = { + split: { + subset: [ + _AUDIO_ARCHIVE_URL.format( + subset=subset, + is_additional=self._is_additional_data(subset), + archive_id=i, + ) + for i in range(n_archives[split][subset]) + ] + for subset in splits_to_subsets[split] + } + for split in splits + } + audio_archives_paths = dl_manager.download(audio_archives_urls) + # flatten archives paths from + # {"train": {"xs": [path1, path2,], "s": [path3], "m": [path5, path5]}, "dev": {"dev": [path6,...]}, "test": {"test": [...]}} + # to {"train": [path1, path2, path3, path4, path5], "dev": [path6, ...], "test": [...]} + audio_archives_paths = _flatten_nested_dict(audio_archives_paths) + local_audio_archives_paths = ( + dl_manager.extract(audio_archives_paths) + if not dl_manager.is_streaming + else None + ) + + # 3. prepare sharded metadata csv files + meta_urls = { + split: { + subset: [ + _META_URL.format( + subset=subset, + is_additional=self._is_additional_data(subset), + archive_id=i, + ) + for i in range(n_archives[split][subset]) + ] + for subset in splits_to_subsets[split] + } + for split in splits + } + meta_paths = dl_manager.download_and_extract(meta_urls) + meta_paths = _flatten_nested_dict(meta_paths) + + if self.config.name not in {"dev", "test"}: + return [ + datasets.SplitGenerator( + name=datasets.Split.TRAIN, + gen_kwargs={ + "audio_archives_iterators": [ + dl_manager.iter_archive(archive_path) + for archive_path in audio_archives_paths["train"] + ], + "local_audio_archives_paths": ( + local_audio_archives_paths["train"] + if local_audio_archives_paths + else None + ), + "meta_paths": meta_paths["train"], + }, + ), + datasets.SplitGenerator( + name=datasets.Split.VALIDATION, + gen_kwargs={ + "audio_archives_iterators": [ + dl_manager.iter_archive(archive_path) + for archive_path in audio_archives_paths["dev"] + ], + "local_audio_archives_paths": ( + local_audio_archives_paths["dev"] + if local_audio_archives_paths + else None + ), + "meta_paths": meta_paths["dev"], + }, + ), + datasets.SplitGenerator( + name=datasets.Split.TEST, + gen_kwargs={ + "audio_archives_iterators": [ + dl_manager.iter_archive(archive_path) + for archive_path in audio_archives_paths["test"] + ], + "local_audio_archives_paths": ( + local_audio_archives_paths["test"] + if local_audio_archives_paths + else None + ), + "meta_paths": meta_paths["test"], + }, + ), + ] + + if self.config.name == "dev": + return [ + datasets.SplitGenerator( + name=datasets.Split.VALIDATION, + gen_kwargs={ + "audio_archives_iterators": [ + dl_manager.iter_archive(archive_path) + for archive_path in audio_archives_paths["dev"] + ], + "local_audio_archives_paths": ( + local_audio_archives_paths["dev"] + if local_audio_archives_paths + else None + ), + "meta_paths": meta_paths["dev"], + }, + ), + ] + + if self.config.name == "test": + return [ + datasets.SplitGenerator( + name=datasets.Split.TEST, + gen_kwargs={ + "audio_archives_iterators": [ + dl_manager.iter_archive(archive_path) + for archive_path in audio_archives_paths["test"] + ], + "local_audio_archives_paths": ( + local_audio_archives_paths["test"] + if local_audio_archives_paths + else None + ), + "meta_paths": meta_paths["test"], + }, + ), + ] + + def _generate_examples( + self, audio_archives_iterators, local_audio_archives_paths, meta_paths + ): + assert len(audio_archives_iterators) == len(meta_paths) + if local_audio_archives_paths: + assert len(audio_archives_iterators) == len( + local_audio_archives_paths + ) + + for i, (meta_path, audio_archive_iterator) in enumerate( + zip(meta_paths, audio_archives_iterators) + ): + meta_dict = dict() + with open(meta_path) as csvfile: + meta_csv = csv.DictReader(csvfile) + for line in meta_csv: + meta_dict[line["sid"]] = line + + for audio_path_in_archive, audio_file in audio_archive_iterator: + # `audio_path_in_archive` is like "dev_chunks_0000/YOU1000000029_S0000095.wav" + audio_filename = os.path.split(audio_path_in_archive)[1] + audio_id = audio_filename.split(".wav")[0] + audio_meta = meta_dict[audio_id] + audio_meta["segment_id"] = audio_meta.pop("sid") + audio_meta["original_full_path"] = audio_meta.pop("path") + audio_meta["text"] = audio_meta.pop("text_tn") + audio_meta["audio_id"] = audio_meta.pop("aid") + if not audio_meta["category"]: + audio_meta["category"] = "N/A" + + path = ( + os.path.join( + local_audio_archives_paths[i], audio_path_in_archive + ) + if local_audio_archives_paths + else audio_path_in_archive + ) + + yield audio_id, { + "audio": {"path": path, "bytes": audio_file.read()}, + **{ + feature: value + for feature, value in audio_meta.items() + if feature in self.info.features + }, + } + + +def _flatten_nested_dict(nested_dict): + return { + key: [ + inner_list_element + for inner_list in value_to_lists.values() + for inner_list_element in inner_list + ] + for key, value_to_lists in nested_dict.items() + } diff --git a/recipes/GigaSpeech/gigaspeech_prepare.py b/recipes/GigaSpeech/gigaspeech_prepare.py new file mode 100644 index 0000000000..cdfb502cc1 --- /dev/null +++ b/recipes/GigaSpeech/gigaspeech_prepare.py @@ -0,0 +1,699 @@ +""" +Data preparation script for the GigaSpeech dataset. + +Download instructions: + 1. https://github.com/SpeechColab/GigaSpeech + 2. https://huggingface.co/datasets/speechcolab/gigaspeech +Reference: https://arxiv.org/abs/2106.06909 + +Author +------- + * Adel Moumen, 2024 +""" + +import csv +import functools +import json +import logging +import os +from dataclasses import dataclass + +import torchaudio + +from speechbrain.utils.parallel import parallel_map + +logger = logging.getLogger(__name__) +FILLERS = [ + "UH", + "UHH", + "UM", + "EH", + "MM", + "HM", + "AH", + "HUH", + "HA", + "ER", + "OOF", + "HEE", + "ACH", + "EEE", + "EW", +] +GARBAGE_UTTERANCE_TAGS = ["", "", "", ""] +PUNCTUATION_TAGS = { + "": ",", + "": "!", + "": ".", + "": "?", +} +SPLITS = ["DEV", "TEST"] +TRAIN_SUBSET = ["XS", "S", "M", "L", "XL"] +SAMPLING_RATE = 16000 + + +@dataclass +class GigaSpeechRow: + """Dataclass for handling GigaSpeech rows. + + Attributes + ---------- + utt_id : str + The segment ID. + audio_id : str + The audio ID. + audio_path : str + The path to the audio file. + speaker : str + The speaker ID. + begin_time : float + The start time of the segment. + end_time : float + The end time of the segment. + duration : float + The duration of the segment. + text : str + The text of the segment. + """ + + utt_id: str # segment[sid] + audio_id: str # audio[aid] + audio_path: str # by default this is opus files + speaker: str # audio["speaker"] + begin_time: float + end_time: float + duration: float + text: str + + +def prepare_gigaspeech( + data_folder: str, + save_folder: str, + splits: list, + output_train: str, + output_dev: str, + output_test: str, + json_file: str = "GigaSpeech.json", + skip_prep: bool = False, + convert_opus_to_wav: bool = True, + download_with_HF: bool = False, + punctuation: bool = False, + filler: bool = False, + hf_multiprocess_load: bool = True, +) -> None: + """Prepare the csv files for GigaSpeech dataset. + + Download instructions: https://github.com/SpeechColab/GigaSpeech + Reference: https://arxiv.org/abs/2106.06909 + + The `train.csv` file is created by following the train subset specified in the `splits` list. + It must be part of the `TRAIN_SUBSET` list. You cannot use multiple train subsets. + + The `dev.csv` and `test.csv` files are created based on the `DEV` and `TEST` splits + specified in the `splits` list. + + Parameters + ---------- + data_folder : str + The path to the GigaSpeech dataset. + save_folder : str + The path to the folder where the CSV files will be saved. + splits : list + The list of splits to be used for creating the CSV files. + output_train : str + The path in which the train CSV or shards will be saved. + output_dev : str + The path in which the dev CSV or shards will be saved. + output_test : str + The path in which the test CSV or shards will be saved. + json_file : str, optional + The name of the JSON file containing the metadata of the GigaSpeech dataset. + skip_prep : bool, optional + If True, the data preparation will be skipped, and the function will return immediately. + convert_opus_to_wav : bool, optional + If True, the opus files will be converted to wav files. + download_with_HF : bool, optional + If True, the dataset will be downloaded using the Hugging Face datasets library. + We highly recommend using this option if you are based in the EU or US as it will + be faster and more reliable than the official host. Make sure to read the + instructions on how to get the dataset from Hugging Face here: + https://huggingface.co/datasets/speechcolab/gigaspeech + The dataset will be downloaded in the default folder specified in the + environment variable HF_HUB_CACHE. Please change it if necessary. + punctuation : bool, optional + Keeping the punctuation, or not. + filler : bool, optional + Keeping filler words (hum), or not. + hf_multiprocess_load: bool, optional + If True, all the CPU threads will be used for data prepration. If set to + False, only one will be. Note that the data prepration of the larger sets + on a single core car take more than 24 hours (from downloading to done). + + Returns + ------- + None + """ + logger.info(f"Preparing GigaSpeech dataset in {save_folder}...") + + if skip_prep: + logger.info("Skipping data preparation as `skip_prep` is set to `True`") + return + + # check that `splits` input is valid + for split in splits: + assert ( + split in SPLITS + TRAIN_SUBSET + ), f"Split {split} not recognized. Valid splits are {SPLITS + TRAIN_SUBSET}." + + # check that we are not using multiple train subsets + if len(set(splits).intersection(TRAIN_SUBSET)) > 1: + raise ValueError( + "You cannot use multiple train subsets. Please select only one train subset." + ) + + os.makedirs(save_folder, exist_ok=True) + + # Setting output paths + save_output = {} + split_map = {} + train_split = "" + for split in splits: + if split in TRAIN_SUBSET: + save_output["train"] = output_train + split_map["train"] = split + train_split = split + else: + if split == "DEV": + save_output["validation"] = output_dev + split_map["validation"] = split + elif split == "TEST": + save_output["test"] = output_test + split_map["test"] = split + + # check if the data is already prepared + if skip_csv(save_output): + logger.info("Skipping preparation, completed in previous run.") + return + else: + logger.info("Starting data preparation...") + + if download_with_HF: + from datasets import load_dataset + + if os.path.exists("dataset.py"): + logger.info("HuggingFace dataset.py found.") + else: + raise FileNotFoundError( + "HuggingFace dataset.py not found. Please run this recipe from the correct recipe folder or copy the dataset.py file." + ) + + if "HF_HUB_CACHE" in os.environ: + hf_caching_dir = os.environ["HF_HUB_CACHE"] + elif "HF_HOME" in os.environ: + hf_caching_dir = os.environ["HF_HOME"] + else: + hf_caching_dir = os.environ["XDG_CACHE_HOME"] + + logger.info( + "Downloading dataset from HuggingFace to: " + str(hf_caching_dir) + ) + logger.info( + "To change this directory modify the HF_HUB_CACHE env. variable." + ) + + nproc = 1 + if hf_multiprocess_load: + import multiprocessing + + nproc = multiprocessing.cpu_count() + + hf_dataset = load_dataset( + "dataset.py", + train_split.lower(), + trust_remote_code=True, + data_dir=data_folder, + cache_dir=data_folder, + num_proc=nproc, + ) + for split, output in save_output.items(): + logger.info(f"Starting creating {output} using {split} split.") + HF_create_csv(output, hf_dataset[split], split, punctuation, filler) + else: + # check that the data folder contains the GigaSpeech dataset + check_gigaspeech_folders(data_folder, json_file) + + logger.info(f"Starting reading {json_file}.") + with open(json_file, "r") as f: + info = json.load(f) + logger.info(f"Reading {json_file} done.") + + for split, output in save_output.items(): + logger.info(f"Starting creating {output} using {split} split.") + create_csv( + output, + info, + data_folder, + split_map[split], + convert_opus_to_wav, + punctuation, + filler, + ) + logger.info("Data preparation completed!") + + +def process_line( + audio: json, + data_folder: str, + split: str, + convert_opus_to_wav: bool, + punctuation: bool, + stopwords: list, +) -> list: + """ + Process the audio line and return the utterances for the given split. + + Parameters + ---------- + audio : dict + The audio line to be processed. + data_folder : str + The path to the GigaSpeech dataset. + split : str + The split to be used for filtering the data. + convert_opus_to_wav : bool + If True, the opus files will be converted to wav files. + punctuation : bool + Keeping punctuation or not. Default is no. + stopwords: list + List of stopwords to remove from the text of the labels. + + Returns + ------- + list + The list of utterances for the given split. + """ + if ("{" + split + "}") in audio["subsets"]: + + audio_path = os.path.join(data_folder, audio["path"]) + assert os.path.isfile(audio_path), f"File not found: {audio_path}" + + if convert_opus_to_wav and audio_path.endswith(".opus"): + audio_path = convert_opus2wav(audio_path) + + # 2. iterate over the utterances + utterances = [] + for segment in audio["segments"]: + text = preprocess_text(segment["text_tn"], punctuation, stopwords) + if text: + begin_time = float(segment["begin_time"]) + end_time = float(segment["end_time"]) + duration = end_time - begin_time + utterance = GigaSpeechRow( + utt_id=segment["sid"], + audio_id=audio["aid"], + audio_path=str(audio_path), + speaker=audio["speaker"], + begin_time=begin_time, + end_time=end_time, + duration=duration, + text=text, + ) + utterances.append(utterance) + return utterances + + +def create_csv( + csv_file: str, + info: json, + data_folder: str, + split: str, + convert_opus_to_wav: bool, + punctuation: bool = False, + filler: bool = False, +) -> None: + """ + Create a CSV file based on the info in the GigaSpeech JSON file and filter the data based on the split. + + Parameters + ---------- + csv_file : str + The path to the CSV file to be created. + info : dict + The GigaSpeech JSON file content. + data_folder : str + The path to the GigaSpeech dataset. + split : str + The split to be used for filtering the data. + convert_opus_to_wav : bool + If True, the opus files will be converted to wav files. + punctuation : bool + Keeping punctuation or not. Default is no. + filler : bool + Keeping filler words or not (hum, er). Default is no. + + Returns + ------- + None + """ + total_duration = 0.0 + nb_samples = 0 + + to_remove = GARBAGE_UTTERANCE_TAGS + if not filler: + to_remove += FILLERS + + line_processor = functools.partial( + process_line, + data_folder=data_folder, + split=split, + convert_opus_to_wav=convert_opus_to_wav, + stopwords=to_remove, + punctuation=punctuation, + ) + + csv_file_tmp = csv_file + ".tmp" + with open(csv_file_tmp, mode="w", encoding="utf-8") as csv_f: + csv_writer = csv.writer( + csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL + ) + header = [ + "ID", + "audio_id", + "audio_path", + "speaker", + "begin_time", + "end_time", + "duration", + "text", + ] + csv_writer.writerow(header) + for row in parallel_map(line_processor, info["audios"]): + if row is None: + continue + + for item in row: + csv_writer.writerow( + [ + item.utt_id, + item.audio_id, + item.audio_path, + item.speaker, + str(item.begin_time), + str(item.end_time), + str(item.duration), + item.text, + ] + ) + + total_duration += item.duration + nb_samples += 1 + + os.replace(csv_file_tmp, csv_file) + + logger.info(f"{csv_file} successfully created!") + logger.info(f"Number of samples in {split} split: {nb_samples}") + logger.info( + f"Total duration of {split} split: {round(total_duration / 3600, 2)} Hours" + ) + + +def HF_create_csv( + csv_file: str, + hf_dataset, + split: str, + punctuation: bool = False, + filler: bool = False, +) -> None: + """ + Create a CSV file based on a HuggingFace dataset. + + Parameters + ---------- + csv_file : str + The path to the CSV file to be created. + hf_dataset : huggingface dataset, + The huggingface dataset. + split : str + The split to be used for filtering the data. + punctuation : bool + Keeping punctuation or not. Default is no. + filler : bool + Keeping filler words or not (hum, er). Default is no. + + + Returns + ------- + None + """ + total_duration = 0.0 + nb_samples = 0 + + to_remove = GARBAGE_UTTERANCE_TAGS + if not filler: + to_remove += FILLERS + + line_processor = functools.partial( + HF_process_line, + stopwords=to_remove, + punctuation=punctuation, + ) + + csv_file_tmp = csv_file + ".tmp" + with open(csv_file_tmp, mode="w", encoding="utf-8") as csv_f: + csv_writer = csv.writer( + csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL + ) + header = [ + "ID", + "audio_id", + "audio_path", + "speaker", + "begin_time", + "end_time", + "duration", + "text", + ] + csv_writer.writerow(header) + + for row in parallel_map(line_processor, hf_dataset, chunk_size=1024): + if row is None: + continue + + csv_writer.writerow( + [ + row.utt_id, + row.audio_id, + row.audio_path, + row.speaker, + str(row.begin_time), + str(row.end_time), + str(row.duration), + row.text, + ] + ) + + total_duration += row.duration + nb_samples += 1 + + os.replace(csv_file_tmp, csv_file) + + logger.info(f"{csv_file} successfully created!") + logger.info(f"Number of samples in {split} split: {nb_samples}") + logger.info( + f"Total duration of {split} split: {round(total_duration / 3600, 2)} Hours" + ) + + +def HF_process_line(row: dict, punctuation: bool, stopwords: list) -> list: + """ + Process the audio line and return the utterances for the given split. + + Parameters + ---------- + row: dict + The audio line to be processed. + punctuation : bool + Keeping punctuation or not. Default is no. + stopwords: list + List of stopwords to remove from the text of the labels. + + Returns + ------- + list + The list of utterances for the given split. + """ + audio_path = os.path.join(row["audio"]["path"]) + + if not os.path.isfile(audio_path): + return None + + # check reading the audio file ; HF may have some corrupted files + try: + _ = torchaudio.info(audio_path) + except Exception as e: + logger.error(f"Failed reading {audio_path}: {e}") + return None + + text = preprocess_text(row["text"], punctuation, stopwords) + + if text: + utt_id = row["segment_id"] + audio_id = row["audio_id"] + audio_path = row["audio"]["path"] + speaker = row["speaker"] + begin_time = float(row["begin_time"]) + end_time = float(row["end_time"]) + duration = end_time - begin_time + + row = GigaSpeechRow( + utt_id=utt_id, + audio_id=audio_id, + audio_path=audio_path, + speaker=speaker, + begin_time=begin_time, + end_time=end_time, + duration=duration, + text=text, + ) + + return row + else: + return None + + +def convert_opus2wav(audio_opus_path): + """Convert an opus file to a wav file. + + Parameters + ---------- + audio_opus_path : str + The path to the opus file to be converted. + + Returns + ------- + str + The path to the converted wav file. + + Raises + ------ + subprocess.CalledProcessError + If the conversion process fails. + """ + audio_wav_path = audio_opus_path.replace(".opus", ".wav") + os.system( + f"ffmpeg -y -i {audio_opus_path} -ac 1 -ar {SAMPLING_RATE} {audio_wav_path} > /dev/null 2>&1" + ) + return audio_wav_path + + +def preprocess_text(text: str, punctuation: bool, stopwords) -> str: + """ + Preprocesses the input text by removing garbage tags and removing punctuation + and filler words if specified. + + Parameters + ---------- + text : str + The input text to be preprocessed. + punctuation : bool + Keeping punctuation or not. Default is no. + stopwords : list + List of words to remove from the input test string. + + Returns + ------- + str + The preprocessed text with removed garbage tags and replaced punctuation tags. + + Raises + ------ + AssertionError + If '<' or '>' tags are found in the text after preprocessing. + + Notes + ----- + The function iterates over predefined garbage utterance tags (GARBAGE_UTTERANCE_TAGS) + and removes them from the input text. It then iterates over predefined punctuation tags + (PUNCTUATION_TAGS) and replaces them with the corresponding punctuation. + + Examples + -------- + >>> text = " DOUGLAS MCGRAY IS GOING TO BE OUR GUIDE YOU WALK THROUGH THE DOOR YOU SEE THE RED CARPETING YOU SEE SOMEONE IN A SUIT THEY MAY BE GREETING YOU " + >>> preprocess_text(text, punctuation=True, stopwords=GARBAGE_UTTERANCE_TAGS) + "DOUGLAS MCGRAY IS GOING TO BE OUR GUIDE YOU WALK THROUGH THE DOOR, YOU SEE THE RED CARPETING, YOU SEE SOMEONE IN A SUIT. THEY MAY BE GREETING YOU." + """ + + text = text.upper() + text = text.replace("-", " ") + + sentence = " ".join( + [word for word in text.split() if word not in stopwords] + ) + + if punctuation: + for tag, punctuation in PUNCTUATION_TAGS.items(): + sentence = sentence.replace(" " + tag, punctuation) + + return sentence + + +def skip_csv(save_csv_files: dict) -> bool: + """Check if the CSV files already exist. + + Parameters + ---------- + save_csv_files : dict + The dictionary containing the paths to the CSV files. + + Returns + ------- + bool + True if all the CSV files already exist, False otherwise. + """ + return all(os.path.isfile(path) for path in save_csv_files.values()) + + +def check_gigaspeech_folders( + data_folder: str, + json_file: str = "GigaSpeech.json", + audio_folder: str = "audio", +) -> None: + """Check if the data folder actually contains the GigaSpeech dataset. + + If it does not, an error is raised. + + Parameters + ---------- + data_folder : str + The path to the GigaSpeech dataset. + json_file : str, optional + The name of the JSON file containing the metadata of the GigaSpeech dataset. + audio_folder : str, optional + The name of the folder containing the audio files of the GigaSpeech dataset. + + Returns + ------- + None + + Raises + ------ + OSError + If GigaSpeech is not found at the specified path. + """ + # Checking if "GigaSpeech.json" exist + if not os.path.exists(json_file): + err_msg = ( + "the opus file %s does not exist (it is expected in the " + "Gigaspeech dataset)" % json_file + ) + raise OSError(err_msg) + + # Check if audio folders exist + for folder_subset in ["audiobook", "podcast", "youtube"]: + audio_subset = os.path.join(data_folder, audio_folder, folder_subset) + if not os.path.exists(audio_subset): + err_msg = ( + "the file %s does not exist (it is expected in the " + "Gigaspeech dataset)" % audio_subset + ) + raise OSError(err_msg) diff --git a/recipes/LibriSpeech/ASR/CTC/train_with_wav2vec_k2.py b/recipes/LibriSpeech/ASR/CTC/train_with_wav2vec_k2.py index 3501a268fa..2ddc4b72be 100644 --- a/recipes/LibriSpeech/ASR/CTC/train_with_wav2vec_k2.py +++ b/recipes/LibriSpeech/ASR/CTC/train_with_wav2vec_k2.py @@ -55,7 +55,7 @@ def compute_forward(self, batch, stage): # Forward pass - # Handling SpeechBrain vs HuggingFance pretrained models + # Handling SpeechBrain vs HuggingFace pretrained models if hasattr(self.modules, "extractor"): # SpeechBrain pretrained model latents = self.modules.extractor(wavs) feats = self.modules.encoder_wrapper(latents, wav_lens=wav_lens)[ diff --git a/speechbrain/inference/ASR.py b/speechbrain/inference/ASR.py index 9655a6f059..c35f7b05b4 100644 --- a/speechbrain/inference/ASR.py +++ b/speechbrain/inference/ASR.py @@ -259,7 +259,9 @@ def set_decoding_function(self): opt_beam_search_params["kenlm_model_path"] ) kenlm_model_path = str( - fetch(fl, source=source, savedir=".") + fetch( + fl, source=source, savedir=self.hparams.savedir + ) ) # we need to update the kenlm_model_path in the opt_beam_search_params opt_beam_search_params["kenlm_model_path"] = ( diff --git a/speechbrain/inference/interfaces.py b/speechbrain/inference/interfaces.py index 6c3bfb863e..a5e495222b 100644 --- a/speechbrain/inference/interfaces.py +++ b/speechbrain/inference/interfaces.py @@ -134,6 +134,7 @@ def foreign_class( with open(hparams_local_path) as fin: hparams = load_hyperpyyaml(fin, overrides, overrides_must_match) + hparams["savedir"] = savedir # Pretraining: pretrainer = hparams["pretrainer"] pretrainer.set_collect_in(savedir) @@ -508,6 +509,9 @@ def from_hparams( fin, overrides, overrides_must_match=overrides_must_match ) + # add savedir to hparams + hparams["savedir"] = savedir + # Pretraining: pretrainer = hparams.get("pretrainer", None) if pretrainer is not None: diff --git a/tests/recipes/GigaSpeech.csv b/tests/recipes/GigaSpeech.csv new file mode 100644 index 0000000000..ddf43f3cc5 --- /dev/null +++ b/tests/recipes/GigaSpeech.csv @@ -0,0 +1,3 @@ +Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks,performance +ASR-CTC,GigaSpeech,recipes/GigaSpeech/ASR/CTC/train_with_wavlm.py,recipes/GigaSpeech/ASR/CTC/hparams/train_hf_wavlm.yaml,recipes/GigaSpeech/ASR/CTC/gigaspeech_prepare.py,recipes/GigaSpeech/ASR/CTC/README.md,,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=1 --skip_prep=True --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint, +ASR-Transducers,GigaSpeech,recipes/GigaSpeech/ASR/transducer/train.py,recipes/GigaSpeech/ASR/transducer/hparams/conformer_transducer.yaml,recipes/GigaSpeech/ASR/transducer/gigaspeech_prepare.py,recipes/GigaSpeech/ASR/transducer/README.md,,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=1 --skip_prep=True, \ No newline at end of file