Skip to content

Commit

Permalink
✅ Add test for GRPOTrainer with beta=0 to ensure no reference model a…
Browse files Browse the repository at this point in the history
…nd KL divergence
  • Loading branch information
ingambe committed Feb 8, 2025
1 parent 175ff2e commit 2833aca
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,3 +455,34 @@ def test_training_with_sync_ref_model(self):
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")

def test_beta_zero_no_ref_model_and_no_kl(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = GRPOConfig(
output_dir=tmp_dir,
beta=0, # set beta to 0 to test the case where the reference model is not used
learning_rate=0.1,
per_device_train_batch_size=3,
num_generations=3,
max_completion_length=32,
max_steps=1, # run only one training step to keep the test fast
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
)

# Check that the reference model is not initialized when beta is 0
self.assertIsNone(trainer.ref_model, "ref_model should be None when beta==0")

trainer.train()

# Check that no KL divergence was computed during training
for log in trainer.state.log_history:
self.assertNotIn("kl", log, "KL divergence metric should not be logged when beta==0")


0 comments on commit 2833aca

Please sign in to comment.