diff --git a/instructor/utils.py b/instructor/utils.py index 088cebbaf..29effb9df 100644 --- a/instructor/utils.py +++ b/instructor/utils.py @@ -167,20 +167,36 @@ def update_total_usage( if isinstance(response_usage, AnthropicUsage) and isinstance( total_usage, AnthropicUsage ): - if not total_usage.cache_creation_input_tokens: - total_usage.cache_creation_input_tokens = 0 + # update input_tokens / output_tokens + if hasattr(total_usage, "input_tokens") and hasattr( + response_usage, "input_tokens" + ): + total_usage.input_tokens += response_usage.input_tokens or 0 + if hasattr(total_usage, "output_tokens") and hasattr( + response_usage, "output_tokens" + ): + total_usage.output_tokens += response_usage.output_tokens or 0 + + # Update cache_creation_input_tokens if both have that field + if hasattr(total_usage, "cache_creation_input_tokens") and hasattr( + response_usage, "cache_creation_input_tokens" + ): + if not total_usage.cache_creation_input_tokens: + total_usage.cache_creation_input_tokens = 0 + total_usage.cache_creation_input_tokens += ( + response_usage.cache_creation_input_tokens or 0 + ) + + # Update cache_read_input_tokens if both have that field + if hasattr(total_usage, "cache_read_input_tokens") and hasattr( + response_usage, "cache_read_input_tokens" + ): + if not total_usage.cache_read_input_tokens: + total_usage.cache_read_input_tokens = 0 + total_usage.cache_read_input_tokens += ( + response_usage.cache_read_input_tokens or 0 + ) - if not total_usage.cache_read_input_tokens: - total_usage.cache_read_input_tokens = 0 - - total_usage.input_tokens += response_usage.input_tokens or 0 - total_usage.output_tokens += response_usage.output_tokens or 0 - total_usage.cache_creation_input_tokens += ( - response_usage.cache_creation_input_tokens or 0 - ) - total_usage.cache_read_input_tokens += ( - response_usage.cache_read_input_tokens or 0 - ) response.usage = total_usage return response except ImportError: diff --git a/tests/llm/test_new_client.py b/tests/llm/test_new_client.py index 5acf4e6a7..d20eb64d4 100644 --- a/tests/llm/test_new_client.py +++ b/tests/llm/test_new_client.py @@ -180,7 +180,6 @@ def test_client_anthropic_response(): assert user.age == 10 -@pytest.mark.skip(reason="Skip for now") def test_client_anthropic_bedrock_response(): client = anthropic.AnthropicBedrock( aws_access_key=os.getenv("AWS_ACCESS_KEY_ID"), @@ -222,7 +221,6 @@ async def test_async_client_anthropic_response(): assert user.age == 10 -@pytest.mark.skip(reason="Skip for now") @pytest.mark.asyncio async def test_async_client_anthropic_bedrock_response(): client = anthropic.AsyncAnthropicBedrock(