Skip to content

Commit

Permalink
Pass in lse from the op.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Sep 15, 2024
1 parent c67441f commit a58acd2
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c
const int o_nd = ccv_nnc_tensor_nd(o->info.dim);
assert(o_nd == 3 || o_nd == 4);
assert(q_nd == k_nd && k_nd == v_nd && v_nd == o_nd);
ccv_nnc_tensor_view_t* const lse = output_size > 1 ? (ccv_nnc_tensor_view_t*)outputs[1] : 0;

int qdim[CCV_NNC_MAX_DIM_ALLOC];
int kdim[CCV_NNC_MAX_DIM_ALLOC];
Expand Down Expand Up @@ -185,20 +186,22 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c
if (params.masked) {
mask_buffer = mpgetbuffer((ccv_nnc_tensor_t*)attn_mask);
}
mtl_buffer_t* tensors[6] = {
mtl_buffer_t* tensors[7] = {
mpgetbuffer((ccv_nnc_tensor_t*)q),
mpgetbuffer((ccv_nnc_tensor_t*)k),
mpgetbuffer((ccv_nnc_tensor_t*)v),
mpgetbuffer((ccv_nnc_tensor_t*)o),
mask_buffer,
lse ? mpgetbuffer((ccv_nnc_tensor_t*)lse) : 0,
NULL,
};
size_t tensor_offsets[5] = {
size_t tensor_offsets[6] = {
q->dataof,
k->dataof,
v->dataof,
o->dataof,
attn_mask ? attn_mask->dataof : 0,
lse ? lse->dataof : 0,
};
ccv_nnc_mfa_encode_attention(context, params, command_batch, tensors, tensor_offsets);

Expand Down
83 changes: 50 additions & 33 deletions lib/nnc/mfa/ccv_nnc_mfa_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,34 +20,7 @@ void ccv_nnc_mfa_prepare_attention(mfa::context* context, ccv_nnc_mfa_attention_
void ccv_nnc_mfa_encode_attention(mfa::context* context, ccv_nnc_mfa_attention_params_t params, MTL::CommandBatch* command_batch, MTL::Buffer** tensors, size_t* tensor_offsets)
{
mfa::attention::hash hash(params);
auto iterator = context->attention_cache.map.find(hash);
if (iterator == context->attention_cache.map.end()) {
mfa::precondition_failure("Attention hash not cached.", __LINE__, __FILE__, __FUNCTION__);
}

auto* pipeline = iterator->second;
auto encoder = command_batch->startCommand();

int num_tensors = 0;
while (tensors[num_tensors] != nullptr) {
num_tensors += 1;
}
CCV_NNC_MFA_PRECONDITION(num_tensors == (hash.masked ? 5 : 4));

uint16_t data_type_size = 0;
switch (params.data_type) {
case MTL::DataTypeHalf: {
data_type_size = 2;
break;
}
case MTL::DataTypeFloat: {
data_type_size = 4;
break;
}
default:
CCV_NNC_MFA_PRECONDITION(false);
break;
}

if (!params.masked && params.Hq == params.Hk) {
simd::ushort2 num_batch_dims(0);
simd::uint2 batch_sizes(1);
Expand Down Expand Up @@ -121,35 +94,53 @@ void ccv_nnc_mfa_encode_attention(mfa::context* context, ccv_nnc_mfa_attention_p
auto pipeline = pipelineValue->pipeline;

// Allocate a new command.
auto encoder = command_batch->startCommand();
encoder->setComputePipelineState(pipeline.get());
encoder->setThreadgroupMemoryLength(kernel->threadgroupMemoryAllocation, 0);

// Bind the function arguments.
encoder->useResource(tensors[0], MTL::ResourceUsageRead);
encoder->useResource(tensors[1], MTL::ResourceUsageRead);
encoder->useResource(tensors[2], MTL::ResourceUsageRead);
auto scratch_size = sizeof(float) * hash.R * hash.Hq;
auto scratch_size = 0;
if (attentionDesc.lowPrecisionInputs) {
// Need scratch space for FP16 output.
scratch_size += sizeof(float) * hash.R * hash.D * hash.Hq * attentionDesc.batchDimension;
}
auto scratch = context->request_scratch(scratch_size);
if (!tensors[5]) {
// Need scratch space for LSE.
scratch_size += sizeof(float) * hash.R * hash.Hq;
}
auto scratch = scratch_size > 0 ? context->request_scratch(scratch_size) : NULL;
if (attentionDesc.lowPrecisionInputs) {
encoder->useResource(scratch, MTL::ResourceUsageRead | MTL::ResourceUsageWrite);
} else {
encoder->useResource(tensors[3], MTL::ResourceUsageWrite);
encoder->useResource(scratch, MTL::ResourceUsageRead | MTL::ResourceUsageWrite);
if (!tensors[5]) {
encoder->useResource(scratch, MTL::ResourceUsageRead | MTL::ResourceUsageWrite);
}
}
if (tensors[5]) {
encoder->useResource(tensors[5], MTL::ResourceUsageRead | MTL::ResourceUsageWrite);
}

encoder->setBuffer(tensors[0], tensor_offsets[0], AttentionOperand(AttentionOperand::Q).bufferIndex());
encoder->setBuffer(tensors[1], tensor_offsets[1], AttentionOperand(AttentionOperand::K).bufferIndex());
encoder->setBuffer(tensors[2], tensor_offsets[2], AttentionOperand(AttentionOperand::V).bufferIndex());
if (attentionDesc.lowPrecisionInputs) {
encoder->setBuffer(scratch, 0, AttentionOperand(AttentionOperand::O).bufferIndex());
encoder->setBuffer(scratch, sizeof(float) * hash.R * hash.D * hash.Hq * attentionDesc.batchDimension, AttentionOperand(AttentionOperand::L).bufferIndex());
if (tensors[5]) {
encoder->setBuffer(tensors[5], tensor_offsets[5], AttentionOperand(AttentionOperand::L).bufferIndex());
} else {
encoder->setBuffer(scratch, sizeof(float) * hash.R * hash.D * hash.Hq * attentionDesc.batchDimension, AttentionOperand(AttentionOperand::L).bufferIndex());
}
} else {
encoder->setBuffer(tensors[3], tensor_offsets[3], AttentionOperand(AttentionOperand::O).bufferIndex());
encoder->setBuffer(scratch, 0, AttentionOperand(AttentionOperand::L).bufferIndex());
if (tensors[5]) {
encoder->setBuffer(tensors[5], tensor_offsets[5], AttentionOperand(AttentionOperand::L).bufferIndex());
} else {
encoder->setBuffer(scratch, 0, AttentionOperand(AttentionOperand::L).bufferIndex());
}
}

// Calculate the grid size.
Expand Down Expand Up @@ -190,6 +181,29 @@ void ccv_nnc_mfa_encode_attention(mfa::context* context, ccv_nnc_mfa_attention_p
}
return;
}

auto iterator = context->attention_cache.map.find(hash);
if (iterator == context->attention_cache.map.end()) {
mfa::precondition_failure("Attention hash not cached.", __LINE__, __FILE__, __FUNCTION__);
}

auto* pipeline = iterator->second;
auto encoder = command_batch->startCommand();

uint16_t data_type_size = 0;
switch (params.data_type) {
case MTL::DataTypeHalf: {
data_type_size = 2;
break;
}
case MTL::DataTypeFloat: {
data_type_size = 4;
break;
}
default:
CCV_NNC_MFA_PRECONDITION(false);
break;
}

// Simple broadcasting rules; not yet support for NumPy broadcasting rules.
simd::ushort2 num_batch_dims(0);
Expand Down Expand Up @@ -368,6 +382,9 @@ std::size_t std::hash<mfa::attention::hash>::operator()(const mfa::attention::ha
}

mfa::attention::pipeline::pipeline(mfa::context* context, mfa::attention::hash hash) {
if (!hash.masked && hash.Hq == hash.Hk) { // Avoid pipeline setup if we use v2.
return;
}
CCV_NNC_MFA_PRECONDITION((hash.data_type == MTL::DataTypeFloat) || (hash.data_type == MTL::DataTypeHalf))

auto* pool = NS::AutoreleasePool::alloc()->init();
Expand Down

0 comments on commit a58acd2

Please sign in to comment.