-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathopenai_utils.py
79 lines (63 loc) · 2.26 KB
/
openai_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import json
from time import sleep
import numpy as np
import openai
with open("openai_config.json", encoding="utf-8") as f:
config = json.load(f)
openai.api_key = config["openai_api_key"]
def parse_results(result):
raw_logprobs = result["choices"][0]["logprobs"]["top_logprobs"][0]
logprobs = [(int(x.strip()), raw_logprobs[x]) for x in raw_logprobs if x.strip().isdecimal()]
sorted_logprobs = sorted(logprobs, key=lambda tup: tup[1], reverse=True)
probs = [x[1] for x in sorted_logprobs]
softmax_probs = np.exp(probs) / np.sum(np.exp(probs), axis=0)
to_return = [(x[0], p) for x, p in zip(sorted_logprobs, softmax_probs)]
return to_return
def parse_results_chatgpt(result):
return_text = result["choices"][0]["message"]["content"]
to_return = [(int(return_text), 1)] if return_text.isdecimal() else []
return to_return
def predict(prompt, args):
got_result = False
while not got_result:
try:
results = openai.Completion.create(
engine=args.model,
prompt=prompt,
max_tokens=64,
temperature=0.0,
top_p=1,
n=1,
stop=["]", "."],
logprobs=10,
)
got_result = True
except Exception: # pylint: disable=broad-exception-caught
sleep(3)
parsed_results = parse_results(results) # type: ignore
return parsed_results
def predict_chatgpt(prompt, args):
if args.sys_instruction == "":
prompt = [{"role": "user", "content": prompt}]
else:
prompt = [
{"role": "system", "content": args.sys_instruction},
{"role": "user", "content": prompt},
]
got_result = False
while not got_result:
try:
results = openai.ChatCompletion.create(
model="gpt-3.5-turbo-0301",
messages=prompt,
max_tokens=64,
temperature=0.0,
top_p=1,
n=1,
stop=["]", "."],
)
got_result = True
except Exception: # pylint: disable=broad-exception-caught
sleep(3)
parsed_results = parse_results_chatgpt(results) # type: ignore
return parsed_results