From 110c26e04914ddfd6e6494f81c4412b5ec1eb298 Mon Sep 17 00:00:00 2001 From: Vassilis Vassiliadis Date: Fri, 14 Jun 2024 08:36:26 +0100 Subject: [PATCH 1/3] fix: concatenate the trainer_callbacks and additional_callbacks in sft_trainer:train() Signed-off-by: Vassilis Vassiliadis --- tuning/sft_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index de616fd2..483ef392 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -150,7 +150,7 @@ def train( # Add any extra callback if passed by users if additional_callbacks is not None: - trainer_callbacks.append(additional_callbacks) + trainer_callbacks.extend(additional_callbacks) framework = AccelerationFrameworkConfig.from_dataclasses( quantized_lora_config, fusedops_kernels_config From e0168403e6dd4e5431f8f0a4a7e3c8149ef76d7e Mon Sep 17 00:00:00 2001 From: Vassilis Vassiliadis Date: Fri, 14 Jun 2024 09:02:47 +0100 Subject: [PATCH 2/3] test: add a unit test for additional_callbacks param of train() Signed-off-by: Vassilis Vassiliadis --- tests/test_sft_trainer.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index c02146bf..d5d1619e 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -616,3 +616,22 @@ def test_bad_torch_dtype(): with pytest.raises(ValueError): sft_trainer.train(model_args, DATA_ARGS, train_args, PEFT_PT_ARGS) + + +def test_run_with_additional_callbacks(): + """Ensure that train() can work with additional_callbacks""" + # Third Party + from transformers.trainer_callback import TrainerCallback + + with tempfile.TemporaryDirectory() as tempdir: + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + model_args = copy.deepcopy(MODEL_ARGS) + + sft_trainer.train( + model_args, + DATA_ARGS, + train_args, + PEFT_PT_ARGS, + additional_callbacks=[TrainerCallback()], + ) From 890590481a1d7eaf4bfd59288bdea01bd04a88d8 Mon Sep 17 00:00:00 2001 From: Vassilis Vassiliadis Date: Fri, 14 Jun 2024 09:20:46 +0100 Subject: [PATCH 3/3] chore: fix formating of test_run_with_additional_callbacks Signed-off-by: Vassilis Vassiliadis --- tests/test_sft_trainer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index d5d1619e..7041bd13 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -23,6 +23,7 @@ # Third Party from datasets.exceptions import DatasetGenerationError +from transformers.trainer_callback import TrainerCallback import pytest import torch import transformers @@ -620,8 +621,6 @@ def test_bad_torch_dtype(): def test_run_with_additional_callbacks(): """Ensure that train() can work with additional_callbacks""" - # Third Party - from transformers.trainer_callback import TrainerCallback with tempfile.TemporaryDirectory() as tempdir: train_args = copy.deepcopy(TRAIN_ARGS)