-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9fb9d6d
commit f25fe0a
Showing
3 changed files
with
124 additions
and
106 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
|
||
|
||
export default function handler(req: NextApiRequest, res: NextApiResponse) { | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,131 +1,148 @@ | ||
import { TRPCError } from "@trpc/server"; | ||
import { Configuration, OpenAIApi } from "openai"; | ||
import { GENERATE_IMAGE_SIZE, MAX_PROMPT_LENGTH } from "~/common/constants"; | ||
import { env } from '~/env.mjs'; | ||
import { OpenAI } from "openai"; | ||
import { MAX_PROMPT_LENGTH } from "~/common/constants"; | ||
import { env } from "~/env.mjs"; | ||
|
||
const configuration = new Configuration({ | ||
apiKey: env.OPENAI_API_KEY | ||
const openAi = new OpenAI({ | ||
apiKey: env.OPENAI_API_KEY, | ||
}); | ||
|
||
const openAi = new OpenAIApi(configuration); | ||
|
||
interface GenerateImageOptions { | ||
prompt: string; | ||
userId?: string; | ||
count: number | ||
prompt: string; | ||
userId?: string; | ||
count: number; | ||
} | ||
|
||
export interface GeneratedImageData { | ||
blob: Blob, | ||
url: string; | ||
prompt: string; | ||
blob: Blob; | ||
url: string; | ||
prompt: string; | ||
} | ||
|
||
// eslint-disable-next-line @typescript-eslint/no-namespace | ||
export namespace AI { | ||
export async function generateImages({ prompt, userId, count }: GenerateImageOptions) { | ||
const imageResponse = await openAi.createImage({ | ||
prompt, | ||
n: count, | ||
response_format: 'url', | ||
size: GENERATE_IMAGE_SIZE, | ||
user: userId | ||
}); | ||
|
||
const imageData = imageResponse.data; | ||
|
||
if (imageData.data.length === 0) { | ||
throw new Error("no were returned from OpenAI"); | ||
} | ||
|
||
console.log("generated images: ", imageData.data); | ||
|
||
const imagesPromises: Promise<GeneratedImageData>[] = []; | ||
|
||
for (const d of imageData.data) { | ||
const imageUrl = d.url; | ||
export async function generateImages({ | ||
prompt, | ||
userId, | ||
count, | ||
}: GenerateImageOptions) { | ||
const imageResponse = await openAi.images.generate({ | ||
prompt, | ||
n: count, | ||
response_format: "url", | ||
model: "dall-e-3", | ||
user: userId, | ||
}); | ||
|
||
const imageData = imageResponse.data; | ||
|
||
if (imageData.length === 0) { | ||
throw new Error("no were returned from OpenAI"); | ||
} | ||
|
||
if (imageUrl == null) { | ||
throw new Error("Image url was null"); | ||
} | ||
console.log("generated images: ", imageData); | ||
|
||
const fetchImage = async () => { | ||
const res = await fetch(imageUrl); | ||
const imagesPromises: Promise<GeneratedImageData>[] = []; | ||
|
||
if (!res.ok) { | ||
const error = await res.text(); | ||
throw new Error(`Failed to fetch image (${res.statusText}) ${imageUrl}: ${error}`); | ||
} | ||
for (const d of imageData) { | ||
const imageUrl = d.url; | ||
|
||
const imageBlob = await res.blob(); | ||
if (imageUrl == null) { | ||
throw new Error("Image url was null"); | ||
} | ||
|
||
return { | ||
blob: imageBlob, | ||
url: imageUrl, | ||
prompt | ||
} | ||
} | ||
const fetchImage = async () => { | ||
const res = await fetch(imageUrl); | ||
|
||
imagesPromises.push(fetchImage()); | ||
if (!res.ok) { | ||
const error = await res.text(); | ||
throw new Error( | ||
`Failed to fetch image (${res.statusText}) ${imageUrl}: ${error}` | ||
); | ||
} | ||
|
||
const result = await Promise.all(imagesPromises); | ||
console.log(`${result.length} images were generated for prompt: ${prompt}`); | ||
return result; | ||
} | ||
|
||
export async function moderateContent(input: string) { | ||
const moderationResponse = await openAi.createModeration({ | ||
input | ||
}); | ||
|
||
const data = moderationResponse.data; | ||
const isFlagged = data.results.some(x => x.flagged === true); | ||
const imageBlob = await res.blob(); | ||
|
||
return { | ||
results: data.results, | ||
isFlagged | ||
} | ||
blob: imageBlob, | ||
url: imageUrl, | ||
prompt, | ||
}; | ||
}; | ||
|
||
imagesPromises.push(fetchImage()); | ||
} | ||
|
||
export async function improveImagePrompt({ prompt, userId }: { prompt: string, userId: string }) { | ||
console.log(`Prompt to update: ${prompt}`); | ||
|
||
const ERROR_MESSAGE = "[INVALID PROMPT]"; | ||
const response = await openAi.createChatCompletion({ | ||
model: 'gpt-3.5-turbo', | ||
user: userId, | ||
temperature: 1.6, | ||
messages: [ | ||
{ | ||
// We seed the max characters but currently the AI may not be able to infer that | ||
role: 'system', | ||
content: `You are an assistant that improve image generation prompts, for a given | ||
const result = await Promise.all(imagesPromises); | ||
console.log(`${result.length} images were generated for prompt: ${prompt}`); | ||
return result; | ||
} | ||
|
||
export async function moderateContent(input: string) { | ||
const moderationResponse = await openAi.moderations.create({ | ||
input, | ||
}); | ||
|
||
const isFlagged = moderationResponse.results.some( | ||
(x) => x.flagged === true | ||
); | ||
|
||
return { | ||
results: moderationResponse.results, | ||
isFlagged, | ||
}; | ||
} | ||
|
||
export async function improveImagePrompt({ | ||
prompt, | ||
userId, | ||
}: { | ||
prompt: string; | ||
userId: string; | ||
}) { | ||
console.log(`Prompt to update: ${prompt}`); | ||
|
||
const ERROR_MESSAGE = "[INVALID PROMPT]"; | ||
const response = await openAi.chat.completions.create({ | ||
model: "gpt-3.5-turbo", | ||
user: userId, | ||
temperature: 1.6, | ||
stream: false, | ||
messages: [ | ||
{ | ||
// We seed the max characters but currently the AI may not be able to infer that | ||
role: "system", | ||
content: `You are an assistant that improve image generation prompts, for a given | ||
prompt you MUST return a more detailed version in a single paragraph with less than ${MAX_PROMPT_LENGTH} characters | ||
of text with more details if not specified but if the prompt is not a valid | ||
word or phrase return the text: "${ERROR_MESSAGE}".` | ||
}, | ||
{ | ||
role: 'assistant', | ||
content: prompt | ||
} | ||
] | ||
}); | ||
|
||
const data = response.data; | ||
const choice = data.choices[0]; | ||
const content = choice?.message?.content; | ||
|
||
console.log(`Updated prompt content: '${content ?? ""}'`); | ||
|
||
if (content == null) { | ||
throw new TRPCError({ code: 'INTERNAL_SERVER_ERROR', message: "Invalid OpenAI response" }); | ||
} | ||
|
||
if (content.includes(ERROR_MESSAGE)) { | ||
throw new TRPCError({ code: 'BAD_REQUEST', message: "No enough context to improve the prompt" }) | ||
} | ||
word or phrase return the text: "${ERROR_MESSAGE}".`, | ||
}, | ||
{ | ||
role: "assistant", | ||
content: prompt, | ||
}, | ||
], | ||
}); | ||
|
||
const choice = response.choices[0]; | ||
const content = choice?.message?.content; | ||
|
||
console.log(`Updated prompt content: '${content ?? ""}'`); | ||
|
||
if (content == null) { | ||
throw new TRPCError({ | ||
code: "INTERNAL_SERVER_ERROR", | ||
message: "Invalid OpenAI response", | ||
}); | ||
} | ||
|
||
return content; | ||
if (content.includes(ERROR_MESSAGE)) { | ||
throw new TRPCError({ | ||
code: "BAD_REQUEST", | ||
message: "No enough context to improve the prompt", | ||
}); | ||
} | ||
|
||
return content; | ||
} | ||
} |