Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fp8 kv cache for ROCm #2856

Merged
merged 5 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 38 additions & 19 deletions server/text_generation_server/layers/attention/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,18 @@ def __init__(
device: torch.device,
):
"""Construct the key-value cache for a layer."""

if dtype in {torch.float8_e5m2, torch.float8_e4m3fn} and (
ATTENTION != "flashinfer" or SYSTEM != "cuda"
):
raise ValueError(
"FP8 KV cache is currently only supported for flashinfer on CUDA"
)
if dtype in {torch.float8_e5m2, torch.float8_e4m3fn}:
if not (
(ATTENTION == "flashinfer" and SYSTEM == "cuda")
or (ATTENTION == "paged" and SYSTEM == "rocm")
):
raise ValueError(
"FP8 KV cache is currently only supported for flashinfer on CUDA and paged attention on ROCm. "
)
if SYSTEM == "rocm" and dtype == torch.float8_e5m2:
raise ValueError(
"float8_e5m2 FP8 KV cache is not supported on AMD ROCm"
)

element_size = torch.tensor([], dtype=dtype).element_size()
if SYSTEM == "ipex" and device.type == "xpu":
Expand Down Expand Up @@ -110,21 +115,17 @@ def can_scale(self, kv_scales: KVScales) -> bool:
"""Check if the cache can be scaled by the given scales."""
if kv_scales.key_scale_cpu == 1.0 and kv_scales.value_scale_cpu == 1.0:
return False
elif (
self.dtype == torch.float8_e4m3fn
and ATTENTION == "flashinfer"
and SYSTEM == "cuda"
elif self.dtype == torch.float8_e4m3fn and (
(ATTENTION == "flashinfer" and SYSTEM == "cuda")
or (ATTENTION == "paged" and SYSTEM == "rocm")
):
log_once(
logger.info,
"Using FP8 KV cache scales",
)
log_once(logger.info, "Using FP8 KV cache scales")
return True
else:
# We have scales, but not the correct FP8 cache type, so warn once.
log_once(
logger.info,
"Ignoring FP8 KV cache scales, only float8_e4m3fn KV cache on flashinfer is supported",
"Ignoring FP8 KV cache scales, supported only for float8_e4m3fn KV cache with flashinfer on CUDA and paged attention on ROCm",
)
return False

Expand Down Expand Up @@ -158,7 +159,7 @@ def store(
key_cache = self.kv_cache[0]
value_cache = self.kv_cache[1]

if self.can_scale(kv_scales):
if self.can_scale(kv_scales) and SYSTEM == "cuda":
if kv_scales.key_scale_cpu != 1.0:
key = fp8_quantize(
key.float(),
Expand Down Expand Up @@ -188,7 +189,15 @@ def store(
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
paged_reshape_and_cache(key, value, key_cache, value_cache, slots)
paged_reshape_and_cache(
key,
value,
key_cache,
value_cache,
slots,
kv_scales.key_scale_cpu,
kv_scales.value_scale_cpu,
)


def paged_reshape_and_cache(
Expand All @@ -197,7 +206,10 @@ def paged_reshape_and_cache(
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
k_scale: float = 1.0,
v_scale: float = 1.0,
):

if SYSTEM == "cuda":
try:
import attention_kernels
Expand All @@ -215,8 +227,15 @@ def paged_reshape_and_cache(
raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
)

kv_cache_dtype = "auto"
if key_cache.dtype == torch.float8_e4m3fn:
key_cache = key_cache.view(torch.uint8)
value_cache = value_cache.view(torch.uint8)
kv_cache_dtype = "fp8"

ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0, 1.0
key, value, key_cache, value_cache, slots, kv_cache_dtype, k_scale, v_scale
)
elif SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex
Expand Down
39 changes: 24 additions & 15 deletions server/text_generation_server/layers/attention/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,15 @@ def paged_attention(

out = torch.empty_like(query)

if kv_cache.dtype == torch.float8_e4m3fn:
key = kv_cache.key.view(torch.uint8)
value = kv_cache.value.view(torch.uint8)
kv_cache_dtype = "fp8"
else:
key = kv_cache.key
value = kv_cache.value
kv_cache_dtype = "auto"

# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
Expand All @@ -147,18 +156,18 @@ def paged_attention(
ops.paged_attention_v1(
out,
query,
kv_cache.key,
kv_cache.value,
key,
value,
num_kv_heads,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
1.0,
kv_cache_dtype,
kv_scales.key_scale_cpu,
kv_scales.value_scale_cpu,
)
else:
# Run PagedAttention V2.
Expand All @@ -182,18 +191,18 @@ def paged_attention(
max_logits,
tmp_output,
query,
kv_cache.key,
kv_cache.value,
key,
value,
num_kv_heads,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
1.0,
kv_cache_dtype,
kv_scales.key_scale_cpu,
kv_scales.value_scale_cpu,
)
else:
ops.paged_attention_rocm(
Expand All @@ -202,18 +211,18 @@ def paged_attention(
max_logits,
tmp_output,
query,
kv_cache.key,
kv_cache.value,
key,
value,
num_kv_heads,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
1.0,
kv_cache_dtype,
kv_scales.key_scale_cpu,
kv_scales.value_scale_cpu,
None,
_PARTITION_SIZE,
)
Expand Down
Loading