Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Brgemm register tiling for bf16 type #1005

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

arun-thmn
Copy link
Contributor

@arun-thmn arun-thmn commented Feb 3, 2025

This PR extends the brgemm register tiling pass to support bf16 type. The changes:

  1. Template the existing pass to execute on linalg.batch_reduce_matmul for fp32 and linal.generic for vnni opt bf16,
  2. Test-cases for bf16 type.

@arun-thmn arun-thmn added the benchmark-all Benchmark all targets label Feb 3, 2025
@arun-thmn arun-thmn marked this pull request as ready for review February 3, 2025 03:38
@arun-thmn
Copy link
Contributor Author

@rengolin Request to review this PR for bf16 register tile support. I have re-written the tiling pass with new logic (template and more checks) to tile both fp32 and f16 (vnni). If you have time, I request you to review it as a new pass (as the existing tiling for fp32, I did it immediately joining Intel with lesser understanding of concepts).

@arun-thmn arun-thmn added benchmark-all Benchmark all targets and removed benchmark-all Benchmark all targets labels Feb 3, 2025
Comment on lines +1 to +2
//===- BrgemmLinalgTiling.cpp -----------------------------------------*-
//C++-*-===//
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: tweak to fit in single line

// Check whether the tile sizes are valid
if (options.registerTileShape.size() != 3 &&
options.registerTileShape.size() != 2)
return failure();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it's easier to debug in the future with rewriter.notifyMatchFailure that provides some feedback instead of failure

Comment on lines +81 to +83
// Set the K tile to 1, if the user not provided (it is fp32 target)
if (options.registerTileShape.size() == 2)
mxnxkTile[2] = 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd just force user to always provide m,n,k tiles. It'll simplify verification logic and makes usage more explicit.

if (options.registerTileShape.size() == 2)
mxnxkTile[2] = 1;

// k-tile size adjusted based on the vnni layout for bf16 type
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has baked-in assumptions that are not verified.
As the pass now operates on generic, we need to strictly filter ops that are accepted. I think you need to at least ensure it is a VNNI contraction first - there should be some suitable helpers in VnniUtils.
If f32 generic should be supported as well, it might need some extra checks there too.

Comment on lines +99 to +100
for (auto itrShapeMNK = mxnxkTile.begin(); itrShapeMNK != mxnxkTile.end();
itrShapeMNK++, i++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you could use llvm::enumerate for this, like

for (auto [idx, itrShape] : llvm::enumerate(mxnxkTile)) {

}
}

// DS to assist while creating new subviews with correct indices and shapes
SmallVector<int64_t> mxkTile(2);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you could directly brace initialize it as mxkTile = {val1, val2};

}
}

auto subview = rewriter.create<memref::SubViewOp>(
brgemmOp.getLoc(), MemRefType(),
input, offsets, shape, strides);
brgemmOp.getLoc(), MemRefType(), input, offsets, shape, strides);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you can skip the result MemRefType and use this builder:
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, Value source, ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides, ArrayRef<NamedAttribute> attrs = {});

@@ -0,0 +1,30 @@
// RUN: tpp-run -e register_tile_bf16 --entry-point-result=void -print %s > %t.1
// RUN: tpp-opt %s --tile-brgemm-linalg="registerBlocking=32,32,32" -convert-linalg-to-xsmm | tpp-run -e register_tile_bf16 --entry-point-result=void -print > %t.2
// RUN: diff %t.1 %t.2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be more robust to use fpcmp

// RUN: tpp-run -e register_tile_bf16 --entry-point-result=void -print %s > %t.1
// RUN: tpp-opt %s --tile-brgemm-linalg="registerBlocking=32,32,32" -convert-linalg-to-xsmm | tpp-run -e register_tile_bf16 --entry-point-result=void -print > %t.2
// RUN: diff %t.1 %t.2
// RUN: rm %t.1 %t.2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it already creates temporary file, there should be no need to explicitly delete them

// CONF1-LABEL: memref.global "private" constant @__constant_48x32x32xf32 : memref<48x32x32xf32> = dense<1.000000e+00> {alignment = 64 : i64}
// CONF1-LABEL: func.func @chainned_gemm_do_register_tiling(
// CONF1-SAME: %[[VAL_0:.*]]: memref<8x48x32x32xf32>) -> memref<8x48x32x32xf32> {
// CONF1: %[[VAL_1:.*]] = arith.constant 1 : index
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you use more descriptive named for the captured values?

Also, these check feel too explicit, maybe you could omit some details

// Creates M, N, and K tile loops
scf::ForOp loopOp = rewriter.create<scf::ForOp>(brgemmOp.getLoc(),
zeroCst, ubCstTiledLoop, stepCstTiledLoop);
scf::ForOp loopOp = rewriter.create<scf::ForOp>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I am understanding right, this transform is meant to operate on linalg ops. As I expect all the ops you want to support will implement TilingInterface, would it be possible to just use the TileUsingFor transform instead of manually implementing tiling?

%0 = memref.get_global @__constant_32x16x32x2xbf16 : memref<32x16x32x2xbf16>
%alloc = memref.alloc() {alignment = 64 : i64} : memref<8x32x32x32xbf16>
%expand_shape = memref.expand_shape %arg0 [[0], [1], [2], [3, 4]] output_shape [8, 32, 32, 16, 2] : memref<8x32x32x32xbf16> into memref<8x32x32x16x2xbf16>
scf.forall (%arg1, %arg2) in (8, 32) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If these scf.forall are not needed by the (matcher of the) transform can we please get rid of them? Same goes for all the unittests in this file and other surrounding IR that does not influence the code-under-test.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
benchmark-all Benchmark all targets
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants