diff --git a/tests/test_mem_eff_attention.py b/tests/test_mem_eff_attention.py index 307ae849d..664a184b2 100644 --- a/tests/test_mem_eff_attention.py +++ b/tests/test_mem_eff_attention.py @@ -26,6 +26,9 @@ sm75_or_better_only = pytest.mark.skipif( compute_capability < (7, 5), reason="requires sm75+" ) +sm80_or_better_only = pytest.mark.skipif( + compute_capability < (8, 0), reason="requires sm80+" +) _devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] ALL_FW_OPS: Sequence[Type[fmha.common.AttentionFwOpBase]] = [ @@ -1621,6 +1624,50 @@ def test_attn_bias_padded() -> None: ) +@sm80_or_better_only +@pytest.mark.parametrize("multiquery", [True, False], ids=lambda x: "mq" if x else "") +@pytest.mark.parametrize("n_heads", [1, 32]) +@pytest.mark.parametrize("bsz", [1, 8]) +@pytest.mark.parametrize("dtype", ["f16", "bf16", "f32"]) +def test_decoder(multiquery: bool, n_heads: int, bsz: int, dtype: str) -> None: + dtype_ = {"f16": torch.float16, "bf16": torch.bfloat16, "f32": torch.float32}[dtype] + torch.manual_seed(1) + d, padding = 128, 32 + k_shape = (1, bsz * padding, n_heads, d) + # TODO: support 2 kv heads etc. + k = torch.randn(k_shape, dtype=dtype_).cuda() + k_seqlen = [5, 8, 7, 1, 9, 3, 12, 32][:bsz] + v = torch.randn(k_shape, dtype=dtype_).cuda() + q = torch.randn((1, bsz, n_heads, d), dtype=dtype_).cuda() + causal_diagonal = torch.tensor( # TODO: make unnecessary + [i - 1 for i in k_seqlen], dtype=torch.int32 + ).cuda() + + if multiquery: + k = k[:, :, :1].expand(k_shape) + v = v[:, :, :1].expand(k_shape) + + attn_bias = fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( + q_seqlen=[1] * bsz, + kv_seqlen=k_seqlen, + causal_diagonal=causal_diagonal, + kv_padding=padding, + ) + + cutlass_output = fmha.memory_efficient_attention_forward( + q, k, v, attn_bias, op=fmha.cutlass.FwOp + ) + decoder_output = fmha.memory_efficient_attention_forward( + q, k, v, attn_bias, op=fmha.decoder.FwOp + ) + assert_allclose( + decoder_output, + cutlass_output, + atol=fmha.cutlass.FwOp.ERROR_ATOL[torch.float16] * 4, + rtol=fmha.cutlass.FwOp.ERROR_RTOL[torch.float16], + ) + + def test_attn_bias_from_seqlens() -> None: bias = fmha.attn_bias.BlockDiagonalMask.from_seqlens([3, 5, 1]) out = bias.split(torch.randn([1, 3 + 5 + 1, 16])) diff --git a/xformers/benchmarks/benchmark_mem_eff_attn_decoder.py b/xformers/benchmarks/benchmark_mem_eff_attn_decoder.py new file mode 100644 index 000000000..6e4e351fc --- /dev/null +++ b/xformers/benchmarks/benchmark_mem_eff_attn_decoder.py @@ -0,0 +1,179 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import itertools +from functools import partial + +import torch +from torch.utils import benchmark +from utils import benchmark_main_helper + +import xformers.ops +import xformers.ops.fmha as fmha + +torch.backends.cuda.matmul.allow_tf32 = False + +# Run with +# python xformers/benchmarks/benchmark_mem_eff_attn_decoder.py --omit-baselines --quiet +# The baselines for these benchmarks are really slow because there is +# so much padding in the inputs, so there is no point running them. + + +def ref_attention_bmk(q, k, v, attn_bias=None): + if isinstance(attn_bias, xformers.ops.AttentionMask): + attn_bias = ( + attn_bias.materialize((q.shape[0], 1, q.shape[1], k.shape[1])) + .to(q) + .squeeze() + ) + q = q * (1.0 / q.shape[-1] ** 0.5) + if attn_bias is None: + attn = q @ k.transpose(-2, -1) + else: + # equivalent to (q @ k.transpose(-2, -1) + m).softmax(-1) @ v + # but faster, and is what is used in PyTorch now + attn = torch.baddbmm(attn_bias, q, k.transpose(-2, -1)) + attn = attn.softmax(-1) + return attn @ v + + +def ref_attention(q, k, v, attn_bias): + assert q.ndim == 4 + + def T(t): + return t.permute((0, 2, 1, 3)).reshape( + [t.shape[0] * t.shape[2], t.shape[1], t.shape[3]] + ) + + out = ref_attention_bmk(T(q), T(k), T(v), attn_bias) + out = out.reshape([q.shape[0], q.shape[2], q.shape[1], v.shape[3]]) + return out.permute((0, 2, 1, 3)) + + +min_run_time = 0.5 +device = torch.device("cuda") + +NUM_THREADS = [1] if device.type == "cuda" else [1, 40] + +OPS = [ + xformers.ops.fmha.cutlass.FwOp, + xformers.ops.fmha.decoder.FwOp, +] + +KV_SHAPES = [ + # list of n_keys, padding_length, batchsize + (2, 64, 3), + (32, 1024, 500), + (1000, 1024, 2), + (8000, 8192, 1), +] + +N_HEADS = [8, 64] + + +def product_dict(**kwargs): + keys = kwargs.keys() + vals = kwargs.values() + for instance in itertools.product(*vals): + yield dict(zip(keys, instance)) + + +CASES = list( + product_dict( + kv_shape=KV_SHAPES, + n_heads=N_HEADS, + num_threads=NUM_THREADS, + multiquery=[True, False], + ) +) + + +def mem_eff_attention_decoder( + kv_shape, n_heads: int, num_threads: int, multiquery: bool +): + n_keys, padding, B = kv_shape + k_seqlen = [n_keys] * B + K = 128 + + q = torch.rand(1, B, n_heads, K, device=device, dtype=torch.bfloat16) + if multiquery: + k = torch.rand( + 1, B * padding, 1, K, device=device, dtype=torch.bfloat16 + ).expand(1, B * padding, n_heads, K) + v = torch.rand( + 1, B * padding, 1, K, device=device, dtype=torch.bfloat16 + ).expand(1, B * padding, n_heads, K) + else: + k = torch.rand(1, B * padding, n_heads, K, device=device, dtype=torch.bfloat16) + v = torch.rand(1, B * padding, n_heads, K, device=device, dtype=torch.bfloat16) + + bias = fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( + q_seqlen=[1] * B, + kv_seqlen=k_seqlen, + kv_padding=padding, + ) + + sub_label = f"{B}batch-{k_seqlen[0]}keys-{n_heads}heads" + if multiquery: + sub_label += "-mq" + + has_run = False + for fw_op in OPS: + fn = partial(xformers.ops.memory_efficient_attention_forward, op=fw_op) + + yield benchmark.Timer( + stmt="fn(q, k, v, attn_bias)", + globals={ + "q": q, + "k": k, + "v": v, + "attn_bias": bias, + "fn": fn, + }, + label="attention", + description=fw_op.NAME, + sub_label=sub_label, + num_threads=num_threads, + ) + + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + fn(q, k, v, bias) + yield benchmark.Timer( + stmt="graph.replay()", + globals={ + "graph": graph, + }, + label="cuda graphed attention", + description=fw_op.NAME, + sub_label=sub_label, + num_threads=num_threads, + ) + + has_run = True + + if not has_run: + return + + RUN_BASELINES = False + if RUN_BASELINES: + yield benchmark.Timer( + stmt="fn(q, k, v, attn_bias)", + globals={ + "q": q, + "k": k, + "v": v, + "attn_bias": bias, + "fn": ref_attention, + }, + label="attention", + description="eager", + sub_label=sub_label, + num_threads=num_threads, + ) + + +benchmark_main_helper(mem_eff_attention_decoder, CASES, min_run_time=min_run_time) diff --git a/xformers/benchmarks/utils.py b/xformers/benchmarks/utils.py index 1231ee7d8..a3d10d63d 100644 --- a/xformers/benchmarks/utils.py +++ b/xformers/benchmarks/utils.py @@ -68,7 +68,8 @@ def pretty_plot( results, title, units: str, filename=None, dash_key="", legend_loc="lower right" ): """Graph out the contents of a dict. - Dash key means that if the result label has this key, then it will be displayed with a dash""" + Dash key means that if the result label has this key, then it will be displayed with a dash + """ if not filename: filename = title + ".png" @@ -139,7 +140,8 @@ def bench_functions( def pretty_barplot(results, title, units: str, filename=None, dash_key=""): """Graph out the contents of a dict. - Dash key means that if the result label has this key, then it will be displayed with a dash""" + Dash key means that if the result label has this key, then it will be displayed with a dash + """ if not filename: filename = title + ".png" @@ -295,7 +297,7 @@ def _finalize_results(results: List[Tuple[Dict[str, Any], Any]]) -> List[Any]: """ all_algorithms: Set[str] = set() all_description: Set[str] = set() - for (metadata, r) in results: + for metadata, r in results: algo = metadata.get(META_ALGORITHM, None) if algo is not None: all_algorithms.add(algo) @@ -304,7 +306,7 @@ def _finalize_results(results: List[Tuple[Dict[str, Any], Any]]) -> List[Any]: display_descr = len(all_description) > 1 display_results = [] - for (metadata, r) in results: + for metadata, r in results: algo = metadata.get(META_ALGORITHM, None) if algo is None: display_results.append(r) @@ -343,14 +345,13 @@ def _render_bar_plot(results: List[Any], store_results_folder: str) -> None: all_data_run: List[Any] = [] for key, runtime_values in runtime.items(): memory_values = memory_usage[key] - all_data_mem.append( - [key] - + [ - memory_values.get(d, 0) - / memory_values.get(all_descriptions[0], math.inf) - for d in all_descriptions - ] - ) + denom = memory_values.get(all_descriptions[0], math.inf) + if denom == 0: + all_data_mem.append([key] + [0] * len(all_descriptions)) + else: + all_data_mem.append( + [key] + [memory_values.get(d, 0) / denom for d in all_descriptions] + ) all_data_run.append( [key] + [ @@ -409,7 +410,11 @@ def benchmark_main_helper(benchmark_fn, cases: List[Dict[str, Any]], **kwargs) - type=str, help="Compare to previously stored benchmarks (coma separated)", ) - parser.add_argument("--omit-baselines", action="store_true") + parser.add_argument( + "--omit-baselines", + action="store_true", + help="Do not run the (potentially slow) baselines", + ) parser.add_argument( "--quiet", action="store_true", diff --git a/xformers/components/attention/ortho.py b/xformers/components/attention/ortho.py index 3737f6cdd..3d6de43a3 100644 --- a/xformers/components/attention/ortho.py +++ b/xformers/components/attention/ortho.py @@ -321,4 +321,4 @@ def _compute_orthogonal_landmarks(self, q: torch.Tensor) -> torch.Tensor: landmarks = torch.masked_select(q_samples, selected_mask.bool()).reshape( B, -1, D ) - return landmarks #  (B, M, D) + return landmarks # (B, M, D) diff --git a/xformers/csrc/attention/attention.cpp b/xformers/csrc/attention/attention.cpp index f51c8f00e..237a12ea5 100644 --- a/xformers/csrc/attention/attention.cpp +++ b/xformers/csrc/attention/attention.cpp @@ -25,6 +25,8 @@ TORCH_LIBRARY_FRAGMENT(xformers, m) { "xformers::efficient_attention_forward_small_k(Tensor query, Tensor key, Tensor value, bool compute_logsumexp, Tensor? attn_bias, float p) -> (Tensor, Tensor, int, int)")); m.def(TORCH_SELECTIVE_SCHEMA( "xformers::efficient_attention_forward_cutlass(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, Tensor? seqstart_q, Tensor? seqstart_k, int? max_seqlen_q, float dropout_p, bool compute_logsumexp, int custom_mask_type, float? scale, Tensor? seqlen_k) -> (Tensor, Tensor, int, int)")); + m.def(TORCH_SELECTIVE_SCHEMA( + "xformers::efficient_attention_forward_decoder(Tensor query, Tensor key, Tensor value, Tensor seq_positions, float scale) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA( "xformers::efficient_attention_backward_small_k(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor logsumexp, Tensor output, Tensor? attn_bias, float p, int rng_seed, int rng_offset) -> (Tensor, Tensor, Tensor)")); m.def(TORCH_SELECTIVE_SCHEMA( diff --git a/xformers/csrc/attention/cuda/fmha/decoder.cu b/xformers/csrc/attention/cuda/fmha/decoder.cu new file mode 100644 index 000000000..b13257618 --- /dev/null +++ b/xformers/csrc/attention/cuda/fmha/decoder.cu @@ -0,0 +1,499 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +namespace { + +// Each block handles a single batch and head + +// Each warp handles separate D dimension. + +// Load Q into registers in all warps. +// Split T across warps in a block +// Compute S[T_MAX] = for i in range(T): S[t] = sum(Q[d] * K[t, d]) +// Use shared reduction to compute max and compute softmax on shared memory. + +// Split T across warps in a block + +// each warp compute sum(t_subset) P[t] * V[t_subset, d] +// outputs are of size float[D] + +constexpr int32_t kThreadsPerWarp = 32; +constexpr int32_t kWarpsPerBlock = 32; +constexpr int32_t D_H = 128; +constexpr int32_t T_MAX = 8192; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + +struct __align__(16) fx4 { + float x; + float y; + float z; + float w; + __host__ __device__ fx4() { + x = 0; + y = 0; + z = 0; + w = 0; + } +}; + +template +struct scalar4; + +template <> // bfx4 +struct __align__(8) scalar4 { + __nv_bfloat162 vals[2]; + using whole_int_t = uint2; +}; +template <> +struct __align__(8) scalar4 { + __half2 vals[2]; + using whole_int_t = uint2; +}; +template <> +struct scalar4 { + fx4 v; + using whole_int_t = uint4; +}; + +// bfx4_dot +__device__ __forceinline__ float scalar4_dot( + scalar4 a, + scalar4 b) { + // float2 acc = {0, 0}; + // __nv_bfloat162 acc; + // acc.x = static_cast(0); + // acc.y = static_cast(0); + // TODO: need to be performed in float32? + auto a0 = __bfloat1622float2(a.vals[0]); + auto a1 = __bfloat1622float2(a.vals[1]); + auto b0 = __bfloat1622float2(b.vals[0]); + auto b1 = __bfloat1622float2(b.vals[1]); + return a0.x * b0.x + a0.y * b0.y + a1.x * b1.x + a1.y * b1.y; + + // acc = __hfma2(a.vals[0], b.vals[0], acc); + // acc = __hfma2(a.vals[1], b.vals[1], acc); + // auto r = __bfloat1622float2(acc); + // return r.x + r.y; +} +__device__ __forceinline__ float scalar4_dot( + scalar4 a, + scalar4 b) { + auto a0 = __half22float2(a.vals[0]); + auto a1 = __half22float2(a.vals[1]); + auto b0 = __half22float2(b.vals[0]); + auto b1 = __half22float2(b.vals[1]); + return a0.x * b0.x + a0.y * b0.y + a1.x * b1.x + a1.y * b1.y; +} +__device__ __forceinline__ float scalar4_dot( + scalar4 a, + scalar4 b) { + return a.v.x * b.v.x + a.v.y * b.v.y + a.v.z * b.v.z + a.v.w * b.v.w; +} + +// bfx4_scale_acc +__device__ __forceinline__ fx4 +scalar4_scale_acc(fx4 acc, scalar4 a, float b) { + auto axy = __bfloat1622float2(a.vals[0]); + auto azw = __bfloat1622float2(a.vals[1]); + acc.x += axy.x * b; + acc.y += axy.y * b; + acc.z += azw.x * b; + acc.w += azw.y * b; + return acc; +} +__device__ __forceinline__ fx4 +scalar4_scale_acc(fx4 acc, scalar4 a, float b) { + auto axy = __half22float2(a.vals[0]); + auto azw = __half22float2(a.vals[1]); + acc.x += axy.x * b; + acc.y += axy.y * b; + acc.z += azw.x * b; + acc.w += azw.y * b; + return acc; +} +__device__ __forceinline__ fx4 +scalar4_scale_acc(fx4 acc, scalar4 a, float b) { + acc.x += a.v.x * b; + acc.y += a.v.y * b; + acc.z += a.v.z * b; + acc.w += a.v.w * b; + return acc; +} +__device__ __forceinline__ fx4 fx4_acc(fx4 a, fx4 b) { + a.x += b.x; + a.y += b.y; + a.z += b.z; + a.w += b.w; + return a; +} + +template +scalar4 fx4_to_scalar4(fx4 a); + +template <> // fx4_to_bfx4 +__device__ __forceinline__ scalar4 fx4_to_scalar4( + fx4 a) { + scalar4 r; + r.vals[0] = __floats2bfloat162_rn(a.x, a.y); + r.vals[1] = __floats2bfloat162_rn(a.z, a.w); + return r; +} +template <> +__device__ __forceinline__ scalar4 fx4_to_scalar4(fx4 a) { + scalar4 r; + r.vals[0] = __floats2half2_rn(a.x, a.y); + r.vals[1] = __floats2half2_rn(a.z, a.w); + return r; +} +template <> +__device__ __forceinline__ scalar4 fx4_to_scalar4(fx4 a) { + return {a}; +} +#define FINAL_MASK 0xffffffff + +template +__inline__ __device__ T warpReduceSum(T val) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val += __shfl_xor_sync(FINAL_MASK, val, mask, 32); + return val; +} + +template +__inline__ __device__ T warpReduceMax(T val) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val = max(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); + return val; +} + +#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + +// TODO: can also fuse RoPe into this kernel. Doesn't seem worth it. +template < + typename scalar_t, + // Offset from values read from seq_positions. + // Never nonzero in Python xformers library. + int seq_positions_shift = 0> +__global__ void mqa_attn_kernel( + at::PackedTensorAccessor32 XQ, + at::PackedTensorAccessor64 cache_K, + at::PackedTensorAccessor64 cache_V, + at::PackedTensorAccessor32 O, + at::PackedTensorAccessor32 seq_positions, + float qk_scale) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using whole_int_t = typename scalar4::whole_int_t; + static_assert(4 * kThreadsPerWarp == D_H, ""); + static_assert(kWarpsPerBlock <= kThreadsPerWarp, ""); + + extern __shared__ __align__(16) float smem[]; + + // Each block handles a single batch and head + int32_t b = blockIdx.x; + int32_t h = blockIdx.y; + + // Note: this is decoding case where we attent to current and all previous + // tokens. + int32_t t_max = seq_positions[b] + seq_positions_shift; + + int32_t warp_idx = threadIdx.y; + // need kWarpsPerBlock == blockDim.y; + // Need D_H == 128 + auto* q_ = &(XQ[b][0][h][0]); + + bool multiquery = cache_K.size(2) == 1; + auto* cache_K_base = &cache_K[b][0][multiquery ? 0 : h][0]; + auto* cache_V_base = &cache_V[b][0][multiquery ? 0 : h][0]; + + // Load Q into registers in all warps. + // Each thread handles 4 D dimensions + scalar4 q_thread; + *reinterpret_cast(&q_thread) = + *(reinterpret_cast(q_) + threadIdx.x); + + // Each block computes different B value + float max_qk_acc = std::numeric_limits::lowest(); + + // Compute S[T_MAX] = for i in range(T): S[t] = sum(Q[d] * K[t, d]) + // Split T across warps in a block, unroll loads to expose more + // parallelism. + + constexpr int32_t kTimeUnroll = 1; + scalar4 k_loads[kTimeUnroll]; + + int32_t t_max_unroll = + (t_max / (kWarpsPerBlock * kTimeUnroll)) * (kWarpsPerBlock * kTimeUnroll); + for (auto tt = warp_idx * kTimeUnroll; tt < t_max_unroll; + tt += kWarpsPerBlock * kTimeUnroll) { +#pragma unroll kTimeUnroll + for (auto ttt = 0; ttt < kTimeUnroll; ++ttt) { + int32_t t = tt + ttt; + auto* k_ = cache_K_base + t * cache_K.stride(1); + // scalar4 k_thread; + *reinterpret_cast(&k_loads[ttt]) = + *(reinterpret_cast(k_) + threadIdx.x); + } +#pragma unroll kTimeUnroll + for (auto ttt = 0; ttt < kTimeUnroll; ++ttt) { + float qk_acc = 0; + int32_t t = tt + ttt; + qk_acc += scalar4_dot(q_thread, k_loads[ttt]) * qk_scale; + + qk_acc = warpReduceSum(qk_acc); + max_qk_acc = max(qk_acc, max_qk_acc); + + // write accumulated sums to smem. + if (threadIdx.x == 0) { + smem[t] = qk_acc; + } + } + } + + constexpr int32_t kTimeUnroll1 = 1; + for (auto tt = t_max_unroll + warp_idx; tt < t_max; + tt += kWarpsPerBlock * kTimeUnroll1) { +#pragma unroll kTimeUnroll1 + for (auto ttt = 0; ttt < kTimeUnroll1; ++ttt) { + int32_t t = tt + ttt; + // &(cache_K[b][t][0][0]); + auto* k_ = cache_K_base + t * cache_K.stride(1); + // scalar4 k_thread; + *reinterpret_cast(&k_loads[ttt]) = + *(reinterpret_cast(k_) + threadIdx.x); + } +#pragma unroll kTimeUnroll1 + for (auto ttt = 0; ttt < kTimeUnroll1; ++ttt) { + float qk_acc = 0; + int32_t t = tt + ttt; + qk_acc += scalar4_dot(q_thread, k_loads[ttt]) * qk_scale; + + qk_acc = warpReduceSum(qk_acc); + max_qk_acc = max(qk_acc, max_qk_acc); + + // write accumulated sums to smem. + if (threadIdx.x == 0) { + smem[t] = qk_acc; + } + } + } + + // Use shared reduction to compute max and compute softmax on shared memory. + // write max acc + if (threadIdx.x == 0) { + smem[T_MAX + warp_idx] = max_qk_acc; + } + __syncthreads(); + if (threadIdx.x < kWarpsPerBlock) { + max_qk_acc = max(max_qk_acc, smem[T_MAX + threadIdx.x]); + } + // shared across all threads in block + max_qk_acc = warpReduceMax(max_qk_acc); + // each warp computes partial sum of exp. + float softmax_denominator = 0.0f; + for (int32_t t = threadIdx.x + warp_idx * kThreadsPerWarp; t < t_max; + t += kWarpsPerBlock * kThreadsPerWarp) { + softmax_denominator += __expf(smem[t] - max_qk_acc); + } + softmax_denominator = warpReduceSum(softmax_denominator); + + __syncthreads(); + if (threadIdx.x == 0) { + smem[T_MAX + warp_idx] = softmax_denominator; + } + __syncthreads(); + // now, compute sum of exp(x - max(x)) over all intermediate results. + softmax_denominator = 0.0; + if (threadIdx.x < kWarpsPerBlock) { + softmax_denominator = smem[T_MAX + threadIdx.x]; + } + softmax_denominator = warpReduceSum(softmax_denominator); + + // now, compute the normalization across all threads. + for (int32_t t = threadIdx.x + warp_idx * kThreadsPerWarp; t < t_max; + t += kWarpsPerBlock * kThreadsPerWarp) { + smem[t] = __expf(smem[t] - max_qk_acc) / softmax_denominator; + } + __syncthreads(); + + // Now, we can comute the softmax and write the outputs. + + // Split T across warps in a block + // each warp compute sum(t_subset) P[t] * V[t_subset, d] + // outputs are of size float[D] + + float ps[kTimeUnroll]; + fx4 o_acc; + for (auto tt = warp_idx * kTimeUnroll; tt < t_max_unroll; + tt += kWarpsPerBlock * kTimeUnroll) { +#pragma unroll kTimeUnroll + for (auto ttt = 0; ttt < kTimeUnroll; ++ttt) { + int32_t t = tt + ttt; + // &(cache_V[b][t][0][0]); + auto* v_ = cache_V_base + t * cache_V.stride(1); + // scalar4 v_thread; + *reinterpret_cast(&k_loads[ttt]) = + *(reinterpret_cast(v_) + threadIdx.x); + ps[ttt] = smem[t]; + } + +#pragma unroll kTimeUnroll + for (auto ttt = 0; ttt < kTimeUnroll; ++ttt) { + o_acc = scalar4_scale_acc(o_acc, k_loads[ttt], ps[ttt]); + } + } + + for (auto tt = t_max_unroll + warp_idx; tt < t_max; + tt += kWarpsPerBlock * kTimeUnroll1) { +#pragma unroll kTimeUnroll1 + for (auto ttt = 0; ttt < kTimeUnroll1; ++ttt) { + int32_t t = tt + ttt; + // &(cache_V[b][t][0][0]); + auto* v_ = cache_V_base + t * cache_V.stride(1); + // scalar4 v_thread; + *reinterpret_cast(&k_loads[ttt]) = + *(reinterpret_cast(v_) + threadIdx.x); + ps[ttt] = smem[t]; + } + +#pragma unroll kTimeUnroll1 + for (auto ttt = 0; ttt < kTimeUnroll1; ++ttt) { + o_acc = scalar4_scale_acc(o_acc, k_loads[ttt], ps[ttt]); + } + } + + // now, each thread has partial sums. Write to smem and get accumulated + // results back. + __syncthreads(); + *(reinterpret_cast(&smem[0]) + warp_idx * kThreadsPerWarp + + threadIdx.x) = o_acc; + __syncthreads(); + // sum up partial D rows from other warps + if (warp_idx == 0) { + fx4 r; + for (int32_t w = 0; w < kWarpsPerBlock; ++w) { + auto partial_r = *( + reinterpret_cast(&smem[0]) + w * kThreadsPerWarp + threadIdx.x); + r = fx4_acc(r, partial_r); + } + // write output D row + auto* o_ = (&O[b][0][h][0]); + auto bf_r = fx4_to_scalar4(r); + *(reinterpret_cast(o_) + threadIdx.x) = + *reinterpret_cast(&bf_r); + } +#else + printf("FATAL: kernel is for sm80+ only"); +#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +} + +at::Tensor mqa_attn( + at::Tensor XQ, // [B, 1, H, D] + at::Tensor cache_K, // [B, T_MAX, H or 1, D] + at::Tensor cache_V, // [B, T_MAX, H or 1, D] + at::Tensor seq_positions, // [B] + double qk_scale) { + at::OptionalDeviceGuard guard(XQ.device()); + TORCH_CHECK(XQ.is_cuda()); + TORCH_CHECK(cache_K.is_cuda()); + TORCH_CHECK(cache_V.is_cuda()); + + TORCH_CHECK(seq_positions.is_cuda()); + + TORCH_CHECK(cache_K.size(1) <= T_MAX); + TORCH_CHECK(cache_K.size(3) == D_H); + + auto O = at::empty_like(XQ); + auto B = XQ.size(0); + auto H = XQ.size(2); + dim3 blocks(B, H); + dim3 threads(kThreadsPerWarp, kWarpsPerBlock); + + int32_t smem_softmax = T_MAX * sizeof(float) + kWarpsPerBlock * sizeof(float); + int32_t smem_output = D_H * sizeof(float) * kWarpsPerBlock; + int32_t smem = max(smem_softmax, smem_output); + + if (XQ.scalar_type() == at::ScalarType::Half) { + if (smem > 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + mqa_attn_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem)); + } + mqa_attn_kernel + <<>>( + XQ.packed_accessor32(), + cache_K.packed_accessor64(), + cache_V.packed_accessor64(), + O.packed_accessor32(), + seq_positions + .packed_accessor32(), + qk_scale); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else if (XQ.scalar_type() == at::ScalarType::BFloat16) { + if (smem > 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + mqa_attn_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem)); + } + mqa_attn_kernel + <<>>( + XQ.packed_accessor32(), + cache_K.packed_accessor64(), + cache_V.packed_accessor64(), + O.packed_accessor32(), + seq_positions + .packed_accessor32(), + qk_scale); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + TORCH_CHECK( + XQ.scalar_type() == at::ScalarType::Float, + "Only supports bf16/f16/f32"); + if (smem > 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + mqa_attn_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem)); + } + mqa_attn_kernel + <<>>( + XQ.packed_accessor32(), + cache_K.packed_accessor64(), + cache_V.packed_accessor64(), + O.packed_accessor32(), + seq_positions + .packed_accessor32(), + qk_scale); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + return O; +} + +} // namespace + +TORCH_LIBRARY_IMPL(xformers, CUDA, m) { + m.impl( + TORCH_SELECTIVE_NAME("xformers::efficient_attention_forward_decoder"), + TORCH_FN(mqa_attn)); +} diff --git a/xformers/ops/fmha/__init__.py b/xformers/ops/fmha/__init__.py index 2101eaa6b..73bbd6c15 100644 --- a/xformers/ops/fmha/__init__.py +++ b/xformers/ops/fmha/__init__.py @@ -7,7 +7,7 @@ import torch -from . import cutlass, flash, small_k, triton +from . import cutlass, decoder, flash, small_k, triton from .attn_bias import AttentionBias, BlockDiagonalMask, LowerTriangularMask from .common import ( AttentionBwOpBase, @@ -24,6 +24,7 @@ MemoryEfficientAttentionCutlassOp = (cutlass.FwOp, cutlass.BwOp) MemoryEfficientAttentionCutlassFwdFlashBwOp = (cutlass.FwOp, flash.BwOp) +MemoryEfficientAttentionDecoderOp = (decoder.FwOp, cutlass.BwOp) MemoryEfficientAttentionTritonFwdFlashBwOp = (triton.FwOp, flash.BwOp) MemoryEfficientAttentionFlashAttentionOp = (flash.FwOp, flash.BwOp) MemoryEfficientAttentionOp = (small_k.FwOp, small_k.BwOp) @@ -303,7 +304,7 @@ def _memory_efficient_attention_forward( inp.validate_inputs() output_shape = inp.normalize_bmhk() if op is None: - op = _dispatch_fw(inp) + op = _dispatch_fw(inp, False) else: _ensure_op_supports_or_raise(ValueError, "memory_efficient_attention", op, inp) @@ -317,7 +318,7 @@ def _memory_efficient_attention_forward_requires_grad( inp.validate_inputs() output_shape = inp.normalize_bmhk() if op is None: - op = _dispatch_fw(inp) + op = _dispatch_fw(inp, True) else: _ensure_op_supports_or_raise(ValueError, "memory_efficient_attention", op, inp) out = op.apply(inp, needs_gradient=True) diff --git a/xformers/ops/fmha/attn_bias.py b/xformers/ops/fmha/attn_bias.py index cfff13496..13d0ce9bd 100644 --- a/xformers/ops/fmha/attn_bias.py +++ b/xformers/ops/fmha/attn_bias.py @@ -198,10 +198,12 @@ class _PaddedSeqLenInfo(_SeqLenInfo): seqstart: torch.IntTensor([0, 4, 8, 12]) seqlen_py: [2, 3, 2] seqlen: torch.IntTensor([2, 3, 2]) + padding: 4 """ seqlen: torch.Tensor seqlen_py: Sequence[int] + padding: int # From parent: seqstart[i] contains the start position # of the i-th sequence # seqstart: torch.Tensor @@ -240,6 +242,7 @@ def from_seqlens_padded( max_seqlen=max(seqlens), seqstart=torch.tensor(seqstart_py, dtype=torch.int32), seqstart_py=seqstart_py, + padding=padding, ) def split( @@ -600,7 +603,8 @@ def from_seqlens( kv_seqlen: Sequence[int], causal_diagonal: Any = None, ) -> "BlockDiagonalCausalWithOffsetPaddedKeysMask": - """Creates a :attr:`BlockDiagonalCausalWithOffsetPaddedKeysMask` from a list of tensors lengths for query and key/value. + """Creates a :attr:`BlockDiagonalCausalWithOffsetPaddedKeysMask` from a list of tensor + lengths for query and key/value. Args: q_seqlen (Sequence[int]): List or tensor of sequence lengths for query tensors diff --git a/xformers/ops/fmha/common.py b/xformers/ops/fmha/common.py index c9c599da6..98e2476ec 100644 --- a/xformers/ops/fmha/common.py +++ b/xformers/ops/fmha/common.py @@ -388,7 +388,7 @@ def from_arguments( p=p, scale=scale, ) - return AttentionOpDispatch(op=(_dispatch_fw(inp), _dispatch_bw(inp))) + return AttentionOpDispatch(op=(_dispatch_fw(inp, True), _dispatch_bw(inp))) def bmk2bmhk(tensor, num_heads: int) -> torch.Tensor: diff --git a/xformers/ops/fmha/decoder.py b/xformers/ops/fmha/decoder.py new file mode 100644 index 000000000..e9d72823c --- /dev/null +++ b/xformers/ops/fmha/decoder.py @@ -0,0 +1,105 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, List, Optional, Set, Tuple + +import numpy as np +import torch + +from ..common import get_xformers_operator, register_operator +from .attn_bias import BlockDiagonalCausalWithOffsetPaddedKeysMask +from .common import AttentionFwOpBase, Context, Inputs + + +@register_operator +class FwOp(AttentionFwOpBase): + """An operator optimized for very small values of K (``K <= 32``) \ + and f32 pre-Ampere as it does not use TensorCores. + Only supports contiguous inputs in BMK format, so an extra reshape \ + or contiguous call might be done. + + :Deprecated: + + This operator is deprecated and should not be used in new code + """ + + OPERATOR = get_xformers_operator("efficient_attention_forward_decoder") + SUPPORTED_DEVICES = {"cuda"} + SUPPORTED_DTYPES = {torch.bfloat16, torch.half, torch.float32} + CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0) + SUPPORTED_MAX_K: float = 128 + SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {BlockDiagonalCausalWithOffsetPaddedKeysMask} + SUPPORTS_DROPOUT = False + SUPPORTS_CUSTOM_SCALE = True + NAME = "decoderF" + + @classmethod + def not_supported_reasons(cls, d: Inputs) -> List[str]: + reasons = super(FwOp, cls).not_supported_reasons(d) + + attn_bias = d.attn_bias + if isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask): + # If we don't get here, we've an error elsewhere + if d.query.ndim != 4 or d.key.ndim != 4: + reasons.append("Inputs must be BMHK. BMK not supported") + + if d.query.shape[0] != 1: + reasons.append("One formal batch element expected") + + if d.query.shape[-1] != 128: + reasons.append("Only head_dim==128 for now.") + + if d.key.stride(-1) != 1: + reasons.append("expect keys to have last dim contiguous") + + if d.value.stride(-1) != 1: + reasons.append("expect values to have last dim contiguous") + + if attn_bias.q_seqinfo.max_seqlen != 1: + reasons.append("decoding expects one query") + + if attn_bias.k_seqinfo.padding > 8192: + reasons.append("key padding exceeds 8192") + + return reasons + + @classmethod + def apply( + cls, inp: Inputs, needs_gradient: bool + ) -> Tuple[torch.Tensor, Optional[Context]]: + if needs_gradient: + raise NotImplementedError("gradient") + attn_bias = inp.attn_bias + assert isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask) + + attn_bias.k_seqinfo.to(inp.query.device) + attn_bias.q_seqinfo.to(inp.query.device) + + padding = attn_bias.k_seqinfo.padding + multiquery = inp.key.stride(2) == 0 + if multiquery: + key = inp.key[0, :, :1].unflatten(0, (-1, padding)) + value = inp.value[0, :, :1].unflatten(0, (-1, padding)) + else: + key = inp.key[0].unflatten(0, (-1, padding)) + value = inp.value[0].unflatten(0, (-1, padding)) + + seq_positions = attn_bias.k_seqinfo.seqlen + + query = inp.query[0, :, None] + + if inp.scale is not None: + qk_scale = inp.scale + else: + qk_scale = 1.0 / np.sqrt(key.shape[-1]) + + out = cls.OPERATOR( + query=query, + key=key, + value=value, + seq_positions=seq_positions, + scale=qk_scale, + ) + return out, None diff --git a/xformers/ops/fmha/dispatch.py b/xformers/ops/fmha/dispatch.py index e3bd96a4f..aa7b51b1e 100644 --- a/xformers/ops/fmha/dispatch.py +++ b/xformers/ops/fmha/dispatch.py @@ -5,9 +5,10 @@ import textwrap -from typing import List, Type, TypeVar +from collections import deque +from typing import List, Sequence, Type, TypeVar -from . import cutlass, flash, small_k, triton +from . import cutlass, decoder, flash, small_k, triton from .common import AttentionBwOpBase, AttentionFwOpBase, Inputs @@ -53,7 +54,7 @@ def _format_not_supported_reasons(op, reasons: List[str]) -> str: return f"`{op.NAME}` is not supported because:\n " + "\n ".join(reasons) -def _run_priority_list(name: str, priority_list: List[T], inp: Inputs) -> T: +def _run_priority_list(name: str, priority_list: Sequence[T], inp: Inputs) -> T: not_supported_reasons: List[List[str]] = [] for op in priority_list: not_supported = op.not_supported_reasons(inp) @@ -69,7 +70,7 @@ def _run_priority_list(name: str, priority_list: List[T], inp: Inputs) -> T: raise NotImplementedError(msg) -def _dispatch_fw(inp: Inputs) -> Type[AttentionFwOpBase]: +def _dispatch_fw(inp: Inputs, needs_gradient: bool) -> Type[AttentionFwOpBase]: """Computes the best operator for forward Raises: @@ -79,18 +80,26 @@ def _dispatch_fw(inp: Inputs) -> Type[AttentionFwOpBase]: AttentionOp: The best operator for the configuration """ - priority_list_ops: List[Type[AttentionFwOpBase]] = [ - flash.FwOp, - triton.FwOp, - cutlass.FwOp, - small_k.FwOp, - ] + priority_list_ops = deque( + [ + flash.FwOp, + triton.FwOp, + cutlass.FwOp, + small_k.FwOp, + ] + ) if _is_cutlass_fwd_faster_than_flash(inp): priority_list_ops.remove(cutlass.FwOp) - priority_list_ops.insert(0, cutlass.FwOp) + priority_list_ops.appendleft(cutlass.FwOp) if _is_triton_fwd_fastest(inp): priority_list_ops.remove(triton.FwOp) - priority_list_ops.insert(0, triton.FwOp) + priority_list_ops.appendleft(triton.FwOp) + if not needs_gradient: + multiquery = inp.key.stride(2) == 0 + if not multiquery: + # With multiquery, cutlass is sometimes faster than decoder + # but it's not currently clear when. + priority_list_ops.appendleft(decoder.FwOp) return _run_priority_list( "memory_efficient_attention_forward", priority_list_ops, inp )