From 8a48bbe5778bfb17e7dd3a2170ea9da53df69876 Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Tue, 21 Jan 2025 23:42:48 +0000 Subject: [PATCH] [Llama] Do not allow configurable partitions for KVCache --- .../sharktank/export_layer/export_kv_cache.py | 3 +- sharktank/sharktank/layers/kv_cache.py | 47 +++++++++---------- .../layers/paged_llama_attention_block.py | 9 ++-- sharktank/tests/layers/kv_cache_test.py | 12 +++-- .../layers/sharded_paged_kv_cache_test.py | 19 ++++---- 5 files changed, 47 insertions(+), 43 deletions(-) diff --git a/sharktank/sharktank/export_layer/export_kv_cache.py b/sharktank/sharktank/export_layer/export_kv_cache.py index 09f0a1c15..49ad1f03a 100644 --- a/sharktank/sharktank/export_layer/export_kv_cache.py +++ b/sharktank/sharktank/export_layer/export_kv_cache.py @@ -105,7 +105,8 @@ def _(model, state, partition_0, write_page_ids: torch.Tensor) -> torch.Tensor: write_page_ids = replicate(write_page_ids, count=args.sharding) cache.write( state, - cache_partitions=[partition_0, partition_0], + partition_0, + partition_0, transformer_block_index=1, page_ids=write_page_ids, ) diff --git a/sharktank/sharktank/layers/kv_cache.py b/sharktank/sharktank/layers/kv_cache.py index be8c66fb4..6e292477b 100644 --- a/sharktank/sharktank/layers/kv_cache.py +++ b/sharktank/sharktank/layers/kv_cache.py @@ -55,7 +55,6 @@ def __init__( transformer_block_count: int, attn_head_count: int, attn_head_dim: int, - cache_partition_count: int = 2, block_seq_stride: int = 16, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, @@ -64,7 +63,6 @@ def __init__( self.transformer_block_count = transformer_block_count self.attn_head_count = attn_head_count self.attn_head_dim = attn_head_dim - self.cache_partition_count = cache_partition_count self.block_seq_stride = block_seq_stride self.shard_count = shard_count if attn_head_count % shard_count != 0: @@ -75,7 +73,7 @@ def __init__( # Some derived values based on attributes. self.sub_page_dims = [ self.transformer_block_count, - self.cache_partition_count, + 2, self.block_seq_stride, self.attn_head_count // self.shard_count, self.attn_head_dim, @@ -115,7 +113,7 @@ def shard_state( [ -1, self.transformer_block_count, - self.cache_partition_count, + 2, self.block_seq_stride, self.attn_head_count, self.attn_head_dim, @@ -195,8 +193,8 @@ def read( # [page, attn_layer, cache_partition] # Where the cache line can be 0 (k) or 1 (v). subblock_table = page_table.flatten(start_dim=0, end_dim=2) - page_stride = self.transformer_block_count * self.cache_partition_count - transformer_block_stride = self.cache_partition_count + page_stride = self.transformer_block_count * 2 + transformer_block_stride = 2 base_subblock_ids = page_ids * page_stride + ( transformer_block_index * transformer_block_stride ) @@ -225,8 +223,10 @@ def read_cache_partition(index: int): def write_timestep( self, state: list[Union[torch.Tensor, SplitPrimitiveTensor]], - # List of [bs, 1, attn_head_count, attn_head_dim] - cache_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]], + # [bs, 1, attn_head_count, attn_head_dim] + key: Union[torch.Tensor, SplitPrimitiveTensor], + # [bs, 1, attn_head_count, attn_head_dim] + value: Union[torch.Tensor, SplitPrimitiveTensor], *, transformer_block_index: int, # [bs] @@ -242,10 +242,9 @@ def write_timestep( device = self.device page_table = self.unflatten_page_table(state) # 6D bs, *_ = seq_positions.shape - assert len(cache_partitions) == self.cache_partition_count # [bs, 1, atten_head_count, attn_head_dim] - for idx, cache_partition in enumerate(cache_partitions): + for idx, cache_partition in enumerate([key, value]): # [bs, 1] page_index = seq_positions // self.block_seq_stride @@ -277,12 +276,11 @@ def write_timestep( indices = (page_id, transformer_block, partitions, page_offset) page_table.index_put_(indices=indices, values=cache_partition) - return - def write( self, state: list[Union[torch.Tensor, SplitPrimitiveTensor]], - cache_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]], + key: Union[torch.Tensor, SplitPrimitiveTensor], + value: Union[torch.Tensor, SplitPrimitiveTensor], *, transformer_block_index: int, page_ids: Union[torch.Tensor, ReplicatedTensor], @@ -304,20 +302,21 @@ def write( # [page, attn_layer, cache_partition] # Where the cache line can be 0 (k) or 1 (v). subblock_table = page_table.flatten(start_dim=0, end_dim=2) - page_stride = self.transformer_block_count * self.cache_partition_count - transformer_block_stride = self.cache_partition_count + page_stride = self.transformer_block_count * 2 + transformer_block_stride = 2 base_subblock_ids = page_ids * page_stride + ( transformer_block_index * transformer_block_stride ) - for index, partition in enumerate(cache_partitions): - part_block_view = partition.unflatten( - 1, (block_seq_len, self.block_seq_stride) - ) - part_block_view = part_block_view.flatten(0, 1) + key_reshaped = key.unflatten(1, (block_seq_len, self.block_seq_stride)).flatten( + 0, 1 + ) + value_reshaped = value.unflatten( + 1, (block_seq_len, self.block_seq_stride) + ).flatten(0, 1) - subblock_ids = ( - (base_subblock_ids + index) if index > 0 else base_subblock_ids - ).flatten(0, 1) + key_ids = base_subblock_ids.flatten(0, 1) + value_ids = base_subblock_ids.flatten(0, 1) + 1 - subblock_table.index_copy_(0, subblock_ids, part_block_view) + subblock_table.index_copy_(0, key_ids, key_reshaped) + subblock_table.index_copy_(0, value_ids, value_reshaped) diff --git a/sharktank/sharktank/layers/paged_llama_attention_block.py b/sharktank/sharktank/layers/paged_llama_attention_block.py index 64c7c65d9..2672409f9 100644 --- a/sharktank/sharktank/layers/paged_llama_attention_block.py +++ b/sharktank/sharktank/layers/paged_llama_attention_block.py @@ -248,7 +248,8 @@ def transact_cache( # Prefill: Write the entire cache. cache.write( cache_state, - cache_partitions=[xk_cache_update, xv_cache_update], + xk_cache_update, + xv_cache_update, transformer_block_index=self.block_index, page_ids=seq_block_ids, ) @@ -267,10 +268,8 @@ def transact_cache( # Write our one updated cache row into the cache. cache.write_timestep( cache_state, - cache_partitions=[ - xk_cache_update, - xv_cache_update, - ], + xk_cache_update, + xv_cache_update, transformer_block_index=self.block_index, seq_positions=start_positions, page_ids=seq_block_ids, diff --git a/sharktank/tests/layers/kv_cache_test.py b/sharktank/tests/layers/kv_cache_test.py index d59d0a85b..fd0d2da25 100644 --- a/sharktank/tests/layers/kv_cache_test.py +++ b/sharktank/tests/layers/kv_cache_test.py @@ -48,7 +48,8 @@ def test_paged(): cache.write( allocation, - cache_partitions=[write_ones, write_twos], + write_ones, + write_twos, transformer_block_index=1, page_ids=write_page_ids, ) @@ -85,7 +86,8 @@ def test_paged(): write_pos = torch.full((bs,), write_seq_length, dtype=torch.int64) cache.write_timestep( allocation, - cache_partitions=[write_threes, write_fours], + write_threes, + write_fours, transformer_block_index=1, seq_positions=write_pos, page_ids=page_ids, @@ -154,7 +156,8 @@ def test_sharded_paged(): cache.write( allocation, - cache_partitions=[write_ones, write_twos], + write_ones, + write_twos, transformer_block_index=1, page_ids=write_page_ids, ) @@ -187,7 +190,8 @@ def test_sharded_paged(): cache.write_timestep( allocation, - cache_partitions=[write_threes, write_fours], + write_threes, + write_fours, transformer_block_index=1, seq_positions=write_pos, page_ids=page_ids, diff --git a/sharktank/tests/layers/sharded_paged_kv_cache_test.py b/sharktank/tests/layers/sharded_paged_kv_cache_test.py index 833e2ce71..04b3d1052 100644 --- a/sharktank/tests/layers/sharded_paged_kv_cache_test.py +++ b/sharktank/tests/layers/sharded_paged_kv_cache_test.py @@ -26,7 +26,6 @@ def setUp(self): self.attn_head_count = self.shard_count * 7 self.block_seq_stride = 19 self.attn_head_dim = 17 - self.cache_partition_count = 2 self.page_count = 23 self.batch_size = 11 self.block_seq_len = 2 @@ -37,7 +36,6 @@ def setUp(self): attn_head_count=self.attn_head_count, block_seq_stride=self.block_seq_stride, attn_head_dim=self.attn_head_dim, - cache_partition_count=self.cache_partition_count, dtype=self.dtype, ) self.sharded_cache = PagedKVCache( @@ -46,7 +44,6 @@ def setUp(self): attn_head_count=self.attn_head_count, block_seq_stride=self.block_seq_stride, attn_head_dim=self.attn_head_dim, - cache_partition_count=self.cache_partition_count, dtype=self.dtype, ) @@ -137,7 +134,7 @@ def testWriteTimestep(self): self.attn_head_count, self.attn_head_dim, ) - for _ in range(self.cache_partition_count) + for _ in range(2) ] transformer_block_index = 1 seq_positions = torch.randint( @@ -148,7 +145,8 @@ def testWriteTimestep(self): ) self.cache.write_timestep( state=cache_state, - cache_partitions=cache_partitions, + key=cache_partitions[0], + value=cache_partitions[1], transformer_block_index=transformer_block_index, seq_positions=seq_positions, page_ids=page_ids, @@ -163,7 +161,8 @@ def testWriteTimestep(self): sharded_page_ids = ops.replicate(page_ids, count=self.shard_count) self.sharded_cache.write_timestep( state=sharded_cache_state, - cache_partitions=sharded_cache_partitions, + key=sharded_cache_partitions[0], + value=sharded_cache_partitions[1], transformer_block_index=transformer_block_index, seq_positions=sharded_seq_positions, page_ids=sharded_page_ids, @@ -185,7 +184,7 @@ def testWrite(self): self.attn_head_count, self.attn_head_dim, ) - for _ in range(self.cache_partition_count) + for _ in range(2) ] transformer_block_index = 1 assert self.batch_size * self.block_seq_len <= self.page_count @@ -194,7 +193,8 @@ def testWrite(self): ) self.cache.write( state=cache_state, - cache_partitions=cache_partitions, + key=cache_partitions[0], + value=cache_partitions[1], transformer_block_index=transformer_block_index, page_ids=page_ids, ) @@ -207,7 +207,8 @@ def testWrite(self): sharded_page_ids = ops.replicate(page_ids, count=self.shard_count) self.sharded_cache.write( state=sharded_cache_state, - cache_partitions=sharded_cache_partitions, + key=sharded_cache_partitions[0], + value=sharded_cache_partitions[1], transformer_block_index=transformer_block_index, page_ids=sharded_page_ids, )