Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(response model): introduce handling of simple types #447

Merged
merged 3 commits into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions instructor/dsl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .partial import Partial
from .validators import llm_validator, openai_moderation
from .citation import CitationMixin
from .simple_type import is_simple_type, ModelAdapter

__all__ = [ # noqa: F405
"CitationMixin",
Expand All @@ -11,4 +12,6 @@
"Partial",
"llm_validator",
"openai_moderation",
"is_simple_type",
"ModelAdapter",
]
64 changes: 64 additions & 0 deletions instructor/dsl/simple_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from inspect import isclass
import typing
from pydantic import BaseModel, create_model
from enum import Enum


from instructor.dsl.partial import Partial
from instructor.function_calls import OpenAISchema


T = typing.TypeVar("T")


class AdapterBase(BaseModel):
pass


class ModelAdapter(typing.Generic[T]):
"""
Accepts a response model and returns a BaseModel with the response model as the content.
"""

def __class_getitem__(cls, response_model) -> typing.Type[BaseModel]:
assert is_simple_type(response_model), "Only simple types are supported"
tmp = create_model(
"Response",
content=(response_model, ...),
__doc__="Correctly Formated and Extracted Response.",
__base__=(AdapterBase, OpenAISchema),
)
return tmp


def is_simple_type(response_model) -> bool:
# ! we're getting mixes between classes and instances due to how we handle some
# ! response model types, we should fix this in later PRs
if isclass(response_model) and issubclass(response_model, BaseModel):
return False

if typing.get_origin(response_model) in {typing.Iterable, Partial}:
# These are reserved for streaming types, would be nice to
return False

if response_model in {
str,
int,
float,
bool,
}:
return True

# If the response_model is a simple type like annotated
if typing.get_origin(response_model) in {
typing.Annotated,
typing.Literal,
typing.Union,
list, # origin of List[T] is list
}:
return True

if isclass(response_model) and issubclass(response_model, Enum):
return True

return False
20 changes: 19 additions & 1 deletion instructor/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from instructor.dsl.iterable import IterableModel, IterableBase
from instructor.dsl.parallel import ParallelBase, ParallelModel, handle_parallel_model
from instructor.dsl.partial import PartialBase
from instructor.dsl.simple_type import ModelAdapter, AdapterBase, is_simple_type

from .function_calls import Mode, OpenAISchema, openai_schema

