diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp index 3f28a8293a6d..2e2c0c237bb0 100644 --- a/lib/Target/X86/X86ISelLowering.cpp +++ b/lib/Target/X86/X86ISelLowering.cpp @@ -41658,6 +41658,19 @@ static SDValue combineMulToPMADDWD(SDNode *N, SelectionDAG &DAG, N1.getOperand(0).getScalarValueSizeInBits() <= 8)) return SDValue(); + // Check if only high 16 bits of signed 16-bit multiplication are used + bool high_only = true; + + for (auto *User : N->uses()) { + if (User->getOpcode() == ISD::SRL || User->getOpcode() == ISD::SRA) { + if (DAG.MaskedValueIsAllOnes(User->getOperand(1), {32, 16})) { + continue; + } + } + high_only = false; + break; + } + APInt Mask17 = APInt::getHighBitsSet(32, 17); if (N0.getOpcode() == ISD::SRA && N1.getOpcode() == ISD::SRA) { // If both arguments are sign-extended, try to replace sign extends @@ -41671,10 +41684,17 @@ static SDValue combineMulToPMADDWD(SDNode *N, SelectionDAG &DAG, DAG.isSplatValue(N1.getOperand(1))) { // Nullify mask to pass the following check Mask17 = 0; - N0 = DAG.getNode(ISD::SRL, N0.getNode(), VT, N0.getOperand(0), - N0.getOperand(1)); - N1 = DAG.getNode(ISD::SRL, N1.getNode(), VT, N1.getOperand(0), - N1.getOperand(1)); + + if (high_only) { + // Bypass shifts to keep values in high bits + N0 = N0.getOperand(0); + N1 = N1.getOperand(0); + } else { + N0 = DAG.getNode(ISD::SRL, N0.getNode(), VT, N0.getOperand(0), + N0.getOperand(1)); + N1 = DAG.getNode(ISD::SRL, N1.getNode(), VT, N1.getOperand(0), + N1.getOperand(1)); + } } } @@ -41684,8 +41704,12 @@ static SDValue combineMulToPMADDWD(SDNode *N, SelectionDAG &DAG, DAG.isSplatValue(N0.getOperand(1)) && N0.getOperand(1).getConstantOperandVal(0) == 16) { Mask17 = 0; - N0 = DAG.getNode(ISD::SRL, N0.getNode(), VT, N0.getOperand(0), - N0.getOperand(1)); + + if (high_only) + N0 = N0.getOperand(0); + else + N0 = DAG.getNode(ISD::SRL, N0.getNode(), VT, N0.getOperand(0), + N0.getOperand(1)); } } @@ -41693,6 +41717,19 @@ static SDValue combineMulToPMADDWD(SDNode *N, SelectionDAG &DAG, !DAG.MaskedValueIsZero(N0, Mask17))) return SDValue(); + // Use PMULHW if applicable + if (high_only && !Mask17) { + auto MULHSBuilder = [=](SelectionDAG &DAG, const SDLoc &DL, + ArrayRef Ops) { + MVT RT = MVT::getVectorVT(MVT::i32, Ops[0].getValueSizeInBits() / 32); + MVT OpVT = Ops[0].getSimpleValueType(); + return DAG.getBitcast(RT, DAG.getNode(ISD::MULHS, DL, OpVT, Ops)); + }; + return SplitOpsAndApply(DAG, Subtarget, SDLoc(N), VT, + {DAG.getBitcast(WVT, N0), DAG.getBitcast(WVT, N1)}, + MULHSBuilder); + } + // Use SplitOpsAndApply to handle AVX splitting. auto PMADDWDBuilder = [](SelectionDAG &DAG, const SDLoc &DL, ArrayRef Ops) {