Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplified kv caching #20

Draft
wants to merge 4 commits into
base: add_flash_attention
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,7 @@ def __init__(
inference_runner=InferenceRunnerType.NO_RUNNER,
validate_runner_input=True,
pre_allocate_kv_cache=False,
max_sequence_length=None,
max_batch_size=None,
pad_key_length=True,
pad_key_length=None,
predict_last_token: bool = False,
**kwargs,
):
Expand Down Expand Up @@ -172,12 +170,10 @@ def __init__(
# Set to False to disable input validation of safe inputs, for a small speedup.
self.validate_runner_input = validate_runner_input

# Pre-allocate to sequence length `n_positions` (`True`) or the specified value (`int`)
self.pre_allocate_kv_cache = pre_allocate_kv_cache
# The max sequence length for the pre-allocated KV cache (`n_positions` if not provided).
self.max_sequence_length = max_sequence_length
# The max batch size for the pre-allocated KV cache, (deduce from input if not provided).
self.max_batch_size = max_batch_size
# Pad key length to a multiple of 8 (requires pre_allocate_kv_cache).
# Pad key length to a multiple of 8.
self.pad_key_length = pad_key_length

# Predict only the last token in inference even if the input is bigger.
Expand Down
233 changes: 130 additions & 103 deletions src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):
self.mask_value = None

self.multi_query = config.multi_query
# TODO: chack availability
self.seq_dim = -2 if self.multi_query else -1
self.flash_attention = config.flash_attention
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
Expand All @@ -163,12 +163,11 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):
self.fused_softmax = config.fused_softmax

# KV caching and padding
self.kv_cache = None
self.kv_cache_max_batch_size = config.max_batch_size or 0
self.kv_cache_max_sequence_length = config.max_sequence_length or config.n_positions
self.pre_allocate_kv_cache = config.pre_allocate_kv_cache
self.pad_key_length = config.pad_key_length and config.pre_allocate_kv_cache
self._frozen_kv_cache = False
self.pre_allocate_kv_cache = (
config.n_embd if config.pre_allocate_kv_cache is True else config.pre_allocate_kv_cache
)
pad_key_length = config.pre_allocate_kv_cache if config.pad_key_length is None else config.pad_key_length
self._tuple_cache_format = self.pre_allocate_kv_cache or pad_key_length or self.flash_attention

if self.is_cross_attention:
raise NotImplementedError("Cross-attention is not supported for gpt_bigcode.")
Expand All @@ -188,12 +187,6 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):
if not self.multi_query:
# TODO: Flash Attention is implemented but not tested for MHA
raise ValueError("Flash Attention is not supported with multi-head attention.")
if self.pre_allocate_kv_cache:
raise ValueError("KV cache pre-allocation is not supported with Flash Attention")
assert not self.pre_allocate_kv_cache
self._attn_fn = self._attn_flash
else:
self._attn_fn = self._attn_mqa if self.multi_query else self._attn_mha

def _get_mask_value(self, device, dtype):
# torch.where expects a tensor. We use a cache to avoid recreating it every time.
Expand Down Expand Up @@ -248,22 +241,15 @@ def _attn(self, query, key, value, attention_mask, head_mask=None):
beta = 0
attn_weights = torch.baddbmm(attn_weights, query, key, beta=beta, alpha=scale_factor).view(attn_shape)

if upcast:
# Use a fused kernel to prevent a large overhead from casting and scaling.
# Sub-optimal when the key length is not a multiple of 8.
if attention_mask is None:
attn_weights = upcast_softmax(attn_weights, unscale, softmax_dtype)
else:
mask_value = self._get_mask_value(attn_weights.device, softmax_dtype)
attn_weights = upcast_masked_softmax(attn_weights, attention_mask, mask_value, unscale, softmax_dtype)
else:
if attention_mask is not None:
mask_value = self._get_mask_value(attn_weights.device, softmax_dtype)

# The fused kernel is very slow when the key length is not a multiple of 8, so we skip fusion.
attn_weights = torch.where(attention_mask, attn_weights, mask_value)

attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
attn_weights = softmax_function(
attn_weights,
attention_mask,
None if attention_mask is None else self._get_mask_value(attn_weights.device, softmax_dtype),
unscale,
softmax_dtype,
upcast,
self.fused_softmax,
)

attn_weights = self.attn_dropout(attn_weights)

Expand Down Expand Up @@ -312,40 +298,105 @@ def _attn_flash(self, query, key, value, attention_mask, head_mask=None):

return attn_output, None

def freeze_kv_cache(self, enable=True):
if self.kv_cache is None:
raise RuntimeError("KV cache not found.")
# Prevent re-allocation of the KV cache.
self._frozen_kv_cache = enable

