diff --git a/lib/nnc/mfa/v2/AttentionKernel.cpp b/lib/nnc/mfa/v2/AttentionKernel.cpp index 3dcc8aab7..ee7273d17 100644 --- a/lib/nnc/mfa/v2/AttentionKernel.cpp +++ b/lib/nnc/mfa/v2/AttentionKernel.cpp @@ -499,6 +499,29 @@ std::string AttentionKernel::createBufferBindings() const noexcept { return output; } +std::string AttentionKernel::operandLocationWithHeadOffsetValue(AttentionOperand operand) const noexcept { + CodeWriter source; + source.SetValue("OPERAND", operand.name()); + if (operand.value == AttentionOperand::L || operand.value == AttentionOperand::D) { + source += "{{OPERAND}} + (gid.z * Hq + gid.y) * R\\"; + } else if (Hq > 1) { + source.SetValue("HEAD_DIMENSION", std::to_string(headDimension)); + if (!transposed(operand)) { + source += "{{OPERAND}} + gid.z * {{OPERAND}}_batch_stride + gid.y * {{HEAD_DIMENSION}}\\"; + } else { + source.SetValue("SEQUENCE_LENGTH", sequenceLength(operand)); + source += "{{OPERAND}} + gid.z * {{OPERAND}}_batch_stride + gid.y * {{HEAD_DIMENSION}} * {{SEQUENCE_LENGTH}}\\"; + } + } else { + source += "{{OPERAND}} + gid.z * {{OPERAND}}_batch_stride\\"; + } + return source.ToString(); +} + +std::string AttentionKernel::operandLocationValue(AttentionOperand operand) const noexcept { + return operand.name(); +} + std::string AttentionKernel::createAdjustOffsets() const noexcept { std::vector operands; switch (type.value) { @@ -515,20 +538,10 @@ std::string AttentionKernel::createAdjustOffsets() const noexcept { CodeWriter source; for (const auto& operand : operands) { source.SetValue("OPERAND", operand.name()); - if (operand.value == AttentionOperand::L || operand.value == AttentionOperand::D) { + source.SetValue("OPERAND_LOCATION", operandLocationWithHeadOffsetValue(operand)); source += R"( - {{OPERAND}} = {{OPERAND}} + (gid.z * Hq + gid.y) * R; + {{OPERAND}} = {{OPERAND_LOCATION}}; )"; - } else { - if (!transposed(operand)) { - source.SetValue("HEAD_DIMENSION", std::to_string(headDimension)); - } else { - source.SetValue("HEAD_DIMENSION", "1"); - } - source += R"( - {{OPERAND}} = {{OPERAND}} + gid.z * {{OPERAND}}_batch_stride + gid.y * {{HEAD_DIMENSION}}; -)"; - } } return source.ToString(); } @@ -867,7 +880,8 @@ std::string AttentionKernel::accumulate(const AttentionAccumulateDescriptor& acc [=](LoopIterationDescriptor descriptor) -> std::string { CodeWriter source; source.SetValue("C", C.name()); - source.SetValue("CLAMPED_PARALLELIZATION_THREAD_OFFSET", clampedParallelizationThreadOffsetValue()); + source.SetValue("C_LOCATION", operandLocationValue(C)); + source.SetValue("CLAMPED_PARALLELIZATION_THREAD_OFFSET", clampedParallelizationThreadOffsetValue()); source.SetValue("MEMORY_NAME_C", memoryName(C)); source.SetValue("LEADING_DIMENSION_C", leadingDimension(C)); source.SetValue("LEADING_BLOCK_DIMENSION_C", std::to_string(leadingBlockDimension(C))); @@ -881,7 +895,7 @@ std::string AttentionKernel::accumulate(const AttentionAccumulateDescriptor& acc {{CLAMPED_PARALLELIZATION_THREAD_OFFSET}}); auto {{C}}_src = simdgroup_matrix_storage<{{MEMORY_NAME_C}}> ::apply_offset( - {{C}}, {{LEADING_DIMENSION_C}}, + {{C_LOCATION}}, {{LEADING_DIMENSION_C}}, {{C}}_src_offset, {{TRANSPOSED_C}}); )"; @@ -908,11 +922,12 @@ std::string AttentionKernel::accumulate(const AttentionAccumulateDescriptor& acc auto asyncLoadAccumulator = [=]() -> std::string { CodeWriter source; source.SetValue("C", C.name()); - source.SetValue("PARALLELIZATION_GROUP_OFFSET", parallelizationGroupOffsetValue()); - source.SetValue("PARALLELIZATION_DIMENSION", parallelizationDimensionValue()); - source.SetValue("BLOCK_DIMENSIONS_PARALLELIZATION", std::to_string(blockDimensions[0])); - source.SetValue("BLOCK_DIMENSIONS_HEAD", std::to_string(blockDimensions[2])); - source.SetValue("HEAD_DIMENSION", std::to_string(headDimension)); + source.SetValue("C_LOCATION", operandLocationValue(C)); + source.SetValue("PARALLELIZATION_GROUP_OFFSET", parallelizationGroupOffsetValue()); + source.SetValue("PARALLELIZATION_DIMENSION", parallelizationDimensionValue()); + source.SetValue("BLOCK_DIMENSIONS_PARALLELIZATION", std::to_string(blockDimensions[0])); + source.SetValue("BLOCK_DIMENSIONS_HEAD", std::to_string(blockDimensions[2])); + source.SetValue("HEAD_DIMENSION", std::to_string(headDimension)); source.SetValue("MEMORY_NAME_C", memoryName(C)); source.SetValue("LEADING_DIMENSION_C", leadingDimension(C)); source.SetValue("LEADING_BLOCK_DIMENSION_C", std::to_string(leadingBlockDimension(C))); @@ -924,7 +939,7 @@ std::string AttentionKernel::accumulate(const AttentionAccumulateDescriptor& acc uint2 {{C}}_offset(d_outer, {{PARALLELIZATION_GROUP_OFFSET}}); auto src = simdgroup_matrix_storage<{{MEMORY_NAME_C}}> ::apply_offset( - {{C}}, {{LEADING_DIMENSION_C}}, + {{C_LOCATION}}, {{LEADING_DIMENSION_C}}, {{C}}_offset, {{TRANSPOSED_C}}); auto dst = (threadgroup {{MEMORY_NAME_C}}*)(threadgroup_block); @@ -951,11 +966,12 @@ std::string AttentionKernel::accumulate(const AttentionAccumulateDescriptor& acc [=]() -> std::string { CodeWriter source; source.SetValue("C", C.name()); - source.SetValue("PARALLELIZATION_GROUP_OFFSET", parallelizationGroupOffsetValue()); - source.SetValue("PARALLELIZATION_DIMENSION", parallelizationDimensionValue()); - source.SetValue("BLOCK_DIMENSIONS_PARALLELIZATION", std::to_string(blockDimensions[0])); - source.SetValue("BLOCK_DIMENSIONS_HEAD", std::to_string(blockDimensions[2])); - source.SetValue("HEAD_DIMENSION", std::to_string(headDimension)); + source.SetValue("C_LOCATION", operandLocationValue(C)); + source.SetValue("PARALLELIZATION_GROUP_OFFSET", parallelizationGroupOffsetValue()); + source.SetValue("PARALLELIZATION_DIMENSION", parallelizationDimensionValue()); + source.SetValue("BLOCK_DIMENSIONS_PARALLELIZATION", std::to_string(blockDimensions[0])); + source.SetValue("BLOCK_DIMENSIONS_HEAD", std::to_string(blockDimensions[2])); + source.SetValue("HEAD_DIMENSION", std::to_string(headDimension)); source.SetValue("MEMORY_NAME_C", memoryName(C)); source.SetValue("LEADING_DIMENSION_C", leadingDimension(C)); source.SetValue("LEADING_BLOCK_DIMENSION_C", std::to_string(leadingBlockDimension(C))); @@ -968,7 +984,7 @@ std::string AttentionKernel::accumulate(const AttentionAccumulateDescriptor& acc auto src = (threadgroup {{MEMORY_NAME_C}}*)(threadgroup_block); auto dst = simdgroup_matrix_storage<{{MEMORY_NAME_C}}> ::apply_offset( - {{C}}, {{LEADING_DIMENSION_C}}, + {{C_LOCATION}}, {{LEADING_DIMENSION_C}}, {{C}}_offset, {{TRANSPOSED_C}}); ushort D_dimension = min( @@ -1106,6 +1122,7 @@ std::string AttentionKernel::accumulate(const AttentionAccumulateDescriptor& acc [=](LoopIterationDescriptor descriptor) -> std::string { CodeWriter source; source.SetValue("B", B.name()); + source.SetValue("B_LOCATION", operandLocationValue(B)); source.SetValue("MEMORY_NAME_B", memoryName(B)); source.SetValue("LEADING_DIMENSION_B", leadingDimension(B)); source.SetValue("LEADING_BLOCK_DIMENSION_B", std::to_string(leadingBlockDimension(B))); @@ -1120,7 +1137,7 @@ std::string AttentionKernel::accumulate(const AttentionAccumulateDescriptor& acc morton_offset.y + {{TRAVERSAL_OFFSET}}); auto {{B}}_src = simdgroup_matrix_storage<{{MEMORY_NAME_B}}> ::apply_offset( - {{B}}, {{LEADING_DIMENSION_B}}, + {{B_LOCATION}}, {{LEADING_DIMENSION_B}}, {{B}}_src_offset, {{TRANSPOSED_B}}); )"; @@ -1152,6 +1169,7 @@ std::string AttentionKernel::accumulate(const AttentionAccumulateDescriptor& acc case MTLAddressSpace::threadgroup: CodeWriter source; source.SetValue("B", B.name()); + source.SetValue("B_LOCATION", operandLocationValue(B)); source.SetValue("MEMORY_NAME_B", memoryName(B)); source.SetValue("LEADING_DIMENSION_B", leadingDimension(B)); source.SetValue("LEADING_BLOCK_DIMENSION_B", std::to_string(leadingBlockDimension(B))); @@ -1170,7 +1188,7 @@ std::string AttentionKernel::accumulate(const AttentionAccumulateDescriptor& acc uint2 {{B}}_offset(d_outer, {{TRAVERSAL_OFFSET}}); auto src = simdgroup_matrix_storage<{{MEMORY_NAME_B}}> ::apply_offset( - {{B}}, {{LEADING_DIMENSION_B}}, + {{B_LOCATION}}, {{LEADING_DIMENSION_B}}, {{B}}_offset, {{TRANSPOSED_B}}); auto dst = (threadgroup {{MEMORY_NAME_B}}*)(threadgroup_block); @@ -1262,7 +1280,7 @@ std::string AttentionKernel::accumulate(const AttentionAccumulateDescriptor& acc if (descriptor.addressSpaceLHS == MTLAddressSpace::device || descriptor.addressSpaceRHS == MTLAddressSpace::device) { auto blockDim = blockDimensions[1]; - source.SetValue("INNER_LOOP_TRAVERSAL", innerLoopTraversal("0", std::to_string(blockDim), descriptor)); + source.SetValue("INNER_LOOP_TRAVERSAL", innerLoopTraversal("0", std::to_string(blockDim), descriptor)); source.SetValue("BLOCK_DIM", std::to_string(blockDim)); source.SetValue("TRAVERSAL_OFFSET", traversalOffsetValue()); source.SetValue("TRAVERSAL_DIMENSION", traversalDimensionValue()); @@ -1279,8 +1297,8 @@ std::string AttentionKernel::accumulate(const AttentionAccumulateDescriptor& acc )"; } else { - source.SetValue("INNER_LOOP_TRAVERSAL_0", innerLoopTraversal("0", paddedTraversalEdgeValue(), descriptor)); - source.SetValue("INNER_LOOP_TRAVERSAL_1", innerLoopTraversal(paddedTraversalEdgeValue(), std::to_string(blockDimensions[1]), descriptor)); + source.SetValue("INNER_LOOP_TRAVERSAL_0", innerLoopTraversal("0", paddedTraversalEdgeValue(), descriptor)); + source.SetValue("INNER_LOOP_TRAVERSAL_1", innerLoopTraversal(paddedTraversalEdgeValue(), std::to_string(blockDimensions[1]), descriptor)); source.SetValue("BLOCK_DIMENSIONS_TRAVERSAL", std::to_string(blockDimensions[1])); source.SetValue("TRAVERSAL_OFFSET", traversalOffsetValue()); source.SetValue("TRAVERSAL_DIMENSION", traversalDimensionValue()); @@ -1307,10 +1325,10 @@ std::string AttentionKernel::accumulate(const AttentionAccumulateDescriptor& acc if (cached(C)) { source.SetValue("LOAD_ACCUMULATOR", ""); source.SetValue("STORE_ACCUMULATOR", ""); - } else { + } else { source.SetValue("LOAD_ACCUMULATOR", loadAccumulator(descriptor)); source.SetValue("STORE_ACCUMULATOR", storeAccumulator(descriptor)); - } + } source.SetValue("LOAD_RHS", loadRHS(descriptor)); source.SetValue("MULTIPLY_AB", multiplyAB()); source.SetValue("SCALE_ACCUMULATOR", scaleAccumulator(accumulateDesc.everyIterationScale, descriptor)); @@ -1325,7 +1343,7 @@ std::string AttentionKernel::accumulate(const AttentionAccumulateDescriptor& acc } {{LOAD_RHS}} {{MULTIPLY_AB}} - {{STORE_ACCUMULATOR}} + {{STORE_ACCUMULATOR}} )"; return source.ToString(); @@ -1471,9 +1489,9 @@ std::string AttentionKernel::cache(AttentionOperand operand, CachingOperationTyp [=]() -> std::string { if (type == CachingOperationType::load) { CodeWriter source; - source.SetValue("REGISTER_NAME_OPERAND", registerName(operand)); - source.SetValue("OPERAND", operand.name()); - source.SetValue("PADDED_HEAD_DIMENSION_8", std::to_string(paddedHeadDimensionValue() / 8)); + source.SetValue("REGISTER_NAME_OPERAND", registerName(operand)); + source.SetValue("OPERAND", operand.name()); + source.SetValue("PADDED_HEAD_DIMENSION_8", std::to_string(paddedHeadDimensionValue() / 8)); source += R"( simdgroup_matrix_storage<{{REGISTER_NAME_OPERAND}}> {{OPERAND}}_sram[{{PADDED_HEAD_DIMENSION_8}}]; @@ -1489,17 +1507,18 @@ std::string AttentionKernel::cache(AttentionOperand operand, CachingOperationTyp [=]() -> std::string { if (type == CachingOperationType::load) { CodeWriter source; - source.SetValue("MEMORY_NAME_OPERAND", memoryName(operand)); - source.SetValue("OPERAND", operand.name()); - source.SetValue("LEADING_BLOCK_DIMENSION_OPERAND", std::to_string(leadingBlockDimension(operand))); - source.SetValue("LEADING_DIMENSION_OPERAND", leadingDimension(operand)); - source.SetValue("TRANSPOSED_OPERAND", transposed(operand) ? "true" : "false"); - source.SetValue("HEAD_DIMENSION", std::to_string(headDimension)); - source.SetValue("PADDED_HEAD_DIMENSION", std::to_string(paddedHeadDimensionValue())); - source.SetValue("BLOCK_DIMENSIONS_HEAD", std::to_string(blockDimensions[2])); - source.SetValue("PARALLELIZATION_DIMENSION", parallelizationDimensionValue()); - source.SetValue("PARALLELIZATION_GROUP_OFFSET", parallelizationGroupOffsetValue()); - source.SetValue("BLOCK_DIMENSIONS_PARALLELIZATION", std::to_string(blockDimensions[0])); + source.SetValue("MEMORY_NAME_OPERAND", memoryName(operand)); + source.SetValue("OPERAND", operand.name()); + source.SetValue("OPERAND_LOCATION", operandLocationValue(operand)); + source.SetValue("LEADING_BLOCK_DIMENSION_OPERAND", std::to_string(leadingBlockDimension(operand))); + source.SetValue("LEADING_DIMENSION_OPERAND", leadingDimension(operand)); + source.SetValue("TRANSPOSED_OPERAND", transposed(operand) ? "true" : "false"); + source.SetValue("HEAD_DIMENSION", std::to_string(headDimension)); + source.SetValue("PADDED_HEAD_DIMENSION", std::to_string(paddedHeadDimensionValue())); + source.SetValue("BLOCK_DIMENSIONS_HEAD", std::to_string(blockDimensions[2])); + source.SetValue("PARALLELIZATION_DIMENSION", parallelizationDimensionValue()); + source.SetValue("PARALLELIZATION_GROUP_OFFSET", parallelizationGroupOffsetValue()); + source.SetValue("BLOCK_DIMENSIONS_PARALLELIZATION", std::to_string(blockDimensions[0])); source += R"( threadgroup_barrier(mem_flags::mem_threadgroup); @@ -1507,7 +1526,7 @@ std::string AttentionKernel::cache(AttentionOperand operand, CachingOperationTyp uint2 {{OPERAND}}_offset(d_outer, {{PARALLELIZATION_GROUP_OFFSET}}); auto src = simdgroup_matrix_storage<{{MEMORY_NAME_OPERAND}}> ::apply_offset( - {{OPERAND}}, {{LEADING_DIMENSION_OPERAND}}, + {{OPERAND_LOCATION}}, {{LEADING_DIMENSION_OPERAND}}, {{OPERAND}}_offset, {{TRANSPOSED_OPERAND}}); auto dst = (threadgroup {{MEMORY_NAME_OPERAND}}*)(threadgroup_block); @@ -1535,16 +1554,17 @@ std::string AttentionKernel::cache(AttentionOperand operand, CachingOperationTyp return source.ToString(); } else { CodeWriter source; - source.SetValue("MEMORY_NAME_OPERAND", memoryName(operand)); - source.SetValue("OPERAND", operand.name()); - source.SetValue("LEADING_BLOCK_DIMENSION_OPERAND", std::to_string(leadingBlockDimension(operand))); - source.SetValue("LEADING_DIMENSION_OPERAND", leadingDimension(operand)); - source.SetValue("TRANSPOSED_OPERAND", transposed(operand) ? "true" : "false"); - source.SetValue("HEAD_DIMENSION", std::to_string(headDimension)); - source.SetValue("BLOCK_DIMENSIONS_HEAD", std::to_string(blockDimensions[2])); - source.SetValue("PARALLELIZATION_DIMENSION", parallelizationDimensionValue()); - source.SetValue("PARALLELIZATION_GROUP_OFFSET", parallelizationGroupOffsetValue()); - source.SetValue("BLOCK_DIMENSIONS_PARALLELIZATION", std::to_string(blockDimensions[0])); + source.SetValue("MEMORY_NAME_OPERAND", memoryName(operand)); + source.SetValue("OPERAND", operand.name()); + source.SetValue("OPERAND_LOCATION", operandLocationValue(operand)); + source.SetValue("LEADING_BLOCK_DIMENSION_OPERAND", std::to_string(leadingBlockDimension(operand))); + source.SetValue("LEADING_DIMENSION_OPERAND", leadingDimension(operand)); + source.SetValue("TRANSPOSED_OPERAND", transposed(operand) ? "true" : "false"); + source.SetValue("HEAD_DIMENSION", std::to_string(headDimension)); + source.SetValue("BLOCK_DIMENSIONS_HEAD", std::to_string(blockDimensions[2])); + source.SetValue("PARALLELIZATION_DIMENSION", parallelizationDimensionValue()); + source.SetValue("PARALLELIZATION_GROUP_OFFSET", parallelizationGroupOffsetValue()); + source.SetValue("BLOCK_DIMENSIONS_PARALLELIZATION", std::to_string(blockDimensions[0])); source += R"( threadgroup_barrier(mem_flags::mem_threadgroup); @@ -1553,7 +1573,7 @@ std::string AttentionKernel::cache(AttentionOperand operand, CachingOperationTyp auto src = (threadgroup {{MEMORY_NAME_OPERAND}}*)(threadgroup_block); auto dst = simdgroup_matrix_storage<{{MEMORY_NAME_OPERAND}}> ::apply_offset( - {{OPERAND}}, {{LEADING_DIMENSION_OPERAND}}, + {{OPERAND_LOCATION}}, {{LEADING_DIMENSION_OPERAND}}, {{OPERAND}}_offset, {{TRANSPOSED_OPERAND}}); ushort D_dimension = min( @@ -1594,11 +1614,12 @@ std::string AttentionKernel::cache(AttentionOperand operand, CachingOperationTyp [=](LoopIterationDescriptor descriptor) -> std::string { if (descriptor.addressSpace == MTLAddressSpace::device) { CodeWriter source; - source.SetValue("MEMORY_NAME_OPERAND", memoryName(operand)); - source.SetValue("OPERAND", operand.name()); - source.SetValue("LEADING_DIMENSION_OPERAND", leadingDimension(operand)); - source.SetValue("TRANSPOSED_OPERAND", transposed(operand) ? "true" : "false"); - source.SetValue("CLAMPED_PARALLELIZATION_THREAD_OFFSET", clampedParallelizationThreadOffsetValue()); + source.SetValue("MEMORY_NAME_OPERAND", memoryName(operand)); + source.SetValue("OPERAND", operand.name()); + source.SetValue("OPERAND_LOCATION", operandLocationValue(operand)); + source.SetValue("LEADING_DIMENSION_OPERAND", leadingDimension(operand)); + source.SetValue("TRANSPOSED_OPERAND", transposed(operand) ? "true" : "false"); + source.SetValue("CLAMPED_PARALLELIZATION_THREAD_OFFSET", clampedParallelizationThreadOffsetValue()); source += R"( uint2 {{OPERAND}}_src_offset( @@ -1606,17 +1627,17 @@ std::string AttentionKernel::cache(AttentionOperand operand, CachingOperationTyp {{CLAMPED_PARALLELIZATION_THREAD_OFFSET}}); auto {{OPERAND}}_src = simdgroup_matrix_storage<{{MEMORY_NAME_OPERAND}}> ::apply_offset( - {{OPERAND}}, {{LEADING_DIMENSION_OPERAND}}, + {{OPERAND_LOCATION}}, {{LEADING_DIMENSION_OPERAND}}, {{OPERAND}}_src_offset, {{TRANSPOSED_OPERAND}}); )"; return source.ToString(); } else { CodeWriter source; - source.SetValue("MEMORY_NAME_OPERAND", memoryName(operand)); - source.SetValue("OPERAND", operand.name()); - source.SetValue("LEADING_BLOCK_DIMENSION_OPERAND", std::to_string(leadingBlockDimension(operand))); - source.SetValue("TRANSPOSED_OPERAND", transposed(operand) ? "true" : "false"); + source.SetValue("MEMORY_NAME_OPERAND", memoryName(operand)); + source.SetValue("OPERAND", operand.name()); + source.SetValue("LEADING_BLOCK_DIMENSION_OPERAND", std::to_string(leadingBlockDimension(operand))); + source.SetValue("TRANSPOSED_OPERAND", transposed(operand) ? "true" : "false"); source += R"( ushort2 {{OPERAND}}_block_offset( @@ -1673,7 +1694,7 @@ std::string AttentionKernel::cache(AttentionOperand operand, CachingOperationTyp )"; } - return source.ToString(); + return source.ToString(); }; // MARK: - Outer Loop @@ -1738,7 +1759,7 @@ std::string AttentionKernel::cache(AttentionOperand operand, CachingOperationTyp {{STORE_OPERAND}} )"; - return source.ToString(); + return source.ToString(); } }; @@ -1749,11 +1770,11 @@ std::string AttentionKernel::cache(AttentionOperand operand, CachingOperationTyp descriptorDevice.addressSpace = MTLAddressSpace::device; descriptorThreadgroup.addressSpace = MTLAddressSpace::threadgroup; CodeWriter source; - source.SetValue("NOT_PREFER_ASYNC_CACHE", !preferAsyncCache ? "true" : "false"); - source.SetValue("HEAD_DIMENSION", std::to_string(headDimension)); - source.SetValue("BLOCK_DIMENSIONS_HEAD", std::to_string(blockDimensions[2])); - source.SetValue("LOOP_ITERATION_DEVICE", loopIteration(descriptorDevice)); - source.SetValue("LOOP_ITERATION_THREADGROUP", loopIteration(descriptorThreadgroup)); + source.SetValue("NOT_PREFER_ASYNC_CACHE", !preferAsyncCache ? "true" : "false"); + source.SetValue("HEAD_DIMENSION", std::to_string(headDimension)); + source.SetValue("BLOCK_DIMENSIONS_HEAD", std::to_string(blockDimensions[2])); + source.SetValue("LOOP_ITERATION_DEVICE", loopIteration(descriptorDevice)); + source.SetValue("LOOP_ITERATION_THREADGROUP", loopIteration(descriptorThreadgroup)); source += R"( @@ -1796,9 +1817,9 @@ std::string AttentionKernel::createSetup() const noexcept { auto allocate = [=](AttentionOperand operand) -> std::string { CodeWriter source; - source.SetValue("REGISTER_NAME_OPERAND", registerName(operand)); - source.SetValue("OPERAND", operand.name()); - source.SetValue("PADDED_HEAD_DIMENSION_8", std::to_string(paddedHeadDimensionValue() / 8)); + source.SetValue("REGISTER_NAME_OPERAND", registerName(operand)); + source.SetValue("OPERAND", operand.name()); + source.SetValue("PADDED_HEAD_DIMENSION_8", std::to_string(paddedHeadDimensionValue() / 8)); source += R"( simdgroup_matrix_storage<{{REGISTER_NAME_OPERAND}}> {{OPERAND}}_sram[{{PADDED_HEAD_DIMENSION_8}}]; @@ -1843,8 +1864,8 @@ std::string AttentionKernel::createSetup() const noexcept { // L is always either FP16 or FP32, so we don't need custom type // conversion code here. - output.SetValue("CLAMPED_PARALLELIZATION_THREAD_OFFSET", clampedParallelizationThreadOffsetValue()); - output.SetValue("COMPUTE_D", computeD()); + output.SetValue("CLAMPED_PARALLELIZATION_THREAD_OFFSET", clampedParallelizationThreadOffsetValue()); + output.SetValue("COMPUTE_D", computeD()); output += R"( float L_sram = L[{{CLAMPED_PARALLELIZATION_THREAD_OFFSET}}]; @@ -1884,15 +1905,16 @@ std::string AttentionKernel::createCleanup(const AttentionKernelType type) const // L is always either FP16 or FP32, so we don't need custom type // conversion code here. - output.SetValue("UNSAFE_PARALLELIZATION_THREAD_OFFSET", unsafeParallelizationThreadOffsetValue()); - output.SetValue("PARALLELIZATION_DIMENSION", parallelizationDimensionValue()); - output.SetValue("CLAMPED_PARALLELIZATION_THREAD_OFFSET", clampedParallelizationThreadOffsetValue()); + output.SetValue("L_LOCATION", operandLocationValue(AttentionOperand::L)); + output.SetValue("UNSAFE_PARALLELIZATION_THREAD_OFFSET", unsafeParallelizationThreadOffsetValue()); + output.SetValue("PARALLELIZATION_DIMENSION", parallelizationDimensionValue()); + output.SetValue("CLAMPED_PARALLELIZATION_THREAD_OFFSET", clampedParallelizationThreadOffsetValue()); output += R"( if ({{UNSAFE_PARALLELIZATION_THREAD_OFFSET}} < {{PARALLELIZATION_DIMENSION}}) { // Premultiplied by log_base_2(e). float L_sram = m + fast::log2(l); - L[{{CLAMPED_PARALLELIZATION_THREAD_OFFSET}}] = L_sram; + ({{L_LOCATION}})[{{CLAMPED_PARALLELIZATION_THREAD_OFFSET}}] = L_sram; } )"; @@ -1906,12 +1928,13 @@ std::string AttentionKernel::createCleanup(const AttentionKernelType type) const auto storeD = [=]() -> std::string { CodeWriter source; - source.SetValue("CLAMPED_PARALLELIZATION_THREAD_OFFSET", clampedParallelizationThreadOffsetValue()); + source.SetValue("D_LOCATION", operandLocationValue(AttentionOperand::D)); + source.SetValue("CLAMPED_PARALLELIZATION_THREAD_OFFSET", clampedParallelizationThreadOffsetValue()); switch (memoryPrecisions[AttentionOperand::D].value().value) { case GEMMOperandPrecision::FP32: source += R"( - D[{{CLAMPED_PARALLELIZATION_THREAD_OFFSET}}] = D_sram; + ({{D_LOCATION}})[{{CLAMPED_PARALLELIZATION_THREAD_OFFSET}}] = D_sram; )"; break; @@ -1920,19 +1943,19 @@ std::string AttentionKernel::createCleanup(const AttentionKernelType type) const bfloat2 registerForm = *(thread bfloat2*)(&D_sram); bfloat memoryForm = registerForm[1]; - D[{{CLAMPED_PARALLELIZATION_THREAD_OFFSET}}] = memoryForm; + ({{D_LOCATION}})[{{CLAMPED_PARALLELIZATION_THREAD_OFFSET}}] = memoryForm; )"; break; default: - CCV_NNC_MFA_PRECONDITION(false); + CCV_NNC_MFA_PRECONDITION(false); break; } return source.ToString(); }; - output.SetValue("UNSAFE_PARALLELIZATION_THREAD_OFFSET", unsafeParallelizationThreadOffsetValue()); - output.SetValue("PARALLELIZATION_DIMENSION", parallelizationDimensionValue()); - output.SetValue("STORE_D", storeD()); + output.SetValue("UNSAFE_PARALLELIZATION_THREAD_OFFSET", unsafeParallelizationThreadOffsetValue()); + output.SetValue("PARALLELIZATION_DIMENSION", parallelizationDimensionValue()); + output.SetValue("STORE_D", storeD()); output += R"( if ({{UNSAFE_PARALLELIZATION_THREAD_OFFSET}} < {{PARALLELIZATION_DIMENSION}}) { @@ -1967,9 +1990,9 @@ std::string AttentionKernel::outerProduct(const AttentionOuterProductDescriptor& auto allocateAccumulator = [=]() -> std::string { CodeWriter source; - source.SetValue("C", C.name()); - source.SetValue("REGISTER_NAME_C", registerName(C)); - source.SetValue("BLOCK_DIMENSIONS_TRAVERSAL", std::to_string(blockDimensions[1])); + source.SetValue("C", C.name()); + source.SetValue("REGISTER_NAME_C", registerName(C)); + source.SetValue("BLOCK_DIMENSIONS_TRAVERSAL", std::to_string(blockDimensions[1])); source += R"( simdgroup_matrix_storage<{{REGISTER_NAME_C}}> {{C}}_sram[{{BLOCK_DIMENSIONS_TRAVERSAL}} / 8]; @@ -1981,9 +2004,9 @@ std::string AttentionKernel::outerProduct(const AttentionOuterProductDescriptor& auto initializeAccumulator = [=]() -> std::string { CodeWriter source; - source.SetValue("C", C.name()); - source.SetValue("REGISTER_NAME_C", registerName(C)); - source.SetValue("BLOCK_DIMENSIONS_TRAVERSAL", std::to_string(blockDimensions[1])); + source.SetValue("C", C.name()); + source.SetValue("REGISTER_NAME_C", registerName(C)); + source.SetValue("BLOCK_DIMENSIONS_TRAVERSAL", std::to_string(blockDimensions[1])); source += R"( #pragma clang loop unroll(full) @@ -1993,7 +2016,7 @@ std::string AttentionKernel::outerProduct(const AttentionOuterProductDescriptor& } )"; - return source.ToString(); + return source.ToString(); }; struct LoopIterationDescriptor { @@ -2011,15 +2034,15 @@ std::string AttentionKernel::outerProduct(const AttentionOuterProductDescriptor& return ""; } CodeWriter source; - source.SetValue("A", A.name()); - source.SetValue("REGISTER_NAME_A", registerName(A)); - source.SetValue("DESCRIPTOR_REGISTER_SIZE", std::to_string(descriptor.registerSize)); + source.SetValue("A", A.name()); + source.SetValue("REGISTER_NAME_A", registerName(A)); + source.SetValue("DESCRIPTOR_REGISTER_SIZE", std::to_string(descriptor.registerSize)); source += R"( simdgroup_matrix_storage<{{REGISTER_NAME_A}}> {{A}}_sram[{{DESCRIPTOR_REGISTER_SIZE}} / 8]; )"; - return source.ToString(); + return source.ToString(); }; // MARK: - Load LHS @@ -2027,13 +2050,14 @@ std::string AttentionKernel::outerProduct(const AttentionOuterProductDescriptor& auto declareLHSLocation = [=](LoopIterationDescriptor descriptor) -> std::string { CodeWriter source; - source.SetValue("A", A.name()); - source.SetValue("MEMORY_NAME_A", memoryName(A)); - source.SetValue("CLAMPED_PARALLELIZATION_THREAD_OFFSET", clampedParallelizationThreadOffsetValue()); - source.SetValue("TRANSPOSED_A", transposed(A) ? "true" : "false"); + source.SetValue("A", A.name()); + source.SetValue("MEMORY_NAME_A", memoryName(A)); + source.SetValue("CLAMPED_PARALLELIZATION_THREAD_OFFSET", clampedParallelizationThreadOffsetValue()); + source.SetValue("TRANSPOSED_A", transposed(A) ? "true" : "false"); switch (descriptor.addressSpaceLHS.value) { case MTLAddressSpace::device: - source.SetValue("LEADING_DIMENSION_A", leadingDimension(A)); + source.SetValue("LEADING_DIMENSION_A", leadingDimension(A)); + source.SetValue("A_LOCATION", operandLocationValue(A)); source += R"( uint2 {{A}}_src_offset( @@ -2041,13 +2065,13 @@ std::string AttentionKernel::outerProduct(const AttentionOuterProductDescriptor& {{CLAMPED_PARALLELIZATION_THREAD_OFFSET}}); auto {{A}}_src = simdgroup_matrix_storage<{{MEMORY_NAME_A}}> ::apply_offset( - {{A}}, {{LEADING_DIMENSION_A}}, + {{A_LOCATION}}, {{LEADING_DIMENSION_A}}, {{A}}_src_offset, {{TRANSPOSED_A}}); )"; return source.ToString(); case MTLAddressSpace::threadgroup: - source.SetValue("LEADING_BLOCK_DIMENSION_A", std::to_string(leadingBlockDimension(A))); + source.SetValue("LEADING_BLOCK_DIMENSION_A", std::to_string(leadingBlockDimension(A))); source += R"( ushort2 {{A}}_block_offset( @@ -2068,18 +2092,19 @@ std::string AttentionKernel::outerProduct(const AttentionOuterProductDescriptor& auto asyncLoadLHS = [=](LoopIterationDescriptor descriptor) -> std::string { CodeWriter source; - source.SetValue("A", A.name()); - source.SetValue("MEMORY_NAME_A", memoryName(A)); - source.SetValue("CLAMPED_PARALLELIZATION_THREAD_OFFSET", clampedParallelizationThreadOffsetValue()); - source.SetValue("TRANSPOSED_A", transposed(A) ? "true" : "false"); + source.SetValue("A", A.name()); + source.SetValue("A_LOCATION", operandLocationValue(A)); + source.SetValue("MEMORY_NAME_A", memoryName(A)); + source.SetValue("CLAMPED_PARALLELIZATION_THREAD_OFFSET", clampedParallelizationThreadOffsetValue()); + source.SetValue("TRANSPOSED_A", transposed(A) ? "true" : "false"); source.SetValue("LEADING_DIMENSION_A", leadingDimension(A)); - source.SetValue("LEADING_BLOCK_DIMENSION_A", std::to_string(leadingBlockDimension(A))); - source.SetValue("BLOCK_DIMENSIONS_HEAD", std::to_string(blockDimensions[2])); - source.SetValue("HEAD_DIMENSION", std::to_string(headDimension)); - source.SetValue("DESCRIPTOR_REGISTER_SIZE", std::to_string(descriptor.registerSize)); - source.SetValue("BLOCK_DIMENSIONS_PARALLELIZATION", std::to_string(blockDimensions[0])); - source.SetValue("PARALLELIZATION_DIMENSION", parallelizationDimensionValue()); - source.SetValue("PARALLELIZATION_GROUP_OFFSET", parallelizationGroupOffsetValue()); + source.SetValue("LEADING_BLOCK_DIMENSION_A", std::to_string(leadingBlockDimension(A))); + source.SetValue("BLOCK_DIMENSIONS_HEAD", std::to_string(blockDimensions[2])); + source.SetValue("HEAD_DIMENSION", std::to_string(headDimension)); + source.SetValue("DESCRIPTOR_REGISTER_SIZE", std::to_string(descriptor.registerSize)); + source.SetValue("BLOCK_DIMENSIONS_PARALLELIZATION", std::to_string(blockDimensions[0])); + source.SetValue("PARALLELIZATION_DIMENSION", parallelizationDimensionValue()); + source.SetValue("PARALLELIZATION_GROUP_OFFSET", parallelizationGroupOffsetValue()); source += R"( threadgroup_barrier(mem_flags::mem_threadgroup); @@ -2087,7 +2112,7 @@ std::string AttentionKernel::outerProduct(const AttentionOuterProductDescriptor& uint2 {{A}}_offset(d_outer, {{PARALLELIZATION_GROUP_OFFSET}}); auto src = simdgroup_matrix_storage<{{MEMORY_NAME_A}}> ::apply_offset( - {{A}}, {{LEADING_DIMENSION_A}}, + {{A_LOCATION}}, {{LEADING_DIMENSION_A}}, {{A}}_offset, {{TRANSPOSED_A}}); auto dst = (threadgroup {{MEMORY_NAME_A}}*)(threadgroup_block); @@ -2109,7 +2134,7 @@ std::string AttentionKernel::outerProduct(const AttentionOuterProductDescriptor& } )"; - return source.ToString(); + return source.ToString(); }; auto loadLHS = @@ -2175,14 +2200,15 @@ std::string AttentionKernel::outerProduct(const AttentionOuterProductDescriptor& auto declareRHSLocation = [=](LoopIterationDescriptor descriptor) -> std::string { - CodeWriter source; - source.SetValue("B", B.name()); - source.SetValue("MEMORY_NAME_B", memoryName(B)); + CodeWriter source; + source.SetValue("B", B.name()); + source.SetValue("MEMORY_NAME_B", memoryName(B)); switch (descriptor.addressSpaceRHS.value) { case MTLAddressSpace::device: - source.SetValue("TRAVERSAL_OFFSET", traversalOffsetValue()); - source.SetValue("LEADING_DIMENSION_B", leadingDimension(B)); - source.SetValue("TRANSPOSED_B", transposed(B) ? "true" : "false"); + source.SetValue("TRAVERSAL_OFFSET", traversalOffsetValue()); + source.SetValue("LEADING_DIMENSION_B", leadingDimension(B)); + source.SetValue("TRANSPOSED_B", transposed(B) ? "true" : "false"); + source.SetValue("B_LOCATION", operandLocationValue(B)); source += R"( uint2 {{B}}_src_offset( @@ -2190,14 +2216,14 @@ std::string AttentionKernel::outerProduct(const AttentionOuterProductDescriptor& morton_offset.x + {{TRAVERSAL_OFFSET}}); auto {{B}}_src = simdgroup_matrix_storage<{{MEMORY_NAME_B}}> ::apply_offset( - {{B}}, {{LEADING_DIMENSION_B}}, + {{B_LOCATION}}, {{LEADING_DIMENSION_B}}, {{B}}_src_offset, {{TRANSPOSED_B}}); )"; break; case MTLAddressSpace::threadgroup: - source.SetValue("LEADING_BLOCK_DIMENSION_B", std::to_string(leadingBlockDimension(B))); - source.SetValue("NOT_TRANSPOSED_B", !transposed(B) ? "true" : "false"); + source.SetValue("LEADING_BLOCK_DIMENSION_B", std::to_string(leadingBlockDimension(B))); + source.SetValue("NOT_TRANSPOSED_B", !transposed(B) ? "true" : "false"); source += R"( ushort2 {{B}}_block_offset( @@ -2213,7 +2239,7 @@ std::string AttentionKernel::outerProduct(const AttentionOuterProductDescriptor& )"; break; } - return source.ToString(); + return source.ToString(); }; auto loadRHS = @@ -2223,27 +2249,28 @@ std::string AttentionKernel::outerProduct(const AttentionOuterProductDescriptor& case MTLAddressSpace::device: return declareRHSLocation(descriptor); case MTLAddressSpace::threadgroup: - source.SetValue("B", B.name()); - source.SetValue("MEMORY_NAME_B", memoryName(B)); - source.SetValue("TRAVERSAL_OFFSET", traversalOffsetValue()); - source.SetValue("TRAVERSAL_DIMENSION", traversalDimensionValue()); - source.SetValue("BLOCK_DIMENSIONS_TRAVERSAL", std::to_string(blockDimensions[1])); - source.SetValue("BLOCK_DIMENSIONS_HEAD", std::to_string(blockDimensions[2])); - source.SetValue("HEAD_DIMENSION", std::to_string(headDimension)); - source.SetValue("PADDED_TRAVERSAL_EDGE", paddedTraversalEdgeValue()); - source.SetValue("LEADING_DIMENSION_B", leadingDimension(B)); - source.SetValue("TRANSPOSED_B", transposed(B) ? "true" : "false"); - source.SetValue("LEADING_BLOCK_DIMENSION_B", std::to_string(leadingBlockDimension(B))); - source.SetValue("DESCRIPTOR_REGISTER_SIZE", std::to_string(descriptor.registerSize)); - source.SetValue("DECLARE_RHS_LOCATION", declareRHSLocation(descriptor)); - source += R"( + source.SetValue("B", B.name()); + source.SetValue("B_LOCATION", operandLocationValue(B)); + source.SetValue("MEMORY_NAME_B", memoryName(B)); + source.SetValue("TRAVERSAL_OFFSET", traversalOffsetValue()); + source.SetValue("TRAVERSAL_DIMENSION", traversalDimensionValue()); + source.SetValue("BLOCK_DIMENSIONS_TRAVERSAL", std::to_string(blockDimensions[1])); + source.SetValue("BLOCK_DIMENSIONS_HEAD", std::to_string(blockDimensions[2])); + source.SetValue("HEAD_DIMENSION", std::to_string(headDimension)); + source.SetValue("PADDED_TRAVERSAL_EDGE", paddedTraversalEdgeValue()); + source.SetValue("LEADING_DIMENSION_B", leadingDimension(B)); + source.SetValue("TRANSPOSED_B", transposed(B) ? "true" : "false"); + source.SetValue("LEADING_BLOCK_DIMENSION_B", std::to_string(leadingBlockDimension(B))); + source.SetValue("DESCRIPTOR_REGISTER_SIZE", std::to_string(descriptor.registerSize)); + source.SetValue("DECLARE_RHS_LOCATION", declareRHSLocation(descriptor)); + source += R"( threadgroup_barrier(mem_flags::mem_threadgroup); if (sidx == 0) { uint2 {{B}}_offset(d_outer, {{TRAVERSAL_OFFSET}}); auto src = simdgroup_matrix_storage<{{MEMORY_NAME_B}}> ::apply_offset( - {{B}}, {{LEADING_DIMENSION_B}}, + {{B_LOCATION}}, {{LEADING_DIMENSION_B}}, {{B}}_offset, {{TRANSPOSED_B}}); auto dst = (threadgroup {{MEMORY_NAME_B}}*)(threadgroup_block); @@ -2270,9 +2297,9 @@ std::string AttentionKernel::outerProduct(const AttentionOuterProductDescriptor& {{DECLARE_RHS_LOCATION}} )"; - break; + break; } - return source.ToString(); + return source.ToString(); }; @@ -2281,18 +2308,18 @@ std::string AttentionKernel::outerProduct(const AttentionOuterProductDescriptor& auto innerLoopTraversal = [=](std::string traversalStart, std::string traversalEnd, LoopIterationDescriptor descriptor) -> std::string { CodeWriter source; - source.SetValue("TRAVERSAL_START", traversalStart); - source.SetValue("TRAVERSAL_END", traversalEnd); - source.SetValue("A", A.name()); - source.SetValue("B", B.name()); - source.SetValue("C", C.name()); - source.SetValue("REGISTER_NAME_B", registerName(B)); - source.SetValue("LOAD_FUNCTION_B", loadFunction(B)); - source.SetValue("LEADING_DIMENSION_RHS", leadingDimensionRHS(descriptor)); - source.SetValue("NOT_TRANSPOSED_B", !transposed(B) ? "true" : "false"); - source.SetValue("DESCRIPTOR_REGISTER_OFFSET", descriptor.registerOffset); - source.SetValue("DESCRIPTOR_ACCUMULATE_CONDITIONAL", descriptor.accumulateConditional); - source += R"( + source.SetValue("TRAVERSAL_START", traversalStart); + source.SetValue("TRAVERSAL_END", traversalEnd); + source.SetValue("A", A.name()); + source.SetValue("B", B.name()); + source.SetValue("C", C.name()); + source.SetValue("REGISTER_NAME_B", registerName(B)); + source.SetValue("LOAD_FUNCTION_B", loadFunction(B)); + source.SetValue("LEADING_DIMENSION_RHS", leadingDimensionRHS(descriptor)); + source.SetValue("NOT_TRANSPOSED_B", !transposed(B) ? "true" : "false"); + source.SetValue("DESCRIPTOR_REGISTER_OFFSET", descriptor.registerOffset); + source.SetValue("DESCRIPTOR_ACCUMULATE_CONDITIONAL", descriptor.accumulateConditional); + source += R"( #pragma clang loop unroll(full) for (ushort c = {{TRAVERSAL_START}}; c < {{TRAVERSAL_END}}; c += 8) { @@ -2310,7 +2337,7 @@ std::string AttentionKernel::outerProduct(const AttentionOuterProductDescriptor& } )"; - return source.ToString(); + return source.ToString(); }; auto innerLoopHead = @@ -2319,7 +2346,7 @@ std::string AttentionKernel::outerProduct(const AttentionOuterProductDescriptor& source.SetValue("DESCRIPTOR_REGISTER_SIZE", std::to_string(descriptor.registerSize)); if (descriptor.addressSpaceLHS == MTLAddressSpace::device || descriptor.addressSpaceRHS == MTLAddressSpace::device) { - source.SetValue("INNER_LOOP_TRAVERSAL", innerLoopTraversal("0", std::to_string(blockDimensions[1]), descriptor)); + source.SetValue("INNER_LOOP_TRAVERSAL", innerLoopTraversal("0", std::to_string(blockDimensions[1]), descriptor)); source += R"( #pragma clang loop unroll(full) @@ -2329,11 +2356,11 @@ std::string AttentionKernel::outerProduct(const AttentionOuterProductDescriptor& )"; } else { - source.SetValue("INNER_LOOP_TRAVERSAL_0", innerLoopTraversal("0", paddedTraversalEdgeValue(), descriptor)); - source.SetValue("INNER_LOOP_TRAVERSAL_1", innerLoopTraversal(paddedTraversalEdgeValue(), std::to_string(blockDimensions[1]), descriptor)); - source.SetValue("TRAVERSAL_OFFSET", traversalOffsetValue()); - source.SetValue("TRAVERSAL_DIMENSION", traversalDimensionValue()); - source.SetValue("BLOCK_DIMENSIONS_TRAVERSAL", std::to_string(blockDimensions[1])); + source.SetValue("INNER_LOOP_TRAVERSAL_0", innerLoopTraversal("0", paddedTraversalEdgeValue(), descriptor)); + source.SetValue("INNER_LOOP_TRAVERSAL_1", innerLoopTraversal(paddedTraversalEdgeValue(), std::to_string(blockDimensions[1]), descriptor)); + source.SetValue("TRAVERSAL_OFFSET", traversalOffsetValue()); + source.SetValue("TRAVERSAL_DIMENSION", traversalDimensionValue()); + source.SetValue("BLOCK_DIMENSIONS_TRAVERSAL", std::to_string(blockDimensions[1])); source += R"( #pragma clang loop unroll(full) @@ -2347,7 +2374,7 @@ std::string AttentionKernel::outerProduct(const AttentionOuterProductDescriptor& )"; } - return source.ToString(); + return source.ToString(); }; // MARK: - Outer Loop @@ -2379,14 +2406,14 @@ std::string AttentionKernel::outerProduct(const AttentionOuterProductDescriptor& } auto blockDim = blockDimensions[1]; - CodeWriter source; - source.SetValue("BLOCK_DIM", std::to_string(blockDim)); - source.SetValue("TRAVERSAL_DIMENSION", traversalDimensionValue()); - source.SetValue("TRAVERSAL_OFFSET", traversalOffsetValue()); - source.SetValue("HEAD_DIMENSION", std::to_string(headDimension)); - source.SetValue("DESCRIPTOR_REGISTER_SIZE", std::to_string(descriptor.registerSize)); - source.SetValue("LOOP_ITERATION_DEVICE", loopIteration(descriptorDevice)); - source.SetValue("LOOP_ITERATION_THREADGROUP", loopIteration(descriptorThreadgroup)); + CodeWriter source; + source.SetValue("BLOCK_DIM", std::to_string(blockDim)); + source.SetValue("TRAVERSAL_DIMENSION", traversalDimensionValue()); + source.SetValue("TRAVERSAL_OFFSET", traversalOffsetValue()); + source.SetValue("HEAD_DIMENSION", std::to_string(headDimension)); + source.SetValue("DESCRIPTOR_REGISTER_SIZE", std::to_string(descriptor.registerSize)); + source.SetValue("LOOP_ITERATION_DEVICE", loopIteration(descriptorDevice)); + source.SetValue("LOOP_ITERATION_THREADGROUP", loopIteration(descriptorThreadgroup)); source += R"( @@ -2403,7 +2430,7 @@ std::string AttentionKernel::outerProduct(const AttentionOuterProductDescriptor& } )"; - return source.ToString(); + return source.ToString(); }; // MARK: - Top Level Specification @@ -2466,10 +2493,10 @@ std::string AttentionKernel::outerProduct(const AttentionOuterProductDescriptor& descriptor.registerSize = blockDimensions[2]; CodeWriter source; - source.SetValue("UNROLL_STATEMENT", unrollStatement()); - source.SetValue("LOOP_END_FLOOR", std::to_string(loopEndFloor())); - source.SetValue("BLOCK_DIMENSIONS_HEAD", std::to_string(blockDimensions[2])); - source.SetValue("GATED_LOOP_ITERATION", gatedLoopIteration(descriptor)); + source.SetValue("UNROLL_STATEMENT", unrollStatement()); + source.SetValue("LOOP_END_FLOOR", std::to_string(loopEndFloor())); + source.SetValue("BLOCK_DIMENSIONS_HEAD", std::to_string(blockDimensions[2])); + source.SetValue("GATED_LOOP_ITERATION", gatedLoopIteration(descriptor)); source += R"( {{UNROLL_STATEMENT}} @@ -2482,7 +2509,7 @@ std::string AttentionKernel::outerProduct(const AttentionOuterProductDescriptor& } )"; - return source.ToString(); + return source.ToString(); }; auto lastIteration = @@ -2493,9 +2520,9 @@ std::string AttentionKernel::outerProduct(const AttentionOuterProductDescriptor& descriptor.registerSize = paddedHeadEdgeValue(); CodeWriter source; - source.SetValue("LOOP_END_FLOOR", std::to_string(loopEndFloor())); - source.SetValue("LOOP_END_FLOOR_LESS_LOOP_END", (loopEndFloor() < loopEnd()) ? "true" : "false"); - source.SetValue("GATED_LOOP_ITERATION", gatedLoopIteration(descriptor)); + source.SetValue("LOOP_END_FLOOR", std::to_string(loopEndFloor())); + source.SetValue("LOOP_END_FLOOR_LESS_LOOP_END", (loopEndFloor() < loopEnd()) ? "true" : "false"); + source.SetValue("GATED_LOOP_ITERATION", gatedLoopIteration(descriptor)); source += R"( if ({{LOOP_END_FLOOR_LESS_LOOP_END}}) { @@ -2504,7 +2531,7 @@ std::string AttentionKernel::outerProduct(const AttentionOuterProductDescriptor& } )"; - return source.ToString(); + return source.ToString(); }; // Collect all of the statements into one string. @@ -2539,20 +2566,21 @@ std::string AttentionKernel::computeD() const noexcept { if (cached(AttentionOperand::dO)) { return ""; } else { - CodeWriter source; - source.SetValue("MEMORY_NAME_DO", memoryName(AttentionOperand::dO)); - source.SetValue("LEADING_DIMENSION_DO", leadingDimension(AttentionOperand::dO)); - source.SetValue("TRANSPOSED_DO", transposed(AttentionOperand::dO) ? "true" : "false"); + CodeWriter source; + source.SetValue("DO_LOCATION", operandLocationValue(AttentionOperand::dO)); + source.SetValue("MEMORY_NAME_DO", memoryName(AttentionOperand::dO)); + source.SetValue("LEADING_DIMENSION_DO", leadingDimension(AttentionOperand::dO)); + source.SetValue("TRANSPOSED_DO", transposed(AttentionOperand::dO) ? "true" : "false"); source += R"( // Where the dO data will be read from. auto dO_src = simdgroup_matrix_storage<{{MEMORY_NAME_DO}}> ::apply_offset( - dO, {{LEADING_DIMENSION_DO}}, + {{DO_LOCATION}}, {{LEADING_DIMENSION_DO}}, offset_src, {{TRANSPOSED_DO}}); )"; - return source.ToString(); + return source.ToString(); } }; @@ -2565,11 +2593,11 @@ std::string AttentionKernel::computeD() const noexcept { )"; } else { - CodeWriter source; - source.SetValue("REGISTER_NAME_DO", registerName(AttentionOperand::dO)); - source.SetValue("LOAD_FUNCTION_DO", loadFunction(AttentionOperand::dO)); - source.SetValue("LEADING_DIMENSION_DO", leadingDimension(AttentionOperand::dO)); - source.SetValue("TRANSPOSED_DO", transposed(AttentionOperand::dO) ? "true" : "false"); + CodeWriter source; + source.SetValue("REGISTER_NAME_DO", registerName(AttentionOperand::dO)); + source.SetValue("LOAD_FUNCTION_DO", loadFunction(AttentionOperand::dO)); + source.SetValue("LEADING_DIMENSION_DO", leadingDimension(AttentionOperand::dO)); + source.SetValue("TRANSPOSED_DO", transposed(AttentionOperand::dO) ? "true" : "false"); source += R"( simdgroup_matrix_storage<{{REGISTER_NAME_DO}}> dO; @@ -2578,20 +2606,21 @@ std::string AttentionKernel::computeD() const noexcept { ushort2(d, 0), {{TRANSPOSED_DO}}); )"; - return source.ToString(); + return source.ToString(); } }; - CodeWriter source; - source.SetValue("CLAMPED_PARALLELIZATION_THREAD_OFFSET", clampedParallelizationThreadOffsetValue()); - source.SetValue("DECLARE_DERIVATIVE_O_LOCATION", declareDerivativeOLocation()); - source.SetValue("LOAD_DERIVATIVE_O", loadDerivativeO()); - source.SetValue("MEMORY_NAME_O", memoryName(AttentionOperand::O)); - source.SetValue("REGISTER_NAME_O", registerName(AttentionOperand::O)); - source.SetValue("LOAD_FUNCTION_O", loadFunction(AttentionOperand::O)); - source.SetValue("LEADING_DIMENSION_O", leadingDimension(AttentionOperand::O)); - source.SetValue("TRANSPOSED_O", transposed(AttentionOperand::O) ? "true" : "false"); - source.SetValue("TRUNCATED_HEAD_DIMENSION", std::to_string(truncatedHeadDimension)); + CodeWriter source; + source.SetValue("CLAMPED_PARALLELIZATION_THREAD_OFFSET", clampedParallelizationThreadOffsetValue()); + source.SetValue("DECLARE_DERIVATIVE_O_LOCATION", declareDerivativeOLocation()); + source.SetValue("LOAD_DERIVATIVE_O", loadDerivativeO()); + source.SetValue("O_LOCATION", operandLocationValue(AttentionOperand::O)); + source.SetValue("MEMORY_NAME_O", memoryName(AttentionOperand::O)); + source.SetValue("REGISTER_NAME_O", registerName(AttentionOperand::O)); + source.SetValue("LOAD_FUNCTION_O", loadFunction(AttentionOperand::O)); + source.SetValue("LEADING_DIMENSION_O", leadingDimension(AttentionOperand::O)); + source.SetValue("TRANSPOSED_O", transposed(AttentionOperand::O) ? "true" : "false"); + source.SetValue("TRUNCATED_HEAD_DIMENSION", std::to_string(truncatedHeadDimension)); source += R"( // Threads outside of the matrix along the row dimension, @@ -2605,7 +2634,7 @@ std::string AttentionKernel::computeD() const noexcept { // Where the O data will be read from. auto O_src = simdgroup_matrix_storage<{{MEMORY_NAME_O}}> ::apply_offset( - O, {{LEADING_DIMENSION_O}}, + {{O_LOCATION}}, {{LEADING_DIMENSION_O}}, offset_src, {{TRANSPOSED_O}}); // Going to use async copy to handle the matrix edge. @@ -2624,7 +2653,7 @@ std::string AttentionKernel::computeD() const noexcept { D_accumulator += float2(dO_value) * float2(O_value); } )"; - return source.ToString(); + return source.ToString(); }; // Parts of the dO * O reduction that fall on an indivisible edge. @@ -2653,26 +2682,28 @@ std::string AttentionKernel::computeD() const noexcept { return blockDimensions[0] * 8 * size; }; - CodeWriter source; - source.SetValue("TRUNCATED_HEAD_DIMENSION", std::to_string(truncatedHeadDimension)); - source.SetValue("PARALLELIZATION_GROUP_OFFSET", parallelizationGroupOffsetValue()); - source.SetValue("MEMORY_NAME_DO", memoryName(AttentionOperand::dO)); - source.SetValue("REGISTER_NAME_DO", registerName(AttentionOperand::dO)); - source.SetValue("LOAD_FUNCTION_DO", registerName(AttentionOperand::dO)); - source.SetValue("LEADING_DIMENSION_DO", leadingDimension(AttentionOperand::dO)); - source.SetValue("LEADING_BLOCK_DIMENSION_DO", std::to_string(leadingBlockDimension(AttentionOperand::dO))); - source.SetValue("TRANSPOSED_DO", transposed(AttentionOperand::dO) ? "true" : "false"); - source.SetValue("MEMORY_NAME_O", memoryName(AttentionOperand::O)); - source.SetValue("REGISTER_NAME_O", registerName(AttentionOperand::O)); - source.SetValue("LOAD_FUNCTION_O", registerName(AttentionOperand::O)); - source.SetValue("LEADING_DIMENSION_O", leadingDimension(AttentionOperand::O)); - source.SetValue("LEADING_BLOCK_DIMENSION_O", std::to_string(leadingBlockDimension(AttentionOperand::O))); - source.SetValue("TRANSPOSED_O", transposed(AttentionOperand::O) ? "true" : "false"); - source.SetValue("BLOCK_BYTES_DERIVATIVE_O", std::to_string(blockBytesDerivativeO())); - source.SetValue("BLOCK_DIMENSIONS_PARALLELIZATION", std::to_string(blockDimensions[0])); - source.SetValue("HEAD_DIMENSION", std::to_string(headDimension)); - source.SetValue("PARALLELIZATION_DIMENSION", parallelizationDimensionValue()); - source.SetValue("PARALLELIZATION_GROUP_OFFSET", parallelizationGroupOffsetValue()); + CodeWriter source; + source.SetValue("TRUNCATED_HEAD_DIMENSION", std::to_string(truncatedHeadDimension)); + source.SetValue("PARALLELIZATION_GROUP_OFFSET", parallelizationGroupOffsetValue()); + source.SetValue("DO_LOCATION", operandLocationValue(AttentionOperand::dO)); + source.SetValue("MEMORY_NAME_DO", memoryName(AttentionOperand::dO)); + source.SetValue("REGISTER_NAME_DO", registerName(AttentionOperand::dO)); + source.SetValue("LOAD_FUNCTION_DO", registerName(AttentionOperand::dO)); + source.SetValue("LEADING_DIMENSION_DO", leadingDimension(AttentionOperand::dO)); + source.SetValue("LEADING_BLOCK_DIMENSION_DO", std::to_string(leadingBlockDimension(AttentionOperand::dO))); + source.SetValue("TRANSPOSED_DO", transposed(AttentionOperand::dO) ? "true" : "false"); + source.SetValue("O_LOCATION", operandLocationValue(AttentionOperand::O)); + source.SetValue("MEMORY_NAME_O", memoryName(AttentionOperand::O)); + source.SetValue("REGISTER_NAME_O", registerName(AttentionOperand::O)); + source.SetValue("LOAD_FUNCTION_O", registerName(AttentionOperand::O)); + source.SetValue("LEADING_DIMENSION_O", leadingDimension(AttentionOperand::O)); + source.SetValue("LEADING_BLOCK_DIMENSION_O", std::to_string(leadingBlockDimension(AttentionOperand::O))); + source.SetValue("TRANSPOSED_O", transposed(AttentionOperand::O) ? "true" : "false"); + source.SetValue("BLOCK_BYTES_DERIVATIVE_O", std::to_string(blockBytesDerivativeO())); + source.SetValue("BLOCK_DIMENSIONS_PARALLELIZATION", std::to_string(blockDimensions[0])); + source.SetValue("HEAD_DIMENSION", std::to_string(headDimension)); + source.SetValue("PARALLELIZATION_DIMENSION", parallelizationDimensionValue()); + source.SetValue("PARALLELIZATION_GROUP_OFFSET", parallelizationGroupOffsetValue()); source += R"( threadgroup_barrier(mem_flags::mem_threadgroup); @@ -2683,11 +2714,11 @@ std::string AttentionKernel::computeD() const noexcept { auto dO_src = simdgroup_matrix_storage<{{MEMORY_NAME_DO}}> ::apply_offset( - dO, {{LEADING_DIMENSION_DO}}, + {{DO_LOCATION}}, {{LEADING_DIMENSION_DO}}, offset_src, {{TRANSPOSED_DO}}); auto O_src = simdgroup_matrix_storage<{{MEMORY_NAME_O}}> ::apply_offset( - O, {{LEADING_DIMENSION_O}}, + {{O_LOCATION}}, {{LEADING_DIMENSION_O}}, offset_src, {{TRANSPOSED_O}}); auto dO_dst = (threadgroup{{MEMORY_NAME_DO}})*)(threadgroup_block); @@ -2745,8 +2776,8 @@ std::string AttentionKernel::computeD() const noexcept { auto O_value = *(O.thread_elements()); D_accumulator += float2(dO_value) * float2(O_value); - )"; - return source.ToString(); + )"; + return source.ToString(); }; // Outer loop over the head dimension. @@ -2888,40 +2919,41 @@ std::string AttentionKernel::softmax(bool derivative) const noexcept { auto allocateOutput = [=]() -> std::string { auto blockDim = blockDimensions[1]; - CodeWriter source; - source.SetValue("BLOCK_DIM", std::to_string(blockDim)); + CodeWriter source; + source.SetValue("BLOCK_DIM", std::to_string(blockDim)); if (!derivative) { - source.SetValue("REGISTER_NAME_P", registerName(AttentionOperand::P)); + source.SetValue("REGISTER_NAME_P", registerName(AttentionOperand::P)); source += R"( simdgroup_matrix_storage<{{REGISTER_NAME_P}}> P_sram[{{BLOCK_DIM}} / 8]; )"; } else { - source.SetValue("REGISTER_NAME_DS", registerName(AttentionOperand::dS)); + source.SetValue("REGISTER_NAME_DS", registerName(AttentionOperand::dS)); source += R"( simdgroup_matrix_storage<{{REGISTER_NAME_DS}}> dS_sram[{{BLOCK_DIM}} / 8]; )"; } - return source.ToString(); + return source.ToString(); }; auto loadOperand = [=]() -> std::string { - CodeWriter source; - source.SetValue("OPERAND", operand.name()); - source.SetValue("TRAVERSAL_OFFSET", traversalOffsetValue()); - source.SetValue("MEMORY_NAME_OPERAND", memoryName(operand)); - source.SetValue("BLOCK_DIMENSIONS_TRAVERSAL", std::to_string(blockDimensions[1])); - source.SetValue("TRAVERSAL_DIMENSION", traversalDimensionValue()); - source.SetValue("PADDED_TRAVERSAL_EDGE", paddedTraversalEdgeValue()); + CodeWriter source; + source.SetValue("OPERAND", operand.name()); + source.SetValue("OPERAND_LOCATION", operandLocationValue(operand)); + source.SetValue("TRAVERSAL_OFFSET", traversalOffsetValue()); + source.SetValue("MEMORY_NAME_OPERAND", memoryName(operand)); + source.SetValue("BLOCK_DIMENSIONS_TRAVERSAL", std::to_string(blockDimensions[1])); + source.SetValue("TRAVERSAL_DIMENSION", traversalDimensionValue()); + source.SetValue("PADDED_TRAVERSAL_EDGE", paddedTraversalEdgeValue()); source += R"( threadgroup_barrier(mem_flags::mem_threadgroup); if (sidx == 0) { - auto {{OPERAND}}_src = {{OPERAND}} + {{TRAVERSAL_OFFSET}}; + auto {{OPERAND}}_src = {{OPERAND_LOCATION}} + {{TRAVERSAL_OFFSET}}; auto {{OPERAND}}_dst = (threadgroup {{MEMORY_NAME_OPERAND}}*)(threadgroup_block); @@ -2941,7 +2973,7 @@ std::string AttentionKernel::softmax(bool derivative) const noexcept { } )"; - return source.ToString(); + return source.ToString(); }; // Declares the source of L or D. @@ -2950,14 +2982,15 @@ std::string AttentionKernel::softmax(bool derivative) const noexcept { auto declareOperandLocation = [=](MTLAddressSpace addressSpace) -> std::string { - CodeWriter source; - source.SetValue("OPERAND", operand.name()); - source.SetValue("TRAVERSAL_OFFSET", traversalOffsetValue()); - source.SetValue("MEMORY_NAME_OPERAND", memoryName(operand)); + CodeWriter source; + source.SetValue("OPERAND", operand.name()); + source.SetValue("OPERAND_LOCATION", operandLocationValue(operand)); + source.SetValue("TRAVERSAL_OFFSET", traversalOffsetValue()); + source.SetValue("MEMORY_NAME_OPERAND", memoryName(operand)); if (addressSpace == MTLAddressSpace::device) { source += R"( - auto {{OPERAND}}_src = {{OPERAND}}; + auto {{OPERAND}}_src = {{OPERAND_LOCATION}}; {{OPERAND}}_src += {{TRAVERSAL_OFFSET}} + morton_offset.x; )"; @@ -2971,13 +3004,13 @@ std::string AttentionKernel::softmax(bool derivative) const noexcept { )"; } - return source.ToString(); + return source.ToString(); }; auto overwriteAttentionMatrixElements = [=]() -> std::string { - CodeWriter source; - source.SetValue("SCALE", dotProductScale(scale, derivative, headDimension)); + CodeWriter source; + source.SetValue("SCALE", dotProductScale(scale, derivative, headDimension)); if (!derivative) { source.SetValue("REGISTER_NAME_P", registerName(AttentionOperand::P)); @@ -3001,17 +3034,17 @@ std::string AttentionKernel::softmax(bool derivative) const noexcept { )"; } - return source.ToString(); + return source.ToString(); }; auto innerLoop = [=]() -> std::string { CodeWriter source; - source.SetValue("BLOCK_DIMENSIONS_TRAVERSAL", std::to_string(blockDimensions[1])); - source.SetValue("OVERWRITE_ATTENTION_MATRIX_ELEMENTS", overwriteAttentionMatrixElements()); - source.SetValue("OPERAND", operand.name()); - source.SetValue("LOAD_FUNCTION_OPERAND", loadFunction(operand)); - source.SetValue("REGISTER_NAME_OPERAND", registerName(operand)); + source.SetValue("BLOCK_DIMENSIONS_TRAVERSAL", std::to_string(blockDimensions[1])); + source.SetValue("OVERWRITE_ATTENTION_MATRIX_ELEMENTS", overwriteAttentionMatrixElements()); + source.SetValue("OPERAND", operand.name()); + source.SetValue("LOAD_FUNCTION_OPERAND", loadFunction(operand)); + source.SetValue("REGISTER_NAME_OPERAND", registerName(operand)); switch (type.value) { case AttentionKernelType::forward: source += R"( @@ -3023,7 +3056,7 @@ std::string AttentionKernel::softmax(bool derivative) const noexcept { } )"; - break; + break; case AttentionKernelType::backwardQuery: source += R"( @@ -3034,7 +3067,7 @@ std::string AttentionKernel::softmax(bool derivative) const noexcept { } )"; - break; + break; case AttentionKernelType::backwardKeyValue: source += R"( @@ -3051,9 +3084,9 @@ std::string AttentionKernel::softmax(bool derivative) const noexcept { } )"; - break; + break; } - return source.ToString(); + return source.ToString(); }; CodeWriter source; @@ -3070,16 +3103,16 @@ std::string AttentionKernel::softmax(bool derivative) const noexcept { } )"; - break; + break; case AttentionKernelType::backwardKeyValue: auto blockDim = blockDimensions[1]; - source.SetValue("BLOCK_DIM", std::to_string(blockDim)); - source.SetValue("NOT_PREFER_ASYNC_LOAD", !preferAsyncLoad ? "true" : "false"); - source.SetValue("TRAVERSAL_DIMENSION", traversalDimensionValue()); - source.SetValue("TRAVERSAL_OFFSET", traversalOffsetValue()); - source.SetValue("LOAD_OPERAND", loadOperand()); - source.SetValue("DECLARE_OPERAND_LOCATION_DEVICE", declareOperandLocation(MTLAddressSpace::device)); - source.SetValue("DECLARE_OPERAND_LOCATION_THREADGROUP", declareOperandLocation(MTLAddressSpace::threadgroup)); + source.SetValue("BLOCK_DIM", std::to_string(blockDim)); + source.SetValue("NOT_PREFER_ASYNC_LOAD", !preferAsyncLoad ? "true" : "false"); + source.SetValue("TRAVERSAL_DIMENSION", traversalDimensionValue()); + source.SetValue("TRAVERSAL_OFFSET", traversalOffsetValue()); + source.SetValue("LOAD_OPERAND", loadOperand()); + source.SetValue("DECLARE_OPERAND_LOCATION_DEVICE", declareOperandLocation(MTLAddressSpace::device)); + source.SetValue("DECLARE_OPERAND_LOCATION_THREADGROUP", declareOperandLocation(MTLAddressSpace::threadgroup)); source += R"( @@ -3097,7 +3130,7 @@ std::string AttentionKernel::softmax(bool derivative) const noexcept { } )"; - break; + break; } return source.ToString(); } diff --git a/lib/nnc/mfa/v2/AttentionKernel.hpp b/lib/nnc/mfa/v2/AttentionKernel.hpp index 4b1795699..afc6a0cc5 100644 --- a/lib/nnc/mfa/v2/AttentionKernel.hpp +++ b/lib/nnc/mfa/v2/AttentionKernel.hpp @@ -74,6 +74,8 @@ struct AttentionKernel { unsigned short paddedHeadEdgeValue() const noexcept; unsigned short threadgroupSizeValue() const noexcept; unsigned short createThreadgroupMemoryAllocation() const noexcept; + std::string operandLocationValue(AttentionOperand operand) const noexcept; + std::string operandLocationWithHeadOffsetValue(AttentionOperand operand) const noexcept; /// AttentionKernel+Source std::string createSource() const noexcept;