Skip to content

Commit

Permalink
fix: Disable saving model checkpoints if gradient accumulation is dis…
Browse files Browse the repository at this point in the history
…abled
  • Loading branch information
saattrupdan committed Oct 24, 2024
1 parent 4ef2be8 commit 126b148
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 1 deletion.
2 changes: 1 addition & 1 deletion config/asr_finetuning.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ wandb_group: default
wandb_name: ${model_id}
resume_from_checkpoint: false
ignore_data_skip: false
save_total_limit: 1
save_total_limit: 0 # Will automatically be set to >=1 if `early_stopping` is enabled

# Optimisation parameters
learning_rate: 1e-4
Expand Down
3 changes: 3 additions & 0 deletions src/coral/wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,9 @@ def load_training_arguments(self) -> TrainingArguments:
if self.is_main_process:
logger.info("Mixed precision training with FP16 enabled.")

if self.config.early_stopping:
self.config.save_total_limit = max(self.config.save_total_limit, 1)

args = TrainingArguments(
output_dir=self.config.model_dir,
hub_model_id=f"{self.config.hub_organisation}/{self.config.model_id}",
Expand Down
4 changes: 4 additions & 0 deletions src/coral/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ def load_training_arguments(self) -> TrainingArguments:
if self.is_main_process:
logger.info("Mixed precision training with FP16 enabled.")

if self.config.early_stopping:
self.config.save_total_limit = max(self.config.save_total_limit, 1)

args = Seq2SeqTrainingArguments(
output_dir=self.config.model_dir,
hub_model_id=f"{self.config.hub_organisation}/{self.config.model_id}",
Expand All @@ -172,6 +175,7 @@ def load_training_arguments(self) -> TrainingArguments:
eval_strategy="steps",
eval_steps=self.config.eval_steps,
save_steps=self.config.save_steps,
save_strategy="no" if self.config.save_total_limit == 0 else "steps",
logging_steps=self.config.logging_steps,
length_column_name="input_length",
gradient_checkpointing=True,
Expand Down

0 comments on commit 126b148

Please sign in to comment.