Skip to content

Commit

Permalink
Rectify L2's offset/size for L3->L2 and L2->L1
Browse files Browse the repository at this point in the history
  • Loading branch information
Abhishek-Varma committed Aug 14, 2024
1 parent 73d011f commit 8ad64a0
Showing 1 changed file with 61 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,21 @@ void AMDAIESplitLogicalObjectFifosPass::runOnOperation() {
return WalkResult::advance();
});

if (l2ToL1DmaOps.size() == 0) return;

SmallVector<Value> baseSourceOffsets = l2ToL1DmaOps[0].getSourceOffsets();
DenseSet<unsigned> splitDimensions;
for (unsigned i = 1, n = l2ToL1DmaOps.size(); i < n; i++) {
SmallVector<Value> sourceOffsets = l2ToL1DmaOps[i].getSourceOffsets();
for (unsigned j = 0, m = baseSourceOffsets.size(); j < m; j++) {
if (baseSourceOffsets[j] != sourceOffsets[j]) {
splitDimensions.insert(j);
}
}
}

OpFoldResult zeroVal = getAsIndexOpFoldResult(context, 0);
OpFoldResult oneVal = getAsIndexOpFoldResult(context, 1);
DenseSet<Operation *> toBeErased;
for (AMDAIE::DmaCpyNdOp l2ToL1DmaOp : l2ToL1DmaOps) {
LogicalObjectFifoFromMemrefOp sourceObjectFifo =
Expand All @@ -56,19 +71,6 @@ void AMDAIESplitLogicalObjectFifosPass::runOnOperation() {
l2ToL1DmaOp.getTargetObjectFifo();
Value targetAllocOp = targetObjectFifo.getMemref();

// Now we'll create a narrowed linearized L2 buffer.
rewriter.setInsertionPoint(sourceAllocOp);
auto oldSourceMemRefType = cast<MemRefType>(sourceAllocOp.getType());
auto targetMemRefType = cast<MemRefType>(targetAllocOp.getType());
MemRefType newAllocType = MemRefType::get(
targetMemRefType.getNumElements(), targetMemRefType.getElementType(),
MemRefLayoutAttrInterface{}, oldSourceMemRefType.getMemorySpace());
auto newAllocOp = rewriter.create<memref::AllocOp>(rewriter.getUnknownLoc(),
newAllocType);
auto newDeallocOp = rewriter.create<memref::DeallocOp>(
rewriter.getUnknownLoc(), newAllocOp);
newDeallocOp->moveBefore(&newAllocOp->getBlock()->back());

// Fetch the L3 -> L2 Dma Op corresponding to the L2 buffer as target.
AMDAIE::DmaCpyNdOp l3ToL2DmaOp;
for (Operation *objFifoUserOp : sourceObjectFifo->getUsers()) {
Expand All @@ -82,43 +84,68 @@ void AMDAIESplitLogicalObjectFifosPass::runOnOperation() {
toBeErased.insert(sourceAllocOp);
toBeErased.insert(sourceObjectFifo);

SmallVector<OpFoldResult, 6> staticL2AsSourceOffsets =
l2ToL1DmaOp.getSourceMixedOffsets();
SmallVector<OpFoldResult, 6> staticL2AsSourceSizes =
l2ToL1DmaOp.getSourceMixedSizes();
SmallVector<OpFoldResult, 6> staticL2AsSourceStrides =
l2ToL1DmaOp.getSourceMixedStrides();
SmallVector<OpFoldResult, 4> staticL2AsTargetOffsets =
l3ToL2DmaOp.getTargetMixedOffsets();
SmallVector<OpFoldResult, 4> staticL2AsTargetSizes =
l3ToL2DmaOp.getTargetMixedSizes();
SmallVector<OpFoldResult, 4> staticL2AsTargetStrides =
l3ToL2DmaOp.getTargetMixedStrides();
SmallVector<int64_t, 4> l2ShapeAsTarget = llvm::to_vector(
cast<MemRefType>(
l3ToL2DmaOp.getTargetObjectFifo().getMemref().getType())
.getShape());
for (unsigned dim : splitDimensions) {
staticL2AsSourceOffsets[dim] = zeroVal;
staticL2AsSourceSizes[dim] = oneVal;
staticL2AsTargetOffsets[dim] = zeroVal;
staticL2AsTargetSizes[dim] = oneVal;
l2ShapeAsTarget[dim] = 1;
}

// Now we'll create a narrowed linearized L2 buffer.
rewriter.setInsertionPoint(sourceAllocOp);
auto oldSourceMemRefType = cast<MemRefType>(sourceAllocOp.getType());
auto targetMemRefType = cast<MemRefType>(targetAllocOp.getType());
MemRefType newAllocType = MemRefType::get(
l2ShapeAsTarget, targetMemRefType.getElementType(),
MemRefLayoutAttrInterface{}, oldSourceMemRefType.getMemorySpace());
auto newAllocOp = rewriter.create<memref::AllocOp>(rewriter.getUnknownLoc(),
newAllocType);
auto newDeallocOp = rewriter.create<memref::DeallocOp>(
rewriter.getUnknownLoc(), newAllocOp);
newDeallocOp->moveBefore(&newAllocOp->getBlock()->back());

auto type = cast<MemRefType>(newAllocOp.getType());
// Create new logicalobjectfifo.from_memref for the newly created L2 buffer.
rewriter.setInsertionPoint(l2ToL1DmaOp.getSourceObjectFifo());
auto source = rewriter.create<AMDAIE::LogicalObjectFifoFromMemrefOp>(
rewriter.getUnknownLoc(), LogicalObjectFifoType::get(type),
newAllocOp.getResult(), sourceObjectFifo.getTiles());

// Create new L3 -> L2 Dma Op. Since the narrowed L2 buffer is linearized,
// we need to form offset/size/stride corresponding to the linearized
// buffer.
SmallVector<OpFoldResult, 4> staticOffsets(
4, getAsIndexOpFoldResult(context, 0));
SmallVector<OpFoldResult, 4> staticSizes(
4, getAsIndexOpFoldResult(context, 1));
SmallVector<OpFoldResult, 4> staticStrides(
4, getAsIndexOpFoldResult(context, 0));
OpFoldResult linearizedShape =
getAsIndexOpFoldResult(context, newAllocType.getNumElements());
staticSizes[staticSizes.size() - 1] = linearizedShape;
staticStrides[staticStrides.size() - 1] =
getAsIndexOpFoldResult(context, 1);
staticStrides[staticStrides.size() - 2] = linearizedShape;
// Create new L3 -> L2 Dma Op.
rewriter.setInsertionPoint(l3ToL2DmaOp);
rewriter.create<AMDAIE::DmaCpyNdOp>(
l3ToL2DmaOp.getLoc(), source, llvm::ArrayRef(staticOffsets),
llvm::ArrayRef(staticSizes), llvm::ArrayRef(staticStrides),
l3ToL2DmaOp.getSource(), l3ToL2DmaOp.getSourceMixedOffsets(),
l3ToL2DmaOp.getSourceMixedSizes(), l3ToL2DmaOp.getSourceMixedStrides());
l3ToL2DmaOp.getLoc(), source, llvm::ArrayRef(staticL2AsTargetOffsets),
llvm::ArrayRef(staticL2AsTargetSizes),
llvm::ArrayRef(staticL2AsTargetStrides), l3ToL2DmaOp.getSource(),
l3ToL2DmaOp.getSourceMixedOffsets(), l3ToL2DmaOp.getSourceMixedSizes(),
l3ToL2DmaOp.getSourceMixedStrides());

// Create new L2 -> L1 Input DmaOp.
rewriter.setInsertionPoint(l2ToL1DmaOp);
auto newL2ToL1DmaOp = rewriter.create<AMDAIE::DmaCpyNdOp>(
l2ToL1DmaOp.getLoc(), l2ToL1DmaOp.getTarget(),
l2ToL1DmaOp.getTargetMixedOffsets(), l2ToL1DmaOp.getTargetMixedSizes(),
l2ToL1DmaOp.getTargetMixedStrides(), source,
llvm::ArrayRef(staticOffsets), llvm::ArrayRef(staticSizes),
llvm::ArrayRef(staticStrides));
llvm::ArrayRef(staticL2AsSourceOffsets),
llvm::ArrayRef(staticL2AsSourceSizes),
llvm::ArrayRef(staticL2AsSourceStrides));
rewriter.replaceOp(l2ToL1DmaOp, newL2ToL1DmaOp);
// We have to discard non-zero offsets as subview has been replaced by a
// dedicated allocated memref.
Expand Down

0 comments on commit 8ad64a0

Please sign in to comment.