-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
16 changed files
with
761 additions
and
634 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
191 changes: 191 additions & 0 deletions
191
compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIECanonicalizeNpuDmaCpyNd.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
// Copyright 2024 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#include <memory> | ||
|
||
#include "iree-amd-aie/IR/AMDAIEDialect.h" | ||
#include "iree-amd-aie/IR/AMDAIEOps.h" | ||
#include "iree-amd-aie/Transforms/Passes.h" | ||
#include "mlir/Dialect/Utils/StaticValueUtils.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Pass/PassManager.h" | ||
|
||
#define DEBUG_TYPE "iree-amdaie-canonicalize-npu-dma-cpy-nd" | ||
|
||
namespace mlir::iree_compiler::AMDAIE { | ||
|
||
class AMDAIECanonicalizeNpuDmaCpyNdPass | ||
: public impl::AMDAIECanonicalizeNpuDmaCpyNdBase< | ||
AMDAIECanonicalizeNpuDmaCpyNdPass> { | ||
public: | ||
void getDependentDialects(DialectRegistry ®istry) const override { | ||
registry.insert<AMDAIEDialect>(); | ||
} | ||
|
||
AMDAIECanonicalizeNpuDmaCpyNdPass() = default; | ||
AMDAIECanonicalizeNpuDmaCpyNdPass( | ||
const AMDAIECanonicalizeNpuDmaCpyNdPass &){}; | ||
|
||
void runOnOperation() override { | ||
MLIRContext *context = &getContext(); | ||
ModuleOp moduleOp = getOperation(); | ||
IRRewriter rewriter(context); | ||
Attribute zero = rewriter.getIndexAttr(0); | ||
Attribute one = rewriter.getIndexAttr(1); | ||
|
||
WalkResult walkResult = moduleOp->walk([&](NpuDmaCpyNdOp dmaOp) { | ||
SmallVector<OpFoldResult> srcOffsets = dmaOp.getSourceMixedOffsets(); | ||
SmallVector<OpFoldResult> srcSizes = dmaOp.getSourceMixedSizes(); | ||
SmallVector<OpFoldResult> srcStrides = dmaOp.getSourceMixedStrides(); | ||
|
||
SmallVector<OpFoldResult> tgtOffsets = dmaOp.getTargetMixedOffsets(); | ||
SmallVector<OpFoldResult> tgtSizes = dmaOp.getTargetMixedSizes(); | ||
SmallVector<OpFoldResult> tgtStrides = dmaOp.getTargetMixedStrides(); | ||
|
||
bool allValidRanks = srcOffsets.size() <= nbDimensions && | ||
srcSizes.size() <= nbDimensions && | ||
srcStrides.size() <= nbDimensions && | ||
tgtOffsets.size() <= nbDimensions && | ||
tgtSizes.size() <= nbDimensions && | ||
tgtStrides.size() <= nbDimensions; | ||
if (!allValidRanks) { | ||
dmaOp.emitOpError() << " has offsets/sizes/strides attributes that are " | ||
"larger than the target dimension of " | ||
<< nbDimensions << "."; | ||
return WalkResult::interrupt(); | ||
} | ||
|
||
if (dmaOp.getSourceMemorySpaceAsUInt() == 0) { | ||
if (!dmaOp.hasSourceAddressing()) { | ||
dmaOp.emitOpError() | ||
<< "has source in L3, but does not have source addressing. " | ||
"Source addressing is required to canonicalize here."; | ||
return WalkResult::interrupt(); | ||
} | ||
srcOffsets = getPrepended(srcOffsets, zero); | ||
srcSizes = getPrepended(srcSizes, one); | ||
srcStrides = getPrepended(srcStrides, zero); | ||
std::optional<uint32_t> maybeSwapIndex = | ||
verifyAndGetZeroStrideIndex(srcSizes, srcStrides, dmaOp); | ||
if (!maybeSwapIndex.has_value()) { | ||
return WalkResult::interrupt(); | ||
} | ||
uint32_t swapIndex = maybeSwapIndex.value(); | ||
bubble(srcOffsets, swapIndex); | ||
bubble(srcSizes, swapIndex); | ||
bubble(srcStrides, swapIndex); | ||
} | ||
|
||
if (dmaOp.getTargetMemorySpaceAsUInt() == 0) { | ||
if (!dmaOp.hasTargetAddressing()) { | ||
dmaOp.emitOpError() | ||
<< "has target in L3, but does not have target addressing. " | ||
"Target addressing is required to canonicalize here."; | ||
return WalkResult::interrupt(); | ||
} | ||
tgtOffsets = getPrepended(tgtOffsets, zero); | ||
tgtSizes = getPrepended(tgtSizes, one); | ||
tgtStrides = getPrepended(tgtStrides, zero); | ||
std::optional<uint32_t> maybeSwapIndex = | ||
verifyAndGetZeroStrideIndex(tgtSizes, tgtStrides, dmaOp); | ||
if (!maybeSwapIndex.has_value()) { | ||
return WalkResult::interrupt(); | ||
} | ||
uint32_t swapIndex = maybeSwapIndex.value(); | ||
bubble(tgtOffsets, swapIndex); | ||
bubble(tgtSizes, swapIndex); | ||
bubble(tgtStrides, swapIndex); | ||
} | ||
|
||
rewriter.setInsertionPoint(dmaOp); | ||
|
||
// Replace the npu.dma_cpy_nd with the canonicalized version. | ||
dmaOp = rewriter.replaceOpWithNewOp<AMDAIE::NpuDmaCpyNdOp>( | ||
dmaOp, dmaOp.getDma(), dmaOp.getTarget(), tgtOffsets, tgtSizes, | ||
tgtStrides, dmaOp.getTargetBdId(), dmaOp.getSource(), srcOffsets, | ||
srcSizes, srcStrides, dmaOp.getSourceBdId()); | ||
|
||
return WalkResult::advance(); | ||
}); | ||
|
||
if (walkResult.wasInterrupted()) { | ||
return signalPassFailure(); | ||
} | ||
} | ||
|
||
private: | ||
// Repeat prepend 'def' to 'tail' to make 'tail' have nbDimensions elements. | ||
SmallVector<OpFoldResult> getPrepended(ArrayRef<OpFoldResult> tail, | ||
Attribute def) { | ||
assert(tail.size() <= nbDimensions); | ||
SmallVector<OpFoldResult> res(nbDimensions, def); | ||
std::copy(tail.begin(), tail.end(), | ||
res.begin() + nbDimensions - tail.size()); | ||
return res; | ||
} | ||
|
||
static size_t getLowestIndexMaybeAboveOne(ArrayRef<OpFoldResult> v) { | ||
for (size_t i = 0; i < v.size(); i++) { | ||
std::optional<int64_t> maybe = getConstantIntValue(v[i]); | ||
if (!maybe.has_value() || maybe.value() > 1) { | ||
return i; | ||
} | ||
} | ||
return v.size(); | ||
} | ||
|
||
static size_t getHighestIndexMaybeZero(ArrayRef<OpFoldResult> v) { | ||
for (size_t i = v.size(); i > 0; i--) { | ||
std::optional<int64_t> maybe = getConstantIntValue(v[i - 1]); | ||
if (!maybe.has_value() || maybe.value() == 0) { | ||
return i - 1; | ||
} | ||
} | ||
return 0; | ||
} | ||
|
||
/// Get the highest index where the stride is 0. If this index is greater | ||
/// than the lowest index where the size is greater than 1, then fail. | ||
std::optional<uint32_t> verifyAndGetZeroStrideIndex( | ||
ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides, | ||
NpuDmaCpyNdOp dmaOp) { | ||
assert(strides.size() == sizes.size() && strides.size() == nbDimensions); | ||
|
||
size_t firstNonUnitDim = getLowestIndexMaybeAboveOne(sizes); | ||
size_t lastZeroStrideDim = getHighestIndexMaybeZero(strides); | ||
|
||
if (firstNonUnitDim < lastZeroStrideDim) { | ||
// Limitation until AIE-4. | ||
dmaOp.emitOpError("might have stride=0 in dimension ") | ||
<< lastZeroStrideDim << ", and size>1 in dimension " | ||
<< firstNonUnitDim << ". As " << firstNonUnitDim << " < " | ||
<< lastZeroStrideDim | ||
<< ", this cannot be supported -- the zero stride cannot be moved " | ||
"to the outer-most (slowest) dimension, as required by current " | ||
"AIE architecture."; | ||
return {}; | ||
} | ||
return lastZeroStrideDim; | ||
} | ||
|
||
// Example, for swapIndex = 2. | ||
// Input | ||
// [0 1 7 13] | ||
// is mutated to | ||
// [7 0 1 13] | ||
static void bubble(MutableArrayRef<OpFoldResult> arr, size_t swapIndex) { | ||
if (swapIndex > 0) { | ||
std::rotate(arr.begin(), arr.begin() + swapIndex, | ||
arr.begin() + swapIndex + 1); | ||
} | ||
} | ||
}; | ||
|
||
std::unique_ptr<Pass> createAMDAIECanonicalizeNpuDmaCpyNdPass() { | ||
return std::make_unique<AMDAIECanonicalizeNpuDmaCpyNdPass>(); | ||
} | ||
|
||
} // namespace mlir::iree_compiler::AMDAIE |
Oops, something went wrong.