Skip to content

Commit

Permalink
[ObjectFifo][NFC] Refactor DmaUtils + SplitLogicalObjectFifoForReuse (#…
Browse files Browse the repository at this point in the history
…759)

-- This commit adds refactoring of few utilities in DmaUtils as well as
a few involved in SplitLogicalObjectFifosForReuse pass.
-- This is required for the follow-up PR that adds a new pass
`--iree-amdaie-combine-logical-objectfifos-for-connection-reuse`.

Signed-off-by: Abhishek Varma <[email protected]>
  • Loading branch information
Abhishek-Varma authored Sep 10, 2024
1 parent 09576c8 commit c041358
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 121 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,60 +52,6 @@ int64_t calculateNbIterations(int64_t lowerBound, int64_t upperBound,

namespace {

/// Utility affine expression visitor to retrieve the scale and optional bias
/// from the expression.
struct RetrieveScaleAndBias
: public AffineExprVisitor<RetrieveScaleAndBias, LogicalResult> {
std::optional<int64_t> scale;
std::optional<int64_t> bias;
LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr /*expr*/) {
return failure();
}
LogicalResult visitConstantExpr(AffineConstantExpr /*expr*/) {
return failure();
}
LogicalResult visitDimExpr(AffineDimExpr /*expr*/) { return failure(); }
LogicalResult visitSymbolExpr(AffineSymbolExpr /*expr*/) { return failure(); }
LogicalResult visitMulExpr(AffineBinaryOpExpr expr) {
if (auto rhsSize = dyn_cast<AffineConstantExpr>(expr.getRHS());
isa<AffineDimExpr>(expr.getLHS())) {
scale = rhsSize.getValue();
} else if (auto lhsSize = dyn_cast<AffineConstantExpr>(expr.getLHS());
isa<AffineDimExpr>(expr.getRHS())) {
scale = lhsSize.getValue();
}
return success();
}
LogicalResult visitAddExpr(AffineBinaryOpExpr expr) {
if (bias) return failure();
if (auto rhsSize = dyn_cast<AffineConstantExpr>(expr.getRHS())) {
bias = rhsSize.getValue();
if (bias.value() < 0) return failure();
if (isa<AffineBinaryOpExpr>(expr.getLHS())) {
return visit(expr.getLHS());
} else if (isa<AffineDimExpr>(expr.getLHS())) {
scale = 1;
return success();
} else {
return failure();
}
} else if (auto lhsSize = dyn_cast<AffineConstantExpr>(expr.getLHS())) {
bias = lhsSize.getValue();
if (bias.value() < 0) return failure();
if (isa<AffineBinaryOpExpr>(expr.getRHS())) {
return visit(expr.getRHS());
} else if (isa<AffineDimExpr>(expr.getRHS())) {
scale = 1;
return success();
} else {
return failure();
}
} else {
return failure();
}
}
};

struct SubsumeLoopIntoDMA
: public OpInterfaceRewritePattern<AMDAIE::DoublyStridedOpInterface> {
using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,71 @@
#include "iree-amd-aie/IR/AMDAIEOps.h"
#include "iree-amd-aie/aie_runtime/iree_aie_runtime.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"

namespace mlir::iree_compiler::AMDAIE {

/// Utility to retrieve a constant index from an OpFoldResult.
int64_t getConstantIndexOrAssert(OpFoldResult dim);

/// Utility affine expression visitor to retrieve the scale and optional bias
/// from the expression.
struct RetrieveScaleAndBias
: public AffineExprVisitor<RetrieveScaleAndBias, LogicalResult> {
std::optional<int64_t> scale;
std::optional<int64_t> bias;
LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr /*expr*/) {
return failure();
}
LogicalResult visitConstantExpr(AffineConstantExpr /*expr*/) {
return failure();
}
LogicalResult visitDimExpr(AffineDimExpr /*expr*/) { return failure(); }
LogicalResult visitSymbolExpr(AffineSymbolExpr /*expr*/) { return failure(); }
LogicalResult visitMulExpr(AffineBinaryOpExpr expr) {
if (auto rhsSize = dyn_cast<AffineConstantExpr>(expr.getRHS());
isa<AffineDimExpr>(expr.getLHS())) {
scale = rhsSize.getValue();
} else if (auto lhsSize = dyn_cast<AffineConstantExpr>(expr.getLHS());
isa<AffineDimExpr>(expr.getRHS())) {
scale = lhsSize.getValue();
}
return success();
}
LogicalResult visitAddExpr(AffineBinaryOpExpr expr) {
if (bias) return failure();
if (auto rhsSize = dyn_cast<AffineConstantExpr>(expr.getRHS())) {
bias = rhsSize.getValue();
if (bias.value() < 0) return failure();
if (isa<AffineBinaryOpExpr>(expr.getLHS())) {
return visit(expr.getLHS());
} else if (isa<AffineDimExpr>(expr.getLHS())) {
scale = 1;
return success();
} else {
return failure();
}
} else if (auto lhsSize = dyn_cast<AffineConstantExpr>(expr.getLHS())) {
bias = lhsSize.getValue();
if (bias.value() < 0) return failure();
if (isa<AffineBinaryOpExpr>(expr.getRHS())) {
return visit(expr.getRHS());
} else if (isa<AffineDimExpr>(expr.getRHS())) {
scale = 1;
return success();
} else {
return failure();
}
} else {
return failure();
}
}
};

// Constant specifying the number of inter-iteration dimension for DMA
// operations.
//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <numeric>

#include "iree-amd-aie/Transforms/AMDAIEDmaUtils.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
Expand All @@ -21,6 +22,56 @@

namespace mlir::iree_compiler::AMDAIE {

/// Utility to create a new logical objectfifo based on shape defined by
/// `newSizesOpFoldResultArr`.
static AMDAIE::LogicalObjectFifoFromMemrefOp createNewLogicalObjectFifo(
IRRewriter &rewriter,
AMDAIE::LogicalObjectFifoFromMemrefOp &oldLogicalObjectFifo,
SmallVectorImpl<OpFoldResult> &newSizesOpFoldResultArr) {
OpBuilder::InsertionGuard guard(rewriter);
SmallVector<int64_t> newSizes = llvm::map_to_vector(
newSizesOpFoldResultArr,
[](OpFoldResult sizeVal) { return getConstantIndexOrAssert(sizeVal); });
Value oldAllocOp = oldLogicalObjectFifo.getMemref();
auto oldMemRefType = cast<MemRefType>(oldAllocOp.getType());
MemRefType newAllocType = MemRefType::get(
newSizes, oldMemRefType.getElementType(), MemRefLayoutAttrInterface{},
oldMemRefType.getMemorySpace());
assert(oldAllocOp.getDefiningOp() && "expected a defining op for the value");
rewriter.setInsertionPoint(oldAllocOp.getDefiningOp());
auto newAllocOp =
rewriter.create<memref::AllocOp>(rewriter.getUnknownLoc(), newAllocType);
auto newDeallocOp =
rewriter.create<memref::DeallocOp>(rewriter.getUnknownLoc(), newAllocOp);
newDeallocOp->moveBefore(&newAllocOp->getBlock()->back());
auto type = cast<MemRefType>(newAllocOp.getType());
// Create new logical objectfifo.
rewriter.setInsertionPoint(oldLogicalObjectFifo);
auto newLogicalObjectFifo =
rewriter.create<AMDAIE::LogicalObjectFifoFromMemrefOp>(
rewriter.getUnknownLoc(), LogicalObjectFifoType::get(type),
newAllocOp.getResult(), oldLogicalObjectFifo.getTiles());
return newLogicalObjectFifo;
}

/// Utility to help fetch those input DmaCpyNd Ops which needs to be split.
SmallVector<AMDAIE::DmaCpyNdOp> fetchDmaCpyNdOpsToSplitOrCombine(
Operation *op) {
SmallVector<AMDAIE::DmaCpyNdOp> l2ToL1DmaOps;
// We are currently walking through CoreOps gathering 3rd Input DmaOp (if
// applicable) from them.
// TODO(avarma): We will generalize this later.
op->walk([&](AMDAIE::CoreOp coreOp) {
SmallVector<Value> inputDmas = coreOp.getInputDmas();
if (inputDmas.size() != 3) return WalkResult::skip();
auto dmaCpyNdOp = inputDmas[2].getDefiningOp<AMDAIE::DmaCpyNdOp>();
assert(dmaCpyNdOp && "expected an amdaie.dma_cpy_nd op");
l2ToL1DmaOps.push_back(dmaCpyNdOp);
return WalkResult::advance();
});
return l2ToL1DmaOps;
}

/// Utility to verify that the split dimensions for L2 are contiguous.
static LogicalResult checkIsRangeFromZero(
SmallVector<size_t> &splitDimsSetForL2) {
Expand Down Expand Up @@ -124,6 +175,33 @@ static FailureOr<OpFoldResult> updateL3SourceOffset(IRRewriter &rewriter,
return newL3AsSourceOffset;
}

/// Given a L2->L1 DmaCpyNd op, find the unique L3->L2 DmaCpyNd op.
static FailureOr<AMDAIE::DmaCpyNdOp> fetchL3ToL2DmaCpyNdOp(
AMDAIE::DmaCpyNdOp l2ToL1DmaOp) {
LogicalObjectFifoFromMemrefOp sourceObjectFifo =
l2ToL1DmaOp.getSourceObjectFifo();
SmallVector<AMDAIE::DmaCpyNdOp> l3ToL2DmaOps;
AMDAIE::DmaCpyNdOp l3ToL2DmaOp;
for (Operation *objFifoUserOp : sourceObjectFifo->getUsers()) {
if (auto dmaOp = dyn_cast<AMDAIE::DmaCpyNdOp>(objFifoUserOp);
dmaOp.getTargetObjectFifo() == sourceObjectFifo) {
l3ToL2DmaOps.push_back(dmaOp);
}
}
if (l3ToL2DmaOps.size() == 0) {
LLVM_DEBUG(llvm::dbgs() << "no corresponding L3->L2 dma op found for "
<< sourceObjectFifo << "\n");
return failure();
}
if (l3ToL2DmaOps.size() > 1) {
LLVM_DEBUG(llvm::dbgs() << "found more than one L3->L2 dma ops for "
<< sourceObjectFifo << "\n");
return failure();
}
l3ToL2DmaOp = l3ToL2DmaOps[0];
return l3ToL2DmaOp;
}

/// A struct utility to encapsulate all the data required to perform splitting
/// of logicalobjectfifos.
struct SplittingLogicalObjectFifoData {
Expand Down Expand Up @@ -186,25 +264,10 @@ static LogicalResult checkWhetherSplitIsPossible(
}

// Fetch the L3 -> L2 Dma Op corresponding to the L2 buffer as target.
SmallVector<AMDAIE::DmaCpyNdOp> l3ToL2DmaOps;
AMDAIE::DmaCpyNdOp l3ToL2DmaOp;
for (Operation *objFifoUserOp : sourceObjectFifo->getUsers()) {
if (auto dmaOp = dyn_cast<AMDAIE::DmaCpyNdOp>(objFifoUserOp);
dmaOp.getTargetObjectFifo() == sourceObjectFifo) {
l3ToL2DmaOps.push_back(dmaOp);
}
}
if (l3ToL2DmaOps.size() == 0) {
LLVM_DEBUG(llvm::dbgs() << "no corresponding L3->L2 dma op found for "
<< sourceObjectFifo << "\n");
return failure();
}
if (l3ToL2DmaOps.size() > 1) {
LLVM_DEBUG(llvm::dbgs() << "found more than one L3->L2 dma ops for "
<< sourceObjectFifo << "\n");
return failure();
}
l3ToL2DmaOp = l3ToL2DmaOps[0];
FailureOr<AMDAIE::DmaCpyNdOp> maybeL3ToL2DmaOp =
fetchL3ToL2DmaCpyNdOp(l2ToL1DmaOps[0]);
if (failed(maybeL3ToL2DmaOp)) return failure();
AMDAIE::DmaCpyNdOp l3ToL2DmaOp = maybeL3ToL2DmaOp.value();
if ((l3ToL2DmaOp.getTargetMixedOffsets().size() !=
l3ToL2DmaOp.getSourceMixedOffsets().size()) ||
(l3ToL2DmaOp.getTargetMixedSizes().size() !=
Expand Down Expand Up @@ -293,9 +356,6 @@ LogicalResult splitLogicalObjectFifos(
l3ToL2DmaOp.getTargetMixedOffsets();
SmallVector<OpFoldResult, 4> staticL2AsTargetSizes =
l3ToL2DmaOp.getTargetMixedSizes();
SmallVector<int64_t, 4> l2ShapeAsTarget = llvm::to_vector(
cast<MemRefType>(l3ToL2DmaOp.getTargetObjectFifo().getMemref().getType())
.getShape());
SmallVector<OpFoldResult, 4> staticL3AsSourceOffsets =
l3ToL2DmaOp.getSourceMixedOffsets();
SmallVector<OpFoldResult, 4> staticL3AsSourceSizes =
Expand All @@ -310,7 +370,6 @@ LogicalResult splitLogicalObjectFifos(
staticL2AsTargetSizes[dim] = oneVal;
staticL3AsSourceOffsets[dim] = zeroVal;
staticL3AsSourceSizes[dim] = oneVal;
l2ShapeAsTarget[dim] = 1;
}

// Traverse each L2->L1 DmaCpyNd op and split them.
Expand All @@ -321,34 +380,18 @@ LogicalResult splitLogicalObjectFifos(
l2ToL1DmaOp.getSourceMixedSizes();

// Now we'll create a new L2 buffer based on the new shape inferred earlier
// via `l2ShapeAsTarget`.
rewriter.setInsertionPoint(sourceAllocOp);
LogicalObjectFifoFromMemrefOp targetObjectFifo =
l2ToL1DmaOp.getTargetObjectFifo();
Value targetAllocOp = targetObjectFifo.getMemref();
auto oldSourceMemRefType = cast<MemRefType>(sourceAllocOp.getType());
auto targetMemRefType = cast<MemRefType>(targetAllocOp.getType());
MemRefType newAllocType = MemRefType::get(
l2ShapeAsTarget, targetMemRefType.getElementType(),
MemRefLayoutAttrInterface{}, oldSourceMemRefType.getMemorySpace());
auto newAllocOp = rewriter.create<memref::AllocOp>(rewriter.getUnknownLoc(),
newAllocType);
auto newDeallocOp = rewriter.create<memref::DeallocOp>(
rewriter.getUnknownLoc(), newAllocOp);
newDeallocOp->moveBefore(&newAllocOp->getBlock()->back());
auto type = cast<MemRefType>(newAllocOp.getType());
// Create new logicalobjectfifo.from_memref for the newly created L2 buffer.
rewriter.setInsertionPoint(l2ToL1DmaOp.getSourceObjectFifo());
auto source = rewriter.create<AMDAIE::LogicalObjectFifoFromMemrefOp>(
rewriter.getUnknownLoc(), LogicalObjectFifoType::get(type),
newAllocOp.getResult(), sourceObjectFifo.getTiles());
// via `staticL2AsTargetSizes`.
LogicalObjectFifoFromMemrefOp oldL2ObjectFifo =
l2ToL1DmaOp.getSourceObjectFifo();
AMDAIE::LogicalObjectFifoFromMemrefOp source = createNewLogicalObjectFifo(
rewriter, oldL2ObjectFifo, staticL2AsTargetSizes);

// --------------------------------------------
// ---------- L3 -> L2 splitting --------------
// --------------------------------------------
// Update L3 source offsets for non-split dimensions. Refer doc comment of
// `updateL3SourceOffset` for the computation rationale involved.
SmallVector<OpFoldResult, 4> staticL3AsSourceOffsets =
SmallVector<OpFoldResult> staticL3AsSourceOffsets =
l3ToL2DmaOp.getSourceMixedOffsets();
for (auto &&[splitDim, nonSplitdim] :
llvm::zip_equal(splitDimsForL2, nonSplitDimsForL2)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@

namespace mlir::iree_compiler::AMDAIE {

/// Utility to split logicalobjectfifos given a struct
/// `SplittingLogicalObjectFifoData` which contains all the required data to
/// perform the splitting.
/// Utility to help fetch those input DmaCpyNd Ops which needs to be split.
SmallVector<AMDAIE::DmaCpyNdOp> fetchDmaCpyNdOpsToSplitOrCombine(Operation *op);

/// Utility to split logicalobjectfifos given a vector of L2->L1 dma ops.
LogicalResult splitLogicalObjectFifos(
IRRewriter &rewriter, SmallVector<AMDAIE::DmaCpyNdOp> &l2ToL1DmaOps,
MLIRContext *context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,6 @@ namespace mlir::iree_compiler::AMDAIE {

namespace {

/// Utility to help fetch those input DmaCpyNd Ops which needs to be split.
static SmallVector<AMDAIE::DmaCpyNdOp> fetchDmaCpyNdOpsToSplit(
ModuleOp moduleOp) {
SmallVector<AMDAIE::DmaCpyNdOp> l2ToL1DmaOps;
// We are currently walking through CoreOps gathering 3rd Input DmaOp (if
// applicable) from them.
// TODO(avarma): We will generalize this later.
moduleOp.walk([&](AMDAIE::CoreOp coreOp) {
SmallVector<Value> inputDmas = coreOp.getInputDmas();
if (inputDmas.size() != 3) return WalkResult::skip();
auto dmaCpyNdOp = inputDmas[2].getDefiningOp<AMDAIE::DmaCpyNdOp>();
assert(dmaCpyNdOp && "expected an amdaie.dma_cpy_nd op");
l2ToL1DmaOps.push_back(dmaCpyNdOp);
return WalkResult::advance();
});
return l2ToL1DmaOps;
}

class AMDAIESplitLogicalObjFifosForConnectionReusePass
: public impl::AMDAIESplitLogicalObjFifosForConnectionReuseBase<
AMDAIESplitLogicalObjFifosForConnectionReusePass> {
Expand All @@ -53,7 +35,7 @@ void AMDAIESplitLogicalObjFifosForConnectionReusePass::runOnOperation() {
IRRewriter rewriter(context);

SmallVector<AMDAIE::DmaCpyNdOp> l2ToL1DmaOps =
fetchDmaCpyNdOpsToSplit(moduleOp);
fetchDmaCpyNdOpsToSplitOrCombine(moduleOp);

if (failed(splitLogicalObjectFifos(rewriter, l2ToL1DmaOps, context))) {
LLVM_DEBUG(llvm::dbgs()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,8 @@ void addAMDAIEObjectFifoLoweringPasses(OpPassManager &passManager) {
passManager.addPass(createCSEPass());
passManager.addPass(createCanonicalizerPass());
passManager.addPass(createAMDAIESplitLogicalObjFifosForConnectionReusePass());
passManager.addPass(createCSEPass());
passManager.addPass(createCanonicalizerPass());

passManager.addPass(createAMDAIEDmaToCircularDmaPass());
passManager.addNestedPass<func::FuncOp>(createAMDAIECreateAIEWorkgroupPass());
Expand Down

0 comments on commit c041358

Please sign in to comment.