Skip to content

Commit

Permalink
fix: Remove cast_to_sampling_rate from process_dataset call in `valid…
Browse files Browse the repository at this point in the history
…ation` module
  • Loading branch information
saattrupdan committed Oct 23, 2024
1 parent 70899f6 commit dff2c7a
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
2 changes: 2 additions & 0 deletions makefile
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ type-check: ## Run type checking
--show-error-codes \
--check-untyped-defs

check: lint format type-check ## Check the code

roest-315m: ## Train the Røst-315M model
@accelerate launch \
--use-deepspeed \
Expand Down
7 changes: 5 additions & 2 deletions src/coral/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import TypeVar

import torch
from datasets import Dataset, DatasetDict
from datasets import Audio, Dataset, DatasetDict
from transformers import AutomaticSpeechRecognitionPipeline, pipeline

from .compute_metrics import compute_metrics_of_dataset_using_pipeline
Expand Down Expand Up @@ -60,6 +60,10 @@ def add_validations(
if input_is_single_split:
dataset = DatasetDict(dict(train=dataset))

dataset = dataset.cast_column(
column=audio_column, feature=Audio(sampling_rate=sampling_rate)
)

processed_dataset = process_dataset(
dataset=dataset,
clean_text=clean_text,
Expand All @@ -69,7 +73,6 @@ def add_validations(
convert_numerals=False,
remove_input_dataset_columns=True,
lower_case=lower_case,
cast_to_sampling_rate=sampling_rate,
)

logger.info(f"Loading the {model_id!r} ASR model...")
Expand Down

0 comments on commit dff2c7a

Please sign in to comment.