-
-
Notifications
You must be signed in to change notification settings - Fork 721
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(response model): introduce handling of simple types (#447)
- Loading branch information
Showing
5 changed files
with
264 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from inspect import isclass | ||
import typing | ||
from pydantic import BaseModel, create_model | ||
from enum import Enum | ||
|
||
|
||
from instructor.dsl.partial import Partial | ||
from instructor.function_calls import OpenAISchema | ||
|
||
|
||
T = typing.TypeVar("T") | ||
|
||
|
||
class AdapterBase(BaseModel): | ||
pass | ||
|
||
|
||
class ModelAdapter(typing.Generic[T]): | ||
""" | ||
Accepts a response model and returns a BaseModel with the response model as the content. | ||
""" | ||
|
||
def __class_getitem__(cls, response_model) -> typing.Type[BaseModel]: | ||
assert is_simple_type(response_model), "Only simple types are supported" | ||
tmp = create_model( | ||
"Response", | ||
content=(response_model, ...), | ||
__doc__="Correctly Formated and Extracted Response.", | ||
__base__=(AdapterBase, OpenAISchema), | ||
) | ||
return tmp | ||
|
||
|
||
def is_simple_type(response_model) -> bool: | ||
# ! we're getting mixes between classes and instances due to how we handle some | ||
# ! response model types, we should fix this in later PRs | ||
if isclass(response_model) and issubclass(response_model, BaseModel): | ||
return False | ||
|
||
if typing.get_origin(response_model) in {typing.Iterable, Partial}: | ||
# These are reserved for streaming types, would be nice to | ||
return False | ||
|
||
if response_model in { | ||
str, | ||
int, | ||
float, | ||
bool, | ||
}: | ||
return True | ||
|
||
# If the response_model is a simple type like annotated | ||
if typing.get_origin(response_model) in { | ||
typing.Annotated, | ||
typing.Literal, | ||
typing.Union, | ||
list, # origin of List[T] is list | ||
}: | ||
return True | ||
|
||
if isclass(response_model) and issubclass(response_model, Enum): | ||
return True | ||
|
||
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
import pytest | ||
import instructor | ||
import enum | ||
|
||
from typing import Annotated, Literal, Union | ||
from pydantic import Field | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_response_simple_types(aclient): | ||
client = instructor.patch(aclient, mode=instructor.Mode.TOOLS) | ||
|
||
for response_model in [int, bool, str]: | ||
response = await client.chat.completions.create( | ||
model="gpt-3.5-turbo", | ||
response_model=response_model, | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": "Produce a Random but correct response given the desired output", | ||
}, | ||
], | ||
) | ||
assert type(response) == response_model | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_annotate(aclient): | ||
client = instructor.patch(aclient, mode=instructor.Mode.TOOLS) | ||
|
||
response = await client.chat.completions.create( | ||
model="gpt-3.5-turbo", | ||
response_model=Annotated[int, Field(description="test")], | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": "Produce a Random but correct response given the desired output", | ||
}, | ||
], | ||
) | ||
assert type(response) == int | ||
|
||
|
||
def test_literal(client): | ||
client = instructor.patch(client, mode=instructor.Mode.TOOLS) | ||
|
||
response = client.chat.completions.create( | ||
model="gpt-3.5-turbo", | ||
response_model=Literal["1231", "212", "331"], | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": "Produce a Random but correct response given the desired output", | ||
}, | ||
], | ||
) | ||
assert response in ["1231", "212", "331"] | ||
|
||
|
||
def test_union(client): | ||
client = instructor.patch(client, mode=instructor.Mode.TOOLS) | ||
|
||
response = client.chat.completions.create( | ||
model="gpt-3.5-turbo", | ||
response_model=Union[int, str], | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": "Produce a Random but correct response given the desired output", | ||
}, | ||
], | ||
) | ||
assert type(response) in [int, str] | ||
|
||
|
||
def test_enum(client): | ||
class Options(enum.Enum): | ||
A = "A" | ||
B = "B" | ||
C = "C" | ||
|
||
client = instructor.patch(client, mode=instructor.Mode.TOOLS) | ||
|
||
response = client.chat.completions.create( | ||
model="gpt-3.5-turbo", | ||
response_model=Options, | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": "Produce a Random but correct response given the desired output", | ||
}, | ||
], | ||
) | ||
assert response in [Options.A, Options.B, Options.C] | ||
|
||
|
||
def test_bool(client): | ||
client = instructor.patch(client, mode=instructor.Mode.TOOLS) | ||
|
||
response = client.chat.completions.create( | ||
model="gpt-3.5-turbo", | ||
response_model=bool, | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": "Produce a Random but correct response given the desired output", | ||
}, | ||
], | ||
) | ||
assert type(response) == bool |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from instructor.dsl import is_simple_type, Partial | ||
from pydantic import BaseModel | ||
|
||
|
||
def test_enum_simple(): | ||
from enum import Enum | ||
|
||
class Color(Enum): | ||
RED = 1 | ||
GREEN = 2 | ||
BLUE = 3 | ||
|
||
assert is_simple_type(Color), "Failed for type: " + str(Color) | ||
|
||
|
||
def test_standard_types(): | ||
for t in [str, int, float, bool]: | ||
assert is_simple_type(t), "Failed for type: " + str(t) | ||
|
||
|
||
def test_partial_not_simple(): | ||
class SampleModel(BaseModel): | ||
data: int | ||
|
||
assert not is_simple_type(Partial[SampleModel]), "Failed for type: " + str( | ||
Partial[int] | ||
) | ||
|
||
|
||
def test_annotated_simple(): | ||
from pydantic import Field | ||
from typing import Annotated | ||
|
||
new_type = Annotated[int, Field(description="test")] | ||
|
||
assert is_simple_type(new_type), "Failed for type: " + str(new_type) | ||
|
||
|
||
def test_literal_simple(): | ||
from typing import Literal | ||
|
||
new_type = Literal[1, 2, 3] | ||
|
||
assert is_simple_type(new_type), "Failed for type: " + str(new_type) | ||
|
||
|
||
def test_union_simple(): | ||
from typing import Union | ||
|
||
new_type = Union[int, str] | ||
|
||
assert is_simple_type(new_type), "Failed for type: " + str(new_type) | ||
|
||
|
||
def test_iterable_not_simple(): | ||
from typing import Iterable | ||
|
||
new_type = Iterable[int] | ||
|
||
assert not is_simple_type(new_type), "Failed for type: " + str(new_type) | ||
|
||
|
||
def test_list_is_simple(): | ||
from typing import List | ||
|
||
new_type = List[int] | ||
|
||
assert is_simple_type(new_type), "Failed for type: " + str(new_type) |