-
-
Notifications
You must be signed in to change notification settings - Fork 724
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(response model): introduce handling of simple types #447
Merged
Merged
Changes from 2 commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from inspect import isclass | ||
import typing | ||
from pydantic import BaseModel, create_model | ||
from enum import Enum | ||
|
||
|
||
from instructor.dsl.partial import Partial | ||
from instructor.function_calls import OpenAISchema | ||
|
||
|
||
T = typing.TypeVar("T") | ||
|
||
|
||
class AdapterBase(BaseModel): | ||
pass | ||
|
||
|
||
class ModelAdapter(typing.Generic[T]): | ||
""" | ||
Accepts a response model and returns a BaseModel with the response model as the content. | ||
""" | ||
|
||
def __class_getitem__(cls, response_model) -> typing.Type[BaseModel]: | ||
assert is_simple_type(response_model), "Only simple types are supported" | ||
tmp = create_model( | ||
"Response", | ||
content=(response_model, ...), | ||
__doc__="Correctly Formated and Extracted Response.", | ||
__base__=(AdapterBase, OpenAISchema), | ||
) | ||
return tmp | ||
|
||
|
||
def is_simple_type(response_model) -> bool: | ||
# ! we're getting mixes between classes and instances due to how we handle some | ||
# ! response model types, we should fix this in later PRs | ||
if isclass(response_model) and issubclass(response_model, BaseModel): | ||
return False | ||
|
||
if typing.get_origin(response_model) in {typing.Iterable, Partial}: | ||
# These are reserved for streaming types, would be nice to | ||
return False | ||
|
||
if response_model in { | ||
str, | ||
int, | ||
float, | ||
bool, | ||
}: | ||
return True | ||
|
||
# If the response_model is a simple type like annotated | ||
if typing.get_origin(response_model) in { | ||
typing.Annotated, | ||
typing.Literal, | ||
typing.Union, | ||
list, # origin of List[T] is list | ||
}: | ||
return True | ||
|
||
if isclass(response_model) and issubclass(response_model, Enum): | ||
return True | ||
|
||
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
import pytest | ||
import instructor | ||
import enum | ||
|
||
from typing import Annotated, Literal, Union | ||
from pydantic import Field | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_response_simple_types(aclient): | ||
client = instructor.patch(aclient, mode=instructor.Mode.TOOLS) | ||
|
||
for response_model in [int, bool, str]: | ||
response = await client.chat.completions.create( | ||
model="gpt-3.5-turbo", | ||
response_model=response_model, | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": "Product a Random but correct response given the desired output", | ||
}, | ||
], | ||
) | ||
assert type(response) == response_model | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_annotate(aclient): | ||
client = instructor.patch(aclient, mode=instructor.Mode.TOOLS) | ||
|
||
response = await client.chat.completions.create( | ||
model="gpt-3.5-turbo", | ||
response_model=Annotated[int, Field(description="test")], | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": "Product a Random but correct response given the desired output", | ||
}, | ||
], | ||
) | ||
assert type(response) == int | ||
|
||
|
||
def test_literal(client): | ||
client = instructor.patch(client, mode=instructor.Mode.TOOLS) | ||
|
||
response = client.chat.completions.create( | ||
model="gpt-3.5-turbo", | ||
response_model=Literal["1231", "212", "331"], | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": "Product a Random but correct response given the desired output", | ||
}, | ||
], | ||
) | ||
assert response in ["1231", "212", "331"] | ||
|
||
|
||
def test_union(client): | ||
client = instructor.patch(client, mode=instructor.Mode.TOOLS) | ||
|
||
response = client.chat.completions.create( | ||
model="gpt-3.5-turbo", | ||
response_model=Union[int, str], | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": "Product a Random but correct response given the desired output", | ||
}, | ||
], | ||
) | ||
assert type(response) in [int, str] | ||
|
||
|
||
def test_enum(client): | ||
class Options(enum.Enum): | ||
A = "A" | ||
B = "B" | ||
C = "C" | ||
|
||
client = instructor.patch(client, mode=instructor.Mode.TOOLS) | ||
|
||
response = client.chat.completions.create( | ||
model="gpt-3.5-turbo", | ||
response_model=Options, | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": "Product a Random but correct response given the desired output", | ||
}, | ||
], | ||
) | ||
assert response in [Options.A, Options.B, Options.C] | ||
|
||
|
||
def test_bool(client): | ||
client = instructor.patch(client, mode=instructor.Mode.TOOLS) | ||
|
||
response = client.chat.completions.create( | ||
model="gpt-3.5-turbo", | ||
response_model=bool, | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": "Product a Random but correct response given the desired output", | ||
}, | ||
], | ||
) | ||
assert type(response) == bool |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from instructor.dsl import is_simple_type, Partial | ||
from pydantic import BaseModel | ||
|
||
|
||
def test_enum_simple(): | ||
from enum import Enum | ||
|
||
class Color(Enum): | ||
RED = 1 | ||
GREEN = 2 | ||
BLUE = 3 | ||
|
||
assert is_simple_type(Color), "Failed for type: " + str(Color) | ||
|
||
|
||
def test_standard_types(): | ||
for t in [str, int, float, bool]: | ||
assert is_simple_type(t), "Failed for type: " + str(t) | ||
|
||
|
||
def test_partial_not_simple(): | ||
class SampleModel(BaseModel): | ||
data: int | ||
|
||
assert not is_simple_type(Partial[SampleModel]), "Failed for type: " + str( | ||
Partial[int] | ||
) | ||
|
||
|
||
def test_annotated_simple(): | ||
from pydantic import Field | ||
from typing import Annotated | ||
|
||
new_type = Annotated[int, Field(description="test")] | ||
|
||
assert is_simple_type(new_type), "Failed for type: " + str(new_type) | ||
|
||
|
||
def test_literal_simple(): | ||
from typing import Literal | ||
|
||
new_type = Literal[1, 2, 3] | ||
|
||
assert is_simple_type(new_type), "Failed for type: " + str(new_type) | ||
|
||
|
||
def test_union_simple(): | ||
from typing import Union | ||
|
||
new_type = Union[int, str] | ||
|
||
assert is_simple_type(new_type), "Failed for type: " + str(new_type) | ||
|
||
|
||
def test_iterable_not_simple(): | ||
from typing import Iterable | ||
|
||
new_type = Iterable[int] | ||
|
||
assert not is_simple_type(new_type), "Failed for type: " + str(new_type) | ||
|
||
|
||
def test_list_is_simple(): | ||
from typing import List | ||
|
||
new_type = List[int] | ||
|
||
assert is_simple_type(new_type), "Failed for type: " + str(new_type) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Product a"