def get_kv_cache(self, batch_size, sequence_length, device, dtype, allocate=True):
if (
self.kv_cache is None
or self.kv_cache.dtype != dtype
or self.kv_cache.device != device
or batch_size > self.kv_cache_max_batch_size
or sequence_length > self.kv_cache_max_sequence_length
):
if self._frozen_kv_cache or not allocate:
if self.kv_cache is None:
raise RuntimeError("KV cache not found.")
else:
raise RuntimeError(
f"Invalid KV cache: "
f"existing = {(self.kv_cache.dtype,self.kv_cache.device,self.kv_cache_max_batch_size,self.kv_cache_max_sequence_length)}, "
f"requested = {(dtype,device,batch_size,sequence_length)}"
)
# Free memory first.
self.kv_cache = None
self.kv_cache_max_sequence_length = max(sequence_length, self.kv_cache_max_sequence_length)
self.kv_cache_max_batch_size = max(batch_size, self.kv_cache_max_batch_size)
kv_cache_size = 2 * self.kv_cache_max_batch_size * self.kv_cache_max_sequence_length * self.kv_dim
self.kv_cache = torch.empty([kv_cache_size], device=device, dtype=dtype)
# This view ensures the cache is contiguous for all batch sizes.
kv_cache = self.kv_cache[: 2 * batch_size * self.kv_cache_max_sequence_length * self.kv_dim].view(
batch_size, self.kv_heads, self.kv_cache_max_sequence_length, 2 * self.head_dim
)
return kv_cache[:, 0, :sequence_length, :] if self.multi_query else kv_cache[:, :, :sequence_length, :]
def _re_allocate_kv_cache(self, kv_cache, key_length, padded_key_length, allocate_key_length):
batch_size = kv_cache.size(-1)
assert not self.training
if self.multi_query:
allocated_kv_cache = torch.empty(
[batch_size, allocate_key_length, self.head_dim], dtype=kv_cache.dtype, device=kv_cache.device
)
allocated_kv_cache[:, :key_length].copy_(kv_cache)
padded_kv_cache = allocated_kv_cache[:, :padded_key_length]
else:
allocated_kv_cache = torch.empty(
[batch_size, self.num_heads, allocate_key_length, self.head_dim],
dtype=kv_cache.dtype,
device=kv_cache.device,
)
allocated_kv_cache[:, :, key_length].copy_(kv_cache)
padded_kv_cache = allocated_kv_cache[:, :, :padded_key_length]
return allocated_kv_cache, padded_kv_cache

def _merge_kv_caches(self, key_value, use_cache, layer_past, attention_mask):
flash_attention = self.flash_attention and layer_past is None

# Convert to standard KV cache format.
if flash_attention and use_cache:
_, padding_index, batch_size, max_sequence_length = attention_mask
current_kv_cache = pad_input(key_value, padding_index, batch_size, max_sequence_length)
if not self.multi_query:
current_kv_cache = current_kv_cache.view(
batch_size, max_sequence_length, self.num_heads, 2 * self.head_dim
).transpose(1, 2)
else:
current_kv_cache = key_value

# Calculate dimensions and recover layer_past
batch_size = current_kv_cache.size(0)
query_length = current_kv_cache.size(self.seq_dim)
if layer_past is None:
allocated_kv_cache, last_key_length = None, 0
last_kv_cache = None
key_length = query_length
allocated_key_length = key_length
else:
allocated_kv_cache, last_key_length = layer_past
last_kv_cache = (
allocated_kv_cache[:, :last_key_length]
if self.multi_query
else allocated_kv_cache[:, :, :last_key_length]
)
key_length = query_length + last_key_length
allocated_key_length = allocated_kv_cache.size(self.seq_dim)

padded_key_length = key_length if flash_attention else attention_mask.size(-1)
allocate_key_length = padded_key_length if use_cache else max(self.pre_allocate_kv_cache, padded_key_length)

# Re-allocate kv cache and copy last value
if allocate_key_length > allocated_key_length:
if self.multi_query:
allocated_kv_cache = torch.empty(
[batch_size, allocate_key_length, 2 * self.head_dim],
dtype=current_kv_cache.dtype,
device=current_kv_cache.device,
)
if layer_past is not None:
allocated_kv_cache[:, :last_key_length].copy_(last_kv_cache)
if allocate_key_length > key_length:
# Nans in `value` can propagate through the matrix multiplication,
# so we set the remaining values to zero. (`last_key_length:key_length` is set below.)
allocated_kv_cache[:, key_length:, self.head_dim :].zero_()
else:
allocated_kv_cache = torch.empty(
[batch_size, self.num_heads, allocate_key_length, 2 * self.head_dim],
dtype=current_kv_cache.dtype,
device=current_kv_cache.device,
)
if layer_past is not None:
allocated_kv_cache[:, :, :last_key_length].copy_(last_kv_cache)
if allocate_key_length > key_length:
allocated_kv_cache[:, :, key_length:, self.head_dim :].zero_()

