diff --git a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py index 9cbaf3e184..fbcb58a634 100644 --- a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py @@ -14,6 +14,7 @@ # limitations under the License. """ GPTBigCode configuration""" +import math from ...configuration_utils import PretrainedConfig from ...utils import logging @@ -105,6 +106,7 @@ def __init__( n_embd=768, n_layer=12, n_head=12, + head_groups=None, n_inner=None, activation_function="gelu_pytorch_tanh", resid_pdrop=0.1, @@ -119,6 +121,10 @@ def __init__( attention_softmax_in_fp32=True, scale_attention_softmax_in_fp32=True, multi_query=True, + use_rotary_embeddings=False, + rotary_embedding_scale=-math.log(10000), # - 9.210 + use_position_embeddings=None, + attention_window_size=None, **kwargs, ): self.vocab_size = vocab_size @@ -137,7 +143,14 @@ def __init__( self.use_cache = use_cache self.attention_softmax_in_fp32 = attention_softmax_in_fp32 self.scale_attention_softmax_in_fp32 = scale_attention_softmax_in_fp32 - self.multi_query = multi_query + self.use_rotary_embeddings = use_rotary_embeddings + self.rotary_embedding_scale = rotary_embedding_scale + self.use_position_embeddings = use_position_embeddings if use_position_embeddings is not None else not use_rotary_embeddings + self.attention_window_size = attention_window_size + if head_groups is None: + self.head_groups = 1 if multi_query else n_head + else: + self.head_groups = head_groups self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id diff --git a/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py b/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py new file mode 100644 index 0000000000..d9e78e97aa --- /dev/null +++ b/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py @@ -0,0 +1,162 @@ +import argparse +import os +from pathlib import Path +import re + +import torch +from transformers.models.gpt_bigcode.merge_fast_llm_checkpoint import merge_checkpoint +from transformers.models.gpt_bigcode import GPTBigCodeConfig + + +def convert_fast_llm_checkpoint(state_dict, config, set_attn_dense_bias_zero, set_mlp_2_bias_zero, version=1): + if set_attn_dense_bias_zero: + print("Will set attention output layer biases to zero") + if set_mlp_2_bias_zero: + print("Will set MLP layer-2 biases to zero") + # The converted output model. + output_state_dict = {} + if "window_size" in config: + attention_window_size = config["window_size"] + else: + attention_window_size = config.get("attention_window_size", None) + + config = GPTBigCodeConfig( + architectures=["GPTBigCodeLMHeadModel"], + vocab_size=config["vocab_size"], + n_positions=config["max_position_embeddings"], + n_embd=config["hidden_size"], + n_layer=config["num_layers"], + n_head=config["num_attention_heads"], + head_groups=config.get("head_groups", None), + n_inner=config["ffn_hidden_size"], + activation_function="gelu", # TODO + multi_query=True, # TODO + resid_pdrop=0.1, + embd_pdrop=0.1, + attn_pdrop=0.1, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + summary_type="cls_index", + summary_use_proj=True, + summary_activation=None, + summary_proj_to_labels=True, + summary_first_dropout=0.1, + scale_attn_weights=True, + use_cache=True, + bos_token_id=0, # TODO: can we remove these? + eos_token_id=0, + attention_softmax_in_fp32=True, + scale_attention_softmax_in_fp32=True, + use_rotary_embeddings=config["use_rotary_embeddings"], + rotary_embedding_scale=config["rotary_embedding_scale"], + use_position_embeddings=config["use_position_embeddings"], + attention_window_size=attention_window_size + ) + + # Truncate the word embeddings to the vocab-size + u="_" if version==0 else "" + word_embeddings = state_dict.pop(f"{u}layers.0.{u}word_embeddings_weight")[:config.vocab_size, :] + output_state_dict["transformer.wte.weight"] = word_embeddings + if config.use_position_embeddings: + output_state_dict["transformer.wpe.weight"] = state_dict.pop(f"{u}layers.0.{u}position_embeddings_weight") + + # Layer-0 is the word/position embeddings + # Layers 1 to n_layer need to be re-mapped from 0 to n_layer-1. + # _layers.{layer_index}.{op}.{w/b} + + # Concatenate QKV matrix + for layer_index in range(1, config.n_layer + 1): + for weight_or_bias in ["weight", "bias"]: + query = state_dict.pop(f"{u}layers.{layer_index}.self_attn.query.{weight_or_bias}") + key_value = state_dict.pop(f"{u}layers.{layer_index}.self_attn.key_value.{weight_or_bias}") + output_state_dict[f"transformer.h.{layer_index - 1}.attn.c_attn.{weight_or_bias}"] = torch.cat([query, key_value], dim=0) + + # The simple map of names for "automated" rules. + name_map = { + f"{u}mlp.{u}layer_1": "mlp.c_fc", + f"{u}mlp.{u}layer_2": "mlp.c_proj", + "layer_norm_1": "ln_1", + "layer_norm_2": "ln_2", + # "attention.dense": "attn.c_proj", + "self_attn.dense": "attn.c_proj", + # "self_attention.query_key_value": "attn.c_attn", + } + # Extract the other ops + layer_re = re.compile(f"{u}layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)") + for name, value in state_dict.items(): + m = layer_re.match(name) + assert m is not None, f"Invalid layer name: {name}" + + # The index of the layer. + layer_index = int(m.group(1)) + # The name of the operation. + op_name = m.group(2) + # Is it a weight or a bias? + weight_or_bias = m.group(3) + + # Final layernorm + if op_name == "final_layernorm": + assert layer_index == config.n_layer + 1 + output_state_dict[f"transformer.ln_f.{weight_or_bias}"] = value + # Bias was not used in training for InputParallel layers + elif op_name == "self_attn.dense" and weight_or_bias == "bias" and set_attn_dense_bias_zero: + output_state_dict[f"transformer.h.{layer_index-1}.{name_map[op_name]}.{weight_or_bias}"] = torch.zeros_like(value) + # MLP layer-2 is also InputParallel + elif op_name == f"{u}mlp.{u}layer_2" and weight_or_bias == "bias" and set_mlp_2_bias_zero: + output_state_dict[f"transformer.h.{layer_index-1}.{name_map[op_name]}.{weight_or_bias}"] = torch.zeros_like(value) + else: + output_state_dict[f"transformer.h.{layer_index-1}.{name_map[op_name]}.{weight_or_bias}"] = value + + # For LM head, transformers' wants the matrix to weight embeddings. + output_state_dict["lm_head.weight"] = word_embeddings + + return output_state_dict, config + + +def main(argv=None): + parser = argparse.ArgumentParser() + parser.add_argument( + "--checkpoint_dir", + type=Path, + help="Path to the experiment directory", + ) + parser.add_argument( + "--save_dir", + type=Path, + help="Path where the converted model is saved" + ) + parser.add_argument( + "--set_attn_dense_bias_zero", + action='store_true', + default=False, + help="Set the attention output layer bias to zero and ignore the value from the checkpoint. Shouldn't be used except to fix a bug from training." + ) + parser.add_argument( + "--set_mlp_2_bias_zero", + action='store_true', + default=False, + help="Set the MLP second layer bias to zero and ignore the value from the checkpoint. Shouldn't be used except to fix a bug from training." + ) + + args = parser.parse_args(argv) + + state_dict, config = merge_checkpoint( + args.checkpoint_dir, + dummy_experiment_dir=None + ) + + output_state_dict, output_config = convert_fast_llm_checkpoint(state_dict, config, args.set_attn_dense_bias_zero, args.set_mlp_2_bias_zero) + + print("Saving config") + save_dir = args.save_dir or args.checkpoint_dir / "converted" + output_config.save_pretrained(save_dir) + + # Store the state_dict to file. + output_checkpoint_file = os.path.join(save_dir, "pytorch_model.bin") + print(f'Saving checkpoint to "{output_checkpoint_file}"') + torch.save(output_state_dict, output_checkpoint_file) + print(f'Done!') + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py b/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py new file mode 100644 index 0000000000..71731559c7 --- /dev/null +++ b/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py @@ -0,0 +1,134 @@ +import re +from tqdm import tqdm +from pathlib import Path + +import numpy as np +import torch +import yaml + + +def get_all_checkpoint_paths(experiment_path): + checkpoints = (Path(experiment_path) / "checkpoints").glob("*") + # Sort checkpoints by iteration number + checkpoints = sorted(checkpoints, key=lambda x: int(x.name)) + return [get_checkpoint_paths(checkpoint) for checkpoint in checkpoints] + + +def get_checkpoint_paths(checkpoint_dir: Path): + return [c_name for c_name in checkpoint_dir.glob("*") if re.match(r"\d+", c_name.name)] + + +def extract_stage_shards(state): + # Extract the weight shard and split it into the stage shards + # Reproduce the split done in MultiStageModelBase.setup + total_shard_size = sum(state['stage_shard_sizes']) + if len(state['shard'].shape) == 1: + # Flat buffer + weight_shard = state['shard'][:total_shard_size] + elif len(state['shard'].shape) == 2: + # 2D buffer + weight_shard = state['shard'][0] + else: + raise ValueError(f"Unrecognized buffer shape {state['shard'].shape}") + return weight_shard.split(state['stage_shard_sizes']) + + +def extract_individual_weights(merged_stage_shard, stage_content): + # Get individual weights from shards that are merged across data-parallel + weights_numel = [np.prod(weight_meta['shape']) for weight_meta in stage_content] + weights = merged_stage_shard[:sum(weights_numel)].split(weights_numel) + return [weight.reshape(weight_meta['shape']) for weight, weight_meta in zip(weights, stage_content)] + + +def concatenate_tp_shards(stage_tp_shards, stage_content): + # Concatenate the tp-shards in a given stage + # Stage_tp_shards: contains the individual weight shards for each rank + # [[weight1, weight2, ...] for rank in range(tp_size)] + concatenated_weights = [] + # Concatenate each individual weight along their TP dimension if they have one. + for weight_tp_shards, weight_meta in zip(zip(*stage_tp_shards), stage_content): + if weight_meta["tensor_parallel_dim"] is not None: + weight = torch.cat(weight_tp_shards, dim=weight_meta["tensor_parallel_dim"]) + else: + weight = weight_tp_shards[0] + concatenated_weights.append(weight) + return concatenated_weights + + +def merge_checkpoint(checkpoint_dir: Path, dummy_experiment_dir=None): + """Load a fast-llm checkpoint and merge the data, tensor, and pipeline-parallel shards""" + # checkpoint_dir=experiment_dir/checkpoints/{iteration} + experiment_dir = checkpoint_dir.parent.parent + checkpoint_paths = get_checkpoint_paths(checkpoint_dir) + config = yaml.safe_load((experiment_dir / "config.yaml").read_text()) + + # Load the states from all the ranks + states = { + int(c_name.name): torch.load(c_name) + for c_name in tqdm(checkpoint_paths) + } + num_stages = len(states[0]["stages"]) + tensor_parallel = config["tensor_parallel"] + data_parallel_size = int(config["world_size"] / (tensor_parallel * config["pipeline_parallel"])) + + if dummy_experiment_dir is not None: + # Use the meta from the dummy checkpoint, and the shard from the actual checkpoint + dummy_checkpoint_paths = get_all_checkpoint_paths(dummy_experiment_dir) + dummy_states = { + int(c_name.name): torch.load(c_name) + for c_name in tqdm(dummy_checkpoint_paths[-1]) + } + for rank, state in dummy_states.items(): + state['shard'] = states[rank]['shard'] + states = dummy_states + + # Gather the data-parallel shards + # {tp_rank: [[stage_0_shard_0, stage_0_shard_1, ...], [stage_1_shard_0, ...], ...]} + # {tp_rank: [{fsdp_rank: shard}, ...]} + fsdp_shards = { + i: [[None for _ in range(data_parallel_size)] for _ in range(num_stages)] + for i in range(tensor_parallel) + } + + for rank, state in states.items(): + on_device_stage_shards = extract_stage_shards(state) + on_device_stage_indices = [i for (i, stage_meta) in enumerate(state["stages"]) if stage_meta["on_device"]] + for stage_index, stage_shard in zip(on_device_stage_indices, on_device_stage_shards): + stage_meta = state["stages"][stage_index] + # fsdp_shards[stage_meta["tp_rank"]][stage_index].append((stage_meta, stage_shard)) + fsdp_shards[stage_meta["tp_rank"]][stage_index][stage_meta["fsdp_rank"]] = stage_shard + + # Concatenate the data-parallel shards + # and get individual weights + dp_concatenated_shards = { + tp_rank: [ + extract_individual_weights( + torch.cat(stage_shards, dim=0), + states[0]["stages"][stage_index]['content'] + ) + for stage_index, stage_shards in enumerate(fsdp_shards[tp_rank]) + ] + for tp_rank in range(config["tensor_parallel"]) + } + + # In the tensor-parallel case, concatenate the TP tensors along their TP dimensions. + tp_concatenated_shards = [] + for stage_index, stage_tp_shards in enumerate(zip(*(dp_concatenated_shards[i] for i in range(tensor_parallel)))): + stage_content = states[0]["stages"][stage_index]["content"] + tp_concatenated_shards.append(concatenate_tp_shards(stage_tp_shards, stage_content)) + + # In the pipeline-parallel case, merge the stages + state_dict = { + weight_meta["name"]: weight + for stage_meta, stage_weights in zip(states[0]["stages"], tp_concatenated_shards) + for weight_meta, weight in zip(stage_meta["content"], stage_weights) + } + + print(f"Total number of parameters: {sum([weight.numel() for weight in state_dict.values()])}") + return state_dict, config + + +if __name__ == "__main__": + merge_checkpoint("/toolkit_infiniband_example_checkpoints/ngc_checkpoints/sc2_ablations/1B_repo_context_Top-level-Depth-first_pp2_64k_64k_2023_10_17_16_35_27/", + dummy_experiment_dir="/toolkit_infiniband_example_checkpoints/ngc_checkpoints/sc2_ablations/dev_1B_repo_context_Random_pp2_64k_64k_2023_10_18_22_20_36/") + diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 1c34f28a5c..56ce2358c4 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -78,16 +78,30 @@ def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor return x +def _apply_rotary_embeddings( + tensor: torch.Tensor, + rope_frequencies: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to a tensor: + * Convert it to a complex, full-precision tensor + * Multiply by the frequencies + * Convert back tho the input format. + # TODO: Full precision only needed for bfloat16? (Doesn't support complex numbers) + """ + complex_tensor = torch.view_as_complex(tensor.float().view(*tensor.shape[:-1], -1, rope_frequencies.size(-1), 2)) + return torch.view_as_real(complex_tensor * rope_frequencies).view_as(tensor).type_as(tensor) + + class GPTBigCodeAttention(nn.Module): def __init__(self, config, is_cross_attention=False, layer_idx=None): super().__init__() self.mask_value = None - self.multi_query = config.multi_query self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads - self.kv_heads = 1 if self.multi_query else self.num_heads + self.kv_heads = config.head_groups self.kv_dim = self.kv_heads * self.head_dim self.split_size = self.embed_dim if self.head_dim * self.num_heads != self.embed_dim: @@ -95,6 +109,7 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads})." ) + self.use_rotary_embeddings = config.use_rotary_embeddings self.scale_attn_weights = config.scale_attn_weights self.is_cross_attention = is_cross_attention @@ -106,8 +121,8 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): ) if self.is_cross_attention: - if self.multi_query: - raise NotImplementedError("Multi-Query Attention not supported for cross_attention") + if self.kv_heads != self.num_heads: + raise NotImplementedError("MQA / GQA not supported for cross_attention") self.c_attn = nn.Linear(self.embed_dim, 2 * self.embed_dim) self.q_attn = nn.Linear(self.embed_dim, self.embed_dim) @@ -135,29 +150,29 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): if self.scale_attn_weights: scale_factor /= self.head_dim**0.5 - # MQA models: (batch_size, query_length, num_heads * head_dim) - # MHA models: (batch_size, num_heads, query_length, head_dim) + # query: (batch_size, query_length, num_heads * head_dim) query_shape = query.shape batch_size = query_shape[0] + query_length = query_shape[1] key_length = key.size(-1) - if self.multi_query: + # MQA + if self.kv_heads == 1: # (batch_size, query_length, num_heads, head_dim) x (batch_size, head_dim, key_length) # -> (batch_size, query_length, num_heads, key_length) - query_length = query_shape[1] attn_shape = (batch_size, query_length, self.num_heads, key_length) attn_view = (batch_size, query_length * self.num_heads, key_length) # No copy needed for MQA 2, or when layer_past is provided. query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim) else: - # (batch_size, num_heads, query_length, head_dim) x (batch_size, num_heads, head_dim, key_length) - # -> (batch_size, num_heads, query_length, key_length) - query_length = query_shape[2] - attn_shape = (batch_size, self.num_heads, query_length, key_length) - attn_view = (batch_size * self.num_heads, query_length, key_length) - # Always copies - query = query.reshape(batch_size * self.num_heads, query_length, self.head_dim) - # No copy when layer_past is provided. - key = key.reshape(batch_size * self.num_heads, self.head_dim, key_length) + heads_per_group = self.num_heads // self.kv_heads + attn_shape = (batch_size, self.kv_heads, query_length, heads_per_group, key_length) + attn_view = (batch_size * self.kv_heads, query_length * heads_per_group, key_length) + query = query.reshape(batch_size, query_length, self.kv_heads, heads_per_group, self.head_dim).transpose(1, 2) + query = query.reshape(batch_size * self.kv_heads, query_length * heads_per_group, self.head_dim) + key = key.reshape(batch_size * self.kv_heads, self.head_dim, key_length) + value = value.transpose(1, 2) # (batch, kv_heads * head_dim, key_length) + value = value.reshape(batch_size * self.kv_heads, self.head_dim, key_length).transpose(1, 2) + # Attention Mask: (batch_size, query_length, 1, key_length) attn_weights = torch.empty(attn_view, device=query.device, dtype=query.dtype) if query.device.type == "cpu": @@ -191,14 +206,18 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): # Mask heads if we want to if head_mask is not None: - if self.multi_query: + if self.kv_heads == 1: head_mask = head_mask.transpose(1, 2) attn_weights = attn_weights * head_mask - if self.multi_query: + # MQA + if self.kv_heads == 1: attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape) else: - attn_output = torch.matmul(attn_weights, value) + # -> (batch_size * self.kv_heads, query_length * heads_per_group, head_dim) + attn_output = torch.bmm(attn_weights.view(attn_view), value) + attn_output = attn_output.reshape(batch_size, self.kv_heads, query_length, heads_per_group, self.head_dim).transpose(1, 2) + attn_output = attn_output.reshape(batch_size, query_length, self.kv_heads * heads_per_group * self.head_dim) return attn_output, attn_weights @@ -212,10 +231,13 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, + rotary_embedding_frequencies_q: Optional[torch.Tensor] = None, + rotary_embedding_frequencies_k: Optional[torch.Tensor] = None ) -> Union[ Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], ]: + # hidden: (batch, sequence, hidden_size) if encoder_hidden_states is not None: if not hasattr(self, "q_attn") or not self.is_cross_attention: raise ValueError( @@ -226,37 +248,57 @@ def forward( query = self.q_attn(hidden_states) key_value = self.c_attn(encoder_hidden_states) attention_mask = encoder_attention_mask - elif self.multi_query: + + if layer_past is not None: + key_value = torch.cat((layer_past, key_value), dim=-2) + present = key_value if use_cache else None + + key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) + elif self.kv_heads == 1: query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) + + if layer_past is not None: + key_value = torch.cat((layer_past, key_value), dim=-2) + present = key_value if use_cache else None + + key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) else: # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim), # i.e., the memory layout is not the same as GPT2. # This makes the concatenation with past_key_value more efficient. - query, key_value = ( - self.c_attn(hidden_states) - .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim) - .transpose(1, 2) - .split((self.head_dim, 2 * self.head_dim), dim=3) - ) + # query, key_value = ( + # self.c_attn(hidden_states) + # .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim) # (batch, sequence, num_heads, 3*head_dim) + # .transpose(1, 2) # (batch, num_heads, sequence, 3*head_dim) + # .split((self.head_dim, 2 * self.head_dim), dim=3) + # ) + + query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) + # key_value: (batch, sequence, 2 * kv_heads * head_dim) + + if layer_past is not None: + key_value = torch.cat((layer_past, key_value), dim=-2) + present = key_value if use_cache else None - if layer_past is not None: - key_value = torch.cat((layer_past, key_value), dim=-2) - present = key_value if use_cache else None + key, value = key_value.split((self.kv_heads * self.head_dim), dim=-1) - key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) + if self.use_rotary_embeddings: + query = _apply_rotary_embeddings(query, rotary_embedding_frequencies_q) + key = _apply_rotary_embeddings(key, rotary_embedding_frequencies_k) attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask) - if not self.multi_query: - attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape) attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) outputs = (attn_output, present) if output_attentions: - if self.multi_query: + if self.kv_heads == 1: # Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length) attn_weights = attn_weights.transpose(1, 2) + else: + # (batch_size, self.kv_heads, query_length, heads_per_group, key_length) + attn_weights = attn_weights.transpose(2, 3) outputs += (attn_weights,) return outputs # a, present, (attentions) @@ -291,8 +333,8 @@ def __init__(self, config, layer_idx=None): self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) if config.add_cross_attention: - if config.multi_query: - raise NotImplementedError("Cross-attention not implemented for MQA") + if config.head_groups < config.num_heads: + raise NotImplementedError("Cross-attention not implemented for MQA / GQA") self.crossattention = GPTBigCodeAttention(config, is_cross_attention=True, layer_idx=layer_idx) self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) @@ -308,6 +350,8 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, + rotary_embedding_frequencies_q: Optional[torch.Tensor] = None, + rotary_embedding_frequencies_k: Optional[torch.Tensor] = None ) -> Union[ Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor] ]: @@ -320,6 +364,8 @@ def forward( head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, + rotary_embedding_frequencies_q=rotary_embedding_frequencies_q, + rotary_embedding_frequencies_k=rotary_embedding_frequencies_k ) attn_output = attn_outputs[0] # output_attn: a, present, (attentions) outputs = attn_outputs[1:] @@ -342,6 +388,8 @@ def forward( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, + rotary_embedding_frequencies_q=rotary_embedding_frequencies_q, + rotary_embedding_frequencies_k=rotary_embedding_frequencies_k ) attn_output = cross_attn_outputs[0] # residual connection @@ -502,11 +550,26 @@ def _set_gradient_checkpointing(self, module, value=False): class GPTBigCodeModel(GPTBigCodePreTrainedModel): def __init__(self, config): super().__init__(config) - self.multi_query = config.multi_query + self.kv_heads = config.head_groups self.embed_dim = config.hidden_size self.wte = nn.Embedding(config.vocab_size, self.embed_dim) - self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + if config.use_position_embeddings: + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + if config.use_rotary_embeddings: + # Calculate the complex frequencies (https://blog.eleuther.ai/rotary-embeddings/) + # `exp(i * n * a) = cos(n * a) + i sin(n * a)`, + # `a = theta ** - (2 * (channel // 2) / kv_channels)`, + # where n is the position in the sequence. + kv_channels = config.n_embd / config.n_head + angles = torch.outer( + torch.arange(config.max_position_embeddings, dtype=torch.float32), + torch.exp( + config.rotary_embedding_scale + * torch.arange(0, 1, 2 / kv_channels, dtype=torch.float32) + ), + ) + self._rotary_embedding_frequencies = torch.polar(torch.ones_like(angles), angles)[None, :, None, :] self.drop = nn.Dropout(config.embd_pdrop) self.h = nn.ModuleList([GPTBigCodeBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)]) @@ -595,20 +658,32 @@ def forward( elif position_ids is None: position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + # Rotary frequencies + rotary_embedding_frequencies_q = None + rotary_embedding_frequencies_k = None + if self.config.use_rotary_embeddings: + rotary_embedding_frequencies_q = self._rotary_embedding_frequencies[:, past_length : past_length + input_shape[-1]].to(device=device) + rotary_embedding_frequencies_k = self._rotary_embedding_frequencies[:, :past_length + input_shape[-1], :, :].to(device=device) # Self-attention mask. query_length = input_shape[-1] key_length = past_length + query_length self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length] + # Sliding window attention + if self.config.attention_window_size is not None: + self_attention_mask.triu_(-self.config.attention_window_size + 1) if attention_mask is not None: self_attention_mask = self_attention_mask * attention_mask.view(batch_size, 1, -1).to( dtype=torch.bool, device=self_attention_mask.device ) - # MQA models: (batch_size, query_length, n_heads, key_length) - # MHA models: (batch_size, n_heads, query_length, key_length) - attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1) + # Attention-shape: (batch_size, query_length, n_heads, key_length) + attention_mask = self_attention_mask.unsqueeze(2) + if self.kv_heads > 1: + # (batch_size, self.kv_heads, query_length, heads_per_group, key_length) + attention_mask = attention_mask.unsqueeze(1) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -620,7 +695,7 @@ def forward( if encoder_attention_mask.dim() == 2: encoder_attention_mask.unsqueeze(1) assert encoder_attention_mask.dim() == 3 - encoder_attention_mask = encoder_attention_mask.bool().unsqueeze(2 if self.multi_query else 1) + encoder_attention_mask = encoder_attention_mask.bool().unsqueeze(2) else: encoder_attention_mask = None @@ -632,8 +707,11 @@ def forward( if inputs_embeds is None: inputs_embeds = self.wte(input_ids) - position_embeds = self.wpe(position_ids) - hidden_states = inputs_embeds + position_embeds + if self.config.use_position_embeddings: + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + else: + hidden_states = inputs_embeds if token_type_ids is not None: token_type_embeds = self.wte(token_type_ids) @@ -656,7 +734,7 @@ def forward( def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value - return module(*inputs, use_cache, output_attentions) + return module(*inputs, use_cache, output_attentions, rotary_embedding_frequencies_q, rotary_embedding_frequencies_k) return custom_forward @@ -679,6 +757,8 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, use_cache=use_cache, output_attentions=output_attentions, + rotary_embedding_frequencies_q=rotary_embedding_frequencies_q, + rotary_embedding_frequencies_k=rotary_embedding_frequencies_k ) hidden_states = outputs[0] diff --git a/src/transformers/models/gpt_bigcode/push_checkpoints.py b/src/transformers/models/gpt_bigcode/push_checkpoints.py new file mode 100644 index 0000000000..3abd551b71 --- /dev/null +++ b/src/transformers/models/gpt_bigcode/push_checkpoints.py @@ -0,0 +1,106 @@ +import os +import argparse +import re +import subprocess +from pathlib import Path + +from huggingface_hub import Repository + +from transformers.models.gpt_bigcode.convert_fast_llm_checkpoint import main as convert + + +""" +Script to upload Fast-llm checkpoints to a HF repo on the Hub. The script clones/creates a repo on the Hub, checks out +a branch `--branch_name`, and converts each `iter_` checkpoint and saves it as a commit on that branch. +""" + + +def get_iter_number(iter_dir: str): + m = re.match(r"(\d+)", iter_dir) + if m is not None: + return int(m.group(1)) + else: + raise ValueError(f"Invalid directory name: {iter_dir}") + + +def main(argv=None): + parser = argparse.ArgumentParser() + parser.add_argument("--exp_dir", type=Path, required=True, help="Path to experiment folder.") + parser.add_argument("--repo_name", required=True, help="Name of repository on the Hub in 'ORG/NAME' format.") + parser.add_argument("--branch_name", required=True, help="Name of branch in repository to save experiments.") + parser.add_argument( + "--save_dir", + type=Path, + help="Path where repository is cloned to locally. Will use {exp_dir}/hf_checkpoints if not provided", + ) + parser.add_argument( + "--iter_interval", + type=int, + default=1, + help="Iteration number must be divisble by iter_interval in order to be pushed", + ) + parser.add_argument( + "--iters", + type=int, + nargs='+', + default=None, + help="Specify a list of iterations to push. If None (default), will potentially push all the checkpoints (subject to iter_interval)", + ) + parser.add_argument( + "--tokenizer", + type=str, + default=None, + help="Path to tokenizer file to commit before the checkoints.", + ) + parser.add_argument( + "--push_past_iters", + action="store_true", + default=False, + help="If True, also push iterations that are lower than the last commit.", + ) + args, argv = parser.parse_known_args(argv) + + save_dir = args.save_dir or args.exp_dir / "hf_checkpoints" + + hf_repo = Repository(save_dir, clone_from=args.repo_name) + hf_repo.git_checkout(args.branch_name, create_branch_ok=True) + + # Pull latest changes + env = os.environ.copy() + env["GIT_LFS_SKIP_SMUDGE"] = "1" + git_pull_output = subprocess.run(["git", "pull"], cwd=save_dir, capture_output=True, env=env) + print(git_pull_output) + + # Find last checkpoint that was uploaded + head_hash = hf_repo.git_head_hash() + commit_msg = subprocess.check_output(["git", "show", "-s", "--format=%B", head_hash], cwd=save_dir).decode() + try: + last_commit_iter = get_iter_number(commit_msg.strip()) + print(f"Last commit iteration: {last_commit_iter}") + except ValueError: + last_commit_iter = -1 + + # The checkpoint dirs should be in ascending iteration order, so that the last commit corresponds to the latest checkpoint + ckpt_dirs = [x for x in (args.exp_dir / "checkpoints").iterdir() if re.match(r"(\d+)", x.name) and x.is_dir()] + if args.iters is not None: + args.iters = [int(n) for n in args.iters] + ckpt_dirs = [p for p in ckpt_dirs if get_iter_number(p.name) in args.iters] + ckpt_dirs = sorted(ckpt_dirs, key=lambda p: get_iter_number(p.name)) + print(f"Found the following checkpoints: {ckpt_dirs}") + + if args.tokenizer is not None: + raise NotImplementedError("Push tokenizer not implemented yet") + + for ckpt_dir in ckpt_dirs: + iter_number = get_iter_number(ckpt_dir.name) + if not args.push_past_iters and iter_number <= last_commit_iter: + continue + if iter_number % args.iter_interval == 0: + print(f"Converting iteration {iter_number}") + convert(argv + [f"--save_dir={str(save_dir)}", f"--checkpoint_dir={ckpt_dir}"]) + print(f"Pushing iteration {iter_number}") + hf_repo.push_to_hub(commit_message=f"{ckpt_dir.name}") + + +if __name__ == "__main__": + main() \ No newline at end of file