diff --git a/.cursorignore b/.cursorignore new file mode 100644 index 000000000..6f9f00ff4 --- /dev/null +++ b/.cursorignore @@ -0,0 +1 @@ +# Add directories or file patterns to ignore during indexing (e.g. foo/ or *.csv) diff --git a/instructor/__init__.py b/instructor/__init__.py index a23aba983..531462756 100644 --- a/instructor/__init__.py +++ b/instructor/__init__.py @@ -92,6 +92,11 @@ __all__ += ["from_vertexai"] +if importlib.util.find_spec("boto3") is not None: + from .client_bedrock import from_bedrock + + __all__ += ["from_bedrock"] + if importlib.util.find_spec("writerai") is not None: from .client_writer import from_writer @@ -99,4 +104,5 @@ if importlib.util.find_spec("openai") is not None: from .client_perplexity import from_perplexity - __all__ += ["from_perplexity"] \ No newline at end of file + __all__ += ["from_perplexity"] + diff --git a/instructor/client_bedrock.py b/instructor/client_bedrock.py new file mode 100644 index 000000000..21e3133cd --- /dev/null +++ b/instructor/client_bedrock.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from typing import Any, overload +import boto3 +from botocore.client import BaseClient +import instructor +from instructor.client import AsyncInstructor, Instructor + + +@overload +def from_bedrock( + client: boto3.client, + mode: instructor.Mode = instructor.Mode.BEDROCK_TOOLS, + **kwargs: Any, +) -> Instructor: ... + + +@overload +def from_bedrock( + client: boto3.client, + mode: instructor.Mode = instructor.Mode.BEDROCK_TOOLS, + **kwargs: Any, +) -> AsyncInstructor: ... + + +def handle_bedrock_json( + response_model: Any, + new_kwargs: Any, +) -> tuple[Any, Any]: + print(f"handle_bedrock_json: response_model {response_model}") + print(f"handle_bedrock_json: new_kwargs {new_kwargs}") + return response_model, new_kwargs + + +def from_bedrock( + client: BaseClient, + mode: instructor.Mode = instructor.Mode.BEDROCK_JSON, + **kwargs: Any, +) -> Instructor | AsyncInstructor: + assert mode in { + instructor.Mode.BEDROCK_TOOLS, + instructor.Mode.BEDROCK_JSON, + }, "Mode must be one of {instructor.Mode.BEDROCK_TOOLS, instructor.Mode.BEDROCK_JSON}" + assert isinstance( + client, + BaseClient, + ), "Client must be an instance of boto3.client" + create = client.converse # Example method, replace with actual method + + return Instructor( + client=client, + create=instructor.patch(create=create, mode=mode), + provider=instructor.Provider.BEDROCK, + mode=mode, + **kwargs, + ) diff --git a/instructor/function_calls.py b/instructor/function_calls.py index d14bbdf5a..eb2cbc895 100644 --- a/instructor/function_calls.py +++ b/instructor/function_calls.py @@ -1,6 +1,7 @@ # type: ignore import json import logging +import re from functools import wraps from typing import Annotated, Any, Optional, TypeVar, cast from docstring_parser import parse @@ -45,7 +46,9 @@ def openai_schema(cls) -> dict[str, Any]: schema = cls.model_json_schema() docstring = parse(cls.__doc__ or "") parameters = { - k: v for k, v in schema.items() if k not in ("title", "description") + k: v + for k, v in schema.items() + if k not in ("title", "description") } for param in docstring.params: if (name := param.arg_name) in parameters["properties"] and ( @@ -55,7 +58,9 @@ def openai_schema(cls) -> dict[str, Any]: parameters["properties"][name]["description"] = description parameters["required"] = sorted( - k for k, v in parameters["properties"].items() if "default" not in v + k + for k, v in parameters["properties"].items() + if "default" not in v ) if "description" not in schema: @@ -88,7 +93,9 @@ def gemini_schema(cls) -> Any: function = genai_types.FunctionDeclaration( name=cls.openai_schema["name"], description=cls.openai_schema["description"], - parameters=map_to_gemini_function_schema(cls.openai_schema["parameters"]), + parameters=map_to_gemini_function_schema( + cls.openai_schema["parameters"] + ), ) return function @@ -112,32 +119,57 @@ def from_response( Returns: cls (OpenAISchema): An instance of the class """ + + if mode == Mode.ANTHROPIC_TOOLS: + return cls.parse_anthropic_tools( + completion, validation_context, strict + ) + if mode == Mode.ANTHROPIC_TOOLS or mode == Mode.ANTHROPIC_REASONING_TOOLS: return cls.parse_anthropic_tools(completion, validation_context, strict) if mode == Mode.ANTHROPIC_JSON: - return cls.parse_anthropic_json(completion, validation_context, strict) + return cls.parse_anthropic_json( + completion, validation_context, strict + ) + + if mode == Mode.BEDROCK_JSON: + return cls.parse_bedrock_json( + completion, validation_context, strict + ) if mode in {Mode.VERTEXAI_TOOLS, Mode.GEMINI_TOOLS}: return cls.parse_vertexai_tools(completion, validation_context) if mode == Mode.VERTEXAI_JSON: - return cls.parse_vertexai_json(completion, validation_context, strict) + return cls.parse_vertexai_json( + completion, validation_context, strict + ) if mode == Mode.COHERE_TOOLS: - return cls.parse_cohere_tools(completion, validation_context, strict) + return cls.parse_cohere_tools( + completion, validation_context, strict + ) if mode == Mode.GEMINI_JSON: - return cls.parse_gemini_json(completion, validation_context, strict) + return cls.parse_gemini_json( + completion, validation_context, strict + ) if mode == Mode.GEMINI_TOOLS: - return cls.parse_gemini_tools(completion, validation_context, strict) + return cls.parse_gemini_tools( + completion, validation_context, strict + ) if mode == Mode.COHERE_JSON_SCHEMA: - return cls.parse_cohere_json_schema(completion, validation_context, strict) + return cls.parse_cohere_json_schema( + completion, validation_context, strict + ) if mode == Mode.WRITER_TOOLS: - return cls.parse_writer_tools(completion, validation_context, strict) + return cls.parse_writer_tools( + completion, validation_context, strict + ) if completion.choices[0].finish_reason == "length": raise IncompleteOutputException(last_completion=completion) @@ -191,12 +223,17 @@ def parse_anthropic_tools( ) -> BaseModel: from anthropic.types import Message - if isinstance(completion, Message) and completion.stop_reason == "max_tokens": + if ( + isinstance(completion, Message) + and completion.stop_reason == "max_tokens" + ): raise IncompleteOutputException(last_completion=completion) # Anthropic returns arguments as a dict, dump to json for model validation below tool_calls = [ - json.dumps(c.input) for c in completion.content if c.type == "tool_use" + json.dumps(c.input) + for c in completion.content + if c.type == "tool_use" ] # TODO update with anthropic specific types tool_calls_validator = TypeAdapter( @@ -240,7 +277,35 @@ def parse_anthropic_json( # Allow control characters. parsed = json.loads(extra_text, strict=False) # Pydantic non-strict: https://docs.pydantic.dev/latest/concepts/strict_mode/ - return cls.model_validate(parsed, context=validation_context, strict=False) + return cls.model_validate( + parsed, context=validation_context, strict=False + ) + + @classmethod + def parse_bedrock_json( + cls: type[BaseModel], + completion: Any, + validation_context: Optional[dict[str, Any]] = None, + strict: Optional[bool] = None, + ) -> BaseModel: + if isinstance(completion, dict): + text = ( + completion.get("output") + .get("message") + .get("content")[0] + .get("text") + ) + + match = re.search(r"```?json(.*?)```?", text, re.DOTALL) + if match: + text = match.group(1).strip() + + text = re.sub(r"```?json|\\n", "", text).strip() + else: + text = completion.text + return cls.model_validate_json( + text, context=validation_context, strict=strict + ) @classmethod def parse_gemini_json( @@ -259,7 +324,9 @@ def parse_gemini_json( try: extra_text = extract_json_from_codeblock(text) # type: ignore except UnboundLocalError: - raise ValueError("Unable to extract JSON from completion text") from None + raise ValueError( + "Unable to extract JSON from completion text" + ) from None if strict: return cls.model_validate_json( @@ -269,7 +336,9 @@ def parse_gemini_json( # Allow control characters. parsed = json.loads(extra_text, strict=False) # Pydantic non-strict: https://docs.pydantic.dev/latest/concepts/strict_mode/ - return cls.model_validate(parsed, context=validation_context, strict=False) + return cls.model_validate( + parsed, context=validation_context, strict=False + ) @classmethod def parse_vertexai_tools( @@ -282,7 +351,9 @@ def parse_vertexai_tools( for field in tool_call: # type: ignore model[field] = tool_call[field] # We enable strict=False because the conversion from protobuf -> dict often results in types like ints being cast to floats, as a result in order for model.validate to work we need to disable strict mode. - return cls.model_validate(model, context=validation_context, strict=False) + return cls.model_validate( + model, context=validation_context, strict=False + ) @classmethod def parse_vertexai_json( diff --git a/instructor/mode.py b/instructor/mode.py index f36d11723..557f3f2f5 100644 --- a/instructor/mode.py +++ b/instructor/mode.py @@ -29,6 +29,8 @@ class Mode(enum.Enum): FIREWORKS_TOOLS = "fireworks_tools" FIREWORKS_JSON = "fireworks_json" WRITER_TOOLS = "writer_tools" + BEDROCK_TOOLS = "bedrock_tools" + BEDROCK_JSON = "bedrock_json" PERPLEXITY_JSON = "perplexity_json" @classmethod diff --git a/instructor/patch.py b/instructor/patch.py index cc5802672..bb6692156 100644 --- a/instructor/patch.py +++ b/instructor/patch.py @@ -183,7 +183,7 @@ def new_create_sync( **kwargs: T_ParamSpec.kwargs, ) -> T_Model: context = handle_context(context, validation_context) - + # print(f"instructor.patch: patched_function {func.__name__}") response_model, new_kwargs = handle_response_model( response_model=response_model, mode=mode, **kwargs ) # type: ignore @@ -228,6 +228,8 @@ def apatch(client: AsyncOpenAI, mode: Mode = Mode.TOOLS) -> AsyncOpenAI: import warnings warnings.warn( - "apatch is deprecated, use patch instead", DeprecationWarning, stacklevel=2 + "apatch is deprecated, use patch instead", + DeprecationWarning, + stacklevel=2, ) return patch(client, mode=mode) diff --git a/instructor/process_response.py b/instructor/process_response.py index b9a2156bd..674b52554 100644 --- a/instructor/process_response.py +++ b/instructor/process_response.py @@ -14,6 +14,7 @@ from openai.types.chat import ChatCompletion from pydantic import BaseModel, create_model +# from instructor.client_bedrock import handle_bedrock_json from instructor.mode import Mode from instructor.dsl.iterable import IterableBase, IterableModel from instructor.dsl.parallel import ( @@ -25,7 +26,11 @@ VertexAIParallelModel, ) from instructor.dsl.partial import PartialBase -from instructor.dsl.simple_type import AdapterBase, ModelAdapter, is_simple_type +from instructor.dsl.simple_type import ( + AdapterBase, + ModelAdapter, + is_simple_type, +) from instructor.function_calls import OpenAISchema, openai_schema from instructor.utils import ( merge_consecutive_messages, @@ -183,6 +188,7 @@ def process_response( return model.content model._raw_response = response + return model @@ -210,7 +216,9 @@ def handle_functions( ) -> tuple[type[T], dict[str, Any]]: Mode.warn_mode_functions_deprecation() new_kwargs["functions"] = [response_model.openai_schema] - new_kwargs["function_call"] = {"name": response_model.openai_schema["name"]} + new_kwargs["function_call"] = { + "name": response_model.openai_schema["name"] + } return response_model, new_kwargs @@ -311,7 +319,9 @@ def handle_json_modes( "content": "Return the correct JSON response within a ```json codeblock. not the JSON_SCHEMA", }, ) - new_kwargs["messages"] = merge_consecutive_messages(new_kwargs["messages"]) + new_kwargs["messages"] = merge_consecutive_messages( + new_kwargs["messages"] + ) if new_kwargs["messages"][0]["role"] != "system": new_kwargs["messages"].insert( @@ -407,13 +417,16 @@ def handle_anthropic_json( ) new_kwargs["system"] = combine_system_messages( - new_kwargs.get("system"), [{"type": "text", "text": json_schema_message}] + new_kwargs.get("system"), + [{"type": "text", "text": json_schema_message}], ) return response_model, new_kwargs -def handle_cohere_modes(new_kwargs: dict[str, Any]) -> tuple[None, dict[str, Any]]: +def handle_cohere_modes( + new_kwargs: dict[str, Any] +) -> tuple[None, dict[str, Any]]: messages = new_kwargs.pop("messages", []) chat_history = [] for message in messages[:-1]: @@ -483,13 +496,15 @@ def handle_gemini_json( ) if new_kwargs["messages"][0]["role"] != "system": - new_kwargs["messages"].insert(0, {"role": "system", "content": message}) + new_kwargs["messages"].insert( + 0, {"role": "system", "content": message} + ) else: new_kwargs["messages"][0]["content"] += f"\n\n{message}" - new_kwargs["generation_config"] = new_kwargs.get("generation_config", {}) | { - "response_mime_type": "application/json" - } + new_kwargs["generation_config"] = new_kwargs.get( + "generation_config", {} + ) | {"response_mime_type": "application/json"} new_kwargs = update_gemini_kwargs(new_kwargs) return response_model, new_kwargs @@ -527,8 +542,9 @@ def handle_vertexai_parallel_tools( # Extract concrete types before passing to vertexai_process_response model_types = list(get_types_array(response_model)) - contents, tools, tool_config = vertexai_process_response(new_kwargs, model_types) - + contents, tools, tool_config = vertexai_process_response( + new_kwargs, model_types + ) new_kwargs["contents"] = contents new_kwargs["tools"] = tools new_kwargs["tool_config"] = tool_config @@ -541,7 +557,9 @@ def handle_vertexai_tools( ) -> tuple[type[T], dict[str, Any]]: from instructor.client_vertexai import vertexai_process_response - contents, tools, tool_config = vertexai_process_response(new_kwargs, response_model) + contents, tools, tool_config = vertexai_process_response( + new_kwargs, response_model + ) new_kwargs["contents"] = contents new_kwargs["tools"] = tools @@ -563,6 +581,43 @@ def handle_vertexai_json( return response_model, new_kwargs +def handle_bedrock_json( + response_model: type[T], new_kwargs: dict[str, Any] +) -> tuple[type[T], dict[str, Any]]: + json_message = dedent( + f""" + As a genius expert, your task is to understand the content and provide + the parsed objects in json that match the following json_schema:\n + + {json.dumps(response_model.model_json_schema(), indent=2, ensure_ascii=False)} + + Make sure to return an instance of the JSON, not the schema itself + and don't include any other text in the response apart from the json + """ + ) + system_message = new_kwargs.pop("system", None) + if not system_message: + new_kwargs["system"] = [{"text": json_message}] + else: + + if not isinstance(system_message, list): + raise ValueError( + """system must be a list of SystemMessage refer + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-runtime/client/converse.html + """ + ) + system_message.append({"text": json_message}) + new_kwargs["system"] = system_message + + return response_model, new_kwargs + + +def handle_bedrock_tools( + response_model: type[T], new_kwargs: dict[str, Any] +) -> tuple[type[T], dict[str, Any]]: + return response_model, new_kwargs + + def handle_cohere_json_schema( response_model: type[T], new_kwargs: dict[str, Any] ) -> tuple[type[T], dict[str, Any]]: @@ -608,9 +663,9 @@ def handle_cerebras_json( Your response should consist only of a valid JSON object that `{response_model.__name__}.model_validate_json()` can successfully parse. """ - new_kwargs["messages"] = [{"role": "system", "content": instruction}] + new_kwargs[ - "messages" - ] + new_kwargs["messages"] = [ + {"role": "system", "content": instruction} + ] + new_kwargs["messages"] return response_model, new_kwargs @@ -728,6 +783,7 @@ def handle_response_model( """ new_kwargs = kwargs.copy() + # print(f"instructor.process_response.py: new_kwargs -> {new_kwargs}") autodetect_images = new_kwargs.pop("autodetect_images", False) if response_model is None: @@ -743,7 +799,9 @@ def handle_response_model( ) if mode in {Mode.ANTHROPIC_JSON, Mode.ANTHROPIC_TOOLS}: # Handle OpenAI style or Anthropic style messages - new_kwargs["messages"] = [m for m in messages if m["role"] != "system"] + new_kwargs["messages"] = [ + m for m in messages if m["role"] != "system" + ] if "system" not in new_kwargs: system_message = extract_system_messages(messages) if system_message: @@ -782,11 +840,15 @@ def handle_response_model( Mode.FIREWORKS_JSON: handle_fireworks_json, Mode.FIREWORKS_TOOLS: handle_fireworks_tools, Mode.WRITER_TOOLS: handle_writer_tools, + Mode.BEDROCK_JSON: handle_bedrock_json, + Mode.BEDROCK_TOOLS: handle_bedrock_tools, Mode.PERPLEXITY_JSON: handle_perplexity_json, } if mode in mode_handlers: - response_model, new_kwargs = mode_handlers[mode](response_model, new_kwargs) + response_model, new_kwargs = mode_handlers[mode]( + response_model, new_kwargs + ) else: raise ValueError(f"Invalid patch mode: {mode}") @@ -802,7 +864,8 @@ def handle_response_model( "mode": mode.value, "response_model": ( response_model.__name__ - if response_model is not None and hasattr(response_model, "__name__") + if response_model is not None + and hasattr(response_model, "__name__") else str(response_model) ), "new_kwargs": new_kwargs, diff --git a/instructor/reask.py b/instructor/reask.py index 45c72b683..1e23bba5e 100644 --- a/instructor/reask.py +++ b/instructor/reask.py @@ -24,7 +24,9 @@ def reask_anthropic_tools( kwargs = kwargs.copy() from anthropic.types import Message - assert isinstance(response, Message), "Response must be a Anthropic Message" + assert isinstance( + response, Message + ), "Response must be a Anthropic Message" assistant_content = [] tool_use_id = None @@ -71,7 +73,9 @@ def reask_anthropic_json( kwargs = kwargs.copy() from anthropic.types import Message - assert isinstance(response, Message), "Response must be a Anthropic Message" + assert isinstance( + response, Message + ), "Response must be a Anthropic Message" reask_msg = { "role": "user", @@ -120,14 +124,18 @@ def reask_gemini_tools( glm.Part( function_response=glm.FunctionResponse( name=response.parts[0].function_call.name, - response={"error": f"Validation Error(s) found:\n{exception}"}, + response={ + "error": f"Validation Error(s) found:\n{exception}" + }, ) ), ], }, { "role": "user", - "parts": ["Recall the function arguments correctly and fix the errors"], + "parts": [ + "Recall the function arguments correctly and fix the errors" + ], }, ] kwargs["contents"].extend(reask_msgs) @@ -253,6 +261,27 @@ def reask_md_json( return kwargs +def reask_bedrock_json( + kwargs: dict[str, Any], + response: Any, + exception: Exception, +): + kwargs = kwargs.copy() + reask_msgs = [response["output"]["message"]] + reask_msgs.append( + { + "role": "user", + "content": [ + { + "text": f"Correct your JSON ONLY RESPONSE, based on the following errors:\n{exception}" + }, + ], + } + ) + kwargs["messages"].extend(reask_msgs) + return kwargs + + def reask_default( kwargs: dict[str, Any], response: Any, @@ -272,7 +301,9 @@ def reask_default( return kwargs -def reask_fireworks_tools(kwargs: dict[str, Any], response: Any, exception: Exception): +def reask_fireworks_tools( + kwargs: dict[str, Any], response: Any, exception: Exception +): kwargs = kwargs.copy() reask_msgs = [dump_message(response.choices[0].message)] for tool_call in response.choices[0].message.tool_calls: @@ -369,7 +400,10 @@ def handle_reask_kwargs( Mode.FIREWORKS_TOOLS: reask_fireworks_tools, Mode.FIREWORKS_JSON: reask_fireworks_json, Mode.WRITER_TOOLS: reask_writer_tools, + Mode.BEDROCK_JSON: reask_bedrock_json, Mode.PERPLEXITY_JSON: reask_perplexity_json, } reask_function = functions.get(mode, reask_default) - return reask_function(kwargs=kwargs, response=response, exception=exception) + return reask_function( + kwargs=kwargs, response=response, exception=exception + ) diff --git a/instructor/retry.py b/instructor/retry.py index fb374c4b7..e7eb6dd00 100644 --- a/instructor/retry.py +++ b/instructor/retry.py @@ -4,13 +4,16 @@ import logging from json import JSONDecodeError -from typing import Any, Callable, TypeVar +from typing import Any, Callable, TypeVar, Union from instructor.exceptions import InstructorRetryException from instructor.hooks import Hooks from instructor.mode import Mode from instructor.reask import handle_reask_kwargs -from instructor.process_response import process_response, process_response_async +from instructor.process_response import ( + process_response, + process_response_async, +) from instructor.utils import update_total_usage from instructor.validators import AsyncValidationError from openai.types.chat import ChatCompletion @@ -37,7 +40,9 @@ T = TypeVar("T") -def initialize_retrying(max_retries: int | Retrying | AsyncRetrying, is_async: bool): +def initialize_retrying( + max_retries: int | Retrying | AsyncRetrying, is_async: bool +): """ Initialize the retrying mechanism based on the type (synchronous or asynchronous). @@ -82,7 +87,9 @@ def initialize_usage(mode: Mode) -> CompletionUsage | Any: completion_tokens_details=CompletionTokensDetails( audio_tokens=0, reasoning_tokens=0 ), - prompt_tokens_details=PromptTokensDetails(audio_tokens=0, cached_tokens=0), + prompt_tokens_details=PromptTokensDetails( + audio_tokens=0, cached_tokens=0 + ), ) if mode in {Mode.ANTHROPIC_TOOLS, Mode.ANTHROPIC_JSON}: from anthropic.types import Usage as AnthropicUsage @@ -150,7 +157,9 @@ def retry_sync( response = None for attempt in max_retries: with attempt: - logger.debug(f"Retrying, attempt: {attempt.retry_state.attempt_number}") + logger.debug( + f"Retrying, attempt: {attempt.retry_state.attempt_number}" + ) try: hooks.emit_completion_arguments(*args, **kwargs) response = func(*args, **kwargs) @@ -158,6 +167,7 @@ def retry_sync( response = update_total_usage( response=response, total_usage=total_usage ) + return process_response( # type: ignore response=response, response_model=response_model, @@ -184,7 +194,8 @@ def retry_sync( n_attempts=attempt.retry_state.attempt_number, #! deprecate messages soon messages=kwargs.get( - "messages", kwargs.get("contents", kwargs.get("chat_history", [])) + "messages", + kwargs.get("contents", kwargs.get("chat_history", [])), ), create_kwargs=kwargs, total_usage=total_usage, @@ -229,7 +240,9 @@ async def retry_async( try: response = None async for attempt in max_retries: - logger.debug(f"Retrying, attempt: {attempt.retry_state.attempt_number}") + logger.debug( + f"Retrying, attempt: {attempt.retry_state.attempt_number}" + ) with attempt: try: hooks.emit_completion_arguments(*args, **kwargs) @@ -247,7 +260,11 @@ async def retry_async( mode=mode, stream=kwargs.get("stream", False), ) - except (ValidationError, JSONDecodeError, AsyncValidationError) as e: + except ( + ValidationError, + JSONDecodeError, + AsyncValidationError, + ) as e: logger.debug(f"Parse error: {e}") hooks.emit_parse_error(e) kwargs = handle_reask_kwargs( @@ -265,7 +282,8 @@ async def retry_async( n_attempts=attempt.retry_state.attempt_number, #! deprecate messages soon messages=kwargs.get( - "messages", kwargs.get("contents", kwargs.get("chat_history", [])) + "messages", + kwargs.get("contents", kwargs.get("chat_history", [])), ), create_kwargs=kwargs, total_usage=total_usage, diff --git a/instructor/utils.py b/instructor/utils.py index 9525677bf..f2d3cfd4a 100644 --- a/instructor/utils.py +++ b/instructor/utils.py @@ -55,6 +55,7 @@ class Provider(Enum): FIREWORKS = "fireworks" WRITER = "writer" UNKNOWN = "unknown" + BEDROCK = "bedrock" PERPLEXITY = "perplexity" @@ -96,7 +97,9 @@ def extract_json_from_codeblock(content: str) -> str: return content[first_paren : last_paren + 1] -def extract_json_from_stream(chunks: Iterable[str]) -> Generator[str, None, None]: +def extract_json_from_stream( + chunks: Iterable[str], +) -> Generator[str, None, None]: capturing = False brace_count = 0 for chunk in chunks: @@ -144,23 +147,33 @@ def update_total_usage( return None response_usage = getattr(response, "usage", None) - if isinstance(response_usage, OpenAIUsage) and isinstance(total_usage, OpenAIUsage): + if isinstance(response_usage, OpenAIUsage) and isinstance( + total_usage, OpenAIUsage + ): total_usage.completion_tokens += response_usage.completion_tokens or 0 total_usage.prompt_tokens += response_usage.prompt_tokens or 0 total_usage.total_tokens += response_usage.total_tokens or 0 if (rtd := response_usage.completion_tokens_details) and ( ttd := total_usage.completion_tokens_details ): - ttd.audio_tokens = (ttd.audio_tokens or 0) + (rtd.audio_tokens or 0) + ttd.audio_tokens = (ttd.audio_tokens or 0) + ( + rtd.audio_tokens or 0 + ) ttd.reasoning_tokens = (ttd.reasoning_tokens or 0) + ( rtd.reasoning_tokens or 0 ) if (rpd := response_usage.prompt_tokens_details) and ( tpd := total_usage.prompt_tokens_details ): - tpd.audio_tokens = (tpd.audio_tokens or 0) + (rpd.audio_tokens or 0) - tpd.cached_tokens = (tpd.cached_tokens or 0) + (rpd.cached_tokens or 0) - response.usage = total_usage # Replace each response usage with the total usage + tpd.audio_tokens = (tpd.audio_tokens or 0) + ( + rpd.audio_tokens or 0 + ) + tpd.cached_tokens = (tpd.cached_tokens or 0) + ( + rpd.cached_tokens or 0 + ) + response.usage = ( + total_usage # Replace each response usage with the total usage + ) return response # Anthropic usage. @@ -189,7 +202,9 @@ def update_total_usage( except ImportError: pass - logger.debug("No compatible response.usage found, token usage not updated.") + logger.debug( + "No compatible response.usage found, token usage not updated." + ) return response @@ -230,7 +245,9 @@ def is_async(func: Callable[..., Any]) -> bool: return is_coroutine -def merge_consecutive_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: +def merge_consecutive_messages( + messages: list[dict[str, Any]] +) -> list[dict[str, Any]]: # merge all consecutive user messages into a single message new_messages: list[dict[str, Any]] = [] # Detect whether all messages have a flat content (i.e. all string) @@ -242,7 +259,10 @@ def merge_consecutive_messages(messages: list[dict[str, Any]]) -> list[dict[str, # If content is not flat, transform it into a list of text new_content = [{"type": "text", "text": new_content}] - if len(new_messages) > 0 and message["role"] == new_messages[-1]["role"]: + if ( + len(new_messages) > 0 + and message["role"] == new_messages[-1]["role"] + ): if flat_string: # New content is a string new_messages[-1]["content"] += f"\n\n{new_content}" @@ -314,7 +334,9 @@ def transform_to_gemini_prompt( if messages_gemini: messages_gemini[0]["parts"].insert(0, f"*{system_prompt}*") else: - messages_gemini.append({"role": "user", "parts": [f"*{system_prompt}*"]}) + messages_gemini.append( + {"role": "user", "parts": [f"*{system_prompt}*"]} + ) return messages_gemini @@ -356,7 +378,9 @@ def add_enum_format(obj: dict[str, Any]) -> dict[str, Any]: schema = add_enum_format(schema) - return FunctionSchema(**schema).model_dump(exclude_none=True, exclude_unset=True) + return FunctionSchema(**schema).model_dump( + exclude_none=True, exclude_unset=True + ) def update_gemini_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]: @@ -431,8 +455,12 @@ def combine_system_messages( raise ValueError("Unsupported system message type combination") -def extract_system_messages(messages: list[dict[str, Any]]) -> list[SystemMessage]: - def convert_message(content: Union[str, dict[str, Any]]) -> SystemMessage: # noqa: UP007 +def extract_system_messages( + messages: list[dict[str, Any]] +) -> list[SystemMessage]: + def convert_message( + content: Union[str, dict[str, Any]] + ) -> SystemMessage: # noqa: UP007 if isinstance(content, str): return SystemMessage(type="text", text=content) elif isinstance(content, dict): @@ -444,7 +472,9 @@ def convert_message(content: Union[str, dict[str, Any]]) -> SystemMessage: # no for m in messages: if m["role"] == "system": # System message must always be a string or list of dictionaries - content = cast(Union[str, list[dict[str, Any]]], m["content"]) # noqa: UP007 + content = cast( + Union[str, list[dict[str, Any]]], m["content"] + ) # noqa: UP007 if isinstance(content, list): result.extend(convert_message(item) for item in content) else: