Skip to content

Commit

Permalink
fix: Adding tests to make sure non-stream iterables work (#413)
Browse files Browse the repository at this point in the history
  • Loading branch information
jxnl authored Feb 7, 2024
1 parent 482143f commit f9389c1
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 5 deletions.
5 changes: 2 additions & 3 deletions instructor/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,7 @@ def process_response(

# ? This really hints at the fact that we need a better way of
# ? attaching usage data and the raw response to the model we return.
if isinstance(response_model, IterableBase):
#! If the response model is a multitask, return the tasks
if isinstance(model, IterableBase):
return [task for task in model.tasks]

if isinstance(response_model, ParallelBase):
Expand Down Expand Up @@ -267,7 +266,7 @@ async def process_response_async(

# ? This really hints at the fact that we need a better way of
# ? attaching usage data and the raw response to the model we return.
if isinstance(response_model, IterableBase):
if isinstance(model, IterableBase):
#! If the response model is a multitask, return the tasks
return [task for task in model.tasks]

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "instructor"
version = "0.5.0"
version = "0.5.1"
description = "structured outputs for llm"
authors = ["Jason Liu <[email protected]>"]
license = "MIT"
Expand Down
68 changes: 67 additions & 1 deletion tests/openai/test_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def test_multi_user(model, mode, client):
def stream_extract(input: str) -> Iterable[User]:
return client.chat.completions.create(
model=model,
stream=True,
response_model=Users,
messages=[
{
Expand Down Expand Up @@ -54,6 +53,73 @@ def stream_extract(input: str) -> Iterable[User]:
async def test_multi_user_tools_mode_async(model, mode, aclient):
client = instructor.patch(aclient, mode=mode)

async def stream_extract(input: str) -> Iterable[User]:
return await client.chat.completions.create(
model=model,
response_model=Users,
messages=[
{
"role": "user",
"content": (
f"Consider the data below:\n{input}"
"Correctly segment it into entitites"
"Make sure the JSON is correct"
),
},
],
max_tokens=1000,
)

resp = []
async for user in await stream_extract(input="Jason is 20, Sarah is 30"):
resp.append(user)
print(resp)
assert len(resp) == 2
assert resp[0].name == "Jason"
assert resp[0].age == 20
assert resp[1].name == "Sarah"
assert resp[1].age == 30


@pytest.mark.parametrize("model, mode", product(models, modes))
def test_multi_user_stream(model, mode, client):
client = instructor.patch(client, mode=mode)

def stream_extract(input: str) -> Iterable[User]:
return client.chat.completions.create(
model=model,
stream=True,
response_model=Users,
messages=[
{
"role": "system",
"content": "You are a perfect entity extraction system",
},
{
"role": "user",
"content": (
f"Consider the data below:\n{input}"
"Correctly segment it into entitites"
"Make sure the JSON is correct"
),
},
],
max_tokens=1000,
)

resp = [user for user in stream_extract(input="Jason is 20, Sarah is 30")]
assert len(resp) == 2
assert resp[0].name == "Jason"
assert resp[0].age == 20
assert resp[1].name == "Sarah"
assert resp[1].age == 30


@pytest.mark.asyncio
@pytest.mark.parametrize("model, mode", product(models, modes))
async def test_multi_user_tools_mode_async_stream(model, mode, aclient):
client = instructor.patch(aclient, mode=mode)

async def stream_extract(input: str) -> Iterable[User]:
return await client.chat.completions.create(
model=model,
Expand Down

0 comments on commit f9389c1

Please sign in to comment.