Skip to content

Commit

Permalink
feat(js): add support for json mode in googleai plugin (#568)
Browse files Browse the repository at this point in the history
  • Loading branch information
huangjeff5 authored Jul 10, 2024
1 parent 3b587e7 commit 9f23469
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 21 deletions.
2 changes: 1 addition & 1 deletion js/plugins/googleai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
"dependencies": {
"@genkit-ai/ai": "workspace:*",
"@genkit-ai/core": "workspace:*",
"@google/generative-ai": "^0.10.0",
"@google/generative-ai": "^0.14.1",
"google-auth-library": "^9.6.3",
"node-fetch": "^3.3.2",
"zod": "^3.22.4"
Expand Down
54 changes: 39 additions & 15 deletions js/plugins/googleai/src/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import {
Content as GeminiMessage,
Part as GeminiPart,
GenerateContentResponse,
GenerationConfig,
GoogleGenerativeAI,
InlineDataPart,
RequestOptions,
Expand Down Expand Up @@ -112,6 +113,7 @@ export const gemini15Pro = modelRef({
media: true,
tools: true,
systemRole: true,
output: ['text', 'json'],
},
versions: ['gemini-1.5-pro-001'],
},
Expand All @@ -127,6 +129,7 @@ export const gemini15Flash = modelRef({
media: true,
tools: true,
systemRole: true,
output: ['text', 'json'],
},
versions: ['gemini-1.5-flash-001'],
},
Expand Down Expand Up @@ -321,7 +324,10 @@ function toGeminiPart(part: Part): GeminiPart {
throw new Error('Unsupported Part type');
}

function fromGeminiPart(part: GeminiPart): Part {
function fromGeminiPart(part: GeminiPart, jsonMode: boolean): Part {
if (jsonMode && part.text !== undefined) {
return { data: JSON.parse(part.text) };
}
if (part.text !== undefined) return { text: part.text };
if (part.inlineData) return fromInlineData(part);
if (part.functionCall) return fromFunctionCall(part);
Expand Down Expand Up @@ -363,12 +369,17 @@ function fromGeminiFinishReason(
}
}

export function fromGeminiCandidate(candidate: GeminiCandidate): CandidateData {
export function fromGeminiCandidate(
candidate: GeminiCandidate,
jsonMode: boolean = false
): CandidateData {
return {
index: candidate.index || 0, // reasonable default?
message: {
role: 'model',
content: (candidate.content?.parts || []).map(fromGeminiPart),
content: (candidate.content?.parts || []).map((part) =>
fromGeminiPart(part, jsonMode)
),
},
finishReason: fromGeminiFinishReason(candidate.finishReason),
finishMessage: candidate.finishMessage,
Expand Down Expand Up @@ -450,33 +461,44 @@ export function googleAIModel(
systemInstruction = toGeminiSystemInstruction(systemMessage);
}
}

const generationConfig: GenerationConfig = {
candidateCount: request.candidates || undefined,
temperature: request.config?.temperature,
maxOutputTokens: request.config?.maxOutputTokens,
topK: request.config?.topK,
topP: request.config?.topP,
stopSequences: request.config?.stopSequences,
responseMimeType:
request.output?.format === 'json' || request.output?.schema
? 'application/json'
: undefined,
};
const chatRequest = {
systemInstruction,
generationConfig,
tools: request.tools?.length
? [{ functionDeclarations: request.tools?.map(toGeminiTool) }]
: [],
history: messages
.slice(0, -1)
.map((message) => toGeminiMessage(message, model)),
generationConfig: {
candidateCount: request.candidates || undefined,
temperature: request.config?.temperature,
maxOutputTokens: request.config?.maxOutputTokens,
topK: request.config?.topK,
topP: request.config?.topP,
stopSequences: request.config?.stopSequences,
},
safetySettings: request.config?.safetySettings,
} as StartChatParams;
const msg = toGeminiMessage(messages[messages.length - 1], model);
const jsonMode =
request.output?.format === 'json' || !!request.output?.schema;
const fromJSONModeScopedGeminiCandidate = (
candidate: GeminiCandidate
) => {
return fromGeminiCandidate(candidate, jsonMode);
};
if (streamingCallback) {
const result = await client
.startChat(chatRequest)
.sendMessageStream(msg.parts);
for await (const item of result.stream) {
(item as GenerateContentResponse).candidates?.forEach((candidate) => {
const c = fromGeminiCandidate(candidate);
const c = fromJSONModeScopedGeminiCandidate(candidate);
streamingCallback({
index: c.index,
content: c.message.content,
Expand All @@ -488,7 +510,8 @@ export function googleAIModel(
throw new Error('No valid candidates returned.');
}
return {
candidates: response.candidates?.map(fromGeminiCandidate) || [],
candidates:
response.candidates?.map(fromJSONModeScopedGeminiCandidate) || [],
custom: response,
};
} else {
Expand All @@ -498,7 +521,8 @@ export function googleAIModel(
if (!result.response.candidates?.length)
throw new Error('No valid candidates returned.');
const responseCandidates =
result.response.candidates?.map(fromGeminiCandidate) || [];
result.response.candidates?.map(fromJSONModeScopedGeminiCandidate) ||
[];
return {
candidates: responseCandidates,
custom: result.response,
Expand Down
10 changes: 5 additions & 5 deletions js/pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 26 additions & 0 deletions js/testapps/flow-simple-ai/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,32 @@ export const jokeWithToolsFlow = defineFlow(
}
);

const outputSchema = z.object({
joke: z.string(),
});

export const jokeWithOutputFlow = defineFlow(
{
name: 'jokeWithOutputFlow',
inputSchema: z.object({
modelName: z.enum([gemini15Flash.name]),
subject: z.string(),
}),
outputSchema,
},
async (input) => {
const llmResponse = await generate({
model: input.modelName,
output: {
format: 'json',
schema: outputSchema,
},
prompt: `Tell a joke about ${input.subject}.`,
});
return { ...llmResponse.output()! };
}
);

export const vertexStreamer = defineFlow(
{
name: 'vertexStreamer',
Expand Down

0 comments on commit 9f23469

Please sign in to comment.