Skip to content

Commit

Permalink
Merge branch 'langchain-ai:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
bigbernnn authored May 29, 2024
2 parents a73f2f6 + 46f2a7f commit d43526b
Show file tree
Hide file tree
Showing 12 changed files with 504 additions and 70 deletions.
2 changes: 1 addition & 1 deletion libs/aws/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ test tests integration_test integration_tests:
PYTHON_FILES=.
MYPY_CACHE=.mypy_cache
lint format: PYTHON_FILES=.
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/aws --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/aws --name-only --diff-filter=d main | grep -E '\.py$$|\.ipynb$$')
lint_package: PYTHON_FILES=langchain_aws
lint_tests: PYTHON_FILES=tests
lint_tests: MYPY_CACHE=.mypy_cache_test
Expand Down
35 changes: 24 additions & 11 deletions libs/aws/langchain_aws/chat_models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@
from langchain_core.tools import BaseTool

from langchain_aws.function_calling import convert_to_anthropic_tool, get_system_message
from langchain_aws.llms.bedrock import BedrockBase
from langchain_aws.llms.bedrock import (
BedrockBase,
_combine_generation_info_for_llm_result,
)
from langchain_aws.utils import (
get_num_tokens_anthropic,
get_token_ids_anthropic,
Expand Down Expand Up @@ -379,7 +382,13 @@ def _stream(
**kwargs,
):
delta = chunk.text
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
yield ChatGenerationChunk(
message=AIMessageChunk(
content=delta, response_metadata=chunk.generation_info
)
if chunk.generation_info is not None
else AIMessageChunk(content=delta)
)

def _generate(
self,
Expand All @@ -389,11 +398,18 @@ def _generate(
**kwargs: Any,
) -> ChatResult:
completion = ""
llm_output: Dict[str, Any] = {"model_id": self.model_id}
usage_info: Dict[str, Any] = {}
llm_output: Dict[str, Any] = {}
provider_stop_reason_code = self.provider_stop_reason_key_map.get(
self._get_provider(), "stop_reason"
)
if self.streaming:
response_metadata: List[Dict[str, Any]] = []
for chunk in self._stream(messages, stop, run_manager, **kwargs):
completion += chunk.text
response_metadata.append(chunk.message.response_metadata)
llm_output = _combine_generation_info_for_llm_result(
response_metadata, provider_stop_reason_code
)
else:
provider = self._get_provider()
prompt, system, formatted_messages = None, None, None
Expand All @@ -416,7 +432,7 @@ def _generate(
if stop:
params["stop_sequences"] = stop

completion, usage_info = self._prepare_input_and_invoke(
completion, llm_output = self._prepare_input_and_invoke(
prompt=prompt,
stop=stop,
run_manager=run_manager,
Expand All @@ -425,14 +441,11 @@ def _generate(
**params,
)

llm_output["usage"] = usage_info

llm_output["model_id"] = self.model_id
return ChatResult(
generations=[
ChatGeneration(
message=AIMessage(
content=completion, additional_kwargs={"usage": usage_info}
)
message=AIMessage(content=completion, additional_kwargs=llm_output)
)
],
llm_output=llm_output,
Expand All @@ -443,7 +456,7 @@ def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
final_output = {}
for output in llm_outputs:
output = output or {}
usage = output.pop("usage", {})
usage = output.get("usage", {})
for token_type, token_count in usage.items():
final_usage[token_type] += token_count
final_output.update(output)
Expand Down
2 changes: 2 additions & 0 deletions libs/aws/langchain_aws/llms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Bedrock,
BedrockBase,
BedrockLLM,
LLMInputOutputAdapter,
)
from langchain_aws.llms.sagemaker_endpoint import SagemakerEndpoint

Expand All @@ -11,5 +12,6 @@
"Bedrock",
"BedrockBase",
"BedrockLLM",
"LLMInputOutputAdapter",
"SagemakerEndpoint",
]
Loading

0 comments on commit d43526b

Please sign in to comment.