diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 0fd0d9f5d2..b1dc98b9da 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -65,6 +65,9 @@ class GRPOConfig(TrainingArguments): device dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors during initialization. + vllm_dtype (`str`, *optional*, defaults to `"auto"`): + Data type to use for vLLM generation. If set to `"auto"`, the data type will be automatically determined + based on the model configuration. Find the supported values in the vLLM documentation. > Parameters that control the training @@ -144,7 +147,15 @@ class GRPOConfig(TrainingArguments): "out-of-memory (OOM) errors during initialization." }, ) - + + vllm_dtype: Optional[str] = field( + default="auto", + metadata={ + "help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically " + "determined based on the model configuration. Find the supported values in the vLLM documentation." + }, + ) + # Parameters that control the training learning_rate: float = field( default=1e-6, diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 38286335ff..7cbebada0a 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -314,6 +314,7 @@ def data_collator(features): # No data collation is needed in GRPO model=model.name_or_path, device=vllm_device, gpu_memory_utilization=self.args.vllm_gpu_memory_utilization, + dtype=self.args.vllm_dtype, # Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can # directly reuse the KV cache if it shares the same prefix with one of the existing queries. # This is particularly useful here because we generate completions from the same prompts.