-
Notifications
You must be signed in to change notification settings - Fork 2
/
configData.py
517 lines (436 loc) · 19 KB
/
configData.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
import logging
import os
import sys
import time
from dotenv import load_dotenv
import openai
from openai import AzureOpenAI
from langchain.chains.question_answering import load_qa_chain
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings
logging.basicConfig(
format="%(asctime)s %(levelname)-4s [%(filename)s:%(lineno)d] - %(message)s",
datefmt="%Y-%m-%dT%H:%M:%S%z",
level=logging.INFO,
)
captionsFolder: str = "Captions"
saveFolder: str = "savedData"
savedFileTypes = ["transcriptData", "topicModel", "topicsOverTime", "questionData"]
outputFolder: str = "Output Data"
representationModelType: str = "langchain"
# Toggle for using KeyBERT vectorization in the BERTopic Model. Default is True.
useKeyBERT: bool = True
# Minimum threshold for the video duration in seconds for processing.
# Shorter videos might not have enough content to generate meaningful topics and questions.
# Default is 300s.
minVideoLength: int = 300
# Because of how BERTopic works, there needs to be a minimum number of sentences per window of time.
# When attempting to automatically segment the transcript into sentences using spaCy,
# If there are sentences longer than this duration, the transcript data might not be suitable for this process.
# Instead, the transcript will be segmented based on the `WINDOW_SIZE`.
# Defaults to 120s, must be >=60.
maxSentenceDuration = 120
# This sets the numbers of times a topic must appear in a given region of time to be considered a valid topic.
# This is used to filter out topics that are not relevant to the video content.
# Default is 2. Higher values will result in fewer questions possibly being generated.
minTopicFrequency: int = 2
for folder in savedFileTypes:
folderPath = os.path.join(saveFolder, folder)
try:
if not os.path.exists(folderPath):
os.makedirs(folderPath)
except OSError:
logging.error(f"Creation of the directory {folderPath} failed.")
sys.exit(f"Directory creation failure. Exiting...")
class configVars:
def __init__(self):
self.logLevel = logging.INFO
self.openAIParams: dict = {
"KEY": "",
"BASE": "",
"VERSION": "",
"MODEL": "",
"ORGANIZATION": "",
}
self.videoToUse: str = ""
self.questionCount: int = 3
self.generationModel: str = "BERTopic"
self.envImportSuccess: dict = {}
# BERTopic-specific Parameters
self.windowSize: int = 30
self.contextWindowSize: int = 600
self.overwriteTranscriptData: bool = False
self.overwriteTopicModel: bool = False
self.overwriteQuestionData: bool = False
self.langchainPrompt: str = (
"Give a single label that is only a few words long to summarize what these documents are about."
)
self.questionPrompt: str = (
"You are a question-generating bot that generates questions for a given topic based on the provided relevant trancription text from a video."
)
def set(self, name, value):
"""
Set the value of a configuration parameter.
Args:
name (str): The name of the configuration parameter.
value: The value to be set.
Raises:
NameError: If the name is not accepted in the `set()` method.
"""
if name in self.__dict__:
self.name = value
else:
raise NameError("Name not accepted in set() method")
def configFetch(
self, name, default=None, casting=None, validation=None, valErrorMsg=None
):
"""
Fetch a configuration parameter from the environment variables.
Args:
name (str): The name of the configuration parameter.
default: The default value to be used if the parameter is not found in the environment variables.
casting (type): The type to cast the parameter value to.
validation (callable): A function to validate the parameter value.
valErrorMsg (str): The error message to be logged if the validation fails.
Returns:
The value of the configuration parameter, or None if it is not found or fails validation.
"""
value = os.environ.get(name, default)
if casting is not None:
try:
if casting is bool:
value = int(value)
value = casting(value)
except ValueError:
errorMsg = f'Casting error for config item "{name}" value "{value}".'
logging.error(errorMsg)
return None
if validation is not None and not validation(value):
errorMsg = f'Validation error for config item "{name}" value "{value}".'
logging.error(errorMsg)
return None
return value
def setFromEnv(self):
"""
Set configuration parameters from environment variables.
This method reads configuration parameters from environment variables and sets them in the object.
It also performs validation and error handling for the configuration parameters.
Raises:
ValueError: If there is a casting error or validation error for any configuration parameter.
FileNotFoundError: If the .env file is missing.
"""
if not os.path.exists(".env"):
logging.error(
"No .env file found. Please configure your environment variables use the .env.sample file as a template."
)
sys.exit("Missing .env file. Exiting...")
# Force the environment variables to be read from the .env file every time.
load_dotenv(".env", override=True)
try:
self.logLevel = str(os.environ.get("LOG_LEVEL", self.logLevel)).upper()
except ValueError:
warnMsg = f"Casting error for config item LOG_LEVEL value. Defaulting to {logging.getLevelName(logging.root.level)}."
logging.warning(warnMsg)
try:
logging.getLogger().setLevel(logging.getLevelName(self.logLevel))
except ValueError:
warnMsg = f"Validation error for config item LOG_LEVEL value. Defaulting to {logging.getLevelName(logging.root.level)}."
logging.warning(warnMsg)
# Currently the code will check and validate all config variables before stopping.
# Reduces the number of runs needed to validate the config variables.
for credPart in self.openAIParams:
if credPart == "BASE":
envVarName = "AZURE_OPENAI_ENDPOINT"
else:
envVarName = "OPENAI_API_" + credPart
self.openAIParams[credPart] = self.configFetch(
envVarName,
self.openAIParams[credPart],
str,
lambda param: len(param) > 0,
)
self.envImportSuccess[self.openAIParams[credPart]] = (
False if not self.openAIParams[credPart] else True
)
if len(self.videoToUse) == 0:
self.videoToUse = self.configFetch(
"VIDEO_TO_USE",
self.videoToUse,
str,
lambda name: len(name) > 0,
)
self.envImportSuccess[self.videoToUse] = (
False if not self.videoToUse else True
)
self.questionCount = self.configFetch(
"QUESTION_COUNT",
self.questionCount,
int,
lambda x: x > 0 or x == -1,
)
self.envImportSuccess[self.questionCount] = (
False if not self.questionCount else True
)
self.generationModel = self.configFetch(
"GENERATION_MODEL",
self.generationModel,
str,
lambda model: model.lower() in ["bertopic", "langchain"],
)
# This should allow for the model to be set to either 'BERTopic' or 'LangChain' in the .env file without being case-sensitive.
if self.generationModel:
self.generationModel = {'bertopic': 'BERTopic', 'langchain': 'LangChain'}.get(self.generationModel.lower(), None)
self.envImportSuccess[self.generationModel] = (
False if not self.generationModel else True
)
self.overwriteTranscriptData = self.configFetch(
"OVERWRITE_EXISTING_TRANSCRIPT",
self.overwriteTranscriptData,
bool,
None,
)
self.envImportSuccess[self.overwriteTranscriptData] = (
False if type(self.overwriteTranscriptData) is not bool else True
)
self.overwriteQuestionData = self.configFetch(
"OVERWRITE_EXISTING_QUESTIONS",
self.overwriteQuestionData,
bool,
None,
)
self.envImportSuccess[self.overwriteQuestionData] = (
False if type(self.overwriteQuestionData) is not bool else True
)
if self.overwriteTranscriptData == True:
self.overwriteQuestionData = True
logging.info(
"Generated Question data will also be overwritten as Transcript data is being overwritten."
)
# Pushed BERTopic variables out to a seperate method to keep the main method clean.
if self.generationModel == "BERTopic":
self.setBERTopicVarsFromEnv()
if False in self.envImportSuccess.values():
sys.exit("Configuration parameter import problems. Exiting...")
logging.info("All configuration parameters set up successfully.")
def setBERTopicVarsFromEnv(self):
"""
Sets the BERTopic variables from the environment configuration.
This method fetches the values of various BERTopic variables from the environment configuration.
It validates the fetched values and updates the corresponding instance variables.
It also updates the `envImportSuccess` dictionary to indicate whether the import was successful for each variable.
Returns:
None
"""
self.windowSize = self.configFetch(
"WINDOW_SIZE",
self.windowSize,
int,
lambda x: x > 0,
)
self.envImportSuccess[self.windowSize] = False if not self.windowSize else True
self.contextWindowSize = self.configFetch(
"RELEVANT_TEXT_CONTEXT_WINDOW",
self.contextWindowSize,
int,
lambda x: x >= 0,
)
self.envImportSuccess[self.contextWindowSize] = (
False if self.contextWindowSize is None else True
)
self.overwriteTranscriptData = self.configFetch(
"OVERWRITE_EXISTING_TRANSCRIPT",
self.overwriteTranscriptData,
bool,
None,
)
self.envImportSuccess[self.overwriteTranscriptData] = (
False if type(self.overwriteTranscriptData) is not bool else True
)
self.overwriteQuestionData = self.configFetch(
"OVERWRITE_EXISTING_QUESTIONS",
self.overwriteQuestionData,
bool,
None,
)
self.envImportSuccess[self.overwriteQuestionData] = (
False if type(self.overwriteQuestionData) is not bool else True
)
self.langchainPrompt = self.configFetch(
"LANGCHAIN_PROMPT",
self.langchainPrompt,
str,
lambda prompt: len(prompt) > 0,
)
self.envImportSuccess[self.langchainPrompt] = (
False if not self.langchainPrompt else True
)
self.questionPrompt = self.configFetch(
"QUESTION_PROMPT",
self.questionPrompt,
str,
lambda prompt: len(prompt) > 0,
)
self.envImportSuccess[self.questionPrompt] = (
False if not self.questionPrompt else True
)
# This checks to set data in the later stages to be overwritten if the earlier stages are set to be overwritten.
if self.overwriteTranscriptData == True:
self.overwriteTopicModel = True
logging.info(
"Topic Model data will also be overwritten as Transcript data is being overwritten."
)
if self.overwriteTopicModel == True:
self.overwriteQuestionData = True
logging.info(
"Generated Question data will also be overwritten as Topic Model data is being overwritten."
)
class OpenAIBot:
"""
A class representing an OpenAI chatbot.
Attributes:
config (object): The configuration object for the chatbot.
messages (list): A list to store the chat messages.
model (str): The OpenAI model to use for generating responses.
systemPrompt (str): The system prompt to include in the chat messages.
client (object): The AzureOpenAI client for making API calls.
tokenUsage (int): The total number of tokens used by the chatbot.
callMaxLimit (int): The maximum number of API call attempts allowed.
Methods:
getResponse(prompt): Generates a response for the given prompt.
"""
def __init__(self, config):
"""
Initializes a new instance of the OpenAIBot class.
Args:
config (object): The configuration object for the chatbot.
"""
self.config = config
self.messages = []
self.model = self.config.openAIParams["MODEL"]
self.systemPrompt = self.config.questionPrompt
self.client = AzureOpenAI(
api_key=self.config.openAIParams["KEY"],
api_version=self.config.openAIParams["VERSION"],
azure_endpoint=self.config.openAIParams["BASE"],
organization=self.config.openAIParams["ORGANIZATION"],
)
self.tokenUsage = 0
self.callMaxLimit = 3
def getResponse(self, prompt):
"""
Generates a response for the given prompt.
Args:
prompt (str): The user prompt for the chatbot.
Returns:
tuple: A tuple containing the response text and a boolean indicating if the response was successful.
"""
callComplete = False
callAttemptCount = 0
while not callComplete and callAttemptCount < self.callMaxLimit:
try:
response = self.client.chat.completions.create(
model=self.model,
messages=[
{
"role": "system",
"content": self.systemPrompt,
},
{"role": "user", "content": prompt},
],
temperature=0,
stop=None,
)
time.sleep(1)
callComplete = True
except openai.AuthenticationError as e:
logging.error(f"Error Message: {e}")
sys.exit("Invalid OpenAI credentials. Exiting...")
except openai.RateLimitError as e:
logging.error(f"Error Message: {e}")
logging.error("Rate limit hit. Pausing for a minute.")
time.sleep(60)
callComplete = False
except openai.Timeout as e:
logging.error(f"Error Message: {e}")
logging.error("Timed out. Pausing for a minute.")
time.sleep(60)
callComplete = False
except Exception as e:
logging.error(f"Error Message: {e}")
logging.error("Failed to send message. Trying again.")
callComplete = False
callAttemptCount += 1
if callAttemptCount >= self.callMaxLimit:
logging.error(
f"Failed to send message at max limit of {self.callMaxLimit} times."
)
sys.exit("Too many failed attempts. Exiting...")
elif callComplete:
responseText = response.choices[0].message.content
self.tokenUsage += response.usage.total_tokens
return responseText, True
class LangChainBot:
def __init__(self, config):
"""
Initializes an instance of the LangChainBot class.
Args:
config (dict): A dictionary containing configuration parameters.
Attributes:
config (dict): The configuration parameters.
model (str): The model specified in the configuration parameters.
client (None): The client object (initially set to None).
embeddings (None): Used only in LangChain-based Question Generation.
chain (None): Used only in BERTopic-based Question Generation.
tokenUsage (int): The token usage count.
"""
self.config = config
self.model = self.config.openAIParams["MODEL"]
self.client = None
self.embeddings = None # Used only in LangChain-based Question Generation
self.chain = None # Used only in BERTopic-based Question Generation
self.tokenUsage = 0
self.initialize()
def initialize(self):
"""
Initializes the LangChainBot instance by calling the appropriate initialization methods based on the generation model specified in the configuration parameters.
"""
self.initializeClient()
if self.config.generationModel == "BERTopic":
self.initializeChain()
elif self.config.generationModel == "LangChain":
self.initializeEmbeddings()
else:
logging.error(
f"Invalid generation model specified: {self.config.generationModel}, valid options are 'BERTopic' and 'LangChain'."
)
sys.exit("Invalid generation model specified. Exiting...")
def initializeClient(self):
"""
Initializes the client object using the configuration parameters.
"""
self.client = AzureChatOpenAI(
api_key=self.config.openAIParams["KEY"],
api_version=self.config.openAIParams["VERSION"],
azure_endpoint=self.config.openAIParams["BASE"],
organization=self.config.openAIParams["ORGANIZATION"],
azure_deployment=self.config.openAIParams["MODEL"],
temperature=0,
)
def initializeEmbeddings(self):
"""
Initializes the embeddings object used for LangChain-based Question Generation using the configuration parameters.
"""
self.embeddings = AzureOpenAIEmbeddings(
api_key=self.config.openAIParams["KEY"],
api_version=self.config.openAIParams["VERSION"],
azure_endpoint=self.config.openAIParams["BASE"],
organization=self.config.openAIParams["ORGANIZATION"],
azure_deployment="text-embedding-ada-002", # This does not work if set to 'gpt-4', but seems to related to 'gpt-4' being the model used in the client.
)
def initializeChain(self):
"""
Initializes the chain object used for BERTopic-based Question Generation using the client object.
"""
self.chain = load_qa_chain(
self.client,
chain_type="stuff",
)