Skip to content

Commit

Permalink
[ascend]Zcx/llama2 infer 910b (#1254)
Browse files Browse the repository at this point in the history
* optimize lightllm

* fix promptFlashAttention on a+x

* add check for incre flash attention

* add description of the added funtion
  • Loading branch information
zhaochaoxing authored and yangbofun committed Jun 12, 2024
1 parent 20decf8 commit e915517
Show file tree
Hide file tree
Showing 13 changed files with 861 additions and 95 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/_runs-on-nv-step1.yml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ jobs:
&& python main.py --mode gen_data" \
|| ( cd ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1} && git clean -xdf ${GEN_DATA} && exit 1 )
source ~/Aoss_env.sh
ads-cli cp s3://${Platform_ci_aoss_name}:${Platform_ci_aoss_url}@platform.aoss.cn-sh-01c.sensecoreapi-oss.cn${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/diopi_test/python/cache/ ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/diopi_test/python/cache/
ads-cli cp ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/diopi_test/python/cache/ s3://${Platform_ci_aoss_name}:${Platform_ci_aoss_url}@platform.aoss.cn-sh-01c.sensecoreapi-oss.cn${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/diopi_test/python/cache/
elif [[ "${GETRUNNER}" == *diopi* ]];then
ssh SH1424 """
set -e
Expand All @@ -87,7 +87,7 @@ jobs:
srun --job-name=${GITHUB_JOB} --partition=${SLURM_PAR_V100} --time=20 --gres=gpu:1 bash -c 'python main.py --mode gen_data' \
|| ( cd ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1} && git clean -xdf ${GEN_DATA} && exit 1 )
source ~/Aoss_env.sh
ads-cli cp s3://${Platform_ci_aoss_name}:${Platform_ci_aoss_url}@platform.aoss.cn-sh-01c.sensecoreapi-oss.cn${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/diopi_test/python/cache/ ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/diopi_test/python/cache/
ads-cli cp ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/diopi_test/python/cache/ s3://${Platform_ci_aoss_name}:${Platform_ci_aoss_url}@platform.aoss.cn-sh-01c.sensecoreapi-oss.cn${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/diopi_test/python/cache/
"""
else
ln -s ${GEN_DATA_PATH}/${GEN_DATA}/diopi ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/diopi_test/python/cache/
Expand Down
158 changes: 158 additions & 0 deletions diopi_test/python/configs/diopi_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8305,6 +8305,36 @@
),
),

'rotary_emb_v2': dict(
name=['rotary_emb_v2'],
interface=['CustomizedTest'],
dtype=[np.float32, np.float16],
para=dict(
dim=[128,]
),
tensor_para=dict(
gen_fn='Genfunc.randn',
args=[
{
"ins": ['query'],
"shape": ((8, 4096),),
},
{
"ins": ['key'],
"shape": ((8, 4096),),
},
{
"ins": ['cos'],
"shape": ((8, 1, 128),),
},
{
"ins": ['sin'],
"shape": ((8, 1, 128),),
},
],
),
),

'rms_norm_default': dict(
name=['rms_norm'],
atol=1e-4,
Expand Down Expand Up @@ -8551,6 +8581,134 @@
),
),

'prompt_flash_attention': dict(
name=['prompt_flash_attention'],
interface=['CustomizedTest'],
atol=1e-2,
rtol=1e-2,
para=dict(
maxInputLen=[2,],
actualSeqLengths=[[2,2],],
numHeads=[32,],
numKeyValueHeads=[32,],
dim=[128,],
),
tensor_para=dict(
args=[
{
"ins": ["query"],
"shape": ((4, 4096),),
"dtype": [np.float16,],
},
{
"ins": ["key"],
"shape": ((4, 4096),),
"dtype": [np.float16,],
},
{
"ins": ["value"],
"shape": ((4, 4096),),
"dtype": [np.float16,],
},
{
"ins": ["attenMask"],
"value": ([[[False, True],
[False, False]],
[[False, True],
[False, False]]],),
"dtype": [np.bool_,],
"gen_policy": "gen_tensor_by_value"
},
]
),
),

