Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
dchourasia committed Aug 17, 2024
2 parents fd86b4f + a6d093e commit 610a3da
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
2 changes: 1 addition & 1 deletion tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def test_run_causallm_ft_save_with_save_model_dir_save_strategy_no():
# validate that no checkpoints created
assert not any(x.startswith("checkpoint-") for x in os.listdir(tempdir))

sft_trainer.save(tempdir, trainer)
sft_trainer.save(tempdir, trainer, "debug")
assert any(x.endswith(".safetensors") for x in os.listdir(tempdir))
_test_run_inference(checkpoint_path=tempdir)

Expand Down
4 changes: 3 additions & 1 deletion tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,13 +373,15 @@ def save(path: str, trainer: SFTTrainer, log_level="WARNING"):
Path to save the model to.
trainer: SFTTrainer
Instance of SFTTrainer used for training to save the model.
log_level: str
Optional threshold to set save save logger to, default warning.
"""
logger = logging.getLogger("sft_trainer_save")
# default value from TrainingArguments
if log_level == "passive":
log_level = "WARNING"

logger.setLevel(log_level)
logger.setLevel(log_level.upper())

if not os.path.exists(path):
os.makedirs(path, exist_ok=True)
Expand Down

0 comments on commit 610a3da

Please sign in to comment.