From 1dc4588c6aa861bbb6539c5fca5afb493a8635fd Mon Sep 17 00:00:00 2001 From: Antoni Viros Date: Tue, 5 Dec 2023 03:38:26 +0000 Subject: [PATCH] Add an SDPA dispatcher for nested tensors with jagged layouts (#114164) Pull Request resolved: https://github.com/pytorch/pytorch/pull/114164 Approved by: https://github.com/jbschlosser --- .../native/transformers/cuda/sdp_utils.cpp | 120 ++- .../native/transformers/sdp_utils_cpp.cpp | 6 +- .../ATen/native/transformers/sdp_utils_cpp.h | 121 +-- test/test_nestedtensor.py | 177 +++++ torch/nested/_internal/nested_tensor.py | 99 ++- torch/nested/_internal/ops.py | 51 +- torch/nested/_internal/sdpa.py | 729 ++++++++++++++++++ 7 files changed, 1180 insertions(+), 123 deletions(-) create mode 100644 torch/nested/_internal/sdpa.py diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index a05284163e2ab..924a53922efcd 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -83,23 +83,6 @@ bool check_head_dim_size_flash(sdp_params const& params, bool debug) { const auto value_size_last = params.value.sym_size(-1); bool same_head_dim_size = query_size_last == key_size_last && query_size_last == value_size_last; - if (has_for_nested_inputs(params)) { - if (!(same_head_dim_size && (query_size_last % 8 == 0) && - (query_size_last <= max_size))) { - if (debug) { - TORCH_WARN( - "For NestedTensor inputs, Flash attention requires q,k,v to have the same last dimension and to be a multiple of 8 and less than or equal to 256.", - " Got Query.size(-1): ", - query_size_last, - ", Key.size(-1): ", - params.key.sym_size(-1), - ", Value.size(-1): ", - params.value.sym_size(-1), - " instead."); - } - return false; - } - } if (!(same_head_dim_size && (query_size_last <= max_size))) { if (debug) { TORCH_WARN( @@ -117,6 +100,31 @@ bool check_head_dim_size_flash(sdp_params const& params, bool debug) { return true; } +bool check_head_dim_size_flash_nested(sdp_params const& params, bool debug) { + const auto max_size = c10::SymInt(256); + const auto query_size_last = params.query.sym_size(-1); + const auto key_size_last = params.key.sym_size(-1); + const auto value_size_last = params.value.sym_size(-1); + bool same_head_dim_size = + query_size_last == key_size_last && query_size_last == value_size_last; + if (!(same_head_dim_size && (query_size_last % 8 == 0) && + (query_size_last <= max_size))) { + if (debug) { + TORCH_WARN( + "For NestedTensor inputs, Flash attention requires q,k,v to have the same last dimension and to be a multiple of 8 and less than or equal to 256.", + " Got Query.size(-1): ", + query_size_last, + ", Key.size(-1): ", + params.key.sym_size(-1), + ", Value.size(-1): ", + params.value.sym_size(-1), + " instead."); + } + return false; + } + return true; +} + bool check_head_dim_size_mem_efficient(sdp_params const& params, bool debug) { const auto query_size_last = params.query.sym_size(-1); const auto value_size_last = params.value.sym_size(-1); @@ -210,7 +218,7 @@ bool check_requires_grad_and_head_dim_gt192_and_sm_ge86_lt90( sdp_params const& params, bool debug) { // Flash Attention will raise an error in the backward pass if the head_dim - // size is greater than 64 And the device is between in the range [sm86, sm89] + // size is greater than 192 And the device is between in the range [sm86, sm89] using sm86 = SMVersion<8, 6>; using sm89 = SMVersion<8, 9>; auto dprops = at::cuda::getCurrentDeviceProperties(); @@ -235,7 +243,9 @@ bool check_flash_causal_non_square_seqlens(sdp_params const& params, bool debug) // FlashAttention 2 updated the default mask meaning for causal in this PR: // 9e5e8bc91e it is now aligned to lower_right which would be a BC break // for non-square masks. We will not support non-square masks for causal w/ FAV2 - if (params.is_causal && params.query.sym_size(-2) != params.key.sym_size(-2)) { + if (params.is_causal && + !params.query.is_nested() && !params.key.is_nested() && + params.query.sym_size(-2) != params.key.sym_size(-2)) { if (debug) { TORCH_WARN( "Flash attention does not support the is_causal flag when seqlen_q != seqlen_k. ", @@ -256,25 +266,43 @@ TORCH_API bool can_use_flash_attention(sdp_params const& params, bool debug) { // Define gate functions that determine if a flash kernel can be ran // Replace with std::to_array when we migrate to c++20 - constexpr auto constraints = array_of( + constexpr auto general_constraints = array_of( check_runtime_disabled_flash, check_tensor_shapes, - check_batch_size_and_num_heads, check_for_attn_mask, check_head_dim_size_flash, - check_for_seq_len_0_nested_tensor, - check_nonzero_sequence_lengths, - check_last_dim_stride_equals_1, check_flash_attention_hardware_support, check_requires_grad_and_head_dim_gt192_and_sm_ge86_lt90, - check_flash_causal_non_square_seqlens, - check_for_seq_len_0_nested_tensor); - for (auto& constraint : constraints) { + check_flash_causal_non_square_seqlens); + for (auto& constraint : general_constraints) { if (!constraint(params, debug)) { return false; } } + if (has_for_nested_inputs(params)) { + constexpr auto nested_constraints = array_of( + check_batch_size_nested, + check_head_dim_size_flash_nested, + check_for_seq_len_0_nested_tensor); + for (auto& constraint : nested_constraints) { + if (!constraint(params, debug)) { + return false; + } + } + } + if (has_only_dense_inputs(params)) { + constexpr auto dense_constraints = array_of( + check_batch_size_and_num_heads_dense, + check_nonzero_sequence_lengths_dense, + check_last_dim_stride_equals_1_dense); + for (auto& constraint : dense_constraints) { + if (!constraint(params, debug)) { + return false; + } + } + } + auto dprop = at::cuda::getCurrentDeviceProperties(); if (dprop->major >= 8) { constexpr auto sm80_flash_dtypes = @@ -297,23 +325,41 @@ TORCH_API bool can_use_mem_efficient_attention(sdp_params const& params, bool de constexpr auto sm50_mem_efficient_dtypes = array_of(at::kHalf, at::kFloat); - // Define gate functions that determine if a flash kernel can be ran - constexpr auto constraints = array_of( + // Define gate functions that determine if a mem efficient kernel can be ran + constexpr auto general_constraints = array_of( check_runtime_disabled_mem_efficient, check_mem_efficient_hardware_support, - check_requires_grad_and_nested, check_tensor_shapes, - check_batch_size_and_num_heads, - check_head_dim_size_mem_efficient, - check_for_seq_len_0_nested_tensor, - check_nonzero_sequence_lengths, - check_last_dim_stride_equals_1); - for (auto& constraint : constraints) { + check_head_dim_size_mem_efficient); + for (auto& constraint : general_constraints) { if (!constraint(params, debug)) { return false; } } + if (has_for_nested_inputs(params)) { + constexpr auto nested_constraints = array_of( + check_requires_grad_and_nested, + check_batch_size_nested, + check_for_seq_len_0_nested_tensor); + for (auto& constraint : nested_constraints) { + if (!constraint(params, debug)) { + return false; + } + } + } + if (has_only_dense_inputs(params)) { + constexpr auto dense_constraints = array_of( + check_batch_size_and_num_heads_dense, + check_nonzero_sequence_lengths_dense, + check_last_dim_stride_equals_1_dense); + for (auto& constraint : dense_constraints) { + if (!constraint(params, debug)) { + return false; + } + } + } + auto dprop = at::cuda::getCurrentDeviceProperties(); if (dprop->major == 5) { return check_tensor_dtype(params, sm50_mem_efficient_dtypes, debug); @@ -370,7 +416,7 @@ SDPBackend select_sdp_backend(sdp_params const& kernel_params) { sdp::can_use_mem_efficient_attention(kernel_params, print_debug); TORCH_WARN("Flash attention kernel not used because:"); sdp::can_use_flash_attention(kernel_params, print_debug); - TORCH_CHECK(!print_debug, "No available kernel. Aborting execution.") + TORCH_CHECK(!print_debug, "No available kernel. Aborting execution.") return SDPBackend::error; } diff --git a/aten/src/ATen/native/transformers/sdp_utils_cpp.cpp b/aten/src/ATen/native/transformers/sdp_utils_cpp.cpp index 308f94cc12578..beb6862daa84b 100644 --- a/aten/src/ATen/native/transformers/sdp_utils_cpp.cpp +++ b/aten/src/ATen/native/transformers/sdp_utils_cpp.cpp @@ -42,11 +42,11 @@ bool use_flash_attention_cpp(sdp_params const& params, bool debug) { check_nested_tensor, check_for_dropout, check_tensor_shapes, - check_batch_size_and_num_heads, + check_batch_size_and_num_heads_dense, check_for_attn_mask, check_head_dim_size_cpp, - check_nonzero_sequence_lengths, - check_last_dim_stride_equals_1); + check_nonzero_sequence_lengths_dense, + check_last_dim_stride_equals_1_dense); for (auto& constraint : constraints) { if (!constraint(params, debug)) { return false; diff --git a/aten/src/ATen/native/transformers/sdp_utils_cpp.h b/aten/src/ATen/native/transformers/sdp_utils_cpp.h index 6dafe7b2cb5ef..270c9a90d6664 100644 --- a/aten/src/ATen/native/transformers/sdp_utils_cpp.h +++ b/aten/src/ATen/native/transformers/sdp_utils_cpp.h @@ -73,9 +73,18 @@ inline bool input_requires_grad(sdp_params const& params) { } inline bool has_for_nested_inputs(sdp_params const& params) { - return ( - params.query.is_nested() || params.key.is_nested() || - params.value.is_nested()); + return + (params.query.is_nested() && params.query.layout() == c10::kStrided) || + (params.key.is_nested() && params.key.layout() == c10::kStrided) || + (params.value.is_nested() && params.value.layout() == c10::kStrided); +} + +inline bool has_for_dense_inputs(sdp_params const& params) { + return !params.query.is_nested() || !params.key.is_nested() || !params.value.is_nested(); +} + +inline bool has_only_dense_inputs(sdp_params const& params) { + return !params.query.is_nested() && !params.key.is_nested() && !params.value.is_nested(); } template @@ -176,10 +185,6 @@ inline bool check_for_seq_len_0_and_consistent_head_dim_nested_tensor_helper( inline bool check_for_seq_len_0_nested_tensor(sdp_params const& params, bool debug) { // When this function is called we are assured that the nt is dim==4 - if (!has_for_nested_inputs(params)) { - return true; - } - bool q_is_safe = params.query.is_nested() ? check_for_seq_len_0_and_consistent_head_dim_nested_tensor_helper( params.query, "query ", debug) @@ -230,10 +235,10 @@ inline bool check_for_seq_len_0_nested_tensor(sdp_params const& params, bool deb inline bool check_nested_tensor(sdp_params const& params, bool debug) { // Return false if have nested tensor - if (has_for_nested_inputs(params)) { + if (!has_only_dense_inputs(params)) { if (debug) { TORCH_WARN( - "Both fused kernels of cpp version currently do support Nested Tensor inputs."); + "Both fused kernels of cpp version currently do not support Nested Tensor inputs."); } return false; } @@ -251,8 +256,7 @@ inline bool check_for_dropout(sdp_params const& params, bool debug) { } inline bool check_requires_grad_and_nested(sdp_params const& params, bool debug) { - // If we fail both checks then we return false - if (has_for_nested_inputs(params) && input_requires_grad(params)) { + if (input_requires_grad(params)) { if (debug) { TORCH_WARN( "Memory efficient attention currently doesn't support training with NT inputs."); @@ -306,50 +310,17 @@ inline bool check_safe_kv_broadcast(at::Tensor const& param, bool debug) { return true; } -inline bool check_batch_size_and_num_heads(sdp_params const& params, bool debug) { +inline bool check_batch_size_and_num_heads_dense(sdp_params const& params, bool debug) { // This is expected to be called after check_tensor_shapes ensuring that the // size() calls won't error since the inputs are all 4 dimensional + auto q_batch_size = params.query.sym_size(0); auto k_batch_size = params.key.sym_size(0); auto v_batch_size = params.value.sym_size(0); - bool has_nested_input = has_for_nested_inputs(params); bool same_batch_size = q_batch_size == k_batch_size && q_batch_size == v_batch_size; - // num_heads logic for nested input is checked in - // check_for_seq_len_0_nested_tensor as there is handling there to make sure - // num_heads is not ragged - if (has_nested_input) { - bool broadcastable_batch_size = true; - if (!same_batch_size) { - if (input_requires_grad(params)){ - if (debug) { - TORCH_WARN( - "Both fused kernels do not support training with broadcasted NT inputs."); - } - return false; - } - // try to broadcast batchsize - broadcastable_batch_size = try_broadcast_param_size( - q_batch_size, k_batch_size, v_batch_size, "batch size ", debug); - - // if only one of k or v require broadcasting of batch size, the other - // must have a consistent seq_len dim - if (broadcastable_batch_size) { - if (k_batch_size == 1 && v_batch_size != 1 && - !check_safe_kv_broadcast(params.value, debug)) { - return false; - } - if (v_batch_size == 1 && k_batch_size != 1 && - !check_safe_kv_broadcast(params.key, debug)) { - return false; - } - } - } - return broadcastable_batch_size; - } - auto q_num_heads = params.query.sym_size(1); auto k_num_heads = params.key.sym_size(1); auto v_num_heads = params.value.sym_size(1); @@ -373,13 +344,49 @@ inline bool check_batch_size_and_num_heads(sdp_params const& params, bool debug) return true; } -inline bool check_nonzero_sequence_lengths(sdp_params const& params, bool debug) { - if (has_for_nested_inputs(params)){ - // Currently we do not support any masking with NestedTensors - // This is checked in validate_sdpa_input so this filter func - // Should have no actually bearing on the kernel selection - return true; +inline bool check_batch_size_nested(sdp_params const& params, bool debug) { + // This is expected to be called after check_tensor_shapes ensuring that the + // size() calls won't error since the inputs are all 4 dimensional + auto q_batch_size = params.query.sym_size(0); + auto k_batch_size = params.key.sym_size(0); + auto v_batch_size = params.value.sym_size(0); + + bool same_batch_size = + q_batch_size == k_batch_size && q_batch_size == v_batch_size; + + // num_heads logic for nested input is checked in + // check_for_seq_len_0_nested_tensor as there is handling there to make sure + // num_heads is not ragged + bool broadcastable_batch_size = true; + if (!same_batch_size) { + if (input_requires_grad(params)){ + if (debug) { + TORCH_WARN( + "Both fused kernels do not support training with broadcasted NT inputs."); + } + return false; + } + // try to broadcast batchsize + broadcastable_batch_size = try_broadcast_param_size( + q_batch_size, k_batch_size, v_batch_size, "batch size ", debug); + + // if only one of k or v require broadcasting of batch size, the other + // must have a consistent seq_len dim + if (broadcastable_batch_size) { + if (k_batch_size == 1 && v_batch_size != 1 && + !check_safe_kv_broadcast(params.value, debug)) { + return false; + } + if (v_batch_size == 1 && k_batch_size != 1 && + !check_safe_kv_broadcast(params.key, debug)) { + return false; + } + } } + return broadcastable_batch_size; +} + +inline bool check_nonzero_sequence_lengths_dense(sdp_params const& params, bool debug) { // In some cases people will pass in 0 sized tensors, this will // cause the fused path to error with unaligned mask bool zero_seq_len_q = params.query.sym_size(-2) == 0; @@ -394,12 +401,10 @@ inline bool check_nonzero_sequence_lengths(sdp_params const& params, bool debug) return true; } -inline bool check_last_dim_stride_equals_1(sdp_params const& params, bool debug) { - if (has_for_nested_inputs(params)){ - // The stride checking for NestedTensors is done within the kernel - // And .contiguous will be called if needed - return true; - } +inline bool check_last_dim_stride_equals_1_dense(sdp_params const& params, bool debug) { + // The stride checking for NestedTensors is done within the kernel + // And .contiguous will be called if needed + // This function checks that the last dimension of the inputs to // fused_attention have stride 1 bool qkv_strides_equal_1 = params.query.sym_stride(-1) == 1 && diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 84814998a9e8f..26136fdae25fb 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -2,12 +2,14 @@ import io import itertools +from typing import Optional, Tuple import unittest from functools import partial import numpy as np import torch import torch.nn +from torch.testing._internal.common_cuda import SM80OrLater from torch.testing._internal.common_device_type import ( dtypes, dtypesIfCUDA, @@ -28,6 +30,7 @@ skipIfSlowGradcheckEnv, skipIfTorchDynamo, subtest, + TEST_WITH_ROCM, TestCase, ) @@ -2907,6 +2910,42 @@ def grad_test_func(a, b, c): data = (a, b, c) assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) +# Found in torch/testing/_comparison.py +default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float32: 1e-5} +default_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float32: 1.3e-6} + +def get_rtol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float: + deviation = true_value - computed_value + deviation = torch.abs(deviation / true_value) + # Fill in the nans with the default rtol + torch.nan_to_num_(deviation, nan=default_rtol[computed_value.dtype]) + return deviation.max().item() + + +def get_atol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float: + deviation = true_value - computed_value + atol = torch.abs(deviation).max().item() + return atol + + +def get_tolerances( + true_value: torch.Tensor, + computed_value: torch.Tensor, + fudge_factor: Optional[float] = None, +) -> Tuple[float, float]: + """Returns the absolute and relative tolerances for comparing two tensors.""" + fudge_factor = fudge_factor if fudge_factor is not None else 1.0 + atol = get_atol(true_value, computed_value) + rtol = get_rtol(true_value, computed_value) + + atol = fudge_factor * max(atol, default_atol[computed_value.dtype]) + rtol = fudge_factor * max(rtol, default_rtol[computed_value.dtype]) + # torch.isclose() has weird behavior around see: + # https://github.com/pytorch/pytorch/issues/102400 + if rtol > 1e30: + rtol = default_rtol[computed_value.dtype] + return atol, rtol + # We can probably parametrizing existing tests instead of having a separate # test class as we begin to support more ops. Also maybe rewrite with OpInfos. class TestNestedTensorSubclass(NestedTestCase): @@ -3386,6 +3425,144 @@ def test_is_contiguous(self, device): self.assertTrue(not nt_noncontiguous.is_contiguous(memory_format=torch.contiguous_format)) self.assertTrue(nt_contiguous_narrow.is_contiguous(memory_format=torch.contiguous_format)) + # Note 1: CPU Fused kernels do not support nested, Math is missing ops to work with NT jagged + # Note 2: Unless running on newer GPUs, only mem-effn or math are available, and mem-effn + # will fail with gradients and math has ops that aren't implemented. Therefore, in + # order to get some kernel to work with most GPUs, we have to disable gradients in + # this more general test + # Note 3: ROCm only supports the math kernel, which doesn't work with jagged NTs + @onlyCUDA + @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support device side asserts") + @torch.set_grad_enabled(False) + @parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32] if + SM80OrLater else [torch.float16, torch.float32]) + def test_sdpa(self, device, dtype): + batch_size = 1 + emb_dims = 1024 + n_heads = 8 + head_dims = emb_dims // n_heads + + sen1 = torch.randn(11, emb_dims, dtype=dtype, device=device) + sen2 = torch.randn(13, emb_dims, dtype=dtype, device=device) + + query = torch.nn.Linear(emb_dims, emb_dims, bias=False, device=device, dtype=dtype) + key = torch.nn.Linear(emb_dims, emb_dims, bias=False, device=device, dtype=dtype) + value = torch.nn.Linear(emb_dims, emb_dims, bias=False, device=device, dtype=dtype) + + # Simplest case: 1 sentence, no batching + x_d1 = sen1.unsqueeze(0) + x_nt = torch.nested.as_nested_tensor([sen1], layout=torch.jagged) + + q_d1 = query(x_d1).view(batch_size, -1, n_heads, head_dims).transpose(1, 2) + k_d1 = key(x_d1).view(batch_size, -1, n_heads, head_dims).transpose(1, 2) + v_d1 = value(x_d1).view(batch_size, -1, n_heads, head_dims).transpose(1, 2) + + q_nt = query(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).transpose(1, 2) + k_nt = key(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).transpose(1, 2) + v_nt = value(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).transpose(1, 2) + + # High Precision Math Reference + q_d1_f32 = q_d1.to(torch.float32) + k_d1_f32 = k_d1.to(torch.float32) + v_d1_f32 = v_d1.to(torch.float32) + out_ref = torch.ops.aten._scaled_dot_product_attention_math(q_d1_f32, k_d1_f32, v_d1_f32)[0] + # Low Precision Math Reference + out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(q_d1, k_d1, v_d1)[0] + output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref) + + attn_d1 = torch.nn.functional.scaled_dot_product_attention(q_d1, k_d1, v_d1).transpose(1, 2) + attn_nt = torch.nn.functional.scaled_dot_product_attention(q_nt, k_nt, v_nt).transpose(1, 2) + + self.assertEqual(attn_d1, attn_nt.unbind()[0].unsqueeze(0), atol=output_ref_atol, rtol=output_ref_rtol) + + # Simple case: 2 sentences, no extra params + x_d2 = sen2.unsqueeze(0) + x_nt = torch.nested.as_nested_tensor([sen1, sen2], layout=torch.jagged) + + q_d2 = query(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2) + k_d2 = key(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2) + v_d2 = value(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2) + + q_nt = query(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).transpose(1, 2) + k_nt = key(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).transpose(1, 2) + v_nt = value(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).transpose(1, 2) + + attn_d2 = torch.nn.functional.scaled_dot_product_attention(q_d2, k_d2, v_d2).transpose(1, 2) + attn_nt = torch.nn.functional.scaled_dot_product_attention(q_nt, k_nt, v_nt).transpose(1, 2) + + attn_nts = attn_nt.unbind() + self.assertEqual(attn_d1, attn_nts[0].unsqueeze(0), atol=output_ref_atol, rtol=output_ref_rtol) + self.assertEqual(attn_d2, attn_nts[1].unsqueeze(0), atol=output_ref_atol, rtol=output_ref_rtol) + + # Test dispatcher works by calling only mem-effn and math (as they are safe for all devices) + with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=True, enable_math=False): + attn_nt = torch.nn.functional.scaled_dot_product_attention(q_nt, k_nt, v_nt).transpose(1, 2) + + attn_nts = attn_nt.unbind() + self.assertEqual(attn_d1, attn_nts[0].unsqueeze(0), atol=output_ref_atol, rtol=output_ref_rtol) + self.assertEqual(attn_d2, attn_nts[1].unsqueeze(0), atol=output_ref_atol, rtol=output_ref_rtol) + + # Will fail bc unsupported ops + # TODO: Add remaining ops, or implement a different math dispatch for jagged + with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): + with self.assertRaises(RuntimeError): + attn_nt = torch.nn.functional.scaled_dot_product_attention(q_nt, k_nt, v_nt).transpose(1, 2) + + # This requires NT -> NT views to work in inductor, which is a TODO + @unittest.expectedFailure + @onlyCUDA + @torch.set_grad_enabled(False) + @parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32] if + SM80OrLater else [torch.float16, torch.float32]) + def test_sdpa_compile(self, device, dtype): + batch_size = 1 + emb_dims = 1024 + n_heads = 8 + head_dims = emb_dims // n_heads + + sen1 = torch.randn(11, emb_dims, dtype=dtype, device=device) + sen2 = torch.randn(13, emb_dims, dtype=dtype, device=device) + + query = torch.nn.Linear(emb_dims, emb_dims, bias=False, device=device, dtype=dtype) + key = torch.nn.Linear(emb_dims, emb_dims, bias=False, device=device, dtype=dtype) + value = torch.nn.Linear(emb_dims, emb_dims, bias=False, device=device, dtype=dtype) + + # Simplest case: 1 sentence, no batching + x_d1 = sen1.unsqueeze(0) + x_d2 = sen2.unsqueeze(0) + x_nt = torch.nested.as_nested_tensor([sen1, sen2], layout=torch.jagged) + + q_d1 = query(x_d1).view(batch_size, -1, n_heads, head_dims).transpose(1, 2) + k_d1 = key(x_d1).view(batch_size, -1, n_heads, head_dims).transpose(1, 2) + v_d1 = value(x_d1).view(batch_size, -1, n_heads, head_dims).transpose(1, 2) + q_d2 = query(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2) + k_d2 = key(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2) + v_d2 = value(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2) + + q_nt = query(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).transpose(1, 2) + k_nt = key(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).transpose(1, 2) + v_nt = value(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).transpose(1, 2) + + # High Precision Math Reference + q_d1_f32 = q_d1.to(torch.float32) + k_d1_f32 = k_d1.to(torch.float32) + v_d1_f32 = v_d1.to(torch.float32) + out_ref = torch.ops.aten._scaled_dot_product_attention_math(q_d1_f32, k_d1_f32, v_d1_f32)[0] + # Low Precision Math Reference + out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(q_d1, k_d1, v_d1)[0] + output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref) + + attn_d1 = torch.nn.functional.scaled_dot_product_attention(q_d1, k_d1, v_d1).transpose(1, 2) + attn_d2 = torch.nn.functional.scaled_dot_product_attention(q_d2, k_d2, v_d2).transpose(1, 2) + + compiled_sdpa = torch.compile(torch.nn.functional.scaled_dot_product_attention) + attn_nt = compiled_sdpa(q_nt, k_nt, v_nt).transpose(1, 2) + + attn_nts = attn_nt.unbind() + self.assertEqual(attn_d1, attn_nts[0].unsqueeze(0), atol=output_ref_atol, rtol=output_ref_rtol) + self.assertEqual(attn_d2, attn_nts[1].unsqueeze(0), atol=output_ref_atol, rtol=output_ref_rtol) + + instantiate_parametrized_tests(TestNestedTensor) instantiate_device_type_tests(TestNestedTensorDeviceType, globals()) diff --git a/torch/nested/_internal/nested_tensor.py b/torch/nested/_internal/nested_tensor.py index 2991a4efa1ee0..5db05333f1ca9 100644 --- a/torch/nested/_internal/nested_tensor.py +++ b/torch/nested/_internal/nested_tensor.py @@ -43,6 +43,9 @@ class NestedTensor(torch.Tensor): _stride: Tuple[int, ...] # Indicates that the nth dimension is ragged _ragged_idx: int + # SDPA Metadata + _max_seqlen: int + _min_seqlen: int @staticmethod def __new__( @@ -84,12 +87,18 @@ def __init__(self, values, offsets, *, lengths=None, **kwargs): # (create a new one if needed). ragged_source = offsets if lengths is None else lengths ragged_size = get_tensor_symint(ragged_source, coeff=1) + self._ragged_idx = kwargs.get("_ragged_idx", 1) B = offsets.shape[0] - 1 - Ds = values.shape[1:] - self._size = (B, ragged_size, *Ds) + Ds = values.shape[: self._ragged_idx - 1] + values.shape[self._ragged_idx :] + + nested_size = [B] + nested_size.extend(Ds[: self._ragged_idx - 1]) + nested_size.append(ragged_size) + nested_size.extend(Ds[self._ragged_idx - 1 :]) + self._size = tuple(nested_size) + stride = values.stride() - self._strides = (ragged_size * stride[0], *stride) - self._ragged_idx = 1 + self._strides = (ragged_size * stride[self._ragged_idx - 1], *stride) if values.requires_grad: raise ValueError( @@ -100,6 +109,27 @@ def __init__(self, values, offsets, *, lengths=None, **kwargs): self._offsets = offsets self._lengths = lengths + # SDPA metadata + def get_sdpa_extreme_seqlen(func, tensor): + return int(func(tensor).item()) + + # Note: Not using kwargs.get to avoid execution of get_sdpa_extreme_seqlen + # unless it is really needed + self._max_seqlen = ( + kwargs["_max_seqlen"] + if "_max_seqlen" in kwargs + else get_sdpa_extreme_seqlen( + torch.max, offsets.diff() if lengths is None else lengths + ) + ) + self._min_seqlen = ( + kwargs["_min_seqlen"] + if "_min_seqlen" in kwargs + else get_sdpa_extreme_seqlen( + torch.min, offsets.diff() if lengths is None else lengths + ) + ) + def values(self): return self._values @@ -135,6 +165,9 @@ def __tensor_flatten__(self): ctx = { "requires_grad": self.requires_grad, "ragged_size": self._size[self._ragged_idx], + "max_seqlen": self._max_seqlen, + "min_seqlen": self._min_seqlen, + "ragged_idx": self._ragged_idx, } inner_tensors = ["_values", "_offsets"] if self._lengths is not None: @@ -187,6 +220,9 @@ def __tensor_unflatten__(inner_tensors: Dict, meta): offsets=offsets, lengths=lengths, requires_grad=meta["requires_grad"], + _max_seqlen=meta["max_seqlen"], + _min_seqlen=meta["min_seqlen"], + _ragged_idx=meta["ragged_idx"], ) @classmethod @@ -222,35 +258,55 @@ class ViewBufferFromNested(torch.autograd.Function): @staticmethod def forward(ctx, x: NestedTensor): # type: ignore[override] ctx.save_for_backward(x.offsets()) + ctx.max_seqlen = x._max_seqlen + ctx.min_seqlen = x._min_seqlen + ctx._ragged_idx = x._ragged_idx return x.values() @staticmethod def backward(ctx, gO: torch.Tensor): # type: ignore[override] (offsets,) = ctx.saved_tensors - return NestedTensor(gO, offsets=offsets) + return NestedTensor( + gO, + offsets=offsets, + _max_seqlen=ctx.max_seqlen, + _min_seqlen=ctx.min_seqlen, + _ragged_idx=ctx._ragged_idx, + ) # Not actually a view! class ViewNestedFromBuffer(torch.autograd.Function): @staticmethod - def forward(ctx, values: torch.Tensor, offsets: torch.Tensor): # type: ignore[override] - return NestedTensor(values.detach(), offsets=offsets) + def forward(ctx, values: torch.Tensor, offsets: torch.Tensor, max_seqlen: int, min_seqlen: int): # type: ignore[override] + return NestedTensor( + values.detach(), + offsets=offsets, + _max_seqlen=max_seqlen, + _min_seqlen=min_seqlen, + ) @staticmethod def backward(ctx, gO: NestedTensor): # type: ignore[override] - return gO.values(), None, None + return gO.values(), None, None, None # Not actually a view! # NOTE: @jbschlosser is working on making it a view class ViewNonContiguousNestedFromBuffer(torch.autograd.Function): @staticmethod - def forward(ctx, values: torch.Tensor, offsets: torch.Tensor, lengths: torch.Tensor): # type: ignore[override] - return NestedTensor(values.detach(), offsets=offsets, lengths=lengths) + def forward(ctx, values: torch.Tensor, offsets: torch.Tensor, lengths: torch.Tensor, max_seqlen: int, min_seqlen: int): # type: ignore[override] + return NestedTensor( + values.detach(), + offsets=offsets, + lengths=lengths, + _max_seqlen=max_seqlen, + _min_seqlen=min_seqlen, + ) @staticmethod def backward(ctx, gO: NestedTensor): # type: ignore[override] - return gO.values(), None, None + return gO.values(), None, None, None, None # Need to make it obvious that users should be passing in offsets @@ -303,7 +359,10 @@ def jagged_from_list( ] ) - return ViewNestedFromBuffer.apply(values, offsets), offsets # type: ignore[call-overload] + max_seqlen = max([t.shape[0] for t in tensors]) + min_seqlen = min([t.shape[0] for t in tensors]) + + return ViewNestedFromBuffer.apply(values, offsets, max_seqlen, min_seqlen), offsets # type: ignore[call-overload] def jagged_from_tensor_and_lengths( @@ -354,16 +413,28 @@ def jagged_from_tensor_and_lengths( if offsets[0] + length_list[0] != orig_dim: is_contiguous = False + actual_max_seqlen = int(torch.max(lengths).item()) + min_seqlen = int(torch.min(lengths).item()) + if is_contiguous: return ( ViewNestedFromBuffer.apply( - values[offsets[0] : offsets[-1]], offsets - offsets[0] + values[offsets[0] : offsets[-1]], + offsets - offsets[0], + actual_max_seqlen, + min_seqlen, ), offsets, None, ) - return ViewNonContiguousNestedFromBuffer.apply(values, offsets, length_list), offsets, length_list # type: ignore[call-overload] + return ( + ViewNonContiguousNestedFromBuffer.apply( + values, offsets, length_list, actual_max_seqlen, min_seqlen + ), + offsets, + length_list, + ) # type: ignore[call-overload] def buffer_from_jagged(jagged): diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index 0700b08ed6a87..df5ec993414fa 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -2,6 +2,7 @@ import math import torch +from torch.nested._internal.sdpa import jagged_scaled_dot_product_attention from .nested_tensor import NestedTensor from typing import * # noqa: F403 @@ -184,6 +185,8 @@ def lookup_jagged(func, *args, **kwargs) -> Optional[Callable]: def extract_kwargs(arg): kwargs = { "offsets": arg.offsets(), + "_max_seqlen": arg._max_seqlen, + "_min_seqlen": arg._min_seqlen, } return kwargs @@ -256,18 +259,10 @@ def jagged_binary_pointwise(func, *args, **kwargs): def jagged_torch_function(func, *args, **kwargs): - # Handle SDPA specially since it's CompositeImplicit. We don't want - # the nestedness of the inputs to affect the kernel choice, so unwrap - # the NTs here before passing to SDPA -> rewrap the output as NT. + # SDPA has special kernels that handle nested tensors. + # Dispatch to the correct implementation here if func is torch._C._nn.scaled_dot_product_attention: - t_args = [t._values if isinstance(t, NestedTensor) else t for t in args] - t_kwargs = { - k: v._values if isinstance(v, NestedTensor) else v - for k, v in kwargs.items() - } - - output = func(*t_args, **t_kwargs) - return NestedTensor(output, **extract_kwargs(args[0])) + return jagged_scaled_dot_product_attention(*args, **kwargs) # Handle flatten() here because it's CompositeImplicit. if func.__name__ == "flatten": @@ -355,6 +350,10 @@ def is_contiguous_general(func, *args, **kwargs): if inp.lengths() is not None: return False + # If jagged dim is not 1 it's not contiguous + if inp._ragged_idx != 1: + return False + new_kwargs["memory_format"] = new_kwargs.get( "memory_format", torch.contiguous_format ) @@ -537,6 +536,11 @@ def unbind_int(func, *args, **kwargs): offsets = inp.offsets() lengths = inp.lengths() + if inp._ragged_idx != 1: + raise RuntimeError( + "unbind(): only supported for NestedTensor when jagged dimension is 1" + ) + if lengths is None: return torch.split(values, offsets.diff().tolist()) return [ @@ -713,7 +717,32 @@ def transpose_int(func, *args, **kwargs): func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True ) + from torch._prims_common import canonicalize_dims + inp = new_kwargs.pop("input") + dim0, dim1 = canonicalize_dims(inp.dim(), (new_kwargs["dim0"], new_kwargs["dim1"])) + + # To support the SDPA API, inputs need to have the ragged idx transposed to dim 2 + # instead of 1, although the internal Flash and mem-effn implementations will + # use the inputs with raggedness in dim 1. + if dim0 == inp._ragged_idx or dim1 == inp._ragged_idx: + if dim0 == 0 or dim1 == 0: + raise ValueError( + "Transpose is not supported on the batch dimension for jagged NT" + ) + if dim0 == inp._ragged_idx: + to_dim = dim1 + else: + to_dim = dim0 + return NestedTensor( + inp.values().transpose( + _outer_to_inner_dim(len(inp._size), dim0), + _outer_to_inner_dim(len(inp._size), dim1), + ), + **extract_kwargs(inp), + _ragged_idx=to_dim, + ) + new_kwargs["dim0"] = _wrap_jagged_dim(inp.dim(), new_kwargs["dim0"], "transpose") new_kwargs["dim1"] = _wrap_jagged_dim(inp.dim(), new_kwargs["dim1"], "transpose") diff --git a/torch/nested/_internal/sdpa.py b/torch/nested/_internal/sdpa.py new file mode 100644 index 0000000000000..437b4b1fb6e0f --- /dev/null +++ b/torch/nested/_internal/sdpa.py @@ -0,0 +1,729 @@ +import logging +import math +from typing import Optional, Tuple + +import torch +import torch.nn +from torch.backends.cuda import ( + can_use_efficient_attention, + can_use_flash_attention, + flash_sdp_enabled, + math_sdp_enabled, + mem_efficient_sdp_enabled, + SDPAParams, + SDPBackend, +) + +from .nested_tensor import NestedTensor + +log = logging.getLogger(__name__) + + +def _validate_sdpa_input( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p=0.0, + is_causal=False, + scale=None, +): + if ( + not isinstance(query, NestedTensor) + or not isinstance(key, NestedTensor) + or not isinstance(value, NestedTensor) + ): + raise ValueError( + f"Expected query, key, and value to be nested tensors, " + f"but got query.is_nested: {query.is_nested}, key.is_nested: {key.is_nested}, " + f"and value.is_nested: {value.is_nested} instead." + ) + if query.dtype != key.dtype or query.dtype != value.dtype: + raise ValueError( + f"Expected query, key, and value to have the same dtype, " + f"but got query.dtype: {query.dtype}, key.dtype: {key.dtype}, " + f"and value.dtype: {value.dtype} instead." + ) + if query.device != key.device or query.device != value.device: + raise ValueError( + f"Expected query, key, and value to have the same device type, " + f"but got query.device: {query.device}, key.device: {key.device}, " + f"and value.device: {value.device} instead." + ) + if query.dim() < 2 or key.dim() < 2 or value.dim() < 2: + raise ValueError( + f"Expected query, key, and value to all be at least 2 dimensional, but got query.dim: " + f"{query.dim()}, key.dim: {key.dim()} and value.dim: {value.dim()} instead." + ) + if query._ragged_idx != 2 or key._ragged_idx != 2 or value._ragged_idx != 2: + raise ValueError( + f"Expected query, key, and value to all be be jagged at dimension 2, but got query._ragged_idx: " + f"{query._ragged_idx}, key._ragged_idx: {key._ragged_idx} and value._ragged_idx: {value._ragged_idx} instead." + ) + if attn_mask is not None: + # TODO: Figure out whether masks are actually supported for this layout or not + raise ValueError("Masks are not yet supported!") + if attn_mask.dtype != torch.bool and attn_mask.dtype != query.dtype: + raise ValueError( + f"Expected attn_mask dtype to be bool or to match query dtype, but got attn_mask.dtype: " + f"{attn_mask.dtype}, and query.dtype: {query.dtype} instead." + ) + + +def _check_batch_size_nested(params: SDPAParams, debug=False) -> bool: + # This is expected to be called after check_tensor_shapes ensuring that the + # size() calls won't error since the inputs are all 4 dimensional + q_batch_size = params.query.size(0) + k_batch_size = params.key.size(0) + v_batch_size = params.value.size(0) + + # num_heads logic for nested input is checked in + # check_for_seq_len_0_nested_tensor as there is handling there to make sure + # num_heads is not ragged + return q_batch_size == k_batch_size and q_batch_size == v_batch_size + + +def _check_head_dim_size_flash_nested(params: SDPAParams, debug=False) -> bool: + max_size = 256 + query_size_last = params.query.size(-1) + key_size_last = params.key.size(-1) + value_size_last = params.value.size(-1) + same_head_dim_size = ( + query_size_last == key_size_last and query_size_last == value_size_last + ) + if not ( + same_head_dim_size + and (query_size_last % 8 == 0) + and (query_size_last <= max_size) + ): + if debug: + log.warning( + "For NestedTensor inputs, Flash attention requires q,k,v to have the same " + "last dimension and to be a multiple of 8 and less than or equal to 256. " + "Got Query.size(-1): %d, Key.size(-1): %d, Value.size(-1): %d instead.", + query_size_last, + key_size_last, + value_size_last, + ) + return False + return True + + +def _check_for_seq_len_0_and_consistent_head_dim_nested_helper( + param: torch.Tensor, param_name: str, debug=False +) -> bool: + assert isinstance(param, NestedTensor), "param should be a jagged NT" + + if param._ragged_idx == 1: + # num_head_dims is ragged + if debug: + log.warning( + "Fused kernels do not support ragged num_head_dims, %s has a ragged num_heads.", + param_name, + ) + return False + + # This is being called inside sdp with shape [batch, heads, {seq_len}, dim] + if param._min_seqlen == 0: + if debug: + log.warning( + "Fused kernels do not support seq_len == 0, %s has a seq len of 0.", + param_name, + ) + return False + + return True + + +def _try_broadcast_param_size(q_size, k_size, v_size, param_name, debug=False) -> bool: + max_size = max(q_size, k_size, v_size) + if ( + (q_size != max_size and q_size != 1) + or (k_size != max_size and k_size != 1) + or (v_size != max_size and v_size != 1) + ): + if debug: + log.warning( + "Both fused kernels require query, key and value to have broadcastable %s, " + "got Query %s %d, Key %s %d, Value %s %d instead.", + param_name, + param_name, + q_size, + param_name, + k_size, + param_name, + v_size, + ) + return False + return True + + +def _check_for_seq_len_0_nested(params: SDPAParams, debug=False) -> bool: + # When this function is called we are assured that the nt is dim==4 + q_is_safe = ( + _check_for_seq_len_0_and_consistent_head_dim_nested_helper( + params.query, "query", debug + ) + if params.query.is_nested + else True + ) + # short circuit if any is unsafe + if not q_is_safe: + return False + + k_is_safe = ( + _check_for_seq_len_0_and_consistent_head_dim_nested_helper( + params.key, "key", debug + ) + if params.key.is_nested + else True + ) + # short circuit if any is unsafe + if not k_is_safe: + return False + + v_is_safe = ( + _check_for_seq_len_0_and_consistent_head_dim_nested_helper( + params.value, "value", debug + ) + if params.value.is_nested + else True + ) + # short circuit if any is unsafe + if not v_is_safe: + return False + + # We now know none of the inputs have ragged num_heads, so we can safely + # access .size(1) + q_num_heads = params.query.size(1) + k_num_heads = params.key.size(1) + v_num_heads = params.value.size(1) + same_num_heads = q_num_heads == k_num_heads and q_num_heads == v_num_heads + + if not same_num_heads: + if ( + params.query.requires_grad + or params.key.requires_grad + or params.value.requires_grad + ): + if debug: + log.warning( + "Both fused kernels do not support training with broadcasted NT inputs." + ) + return False + return _try_broadcast_param_size( + q_num_heads, k_num_heads, v_num_heads, "num heads", debug + ) + return True + + +def _check_requires_grad_nested(params: SDPAParams, debug=False) -> bool: + if ( + params.query.requires_grad + or params.key.requires_grad + or params.value.requires_grad + ): + # TODO: This can be done, it just isn't written yet + if debug: + log.warning( + "Memory efficient attention currently doesn't support training with NT inputs." + ) + return False + return True + + +def _can_use_flash_sdpa_jagged(params: SDPAParams, debug=False) -> bool: + constraints = ( + _check_batch_size_nested, + _check_head_dim_size_flash_nested, + _check_for_seq_len_0_nested, + ) + for constraint in constraints: + if not constraint(params, debug): + return False + return True + + +def _can_use_efficient_sdpa_jagged(params: SDPAParams, debug=False) -> bool: + constraints = ( + _check_requires_grad_nested, + _check_batch_size_nested, + _check_for_seq_len_0_nested, + ) + for constraint in constraints: + if not constraint(params, debug): + return False + return True + + +def _can_use_math_sdpa_jagged(params: SDPAParams, debug=False) -> bool: + if ( + not params.query.transpose(1, 2).is_contiguous() + or not params.key.transpose(1, 2).is_contiguous() + or not params.value.transpose(1, 2).is_contiguous() + ): + if debug: + log.warning( + "If inputs are nested tensors they must be contiguous after transposing." + ) + return False + if params.is_causal: + if debug: + log.warning( + "Nested tensors for query / key are not supported when is_causal=True." + ) + return False + return True + + +def _select_sdp_backend(query, key, value, attn_mask, dropout, is_causal): + if ( + not flash_sdp_enabled() + and not mem_efficient_sdp_enabled() + and not math_sdp_enabled() + ): + return SDPBackend.ERROR + + ordering = ( + SDPBackend.FLASH_ATTENTION, + SDPBackend.EFFICIENT_ATTENTION, + SDPBackend.MATH, + ) + + params = SDPAParams(query, key, value, attn_mask, dropout, is_causal) + + for backend in ordering: + if backend == SDPBackend.FLASH_ATTENTION: + if can_use_flash_attention(params) and _can_use_flash_sdpa_jagged(params): + return SDPBackend.FLASH_ATTENTION + if backend == SDPBackend.EFFICIENT_ATTENTION: + if can_use_efficient_attention(params) and _can_use_efficient_sdpa_jagged( + params + ): + return SDPBackend.EFFICIENT_ATTENTION + if backend == SDPBackend.MATH: + if math_sdp_enabled() and _can_use_math_sdpa_jagged(params): + return SDPBackend.MATH + + log.warning("Memory efficient kernel not used because:") + can_use_efficient_attention(params, debug=True) + _can_use_efficient_sdpa_jagged(params, debug=True) + log.warning("Flash attention kernel not used because:") + can_use_flash_attention(params, debug=True) + _can_use_flash_sdpa_jagged(params, debug=True) + log.warning("Math attention kernel not used because:") + _can_use_math_sdpa_jagged(params, debug=True) + return SDPBackend.ERROR + + +def _cumulative_and_max_seq_len_nnz(qkv: torch.Tensor) -> Tuple[torch.Tensor, int, int]: + # This function is used to calculate two pieces of metadata that are needed + # for use with flash-attention and efficient_attention kernels. They are the + # cumulative sequence_length over a batch of sequences and the maximum + # sequence length. + + # It returns a tuple of cumulative sequence lengths and the maximum sequence + # length, and the last element in the cumulative_sequence_lengths + if not isinstance(qkv, NestedTensor): + raise ValueError("QKV must be nested for flash cumulative_seq_len calculation.") + + if qkv.lengths() is None: + # TODO: Explore performance impact of copying + cumulative_seqlen = qkv.offsets().to(dtype=torch.int32, device=qkv.device) + max_seqlen = qkv._max_seqlen + n_elem = qkv.values().shape[0] + else: + # TODO: Explore performance impact of copying + cumulative_seqlen = ( + qkv.lengths().cumsum(0).to(dtype=torch.int32, device=qkv.device) + ) + batch_size = qkv.size(0) + max_seqlen = qkv._max_seqlen + # TODO: Explore performance impact when compiling + n_elem = int(cumulative_seqlen[-1].item()) + return cumulative_seqlen, max_seqlen, n_elem + + +def _is_safe_to_get_storage_as_tensor(tensor: torch.Tensor): + # This function checks if a nested tensor is valid for + # use with the flash-attention and efficient_attention kernels without + # needing to call contiguous on the nested tensor input. + # It checks that the storage offsets' adjacent_differences are a constant + # mutiple of the previous tensor in the nested tensor and that the strides + # are monitonically decreasing. This check is done after calling transpose on + # the nested tensor resulting in a Nt of shape [bsz, {seq_len}, num_heads, dim] + + # Returns a boolean indicating if contiguous needs to be called for input + assert isinstance(tensor, NestedTensor) + offsets = tensor.offsets() + strides = tensor._stride + + n_tensors = offsets.size(0) - 1 + if n_tensors <= 1: + return True + + # Check initially that the tensor strides are in strictly descending order + prev_stride = strides[1] + for stride in strides[2:]: + if prev_stride <= stride: + # This would mean that the last stride is greater than the seq_len + # stride + return False + prev_stride = stride + + # Congrats you made it! + return True + + +def _view_as_dense( + tensor: torch.Tensor, Nnz: int, num_heads: int, head_dim: int +) -> torch.Tensor: + if tensor.is_nested: + return tensor.values() + return tensor.view(Nnz, num_heads, head_dim) + + +# TODO: Next iteration should add test cases and check it works +# def _sdpa_nested_preprocessing_with_broadcast(query, key, value): +# # Query (Batch x Num_heads x {Q_seq_len} x Dim_per_head) +# # Key (Batch x Num_heads x {KV_seq_len} x Dim_per_head) +# # Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head) +# q_batch_size = query.size(0) +# k_batch_size = key.size(0) +# v_batch_size = value.size(0) + +# output_batch_size = max(q_batch_size, k_batch_size, v_batch_size) + +# q_num_heads = query.size(1) +# k_num_heads = key.size(1) +# v_num_heads = value.size(1) + +# output_num_heads = max(q_num_heads, k_num_heads, v_num_heads) + +# head_dim_qk = query.size(3) +# head_dim_v = value.size(3) + +# q_t = query.transpose(1, 2) +# k_t = key.transpose(1, 2) +# v_t = value.transpose(1, 2) + +# # Checks in sdp_utils ensure that if {*}_batch_size/{*}_num_heads != +# # output_batch_size/num_heads then they are 1 +# q_batch_size_needs_broadcast = q_batch_size != output_batch_size +# k_batch_size_needs_broadcast = k_batch_size != output_batch_size +# v_batch_size_needs_broadcast = v_batch_size != output_batch_size + +# # If {*}_batch_size_needs_broadcast, then +# # (1) max_seqlen_batch_{*} is given by {*}_t.size(1) +# # this is because needs_broadcast indicates that the batch_size is 1 +# # and hence there is only 1 value for seq_len +# # (2) The cum_seq_lens are given by [0, {*}_t.size(1), 2 * {*}_t.size(1), +# # ..., outut_batch_size * {*}_t.size(1)] +# # (3) Nnz_{*} is given by output_batch_size * {*}_t.size(1) + +# if q_batch_size_needs_broadcast or not q_t.is_nested: +# max_seqlen_batch_q = q_t.size(1) +# cumulative_sequence_length_q = torch.arange( +# 0, +# (output_batch_size + 1) * max_seqlen_batch_q, +# max_seqlen_batch_q, +# device=q_t.device, +# dtype=torch.int32, +# ) +# Nnz_q = output_batch_size * max_seqlen_batch_q +# else: +# ( +# cumulative_sequence_length_q, +# max_seqlen_batch_q, +# Nnz_q, +# ) = _cumulative_and_max_seq_len_nnz(q_t) + +# if k_batch_size_needs_broadcast and v_batch_size_needs_broadcast: +# assert k_t.size(1) == v_t.size(1) +# max_seqlen_batch_kv = k_t.size(1) +# cumulative_sequence_length_kv = torch.arange( +# 0, +# (output_batch_size + 1) * max_seqlen_batch_kv, +# max_seqlen_batch_kv, +# device=k_t.device, +# dtype=torch.int32, +# ) +# Nnz_kv = output_batch_size * max_seqlen_batch_kv +# else: +# cumulative_sequence_length_kv, max_seqlen_batch_kv, Nnz_kv = ( +# _cumulative_and_max_seq_len_nnz(v_t) +# if k_batch_size_needs_broadcast +# else _cumulative_and_max_seq_len_nnz(k_t) +# ) + +# q_num_heads_needs_broadcast = q_num_heads != output_num_heads +# k_num_heads_needs_broadcast = k_num_heads != output_num_heads +# v_num_heads_needs_broadcast = v_num_heads != output_num_heads + +# if not q_t.is_nested: +# query_buffer_reshaped = q_t.expand( +# output_batch_size, q_t.size(1), output_num_heads, head_dim_qk +# ) +# query_buffer_reshaped = query_buffer_reshaped.reshape( +# Nnz_q, output_num_heads, head_dim_qk +# ) +# else: +# if not q_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(q_t): +# q_t = q_t.contiguous() +# # If we are broadcasting then Nnz_q will be the output_batch_size since +# # seq_len is 1 +# effective_batch_size_q = ( +# output_batch_size if q_batch_size_needs_broadcast else Nnz_q +# ) +# query_buffer_reshaped = _view_as_dense( +# q_t, effective_batch_size_q, output_num_heads, head_dim_qk +# ) + +# # If the physical layout of the NestedTensor's storage +# # is not: batch, {seq_len}, num_heads, head_dim then we need +# # to call contiguous +# if not k_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(k_t): +# k_t = k_t.contiguous() +# if not v_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(v_t): +# v_t = v_t.contiguous() + +# effective_batch_size_k = ( +# output_batch_size if k_batch_size_needs_broadcast else Nnz_kv +# ) +# key_buffer_reshaped = _view_as_dense( +# k_t, effective_batch_size_k, output_num_heads, head_dim_qk +# ) + +# effective_batch_size_v = ( +# output_batch_size if v_batch_size_needs_broadcast else Nnz_kv +# ) +# value_buffer_reshaped = _view_as_dense( +# v_t, effective_batch_size_v, output_num_heads, head_dim_v +# ) + +# if not q_batch_size_needs_broadcast: +# output_shape = q_t._size +# if head_dim_v != head_dim_qk: +# output_shape[-1] = head_dim_v +# if q_num_heads_needs_broadcast: +# output_shape[1] = output_num_heads +# else: +# output_shape = torch.empty(3, dtype=torch.int64, device=torch.device("cpu")) +# output_shape[0] = q_t.size(1) +# output_shape[1] = output_num_heads +# output_shape[2] = head_dim_v + +# return ( +# query_buffer_reshaped, +# key_buffer_reshaped, +# value_buffer_reshaped, +# cumulative_sequence_length_q, +# cumulative_sequence_length_kv, +# max_seqlen_batch_q, +# max_seqlen_batch_kv, +# output_shape, +# ) + + +def _sdpa_nested_preprocessing(query, key, value): + # Query (Batch x Num_heads x {Q_seq_len} x Dim_per_head) + # Key (Batch x Num_heads x {KV_seq_len} x Dim_per_head) + # Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head) + q_batch_size = query.size(0) + k_batch_size = key.size(0) + v_batch_size = value.size(0) + + q_num_heads = query.size(1) + k_num_heads = key.size(1) + v_num_heads = value.size(1) + + if not (q_batch_size == k_batch_size and q_batch_size == v_batch_size) or not ( + q_num_heads == k_num_heads and k_num_heads == v_num_heads + ): + raise RuntimeError( + "This path is currently not implemented for jagged layout NT." + ) + # return _sdpa_nested_preprocessing_with_broadcast(query, key, value) + + num_heads = query.size(1) + head_dim_qk = query.size(3) + head_dim_v = value.size(3) + q_t = query.transpose(1, 2) + k_t = key.transpose(1, 2) + v_t = value.transpose(1, 2) + + ( + cumulative_sequence_length_q, + max_seqlen_batch_q, + Nnz_q, + ) = _cumulative_and_max_seq_len_nnz(q_t) + ( + cumulative_sequence_length_kv, + max_seqlen_batch_kv, + Nnz_kv, + ) = _cumulative_and_max_seq_len_nnz(k_t) + + # [TODO] K and V have to have the same Nnz, should probably torch_check + # assume in order to not iterate over v + + # If the physical layout of the NestedTensor's storage + # is not: batch, {seq_len}, num_heads, head_dim then we need + # to call contiguous + if not q_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(q_t): + q_t = q_t.contiguous() + if not k_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(k_t): + k_t = k_t.contiguous() + if not v_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(v_t): + v_t = v_t.contiguous() + + query_buffer_reshaped = _view_as_dense(q_t, Nnz_q, num_heads, head_dim_qk) + key_buffer_reshaped = _view_as_dense(k_t, Nnz_kv, num_heads, head_dim_qk) + value_buffer_reshaped = _view_as_dense(v_t, Nnz_kv, num_heads, head_dim_v) + + output_nt_info = { + "offsets": q_t.offsets(), + "_max_seqlen": q_t._max_seqlen, + "_min_seqlen": q_t._min_seqlen, + } + + return ( + query_buffer_reshaped, + key_buffer_reshaped, + value_buffer_reshaped, + cumulative_sequence_length_q, + cumulative_sequence_length_kv, + max_seqlen_batch_q, + max_seqlen_batch_kv, + output_nt_info, + ) + + +def _pad_last_dim( + tensor: torch.Tensor, alignment_size: int, slice: bool +) -> torch.Tensor: + # FlashAttentionV2 requires that head dimension be a multiple of 8 + # This was previously done within the kernel, however + # This causes the kernel to maybe alias query, key, value + # So instead we pad the head_dimensions to be a multiple of 8 + # in the composite region + last_dim_size = tensor.size(-1) + if last_dim_size % alignment_size == 0: + return tensor + pad_count = alignment_size - (last_dim_size % alignment_size) + tensor = torch.nn.functional.pad(tensor, [0, pad_count]) + if slice: + return tensor[..., 0:last_dim_size] + return tensor + + +# TODO: coalesce with torch/nn/utils/attention.py +def _calculate_scale(query, scale): + softmax_scale = scale if scale is not None else math.sqrt(1.0 / query.size(-1)) + return softmax_scale + + +def _post_process_flash_output(out: torch.Tensor, og_size): + if not out.is_nested and out.size(-1) != og_size: + out = out[..., 0:og_size] + return out + + +def jagged_scaled_dot_product_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p=0.0, + is_causal=False, + scale=None, +): + _validate_sdpa_input(query, key, value, attn_mask, dropout_p, is_causal, scale) + compute_logsumexp = query.requires_grad or key.requires_grad or value.requires_grad + + backend_choice = _select_sdp_backend( + query, key, value, attn_mask, dropout_p, is_causal + ) + + if backend_choice == SDPBackend.FLASH_ATTENTION: + og_size = query.size(-1) + query_padded = _pad_last_dim(query, 8, False) + key_padded = _pad_last_dim(key, 8, False) + value_padded = _pad_last_dim(value, 8, False) + # We need to calculate the scale based off the OG head dim size + og_scale = _calculate_scale(query, scale) + ( + query_buffer_reshaped, + key_buffer_reshaped, + value_buffer_reshaped, + cumulative_sequence_length_q, + cumulative_sequence_length_kv, + max_seqlen_batch_q, + max_seqlen_batch_kv, + output_nt_info, + ) = _sdpa_nested_preprocessing(query_padded, key_padded, value_padded) + + ( + attention, + logsumexp, + philox_seed, + philox_offset, + debug_attn_mask, + ) = torch.ops.aten._flash_attention_forward( + query_buffer_reshaped, + key_buffer_reshaped, + value_buffer_reshaped, + cumulative_sequence_length_q, + cumulative_sequence_length_kv, + max_seqlen_batch_q, + max_seqlen_batch_kv, + dropout_p, + is_causal, + False, + scale=og_scale, + ) + + # Reshape output to convert nnz to batch_size and seq_len + attention = NestedTensor(attention, **output_nt_info).transpose(1, 2) + return _post_process_flash_output(attention, og_size) + elif backend_choice == SDPBackend.EFFICIENT_ATTENTION: + ( + query_reshaped, + key_reshaped, + value_reshaped, + cumulative_sequence_length_q, + cumulative_sequence_length_kv, + max_seqlen_batch_q, + _, + output_nt_info, + ) = _sdpa_nested_preprocessing(query, key, value) + ( + attention, + log_sumexp, + seed, + offset, + max_seqlen_q, + max_seqlen_batch_kv, + ) = torch.ops.aten._efficient_attention_forward( + query_reshaped.unsqueeze(0), + key_reshaped.unsqueeze(0), + value_reshaped.unsqueeze(0), + None, + cumulative_sequence_length_q, + cumulative_sequence_length_kv, + max_seqlen_batch_q, + dropout_p, + int(is_causal), + compute_logsumexp, + scale=scale, + ) + + # Reshape output to convert nnz to batch_size and seq_len + return NestedTensor(attention.squeeze(0), **output_nt_info).transpose(1, 2) + elif backend_choice == SDPBackend.MATH: + return torch._scaled_dot_product_attention_math( + query, key, value, attn_mask, dropout_p, is_causal, scale=scale + )[0] + else: + raise RuntimeError( + "No viable backend for scaled_dot_product_attention was found." + )