From 37d559696ab993b284179f57850edbfd78b2fb15 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 12 Jan 2024 10:58:37 +0000 Subject: [PATCH] oops --- tests/slow/test_dpo_slow.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/tests/slow/test_dpo_slow.py b/tests/slow/test_dpo_slow.py index 9c15673478..f33c5820ab 100644 --- a/tests/slow/test_dpo_slow.py +++ b/tests/slow/test_dpo_slow.py @@ -222,7 +222,24 @@ def test_dpo_peft_model_qlora(self, model_id, loss_type, pre_compute_logits, gra @require_torch_multi_gpu -class DPOTrainerSlowTesterMultiGPU(DPOTrainerSlowTester): +class DPOTrainerSlowTesterMultiGPU(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.dataset = load_dataset("trl-internal-testing/mlabonne-chatml-dpo-pairs-copy", split="train[:10%]") + cls.peft_config = LoraConfig( + lora_alpha=16, + lora_dropout=0.1, + r=8, + bias="none", + task_type="CAUSAL_LM", + ) + cls.max_length = 128 + + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + gc.collect() + @parameterized.expand( list( itertools.product(