Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Llama] Do not allow configurable partitions for KVCache #856

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion sharktank/sharktank/export_layer/export_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ def _(model, state, partition_0, write_page_ids: torch.Tensor) -> torch.Tensor:
write_page_ids = replicate(write_page_ids, count=args.sharding)
cache.write(
state,
cache_partitions=[partition_0, partition_0],
partition_0,
partition_0,
transformer_block_index=1,
page_ids=write_page_ids,
)
Expand Down
36 changes: 18 additions & 18 deletions sharktank/sharktank/layers/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def __init__(
transformer_block_count: int,
attn_head_count: int,
attn_head_dim: int,
cache_partition_count: int = 2,
block_seq_stride: int = 16,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
Expand All @@ -64,7 +63,7 @@ def __init__(
self.transformer_block_count = transformer_block_count
self.attn_head_count = attn_head_count
self.attn_head_dim = attn_head_dim
self.cache_partition_count = cache_partition_count
Groverkss marked this conversation as resolved.
Show resolved Hide resolved
self.cache_partition_count = 2
self.block_seq_stride = block_seq_stride
self.shard_count = shard_count
if attn_head_count % shard_count != 0:
Expand Down Expand Up @@ -225,8 +224,10 @@ def read_cache_partition(index: int):
def write_timestep(
self,
state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
# List of [bs, 1, attn_head_count, attn_head_dim]
cache_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]],
# [bs, 1, attn_head_count, attn_head_dim]
key: Union[torch.Tensor, SplitPrimitiveTensor],
# [bs, 1, attn_head_count, attn_head_dim]
value: Union[torch.Tensor, SplitPrimitiveTensor],
*,
transformer_block_index: int,
# [bs]
Expand All @@ -242,10 +243,9 @@ def write_timestep(
device = self.device
page_table = self.unflatten_page_table(state) # 6D
bs, *_ = seq_positions.shape
assert len(cache_partitions) == self.cache_partition_count

# [bs, 1, atten_head_count, attn_head_dim]
for idx, cache_partition in enumerate(cache_partitions):
for idx, cache_partition in enumerate([key, value]):
# [bs, 1]
page_index = seq_positions // self.block_seq_stride

Expand Down Expand Up @@ -277,12 +277,11 @@ def write_timestep(
indices = (page_id, transformer_block, partitions, page_offset)
page_table.index_put_(indices=indices, values=cache_partition)

return

def write(
self,
state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
cache_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]],
key: Union[torch.Tensor, SplitPrimitiveTensor],
value: Union[torch.Tensor, SplitPrimitiveTensor],
*,
transformer_block_index: int,
page_ids: Union[torch.Tensor, ReplicatedTensor],
Expand Down Expand Up @@ -310,14 +309,15 @@ def write(
transformer_block_index * transformer_block_stride
)

for index, partition in enumerate(cache_partitions):
part_block_view = partition.unflatten(
1, (block_seq_len, self.block_seq_stride)
)
part_block_view = part_block_view.flatten(0, 1)
key_reshaped = key.unflatten(1, (block_seq_len, self.block_seq_stride)).flatten(
0, 1
)
value_reshaped = value.unflatten(
1, (block_seq_len, self.block_seq_stride)
).flatten(0, 1)

subblock_ids = (
(base_subblock_ids + index) if index > 0 else base_subblock_ids
).flatten(0, 1)
key_ids = base_subblock_ids.flatten(0, 1)
value_ids = base_subblock_ids.flatten(0, 1) + 1

subblock_table.index_copy_(0, subblock_ids, part_block_view)
subblock_table.index_copy_(0, key_ids, key_reshaped)
subblock_table.index_copy_(0, value_ids, value_reshaped)
9 changes: 4 additions & 5 deletions sharktank/sharktank/layers/paged_llama_attention_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,8 @@ def transact_cache(
# Prefill: Write the entire cache.
cache.write(
cache_state,
cache_partitions=[xk_cache_update, xv_cache_update],
xk_cache_update,
xv_cache_update,
transformer_block_index=self.block_index,
page_ids=seq_block_ids,
)
Expand All @@ -267,10 +268,8 @@ def transact_cache(
# Write our one updated cache row into the cache.
cache.write_timestep(
cache_state,
cache_partitions=[
xk_cache_update,
xv_cache_update,
],
xk_cache_update,
xv_cache_update,
transformer_block_index=self.block_index,
seq_positions=start_positions,
page_ids=seq_block_ids,
Expand Down
1 change: 0 additions & 1 deletion sharktank/sharktank/utils/create_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def create_paged_kv_cache(config: LlamaModelConfig) -> PagedKVCache:
transformer_block_count=hp.block_count,
attn_head_count=hp.attention_head_count_kv,
attn_head_dim=hp.attn_head_dim,
cache_partition_count=2, # One for each of K/V.
block_seq_stride=config.block_seq_stride,
device=config.device,
dtype=config.attention_dtype,
Expand Down
12 changes: 8 additions & 4 deletions sharktank/tests/layers/kv_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def test_paged():

cache.write(
allocation,
cache_partitions=[write_ones, write_twos],
write_ones,
write_twos,
transformer_block_index=1,
page_ids=write_page_ids,
)
Expand Down Expand Up @@ -85,7 +86,8 @@ def test_paged():
write_pos = torch.full((bs,), write_seq_length, dtype=torch.int64)
cache.write_timestep(
allocation,
cache_partitions=[write_threes, write_fours],
write_threes,
write_fours,
transformer_block_index=1,
seq_positions=write_pos,
page_ids=page_ids,
Expand Down Expand Up @@ -154,7 +156,8 @@ def test_sharded_paged():

cache.write(
allocation,
cache_partitions=[write_ones, write_twos],
write_ones,
write_twos,
transformer_block_index=1,
page_ids=write_page_ids,
)
Expand Down Expand Up @@ -187,7 +190,8 @@ def test_sharded_paged():

cache.write_timestep(
allocation,
cache_partitions=[write_threes, write_fours],
write_threes,
write_fours,
transformer_block_index=1,
seq_positions=write_pos,
page_ids=page_ids,
Expand Down
3 changes: 0 additions & 3 deletions sharktank/tests/layers/paged_llama_attention_block_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def setUp(self):
self.attention_head_dim = 11 * 2
self.rms_epsilon = 0.01
self.block_seq_stride = 17
self.cache_partition_count = 2
self.page_count = 23
self.embedding_length = self.attention_head_count * self.attention_head_dim
self.rope_dimension_count = self.attention_head_dim
Expand All @@ -52,7 +51,6 @@ def testExportDecomposed(self):
transformer_block_count=self.transformer_block_count,
attn_head_count=self.head_count_kv,
attn_head_dim=self.attention_head_dim,
cache_partition_count=self.cache_partition_count,
block_seq_stride=self.block_seq_stride,
dtype=dtype,
)
Expand Down Expand Up @@ -135,7 +133,6 @@ def testExportNondecomposed(self):
transformer_block_count=self.transformer_block_count,
attn_head_count=self.head_count_kv,
attn_head_dim=self.attention_head_dim,
cache_partition_count=self.cache_partition_count,
block_seq_stride=self.block_seq_stride,
dtype=dtype,
)
Expand Down
19 changes: 10 additions & 9 deletions sharktank/tests/layers/sharded_paged_kv_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def setUp(self):
self.attn_head_count = self.shard_count * 7
self.block_seq_stride = 19
self.attn_head_dim = 17
self.cache_partition_count = 2
self.page_count = 23
self.batch_size = 11
self.block_seq_len = 2
Expand All @@ -37,7 +36,6 @@ def setUp(self):
attn_head_count=self.attn_head_count,
block_seq_stride=self.block_seq_stride,
attn_head_dim=self.attn_head_dim,
cache_partition_count=self.cache_partition_count,
dtype=self.dtype,
)
self.sharded_cache = PagedKVCache(
Expand All @@ -46,7 +44,6 @@ def setUp(self):
attn_head_count=self.attn_head_count,
block_seq_stride=self.block_seq_stride,
attn_head_dim=self.attn_head_dim,
cache_partition_count=self.cache_partition_count,
dtype=self.dtype,
)

Expand Down Expand Up @@ -137,7 +134,7 @@ def testWriteTimestep(self):
self.attn_head_count,
self.attn_head_dim,
)
for _ in range(self.cache_partition_count)
for _ in range(2)
]
transformer_block_index = 1
seq_positions = torch.randint(
Expand All @@ -148,7 +145,8 @@ def testWriteTimestep(self):
)
self.cache.write_timestep(
state=cache_state,
cache_partitions=cache_partitions,
key=cache_partitions[0],
value=cache_partitions[1],
transformer_block_index=transformer_block_index,
seq_positions=seq_positions,
page_ids=page_ids,
Expand All @@ -163,7 +161,8 @@ def testWriteTimestep(self):
sharded_page_ids = ops.replicate(page_ids, count=self.shard_count)
self.sharded_cache.write_timestep(
state=sharded_cache_state,
cache_partitions=sharded_cache_partitions,
key=sharded_cache_partitions[0],
value=sharded_cache_partitions[1],
transformer_block_index=transformer_block_index,
seq_positions=sharded_seq_positions,
page_ids=sharded_page_ids,
Expand All @@ -185,7 +184,7 @@ def testWrite(self):
self.attn_head_count,
self.attn_head_dim,
)
for _ in range(self.cache_partition_count)
for _ in range(2)
]
transformer_block_index = 1
assert self.batch_size * self.block_seq_len <= self.page_count
Expand All @@ -194,7 +193,8 @@ def testWrite(self):
)
self.cache.write(
state=cache_state,
cache_partitions=cache_partitions,
key=cache_partitions[0],
value=cache_partitions[1],
transformer_block_index=transformer_block_index,
page_ids=page_ids,
)
Expand All @@ -207,7 +207,8 @@ def testWrite(self):
sharded_page_ids = ops.replicate(page_ids, count=self.shard_count)
self.sharded_cache.write(
state=sharded_cache_state,
cache_partitions=sharded_cache_partitions,
key=sharded_cache_partitions[0],
value=sharded_cache_partitions[1],
transformer_block_index=transformer_block_index,
page_ids=sharded_page_ids,
)
Expand Down
2 changes: 0 additions & 2 deletions sharktank/tests/layers/sharded_paged_llama_attention_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def setUp(self):
self.attention_head_dim = 11 * 2
self.rms_epsilon = 0.01
self.block_seq_stride = 17
self.cache_partition_count = 2
self.page_count = 23
self.embedding_length = self.attention_head_count * self.attention_head_dim
self.rope_dimension_count = self.attention_head_dim
Expand Down Expand Up @@ -62,7 +61,6 @@ def make_paged_kv_cache(shard_count: int) -> PagedKVCache:
transformer_block_count=self.transformer_block_count,
attn_head_count=self.head_count_kv,
attn_head_dim=self.attention_head_dim,
cache_partition_count=self.cache_partition_count,
block_seq_stride=self.block_seq_stride,
dtype=dtype,
shard_count=shard_count,
Expand Down
1 change: 0 additions & 1 deletion sharktank/tests/models/llama/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def test(self):
transformer_block_count=head_count,
attn_head_count=head_count,
attn_head_dim=head_dim,
cache_partition_count=2, # One for each of K/V.
block_seq_stride=block_seq_stride,
device="cpu",
dtype=torch.float32,
Expand Down
1 change: 0 additions & 1 deletion sharktank/tests/models/llama/kv_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def setUp(self):
transformer_block_count=self.head_count,
attn_head_count=self.head_count,
attn_head_dim=self.head_dim,
cache_partition_count=2, # One for each of K/V.
block_seq_stride=self.block_seq_stride,
device=self.device,
dtype=self.attention_dtype,
Expand Down
Loading