diff --git a/js/plugins/googleai/package.json b/js/plugins/googleai/package.json index 342589d7a..88f70fca7 100644 --- a/js/plugins/googleai/package.json +++ b/js/plugins/googleai/package.json @@ -33,7 +33,7 @@ "dependencies": { "@genkit-ai/ai": "workspace:*", "@genkit-ai/core": "workspace:*", - "@google/generative-ai": "^0.10.0", + "@google/generative-ai": "^0.14.1", "google-auth-library": "^9.6.3", "node-fetch": "^3.3.2", "zod": "^3.22.4" diff --git a/js/plugins/googleai/src/gemini.ts b/js/plugins/googleai/src/gemini.ts index 03a15b2ab..cedbbcbf2 100644 --- a/js/plugins/googleai/src/gemini.ts +++ b/js/plugins/googleai/src/gemini.ts @@ -44,6 +44,7 @@ import { Content as GeminiMessage, Part as GeminiPart, GenerateContentResponse, + GenerationConfig, GoogleGenerativeAI, InlineDataPart, RequestOptions, @@ -112,6 +113,7 @@ export const gemini15Pro = modelRef({ media: true, tools: true, systemRole: true, + output: ['text', 'json'], }, versions: ['gemini-1.5-pro-001'], }, @@ -127,6 +129,7 @@ export const gemini15Flash = modelRef({ media: true, tools: true, systemRole: true, + output: ['text', 'json'], }, versions: ['gemini-1.5-flash-001'], }, @@ -321,7 +324,10 @@ function toGeminiPart(part: Part): GeminiPart { throw new Error('Unsupported Part type'); } -function fromGeminiPart(part: GeminiPart): Part { +function fromGeminiPart(part: GeminiPart, jsonMode: boolean): Part { + if (jsonMode && part.text !== undefined) { + return { data: JSON.parse(part.text) }; + } if (part.text !== undefined) return { text: part.text }; if (part.inlineData) return fromInlineData(part); if (part.functionCall) return fromFunctionCall(part); @@ -363,12 +369,17 @@ function fromGeminiFinishReason( } } -export function fromGeminiCandidate(candidate: GeminiCandidate): CandidateData { +export function fromGeminiCandidate( + candidate: GeminiCandidate, + jsonMode: boolean = false +): CandidateData { return { index: candidate.index || 0, // reasonable default? message: { role: 'model', - content: (candidate.content?.parts || []).map(fromGeminiPart), + content: (candidate.content?.parts || []).map((part) => + fromGeminiPart(part, jsonMode) + ), }, finishReason: fromGeminiFinishReason(candidate.finishReason), finishMessage: candidate.finishMessage, @@ -450,33 +461,44 @@ export function googleAIModel( systemInstruction = toGeminiSystemInstruction(systemMessage); } } - + const generationConfig: GenerationConfig = { + candidateCount: request.candidates || undefined, + temperature: request.config?.temperature, + maxOutputTokens: request.config?.maxOutputTokens, + topK: request.config?.topK, + topP: request.config?.topP, + stopSequences: request.config?.stopSequences, + responseMimeType: + request.output?.format === 'json' || request.output?.schema + ? 'application/json' + : undefined, + }; const chatRequest = { systemInstruction, + generationConfig, tools: request.tools?.length ? [{ functionDeclarations: request.tools?.map(toGeminiTool) }] : [], history: messages .slice(0, -1) .map((message) => toGeminiMessage(message, model)), - generationConfig: { - candidateCount: request.candidates || undefined, - temperature: request.config?.temperature, - maxOutputTokens: request.config?.maxOutputTokens, - topK: request.config?.topK, - topP: request.config?.topP, - stopSequences: request.config?.stopSequences, - }, safetySettings: request.config?.safetySettings, } as StartChatParams; const msg = toGeminiMessage(messages[messages.length - 1], model); + const jsonMode = + request.output?.format === 'json' || !!request.output?.schema; + const fromJSONModeScopedGeminiCandidate = ( + candidate: GeminiCandidate + ) => { + return fromGeminiCandidate(candidate, jsonMode); + }; if (streamingCallback) { const result = await client .startChat(chatRequest) .sendMessageStream(msg.parts); for await (const item of result.stream) { (item as GenerateContentResponse).candidates?.forEach((candidate) => { - const c = fromGeminiCandidate(candidate); + const c = fromJSONModeScopedGeminiCandidate(candidate); streamingCallback({ index: c.index, content: c.message.content, @@ -488,7 +510,8 @@ export function googleAIModel( throw new Error('No valid candidates returned.'); } return { - candidates: response.candidates?.map(fromGeminiCandidate) || [], + candidates: + response.candidates?.map(fromJSONModeScopedGeminiCandidate) || [], custom: response, }; } else { @@ -498,7 +521,8 @@ export function googleAIModel( if (!result.response.candidates?.length) throw new Error('No valid candidates returned.'); const responseCandidates = - result.response.candidates?.map(fromGeminiCandidate) || []; + result.response.candidates?.map(fromJSONModeScopedGeminiCandidate) || + []; return { candidates: responseCandidates, custom: result.response, diff --git a/js/pnpm-lock.yaml b/js/pnpm-lock.yaml index 4d436b8f0..9e4454938 100644 --- a/js/pnpm-lock.yaml +++ b/js/pnpm-lock.yaml @@ -453,8 +453,8 @@ importers: specifier: workspace:* version: link:../../core '@google/generative-ai': - specifier: ^0.10.0 - version: 0.10.0 + specifier: ^0.14.1 + version: 0.14.1 google-auth-library: specifier: ^9.6.3 version: 9.7.0(encoding@0.1.13) @@ -1451,8 +1451,8 @@ packages: resolution: {integrity: sha512-hfwfdlVpJ+kM6o2b5UFfPnweBcz8tgHAFRswnqUKYqLJsvKU0DDD0Z2/YKoHyAUoPJAv20qg6KlC3msNeUKUiw==} engines: {node: '>=18.0.0'} - '@google/generative-ai@0.10.0': - resolution: {integrity: sha512-fZJEL8DcDgvBCguLdaAdBBEoh+83LDXK3m9rVh5iksvwVJDgZqkpsLGKJuM5FEBKltWhbJ62WSyMEUGgy8eMUg==} + '@google/generative-ai@0.14.1': + resolution: {integrity: sha512-pevEyZCb0Oc+dYNlSberW8oZBm4ofeTD5wN01TowQMhTwdAbGAnJMtQzoklh6Blq2AKsx8Ox6FWa44KioZLZiA==} engines: {node: '>=18.0.0'} '@grpc/grpc-js@1.10.10': @@ -5178,7 +5178,7 @@ snapshots: - encoding - supports-color - '@google/generative-ai@0.10.0': {} + '@google/generative-ai@0.14.1': {} '@grpc/grpc-js@1.10.10': dependencies: diff --git a/js/testapps/flow-simple-ai/src/index.ts b/js/testapps/flow-simple-ai/src/index.ts index 352aca188..b25f8f71b 100644 --- a/js/testapps/flow-simple-ai/src/index.ts +++ b/js/testapps/flow-simple-ai/src/index.ts @@ -229,6 +229,32 @@ export const jokeWithToolsFlow = defineFlow( } ); +const outputSchema = z.object({ + joke: z.string(), +}); + +export const jokeWithOutputFlow = defineFlow( + { + name: 'jokeWithOutputFlow', + inputSchema: z.object({ + modelName: z.enum([gemini15Flash.name]), + subject: z.string(), + }), + outputSchema, + }, + async (input) => { + const llmResponse = await generate({ + model: input.modelName, + output: { + format: 'json', + schema: outputSchema, + }, + prompt: `Tell a joke about ${input.subject}.`, + }); + return { ...llmResponse.output()! }; + } +); + export const vertexStreamer = defineFlow( { name: 'vertexStreamer',