Skip to content

Commit

Permalink
Adapt simple pack pipeline to work with large GEMM (nod-ai#141)
Browse files Browse the repository at this point in the history
With this PR, matmul size 2048x2048x512 can be compiled and run
correctly on hardware. 

Co-authored-by: erweiw <[email protected]>
  • Loading branch information
yzhang93 and erwei-xilinx authored Feb 13, 2024
1 parent 3d60e43 commit 53abdcc
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 46 deletions.
1 change: 1 addition & 0 deletions compiler/plugins/target/AMD-AIE/air/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ iree_cc_library(
"TransformPasses.cpp"
"${IREE_MLIR_AIR_SOURCE_DIR}/lib/Transform/AIRHerdPlacementPass.cpp"
"${IREE_MLIR_AIR_SOURCE_DIR}/lib/Transform/AIRLinalgCodegen.cpp"
"${IREE_MLIR_AIR_SOURCE_DIR}/lib/Transform/AffineLoopOptPass.cpp"
"${IREE_MLIR_AIR_SOURCE_DIR}/lib/Transform/AIRMiscPasses.cpp"
"${IREE_MLIR_AIR_SOURCE_DIR}/lib/Transform/AIRDependency.cpp"
"${IREE_MLIR_AIR_SOURCE_DIR}/lib/Transform/AIRDependencyScheduleOpt.cpp"
Expand Down
2 changes: 2 additions & 0 deletions compiler/plugins/target/AMD-AIE/air/TransformPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,7 @@ void registerAIRTransformPasses() {
registerAIRSpecializeDmaBroadcast();
registerAIRUnrollLoopForPipeliningPattern();
registerAIRCollapseHerdPass();
registerAIRUnrollOuterPerfectlyNestedLoopsPass();
registerAffineLoopOptPass();
}
} // namespace mlir::iree_compiler::AMDAIE
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ struct PackConfig {
SmallVector<SmallVector<int64_t>> outerPerm;
};

static FailureOr<PackConfig> getPackConfig(RewriterBase &rewriter,
int packLevel,
AIEPassPipeline passPipeline) {
static FailureOr<PackConfig> getPackConfig(
RewriterBase &rewriter, int packLevel, AIEPassPipeline passPipeline,
IREE::Codegen::LoweringConfigAttr lowerConfig, int64_t kSize) {
PackConfig config;
if (packLevel == 0) {
// packed size for [M, N, K]
Expand All @@ -40,9 +40,18 @@ static FailureOr<PackConfig> getPackConfig(RewriterBase &rewriter,
rewriter.getI64IntegerAttr(64)};

} else if (passPipeline == AIEPassPipeline::SimplePackPipeline) {
config.packedSizes = {rewriter.getI64IntegerAttr(8),
rewriter.getI64IntegerAttr(16),
rewriter.getI64IntegerAttr(16)};
// Set constraints for pack size [M, N] from first level of tile sizes
// Currently set pack size k as the input size K to avoid failure.
int64_t tileM = 64;
int64_t tileN = 64;
if (lowerConfig) {
auto tileSizes = lowerConfig.getTilingLevels()[0].getSizes();
tileM = tileSizes[0];
tileN = tileSizes[1];
}
config.packedSizes = {rewriter.getI64IntegerAttr(tileM),
rewriter.getI64IntegerAttr(tileN),
rewriter.getI64IntegerAttr(kSize)};
} else {
return failure();
}
Expand Down Expand Up @@ -133,8 +142,10 @@ void AMDAIEPackAndTransposePass::runOnOperation() {

// Step 2. Pack the operation
IRRewriter rewriter(context);
auto lhsType = linalgOp->getOperand(0).getType();
int64_t kSize = llvm::cast<ShapedType>(lhsType).getShape()[1];
FailureOr<PackConfig> packCfg =
getPackConfig(rewriter, packLevel, usePassPipeline);
getPackConfig(rewriter, packLevel, usePassPipeline, config, kSize);
if (failed(packCfg)) {
funcOp->emitOpError("failed to get pack configs");
return signalPassFailure();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,57 @@

namespace mlir::iree_compiler::AMDAIE {

static LogicalResult setRootConfigForPadPipeline(func::FuncOp entryPointFn,
linalg::MatmulOp matmulOp) {
SmallVector<int64_t> TileSizeLevel0 = {8, 8};
SmallVector<int64_t> TileSizeLevel1 = {4, 4};
SmallVector<int64_t> TileSizeLevel2 = {0, 0, 4};
TileSizesListType tileSizes = {TileSizeLevel0, TileSizeLevel1,
TileSizeLevel2};
return setOpConfigAndEntryPointFnTranslation(
entryPointFn, matmulOp, tileSizes,
IREE::Codegen::DispatchLoweringPassPipeline::None);
}

static LogicalResult setRootConfigForSimplePackPipeline(
func::FuncOp entryPointFn, linalg::MatmulOp matmulOp) {
// Assume working on a 2x2 AIE array and make sure the tile size is not larger
// than the input size.
auto initType = matmulOp.getDpsInitOperand(0)->get().getType();
auto initShape = llvm::cast<ShapedType>(initType).getShape();
auto tileM0 = std::min((int)initShape[0], 64);
auto tileN0 = std::min((int)initShape[1], 64);
auto tileM1 = std::max((int)tileM0 / 2, 1);
auto tileN1 = std::max((int)tileN0 / 2, 1);
auto lhsType = matmulOp.getDpsInputOperand(0)->get().getType();
auto lhsShape = llvm::cast<ShapedType>(lhsType).getShape();
auto tileK = std::min((int)lhsShape[1] / 8, 4);

SmallVector<int64_t> TileSizeLevel0 = {tileM0, tileN0};
SmallVector<int64_t> TileSizeLevel1 = {0, 0, 0, tileM1, tileN1};
SmallVector<int64_t> TileSizeLevel2 = {0, 0, 0, 0, 0, tileK};
TileSizesListType tileSizes = {TileSizeLevel0, TileSizeLevel1,
TileSizeLevel2};
return setOpConfigAndEntryPointFnTranslation(
entryPointFn, matmulOp, tileSizes,
IREE::Codegen::DispatchLoweringPassPipeline::None);
}

static LogicalResult setRootConfigForPackPipeline(func::FuncOp entryPointFn,
linalg::MatmulOp matmulOp,
AIEConfig cfg) {
if (!(cfg.num_cores == 1 || cfg.num_cores == 2 || cfg.num_cores == 4))
return matmulOp.emitOpError("unhandled number of cores");
SmallVector<int64_t> TileSizeLevel0 = {16, 64 * cfg.num_cores};
SmallVector<int64_t> TileSizeLevel1 = {0, 0, 64};
SmallVector<int64_t> TileSizeLevel2 = {1, 1};
TileSizesListType tileSizes = {TileSizeLevel0, TileSizeLevel1,
TileSizeLevel2};
return setOpConfigAndEntryPointFnTranslation(
entryPointFn, matmulOp, tileSizes,
IREE::Codegen::DispatchLoweringPassPipeline::None);
}

/// Sets the lowering configuration for dispatch region with root op that
/// implements the contraction operation interface.
static LogicalResult setRootConfig(func::FuncOp entryPointFn,
Expand All @@ -37,36 +88,12 @@ static LogicalResult setRootConfig(func::FuncOp entryPointFn,
// TODO (nmeshram) : This needs to be moved in a separate more generalized
// logic. Also, need a flag to experiment between pad based and pack based
// approach which will have different tile sizes and pass pipelines
if (usePassPipeline == AIEPassPipeline::PadPipeline) {
SmallVector<int64_t> TileSizeLevel0 = {8, 8};
SmallVector<int64_t> TileSizeLevel1 = {4, 4};
SmallVector<int64_t> TileSizeLevel2 = {0, 0, 4};
TileSizesListType tileSizes = {TileSizeLevel0, TileSizeLevel1,
TileSizeLevel2};
return setOpConfigAndEntryPointFnTranslation(
entryPointFn, matmulOp, tileSizes,
IREE::Codegen::DispatchLoweringPassPipeline::None);
} else if (usePassPipeline == AIEPassPipeline::SimplePackPipeline) {
SmallVector<int64_t> TileSizeLevel0 = {8, 16};
SmallVector<int64_t> TileSizeLevel1 = {1, 1};
SmallVector<int64_t> TileSizeLevel2 = {0, 0, 1};
TileSizesListType tileSizes = {TileSizeLevel0, TileSizeLevel1,
TileSizeLevel2};
return setOpConfigAndEntryPointFnTranslation(
entryPointFn, matmulOp, tileSizes,
IREE::Codegen::DispatchLoweringPassPipeline::None);
} else if (usePassPipeline == AIEPassPipeline::PackPipeline) {
if (!(cfg.num_cores == 1 || cfg.num_cores == 2 || cfg.num_cores == 4))
return matmulOp.emitOpError("unhandled number of cores");
SmallVector<int64_t> TileSizeLevel0 = {16, 64 * cfg.num_cores};
SmallVector<int64_t> TileSizeLevel1 = {0, 0, 64};
SmallVector<int64_t> TileSizeLevel2 = {1, 1};
TileSizesListType tileSizes = {TileSizeLevel0, TileSizeLevel1,
TileSizeLevel2};
return setOpConfigAndEntryPointFnTranslation(
entryPointFn, matmulOp, tileSizes,
IREE::Codegen::DispatchLoweringPassPipeline::None);
}
if (usePassPipeline == AIEPassPipeline::PadPipeline)
return setRootConfigForPadPipeline(entryPointFn, matmulOp);
if (usePassPipeline == AIEPassPipeline::SimplePackPipeline)
return setRootConfigForSimplePackPipeline(entryPointFn, matmulOp);
if (usePassPipeline == AIEPassPipeline::PackPipeline)
return setRootConfigForPackPipeline(entryPointFn, matmulOp, cfg);
return matmulOp.emitOpError("unhandled pass pipeline");
}

Expand Down
34 changes: 32 additions & 2 deletions compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "iree/compiler/Codegen/Common/CPU/Passes.h"
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Common/Transforms.h"
#include "mlir/Dialect/Affine/Passes.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Pass/PassManager.h"
Expand Down Expand Up @@ -168,13 +169,13 @@ void addSimplePackBasedPassPipeline(OpPassManager &pm,
modulePassManager.addPass(createCanonicalizerPass());
modulePassManager.addPass(createCSEPass());

// Second level packing and bufferize to allocation
// Second level packing and only promote the result to local memory
packOptions.packLevel = 1;
packOptions.usePassPipeline = AIEPassPipeline::SimplePackPipeline;
modulePassManager.addNestedPass<func::FuncOp>(
createAMDAIEPackAndTransposePass(packOptions));
bufferizeOptions.memorySpace = 2;
bufferizeOptions.bufferizeLevel = -1;
bufferizeOptions.bufferizeLevel = 1;
modulePassManager.addNestedPass<func::FuncOp>(
createAMDAIEBufferizeToAllocationPass(bufferizeOptions));

Expand All @@ -187,6 +188,19 @@ void addSimplePackBasedPassPipeline(OpPassManager &pm,
modulePassManager.addPass(createCanonicalizerPass());
modulePassManager.addPass(createCSEPass());

// Fuse pack ops into for loop
modulePassManager.addNestedPass<func::FuncOp>(
createAMDAIEFusePackIntoForLoopPass());
modulePassManager.addNestedPass<func::FuncOp>(createAMDAIECleanupPass());
modulePassManager.addPass(createCanonicalizerPass());
modulePassManager.addPass(createCSEPass());

// Promote the inputs to local memory
bufferizeOptions.memorySpace = 2;
bufferizeOptions.bufferizeLevel = 2;
modulePassManager.addNestedPass<func::FuncOp>(
createAMDAIEBufferizeToAllocationPass(bufferizeOptions));

// Comprehensive bufferization
addAMDAIEBufferizePasses(modulePassManager);
modulePassManager.addPass(memref::createFoldMemRefAliasOpsPass());
Expand Down Expand Up @@ -386,7 +400,23 @@ void addMLIRAIRAIELoweringPasses(OpPassManager &passManager) {
options.clEmitWhileLoop = true;
passManager.addPass(xilinx::air::createAIRToAIEPass(options));
}
passManager.addPass(createCanonicalizerPass());
passManager.addPass(xilinx::air::createAIRLoweringPass());
{
xilinx::air::AffineLoopOptPassOptions options;
const std::vector<unsigned> tile_sizes = {4, 4};
options.clTileSizes = ArrayRef(tile_sizes);
passManager.addNestedPass<func::FuncOp>(
xilinx::air::createAffineLoopOptPass(options));
}
{
xilinx::air::AIRUnrollOuterPerfectlyNestedLoopsPassOptions options;
options.clDepth = 2;
passManager.addNestedPass<func::FuncOp>(
xilinx::air::createAIRUnrollOuterPerfectlyNestedLoopsPass(options));
}
passManager.addPass(mlir::affine::createAffineExpandIndexOpsPass());

passManager.addPass(xilinx::airrt::createAIRRtToIpuPass());
passManager.addPass(createCanonicalizerPass());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ func.func @matmul_example_dispatch_0_matmul_16x256x256_i8xi8xi32(%arg0 : tensor<
%c0_i32 = arith.constant 0 : i32
%5 = tensor.empty() : tensor<16x256xi32>
%6 = linalg.fill ins(%c0_i32 : i32) outs(%5 : tensor<16x256xi32>) -> tensor<16x256xi32>
// CHECK: tensor.pack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [8, 16] into %{{.*}} : tensor<16x256xi8> -> tensor<2x16x8x16xi8>
// CHECK: tensor.pack %{{.*}} outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %{{.*}} : tensor<256x256xi8> -> tensor<16x16x16x16xi8>
// CHECK: tensor.pack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [8, 16] into %{{.*}} : tensor<16x256xi32> -> tensor<2x16x8x16xi32>
// CHECK: tensor.pack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [8, 256] into %{{.*}} : tensor<16x256xi8> -> tensor<2x1x8x256xi8>
// CHECK: tensor.pack %{{.*}} outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [256, 8] into %{{.*}} : tensor<256x256xi8> -> tensor<1x32x256x8xi8>
// CHECK: tensor.pack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %{{.*}} : tensor<16x256xi32> -> tensor<2x32x8x8xi32>
// CHECK: linalg.generic
// CHECK-SAME: attrs = {lowering_config = #config}
%7 = linalg.matmul {lowering_config = #config} ins(%arg0, %arg1 : tensor<16x256xi8>, tensor<256x256xi8>) outs(%6 : tensor<16x256xi32>) -> tensor<16x256xi32>
Expand Down
30 changes: 26 additions & 4 deletions tests/samples/simple_pack_pipeline_e2e.mlir
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
// RUN: iree-compile --iree-hal-target-backends=amd-aie --compile-to=executable-sources %s | iree-opt --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-hal-translate-target-executable-variants{target=amd-aie})))" | FileCheck %s --check-prefix=CPP
// RUN: iree-compile --iree-hal-target-backends=amd-aie --compile-to=executable-sources %s | iree-opt --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-hal-translate-target-executable-variants{target=amd-aie})))" --split-input-file | FileCheck %s --check-prefix=CPP

// This test demonstrates Pack pipeline based e2e lowering.

// To check the cpp path equivalent to the transform dialect script.
// CPP-LABEL: hal.executable.export public @matmul_static_dispatch_0_matmul_8x32x16_i32
// CPP-LABEL: hal.executable.export public @matmul_small_dispatch_0_matmul_8x32x16_i32
// CPP: aie.device(ipu)
// CPP: aie.shim_dma_allocation
// CPP: aie.shim_dma_allocation
// CPP: aie.shim_dma_allocation
// CPP: func.func @matmul_static_dispatch_0_matmul_8x32x16_i32(%arg0: memref<8x16xi32>, %arg1: memref<16x32xi32>, %arg2: memref<8x32xi32>)
// CPP: func.func @matmul_small_dispatch_0_matmul_8x32x16_i32(%arg0: memref<8x16xi32>, %arg1: memref<16x32xi32>, %arg2: memref<8x32xi32>)
// CPP: aiex.ipu.dma_memcpy_nd
// CPP: aiex.ipu.dma_memcpy_nd
// CPP: aiex.ipu.dma_memcpy_nd
// CPP: aiex.ipu.sync
func.func @matmul_static(%lhs : tensor<8x16xi32>,
func.func @matmul_small(%lhs : tensor<8x16xi32>,
%rhs : tensor<16x32xi32>) -> tensor<8x32xi32> {
%empty = tensor.empty() : tensor<8x32xi32>
%cst = arith.constant 0 : i32
Expand All @@ -22,3 +22,25 @@ func.func @matmul_static(%lhs : tensor<8x16xi32>,
outs(%fill : tensor<8x32xi32>) -> tensor<8x32xi32>
return %2 : tensor<8x32xi32>
}

// -----

// CPP-LABEL: hal.executable.export public @matmul_large_dispatch_0_matmul_2048x2048x512_i32
// CPP: aie.device(ipu)
// CPP: aie.shim_dma_allocation
// CPP: aie.shim_dma_allocation
// CPP: aie.shim_dma_allocation
// CPP: func.func @matmul_large_dispatch_0_matmul_2048x2048x512_i32(%arg0: memref<2048x512xi32>, %arg1: memref<512x2048xi32>, %arg2: memref<2048x2048xi32>)
// CPP: aiex.ipu.dma_memcpy_nd
// CPP: aiex.ipu.dma_memcpy_nd
// CPP: aiex.ipu.dma_memcpy_nd
// CPP: aiex.ipu.sync

func.func @matmul_large(%lhs: tensor<2048x512xi32>, %rhs: tensor<512x2048xi32>) -> tensor<2048x2048xi32> {
%empty = tensor.empty() : tensor<2048x2048xi32>
%cst = arith.constant 0 : i32
%fill = linalg.fill ins(%cst : i32) outs(%empty : tensor<2048x2048xi32>) -> tensor<2048x2048xi32>
%res = linalg.matmul ins(%lhs, %rhs: tensor<2048x512xi32>, tensor<512x2048xi32>)
outs(%fill: tensor<2048x2048xi32>) -> tensor<2048x2048xi32>
return %res : tensor<2048x2048xi32>
}

0 comments on commit 53abdcc

Please sign in to comment.