Skip to content

Commit

Permalink
Build FAv3
Browse files Browse the repository at this point in the history
ghstack-source-id: dc28f4bb77751ffc4f8d67a5ba088b262830659c
Pull Request resolved: fairinternal/xformers#1217

__original_commit__ = fairinternal/xformers@eec3f5f
  • Loading branch information
danthe3rd authored and xFormers Bot committed Sep 11, 2024
1 parent 0dd19df commit 1705ffe
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 19 deletions.
126 changes: 111 additions & 15 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,16 @@ def get_hip_version(rocm_dir) -> Optional[str]:
return None


def get_flash_attention_nvcc_archs_flags(cuda_version: int):
######################################
# FLASH-ATTENTION v2
######################################
# Supports `9.0`, `9.0+PTX`, `9.0a+PTX` etc...
PARSE_CUDA_ARCH_RE = re.compile(
r"(?P<major>[0-9]+)\.(?P<minor>[0-9])(?P<suffix>[a-zA-Z]{0,1})(?P<ptx>\+PTX){0,1}"
)


def get_flash_attention2_nvcc_archs_flags(cuda_version: int):
# XXX: Not supported on windows for cuda<12
# https://github.com/Dao-AILab/flash-attention/issues/345
if platform.system() != "Linux" and cuda_version < 1200:
Expand All @@ -177,10 +186,6 @@ def get_flash_attention_nvcc_archs_flags(cuda_version: int):
if os.getenv("XFORMERS_DISABLE_FLASH_ATTN", "0") != "0":
return []

# Supports `9.0`, `9.0+PTX`, `9.0a+PTX` etc...
PARSE_CUDA_ARCH_RE = re.compile(
r"(?P<major>[0-9]+)\.(?P<minor>[0-9])(?P<suffix>[a-zA-Z]{0,1})(?P<ptx>\+PTX){0,1}"
)
archs_list = os.environ.get("TORCH_CUDA_ARCH_LIST", DEFAULT_ARCHS_LIST)
nvcc_archs_flags = []
for arch in archs_list.replace(" ", ";").split(";"):
Expand All @@ -205,8 +210,8 @@ def get_flash_attention_nvcc_archs_flags(cuda_version: int):
return nvcc_archs_flags


def get_flash_attention_extensions(cuda_version: int, extra_compile_args):
nvcc_archs_flags = get_flash_attention_nvcc_archs_flags(cuda_version)
def get_flash_attention2_extensions(cuda_version: int, extra_compile_args):
nvcc_archs_flags = get_flash_attention2_nvcc_archs_flags(cuda_version)

if not nvcc_archs_flags:
return []
Expand Down Expand Up @@ -260,6 +265,101 @@ def get_flash_attention_extensions(cuda_version: int, extra_compile_args):
]


######################################
# FLASH-ATTENTION v3
######################################
def get_flash_attention3_nvcc_archs_flags(cuda_version: int):
if os.getenv("XFORMERS_DISABLE_FLASH_ATTN", "0") != "0":
return []
if platform.system() != "Linux" or cuda_version < 1203:
return []
archs_list = os.environ.get("TORCH_CUDA_ARCH_LIST")
if archs_list is None:
if torch.cuda.get_device_capability("cuda") != (9, 0):
return []
archs_list = "9.0a"
nvcc_archs_flags = []
for arch in archs_list.replace(" ", ";").split(";"):
match = PARSE_CUDA_ARCH_RE.match(arch)
assert match is not None, f"Invalid sm version: {arch}"
num = 10 * int(match.group("major")) + int(match.group("minor"))
if num != 90: # only support Sm90
continue
suffix = match.group("suffix")
nvcc_archs_flags.append(
f"-gencode=arch=compute_{num}{suffix},code=sm_{num}{suffix}"
)
if match.group("ptx") is not None:
nvcc_archs_flags.append(
f"-gencode=arch=compute_{num}{suffix},code=compute_{num}{suffix}"
)
return nvcc_archs_flags


