Skip to content

Commit

Permalink
final fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
younesbelkada committed Jan 12, 2024
1 parent 434fb28 commit 9585ce8
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 28 deletions.
14 changes: 2 additions & 12 deletions .github/workflows/slow-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,6 @@ jobs:
source activate trl
pip install -e ".[test]" --no-deps
pip install pytest-reportlog parameterized
- name: Run common tests on single GPU
if: always()
run: |
source activate trl
make tests_common_gpu
- name: Run slow SFT tests on single GPU
if: always()
Expand Down Expand Up @@ -80,12 +74,6 @@ jobs:
source activate trl
pip install -e ".[test]" --no-deps
pip install pytest-reportlog parameterized
- name: Run common tests on single GPU
if: always()
run: |
source activate trl
make tests_common_gpu
- name: Run slow SFT tests on single GPU
if: always()
Expand All @@ -103,12 +91,14 @@ jobs:
if: always()
run: |
source activate trl
pip install deepspeed
make run_sft_examples
- name: Run end-to-end DPO examples tests on multi GPU
if: always()
run: |
source activate trl
pip install deepspeed
make run_dpo_examples
- name: Generate Reports
Expand Down
3 changes: 1 addition & 2 deletions scripts/log_reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,7 @@ def main(slack_channel_name=None):
from slack_sdk import WebClient

if len(message) > MAX_LEN_MESSAGE:
print(f"Truncating long message from {len(message)} to {MAX_LEN_MESSAGE}")
message = message[:MAX_LEN_MESSAGE] + "..."
message = f"There are {total_num_failed} failed tests in total ! Cannot display the entire summary - please check the action results directly"

if len(message) != 0:
md_report = {
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
"tyro>=0.5.11",
]
EXTRAS = {
"test": ["parameterized", "pytest", "pytest-xdist", "accelerate", "pytest-cov", "pytest-xdist", "wandb"],
"test": ["parameterized", "pytest", "pytest-xdist", "accelerate", "pytest-cov", "pytest-xdist"],
"peft": ["peft>=0.4.0"],
"diffusers": ["diffusers>=0.18.0"],
"deepspeed": ["deepspeed>=0.9.5"],
Expand Down
15 changes: 3 additions & 12 deletions tests/slow/test_dpo_slow.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,7 @@
from trl import DPOTrainer, is_peft_available

from ..testing_utils import require_bitsandbytes, require_peft, require_torch_gpu, require_torch_multi_gpu
from .testing_constants import (
DPO_GEN_DURING_EVAL,
DPO_LOSS_TYPES,
DPO_PRECOMPUTE_LOGITS,
GRADIENT_CHECKPOINTING_KWARGS,
MODELS_TO_TEST,
)
from .testing_constants import DPO_LOSS_TYPES, DPO_PRECOMPUTE_LOGITS, GRADIENT_CHECKPOINTING_KWARGS, MODELS_TO_TEST


if is_peft_available():
Expand All @@ -57,10 +51,8 @@ def tearDown(self):
torch.cuda.empty_cache()
gc.collect()

@parameterized.expand(
list(itertools.product(MODELS_TO_TEST, DPO_LOSS_TYPES, DPO_GEN_DURING_EVAL, DPO_PRECOMPUTE_LOGITS))
)
def test_dpo_bare_model(self, model_id, loss_type, gen_during_eval, pre_compute_logits):
@parameterized.expand(list(itertools.product(MODELS_TO_TEST, DPO_LOSS_TYPES, DPO_PRECOMPUTE_LOGITS)))
def test_dpo_bare_model(self, model_id, loss_type, pre_compute_logits):
"""
A test that tests the simple usage of `DPOTrainer` using a bare model in full precision.
"""
Expand Down Expand Up @@ -90,7 +82,6 @@ def test_dpo_bare_model(self, model_id, loss_type, gen_during_eval, pre_compute_
tokenizer=tokenizer,
train_dataset=self.dataset,
eval_dataset=self.dataset,
generate_during_eval=gen_during_eval,
loss_type=loss_type,
precompute_ref_log_probs=pre_compute_logits,
max_length=self.max_length,
Expand Down
1 change: 0 additions & 1 deletion tests/slow/testing_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,3 @@

DPO_LOSS_TYPES = ["sigmoid", "ipo", "kto_pair"]
DPO_PRECOMPUTE_LOGITS = [True, False]
DPO_GEN_DURING_EVAL = [True, False]

0 comments on commit 9585ce8

Please sign in to comment.