Skip to content

Commit

Permalink
feat: support for aws bedrock using boto3 (#1287)
Browse files Browse the repository at this point in the history
Co-authored-by: Jason Liu <[email protected]>
  • Loading branch information
imZain448 and jxnl authored Mar 3, 2025
1 parent d38c8d0 commit 77517f9
Show file tree
Hide file tree
Showing 10 changed files with 349 additions and 66 deletions.
1 change: 1 addition & 0 deletions .cursorignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Add directories or file patterns to ignore during indexing (e.g. foo/ or *.csv)
8 changes: 7 additions & 1 deletion instructor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,17 @@

__all__ += ["from_vertexai"]

if importlib.util.find_spec("boto3") is not None:
from .client_bedrock import from_bedrock

__all__ += ["from_bedrock"]

if importlib.util.find_spec("writerai") is not None:
from .client_writer import from_writer

__all__ += ["from_writer"]

if importlib.util.find_spec("openai") is not None:
from .client_perplexity import from_perplexity
__all__ += ["from_perplexity"]
__all__ += ["from_perplexity"]

56 changes: 56 additions & 0 deletions instructor/client_bedrock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from __future__ import annotations

from typing import Any, overload
import boto3
from botocore.client import BaseClient
import instructor
from instructor.client import AsyncInstructor, Instructor


@overload
def from_bedrock(
client: boto3.client,
mode: instructor.Mode = instructor.Mode.BEDROCK_TOOLS,
**kwargs: Any,
) -> Instructor: ...


@overload
def from_bedrock(
client: boto3.client,
mode: instructor.Mode = instructor.Mode.BEDROCK_TOOLS,
**kwargs: Any,
) -> AsyncInstructor: ...


def handle_bedrock_json(
response_model: Any,
new_kwargs: Any,
) -> tuple[Any, Any]:
print(f"handle_bedrock_json: response_model {response_model}")
print(f"handle_bedrock_json: new_kwargs {new_kwargs}")
return response_model, new_kwargs


def from_bedrock(
client: BaseClient,
mode: instructor.Mode = instructor.Mode.BEDROCK_JSON,
**kwargs: Any,
) -> Instructor | AsyncInstructor:
assert mode in {
instructor.Mode.BEDROCK_TOOLS,
instructor.Mode.BEDROCK_JSON,
}, "Mode must be one of {instructor.Mode.BEDROCK_TOOLS, instructor.Mode.BEDROCK_JSON}"
assert isinstance(
client,
BaseClient,
), "Client must be an instance of boto3.client"
create = client.converse # Example method, replace with actual method

return Instructor(
client=client,
create=instructor.patch(create=create, mode=mode),
provider=instructor.Provider.BEDROCK,
mode=mode,
**kwargs,
)
103 changes: 87 additions & 16 deletions instructor/function_calls.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# type: ignore
import json
import logging
import re
from functools import wraps
from typing import Annotated, Any, Optional, TypeVar, cast
from docstring_parser import parse
Expand Down Expand Up @@ -45,7 +46,9 @@ def openai_schema(cls) -> dict[str, Any]:
schema = cls.model_json_schema()
docstring = parse(cls.__doc__ or "")
parameters = {
k: v for k, v in schema.items() if k not in ("title", "description")
k: v
for k, v in schema.items()
if k not in ("title", "description")
}
for param in docstring.params:
if (name := param.arg_name) in parameters["properties"] and (
Expand All @@ -55,7 +58,9 @@ def openai_schema(cls) -> dict[str, Any]:
parameters["properties"][name]["description"] = description

parameters["required"] = sorted(
k for k, v in parameters["properties"].items() if "default" not in v
k
for k, v in parameters["properties"].items()
if "default" not in v
)

if "description" not in schema:
Expand Down Expand Up @@ -88,7 +93,9 @@ def gemini_schema(cls) -> Any:
function = genai_types.FunctionDeclaration(
name=cls.openai_schema["name"],
description=cls.openai_schema["description"],
parameters=map_to_gemini_function_schema(cls.openai_schema["parameters"]),
parameters=map_to_gemini_function_schema(
cls.openai_schema["parameters"]
),
)
return function

Expand All @@ -112,32 +119,57 @@ def from_response(
Returns:
cls (OpenAISchema): An instance of the class
"""

if mode == Mode.ANTHROPIC_TOOLS:
return cls.parse_anthropic_tools(
completion, validation_context, strict
)

if mode == Mode.ANTHROPIC_TOOLS or mode == Mode.ANTHROPIC_REASONING_TOOLS:
return cls.parse_anthropic_tools(completion, validation_context, strict)

if mode == Mode.ANTHROPIC_JSON:
return cls.parse_anthropic_json(completion, validation_context, strict)
return cls.parse_anthropic_json(
completion, validation_context, strict
)

if mode == Mode.BEDROCK_JSON:
return cls.parse_bedrock_json(
completion, validation_context, strict
)

if mode in {Mode.VERTEXAI_TOOLS, Mode.GEMINI_TOOLS}:
return cls.parse_vertexai_tools(completion, validation_context)

if mode == Mode.VERTEXAI_JSON:
return cls.parse_vertexai_json(completion, validation_context, strict)
return cls.parse_vertexai_json(
completion, validation_context, strict
)

if mode == Mode.COHERE_TOOLS:
return cls.parse_cohere_tools(completion, validation_context, strict)
return cls.parse_cohere_tools(
completion, validation_context, strict
)

if mode == Mode.GEMINI_JSON:
return cls.parse_gemini_json(completion, validation_context, strict)
return cls.parse_gemini_json(
completion, validation_context, strict
)

if mode == Mode.GEMINI_TOOLS:
return cls.parse_gemini_tools(completion, validation_context, strict)
return cls.parse_gemini_tools(
completion, validation_context, strict
)

if mode == Mode.COHERE_JSON_SCHEMA:
return cls.parse_cohere_json_schema(completion, validation_context, strict)
return cls.parse_cohere_json_schema(
completion, validation_context, strict
)

if mode == Mode.WRITER_TOOLS:
return cls.parse_writer_tools(completion, validation_context, strict)
return cls.parse_writer_tools(
completion, validation_context, strict
)

if completion.choices[0].finish_reason == "length":
raise IncompleteOutputException(last_completion=completion)
Expand Down Expand Up @@ -191,12 +223,17 @@ def parse_anthropic_tools(
) -> BaseModel:
from anthropic.types import Message

if isinstance(completion, Message) and completion.stop_reason == "max_tokens":
if (
isinstance(completion, Message)
and completion.stop_reason == "max_tokens"
):
raise IncompleteOutputException(last_completion=completion)

# Anthropic returns arguments as a dict, dump to json for model validation below
tool_calls = [
json.dumps(c.input) for c in completion.content if c.type == "tool_use"
json.dumps(c.input)
for c in completion.content
if c.type == "tool_use"
] # TODO update with anthropic specific types

tool_calls_validator = TypeAdapter(
Expand Down Expand Up @@ -240,7 +277,35 @@ def parse_anthropic_json(
# Allow control characters.
parsed = json.loads(extra_text, strict=False)
# Pydantic non-strict: https://docs.pydantic.dev/latest/concepts/strict_mode/
return cls.model_validate(parsed, context=validation_context, strict=False)
return cls.model_validate(
parsed, context=validation_context, strict=False
)

@classmethod
def parse_bedrock_json(
cls: type[BaseModel],
completion: Any,
validation_context: Optional[dict[str, Any]] = None,
strict: Optional[bool] = None,
) -> BaseModel:
if isinstance(completion, dict):
text = (
completion.get("output")
.get("message")
.get("content")[0]
.get("text")
)

match = re.search(r"```?json(.*?)```?", text, re.DOTALL)
if match:
text = match.group(1).strip()

text = re.sub(r"```?json|\\n", "", text).strip()
else:
text = completion.text
return cls.model_validate_json(
text, context=validation_context, strict=strict
)

@classmethod
def parse_gemini_json(
Expand All @@ -259,7 +324,9 @@ def parse_gemini_json(
try:
extra_text = extract_json_from_codeblock(text) # type: ignore
except UnboundLocalError:
raise ValueError("Unable to extract JSON from completion text") from None
raise ValueError(
"Unable to extract JSON from completion text"
) from None

if strict:
return cls.model_validate_json(
Expand All @@ -269,7 +336,9 @@ def parse_gemini_json(
# Allow control characters.
parsed = json.loads(extra_text, strict=False)
# Pydantic non-strict: https://docs.pydantic.dev/latest/concepts/strict_mode/
return cls.model_validate(parsed, context=validation_context, strict=False)
return cls.model_validate(
parsed, context=validation_context, strict=False
)

@classmethod
def parse_vertexai_tools(
Expand All @@ -282,7 +351,9 @@ def parse_vertexai_tools(
for field in tool_call: # type: ignore
model[field] = tool_call[field]
# We enable strict=False because the conversion from protobuf -> dict often results in types like ints being cast to floats, as a result in order for model.validate to work we need to disable strict mode.
return cls.model_validate(model, context=validation_context, strict=False)
return cls.model_validate(
model, context=validation_context, strict=False
)

@classmethod
def parse_vertexai_json(
Expand Down
2 changes: 2 additions & 0 deletions instructor/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class Mode(enum.Enum):
FIREWORKS_TOOLS = "fireworks_tools"
FIREWORKS_JSON = "fireworks_json"
WRITER_TOOLS = "writer_tools"
BEDROCK_TOOLS = "bedrock_tools"
BEDROCK_JSON = "bedrock_json"
PERPLEXITY_JSON = "perplexity_json"

@classmethod
Expand Down
6 changes: 4 additions & 2 deletions instructor/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def new_create_sync(
**kwargs: T_ParamSpec.kwargs,
) -> T_Model:
context = handle_context(context, validation_context)

# print(f"instructor.patch: patched_function {func.__name__}")
response_model, new_kwargs = handle_response_model(
response_model=response_model, mode=mode, **kwargs
) # type: ignore
Expand Down Expand Up @@ -228,6 +228,8 @@ def apatch(client: AsyncOpenAI, mode: Mode = Mode.TOOLS) -> AsyncOpenAI:
import warnings

warnings.warn(
"apatch is deprecated, use patch instead", DeprecationWarning, stacklevel=2
"apatch is deprecated, use patch instead",
DeprecationWarning,
stacklevel=2,
)
return patch(client, mode=mode)
Loading

0 comments on commit 77517f9

Please sign in to comment.