Skip to content

Commit

Permalink
Support safetensors export
Browse files Browse the repository at this point in the history
  • Loading branch information
xu-song authored Sep 27, 2024
1 parent d45cfd3 commit 9213192
Showing 1 changed file with 54 additions and 14 deletions.
68 changes: 54 additions & 14 deletions deepspeed/utils/zero_to_fp32.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,19 @@
# the future. Once extracted, the weights don't require DeepSpeed and can be used in any
# application.
#
# example: python zero_to_fp32.py . pytorch_model.bin
# example:
# python zero_to_fp32.py . output_dir/
# or
# python zero_to_fp32.py . output_dir/ --safe_serialization

import argparse
import torch
import glob
import math
import os
import re
import json
from tqdm import tqdm
from collections import OrderedDict
from dataclasses import dataclass

Expand All @@ -27,6 +32,9 @@
from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_NAME
from transformers.modeling_utils import shard_checkpoint
from safetensors.torch import save_file


@dataclass
Expand Down Expand Up @@ -139,7 +147,6 @@ def parse_model_states(files):


def parse_optim_states(files, ds_checkpoint_dir):

total_files = len(files)
state_dicts = []
for f in files:
Expand Down Expand Up @@ -420,12 +427,10 @@ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero
offset = 0
total_numel = 0
total_params = 0
for name, shape in param_shapes.items():

for name, shape in tqdm(param_shapes.items(), desc='Gather Sharded Weights'):
unpartitioned_numel = shape.numel()
total_numel += unpartitioned_numel
total_params += 1

partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)

if debug:
Expand Down Expand Up @@ -521,21 +526,41 @@ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_f
return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)


def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None, exclude_frozen_parameters=False):
def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_dir,
max_shard_size="5GB", safe_serialization=False,
tag=None, exclude_frozen_parameters=False):
"""
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
Args:
- ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
- ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)
- ``output_dir``: directory to the pytorch fp32 state_dict output files
- ``max_shard_size``: the maximum size for a checkpoint before being sharded, default value is 5GB
- ``safe_serialization``: whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
- ``exclude_frozen_parameters``: exclude frozen parameters
"""

state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag, exclude_frozen_parameters)
print(f"Saving fp32 state dict to {output_file}")
torch.save(state_dict, output_file)
print(f"Saving fp32 state dict to {output_dir}")
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name)

# Save the model
for shard_file, shard in shards.items():
output_path = os.path.join(output_dir, shard_file)
if safe_serialization:
save_file(shard, output_path, metadata={"format": "pt"})
else:
torch.save(shard, output_path)

# Save the index as well
if index is not None:
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
save_index_file = os.path.join(output_dir, save_index_file)
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)


def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
Expand Down Expand Up @@ -578,15 +603,28 @@ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):


if __name__ == "__main__":

parser = argparse.ArgumentParser()
parser.add_argument("checkpoint_dir",
type=str,
help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
parser.add_argument(
"output_file",
"output_dir",
type=str,
help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)")
help="path to the pytorch fp32 state_dict output file "
"(e.g. path/checkpoint-12-output/)")
parser.add_argument(
"--max_shard_size",
type=str,
default="5GB",
help="The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size"
"lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`"
"We default it to 5GB in order for models to be able to run easily on free-tier google colab instances"
"without CPU OOM issues.")
parser.add_argument(
"--safe_serialization",
default=False,
action='store_true',
help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).")
parser.add_argument("-t",
"--tag",
type=str,
Expand All @@ -599,6 +637,8 @@ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
debug = args.debug

convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir,
args.output_file,
args.output_dir,
max_shard_size=args.max_shard_size,
safe_serialization=args.safe_serialization,
tag=args.tag,
exclude_frozen_parameters=args.exclude_frozen_parameters)

0 comments on commit 9213192

Please sign in to comment.