Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable caching for 'generate' and 'stream_generate' functions to ensure persistence of cache across multiple requests #989

Closed
wants to merge 3 commits into from

Conversation

nath1295
Copy link

@nath1295 nath1295 commented Sep 17, 2024

  1. Add two new data classes called CacheHistory and StepOutput for storing cache history along with the token history
  2. Add the option to return cache in "generate" and "stream_generate" for further cache reuse.
  3. Add two functions to save and load cache from disk.
  4. "prompt" argument in the "generate" and "stream_generate" is no longer a suffix for the cache history. It will be the full prompt. In "generate_step", there is a check to find out the index of the maximum shared prefix between the list of token ids from the new prompt and the token ids from the history prompt.

Usage

from mlx_lm import load, stream_generate
from mlx_lm.utils import save_cache, load_cache

model, tokenizer = load('/Path/to/model')

prompt = 'Your long prompt here...'

# First generation without prompt cache history
for i, cache in stream_generate(model=model, 
        tokenizer=tokenizer, prompt=prompt, max_tokens=100, return_cache=True, verbose=True):
    print(i, end='')
# Processing prompt (1431/1431): 100%|██████████| 3/3 [00:02<00:00,  1.50it/s]
# Prompt preprocessing time for 1431 tokens: 2.007s (713.1801 tok/sec)

# Second generation with prompt cache history
for i, cache in stream_generate(model=model, 
        tokenizer=tokenizer, prompt=prompt, max_tokens=100, return_cache=True, verbose=True, cache_history=cache):
    new += i
    print(i, end='')
# Processing prompt (1/1): 100%|██████████| 1/1 [00:00<00:00, 595.61it/s]
# Prompt preprocessing time for 1 tokens: 0.001921s (520.6299 tok/sec)

# Save the cache history to use later
save_cache(cache, filename='test.safetensors', metadata=dict(model_id='My random model'))

# Load an existing cache from disk
cache, metadata = load_cache(filename='test.safetensors')

@nath1295 nath1295 changed the title Cache prompt with "generate" and "generate_stream" in python function Enable caching for 'generate' and 'stream_generate' functions to ensure persistence of cache across multiple requests Sep 18, 2024
@nath1295
Copy link
Author

Just updating the title of the PR for clarity. Now KV cache of any generation can be reused for other requests with these changes.

@nath1295
Copy link
Author

The code in server.py is modified accordingly to adapt to the changes made with generate_step. Prompt caching is available on server.py by default.

@awni
Copy link
Member

awni commented Oct 12, 2024

Thanks for the PR! However, most of this functionality should already be included in #1015 and #1026, so I will close this.

If there is anything here that those don't address please feel free to submit a follow up PR rebased on the latest. Thanks!

@awni awni closed this Oct 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants