From 36871e6917decec3cfa0749b6d7f63b6351a3785 Mon Sep 17 00:00:00 2001 From: lazyhope <78585060+lazyhope@users.noreply.github.com> Date: Sun, 14 Jan 2024 11:57:28 +0800 Subject: [PATCH] Introduce `total_usage` variable to track cumulative token usage (#343) Co-authored-by: Jason Liu --- docs/concepts/raw_response.md | 2 +- instructor/patch.py | 19 ++++++++++++++++++- tests/openai/test_patch.py | 9 +++++++++ tests/test_function_calls.py | 13 +++++++++---- 4 files changed, 37 insertions(+), 6 deletions(-) diff --git a/docs/concepts/raw_response.md b/docs/concepts/raw_response.md index bc535281c..3de76f51f 100644 --- a/docs/concepts/raw_response.md +++ b/docs/concepts/raw_response.md @@ -25,7 +25,7 @@ print(user._raw_response) !!! tip "Accessing tokens usage" - This is the recommended way to access the tokens usage, since it is a pydantic model you can use any of the pydantic model methods on it. For example, you can access the `total_tokens` by doing `user._raw_response.usage.total_tokens`. + This is the recommended way to access the tokens usage, since it is a pydantic model you can use any of the pydantic model methods on it. For example, you can access the `total_tokens` by doing `user._raw_response.usage.total_tokens`. Note that this also includes the tokens used during any previous unsuccessful attempts. In the future, we may add additional hooks to the `raw_response` to make it easier to access the tokens usage. diff --git a/instructor/patch.py b/instructor/patch.py index cb8e6a11f..49ba2bae6 100644 --- a/instructor/patch.py +++ b/instructor/patch.py @@ -12,6 +12,7 @@ ChatCompletionMessage, ChatCompletionMessageParam, ) +from openai.types.completion_usage import CompletionUsage from pydantic import BaseModel, ValidationError from instructor.dsl.multitask import MultiTask, MultiTaskBase @@ -31,7 +32,7 @@ If `stream=True` is specified, the response will be parsed using the `from_stream_response` method of the response model, if available, otherwise it will be parsed using the `from_response` method. -If need to obtain the raw response from OpenAI's API, you can access it using the `_raw_response` attribute of the response model. +If need to obtain the raw response from OpenAI's API, you can access it using the `_raw_response` attribute of the response model. The `_raw_response.usage` attribute is modified to reflect the token usage from the last successful response as well as from any previous unsuccessful attempts. Parameters: response_model (Union[Type[BaseModel], Type[OpenAISchema]]): The response model to use for parsing the response from OpenAI's API, if available (default: None) @@ -225,10 +226,18 @@ async def retry_async( mode: Mode = Mode.FUNCTIONS, ): retries = 0 + total_usage = CompletionUsage(completion_tokens=0, prompt_tokens=0, total_tokens=0) while retries <= max_retries: try: response: ChatCompletion = await func(*args, **kwargs) stream = kwargs.get("stream", False) + if isinstance(response, ChatCompletion) and response.usage is not None: + total_usage.completion_tokens += response.usage.completion_tokens + total_usage.prompt_tokens += response.usage.prompt_tokens + total_usage.total_tokens += response.usage.total_tokens + response.usage = ( + total_usage # Replace each response usage with the total usage + ) return await process_response_async( response, response_model=response_model, @@ -279,11 +288,19 @@ def retry_sync( mode: Mode = Mode.FUNCTIONS, ): retries = 0 + total_usage = CompletionUsage(completion_tokens=0, prompt_tokens=0, total_tokens=0) while retries <= max_retries: # Excepts ValidationError, and JSONDecodeError try: response = func(*args, **kwargs) stream = kwargs.get("stream", False) + if isinstance(response, ChatCompletion) and response.usage is not None: + total_usage.completion_tokens += response.usage.completion_tokens + total_usage.prompt_tokens += response.usage.prompt_tokens + total_usage.total_tokens += response.usage.total_tokens + response.usage = ( + total_usage # Replace each response usage with the total usage + ) return process_response( response, response_model=response_model, diff --git a/tests/openai/test_patch.py b/tests/openai/test_patch.py index e56bcdc6c..a486936d4 100644 --- a/tests/openai/test_patch.py +++ b/tests/openai/test_patch.py @@ -1,5 +1,6 @@ from itertools import product from pydantic import BaseModel, field_validator +from openai.types.chat import ChatCompletion import pytest import instructor @@ -29,6 +30,8 @@ def test_runmodel(model, mode, client): model, "_raw_response" ), "The raw response should be available from OpenAI" + ChatCompletion(**model._raw_response.model_dump()) + @pytest.mark.parametrize("model, mode", product(models, modes)) @pytest.mark.asyncio @@ -49,6 +52,8 @@ async def test_runmodel_async(model, mode, aclient): model, "_raw_response" ), "The raw response should be available from OpenAI" + ChatCompletion(**model._raw_response.model_dump()) + class UserExtractValidated(BaseModel): name: str @@ -81,6 +86,8 @@ def test_runmodel_validator(model, mode, client): model, "_raw_response" ), "The raw response should be available from OpenAI" + ChatCompletion(**model._raw_response.model_dump()) + @pytest.mark.parametrize("model, mode", product(models, modes)) @pytest.mark.asyncio @@ -99,3 +106,5 @@ async def test_runmodel_async_validator(model, mode, aclient): assert hasattr( model, "_raw_response" ), "The raw response should be available from OpenAI" + + ChatCompletion(**model._raw_response.model_dump()) diff --git a/tests/test_function_calls.py b/tests/test_function_calls.py index bf0315a20..cd03126d3 100644 --- a/tests/test_function_calls.py +++ b/tests/test_function_calls.py @@ -3,6 +3,7 @@ from openai.resources.chat.completions import ChatCompletion from instructor import openai_schema, OpenAISchema +import instructor from instructor.exceptions import IncompleteOutputException @@ -92,11 +93,13 @@ class Dummy(OpenAISchema): ) def test_incomplete_output_exception(test_model, mock_completion): with pytest.raises(IncompleteOutputException): - test_model.from_response(mock_completion) + test_model.from_response(mock_completion, mode=instructor.Mode.FUNCTIONS) def test_complete_output_no_exception(test_model, mock_completion): - test_model_instance = test_model.from_response(mock_completion) + test_model_instance = test_model.from_response( + mock_completion, mode=instructor.Mode.FUNCTIONS + ) assert test_model_instance.data == "complete data" @@ -108,10 +111,12 @@ def test_complete_output_no_exception(test_model, mock_completion): ) async def test_incomplete_output_exception_raise(test_model, mock_completion): with pytest.raises(IncompleteOutputException): - await test_model.from_response(mock_completion) + await test_model.from_response(mock_completion, mode=instructor.Mode.FUNCTIONS) @pytest.mark.asyncio async def test_async_complete_output_no_exception(test_model, mock_completion): - test_model_instance = await test_model.from_response_async(mock_completion) + test_model_instance = await test_model.from_response_async( + mock_completion, mode=instructor.Mode.FUNCTIONS + ) assert test_model_instance.data == "complete data"