Skip to content

Commit

Permalink
Introduce total_usage variable to track cumulative token usage (ins…
Browse files Browse the repository at this point in the history
…tructor-ai#343)

Co-authored-by: Jason Liu <[email protected]>
  • Loading branch information
lazyhope and jxnl authored Jan 14, 2024
1 parent 3f901bc commit 36871e6
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 6 deletions.
2 changes: 1 addition & 1 deletion docs/concepts/raw_response.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ print(user._raw_response)

!!! tip "Accessing tokens usage"

This is the recommended way to access the tokens usage, since it is a pydantic model you can use any of the pydantic model methods on it. For example, you can access the `total_tokens` by doing `user._raw_response.usage.total_tokens`.
This is the recommended way to access the tokens usage, since it is a pydantic model you can use any of the pydantic model methods on it. For example, you can access the `total_tokens` by doing `user._raw_response.usage.total_tokens`. Note that this also includes the tokens used during any previous unsuccessful attempts.

In the future, we may add additional hooks to the `raw_response` to make it easier to access the tokens usage.

Expand Down
19 changes: 18 additions & 1 deletion instructor/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
ChatCompletionMessage,
ChatCompletionMessageParam,
)
from openai.types.completion_usage import CompletionUsage
from pydantic import BaseModel, ValidationError

from instructor.dsl.multitask import MultiTask, MultiTaskBase
Expand All @@ -31,7 +32,7 @@
If `stream=True` is specified, the response will be parsed using the `from_stream_response` method of the response model, if available, otherwise it will be parsed using the `from_response` method.
If need to obtain the raw response from OpenAI's API, you can access it using the `_raw_response` attribute of the response model.
If need to obtain the raw response from OpenAI's API, you can access it using the `_raw_response` attribute of the response model. The `_raw_response.usage` attribute is modified to reflect the token usage from the last successful response as well as from any previous unsuccessful attempts.
Parameters:
response_model (Union[Type[BaseModel], Type[OpenAISchema]]): The response model to use for parsing the response from OpenAI's API, if available (default: None)
Expand Down Expand Up @@ -225,10 +226,18 @@ async def retry_async(
mode: Mode = Mode.FUNCTIONS,
):
retries = 0
total_usage = CompletionUsage(completion_tokens=0, prompt_tokens=0, total_tokens=0)
while retries <= max_retries:
try:
response: ChatCompletion = await func(*args, **kwargs)
stream = kwargs.get("stream", False)
if isinstance(response, ChatCompletion) and response.usage is not None:
total_usage.completion_tokens += response.usage.completion_tokens
total_usage.prompt_tokens += response.usage.prompt_tokens
total_usage.total_tokens += response.usage.total_tokens
response.usage = (
total_usage # Replace each response usage with the total usage
)
return await process_response_async(
response,
response_model=response_model,
Expand Down Expand Up @@ -279,11 +288,19 @@ def retry_sync(
mode: Mode = Mode.FUNCTIONS,
):
retries = 0
total_usage = CompletionUsage(completion_tokens=0, prompt_tokens=0, total_tokens=0)
while retries <= max_retries:
# Excepts ValidationError, and JSONDecodeError
try:
response = func(*args, **kwargs)
stream = kwargs.get("stream", False)
if isinstance(response, ChatCompletion) and response.usage is not None:
total_usage.completion_tokens += response.usage.completion_tokens
total_usage.prompt_tokens += response.usage.prompt_tokens
total_usage.total_tokens += response.usage.total_tokens
response.usage = (
total_usage # Replace each response usage with the total usage
)
return process_response(
response,
response_model=response_model,
Expand Down
9 changes: 9 additions & 0 deletions tests/openai/test_patch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from itertools import product
from pydantic import BaseModel, field_validator
from openai.types.chat import ChatCompletion
import pytest
import instructor

Expand Down Expand Up @@ -29,6 +30,8 @@ def test_runmodel(model, mode, client):
model, "_raw_response"
), "The raw response should be available from OpenAI"

ChatCompletion(**model._raw_response.model_dump())


@pytest.mark.parametrize("model, mode", product(models, modes))
@pytest.mark.asyncio
Expand All @@ -49,6 +52,8 @@ async def test_runmodel_async(model, mode, aclient):
model, "_raw_response"
), "The raw response should be available from OpenAI"

ChatCompletion(**model._raw_response.model_dump())


class UserExtractValidated(BaseModel):
name: str
Expand Down Expand Up @@ -81,6 +86,8 @@ def test_runmodel_validator(model, mode, client):
model, "_raw_response"
), "The raw response should be available from OpenAI"

ChatCompletion(**model._raw_response.model_dump())


@pytest.mark.parametrize("model, mode", product(models, modes))
@pytest.mark.asyncio
Expand All @@ -99,3 +106,5 @@ async def test_runmodel_async_validator(model, mode, aclient):
assert hasattr(
model, "_raw_response"
), "The raw response should be available from OpenAI"

ChatCompletion(**model._raw_response.model_dump())
13 changes: 9 additions & 4 deletions tests/test_function_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from openai.resources.chat.completions import ChatCompletion

from instructor import openai_schema, OpenAISchema
import instructor
from instructor.exceptions import IncompleteOutputException


Expand Down Expand Up @@ -92,11 +93,13 @@ class Dummy(OpenAISchema):
)
def test_incomplete_output_exception(test_model, mock_completion):
with pytest.raises(IncompleteOutputException):
test_model.from_response(mock_completion)
test_model.from_response(mock_completion, mode=instructor.Mode.FUNCTIONS)


def test_complete_output_no_exception(test_model, mock_completion):
test_model_instance = test_model.from_response(mock_completion)
test_model_instance = test_model.from_response(
mock_completion, mode=instructor.Mode.FUNCTIONS
)
assert test_model_instance.data == "complete data"


Expand All @@ -108,10 +111,12 @@ def test_complete_output_no_exception(test_model, mock_completion):
)
async def test_incomplete_output_exception_raise(test_model, mock_completion):
with pytest.raises(IncompleteOutputException):
await test_model.from_response(mock_completion)
await test_model.from_response(mock_completion, mode=instructor.Mode.FUNCTIONS)


@pytest.mark.asyncio
async def test_async_complete_output_no_exception(test_model, mock_completion):
test_model_instance = await test_model.from_response_async(mock_completion)
test_model_instance = await test_model.from_response_async(
mock_completion, mode=instructor.Mode.FUNCTIONS
)
assert test_model_instance.data == "complete data"

0 comments on commit 36871e6

Please sign in to comment.