Skip to content

Commit

Permalink
fix format
Browse files Browse the repository at this point in the history
  • Loading branch information
gc-fu committed Dec 21, 2023
1 parent 888eb70 commit 39a8a11
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 21 deletions.
39 changes: 21 additions & 18 deletions python/llm/src/bigdl/llm/transformers/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def llama_attention_selective_batching_forward_4_31(
# is_q4_0 = self.q_proj.qtype == SYM_INT4
# no_tp = not self.config.pretraining_tp > 1
# decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and
# enough_kv_room and bsz * q_len == 1)
# enough_kv_room and bsz * q_len == 1)

# single batch decoding fast path
# forward_qkv takes will perform QKV projection, rotary position embedding
Expand Down Expand Up @@ -376,11 +376,11 @@ def llama_attention_selective_batching_forward_4_31(
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len,
self.num_heads, self.head_dim).transpose(1, 2)
self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len,
self.num_key_value_heads, self.head_dim).transpose(1, 2)
self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len,
self.num_key_value_heads, self.head_dim).transpose(1, 2)
self.num_key_value_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
Expand All @@ -389,7 +389,7 @@ def llama_attention_selective_batching_forward_4_31(
# TODO: fuse_rope
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids, "llama")
cos, sin, position_ids, "llama")

updated_past_key_values = []
if past_key_value is not None:
Expand All @@ -399,8 +399,10 @@ def llama_attention_selective_batching_forward_4_31(
past_k, past_v = past_key_value[batch]
current_kv_len = past_k.shape[-2] + 1

current_key_states = torch.cat([past_k, key_states[batch: batch + 1, : , :, :]], dim=2)
current_value_states = torch.cat([past_v, value_states[batch: batch + 1, :, :, :]], dim=2)
current_key_states = torch.cat([past_k,
key_states[batch: batch + 1, :, :, :]], dim=2)
current_value_states = torch.cat([past_v,
value_states[batch: batch + 1, :, :, :]], dim=2)

updated_past_key_values.append((current_key_states, current_value_states))

Expand All @@ -419,19 +421,19 @@ def llama_attention_selective_batching_forward_4_31(
self.num_heads)
if attn_output.size() != (1, self.num_heads, 1, self.head_dim):
invalidInputError(False,
f"`attn_output` should be of size {(1, self.num_heads, 1, self.head_dim)}, but is"
f" {attn_output.size()}"
)
f"`attn_output` should be of size "
f"{(1, self.num_heads, 1, self.head_dim)}, but is"
f" {attn_output.size()}")
batched_attention_output.append(attn_output)
# For loop ends
# TODO: handle attention_weights later
attn_output = torch.concat(batched_attention_output, dim=0)
batched_attention_output.clear()
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
invalidInputError(False,
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
f"`attn_output` should be of size "
f"{(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}")
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
Expand All @@ -440,7 +442,8 @@ def llama_attention_selective_batching_forward_4_31(
# TODO: Assume always use_cache
print(f"prefill with batch size {bsz}")
for batch in range(bsz):
updated_past_key_values.append((key_states[batch: batch + 1, :, :, :], value_states[batch: batch+1, :, :, :]))
updated_past_key_values.append((key_states[batch: batch + 1, :, :, :],
value_states[batch: batch+1, :, :, :]))

# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups).to(device,
Expand All @@ -459,9 +462,9 @@ def llama_attention_selective_batching_forward_4_31(

if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
invalidInputError(False,
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
f"`attn_output` should be of size "
f"{(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}")
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

Expand Down Expand Up @@ -637,7 +640,7 @@ def llama_model_selective_batching_forward_4_31(
attn_mask, (1, seq_length), inputs_embeds, past_key_value_length
)
attention_mask[i] = new_mask
i+=1
i += 1

hidden_states = inputs_embeds

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,6 @@ def forward(
# TODO: this could be deleted after prefill stage is also selective_batched
decoding_attention_mask_list = []
decoding_position_ids = []
# Attention_mask for decoding could also be a list of tensors due to inconsistent length of kv_cache
# num_layers x len(seq_id) x (2 x torch.Tensor)
if is_decoding_stage:
batch = 0
Expand All @@ -198,12 +197,12 @@ def forward(
cur_attention_mask = [0] * (current_seq_len - cur_pos + 1) + [1] * (cur_pos)
decoding_attention_mask_list.append(cur_attention_mask)


bigdl_input_ids = torch.tensor(bigdl_input_ids, device=self.device)

# TODO: prefill requests could also be sbed, so that we can remove attention_mask forever
if is_decoding_stage:
attention_mask = [torch.tensor(x, device=self.device).unsqueeze(0) for x in decoding_attention_mask_list]
attention_mask = [torch.tensor(x, device=self.device).unsqueeze(0)
for x in decoding_attention_mask_list]
position_ids = torch.tensor(decoding_position_ids).long().unsqueeze(-1)
kwargs = {
"input_ids": bigdl_input_ids,
Expand Down

0 comments on commit 39a8a11

Please sign in to comment.