Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Support for Prompt Caching in SpiceMessages Class #106

Closed
mentatai bot opened this issue Aug 25, 2024 · 2 comments
Closed

Add Support for Prompt Caching in SpiceMessages Class #106

mentatai bot opened this issue Aug 25, 2024 · 2 comments

Comments

@mentatai
Copy link
Contributor

mentatai bot commented Aug 25, 2024

Summary

Add support for Anthropic's prompt caching feature to the SpiceMessages class in the spice library. This will enable faster and more cost-efficient API calls by reusing cached prompt prefixes. Additionally, track cache performance metrics to verify cache hits and the number of input tokens cached.

Changes Required

  1. Update SpiceMessages Class:

    • Add a cache argument to the message creation methods.
    • Set the cache_control parameter based on the cache argument.
  2. Modify Message Creation Functions:

    • Update the message creation functions to handle the cache argument.
  3. Track Cache Performance Metrics:

    • Update the get_response method in the Spice class to handle the new API response fields related to caching.
    • Log the number of input tokens cached and verify cache hits using the client.extract_text_and_tokens method.

Implementation Details

1. Update SpiceMessages Class

Modify the SpiceMessages class to include the cache argument:

class SpiceMessages(UserList[SpiceMessage]):
    ...
    def add_message(self, role: Literal["user", "assistant", "system"], content: str, cache: bool = False):
        self.data.append(create_message(role, content, cache))
    
    def add_user_message(self, content: str, cache: bool = False):
        """Appends a user message with the given content."""
        self.data.append(user_message(content, cache))
    
    def add_system_message(self, content: str, cache: bool = False):
        """Appends a system message with the given content."""
        self.data.append(system_message(content, cache))
    
    def add_assistant_message(self, content: str, cache: bool = False):
        """Appends an assistant message with the given content."""
        self.data.append(assistant_message(content, cache))
    ...

2. Modify Message Creation Functions

Update the message creation functions to handle the cache argument:

def create_message(role: Literal["user", "assistant", "system"], content: str, cache: bool = False) -> ChatCompletionMessageParam:
    message = {"role": role, "content": content}
    if cache:
        message["cache_control"] = {"type": "ephemeral"}
    return message

def user_message(content: str, cache: bool = False) -> ChatCompletionUserMessageParam:
    """Creates a user message with the given content."""
    return create_message("user", content, cache)

def system_message(content: str, cache: bool = False) -> ChatCompletionSystemMessageParam:
    """Creates a system message with the given content."""
    return create_message("system", content, cache)

def assistant_message(content: str, cache: bool = False) -> ChatCompletionAssistantMessageParam:
    """Creates an assistant message with the given content."""
    return create_message("assistant", content, cache)

3. Track Cache Performance Metrics

Update the get_response method in the Spice class to handle the new API response fields related to caching:

async def get_response(
    self,
    messages: Collection[SpiceMessage],
    model: Optional[TextModel | str] = None,
    provider: Optional[Provider | str] = None,
    temperature: Optional[float] = None,
    max_tokens: Optional[int] = None,
    response_format: Optional[ResponseFormat] = None,
    name: Optional[str] = None,
    validator: Optional[Callable[[str], bool]] = None,
    converter: Callable[[str], T] = string_identity,
    streaming_callback: Optional[Callable[[str], None]] = None,
    retries: int = 0,
    retry_strategy: Optional[RetryStrategy[T]] = None,
    cache_control: Optional[Dict[str, Any]] = None,  # New parameter
) -> SpiceResponse[T]:
    ...
    call_args = self._fix_call_args(
        messages, text_model, streaming_callback is not None, temperature, max_tokens, response_format
    )
    ...
    while True:
        ...
        with client.catch_and_convert_errors():
            if streaming_callback is not None:
                stream = await client.get_chat_completion_or_stream(call_args)
                stream = cast(AsyncIterator, stream)
                streaming_spice_response = StreamingSpiceResponse(
                    text_model, call_args, client, stream, None, streaming_callback
                )
                chat_completion = await streaming_spice_response.complete_response()
                text, input_tokens, output_tokens = (
                    chat_completion.text,
                    chat_completion.input_tokens,
                    chat_completion.output_tokens,
                )
            else:
                chat_completion = await client.get_chat_completion_or_stream(call_args)
                text, input_tokens, output_tokens = client.extract_text_and_tokens(chat_completion, call_args)
        
        # Handle cache performance metrics
        cache_creation_input_tokens = chat_completion.usage.get("cache_creation_input_tokens", 0)
        cache_read_input_tokens = chat_completion.usage.get("cache_read_input_tokens", 0)
        print(f"Cache creation input tokens: {cache_creation_input_tokens}")
        print(f"Cache read input tokens: {cache_read_input_tokens}")
        ...

Example Usage

Here's an example of how you might use the updated SpiceMessages class with caching:

from spice import Spice
from spice.spice_message import SpiceMessages

client = Spice()

messages = SpiceMessages(client)
messages.add_system_message("You are an AI assistant tasked with analyzing literary works.", cache=True)
messages.add_user_message("Analyze the major themes in 'Pride and Prejudice'.", cache=True)

response = await client.get_response(messages=messages, model="claude-3-5-sonnet-20240620")
print(response.text)

Acceptance Criteria

  • The SpiceMessages class should support the cache argument.
  • The get_response method should log cache performance metrics.
  • The implementation should be backward compatible and ignore the cache argument for non-Anthropic clients.
@mentatai
Copy link
Contributor Author

mentatai bot commented Aug 25, 2024

I will start working on this issue

@biobootloader
Copy link
Member

done in #108

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant