Skip to content

Commit

Permalink
Merge pull request #42 from alexandrainst/chore/config-structure
Browse files Browse the repository at this point in the history
Chore/config structure
  • Loading branch information
saattrupdan authored Oct 26, 2023
2 parents 12b6829 + ed6f96a commit 272d401
Show file tree
Hide file tree
Showing 28 changed files with 1,716 additions and 1,778 deletions.
7 changes: 2 additions & 5 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
strategy:
matrix:
os: [windows-latest, macos-latest, ubuntu-latest]
python-version: ["3.10", "3.11"]
python-version: ["3.11"]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
Expand All @@ -26,7 +26,7 @@ jobs:
uses: FedericoCarboni/setup-ffmpeg@v2

- name: Install Poetry
run: pipx install poetry==1.4.0
run: pip3 install poetry==1.5.1

- name: Set up Python
uses: actions/setup-python@v4
Expand All @@ -39,9 +39,6 @@ jobs:
poetry env use "${{ matrix.python-version }}"
poetry install --no-interaction --no-cache
- name: Fix PyTorch bug
run: poetry add torch==2.0.0

- name: Test with pytest
run: poetry run pytest
env:
Expand Down
13 changes: 8 additions & 5 deletions .github/workflows/docs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@ jobs:
steps:
- uses: actions/checkout@v3

- name: Install ffmpeg
uses: FedericoCarboni/setup-ffmpeg@v2

- name: Install Poetry
run: pipx install poetry==1.4.0
run: pip3 install poetry==1.5.1

- name: Set up Python
uses: actions/setup-python@v4
Expand All @@ -27,12 +30,12 @@ jobs:
cache: "poetry"

- name: Install Dependencies
run: poetry install
run: |
poetry env use "${{ matrix.python-version }}"
poetry install --no-interaction --no-cache
- name: Build documentation
run: |
poetry env use "3.11"
poetry run pdoc --docformat google src/coral_models -o docs
run: poetry run pdoc --docformat google src/coral_models -o docs

- name: Compress documentation
run: tar --directory docs/ -hcf artifact.tar .
Expand Down
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ target/
# pytest cache
.pytest_cache/

# Linting cache
.ruff_cache/

# Python cache
**/__pycache__

