From c6156e3b203401c25d348ae995fd2f1ffc138787 Mon Sep 17 00:00:00 2001 From: James Newling Date: Thu, 31 Oct 2024 11:10:03 -0700 Subject: [PATCH] Towards vectorized convolution (second PR) (#866) This is just a copy of MLIR op definitions from mlir-aie. These ops are needed to lower `transfer_read` with insufficient alignment correctly. In future PR(s), the ops added in this PR will be used to lower `transfer_read`. --- .../target/AMD-AIE/aievec/AIEVecOps.cpp | 172 ++++++++++++++ .../target/AMD-AIE/aievec/AIEVecOps.td | 39 ++++ .../target/AMD-AIE/aievec/AIEVecToLLVM.cpp | 163 +++++++++++++- .../target/AMD-AIE/aievec/XLLVMAIE2IntrOps.td | 212 +++++++++++++++++- 4 files changed, 579 insertions(+), 7 deletions(-) diff --git a/compiler/plugins/target/AMD-AIE/aievec/AIEVecOps.cpp b/compiler/plugins/target/AMD-AIE/aievec/AIEVecOps.cpp index 1b03a9561..7d9343df2 100644 --- a/compiler/plugins/target/AMD-AIE/aievec/AIEVecOps.cpp +++ b/compiler/plugins/target/AMD-AIE/aievec/AIEVecOps.cpp @@ -605,6 +605,178 @@ ParseResult FMAElemOp::parse(OpAsmParser &parser, OperationState &result) { return parseMulFMAElemOp(parser, result, true); } + +//===----------------------------------------------------------------------===// +// ExtOp +//===----------------------------------------------------------------------===// + +// Print out Ext op. +void ExtOp::print(OpAsmPrinter &p) { + // Print the source vector + p << " " << getSource(); + + // Print the attributes + p.printOptionalAttrDict((*this)->getAttrs()); + + // And now print the types + p << " : " << getSource().getType() << ", " << getResult().getType(); +} + +// Verify Ext op. +LogicalResult ExtOp::verify() { + // Verify the types + VectorType sourceType = llvm::dyn_cast(getSource().getType()); + VectorType resultType = llvm::dyn_cast(getResult().getType()); + if (!sourceType || !resultType) + return emitError("requires vector type"); + + // Check the number of lanes + unsigned sourceLanes = getVectorLaneSize(sourceType); + unsigned resultLanes = getVectorLaneSize(resultType); + // Source lanes must be greater than result lanes + if (sourceLanes / resultLanes <= 1) + return emitError("lanes in source vector must be at least " + "twice that of result vector"); + // Source lanes must be a multiple of result lanes + if (sourceLanes % resultLanes != 0) + return emitError("lanes in result vector must be a multiple " + "of source vector lanes"); + + // Verify validity of index + unsigned factor = sourceLanes / resultLanes; + if (static_cast(getIndex()) >= factor) + return emitError("index out of bounds"); + + // The datatype of source and result must match + Type stype = sourceType.getElementType(); + Type rtype = resultType.getElementType(); + if (stype != rtype) + return emitError("source and result element type must be same"); + + return success(); +} + +// Parse Ext op. +ParseResult ExtOp::parse(OpAsmParser &parser, OperationState &result) { + llvm::SMLoc typesLoc; + SmallVector types; + OpAsmParser::UnresolvedOperand source; + + // Parse the source vector + if (parser.parseOperand(source)) + return failure(); + + // Parse all the attributes and types + if (parser.parseOptionalAttrDict(result.attributes) || + parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types)) + return failure(); + + if (result.attributes.getAttrs().size() != 1) + return parser.emitError(typesLoc, "requires one attribute"); + + // Assert that there are two types (source and result) + if (types.size() != 2) + return parser.emitError(typesLoc, "requires two types"); + + // Some verification + VectorType sourceType = llvm::dyn_cast(types[0]); + VectorType resultType = llvm::dyn_cast(types[1]); + if (!sourceType || !resultType) + return parser.emitError(typesLoc, "requires vector type"); + + // Populate the source in result + if (parser.resolveOperand(source, sourceType, result.operands)) + return failure(); + + return parser.addTypeToList(resultType, result.types); +} + +//===----------------------------------------------------------------------===// +// ShiftOp +//===----------------------------------------------------------------------===// + +// Print out Shift op. +void ShiftOp::print(OpAsmPrinter &p) { + // Print the lhs and rhs vectors + p << " " << getLhs() << ", " << getRhs(); + + // Print shift + p << ", " << getShift(); + + // Print the attributes + p.printOptionalAttrDict((*this)->getAttrs()); + + // And now print the types + p << " : " << getLhs().getType() << ", " << getLhs().getType() << ", " + << getShift().getType() << ", " << getResult().getType(); +} + +// Verify Shift op. +LogicalResult ShiftOp::verify() { + // Verify the types + VectorType resultType = llvm::dyn_cast(getResult().getType()); + if (!resultType) + return emitError("requires vector type"); + + // lhs, rhs and result must have the same type + VectorType lhsType = llvm::dyn_cast(getLhs().getType()); + VectorType rhsType = llvm::dyn_cast(getRhs().getType()); + + if (!lhsType || !rhsType) + return emitError("requires vector type"); + if (lhsType != resultType || rhsType != resultType) + return emitError("All vectors must have same type"); + + if (!isa(getShift().getType())) + return emitError("requires integer type"); + + return success(); +} + +// Parse Shift op. +ParseResult ShiftOp::parse(OpAsmParser &parser, OperationState &result) { + llvm::SMLoc typesLoc; + SmallVector types; + OpAsmParser::UnresolvedOperand lhs, rhs, shift; + + // Parse the source vectors + if (parser.parseOperand(lhs) || parser.parseComma() || + parser.parseOperand(rhs) || parser.parseComma() || + parser.parseOperand(shift)) + return failure(); + + // Parse all the attributes and types + if (parser.parseOptionalAttrDict(result.attributes) || + parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types)) + return failure(); + + if (result.attributes.getAttrs().size() != 1) + return parser.emitError(typesLoc, "expects one attribute"); + + // Assert that there are two types (source and result vectors) + if (types.size() != 4) + return parser.emitError(typesLoc, "requires four types"); + + // Some verification + VectorType lhsType = llvm::dyn_cast(types[0]); + VectorType rhsType = llvm::dyn_cast(types[1]); + IntegerType shiftType = llvm::dyn_cast(types[2]); + VectorType resultType = llvm::dyn_cast(types[3]); + if (!lhsType || !rhsType || !resultType) + return parser.emitError(typesLoc, "requires vector type"); + + if (!shiftType) + return parser.emitError(typesLoc, "requires integer type"); + + // Populate the lhs vector, rhs vectors and shift in result + if (parser.resolveOperand(lhs, lhsType, result.operands) || + parser.resolveOperand(rhs, rhsType, result.operands) || + parser.resolveOperand(shift, shiftType, result.operands)) + return failure(); + + return parser.addTypeToList(resultType, result.types); +} + #define GET_ATTRDEF_CLASSES #include "aievec/AIEVecAttributes.cpp.inc" diff --git a/compiler/plugins/target/AMD-AIE/aievec/AIEVecOps.td b/compiler/plugins/target/AMD-AIE/aievec/AIEVecOps.td index 9d8a76123..568637ed2 100644 --- a/compiler/plugins/target/AMD-AIE/aievec/AIEVecOps.td +++ b/compiler/plugins/target/AMD-AIE/aievec/AIEVecOps.td @@ -33,6 +33,45 @@ class AIEVec_Op traits = []> : let hasVerifier = 1; } + +def AIEVec_ShiftOp: + AIEVec_Op<"shift", [ + Pure + ]>, + Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, I32:$shift, DefaultValuedAttr:$isAcc)>, + Results<(outs AnyVector:$result)> { + let summary = "AIE2 concat and shift"; + let description = [{ + AMD-specific shift intrinsic. Concatenates two + vectors into a bigger vector, interprets them as a vector of 128 bytes + and returns v1::v2[shift: shift+64]. `shift` is the number of bytes to + be shifted. The verifier confirms that all the input and result vectors + have the same number of lanes and element types. + `$result = shift($lhs, $rhs, $shift)` + }]; +} + +def AIEVec_ExtOp: + AIEVec_Op<"ext", [ + Pure + ]>, + Arguments<(ins AnyVector:$source, + // ConfinedAttr, IntMaxValue<8>]>:$index)>, + ConfinedAttr, IntMaxValue<8>]>:$index)>, + Results<(outs AnyVector:$result)> { + let summary = "AIE ext"; + let description = [{ + AMD-specific vector extract intrinsic. Selects contiguous lanes from + the source vector, and transfers the data from those lanes to the + result. The lane selection is controlled by index. There are two cases: + 1. Extracted vector fills half of the original vector lanes (e.g. extract v64int8 from v128int8) + 2. Extracted vector fills a fourth of the original vector lanes (e.g. extract v32int8 from v128int8) + In the first case, index can be 0 or 1. Index 0 extracts the lower half, and index 1 extracts the upper half. + In the second case, index can be 0 to 3. Index 0 extracts the lowest quarter, index 1 the next quarter, and so on. + `$result = ext($source, $index)` + }]; +} + def AIEVec_UPSOp: AIEVec_Op<"ups", [ Pure diff --git a/compiler/plugins/target/AMD-AIE/aievec/AIEVecToLLVM.cpp b/compiler/plugins/target/AMD-AIE/aievec/AIEVecToLLVM.cpp index e6f753b9a..226947f83 100644 --- a/compiler/plugins/target/AMD-AIE/aievec/AIEVecToLLVM.cpp +++ b/compiler/plugins/target/AMD-AIE/aievec/AIEVecToLLVM.cpp @@ -797,11 +797,170 @@ class ShuffleOpConversion } }; +class ShiftOpConversion : public mlir::ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(aievec::ShiftOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + Value result = op.getResult(); + VectorType resultType = cast(result.getType()); + Type resultScaTy = resultType.getElementType(); + unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth(); + int resultLanes = getVectorLaneSize(resultType); + int resultVectorSize = resultBitWidth * resultLanes; + + if (resultVectorSize != 512) { + op.emitWarning() << "aievec.shift conversion with result vector size " + << resultVectorSize << " is not implemented.\n"; + return failure(); + } + + // assume step is always zero + auto stepCst = rewriter.create( + loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0)); + + // create xllvm intrinsic + Value shiftOp = nullptr; + SmallVector operands( + {adaptor.getLhs(), adaptor.getRhs(), stepCst, adaptor.getShift()}); + if (llvm::isa(resultScaTy)) { + // Integer types + shiftOp = rewriter.create( + loc, VectorType::get({16}, rewriter.getI32Type()), + forceCastOperandsToSignature( + rewriter, loc, operands, + {VectorType::get({16}, rewriter.getI32Type()), + VectorType::get({16}, rewriter.getI32Type()), + rewriter.getI32Type(), rewriter.getI32Type()})); + } else { + // Float types + shiftOp = rewriter.create( + loc, VectorType::get({32}, rewriter.getBF16Type()), + forceCastOperandsToSignature( + rewriter, loc, operands, + {VectorType::get({32}, rewriter.getBF16Type()), + VectorType::get({32}, rewriter.getBF16Type()), + rewriter.getI32Type(), rewriter.getI32Type()})); + } + + // create bitcast for result + rewriter.replaceOpWithNewOp(op, op.getResult().getType(), + shiftOp); + + return success(); + } +}; + +class ExtOpConversion : public mlir::ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(aievec::ExtOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + Value src = adaptor.getSource(); + VectorType srcType = cast(src.getType()); + Type srcScalarType = srcType.getElementType(); + unsigned srcBitWidth = srcScalarType.getIntOrFloatBitWidth(); + int srcLanes = getVectorLaneSize(srcType); + int srcVectorSize = srcBitWidth * srcLanes; + + Value result = op.getResult(); + VectorType resultType = cast(result.getType()); + Type resultScaTy = resultType.getElementType(); + unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth(); + int resultLanes = getVectorLaneSize(resultType); + int resultVectorSize = resultBitWidth * resultLanes; + + // create constant for index + auto indexCst = rewriter.create( + loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(op.getIndex())); + + // create xllvm intrinsic + SmallVector operands({adaptor.getSource(), indexCst}); + Value extOp = nullptr; + // Integer types + if (resultVectorSize == 256 && srcVectorSize == 512) { + extOp = rewriter.create( + loc, VectorType::get({8}, rewriter.getI32Type()), + forceCastOperandsToSignature( + rewriter, loc, operands, + {VectorType::get({16}, rewriter.getI32Type()), + rewriter.getI32Type()})); + } else if (resultVectorSize == 512 && srcVectorSize == 1024) { + extOp = rewriter.create( + loc, VectorType::get({16}, rewriter.getI32Type()), + forceCastOperandsToSignature( + rewriter, loc, operands, + {VectorType::get({32}, rewriter.getI32Type()), + rewriter.getI32Type()})); + } else if (resultVectorSize == 256 && srcVectorSize == 1024) { + extOp = rewriter.create( + loc, VectorType::get({8}, rewriter.getI32Type()), + forceCastOperandsToSignature( + rewriter, loc, operands, + {VectorType::get({32}, rewriter.getI32Type()), + rewriter.getI32Type()})); + } else if (resultVectorSize == 128 && srcVectorSize == 512) { + auto shiftOp = adaptor.getSource(); + if (op.getIndex() > 0) { + auto undefOp = rewriter.create( + loc, VectorType::get({16}, rewriter.getI32Type())); + auto stepCst = rewriter.create( + loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0)); + auto shiftCst = rewriter.create( + loc, rewriter.getI32Type(), + rewriter.getI32IntegerAttr(op.getIndex() * 16)); + SmallVector shiftOperands{adaptor.getSource(), undefOp, stepCst, + shiftCst}; + // Right shift the source vector in index * 16 bytes (i.e. in index * + // 128 bits). The integer index is expected to be 0 to 3. + shiftOp = rewriter.create( + loc, VectorType::get({16}, rewriter.getI32Type()), + forceCastOperandsToSignature( + rewriter, loc, shiftOperands, + {VectorType::get({16}, rewriter.getI32Type()), + VectorType::get({16}, rewriter.getI32Type()), + rewriter.getI32Type(), rewriter.getI32Type()})); + } + // The underlying intrinsic takes a source vector and extract the lowest + // 128-bit. i.e. it always extracts the input vector with index = 0. + extOp = rewriter.create( + loc, VectorType::get({4}, rewriter.getI32Type()), + forceCastOperandsToSignature( + rewriter, loc, /*operands=*/{shiftOp}, + {VectorType::get({16}, rewriter.getI32Type())})); + } else { + op.emitWarning() << "aievec.ext with " << srcVectorSize + << "-bit source, and " << resultVectorSize + << "-bit result is not supported.\n"; + return failure(); + } + + // create bitcast for result + if (op.getResult().getType() != extOp.getType()) { + rewriter.replaceOpWithNewOp(op, op.getResult().getType(), + extOp); + } else { + rewriter.replaceOp(op, extOp); + } + + return success(); + } +}; + + void populateAIEVecToLLVMConversionPatterns(mlir::LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns) { patterns.add( - converter); + FMAElemOpConversion, MatMulOpConversion, ShuffleOpConversion, + ExtOpConversion, ShiftOpConversion>(converter); } struct ConvertAIEVecToLLVMPass diff --git a/compiler/plugins/target/AMD-AIE/aievec/XLLVMAIE2IntrOps.td b/compiler/plugins/target/AMD-AIE/aievec/XLLVMAIE2IntrOps.td index dd4c11ec5..5a6d87d7f 100644 --- a/compiler/plugins/target/AMD-AIE/aievec/XLLVMAIE2IntrOps.td +++ b/compiler/plugins/target/AMD-AIE/aievec/XLLVMAIE2IntrOps.td @@ -30,28 +30,28 @@ class AIEVec2_IntrOp:$lhs, VectorOfLengthAndType<[32], [BF16]>:$rhs, VectorOfLengthAndType<[8], [I64]>:$acc, I32:$conf)>; -class AIE2I8MinMaxElem : +class AIE2I8MinMaxElem : Arguments<(ins VectorOfLengthAndType<[64], [I8]>:$lhs, VectorOfLengthAndType<[64], [I8]>:$rhs, I32:$cmp)> ; -class AIE2I16MinMaxElem : +class AIE2I16MinMaxElem : Arguments<(ins VectorOfLengthAndType<[32], [I16]>:$lhs, VectorOfLengthAndType<[32], [I16]>:$rhs, I32:$cmp)> ; -class AIE2I32MinMaxElem : +class AIE2I32MinMaxElem : Arguments<(ins VectorOfLengthAndType<[16], [I32]>:$lhs, VectorOfLengthAndType<[16], [I32]>:$rhs, I32:$cmp)> ; -class AIE2BF16MinMaxElem : +class AIE2BF16MinMaxElem : Arguments<(ins VectorOfLengthAndType<[32], [BF16]>:$lhs, VectorOfLengthAndType<[32], [BF16]>:$rhs)> ; @@ -78,6 +78,36 @@ def MacConfBF16IntrOp : [TypeIs<"res", VectorOfLengthAndType<[8], [I64]>>]>, AIE2BF16MACConf; +// ----- MSC ----- + +def MscConfBF16IntrOp : + AIEVec2_IntrOp<"bf.msc16.conf", + [TypeIs<"res", VectorOfLengthAndType<[8], [I64]>>]>, + AIE2BF16MACConf; + +// ----- MUL ----- + +def MulConfAcc32IntrOp : + AIEVec2_IntrOp<"I512.I512.acc32.mul.conf", + [TypeIs<"res", VectorOfLengthAndType<[16], [I64]>>]>, + Arguments<(ins VectorOfLengthAndType<[64], [I8]>:$lhs, + VectorOfLengthAndType<[16], [I32]>:$rhs, + I32:$conf)>; + +def MulConfAcc64IntrOp : + AIEVec2_IntrOp<"I512.I512.acc64.mul.conf", + [TypeIs<"res", VectorOfLengthAndType<[16], [I64]>>]>, + Arguments<(ins VectorOfLengthAndType<[64], [I8]>:$lhs, + VectorOfLengthAndType<[16], [I32]>:$rhs, + I32:$conf)>; + +def MulConfBF16IntrOp : + AIEVec2_IntrOp<"bf.mul16.conf", + [TypeIs<"res", VectorOfLengthAndType<[8], [I64]>>]>, + Arguments<(ins VectorOfLengthAndType<[32], [BF16]>:$lhs, + VectorOfLengthAndType<[32], [BF16]>:$rhs, + I32:$conf)>; + // ----- SET ----- def VectorSetI512I128IntrOp : @@ -140,6 +170,33 @@ def Vector16AccFloatToV16BF16IntrOp : [TypeIs<"res", VectorOfLengthAndType<[16], [BF16]>>]>, Arguments<(ins VectorOfLengthAndType<[8], [I64]>:$src)>; +// ----- BROADCAST ----- + +def VectorBroadcast8I512IntrOp : + AIEVec2_IntrOp<"vbroadcast8.I512", + [TypeIs<"res", VectorOfLengthAndType<[64], [I8]>>]>, + Arguments<(ins I32:$src)>; + +def VectorBroadcast16I512IntrOp : + AIEVec2_IntrOp<"vbroadcast16.I512", + [TypeIs<"res", VectorOfLengthAndType<[32], [I16]>>]>, + Arguments<(ins I32:$src)>; + +def VectorBroadcast32I512IntrOp : + AIEVec2_IntrOp<"vbroadcast32.I512", + [TypeIs<"res", VectorOfLengthAndType<[16], [I32]>>]>, + Arguments<(ins I32:$src)>; + +def VectorBroadcast16BF512IntrOp : + AIEVec2_IntrOp<"vbroadcast16.bf512", + [TypeIs<"res", VectorOfLengthAndType<[32], [BF16]>>]>, + Arguments<(ins BF16:$src)>; + +def VectorBroadcastfloatI512IntrOp : + AIEVec2_IntrOp<"vbroadcastfloat.I512", + [TypeIs<"res", VectorOfLengthAndType<[16], [F32]>>]>, + Arguments<(ins F32:$src)>; + // ----- EXT ----- def ExtI256I512IntrOp : @@ -154,6 +211,17 @@ def ExtI512I1024IntrOp : Arguments<(ins VectorOfLengthAndType<[32], [I32]>:$src, I32:$idx)>; +def ExtI256I1024IntrOp : + AIEVec2_IntrOp<"ext.I256.I1024", + [TypeIs<"res", VectorOfLengthAndType<[8], [I32]>>]>, + Arguments<(ins VectorOfLengthAndType<[32], [I32]>:$src, + I32:$idx)>; + +def ExtI128I512IntrOp : + AIEVec2_IntrOp<"extract.I128.I512", + [TypeIs<"res", VectorOfLengthAndType<[4], [I32]>>]>, + Arguments<(ins VectorOfLengthAndType<[16], [I32]>:$src)>; + // ----- CONCAT ----- def ConcatI512I256IntrOp : @@ -162,6 +230,14 @@ def ConcatI512I256IntrOp : Arguments<(ins VectorOfLengthAndType<[8], [I32]>:$lhs, VectorOfLengthAndType<[8], [I32]>:$rhs)>; +def ConcatI1024I256IntrOp : + AIEVec2_IntrOp<"concat.I1024.I256", + [TypeIs<"res", VectorOfLengthAndType<[32], [I32]>>]>, + Arguments<(ins VectorOfLengthAndType<[8], [I32]>:$src0, + VectorOfLengthAndType<[8], [I32]>:$src1, + VectorOfLengthAndType<[8], [I32]>:$src2, + VectorOfLengthAndType<[8], [I32]>:$src3)>; + def ConcatI1024I512IntrOp : AIEVec2_IntrOp<"concat.I1024.I512", [TypeIs<"res", VectorOfLengthAndType<[32], [I32]>>]>, @@ -183,6 +259,15 @@ def UndefV16I32IntrOp : AIEVec2_IntrOp<"v16int32", [TypeIs<"res", VectorOfLengthAndType<[16], [I32]>>]>; +// ----- UPD ----- + +def UpdBF512BF256IntrOp : + AIEVec2_IntrOp<"upd.bf512.bf256", + [TypeIs<"res", VectorOfLengthAndType<[32], [BF16]>>]>, + Arguments<(ins VectorOfLengthAndType<[32], [BF16]>:$dst, + VectorOfLengthAndType<[16], [BF16]>:$src, + I32:$idx)>; + // ----- UPS ----- def Acc32V16I256UpsIntrOp : @@ -232,4 +317,121 @@ def Vector16BF16ToV16AccFloatIntrOp : [TypeIs<"res", VectorOfLengthAndType<[8], [I64]>>]>, Arguments<(ins VectorOfLengthAndType<[16], [BF16]>:$src)>; +// ----- SHIFT ----- + +def VectorShiftI512I512IntrOp : + AIEVec2_IntrOp<"vshift.I512.I512", + [TypeIs<"res", VectorOfLengthAndType<[16], [I32]>>]>, + Arguments<(ins VectorOfLengthAndType<[16], [I32]>:$lhs, + VectorOfLengthAndType<[16], [I32]>:$rhs, + I32:$step, + I32:$shift)>; + +def VectorShiftBF512BF512IntrOp : + AIEVec2_IntrOp<"vshift.bf512.bf512", + [TypeIs<"res", VectorOfLengthAndType<[32], [BF16]>>]>, + Arguments<(ins VectorOfLengthAndType<[32], [BF16]>:$lhs, + VectorOfLengthAndType<[32], [BF16]>:$rhs, + I32:$step, + I32:$shift)>; + +// ----- EXTRACT ELEMENT ----- + +def VectorExtractElem8I512IntrOp : + AIEVec2_IntrOp<"vextract.elem8.I512", + [TypeIs<"res", I32>]>, + Arguments<(ins VectorOfLengthAndType<[64], [I8]>:$src, + I32:$idx, + I32:$sign)>; + +def VectorExtractElem16I512IntrOp : + AIEVec2_IntrOp<"vextract.elem16.I512", + [TypeIs<"res", I32>]>, + Arguments<(ins VectorOfLengthAndType<[32], [I16]>:$src, + I32:$idx, + I32:$sign)>; + +def VectorExtractElem32I512IntrOp : + AIEVec2_IntrOp<"vextract.elem32.I512", + [TypeIs<"res", I32>]>, + Arguments<(ins VectorOfLengthAndType<[16], [I32]>:$src, + I32:$idx, + I32:$sign)>; + +// ----- MAX ELEMENT ----- + +def VectorMaxLt8IntrOp : + AIEVec2_IntrOp<"vmax.lt8", + [TypeIs<"res", + LLVM_StructOf<[ + VectorOfLengthAndType<[64], [I8]>, + VectorOfLengthAndType<[2], [I32]>]> + >], /*numResults=*/2>, + AIE2I8MinMaxElem; + +def VectorMaxLt16IntrOp : + AIEVec2_IntrOp<"vmax.lt16", + [TypeIs<"res", + LLVM_StructOf<[ + VectorOfLengthAndType<[32], [I16]>, + I32]> + >], /*numResults=*/2>, + AIE2I16MinMaxElem; + +def VectorMaxLt32IntrOp : + AIEVec2_IntrOp<"vmax.lt32", + [TypeIs<"res", + LLVM_StructOf<[ + VectorOfLengthAndType<[16], [I32]>, + I32]> + >], /*numResults=*/2>, + AIE2I32MinMaxElem; + +def VectorMaxLtBf16IntrOp : + AIEVec2_IntrOp<"vmax.ltbf16", + [TypeIs<"res", + LLVM_StructOf<[ + VectorOfLengthAndType<[32], [BF16]>, + I32]> + >], /*numResults=*/2>, + AIE2BF16MinMaxElem; + +// ----- MIN ELEMENT ----- + +def VectorMinGe8IntrOp : + AIEVec2_IntrOp<"vmin.ge8", + [TypeIs<"res", + LLVM_StructOf<[ + VectorOfLengthAndType<[64], [I8]>, + VectorOfLengthAndType<[2], [I32]>]> + >], /*numResults=*/2>, + AIE2I8MinMaxElem; + +def VectorMinGe16IntrOp : + AIEVec2_IntrOp<"vmin.ge16", + [TypeIs<"res", + LLVM_StructOf<[ + VectorOfLengthAndType<[32], [I16]>, + I32]> + >], /*numResults=*/2>, + AIE2I16MinMaxElem; + +def VectorMinGe32IntrOp : + AIEVec2_IntrOp<"vmin.ge32", + [TypeIs<"res", + LLVM_StructOf<[ + VectorOfLengthAndType<[16], [I32]>, + I32]> + >], /*numResults=*/2>, + AIE2I32MinMaxElem; + +def VectorMinGeBf16IntrOp : + AIEVec2_IntrOp<"vmin.gebf16", + [TypeIs<"res", + LLVM_StructOf<[ + VectorOfLengthAndType<[32], [BF16]>, + I32]> + >], /*numResults=*/2>, + AIE2BF16MinMaxElem; + #endif // AIE_DIALECT_XLLVM_IR_XLLVMAIE2INTROPS_TD