'paged_attention': dict(
name=['paged_attention'],
interface=['CustomizedTest'],
atol=1e-2,
rtol=1e-2,
para=dict(
actualSeqLengths=[[150,],],
numHeads=[32,],
numKeyValueHeads=[32,],
dim=[128,],
blockSize=[128,],
),
tensor_para=dict(
args=[
{
"ins": ["query"],
"shape": ((1, 4096),),
"dtype": [np.float16,],
},
{
"ins": ["key"],
"shape": ((1026, 4096),),
"dtype": [np.float16,],
},
{
"ins": ["value"],
"shape": ((1026, 4096),),
"dtype": [np.float16,],
},
{
"ins": ["blockTable"],
"value": ([[0, 1],],),
"dtype": [np.int32,],
"gen_policy": "gen_tensor_by_value"
},
]
),
),

'apply_penalty_v2': dict(
name=['apply_penalty_v2'],
interface=['CustomizedTest'],
tensor_para=dict(
args=[
{
"ins": ['logits'],
"value": ([[0.1, 0.5, 0.4, 0.3, 0.5],
[0.2, 0.4, 0.0, 0.0, 0.0],
[0.3, 0.4, 0.5, 0.3, 0.0]],),
"dtype": [np.float16, np.float32],
"gen_policy": "gen_tensor_by_value"
},
{
"ins": ["presence_penalty"],
"value": ([0.1, 0.1, 0.1, 0.1, 0.8, 0.8, 0.8, 0.8, 1.0, 1.0, 1.0],),
"dtype": [np.float16, np.float32],
"gen_policy": "gen_tensor_by_value"
},
{
"ins": ["frequency_penalty"],
"value": ([0.1, 0.1, 0.1, 0.1, 0.8, 0.8, 0.8, 0.8, 1.0, 1.0, 1.0],),
"dtype": [np.float16, np.float32],
"gen_policy": "gen_tensor_by_value"
},
{
"ins": ["repetition_penalty"],
"value": ([0.1, 0.1, 0.1, 0.1, 0.8, 0.8, 0.8, 0.8, 1.0, 1.0, 1.0],),
"dtype": [np.float16, np.float32],
"gen_policy": "gen_tensor_by_value"
},
{
"ins": ["p_token_ids"],
"value": ([0, 1, 2, 3, 4, 5, 6, 8, 9, 10, 11],),
"dtype": [np.int32, np.int32],
"gen_policy": "gen_tensor_by_value"
},
{
"ins": ["p_token_counts"],
"value": ([3, 3, 2, 2, 1, 3, 3, 3, 3, 2, 2],),
"dtype": [np.int32, np.int32],
"gen_policy": "gen_tensor_by_value"
},
]
)
),

'token_attention': dict(
name=['token_attention'],
interface=['CustomizedTest'],
Expand Down
112 changes: 111 additions & 1 deletion diopi_test/python/conformance/customized_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,108 @@ def context_attention(q, k, v, out, b_start_loc, b_seq_len, max_input_len):
)
return out

def prompt_flash_attention(
query,
key,
value,
attenMask,
actualSeqLengths,
maxInputLen,
numHeads,
numKeyValueHeads,
dim,
):
bs = len(actualSeqLengths)
xq = query.view(bs, maxInputLen, numHeads, dim).cuda()
keys = key.view(bs, maxInputLen, numKeyValueHeads, dim).cuda()
values = value.view(bs, maxInputLen, numKeyValueHeads, dim).cuda()
mask = (
torch.tril(torch.ones(maxInputLen, maxInputLen), diagonal=0)
.unsqueeze(0)
.unsqueeze(0)
.cuda()
)
mask = mask.masked_fill(mask == 0.0, -100000000.0)
mask = mask.repeat(bs, numHeads, 1, 1)
xq = xq.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(dim)
scores = F.softmax(scores.float() + mask, dim=-1).type_as(xq)
out = torch.matmul(scores, values).transpose(1, 2).contiguous()
return out.reshape(bs * maxInputLen, numHeads * dim)

