Skip to content

Commit

Permalink
ChatBedrock: add usage metadata (#85)
Browse files Browse the repository at this point in the history
  • Loading branch information
ccurme authored Jun 25, 2024
1 parent fa37438 commit df76647
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 14 deletions.
28 changes: 25 additions & 3 deletions libs/aws/langchain_aws/chat_models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
HumanMessage,
SystemMessage,
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.messages.tool import ToolCall, ToolMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Extra
Expand Down Expand Up @@ -484,9 +485,15 @@ def _stream(
**kwargs,
):
delta = chunk.text
if generation_info := chunk.generation_info:
usage_metadata = generation_info.pop("usage_metadata", None)
else:
usage_metadata = None
yield ChatGenerationChunk(
message=AIMessageChunk(
content=delta, response_metadata=chunk.generation_info
content=delta,
response_metadata=chunk.generation_info,
usage_metadata=usage_metadata,
)
if chunk.generation_info is not None
else AIMessageChunk(content=delta)
Expand Down Expand Up @@ -550,16 +557,31 @@ def _generate(
messages=formatted_messages,
**params,
)

# usage metadata
if usage := llm_output.get("usage"):
input_tokens = usage.get("prompt_tokens", 0)
output_tokens = usage.get("completion_tokens", 0)
usage_metadata = UsageMetadata(
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=usage.get("total_tokens", input_tokens + output_tokens),
)
else:
usage_metadata = None
llm_output["model_id"] = self.model_id
if len(tool_calls) > 0:
msg = AIMessage(
content=completion,
additional_kwargs=llm_output,
tool_calls=cast(List[ToolCall], tool_calls),
usage_metadata=usage_metadata,
)
else:
msg = AIMessage(content=completion, additional_kwargs=llm_output)
msg = AIMessage(
content=completion,
additional_kwargs=llm_output,
usage_metadata=usage_metadata,
)
return ChatResult(
generations=[
ChatGeneration(
Expand Down
15 changes: 15 additions & 0 deletions libs/aws/langchain_aws/llms/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,19 @@ def _combine_generation_info_for_llm_result(
return {"usage": total_usage_info, "stop_reason": stop_reason}


def _get_invocation_metrics_chunk(chunk: Dict[str, Any]) -> GenerationChunk:
generation_info = {}
if metrics := chunk.get("amazon-bedrock-invocationMetrics"):
input_tokens = metrics.get("inputTokenCount", 0)
output_tokens = metrics.get("outputTokenCount", 0)
generation_info["usage_metadata"] = {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens,
}
return GenerationChunk(text="", generation_info=generation_info)


def extract_tool_calls(content: List[dict]) -> List[ToolCall]:
tool_calls = []
for block in content:
Expand Down Expand Up @@ -330,9 +343,11 @@ def prepare_output_stream(
provider == "mistral"
and chunk_obj.get(output_key, [{}])[0].get("stop_reason", "") == "stop"
):
yield _get_invocation_metrics_chunk(chunk_obj)
return

elif messages_api and (chunk_obj.get("type") == "message_stop"):
yield _get_invocation_metrics_chunk(chunk_obj)
return

generation_chunk = _stream_response_to_generation_chunk(
Expand Down
46 changes: 39 additions & 7 deletions libs/aws/tests/integration_tests/chat_models/test_bedrock.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Test Bedrock chat model."""
import json
from typing import Any, cast
from typing import Any

import pytest
from langchain_core.messages import (
Expand Down Expand Up @@ -124,22 +124,54 @@ def on_llm_end(


@pytest.mark.scheduled
def test_bedrock_streaming(chat: ChatBedrock) -> None:
"""Test streaming tokens from OpenAI."""

@pytest.mark.parametrize(
"model_id",
[
"anthropic.claude-3-sonnet-20240229-v1:0",
"mistral.mistral-7b-instruct-v0:2",
],
)
def test_bedrock_streaming(model_id: str) -> None:
chat = ChatBedrock(
model_id=model_id,
model_kwargs={"temperature": 0},
) # type: ignore[call-arg]
full = None
for token in chat.stream("I'm Pickle Rick"):
full = token if full is None else full + token # type: ignore[operator]
assert isinstance(token.content, str)
assert isinstance(cast(AIMessageChunk, full).content, str)
assert isinstance(full, AIMessageChunk)
assert isinstance(full.content, str)
assert full.usage_metadata is not None
assert full.usage_metadata["input_tokens"] > 0
assert full.usage_metadata["output_tokens"] > 0
assert full.usage_metadata["total_tokens"] > 0


@pytest.mark.scheduled
async def test_bedrock_astream(chat: ChatBedrock) -> None:
@pytest.mark.parametrize(
"model_id",
[
"anthropic.claude-3-sonnet-20240229-v1:0",
"mistral.mistral-7b-instruct-v0:2",
],
)
async def test_bedrock_astream(model_id: str) -> None:
"""Test streaming tokens from OpenAI."""

chat = ChatBedrock(
model_id=model_id,
model_kwargs={"temperature": 0},
) # type: ignore[call-arg]
full = None
async for token in chat.astream("I'm Pickle Rick"):
full = token if full is None else full + token # type: ignore[operator]
assert isinstance(token.content, str)
assert isinstance(full, AIMessageChunk)
assert isinstance(full.content, str)
assert full.usage_metadata is not None
assert full.usage_metadata["input_tokens"] > 0
assert full.usage_metadata["output_tokens"] > 0
assert full.usage_metadata["total_tokens"] > 0


@pytest.mark.scheduled
Expand Down
4 changes: 0 additions & 4 deletions libs/aws/tests/integration_tests/chat_models/test_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@ def chat_model_params(self) -> dict:
def standard_chat_model_params(self) -> dict:
return {}

@pytest.mark.xfail(reason="Not implemented.")
def test_usage_metadata(self, model: BaseChatModel) -> None:
super().test_usage_metadata(model)

@pytest.mark.xfail(reason="Not implemented.")
def test_stop_sequence(self, model: BaseChatModel) -> None:
super().test_stop_sequence(model)
Expand Down

0 comments on commit df76647

Please sign in to comment.