# Copy the new values.
if allocate_key_length > allocated_key_length or layer_past is not None:
if self.multi_query:
allocated_kv_cache[:, last_key_length:key_length].copy_(current_kv_cache)
padded_kv_cache = allocated_kv_cache[:, :padded_key_length]
else:
allocated_kv_cache[:, :, last_key_length:key_length].copy_(current_kv_cache)
padded_kv_cache = allocated_kv_cache[:, :, :padded_key_length]
if not flash_attention:
# Use the merged KV cache.
# Not needed when layer_past is None but frees some memory.
key_value = padded_kv_cache

if use_cache:
if allocated_kv_cache is None:
allocated_kv_cache = current_kv_cache
present = allocated_kv_cache, key_length
else:
present = None
return key_value, present

def forward(
self,
Expand All @@ -359,9 +410,9 @@ def forward(
Tuple[torch.Tensor, Optional[torch.Tensor]],
Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
]:
flash_attention = self.flash_attention and not layer_past
flash_attention = self.flash_attention and layer_past is None
if self.multi_query or flash_attention:
query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2)
query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=-1)
else:
# Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim),
# i.e., the memory layout is not the same as GPT2.
Expand All @@ -373,39 +424,14 @@ def forward(
.split((self.head_dim, 2 * self.head_dim), dim=3)
)

present = None

if flash_attention and use_cache:
# Unpad and convert to KV cache format.
# Todo: unpadding is only needed if the cache is reused.
_, padding_index, batch_size, max_sequence_length = attention_mask
kv_cache = pad_input(key_value, padding_index, batch_size, max_sequence_length)
if not self.multi_query:
kv_cache = kv_cache.view(*hidden_states.shape[:2], self.num_heads, 2 * self.head_dim).transpose(1, 2)
else:
kv_cache = key_value
if self.pre_allocate_kv_cache:
if use_cache or layer_past is not None:
last_key_length = layer_past or 0
batch_size = kv_cache.size(0)
key_length = last_key_length + kv_cache.size(-2 if self.multi_query else -1)
padded_key_length = key_length + -key_length % (8 if self.pad_key_length else 1)
kv_cache_ = self.get_kv_cache(
batch_size, padded_key_length, kv_cache.device, kv_cache.dtype, allocate=last_key_length == 0
)
if self.multi_query:
kv_cache_[:, last_key_length:key_length, :].copy_(kv_cache)
else:
kv_cache_[:, :, last_key_length:key_length, :].copy_(kv_cache)
present = key_length if use_cache else None
if not flash_attention:
# Not needed when layer_past is None but frees some memory.
key_value = kv_cache_
if self._tuple_cache_format:
# present = (allocated_kv_cache, key_length)
key_value, present = self._merge_kv_caches(key_value, use_cache, layer_past, attention_mask)
else:
# present = key_value
if layer_past is not None:
kv_cache = torch.cat((layer_past, key_value), dim=-2)
key_value = kv_cache
present = kv_cache if use_cache else None
key_value = torch.cat((layer_past, key_value), dim=-2)
present = key_value if use_cache else None

key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)

Expand Down Expand Up @@ -662,8 +688,8 @@ def __init__(self, config):
self.h = nn.ModuleList([GPTBigCodeBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

self.pre_allocate_kv_cache = config.pre_allocate_kv_cache
self.pad_key_length = config.pad_key_length and self.pre_allocate_kv_cache
self.pad_key_length = config.pre_allocate_kv_cache if config.pad_key_length is None else config.pad_key_length
self._tuple_cache_format = config.pre_allocate_kv_cache or self.pad_key_length or config.flash_attention
self.inference_runner_type = InferenceRunnerType(config.inference_runner)

self.flash_attention = config.flash_attention
Expand Down Expand Up @@ -788,11 +814,12 @@ def forward(
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")

flash_attention = self.flash_attention and past_key_values is None
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.h))
elif self.pre_allocate_kv_cache:
past_length = past_key_values[0]
elif self._tuple_cache_format:
past_length = past_key_values[0][1]
else:
past_length = past_key_values[0].size(-2)
key_length = past_length + query_length
Expand All @@ -802,7 +829,7 @@ def forward(
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, query_length)

if not self.flash_attention:
if not flash_attention:
# Self-attention mask (padding + causal).
attention_mask = self._get_causal_mask(attention_mask, query_length, key_length)
if self.pad_key_length:
Expand Down Expand Up @@ -831,7 +858,7 @@ def forward(
hidden_states = self.drop(hidden_states)

# TODO: Unpad earlier (input ids), support unpadded input?
if self.flash_attention:
if flash_attention:
hidden_states, padding_index, sequence_lengths, max_sequence_length = unpad_input(
hidden_states, attention_mask
)
Expand Down Expand Up @@ -883,7 +910,7 @@ def custom_forward(*inputs):

hidden_states = self.ln_f(hidden_states)

if self.flash_attention:
if flash_attention:
hidden_states = pad_input(hidden_states, padding_index, batch_size, query_length)

hidden_states = hidden_states.view(input_shape + (hidden_states.size(-1),))
Expand Down