Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Brgemm register tiling for bf16 type #1005

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 109 additions & 113 deletions lib/TPP/Transforms/BrgemmLinalgTiling.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
//===- BrgemmLinalgTiling.cpp -----------------------------------------*- C++-*-===//
//===- BrgemmLinalgTiling.cpp -----------------------------------------*-
//C++-*-===//
Comment on lines +1 to +2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: tweak to fit in single line

//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand Down Expand Up @@ -43,160 +44,152 @@ using namespace mlir::tpp;

namespace mlir {
namespace tpp {
struct LinalgOpTiling : OpRewritePattern<linalg::BatchReduceMatmulOp> {
using OpRewritePattern<linalg::BatchReduceMatmulOp>::OpRewritePattern;

template <typename BrgemmOp>
struct LinalgOpTiling : OpRewritePattern<BrgemmOp> {
using OpRewritePattern<BrgemmOp>::OpRewritePattern;

LinalgOpTiling(MLIRContext *ctx, BrgemmLinalgTilingOptions tilingoptions)
: OpRewritePattern(ctx), options(tilingoptions) {}
: OpRewritePattern<BrgemmOp>(ctx), options(tilingoptions) {}

LogicalResult matchAndRewrite(linalg::BatchReduceMatmulOp brgemmOp,
LogicalResult matchAndRewrite(BrgemmOp brgemmOp,
PatternRewriter &rewriter) const override {

if (!brgemmOp.hasPureBufferSemantics())
return failure();
// Get the register blocking tile shape from the user input
SmallVector<int64_t> tileShapeM(options.registerTileShape.begin(),
options.registerTileShape.end());

if (tileShapeM.size() != 2)
return failure();
// Check whether the tile sizes are valid
if (options.registerTileShape.size() != 3 &&
options.registerTileShape.size() != 2)
return failure();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it's easier to debug in the future with rewriter.notifyMatchFailure that provides some feedback instead of failure


SmallVector<int64_t> tileShapeN(2);
tileShapeN[0] = 1;
tileShapeN[1] = tileShapeM[1];
tileShapeM[1] = 1;
// Check the whether the operation is brmatmul fp32 or bf16 type using
// reduction count
SmallVector<utils::IteratorType> brgemmIteratorTypes =
brgemmOp.getIteratorTypesArray();
int reductionCount =
std::count(brgemmIteratorTypes.begin(), brgemmIteratorTypes.end(),
utils::IteratorType::reduction);
if (reductionCount != 2 && reductionCount != 3)
return failure();

// Stores the M, N, and K Tile Sizes
// Get the register blocking tile shape from the user input
SmallVector<int64_t> mxnxkTile(3);
// Stores the M, and N Tile Sizes
SmallVector<int64_t> mxnTile(2);
for (size_t i = 0; i < options.registerTileShape.size(); i++) {
mxnxkTile[i] = options.registerTileShape[i];
}

// Set the K tile to 1, if the user not provided (it is fp32 target)
if (options.registerTileShape.size() == 2)
mxnxkTile[2] = 1;
Comment on lines +81 to +83
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd just force user to always provide m,n,k tiles. It'll simplify verification logic and makes usage more explicit.


// k-tile size adjusted based on the vnni layout for bf16 type
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has baked-in assumptions that are not verified.
As the pass now operates on generic, we need to strictly filter ops that are accepted. I think you need to at least ensure it is a VNNI contraction first - there should be some suitable helpers in VnniUtils.
If f32 generic should be supported as well, it might need some extra checks there too.

auto tensorShape =
dyn_cast<MemRefType>(brgemmOp.getOperand(0).getType()).getShape();
if (tensorShape.size() == 4 && options.registerTileShape.size() == 3) {
mxnxkTile[2] = mxnxkTile[2] / tensorShape[3];
}

mxnxkTile[0] = tileShapeM[0];
mxnxkTile[1] = tileShapeN[1];
mxnxkTile[2] = tileShapeM[1];
mxnTile[0] = tileShapeM[0];
mxnTile[1] = tileShapeN[1];

// To assist in calculating the argument and step values for the tiled loop.
SmallVector<int64_t> boundariesOne{1,
static_cast<long>(tileShapeM.size() - 1),
static_cast<long>(mxnxkTile.size() - 1)};

SmallVector<int64_t> tileSizesIndex{static_cast<long>(tileShapeM.size()),
static_cast<long>(tileShapeN.size()),
static_cast<long>(mxnTile.size())};
SmallVector<SmallVector<int64_t>> tileshapes{tileShapeM, tileShapeN, mxnTile};
SmallVector<int> swap_i = {0, 2, 1};
size_t i = 0;
SmallVector<int> swap_i = {0, 2, 1};
std::map<int, std::map<int, Value>> inductionVars;

// For M, N, and K loops
scf::ForOp innermostForLoop;
// For brgemm reduction loop
scf::ForOp reductionForLoop;

// Creating the tiled loops
for (auto itrShapeM = mxnxkTile.begin(); itrShapeM != mxnxkTile.end();
itrShapeM++, i++) {
int index = swap_i[i] / boundariesOne[swap_i[i]];
int offset = swap_i[i] / (mxnxkTile.size() - 1);

int operandSize =
dyn_cast<MemRefType>(brgemmOp.getOperand(index).getType())
.getShape()
.size();
int effectiveOffset = operandSize - tileSizesIndex[index] + offset;
for (auto itrShapeMNK = mxnxkTile.begin(); itrShapeMNK != mxnxkTile.end();
itrShapeMNK++, i++) {
Comment on lines +99 to +100
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you could use llvm::enumerate for this, like

for (auto [idx, itrShape] : llvm::enumerate(mxnxkTile)) {

auto upperBound =
dyn_cast<MemRefType>(brgemmOp.getOperand(index).getType())
.getShape()[effectiveOffset];
dyn_cast<MemRefType>(brgemmOp.getOperand(swap_i[i]).getType())
.getShape()[1];
// Tile size should not be greater than the upperBound
if ((*itrShapeMNK) > upperBound)
return failure();

Location loc = brgemmOp.getLoc();
Value zeroCst = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value ubCstTiledLoop = rewriter.create<arith::ConstantIndexOp>(loc, upperBound);
//Tile size should not be greater than the upperBound
if ((*itrShapeM) > upperBound)
return failure();
Value stepCstTiledLoop = rewriter.create<arith::ConstantIndexOp>(loc, *itrShapeM);
Value ubCstTiledLoop =
rewriter.create<arith::ConstantIndexOp>(loc, upperBound);
Value stepCstTiledLoop =
rewriter.create<arith::ConstantIndexOp>(loc, *itrShapeMNK);
// Creates M, N, and K tile loops
scf::ForOp loopOp = rewriter.create<scf::ForOp>(brgemmOp.getLoc(),
zeroCst, ubCstTiledLoop, stepCstTiledLoop);
scf::ForOp loopOp = rewriter.create<scf::ForOp>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I am understanding right, this transform is meant to operate on linalg ops. As I expect all the ops you want to support will implement TilingInterface, would it be possible to just use the TileUsingFor transform instead of manually implementing tiling?

brgemmOp.getLoc(), zeroCst, ubCstTiledLoop, stepCstTiledLoop);
rewriter.setInsertionPointToStart(loopOp.getBody());
int indexTwo = offset;
int operandSizeTwo =
dyn_cast<MemRefType>(brgemmOp.getOperand(indexTwo).getType())
.getShape()
.size();
int effectiveOffsetTwo = operandSizeTwo - tileSizesIndex[index] + index;

inductionVars[index][effectiveOffset] = loopOp.getInductionVar();

inductionVars[indexTwo][effectiveOffsetTwo] = loopOp.getInductionVar();
int indexThree = mxnTile.size();
int effectiveOffsetThree =
index +
dyn_cast<MemRefType>(brgemmOp.getOperand(indexThree).getType())
.getShape()
.size() -
tileSizesIndex[indexThree];
if (inductionVars[indexThree][effectiveOffsetThree] == NULL) {
inductionVars[indexThree][effectiveOffsetThree] =
loopOp.getInductionVar();
}

innermostForLoop = loopOp;
if ((mxnxkTile.size() - 1) == (i + 1)) {
//Creates the brgemm reduction loop

// Stores the induction variable with respect to the operands mapping it's
// subview.
if (i == 0) { // Stores iv for M loop
inductionVars[0][1] = loopOp.getInductionVar();
inductionVars[2][0] = loopOp.getInductionVar();
} else if (i == 1) { //stores iv for N loop, creates batch loop, and maps iv of batch loop
inductionVars[1][2] = loopOp.getInductionVar();
inductionVars[2][1] = loopOp.getInductionVar();
// Creates reduction loop after the N loop
Value ubCstReduction = rewriter.create<arith::ConstantIndexOp>(
loc, dyn_cast<MemRefType>(brgemmOp.getOperand(0).getType())
.getShape()[0]);
Value stepCstReduction = rewriter.create<arith::ConstantIndexOp>(loc, 1);
Value stepCstReduction =
rewriter.create<arith::ConstantIndexOp>(loc, 1);
scf::ForOp redloopOp = rewriter.create<scf::ForOp>(
brgemmOp.getLoc(), zeroCst, ubCstReduction, stepCstReduction);
rewriter.setInsertionPointToStart(redloopOp.getBody());
reductionForLoop = redloopOp;
inductionVars[0][0] = redloopOp.getInductionVar();
inductionVars[1][0] = redloopOp.getInductionVar();

} else if (i == 2) { // stores iv for k-loop
inductionVars[0][2] = loopOp.getInductionVar();
inductionVars[1][1] = loopOp.getInductionVar();
}
}

// DS to assist while creating new subviews with correct indices and shapes
SmallVector<int64_t> mxkTile(2);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you could directly brace initialize it as mxkTile = {val1, val2};

SmallVector<int64_t> kxnTile(2);
SmallVector<int64_t> mxnTile(2);
mxkTile[0] = mxnxkTile[0];
mxkTile[1] = mxnxkTile[2];
kxnTile[0] = mxnxkTile[2];
kxnTile[1] = mxnxkTile[1];
mxnTile[0] = mxnxkTile[0];
mxnTile[1] = mxnxkTile[1];

SmallVector<SmallVector<int64_t>> tileshapes{mxkTile, kxnTile, mxnTile};
// Creating subviews
SmallVector<SmallVector<int64_t>> tiles = {tileShapeM, tileShapeN};
for (size_t i = 0; i < brgemmOp.getNumOperands(); i++) {
SmallVector<int64_t> indices;
auto input = brgemmOp.getOperand(i);
auto operandType = input.getType();
SmallVector<OpFoldResult> offsets;
size_t k = 0;
auto tileItr = tileshapes[i].begin();
auto tensorShape = dyn_cast<MemRefType>(operandType).getShape();
SmallVector<int64_t> indices;
SmallVector<OpFoldResult> shape;
SmallVector<OpFoldResult> strides;

auto input = brgemmOp.getOperand(i);
auto tensorShape = dyn_cast<MemRefType>(input.getType()).getShape();
auto tileItr = tileshapes[i].begin();

// Iterates over the shape of each tensor and update its offsets, indices,
// shapes, strides with respect to tile sizes
for (size_t j = 0; j < tensorShape.size(); j++) {
if (j < tensorShape.size() - tileSizesIndex[i]) {
if (j == ((tensorShape.size() - tileSizesIndex[i]) - 1) &&
i < (brgemmOp.getNumOperands() - 1)) {
offsets.push_back(reductionForLoop.getInductionVar());
indices.push_back(tensorShape[j] / tensorShape[j]);
shape.push_back(rewriter.getIndexAttr(tensorShape[j] / tensorShape[j]));
strides.push_back(rewriter.getIndexAttr(1));

} else {
offsets.push_back(rewriter.getIndexAttr(0));
indices.push_back(tensorShape[j]);
shape.push_back(rewriter.getIndexAttr(tensorShape[j]));
strides.push_back(rewriter.getIndexAttr(1));
}
} else {
shape.push_back(rewriter.getIndexAttr(*tileItr));
if (j == 0 && (i < 2)) { // Updates the batch dimension
offsets.push_back(inductionVars[i][j]);
indices.push_back(1);
shape.push_back(rewriter.getIndexAttr(1));
strides.push_back(rewriter.getIndexAttr(1));
} else if (j < 3) { // Updates the M, N, and K dimensions
offsets.push_back(inductionVars[i][j]);
indices.push_back((*tileItr));
shape.push_back(rewriter.getIndexAttr(*tileItr));
strides.push_back(rewriter.getIndexAttr(1));
offsets.push_back(
inductionVars[i][tensorShape.size() - tileSizesIndex[i] + k]);
k++;
tileItr++;
} else { // Just copies the vnni layout dimensions
offsets.push_back(rewriter.getIndexAttr(0));
indices.push_back(tensorShape[j]);
shape.push_back(rewriter.getIndexAttr(tensorShape[j]));
strides.push_back(rewriter.getIndexAttr(1));
}
}

auto subview = rewriter.create<memref::SubViewOp>(
brgemmOp.getLoc(), MemRefType(),
input, offsets, shape, strides);
brgemmOp.getLoc(), MemRefType(), input, offsets, shape, strides);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you can skip the result MemRefType and use this builder:
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, Value source, ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides, ArrayRef<NamedAttribute> attrs = {});

brgemmOp.setOperand(i, subview);
}

Expand All @@ -214,11 +207,14 @@ struct LinalgOpTiling : OpRewritePattern<linalg::BatchReduceMatmulOp> {
};

void populateBrgemmLinalgTilingPatterns(RewritePatternSet &patterns,
BrgemmLinalgTilingOptions options) {
patterns.add<LinalgOpTiling>(patterns.getContext(), options);
BrgemmLinalgTilingOptions options) {
patterns.add<LinalgOpTiling<linalg::GenericOp>,
LinalgOpTiling<linalg::BatchReduceMatmulOp>>(
patterns.getContext(), options);
}

struct BrgemmLinalgTiling : public tpp::impl::BrgemmLinalgTilingBase<BrgemmLinalgTiling> {
struct BrgemmLinalgTiling
: public tpp::impl::BrgemmLinalgTilingBase<BrgemmLinalgTiling> {

using BrgemmLinalgTilingBase::BrgemmLinalgTilingBase;

Expand Down
30 changes: 30 additions & 0 deletions test/Integration/tile-brgemm-linalg-matmul-bf16.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// RUN: tpp-run -e register_tile_bf16 --entry-point-result=void -print %s > %t.1
// RUN: tpp-opt %s --tile-brgemm-linalg="registerBlocking=32,32,32" -convert-linalg-to-xsmm | tpp-run -e register_tile_bf16 --entry-point-result=void -print > %t.2
// RUN: diff %t.1 %t.2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be more robust to use fpcmp

// RUN: rm %t.1 %t.2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it already creates temporary file, there should be no need to explicitly delete them


#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>
module {
memref.global "private" constant @__constant_32x16x32x2xbf16 : memref<32x16x32x2xbf16> = dense<1.000000e+00> {alignment = 64 : i64}
func.func @register_tile_bf16(%arg0: memref<8x32x32x32xbf16>) -> memref<8x32x32x32xbf16> {
%cst = arith.constant 0.000000e+00 : bf16
%0 = memref.get_global @__constant_32x16x32x2xbf16 : memref<32x16x32x2xbf16>
%alloc = memref.alloc() {alignment = 64 : i64} : memref<8x32x32x32xbf16>
%expand_shape = memref.expand_shape %arg0 [[0], [1], [2], [3, 4]] output_shape [8, 32, 32, 16, 2] : memref<8x32x32x32xbf16> into memref<8x32x32x16x2xbf16>
scf.forall (%arg1, %arg2) in (8, 32) {
%subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>>
linalg.fill ins(%cst : bf16) outs(%subview : memref<32x32xbf16, strided<[32, 1], offset: ?>>)
%subview_0 = memref.subview %expand_shape[%arg1, 0, 0, 0, 0] [1, 32, 32, 16, 2] [1, 1, 1, 1, 1] : memref<8x32x32x16x2xbf16> to memref<32x32x16x2xbf16, strided<[1024, 32, 2, 1], offset: ?>>
linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%subview_0, %0 : memref<32x32x16x2xbf16, strided<[1024, 32, 2, 1], offset: ?>>, memref<32x16x32x2xbf16>) outs(%subview : memref<32x32xbf16, strided<[32, 1], offset: ?>>) {
^bb0(%in: bf16, %in_1: bf16, %out: bf16):
%1 = arith.mulf %in, %in_1 : bf16
%2 = arith.addf %out, %1 : bf16
linalg.yield %2 : bf16
}
}
return %alloc : memref<8x32x32x32xbf16>
}
}

Loading