From 6abda349c0f5f5256ebd893e6fa893d93cebcedd Mon Sep 17 00:00:00 2001 From: Mircea Pricop Date: Sat, 1 Feb 2025 16:58:41 +0100 Subject: [PATCH 01/11] Set max_model_len when initializing vLLM instance --- trl/trainer/grpo_trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index f130d5155f..67337c8ccb 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -314,6 +314,7 @@ def data_collator(features): # No data collation is needed in GRPO model=model.name_or_path, device=vllm_device, gpu_memory_utilization=self.args.vllm_gpu_memory_utilization, + max_model_len=(self.max_prompt_length + self.max_completion_length), ) self.sampling_params = SamplingParams( n=self.num_generations, From 264f19de9fc30f30f6e208d7857a9c774b389717 Mon Sep 17 00:00:00 2001 From: Mircea Pricop Date: Sat, 1 Feb 2025 18:45:16 +0000 Subject: [PATCH 02/11] Introduce vllm_max_model_len arg --- trl/trainer/grpo_config.py | 12 ++++++++++++ trl/trainer/grpo_trainer.py | 2 +- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 0fd0d9f5d2..34112204d1 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -65,6 +65,10 @@ class GRPOConfig(TrainingArguments): device dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors during initialization. + vllm_max_model_len (`int` or `None`, *optional*, defaults to `None`): + If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced + `vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model + context size, which might be much larger than the KV cache, leading to inefficiencies. > Parameters that control the training @@ -144,6 +148,14 @@ class GRPOConfig(TrainingArguments): "out-of-memory (OOM) errors during initialization." }, ) + vllm_max_model_len: Optional[int] = field( + default=None, + metadata={ + "help": "If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced" + "`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model" + "context size, which might be much larger than the KV cache, leading to inefficiencies." + }, + ) # Parameters that control the training learning_rate: float = field( diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 67337c8ccb..ae66fce11f 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -314,7 +314,7 @@ def data_collator(features): # No data collation is needed in GRPO model=model.name_or_path, device=vllm_device, gpu_memory_utilization=self.args.vllm_gpu_memory_utilization, - max_model_len=(self.max_prompt_length + self.max_completion_length), + max_model_len=self.args.vllm_max_model_len, ) self.sampling_params = SamplingParams( n=self.num_generations, From c0fd3dfa9355ac4f709ed4c168b1962a22384e4d Mon Sep 17 00:00:00 2001 From: Mircea Pricop Date: Sun, 2 Feb 2025 20:35:32 +0000 Subject: [PATCH 03/11] Replace vllm args with vllm_init_kwargs --- trl/trainer/grpo_config.py | 43 ++++++++++++++++++++++++------------- trl/trainer/grpo_trainer.py | 8 +++---- 2 files changed, 32 insertions(+), 19 deletions(-) diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 34112204d1..db35665c9e 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -131,29 +131,25 @@ class GRPOConfig(TrainingArguments): "(`pip install vllm`)." }, ) + vllm_init_kwargs: Optional[dict] = field( + default_factory=lambda: { + "device": "auto", + "gpu_memory_utilization": 0.9, + }, + metadata={ + "help": "Keyword arguments for `vllm.LLM.__init__` when `use_vllm` is true" + }, + ) vllm_device: Optional[str] = field( default="auto", metadata={ - "help": "Device where vLLM generation will run, e.g. 'cuda:1'. If set to 'auto' (default), the system " - "will automatically select the next available GPU after the last one used for training. This assumes " - "that training has not already occupied all available GPUs." + "help": "Deprecated. Set `device` in `vllm_init_kwargs` instead." }, ) vllm_gpu_memory_utilization: float = field( default=0.9, metadata={ - "help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV " - "cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache " - "size and thus improve the model's throughput. However, if the value is too high, it may cause " - "out-of-memory (OOM) errors during initialization." - }, - ) - vllm_max_model_len: Optional[int] = field( - default=None, - metadata={ - "help": "If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced" - "`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model" - "context size, which might be much larger than the KV cache, leading to inefficiencies." + "help": "Deprecated. Set `gpu_memory_utilization` in `vllm_init_kwargs` instead." }, ) @@ -186,3 +182,20 @@ class GRPOConfig(TrainingArguments): default=0.04, metadata={"help": "KL coefficient."}, ) + + def __post_init__(self): + super().__post_init__() + + if self.vllm_device: + warnings.warn( + "`vllm_device` is deprecated. Set `device` in `vllm_init_kwargs` instead.", + DeprecationWarning, + ) + self.vllm_init_kwargs["device"] = self.vllm_device + + if self.vllm_gpu_memory_utilization: + warnings.warn( + "`vllm_gpu_memory_utilization` is deprecated. Set `gpu_memory_utilization` in `vllm_init_kwargs` instead.", + DeprecationWarning, + ) + self.vllm_init_kwargs["gpu_memory_utilization"] = self.vllm_gpu_memory_utilization diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index ae66fce11f..ca17f7de2a 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -285,9 +285,11 @@ def data_collator(features): # No data collation is needed in GRPO ) if self.accelerator.is_main_process: - vllm_device = self.args.vllm_device + vllm_init_kwargs = self.args.vllm_init_kwargs + vllm_device = vllm_init_kwargs.get("device") if vllm_device == "auto": vllm_device = f"cuda:{self.accelerator.num_processes}" # take the next GPU idx + vllm_init_kwargs["device"] = vllm_device # Check that the requested device is available if vllm_device.split(":")[0] == "cuda" and int(vllm_device.split(":")[1]) >= torch.cuda.device_count(): raise ValueError( @@ -312,9 +314,7 @@ def data_collator(features): # No data collation is needed in GRPO with world_size_patch, profiling_patch: self.llm = LLM( model=model.name_or_path, - device=vllm_device, - gpu_memory_utilization=self.args.vllm_gpu_memory_utilization, - max_model_len=self.args.vllm_max_model_len, + **vllm_init_kwargs ) self.sampling_params = SamplingParams( n=self.num_generations, From db14bb02e47c488c84da6d4c1c49c9be43bc9a81 Mon Sep 17 00:00:00 2001 From: Mircea Pricop Date: Sun, 2 Feb 2025 20:40:01 +0000 Subject: [PATCH 04/11] Update docstring --- trl/trainer/grpo_config.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index db35665c9e..05aebd3b50 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -56,19 +56,8 @@ class GRPOConfig(TrainingArguments): use_vllm (`bool`, *optional*, defaults to `False`): Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`). - vllm_device (`str`, *optional*, defaults to `"auto"`): - Device where vLLM generation will run, e.g. `"cuda:1"`. If set to `"auto"` (default), the system will - automatically select the next available GPU after the last one used for training. This assumes that - training has not already occupied all available GPUs. - vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.9`): - Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the - device dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus - improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors - during initialization. - vllm_max_model_len (`int` or `None`, *optional*, defaults to `None`): - If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced - `vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model - context size, which might be much larger than the KV cache, leading to inefficiencies. + vllm_init_kwargs (`dict`, *optional*, defaults to {"device": "auto", "gpu_memory_utilization": 0.9}) + "Keyword arguments for `vllm.LLM.__init__` when `use_vllm` is true" > Parameters that control the training From e26d7fc0653c90841a9595901a4e544c5ddd9ac2 Mon Sep 17 00:00:00 2001 From: Mircea Pricop Date: Sun, 2 Feb 2025 21:07:35 +0000 Subject: [PATCH 05/11] Add missing import --- trl/trainer/grpo_config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 05aebd3b50..5ce9d61b63 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from dataclasses import dataclass, field from typing import Optional From 0533e754ce27757b7c622f0612163aedeb9c740d Mon Sep 17 00:00:00 2001 From: Mircea Pricop Date: Sun, 2 Feb 2025 21:56:00 +0000 Subject: [PATCH 06/11] Remove default values from newly deprecated args --- trl/trainer/grpo_config.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 5ce9d61b63..5051773b26 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -131,13 +131,13 @@ class GRPOConfig(TrainingArguments): }, ) vllm_device: Optional[str] = field( - default="auto", + default=None, metadata={ "help": "Deprecated. Set `device` in `vllm_init_kwargs` instead." }, ) - vllm_gpu_memory_utilization: float = field( - default=0.9, + vllm_gpu_memory_utilization: Optional[float] = field( + default=None, metadata={ "help": "Deprecated. Set `gpu_memory_utilization` in `vllm_init_kwargs` instead." }, From aa6c74c0b891594fad88b75cc6d15ba307d2054e Mon Sep 17 00:00:00 2001 From: Mircea Pricop Date: Tue, 4 Feb 2025 20:18:19 +0000 Subject: [PATCH 07/11] Docs update --- trl/trainer/grpo_config.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 73523c396d..6be77ecdb6 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -57,7 +57,8 @@ class GRPOConfig(TrainingArguments): use_vllm (`bool`, *optional*, defaults to `False`): Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`). - vllm_init_kwargs (`dict`, *optional*, defaults to {"device": "auto", "gpu_memory_utilization": 0.9}) + vllm_init_kwargs (`dict`, *optional*, defaults to + {"device": "auto", "gpu_memory_utilization": 0.9, "enable_prefix_caching": True}) "Keyword arguments for `vllm.LLM.__init__` when `use_vllm` is true" > Parameters that control the training From 25d9abca62c05a62e1958ce9b71b286e2a9e3378 Mon Sep 17 00:00:00 2001 From: Mircea Pricop Date: Tue, 4 Feb 2025 23:33:41 +0000 Subject: [PATCH 08/11] Reverted to adding single arg for max_model_len --- trl/trainer/grpo_config.py | 82 ++++++++++++++++++------------------- trl/trainer/grpo_trainer.py | 12 ++++-- 2 files changed, 49 insertions(+), 45 deletions(-) diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index d1e379e481..f9997faa2c 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -57,13 +57,22 @@ class GRPOConfig(TrainingArguments): use_vllm (`bool`, *optional*, defaults to `False`): Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`). - vllm_init_kwargs (`dict`, *optional*, defaults to { - "device": "auto", - "gpu_memory_utilization": 0.9, - "enable_prefix_caching": True, - "dtype": "auto" - }) - "Keyword arguments for `vllm.LLM.__init__` when `use_vllm` is true" + vllm_device (`str`, *optional*, defaults to `"auto"`): + Device where vLLM generation will run, e.g. `"cuda:1"`. If set to `"auto"` (default), the system will + automatically select the next available GPU after the last one used for training. This assumes that + training has not already occupied all available GPUs. + vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.9`): + Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the + device dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus + improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors + during initialization. + vllm_dtype (`str`, *optional*, defaults to `"auto"`): + Data type to use for vLLM generation. If set to `"auto"`, the data type will be automatically determined + based on the model configuration. Find the supported values in the vLLM documentation. + vllm_max_model_len (`int` or `None`, *optional*, defaults to `None`): + If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced + `vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model + context size, which might be much larger than the KV cache, leading to inefficiencies. > Parameters that control the training @@ -126,30 +135,36 @@ class GRPOConfig(TrainingArguments): "(`pip install vllm`)." }, ) - vllm_init_kwargs: Optional[dict] = field( - default_factory=lambda: { - "device": "auto", - "gpu_memory_utilization": 0.9, - # Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can - # directly reuse the KV cache if it shares the same prefix with one of the existing queries. - # This is particularly useful here because we generate completions from the same prompts. - "enable_prefix_caching": True, - "dtype": "auto" - }, + vllm_device: Optional[str] = field( + default="auto", metadata={ - "help": "Keyword arguments for `vllm.LLM.__init__` when `use_vllm` is true" + "help": "Device where vLLM generation will run, e.g. 'cuda:1'. If set to 'auto' (default), the system " + "will automatically select the next available GPU after the last one used for training. This assumes " + "that training has not already occupied all available GPUs." }, - ) - vllm_device: Optional[str] = field( - default=None, + ) + vllm_gpu_memory_utilization: float = field( + default=0.9, metadata={ - "help": "Deprecated. Set `device` in `vllm_init_kwargs` instead." + "help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV " + "cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache " + "size and thus improve the model's throughput. However, if the value is too high, it may cause " + "out-of-memory (OOM) errors during initialization." }, ) - vllm_gpu_memory_utilization: Optional[float] = field( + vllm_dtype: Optional[str] = field( + default="auto", + metadata={ + "help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically " + "determined based on the model configuration. Find the supported values in the vLLM documentation." + } + ) + vllm_max_model_len: Optional[int] = field( default=None, metadata={ - "help": "Deprecated. Set `gpu_memory_utilization` in `vllm_init_kwargs` instead." + "help": "If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced" + "`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model" + "context size, which might be much larger than the KV cache, leading to inefficiencies." }, ) @@ -181,21 +196,4 @@ class GRPOConfig(TrainingArguments): beta: float = field( default=0.04, metadata={"help": "KL coefficient."}, - ) - - def __post_init__(self): - super().__post_init__() - - if self.vllm_device: - warnings.warn( - "`vllm_device` is deprecated. Set `device` in `vllm_init_kwargs` instead.", - DeprecationWarning, - ) - self.vllm_init_kwargs["device"] = self.vllm_device - - if self.vllm_gpu_memory_utilization: - warnings.warn( - "`vllm_gpu_memory_utilization` is deprecated. Set `gpu_memory_utilization` in `vllm_init_kwargs` instead.", - DeprecationWarning, - ) - self.vllm_init_kwargs["gpu_memory_utilization"] = self.vllm_gpu_memory_utilization + ) \ No newline at end of file diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index d4950124df..78d2206b2c 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -287,8 +287,7 @@ def data_collator(features): # No data collation is needed in GRPO ) if self.accelerator.is_main_process: - vllm_init_kwargs = self.args.vllm_init_kwargs - vllm_device = vllm_init_kwargs.get("device") + vllm_device = self.args.vllm_device if vllm_device == "auto": vllm_device = f"cuda:{self.accelerator.num_processes}" # take the next GPU idx vllm_init_kwargs["device"] = vllm_device @@ -316,7 +315,14 @@ def data_collator(features): # No data collation is needed in GRPO with world_size_patch, profiling_patch: self.llm = LLM( model=model.name_or_path, - **vllm_init_kwargs + device=vllm_device, + gpu_memory_utilization=self.args.vllm_gpu_memory_utilization, + dtype=self.args.vllm_dtype, + # Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can + # directly reuse the KV cache if it shares the same prefix with one of the existing queries. + # This is particularly useful here because we generate completions from the same prompts. + enable_prefix_caching=True, + max_model_len=self.args.vllm_max_model_len, ) self.sampling_params = SamplingParams( n=self.num_generations, From adb0da232ec433a0d2ccd05b4f6489930858238e Mon Sep 17 00:00:00 2001 From: Mircea Pricop Date: Tue, 4 Feb 2025 23:35:08 +0000 Subject: [PATCH 09/11] Remove spurious import --- trl/trainer/grpo_config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index f9997faa2c..87aff8ec27 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import warnings from dataclasses import dataclass, field from typing import Optional From 9e7030a7d1e78cfe6343a753d53db08ddbcd3221 Mon Sep 17 00:00:00 2001 From: Mircea Pricop Date: Tue, 4 Feb 2025 23:41:07 +0000 Subject: [PATCH 10/11] Remove spurious line --- trl/trainer/grpo_trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 27589fce96..f855858cff 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -291,7 +291,6 @@ def data_collator(features): # No data collation is needed in GRPO vllm_device = self.args.vllm_device if vllm_device == "auto": vllm_device = f"cuda:{self.accelerator.num_processes}" # take the next GPU idx - vllm_init_kwargs["device"] = vllm_device # Check that the requested device is available if vllm_device.split(":")[0] == "cuda" and int(vllm_device.split(":")[1]) >= torch.cuda.device_count(): raise ValueError( From d65cec0be58e13c488e1a81938700a159e09b3c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 5 Feb 2025 21:52:11 +0000 Subject: [PATCH 11/11] style --- trl/trainer/grpo_config.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 5d805a44d2..e641065968 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -183,13 +183,13 @@ class GRPOConfig(TrainingArguments): metadata={ "help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically " "determined based on the model configuration. Find the supported values in the vLLM documentation." - } + }, ) vllm_max_model_len: Optional[int] = field( default=None, metadata={ - "help": "If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced" - "`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model" + "help": "If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced " + "`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model " "context size, which might be much larger than the KV cache, leading to inefficiencies." }, )