Skip to content

Commit

Permalink
[Llama] Do not allow configurable partitions for KVCache
Browse files Browse the repository at this point in the history
  • Loading branch information
Groverkss committed Jan 22, 2025
1 parent d4298df commit a470e7c
Show file tree
Hide file tree
Showing 10 changed files with 47 additions and 51 deletions.
3 changes: 2 additions & 1 deletion sharktank/sharktank/export_layer/export_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ def _(model, state, partition_0, write_page_ids: torch.Tensor) -> torch.Tensor:
write_page_ids = replicate(write_page_ids, count=args.sharding)
cache.write(
state,
cache_partitions=[partition_0, partition_0],
partition_0,
partition_0,
transformer_block_index=1,
page_ids=write_page_ids,
)
Expand Down
47 changes: 23 additions & 24 deletions sharktank/sharktank/layers/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def __init__(
transformer_block_count: int,
attn_head_count: int,
attn_head_dim: int,
cache_partition_count: int = 2,
block_seq_stride: int = 16,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
Expand All @@ -64,7 +63,6 @@ def __init__(
self.transformer_block_count = transformer_block_count
self.attn_head_count = attn_head_count
self.attn_head_dim = attn_head_dim
self.cache_partition_count = cache_partition_count
self.block_seq_stride = block_seq_stride
self.shard_count = shard_count
if attn_head_count % shard_count != 0:
Expand All @@ -75,7 +73,7 @@ def __init__(
# Some derived values based on attributes.
self.sub_page_dims = [
self.transformer_block_count,
self.cache_partition_count,
2,
self.block_seq_stride,
self.attn_head_count // self.shard_count,
self.attn_head_dim,
Expand Down Expand Up @@ -115,7 +113,7 @@ def shard_state(
[
-1,
self.transformer_block_count,
self.cache_partition_count,
2,
self.block_seq_stride,
self.attn_head_count,
self.attn_head_dim,
Expand Down Expand Up @@ -195,8 +193,8 @@ def read(
# [page, attn_layer, cache_partition]
# Where the cache line can be 0 (k) or 1 (v).
subblock_table = page_table.flatten(start_dim=0, end_dim=2)
page_stride = self.transformer_block_count * self.cache_partition_count
transformer_block_stride = self.cache_partition_count
page_stride = self.transformer_block_count * 2
transformer_block_stride = 2
base_subblock_ids = page_ids * page_stride + (
transformer_block_index * transformer_block_stride
)
Expand Down Expand Up @@ -225,8 +223,10 @@ def read_cache_partition(index: int):
def write_timestep(
self,
state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
# List of [bs, 1, attn_head_count, attn_head_dim]
cache_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]],
# [bs, 1, attn_head_count, attn_head_dim]
key: Union[torch.Tensor, SplitPrimitiveTensor],
# [bs, 1, attn_head_count, attn_head_dim]
value: Union[torch.Tensor, SplitPrimitiveTensor],
*,
transformer_block_index: int,
# [bs]
Expand All @@ -242,10 +242,9 @@ def write_timestep(
device = self.device
page_table = self.unflatten_page_table(state) # 6D
bs, *_ = seq_positions.shape
assert len(cache_partitions) == self.cache_partition_count

# [bs, 1, atten_head_count, attn_head_dim]
for idx, cache_partition in enumerate(cache_partitions):
for idx, cache_partition in enumerate([key, value]):
# [bs, 1]
page_index = seq_positions // self.block_seq_stride

Expand Down Expand Up @@ -277,12 +276,11 @@ def write_timestep(
indices = (page_id, transformer_block, partitions, page_offset)
page_table.index_put_(indices=indices, values=cache_partition)

return

def write(
self,
state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
cache_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]],
key: Union[torch.Tensor, SplitPrimitiveTensor],
value: Union[torch.Tensor, SplitPrimitiveTensor],
*,
transformer_block_index: int,
page_ids: Union[torch.Tensor, ReplicatedTensor],
Expand All @@ -304,20 +302,21 @@ def write(
# [page, attn_layer, cache_partition]
# Where the cache line can be 0 (k) or 1 (v).
subblock_table = page_table.flatten(start_dim=0, end_dim=2)
page_stride = self.transformer_block_count * self.cache_partition_count
transformer_block_stride = self.cache_partition_count
page_stride = self.transformer_block_count * 2
transformer_block_stride = 2
base_subblock_ids = page_ids * page_stride + (
transformer_block_index * transformer_block_stride
)

for index, partition in enumerate(cache_partitions):
part_block_view = partition.unflatten(
1, (block_seq_len, self.block_seq_stride)
)
part_block_view = part_block_view.flatten(0, 1)
key_reshaped = key.unflatten(1, (block_seq_len, self.block_seq_stride)).flatten(
0, 1
)
value_reshaped = value.unflatten(
1, (block_seq_len, self.block_seq_stride)
).flatten(0, 1)

subblock_ids = (
(base_subblock_ids + index) if index > 0 else base_subblock_ids
).flatten(0, 1)
key_ids = base_subblock_ids.flatten(0, 1)
value_ids = base_subblock_ids.flatten(0, 1) + 1

subblock_table.index_copy_(0, subblock_ids, part_block_view)
subblock_table.index_copy_(0, key_ids, key_reshaped)
subblock_table.index_copy_(0, value_ids, value_reshaped)
9 changes: 4 additions & 5 deletions sharktank/sharktank/layers/paged_llama_attention_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,8 @@ def transact_cache(
# Prefill: Write the entire cache.
cache.write(
cache_state,
cache_partitions=[xk_cache_update, xv_cache_update],
xk_cache_update,
xv_cache_update,
transformer_block_index=self.block_index,
page_ids=seq_block_ids,
)
Expand All @@ -267,10 +268,8 @@ def transact_cache(
# Write our one updated cache row into the cache.
cache.write_timestep(
cache_state,
cache_partitions=[
xk_cache_update,
xv_cache_update,
],
xk_cache_update,
xv_cache_update,
transformer_block_index=self.block_index,
seq_positions=start_positions,
page_ids=seq_block_ids,
Expand Down
1 change: 0 additions & 1 deletion sharktank/sharktank/utils/create_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def create_paged_kv_cache(config: LlamaModelConfig) -> PagedKVCache:
transformer_block_count=hp.block_count,
attn_head_count=hp.attention_head_count_kv,
attn_head_dim=hp.attn_head_dim,
cache_partition_count=2, # One for each of K/V.
block_seq_stride=config.block_seq_stride,
device=config.device,
dtype=config.attention_dtype,
Expand Down
12 changes: 8 additions & 4 deletions sharktank/tests/layers/kv_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def test_paged():

cache.write(
allocation,
cache_partitions=[write_ones, write_twos],
write_ones,
write_twos,
transformer_block_index=1,
page_ids=write_page_ids,
)
Expand Down Expand Up @@ -85,7 +86,8 @@ def test_paged():
write_pos = torch.full((bs,), write_seq_length, dtype=torch.int64)
cache.write_timestep(
allocation,
cache_partitions=[write_threes, write_fours],
write_threes,
write_fours,
transformer_block_index=1,
seq_positions=write_pos,
page_ids=page_ids,
Expand Down Expand Up @@ -154,7 +156,8 @@ def test_sharded_paged():

cache.write(
allocation,
cache_partitions=[write_ones, write_twos],
write_ones,
write_twos,
transformer_block_index=1,
page_ids=write_page_ids,
)
Expand Down Expand Up @@ -187,7 +190,8 @@ def test_sharded_paged():

cache.write_timestep(
allocation,
cache_partitions=[write_threes, write_fours],
write_threes,
write_fours,
transformer_block_index=1,
seq_positions=write_pos,
page_ids=page_ids,
Expand Down
3 changes: 0 additions & 3 deletions sharktank/tests/layers/paged_llama_attention_block_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def setUp(self):
self.attention_head_dim = 11 * 2
self.rms_epsilon = 0.01
self.block_seq_stride = 17
self.cache_partition_count = 2
self.page_count = 23
self.embedding_length = self.attention_head_count * self.attention_head_dim
self.rope_dimension_count = self.attention_head_dim
Expand All @@ -52,7 +51,6 @@ def testExportDecomposed(self):
transformer_block_count=self.transformer_block_count,
attn_head_count=self.head_count_kv,
attn_head_dim=self.attention_head_dim,
cache_partition_count=self.cache_partition_count,
block_seq_stride=self.block_seq_stride,
dtype=dtype,
)
Expand Down Expand Up @@ -135,7 +133,6 @@ def testExportNondecomposed(self):
transformer_block_count=self.transformer_block_count,
attn_head_count=self.head_count_kv,
attn_head_dim=self.attention_head_dim,
cache_partition_count=self.cache_partition_count,
block_seq_stride=self.block_seq_stride,
dtype=dtype,
)
Expand Down
19 changes: 10 additions & 9 deletions sharktank/tests/layers/sharded_paged_kv_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def setUp(self):
self.attn_head_count = self.shard_count * 7
self.block_seq_stride = 19
self.attn_head_dim = 17
self.cache_partition_count = 2
self.page_count = 23
self.batch_size = 11
self.block_seq_len = 2
Expand All @@ -37,7 +36,6 @@ def setUp(self):
attn_head_count=self.attn_head_count,
block_seq_stride=self.block_seq_stride,
attn_head_dim=self.attn_head_dim,
cache_partition_count=self.cache_partition_count,
dtype=self.dtype,
)
self.sharded_cache = PagedKVCache(
Expand All @@ -46,7 +44,6 @@ def setUp(self):
attn_head_count=self.attn_head_count,
block_seq_stride=self.block_seq_stride,
attn_head_dim=self.attn_head_dim,
cache_partition_count=self.cache_partition_count,
dtype=self.dtype,
)

Expand Down Expand Up @@ -137,7 +134,7 @@ def testWriteTimestep(self):
self.attn_head_count,
self.attn_head_dim,
)
for _ in range(self.cache_partition_count)
for _ in range(2)
]
transformer_block_index = 1
seq_positions = torch.randint(
Expand All @@ -148,7 +145,8 @@ def testWriteTimestep(self):
)
self.cache.write_timestep(
state=cache_state,
cache_partitions=cache_partitions,
key=cache_partitions[0],
value=cache_partitions[1],
transformer_block_index=transformer_block_index,
seq_positions=seq_positions,
page_ids=page_ids,
Expand All @@ -163,7 +161,8 @@ def testWriteTimestep(self):
sharded_page_ids = ops.replicate(page_ids, count=self.shard_count)
self.sharded_cache.write_timestep(
state=sharded_cache_state,
cache_partitions=sharded_cache_partitions,
key=sharded_cache_partitions[0],
value=sharded_cache_partitions[1],
transformer_block_index=transformer_block_index,
seq_positions=sharded_seq_positions,
page_ids=sharded_page_ids,
Expand All @@ -185,7 +184,7 @@ def testWrite(self):
self.attn_head_count,
self.attn_head_dim,
)
for _ in range(self.cache_partition_count)
for _ in range(2)
]
transformer_block_index = 1
assert self.batch_size * self.block_seq_len <= self.page_count
Expand All @@ -194,7 +193,8 @@ def testWrite(self):
)
self.cache.write(
state=cache_state,
cache_partitions=cache_partitions,
key=cache_partitions[0],
value=cache_partitions[1],
transformer_block_index=transformer_block_index,
page_ids=page_ids,
)
Expand All @@ -207,7 +207,8 @@ def testWrite(self):
sharded_page_ids = ops.replicate(page_ids, count=self.shard_count)
self.sharded_cache.write(
state=sharded_cache_state,
cache_partitions=sharded_cache_partitions,
key=sharded_cache_partitions[0],
value=sharded_cache_partitions[1],
transformer_block_index=transformer_block_index,
page_ids=sharded_page_ids,
)
Expand Down
2 changes: 0 additions & 2 deletions sharktank/tests/layers/sharded_paged_llama_attention_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def setUp(self):
self.attention_head_dim = 11 * 2
self.rms_epsilon = 0.01
self.block_seq_stride = 17
self.cache_partition_count = 2
self.page_count = 23
self.embedding_length = self.attention_head_count * self.attention_head_dim
self.rope_dimension_count = self.attention_head_dim
Expand Down Expand Up @@ -62,7 +61,6 @@ def make_paged_kv_cache(shard_count: int) -> PagedKVCache:
transformer_block_count=self.transformer_block_count,
attn_head_count=self.head_count_kv,
attn_head_dim=self.attention_head_dim,
cache_partition_count=self.cache_partition_count,
block_seq_stride=self.block_seq_stride,
dtype=dtype,
shard_count=shard_count,
Expand Down
1 change: 0 additions & 1 deletion sharktank/tests/models/llama/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def test(self):
transformer_block_count=head_count,
attn_head_count=head_count,
attn_head_dim=head_dim,
cache_partition_count=2, # One for each of K/V.
block_seq_stride=block_seq_stride,
device="cpu",
dtype=torch.float32,
Expand Down
1 change: 0 additions & 1 deletion sharktank/tests/models/llama/kv_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def setUp(self):
transformer_block_count=self.head_count,
attn_head_count=self.head_count,
attn_head_dim=self.head_dim,
cache_partition_count=2, # One for each of K/V.
block_seq_stride=self.block_seq_stride,
device=self.device,
dtype=self.attention_dtype,
Expand Down

0 comments on commit a470e7c

Please sign in to comment.