Skip to content

Commit

Permalink
Explicit cuda kernel for decoding
Browse files Browse the repository at this point in the history
This is another implementation of a forward pass of fused
multiheaded attention, callable via the existing
memory_efficient_attention_forward .

It is only applicable when the bias is BlockDiagonalCausalWithOffsetPaddedKeysMask
and when there is a single query per logical batch element.

Its biggest wins are when multiquery is not in use.

These times are in microseconds (us). For some reason the new op
has more CPU overhead than the existing.

[---------------------- attention -----------------------]
                                  |  cutlassF  |  decoderF
1 threads: -----------------------------------------------
      3batch-2keys-1heads-mq      |     77.5   |    113.3
      3batch-2keys-1heads         |     77.5   |    113.5
      3batch-2keys-8heads-mq      |     77.8   |    121.8
      3batch-2keys-8heads         |     77.6   |    112.9
      3batch-2keys-64heads-mq     |     77.7   |    121.9
      3batch-2keys-64heads        |     77.8   |    113.5
      500batch-32keys-1heads-mq   |     77.9   |    113.5
      500batch-32keys-1heads      |     78.1   |    114.1
      500batch-32keys-8heads-mq   |     78.1   |    122.5
      500batch-32keys-8heads      |    164.4   |    148.5
      500batch-32keys-64heads-mq  |    152.3   |    892.4
      500batch-32keys-64heads     |   1244.5   |   1130.9
      2batch-1000keys-1heads-mq   |     77.5   |    112.2
      2batch-1000keys-1heads      |     77.6   |    111.4
      2batch-1000keys-8heads-mq   |     77.6   |    120.5
      2batch-1000keys-8heads      |     77.5   |    111.9
      2batch-1000keys-64heads-mq  |     77.9   |    120.8
      2batch-1000keys-64heads     |     77.8   |    112.3

[---------------- cuda graphed attention ----------------]
                                  |  cutlassF  |  decoderF
1 threads: -----------------------------------------------
      3batch-2keys-1heads-mq      |     12.3   |     12.2
      3batch-2keys-1heads         |     12.3   |     12.2
      3batch-2keys-8heads-mq      |     12.0   |     12.1
      3batch-2keys-8heads         |     12.3   |     12.1
      3batch-2keys-64heads-mq     |     12.8   |     13.5
      3batch-2keys-64heads        |     15.5   |     13.5
      500batch-32keys-1heads-mq   |     28.2   |     23.3
      500batch-32keys-1heads      |     28.3   |     23.3
      500batch-32keys-8heads-mq   |     44.5   |    120.5
      500batch-32keys-8heads      |    168.4   |    152.2
      500batch-32keys-64heads-mq  |    156.8   |    895.1
      500batch-32keys-64heads     |   1247.3   |   1133.4
      2batch-1000keys-1heads-mq   |     44.6   |     25.9
      2batch-1000keys-1heads      |     44.5   |     26.0
      2batch-1000keys-8heads-mq   |     44.8   |     26.8
      2batch-1000keys-8heads      |     44.8   |     26.8
      2batch-1000keys-64heads-mq  |     46.2   |     49.0
      2batch-1000keys-64heads     |     76.4   |     88.6

ghstack-source-id: 529cd1152a92ce79c383f43bb80cf78543d0d357
Pull Request resolved: fairinternal/xformers#676

__original_commit__ = fairinternal/xformers@293daad
  • Loading branch information
bottler authored and xFormers Bot committed Jul 18, 2023
1 parent c979299 commit ec5e4d3
Show file tree
Hide file tree
Showing 11 changed files with 882 additions and 31 deletions.
47 changes: 47 additions & 0 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
sm75_or_better_only = pytest.mark.skipif(
compute_capability < (7, 5), reason="requires sm75+"
)
sm80_or_better_only = pytest.mark.skipif(
compute_capability < (8, 0), reason="requires sm80+"
)
_devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]

