diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index a884877bee..5d399f65dd 100755 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -132,6 +132,7 @@ Here are a few settings you may be interested in: - `--prompt` to benchmark the model on one or several prompts of your choice - `--attn_softmax_bf16` to run attention softmax layer in bfloat16 precision provided that the model (such as Llama) supports it - `--trim_logits` to calculate logits only for the last token in the first time step provided that the model (such as Llama) supports it +- `--attn_batch_split` specifies the number of smaller batches into which attention and MLP processing are split to improve parallelization. By default, no splitting is performed (value is 1). Splitting is enabled only for prompt processing. This configuration is most effective for batch sizes (BS) > 125 and tensor parallelism (TP) >= 2, with a recommended value of '3' splits. For example, you can reproduce the results presented in [this blog post](https://huggingface.co/blog/habana-gaudi-2-bloom) with the following command: ```bash diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index c58c01b9fe..274aa591cd 100755 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -356,6 +356,12 @@ def setup_parser(parser): default=None, help="Path to neural-compressor quantized model, if set, the checkpoint will be loaded.", ) + parser.add_argument( + "--attn_batch_split", + default=1, + type=int, + help="Specify the batch size split for attention and mlp layers. 1 for no split. This is enabled only for prompt.", + ) args = parser.parse_args() diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index c4bd2147ac..5ef9c5c858 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -672,6 +672,7 @@ def setup_generation_config(args, model, assistant_model, tokenizer): generation_config.flash_attention_fast_softmax = args.flash_attention_fast_softmax generation_config.trust_remote_code = args.trust_remote_code generation_config.valid_sequence_lengths = None + generation_config.attn_batch_split = args.attn_batch_split return generation_config @@ -696,6 +697,9 @@ def exclude_hpu_graph_configs(args): def initialize_model(args, logger): init_start = time.perf_counter() setup_distributed(args) + if not args.world_size > 0 and args.attn_batch_split > 1: + logger.warning("Disabling attention batch splitting as it's unnecessary for single-card execution") + args.attn_batch_split = 1 if exclude_hpu_graph_configs(args): args.limit_hpu_graphs = False override_prints(args.global_rank == 0 or args.verbose_workers, logger) diff --git a/optimum/habana/transformers/generation/configuration_utils.py b/optimum/habana/transformers/generation/configuration_utils.py index 94c7e66217..4ed9cd80a2 100644 --- a/optimum/habana/transformers/generation/configuration_utils.py +++ b/optimum/habana/transformers/generation/configuration_utils.py @@ -39,6 +39,8 @@ class GaudiGenerationConfig(GenerationConfig): Whether to enable causal_mask if use Habana flash attention. flash_attention_fast_softmax_mode (`bool`, *optional*): Whether to use fast softmax with reduced precision if use Habana flash attention. + attn_batch_split (`int`, *optional*): + Specify the batch size split for attention and mlp layers. 1 for no split. This is enabled only for prompt. """ def __init__(self, **kwargs): @@ -59,3 +61,4 @@ def __init__(self, **kwargs): self.flash_attention_fast_softmax = kwargs.get("flash_attention_fast_softmax", None) self.use_fused_rope = kwargs.get("use_fused_rope", None) self.valid_sequence_lengths = kwargs.get("valid_sequence_lengths", None) + self.attn_batch_split = kwargs.get("attn_batch_split", 1) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index f7fc1a34b3..bcb4d74e5b 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -1270,6 +1270,9 @@ def generate( # prepare for allocate kv cache model_kwargs["reuse_cache"] = generation_config.reuse_cache + # prepare for attention batch splitting + model_kwargs["attn_batch_split"] = generation_config.attn_batch_split + # determine whether flash attention needs to be used model_kwargs["use_flash_attention"] = generation_config.use_flash_attention model_kwargs["flash_attention_recompute"] = True if generation_config.flash_attention_recompute else False diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 18867ff8a4..130490d5c9 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -911,7 +911,6 @@ class GaudiLlamaDecoderLayer(LlamaDecoderLayer): def __init__(self, config: LlamaConfig, layer_idx: int): super(LlamaDecoderLayer, self).__init__() self.hidden_size = config.hidden_size - self.self_attn = GaudiLlamaAttention(config=config, layer_idx=layer_idx) self.mlp = GaudiLlamaMLP(config) @@ -947,6 +946,8 @@ def forward( valid_sequence_lengths: Optional[torch.Tensor] = None, cache_idx: int = None, num_virtual_tokens: int = None, + attn_batch_split: int = 1, + prev_layer_residual: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -960,33 +961,99 @@ def forward( - add new arg flash_attention_causal_mask - add new arg flash_attention_fast_softmax """ - residual = hidden_states + if attn_batch_split > 1 and past_key_value is None: + # Calculate split sizes to handle cases where batch size is not divisible by attn_batch_split + batch_size = attention_mask.size(0) + base_split_size = batch_size // attn_batch_split + remainder = batch_size % attn_batch_split + + split_sizes = [base_split_size + 1 if i < remainder else base_split_size for i in range(attn_batch_split)] + + # Split tensors using the calculated sizes + sub_attention_mask = torch.split(attention_mask, split_sizes, dim=0) + sub_position_ids = torch.split(position_ids, split_sizes, dim=0) + sub_valid_sequence_lengths = torch.split(valid_sequence_lengths, split_sizes, dim=0) + split_attn_weights = [] + split_present_key_values = [] + split_hidden_states = [None] * attn_batch_split + residual = [None] * attn_batch_split + + for i in range(attn_batch_split): + split_hidden_states[i] = hidden_states[i] + if self.self_attn.layer_idx != 0: + # Add the residual from the previous layer + split_hidden_states[i] = self.post_mlp(hidden_states[i], prev_layer_residual[i]) + + residual[i] = split_hidden_states[i] + split_hidden_states[i], self_attn_weights, present_key_value = self.pre_attn( + hidden_states=split_hidden_states[i], + attention_mask=sub_attention_mask[i], + position_ids=sub_position_ids[i], + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + token_idx=token_idx, + attn_softmax_bf16=attn_softmax_bf16, + reuse_cache=reuse_cache, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, + flash_attention_fast_softmax=flash_attention_fast_softmax, + valid_sequence_lengths=sub_valid_sequence_lengths[i], + cache_idx=cache_idx, + num_virtual_tokens=num_virtual_tokens, + **kwargs, + ) + self.self_attn.attention_all_reduce(split_hidden_states[i]) + if output_attentions: + split_attn_weights.append(self_attn_weights) + if use_cache: + split_present_key_values.append(present_key_value) - hidden_states, self_attn_weights, present_key_value = self.pre_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - token_idx=token_idx, - attn_softmax_bf16=attn_softmax_bf16, - reuse_cache=reuse_cache, - use_flash_attention=use_flash_attention, - flash_attention_recompute=flash_attention_recompute, - flash_attention_causal_mask=flash_attention_causal_mask, - flash_attention_fast_softmax=flash_attention_fast_softmax, - valid_sequence_lengths=valid_sequence_lengths, - cache_idx=cache_idx, - num_virtual_tokens=num_virtual_tokens, - **kwargs, - ) - self.self_attn.attention_all_reduce(hidden_states) - hidden_states, residual = self.post_attn_pre_mlp(hidden_states, residual) - self.mlp.mlp_all_reduce(hidden_states) - hidden_states = self.post_mlp(hidden_states, residual) + self_attn_weights = torch.cat(split_attn_weights, dim=0) if split_attn_weights else None + present_key_value = [torch.cat(tensors, dim=0) for tensors in zip(*split_present_key_values)] + + int_residual_splits = [] + for i in range(attn_batch_split): + split_hidden_states[i], int_residual = self.post_attn_pre_mlp(split_hidden_states[i], residual[i]) + self.mlp.mlp_all_reduce(split_hidden_states[i]) + int_residual_splits.append(int_residual) + + if self.self_attn.layer_idx == (self.self_attn.config.num_hidden_layers - 1): + for i in range(attn_batch_split): + split_hidden_states[i] = self.post_mlp(split_hidden_states[i], int_residual_splits[i]) + + hidden_states = split_hidden_states + + else: + residual = hidden_states + hidden_states, self_attn_weights, present_key_value = self.pre_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + token_idx=token_idx, + attn_softmax_bf16=attn_softmax_bf16, + reuse_cache=reuse_cache, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, + flash_attention_fast_softmax=flash_attention_fast_softmax, + valid_sequence_lengths=valid_sequence_lengths, + cache_idx=cache_idx, + num_virtual_tokens=num_virtual_tokens, + **kwargs, + ) + self.self_attn.attention_all_reduce(hidden_states) + hidden_states, residual = self.post_attn_pre_mlp(hidden_states, residual) + self.mlp.mlp_all_reduce(hidden_states) + hidden_states = self.post_mlp(hidden_states, residual) outputs = (hidden_states,) @@ -994,6 +1061,9 @@ def forward( outputs += (self_attn_weights,) if use_cache: outputs += (present_key_value,) + # Store the residual splits to add them in the beginning of the next layer + if attn_batch_split > 1 and past_key_value is None: + outputs += (int_residual_splits,) return outputs @@ -1039,6 +1109,7 @@ def pre_attn( cache_idx=cache_idx, num_virtual_tokens=num_virtual_tokens, ) + return hidden_states, attn_weights, present_key_value def post_attn_pre_mlp(self, hidden_states, residual): @@ -1133,6 +1204,7 @@ def forward( cache_idx: int = None, lazy_mode: Optional[bool] = True, num_virtual_tokens: int = None, + attn_batch_split: int = 1, ) -> Union[Tuple, BaseModelOutputWithPast]: """ Copied from LlamaModel.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py @@ -1240,6 +1312,18 @@ def forward( if lazy_mode: htcore.mark_step() + split_prompt = False + prev_layer_residual = None + if attn_batch_split > 1 and past_key_values is None: + # Calculate split sizes to handle cases where batch size is not divisible by attn_batch_split + batch_size = hidden_states.size(0) + base_split_size = batch_size // attn_batch_split + remainder = batch_size % attn_batch_split + split_sizes = [base_split_size + 1 if i < remainder else base_split_size for i in range(attn_batch_split)] + # Split tensors using the calculated sizes + hidden_states_split = torch.split(hidden_states, split_sizes, dim=0) + split_prompt = True + for layer_idx, decoder_layer in enumerate(self.layers): if ( lazy_mode @@ -1272,9 +1356,11 @@ def forward( valid_sequence_lengths, None, ) + hidden_states = layer_outputs[0] else: + use_prev_layer_residual = attn_batch_split > 1 and past_key_values is None layer_outputs = decoder_layer( - hidden_states, + hidden_states=hidden_states_split if split_prompt else hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_value=None if past_key_values is None else past_key_values[layer_idx], @@ -1292,8 +1378,16 @@ def forward( valid_sequence_lengths=valid_sequence_lengths, cache_idx=cache_idx, num_virtual_tokens=num_virtual_tokens, + attn_batch_split=attn_batch_split, + prev_layer_residual=prev_layer_residual if use_prev_layer_residual else None, ) - hidden_states = layer_outputs[0] + if use_prev_layer_residual: + index = 1 + int(use_cache) + int(output_attentions) + prev_layer_residual = layer_outputs[index] + if split_prompt: + hidden_states_split = layer_outputs[0] + else: + hidden_states = layer_outputs[0] if use_cache: next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) @@ -1372,6 +1466,7 @@ def forward( cache_idx: int = None, lazy_mode: Optional[bool] = True, num_virtual_tokens: int = None, + attn_batch_split: int = 1, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1405,6 +1500,7 @@ def forward( cache_idx=cache_idx, lazy_mode=lazy_mode, num_virtual_tokens=num_virtual_tokens, + attn_batch_split=attn_batch_split, ) hidden_states = outputs[0] _, seq_len, _ = hidden_states.shape @@ -1560,6 +1656,7 @@ def prepare_inputs_for_generation( "cache_idx": kwargs.get("cache_idx"), "lazy_mode": kwargs.get("lazy_mode"), "num_virtual_tokens": kwargs.get("num_virtual_tokens"), + "attn_batch_split": kwargs.get("attn_batch_split"), } ) return model_inputs