================
@@ -8702,61 +8738,264 @@ SDValue SystemZTargetLowering::combineSETCC(
   return SDValue();
 }
 
+static SmallSet<int, 4> convertCCMaskToCCValsSet(int Mask) {
+  SmallSet<int, 4> CCVals;
+  size_t Pos = 0;
+  while (Mask) {
+    if (Mask & 0x1)
+      CCVals.insert(3 - Pos);
+    Mask >>= 1;
+    ++Pos;
+  }
+  return CCVals;
+}
+
+static std::pair<SDValue, int> findCCUse(const SDValue &Val) {
+  auto *N = Val.getNode();
+  if (!N)
+    return {Val, SystemZ::CCMASK_NONE};
+  if (isa<ConstantSDNode>(Val))
+    return std::make_pair(SDValue(), SystemZ::CCMASK_NONE);
+  else if (N->getOpcode() == ISD::CopyFromReg && N->getNumOperands() > 1) {
+    if (auto *RN = cast<RegisterSDNode>(N->getOperand(1))) {
+      if (RN->getReg() == SystemZ::CC)
+        return {Val, SystemZ::CCMASK_ANY};
+    }
+  } else if (N->getOpcode() == SystemZISD::IPM)
+    return std::make_pair(N->getOperand(0), SystemZ::CCMASK_ANY);
+  else if (N->getOpcode() == ISD::SRL)
+    return findCCUse(N->getOperand(0));
+  else if (N->getOpcode() == SystemZISD::ICMP)
+    return findCCUse(N->getOperand(0));
+  else if (N->getOpcode() == SystemZISD::TM)
+    return findCCUse(N->getOperand(0));
+  else if (N->getOpcode() == SystemZISD::SELECT_CCMASK) {
+    SDValue SelectCCReg = N->getOperand(4);
+    auto [OpCC, OpCCValid] = findCCUse(SelectCCReg);
+    auto *OpCCNode = OpCC.getNode();
+    if (OpCCNode && OpCCNode != SelectCCReg.getNode())
+      return std::make_pair(OpCC, OpCCValid);
+    auto *CCValid = dyn_cast<ConstantSDNode>(N->getOperand(2));
+    if (CCValid)
+      return std::make_pair(SelectCCReg, CCValid->getZExtValue());
+  } else if (N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::AND ||
+             N->getOpcode() == ISD::OR || N->getOpcode() == ISD::XOR ||
+             N->getOpcode() == ISD::AND) {
+    auto [Op0CC, Op0CCValid] = findCCUse(N->getOperand(0));
+    if (isa<ConstantSDNode>(N->getOperand(1)))
+      return std::make_pair(Op0CC, Op0CCValid);
+    auto [Op1CC, Op1CCValid] = findCCUse(N->getOperand(1));
+    auto *N0 = Op0CC.getNode(), *N1 = Op1CC.getNode();
+    if (N0 && N1 && N0 == N1 && Op0CCValid == Op1CCValid)
+      return std::make_pair(Op0CC, Op0CCValid);
+  }
+  return {SDValue(), SystemZ::CCMASK_ANY};
+}
+
+static bool combineCCMask(SDValue &CCReg, int &CCValid, int &CCMask);
+
+static SmallVector<int, 4> simplifyAssumingCCVal(SDValue &Val, SDValue &CC) {
+  const auto isValidBinaryOperation = [](const SDValue &Op, SDValue &Op0,
+                                         SDValue &Op1, unsigned &Opcode) {
+    auto *N = Op.getNode();
+    if (!N)
+      return false;
+    Opcode = N->getOpcode();
+    if (Opcode != ISD::ADD && Opcode != ISD::AND && Opcode != ISD::OR &&
+        Opcode != ISD::XOR)
+      return false;
+    Op0 = N->getOperand(0);
+    Op1 = N->getOperand(1);
+    return true;
+  };
+  auto *C = dyn_cast<ConstantSDNode>(Val);
+  if (C) {
+    int ConstVal = C->getZExtValue();
+    return {ConstVal, ConstVal, ConstVal, ConstVal};
+  }
+  auto *N = Val.getNode(), *CCNode = CC.getNode();
+  if (!N || !CCNode)
+    return {};
+  if (N == CCNode)
+    return {0, 1, 2, 3};
+  if (N->getOpcode() == SystemZISD::IPM) {
+    SDValue IPMOp0 = N->getOperand(0);
+    auto &&CCVals = simplifyAssumingCCVal(IPMOp0, CC);
+    if (CCVals.empty())
+      return CCVals;
+    auto ShiftAmount = SystemZ::IPM_CC;
+    std::for_each(CCVals.begin(), CCVals.end(),
+                  [&ShiftAmount](auto &V) { V <<= ShiftAmount; });
+    return CCVals;
+  }
+  if (N->getOpcode() == ISD::SRL) {
+    SDValue Op0 = N->getOperand(0);
+    auto *SRLCount = dyn_cast<ConstantSDNode>(N->getOperand(1));
+    if (!SRLCount)
+      return {};
+    auto SRLCountVal = SRLCount->getZExtValue();
+    auto &&CCVals = simplifyAssumingCCVal(Op0, CC);
+    if (CCVals.empty())
+      return CCVals;
+    std::for_each(CCVals.begin(), CCVals.end(),
+                  [SRLCountVal](auto &V) { V >>= SRLCountVal; });
+    return CCVals;
+  }
+  DenseMap<unsigned, std::function<int(int, int)>> BinaryOPS = {
+      {ISD::ADD, [](int Op1, int Op2) { return Op1 + Op2; }},
+      {ISD::AND, [](int Op1, int Op2) { return Op1 & Op2; }},
+      {ISD::OR, [](int Op1, int Op2) { return Op1 | Op2; }},
+      {ISD::XOR, [](int Op1, int Op2) { return Op1 ^ Op2; }},
+  };
+  if (N->getOpcode() == SystemZISD::SELECT_CCMASK) {
+    SDValue TrueVal = N->getOperand(0), FalseVal = N->getOperand(1);
+    auto *TrueOp = TrueVal.getNode();
+    auto *FalseOp = FalseVal.getNode();
+    auto *CCValid = dyn_cast<ConstantSDNode>(N->getOperand(2));
+    auto *CCMask = dyn_cast<ConstantSDNode>(N->getOperand(3));
+    if (!TrueOp || !FalseOp || !CCValid || !CCMask)
+      return {};
+
+    int CCValidVal = CCValid->getZExtValue();
+    int CCMaskVal = CCMask->getZExtValue();
+    auto [Op0CC, Op0CCValid] = findCCUse(TrueVal);
+    auto &&TrueVals =
+        simplifyAssumingCCVal(TrueVal, Op0CC == SDValue() ? CC : Op0CC);
+    auto [Op1CC, Op1CCValid] = findCCUse(FalseVal);
+    auto &&FalseVals =
+        simplifyAssumingCCVal(FalseVal, Op1CC == SDValue() ? CC : Op1CC);
+    SDValue SelectCCReg = N->getOperand(4);
+    auto [Op4CC, Op4CCValid] = findCCUse(SelectCCReg);
+    auto Op4CCNode = Op4CC.getNode();
+    if (Op4CCNode && Op4CCNode == CCNode && Op4CCNode != SelectCCReg.getNode())
+      combineCCMask(SelectCCReg, CCValidVal, CCMaskVal);
+    SmallVector<int, 4> CCVals;
+    auto CCMaskValsSet = convertCCMaskToCCValsSet(CCMaskVal);
+    auto CCValidValsSet = convertCCMaskToCCValsSet(CCValidVal);
+    for (auto &CCVal : {0, 1, 2, 3})
+      CCVals.emplace_back(
+          (CCMaskValsSet.count(CCVal) && CCValidValsSet.count(CCVal))
+              ? TrueVals[CCVal]
+              : FalseVals[CCVal]);
+    return CCVals;
+  }
+  SDValue Op0, Op1;
+  unsigned Opcode;
+  if (isValidBinaryOperation(Val, Op0, Op1, Opcode)) {
+    auto &&CC0Vals = simplifyAssumingCCVal(Op0, CC);
+    auto &&CC1Vals = simplifyAssumingCCVal(Op1, CC);
+    SmallVector<int, 4> CCVals;
+    for (auto CCVal : {0, 1, 2, 3})
+      CCVals.emplace_back(BinaryOPS[Opcode](CC0Vals[CCVal], CC1Vals[CCVal]));
+    return CCVals;
+  }
+  return {};
+}
+
 static bool combineCCMask(SDValue &CCReg, int &CCValid, int &CCMask) {
   // We have a SELECT_CCMASK or BR_CCMASK comparing the condition code
   // set by the CCReg instruction using the CCValid / CCMask masks,
-  // If the CCReg instruction is itself a ICMP testing the condition
+  // If the CCReg instruction is itself a ICMP / TM  testing the condition
   // code set by some other instruction, see whether we can directly
   // use that condition code.
-
-  // Verify that we have an ICMP against some constant.
-  if (CCValid != SystemZ::CCMASK_ICMP)
-    return false;
-  auto *ICmp = CCReg.getNode();
-  if (ICmp->getOpcode() != SystemZISD::ICMP)
-    return false;
-  auto *CompareLHS = ICmp->getOperand(0).getNode();
-  auto *CompareRHS = dyn_cast<ConstantSDNode>(ICmp->getOperand(1));
-  if (!CompareRHS)
+  auto *CCNode = CCReg.getNode();
+  if (!CCNode)
     return false;
-
-  // Optimize the case where CompareLHS is a SELECT_CCMASK.
-  if (CompareLHS->getOpcode() == SystemZISD::SELECT_CCMASK) {
-    // Verify that we have an appropriate mask for a EQ or NE comparison.
-    bool Invert = false;
-    if (CCMask == SystemZ::CCMASK_CMP_NE)
-      Invert = !Invert;
-    else if (CCMask != SystemZ::CCMASK_CMP_EQ)
+  if (CCNode->getOpcode() == SystemZISD::TM) {
+    if (CCValid != SystemZ::CCMASK_TM)
       return false;
-
-    // Verify that the ICMP compares against one of select values.
-    auto *TrueVal = dyn_cast<ConstantSDNode>(CompareLHS->getOperand(0));
-    if (!TrueVal)
+    if ((CCMask != SystemZ::CCMASK_TM_SOME_1) &&
+        (CCMask != SystemZ::CCMASK_TM_ALL_0))
       return false;
-    auto *FalseVal = dyn_cast<ConstantSDNode>(CompareLHS->getOperand(1));
-    if (!FalseVal)
+    SDValue Op0 = CCNode->getOperand(0);
+    SDValue Op1 = CCNode->getOperand(1);
+    auto [Op0CC, Op0CCValid] = findCCUse(Op0);
+    if (Op0CC == SDValue())
       return false;
-    if (CompareRHS->getAPIntValue() == FalseVal->getAPIntValue())
-      Invert = !Invert;
-    else if (CompareRHS->getAPIntValue() != TrueVal->getAPIntValue())
+    int TMConstVal;
+    auto *N = dyn_cast<ConstantSDNode>(Op1);
+    if (N) {
+      TMConstVal = N->getZExtValue();
+      if ((TMConstVal != (1 << SystemZ::IPM_CC)) && (TMConstVal != 1))
+        return false;
----------------
uweigand wrote:

This should likewise not be necessary if we fully implement TM semantics.

https://github.com/llvm/llvm-project/pull/125970
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to