diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index db575a9558d7..39d6a348d455 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -84,8 +84,29 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] / auto mfmaEnc = dotOpEnc.getParent().dyn_cast(); if (mfmaEnc) { - // Swizzling is currently disabled for MFMA - return $_get(context, 1, 1, 1, order); + int kDimNum = dotOpEnc.getOpIdx() == 0 ? 1 : 0; + bool isKDimInner = (order[0] == kDimNum); + if (isKDimInner) { + const int numBanks = 32; + const int bankBitWidth = 32; + + // number of inner dimension rows per one pattern repeat + int outerDimGranularity = mfmaEnc.getNonKDim(); + int typeBitWidth = eltTy.getIntOrFloatBitWidth(); + int innerDimLength = shape[order[0]]; + int elemsPerOneBanksRow = (numBanks * bankBitWidth) / typeBitWidth; + + int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength); + int maxPhase = outerDimGranularity / perPhase; + int vecSize = innerDimLength / maxPhase; + assert(vecSize > 0); + + return $_get(context, vecSize, perPhase, maxPhase, order); + } else { + // Do not swizzle in case k dimension is not innermost. + // In this case accesses will go in different banks even without swizzling. + return $_get(context, 1, 1, 1, order); + } } #endif auto mmaEnc = dotOpEnc.getParent().dyn_cast();