Skip to content

Commit

Permalink
🏆 Custom reward function for GRPO and shiny doc (#2606)
Browse files Browse the repository at this point in the history
* initial commit

* doc on custom reward function

* test

* doc doc doc

* fix collator

* style

* links?

* I need a docdoc 🎵

* fix link

* I do like writing doc tbh

* it takes time, but it's worth it

* no return!

* type hint

* it's probably the best of both worlds [ci skip]

* new doc before implementation

* tests

* more doc

* style

* multiple pretrained funcs

* fix arg name

* main?

* example for R1

* fix script

* clearer

* import [ci skip]

* Update docs/source/grpo_trainer.md

Co-authored-by: lewtun <[email protected]>

---------

Co-authored-by: lewtun <[email protected]>
  • Loading branch information
qgallouedec and lewtun authored Jan 23, 2025
1 parent 887c1f3 commit a1d2955
Show file tree
Hide file tree
Showing 4 changed files with 392 additions and 57 deletions.
93 changes: 93 additions & 0 deletions docs/source/grpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,99 @@ The GRPO Trainer logs the following metrics:
- `reward_std` : The average standard deviation within reward groups.
- `kl` : The average KL divergence between the model and the reference model calculated on completions.

## Customization

### Using a custom reward function

The [`GRPOTrainer`] supports using custom reward functions instead of dense reward models. To ensure compatibility, your reward function must satisfy the following requirements:

1. **Input arguments**:
- The function must accept two arguments: `prompts` and `completions`.
- Depending on the dataset format, the input will vary:
- For [standard format](dataset_formats#standard), `prompts` and `completions` will be lists of strings.
- For [conversational format](dataset_formats#conversational), `prompts` and `completions` will be lists of message dictionaries.

2. **Return value**: The function must return a list of floats. Each float represents the reward corresponding to a single completion.

#### Example 1: Reward longer completions

Below is an example of a reward function for a standard format that rewards longer completions:

```python
def reward_func(prompts, completions):
"""Reward function that gives higher scores to longer completions."""
return [float(len(completion)) for completion in completions]
```

You can test it as follows:

```python
>>> prompts = ["The sky is", "The sun is"]
>>> completions = [" blue.", " in the sky."]
>>> print(reward_func(prompts, completions))
[6.0, 12.0]
```

#### Example 2: Reward completions with specific format

Below is an example of a reward function that checks if the completion has a specific format. This example is inspired by the reward function used in the paper [DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning](https://huggingface.co/papers/2501.12948).
It is designed for conversational format, where prompts and completions consist of structured messages.

```python
import re

def format_reward_func(prompts, completions):
"""Reward function that checks if the completion has a specific format."""
pattern = r"^<think>.*?</think><answer>.*?</answer>$"
completion_contents = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, content) for content in completion_contents]
return [1.0 if match else 0.0 for match in matches]
```

You can test this function as follows:

```python
>>> prompts = [
... [{"role": "assistant", "content": "What is the result of (1 + 2) * 4?"}],
... [{"role": "assistant", "content": "What is the result of (3 + 1) * 2?"}],
... ]
>>> completions = [
... [{"role": "assistant", "content": "<think>The sum of 1 and 2 is 3, which we multiply by 4 to get 12.</think><answer>(1 + 2) * 4 = 12</answer>"}],
... [{"role": "assistant", "content": "The sum of 3 and 1 is 4, which we multiply by 2 to get 8. So (3 + 1) * 2 = 8."}],
... ]
>>> format_reward_func(prompts, completions)
[1.0, 0.0]
>>>
```

#### Passing the reward function to the trainer

To use your custom reward function, pass it to the `GRPOTrainer` as follows:

```python
from trl import GRPOTrainer

trainer = GRPOTrainer(
reward_funcs=reward_func,
...,
)
```

If you have multiple reward functions, you can pass them as a list:

```python
from trl import GRPOTrainer

trainer = GRPOTrainer(
reward_funcs=[reward_func1, reward_func2],
...,
)
```

and the reward will be computed as the sum of the rewards from each function.

Note that [`GRPOTrainer`] supports multiple reward functions of different types. See the parameters documentation for more details.

## GRPOTrainer

[[autodoc]] GRPOTrainer
Expand Down
155 changes: 150 additions & 5 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_init_minimal(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
train_dataset=dataset,
)

Expand All @@ -54,7 +54,7 @@ def test_training(self, config_name):
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
)
Expand Down Expand Up @@ -87,7 +87,7 @@ def test_training_peft(self):
)
trainer = GRPOTrainer(
model=model,
reward_model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
peft_config=LoraConfig(),
Expand Down Expand Up @@ -130,10 +130,155 @@ def test_training_different_reward_model(self):
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_model=reward_model,
reward_funcs=reward_model,
args=training_args,
train_dataset=dataset,
reward_processing_classes=reward_tokenizer,
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")

def test_training_reward_func_standard(self):
# Test if trainer can handle reward function with standard format
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

def reward_func(prompts, completions):
"""Reward function that rewards longer completions."""
return [float(len(completion)) for completion in completions]

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=2, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")

def test_training_reward_func_conversational(self):
# Test if trainer can handle reward function with conversational format
dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only", split="train")

def reward_func(prompts, completions):
"""Reward function that gives higher scores to longer completion content."""
completion_contents = [completion[0]["content"] for completion in completions]
return [float(len(content)) for content in completion_contents]

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=2, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")

def test_training_multiple_reward_funcs(self):
# Test that GRPOTrainer can be instantiated with multiple reward functions
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

def reward_func1(prompts, completions):
"""Reward function that rewards longer completions."""
return [float(len(completion)) for completion in completions]

def reward_func2(prompts, completions):
"""Reward function that rewards completions with more unique letters."""
return [float(len(set(completion))) for completion in completions]

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=2, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs=[reward_func1, reward_func2],
args=training_args,
train_dataset=dataset,
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")

def test_training_multiple_mixed_reward_funcs(self):
# Test if the trainer can handle a mix of reward functions and reward models
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

def reward_func(prompts, completions):
"""Reward function that rewards longer completions."""
return [float(len(completion)) for completion in completions]

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=2, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs=[reward_func, "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5"],
args=training_args,
train_dataset=dataset,
reward_processing_class=reward_tokenizer,
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
Expand Down
2 changes: 1 addition & 1 deletion trl/scripts/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def main(script_args, training_args, model_args):
# Initialize the GRPO trainer
trainer = GRPOTrainer(
model=model,
reward_model=reward_model,
reward_funcs=reward_model,
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
Expand Down
Loading

0 comments on commit a1d2955

Please sign in to comment.