From 9213192debe6bb450b2a8e0877eaaf8a2b9a5190 Mon Sep 17 00:00:00 2001 From: Xu Song Date: Fri, 27 Sep 2024 16:39:54 +0800 Subject: [PATCH] Support safetensors export --- deepspeed/utils/zero_to_fp32.py | 68 ++++++++++++++++++++++++++------- 1 file changed, 54 insertions(+), 14 deletions(-) diff --git a/deepspeed/utils/zero_to_fp32.py b/deepspeed/utils/zero_to_fp32.py index 24cc342e78d1..2a428a6ababd 100755 --- a/deepspeed/utils/zero_to_fp32.py +++ b/deepspeed/utils/zero_to_fp32.py @@ -10,7 +10,10 @@ # the future. Once extracted, the weights don't require DeepSpeed and can be used in any # application. # -# example: python zero_to_fp32.py . pytorch_model.bin +# example: +# python zero_to_fp32.py . output_dir/ +# or +# python zero_to_fp32.py . output_dir/ --safe_serialization import argparse import torch @@ -18,6 +21,8 @@ import math import os import re +import json +from tqdm import tqdm from collections import OrderedDict from dataclasses import dataclass @@ -27,6 +32,9 @@ from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS, FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES, FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS) +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_NAME +from transformers.modeling_utils import shard_checkpoint +from safetensors.torch import save_file @dataclass @@ -139,7 +147,6 @@ def parse_model_states(files): def parse_optim_states(files, ds_checkpoint_dir): - total_files = len(files) state_dicts = [] for f in files: @@ -420,12 +427,10 @@ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero offset = 0 total_numel = 0 total_params = 0 - for name, shape in param_shapes.items(): - + for name, shape in tqdm(param_shapes.items(), desc='Gather Sharded Weights'): unpartitioned_numel = shape.numel() total_numel += unpartitioned_numel total_params += 1 - partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) if debug: @@ -521,21 +526,41 @@ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_f return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters) -def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None, exclude_frozen_parameters=False): +def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_dir, + max_shard_size="5GB", safe_serialization=False, + tag=None, exclude_frozen_parameters=False): """ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. Args: - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) - - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) + - ``output_dir``: directory to the pytorch fp32 state_dict output files + - ``max_shard_size``: the maximum size for a checkpoint before being sharded, default value is 5GB + - ``safe_serialization``: whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` - ``exclude_frozen_parameters``: exclude frozen parameters """ - state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag, exclude_frozen_parameters) - print(f"Saving fp32 state dict to {output_file}") - torch.save(state_dict, output_file) + print(f"Saving fp32 state dict to {output_dir}") + weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME + shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name) + + # Save the model + for shard_file, shard in shards.items(): + output_path = os.path.join(output_dir, shard_file) + if safe_serialization: + save_file(shard, output_path, metadata={"format": "pt"}) + else: + torch.save(shard, output_path) + + # Save the index as well + if index is not None: + save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME + save_index_file = os.path.join(output_dir, save_index_file) + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): @@ -578,15 +603,28 @@ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): if __name__ == "__main__": - parser = argparse.ArgumentParser() parser.add_argument("checkpoint_dir", type=str, help="path to the desired checkpoint folder, e.g., path/checkpoint-12") parser.add_argument( - "output_file", + "output_dir", type=str, - help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)") + help="path to the pytorch fp32 state_dict output file " + "(e.g. path/checkpoint-12-output/)") + parser.add_argument( + "--max_shard_size", + type=str, + default="5GB", + help="The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size" + "lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`" + "We default it to 5GB in order for models to be able to run easily on free-tier google colab instances" + "without CPU OOM issues.") + parser.add_argument( + "--safe_serialization", + default=False, + action='store_true', + help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).") parser.add_argument("-t", "--tag", type=str, @@ -599,6 +637,8 @@ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): debug = args.debug convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, - args.output_file, + args.output_dir, + max_shard_size=args.max_shard_size, + safe_serialization=args.safe_serialization, tag=args.tag, exclude_frozen_parameters=args.exclude_frozen_parameters)