================ @@ -8702,61 +8738,264 @@ SDValue SystemZTargetLowering::combineSETCC( return SDValue(); } +static SmallSet<int, 4> convertCCMaskToCCValsSet(int Mask) { + SmallSet<int, 4> CCVals; + size_t Pos = 0; + while (Mask) { + if (Mask & 0x1) + CCVals.insert(3 - Pos); + Mask >>= 1; + ++Pos; + } + return CCVals; +} + +static std::pair<SDValue, int> findCCUse(const SDValue &Val) { + auto *N = Val.getNode(); + if (!N) + return {Val, SystemZ::CCMASK_NONE}; + if (isa<ConstantSDNode>(Val)) + return std::make_pair(SDValue(), SystemZ::CCMASK_NONE); + else if (N->getOpcode() == ISD::CopyFromReg && N->getNumOperands() > 1) { + if (auto *RN = cast<RegisterSDNode>(N->getOperand(1))) { + if (RN->getReg() == SystemZ::CC) + return {Val, SystemZ::CCMASK_ANY}; + } + } else if (N->getOpcode() == SystemZISD::IPM) + return std::make_pair(N->getOperand(0), SystemZ::CCMASK_ANY); + else if (N->getOpcode() == ISD::SRL) + return findCCUse(N->getOperand(0)); + else if (N->getOpcode() == SystemZISD::ICMP) + return findCCUse(N->getOperand(0)); + else if (N->getOpcode() == SystemZISD::TM) + return findCCUse(N->getOperand(0)); + else if (N->getOpcode() == SystemZISD::SELECT_CCMASK) { + SDValue SelectCCReg = N->getOperand(4); + auto [OpCC, OpCCValid] = findCCUse(SelectCCReg); + auto *OpCCNode = OpCC.getNode(); + if (OpCCNode && OpCCNode != SelectCCReg.getNode()) + return std::make_pair(OpCC, OpCCValid); + auto *CCValid = dyn_cast<ConstantSDNode>(N->getOperand(2)); + if (CCValid) + return std::make_pair(SelectCCReg, CCValid->getZExtValue()); + } else if (N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::AND || + N->getOpcode() == ISD::OR || N->getOpcode() == ISD::XOR || + N->getOpcode() == ISD::AND) { + auto [Op0CC, Op0CCValid] = findCCUse(N->getOperand(0)); + if (isa<ConstantSDNode>(N->getOperand(1))) + return std::make_pair(Op0CC, Op0CCValid); + auto [Op1CC, Op1CCValid] = findCCUse(N->getOperand(1)); + auto *N0 = Op0CC.getNode(), *N1 = Op1CC.getNode(); + if (N0 && N1 && N0 == N1 && Op0CCValid == Op1CCValid) + return std::make_pair(Op0CC, Op0CCValid); + } + return {SDValue(), SystemZ::CCMASK_ANY}; +} + +static bool combineCCMask(SDValue &CCReg, int &CCValid, int &CCMask); + +static SmallVector<int, 4> simplifyAssumingCCVal(SDValue &Val, SDValue &CC) { + const auto isValidBinaryOperation = [](const SDValue &Op, SDValue &Op0, + SDValue &Op1, unsigned &Opcode) { + auto *N = Op.getNode(); + if (!N) + return false; + Opcode = N->getOpcode(); + if (Opcode != ISD::ADD && Opcode != ISD::AND && Opcode != ISD::OR && + Opcode != ISD::XOR) + return false; + Op0 = N->getOperand(0); + Op1 = N->getOperand(1); + return true; + }; + auto *C = dyn_cast<ConstantSDNode>(Val); + if (C) { + int ConstVal = C->getZExtValue(); + return {ConstVal, ConstVal, ConstVal, ConstVal}; + } + auto *N = Val.getNode(), *CCNode = CC.getNode(); + if (!N || !CCNode) + return {}; + if (N == CCNode) + return {0, 1, 2, 3}; + if (N->getOpcode() == SystemZISD::IPM) { + SDValue IPMOp0 = N->getOperand(0); + auto &&CCVals = simplifyAssumingCCVal(IPMOp0, CC); + if (CCVals.empty()) + return CCVals; + auto ShiftAmount = SystemZ::IPM_CC; + std::for_each(CCVals.begin(), CCVals.end(), + [&ShiftAmount](auto &V) { V <<= ShiftAmount; }); + return CCVals; + } + if (N->getOpcode() == ISD::SRL) { + SDValue Op0 = N->getOperand(0); + auto *SRLCount = dyn_cast<ConstantSDNode>(N->getOperand(1)); + if (!SRLCount) + return {}; + auto SRLCountVal = SRLCount->getZExtValue(); + auto &&CCVals = simplifyAssumingCCVal(Op0, CC); + if (CCVals.empty()) + return CCVals; + std::for_each(CCVals.begin(), CCVals.end(), + [SRLCountVal](auto &V) { V >>= SRLCountVal; }); + return CCVals; + } + DenseMap<unsigned, std::function<int(int, int)>> BinaryOPS = { + {ISD::ADD, [](int Op1, int Op2) { return Op1 + Op2; }}, + {ISD::AND, [](int Op1, int Op2) { return Op1 & Op2; }}, + {ISD::OR, [](int Op1, int Op2) { return Op1 | Op2; }}, + {ISD::XOR, [](int Op1, int Op2) { return Op1 ^ Op2; }}, + }; + if (N->getOpcode() == SystemZISD::SELECT_CCMASK) { + SDValue TrueVal = N->getOperand(0), FalseVal = N->getOperand(1); + auto *TrueOp = TrueVal.getNode(); + auto *FalseOp = FalseVal.getNode(); + auto *CCValid = dyn_cast<ConstantSDNode>(N->getOperand(2)); + auto *CCMask = dyn_cast<ConstantSDNode>(N->getOperand(3)); + if (!TrueOp || !FalseOp || !CCValid || !CCMask) + return {}; + + int CCValidVal = CCValid->getZExtValue(); + int CCMaskVal = CCMask->getZExtValue(); + auto [Op0CC, Op0CCValid] = findCCUse(TrueVal); + auto &&TrueVals = + simplifyAssumingCCVal(TrueVal, Op0CC == SDValue() ? CC : Op0CC); + auto [Op1CC, Op1CCValid] = findCCUse(FalseVal); + auto &&FalseVals = + simplifyAssumingCCVal(FalseVal, Op1CC == SDValue() ? CC : Op1CC); + SDValue SelectCCReg = N->getOperand(4); + auto [Op4CC, Op4CCValid] = findCCUse(SelectCCReg); + auto Op4CCNode = Op4CC.getNode(); + if (Op4CCNode && Op4CCNode == CCNode && Op4CCNode != SelectCCReg.getNode()) + combineCCMask(SelectCCReg, CCValidVal, CCMaskVal); + SmallVector<int, 4> CCVals; + auto CCMaskValsSet = convertCCMaskToCCValsSet(CCMaskVal); + auto CCValidValsSet = convertCCMaskToCCValsSet(CCValidVal); + for (auto &CCVal : {0, 1, 2, 3}) + CCVals.emplace_back( + (CCMaskValsSet.count(CCVal) && CCValidValsSet.count(CCVal)) + ? TrueVals[CCVal] + : FalseVals[CCVal]); + return CCVals; + } + SDValue Op0, Op1; + unsigned Opcode; + if (isValidBinaryOperation(Val, Op0, Op1, Opcode)) { + auto &&CC0Vals = simplifyAssumingCCVal(Op0, CC); + auto &&CC1Vals = simplifyAssumingCCVal(Op1, CC); + SmallVector<int, 4> CCVals; + for (auto CCVal : {0, 1, 2, 3}) + CCVals.emplace_back(BinaryOPS[Opcode](CC0Vals[CCVal], CC1Vals[CCVal])); + return CCVals; + } + return {}; +} + static bool combineCCMask(SDValue &CCReg, int &CCValid, int &CCMask) { // We have a SELECT_CCMASK or BR_CCMASK comparing the condition code // set by the CCReg instruction using the CCValid / CCMask masks, - // If the CCReg instruction is itself a ICMP testing the condition + // If the CCReg instruction is itself a ICMP / TM testing the condition // code set by some other instruction, see whether we can directly // use that condition code. - - // Verify that we have an ICMP against some constant. - if (CCValid != SystemZ::CCMASK_ICMP) - return false; - auto *ICmp = CCReg.getNode(); - if (ICmp->getOpcode() != SystemZISD::ICMP) - return false; - auto *CompareLHS = ICmp->getOperand(0).getNode(); - auto *CompareRHS = dyn_cast<ConstantSDNode>(ICmp->getOperand(1)); - if (!CompareRHS) + auto *CCNode = CCReg.getNode(); + if (!CCNode) return false; - - // Optimize the case where CompareLHS is a SELECT_CCMASK. - if (CompareLHS->getOpcode() == SystemZISD::SELECT_CCMASK) { - // Verify that we have an appropriate mask for a EQ or NE comparison. - bool Invert = false; - if (CCMask == SystemZ::CCMASK_CMP_NE) - Invert = !Invert; - else if (CCMask != SystemZ::CCMASK_CMP_EQ) + if (CCNode->getOpcode() == SystemZISD::TM) { + if (CCValid != SystemZ::CCMASK_TM) return false; - - // Verify that the ICMP compares against one of select values. - auto *TrueVal = dyn_cast<ConstantSDNode>(CompareLHS->getOperand(0)); - if (!TrueVal) + if ((CCMask != SystemZ::CCMASK_TM_SOME_1) && + (CCMask != SystemZ::CCMASK_TM_ALL_0)) return false; - auto *FalseVal = dyn_cast<ConstantSDNode>(CompareLHS->getOperand(1)); - if (!FalseVal) + SDValue Op0 = CCNode->getOperand(0); + SDValue Op1 = CCNode->getOperand(1); + auto [Op0CC, Op0CCValid] = findCCUse(Op0); + if (Op0CC == SDValue()) return false; - if (CompareRHS->getAPIntValue() == FalseVal->getAPIntValue()) - Invert = !Invert; - else if (CompareRHS->getAPIntValue() != TrueVal->getAPIntValue()) + int TMConstVal; + auto *N = dyn_cast<ConstantSDNode>(Op1); + if (N) { + TMConstVal = N->getZExtValue(); + if ((TMConstVal != (1 << SystemZ::IPM_CC)) && (TMConstVal != 1)) + return false; ---------------- uweigand wrote:
This should likewise not be necessary if we fully implement TM semantics. https://github.com/llvm/llvm-project/pull/125970 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits