Skip to content

Commit

Permalink
[AMD] Support int DotOp for RDNA3 (triton-lang#3904)
Browse files Browse the repository at this point in the history
- Fixed int8 tests for Navi31
- Generate wmma instructions for int8 or int4 operands
- Convert operands to fp32 if initial type more than int8
- Convert result back to initial type after dot
- Add lit tests

Signed-off-by: Ilya Veselov <[email protected]>
  • Loading branch information
joviliast authored May 22, 2024
1 parent 74ad278 commit cfc14ec
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 7 deletions.
4 changes: 2 additions & 2 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -491,15 +491,15 @@ static bool supportWMMATypes(Type a, Type b, Type c, Type d) {
if (a.isIntOrIndex()) {
if (!c.isIntOrIndex())
return false;
bool aValid = a.isUnsignedInteger() && aWidth <= 8;
bool aValid = aWidth <= 8;
bool cValid = cWidth <= 32;
return aValid && cValid;
} else if (isa<FloatType>(a) && isa<FloatType>(c)) {
if (a.isBF16())
return c.isBF16() || c.isF32();
if (a.isF16())
return c.isF16() || c.isF32();
return aWidth <= cWidth;
return aWidth <= cWidth && aWidth <= 16;
}
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,50 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK: triton_gpu.convert_layout %[[DOT2_WMMA_RES]]
// CHECK-SAME: -> tensor<32x64xf16, #[[DOT_OP_PARENT]]>
tt.store %2, %4 : tensor<32x64x!tt.ptr<f16>, #blocked>
tt.return
}
tt.func public @wmma_dot_i8_i32(
// CHECK: %[[DOT1_ARG_A:.+]]: tensor<32x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>>
%0: tensor<32x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>,
// CHECK-SAME: %[[DOT1_ARG_B:.+]]: tensor<64x32xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>>
%1: tensor<64x32xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>,
%2: tensor<32x32x!tt.ptr<i32>, #blocked>) {
// CHECK: %[[DOT1_ARG_C:.+]] = arith.constant dense<0> : tensor<32x32xi32, #[[DOT_OP_PARENT]]>
// CHECK: %[[DOT1_OP_C:.+]] = triton_gpu.convert_layout %[[DOT1_ARG_C]]
// CHECK-SAME: -> tensor<32x32xi32, #[[WMMA_1]]
%3 = arith.constant dense<0> : tensor<32x32xi32, #blocked>
// CHECK: %[[DOT1_OP_A:.+]] = triton_gpu.convert_layout %[[DOT1_ARG_A]]
// CHECK-SAME: -> tensor<32x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #[[WMMA_1]]
// CHECK: %[[DOT1_OP_B:.+]] = triton_gpu.convert_layout %[[DOT1_ARG_B]]
// CHECK-SAME: -> tensor<64x32xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #[[WMMA_1]]
// CHECK: %[[DOT1_WMMA_RES:.+]] = tt.dot %[[DOT1_OP_A]], %[[DOT1_OP_B]], %[[DOT1_OP_C]]
// CHECK-SAME: -> tensor<32x32xi32, #[[WMMA_1]]
%4 = tt.dot %0, %1, %3 : tensor<32x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x32xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x32xi32, #blocked>
// CHECK: triton_gpu.convert_layout %[[DOT1_WMMA_RES]]
// CHECK-SAME: -> tensor<32x32xi32, #[[DOT_OP_PARENT]]>
tt.store %2, %4 : tensor<32x32x!tt.ptr<i32>, #blocked>
tt.return
}
tt.func public @fma_dot_i16_i16(
// CHECK: %[[DOT3_ARG_A:.+]]: tensor<128x64xi16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>>
%0: tensor<128x64xi16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>,
// CHECK-SAME: %[[DOT3_ARG_B:.+]]: tensor<64x32xi16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>>
%1: tensor<64x32xi16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>,
%2: tensor<128x32x!tt.ptr<i16>, #blocked>) {
// CHECK: %[[DOT3_ARG_C:.+]] = arith.constant dense<0> : tensor<128x32xi16, #[[DOT_OP_PARENT]]>
%3 = arith.constant dense<0> : tensor<128x32xi16, #blocked>
// CHECK: %[[DOT3_OP_A:.+]] = arith.sitofp %[[DOT3_ARG_A]]
// CHECK-SAME: to tensor<128x64xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]
// CHECK: %[[DOT3_OP_B:.+]] = arith.sitofp %[[DOT3_ARG_B]]
// CHECK-SAME: to tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]
// CHECK: %[[DOT3_OP_C:.+]] = arith.sitofp %[[DOT3_ARG_C]]
// CHECK-SAME: to tensor<128x32xf32, #[[DOT_OP_PARENT]]
// CHECK: %[[DOT3_FMA_RES:.+]] = tt.dot %[[DOT3_OP_A]], %[[DOT3_OP_B]], %[[DOT3_OP_C]]
// CHECK-SAME: -> tensor<128x32xf32, #[[DOT_OP_PARENT]]>
%4 = tt.dot %0, %1, %3 : tensor<128x64xi16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x32xi16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x32xi16, #blocked>
// CHECK: arith.fptosi %[[DOT3_FMA_RES]]
// CHECK-SAME: to tensor<128x32xi16, #[[DOT_OP_PARENT]]>
tt.store %2, %4 : tensor<128x32x!tt.ptr<i16>, #blocked>
tt.return
}
}
15 changes: 10 additions & 5 deletions third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ static WMMAInstrType getWMMAInstrTypeFromDot(DotOp op) {
}

