diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index c8fdc1c70f5c..82f1eb8b9283 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -131,6 +131,7 @@ compared to 1*64 when the hasLeadingOffset is false. if (mfmaEnc) { int kDimNum = dotOpEnc.getOpIdx() == 0 ? 1 : 0; + int nonKDimNum = 1 - kDimNum; if (needTrans) kDimNum = 1 - kDimNum; bool isKDimInner = (order[0] == kDimNum); @@ -143,17 +144,39 @@ compared to 1*64 when the hasLeadingOffset is false. int innerDimLength = shape[order[0]]; int elemsPerOneBanksRow = (numBanks * bankBitWidth) / typeWidthInBit; + int mDim = mfmaEnc.getMDim(); + int nDim = mfmaEnc.getNDim(); + int nonKDim = dotOpEnc.getOpIdx() == 0 ? mDim : nDim; + if ((mDim == 4 && nDim == 64) || (nDim == 4 && mDim == 64)) { + // Operands of the layout have following shapes + // Large operand: + // - shape 64(non-k)x64(k) for 16 bit dtypes + // - shape 64(non-k)x16(k) for 32 bit dtypes + // Small operand: + // - shape 4(non-k)x64(k) for 16 bit dtypes + // - shape 4(non-k)x16(k) for 32 bit dtypes + const int vecSize = bankBitWidth / typeWidthInBit; + const int perPhase = std::max(1, numBanks / innerDimLength); + const int maxPhase = std::min(numBanks, nonKDim) / perPhase; + return get(context, vecSize, perPhase, maxPhase, order, CTALayout); + } int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength); // vecSize is set to kWidth of the dotop layout int vecSize = dotOpEnc.getKWidth(); // maxPhase is set to SIMDWidth / perPhase int maxPhase = std::min(SIMDWidth / perPhase, innerDimLength / vecSize); // TODO (zhanglx): figure out better parameters for mfma4 - auto mDim = mfmaEnc.getMDim(); - auto nDim = mfmaEnc.getNDim(); - auto nonKDim = dotOpEnc.getOpIdx() == 0 ? mDim : nDim; if (4 == nonKDim) maxPhase = 4; + // if maxPhase * perPhase is larger than one block of warps, + // fallback to unswizzled tensor. + // Shared to dot op conversion requires that swizzling patern + // fits into one block of warps. + auto warpsPerCTA = mfmaEnc.getWarpsPerCTA(); + if (maxPhase * perPhase > nonKDim * warpsPerCTA[nonKDimNum]) { + assert(isKDimInner); + maxPhase = 1; + } assert(maxPhase > 0); return get(context, vecSize, perPhase, maxPhase, order, CTALayout); diff --git a/include/triton/Dialect/TritonGPU/Transforms/Utility.h b/include/triton/Dialect/TritonGPU/Transforms/Utility.h index 9cb4de97d744..5afa922665ef 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Utility.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Utility.h @@ -175,7 +175,8 @@ struct MfmaInsnAttr { unsigned n; unsigned k; // k_base refers to the number of elements per thread - unsigned k_base; + unsigned k_base_a; + unsigned k_base_b; llvm::StringRef insn; }; @@ -223,7 +224,8 @@ class MfmaInsn { unsigned getMDim(); unsigned getNDim(); StringRef getInsnName(); - unsigned getKBase(); + unsigned getKBaseA(); + unsigned getKBaseB(); }; } // namespace mlir diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 6dbc10b943a6..51c7ed03b1a9 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -571,7 +571,8 @@ bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { dotOperandLayout.getOpIdx() == 0 && dotOperandLayout.getKWidth() == 4 && dotOperandLayout.getParent() == mfmaLayout && - (mfmaLayout.getMDim() == 32 || mfmaLayout.getMDim() == 16) && + (mfmaLayout.getMDim() == 32 || mfmaLayout.getMDim() == 16 || + (mfmaLayout.getMDim() == 4 && mfmaLayout.getNDim() == 64)) && mfmaLayout.getIsTransposed() && (srcTy.getElementType().isF16() || srcTy.getElementType().isBF16()); } diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp index 86a4153603b2..77d5f6ca5160 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp @@ -158,14 +158,12 @@ llvm::SmallVector> computeTensorElemMappingInBlock( if (iNonKDim == 32) laneHOffset = select(icmp_uge(laneId, _32), i32_val(numOfElems), _0); else { - // In this configuration wave contains 16 copies of same data - if ((iKDim == 1 || iKDim == 4) && iNonKDim == 4) { + // shortcut for 64x64 tile size. + // In this case warp do not wrap, so no need to introduce this offset + if (iNonKDim == 64) laneHOffset = i32_val(0); - } else { - assert(iKDim * iNonKDim / numOfElems == 64 && - "seems no all threads in wave contain unique elements"); + else laneHOffset = mul(udiv(laneId, nonKDim), i32_val(numOfElems)); - } } for (int loadId = 0; loadId < loadsPerThread; ++loadId) { @@ -346,7 +344,7 @@ fastPathComputeOffsets(ConversionPatternRewriter &rewriter, Location loc, // 32 33 34 35 ... 63 // 32 33 34 35 ... 63 Value halfOffset; - if ((iKDim == 1 || iKDim == 4) && iNonKDim == 4) + if (iNonKDim == 64) halfOffset = i32_val(0); else halfOffset = @@ -456,6 +454,8 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, int numSubBlocks = 1; if ((mfmaInstrK == 4 || mfmaInstrK == 1) && mfmaInstrNonK == 4) numSubBlocks = 16; + assert(numSubBlocks == 1 && + "after reworking layout, there should be no redundency"); int numOfElems = mfmaInstrNonK * mfmaInstrK * numSubBlocks / iWaveSize; assert(numOfElems >= 1); diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MFMA.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MFMA.cpp index 10bec3614969..94809cbdf360 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MFMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MFMA.cpp @@ -37,7 +37,12 @@ using ::mlir::triton::gpu::DotOperandEncodingAttr; using ::mlir::triton::gpu::MfmaEncodingAttr; using ::mlir::triton::gpu::SharedEncodingAttr; -using ValueTable = std::map, Value>; +// mapping from touple to vector of values +// vector contains single element for MFMA32, MFMA16 and MFMA4 layouts +// for MFMA 4x64 and 64x4 layouts there are 16 vectors for one of the arguments, +// because each repetition in these layouts requires 16 mfma operations +using ValueTable = std::map, + llvm::SmallVector>; struct DotOpMFMAConversionHelper { MfmaEncodingAttr mfmaLayout; @@ -60,16 +65,114 @@ struct DotOpMFMAConversionHelper { return rewriter.create(loc, i32_ty, tid); } + /** + * @param mfmaInsnName + * @param valA + * @param valB + * @param valC + * @param cbsz Control Broadcast Size modifier + * @param abid A-matrix Broadcast Identifier + * @param blgp B-matrix Lane Group Pattern modifier + */ Value generateMFMAOp(StringRef mfmaInsnName, Value valA, Value valB, - Value valC) const { + Value valC, int cbsz = 0, int abid = 0, + int blgp = 0) const { + assert(cbsz >= 0 && cbsz <= 4); + assert(abid >= 0 && abid <= 15); + assert(blgp >= 0 && blgp <= 7); auto resType = valC.getType(); - Value zeroFlag = i32_val(0); + Value zeroVal = i32_val(0); + Value cbszFlag = cbsz != 0 ? i32_val(cbsz) : zeroVal; + Value abidFlag = abid != 0 ? i32_val(abid) : zeroVal; + Value blgpFlag = blgp != 0 ? i32_val(blgp) : zeroVal; OperationState loweredOp(loc, mfmaInsnName); loweredOp.addTypes(resType); - loweredOp.addOperands({valA, valB, valC, zeroFlag, zeroFlag, zeroFlag}); + loweredOp.addOperands({valA, valB, valC, cbszFlag, abidFlag, blgpFlag}); return rewriter.create(loweredOp)->getResult(0); } + Value broadcastGroup(Value val, int groupId, int numGroups) const { + constexpr int waveSize = 64; + const int groupSize = waveSize / numGroups; + + Value lane = getThreadId(); + // Multiply by 4, because permute requires offset in bytes + Value laneOffset = mul(urem(lane, i32_val(groupSize)), i32_val(4)); + Value permuteAddr = add(laneOffset, i32_val(groupId * groupSize * 4)); + Type valType = val.getType(); + Value broadcasted; + if (valType.isInteger(32)) + broadcasted = rewriter.create(loc, val.getType(), + permuteAddr, val); + if (valType.isF32()) { + val = bitcast(val, i32_ty); + broadcasted = rewriter.create(loc, val.getType(), + permuteAddr, val); + broadcasted = bitcast(broadcasted, f32_ty); + } + if (valType.isa()) { + auto vecTy = valType.dyn_cast(); + auto vecBitSize = vecTy.getElementType().getIntOrFloatBitWidth() * + vecTy.getNumElements(); + const int int32VecSize = vecBitSize / 32; + + Type int32VecTy = vec_ty(i32_ty, int32VecSize); + Value int32Val = bitcast(val, int32VecTy); + Value int32Broadcasted = undef(int32VecTy); + for (int i = 0; i < int32VecSize; ++i) { + Value int32Chunk = extract_element(i32_ty, int32Val, i32_val(i)); + Value broadcastedChunk = rewriter.create( + loc, i32_ty, permuteAddr, int32Chunk); + int32Broadcasted = insert_element(int32VecTy, int32Broadcasted, + broadcastedChunk, i32_val(i)); + } + broadcasted = bitcast(int32Broadcasted, valType); + } + assert(broadcasted); + return broadcasted; + } + + Value generateMFMATile(StringRef mfmaInsnName, SmallVector valA, + SmallVector valB, Value valC, int mDim, + int nDim, bool transpose) const { + + Value acc; + if (mDim == nDim) { + assert(valA.size() == 1 && valB.size() == 1); + acc = transpose ? generateMFMAOp(mfmaInsnName, valB[0], valA[0], valC) + : generateMFMAOp(mfmaInsnName, valA[0], valB[0], valC); + } + if (mDim == 4 && nDim == 64 || mDim == 64 && nDim == 4) { + // broadcast selected kRep A operand matrix to all A matrices(2^4=16) + constexpr int broadcastCtrl = 4; + constexpr int numRepeats = 16; + acc = valC; + for (int kRep = 0; kRep < numRepeats; kRep++) { + if (mDim == 4 && !transpose) { + assert(valA.size() == 1 && valB.size() == 16); + acc = generateMFMAOp(mfmaInsnName, valA[0], valB[kRep], acc, + broadcastCtrl, kRep); + } + if (mDim == 4 && transpose) { + assert(valA.size() == 1 && valB.size() == 16); + Value broadcastValA = broadcastGroup(valA[0], kRep, numRepeats); + acc = generateMFMAOp(mfmaInsnName, valB[kRep], broadcastValA, acc); + } + if (nDim == 4 && !transpose) { + assert(valA.size() == 16 && valB.size() == 1); + Value broadcastValB = broadcastGroup(valB[0], kRep, numRepeats); + acc = generateMFMAOp(mfmaInsnName, valA[kRep], broadcastValB, acc); + } + if (nDim == 4 && transpose) { + assert(valA.size() == 16 && valB.size() == 1); + acc = generateMFMAOp(mfmaInsnName, valB[0], valA[kRep], acc, + broadcastCtrl, kRep); + } + } + } + return acc; + } + int getNumSubmatrices(Type elementType, int mDim, int nDim) const { if (mDim == 64 && nDim == 4 || mDim == 4 && nDim == 64) return 1; @@ -187,13 +290,14 @@ struct DotOpMFMAConversionHelper { llvm::report_fatal_error("No match found in MFMA database\n"); mfmaInsnName = (*maybeMfmaInsn).getInsnName(); - unsigned k_base = (*maybeMfmaInsn).getKBase(); + unsigned kBaseA = (*maybeMfmaInsn).getKBaseA(); + unsigned kBaseB = (*maybeMfmaInsn).getKBaseB(); auto aEncoding = aTensorTy.getEncoding().cast(); auto bEncoding = bTensorTy.getEncoding().cast(); - auto kWidth = aEncoding.getKWidth(); - assert(kWidth == bEncoding.getKWidth()); + auto kWidthA = aEncoding.getKWidth(); + auto kWidthB = bEncoding.getKWidth(); auto repA = aEncoding.getMFMARep(aTensorTy.getShape()); auto repB = bEncoding.getMFMARep(bTensorTy.getShape()); @@ -209,9 +313,9 @@ struct DotOpMFMAConversionHelper { auto numRepK = repA[1]; auto operandA = getValuesFromDotOperandLayoutStruct( - loadedA, numRepM, numRepK, kWidth, k_base, aTensorTy.getElementType()); + loadedA, numRepM, numRepK, kWidthA, kBaseA, aTensorTy.getElementType()); auto operandB = getValuesFromDotOperandLayoutStruct( - loadedB, numRepN, numRepK, kWidth, k_base, aTensorTy.getElementType()); + loadedB, numRepN, numRepK, kWidthB, kBaseB, aTensorTy.getElementType()); auto dstElemTy = dTensorTy.getElementType(); auto fc = @@ -236,12 +340,10 @@ struct DotOpMFMAConversionHelper { acc = zeroAuxiliarBlocks(subBlocks, acc); for (size_t k = 0; k < numRepK; k++) - for (int kpack = 0; kpack < kWidth / k_base; ++kpack) - acc = mfmaLayout.getIsTransposed() - ? generateMFMAOp(mfmaInsnName, operandB[kpack][{n, k}], - operandA[kpack][{m, k}], acc) - : generateMFMAOp(mfmaInsnName, operandA[kpack][{m, k}], - operandB[kpack][{n, k}], acc); + for (int kpack = 0; kpack < kWidthA / kBaseA; ++kpack) + acc = generateMFMATile(mfmaInsnName, operandA[{kpack, m, k}], + operandB[{kpack, n, k}], acc, mDim, nDim, + mfmaLayout.getIsTransposed()); acc = reduceSubBlocks(subBlocks, acc); for (unsigned v = 0; v < elemsPerVec; ++v) { fc[m * numRepN * elemsPerVec + n * elemsPerVec + v] = @@ -260,30 +362,39 @@ struct DotOpMFMAConversionHelper { } /** - * @brief extract vector from rawElems based on kWidth and k_base + * @brief extract vector from rawElems based on kWidth and kBase * rawElems is a vector of kWidth elements. We need to prepare vector(s) of - * k_base elements for each mfma instruction + * kBase elements for each mfma instruction + * + * @param rawElems vector of "raw" elements for one mfma tile + * @param k id in k-pack + * @param kPack size of k-pack + * @param numIntrinsics number of operands we need to extract + * @param type type mfma intrinsic requires + * + * @return elements converted for one repetition */ - SmallVector extractOperands(Value rawElems, int kWidth, int k_base, - Type type) const { - int kpack = kWidth / k_base; + SmallVector extractOperands(Value rawElems, int k, int kPack, + int numIntrinsics, Type type) const { + assert(numIntrinsics == 1 || numIntrinsics == 16); + auto rawTy = rawElems.getType().cast(); + auto rawElemTy = rawTy.getElementType(); + // number of elements required by one mfma intrinsic + int intrinsicK = rawTy.getNumElements() / numIntrinsics / kPack; + int kBase = rawTy.getNumElements() / kPack; + SmallVector results; - auto vecTy = vec_ty(type, k_base); - for (int k = 0; k < kpack; ++k) { - Value vec = undef(vecTy); - for (int elemId = 0; elemId < k_base; ++elemId) { - auto val = - extract_element(type, rawElems, i32_val(elemId + k * k_base)); - vec = insert_element(vecTy, vec, val, i32_val(elemId)); + // extract needed elements in original dtype + auto typedVecTy = vec_ty(rawElemTy, intrinsicK); + for (int intrinsic = 0; intrinsic < numIntrinsics; ++intrinsic) { + Value typedVec = undef(typedVecTy); + for (int elemId = 0; elemId < intrinsicK; ++elemId) { + int elemOff = elemId + intrinsic * intrinsicK + k * kBase; + auto val = extract_element(rawElemTy, rawElems, i32_val(elemOff)); + typedVec = insert_element(typedVecTy, typedVec, val, i32_val(elemId)); } - if (type.getIntOrFloatBitWidth() == 8) { - if (4 == k_base) - // This is for int8 on pre- MI300 GPUs - results.push_back(bitcast(vec, i32_ty)); - if (8 == k_base) - results.push_back(bitcast(vec, i64_ty)); - } else - results.push_back(vec); + Value castedVec = bitcast(typedVec, type); + results.push_back(castedVec); } return results; } @@ -292,35 +403,38 @@ struct DotOpMFMAConversionHelper { * @brief Converts dot operand structure to value table and converts types * appropriate for mfma instructions */ - SmallVector - getValuesFromDotOperandLayoutStruct(Value value, int n0, int n1, int kWidth, - int k_base, Type type) const { + ValueTable getValuesFromDotOperandLayoutStruct(Value value, int n0, int n1, + int kWidth, int kBase, + Type type) const { auto elems = typeConverter->unpackLLElements(loc, value, rewriter, type); - ValueTable vals; - ValueTable vals1; - int kpack = kWidth / k_base; - SmallVector dotOpVals(kpack); + int kpack = kWidth / kBase; + // "Wide operand" means that this operand is for mfma 4x64 layout + // This operand is 64x64 for fp16, bf16 and int8 data types and + // 16x64 for fp32 + bool wideOperand = kWidth >= 16; + // How many rocdl intrinsics will process one tile + int numIntrinsics = wideOperand ? 16 : 1; + int intrinsicKWidth = wideOperand ? kBase / numIntrinsics : kBase; + Type intrinsicDType; + if (type.isF32()) + intrinsicDType = f32_ty; + if (type.getIntOrFloatBitWidth() == 8) + intrinsicDType = rewriter.getIntegerType(intrinsicKWidth * 8); + if (type.isBF16()) + intrinsicDType = vec_ty(i16_ty, intrinsicKWidth); + if (type.isF16()) + intrinsicDType = vec_ty(f16_ty, intrinsicKWidth); + assert(intrinsicDType); + + ValueTable dotOpVals; for (int i = 0; i < n0; i++) { for (int j = 0; j < n1; j++) { auto rawElems = elems[n1 * i + j]; - - if (type.isF32()) { - for (int k = 0; k < kpack; ++k) { - dotOpVals[k][{i, j}] = extract_element(type, rawElems, i32_val(k)); - } - } else { - SmallVector vals; - if (type.getIntOrFloatBitWidth() == 8) { - vals = extractOperands(rawElems, kWidth, k_base, i8_ty); - } else if (type.isBF16()) { - vals = extractOperands(rawElems, kWidth, k_base, i16_ty); - } else { - assert(type.isF16() && "Unsupported data type"); - vals = extractOperands(rawElems, kWidth, k_base, f16_ty); - } - for (int k = 0; k < kpack; ++k) { - dotOpVals[k][{i, j}] = vals[k]; - } + for (int k = 0; k < kpack; k++) { + SmallVector vals = extractOperands( + rawElems, k, kpack, numIntrinsics, intrinsicDType); + assert(vals.size() == numIntrinsics); + dotOpVals[{k, i, j}] = vals; } } } diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 87e8bb218bc9..48171dc43822 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -304,12 +304,17 @@ SmallVector getSizePerThread(Attribute layout) { llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); return {}; } - } else if (parentLayout.isa()) { + } else if (auto mfmaLayout = parentLayout.dyn_cast()) { auto opIdx = dotLayout.getOpIdx(); + auto kWidth = dotLayout.getKWidth(); if (opIdx == 0) { - return {4, 1}; + int repeats = + (mfmaLayout.getMDim() == 64 && mfmaLayout.getNDim() == 4) ? 16 : 1; + return {1, kWidth * repeats}; } else if (opIdx == 1) { - return {1, 4}; + int repeats = + (mfmaLayout.getMDim() == 4 && mfmaLayout.getNDim() == 64) ? 16 : 1; + return {kWidth * repeats, 1}; } else { assert(0 && "DotOperandEncodingAttr opIdx must be 0 or 1"); return {}; @@ -458,6 +463,8 @@ SmallVector getShapePerCTATile(Attribute layout, auto parentShapePerCTA = getShapePerCTATile(parentLayout, tensorShape); auto opIdx = dotLayout.getOpIdx(); + assert(parentMfmaLayout.getMDim() == 32); + if (opIdx == 0) { return {parentShapePerCTA[0], 32}; } else if (opIdx == 1) { @@ -1102,16 +1109,13 @@ DotOperandEncodingAttr::getMFMAElemsPerInstr() const { (mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64)); int64_t kWidth = getKWidth(); constexpr int waveSize = 64; // MFMA is used on wave64 architectures only - int kGroups = -1; - if (mDim == nDim) - kGroups = waveSize / mDim; - if (mDim == 64 && nDim == 4 || mDim == 4 && nDim == 64) - kGroups = 1; + auto nonKDim = getOpIdx() == 0 ? mDim : nDim; + int kGroups = waveSize / nonKDim; int64_t kDim = kWidth * kGroups; if (getOpIdx() == 0) - return {mDim, kDim}; + return {nonKDim, kDim}; else - return {kDim, nDim}; + return {kDim, nonKDim}; } SmallVector @@ -1902,6 +1906,18 @@ struct TritonGPUInferLayoutInterface // Verify that the encodings are valid. if (!aEncoding || !bEncoding) return op->emitError("mismatching encoding between A and B operands"); +#ifdef USE_ROCM + auto aParentEncoding = + aEncoding.getParent().dyn_cast_or_null(); + auto bParentEncoding = + bEncoding.getParent().dyn_cast_or_null(); + if (aParentEncoding != bParentEncoding) + return op->emitError( + "mismatching parent encoding between A and B operands"); + if (aParentEncoding != nullptr && + aParentEncoding.getMDim() != aParentEncoding.getNDim()) + return success(); +#endif // USE_ROCM if (aEncoding.getKWidth() != bEncoding.getKWidth()) return op->emitError("mismatching kWidth between A and B operands"); return success(); diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp index 12fdbf23e4a4..32303ca748bc 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp @@ -158,9 +158,8 @@ class BlockedToMFMA : public mlir::RewritePattern { /// @brief Choose MFMA instruction parameters /// @param dot target dot operation - /// @return pair {nonKDim, kDim} sizes of one MFMA instruction arguments - std::tuple - chooseMfmaDimensions(tt::DotOp dot) const { + /// @return selected mfma instruction + MfmaInsn chooseMfmaDimensions(tt::DotOp dot) const { // number of matrix elements along k dim per one MFMA intruction unsigned kDim = 0; auto opType = dot.getA().getType().cast(); @@ -175,8 +174,20 @@ class BlockedToMFMA : public mlir::RewritePattern { unsigned mDim = 0; unsigned nDim = 0; if (enforcedNonKDim != 0) { - mDim = enforcedNonKDim; - nDim = enforcedNonKDim; + if (enforcedNonKDim == 32 || enforcedNonKDim == 16 || + enforcedNonKDim == 4) { + mDim = enforcedNonKDim; + nDim = enforcedNonKDim; + } else if (enforcedNonKDim == 464) { + mDim = 4; + nDim = 64; + } else if (enforcedNonKDim == 644) { + mDim = 64; + nDim = 4; + } else { + llvm::report_fatal_error("Invalid MFMA nonKDim option, supported " + "values are: 32, 16, 4, 464, 644"); + } } else { int minSize = std::min(resShape[0], resShape[1]); if (minSize >= 32) { @@ -188,6 +199,8 @@ class BlockedToMFMA : public mlir::RewritePattern { nDim = 16; } if (minSize < 16) { + assert(opType.getShape()[1] >= 64 && + "k should be at least 64 to use this layout"); if (resShape[0] < 16 && resShape[1] >= 64) { mDim = 4; nDim = 64; @@ -195,8 +208,6 @@ class BlockedToMFMA : public mlir::RewritePattern { mDim = 64; nDim = 4; } else { - assert(opType.getShape()[1] >= 64 && - "k should be at least 64 to use this layout"); mDim = 4; nDim = 4; } @@ -215,7 +226,7 @@ class BlockedToMFMA : public mlir::RewritePattern { assert(mDim != 0 && nDim != 0); assert(resShape[0] % mDim == 0 && resShape[1] % nDim == 0); assert(opType.getShape()[1] % kDim == 0); - return {mDim, nDim, kDim}; + return maybeMfmaInsn.value(); } mlir::LogicalResult @@ -247,7 +258,10 @@ class BlockedToMFMA : public mlir::RewritePattern { ttg::MfmaEncodingAttr mfmaEnc; - auto [mDim, nDim, kDim] = chooseMfmaDimensions(dotOp); + auto instr = chooseMfmaDimensions(dotOp); + auto mDim = instr.getMDim(); + auto nDim = instr.getNDim(); + auto kDim = instr.getKDim(); auto warpsPerTile = warpsPerTileMFMA(dotOp, retShape, numWarps, {mDim, nDim}); @@ -278,33 +292,24 @@ class BlockedToMFMA : public mlir::RewritePattern { // kWidth is initialized as k_base, which is the number of elements hold by // one thread per mfma instruction - auto kWidth = -1; - // in mfma 32x32 case argument matrix groups elements in 2 groups - // in mfma 16x16 case argument matrix groups elements in 4 groups - // in mfma 4x4 case argument matrix groups in 16 groups - if (mDim == 32 && nDim == 32) - kWidth = kDim / 2; - if (mDim == 16 && nDim == 16) - kWidth = kDim / 4; - if (mDim == 4 && nDim == 4) - kWidth = kDim / 16; - if (mDim == 4 && nDim == 64 || mDim == 64 && nDim == 4) - kWidth = kDim; - assert(kWidth != -1); + auto kWidthA = instr.getKBaseA(); + auto kWidthB = instr.getKBaseB(); // We want to extend kWidth by kpack (kpack=1 means no extension) // to increase ds_read vector size // However, in FA, the second dot can only use kWidth = k_bse since it's // limited by the result of the first dot, which is of mfmaLayout. - if (!isSecondDot(dotOp)) - kWidth *= kpack; + if (!isSecondDot(dotOp)) { + kWidthA *= kpack; + kWidthB *= kpack; + } auto newAType = RankedTensorType::get( oldAType.getShape(), oldAType.getElementType(), - ttg::DotOperandEncodingAttr::get(ctx, 0, mfmaEnc, kWidth)); + ttg::DotOperandEncodingAttr::get(ctx, 0, mfmaEnc, kWidthA)); auto newBType = RankedTensorType::get( oldBType.getShape(), oldBType.getElementType(), - ttg::DotOperandEncodingAttr::get(ctx, 1, mfmaEnc, kWidth)); + ttg::DotOperandEncodingAttr::get(ctx, 1, mfmaEnc, kWidthB)); a = rewriter.create(a.getLoc(), newAType, a); b = rewriter.create(b.getLoc(), newBType, b); auto newDot = rewriter.create(dotOp.getLoc(), newRetType, a, b, diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 23f1befd2617..5a5046b1f8f0 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -669,173 +669,183 @@ using MfmaInsnGroupMap = llvm::DenseMap const MfmaInsnGroupMap & { static MfmaInsnGroupMap MfmaInsnMap{ + // MFMA tile description: + // M N K k_base_a k_base_b instr_name // f32 // mfma_f32_32x32x2f32 {{32, 32, MfmaTypeId::Fp32TyId, 1}, - {32, 32, 2, 1, ROCDL::mfma_f32_32x32x2f32::getOperationName()}}, + {32, 32, 2, 1, 1, ROCDL::mfma_f32_32x32x2f32::getOperationName()}}, {{32, 32, MfmaTypeId::Fp32TyId, 2}, - {32, 32, 2, 1, ROCDL::mfma_f32_32x32x2f32::getOperationName()}}, + {32, 32, 2, 1, 1, ROCDL::mfma_f32_32x32x2f32::getOperationName()}}, {{32, 32, MfmaTypeId::Fp32TyId, 3}, - {32, 32, 2, 1, ROCDL::mfma_f32_32x32x2f32::getOperationName()}}, + {32, 32, 2, 1, 1, ROCDL::mfma_f32_32x32x2f32::getOperationName()}}, // mfma_f32_16x16x4f32 {{16, 16, MfmaTypeId::Fp32TyId, 1}, - {16, 16, 4, 1, ROCDL::mfma_f32_16x16x4f32::getOperationName()}}, + {16, 16, 4, 1, 1, ROCDL::mfma_f32_16x16x4f32::getOperationName()}}, {{16, 16, MfmaTypeId::Fp32TyId, 2}, - {16, 16, 4, 1, ROCDL::mfma_f32_16x16x4f32::getOperationName()}}, + {16, 16, 4, 1, 1, ROCDL::mfma_f32_16x16x4f32::getOperationName()}}, {{16, 16, MfmaTypeId::Fp32TyId, 3}, - {16, 16, 4, 1, ROCDL::mfma_f32_16x16x4f32::getOperationName()}}, + {16, 16, 4, 1, 1, ROCDL::mfma_f32_16x16x4f32::getOperationName()}}, // mfma_f32_4x4x1f32 {{4, 4, MfmaTypeId::Fp32TyId, 1}, - {4, 4, 16, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, + {4, 4, 16, 1, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, {{4, 4, MfmaTypeId::Fp32TyId, 2}, - {4, 4, 16, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, + {4, 4, 16, 1, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, {{4, 64, MfmaTypeId::Fp32TyId, 1}, - {4, 64, 1, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, + {4, 64, 16, 1, 16, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, {{4, 64, MfmaTypeId::Fp32TyId, 2}, - {4, 64, 1, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, + {4, 64, 16, 1, 16, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, {{64, 4, MfmaTypeId::Fp32TyId, 1}, - {64, 4, 1, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, + {64, 4, 16, 16, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, {{64, 4, MfmaTypeId::Fp32TyId, 2}, - {64, 4, 1, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, + {64, 4, 16, 16, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, // mfma_f32_4x4x1_16B_f32 {{4, 4, MfmaTypeId::Fp32TyId, 3}, - {4, 4, 16, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, + {4, 4, 16, 1, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, {{4, 64, MfmaTypeId::Fp32TyId, 3}, - {4, 64, 1, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, + {4, 64, 16, 1, 16, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, {{64, 4, MfmaTypeId::Fp32TyId, 3}, - {64, 4, 1, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, + {64, 4, 16, 16, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, // f16 // mfma_f32_32x32x8f16 {{32, 32, MfmaTypeId::Fp16TyId, 1}, - {32, 32, 8, 4, ROCDL::mfma_f32_32x32x8f16::getOperationName()}}, + {32, 32, 8, 4, 4, ROCDL::mfma_f32_32x32x8f16::getOperationName()}}, {{32, 32, MfmaTypeId::Fp16TyId, 2}, - {32, 32, 8, 4, ROCDL::mfma_f32_32x32x8f16::getOperationName()}}, + {32, 32, 8, 4, 4, ROCDL::mfma_f32_32x32x8f16::getOperationName()}}, {{32, 32, MfmaTypeId::Fp16TyId, 3}, - {32, 32, 8, 4, ROCDL::mfma_f32_32x32x8f16::getOperationName()}}, + {32, 32, 8, 4, 4, ROCDL::mfma_f32_32x32x8f16::getOperationName()}}, // mfma_f32_16x16x16xf16 {{16, 16, MfmaTypeId::Fp16TyId, 1}, - {16, 16, 16, 4, ROCDL::mfma_f32_16x16x16f16::getOperationName()}}, + {16, 16, 16, 4, 4, ROCDL::mfma_f32_16x16x16f16::getOperationName()}}, {{16, 16, MfmaTypeId::Fp16TyId, 2}, - {16, 16, 16, 4, ROCDL::mfma_f32_16x16x16f16::getOperationName()}}, + {16, 16, 16, 4, 4, ROCDL::mfma_f32_16x16x16f16::getOperationName()}}, {{16, 16, MfmaTypeId::Fp16TyId, 3}, - {16, 16, 16, 4, ROCDL::mfma_f32_16x16x16f16::getOperationName()}}, + {16, 16, 16, 4, 4, ROCDL::mfma_f32_16x16x16f16::getOperationName()}}, // mfma_f32_4x4x4f16 {{4, 4, MfmaTypeId::Fp16TyId, 1}, - {4, 4, 64, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, + {4, 4, 64, 4, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, {{4, 4, MfmaTypeId::Fp16TyId, 2}, - {4, 4, 64, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, + {4, 4, 64, 4, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, {{4, 4, MfmaTypeId::Fp16TyId, 3}, - {4, 4, 64, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, + {4, 4, 64, 4, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, {{4, 64, MfmaTypeId::Fp16TyId, 1}, - {4, 64, 4, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, + {4, 64, 64, 4, 64, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, {{4, 64, MfmaTypeId::Fp16TyId, 2}, - {4, 64, 4, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, + {4, 64, 64, 4, 64, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, {{4, 64, MfmaTypeId::Fp16TyId, 3}, - {4, 64, 4, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, + {4, 64, 64, 4, 64, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, {{64, 4, MfmaTypeId::Fp16TyId, 1}, - {64, 4, 4, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, + {64, 4, 64, 64, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, {{64, 4, MfmaTypeId::Fp16TyId, 2}, - {64, 4, 4, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, + {64, 4, 64, 64, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, {{64, 4, MfmaTypeId::Fp16TyId, 3}, - {64, 4, 4, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, + {64, 4, 64, 64, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, // bf16 // mfma_f32_32x32x4_bf16 {{32, 32, MfmaTypeId::Bf16TyId, 1}, - {32, 32, 4, 2, ROCDL::mfma_f32_32x32x4bf16::getOperationName()}}, + {32, 32, 4, 2, 2, ROCDL::mfma_f32_32x32x4bf16::getOperationName()}}, // mfma_f32_32x32x8_bf16_1K {{32, 32, MfmaTypeId::Bf16TyId, 2}, - {32, 32, 8, 4, ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName()}}, + {32, 32, 8, 4, 4, ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName()}}, {{32, 32, MfmaTypeId::Bf16TyId, 3}, - {32, 32, 8, 4, ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName()}}, + {32, 32, 8, 4, 4, ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName()}}, // mfma_f32_16x16x8_bf16 {{16, 16, MfmaTypeId::Bf16TyId, 1}, - {16, 16, 8, 2, ROCDL::mfma_f32_16x16x8bf16::getOperationName()}}, + {16, 16, 8, 2, 2, ROCDL::mfma_f32_16x16x8bf16::getOperationName()}}, // mfma_f32_16x16x16_bf16_1K {{16, 16, MfmaTypeId::Bf16TyId, 2}, - {16, 16, 16, 4, ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName()}}, + {16, 16, 16, 4, 4, ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName()}}, {{16, 16, MfmaTypeId::Bf16TyId, 3}, - {16, 16, 16, 4, ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName()}}, + {16, 16, 16, 4, 4, ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName()}}, // mfma_f32_4x4x2_bf16 {{4, 4, MfmaTypeId::Bf16TyId, 1}, - {4, 4, 32, 2, ROCDL::mfma_f32_4x4x2bf16::getOperationName()}}, + {4, 4, 32, 2, 2, ROCDL::mfma_f32_4x4x2bf16::getOperationName()}}, {{4, 64, MfmaTypeId::Bf16TyId, 1}, - {4, 64, 2, 2, ROCDL::mfma_f32_4x4x2bf16::getOperationName()}}, + {4, 64, 32, 2, 32, ROCDL::mfma_f32_4x4x2bf16::getOperationName()}}, {{64, 4, MfmaTypeId::Bf16TyId, 1}, - {64, 4, 2, 2, ROCDL::mfma_f32_4x4x2bf16::getOperationName()}}, + {64, 4, 32, 32, 2, ROCDL::mfma_f32_4x4x2bf16::getOperationName()}}, // mfma_f32_4x4x4_bf16_1K {{4, 4, MfmaTypeId::Bf16TyId, 2}, - {4, 4, 64, 4, ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName()}}, + {4, 4, 64, 4, 4, ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName()}}, {{4, 4, MfmaTypeId::Bf16TyId, 3}, - {4, 4, 64, 4, ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName()}}, + {4, 4, 64, 4, 4, ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName()}}, {{4, 64, MfmaTypeId::Bf16TyId, 2}, - {4, 64, 4, 4, ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName()}}, + {4, 64, 64, 4, 64, ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName()}}, {{4, 64, MfmaTypeId::Bf16TyId, 3}, - {4, 64, 4, 4, ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName()}}, + {4, 64, 64, 4, 64, ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName()}}, {{64, 4, MfmaTypeId::Bf16TyId, 2}, - {64, 4, 4, 4, ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName()}}, + {64, 4, 64, 64, 4, ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName()}}, {{64, 4, MfmaTypeId::Bf16TyId, 3}, - {64, 4, 4, 4, ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName()}}, + {64, 4, 64, 64, 4, ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName()}}, // int8 // mfma_i32_32x32x8i8 {{32, 32, MfmaTypeId::I8TyId, 1}, - {32, 32, 8, 4, ROCDL::mfma_i32_32x32x8i8::getOperationName()}}, + {32, 32, 8, 4, 4, ROCDL::mfma_i32_32x32x8i8::getOperationName()}}, {{32, 32, MfmaTypeId::I8TyId, 2}, - {32, 32, 8, 4, ROCDL::mfma_i32_32x32x8i8::getOperationName()}}, + {32, 32, 8, 4, 4, ROCDL::mfma_i32_32x32x8i8::getOperationName()}}, // mfma_i32_32x32x16i8 {{32, 32, MfmaTypeId::I8TyId, 3}, - {32, 32, 16, 8, ROCDL::mfma_i32_32x32x16_i8::getOperationName()}}, + {32, 32, 16, 8, 8, ROCDL::mfma_i32_32x32x16_i8::getOperationName()}}, // mfma_i32_16x16x16i8 {{16, 16, MfmaTypeId::I8TyId, 1}, - {16, 16, 16, 4, ROCDL::mfma_i32_16x16x16i8::getOperationName()}}, + {16, 16, 16, 4, 4, ROCDL::mfma_i32_16x16x16i8::getOperationName()}}, {{16, 16, MfmaTypeId::I8TyId, 2}, - {16, 16, 16, 4, ROCDL::mfma_i32_16x16x16i8::getOperationName()}}, + {16, 16, 16, 4, 4, ROCDL::mfma_i32_16x16x16i8::getOperationName()}}, // mfma_i32_16x16x32i8 {{16, 16, MfmaTypeId::I8TyId, 3}, - {16, 16, 32, 8, ROCDL::mfma_i32_16x16x32_i8::getOperationName()}}, + {16, 16, 32, 8, 8, ROCDL::mfma_i32_16x16x32_i8::getOperationName()}}, // mfma_i32_4x4x4i8 {{4, 4, MfmaTypeId::I8TyId, 1}, - {4, 4, 64, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, + {4, 4, 64, 4, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, {{4, 4, MfmaTypeId::I8TyId, 2}, - {4, 4, 64, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, + {4, 4, 64, 4, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, {{4, 4, MfmaTypeId::I8TyId, 3}, - {4, 4, 64, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, + {4, 4, 64, 4, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, {{4, 64, MfmaTypeId::I8TyId, 1}, - {4, 64, 4, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, + {4, 64, 64, 4, 64, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, {{4, 64, MfmaTypeId::I8TyId, 2}, - {4, 64, 4, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, + {4, 64, 64, 4, 64, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, {{4, 64, MfmaTypeId::I8TyId, 3}, - {4, 64, 4, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, + {4, 64, 64, 4, 64, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, {{64, 4, MfmaTypeId::I8TyId, 1}, - {64, 4, 4, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, + {64, 4, 64, 64, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, {{64, 4, MfmaTypeId::I8TyId, 2}, - {64, 4, 4, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, + {64, 4, 64, 64, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, {{64, 4, MfmaTypeId::I8TyId, 3}, - {64, 4, 4, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, + {64, 4, 64, 64, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, // fp8 * pf8 // mfma_f32_32x32x16_FP8_FP8 {{32, 32, MfmaTypeId::Fp8Fp8TyId, 3}, - {32, 32, 16, 8, ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName()}}, + {32, 32, 16, 8, 8, + ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName()}}, // mfma_f32_16x16x32_FP8_FP8 {{16, 16, MfmaTypeId::Fp8Fp8TyId, 3}, - {16, 16, 32, 8, ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName()}}, + {16, 16, 32, 8, 8, + ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName()}}, // mfma_f32_32x32x16_FP8_BF8 {{32, 32, MfmaTypeId::Fp8Bf8TyId, 3}, - {32, 32, 16, 8, ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName()}}, + {32, 32, 16, 8, 8, + ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName()}}, // mfma_f32_16x16x32_FP8_BF8 {{16, 16, MfmaTypeId::Fp8Bf8TyId, 3}, - {16, 16, 32, 8, ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName()}}, + {16, 16, 32, 8, 8, + ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName()}}, // mfma_f32_32x32x16_BF8_FP8 {{32, 32, MfmaTypeId::Bf8Fp8TyId, 3}, - {32, 32, 16, 8, ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName()}}, + {32, 32, 16, 8, 8, + ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName()}}, // mfma_f32_16x16x32_BF8_FP8 {{16, 16, MfmaTypeId::Bf8Fp8TyId, 3}, - {16, 16, 32, 8, ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName()}}, + {16, 16, 32, 8, 8, + ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName()}}, // mfma_f32_32x32x16_BF8_BF8 {{32, 32, MfmaTypeId::Bf8Bf8TyId, 3}, - {32, 32, 16, 8, ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName()}}, + {32, 32, 16, 8, 8, + ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName()}}, // mfma_f32_16x16x32_BF8_BF8 {{16, 16, MfmaTypeId::Bf8Bf8TyId, 3}, - {16, 16, 32, 8, ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName()}}}; + {16, 16, 32, 8, 8, + ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName()}}}; return MfmaInsnMap; }; @@ -859,6 +869,7 @@ unsigned MfmaInsn::getKDim() { return attr.k; } unsigned MfmaInsn::getMDim() { return attr.m; } unsigned MfmaInsn::getNDim() { return attr.n; } StringRef MfmaInsn::getInsnName() { return attr.insn; } -unsigned MfmaInsn::getKBase() { return attr.k_base;} +unsigned MfmaInsn::getKBaseA() { return attr.k_base_a; } +unsigned MfmaInsn::getKBaseB() { return attr.k_base_b; } } // namespace mlir diff --git a/python/06-attention-decode.py b/python/06-attention-decode.py new file mode 100644 index 000000000000..04d985405a58 --- /dev/null +++ b/python/06-attention-decode.py @@ -0,0 +1,853 @@ +from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple +import pytest +import torch +import sys + +import triton +import triton.language as tl + +def _strides(x: torch.Tensor, *stride_names: str): + assert x.ndim == len(stride_names) + return {f"stride_{s}": x.stride(i) for i, s in enumerate(stride_names)} + + +@triton.jit +def _fwd_kernel_splitK( + Q, + K, + V, + sm_scale, + Out_splitK, # [B, H, split_k, Mq, K] + Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li] + Seq_len, + stride_qz, + stride_qm, + stride_qg, + stride_qh, + stride_qk, + stride_kz, + stride_kn, + stride_kg, + stride_kh, + stride_kk, + stride_vz, + stride_vn, + stride_vg, + stride_vh, + stride_vk, + stride_osk_zhg, + stride_osk_s, + stride_osk_m, + stride_osk_k, + stride_mzhg, + stride_m2, + stride_ms, + stride_mm, + Z, + N_CTX_Q, + N_CTX_K, + BLOCK_N_PER_SPLIT, + H: tl.constexpr, + G: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + BOUNDS_CHECKS_N: tl.constexpr, + USE_SEQ_LEN: tl.constexpr, + PACKED_PER_VAL: tl.constexpr = 1, + N_GROUPS: tl.constexpr = 1, +): + """This kernel can accept non-quantized or int4-quantized keys/values. + PACKED_PER_VAL determines the quantization type: + - PACKED_PER_VAL == 1 means no quantization + - PACKED_PER_VAL == 8 means 4-bit quantization (8 packed quantized values inside one int32) + For the quantized case K/V should be int32 tensors. + Quantization can be row-wise (when N_GROUPS = 1) or group-wise with N_GROUPS = 2, 4, or 8. + Quantization coefficients are stored at the beginning of the row along the last dimension of K/V + So K[B, H, M, :] has a form + [ quant_coef0, quant_coef1, ...| + group0_quant_value0, group0_quant_value1,... | + group1_quant_value0, group1_quant_value1,...] + where each quant_coef is an int32 which should be interpreted as 2 packed float16: scale and offset. + + """ + tl.static_assert( + (PACKED_PER_VAL == 1 and tl.constexpr(K.dtype.element_ty != tl.int32)) + or (PACKED_PER_VAL == 8 and tl.constexpr(K.dtype.element_ty == tl.int32)), + f"Only 4-bit quantization is supported, K/V should have dtype int32 in " + f"the quantized case: {PACKED_PER_VAL=} {tl.constexpr(K.dtype)=} {tl.constexpr(K.dtype.element_ty)=}", + ) + tl.static_assert( + (((N_GROUPS == 1 or N_GROUPS == 2) or N_GROUPS == 4) or N_GROUPS == 8), + "Number of quantization groups can be 1 (row-wise quantization), 2, 4, or 8.", + ) + + QUANTIZED: tl.constexpr = PACKED_PER_VAL > 1 + PACKED_D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // PACKED_PER_VAL // N_GROUPS + D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // N_GROUPS + + start_m = tl.program_id(0) + off_zhg = tl.program_id(1) + off_z = off_zhg // (H * G) + off_h = (off_zhg // G) % H + off_g = off_zhg % G + splitk_idx = tl.program_id(2) + + lo = splitk_idx * BLOCK_N_PER_SPLIT + if USE_SEQ_LEN: + kv_len = tl.load(Seq_len + off_z) + else: + kv_len = N_CTX_K + hi = tl.minimum((splitk_idx + 1) * BLOCK_N_PER_SPLIT, kv_len) + + Q_block_ptr = tl.make_block_ptr( + base=Q + off_h * stride_qh + off_z * stride_qz + off_g * stride_qg, + shape=(N_CTX_Q, D_PER_GROUP), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, D_PER_GROUP), + order=(1, 0), + ) + + k_base = K + off_h * stride_kh + off_z * stride_kz + off_g * stride_kg + # Additional shift by 1 along the last dimension in the quantized case, since + # the first element along that dim contains packed quantization coefficients. + K_block_ptr = tl.make_block_ptr( + base=k_base + stride_kk * QUANTIZED * N_GROUPS, + shape=(PACKED_D_PER_GROUP, hi), + strides=(stride_kk, stride_kn), + offsets=(0, lo), + block_shape=(PACKED_D_PER_GROUP, BLOCK_N), + order=(0, 1), + ) + v_base = V + off_h * stride_vh + off_z * stride_vz + off_g * stride_vg + V_block_ptr = tl.make_block_ptr( + base=v_base + stride_vk * QUANTIZED * N_GROUPS, + shape=(hi, PACKED_D_PER_GROUP), + strides=(stride_vn, stride_vk), + offsets=(lo, 0), + block_shape=(BLOCK_N, PACKED_D_PER_GROUP), + order=(1, 0), + ) + + if QUANTIZED: + # Pointers to quantization coefficients. Even those they are 1D, + # we have to use block pointers, since usual pointers + # don't support boundary checks + K_scale_shift_block_ptr = tl.make_block_ptr( + base=k_base, + shape=(1, hi), + strides=(stride_kk, stride_kn), + offsets=(0, lo), + block_shape=(1, BLOCK_N), + order=(0, 1), + ) + V_scale_shift_block_ptr = tl.make_block_ptr( + base=v_base, + shape=(hi, 1), + strides=(stride_vn, stride_vk), + offsets=(lo, 0), + block_shape=(BLOCK_N, 1), + order=(1, 0), + ) + else: + K_scale_shift_block_ptr = None + V_scale_shift_block_ptr = None + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + + acc = tl.zeros([BLOCK_M, D_PER_GROUP], dtype=tl.float32) # noqa: F821 + + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + # load q: it will stay in SRAM throughout + # q: "VAR_ARGS_ARRAY" # noqa: F821 + # for i in range(elem_num): # noqa: F821 + q = tl.load( # noqa: F821 + tl.advance(Q_block_ptr, (0, 0)), boundary_check=(0,) + ) + q = (q * qk_scale).to(tl.float16) + + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + # k: "VAR_ARGS_ARRAY" # noqa: F821 + # v: "VAR_ARGS_ARRAY" # noqa: F821 + # for i in range(len(acc)): # noqa: F821 + k, v = load_dequantize_k_v_group( # noqa: F821 + K_block_ptr, + V_block_ptr, + K_scale_shift_block_ptr, + V_scale_shift_block_ptr, + BOUNDS_CHECKS_N, + PACKED_PER_VAL, + PACKED_D_PER_GROUP, + Q.dtype.element_ty, + 0, + ) + # k.append(k_tmp) + # v.append(v_tmp) + + # -- compute qk --- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + # for i in range(elem_num): # noqa: F821 + qk += tl.dot(q, k) # noqa: F821 + #qk += tl.dot(q, k) # noqa: F821 + + # TODO: This is slow, and only needed at the last iteration. + # Maybe we can unroll the last iteration instead? + if BOUNDS_CHECKS_N: + qk = tl.where(tl.arange(0, BLOCK_N) < hi - start_n, qk, float("-inf")) + # -- compute scaling constant --- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + p = p.to(Q.dtype.element_ty) + + # -- scale and update acc -- + # for i in range(elem_num): # noqa: F821 + acc *= alpha[:, None] # noqa: F821 + #acc += tl.dot(p, v) # noqa: F821 + acc += tl.dot(p, v) # noqa: F821 + # update pointers + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + if PACKED_PER_VAL > 1: + K_scale_shift_block_ptr = tl.advance( + K_scale_shift_block_ptr, (0, BLOCK_N) + ) + V_scale_shift_block_ptr = tl.advance( + V_scale_shift_block_ptr, (BLOCK_N, 0) + ) + + # write back O + O_block_ptr = tl.make_block_ptr( + base=Out_splitK + off_zhg * stride_osk_zhg + splitk_idx * stride_osk_s, + shape=(N_CTX_Q, D_PER_GROUP), + strides=(stride_osk_m, 1), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, D_PER_GROUP), + order=(1, 0), + ) + # for i in range(elem_num): # noqa: F821 + tl.store( + tl.advance(O_block_ptr, (0, 0)), + acc, # noqa: F821 + boundary_check=(0,), + ) + # Write metadata for split-K reduction + Metadata_ptr = ( + Metadata + + off_zhg * stride_mzhg + + splitk_idx * stride_ms + + start_m * BLOCK_M + + tl.arange(0, BLOCK_M) + ) + tl.store(Metadata_ptr, m_i) + tl.store(Metadata_ptr + stride_m2, l_i) + +@triton.jit +def load_dequantize_k_v_group( + K_block_ptr, + V_block_ptr, + K_scale_shift_block_ptr, + V_scale_shift_block_ptr, + BOUNDS_CHECKS_N: tl.constexpr, + PACKED_PER_VAL: tl.constexpr, + PACKED_D_PER_GROUP: tl.constexpr, + dtype: tl.constexpr, + group_id: tl.constexpr, +): + #Load K/V for a given block. In case of int4-quantized K/V, + # dequantize them after loading. If quantization is group-wise, + # use group_id to advance the pointers to the current group. + + # Advance to the current quantization group + K_block_ptr = tl.advance(K_block_ptr, (PACKED_D_PER_GROUP * group_id, 0)) + V_block_ptr = tl.advance(V_block_ptr, (0, PACKED_D_PER_GROUP * group_id)) + + # -- load k, v -- + k = tl.load(K_block_ptr, boundary_check=(1,) if BOUNDS_CHECKS_N else ()) + v = tl.load(V_block_ptr, boundary_check=(0,) if BOUNDS_CHECKS_N else ()) + + if PACKED_PER_VAL > 1: + # K/V are quantized, load quantization coefficients and dequantize + + K_scale_shift_block_ptr = tl.advance(K_scale_shift_block_ptr, (group_id, 0)) + V_scale_shift_block_ptr = tl.advance(V_scale_shift_block_ptr, (0, group_id)) + + k_scale_shift = tl.load( + K_scale_shift_block_ptr, boundary_check=(1,) if BOUNDS_CHECKS_N else () + ) + v_scale_shift = tl.load( + V_scale_shift_block_ptr, boundary_check=(0,) if BOUNDS_CHECKS_N else () + ) + + k_scale, k_shift = cast_uint32_to_half2(k_scale_shift) + v_scale, v_shift = cast_uint32_to_half2(v_scale_shift) + v = dequantize(v, v_scale, v_shift, PACKED_PER_VAL).to(dtype) + k_t = dequantize( + tl.trans(k), + tl.trans(k_scale), + tl.trans(k_shift), + PACKED_PER_VAL, + ).to(dtype) + k = tl.trans(k_t) + return k, v + +@triton.jit +def cast_uint32_to_half2(scale_shift): + # Extract two float16 packed into one int32 + scale = scale_shift & 0xFFFF + shift = scale_shift >> 16 + scale = scale.to(tl.uint16).to(tl.float16, bitcast=True) + shift = shift.to(tl.uint16).to(tl.float16, bitcast=True) + return scale, shift + +@triton.jit +def dequantize( + x_, + scale, + shift, + PACKED_PER_VAL: tl.constexpr = 8, +): + # PACKED_PER_VAL is the number of values packed into + # each element x_. For example, for int4 quantization + #and x_ of type int32, PACKED_PER_VAL is 8. + + BLOCK_N: tl.constexpr = x_.shape[0] + BLOCK_DMODEL_PACKED: tl.constexpr = x_.shape[1] + offsets = tl.arange(0, PACKED_PER_VAL) * 4 + quant_offset = ( + x_[:, None, :] >> offsets[None, :, None] + ) # (BLOCK_N, PACKED_PER_VAL, D // PACKED_PER_VAL) + + quant_offset = tl.view( + quant_offset, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL) + ) + # Trick - instead of converting int4 to float16 we view it as float16 + # and then multiply by 32768 * 512 == 2**24 + quant_offset = (quant_offset & 0xF).to(tl.uint16).to(tl.float16, bitcast=True) + quant_offset = (quant_offset * 32768.0).to(tl.float16) + scale_512 = scale * 512 + + dequant = quant_offset * scale_512 + shift + return dequant + +@triton.jit +def _splitK_reduce( + Out_splitK, # [B, H, split_k, Mq, K] + Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li] + Out, # [B, H, M, K] + LSE, # [B, H, M] + stride_osk_zhg, + stride_osk_s, + stride_osk_m, + stride_osk_k, + stride_mzhg, + stride_m2, + stride_ms, + stride_mm, + stride_oz, + stride_oh, + stride_og, + stride_om, + stride_ok, + stride_lse_zhg, + stride_lse_m, + M_ceil:tl.constexpr, + BLOCK_SIZE: tl.constexpr, + H: tl.constexpr, + G: tl.constexpr, + split_k:tl.constexpr, + splitK_pow2:tl.constexpr, + use_mask:tl.constexpr, +): + off_zhg = tl.program_id(0) + off_z = off_zhg // (H * G) + off_h = (off_zhg // G) % H + off_g = off_zhg % G + off_m = tl.program_id(1) + off_k = tl.program_id(2) + + # read chunk + spk_idx = tl.arange(0, splitK_pow2) + kidx = tl.arange(0, BLOCK_SIZE) + + Metadata_ptr = ( + Metadata + + stride_mzhg * off_zhg + + spk_idx * stride_ms + + off_m * stride_mm + ) + + o_ptr = ( + Out_splitK + + off_zhg * stride_osk_zhg + + stride_osk_m * off_m + + off_k * BLOCK_SIZE + + stride_osk_s * spk_idx[:, None] + + kidx[None, :] * stride_osk_k + ) + + # read max values of each splitK + if use_mask: + spk_mask = spk_idx < split_k + l_m = tl.load(Metadata_ptr, mask=spk_mask, other=float("-inf")) + l_sum = tl.load(Metadata_ptr + stride_m2, mask=spk_mask, other=0.0) + acc = tl.load(o_ptr, mask=spk_mask[:,None], other=0.0) + else: + l_m = tl.load(Metadata_ptr) + l_sum = tl.load(Metadata_ptr + stride_m2) + acc = tl.load(o_ptr) + + g_m = tl.max(l_m, axis=0) + alpha = tl.math.exp2(l_m - g_m) + + # read sum + l_sum *= alpha + g_sum = tl.sum(l_sum, axis=0) + + alpha = tl.math.exp2(l_m - g_m) + acc = acc * alpha[:, None] + acc_out = tl.sum(acc, axis=0) / g_sum + Out_ptr = ( + Out + + stride_oz * off_z + + stride_oh * off_h + + stride_og * off_g + + stride_om * off_m + + off_k * BLOCK_SIZE + + tl.arange(0, BLOCK_SIZE) + ) + tl.store(Out_ptr, acc_out) + l_ptrs = LSE + off_zhg * stride_lse_zhg + off_m + tl.store(l_ptrs, (g_m + tl.math.log2(g_sum)) / 1.44269504) + + +def quantize_kv_int4(k: torch.Tensor, num_groups: int = 1) -> torch.Tensor: + # Scale and shift are such that quantization linearly maps + # int4 values range [0..15] to input values range min(k)..max(k) + # individually for every row + k = k.reshape(*k.shape[:-1], num_groups, k.shape[-1] // num_groups) + max_vals = torch.max(k, dim=-1, keepdim=True).values + min_vals = torch.min(k, dim=-1, keepdim=True).values + scale_k: torch.Tensor = (max_vals - min_vals) / 15 + + shift_k = torch.min(k, dim=-1, keepdim=True).values + scale_k = scale_k.to(torch.float16) + shift_k = shift_k.to(torch.float16) + + in_bytes = ((k - shift_k.expand(k.shape)) / scale_k.expand(k.shape)) + 0.5 + in_bytes = in_bytes.to(torch.uint8) + in_int4 = in_bytes & 0xF + in_int4_packed = in_int4[..., ::2] + (in_int4[..., 1::2] << 4) + scale_shift = torch.concat( + [scale_k.view(torch.uint8), shift_k.view(torch.uint8)], dim=-1 + ) + k_quant = torch.concat( + [ + scale_shift.flatten(start_dim=-2), + in_int4_packed.flatten(start_dim=-2), + ], + dim=-1, + ).view(torch.int16) + return k_quant + + +def dequantize_kv_fp16(quant_k: torch.Tensor, num_groups: int = 1) -> torch.Tensor: + k_i16 = quant_k.view(torch.int16) + k_ui8 = k_i16.view(torch.uint8) + + ss_size = num_groups * 4 + scale_shift_ui8 = k_ui8[...,0:ss_size] + scale_shift_ui8 = scale_shift_ui8.reshape(*scale_shift_ui8.shape[:-1], num_groups, 4) + scale = scale_shift_ui8[...,0:2].view(torch.float16) + shift = scale_shift_ui8[...,2:4].view(torch.float16) + + kv_ui8 = k_ui8[...,ss_size:] + k_ui8 = kv_ui8.reshape(*kv_ui8.shape[:-1], num_groups, -1) + k1_i4 = k_ui8 & 0xF + k2_i4 = (k_ui8 & 0xF0) >> 4 + k_shape = k1_i4.shape + k1_f16 = k1_i4.to(torch.float16) * scale.expand(k_shape) + shift.expand(k_shape) + k2_f16 = k2_i4.to(torch.float16) * scale.expand(k_shape) + shift.expand(k_shape) + + out = torch.empty((*k1_f16.shape[:-1], k1_f16.shape[-1] * 2), dtype=torch.float16, device=quant_k.device) + out[...,::2] = k1_f16 + out[...,1::2] = k2_f16 + out = out.reshape(*k_shape[:-2], -1) + + return out + + +def get_split_k(B: int, G: int, H: int, Mk: int) -> int: + """Heuristic for the number of splits""" + bh = max(B * H, 1) # NOTE: Handle B*h=0 case + split_k = max(Mk, 1024) // bh + max_chunk_size = 64 + while split_k > 0 and Mk / split_k < max_chunk_size: + split_k = split_k // 2 + while B * H * G * split_k >= 1024: + split_k = split_k // 2 + split_k = min(split_k, 512) + split_k = max(split_k, 1) + return split_k + + +class _attention(torch.autograd.Function): + + OPERATOR = _fwd_kernel_splitK + SUPPORTED_DEVICES = {"cuda"} + CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0) + SUPPORTED_DTYPES = { + torch.half, + torch.bfloat16, + } # Those are dtypes of Q. In the quantized case K/V has dtype int32 + SUPPORTED_MAX_K = 128 + # SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = { + # type(None), + # BlockDiagonalCausalWithOffsetPaddedKeysMask, + # } + SUPPORTS_DROPOUT = False + SUPPORTS_CUSTOM_SCALE = True + SUPPORTS_BMGHK = True + NAME = "triton_splitKF" + + @staticmethod + def forward(cls, q, k, v, scale_float): + + cls.SPLIT_K: Optional[int] = None + cls.BLOCK_M = 16 + cls.BLOCK_N = 64 + + cls.NUM_GROUPS = 1 # Default quantization is row-wise + + # attn_bias = inp.attn_bias + seq_len = None + + # Transpose in the case of MQA/GQA + mqa_swap_seqlen_head = False + if k.shape[3] > 1 and k.stride(3) == 0 and v.stride(3) == 0: + mqa_swap_seqlen_head = True + assert q.shape[1] == 1 + q = q.transpose(1, 3) + k = k[:, :, :, :1] + v = v[:, :, :, :1] + + if k.dtype == torch.int32: + # Quantized K/V + PACKED_PER_VAL = 8 + Lk = (k.shape[-1] - cls.NUM_GROUPS) * 8 + else: + Lk = k.shape[-1] + PACKED_PER_VAL = 1 + + B, Mk, G, H, Kkv = k.shape + B, M, G, H, Kq = q.shape + assert Lk == Kq, f"Keys have head dim {Lk} but queries have head dim {Kq}" + # print(f"B = {B}, M = {M}, G = {G}, H = {H}, Kkv = {Kkv}, Kq = {Kq}") + + BLOCK_M = cls.BLOCK_M + BLOCK_N = cls.BLOCK_N + if cls.SPLIT_K is not None: + split_k = cls.SPLIT_K + else: + # Use heuristics + split_k = get_split_k(B, G, H, Mk) + + M_ceil = (M + BLOCK_M - 1) // BLOCK_M * BLOCK_M + o_splitk = torch.empty( + [B * G * H, split_k, M_ceil, Kq], dtype=torch.float32, device=q.device + ) + metadata = torch.empty( + [B * G * H, 2, split_k, M_ceil], dtype=torch.float32, device=q.device + ) + lse = torch.empty((B * G * H, M), device=q.device, dtype=torch.float32) + grid = (triton.cdiv(M, BLOCK_M), B * G * H, split_k) + + split_size = (Mk + split_k - 1) // split_k + use_seq_len = seq_len is not None + + #print(f"B = {B}, G = {G}, H = {H}, split_k = {split_k}, M_ceil = {M_ceil}, Kq = {Kq}, num_of_wgs = {G * G * H * split_k}") + #print(grid) + #print(_strides(k, "kz", "kn", "kg", "kh", "kk")) + #print("BLOCK_N", BLOCK_N) + + pgm = _fwd_kernel_splitK[grid]( + Q=q, + K=k, + V=v, + sm_scale=scale_float, + Out_splitK=o_splitk, + Metadata=metadata, + Seq_len=seq_len, + **_strides(q, "qz", "qm", "qg", "qh", "qk"), + **_strides(k, "kz", "kn", "kg", "kh", "kk"), + **_strides(v, "vz", "vn", "vg", "vh", "vk"), + **_strides(o_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"), + **_strides(metadata, "mzhg", "m2", "ms", "mm"), + Z=B, + H=H, + G=G, + N_CTX_Q=M, + N_CTX_K=Mk, + BLOCK_N_PER_SPLIT=split_size, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_DMODEL=Lk, + BOUNDS_CHECKS_N=(split_size % BLOCK_N) > 0 or use_seq_len, + USE_SEQ_LEN=use_seq_len, + num_warps=4, + num_stages=1, + PACKED_PER_VAL=PACKED_PER_VAL, + N_GROUPS=cls.NUM_GROUPS if PACKED_PER_VAL > 1 else 1, + matrix_instr_nonkdim=464, + waves_per_eu=1 + ) + #print(f"kernel run B = {B}, G = {G}, H = {H}, split_k = {split_k}, M_ceil = {M_ceil}, Kq = {Kq}\n", pgm.asm["amdgcn"]) + + if mqa_swap_seqlen_head: + out = torch.empty( + (B, H, G, M, Kq), device=q.device, dtype=q.dtype + ).transpose(1, 3) + else: + out = torch.empty((B, M, G, H, Kq), device=q.device, dtype=q.dtype) + + # Merge together + splitK_pow2 = triton.next_power_of_2(split_k) + use_mask = splitK_pow2 > split_k + if B * G * H * M >= 512: + k_block_num = 1 + else: + k_block_num = 2 + assert out.shape[-1] % k_block_num == 0 + k_block_size = out.shape[-1] // k_block_num + grid = (B * G * H, M, k_block_num) + #print("reduce split", split_k, k_block_size, k_block_num) + _splitK_reduce[grid]( + o_splitk, + metadata, + out, + lse, + **_strides(o_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"), + **_strides(metadata, "mzhg", "m2", "ms", "mm"), + **_strides(out, "oz", "om", "og", "oh", "ok"), + **_strides(lse, "lse_zhg", "lse_m"), + M_ceil=M_ceil, + BLOCK_SIZE=k_block_size, + G=G, + H=H, + # TODO: Tune num_warps + split_k=split_k, + splitK_pow2=splitK_pow2, + use_mask=use_mask, + num_warps=4 + ) + + lse = lse.reshape([B, G, H, M]) + if mqa_swap_seqlen_head: + # H/M dimensions have been swapped + out = out.transpose(1, 3) + lse = lse.transpose(2, 3) + if q.ndim == 4: + # BMGHK -> BMHK + assert G == 1 + out = out[:, :, 0] + lse = lse[:, 0] + if Mk == 0: + out.zero_() + if mqa_swap_seqlen_head: + out = out.reshape(B, -1, M * G, Kq).transpose(1, 2).contiguous() + else: + out = out.reshape(B, H * G, -1, Kq).contiguous() + + return out + +attention = _attention.apply + +def get_input_shapes(): +# cases = [ +# (max(1, 2 ** (16 - i)), 1, 2**i, 16, 1, 128) +# for i in range(13, 14) +# ] +# return cases + cases = [ + (max(1, 2 ** (16 - i)), 1, 2**i, 16, 1, 128) + for i in range(8, 14) + ] + [ + (max(1, 2 ** (16 - i)), 1, 2**i, 16, 2, 128) + for i in range(8, 14) + ] + cases += [(4, 1, 8192, 16, 1, 128)] + + return cases + + +@pytest.mark.parametrize('B, Mq, Mkv, Hq, Hkv, K', + get_input_shapes()) +def test_op_fwd(B, Mq, Mkv, Hq, Hkv, K, dtype=torch.float16): + torch.manual_seed(20) + q = ( + torch.empty((B, Mq, Hkv, (Hq + Hkv - 1) // Hkv, K), dtype=dtype, device="cuda") + .normal_(mean=0., std=0.5) + .requires_grad_() + ) + k = ( + torch.empty((B, Mkv, Hkv, 1, K), dtype=dtype, device="cuda") + .normal_(mean=0., std=0.5) + .requires_grad_() + ).expand(-1, -1, -1, (Hq + Hkv - 1) // Hkv, -1) + v = ( + torch.empty((B, Mkv, Hkv, 1, K), dtype=dtype, device="cuda") + .normal_(mean=0., std=0.5) + .requires_grad_() + ).expand(-1, -1, -1, (Hq + Hkv - 1) // Hkv, -1) + scale = 1 / K**0.5 + tri_out = attention(q, k, v, scale) + + q = q.reshape([B, Mq, -1, K]).permute(0, 2, 1, 3) + k = k.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) + v = v.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) + attn = (q @ k.transpose(-1, -2) * scale).softmax(-1) + ref_out = attn @ v + + # compare + torch.testing.assert_close(ref_out, tri_out, atol=1e-3, rtol=0.01) + + +@pytest.mark.parametrize('B, Mq, Mkv, Hq, Hkv, K', + get_input_shapes()) +def test_op_fwd_int4_kv(B, Mq, Mkv, Hq, Hkv, K, dtype=torch.float16): + torch.manual_seed(2) + q = ( + torch.empty((B, Mq, Hkv, (Hq + Hkv - 1) // Hkv, K), dtype=dtype, device="cuda") + .normal_(mean=1.0, std=0.5) + .requires_grad_() + ) + k = ( + torch.empty((B, Mkv, Hkv, 1, K), dtype=dtype, device="cuda") + .normal_(mean=1.0, std=0.5) + .requires_grad_() + ).expand(-1, -1, -1, (Hq + Hkv - 1) // Hkv, -1) + v = ( + torch.empty((B, Mkv, Hkv, 1, K), dtype=dtype, device="cuda") + .normal_(mean=1.0, std=0.5) + .requires_grad_() + ).expand(-1, -1, -1, (Hq + Hkv - 1) // Hkv, -1) + + num_groups = 1 + quant_k = ( + quantize_kv_int4(k, num_groups=num_groups) + .contiguous() + .view(torch.int32) + ) + quant_v = ( + quantize_kv_int4(v, num_groups=num_groups) + .contiguous() + .view(torch.int32) + ) + scale = 1 / K**0.5 + tri_out = attention(q, quant_k, quant_v, scale) + + q = q.reshape([B, Mq, -1, K]).permute(0, 2, 1, 3) + k = k.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) + v = v.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) + attn = (q @ k.transpose(-1, -2) * scale).softmax(-1) + ref_out = attn @ v + # compare + torch.testing.assert_close(ref_out, tri_out, atol=2.1e-2, rtol=0) + + # since quantization introduces rounding error, use the + # dequantized kv as inputs to the ref implementation to reduce + # the tolerance to 1e-3 + dqk = dequantize_kv_fp16(quant_k, num_groups=num_groups) + dqv = dequantize_kv_fp16(quant_v, num_groups=num_groups) + dqk = dqk.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) + dqv = dqv.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) + dq_attn = (q @ dqk.transpose(-1, -2) * scale).softmax(-1) + dq_ref_out = dq_attn @ dqv + torch.testing.assert_close(dq_ref_out, tri_out, atol=1e-3, rtol=0) + + +def test_quantization(): + a = torch.randn((2, 4, 32), dtype=torch.float16, device='cuda') + qa = quantize_kv_int4(a, num_groups=4) + dqa = dequantize_kv_fp16(qa, num_groups=4) + torch.testing.assert_close(a, dqa, atol=1.5e-1, rtol=1e-1) + + +try: + from flash_attn.flash_attn_interface import \ + flash_attn_qkvpacked_func as flash_attn_func + FLASH_VER = 2 +except BaseException: + try: + from flash_attn.flash_attn_interface import flash_attn_func + FLASH_VER = 1 + except BaseException: + FLASH_VER = None +HAS_FLASH = FLASH_VER is not None + +configs = [] +for mode in ['fwd']: + # for D_HEAD in [128]: + for causal in [False]: + configs.append(triton.testing.Benchmark( + x_names=['B', 'Mq','Mkv', 'Hq', 'Hkv', 'K'], + x_vals=get_input_shapes(), + line_arg='provider', + line_vals=['triton'] + (['flash'] if HAS_FLASH else []), + line_names=['Triton'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []), + styles=[('red', '-'), ('blue', '-')], + ylabel='ms', + plot_name=f'fused-attention-d{128}-{mode}-causal={causal}', + args={ + # 'D_HEAD': D_HEAD, + 'dtype': torch.float16, + 'mode': mode, + 'causal': causal}) + ) + + +@triton.testing.perf_report(configs) +def bench_flash_attention(B, Mq, Mkv, Hq, Hkv, K, causal, mode, provider, dtype=torch.float16, device="cuda"): + assert mode in ['fwd', 'bwd'] + warmup = 100 + rep = 400 + ms = 0 + if provider == "triton": + q = torch.randn( + [B, Mq, Hkv, Hq // Hkv, K], device="cuda", dtype=dtype, requires_grad=False + ) + k = torch.randn( + [B, Mkv, Hkv, 1, K], device="cuda", dtype=dtype, requires_grad=False + ).expand(-1, -1, -1, Hq // Hkv, -1) + v = torch.randn( + [B, Mkv, Hkv, 1, K], device="cuda", dtype=dtype, requires_grad=False + ).expand(-1, -1, -1, Hq // Hkv, -1) + + sm_scale = 1.3 + fn = lambda: attention(q, k, v, sm_scale) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + + flops_per_matmul = 2 * B * Hq * (Mq * K * Mkv + Mq * Mkv * K) + total_flops = 2 * flops_per_matmul + totalBytes = ((B * Mkv * Hkv * K * 2) + (B * Mq * Hq * K) + (B * Mq * Hq * K)) * 2 + + # return totalBytes / ms * 1e-9 + return ms * 1000 + + +def main(): + bench_flash_attention.run(save_path='.', print_data=True) + +if __name__ == '__main__': + sys.exit(main()) diff --git a/python/test/unit/language/test_core_amd.py b/python/test/unit/language/test_core_amd.py index 2bf5c63dd613..f24bf301c6d0 100644 --- a/python/test/unit/language/test_core_amd.py +++ b/python/test/unit/language/test_core_amd.py @@ -1665,6 +1665,19 @@ def kernel(X, stride_xm, stride_xn, for non_k_dim in [0, 4, 16, 32] if not (allow_tf32 and (in_dtype in ['float16']))] + + [(*shape, warps, False, False, epilogue, allow_tf32, in_dtype, out_dtype, non_k_dim, 1) + for shape in [(64, 64, 128), (16, 64, 128)] + for warps in [1, 4] + for epilogue in ['none', 'trans', 'add-matrix', 'chain-dot', 'softmax'] + for allow_tf32 in [False] + for in_dtype, out_dtype in [('float16', 'float16'), + ('bfloat16', 'float32'), + ('float8e5m2fnuz', 'float32'), + ('float8e4m3fnuz', 'float32'), + ('float16', 'float32'), + ('float32', 'float32')] + for non_k_dim in [464, 644]] + + [(*shape_nw, col_a, col_b, 'none', allow_tf32, in_dtype, out_dtype, non_k_dim, kpack) for shape_nw in [[128, 128, 32, 2], [128, 16, 32, 4], @@ -1693,8 +1706,9 @@ def kernel(X, stride_xm, stride_xn, [4, 32, 64, 4], [32, 4, 64, 2], [16, 4, 64, 8], - [64, 4, 16, 1], - [4, 64, 16, 1], + [64, 4, 64, 1], + [4, 64, 64, 1], + [4, 64, 64, 4], ] for allow_tf32 in [False, True] for col_a in [True, False] @@ -1728,6 +1742,9 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o pytest.skip("incompatible non_k_dim == 4 with K size") if non_k_dim == 4 and (M > 16 or N > 16): pytest.skip("skipping large matrices for non_k_dim == 4 to speedup testing") + if (non_k_dim == 464 and N < 64) or (non_k_dim == 644 and M < 64): + pytest.skip(f"skipping non_k_dim={non_k_dim} specific test with incompatible matrix sizes") + if capability[0] < 7: pytest.skip("Only test tl.dot() on devices with sm >= 70") @@ -1852,7 +1869,7 @@ def kernel(X, stride_xm, stride_xk, z_tri = to_triton(z, device=device) if epilogue == 'trans': - z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1]) + z_tri = torch.as_strided(z_tri, (M, N), [1, M]) if out_dtype == 'int8': out_dtype = tl.int8 diff --git a/scripts/amd/lit_tests/generate_accelerate_matmul_tests.py b/scripts/amd/lit_tests/generate_accelerate_matmul_tests.py new file mode 100755 index 000000000000..5bcc266634d1 --- /dev/null +++ b/scripts/amd/lit_tests/generate_accelerate_matmul_tests.py @@ -0,0 +1,182 @@ +import argparse +import sys + +# M N K a_ty b_ty c_ty +configs = [[32, 32, 32, "f16", "f16", "f32"], + [32, 32, 32, "bf16", "bf16", "f32"], + [32, 32, 32, "f32", "f32", "f32"], + [32, 32, 32, "i8", "i8", "i32"], + [32, 32, 32, "f8E4M3FNUZ", "f8E4M3FNUZ", "f32"], + [32, 32, 32, "f8E4M3FNUZ", "f8E5M2FNUZ", "f32"], + [32, 32, 32, "f8E5M2FNUZ", "f8E4M3FNUZ", "f32"], + [32, 32, 32, "f8E5M2FNUZ", "f8E5M2FNUZ", "f32"], + + [16, 16, 32, "f16", "f16", "f32"], + [16, 16, 32, "bf16", "bf16", "f32"], + [16, 16, 32, "f32", "f32", "f32"], + [16, 16, 32, "i8", "i8", "i32"], + [16, 16, 32, "f8E4M3FNUZ", "f8E4M3FNUZ", "f32"], + [16, 16, 32, "f8E4M3FNUZ", "f8E5M2FNUZ", "f32"], + [16, 16, 32, "f8E5M2FNUZ", "f8E4M3FNUZ", "f32"], + [16, 16, 32, "f8E5M2FNUZ", "f8E5M2FNUZ", "f32"], + + [4, 4, 64, "f16", "f16", "f32"], + [4, 4, 64, "bf16", "bf16", "f32"], + [4, 4, 64, "f32", "f32", "f32"], + [4, 4, 64, "i8", "i8", "i32"], + [4, 4, 64, "f8E4M3FNUZ", "f8E4M3FNUZ", "f32"], + [4, 4, 64, "f8E4M3FNUZ", "f8E5M2FNUZ", "f32"], + [4, 4, 64, "f8E5M2FNUZ", "f8E4M3FNUZ", "f32"], + [4, 4, 64, "f8E5M2FNUZ", "f8E5M2FNUZ", "f32"], + + [64, 4, 64, "f16", "f16", "f32"], + [64, 4, 64, "bf16", "bf16", "f32"], + [64, 4, 64, "f32", "f32", "f32"], + [64, 4, 64, "i8", "i8", "i32"], + [64, 4, 64, "f8E4M3FNUZ", "f8E4M3FNUZ", "f32"], + [64, 4, 64, "f8E4M3FNUZ", "f8E5M2FNUZ", "f32"], + [64, 4, 64, "f8E5M2FNUZ", "f8E4M3FNUZ", "f32"], + [64, 4, 64, "f8E5M2FNUZ", "f8E5M2FNUZ", "f32"], + + [4, 64, 64, "f16", "f16", "f32"], + [4, 64, 64, "bf16", "bf16", "f32"], + [4, 64, 64, "f32", "f32", "f32"], + [4, 64, 64, "i8", "i8", "i32"], + [4, 64, 64, "f8E4M3FNUZ", "f8E4M3FNUZ", "f32"], + [4, 64, 64, "f8E4M3FNUZ", "f8E5M2FNUZ", "f32"], + [4, 64, 64, "f8E5M2FNUZ", "f8E4M3FNUZ", "f32"], + [4, 64, 64, "f8E5M2FNUZ", "f8E5M2FNUZ", "f32"] + ] + +def generate(cdna_version, output_file): + arch_names = {0:"", 1: "gfx908", 2: "gfx90a", 3: "gfx940"} + arch_name = arch_names[cdna_version] + print(f"// This file is generated: $ python3 {' '.join(sys.argv)}", file=output_file) + print(f"// RUN: (! triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul=arch-generation-name={arch_name} --mlir-pass-pipeline-crash-reproducer=%t 2>/dev/null) | FileCheck --check-prefixes=CHECK %s", file=output_file) + + for cfg_id in range(len(configs)): + cfg = configs[cfg_id] + + cfg_name = "_".join([str(item) for item in cfg]) + + M, N, K, a_ty, b_ty, c_ty = cfg + if "i" in c_ty: + cst_val = "0" + else: + cst_val = "0.000000e+00" + + supported = True + if cdna_version < 3 and ("f8" in a_ty or "f8" in b_ty): + supported = False + + if M >= 32 and N >= 32: + m_dim = 32 + n_dim = 32 + elif M >= 16 and N >= 16: + m_dim = 16 + n_dim = 16 + elif M >= 64 and N < 16: + m_dim = 64 + n_dim = 4 + elif M < 16 and N >= 64: + m_dim = 4 + n_dim = 64 + elif M < 16 and N < 16: + m_dim = 4 + n_dim = 4 + if ("f8" in a_ty or "f8" in b_ty) and min(m_dim, n_dim) == 4: + supported = False + + if cdna_version == 1: + if a_ty == "f16": + k_width0 = 4 + k_width1 = 4 + if a_ty == "bf16": + k_width0 = 2 + k_width1 = 2 + if a_ty == "i8": + k_width0 = 4 + k_width1 = 4 + if a_ty == "f32": + k_width0 = 1 + k_width1 = 1 + if cdna_version == 2: + if a_ty == "f16": + k_width0 = 4 + k_width1 = 4 + if a_ty == "bf16": + k_width0 = 4 + k_width1 = 4 + if a_ty == "i8": + k_width0 = 4 + k_width1 = 4 + if a_ty == "f32": + k_width0 = 1 + k_width1 = 1 + if cdna_version == 3: + if "f8" in a_ty: + k_width0 = 8 + k_width1 = 8 + if a_ty == "f16": + k_width0 = 4 + k_width1 = 4 + if a_ty == "bf16": + k_width0 = 4 + k_width1 = 4 + if a_ty == "i8": + if min(m_dim, n_dim) == 4: + k_width0 = 4 + k_width1 = 4 + else: + k_width0 = 8 + k_width1 = 8 + if a_ty == "f32": + k_width0 = 1 + k_width1 = 1 + if m_dim == 64: + k_width0 *= 16 + if n_dim == 64: + k_width1 *= 16 + + if supported: + mfma_check = f"// CHECK: #mfma = #triton_gpu.mfma<{{versionMajor = {cdna_version}, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [{m_dim}, {n_dim}], isTransposed = false}}>" + label_check = f"// CHECK: convert_dot_{cfg_name}" + checks =f"""// CHECK: triton_gpu.convert_layout {{{{.*}}}} : (tensor<{{{{.*}}}}, #blocked>) -> tensor<{{{{.*}}}}, #mfma> +// CHECK: triton_gpu.convert_layout {{{{.*}}}} : (tensor<{{{{.*}}}}, #triton_gpu.dot_op<{{opIdx = 0, parent = #blocked}}>>) -> tensor<{{{{.*}}}}, #triton_gpu.dot_op<{{opIdx = 0, parent = #mfma, kWidth = {k_width0}}}>> +// CHECK: triton_gpu.convert_layout {{{{.*}}}} : (tensor<{{{{.*}}}}, #triton_gpu.dot_op<{{opIdx = 1, parent = #blocked}}>>) -> tensor<{{{{.*}}}}, #triton_gpu.dot_op<{{opIdx = 1, parent = #mfma, kWidth = {k_width1}}}>>""" + else: + mfma_check = "" + label_check = f"// CHECK-NOT: convert_dot_{cfg_name}" + checks = "" + + case_text = f''' +!a_ty = {a_ty} +!b_ty = {b_ty} +!c_ty = {c_ty} +#blocked = #triton_gpu.blocked<{{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}}> +#dot_operand_a = #triton_gpu.dot_op<{{opIdx=0, parent=#blocked}}> +#dot_operand_b = #triton_gpu.dot_op<{{opIdx=1, parent=#blocked}}> +module attributes {{"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32}} {{ +{mfma_check} +{label_check} + tt.func @convert_dot_{cfg_name}(%a: tensor<{M}x{K}x!a_ty, #dot_operand_a>, %b: tensor<{K}x{N}x!b_ty, #dot_operand_b>) -> tensor<{M}x{N}x!c_ty, #blocked> {{ + %cst_c = arith.constant dense<{cst_val}> : tensor<{M}x{N}x!c_ty, #blocked> +{checks} + %D = tt.dot %a, %b, %cst_c {{allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false}} : tensor<{M}x{K}x!a_ty, #dot_operand_a> * tensor<{K}x{N}x!b_ty, #dot_operand_b> -> tensor<{M}x{N}x!c_ty, #blocked> + tt.return %D: tensor<{M}x{N}x!c_ty, #blocked> + }} +}} + +''' + if cfg_id == len(configs) - 1: + print(case_text, end="", file=output_file) + else: + print(case_text, end="// -----\n", file=output_file) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("cdna_version", type=int) + parser.add_argument("output_file", type=str) + args = parser.parse_args() + with open(args.output_file, "w") as f: + generate(cdna_version=args.cdna_version, output_file=f) diff --git a/test/TritonGPU/accelerate-matmul-cdna1.mlir b/test/TritonGPU/accelerate-matmul-cdna1.mlir index 51956c590035..07886f2f2e21 100644 --- a/test/TritonGPU/accelerate-matmul-cdna1.mlir +++ b/test/TritonGPU/accelerate-matmul-cdna1.mlir @@ -1,3 +1,4 @@ +// This file is generated: $ python3 ../scripts/amd/lit_tests/generate_accelerate_matmul_tests.py 1 ../test/TritonGPU/accelerate-matmul-cdna1.mlir // RUN: (! triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul=arch-generation-name=gfx908 --mlir-pass-pipeline-crash-reproducer=%t 2>/dev/null) | FileCheck --check-prefixes=CHECK %s !a_ty = f16 @@ -488,13 +489,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 1, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [64, 4], isTransposed = false}> -// CHECK: convert_dot_64_4_4_f16_f16_f32 - tt.func @convert_dot_64_4_4_f16_f16_f32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK: convert_dot_64_4_64_f16_f16_f32 + tt.func @convert_dot_64_4_64_f16_f16_f32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<64x4x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 64}>> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -509,13 +510,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 1, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [64, 4], isTransposed = false}> -// CHECK: convert_dot_64_4_4_bf16_bf16_f32 - tt.func @convert_dot_64_4_4_bf16_bf16_f32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK: convert_dot_64_4_64_bf16_bf16_f32 + tt.func @convert_dot_64_4_64_bf16_bf16_f32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<64x4x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 2}>> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 32}>> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 2}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -530,13 +531,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 1, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [64, 4], isTransposed = false}> -// CHECK: convert_dot_64_4_4_f32_f32_f32 - tt.func @convert_dot_64_4_4_f32_f32_f32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK: convert_dot_64_4_64_f32_f32_f32 + tt.func @convert_dot_64_4_64_f32_f32_f32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<64x4x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 1}>> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 16}>> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 1}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -551,13 +552,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 1, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [64, 4], isTransposed = false}> -// CHECK: convert_dot_64_4_4_i8_i8_i32 - tt.func @convert_dot_64_4_4_i8_i8_i32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK: convert_dot_64_4_64_i8_i8_i32 + tt.func @convert_dot_64_4_64_i8_i8_i32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0> : tensor<64x4x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 64}>> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -572,11 +573,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_64_4_4_f8E4M3FNUZ_f8E4M3FNUZ_f32 - tt.func @convert_dot_64_4_4_f8E4M3FNUZ_f8E4M3FNUZ_f32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_64_4_64_f8E4M3FNUZ_f8E4M3FNUZ_f32 + tt.func @convert_dot_64_4_64_f8E4M3FNUZ_f8E4M3FNUZ_f32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<64x4x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -591,11 +592,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_64_4_4_f8E4M3FNUZ_f8E5M2FNUZ_f32 - tt.func @convert_dot_64_4_4_f8E4M3FNUZ_f8E5M2FNUZ_f32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_64_4_64_f8E4M3FNUZ_f8E5M2FNUZ_f32 + tt.func @convert_dot_64_4_64_f8E4M3FNUZ_f8E5M2FNUZ_f32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<64x4x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -610,11 +611,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_64_4_4_f8E5M2FNUZ_f8E4M3FNUZ_f32 - tt.func @convert_dot_64_4_4_f8E5M2FNUZ_f8E4M3FNUZ_f32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_64_4_64_f8E5M2FNUZ_f8E4M3FNUZ_f32 + tt.func @convert_dot_64_4_64_f8E5M2FNUZ_f8E4M3FNUZ_f32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<64x4x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -629,11 +630,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_64_4_4_f8E5M2FNUZ_f8E5M2FNUZ_f32 - tt.func @convert_dot_64_4_4_f8E5M2FNUZ_f8E5M2FNUZ_f32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_64_4_64_f8E5M2FNUZ_f8E5M2FNUZ_f32 + tt.func @convert_dot_64_4_64_f8E5M2FNUZ_f8E5M2FNUZ_f32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<64x4x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -648,13 +649,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 1, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [4, 64], isTransposed = false}> -// CHECK: convert_dot_4_64_4_f16_f16_f32 - tt.func @convert_dot_4_64_4_f16_f16_f32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK: convert_dot_4_64_64_f16_f16_f32 + tt.func @convert_dot_4_64_64_f16_f16_f32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<4x64x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 64}>> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } @@ -669,13 +670,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 1, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [4, 64], isTransposed = false}> -// CHECK: convert_dot_4_64_4_bf16_bf16_f32 - tt.func @convert_dot_4_64_4_bf16_bf16_f32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK: convert_dot_4_64_64_bf16_bf16_f32 + tt.func @convert_dot_4_64_64_bf16_bf16_f32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<4x64x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 2}>> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 2}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 32}>> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } @@ -690,13 +691,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 1, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [4, 64], isTransposed = false}> -// CHECK: convert_dot_4_64_4_f32_f32_f32 - tt.func @convert_dot_4_64_4_f32_f32_f32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK: convert_dot_4_64_64_f32_f32_f32 + tt.func @convert_dot_4_64_64_f32_f32_f32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<4x64x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 1}>> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 1}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 16}>> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } @@ -711,13 +712,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 1, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [4, 64], isTransposed = false}> -// CHECK: convert_dot_4_64_4_i8_i8_i32 - tt.func @convert_dot_4_64_4_i8_i8_i32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK: convert_dot_4_64_64_i8_i8_i32 + tt.func @convert_dot_4_64_64_i8_i8_i32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0> : tensor<4x64x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 64}>> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } @@ -732,11 +733,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_4_64_4_f8E4M3FNUZ_f8E4M3FNUZ_f32 - tt.func @convert_dot_4_64_4_f8E4M3FNUZ_f8E4M3FNUZ_f32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_4_64_64_f8E4M3FNUZ_f8E4M3FNUZ_f32 + tt.func @convert_dot_4_64_64_f8E4M3FNUZ_f8E4M3FNUZ_f32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<4x64x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } @@ -751,11 +752,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_4_64_4_f8E4M3FNUZ_f8E5M2FNUZ_f32 - tt.func @convert_dot_4_64_4_f8E4M3FNUZ_f8E5M2FNUZ_f32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_4_64_64_f8E4M3FNUZ_f8E5M2FNUZ_f32 + tt.func @convert_dot_4_64_64_f8E4M3FNUZ_f8E5M2FNUZ_f32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<4x64x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } @@ -770,11 +771,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_4_64_4_f8E5M2FNUZ_f8E4M3FNUZ_f32 - tt.func @convert_dot_4_64_4_f8E5M2FNUZ_f8E4M3FNUZ_f32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_4_64_64_f8E5M2FNUZ_f8E4M3FNUZ_f32 + tt.func @convert_dot_4_64_64_f8E5M2FNUZ_f8E4M3FNUZ_f32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<4x64x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } @@ -789,11 +790,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_4_64_4_f8E5M2FNUZ_f8E5M2FNUZ_f32 - tt.func @convert_dot_4_64_4_f8E5M2FNUZ_f8E5M2FNUZ_f32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_4_64_64_f8E5M2FNUZ_f8E5M2FNUZ_f32 + tt.func @convert_dot_4_64_64_f8E5M2FNUZ_f8E5M2FNUZ_f32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<4x64x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } diff --git a/test/TritonGPU/accelerate-matmul-cdna2.mlir b/test/TritonGPU/accelerate-matmul-cdna2.mlir index 0d853186bac3..4226f92ff0fd 100644 --- a/test/TritonGPU/accelerate-matmul-cdna2.mlir +++ b/test/TritonGPU/accelerate-matmul-cdna2.mlir @@ -1,3 +1,4 @@ +// This file is generated: $ python3 ../scripts/amd/lit_tests/generate_accelerate_matmul_tests.py 2 ../test/TritonGPU/accelerate-matmul-cdna2.mlir // RUN: (! triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul=arch-generation-name=gfx90a --mlir-pass-pipeline-crash-reproducer=%t 2>/dev/null) | FileCheck --check-prefixes=CHECK %s !a_ty = f16 @@ -488,13 +489,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [64, 4], isTransposed = false}> -// CHECK: convert_dot_64_4_4_f16_f16_f32 - tt.func @convert_dot_64_4_4_f16_f16_f32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK: convert_dot_64_4_64_f16_f16_f32 + tt.func @convert_dot_64_4_64_f16_f16_f32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<64x4x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 64}>> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -509,13 +510,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [64, 4], isTransposed = false}> -// CHECK: convert_dot_64_4_4_bf16_bf16_f32 - tt.func @convert_dot_64_4_4_bf16_bf16_f32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK: convert_dot_64_4_64_bf16_bf16_f32 + tt.func @convert_dot_64_4_64_bf16_bf16_f32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<64x4x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 64}>> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -530,13 +531,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [64, 4], isTransposed = false}> -// CHECK: convert_dot_64_4_4_f32_f32_f32 - tt.func @convert_dot_64_4_4_f32_f32_f32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK: convert_dot_64_4_64_f32_f32_f32 + tt.func @convert_dot_64_4_64_f32_f32_f32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<64x4x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 1}>> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 16}>> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 1}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -551,13 +552,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [64, 4], isTransposed = false}> -// CHECK: convert_dot_64_4_4_i8_i8_i32 - tt.func @convert_dot_64_4_4_i8_i8_i32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK: convert_dot_64_4_64_i8_i8_i32 + tt.func @convert_dot_64_4_64_i8_i8_i32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0> : tensor<64x4x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 64}>> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -572,11 +573,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_64_4_4_f8E4M3FNUZ_f8E4M3FNUZ_f32 - tt.func @convert_dot_64_4_4_f8E4M3FNUZ_f8E4M3FNUZ_f32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_64_4_64_f8E4M3FNUZ_f8E4M3FNUZ_f32 + tt.func @convert_dot_64_4_64_f8E4M3FNUZ_f8E4M3FNUZ_f32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<64x4x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -591,11 +592,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_64_4_4_f8E4M3FNUZ_f8E5M2FNUZ_f32 - tt.func @convert_dot_64_4_4_f8E4M3FNUZ_f8E5M2FNUZ_f32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_64_4_64_f8E4M3FNUZ_f8E5M2FNUZ_f32 + tt.func @convert_dot_64_4_64_f8E4M3FNUZ_f8E5M2FNUZ_f32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<64x4x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -610,11 +611,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_64_4_4_f8E5M2FNUZ_f8E4M3FNUZ_f32 - tt.func @convert_dot_64_4_4_f8E5M2FNUZ_f8E4M3FNUZ_f32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_64_4_64_f8E5M2FNUZ_f8E4M3FNUZ_f32 + tt.func @convert_dot_64_4_64_f8E5M2FNUZ_f8E4M3FNUZ_f32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<64x4x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -629,11 +630,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_64_4_4_f8E5M2FNUZ_f8E5M2FNUZ_f32 - tt.func @convert_dot_64_4_4_f8E5M2FNUZ_f8E5M2FNUZ_f32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_64_4_64_f8E5M2FNUZ_f8E5M2FNUZ_f32 + tt.func @convert_dot_64_4_64_f8E5M2FNUZ_f8E5M2FNUZ_f32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<64x4x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -648,13 +649,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [4, 64], isTransposed = false}> -// CHECK: convert_dot_4_64_4_f16_f16_f32 - tt.func @convert_dot_4_64_4_f16_f16_f32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK: convert_dot_4_64_64_f16_f16_f32 + tt.func @convert_dot_4_64_64_f16_f16_f32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<4x64x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 64}>> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } @@ -669,13 +670,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [4, 64], isTransposed = false}> -// CHECK: convert_dot_4_64_4_bf16_bf16_f32 - tt.func @convert_dot_4_64_4_bf16_bf16_f32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK: convert_dot_4_64_64_bf16_bf16_f32 + tt.func @convert_dot_4_64_64_bf16_bf16_f32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<4x64x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 64}>> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } @@ -690,13 +691,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [4, 64], isTransposed = false}> -// CHECK: convert_dot_4_64_4_f32_f32_f32 - tt.func @convert_dot_4_64_4_f32_f32_f32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK: convert_dot_4_64_64_f32_f32_f32 + tt.func @convert_dot_4_64_64_f32_f32_f32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<4x64x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 1}>> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 1}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 16}>> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } @@ -711,13 +712,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [4, 64], isTransposed = false}> -// CHECK: convert_dot_4_64_4_i8_i8_i32 - tt.func @convert_dot_4_64_4_i8_i8_i32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK: convert_dot_4_64_64_i8_i8_i32 + tt.func @convert_dot_4_64_64_i8_i8_i32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0> : tensor<4x64x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 64}>> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } @@ -732,11 +733,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_4_64_4_f8E4M3FNUZ_f8E4M3FNUZ_f32 - tt.func @convert_dot_4_64_4_f8E4M3FNUZ_f8E4M3FNUZ_f32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_4_64_64_f8E4M3FNUZ_f8E4M3FNUZ_f32 + tt.func @convert_dot_4_64_64_f8E4M3FNUZ_f8E4M3FNUZ_f32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<4x64x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } @@ -751,11 +752,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_4_64_4_f8E4M3FNUZ_f8E5M2FNUZ_f32 - tt.func @convert_dot_4_64_4_f8E4M3FNUZ_f8E5M2FNUZ_f32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_4_64_64_f8E4M3FNUZ_f8E5M2FNUZ_f32 + tt.func @convert_dot_4_64_64_f8E4M3FNUZ_f8E5M2FNUZ_f32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<4x64x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } @@ -770,11 +771,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_4_64_4_f8E5M2FNUZ_f8E4M3FNUZ_f32 - tt.func @convert_dot_4_64_4_f8E5M2FNUZ_f8E4M3FNUZ_f32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_4_64_64_f8E5M2FNUZ_f8E4M3FNUZ_f32 + tt.func @convert_dot_4_64_64_f8E5M2FNUZ_f8E4M3FNUZ_f32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<4x64x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } @@ -789,11 +790,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_4_64_4_f8E5M2FNUZ_f8E5M2FNUZ_f32 - tt.func @convert_dot_4_64_4_f8E5M2FNUZ_f8E5M2FNUZ_f32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_4_64_64_f8E5M2FNUZ_f8E5M2FNUZ_f32 + tt.func @convert_dot_4_64_64_f8E5M2FNUZ_f8E5M2FNUZ_f32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<4x64x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } diff --git a/test/TritonGPU/accelerate-matmul-cdna3.mlir b/test/TritonGPU/accelerate-matmul-cdna3.mlir index 24d8ee993615..02c096550345 100644 --- a/test/TritonGPU/accelerate-matmul-cdna3.mlir +++ b/test/TritonGPU/accelerate-matmul-cdna3.mlir @@ -1,3 +1,4 @@ +// This file is generated: $ python3 ../scripts/amd/lit_tests/generate_accelerate_matmul_tests.py 3 ../test/TritonGPU/accelerate-matmul-cdna3.mlir // RUN: (! triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul=arch-generation-name=gfx940 --mlir-pass-pipeline-crash-reproducer=%t 2>/dev/null) | FileCheck --check-prefixes=CHECK %s !a_ty = f16 @@ -504,13 +505,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [64, 4], isTransposed = false}> -// CHECK: convert_dot_64_4_4_f16_f16_f32 - tt.func @convert_dot_64_4_4_f16_f16_f32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK: convert_dot_64_4_64_f16_f16_f32 + tt.func @convert_dot_64_4_64_f16_f16_f32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<64x4x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 64}>> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -525,13 +526,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [64, 4], isTransposed = false}> -// CHECK: convert_dot_64_4_4_bf16_bf16_f32 - tt.func @convert_dot_64_4_4_bf16_bf16_f32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK: convert_dot_64_4_64_bf16_bf16_f32 + tt.func @convert_dot_64_4_64_bf16_bf16_f32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<64x4x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 64}>> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -546,13 +547,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [64, 4], isTransposed = false}> -// CHECK: convert_dot_64_4_4_f32_f32_f32 - tt.func @convert_dot_64_4_4_f32_f32_f32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK: convert_dot_64_4_64_f32_f32_f32 + tt.func @convert_dot_64_4_64_f32_f32_f32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<64x4x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 1}>> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 16}>> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 1}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -567,13 +568,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [64, 4], isTransposed = false}> -// CHECK: convert_dot_64_4_4_i8_i8_i32 - tt.func @convert_dot_64_4_4_i8_i8_i32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK: convert_dot_64_4_64_i8_i8_i32 + tt.func @convert_dot_64_4_64_i8_i8_i32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0> : tensor<64x4x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 64}>> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -588,11 +589,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_64_4_4_f8E4M3FNUZ_f8E4M3FNUZ_f32 - tt.func @convert_dot_64_4_4_f8E4M3FNUZ_f8E4M3FNUZ_f32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_64_4_64_f8E4M3FNUZ_f8E4M3FNUZ_f32 + tt.func @convert_dot_64_4_64_f8E4M3FNUZ_f8E4M3FNUZ_f32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<64x4x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -607,11 +608,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_64_4_4_f8E4M3FNUZ_f8E5M2FNUZ_f32 - tt.func @convert_dot_64_4_4_f8E4M3FNUZ_f8E5M2FNUZ_f32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_64_4_64_f8E4M3FNUZ_f8E5M2FNUZ_f32 + tt.func @convert_dot_64_4_64_f8E4M3FNUZ_f8E5M2FNUZ_f32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<64x4x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -626,11 +627,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_64_4_4_f8E5M2FNUZ_f8E4M3FNUZ_f32 - tt.func @convert_dot_64_4_4_f8E5M2FNUZ_f8E4M3FNUZ_f32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_64_4_64_f8E5M2FNUZ_f8E4M3FNUZ_f32 + tt.func @convert_dot_64_4_64_f8E5M2FNUZ_f8E4M3FNUZ_f32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<64x4x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -645,11 +646,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_64_4_4_f8E5M2FNUZ_f8E5M2FNUZ_f32 - tt.func @convert_dot_64_4_4_f8E5M2FNUZ_f8E5M2FNUZ_f32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_64_4_64_f8E5M2FNUZ_f8E5M2FNUZ_f32 + tt.func @convert_dot_64_4_64_f8E5M2FNUZ_f8E5M2FNUZ_f32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<64x4x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -664,13 +665,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [4, 64], isTransposed = false}> -// CHECK: convert_dot_4_64_4_f16_f16_f32 - tt.func @convert_dot_4_64_4_f16_f16_f32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK: convert_dot_4_64_64_f16_f16_f32 + tt.func @convert_dot_4_64_64_f16_f16_f32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<4x64x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 64}>> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } @@ -685,13 +686,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [4, 64], isTransposed = false}> -// CHECK: convert_dot_4_64_4_bf16_bf16_f32 - tt.func @convert_dot_4_64_4_bf16_bf16_f32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK: convert_dot_4_64_64_bf16_bf16_f32 + tt.func @convert_dot_4_64_64_bf16_bf16_f32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<4x64x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 64}>> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } @@ -706,13 +707,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [4, 64], isTransposed = false}> -// CHECK: convert_dot_4_64_4_f32_f32_f32 - tt.func @convert_dot_4_64_4_f32_f32_f32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK: convert_dot_4_64_64_f32_f32_f32 + tt.func @convert_dot_4_64_64_f32_f32_f32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<4x64x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 1}>> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 1}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 16}>> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } @@ -727,13 +728,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [4, 64], isTransposed = false}> -// CHECK: convert_dot_4_64_4_i8_i8_i32 - tt.func @convert_dot_4_64_4_i8_i8_i32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK: convert_dot_4_64_64_i8_i8_i32 + tt.func @convert_dot_4_64_64_i8_i8_i32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0> : tensor<4x64x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 64}>> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } @@ -748,11 +749,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_4_64_4_f8E4M3FNUZ_f8E4M3FNUZ_f32 - tt.func @convert_dot_4_64_4_f8E4M3FNUZ_f8E4M3FNUZ_f32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_4_64_64_f8E4M3FNUZ_f8E4M3FNUZ_f32 + tt.func @convert_dot_4_64_64_f8E4M3FNUZ_f8E4M3FNUZ_f32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<4x64x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } @@ -767,11 +768,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_4_64_4_f8E4M3FNUZ_f8E5M2FNUZ_f32 - tt.func @convert_dot_4_64_4_f8E4M3FNUZ_f8E5M2FNUZ_f32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_4_64_64_f8E4M3FNUZ_f8E5M2FNUZ_f32 + tt.func @convert_dot_4_64_64_f8E4M3FNUZ_f8E5M2FNUZ_f32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<4x64x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } @@ -786,11 +787,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_4_64_4_f8E5M2FNUZ_f8E4M3FNUZ_f32 - tt.func @convert_dot_4_64_4_f8E5M2FNUZ_f8E4M3FNUZ_f32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_4_64_64_f8E5M2FNUZ_f8E4M3FNUZ_f32 + tt.func @convert_dot_4_64_64_f8E5M2FNUZ_f8E4M3FNUZ_f32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<4x64x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } @@ -805,11 +806,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_4_64_4_f8E5M2FNUZ_f8E5M2FNUZ_f32 - tt.func @convert_dot_4_64_4_f8E5M2FNUZ_f8E5M2FNUZ_f32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_4_64_64_f8E5M2FNUZ_f8E5M2FNUZ_f32 + tt.func @convert_dot_4_64_64_f8E5M2FNUZ_f8E5M2FNUZ_f32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<4x64x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } }