Skip to content

Commit

Permalink
FP8 FA fixes (ROCm#381)
Browse files Browse the repository at this point in the history
* FP8 FA fixes

Summary:
Add missing clamp and fix reciprocal scale computation.

* linter
  • Loading branch information
ilia-cher authored Jan 23, 2025
1 parent b5839a1 commit a600e9f
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
5 changes: 2 additions & 3 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,9 +681,8 @@ def forward(
seq_lens,
make_attn_mask=False) # type: ignore
full_scales = (
1.0 / layer._q_scale.item(),
1.0 / layer._k_scale.item(), 1.0 /
layer._v_scale.item(), 1.0 / layer._prob_scale.item(),
layer._q_scale.item(), layer._k_scale.item(),
layer._v_scale.item(), layer._prob_scale.item(),
fp8_out_scale.item()) if (
fp8_out_scale and layer._q_scale
and layer._prob_scale
Expand Down
9 changes: 7 additions & 2 deletions vllm/attention/ops/triton_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,8 @@ def get_autotune_configs():

autotune_configs, autotune_keys = get_autotune_configs()

float8_info = torch.finfo(torch.float8_e4m3fnuz)


@triton.autotune(
configs=autotune_configs,
Expand Down Expand Up @@ -451,6 +453,8 @@ def attn_fwd(
BIAS_TYPE: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr,
RETURN_ENCODED_SOFTMAX: tl.constexpr,
FP8_MIN: tl.constexpr = float8_info.min,
FP8_MAX: tl.constexpr = float8_info.max,
):
start_m = tl.program_id(0)
off_h_q = tl.program_id(1)
Expand Down Expand Up @@ -733,6 +737,7 @@ def attn_fwd(
causal_start_idx = seqlen_q - seqlen_k
if USE_FP8:
acc *= o_descale
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
acc = acc.to(Out.type.element_ty)
if IS_CAUSAL: # noqa: SIM102
if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:
Expand Down Expand Up @@ -832,9 +837,9 @@ def forward(

def check_and_convert(t, scale):
if t.dtype != float8:
finfo = torch.finfo(float8)
descale = 1.0 / scale
ts = (t * descale).clamp(min=finfo.min, max=finfo.max)
ts = (t * descale).clamp(min=float8_info.min,
max=float8_info.max)
return ts.to(float8)
else:
return t
Expand Down

0 comments on commit a600e9f

Please sign in to comment.