Skip to content

Commit

Permalink
oops
Browse files Browse the repository at this point in the history
  • Loading branch information
younesbelkada committed Jan 12, 2024
1 parent 4ea24af commit 37d5596
Showing 1 changed file with 18 additions and 1 deletion.
19 changes: 18 additions & 1 deletion tests/slow/test_dpo_slow.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,24 @@ def test_dpo_peft_model_qlora(self, model_id, loss_type, pre_compute_logits, gra


@require_torch_multi_gpu
class DPOTrainerSlowTesterMultiGPU(DPOTrainerSlowTester):
class DPOTrainerSlowTesterMultiGPU(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.dataset = load_dataset("trl-internal-testing/mlabonne-chatml-dpo-pairs-copy", split="train[:10%]")
cls.peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
r=8,
bias="none",
task_type="CAUSAL_LM",
)
cls.max_length = 128

def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
gc.collect()

@parameterized.expand(
list(
itertools.product(
Expand Down

0 comments on commit 37d5596

Please sign in to comment.