Skip to content

Commit

Permalink
fix: Set Whisper TrainingArguments as wav2vec2
Browse files Browse the repository at this point in the history
  • Loading branch information
saattrupdan committed Oct 24, 2024
1 parent b85036c commit 2eff85b
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/coral/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
AutoConfig,
AutoModelForSpeechSeq2Seq,
EvalPrediction,
SchedulerType,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
Trainer,
Expand Down Expand Up @@ -167,6 +168,7 @@ def load_training_arguments(self) -> TrainingArguments:
per_device_eval_batch_size=self.config.per_device_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
learning_rate=self.config.learning_rate,
lr_scheduler_type=SchedulerType.COSINE,
warmup_steps=self.config.warmup_steps,
max_steps=self.config.max_steps,
fp16=fp16,
Expand All @@ -178,14 +180,16 @@ def load_training_arguments(self) -> TrainingArguments:
save_strategy="no" if self.config.save_total_limit == 0 else "steps",
logging_steps=self.config.logging_steps,
length_column_name="input_length",
gradient_checkpointing=True,
gradient_checkpointing=self.config.gradient_checkpointing,
save_total_limit=self.config.save_total_limit,
load_best_model_at_end=self.config.early_stopping,
metric_for_best_model="wer",
greater_is_better=False,
seed=self.config.seed,
remove_unused_columns=False,
optim=OptimizerNames.ADAMW_TORCH,
adam_beta1=self.config.adam_first_momentum,
adam_beta2=self.config.adam_second_momentum,
report_to=["wandb"] if self.config.wandb else [],
ignore_data_skip=self.config.ignore_data_skip,
save_safetensors=True,
Expand Down

0 comments on commit 2eff85b

Please sign in to comment.