diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp index b955b9033fab..3f28a8293a6d 100644 --- a/lib/Target/X86/X86ISelLowering.cpp +++ b/lib/Target/X86/X86ISelLowering.cpp @@ -41664,9 +41664,11 @@ static SDValue combineMulToPMADDWD(SDNode *N, SelectionDAG &DAG, // with zero extends, which should qualify for the optimization. // Otherwise just fallback to zero-extension check. if (isa(N0.getOperand(1).getOperand(0)) && - N0.getOperand(1).getConstantOperandVal(0) == 16 && isa(N1.getOperand(1).getOperand(0)) && - N1.getOperand(1).getConstantOperandVal(0) == 16) { + N0.getOperand(1).getConstantOperandVal(0) == 16 && + N1.getOperand(1).getConstantOperandVal(0) == 16 && + DAG.isSplatValue(N0.getOperand(1)) && + DAG.isSplatValue(N1.getOperand(1))) { // Nullify mask to pass the following check Mask17 = 0; N0 = DAG.getNode(ISD::SRL, N0.getNode(), VT, N0.getOperand(0), @@ -41675,8 +41677,20 @@ static SDValue combineMulToPMADDWD(SDNode *N, SelectionDAG &DAG, N1.getOperand(1)); } } - if (!DAG.MaskedValueIsZero(N1, Mask17) || - !DAG.MaskedValueIsZero(N0, Mask17)) + + if (!!Mask17 && N0.getOpcode() == ISD::SRA) { + if (isa(N0.getOperand(1).getOperand(0)) && + DAG.ComputeNumSignBits(N1) >= 17 && + DAG.isSplatValue(N0.getOperand(1)) && + N0.getOperand(1).getConstantOperandVal(0) == 16) { + Mask17 = 0; + N0 = DAG.getNode(ISD::SRL, N0.getNode(), VT, N0.getOperand(0), + N0.getOperand(1)); + } + } + + if (!!Mask17 && (!DAG.MaskedValueIsZero(N1, Mask17) || + !DAG.MaskedValueIsZero(N0, Mask17))) return SDValue(); // Use SplitOpsAndApply to handle AVX splitting.