Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
newling committed Jun 27, 2024
1 parent 3d2bd60 commit 82facf6
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "iree-amd-aie/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "iree-amd-aie/Transforms/AMDAIEUtils.h"

namespace mlir::iree_compiler::AMDAIE {

Expand Down Expand Up @@ -74,6 +75,22 @@ class AMDAIEInsertLoopsForVectorizationPass
// Matmul-like ops have 3 operands.
if (genericOp->getNumOperands() != 3) return failure();

// Check that the operands and result are of vectorizable types, if they are
// not, then do not tile.
auto hasAieVectorizableTypes = [genericOp]() -> bool {
auto elType = [](Value v) {
return cast<ShapedType>(v.getType()).getElementType();
};
auto lhsType = elType(genericOp->getOperand(0));
auto rhsType = elType(genericOp->getOperand(1));
auto resType = elType(genericOp->getResult(0));
FailureOr<std::array<uint32_t, 3>> maybeSize =
::mlir::iree_compiler::AMDAIE::getAIEMatmulInstructionSize(
lhsType, rhsType, resType);
return !failed(maybeSize);
}();
if (!hasAieVectorizableTypes) return failure();

// Don't transform to scf.for loops unless there is at least one
// non-singleton loop to construct. This isn't strictly necessary, but
// avoids generating a bunch of loops of size 1.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,28 +1,36 @@
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-amdaie-insert-loops-for-vectorization))" %s | FileCheck %s

!t2 = tensor<64x64xf32>
!t3 = tensor<64x64x64xf32>
!t4 = tensor<64x64x64x64xf32>
!t2_bf16 = tensor<64x64xbf16>
!t3_bf16 = tensor<64x64x64xbf16>
!t4_bf16 = tensor<64x64x64x64xbf16>

!t2_f32 = tensor<64x64xf32>
!t3_f32 = tensor<64x64x64xf32>
!t4_f32 = tensor<64x64x64x64xf32>