def paged_attention(
query,
key,
value,
actualSeqLengths,
numHeads,
numKeyValueHeads,
dim,
blockTable,
blockSize,
):
# q: BSH
b_loc = torch.arange(key.shape[0], dtype=torch.int32).reshape(1, -1).cuda()
batch = b_loc.shape[0]
xq = query.view(batch, 1, numHeads, dim).transpose(1, 2).cuda()
k = key.view(-1, numKeyValueHeads, dim).cuda()
v = value.view(-1, numKeyValueHeads, dim).cuda()
out = torch.empty([batch, numHeads, dim], device="cuda", dtype=query.dtype)
max_input_len = max(actualSeqLengths)
b_seq_len = torch.tensor(actualSeqLengths, dtype=torch.int32).cuda()
for i in range(batch):
k_loc = b_loc[i][
max_input_len
- b_seq_len[i]
+ torch.arange(0, b_seq_len[i], device="cuda", dtype=torch.int32)
]
key = k[k_loc, :].view(1, b_seq_len[i], numHeads, dim).transpose(1, 2)
logics = (
torch.matmul(xq[i, :], key.transpose(2, 3)) / math.sqrt(dim)
).reshape(numHeads, b_seq_len[i])
v_loc = b_loc[i][
max_input_len
- b_seq_len[i]
+ torch.arange(0, b_seq_len[i], device=logics.device, dtype=torch.int32)
]
P = logics.softmax(-1).reshape(1, numHeads, 1, b_seq_len[i])
V = v[v_loc, :].view(1, b_seq_len[i], numHeads, dim).transpose(1, 2)
out[i, :] = torch.matmul(P, V).view(numHeads, dim)
return out.view(-1, numHeads * dim)

def apply_penalty_v2(
logits,
presence_penalty,
frequency_penalty,
repetition_penalty,
p_token_ids,
p_token_counts,
):
batch = logits.shape[0]
logits = logits.view(-1)
cur_logits = logits.index_select(0, p_token_ids)
rep_logits = torch.where(
cur_logits > 0,
cur_logits / repetition_penalty,
cur_logits * repetition_penalty,
)
rep_logits = rep_logits - p_token_counts * frequency_penalty - presence_penalty
logits[p_token_ids] = rep_logits
return logits.view(batch, -1)

def rotary_emb_v2(query, key, cos, sin, dim):
query = query.view(query.shape[0], -1, dim)
key = key.view(key.shape[0], -1, dim)
q1, q2 = query.chunk(2, dim=-1)
query_rotate = torch.cat((-q2, q1), dim=-1)
query = query * cos + query_rotate * sin
k1, k2 = key.chunk(2, dim=-1)
key_rotate = torch.cat((-k2, k1), dim=-1)
key = key * cos + key_rotate * sin
return query.view(query.shape[0], -1), key.view(key.shape[0], -1)

def attention(
query,
key,
Expand Down Expand Up @@ -742,5 +844,13 @@ def attention_varlen(
start_idx = cu_seqlens_q[i]
end_idx = cu_seqlens_q[i + 1]
actual_seq_len = end_idx - start_idx
out[start_idx:end_idx, :, :] = out_paded[i, :actual_seq_len, :, :] # BSND->TND
out[start_idx:end_idx, :, :] = out_paded[
i, :actual_seq_len, :, :
] # BSND->TND
return out

def nll_loss_v2(input, target, weight=None, ignore_index=-100, reduction="mean"):
out = torch.nn.functional.nll_loss(
input, target, weight, None, ignore_index, None, reduction
)
return out
Loading

0 comments on commit e915517

Please sign in to comment.