Skip to content

Commit

Permalink
[AIE2P] Custom legalize G_CONCAT_VECTORS to only two input vectors
Browse files Browse the repository at this point in the history
  • Loading branch information
konstantinschwarz committed Feb 4, 2025
1 parent cf9f33b commit cd205a0
Show file tree
Hide file tree
Showing 3 changed files with 861 additions and 89 deletions.
46 changes: 25 additions & 21 deletions llvm/lib/Target/AIE/AIELegalizerHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1286,32 +1286,36 @@ bool AIELegalizerHelper::legalizeG_SELECT(LegalizerHelper &Helper,
return true;
}

// We legalize concat vector of 2 inputs. So, anything above we need to split
// it. So far expect only 4 input. 1024bit vector from 4 256bit register and
// 2048 accumulator register from 4 512bit registers.
/// Legalize the incoming \p MI G_CONCAT_VECTORS to half the number of inputs,
/// but at least 2 inputs.
bool AIELegalizerHelper::legalizeG_CONCAT_VECTORS(LegalizerHelper &Helper,
MachineInstr &MI) const {
MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
MachineRegisterInfo &MRI = *MIRBuilder.getMRI();

const Register DstReg = MI.getOperand(0).getReg();
const Register SrcReg = MI.getOperand(1).getReg();
const LLT DstTy = MRI.getType(DstReg);
const LLT SrcTy = MRI.getType(SrcReg);
const auto [DstReg, DstTy, SrcReg, SrcTy] = MI.getFirst2RegLLTs();
assert(DstTy.isVector() && SrcTy.isVector() && "Expected vector types");
assert(SrcTy.getSizeInBits() >= 256 && "Input vector size does not match!");
assert(MI.getNumOperands() == 5 && "Expected 4 inputs!");

const LLT DstVecEltTy = DstTy.getElementType();
const unsigned ElTySize = DstVecEltTy.getSizeInBits();
const LLT SplitTy = LLT::fixed_vector(DstTy.getNumElements() / 2, ElTySize);
const Register DstRegLo = MRI.createGenericVirtualRegister(SplitTy);
const Register DstRegHi = MRI.createGenericVirtualRegister(SplitTy);
MIRBuilder.buildConcatVectors(
{DstRegLo}, {MI.getOperand(1).getReg(), MI.getOperand(2).getReg()});
MIRBuilder.buildConcatVectors(
{DstRegHi}, {MI.getOperand(3).getReg(), MI.getOperand(4).getReg()});
MIRBuilder.buildConcatVectors({DstReg}, {DstRegLo, DstRegHi});
assert(SrcTy.getSizeInBits() >= 128 &&
"Vectors < 128-bit should be lowered to insert vector elt");

// Prevent infinite looping in the Legalizer. The base case should be legal
// and we should not reach this.
assert(DstTy.getSizeInBits() > 2 * SrcTy.getSizeInBits());

const LLT StepTy = SrcTy.multiplyElements(2);

// Concatenate pairs of source vector operands.
SmallVector<Register, 4> ConcatSteps;
for (size_t I = 1; I < MI.getNumOperands(); I += 2) {
const Register ConcatStep =
MIRBuilder
.buildConcatVectors({StepTy}, {MI.getOperand(I).getReg(),
MI.getOperand(I + 1).getReg()})
.getReg(0);
ConcatSteps.push_back(ConcatStep);
}

// Concatenate the resulting artifacts.
MIRBuilder.buildConcatVectors(DstReg, ConcatSteps);
MI.eraseFromParent();
return true;
}
Expand Down
26 changes: 14 additions & 12 deletions llvm/lib/Target/AIE/aie2p/AIE2PLegalizerInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,6 @@ isValidVectorMergeUnmergeOp(const unsigned BigVectorId,
};
}

static LegalityPredicate isValidVectorConcatOp(const unsigned BigVectorId,
const unsigned SmallVectorId) {
return [=](const LegalityQuery &Query) {
const LLT Big = Query.Types[BigVectorId];
const LLT Small = Query.Types[SmallVectorId];
return Big.isVector() && Small.isVector() &&
Big.getElementType() == Small.getElementType() &&
!(Big.getNumElements() % Small.getNumElements());
};
}

static LegalityPredicate isValidVectorAIEP(const unsigned TypeIdx) {
return [=](const LegalityQuery &Query) {
const LLT DstTy = Query.Types[TypeIdx];
Expand Down Expand Up @@ -583,7 +572,20 @@ AIE2PLegalizerInfo::AIE2PLegalizerInfo(const AIE2PSubtarget &ST)
getActionDefinitionsBuilder(G_CONCAT_VECTORS)
.unsupportedIf(IsNotValidDestinationVector)
.legalIf(isValidVectorMergeUnmergeOp(0, 1))
.customIf(isValidVectorConcatOp(0, 1));
.customIf([=](const LegalityQuery &Query) {
const LLT &DstTy = Query.Types[0];
const LLT &SrcTy = Query.Types[1];
if (!DstTy.isVector() || !SrcTy.isVector())
return false;

// Concatenating vectors <= 64-bit are not sub-vector operations.
// These should be lowered to insert vector elements.
if (SrcTy.getSizeInBits() <= 64)
return false;

// Legalize concat vectors to have excatly two inputs
return (DstTy.getNumElements() != 2 * SrcTy.getNumElements());
});

getActionDefinitionsBuilder(G_BUILD_VECTOR)
// Legacy legalization for bitcasts
Expand Down
Loading

0 comments on commit cd205a0

Please sign in to comment.