Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability in Token Counter to retrieve Open AI cached_tokens #17372 #17380

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions llama-index-core/llama_index/core/callbacks/token_counting.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class TokenCountingEvent:
completion_token_count: int
prompt_token_count: int
total_token_count: int = 0
cached_tokens: int = 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this works for openai. But for llms that charge based on cache_read and cache_write (like anthropic), this approach doesn't translate well 🤔

event_id: str = ""

def __post_init__(self) -> None:
Expand All @@ -56,7 +57,8 @@ def get_tokens_from_response(

possible_input_keys = ("prompt_tokens", "input_tokens")
possible_output_keys = ("completion_tokens", "output_tokens")

openai_prompt_tokens_details_key = 'prompt_tokens_details'

prompt_tokens = 0
for input_key in possible_input_keys:
if input_key in usage:
Expand All @@ -68,8 +70,12 @@ def get_tokens_from_response(
if output_key in usage:
completion_tokens = usage[output_key]
break

return prompt_tokens, completion_tokens

cached_tokens = 0
if openai_prompt_tokens_details_key in usage:
cached_tokens = usage[openai_prompt_tokens_details_key]['cached_tokens']

return prompt_tokens, completion_tokens, cached_tokens


def get_llm_token_counts(
Expand All @@ -83,9 +89,9 @@ def get_llm_token_counts(

if completion:
# get from raw or additional_kwargs
prompt_tokens, completion_tokens = get_tokens_from_response(completion)
prompt_tokens, completion_tokens, cached_tokens = get_tokens_from_response(completion)
else:
prompt_tokens, completion_tokens = 0, 0
prompt_tokens, completion_tokens, cached_tokens = 0, 0, 0

if prompt_tokens == 0:
prompt_tokens = token_counter.get_string_tokens(str(prompt))
Expand All @@ -99,6 +105,7 @@ def get_llm_token_counts(
prompt_token_count=prompt_tokens,
completion=str(completion),
completion_token_count=completion_tokens,
cached_tokens=cached_tokens,
)

elif EventPayload.MESSAGES in payload:
Expand All @@ -109,9 +116,9 @@ def get_llm_token_counts(
response_str = str(response)

if response:
prompt_tokens, completion_tokens = get_tokens_from_response(response)
prompt_tokens, completion_tokens, cached_tokens = get_tokens_from_response(response)
else:
prompt_tokens, completion_tokens = 0, 0
prompt_tokens, completion_tokens, cached_tokens = 0, 0, 0

if prompt_tokens == 0:
prompt_tokens = token_counter.estimate_tokens_in_messages(messages)
Expand All @@ -125,6 +132,7 @@ def get_llm_token_counts(
prompt_token_count=prompt_tokens,
completion=response_str,
completion_token_count=completion_tokens,
cached_tokens=cached_tokens,
)
else:
return TokenCountingEvent(
Expand All @@ -133,6 +141,7 @@ def get_llm_token_counts(
prompt_token_count=0,
completion="",
completion_token_count=0,
cached_tokens=0,
)


Expand Down Expand Up @@ -214,7 +223,9 @@ def on_event_end(
"LLM Prompt Token Usage: "
f"{self.llm_token_counts[-1].prompt_token_count}\n"
"LLM Completion Token Usage: "
f"{self.llm_token_counts[-1].completion_token_count}",
f"{self.llm_token_counts[-1].completion_token_count}"
"LLM Cached Tokens: "
f"{self.llm_token_counts[-1].cached_tokens}",
)
elif (
event_type == CBEventType.EMBEDDING
Expand Down Expand Up @@ -251,6 +262,11 @@ def prompt_llm_token_count(self) -> int:
def completion_llm_token_count(self) -> int:
"""Get the current total LLM completion token count."""
return sum([x.completion_token_count for x in self.llm_token_counts])

@property
def total_cached_token_count(self) -> int:
"""Get the current total cached token count."""
return sum([x.cached_tokens for x in self.llm_token_counts])

@property
def total_embedding_token_count(self) -> int:
Expand Down