From 396bf6f72ba24968c082b30c0d06b66e9f1fa423 Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Thu, 23 Jan 2025 19:07:13 +0000 Subject: [PATCH] [Llama] Do not allow configurable partitions for KVCache --- .../sharktank/export_layer/export_kv_cache.py | 3 +- sharktank/sharktank/layers/kv_cache.py | 36 +++++++++---------- .../layers/paged_llama_attention_block.py | 9 +++-- sharktank/sharktank/utils/create_cache.py | 1 - sharktank/tests/layers/kv_cache_test.py | 12 ++++--- .../paged_llama_attention_block_test.py | 3 -- .../layers/sharded_paged_kv_cache_test.py | 19 +++++----- .../sharded_paged_llama_attention_block.py | 2 -- .../tests/models/llama/attention_test.py | 1 - sharktank/tests/models/llama/kv_cache_test.py | 1 - 10 files changed, 42 insertions(+), 45 deletions(-) diff --git a/sharktank/sharktank/export_layer/export_kv_cache.py b/sharktank/sharktank/export_layer/export_kv_cache.py index 09f0a1c15..49ad1f03a 100644 --- a/sharktank/sharktank/export_layer/export_kv_cache.py +++ b/sharktank/sharktank/export_layer/export_kv_cache.py @@ -105,7 +105,8 @@ def _(model, state, partition_0, write_page_ids: torch.Tensor) -> torch.Tensor: write_page_ids = replicate(write_page_ids, count=args.sharding) cache.write( state, - cache_partitions=[partition_0, partition_0], + partition_0, + partition_0, transformer_block_index=1, page_ids=write_page_ids, ) diff --git a/sharktank/sharktank/layers/kv_cache.py b/sharktank/sharktank/layers/kv_cache.py index be8c66fb4..af01d68ba 100644 --- a/sharktank/sharktank/layers/kv_cache.py +++ b/sharktank/sharktank/layers/kv_cache.py @@ -55,7 +55,6 @@ def __init__( transformer_block_count: int, attn_head_count: int, attn_head_dim: int, - cache_partition_count: int = 2, block_seq_stride: int = 16, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, @@ -64,7 +63,7 @@ def __init__( self.transformer_block_count = transformer_block_count self.attn_head_count = attn_head_count self.attn_head_dim = attn_head_dim - self.cache_partition_count = cache_partition_count + self.cache_partition_count = 2 self.block_seq_stride = block_seq_stride self.shard_count = shard_count if attn_head_count % shard_count != 0: @@ -225,8 +224,10 @@ def read_cache_partition(index: int): def write_timestep( self, state: list[Union[torch.Tensor, SplitPrimitiveTensor]], - # List of [bs, 1, attn_head_count, attn_head_dim] - cache_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]], + # [bs, 1, attn_head_count, attn_head_dim] + key: Union[torch.Tensor, SplitPrimitiveTensor], + # [bs, 1, attn_head_count, attn_head_dim] + value: Union[torch.Tensor, SplitPrimitiveTensor], *, transformer_block_index: int, # [bs] @@ -242,10 +243,9 @@ def write_timestep( device = self.device page_table = self.unflatten_page_table(state) # 6D bs, *_ = seq_positions.shape - assert len(cache_partitions) == self.cache_partition_count # [bs, 1, atten_head_count, attn_head_dim] - for idx, cache_partition in enumerate(cache_partitions): + for idx, cache_partition in enumerate([key, value]): # [bs, 1] page_index = seq_positions // self.block_seq_stride @@ -277,12 +277,11 @@ def write_timestep( indices = (page_id, transformer_block, partitions, page_offset) page_table.index_put_(indices=indices, values=cache_partition) - return - def write( self, state: list[Union[torch.Tensor, SplitPrimitiveTensor]], - cache_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]], + key: Union[torch.Tensor, SplitPrimitiveTensor], + value: Union[torch.Tensor, SplitPrimitiveTensor], *, transformer_block_index: int, page_ids: Union[torch.Tensor, ReplicatedTensor], @@ -310,14 +309,15 @@ def write( transformer_block_index * transformer_block_stride ) - for index, partition in enumerate(cache_partitions): - part_block_view = partition.unflatten( - 1, (block_seq_len, self.block_seq_stride) - ) - part_block_view = part_block_view.flatten(0, 1) + key_reshaped = key.unflatten(1, (block_seq_len, self.block_seq_stride)).flatten( + 0, 1 + ) + value_reshaped = value.unflatten( + 1, (block_seq_len, self.block_seq_stride) + ).flatten(0, 1) - subblock_ids = ( - (base_subblock_ids + index) if index > 0 else base_subblock_ids - ).flatten(0, 1) + key_ids = base_subblock_ids.flatten(0, 1) + value_ids = base_subblock_ids.flatten(0, 1) + 1 - subblock_table.index_copy_(0, subblock_ids, part_block_view) + subblock_table.index_copy_(0, key_ids, key_reshaped) + subblock_table.index_copy_(0, value_ids, value_reshaped) diff --git a/sharktank/sharktank/layers/paged_llama_attention_block.py b/sharktank/sharktank/layers/paged_llama_attention_block.py index 64c7c65d9..2672409f9 100644 --- a/sharktank/sharktank/layers/paged_llama_attention_block.py +++ b/sharktank/sharktank/layers/paged_llama_attention_block.py @@ -248,7 +248,8 @@ def transact_cache( # Prefill: Write the entire cache. cache.write( cache_state, - cache_partitions=[xk_cache_update, xv_cache_update], + xk_cache_update, + xv_cache_update, transformer_block_index=self.block_index, page_ids=seq_block_ids, ) @@ -267,10 +268,8 @@ def transact_cache( # Write our one updated cache row into the cache. cache.write_timestep( cache_state, - cache_partitions=[ - xk_cache_update, - xv_cache_update, - ], + xk_cache_update, + xv_cache_update, transformer_block_index=self.block_index, seq_positions=start_positions, page_ids=seq_block_ids, diff --git a/sharktank/sharktank/utils/create_cache.py b/sharktank/sharktank/utils/create_cache.py index f462d9c00..03eb637b6 100644 --- a/sharktank/sharktank/utils/create_cache.py +++ b/sharktank/sharktank/utils/create_cache.py @@ -16,7 +16,6 @@ def create_paged_kv_cache(config: LlamaModelConfig) -> PagedKVCache: transformer_block_count=hp.block_count, attn_head_count=hp.attention_head_count_kv, attn_head_dim=hp.attn_head_dim, - cache_partition_count=2, # One for each of K/V. block_seq_stride=config.block_seq_stride, device=config.device, dtype=config.attention_dtype, diff --git a/sharktank/tests/layers/kv_cache_test.py b/sharktank/tests/layers/kv_cache_test.py index d59d0a85b..fd0d2da25 100644 --- a/sharktank/tests/layers/kv_cache_test.py +++ b/sharktank/tests/layers/kv_cache_test.py @@ -48,7 +48,8 @@ def test_paged(): cache.write( allocation, - cache_partitions=[write_ones, write_twos], + write_ones, + write_twos, transformer_block_index=1, page_ids=write_page_ids, ) @@ -85,7 +86,8 @@ def test_paged(): write_pos = torch.full((bs,), write_seq_length, dtype=torch.int64) cache.write_timestep( allocation, - cache_partitions=[write_threes, write_fours], + write_threes, + write_fours, transformer_block_index=1, seq_positions=write_pos, page_ids=page_ids, @@ -154,7 +156,8 @@ def test_sharded_paged(): cache.write( allocation, - cache_partitions=[write_ones, write_twos], + write_ones, + write_twos, transformer_block_index=1, page_ids=write_page_ids, ) @@ -187,7 +190,8 @@ def test_sharded_paged(): cache.write_timestep( allocation, - cache_partitions=[write_threes, write_fours], + write_threes, + write_fours, transformer_block_index=1, seq_positions=write_pos, page_ids=page_ids, diff --git a/sharktank/tests/layers/paged_llama_attention_block_test.py b/sharktank/tests/layers/paged_llama_attention_block_test.py index e74a14ad5..2e66c0313 100644 --- a/sharktank/tests/layers/paged_llama_attention_block_test.py +++ b/sharktank/tests/layers/paged_llama_attention_block_test.py @@ -35,7 +35,6 @@ def setUp(self): self.attention_head_dim = 11 * 2 self.rms_epsilon = 0.01 self.block_seq_stride = 17 - self.cache_partition_count = 2 self.page_count = 23 self.embedding_length = self.attention_head_count * self.attention_head_dim self.rope_dimension_count = self.attention_head_dim @@ -52,7 +51,6 @@ def testExportDecomposed(self): transformer_block_count=self.transformer_block_count, attn_head_count=self.head_count_kv, attn_head_dim=self.attention_head_dim, - cache_partition_count=self.cache_partition_count, block_seq_stride=self.block_seq_stride, dtype=dtype, ) @@ -135,7 +133,6 @@ def testExportNondecomposed(self): transformer_block_count=self.transformer_block_count, attn_head_count=self.head_count_kv, attn_head_dim=self.attention_head_dim, - cache_partition_count=self.cache_partition_count, block_seq_stride=self.block_seq_stride, dtype=dtype, ) diff --git a/sharktank/tests/layers/sharded_paged_kv_cache_test.py b/sharktank/tests/layers/sharded_paged_kv_cache_test.py index 833e2ce71..04b3d1052 100644 --- a/sharktank/tests/layers/sharded_paged_kv_cache_test.py +++ b/sharktank/tests/layers/sharded_paged_kv_cache_test.py @@ -26,7 +26,6 @@ def setUp(self): self.attn_head_count = self.shard_count * 7 self.block_seq_stride = 19 self.attn_head_dim = 17 - self.cache_partition_count = 2 self.page_count = 23 self.batch_size = 11 self.block_seq_len = 2 @@ -37,7 +36,6 @@ def setUp(self): attn_head_count=self.attn_head_count, block_seq_stride=self.block_seq_stride, attn_head_dim=self.attn_head_dim, - cache_partition_count=self.cache_partition_count, dtype=self.dtype, ) self.sharded_cache = PagedKVCache( @@ -46,7 +44,6 @@ def setUp(self): attn_head_count=self.attn_head_count, block_seq_stride=self.block_seq_stride, attn_head_dim=self.attn_head_dim, - cache_partition_count=self.cache_partition_count, dtype=self.dtype, ) @@ -137,7 +134,7 @@ def testWriteTimestep(self): self.attn_head_count, self.attn_head_dim, ) - for _ in range(self.cache_partition_count) + for _ in range(2) ] transformer_block_index = 1 seq_positions = torch.randint( @@ -148,7 +145,8 @@ def testWriteTimestep(self): ) self.cache.write_timestep( state=cache_state, - cache_partitions=cache_partitions, + key=cache_partitions[0], + value=cache_partitions[1], transformer_block_index=transformer_block_index, seq_positions=seq_positions, page_ids=page_ids, @@ -163,7 +161,8 @@ def testWriteTimestep(self): sharded_page_ids = ops.replicate(page_ids, count=self.shard_count) self.sharded_cache.write_timestep( state=sharded_cache_state, - cache_partitions=sharded_cache_partitions, + key=sharded_cache_partitions[0], + value=sharded_cache_partitions[1], transformer_block_index=transformer_block_index, seq_positions=sharded_seq_positions, page_ids=sharded_page_ids, @@ -185,7 +184,7 @@ def testWrite(self): self.attn_head_count, self.attn_head_dim, ) - for _ in range(self.cache_partition_count) + for _ in range(2) ] transformer_block_index = 1 assert self.batch_size * self.block_seq_len <= self.page_count @@ -194,7 +193,8 @@ def testWrite(self): ) self.cache.write( state=cache_state, - cache_partitions=cache_partitions, + key=cache_partitions[0], + value=cache_partitions[1], transformer_block_index=transformer_block_index, page_ids=page_ids, ) @@ -207,7 +207,8 @@ def testWrite(self): sharded_page_ids = ops.replicate(page_ids, count=self.shard_count) self.sharded_cache.write( state=sharded_cache_state, - cache_partitions=sharded_cache_partitions, + key=sharded_cache_partitions[0], + value=sharded_cache_partitions[1], transformer_block_index=transformer_block_index, page_ids=sharded_page_ids, ) diff --git a/sharktank/tests/layers/sharded_paged_llama_attention_block.py b/sharktank/tests/layers/sharded_paged_llama_attention_block.py index 11a2d90a7..f6af27d86 100644 --- a/sharktank/tests/layers/sharded_paged_llama_attention_block.py +++ b/sharktank/tests/layers/sharded_paged_llama_attention_block.py @@ -33,7 +33,6 @@ def setUp(self): self.attention_head_dim = 11 * 2 self.rms_epsilon = 0.01 self.block_seq_stride = 17 - self.cache_partition_count = 2 self.page_count = 23 self.embedding_length = self.attention_head_count * self.attention_head_dim self.rope_dimension_count = self.attention_head_dim @@ -62,7 +61,6 @@ def make_paged_kv_cache(shard_count: int) -> PagedKVCache: transformer_block_count=self.transformer_block_count, attn_head_count=self.head_count_kv, attn_head_dim=self.attention_head_dim, - cache_partition_count=self.cache_partition_count, block_seq_stride=self.block_seq_stride, dtype=dtype, shard_count=shard_count, diff --git a/sharktank/tests/models/llama/attention_test.py b/sharktank/tests/models/llama/attention_test.py index 22013635b..dffe2173c 100644 --- a/sharktank/tests/models/llama/attention_test.py +++ b/sharktank/tests/models/llama/attention_test.py @@ -46,7 +46,6 @@ def test(self): transformer_block_count=head_count, attn_head_count=head_count, attn_head_dim=head_dim, - cache_partition_count=2, # One for each of K/V. block_seq_stride=block_seq_stride, device="cpu", dtype=torch.float32, diff --git a/sharktank/tests/models/llama/kv_cache_test.py b/sharktank/tests/models/llama/kv_cache_test.py index 3d43243b0..f80d40c9c 100644 --- a/sharktank/tests/models/llama/kv_cache_test.py +++ b/sharktank/tests/models/llama/kv_cache_test.py @@ -43,7 +43,6 @@ def setUp(self): transformer_block_count=self.head_count, attn_head_count=self.head_count, attn_head_dim=self.head_dim, - cache_partition_count=2, # One for each of K/V. block_seq_stride=self.block_seq_stride, device=self.device, dtype=self.attention_dtype,