diff --git a/libs/aws/langchain_aws/__init__.py b/libs/aws/langchain_aws/__init__.py index 68921e44..384e5b9c 100644 --- a/libs/aws/langchain_aws/__init__.py +++ b/libs/aws/langchain_aws/__init__.py @@ -1,4 +1,4 @@ -from langchain_aws.chat_models import BedrockChat, ChatBedrock +from langchain_aws.chat_models import BedrockChat, ChatBedrock, ChatBedrockConverse from langchain_aws.embeddings import BedrockEmbeddings from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph from langchain_aws.llms import Bedrock, BedrockLLM, SagemakerEndpoint @@ -13,6 +13,7 @@ "BedrockLLM", "BedrockChat", "ChatBedrock", + "ChatBedrockConverse", "SagemakerEndpoint", "AmazonKendraRetriever", "AmazonKnowledgeBasesRetriever", diff --git a/libs/aws/langchain_aws/chat_models/__init__.py b/libs/aws/langchain_aws/chat_models/__init__.py index e334788a..5fcdd69f 100644 --- a/libs/aws/langchain_aws/chat_models/__init__.py +++ b/libs/aws/langchain_aws/chat_models/__init__.py @@ -1,3 +1,4 @@ from langchain_aws.chat_models.bedrock import BedrockChat, ChatBedrock +from langchain_aws.chat_models.bedrock_converse import ChatBedrockConverse -__all__ = ["BedrockChat", "ChatBedrock"] +__all__ = ["BedrockChat", "ChatBedrock", "ChatBedrockConverse"] diff --git a/libs/aws/langchain_aws/chat_models/bedrock.py b/libs/aws/langchain_aws/chat_models/bedrock.py index 4914c45b..fd5a4dc8 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock.py +++ b/libs/aws/langchain_aws/chat_models/bedrock.py @@ -37,6 +37,7 @@ from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough from langchain_core.tools import BaseTool +from langchain_aws.chat_models.bedrock_converse import ChatBedrockConverse from langchain_aws.function_calling import ( ToolsOutputParser, _lc_tool_calls_to_anthropic_tool_use_blocks, @@ -387,6 +388,9 @@ class ChatBedrock(BaseChatModel, BedrockBase): """A chat model that uses the Bedrock API.""" system_prompt_with_tools: str = "" + beta_use_converse_api: bool = False + """Use the new Bedrock ``converse`` API which provides a standardized interface to + all Bedrock models. Support still in beta. See ChatBedrockConverse docs for more.""" @property def _llm_type(self) -> str: @@ -424,6 +428,11 @@ def _stream( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: + if self.beta_use_converse_api: + yield from self._as_converse._stream( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return provider = self._get_provider() prompt, system, formatted_messages = None, None, None @@ -490,6 +499,10 @@ def _generate( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: + if self.beta_use_converse_api: + return self._as_converse._generate( + messages, stop=stop, run_manager=run_manager, **kwargs + ) completion = "" llm_output: Dict[str, Any] = {} tool_calls: List[Dict[str, Any]] = [] @@ -608,6 +621,12 @@ def bind_tools( **kwargs: Any additional parameters to pass to the :class:`~langchain.runnable.Runnable` constructor. """ + if self.beta_use_converse_api: + if isinstance(tool_choice, bool): + tool_choice = "any" if tool_choice else None + return self._as_converse.bind_tools( + tools, tool_choice=tool_choice, **kwargs + ) if self._get_provider() == "anthropic": formatted_tools = [convert_to_anthropic_tool(tool) for tool in tools] @@ -745,6 +764,10 @@ class AnswerWithJustification(BaseModel): # } """ # noqa: E501 + if self.beta_use_converse_api: + return self._as_converse.with_structured_output( + schema, include_raw=include_raw, **kwargs + ) if "claude-3" not in self._get_model(): ValueError( f"Structured output is not supported for model {self._get_model()}" @@ -769,6 +792,23 @@ class AnswerWithJustification(BaseModel): else: return llm | output_parser + @property + def _as_converse(self) -> ChatBedrockConverse: + kwargs = { + k: v + for k, v in (self.model_kwargs or {}).items() + if k in ("stop", "stop_sequences", "max_tokens", "temperature", "top_p") + } + return ChatBedrockConverse( + model=self.model_id, + region_name=self.region_name, + credentials_profile_name=self.credentials_profile_name, + config=self.config, + provider=self.provider or "", + base_url=self.endpoint_url, + **kwargs, + ) + @deprecated(since="0.1.0", removal="0.2.0", alternative="ChatBedrock") class BedrockChat(ChatBedrock): diff --git a/libs/aws/langchain_aws/chat_models/bedrock_converse.py b/libs/aws/langchain_aws/chat_models/bedrock_converse.py new file mode 100644 index 00000000..c1b3748f --- /dev/null +++ b/libs/aws/langchain_aws/chat_models/bedrock_converse.py @@ -0,0 +1,863 @@ +import base64 +import json +import re +from operator import itemgetter +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + Literal, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, + cast, +) + +import boto3 +from langchain_core._api import beta +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models import BaseChatModel, LanguageModelInput +from langchain_core.language_models.chat_models import LangSmithParams +from langchain_core.messages import ( + AIMessage, + BaseMessage, + BaseMessageChunk, + HumanMessage, + HumanMessageChunk, + SystemMessage, + ToolCall, + ToolCallChunk, + ToolMessage, +) +from langchain_core.messages.ai import AIMessageChunk, UsageMetadata +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult +from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator +from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough +from langchain_core.tools import BaseTool +from langchain_core.utils import get_from_dict_or_env +from langchain_core.utils.function_calling import convert_to_openai_function + +from langchain_aws.function_calling import ToolsOutputParser + + +@beta() +class ChatBedrockConverse(BaseChatModel): + """Bedrock chat model integration built on the ``converse`` api. + + Setup: + To use Amazon Bedrock make sure you've gone through all the steps described + here: https://docs.aws.amazon.com/bedrock/latest/userguide/setting-up.html + + Once that's completed, install the LangChain integration: + + .. code-block:: bash + + pip install -U langchain-aws + + Key init args — completion params: + model: str + Name of BedrockConverse model to use. + temperature: float + Sampling temperature. + max_tokens: Optional[int] + Max number of tokens to generate. + + + Key init args — client params: + region_name: Optional[str] + AWS region to use, e.g. 'us-west-2'. + base_url: Optional[str] + Bedrock endpoint to use. Needed if you don't want to default to us-east- + 1 endpoint. + credentials_profile_name: Optional[str] + The name of the profile in the ~/.aws/credentials or ~/.aws/config files. + + See full list of supported init args and their descriptions in the params section. + + # TODO: Replace with relevant init params. + Instantiate: + .. code-block:: python + + from langchain_aws import ChatBedrockConverse + + llm = ChatBedrockConverse( + model="anthropic.claude-3-sonnet-20240229-v1:0", + temperature=0, + max_tokens=None, + # other params... + ) + + Invoke: + .. code-block:: python + + messages = [ + ("system", "You are a helpful translator. Translate the user sentence to French."), + ("human", "I love programming."), + ] + llm.invoke(messages) + + .. code-block:: python + + AIMessage(content=[{'type': 'text', 'text': "J'aime la programmation."}], response_metadata={'ResponseMetadata': {'RequestId': '9ef1e313-a4c1-4f79-b631-171f658d3c0e', 'HTTPStatusCode': 200, 'HTTPHeaders': {'date': 'Sat, 15 Jun 2024 01:19:24 GMT', 'content-type': 'application/json', 'content-length': '205', 'connection': 'keep-alive', 'x-amzn-requestid': '9ef1e313-a4c1-4f79-b631-171f658d3c0e'}, 'RetryAttempts': 0}, 'stopReason': 'end_turn', 'metrics': {'latencyMs': 609}}, id='run-754e152b-2b41-4784-9538-d40d71a5c3bc-0', usage_metadata={'input_tokens': 25, 'output_tokens': 11, 'total_tokens': 36}) + + Stream: + .. code-block:: python + + for chunk in llm.stream(messages): + print(chunk) + + .. code-block:: python + + AIMessageChunk(content=[], id='run-da3c2606-4792-440a-ac66-72e0d1f6d117') + AIMessageChunk(content=[{'type': 'text', 'text': 'J', 'index': 0}], id='run-da3c2606-4792-440a-ac66-72e0d1f6d117') + AIMessageChunk(content=[{'text': "'", 'index': 0}], id='run-da3c2606-4792-440a-ac66-72e0d1f6d117') + AIMessageChunk(content=[{'text': 'a', 'index': 0}], id='run-da3c2606-4792-440a-ac66-72e0d1f6d117') + AIMessageChunk(content=[{'text': 'ime', 'index': 0}], id='run-da3c2606-4792-440a-ac66-72e0d1f6d117') + AIMessageChunk(content=[{'text': ' la', 'index': 0}], id='run-da3c2606-4792-440a-ac66-72e0d1f6d117') + AIMessageChunk(content=[{'text': ' programm', 'index': 0}], id='run-da3c2606-4792-440a-ac66-72e0d1f6d117') + AIMessageChunk(content=[{'text': 'ation', 'index': 0}], id='run-da3c2606-4792-440a-ac66-72e0d1f6d117') + AIMessageChunk(content=[{'text': '.', 'index': 0}], id='run-da3c2606-4792-440a-ac66-72e0d1f6d117') + AIMessageChunk(content=[{'index': 0}], id='run-da3c2606-4792-440a-ac66-72e0d1f6d117') + AIMessageChunk(content=[], response_metadata={'stopReason': 'end_turn'}, id='run-da3c2606-4792-440a-ac66-72e0d1f6d117') + AIMessageChunk(content=[], response_metadata={'metrics': {'latencyMs': 581}}, id='run-da3c2606-4792-440a-ac66-72e0d1f6d117', usage_metadata={'input_tokens': 25, 'output_tokens': 11, 'total_tokens': 36}) + + .. code-block:: python + + stream = llm.stream(messages) + full = next(stream) + for chunk in stream: + full += chunk + full + + .. code-block:: python + + AIMessageChunk(content=[{'type': 'text', 'text': "J'aime la programmation.", 'index': 0}], response_metadata={'stopReason': 'end_turn', 'metrics': {'latencyMs': 554}}, id='run-56a5a5e0-de86-412b-9835-624652dc3539', usage_metadata={'input_tokens': 25, 'output_tokens': 11, 'total_tokens': 36}) + + Tool calling: + .. code-block:: python + + from langchain_core.pydantic_v1 import BaseModel, Field + + class GetWeather(BaseModel): + '''Get the current weather in a given location''' + + location: str = Field(..., description="The city and state, e.g. San Francisco, CA") + + class GetPopulation(BaseModel): + '''Get the current population in a given location''' + + location: str = Field(..., description="The city and state, e.g. San Francisco, CA") + + llm_with_tools = llm.bind_tools([GetWeather, GetPopulation]) + ai_msg = llm_with_tools.invoke("Which city is hotter today and which is bigger: LA or NY?") + ai_msg.tool_calls + + .. code-block:: python + + [{'name': 'GetWeather', + 'args': {'location': 'Los Angeles, CA'}, + 'id': 'tooluse_Mspi2igUTQygp-xbX6XGVw'}, + {'name': 'GetWeather', + 'args': {'location': 'New York, NY'}, + 'id': 'tooluse_tOPHiDhvR2m0xF5_5tyqWg'}, + {'name': 'GetPopulation', + 'args': {'location': 'Los Angeles, CA'}, + 'id': 'tooluse__gcY_klbSC-GqB-bF_pxNg'}, + {'name': 'GetPopulation', + 'args': {'location': 'New York, NY'}, + 'id': 'tooluse_-1HSoGX0TQCSaIg7cdFy8Q'}] + + See ``ChatBedrockConverse.bind_tools()`` method for more. + + Structured output: + .. code-block:: python + + from typing import Optional + + from langchain_core.pydantic_v1 import BaseModel, Field + + class Joke(BaseModel): + '''Joke to tell user.''' + + setup: str = Field(description="The setup of the joke") + punchline: str = Field(description="The punchline to the joke") + rating: Optional[int] = Field(description="How funny the joke is, from 1 to 10") + + structured_llm = llm.with_structured_output(Joke) + structured_llm.invoke("Tell me a joke about cats") + + .. code-block:: python + + Joke(setup='What do you call a cat that gets all dressed up?', punchline='A purrfessional!', rating=7) + + See ``ChatBedrockConverse.with_structured_output()`` for more. + + Image input: + .. code-block:: python + + import base64 + import httpx + from langchain_core.messages import HumanMessage + + image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + image_data = base64.b64encode(httpx.get(image_url).content).decode("utf-8") + message = HumanMessage( + content=[ + {"type": "text", "text": "describe the weather in this image"}, + { + "type": "image", + "source": {"type": "base64", "media_type": "image/jpeg", "data": image_data}, + }, + ], + ) + ai_msg = llm.invoke([message]) + ai_msg.content + + .. code-block:: python + + [{'type': 'text', + 'text': 'The image depicts a sunny day with a partly cloudy sky. The sky is a brilliant blue color with scattered white clouds drifting across. The lighting and cloud patterns suggest pleasant, mild weather conditions. The scene shows an open grassy field or meadow, indicating warm temperatures conducive for vegetation growth. Overall, the weather portrayed in this scenic outdoor image appears to be sunny with some clouds, likely representing a nice, comfortable day.'}] + + Token usage: + .. code-block:: python + + ai_msg = llm.invoke(messages) + ai_msg.usage_metadata + + .. code-block:: python + + {'input_tokens': 25, 'output_tokens': 11, 'total_tokens': 36} + + Response metadata + .. code-block:: python + + ai_msg = llm.invoke(messages) + ai_msg.response_metadata + + .. code-block:: python + + {'ResponseMetadata': {'RequestId': '776a2a26-5946-45ae-859e-82dc5f12017c', + 'HTTPStatusCode': 200, + 'HTTPHeaders': {'date': 'Mon, 17 Jun 2024 01:37:05 GMT', + 'content-type': 'application/json', + 'content-length': '206', + 'connection': 'keep-alive', + 'x-amzn-requestid': '776a2a26-5946-45ae-859e-82dc5f12017c'}, + 'RetryAttempts': 0}, + 'stopReason': 'end_turn', + 'metrics': {'latencyMs': 1290}} + """ # noqa: E501 + + client: Any = Field(exclude=True) #: :meta private: + + model_id: str = Field(alias="model") + """Id of the model to call. + + e.g., ``"anthropic.claude-3-sonnet-20240229-v1:0"``. This is equivalent to the + modelID property in the list-foundation-models api. For custom and provisioned + models, an ARN value is expected. See + https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html#model-ids-arns + for a list of all supported built-in models. + """ + + max_tokens: Optional[int] = None + """Max tokens to generate.""" + + stop_sequences: Optional[List[str]] = Field(None, alias="stop") + """Stop generation if any of these substrings occurs.""" + + temperature: Optional[float] = None + """Sampling temperature. Must be 0 to 1.""" + + top_p: Optional[float] = None + """The percentage of most-likely candidates that are considered for the next token. + + Must be 0 to 1. + + For example, if you choose a value of 0.8 for topP, the model selects from + the top 80% of the probability distribution of tokens that could be next in the + sequence.""" + + region_name: Optional[str] = None + """The aws region, e.g., `us-west-2`. + + Falls back to AWS_DEFAULT_REGION env variable or region specified in ~/.aws/config + in case it is not provided here. + """ + + credentials_profile_name: Optional[str] = Field(default=None, exclude=True) + """The name of the profile in the ~/.aws/credentials or ~/.aws/config files. + + Profile should either have access keys or role information specified. + If not specified, the default credential profile or, if on an EC2 instance, + credentials from IMDS will be used. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html + """ + + provider: str = "" + """The model provider, e.g., amazon, cohere, ai21, etc. + + When not supplied, provider is extracted from the first part of the model_id, e.g. + 'amazon' in 'amazon.titan-text-express-v1'. This value should be provided for model + ids that do not have the provider in them, like custom and provisioned models that + have an ARN associated with them. + """ + + endpoint_url: Optional[str] = Field(None, alias="base_url") + """Needed if you don't want to default to us-east-1 endpoint""" + + config: Any = None + """An optional botocore.config.Config instance to pass to the client.""" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + allow_population_by_field_name = True + + @root_validator(pre=False, skip_on_failure=True) + def validate_environment(cls, values: Dict) -> Dict: + """Validate that AWS credentials to and python package exists in environment.""" + values["provider"] = values["provider"] or values["model_id"].split(".")[0] + + if values["client"] is not None: + return values + + try: + if values["credentials_profile_name"] is not None: + session = boto3.Session(profile_name=values["credentials_profile_name"]) + else: + session = boto3.Session() + except ValueError as e: + raise ValueError(f"Error raised by bedrock service: {e}") + except Exception as e: + raise ValueError( + "Could not load credentials to authenticate with AWS client. " + "Please check that credentials in the specified " + f"profile name are valid. Bedrock error: {e}" + ) from e + + values["region_name"] = get_from_dict_or_env( + values, + "region_name", + "AWS_DEFAULT_REGION", + default=session.region_name, + ) + + client_params = {} + if values["region_name"]: + client_params["region_name"] = values["region_name"] + if values["endpoint_url"]: + client_params["endpoint_url"] = values["endpoint_url"] + if values["config"]: + client_params["config"] = values["config"] + + try: + values["client"] = session.client("bedrock-runtime", **client_params) + except ValueError as e: + raise ValueError(f"Error raised by bedrock service: {e}") + except Exception as e: + raise ValueError( + "Could not load credentials to authenticate with AWS client. " + "Please check that credentials in the specified " + f"profile name are valid. Bedrock error: {e}" + ) from e + + return values + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + """Top Level call""" + bedrock_messages, system = _messages_to_bedrock(messages) + params = self._converse_params(stop=stop, **_snake_to_camel_keys(kwargs)) + response = self.client.converse( + messages=bedrock_messages, system=system, **params + ) + response_message = _parse_response(response) + return ChatResult(generations=[ChatGeneration(message=response_message)]) + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + bedrock_messages, system = _messages_to_bedrock(messages) + params = self._converse_params(stop=stop, **_snake_to_camel_keys(kwargs)) + response = self.client.converse_stream( + messages=bedrock_messages, system=system, **params + ) + for event in response["stream"]: + if message_chunk := _parse_stream_event(event): + yield ChatGenerationChunk(message=message_chunk) + + # TODO: Add async support once there are async bedrock.converse methods. + + def bind_tools( + self, + tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], + *, + tool_choice: Optional[Union[dict, str, Literal["auto", "any"]]] = None, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + if tool_choice: + kwargs["tool_choice"] = _format_tool_choice(tool_choice) + return self.bind(tools=_format_tools(tools), **kwargs) + + def with_structured_output( + self, + schema: Union[Dict, Type[BaseModel]], + *, + include_raw: bool = False, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: + llm = self.bind_tools([schema], tool_choice="any") + if isinstance(schema, type) and issubclass(schema, BaseModel): + output_parser = ToolsOutputParser( + first_tool_only=True, pydantic_schemas=[schema] + ) + else: + output_parser = ToolsOutputParser(first_tool_only=True, args_only=True) + + if include_raw: + parser_assign = RunnablePassthrough.assign( + parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None + ) + parser_none = RunnablePassthrough.assign(parsed=lambda _: None) + parser_with_fallback = parser_assign.with_fallbacks( + [parser_none], exception_key="parsing_error" + ) + return RunnableMap(raw=llm) | parser_with_fallback + else: + return llm | output_parser + + def _converse_params( + self, + *, + stop: Optional[List[str]] = None, + stopSequences: Optional[List[str]] = None, + maxTokens: Optional[List[str]] = None, + temperature: Optional[float] = None, + topP: Optional[float] = None, + tools: Optional[List] = None, + toolChoice: Optional[dict] = None, + modelId: Optional[str] = None, + inferenceConfig: Optional[dict] = None, + toolConfig: Optional[dict] = None, + additionalModelRequestFields: Optional[dict] = None, + additionalModelResponseFieldPaths: Optional[List[str]] = None, + ) -> Dict[str, Any]: + if not inferenceConfig: + inferenceConfig = { + "maxTokens": maxTokens or self.max_tokens, + "temperature": temperature or self.temperature, + "topP": self.top_p or topP, + "stopSequences": stop or stopSequences or self.stop_sequences, + } + if not toolConfig and tools: + toolChoice = _format_tool_choice(toolChoice) if toolChoice else None + toolConfig = {"tools": _format_tools(tools), "toolChoice": toolChoice} + + return _drop_none( + { + "modelId": modelId or self.model_id, + "inferenceConfig": inferenceConfig, + "toolConfig": toolConfig, + "additionalModelRequestFields": additionalModelRequestFields, + "additionalModelResponseFieldPaths": additionalModelResponseFieldPaths, + } + ) + + def _get_ls_params( + self, stop: Optional[List[str]] = None, **kwargs: Any + ) -> LangSmithParams: + """Get standard params for tracing.""" + params = self._get_invocation_params(stop=stop, **kwargs) + ls_params = LangSmithParams( + ls_provider="amazon_bedrock", + ls_model_name=self.model_id, + ls_model_type="chat", + ls_temperature=params.get("temperature", self.temperature), + ) + if ls_max_tokens := params.get("max_tokens", self.max_tokens): + ls_params["ls_max_tokens"] = ls_max_tokens + if ls_stop := stop or params.get("stop", None): + ls_params["ls_stop"] = ls_stop + return ls_params + + @property + def _llm_type(self) -> str: + """Return type of chat model.""" + return "amazon_bedrock_converse_chat" + + +def _messages_to_bedrock( + messages: List[BaseMessage], +) -> Tuple[List[Dict[str, Any]], List[Dict[Literal["text"], str]]]: + """Handle Bedrock converse and Anthropic style content blocks""" + bedrock_messages: List[Dict[str, Any]] = [] + bedrock_system: List[Dict[Literal["text"], str]] = [] + for msg in messages: + content = _anthropic_to_bedrock(msg.content) + if isinstance(msg, HumanMessage): + bedrock_messages.append({"role": "user", "content": content}) + elif isinstance(msg, AIMessage): + content = _upsert_tool_calls_to_bedrock_content(content, msg.tool_calls) + bedrock_messages.append({"role": "assistant", "content": content}) + elif isinstance(msg, SystemMessage): + if isinstance(msg.content, str): + bedrock_system.append({"text": msg.content}) + else: + bedrock_system.extend( + [ + { + "text": block["text"] if isinstance(block, dict) else block + for block in msg.content + } + ] + ) + elif isinstance(msg, ToolMessage): + if bedrock_messages and bedrock_messages[-1]["role"] == "user": + curr = bedrock_messages.pop() + else: + curr = {"role": "user", "content": []} + + # TODO: Add status once we have ToolMessage.status support. + curr["content"].append( + {"toolResult": {"content": content, "toolUseId": msg.tool_call_id}} + ) + bedrock_messages.append(curr) + else: + raise ValueError() + return bedrock_messages, bedrock_system + + +def _parse_response(response: Dict[str, Any]) -> AIMessage: + anthropic_content = _bedrock_to_anthropic( + response.pop("output")["message"]["content"] + ) + tool_calls = _extract_tool_calls(anthropic_content) + usage = UsageMetadata(_camel_to_snake_keys(response.pop("usage"))) # type: ignore[misc] + return AIMessage( + content=_str_if_single_text_block(anthropic_content), # type: ignore[arg-type] + usage_metadata=usage, + response_metadata=response, + tool_calls=tool_calls, + ) + + +def _parse_stream_event(event: Dict[str, Any]) -> Optional[BaseMessageChunk]: + if "messageStart" in event: + # TODO: needed? + return ( + AIMessageChunk(content=[]) + if event["messageStart"]["role"] == "assistant" + else HumanMessageChunk(content=[]) + ) + elif "contentBlockStart" in event: + block = { + **_bedrock_to_anthropic([event["contentBlockStart"]["start"]])[0], + "index": event["contentBlockStart"]["contentBlockIndex"], + } + tool_call_chunks = [] + if block["type"] == "tool_use": + tool_call_chunks.append( + ToolCallChunk( + name=block.get("name"), + id=block.get("id"), + args=block.get("input"), + index=event["contentBlockStart"]["contentBlockIndex"], + ) + ) + return AIMessageChunk(content=[block], tool_call_chunks=tool_call_chunks) + elif "contentBlockDelta" in event: + block = { + **_bedrock_to_anthropic([event["contentBlockDelta"]["delta"]])[0], + "index": event["contentBlockDelta"]["contentBlockIndex"], + } + tool_call_chunks = [] + if block["type"] == "tool_use": + tool_call_chunks.append( + ToolCallChunk( + name=block.get("name"), + id=block.get("id"), + args=block.get("input"), + index=event["contentBlockDelta"]["contentBlockIndex"], + ) + ) + return AIMessageChunk(content=[block], tool_call_chunks=tool_call_chunks) + elif "contentBlockStop" in event: + # TODO: needed? + return AIMessageChunk( + content=[{"index": event["contentBlockStop"]["contentBlockIndex"]}] + ) + elif "messageStop" in event: + # TODO: snake case response metadata? + return AIMessageChunk(content=[], response_metadata=event["messageStop"]) + elif "metadata" in event: + usage = UsageMetadata(_camel_to_snake_keys(event["metadata"].pop("usage"))) # type: ignore[misc] + return AIMessageChunk( + content=[], response_metadata=event["metadata"], usage_metadata=usage + ) + elif "Exception" in list(event.keys())[0]: + name, info = list(event.items())[0] + raise ValueError( + f"Received AWS exception {name}:\n\n{json.dumps(info, indent=2)}" + ) + else: + raise ValueError(f"Received unsupported stream event:\n\n{event}") + + +def _anthropic_to_bedrock( + content: Union[str, List[Union[str, Dict[str, Any]]]], +) -> List[Dict[str, Any]]: + if isinstance(content, str): + content = [{"text": content}] + bedrock_content: List[Dict[str, Any]] = [] + for block in _snake_to_camel_keys(content): + if isinstance(block, str): + bedrock_content.append({"text": block}) + # Assume block is already in bedrock format. + elif "type" not in block: + bedrock_content.append(block) + elif block["type"] == "text": + bedrock_content.append({"text": block["text"]}) + elif block["type"] == "image": + # Assume block is already in bedrock format. + if "image" in block: + bedrock_content.append(block) + else: + bedrock_content.append( + { + "image": { + "format": block["source"]["mediaType"].split("/")[1], + "source": { + "bytes": _b64str_to_bytes(block["source"]["data"]) + }, + } + } + ) + elif block["type"] == "image_url": + # Support OpenAI image format as well. + bedrock_content.append( + {"image": _format_openai_image_url(block["imageUrl"]["url"])} + ) + elif block["type"] == "tool_use": + bedrock_content.append( + { + "toolUse": { + "toolUseId": block["id"], + "input": block["input"], + "name": block["name"], + } + } + ) + elif block["type"] == "tool_result": + bedrock_content.append( + { + "toolResult": { + "toolUseId": block["toolUseId"], + "content": _anthropic_to_bedrock(content), + } + } + ) + # Only needed for tool_result content blocks. + elif block["type"] == "json": + bedrock_content.append({"json": block["json"]}) + else: + raise ValueError(f"Unsupported content block type:\n{block}") + # drop empty text blocks + return [block for block in bedrock_content if block.get("text", True)] + + +def _bedrock_to_anthropic(content: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + anthropic_content = [] + for block in _camel_to_snake_keys(content): + if "text" in block: + anthropic_content.append({"type": "text", "text": block["text"]}) + elif "tool_use" in block: + block["tool_use"]["id"] = block["tool_use"].pop("tool_use_id", None) + anthropic_content.append({"type": "tool_use", **block["tool_use"]}) + elif "image" in block: + anthropic_content.append( + { + "type": "image", + "source": { + "media_type": f"image/{block['image']['format']}", + "type": "base64", + "data": _bytes_to_b64_str(block["image"]["source"]["bytes"]), + }, + } + ) + elif "tool_result" in block: + anthropic_content.append( + { + "type": "tool_result", + "tool_use_id": block["tool_result"]["tool_use_id"], + "is_error": block["tool_result"]["status"] == "success", + "content": _bedrock_to_anthropic(block["tool_result"]["content"]), + } + ) + # Only occurs in content blocks of a tool_result: + elif "json" in block: + anthropic_content.append({"type": "json", **block}) + else: + raise ValueError( + "Unexpected content block type in content. Expected to have one of " + "'text', 'tool_use', 'image', or 'tool_result' keys. Received:\n\n" + f"{block}" + ) + return anthropic_content + + +def _format_tools( + tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool],], +) -> List[Dict[Literal["toolSpec"], Dict[str, Union[Dict[str, Any], str]]]]: + formatted_tools: List = [] + for tool in tools: + if isinstance(tool, dict) and "toolSpec" in tool: + formatted_tools.append(tool) + else: + spec = convert_to_openai_function(tool) + spec["inputSchema"] = {"json": spec.pop("parameters")} + formatted_tools.append({"toolSpec": spec}) + return formatted_tools + + +def _format_tool_choice( + tool_choice: Union[Dict[str, Dict], Literal["auto", "any"], str], +) -> Dict[str, Dict[str, str]]: + if isinstance(tool_choice, dict): + return tool_choice + elif tool_choice in ("auto", "any"): + return {tool_choice: {}} + else: + return {"tool": {"name": tool_choice}} + + +def _extract_tool_calls(anthropic_content: List[dict]) -> List[ToolCall]: + tool_calls = [] + for block in anthropic_content: + if block["type"] == "tool_use": + tool_calls.append( + ToolCall(name=block["name"], args=block["input"], id=block["id"]) + ) + return tool_calls + + +def _snake_to_camel(text: str) -> str: + split = text.split("_") + return "".join(split[:1] + [s.title() for s in split[1:]]) + + +def _camel_to_snake(text: str) -> str: + pattern = re.compile(r"(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])") + return pattern.sub("_", text).lower() + + +_T = TypeVar("_T") + + +def _camel_to_snake_keys(obj: _T) -> _T: + if isinstance(obj, list): + return cast(_T, [_camel_to_snake_keys(e) for e in obj]) + elif isinstance(obj, dict): + return cast( + _T, {_camel_to_snake(k): _camel_to_snake_keys(v) for k, v in obj.items()} + ) + else: + return obj + + +def _snake_to_camel_keys(obj: _T) -> _T: + if isinstance(obj, list): + return cast(_T, [_snake_to_camel_keys(e) for e in obj]) + elif isinstance(obj, dict): + return cast( + _T, {_snake_to_camel(k): _snake_to_camel_keys(v) for k, v in obj.items()} + ) + else: + return obj + + +def _drop_none(obj: Any) -> Any: + if isinstance(obj, dict): + new = {k: _drop_none(v) for k, v in obj.items() if _drop_none(v) is not None} + return new or None + else: + return obj + + +def _b64str_to_bytes(base64_str: str) -> bytes: + return base64.b64decode(base64_str.encode("utf-8")) + + +def _bytes_to_b64_str(bytes_: bytes) -> str: + return base64.b64encode(bytes_).decode("utf-8") + + +def _str_if_single_text_block( + anthropic_content: List[Dict[str, Any]], +) -> Union[str, List[Dict[str, Any]]]: + if len(anthropic_content) == 1 and anthropic_content[0]["type"] == "text": + return anthropic_content[0]["text"] + return anthropic_content + + +def _upsert_tool_calls_to_bedrock_content( + content: List[Dict[str, Any]], tool_calls: List[ToolCall] +) -> List[Dict[str, Any]]: + existing_tc_blocks = [block for block in content if "toolUse" in block] + for tool_call in tool_calls: + if tool_call["id"] in [ + block["toolUse"]["toolUseId"] for block in existing_tc_blocks + ]: + tc_block = next( + block + for block in existing_tc_blocks + if block["toolUse"]["toolUseId"] == tool_call["id"] + ) + tc_block["toolUse"]["input"] = tool_call["args"] + tc_block["toolUse"]["name"] = tool_call["name"] + else: + content.append( + { + "toolUse": { + "toolUseId": tool_call["id"], + "input": tool_call["args"], + "name": tool_call["name"], + } + } + ) + return content + + +def _format_openai_image_url(image_url: str) -> Dict: + """ + Formats an image of format data:image/jpeg;base64,{b64_string} + to a dict for bedrock api. + + And throws an error if url is not a b64 image. + """ + regex = r"^data:image/(?P.+);base64,(?P.+)$" + match = re.match(regex, image_url) + if match is None: + raise ValueError( + "Bedrock does not currently support OpenAI-format image URLs, only " + "base64-encoded images. Example: data:image/png;base64,'/9j/4AAQSk'..." + ) + return { + "format": match.group("media_type"), + "source": {"bytes": _b64str_to_bytes(match.group("data"))}, + } diff --git a/libs/aws/langchain_aws/function_calling.py b/libs/aws/langchain_aws/function_calling.py index 1e2c53e1..c33f709b 100644 --- a/libs/aws/langchain_aws/function_calling.py +++ b/libs/aws/langchain_aws/function_calling.py @@ -175,17 +175,14 @@ def parse_result(self, result: List[Generation], *, partial: bool = False) -> An Returns: Structured output. """ - if not result or not isinstance(result[0], ChatGeneration): + if ( + not result + or not isinstance(result[0], ChatGeneration) + or not isinstance(result[0].message, AIMessage) + or not result[0].message.tool_calls + ): return None if self.first_tool_only else [] - message = result[0].message - if len(message.content) > 0: - tool_calls: List = [] - else: - content = cast(AIMessage, message) - _tool_calls = [dict(tc) for tc in content.tool_calls] - # Map tool call id to index - id_to_index = {block["id"]: i for i, block in enumerate(_tool_calls)} - tool_calls = [{**tc, "index": id_to_index[tc["id"]]} for tc in _tool_calls] + tool_calls: Any = result[0].message.tool_calls if self.pydantic_schemas: tool_calls = [self._pydantic_parse(tc) for tc in tool_calls] elif self.args_only: @@ -194,11 +191,11 @@ def parse_result(self, result: List[Generation], *, partial: bool = False) -> An pass if self.first_tool_only: - return tool_calls[0] if tool_calls else None + return tool_calls[0] else: - return [tool_call for tool_call in tool_calls] + return tool_calls - def _pydantic_parse(self, tool_call: dict) -> BaseModel: + def _pydantic_parse(self, tool_call: ToolCall) -> BaseModel: cls_ = {schema.__name__: schema for schema in self.pydantic_schemas or []}[ tool_call["name"] ] diff --git a/libs/aws/poetry.lock b/libs/aws/poetry.lock index d0f665c2..f24b1f11 100644 --- a/libs/aws/poetry.lock +++ b/libs/aws/poetry.lock @@ -16,17 +16,17 @@ typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.9\""} [[package]] name = "boto3" -version = "1.34.106" +version = "1.34.127" description = "The AWS SDK for Python" optional = false python-versions = ">=3.8" files = [ - {file = "boto3-1.34.106-py3-none-any.whl", hash = "sha256:d3be4e1dd5d546a001cd4da805816934cbde9d395316546e9411fec341ade5cf"}, - {file = "boto3-1.34.106.tar.gz", hash = "sha256:6165b8cf1c7e625628ab28b32f9027064c8f5e5fca1c38d7fc228cd22069a19f"}, + {file = "boto3-1.34.127-py3-none-any.whl", hash = "sha256:d370befe4fb7aea5bc383057d7dad18dda5d0cf3cd3295915bcc8c8c4191905c"}, + {file = "boto3-1.34.127.tar.gz", hash = "sha256:58ccdeae3a96811ecc9d5d866d8226faadbd0ee1891756e4a04d5186e9a57a64"}, ] [package.dependencies] -botocore = ">=1.34.106,<1.35.0" +botocore = ">=1.34.127,<1.35.0" jmespath = ">=0.7.1,<2.0.0" s3transfer = ">=0.10.0,<0.11.0" @@ -35,13 +35,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] [[package]] name = "botocore" -version = "1.34.106" +version = "1.34.127" description = "Low-level, data-driven core of boto 3." optional = false python-versions = ">=3.8" files = [ - {file = "botocore-1.34.106-py3-none-any.whl", hash = "sha256:4baf0e27c2dfc4f4d0dee7c217c716e0782f9b30e8e1fff983fce237d88f73ae"}, - {file = "botocore-1.34.106.tar.gz", hash = "sha256:921fa5202f88c3e58fdcb4b3acffd56d65b24bca47092ee4b27aa988556c0be6"}, + {file = "botocore-1.34.127-py3-none-any.whl", hash = "sha256:e14fa28c8bb141de965e700f88b196d17c67a703c7f0f5c7e14f7dd1cf636011"}, + {file = "botocore-1.34.127.tar.gz", hash = "sha256:a377871742c40603d559103f19acb7bc93cfaf285e68f21b81637ec396099877"}, ] [package.dependencies] @@ -53,7 +53,7 @@ urllib3 = [ ] [package.extras] -crt = ["awscrt (==0.20.9)"] +crt = ["awscrt (==0.20.11)"] [[package]] name = "certifi" @@ -334,7 +334,7 @@ files = [ [[package]] name = "langchain-core" -version = "0.2.5" +version = "0.2.6" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" @@ -344,7 +344,7 @@ develop = false [package.dependencies] jsonpatch = "^1.33" langsmith = "^0.1.75" -packaging = "^23.2" +packaging = ">=23.2,<25" pydantic = ">=1,<3" PyYAML = ">=5.3" tenacity = "^8.1.0" @@ -353,12 +353,12 @@ tenacity = "^8.1.0" type = "git" url = "https://github.com/langchain-ai/langchain.git" reference = "HEAD" -resolved_reference = "00ad19750255008e6f7a86b4c0e89530a4b2a0cc" +resolved_reference = "7234fd0f51fbcc3444253fa528208a1f7a8829c3" subdirectory = "libs/core" [[package]] name = "langchain-standard-tests" -version = "0.1.0" +version = "0.1.1" description = "Standard tests for LangChain implementations" optional = false python-versions = ">=3.8.1,<4.0" @@ -373,7 +373,7 @@ pytest = ">=7,<9" type = "git" url = "https://github.com/langchain-ai/langchain.git" reference = "HEAD" -resolved_reference = "936aedd10cdc22cb19d33b255ca65a8ddab3ab50" +resolved_reference = "6605ae22f6001d7428eca57c55d9f22c521abe6f" subdirectory = "libs/standard-tests" [[package]] @@ -1011,4 +1011,4 @@ zstd = ["zstandard (>=0.18.0)"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "fd4ce90ec2f2c93efaf779201bdef7e4ac8ae76b74b15693ddc80311f37f5f71" +content-hash = "d0e65eb18a8405ef09838722e42c2c4b76007fd7062cadadc4391f9476372751" diff --git a/libs/aws/pyproject.toml b/libs/aws/pyproject.toml index 7c34a6e0..04a519c2 100644 --- a/libs/aws/pyproject.toml +++ b/libs/aws/pyproject.toml @@ -12,8 +12,8 @@ license = "MIT" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" -langchain-core = ">=0.2.2,<0.3" -boto3 = ">=1.34.51,<1.35.0" +langchain-core = ">=0.2.6,<0.3" +boto3 = ">=1.34.127,<1.35.0" numpy = "^1" [tool.poetry.group.test] @@ -27,6 +27,7 @@ pytest-asyncio = "^0.23.2" langchain-core = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/core" } langchain-standard-tests = {git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/standard-tests"} + [tool.poetry.group.codespell] optional = true @@ -45,7 +46,7 @@ optional = true ruff = "^0.1.8" [tool.poetry.group.typing.dependencies] -mypy = "^1.7.1" +mypy = "^1.7" types-requests = "^2.28.11.5" langchain-core = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/core" } diff --git a/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py b/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py new file mode 100644 index 00000000..7c0429b0 --- /dev/null +++ b/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py @@ -0,0 +1,30 @@ +"""Standard LangChain interface tests""" + +from typing import Type + +from langchain_core.language_models import BaseChatModel +from langchain_standard_tests.integration_tests import ChatModelIntegrationTests + +from langchain_aws import ChatBedrockConverse + + +class TestBedrockStandard(ChatModelIntegrationTests): + @property + def chat_model_class(self) -> Type[BaseChatModel]: + return ChatBedrockConverse + + @property + def chat_model_params(self) -> dict: + return {"model": "anthropic.claude-3-sonnet-20240229-v1:0"} + + @property + def standard_chat_model_params(self) -> dict: + return { + "temperature": 0, + "max_tokens": 100, + "stop": [], + } + + @property + def supports_image_inputs(self) -> bool: + return True diff --git a/libs/aws/tests/integration_tests/chat_models/test_standard.py b/libs/aws/tests/integration_tests/chat_models/test_standard.py index f9b211cf..6d60d769 100644 --- a/libs/aws/tests/integration_tests/chat_models/test_standard.py +++ b/libs/aws/tests/integration_tests/chat_models/test_standard.py @@ -10,67 +10,68 @@ class TestBedrockStandard(ChatModelIntegrationTests): - @pytest.fixture + @property def chat_model_class(self) -> Type[BaseChatModel]: return ChatBedrock - @pytest.fixture + @property def chat_model_params(self) -> dict: - return { - "model_id": "anthropic.claude-3-sonnet-20240229-v1:0", - } + return {"model_id": "anthropic.claude-3-sonnet-20240229-v1:0"} + + @property + def standard_chat_model_params(self) -> dict: + return {} @pytest.mark.xfail(reason="Not implemented.") - def test_usage_metadata( - self, - chat_model_class: Type[BaseChatModel], - chat_model_params: dict, - ) -> None: - super().test_usage_metadata( - chat_model_class, - chat_model_params, - ) + def test_usage_metadata(self, model: BaseChatModel) -> None: + super().test_usage_metadata(model) @pytest.mark.xfail(reason="Not implemented.") - def test_stop_sequence( - self, - chat_model_class: Type[BaseChatModel], - chat_model_params: dict, - ) -> None: - super().test_stop_sequence( - chat_model_class, - chat_model_params, - ) + def test_stop_sequence(self, model: BaseChatModel) -> None: + super().test_stop_sequence(model) @pytest.mark.xfail(reason="Not yet implemented.") - def test_tool_message_histories_string_content( - self, - chat_model_class: Type[BaseChatModel], - chat_model_params: dict, - chat_model_has_tool_calling: bool, - ) -> None: - super().test_tool_message_histories_string_content( - chat_model_class, chat_model_params, chat_model_has_tool_calling - ) + def test_tool_message_histories_string_content(self, model: BaseChatModel) -> None: + super().test_tool_message_histories_string_content(model) @pytest.mark.xfail(reason="Not yet implemented.") - def test_tool_message_histories_list_content( - self, - chat_model_class: Type[BaseChatModel], - chat_model_params: dict, - chat_model_has_tool_calling: bool, - ) -> None: - super().test_tool_message_histories_list_content( - chat_model_class, chat_model_params, chat_model_has_tool_calling - ) + def test_tool_message_histories_list_content(self, model: BaseChatModel) -> None: + super().test_tool_message_histories_list_content(model) @pytest.mark.xfail(reason="Not yet implemented.") def test_structured_few_shot_examples( self, - chat_model_class: Type[BaseChatModel], - chat_model_params: dict, - chat_model_has_tool_calling: bool, + model: BaseChatModel, ) -> None: - super().test_structured_few_shot_examples( - chat_model_class, chat_model_params, chat_model_has_tool_calling - ) + super().test_structured_few_shot_examples(model) + + +class TestBedrockUseConverseStandard(ChatModelIntegrationTests): + @property + def chat_model_class(self) -> Type[BaseChatModel]: + return ChatBedrock + + @property + def chat_model_params(self) -> dict: + return { + "model_id": "anthropic.claude-3-sonnet-20240229-v1:0", + "beta_use_converse_api": True, + } + + @property + def standard_chat_model_params(self) -> dict: + return { + "model_kwargs": { + "temperature": 0, + "max_tokens": 100, + "stop": [], + } + } + + @property + def supports_image_inputs(self) -> bool: + return True + + @pytest.mark.xfail(reason="Not implemented.") + def test_stop_sequence(self, model: BaseChatModel) -> None: + super().test_stop_sequence(model) diff --git a/libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py b/libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py new file mode 100644 index 00000000..facfe06c --- /dev/null +++ b/libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py @@ -0,0 +1,68 @@ +"""Test chat model integration.""" + +from typing import Type, cast + +import pytest +from langchain_core.language_models import BaseChatModel +from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.runnables import RunnableBinding +from langchain_standard_tests.unit_tests import ChatModelUnitTests + +from langchain_aws import ChatBedrockConverse + + +class TestBedrockStandard(ChatModelUnitTests): + @property + def chat_model_class(self) -> Type[BaseChatModel]: + return ChatBedrockConverse + + @property + def chat_model_params(self) -> dict: + return { + "model": "anthropic.claude-3-sonnet-20240229-v1:0", + "region_name": "us-west-1", + } + + @property + def standard_chat_model_params(self) -> dict: + return { + "temperature": 0, + "max_tokens": 100, + "stop": [], + } + + @pytest.mark.xfail() + def test_init_streaming(self) -> None: + super().test_init_streaming() + + +class GetWeather(BaseModel): + """Get the current weather in a given location""" + + location: str = Field(..., description="The city and state, e.g. San Francisco, CA") + + +def test_anthropic_bind_tools_tool_choice() -> None: + chat_model = ChatBedrockConverse( + model="anthropic.claude-3-sonnet-20240229-v1:0", region_name="us-west-2" + ) # type: ignore[call-arg] + chat_model_with_tools = chat_model.bind_tools( + [GetWeather], tool_choice={"tool": {"name": "GetWeather"}} + ) + assert cast(RunnableBinding, chat_model_with_tools).kwargs["tool_choice"] == { + "tool": {"name": "GetWeather"} + } + chat_model_with_tools = chat_model.bind_tools( + [GetWeather], tool_choice="GetWeather" + ) + assert cast(RunnableBinding, chat_model_with_tools).kwargs["tool_choice"] == { + "tool": {"name": "GetWeather"} + } + chat_model_with_tools = chat_model.bind_tools([GetWeather], tool_choice="auto") + assert cast(RunnableBinding, chat_model_with_tools).kwargs["tool_choice"] == { + "auto": {} + } + chat_model_with_tools = chat_model.bind_tools([GetWeather], tool_choice="any") + assert cast(RunnableBinding, chat_model_with_tools).kwargs["tool_choice"] == { + "any": {} + } diff --git a/libs/aws/tests/unit_tests/test_standard.py b/libs/aws/tests/unit_tests/test_standard.py index a414e16e..951636e6 100644 --- a/libs/aws/tests/unit_tests/test_standard.py +++ b/libs/aws/tests/unit_tests/test_standard.py @@ -10,35 +10,49 @@ class TestBedrockStandard(ChatModelUnitTests): - @pytest.fixture + @property def chat_model_class(self) -> Type[BaseChatModel]: return ChatBedrock - @pytest.fixture + @property def chat_model_params(self) -> dict: return { "model_id": "anthropic.claude-3-sonnet-20240229-v1:0", "region_name": "us-east-1", } + @property + def standard_chat_model_params(self) -> dict: + return {} + @pytest.mark.xfail(reason="Not implemented.") - def test_chat_model_init_api_key( - self, - chat_model_class: Type[BaseChatModel], - chat_model_params: dict, - ) -> None: - super().test_chat_model_init_api_key( - chat_model_class, - chat_model_params, - ) + def test_standard_params(self, model: BaseChatModel) -> None: + super().test_standard_params(model) + + +class TestBedrockAsConverseStandard(ChatModelUnitTests): + @property + def chat_model_class(self) -> Type[BaseChatModel]: + return ChatBedrock + + @property + def chat_model_params(self) -> dict: + return { + "model_id": "anthropic.claude-3-sonnet-20240229-v1:0", + "region_name": "us-east-1", + "beta_use_converse_api": True, + } + + @property + def standard_chat_model_params(self) -> dict: + return { + "model_kwargs": { + "temperature": 0, + "max_tokens": 100, + "stop": [], + } + } @pytest.mark.xfail(reason="Not implemented.") - def test_standard_params( - self, - chat_model_class: Type[BaseChatModel], - chat_model_params: dict, - ) -> None: - super().test_standard_params( - chat_model_class, - chat_model_params, - ) + def test_standard_params(self, model: BaseChatModel) -> None: + super().test_standard_params(model)