From 9361434ebfe728b2b2ea79654021d44d6cca7b4a Mon Sep 17 00:00:00 2001 From: Kalyan Kumar Date: Tue, 10 Dec 2024 14:04:42 +0530 Subject: [PATCH 1/5] [SW-207965] Add batch splitting in attention layer to hide NIC latency (#14) - Introduced the `--attn_batch_split` parameter to enable batch splitting in the attention and mlp layer. - This approach aims to overlap communication and computation, effectively hiding NIC latency during distributed attention operations. - Perform the add in the beginning of the next layer for better pipelining - Updated Readme - [SW-212702] Fix the attn_batch_split argument specific to llama config (#74) Co-authored-by: Kalyan --- examples/text-generation/README.md | 1 + examples/text-generation/run_generation.py | 6 + examples/text-generation/utils.py | 4 + .../generation/configuration_utils.py | 3 + .../habana/transformers/generation/utils.py | 3 + .../models/llama/modeling_llama.py | 217 ++++++++++++++---- 6 files changed, 186 insertions(+), 48 deletions(-) diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index 6da9bc8470..d8c5f9e5d4 100755 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -129,6 +129,7 @@ Here are a few settings you may be interested in: - `--prompt` to benchmark the model on one or several prompts of your choice - `--attn_softmax_bf16` to run attention softmax layer in bfloat16 precision provided that the model (such as Llama) supports it - `--trim_logits` to calculate logits only for the last token in the first time step provided that the model (such as Llama) supports it +- `--attn_batch_split` specifies the number of smaller batches to split the attention and MLP processing into for better parallelization.By default, no splitting is performed (value is 1). Splitting is enabled only for prompt. For example, you can reproduce the results presented in [this blog post](https://huggingface.co/blog/habana-gaudi-2-bloom) with the following command: ```bash diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 5355ceb1b6..6c0a4491fe 100755 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -343,6 +343,12 @@ def setup_parser(parser): default=None, help="Path to neural-compressor quantized model, if set, the checkpoint will be loaded.", ) + quant_parser_group.add_argument( + "--attn_batch_split", + default=1, + type=int, + help="Specify the batch size split for attention and mlp layers. 1 for no split. This is enabled only for prompt.", + ) args = parser.parse_args() diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index adfcb48b08..637318b7be 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -665,6 +665,7 @@ def setup_generation_config(args, model, assistant_model, tokenizer): generation_config.flash_attention_fast_softmax = args.flash_attention_fast_softmax generation_config.trust_remote_code = args.trust_remote_code generation_config.valid_sequence_lengths = None + generation_config.attn_batch_split = args.attn_batch_split return generation_config @@ -689,6 +690,9 @@ def exclude_hpu_graph_configs(args): def initialize_model(args, logger): init_start = time.perf_counter() setup_distributed(args) + if not args.world_size > 0 and args.attn_batch_split > 1: + logger.warning("Disabling attention batch splitting as it's unnecessary for single-card execution") + args.attn_batch_split = 1 if exclude_hpu_graph_configs(args): args.limit_hpu_graphs = False override_prints(args.global_rank == 0 or args.verbose_workers, logger) diff --git a/optimum/habana/transformers/generation/configuration_utils.py b/optimum/habana/transformers/generation/configuration_utils.py index ec04f139c9..5a4c797a2b 100644 --- a/optimum/habana/transformers/generation/configuration_utils.py +++ b/optimum/habana/transformers/generation/configuration_utils.py @@ -37,6 +37,8 @@ class GaudiGenerationConfig(GenerationConfig): Whether to enable causal_mask if use Habana flash attention. flash_attention_fast_softmax_mode (`bool`, *optional*): Whether to use fast softmax with reduced precision if use Habana flash attention. + attn_batch_split (`int`, *optional*): + Specify the batch size split for attention and mlp layers. 1 for no split. This is enabled only for prompt. """ def __init__(self, **kwargs): @@ -56,3 +58,4 @@ def __init__(self, **kwargs): self.flash_attention_fast_softmax = kwargs.get("flash_attention_fast_softmax", None) self.use_fused_rope = kwargs.get("use_fused_rope", None) self.valid_sequence_lengths = kwargs.get("valid_sequence_lengths", None) + self.attn_batch_split = kwargs.get("attn_batch_split", 1) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index d81e0d179a..f371b51c9a 100644 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -1266,6 +1266,9 @@ def generate( # prepare for allocate kv cache model_kwargs["reuse_cache"] = generation_config.reuse_cache + # prepare for attention batch splitting + model_kwargs["attn_batch_split"] = generation_config.attn_batch_split + # determine whether flash attention needs to be used model_kwargs["use_flash_attention"] = generation_config.use_flash_attention model_kwargs["flash_attention_recompute"] = True if generation_config.flash_attention_recompute else False diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 55d4475a87..a9aff11c00 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -893,7 +893,6 @@ class GaudiLlamaDecoderLayer(LlamaDecoderLayer): def __init__(self, config: LlamaConfig, layer_idx: int): super(LlamaDecoderLayer, self).__init__() self.hidden_size = config.hidden_size - self.self_attn = GaudiLlamaAttention(config=config, layer_idx=layer_idx) self.mlp = GaudiLlamaMLP(config) @@ -929,6 +928,8 @@ def forward( valid_sequence_lengths: Optional[torch.Tensor] = None, cache_idx: int = None, num_virtual_tokens: int = None, + attn_batch_split: int = 1, + prev_layer_residual: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -942,33 +943,99 @@ def forward( - add new arg flash_attention_causal_mask - add new arg flash_attention_fast_softmax """ - residual = hidden_states + if attn_batch_split > 1 and past_key_value is None: + # Calculate split sizes to handle cases where batch size is not divisible by attn_batch_split + batch_size = attention_mask.size(0) + base_split_size = batch_size // attn_batch_split + remainder = batch_size % attn_batch_split + + split_sizes = [base_split_size + 1 if i < remainder else base_split_size for i in range(attn_batch_split)] + + # Split tensors using the calculated sizes + sub_attention_mask = torch.split(attention_mask, split_sizes, dim=0) + sub_position_ids = torch.split(position_ids, split_sizes, dim=0) + sub_valid_sequence_lengths = torch.split(valid_sequence_lengths, split_sizes, dim=0) + split_attn_weights = [] + split_present_key_values = [] + split_hidden_states = [None] * attn_batch_split + residual = [None] * attn_batch_split + + for i in range(attn_batch_split): + split_hidden_states[i] = hidden_states[i] + if self.self_attn.layer_idx != 0: + # Add the residual from the previous layer + split_hidden_states[i] = self.post_mlp(hidden_states[i], prev_layer_residual[i]) + + residual[i] = split_hidden_states[i] + split_hidden_states[i], self_attn_weights, present_key_value = self.pre_attn( + hidden_states=split_hidden_states[i], + attention_mask=sub_attention_mask[i], + position_ids=sub_position_ids[i], + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + token_idx=token_idx, + attn_softmax_bf16=attn_softmax_bf16, + reuse_cache=reuse_cache, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, + flash_attention_fast_softmax=flash_attention_fast_softmax, + valid_sequence_lengths=sub_valid_sequence_lengths[i], + cache_idx=cache_idx, + num_virtual_tokens=num_virtual_tokens, + **kwargs, + ) + self.self_attn.attention_all_reduce(split_hidden_states[i]) + if output_attentions: + split_attn_weights.append(self_attn_weights) + if use_cache: + split_present_key_values.append(present_key_value) - hidden_states, self_attn_weights, present_key_value = self.pre_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - token_idx=token_idx, - attn_softmax_bf16=attn_softmax_bf16, - reuse_cache=reuse_cache, - use_flash_attention=use_flash_attention, - flash_attention_recompute=flash_attention_recompute, - flash_attention_causal_mask=flash_attention_causal_mask, - flash_attention_fast_softmax=flash_attention_fast_softmax, - valid_sequence_lengths=valid_sequence_lengths, - cache_idx=cache_idx, - num_virtual_tokens=num_virtual_tokens, - **kwargs, - ) - self.self_attn.attention_all_reduce(hidden_states) - hidden_states, residual = self.post_attn_pre_mlp(hidden_states, residual) - self.mlp.mlp_all_reduce(hidden_states) - hidden_states = self.post_mlp(hidden_states, residual) + self_attn_weights = torch.cat(split_attn_weights, dim=0) if split_attn_weights else None + present_key_value = [torch.cat(tensors, dim=0) for tensors in zip(*split_present_key_values)] + + int_residual_splits = [] + for i in range(attn_batch_split): + split_hidden_states[i], int_residual = self.post_attn_pre_mlp(split_hidden_states[i], residual[i]) + self.mlp.mlp_all_reduce(split_hidden_states[i]) + int_residual_splits.append(int_residual) + + if self.self_attn.layer_idx == (self.self_attn.config.num_hidden_layers - 1): + for i in range(attn_batch_split): + split_hidden_states[i] = self.post_mlp(split_hidden_states[i], int_residual_splits[i]) + + hidden_states = split_hidden_states + + else: + residual = hidden_states + hidden_states, self_attn_weights, present_key_value = self.pre_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + token_idx=token_idx, + attn_softmax_bf16=attn_softmax_bf16, + reuse_cache=reuse_cache, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, + flash_attention_fast_softmax=flash_attention_fast_softmax, + valid_sequence_lengths=valid_sequence_lengths, + cache_idx=cache_idx, + num_virtual_tokens=num_virtual_tokens, + **kwargs, + ) + self.self_attn.attention_all_reduce(hidden_states) + hidden_states, residual = self.post_attn_pre_mlp(hidden_states, residual) + self.mlp.mlp_all_reduce(hidden_states) + hidden_states = self.post_mlp(hidden_states, residual) outputs = (hidden_states,) @@ -976,6 +1043,9 @@ def forward( outputs += (self_attn_weights,) if use_cache: outputs += (present_key_value,) + # Store the residual spits to add them in the beginning of the next layer + if attn_batch_split > 1 and past_key_value is None: + outputs += (int_residual_splits,) return outputs @@ -1021,6 +1091,7 @@ def pre_attn( cache_idx=cache_idx, num_virtual_tokens=num_virtual_tokens, ) + return hidden_states, attn_weights, present_key_value def post_attn_pre_mlp(self, hidden_states, residual): @@ -1115,6 +1186,7 @@ def forward( cache_idx: int = None, lazy_mode: Optional[bool] = True, num_virtual_tokens: int = None, + attn_batch_split: int = 1, ) -> Union[Tuple, BaseModelOutputWithPast]: """ Copied from LlamaModel.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py @@ -1222,6 +1294,18 @@ def forward( if lazy_mode: htcore.mark_step() + split_prompt = False + if attn_batch_split > 1 and past_key_values is None: + # Calculate split sizes to handle cases where batch size is not divisible by attn_batch_split + batch_size = hidden_states.size(0) + base_split_size = batch_size // attn_batch_split + remainder = batch_size % attn_batch_split + split_sizes = [base_split_size + 1 if i < remainder else base_split_size for i in range(attn_batch_split)] + # Split tensors using the calculated sizes + hidden_states = torch.split(hidden_states, split_sizes, dim=0) + split_prompt = True + prev_layer_residual = None + for layer_idx, decoder_layer in enumerate(self.layers): if ( lazy_mode @@ -1231,7 +1315,10 @@ def forward( htcore.mark_step() if output_hidden_states: - all_hidden_states += (hidden_states,) + if split_prompt: + all_hidden_states += (torch.cat(hidden_states, dim=0),) + else: + all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( @@ -1255,26 +1342,54 @@ def forward( None, ) else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=None if past_key_values is None else past_key_values[layer_idx], - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - token_idx=token_idx, - attn_softmax_bf16=attn_softmax_bf16, - reuse_cache=reuse_cache, - use_flash_attention=use_flash_attention, - flash_attention_recompute=flash_attention_recompute, - flash_attention_causal_mask=flash_attention_causal_mask, - flash_attention_fast_softmax=flash_attention_fast_softmax, - valid_sequence_lengths=valid_sequence_lengths, - cache_idx=cache_idx, - num_virtual_tokens=num_virtual_tokens, - ) + if attn_batch_split > 1 and past_key_values is None: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=None if past_key_values is None else past_key_values[layer_idx], + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + token_idx=token_idx, + attn_softmax_bf16=attn_softmax_bf16, + reuse_cache=reuse_cache, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, + flash_attention_fast_softmax=flash_attention_fast_softmax, + valid_sequence_lengths=valid_sequence_lengths, + cache_idx=cache_idx, + num_virtual_tokens=num_virtual_tokens, + attn_batch_split=attn_batch_split, + prev_layer_residual=prev_layer_residual, + ) + index = 1 + int(use_cache) + int(output_attentions) + prev_layer_residual = layer_outputs[index] + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=None if past_key_values is None else past_key_values[layer_idx], + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + token_idx=token_idx, + attn_softmax_bf16=attn_softmax_bf16, + reuse_cache=reuse_cache, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, + flash_attention_fast_softmax=flash_attention_fast_softmax, + valid_sequence_lengths=valid_sequence_lengths, + cache_idx=cache_idx, + num_virtual_tokens=num_virtual_tokens, + attn_batch_split=attn_batch_split, + ) + hidden_states = layer_outputs[0] if use_cache: @@ -1283,6 +1398,9 @@ def forward( if output_attentions: all_self_attns += (layer_outputs[1],) + if split_prompt: + hidden_states = torch.cat(hidden_states, dim=0) + hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer @@ -1354,6 +1472,7 @@ def forward( cache_idx: int = None, lazy_mode: Optional[bool] = True, num_virtual_tokens: int = None, + attn_batch_split: int = 1, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1387,6 +1506,7 @@ def forward( cache_idx=cache_idx, lazy_mode=lazy_mode, num_virtual_tokens=num_virtual_tokens, + attn_batch_split=attn_batch_split, ) hidden_states = outputs[0] _, seq_len, _ = hidden_states.shape @@ -1542,6 +1662,7 @@ def prepare_inputs_for_generation( "cache_idx": kwargs.get("cache_idx"), "lazy_mode": kwargs.get("lazy_mode"), "num_virtual_tokens": kwargs.get("num_virtual_tokens"), + "attn_batch_split": kwargs.get("attn_batch_split"), } ) return model_inputs From 624dcfa5ae0ec37dd130ec0cf639d32e0d0ff8bc Mon Sep 17 00:00:00 2001 From: Kalyan Date: Fri, 20 Dec 2024 11:17:16 +0200 Subject: [PATCH 2/5] Updated readme for typo and argument group for attn_batch_split --- examples/text-generation/README.md | 2 +- examples/text-generation/run_generation.py | 2 +- .../models/llama/modeling_llama.py | 72 +++++++------------ 3 files changed, 28 insertions(+), 48 deletions(-) diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index d8c5f9e5d4..bd81c7f7fe 100755 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -129,7 +129,7 @@ Here are a few settings you may be interested in: - `--prompt` to benchmark the model on one or several prompts of your choice - `--attn_softmax_bf16` to run attention softmax layer in bfloat16 precision provided that the model (such as Llama) supports it - `--trim_logits` to calculate logits only for the last token in the first time step provided that the model (such as Llama) supports it -- `--attn_batch_split` specifies the number of smaller batches to split the attention and MLP processing into for better parallelization.By default, no splitting is performed (value is 1). Splitting is enabled only for prompt. +- `--attn_batch_split` specifies the number of smaller batches into which attention and MLP processing are split to improve parallelization. By default, no splitting is performed (value is 1). Splitting is enabled only for prompt processing. This configuration is most effective for batch sizes (BS) > 125 and tensor parallelism (TP) >= 2, with a recommended value of '3' splits. For example, you can reproduce the results presented in [this blog post](https://huggingface.co/blog/habana-gaudi-2-bloom) with the following command: ```bash diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 6c0a4491fe..a20783511d 100755 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -343,7 +343,7 @@ def setup_parser(parser): default=None, help="Path to neural-compressor quantized model, if set, the checkpoint will be loaded.", ) - quant_parser_group.add_argument( + parser.add_argument( "--attn_batch_split", default=1, type=int, diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index a9aff11c00..facd76552a 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -944,6 +944,7 @@ def forward( - add new arg flash_attention_fast_softmax """ if attn_batch_split > 1 and past_key_value is None: + print(" ########################## PRROMPT") # Calculate split sizes to handle cases where batch size is not divisible by attn_batch_split batch_size = attention_mask.size(0) base_split_size = batch_size // attn_batch_split @@ -1295,6 +1296,7 @@ def forward( htcore.mark_step() split_prompt = False + prev_layer_residual = None if attn_batch_split > 1 and past_key_values is None: # Calculate split sizes to handle cases where batch size is not divisible by attn_batch_split batch_size = hidden_states.size(0) @@ -1304,7 +1306,6 @@ def forward( # Split tensors using the calculated sizes hidden_states = torch.split(hidden_states, split_sizes, dim=0) split_prompt = True - prev_layer_residual = None for layer_idx, decoder_layer in enumerate(self.layers): if ( @@ -1342,53 +1343,32 @@ def forward( None, ) else: - if attn_batch_split > 1 and past_key_values is None: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=None if past_key_values is None else past_key_values[layer_idx], - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - token_idx=token_idx, - attn_softmax_bf16=attn_softmax_bf16, - reuse_cache=reuse_cache, - use_flash_attention=use_flash_attention, - flash_attention_recompute=flash_attention_recompute, - flash_attention_causal_mask=flash_attention_causal_mask, - flash_attention_fast_softmax=flash_attention_fast_softmax, - valid_sequence_lengths=valid_sequence_lengths, - cache_idx=cache_idx, - num_virtual_tokens=num_virtual_tokens, - attn_batch_split=attn_batch_split, - prev_layer_residual=prev_layer_residual, - ) + use_prev_layer_residual = attn_batch_split > 1 and past_key_values is None + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=None if past_key_values is None else past_key_values[layer_idx], + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + token_idx=token_idx, + attn_softmax_bf16=attn_softmax_bf16, + reuse_cache=reuse_cache, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, + flash_attention_fast_softmax=flash_attention_fast_softmax, + valid_sequence_lengths=valid_sequence_lengths, + cache_idx=cache_idx, + num_virtual_tokens=num_virtual_tokens, + attn_batch_split=attn_batch_split, + prev_layer_residual=prev_layer_residual, + ) + if use_prev_layer_residual: index = 1 + int(use_cache) + int(output_attentions) prev_layer_residual = layer_outputs[index] - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=None if past_key_values is None else past_key_values[layer_idx], - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - token_idx=token_idx, - attn_softmax_bf16=attn_softmax_bf16, - reuse_cache=reuse_cache, - use_flash_attention=use_flash_attention, - flash_attention_recompute=flash_attention_recompute, - flash_attention_causal_mask=flash_attention_causal_mask, - flash_attention_fast_softmax=flash_attention_fast_softmax, - valid_sequence_lengths=valid_sequence_lengths, - cache_idx=cache_idx, - num_virtual_tokens=num_virtual_tokens, - attn_batch_split=attn_batch_split, - ) hidden_states = layer_outputs[0] From 4cd5e8f83c4734a4a75e8e67baf681721638c503 Mon Sep 17 00:00:00 2001 From: Kalyan Kumar Date: Fri, 20 Dec 2024 22:55:16 +0530 Subject: [PATCH 3/5] Remove debug prints --- optimum/habana/transformers/models/llama/modeling_llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index facd76552a..f4bc485d9e 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -944,7 +944,6 @@ def forward( - add new arg flash_attention_fast_softmax """ if attn_batch_split > 1 and past_key_value is None: - print(" ########################## PRROMPT") # Calculate split sizes to handle cases where batch size is not divisible by attn_batch_split batch_size = attention_mask.size(0) base_split_size = batch_size // attn_batch_split From e542cbf25ce3d532f9b7ff6679b8ced8b805c70f Mon Sep 17 00:00:00 2001 From: Kalyan Date: Thu, 9 Jan 2025 06:54:23 +0200 Subject: [PATCH 4/5] Updated check for prev_layer_residual add. --- optimum/habana/transformers/models/llama/modeling_llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index f4bc485d9e..10cf33cc3c 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -1363,7 +1363,7 @@ def forward( cache_idx=cache_idx, num_virtual_tokens=num_virtual_tokens, attn_batch_split=attn_batch_split, - prev_layer_residual=prev_layer_residual, + prev_layer_residual=prev_layer_residual if use_prev_layer_residual else None, ) if use_prev_layer_residual: index = 1 + int(use_cache) + int(output_attentions) From 7151f37bdcc048fc012cea61fe2b73c1e0fa1362 Mon Sep 17 00:00:00 2001 From: Kalyan Date: Wed, 15 Jan 2025 08:35:22 +0200 Subject: [PATCH 5/5] Remove redundant cat op from attn_split flow --- .../models/llama/modeling_llama.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 10cf33cc3c..f069b8d300 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -1303,7 +1303,7 @@ def forward( remainder = batch_size % attn_batch_split split_sizes = [base_split_size + 1 if i < remainder else base_split_size for i in range(attn_batch_split)] # Split tensors using the calculated sizes - hidden_states = torch.split(hidden_states, split_sizes, dim=0) + hidden_states_split = torch.split(hidden_states, split_sizes, dim=0) split_prompt = True for layer_idx, decoder_layer in enumerate(self.layers): @@ -1315,10 +1315,7 @@ def forward( htcore.mark_step() if output_hidden_states: - if split_prompt: - all_hidden_states += (torch.cat(hidden_states, dim=0),) - else: - all_hidden_states += (hidden_states,) + all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( @@ -1341,10 +1338,11 @@ def forward( valid_sequence_lengths, None, ) + hidden_states = layer_outputs[0] else: use_prev_layer_residual = attn_batch_split > 1 and past_key_values is None layer_outputs = decoder_layer( - hidden_states, + hidden_states=hidden_states_split if split_prompt else hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_value=None if past_key_values is None else past_key_values[layer_idx], @@ -1368,8 +1366,10 @@ def forward( if use_prev_layer_residual: index = 1 + int(use_cache) + int(output_attentions) prev_layer_residual = layer_outputs[index] - - hidden_states = layer_outputs[0] + if split_prompt: + hidden_states_split = layer_outputs[0] + else: + hidden_states = layer_outputs[0] if use_cache: next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) @@ -1377,9 +1377,6 @@ def forward( if output_attentions: all_self_attns += (layer_outputs[1],) - if split_prompt: - hidden_states = torch.cat(hidden_states, dim=0) - hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer