Skip to content

Commit

Permalink
Towards vectorized convolution (second PR) (#866)
Browse files Browse the repository at this point in the history
This is just a copy of MLIR op definitions from mlir-aie. These ops are
needed to lower `transfer_read` with insufficient alignment correctly.
In future PR(s), the ops added in this PR will be used to lower
`transfer_read`.
  • Loading branch information
newling authored Oct 31, 2024
1 parent 38118dc commit c6156e3
Show file tree
Hide file tree
Showing 4 changed files with 579 additions and 7 deletions.
172 changes: 172 additions & 0 deletions compiler/plugins/target/AMD-AIE/aievec/AIEVecOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,178 @@ ParseResult FMAElemOp::parse(OpAsmParser &parser, OperationState &result) {
return parseMulFMAElemOp(parser, result, true);
}


//===----------------------------------------------------------------------===//
// ExtOp
//===----------------------------------------------------------------------===//

// Print out Ext op.
void ExtOp::print(OpAsmPrinter &p) {
// Print the source vector
p << " " << getSource();

// Print the attributes
p.printOptionalAttrDict((*this)->getAttrs());

// And now print the types
p << " : " << getSource().getType() << ", " << getResult().getType();
}

// Verify Ext op.
LogicalResult ExtOp::verify() {
// Verify the types
VectorType sourceType = llvm::dyn_cast<VectorType>(getSource().getType());
VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
if (!sourceType || !resultType)
return emitError("requires vector type");

// Check the number of lanes
unsigned sourceLanes = getVectorLaneSize(sourceType);
unsigned resultLanes = getVectorLaneSize(resultType);
// Source lanes must be greater than result lanes
if (sourceLanes / resultLanes <= 1)
return emitError("lanes in source vector must be at least "
"twice that of result vector");
// Source lanes must be a multiple of result lanes
if (sourceLanes % resultLanes != 0)
return emitError("lanes in result vector must be a multiple "
"of source vector lanes");

// Verify validity of index
unsigned factor = sourceLanes / resultLanes;
if (static_cast<unsigned>(getIndex()) >= factor)
return emitError("index out of bounds");

// The datatype of source and result must match
Type stype = sourceType.getElementType();
Type rtype = resultType.getElementType();
if (stype != rtype)
return emitError("source and result element type must be same");

return success();
}

// Parse Ext op.
ParseResult ExtOp::parse(OpAsmParser &parser, OperationState &result) {
llvm::SMLoc typesLoc;
SmallVector<Type, 2> types;
OpAsmParser::UnresolvedOperand source;

// Parse the source vector
if (parser.parseOperand(source))
return failure();

// Parse all the attributes and types
if (parser.parseOptionalAttrDict(result.attributes) ||
parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
return failure();

if (result.attributes.getAttrs().size() != 1)
return parser.emitError(typesLoc, "requires one attribute");

// Assert that there are two types (source and result)
if (types.size() != 2)
return parser.emitError(typesLoc, "requires two types");

// Some verification
VectorType sourceType = llvm::dyn_cast<VectorType>(types[0]);
VectorType resultType = llvm::dyn_cast<VectorType>(types[1]);
if (!sourceType || !resultType)
return parser.emitError(typesLoc, "requires vector type");

// Populate the source in result
if (parser.resolveOperand(source, sourceType, result.operands))
return failure();

return parser.addTypeToList(resultType, result.types);
}

//===----------------------------------------------------------------------===//
// ShiftOp
//===----------------------------------------------------------------------===//

// Print out Shift op.
void ShiftOp::print(OpAsmPrinter &p) {
// Print the lhs and rhs vectors
p << " " << getLhs() << ", " << getRhs();

// Print shift
p << ", " << getShift();

// Print the attributes
p.printOptionalAttrDict((*this)->getAttrs());

// And now print the types
p << " : " << getLhs().getType() << ", " << getLhs().getType() << ", "
<< getShift().getType() << ", " << getResult().getType();
}

// Verify Shift op.
LogicalResult ShiftOp::verify() {
// Verify the types
VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
if (!resultType)
return emitError("requires vector type");

// lhs, rhs and result must have the same type
VectorType lhsType = llvm::dyn_cast<VectorType>(getLhs().getType());
VectorType rhsType = llvm::dyn_cast<VectorType>(getRhs().getType());

if (!lhsType || !rhsType)
return emitError("requires vector type");
if (lhsType != resultType || rhsType != resultType)
return emitError("All vectors must have same type");

if (!isa<IntegerType>(getShift().getType()))
return emitError("requires integer type");

return success();
}

// Parse Shift op.
ParseResult ShiftOp::parse(OpAsmParser &parser, OperationState &result) {
llvm::SMLoc typesLoc;
SmallVector<Type, 4> types;
OpAsmParser::UnresolvedOperand lhs, rhs, shift;

// Parse the source vectors
if (parser.parseOperand(lhs) || parser.parseComma() ||
parser.parseOperand(rhs) || parser.parseComma() ||
parser.parseOperand(shift))
return failure();

// Parse all the attributes and types
if (parser.parseOptionalAttrDict(result.attributes) ||
parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
return failure();

if (result.attributes.getAttrs().size() != 1)
return parser.emitError(typesLoc, "expects one attribute");

// Assert that there are two types (source and result vectors)
if (types.size() != 4)
return parser.emitError(typesLoc, "requires four types");

// Some verification
VectorType lhsType = llvm::dyn_cast<VectorType>(types[0]);
VectorType rhsType = llvm::dyn_cast<VectorType>(types[1]);
IntegerType shiftType = llvm::dyn_cast<IntegerType>(types[2]);
VectorType resultType = llvm::dyn_cast<VectorType>(types[3]);
if (!lhsType || !rhsType || !resultType)
return parser.emitError(typesLoc, "requires vector type");

if (!shiftType)
return parser.emitError(typesLoc, "requires integer type");

// Populate the lhs vector, rhs vectors and shift in result
if (parser.resolveOperand(lhs, lhsType, result.operands) ||
parser.resolveOperand(rhs, rhsType, result.operands) ||
parser.resolveOperand(shift, shiftType, result.operands))
return failure();

return parser.addTypeToList(resultType, result.types);
}

#define GET_ATTRDEF_CLASSES
#include "aievec/AIEVecAttributes.cpp.inc"

Expand Down
39 changes: 39 additions & 0 deletions compiler/plugins/target/AMD-AIE/aievec/AIEVecOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,45 @@ class AIEVec_Op<string mnemonic, list<Trait> traits = []> :
let hasVerifier = 1;
}


def AIEVec_ShiftOp:
AIEVec_Op<"shift", [
Pure
]>,
Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, I32:$shift, DefaultValuedAttr<BoolAttr, "false">:$isAcc)>,
Results<(outs AnyVector:$result)> {
let summary = "AIE2 concat and shift";
let description = [{
AMD-specific shift intrinsic. Concatenates two
vectors into a bigger vector, interprets them as a vector of 128 bytes
and returns v1::v2[shift: shift+64]. `shift` is the number of bytes to
be shifted. The verifier confirms that all the input and result vectors
have the same number of lanes and element types.
`$result = shift($lhs, $rhs, $shift)`
}];
}

def AIEVec_ExtOp:
AIEVec_Op<"ext", [
Pure
]>,
Arguments<(ins AnyVector:$source,
// ConfinedAttr<AIEI8Attr, [IntMinValue<0>, IntMaxValue<8>]>:$index)>,
ConfinedAttr<I8Attr, [IntMinValue<0>, IntMaxValue<8>]>:$index)>,
Results<(outs AnyVector:$result)> {
let summary = "AIE ext";
let description = [{
AMD-specific vector extract intrinsic. Selects contiguous lanes from
the source vector, and transfers the data from those lanes to the
result. The lane selection is controlled by index. There are two cases:
1. Extracted vector fills half of the original vector lanes (e.g. extract v64int8 from v128int8)
2. Extracted vector fills a fourth of the original vector lanes (e.g. extract v32int8 from v128int8)
In the first case, index can be 0 or 1. Index 0 extracts the lower half, and index 1 extracts the upper half.
In the second case, index can be 0 to 3. Index 0 extracts the lowest quarter, index 1 the next quarter, and so on.
`$result = ext($source, $index)`
}];
}

def AIEVec_UPSOp:
AIEVec_Op<"ups", [
Pure
Expand Down
163 changes: 161 additions & 2 deletions compiler/plugins/target/AMD-AIE/aievec/AIEVecToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -797,11 +797,170 @@ class ShuffleOpConversion
}
};

class ShiftOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::ShiftOp> {
public:
using ConvertOpToLLVMPattern<aievec::ShiftOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(aievec::ShiftOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();

Value result = op.getResult();
VectorType resultType = cast<VectorType>(result.getType());
Type resultScaTy = resultType.getElementType();
unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
int resultLanes = getVectorLaneSize(resultType);
int resultVectorSize = resultBitWidth * resultLanes;

if (resultVectorSize != 512) {
op.emitWarning() << "aievec.shift conversion with result vector size "
<< resultVectorSize << " is not implemented.\n";
return failure();
}

// assume step is always zero
auto stepCst = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));

// create xllvm intrinsic
Value shiftOp = nullptr;
SmallVector<Value> operands(
{adaptor.getLhs(), adaptor.getRhs(), stepCst, adaptor.getShift()});
if (llvm::isa<IntegerType>(resultScaTy)) {
// Integer types
shiftOp = rewriter.create<xllvm::VectorShiftI512I512IntrOp>(
loc, VectorType::get({16}, rewriter.getI32Type()),
forceCastOperandsToSignature(
rewriter, loc, operands,
{VectorType::get({16}, rewriter.getI32Type()),
VectorType::get({16}, rewriter.getI32Type()),
rewriter.getI32Type(), rewriter.getI32Type()}));
} else {
// Float types
shiftOp = rewriter.create<xllvm::VectorShiftBF512BF512IntrOp>(
loc, VectorType::get({32}, rewriter.getBF16Type()),
forceCastOperandsToSignature(
rewriter, loc, operands,
{VectorType::get({32}, rewriter.getBF16Type()),
VectorType::get({32}, rewriter.getBF16Type()),
rewriter.getI32Type(), rewriter.getI32Type()}));
}