Value generateWMMAOp(ConversionPatternRewriter &rewriter, Location loc,
WMMAInstrType wmmaType, Value valA, Value valB,
Value valC) {
WMMAInstrType wmmaType, Value valA, Value valB, Value valC,
Type aElType, Type bElType) {
auto resType = valC.getType();
Value falseFlag = int_val(1, false);
switch (wmmaType) {
Expand All @@ -129,11 +129,15 @@ Value generateWMMAOp(ConversionPatternRewriter &rewriter, Location loc,
case WMMAInstrType::INT32_IU8:
return rewriter.create<ROCDL::wmma_i32_16x16x16_iu8>(
loc, TypeRange{resType},
ValueRange{falseFlag, valA, falseFlag, valB, valC, falseFlag});
ValueRange{int_val(1, !aElType.isUnsignedInteger()), valA,
int_val(1, !bElType.isUnsignedInteger()), valB, valC,
falseFlag});
case WMMAInstrType::INT32_IU4:
return rewriter.create<ROCDL::wmma_i32_16x16x16_iu4>(
loc, TypeRange{resType},
ValueRange{falseFlag, valA, falseFlag, valB, valC, falseFlag});
ValueRange{int_val(1, !aElType.isUnsignedInteger()), valA,
int_val(1, !bElType.isUnsignedInteger()), valB, valC,
falseFlag});
default:
llvm::report_fatal_error("WMMA data type not supported");
}
Expand Down Expand Up @@ -207,7 +211,8 @@ LogicalResult convertDot(DotOp op, DotOpAdaptor adaptor,
}
for (size_t k = 0; k < numRepK; k++) {
acc = generateWMMAOp(rewriter, loc, wmmaInstrType, ha[{m, k}],
hb[{n, k}], acc);
hb[{n, k}], acc, aTensorTy.getElementType(),
bTensorTy.getElementType());
}
for (unsigned v = 0; v < dElemsToStorePerThread; ++v) {
fc[m * numRepN * dElemsToStorePerThread + n * dElemsToStorePerThread +
Expand Down
50 changes: 50 additions & 0 deletions third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,56 @@ static void decomposeMixedModeDotOp(ModuleOp mod) {
// FMA case.
Type AElType = dotOp.getA().getType().getElementType();
Type DElType = D.getType().getElementType();

// Convert int operands to FP32 to apply FMA case
// Do it here instead of introducing new pattern because the pass is more
// about MMA dots.
// TODO: Introduce new pass for FMA dots legalization.
if (AElType.isIntOrIndex()) {
assert(dotOp.getB().getType().getElementType().isIntOrIndex() &&
dotOp.getC().getType().getElementType().isIntOrIndex() &&
DElType.isIntOrIndex());
auto convertTensorIToFP = [&](Value v) -> Value {
RankedTensorType vTy = cast<RankedTensorType>(v.getType());
Type dstType = vTy.cloneWith(std::nullopt, builder.getF32Type());
Type srcElType = vTy.getElementType();
return !srcElType.isUnsignedInteger()
? builder
.create<mlir::arith::SIToFPOp>(dotOp.getLoc(),
dstType, v)
.getResult()
: builder
.create<mlir::arith::UIToFPOp>(dotOp.getLoc(),
dstType, v)
.getResult();
};
auto convertTensorFPToI = [&](Type dstElType, Value v) -> Value {
RankedTensorType vTy = cast<RankedTensorType>(v.getType());
Type dstType = vTy.cloneWith(std::nullopt, dstElType);
return !dstElType.isUnsignedInteger()
? builder
.create<mlir::arith::FPToSIOp>(dotOp.getLoc(),
dstType, v)
.getResult()
: builder
.create<mlir::arith::FPToUIOp>(dotOp.getLoc(),
dstType, v)
.getResult();
};

auto newAOperand = convertTensorIToFP(dotOp.getA());
auto newBOperand = convertTensorIToFP(dotOp.getB());
auto newCOperand = convertTensorIToFP(dotOp.getC());
auto newDot = builder.create<tt::DotOp>(
dotOp.getLoc(), newCOperand.getType(), newAOperand, newBOperand,
newCOperand, dotOp.getInputPrecision(),
dotOp.getMaxNumImpreciseAcc());
auto newD = convertTensorFPToI(DElType, newDot.getResult());
D.replaceAllUsesWith(newD);
dotOp.erase();
return;
}

if (AElType == DElType)
return;
promoteType = DElType;
Expand Down

0 comments on commit cfc14ec

Please sign in to comment.