Skip to content

Commit

Permalink
Update FlashAttention v2
Browse files Browse the repository at this point in the history
ghstack-source-id: 380addf7adba5c555a1a993c03f3558becd390ac
Pull Request resolved: https://github.com/fairinternal/xformers/pull/720

__original_commit__ = fairinternal/xformers@3aaa74860c02bd802ca6636fc55d78300bbc93fc
  • Loading branch information
danthe3rd authored and xFormers Bot committed Jul 20, 2023
1 parent b206d71 commit cfea89f
Show file tree
Hide file tree
Showing 8 changed files with 134 additions and 76 deletions.
6 changes: 3 additions & 3 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[submodule "third_party/flash-attention"]
path = third_party/flash-attention
url = https://github.com/HazyResearch/flash-attention.git
[submodule "third_party/cutlass"]
path = third_party/cutlass
url = https://github.com/NVIDIA/cutlass.git
[submodule "third_party/flash-attention"]
path = third_party/flash-attention
url = https://github.com/Dao-AILab/flash-attention.git
57 changes: 33 additions & 24 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,18 @@ def get_cuda_version(cuda_dir) -> int:


def get_flash_attention_extensions(cuda_version: int, extra_compile_args):
# XXX: Not supported on windows yet
# https://github.com/Dao-AILab/flash-attention/issues/345
if platform.system() != "Linux":
return []
# Figure out default archs to target
DEFAULT_ARCHS_LIST = ""
if cuda_version >= 1108:
DEFAULT_ARCHS_LIST = "7.5;8.0;8.6;9.0"
DEFAULT_ARCHS_LIST = "8.0;8.6;9.0"
elif cuda_version > 1100:
DEFAULT_ARCHS_LIST = "7.5;8.0;8.6"
DEFAULT_ARCHS_LIST = "8.0;8.6"
elif cuda_version == 1100:
DEFAULT_ARCHS_LIST = "7.5;8.0"
DEFAULT_ARCHS_LIST = "8.0"
else:
return []

Expand All @@ -125,59 +129,59 @@ def get_flash_attention_extensions(cuda_version: int, extra_compile_args):

arch_arr = arch.split(".")
num = 10 * int(arch_arr[0]) + int(arch_arr[1].partition("+")[0])
# Need at least 7.5
if num < 75:
# Need at least 8.0
if num < 80:
continue
nvcc_archs_flags.append(f"-gencode=arch=compute_{num},code=sm_{num}")
if arch.endswith("+PTX"):
nvcc_archs_flags.append(f"-gencode=arch=compute_{num},code=compute_{num}")
if not nvcc_archs_flags:
return []

nvcc_windows_flags = []
if platform.system() == "Windows":
nvcc_windows_flags = ["-Xcompiler", "/permissive-"]

flash_root = os.path.join(this_dir, "third_party", "flash-attention")
if not os.path.exists(flash_root):
raise RuntimeError(
"flashattention submodule not found. Did you forget "
"to run `git submodule update --init --recursive` ?"
)

flash_root = os.path.join(this_dir, "third_party", "flash-attention")
sources = ["csrc/flash_attn/flash_api.cpp"]
for f in glob.glob(os.path.join(flash_root, "csrc", "flash_attn", "src", "*.cu")):
sources.append(str(Path(f).relative_to(flash_root)))
return [
CUDAExtension(
name="xformers._C_flashattention",
sources=[
os.path.join("third_party", "flash-attention", path)
for path in [
"csrc/flash_attn/fmha_api.cpp",
"csrc/flash_attn/src/fmha_fwd_hdim32.cu",
"csrc/flash_attn/src/fmha_fwd_hdim64.cu",
"csrc/flash_attn/src/fmha_fwd_hdim128.cu",
"csrc/flash_attn/src/fmha_bwd_hdim32.cu",
"csrc/flash_attn/src/fmha_bwd_hdim64.cu",
"csrc/flash_attn/src/fmha_bwd_hdim128.cu",
"csrc/flash_attn/src/fmha_block_fprop_fp16_kernel.sm80.cu",
"csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu",
]
],
sources=[os.path.join(flash_root, path) for path in sources],
extra_compile_args={
**extra_compile_args,
"nvcc": extra_compile_args.get("nvcc", [])
+ [
"-O3",
"-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
"--ptxas-options=-v",
]
+ nvcc_archs_flags
+ nvcc_windows_flags
+ get_extra_nvcc_flags_for_build_type(),
},
include_dirs=[
p.absolute()
for p in [
Path(flash_root) / "csrc" / "flash_attn",
Path(flash_root) / "csrc" / "flash_attn" / "src",
Path(this_dir) / "third_party" / "cutlass" / "include",
Path(flash_root) / "csrc" / "cutlass" / "include",
]
],
)
Expand Down Expand Up @@ -230,10 +234,6 @@ def get_extensions():
"-U__CUDA_NO_HALF_CONVERSIONS__",
"--extended-lambda",
"-D_ENABLE_EXTENDED_ALIGNED_STORAGE",
# Workaround for a regression with nvcc > 11.6
# See https://github.com/facebookresearch/xformers/issues/712
"--ptxas-options=-O2",
"--ptxas-options=-allow-expensive-optimizations=true",
] + get_extra_nvcc_flags_for_build_type()
if os.getenv("XFORMERS_ENABLE_DEBUG_ASSERTIONS", "0") != "1":
nvcc_flags.append("-DNDEBUG")
Expand All @@ -259,6 +259,15 @@ def get_extensions():
cuda_version=cuda_version, extra_compile_args=extra_compile_args
)

