Skip to content

Commit

Permalink
feat: support for Sonnet 3.7 Reasoning for ANTRHOPIC_JSON (#1361)
Browse files Browse the repository at this point in the history
  • Loading branch information
A-F-V authored Mar 3, 2025
1 parent c6b8275 commit d38c8d0
Show file tree
Hide file tree
Showing 9 changed files with 122 additions and 70 deletions.
12 changes: 5 additions & 7 deletions instructor/client_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,11 @@ def from_anthropic(
TypeError: If enable_prompt_caching is True and client is not Anthropic or AsyncAnthropic
AssertionError: If mode is not ANTHROPIC_JSON or ANTHROPIC_TOOLS
"""
assert (
mode
in {
instructor.Mode.ANTHROPIC_JSON,
instructor.Mode.ANTHROPIC_TOOLS,
}
), "Mode be one of {instructor.Mode.ANTHROPIC_JSON, instructor.Mode.ANTHROPIC_TOOLS}"
assert mode in {
instructor.Mode.ANTHROPIC_JSON,
instructor.Mode.ANTHROPIC_TOOLS,
instructor.Mode.ANTHROPIC_REASONING_TOOLS,
}, "Mode be one of {instructor.Mode.ANTHROPIC_JSON, instructor.Mode.ANTHROPIC_TOOLS, instructor.Mode.ANTHROPIC_REASONING_TOOLS}"

assert isinstance(
client,
Expand Down
6 changes: 4 additions & 2 deletions instructor/function_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def from_response(
Returns:
cls (OpenAISchema): An instance of the class
"""
if mode == Mode.ANTHROPIC_TOOLS:
if mode == Mode.ANTHROPIC_TOOLS or mode == Mode.ANTHROPIC_REASONING_TOOLS:
return cls.parse_anthropic_tools(completion, validation_context, strict)

if mode == Mode.ANTHROPIC_JSON:
Expand Down Expand Up @@ -226,7 +226,9 @@ def parse_anthropic_json(
assert isinstance(completion, Message)
if completion.stop_reason == "max_tokens":
raise IncompleteOutputException(last_completion=completion)
text = completion.content[0].text
# Find the first text block
text_blocks = [c for c in completion.content if c.type == "text"]
text = text_blocks[0].text

extra_text = extract_json_from_codeblock(text)

Expand Down
1 change: 1 addition & 0 deletions instructor/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class Mode(enum.Enum):
MD_JSON = "markdown_json_mode"
JSON_SCHEMA = "json_schema_mode"
ANTHROPIC_TOOLS = "anthropic_tools"
ANTHROPIC_REASONING_TOOLS = "anthropic_reasoning_tools"
ANTHROPIC_JSON = "anthropic_json"
COHERE_TOOLS = "cohere_tools"
VERTEXAI_TOOLS = "vertexai_tools"
Expand Down
14 changes: 11 additions & 3 deletions instructor/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,9 @@ def to_openai(self) -> dict[str, Any]:
class Audio(BaseModel):
"""Represents an audio that can be loaded from a URL or file path."""

source: str | Path = Field(description="URL or file path of the audio") # noqa: UP007
source: str | Path = Field(
description="URL or file path of the audio"
) # noqa: UP007
data: Union[str, None] = Field( # noqa: UP007
None, description="Base64 encoded audio data", repr=False
)
Expand Down Expand Up @@ -293,7 +295,11 @@ def convert_contents(
elif isinstance(content, dict):
converted_contents.append(content)
elif isinstance(content, (Image, Audio)):
if mode in {Mode.ANTHROPIC_JSON, Mode.ANTHROPIC_TOOLS}:
if mode in {
Mode.ANTHROPIC_JSON,
Mode.ANTHROPIC_TOOLS,
Mode.ANTHROPIC_REASONING_TOOLS,
}:
converted_contents.append(content.to_anthropic())
elif mode in {Mode.GEMINI_JSON, Mode.GEMINI_TOOLS}:
raise NotImplementedError("Gemini is not supported yet")
Expand Down Expand Up @@ -339,7 +345,9 @@ def is_image_params(x: Any) -> bool:
}
if autodetect_images:
if isinstance(content, list):
new_content: list[str | dict[str, Any] | Image | Audio] = [] # noqa: UP007
new_content: list[str | dict[str, Any] | Image | Audio] = (
[]
) # noqa: UP007
for item in content:
if isinstance(item, str):
new_content.append(Image.autodetect_safely(item))
Expand Down
43 changes: 34 additions & 9 deletions instructor/process_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
from instructor.mode import Mode
from instructor.dsl.iterable import IterableBase, IterableModel
from instructor.dsl.parallel import (
ParallelBase,
ParallelModel,
handle_parallel_model,
ParallelBase,
ParallelModel,
handle_parallel_model,
get_types_array,
VertexAIParallelBase,
VertexAIParallelModel
VertexAIParallelModel,
)
from instructor.dsl.partial import PartialBase
from instructor.dsl.simple_type import AdapterBase, ModelAdapter, is_simple_type
Expand Down Expand Up @@ -357,6 +357,30 @@ def handle_anthropic_tools(
return response_model, new_kwargs


def handle_anthropic_reasoning_tools(
response_model: type[T], new_kwargs: dict[str, Any]
) -> tuple[type[T], dict[str, Any]]:
# https://docs.anthropic.com/en/docs/build-with-claude/tool-use/overview#forcing-tool-use

response_model, new_kwargs = handle_anthropic_tools(response_model, new_kwargs)

# https://docs.anthropic.com/en/docs/build-with-claude/tool-use/overview#forcing-tool-use
# Reasoning does not allow forced tool use
new_kwargs["tool_choice"] = {"type": "auto"}

# But add a message recommending only to use the tools if they are relevant
implict_forced_tool_message = dedent(
f"""
Return only the tool call and no additional text.
"""
)
new_kwargs["system"] = combine_system_messages(
new_kwargs.get("system"),
[{"type": "text", "text": implict_forced_tool_message}],
)
return response_model, new_kwargs


def handle_anthropic_json(
response_model: type[T], new_kwargs: dict[str, Any]
) -> tuple[type[T], dict[str, Any]]:
Expand Down Expand Up @@ -498,17 +522,17 @@ def handle_vertexai_parallel_tools(
assert (
new_kwargs.get("stream", False) is False
), "stream=True is not supported when using PARALLEL_TOOLS mode"

from instructor.client_vertexai import vertexai_process_response

# Extract concrete types before passing to vertexai_process_response
model_types = list(get_types_array(response_model))
contents, tools, tool_config = vertexai_process_response(new_kwargs, model_types)

new_kwargs["contents"] = contents
new_kwargs["tools"] = tools
new_kwargs["tool_config"] = tool_config

return VertexAIParallelModel(typehint=response_model), new_kwargs


Expand Down Expand Up @@ -612,7 +636,7 @@ def handle_cohere_tools(


def handle_writer_tools(
response_model: type[T], new_kwargs: dict[str, Any]
response_model: type[T], new_kwargs: dict[str, Any]
) -> tuple[type[T], dict[str, Any]]:
new_kwargs["tools"] = [
{
Expand Down Expand Up @@ -745,6 +769,7 @@ def handle_response_model(
Mode.MD_JSON: lambda rm, nk: handle_json_modes(rm, nk, Mode.MD_JSON), # type: ignore
Mode.JSON_SCHEMA: lambda rm, nk: handle_json_modes(rm, nk, Mode.JSON_SCHEMA), # type: ignore
Mode.ANTHROPIC_TOOLS: handle_anthropic_tools,
Mode.ANTHROPIC_REASONING_TOOLS: handle_anthropic_reasoning_tools,
Mode.ANTHROPIC_JSON: handle_anthropic_json,
Mode.COHERE_JSON_SCHEMA: handle_cohere_json_schema,
Mode.COHERE_TOOLS: handle_cohere_tools,
Expand Down
1 change: 1 addition & 0 deletions instructor/reask.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ def handle_reask_kwargs(

functions = {
Mode.ANTHROPIC_TOOLS: reask_anthropic_tools,
Mode.ANTHROPIC_REASONING_TOOLS: reask_anthropic_tools,
Mode.ANTHROPIC_JSON: reask_anthropic_json,
Mode.COHERE_TOOLS: reask_cohere_tools,
Mode.COHERE_JSON_SCHEMA: reask_cohere_tools, # Same Function
Expand Down
51 changes: 13 additions & 38 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
[project]
authors = [
{name = "Jason Liu", email = "[email protected]"},
]
license = {text = "MIT"}
authors = [{ name = "Jason Liu", email = "[email protected]" }]
license = { text = "MIT" }
requires-python = "<4.0,>=3.9"
dependencies = [
"openai<2.0.0,>=1.52.0",
Expand Down Expand Up @@ -59,33 +57,17 @@ test-docs = [
"litellm<2.0.0,>=1.35.31",
"mistralai<2.0.0,>=1.0.3",
]
anthropic = [
"anthropic==0.42.0",
"xmltodict<0.15,>=0.13",
]
groq = [
"groq<0.14.0,>=0.4.2",
]
cohere = [
"cohere<6.0.0,>=5.1.8",
]
anthropic = ["anthropic==0.47.2", "xmltodict<0.15,>=0.13"]
groq = ["groq<0.14.0,>=0.4.2"]
cohere = ["cohere<6.0.0,>=5.1.8"]
google-generativeai = [
"google-generativeai<1.0.0,>=0.8.2",
"jsonref<2.0.0,>=1.1.0",
]
vertexai = [
"google-cloud-aiplatform<2.0.0,>=1.53.0",
"jsonref<2.0.0,>=1.1.0",
]
cerebras_cloud_sdk = [
"cerebras-cloud-sdk<2.0.0,>=1.5.0",
]
fireworks-ai = [
"fireworks-ai<1.0.0,>=0.15.4",
]
writer = [
"writer-sdk<2.0.0,>=1.2.0",
]
vertexai = ["google-cloud-aiplatform<2.0.0,>=1.53.0", "jsonref<2.0.0,>=1.1.0"]
cerebras_cloud_sdk = ["cerebras-cloud-sdk<2.0.0,>=1.5.0"]
fireworks-ai = ["fireworks-ai<1.0.0,>=0.15.4"]
writer = ["writer-sdk<2.0.0,>=1.2.0"]

[project.scripts]
instructor = "instructor.cli.cli:app"
Expand All @@ -112,9 +94,7 @@ docs = [
"mkdocs-redirects<2.0.0,>=1.2.1",
"material>=0.1",
]
anthropic = [
"anthropic==0.42.0",
]
anthropic = ["anthropic==0.47.2"]
test-docs = [
"fastapi<0.116.0,>=0.109.2",
"redis<6.0.0,>=5.0.1",
Expand All @@ -123,7 +103,7 @@ test-docs = [
"tabulate<1.0.0,>=0.9.0",
"pydantic-extra-types<3.0.0,>=2.6.0",
"litellm<2.0.0,>=1.35.31",
"anthropic==0.42.0",
"anthropic==0.47.2",
"xmltodict<0.15,>=0.13",
"groq<0.14.0,>=0.4.2",
"phonenumbers<9.0.0,>=8.13.33",
Expand All @@ -138,14 +118,9 @@ test-docs = [
"datasets<4.0.0,>=3.0.1",
"writer-sdk<2.0.0,>=1.2.0",
]
litellm = [
"litellm<2.0.0,>=1.35.31",
]
litellm = ["litellm<2.0.0,>=1.35.31"]
google-generativeai = [
"google-generativeai<1.0.0,>=0.8.2",
"jsonref<2.0.0,>=1.1.0",
]
vertexai = [
"google-cloud-aiplatform<2.0.0,>=1.53.0",
"jsonref<2.0.0,>=1.1.0",
]
vertexai = ["google-cloud-aiplatform<2.0.0,>=1.53.0", "jsonref<2.0.0,>=1.1.0"]
39 changes: 39 additions & 0 deletions tests/llm/test_anthropic/test_reasoning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import anthropic
import pytest
import instructor
from pydantic import BaseModel
from itertools import product

Check failure on line 5 in tests/llm/test_anthropic/test_reasoning.py

View workflow job for this annotation

GitHub Actions / Ruff (ubuntu-latest)

Ruff (F401)

tests/llm/test_anthropic/test_reasoning.py:5:23: F401 `itertools.product` imported but unused
from anthropic.types.message import Message

Check failure on line 6 in tests/llm/test_anthropic/test_reasoning.py

View workflow job for this annotation

GitHub Actions / Ruff (ubuntu-latest)

Ruff (F401)

tests/llm/test_anthropic/test_reasoning.py:6:37: F401 `anthropic.types.message.Message` imported but unused


class Answer(BaseModel):
answer: float


modes = [
instructor.Mode.ANTHROPIC_REASONING_TOOLS,
instructor.Mode.ANTHROPIC_JSON,
]


@pytest.mark.parametrize("mode", modes)
def test_reasoning(mode):
anthropic_client = anthropic.Anthropic()
client = instructor.from_anthropic(anthropic_client, mode=mode)
response = client.chat.completions.create(
model="claude-3-7-sonnet-latest",
response_model=Answer,
messages=[
{
"role": "user",
"content": "Which is larger, 9.11 or 9.8",
},
],
temperature=1,
max_tokens=2000,
thinking={"type": "enabled", "budget_tokens": 1024},
)

# Assertions to validate the response
assert isinstance(response, Answer)
assert response.answer == 9.8
Loading

0 comments on commit d38c8d0

Please sign in to comment.