Skip to content

Commit

Permalink
fix: fixed token counting issues in Converse API-based models, Claude…
Browse files Browse the repository at this point in the history
… 3 and Cohere models (#221)
  • Loading branch information
adubovik authored Feb 3, 2025
1 parent a6f3189 commit 5ae0215
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 79 deletions.
29 changes: 28 additions & 1 deletion aidial_adapter_bedrock/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class ResponseWithInvocationMetricsMixin(ABC, BaseModel):
alias="amazon-bedrock-invocationMetrics"
)

def usage_by_metrics(self) -> TokenUsage:
def usage_from_metrics(self) -> TokenUsage:
metrics = self.invocation_metrics
if metrics is None:
return TokenUsage()
Expand All @@ -137,3 +137,30 @@ def usage_by_metrics(self) -> TokenUsage:
prompt_tokens=metrics.inputTokenCount,
completion_tokens=metrics.outputTokenCount,
)


def prompt_tokens_from_headers(headers: Headers) -> int | None:
try:
return int(headers["x-amzn-bedrock-input-token-count"])
except Exception:
log.error(
"Failed to extract prompt token usage from the response headers"
)
return None


def completion_tokens_from_headers(headers: Headers) -> int | None:
try:
return int(headers["x-amzn-bedrock-output-token-count"])
except Exception:
log.error(
"Failed to extract completion token usage from the response headers"
)
return None


def usage_from_headers(response_headers: Headers) -> TokenUsage:
return TokenUsage(
prompt_tokens=prompt_tokens_from_headers(response_headers) or 0,
completion_tokens=completion_tokens_from_headers(response_headers) or 0,
)
8 changes: 2 additions & 6 deletions aidial_adapter_bedrock/embedding/cohere/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

from pydantic import BaseModel

from aidial_adapter_bedrock.bedrock import Bedrock
from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log
from aidial_adapter_bedrock.bedrock import Bedrock, prompt_tokens_from_headers


class CohereResponse(BaseModel):
Expand All @@ -22,8 +21,5 @@ async def call_embedding_model(
body, headers = await client.ainvoke_non_streaming(model, request)
response = CohereResponse.parse_obj(body)

input_tokens = int(headers.get("x-amzn-bedrock-input-token-count", "0"))
if input_tokens == 0:
log.warning("Can't extract input tokens from embeddings response")

input_tokens = prompt_tokens_from_headers(headers) or 0
return response.embeddings, input_tokens
16 changes: 13 additions & 3 deletions aidial_adapter_bedrock/llm/converse/adapter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from logging import DEBUG
from typing import Any, Awaitable, Callable, List, Tuple

from aidial_sdk.chat_completion import Message as DialMessage
Expand Down Expand Up @@ -35,9 +36,10 @@
DiscardedMessages,
truncate_prompt,
)
from aidial_adapter_bedrock.utils.json import remove_nones
from aidial_adapter_bedrock.utils.json import json_dumps_short, remove_nones
from aidial_adapter_bedrock.utils.list import omit_by_indices
from aidial_adapter_bedrock.utils.list_projection import ListProjection
from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log

ConverseMessages = List[Tuple[ConverseMessage, Any]]

