diff --git a/README.md b/README.md index c4cf0b1..04c0760 100644 --- a/README.md +++ b/README.md @@ -72,6 +72,25 @@ _Average performance on the RULER dataset with 4k context length and Loogle Shor Please refer to the [evaluation](evaluation/README.md) directory for more details and results. +## KV cache quantization + +We support KV cache quantization through the transformers `QuantizedCache` class (see [HF blog post](https://huggingface.co/blog/kv-cache-quantization#how-to-use-quantized-kv-cache-in-%F0%9F%A4%97-transformers)). To use it, simply pass a cache object to your pipeline: + +```python +from transformers import QuantizedCacheConfig, QuantoQuantizedCache + +config = QuantizedCacheConfig(nbits=4) +cache = QuantoQuantizedCache(config) + +pipe(..., cache=cache) +``` + +By default, the `DynamicCache` is used (no quantization). + +> [!IMPORTANT] +> To use the `QuantizedCache`, you need to install additional dependencies (e.g. `pip install optimum-quanto==0.2.4`, see also [this issue](https://github.com/huggingface/transformers/issues/34848)). + + ## FAQ
@@ -165,10 +184,3 @@ Check the [demo notebook](notebooks/per_layer_compression_demo.ipynb) for more d
- -### Is quantization supported ? - - -We don't support quantization of the KV cache yet. Quantization can achieve up to 4x compression moving from (b)float16 to int4 and we believe it is orthogonal to the KV cache pruning strategies proposed in this repository. - -
diff --git a/kvpress/pipeline.py b/kvpress/pipeline.py index 45bc2c3..1ae45ae 100644 --- a/kvpress/pipeline.py +++ b/kvpress/pipeline.py @@ -7,7 +7,7 @@ from typing import Optional import torch -from transformers import AutoModelForCausalLM, DynamicCache, Pipeline +from transformers import AutoModelForCausalLM, Cache, DynamicCache, QuantizedCache, Pipeline from transformers.pipelines import PIPELINE_REGISTRY from transformers.pipelines.base import GenericTensor @@ -32,6 +32,7 @@ def _sanitize_parameters( press: Optional[BasePress] = None, max_new_tokens: int = 50, max_context_length: Optional[int] = None, + cache: Optional[Cache] = None, **kwargs, ): """ @@ -42,7 +43,7 @@ def _sanitize_parameters( ---------- question : str, optional The question to be asked about the context. Exclusive with `questions`. - questions : List[str], optional + questions : list[str], optional A list of questions to be asked about the context. Exclusive with `question`. answer_prefix : str, optional The prefix to be added to the generated answer. @@ -52,12 +53,14 @@ def _sanitize_parameters( The maximum number of new tokens to generate for each answer. max_context_length : int, optional The maximum number of tokens in the context. By default will use the maximum length supported by the model. + cache : Cache, optional + The cache to use for the forward pass. Defaults to None (DynamicCache). **kwargs : dict Additional keyword arguments, currently ignored. Returns ------- - Tuple[Dict, Dict, Dict] + Tuple[dict, dict, dict] A tuple containing three dictionaries: - preprocess_kwargs: The keyword arguments for the preprocess function. - forward_kwargs: The keyword arguments for the forward function. @@ -75,7 +78,7 @@ def _sanitize_parameters( "answer_prefix": answer_prefix, "max_context_length": max_context_length, } - forward_kwargs = {"press": press, "max_new_tokens": max_new_tokens} + forward_kwargs = {"press": press, "max_new_tokens": max_new_tokens, "cache": cache} return preprocess_kwargs, forward_kwargs, postprocess_kwargs def preprocess( @@ -90,7 +93,7 @@ def preprocess( Returns ------- - Dict[str, GenericTensor] + dict[str, GenericTensor] A dictionary containing the tokenized context (key: "context_ids") and questions (key: "questions_ids"). """ @@ -127,23 +130,29 @@ def preprocess( return {"context_ids": context_ids, "questions_ids": question_ids} def _forward( - self, input_tensors: dict[str, GenericTensor], max_new_tokens: int = 50, press: Optional[BasePress] = None + self, + input_tensors: dict[str, GenericTensor], + max_new_tokens: int = 50, + press: Optional[BasePress] = None, + cache: Optional[Cache] = None, ): """ Forward pass of the kv-press pipeline. Parameters ---------- - input_tensors : Dict[str, GenericTensor] + input_tensors : dict[str, GenericTensor] A dictionary containing the tokenized context and questions. max_new_tokens : int, optional The maximum number of new tokens to generate for each answer. Defaults to 50. press : BasePress, optional The key-value press to use for compression. Defaults to None. + cache : Cache, optional + The cache to use for the forward pass. Defaults to None (DynamicCache). Returns ------- - List[str] + list[str] A list of generated answers. """ @@ -151,23 +160,26 @@ def _forward( context_length = context_ids.shape[1] # Prefilling using the press on the context + if cache is None: + cache = DynamicCache() + with press(self.model) if press is not None else contextlib.nullcontext(): - past_key_values = self.model( + self.model( input_ids=context_ids, - past_key_values=DynamicCache(), + past_key_values=cache, output_attentions=isinstance(press, ObservedAttentionPress), num_logits_to_keep=1, - ).past_key_values + ) logger.debug(f"Context Length: {context_length}") - logger.debug(f"Compressed Context Length: {past_key_values.get_seq_length()}") + logger.debug(f"Compressed Context Length: {cache.get_seq_length()}") # Greedy decoding for each question answers = [] for question_ids in input_tensors["questions_ids"]: answer = self.generate_answer( question_ids=question_ids.to(self.model.device), - past_key_values=past_key_values, + cache=cache, context_length=context_length, max_new_tokens=max_new_tokens, ) @@ -181,7 +193,7 @@ def postprocess(self, model_outputs, single_question): return {"answers": model_outputs} def generate_answer( - self, question_ids: torch.Tensor, past_key_values: DynamicCache, context_length: int, max_new_tokens: int + self, question_ids: torch.Tensor, cache: Cache, context_length: int, max_new_tokens: int ) -> str: """ Generate an answer to a question using greedy decoding. @@ -190,7 +202,7 @@ def generate_answer( ---------- question_ids : torch.Tensor The tokenized question. - past_key_values : DynamicCache + cache : Cache The compressed key-value cache. context_length : int The length of the context. @@ -203,10 +215,7 @@ def generate_answer( The generated answer. """ - cache_seq_lengths = [ - past_key_values.get_seq_length(layer_idx=layer_idx) for layer_idx in range(len(past_key_values)) - ] - + cache_seq_lengths = [cache.get_seq_length(layer_idx) for layer_idx in range(len(cache))] position_ids = torch.arange( context_length, context_length + question_ids.shape[1], device=self.model.device ).unsqueeze(0) @@ -214,7 +223,7 @@ def generate_answer( # if the user doesn't provide a question, skip forward pass outputs = self.model( input_ids=question_ids.to(self.model.device), - past_key_values=past_key_values, + past_key_values=cache, position_ids=position_ids, num_logits_to_keep=1, ) @@ -229,7 +238,7 @@ def generate_answer( for i in range(max_new_tokens - 1): outputs = self.model( input_ids=generated_ids[-1].unsqueeze(0).unsqueeze(0), - past_key_values=outputs.past_key_values, + past_key_values=cache, position_ids=position_ids + i, ) new_id = outputs.logits[0, -1].argmax() @@ -238,13 +247,15 @@ def generate_answer( break answer = self.tokenizer.decode(torch.stack(generated_ids), skip_special_tokens=True) - # remove the generated tokens from the cache - past_key_values.key_cache = [ - key[:, :, :cache_seq_len] for key, cache_seq_len in zip(past_key_values.key_cache, cache_seq_lengths) - ] - past_key_values.value_cache = [ - value[:, :, :cache_seq_len] for value, cache_seq_len in zip(past_key_values.value_cache, cache_seq_lengths) - ] + # Remove the generated tokens from the cache + if isinstance(cache, QuantizedCache): + key_attr, value_attr = "_quantized_key_cache", "_quantized_value_cache" + else: + key_attr, value_attr = "key_cache", "value_cache" + + setattr(cache, key_attr, [key[:, :, :c] for key, c in zip(getattr(cache, key_attr), cache_seq_lengths)]) + setattr(cache, value_attr, [value[:, :, :c] for value, c in zip(getattr(cache, value_attr), cache_seq_lengths)]) + return answer diff --git a/kvpress/presses/base_press.py b/kvpress/presses/base_press.py index 8d33cdd..553187a 100644 --- a/kvpress/presses/base_press.py +++ b/kvpress/presses/base_press.py @@ -8,7 +8,14 @@ import torch from torch import nn -from transformers import LlamaForCausalLM, MistralForCausalLM, Phi3ForCausalLM, PreTrainedModel, Qwen2ForCausalLM +from transformers import ( + LlamaForCausalLM, + MistralForCausalLM, + Phi3ForCausalLM, + PreTrainedModel, + Qwen2ForCausalLM, + QuantizedCache, +) logger = logging.getLogger(__name__) @@ -92,8 +99,12 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic if (self.compression_ratio == 0) or (cache.seen_tokens > q_len): return output - keys = cache.key_cache[module.layer_idx] - values = cache.value_cache[module.layer_idx] + if isinstance(cache, QuantizedCache): + keys = cache._dequantize(cache._quantized_key_cache[module.layer_idx]) + values = cache._dequantize(cache._quantized_value_cache[module.layer_idx]) + else: + keys = cache.key_cache[module.layer_idx] + values = cache.value_cache[module.layer_idx] with torch.no_grad(): scores = self.score(module, hidden_states, keys, values, attentions, kwargs) @@ -104,8 +115,14 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim) # Update cache - cache.key_cache[module.layer_idx] = keys.gather(2, indices) - cache.value_cache[module.layer_idx] = values.gather(2, indices) + keys = keys.gather(2, indices).contiguous() + values = values.gather(2, indices).contiguous() + if isinstance(cache, QuantizedCache): + cache._quantized_key_cache[module.layer_idx] = cache._quantize(keys, axis=cache.axis_key) + cache._quantized_value_cache[module.layer_idx] = cache._quantize(values, axis=cache.axis_value) + else: + cache.key_cache[module.layer_idx] = keys + cache.value_cache[module.layer_idx] = values return output diff --git a/pyproject.toml b/pyproject.toml index fac7c9c..3deb82a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ name = "kvpress" authors = ["Simon Jegou", "Maximilian Jeblick", "Jiwei Liu", "David Austin"] description = "Efficiently compress the KV cache of any pretrained transformer" -version = "0.0.1" +version = "0.0.2" readme = "README.md" [tool.poetry.dependencies]