Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

↔️ GRPO: Set max_model_len when initializing vLLM instance #2728

Merged
merged 16 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ class GRPOConfig(TrainingArguments):
vllm_dtype (`str`, *optional*, defaults to `"auto"`):
Data type to use for vLLM generation. If set to `"auto"`, the data type will be automatically determined
based on the model configuration. Find the supported values in the vLLM documentation.
vllm_max_model_len (`int` or `None`, *optional*, defaults to `None`):
If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced
`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model
context size, which might be much larger than the KV cache, leading to inefficiencies.

> Parameters that control the training

Expand Down Expand Up @@ -181,6 +185,14 @@ class GRPOConfig(TrainingArguments):
"determined based on the model configuration. Find the supported values in the vLLM documentation."
},
)
vllm_max_model_len: Optional[int] = field(
default=None,
metadata={
"help": "If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced "
"`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model "
"context size, which might be much larger than the KV cache, leading to inefficiencies."
},
)

# Parameters that control the training
learning_rate: float = field(
Expand Down
1 change: 1 addition & 0 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ def data_collator(features): # No data collation is needed in GRPO
# directly reuse the KV cache if it shares the same prefix with one of the existing queries.
# This is particularly useful here because we generate completions from the same prompts.
enable_prefix_caching=True,
max_model_len=self.args.vllm_max_model_len,
)
self.sampling_params = SamplingParams(
n=self.num_generations,
Expand Down