Skip to content

Commit

Permalink
Improve HALF_DTYPE selection
Browse files Browse the repository at this point in the history
  • Loading branch information
Skylion007 committed Jan 4, 2024
1 parent 6ecb798 commit eb93f3b
Showing 1 changed file with 8 additions and 11 deletions.
19 changes: 8 additions & 11 deletions examples/benchmarks/bert/src/bert_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
import os
import sys
import warnings
from functools import lru_cache
from typing import List, Optional, Tuple, Union

# Add folder root to path to allow us to use relative imports regardless of what directory the script is run from
Expand Down Expand Up @@ -79,12 +78,11 @@

logger = logging.getLogger(__name__)


@lru_cache
def _get_half_dtype() -> torch.dtype:
if flash_attn_qkvpacked_func is not None:
if torch.cuda.is_bf16_supported():
return torch.bfloat16
return torch.float16
HALF_DTYPE = torch.bfloat16
else:
HALF_DTYPE = torch.float16


class BertEmbeddings(nn.Module):
Expand Down Expand Up @@ -264,9 +262,9 @@ def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
# If FA2 is supported, bfloat16 must be supported
# as of FA2 2.4.2. (Turing GPUs not supported)
orig_dtype = qkv.dtype
qkv = qkv.to(torch.bfloat16)
qkv = qkv.to(HALF_DTYPE)
bias_dtype = bias.dtype
bias = bias.to(torch.bfloat16)
bias = bias.to(HALF_DTYPE)

attention = flash_attn_qkvpacked_func(
qkv, dropout_p=self.p_dropout, alibi_slopes=slopes)
Expand All @@ -279,13 +277,12 @@ def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
# Triton implementation only supports 0 attention dropout
convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
if convert_dtype:
half = _get_half_dtype()

# Triton implementation only supports fp16 and bf16
orig_dtype = qkv.dtype
qkv = qkv.to(half)
qkv = qkv.to(HALF_DTYPE)
bias_dtype = bias.dtype
bias = bias.to(half)
bias = bias.to(HALF_DTYPE)
attention = flash_attn_qkvpacked_func(qkv, bias)
attention = attention.to(orig_dtype)
bias = bias.to(bias_dtype)
Expand Down

0 comments on commit eb93f3b

Please sign in to comment.