Skip to content

Commit

Permalink
Add logical objFifo placeholder op for connection reuse (#709)
Browse files Browse the repository at this point in the history
Adds the `amdaie.logicalobjectfifo.placeholder` operation that
represents a logical objectFifo to be filled in later. This enables
reuse of connections/circular DMAs/physical AIE channels for different
data packets, which helps with (fused) operations with more than 2
inputs. This is especially useful for reading from/writing to DDR.
  • Loading branch information
jtuyls authored Aug 27, 2024
1 parent e68cbd2 commit f8f31a8
Show file tree
Hide file tree
Showing 15 changed files with 782 additions and 327 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree-amd-aie/IR/AMDAIELogicalObjFifoOpInterface.h"

/// Include the definitions of the logical-objFifo-like interfaces.
#include "iree-amd-aie/IR/AMDAIELogicalObjFifoOpInterface.cpp.inc"
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#ifndef IREE_COMPILER_AMDAIE_LOGICALOBJFIFOOPINTERFACE_H_
#define IREE_COMPILER_AMDAIE_LOGICALOBJFIFOOPINTERFACE_H_

#include "mlir/IR/OpImplementation.h"

// clang-format off
#include "iree-amd-aie/IR/AMDAIELogicalObjFifoOpInterface.h.inc"
// clang-format on

#endif // IREE_COMPILER_AMDAIE_LOGICALOBJFIFOOPINTERFACE_H_
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#ifndef IREE_AMDAIE_DIALECT_LOGICALOBJFIFOOPINTERFACE
#define IREE_AMDAIE_DIALECT_LOGICALOBJFIFOOPINTERFACE

include "mlir/IR/OpBase.td"
include "mlir/Interfaces/CopyOpInterface.td"

//===----------------------------------------------------------------------===//
// Defines the interface for logical objectFifo operations.
//===----------------------------------------------------------------------===//

def LogicalObjFifoOpInterface : OpInterface<"LogicalObjFifoOpInterface"> {
let description = [{
Interface for operations creating a logical objectFifo.
}];
let cppNamespace = "mlir::iree_compiler::AMDAIE";

let methods = [
InterfaceMethod<
/*desc=*/"Return the assigned tiles.",
/*retTy=*/"::mlir::OperandRange",
/*methodName=*/"getTiles",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return $_op.getTiles();
}]
>
];
}

