Skip to content

Commit

Permalink
fix (ui): single assistant message with multiple tool steps (#4591)
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrammel authored Jan 30, 2025
1 parent 30e7cb5 commit 0d2d9bf
Show file tree
Hide file tree
Showing 25 changed files with 1,573 additions and 826 deletions.
8 changes: 8 additions & 0 deletions .changeset/lucky-carpets-boil.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
'@ai-sdk/svelte': patch
'@ai-sdk/react': patch
'@ai-sdk/solid': patch
'@ai-sdk/vue': patch
---

fix (ui): empty submits (with allowEmptySubmit) create user messages
10 changes: 10 additions & 0 deletions .changeset/tasty-insects-applaud.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
---
'@ai-sdk/ui-utils': patch
'@ai-sdk/svelte': patch
'@ai-sdk/react': patch
'@ai-sdk/solid': patch
'@ai-sdk/vue': patch
'ai': patch
---

fix (ui): single assistant message with multiple tool steps
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { loadChat, saveChat } from '@util/chat-store';
import {
appendClientMessage,
appendResponseMessages,
createDataStreamResponse,
createIdGenerator,
streamText,
tool,
Expand All @@ -12,6 +13,8 @@ import { z } from 'zod';
// Allow streaming responses up to 30 seconds
export const maxDuration = 30;

let count = 0;

export async function POST(req: Request) {
// get the last message from the client:
const { message, id } = await req.json();
Expand All @@ -25,55 +28,87 @@ export async function POST(req: Request) {
message,
});

const result = streamText({
model: openai('gpt-4o-mini'),
messages,
toolCallStreaming: true,
maxSteps: 5, // multi-steps for server-side tools
tools: {
// server-side tool with execute function:
getWeatherInformation: tool({
description: 'show the weather in a given city to the user',
parameters: z.object({ city: z.string() }),
execute: async ({}: { city: string }) => {
// Add artificial delay of 2 seconds
await new Promise(resolve => setTimeout(resolve, 2000));
// immediately start streaming (solves RAG issues with status, etc.)
return createDataStreamResponse({
execute: dataStream => {
dataStream.writeMessageAnnotation({
start: 'start',
count: count++,
});

const result = streamText({
model: openai('gpt-4o'),
messages,
toolCallStreaming: true,
maxSteps: 5, // multi-steps for server-side tools
tools: {
// server-side tool with execute function:
getWeatherInformation: tool({
description: 'show the weather in a given city to the user',
parameters: z.object({ city: z.string() }),
execute: async ({ city }: { city: string }) => {
// Add artificial delay of 2 seconds
await new Promise(resolve => setTimeout(resolve, 2000));

const weatherOptions = [
'sunny',
'cloudy',
'rainy',
'snowy',
'windy',
];

const weather =
weatherOptions[
Math.floor(Math.random() * weatherOptions.length)
];

dataStream.writeMessageAnnotation({
city,
weather,
});

const weatherOptions = ['sunny', 'cloudy', 'rainy', 'snowy', 'windy'];
return weatherOptions[
Math.floor(Math.random() * weatherOptions.length)
];
return weather;
},
}),
// client-side tool that starts user interaction:
askForConfirmation: tool({
description: 'Ask the user for confirmation.',
parameters: z.object({
message: z
.string()
.describe('The message to ask for confirmation.'),
}),
}),
// client-side tool that is automatically executed on the client:
getLocation: tool({
description:
'Get the user location. Always ask for confirmation before using this tool.',
parameters: z.object({}),
}),
},
}),
// client-side tool that starts user interaction:
askForConfirmation: tool({
description: 'Ask the user for confirmation.',
parameters: z.object({
message: z.string().describe('The message to ask for confirmation.'),
}),
}),
// client-side tool that is automatically executed on the client:
getLocation: tool({
description:
'Get the user location. Always ask for confirmation before using this tool.',
parameters: z.object({}),
}),
},
// id format for server-side messages:
experimental_generateMessageId: createIdGenerator({
prefix: 'msgs',
size: 16,
}),
async onFinish({ response }) {
await saveChat({
id,
messages: appendResponseMessages({
messages,
responseMessages: response.messages,
// id format for server-side messages:
experimental_generateMessageId: createIdGenerator({
prefix: 'msgs',
size: 16,
}),
async onFinish({ response }) {
await saveChat({
id,
messages: appendResponseMessages({
messages,
responseMessages: response.messages,
}),
});
},
});

result.mergeIntoDataStream(dataStream);
},
onError: error => {
// Error messages are masked by default for security reasons.
// If you want to expose the error message to the client, you can do so here:
return error instanceof Error ? error.message : String(error);
},
});

return result.toDataStreamResponse();
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ export default function Chat({
{messages?.map((m: Message) => (
<div key={m.id} className="whitespace-pre-wrap">
<strong>{`${m.role}: `}</strong>
{m.content}
{m.toolInvocations?.map((toolInvocation: ToolInvocation) => {
const toolCallId = toolInvocation.toolCallId;

Expand Down Expand Up @@ -103,7 +102,13 @@ export default function Chat({
Calling {toolInvocation.toolName}...
</div>
);
})}
})}{' '}
{m.annotations && (
<pre className="p-4 text-sm bg-gray-100">
{JSON.stringify(m.annotations, null, 2)}
</pre>
)}
{m.content}
<br />
<br />
</div>
Expand Down
2 changes: 1 addition & 1 deletion examples/next-openai/app/use-chat-tools/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ export default function Chat() {
{messages?.map((m: Message) => (
<div key={m.id} className="whitespace-pre-wrap">
<strong>{`${m.role}: `}</strong>
{m.content}
{m.toolInvocations?.map((toolInvocation: ToolInvocation) => {
const toolCallId = toolInvocation.toolCallId;

Expand Down Expand Up @@ -92,6 +91,7 @@ export default function Chat() {
</div>
);
})}
{m.content}
<br />
<br />
</div>
Expand Down
Loading

0 comments on commit 0d2d9bf

Please sign in to comment.