From 3cc2117750b2b83755aab9b6072ca941f69b3931 Mon Sep 17 00:00:00 2001 From: Liu Liu Date: Fri, 13 Sep 2024 14:23:41 -0400 Subject: [PATCH] Pass square_attention_test. --- bin/nnc/square_attention_test.cpp | 444 +++++++++++++++++++ lib/nnc/mfa/v2/AttentionDescriptor.cpp | 14 +- lib/nnc/mfa/v2/AttentionDescriptor.hpp | 2 + lib/nnc/mfa/v2/AttentionKernel.cpp | 105 +++-- lib/nnc/mfa/v2/AttentionKernel.hpp | 5 + lib/nnc/mfa/v2/AttentionKernelDescriptor.cpp | 3 +- lib/nnc/mfa/v2/AttentionKernelDescriptor.hpp | 4 +- 7 files changed, 520 insertions(+), 57 deletions(-) create mode 100644 bin/nnc/square_attention_test.cpp diff --git a/bin/nnc/square_attention_test.cpp b/bin/nnc/square_attention_test.cpp new file mode 100644 index 000000000..64da1de14 --- /dev/null +++ b/bin/nnc/square_attention_test.cpp @@ -0,0 +1,444 @@ +extern "C" { +#include +#include +#include +#include +} +#include "nnc/mfa/v2/ShaderCache.hpp" +#include "nnc/mfa/v2/AttentionDescriptor.hpp" +#include "nnc/mfa/v2/AttentionKernelDescriptor.hpp" +#include "nnc/mfa/v2/AttentionKernel.hpp" +#include "3rdparty/dsfmt/dSFMT.h" +#include + +#include +#include +#include +#include +#include + +struct NetworkDescriptor { + int rowDimension; + int columnDimension; + int headDimension; + float scale; +}; + +class Network { +private: + int rowDimension; + int columnDimension; + int headDimension; + float scale; + + static std::pair boxMullerTransform() { + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution<> dis(0.0, 1.0); + + float u1 = dis(gen); + float u2 = dis(gen); + + float magnitudePart = std::sqrt(-2.0f * std::log(u1)); + float anglePart = 2.0f * M_PI * u2; + + return { + magnitudePart * std::cos(anglePart), + magnitudePart * std::sin(anglePart) + }; + } + +public: + std::vector Q; + std::vector K; + std::vector V; + std::vector dO; + + Network(const NetworkDescriptor& descriptor) + : rowDimension(descriptor.rowDimension), + columnDimension(descriptor.columnDimension), + headDimension(descriptor.headDimension), + scale(descriptor.scale), + Q(rowDimension * headDimension), + K(columnDimension * headDimension), + V(columnDimension * headDimension), + dO(rowDimension * headDimension) + { + if (rowDimension <= 0 || columnDimension <= 0 || headDimension <= 0) { + throw std::runtime_error("Descriptor was incomplete."); + } + + for (int rowID = 0; rowID < rowDimension; ++rowID) { + for (int d = 0; d < headDimension; ++d) { + int matrixAddress = rowID * headDimension + d; + auto [r1, r2] = boxMullerTransform(); + Q[matrixAddress] = r1; + dO[matrixAddress] = r2; + } + } + + for (int columnID = 0; columnID < columnDimension; ++columnID) { + for (int d = 0; d < headDimension; ++d) { + int matrixAddress = columnID * headDimension + d; + auto [r1, r2] = boxMullerTransform(); + K[matrixAddress] = r1; + V[matrixAddress] = r2; + } + } + } + + std::vector createMatrixSRow(int rowID) const { + std::vector output(columnDimension, 0.0f); + + for (int columnID = 0; columnID < columnDimension; ++columnID) { + float dotProduct = 0.0f; + for (int d = 0; d < headDimension; ++d) { + int addressQ = rowID * headDimension + d; + int addressK = columnID * headDimension + d; + dotProduct += Q[addressQ] * K[addressK]; + } + output[columnID] = dotProduct; + } + + return output; + } + + std::vector createMatrixPRow(int rowID) const { + std::vector output = createMatrixSRow(rowID); + float scaleFactor = scale; + + float maximum = *std::max_element(output.begin(), output.end()) * scaleFactor; + + float sum = 0.0f; + for (float& value : output) { + value *= scaleFactor; + float expTerm = std::exp(value - maximum); + sum += expTerm; + } + + float lse = maximum + std::log(sum); + for (float& value : output) { + value = std::exp(value - lse); + } + + return output; + } + + float createLTerm(int rowID) const { + std::vector matrixSRow = createMatrixSRow(rowID); + float scaleFactor = 1.0f / std::sqrt(static_cast(headDimension)); + + float maximum = *std::max_element(matrixSRow.begin(), matrixSRow.end()) * scaleFactor; + + float sum = 0.0f; + for (float value : matrixSRow) { + value *= scaleFactor; + float expTerm = std::exp(value - maximum); + sum += expTerm; + } + + return maximum + std::log(sum); + } + + std::vector createDerivativePRow(int rowID) const { + std::vector output(columnDimension, 0.0f); + for (int columnID = 0; columnID < columnDimension; ++columnID) { + float dotProduct = 0.0f; + for (int d = 0; d < headDimension; ++d) { + int addressO = rowID * headDimension + d; + int addressV = columnID * headDimension + d; + dotProduct += dO[addressO] * V[addressV]; + } + output[columnID] = dotProduct; + } + return output; + } + + std::vector createDerivativeSRow(int rowID) const { + std::vector matrixPRow = createMatrixPRow(rowID); + std::vector matrixORow(headDimension, 0.0f); + + for (int d = 0; d < headDimension; ++d) { + float dotProduct = 0.0f; + for (int columnID = 0; columnID < columnDimension; ++columnID) { + float valueP = matrixPRow[columnID]; + int addressV = columnID * headDimension + d; + dotProduct += valueP * V[addressV]; + } + matrixORow[d] = dotProduct; + } + + float termD = 0.0f; + for (int d = 0; d < headDimension; ++d) { + int addressDerivativeO = rowID * headDimension + d; + termD += matrixORow[d] * dO[addressDerivativeO]; + } + + std::vector derivativeSRow(columnDimension, 0.0f); + std::vector derivativePRow = createDerivativePRow(rowID); + float scaleFactor = 1.0f / std::sqrt(static_cast(headDimension)); + + for (int columnID = 0; columnID < columnDimension; ++columnID) { + float valueP = matrixPRow[columnID]; + float valueDerivativeP = derivativePRow[columnID]; + float valueS = valueP * (valueDerivativeP - termD); + valueS *= scaleFactor; + derivativeSRow[columnID] = valueS; + } + + return derivativeSRow; + } + + float createDTerm(int rowID) const { + std::vector matrixPRow = createMatrixPRow(rowID); + std::vector matrixORow(headDimension, 0.0f); + + for (int d = 0; d < headDimension; ++d) { + float dotProduct = 0.0f; + for (int columnID = 0; columnID < columnDimension; ++columnID) { + float valueP = matrixPRow[columnID]; + int addressV = columnID * headDimension + d; + dotProduct += valueP * V[addressV]; + } + matrixORow[d] = dotProduct; + } + + float termD = 0.0f; + for (int d = 0; d < headDimension; ++d) { + int addressDerivativeO = rowID * headDimension + d; + termD += matrixORow[d] * dO[addressDerivativeO]; + } + return termD; + } + + std::vector inferenceAttention() const { + std::vector output(rowDimension * headDimension, 0.0f); + for (int rowID = 0; rowID < rowDimension; ++rowID) { + std::vector matrixPRow = createMatrixPRow(rowID); + std::vector matrixORow(headDimension, 0.0f); + + for (int d = 0; d < headDimension; ++d) { + float dotProduct = 0.0f; + for (int columnID = 0; columnID < columnDimension; ++columnID) { + float valueP = matrixPRow[columnID]; + int addressV = columnID * headDimension + d; + dotProduct += valueP * V[addressV]; + } + matrixORow[d] = dotProduct; + } + + for (int d = 0; d < headDimension; ++d) { + float valueO = matrixORow[d]; + int addressO = rowID * headDimension + d; + output[addressO] = valueO; + } + } + + return output; + } + + float loss() const { + std::vector O = inferenceAttention(); + float output = 0.0f; + + for (int rowID = 0; rowID < rowDimension; ++rowID) { + for (int d = 0; d < headDimension; ++d) { + int address = rowID * headDimension + d; + output += dO[address] * O[address]; + } + } + return output; + } + + std::vector derivativeV() const { + std::vector output(columnDimension * headDimension, 0.0f); + + for (int rowID = 0; rowID < rowDimension; ++rowID) { + std::vector matrixPRow = createMatrixPRow(rowID); + + for (int columnID = 0; columnID < columnDimension; ++columnID) { + for (int d = 0; d < headDimension; ++d) { + int addressV = columnID * headDimension + d; + int addressDerivativeO = rowID * headDimension + d; + + output[addressV] += matrixPRow[columnID] * dO[addressDerivativeO]; + } + } + } + return output; + } + + std::vector derivativeK() const { + std::vector output(columnDimension * headDimension, 0.0f); + + for (int rowID = 0; rowID < rowDimension; ++rowID) { + std::vector derivativeSRow = createDerivativeSRow(rowID); + + for (int columnID = 0; columnID < columnDimension; ++columnID) { + for (int d = 0; d < headDimension; ++d) { + int addressK = columnID * headDimension + d; + int addressQ = rowID * headDimension + d; + + output[addressK] += derivativeSRow[columnID] * Q[addressQ]; + } + } + } + return output; + } + + std::vector derivativeQ() const { + std::vector output(rowDimension * headDimension, 0.0f); + + for (int rowID = 0; rowID < rowDimension; ++rowID) { + std::vector derivativeSRow = createDerivativeSRow(rowID); + std::vector derivativeQRow(headDimension, 0.0f); + + for (int d = 0; d < headDimension; ++d) { + float dotProduct = 0.0f; + for (int columnID = 0; columnID < columnDimension; ++columnID) { + float derivativeSValue = derivativeSRow[columnID]; + int addressK = columnID * headDimension + d; + dotProduct += derivativeSValue * K[addressK]; + } + derivativeQRow[d] = dotProduct; + } + + for (int d = 0; d < headDimension; ++d) { + float derivativeQValue = derivativeQRow[d]; + int addressQ = rowID * headDimension + d; + output[addressQ] = derivativeQValue; + } + } + + return output; + } +}; + +ShaderCache shaderCache; + +void validateProblemSize(int sequenceDimension, int headDimension) +{ + NetworkDescriptor networkDesc; + networkDesc.rowDimension = sequenceDimension; + networkDesc.columnDimension = sequenceDimension; + networkDesc.headDimension = headDimension; + networkDesc.scale = 1.0 / sqrtf((float)headDimension); + Network network(networkDesc); + AttentionDescriptor attentionDesc; + attentionDesc.lowPrecisionInputs = false; + attentionDesc.lowPrecisionIntermediates = false; + attentionDesc.matrixDimensions[0] = sequenceDimension; + attentionDesc.matrixDimensions[1] = sequenceDimension; + attentionDesc.matrixDimensions[2] = headDimension; + attentionDesc.transposeState[0] = false; + attentionDesc.transposeState[1] = false; + attentionDesc.transposeState[2] = false; + attentionDesc.transposeState[3] = false; + attentionDesc.type = AttentionKernelType::forward; + attentionDesc.scale = 1.0 / sqrtf((float)headDimension); + + DeviceProperties dprops; + dprops.coreCount = 18; + NS::SharedPtr device = NS::TransferPtr(MTL::CreateSystemDefaultDevice()); + NS::SharedPtr queue = NS::TransferPtr(device->newCommandQueue()); + { + // Generate the kernel. + auto pipelineValue = shaderCache.findKernel(attentionDesc, device.get(), dprops); + NS::SharedPtr bufferQ = NS::TransferPtr(device->newBuffer(network.Q.data(), sizeof(float) * sequenceDimension * headDimension, MTL::ResourceStorageModeShared | MTL::ResourceHazardTrackingModeTracked)); + NS::SharedPtr bufferK = NS::TransferPtr(device->newBuffer(network.K.data(), sizeof(float) * sequenceDimension * headDimension, MTL::ResourceStorageModeShared | MTL::ResourceHazardTrackingModeTracked)); + NS::SharedPtr bufferV = NS::TransferPtr(device->newBuffer(network.V.data(), sizeof(float) * sequenceDimension * headDimension, MTL::ResourceStorageModeShared | MTL::ResourceHazardTrackingModeTracked)); + float* resultO = (float*)ccmalloc(sizeof(float) * sequenceDimension * headDimension); + resultO[0] = NAN; + NS::SharedPtr bufferO = NS::TransferPtr(device->newBuffer(resultO, sizeof(float) * sequenceDimension * headDimension, MTL::ResourceStorageModeShared | MTL::ResourceHazardTrackingModeTracked)); + NS::SharedPtr commandBuffer = NS::TransferPtr(queue->commandBuffer()); + NS::SharedPtr encoder = NS::TransferPtr(commandBuffer->computeCommandEncoder()); + encoder->setComputePipelineState(pipelineValue->pipeline.get()); + encoder->setThreadgroupMemoryLength(pipelineValue->kernel->threadgroupMemoryAllocation, 0); + encoder->setBuffer(bufferQ.get(), 0, 0); + encoder->setBuffer(bufferK.get(), 0, 1); + encoder->setBuffer(bufferV.get(), 0, 2); + encoder->setBuffer(bufferO.get(), 0, 3); + encoder->useResource(bufferQ.get(), MTL::ResourceUsageRead); + encoder->useResource(bufferK.get(), MTL::ResourceUsageRead); + encoder->useResource(bufferV.get(), MTL::ResourceUsageRead); + encoder->useResource(bufferO.get(), MTL::ResourceUsageWrite); + auto ceilDivide = + [=](int64_t target, uint16_t granularity) -> int64_t { + return (target + int64_t(granularity) - 1) / int64_t(granularity); + }; + MTL::Size gridSize = MTL::Size(ceilDivide(sequenceDimension, pipelineValue->kernel->blockDimensions[0]), 1, 1); + MTL::Size groupSize = MTL::Size(pipelineValue->kernel->threadgroupSize, 1, 1); + encoder->dispatchThreadgroups(gridSize, groupSize); + encoder->endEncoding(); + commandBuffer->commit(); + commandBuffer->waitUntilCompleted(); + auto start = commandBuffer->GPUStartTime(); + auto end = commandBuffer->GPUEndTime(); + auto latency = end - start; + auto O = network.inferenceAttention(); + auto raw = bufferO->contents(); + for (int rowID = 0; rowID < sequenceDimension; rowID++) + { + for (int columnID = 0; columnID < headDimension; columnID++) + { + const int address = rowID * headDimension + columnID; + float entry32; + entry32 = ((float*)raw)[address]; + resultO[address] = entry32; + } + } + auto check = [=](std::vector expected, float* actual, float tolerance) { + int errorCount = 0; + for (int i = 0; i < expected.size(); i++) { + auto error = fabs(expected[i] - actual[i]); + if (error > tolerance || isnan(error)) { + // Don't report errors in this case. + if ((isnan(expected[i]) || isinf(expected[i])) && (isnan(actual[i]) || isinf(actual[i]))) { + continue; + } + + // Update the error count in the outer scope. + if (errorCount < 10) { + errorCount += 1; + std::cerr << "error: "<< error << " / ~1.000" << std::endl; + std::cerr << "- expected[" << i << "] =" << expected[i] << std::endl; + std::cerr << "- actual[" << i << "] =" << actual[i] << std::endl; + } + } + } + }; + if (attentionDesc.lowPrecisionInputs || attentionDesc.lowPrecisionIntermediates) { + check(O, resultO, 5e-2); + } else { + check(O, resultO, 2e-5); + } + } +} + +int main(int argc, char** argv) +{ + ccv_nnc_init(); + { + validateProblemSize(10, 3); + validateProblemSize(10, 80); + validateProblemSize(8, 2); + validateProblemSize(9, 2); + validateProblemSize(23, 2); + validateProblemSize(24, 2); + validateProblemSize(25, 2); + validateProblemSize(192, 77); + validateProblemSize(192, 80); + validateProblemSize(93, 32); + validateProblemSize(99, 35); + validateProblemSize(64, 32); + validateProblemSize(64, 34); + validateProblemSize(64, 36); + validateProblemSize(64, 40); + validateProblemSize(32, 64); + validateProblemSize(4, 1); + validateProblemSize(4, 2); + validateProblemSize(384, 95); + validateProblemSize(777, 199); + } + return 0; +} diff --git a/lib/nnc/mfa/v2/AttentionDescriptor.cpp b/lib/nnc/mfa/v2/AttentionDescriptor.cpp index 6fd30d3d2..909aed059 100644 --- a/lib/nnc/mfa/v2/AttentionDescriptor.cpp +++ b/lib/nnc/mfa/v2/AttentionDescriptor.cpp @@ -81,9 +81,9 @@ AttentionKernelDescriptor AttentionDescriptor::kernelDescriptor(MTL::Device *con }; if (device->supportsFamily(MTL::GPUFamily(1009))) { - return AttentionKernelDescriptor(createBlockDimensions(), createCacheState(), createHeadDimension(), createMemoryPrecisions(), true, false, createRegisterPrecisions(device), createTransposeState(), type); + return AttentionKernelDescriptor(createBlockDimensions(), createCacheState(), createHeadDimension(), createMemoryPrecisions(), true, false, createRegisterPrecisions(device), createTransposeState(), type, scale); } else { - return AttentionKernelDescriptor(createBlockDimensions(), createCacheState(), createHeadDimension(), createMemoryPrecisions(), false, true, createRegisterPrecisions(device), createTransposeState(), type); + return AttentionKernelDescriptor(createBlockDimensions(), createCacheState(), createHeadDimension(), createMemoryPrecisions(), false, true, createRegisterPrecisions(device), createTransposeState(), type, scale); } } @@ -100,12 +100,14 @@ std::pair *> Attention NS::String* swiftName = NS::String::string("attention", NS::UTF8StringEncoding); NS::Error* error = nil; - - auto function = NS::TransferPtr - (library->newFunction(swiftName, constants.get(), &error)); + + auto pipelineDesc = NS::TransferPtr(MTL::ComputePipelineDescriptor::alloc()->init()); + pipelineDesc->setComputeFunction(NS::TransferPtr + (library->newFunction(swiftName, constants.get(), &error)).get()); + pipelineDesc->setMaxTotalThreadsPerThreadgroup(1024); CCV_NNC_MFA_CHECK_ERROR(error); - auto pipeline = device->newComputePipelineState(function.get(), &error); + auto pipeline = device->newComputePipelineState(pipelineDesc.get(), MTL::PipelineOptionNone, NULL, &error); CCV_NNC_MFA_CHECK_ERROR(error); return pipeline; }; diff --git a/lib/nnc/mfa/v2/AttentionDescriptor.hpp b/lib/nnc/mfa/v2/AttentionDescriptor.hpp index eb6e4562a..42fb8463a 100644 --- a/lib/nnc/mfa/v2/AttentionDescriptor.hpp +++ b/lib/nnc/mfa/v2/AttentionDescriptor.hpp @@ -29,6 +29,8 @@ struct AttentionDescriptor { AttentionKernelType type; + float scale; + bool operator==(const AttentionDescriptor& rhs) const; std::pair *> findKernel(MTL::Device* const device, const DeviceProperties &dprops, std::unordered_map> *const libraryCache) const noexcept; diff --git a/lib/nnc/mfa/v2/AttentionKernel.cpp b/lib/nnc/mfa/v2/AttentionKernel.cpp index bdfc7f387..1f1ac7f76 100644 --- a/lib/nnc/mfa/v2/AttentionKernel.cpp +++ b/lib/nnc/mfa/v2/AttentionKernel.cpp @@ -4,25 +4,28 @@ #include "../ccv_nnc_mfa.hpp" #include +#include AttentionKernel::AttentionKernel(AttentionKernelDescriptor descriptor, MTL::Device *const device) { - this->type = descriptor.type; - this->cacheState = descriptor.cacheState; - this->memoryPrecisions = descriptor.memoryPrecisions; - this->preferAsyncCache = descriptor.preferAsyncCache; - this->preferAsyncLoad = descriptor.preferAsyncLoad; - this->registerPrecisions = descriptor.registerPrecisions; - this->transposeState = descriptor.transposeState; + type = descriptor.type; + cacheState = descriptor.cacheState; + memoryPrecisions = descriptor.memoryPrecisions; + preferAsyncCache = descriptor.preferAsyncCache; + preferAsyncLoad = descriptor.preferAsyncLoad; + registerPrecisions = descriptor.registerPrecisions; + transposeState = descriptor.transposeState; - this->blockDimensions = descriptor.blockDimensions; - this->headDimension = descriptor.headDimension; + blockDimensions = descriptor.blockDimensions; + headDimension = descriptor.headDimension; - source = createSource(); + scale = descriptor.scale; - std::cout << source << std::endl; + source = createSource(); threadgroupMemoryAllocation = createThreadgroupMemoryAllocation(); + threadgroupSize = 32 * (blockDimensions[0] / 8); + // Compile the shader source. { auto string = NS::String::string(source.c_str(), NS::UTF8StringEncoding); @@ -843,7 +846,7 @@ std::string AttentionKernel::accumulate(const AttentionAccumulateDescriptor& acc source.SetValue("MEMORY_NAME_C", memoryName(C)); source.SetValue("LEADING_DIMENSION_C", leadingDimension(C)); source.SetValue("LEADING_BLOCK_DIMENSION_C", std::to_string(leadingBlockDimension(C))); - source.SetValue("TRANSPOSED_C", std::to_string(transposed(C))); + source.SetValue("TRANSPOSED_C", transposed(C) ? "true" : "false"); switch (descriptor.addressSpaceLHS.value) { case MTLAddressSpace::device: source += R"( @@ -888,7 +891,7 @@ std::string AttentionKernel::accumulate(const AttentionAccumulateDescriptor& acc source.SetValue("MEMORY_NAME_C", memoryName(C)); source.SetValue("LEADING_DIMENSION_C", leadingDimension(C)); source.SetValue("LEADING_BLOCK_DIMENSION_C", std::to_string(leadingBlockDimension(C))); - source.SetValue("TRANSPOSED_C", std::to_string(transposed(C))); + source.SetValue("TRANSPOSED_C", transposed(C) ? "true" : "false"); source += R"( threadgroup_barrier(mem_flags::mem_threadgroup); @@ -931,7 +934,7 @@ std::string AttentionKernel::accumulate(const AttentionAccumulateDescriptor& acc source.SetValue("MEMORY_NAME_C", memoryName(C)); source.SetValue("LEADING_DIMENSION_C", leadingDimension(C)); source.SetValue("LEADING_BLOCK_DIMENSION_C", std::to_string(leadingBlockDimension(C))); - source.SetValue("TRANSPOSED_C", std::to_string(transposed(C))); + source.SetValue("TRANSPOSED_C", transposed(C) ? "true" : "false"); source += R"( threadgroup_barrier(mem_flags::mem_threadgroup); @@ -972,7 +975,7 @@ std::string AttentionKernel::accumulate(const AttentionAccumulateDescriptor& acc source.SetValue("LOAD_FUNCTION_C", loadFunction(C)); source.SetValue("LEADING_DIMENSION_C", leadingDimension(C)); source.SetValue("LEADING_BLOCK_DIMENSION_C", std::to_string(leadingBlockDimension(C))); - source.SetValue("TRANSPOSED_C", std::to_string(transposed(C))); + source.SetValue("TRANSPOSED_C", transposed(C) ? "true" : "false"); switch (descriptor.addressSpaceLHS.value) { case MTLAddressSpace::device: source += R"( @@ -1019,7 +1022,7 @@ std::string AttentionKernel::accumulate(const AttentionAccumulateDescriptor& acc source.SetValue("STORE_FUNCTION_C", storeFunction(C)); source.SetValue("LEADING_DIMENSION_C", leadingDimension(C)); source.SetValue("LEADING_BLOCK_DIMENSION_C", std::to_string(leadingBlockDimension(C))); - source.SetValue("TRANSPOSED_C", std::to_string(transposed(C))); + source.SetValue("TRANSPOSED_C", transposed(C) ? "true" : "false"); source.SetValue("UNSAFE_PARALLELIZATION_THREAD_OFFSET", unsafeParallelizationThreadOffsetValue()); source.SetValue("PARALLELIZATION_DIMENSION", parallelizationDimensionValue()); switch (descriptor.addressSpaceLHS.value) { @@ -1081,7 +1084,7 @@ std::string AttentionKernel::accumulate(const AttentionAccumulateDescriptor& acc source.SetValue("MEMORY_NAME_B", memoryName(B)); source.SetValue("LEADING_DIMENSION_B", leadingDimension(B)); source.SetValue("LEADING_BLOCK_DIMENSION_B", std::to_string(leadingBlockDimension(B))); - source.SetValue("TRANSPOSED_B", std::to_string(transposed(B))); + source.SetValue("TRANSPOSED_B", transposed(B) ? "true" : "false"); source.SetValue("TRAVERSAL_OFFSET", traversalOffsetValue()); switch (descriptor.addressSpaceRHS.value) { case MTLAddressSpace::device: @@ -1127,7 +1130,7 @@ std::string AttentionKernel::accumulate(const AttentionAccumulateDescriptor& acc source.SetValue("MEMORY_NAME_B", memoryName(B)); source.SetValue("LEADING_DIMENSION_B", leadingDimension(B)); source.SetValue("LEADING_BLOCK_DIMENSION_B", std::to_string(leadingBlockDimension(B))); - source.SetValue("TRANSPOSED_B", std::to_string(transposed(B))); + source.SetValue("TRANSPOSED_B", transposed(B) ? "true" : "false"); source.SetValue("TRAVERSAL_OFFSET", traversalOffsetValue()); source.SetValue("BLOCK_DIMENSIONS_HEAD", std::to_string(blockDimensions[2])); source.SetValue("HEAD_DIMENSION", std::to_string(headDimension)); @@ -1185,7 +1188,7 @@ std::string AttentionKernel::accumulate(const AttentionAccumulateDescriptor& acc source.SetValue("B", B.name()); source.SetValue("C", C.name()); source.SetValue("LOAD_FUNCTION_B", loadFunction(B)); - source.SetValue("TRANSPOSED_B", std::to_string(transposed(B))); + source.SetValue("TRANSPOSED_B", transposed(B) ? "true" : "false"); source.SetValue("LEADING_DIMENSION_RHS", leadingDimensionRHS(descriptor)); source += R"( @@ -1416,7 +1419,7 @@ std::string AttentionKernel::accumulate(const AttentionAccumulateDescriptor& acc CodeWriter source; source.SetValue("LOOP_END_FLOOR", std::to_string(loopEndFloor())); - source.SetValue("LOOP_END_FLOOR_LESS_LOOP_END", std::to_string(loopEndFloor() < loopEnd())); + source.SetValue("LOOP_END_FLOOR_LESS_LOOP_END", (loopEndFloor() < loopEnd()) ? "true" : "false"); source.SetValue("GATED_LOOP_ITERATION", gatedLoopIteration(descriptor)); source += R"( @@ -1465,7 +1468,7 @@ std::string AttentionKernel::cache(AttentionOperand operand, CachingOperationTyp source.SetValue("OPERAND", operand.name()); source.SetValue("LEADING_BLOCK_DIMENSION_OPERAND", std::to_string(leadingBlockDimension(operand))); source.SetValue("LEADING_DIMENSION_OPERAND", leadingDimension(operand)); - source.SetValue("TRANSPOSED_OPERAND", std::to_string(transposed(operand))); + source.SetValue("TRANSPOSED_OPERAND", transposed(operand) ? "true" : "false"); source.SetValue("HEAD_DIMENSION", std::to_string(headDimension)); source.SetValue("PADED_HEAD_DIMENSION", std::to_string(paddedHeadDimensionValue())); source.SetValue("BLOCK_DIMENSIONS_HEAD", std::to_string(blockDimensions[2])); @@ -1511,7 +1514,7 @@ std::string AttentionKernel::cache(AttentionOperand operand, CachingOperationTyp source.SetValue("OPERAND", operand.name()); source.SetValue("LEADING_BLOCK_DIMENSION_OPERAND", std::to_string(leadingBlockDimension(operand))); source.SetValue("LEADING_DIMENSION_OPERAND", leadingDimension(operand)); - source.SetValue("TRANSPOSED_OPERAND", std::to_string(transposed(operand))); + source.SetValue("TRANSPOSED_OPERAND", transposed(operand) ? "true" : "false"); source.SetValue("HEAD_DIMENSION", std::to_string(headDimension)); source.SetValue("BLOCK_DIMENSIONS_HEAD", std::to_string(blockDimensions[2])); source.SetValue("PARALLELIZATION_DIMENSION", parallelizationDimensionValue()); @@ -1569,7 +1572,7 @@ std::string AttentionKernel::cache(AttentionOperand operand, CachingOperationTyp source.SetValue("MEMORY_NAME_OPERAND", memoryName(operand)); source.SetValue("OPERAND", operand.name()); source.SetValue("LEADING_DIMENSION_OPERAND", leadingDimension(operand)); - source.SetValue("TRANSPOSED_OPERAND", std::to_string(transposed(operand))); + source.SetValue("TRANSPOSED_OPERAND", transposed(operand) ? "true" : "false"); source.SetValue("CLAMPED_PARALLELIZATION_THREAD_OFFSET", clampedParallelizationThreadOffsetValue()); source += R"( @@ -1588,7 +1591,7 @@ std::string AttentionKernel::cache(AttentionOperand operand, CachingOperationTyp source.SetValue("MEMORY_NAME_OPERAND", memoryName(operand)); source.SetValue("OPERAND", operand.name()); source.SetValue("LEADING_BLOCK_DIMENSION_OPERAND", std::to_string(leadingBlockDimension(operand))); - source.SetValue("TRANSPOSED_OPERAND", std::to_string(transposed(operand))); + source.SetValue("TRANSPOSED_OPERAND", transposed(operand) ? "true" : "false"); source += R"( ushort2 {{OPERAND}}_block_offset( @@ -1617,7 +1620,7 @@ std::string AttentionKernel::cache(AttentionOperand operand, CachingOperationTyp source.SetValue("HEAD_END", std::to_string(headEnd)); source.SetValue("OPERAND", operand.name()); source.SetValue("LEADING_DIMENSION_OPERAND", leadingDimensionOperand(descriptor)); - source.SetValue("TRANSPOSED_OPERAND", std::to_string(transposed(operand))); + source.SetValue("TRANSPOSED_OPERAND", transposed(operand) ? "true" : "false"); if (type == CachingOperationType::load) { source.SetValue("LOAD_FUNCTION_OPERAND", loadFunction(operand)); source += R"( @@ -1673,7 +1676,7 @@ std::string AttentionKernel::cache(AttentionOperand operand, CachingOperationTyp if (descriptor.addressSpace == MTLAddressSpace::device) { CodeWriter source; source.SetValue("DECLARE_OPERAND_LOCATION", declareOperandLocation(descriptor)); - source.SetValue("TYPE_IS_LOAD", std::to_string(type == CachingOperationType::load)); + source.SetValue("TYPE_IS_LOAD", (type == CachingOperationType::load) ? "true" : "false"); source.SetValue("UNSAFE_PARALLELIZATION_THREAD_OFFSET", unsafeParallelizationThreadOffsetValue()); source.SetValue("PARALLELIZATION_DIMENSION", parallelizationDimensionValue()); source.SetValue("INNER_LOOP_HEAD", innerLoopHead(0, blockDimensions[2], descriptor)); @@ -1721,7 +1724,7 @@ std::string AttentionKernel::cache(AttentionOperand operand, CachingOperationTyp descriptorDevice.addressSpace = MTLAddressSpace::device; descriptorThreadgroup.addressSpace = MTLAddressSpace::threadgroup; CodeWriter source; - source.SetValue("NOT_PREFER_ASYNC_CACHE", std::to_string(!preferAsyncCache)); + source.SetValue("NOT_PREFER_ASYNC_CACHE", !preferAsyncCache ? "true" : "false"); source.SetValue("HEAD_DIMENSION", std::to_string(headDimension)); source.SetValue("BLOCK_DIMENSIONS_HEAD", std::to_string(blockDimensions[2])); source.SetValue("LOOP_ITERATION_DEVICE", loopIteration(descriptorDevice)); @@ -2002,7 +2005,7 @@ std::string AttentionKernel::outerProduct(const AttentionOuterProductDescriptor& source.SetValue("A", A.name()); source.SetValue("MEMORY_NAME_A", memoryName(A)); source.SetValue("CLAMPED_PARALLELIZATION_THREAD_OFFSET", clampedParallelizationThreadOffsetValue()); - source.SetValue("TRANSPOSED_A", std::to_string(transposed(A))); + source.SetValue("TRANSPOSED_A", transposed(A) ? "true" : "false"); switch (descriptor.addressSpaceLHS.value) { case MTLAddressSpace::device: source.SetValue("LEADING_DIMENSION_A", leadingDimension(A)); @@ -2043,7 +2046,7 @@ std::string AttentionKernel::outerProduct(const AttentionOuterProductDescriptor& source.SetValue("A", A.name()); source.SetValue("MEMORY_NAME_A", memoryName(A)); source.SetValue("CLAMPED_PARALLELIZATION_THREAD_OFFSET", clampedParallelizationThreadOffsetValue()); - source.SetValue("TRANSPOSED_A", std::to_string(transposed(A))); + source.SetValue("TRANSPOSED_A", transposed(A) ? "true" : "false"); source.SetValue("LEADING_DIMENSION_A", leadingDimension(A)); source.SetValue("LEADING_BLOCK_DIMENSION_A", std::to_string(leadingBlockDimension(A))); source.SetValue("BLOCK_DIMENSIONS_HEAD", std::to_string(blockDimensions[2])); @@ -2093,7 +2096,7 @@ std::string AttentionKernel::outerProduct(const AttentionOuterProductDescriptor& source.SetValue("A", A.name()); source.SetValue("DESCRIPTOR_REGISTER_SIZE", std::to_string(descriptor.registerSize)); source.SetValue("LOAD_FUNCTION_A", loadFunction(A)); - source.SetValue("TRANSPOSED_A", std::to_string(transposed(A))); + source.SetValue("TRANSPOSED_A", transposed(A) ? "true" : "false"); source.SetValue("DECLARE_LHS_LOCATION", declareLHSLocation(descriptor)); switch (descriptor.addressSpaceLHS.value) { case MTLAddressSpace::device: @@ -2154,7 +2157,7 @@ std::string AttentionKernel::outerProduct(const AttentionOuterProductDescriptor& case MTLAddressSpace::device: source.SetValue("TRAVERSAL_OFFSET", traversalOffsetValue()); source.SetValue("LEADING_DIMENSION_B", leadingDimension(B)); - source.SetValue("TRANSPOSED_B", std::to_string(transposed(B))); + source.SetValue("TRANSPOSED_B", transposed(B) ? "true" : "false"); source += R"( uint2 {{B}}_src_offset( @@ -2169,7 +2172,7 @@ std::string AttentionKernel::outerProduct(const AttentionOuterProductDescriptor& break; case MTLAddressSpace::threadgroup: source.SetValue("LEADING_BLOCK_DIMENSION_B", std::to_string(leadingBlockDimension(B))); - source.SetValue("NOT_TRANSPOSED_B", std::to_string(!transposed(B))); + source.SetValue("NOT_TRANSPOSED_B", !transposed(B) ? "true" : "false"); source += R"( ushort2 {{B}}_block_offset( @@ -2204,7 +2207,7 @@ std::string AttentionKernel::outerProduct(const AttentionOuterProductDescriptor& source.SetValue("HEAD_DIMENSION", std::to_string(headDimension)); source.SetValue("PADDED_TRAVERSAL_EDGE", paddedTraversalEdgeValue()); source.SetValue("LEADING_DIMENSION_B", leadingDimension(B)); - source.SetValue("TRANSPOSED_B", std::to_string(transposed(B))); + source.SetValue("TRANSPOSED_B", transposed(B) ? "true" : "false"); source.SetValue("LEADING_BLOCK_DIMENSION_B", std::to_string(leadingBlockDimension(B))); source.SetValue("DESCRIPTOR_REGISTER_SIZE", std::to_string(descriptor.registerSize)); source.SetValue("DECLARE_RHS_LOCATION", declareRHSLocation(descriptor)); @@ -2261,7 +2264,7 @@ std::string AttentionKernel::outerProduct(const AttentionOuterProductDescriptor& source.SetValue("REGISTER_NAME_B", registerName(B)); source.SetValue("LOAD_FUNCTION_B", loadFunction(B)); source.SetValue("LEADING_DIMENSION_RHS", leadingDimensionRHS(descriptor)); - source.SetValue("NOT_TRANSPOSED_B", std::to_string(!transposed(B))); + source.SetValue("NOT_TRANSPOSED_B", !transposed(B) ? "true" : "false"); source.SetValue("DESCRIPTOR_REGISTER_OFFSET", descriptor.registerOffset); source.SetValue("DESCRIPTOR_ACCUMULATE_CONDITIONAL", descriptor.accumulateConditional); source += R"( @@ -2466,7 +2469,7 @@ std::string AttentionKernel::outerProduct(const AttentionOuterProductDescriptor& CodeWriter source; source.SetValue("LOOP_END_FLOOR", std::to_string(loopEndFloor())); - source.SetValue("LOOP_END_FLOOR_LESS_LOOP_END", std::to_string(loopEndFloor() < loopEnd())); + source.SetValue("LOOP_END_FLOOR_LESS_LOOP_END", (loopEndFloor() < loopEnd()) ? "true" : "false"); source.SetValue("GATED_LOOP_ITERATION", gatedLoopIteration(descriptor)); source += R"( @@ -2485,14 +2488,19 @@ std::string AttentionKernel::outerProduct(const AttentionOuterProductDescriptor& // MARK: - AttentionKernel+Softmax -static std::string dotProductScale(bool derivative, unsigned short headDimension) { +static std::string high_precision_to_string(float value) { + std::ostringstream oss; + oss << std::setprecision(std::numeric_limits::max_digits10) << value; + return oss.str(); +} + +static std::string dotProductScale(float rsqrtD, bool derivative, unsigned short headDimension) { float logBase2E = 1.442695041; - float rsqrtD = 1 / sqrt((float)headDimension); if (!derivative) { - return std::to_string(logBase2E * rsqrtD); + return high_precision_to_string(logBase2E * rsqrtD); } else { - return std::to_string(rsqrtD); + return high_precision_to_string(rsqrtD); } } std::string AttentionKernel::computeD() const noexcept { @@ -2509,7 +2517,7 @@ std::string AttentionKernel::computeD() const noexcept { CodeWriter source; source.SetValue("MEMORY_NAME_DO", memoryName(AttentionOperand::dO)); source.SetValue("LEADING_DIMENSION_DO", leadingDimension(AttentionOperand::dO)); - source.SetValue("TRANSPOSED_DO", std::to_string(transposed(AttentionOperand::dO))); + source.SetValue("TRANSPOSED_DO", transposed(AttentionOperand::dO) ? "true" : "false"); source += R"( // Where the dO data will be read from. @@ -2536,7 +2544,7 @@ std::string AttentionKernel::computeD() const noexcept { source.SetValue("REGISTER_NAME_DO", registerName(AttentionOperand::dO)); source.SetValue("LOAD_FUNCTION_DO", loadFunction(AttentionOperand::dO)); source.SetValue("LEADING_DIMENSION_DO", leadingDimension(AttentionOperand::dO)); - source.SetValue("TRANSPOSED_DO", std::to_string(transposed(AttentionOperand::dO))); + source.SetValue("TRANSPOSED_DO", transposed(AttentionOperand::dO) ? "true" : "false"); source += R"( simdgroup_matrix_storage<{{REGISTER_NAME_DO}}> dO; @@ -2557,7 +2565,7 @@ std::string AttentionKernel::computeD() const noexcept { source.SetValue("REGISTER_NAME_O", registerName(AttentionOperand::O)); source.SetValue("LOAD_FUNCTION_O", loadFunction(AttentionOperand::O)); source.SetValue("LEADING_DIMENSION_O", leadingDimension(AttentionOperand::O)); - source.SetValue("TRANSPOSED_O", std::to_string(transposed(AttentionOperand::O))); + source.SetValue("TRANSPOSED_O", transposed(AttentionOperand::O) ? "true" : "false"); source.SetValue("TRUNCATED_HEAD_DIMENSION", std::to_string(truncatedHeadDimension)); source += R"( @@ -2628,13 +2636,13 @@ std::string AttentionKernel::computeD() const noexcept { source.SetValue("LOAD_FUNCTION_DO", registerName(AttentionOperand::dO)); source.SetValue("LEADING_DIMENSION_DO", leadingDimension(AttentionOperand::dO)); source.SetValue("LEADING_BLOCK_DIMENSION_DO", std::to_string(leadingBlockDimension(AttentionOperand::dO))); - source.SetValue("TRANSPOSED_DO", std::to_string(transposed(AttentionOperand::dO))); + source.SetValue("TRANSPOSED_DO", transposed(AttentionOperand::dO) ? "true" : "false"); source.SetValue("MEMORY_NAME_O", memoryName(AttentionOperand::O)); source.SetValue("REGISTER_NAME_O", registerName(AttentionOperand::O)); source.SetValue("LOAD_FUNCTION_O", registerName(AttentionOperand::O)); source.SetValue("LEADING_DIMENSION_O", leadingDimension(AttentionOperand::O)); source.SetValue("LEADING_BLOCK_DIMENSION_O", std::to_string(leadingBlockDimension(AttentionOperand::O))); - source.SetValue("TRANSPOSED_O", std::to_string(transposed(AttentionOperand::O))); + source.SetValue("TRANSPOSED_O", transposed(AttentionOperand::O) ? "true" : "false"); source.SetValue("BLOCK_BYTES_DERIVATIVE_O", std::to_string(blockBytesDerivativeO())); source.SetValue("BLOCK_DIMENSIONS_PARALLELIZATION", std::to_string(blockDimensions[0])); source.SetValue("HEAD_DIMENSION", std::to_string(headDimension)); @@ -2721,7 +2729,7 @@ std::string AttentionKernel::computeD() const noexcept { CodeWriter source; source.SetValue("BULK_CONTRIBUTIONS", bulkContributions(loopEndFloor)); source.SetValue("EDGE_CONTRIBUTIONS", edgeContributions(loopEndFloor)); - source.SetValue("DOT_PRODUCT_SCALE", dotProductScale(true, headDimension)); + source.SetValue("DOT_PRODUCT_SCALE", dotProductScale(scale, true, headDimension)); source += R"( float2 D_accumulator(0); @@ -2788,7 +2796,7 @@ std::string AttentionKernel::onlineReduceMaximum() const noexcept { CodeWriter source; source.SetValue("REGISTER_NAME_S", registerName(AttentionOperand::S)); source.SetValue("BLOCK_DIMENSIONS_TRAVERSAL", std::to_string(blockDimensions[1])); - source.SetValue("DOT_PRODUCT_SCALE", dotProductScale(false, headDimension)); + source.SetValue("DOT_PRODUCT_SCALE", dotProductScale(scale, false, headDimension)); source += R"( // update 'm' @@ -2944,8 +2952,7 @@ std::string AttentionKernel::softmax(bool derivative) const noexcept { auto overwriteAttentionMatrixElements = [=]() -> std::string { CodeWriter source; - auto scale = dotProductScale(derivative, headDimension); - source.SetValue("SCALE", scale); + source.SetValue("SCALE", dotProductScale(scale, derivative, headDimension)); if (!derivative) { source.SetValue("REGISTER_NAME_P", registerName(AttentionOperand::P)); @@ -3042,7 +3049,7 @@ std::string AttentionKernel::softmax(bool derivative) const noexcept { case AttentionKernelType::backwardKeyValue: auto blockDim = blockDimensions[1]; source.SetValue("BLOCK_DIM", std::to_string(blockDim)); - source.SetValue("NOT_PREFER_ASYNC_LOAD", std::to_string(!preferAsyncLoad)); + source.SetValue("NOT_PREFER_ASYNC_LOAD", !preferAsyncLoad ? "true" : "false"); source.SetValue("TRAVERSAL_DIMENSION", traversalDimensionValue()); source.SetValue("TRAVERSAL_OFFSET", traversalOffsetValue()); source.SetValue("LOAD_OPERAND", loadOperand()); diff --git a/lib/nnc/mfa/v2/AttentionKernel.hpp b/lib/nnc/mfa/v2/AttentionKernel.hpp index c4f227798..20e0998e9 100644 --- a/lib/nnc/mfa/v2/AttentionKernel.hpp +++ b/lib/nnc/mfa/v2/AttentionKernel.hpp @@ -17,6 +17,8 @@ struct AttentionKernel { AttentionKernelType type; + float scale; + AttentionOperands cacheState; AttentionOperands memoryPrecisions; @@ -36,6 +38,9 @@ struct AttentionKernel { unsigned short threadgroupMemoryAllocation; + /// The number of threads per group. + uint16_t threadgroupSize; + AttentionKernel(AttentionKernelDescriptor descriptor, MTL::Device *const device); private: diff --git a/lib/nnc/mfa/v2/AttentionKernelDescriptor.cpp b/lib/nnc/mfa/v2/AttentionKernelDescriptor.cpp index e9f1efd6d..f81291858 100644 --- a/lib/nnc/mfa/v2/AttentionKernelDescriptor.cpp +++ b/lib/nnc/mfa/v2/AttentionKernelDescriptor.cpp @@ -28,7 +28,7 @@ std::size_t std::hash::operator()(const AttentionKern // MARK: - Initializer -AttentionKernelDescriptor::AttentionKernelDescriptor(simd::ushort3 blockDimensions, AttentionOperands cacheState, unsigned short headDimension, AttentionOperands memoryPrecisions, bool preferAsyncCache, bool preferAsyncLoad, AttentionOperands registerPrecisions, AttentionOperands transposeState, AttentionKernelType type) noexcept { +AttentionKernelDescriptor::AttentionKernelDescriptor(simd::ushort3 blockDimensions, AttentionOperands cacheState, unsigned short headDimension, AttentionOperands memoryPrecisions, bool preferAsyncCache, bool preferAsyncLoad, AttentionOperands registerPrecisions, AttentionOperands transposeState, AttentionKernelType type, float scale) noexcept { this->blockDimensions = blockDimensions; this->cacheState = cacheState; this->headDimension = headDimension; @@ -38,4 +38,5 @@ AttentionKernelDescriptor::AttentionKernelDescriptor(simd::ushort3 blockDimensio this->registerPrecisions = registerPrecisions; this->transposeState = transposeState; this->type = type; + this->scale = scale; } diff --git a/lib/nnc/mfa/v2/AttentionKernelDescriptor.hpp b/lib/nnc/mfa/v2/AttentionKernelDescriptor.hpp index 87bc0e8cc..63989c7f4 100644 --- a/lib/nnc/mfa/v2/AttentionKernelDescriptor.hpp +++ b/lib/nnc/mfa/v2/AttentionKernelDescriptor.hpp @@ -48,12 +48,14 @@ struct AttentionKernelDescriptor { AttentionKernelType type; + float scale; + // MARK: - Functionality from AttentionDescriptor AttentionKernelDescriptor() = delete; /// Initialize the kernel descriptor. - AttentionKernelDescriptor(simd::ushort3 blockDimensions, AttentionOperands cacheState, unsigned short headDimension, AttentionOperands memoryPrecisions, bool preferAsyncCache, bool preferAsyncLoad, AttentionOperands registerPrecisions, AttentionOperands transposeState, AttentionKernelType type) noexcept; + AttentionKernelDescriptor(simd::ushort3 blockDimensions, AttentionOperands cacheState, unsigned short headDimension, AttentionOperands memoryPrecisions, bool preferAsyncCache, bool preferAsyncLoad, AttentionOperands registerPrecisions, AttentionOperands transposeState, AttentionKernelType type, float scale) noexcept; bool operator==(const AttentionKernelDescriptor& rhs) const; };