def get_flash_attention3_extensions(cuda_version: int, extra_compile_args):
nvcc_archs_flags = get_flash_attention3_nvcc_archs_flags(cuda_version)

if not nvcc_archs_flags:
return []

flash_root = os.path.join(this_dir, "third_party", "flash-attention")
cutlass_inc = os.path.join(flash_root, "csrc", "cutlass", "include")
if not os.path.exists(flash_root) or not os.path.exists(cutlass_inc):
raise RuntimeError(
"flashattention submodule not found. Did you forget "
"to run `git submodule update --init --recursive` ?"
)

sources = [
str(Path(f).relative_to(flash_root))
for f in glob.glob(os.path.join(flash_root, "hopper", "*.cu"))
+ glob.glob(os.path.join(flash_root, "hopper", "*.cpp"))
]
sources = [s for s in sources if "flash_bwd_hdim256_fp16_sm90.cu" not in s]
return [
CUDAExtension(
name="xformers._C_flashattention3",
sources=[os.path.join(flash_root, path) for path in sources],
extra_compile_args={
"cxx": extra_compile_args.get("cxx", []),
"nvcc": extra_compile_args.get("nvcc", [])
+ [
"-O3",
# "-O0",
"-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
# "--ptxas-options=-v", # printing out number of registers
"--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage",
# "-lineinfo", # xformers: save binary size
"-DCUTLASS_DEBUG_TRACE_LEVEL=0", # Can toggle for debugging
"-DNDEBUG", # Important, otherwise performance is severely impacted
"-DQBLKSIZE=128",
"-DKBLKSIZE=128",
"-DCTA256",
"-DDQINRMEM",
]
+ nvcc_archs_flags
+ get_extra_nvcc_flags_for_build_type(cuda_version),
},
include_dirs=[
p.absolute()
for p in [
Path(flash_root) / "csrc" / "cutlass" / "include",
Path(flash_root) / "hopper",
]
],
)
]


def rename_cpp_cu(cpp_files):
for entry in cpp_files:
shutil.copy(entry, os.path.splitext(entry)[0] + ".cu")
Expand Down Expand Up @@ -380,11 +480,11 @@ def get_extensions():
]
extra_compile_args["nvcc"] = nvcc_flags

flash_extensions = []
xformers_pt_flash_attn = os.getenv("XFORMERS_PT_FLASH_ATTN")

