Skip to content

Commit

Permalink
feat(js/ai): stream tools responses (#1614)
Browse files Browse the repository at this point in the history
  • Loading branch information
pavelgj authored Jan 16, 2025
1 parent 7d22c8b commit d918681
Show file tree
Hide file tree
Showing 14 changed files with 221 additions and 56 deletions.
8 changes: 4 additions & 4 deletions genkit-tools/common/src/types/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,9 @@ export type GenerateResponseData = z.infer<typeof GenerateResponseSchema>;

/** ModelResponseChunkSchema represents a chunk of content to stream to the client. */
export const ModelResponseChunkSchema = z.object({
role: RoleSchema.optional(),
/** index of the message this chunk belongs to. */
index: z.number().optional(),
/** The chunk of content to stream right now. */
content: z.array(PartSchema),
/** Model-specific extra information attached to this chunk. */
Expand All @@ -254,10 +257,7 @@ export const ModelResponseChunkSchema = z.object({
});
export type ModelResponseChunkData = z.infer<typeof ModelResponseChunkSchema>;

export const GenerateResponseChunkSchema = ModelResponseChunkSchema.extend({
/** @deprecated The index of the candidate this chunk belongs to. Always 0. */
index: z.number(),
});
export const GenerateResponseChunkSchema = ModelResponseChunkSchema.extend({});
export type GenerateResponseChunkData = z.infer<
typeof GenerateResponseChunkSchema
>;
18 changes: 13 additions & 5 deletions genkit-tools/genkit-schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,12 @@
"GenerateResponseChunk": {
"type": "object",
"properties": {
"role": {
"$ref": "#/$defs/Role"
},
"index": {
"type": "number"
},
"content": {
"type": "array",
"items": {
Expand All @@ -461,14 +467,10 @@
"custom": {},
"aggregated": {
"type": "boolean"
},
"index": {
"type": "number"
}
},
"required": [
"content",
"index"
"content"
],
"additionalProperties": false
},
Expand Down Expand Up @@ -714,6 +716,12 @@
"ModelResponseChunk": {
"type": "object",
"properties": {
"role": {
"$ref": "#/$defs/GenerateResponseChunk/properties/role"
},
"index": {
"$ref": "#/$defs/GenerateResponseChunk/properties/index"
},
"content": {
"$ref": "#/$defs/GenerateResponseChunk/properties/content"
},
Expand Down
53 changes: 40 additions & 13 deletions js/ai/src/generate/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,11 @@ export async function generateHelper(
registry: Registry,
input: z.infer<typeof GenerateUtilParamSchema>,
middleware?: ModelMiddleware[],
currentTurns?: number
currentTurns?: number,
messageIndex?: number
): Promise<GenerateResponseData> {
currentTurns = currentTurns ?? 0;
messageIndex = messageIndex ?? 0;
// do tracing
return await runInNewSpan(
registry,
Expand All @@ -100,7 +102,13 @@ export async function generateHelper(
async (metadata) => {
metadata.name = 'generate';
metadata.input = input;
const output = await generate(registry, input, middleware, currentTurns!);
const output = await generate(
registry,
input,
middleware,
currentTurns!,
messageIndex!
);
metadata.output = JSON.stringify(output);
return output;
}
Expand All @@ -111,7 +119,8 @@ async function generate(
registry: Registry,
rawRequest: z.infer<typeof GenerateUtilParamSchema>,
middleware: ModelMiddleware[] | undefined,
currentTurn: number
currentTurn: number,
messageIndex: number
): Promise<GenerateResponseData> {
const { modelAction: model } = await resolveModel(registry, rawRequest.model);
if (model.__action.metadata?.model.stage === 'deprecated') {
Expand Down Expand Up @@ -152,15 +161,17 @@ async function generate(
streamingCallback
? (chunk: GenerateResponseChunkData) => {
// Store accumulated chunk data
streamingCallback(
new GenerateResponseChunk(chunk, {
index: 0,
role: 'model',
previousChunks: accumulatedChunks,
parser: resolvedFormat?.handler(request.output?.schema)
.parseChunk,
})
);
if (streamingCallback) {
streamingCallback!(
new GenerateResponseChunk(chunk, {
index: messageIndex,
role: 'model',
previousChunks: accumulatedChunks,
parser: resolvedFormat?.handler(request.output?.schema)
.parseChunk,
})
);
}
accumulatedChunks.push(chunk);
}
: undefined,
Expand Down Expand Up @@ -246,6 +257,7 @@ async function generate(
});
}
}
messageIndex++;
const nextRequest = {
...rawRequest,
messages: [
Expand All @@ -257,11 +269,26 @@ async function generate(
] as MessageData[],
tools: newTools,
};
// stream out the tool responses
streamingCallback?.(
new GenerateResponseChunk(
{
content: toolResponses,
},
{
index: messageIndex,
role: 'model',
previousChunks: accumulatedChunks,
parser: resolvedFormat?.handler(request.output?.schema).parseChunk,
}
)
);
return await generateHelper(
registry,
nextRequest,
middleware,
currentTurn + 1
currentTurn + 1,
messageIndex + 1
);
}

Expand Down
28 changes: 18 additions & 10 deletions js/ai/src/generate/chunk.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ export class GenerateResponseChunk<T = unknown>
implements GenerateResponseChunkData
{
/** The index of the message this chunk corresponds to, starting with `0` for the first model response of the generation. */
index?: number;
index: number;
/** The role of the message this chunk corresponds to. Will always be `model` or `tool`. */
role?: Role;
role: Role;
/** The content generated in this chunk. */
content: Part[];
/** Custom model-specific data for this chunk. */
Expand All @@ -45,21 +45,21 @@ export class GenerateResponseChunk<T = unknown>

constructor(
data: GenerateResponseChunkData,
options?: {
options: {
previousChunks?: GenerateResponseChunkData[];
role?: Role;
index?: number;
role: Role;
index: number;
parser?: ChunkParser<T>;
}
) {
this.content = data.content || [];
this.custom = data.custom;
this.previousChunks = options?.previousChunks
this.previousChunks = options.previousChunks
? [...options.previousChunks]
: undefined;
this.index = options?.index;
this.role = options?.role;
this.parser = options?.parser;
this.index = options.index;
this.role = options.role;
this.parser = options.parser;
}

/**
Expand Down Expand Up @@ -130,6 +130,14 @@ export class GenerateResponseChunk<T = unknown>
}

toJSON(): GenerateResponseChunkData {
return { content: this.content, custom: this.custom };
const data = {
role: this.role,
index: this.index,
content: this.content,
} as GenerateResponseChunkData;
if (this.custom) {
data.custom = this.custom;
}
return data;
}
}
8 changes: 4 additions & 4 deletions js/ai/src/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,9 @@ export type GenerateResponseData = z.infer<typeof GenerateResponseSchema>;

/** ModelResponseChunkSchema represents a chunk of content to stream to the client. */
export const ModelResponseChunkSchema = z.object({
role: RoleSchema.optional(),
/** index of the message this chunk belongs to. */
index: z.number().optional(),
/** The chunk of content to stream right now. */
content: z.array(PartSchema),
/** Model-specific extra information attached to this chunk. */
Expand All @@ -409,10 +412,7 @@ export const ModelResponseChunkSchema = z.object({
});
export type ModelResponseChunkData = z.infer<typeof ModelResponseChunkSchema>;

export const GenerateResponseChunkSchema = ModelResponseChunkSchema.extend({
/** @deprecated The index of the candidate this chunk belongs to. Always 0. */
index: z.number().optional(),
});
export const GenerateResponseChunkSchema = ModelResponseChunkSchema;
export type GenerateResponseChunkData = z.infer<
typeof GenerateResponseChunkSchema
>;
Expand Down
8 changes: 7 additions & 1 deletion js/ai/tests/formats/array_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,17 @@ describe('arrayFormat', () => {

for (const chunk of st.chunks) {
const newChunk: GenerateResponseChunkData = {
index: 0,
role: 'model',
content: [{ text: chunk.text }],
};

const result = parser.parseChunk!(
new GenerateResponseChunk(newChunk, { previousChunks: chunks })
new GenerateResponseChunk(newChunk, {
index: 0,
role: 'model',
previousChunks: chunks,
})
);
chunks.push(newChunk);

Expand Down
8 changes: 7 additions & 1 deletion js/ai/tests/formats/json_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,17 @@ describe('jsonFormat', () => {

for (const chunk of st.chunks) {
const newChunk: GenerateResponseChunkData = {
index: 0,
role: 'model',
content: [{ text: chunk.text }],
};

const result = parser.parseChunk!(
new GenerateResponseChunk(newChunk, { previousChunks: [...chunks] })
new GenerateResponseChunk(newChunk, {
index: 0,
role: 'model',
previousChunks: [...chunks],
})
);
chunks.push(newChunk);

Expand Down
8 changes: 7 additions & 1 deletion js/ai/tests/formats/jsonl_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,17 @@ describe('jsonlFormat', () => {

for (const chunk of st.chunks) {
const newChunk: GenerateResponseChunkData = {
index: 0,
role: 'model',
content: [{ text: chunk.text }],
};

const result = parser.parseChunk!(
new GenerateResponseChunk(newChunk, { previousChunks: chunks })
new GenerateResponseChunk(newChunk, {
index: 0,
role: 'model',
previousChunks: chunks,
})
);
chunks.push(newChunk);

Expand Down
8 changes: 7 additions & 1 deletion js/ai/tests/formats/text_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,17 @@ describe('textFormat', () => {

for (const chunk of st.chunks) {
const newChunk: GenerateResponseChunkData = {
index: 0,
role: 'model',
content: [{ text: chunk.text }],
};

const result = parser.parseChunk!(
new GenerateResponseChunk(newChunk, { previousChunks: chunks })
new GenerateResponseChunk(newChunk, {
index: 0,
role: 'model',
previousChunks: chunks,
})
);
chunks.push(newChunk);

Expand Down
6 changes: 3 additions & 3 deletions js/ai/tests/generate/chunk_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ import { GenerateResponseChunk } from '../../src/generate.js';
describe('GenerateResponseChunk', () => {
describe('text accumulation', () => {
const testChunk = new GenerateResponseChunk(
{ content: [{ text: 'new' }] },
{ index: 0, role: 'model', content: [{ text: 'new' }] },
{
previousChunks: [
{ content: [{ text: 'old1' }] },
{ content: [{ text: 'old2' }] },
{ index: 0, role: 'model', content: [{ text: 'old1' }] },
{ index: 0, role: 'model', content: [{ text: 'old2' }] },
],
}
);
Expand Down
19 changes: 15 additions & 4 deletions js/ai/tests/generate/generate_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ describe('generate', () => {
});

describe('generateStream', () => {
it('should pass a smoke test', async () => {
it('should stream out chunks', async () => {
let registry = new Registry();

defineModel(
Expand All @@ -390,11 +390,22 @@ describe('generate', () => {
prompt: 'Testing streaming',
});

let streamed: string[] = [];
let streamed: any[] = [];
for await (const chunk of stream) {
streamed.push(chunk.text);
streamed.push(chunk.toJSON());
}
assert.deepEqual(streamed, ['hello, ', 'world!']);
assert.deepStrictEqual(streamed, [
{
index: 0,
role: 'model',
content: [{ text: 'hello, ' }],
},
{
index: 0,
role: 'model',
content: [{ text: 'world!' }],
},
]);
assert.deepEqual(
(await response).messages.map((m) => m.content[0].text),
['Testing streaming', 'Testing streaming']
Expand Down
Loading

0 comments on commit d918681

Please sign in to comment.