diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index abc5dcb2ee..3c520206bc 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -45,8 +45,8 @@ class GRPOConfig(TrainingArguments): max_prompt_length (`int` or `None`, *optional*, defaults to `512`): Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left. num_generations (`int` or `None`, *optional*, defaults to `8`): - Number of generations per prompt to sample. It must be evenly divisible by the effective batch size - (`num_processes` * `per_device_train_batch_size`). + Number of generations per prompt to sample. The global batch size (num_processes * per_device_batch_size) + must be divisible by this value. temperature (`float`, *optional*, defaults to `0.9`): Temperature for sampling. The higher the temperature, the more random the completions. max_completion_length (`int` or `None`, *optional*, defaults to `256`): @@ -125,8 +125,8 @@ class GRPOConfig(TrainingArguments): num_generations: Optional[int] = field( default=8, metadata={ - "help": "Number of generations to sample. It must be evenly divisible by the effective batch size " - "(num_processes * per_device_train_batch_size)." + "help": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) " + "must be divisible by this value." }, ) temperature: Optional[float] = field( diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 892143aa18..04a7e79cb0 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -341,9 +341,7 @@ def data_collator(features): # No data collation is needed in GRPO optimizers=optimizers, ) - # Maybe move this elsewhere in the future: # Check if the per_device_train_batch_size * num processes can be divided by the number of generations - if args.per_device_train_batch_size * self.accelerator.num_processes % self.num_generations != 0: possible_values = [ i @@ -351,8 +349,8 @@ def data_collator(features): # No data collation is needed in GRPO if (self.accelerator.num_processes * args.per_device_train_batch_size) % i == 0 ] raise ValueError( - f"The number of generations per prompt ({self.num_generations}) must be evenly divisible by the " - f"effective batch size ({self.accelerator.num_processes} x {args.per_device_train_batch_size}). " + f"The global batch size ({self.accelerator.num_processes} x {args.per_device_train_batch_size}) " + f"must be evenly divisible by the number of generations per prompt ({self.num_generations})." f"Given this batch size, the valid values for the number of generations are: {possible_values}." )