Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

For visibility: conversion scripts from fast-llm #29

Open
wants to merge 24 commits into
base: huggingface_main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
""" GPTBigCode configuration"""

import math
from ...configuration_utils import PretrainedConfig
from ...utils import logging

Expand Down Expand Up @@ -105,6 +106,7 @@ def __init__(
n_embd=768,
n_layer=12,
n_head=12,
head_groups=None,
n_inner=None,
activation_function="gelu_pytorch_tanh",
resid_pdrop=0.1,
Expand All @@ -119,6 +121,10 @@ def __init__(
attention_softmax_in_fp32=True,
scale_attention_softmax_in_fp32=True,
multi_query=True,
use_rotary_embeddings=False,
rotary_embedding_scale=-math.log(10000), # - 9.210
use_position_embeddings=None,
attention_window_size=None,
**kwargs,
):
self.vocab_size = vocab_size
Expand All @@ -137,7 +143,14 @@ def __init__(
self.use_cache = use_cache
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
self.scale_attention_softmax_in_fp32 = scale_attention_softmax_in_fp32
self.multi_query = multi_query
self.use_rotary_embeddings = use_rotary_embeddings
self.rotary_embedding_scale = rotary_embedding_scale
self.use_position_embeddings = use_position_embeddings if use_position_embeddings is not None else not use_rotary_embeddings
self.attention_window_size = attention_window_size
if head_groups is None:
self.head_groups = 1 if multi_query else n_head
else:
self.head_groups = head_groups

self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
Expand Down
162 changes: 162 additions & 0 deletions src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import argparse
import os
from pathlib import Path
import re

import torch
from transformers.models.gpt_bigcode.merge_fast_llm_checkpoint import merge_checkpoint
from transformers.models.gpt_bigcode import GPTBigCodeConfig


def convert_fast_llm_checkpoint(state_dict, config, set_attn_dense_bias_zero, set_mlp_2_bias_zero, version=1):
if set_attn_dense_bias_zero:
print("Will set attention output layer biases to zero")
if set_mlp_2_bias_zero:
print("Will set MLP layer-2 biases to zero")
# The converted output model.
output_state_dict = {}
if "window_size" in config:
attention_window_size = config["window_size"]
else:
attention_window_size = config.get("attention_window_size", None)

config = GPTBigCodeConfig(
architectures=["GPTBigCodeLMHeadModel"],
vocab_size=config["vocab_size"],
n_positions=config["max_position_embeddings"],
n_embd=config["hidden_size"],
n_layer=config["num_layers"],
n_head=config["num_attention_heads"],
head_groups=config.get("head_groups", None),
n_inner=config["ffn_hidden_size"],
activation_function="gelu", # TODO
multi_query=True, # TODO
resid_pdrop=0.1,
embd_pdrop=0.1,
attn_pdrop=0.1,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
summary_type="cls_index",
summary_use_proj=True,
summary_activation=None,
summary_proj_to_labels=True,
summary_first_dropout=0.1,
scale_attn_weights=True,
use_cache=True,
bos_token_id=0, # TODO: can we remove these?
eos_token_id=0,
attention_softmax_in_fp32=True,
scale_attention_softmax_in_fp32=True,
use_rotary_embeddings=config["use_rotary_embeddings"],
rotary_embedding_scale=config["rotary_embedding_scale"],
use_position_embeddings=config["use_position_embeddings"],
attention_window_size=attention_window_size
)

# Truncate the word embeddings to the vocab-size
u="_" if version==0 else ""
word_embeddings = state_dict.pop(f"{u}layers.0.{u}word_embeddings_weight")[:config.vocab_size, :]
output_state_dict["transformer.wte.weight"] = word_embeddings
if config.use_position_embeddings:
output_state_dict["transformer.wpe.weight"] = state_dict.pop(f"{u}layers.0.{u}position_embeddings_weight")

# Layer-0 is the word/position embeddings
# Layers 1 to n_layer need to be re-mapped from 0 to n_layer-1.
# _layers.{layer_index}.{op}.{w/b}

# Concatenate QKV matrix
for layer_index in range(1, config.n_layer + 1):
for weight_or_bias in ["weight", "bias"]:
query = state_dict.pop(f"{u}layers.{layer_index}.self_attn.query.{weight_or_bias}")
key_value = state_dict.pop(f"{u}layers.{layer_index}.self_attn.key_value.{weight_or_bias}")
output_state_dict[f"transformer.h.{layer_index - 1}.attn.c_attn.{weight_or_bias}"] = torch.cat([query, key_value], dim=0)

# The simple map of names for "automated" rules.
name_map = {
f"{u}mlp.{u}layer_1": "mlp.c_fc",
f"{u}mlp.{u}layer_2": "mlp.c_proj",
"layer_norm_1": "ln_1",
"layer_norm_2": "ln_2",
# "attention.dense": "attn.c_proj",
"self_attn.dense": "attn.c_proj",
# "self_attention.query_key_value": "attn.c_attn",
}
# Extract the other ops
layer_re = re.compile(f"{u}layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)")
for name, value in state_dict.items():
m = layer_re.match(name)
assert m is not None, f"Invalid layer name: {name}"

# The index of the layer.
layer_index = int(m.group(1))
# The name of the operation.
op_name = m.group(2)
# Is it a weight or a bias?
weight_or_bias = m.group(3)

# Final layernorm
if op_name == "final_layernorm":
assert layer_index == config.n_layer + 1
output_state_dict[f"transformer.ln_f.{weight_or_bias}"] = value
# Bias was not used in training for InputParallel layers
elif op_name == "self_attn.dense" and weight_or_bias == "bias" and set_attn_dense_bias_zero:
output_state_dict[f"transformer.h.{layer_index-1}.{name_map[op_name]}.{weight_or_bias}"] = torch.zeros_like(value)
# MLP layer-2 is also InputParallel
elif op_name == f"{u}mlp.{u}layer_2" and weight_or_bias == "bias" and set_mlp_2_bias_zero:
output_state_dict[f"transformer.h.{layer_index-1}.{name_map[op_name]}.{weight_or_bias}"] = torch.zeros_like(value)
else:
output_state_dict[f"transformer.h.{layer_index-1}.{name_map[op_name]}.{weight_or_bias}"] = value

# For LM head, transformers' wants the matrix to weight embeddings.
output_state_dict["lm_head.weight"] = word_embeddings

return output_state_dict, config


def main(argv=None):
parser = argparse.ArgumentParser()
parser.add_argument(
"--checkpoint_dir",
type=Path,
help="Path to the experiment directory",
)
parser.add_argument(
"--save_dir",
type=Path,
help="Path where the converted model is saved"
)
parser.add_argument(
"--set_attn_dense_bias_zero",
action='store_true',
default=False,
help="Set the attention output layer bias to zero and ignore the value from the checkpoint. Shouldn't be used except to fix a bug from training."
)
parser.add_argument(
"--set_mlp_2_bias_zero",
action='store_true',
default=False,
help="Set the MLP second layer bias to zero and ignore the value from the checkpoint. Shouldn't be used except to fix a bug from training."
)

args = parser.parse_args(argv)

state_dict, config = merge_checkpoint(
args.checkpoint_dir,
dummy_experiment_dir=None
)

output_state_dict, output_config = convert_fast_llm_checkpoint(state_dict, config, args.set_attn_dense_bias_zero, args.set_mlp_2_bias_zero)

print("Saving config")
save_dir = args.save_dir or args.checkpoint_dir / "converted"
output_config.save_pretrained(save_dir)

# Store the state_dict to file.
output_checkpoint_file = os.path.join(save_dir, "pytorch_model.bin")
print(f'Saving checkpoint to "{output_checkpoint_file}"')
torch.save(output_state_dict, output_checkpoint_file)
print(f'Done!')


if __name__ == "__main__":
main()
134 changes: 134 additions & 0 deletions src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import re
from tqdm import tqdm
from pathlib import Path

import numpy as np
import torch
import yaml


def get_all_checkpoint_paths(experiment_path):
checkpoints = (Path(experiment_path) / "checkpoints").glob("*")
# Sort checkpoints by iteration number
checkpoints = sorted(checkpoints, key=lambda x: int(x.name))
return [get_checkpoint_paths(checkpoint) for checkpoint in checkpoints]


def get_checkpoint_paths(checkpoint_dir: Path):
return [c_name for c_name in checkpoint_dir.glob("*") if re.match(r"\d+", c_name.name)]


def extract_stage_shards(state):
# Extract the weight shard and split it into the stage shards
# Reproduce the split done in MultiStageModelBase.setup
total_shard_size = sum(state['stage_shard_sizes'])
if len(state['shard'].shape) == 1:
# Flat buffer
weight_shard = state['shard'][:total_shard_size]
elif len(state['shard'].shape) == 2:
# 2D buffer
weight_shard = state['shard'][0]
else:
raise ValueError(f"Unrecognized buffer shape {state['shard'].shape}")
return weight_shard.split(state['stage_shard_sizes'])


def extract_individual_weights(merged_stage_shard, stage_content):
# Get individual weights from shards that are merged across data-parallel
weights_numel = [np.prod(weight_meta['shape']) for weight_meta in stage_content]
weights = merged_stage_shard[:sum(weights_numel)].split(weights_numel)
return [weight.reshape(weight_meta['shape']) for weight, weight_meta in zip(weights, stage_content)]


def concatenate_tp_shards(stage_tp_shards, stage_content):
# Concatenate the tp-shards in a given stage
# Stage_tp_shards: contains the individual weight shards for each rank
# [[weight1, weight2, ...] for rank in range(tp_size)]
concatenated_weights = []
# Concatenate each individual weight along their TP dimension if they have one.
for weight_tp_shards, weight_meta in zip(zip(*stage_tp_shards), stage_content):
if weight_meta["tensor_parallel_dim"] is not None:
weight = torch.cat(weight_tp_shards, dim=weight_meta["tensor_parallel_dim"])
else:
weight = weight_tp_shards[0]
concatenated_weights.append(weight)
return concatenated_weights


def merge_checkpoint(checkpoint_dir: Path, dummy_experiment_dir=None):
"""Load a fast-llm checkpoint and merge the data, tensor, and pipeline-parallel shards"""
# checkpoint_dir=experiment_dir/checkpoints/{iteration}
experiment_dir = checkpoint_dir.parent.parent
checkpoint_paths = get_checkpoint_paths(checkpoint_dir)
config = yaml.safe_load((experiment_dir / "config.yaml").read_text())

# Load the states from all the ranks
states = {
int(c_name.name): torch.load(c_name)
for c_name in tqdm(checkpoint_paths)
}
num_stages = len(states[0]["stages"])
tensor_parallel = config["tensor_parallel"]
data_parallel_size = int(config["world_size"] / (tensor_parallel * config["pipeline_parallel"]))

if dummy_experiment_dir is not None:
# Use the meta from the dummy checkpoint, and the shard from the actual checkpoint
dummy_checkpoint_paths = get_all_checkpoint_paths(dummy_experiment_dir)
dummy_states = {
int(c_name.name): torch.load(c_name)
for c_name in tqdm(dummy_checkpoint_paths[-1])
}
for rank, state in dummy_states.items():
state['shard'] = states[rank]['shard']
states = dummy_states

# Gather the data-parallel shards
# {tp_rank: [[stage_0_shard_0, stage_0_shard_1, ...], [stage_1_shard_0, ...], ...]}
# {tp_rank: [{fsdp_rank: shard}, ...]}
fsdp_shards = {
i: [[None for _ in range(data_parallel_size)] for _ in range(num_stages)]
for i in range(tensor_parallel)
}

for rank, state in states.items():
on_device_stage_shards = extract_stage_shards(state)
on_device_stage_indices = [i for (i, stage_meta) in enumerate(state["stages"]) if stage_meta["on_device"]]
for stage_index, stage_shard in zip(on_device_stage_indices, on_device_stage_shards):
stage_meta = state["stages"][stage_index]
# fsdp_shards[stage_meta["tp_rank"]][stage_index].append((stage_meta, stage_shard))
fsdp_shards[stage_meta["tp_rank"]][stage_index][stage_meta["fsdp_rank"]] = stage_shard

# Concatenate the data-parallel shards
# and get individual weights
dp_concatenated_shards = {
tp_rank: [
extract_individual_weights(
torch.cat(stage_shards, dim=0),
states[0]["stages"][stage_index]['content']
)
for stage_index, stage_shards in enumerate(fsdp_shards[tp_rank])
]
for tp_rank in range(config["tensor_parallel"])
}

# In the tensor-parallel case, concatenate the TP tensors along their TP dimensions.
tp_concatenated_shards = []
for stage_index, stage_tp_shards in enumerate(zip(*(dp_concatenated_shards[i] for i in range(tensor_parallel)))):
stage_content = states[0]["stages"][stage_index]["content"]
tp_concatenated_shards.append(concatenate_tp_shards(stage_tp_shards, stage_content))

# In the pipeline-parallel case, merge the stages
state_dict = {
weight_meta["name"]: weight
for stage_meta, stage_weights in zip(states[0]["stages"], tp_concatenated_shards)
for weight_meta, weight in zip(stage_meta["content"], stage_weights)
}

print(f"Total number of parameters: {sum([weight.numel() for weight in state_dict.values()])}")
return state_dict, config


if __name__ == "__main__":
merge_checkpoint("/toolkit_infiniband_example_checkpoints/ngc_checkpoints/sc2_ablations/1B_repo_context_Top-level-Depth-first_pp2_64k_64k_2023_10_17_16_35_27/",
dummy_experiment_dir="/toolkit_infiniband_example_checkpoints/ngc_checkpoints/sc2_ablations/dev_1B_repo_context_Random_pp2_64k_64k_2023_10_18_22_20_36/")

Loading