Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
munday-tech authored Jun 23, 2024
2 parents d95bac6 + 046efe5 commit bc300d6
Show file tree
Hide file tree
Showing 11 changed files with 1,118 additions and 99 deletions.
3 changes: 2 additions & 1 deletion libs/aws/langchain_aws/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from langchain_aws.chat_models import BedrockChat, ChatBedrock
from langchain_aws.chat_models import BedrockChat, ChatBedrock, ChatBedrockConverse
from langchain_aws.embeddings import BedrockEmbeddings
from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph
from langchain_aws.llms import Bedrock, BedrockLLM, SagemakerEndpoint
Expand All @@ -13,6 +13,7 @@
"BedrockLLM",
"BedrockChat",
"ChatBedrock",
"ChatBedrockConverse",
"SagemakerEndpoint",
"AmazonKendraRetriever",
"AmazonKnowledgeBasesRetriever",
Expand Down
3 changes: 2 additions & 1 deletion libs/aws/langchain_aws/chat_models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from langchain_aws.chat_models.bedrock import BedrockChat, ChatBedrock
from langchain_aws.chat_models.bedrock_converse import ChatBedrockConverse

__all__ = ["BedrockChat", "ChatBedrock"]
__all__ = ["BedrockChat", "ChatBedrock", "ChatBedrockConverse"]
43 changes: 43 additions & 0 deletions libs/aws/langchain_aws/chat_models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool

from langchain_aws.chat_models.bedrock_converse import ChatBedrockConverse
from langchain_aws.function_calling import (
ToolsOutputParser,
_lc_tool_calls_to_anthropic_tool_use_blocks,
Expand Down Expand Up @@ -387,6 +388,9 @@ class ChatBedrock(BaseChatModel, BedrockBase):
"""A chat model that uses the Bedrock API."""

system_prompt_with_tools: str = ""
beta_use_converse_api: bool = False
"""Use the new Bedrock ``converse`` API which provides a standardized interface to
all Bedrock models. Support still in beta. See ChatBedrockConverse docs for more."""

@property
def _llm_type(self) -> str:
Expand Down Expand Up @@ -424,6 +428,11 @@ def _stream(
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
if self.beta_use_converse_api:
yield from self._as_converse._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return
provider = self._get_provider()
prompt, system, formatted_messages = None, None, None

Expand Down Expand Up @@ -491,7 +500,14 @@ def _generate(
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:

should_stream = stream if stream is not None else self.streaming

if self.beta_use_converse_api:
return self._as_converse._generate(
messages, stop=stop, run_manager=run_manager, **kwargs
)

completion = ""
llm_output: Dict[str, Any] = {}
tool_calls: List[Dict[str, Any]] = []
Expand Down Expand Up @@ -610,6 +626,12 @@ def bind_tools(
**kwargs: Any additional parameters to pass to the
:class:`~langchain.runnable.Runnable` constructor.
"""
if self.beta_use_converse_api:
if isinstance(tool_choice, bool):
tool_choice = "any" if tool_choice else None
return self._as_converse.bind_tools(
tools, tool_choice=tool_choice, **kwargs
)
if self._get_provider() == "anthropic":
formatted_tools = [convert_to_anthropic_tool(tool) for tool in tools]

Expand Down Expand Up @@ -747,6 +769,10 @@ class AnswerWithJustification(BaseModel):
# }
""" # noqa: E501
if self.beta_use_converse_api:
return self._as_converse.with_structured_output(
schema, include_raw=include_raw, **kwargs
)
if "claude-3" not in self._get_model():
ValueError(
f"Structured output is not supported for model {self._get_model()}"
Expand All @@ -771,6 +797,23 @@ class AnswerWithJustification(BaseModel):
else:
return llm | output_parser

@property
def _as_converse(self) -> ChatBedrockConverse:
kwargs = {
k: v
for k, v in (self.model_kwargs or {}).items()
if k in ("stop", "stop_sequences", "max_tokens", "temperature", "top_p")
}
return ChatBedrockConverse(
model=self.model_id,
region_name=self.region_name,
credentials_profile_name=self.credentials_profile_name,
config=self.config,
provider=self.provider or "",
base_url=self.endpoint_url,
**kwargs,
)


@deprecated(since="0.1.0", removal="0.2.0", alternative="ChatBedrock")
class BedrockChat(ChatBedrock):
Expand Down
Loading

0 comments on commit bc300d6

Please sign in to comment.