Skip to content

Commit

Permalink
X86: allow combineMulToPMADDWD to emit PMULHW (MULHS)
Browse files Browse the repository at this point in the history
If only high bits of single multiplication are used.
  • Loading branch information
Nekotekina committed May 17, 2021
1 parent 5836324 commit 1f23ff6
Showing 1 changed file with 43 additions and 6 deletions.
49 changes: 43 additions & 6 deletions lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41658,6 +41658,19 @@ static SDValue combineMulToPMADDWD(SDNode *N, SelectionDAG &DAG,
N1.getOperand(0).getScalarValueSizeInBits() <= 8))
return SDValue();

// Check if only high 16 bits of signed 16-bit multiplication are used
bool high_only = true;

for (auto *User : N->uses()) {
if (User->getOpcode() == ISD::SRL || User->getOpcode() == ISD::SRA) {
if (DAG.MaskedValueIsAllOnes(User->getOperand(1), {32, 16})) {
continue;
}
}
high_only = false;
break;
}

APInt Mask17 = APInt::getHighBitsSet(32, 17);
if (N0.getOpcode() == ISD::SRA && N1.getOpcode() == ISD::SRA) {
// If both arguments are sign-extended, try to replace sign extends
Expand All @@ -41671,10 +41684,17 @@ static SDValue combineMulToPMADDWD(SDNode *N, SelectionDAG &DAG,
DAG.isSplatValue(N1.getOperand(1))) {
// Nullify mask to pass the following check
Mask17 = 0;
N0 = DAG.getNode(ISD::SRL, N0.getNode(), VT, N0.getOperand(0),
N0.getOperand(1));
N1 = DAG.getNode(ISD::SRL, N1.getNode(), VT, N1.getOperand(0),
N1.getOperand(1));

if (high_only) {
// Bypass shifts to keep values in high bits
N0 = N0.getOperand(0);
N1 = N1.getOperand(0);
} else {
N0 = DAG.getNode(ISD::SRL, N0.getNode(), VT, N0.getOperand(0),
N0.getOperand(1));
N1 = DAG.getNode(ISD::SRL, N1.getNode(), VT, N1.getOperand(0),
N1.getOperand(1));
}
}
}

Expand All @@ -41684,15 +41704,32 @@ static SDValue combineMulToPMADDWD(SDNode *N, SelectionDAG &DAG,
DAG.isSplatValue(N0.getOperand(1)) &&
N0.getOperand(1).getConstantOperandVal(0) == 16) {
Mask17 = 0;
N0 = DAG.getNode(ISD::SRL, N0.getNode(), VT, N0.getOperand(0),
N0.getOperand(1));

if (high_only)
N0 = N0.getOperand(0);
else
N0 = DAG.getNode(ISD::SRL, N0.getNode(), VT, N0.getOperand(0),
N0.getOperand(1));
}
}

if (!!Mask17 && (!DAG.MaskedValueIsZero(N1, Mask17) ||
!DAG.MaskedValueIsZero(N0, Mask17)))
return SDValue();

// Use PMULHW if applicable
if (high_only && !Mask17) {
auto MULHSBuilder = [=](SelectionDAG &DAG, const SDLoc &DL,
ArrayRef<SDValue> Ops) {
MVT RT = MVT::getVectorVT(MVT::i32, Ops[0].getValueSizeInBits() / 32);
MVT OpVT = Ops[0].getSimpleValueType();
return DAG.getBitcast(RT, DAG.getNode(ISD::MULHS, DL, OpVT, Ops));
};
return SplitOpsAndApply(DAG, Subtarget, SDLoc(N), VT,
{DAG.getBitcast(WVT, N0), DAG.getBitcast(WVT, N1)},
MULHSBuilder);
}

// Use SplitOpsAndApply to handle AVX splitting.
auto PMADDWDBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
ArrayRef<SDValue> Ops) {
Expand Down

0 comments on commit 1f23ff6

Please sign in to comment.