Skip to content

Commit

Permalink
Add an SDPA dispatcher for nested tensors with jagged layouts (pytorc…
Browse files Browse the repository at this point in the history
…h#114164)

Pull Request resolved: pytorch#114164
Approved by: https://github.com/jbschlosser
  • Loading branch information
ani300 authored and pytorchmergebot committed Dec 5, 2023
1 parent fb92983 commit 1dc4588
Show file tree
Hide file tree
Showing 7 changed files with 1,180 additions and 123 deletions.
120 changes: 83 additions & 37 deletions aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,23 +83,6 @@ bool check_head_dim_size_flash(sdp_params const& params, bool debug) {
const auto value_size_last = params.value.sym_size(-1);
bool same_head_dim_size =
query_size_last == key_size_last && query_size_last == value_size_last;
if (has_for_nested_inputs(params)) {
if (!(same_head_dim_size && (query_size_last % 8 == 0) &&
(query_size_last <= max_size))) {
if (debug) {
TORCH_WARN(
"For NestedTensor inputs, Flash attention requires q,k,v to have the same last dimension and to be a multiple of 8 and less than or equal to 256.",
" Got Query.size(-1): ",
query_size_last,
", Key.size(-1): ",
params.key.sym_size(-1),
", Value.size(-1): ",
params.value.sym_size(-1),
" instead.");
}
return false;
}
}
if (!(same_head_dim_size && (query_size_last <= max_size))) {
if (debug) {
TORCH_WARN(
Expand All @@ -117,6 +100,31 @@ bool check_head_dim_size_flash(sdp_params const& params, bool debug) {
return true;
}

bool check_head_dim_size_flash_nested(sdp_params const& params, bool debug) {
const auto max_size = c10::SymInt(256);
const auto query_size_last = params.query.sym_size(-1);
const auto key_size_last = params.key.sym_size(-1);
const auto value_size_last = params.value.sym_size(-1);
bool same_head_dim_size =
query_size_last == key_size_last && query_size_last == value_size_last;
if (!(same_head_dim_size && (query_size_last % 8 == 0) &&
(query_size_last <= max_size))) {
if (debug) {
TORCH_WARN(
"For NestedTensor inputs, Flash attention requires q,k,v to have the same last dimension and to be a multiple of 8 and less than or equal to 256.",
" Got Query.size(-1): ",
query_size_last,
", Key.size(-1): ",
params.key.sym_size(-1),
", Value.size(-1): ",
params.value.sym_size(-1),
" instead.");
}
return false;
}
return true;
}

bool check_head_dim_size_mem_efficient(sdp_params const& params, bool debug) {
const auto query_size_last = params.query.sym_size(-1);
const auto value_size_last = params.value.sym_size(-1);
Expand Down Expand Up @@ -210,7 +218,7 @@ bool check_requires_grad_and_head_dim_gt192_and_sm_ge86_lt90(
sdp_params const& params,
bool debug) {
// Flash Attention will raise an error in the backward pass if the head_dim
// size is greater than 64 And the device is between in the range [sm86, sm89]
// size is greater than 192 And the device is between in the range [sm86, sm89]
using sm86 = SMVersion<8, 6>;
using sm89 = SMVersion<8, 9>;
auto dprops = at::cuda::getCurrentDeviceProperties();
Expand All @@ -235,7 +243,9 @@ bool check_flash_causal_non_square_seqlens(sdp_params const& params, bool debug)
// FlashAttention 2 updated the default mask meaning for causal in this PR:
// 9e5e8bc91e it is now aligned to lower_right which would be a BC break
// for non-square masks. We will not support non-square masks for causal w/ FAV2
if (params.is_causal && params.query.sym_size(-2) != params.key.sym_size(-2)) {
if (params.is_causal &&
!params.query.is_nested() && !params.key.is_nested() &&
params.query.sym_size(-2) != params.key.sym_size(-2)) {
if (debug) {
TORCH_WARN(
"Flash attention does not support the is_causal flag when seqlen_q != seqlen_k. ",
Expand All @@ -256,25 +266,43 @@ TORCH_API bool can_use_flash_attention(sdp_params const& params, bool debug) {

// Define gate functions that determine if a flash kernel can be ran
// Replace with std::to_array when we migrate to c++20
constexpr auto constraints = array_of<bool (*)(sdp_params const&, bool)>(
constexpr auto general_constraints = array_of<bool (*)(sdp_params const&, bool)>(
check_runtime_disabled_flash,
check_tensor_shapes,
check_batch_size_and_num_heads,
check_for_attn_mask,
check_head_dim_size_flash,
check_for_seq_len_0_nested_tensor,
check_nonzero_sequence_lengths,
check_last_dim_stride_equals_1,
check_flash_attention_hardware_support,
check_requires_grad_and_head_dim_gt192_and_sm_ge86_lt90,
check_flash_causal_non_square_seqlens,
check_for_seq_len_0_nested_tensor);
for (auto& constraint : constraints) {
check_flash_causal_non_square_seqlens);
for (auto& constraint : general_constraints) {
if (!constraint(params, debug)) {
return false;
}
}

if (has_for_nested_inputs(params)) {
constexpr auto nested_constraints = array_of<bool (*)(sdp_params const&, bool)>(
check_batch_size_nested,
check_head_dim_size_flash_nested,
check_for_seq_len_0_nested_tensor);
for (auto& constraint : nested_constraints) {
if (!constraint(params, debug)) {
return false;
}
}
}
if (has_only_dense_inputs(params)) {
constexpr auto dense_constraints = array_of<bool (*)(sdp_params const&, bool)>(
check_batch_size_and_num_heads_dense,
check_nonzero_sequence_lengths_dense,
check_last_dim_stride_equals_1_dense);
for (auto& constraint : dense_constraints) {
if (!constraint(params, debug)) {
return false;
}
}
}

auto dprop = at::cuda::getCurrentDeviceProperties();
if (dprop->major >= 8) {
constexpr auto sm80_flash_dtypes =
Expand All @@ -297,23 +325,41 @@ TORCH_API bool can_use_mem_efficient_attention(sdp_params const& params, bool de
constexpr auto sm50_mem_efficient_dtypes =
array_of<at::ScalarType>(at::kHalf, at::kFloat);

// Define gate functions that determine if a flash kernel can be ran
constexpr auto constraints = array_of<bool (*)(sdp_params const&, bool)>(
// Define gate functions that determine if a mem efficient kernel can be ran
constexpr auto general_constraints = array_of<bool (*)(sdp_params const&, bool)>(
check_runtime_disabled_mem_efficient,
check_mem_efficient_hardware_support,
check_requires_grad_and_nested,
check_tensor_shapes,
check_batch_size_and_num_heads,
check_head_dim_size_mem_efficient,
check_for_seq_len_0_nested_tensor,
check_nonzero_sequence_lengths,
check_last_dim_stride_equals_1);
for (auto& constraint : constraints) {
check_head_dim_size_mem_efficient);
for (auto& constraint : general_constraints) {
if (!constraint(params, debug)) {
return false;
}
}

if (has_for_nested_inputs(params)) {
constexpr auto nested_constraints = array_of<bool (*)(sdp_params const&, bool)>(
check_requires_grad_and_nested,
check_batch_size_nested,
check_for_seq_len_0_nested_tensor);
for (auto& constraint : nested_constraints) {
if (!constraint(params, debug)) {
return false;
}
}
}
if (has_only_dense_inputs(params)) {
constexpr auto dense_constraints = array_of<bool (*)(sdp_params const&, bool)>(
check_batch_size_and_num_heads_dense,
check_nonzero_sequence_lengths_dense,
check_last_dim_stride_equals_1_dense);
for (auto& constraint : dense_constraints) {
if (!constraint(params, debug)) {
return false;
}
}
}

auto dprop = at::cuda::getCurrentDeviceProperties();
if (dprop->major == 5) {
return check_tensor_dtype(params, sm50_mem_efficient_dtypes, debug);
Expand Down Expand Up @@ -370,7 +416,7 @@ SDPBackend select_sdp_backend(sdp_params const& kernel_params) {
sdp::can_use_mem_efficient_attention(kernel_params, print_debug);
TORCH_WARN("Flash attention kernel not used because:");
sdp::can_use_flash_attention(kernel_params, print_debug);
TORCH_CHECK(!print_debug, "No available kernel. Aborting execution.")
TORCH_CHECK(!print_debug, "No available kernel. Aborting execution.")
return SDPBackend::error;
}

Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/native/transformers/sdp_utils_cpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ bool use_flash_attention_cpp(sdp_params const& params, bool debug) {
check_nested_tensor,
check_for_dropout,
check_tensor_shapes,
check_batch_size_and_num_heads,
check_batch_size_and_num_heads_dense,
check_for_attn_mask,
check_head_dim_size_cpp,
check_nonzero_sequence_lengths,
check_last_dim_stride_equals_1);
check_nonzero_sequence_lengths_dense,
check_last_dim_stride_equals_1_dense);
for (auto& constraint : constraints) {
if (!constraint(params, debug)) {
return false;
Expand Down
121 changes: 63 additions & 58 deletions aten/src/ATen/native/transformers/sdp_utils_cpp.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,18 @@ inline bool input_requires_grad(sdp_params const& params) {
}

inline bool has_for_nested_inputs(sdp_params const& params) {
return (
params.query.is_nested() || params.key.is_nested() ||
params.value.is_nested());
return
(params.query.is_nested() && params.query.layout() == c10::kStrided) ||
(params.key.is_nested() && params.key.layout() == c10::kStrided) ||
(params.value.is_nested() && params.value.layout() == c10::kStrided);
}

inline bool has_for_dense_inputs(sdp_params const& params) {
return !params.query.is_nested() || !params.key.is_nested() || !params.value.is_nested();
}

inline bool has_only_dense_inputs(sdp_params const& params) {
return !params.query.is_nested() && !params.key.is_nested() && !params.value.is_nested();
}

template <typename dtype_vector>
Expand Down Expand Up @@ -176,10 +185,6 @@ inline bool check_for_seq_len_0_and_consistent_head_dim_nested_tensor_helper(

inline bool check_for_seq_len_0_nested_tensor(sdp_params const& params, bool debug) {
// When this function is called we are assured that the nt is dim==4
if (!has_for_nested_inputs(params)) {
return true;
}

bool q_is_safe = params.query.is_nested()
? check_for_seq_len_0_and_consistent_head_dim_nested_tensor_helper(
params.query, "query ", debug)
Expand Down Expand Up @@ -230,10 +235,10 @@ inline bool check_for_seq_len_0_nested_tensor(sdp_params const& params, bool deb

inline bool check_nested_tensor(sdp_params const& params, bool debug) {
// Return false if have nested tensor
if (has_for_nested_inputs(params)) {
if (!has_only_dense_inputs(params)) {
if (debug) {
TORCH_WARN(
"Both fused kernels of cpp version currently do support Nested Tensor inputs.");
"Both fused kernels of cpp version currently do not support Nested Tensor inputs.");
}
return false;
}
Expand All @@ -251,8 +256,7 @@ inline bool check_for_dropout(sdp_params const& params, bool debug) {
}

inline bool check_requires_grad_and_nested(sdp_params const& params, bool debug) {
// If we fail both checks then we return false
if (has_for_nested_inputs(params) && input_requires_grad(params)) {
if (input_requires_grad(params)) {
if (debug) {
TORCH_WARN(
"Memory efficient attention currently doesn't support training with NT inputs.");
Expand Down Expand Up @@ -306,50 +310,17 @@ inline bool check_safe_kv_broadcast(at::Tensor const& param, bool debug) {
return true;
}

inline bool check_batch_size_and_num_heads(sdp_params const& params, bool debug) {
inline bool check_batch_size_and_num_heads_dense(sdp_params const& params, bool debug) {
// This is expected to be called after check_tensor_shapes ensuring that the
// size() calls won't error since the inputs are all 4 dimensional

auto q_batch_size = params.query.sym_size(0);
auto k_batch_size = params.key.sym_size(0);
auto v_batch_size = params.value.sym_size(0);

bool has_nested_input = has_for_nested_inputs(params);
bool same_batch_size =
q_batch_size == k_batch_size && q_batch_size == v_batch_size;

// num_heads logic for nested input is checked in
// check_for_seq_len_0_nested_tensor as there is handling there to make sure
// num_heads is not ragged
if (has_nested_input) {
bool broadcastable_batch_size = true;
if (!same_batch_size) {
if (input_requires_grad(params)){
if (debug) {
TORCH_WARN(
"Both fused kernels do not support training with broadcasted NT inputs.");
}
return false;
}
// try to broadcast batchsize
broadcastable_batch_size = try_broadcast_param_size(
q_batch_size, k_batch_size, v_batch_size, "batch size ", debug);

// if only one of k or v require broadcasting of batch size, the other
// must have a consistent seq_len dim
if (broadcastable_batch_size) {
if (k_batch_size == 1 && v_batch_size != 1 &&
!check_safe_kv_broadcast(params.value, debug)) {
return false;
}
if (v_batch_size == 1 && k_batch_size != 1 &&
!check_safe_kv_broadcast(params.key, debug)) {
return false;
}
}
}
return broadcastable_batch_size;
}

auto q_num_heads = params.query.sym_size(1);
auto k_num_heads = params.key.sym_size(1);
auto v_num_heads = params.value.sym_size(1);
Expand All @@ -373,13 +344,49 @@ inline bool check_batch_size_and_num_heads(sdp_params const& params, bool debug)
return true;
}

inline bool check_nonzero_sequence_lengths(sdp_params const& params, bool debug) {
if (has_for_nested_inputs(params)){
// Currently we do not support any masking with NestedTensors
// This is checked in validate_sdpa_input so this filter func
// Should have no actually bearing on the kernel selection
return true;
inline bool check_batch_size_nested(sdp_params const& params, bool debug) {
// This is expected to be called after check_tensor_shapes ensuring that the
// size() calls won't error since the inputs are all 4 dimensional
auto q_batch_size = params.query.sym_size(0);
auto k_batch_size = params.key.sym_size(0);
auto v_batch_size = params.value.sym_size(0);

bool same_batch_size =
q_batch_size == k_batch_size && q_batch_size == v_batch_size;

// num_heads logic for nested input is checked in
// check_for_seq_len_0_nested_tensor as there is handling there to make sure
// num_heads is not ragged
bool broadcastable_batch_size = true;
if (!same_batch_size) {
if (input_requires_grad(params)){
if (debug) {
TORCH_WARN(
"Both fused kernels do not support training with broadcasted NT inputs.");
}
return false;
}
// try to broadcast batchsize
broadcastable_batch_size = try_broadcast_param_size(
q_batch_size, k_batch_size, v_batch_size, "batch size ", debug);

// if only one of k or v require broadcasting of batch size, the other
// must have a consistent seq_len dim
if (broadcastable_batch_size) {
if (k_batch_size == 1 && v_batch_size != 1 &&
!check_safe_kv_broadcast(params.value, debug)) {
return false;
}
if (v_batch_size == 1 && k_batch_size != 1 &&
!check_safe_kv_broadcast(params.key, debug)) {
return false;
}
}
}
return broadcastable_batch_size;
}

inline bool check_nonzero_sequence_lengths_dense(sdp_params const& params, bool debug) {
// In some cases people will pass in 0 sized tensors, this will
// cause the fused path to error with unaligned mask
bool zero_seq_len_q = params.query.sym_size(-2) == 0;
Expand All @@ -394,12 +401,10 @@ inline bool check_nonzero_sequence_lengths(sdp_params const& params, bool debug)
return true;
}

inline bool check_last_dim_stride_equals_1(sdp_params const& params, bool debug) {
if (has_for_nested_inputs(params)){
// The stride checking for NestedTensors is done within the kernel
// And .contiguous will be called if needed
return true;
}
inline bool check_last_dim_stride_equals_1_dense(sdp_params const& params, bool debug) {
// The stride checking for NestedTensors is done within the kernel
// And .contiguous will be called if needed

// This function checks that the last dimension of the inputs to
// fused_attention have stride 1
bool qkv_strides_equal_1 = params.query.sym_stride(-1) == 1 &&
Expand Down
Loading

0 comments on commit 1dc4588

Please sign in to comment.