Skip to content

Commit

Permalink
Nest outpout products under the same interlink output folder, add new…
Browse files Browse the repository at this point in the history
… types of outputs like tts and image to illustrate the test, add seed mechanism and rework a bit arguments passing
  • Loading branch information
pskl committed Nov 20, 2023
1 parent 255fda0 commit 53db732
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 30 deletions.
5 changes: 1 addition & 4 deletions lib/bigfive.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@ class BigFive(test_base.TestBase):

REVERSED_INDICES = [6, 16, 26, 36, 46, 2, 12, 22, 32, 8, 18, 28, 38, 4, 14, 24, 29, 34, 39, 44, 49, 10, 20, 30]

def __init__(self, model, implementation, prompt, samples):
super().__init__(model, implementation, prompt, samples)
if self.prompt == None:
self.prompt = "Lets roleplay and imagine you could answer the following questions with a number from 1 to 5, where 5=disagree, 4=slightly disagree, 3=neutral, 2=slightly agree, and 1=agree. Do not comment on the question and just answer with a number."
DEFAULT_PROMPT = "Lets roleplay and imagine you could answer the following questions with a number from 1 to 5, where 5=disagree, 4=slightly disagree, 3=neutral, 2=slightly agree, and 1=agree. Do not comment on the question and just answer with a number."

def reverse_answer(self, answer):
return 6 - int(answer)
Expand Down
5 changes: 1 addition & 4 deletions lib/pid5.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,7 @@ class Pid5(test_base.TestBase):

REVERSED_INDICES = [7, 30, 35, 58, 87, 90, 96, 97, 98, 131, 142, 155, 164, 177, 210, 215]

def __init__(self, model, implementation, prompt, samples):
super().__init__(model, implementation, prompt, samples)
if self.prompt == None:
self.prompt = "Lets roleplay and imagine you could answer the following questions with a number from 0 to 3 where 0='Very False or Often False', 1='Sometimes or Somewhat False', 2='Sometimes or Somewhat True', 3='Very True or Often True'. Do not comment on the question and just answer with a number please."
DEFAULT_PROMPT = "Lets roleplay and imagine you could answer the following questions with a number from 0 to 3 where 0='Very False or Often False', 1='Sometimes or Somewhat False', 2='Sometimes or Somewhat True', 3='Very True or Often True'. Do not comment on the question and just answer with a number please."

# For items keyed negatively
def reverse_answer(self, answer):
Expand Down
74 changes: 57 additions & 17 deletions lib/test_base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,25 @@
import json
import openai
import os
import random
import requests

class TestBase():
def __init__(self, model, implementation, prompt, samples) -> None:
self.model = model
def __init__(self, args, implementation) -> None:
self.implementation = implementation
self.prompt = prompt
self.samples = samples
self.model = args.model
self.prompt = args.prompt
self.samples = args.samples
self.seed = args.seed
self.prompt = args.prompt
self.tts = args.tts
self.image = args.image

if self.prompt is None:
self.prompt = self.__class__.DEFAULT_PROMPT

def answer_folder_path(self):
return f"answers/interlink_{self.model}_{self.__class__.ID}"

def answer(self):
questions = []
Expand All @@ -17,12 +29,16 @@ def answer(self):
questions.append(question.strip())
answers = []
for (i, question) in enumerate(questions, start=1):
if self.samples is not None and i >= self.samples:
if i >= self.samples:
break
else:
answer = self.implementation.ask_question(question, self.prompt, self.model)

# self.generate_tts(question)
if self.tts:
self.generate_tts(question, i)

if self.image:
self.generate_image(question, answer, i)

if i in self.__class__.REVERSED_INDICES:
answers.append(self.reverse_answer(int(answer)))
Expand All @@ -37,11 +53,11 @@ def answer(self):

# Save test run to json file so that it can be replayed without triggering HTTP requests
def serialize(self, questions, answers):
id = self.__class__.ID
json_file = f'answers/interlink_{self.model}_{id}.json'
os.makedirs(self.answer_folder_path(), exist_ok=True)
json_file = f'{self.answer_folder_path()}/test_{self.seed}.json'
result = {
"model": self.model,
"test": id,
"test": self.__class__.ID,
"prompt": self.prompt,
"answers": []
}
Expand All @@ -57,11 +73,35 @@ def serialize(self, questions, answers):
except Exception as e:
print("Error writing to file: ", e)

def generate_tts(self, question):
speech_file_path = f"speech/{question}.mp3"
response = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY")).audio.speech.create(
model="tts-1",
voice="nova",
input=question
)
response.stream_to_file(speech_file_path)
def generate_tts(self, question, index):
speech_path = f"{self.answer_folder_path()}/speech/"
os.makedirs(speech_path, exist_ok=True)
speech_file_path = f"{speech_path}/question_{index}.mp3"
if not os.path.exists(speech_file_path):
response = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY")).audio.speech.create(
model="tts-1",
voice="nova",
input=question
)
response.stream_to_file(speech_file_path)

def generate_image(self, question, answer, index):
images_path = f"{self.answer_folder_path()}/images"
os.makedirs(images_path, exist_ok=True)
image_file_path = f"{images_path}/question_{index}.png"
if not os.path.exists(image_file_path):
response = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY")).images.generate(
model="dall-e-3",
prompt=f"an illustration of the sentence in which the intensity of what is represented as integer is: {answer}. Here is the sentence: '{question}'. in style of a rorschach test, monochrome, abstract, no visible text, white background",
size="1024x1024",
quality="standard",
n=1,
)
image_url = response.data[0].url
image_response = requests.get(image_url)
if image_response.status_code == 200:
with open(image_file_path, 'wb') as file:
file.write(image_response.content)
print(f"Image saved as {image_file_path}")
else:
print(f"Failed to retrieve image. Status code: {image_response.status_code}")
18 changes: 13 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from lib.pid5 import Pid5
from lib.bigfive import BigFive
import argparse
import random
import os

from lib.implementations import OllamaImpl
from lib.implementations import OpenaiImpl
Expand All @@ -15,8 +17,16 @@
parser.add_argument('--prompt', type=str, default=None,
help='the prompt to use')

parser.add_argument('--samples', type=str, default=None,
help='total number of samples')
parser.add_argument('--image', type=bool, default=False,
help='whether to generate images for each items')

parser.add_argument('--tts', type=bool, default=False,
help='whether to generate tts samples for each items')

parser.add_argument('--samples', type=int, default=220,
help='max number of samples')

parser.add_argument('--seed', type=int, default=int.from_bytes(os.urandom(8), byteorder="big"))

args = parser.parse_args()

Expand All @@ -37,6 +47,4 @@
BigFive.ID: BigFive
}

test = TESTS[args.test](model=args.model, prompt=args.prompt, implementation=implementation, samples=args.samples)

test.answer()
test = TESTS[args.test](args, implementation).answer()

0 comments on commit 53db732

Please sign in to comment.