Skip to content

Commit

Permalink
fix (provider/mistral): correctly parse complex content type responses (
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrammel authored Feb 8, 2025
1 parent 3249157 commit e6a7628
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 8 deletions.
5 changes: 5 additions & 0 deletions .changeset/fifty-fans-march.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@ai-sdk/mistral': patch
---

fix (provider/mistral): correctly parse complex content type responses
67 changes: 67 additions & 0 deletions packages/mistral/src/mistral-chat-language-model.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,39 @@ describe('doGenerate', () => {
body: '{"model":"mistral-small-latest","messages":[{"role":"user","content":[{"type":"text","text":"Hello"}]}]}',
});
});

it('should extract text response when message content is a content object', async () => {
server.responseBodyJson = {
object: 'chat.completion',
id: 'object-id',
created: 1711113008,
model: 'mistral-small-latest',
choices: [
{
index: 0,
message: {
role: 'assistant',
content: {
type: 'text',
text: 'Hello from object',
},
tool_calls: null,
},
finish_reason: 'stop',
logprobs: null,
},
],
usage: { prompt_tokens: 4, total_tokens: 34, completion_tokens: 30 },
};

const { text } = await model.doGenerate({
inputFormat: 'prompt',
mode: { type: 'regular' },
prompt: TEST_PROMPT,
});

expect(text).toStrictEqual('Hello from object');
});
});

describe('doStream', () => {
Expand Down Expand Up @@ -545,4 +578,38 @@ describe('doStream', () => {
body: '{"model":"mistral-small-latest","messages":[{"role":"user","content":[{"type":"text","text":"Hello"}]}],"stream":true}',
});
});

it('should stream text with content objects', async () => {
// Instead of using prepareStreamResponse (which sends strings),
// we set the chunks manually so that each delta's content is an object.
server.responseChunks = [
`data: {"id":"stream-object-id","object":"chat.completion.chunk","created":1711097175,"model":"mistral-small-latest","choices":[{"index":0,"delta":{"role":"assistant","content":{"type":"text","text":""}},"finish_reason":null,"logprobs":null}]}\n\n`,
`data: {"id":"stream-object-id","object":"chat.completion.chunk","created":1711097175,"model":"mistral-small-latest","choices":[{"index":0,"delta":{"content":{"type":"text","text":"Hello"}},"finish_reason":null,"logprobs":null}]}\n\n`,
`data: {"id":"stream-object-id","object":"chat.completion.chunk","created":1711097175,"model":"mistral-small-latest","choices":[{"index":0,"delta":{"content":{"type":"text","text":", world!"}},"finish_reason":"stop","logprobs":null}],"usage":{"prompt_tokens":4,"total_tokens":36,"completion_tokens":32}}\n\n`,
`data: [DONE]\n\n`,
];

const { stream } = await model.doStream({
inputFormat: 'prompt',
mode: { type: 'regular' },
prompt: TEST_PROMPT,
});

expect(await convertReadableStreamToArray(stream)).toStrictEqual([
{
type: 'response-metadata',
id: 'stream-object-id',
timestamp: new Date(1711097175 * 1000),
modelId: 'mistral-small-latest',
},
{ type: 'text-delta', textDelta: '' },
{ type: 'text-delta', textDelta: 'Hello' },
{ type: 'text-delta', textDelta: ', world!' },
{
type: 'finish',
finishReason: 'stop',
usage: { promptTokens: 4, completionTokens: 32 },
},
]);
});
});
55 changes: 47 additions & 8 deletions packages/mistral/src/mistral-chat-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,10 @@ export class MistralChatLanguageModel implements LanguageModelV1 {

const { messages: rawPrompt, ...rawSettings } = args;
const choice = response.choices[0];
let text = choice.message.content ?? undefined;

// extract text content.
// image content or reference content is currently ignored.
let text = extractTextContent(choice.message.content);

// when there is a trailing assistant message, mistral will send the
// content of that message again. we skip this repeated content to
Expand Down Expand Up @@ -294,6 +297,10 @@ export class MistralChatLanguageModel implements LanguageModelV1 {

const delta = choice.delta;

// extract text content.
// image content or reference content is currently ignored.
const textContent = extractTextContent(delta.content);

// when there is a trailing assistant message, mistral will send the
// content of that message again. we skip this repeated content to
// avoid duplication, e.g. in continuation mode.
Expand All @@ -302,11 +309,11 @@ export class MistralChatLanguageModel implements LanguageModelV1 {

if (
lastMessage.role === 'assistant' &&
delta.content === lastMessage.content.trimEnd()
textContent === lastMessage.content.trimEnd()
) {
// Mistral moves the trailing space from the prefix to the next chunk.
// We trim the leading space to avoid duplication.
if (delta.content.length < lastMessage.content.length) {
if (textContent.length < lastMessage.content.length) {
trimLeadingSpace = true;
}

Expand All @@ -315,12 +322,12 @@ export class MistralChatLanguageModel implements LanguageModelV1 {
}
}

if (delta.content != null) {
if (textContent != null) {
controller.enqueue({
type: 'text-delta',
textDelta: trimLeadingSpace
? delta.content.trimStart()
: delta.content,
? textContent.trimStart()
: textContent,
});

trimLeadingSpace = false;
Expand Down Expand Up @@ -360,6 +367,38 @@ export class MistralChatLanguageModel implements LanguageModelV1 {
}
}

function extractTextContent(content: z.infer<typeof mistralContentSchema>) {
return typeof content === 'string'
? content
: content?.type === 'text'
? content.text
: undefined;
}

const mistralContentSchema = z
.union([
z.string(),
z.object({
type: z.literal('text'),
text: z.string(),
}),
z.object({
type: z.literal('image_url'),
image_url: z.union([
z.string(),
z.object({
url: z.string(),
detail: z.string().nullable(),
}),
]),
}),
z.object({
type: z.literal('reference'),
reference_ids: z.array(z.number()),
}),
])
.nullable();

// limited version of the schema, focussed on what is needed for the implementation
// this approach limits breakages when the API changes and increases efficiency
const mistralChatResponseSchema = z.object({
Expand All @@ -370,7 +409,7 @@ const mistralChatResponseSchema = z.object({
z.object({
message: z.object({
role: z.literal('assistant'),
content: z.string().nullable(),
content: mistralContentSchema,
tool_calls: z
.array(
z.object({
Expand Down Expand Up @@ -401,7 +440,7 @@ const mistralChatChunkSchema = z.object({
z.object({
delta: z.object({
role: z.enum(['assistant']).optional(),
content: z.string().nullish(),
content: mistralContentSchema,
tool_calls: z
.array(
z.object({
Expand Down

0 comments on commit e6a7628

Please sign in to comment.