Skip to content

Commit

Permalink
Further fix translation errors now forward with this particular confi…
Browse files Browse the repository at this point in the history
…guration will match.
  • Loading branch information
liuliu committed Sep 13, 2024
1 parent 20bba89 commit d71a4f5
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 22 deletions.
4 changes: 2 additions & 2 deletions lib/nnc/mfa/v2/AttentionDescriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ AttentionKernelDescriptor AttentionDescriptor::kernelDescriptor(MTL::Device *con
auto createBlockDimensions =
[=]() -> simd::ushort3 {
unsigned short parallelization = 16;
unsigned short traversal = 128;
unsigned short traversal = 64; // 128;
unsigned short originalHead = 16;
// Enforce the rule that head block dimension <= head dimension.
unsigned short headDimension = createHeadDimension();
Expand All @@ -48,7 +48,7 @@ AttentionKernelDescriptor AttentionDescriptor::kernelDescriptor(MTL::Device *con
switch (type.value) {
case AttentionKernelType::forward:
output[AttentionOperand::Q] = false;
output[AttentionOperand::O] = false;
output[AttentionOperand::O] = true; // false;
break;
case AttentionKernelType::backwardQuery:
output[AttentionOperand::Q] = false;
Expand Down
38 changes: 18 additions & 20 deletions lib/nnc/mfa/v2/AttentionKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -786,8 +786,7 @@ std::string AttentionKernel::accumulate(const AttentionAccumulateDescriptor& acc
source.SetValue("DESCRIPTOR_REGISTER_SIZE", std::to_string(descriptor.registerSize));
source += R"(
simdgroup_matrix_storage<{{REGISTER_NAME_C}}> \
{{C}}_sram[{{DESCRIPTOR_REGISTER_SIZE}} / 8];
simdgroup_matrix_storage<{{REGISTER_NAME_C}}> {{C}}_sram[{{DESCRIPTOR_REGISTER_SIZE}} / 8];
)";
return source.ToString();
Expand Down Expand Up @@ -1277,8 +1276,13 @@ std::string AttentionKernel::accumulate(const AttentionAccumulateDescriptor& acc
source.SetValue("ALLOCATE_ACCUMULATOR", allocateAccumulator(descriptor));
source.SetValue("TRAVERSAL_OFFSET", traversalOffsetValue());
source.SetValue("INITIALIZE_ACCUMULATOR", initializeAccumulator(descriptor));
source.SetValue("LOAD_ACCUMULATOR", loadAccumulator(descriptor));
source.SetValue("STORE_ACCUMULATOR", storeAccumulator(descriptor));
if (cached(C)) {
source.SetValue("LOAD_ACCUMULATOR", "");
source.SetValue("STORE_ACCUMULATOR", "");
} else {
source.SetValue("LOAD_ACCUMULATOR", loadAccumulator(descriptor));
source.SetValue("STORE_ACCUMULATOR", storeAccumulator(descriptor));
}
source.SetValue("LOAD_RHS", loadRHS(descriptor));
source.SetValue("MULTIPLY_AB", multiplyAB());
source.SetValue("SCALE_ACCUMULATOR", scaleAccumulator(accumulateDesc.everyIterationScale, descriptor));
Expand Down Expand Up @@ -1357,7 +1361,7 @@ std::string AttentionKernel::accumulate(const AttentionAccumulateDescriptor& acc

auto loopEndFloor =
[=]() -> unsigned short {
return loopEnd() - loopEnd() % blockDimensions[1];
return loopEnd() - loopEnd() % blockDimensions[2];
};

auto unrollStatement =
Expand Down Expand Up @@ -1444,8 +1448,7 @@ std::string AttentionKernel::cache(AttentionOperand operand, CachingOperationTyp
source.SetValue("PADDED_HEAD_DIMENSION_8", std::to_string(paddedHeadDimensionValue() / 8));
source += R"(
simdgroup_matrix_storage<{{REGISTER_NAME_OPERAND}}> \
{{OPERAND}}_sram[{{PADDED_HEAD_DIMENSION_8}}];
simdgroup_matrix_storage<{{REGISTER_NAME_OPERAND}}> {{OPERAND}}_sram[{{PADDED_HEAD_DIMENSION_8}}];
)";
return source.ToString();
Expand Down Expand Up @@ -1770,8 +1773,7 @@ std::string AttentionKernel::createSetup() const noexcept {
source.SetValue("PADDED_HEAD_DIMENSION_8", std::to_string(paddedHeadDimensionValue() / 8));
source += R"(
simdgroup_matrix_storage<{{REGISTER_NAME_OPERAND}}> \
{{OPERAND}}_sram[{{PADDED_HEAD_DIMENSION_8}}];
simdgroup_matrix_storage<{{REGISTER_NAME_OPERAND}}> {{OPERAND}}_sram[{{PADDED_HEAD_DIMENSION_8}}];
)";
return source.ToString();
Expand Down Expand Up @@ -1942,8 +1944,7 @@ std::string AttentionKernel::outerProduct(const AttentionOuterProductDescriptor&
source.SetValue("BLOCK_DIMENSIONS_TRAVERSAL", std::to_string(blockDimensions[1]));
source += R"(
simdgroup_matrix_storage<{{REGISTER_NAME_C}}> \
{{C}}_sram[{{BLOCK_DIMENSIONS_TRAVERSAL}} / 8];
simdgroup_matrix_storage<{{REGISTER_NAME_C}}> {{C}}_sram[{{BLOCK_DIMENSIONS_TRAVERSAL}} / 8];
)";
return source.ToString();
Expand Down Expand Up @@ -1987,8 +1988,7 @@ std::string AttentionKernel::outerProduct(const AttentionOuterProductDescriptor&
source.SetValue("DESCRIPTOR_REGISTER_SIZE", std::to_string(descriptor.registerSize));
source += R"(
simdgroup_matrix_storage<{{REGISTER_NAME_A}}> \
{{A}}_sram[{{DESCRIPTOR_REGISTER_SIZE}} / 8];
simdgroup_matrix_storage<{{REGISTER_NAME_A}}> {{A}}_sram[{{DESCRIPTOR_REGISTER_SIZE}} / 8];
)";
return source.ToString();
Expand Down Expand Up @@ -2168,7 +2168,7 @@ std::string AttentionKernel::outerProduct(const AttentionOuterProductDescriptor&
)";
break;
case MTLAddressSpace::threadgroup:
source.SetValue("LEADING_BLOCK_DIMENSION_B", leadingDimension(B));
source.SetValue("LEADING_BLOCK_DIMENSION_B", std::to_string(leadingBlockDimension(B)));
source.SetValue("NOT_TRANSPOSED_B", std::to_string(!transposed(B)));
source += R"(
Expand Down Expand Up @@ -2205,7 +2205,7 @@ std::string AttentionKernel::outerProduct(const AttentionOuterProductDescriptor&
source.SetValue("PADDED_TRAVERSAL_EDGE", paddedTraversalEdgeValue());
source.SetValue("LEADING_DIMENSION_B", leadingDimension(B));
source.SetValue("TRANSPOSED_B", std::to_string(transposed(B)));
source.SetValue("LEADING_BLOCK_DIMENSION_B", leadingDimension(B));
source.SetValue("LEADING_BLOCK_DIMENSION_B", std::to_string(leadingBlockDimension(B)));
source.SetValue("DESCRIPTOR_REGISTER_SIZE", std::to_string(descriptor.registerSize));
source.SetValue("DECLARE_RHS_LOCATION", declareRHSLocation(descriptor));
source += R"(
Expand Down Expand Up @@ -2261,7 +2261,7 @@ std::string AttentionKernel::outerProduct(const AttentionOuterProductDescriptor&
source.SetValue("REGISTER_NAME_B", registerName(B));
source.SetValue("LOAD_FUNCTION_B", loadFunction(B));
source.SetValue("LEADING_DIMENSION_RHS", leadingDimensionRHS(descriptor));
source.SetValue("NOT_TRANSPOSED_B", std::to_string(transposed(B)));
source.SetValue("NOT_TRANSPOSED_B", std::to_string(!transposed(B)));
source.SetValue("DESCRIPTOR_REGISTER_OFFSET", descriptor.registerOffset);
source.SetValue("DESCRIPTOR_ACCUMULATE_CONDITIONAL", descriptor.accumulateConditional);
source += R"(
Expand Down Expand Up @@ -2861,16 +2861,14 @@ std::string AttentionKernel::softmax(bool derivative) const noexcept {
source.SetValue("REGISTER_NAME_P", registerName(AttentionOperand::P));
source += R"(
simdgroup_matrix_storage<{{REGISTER_NAME_P}}> \
P_sram[{{BLOCK_DIM}} / 8];
simdgroup_matrix_storage<{{REGISTER_NAME_P}}> P_sram[{{BLOCK_DIM}} / 8];
)";
} else {
source.SetValue("REGISTER_NAME_DS", registerName(AttentionOperand::dS));
source += R"(
simdgroup_matrix_storage<{{REGISTER_NAME_DS}}> \
dS_sram[{{BLOCK_DIM}} / 8];
simdgroup_matrix_storage<{{REGISTER_NAME_DS}}> dS_sram[{{BLOCK_DIM}} / 8];
)";
}
Expand Down

0 comments on commit d71a4f5

Please sign in to comment.