diff --git a/.gitmodules b/.gitmodules index ab23324ae..b15bd78f6 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,6 @@ -[submodule "third_party/flash-attention"] - path = third_party/flash-attention - url = https://github.com/HazyResearch/flash-attention.git [submodule "third_party/cutlass"] path = third_party/cutlass url = https://github.com/NVIDIA/cutlass.git +[submodule "third_party/flash-attention"] + path = third_party/flash-attention + url = https://github.com/Dao-AILab/flash-attention.git diff --git a/setup.py b/setup.py index b5741bb1a..e009cb7ff 100644 --- a/setup.py +++ b/setup.py @@ -104,14 +104,18 @@ def get_cuda_version(cuda_dir) -> int: def get_flash_attention_extensions(cuda_version: int, extra_compile_args): + # XXX: Not supported on windows yet + # https://github.com/Dao-AILab/flash-attention/issues/345 + if platform.system() != "Linux": + return [] # Figure out default archs to target DEFAULT_ARCHS_LIST = "" if cuda_version >= 1108: - DEFAULT_ARCHS_LIST = "7.5;8.0;8.6;9.0" + DEFAULT_ARCHS_LIST = "8.0;8.6;9.0" elif cuda_version > 1100: - DEFAULT_ARCHS_LIST = "7.5;8.0;8.6" + DEFAULT_ARCHS_LIST = "8.0;8.6" elif cuda_version == 1100: - DEFAULT_ARCHS_LIST = "7.5;8.0" + DEFAULT_ARCHS_LIST = "8.0" else: return [] @@ -125,8 +129,8 @@ def get_flash_attention_extensions(cuda_version: int, extra_compile_args): arch_arr = arch.split(".") num = 10 * int(arch_arr[0]) + int(arch_arr[1].partition("+")[0]) - # Need at least 7.5 - if num < 75: + # Need at least 8.0 + if num < 80: continue nvcc_archs_flags.append(f"-gencode=arch=compute_{num},code=sm_{num}") if arch.endswith("+PTX"): @@ -134,6 +138,10 @@ def get_flash_attention_extensions(cuda_version: int, extra_compile_args): if not nvcc_archs_flags: return [] + nvcc_windows_flags = [] + if platform.system() == "Windows": + nvcc_windows_flags = ["-Xcompiler", "/permissive-"] + flash_root = os.path.join(this_dir, "third_party", "flash-attention") if not os.path.exists(flash_root): raise RuntimeError( @@ -141,35 +149,31 @@ def get_flash_attention_extensions(cuda_version: int, extra_compile_args): "to run `git submodule update --init --recursive` ?" ) + flash_root = os.path.join(this_dir, "third_party", "flash-attention") + sources = ["csrc/flash_attn/flash_api.cpp"] + for f in glob.glob(os.path.join(flash_root, "csrc", "flash_attn", "src", "*.cu")): + sources.append(str(Path(f).relative_to(flash_root))) return [ CUDAExtension( name="xformers._C_flashattention", - sources=[ - os.path.join("third_party", "flash-attention", path) - for path in [ - "csrc/flash_attn/fmha_api.cpp", - "csrc/flash_attn/src/fmha_fwd_hdim32.cu", - "csrc/flash_attn/src/fmha_fwd_hdim64.cu", - "csrc/flash_attn/src/fmha_fwd_hdim128.cu", - "csrc/flash_attn/src/fmha_bwd_hdim32.cu", - "csrc/flash_attn/src/fmha_bwd_hdim64.cu", - "csrc/flash_attn/src/fmha_bwd_hdim128.cu", - "csrc/flash_attn/src/fmha_block_fprop_fp16_kernel.sm80.cu", - "csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu", - ] - ], + sources=[os.path.join(flash_root, path) for path in sources], extra_compile_args={ **extra_compile_args, "nvcc": extra_compile_args.get("nvcc", []) + [ "-O3", "-std=c++17", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", "--expt-relaxed-constexpr", "--expt-extended-lambda", "--use_fast_math", "--ptxas-options=-v", ] + nvcc_archs_flags + + nvcc_windows_flags + get_extra_nvcc_flags_for_build_type(), }, include_dirs=[ @@ -177,7 +181,7 @@ def get_flash_attention_extensions(cuda_version: int, extra_compile_args): for p in [ Path(flash_root) / "csrc" / "flash_attn", Path(flash_root) / "csrc" / "flash_attn" / "src", - Path(this_dir) / "third_party" / "cutlass" / "include", + Path(flash_root) / "csrc" / "cutlass" / "include", ] ], ) @@ -230,10 +234,6 @@ def get_extensions(): "-U__CUDA_NO_HALF_CONVERSIONS__", "--extended-lambda", "-D_ENABLE_EXTENDED_ALIGNED_STORAGE", - # Workaround for a regression with nvcc > 11.6 - # See https://github.com/facebookresearch/xformers/issues/712 - "--ptxas-options=-O2", - "--ptxas-options=-allow-expensive-optimizations=true", ] + get_extra_nvcc_flags_for_build_type() if os.getenv("XFORMERS_ENABLE_DEBUG_ASSERTIONS", "0") != "1": nvcc_flags.append("-DNDEBUG") @@ -259,6 +259,15 @@ def get_extensions(): cuda_version=cuda_version, extra_compile_args=extra_compile_args ) + # NOTE: This should not be applied to Flash-Attention + # see https://github.com/Dao-AILab/flash-attention/issues/359 + extra_compile_args["nvcc"] += [ + # Workaround for a regression with nvcc > 11.6 + # See https://github.com/facebookresearch/xformers/issues/712 + "--ptxas-options=-O2", + "--ptxas-options=-allow-expensive-optimizations=true", + ] + ext_modules.append( extension( "xformers._C", diff --git a/tests/test_mem_eff_attention.py b/tests/test_mem_eff_attention.py index 664a184b2..a53b33f10 100644 --- a/tests/test_mem_eff_attention.py +++ b/tests/test_mem_eff_attention.py @@ -82,7 +82,7 @@ def generate_test_shapes_B_Mq_Mkv_H_K_Kv(op): shapes = [] for B in op._TEST_BATCH_SIZES: for Mq in [32, 256]: - for Mkv in [32, 64, 256]: + for Mkv in [32, 64, 256, 1024]: for K in op._TEST_K: shapes.append((B, Mq, Mkv, 1, K, K)) Mq = 256 @@ -93,7 +93,7 @@ def generate_test_shapes_B_Mq_Mkv_H_K_Kv(op): for M in [2, 3, 15, 31, 32, 34, 68, 72, 90, 132, 136]: shapes.append((B, M, Mkv, H, K, K)) shapes.append((B, Mq, M, H, K, K)) - for _K in [1, 2, 3, 31, 34, 36, 38, 40, 64, 256 + 2, 256 + 8, 512]: + for _K in [1, 2, 3, 31, 34, 36, 38, 40, 64, 80, 160, 256 + 2, 256 + 8, 512]: if _K <= op.SUPPORTED_MAX_K: shapes.append((B, Mq, Mkv, H, _K, _K)) # Different value for K / Kv @@ -108,6 +108,17 @@ def generate_test_shapes_B_Mq_Mkv_H_K_Kv(op): # Some number of heads for H in [3, 5, 12]: shapes.append((max(1, B // H), Mq, Mkv, H, K, K)) + # Filter-out not supported shapes + shapes = [ + shape + for shape in shapes + if len( + op.shape_not_supported_reasons( + Mq=shape[1], Mkv=shape[2], K=shape[4], Kv=shape[5] + ) + ) + == 0 + ] # Add some random shapes if op in [ fmha.cutlass.FwOp, @@ -116,7 +127,8 @@ def generate_test_shapes_B_Mq_Mkv_H_K_Kv(op): ]: K_CHOICES = [8 * i for i in range(1, 256 // 8)] r = random.Random(0) - for _ in range(20): + found_count = 0 + while found_count < 20: B = r.randint(1, 400) Mq = r.randint(1, 500) Mkv = r.randint(1, 500) @@ -126,6 +138,9 @@ def generate_test_shapes_B_Mq_Mkv_H_K_Kv(op): Kv = r.choice(K_CHOICES) if not op.SUPPORTS_DIFFERENT_VALUE_EMBED: Kv = K + if len(op.shape_not_supported_reasons(Mq, Mkv, K, Kv)): + continue + found_count += 1 shapes.append((B, Mq, Mkv, H, K, Kv)) return shapes diff --git a/xformers/benchmarks/benchmark_mem_eff_attention.py b/xformers/benchmarks/benchmark_mem_eff_attention.py index 8e532adf0..3b0c6061d 100644 --- a/xformers/benchmarks/benchmark_mem_eff_attention.py +++ b/xformers/benchmarks/benchmark_mem_eff_attention.py @@ -117,7 +117,11 @@ def T(t): (8, 2048, 20, 128), # LLaMa 70b - mp=8/16 *sorted(list(itertools.product([1, 2], [2048, 4096, 8192], [4, 8], [128]))), - *sorted(list(itertools.product([16], [128, 512, 1024], [16], [16, 32, 64, 128]))), + *sorted( + list( + itertools.product([16], [128, 512, 1024], [16], [16, 32, 64, 128, 160, 256]) + ) + ), ] OPS = [ diff --git a/xformers/ops/fmha/common.py b/xformers/ops/fmha/common.py index 98e2476ec..f6b4a821e 100644 --- a/xformers/ops/fmha/common.py +++ b/xformers/ops/fmha/common.py @@ -151,6 +151,7 @@ class AttentionOpBase(BaseOperator): SUPPORTS_DROPOUT: bool SUPPORTS_CUSTOM_SCALE: bool = False SUPPORTS_DIFFERENT_VALUE_EMBED: bool = False + IS_DETERMINISTIC: bool = True NAME: str OPERATOR_CATEGORY = "memory_efficient_attention" @@ -161,13 +162,31 @@ class AttentionOpBase(BaseOperator): def supports(cls, d: Inputs) -> bool: return not cls.not_supported_reasons(d) + @classmethod + def shape_not_supported_reasons( + cls, Mq: int, Mkv: int, K: int, Kv: int + ) -> List[str]: + reasons = [] + if not cls.SUPPORTS_DIFFERENT_VALUE_EMBED and K != Kv: + reasons.append("query.shape[-1] != value.shape[-1]") + if max(K, Kv) > cls.SUPPORTED_MAX_K: + reasons.append( + f"max(query.shape[-1] != value.shape[-1]) > {cls.SUPPORTED_MAX_K}" + ) + return reasons + @classmethod def not_supported_reasons(cls, d: Inputs) -> List[str]: """ Returns a list of reasons why this is not supported. The kernel can run these inputs only if the returned list is empty """ - reasons = [] + reasons = cls.shape_not_supported_reasons( + Mq=d.query.shape[1], + Mkv=d.key.shape[1], + K=d.query.shape[-1], + Kv=d.query.shape[-1], + ) device_type = d.query.device.type dtype = d.query.dtype if device_type not in cls.SUPPORTED_DEVICES: @@ -176,15 +195,6 @@ def not_supported_reasons(cls, d: Inputs) -> List[str]: reasons.append("xFormers wasn't build with CUDA support") if dtype not in cls.SUPPORTED_DTYPES: reasons.append(f"dtype={dtype} (supported: {cls.SUPPORTED_DTYPES})") - if ( - not cls.SUPPORTS_DIFFERENT_VALUE_EMBED - and d.query.shape[-1] != d.value.shape[-1] - ): - reasons.append("query.shape[-1] != value.shape[-1]") - if max(d.query.shape[-1], d.value.shape[-1]) > cls.SUPPORTED_MAX_K: - reasons.append( - f"max(query.shape[-1] != value.shape[-1]) > {cls.SUPPORTED_MAX_K}" - ) if type(d.attn_bias) not in cls.SUPPORTED_ATTN_BIAS_TYPES: reasons.append(f"attn_bias type is {type(d.attn_bias)}") if (d.p != 0.0) and not cls.SUPPORTS_DROPOUT: @@ -201,7 +211,11 @@ def not_supported_reasons(cls, d: Inputs) -> List[str]: reasons.append("bf16 is only supported on A100+ GPUs") if not cls.is_available(): reasons.append( - "Operator wasn't built - see `python -m xformers.info` for more info" + "operator wasn't built - see `python -m xformers.info` for more info" + ) + if not cls.IS_DETERMINISTIC and torch.are_deterministic_algorithms_enabled(): + reasons.append( + "operator is non-deterministic, but `torch.use_deterministic_algorithms` is set" ) return reasons diff --git a/xformers/ops/fmha/dispatch.py b/xformers/ops/fmha/dispatch.py index aa7b51b1e..3ed6dd1cb 100644 --- a/xformers/ops/fmha/dispatch.py +++ b/xformers/ops/fmha/dispatch.py @@ -13,14 +13,7 @@ def _is_cutlass_fwd_faster_than_flash(inp: Inputs) -> bool: - # For dropout, we can't mix & match kernels - # Unfortunately, the dropout implementation in CUTLASS - # backward is pretty slow for the BW, so disable it here - if inp.p > 0.0: - return False - - # Large values of K - return max(inp.query.shape[-1], inp.value.shape[-1]) > 64 + return False def _is_triton_fwd_fastest(inp: Inputs) -> bool: @@ -106,8 +99,7 @@ def _dispatch_fw(inp: Inputs, needs_gradient: bool) -> Type[AttentionFwOpBase]: def _is_cutlassB_faster_than_flash(inp: Inputs) -> bool: - embed_dim = max(inp.query.shape[-1], inp.value.shape[-1]) - return embed_dim > 64 and inp.attn_bias is None and inp.p == 0.0 + return False def _dispatch_bw(inp: Inputs) -> Type[AttentionBwOpBase]: diff --git a/xformers/ops/fmha/flash.py b/xformers/ops/fmha/flash.py index 29df3517a..f8d608f8f 100644 --- a/xformers/ops/fmha/flash.py +++ b/xformers/ops/fmha/flash.py @@ -56,7 +56,15 @@ def _flash_fwd( return_softmax, ): out = query.new_empty(query.shape[0], query.shape[1], value.shape[2]) - lse, rng_state = _C_flashattention.fwd( + ( + out, + q_padded, + k_padded, + v_padded, + out_padded, + softmax_lse, + p, + ) = _C_flashattention.varlen_fwd( query, key, value, @@ -70,10 +78,9 @@ def _flash_fwd( False, causal, return_softmax, - 0, None, ) - return out, lse, rng_state + return out, softmax_lse, None def _flash_bwd( grad, @@ -94,7 +101,7 @@ def _flash_bwd( causal, rng_state, ): - _C_flashattention.bwd( + _C_flashattention.varlen_bwd( grad, query, key, @@ -110,11 +117,9 @@ def _flash_bwd( max_seq_len_k, p, softmax_scale, - False, + False, # zero_tensors causal, - 0, None, - rng_state, ) return dq @@ -190,9 +195,9 @@ class FwOp(AttentionFwOpBase): OPERATOR = get_operator("xformers_flash", "flash_fwd") SUPPORTED_DEVICES: Set[str] = {"cuda"} - CUDA_MINIMUM_COMPUTE_CAPABILITY = (7, 5) + CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0) SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16} - SUPPORTED_MAX_K = 128 + SUPPORTED_MAX_K = 256 SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = { type(None), LowerTriangularMask, @@ -202,16 +207,12 @@ class FwOp(AttentionFwOpBase): SUPPORTS_DROPOUT = True SUPPORTS_CUSTOM_SCALE = True SUPPORTS_DIFFERENT_VALUE_EMBED = False - NAME = "flshattF" + NAME = "flshattFv2" @classmethod def not_supported_reasons(cls, d: Inputs) -> List[str]: reasons = super(FwOp, cls).not_supported_reasons(d) check_lastdim_alignment_stride1(reasons, "query", d.query, 8) - if d.device.type == "cuda": - device_capability = torch.cuda.get_device_capability(d.device) - if device_capability < (7, 5): - reasons.append("requires a GPU with compute capability > 7.5") return reasons @classmethod @@ -293,23 +294,38 @@ class BwOp(AttentionBwOpBase): SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED - NAME = "flshattB" + IS_DETERMINISTIC = False + NAME = "flshattBv2" + + MAX_HEADDIM_SM8x = 192 + + @classmethod + def shape_not_supported_reasons( + cls, Mq: int, Mkv: int, K: int, Kv: int + ) -> List[str]: + reasons = super().shape_not_supported_reasons(Mq, Mkv, K, Kv) + if (Mq % 128) or (Mkv % 128): + reasons.append( + "flashv2 beta: BW is incorrect when seqlen is not aligned on 128 " + "(https://github.com/Dao-AILab/flash-attention/issues/334)" + ) + return reasons @classmethod def not_supported_reasons(cls, d: Inputs) -> List[str]: reasons = super(BwOp, cls).not_supported_reasons(d) check_lastdim_alignment_stride1(reasons, "query", d.query, 8) if d.device.type == "cuda": - # We know `d.device` is cuda now - # d=128 is only supported on A100 for bw - # d > 64 is only supported on A100 for bw + # Due to limited shared-memory, some GPUs are limited in head dimension device_capability = torch.cuda.get_device_capability(d.device) - if device_capability < (7, 5): - reasons.append("requires a GPU with compute capability > 7.5") is_sm80_or_sm90 = device_capability in [(8, 0), (9, 0)] - if max(d.key.shape[-1], d.query.shape[-1]) > 64 and not is_sm80_or_sm90: + if ( + max(d.key.shape[-1], d.query.shape[-1]) > cls.MAX_HEADDIM_SM8x + and not is_sm80_or_sm90 + ): reasons.append( - "requires a GPU with compute capability 8.0 (A100) or 9.0 (H100) for 'query.shape[-1] > 64'" + "requires a GPU with compute capability 8.0 " + f"(A100) or 9.0 (H100) for 'query.shape[-1] > {cls.MAX_HEADDIM_SM8x}'" ) return reasons @@ -324,6 +340,11 @@ def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients: cu_seqlens_k, max_seqlen_k, ) = _convert_input_format(inp) + assert ctx.lse.is_contiguous + ctx_lse = ctx.lse + assert ctx_lse.shape[2] >= max_seqlen_q + if max_seqlen_q != ctx_lse.shape[2]: + ctx_lse = ctx_lse[:, :, :max_seqlen_q].contiguous() kernel_out_shape = [ inp.query.shape[0], inp.query.shape[1], @@ -368,7 +389,7 @@ def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients: inp.key, inp.value, ctx.out.reshape(kernel_out_shape), - ctx.lse, + ctx_lse, grads.dq, grads.dk, grads.dv, diff --git a/xformers/ops/fmha/triton.py b/xformers/ops/fmha/triton.py index c2fe716f5..a5d7619d3 100644 --- a/xformers/ops/fmha/triton.py +++ b/xformers/ops/fmha/triton.py @@ -12,7 +12,8 @@ from ... import _is_triton_available from ..common import register_operator -if TYPE_CHECKING or _is_triton_available(): +# XXX: Disabled for now +if TYPE_CHECKING or (False and _is_triton_available()): from ..._flash_attn.flash_attn_triton import ( _flash_attn_backward, _flash_attn_forward, @@ -75,6 +76,7 @@ class FwOp(AttentionFwOpBase): @classmethod def not_supported_reasons(cls, d: Inputs) -> List[str]: + return ["Triton implementation is disabled as we update to Flashv2"] reasons = super(FwOp, cls).not_supported_reasons(d) check_lastdim_alignment_stride1(reasons, "query", d.query, 8) check_lastdim_alignment_stride1(reasons, "key", d.key, 8) @@ -127,6 +129,7 @@ class BwOp(AttentionBwOpBase): @classmethod def not_supported_reasons(cls, d: Inputs) -> List[str]: + return ["Triton implementation is disabled as we update to Flashv2"] reasons = super(BwOp, cls).not_supported_reasons(d) check_lastdim_alignment_stride1(reasons, "query", d.query, 8) check_lastdim_alignment_stride1(reasons, "key", d.key, 8)