# check if the current device supports flash_attention
nvcc_archs_flags = get_flash_attention_nvcc_archs_flags(cuda_version)
flash_version = get_flash_version()
nvcc_archs_flags = get_flash_attention2_nvcc_archs_flags(cuda_version)
if not nvcc_archs_flags:
if xformers_pt_flash_attn == "1":
raise ValueError(
Expand All @@ -399,16 +499,12 @@ def get_extensions():
) and attn_compat_module.is_pt_flash_compatible(
force=xformers_pt_flash_attn == "1"
):
flash_version = torch.nn.attention._get_flash_version() + "-pt"
use_pt_flash = True
else:
flash_extensions = get_flash_attention_extensions(
ext_modules += get_flash_attention2_extensions(
cuda_version=cuda_version, extra_compile_args=extra_compile_args
)
if flash_extensions:
flash_version = get_flash_version()

ext_modules += flash_extensions
ext_modules += get_flash_attention3_extensions(cuda_version, extra_compile_args)

# NOTE: This should not be applied to Flash-Attention
# see https://github.com/Dao-AILab/flash-attention/issues/359
Expand Down
2 changes: 1 addition & 1 deletion third_party/flash-attention
Submodule flash-attention updated 40 files
+33 −36 README.md
+43 −0 benchmarks/benchmark_gemm.py
+3 −2 flash_attn/losses/cross_entropy.py
+1 −1 flash_attn/modules/mha.py
+43 −36 flash_attn/ops/triton/cross_entropy.py
+35 −9 flash_attn/ops/triton/layer_norm.py
+59 −18 hopper/benchmark_attn.py
+351 −0 hopper/benchmark_flash_attention_fp8.py
+270 −0 hopper/epilogue_bwd_sm90_tma.hpp
+82 −5 hopper/epilogue_fwd_sm90_tma.hpp
+4 −0 hopper/flash.h
+294 −50 hopper/flash_api.cpp
+62 −31 hopper/flash_attn_interface.py
+9 −0 hopper/flash_bwd_hdim128_bf16_sm90.cu
+9 −0 hopper/flash_bwd_hdim64_bf16_sm90.cu
+9 −0 hopper/flash_bwd_hdim96_bf16_sm90.cu
+9 −0 hopper/flash_bwd_hdim96_fp16_sm90.cu
+242 −1,976 hopper/flash_bwd_kernel.h
+159 −204 hopper/flash_bwd_launch_template.h
+251 −0 hopper/flash_bwd_postprocess_kernel.h
+228 −421 hopper/flash_bwd_preprocess_kernel.h
+9 −0 hopper/flash_fwd_hdim128_e4m3_sm90.cu
+9 −0 hopper/flash_fwd_hdim256_e4m3_sm90.cu
+9 −0 hopper/flash_fwd_hdim64_e4m3_sm90.cu
+202 −3 hopper/flash_fwd_kernel.h
+101 −10 hopper/flash_fwd_launch_template.h
+146 −2 hopper/kernel_traits.h
+841 −0 hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp
+581 −51 hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp
+19 −1 hopper/named_barrier.hpp
+11 −9 hopper/setup.py
+22 −13 hopper/softmax.h
+134 −78 hopper/test_flash_attn.py
+14 −10 hopper/tile_scheduler.hpp
+92 −0 hopper/tile_scheduler_bwd.hpp
+33 −1 hopper/utils.h
+2 −2 setup.py
+21 −6 tests/losses/test_cross_entropy.py
+21 −5 tests/losses/test_cross_entropy_parallel.py
+11 −2 tests/test_util.py
4 changes: 3 additions & 1 deletion xformers/ops/fmha/flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
raise
assert is_pt_flash_compatible(force=True)
FLASH_VERSION = torch.nn.attention._get_flash_version() # type: ignore
FLASH_VERSION = f"v{FLASH_VERSION}"
VARLEN_LSE_PACKED = False
_USE_PT_FLASH_ATTN = True

Expand Down Expand Up @@ -551,6 +552,7 @@ def _post_process_lse(
lse: torch.Tensor,
inp: Inputs,
original_query_shape: Tuple[int, ...],
varlen_lse_packed: bool = VARLEN_LSE_PACKED,
) -> torch.Tensor:
# Easy case: no varlen
if not isinstance(inp.attn_bias, VARLEN_BIASES):
Expand All @@ -560,7 +562,7 @@ def _post_process_lse(
return lse

# Already packed: just bring back the batch dimension
if VARLEN_LSE_PACKED:
if varlen_lse_packed:
if len(original_query_shape) == 5:
# (1, G, H, total_q)
return lse.unflatten(0, original_query_shape[2:4]).unsqueeze(0)
Expand Down
6 changes: 4 additions & 2 deletions xformers/ops/fmha/flash3.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def mha_fwd(
softmax_lse,
p,
) = _C_flashattention3.fwd(
query, key, value, None, softmax_scale, is_causal
query, key, value, None, softmax_scale, None, None, None, is_causal
)
else:
out, q, k, v, out_padded, softmax_lse = _C_flashattention3.varlen_fwd(
Expand Down Expand Up @@ -316,7 +316,9 @@ def apply(
return out, None
ctx = Context(
out=out,
lse=_post_process_lse(softmax_lse, inp, tuple(original_query_shape)),
lse=_post_process_lse(
softmax_lse, inp, tuple(original_query_shape), varlen_lse_packed=True
),
)
return (out, ctx)

Expand Down

0 comments on commit 1705ffe

Please sign in to comment.