From e18f7540de3b4e83465c46ec92bdaa2732504d13 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 24 Apr 2023 16:39:37 -0400 Subject: [PATCH 1/4] Simplified kv caching --- .../gpt_bigcode/configuration_gpt_bigcode.py | 10 +- .../gpt_bigcode/modeling_gpt_bigcode.py | 188 ++++++++++-------- 2 files changed, 112 insertions(+), 86 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py index f25dca27b9..5183f55705 100644 --- a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py @@ -139,9 +139,7 @@ def __init__( inference_runner=InferenceRunnerType.NO_RUNNER, validate_runner_input=True, pre_allocate_kv_cache=False, - max_sequence_length=None, - max_batch_size=None, - pad_key_length=True, + pad_key_length=None, predict_last_token: bool = False, **kwargs, ): @@ -172,12 +170,10 @@ def __init__( # Set to False to disable input validation of safe inputs, for a small speedup. self.validate_runner_input = validate_runner_input + # Pre-allocate to sequence length `n_positions` (`True`) or the specified value (`int`) self.pre_allocate_kv_cache = pre_allocate_kv_cache # The max sequence length for the pre-allocated KV cache (`n_positions` if not provided). - self.max_sequence_length = max_sequence_length - # The max batch size for the pre-allocated KV cache, (deduce from input if not provided). - self.max_batch_size = max_batch_size - # Pad key length to a multiple of 8 (requires pre_allocate_kv_cache). + # Pad key length to a multiple of 8. self.pad_key_length = pad_key_length # Predict only the last token in inference even if the input is bigger. diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index b51eb67eb7..c1fcf3954e 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -138,7 +138,7 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): self.mask_value = None self.multi_query = config.multi_query - # TODO: chack availability + self.seq_dim = -2 if self.multi_query else -1 self.flash_attention = config.flash_attention self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads @@ -163,12 +163,11 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): self.fused_softmax = config.fused_softmax # KV caching and padding - self.kv_cache = None - self.kv_cache_max_batch_size = config.max_batch_size or 0 - self.kv_cache_max_sequence_length = config.max_sequence_length or config.n_positions - self.pre_allocate_kv_cache = config.pre_allocate_kv_cache - self.pad_key_length = config.pad_key_length and config.pre_allocate_kv_cache - self._frozen_kv_cache = False + self.pre_allocate_kv_cache = ( + config.n_embd if config.pre_allocate_kv_cache is True else config.pre_allocate_kv_cache + ) + self.pad_key_length = config.pre_allocate_kv_cache if config.pad_key_length is None else config.pad_key_length + self._tuple_cache_format = self.pre_allocate_kv_cache or self.pad_key_length or self.flash_attention if self.is_cross_attention: raise NotImplementedError("Cross-attention is not supported for gpt_bigcode.") @@ -191,9 +190,6 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): if self.pre_allocate_kv_cache: raise ValueError("KV cache pre-allocation is not supported with Flash Attention") assert not self.pre_allocate_kv_cache - self._attn_fn = self._attn_flash - else: - self._attn_fn = self._attn_mqa if self.multi_query else self._attn_mha def _get_mask_value(self, device, dtype): # torch.where expects a tensor. We use a cache to avoid recreating it every time. @@ -312,40 +308,99 @@ def _attn_flash(self, query, key, value, attention_mask, head_mask=None): return attn_output, None - def freeze_kv_cache(self, enable=True): - if self.kv_cache is None: - raise RuntimeError("KV cache not found.") - # Prevent re-allocation of the KV cache. - self._frozen_kv_cache = enable - - def get_kv_cache(self, batch_size, sequence_length, device, dtype, allocate=True): - if ( - self.kv_cache is None - or self.kv_cache.dtype != dtype - or self.kv_cache.device != device - or batch_size > self.kv_cache_max_batch_size - or sequence_length > self.kv_cache_max_sequence_length - ): - if self._frozen_kv_cache or not allocate: - if self.kv_cache is None: - raise RuntimeError("KV cache not found.") - else: - raise RuntimeError( - f"Invalid KV cache: " - f"existing = {(self.kv_cache.dtype,self.kv_cache.device,self.kv_cache_max_batch_size,self.kv_cache_max_sequence_length)}, " - f"requested = {(dtype,device,batch_size,sequence_length)}" - ) - # Free memory first. - self.kv_cache = None - self.kv_cache_max_sequence_length = max(sequence_length, self.kv_cache_max_sequence_length) - self.kv_cache_max_batch_size = max(batch_size, self.kv_cache_max_batch_size) - kv_cache_size = 2 * self.kv_cache_max_batch_size * self.kv_cache_max_sequence_length * self.kv_dim - self.kv_cache = torch.empty([kv_cache_size], device=device, dtype=dtype) - # This view ensures the cache is contiguous for all batch sizes. - kv_cache = self.kv_cache[: 2 * batch_size * self.kv_cache_max_sequence_length * self.kv_dim].view( - batch_size, self.kv_heads, self.kv_cache_max_sequence_length, 2 * self.head_dim - ) - return kv_cache[:, 0, :sequence_length, :] if self.multi_query else kv_cache[:, :, :sequence_length, :] + def _re_allocate_kv_cache(self, kv_cache, key_length, padded_key_length, allocate_key_length): + batch_size = kv_cache.size(-1) + assert not self.training + if self.multi_query: + allocated_kv_cache = torch.empty( + [batch_size, allocate_key_length, self.head_dim], dtype=kv_cache.dtype, device=kv_cache.device + ) + allocated_kv_cache[:, :key_length].copy_(kv_cache) + padded_kv_cache = allocated_kv_cache[:, :padded_key_length] + else: + allocated_kv_cache = torch.empty( + [batch_size, self.num_heads, allocate_key_length, self.head_dim], + dtype=kv_cache.dtype, + device=kv_cache.device, + ) + allocated_kv_cache[:, :, key_length].copy_(kv_cache) + padded_kv_cache = allocated_kv_cache[:, :, :padded_key_length] + return allocated_kv_cache, padded_kv_cache + + def _merge_kv_caches(self, key_value, use_cache, layer_past, attention_mask): + flash_attention = self.flash_attention and not layer_past + + # Convert to standard KV cache format. + if flash_attention and use_cache: + _, padding_index, batch_size, max_sequence_length = attention_mask + current_kv_cache = pad_input(key_value, padding_index, batch_size, max_sequence_length) + if not self.multi_query: + current_kv_cache = current_kv_cache.view( + batch_size, max_sequence_length, self.num_heads, 2 * self.head_dim + ).transpose(1, 2) + else: + current_kv_cache = key_value + + # Calculate dimensions and recover layer_past + batch_size = current_kv_cache.size(0) + query_length = current_kv_cache.size(self.seq_dim) + if layer_past is None: + allocated_kv_cache, last_key_length = None, 0 + last_kv_cache = None + key_length = query_length + allocated_key_length = key_length + else: + allocated_kv_cache, last_key_length = layer_past + last_kv_cache = ( + allocated_kv_cache[:, :last_key_length] + if self.multi_query + else allocated_kv_cache[:, :, :last_key_length] + ) + key_length = query_length + last_key_length + allocated_key_length = allocated_kv_cache.size(self.seq_dim) + + padded_key_length = key_length if flash_attention else attention_mask.size(self.seq_dim) + allocate_key_length = padded_key_length if use_cache else max(self.pre_allocate_kv_cache, padded_key_length) + + # Re-allocate kv cache and copy last value + if allocate_key_length > allocated_key_length: + if self.multi_query: + allocated_kv_cache = torch.empty( + [batch_size, allocate_key_length, self.head_dim], + dtype=current_kv_cache.dtype, + device=current_kv_cache.device, + ) + if layer_past is not None: + allocated_kv_cache[:, :last_key_length].copy_(last_kv_cache) + else: + allocated_kv_cache = torch.empty( + [batch_size, self.num_heads, allocate_key_length, self.head_dim], + dtype=current_kv_cache.dtype, + device=current_kv_cache.device, + ) + if layer_past is not None: + allocated_kv_cache[:, :, :last_key_length].copy_(last_kv_cache) + + # Copy the new values. + if allocate_key_length > allocated_key_length or layer_past is not None: + if self.multi_query: + allocated_kv_cache[:, last_key_length:key_length].copy_(current_kv_cache) + padded_kv_cache = allocated_kv_cache[:, :padded_key_length] + else: + allocated_kv_cache[:, :, last_key_length:key_length].copy_(current_kv_cache) + padded_kv_cache = allocated_kv_cache[:, :, :padded_key_length] + if not flash_attention: + # Use the merged KV cache. + # Not needed when layer_past is None but frees some memory. + key_value = padded_kv_cache + + if use_cache: + if allocated_kv_cache is None: + allocated_kv_cache = current_kv_cache + present = allocated_kv_cache, key_length + else: + present = None + return key_value, present def forward( self, @@ -373,39 +428,14 @@ def forward( .split((self.head_dim, 2 * self.head_dim), dim=3) ) - present = None - - if flash_attention and use_cache: - # Unpad and convert to KV cache format. - # Todo: unpadding is only needed if the cache is reused. - _, padding_index, batch_size, max_sequence_length = attention_mask - kv_cache = pad_input(key_value, padding_index, batch_size, max_sequence_length) - if not self.multi_query: - kv_cache = kv_cache.view(*hidden_states.shape[:2], self.num_heads, 2 * self.head_dim).transpose(1, 2) - else: - kv_cache = key_value - if self.pre_allocate_kv_cache: - if use_cache or layer_past is not None: - last_key_length = layer_past or 0 - batch_size = kv_cache.size(0) - key_length = last_key_length + kv_cache.size(-2 if self.multi_query else -1) - padded_key_length = key_length + -key_length % (8 if self.pad_key_length else 1) - kv_cache_ = self.get_kv_cache( - batch_size, padded_key_length, kv_cache.device, kv_cache.dtype, allocate=last_key_length == 0 - ) - if self.multi_query: - kv_cache_[:, last_key_length:key_length, :].copy_(kv_cache) - else: - kv_cache_[:, :, last_key_length:key_length, :].copy_(kv_cache) - present = key_length if use_cache else None - if not flash_attention: - # Not needed when layer_past is None but frees some memory. - key_value = kv_cache_ + if self._tuple_cache_format: + # present = (allocated_kv_cache, key_length) + key_value, present = self._merge_kv_caches(key_value, use_cache, layer_past, attention_mask) else: + # present = key_value if layer_past is not None: - kv_cache = torch.cat((layer_past, key_value), dim=-2) - key_value = kv_cache - present = kv_cache if use_cache else None + key_value = torch.cat((layer_past, key_value), dim=-2) + present = key_value if use_cache else None key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) @@ -662,8 +692,8 @@ def __init__(self, config): self.h = nn.ModuleList([GPTBigCodeBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) - self.pre_allocate_kv_cache = config.pre_allocate_kv_cache - self.pad_key_length = config.pad_key_length and self.pre_allocate_kv_cache + self.pad_key_length = config.pre_allocate_kv_cache if config.pad_key_length is None else config.pad_key_length + self._tuple_cache_format = config.pre_allocate_kv_cache or self.pad_key_length or config.flash_attention self.inference_runner_type = InferenceRunnerType(config.inference_runner) self.flash_attention = config.flash_attention @@ -791,8 +821,8 @@ def forward( if past_key_values is None: past_length = 0 past_key_values = tuple([None] * len(self.h)) - elif self.pre_allocate_kv_cache: - past_length = past_key_values[0] + elif self._tuple_cache_format: + past_length = past_key_values[0][1] else: past_length = past_key_values[0].size(-2) key_length = past_length + query_length From 42513fbb31ab262befbb07a5d69e83c1cc8b8303 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 25 Apr 2023 05:59:03 -0400 Subject: [PATCH 2/4] misc --- .../gpt_bigcode/modeling_gpt_bigcode.py | 29 ++++++++++++------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index c1fcf3954e..9d6f16ffd2 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -166,8 +166,8 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): self.pre_allocate_kv_cache = ( config.n_embd if config.pre_allocate_kv_cache is True else config.pre_allocate_kv_cache ) - self.pad_key_length = config.pre_allocate_kv_cache if config.pad_key_length is None else config.pad_key_length - self._tuple_cache_format = self.pre_allocate_kv_cache or self.pad_key_length or self.flash_attention + pad_key_length = config.pre_allocate_kv_cache if config.pad_key_length is None else config.pad_key_length + self._tuple_cache_format = self.pre_allocate_kv_cache or pad_key_length or self.flash_attention if self.is_cross_attention: raise NotImplementedError("Cross-attention is not supported for gpt_bigcode.") @@ -328,7 +328,7 @@ def _re_allocate_kv_cache(self, kv_cache, key_length, padded_key_length, allocat return allocated_kv_cache, padded_kv_cache def _merge_kv_caches(self, key_value, use_cache, layer_past, attention_mask): - flash_attention = self.flash_attention and not layer_past + flash_attention = self.flash_attention and layer_past is None # Convert to standard KV cache format. if flash_attention and use_cache: @@ -359,22 +359,26 @@ def _merge_kv_caches(self, key_value, use_cache, layer_past, attention_mask): key_length = query_length + last_key_length allocated_key_length = allocated_kv_cache.size(self.seq_dim) - padded_key_length = key_length if flash_attention else attention_mask.size(self.seq_dim) + + padded_key_length = key_length if flash_attention else attention_mask.size(-1) allocate_key_length = padded_key_length if use_cache else max(self.pre_allocate_kv_cache, padded_key_length) + print("A", key_length, padded_key_length, allocate_key_length,allocate_key_length)#, attention_mask.shape) # Re-allocate kv cache and copy last value if allocate_key_length > allocated_key_length: + print(f"Allocate {allocate_key_length}") if self.multi_query: allocated_kv_cache = torch.empty( - [batch_size, allocate_key_length, self.head_dim], + [batch_size, allocate_key_length, 2*self.head_dim], dtype=current_kv_cache.dtype, device=current_kv_cache.device, ) + print(f"Copy 0:{last_key_length}") if layer_past is not None: allocated_kv_cache[:, :last_key_length].copy_(last_kv_cache) else: allocated_kv_cache = torch.empty( - [batch_size, self.num_heads, allocate_key_length, self.head_dim], + [batch_size, self.num_heads, allocate_key_length, 2*self.head_dim], dtype=current_kv_cache.dtype, device=current_kv_cache.device, ) @@ -384,6 +388,7 @@ def _merge_kv_caches(self, key_value, use_cache, layer_past, attention_mask): # Copy the new values. if allocate_key_length > allocated_key_length or layer_past is not None: if self.multi_query: + print(f"Copy {last_key_length}:{key_length}") allocated_kv_cache[:, last_key_length:key_length].copy_(current_kv_cache) padded_kv_cache = allocated_kv_cache[:, :padded_key_length] else: @@ -393,6 +398,7 @@ def _merge_kv_caches(self, key_value, use_cache, layer_past, attention_mask): # Use the merged KV cache. # Not needed when layer_past is None but frees some memory. key_value = padded_kv_cache + #print("B", key_value.shape, allocated_kv_cache.shape)#, padded_kv_cache.shape if use_cache: if allocated_kv_cache is None: @@ -414,9 +420,9 @@ def forward( Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], ]: - flash_attention = self.flash_attention and not layer_past + flash_attention = self.flash_attention and layer_past is None if self.multi_query or flash_attention: - query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) + query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=-1) else: # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim), # i.e., the memory layout is not the same as GPT2. @@ -818,6 +824,7 @@ def forward( if batch_size <= 0: raise ValueError("batch_size has to be defined and > 0") + flash_attention = self.flash_attention and past_key_values is None if past_key_values is None: past_length = 0 past_key_values = tuple([None] * len(self.h)) @@ -832,7 +839,7 @@ def forward( if token_type_ids is not None: token_type_ids = token_type_ids.view(-1, query_length) - if not self.flash_attention: + if not flash_attention: # Self-attention mask (padding + causal). attention_mask = self._get_causal_mask(attention_mask, query_length, key_length) if self.pad_key_length: @@ -861,7 +868,7 @@ def forward( hidden_states = self.drop(hidden_states) # TODO: Unpad earlier (input ids), support unpadded input? - if self.flash_attention: + if flash_attention: hidden_states, padding_index, sequence_lengths, max_sequence_length = unpad_input( hidden_states, attention_mask ) @@ -913,7 +920,7 @@ def custom_forward(*inputs): hidden_states = self.ln_f(hidden_states) - if self.flash_attention: + if flash_attention: hidden_states = pad_input(hidden_states, padding_index, batch_size, query_length) hidden_states = hidden_states.view(input_shape + (hidden_states.size(-1),)) From 4427b7dc47e0e3938b2832e510e82b63acece47c Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 25 Apr 2023 07:47:08 -0400 Subject: [PATCH 3/4] Fixes and cleanup --- .../gpt_bigcode/modeling_gpt_bigcode.py | 41 ++++++++----------- 1 file changed, 17 insertions(+), 24 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 9d6f16ffd2..3f8ffa8390 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -244,22 +244,15 @@ def _attn(self, query, key, value, attention_mask, head_mask=None): beta = 0 attn_weights = torch.baddbmm(attn_weights, query, key, beta=beta, alpha=scale_factor).view(attn_shape) - if upcast: - # Use a fused kernel to prevent a large overhead from casting and scaling. - # Sub-optimal when the key length is not a multiple of 8. - if attention_mask is None: - attn_weights = upcast_softmax(attn_weights, unscale, softmax_dtype) - else: - mask_value = self._get_mask_value(attn_weights.device, softmax_dtype) - attn_weights = upcast_masked_softmax(attn_weights, attention_mask, mask_value, unscale, softmax_dtype) - else: - if attention_mask is not None: - mask_value = self._get_mask_value(attn_weights.device, softmax_dtype) - - # The fused kernel is very slow when the key length is not a multiple of 8, so we skip fusion. - attn_weights = torch.where(attention_mask, attn_weights, mask_value) - - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + attn_weights = softmax_function( + attn_weights, + attention_mask, + None if attention_mask is None else self._get_mask_value(attn_weights.device, softmax_dtype), + unscale, + softmax_dtype, + upcast, + self.fused_softmax, + ) attn_weights = self.attn_dropout(attn_weights) @@ -359,36 +352,37 @@ def _merge_kv_caches(self, key_value, use_cache, layer_past, attention_mask): key_length = query_length + last_key_length allocated_key_length = allocated_kv_cache.size(self.seq_dim) - padded_key_length = key_length if flash_attention else attention_mask.size(-1) allocate_key_length = padded_key_length if use_cache else max(self.pre_allocate_kv_cache, padded_key_length) - print("A", key_length, padded_key_length, allocate_key_length,allocate_key_length)#, attention_mask.shape) # Re-allocate kv cache and copy last value if allocate_key_length > allocated_key_length: - print(f"Allocate {allocate_key_length}") if self.multi_query: allocated_kv_cache = torch.empty( - [batch_size, allocate_key_length, 2*self.head_dim], + [batch_size, allocate_key_length, 2 * self.head_dim], dtype=current_kv_cache.dtype, device=current_kv_cache.device, ) - print(f"Copy 0:{last_key_length}") if layer_past is not None: allocated_kv_cache[:, :last_key_length].copy_(last_kv_cache) + if allocate_key_length > key_length: + # Nans in `value` can propagate through the matrix multiplication, + # so we set the remaining values to zero. (`last_key_length:key_length` is set below.) + allocated_kv_cache[:, key_length:, self.head_dim :].zero_() else: allocated_kv_cache = torch.empty( - [batch_size, self.num_heads, allocate_key_length, 2*self.head_dim], + [batch_size, self.num_heads, allocate_key_length, 2 * self.head_dim], dtype=current_kv_cache.dtype, device=current_kv_cache.device, ) if layer_past is not None: allocated_kv_cache[:, :, :last_key_length].copy_(last_kv_cache) + if allocate_key_length > key_length: + allocated_kv_cache[:, :, key_length:, self.head_dim :].zero_() # Copy the new values. if allocate_key_length > allocated_key_length or layer_past is not None: if self.multi_query: - print(f"Copy {last_key_length}:{key_length}") allocated_kv_cache[:, last_key_length:key_length].copy_(current_kv_cache) padded_kv_cache = allocated_kv_cache[:, :padded_key_length] else: @@ -398,7 +392,6 @@ def _merge_kv_caches(self, key_value, use_cache, layer_past, attention_mask): # Use the merged KV cache. # Not needed when layer_past is None but frees some memory. key_value = padded_kv_cache - #print("B", key_value.shape, allocated_kv_cache.shape)#, padded_kv_cache.shape if use_cache: if allocated_kv_cache is None: From b50afe022715ce94502dfda2679c559a7dad8595 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 25 Apr 2023 08:06:32 -0400 Subject: [PATCH 4/4] Remove error --- src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 3f8ffa8390..fd12371182 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -187,9 +187,6 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): if not self.multi_query: # TODO: Flash Attention is implemented but not tested for MHA raise ValueError("Flash Attention is not supported with multi-head attention.") - if self.pre_allocate_kv_cache: - raise ValueError("KV cache pre-allocation is not supported with Flash Attention") - assert not self.pre_allocate_kv_cache def _get_mask_value(self, device, dtype): # torch.where expects a tensor. We use a cache to avoid recreating it every time.