Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

aws[patch]: Add ToolCall "type" #111

Merged
merged 1 commit into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 10 additions & 13 deletions libs/aws/langchain_aws/chat_models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,8 @@
)

from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import BaseChatModel, LanguageModelInput
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
Expand All @@ -32,7 +29,7 @@
SystemMessage,
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.messages.tool import ToolCall, ToolMessage
from langchain_core.messages.tool import ToolCall, ToolMessage, tool_call_chunk
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Extra
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
Expand Down Expand Up @@ -445,12 +442,12 @@ def _stream(
message = result.generations[0].message
if isinstance(message, AIMessage) and message.tool_calls is not None:
tool_call_chunks = [
{
"name": tool_call["name"],
"args": json.dumps(tool_call["args"]),
"id": tool_call["id"],
"index": idx,
}
tool_call_chunk(
name=tool_call["name"],
args=json.dumps(tool_call["args"]),
id=tool_call["id"],
index=idx,
)
for idx, tool_call in enumerate(message.tool_calls)
]
message_chunk = AIMessageChunk(
Expand Down Expand Up @@ -512,7 +509,7 @@ def _generate(
)
completion = ""
llm_output: Dict[str, Any] = {}
tool_calls: List[Dict[str, Any]] = []
tool_calls: List[ToolCall] = []
provider_stop_reason_code = self.provider_stop_reason_key_map.get(
self._get_provider(), "stop_reason"
)
Expand Down
11 changes: 7 additions & 4 deletions libs/aws/langchain_aws/chat_models/bedrock_converse.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@
HumanMessageChunk,
SystemMessage,
ToolCall,
ToolCallChunk,
ToolMessage,
)
from langchain_core.messages.ai import AIMessageChunk, UsageMetadata
from langchain_core.messages.tool import tool_call as create_tool_call
from langchain_core.messages.tool import tool_call_chunk
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
Expand Down Expand Up @@ -591,7 +592,7 @@ def _parse_stream_event(event: Dict[str, Any]) -> Optional[BaseMessageChunk]:
tool_call_chunks = []
if block["type"] == "tool_use":
tool_call_chunks.append(
ToolCallChunk(
tool_call_chunk(
name=block.get("name"),
id=block.get("id"),
args=block.get("input"),
Expand All @@ -607,7 +608,7 @@ def _parse_stream_event(event: Dict[str, Any]) -> Optional[BaseMessageChunk]:
tool_call_chunks = []
if block["type"] == "tool_use":
tool_call_chunks.append(
ToolCallChunk(
tool_call_chunk(
name=block.get("name"),
id=block.get("id"),
args=block.get("input"),
Expand Down Expand Up @@ -782,7 +783,9 @@ def _extract_tool_calls(anthropic_content: List[dict]) -> List[ToolCall]:
for block in anthropic_content:
if block["type"] == "tool_use":
tool_calls.append(
ToolCall(name=block["name"], args=block["input"], id=block["id"])
create_tool_call(
name=block["name"], args=block["input"], id=block["id"]
)
)
return tool_calls

Expand Down
5 changes: 3 additions & 2 deletions libs/aws/langchain_aws/llms/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)
from langchain_core.language_models import LLM, BaseLanguageModel
from langchain_core.messages import ToolCall
from langchain_core.messages.tool import tool_call
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.pydantic_v1 import Extra, Field, root_validator
from langchain_core.utils import get_from_dict_or_env
Expand Down Expand Up @@ -199,7 +200,7 @@ def extract_tool_calls(content: List[dict]) -> List[ToolCall]:
if block["type"] != "tool_use":
continue
tool_calls.append(
ToolCall(name=block["name"], args=block["input"], id=block["id"])
tool_call(name=block["name"], args=block["input"], id=block["id"])
)
return tool_calls

Expand Down Expand Up @@ -632,7 +633,7 @@ def _prepare_input_and_invoke(
**kwargs: Any,
) -> Tuple[
str,
List[dict],
List[ToolCall],
Dict[str, Any],
]:
_model_kwargs = self.model_kwargs or {}
Expand Down
Loading
Loading