Skip to content

Commit

Permalink
📐 Add vLLM dtype configuration for GRPO trainer (#2738)
Browse files Browse the repository at this point in the history
* feat: Add vLLM dtype configuration for GRPO trainer

* added vllm dtype info in docstring

* send to vLLM doc

---------

Co-authored-by: Quentin Gallouédec <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]>
  • Loading branch information
3 people authored Feb 4, 2025
1 parent 1a22764 commit 32f8fa8
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
13 changes: 12 additions & 1 deletion trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ class GRPOConfig(TrainingArguments):
device dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus
improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors
during initialization.
vllm_dtype (`str`, *optional*, defaults to `"auto"`):
Data type to use for vLLM generation. If set to `"auto"`, the data type will be automatically determined
based on the model configuration. Find the supported values in the vLLM documentation.
> Parameters that control the training
Expand Down Expand Up @@ -144,7 +147,15 @@ class GRPOConfig(TrainingArguments):
"out-of-memory (OOM) errors during initialization."
},
)


vllm_dtype: Optional[str] = field(
default="auto",
metadata={
"help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically "
"determined based on the model configuration. Find the supported values in the vLLM documentation."
},
)

# Parameters that control the training
learning_rate: float = field(
default=1e-6,
Expand Down
1 change: 1 addition & 0 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ def data_collator(features): # No data collation is needed in GRPO
model=model.name_or_path,
device=vllm_device,
gpu_memory_utilization=self.args.vllm_gpu_memory_utilization,
dtype=self.args.vllm_dtype,
# Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can
# directly reuse the KV cache if it shares the same prefix with one of the existing queries.
# This is particularly useful here because we generate completions from the same prompts.
Expand Down

0 comments on commit 32f8fa8

Please sign in to comment.