Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds support for function calling with Anthropic models on Bedrock #37

Merged
merged 4 commits into from
May 3, 2024

Conversation

bigbernnn
Copy link
Contributor

Workaround for Bedrock support of function calling with Anthropic models. This change adds that bind_tools function to ChatBedrock.

from bedrock import ChatBedrock

chat = ChatBedrock(
    model_id=model_id,
    model_kwargs={"temperature": 0.1},
)

class GetWeather(BaseModel):
    """Get the current weather in a given location"""

    location: str = Field(..., description="The city and state, e.g. San Francisco, CA")

llm_with_tools = chat.bind_tools([GetWeather])
llm_with_tools

messages = [
    HumanMessage(
        content="what is the weather like in San Francisco"
    )
]
ai_msg = llm_with_tools.invoke(messages)
ai_msg

The workaround is implemented similarly to its equivalent directly using Anthropic prior to the feature currently in beta.

Copy link
Collaborator

@3coins 3coins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bigbernnn
Thanks for working on this change. Can you add integration tests or examples to test this.

libs/aws/langchain_aws/chat_models/bedrock.py Outdated Show resolved Hide resolved
@bigbernnn
Copy link
Contributor Author

@bigbernnn Thanks for working on this change. Can you add integration tests or examples to test this.

Integration tests added for both generate and stream responses.

Copy link
Collaborator

@3coins 3coins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

@3coins 3coins merged commit 123c720 into langchain-ai:main May 3, 2024
12 checks passed
@zhongyu09
Copy link

Hi @bigbernnn, @3coins, from my side it seems the tool calling doesn't work for langgraph that relied on tool_calls in AIMessage. In line 399 of bedrock.py, when creating AIMessage, no tool_calls is passed to additional_kwargs.

@thiagotps
Copy link
Contributor

Hi @3coins @bigbernnn. The current implementation also seems to ignore the stop_reason and stop_sequence fields returned by the Bedrock API when using the Claude 3 models.

@mhussar
Copy link

mhussar commented Jun 6, 2024

has any progress been made on this. There doesn't seem to be a standard result that can be parsed to receive the response required. it seems that we would have to rely on a hack using regular expressions

@ROZBEH
Copy link

ROZBEH commented Jun 25, 2024

Hi,
Thanks for working on this.
Any progress has been made on this?
Examples listed on this page(tool calling) wouldn't work with the model below:

llm = ChatBedrock(
    model_id="anthropic.claude-3-sonnet-20240229-v1:0",
    model_kwargs={"temperature": 0.0},
)

The error is something like:

{
	"name": "KeyError",
	"message": "'tool'",
	"stack": "---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[2], line 41
     33 few_shot_prompt = ChatPromptTemplate.from_messages(
     34     [
     35         (\"system\", system),
   (...)
     38     ]
     39 )
     40 chain = {\"query\": RunnablePassthrough()} | few_shot_prompt | llm_with_tools
---> 41 ai_msg = chain.invoke(\"Whats 119 times 8 minus 20\")
     42 msgs = []
     43 msgs.append(HumanMessage(\"Whats 119 times 8 minus 20\"))

File ~/.pyenv/versions/3.11.7/envs/3.11/lib/python3.11/site-packages/langchain_core/runnables/base.py:2504, in RunnableSequence.invoke(self, input, config, **kwargs)
   2502             input = step.invoke(input, config, **kwargs)
   2503         else:
-> 2504             input = step.invoke(input, config)
   2505 # finish the root run
   2506 except BaseException as e:

File ~/.pyenv/versions/3.11.7/envs/3.11/lib/python3.11/site-packages/langchain_core/language_models/chat_models.py:170, in BaseChatModel.invoke(self, input, config, stop, **kwargs)
    159 def invoke(
    160     self,
    161     input: LanguageModelInput,
   (...)
    165     **kwargs: Any,
    166 ) -> BaseMessage:
    167     config = ensure_config(config)
    168     return cast(
    169         ChatGeneration,
--> 170         self.generate_prompt(
    171             [self._convert_input(input)],
    172             stop=stop,
    173             callbacks=config.get(\"callbacks\"),
    174             tags=config.get(\"tags\"),
    175             metadata=config.get(\"metadata\"),
    176             run_name=config.get(\"run_name\"),
    177             run_id=config.pop(\"run_id\", None),
    178             **kwargs,
    179         ).generations[0][0],
    180     ).message

File ~/.pyenv/versions/3.11.7/envs/3.11/lib/python3.11/site-packages/langchain_core/language_models/chat_models.py:599, in BaseChatModel.generate_prompt(self, prompts, stop, callbacks, **kwargs)
    591 def generate_prompt(
    592     self,
    593     prompts: List[PromptValue],
   (...)
    596     **kwargs: Any,
    597 ) -> LLMResult:
    598     prompt_messages = [p.to_messages() for p in prompts]
--> 599     return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs)

File ~/.pyenv/versions/3.11.7/envs/3.11/lib/python3.11/site-packages/langchain_core/language_models/chat_models.py:456, in BaseChatModel.generate(self, messages, stop, callbacks, tags, metadata, run_name, run_id, **kwargs)
    454         if run_managers:
    455             run_managers[i].on_llm_error(e, response=LLMResult(generations=[]))
--> 456         raise e
    457 flattened_outputs = [
    458     LLMResult(generations=[res.generations], llm_output=res.llm_output)  # type: ignore[list-item]
    459     for res in results
    460 ]
    461 llm_output = self._combine_llm_outputs([res.llm_output for res in results])

File ~/.pyenv/versions/3.11.7/envs/3.11/lib/python3.11/site-packages/langchain_core/language_models/chat_models.py:446, in BaseChatModel.generate(self, messages, stop, callbacks, tags, metadata, run_name, run_id, **kwargs)
    443 for i, m in enumerate(messages):
    444     try:
    445         results.append(
--> 446             self._generate_with_cache(
    447                 m,
    448                 stop=stop,
    449                 run_manager=run_managers[i] if run_managers else None,
    450                 **kwargs,
    451             )
    452         )
    453     except BaseException as e:
    454         if run_managers:

File ~/.pyenv/versions/3.11.7/envs/3.11/lib/python3.11/site-packages/langchain_core/language_models/chat_models.py:671, in BaseChatModel._generate_with_cache(self, messages, stop, run_manager, **kwargs)
    669 else:
    670     if inspect.signature(self._generate).parameters.get(\"run_manager\"):
--> 671         result = self._generate(
    672             messages, stop=stop, run_manager=run_manager, **kwargs
    673         )
    674     else:
    675         result = self._generate(messages, stop=stop, **kwargs)

File ~/.pyenv/versions/3.11.7/envs/3.11/lib/python3.11/site-packages/langchain_aws/chat_models/bedrock.py:423, in ChatBedrock._generate(self, messages, stop, run_manager, **kwargs)
    420 params: Dict[str, Any] = {**kwargs}
    422 if provider == \"anthropic\":
--> 423     system, formatted_messages = ChatPromptAdapter.format_messages(
    424         provider, messages
    425     )
    426     if self.system_prompt_with_tools:
    427         if system:

File ~/.pyenv/versions/3.11.7/envs/3.11/lib/python3.11/site-packages/langchain_aws/chat_models/bedrock.py:312, in ChatPromptAdapter.format_messages(cls, provider, messages)
    307 @classmethod
    308 def format_messages(
    309     cls, provider: str, messages: List[BaseMessage]
    310 ) -> Tuple[Optional[str], List[Dict]]:
    311     if provider == \"anthropic\":
--> 312         return _format_anthropic_messages(messages)
    314     raise NotImplementedError(
    315         f\"Provider {provider} not supported for format_messages\"
    316     )

File ~/.pyenv/versions/3.11.7/envs/3.11/lib/python3.11/site-packages/langchain_aws/chat_models/bedrock.py:228, in _format_anthropic_messages(messages)
    225     system = message.content
    226     continue
--> 228 role = _message_type_lookups[message.type]
    229 content: Union[str, List[Dict]]
    231 if not isinstance(message.content, str):
    232     # parse as dict

KeyError: 'tool'"
}

@vb-rob
Copy link

vb-rob commented Jun 27, 2024

Try the new ChatBedrockConverse model instead (still in beta). It supports tool calling, streaming, structured outputs, etc. Git repo here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants