Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LinalgFunctionOutlining] Create a pass to outline linalg compute ops #862

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree-amd-aie/IR/AMDAIEOps.h"
#include "iree-amd-aie/Transforms/AMDAIEUtils.h"
#include "iree-amd-aie/Transforms/Passes.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"

#define DEBUG_TYPE "iree-amdaie-linalg-function-outlining"

namespace mlir::iree_compiler::AMDAIE {

namespace {

unsigned uniqueOutlinedMatmul = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Design nit: I would make these (private) members of the Pass, and make outlinedToAFunction a member function of the Pass too. Just because I'm a bit unsure about these non-static global variables, not seen this C++ design before.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Global values like this arent a good idea. This can cause race conditions when run in multi-threaded context.

I think LLVM already has a way to deduplicate symbols. I dont know how to call it though.

unsigned uniqueOutlinedElementwise = 0;

/// Utility to outline the linalg compute op.
static FailureOr<func::FuncOp> outlinedToAFunction(
IRRewriter &rewriter, ModuleOp moduleOp, linalg::LinalgOp computeOp,
std::string outlineFuncName,
DenseMap<Operation *, std::string> &computeOpToOutlinedFuncMap) {
// Check if the compute op is equivalent to a previously outlined compute op.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point we expect moduleOp's symbol table to not contain outlineFuncName. I think this is worth asserting, to avoid to unlikely case that there is already a function with the name you've just created. I suppose you could increment the uniqueOutlined index in a while loop until an unused name is found.

// If yes, we replace the `outlineFuncName` of the current compute op to be
// same as the previous equivalent outlined compute op in order to lookup the
// Symbol table.
for (auto &[op, funcName] : computeOpToOutlinedFuncMap) {
if (!OperationEquivalence::isEquivalentTo(
computeOp.getOperation(), op,
OperationEquivalence::ignoreValueEquivalence, /*flags=*/nullptr,
OperationEquivalence::IgnoreLocations))
continue;
outlineFuncName = funcName;
break;
Comment on lines +36 to +42
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't you directly check if the isEquivalentTo condition holds, and then update outlineFuncName and break?

}
if (auto outlinedFuncOp = dyn_cast_if_present<func::FuncOp>(
moduleOp.lookupSymbol(outlineFuncName))) {
return outlinedFuncOp;
}

// Form outlined FunctionType.
SmallVector<Type> inputTypes = llvm::map_to_vector(
computeOp.getDpsInputs(), [](Value v) { return v.getType(); });
for (Value val : computeOp.getDpsInits()) inputTypes.push_back(val.getType());
auto outlinedFuncType =
FunctionType::get(rewriter.getContext(), inputTypes, /*outputTypes=*/{});

// Form outlined FuncSignature
rewriter.setInsertionPointToStart(moduleOp.getBody());
auto outlinedFunc = rewriter.create<func::FuncOp>(
moduleOp.getLoc(), outlineFuncName, outlinedFuncType);
outlinedFunc.setPrivate();

// Create an entry func block and map the original operands of the compute
// op to the block arguments.
Block *outlinedFuncBody = outlinedFunc.addEntryBlock();
rewriter.setInsertionPointToStart(outlinedFuncBody);
SmallVector<BlockArgument> outlinedFuncArgs = llvm::map_to_vector(
outlinedFunc.getArguments(), [&](BlockArgument bbArg) { return bbArg; });
unsigned bbArgIndex = 0;
IRMapping operandMap;
for (Value origOperand : computeOp.getDpsInputs())
operandMap.map(origOperand, outlinedFuncArgs[bbArgIndex++]);
for (Value origOperand : computeOp.getDpsInits())
operandMap.map(origOperand, outlinedFuncArgs[bbArgIndex++]);

// Clone the compute op while mapping the operand to the function block
// arguments.
Operation *clonedComputeOp = rewriter.clone(*computeOp, operandMap);

// Create terminator op returning the cloned compute op's results.
rewriter.setInsertionPointToEnd(outlinedFuncBody);
rewriter.create<func::ReturnOp>(clonedComputeOp->getLoc(), ValueRange({}));

computeOpToOutlinedFuncMap[computeOp] = outlineFuncName;

// Since we have created a new outlined function, we will increase the
// corresponding unique count.
if (isMatmul(computeOp)) {
++uniqueOutlinedMatmul;
} else if (isElementwise(computeOp)) {
++uniqueOutlinedElementwise;
}
return outlinedFunc;
}

class AMDAIELinalgFunctionOutliningPass
: public impl::AMDAIELinalgFunctionOutliningBase<
AMDAIELinalgFunctionOutliningPass> {
public:
AMDAIELinalgFunctionOutliningPass() = default;
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AMDAIEDialect, linalg::LinalgDialect>();
}

void runOnOperation() override;
};

void AMDAIELinalgFunctionOutliningPass::runOnOperation() {
ModuleOp moduleOp = getOperation();
MLIRContext *context = &getContext();
IRRewriter rewriter(context);

DenseMap<Operation *, std::string> computeOpToOutlinedFuncMap;
SmallVector<Operation *> toBeErased;
moduleOp.walk([&](linalg::LinalgOp computeOp) {
// Form outlined FuncName.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also add comments of what linalg ops are currently outlined.

std::string computeName = "";
if (isMatmul(computeOp)) {
computeName = "_matmul_" + std::to_string(uniqueOutlinedMatmul);
} else if (isElementwise(computeOp)) {
computeName = "_elementwise_" + std::to_string(uniqueOutlinedElementwise);
} else {
return WalkResult::skip();
}
std::string outlineFuncName =
computeOp->getName().stripDialect().str() + computeName + "_outlined";
FailureOr<func::FuncOp> outlinedFuncOp =
outlinedToAFunction(rewriter, moduleOp, computeOp, outlineFuncName,
computeOpToOutlinedFuncMap);
if (failed(outlinedFuncOp)) return WalkResult::interrupt();
rewriter.setInsertionPoint(computeOp);
rewriter.create<func::CallOp>(computeOp.getLoc(), *outlinedFuncOp,
computeOp->getOperands());
// We cannot immediately erase the compute op because it'd be used for
// equivalence check.
toBeErased.push_back(computeOp);
return WalkResult::advance();
});
for (Operation *op : toBeErased) {
op->dropAllUses();
rewriter.eraseOp(op);
}
}

} // namespace

std::unique_ptr<Pass> createAMDAIELinalgFunctionOutliningPass() {
return std::make_unique<AMDAIELinalgFunctionOutliningPass>();
}
} // namespace mlir::iree_compiler::AMDAIE
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ LogicalResult AIEDeviceBuilder::coreFuncCallOpToAIE(
SymbolTable::setSymbolVisibility(newFnDecl,
SymbolTable::Visibility::Private);
newFnDecl->setAttr("llvm.bareptr", rewriter.getBoolAttr(true));
fnDecl.getBody().cloneInto(&(newFnDecl.getBody()), mapper);
mapper.map(fnDecl.getOperation(), newFnDecl.getOperation());
fnDecl = newFnDecl;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ iree_cc_library(
"AMDAIEDmaToCircularDma.cpp"
"AMDAIEDmaUtils.cpp"
"AMDAIEFlattenLogicalObjectFifo.cpp"
"AMDAIELinalgFunctionOutlining.cpp"
"AMDAIEFuseConsumerIntoLoop.cpp"
"AMDAIEFuseFillIntoForall.cpp"
"AMDAIEFusePackIntoLoop.cpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ namespace mlir::iree_compiler::AMDAIE {
#define GEN_PASS_DEF_AMDAIEDMALOOPSUBSUMPTION
#define GEN_PASS_DEF_AMDAIEDMATOCIRCULARDMA
#define GEN_PASS_DEF_AMDAIEFLATTENLOGICALOBJECTFIFO
#define GEN_PASS_DEF_AMDAIELINALGFUNCTIONOUTLINING
#define GEN_PASS_DEF_AMDAIEFUSECONSUMERINTOLOOP
#define GEN_PASS_DEF_AMDAIEFUSEFILLINTOFORALL
#define GEN_PASS_DEF_AMDAIEFUSEPACKINTOLOOP
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ std::unique_ptr<Pass> createAMDAIEDmaToCircularDmaPass();
/// Create a pass to flatten the logical objectFifos.
std::unique_ptr<Pass> createAMDAIEFlattenLogicalObjectFifoPass();

/// Create a pass for function outlining.
std::unique_ptr<Pass> createAMDAIELinalgFunctionOutliningPass();

/// Create a pass to fuse the consumer op into the innermost last scf loop.
std::unique_ptr<Pass> createAMDAIEFuseConsumerIntoLoopPass(
AMDAIEFuseConsumerIntoLoopOptions options = {});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,19 @@ def AMDAIEFlattenLogicalObjectFifo :
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIEFlattenLogicalObjectFifoPass()";
}

def AMDAIELinalgFunctionOutlining :
Pass<"iree-amdaie-linalg-function-outlining", "ModuleOp"> {
let summary = "Outlining of linalg compute ops";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add some additional information in a description = [{ }] block about which linalg ops specifically are outlined (i.e. linalg.fill is not) and what the motivation for this pass is, and what some assumptions are? Sorry, I know that most of our passes in this file don't contain descriptions, but I think it'd be a good thing if we all started adding high-level information here for people without context.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added that. Please check if the wordings are correct.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.

Is

One assumption this pass currently makes: All elementwise/matmul linalg ops within a
    dispatch have same body content.

still true? Otherwise looks good.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question. Doesnt operation equivalence check the region as well?

let description = [{
Outlines matmul/elementwise linalg compute ops only. This pass essentially minimises
the code footprint overall as the unrolling in the later passes lead to lot repeated
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: 1) "loop unrolling"; 2) "lead to lots of repeated codes".

lines of codes.
One assumption this pass currently makes: All elementwise/matmul linalg ops within a
dispatch have same body content.
}];
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIELinalgFunctionOutliningPass()";
}

def AMDAIEFuseConsumerIntoLoop :
InterfacePass<"iree-amdaie-fuse-consumer-into-loop", "mlir::FunctionOpInterface"> {
let summary = "Fuse the consumer operation into the innermost last scf loop.";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ iree_lit_test_suite(
"dma_loop_subsumption.mlir"
"dma_to_circular_dma.mlir"
"flatten_logical_objectfifo.mlir"
"linalg_function_outlining.mlir"
"fuse_consumer_into_loop_scf_for.mlir"
"fuse_consumer_into_loop_scf_forall.mlir"
"fuse_fill_into_forall.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
// RUN: iree-opt --split-input-file --iree-amdaie-linalg-function-outlining --verify-diagnostics --split-input-file %s | FileCheck %s

// Test demonstrating multiple Matmul using different SSAs.

// CHECK-LABEL: func.func private @generic_matmul_0_outlined
// CHECK-SAME: (%[[LHS:.*]]: memref<4x8xbf16>,
// CHECK-SAME: %[[RHS:.*]]: memref<8x4xbf16>,
// CHECK-SAME: %[[OUT:.*]]: memref<4x4xf32>) {
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[LHS]], %[[RHS]] :
// CHECK-SAME: outs(%[[OUT]] :
// CHECK: return
// CHECK: }
// CHECK-LABEL: func.func @matmul_example
// CHECK-SAME: (%[[A:.*]]: memref<4x8xbf16>,
// CHECK-SAME: %[[B:.*]]: memref<8x4xbf16>,
// CHECK-SAME: %[[C:.*]]: memref<4x4xf32>) {
// CHECK: amdaie.core
// CHECK: func.call @generic_matmul_0_outlined(%[[A]], %[[B]], %[[C]])
// CHECK-NOT: linalg.generic
// CHECK: amdaie.end
// CHECK: }
// CHECK: amdaie.core
// CHECK: func.call @generic_matmul_0_outlined(%[[A]], %[[B]], %[[C]])
// CHECK-NOT: linalg.generic
// CHECK: amdaie.end
// CHECK: }
// CHECK: return
// CHECK: }
func.func @matmul_example(%A: memref<4x8xbf16>, %B: memref<8x4xbf16>, %C: memref<4x4xf32>) {
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%tile = amdaie.tile(%c1, %c2)
%0 = amdaie.core(%tile, in : [], out : []) {
linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
affine_map<(d0, d1, d2) -> (d2, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1)>
],
iterator_types = ["parallel", "parallel", "reduction"]
} ins(%A, %B : memref<4x8xbf16>, memref<8x4xbf16>)
outs(%C : memref<4x4xf32>) {
^bb0(%in: bf16, %in_17: bf16, %out: f32):
%1 = arith.extf %in : bf16 to f32
%2 = arith.extf %in_17 : bf16 to f32
%3 = arith.mulf %1, %2 : f32
%4 = arith.addf %out, %3 : f32
linalg.yield %4 : f32
}
amdaie.end
}
%1 = amdaie.core(%tile, in : [], out : []) {
linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
affine_map<(d0, d1, d2) -> (d2, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1)>
],
iterator_types = ["parallel", "parallel", "reduction"]
} ins(%A, %B : memref<4x8xbf16>, memref<8x4xbf16>)
outs(%C : memref<4x4xf32>) {
^bb0(%in: bf16, %in_17: bf16, %out: f32):
%1 = arith.extf %in : bf16 to f32
%2 = arith.extf %in_17 : bf16 to f32
%3 = arith.mulf %1, %2 : f32
%4 = arith.addf %out, %3 : f32
linalg.yield %4 : f32
}
amdaie.end
}
return
}