# NOTE: This should not be applied to Flash-Attention
# see https://github.com/Dao-AILab/flash-attention/issues/359
extra_compile_args["nvcc"] += [
# Workaround for a regression with nvcc > 11.6
# See https://github.com/facebookresearch/xformers/issues/712
"--ptxas-options=-O2",
"--ptxas-options=-allow-expensive-optimizations=true",
]

ext_modules.append(
extension(
"xformers._C",
Expand Down
21 changes: 18 additions & 3 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def generate_test_shapes_B_Mq_Mkv_H_K_Kv(op):
shapes = []
for B in op._TEST_BATCH_SIZES:
for Mq in [32, 256]:
for Mkv in [32, 64, 256]:
for Mkv in [32, 64, 256, 1024]:
for K in op._TEST_K:
shapes.append((B, Mq, Mkv, 1, K, K))
Mq = 256
Expand All @@ -93,7 +93,7 @@ def generate_test_shapes_B_Mq_Mkv_H_K_Kv(op):
for M in [2, 3, 15, 31, 32, 34, 68, 72, 90, 132, 136]:
shapes.append((B, M, Mkv, H, K, K))
shapes.append((B, Mq, M, H, K, K))
for _K in [1, 2, 3, 31, 34, 36, 38, 40, 64, 256 + 2, 256 + 8, 512]:
for _K in [1, 2, 3, 31, 34, 36, 38, 40, 64, 80, 160, 256 + 2, 256 + 8, 512]:
if _K <= op.SUPPORTED_MAX_K:
shapes.append((B, Mq, Mkv, H, _K, _K))
# Different value for K / Kv
Expand All @@ -108,6 +108,17 @@ def generate_test_shapes_B_Mq_Mkv_H_K_Kv(op):
# Some number of heads
for H in [3, 5, 12]:
shapes.append((max(1, B // H), Mq, Mkv, H, K, K))
# Filter-out not supported shapes
shapes = [
shape
for shape in shapes
if len(
op.shape_not_supported_reasons(
Mq=shape[1], Mkv=shape[2], K=shape[4], Kv=shape[5]
)
)
== 0
]
# Add some random shapes
if op in [
fmha.cutlass.FwOp,
Expand All @@ -116,7 +127,8 @@ def generate_test_shapes_B_Mq_Mkv_H_K_Kv(op):
]:
K_CHOICES = [8 * i for i in range(1, 256 // 8)]
r = random.Random(0)
for _ in range(20):
found_count = 0
while found_count < 20:
B = r.randint(1, 400)
Mq = r.randint(1, 500)
Mkv = r.randint(1, 500)
Expand All @@ -126,6 +138,9 @@ def generate_test_shapes_B_Mq_Mkv_H_K_Kv(op):
Kv = r.choice(K_CHOICES)
if not op.SUPPORTS_DIFFERENT_VALUE_EMBED:
Kv = K
if len(op.shape_not_supported_reasons(Mq, Mkv, K, Kv)):
continue
found_count += 1
shapes.append((B, Mq, Mkv, H, K, Kv))
return shapes

Expand Down
6 changes: 5 additions & 1 deletion xformers/benchmarks/benchmark_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,11 @@ def T(t):
(8, 2048, 20, 128),
# LLaMa 70b - mp=8/16
*sorted(list(itertools.product([1, 2], [2048, 4096, 8192], [4, 8], [128]))),
*sorted(list(itertools.product([16], [128, 512, 1024], [16], [16, 32, 64, 128]))),
*sorted(
list(
itertools.product([16], [128, 512, 1024], [16], [16, 32, 64, 128, 160, 256])
)
),
]

OPS = [
Expand Down
36 changes: 25 additions & 11 deletions xformers/ops/fmha/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ class AttentionOpBase(BaseOperator):
SUPPORTS_DROPOUT: bool
SUPPORTS_CUSTOM_SCALE: bool = False
SUPPORTS_DIFFERENT_VALUE_EMBED: bool = False
IS_DETERMINISTIC: bool = True
NAME: str
OPERATOR_CATEGORY = "memory_efficient_attention"

Expand All @@ -161,13 +162,31 @@ class AttentionOpBase(BaseOperator):
def supports(cls, d: Inputs) -> bool:
return not cls.not_supported_reasons(d)

@classmethod
def shape_not_supported_reasons(
cls, Mq: int, Mkv: int, K: int, Kv: int
) -> List[str]:
reasons = []
if not cls.SUPPORTS_DIFFERENT_VALUE_EMBED and K != Kv:
reasons.append("query.shape[-1] != value.shape[-1]")
if max(K, Kv) > cls.SUPPORTED_MAX_K:
reasons.append(
f"max(query.shape[-1] != value.shape[-1]) > {cls.SUPPORTED_MAX_K}"
)
return reasons

@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
"""
Returns a list of reasons why this is not supported.
The kernel can run these inputs only if the returned list is empty
"""
reasons = []
reasons = cls.shape_not_supported_reasons(
Mq=d.query.shape[1],
Mkv=d.key.shape[1],
K=d.query.shape[-1],
Kv=d.query.shape[-1],
)
device_type = d.query.device.type
dtype = d.query.dtype
if device_type not in cls.SUPPORTED_DEVICES:
Expand All @@ -176,15 +195,6 @@ def not_supported_reasons(cls, d: Inputs) -> List[str]:
reasons.append("xFormers wasn't build with CUDA support")
if dtype not in cls.SUPPORTED_DTYPES:
reasons.append(f"dtype={dtype} (supported: {cls.SUPPORTED_DTYPES})")
if (
not cls.SUPPORTS_DIFFERENT_VALUE_EMBED
and d.query.shape[-1] != d.value.shape[-1]
):
reasons.append("query.shape[-1] != value.shape[-1]")
if max(d.query.shape[-1], d.value.shape[-1]) > cls.SUPPORTED_MAX_K:
reasons.append(
f"max(query.shape[-1] != value.shape[-1]) > {cls.SUPPORTED_MAX_K}"
)
if type(d.attn_bias) not in cls.SUPPORTED_ATTN_BIAS_TYPES:
reasons.append(f"attn_bias type is {type(d.attn_bias)}")
if (d.p != 0.0) and not cls.SUPPORTS_DROPOUT:
Expand All @@ -201,7 +211,11 @@ def not_supported_reasons(cls, d: Inputs) -> List[str]:
reasons.append("bf16 is only supported on A100+ GPUs")
if not cls.is_available():
reasons.append(
"Operator wasn't built - see `python -m xformers.info` for more info"
"operator wasn't built - see `python -m xformers.info` for more info"
)
if not cls.IS_DETERMINISTIC and torch.are_deterministic_algorithms_enabled():
reasons.append(
"operator is non-deterministic, but `torch.use_deterministic_algorithms` is set"
)
return reasons

Expand Down
12 changes: 2 additions & 10 deletions xformers/ops/fmha/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,7 @@


def _is_cutlass_fwd_faster_than_flash(inp: Inputs) -> bool:
# For dropout, we can't mix & match kernels
# Unfortunately, the dropout implementation in CUTLASS
# backward is pretty slow for the BW, so disable it here
if inp.p > 0.0:
return False

# Large values of K
return max(inp.query.shape[-1], inp.value.shape[-1]) > 64
return False


def _is_triton_fwd_fastest(inp: Inputs) -> bool:
Expand Down Expand Up @@ -106,8 +99,7 @@ def _dispatch_fw(inp: Inputs, needs_gradient: bool) -> Type[AttentionFwOpBase]:


def _is_cutlassB_faster_than_flash(inp: Inputs) -> bool:
embed_dim = max(inp.query.shape[-1], inp.value.shape[-1])
return embed_dim > 64 and inp.attn_bias is None and inp.p == 0.0
return False


def _dispatch_bw(inp: Inputs) -> Type[AttentionBwOpBase]:
Expand Down
Loading

0 comments on commit cfea89f

Please sign in to comment.