================ @@ -8701,95 +8734,266 @@ SDValue SystemZTargetLowering::combineSETCC( return SDValue(); } -static bool combineCCMask(SDValue &CCReg, int &CCValid, int &CCMask) { +static std::pair<SDValue, int> findCCUse(const SDValue &Val) { + auto *N = Val.getNode(); + if (!N) + return std::make_pair(SDValue(), SystemZ::CCMASK_NONE); + switch (N->getOpcode()) { + default: + return std::make_pair(SDValue(), SystemZ::CCMASK_NONE); + case SystemZISD::IPM: + if (N->getOperand(0).getOpcode() == SystemZISD::CLC || + N->getOperand(0).getOpcode() == SystemZISD::STRCMP) + return std::make_pair(N->getOperand(0), SystemZ::CCMASK_ICMP); + return std::make_pair(N->getOperand(0), SystemZ::CCMASK_ANY); + case SystemZISD::SELECT_CCMASK: { + SDValue Op4CCReg = N->getOperand(4); + auto *Op4CCNode = Op4CCReg.getNode(); + auto *CCValid = dyn_cast<ConstantSDNode>(N->getOperand(2)); + if (!CCValid || !Op4CCNode) + return std::make_pair(SDValue(), SystemZ::CCMASK_NONE); + int CCValidVal = CCValid->getZExtValue(); + if (Op4CCNode->getOpcode() == SystemZISD::ICMP || + Op4CCNode->getOpcode() == SystemZISD::TM) { + auto [OpCC, OpCCValid] = findCCUse(Op4CCNode->getOperand(0)); + if (OpCC != SDValue()) + return std::make_pair(OpCC, OpCCValid); + } + return std::make_pair(Op4CCReg, CCValidVal); + } + case ISD::ADD: + case ISD::AND: + case ISD::OR: + case ISD::XOR: + case ISD::SHL: + case ISD::SRA: + case ISD::SRL: + auto [Op0CC, Op0CCValid] = findCCUse(N->getOperand(0)); + if (Op0CC != SDValue()) + return std::make_pair(Op0CC, Op0CCValid); + return findCCUse(N->getOperand(1)); + } +} + +static bool combineCCMask(SDValue &CCReg, int &CCValid, int &CCMask, + SelectionDAG &DAG); + +SmallVector<SDValue, 4> static simplifyAssumingCCVal(SDValue &Val, SDValue &CC, + SelectionDAG &DAG) { + auto *N = Val.getNode(), *CCNode = CC.getNode(); + if (!N || !CCNode) + return {}; + SDLoc DL(N); + auto Opcode = N->getOpcode(); + switch (Opcode) { + default: + return {}; + case ISD::Constant: + return {Val, Val, Val, Val}; + case SystemZISD::IPM: { + auto *IPMOp0Node = N->getOperand(0).getNode(); + if (!IPMOp0Node || IPMOp0Node != CCNode) + return {}; + SmallVector<SDValue, 4> ShiftedCCVals; + for (auto CC : {0, 1, 2, 3}) + ShiftedCCVals.emplace_back( + DAG.getConstant((CC << SystemZ::IPM_CC), DL, MVT::i32)); + return ShiftedCCVals; + } + case SystemZISD::SELECT_CCMASK: { + SDValue TrueVal = N->getOperand(0), FalseVal = N->getOperand(1); + auto *TrueOp = TrueVal.getNode(); + auto *FalseOp = FalseVal.getNode(); + auto *CCValid = dyn_cast<ConstantSDNode>(N->getOperand(2)); + auto *CCMask = dyn_cast<ConstantSDNode>(N->getOperand(3)); + if (!TrueOp || !FalseOp || !CCValid || !CCMask) + return {}; + + int CCValidVal = CCValid->getZExtValue(); + int CCMaskVal = CCMask->getZExtValue(); + const auto &&TrueSDVals = simplifyAssumingCCVal(TrueVal, CC, DAG); + const auto &&FalseSDVals = simplifyAssumingCCVal(FalseVal, CC, DAG); + if (TrueSDVals.empty() || FalseSDVals.empty()) + return {}; + SDValue Op4CCReg = N->getOperand(4); + auto *Op4CCNode = Op4CCReg.getNode(); + if (Op4CCNode && Op4CCNode != CCNode) + combineCCMask(Op4CCReg, CCValidVal, CCMaskVal, DAG); + Op4CCNode = Op4CCReg.getNode(); + if (!Op4CCNode || Op4CCNode != CCNode) + return {}; + SmallVector<SDValue, 4> MergedSDVals; + for (auto &CCVal : {0, 1, 2, 3}) + MergedSDVals.emplace_back(((CCMaskVal & (1 << (3 - CCVal))) != 0) + ? TrueSDVals[CCVal] + : FalseSDVals[CCVal]); + return MergedSDVals; + } + case ISD::ADD: + case ISD::AND: + case ISD::OR: + case ISD::XOR: + case ISD::SRA: + if (!N->hasOneUse()) + return {}; + [[fallthrough]]; + case ISD::SHL: + case ISD::SRL: + SDValue Op0 = N->getOperand(0), Op1 = N->getOperand(1); + const auto &&Op0SDVals = simplifyAssumingCCVal(Op0, CC, DAG); + const auto &&Op1SDVals = simplifyAssumingCCVal(Op1, CC, DAG); + if (Op0SDVals.empty() || Op1SDVals.empty()) + return {}; + SmallVector<SDValue, 4> BinaryOpSDVals; + for (auto CCVal : {0, 1, 2, 3}) + BinaryOpSDVals.emplace_back(DAG.getNode( + Opcode, DL, Val.getValueType(), Op0SDVals[CCVal], Op1SDVals[CCVal])); + return BinaryOpSDVals; + } +} + +static bool combineCCMask(SDValue &CCReg, int &CCValid, int &CCMask, + SelectionDAG &DAG) { // We have a SELECT_CCMASK or BR_CCMASK comparing the condition code // set by the CCReg instruction using the CCValid / CCMask masks, - // If the CCReg instruction is itself a ICMP testing the condition + // If the CCReg instruction is itself a ICMP / TM testing the condition // code set by some other instruction, see whether we can directly // use that condition code. - - // Verify that we have an ICMP against some constant. - if (CCValid != SystemZ::CCMASK_ICMP) - return false; - auto *ICmp = CCReg.getNode(); - if (ICmp->getOpcode() != SystemZISD::ICMP) - return false; - auto *CompareLHS = ICmp->getOperand(0).getNode(); - auto *CompareRHS = dyn_cast<ConstantSDNode>(ICmp->getOperand(1)); - if (!CompareRHS) + auto *CCNode = CCReg.getNode(); + if (!CCNode) return false; + const auto getAPIntSDVals = [](const SmallVector<SDValue, 4> &Vals) { + SmallVector<APInt, 4> APIntVals; + for (const auto &Val : Vals) { + auto *ConstValNode = dyn_cast<ConstantSDNode>(Val.getNode()); + if (!ConstValNode) + return SmallVector<APInt, 4>(); + APIntVals.emplace_back(ConstValNode->getAPIntValue()); + } + return APIntVals; + }; - // Optimize the case where CompareLHS is a SELECT_CCMASK. - if (CompareLHS->getOpcode() == SystemZISD::SELECT_CCMASK) { - // Verify that we have an appropriate mask for a EQ or NE comparison. - bool Invert = false; - if (CCMask == SystemZ::CCMASK_CMP_NE) - Invert = !Invert; - else if (CCMask != SystemZ::CCMASK_CMP_EQ) + if (CCNode->getOpcode() == SystemZISD::TM) { + if (CCValid != SystemZ::CCMASK_TM) return false; - - // Verify that the ICMP compares against one of select values. - auto *TrueVal = dyn_cast<ConstantSDNode>(CompareLHS->getOperand(0)); - if (!TrueVal) + auto emulateTMCCMask = [](const APInt &Op0Val, const APInt &Op1Val) { + auto Result = Op0Val & Op1Val; + bool AllOnes = Result == Op1Val; + bool AllZeros = Result == 0; + int MSBPos = Op1Val.countl_zero(); + bool IsLeftMostBitSet = (Result & (1 << MSBPos)) != 0; + return AllOnes ? 3 : AllZeros ? 0 : IsLeftMostBitSet ? 2 : 1; + }; + SDValue Op0 = CCNode->getOperand(0); + SDValue Op1 = CCNode->getOperand(1); + auto [Op0CC, Op0CCValid] = findCCUse(Op0); + if (Op0CC == SDValue()) return false; - auto *FalseVal = dyn_cast<ConstantSDNode>(CompareLHS->getOperand(1)); - if (!FalseVal) + const auto &&Op0SDVals = simplifyAssumingCCVal(Op0, Op0CC, DAG); + const auto &&Op1SDVals = simplifyAssumingCCVal(Op1, Op0CC, DAG); + if (Op0SDVals.empty() || Op1SDVals.empty()) return false; - if (CompareRHS->getAPIntValue() == FalseVal->getAPIntValue()) - Invert = !Invert; - else if (CompareRHS->getAPIntValue() != TrueVal->getAPIntValue()) + auto &&Op0APInts = getAPIntSDVals(Op0SDVals); + const auto &&Op1APInts = getAPIntSDVals(Op1SDVals); + if (Op0APInts.empty() || Op1APInts.empty()) return false; - - // Compute the effective CC mask for the new branch or select. - auto *NewCCValid = dyn_cast<ConstantSDNode>(CompareLHS->getOperand(2)); - auto *NewCCMask = dyn_cast<ConstantSDNode>(CompareLHS->getOperand(3)); - if (!NewCCValid || !NewCCMask) + SmallVector<int, 4> CCVals; + std::transform(Op0APInts.begin(), Op0APInts.end(), Op1APInts.begin(), + std::back_inserter(CCVals), emulateTMCCMask); + if (CCVals.empty()) return false; - CCValid = NewCCValid->getZExtValue(); - CCMask = NewCCMask->getZExtValue(); - if (Invert) - CCMask ^= CCValid; - - // Return the updated CCReg link. - CCReg = CompareLHS->getOperand(4); + int NewCCMask = 0; + for (auto CC : CCVals) { + NewCCMask <<= 1; + NewCCMask |= (CCMask & (1 << (3 - CC))) != 0; + } + CCReg = Op0CC; + CCMask = NewCCMask; + CCValid = Op0CCValid; return true; } + if (CCNode->getOpcode() != SystemZISD::ICMP || + CCValid != SystemZ::CCMASK_ICMP) + return false; - // Optimize the case where CompareRHS is (SRA (SHL (IPM))). - if (CompareLHS->getOpcode() == ISD::SRA) { - auto *SRACount = dyn_cast<ConstantSDNode>(CompareLHS->getOperand(1)); - if (!SRACount || SRACount->getZExtValue() != 30) + SDValue CmpOp0 = CCNode->getOperand(0); + SDValue CmpOp1 = CCNode->getOperand(1); + SDValue CmpOp2 = CCNode->getOperand(2); + auto [Op0CC, Op0CCValid] = findCCUse(CmpOp0); + if (Op0CC != SDValue()) { + const auto &&Op0SDVals = simplifyAssumingCCVal(CmpOp0, Op0CC, DAG); + const auto &&Op1SDVals = simplifyAssumingCCVal(CmpOp1, Op0CC, DAG); + if (Op0SDVals.empty() || Op1SDVals.empty()) return false; - auto *SHL = CompareLHS->getOperand(0).getNode(); - if (SHL->getOpcode() != ISD::SHL) - return false; - auto *SHLCount = dyn_cast<ConstantSDNode>(SHL->getOperand(1)); - if (!SHLCount || SHLCount->getZExtValue() != 30 - SystemZ::IPM_CC) - return false; - auto *IPM = SHL->getOperand(0).getNode(); - if (IPM->getOpcode() != SystemZISD::IPM) + auto &&Op0APInts = getAPIntSDVals(Op0SDVals); + const auto &&Op1APInts = getAPIntSDVals(Op1SDVals); + if (Op0APInts.empty() || Op1APInts.empty()) return false; - // Avoid introducing CC spills (because SRA would clobber CC). - if (!CompareLHS->hasOneUse()) - return false; - // Verify that the ICMP compares against zero. - if (CompareRHS->getZExtValue() != 0) + auto *CmpType = dyn_cast<ConstantSDNode>(CmpOp2); + auto CmpTypeVal = CmpType->getZExtValue(); + const auto compareCCSigned = [&CmpTypeVal](const APInt &Op0Val, + const APInt &Op1Val) { + if (CmpTypeVal == SystemZICMP::SignedOnly) + return Op0Val == Op1Val ? 0 : Op0Val.slt(Op1Val) ? 1 : 2; + return Op0Val == Op1Val ? 0 : Op0Val.ult(Op1Val) ? 1 : 2; + }; + SmallVector<int, 4> CCVals; + std::transform(Op0APInts.begin(), Op0APInts.end(), Op1APInts.begin(), + std::back_inserter(CCVals), compareCCSigned); + if (CCVals.empty()) return false; ---------------- uweigand wrote:
As above, I don't think this can ever be empty. https://github.com/llvm/llvm-project/pull/125970 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits