Skip to content

Commit

Permalink
[llm] Generalize the KV cache to support both direct and paged caches. (
Browse files Browse the repository at this point in the history
#607)

I still want to do more work on configs and how to construct models, but
this at least unifies the implementations for paged and direct.
  • Loading branch information
stellaraccident authored Apr 11, 2024
1 parent e158bae commit 9484484
Show file tree
Hide file tree
Showing 7 changed files with 314 additions and 100 deletions.
3 changes: 1 addition & 2 deletions llm/turbine_llm/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,7 @@ def _(
attention_mask=attention_mask,
start_positions=start_positions,
seq_block_ids=seq_block_ids,
read_cache_state=cache_state,
write_cache_state=cache_state,
cache_state=cache_state,
)
return logits

Expand Down
50 changes: 39 additions & 11 deletions llm/turbine_llm/examples/paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ..layers import *

# TODO: Should be using a base class with the protocol supported.
from ..models.llama.llama import PagedLlamaModelV1
from ..models.llama.llama import *
from ..utils.debugging import trace_tensor
from ..utils.tokenizer import InferenceTokenizer, load_tokenizer

Expand All @@ -32,7 +32,12 @@ def __init__(
):
self.model = model
self.tokenizer = tokenizer
self.cache_state = model.cache.allocate(page_cache_size, dtype=torch.float32)
if model.cache.is_paged:
self.shared_cache_state = model.cache.paged.allocate(
page_cache_size, dtype=torch.float32
)
else:
self.shared_cache_state = None
self.free_pages = list(range(1, 128))
self.end_token = end_token

Expand All @@ -42,36 +47,55 @@ def block_seq_stride(self) -> int:

def begin_batch(self, prompts: list[str]):
token_ids, seq_lens = self.tokenizer.encode(
prompts, pad_to_multiple_of=self.model.cache.block_seq_stride
prompts, pad_to_multiple_of=self.model.cache.pad_sequence_stride
)
token_ids = torch.tensor(token_ids)
seq_lens = torch.tensor(seq_lens)
return Batch(self, token_ids, seq_lens)
if self.shared_cache_state is not None:
cache_state = self.shared_cache_state
else:
cache_state = self.model.cache.direct.allocate(
bs=len(prompts), dtype=torch.float32
)
return Batch(self, token_ids, seq_lens, cache_state)

def alloc_page(self) -> int:
if self.model.cache.is_direct:
# We don't allocate block ids for the direct cache.
return 0

return self.free_pages.pop()

def release_page(self, index: int):
if self.model.cache.is_direct:
return
self.free_pages.append(index)


class Batch:
def __init__(
self, parent: TorchGenerator, token_ids: torch.Tensor, seq_lens: torch.Tensor
self,
parent: TorchGenerator,
token_ids: torch.Tensor,
seq_lens: torch.Tensor,
cache_state: list[torch.Tensor],
):
self.bs = token_ids.shape[0]
assert seq_lens.shape[0] == self.bs
self.parent = parent
self.token_ids = token_ids
self.seq_lens = seq_lens
self.cache_state = cache_state
self.results: list[list[int]] = [[] for _ in range(self.bs)]
self.done_result_indices: set[int] = set()

# Assemble the batch.
seq_stride = self.parent.block_seq_stride
self.seq_block_ids: list[list[int]] = []
for seq_len in self.seq_lens:
blocks_needed = int(math.ceil(seq_len / seq_stride))
blocks_needed = (
int(math.ceil(seq_len / seq_stride)) if seq_stride > 0 else 0
)
row = []
for _ in range(blocks_needed):
row.append(self.parent.alloc_page())
Expand Down Expand Up @@ -126,7 +150,7 @@ def prefill(self):
self.token_ids,
attention_mask=attention_mask,
seq_block_ids=seq_block_ids_tensor,
cache_state=self.parent.cache_state,
cache_state=self.cache_state,
)
# TODO: Normalize the output of extract_tokens_from_logits into
# tensor [bs, 1].
Expand Down Expand Up @@ -160,8 +184,7 @@ def decode(self):
attention_mask=decode_attention_mask,
start_positions=start_positions,
seq_block_ids=seq_block_ids_tensor,
read_cache_state=self.parent.cache_state,
write_cache_state=self.parent.cache_state,
cache_state=self.cache_state,
)
trace_tensor("decode.logits", logits)
# TODO: Normalize the output of extract_tokens_from_logits into
Expand All @@ -183,6 +206,7 @@ def main():

parser = cli.create_parser()
parser.add_argument("prompt", nargs="+", help="Prompt strings")
parser.add_argument("--kv-cache-type", default="paged", help="KV cache type")
cli.add_gguf_dataset_options(parser)
cli.add_tokenizer_options(parser)
args = cli.parse(parser)
Expand All @@ -192,8 +216,12 @@ def main():
dataset = gguf.load_file(data_files["gguf"])
prompts = args.prompt

hp = configs.LlamaHParams.from_gguf_props(dataset.properties)
model = PagedLlamaModelV1(dataset.root_theta, hp)
config = LlamaModelConfig(
hp=configs.LlamaHParams.from_gguf_props(dataset.properties),
block_seq_stride=16,
kv_cache_type=args.kv_cache_type,
)
model = PagedLlamaModelV1(dataset.root_theta, config)
generator = TorchGenerator(model, tokenizer)

