Skip to content

Commit

Permalink
Enabled configurable auto Tensor Parallelism (TP) for inference of di…
Browse files Browse the repository at this point in the history
…verse models
  • Loading branch information
gyou2021 committed Sep 18, 2024
1 parent 08f728d commit f6e8637
Showing 1 changed file with 17 additions and 25 deletions.
42 changes: 17 additions & 25 deletions deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from deepspeed.accelerator import get_accelerator
from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw, shard_value_with_share_qk, shard_chunk_mlp
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list
import os
import ast


def move(tensor, device):
Expand Down Expand Up @@ -270,6 +272,7 @@ def kernel_supported(module_list):
return True
return False

## tp parser based on autoTP config in environment
def tp_parser(model):
policy_list = []
module_list = []
Expand All @@ -279,40 +282,27 @@ def tp_parser(model):
module_list = AutoTP.get_module_list(model)
assert AutoTP.supported(model), "AutoTP not supported for model. Please use kernel injection since container policy for model exists." \
if AutoTP.kernel_supported(module_list) else "AutoTP not supported for model. Please provide policy."
norm_layer_name_list = ['LayerNorm', 'layer_norm', 'ln_1', 'ln_2']
#ln_1 , ln_2 for Qwen

allReduceLinearItems = os.environ['allReduceLinearItems']
allReduceLinearItems = ast.literal_eval(allReduceLinearItems)

for module in module_list:
for key, submodule in module._modules.items():
if isinstance(submodule, nn.Linear):
layer_list = layer_list + ["." + key]
elif isinstance(submodule, nn.LayerNorm) or key in norm_layer_name_list:
elif isinstance(submodule, nn.LayerNorm) or key == 'LayerNorm' or key == 'layer_norm':
layer_list = layer_list + ["ln"]
else:
layer_list = layer_list + AutoTP.get_layers(key, submodule)

for i, layer in enumerate(layer_list):
if layer == 'ln':
if layer_list[i - 1] != 'ln':
gem_list = gem_list + [layer_list[i - 1]]
elif 'out_proj' in layer:
gem_list = gem_list + [layer]
elif 'o_proj' in layer:
gem_list = gem_list + [layer]
elif 'down_proj' in layer:
gem_list = gem_list + [layer]
elif 'attention.dense' in layer and 'GPTNeoX' in str(model):
gem_list = gem_list + [layer]
elif 'self_attention.dense' in layer and 'falcon' in str(
type(module)): # this is a hack to get the right linear layer for this model!
gem_list = gem_list + [layer]
# Mixtral-7x8b used w2*act(w1*w3) linear. need to replace w2 to linearallreduce.
elif 'w2' in layer and 'Mixtral' in str(type(module)):
gem_list = gem_list + [layer]
elif 'self_attn.dense' in layer and 'Phi' in str(type(module)):
gem_list = gem_list + [layer]
elif 'self_attention.dense' in layer and 'ChatGLM' in str(model):
gem_list = gem_list + [layer]
elif 'dense_4h_to_h' in layer and 'ChatGLM' in str(model):
gem_list = gem_list + [layer]
continue
for item in allReduceLinearItems:
if item in layer:
gem_list = gem_list + [layer]

layer_list = []
if gem_list != []:
Expand Down Expand Up @@ -473,8 +463,10 @@ def _replace_module(self, r_module, prev_name='', prev_class_name=''):
if len(child._buffers) != 0 and self.state_dict is not None:
Loading.load_buffer(child, self.state_dict, checking_key)
if child.__class__ in self.linear_policies:
if ('shared_expert_gate' not in checking_key and '.gate.' not in checking_key
and 'qwen2_moe' in str(type(r_module))) or 'qwen2_moe' not in str(type(r_module)):
keepLinearItems = os.environ['keepLinearItems']
keepLinearItems = ast.literal_eval(keepLinearItems)

if any(item not in checking_key for item in keepLinearItems):
setattr(
r_module, name, self.linear_policies[child.__class__](child, prev_name + '.' + name,
self.conv_linear_layer))
Expand Down

0 comments on commit f6e8637

Please sign in to comment.