Skip to content

Commit

Permalink
fix some logic errors
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec committed Feb 5, 2025
1 parent 5f37e3d commit 0b131b1
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 8 deletions.
8 changes: 4 additions & 4 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ class GRPOConfig(TrainingArguments):
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.
num_generations (`int` or `None`, *optional*, defaults to `8`):
Number of generations per prompt to sample. It must be evenly divisible by the effective batch size
(`num_processes` * `per_device_train_batch_size`).
Number of generations per prompt to sample. The global batch size (num_processes * per_device_batch_size)
must be divisible by this value.
temperature (`float`, *optional*, defaults to `0.9`):
Temperature for sampling. The higher the temperature, the more random the completions.
max_completion_length (`int` or `None`, *optional*, defaults to `256`):
Expand Down Expand Up @@ -125,8 +125,8 @@ class GRPOConfig(TrainingArguments):
num_generations: Optional[int] = field(
default=8,
metadata={
"help": "Number of generations to sample. It must be evenly divisible by the effective batch size "
"(num_processes * per_device_train_batch_size)."
"help": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) "
"must be divisible by this value."
},
)
temperature: Optional[float] = field(
Expand Down
6 changes: 2 additions & 4 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,18 +341,16 @@ def data_collator(features): # No data collation is needed in GRPO
optimizers=optimizers,
)

# Maybe move this elsewhere in the future:
# Check if the per_device_train_batch_size * num processes can be divided by the number of generations

if args.per_device_train_batch_size * self.accelerator.num_processes % self.num_generations != 0:
possible_values = [
i
for i in range(2, self.accelerator.num_processes * args.per_device_train_batch_size + 1)
if (self.accelerator.num_processes * args.per_device_train_batch_size) % i == 0
]
raise ValueError(
f"The number of generations per prompt ({self.num_generations}) must be evenly divisible by the "
f"effective batch size ({self.accelerator.num_processes} x {args.per_device_train_batch_size}). "
f"The global batch size ({self.accelerator.num_processes} x {args.per_device_train_batch_size}) "
f"must be evenly divisible by the number of generations per prompt ({self.num_generations})."
f"Given this batch size, the valid values for the number of generations are: {possible_values}."
)

Expand Down

0 comments on commit 0b131b1

Please sign in to comment.