From 39a8a114eca7319a6ee5af781125bf2cd73d86b6 Mon Sep 17 00:00:00 2001 From: gc-fu Date: Thu, 21 Dec 2023 11:25:34 +0800 Subject: [PATCH] fix format --- .../bigdl/llm/transformers/models/llama.py | 39 ++++++++++--------- .../vllm/model_executor/models/bigdl_llama.py | 5 +-- 2 files changed, 23 insertions(+), 21 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index a373c73a1cc..860b970d9f0 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -344,7 +344,7 @@ def llama_attention_selective_batching_forward_4_31( # is_q4_0 = self.q_proj.qtype == SYM_INT4 # no_tp = not self.config.pretraining_tp > 1 # decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and - # enough_kv_room and bsz * q_len == 1) + # enough_kv_room and bsz * q_len == 1) # single batch decoding fast path # forward_qkv takes will perform QKV projection, rotary position embedding @@ -376,11 +376,11 @@ def llama_attention_selective_batching_forward_4_31( value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, - self.num_heads, self.head_dim).transpose(1, 2) + self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, - self.num_key_value_heads, self.head_dim).transpose(1, 2) + self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, - self.num_key_value_heads, self.head_dim).transpose(1, 2) + self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: @@ -389,7 +389,7 @@ def llama_attention_selective_batching_forward_4_31( # TODO: fuse_rope cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, - cos, sin, position_ids, "llama") + cos, sin, position_ids, "llama") updated_past_key_values = [] if past_key_value is not None: @@ -399,8 +399,10 @@ def llama_attention_selective_batching_forward_4_31( past_k, past_v = past_key_value[batch] current_kv_len = past_k.shape[-2] + 1 - current_key_states = torch.cat([past_k, key_states[batch: batch + 1, : , :, :]], dim=2) - current_value_states = torch.cat([past_v, value_states[batch: batch + 1, :, :, :]], dim=2) + current_key_states = torch.cat([past_k, + key_states[batch: batch + 1, :, :, :]], dim=2) + current_value_states = torch.cat([past_v, + value_states[batch: batch + 1, :, :, :]], dim=2) updated_past_key_values.append((current_key_states, current_value_states)) @@ -419,9 +421,9 @@ def llama_attention_selective_batching_forward_4_31( self.num_heads) if attn_output.size() != (1, self.num_heads, 1, self.head_dim): invalidInputError(False, - f"`attn_output` should be of size {(1, self.num_heads, 1, self.head_dim)}, but is" - f" {attn_output.size()}" - ) + f"`attn_output` should be of size " + f"{(1, self.num_heads, 1, self.head_dim)}, but is" + f" {attn_output.size()}") batched_attention_output.append(attn_output) # For loop ends # TODO: handle attention_weights later @@ -429,9 +431,9 @@ def llama_attention_selective_batching_forward_4_31( batched_attention_output.clear() if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): invalidInputError(False, - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) + f"`attn_output` should be of size " + f"{(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}") attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) @@ -440,7 +442,8 @@ def llama_attention_selective_batching_forward_4_31( # TODO: Assume always use_cache print(f"prefill with batch size {bsz}") for batch in range(bsz): - updated_past_key_values.append((key_states[batch: batch + 1, :, :, :], value_states[batch: batch+1, :, :, :])) + updated_past_key_values.append((key_states[batch: batch + 1, :, :, :], + value_states[batch: batch+1, :, :, :])) # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups).to(device, @@ -459,9 +462,9 @@ def llama_attention_selective_batching_forward_4_31( if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): invalidInputError(False, - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) + f"`attn_output` should be of size " + f"{(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}") attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) @@ -637,7 +640,7 @@ def llama_model_selective_batching_forward_4_31( attn_mask, (1, seq_length), inputs_embeds, past_key_value_length ) attention_mask[i] = new_mask - i+=1 + i += 1 hidden_states = inputs_embeds diff --git a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py index 949ae3c05a2..b0751e55430 100644 --- a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py +++ b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py @@ -182,7 +182,6 @@ def forward( # TODO: this could be deleted after prefill stage is also selective_batched decoding_attention_mask_list = [] decoding_position_ids = [] - # Attention_mask for decoding could also be a list of tensors due to inconsistent length of kv_cache # num_layers x len(seq_id) x (2 x torch.Tensor) if is_decoding_stage: batch = 0 @@ -198,12 +197,12 @@ def forward( cur_attention_mask = [0] * (current_seq_len - cur_pos + 1) + [1] * (cur_pos) decoding_attention_mask_list.append(cur_attention_mask) - bigdl_input_ids = torch.tensor(bigdl_input_ids, device=self.device) # TODO: prefill requests could also be sbed, so that we can remove attention_mask forever if is_decoding_stage: - attention_mask = [torch.tensor(x, device=self.device).unsqueeze(0) for x in decoding_attention_mask_list] + attention_mask = [torch.tensor(x, device=self.device).unsqueeze(0) + for x in decoding_attention_mask_list] position_ids = torch.tensor(decoding_position_ids).long().unsqueeze(-1) kwargs = { "input_ids": bigdl_input_ids,