-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsharegpt-to-dpo.py
233 lines (204 loc) · 7.54 KB
/
sharegpt-to-dpo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
import argparse
import json
import pathlib
import time
from functools import lru_cache
import httpx
from datasets import Dataset, load_dataset
def infer_completion(prompt, genparams: dict):
"""Generate response from API using completions endpoint"""
while True:
r = httpx.stream(
"POST",
f"{API_URL}/v1/completions",
headers={"authorization": API_KEY, "x-api-key": API_KEY},
json={"prompt": prompt, "stream": True, "model": MODEL, **genparams},
timeout=None,
)
with r as r:
if r.status_code == 200:
generated_text = ""
for chunk in r.iter_lines():
if chunk.startswith("data: "):
chunk = chunk.replace("data: ", "")
if chunk == "[DONE]":
break
chunk_data = json.loads(chunk)
if "choices" in chunk_data:
word = chunk_data["choices"][0]["text"]
generated_text += word
return generated_text
else:
time.sleep(5)
def infer_chat_completion(messages, genparams: dict):
"""Generate response from API using chat completions endpoint"""
while True:
r = httpx.stream(
"POST",
f"{API_URL}/v1/chat/completions",
headers={"authorization": API_KEY, "x-api-key": API_KEY},
json={
"messages": messages,
"stream": True,
"model": MODEL,
"add_generation_prompt": True,
**genparams,
},
timeout=None,
)
with r as r:
if r.status_code == 200:
generated_text = ""
for chunk in r.iter_lines():
if chunk.startswith("data: "):
chunk = chunk.replace("data: ", "")
chunk_data = json.loads(chunk)
if "choices" in chunk_data:
if chunk_data["choices"][0]["finish_reason"]:
break
word = chunk_data["choices"][0]["delta"]["content"]
generated_text += word
return generated_text
else:
time.sleep(5)
def format_prompt_jinja(
messages, template: str, add_generation_prompt: bool, special_tokens: dict
):
"""Format prompt using Jinja2 template"""
compiled_template = _compile_template(template)
return compiled_template.render(
messages=messages,
add_generation_prompt=add_generation_prompt,
**special_tokens,
)
# Inspired from
# https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1761
@lru_cache
def _compile_template(template: str):
"""Compiles a Jinja2 template"""
# Exception handler
def raise_exception(message):
raise TemplateError(message)
jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
jinja_env.globals["raise_exception"] = raise_exception
jinja_template = jinja_env.from_string(template)
return jinja_template
def get_template_from_file(template_path_raw: str):
"""Get a template from a jinja file"""
template_path = pathlib.Path(template_path_raw)
if template_path.exists():
with open(template_path, "r", encoding="utf8") as raw_template:
return raw_template.read()
else:
raise FileNotFoundError(f'Template "{template_path_raw}" not found.')
def process(data):
"""Primary dataset building function"""
system = ""
prompt = ""
chosen = ""
rejected = ""
convo = []
for message in data["conversations"]:
if message["from"] == "system" and not system:
system = message["value"].strip()
convo.append({"role": "system", "content": system})
elif message["from"] == "human" and not prompt:
prompt = message["value"].strip()
convo.append({"role": "user", "content": prompt})
elif message["from"] == "gpt" and not prompt:
print("\nWARNING: Conversation does not begin with user turn - skipping.")
data["system"] = None
data["prompt"] = None
data["chosen"] = None
data["rejected"] = None
return data
elif message["from"] == "gpt" and not chosen:
chosen = message["value"].strip()
break
if CHAT_COMPLETION:
rejected = infer_chat_completion(convo, GEN_PARAMS)
else:
if PROMPT_TEMPLATE:
# We don't need BOS token here because infer_completion already asks the endpoint to add it
constructed_prompt = format_prompt_jinja(
convo,
PROMPT_TEMPLATE,
True,
{"eos_token": EOS_TOKEN},
)
else:
# Default Mistral prompt fallback
if system:
constructed_prompt = f"[INST] {system}\n\n{prompt} [/INST]"
else:
constructed_prompt = f"[INST] {prompt} [/INST]"
rejected = infer_completion(constructed_prompt, GEN_PARAMS)
data["system"] = system
data["prompt"] = prompt
if args.chosen:
# Reverse chosen and rejected if requested
data["chosen"] = rejected
data["rejected"] = chosen
else:
data["chosen"] = chosen
data["rejected"] = rejected
return data
# Setup args
script_dir = pathlib.Path(__file__).parent.resolve()
conf_path = script_dir / "config.json"
with open(conf_path, "r") as config_file:
config = json.load(config_file)
API_URL = config.get("api_url", "http://127.0.0.1:5000")
API_KEY = config.get("api_key", None)
MODEL = config.get("model", "gpt-3.5-turbo")
CHAT_COMPLETION = config.get("chat_completion", False)
EOS_TOKEN = config.get("eos_token", "</s>")
GEN_PARAMS = config.get("gen_params", {})
if GEN_PARAMS.get("ban_eos_token"):
# Don't use EOS token in jinja2 template if user bans EOS token
EOS_TOKEN = None
parser = argparse.ArgumentParser(description="ShareGPT to DPO dataset creator")
parser.add_argument(
"datafile",
type=str,
help="Dataset file in ShareGPT format, accepts .json/.jsonl/.parquet",
)
parser.add_argument(
"-t",
"--template",
type=str,
default=None,
help="(Optional) Prompt template in Jinja2 format",
)
parser.add_argument(
"-c",
"--chosen",
action="store_true",
help="Generate 'chosen' response for DPO instead of 'rejected'",
)
args = parser.parse_args()
file = pathlib.Path(args.datafile)
datatype = None
if file.name.endswith(".json") or file.name.endswith(".jsonl"):
datatype = "json"
elif file.name.endswith(".parquet"):
datatype = "parquet"
PROMPT_TEMPLATE = None
if args.template:
try:
from jinja2 import TemplateError
from jinja2.sandbox import ImmutableSandboxedEnvironment
template_file = pathlib.Path(args.template)
PROMPT_TEMPLATE = get_template_from_file(template_file)
except Exception:
print("jinja2 template not available, using default prompt formatter (Mistral)")
# Load and process dataset
dataset = load_dataset(datatype, data_files=str(file))
dataset = dataset.map(process)
dataset = dataset.select_columns(["system", "prompt", "chosen", "rejected"])
filtered_data = []
for row in dataset["train"]:
if row.get("prompt"):
filtered_data.append(row)
dataset = Dataset.from_list(filtered_data)
dataset.to_json(f"{file.stem}-dpo.jsonl")