From df76647b11a5466f8cb3603f9f57337d73b1d5b4 Mon Sep 17 00:00:00 2001 From: ccurme Date: Tue, 25 Jun 2024 12:32:46 -0400 Subject: [PATCH] ChatBedrock: add usage metadata (#85) --- libs/aws/langchain_aws/chat_models/bedrock.py | 28 +++++++++-- libs/aws/langchain_aws/llms/bedrock.py | 15 ++++++ .../chat_models/test_bedrock.py | 46 ++++++++++++++++--- .../chat_models/test_standard.py | 4 -- 4 files changed, 79 insertions(+), 14 deletions(-) diff --git a/libs/aws/langchain_aws/chat_models/bedrock.py b/libs/aws/langchain_aws/chat_models/bedrock.py index fd5a4dc8..1f51b629 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock.py +++ b/libs/aws/langchain_aws/chat_models/bedrock.py @@ -31,6 +31,7 @@ HumanMessage, SystemMessage, ) +from langchain_core.messages.ai import UsageMetadata from langchain_core.messages.tool import ToolCall, ToolMessage from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.pydantic_v1 import BaseModel, Extra @@ -484,9 +485,15 @@ def _stream( **kwargs, ): delta = chunk.text + if generation_info := chunk.generation_info: + usage_metadata = generation_info.pop("usage_metadata", None) + else: + usage_metadata = None yield ChatGenerationChunk( message=AIMessageChunk( - content=delta, response_metadata=chunk.generation_info + content=delta, + response_metadata=chunk.generation_info, + usage_metadata=usage_metadata, ) if chunk.generation_info is not None else AIMessageChunk(content=delta) @@ -550,16 +557,31 @@ def _generate( messages=formatted_messages, **params, ) - + # usage metadata + if usage := llm_output.get("usage"): + input_tokens = usage.get("prompt_tokens", 0) + output_tokens = usage.get("completion_tokens", 0) + usage_metadata = UsageMetadata( + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=usage.get("total_tokens", input_tokens + output_tokens), + ) + else: + usage_metadata = None llm_output["model_id"] = self.model_id if len(tool_calls) > 0: msg = AIMessage( content=completion, additional_kwargs=llm_output, tool_calls=cast(List[ToolCall], tool_calls), + usage_metadata=usage_metadata, ) else: - msg = AIMessage(content=completion, additional_kwargs=llm_output) + msg = AIMessage( + content=completion, + additional_kwargs=llm_output, + usage_metadata=usage_metadata, + ) return ChatResult( generations=[ ChatGeneration( diff --git a/libs/aws/langchain_aws/llms/bedrock.py b/libs/aws/langchain_aws/llms/bedrock.py index 24416eac..2fcf346e 100644 --- a/libs/aws/langchain_aws/llms/bedrock.py +++ b/libs/aws/langchain_aws/llms/bedrock.py @@ -180,6 +180,19 @@ def _combine_generation_info_for_llm_result( return {"usage": total_usage_info, "stop_reason": stop_reason} +def _get_invocation_metrics_chunk(chunk: Dict[str, Any]) -> GenerationChunk: + generation_info = {} + if metrics := chunk.get("amazon-bedrock-invocationMetrics"): + input_tokens = metrics.get("inputTokenCount", 0) + output_tokens = metrics.get("outputTokenCount", 0) + generation_info["usage_metadata"] = { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "total_tokens": input_tokens + output_tokens, + } + return GenerationChunk(text="", generation_info=generation_info) + + def extract_tool_calls(content: List[dict]) -> List[ToolCall]: tool_calls = [] for block in content: @@ -330,9 +343,11 @@ def prepare_output_stream( provider == "mistral" and chunk_obj.get(output_key, [{}])[0].get("stop_reason", "") == "stop" ): + yield _get_invocation_metrics_chunk(chunk_obj) return elif messages_api and (chunk_obj.get("type") == "message_stop"): + yield _get_invocation_metrics_chunk(chunk_obj) return generation_chunk = _stream_response_to_generation_chunk( diff --git a/libs/aws/tests/integration_tests/chat_models/test_bedrock.py b/libs/aws/tests/integration_tests/chat_models/test_bedrock.py index 1d1d8e96..5e05cddf 100644 --- a/libs/aws/tests/integration_tests/chat_models/test_bedrock.py +++ b/libs/aws/tests/integration_tests/chat_models/test_bedrock.py @@ -1,6 +1,6 @@ """Test Bedrock chat model.""" import json -from typing import Any, cast +from typing import Any import pytest from langchain_core.messages import ( @@ -124,22 +124,54 @@ def on_llm_end( @pytest.mark.scheduled -def test_bedrock_streaming(chat: ChatBedrock) -> None: - """Test streaming tokens from OpenAI.""" - +@pytest.mark.parametrize( + "model_id", + [ + "anthropic.claude-3-sonnet-20240229-v1:0", + "mistral.mistral-7b-instruct-v0:2", + ], +) +def test_bedrock_streaming(model_id: str) -> None: + chat = ChatBedrock( + model_id=model_id, + model_kwargs={"temperature": 0}, + ) # type: ignore[call-arg] full = None for token in chat.stream("I'm Pickle Rick"): full = token if full is None else full + token # type: ignore[operator] assert isinstance(token.content, str) - assert isinstance(cast(AIMessageChunk, full).content, str) + assert isinstance(full, AIMessageChunk) + assert isinstance(full.content, str) + assert full.usage_metadata is not None + assert full.usage_metadata["input_tokens"] > 0 + assert full.usage_metadata["output_tokens"] > 0 + assert full.usage_metadata["total_tokens"] > 0 @pytest.mark.scheduled -async def test_bedrock_astream(chat: ChatBedrock) -> None: +@pytest.mark.parametrize( + "model_id", + [ + "anthropic.claude-3-sonnet-20240229-v1:0", + "mistral.mistral-7b-instruct-v0:2", + ], +) +async def test_bedrock_astream(model_id: str) -> None: """Test streaming tokens from OpenAI.""" - + chat = ChatBedrock( + model_id=model_id, + model_kwargs={"temperature": 0}, + ) # type: ignore[call-arg] + full = None async for token in chat.astream("I'm Pickle Rick"): + full = token if full is None else full + token # type: ignore[operator] assert isinstance(token.content, str) + assert isinstance(full, AIMessageChunk) + assert isinstance(full.content, str) + assert full.usage_metadata is not None + assert full.usage_metadata["input_tokens"] > 0 + assert full.usage_metadata["output_tokens"] > 0 + assert full.usage_metadata["total_tokens"] > 0 @pytest.mark.scheduled diff --git a/libs/aws/tests/integration_tests/chat_models/test_standard.py b/libs/aws/tests/integration_tests/chat_models/test_standard.py index 6d60d769..43434169 100644 --- a/libs/aws/tests/integration_tests/chat_models/test_standard.py +++ b/libs/aws/tests/integration_tests/chat_models/test_standard.py @@ -22,10 +22,6 @@ def chat_model_params(self) -> dict: def standard_chat_model_params(self) -> dict: return {} - @pytest.mark.xfail(reason="Not implemented.") - def test_usage_metadata(self, model: BaseChatModel) -> None: - super().test_usage_metadata(model) - @pytest.mark.xfail(reason="Not implemented.") def test_stop_sequence(self, model: BaseChatModel) -> None: super().test_stop_sequence(model)