Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

regroup k/v weight. #219

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 46 additions & 16 deletions chatlearn/synchronizer/megatron_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class MegatronVllmSync(BaseSync):
def __init__(self, src_model, dst_model):
super().__init__(src_model, dst_model)
self.src_module_args = src_model.module_args
self.dst_module_args = dst_model.module_args
self.is_parameter_changed = True

@abstractmethod
Expand Down Expand Up @@ -322,15 +323,17 @@ def transform_parameters(self, params_to_sync_list):
params_to_sync_list = self.fix_shared_expert_ordering(params_to_sync_list)
return params_to_sync_list

def regroup_qkv_tp_slices(self, name, param_data, tp_divition):
def regroup_qkv_tp_slices(self, name, param_data, tp_division):
param_data_shape = param_data.shape
# Regroup qkv tensors into different tp slices only for inference model which enables vLLM backend.
to_fix_qkv_ordering_dict = self.sync_map.to_fix_qkv_ordering_dict
# pylint: disable=too-many-nested-blocks
if "attention.query_key_value" in name or \
"self_attention.query_key_value" in name or \
"self_attention.linear_qkv" in name:
tp_size = self.src_module_args.args_dict["tensor_model_parallel_size"]
heads = self.src_module_args.args_dict["num_attention_heads"] // tp_size
src_tp_size = self.src_module_args.args_dict["tensor_model_parallel_size"]
dst_tp_size = self.dst_module_args.args_dict["tensor_model_parallel_size"]
heads = self.src_module_args.args_dict["num_attention_heads"] // src_tp_size
hidden_size_per_head = self.src_module_args.args_dict["hidden_size"] // self.src_module_args.args_dict["num_attention_heads"]

param_shape = (3, heads, hidden_size_per_head) + param_data_shape[1:]
Expand All @@ -340,31 +343,58 @@ def regroup_qkv_tp_slices(self, name, param_data, tp_divition):
if to_fix_qkv_ordering_dict is not None:
param_data = param_data.view(param_shape)
param_data_list = []
head_offset = heads // tp_divition
for idx in range(tp_divition):
head_offset = heads // tp_division
for idx in range(tp_division):
start = idx * head_offset
end = start + head_offset
param_data_list.append(param_data[:,start:end])
param_data = torch.concat(param_data_list, dim=0).view(param_data_shape)
del param_data_list
else:
_num_query_groups = self.src_module_args.args_dict["num_query_groups"]//tp_size \
if self.src_module_args.args_dict["group_query_attention"] else heads
if to_fix_qkv_ordering_dict is not None or _num_query_groups == 1:
if self.src_module_args.args_dict["group_query_attention"]:
num_query_groups = self.src_module_args.args_dict["num_query_groups"]
assert num_query_groups == self.dst_module_args.args_dict["num_query_groups"], (
f"num_query_groups of src model ({num_query_groups}) must be equal to num_query_groups of "
f"dst model ({self.dst_moduel_args.args_dict['num_query_groups']}). Please double-check your config."
)
src_num_query_groups_per_replica = num_query_groups // src_tp_size
if dst_tp_size >= num_query_groups:
num_dst_kv_head_replicas = dst_tp_size // num_query_groups
else:
num_dst_kv_head_replicas = 1
else:
src_num_query_groups_per_replica = heads
num_dst_kv_head_replicas = 1

if to_fix_qkv_ordering_dict is not None or src_num_query_groups_per_replica == 1:
if len(param_data_shape) == 1:
param_data = param_data.view((heads + 2 * _num_query_groups, hidden_size_per_head))
param_data = param_data.view((heads + 2 * src_num_query_groups_per_replica, hidden_size_per_head))
else:
param_data = param_data.view(
(heads + 2 * _num_query_groups, hidden_size_per_head, self.src_module_args.args_dict["hidden_size"]))
(heads + 2 * src_num_query_groups_per_replica, hidden_size_per_head, self.src_module_args.args_dict["hidden_size"]))
param_data_list = []
head_offset = heads // tp_divition
for idx in range(tp_divition):
head_offset = heads // tp_division
for idx in range(tp_division):
q_start = idx * head_offset
q_end = q_start + head_offset
k_start = (heads + idx) if _num_query_groups // tp_divition else heads
k_end = k_start + 1
v_start = k_start + _num_query_groups
v_end = v_start + 1
if num_dst_kv_head_replicas == 1:
if src_num_query_groups_per_replica > tp_division:
assert src_num_query_groups_per_replica % tp_division == 0, (
f"num_query_groups per replica of src model ({src_num_query_groups_per_replica}) "
f"must be divisible by tp_division ({tp_division}). Please double-check your config."
)
kv_offset = src_num_query_groups_per_replica // tp_division
else:
kv_offset = 1
k_start = (heads + idx) if src_num_query_groups_per_replica // tp_division else heads
k_end = k_start + kv_offset
v_start = k_start + src_num_query_groups_per_replica
v_end = v_start + kv_offset
else:
k_start = heads + idx // num_dst_kv_head_replicas
k_end = k_start + 1
v_start = k_start + src_num_query_groups_per_replica
v_end = v_start + 1

q_proj = param_data[q_start:q_end].contiguous()
k_proj = param_data[k_start:k_end].contiguous()
Expand Down