Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ChatBedrock: add usage metadata #85

Merged
merged 6 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions libs/aws/langchain_aws/chat_models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
HumanMessage,
SystemMessage,
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.messages.tool import ToolCall, ToolMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Extra
Expand Down Expand Up @@ -484,9 +485,15 @@ def _stream(
**kwargs,
):
delta = chunk.text
if generation_info := chunk.generation_info:
usage_metadata = generation_info.pop("usage_metadata", None)
else:
usage_metadata = None
yield ChatGenerationChunk(
message=AIMessageChunk(
content=delta, response_metadata=chunk.generation_info
content=delta,
response_metadata=chunk.generation_info,
usage_metadata=usage_metadata,
)
if chunk.generation_info is not None
else AIMessageChunk(content=delta)
Expand Down Expand Up @@ -550,16 +557,31 @@ def _generate(
messages=formatted_messages,
**params,
)

# usage metadata
if usage := llm_output.get("usage"):
input_tokens = usage.get("prompt_tokens", 0)
output_tokens = usage.get("completion_tokens", 0)
usage_metadata = UsageMetadata(
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=usage.get("total_tokens", input_tokens + output_tokens),
)
else:
usage_metadata = None
llm_output["model_id"] = self.model_id
if len(tool_calls) > 0:
msg = AIMessage(
content=completion,
additional_kwargs=llm_output,
tool_calls=cast(List[ToolCall], tool_calls),
usage_metadata=usage_metadata,
)
else:
msg = AIMessage(content=completion, additional_kwargs=llm_output)
msg = AIMessage(
content=completion,
additional_kwargs=llm_output,
usage_metadata=usage_metadata,
)
return ChatResult(
generations=[
ChatGeneration(
Expand Down
15 changes: 15 additions & 0 deletions libs/aws/langchain_aws/llms/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,19 @@ def _combine_generation_info_for_llm_result(
return {"usage": total_usage_info, "stop_reason": stop_reason}


def _get_invocation_metrics_chunk(chunk: Dict[str, Any]) -> GenerationChunk:
generation_info = {}
if metrics := chunk.get("amazon-bedrock-invocationMetrics"):
input_tokens = metrics.get("inputTokenCount", 0)
output_tokens = metrics.get("outputTokenCount", 0)
generation_info["usage_metadata"] = {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens,
}
return GenerationChunk(text="", generation_info=generation_info)


def extract_tool_calls(content: List[dict]) -> List[ToolCall]:
tool_calls = []
for block in content:
Expand Down Expand Up @@ -330,9 +343,11 @@ def prepare_output_stream(
provider == "mistral"
and chunk_obj.get(output_key, [{}])[0].get("stop_reason", "") == "stop"
):
yield _get_invocation_metrics_chunk(chunk_obj)
return

elif messages_api and (chunk_obj.get("type") == "message_stop"):
yield _get_invocation_metrics_chunk(chunk_obj)
return

generation_chunk = _stream_response_to_generation_chunk(
Expand Down
46 changes: 39 additions & 7 deletions libs/aws/tests/integration_tests/chat_models/test_bedrock.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Test Bedrock chat model."""
import json
from typing import Any, cast
from typing import Any

import pytest
from langchain_core.messages import (
Expand Down Expand Up @@ -124,22 +124,54 @@ def on_llm_end(


@pytest.mark.scheduled
def test_bedrock_streaming(chat: ChatBedrock) -> None:
"""Test streaming tokens from OpenAI."""

@pytest.mark.parametrize(
"model_id",
[
"anthropic.claude-3-sonnet-20240229-v1:0",
"mistral.mistral-7b-instruct-v0:2",
],
)
def test_bedrock_streaming(model_id: str) -> None:
chat = ChatBedrock(
model_id=model_id,
model_kwargs={"temperature": 0},
) # type: ignore[call-arg]
full = None
for token in chat.stream("I'm Pickle Rick"):
full = token if full is None else full + token # type: ignore[operator]
assert isinstance(token.content, str)
assert isinstance(cast(AIMessageChunk, full).content, str)
assert isinstance(full, AIMessageChunk)
assert isinstance(full.content, str)
assert full.usage_metadata is not None
assert full.usage_metadata["input_tokens"] > 0
assert full.usage_metadata["output_tokens"] > 0
assert full.usage_metadata["total_tokens"] > 0


@pytest.mark.scheduled
async def test_bedrock_astream(chat: ChatBedrock) -> None:
@pytest.mark.parametrize(
"model_id",
[
"anthropic.claude-3-sonnet-20240229-v1:0",
"mistral.mistral-7b-instruct-v0:2",
],
)
async def test_bedrock_astream(model_id: str) -> None:
"""Test streaming tokens from OpenAI."""

chat = ChatBedrock(
model_id=model_id,
model_kwargs={"temperature": 0},
) # type: ignore[call-arg]
full = None
async for token in chat.astream("I'm Pickle Rick"):
full = token if full is None else full + token # type: ignore[operator]
assert isinstance(token.content, str)
assert isinstance(full, AIMessageChunk)
assert isinstance(full.content, str)
assert full.usage_metadata is not None
assert full.usage_metadata["input_tokens"] > 0
assert full.usage_metadata["output_tokens"] > 0
assert full.usage_metadata["total_tokens"] > 0


@pytest.mark.scheduled
Expand Down
4 changes: 0 additions & 4 deletions libs/aws/tests/integration_tests/chat_models/test_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@ def chat_model_params(self) -> dict:
def standard_chat_model_params(self) -> dict:
return {}

@pytest.mark.xfail(reason="Not implemented.")
def test_usage_metadata(self, model: BaseChatModel) -> None:
super().test_usage_metadata(model)

@pytest.mark.xfail(reason="Not implemented.")
def test_stop_sequence(self, model: BaseChatModel) -> None:
super().test_stop_sequence(model)
Expand Down
Loading