Skip to content

Commit

Permalink
[MFMA] Swizzled operands (#285)
Browse files Browse the repository at this point in the history
This pr enables generation of swizzled tensors for mfma dot operands
  • Loading branch information
binarman authored Aug 9, 2023
1 parent af05f01 commit a1f4ee6
Showing 1 changed file with 23 additions and 2 deletions.
25 changes: 23 additions & 2 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,29 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
auto mfmaEnc = dotOpEnc.getParent().dyn_cast<MfmaEncodingAttr>();

if (mfmaEnc) {
// Swizzling is currently disabled for MFMA
return $_get(context, 1, 1, 1, order);
int kDimNum = dotOpEnc.getOpIdx() == 0 ? 1 : 0;
bool isKDimInner = (order[0] == kDimNum);
if (isKDimInner) {
const int numBanks = 32;
const int bankBitWidth = 32;

// number of inner dimension rows per one pattern repeat
int outerDimGranularity = mfmaEnc.getNonKDim();
int typeBitWidth = eltTy.getIntOrFloatBitWidth();
int innerDimLength = shape[order[0]];
int elemsPerOneBanksRow = (numBanks * bankBitWidth) / typeBitWidth;

int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength);
int maxPhase = outerDimGranularity / perPhase;
int vecSize = innerDimLength / maxPhase;
assert(vecSize > 0);

return $_get(context, vecSize, perPhase, maxPhase, order);
} else {
// Do not swizzle in case k dimension is not innermost.
// In this case accesses will go in different banks even without swizzling.
return $_get(context, 1, 1, 1, order);
}
}
#endif
auto mmaEnc = dotOpEnc.getParent().dyn_cast<MmaEncodingAttr>();
Expand Down

0 comments on commit a1f4ee6

Please sign in to comment.