-
Notifications
You must be signed in to change notification settings - Fork 31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Brgemm register tiling for bf16 type #1005
base: main
Are you sure you want to change the base?
Conversation
@rengolin Request to review this PR for |
//===- BrgemmLinalgTiling.cpp -----------------------------------------*- | ||
//C++-*-===// |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: tweak to fit in single line
// Check whether the tile sizes are valid | ||
if (options.registerTileShape.size() != 3 && | ||
options.registerTileShape.size() != 2) | ||
return failure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: it's easier to debug in the future with rewriter.notifyMatchFailure
that provides some feedback instead of failure
// Set the K tile to 1, if the user not provided (it is fp32 target) | ||
if (options.registerTileShape.size() == 2) | ||
mxnxkTile[2] = 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd just force user to always provide m,n,k tiles. It'll simplify verification logic and makes usage more explicit.
if (options.registerTileShape.size() == 2) | ||
mxnxkTile[2] = 1; | ||
|
||
// k-tile size adjusted based on the vnni layout for bf16 type |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This has baked-in assumptions that are not verified.
As the pass now operates on generic, we need to strictly filter ops that are accepted. I think you need to at least ensure it is a VNNI contraction first - there should be some suitable helpers in VnniUtils
.
If f32 generic should be supported as well, it might need some extra checks there too.
for (auto itrShapeMNK = mxnxkTile.begin(); itrShapeMNK != mxnxkTile.end(); | ||
itrShapeMNK++, i++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: you could use llvm::enumerate
for this, like
for (auto [idx, itrShape] : llvm::enumerate(mxnxkTile)) {
} | ||
} | ||
|
||
// DS to assist while creating new subviews with correct indices and shapes | ||
SmallVector<int64_t> mxkTile(2); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: you could directly brace initialize it as mxkTile = {val1, val2};
} | ||
} | ||
|
||
auto subview = rewriter.create<memref::SubViewOp>( | ||
brgemmOp.getLoc(), MemRefType(), | ||
input, offsets, shape, strides); | ||
brgemmOp.getLoc(), MemRefType(), input, offsets, shape, strides); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: you can skip the result MemRefType and use this builder:
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, Value source, ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides, ArrayRef<NamedAttribute> attrs = {});
@@ -0,0 +1,30 @@ | |||
// RUN: tpp-run -e register_tile_bf16 --entry-point-result=void -print %s > %t.1 | |||
// RUN: tpp-opt %s --tile-brgemm-linalg="registerBlocking=32,32,32" -convert-linalg-to-xsmm | tpp-run -e register_tile_bf16 --entry-point-result=void -print > %t.2 | |||
// RUN: diff %t.1 %t.2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be more robust to use fpcmp
// RUN: tpp-run -e register_tile_bf16 --entry-point-result=void -print %s > %t.1 | ||
// RUN: tpp-opt %s --tile-brgemm-linalg="registerBlocking=32,32,32" -convert-linalg-to-xsmm | tpp-run -e register_tile_bf16 --entry-point-result=void -print > %t.2 | ||
// RUN: diff %t.1 %t.2 | ||
// RUN: rm %t.1 %t.2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: it already creates temporary file, there should be no need to explicitly delete them
// CONF1-LABEL: memref.global "private" constant @__constant_48x32x32xf32 : memref<48x32x32xf32> = dense<1.000000e+00> {alignment = 64 : i64} | ||
// CONF1-LABEL: func.func @chainned_gemm_do_register_tiling( | ||
// CONF1-SAME: %[[VAL_0:.*]]: memref<8x48x32x32xf32>) -> memref<8x48x32x32xf32> { | ||
// CONF1: %[[VAL_1:.*]] = arith.constant 1 : index |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you use more descriptive named for the captured values?
Also, these check feel too explicit, maybe you could omit some details
// Creates M, N, and K tile loops | ||
scf::ForOp loopOp = rewriter.create<scf::ForOp>(brgemmOp.getLoc(), | ||
zeroCst, ubCstTiledLoop, stepCstTiledLoop); | ||
scf::ForOp loopOp = rewriter.create<scf::ForOp>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I am understanding right, this transform is meant to operate on linalg ops. As I expect all the ops you want to support will implement TilingInterface, would it be possible to just use the TileUsingFor
transform instead of manually implementing tiling?
%0 = memref.get_global @__constant_32x16x32x2xbf16 : memref<32x16x32x2xbf16> | ||
%alloc = memref.alloc() {alignment = 64 : i64} : memref<8x32x32x32xbf16> | ||
%expand_shape = memref.expand_shape %arg0 [[0], [1], [2], [3, 4]] output_shape [8, 32, 32, 16, 2] : memref<8x32x32x32xbf16> into memref<8x32x32x16x2xbf16> | ||
scf.forall (%arg1, %arg2) in (8, 32) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If these scf.forall are not needed by the (matcher of the) transform can we please get rid of them? Same goes for all the unittests in this file and other surrounding IR that does not influence the code-under-test.
This PR extends the
brgemm register tiling
pass to supportbf16
type. The changes:linalg.batch_reduce_matmul
forfp32
andlinal.generic
forvnni
opt bf16,bf16
type.