# Hydra logs
outputs/
multirun/
Expand Down
55 changes: 39 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ ______________________________________________________________________
[![Documentation](https://img.shields.io/badge/docs-passing-green)](https://alexandrainst.github.io/coral_models/coral_models.html)
[![License](https://img.shields.io/github/license/alexandrainst/coral_models)](https://github.com/alexandrainst/coral_models/blob/main/LICENSE)
[![LastCommit](https://img.shields.io/github/last-commit/alexandrainst/coral_models)](https://github.com/alexandrainst/coral_models/commits/main)
[![Code Coverage](https://img.shields.io/badge/Coverage-60%25-yellow.svg)](https://github.com/alexandrainst/coral_models/tree/main/tests)
[![Code Coverage](https://img.shields.io/badge/Coverage-53%25-orange.svg)](https://github.com/alexandrainst/coral_models/tree/main/tests)


Developers:
Expand Down Expand Up @@ -54,7 +54,6 @@ publishing the code as a package and more.
## Project structure
```
.
├── .flake8
├── .github
│   └── workflows
│   ├── ci.yaml
Expand All @@ -66,22 +65,35 @@ publishing the code as a package and more.
├── config
│   ├── __init__.py
│   ├── config.yaml
│   ├── dataset
│   │   ├── common_voice_da.yaml
│   ├── datasets
│   │   ├── alvenir_test_set.yaml
│   │   ├── common_voice_13_da.yaml
│   │   ├── common_voice_13_nn.yaml
│   │   ├── common_voice_13_sv.yaml
│   │   ├── common_voice_9_da.yaml
│   │   ├── fleurs_da.yaml
│   │   ├── fleurs_nb.yaml
│   │   ├── fleurs_sv.yaml
│   │   ├── ftspeech.yaml
│   │   └── test.yaml
│   │   ├── nota.yaml
│   │   ├── nst_da.yaml
│   │   └── test_dataset.yaml
│   ├── hydra
│   │   └── job_logging
│   │   └── custom.yaml
│   └── model
│   ├── test.yaml
│   ├── test_wav2vec2.yaml
│   ├── test_whisper.yaml
│   ├── wav2vec2.yaml
│   ├── wav2vec2_with_lm.yaml
│   └── whisper.yaml
├── data
│   ├── whisper_large.yaml
│   ├── whisper_medium.yaml
│   ├── whisper_small.yaml
│   ├── whisper_xsmall.yaml
│   └── whisper_xxsmall.yaml
├── docs
│   └── .gitkeep
├── makefile
├── models
├── notebooks
├── poetry.lock
├── poetry.toml
├── pyproject.toml
├── src
Expand All @@ -90,15 +102,25 @@ publishing the code as a package and more.
│   │   ├── compute_metrics.py
│   │   ├── data.py
│   │   ├── finetune.py
│   │   ├── model_setup.py
│   │   ├── plot.py
│   │   ├── prepare_raw_data.py
│   │   ├── protocols.py
│   │   ├── utils.py
│   │   └── wav2vec2.py
│   │   ├── wav2vec2.py
│   │   └── whisper.py
│   └── scripts
│   ├── build_coral_data.py
│   ├── build_ftspeech.py
│   ├── evaluate.py
│   ├── finetune.py
│   ├── build_nota.py
│   ├── build_nst_da.py
│   ├── download_ftspeech.py
│   ├── evaluate_model.py
│   ├── find_faulty_audio_clips.py
│   ├── finetune_model.py
│   ├── fix_dot_env_file.py
│   ├── push_ftspeech_to_hub.py
│   ├── plot_training_trajectory.py
│   ├── push_to_hub.py
│   ├── train_ngram_decoder.py
│   └── versioning.py
└── tests
Expand All @@ -109,5 +131,6 @@ publishing the code as a package and more.
├── test_finetune.py
├── test_protocols.py
├── test_utils.py
└── test_wav2vec2.py
├── test_wav2vec2.py
└── test_whisper.py
```
24 changes: 19 additions & 5 deletions config/config.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
defaults:
- model: wav2vec2
- datasets:
- nst_da
- common_voice_9_da
- common_voice_13_da
- override hydra/job_logging: custom
- _self_

Expand All @@ -16,7 +15,13 @@ dirs:
seed: 4242

# Dataset parameters
dataset_probabilities: # null = equal probability to every dataset
characters_to_keep: 'abcdefghijklmnopqrstuvwxyzæøå0123456789éü '
max_seconds_per_example: 10
dataloader_num_workers: 8

# This is a list of the sampling probability of each dataset, where null means that
# each dataset will be sampled equally often
dataset_probabilities:
train: null
val: null
test: null
Expand All @@ -26,6 +31,7 @@ pipeline_id: ${model.name}-finetuned
hub_id: alexandrainst/${pipeline_id}
model_dir: ${dirs.models}/${pipeline_id}
push_to_hub: false
fp16: true

# Training parameters
wandb: false
Expand All @@ -34,10 +40,18 @@ wandb_group: default
wandb_name: null
resume_from_checkpoint: false
ignore_data_skip: false
save_total_limit: 2

# Optimisation parameters
learning_rate: 3e-5
adam_first_momentum: 0.9
adam_second_momentum: 0.98
batch_size: 8
gradient_accumulation: 32
max_steps: 50_000
warmup_steps: 1_000
logging_steps: 10
eval_steps: 100
save_steps: 100
save_total_limit: 2
early_stopping: false
early_stopping_patience: 50
fp16: true
7 changes: 7 additions & 0 deletions config/datasets/alvenir_test_set.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
alvenir_test_set:
id: Alvenir/alvenir_asr_da_eval
subset: null
train_name: null
val_name: null
test_name: test
text_column: sentence
6 changes: 5 additions & 1 deletion config/hydra/job_logging/custom.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@ handlers:
class: logging.StreamHandler
formatter: simple
stream: ext://sys.stdout
file:
class: logging.FileHandler
formatter: simple
filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
root:
handlers: [console]
handlers: [console, file]

disable_existing_loggers: false
13 changes: 0 additions & 13 deletions config/model/test_wav2vec2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ freeze_feature_encoder: true

# Data hyperparameters
clean_dataset: true
characters_to_keep: 'abcdefghijklmnopqrstuvwxyzæøå0123456789éü '

# Model hyperparameters
sampling_rate: 16_000
Expand All @@ -23,15 +22,3 @@ ctc_loss_reduction: sum

# Decoder hyperparameters
language_model_decoder: null

# Training hyperparameters
batch_size: 1
gradient_accumulation: 1
max_steps: 3
learning_rate: 4e-5
warmup_steps: 1
early_stopping: true
early_stopping_patience: 5
adam_first_momentum: 0.9
adam_second_momentum: 0.999
fp16: false
12 changes: 1 addition & 11 deletions config/model/test_whisper.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,4 @@ mask_time_prob: 0.5
mask_time_length: 10
mask_feature_prob: 0.5
mask_feature_length: 64

# Training hyperparameters
batch_size: 1
gradient_accumulation: 1
max_steps: 3
learning_rate: 4e-5
warmup_steps: 1
early_stopping: true
early_stopping_patience: 5
fp16: false
generation_max_length: 1
generation_max_length: 128
14 changes: 2 additions & 12 deletions config/model/wav2vec2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ freeze_feature_encoder: false

# Data hyperparameters
clean_dataset: true
characters_to_keep: 'abcdefghijklmnopqrstuvwxyzæøå0123456789éü '

# Model hyperparameters
sampling_rate: 16_000
Expand All @@ -15,9 +14,9 @@ hidden_dropout: 0.0
feat_proj_dropout: 0.0
feat_quantizer_dropout: 0.0
final_dropout: 0.0
mask_time_prob: 0.5
mask_time_prob: 0.3
mask_time_length: 10
mask_feature_prob: 0.5
mask_feature_prob: 0.3
mask_feature_length: 64
layerdrop: 0.1
ctc_loss_reduction: mean
Expand All @@ -29,12 +28,3 @@ decoder:
dataset_subset: null
dataset_split: train
n: 5

# Training hyperparameters
batch_size: 8
gradient_accumulation: 32
max_steps: 13_000 # Based on the XLS-R paper, section 4.3
warmup_steps: 1_300 # Based on the XLS-R paper, section 4.3
learning_rate: 3e-5
adam_first_momentum: 0.9
adam_second_momentum: 0.98
40 changes: 0 additions & 40 deletions config/model/wav2vec2_no_reg.yaml

This file was deleted.

19 changes: 6 additions & 13 deletions config/model/whisper_large.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,11 @@ clean_dataset: false

# Model hyperparameters
sampling_rate: 16_000
dropout: 0.1
activation_dropout: 0.1
attention_dropout: 0.1
mask_time_prob: 0.5
dropout: 0.0
activation_dropout: 0.0
attention_dropout: 0.0
mask_time_prob: 0.3
mask_time_length: 10
mask_feature_prob: 0.5
mask_feature_prob: 0.3
mask_feature_length: 64

# Training hyperparameters
batch_size: 1
gradient_accumulation: 32
max_steps: 120_000
learning_rate: 3e-5
warmup_steps: 500
generation_max_length: 225
generation_max_length: 128
Loading

0 comments on commit 272d401

Please sign in to comment.