From a58acd2f822dfe0ca70da9aa51bae5e7afb1076e Mon Sep 17 00:00:00 2001 From: Liu Liu Date: Sun, 15 Sep 2024 12:57:28 -0400 Subject: [PATCH] Pass in lse from the op. --- ...ccv_nnc_scaled_dot_product_attention_mps.m | 7 +- lib/nnc/mfa/ccv_nnc_mfa_attention.cpp | 83 +++++++++++-------- 2 files changed, 55 insertions(+), 35 deletions(-) diff --git a/lib/nnc/cmd/scaled_dot_product_attention/mps/ccv_nnc_scaled_dot_product_attention_mps.m b/lib/nnc/cmd/scaled_dot_product_attention/mps/ccv_nnc_scaled_dot_product_attention_mps.m index 3a2deaed8..7d6fc3958 100644 --- a/lib/nnc/cmd/scaled_dot_product_attention/mps/ccv_nnc_scaled_dot_product_attention_mps.m +++ b/lib/nnc/cmd/scaled_dot_product_attention/mps/ccv_nnc_scaled_dot_product_attention_mps.m @@ -36,6 +36,7 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c const int o_nd = ccv_nnc_tensor_nd(o->info.dim); assert(o_nd == 3 || o_nd == 4); assert(q_nd == k_nd && k_nd == v_nd && v_nd == o_nd); + ccv_nnc_tensor_view_t* const lse = output_size > 1 ? (ccv_nnc_tensor_view_t*)outputs[1] : 0; int qdim[CCV_NNC_MAX_DIM_ALLOC]; int kdim[CCV_NNC_MAX_DIM_ALLOC]; @@ -185,20 +186,22 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c if (params.masked) { mask_buffer = mpgetbuffer((ccv_nnc_tensor_t*)attn_mask); } - mtl_buffer_t* tensors[6] = { + mtl_buffer_t* tensors[7] = { mpgetbuffer((ccv_nnc_tensor_t*)q), mpgetbuffer((ccv_nnc_tensor_t*)k), mpgetbuffer((ccv_nnc_tensor_t*)v), mpgetbuffer((ccv_nnc_tensor_t*)o), mask_buffer, + lse ? mpgetbuffer((ccv_nnc_tensor_t*)lse) : 0, NULL, }; - size_t tensor_offsets[5] = { + size_t tensor_offsets[6] = { q->dataof, k->dataof, v->dataof, o->dataof, attn_mask ? attn_mask->dataof : 0, + lse ? lse->dataof : 0, }; ccv_nnc_mfa_encode_attention(context, params, command_batch, tensors, tensor_offsets); diff --git a/lib/nnc/mfa/ccv_nnc_mfa_attention.cpp b/lib/nnc/mfa/ccv_nnc_mfa_attention.cpp index 5aab76377..dadbc8a6f 100644 --- a/lib/nnc/mfa/ccv_nnc_mfa_attention.cpp +++ b/lib/nnc/mfa/ccv_nnc_mfa_attention.cpp @@ -20,34 +20,7 @@ void ccv_nnc_mfa_prepare_attention(mfa::context* context, ccv_nnc_mfa_attention_ void ccv_nnc_mfa_encode_attention(mfa::context* context, ccv_nnc_mfa_attention_params_t params, MTL::CommandBatch* command_batch, MTL::Buffer** tensors, size_t* tensor_offsets) { mfa::attention::hash hash(params); - auto iterator = context->attention_cache.map.find(hash); - if (iterator == context->attention_cache.map.end()) { - mfa::precondition_failure("Attention hash not cached.", __LINE__, __FILE__, __FUNCTION__); - } - - auto* pipeline = iterator->second; - auto encoder = command_batch->startCommand(); - - int num_tensors = 0; - while (tensors[num_tensors] != nullptr) { - num_tensors += 1; - } - CCV_NNC_MFA_PRECONDITION(num_tensors == (hash.masked ? 5 : 4)); - - uint16_t data_type_size = 0; - switch (params.data_type) { - case MTL::DataTypeHalf: { - data_type_size = 2; - break; - } - case MTL::DataTypeFloat: { - data_type_size = 4; - break; - } - default: - CCV_NNC_MFA_PRECONDITION(false); - break; - } + if (!params.masked && params.Hq == params.Hk) { simd::ushort2 num_batch_dims(0); simd::uint2 batch_sizes(1); @@ -121,6 +94,7 @@ void ccv_nnc_mfa_encode_attention(mfa::context* context, ccv_nnc_mfa_attention_p auto pipeline = pipelineValue->pipeline; // Allocate a new command. + auto encoder = command_batch->startCommand(); encoder->setComputePipelineState(pipeline.get()); encoder->setThreadgroupMemoryLength(kernel->threadgroupMemoryAllocation, 0); @@ -128,17 +102,26 @@ void ccv_nnc_mfa_encode_attention(mfa::context* context, ccv_nnc_mfa_attention_p encoder->useResource(tensors[0], MTL::ResourceUsageRead); encoder->useResource(tensors[1], MTL::ResourceUsageRead); encoder->useResource(tensors[2], MTL::ResourceUsageRead); - auto scratch_size = sizeof(float) * hash.R * hash.Hq; + auto scratch_size = 0; if (attentionDesc.lowPrecisionInputs) { // Need scratch space for FP16 output. scratch_size += sizeof(float) * hash.R * hash.D * hash.Hq * attentionDesc.batchDimension; } - auto scratch = context->request_scratch(scratch_size); + if (!tensors[5]) { + // Need scratch space for LSE. + scratch_size += sizeof(float) * hash.R * hash.Hq; + } + auto scratch = scratch_size > 0 ? context->request_scratch(scratch_size) : NULL; if (attentionDesc.lowPrecisionInputs) { encoder->useResource(scratch, MTL::ResourceUsageRead | MTL::ResourceUsageWrite); } else { encoder->useResource(tensors[3], MTL::ResourceUsageWrite); - encoder->useResource(scratch, MTL::ResourceUsageRead | MTL::ResourceUsageWrite); + if (!tensors[5]) { + encoder->useResource(scratch, MTL::ResourceUsageRead | MTL::ResourceUsageWrite); + } + } + if (tensors[5]) { + encoder->useResource(tensors[5], MTL::ResourceUsageRead | MTL::ResourceUsageWrite); } encoder->setBuffer(tensors[0], tensor_offsets[0], AttentionOperand(AttentionOperand::Q).bufferIndex()); @@ -146,10 +129,18 @@ void ccv_nnc_mfa_encode_attention(mfa::context* context, ccv_nnc_mfa_attention_p encoder->setBuffer(tensors[2], tensor_offsets[2], AttentionOperand(AttentionOperand::V).bufferIndex()); if (attentionDesc.lowPrecisionInputs) { encoder->setBuffer(scratch, 0, AttentionOperand(AttentionOperand::O).bufferIndex()); - encoder->setBuffer(scratch, sizeof(float) * hash.R * hash.D * hash.Hq * attentionDesc.batchDimension, AttentionOperand(AttentionOperand::L).bufferIndex()); + if (tensors[5]) { + encoder->setBuffer(tensors[5], tensor_offsets[5], AttentionOperand(AttentionOperand::L).bufferIndex()); + } else { + encoder->setBuffer(scratch, sizeof(float) * hash.R * hash.D * hash.Hq * attentionDesc.batchDimension, AttentionOperand(AttentionOperand::L).bufferIndex()); + } } else { encoder->setBuffer(tensors[3], tensor_offsets[3], AttentionOperand(AttentionOperand::O).bufferIndex()); - encoder->setBuffer(scratch, 0, AttentionOperand(AttentionOperand::L).bufferIndex()); + if (tensors[5]) { + encoder->setBuffer(tensors[5], tensor_offsets[5], AttentionOperand(AttentionOperand::L).bufferIndex()); + } else { + encoder->setBuffer(scratch, 0, AttentionOperand(AttentionOperand::L).bufferIndex()); + } } // Calculate the grid size. @@ -190,6 +181,29 @@ void ccv_nnc_mfa_encode_attention(mfa::context* context, ccv_nnc_mfa_attention_p } return; } + + auto iterator = context->attention_cache.map.find(hash); + if (iterator == context->attention_cache.map.end()) { + mfa::precondition_failure("Attention hash not cached.", __LINE__, __FILE__, __FUNCTION__); + } + + auto* pipeline = iterator->second; + auto encoder = command_batch->startCommand(); + + uint16_t data_type_size = 0; + switch (params.data_type) { + case MTL::DataTypeHalf: { + data_type_size = 2; + break; + } + case MTL::DataTypeFloat: { + data_type_size = 4; + break; + } + default: + CCV_NNC_MFA_PRECONDITION(false); + break; + } // Simple broadcasting rules; not yet support for NumPy broadcasting rules. simd::ushort2 num_batch_dims(0); @@ -368,6 +382,9 @@ std::size_t std::hash::operator()(const mfa::attention::ha } mfa::attention::pipeline::pipeline(mfa::context* context, mfa::attention::hash hash) { + if (!hash.masked && hash.Hq == hash.Hk) { // Avoid pipeline setup if we use v2. + return; + } CCV_NNC_MFA_PRECONDITION((hash.data_type == MTL::DataTypeFloat) || (hash.data_type == MTL::DataTypeHalf)) auto* pool = NS::AutoreleasePool::alloc()->init();