Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
Browse files Browse the repository at this point in the history
…nto support-deepseek-v3
  • Loading branch information
yuanlehome committed Jan 24, 2025
2 parents a5a16d4 + 96856bd commit a8f3839
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 33 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
| [ChatGLM3](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/chatglm2) | THUDM/chatglm3-6b |
| [DeepSeekV2](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/config/deepseek-v2) | deepseek-ai/DeepSeek-V2, deepseek-ai/DeepSeek-V2-Chat, deepseek-ai/DeepSeek-V2-Lite, deepseek-ai/DeepSeek-V2-Lite-Chat, deepseek-ai/DeepSeek-Coder-V2-Base, deepseek-ai/DeepSeek-Coder-V2-Instruct, deepseek-ai/DeepSeek-Coder-V2-Lite-Base, deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct |
| [DeepSeekV3](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/config/deepseek-v2) | deepseek-ai/DeepSeek-V3, deepseek-ai/DeepSeek-V3-Base |
| [DeepSeek-R1](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/config/deepseek-v2) | deepseek-ai/DeepSeek-R1, deepseek-ai/DeepSeek-R1-Zero, deepseek-ai/DeepSeek-R1-Distill-Llama-70B, deepseek-ai/DeepSeek-R1-Distill-Llama-8B, deepseek-ai/DeepSeek-R1-Distill-Qwen-14B, deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B, deepseek-ai/DeepSeek-R1-Distill-Qwen-32B, deepseek-ai/DeepSeek-R1-Distill-Qwen-7B |
| [Gemma](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/gemma) | google/gemma-7b, google/gemma-7b-it, google/gemma-2b, google/gemma-2b-it |
| [Mistral](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/mistral) | mistralai/Mistral-7B-Instruct-v0.3, mistralai/Mistral-7B-v0.1 |
| [Mixtral](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/mixtral) | mistralai/Mixtral-8x7B-Instruct-v0.1 |
Expand Down
32 changes: 17 additions & 15 deletions paddlenlp/trainer/unified_checkpoint/load_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,6 @@ def _remove_unused_keys(


def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoint, safe_serialization=False):
# Special process with split param.
if is_sharding_split_param_mode(args):
returned_optim_state_dict = load_unified_optimizer_split_param(args, model, optimizer, resume_from_checkpoint)
return returned_optim_state_dict

# init and get optimizer LR_Scheduler
returned_optim_state_dict = nested_copy(optimizer.state_dict())

