-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LinalgFunctionOutlining] Create a pass to outline linalg compute ops #862
base: main
Are you sure you want to change the base?
Changes from all commits
493e08d
9e686fd
1f04f8e
48cbd72
daeb581
436a402
9bf1eb3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
// Copyright 2024 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#include "iree-amd-aie/IR/AMDAIEOps.h" | ||
#include "iree-amd-aie/Transforms/AMDAIEUtils.h" | ||
#include "iree-amd-aie/Transforms/Passes.h" | ||
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" | ||
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" | ||
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" | ||
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h" | ||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h" | ||
#include "mlir/Dialect/Vector/IR/VectorOps.h" | ||
|
||
#define DEBUG_TYPE "iree-amdaie-linalg-function-outlining" | ||
|
||
namespace mlir::iree_compiler::AMDAIE { | ||
|
||
namespace { | ||
|
||
unsigned uniqueOutlinedMatmul = 0; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Global values like this arent a good idea. This can cause race conditions when run in multi-threaded context. I think LLVM already has a way to deduplicate symbols. I dont know how to call it though. |
||
unsigned uniqueOutlinedElementwise = 0; | ||
|
||
/// Utility to outline the linalg compute op. | ||
static FailureOr<func::FuncOp> outlinedToAFunction( | ||
IRRewriter &rewriter, ModuleOp moduleOp, linalg::LinalgOp computeOp, | ||
std::string outlineFuncName, | ||
DenseMap<Operation *, std::string> &computeOpToOutlinedFuncMap) { | ||
// Check if the compute op is equivalent to a previously outlined compute op. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At this point we expect moduleOp's symbol table to not contain |
||
// If yes, we replace the `outlineFuncName` of the current compute op to be | ||
// same as the previous equivalent outlined compute op in order to lookup the | ||
// Symbol table. | ||
for (auto &[op, funcName] : computeOpToOutlinedFuncMap) { | ||
if (!OperationEquivalence::isEquivalentTo( | ||
computeOp.getOperation(), op, | ||
OperationEquivalence::ignoreValueEquivalence, /*flags=*/nullptr, | ||
OperationEquivalence::IgnoreLocations)) | ||
continue; | ||
outlineFuncName = funcName; | ||
break; | ||
Comment on lines
+36
to
+42
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why don't you directly check if the |
||
} | ||
if (auto outlinedFuncOp = dyn_cast_if_present<func::FuncOp>( | ||
moduleOp.lookupSymbol(outlineFuncName))) { | ||
return outlinedFuncOp; | ||
} | ||
|
||
// Form outlined FunctionType. | ||
SmallVector<Type> inputTypes = llvm::map_to_vector( | ||
computeOp.getDpsInputs(), [](Value v) { return v.getType(); }); | ||
for (Value val : computeOp.getDpsInits()) inputTypes.push_back(val.getType()); | ||
auto outlinedFuncType = | ||
FunctionType::get(rewriter.getContext(), inputTypes, /*outputTypes=*/{}); | ||
|
||
// Form outlined FuncSignature | ||
rewriter.setInsertionPointToStart(moduleOp.getBody()); | ||
auto outlinedFunc = rewriter.create<func::FuncOp>( | ||
moduleOp.getLoc(), outlineFuncName, outlinedFuncType); | ||
outlinedFunc.setPrivate(); | ||
|
||
// Create an entry func block and map the original operands of the compute | ||
// op to the block arguments. | ||
Block *outlinedFuncBody = outlinedFunc.addEntryBlock(); | ||
rewriter.setInsertionPointToStart(outlinedFuncBody); | ||
SmallVector<BlockArgument> outlinedFuncArgs = llvm::map_to_vector( | ||
outlinedFunc.getArguments(), [&](BlockArgument bbArg) { return bbArg; }); | ||
unsigned bbArgIndex = 0; | ||
IRMapping operandMap; | ||
for (Value origOperand : computeOp.getDpsInputs()) | ||
operandMap.map(origOperand, outlinedFuncArgs[bbArgIndex++]); | ||
for (Value origOperand : computeOp.getDpsInits()) | ||
operandMap.map(origOperand, outlinedFuncArgs[bbArgIndex++]); | ||
|
||
// Clone the compute op while mapping the operand to the function block | ||
// arguments. | ||
Operation *clonedComputeOp = rewriter.clone(*computeOp, operandMap); | ||
|
||
// Create terminator op returning the cloned compute op's results. | ||
rewriter.setInsertionPointToEnd(outlinedFuncBody); | ||
rewriter.create<func::ReturnOp>(clonedComputeOp->getLoc(), ValueRange({})); | ||
|
||
computeOpToOutlinedFuncMap[computeOp] = outlineFuncName; | ||
|
||
// Since we have created a new outlined function, we will increase the | ||
// corresponding unique count. | ||
if (isMatmul(computeOp)) { | ||
++uniqueOutlinedMatmul; | ||
} else if (isElementwise(computeOp)) { | ||
++uniqueOutlinedElementwise; | ||
} | ||
return outlinedFunc; | ||
} | ||
|
||
class AMDAIELinalgFunctionOutliningPass | ||
: public impl::AMDAIELinalgFunctionOutliningBase< | ||
AMDAIELinalgFunctionOutliningPass> { | ||
public: | ||
AMDAIELinalgFunctionOutliningPass() = default; | ||
void getDependentDialects(DialectRegistry ®istry) const override { | ||
registry.insert<AMDAIEDialect, linalg::LinalgDialect>(); | ||
} | ||
|
||
void runOnOperation() override; | ||
}; | ||
|
||
void AMDAIELinalgFunctionOutliningPass::runOnOperation() { | ||
ModuleOp moduleOp = getOperation(); | ||
MLIRContext *context = &getContext(); | ||
IRRewriter rewriter(context); | ||
|
||
DenseMap<Operation *, std::string> computeOpToOutlinedFuncMap; | ||
SmallVector<Operation *> toBeErased; | ||
moduleOp.walk([&](linalg::LinalgOp computeOp) { | ||
// Form outlined FuncName. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also add comments of what linalg ops are currently outlined. |
||
std::string computeName = ""; | ||
if (isMatmul(computeOp)) { | ||
computeName = "_matmul_" + std::to_string(uniqueOutlinedMatmul); | ||
} else if (isElementwise(computeOp)) { | ||
computeName = "_elementwise_" + std::to_string(uniqueOutlinedElementwise); | ||
} else { | ||
return WalkResult::skip(); | ||
} | ||
std::string outlineFuncName = | ||
computeOp->getName().stripDialect().str() + computeName + "_outlined"; | ||
FailureOr<func::FuncOp> outlinedFuncOp = | ||
outlinedToAFunction(rewriter, moduleOp, computeOp, outlineFuncName, | ||
computeOpToOutlinedFuncMap); | ||
if (failed(outlinedFuncOp)) return WalkResult::interrupt(); | ||
rewriter.setInsertionPoint(computeOp); | ||
rewriter.create<func::CallOp>(computeOp.getLoc(), *outlinedFuncOp, | ||
computeOp->getOperands()); | ||
// We cannot immediately erase the compute op because it'd be used for | ||
// equivalence check. | ||
toBeErased.push_back(computeOp); | ||
return WalkResult::advance(); | ||
}); | ||
for (Operation *op : toBeErased) { | ||
op->dropAllUses(); | ||
rewriter.eraseOp(op); | ||
} | ||
} | ||
|
||
} // namespace | ||
|
||
std::unique_ptr<Pass> createAMDAIELinalgFunctionOutliningPass() { | ||
return std::make_unique<AMDAIELinalgFunctionOutliningPass>(); | ||
} | ||
} // namespace mlir::iree_compiler::AMDAIE |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -243,6 +243,19 @@ def AMDAIEFlattenLogicalObjectFifo : | |
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIEFlattenLogicalObjectFifoPass()"; | ||
} | ||
|
||
def AMDAIELinalgFunctionOutlining : | ||
Pass<"iree-amdaie-linalg-function-outlining", "ModuleOp"> { | ||
let summary = "Outlining of linalg compute ops"; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add some additional information in a description = [{ }] block about which linalg ops specifically are outlined (i.e. linalg.fill is not) and what the motivation for this pass is, and what some assumptions are? Sorry, I know that most of our passes in this file don't contain descriptions, but I think it'd be a good thing if we all started adding high-level information here for people without context. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have added that. Please check if the wordings are correct. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks. Is
still true? Otherwise looks good. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same question. Doesnt operation equivalence check the region as well? |
||
let description = [{ | ||
Outlines matmul/elementwise linalg compute ops only. This pass essentially minimises | ||
the code footprint overall as the unrolling in the later passes lead to lot repeated | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: 1) "loop unrolling"; 2) "lead to lots of repeated codes". |
||
lines of codes. | ||
One assumption this pass currently makes: All elementwise/matmul linalg ops within a | ||
dispatch have same body content. | ||
}]; | ||
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIELinalgFunctionOutliningPass()"; | ||
} | ||
|
||
def AMDAIEFuseConsumerIntoLoop : | ||
InterfacePass<"iree-amdaie-fuse-consumer-into-loop", "mlir::FunctionOpInterface"> { | ||
let summary = "Fuse the consumer operation into the innermost last scf loop."; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
// RUN: iree-opt --split-input-file --iree-amdaie-linalg-function-outlining --verify-diagnostics --split-input-file %s | FileCheck %s | ||
|
||
// Test demonstrating multiple Matmul using different SSAs. | ||
|
||
// CHECK-LABEL: func.func private @generic_matmul_0_outlined | ||
// CHECK-SAME: (%[[LHS:.*]]: memref<4x8xbf16>, | ||
// CHECK-SAME: %[[RHS:.*]]: memref<8x4xbf16>, | ||
// CHECK-SAME: %[[OUT:.*]]: memref<4x4xf32>) { | ||
// CHECK: linalg.generic | ||
// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : | ||
// CHECK-SAME: outs(%[[OUT]] : | ||
// CHECK: return | ||
// CHECK: } | ||
// CHECK-LABEL: func.func @matmul_example | ||
// CHECK-SAME: (%[[A:.*]]: memref<4x8xbf16>, | ||
// CHECK-SAME: %[[B:.*]]: memref<8x4xbf16>, | ||
// CHECK-SAME: %[[C:.*]]: memref<4x4xf32>) { | ||
// CHECK: amdaie.core | ||
// CHECK: func.call @generic_matmul_0_outlined(%[[A]], %[[B]], %[[C]]) | ||
// CHECK-NOT: linalg.generic | ||
// CHECK: amdaie.end | ||
// CHECK: } | ||
// CHECK: amdaie.core | ||
// CHECK: func.call @generic_matmul_0_outlined(%[[A]], %[[B]], %[[C]]) | ||
// CHECK-NOT: linalg.generic | ||
// CHECK: amdaie.end | ||
// CHECK: } | ||
// CHECK: return | ||
// CHECK: } | ||
func.func @matmul_example(%A: memref<4x8xbf16>, %B: memref<8x4xbf16>, %C: memref<4x4xf32>) { | ||
%c2 = arith.constant 2 : index | ||
%c1 = arith.constant 1 : index | ||
%tile = amdaie.tile(%c1, %c2) | ||
%0 = amdaie.core(%tile, in : [], out : []) { | ||
linalg.generic { | ||
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, | ||
affine_map<(d0, d1, d2) -> (d2, d1)>, | ||
affine_map<(d0, d1, d2) -> (d0, d1)> | ||
], | ||
iterator_types = ["parallel", "parallel", "reduction"] | ||
} ins(%A, %B : memref<4x8xbf16>, memref<8x4xbf16>) | ||
outs(%C : memref<4x4xf32>) { | ||
^bb0(%in: bf16, %in_17: bf16, %out: f32): | ||
%1 = arith.extf %in : bf16 to f32 | ||
%2 = arith.extf %in_17 : bf16 to f32 | ||
%3 = arith.mulf %1, %2 : f32 | ||
%4 = arith.addf %out, %3 : f32 | ||
linalg.yield %4 : f32 | ||
} | ||
amdaie.end | ||
} | ||
%1 = amdaie.core(%tile, in : [], out : []) { | ||
linalg.generic { | ||
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, | ||
affine_map<(d0, d1, d2) -> (d2, d1)>, | ||
affine_map<(d0, d1, d2) -> (d0, d1)> | ||
], | ||
iterator_types = ["parallel", "parallel", "reduction"] | ||
} ins(%A, %B : memref<4x8xbf16>, memref<8x4xbf16>) | ||
outs(%C : memref<4x4xf32>) { | ||
^bb0(%in: bf16, %in_17: bf16, %out: f32): | ||
%1 = arith.extf %in : bf16 to f32 | ||
%2 = arith.extf %in_17 : bf16 to f32 | ||
%3 = arith.mulf %1, %2 : f32 | ||
%4 = arith.addf %out, %3 : f32 | ||
linalg.yield %4 : f32 | ||
} | ||
amdaie.end | ||
} | ||
return | ||
} | ||
|
||
// ----- | ||
|
||
// Test demonstrating different kind of elementwise operations being mapped to a | ||
// unique corresponding outlined function. | ||
|
||
// CHECK-LABEL: func.func private @generic_elementwise_1_outlined | ||
// CHECK-SAME: (%[[INPUT:.*]]: memref<4xf32>, | ||
// CHECK-SAME: %[[OUTPUT:.*]]: memref<4xbf16>) { | ||
// CHECK: linalg.generic | ||
// CHECK-SAME: ins(%[[INPUT]] : | ||
// CHECK-SAME: outs(%[[OUTPUT]] : | ||
// CHECK: arith.truncf | ||
// CHECK: arith.addf | ||
// CHECK: return | ||
// CHECK: } | ||
// CHECK: func.func private @generic_elementwise_0_outlined | ||
// CHECK-SAME: (%[[INPUT:.*]]: memref<4xf32>, | ||
// CHECK-SAME: %[[OUTPUT:.*]]: memref<4xbf16>) { | ||
// CHECK: linalg.generic | ||
// CHECK-SAME: ins(%[[INPUT]] : | ||
// CHECK-SAME: outs(%[[OUTPUT]] : | ||
// CHECK: arith.truncf | ||
// CHECK: return | ||
// CHECK: } | ||
// CHECK-LABEL: func.func @elemwise_example | ||
// CHECK-SAME: (%[[A:.*]]: memref<4xf32>, | ||
// CHECK-SAME: %[[C:.*]]: memref<4xbf16>, | ||
// CHECK-SAME: %[[B:.*]]: memref<4xf32>) { | ||
// CHECK: amdaie.core | ||
// CHECK: func.call @generic_elementwise_0_outlined(%[[A]], %[[C]]) | ||
// CHECK-NOT: linalg.generic | ||
// CHECK: amdaie.end | ||
// CHECK: } | ||
// CHECK: amdaie.core | ||
// CHECK: func.call @generic_elementwise_0_outlined(%[[B]], %[[C]]) | ||
// CHECK-NOT: linalg.generic | ||
// CHECK: amdaie.end | ||
// CHECK: } | ||
// CHECK: amdaie.core | ||
// CHECK: func.call @generic_elementwise_1_outlined(%[[A]], %[[C]]) | ||
// CHECK-NOT: linalg.generic | ||
// CHECK: amdaie.end | ||
// CHECK: } | ||
// CHECK: amdaie.core | ||
// CHECK: func.call @generic_elementwise_0_outlined(%[[A]], %[[C]]) | ||
// CHECK-NOT: linalg.generic | ||
// CHECK: amdaie.end | ||
// CHECK: } | ||
// CHECK: return | ||
// CHECK: } | ||
func.func @elemwise_example(%A: memref<4xf32>, %C: memref<4xbf16>, %B: memref<4xf32>) { | ||
%c2 = arith.constant 2 : index | ||
%c1 = arith.constant 1 : index | ||
%tile = amdaie.tile(%c1, %c2) | ||
%0 = amdaie.core(%tile, in : [], out : []) { | ||
linalg.generic { | ||
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], | ||
iterator_types = ["parallel"] | ||
} ins(%A : memref<4xf32>) | ||
outs(%C : memref<4xbf16>) { | ||
^bb0(%in: f32, %out: bf16): | ||
%1 = arith.truncf %in : f32 to bf16 | ||
linalg.yield %1 : bf16 | ||
} | ||
amdaie.end | ||
} | ||
Abhishek-Varma marked this conversation as resolved.
Show resolved
Hide resolved
|
||
%2 = amdaie.core(%tile, in : [], out : []) { | ||
linalg.generic { | ||
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], | ||
iterator_types = ["parallel"] | ||
} ins(%B : memref<4xf32>) | ||
outs(%C : memref<4xbf16>) { | ||
^bb0(%in: f32, %out: bf16): | ||
%3 = arith.truncf %in : f32 to bf16 | ||
linalg.yield %3 : bf16 | ||
} | ||
amdaie.end | ||
} | ||
%4 = amdaie.core(%tile, in : [], out : []) { | ||
linalg.generic { | ||
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], | ||
iterator_types = ["parallel"] | ||
} ins(%A : memref<4xf32>) | ||
outs(%C : memref<4xbf16>) { | ||
^bb0(%in: f32, %out: bf16): | ||
%5 = arith.truncf %in : f32 to bf16 | ||
%6 = arith.addf %5, %out : bf16 | ||
linalg.yield %6 : bf16 | ||
} | ||
amdaie.end | ||
} | ||
%7 = amdaie.core(%tile, in : [], out : []) { | ||
linalg.generic { | ||
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], | ||
iterator_types = ["parallel"] | ||
} ins(%A : memref<4xf32>) | ||
outs(%C : memref<4xbf16>) { | ||
^bb0(%in: f32, %out: bf16): | ||
%8 = arith.truncf %in : f32 to bf16 | ||
linalg.yield %8 : bf16 | ||
} | ||
amdaie.end | ||
} | ||
Comment on lines
+164
to
+175
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the above three elementwise cores are good enough for test purpose. Why don't you delete this or change this to a matmul. |
||
return | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Design nit: I would make these (private) members of the Pass, and make outlinedToAFunction a member function of the Pass too. Just because I'm a bit unsure about these non-static global variables, not seen this C++ design before.