Expand Down Expand Up @@ -166,12 +168,20 @@ async def chat(

consumer.set_discarded_messages(discarded_messages)

request = converse_params.to_request()

if log.isEnabledFor(DEBUG):
msg = json_dumps_short(
{"deployment": self.deployment, "request": request}
)
log.debug(f"request: {msg}")

if self.is_stream(params):
await process_streaming(
params=params,
stream=(
await self.bedrock.aconverse_streaming(
self.deployment, **converse_params.to_request()
self.deployment, **request
)
),
consumer=consumer,
Expand All @@ -180,7 +190,7 @@ async def chat(
process_non_streaming(
params=params,
response=await self.bedrock.aconverse_non_streaming(
self.deployment, **converse_params.to_request()
self.deployment, **request
),
consumer=consumer,
)
20 changes: 19 additions & 1 deletion aidial_adapter_bedrock/llm/converse/output.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from logging import DEBUG
from typing import Any, AsyncIterator, Dict, assert_never

from aidial_sdk.chat_completion import FinishReason as DialFinishReason
Expand All @@ -14,6 +15,8 @@
)
from aidial_adapter_bedrock.llm.converse.types import ConverseStopReason
from aidial_adapter_bedrock.llm.tools.tools_config import ToolsMode
from aidial_adapter_bedrock.utils.json import json_dumps_short
from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log


def to_dial_finish_reason(
Expand All @@ -34,6 +37,19 @@ async def process_streaming(
current_tool_use = None

async for event in stream:
if log.isEnabledFor(DEBUG):
log.debug(f"response event: {json_dumps_short(event)}")

if (metadata := event.get("metadata")) and (
usage := metadata.get("usage")
):
consumer.add_usage(
TokenUsage(
prompt_tokens=usage.get("inputTokens") or 0,
completion_tokens=usage.get("outputTokens") or 0,
)
)

if (content_block_start := event.get("contentBlockStart")) and (
tool_use := content_block_start.get("start", {}).get("toolUse")
):
Expand All @@ -57,7 +73,6 @@ async def process_streaming(

elif event.get("contentBlockStop"):
if current_tool_use:

match params.tools_mode:
case ToolsMode.TOOLS:
consumer.create_function_tool_call(
Expand Down Expand Up @@ -99,6 +114,9 @@ def process_non_streaming(
response: Dict[str, Any],
consumer: Consumer,
) -> None:
if log.isEnabledFor(DEBUG):
log.debug(f"response: {json_dumps_short(response)}")

message = response["output"]["message"]
for content_block in message.get("content", []):
if "text" in content_block:
Expand Down
44 changes: 7 additions & 37 deletions aidial_adapter_bedrock/llm/model/claude/v3/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from anthropic import NOT_GIVEN, MessageStopEvent, NotGiven
from anthropic.lib.bedrock import AsyncAnthropicBedrock
from anthropic.lib.streaming import (
AsyncMessageStream,
ContentBlockStopEvent,
InputJsonEvent,
TextEvent,
Expand All @@ -17,12 +16,7 @@
MessageDeltaEvent,
)
from anthropic.types import MessageParam as ClaudeMessage
from anthropic.types import (
MessageStartEvent,
MessageStreamEvent,
TextBlock,
ToolUseBlock,
)
from anthropic.types import MessageStartEvent, TextBlock, ToolUseBlock
from anthropic.types.message_create_params import ToolChoice

from aidial_adapter_bedrock.adapter_deployments import AdapterDeployment
Expand Down Expand Up @@ -51,7 +45,6 @@
from aidial_adapter_bedrock.llm.errors import ValidationError
from aidial_adapter_bedrock.llm.message import parse_dial_message
from aidial_adapter_bedrock.llm.model.claude.v3.converters import (
ClaudeFinishReason,
to_claude_messages,
to_claude_tool_config,
to_dial_finish_reason,
Expand All @@ -76,19 +69,6 @@
from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log


class UsageEventHandler(AsyncMessageStream):
prompt_tokens: int = 0
completion_tokens: int = 0
stop_reason: Optional[ClaudeFinishReason] = None

async def on_stream_event(self, event: MessageStreamEvent):
if isinstance(event, MessageStartEvent):
self.prompt_tokens = event.message.usage.input_tokens
elif isinstance(event, MessageDeltaEvent):
self.completion_tokens += event.usage.output_tokens
self.stop_reason = event.delta.stop_reason


# NOTE: it's not pydantic BaseModel, because
# ClaudeMessage.content is of Iterable type and
# pydantic automatically converts lists into
Expand Down Expand Up @@ -246,15 +226,11 @@ async def invoke_streaming(
request: ClaudeRequest,
discarded_messages: DiscardedMessages | None,
):

if log.isEnabledFor(DEBUG):
msg = json_dumps_short(
{
"deployment": self.deployment,
"request": request,
}
{"deployment": self.deployment, "request": request}
)
log.debug(f"Streaming request: {msg}")
log.debug(f"request: {msg}")

async with self.client.messages.stream(
messages=request.messages.raw_list,
Expand All @@ -266,9 +242,7 @@ async def invoke_streaming(
stop_reason = None
async for event in stream:
if log.isEnabledFor(DEBUG):
log.debug(
f"claude response event: {json_dumps_short(event)}"
)
log.debug(f"response event: {json_dumps_short(event)}")

match event:
case MessageStartEvent(message=message):
Expand All @@ -289,7 +263,6 @@ async def invoke_streaming(
case _:
assert_never(content_block)
case MessageStopEvent(message=message):
completion_tokens += message.usage.output_tokens
stop_reason = message.stop_reason
case (
InputJsonEvent()
Expand Down Expand Up @@ -323,12 +296,9 @@ async def invoke_non_streaming(

if log.isEnabledFor(DEBUG):
msg = json_dumps_short(
{
"deployment": self.deployment,
"request": request,
}
{"deployment": self.deployment, "request": request}
)
log.debug(f"Request: {msg}")
log.debug(f"request: {msg}")

message = await self.client.messages.create(
messages=request.messages.raw_list,
Expand All @@ -338,7 +308,7 @@ async def invoke_non_streaming(
)

if log.isEnabledFor(DEBUG):
log.debug(f"claude response message: {json_dumps_short(message)}")
log.debug(f"response: {json_dumps_short(message)}")

for content in message.content:
match content:
Expand Down
38 changes: 8 additions & 30 deletions aidial_adapter_bedrock/llm/model/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

from aidial_adapter_bedrock.bedrock import (
Bedrock,
Headers,
ResponseWithInvocationMetricsMixin,
usage_from_headers,
)
from aidial_adapter_bedrock.dial_api.request import ModelParameters
from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage
Expand Down Expand Up @@ -40,7 +42,6 @@
default_tools_emulator,
)
from aidial_adapter_bedrock.utils.list_projection import ListProjection
from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log


class CohereResult(BaseModel):
Expand Down Expand Up @@ -75,27 +76,6 @@ def tokens(self) -> List[str]:
"""Includes prompt and completion tokens"""
return [lh.token for lh in self.generations[0].token_likelihoods]

def usage_by_tokens(self) -> TokenUsage:
special_tokens = 7
total_tokens = len(self.tokens) - special_tokens

# The structure for the response:
# ["<BOS_TOKEN>", "User", ":", *<prompt>, "\n", "Chat", "bot", ":", "<EOP_TOKEN>", *<completion>]
# prompt_tokens = len(<prompt>)
# completion_tokens = len(["<EOP_TOKEN>"] + <completion>)

separator = "<EOP_TOKEN>"
if separator in self.tokens:
prompt_tokens = self.tokens.index(separator) - special_tokens
else:
log.error(f"Separator '{separator}' not found in tokens")
prompt_tokens = total_tokens // 2

return TokenUsage(
prompt_tokens=prompt_tokens,
completion_tokens=total_tokens - prompt_tokens,
)


def convert_params(params: ModelParameters) -> Dict[str, Any]:
ret = {}
Expand Down Expand Up @@ -125,17 +105,15 @@ async def chunks_to_stream(
) -> AsyncIterator[str]:
async for chunk in chunks:
resp = CohereResponse.parse_obj(chunk)
usage.accumulate(resp.usage_by_metrics())
log.debug(f"tokens: {'|'.join(resp.tokens)!r}")
usage.accumulate(resp.usage_from_metrics())
yield resp.content()


async def response_to_stream(
response: dict, usage: TokenUsage
response_body: dict, response_headers: Headers, usage: TokenUsage
) -> AsyncIterator[str]:
resp = CohereResponse.parse_obj(response)
usage.accumulate(resp.usage_by_tokens())
log.debug(f"tokens: {'|'.join(resp.tokens)!r}")
resp = CohereResponse.parse_obj(response_body)
usage.accumulate(usage_from_headers(response_headers))
yield resp.content()


Expand Down Expand Up @@ -197,10 +175,10 @@ async def predict(
chunks = self.client.ainvoke_streaming(self.model, args)
stream = chunks_to_stream(chunks, usage)
else:
response, _headers = await self.client.ainvoke_non_streaming(
response, headers = await self.client.ainvoke_non_streaming(
self.model, args
)
stream = response_to_stream(response, usage)
stream = response_to_stream(response, headers, usage)

stream = post_process_completion_stream(params, cohere_emulator, stream)

Expand Down
4 changes: 3 additions & 1 deletion tests/integration_tests/test_chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,9 @@ def test_case(
name="pinocchio in one token",
max_tokens=1,
messages=[user("tell me the full story of Pinocchio")],
expected=lambda s: len(s.content.split()) <= 1,
expected=lambda s: len(s.content.split()) <= 1
and s.usage is not None
and s.usage.completion_tokens == 1,
)

# ai21 models do not support more than one stop word
Expand Down

0 comments on commit 5ae0215

Please sign in to comment.