if not safe_serialization:
index_filename, index_filename_master_weights = (
PADDLE_OPTIMIZER_INDEX_NAME,
Expand All @@ -165,6 +157,23 @@ def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoin
else:
index_filename, index_filename_master_weights = SAFE_OPTIMIZER_INDEX_NAME, SAFE_MASTER_WEIGHTS_INDEX_NAME

with open(os.path.join(resume_from_checkpoint, index_filename), "r") as f:
index = json.loads(f.read())

ckpt_quant_stage = "O0"
if "ckpt_quant_stage" in index:
ckpt_quant_stage = index["ckpt_quant_stage"]

# Special process with split param.
if is_sharding_split_param_mode(args):
returned_optim_state_dict = load_unified_optimizer_split_param(
args, model, optimizer, resume_from_checkpoint, ckpt_quant_stage
)
return returned_optim_state_dict

# init and get optimizer LR_Scheduler
returned_optim_state_dict = nested_copy(optimizer.state_dict())

resolved_archive_file, sharded_metadata = get_optimizer_shard_files(
optimizer_path=resume_from_checkpoint,
index_filename=os.path.join(resume_from_checkpoint, index_filename),
Expand All @@ -184,13 +193,6 @@ def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoin
if len(resolved_archive_file) > 1:
resolved_archive_file = tqdm(resolved_archive_file, desc="Loading optimizer shards")

with open(os.path.join(resume_from_checkpoint, index_filename), "r") as f:
index = json.loads(f.read())

ckpt_quant_stage = "O0"
if "ckpt_quant_stage" in index:
ckpt_quant_stage = index["ckpt_quant_stage"]

# update has_master_weights and index_filename_master_weights
# 1. if the master weight exists, only has_master_weights is set True and loaded when needed
# 2. if master weight does not exist, convert model weight to master weight when needed
Expand Down
71 changes: 59 additions & 12 deletions paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,25 @@
get_expected_state_dict,
get_optimizer_shard_files,
mapping_optimizer_tp_actions,
update_master_weight_status,
)

__all__ = ["gather_splited_param_for_optimizer", "load_unified_optimizer_split_param"]


def merge_splited_param(
state_dict, partial_tensor_list, param_shape_info, send_table, recv_table, is_master_weights=False
state_dict,
partial_tensor_list,
param_shape_info,
send_table,
recv_table,
is_master_weights=False,
ckpt_quant_stage="O0",
):
"""Merge the splited param in sharding group."""
global_rank = dist.get_rank()
for key in list(state_dict.keys()):
if state_dict[key].numel().item() == 1: # for example: beta1, beta2
if int(state_dict[key].numel()) == 1: # for example: beta1, beta2
continue

static_name = key if is_master_weights else generate_base_static_name(key)[0]
Expand Down Expand Up @@ -89,10 +96,21 @@ def merge_splited_param(
)
dist.stream.send(tensor, dst=recv_rank)
state_dict.pop(key)

if ckpt_quant_stage != "O0":
for key in list(state_dict.keys()):
if int(state_dict[key].numel()) == 1: # for example: beta1, beta2
static_name = key if is_master_weights else generate_base_static_name(key)[0]
if static_name in partial_tensor_list:
recv_rank = recv_table[static_name]
send_info = send_table[static_name]
if global_rank != recv_rank:
state_dict.pop(key)

return state_dict


def gather_splited_param_for_optimizer(optimizer):
def gather_splited_param_for_optimizer(optimizer, ckpt_quant_stage="O0"):
hcg = fleet.get_hybrid_communicate_group()
sharding_group = hcg.get_sharding_parallel_group()
global_rank = dist.get_rank()
Expand Down Expand Up @@ -127,7 +145,7 @@ def gather_splited_param_for_optimizer(optimizer):
for key in list(optim_state_dict.keys()):
static_name, _ = generate_base_static_name(key)
if static_name in param_slice_info.keys():
if optim_state_dict[key].numel().item() == 1: # for example: beta1, beta2
if int(optim_state_dict[key].numel()) == 1: # for example: beta1, beta2
continue
begin, end = param_slice_info[static_name]
shape, numel, _, _ = param_shape_info[static_name]
Expand All @@ -149,13 +167,15 @@ def gather_splited_param_for_optimizer(optimizer):
recv_table[key] = sharding_ranklist[0][0] # which sharding_rank to recv the splited tensor
send_table[key] = [(rank, begin, end) for rank, begin, end in sharding_ranklist]

merge_splited_param(optim_state_dict, partial_tensor_list, param_shape_info, send_table, recv_table, False)
merge_splited_param(
optim_state_dict, partial_tensor_list, param_shape_info, send_table, recv_table, False, ckpt_quant_stage
)
if master_weights is not None:
merge_splited_param(master_weights, partial_tensor_list, param_shape_info, send_table, recv_table, True)
return optim_state_dict, master_weights


def load_unified_optimizer_split_param(args, model, optimizer, resume_from_checkpoint):
def load_unified_optimizer_split_param(args, model, optimizer, resume_from_checkpoint, ckpt_quant_stage="O0"):
returned_optim_state_dict = nested_copy(optimizer.state_dict())

index_filename, index_filename_master_weights = SAFE_OPTIMIZER_INDEX_NAME, SAFE_MASTER_WEIGHTS_INDEX_NAME
Expand Down Expand Up @@ -208,6 +228,10 @@ def load_unified_optimizer_split_param(args, model, optimizer, resume_from_check
if len(resolved_archive_file) > 1:
resolved_archive_file = tqdm(resolved_archive_file, desc="Loading optimizer shards")

has_master_weights, index_filename_master_weights = update_master_weight_status(
args, optimizer, has_master_weights, safe_serialization=True
)

if has_master_weights:
returned_optim_state_dict["master_weights"] = {}
resolved_archive_file_mw, sharded_metadata_mw = get_optimizer_shard_files(
Expand All @@ -217,7 +241,9 @@ def load_unified_optimizer_split_param(args, model, optimizer, resume_from_check
if len(resolved_archive_file_mw) > 1:
resolved_archive_file_mw = tqdm(resolved_archive_file_mw, desc="Loading master weights shards")

def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected_keys, is_master_weights=False):
def load_resolved_archive_file(
resolved_archive_file, sharded_metadata, expected_keys, is_master_weights=False, ckpt_quant_stage="O0"
):
returned_state_dict = {}

if model.config.tensor_parallel_degree > 1:
Expand All @@ -232,24 +258,38 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
if expected_keys.isdisjoint(sharded_metadata["file_map"][os.path.split(shard_file)[-1]]):
continue
if model.config.tensor_parallel_degree > 1:
state_dict = load_state_dict(shard_file, tp_actions, expected_keys, device="cpu")
state_dict = load_state_dict(
shard_file,
tp_actions,
expected_keys,
device="cpu",
ckpt_quant_stage=ckpt_quant_stage,
)
else:
state_dict = load_state_dict(shard_file, None, expected_keys, device="cpu")
state_dict = load_state_dict(
shard_file,
None,
expected_keys,
device="cpu",
ckpt_quant_stage=ckpt_quant_stage,
)
returned_state_dict.update(state_dict)
del state_dict
gc.collect()

return returned_state_dict

# get tp params
state_dict_optim = load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected_keys_optim)
state_dict_optim = load_resolved_archive_file(
resolved_archive_file, sharded_metadata, expected_keys_optim, ckpt_quant_stage=ckpt_quant_stage
)

# need to split param for different sharding rank, maybe need to deal with oom issue.
for key in list(state_dict_optim.keys()):
key_name = key.split("/")
static_name = struct2static_name_mappings.get(key_name[0], None)

if state_dict_optim[key].numel().item() > 1:
if int(state_dict_optim[key].numel()) > 1:
begin, end = param_slice_info[static_name]
shape, numel, index, padded_size = param_shape_info[static_name]
state_dict_optim[key] = state_dict_optim[key].reshape([-1])
Expand Down Expand Up @@ -284,7 +324,7 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected

for key in list(state_dict_master_weight.keys()):
static_name = struct2static_name_mappings.get(key, None)
if state_dict_master_weight[key].numel().item() > 1:
if int(state_dict_master_weight[key].numel()) > 1:
begin, end = param_slice_info[static_name]
shape, numel, index, padded_size = param_shape_info[static_name]
state_dict_master_weight[key] = state_dict_master_weight[key].reshape([-1])
Expand All @@ -303,6 +343,13 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
paddle.framework._current_expected_place(), False
)
returned_optim_state_dict["master_weights"][static_name] = state_dict_master_weight.pop(key)

# master weight cast (only in remove_master_weight)
if returned_optim_state_dict["master_weights"][static_name].dtype != paddle.float32:
returned_optim_state_dict["master_weights"][static_name] = paddle.cast(
returned_optim_state_dict["master_weights"][static_name], dtype=paddle.float32
)

returned_optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER])

