From 1705ffe1bb64a51469f6b188125695f889ee0ef0 Mon Sep 17 00:00:00 2001 From: danthe3rd <43445237+danthe3rd@users.noreply.github.com> Date: Wed, 11 Sep 2024 13:38:40 +0000 Subject: [PATCH] Build FAv3 ghstack-source-id: dc28f4bb77751ffc4f8d67a5ba088b262830659c Pull Request resolved: https://github.com/fairinternal/xformers/pull/1217 __original_commit__ = fairinternal/xformers@eec3f5f8b70964524fd9919b3e927bda3f1fe7d4 --- setup.py | 126 +++++++++++++++++++++++++++++++----- third_party/flash-attention | 2 +- xformers/ops/fmha/flash.py | 4 +- xformers/ops/fmha/flash3.py | 6 +- 4 files changed, 119 insertions(+), 19 deletions(-) diff --git a/setup.py b/setup.py index 9c3314673..4f0da97fb 100644 --- a/setup.py +++ b/setup.py @@ -158,7 +158,16 @@ def get_hip_version(rocm_dir) -> Optional[str]: return None -def get_flash_attention_nvcc_archs_flags(cuda_version: int): +###################################### +# FLASH-ATTENTION v2 +###################################### +# Supports `9.0`, `9.0+PTX`, `9.0a+PTX` etc... +PARSE_CUDA_ARCH_RE = re.compile( + r"(?P[0-9]+)\.(?P[0-9])(?P[a-zA-Z]{0,1})(?P\+PTX){0,1}" +) + + +def get_flash_attention2_nvcc_archs_flags(cuda_version: int): # XXX: Not supported on windows for cuda<12 # https://github.com/Dao-AILab/flash-attention/issues/345 if platform.system() != "Linux" and cuda_version < 1200: @@ -177,10 +186,6 @@ def get_flash_attention_nvcc_archs_flags(cuda_version: int): if os.getenv("XFORMERS_DISABLE_FLASH_ATTN", "0") != "0": return [] - # Supports `9.0`, `9.0+PTX`, `9.0a+PTX` etc... - PARSE_CUDA_ARCH_RE = re.compile( - r"(?P[0-9]+)\.(?P[0-9])(?P[a-zA-Z]{0,1})(?P\+PTX){0,1}" - ) archs_list = os.environ.get("TORCH_CUDA_ARCH_LIST", DEFAULT_ARCHS_LIST) nvcc_archs_flags = [] for arch in archs_list.replace(" ", ";").split(";"): @@ -205,8 +210,8 @@ def get_flash_attention_nvcc_archs_flags(cuda_version: int): return nvcc_archs_flags -def get_flash_attention_extensions(cuda_version: int, extra_compile_args): - nvcc_archs_flags = get_flash_attention_nvcc_archs_flags(cuda_version) +def get_flash_attention2_extensions(cuda_version: int, extra_compile_args): + nvcc_archs_flags = get_flash_attention2_nvcc_archs_flags(cuda_version) if not nvcc_archs_flags: return [] @@ -260,6 +265,101 @@ def get_flash_attention_extensions(cuda_version: int, extra_compile_args): ] +###################################### +# FLASH-ATTENTION v3 +###################################### +def get_flash_attention3_nvcc_archs_flags(cuda_version: int): + if os.getenv("XFORMERS_DISABLE_FLASH_ATTN", "0") != "0": + return [] + if platform.system() != "Linux" or cuda_version < 1203: + return [] + archs_list = os.environ.get("TORCH_CUDA_ARCH_LIST") + if archs_list is None: + if torch.cuda.get_device_capability("cuda") != (9, 0): + return [] + archs_list = "9.0a" + nvcc_archs_flags = [] + for arch in archs_list.replace(" ", ";").split(";"): + match = PARSE_CUDA_ARCH_RE.match(arch) + assert match is not None, f"Invalid sm version: {arch}" + num = 10 * int(match.group("major")) + int(match.group("minor")) + if num != 90: # only support Sm90 + continue + suffix = match.group("suffix") + nvcc_archs_flags.append( + f"-gencode=arch=compute_{num}{suffix},code=sm_{num}{suffix}" + ) + if match.group("ptx") is not None: + nvcc_archs_flags.append( + f"-gencode=arch=compute_{num}{suffix},code=compute_{num}{suffix}" + ) + return nvcc_archs_flags + + +def get_flash_attention3_extensions(cuda_version: int, extra_compile_args): + nvcc_archs_flags = get_flash_attention3_nvcc_archs_flags(cuda_version) + + if not nvcc_archs_flags: + return [] + + flash_root = os.path.join(this_dir, "third_party", "flash-attention") + cutlass_inc = os.path.join(flash_root, "csrc", "cutlass", "include") + if not os.path.exists(flash_root) or not os.path.exists(cutlass_inc): + raise RuntimeError( + "flashattention submodule not found. Did you forget " + "to run `git submodule update --init --recursive` ?" + ) + + sources = [ + str(Path(f).relative_to(flash_root)) + for f in glob.glob(os.path.join(flash_root, "hopper", "*.cu")) + + glob.glob(os.path.join(flash_root, "hopper", "*.cpp")) + ] + sources = [s for s in sources if "flash_bwd_hdim256_fp16_sm90.cu" not in s] + return [ + CUDAExtension( + name="xformers._C_flashattention3", + sources=[os.path.join(flash_root, path) for path in sources], + extra_compile_args={ + "cxx": extra_compile_args.get("cxx", []), + "nvcc": extra_compile_args.get("nvcc", []) + + [ + "-O3", + # "-O0", + "-std=c++17", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT16_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT162_OPERATORS__", + "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + # "--ptxas-options=-v", # printing out number of registers + "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage", + # "-lineinfo", # xformers: save binary size + "-DCUTLASS_DEBUG_TRACE_LEVEL=0", # Can toggle for debugging + "-DNDEBUG", # Important, otherwise performance is severely impacted + "-DQBLKSIZE=128", + "-DKBLKSIZE=128", + "-DCTA256", + "-DDQINRMEM", + ] + + nvcc_archs_flags + + get_extra_nvcc_flags_for_build_type(cuda_version), + }, + include_dirs=[ + p.absolute() + for p in [ + Path(flash_root) / "csrc" / "cutlass" / "include", + Path(flash_root) / "hopper", + ] + ], + ) + ] + + def rename_cpp_cu(cpp_files): for entry in cpp_files: shutil.copy(entry, os.path.splitext(entry)[0] + ".cu") @@ -380,11 +480,11 @@ def get_extensions(): ] extra_compile_args["nvcc"] = nvcc_flags - flash_extensions = [] xformers_pt_flash_attn = os.getenv("XFORMERS_PT_FLASH_ATTN") # check if the current device supports flash_attention - nvcc_archs_flags = get_flash_attention_nvcc_archs_flags(cuda_version) + flash_version = get_flash_version() + nvcc_archs_flags = get_flash_attention2_nvcc_archs_flags(cuda_version) if not nvcc_archs_flags: if xformers_pt_flash_attn == "1": raise ValueError( @@ -399,16 +499,12 @@ def get_extensions(): ) and attn_compat_module.is_pt_flash_compatible( force=xformers_pt_flash_attn == "1" ): - flash_version = torch.nn.attention._get_flash_version() + "-pt" use_pt_flash = True else: - flash_extensions = get_flash_attention_extensions( + ext_modules += get_flash_attention2_extensions( cuda_version=cuda_version, extra_compile_args=extra_compile_args ) - if flash_extensions: - flash_version = get_flash_version() - - ext_modules += flash_extensions + ext_modules += get_flash_attention3_extensions(cuda_version, extra_compile_args) # NOTE: This should not be applied to Flash-Attention # see https://github.com/Dao-AILab/flash-attention/issues/359 diff --git a/third_party/flash-attention b/third_party/flash-attention index 418d67719..bdf733be5 160000 --- a/third_party/flash-attention +++ b/third_party/flash-attention @@ -1 +1 @@ -Subproject commit 418d677192b483dfc1decfdf9aadca40b402485d +Subproject commit bdf733be55f0b323a8cf7cc6745a81c3f43cd7f0 diff --git a/xformers/ops/fmha/flash.py b/xformers/ops/fmha/flash.py index 3631f5505..f598dbb74 100644 --- a/xformers/ops/fmha/flash.py +++ b/xformers/ops/fmha/flash.py @@ -79,6 +79,7 @@ raise assert is_pt_flash_compatible(force=True) FLASH_VERSION = torch.nn.attention._get_flash_version() # type: ignore + FLASH_VERSION = f"v{FLASH_VERSION}" VARLEN_LSE_PACKED = False _USE_PT_FLASH_ATTN = True @@ -551,6 +552,7 @@ def _post_process_lse( lse: torch.Tensor, inp: Inputs, original_query_shape: Tuple[int, ...], + varlen_lse_packed: bool = VARLEN_LSE_PACKED, ) -> torch.Tensor: # Easy case: no varlen if not isinstance(inp.attn_bias, VARLEN_BIASES): @@ -560,7 +562,7 @@ def _post_process_lse( return lse # Already packed: just bring back the batch dimension - if VARLEN_LSE_PACKED: + if varlen_lse_packed: if len(original_query_shape) == 5: # (1, G, H, total_q) return lse.unflatten(0, original_query_shape[2:4]).unsqueeze(0) diff --git a/xformers/ops/fmha/flash3.py b/xformers/ops/fmha/flash3.py index e03a9fcea..88e8e29fc 100644 --- a/xformers/ops/fmha/flash3.py +++ b/xformers/ops/fmha/flash3.py @@ -80,7 +80,7 @@ def mha_fwd( softmax_lse, p, ) = _C_flashattention3.fwd( - query, key, value, None, softmax_scale, is_causal + query, key, value, None, softmax_scale, None, None, None, is_causal ) else: out, q, k, v, out_padded, softmax_lse = _C_flashattention3.varlen_fwd( @@ -316,7 +316,9 @@ def apply( return out, None ctx = Context( out=out, - lse=_post_process_lse(softmax_lse, inp, tuple(original_query_shape)), + lse=_post_process_lse( + softmax_lse, inp, tuple(original_query_shape), varlen_lse_packed=True + ), ) return (out, ctx)