// create bitcast for result
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, op.getResult().getType(),
shiftOp);

return success();
}
};

class ExtOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::ExtOp> {
public:
using ConvertOpToLLVMPattern<aievec::ExtOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(aievec::ExtOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();

Value src = adaptor.getSource();
VectorType srcType = cast<VectorType>(src.getType());
Type srcScalarType = srcType.getElementType();
unsigned srcBitWidth = srcScalarType.getIntOrFloatBitWidth();
int srcLanes = getVectorLaneSize(srcType);
int srcVectorSize = srcBitWidth * srcLanes;

Value result = op.getResult();
VectorType resultType = cast<VectorType>(result.getType());
Type resultScaTy = resultType.getElementType();
unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
int resultLanes = getVectorLaneSize(resultType);
int resultVectorSize = resultBitWidth * resultLanes;

// create constant for index
auto indexCst = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(op.getIndex()));

// create xllvm intrinsic
SmallVector<Value> operands({adaptor.getSource(), indexCst});
Value extOp = nullptr;
// Integer types
if (resultVectorSize == 256 && srcVectorSize == 512) {
extOp = rewriter.create<xllvm::ExtI256I512IntrOp>(
loc, VectorType::get({8}, rewriter.getI32Type()),
forceCastOperandsToSignature(
rewriter, loc, operands,
{VectorType::get({16}, rewriter.getI32Type()),
rewriter.getI32Type()}));
} else if (resultVectorSize == 512 && srcVectorSize == 1024) {
extOp = rewriter.create<xllvm::ExtI512I1024IntrOp>(
loc, VectorType::get({16}, rewriter.getI32Type()),
forceCastOperandsToSignature(
rewriter, loc, operands,
{VectorType::get({32}, rewriter.getI32Type()),
rewriter.getI32Type()}));
} else if (resultVectorSize == 256 && srcVectorSize == 1024) {
extOp = rewriter.create<xllvm::ExtI256I1024IntrOp>(
loc, VectorType::get({8}, rewriter.getI32Type()),
forceCastOperandsToSignature(
rewriter, loc, operands,
{VectorType::get({32}, rewriter.getI32Type()),
rewriter.getI32Type()}));
} else if (resultVectorSize == 128 && srcVectorSize == 512) {
auto shiftOp = adaptor.getSource();
if (op.getIndex() > 0) {
auto undefOp = rewriter.create<xllvm::UndefV16I32IntrOp>(
loc, VectorType::get({16}, rewriter.getI32Type()));
auto stepCst = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
auto shiftCst = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(),
rewriter.getI32IntegerAttr(op.getIndex() * 16));
SmallVector<Value> shiftOperands{adaptor.getSource(), undefOp, stepCst,
shiftCst};
// Right shift the source vector in index * 16 bytes (i.e. in index *
// 128 bits). The integer index is expected to be 0 to 3.
shiftOp = rewriter.create<xllvm::VectorShiftI512I512IntrOp>(
loc, VectorType::get({16}, rewriter.getI32Type()),
forceCastOperandsToSignature(
rewriter, loc, shiftOperands,
{VectorType::get({16}, rewriter.getI32Type()),
VectorType::get({16}, rewriter.getI32Type()),
rewriter.getI32Type(), rewriter.getI32Type()}));
}
// The underlying intrinsic takes a source vector and extract the lowest
// 128-bit. i.e. it always extracts the input vector with index = 0.
extOp = rewriter.create<xllvm::ExtI128I512IntrOp>(
loc, VectorType::get({4}, rewriter.getI32Type()),
forceCastOperandsToSignature(
rewriter, loc, /*operands=*/{shiftOp},
{VectorType::get({16}, rewriter.getI32Type())}));
} else {
op.emitWarning() << "aievec.ext with " << srcVectorSize
<< "-bit source, and " << resultVectorSize
<< "-bit result is not supported.\n";
return failure();
}

// create bitcast for result
if (op.getResult().getType() != extOp.getType()) {
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, op.getResult().getType(),
extOp);
} else {
rewriter.replaceOp(op, extOp);
}

return success();
}
};


void populateAIEVecToLLVMConversionPatterns(mlir::LLVMTypeConverter &converter,
mlir::RewritePatternSet &patterns) {
patterns.add<UPSOpConversion, SRSOpConversion, FoldAIECastOps,
FMAElemOpConversion, MatMulOpConversion, ShuffleOpConversion>(
converter);
FMAElemOpConversion, MatMulOpConversion, ShuffleOpConversion,
ExtOpConversion, ShiftOpConversion>(converter);
}

struct ConvertAIEVecToLLVMPass
Expand Down
Loading

0 comments on commit c6156e3

Please sign in to comment.