return returned_optim_state_dict
4 changes: 3 additions & 1 deletion paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,9 @@ def save_unified_optimizer(self, model, optimizer, output_dir, signal_dir):
return

if is_sharding_split_param_mode(self.args):
optim_state_dict, master_weights = gather_splited_param_for_optimizer(optimizer)
optim_state_dict, master_weights = gather_splited_param_for_optimizer(
optimizer, self.args.ckpt_quant_stage if "quant_reach_limit" not in infohub else "O0"
)
else:
optim_state_dict = nested_copy(optimizer.state_dict())
master_weights = None
Expand Down
36 changes: 31 additions & 5 deletions paddlenlp/transformers/deepseek_v2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):


class DeepseekV2MLP(nn.Layer):
def __init__(self, config: DeepseekV2Config, hidden_size=None, intermediate_size=None):
def __init__(self, config: DeepseekV2Config, hidden_size=None, intermediate_size=None, is_moe=False):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
Expand All @@ -580,7 +580,7 @@ def __init__(self, config: DeepseekV2Config, hidden_size=None, intermediate_size
ColumnParallelLinear = linear_utils.ColumnParallelLinear
RowParallelLinear = linear_utils.RowParallelLinear

if config.tensor_parallel_degree > 1:
if config.tensor_parallel_degree > 1 and not is_moe:
self.gate_proj = ColumnParallelLinear(
self.hidden_size,
self.intermediate_size,
Expand Down Expand Up @@ -753,14 +753,14 @@ def __init__(self, config):
self.ep_rank = 0
self.experts = nn.LayerList(
[
DeepseekV2MLP(config, intermediate_size=config.moe_intermediate_size)
DeepseekV2MLP(config, intermediate_size=config.moe_intermediate_size, is_moe=True)
for i in range(config.n_routed_experts)
]
)
self.gate = MoEGate(config)
if config.n_shared_experts is not None:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekV2MLP(config=config, intermediate_size=intermediate_size)
self.shared_experts = DeepseekV2MLP(config=config, intermediate_size=intermediate_size, is_moe=True)

def forward(self, hidden_states):
identity = hidden_states
Expand Down Expand Up @@ -1158,7 +1158,8 @@ def _get_name_mappings(cls, config: DeepseekV2Config) -> list[StateDictNameMappi
["embed_tokens.weight"],
["norm.weight"],
]
for layer_index in range(config.num_hidden_layers):
# last one layer contains MTP (eagle) parameters for inference
for layer_index in range(config.num_hidden_layers + config.num_nextn_predict_layers):
layer_mappings = [
[f"layers.{layer_index}.self_attn.q_proj.weight", None, "transpose"],
[f"layers.{layer_index}.self_attn.q_a_proj.weight", None, "transpose"],
Expand All @@ -1178,6 +1179,7 @@ def _get_name_mappings(cls, config: DeepseekV2Config) -> list[StateDictNameMappi

# MoE parameters
model_mappings.append([f"layers.{layer_index}.mlp.gate.weight", None, "transpose"])
model_mappings.append([f"layers.{layer_index}.mlp.gate.e_score_correction_bias"])
for expert_idx in range(config.n_routed_experts):
expert_mappings = [
[f"layers.{layer_index}.mlp.experts.{expert_idx}.gate_proj.weight", None, "transpose"],
Expand All @@ -1189,6 +1191,15 @@ def _get_name_mappings(cls, config: DeepseekV2Config) -> list[StateDictNameMappi
model_mappings.append([f"layers.{layer_index}.mlp.shared_experts.up_proj.weight", None, "transpose"])
model_mappings.append([f"layers.{layer_index}.mlp.shared_experts.down_proj.weight", None, "transpose"])

# MTP (eagle) parameters for inference
if layer_index >= config.num_hidden_layers:
model_mappings.append([f"layers.{layer_index}.embed_tokens.weight"])
model_mappings.append([f"layers.{layer_index}.enorm.weight"])
model_mappings.append([f"layers.{layer_index}.hnorm.weight"])
model_mappings.append([f"layers.{layer_index}.eh_proj.weight", None, "transpose"])
model_mappings.append([f"layers.{layer_index}.shared_head.norm.weight"])
model_mappings.append([f"layers.{layer_index}.shared_head.head.weight", None, "transpose"])

init_name_mappings(mappings=model_mappings)
if cls.base_model_class.__name__ not in config.architectures:
for mapping in model_mappings:
Expand Down Expand Up @@ -1251,6 +1262,21 @@ def get_tensor_parallel_split_mappings(num_layers):
final_actions[key.replace("layers.0.", f"layers.{i}.")] = action
final_actions[key] = action

# for MTP (eagle) parameters for inference
base_actions.pop("embed_tokens.weight")
base_actions.pop("lm_head.weight")
base_actions["layers.0.embed_tokens.weight"] = partial(fn, is_column=False)
base_actions["layers.0.eh_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.shared_head.head.weight"] = partial(fn, is_column=True)
for key, action in base_actions.items():
if "layers.0." in key:
for i in range(
config.num_hidden_layers, config.num_hidden_layers + config.num_nextn_predict_layers
):
final_actions[key.replace("layers.0.", f"layers.{i}.")] = action
else:
final_actions[key] = action

return final_actions

mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers)
Expand Down

0 comments on commit a8f3839

Please sign in to comment.