diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 45102c655d0bf..5f6e7c7cca94d 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -681,9 +681,8 @@ def forward( seq_lens, make_attn_mask=False) # type: ignore full_scales = ( - 1.0 / layer._q_scale.item(), - 1.0 / layer._k_scale.item(), 1.0 / - layer._v_scale.item(), 1.0 / layer._prob_scale.item(), + layer._q_scale.item(), layer._k_scale.item(), + layer._v_scale.item(), layer._prob_scale.item(), fp8_out_scale.item()) if ( fp8_out_scale and layer._q_scale and layer._prob_scale diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index 828bdc2905957..bc550a85f5a92 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -390,6 +390,8 @@ def get_autotune_configs(): autotune_configs, autotune_keys = get_autotune_configs() +float8_info = torch.finfo(torch.float8_e4m3fnuz) + @triton.autotune( configs=autotune_configs, @@ -451,6 +453,8 @@ def attn_fwd( BIAS_TYPE: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, + FP8_MIN: tl.constexpr = float8_info.min, + FP8_MAX: tl.constexpr = float8_info.max, ): start_m = tl.program_id(0) off_h_q = tl.program_id(1) @@ -733,6 +737,7 @@ def attn_fwd( causal_start_idx = seqlen_q - seqlen_k if USE_FP8: acc *= o_descale + acc = tl.clamp(acc, FP8_MIN, FP8_MAX) acc = acc.to(Out.type.element_ty) if IS_CAUSAL: # noqa: SIM102 if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: @@ -832,9 +837,9 @@ def forward( def check_and_convert(t, scale): if t.dtype != float8: - finfo = torch.finfo(float8) descale = 1.0 / scale - ts = (t * descale).clamp(min=finfo.min, max=finfo.max) + ts = (t * descale).clamp(min=float8_info.min, + max=float8_info.max) return ts.to(float8) else: return t