#endif // IREE_AMDAIE_DIALECT_LOGICALOBJFIFOOPINTERFACE
275 changes: 239 additions & 36 deletions compiler/plugins/target/AMD-AIE/iree-amd-aie/IR/AMDAIEOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -522,16 +522,14 @@ void LogicalObjectFifoRelease::build(OpBuilder &b, mlir::OperationState &result,
// AMDAIE_NpuDmaCpyNdOp
//===----------------------------------------------------------------------===//

// Build a NpuDmaCpyNdOp with mixed static and dynamic entries and target
// and source BD IDs.
void NpuDmaCpyNdOp::build(OpBuilder &b, OperationState &result, Value dma,
ArrayRef<OpFoldResult> targetOffsets,
ArrayRef<OpFoldResult> targetSizes,
ArrayRef<OpFoldResult> targetStrides,
ArrayRef<OpFoldResult> sourceOffsets,
ArrayRef<OpFoldResult> sourceSizes,
ArrayRef<OpFoldResult> sourceStrides,
mlir::Value targetBdId, mlir::Value sourceBdId) {
// Build a NpuDmaCpyNdOp with mixed static and dynamic entries and target and
// source BD IDs.
void NpuDmaCpyNdOp::build(
OpBuilder &b, OperationState &result, Value dma, Value target,
ArrayRef<OpFoldResult> targetOffsets, ArrayRef<OpFoldResult> targetSizes,
ArrayRef<OpFoldResult> targetStrides, Value targetBdId, Value source,
ArrayRef<OpFoldResult> sourceOffsets, ArrayRef<OpFoldResult> sourceSizes,
ArrayRef<OpFoldResult> sourceStrides, Value sourceBdId) {
SmallVector<int64_t> staticTargetOffsets, staticTargetSizes,
staticTargetStrides;
SmallVector<int64_t> staticSourceOffsets, staticSourceSizes,
Expand All @@ -552,22 +550,21 @@ void NpuDmaCpyNdOp::build(OpBuilder &b, OperationState &result, Value dma,
staticSourceSizes);
dispatchIndexOpFoldResults(sourceStrides, dynamicSourceStrides,
staticSourceStrides);
build(b, result, b.getIndexType(), dma, dynamicTargetOffsets,
build(b, result, b.getIndexType(), dma, target, dynamicTargetOffsets,
dynamicTargetSizes, dynamicTargetStrides, staticTargetOffsets,
staticTargetSizes, staticTargetStrides, dynamicSourceOffsets,
dynamicSourceSizes, dynamicSourceStrides, staticSourceOffsets,
staticSourceSizes, staticSourceStrides, targetBdId, sourceBdId);
staticTargetSizes, staticTargetStrides, targetBdId, source,
dynamicSourceOffsets, dynamicSourceSizes, dynamicSourceStrides,
staticSourceOffsets, staticSourceSizes, staticSourceStrides,
sourceBdId);
}

// Build a NpuDmaCpyNdOp with static entries.
void NpuDmaCpyNdOp::build(OpBuilder &b, OperationState &result, Value dma,
ArrayRef<int64_t> targetOffsets,
ArrayRef<int64_t> targetSizes,
ArrayRef<int64_t> targetStrides,
ArrayRef<int64_t> sourceOffsets,
ArrayRef<int64_t> sourceSizes,
ArrayRef<int64_t> sourceStrides,
mlir::Value targetBdId, mlir::Value sourceBdId) {
void NpuDmaCpyNdOp::build(
OpBuilder &b, OperationState &result, Value dma, Value target,
ArrayRef<int64_t> targetOffsets, ArrayRef<int64_t> targetSizes,
ArrayRef<int64_t> targetStrides, mlir::Value targetBdId, Value source,
ArrayRef<int64_t> sourceOffsets, ArrayRef<int64_t> sourceSizes,
ArrayRef<int64_t> sourceStrides, mlir::Value sourceBdId) {
SmallVector<OpFoldResult> targetOffsetValues = llvm::to_vector<4>(
llvm::map_range(targetOffsets, [&](int64_t v) -> OpFoldResult {
return b.getI64IntegerAttr(v);
Expand All @@ -592,17 +589,18 @@ void NpuDmaCpyNdOp::build(OpBuilder &b, OperationState &result, Value dma,
llvm::map_range(sourceStrides, [&](int64_t v) -> OpFoldResult {
return b.getI64IntegerAttr(v);
}));
build(b, result, dma, targetOffsetValues, targetSizeValues,
targetStrideValues, sourceOffsetValues, sourceSizeValues,
sourceStrideValues, targetBdId, sourceBdId);
build(b, result, dma, target, targetOffsetValues, targetSizeValues,
targetStrideValues, targetBdId, source, sourceOffsetValues,
sourceSizeValues, sourceStrideValues, sourceBdId);
}

// Build a NpuDmaCpyNdOp with dynamic entries.
void NpuDmaCpyNdOp::build(OpBuilder &b, OperationState &result, Value dma,
ValueRange targetOffsets, ValueRange targetSizes,
ValueRange targetStrides, ValueRange sourceOffsets,
ValueRange sourceSizes, ValueRange sourceStrides,
mlir::Value targetBdId, mlir::Value sourceBdId) {
Value target, ValueRange targetOffsets,
ValueRange targetSizes, ValueRange targetStrides,
mlir::Value targetBdId, Value source,
ValueRange sourceOffsets, ValueRange sourceSizes,
ValueRange sourceStrides, mlir::Value sourceBdId) {
SmallVector<OpFoldResult> targetOffsetValues =
llvm::to_vector<4>(llvm::map_range(
targetOffsets, [](Value v) -> OpFoldResult { return v; }));
Expand All @@ -619,9 +617,212 @@ void NpuDmaCpyNdOp::build(OpBuilder &b, OperationState &result, Value dma,
SmallVector<OpFoldResult> sourceStrideValues =
llvm::to_vector<4>(llvm::map_range(
sourceStrides, [](Value v) -> OpFoldResult { return v; }));
build(b, result, dma, targetOffsetValues, targetSizeValues,
targetStrideValues, sourceOffsetValues, sourceSizeValues,
sourceStrideValues, targetBdId, sourceBdId);
build(b, result, dma, target, targetOffsetValues, targetSizeValues,
targetStrideValues, targetBdId, source, sourceOffsetValues,
sourceSizeValues, sourceStrideValues, sourceBdId);
}

void NpuDmaCpyNdOp::print(OpAsmPrinter &p) {
Operation *op = getOperation();
p << " " << getDma() << "(";
if (getTarget()) p << getTarget();
printDynamicIndexList(p, op, getTargetOffsets(), getTargetStaticOffsets());
p << " ";
printDynamicIndexList(p, op, getTargetSizes(), getTargetStaticSizes());
p << " ";
printDynamicIndexList(p, op, getTargetStrides(), getTargetStaticStrides());
if (getTargetBdId()) p << " bd_id = " << getTargetBdId();
p << ", ";
if (getSource()) p << getSource();
printDynamicIndexList(p, op, getSourceOffsets(), getSourceStaticOffsets());
p << " ";
printDynamicIndexList(p, op, getSourceSizes(), getSourceStaticSizes());
p << " ";
printDynamicIndexList(p, op, getSourceStrides(), getSourceStaticStrides());
if (getSourceBdId()) p << " bd_id = " << getSourceBdId();
p << ")";
SmallVector<StringRef, 7> elidedAttrs;
elidedAttrs.push_back("operandSegmentSizes");
elidedAttrs.push_back("target_static_offsets");
elidedAttrs.push_back("target_static_sizes");
elidedAttrs.push_back("target_static_strides");
elidedAttrs.push_back("source_static_offsets");
elidedAttrs.push_back("source_static_sizes");
elidedAttrs.push_back("source_static_strides");
p.printOptionalAttrDictWithKeyword(op->getAttrs(), elidedAttrs);
if (getTarget() || getSource()) p << " :";
if (getTarget()) p << " target_type = " << getTarget().getType();
if (getSource()) p << " source_type = " << getSource().getType();
}

ParseResult NpuDmaCpyNdOp::parse(OpAsmParser &parser, OperationState &result) {
OpBuilder b(parser.getContext());
auto indexType = b.getIndexType();

SMLoc targetOperandsLoc, sourceOperandsLoc;
OpAsmParser::UnresolvedOperand dma;
SmallVector<OpAsmParser::UnresolvedOperand, 1> targetOperands, sourceOperands,
targetBdIdOperands, sourceBdIdOperands;
DenseI64ArrayAttr targetStaticOffsets, targetStaticSizes, targetStaticStrides;
SmallVector<OpAsmParser::UnresolvedOperand, 4> targetDynamicOffsets,
targetDynamicSizes, targetDynamicStrides;
DenseI64ArrayAttr sourceStaticOffsets, sourceStaticSizes, sourceStaticStrides;
SmallVector<OpAsmParser::UnresolvedOperand, 4> sourceDynamicOffsets,
sourceDynamicSizes, sourceDynamicStrides;
SmallVector<Type, 1> targetTypes;
SmallVector<Type, 1> sourceTypes;

if (failed(parser.parseOperand(dma)) || failed(parser.parseLParen()))
return failure();

OpAsmParser::UnresolvedOperand target;
if (parser.parseOptionalOperand(target).has_value()) {
targetOperands.push_back(target);
}
if (failed(parseDynamicIndexList(parser, targetDynamicOffsets,
targetStaticOffsets))) {
return failure();
}
result.getOrAddProperties<NpuDmaCpyNdOp::Properties>().target_static_offsets =
targetStaticOffsets;
if (failed(parseDynamicIndexList(parser, targetDynamicSizes,
targetStaticSizes))) {
return failure();
}
result.getOrAddProperties<NpuDmaCpyNdOp::Properties>().target_static_sizes =
targetStaticSizes;
if (failed(parseDynamicIndexList(parser, targetDynamicStrides,
targetStaticStrides))) {
return failure();
}
result.getOrAddProperties<NpuDmaCpyNdOp::Properties>().target_static_strides =
targetStaticStrides;

if (succeeded(parser.parseOptionalKeyword("bd_id"))) {
if (failed(parser.parseEqual())) return failure();
OpAsmParser::UnresolvedOperand bdId;
if (failed(parser.parseOperand(bdId))) return failure();
targetBdIdOperands.push_back(bdId);
}

if (failed(parser.parseComma())) return failure();

OpAsmParser::UnresolvedOperand source;
if (parser.parseOptionalOperand(source).has_value()) {
sourceOperands.push_back(source);
}
if (failed(parseDynamicIndexList(parser, sourceDynamicOffsets,
sourceStaticOffsets))) {
return failure();
}
result.getOrAddProperties<NpuDmaCpyNdOp::Properties>().source_static_offsets =
sourceStaticOffsets;
if (failed(parseDynamicIndexList(parser, sourceDynamicSizes,
sourceStaticSizes))) {
return failure();
}
result.getOrAddProperties<NpuDmaCpyNdOp::Properties>().source_static_sizes =
sourceStaticSizes;
if (failed(parseDynamicIndexList(parser, sourceDynamicStrides,
sourceStaticStrides))) {
return failure();
}
result.getOrAddProperties<NpuDmaCpyNdOp::Properties>().source_static_strides =
sourceStaticStrides;

if (succeeded(parser.parseOptionalKeyword("bd_id"))) {
if (failed(parser.parseEqual())) return failure();
OpAsmParser::UnresolvedOperand bdId;
if (failed(parser.parseOperand(bdId))) return failure();
sourceBdIdOperands.push_back(bdId);
}

if (failed(parser.parseRParen())) return failure();
{
auto loc = parser.getCurrentLocation();
if (parser.parseOptionalAttrDict(result.attributes)) return failure();
if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
return parser.emitError(loc)
<< "'" << result.name.getStringRef() << "' op ";
}))) {
return failure();
}
}

if (succeeded(parser.parseOptionalColon())) {
if (succeeded(parser.parseOptionalKeyword("target_type"))) {
if (parser.parseEqual()) return failure();
Type targetType;
if (failed(parser.parseType(targetType))) return failure();
targetTypes.push_back(targetType);
}
if (succeeded(parser.parseOptionalKeyword("source_type"))) {
if (parser.parseEqual()) return failure();
Type sourceType;
if (failed(parser.parseType(sourceType))) return failure();
sourceTypes.push_back(sourceType);
}
}

llvm::copy(
ArrayRef<int32_t>({1, static_cast<int32_t>(targetOperands.size()),
static_cast<int32_t>(targetDynamicOffsets.size()),
static_cast<int32_t>(targetDynamicSizes.size()),
static_cast<int32_t>(targetDynamicStrides.size()),
static_cast<int32_t>(targetBdIdOperands.size()),
static_cast<int32_t>(sourceOperands.size()),
static_cast<int32_t>(sourceDynamicOffsets.size()),
static_cast<int32_t>(sourceDynamicSizes.size()),
static_cast<int32_t>(sourceDynamicStrides.size()),
static_cast<int32_t>(sourceBdIdOperands.size())}),
result.getOrAddProperties<NpuDmaCpyNdOp::Properties>()
.operandSegmentSizes.begin());

if (failed(parser.resolveOperand(dma, indexType, result.operands)))
return failure();
if (failed(parser.resolveOperands(targetOperands, targetTypes,
targetOperandsLoc, result.operands))) {
return failure();
}
if (failed(parser.resolveOperands(targetDynamicOffsets, indexType,
result.operands))) {
return failure();
}
if (failed(parser.resolveOperands(targetDynamicSizes, indexType,
result.operands))) {
return failure();
}
if (failed(parser.resolveOperands(targetDynamicStrides, indexType,
result.operands))) {
return failure();
}
if (failed(parser.resolveOperands(targetBdIdOperands, indexType,
result.operands))) {
return failure();
}
if (failed(parser.resolveOperands(sourceOperands, sourceTypes,
sourceOperandsLoc, result.operands))) {
return failure();
}
if (failed(parser.resolveOperands(sourceDynamicOffsets, indexType,
result.operands))) {
return failure();
}
if (failed(parser.resolveOperands(sourceDynamicSizes, indexType,
result.operands))) {
return failure();
}
if (failed(parser.resolveOperands(sourceDynamicStrides, indexType,
result.operands))) {
return failure();
}
if (failed(parser.resolveOperands(sourceBdIdOperands, indexType,
result.operands))) {
return failure();
}

result.addTypes(indexType);
return success();
}

DoublyStridedOpInterface NpuDmaCpyNdOp::createDoublyStridedOp(
Expand All @@ -634,14 +835,15 @@ DoublyStridedOpInterface NpuDmaCpyNdOp::createDoublyStridedOp(
::llvm::SmallVector<OpFoldResult> &newSourceStrides) {
Location loc = (*this)->getLoc();
auto newOp = rewriter.create<AMDAIE::NpuDmaCpyNdOp>(
loc, getDma(),
loc, getDma(), getTarget(),
getValueOrCreateConstantIndexOp(rewriter, loc, newTargetOffsets),
getValueOrCreateConstantIndexOp(rewriter, loc, newTargetSizes),
getValueOrCreateConstantIndexOp(rewriter, loc, newTargetStrides),
getTargetBdId(), getSource(),
getValueOrCreateConstantIndexOp(rewriter, loc, newSourceOffsets),
getValueOrCreateConstantIndexOp(rewriter, loc, newSourceSizes),
getValueOrCreateConstantIndexOp(rewriter, loc, newSourceStrides),
getTargetBdId(), getSourceBdId());
getSourceBdId());
return cast<DoublyStridedOpInterface>(newOp.getOperation());
}

Expand All @@ -660,8 +862,9 @@ struct NpuDmaCpyNdOpReplacementBuilder {
ArrayRef<OpFoldResult> srcMixedSizes,
ArrayRef<OpFoldResult> srcMixedStrides) {
rewriter.replaceOpWithNewOp<NpuDmaCpyNdOp>(
dmaOp, dmaOp.getDma(), tgtMixedOffsets, tgtMixedSizes, tgtMixedStrides,
srcMixedOffsets, srcMixedSizes, srcMixedStrides, dmaOp.getTargetBdId(),
dmaOp, dmaOp.getDma(), dmaOp.getTarget(), tgtMixedOffsets,
tgtMixedSizes, tgtMixedStrides, dmaOp.getTargetBdId(),
dmaOp.getSource(), srcMixedOffsets, srcMixedSizes, srcMixedStrides,
dmaOp.getSourceBdId());
}
};
Expand Down
Loading

0 comments on commit f8f31a8

Please sign in to comment.