diff --git a/trl/models/__init__.py b/trl/models/__init__.py index 2365e7c1de..94fbf9c188 100644 --- a/trl/models/__init__.py +++ b/trl/models/__init__.py @@ -21,6 +21,7 @@ "modeling_base": ["GeometricMixtureWrapper", "PreTrainedModelWrapper", "create_reference_model"], "modeling_value_head": ["AutoModelForCausalLMWithValueHead", "AutoModelForSeq2SeqLMWithValueHead"], "utils": ["SUPPORTED_ARCHITECTURES", "prepare_deepspeed", "setup_chat_format", "unwrap_model_for_generation"], + "remote_models": ["RemoteModel"], } try: @@ -40,6 +41,7 @@ from .modeling_base import GeometricMixtureWrapper, PreTrainedModelWrapper, create_reference_model from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead from .utils import SUPPORTED_ARCHITECTURES, prepare_deepspeed, setup_chat_format, unwrap_model_for_generation + from .remote_models import RemoteModel try: if not is_diffusers_available(): diff --git a/trl/models/remote_model_app.py b/trl/models/remote_model_app.py new file mode 100644 index 0000000000..dd593dfbf7 --- /dev/null +++ b/trl/models/remote_model_app.py @@ -0,0 +1,94 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel +from transformers import AutoModelForCausalLM +import torch +import argparse +import uvicorn +import argparse +from trl import ModelConfig +""" +Usage +python trl/models/remote_model_app.py --model_name deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --port 8000 +""" + +app = FastAPI() +model = None + +class ForwardPassRequest(BaseModel): + input_ids: list[list[int]] + attention_mask: list[list[int]] + logits_to_keep: int + +@app.post("/forward/") +async def forward_pass(request: ForwardPassRequest): + print(request) + device = model.device + input_ids = torch.LongTensor(request.input_ids).to(device) + attention_mask = torch.LongTensor(request.attention_mask).to(device) + logits_to_keep = request.logits_to_keep + # Perform the forward pass + with torch.no_grad(): + outputs = model( + input_ids=input_ids, + attention_mask=attention_mask, + logits_to_keep=logits_to_keep, + ) + logits = outputs.logits + + # Convert logits to CPU and then to a list for JSON serialization + logits_list = logits.cpu().tolist() + + return {"logits": logits_list} + +@app.get("/health") +async def health_check(): + """ + Provides a health check endpoint for the server. + + Returns: + dict: A dictionary indicating the server's health status. + """ + return {"status": "OK"} + +def init_model(model_config: ModelConfig): + global model + + torch_dtype = ( + model_args.torch_dtype + if model_args.torch_dtype in ["auto", None] + else getattr(torch, model_args.torch_dtype) + ) + model = AutoModelForCausalLM.from_pretrained( + model_config.model_name_or_path, + revision=model_config.model_revision, + trust_remote_code=model_config.trust_remote_code, + attn_implementation=model_config.attn_implementation, + torch_dtype=torch_dtype, + ) + + if torch.cuda.is_available(): + model.to("cuda") + print(f"Model '{model_config.model_name_or_path}' loaded on GPU") + else: + print(f"Model '{model_config.model_name_or_path}' loaded on CPU") + +if __name__ == "__main__": + from trl import ModelConfig, TrlParser + parser = TrlParser(ModelConfig) + model_args = parser.parse_args_and_config()[0] + init_model(model_args) + uvicorn.run(app) \ No newline at end of file diff --git a/trl/models/remote_models.py b/trl/models/remote_models.py new file mode 100644 index 0000000000..a2167b55b8 --- /dev/null +++ b/trl/models/remote_models.py @@ -0,0 +1,80 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import requests +from transformers.modeling_outputs import CausalLMOutputWithPast + +class RemoteModel(): + def __init__(self, remote_model_url): + self.remote_model_url = remote_model_url + # Check if the remote server is healthy + health_check_url = f"{self.remote_model_url}/health" + response = requests.get(health_check_url) + if response.status_code != 200: + raise Exception(f"Server health check failed: {response.text}") + + def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, logits_to_keep: int) -> CausalLMOutputWithPast: + """ + Sends a request to the remote server to perform a forward pass. + + Args: + input_ids (torch.Tensor): The input token IDs. + attention_mask (torch.Tensor): The attention mask. + logits_to_keep (int): The number of logits to keep. + + Returns: + CausalLMOutputWithPast: Contains only the logits. + """ + # Convert tensors to lists for JSON serialization + device = input_ids.device + input_ids_list = input_ids.tolist() + attention_mask_list = attention_mask.tolist() + + # Prepare the request body + request_body = { + "input_ids": input_ids_list, + "attention_mask": attention_mask_list, + "logits_to_keep": logits_to_keep + } + + # Send the POST request to the server + # add a few retries? + response = requests.post(f"{self.remote_model_url}/forward", json=request_body) + + # Check for errors + if response.status_code != 200: + raise Exception(f"Error from server: {response}") + + # Parse the response + response_json = response.json() + logits_list = response_json["logits"] + + # Convert the logits back to a tensor + logits = torch.tensor(logits_list).to(device) + + return CausalLMOutputWithPast(logits=logits) + +if __name__ == "__main__": + import argparse + # Parse command line arguments + parser = argparse.ArgumentParser() + parser.add_argument("--url", type=str, required=True) + args = parser.parse_args() + remote_model = RemoteModel(args.url) + print(remote_model.remote_model_url) + input_ids = torch.Tensor([[1, 2, 3]]) + attention_mask = torch.Tensor([[1, 1, 1]]) + logits_to_keep = 1 + print(remote_model(input_ids, attention_mask, logits_to_keep)) + \ No newline at end of file diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 0fd0d9f5d2..1a3b81a295 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -78,6 +78,9 @@ class GRPOConfig(TrainingArguments): Number of updates steps to accumulate the gradients for, before performing a backward/update pass. beta (`float`, *optional*, defaults to `0.04`): KL coefficient. + + > Parameters that control remote models + ref_model_url: str """ # Parameters that control the model and reference model @@ -174,3 +177,5 @@ class GRPOConfig(TrainingArguments): default=0.04, metadata={"help": "KL coefficient."}, ) + + ref_model_url: Optional[str] = field(default=None, metadata={"help": "URL of the reference model, if you are using a TRL RemoteModel"}) \ No newline at end of file diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 38286335ff..6d38054852 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -41,7 +41,7 @@ from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template from ..import_utils import is_vllm_available -from ..models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation +from ..models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation, RemoteModel from .grpo_config import GRPOConfig from .utils import generate_model_card, get_comet_experiment_url, pad @@ -198,15 +198,21 @@ def __init__( model = get_peft_model(model, peft_config) # Reference model - if is_deepspeed_zero3_enabled(): - self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs) - elif peft_config is None: - # If PEFT configuration is not provided, create a reference model based on the initial model. - self.ref_model = create_reference_model(model) + if args.ref_model_url is None: + if is_deepspeed_zero3_enabled(): + self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs) + elif peft_config is None: + # If PEFT configuration is not provided, create a reference model based on the initial model. + self.ref_model = create_reference_model(model) + else: + # If PEFT is used, the reference model is not needed since the adapter can be disabled + # to revert to the initial model. + self.ref_model = None else: - # If PEFT is used, the reference model is not needed since the adapter can be disabled - # to revert to the initial model. - self.ref_model = None + if peft_config is not None: + raise ValueError("You cannot use PEFT and a remote model at the same time") + + self.ref_model = RemoteModel(args.ref_model_url) # Processing class if processing_class is None: @@ -348,7 +354,7 @@ def data_collator(features): # No data collation is needed in GRPO # Add tags to the model self.model.add_model_tags(self._tag_names) - if self.ref_model is not None: + if self.ref_model is not None and not isinstance(self.ref_model, RemoteModel): if self.is_deepspeed_enabled: self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) else: