-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathLLMInterface.py
133 lines (108 loc) · 4.29 KB
/
LLMInterface.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import random
import requests
import itertools
from Data import load_qafiyas, load_poets, load_bohours
from abc import ABC, abstractmethod
from ibm_cloud_sdk_core.authenticators import IAMAuthenticator
import os
def load_env(file_path):
"""Load environment variables from a .env file."""
with open(file_path) as f:
for line in f:
# Remove comments and whitespace
line = line.strip()
if line and not line.startswith('#'):
key, value = line.split('=', 1)
os.environ[key] = value
BASE_URL = "https://eu-de.ml.cloud.ibm.com/ml/"
class OpenAI_Generator:
def __init__(self, API_KEY):
pass
def generate(self, prompt):
pass
BAYT_SEPARATORS = ["\n", "*", "#", '/', '.']
import copy
class ALLAM_GENERATOR:
def __init__(self, API_KEY):
self.model_id = "sdaia/allam-1-13b-instruct"
self.project_id = "0a443bde-e9c6-41dc-b1f2-65c6292030e4"
# get authentication token
authenticator = IAMAuthenticator(API_KEY)
token = authenticator.token_manager.get_token()
self.headers = {
'Accept': 'application/json',
'Content-Type': 'application/json',
'Authorization': f'Bearer {token}'
}
# set default parameters
self.parameters = {
"decoding_method": "sample",
"max_new_tokens": 15,
"min_new_tokens": 3,
"temperature": 0.3,
"top_k": 40,
#"top_p": 0.5,
"repetition_penalty": 1.25,
"stop_sequences": BAYT_SEPARATORS,
}
self.critic_parameters = {
"decoding_method": "greedy",
"max_new_tokens": 250,
#"stop_sequences": ["\n"],
"stop_sequences":[],
}
def generate(self, prompt, is_critic=False, temp=None, stop_tokens=[]):
url = BASE_URL + "v1/text/generation?version=2024-08-30"
params = copy.deepcopy(self.critic_parameters if is_critic else self.parameters)
params["stop_sequences"] = stop_tokens
body = {
"input": prompt,
"model_id": self.model_id,
"project_id": self.project_id,
"parameters": params
}
#body["parameters"]["stop_sequences"] = stop_tokens
if not is_critic:
if temp:
body["parameters"]["temperature"] = temp
else:
body["parameters"]["temperature"] = 0.3
else:
body["parameters"]["temperature"] = 0.0
response = requests.post(url, headers=self.headers, json=body)
response.raise_for_status()
data = response.json()
return data['results'][0]['generated_text']
# A fake LLM for testing
class FakeGenerator:
def __init__(self, poet=None, wazn=None, qafiya=None):
self.poets = load_poets()
self.bohours = load_bohours()
self.qafiyas = load_qafiyas()
if poet: # in arabic
if poet not in self.poets.keys():
raise ValueError(f"Poet not found in database: Could not find {poet} in 'poet.json'")
self.poet = self.poets[poet]
else:
# self.poet = self.poets[random.choice(list(self.poets.keys()))]
self.poet = self.poets["المتنبي"]
if wazn: # in arabic
if wazn not in self.bohours.keys():
raise ValueError(f"Bahr not found in database: Could not find {wazn} in 'bohours.json'")
if not poet:
# self.poet = self.poets[random.choice(list(self.poets.keys()))]
self.poet = self.poets["المتنبي"]
try:
self.poems = self.poet['poems'][self.bohours[wazn]['name_en']]
except KeyError:
print(f"Poet {self.poet['name']} does not have any poems in {wazn}")
self.poems = random.choice(list(self.poet['poems'].values()))
else:
self.poems = random.choice(list(self.poet['poems'].values()))
if qafiya:
# TODO: Implement this
pass
# cycle through the poem lines infinitely
self.poem = itertools.cycle(random.choice(self.poems))
def generate(self, prompt=None, is_critic=False, temp=0.5, stop_tokens=[]):
return next(self.poem)