ALL_FW_OPS: Sequence[Type[fmha.common.AttentionFwOpBase]] = [
Expand Down Expand Up @@ -1621,6 +1624,50 @@ def test_attn_bias_padded() -> None:
)


@sm80_or_better_only
@pytest.mark.parametrize("multiquery", [True, False], ids=lambda x: "mq" if x else "")
@pytest.mark.parametrize("n_heads", [1, 32])
@pytest.mark.parametrize("bsz", [1, 8])
@pytest.mark.parametrize("dtype", ["f16", "bf16", "f32"])
def test_decoder(multiquery: bool, n_heads: int, bsz: int, dtype: str) -> None:
dtype_ = {"f16": torch.float16, "bf16": torch.bfloat16, "f32": torch.float32}[dtype]
torch.manual_seed(1)
d, padding = 128, 32
k_shape = (1, bsz * padding, n_heads, d)
# TODO: support 2 kv heads etc.
k = torch.randn(k_shape, dtype=dtype_).cuda()
k_seqlen = [5, 8, 7, 1, 9, 3, 12, 32][:bsz]
v = torch.randn(k_shape, dtype=dtype_).cuda()
q = torch.randn((1, bsz, n_heads, d), dtype=dtype_).cuda()
causal_diagonal = torch.tensor( # TODO: make unnecessary
[i - 1 for i in k_seqlen], dtype=torch.int32
).cuda()

if multiquery:
k = k[:, :, :1].expand(k_shape)
v = v[:, :, :1].expand(k_shape)

attn_bias = fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
q_seqlen=[1] * bsz,
kv_seqlen=k_seqlen,
causal_diagonal=causal_diagonal,
kv_padding=padding,
)

cutlass_output = fmha.memory_efficient_attention_forward(
q, k, v, attn_bias, op=fmha.cutlass.FwOp
)
decoder_output = fmha.memory_efficient_attention_forward(
q, k, v, attn_bias, op=fmha.decoder.FwOp
)
assert_allclose(
decoder_output,
cutlass_output,
atol=fmha.cutlass.FwOp.ERROR_ATOL[torch.float16] * 4,
rtol=fmha.cutlass.FwOp.ERROR_RTOL[torch.float16],
)


def test_attn_bias_from_seqlens() -> None:
bias = fmha.attn_bias.BlockDiagonalMask.from_seqlens([3, 5, 1])
out = bias.split(torch.randn([1, 3 + 5 + 1, 16]))
Expand Down
179 changes: 179 additions & 0 deletions xformers/benchmarks/benchmark_mem_eff_attn_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


import itertools
from functools import partial

import torch
from torch.utils import benchmark
from utils import benchmark_main_helper

import xformers.ops
import xformers.ops.fmha as fmha

torch.backends.cuda.matmul.allow_tf32 = False

# Run with
# python xformers/benchmarks/benchmark_mem_eff_attn_decoder.py --omit-baselines --quiet
# The baselines for these benchmarks are really slow because there is
# so much padding in the inputs, so there is no point running them.


def ref_attention_bmk(q, k, v, attn_bias=None):
if isinstance(attn_bias, xformers.ops.AttentionMask):
attn_bias = (
attn_bias.materialize((q.shape[0], 1, q.shape[1], k.shape[1]))
.to(q)
.squeeze()
)
q = q * (1.0 / q.shape[-1] ** 0.5)
if attn_bias is None:
attn = q @ k.transpose(-2, -1)
else:
# equivalent to (q @ k.transpose(-2, -1) + m).softmax(-1) @ v
# but faster, and is what is used in PyTorch now
attn = torch.baddbmm(attn_bias, q, k.transpose(-2, -1))
attn = attn.softmax(-1)
return attn @ v


def ref_attention(q, k, v, attn_bias):
assert q.ndim == 4

def T(t):
return t.permute((0, 2, 1, 3)).reshape(
[t.shape[0] * t.shape[2], t.shape[1], t.shape[3]]
)

