From b30eac758a4b89fcbe20cb877f0e3a917054a9b9 Mon Sep 17 00:00:00 2001 From: gc-fu Date: Thu, 7 Dec 2023 16:22:14 +0800 Subject: [PATCH 01/26] add control to load hf model --- .../vllm/model_executor/models/bigdl_llama.py | 89 ++++++++++++------- 1 file changed, 55 insertions(+), 34 deletions(-) diff --git a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py index cecf4df616c..6c542eeb63d 100644 --- a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py +++ b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py @@ -27,6 +27,7 @@ import math import time from bigdl.llm.vllm.model_executor.input_metadata import InputMetadata +import os from transformers.generation.logits_process import ( LogitsProcessorList, RepetitionPenaltyLogitsProcessor, @@ -62,42 +63,62 @@ def __init__( super().__init__(config, device, max_model_len) self.config = config # TODO(gc): later change this to a switch? - if True: + use_bigdl_lowbit = os.getenv("VLLM_USE_BIGDL_LOWBIT") + if use_bigdl_lowbit.lower() == "true": from bigdl.llm.transformers import AutoModelForCausalLM from bigdl.llm import optimize_model - - # low_bit = 'sym_int4' - if device == 'cpu': - model = AutoModelForCausalLM.from_pretrained( - config._name_or_path, - low_cpu_mem_usage=True, - trust_remote_code=True, - use_cache=True, - ) - self.model = optimize_model(model) - self.sampler = BigDLSampler(config.vocab_size, device) - elif device == 'xpu': - try: - import intel_extension_for_pytorch as ipex - except ImportError: - print("Intel Extension for PyTorch is not installed, \ - but is required for xpu inference.") - - low_bit = 'sym_int4' - model = AutoModelForCausalLM.from_pretrained( - config._name_or_path, - load_in_low_bit=low_bit, - trust_remote_code=True, - use_cache=True, - ) - self.model = model.to('xpu') - self.sampler = BigDLSampler(config.vocab_size, device).to('xpu') - - if device is None: - self.device = torch.device( - "cuda" if torch.cuda.is_available() else "cpu") + if device == 'cpu': + model = AutoModelForCausalLM.from_pretrained( + config._name_or_path, + low_cpu_mem_usage=True, + trust_remote_code=True, + use_cache=True, + ) + self.model = optimize_model(model) + self.sampler = BigDLSampler(config.vocab_size, device) + elif device == 'xpu': + try: + import intel_extension_for_pytorch as ipex + except ImportError: + print("Intel Extension for PyTorch is not installed, \ + but is required for xpu inference.") + + low_bit = 'sym_int4' + model = AutoModelForCausalLM.from_pretrained( + config._name_or_path, + load_in_low_bit=low_bit, + trust_remote_code=True, + use_cache=True, + ) + self.model = model.to('xpu') + self.sampler = BigDLSampler(config.vocab_size, device).to('xpu') else: - self.device = torch.device(device) + from transformers import AutoModelForCausalLM + if device == 'cpu': + self.model = AutoModelForCausalLM.from_pretrained( + config._name_or_path, + low_cpu_mem_usage=True, + trust_remote_code=True, + use_cache=True, + ) + self.sampler = BigDLSampler(config.vocab_size, device) + elif device == 'xpu': + try: + import intel_extension_for_pytorch as ipex + except ImportError: + print("Intel Extension for PyTorch is not installed, \ + but is required for xpu inference.") + model = AutoModelForCausalLM.from_pretrained( + config._name_or_path, + load_in_low_bit=low_bit, + trust_remote_code=True, + use_cache=True, + ) + self.model = model.to('xpu') + self.sampler = BigDLSampler( + config.vocab_size, device).to('xpu') + + self.device = torch.device(device) self.dtype = self.model.dtype self.last_seq_ids = [] self.tmp_kv_cache = None @@ -116,7 +137,7 @@ def forward( decoder_kv_size = 2 bigdl_input_ids = [] - bigdl_position_ids = [] + # bigdl_position_ids = [] bigdl_attention_mask = [] cur_seq_ids = [] From 7dbe7ce6a73035517b521fe0413f38092aaeff54 Mon Sep 17 00:00:00 2001 From: gc-fu Date: Tue, 12 Dec 2023 15:58:56 +0800 Subject: [PATCH 02/26] finish initial version of selective_batching --- .../vllm/model_executor/models/bigdl_llama.py | 42 +++++++++++-------- .../vllm/model_executor/models/bigdl_model.py | 32 +++++++++++--- 2 files changed, 52 insertions(+), 22 deletions(-) diff --git a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py index 6c542eeb63d..e7537e7d629 100644 --- a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py +++ b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py @@ -63,7 +63,7 @@ def __init__( super().__init__(config, device, max_model_len) self.config = config # TODO(gc): later change this to a switch? - use_bigdl_lowbit = os.getenv("VLLM_USE_BIGDL_LOWBIT") + use_bigdl_lowbit = os.getenv("VLLM_USE_BIGDL_LOWBIT", "") if use_bigdl_lowbit.lower() == "true": from bigdl.llm.transformers import AutoModelForCausalLM from bigdl.llm import optimize_model @@ -110,7 +110,6 @@ def __init__( but is required for xpu inference.") model = AutoModelForCausalLM.from_pretrained( config._name_or_path, - load_in_low_bit=low_bit, trust_remote_code=True, use_cache=True, ) @@ -125,6 +124,12 @@ def __init__( self.pad_token_id = config.pad_token_id self.max_seq_limit = max_model_len + # GC: Note for selective batching + # KV_CACHE in the format of num_layers x 2 x (seq_id -> torch.Tensor) + # past_key_values in the format of num_layers x len(seq_id) x (2 x torch.Tensor) + # If we set num_layers to 9, have 10 sequences in total. + # then, for the kv_cache, we get 9 x 2 x 10 = 180 tensors + # for past_key_values, we get 9 x 10 x 2 = 180 tensors def forward( self, seq_group_meta_data_lists: List[SequenceGroupMetadata], @@ -165,8 +170,8 @@ def forward( # 1. Assemble bigdl_input_ids end if is_decoding_stage: - bigdl_kv_cache = self.prepare_kv_cache(cur_seq_ids, seq_group_meta_data_lists, - kv_cache, num_layers, decoder_kv_size) + bigdl_kv_cache = self.prepare_kv_cache_llama(cur_seq_ids, + kv_cache, num_layers) else: bigdl_attention_mask = _get_attention_mask_for_prompts(bigdl_input_ids, max_prompt_len) bigdl_input_ids = [ @@ -174,26 +179,28 @@ def forward( for input_ids in bigdl_input_ids ] - if is_decoding_stage: - cur_seq_len = bigdl_kv_cache[0][0].size(2) - for seq_group_meta_data in seq_group_meta_data_lists: - seq_ids = list(seq_group_meta_data.seq_data.keys()) - seq_id = seq_ids[0] - seq_data = seq_group_meta_data.seq_data[seq_id] - cur_pos = seq_data.get_len() - # bigdl_position_ids.append([cur_pos - 1]) - cur_attention_mask = [0] * (cur_seq_len - cur_pos + 1) + [1] * (cur_pos) - bigdl_attention_mask.append(cur_attention_mask) + # GC(co): by using selective_batching, we no longer need to create + # attention_mask for decoding. + # if is_decoding_stage: + # cur_seq_len = bigdl_kv_cache[0][0].size(2) + # for seq_group_meta_data in seq_group_meta_data_lists: + # seq_ids = list(seq_group_meta_data.seq_data.keys()) + # seq_id = seq_ids[0] + # seq_data = seq_group_meta_data.seq_data[seq_id] + # cur_pos = seq_data.get_len() + # # bigdl_position_ids.append([cur_pos - 1]) + # cur_attention_mask = [0] * (cur_seq_len - cur_pos + 1) + [1] * (cur_pos) + # bigdl_attention_mask.append(cur_attention_mask) bigdl_input_ids = torch.tensor(bigdl_input_ids, device=self.device) if is_decoding_stage: - # bigdl_position_ids = torch.tensor(bigdl_position_ids, device=self.device) - bigdl_attention_mask = torch.tensor(bigdl_attention_mask, device=self.device) kwargs = { "input_ids": bigdl_input_ids, + # gc(co): we rely on underlying model to generate position_ids # "position_ids": bigdl_position_ids, - "attention_mask": bigdl_attention_mask, + # gc(co): we no longer need attention_mask + # "attention_mask": bigdl_attention_mask, "past_key_values": bigdl_kv_cache, "use_cache": True, # "return_dict": True, @@ -207,6 +214,7 @@ def forward( "use_cache": True, # "return_dict": True, } + # Prefill may need additional space, which forces us to delete the last_kv_cache if self.last_kv_cache: del self.last_kv_cache # pdb.set_trace() diff --git a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_model.py b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_model.py index 6694a3f1f4b..658f2155db0 100644 --- a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_model.py +++ b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_model.py @@ -137,6 +137,30 @@ def prepare_kv_cache( return bigdl_kv_cache + # This is an implementation for models that KV Cache shape in (batch_size, num_heads, + # sequence_length, embed_size_per_head). + def prepare_kv_cache_llama( + self, + cur_seq_ids: List[int], + kv_cache: Dict, + num_layers: int, + ): + # Return bigdl_kv_cache in the format of Tuple(List[Tuple(torch.Tensor)]) + bigdl_kv_cache = [] + for i in range(num_layers): + # Construct a list of tuple(tensor) + temp_cache = [] + for seq_id in cur_seq_ids: + key = kv_cache[i][0][seq_id] + value = kv_cache[i][1][seq_id] + temp_cache.append((key, value)) + bigdl_kv_cache.append(temp_cache) + return bigdl_kv_cache + + # for i in range(len(cur_seq_ids)): + # current_kv = [] + # current_kv.append(kv_cache) + # This is an implementation for models that KV Cache shape in (batch_size, num_heads, # sequence_length, embed_size_per_head). def update_kv_cache( @@ -147,11 +171,9 @@ def update_kv_cache( kv_cache_size_1: int, ) -> None: for i in range(layer): - for j in range(kv_cache_size_1): - batch_dim = 0 - for seq_id in cur_seq_ids: - kv_cache[i][j][seq_id] = self.last_kv_cache[i][j][batch_dim] - batch_dim = batch_dim + 1 + for j in range(len(cur_seq_ids)): + kv_cache[i][0][cur_seq_ids[j]] = self.last_kv_cache[i][j][0] + kv_cache[i][1][cur_seq_ids[j]] = self.last_kv_cache[i][j][1] def forward( self, From dec985adda9741e18d7a5c87751ad0500e83d7ff Mon Sep 17 00:00:00 2001 From: gc-fu Date: Tue, 19 Dec 2023 11:15:57 +0800 Subject: [PATCH 03/26] temp --- .../llm/src/bigdl/llm/transformers/convert.py | 15 +- .../bigdl/llm/transformers/models/llama.py | 387 +++++++++++++++++- .../vllm/model_executor/models/bigdl_llama.py | 4 +- 3 files changed, 402 insertions(+), 4 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index bc53d4eea19..4f7cb681c8a 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -374,7 +374,12 @@ def _optimize_post(model, lightweight_bmm=False): from packaging import version from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31 from bigdl.llm.transformers.models.llama import llama_rms_norm_forward +<<<<<<< HEAD from bigdl.llm.transformers.models.llama import llama_mlp_forward +======= + from bigdl.llm.transformers.models.llama import llama_attention_selective_batching_forward_4_31 + from bigdl.llm.transformers.models.llama import llama_model_selective_batching_forward_4_31 +>>>>>>> 372980fd83 (temp) from transformers.modeling_utils import PreTrainedModel # All huggingface format models are inherited from `PreTrainedModel` @@ -388,14 +393,22 @@ def _optimize_post(model, lightweight_bmm=False): convert_forward( model, transformers.models.llama.modeling_llama.LlamaAttention, - llama_attention_forward_4_31,) + llama_attention_selective_batching_forward_4_31,) convert_forward( model, transformers.models.llama.modeling_llama.LlamaRMSNorm, llama_rms_norm_forward,) +<<<<<<< HEAD convert_forward(model, transformers.models.llama.modeling_llama.LlamaMLP, llama_mlp_forward) +======= + convert_forward( + model, + transformers.models.llama.modeling_llama.LlamaModel, + llama_model_selective_batching_forward_4_31, + ) +>>>>>>> 372980fd83 (temp) else: # todo implement 4.28.0 ~ 4.30.2 pass diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 7594c74baef..07b9605ef33 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -34,13 +34,14 @@ import torch import importlib import torch.nn as nn -from typing import Optional, Tuple +from typing import Optional, Tuple, Union, List import math import torch.nn.functional as F from bigdl.llm.utils.common import invalidInputError from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu +from transformers.modeling_outputs import BaseModelOutputWithPast from bigdl.llm.transformers.low_bit_linear import SYM_INT4 from bigdl.llm.ggml.quantize import ggml_tensor_qtype @@ -310,6 +311,234 @@ def llama_attention_forward_4_31( return attn_output.to(original_dtype), attn_weights, past_key_value +# dev-note: this differs to original llama model in the aspect +# that it has different format for past_key_value +def llama_attention_selective_batching_forward_4_31( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + padding_mask: Optional[torch.LongTensor] = None, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + device = hidden_states.device + # for flash attention + original_dtype = hidden_states.dtype + if not self.training and not hidden_states.requires_grad: + fsdp_flag = check_flash_attention_available(hidden_states) + else: + fsdp_flag = False + if fsdp_flag and q_len > 1: + attention_dtype = torch.float16 # use fp16 for flash attention + else: + attention_dtype = original_dtype + + # TODO: delete + # print("use fsdp_flag:" + str(fsdp_flag)) + # TODO: consider this branch later + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) + // self.config.pretraining_tp, dim=0) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) + for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) + for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) + for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, + self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, + self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, + self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + updated_past_key_values = [] + # Apply rotary_embeddings + max_kv_length = max(kv_pair[0].shape[-2] for kv_pair in past_key_value) + max_kv_length += key_states.shape[-2] + + # TODO: decide if we need to use_fuse_rope + cos, sin = self.rotary_emb(value_states, seq_len=max_kv_length) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, + position_ids, "llama") + # End of applying rotary_embedding + + batched_attention_output = [] + for batch in range(bsz): + # 2. concat key_states, value_states + # Get current key_states, value_states from previous cache + past_k, past_v = past_key_value[batch] + # Should be len + 1 + current_kv_len = past_k.shape[-2] + 1 + if past_k.stride()[1] <= past_k.size(2) * past_k.size(3): + new_past_k, new_past_v = extend_kv_cache(1, + self.num_key_value_heads, + self.head_dim, + past_k.size(2), + current_kv_len + + KV_CACHE_ALLOC_BLOCK_LENGTH, + dtype=past_k.dtype, + device=device) + new_past_k[:] = past_k + new_past_v[:] = past_v + past_k = new_past_k + past_v = new_past_v + current_key_states, current_value_states = append_kv_cache(past_k, past_v, # noqa + key_states[batch:batch + 1, :, :, :], # noqa + value_states[batch:batch + 1, :, :, :] # noqa + ) + # 2. concat key_states, value_states ends + + # 3. Record key_states, and value_states + updated_past_key_values.append((current_key_states, current_value_states)) + # 3. Record key_states, value_states end + + # 4. repeat kv + # repeat k/v heads if n_kv_heads < n_heads + current_key_states = repeat_kv(current_key_states, self.num_key_value_groups) + current_value_states = repeat_kv(current_value_states, self.num_key_value_groups) + # 4. repeat kv ends + + # TODO: decide if we want to use flash attention + # 5. Attention calculation + # TODO: decide if we need to apply attention mask + # TODO: fix attention weight + attn_output, attn_weights = native_sdp(query_states[batch:batch + 1, :, :, :], + current_key_states, + current_value_states, + None, + 1, + 1, + current_kv_len, + self.head_dim, + self.num_heads) + batched_attention_output.append(attn_output) + # 5. Attention calculation ends + # Concat attention output together + # (1, num_heads, 1, head_dim) + attn_output = torch.concat(batched_attention_output, dim=0) + batched_attention_output.clear() + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( # noqa + "`attn_output` should be of size " + f"{repr((bsz, self.num_heads, q_len, self.head_dim))}, but is " + f"{repr(attn_output.size())}" + ) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + # Apply rotary_embedding first in a batch + + # TODO: this output_attentions is not correct, we are not concat attention weight + if not output_attentions: + attn_weights = None + return attn_output, attn_weights, updated_past_key_values if use_cache else None + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + use_fuse_rope = query_states.device.type == "xpu" + use_fuse_rope = use_fuse_rope and not (self.training and query_states.requires_grad) + use_fuse_rope = use_fuse_rope and self.config.rope_scaling is None + + # TODO: we do not know if this have the same effect + if use_fuse_rope: + query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, + key_states, + position_ids, + "llama") + else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, position_ids, "llama") + + # past_key_value is None + if use_cache: + past_key_value = [] + max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH + new_key_states, new_value_states = init_kv_cache(bsz, + self.num_key_value_heads, + self.head_dim, + kv_seq_len, + max_cache_length, + dtype=key_states.dtype, + device=device) + new_key_states[:] = key_states + new_value_states[:] = value_states + key_states = new_key_states + value_states = new_value_states + for batch in range(bsz): + past_key_value.append((key_states[batch: batch + 1, :, :, :], + value_states[batch: batch+1, :, :, :])) + # past_key_value = (key_states, value_states) if use_cache else None + else: + past_key_value = None + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups).to(device, + dtype=attention_dtype) + value_states = repeat_kv(value_states, self.num_key_value_groups).to(device, + dtype=attention_dtype) + + if fsdp_flag and q_len > 1: + # now only use flash attention for first token + attn_output = F.scaled_dot_product_attention(query_states.to(dtype=attention_dtype), + key_states, + value_states, + is_causal=True) + attn_weights = None + else: + # otherwise, use native attention + attn_output, attn_weights = native_sdp(query_states, key_states, value_states, + attention_mask, + bsz, q_len, kv_seq_len, + self.head_dim, self.num_heads) + + attn_output_size = (bsz, self.num_heads, q_len, self.head_dim) + if attn_output.size() != attn_output_size: + invalidInputError(False, + f"`attn_output` should be of size {attn_output_size}," + f" but is {attn_output.size()}") + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, + dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) + for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output.to(original_dtype), attn_weights, past_key_value + + def check_flash_attention_available(query): # check whether ipex flash attention can be used if query.device.type != "xpu": @@ -376,3 +605,159 @@ def native_sdp(query, key, value, attention_mask, dtype=torch.float32).to(value.dtype) attn_output = torch.matmul(attn_weights, value) return attn_output, attn_weights + + +def llama_model_selective_batching_forward_4_31( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +) -> Union[Tuple, BaseModelOutputWithPast]: + print("#########################We are using the newest code!!!") + if output_attentions is not None: + output_attentions = output_attentions + else: + output_attentions = self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids" # noqa + " and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either " # noqa + "decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + # The original position_ids in the format of [1, 1] + # However, this only applies when kv_len is the same for all the sequences + # We should set it to format of [batch, position_id] + # TODO: validate correctness + device = input_ids.device if input_ids is not None else inputs_embeds.device + if past_key_values is None: + # For prefill + position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + past_key_values_length = [] + for sequence_kv in past_key_values[0]: + key = sequence_kv[0] + past_key_values_length.append(key.shape[-2]) + position_ids = torch.tensor(past_key_values_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, 1) + + if past_key_values is not None: + # past_key_values in the format of num_layers x num_seqs x 2 + # TODO: this may be incorrect + past_key_values_length = past_key_values[0][0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + # if position_ids is None: + # device = input_ids.device if input_ids is not None else inputs_embeds.device + # # [start, end) + # position_ids = torch.arange( + # past_key_values_length, seq_length + + # past_key_values_length, dtype=torch.long, device=device + # ) + # position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + # else: + # position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + # TODO: only generate attention_mask for prefilling + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + # TODO: decide if we need this attention_mask, + # we are not using the attention mask when decoding + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) # noqa + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) diff --git a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py index e7537e7d629..6d982a82c31 100644 --- a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py +++ b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py @@ -120,7 +120,7 @@ def __init__( self.device = torch.device(device) self.dtype = self.model.dtype self.last_seq_ids = [] - self.tmp_kv_cache = None + self.last_kv_cache = None self.pad_token_id = config.pad_token_id self.max_seq_limit = max_model_len @@ -216,7 +216,7 @@ def forward( } # Prefill may need additional space, which forces us to delete the last_kv_cache if self.last_kv_cache: - del self.last_kv_cache + self.last_kv_cache = None # pdb.set_trace() if self.device.type == 'xpu': From 9132a5ade8b746676150cd6fc62c05a50f433d59 Mon Sep 17 00:00:00 2001 From: gc-fu Date: Tue, 19 Dec 2023 11:28:04 +0800 Subject: [PATCH 04/26] finish --- python/llm/src/bigdl/llm/transformers/convert.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 4f7cb681c8a..6b1ccec874e 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -374,12 +374,9 @@ def _optimize_post(model, lightweight_bmm=False): from packaging import version from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31 from bigdl.llm.transformers.models.llama import llama_rms_norm_forward -<<<<<<< HEAD from bigdl.llm.transformers.models.llama import llama_mlp_forward -======= from bigdl.llm.transformers.models.llama import llama_attention_selective_batching_forward_4_31 from bigdl.llm.transformers.models.llama import llama_model_selective_batching_forward_4_31 ->>>>>>> 372980fd83 (temp) from transformers.modeling_utils import PreTrainedModel # All huggingface format models are inherited from `PreTrainedModel` @@ -398,17 +395,14 @@ def _optimize_post(model, lightweight_bmm=False): model, transformers.models.llama.modeling_llama.LlamaRMSNorm, llama_rms_norm_forward,) -<<<<<<< HEAD convert_forward(model, transformers.models.llama.modeling_llama.LlamaMLP, llama_mlp_forward) -======= convert_forward( model, transformers.models.llama.modeling_llama.LlamaModel, llama_model_selective_batching_forward_4_31, ) ->>>>>>> 372980fd83 (temp) else: # todo implement 4.28.0 ~ 4.30.2 pass From 4baaf9e641b95b73380ade85669c02b92b166919 Mon Sep 17 00:00:00 2001 From: gc-fu Date: Tue, 19 Dec 2023 13:43:59 +0800 Subject: [PATCH 05/26] Remove print statement --- python/llm/src/bigdl/llm/transformers/models/llama.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 07b9605ef33..40fab1320ca 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -103,6 +103,11 @@ def llama_mlp_forward( and self.gate_proj.qtype == ggml_tensor_qtype["sym_int4"] \ and not (self.training and x.requires_grad): import linear_q4_0 +<<<<<<< HEAD +======= + x_2d = x.view(-1, x.shape[-1]) + print(x_2d.shape) +>>>>>>> 22c9b3c573 (Remove print statement) if not x_2d.is_contiguous(): x_2d = x_2d.contiguous() return self.down_proj(linear_q4_0.mlp_forward_q4_0_xpu( @@ -619,7 +624,6 @@ def llama_model_selective_batching_forward_4_31( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: - print("#########################We are using the newest code!!!") if output_attentions is not None: output_attentions = output_attentions else: From 59ad16a034229c37fbdc7b7c1498e85278c83e7e Mon Sep 17 00:00:00 2001 From: gc-fu Date: Tue, 19 Dec 2023 16:14:19 +0800 Subject: [PATCH 06/26] fix error --- python/llm/src/bigdl/llm/transformers/models/llama.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 40fab1320ca..292647924e8 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -103,11 +103,6 @@ def llama_mlp_forward( and self.gate_proj.qtype == ggml_tensor_qtype["sym_int4"] \ and not (self.training and x.requires_grad): import linear_q4_0 -<<<<<<< HEAD -======= - x_2d = x.view(-1, x.shape[-1]) - print(x_2d.shape) ->>>>>>> 22c9b3c573 (Remove print statement) if not x_2d.is_contiguous(): x_2d = x_2d.contiguous() return self.down_proj(linear_q4_0.mlp_forward_q4_0_xpu( From 35a0ef8add9345adfbd87eda5dbd617e3e00dbd1 Mon Sep 17 00:00:00 2001 From: gc-fu Date: Tue, 19 Dec 2023 16:17:52 +0800 Subject: [PATCH 07/26] Apply yang's optimization --- .../bigdl/llm/transformers/models/llama.py | 376 ++++++++++-------- 1 file changed, 205 insertions(+), 171 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 292647924e8..c5dbc93ae1b 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -197,7 +197,6 @@ def llama_attention_forward_4_31( value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] value_states = torch.cat(value_states, dim=-1) - else: query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) @@ -311,6 +310,39 @@ def llama_attention_forward_4_31( return attn_output.to(original_dtype), attn_weights, past_key_value +# Return attn_output, attn_weights +def calculate_xpu_sdp(fsdp_flag, q_len, head_dim, num_heads, q_states, k_states, v_states, current_kv_length): + if fsdp_flag and q_len > 1: + # TODO: delete + print("F.scaled_dot_product_attention") + attn_output = F.scaled_dot_product_attention(q_states, + k_states, + v_states, + is_causal=True) + return attn_output, None + elif use_esimd_sdp(q_len, head_dim, q_states): + import linear_fp16_esimd + # TODO: delete + print("linear_fp16_esimd") + attn_output = linear_fp16_esimd.sdp_forward(q_states, + k_states.contiguous(), + v_states.contiguous()) + attn_output = attn_output.view(q_states.shape) + return attn_output, None + else: + # TODO: delete + print("native sdp") + attn_output, attn_weights = native_sdp(q_states, + k_states, + v_states, + None, + 1, + 1, + current_kv_length, + head_dim, + num_heads) + return attn_output, attn_weights + # dev-note: this differs to original llama model in the aspect # that it has different format for past_key_value def llama_attention_selective_batching_forward_4_31( @@ -337,183 +369,189 @@ def llama_attention_selective_batching_forward_4_31( else: attention_dtype = original_dtype - # TODO: delete - # print("use fsdp_flag:" + str(fsdp_flag)) - # TODO: consider this branch later - if self.config.pretraining_tp > 1: - key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp - query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) - // self.config.pretraining_tp, dim=0) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [F.linear(hidden_states, query_slices[i]) - for i in range(self.config.pretraining_tp)] - query_states = torch.cat(query_states, dim=-1) - - key_states = [F.linear(hidden_states, key_slices[i]) - for i in range(self.config.pretraining_tp)] - key_states = torch.cat(key_states, dim=-1) - - value_states = [F.linear(hidden_states, value_slices[i]) - for i in range(self.config.pretraining_tp)] - value_states = torch.cat(value_states, dim=-1) + # dev: new changes from yang's pr + use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) + # We have unusual past_key_value format we have only one batch + enough_kv_room = is_enough_kv_cache_room(past_key_value[0][0]) + is_q4_0 = self.q_proj.qtype == SYM_INT4 + no_tp = not self.config.pretraining_tp > 1 + decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and + enough_kv_room and bsz * q_len == 1) + # we put the new kv_caches into this + updated_past_key_values = [] + if decoding_fast_path: + # 1, 1, hidden_dimension + # TODO: delete + print("Debug: Decoding fast path") + hidden_states = hidden_states.view(1, -1) + kv_seq_len = past_key_value[0][0].shape[-2] + cache_k = past_key_value[0][0] + cache_v = past_key_value[0][1] + import linear_q4_0 + query_states, key_states, value_states = linear_q4_0.forward_qkv(hidden_states, + self.q_proj.weight, + self.k_proj.weight, + self.v_proj.weight, + position_ids, + cache_k, cache_v, + self.q_proj.weight.qtype, + kv_seq_len, + self.head_dim) + kv_seq_len += 1 + # Append kv_cache + updated_past_key_values.append((key_states, value_states)) else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, - self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, - self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, - self.num_key_value_heads, self.head_dim).transpose(1, 2) - - if past_key_value is not None: - updated_past_key_values = [] - # Apply rotary_embeddings - max_kv_length = max(kv_pair[0].shape[-2] for kv_pair in past_key_value) - max_kv_length += key_states.shape[-2] - - # TODO: decide if we need to use_fuse_rope - cos, sin = self.rotary_emb(value_states, seq_len=max_kv_length) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, - position_ids, "llama") - # End of applying rotary_embedding - - batched_attention_output = [] - for batch in range(bsz): - # 2. concat key_states, value_states - # Get current key_states, value_states from previous cache - past_k, past_v = past_key_value[batch] - # Should be len + 1 - current_kv_len = past_k.shape[-2] + 1 - if past_k.stride()[1] <= past_k.size(2) * past_k.size(3): - new_past_k, new_past_v = extend_kv_cache(1, - self.num_key_value_heads, - self.head_dim, - past_k.size(2), - current_kv_len + - KV_CACHE_ALLOC_BLOCK_LENGTH, - dtype=past_k.dtype, - device=device) - new_past_k[:] = past_k - new_past_v[:] = past_v - past_k = new_past_k - past_v = new_past_v - current_key_states, current_value_states = append_kv_cache(past_k, past_v, # noqa - key_states[batch:batch + 1, :, :, :], # noqa - value_states[batch:batch + 1, :, :, :] # noqa - ) - # 2. concat key_states, value_states ends - - # 3. Record key_states, and value_states - updated_past_key_values.append((current_key_states, current_value_states)) - # 3. Record key_states, value_states end - - # 4. repeat kv - # repeat k/v heads if n_kv_heads < n_heads - current_key_states = repeat_kv(current_key_states, self.num_key_value_groups) - current_value_states = repeat_kv(current_value_states, self.num_key_value_groups) - # 4. repeat kv ends - - # TODO: decide if we want to use flash attention - # 5. Attention calculation - # TODO: decide if we need to apply attention mask - # TODO: fix attention weight - attn_output, attn_weights = native_sdp(query_states[batch:batch + 1, :, :, :], - current_key_states, - current_value_states, - None, - 1, - 1, - current_kv_len, - self.head_dim, - self.num_heads) - batched_attention_output.append(attn_output) - # 5. Attention calculation ends - # Concat attention output together - # (1, num_heads, 1, head_dim) - attn_output = torch.concat(batched_attention_output, dim=0) - batched_attention_output.clear() - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( # noqa - "`attn_output` should be of size " - f"{repr((bsz, self.num_heads, q_len, self.head_dim))}, but is " - f"{repr(attn_output.size())}" - ) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - attn_output = self.o_proj(attn_output) - # Apply rotary_embedding first in a batch + if self.config.pretraining_tp > 1: + # TODO: implement the case with pretraining_tp + raise ValueError("selective batching: we have not implemented feature with" + " pretraining_tp > 1") + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) - # TODO: this output_attentions is not correct, we are not concat attention weight - if not output_attentions: - attn_weights = None - return attn_output, attn_weights, updated_past_key_values if use_cache else None + query_states = query_states.view(bsz, q_len, + self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, + self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, + self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] + if past_key_value is not None: + # Decoding + # Apply rotary_embeddings + max_kv_length = max(kv_pair[0].shape[-2] for kv_pair in past_key_value) + max_kv_length += key_states.shape[-2] + + # TODO: decide if we need to use_fuse_rope + cos, sin = self.rotary_emb(value_states, seq_len=max_kv_length) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, + position_ids, "llama") + # End of applying rotary_embedding + batched_attention_output = [] + for batch in range(bsz): + # 2. concat key_states, value_states + # Get current key_states, value_states from previous cache + past_k, past_v = past_key_value[batch] + # Should be len + 1 + current_kv_len = past_k.shape[-2] + 1 + if past_k.stride()[1] <= past_k.size(2) * past_k.size(3): + new_past_k, new_past_v = extend_kv_cache(1, + self.num_key_value_heads, + self.head_dim, + past_k.size(2), + current_kv_len + + KV_CACHE_ALLOC_BLOCK_LENGTH, + dtype=past_k.dtype, + device=device) + new_past_k[:] = past_k + new_past_v[:] = past_v + past_k = new_past_k + past_v = new_past_v + current_key_states, current_value_states = append_kv_cache(past_k, past_v, # noqa + key_states[batch:batch + 1, :, :, :], # noqa + value_states[batch:batch + 1, :, :, :] # noqa + ) + # 2. concat key_states, value_states ends + + # 3. Record key_states, and value_states + updated_past_key_values.append((current_key_states, current_value_states)) + # 3. Record key_states, value_states end + + # 4. repeat kv + # repeat k/v heads if n_kv_heads < n_heads + current_key_states = repeat_kv(current_key_states, self.num_key_value_groups) + current_value_states = repeat_kv(current_value_states, self.num_key_value_groups) + # 4. repeat kv ends + + # 5. Attention calculation + # TODO: decide if we need to apply attention mask + # TODO: fix attention weight + attn_output, attn_weights = calculate_xpu_sdp(fsdp_flag, + q_len, + self.head_dim, + self.num_heads, + query_states[batch:batch + 1, :, :, :], + current_key_states, + current_value_states, + current_kv_len) + batched_attention_output.append(attn_output) + # 5. Attention calculation ends + # Concat attention output together + # (1, num_heads, 1, head_dim) + attn_output = torch.concat(batched_attention_output, dim=0) + batched_attention_output.clear() + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( # noqa + "`attn_output` should be of size " + f"{repr((bsz, self.num_heads, q_len, self.head_dim))}, but is " + f"{repr(attn_output.size())}" + ) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + # Apply rotary_embedding first in a batch + + # TODO: this output_attentions is not correct, we are not concat attention weight + if not output_attentions: + attn_weights = None + return attn_output, attn_weights, updated_past_key_values if use_cache else None + + # past_key_values is None + + # TODO: we assume this is prefill stage, but this may not + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] - use_fuse_rope = query_states.device.type == "xpu" - use_fuse_rope = use_fuse_rope and not (self.training and query_states.requires_grad) - use_fuse_rope = use_fuse_rope and self.config.rope_scaling is None + # TODO: we do not know if this have the same effect + if use_fuse_rope: + query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, + key_states, + position_ids, + "llama") + else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, position_ids, "llama") - # TODO: we do not know if this have the same effect - if use_fuse_rope: - query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, - key_states, - position_ids, - "llama") - else: - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, - cos, sin, position_ids, "llama") - - # past_key_value is None - if use_cache: - past_key_value = [] - max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH - new_key_states, new_value_states = init_kv_cache(bsz, - self.num_key_value_heads, - self.head_dim, - kv_seq_len, - max_cache_length, - dtype=key_states.dtype, - device=device) - new_key_states[:] = key_states - new_value_states[:] = value_states - key_states = new_key_states - value_states = new_value_states - for batch in range(bsz): - past_key_value.append((key_states[batch: batch + 1, :, :, :], - value_states[batch: batch+1, :, :, :])) - # past_key_value = (key_states, value_states) if use_cache else None - else: - past_key_value = None + # past_key_value is None + if use_cache: + # past_key_value = [] + max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH + new_key_states, new_value_states = init_kv_cache(bsz, + self.num_key_value_heads, + self.head_dim, + kv_seq_len, + max_cache_length, + dtype=key_states.dtype, + device=device) + new_key_states[:] = key_states + new_value_states[:] = value_states + key_states = new_key_states + value_states = new_value_states + for batch in range(bsz): + updated_past_key_values.append((key_states[batch: batch + 1, :, :, :], + value_states[batch: batch+1, :, :, :])) + # past_key_value = (key_states, value_states) if use_cache else None + else: + # past_key_value = None + updated_past_key_values = None # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups).to(device, dtype=attention_dtype) value_states = repeat_kv(value_states, self.num_key_value_groups).to(device, dtype=attention_dtype) - - if fsdp_flag and q_len > 1: - # now only use flash attention for first token - attn_output = F.scaled_dot_product_attention(query_states.to(dtype=attention_dtype), - key_states, - value_states, - is_causal=True) - attn_weights = None - else: - # otherwise, use native attention - attn_output, attn_weights = native_sdp(query_states, key_states, value_states, - attention_mask, - bsz, q_len, kv_seq_len, - self.head_dim, self.num_heads) + attn_output, attn_weights = calculate_xpu_sdp(fsdp_flag, + q_len, + self.head_dim, + self.num_heads, + query_states, + key_states, + value_states, + kv_seq_len) attn_output_size = (bsz, self.num_heads, q_len, self.head_dim) if attn_output.size() != attn_output_size: @@ -525,11 +563,7 @@ def llama_attention_selective_batching_forward_4_31( attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) if self.config.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, - dim=1) - attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) - for i in range(self.config.pretraining_tp)]) + raise ValueError("Not Implemented") else: attn_output = self.o_proj(attn_output) From 196e878ec9a763def2c9e33a1c0b95ac8c25ba92 Mon Sep 17 00:00:00 2001 From: gc-fu Date: Tue, 19 Dec 2023 17:57:04 +0800 Subject: [PATCH 08/26] a version that works --- .../bigdl/llm/transformers/models/llama.py | 37 +++++++++++++------ 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index c5dbc93ae1b..ab985f49f77 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -311,7 +311,7 @@ def llama_attention_forward_4_31( # Return attn_output, attn_weights -def calculate_xpu_sdp(fsdp_flag, q_len, head_dim, num_heads, q_states, k_states, v_states, current_kv_length): +def calculate_xpu_sdp(fsdp_flag, bsz, q_len, head_dim, num_heads, q_states, k_states, v_states, current_kv_length, attention_mask): if fsdp_flag and q_len > 1: # TODO: delete print("F.scaled_dot_product_attention") @@ -335,9 +335,9 @@ def calculate_xpu_sdp(fsdp_flag, q_len, head_dim, num_heads, q_states, k_states, attn_output, attn_weights = native_sdp(q_states, k_states, v_states, - None, - 1, - 1, + attention_mask, + bsz, + q_len, current_kv_length, head_dim, num_heads) @@ -372,7 +372,7 @@ def llama_attention_selective_batching_forward_4_31( # dev: new changes from yang's pr use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) # We have unusual past_key_value format we have only one batch - enough_kv_room = is_enough_kv_cache_room(past_key_value[0][0]) + enough_kv_room = past_key_value is not None and is_enough_kv_cache_room(past_key_value[0]) is_q4_0 = self.q_proj.qtype == SYM_INT4 no_tp = not self.config.pretraining_tp > 1 decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and @@ -381,6 +381,7 @@ def llama_attention_selective_batching_forward_4_31( # we put the new kv_caches into this updated_past_key_values = [] if decoding_fast_path: + # This decoding fast path? # 1, 1, hidden_dimension # TODO: delete print("Debug: Decoding fast path") @@ -469,13 +470,15 @@ def llama_attention_selective_batching_forward_4_31( # TODO: decide if we need to apply attention mask # TODO: fix attention weight attn_output, attn_weights = calculate_xpu_sdp(fsdp_flag, - q_len, + 1, + 1, self.head_dim, self.num_heads, query_states[batch:batch + 1, :, :, :], current_key_states, current_value_states, - current_kv_len) + current_kv_len, + None) batched_attention_output.append(attn_output) # 5. Attention calculation ends # Concat attention output together @@ -483,6 +486,7 @@ def llama_attention_selective_batching_forward_4_31( attn_output = torch.concat(batched_attention_output, dim=0) batched_attention_output.clear() if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + print("we are here 2") raise ValueError( # noqa "`attn_output` should be of size " f"{repr((bsz, self.num_heads, q_len, self.head_dim))}, but is " @@ -502,8 +506,6 @@ def llama_attention_selective_batching_forward_4_31( # TODO: we assume this is prefill stage, but this may not kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] # TODO: we do not know if this have the same effect if use_fuse_rope: @@ -544,17 +546,28 @@ def llama_attention_selective_batching_forward_4_31( dtype=attention_dtype) value_states = repeat_kv(value_states, self.num_key_value_groups).to(device, dtype=attention_dtype) + print("q_len is " + str(q_len)) + print("kv_seq_len is " + str(kv_seq_len)) + + print(query_states[0][0][0][0]) + print(key_states[0][0][0][0]) + print(value_states[0][0][0][0]) + # TODO: this might be wrong for decoding + # TODO: for decoding we do not want to apply attention_mask attn_output, attn_weights = calculate_xpu_sdp(fsdp_flag, + bsz, q_len, self.head_dim, self.num_heads, query_states, key_states, value_states, - kv_seq_len) - + kv_seq_len, + attention_mask) + print(attn_output[0][0][0][0]) attn_output_size = (bsz, self.num_heads, q_len, self.head_dim) if attn_output.size() != attn_output_size: + print("we are at here 1") invalidInputError(False, f"`attn_output` should be of size {attn_output_size}," f" but is {attn_output.size()}") @@ -570,7 +583,7 @@ def llama_attention_selective_batching_forward_4_31( if not output_attentions: attn_weights = None - return attn_output.to(original_dtype), attn_weights, past_key_value + return attn_output.to(original_dtype), attn_weights, updated_past_key_values def check_flash_attention_available(query): From 6fa0f68a85ea480cf0a42ec2f552a5d02b4e087a Mon Sep 17 00:00:00 2001 From: gc-fu Date: Wed, 20 Dec 2023 12:31:42 +0800 Subject: [PATCH 09/26] We need to check kv_cache passed in, this could be an error. TODO: add fast decoding path --- .../bigdl/llm/transformers/models/llama.py | 74 ++++++++++++------- .../vllm/model_executor/models/bigdl_llama.py | 47 +++++++----- 2 files changed, 77 insertions(+), 44 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index ab985f49f77..54de57134bc 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -331,7 +331,7 @@ def calculate_xpu_sdp(fsdp_flag, bsz, q_len, head_dim, num_heads, q_states, k_st return attn_output, None else: # TODO: delete - print("native sdp") + #print("native sdp") attn_output, attn_weights = native_sdp(q_states, k_states, v_states, @@ -380,11 +380,13 @@ def llama_attention_selective_batching_forward_4_31( # we put the new kv_caches into this updated_past_key_values = [] + # TODO: change this later + decoding_fast_path = False if decoding_fast_path: # This decoding fast path? # 1, 1, hidden_dimension # TODO: delete - print("Debug: Decoding fast path") + #print("Debug: Decoding fast path") hidden_states = hidden_states.view(1, -1) kv_seq_len = past_key_value[0][0].shape[-2] cache_k = past_key_value[0][0] @@ -467,7 +469,6 @@ def llama_attention_selective_batching_forward_4_31( # 4. repeat kv ends # 5. Attention calculation - # TODO: decide if we need to apply attention mask # TODO: fix attention weight attn_output, attn_weights = calculate_xpu_sdp(fsdp_flag, 1, @@ -478,7 +479,7 @@ def llama_attention_selective_batching_forward_4_31( current_key_states, current_value_states, current_kv_len, - None) + attention_mask[batch]) batched_attention_output.append(attn_output) # 5. Attention calculation ends # Concat attention output together @@ -486,7 +487,6 @@ def llama_attention_selective_batching_forward_4_31( attn_output = torch.concat(batched_attention_output, dim=0) batched_attention_output.clear() if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - print("we are here 2") raise ValueError( # noqa "`attn_output` should be of size " f"{repr((bsz, self.num_heads, q_len, self.head_dim))}, but is " @@ -542,16 +542,20 @@ def llama_attention_selective_batching_forward_4_31( updated_past_key_values = None # repeat k/v heads if n_kv_heads < n_heads + + if not decoding_fast_path: + print(f"Prefill with batching size {bsz}") + key_states = repeat_kv(key_states, self.num_key_value_groups).to(device, dtype=attention_dtype) value_states = repeat_kv(value_states, self.num_key_value_groups).to(device, dtype=attention_dtype) - print("q_len is " + str(q_len)) - print("kv_seq_len is " + str(kv_seq_len)) + #print("q_len is " + str(q_len)) + #print("kv_seq_len is " + str(kv_seq_len)) - print(query_states[0][0][0][0]) - print(key_states[0][0][0][0]) - print(value_states[0][0][0][0]) + #print(query_states[0][0][0][0]) + #print(key_states[0][0][0][0]) + #print(value_states[0][0][0][0]) # TODO: this might be wrong for decoding # TODO: for decoding we do not want to apply attention_mask attn_output, attn_weights = calculate_xpu_sdp(fsdp_flag, @@ -564,10 +568,10 @@ def llama_attention_selective_batching_forward_4_31( value_states, kv_seq_len, attention_mask) - print(attn_output[0][0][0][0]) + #print(attn_output[0][0][0][0]) attn_output_size = (bsz, self.num_heads, q_len, self.head_dim) if attn_output.size() != attn_output_size: - print("we are at here 1") + #print("we are at here 1") invalidInputError(False, f"`attn_output` should be of size {attn_output_size}," f" but is {attn_output.size()}") @@ -698,17 +702,25 @@ def llama_model_selective_batching_forward_4_31( # We should set it to format of [batch, position_id] # TODO: validate correctness device = input_ids.device if input_ids is not None else inputs_embeds.device - if past_key_values is None: - # For prefill - position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + if position_ids is None: + # This should never happened in our case + print("position_ids is None!!!") + raise ValueError("Position_ids should never be None") else: - past_key_values_length = [] - for sequence_kv in past_key_values[0]: - key = sequence_kv[0] - past_key_values_length.append(key.shape[-2]) - position_ids = torch.tensor(past_key_values_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0).view(-1, 1) + print(f"Original position_ids is {position_ids}") + position_ids = position_ids.view(-1, seq_length) + print(f"after position_ids is {position_ids}") + # if past_key_values is None: + # # For prefill + # position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device) + # position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + # else: + # past_key_values_length = [] + # for sequence_kv in past_key_values[0]: + # key = sequence_kv[0] + # past_key_values_length.append(key.shape[-2]) + # position_ids = torch.tensor(past_key_values_length, dtype=torch.long, device=device) + # position_ids = position_ids.unsqueeze(0).view(-1, 1) if past_key_values is not None: # past_key_values in the format of num_layers x num_seqs x 2 @@ -732,12 +744,20 @@ def llama_model_selective_batching_forward_4_31( # embed positions # TODO: only generate attention_mask for prefilling if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + raise ValueError("attention_mask should never be None") + if past_key_values is None: + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length ) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) + else: + i = 0 + for attn_mask in attention_mask: + past_key_value_length = past_key_values[0][i][0].shape[2] + new_mask = self._prepare_decoder_attention_mask( + attn_mask, (1, seq_length), inputs_embeds, past_key_value_length + ) + attention_mask[i] = new_mask + i+=1 hidden_states = inputs_embeds diff --git a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py index 6d982a82c31..949ae3c05a2 100644 --- a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py +++ b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py @@ -179,37 +179,50 @@ def forward( for input_ids in bigdl_input_ids ] - # GC(co): by using selective_batching, we no longer need to create - # attention_mask for decoding. - # if is_decoding_stage: - # cur_seq_len = bigdl_kv_cache[0][0].size(2) - # for seq_group_meta_data in seq_group_meta_data_lists: - # seq_ids = list(seq_group_meta_data.seq_data.keys()) - # seq_id = seq_ids[0] - # seq_data = seq_group_meta_data.seq_data[seq_id] - # cur_pos = seq_data.get_len() - # # bigdl_position_ids.append([cur_pos - 1]) - # cur_attention_mask = [0] * (cur_seq_len - cur_pos + 1) + [1] * (cur_pos) - # bigdl_attention_mask.append(cur_attention_mask) + # TODO: this could be deleted after prefill stage is also selective_batched + decoding_attention_mask_list = [] + decoding_position_ids = [] + # Attention_mask for decoding could also be a list of tensors due to inconsistent length of kv_cache + # num_layers x len(seq_id) x (2 x torch.Tensor) + if is_decoding_stage: + batch = 0 + for seq_group_meta_data in seq_group_meta_data_lists: + # Get current seq_len in kv_cache + current_seq_len = bigdl_kv_cache[0][batch][0].size(2) + batch += 1 + seq_ids = list(seq_group_meta_data.seq_data.keys()) + seq_data = seq_group_meta_data.seq_data[seq_ids[0]] + cur_pos = seq_data.get_len() + decoding_position_ids.append(cur_pos - 1) + # Total length: current_seq_len + 1 + cur_attention_mask = [0] * (current_seq_len - cur_pos + 1) + [1] * (cur_pos) + decoding_attention_mask_list.append(cur_attention_mask) + bigdl_input_ids = torch.tensor(bigdl_input_ids, device=self.device) + # TODO: prefill requests could also be sbed, so that we can remove attention_mask forever if is_decoding_stage: + attention_mask = [torch.tensor(x, device=self.device).unsqueeze(0) for x in decoding_attention_mask_list] + position_ids = torch.tensor(decoding_position_ids).long().unsqueeze(-1) kwargs = { "input_ids": bigdl_input_ids, # gc(co): we rely on underlying model to generate position_ids - # "position_ids": bigdl_position_ids, - # gc(co): we no longer need attention_mask - # "attention_mask": bigdl_attention_mask, + "position_ids": position_ids, + "attention_mask": attention_mask, "past_key_values": bigdl_kv_cache, "use_cache": True, # "return_dict": True, } else: + # Prefill stage + attention_mask = torch.tensor(bigdl_attention_mask, device=self.device) + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) kwargs = { "input_ids": bigdl_input_ids, - "attention_mask": torch.tensor(bigdl_attention_mask, device=self.device), - # "position_ids": bigdl_position_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, "past_key_values": None, "use_cache": True, # "return_dict": True, From 6e883c471f1c3e27ced365186d29c3bc5cead900 Mon Sep 17 00:00:00 2001 From: gc-fu Date: Wed, 20 Dec 2023 13:19:42 +0800 Subject: [PATCH 10/26] format --- python/llm/src/bigdl/llm/transformers/models/llama.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 54de57134bc..a53c595cabd 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -453,9 +453,9 @@ def llama_attention_selective_batching_forward_4_31( past_k = new_past_k past_v = new_past_v current_key_states, current_value_states = append_kv_cache(past_k, past_v, # noqa - key_states[batch:batch + 1, :, :, :], # noqa - value_states[batch:batch + 1, :, :, :] # noqa - ) + key_states[batch:batch + 1, :, :, :], # noqa + value_states[batch:batch + 1, :, :, :] # noqa + ) # 2. concat key_states, value_states ends # 3. Record key_states, and value_states @@ -745,6 +745,7 @@ def llama_model_selective_batching_forward_4_31( # TODO: only generate attention_mask for prefilling if attention_mask is None: raise ValueError("attention_mask should never be None") + print(f"attention_mask before expanding: {attention_mask}") if past_key_values is None: attention_mask = self._prepare_decoder_attention_mask( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length From 39ff8980a46d4ce2f41c23f01a5ffa30d539820b Mon Sep 17 00:00:00 2001 From: gc-fu Date: Wed, 20 Dec 2023 14:36:29 +0800 Subject: [PATCH 11/26] temp solution: not batching prefill requests --- python/llm/src/bigdl/llm/vllm/core/scheduler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/llm/src/bigdl/llm/vllm/core/scheduler.py b/python/llm/src/bigdl/llm/vllm/core/scheduler.py index b41ea166d45..be5c8103af2 100644 --- a/python/llm/src/bigdl/llm/vllm/core/scheduler.py +++ b/python/llm/src/bigdl/llm/vllm/core/scheduler.py @@ -226,6 +226,8 @@ def _schedule(self) -> SchedulerOutputs: num_batched_tokens += num_prompt_tokens num_curr_seqs += num_new_seqs scheduled.append(seq_group) + # TODO: we choose to not batching the prefill requests + break if scheduled or ignored_seq_groups: scheduler_outputs = SchedulerOutputs( From 6ebc5805909996f301662170153d129afa565b56 Mon Sep 17 00:00:00 2001 From: gc-fu Date: Wed, 20 Dec 2023 16:18:17 +0800 Subject: [PATCH 12/26] a version that works for prefill batching --- .../llm/src/bigdl/llm/transformers/convert.py | 4 +- .../bigdl/llm/transformers/models/llama.py | 263 ++++++++++++++++++ .../llm/src/bigdl/llm/vllm/core/scheduler.py | 2 +- 3 files changed, 266 insertions(+), 3 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 6b1ccec874e..d8512efc12a 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -375,7 +375,7 @@ def _optimize_post(model, lightweight_bmm=False): from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31 from bigdl.llm.transformers.models.llama import llama_rms_norm_forward from bigdl.llm.transformers.models.llama import llama_mlp_forward - from bigdl.llm.transformers.models.llama import llama_attention_selective_batching_forward_4_31 + from bigdl.llm.transformers.models.llama import llama_attention_selective_batching_forward_4_31, llama_attention_forward_4_31_so_sb from bigdl.llm.transformers.models.llama import llama_model_selective_batching_forward_4_31 from transformers.modeling_utils import PreTrainedModel @@ -390,7 +390,7 @@ def _optimize_post(model, lightweight_bmm=False): convert_forward( model, transformers.models.llama.modeling_llama.LlamaAttention, - llama_attention_selective_batching_forward_4_31,) + llama_attention_forward_4_31_so_sb,) convert_forward( model, transformers.models.llama.modeling_llama.LlamaRMSNorm, diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index a53c595cabd..189ee6e8934 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -310,6 +310,269 @@ def llama_attention_forward_4_31( return attn_output.to(original_dtype), attn_weights, past_key_value +def llama_attention_forward_4_31_so_sb( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + padding_mask: Optional[torch.LongTensor] = None, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + device = hidden_states.device + # for flash attention + original_dtype = hidden_states.dtype + # TODO: consider this later + # if not self.training and not hidden_states.requires_grad: + # fsdp_flag = check_flash_attention_available(hidden_states) + # else: + # fsdp_flag = False + # if fsdp_flag and q_len > 1: + # attention_dtype = torch.float16 # use fp16 for flash attention + # else: + # attention_dtype = original_dtype + + attention_dtype = original_dtype + + # use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) + # enough_kv_room = is_enough_kv_cache_room(past_key_value[0]) + # is_q4_0 = self.q_proj.qtype == SYM_INT4 + # no_tp = not self.config.pretraining_tp > 1 + # decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and + # enough_kv_room and bsz * q_len == 1) + + # single batch decoding fast path + # forward_qkv takes will perform QKV projection, rotary position embedding + # and save the key/value states to cache, then return query states and the + # extended key/value cache + # if decoding_fast_path: + # hidden_states = hidden_states.view(1, -1) + # kv_seq_len = past_key_value[0].shape[-2] + # cache_k = past_key_value[0] + # cache_v = past_key_value[1] + # import linear_q4_0 + # query_states, key_states, value_states = linear_q4_0.forward_qkv(hidden_states, + # self.q_proj.weight, + # self.k_proj.weight, + # self.v_proj.weight, + # position_ids, + # cache_k, cache_v, + # self.q_proj.weight.qtype, + # kv_seq_len, + # self.head_dim) + # kv_seq_len += 1 + + # else: + if self.config.pretraining_tp > 1: + raise NotImplementedError("config.pretraining_tp not implemented") + # key_value_slicing = ((self.num_key_value_heads * self.head_dim) // + # self.config.pretraining_tp) + # query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) + # // self.config.pretraining_tp, dim=0) + # key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + # value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + # query_states = [F.linear(hidden_states, query_slices[i]) + # for i in range(self.config.pretraining_tp)] + # query_states = torch.cat(query_states, dim=-1) + + # key_states = [F.linear(hidden_states, key_slices[i]) + # for i in range(self.config.pretraining_tp)] + # key_states = torch.cat(key_states, dim=-1) + + # value_states = [F.linear(hidden_states, value_slices[i]) + # for i in range(self.config.pretraining_tp)] + # value_states = torch.cat(value_states, dim=-1) + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, + self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, + self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, + self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += max(kv_pair[0].shape[-2] for kv_pair in past_key_value) + + # if use_fuse_rope: + # query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, + # key_states, + # position_ids, + # "llama") + # else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, position_ids, "llama") + + updated_past_key_values = [] + if past_key_value is not None: + batched_attention_output = [] + for batch in range(bsz): + past_k, past_v = past_key_value[batch] + current_kv_len = past_k.shape[-2] + 1 + + current_key_states = torch.cat([past_k, key_states[batch: batch + 1, : , :, :]], dim=2) + current_value_states = torch.cat([past_v, value_states[batch: batch + 1, :, :, :]], dim=2) + + updated_past_key_values.append((current_key_states, current_value_states)) + + current_key_states = repeat_kv(current_key_states, self.num_key_value_groups) + current_value_states = repeat_kv(current_value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states[batch: batch + 1, :, :, :], current_key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + if attn_weights.size() != (1, self.num_heads, 1, current_kv_len): + raise ValueError( + f"Attention weights should be of size {(1, self.num_heads, 1, current_kv_len)}, but is" + f" {attn_weights.size()}" + ) + # TODO: decide if we need to apply attention mask + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + # (1, num_heads, 1, kv_len + 1) x (1, num_heads, kv_len + 1, head_dim) + # (1, num_heads, 1, head_dim) + attn_output = torch.matmul(attn_weights, current_value_states) + if attn_output.size() != (1, self.num_heads, 1, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(1, self.num_heads, 1, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + batched_attention_output.append(attn_output) + # For loop ends + attn_output = torch.concat(batched_attention_output, dim=0) + batched_attention_output.clear() + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + return attn_output, None, updated_past_key_values + # TODO: apply later + # reuse k, v, self_attention + # cache_k = past_key_value[0] + # cache_v = past_key_value[1] + # if not enough_kv_room: + # # allocate new + # new_cache_k, new_cache_v = extend_kv_cache(bsz, + # self.num_key_value_heads, # Support GQA + # self.head_dim, + # cache_k.size(2), + # kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, + # dtype=cache_k.dtype, + # device=device) + # new_cache_k[:] = cache_k + # new_cache_v[:] = cache_v + # cache_k = new_cache_k + # cache_v = new_cache_v + + # key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states) + + # elif use_cache: + # # Must be prefill + # max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH + # new_key_states, new_value_states = init_kv_cache(bsz, + # self.num_key_value_heads, + # self.head_dim, + # kv_seq_len, + # max_cache_length, + # dtype=key_states.dtype, + # device=device) + # new_key_states[:] = key_states + # new_value_states[:] = value_states + # key_states = new_key_states + # value_states = new_value_states + + # TODO: Assume always use_cache + for batch in range(bsz): + updated_past_key_values.append((key_states[batch: batch + 1, :, :, :], value_states[batch: batch+1, :, :, :])) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups).to(device, + dtype=attention_dtype) + value_states = repeat_kv(value_states, self.num_key_value_groups).to(device, + dtype=attention_dtype) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + # Apply attention_mask in four dimensions + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + # if fsdp_flag and q_len > 1: + # # now only use flash attention for first token + # attn_output = F.scaled_dot_product_attention(query_states.to(dtype=attention_dtype), + # key_states, + # value_states, + # is_causal=True) + # attn_weights = None + # elif use_esimd_sdp(q_len, self.head_dim, query_states): + # import linear_fp16_esimd + # attn_output = linear_fp16_esimd.sdp_forward(query_states, + # key_states.contiguous(), + # value_states.contiguous()) + # attn_output = attn_output.view(query_states.shape) + # attn_weights = None + # else: + # # otherwise, use native attention + # attn_output, attn_weights = native_sdp(query_states, key_states, value_states, + # attention_mask, + # bsz, q_len, kv_seq_len, + # self.head_dim, self.num_heads) + + # attn_output_size = (bsz, self.num_heads, q_len, self.head_dim) + # if attn_output.size() != attn_output_size: + # invalidInputError(False, + # f"`attn_output` should be of size {attn_output_size}," + # f" but is {attn_output.size()}") + + # attn_output = attn_output.transpose(1, 2).contiguous() + # attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + # if self.config.pretraining_tp > 1: + # attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + # o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, + # dim=1) + # attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) + # for i in range(self.config.pretraining_tp)]) + # else: + # attn_output = self.o_proj(attn_output) + + # if not output_attentions: + # attn_weights = None + + return attn_output.to(original_dtype), attn_weights, updated_past_key_values + + # Return attn_output, attn_weights def calculate_xpu_sdp(fsdp_flag, bsz, q_len, head_dim, num_heads, q_states, k_states, v_states, current_kv_length, attention_mask): if fsdp_flag and q_len > 1: diff --git a/python/llm/src/bigdl/llm/vllm/core/scheduler.py b/python/llm/src/bigdl/llm/vllm/core/scheduler.py index be5c8103af2..cc67638a576 100644 --- a/python/llm/src/bigdl/llm/vllm/core/scheduler.py +++ b/python/llm/src/bigdl/llm/vllm/core/scheduler.py @@ -227,7 +227,7 @@ def _schedule(self) -> SchedulerOutputs: num_curr_seqs += num_new_seqs scheduled.append(seq_group) # TODO: we choose to not batching the prefill requests - break + # break if scheduled or ignored_seq_groups: scheduler_outputs = SchedulerOutputs( From 9f4435a1c584f29d2c16dd759edb510f75c517a5 Mon Sep 17 00:00:00 2001 From: gc-fu Date: Thu, 21 Dec 2023 10:26:11 +0800 Subject: [PATCH 13/26] format --- .../bigdl/llm/transformers/models/llama.py | 100 +++++++++--------- 1 file changed, 51 insertions(+), 49 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 189ee6e8934..5e96a1a02f5 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -44,6 +44,7 @@ from transformers.modeling_outputs import BaseModelOutputWithPast from bigdl.llm.transformers.low_bit_linear import SYM_INT4 from bigdl.llm.ggml.quantize import ggml_tensor_qtype +from bigdl.llm.utils.common import invalidInputError def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -367,25 +368,7 @@ def llama_attention_forward_4_31_so_sb( # else: if self.config.pretraining_tp > 1: - raise NotImplementedError("config.pretraining_tp not implemented") - # key_value_slicing = ((self.num_key_value_heads * self.head_dim) // - # self.config.pretraining_tp) - # query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) - # // self.config.pretraining_tp, dim=0) - # key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - # value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - # query_states = [F.linear(hidden_states, query_slices[i]) - # for i in range(self.config.pretraining_tp)] - # query_states = torch.cat(query_states, dim=-1) - - # key_states = [F.linear(hidden_states, key_slices[i]) - # for i in range(self.config.pretraining_tp)] - # key_states = torch.cat(key_states, dim=-1) - - # value_states = [F.linear(hidden_states, value_slices[i]) - # for i in range(self.config.pretraining_tp)] - # value_states = torch.cat(value_states, dim=-1) + invalidInputError(False, f"vLLM: config.pretraining_tp > 1 not supported yet") else: query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) @@ -415,6 +398,7 @@ def llama_attention_forward_4_31_so_sb( updated_past_key_values = [] if past_key_value is not None: batched_attention_output = [] + # print(f"type of attention_mask is {type(attention_mask)}") for batch in range(bsz): past_k, past_v = past_key_value[batch] current_kv_len = past_k.shape[-2] + 1 @@ -427,30 +411,47 @@ def llama_attention_forward_4_31_so_sb( current_key_states = repeat_kv(current_key_states, self.num_key_value_groups) current_value_states = repeat_kv(current_value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states[batch: batch + 1, :, :, :], current_key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attn_weights.size() != (1, self.num_heads, 1, current_kv_len): - raise ValueError( - f"Attention weights should be of size {(1, self.num_heads, 1, current_kv_len)}, but is" - f" {attn_weights.size()}" - ) - # TODO: decide if we need to apply attention mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + current_query_states = query_states[batch: batch + 1, :, :, :] + attn_output, aattn_weights = native_sdp(current_query_states, + current_key_states, + current_value_states, + attention_mask[batch], + 1, + 1, + current_kv_len, + self.head_dim, + self.num_heads) + # attn_weights = torch.matmul(query_states[batch: batch + 1, :, :, :], current_key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + # if attn_weights.size() != (1, self.num_heads, 1, current_kv_len): + # invalidInputError(False, + # f"Attention weights should be of size {(1, self.num_heads, 1, current_kv_len)}, but is" + # f" {attn_weights.size()}" + # ) + # if attention_mask is not None: + # if attention_mask[batch].size() != (1, 1, q_len, current_kv_len): + # invalidInputError(False, + # f"Attention mask should be of size {(1, 1, q_len, current_kv_len)}, but is {attention_mask.size()}" + # ) + # # print(f"added attention_mask") + # attn_weights = attn_weights + attention_mask[batch] + # attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) # (1, num_heads, 1, kv_len + 1) x (1, num_heads, kv_len + 1, head_dim) # (1, num_heads, 1, head_dim) - attn_output = torch.matmul(attn_weights, current_value_states) + # attn_output = torch.matmul(attn_weights, current_value_states) if attn_output.size() != (1, self.num_heads, 1, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(1, self.num_heads, 1, self.head_dim)}, but is" - f" {attn_output.size()}" + invalidInputError(False, + f"`attn_output` should be of size {(1, self.num_heads, 1, self.head_dim)}, but is" + f" {attn_output.size()}" ) batched_attention_output.append(attn_output) # For loop ends + # TODO: handle attention_weights attn_output = torch.concat(batched_attention_output, dim=0) batched_attention_output.clear() if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" + invalidInputError(False, + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) @@ -492,6 +493,7 @@ def llama_attention_forward_4_31_so_sb( # value_states = new_value_states # TODO: Assume always use_cache + print(f"prefill with batch size {bsz}") for batch in range(bsz): updated_past_key_values.append((key_states[batch: batch + 1, :, :, :], value_states[batch: batch+1, :, :, :])) @@ -503,14 +505,14 @@ def llama_attention_forward_4_31_so_sb( attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( + invalidInputError(False, f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" f" {attn_weights.size()}" ) # Apply attention_mask in four dimensions if attention_mask is not None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( + invalidInputError(False, f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights + attention_mask @@ -519,7 +521,7 @@ def llama_attention_forward_4_31_so_sb( attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( + invalidInputError(False, f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}" ) @@ -806,8 +808,8 @@ def llama_attention_selective_batching_forward_4_31( # repeat k/v heads if n_kv_heads < n_heads - if not decoding_fast_path: - print(f"Prefill with batching size {bsz}") + # if not decoding_fast_path: + # print(f"Prefill with batching size {bsz}") key_states = repeat_kv(key_states, self.num_key_value_groups).to(device, dtype=attention_dtype) @@ -947,15 +949,17 @@ def llama_model_selective_batching_forward_4_31( # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids" # noqa - " and decoder_inputs_embeds at the same time") + invalidInputError(False, + "You cannot specify both decoder_input_ids" + " and decoder_inputs_embeds at the same time") elif input_ids is not None: batch_size, seq_length = input_ids.shape elif inputs_embeds is not None: batch_size, seq_length, _ = inputs_embeds.shape else: - raise ValueError("You have to specify either " # noqa - "decoder_input_ids or decoder_inputs_embeds") + invalidInputError(False, + "You have to specify either " + "decoder_input_ids or decoder_inputs_embeds") seq_length_with_past = seq_length past_key_values_length = 0 @@ -966,13 +970,11 @@ def llama_model_selective_batching_forward_4_31( # TODO: validate correctness device = input_ids.device if input_ids is not None else inputs_embeds.device if position_ids is None: - # This should never happened in our case - print("position_ids is None!!!") - raise ValueError("Position_ids should never be None") + invalidInputError("vLLM: position_ids should never be None") else: - print(f"Original position_ids is {position_ids}") + # print(f"Original position_ids is {position_ids}") position_ids = position_ids.view(-1, seq_length) - print(f"after position_ids is {position_ids}") + # print(f"after position_ids is {position_ids}") # if past_key_values is None: # # For prefill # position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device) @@ -1007,7 +1009,7 @@ def llama_model_selective_batching_forward_4_31( # embed positions # TODO: only generate attention_mask for prefilling if attention_mask is None: - raise ValueError("attention_mask should never be None") + invalidInputError(False, "attention_mask should never be None") print(f"attention_mask before expanding: {attention_mask}") if past_key_values is None: attention_mask = self._prepare_decoder_attention_mask( From 16bd217dbc695b8f67af2e4315380b5ed76eb0e3 Mon Sep 17 00:00:00 2001 From: gc-fu Date: Thu, 21 Dec 2023 10:29:52 +0800 Subject: [PATCH 14/26] a solid version: works normally --- .../bigdl/llm/transformers/models/llama.py | 44 +++++++++++-------- 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 5e96a1a02f5..941faac8768 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -445,7 +445,7 @@ def llama_attention_forward_4_31_so_sb( ) batched_attention_output.append(attn_output) # For loop ends - # TODO: handle attention_weights + # TODO: handle attention_weights later attn_output = torch.concat(batched_attention_output, dim=0) batched_attention_output.clear() if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): @@ -502,23 +502,31 @@ def llama_attention_forward_4_31_so_sb( dtype=attention_dtype) value_states = repeat_kv(value_states, self.num_key_value_groups).to(device, dtype=attention_dtype) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - invalidInputError(False, - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - # Apply attention_mask in four dimensions - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - invalidInputError(False, - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) + attn_output, attn_weights = native_sdp(query_states, + key_states, + value_states, + attention_mask, + bsz, + q_len, + kv_seq_len, + self.head_dim, + self.num_heads) + # attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + # if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + # invalidInputError(False, + # f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + # f" {attn_weights.size()}" + # ) + # # Apply attention_mask in four dimensions + # if attention_mask is not None: + # if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + # invalidInputError(False, + # f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + # ) + # attn_weights = attn_weights + attention_mask + # # upcast attention to fp32 + # attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + # attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): invalidInputError(False, From cc4cc9d57949064b77ab897b2653d41419f18bf0 Mon Sep 17 00:00:00 2001 From: gc-fu Date: Thu, 21 Dec 2023 11:04:27 +0800 Subject: [PATCH 15/26] a temp version --- .../llm/src/bigdl/llm/transformers/convert.py | 2 +- .../bigdl/llm/transformers/models/llama.py | 402 +----------------- 2 files changed, 5 insertions(+), 399 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index d8512efc12a..e01441b6a04 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -375,7 +375,7 @@ def _optimize_post(model, lightweight_bmm=False): from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31 from bigdl.llm.transformers.models.llama import llama_rms_norm_forward from bigdl.llm.transformers.models.llama import llama_mlp_forward - from bigdl.llm.transformers.models.llama import llama_attention_selective_batching_forward_4_31, llama_attention_forward_4_31_so_sb + from bigdl.llm.transformers.models.llama import llama_attention_selective_batching_forward_4_31 from bigdl.llm.transformers.models.llama import llama_model_selective_batching_forward_4_31 from transformers.modeling_utils import PreTrainedModel diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 941faac8768..a373c73a1cc 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -311,7 +311,7 @@ def llama_attention_forward_4_31( return attn_output.to(original_dtype), attn_weights, past_key_value -def llama_attention_forward_4_31_so_sb( +def llama_attention_selective_batching_forward_4_31( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, @@ -326,7 +326,7 @@ def llama_attention_forward_4_31_so_sb( device = hidden_states.device # for flash attention original_dtype = hidden_states.dtype - # TODO: consider this later + # TODO: consider this later - flash attention # if not self.training and not hidden_states.requires_grad: # fsdp_flag = check_flash_attention_available(hidden_states) # else: @@ -338,6 +338,7 @@ def llama_attention_forward_4_31_so_sb( attention_dtype = original_dtype + # TODO: decoding fast path # use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) # enough_kv_room = is_enough_kv_cache_room(past_key_value[0]) # is_q4_0 = self.q_proj.qtype == SYM_INT4 @@ -385,12 +386,7 @@ def llama_attention_forward_4_31_so_sb( if past_key_value is not None: kv_seq_len += max(kv_pair[0].shape[-2] for kv_pair in past_key_value) - # if use_fuse_rope: - # query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, - # key_states, - # position_ids, - # "llama") - # else: + # TODO: fuse_rope cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids, "llama") @@ -421,23 +417,6 @@ def llama_attention_forward_4_31_so_sb( current_kv_len, self.head_dim, self.num_heads) - # attn_weights = torch.matmul(query_states[batch: batch + 1, :, :, :], current_key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - # if attn_weights.size() != (1, self.num_heads, 1, current_kv_len): - # invalidInputError(False, - # f"Attention weights should be of size {(1, self.num_heads, 1, current_kv_len)}, but is" - # f" {attn_weights.size()}" - # ) - # if attention_mask is not None: - # if attention_mask[batch].size() != (1, 1, q_len, current_kv_len): - # invalidInputError(False, - # f"Attention mask should be of size {(1, 1, q_len, current_kv_len)}, but is {attention_mask.size()}" - # ) - # # print(f"added attention_mask") - # attn_weights = attn_weights + attention_mask[batch] - # attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - # (1, num_heads, 1, kv_len + 1) x (1, num_heads, kv_len + 1, head_dim) - # (1, num_heads, 1, head_dim) - # attn_output = torch.matmul(attn_weights, current_value_states) if attn_output.size() != (1, self.num_heads, 1, self.head_dim): invalidInputError(False, f"`attn_output` should be of size {(1, self.num_heads, 1, self.head_dim)}, but is" @@ -457,40 +436,6 @@ def llama_attention_forward_4_31_so_sb( attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) return attn_output, None, updated_past_key_values - # TODO: apply later - # reuse k, v, self_attention - # cache_k = past_key_value[0] - # cache_v = past_key_value[1] - # if not enough_kv_room: - # # allocate new - # new_cache_k, new_cache_v = extend_kv_cache(bsz, - # self.num_key_value_heads, # Support GQA - # self.head_dim, - # cache_k.size(2), - # kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, - # dtype=cache_k.dtype, - # device=device) - # new_cache_k[:] = cache_k - # new_cache_v[:] = cache_v - # cache_k = new_cache_k - # cache_v = new_cache_v - - # key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states) - - # elif use_cache: - # # Must be prefill - # max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH - # new_key_states, new_value_states = init_kv_cache(bsz, - # self.num_key_value_heads, - # self.head_dim, - # kv_seq_len, - # max_cache_length, - # dtype=key_states.dtype, - # device=device) - # new_key_states[:] = key_states - # new_value_states[:] = value_states - # key_states = new_key_states - # value_states = new_value_states # TODO: Assume always use_cache print(f"prefill with batch size {bsz}") @@ -511,22 +456,6 @@ def llama_attention_forward_4_31_so_sb( kv_seq_len, self.head_dim, self.num_heads) - # attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - # if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - # invalidInputError(False, - # f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - # f" {attn_weights.size()}" - # ) - # # Apply attention_mask in four dimensions - # if attention_mask is not None: - # if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - # invalidInputError(False, - # f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - # ) - # attn_weights = attn_weights + attention_mask - # # upcast attention to fp32 - # attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - # attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): invalidInputError(False, @@ -537,329 +466,6 @@ def llama_attention_forward_4_31_so_sb( attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) - - # if fsdp_flag and q_len > 1: - # # now only use flash attention for first token - # attn_output = F.scaled_dot_product_attention(query_states.to(dtype=attention_dtype), - # key_states, - # value_states, - # is_causal=True) - # attn_weights = None - # elif use_esimd_sdp(q_len, self.head_dim, query_states): - # import linear_fp16_esimd - # attn_output = linear_fp16_esimd.sdp_forward(query_states, - # key_states.contiguous(), - # value_states.contiguous()) - # attn_output = attn_output.view(query_states.shape) - # attn_weights = None - # else: - # # otherwise, use native attention - # attn_output, attn_weights = native_sdp(query_states, key_states, value_states, - # attention_mask, - # bsz, q_len, kv_seq_len, - # self.head_dim, self.num_heads) - - # attn_output_size = (bsz, self.num_heads, q_len, self.head_dim) - # if attn_output.size() != attn_output_size: - # invalidInputError(False, - # f"`attn_output` should be of size {attn_output_size}," - # f" but is {attn_output.size()}") - - # attn_output = attn_output.transpose(1, 2).contiguous() - # attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - # if self.config.pretraining_tp > 1: - # attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) - # o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, - # dim=1) - # attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) - # for i in range(self.config.pretraining_tp)]) - # else: - # attn_output = self.o_proj(attn_output) - - # if not output_attentions: - # attn_weights = None - - return attn_output.to(original_dtype), attn_weights, updated_past_key_values - - -# Return attn_output, attn_weights -def calculate_xpu_sdp(fsdp_flag, bsz, q_len, head_dim, num_heads, q_states, k_states, v_states, current_kv_length, attention_mask): - if fsdp_flag and q_len > 1: - # TODO: delete - print("F.scaled_dot_product_attention") - attn_output = F.scaled_dot_product_attention(q_states, - k_states, - v_states, - is_causal=True) - return attn_output, None - elif use_esimd_sdp(q_len, head_dim, q_states): - import linear_fp16_esimd - # TODO: delete - print("linear_fp16_esimd") - attn_output = linear_fp16_esimd.sdp_forward(q_states, - k_states.contiguous(), - v_states.contiguous()) - attn_output = attn_output.view(q_states.shape) - return attn_output, None - else: - # TODO: delete - #print("native sdp") - attn_output, attn_weights = native_sdp(q_states, - k_states, - v_states, - attention_mask, - bsz, - q_len, - current_kv_length, - head_dim, - num_heads) - return attn_output, attn_weights - -# dev-note: this differs to original llama model in the aspect -# that it has different format for past_key_value -def llama_attention_selective_batching_forward_4_31( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - padding_mask: Optional[torch.LongTensor] = None, - **kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - device = hidden_states.device - # for flash attention - original_dtype = hidden_states.dtype - if not self.training and not hidden_states.requires_grad: - fsdp_flag = check_flash_attention_available(hidden_states) - else: - fsdp_flag = False - if fsdp_flag and q_len > 1: - attention_dtype = torch.float16 # use fp16 for flash attention - else: - attention_dtype = original_dtype - - # dev: new changes from yang's pr - use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) - # We have unusual past_key_value format we have only one batch - enough_kv_room = past_key_value is not None and is_enough_kv_cache_room(past_key_value[0]) - is_q4_0 = self.q_proj.qtype == SYM_INT4 - no_tp = not self.config.pretraining_tp > 1 - decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and - enough_kv_room and bsz * q_len == 1) - - # we put the new kv_caches into this - updated_past_key_values = [] - # TODO: change this later - decoding_fast_path = False - if decoding_fast_path: - # This decoding fast path? - # 1, 1, hidden_dimension - # TODO: delete - #print("Debug: Decoding fast path") - hidden_states = hidden_states.view(1, -1) - kv_seq_len = past_key_value[0][0].shape[-2] - cache_k = past_key_value[0][0] - cache_v = past_key_value[0][1] - import linear_q4_0 - query_states, key_states, value_states = linear_q4_0.forward_qkv(hidden_states, - self.q_proj.weight, - self.k_proj.weight, - self.v_proj.weight, - position_ids, - cache_k, cache_v, - self.q_proj.weight.qtype, - kv_seq_len, - self.head_dim) - kv_seq_len += 1 - # Append kv_cache - updated_past_key_values.append((key_states, value_states)) - else: - if self.config.pretraining_tp > 1: - # TODO: implement the case with pretraining_tp - raise ValueError("selective batching: we have not implemented feature with" - " pretraining_tp > 1") - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, - self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, - self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, - self.num_key_value_heads, self.head_dim).transpose(1, 2) - - if past_key_value is not None: - # Decoding - # Apply rotary_embeddings - max_kv_length = max(kv_pair[0].shape[-2] for kv_pair in past_key_value) - max_kv_length += key_states.shape[-2] - - # TODO: decide if we need to use_fuse_rope - cos, sin = self.rotary_emb(value_states, seq_len=max_kv_length) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, - position_ids, "llama") - # End of applying rotary_embedding - batched_attention_output = [] - for batch in range(bsz): - # 2. concat key_states, value_states - # Get current key_states, value_states from previous cache - past_k, past_v = past_key_value[batch] - # Should be len + 1 - current_kv_len = past_k.shape[-2] + 1 - if past_k.stride()[1] <= past_k.size(2) * past_k.size(3): - new_past_k, new_past_v = extend_kv_cache(1, - self.num_key_value_heads, - self.head_dim, - past_k.size(2), - current_kv_len + - KV_CACHE_ALLOC_BLOCK_LENGTH, - dtype=past_k.dtype, - device=device) - new_past_k[:] = past_k - new_past_v[:] = past_v - past_k = new_past_k - past_v = new_past_v - current_key_states, current_value_states = append_kv_cache(past_k, past_v, # noqa - key_states[batch:batch + 1, :, :, :], # noqa - value_states[batch:batch + 1, :, :, :] # noqa - ) - # 2. concat key_states, value_states ends - - # 3. Record key_states, and value_states - updated_past_key_values.append((current_key_states, current_value_states)) - # 3. Record key_states, value_states end - - # 4. repeat kv - # repeat k/v heads if n_kv_heads < n_heads - current_key_states = repeat_kv(current_key_states, self.num_key_value_groups) - current_value_states = repeat_kv(current_value_states, self.num_key_value_groups) - # 4. repeat kv ends - - # 5. Attention calculation - # TODO: fix attention weight - attn_output, attn_weights = calculate_xpu_sdp(fsdp_flag, - 1, - 1, - self.head_dim, - self.num_heads, - query_states[batch:batch + 1, :, :, :], - current_key_states, - current_value_states, - current_kv_len, - attention_mask[batch]) - batched_attention_output.append(attn_output) - # 5. Attention calculation ends - # Concat attention output together - # (1, num_heads, 1, head_dim) - attn_output = torch.concat(batched_attention_output, dim=0) - batched_attention_output.clear() - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( # noqa - "`attn_output` should be of size " - f"{repr((bsz, self.num_heads, q_len, self.head_dim))}, but is " - f"{repr(attn_output.size())}" - ) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - attn_output = self.o_proj(attn_output) - # Apply rotary_embedding first in a batch - - # TODO: this output_attentions is not correct, we are not concat attention weight - if not output_attentions: - attn_weights = None - return attn_output, attn_weights, updated_past_key_values if use_cache else None - - # past_key_values is None - - # TODO: we assume this is prefill stage, but this may not - kv_seq_len = key_states.shape[-2] - - # TODO: we do not know if this have the same effect - if use_fuse_rope: - query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, - key_states, - position_ids, - "llama") - else: - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, - cos, sin, position_ids, "llama") - - # past_key_value is None - if use_cache: - # past_key_value = [] - max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH - new_key_states, new_value_states = init_kv_cache(bsz, - self.num_key_value_heads, - self.head_dim, - kv_seq_len, - max_cache_length, - dtype=key_states.dtype, - device=device) - new_key_states[:] = key_states - new_value_states[:] = value_states - key_states = new_key_states - value_states = new_value_states - for batch in range(bsz): - updated_past_key_values.append((key_states[batch: batch + 1, :, :, :], - value_states[batch: batch+1, :, :, :])) - # past_key_value = (key_states, value_states) if use_cache else None - else: - # past_key_value = None - updated_past_key_values = None - - # repeat k/v heads if n_kv_heads < n_heads - - # if not decoding_fast_path: - # print(f"Prefill with batching size {bsz}") - - key_states = repeat_kv(key_states, self.num_key_value_groups).to(device, - dtype=attention_dtype) - value_states = repeat_kv(value_states, self.num_key_value_groups).to(device, - dtype=attention_dtype) - #print("q_len is " + str(q_len)) - #print("kv_seq_len is " + str(kv_seq_len)) - - #print(query_states[0][0][0][0]) - #print(key_states[0][0][0][0]) - #print(value_states[0][0][0][0]) - # TODO: this might be wrong for decoding - # TODO: for decoding we do not want to apply attention_mask - attn_output, attn_weights = calculate_xpu_sdp(fsdp_flag, - bsz, - q_len, - self.head_dim, - self.num_heads, - query_states, - key_states, - value_states, - kv_seq_len, - attention_mask) - #print(attn_output[0][0][0][0]) - attn_output_size = (bsz, self.num_heads, q_len, self.head_dim) - if attn_output.size() != attn_output_size: - #print("we are at here 1") - invalidInputError(False, - f"`attn_output` should be of size {attn_output_size}," - f" but is {attn_output.size()}") - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - if self.config.pretraining_tp > 1: - raise ValueError("Not Implemented") - else: - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - return attn_output.to(original_dtype), attn_weights, updated_past_key_values From 888eb703e5b599fc3306ba1db020e18a78589d9e Mon Sep 17 00:00:00 2001 From: gc-fu Date: Thu, 21 Dec 2023 11:08:12 +0800 Subject: [PATCH 16/26] Solid version: remove redundant functions --- python/llm/src/bigdl/llm/transformers/convert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index e01441b6a04..6b1ccec874e 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -390,7 +390,7 @@ def _optimize_post(model, lightweight_bmm=False): convert_forward( model, transformers.models.llama.modeling_llama.LlamaAttention, - llama_attention_forward_4_31_so_sb,) + llama_attention_selective_batching_forward_4_31,) convert_forward( model, transformers.models.llama.modeling_llama.LlamaRMSNorm, From 39a8a114eca7319a6ee5af781125bf2cd73d86b6 Mon Sep 17 00:00:00 2001 From: gc-fu Date: Thu, 21 Dec 2023 11:25:34 +0800 Subject: [PATCH 17/26] fix format --- .../bigdl/llm/transformers/models/llama.py | 39 ++++++++++--------- .../vllm/model_executor/models/bigdl_llama.py | 5 +-- 2 files changed, 23 insertions(+), 21 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index a373c73a1cc..860b970d9f0 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -344,7 +344,7 @@ def llama_attention_selective_batching_forward_4_31( # is_q4_0 = self.q_proj.qtype == SYM_INT4 # no_tp = not self.config.pretraining_tp > 1 # decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and - # enough_kv_room and bsz * q_len == 1) + # enough_kv_room and bsz * q_len == 1) # single batch decoding fast path # forward_qkv takes will perform QKV projection, rotary position embedding @@ -376,11 +376,11 @@ def llama_attention_selective_batching_forward_4_31( value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, - self.num_heads, self.head_dim).transpose(1, 2) + self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, - self.num_key_value_heads, self.head_dim).transpose(1, 2) + self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, - self.num_key_value_heads, self.head_dim).transpose(1, 2) + self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: @@ -389,7 +389,7 @@ def llama_attention_selective_batching_forward_4_31( # TODO: fuse_rope cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, - cos, sin, position_ids, "llama") + cos, sin, position_ids, "llama") updated_past_key_values = [] if past_key_value is not None: @@ -399,8 +399,10 @@ def llama_attention_selective_batching_forward_4_31( past_k, past_v = past_key_value[batch] current_kv_len = past_k.shape[-2] + 1 - current_key_states = torch.cat([past_k, key_states[batch: batch + 1, : , :, :]], dim=2) - current_value_states = torch.cat([past_v, value_states[batch: batch + 1, :, :, :]], dim=2) + current_key_states = torch.cat([past_k, + key_states[batch: batch + 1, :, :, :]], dim=2) + current_value_states = torch.cat([past_v, + value_states[batch: batch + 1, :, :, :]], dim=2) updated_past_key_values.append((current_key_states, current_value_states)) @@ -419,9 +421,9 @@ def llama_attention_selective_batching_forward_4_31( self.num_heads) if attn_output.size() != (1, self.num_heads, 1, self.head_dim): invalidInputError(False, - f"`attn_output` should be of size {(1, self.num_heads, 1, self.head_dim)}, but is" - f" {attn_output.size()}" - ) + f"`attn_output` should be of size " + f"{(1, self.num_heads, 1, self.head_dim)}, but is" + f" {attn_output.size()}") batched_attention_output.append(attn_output) # For loop ends # TODO: handle attention_weights later @@ -429,9 +431,9 @@ def llama_attention_selective_batching_forward_4_31( batched_attention_output.clear() if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): invalidInputError(False, - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) + f"`attn_output` should be of size " + f"{(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}") attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) @@ -440,7 +442,8 @@ def llama_attention_selective_batching_forward_4_31( # TODO: Assume always use_cache print(f"prefill with batch size {bsz}") for batch in range(bsz): - updated_past_key_values.append((key_states[batch: batch + 1, :, :, :], value_states[batch: batch+1, :, :, :])) + updated_past_key_values.append((key_states[batch: batch + 1, :, :, :], + value_states[batch: batch+1, :, :, :])) # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups).to(device, @@ -459,9 +462,9 @@ def llama_attention_selective_batching_forward_4_31( if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): invalidInputError(False, - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) + f"`attn_output` should be of size " + f"{(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}") attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) @@ -637,7 +640,7 @@ def llama_model_selective_batching_forward_4_31( attn_mask, (1, seq_length), inputs_embeds, past_key_value_length ) attention_mask[i] = new_mask - i+=1 + i += 1 hidden_states = inputs_embeds diff --git a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py index 949ae3c05a2..b0751e55430 100644 --- a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py +++ b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py @@ -182,7 +182,6 @@ def forward( # TODO: this could be deleted after prefill stage is also selective_batched decoding_attention_mask_list = [] decoding_position_ids = [] - # Attention_mask for decoding could also be a list of tensors due to inconsistent length of kv_cache # num_layers x len(seq_id) x (2 x torch.Tensor) if is_decoding_stage: batch = 0 @@ -198,12 +197,12 @@ def forward( cur_attention_mask = [0] * (current_seq_len - cur_pos + 1) + [1] * (cur_pos) decoding_attention_mask_list.append(cur_attention_mask) - bigdl_input_ids = torch.tensor(bigdl_input_ids, device=self.device) # TODO: prefill requests could also be sbed, so that we can remove attention_mask forever if is_decoding_stage: - attention_mask = [torch.tensor(x, device=self.device).unsqueeze(0) for x in decoding_attention_mask_list] + attention_mask = [torch.tensor(x, device=self.device).unsqueeze(0) + for x in decoding_attention_mask_list] position_ids = torch.tensor(decoding_position_ids).long().unsqueeze(-1) kwargs = { "input_ids": bigdl_input_ids, From 861c072b5d1e1e8548e8bee028cc901cf71800f3 Mon Sep 17 00:00:00 2001 From: gc-fu Date: Thu, 21 Dec 2023 13:29:46 +0800 Subject: [PATCH 18/26] format --- .../llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py index b0751e55430..cda6218f13c 100644 --- a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py +++ b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py @@ -206,7 +206,6 @@ def forward( position_ids = torch.tensor(decoding_position_ids).long().unsqueeze(-1) kwargs = { "input_ids": bigdl_input_ids, - # gc(co): we rely on underlying model to generate position_ids "position_ids": position_ids, "attention_mask": attention_mask, "past_key_values": bigdl_kv_cache, From 29b2e606b3f8f61f0bbe8cdc6ad7c75f41018e66 Mon Sep 17 00:00:00 2001 From: gc-fu Date: Thu, 21 Dec 2023 14:05:53 +0800 Subject: [PATCH 19/26] solid: add option to enable selective_batching --- .../llm/src/bigdl/llm/transformers/convert.py | 27 +++++++++++++------ 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 6b1ccec874e..3db80682f16 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -46,6 +46,7 @@ from .utils import logger from typing import Union import numpy as np +import os from bigdl.llm.utils.common import invalidInputError @@ -375,8 +376,6 @@ def _optimize_post(model, lightweight_bmm=False): from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31 from bigdl.llm.transformers.models.llama import llama_rms_norm_forward from bigdl.llm.transformers.models.llama import llama_mlp_forward - from bigdl.llm.transformers.models.llama import llama_attention_selective_batching_forward_4_31 - from bigdl.llm.transformers.models.llama import llama_model_selective_batching_forward_4_31 from transformers.modeling_utils import PreTrainedModel # All huggingface format models are inherited from `PreTrainedModel` @@ -385,12 +384,16 @@ def _optimize_post(model, lightweight_bmm=False): "supported for further optimizations") return model + enable_vllm_selective_batching = os.getenv("VLLM_ENABLE_SELECTIVE_BATCHING") + enable_vllm_selective_batching = True if enable_vllm_selective_batching is not None \ + and enable_vllm_selective_batching.lower()=="true" \ + else False trans_version = transformers.__version__ if version.parse(trans_version) >= version.parse("4.31.0"): convert_forward( model, transformers.models.llama.modeling_llama.LlamaAttention, - llama_attention_selective_batching_forward_4_31,) + llama_attention_forward_4_31,) convert_forward( model, transformers.models.llama.modeling_llama.LlamaRMSNorm, @@ -398,11 +401,19 @@ def _optimize_post(model, lightweight_bmm=False): convert_forward(model, transformers.models.llama.modeling_llama.LlamaMLP, llama_mlp_forward) - convert_forward( - model, - transformers.models.llama.modeling_llama.LlamaModel, - llama_model_selective_batching_forward_4_31, - ) + if enable_vllm_selective_batching: + from bigdl.llm.transformers.models.llama import llama_attention_selective_batching_forward_4_31 + from bigdl.llm.transformers.models.llama import llama_model_selective_batching_forward_4_31 + convert_forward( + model, + transformers.models.llama.modeling_llama.LlamaModel, + llama_model_selective_batching_forward_4_31, + ) + convert_forward( + model, + transformers.models.llama.modeling_llama.LlamaAttention, + llama_attention_selective_batching_forward_4_31, + ) else: # todo implement 4.28.0 ~ 4.30.2 pass From fa685429480377a38ab3171091a12b941e16bb86 Mon Sep 17 00:00:00 2001 From: gc-fu Date: Thu, 21 Dec 2023 14:12:59 +0800 Subject: [PATCH 20/26] remove logic for using transformer models --- .../vllm/model_executor/models/bigdl_llama.py | 106 +++++++++--------- 1 file changed, 53 insertions(+), 53 deletions(-) diff --git a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py index cda6218f13c..a02e2552c52 100644 --- a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py +++ b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py @@ -63,59 +63,59 @@ def __init__( super().__init__(config, device, max_model_len) self.config = config # TODO(gc): later change this to a switch? - use_bigdl_lowbit = os.getenv("VLLM_USE_BIGDL_LOWBIT", "") - if use_bigdl_lowbit.lower() == "true": - from bigdl.llm.transformers import AutoModelForCausalLM - from bigdl.llm import optimize_model - if device == 'cpu': - model = AutoModelForCausalLM.from_pretrained( - config._name_or_path, - low_cpu_mem_usage=True, - trust_remote_code=True, - use_cache=True, - ) - self.model = optimize_model(model) - self.sampler = BigDLSampler(config.vocab_size, device) - elif device == 'xpu': - try: - import intel_extension_for_pytorch as ipex - except ImportError: - print("Intel Extension for PyTorch is not installed, \ - but is required for xpu inference.") - - low_bit = 'sym_int4' - model = AutoModelForCausalLM.from_pretrained( - config._name_or_path, - load_in_low_bit=low_bit, - trust_remote_code=True, - use_cache=True, - ) - self.model = model.to('xpu') - self.sampler = BigDLSampler(config.vocab_size, device).to('xpu') - else: - from transformers import AutoModelForCausalLM - if device == 'cpu': - self.model = AutoModelForCausalLM.from_pretrained( - config._name_or_path, - low_cpu_mem_usage=True, - trust_remote_code=True, - use_cache=True, - ) - self.sampler = BigDLSampler(config.vocab_size, device) - elif device == 'xpu': - try: - import intel_extension_for_pytorch as ipex - except ImportError: - print("Intel Extension for PyTorch is not installed, \ - but is required for xpu inference.") - model = AutoModelForCausalLM.from_pretrained( - config._name_or_path, - trust_remote_code=True, - use_cache=True, - ) - self.model = model.to('xpu') - self.sampler = BigDLSampler( - config.vocab_size, device).to('xpu') + # use_bigdl_lowbit = os.getenv("VLLM_USE_BIGDL_LOWBIT", "") + # if use_bigdl_lowbit.lower() == "true": + from bigdl.llm.transformers import AutoModelForCausalLM + from bigdl.llm import optimize_model + if device == 'cpu': + model = AutoModelForCausalLM.from_pretrained( + config._name_or_path, + low_cpu_mem_usage=True, + trust_remote_code=True, + use_cache=True, + ) + self.model = optimize_model(model) + self.sampler = BigDLSampler(config.vocab_size, device) + elif device == 'xpu': + try: + import intel_extension_for_pytorch as ipex + except ImportError: + print("Intel Extension for PyTorch is not installed, \ + but is required for xpu inference.") + + low_bit = 'sym_int4' + model = AutoModelForCausalLM.from_pretrained( + config._name_or_path, + load_in_low_bit=low_bit, + trust_remote_code=True, + use_cache=True, + ) + self.model = model.to('xpu') + self.sampler = BigDLSampler(config.vocab_size, device).to('xpu') + # else: + # from transformers import AutoModelForCausalLM + # if device == 'cpu': + # self.model = AutoModelForCausalLM.from_pretrained( + # config._name_or_path, + # low_cpu_mem_usage=True, + # trust_remote_code=True, + # use_cache=True, + # ) + # self.sampler = BigDLSampler(config.vocab_size, device) + # elif device == 'xpu': + # try: + # import intel_extension_for_pytorch as ipex + # except ImportError: + # print("Intel Extension for PyTorch is not installed, \ + # but is required for xpu inference.") + # model = AutoModelForCausalLM.from_pretrained( + # config._name_or_path, + # trust_remote_code=True, + # use_cache=True, + # ) + # self.model = model.to('xpu') + # self.sampler = BigDLSampler( + # config.vocab_size, device).to('xpu') self.device = torch.device(device) self.dtype = self.model.dtype From b7d300ae648c3c5f2d01bb4a0c654b42286e0e5a Mon Sep 17 00:00:00 2001 From: gc-fu Date: Thu, 21 Dec 2023 14:23:02 +0800 Subject: [PATCH 21/26] format --- python/llm/src/bigdl/llm/transformers/convert.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 3db80682f16..b3a480a9b27 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -374,6 +374,8 @@ def convert_forward(m, target_m, new_forward): def _optimize_post(model, lightweight_bmm=False): from packaging import version from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31 + from bigdl.llm.transformers.models.llama import llama_attention_selective_batching_forward_4_31 + from bigdl.llm.transformers.models.llama import llama_model_selective_batching_forward_4_31 from bigdl.llm.transformers.models.llama import llama_rms_norm_forward from bigdl.llm.transformers.models.llama import llama_mlp_forward from transformers.modeling_utils import PreTrainedModel @@ -384,10 +386,10 @@ def _optimize_post(model, lightweight_bmm=False): "supported for further optimizations") return model - enable_vllm_selective_batching = os.getenv("VLLM_ENABLE_SELECTIVE_BATCHING") - enable_vllm_selective_batching = True if enable_vllm_selective_batching is not None \ - and enable_vllm_selective_batching.lower()=="true" \ - else False + vllm_selective_batching = os.getenv("VLLM_ENABLE_SELECTIVE_BATCHING") + enable_vllm_se_batching = vllm_selective_batching is not None + enable_vllm_se_batching = enable_vllm_se_batching and vllm_selective_batching.lower() == "true" + trans_version = transformers.__version__ if version.parse(trans_version) >= version.parse("4.31.0"): convert_forward( @@ -401,9 +403,8 @@ def _optimize_post(model, lightweight_bmm=False): convert_forward(model, transformers.models.llama.modeling_llama.LlamaMLP, llama_mlp_forward) - if enable_vllm_selective_batching: - from bigdl.llm.transformers.models.llama import llama_attention_selective_batching_forward_4_31 - from bigdl.llm.transformers.models.llama import llama_model_selective_batching_forward_4_31 + if enable_vllm_se_batching: + convert_forward( model, transformers.models.llama.modeling_llama.LlamaModel, From 2f4abf1b8e6cb5a4e3ec2ef4b78f7c02470b2b2e Mon Sep 17 00:00:00 2001 From: gc-fu Date: Thu, 21 Dec 2023 14:31:09 +0800 Subject: [PATCH 22/26] format --- python/llm/src/bigdl/llm/transformers/models/llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 860b970d9f0..27e35b5f064 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -440,7 +440,7 @@ def llama_attention_selective_batching_forward_4_31( return attn_output, None, updated_past_key_values # TODO: Assume always use_cache - print(f"prefill with batch size {bsz}") + # print(f"prefill with batch size {bsz}") for batch in range(bsz): updated_past_key_values.append((key_states[batch: batch + 1, :, :, :], value_states[batch: batch+1, :, :, :])) @@ -627,7 +627,7 @@ def llama_model_selective_batching_forward_4_31( # TODO: only generate attention_mask for prefilling if attention_mask is None: invalidInputError(False, "attention_mask should never be None") - print(f"attention_mask before expanding: {attention_mask}") + # print(f"attention_mask before expanding: {attention_mask}") if past_key_values is None: attention_mask = self._prepare_decoder_attention_mask( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length From 3e908f793cef1fbe615f41788d78916a00455a23 Mon Sep 17 00:00:00 2001 From: gc-fu Date: Thu, 21 Dec 2023 15:44:06 +0800 Subject: [PATCH 23/26] solid: enable argument VLLM_ENABLE_SELECTIVE_BATCHING --- .../llm/src/bigdl/llm/transformers/convert.py | 1 - .../vllm/model_executor/models/bigdl_llama.py | 92 +++++++++---------- .../vllm/model_executor/models/bigdl_model.py | 28 +++++- 3 files changed, 69 insertions(+), 52 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index b3a480a9b27..484f9bd2f6f 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -404,7 +404,6 @@ def _optimize_post(model, lightweight_bmm=False): transformers.models.llama.modeling_llama.LlamaMLP, llama_mlp_forward) if enable_vllm_se_batching: - convert_forward( model, transformers.models.llama.modeling_llama.LlamaModel, diff --git a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py index a02e2552c52..727ee8c71e3 100644 --- a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py +++ b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py @@ -51,6 +51,10 @@ def _get_attention_mask_for_prompts( ] return attention_mask +vllm_selective_batching = os.getenv("VLLM_ENABLE_SELECTIVE_BATCHING") +enable_vllm_se_batching = vllm_selective_batching is not None +enable_vllm_se_batching = enable_vllm_se_batching and vllm_selective_batching.lower() == "true" + class BigDLLlamaForCausalLM(BigDLModelForCausalLM): @@ -62,9 +66,7 @@ def __init__( ): super().__init__(config, device, max_model_len) self.config = config - # TODO(gc): later change this to a switch? - # use_bigdl_lowbit = os.getenv("VLLM_USE_BIGDL_LOWBIT", "") - # if use_bigdl_lowbit.lower() == "true": + # Always enable bigdl-llm model from bigdl.llm.transformers import AutoModelForCausalLM from bigdl.llm import optimize_model if device == 'cpu': @@ -92,30 +94,6 @@ def __init__( ) self.model = model.to('xpu') self.sampler = BigDLSampler(config.vocab_size, device).to('xpu') - # else: - # from transformers import AutoModelForCausalLM - # if device == 'cpu': - # self.model = AutoModelForCausalLM.from_pretrained( - # config._name_or_path, - # low_cpu_mem_usage=True, - # trust_remote_code=True, - # use_cache=True, - # ) - # self.sampler = BigDLSampler(config.vocab_size, device) - # elif device == 'xpu': - # try: - # import intel_extension_for_pytorch as ipex - # except ImportError: - # print("Intel Extension for PyTorch is not installed, \ - # but is required for xpu inference.") - # model = AutoModelForCausalLM.from_pretrained( - # config._name_or_path, - # trust_remote_code=True, - # use_cache=True, - # ) - # self.model = model.to('xpu') - # self.sampler = BigDLSampler( - # config.vocab_size, device).to('xpu') self.device = torch.device(device) self.dtype = self.model.dtype @@ -170,8 +148,12 @@ def forward( # 1. Assemble bigdl_input_ids end if is_decoding_stage: - bigdl_kv_cache = self.prepare_kv_cache_llama(cur_seq_ids, - kv_cache, num_layers) + construct_kv_cache_func = self.get_construct_kv_cache_func(enable_vllm_se_batching) + bigdl_kv_cache = construct_kv_cache_func(cur_seq_ids, + seq_group_meta_data_lists, + kv_cache, + num_layers, + 2) else: bigdl_attention_mask = _get_attention_mask_for_prompts(bigdl_input_ids, max_prompt_len) bigdl_input_ids = [ @@ -179,30 +161,44 @@ def forward( for input_ids in bigdl_input_ids ] - # TODO: this could be deleted after prefill stage is also selective_batched decoding_attention_mask_list = [] decoding_position_ids = [] # num_layers x len(seq_id) x (2 x torch.Tensor) if is_decoding_stage: - batch = 0 - for seq_group_meta_data in seq_group_meta_data_lists: - # Get current seq_len in kv_cache - current_seq_len = bigdl_kv_cache[0][batch][0].size(2) - batch += 1 - seq_ids = list(seq_group_meta_data.seq_data.keys()) - seq_data = seq_group_meta_data.seq_data[seq_ids[0]] - cur_pos = seq_data.get_len() - decoding_position_ids.append(cur_pos - 1) - # Total length: current_seq_len + 1 - cur_attention_mask = [0] * (current_seq_len - cur_pos + 1) + [1] * (cur_pos) - decoding_attention_mask_list.append(cur_attention_mask) + if enable_vllm_se_batching: + batch = 0 + for seq_group_meta_data in seq_group_meta_data_lists: + # Get current seq_len in kv_cache + current_seq_len = bigdl_kv_cache[0][batch][0].size(2) + batch += 1 + seq_ids = list(seq_group_meta_data.seq_data.keys()) + seq_data = seq_group_meta_data.seq_data[seq_ids[0]] + cur_pos = seq_data.get_len() + decoding_position_ids.append(cur_pos - 1) + # Total length: current_seq_len + 1 + cur_attention_mask = [0] * (current_seq_len - cur_pos + 1) + [1] * (cur_pos) + decoding_attention_mask_list.append(cur_attention_mask) + else: + cur_seq_len = bigdl_kv_cache[0][0].size(2) + for seq_group_meta_data in seq_group_meta_data_lists: + seq_ids = list(seq_group_meta_data.seq_data.keys()) + seq_id = seq_ids[0] + seq_data = seq_group_meta_data.seq_data[seq_id] + cur_pos = seq_data.get_len() + # bigdl_position_ids.append([cur_pos - 1]) + decoding_position_ids.append(cur_pos - 1) + cur_attention_mask = [0] * (cur_seq_len - cur_pos + 1) + [1] * (cur_pos) + decoding_attention_mask_list.append(cur_attention_mask) bigdl_input_ids = torch.tensor(bigdl_input_ids, device=self.device) # TODO: prefill requests could also be sbed, so that we can remove attention_mask forever if is_decoding_stage: - attention_mask = [torch.tensor(x, device=self.device).unsqueeze(0) - for x in decoding_attention_mask_list] + if enable_vllm_se_batching: + attention_mask = [torch.tensor(x, device=self.device).unsqueeze(0) + for x in decoding_attention_mask_list] + else: + attention_mask = torch.tensor(decoding_attention_mask_list, device=self.device) position_ids = torch.tensor(decoding_position_ids).long().unsqueeze(-1) kwargs = { "input_ids": bigdl_input_ids, @@ -247,8 +243,12 @@ def forward( # tmp = torch.xpu.memory_stats() # logger.info(f"before: {tmp['allocated_bytes.all.current']}") - self.update_kv_cache(cur_seq_ids, - kv_cache, num_layers, decoder_kv_size) + if enable_vllm_se_batching: + self.update_kv_cache_selective_batching( + cur_seq_ids, kv_cache, num_layers, decoder_kv_size) + self.last_kv_cache = None + else: + self.update_kv_cache(cur_seq_ids, kv_cache, num_layers, decoder_kv_size) # tmp = torch.xpu.memory_stats() # logger.info(f"after: {tmp['allocated_bytes.all.current']}") diff --git a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_model.py b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_model.py index 658f2155db0..a81993dceae 100644 --- a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_model.py +++ b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_model.py @@ -137,13 +137,21 @@ def prepare_kv_cache( return bigdl_kv_cache + def get_construct_kv_cache_func(self, enable_selective_batching): + if enable_selective_batching: + return self.prepare_kv_cache_selective_batching + else: + return self.prepare_kv_cache + # This is an implementation for models that KV Cache shape in (batch_size, num_heads, # sequence_length, embed_size_per_head). - def prepare_kv_cache_llama( + def prepare_kv_cache_selective_batching( self, cur_seq_ids: List[int], + seq_group_meta_data_lists: List[SequenceGroupMetadata], kv_cache: Dict, num_layers: int, + kv_cache_size_1: int, ): # Return bigdl_kv_cache in the format of Tuple(List[Tuple(torch.Tensor)]) bigdl_kv_cache = [] @@ -157,10 +165,6 @@ def prepare_kv_cache_llama( bigdl_kv_cache.append(temp_cache) return bigdl_kv_cache - # for i in range(len(cur_seq_ids)): - # current_kv = [] - # current_kv.append(kv_cache) - # This is an implementation for models that KV Cache shape in (batch_size, num_heads, # sequence_length, embed_size_per_head). def update_kv_cache( @@ -169,6 +173,20 @@ def update_kv_cache( kv_cache, layer: int, kv_cache_size_1: int, + ) -> None: + for i in range(layer): + for j in range(kv_cache_size_1): + batch_dim = 0 + for seq_id in cur_seq_ids: + kv_cache[i][j][seq_id] = self.last_kv_cache[i][j][batch_dim] + batch_dim = batch_dim + 1 + + def update_kv_cache_selective_batching( + self, + cur_seq_ids: List[int], + kv_cache, + layer: int, + kv_cache_size_1: int, ) -> None: for i in range(layer): for j in range(len(cur_seq_ids)): From f1216aebad99cc951d93dc884f70d06c7d0b131e Mon Sep 17 00:00:00 2001 From: gc-fu Date: Thu, 21 Dec 2023 15:50:05 +0800 Subject: [PATCH 24/26] format --- .../llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py index 727ee8c71e3..0be27eccb9d 100644 --- a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py +++ b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py @@ -196,7 +196,7 @@ def forward( if is_decoding_stage: if enable_vllm_se_batching: attention_mask = [torch.tensor(x, device=self.device).unsqueeze(0) - for x in decoding_attention_mask_list] + for x in decoding_attention_mask_list] else: attention_mask = torch.tensor(decoding_attention_mask_list, device=self.device) position_ids = torch.tensor(decoding_position_ids).long().unsqueeze(-1) From ec4a84e1a81fc85368a9afd4e8d38746fb7d7dc0 Mon Sep 17 00:00:00 2001 From: gc-fu Date: Fri, 22 Dec 2023 10:39:05 +0800 Subject: [PATCH 25/26] finish --- .../bigdl/llm/transformers/models/llama.py | 28 ++++++++----------- .../llm/src/bigdl/llm/vllm/core/scheduler.py | 2 -- .../vllm/model_executor/models/bigdl_llama.py | 10 +++++-- 3 files changed, 19 insertions(+), 21 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 27e35b5f064..682af89f4ed 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -410,15 +410,15 @@ def llama_attention_selective_batching_forward_4_31( current_value_states = repeat_kv(current_value_states, self.num_key_value_groups) current_query_states = query_states[batch: batch + 1, :, :, :] - attn_output, aattn_weights = native_sdp(current_query_states, - current_key_states, - current_value_states, - attention_mask[batch], - 1, - 1, - current_kv_len, - self.head_dim, - self.num_heads) + attn_output, attn_weights = native_sdp(current_query_states, + current_key_states, + current_value_states, + attention_mask[batch], + 1, + 1, + current_kv_len, + self.head_dim, + self.num_heads) if attn_output.size() != (1, self.num_heads, 1, self.head_dim): invalidInputError(False, f"`attn_output` should be of size " @@ -578,7 +578,7 @@ def llama_model_selective_batching_forward_4_31( "You have to specify either " "decoder_input_ids or decoder_inputs_embeds") - seq_length_with_past = seq_length + # seq_length_with_past = seq_length past_key_values_length = 0 # The original position_ids in the format of [1, 1] @@ -608,7 +608,7 @@ def llama_model_selective_batching_forward_4_31( # past_key_values in the format of num_layers x num_seqs x 2 # TODO: this may be incorrect past_key_values_length = past_key_values[0][0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length + # seq_length_with_past = seq_length_with_past + past_key_values_length # if position_ids is None: # device = input_ids.device if input_ids is not None else inputs_embeds.device @@ -624,7 +624,6 @@ def llama_model_selective_batching_forward_4_31( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) # embed positions - # TODO: only generate attention_mask for prefilling if attention_mask is None: invalidInputError(False, "attention_mask should never be None") # print(f"attention_mask before expanding: {attention_mask}") @@ -645,8 +644,7 @@ def llama_model_selective_batching_forward_4_31( hidden_states = inputs_embeds if self.gradient_checkpointing and self.training: - if use_cache: - use_cache = False + invalidInputError(False, "gradient_checkpointing is not supported") # decoder layers all_hidden_states = () if output_hidden_states else None @@ -678,8 +676,6 @@ def custom_forward(*inputs): else: layer_outputs = decoder_layer( hidden_states, - # TODO: decide if we need this attention_mask, - # we are not using the attention mask when decoding attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, diff --git a/python/llm/src/bigdl/llm/vllm/core/scheduler.py b/python/llm/src/bigdl/llm/vllm/core/scheduler.py index cc67638a576..b41ea166d45 100644 --- a/python/llm/src/bigdl/llm/vllm/core/scheduler.py +++ b/python/llm/src/bigdl/llm/vllm/core/scheduler.py @@ -226,8 +226,6 @@ def _schedule(self) -> SchedulerOutputs: num_batched_tokens += num_prompt_tokens num_curr_seqs += num_new_seqs scheduled.append(seq_group) - # TODO: we choose to not batching the prefill requests - # break if scheduled or ignored_seq_groups: scheduler_outputs = SchedulerOutputs( diff --git a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py index 0be27eccb9d..fa5e0a9bd59 100644 --- a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py +++ b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py @@ -197,9 +197,10 @@ def forward( if enable_vllm_se_batching: attention_mask = [torch.tensor(x, device=self.device).unsqueeze(0) for x in decoding_attention_mask_list] + position_ids = torch.tensor(decoding_position_ids).long().unsqueeze(-1) else: attention_mask = torch.tensor(decoding_attention_mask_list, device=self.device) - position_ids = torch.tensor(decoding_position_ids).long().unsqueeze(-1) + position_ids = None kwargs = { "input_ids": bigdl_input_ids, "position_ids": position_ids, @@ -211,8 +212,11 @@ def forward( else: # Prefill stage attention_mask = torch.tensor(bigdl_attention_mask, device=self.device) - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) + if enable_vllm_se_batching: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + else: + position_ids = None kwargs = { "input_ids": bigdl_input_ids, "attention_mask": attention_mask, From a473fbcde50c54b813fe6fa701a5949008b45115 Mon Sep 17 00:00:00 2001 From: gc-fu Date: Fri, 22 Dec 2023 10:59:33 +0800 Subject: [PATCH 26/26] format --- .../src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py index fa5e0a9bd59..eb6fa282c7f 100644 --- a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py +++ b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py @@ -186,13 +186,12 @@ def forward( seq_data = seq_group_meta_data.seq_data[seq_id] cur_pos = seq_data.get_len() # bigdl_position_ids.append([cur_pos - 1]) - decoding_position_ids.append(cur_pos - 1) + # decoding_position_ids.append(cur_pos - 1) cur_attention_mask = [0] * (cur_seq_len - cur_pos + 1) + [1] * (cur_pos) decoding_attention_mask_list.append(cur_attention_mask) bigdl_input_ids = torch.tensor(bigdl_input_ids, device=self.device) - # TODO: prefill requests could also be sbed, so that we can remove attention_mask forever if is_decoding_stage: if enable_vllm_se_batching: attention_mask = [torch.tensor(x, device=self.device).unsqueeze(0)