Skip to content

Commit

Permalink
adds remote ref models to GRPO
Browse files Browse the repository at this point in the history
  • Loading branch information
edbeeching committed Feb 4, 2025
1 parent 1f344c9 commit 9b4c4c1
Show file tree
Hide file tree
Showing 5 changed files with 197 additions and 10 deletions.
2 changes: 2 additions & 0 deletions trl/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"modeling_base": ["GeometricMixtureWrapper", "PreTrainedModelWrapper", "create_reference_model"],
"modeling_value_head": ["AutoModelForCausalLMWithValueHead", "AutoModelForSeq2SeqLMWithValueHead"],
"utils": ["SUPPORTED_ARCHITECTURES", "prepare_deepspeed", "setup_chat_format", "unwrap_model_for_generation"],
"remote_models": ["RemoteModel"],
}

try:
Expand All @@ -40,6 +41,7 @@
from .modeling_base import GeometricMixtureWrapper, PreTrainedModelWrapper, create_reference_model
from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead
from .utils import SUPPORTED_ARCHITECTURES, prepare_deepspeed, setup_chat_format, unwrap_model_for_generation
from .remote_models import RemoteModel

try:
if not is_diffusers_available():
Expand Down
94 changes: 94 additions & 0 deletions trl/models/remote_model_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoModelForCausalLM
import torch
import argparse
import uvicorn
import argparse
from trl import ModelConfig
"""
Usage
python trl/models/remote_model_app.py --model_name deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --port 8000
"""

app = FastAPI()
model = None

class ForwardPassRequest(BaseModel):
input_ids: list[list[int]]
attention_mask: list[list[int]]
logits_to_keep: int

@app.post("/forward/")
async def forward_pass(request: ForwardPassRequest):
print(request)
device = model.device
input_ids = torch.LongTensor(request.input_ids).to(device)
attention_mask = torch.LongTensor(request.attention_mask).to(device)
logits_to_keep = request.logits_to_keep
# Perform the forward pass
with torch.no_grad():
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
logits_to_keep=logits_to_keep,
)
logits = outputs.logits

# Convert logits to CPU and then to a list for JSON serialization
logits_list = logits.cpu().tolist()

return {"logits": logits_list}

@app.get("/health")
async def health_check():
"""
Provides a health check endpoint for the server.
Returns:
dict: A dictionary indicating the server's health status.
"""
return {"status": "OK"}

def init_model(model_config: ModelConfig):
global model

torch_dtype = (
model_args.torch_dtype
if model_args.torch_dtype in ["auto", None]
else getattr(torch, model_args.torch_dtype)
)
model = AutoModelForCausalLM.from_pretrained(
model_config.model_name_or_path,
revision=model_config.model_revision,
trust_remote_code=model_config.trust_remote_code,
attn_implementation=model_config.attn_implementation,
torch_dtype=torch_dtype,
)

if torch.cuda.is_available():
model.to("cuda")
print(f"Model '{model_config.model_name_or_path}' loaded on GPU")
else:
print(f"Model '{model_config.model_name_or_path}' loaded on CPU")

if __name__ == "__main__":
from trl import ModelConfig, TrlParser
parser = TrlParser(ModelConfig)
model_args = parser.parse_args_and_config()[0]
init_model(model_args)
uvicorn.run(app)
80 changes: 80 additions & 0 deletions trl/models/remote_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import requests
from transformers.modeling_outputs import CausalLMOutputWithPast

class RemoteModel():
def __init__(self, remote_model_url):
self.remote_model_url = remote_model_url
# Check if the remote server is healthy
health_check_url = f"{self.remote_model_url}/health"
response = requests.get(health_check_url)
if response.status_code != 200:
raise Exception(f"Server health check failed: {response.text}")

def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, logits_to_keep: int) -> CausalLMOutputWithPast:
"""
Sends a request to the remote server to perform a forward pass.
Args:
input_ids (torch.Tensor): The input token IDs.
attention_mask (torch.Tensor): The attention mask.
logits_to_keep (int): The number of logits to keep.
Returns:
CausalLMOutputWithPast: Contains only the logits.
"""
# Convert tensors to lists for JSON serialization
device = input_ids.device
input_ids_list = input_ids.tolist()
attention_mask_list = attention_mask.tolist()

# Prepare the request body
request_body = {
"input_ids": input_ids_list,
"attention_mask": attention_mask_list,
"logits_to_keep": logits_to_keep
}

# Send the POST request to the server
# add a few retries?
response = requests.post(f"{self.remote_model_url}/forward", json=request_body)

# Check for errors
if response.status_code != 200:
raise Exception(f"Error from server: {response}")

# Parse the response
response_json = response.json()
logits_list = response_json["logits"]

# Convert the logits back to a tensor
logits = torch.tensor(logits_list).to(device)

return CausalLMOutputWithPast(logits=logits)

if __name__ == "__main__":
import argparse
# Parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--url", type=str, required=True)
args = parser.parse_args()
remote_model = RemoteModel(args.url)
print(remote_model.remote_model_url)
input_ids = torch.Tensor([[1, 2, 3]])
attention_mask = torch.Tensor([[1, 1, 1]])
logits_to_keep = 1
print(remote_model(input_ids, attention_mask, logits_to_keep))

5 changes: 5 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ class GRPOConfig(TrainingArguments):
Number of updates steps to accumulate the gradients for, before performing a backward/update pass.
beta (`float`, *optional*, defaults to `0.04`):
KL coefficient.
> Parameters that control remote models
ref_model_url: str
"""

# Parameters that control the model and reference model
Expand Down Expand Up @@ -174,3 +177,5 @@ class GRPOConfig(TrainingArguments):
default=0.04,
metadata={"help": "KL coefficient."},
)

ref_model_url: Optional[str] = field(default=None, metadata={"help": "URL of the reference model, if you are using a TRL RemoteModel"})
26 changes: 16 additions & 10 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@

from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
from ..import_utils import is_vllm_available
from ..models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
from ..models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation, RemoteModel
from .grpo_config import GRPOConfig
from .utils import generate_model_card, get_comet_experiment_url, pad

Expand Down Expand Up @@ -198,15 +198,21 @@ def __init__(
model = get_peft_model(model, peft_config)

# Reference model
if is_deepspeed_zero3_enabled():
self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
elif peft_config is None:
# If PEFT configuration is not provided, create a reference model based on the initial model.
self.ref_model = create_reference_model(model)
if args.ref_model_url is None:
if is_deepspeed_zero3_enabled():
self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
elif peft_config is None:
# If PEFT configuration is not provided, create a reference model based on the initial model.
self.ref_model = create_reference_model(model)
else:
# If PEFT is used, the reference model is not needed since the adapter can be disabled
# to revert to the initial model.
self.ref_model = None
else:
# If PEFT is used, the reference model is not needed since the adapter can be disabled
# to revert to the initial model.
self.ref_model = None
if peft_config is not None:
raise ValueError("You cannot use PEFT and a remote model at the same time")

self.ref_model = RemoteModel(args.ref_model_url)

# Processing class
if processing_class is None:
Expand Down Expand Up @@ -348,7 +354,7 @@ def data_collator(features): # No data collation is needed in GRPO
# Add tags to the model
self.model.add_model_tags(self._tag_names)

if self.ref_model is not None:
if self.ref_model is not None and not isinstance(self.ref_model, RemoteModel):
if self.is_deepspeed_enabled:
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
else:
Expand Down

0 comments on commit 9b4c4c1

Please sign in to comment.