Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add batch splitting in attention layer to hide NIC latency(#14) #1640

Merged
merged 6 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/text-generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ Here are a few settings you may be interested in:
- `--prompt` to benchmark the model on one or several prompts of your choice
- `--attn_softmax_bf16` to run attention softmax layer in bfloat16 precision provided that the model (such as Llama) supports it
- `--trim_logits` to calculate logits only for the last token in the first time step provided that the model (such as Llama) supports it
- `--attn_batch_split` specifies the number of smaller batches into which attention and MLP processing are split to improve parallelization. By default, no splitting is performed (value is 1). Splitting is enabled only for prompt processing. This configuration is most effective for batch sizes (BS) > 125 and tensor parallelism (TP) >= 2, with a recommended value of '3' splits.

For example, you can reproduce the results presented in [this blog post](https://huggingface.co/blog/habana-gaudi-2-bloom) with the following command:
```bash
Expand Down
6 changes: 6 additions & 0 deletions examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,12 @@ def setup_parser(parser):
default=None,
help="Path to neural-compressor quantized model, if set, the checkpoint will be loaded.",
)
parser.add_argument(
"--attn_batch_split",
default=1,
type=int,
help="Specify the batch size split for attention and mlp layers. 1 for no split. This is enabled only for prompt.",
)

args = parser.parse_args()

Expand Down
4 changes: 4 additions & 0 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,7 @@ def setup_generation_config(args, model, assistant_model, tokenizer):
generation_config.flash_attention_fast_softmax = args.flash_attention_fast_softmax
generation_config.trust_remote_code = args.trust_remote_code
generation_config.valid_sequence_lengths = None
generation_config.attn_batch_split = args.attn_batch_split

return generation_config

Expand All @@ -689,6 +690,9 @@ def exclude_hpu_graph_configs(args):
def initialize_model(args, logger):
init_start = time.perf_counter()
setup_distributed(args)
if not args.world_size > 0 and args.attn_batch_split > 1:
logger.warning("Disabling attention batch splitting as it's unnecessary for single-card execution")
args.attn_batch_split = 1
if exclude_hpu_graph_configs(args):
args.limit_hpu_graphs = False
override_prints(args.global_rank == 0 or args.verbose_workers, logger)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class GaudiGenerationConfig(GenerationConfig):
Whether to enable causal_mask if use Habana flash attention.
flash_attention_fast_softmax_mode (`bool`, *optional*):
Whether to use fast softmax with reduced precision if use Habana flash attention.
attn_batch_split (`int`, *optional*):
Specify the batch size split for attention and mlp layers. 1 for no split. This is enabled only for prompt.
"""

def __init__(self, **kwargs):
Expand All @@ -56,3 +58,4 @@ def __init__(self, **kwargs):
self.flash_attention_fast_softmax = kwargs.get("flash_attention_fast_softmax", None)
self.use_fused_rope = kwargs.get("use_fused_rope", None)
self.valid_sequence_lengths = kwargs.get("valid_sequence_lengths", None)
self.attn_batch_split = kwargs.get("attn_batch_split", 1)
3 changes: 3 additions & 0 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1266,6 +1266,9 @@ def generate(
# prepare for allocate kv cache
model_kwargs["reuse_cache"] = generation_config.reuse_cache

# prepare for attention batch splitting
model_kwargs["attn_batch_split"] = generation_config.attn_batch_split

# determine whether flash attention needs to be used
model_kwargs["use_flash_attention"] = generation_config.use_flash_attention
model_kwargs["flash_attention_recompute"] = True if generation_config.flash_attention_recompute else False
Expand Down
156 changes: 128 additions & 28 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,7 +893,6 @@ class GaudiLlamaDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: LlamaConfig, layer_idx: int):
super(LlamaDecoderLayer, self).__init__()
self.hidden_size = config.hidden_size

self.self_attn = GaudiLlamaAttention(config=config, layer_idx=layer_idx)

self.mlp = GaudiLlamaMLP(config)
Expand Down Expand Up @@ -929,6 +928,8 @@ def forward(
valid_sequence_lengths: Optional[torch.Tensor] = None,
cache_idx: int = None,
num_virtual_tokens: int = None,
attn_batch_split: int = 1,
prev_layer_residual: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Expand All @@ -942,40 +943,109 @@ def forward(
- add new arg flash_attention_causal_mask
- add new arg flash_attention_fast_softmax
"""
residual = hidden_states
if attn_batch_split > 1 and past_key_value is None:
# Calculate split sizes to handle cases where batch size is not divisible by attn_batch_split
batch_size = attention_mask.size(0)
base_split_size = batch_size // attn_batch_split
remainder = batch_size % attn_batch_split

split_sizes = [base_split_size + 1 if i < remainder else base_split_size for i in range(attn_batch_split)]

# Split tensors using the calculated sizes
sub_attention_mask = torch.split(attention_mask, split_sizes, dim=0)
sub_position_ids = torch.split(position_ids, split_sizes, dim=0)
sub_valid_sequence_lengths = torch.split(valid_sequence_lengths, split_sizes, dim=0)
split_attn_weights = []
split_present_key_values = []
split_hidden_states = [None] * attn_batch_split
residual = [None] * attn_batch_split

for i in range(attn_batch_split):
split_hidden_states[i] = hidden_states[i]
if self.self_attn.layer_idx != 0:
# Add the residual from the previous layer
split_hidden_states[i] = self.post_mlp(hidden_states[i], prev_layer_residual[i])

residual[i] = split_hidden_states[i]
split_hidden_states[i], self_attn_weights, present_key_value = self.pre_attn(
hidden_states=split_hidden_states[i],
attention_mask=sub_attention_mask[i],
position_ids=sub_position_ids[i],
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
token_idx=token_idx,
attn_softmax_bf16=attn_softmax_bf16,
reuse_cache=reuse_cache,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
flash_attention_fast_softmax=flash_attention_fast_softmax,
valid_sequence_lengths=sub_valid_sequence_lengths[i],
cache_idx=cache_idx,
num_virtual_tokens=num_virtual_tokens,
**kwargs,
)
self.self_attn.attention_all_reduce(split_hidden_states[i])
if output_attentions:
split_attn_weights.append(self_attn_weights)
if use_cache:
split_present_key_values.append(present_key_value)

hidden_states, self_attn_weights, present_key_value = self.pre_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
token_idx=token_idx,
attn_softmax_bf16=attn_softmax_bf16,
reuse_cache=reuse_cache,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
flash_attention_fast_softmax=flash_attention_fast_softmax,
valid_sequence_lengths=valid_sequence_lengths,
cache_idx=cache_idx,
num_virtual_tokens=num_virtual_tokens,
**kwargs,
)
self.self_attn.attention_all_reduce(hidden_states)
hidden_states, residual = self.post_attn_pre_mlp(hidden_states, residual)
self.mlp.mlp_all_reduce(hidden_states)
hidden_states = self.post_mlp(hidden_states, residual)
self_attn_weights = torch.cat(split_attn_weights, dim=0) if split_attn_weights else None
present_key_value = [torch.cat(tensors, dim=0) for tensors in zip(*split_present_key_values)]

int_residual_splits = []
for i in range(attn_batch_split):
split_hidden_states[i], int_residual = self.post_attn_pre_mlp(split_hidden_states[i], residual[i])
self.mlp.mlp_all_reduce(split_hidden_states[i])
int_residual_splits.append(int_residual)

if self.self_attn.layer_idx == (self.self_attn.config.num_hidden_layers - 1):
for i in range(attn_batch_split):
split_hidden_states[i] = self.post_mlp(split_hidden_states[i], int_residual_splits[i])

hidden_states = split_hidden_states

else:
residual = hidden_states
hidden_states, self_attn_weights, present_key_value = self.pre_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
token_idx=token_idx,
attn_softmax_bf16=attn_softmax_bf16,
reuse_cache=reuse_cache,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
flash_attention_fast_softmax=flash_attention_fast_softmax,
valid_sequence_lengths=valid_sequence_lengths,
cache_idx=cache_idx,
num_virtual_tokens=num_virtual_tokens,
**kwargs,
)
self.self_attn.attention_all_reduce(hidden_states)
hidden_states, residual = self.post_attn_pre_mlp(hidden_states, residual)
self.mlp.mlp_all_reduce(hidden_states)
hidden_states = self.post_mlp(hidden_states, residual)

outputs = (hidden_states,)

if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
# Store the residual spits to add them in the beginning of the next layer
if attn_batch_split > 1 and past_key_value is None:
outputs += (int_residual_splits,)

return outputs

Expand Down Expand Up @@ -1021,6 +1091,7 @@ def pre_attn(
cache_idx=cache_idx,
num_virtual_tokens=num_virtual_tokens,
)

return hidden_states, attn_weights, present_key_value

def post_attn_pre_mlp(self, hidden_states, residual):
Expand Down Expand Up @@ -1115,6 +1186,7 @@ def forward(
cache_idx: int = None,
lazy_mode: Optional[bool] = True,
num_virtual_tokens: int = None,
attn_batch_split: int = 1,
) -> Union[Tuple, BaseModelOutputWithPast]:
"""
Copied from LlamaModel.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
Expand Down Expand Up @@ -1222,6 +1294,18 @@ def forward(
if lazy_mode:
htcore.mark_step()

split_prompt = False
prev_layer_residual = None
if attn_batch_split > 1 and past_key_values is None:
# Calculate split sizes to handle cases where batch size is not divisible by attn_batch_split
batch_size = hidden_states.size(0)
base_split_size = batch_size // attn_batch_split
remainder = batch_size % attn_batch_split
split_sizes = [base_split_size + 1 if i < remainder else base_split_size for i in range(attn_batch_split)]
# Split tensors using the calculated sizes
hidden_states = torch.split(hidden_states, split_sizes, dim=0)
split_prompt = True

for layer_idx, decoder_layer in enumerate(self.layers):
if (
lazy_mode
Expand All @@ -1231,7 +1315,10 @@ def forward(
htcore.mark_step()

if output_hidden_states:
all_hidden_states += (hidden_states,)
if split_prompt:
all_hidden_states += (torch.cat(hidden_states, dim=0),)
else:
all_hidden_states += (hidden_states,)

if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
Expand All @@ -1255,6 +1342,7 @@ def forward(
None,
)
else:
use_prev_layer_residual = attn_batch_split > 1 and past_key_values is None
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
Expand All @@ -1274,7 +1362,13 @@ def forward(
valid_sequence_lengths=valid_sequence_lengths,
cache_idx=cache_idx,
num_virtual_tokens=num_virtual_tokens,
attn_batch_split=attn_batch_split,
prev_layer_residual=prev_layer_residual,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the previous review, it seems this need to be changed to,

prev_layer_residual=prev_layer_residual if use_prev_layer_residual else None,

And where prev_layer_residual value is set after line 1298?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prev_layer_residual is set in line 1370 prev_layer_residual = layer_outputs[index], for the first layer it will be None.

Copy link
Contributor Author

@kalyank007 kalyank007 Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perf improvement on Gaudi3
image

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please post CI result also?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yeonsily Can we mark this conversation resolved ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please re-run test_examples.py with RUN_SLOW=true GAUDI2_CI=1 ? as llama ones are all skipped and didn't run.

)
if use_prev_layer_residual:
index = 1 + int(use_cache) + int(output_attentions)
prev_layer_residual = layer_outputs[index]

kalyank007 marked this conversation as resolved.
Show resolved Hide resolved
hidden_states = layer_outputs[0]

if use_cache:
Expand All @@ -1283,6 +1377,9 @@ def forward(
if output_attentions:
all_self_attns += (layer_outputs[1],)

if split_prompt:
hidden_states = torch.cat(hidden_states, dim=0)

hidden_states = self.norm(hidden_states)

# add hidden states from the last decoder layer
Expand Down Expand Up @@ -1354,6 +1451,7 @@ def forward(
cache_idx: int = None,
lazy_mode: Optional[bool] = True,
num_virtual_tokens: int = None,
attn_batch_split: int = 1,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -1387,6 +1485,7 @@ def forward(
cache_idx=cache_idx,
lazy_mode=lazy_mode,
num_virtual_tokens=num_virtual_tokens,
attn_batch_split=attn_batch_split,
)
hidden_states = outputs[0]
_, seq_len, _ = hidden_states.shape
Expand Down Expand Up @@ -1542,6 +1641,7 @@ def prepare_inputs_for_generation(
"cache_idx": kwargs.get("cache_idx"),
"lazy_mode": kwargs.get("lazy_mode"),
"num_virtual_tokens": kwargs.get("num_virtual_tokens"),
"attn_batch_split": kwargs.get("attn_batch_split"),
}
)
return model_inputs
Expand Down