-
Notifications
You must be signed in to change notification settings - Fork 31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Brgemm register tiling for bf16 type #1005
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
//===- BrgemmLinalgTiling.cpp -----------------------------------------*- C++-*-===// | ||
//===- BrgemmLinalgTiling.cpp -----------------------------------------*- | ||
//C++-*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
|
@@ -43,160 +44,152 @@ using namespace mlir::tpp; | |
|
||
namespace mlir { | ||
namespace tpp { | ||
struct LinalgOpTiling : OpRewritePattern<linalg::BatchReduceMatmulOp> { | ||
using OpRewritePattern<linalg::BatchReduceMatmulOp>::OpRewritePattern; | ||
|
||
template <typename BrgemmOp> | ||
struct LinalgOpTiling : OpRewritePattern<BrgemmOp> { | ||
using OpRewritePattern<BrgemmOp>::OpRewritePattern; | ||
|
||
LinalgOpTiling(MLIRContext *ctx, BrgemmLinalgTilingOptions tilingoptions) | ||
: OpRewritePattern(ctx), options(tilingoptions) {} | ||
: OpRewritePattern<BrgemmOp>(ctx), options(tilingoptions) {} | ||
|
||
LogicalResult matchAndRewrite(linalg::BatchReduceMatmulOp brgemmOp, | ||
LogicalResult matchAndRewrite(BrgemmOp brgemmOp, | ||
PatternRewriter &rewriter) const override { | ||
|
||
if (!brgemmOp.hasPureBufferSemantics()) | ||
return failure(); | ||
// Get the register blocking tile shape from the user input | ||
SmallVector<int64_t> tileShapeM(options.registerTileShape.begin(), | ||
options.registerTileShape.end()); | ||
|
||
if (tileShapeM.size() != 2) | ||
return failure(); | ||
// Check whether the tile sizes are valid | ||
if (options.registerTileShape.size() != 3 && | ||
options.registerTileShape.size() != 2) | ||
return failure(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: it's easier to debug in the future with |
||
|
||
SmallVector<int64_t> tileShapeN(2); | ||
tileShapeN[0] = 1; | ||
tileShapeN[1] = tileShapeM[1]; | ||
tileShapeM[1] = 1; | ||
// Check the whether the operation is brmatmul fp32 or bf16 type using | ||
// reduction count | ||
SmallVector<utils::IteratorType> brgemmIteratorTypes = | ||
brgemmOp.getIteratorTypesArray(); | ||
int reductionCount = | ||
std::count(brgemmIteratorTypes.begin(), brgemmIteratorTypes.end(), | ||
utils::IteratorType::reduction); | ||
if (reductionCount != 2 && reductionCount != 3) | ||
return failure(); | ||
|
||
// Stores the M, N, and K Tile Sizes | ||
// Get the register blocking tile shape from the user input | ||
SmallVector<int64_t> mxnxkTile(3); | ||
// Stores the M, and N Tile Sizes | ||
SmallVector<int64_t> mxnTile(2); | ||
for (size_t i = 0; i < options.registerTileShape.size(); i++) { | ||
mxnxkTile[i] = options.registerTileShape[i]; | ||
} | ||
|
||
// Set the K tile to 1, if the user not provided (it is fp32 target) | ||
if (options.registerTileShape.size() == 2) | ||
mxnxkTile[2] = 1; | ||
Comment on lines
+81
to
+83
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd just force user to always provide m,n,k tiles. It'll simplify verification logic and makes usage more explicit. |
||
|
||
// k-tile size adjusted based on the vnni layout for bf16 type | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This has baked-in assumptions that are not verified. |
||
auto tensorShape = | ||
dyn_cast<MemRefType>(brgemmOp.getOperand(0).getType()).getShape(); | ||
if (tensorShape.size() == 4 && options.registerTileShape.size() == 3) { | ||
mxnxkTile[2] = mxnxkTile[2] / tensorShape[3]; | ||
} | ||
|
||
mxnxkTile[0] = tileShapeM[0]; | ||
mxnxkTile[1] = tileShapeN[1]; | ||
mxnxkTile[2] = tileShapeM[1]; | ||
mxnTile[0] = tileShapeM[0]; | ||
mxnTile[1] = tileShapeN[1]; | ||
|
||
// To assist in calculating the argument and step values for the tiled loop. | ||
SmallVector<int64_t> boundariesOne{1, | ||
static_cast<long>(tileShapeM.size() - 1), | ||
static_cast<long>(mxnxkTile.size() - 1)}; | ||
|
||
SmallVector<int64_t> tileSizesIndex{static_cast<long>(tileShapeM.size()), | ||
static_cast<long>(tileShapeN.size()), | ||
static_cast<long>(mxnTile.size())}; | ||
SmallVector<SmallVector<int64_t>> tileshapes{tileShapeM, tileShapeN, mxnTile}; | ||
SmallVector<int> swap_i = {0, 2, 1}; | ||
size_t i = 0; | ||
SmallVector<int> swap_i = {0, 2, 1}; | ||
std::map<int, std::map<int, Value>> inductionVars; | ||
|
||
// For M, N, and K loops | ||
scf::ForOp innermostForLoop; | ||
// For brgemm reduction loop | ||
scf::ForOp reductionForLoop; | ||
|
||
// Creating the tiled loops | ||
for (auto itrShapeM = mxnxkTile.begin(); itrShapeM != mxnxkTile.end(); | ||
itrShapeM++, i++) { | ||
int index = swap_i[i] / boundariesOne[swap_i[i]]; | ||
int offset = swap_i[i] / (mxnxkTile.size() - 1); | ||
|
||
int operandSize = | ||
dyn_cast<MemRefType>(brgemmOp.getOperand(index).getType()) | ||
.getShape() | ||
.size(); | ||
int effectiveOffset = operandSize - tileSizesIndex[index] + offset; | ||
for (auto itrShapeMNK = mxnxkTile.begin(); itrShapeMNK != mxnxkTile.end(); | ||
itrShapeMNK++, i++) { | ||
Comment on lines
+99
to
+100
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: you could use
|
||
auto upperBound = | ||
dyn_cast<MemRefType>(brgemmOp.getOperand(index).getType()) | ||
.getShape()[effectiveOffset]; | ||
dyn_cast<MemRefType>(brgemmOp.getOperand(swap_i[i]).getType()) | ||
.getShape()[1]; | ||
// Tile size should not be greater than the upperBound | ||
if ((*itrShapeMNK) > upperBound) | ||
return failure(); | ||
|
||
Location loc = brgemmOp.getLoc(); | ||
Value zeroCst = rewriter.create<arith::ConstantIndexOp>(loc, 0); | ||
Value ubCstTiledLoop = rewriter.create<arith::ConstantIndexOp>(loc, upperBound); | ||
//Tile size should not be greater than the upperBound | ||
if ((*itrShapeM) > upperBound) | ||
return failure(); | ||
Value stepCstTiledLoop = rewriter.create<arith::ConstantIndexOp>(loc, *itrShapeM); | ||
Value ubCstTiledLoop = | ||
rewriter.create<arith::ConstantIndexOp>(loc, upperBound); | ||
Value stepCstTiledLoop = | ||
rewriter.create<arith::ConstantIndexOp>(loc, *itrShapeMNK); | ||
// Creates M, N, and K tile loops | ||
scf::ForOp loopOp = rewriter.create<scf::ForOp>(brgemmOp.getLoc(), | ||
zeroCst, ubCstTiledLoop, stepCstTiledLoop); | ||
scf::ForOp loopOp = rewriter.create<scf::ForOp>( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I am understanding right, this transform is meant to operate on linalg ops. As I expect all the ops you want to support will implement TilingInterface, would it be possible to just use the |
||
brgemmOp.getLoc(), zeroCst, ubCstTiledLoop, stepCstTiledLoop); | ||
rewriter.setInsertionPointToStart(loopOp.getBody()); | ||
int indexTwo = offset; | ||
int operandSizeTwo = | ||
dyn_cast<MemRefType>(brgemmOp.getOperand(indexTwo).getType()) | ||
.getShape() | ||
.size(); | ||
int effectiveOffsetTwo = operandSizeTwo - tileSizesIndex[index] + index; | ||
|
||
inductionVars[index][effectiveOffset] = loopOp.getInductionVar(); | ||
|
||
inductionVars[indexTwo][effectiveOffsetTwo] = loopOp.getInductionVar(); | ||
int indexThree = mxnTile.size(); | ||
int effectiveOffsetThree = | ||
index + | ||
dyn_cast<MemRefType>(brgemmOp.getOperand(indexThree).getType()) | ||
.getShape() | ||
.size() - | ||
tileSizesIndex[indexThree]; | ||
if (inductionVars[indexThree][effectiveOffsetThree] == NULL) { | ||
inductionVars[indexThree][effectiveOffsetThree] = | ||
loopOp.getInductionVar(); | ||
} | ||
|
||
innermostForLoop = loopOp; | ||
if ((mxnxkTile.size() - 1) == (i + 1)) { | ||
//Creates the brgemm reduction loop | ||
|
||
// Stores the induction variable with respect to the operands mapping it's | ||
// subview. | ||
if (i == 0) { // Stores iv for M loop | ||
inductionVars[0][1] = loopOp.getInductionVar(); | ||
inductionVars[2][0] = loopOp.getInductionVar(); | ||
} else if (i == 1) { //stores iv for N loop, creates batch loop, and maps iv of batch loop | ||
inductionVars[1][2] = loopOp.getInductionVar(); | ||
inductionVars[2][1] = loopOp.getInductionVar(); | ||
// Creates reduction loop after the N loop | ||
Value ubCstReduction = rewriter.create<arith::ConstantIndexOp>( | ||
loc, dyn_cast<MemRefType>(brgemmOp.getOperand(0).getType()) | ||
.getShape()[0]); | ||
Value stepCstReduction = rewriter.create<arith::ConstantIndexOp>(loc, 1); | ||
Value stepCstReduction = | ||
rewriter.create<arith::ConstantIndexOp>(loc, 1); | ||
scf::ForOp redloopOp = rewriter.create<scf::ForOp>( | ||
brgemmOp.getLoc(), zeroCst, ubCstReduction, stepCstReduction); | ||
rewriter.setInsertionPointToStart(redloopOp.getBody()); | ||
reductionForLoop = redloopOp; | ||
inductionVars[0][0] = redloopOp.getInductionVar(); | ||
inductionVars[1][0] = redloopOp.getInductionVar(); | ||
|
||
} else if (i == 2) { // stores iv for k-loop | ||
inductionVars[0][2] = loopOp.getInductionVar(); | ||
inductionVars[1][1] = loopOp.getInductionVar(); | ||
} | ||
} | ||
|
||
// DS to assist while creating new subviews with correct indices and shapes | ||
SmallVector<int64_t> mxkTile(2); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: you could directly brace initialize it as |
||
SmallVector<int64_t> kxnTile(2); | ||
SmallVector<int64_t> mxnTile(2); | ||
mxkTile[0] = mxnxkTile[0]; | ||
mxkTile[1] = mxnxkTile[2]; | ||
kxnTile[0] = mxnxkTile[2]; | ||
kxnTile[1] = mxnxkTile[1]; | ||
mxnTile[0] = mxnxkTile[0]; | ||
mxnTile[1] = mxnxkTile[1]; | ||
|
||
SmallVector<SmallVector<int64_t>> tileshapes{mxkTile, kxnTile, mxnTile}; | ||
// Creating subviews | ||
SmallVector<SmallVector<int64_t>> tiles = {tileShapeM, tileShapeN}; | ||
for (size_t i = 0; i < brgemmOp.getNumOperands(); i++) { | ||
SmallVector<int64_t> indices; | ||
auto input = brgemmOp.getOperand(i); | ||
auto operandType = input.getType(); | ||
SmallVector<OpFoldResult> offsets; | ||
size_t k = 0; | ||
auto tileItr = tileshapes[i].begin(); | ||
auto tensorShape = dyn_cast<MemRefType>(operandType).getShape(); | ||
SmallVector<int64_t> indices; | ||
SmallVector<OpFoldResult> shape; | ||
SmallVector<OpFoldResult> strides; | ||
|
||
auto input = brgemmOp.getOperand(i); | ||
auto tensorShape = dyn_cast<MemRefType>(input.getType()).getShape(); | ||
auto tileItr = tileshapes[i].begin(); | ||
|
||
// Iterates over the shape of each tensor and update its offsets, indices, | ||
// shapes, strides with respect to tile sizes | ||
for (size_t j = 0; j < tensorShape.size(); j++) { | ||
if (j < tensorShape.size() - tileSizesIndex[i]) { | ||
if (j == ((tensorShape.size() - tileSizesIndex[i]) - 1) && | ||
i < (brgemmOp.getNumOperands() - 1)) { | ||
offsets.push_back(reductionForLoop.getInductionVar()); | ||
indices.push_back(tensorShape[j] / tensorShape[j]); | ||
shape.push_back(rewriter.getIndexAttr(tensorShape[j] / tensorShape[j])); | ||
strides.push_back(rewriter.getIndexAttr(1)); | ||
|
||
} else { | ||
offsets.push_back(rewriter.getIndexAttr(0)); | ||
indices.push_back(tensorShape[j]); | ||
shape.push_back(rewriter.getIndexAttr(tensorShape[j])); | ||
strides.push_back(rewriter.getIndexAttr(1)); | ||
} | ||
} else { | ||
shape.push_back(rewriter.getIndexAttr(*tileItr)); | ||
if (j == 0 && (i < 2)) { // Updates the batch dimension | ||
offsets.push_back(inductionVars[i][j]); | ||
indices.push_back(1); | ||
shape.push_back(rewriter.getIndexAttr(1)); | ||
strides.push_back(rewriter.getIndexAttr(1)); | ||
} else if (j < 3) { // Updates the M, N, and K dimensions | ||
offsets.push_back(inductionVars[i][j]); | ||
indices.push_back((*tileItr)); | ||
shape.push_back(rewriter.getIndexAttr(*tileItr)); | ||
strides.push_back(rewriter.getIndexAttr(1)); | ||
offsets.push_back( | ||
inductionVars[i][tensorShape.size() - tileSizesIndex[i] + k]); | ||
k++; | ||
tileItr++; | ||
} else { // Just copies the vnni layout dimensions | ||
offsets.push_back(rewriter.getIndexAttr(0)); | ||
indices.push_back(tensorShape[j]); | ||
shape.push_back(rewriter.getIndexAttr(tensorShape[j])); | ||
strides.push_back(rewriter.getIndexAttr(1)); | ||
} | ||
} | ||
|
||
auto subview = rewriter.create<memref::SubViewOp>( | ||
brgemmOp.getLoc(), MemRefType(), | ||
input, offsets, shape, strides); | ||
brgemmOp.getLoc(), MemRefType(), input, offsets, shape, strides); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: you can skip the result MemRefType and use this builder: |
||
brgemmOp.setOperand(i, subview); | ||
} | ||
|
||
|
@@ -214,11 +207,14 @@ struct LinalgOpTiling : OpRewritePattern<linalg::BatchReduceMatmulOp> { | |
}; | ||
|
||
void populateBrgemmLinalgTilingPatterns(RewritePatternSet &patterns, | ||
BrgemmLinalgTilingOptions options) { | ||
patterns.add<LinalgOpTiling>(patterns.getContext(), options); | ||
BrgemmLinalgTilingOptions options) { | ||
patterns.add<LinalgOpTiling<linalg::GenericOp>, | ||
LinalgOpTiling<linalg::BatchReduceMatmulOp>>( | ||
patterns.getContext(), options); | ||
} | ||
|
||
struct BrgemmLinalgTiling : public tpp::impl::BrgemmLinalgTilingBase<BrgemmLinalgTiling> { | ||
struct BrgemmLinalgTiling | ||
: public tpp::impl::BrgemmLinalgTilingBase<BrgemmLinalgTiling> { | ||
|
||
using BrgemmLinalgTilingBase::BrgemmLinalgTilingBase; | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
// RUN: tpp-run -e register_tile_bf16 --entry-point-result=void -print %s > %t.1 | ||
// RUN: tpp-opt %s --tile-brgemm-linalg="registerBlocking=32,32,32" -convert-linalg-to-xsmm | tpp-run -e register_tile_bf16 --entry-point-result=void -print > %t.2 | ||
// RUN: diff %t.1 %t.2 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It might be more robust to use |
||
// RUN: rm %t.1 %t.2 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: it already creates temporary file, there should be no need to explicitly delete them |
||
|
||
#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)> | ||
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)> | ||
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)> | ||
module { | ||
memref.global "private" constant @__constant_32x16x32x2xbf16 : memref<32x16x32x2xbf16> = dense<1.000000e+00> {alignment = 64 : i64} | ||
func.func @register_tile_bf16(%arg0: memref<8x32x32x32xbf16>) -> memref<8x32x32x32xbf16> { | ||
%cst = arith.constant 0.000000e+00 : bf16 | ||
%0 = memref.get_global @__constant_32x16x32x2xbf16 : memref<32x16x32x2xbf16> | ||
%alloc = memref.alloc() {alignment = 64 : i64} : memref<8x32x32x32xbf16> | ||
%expand_shape = memref.expand_shape %arg0 [[0], [1], [2], [3, 4]] output_shape [8, 32, 32, 16, 2] : memref<8x32x32x32xbf16> into memref<8x32x32x16x2xbf16> | ||
scf.forall (%arg1, %arg2) in (8, 32) { | ||
%subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>> | ||
linalg.fill ins(%cst : bf16) outs(%subview : memref<32x32xbf16, strided<[32, 1], offset: ?>>) | ||
%subview_0 = memref.subview %expand_shape[%arg1, 0, 0, 0, 0] [1, 32, 32, 16, 2] [1, 1, 1, 1, 1] : memref<8x32x32x16x2xbf16> to memref<32x32x16x2xbf16, strided<[1024, 32, 2, 1], offset: ?>> | ||
linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%subview_0, %0 : memref<32x32x16x2xbf16, strided<[1024, 32, 2, 1], offset: ?>>, memref<32x16x32x2xbf16>) outs(%subview : memref<32x32xbf16, strided<[32, 1], offset: ?>>) { | ||
^bb0(%in: bf16, %in_1: bf16, %out: bf16): | ||
%1 = arith.mulf %in, %in_1 : bf16 | ||
%2 = arith.addf %out, %1 : bf16 | ||
linalg.yield %2 : bf16 | ||
} | ||
} | ||
return %alloc : memref<8x32x32x32xbf16> | ||
} | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: tweak to fit in single line