out = ref_attention_bmk(T(q), T(k), T(v), attn_bias)
out = out.reshape([q.shape[0], q.shape[2], q.shape[1], v.shape[3]])
return out.permute((0, 2, 1, 3))


min_run_time = 0.5
device = torch.device("cuda")

NUM_THREADS = [1] if device.type == "cuda" else [1, 40]

OPS = [
xformers.ops.fmha.cutlass.FwOp,
xformers.ops.fmha.decoder.FwOp,
]

KV_SHAPES = [
# list of n_keys, padding_length, batchsize
(2, 64, 3),
(32, 1024, 500),
(1000, 1024, 2),
(8000, 8192, 1),
]

N_HEADS = [8, 64]


def product_dict(**kwargs):
keys = kwargs.keys()
vals = kwargs.values()
for instance in itertools.product(*vals):
yield dict(zip(keys, instance))


CASES = list(
product_dict(
kv_shape=KV_SHAPES,
n_heads=N_HEADS,
num_threads=NUM_THREADS,
multiquery=[True, False],
)
)


def mem_eff_attention_decoder(
kv_shape, n_heads: int, num_threads: int, multiquery: bool
):
n_keys, padding, B = kv_shape
k_seqlen = [n_keys] * B
K = 128

q = torch.rand(1, B, n_heads, K, device=device, dtype=torch.bfloat16)
if multiquery:
k = torch.rand(
1, B * padding, 1, K, device=device, dtype=torch.bfloat16
).expand(1, B * padding, n_heads, K)
v = torch.rand(
1, B * padding, 1, K, device=device, dtype=torch.bfloat16
).expand(1, B * padding, n_heads, K)
else:
k = torch.rand(1, B * padding, n_heads, K, device=device, dtype=torch.bfloat16)
v = torch.rand(1, B * padding, n_heads, K, device=device, dtype=torch.bfloat16)

bias = fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
q_seqlen=[1] * B,
kv_seqlen=k_seqlen,
kv_padding=padding,
)

sub_label = f"{B}batch-{k_seqlen[0]}keys-{n_heads}heads"
if multiquery:
sub_label += "-mq"

has_run = False
for fw_op in OPS:
fn = partial(xformers.ops.memory_efficient_attention_forward, op=fw_op)

yield benchmark.Timer(
stmt="fn(q, k, v, attn_bias)",
globals={
"q": q,
"k": k,
"v": v,
"attn_bias": bias,
"fn": fn,
},
label="attention",
description=fw_op.NAME,
sub_label=sub_label,
num_threads=num_threads,
)

graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
fn(q, k, v, bias)
yield benchmark.Timer(
stmt="graph.replay()",
globals={
"graph": graph,
},
label="cuda graphed attention",
description=fw_op.NAME,
sub_label=sub_label,
num_threads=num_threads,
)

has_run = True

if not has_run:
return

RUN_BASELINES = False
if RUN_BASELINES:
yield benchmark.Timer(
stmt="fn(q, k, v, attn_bias)",
globals={
"q": q,
"k": k,
"v": v,
"attn_bias": bias,
"fn": ref_attention,
},
label="attention",
description="eager",
sub_label=sub_label,
num_threads=num_threads,
)


benchmark_main_helper(mem_eff_attention_decoder, CASES, min_run_time=min_run_time)
31 changes: 18 additions & 13 deletions xformers/benchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def pretty_plot(
results, title, units: str, filename=None, dash_key="", legend_loc="lower right"
):
"""Graph out the contents of a dict.
Dash key means that if the result label has this key, then it will be displayed with a dash"""
Dash key means that if the result label has this key, then it will be displayed with a dash
"""

if not filename:
filename = title + ".png"
Expand Down Expand Up @@ -139,7 +140,8 @@ def bench_functions(

def pretty_barplot(results, title, units: str, filename=None, dash_key=""):
"""Graph out the contents of a dict.
Dash key means that if the result label has this key, then it will be displayed with a dash"""
Dash key means that if the result label has this key, then it will be displayed with a dash
"""

if not filename:
filename = title + ".png"
Expand Down Expand Up @@ -295,7 +297,7 @@ def _finalize_results(results: List[Tuple[Dict[str, Any], Any]]) -> List[Any]:
"""
all_algorithms: Set[str] = set()
all_description: Set[str] = set()
for (metadata, r) in results:
for metadata, r in results:
algo = metadata.get(META_ALGORITHM, None)
if algo is not None:
all_algorithms.add(algo)
Expand All @@ -304,7 +306,7 @@ def _finalize_results(results: List[Tuple[Dict[str, Any], Any]]) -> List[Any]:
display_descr = len(all_description) > 1

display_results = []
for (metadata, r) in results:
for metadata, r in results:
algo = metadata.get(META_ALGORITHM, None)
if algo is None:
display_results.append(r)
Expand Down Expand Up @@ -343,14 +345,13 @@ def _render_bar_plot(results: List[Any], store_results_folder: str) -> None:
all_data_run: List[Any] = []
for key, runtime_values in runtime.items():
memory_values = memory_usage[key]
all_data_mem.append(
[key]
+ [
memory_values.get(d, 0)
/ memory_values.get(all_descriptions[0], math.inf)
for d in all_descriptions
]
)
denom = memory_values.get(all_descriptions[0], math.inf)
if denom == 0:
all_data_mem.append([key] + [0] * len(all_descriptions))
else:
all_data_mem.append(
[key] + [memory_values.get(d, 0) / denom for d in all_descriptions]
)
all_data_run.append(
[key]
+ [
Expand Down Expand Up @@ -409,7 +410,11 @@ def benchmark_main_helper(benchmark_fn, cases: List[Dict[str, Any]], **kwargs) -
type=str,
help="Compare to previously stored benchmarks (coma separated)",
)
parser.add_argument("--omit-baselines", action="store_true")
parser.add_argument(
"--omit-baselines",
action="store_true",
help="Do not run the (potentially slow) baselines",
)
parser.add_argument(
"--quiet",
action="store_true",
Expand Down
2 changes: 1 addition & 1 deletion xformers/components/attention/ortho.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,4 +321,4 @@ def _compute_orthogonal_landmarks(self, q: torch.Tensor) -> torch.Tensor:
landmarks = torch.masked_select(q_samples, selected_mask.bool()).reshape(
B, -1, D
)
return landmarks #  (B, M, D)
return landmarks # (B, M, D)
2 changes: 2 additions & 0 deletions xformers/csrc/attention/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ TORCH_LIBRARY_FRAGMENT(xformers, m) {
"xformers::efficient_attention_forward_small_k(Tensor query, Tensor key, Tensor value, bool compute_logsumexp, Tensor? attn_bias, float p) -> (Tensor, Tensor, int, int)"));
m.def(TORCH_SELECTIVE_SCHEMA(
"xformers::efficient_attention_forward_cutlass(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, Tensor? seqstart_q, Tensor? seqstart_k, int? max_seqlen_q, float dropout_p, bool compute_logsumexp, int custom_mask_type, float? scale, Tensor? seqlen_k) -> (Tensor, Tensor, int, int)"));
m.def(TORCH_SELECTIVE_SCHEMA(
"xformers::efficient_attention_forward_decoder(Tensor query, Tensor key, Tensor value, Tensor seq_positions, float scale) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA(
"xformers::efficient_attention_backward_small_k(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor logsumexp, Tensor output, Tensor? attn_bias, float p, int rng_seed, int rng_offset) -> (Tensor, Tensor, Tensor)"));
m.def(TORCH_SELECTIVE_SCHEMA(
Expand Down
Loading

0 comments on commit ec5e4d3

Please sign in to comment.