Expand Down Expand Up @@ -80,6 +81,12 @@ def handle_response_model(
"""
new_kwargs = kwargs.copy()
if response_model is not None:
# Handles the case where the response_model is a simple type
# Literal, Annotated, Union, str, int, float, bool, Enum
# We wrap the response_model in a ModelAdapter that sets 'content' as the response
if is_simple_type(response_model):
response_model = ModelAdapter[response_model]

# This a special case for parallel tools
if mode == Mode.PARALLEL_TOOLS:
assert (
Expand Down Expand Up @@ -213,11 +220,17 @@ def process_response(
# ? This really hints at the fact that we need a better way of
# ? attaching usage data and the raw response to the model we return.
if isinstance(model, IterableBase):
logger.debug(f"Returning takes from IterableBase")
return [task for task in model.tasks]

if isinstance(response_model, ParallelBase):
logger.debug(f"Returning model from ParallelBase")
return model

if isinstance(model, AdapterBase):
logger.debug(f"Returning model from AdapterBase")
return model.content

model._raw_response = response
return model

Expand Down Expand Up @@ -266,12 +279,17 @@ async def process_response_async(
# ? This really hints at the fact that we need a better way of
# ? attaching usage data and the raw response to the model we return.
if isinstance(model, IterableBase):
#! If the response model is a multitask, return the tasks
logger.debug(f"Returning takes from IterableBase")
return [task for task in model.tasks]

if isinstance(response_model, ParallelBase):
logger.debug(f"Returning model from ParallelBase")
return model

if isinstance(model, AdapterBase):
logger.debug(f"Returning model from AdapterBase")
return model.content

model._raw_response = response
return model

Expand Down
110 changes: 110 additions & 0 deletions tests/openai/test_simple_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import pytest
import instructor
import enum

from typing import Annotated, Literal, Union
from pydantic import Field


@pytest.mark.asyncio
async def test_response_simple_types(aclient):
client = instructor.patch(aclient, mode=instructor.Mode.TOOLS)

for response_model in [int, bool, str]:
response = await client.chat.completions.create(
model="gpt-3.5-turbo",
response_model=response_model,
messages=[
{
"role": "user",
"content": "Produce a Random but correct response given the desired output",
},
],
)
assert type(response) == response_model


@pytest.mark.asyncio
async def test_annotate(aclient):
client = instructor.patch(aclient, mode=instructor.Mode.TOOLS)

response = await client.chat.completions.create(
model="gpt-3.5-turbo",
response_model=Annotated[int, Field(description="test")],
messages=[
{
"role": "user",
"content": "Produce a Random but correct response given the desired output",
},
],
)
assert type(response) == int


def test_literal(client):
client = instructor.patch(client, mode=instructor.Mode.TOOLS)

response = client.chat.completions.create(
model="gpt-3.5-turbo",
response_model=Literal["1231", "212", "331"],
messages=[
{
"role": "user",
"content": "Produce a Random but correct response given the desired output",
},
],
)
assert response in ["1231", "212", "331"]


def test_union(client):
client = instructor.patch(client, mode=instructor.Mode.TOOLS)

response = client.chat.completions.create(
model="gpt-3.5-turbo",
response_model=Union[int, str],
messages=[
{
"role": "user",
"content": "Produce a Random but correct response given the desired output",
},
],
)
assert type(response) in [int, str]


def test_enum(client):
class Options(enum.Enum):
A = "A"
B = "B"
C = "C"

client = instructor.patch(client, mode=instructor.Mode.TOOLS)

response = client.chat.completions.create(
model="gpt-3.5-turbo",
response_model=Options,
messages=[
{
"role": "user",
"content": "Produce a Random but correct response given the desired output",
},
],
)
assert response in [Options.A, Options.B, Options.C]


def test_bool(client):
client = instructor.patch(client, mode=instructor.Mode.TOOLS)

response = client.chat.completions.create(
model="gpt-3.5-turbo",
response_model=bool,
messages=[
{
"role": "user",
"content": "Produce a Random but correct response given the desired output",
},
],
)
assert type(response) == bool
68 changes: 68 additions & 0 deletions tests/test_simple_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from instructor.dsl import is_simple_type, Partial
from pydantic import BaseModel


def test_enum_simple():
from enum import Enum

class Color(Enum):
RED = 1
GREEN = 2
BLUE = 3

assert is_simple_type(Color), "Failed for type: " + str(Color)


def test_standard_types():
for t in [str, int, float, bool]:
assert is_simple_type(t), "Failed for type: " + str(t)


def test_partial_not_simple():
class SampleModel(BaseModel):
data: int

assert not is_simple_type(Partial[SampleModel]), "Failed for type: " + str(
Partial[int]
)


def test_annotated_simple():
from pydantic import Field
from typing import Annotated

new_type = Annotated[int, Field(description="test")]

assert is_simple_type(new_type), "Failed for type: " + str(new_type)


def test_literal_simple():
from typing import Literal

new_type = Literal[1, 2, 3]

assert is_simple_type(new_type), "Failed for type: " + str(new_type)


def test_union_simple():
from typing import Union

new_type = Union[int, str]

assert is_simple_type(new_type), "Failed for type: " + str(new_type)


def test_iterable_not_simple():
from typing import Iterable

new_type = Iterable[int]

assert not is_simple_type(new_type), "Failed for type: " + str(new_type)


def test_list_is_simple():
from typing import List

new_type = List[int]

assert is_simple_type(new_type), "Failed for type: " + str(new_type)
Loading