Skip to content

Commit

Permalink
feat(response model): introduce handling of simple types (#447)
Browse files Browse the repository at this point in the history
  • Loading branch information
jxnl authored Feb 19, 2024
1 parent f29f1bd commit 2319fff
Show file tree
Hide file tree
Showing 5 changed files with 264 additions and 1 deletion.
3 changes: 3 additions & 0 deletions instructor/dsl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .partial import Partial
from .validators import llm_validator, openai_moderation
from .citation import CitationMixin
from .simple_type import is_simple_type, ModelAdapter

__all__ = [ # noqa: F405
"CitationMixin",
Expand All @@ -11,4 +12,6 @@
"Partial",
"llm_validator",
"openai_moderation",
"is_simple_type",
"ModelAdapter",
]
64 changes: 64 additions & 0 deletions instructor/dsl/simple_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from inspect import isclass
import typing
from pydantic import BaseModel, create_model
from enum import Enum


from instructor.dsl.partial import Partial
from instructor.function_calls import OpenAISchema


T = typing.TypeVar("T")


class AdapterBase(BaseModel):
pass


class ModelAdapter(typing.Generic[T]):
"""
Accepts a response model and returns a BaseModel with the response model as the content.
"""

def __class_getitem__(cls, response_model) -> typing.Type[BaseModel]:
assert is_simple_type(response_model), "Only simple types are supported"
tmp = create_model(
"Response",
content=(response_model, ...),
__doc__="Correctly Formated and Extracted Response.",
__base__=(AdapterBase, OpenAISchema),
)
return tmp


def is_simple_type(response_model) -> bool:
# ! we're getting mixes between classes and instances due to how we handle some
# ! response model types, we should fix this in later PRs
if isclass(response_model) and issubclass(response_model, BaseModel):
return False

if typing.get_origin(response_model) in {typing.Iterable, Partial}:
# These are reserved for streaming types, would be nice to
return False

if response_model in {
str,
int,
float,
bool,
}:
return True

# If the response_model is a simple type like annotated
if typing.get_origin(response_model) in {
typing.Annotated,
typing.Literal,
typing.Union,
list, # origin of List[T] is list
}:
return True

if isclass(response_model) and issubclass(response_model, Enum):
return True

return False
20 changes: 19 additions & 1 deletion instructor/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from instructor.dsl.iterable import IterableModel, IterableBase
from instructor.dsl.parallel import ParallelBase, ParallelModel, handle_parallel_model
from instructor.dsl.partial import PartialBase
from instructor.dsl.simple_type import ModelAdapter, AdapterBase, is_simple_type

from .function_calls import Mode, OpenAISchema, openai_schema

Expand Down Expand Up @@ -80,6 +81,12 @@ def handle_response_model(
"""
new_kwargs = kwargs.copy()
if response_model is not None:
# Handles the case where the response_model is a simple type
# Literal, Annotated, Union, str, int, float, bool, Enum
# We wrap the response_model in a ModelAdapter that sets 'content' as the response
if is_simple_type(response_model):
response_model = ModelAdapter[response_model]

# This a special case for parallel tools
if mode == Mode.PARALLEL_TOOLS:
assert (
Expand Down Expand Up @@ -213,11 +220,17 @@ def process_response(
# ? This really hints at the fact that we need a better way of
# ? attaching usage data and the raw response to the model we return.
if isinstance(model, IterableBase):
logger.debug(f"Returning takes from IterableBase")
return [task for task in model.tasks]

if isinstance(response_model, ParallelBase):
logger.debug(f"Returning model from ParallelBase")
return model

if isinstance(model, AdapterBase):
logger.debug(f"Returning model from AdapterBase")
return model.content

model._raw_response = response
return model

Expand Down Expand Up @@ -266,12 +279,17 @@ async def process_response_async(
# ? This really hints at the fact that we need a better way of
# ? attaching usage data and the raw response to the model we return.
if isinstance(model, IterableBase):
#! If the response model is a multitask, return the tasks
logger.debug(f"Returning takes from IterableBase")
return [task for task in model.tasks]

if isinstance(response_model, ParallelBase):
logger.debug(f"Returning model from ParallelBase")
return model

if isinstance(model, AdapterBase):
logger.debug(f"Returning model from AdapterBase")
return model.content

model._raw_response = response
return model

Expand Down
110 changes: 110 additions & 0 deletions tests/openai/test_simple_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import pytest
import instructor
import enum

from typing import Annotated, Literal, Union
from pydantic import Field


@pytest.mark.asyncio
async def test_response_simple_types(aclient):
client = instructor.patch(aclient, mode=instructor.Mode.TOOLS)

for response_model in [int, bool, str]:
response = await client.chat.completions.create(
model="gpt-3.5-turbo",
response_model=response_model,
messages=[
{
"role": "user",
"content": "Produce a Random but correct response given the desired output",
},
],
)
assert type(response) == response_model


@pytest.mark.asyncio
async def test_annotate(aclient):
client = instructor.patch(aclient, mode=instructor.Mode.TOOLS)

response = await client.chat.completions.create(
model="gpt-3.5-turbo",
response_model=Annotated[int, Field(description="test")],
messages=[
{
"role": "user",
"content": "Produce a Random but correct response given the desired output",
},
],
)
assert type(response) == int


def test_literal(client):
client = instructor.patch(client, mode=instructor.Mode.TOOLS)

response = client.chat.completions.create(
model="gpt-3.5-turbo",
response_model=Literal["1231", "212", "331"],
messages=[
{
"role": "user",
"content": "Produce a Random but correct response given the desired output",
},
],
)
assert response in ["1231", "212", "331"]


def test_union(client):
client = instructor.patch(client, mode=instructor.Mode.TOOLS)

response = client.chat.completions.create(
model="gpt-3.5-turbo",
response_model=Union[int, str],
messages=[
{
"role": "user",
"content": "Produce a Random but correct response given the desired output",
},
],
)
assert type(response) in [int, str]


def test_enum(client):
class Options(enum.Enum):
A = "A"
B = "B"
C = "C"

client = instructor.patch(client, mode=instructor.Mode.TOOLS)

response = client.chat.completions.create(
model="gpt-3.5-turbo",
response_model=Options,
messages=[
{
"role": "user",
"content": "Produce a Random but correct response given the desired output",
},
],
)
assert response in [Options.A, Options.B, Options.C]


def test_bool(client):
client = instructor.patch(client, mode=instructor.Mode.TOOLS)

response = client.chat.completions.create(
model="gpt-3.5-turbo",
response_model=bool,
messages=[
{
"role": "user",
"content": "Produce a Random but correct response given the desired output",
},
],
)
assert type(response) == bool
68 changes: 68 additions & 0 deletions tests/test_simple_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from instructor.dsl import is_simple_type, Partial
from pydantic import BaseModel


def test_enum_simple():
from enum import Enum

class Color(Enum):
RED = 1
GREEN = 2
BLUE = 3

assert is_simple_type(Color), "Failed for type: " + str(Color)


def test_standard_types():
for t in [str, int, float, bool]:
assert is_simple_type(t), "Failed for type: " + str(t)


def test_partial_not_simple():
class SampleModel(BaseModel):
data: int

assert not is_simple_type(Partial[SampleModel]), "Failed for type: " + str(
Partial[int]
)


def test_annotated_simple():
from pydantic import Field
from typing import Annotated

new_type = Annotated[int, Field(description="test")]

assert is_simple_type(new_type), "Failed for type: " + str(new_type)


def test_literal_simple():
from typing import Literal

new_type = Literal[1, 2, 3]

assert is_simple_type(new_type), "Failed for type: " + str(new_type)


def test_union_simple():
from typing import Union

new_type = Union[int, str]

assert is_simple_type(new_type), "Failed for type: " + str(new_type)


def test_iterable_not_simple():
from typing import Iterable

new_type = Iterable[int]

assert not is_simple_type(new_type), "Failed for type: " + str(new_type)


def test_list_is_simple():
from typing import List

new_type = List[int]

assert is_simple_type(new_type), "Failed for type: " + str(new_type)

0 comments on commit 2319fff

Please sign in to comment.