Skip to content

Commit

Permalink
Add AttentionDescriptor+Parameters and still passes.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Sep 13, 2024
1 parent 3cc2117 commit 290be87
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 5 deletions.
1 change: 1 addition & 0 deletions bin/nnc/square_attention_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,7 @@ void validateProblemSize(int sequenceDimension, int headDimension)
} else {
check(O, resultO, 2e-5);
}
ccfree(resultO);
}
}

Expand Down
172 changes: 168 additions & 4 deletions lib/nnc/mfa/v2/AttentionDescriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@ AttentionKernelDescriptor AttentionDescriptor::kernelDescriptor(MTL::Device *con
[=]() -> unsigned short {
return matrixDimensions[2];
};
std::vector table = parameterFile(type, device);
auto row = this->row(table);
auto createBlockDimensions =
[=]() -> simd::ushort3 {
unsigned short parallelization = 16;
unsigned short traversal = 64; // 128;
unsigned short originalHead = 16;
unsigned short parallelization = row.parallelization;
unsigned short traversal = row.traversal;
unsigned short originalHead = row.head;
// Enforce the rule that head block dimension <= head dimension.
unsigned short headDimension = createHeadDimension();
unsigned short paddedHeadDimension = (headDimension + 7) / 8 * 8;
Expand All @@ -48,7 +50,7 @@ AttentionKernelDescriptor AttentionDescriptor::kernelDescriptor(MTL::Device *con
switch (type.value) {
case AttentionKernelType::forward:
output[AttentionOperand::Q] = false;
output[AttentionOperand::O] = true; // false;
output[AttentionOperand::O] = false;
break;
case AttentionKernelType::backwardQuery:
output[AttentionOperand::Q] = false;
Expand All @@ -62,6 +64,10 @@ AttentionKernelDescriptor AttentionDescriptor::kernelDescriptor(MTL::Device *con
output[AttentionOperand::dK] = false;
break;
}
auto cachedOperands = row.cachedOperands;
for (const auto& operand : cachedOperands) {
output[operand] = true;
}
return output;
};

Expand Down Expand Up @@ -135,6 +141,8 @@ std::pair<AttentionKernelDescriptor, PipelineValue<AttentionKernel> *> Attention
return std::make_pair(kernelDesc, output);
}

// MARK: - AttentionDescriptor+Precisions

AttentionOperands<GEMMOperandPrecision> AttentionDescriptor::createMemoryPrecisions() const noexcept {
AttentionOperands<GEMMOperandPrecision> memoryPrecisions;

Expand Down Expand Up @@ -340,3 +348,159 @@ AttentionOperands<GEMMOperandPrecision> AttentionDescriptor::createRegisterPreci
return registerPrecisions;

}

// MARK: - AttentionDescriptor+Parameters

std::vector<AttentionParameterRow> AttentionDescriptor::parameterFile(AttentionKernelType type, MTL::Device *const device) const noexcept {
if (lowPrecisionInputs && lowPrecisionIntermediates) {
switch (type.value) {
case AttentionKernelType::forward:
return forwardMixed(device);
case AttentionKernelType::backwardQuery:
return backwardQueryMixed(device);
case AttentionKernelType::backwardKeyValue:
return backwardKeyValueMixed(device);
}
} else {
switch (type.value) {
case AttentionKernelType::forward:
return forward(device);
case AttentionKernelType::backwardQuery:
return backwardQuery(device);
case AttentionKernelType::backwardKeyValue:
return backwardKeyValue(device);
}
}
return defaultParameters(device);
}

AttentionParameterRow AttentionDescriptor::row(const std::vector<AttentionParameterRow>& table) const noexcept {
auto headDimension = matrixDimensions[2];
int matchedRowID = table.size() - 1;
for (int i = 0; i < table.size(); i++) {
if (headDimension <= table[i].maximumHeadDimension) {
matchedRowID = i;
break;
}
}
return table[matchedRowID];
}

std::vector<AttentionParameterRow> AttentionDescriptor::defaultParameters(MTL::Device *const device) const noexcept {
if (device->supportsFamily(MTL::GPUFamily(1009))) {
return { AttentionParameterRow(0, 16, 128, 16, {}) };
} else {
return { AttentionParameterRow(0, 32, 80, 16, {}) };
}
}

std::vector<AttentionParameterRow> AttentionDescriptor::forwardMixed(MTL::Device *const device) const noexcept {
if (device->supportsFamily(MTL::GPUFamily(1009))) {
return {
AttentionParameterRow(32, 16, 128, 16, { AttentionOperand::Q, AttentionOperand::O }),
AttentionParameterRow(96, 16, 128, 32, { AttentionOperand::Q, AttentionOperand::O }),
AttentionParameterRow(160, 16, 128, 32, { AttentionOperand::O }),
AttentionParameterRow(224, 16, 128, 32, { AttentionOperand::Q }),
AttentionParameterRow(384, 16, 128, 32, {})
};
} else {
return {
AttentionParameterRow(96, 32, 128, 32, { AttentionOperand::Q, AttentionOperand::O }),
AttentionParameterRow(128, 32, 128, 32, { AttentionOperand::Q }),
AttentionParameterRow(384, 32, 128, 32, {})
};
}
}

std::vector<AttentionParameterRow> AttentionDescriptor::forward(MTL::Device *const device) const noexcept {
if (device->supportsFamily(MTL::GPUFamily(1009))) {
return {
AttentionParameterRow(8, 16, 128, 16, { AttentionOperand::Q, AttentionOperand::O }),
AttentionParameterRow(16, 16, 64, 16, { AttentionOperand::Q, AttentionOperand::O }),
AttentionParameterRow(48, 16, 32, 8, { AttentionOperand::Q, AttentionOperand::O }),
AttentionParameterRow(192, 16, 64, 16, { AttentionOperand::O }),
AttentionParameterRow(384, 16, 128, 16, {})
};
} else {
return {
AttentionParameterRow(24, 32, 64, 24, { AttentionOperand::Q, AttentionOperand::O }),
AttentionParameterRow(32, 32, 64, 32, { AttentionOperand::O }),
AttentionParameterRow(56, 32, 32, 56, { AttentionOperand::Q }),
AttentionParameterRow(384, 32, 80, 16, {})
};
}
}

std::vector<AttentionParameterRow> AttentionDescriptor::backwardQueryMixed(MTL::Device *const device) const noexcept {
if (device->supportsFamily(MTL::GPUFamily(1009))) {
return {
AttentionParameterRow(80, 16, 64, 8, { AttentionOperand::Q, AttentionOperand::dO, AttentionOperand::dQ }),
AttentionParameterRow(192, 16, 64, 32, { AttentionOperand::Q, AttentionOperand::dQ }),
AttentionParameterRow(384, 16, 128, 32, {})
};
} else {
return {
AttentionParameterRow(32, 32, 64, 32, { AttentionOperand::Q, AttentionOperand::dQ }),
AttentionParameterRow(96, 32, 64, 32, { AttentionOperand::dQ }),
AttentionParameterRow(384, 32, 64, 32, {})
};
}
}

std::vector<AttentionParameterRow> AttentionDescriptor::backwardQuery(MTL::Device *const device) const noexcept {
if (device->supportsFamily(MTL::GPUFamily(1009))) {
return {
AttentionParameterRow(16, 16, 64, 8, { AttentionOperand::Q, AttentionOperand::dO, AttentionOperand::dQ }),
AttentionParameterRow(32, 16, 64, 16, { AttentionOperand::Q, AttentionOperand::dQ }),
AttentionParameterRow(192, 16, 64, 32, { AttentionOperand::Q, AttentionOperand::dQ }),
AttentionParameterRow(384, 16, 128, 16, {})
};
} else {
return {
AttentionParameterRow(16, 32, 64, 16, { AttentionOperand::Q, AttentionOperand::dQ }),
AttentionParameterRow(32, 32, 64, 32, { AttentionOperand::dQ }),
AttentionParameterRow(56, 32, 64, 24, { AttentionOperand::dQ }),
AttentionParameterRow(384, 32, 80, 16, {})
};
}
}

std::vector<AttentionParameterRow> AttentionDescriptor::backwardKeyValueMixed(MTL::Device *const device) const noexcept {
if (device->supportsFamily(MTL::GPUFamily(1009))) {
return {
AttentionParameterRow(56, 16, 64, 8, { AttentionOperand::K, AttentionOperand::V, AttentionOperand::dV, AttentionOperand::dK }),
AttentionParameterRow(80, 16, 32, 16, { AttentionOperand::V, AttentionOperand::dV, AttentionOperand::dK }),
AttentionParameterRow(144, 16, 128, 16, { AttentionOperand::dV, AttentionOperand::dK }),
AttentionParameterRow(224, 16, 128, 16, { AttentionOperand::dV }),
AttentionParameterRow(384, 16, 128, 32, {})
};
} else {
return {
AttentionParameterRow(16, 32, 64, 16, { AttentionOperand::V, AttentionOperand::dV, AttentionOperand::dK }),
AttentionParameterRow(32, 32, 64, 32, { AttentionOperand::dV, AttentionOperand::dK }),
AttentionParameterRow(56, 32, 80, 32, { AttentionOperand::dV }),
AttentionParameterRow(96, 32, 64, 32, { AttentionOperand::dV }),
AttentionParameterRow(384, 32, 64, 32, {})
};
}
}

std::vector<AttentionParameterRow> AttentionDescriptor::backwardKeyValue(MTL::Device *const device) const noexcept {
if (device->supportsFamily(MTL::GPUFamily(1009))) {
return {
AttentionParameterRow(16, 16, 64, 8, { AttentionOperand::K, AttentionOperand::V, AttentionOperand::dV, AttentionOperand::dK }),
AttentionParameterRow(32, 16, 32, 16, { AttentionOperand::K, AttentionOperand::V, AttentionOperand::dV, AttentionOperand::dK }),
AttentionParameterRow(64, 16, 32, 16, { AttentionOperand::V, AttentionOperand::dV, AttentionOperand::dK }),
AttentionParameterRow(128, 16, 128, 16, { AttentionOperand::dV, AttentionOperand::dK }),
AttentionParameterRow(160, 16, 128, 16, { AttentionOperand::dV }),
AttentionParameterRow(384, 16, 128, 16, {})
};
} else {
return {
AttentionParameterRow(16, 32, 32, 16, { AttentionOperand::V, AttentionOperand::dV, AttentionOperand::dK }),
AttentionParameterRow(24, 32, 64, 24, { AttentionOperand::dV, AttentionOperand::dK }),
AttentionParameterRow(56, 32, 80, 16, { AttentionOperand::dV }),
AttentionParameterRow(384, 32, 80, 16, {})
};
}
}
21 changes: 21 additions & 0 deletions lib/nnc/mfa/v2/AttentionDescriptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,16 @@
struct AttentionKernelDescriptor;
struct AttentionKernel;

struct AttentionParameterRow {
unsigned short maximumHeadDimension;
unsigned short parallelization;
unsigned short traversal;
unsigned short head;
std::vector<AttentionOperand> cachedOperands;
AttentionParameterRow() = delete;
AttentionParameterRow(unsigned short maximumHeadDimension, unsigned short parallelization, unsigned short traversal, unsigned short head, std::vector<AttentionOperand> cachedOperands) noexcept : maximumHeadDimension(maximumHeadDimension), parallelization(parallelization), traversal(traversal), head(head), cachedOperands(cachedOperands) {}
};

struct AttentionDescriptor {
/// Q, K, V, dO
bool lowPrecisionInputs;
Expand All @@ -37,8 +47,19 @@ struct AttentionDescriptor {

private:
AttentionKernelDescriptor kernelDescriptor(MTL::Device *const device, const DeviceProperties &dprops) const noexcept;
/// AttentionDescriptor+Precisions
AttentionOperands<GEMMOperandPrecision> createMemoryPrecisions() const noexcept;
AttentionOperands<GEMMOperandPrecision> createRegisterPrecisions(MTL::Device *const device) const noexcept;
/// AttentionDescriptor+Parameters
std::vector<AttentionParameterRow> parameterFile(AttentionKernelType type, MTL::Device *const device) const noexcept;
AttentionParameterRow row(const std::vector<AttentionParameterRow>& table) const noexcept;
std::vector<AttentionParameterRow> defaultParameters(MTL::Device *const device) const noexcept;
std::vector<AttentionParameterRow> forwardMixed(MTL::Device *const device) const noexcept;
std::vector<AttentionParameterRow> forward(MTL::Device *const device) const noexcept;
std::vector<AttentionParameterRow> backwardQueryMixed(MTL::Device *const device) const noexcept;
std::vector<AttentionParameterRow> backwardQuery(MTL::Device *const device) const noexcept;
std::vector<AttentionParameterRow> backwardKeyValueMixed(MTL::Device *const device) const noexcept;
std::vector<AttentionParameterRow> backwardKeyValue(MTL::Device *const device) const noexcept;
};

template<>
Expand Down
2 changes: 1 addition & 1 deletion lib/nnc/mfa/v2/AttentionKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1470,7 +1470,7 @@ std::string AttentionKernel::cache(AttentionOperand operand, CachingOperationTyp
source.SetValue("LEADING_DIMENSION_OPERAND", leadingDimension(operand));
source.SetValue("TRANSPOSED_OPERAND", transposed(operand) ? "true" : "false");
source.SetValue("HEAD_DIMENSION", std::to_string(headDimension));
source.SetValue("PADED_HEAD_DIMENSION", std::to_string(paddedHeadDimensionValue()));
source.SetValue("PADDED_HEAD_DIMENSION", std::to_string(paddedHeadDimensionValue()));
source.SetValue("BLOCK_DIMENSIONS_HEAD", std::to_string(blockDimensions[2]));
source.SetValue("PARALLELIZATION_DIMENSION", parallelizationDimensionValue());
source.SetValue("PARALLELIZATION_GROUP_OFFSET", parallelizationGroupOffsetValue());
Expand Down

0 comments on commit 290be87

Please sign in to comment.