Skip to content

Commit

Permalink
Merge branch 'main' into bedrock-token-count-callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
NAPTlME authored May 3, 2024
2 parents 64b3578 + 123c720 commit 1526f5a
Show file tree
Hide file tree
Showing 4 changed files with 292 additions and 35 deletions.
69 changes: 67 additions & 2 deletions libs/aws/langchain_aws/chat_models/bedrock.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,25 @@
import re
from collections import defaultdict
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Literal,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)

from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
Expand All @@ -16,8 +30,11 @@
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Extra
from langchain_core.pydantic_v1 import BaseModel, Extra
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool

from langchain_aws.function_calling import convert_to_anthropic_tool, get_system_message
from langchain_aws.llms.bedrock import (
BedrockBase,
_combine_generation_info_for_llm_result,
Expand Down Expand Up @@ -267,6 +284,8 @@ def format_messages(
class ChatBedrock(BaseChatModel, BedrockBase):
"""A chat model that uses the Bedrock API."""

system_prompt_with_tools: str = ""

@property
def _llm_type(self) -> str:
"""Return type of chat model."""
Expand Down Expand Up @@ -310,6 +329,11 @@ def _stream(
system, formatted_messages = ChatPromptAdapter.format_messages(
provider, messages
)
if self.system_prompt_with_tools:
if system:
system = self.system_prompt_with_tools + f"\n{system}"
else:
system = self.system_prompt_with_tools
else:
prompt = ChatPromptAdapter.convert_messages_to_prompt(
provider=provider, messages=messages
Expand Down Expand Up @@ -361,6 +385,11 @@ def _generate(
system, formatted_messages = ChatPromptAdapter.format_messages(
provider, messages
)
if self.system_prompt_with_tools:
if system:
system = self.system_prompt_with_tools + f"\n{system}"
else:
system = self.system_prompt_with_tools
else:
prompt = ChatPromptAdapter.convert_messages_to_prompt(
provider=provider, messages=messages
Expand Down Expand Up @@ -412,6 +441,42 @@ def get_token_ids(self, text: str) -> List[int]:
else:
return super().get_token_ids(text)

def set_system_prompt_with_tools(self, xml_tools_system_prompt: str) -> None:
"""Workaround to bind. Sets the system prompt with tools"""
self.system_prompt_with_tools = xml_tools_system_prompt

def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
*,
tool_choice: Optional[Union[dict, str, Literal["auto", "none"], bool]] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tool-like objects to this chat model.
Assumes model has a tool calling API.
Args:
tools: A list of tool definitions to bind to this chat model.
Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic
models, callables, and BaseTools will be automatically converted to
their schema dictionary representation.
tool_choice: Which tool to require the model to call.
Must be the name of the single provided function or
"auto" to automatically determine which function to call
(if any), or a dict of the form:
{"type": "function", "function": {"name": <<tool_name>>}}.
**kwargs: Any additional parameters to pass to the
:class:`~langchain.runnable.Runnable` constructor.
"""
provider = self._get_provider()

if provider == "anthropic":
formatted_tools = [convert_to_anthropic_tool(tool) for tool in tools]
system_formatted_tools = get_system_message(formatted_tools)
self.set_system_prompt_with_tools(system_formatted_tools)
return self


@deprecated(since="0.1.0", removal="0.2.0", alternative="ChatBedrock")
class BedrockChat(ChatBedrock):
Expand Down
139 changes: 139 additions & 0 deletions libs/aws/langchain_aws/function_calling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
"""Methods for creating function specs in the style of Bedrock Functions
for supported model providers"""

import json
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Type,
Union,
)

from langchain_core.pydantic_v1 import BaseModel
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from typing_extensions import TypedDict

PYTHON_TO_JSON_TYPES = {
"str": "string",
"int": "integer",
"float": "number",
"bool": "boolean",
}

SYSTEM_PROMPT_FORMAT = """In this environment you have access to a set of tools you can use to answer the user's question.
You may call them like this:
<function_calls>
<invoke>
<tool_name>$TOOL_NAME</tool_name>
<parameters>
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
...
</parameters>
</invoke>
</function_calls>
Here are the tools available:
<tools>
{formatted_tools}
</tools>""" # noqa: E501

TOOL_FORMAT = """<tool_description>
<tool_name>{tool_name}</tool_name>
<description>{tool_description}</description>
<parameters>
{formatted_parameters}
</parameters>
</tool_description>"""

TOOL_PARAMETER_FORMAT = """<parameter>
<name>{parameter_name}</name>
<type>{parameter_type}</type>
<description>{parameter_description}</description>
</parameter>"""


class AnthropicTool(TypedDict):
name: str
description: str
input_schema: Dict[str, Any]


def _get_type(parameter: Dict[str, Any]) -> str:
if "type" in parameter:
return parameter["type"]
if "anyOf" in parameter:
return json.dumps({"anyOf": parameter["anyOf"]})
if "allOf" in parameter:
return json.dumps({"allOf": parameter["allOf"]})
return json.dumps(parameter)


def get_system_message(tools: List[AnthropicTool]) -> str:
tools_data: List[Dict] = [
{
"tool_name": tool["name"],
"tool_description": tool["description"],
"formatted_parameters": "\n".join(
[
TOOL_PARAMETER_FORMAT.format(
parameter_name=name,
parameter_type=_get_type(parameter),
parameter_description=parameter.get("description"),
)
for name, parameter in tool["input_schema"]["properties"].items()
]
),
}
for tool in tools
]
tools_formatted = "\n".join(
[
TOOL_FORMAT.format(
tool_name=tool["tool_name"],
tool_description=tool["tool_description"],
formatted_parameters=tool["formatted_parameters"],
)
for tool in tools_data
]
)
return SYSTEM_PROMPT_FORMAT.format(formatted_tools=tools_formatted)


class FunctionDescription(TypedDict):
"""Representation of a callable function to send to an LLM."""

name: str
"""The name of the function."""
description: str
"""A description of the function."""
parameters: dict
"""The parameters of the function."""


class ToolDescription(TypedDict):
"""Representation of a callable function to the OpenAI API."""

type: Literal["function"]
function: FunctionDescription


def convert_to_anthropic_tool(
tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool],
) -> AnthropicTool:
# already in Anthropic tool format
if isinstance(tool, dict) and all(
k in tool for k in ("name", "description", "input_schema")
):
return AnthropicTool(tool) # type: ignore
else:
formatted = convert_to_openai_tool(tool)["function"]
return AnthropicTool(
name=formatted["name"],
description=formatted["description"],
input_schema=formatted["parameters"],
)
54 changes: 23 additions & 31 deletions libs/aws/langchain_aws/llms/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,9 +406,9 @@ class BedrockBase(BaseLanguageModel, ABC):
}

guardrails: Optional[Mapping[str, Any]] = {
"id": None,
"version": None,
"trace": False,
"trace": None,
"guardrailIdentifier": None,
"guardrailVersion": None,
}
"""
An optional dictionary to configure guardrails for Bedrock.
Expand Down Expand Up @@ -509,7 +509,9 @@ def _identifying_params(self) -> Dict[str, Any]:
"model_id": self.model_id,
"provider": self._get_provider(),
"stream": self.streaming,
"guardrails": self.guardrails,
"trace": self.guardrails.get("trace"), # type: ignore[union-attr]
"guardrailIdentifier": self.guardrails.get("guardrailIdentifier", None), # type: ignore[union-attr]
"guardrailVersion": self.guardrails.get("guardrailVersion", None), # type: ignore[union-attr]
**_model_kwargs,
}

Expand Down Expand Up @@ -543,32 +545,16 @@ def _guardrails_enabled(self) -> bool:
try:
return (
isinstance(self.guardrails, dict)
and bool(self.guardrails["id"])
and bool(self.guardrails["version"])
and bool(self.guardrails["guardrailIdentifier"])
and bool(self.guardrails["guardrailVersion"])
)

except KeyError as e:
raise TypeError(
"Guardrails must be a dictionary with 'id' and 'version' keys."
"Guardrails must be a dictionary with 'guardrailIdentifier' \
and 'guardrailVersion' keys."
) from e

def _get_guardrails_canonical(self) -> Dict[str, Any]:
"""
The canonical way to pass in guardrails to the bedrock service
adheres to the following format:
"amazon-bedrock-guardrailDetails": {
"guardrailId": "string",
"guardrailVersion": "string"
}
"""
return {
"amazon-bedrock-guardrailDetails": {
"guardrailId": self.guardrails.get("id"), # type: ignore[union-attr]
"guardrailVersion": self.guardrails.get("version"), # type: ignore[union-attr]
}
}

def _prepare_input_and_invoke(
self,
prompt: Optional[str] = None,
Expand All @@ -582,8 +568,7 @@ def _prepare_input_and_invoke(

provider = self._get_provider()
params = {**_model_kwargs, **kwargs}
if self._guardrails_enabled:
params.update(self._get_guardrails_canonical())

input_body = LLMInputOutputAdapter.prepare_input(
provider=provider,
model_kwargs=params,
Expand All @@ -603,7 +588,12 @@ def _prepare_input_and_invoke(
}

if self._guardrails_enabled:
request_options["guardrail"] = "ENABLED"
request_options["guardrailIdentifier"] = self.guardrails.get( # type: ignore[union-attr]
"guardrailIdentifier", ""
)
request_options["guardrailVersion"] = self.guardrails.get( # type: ignore[union-attr]
"guardrailVersion", ""
)
if self.guardrails.get("trace"): # type: ignore[union-attr]
request_options["trace"] = "ENABLED"

Expand Down Expand Up @@ -693,9 +683,6 @@ def _prepare_input_and_invoke_stream(

params = {**_model_kwargs, **kwargs}

if self._guardrails_enabled:
params.update(self._get_guardrails_canonical())

input_body = LLMInputOutputAdapter.prepare_input(
provider=provider,
prompt=prompt,
Expand All @@ -713,7 +700,12 @@ def _prepare_input_and_invoke_stream(
}

if self._guardrails_enabled:
request_options["guardrail"] = "ENABLED"
request_options["guardrailIdentifier"] = self.guardrails.get( # type: ignore[union-attr]
"guardrailIdentifier", ""
)
request_options["guardrailVersion"] = self.guardrails.get( # type: ignore[union-attr]
"guardrailVersion", ""
)
if self.guardrails.get("trace"): # type: ignore[union-attr]
request_options["trace"] = "ENABLED"

Expand Down
Loading

0 comments on commit 1526f5a

Please sign in to comment.