module {
// A generic that corresponds to a simple matmul (2 rank-2 operands)
// does NOT get tiled.
// CHECK-LABEL: vanilla
// CHECK-NOT: scf.for
func.func @vanilla(%arg0: !t2, %arg1: !t2, %arg2: !t2) -> !t2 {
func.func @vanilla(%arg0: !t2_bf16, %arg1: !t2_bf16, %arg2: !t2_f32) -> !t2_f32 {
%0 = linalg.generic {indexing_maps =
[
affine_map<(d0, d1, d2) -> (d0, d2)>,
affine_map<(d0, d1, d2) -> (d2, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1)>
],
iterator_types = ["parallel", "parallel", "reduction"]}
ins(%arg0, %arg1 : !t2, !t2) outs(%arg2 : !t2) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%1 = arith.mulf %in, %in_0 : f32
ins(%arg0, %arg1 : !t2_bf16, !t2_bf16) outs(%arg2 : !t2_f32) {
^bb0(%in_0_bf16: bf16, %in_1_bf16: bf16, %out: f32):
%in_0 = arith.extf %in_0_bf16: bf16 to f32
%in_1 = arith.extf %in_1_bf16: bf16 to f32
%1 = arith.mulf %in_0, %in_1 : f32
%2 = arith.addf %out, %1 : f32
linalg.yield %2 : f32
} -> !t2
return %0 : !t2
} -> !t2_f32
return %0 : !t2_f32
}

// A batched matmul gets the batch dimension converted to a single scf.for
Expand All @@ -31,25 +39,48 @@ module {
// CHECK: linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
// CHECK-SAME: ins
// CHECK-SAME: tensor<1x64x64xf32>, tensor<1x64x64xf32>
// CHECK-SAME: tensor<1x64x64xbf16>, tensor<1x64x64xbf16>
// CHECK-SAME: outs
// CHECK-SAME: tensor<1x64x64xf32>
// CHECK-NOT: scf.for
func.func @batched0(%arg0: !t3, %arg1: !t3, %arg2: !t3) -> !t3 {
func.func @batched0(%arg0: !t3_bf16, %arg1: !t3_bf16, %arg2: !t3_f32) -> !t3_f32 {
%0 = linalg.generic {indexing_maps =
[
affine_map<(b0, d0, d1, d2) -> (b0, d0, d2)>,
affine_map<(b0, d0, d1, d2) -> (b0, d2, d1)>,
affine_map<(b0, d0, d1, d2) -> (b0, d0, d1)>
],
iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
ins(%arg0, %arg1 : !t3_bf16, !t3_bf16) outs(%arg2 : !t3_f32) {
^bb0(%in_0_bf16: bf16, %in_1_bf16: bf16, %out: f32):
%in_0 = arith.extf %in_0_bf16: bf16 to f32
%in_1 = arith.extf %in_1_bf16: bf16 to f32
%1 = arith.mulf %in_0, %in_1 : f32
%2 = arith.addf %out, %1 : f32
linalg.yield %2 : f32
} -> !t3_f32
return %0 : !t3_f32
}

// A batched matmul where the element types are not supported for
// vectorization on AIE, does not get tiles:
// CHECK-LABEL: batched_bad_element_types
// CHECK-NOT: scf.for
func.func @batched_bad_element_types(%arg0: !t3_f32, %arg1: !t3_f32, %arg2: !t3_f32) -> !t3_f32 {
%0 = linalg.generic {indexing_maps =
[
affine_map<(b0, d0, d1, d2) -> (b0, d0, d2)>,
affine_map<(b0, d0, d1, d2) -> (b0, d2, d1)>,
affine_map<(b0, d0, d1, d2) -> (b0, d0, d1)>
],
iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
ins(%arg0, %arg1 : !t3, !t3) outs(%arg2 : !t3) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%1 = arith.mulf %in, %in_0 : f32
ins(%arg0, %arg1 : !t3_f32, !t3_f32) outs(%arg2 : !t3_f32) {
^bb0(%in_0: f32, %in_1: f32, %out: f32):
%1 = arith.mulf %in_0, %in_1 : f32
%2 = arith.addf %out, %1 : f32
linalg.yield %2 : f32
} -> !t3
return %0 : !t3
} -> !t3_f32
return %0 : !t3_f32
}

// A test like the above, but with a matmul_tranpose_b instead of a matmul
Expand All @@ -58,25 +89,27 @@ module {
// CHECK: linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
// CHECK-SAME: ins
// CHECK-SAME: tensor<1x64x64xf32>, tensor<1x64x64xf32>
// CHECK-SAME: tensor<1x64x64xbf16>, tensor<1x64x64xbf16>
// CHECK-SAME: outs
// CHECK-SAME: tensor<1x64x64xf32>
// CHECK-NOT: scf.for
func.func @batched_transpose_b(%arg0: !t3, %arg1: !t3, %arg2: !t3) -> !t3 {
func.func @batched_transpose_b(%arg0: !t3_bf16, %arg1: !t3_bf16, %arg2: !t3_f32) -> !t3_f32 {
%0 = linalg.generic {indexing_maps =
[
affine_map<(b0, d0, d1, d2) -> (b0, d0, d2)>,
affine_map<(b0, d0, d1, d2) -> (b0, d1, d2)>,
affine_map<(b0, d0, d1, d2) -> (b0, d0, d1)>
],
iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
ins(%arg0, %arg1 : !t3, !t3) outs(%arg2 : !t3) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%1 = arith.mulf %in, %in_0 : f32
ins(%arg0, %arg1 : !t3_bf16, !t3_bf16) outs(%arg2 : !t3_f32) {
^bb0(%in_0_bf16: bf16, %in_1_bf16: bf16, %out: f32):
%in_0 = arith.extf %in_0_bf16: bf16 to f32
%in_1 = arith.extf %in_1_bf16: bf16 to f32
%1 = arith.mulf %in_0, %in_1 : f32
%2 = arith.addf %out, %1 : f32
linalg.yield %2 : f32
} -> !t3
return %0 : !t3
} -> !t3_f32
return %0 : !t3_f32
}

// Another test with a transposed matmul, but in this case A is transposed.
Expand All @@ -85,46 +118,48 @@ module {
// CHECK-LABEL: batched_transpose_a
// CHECK-NOT: scf.for

func.func @batched_transpose_a(%arg0: !t3, %arg1: !t3, %arg2: !t3) -> !t3 {
func.func @batched_transpose_a(%arg0: !t3_bf16, %arg1: !t3_bf16, %arg2: !t3_f32) -> !t3_f32 {
%0 = linalg.generic {indexing_maps =
[
affine_map<(b0, d0, d1, d2) -> (b0, d2, d1)>,
affine_map<(b0, d0, d1, d2) -> (b0, d2, d0)>,
affine_map<(b0, d0, d1, d2) -> (b0, d0, d1)>
],
iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
ins(%arg0, %arg1 : !t3, !t3) outs(%arg2 : !t3) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%1 = arith.mulf %in, %in_0 : f32
ins(%arg0, %arg1 : !t3_bf16, !t3_bf16) outs(%arg2 : !t3_f32) {
^bb0(%in_0_bf16: bf16, %in_1_bf16: bf16, %out: f32):
%in_0 = arith.extf %in_0_bf16: bf16 to f32
%in_1 = arith.extf %in_1_bf16: bf16 to f32
%1 = arith.mulf %in_0, %in_1 : f32
%2 = arith.addf %out, %1 : f32
linalg.yield %2 : f32
} -> !t3
return %0 : !t3
} -> !t3_f32
return %0 : !t3_f32
}


// A check that a linalg.generic where the number of operands is not 3, does
// not get transformed to have an scf.for
// CHECK-LABEL: funcWithTwoOperands
// CHECK-NOT: scf.for
func.func @funcWithTwoOperands(%arg0: !t4, %arg1: !t4) -> !t4 {
func.func @funcWithTwoOperands(%arg0: !t4_bf16, %arg1: !t4_bf16) -> !t4_bf16 {
%0 = linalg.generic {indexing_maps =
[
affine_map<(b0, d0, d1, d2) -> (b0, d0, d1, d2)>,
affine_map<(b0, d0, d1, d2) -> (d0, d1, d2, b0)>
],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
ins(%arg0 : !t4) outs(%arg1 : !t4) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> !t4
return %0 : !t4
ins(%arg0 : !t4_bf16) outs(%arg1 : !t4_bf16) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
} -> !t4_bf16
return %0 : !t4_bf16
}

// Check that the final 3 dimensions do have the pattern of a matmul (or matmul transpose)
// CHECK-LABEL: batched1
// CHECK-NOT: scf.for
func.func @batched1(%arg0: !t3, %arg1: !t3, %arg2: !t3) -> !t3 {
func.func @batched1(%arg0: !t3_bf16, %arg1: !t3_bf16, %arg2: !t3_f32) -> !t3_f32 {
%0 = linalg.generic {indexing_maps =
[
// This is like a matmul but the first operand is
Expand All @@ -134,13 +169,15 @@ module {
affine_map<(b0, d0, d1, d2) -> (b0, d0, d1)>
],
iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
ins(%arg0, %arg1 : !t3, !t3) outs(%arg2 : !t3) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%1 = arith.mulf %in, %in_0 : f32
ins(%arg0, %arg1 : !t3_bf16, !t3_bf16) outs(%arg2 : !t3_f32) {
^bb0(%in_0_bf16: bf16, %in_1_bf16: bf16, %out: f32):
%in_0 = arith.extf %in_0_bf16: bf16 to f32
%in_1 = arith.extf %in_1_bf16: bf16 to f32
%1 = arith.mulf %in_0, %in_1 : f32
%2 = arith.addf %out, %1 : f32
linalg.yield %2 : f32
} -> !t3
return %0 : !t3
} -> !t3_f32
return %0 : !t3_f32
}

// Check for a batched matmul where operand 0 is broadcast:
Expand All @@ -149,25 +186,27 @@ module {
// CHECK: linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
// CHECK-SAME: ins
// CHECK-SAME: tensor<64x64xf32>, tensor<1x64x64xf32>
// CHECK-SAME: tensor<64x64xbf16>, tensor<1x64x64xbf16>
// CHECK-SAME: outs
// CHECK-SAME: tensor<1x64x64xf32>
// CHECK-NOT: scf.for
func.func @batched2(%arg0: !t2, %arg1: !t3, %arg2: !t3) -> !t3 {
func.func @batched2(%arg0: !t2_bf16, %arg1: !t3_bf16, %arg2: !t3_f32) -> !t3_f32 {
%0 = linalg.generic {indexing_maps =
[
affine_map<(b0, d0, d1, d2) -> (d0, d2)>,
affine_map<(b0, d0, d1, d2) -> (b0, d2, d1)>,
affine_map<(b0, d0, d1, d2) -> (b0, d0, d1)>
],
iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
ins(%arg0, %arg1 : !t2, !t3) outs(%arg2 : !t3) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%1 = arith.mulf %in, %in_0 : f32
ins(%arg0, %arg1 : !t2_bf16, !t3_bf16) outs(%arg2 : !t3_f32) {
^bb0(%in_0_bf16: bf16, %in_1_bf16: bf16, %out: f32):
%in_0 = arith.extf %in_0_bf16: bf16 to f32
%in_1 = arith.extf %in_1_bf16: bf16 to f32
%1 = arith.mulf %in_0, %in_1 : f32
%2 = arith.addf %out, %1 : f32
linalg.yield %2 : f32
} -> !t3
return %0 : !t3
} -> !t3_f32
return %0 : !t3_f32
}

// A function which arises from the pack-based pipeline in iree-amd-aie,
Expand Down

0 comments on commit 82facf6

Please sign in to comment.