-
Notifications
You must be signed in to change notification settings - Fork 607
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This is another implementation of a forward pass of fused multiheaded attention, callable via the existing memory_efficient_attention_forward . It is only applicable when the bias is BlockDiagonalCausalWithOffsetPaddedKeysMask and when there is a single query per logical batch element. Its biggest wins are when multiquery is not in use. These times are in microseconds (us). For some reason the new op has more CPU overhead than the existing. [---------------------- attention -----------------------] | cutlassF | decoderF 1 threads: ----------------------------------------------- 3batch-2keys-1heads-mq | 77.5 | 113.3 3batch-2keys-1heads | 77.5 | 113.5 3batch-2keys-8heads-mq | 77.8 | 121.8 3batch-2keys-8heads | 77.6 | 112.9 3batch-2keys-64heads-mq | 77.7 | 121.9 3batch-2keys-64heads | 77.8 | 113.5 500batch-32keys-1heads-mq | 77.9 | 113.5 500batch-32keys-1heads | 78.1 | 114.1 500batch-32keys-8heads-mq | 78.1 | 122.5 500batch-32keys-8heads | 164.4 | 148.5 500batch-32keys-64heads-mq | 152.3 | 892.4 500batch-32keys-64heads | 1244.5 | 1130.9 2batch-1000keys-1heads-mq | 77.5 | 112.2 2batch-1000keys-1heads | 77.6 | 111.4 2batch-1000keys-8heads-mq | 77.6 | 120.5 2batch-1000keys-8heads | 77.5 | 111.9 2batch-1000keys-64heads-mq | 77.9 | 120.8 2batch-1000keys-64heads | 77.8 | 112.3 [---------------- cuda graphed attention ----------------] | cutlassF | decoderF 1 threads: ----------------------------------------------- 3batch-2keys-1heads-mq | 12.3 | 12.2 3batch-2keys-1heads | 12.3 | 12.2 3batch-2keys-8heads-mq | 12.0 | 12.1 3batch-2keys-8heads | 12.3 | 12.1 3batch-2keys-64heads-mq | 12.8 | 13.5 3batch-2keys-64heads | 15.5 | 13.5 500batch-32keys-1heads-mq | 28.2 | 23.3 500batch-32keys-1heads | 28.3 | 23.3 500batch-32keys-8heads-mq | 44.5 | 120.5 500batch-32keys-8heads | 168.4 | 152.2 500batch-32keys-64heads-mq | 156.8 | 895.1 500batch-32keys-64heads | 1247.3 | 1133.4 2batch-1000keys-1heads-mq | 44.6 | 25.9 2batch-1000keys-1heads | 44.5 | 26.0 2batch-1000keys-8heads-mq | 44.8 | 26.8 2batch-1000keys-8heads | 44.8 | 26.8 2batch-1000keys-64heads-mq | 46.2 | 49.0 2batch-1000keys-64heads | 76.4 | 88.6 ghstack-source-id: 529cd1152a92ce79c383f43bb80cf78543d0d357 Pull Request resolved: fairinternal/xformers#676 __original_commit__ = fairinternal/xformers@293daad
- Loading branch information
Showing
11 changed files
with
882 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
# | ||
# This source code is licensed under the BSD license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
import itertools | ||
from functools import partial | ||
|
||
import torch | ||
from torch.utils import benchmark | ||
from utils import benchmark_main_helper | ||
|
||
import xformers.ops | ||
import xformers.ops.fmha as fmha | ||
|
||
torch.backends.cuda.matmul.allow_tf32 = False | ||
|
||
# Run with | ||
# python xformers/benchmarks/benchmark_mem_eff_attn_decoder.py --omit-baselines --quiet | ||
# The baselines for these benchmarks are really slow because there is | ||
# so much padding in the inputs, so there is no point running them. | ||
|
||
|
||
def ref_attention_bmk(q, k, v, attn_bias=None): | ||
if isinstance(attn_bias, xformers.ops.AttentionMask): | ||
attn_bias = ( | ||
attn_bias.materialize((q.shape[0], 1, q.shape[1], k.shape[1])) | ||
.to(q) | ||
.squeeze() | ||
) | ||
q = q * (1.0 / q.shape[-1] ** 0.5) | ||
if attn_bias is None: | ||
attn = q @ k.transpose(-2, -1) | ||
else: | ||
# equivalent to (q @ k.transpose(-2, -1) + m).softmax(-1) @ v | ||
# but faster, and is what is used in PyTorch now | ||
attn = torch.baddbmm(attn_bias, q, k.transpose(-2, -1)) | ||
attn = attn.softmax(-1) | ||
return attn @ v | ||
|
||
|
||
def ref_attention(q, k, v, attn_bias): | ||
assert q.ndim == 4 | ||
|
||
def T(t): | ||
return t.permute((0, 2, 1, 3)).reshape( | ||
[t.shape[0] * t.shape[2], t.shape[1], t.shape[3]] | ||
) | ||
|
||
out = ref_attention_bmk(T(q), T(k), T(v), attn_bias) | ||
out = out.reshape([q.shape[0], q.shape[2], q.shape[1], v.shape[3]]) | ||
return out.permute((0, 2, 1, 3)) | ||
|
||
|
||
min_run_time = 0.5 | ||
device = torch.device("cuda") | ||
|
||
NUM_THREADS = [1] if device.type == "cuda" else [1, 40] | ||
|
||
OPS = [ | ||
xformers.ops.fmha.cutlass.FwOp, | ||
xformers.ops.fmha.decoder.FwOp, | ||
] | ||
|
||
KV_SHAPES = [ | ||
# list of n_keys, padding_length, batchsize | ||
(2, 64, 3), | ||
(32, 1024, 500), | ||
(1000, 1024, 2), | ||
(8000, 8192, 1), | ||
] | ||
|
||
N_HEADS = [8, 64] | ||
|
||
|
||
def product_dict(**kwargs): | ||
keys = kwargs.keys() | ||
vals = kwargs.values() | ||
for instance in itertools.product(*vals): | ||
yield dict(zip(keys, instance)) | ||
|
||
|
||
CASES = list( | ||
product_dict( | ||
kv_shape=KV_SHAPES, | ||
n_heads=N_HEADS, | ||
num_threads=NUM_THREADS, | ||
multiquery=[True, False], | ||
) | ||
) | ||
|
||
|
||
def mem_eff_attention_decoder( | ||
kv_shape, n_heads: int, num_threads: int, multiquery: bool | ||
): | ||
n_keys, padding, B = kv_shape | ||
k_seqlen = [n_keys] * B | ||
K = 128 | ||
|
||
q = torch.rand(1, B, n_heads, K, device=device, dtype=torch.bfloat16) | ||
if multiquery: | ||
k = torch.rand( | ||
1, B * padding, 1, K, device=device, dtype=torch.bfloat16 | ||
).expand(1, B * padding, n_heads, K) | ||
v = torch.rand( | ||
1, B * padding, 1, K, device=device, dtype=torch.bfloat16 | ||
).expand(1, B * padding, n_heads, K) | ||
else: | ||
k = torch.rand(1, B * padding, n_heads, K, device=device, dtype=torch.bfloat16) | ||
v = torch.rand(1, B * padding, n_heads, K, device=device, dtype=torch.bfloat16) | ||
|
||
bias = fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( | ||
q_seqlen=[1] * B, | ||
kv_seqlen=k_seqlen, | ||
kv_padding=padding, | ||
) | ||
|
||
sub_label = f"{B}batch-{k_seqlen[0]}keys-{n_heads}heads" | ||
if multiquery: | ||
sub_label += "-mq" | ||
|
||
has_run = False | ||
for fw_op in OPS: | ||
fn = partial(xformers.ops.memory_efficient_attention_forward, op=fw_op) | ||
|
||
yield benchmark.Timer( | ||
stmt="fn(q, k, v, attn_bias)", | ||
globals={ | ||
"q": q, | ||
"k": k, | ||
"v": v, | ||
"attn_bias": bias, | ||
"fn": fn, | ||
}, | ||
label="attention", | ||
description=fw_op.NAME, | ||
sub_label=sub_label, | ||
num_threads=num_threads, | ||
) | ||
|
||
graph = torch.cuda.CUDAGraph() | ||
with torch.cuda.graph(graph): | ||
fn(q, k, v, bias) | ||
yield benchmark.Timer( | ||
stmt="graph.replay()", | ||
globals={ | ||
"graph": graph, | ||
}, | ||
label="cuda graphed attention", | ||
description=fw_op.NAME, | ||
sub_label=sub_label, | ||
num_threads=num_threads, | ||
) | ||
|
||
has_run = True | ||
|
||
if not has_run: | ||
return | ||
|
||
RUN_BASELINES = False | ||
if RUN_BASELINES: | ||
yield benchmark.Timer( | ||
stmt="fn(q, k, v, attn_bias)", | ||
globals={ | ||
"q": q, | ||
"k": k, | ||
"v": v, | ||
"attn_bias": bias, | ||
"fn": ref_attention, | ||
}, | ||
label="attention", | ||
description="eager", | ||
sub_label=sub_label, | ||
num_threads=num_threads, | ||
) | ||
|
||
|
||
benchmark_main_helper(mem_eff_attention_decoder, CASES, min_run_time=min_run_time) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.