Skip to content

Commit

Permalink
fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
baskaryan committed Jun 17, 2024
1 parent 3d315b8 commit 54aa053
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 6 deletions.
48 changes: 43 additions & 5 deletions libs/aws/langchain_aws/chat_models/bedrock_converse.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@ def _messages_to_bedrock(
if isinstance(msg, HumanMessage):
bedrock_messages.append({"role": "user", "content": content})
elif isinstance(msg, AIMessage):
content = _upsert_tool_calls_to_bedrock_content(content, msg.tool_calls)
bedrock_messages.append({"role": "assistant", "content": content})
elif isinstance(msg, SystemMessage):
if isinstance(msg.content, str):
Expand All @@ -514,7 +515,7 @@ def _messages_to_bedrock(

# TODO: Add status once we have ToolMessage.status support.
curr["content"].append(
{"toolResult": {"content": content, "toolUseID": msg.tool_call_id}}
{"toolResult": {"content": content, "toolUseId": msg.tool_call_id}}
)
bedrock_messages.append(curr)
else:
Expand All @@ -529,7 +530,7 @@ def _parse_response(response: Dict[str, Any]) -> AIMessage:
tool_calls = _extract_tool_calls(anthropic_content)
usage = UsageMetadata(_camel_to_snake_keys(response.pop("usage"))) # type: ignore[misc]
return AIMessage(
content=anthropic_content, # type: ignore[arg-type]
content=_str_if_single_text_block(anthropic_content), # type: ignore[arg-type]
usage_metadata=usage,
response_metadata=response,
tool_calls=tool_calls,
Expand Down Expand Up @@ -602,7 +603,7 @@ def _anthropic_to_bedrock(
content: Union[str, List[Union[str, Dict[str, Any]]]],
) -> List[Dict[str, Any]]:
if isinstance(content, str):
return [{"text": content}]
content = [{"text": content}]
bedrock_content: List[Dict[str, Any]] = []
for block in _snake_to_camel_keys(content):
if isinstance(block, str):
Expand Down Expand Up @@ -641,7 +642,7 @@ def _anthropic_to_bedrock(
bedrock_content.append(
{
"toolResult": {
"toolUseID": block["toolUseId"],
"toolUseId": block["toolUseId"],
"content": _anthropic_to_bedrock(content),
}
}
Expand All @@ -651,7 +652,8 @@ def _anthropic_to_bedrock(
bedrock_content.append({"json": block["json"]})
else:
raise ValueError(f"Unsupported content block type:\n{block}")
return bedrock_content
# drop empty text blocks
return [block for block in bedrock_content if block.get("text", True)]


def _bedrock_to_anthropic(content: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
Expand Down Expand Up @@ -778,3 +780,39 @@ def _b64str_to_bytes(base64_str: str) -> bytes:

def _bytes_to_b64_str(bytes_: bytes) -> str:
return base64.b64encode(bytes_).decode("utf-8")


def _str_if_single_text_block(
anthropic_content: List[Dict[str, Any]],
) -> Union[str, List[Dict[str, Any]]]:
if len(anthropic_content) == 1 and anthropic_content[0]["type"] == "text":
return anthropic_content[0]["text"]
return anthropic_content


def _upsert_tool_calls_to_bedrock_content(
content: List[Dict[str, Any]], tool_calls: List[ToolCall]
) -> List[Dict[str, Any]]:
existing_tc_blocks = [block for block in content if "toolUse" in block]
for tool_call in tool_calls:
if tool_call["id"] in [
block["toolUse"]["toolUseId"] for block in existing_tc_blocks
]:
tc_block = next(
block
for block in existing_tc_blocks
if block["toolUse"]["toolUseId"] == tool_call["id"]
)
tc_block["toolUse"]["input"] = tool_call["args"]
tc_block["toolUse"]["name"] = tool_call["name"]
else:
content.append(
{
"toolUse": {
"toolUseId": tool_call["id"],
"input": tool_call["args"],
"name": tool_call["name"],
}
}
)
return content
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ def chat_model_class(self) -> Type[BaseChatModel]:
def chat_model_params(self) -> dict:
return {
"model_id": "anthropic.claude-3-sonnet-20240229-v1:0",
}
}

0 comments on commit 54aa053

Please sign in to comment.