Skip to content

Commit

Permalink
Add DPO single node example (#95)
Browse files Browse the repository at this point in the history
  • Loading branch information
AVSuni authored Feb 6, 2025
1 parent 62bbdf8 commit 5e7dc96
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,24 @@ Or set these variables yourself with the following command:

## Dependencies
- Secret `hf-token`: Hugging Face API token for model download


## Testing the workload

RayService typically starts a service called `yourworkloadname-serve-svc` at port 8000. For production, you should add an ingress to the service. For testing, you can use `kubectl port-forward` to access the service.

`kubectl port-forward svc/yourworkloadname-multi-serve-svc 8000:8000 -n yournamespace`

Then you can send a request to the service with curl:

```
curl http://localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{
"model": "meta-llama/Llama-3.1-8B-Instruct",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Provide a brief sentence describing the Ray open-source project."}
],
"temperature": 0.7
}'
```
3 changes: 3 additions & 0 deletions workloads/training/LLMs/dpo/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Single-node Direct Preference Optimization (DPO) example with HF Accelerate

Run on 4 GPUs with `kaiwo submit -p workloads/training/LLMs/dpo -g 4 -n yournamespace --storage=100Gi,yourstorageclass`
135 changes: 135 additions & 0 deletions workloads/training/LLMs/dpo/dpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# This code includes portions from TRL project, licensed under the Apache License 2.0.
# See https://github.com/huggingface/trl for details.

# Copyright 2025 Advanced Micro Devices, Inc. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import numpy as np
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.nn.attention import SDPBackend, sdpa_kernel

from trl import (
DPOConfig,
DPOTrainer,
ModelConfig,
ScriptArguments,
TrlParser,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE




def remove_invalid_entries(example):
"""
Removes NaN and Inf values from the dataset example.
"""
for key, value in example.items():
if isinstance(value, float) and (np.isnan(value) or np.isinf(value)):
return False # Filter out this entry
return True



def main(script_args, training_args, model_args):
################
# Model & Tokenizer
###################
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs
)
peft_config = get_peft_config(model_args)
if peft_config is None:
ref_model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs
)
else:
ref_model = None
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
if script_args.ignore_bias_buffers:
# torch distributed hack
model._ddp_params_and_buffers_to_ignore = [
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
]

################
# Dataset
################
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
dataset = dataset.filter(remove_invalid_entries)

##########
# Training
################
trainer = DPOTrainer(
model,
ref_model,
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
processing_class=tokenizer,
peft_config=peft_config,
)

with sdpa_kernel(SDPBackend.MATH):
trainer.train()

if training_args.eval_strategy != "no":
metrics = trainer.evaluate()
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub(dataset_name=script_args.dataset_name)


def make_parser(subparsers: argparse._SubParsersAction = None):
dataclass_types = (ScriptArguments, DPOConfig, ModelConfig)
if subparsers is not None:
parser = subparsers.add_parser("dpo", help="Run the DPO training script", dataclass_types=dataclass_types)
else:
parser = TrlParser(dataclass_types)
return parser


if __name__ == "__main__":
torch.autograd.set_detect_anomaly(True)
parser = make_parser()
script_args, training_args, model_args = parser.parse_args_and_config()
main(script_args, training_args, model_args)
16 changes: 16 additions & 0 deletions workloads/training/LLMs/dpo/entrypoint
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
accelerate launch mounted/dpo.py \
--dataset_name trl-lib/ultrafeedback_binarized \
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
--learning_rate 5.0e-6 \
--num_train_epochs 1 \
--per_device_train_batch_size 8 \
--logging_steps 25 \
--eval_strategy steps \
--eval_steps 50 \
--output_dir Qwen2-0.5B-DPO \
--no_remove_unused_columns \
--use_peft \
--lora_r 32 \
--lora_alpha 16 \
--bf16 \
--optim="adamw_torch"

0 comments on commit 5e7dc96

Please sign in to comment.