From 54aa05347e304af95ba11c2ed86ae154badda6ce Mon Sep 17 00:00:00 2001 From: Bagatur Date: Sun, 16 Jun 2024 19:55:45 -0700 Subject: [PATCH] fmt --- .../chat_models/bedrock_converse.py | 48 +++++++++++++++++-- .../test_bedrock_converse_standard.py | 2 +- 2 files changed, 44 insertions(+), 6 deletions(-) diff --git a/libs/aws/langchain_aws/chat_models/bedrock_converse.py b/libs/aws/langchain_aws/chat_models/bedrock_converse.py index 737d55a5..4063280e 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock_converse.py +++ b/libs/aws/langchain_aws/chat_models/bedrock_converse.py @@ -493,6 +493,7 @@ def _messages_to_bedrock( if isinstance(msg, HumanMessage): bedrock_messages.append({"role": "user", "content": content}) elif isinstance(msg, AIMessage): + content = _upsert_tool_calls_to_bedrock_content(content, msg.tool_calls) bedrock_messages.append({"role": "assistant", "content": content}) elif isinstance(msg, SystemMessage): if isinstance(msg.content, str): @@ -514,7 +515,7 @@ def _messages_to_bedrock( # TODO: Add status once we have ToolMessage.status support. curr["content"].append( - {"toolResult": {"content": content, "toolUseID": msg.tool_call_id}} + {"toolResult": {"content": content, "toolUseId": msg.tool_call_id}} ) bedrock_messages.append(curr) else: @@ -529,7 +530,7 @@ def _parse_response(response: Dict[str, Any]) -> AIMessage: tool_calls = _extract_tool_calls(anthropic_content) usage = UsageMetadata(_camel_to_snake_keys(response.pop("usage"))) # type: ignore[misc] return AIMessage( - content=anthropic_content, # type: ignore[arg-type] + content=_str_if_single_text_block(anthropic_content), # type: ignore[arg-type] usage_metadata=usage, response_metadata=response, tool_calls=tool_calls, @@ -602,7 +603,7 @@ def _anthropic_to_bedrock( content: Union[str, List[Union[str, Dict[str, Any]]]], ) -> List[Dict[str, Any]]: if isinstance(content, str): - return [{"text": content}] + content = [{"text": content}] bedrock_content: List[Dict[str, Any]] = [] for block in _snake_to_camel_keys(content): if isinstance(block, str): @@ -641,7 +642,7 @@ def _anthropic_to_bedrock( bedrock_content.append( { "toolResult": { - "toolUseID": block["toolUseId"], + "toolUseId": block["toolUseId"], "content": _anthropic_to_bedrock(content), } } @@ -651,7 +652,8 @@ def _anthropic_to_bedrock( bedrock_content.append({"json": block["json"]}) else: raise ValueError(f"Unsupported content block type:\n{block}") - return bedrock_content + # drop empty text blocks + return [block for block in bedrock_content if block.get("text", True)] def _bedrock_to_anthropic(content: List[Dict[str, Any]]) -> List[Dict[str, Any]]: @@ -778,3 +780,39 @@ def _b64str_to_bytes(base64_str: str) -> bytes: def _bytes_to_b64_str(bytes_: bytes) -> str: return base64.b64encode(bytes_).decode("utf-8") + + +def _str_if_single_text_block( + anthropic_content: List[Dict[str, Any]], +) -> Union[str, List[Dict[str, Any]]]: + if len(anthropic_content) == 1 and anthropic_content[0]["type"] == "text": + return anthropic_content[0]["text"] + return anthropic_content + + +def _upsert_tool_calls_to_bedrock_content( + content: List[Dict[str, Any]], tool_calls: List[ToolCall] +) -> List[Dict[str, Any]]: + existing_tc_blocks = [block for block in content if "toolUse" in block] + for tool_call in tool_calls: + if tool_call["id"] in [ + block["toolUse"]["toolUseId"] for block in existing_tc_blocks + ]: + tc_block = next( + block + for block in existing_tc_blocks + if block["toolUse"]["toolUseId"] == tool_call["id"] + ) + tc_block["toolUse"]["input"] = tool_call["args"] + tc_block["toolUse"]["name"] = tool_call["name"] + else: + content.append( + { + "toolUse": { + "toolUseId": tool_call["id"], + "input": tool_call["args"], + "name": tool_call["name"], + } + } + ) + return content diff --git a/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse_standard.py b/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse_standard.py index 77b107e9..528a64d8 100644 --- a/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse_standard.py +++ b/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse_standard.py @@ -18,4 +18,4 @@ def chat_model_class(self) -> Type[BaseChatModel]: def chat_model_params(self) -> dict: return { "model_id": "anthropic.claude-3-sonnet-20240229-v1:0", - } \ No newline at end of file + }