Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ObjectFifo][NFC] Refactor DmaUtils + SplitLogicalObjectFifoForReuse #759

Merged
merged 3 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,60 +52,6 @@ int64_t calculateNbIterations(int64_t lowerBound, int64_t upperBound,

namespace {

/// Utility affine expression visitor to retrieve the scale and optional bias
/// from the expression.
struct RetrieveScaleAndBias
: public AffineExprVisitor<RetrieveScaleAndBias, LogicalResult> {
std::optional<int64_t> scale;
std::optional<int64_t> bias;
LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr /*expr*/) {
return failure();
}
LogicalResult visitConstantExpr(AffineConstantExpr /*expr*/) {
return failure();
}
LogicalResult visitDimExpr(AffineDimExpr /*expr*/) { return failure(); }
LogicalResult visitSymbolExpr(AffineSymbolExpr /*expr*/) { return failure(); }
LogicalResult visitMulExpr(AffineBinaryOpExpr expr) {
if (auto rhsSize = dyn_cast<AffineConstantExpr>(expr.getRHS());
isa<AffineDimExpr>(expr.getLHS())) {
scale = rhsSize.getValue();
} else if (auto lhsSize = dyn_cast<AffineConstantExpr>(expr.getLHS());
isa<AffineDimExpr>(expr.getRHS())) {
scale = lhsSize.getValue();
}
return success();
}
LogicalResult visitAddExpr(AffineBinaryOpExpr expr) {
if (bias) return failure();
if (auto rhsSize = dyn_cast<AffineConstantExpr>(expr.getRHS())) {
bias = rhsSize.getValue();
if (bias.value() < 0) return failure();
if (isa<AffineBinaryOpExpr>(expr.getLHS())) {
return visit(expr.getLHS());
} else if (isa<AffineDimExpr>(expr.getLHS())) {
scale = 1;
return success();
} else {
return failure();
}
} else if (auto lhsSize = dyn_cast<AffineConstantExpr>(expr.getLHS())) {
bias = lhsSize.getValue();
if (bias.value() < 0) return failure();
if (isa<AffineBinaryOpExpr>(expr.getRHS())) {
return visit(expr.getRHS());
} else if (isa<AffineDimExpr>(expr.getRHS())) {
scale = 1;
return success();
} else {
return failure();
}
} else {
return failure();
}
}
};

struct SubsumeLoopIntoDMA
: public OpInterfaceRewritePattern<AMDAIE::DoublyStridedOpInterface> {
using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,71 @@
#include "iree-amd-aie/IR/AMDAIEOps.h"
#include "iree-amd-aie/aie_runtime/iree_aie_runtime.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"

namespace mlir::iree_compiler::AMDAIE {

/// Utility to retrieve a constant index from an OpFoldResult.
int64_t getConstantIndexOrAssert(OpFoldResult dim);

/// Utility affine expression visitor to retrieve the scale and optional bias
/// from the expression.
struct RetrieveScaleAndBias
: public AffineExprVisitor<RetrieveScaleAndBias, LogicalResult> {
std::optional<int64_t> scale;
std::optional<int64_t> bias;
LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr /*expr*/) {
return failure();
}
LogicalResult visitConstantExpr(AffineConstantExpr /*expr*/) {
return failure();
}
LogicalResult visitDimExpr(AffineDimExpr /*expr*/) { return failure(); }
LogicalResult visitSymbolExpr(AffineSymbolExpr /*expr*/) { return failure(); }
LogicalResult visitMulExpr(AffineBinaryOpExpr expr) {
if (auto rhsSize = dyn_cast<AffineConstantExpr>(expr.getRHS());
isa<AffineDimExpr>(expr.getLHS())) {
scale = rhsSize.getValue();
} else if (auto lhsSize = dyn_cast<AffineConstantExpr>(expr.getLHS());
isa<AffineDimExpr>(expr.getRHS())) {
scale = lhsSize.getValue();
}
return success();
}
LogicalResult visitAddExpr(AffineBinaryOpExpr expr) {
if (bias) return failure();
if (auto rhsSize = dyn_cast<AffineConstantExpr>(expr.getRHS())) {
bias = rhsSize.getValue();
if (bias.value() < 0) return failure();
if (isa<AffineBinaryOpExpr>(expr.getLHS())) {
return visit(expr.getLHS());
} else if (isa<AffineDimExpr>(expr.getLHS())) {
scale = 1;
return success();
} else {
return failure();
}
} else if (auto lhsSize = dyn_cast<AffineConstantExpr>(expr.getLHS())) {
bias = lhsSize.getValue();
if (bias.value() < 0) return failure();
if (isa<AffineBinaryOpExpr>(expr.getRHS())) {
return visit(expr.getRHS());
} else if (isa<AffineDimExpr>(expr.getRHS())) {
scale = 1;
return success();
} else {
return failure();
}
} else {
return failure();
}
}
};

// Constant specifying the number of inter-iteration dimension for DMA
// operations.
//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <numeric>

#include "iree-amd-aie/Transforms/AMDAIEDmaUtils.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
Expand All @@ -21,6 +22,56 @@

namespace mlir::iree_compiler::AMDAIE {

/// Utility to create a new logical objectfifo based on shape defined by
/// `newSizesOpFoldResultArr`.
static AMDAIE::LogicalObjectFifoFromMemrefOp createNewLogicalObjectFifo(
IRRewriter &rewriter,
AMDAIE::LogicalObjectFifoFromMemrefOp &oldLogicalObjectFifo,
SmallVectorImpl<OpFoldResult> &newSizesOpFoldResultArr) {
OpBuilder::InsertionGuard guard(rewriter);
SmallVector<int64_t> newSizes = llvm::map_to_vector(
newSizesOpFoldResultArr,
[](OpFoldResult sizeVal) { return getConstantIndexOrAssert(sizeVal); });
Value oldAllocOp = oldLogicalObjectFifo.getMemref();
auto oldMemRefType = cast<MemRefType>(oldAllocOp.getType());
MemRefType newAllocType = MemRefType::get(
newSizes, oldMemRefType.getElementType(), MemRefLayoutAttrInterface{},
oldMemRefType.getMemorySpace());
assert(oldAllocOp.getDefiningOp() && "expected a defining op for the value");
rewriter.setInsertionPoint(oldAllocOp.getDefiningOp());
auto newAllocOp =
rewriter.create<memref::AllocOp>(rewriter.getUnknownLoc(), newAllocType);
auto newDeallocOp =
rewriter.create<memref::DeallocOp>(rewriter.getUnknownLoc(), newAllocOp);
newDeallocOp->moveBefore(&newAllocOp->getBlock()->back());
auto type = cast<MemRefType>(newAllocOp.getType());
// Create new logical objectfifo.
rewriter.setInsertionPoint(oldLogicalObjectFifo);
auto newLogicalObjectFifo =
rewriter.create<AMDAIE::LogicalObjectFifoFromMemrefOp>(
rewriter.getUnknownLoc(), LogicalObjectFifoType::get(type),
newAllocOp.getResult(), oldLogicalObjectFifo.getTiles());
return newLogicalObjectFifo;
}

/// Utility to help fetch those input DmaCpyNd Ops which needs to be split.
SmallVector<AMDAIE::DmaCpyNdOp> fetchDmaCpyNdOpsToSplitOrCombine(
Abhishek-Varma marked this conversation as resolved.
Show resolved Hide resolved
Operation *op) {
SmallVector<AMDAIE::DmaCpyNdOp> l2ToL1DmaOps;
// We are currently walking through CoreOps gathering 3rd Input DmaOp (if
// applicable) from them.
// TODO(avarma): We will generalize this later.
op->walk([&](AMDAIE::CoreOp coreOp) {
SmallVector<Value> inputDmas = coreOp.getInputDmas();
if (inputDmas.size() != 3) return WalkResult::skip();
auto dmaCpyNdOp = inputDmas[2].getDefiningOp<AMDAIE::DmaCpyNdOp>();
assert(dmaCpyNdOp && "expected an amdaie.dma_cpy_nd op");
l2ToL1DmaOps.push_back(dmaCpyNdOp);
return WalkResult::advance();
});
return l2ToL1DmaOps;
}

/// Utility to verify that the split dimensions for L2 are contiguous.
static LogicalResult checkIsRangeFromZero(
SmallVector<size_t> &splitDimsSetForL2) {
Expand Down Expand Up @@ -124,6 +175,33 @@ static FailureOr<OpFoldResult> updateL3SourceOffset(IRRewriter &rewriter,
return newL3AsSourceOffset;
}

/// Given a L2->L1 DmaCpyNd op, find the unique L3->L2 DmaCpyNd op.
static FailureOr<AMDAIE::DmaCpyNdOp> fetchL3ToL2DmaCpyNdOp(
AMDAIE::DmaCpyNdOp l2ToL1DmaOp) {
LogicalObjectFifoFromMemrefOp sourceObjectFifo =
l2ToL1DmaOp.getSourceObjectFifo();
SmallVector<AMDAIE::DmaCpyNdOp> l3ToL2DmaOps;
AMDAIE::DmaCpyNdOp l3ToL2DmaOp;
for (Operation *objFifoUserOp : sourceObjectFifo->getUsers()) {
if (auto dmaOp = dyn_cast<AMDAIE::DmaCpyNdOp>(objFifoUserOp);
dmaOp.getTargetObjectFifo() == sourceObjectFifo) {
l3ToL2DmaOps.push_back(dmaOp);
}
}
if (l3ToL2DmaOps.size() == 0) {
LLVM_DEBUG(llvm::dbgs() << "no corresponding L3->L2 dma op found for "
<< sourceObjectFifo << "\n");
return failure();
}
if (l3ToL2DmaOps.size() > 1) {
LLVM_DEBUG(llvm::dbgs() << "found more than one L3->L2 dma ops for "
<< sourceObjectFifo << "\n");
return failure();
}
l3ToL2DmaOp = l3ToL2DmaOps[0];
return l3ToL2DmaOp;
}

/// A struct utility to encapsulate all the data required to perform splitting
/// of logicalobjectfifos.
struct SplittingLogicalObjectFifoData {
Expand Down Expand Up @@ -186,25 +264,10 @@ static LogicalResult checkWhetherSplitIsPossible(
}

// Fetch the L3 -> L2 Dma Op corresponding to the L2 buffer as target.
SmallVector<AMDAIE::DmaCpyNdOp> l3ToL2DmaOps;
AMDAIE::DmaCpyNdOp l3ToL2DmaOp;
for (Operation *objFifoUserOp : sourceObjectFifo->getUsers()) {
if (auto dmaOp = dyn_cast<AMDAIE::DmaCpyNdOp>(objFifoUserOp);
dmaOp.getTargetObjectFifo() == sourceObjectFifo) {
l3ToL2DmaOps.push_back(dmaOp);
}
}
if (l3ToL2DmaOps.size() == 0) {
LLVM_DEBUG(llvm::dbgs() << "no corresponding L3->L2 dma op found for "
<< sourceObjectFifo << "\n");
return failure();
}
if (l3ToL2DmaOps.size() > 1) {
LLVM_DEBUG(llvm::dbgs() << "found more than one L3->L2 dma ops for "
<< sourceObjectFifo << "\n");
return failure();
}
l3ToL2DmaOp = l3ToL2DmaOps[0];
FailureOr<AMDAIE::DmaCpyNdOp> maybeL3ToL2DmaOp =
fetchL3ToL2DmaCpyNdOp(l2ToL1DmaOps[0]);
if (failed(maybeL3ToL2DmaOp)) return failure();
AMDAIE::DmaCpyNdOp l3ToL2DmaOp = maybeL3ToL2DmaOp.value();
if ((l3ToL2DmaOp.getTargetMixedOffsets().size() !=
l3ToL2DmaOp.getSourceMixedOffsets().size()) ||
(l3ToL2DmaOp.getTargetMixedSizes().size() !=
Expand Down Expand Up @@ -293,9 +356,6 @@ LogicalResult splitLogicalObjectFifos(
l3ToL2DmaOp.getTargetMixedOffsets();
SmallVector<OpFoldResult, 4> staticL2AsTargetSizes =
l3ToL2DmaOp.getTargetMixedSizes();
SmallVector<int64_t, 4> l2ShapeAsTarget = llvm::to_vector(
cast<MemRefType>(l3ToL2DmaOp.getTargetObjectFifo().getMemref().getType())
.getShape());
SmallVector<OpFoldResult, 4> staticL3AsSourceOffsets =
l3ToL2DmaOp.getSourceMixedOffsets();
SmallVector<OpFoldResult, 4> staticL3AsSourceSizes =
Expand All @@ -310,7 +370,6 @@ LogicalResult splitLogicalObjectFifos(
staticL2AsTargetSizes[dim] = oneVal;
staticL3AsSourceOffsets[dim] = zeroVal;
staticL3AsSourceSizes[dim] = oneVal;
l2ShapeAsTarget[dim] = 1;
}

// Traverse each L2->L1 DmaCpyNd op and split them.
Expand All @@ -321,34 +380,18 @@ LogicalResult splitLogicalObjectFifos(
l2ToL1DmaOp.getSourceMixedSizes();

// Now we'll create a new L2 buffer based on the new shape inferred earlier
// via `l2ShapeAsTarget`.
rewriter.setInsertionPoint(sourceAllocOp);
LogicalObjectFifoFromMemrefOp targetObjectFifo =
l2ToL1DmaOp.getTargetObjectFifo();
Value targetAllocOp = targetObjectFifo.getMemref();
auto oldSourceMemRefType = cast<MemRefType>(sourceAllocOp.getType());
auto targetMemRefType = cast<MemRefType>(targetAllocOp.getType());
MemRefType newAllocType = MemRefType::get(
l2ShapeAsTarget, targetMemRefType.getElementType(),
MemRefLayoutAttrInterface{}, oldSourceMemRefType.getMemorySpace());
auto newAllocOp = rewriter.create<memref::AllocOp>(rewriter.getUnknownLoc(),
newAllocType);
auto newDeallocOp = rewriter.create<memref::DeallocOp>(
rewriter.getUnknownLoc(), newAllocOp);
newDeallocOp->moveBefore(&newAllocOp->getBlock()->back());
auto type = cast<MemRefType>(newAllocOp.getType());
// Create new logicalobjectfifo.from_memref for the newly created L2 buffer.
rewriter.setInsertionPoint(l2ToL1DmaOp.getSourceObjectFifo());
auto source = rewriter.create<AMDAIE::LogicalObjectFifoFromMemrefOp>(
rewriter.getUnknownLoc(), LogicalObjectFifoType::get(type),
newAllocOp.getResult(), sourceObjectFifo.getTiles());
// via `staticL2AsTargetSizes`.
LogicalObjectFifoFromMemrefOp oldL2ObjectFifo =
l2ToL1DmaOp.getSourceObjectFifo();
AMDAIE::LogicalObjectFifoFromMemrefOp source = createNewLogicalObjectFifo(
rewriter, oldL2ObjectFifo, staticL2AsTargetSizes);

// --------------------------------------------
// ---------- L3 -> L2 splitting --------------
// --------------------------------------------
// Update L3 source offsets for non-split dimensions. Refer doc comment of
// `updateL3SourceOffset` for the computation rationale involved.
SmallVector<OpFoldResult, 4> staticL3AsSourceOffsets =
SmallVector<OpFoldResult> staticL3AsSourceOffsets =
l3ToL2DmaOp.getSourceMixedOffsets();
for (auto &&[splitDim, nonSplitdim] :
llvm::zip_equal(splitDimsForL2, nonSplitDimsForL2)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@

namespace mlir::iree_compiler::AMDAIE {

/// Utility to split logicalobjectfifos given a struct
/// `SplittingLogicalObjectFifoData` which contains all the required data to
/// perform the splitting.
/// Utility to help fetch those input DmaCpyNd Ops which needs to be split.
SmallVector<AMDAIE::DmaCpyNdOp> fetchDmaCpyNdOpsToSplitOrCombine(Operation *op);

/// Utility to split logicalobjectfifos given a vector of L2->L1 dma ops.
LogicalResult splitLogicalObjectFifos(
IRRewriter &rewriter, SmallVector<AMDAIE::DmaCpyNdOp> &l2ToL1DmaOps,
MLIRContext *context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,6 @@ namespace mlir::iree_compiler::AMDAIE {

namespace {

/// Utility to help fetch those input DmaCpyNd Ops which needs to be split.
static SmallVector<AMDAIE::DmaCpyNdOp> fetchDmaCpyNdOpsToSplit(
ModuleOp moduleOp) {
SmallVector<AMDAIE::DmaCpyNdOp> l2ToL1DmaOps;
// We are currently walking through CoreOps gathering 3rd Input DmaOp (if
// applicable) from them.
// TODO(avarma): We will generalize this later.
moduleOp.walk([&](AMDAIE::CoreOp coreOp) {
SmallVector<Value> inputDmas = coreOp.getInputDmas();
if (inputDmas.size() != 3) return WalkResult::skip();
auto dmaCpyNdOp = inputDmas[2].getDefiningOp<AMDAIE::DmaCpyNdOp>();
assert(dmaCpyNdOp && "expected an amdaie.dma_cpy_nd op");
l2ToL1DmaOps.push_back(dmaCpyNdOp);
return WalkResult::advance();
});
return l2ToL1DmaOps;
}

class AMDAIESplitLogicalObjFifosForConnectionReusePass
: public impl::AMDAIESplitLogicalObjFifosForConnectionReuseBase<
AMDAIESplitLogicalObjFifosForConnectionReusePass> {
Expand All @@ -53,7 +35,7 @@ void AMDAIESplitLogicalObjFifosForConnectionReusePass::runOnOperation() {
IRRewriter rewriter(context);

SmallVector<AMDAIE::DmaCpyNdOp> l2ToL1DmaOps =
fetchDmaCpyNdOpsToSplit(moduleOp);
fetchDmaCpyNdOpsToSplitOrCombine(moduleOp);

if (failed(splitLogicalObjectFifos(rewriter, l2ToL1DmaOps, context))) {
LLVM_DEBUG(llvm::dbgs()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,8 @@ void addAMDAIEObjectFifoLoweringPasses(OpPassManager &passManager) {
passManager.addPass(createCSEPass());
passManager.addPass(createCanonicalizerPass());
passManager.addPass(createAMDAIESplitLogicalObjFifosForConnectionReusePass());
passManager.addPass(createCSEPass());
passManager.addPass(createCanonicalizerPass());

passManager.addPass(createAMDAIEDmaToCircularDmaPass());
passManager.addNestedPass<func::FuncOp>(createAMDAIECreateAIEWorkgroupPass());
Expand Down
Loading