print(f":: Prompting:")
Expand Down
5 changes: 2 additions & 3 deletions llm/turbine_llm/examples/validate_paged_llama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def main(args: list[str]):
config = gguf.load_file(args[0])
hp = configs.LlamaHParams.from_gguf_props(config.properties)
model = PagedLlamaModelV1(config.root_theta, hp)
cache_state = model.cache.allocate(128, torch.float32)
cache_state = model.cache.paged.allocate(128, torch.float32)
start_index = 0
next_batch = torch.tensor(
[
Expand Down Expand Up @@ -118,8 +118,7 @@ def main(args: list[str]):
attention_mask=decode_attention_mask,
start_positions=start_positions,
seq_block_ids=seq_block_ids,
read_cache_state=cache_state,
write_cache_state=cache_state,
cache_state=cache_state,
)
tokens = torch.tensor(
model.extract_tokens_from_logits(logits, [1, 1, 1, 1])
Expand Down
2 changes: 1 addition & 1 deletion llm/turbine_llm/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from .base import BaseLayer, ThetaLayer
from .kv_cache import PagedKVCache
from .kv_cache import BaseKVCache, DirectKVCache, PagedKVCache
from .causal_llm import BaseCausalLMModel
from .data import (
Dataset,
Expand Down
2 changes: 2 additions & 0 deletions llm/turbine_llm/layers/configs/llm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class LlamaHParams:
feed_forward_length: int
rope_dimension_count: int
attention_head_count: int
attn_head_dim: int
attention_layer_norm_rms_epsilon: float
attention_head_count_kv: int

Expand All @@ -50,6 +51,7 @@ def from_gguf_props(p: dict[str, Any]):
embedding_length=_int_prop(p, "llama.embedding_length"),
block_count=_int_prop(p, "llama.block_count"),
feed_forward_length=_int_prop(p, "llama.feed_forward_length"),
attn_head_dim=_int_prop(p, "llama.rope.dimension_count"),
rope_dimension_count=_int_prop(p, "llama.rope.dimension_count"),
attention_head_count=attention_head_count,
attention_layer_norm_rms_epsilon=_float_prop(
Expand Down
98 changes: 97 additions & 1 deletion llm/turbine_llm/layers/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,110 @@
and dims floating around everywhere.
"""

import abc
import math

import torch

from ..utils.debugging import trace_tensor

__all__ = [
"BaseKVCache",
"DirectKVCache",
"PagedKVCache",
]


class PagedKVCache:
class BaseKVCache(abc.ABC):
"""Base class for a KV cache.
This doesn't do much on its own except to serve as a type-safe base class
unifying the PagedKVCache and DirectKVCache:
* PagedKVCache is a shared cache which can be used across an arbitrary
number of batches/sequences with random mapping of blocks within a
sequence to backing "page".
* DirectKVCache is a single-batch cache with a fixed batch size and
sequence length where the K/V cache tensors for each transformer block
are densely layed out in memory.
"""

block_seq_stride: int
transformer_block_count: int
attn_head_count: int
attn_head_dim: int

@property
@abc.abstractmethod
def pad_sequence_stride(self) -> int:
"""Stride that a sequence must be padded to in order to be valid for
the cache. For paged caches, this will typically be a multiple of the
block_seq_stride. For direct caches it may be 1 or a multiple that
is chosen for performance reasons.
"""
...

@property
def is_paged(self) -> bool:
return isinstance(self, PagedKVCache)

@property
def is_direct(self) -> bool:
return isinstance(self, DirectKVCache)

@property
def paged(self) -> "PagedKVCache":
assert isinstance(
self, PagedKVCache
), f"Attempt to access cache {type(self)} as paged but it is not"
return self

@property
def direct(self) -> "DirectKVCache":
assert isinstance(
self, DirectKVCache
), f"Attempt to access cache {type(self)} as direct but it is not"
return self


class DirectKVCache(BaseKVCache):
"""KVCache for a single batch where the cache tensors are densely laid out."""

def __init__(
self,
*,
block_seq_stride: int,
transformer_block_count: int,
attn_head_count: int,
attn_head_dim: int,
seq_length: int,
):
self.block_seq_stride = block_seq_stride
self.transformer_block_count = transformer_block_count
self.attn_head_count = attn_head_count
self.attn_head_dim = attn_head_dim
self.seq_length = seq_length

@property
def pad_sequence_stride(self) -> int:
return self.block_seq_stride

def allocate(self, *, bs: int, dtype: torch.dtype) -> list[torch.Tensor]:
"""Allocates 2*transformer_block_count K/V cache tensors for the
given batch size and sequence length.
Each tensor has shape: [bs, sl, attn_head_count, attn_head_dim]
"""
return [
torch.empty(
[bs, self.seq_length, self.attn_head_count, self.attn_head_dim],
dtype=dtype,
)
for _ in range(2 * self.transformer_block_count)
]


class PagedKVCache(BaseKVCache):
"""Implementation of a KV cache on top of a 'page table'.
The page table slab is physically represented as a 2D tensor:
Expand Down Expand Up @@ -81,6 +173,10 @@ def unflatten_page_table(self, state: list[torch.Tensor]) -> torch.Tensor:
]
)

@property
def pad_sequence_stride(self) -> int:
return self.block_seq_stride

def allocate(self, page_count: int, dtype: torch.dtype) -> list[torch.Tensor]:
"""Allocates tensor state for a page table for the given capacity in
pages.
Expand Down
Loading

0 comments on commit 9484484

Please sign in to comment.