diff --git a/genkit-tools/common/src/types/model.ts b/genkit-tools/common/src/types/model.ts index 5ce711c17..d3f218ff4 100644 --- a/genkit-tools/common/src/types/model.ts +++ b/genkit-tools/common/src/types/model.ts @@ -245,6 +245,9 @@ export type GenerateResponseData = z.infer; /** ModelResponseChunkSchema represents a chunk of content to stream to the client. */ export const ModelResponseChunkSchema = z.object({ + role: RoleSchema.optional(), + /** index of the message this chunk belongs to. */ + index: z.number().optional(), /** The chunk of content to stream right now. */ content: z.array(PartSchema), /** Model-specific extra information attached to this chunk. */ @@ -254,10 +257,7 @@ export const ModelResponseChunkSchema = z.object({ }); export type ModelResponseChunkData = z.infer; -export const GenerateResponseChunkSchema = ModelResponseChunkSchema.extend({ - /** @deprecated The index of the candidate this chunk belongs to. Always 0. */ - index: z.number(), -}); +export const GenerateResponseChunkSchema = ModelResponseChunkSchema.extend({}); export type GenerateResponseChunkData = z.infer< typeof GenerateResponseChunkSchema >; diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index 9d629ed14..d6f36262f 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -452,6 +452,12 @@ "GenerateResponseChunk": { "type": "object", "properties": { + "role": { + "$ref": "#/$defs/Role" + }, + "index": { + "type": "number" + }, "content": { "type": "array", "items": { @@ -461,14 +467,10 @@ "custom": {}, "aggregated": { "type": "boolean" - }, - "index": { - "type": "number" } }, "required": [ - "content", - "index" + "content" ], "additionalProperties": false }, @@ -714,6 +716,12 @@ "ModelResponseChunk": { "type": "object", "properties": { + "role": { + "$ref": "#/$defs/GenerateResponseChunk/properties/role" + }, + "index": { + "$ref": "#/$defs/GenerateResponseChunk/properties/index" + }, "content": { "$ref": "#/$defs/GenerateResponseChunk/properties/content" }, diff --git a/js/ai/src/generate/action.ts b/js/ai/src/generate/action.ts index 675c8be3a..8a31a87bb 100644 --- a/js/ai/src/generate/action.ts +++ b/js/ai/src/generate/action.ts @@ -83,9 +83,11 @@ export async function generateHelper( registry: Registry, input: z.infer, middleware?: ModelMiddleware[], - currentTurns?: number + currentTurns?: number, + messageIndex?: number ): Promise { currentTurns = currentTurns ?? 0; + messageIndex = messageIndex ?? 0; // do tracing return await runInNewSpan( registry, @@ -100,7 +102,13 @@ export async function generateHelper( async (metadata) => { metadata.name = 'generate'; metadata.input = input; - const output = await generate(registry, input, middleware, currentTurns!); + const output = await generate( + registry, + input, + middleware, + currentTurns!, + messageIndex! + ); metadata.output = JSON.stringify(output); return output; } @@ -111,7 +119,8 @@ async function generate( registry: Registry, rawRequest: z.infer, middleware: ModelMiddleware[] | undefined, - currentTurn: number + currentTurn: number, + messageIndex: number ): Promise { const { modelAction: model } = await resolveModel(registry, rawRequest.model); if (model.__action.metadata?.model.stage === 'deprecated') { @@ -152,15 +161,17 @@ async function generate( streamingCallback ? (chunk: GenerateResponseChunkData) => { // Store accumulated chunk data - streamingCallback( - new GenerateResponseChunk(chunk, { - index: 0, - role: 'model', - previousChunks: accumulatedChunks, - parser: resolvedFormat?.handler(request.output?.schema) - .parseChunk, - }) - ); + if (streamingCallback) { + streamingCallback!( + new GenerateResponseChunk(chunk, { + index: messageIndex, + role: 'model', + previousChunks: accumulatedChunks, + parser: resolvedFormat?.handler(request.output?.schema) + .parseChunk, + }) + ); + } accumulatedChunks.push(chunk); } : undefined, @@ -246,6 +257,7 @@ async function generate( }); } } + messageIndex++; const nextRequest = { ...rawRequest, messages: [ @@ -257,11 +269,26 @@ async function generate( ] as MessageData[], tools: newTools, }; + // stream out the tool responses + streamingCallback?.( + new GenerateResponseChunk( + { + content: toolResponses, + }, + { + index: messageIndex, + role: 'model', + previousChunks: accumulatedChunks, + parser: resolvedFormat?.handler(request.output?.schema).parseChunk, + } + ) + ); return await generateHelper( registry, nextRequest, middleware, - currentTurn + 1 + currentTurn + 1, + messageIndex + 1 ); } diff --git a/js/ai/src/generate/chunk.ts b/js/ai/src/generate/chunk.ts index 4cf20d67a..231cbf939 100644 --- a/js/ai/src/generate/chunk.ts +++ b/js/ai/src/generate/chunk.ts @@ -31,9 +31,9 @@ export class GenerateResponseChunk implements GenerateResponseChunkData { /** The index of the message this chunk corresponds to, starting with `0` for the first model response of the generation. */ - index?: number; + index: number; /** The role of the message this chunk corresponds to. Will always be `model` or `tool`. */ - role?: Role; + role: Role; /** The content generated in this chunk. */ content: Part[]; /** Custom model-specific data for this chunk. */ @@ -45,21 +45,21 @@ export class GenerateResponseChunk constructor( data: GenerateResponseChunkData, - options?: { + options: { previousChunks?: GenerateResponseChunkData[]; - role?: Role; - index?: number; + role: Role; + index: number; parser?: ChunkParser; } ) { this.content = data.content || []; this.custom = data.custom; - this.previousChunks = options?.previousChunks + this.previousChunks = options.previousChunks ? [...options.previousChunks] : undefined; - this.index = options?.index; - this.role = options?.role; - this.parser = options?.parser; + this.index = options.index; + this.role = options.role; + this.parser = options.parser; } /** @@ -130,6 +130,14 @@ export class GenerateResponseChunk } toJSON(): GenerateResponseChunkData { - return { content: this.content, custom: this.custom }; + const data = { + role: this.role, + index: this.index, + content: this.content, + } as GenerateResponseChunkData; + if (this.custom) { + data.custom = this.custom; + } + return data; } } diff --git a/js/ai/src/model.ts b/js/ai/src/model.ts index 71a067b46..56bb3fb09 100644 --- a/js/ai/src/model.ts +++ b/js/ai/src/model.ts @@ -400,6 +400,9 @@ export type GenerateResponseData = z.infer; /** ModelResponseChunkSchema represents a chunk of content to stream to the client. */ export const ModelResponseChunkSchema = z.object({ + role: RoleSchema.optional(), + /** index of the message this chunk belongs to. */ + index: z.number().optional(), /** The chunk of content to stream right now. */ content: z.array(PartSchema), /** Model-specific extra information attached to this chunk. */ @@ -409,10 +412,7 @@ export const ModelResponseChunkSchema = z.object({ }); export type ModelResponseChunkData = z.infer; -export const GenerateResponseChunkSchema = ModelResponseChunkSchema.extend({ - /** @deprecated The index of the candidate this chunk belongs to. Always 0. */ - index: z.number().optional(), -}); +export const GenerateResponseChunkSchema = ModelResponseChunkSchema; export type GenerateResponseChunkData = z.infer< typeof GenerateResponseChunkSchema >; diff --git a/js/ai/tests/formats/array_test.ts b/js/ai/tests/formats/array_test.ts index a56752950..e85729b31 100644 --- a/js/ai/tests/formats/array_test.ts +++ b/js/ai/tests/formats/array_test.ts @@ -71,11 +71,17 @@ describe('arrayFormat', () => { for (const chunk of st.chunks) { const newChunk: GenerateResponseChunkData = { + index: 0, + role: 'model', content: [{ text: chunk.text }], }; const result = parser.parseChunk!( - new GenerateResponseChunk(newChunk, { previousChunks: chunks }) + new GenerateResponseChunk(newChunk, { + index: 0, + role: 'model', + previousChunks: chunks, + }) ); chunks.push(newChunk); diff --git a/js/ai/tests/formats/json_test.ts b/js/ai/tests/formats/json_test.ts index dff531b73..6a8e94420 100644 --- a/js/ai/tests/formats/json_test.ts +++ b/js/ai/tests/formats/json_test.ts @@ -67,11 +67,17 @@ describe('jsonFormat', () => { for (const chunk of st.chunks) { const newChunk: GenerateResponseChunkData = { + index: 0, + role: 'model', content: [{ text: chunk.text }], }; const result = parser.parseChunk!( - new GenerateResponseChunk(newChunk, { previousChunks: [...chunks] }) + new GenerateResponseChunk(newChunk, { + index: 0, + role: 'model', + previousChunks: [...chunks], + }) ); chunks.push(newChunk); diff --git a/js/ai/tests/formats/jsonl_test.ts b/js/ai/tests/formats/jsonl_test.ts index 4a58122fb..a5f3bce77 100644 --- a/js/ai/tests/formats/jsonl_test.ts +++ b/js/ai/tests/formats/jsonl_test.ts @@ -80,11 +80,17 @@ describe('jsonlFormat', () => { for (const chunk of st.chunks) { const newChunk: GenerateResponseChunkData = { + index: 0, + role: 'model', content: [{ text: chunk.text }], }; const result = parser.parseChunk!( - new GenerateResponseChunk(newChunk, { previousChunks: chunks }) + new GenerateResponseChunk(newChunk, { + index: 0, + role: 'model', + previousChunks: chunks, + }) ); chunks.push(newChunk); diff --git a/js/ai/tests/formats/text_test.ts b/js/ai/tests/formats/text_test.ts index 5dabc35f4..2c9fecd5f 100644 --- a/js/ai/tests/formats/text_test.ts +++ b/js/ai/tests/formats/text_test.ts @@ -54,11 +54,17 @@ describe('textFormat', () => { for (const chunk of st.chunks) { const newChunk: GenerateResponseChunkData = { + index: 0, + role: 'model', content: [{ text: chunk.text }], }; const result = parser.parseChunk!( - new GenerateResponseChunk(newChunk, { previousChunks: chunks }) + new GenerateResponseChunk(newChunk, { + index: 0, + role: 'model', + previousChunks: chunks, + }) ); chunks.push(newChunk); diff --git a/js/ai/tests/generate/chunk_test.ts b/js/ai/tests/generate/chunk_test.ts index febc670a1..21c5d6718 100644 --- a/js/ai/tests/generate/chunk_test.ts +++ b/js/ai/tests/generate/chunk_test.ts @@ -21,11 +21,11 @@ import { GenerateResponseChunk } from '../../src/generate.js'; describe('GenerateResponseChunk', () => { describe('text accumulation', () => { const testChunk = new GenerateResponseChunk( - { content: [{ text: 'new' }] }, + { index: 0, role: 'model', content: [{ text: 'new' }] }, { previousChunks: [ - { content: [{ text: 'old1' }] }, - { content: [{ text: 'old2' }] }, + { index: 0, role: 'model', content: [{ text: 'old1' }] }, + { index: 0, role: 'model', content: [{ text: 'old2' }] }, ], } ); diff --git a/js/ai/tests/generate/generate_test.ts b/js/ai/tests/generate/generate_test.ts index c7ff5ec10..594375144 100644 --- a/js/ai/tests/generate/generate_test.ts +++ b/js/ai/tests/generate/generate_test.ts @@ -369,7 +369,7 @@ describe('generate', () => { }); describe('generateStream', () => { - it('should pass a smoke test', async () => { + it('should stream out chunks', async () => { let registry = new Registry(); defineModel( @@ -390,11 +390,22 @@ describe('generate', () => { prompt: 'Testing streaming', }); - let streamed: string[] = []; + let streamed: any[] = []; for await (const chunk of stream) { - streamed.push(chunk.text); + streamed.push(chunk.toJSON()); } - assert.deepEqual(streamed, ['hello, ', 'world!']); + assert.deepStrictEqual(streamed, [ + { + index: 0, + role: 'model', + content: [{ text: 'hello, ' }], + }, + { + index: 0, + role: 'model', + content: [{ text: 'world!' }], + }, + ]); assert.deepEqual( (await response).messages.map((m) => m.content[0].text), ['Testing streaming', 'Testing streaming'] diff --git a/js/genkit/tests/generate_test.ts b/js/genkit/tests/generate_test.ts index 3120b458e..75a954898 100644 --- a/js/genkit/tests/generate_test.ts +++ b/js/genkit/tests/generate_test.ts @@ -401,19 +401,105 @@ describe('generate', () => { }, }; }; - const { text, output } = await ai.generate({ output: { schema }, prompt: 'call the tool', tools: ['testTool'], }); - assert.strictEqual(text, "```\n{foo: 'fromModel'}\n```"); assert.deepStrictEqual(output, { foo: 'fromModel', }); }); + it('streams the tool responses', async () => { + ai.defineTool( + { name: 'testTool', description: 'description' }, + async () => 'tool called' + ); + + // first response be tools call, the subsequent just text response from agent b. + let reqCounter = 0; + pm.handleResponse = async (req, sc) => { + if (sc) { + sc({ + content: [ + reqCounter === 0 + ? { + toolRequest: { + name: 'testTool', + input: {}, + ref: 'ref123', + }, + } + : { text: 'done' }, + ], + }); + } + return { + message: { + role: 'model', + content: [ + reqCounter++ === 0 + ? { + toolRequest: { + name: 'testTool', + input: {}, + ref: 'ref123', + }, + } + : { text: 'done' }, + ], + }, + }; + }; + + const { stream, response } = await ai.generateStream({ + prompt: 'call the tool', + tools: ['testTool'], + }); + + const chunks: any[] = []; + for await (const chunk of stream) { + chunks.push(chunk.toJSON()); + } + + assert.strictEqual((await response).text, 'done'); + assert.deepStrictEqual(chunks, [ + { + content: [ + { + toolRequest: { + input: {}, + name: 'testTool', + ref: 'ref123', + }, + }, + ], + index: 0, + role: 'model', + }, + { + content: [ + { + toolResponse: { + name: 'testTool', + output: 'tool called', + ref: 'ref123', + }, + }, + ], + index: 1, + role: 'model', + }, + { + content: [{ text: 'done' }], + index: 2, + role: 'model', + }, + ]); + }); + it('throws when exceeding max tool call iterations', async () => { ai.defineTool( { name: 'testTool', description: 'description' }, diff --git a/js/plugins/express/tests/express_test.ts b/js/plugins/express/tests/express_test.ts index 4866ef3f6..77d9bc85b 100644 --- a/js/plugins/express/tests/express_test.ts +++ b/js/plugins/express/tests/express_test.ts @@ -288,9 +288,9 @@ describe('expressHandler', async () => { } assert.deepStrictEqual(gotChunks, [ - { content: [{ text: '3' }] }, - { content: [{ text: '2' }] }, - { content: [{ text: '1' }] }, + { index: 0, role: 'model', content: [{ text: '3' }] }, + { index: 0, role: 'model', content: [{ text: '2' }] }, + { index: 0, role: 'model', content: [{ text: '1' }] }, ]); assert.strictEqual(await result.output(), 'Echo: olleh'); @@ -507,9 +507,9 @@ describe('startFlowServer', async () => { } assert.deepStrictEqual(gotChunks, [ - { content: [{ text: '3' }] }, - { content: [{ text: '2' }] }, - { content: [{ text: '1' }] }, + { index: 0, role: 'model', content: [{ text: '3' }] }, + { index: 0, role: 'model', content: [{ text: '2' }] }, + { index: 0, role: 'model', content: [{ text: '1' }] }, ]); assert.strictEqual(await result.output(), 'Echo: olleh'); diff --git a/js/testapps/flow-simple-ai/src/index.ts b/js/testapps/flow-simple-ai/src/index.ts index 419c18624..e973c0eca 100644 --- a/js/testapps/flow-simple-ai/src/index.ts +++ b/js/testapps/flow-simple-ai/src/index.ts @@ -572,11 +572,12 @@ ai.defineFlow( inputSchema: z.string(), outputSchema: z.string(), }, - async (query) => { + async (query, { sendChunk }) => { const { text } = await ai.generate({ model: gemini15Flash, prompt: query, tools: ['math/add', 'math/subtract'], + onChunk: sendChunk, }); return text; }