Skip to content

Commit

Permalink
🔧 Refactor GRPOTrainer code for improved readability and maintainability
Browse files Browse the repository at this point in the history
  • Loading branch information
ingambe committed Feb 8, 2025
1 parent 2833aca commit 7efc037
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
4 changes: 1 addition & 3 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ def test_beta_zero_no_ref_model_and_no_kl(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = GRPOConfig(
output_dir=tmp_dir,
beta=0, # set beta to 0 to test the case where the reference model is not used
beta=0, # set beta to 0 to test the case where the reference model is not used
learning_rate=0.1,
per_device_train_batch_size=3,
num_generations=3,
Expand All @@ -484,5 +484,3 @@ def test_beta_zero_no_ref_model_and_no_kl(self):
# Check that no KL divergence was computed during training
for log in trainer.state.log_history:
self.assertNotIn("kl", log, "KL divergence metric should not be logged when beta==0")


7 changes: 4 additions & 3 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def __init__(
"You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
"This argument can only be used when the `model` argument is a string."
)

self.beta = args.beta

if peft_config is not None:
Expand Down Expand Up @@ -293,7 +293,6 @@ def data_collator(features): # No data collation is needed in GRPO
self.num_generations = args.num_generations # = G in the GRPO paper
self.use_vllm = args.use_vllm


# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
# input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
# "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
Expand Down Expand Up @@ -651,7 +650,9 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
per_token_loss = -per_token_loss
else:
# we need to compute the KL divergence between the model and the reference model
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
per_token_kl = (
torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
)

per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
per_token_loss = -(per_token_loss - self.beta * per_token_kl)
Expand Down

0 comments on commit 7efc037

Please sign in to comment.