Skip to content

Commit

Permalink
fix: add handling for List[non-object] types (#521)
Browse files Browse the repository at this point in the history
Co-authored-by: Jason Liu <[email protected]>
  • Loading branch information
shreya-51 and jxnl authored Mar 23, 2024
1 parent a9d6cd8 commit cea534f
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 13 deletions.
24 changes: 12 additions & 12 deletions instructor/anthropic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def _add_params(
# TODO: handling of nested params with the same name
properties = model_dict.get("properties", {})
list_found = False
nested_list_found = False

for field_name, details in properties.items():
parameter = ET.SubElement(root, "parameter")
Expand All @@ -74,11 +75,19 @@ def _add_params(
field_type = details.get(
"type", "unknown"
) # Might be better to fail here if there is no type since pydantic models require types

if "array" in field_type and "items" not in details:
raise ValueError("Invalid array item.")

# Adjust type if array
if "array" in field_type or "List" in field_type:
# Check for nested List
if "array" in field_type and "$ref" in details["items"]:
type_element.text = f"List[{details['title']}]"
list_found = True
nested_list_found = True
# Check for non-nested List
elif "array" in field_type and "type" in details["items"]:
type_element.text = f"List[{details['items']['type']}]"
list_found = True
else:
type_element.text = field_type

Expand All @@ -105,22 +114,13 @@ def _add_params(
reference,
references,
)
elif field_type == "array": # Handling for List[] type
elif field_type == "array" and nested_list_found: # Handling for List[] type
nested_params = ET.SubElement(parameter, "parameters")
list_found |= _add_params(
nested_params,
_resolve_reference(references, details["items"]["$ref"]),
references,
)
elif "array" in field_type: # Handling for optional List[] type
nested_params = ET.SubElement(parameter, "parameters")
list_found |= _add_params(
nested_params,
_resolve_reference(
references, details["anyOf"][0]["items"]["$ref"]
), # CHANGE
references,
)

return list_found

Expand Down
1 change: 1 addition & 0 deletions instructor/function_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def anthropic_schema(cls) -> str:
for line in parseString(json_to_xml(cls)).toprettyxml().splitlines()[1:]
)


@classmethod
def from_response(
cls,
Expand Down
3 changes: 2 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 25 additions & 0 deletions tests/anthropic/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,31 @@ class User(BaseModel):
assert resp.address.street_name == "First Avenue"


def test_list():
class User(BaseModel):
name: str
age: int
family: List[str]

resp = create(
model="claude-3-opus-20240229", # Fails with claude-3-haiku-20240307
max_tokens=1024,
max_retries=0,
messages=[
{
"role": "user",
"content": "Create a user for a model with a name, age, and family members.",
}
],
response_model=User,
)

assert isinstance(resp, User)
assert isinstance(resp.family, List)
for member in resp.family:
assert isinstance(member, str)


def test_nested_list():
class Properties(BaseModel):
key: str
Expand Down

0 comments on commit cea534f

Please sign in to comment.