// -----

// Test demonstrating different kind of elementwise operations being mapped to a
// unique corresponding outlined function.

// CHECK-LABEL: func.func private @generic_elementwise_1_outlined
// CHECK-SAME: (%[[INPUT:.*]]: memref<4xf32>,
// CHECK-SAME: %[[OUTPUT:.*]]: memref<4xbf16>) {
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[INPUT]] :
// CHECK-SAME: outs(%[[OUTPUT]] :
// CHECK: arith.truncf
// CHECK: arith.addf
// CHECK: return
// CHECK: }
// CHECK: func.func private @generic_elementwise_0_outlined
// CHECK-SAME: (%[[INPUT:.*]]: memref<4xf32>,
// CHECK-SAME: %[[OUTPUT:.*]]: memref<4xbf16>) {
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[INPUT]] :
// CHECK-SAME: outs(%[[OUTPUT]] :
// CHECK: arith.truncf
// CHECK: return
// CHECK: }
// CHECK-LABEL: func.func @elemwise_example
// CHECK-SAME: (%[[A:.*]]: memref<4xf32>,
// CHECK-SAME: %[[C:.*]]: memref<4xbf16>,
// CHECK-SAME: %[[B:.*]]: memref<4xf32>) {
// CHECK: amdaie.core
// CHECK: func.call @generic_elementwise_0_outlined(%[[A]], %[[C]])
// CHECK-NOT: linalg.generic
// CHECK: amdaie.end
// CHECK: }
// CHECK: amdaie.core
// CHECK: func.call @generic_elementwise_0_outlined(%[[B]], %[[C]])
// CHECK-NOT: linalg.generic
// CHECK: amdaie.end
// CHECK: }
// CHECK: amdaie.core
// CHECK: func.call @generic_elementwise_1_outlined(%[[A]], %[[C]])
// CHECK-NOT: linalg.generic
// CHECK: amdaie.end
// CHECK: }
// CHECK: amdaie.core
// CHECK: func.call @generic_elementwise_0_outlined(%[[A]], %[[C]])
// CHECK-NOT: linalg.generic
// CHECK: amdaie.end
// CHECK: }
// CHECK: return
// CHECK: }
func.func @elemwise_example(%A: memref<4xf32>, %C: memref<4xbf16>, %B: memref<4xf32>) {
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%tile = amdaie.tile(%c1, %c2)
%0 = amdaie.core(%tile, in : [], out : []) {
linalg.generic {
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]
} ins(%A : memref<4xf32>)
outs(%C : memref<4xbf16>) {
^bb0(%in: f32, %out: bf16):
%1 = arith.truncf %in : f32 to bf16
linalg.yield %1 : bf16
}
amdaie.end
}
Abhishek-Varma marked this conversation as resolved.
Show resolved Hide resolved
%2 = amdaie.core(%tile, in : [], out : []) {
linalg.generic {
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]
} ins(%B : memref<4xf32>)
outs(%C : memref<4xbf16>) {
^bb0(%in: f32, %out: bf16):
%3 = arith.truncf %in : f32 to bf16
linalg.yield %3 : bf16
}
amdaie.end
}
%4 = amdaie.core(%tile, in : [], out : []) {
linalg.generic {
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]
} ins(%A : memref<4xf32>)
outs(%C : memref<4xbf16>) {
^bb0(%in: f32, %out: bf16):
%5 = arith.truncf %in : f32 to bf16
%6 = arith.addf %5, %out : bf16
linalg.yield %6 : bf16
}
amdaie.end
}
%7 = amdaie.core(%tile, in : [], out : []) {
linalg.generic {
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]
} ins(%A : memref<4xf32>)
outs(%C : memref<4xbf16>) {
^bb0(%in: f32, %out: bf16):
%8 = arith.truncf %in : f32 to bf16
linalg.yield %8 : bf16
}
amdaie.end
}
Comment on lines +164 to +175
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the above three elementwise cores are good enough for test purpose. Why don't you delete this or change this to a matmul.

return
}
Loading