Skip to content

Commit

Permalink
[aievec][nfc] Clean-up aievec to llvm conversion
Browse files Browse the repository at this point in the history
This code needed updating its use of a couple of constructs, and
namespaces.
  • Loading branch information
jsetoain committed Sep 20, 2023
1 parent dcb5083 commit 01c79ae
Showing 1 changed file with 80 additions and 92 deletions.
172 changes: 80 additions & 92 deletions lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
#include <sstream>

using namespace mlir;
using namespace xilinx;
using namespace xilinx::aievec;

namespace xilinx {
namespace aievec {
Expand All @@ -40,9 +38,9 @@ std::string getVectorTypeString(VectorType type, bool abbrev = false,
std::stringstream ss;
auto size = getVectorLaneSize(type);
ss << "v" << size;
if (auto intType = type.getElementType().dyn_cast<IntegerType>()) {
if (auto intType = dyn_cast<IntegerType>(type.getElementType())) {
ss << (acc ? "acc" : abbrev ? "i" : "int") << intType.getWidth();
} else if (auto floatType = type.getElementType().dyn_cast<FloatType>()) {
} else if (auto floatType = dyn_cast<FloatType>(type.getElementType())) {
ss << (abbrev ? "f" : "float");
}
return ss.str();
Expand All @@ -51,26 +49,26 @@ std::string getVectorTypeString(VectorType type, bool abbrev = false,
std::string getMulOrFMAIntrinsicName(Operation *op) {
std::string baseName;
Value lhs, rhs, result;
if (auto mulOp = dyn_cast<xilinx::aievec::MulOp>(op)) {
if (auto mulOp = dyn_cast<aievec::MulOp>(op)) {
baseName = "mul";
lhs = mulOp.getLhs();
rhs = mulOp.getRhs();
result = mulOp.getResult();
} else if (auto fmaOp = dyn_cast<xilinx::aievec::FMAOp>(op)) {
} else if (auto fmaOp = dyn_cast<aievec::FMAOp>(op)) {
baseName = "mac";
lhs = fmaOp.getLhs();
rhs = fmaOp.getRhs();
result = fmaOp.getResult();
}
VectorType resultType = result.getType().cast<VectorType>();
VectorType resultType = cast<VectorType>(result.getType());
int resultSize = getVectorLaneSize(resultType);
std::stringstream ss;
ss << "llvm.aie.";
if (auto intType = resultType.getElementType().dyn_cast<IntegerType>()) {
if (auto intType = dyn_cast<IntegerType>(resultType.getElementType())) {
ss << baseName;
ss << resultSize << "."
<< getVectorTypeString(lhs.getType().cast<VectorType>());
} else if (resultType.getElementType().dyn_cast<FloatType>()) {
<< getVectorTypeString(cast<VectorType>(lhs.getType()));
} else if (dyn_cast<FloatType>(resultType.getElementType())) {
ss << "vfp" << baseName;
}
return ss.str();
Expand All @@ -96,39 +94,36 @@ void encodeConf(uint32_t conf[2], const BufferParams &x, const BufferParams &z,
conf[1] |= sub << 17;
}

class AddOpConversion
: public mlir::ConvertOpToLLVMPattern<xilinx::aievec::AddOp> {
class AddOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::AddOp> {
public:
using ConvertOpToLLVMPattern<xilinx::aievec::AddOp>::ConvertOpToLLVMPattern;
using ConvertOpToLLVMPattern<aievec::AddOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(xilinx::aievec::AddOp op, OpAdaptor adaptor,
matchAndRewrite(aievec::AddOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
op.emitWarning() << "aie.add conversion is not implemented\n";
return failure();
}
};

class SubOpConversion
: public mlir::ConvertOpToLLVMPattern<xilinx::aievec::SubOp> {
class SubOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::SubOp> {
public:
using ConvertOpToLLVMPattern<xilinx::aievec::SubOp>::ConvertOpToLLVMPattern;
using ConvertOpToLLVMPattern<aievec::SubOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(xilinx::aievec::SubOp op, OpAdaptor adaptor,
matchAndRewrite(aievec::SubOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
op.emitWarning() << "aie.sub conversion is not implemented\n";
return failure();
}
};

class FMAOpConversion
: public mlir::ConvertOpToLLVMPattern<xilinx::aievec::FMAOp> {
class FMAOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::FMAOp> {
public:
using ConvertOpToLLVMPattern<xilinx::aievec::FMAOp>::ConvertOpToLLVMPattern;
using ConvertOpToLLVMPattern<aievec::FMAOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(xilinx::aievec::FMAOp op, OpAdaptor adaptor,
matchAndRewrite(aievec::FMAOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto module = op->getParentOfType<ModuleOp>();
MLIRContext *context = rewriter.getContext();
Expand Down Expand Up @@ -200,13 +195,12 @@ class FMAOpConversion
}
};

class MulOpConversion
: public mlir::ConvertOpToLLVMPattern<xilinx::aievec::MulOp> {
class MulOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::MulOp> {
public:
using ConvertOpToLLVMPattern<xilinx::aievec::MulOp>::ConvertOpToLLVMPattern;
using ConvertOpToLLVMPattern<aievec::MulOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(xilinx::aievec::MulOp op, OpAdaptor adaptor,
matchAndRewrite(aievec::MulOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto module = op->getParentOfType<ModuleOp>();
MLIRContext *context = rewriter.getContext();
Expand Down Expand Up @@ -278,33 +272,31 @@ class MulOpConversion
}
};

class UPSOpConversion
: public mlir::ConvertOpToLLVMPattern<xilinx::aievec::UPSOp> {
class UPSOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::UPSOp> {
public:
using ConvertOpToLLVMPattern<xilinx::aievec::UPSOp>::ConvertOpToLLVMPattern;
using ConvertOpToLLVMPattern<aievec::UPSOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(xilinx::aievec::UPSOp op, OpAdaptor adaptor,
matchAndRewrite(aievec::UPSOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
op.emitWarning() << "aie.ups conversion is not implemented\n";
return failure();
}
};

class SRSOpConversion
: public mlir::ConvertOpToLLVMPattern<xilinx::aievec::SRSOp> {
class SRSOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::SRSOp> {
public:
using ConvertOpToLLVMPattern<xilinx::aievec::SRSOp>::ConvertOpToLLVMPattern;
using ConvertOpToLLVMPattern<aievec::SRSOp>::ConvertOpToLLVMPattern;

static std::string getIntrinsicName(xilinx::aievec::SRSOp op) {
static std::string getIntrinsicName(aievec::SRSOp op) {
std::stringstream ss;
ss << "llvm.aie.";

// Determine the prefix
auto sourceType = op.getSource().getType().cast<VectorType>();
auto resultType = op.getResult().getType().cast<VectorType>();
auto sourceElType = sourceType.getElementType().cast<IntegerType>();
auto resultElType = resultType.getElementType().cast<IntegerType>();
auto sourceType = cast<VectorType>(op.getSource().getType());
auto resultType = cast<VectorType>(op.getResult().getType());
auto sourceElType = cast<IntegerType>(sourceType.getElementType());
auto resultElType = cast<IntegerType>(resultType.getElementType());

auto sourceElWidth = sourceElType.getWidth();
auto resultElWidth = resultElType.getWidth();
Expand All @@ -322,7 +314,7 @@ class SRSOpConversion
}

LogicalResult
matchAndRewrite(xilinx::aievec::SRSOp op, OpAdaptor adaptor,
matchAndRewrite(aievec::SRSOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// If the intrinsic declaration doesn't exist, create it
std::string intrinsicName = getIntrinsicName(op);
Expand Down Expand Up @@ -350,13 +342,12 @@ class SRSOpConversion
}
};

class UPDOpConversion
: public mlir::ConvertOpToLLVMPattern<xilinx::aievec::UPDOp> {
class UPDOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::UPDOp> {
public:
using ConvertOpToLLVMPattern<xilinx::aievec::UPDOp>::ConvertOpToLLVMPattern;
using ConvertOpToLLVMPattern<aievec::UPDOp>::ConvertOpToLLVMPattern;

static std::string getIntrinsicName(xilinx::aievec::UPDOp op, int loadSize) {
auto resultType = op.getResult().getType().cast<VectorType>();
static std::string getIntrinsicName(aievec::UPDOp op, int loadSize) {
auto resultType = cast<VectorType>(op.getResult().getType());
std::stringstream ss;
ss << "llvm.aie.upd.";
ss << (loadSize == 128 ? 'v' : loadSize == 256 ? 'w' : 'x') << ".";
Expand All @@ -367,7 +358,7 @@ class UPDOpConversion
}

LogicalResult
matchAndRewrite(xilinx::aievec::UPDOp op, OpAdaptor adaptor,
matchAndRewrite(aievec::UPDOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto module = op->getParentOfType<ModuleOp>();
MLIRContext *context = rewriter.getContext();
Expand All @@ -376,10 +367,10 @@ class UPDOpConversion
// AIE1 is capable of 128-bit on one bank and 256-bit loads on even-odd
// banks Identify size of update
int vecSizeInBits =
getVectorSizeInBits(op.getResult().getType().cast<VectorType>());
getVectorSizeInBits(cast<VectorType>(op.getResult().getType()));

auto ptr = this->getStridedElementPtr(
op->getLoc(), op.getSource().getType().cast<MemRefType>(),
op->getLoc(), cast<MemRefType>(op.getSource().getType()),
adaptor.getSource(), adaptor.getIndices(), rewriter);

// TODO: handle the offset field
Expand All @@ -389,8 +380,8 @@ class UPDOpConversion
// we can do a direct load into the vector register
// look at the indices to calculate the address
auto vectorPtrType = LLVM::LLVMPointerType::get(
op.getResult().getType().cast<VectorType>(),
op.getSource().getType().cast<MemRefType>().getMemorySpaceAsInt());
cast<VectorType>(op.getResult().getType()),
cast<MemRefType>(op.getSource().getType()).getMemorySpaceAsInt());
auto castedPtr =
rewriter.create<LLVM::BitcastOp>(op->getLoc(), vectorPtrType, ptr);
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, castedPtr, 1);
Expand All @@ -407,15 +398,15 @@ class UPDOpConversion

// Create a vectorType for the load proper
// Load half of the final result vector
auto resultType = op.getResult().getType().cast<VectorType>();
auto resultType = cast<VectorType>(op.getResult().getType());
int lanes = getVectorLaneSize(resultType);
auto loadType =
VectorType::get({(int64_t)lanes / 2}, resultType.getElementType());

// Load the vector
auto vectorPtrType = LLVM::LLVMPointerType::get(
loadType,
op.getSource().getType().cast<MemRefType>().getMemorySpaceAsInt());
cast<MemRefType>(op.getSource().getType()).getMemorySpaceAsInt());
auto castedPtr =
rewriter.create<LLVM::BitcastOp>(op->getLoc(), vectorPtrType, ptr);
auto loadValue =
Expand Down Expand Up @@ -478,21 +469,20 @@ class UPDOpConversion
};

class ConcatOpConversion
: public mlir::ConvertOpToLLVMPattern<xilinx::aievec::ConcatOp> {
: public mlir::ConvertOpToLLVMPattern<aievec::ConcatOp> {
public:
using ConvertOpToLLVMPattern<
xilinx::aievec::ConcatOp>::ConvertOpToLLVMPattern;
using ConvertOpToLLVMPattern<aievec::ConcatOp>::ConvertOpToLLVMPattern;

static std::string getIntrinsicName(xilinx::aievec::ConcatOp op) {
auto sourceType = op.getSources()[0].getType().cast<VectorType>();
static std::string getIntrinsicName(aievec::ConcatOp op) {
auto sourceType = cast<VectorType>(op.getSources()[0].getType());
std::stringstream ss;
ss << "llvm.aie.concat.";
ss << getVectorTypeString(sourceType, true);
return ss.str();
}

LogicalResult
matchAndRewrite(xilinx::aievec::ConcatOp op, OpAdaptor adaptor,
matchAndRewrite(aievec::ConcatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto module = op->getParentOfType<ModuleOp>();
MLIRContext *context = rewriter.getContext();
Expand All @@ -519,14 +509,13 @@ class ConcatOpConversion
}
};

class ExtOpConversion
: public mlir::ConvertOpToLLVMPattern<xilinx::aievec::ExtOp> {
class ExtOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::ExtOp> {
public:
using ConvertOpToLLVMPattern<xilinx::aievec::ExtOp>::ConvertOpToLLVMPattern;
using ConvertOpToLLVMPattern<aievec::ExtOp>::ConvertOpToLLVMPattern;

static std::string getIntrinsicName(xilinx::aievec::ExtOp op) {
auto sourceType = op.getSource().getType().cast<VectorType>();
auto resultType = op.getResult().getType().cast<VectorType>();
static std::string getIntrinsicName(aievec::ExtOp op) {
auto sourceType = cast<VectorType>(op.getSource().getType());
auto resultType = cast<VectorType>(op.getResult().getType());
int resultSize = getVectorSizeInBits(resultType);
std::stringstream ss;
ss << "llvm.aie.ext.";
Expand All @@ -538,7 +527,7 @@ class ExtOpConversion
}

LogicalResult
matchAndRewrite(xilinx::aievec::ExtOp op, OpAdaptor adaptor,
matchAndRewrite(aievec::ExtOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto module = op->getParentOfType<ModuleOp>();
MLIRContext *context = rewriter.getContext();
Expand All @@ -564,20 +553,19 @@ class ExtOpConversion
};

class SelectOpConversion
: public mlir::ConvertOpToLLVMPattern<xilinx::aievec::SelectOp> {
: public mlir::ConvertOpToLLVMPattern<aievec::SelectOp> {
public:
using ConvertOpToLLVMPattern<
xilinx::aievec::SelectOp>::ConvertOpToLLVMPattern;
using ConvertOpToLLVMPattern<aievec::SelectOp>::ConvertOpToLLVMPattern;

static std::string getIntrinsicName(xilinx::aievec::SelectOp op) {
auto xbuffType = op.getXbuff().getType().cast<VectorType>();
static std::string getIntrinsicName(aievec::SelectOp op) {
auto xbuffType = cast<VectorType>(op.getXbuff().getType());
std::stringstream ss;
ss << "llvm.aie.prim." << getVectorTypeString(xbuffType) << ".select";
return ss.str();
}

LogicalResult
matchAndRewrite(xilinx::aievec::SelectOp op, OpAdaptor adaptor,
matchAndRewrite(aievec::SelectOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto module = op->getParentOfType<ModuleOp>();
MLIRContext *context = rewriter.getContext();
Expand Down Expand Up @@ -651,20 +639,19 @@ class SelectOpConversion
}
};

class PackOpConversion
: public mlir::ConvertOpToLLVMPattern<xilinx::aievec::PackOp> {
class PackOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::PackOp> {
public:
using ConvertOpToLLVMPattern<xilinx::aievec::PackOp>::ConvertOpToLLVMPattern;
using ConvertOpToLLVMPattern<aievec::PackOp>::ConvertOpToLLVMPattern;

static std::string getIntrinsicName(xilinx::aievec::PackOp op) {
auto sourceType = op.getSource().getType().cast<VectorType>();
static std::string getIntrinsicName(aievec::PackOp op) {
auto sourceType = cast<VectorType>(op.getSource().getType());
std::stringstream ss;
ss << "llvm.aie.pack." << getVectorTypeString(sourceType);
return ss.str();
}

LogicalResult
matchAndRewrite(xilinx::aievec::PackOp op, OpAdaptor adaptor,
matchAndRewrite(aievec::PackOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto module = op->getParentOfType<ModuleOp>();
MLIRContext *context = rewriter.getContext();
Expand All @@ -690,13 +677,12 @@ class PackOpConversion
};

class UnpackOpConversion
: public mlir::ConvertOpToLLVMPattern<xilinx::aievec::UnpackOp> {
: public mlir::ConvertOpToLLVMPattern<aievec::UnpackOp> {
public:
using ConvertOpToLLVMPattern<
xilinx::aievec::UnpackOp>::ConvertOpToLLVMPattern;
using ConvertOpToLLVMPattern<aievec::UnpackOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(xilinx::aievec::UnpackOp op, OpAdaptor adaptor,
matchAndRewrite(aievec::UnpackOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
op.emitWarning() << "aie.unpack conversion is not implemented\n";
return failure();
Expand All @@ -705,18 +691,20 @@ class UnpackOpConversion

void populateAIEVecToLLVMConversionPatterns(mlir::LLVMTypeConverter &converter,
mlir::RewritePatternSet &patterns) {
patterns.add<xilinx::aievec::AddOpConversion>(converter);
patterns.add<xilinx::aievec::SubOpConversion>(converter);
patterns.add<xilinx::aievec::FMAOpConversion>(converter);
patterns.add<xilinx::aievec::MulOpConversion>(converter);
patterns.add<xilinx::aievec::UPSOpConversion>(converter);
patterns.add<xilinx::aievec::SRSOpConversion>(converter);
patterns.add<xilinx::aievec::UPDOpConversion>(converter);
patterns.add<xilinx::aievec::ConcatOpConversion>(converter);
patterns.add<xilinx::aievec::ExtOpConversion>(converter);
patterns.add<xilinx::aievec::SelectOpConversion>(converter);
patterns.add<xilinx::aievec::PackOpConversion>(converter);
patterns.add<xilinx::aievec::UnpackOpConversion>(converter);
// clang-format off
patterns.add<AddOpConversion,
SubOpConversion,
FMAOpConversion,
MulOpConversion,
UPSOpConversion,
SRSOpConversion,
UPDOpConversion,
ConcatOpConversion,
ExtOpConversion,
SelectOpConversion,
PackOpConversion,
UnpackOpConversion>(converter);
// clang-format on
}

struct ConvertAIEVecToLLVMPass
Expand Down

0 comments on commit 01c79ae

Please sign in to comment.