Skip to content

Commit

Permalink
feat (ai/core): support zod transformers in generateObject & streamOb…
Browse files Browse the repository at this point in the history
…ject (vercel#2549)
  • Loading branch information
lgrammel authored Aug 5, 2024
1 parent be6a179 commit 0762a22
Show file tree
Hide file tree
Showing 10 changed files with 132 additions and 21 deletions.
6 changes: 6 additions & 0 deletions .changeset/perfect-chairs-bathe.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
'@ai-sdk/provider-utils': patch
'ai': patch
---

feat (ai/core): support zod transformers in generateObject & streamObject
26 changes: 15 additions & 11 deletions content/docs/03-ai-sdk-core/10-generating-structured-data.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -93,21 +93,25 @@ console.log('users', users);

### Dates

Zod expects JavaScript Date objects, but most models return dates as strings. You can use the `z.string().datetime()` method to specify and validate datetime strings.
Zod expects JavaScript Date objects, but models return dates as strings.
You can specify and validate the date format using `z.string().datetime()` or `z.string().date()`,
and then use a Zod transformer to convert the string to a Date object.

```ts highlight="6-9"
```ts highlight="7-10"
const result = await generateObject({
model: openai('gpt-4o'),
model: openai('gpt-4-turbo'),
schema: z.object({
user: z.object({
login: z.string(),
lastSeen: z
.string()
.datetime()
.describe('Last time the user was seen (ISO 8601 date string().'),
}),
events: z.array(
z.object({
event: z.string(),
date: z
.string()
.date()
.transform(value => new Date(value)),
}),
),
}),
prompt: 'Generate a fake user profile for testing.',
prompt: 'List 5 important events from the the year 2000.',
});
```

Expand Down
30 changes: 30 additions & 0 deletions examples/ai-core/src/generate-object/openai-date-parsing.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import { openai } from '@ai-sdk/openai';
import { generateObject } from 'ai';
import dotenv from 'dotenv';
import { z } from 'zod';

dotenv.config();

async function main() {
const {
object: { events },
} = await generateObject({
model: openai('gpt-4-turbo'),
schema: z.object({
events: z.array(
z.object({
date: z
.string()
.date()
.transform(value => new Date(value)),
event: z.string(),
}),
),
}),
prompt: 'List 5 important events from the the year 2000.',
});

console.log(events);
}

main().catch(console.error);
69 changes: 69 additions & 0 deletions packages/core/core/generate-object/generate-object.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -494,3 +494,72 @@ describe('custom schema', () => {
assert.deepStrictEqual(result.object, { content: 'Hello, world!' });
});
});

describe('zod schema', () => {
it('should generate object when using zod transform', async () => {
const result = await generateObject({
model: new MockLanguageModelV1({
doGenerate: async ({ prompt, mode }) => {
assert.deepStrictEqual(mode, { type: 'object-json' });
assert.deepStrictEqual(prompt, [
{
role: 'system',
content:
'JSON schema:\n' +
'{"type":"object","properties":{"content":{"type":"string"}},"required":["content"],"additionalProperties":false,"$schema":"http://json-schema.org/draft-07/schema#"}\n' +
'You MUST answer with a JSON object that matches the JSON schema above.',
},
{ role: 'user', content: [{ type: 'text', text: 'prompt' }] },
]);

return {
...dummyResponseValues,
text: `{ "content": "Hello, world!" }`,
};
},
}),
schema: z.object({
content: z.string().transform(value => value.length),
}),
mode: 'json',
prompt: 'prompt',
});

assert.deepStrictEqual(result.object, { content: 13 });
});

it('should generate object with tool mode when using zod prePreprocess', async () => {
const result = await generateObject({
model: new MockLanguageModelV1({
doGenerate: async ({ prompt, mode }) => {
assert.deepStrictEqual(mode, { type: 'object-json' });
assert.deepStrictEqual(prompt, [
{
role: 'system',
content:
'JSON schema:\n' +
'{"type":"object","properties":{"content":{"type":"string"}},"required":["content"],"additionalProperties":false,"$schema":"http://json-schema.org/draft-07/schema#"}\n' +
'You MUST answer with a JSON object that matches the JSON schema above.',
},
{ role: 'user', content: [{ type: 'text', text: 'prompt' }] },
]);

return {
...dummyResponseValues,
text: `{ "content": "Hello, world!" }`,
};
},
}),
schema: z.object({
content: z.preprocess(
val => (typeof val === 'number' ? String(val) : val),
z.string(),
),
}),
mode: 'json',
prompt: 'prompt',
});

assert.deepStrictEqual(result.object, { content: 'Hello, world!' });
});
});
2 changes: 1 addition & 1 deletion packages/core/core/generate-object/generate-object.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ The language model to use.
/**
The schema of the object that the model should generate.
*/
schema: z.Schema<T> | Schema<T>;
schema: z.Schema<T, z.ZodTypeDef, any> | Schema<T>;

/**
The mode to use for object generation.
Expand Down
4 changes: 2 additions & 2 deletions packages/core/core/generate-object/stream-object.ts
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ The language model to use.
/**
The schema of the object that the model should generate.
*/
schema: z.Schema<T> | Schema<T>;
schema: z.Schema<T, z.ZodTypeDef, any> | Schema<T>;

/**
The mode to use for object generation.
Expand Down Expand Up @@ -382,7 +382,7 @@ class DefaultStreamObjectResult<T> implements StreamObjectResult<T> {
>;
warnings: StreamObjectResult<T>['warnings'];
rawResponse?: StreamObjectResult<T>['rawResponse'];
schema: z.Schema<T> | Schema<T>;
schema: z.Schema<T, z.ZodTypeDef, any> | Schema<T>;
onFinish: Parameters<typeof streamObject<T>>[0]['onFinish'];
rootSpan: Span;
doStreamSpan: Span;
Expand Down
6 changes: 4 additions & 2 deletions packages/core/core/util/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,14 @@ function isSchema(value: unknown): value is Schema {
}

export function asSchema<OBJECT>(
schema: z.Schema<OBJECT> | Schema<OBJECT>,
schema: z.Schema<OBJECT, z.ZodTypeDef, any> | Schema<OBJECT>,
): Schema<OBJECT> {
return isSchema(schema) ? schema : zodSchema(schema);
}

export function zodSchema<OBJECT>(zodSchema: z.Schema<OBJECT>): Schema<OBJECT> {
export function zodSchema<OBJECT>(
zodSchema: z.Schema<OBJECT, z.ZodTypeDef, any>,
): Schema<OBJECT> {
return jsonSchema(
// we assume that zodToJsonSchema will return a valid JSONSchema7:
zodToJsonSchema(zodSchema) as JSONSchema7,
Expand Down
6 changes: 3 additions & 3 deletions packages/provider-utils/src/validate-types.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { TypeValidationError } from '@ai-sdk/provider';
import { ZodSchema } from 'zod';
import { z } from 'zod';
import { Validator, isValidator, zodValidator } from './validator';

/**
Expand All @@ -16,7 +16,7 @@ export function validateTypes<T>({
schema: inputSchema,
}: {
value: unknown;
schema: ZodSchema<T> | Validator<T>;
schema: z.Schema<T, z.ZodTypeDef, any> | Validator<T>;
}): T {
const result = safeValidateTypes({ value, schema: inputSchema });

Expand All @@ -41,7 +41,7 @@ export function safeValidateTypes<T>({
schema: inputSchema,
}: {
value: unknown;
schema: ZodSchema<T> | Validator<T>;
schema: z.Schema<T, z.ZodTypeDef, any> | Validator<T>;
}):
| { success: true; value: T }
| { success: false; error: TypeValidationError } {
Expand Down
2 changes: 1 addition & 1 deletion packages/provider-utils/src/validator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ export function isValidator(value: unknown): value is Validator {
}

export function zodValidator<OBJECT>(
zodSchema: z.Schema<OBJECT>,
zodSchema: z.Schema<OBJECT, z.ZodTypeDef, any>,
): Validator<OBJECT> {
return validator(value => {
const result = zodSchema.safeParse(value);
Expand Down
2 changes: 1 addition & 1 deletion packages/react/src/use-object.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ export type Experimental_UseObjectOptions<RESULT> = {
/**
* A Zod schema that defines the shape of the complete object.
*/
schema: z.Schema<RESULT>;
schema: z.Schema<RESULT, z.ZodTypeDef, any>;

/**
* An unique identifier. If not provided, a random one will be
Expand Down

0 comments on commit 0762a22

Please sign in to comment.