Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MFMA] MFMA 4x64 64x4 version 2 #539

Draft
wants to merge 4 commits into
base: triton-mlir
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ compared to 1*64 when the hasLeadingOffset is false.

if (mfmaEnc) {
int kDimNum = dotOpEnc.getOpIdx() == 0 ? 1 : 0;
int nonKDimNum = 1 - kDimNum;
if (needTrans)
kDimNum = 1 - kDimNum;
bool isKDimInner = (order[0] == kDimNum);
Expand All @@ -143,17 +144,39 @@ compared to 1*64 when the hasLeadingOffset is false.
int innerDimLength = shape[order[0]];
int elemsPerOneBanksRow = (numBanks * bankBitWidth) / typeWidthInBit;

int mDim = mfmaEnc.getMDim();
int nDim = mfmaEnc.getNDim();
int nonKDim = dotOpEnc.getOpIdx() == 0 ? mDim : nDim;
if ((mDim == 4 && nDim == 64) || (nDim == 4 && mDim == 64)) {
// Operands of the layout have following shapes
// Large operand:
// - shape 64(non-k)x64(k) for 16 bit dtypes
// - shape 64(non-k)x16(k) for 32 bit dtypes
// Small operand:
// - shape 4(non-k)x64(k) for 16 bit dtypes
// - shape 4(non-k)x16(k) for 32 bit dtypes
const int vecSize = bankBitWidth / typeWidthInBit;
const int perPhase = std::max(1, numBanks / innerDimLength);
const int maxPhase = std::min(numBanks, nonKDim) / perPhase;
return get(context, vecSize, perPhase, maxPhase, order, CTALayout);
}
int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength);
// vecSize is set to kWidth of the dotop layout
int vecSize = dotOpEnc.getKWidth();
// maxPhase is set to SIMDWidth / perPhase
int maxPhase = std::min(SIMDWidth / perPhase, innerDimLength / vecSize);
// TODO (zhanglx): figure out better parameters for mfma4
auto mDim = mfmaEnc.getMDim();
auto nDim = mfmaEnc.getNDim();
auto nonKDim = dotOpEnc.getOpIdx() == 0 ? mDim : nDim;
if (4 == nonKDim)
maxPhase = 4;
// if maxPhase * perPhase is larger than one block of warps,
// fallback to unswizzled tensor.
// Shared to dot op conversion requires that swizzling patern
// fits into one block of warps.
auto warpsPerCTA = mfmaEnc.getWarpsPerCTA();
if (maxPhase * perPhase > nonKDim * warpsPerCTA[nonKDimNum]) {
assert(isKDimInner);
maxPhase = 1;
}
assert(maxPhase > 0);

return get(context, vecSize, perPhase, maxPhase, order, CTALayout);
Expand Down
6 changes: 4 additions & 2 deletions include/triton/Dialect/TritonGPU/Transforms/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ struct MfmaInsnAttr {
unsigned n;
unsigned k;
// k_base refers to the number of elements per thread
unsigned k_base;
unsigned k_base_a;
unsigned k_base_b;
llvm::StringRef insn;
};

Expand Down Expand Up @@ -223,7 +224,8 @@ class MfmaInsn {
unsigned getMDim();
unsigned getNDim();
StringRef getInsnName();
unsigned getKBase();
unsigned getKBaseA();
unsigned getKBaseB();
};
} // namespace mlir

Expand Down
3 changes: 2 additions & 1 deletion lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,8 @@ bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
dotOperandLayout.getOpIdx() == 0 &&
dotOperandLayout.getKWidth() == 4 &&
dotOperandLayout.getParent() == mfmaLayout &&
(mfmaLayout.getMDim() == 32 || mfmaLayout.getMDim() == 16) &&
(mfmaLayout.getMDim() == 32 || mfmaLayout.getMDim() == 16 ||
(mfmaLayout.getMDim() == 4 && mfmaLayout.getNDim() == 64)) &&
mfmaLayout.getIsTransposed() &&
(srcTy.getElementType().isF16() || srcTy.getElementType().isBF16());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,12 @@ llvm::SmallVector<llvm::SmallVector<Value>> computeTensorElemMappingInBlock(
if (iNonKDim == 32)
laneHOffset = select(icmp_uge(laneId, _32), i32_val(numOfElems), _0);
else {
// In this configuration wave contains 16 copies of same data
if ((iKDim == 1 || iKDim == 4) && iNonKDim == 4) {
// shortcut for 64x64 tile size.
// In this case warp do not wrap, so no need to introduce this offset
if (iNonKDim == 64)
laneHOffset = i32_val(0);
} else {
assert(iKDim * iNonKDim / numOfElems == 64 &&
"seems no all threads in wave contain unique elements");
else
laneHOffset = mul(udiv(laneId, nonKDim), i32_val(numOfElems));
}
}

for (int loadId = 0; loadId < loadsPerThread; ++loadId) {
Expand Down Expand Up @@ -346,7 +344,7 @@ fastPathComputeOffsets(ConversionPatternRewriter &rewriter, Location loc,
// 32 33 34 35 ... 63
// 32 33 34 35 ... 63
Value halfOffset;
if ((iKDim == 1 || iKDim == 4) && iNonKDim == 4)
if (iNonKDim == 64)
halfOffset = i32_val(0);
else
halfOffset =
Expand Down Expand Up @@ -456,6 +454,8 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
int numSubBlocks = 1;
if ((mfmaInstrK == 4 || mfmaInstrK == 1) && mfmaInstrNonK == 4)
numSubBlocks = 16;
assert(numSubBlocks == 1 &&
"after reworking layout, there should be no redundency");
int numOfElems = mfmaInstrNonK * mfmaInstrK * numSubBlocks / iWaveSize;
assert(numOfElems >= 1);

Expand Down
Loading