https://github.com/wizardengineer created 
https://github.com/llvm/llvm-project/pull/166706

This patch implements architecture-specific lowering for ct.select on AArch64
using CSEL (conditional select) instructions for constant-time selection.

Implementation details:
- Uses CSEL family of instructions for scalar integer types
- Uses FCSEL for floating-point types (F16, BF16, F32, F64)
- Post-RA MC lowering to convert pseudo-instructions to real CSEL/FCSEL
- Handles vector types appropriately
- Comprehensive test coverage for AArch64

The implementation includes:
- ISelLowering: Custom lowering to CTSELECT pseudo-instructions
- InstrInfo: Pseudo-instruction definitions and patterns
- MCInstLower: Post-RA lowering of pseudo-instructions to actual CSEL/FCSEL
- Proper handling of condition codes for constant-time guarantees

>From 071428b7a6eed7a800364cc4b9a7e25e1d8e310e Mon Sep 17 00:00:00 2001
From: wizardengineer <[email protected]>
Date: Wed, 5 Nov 2025 17:09:45 -0500
Subject: [PATCH] [LLVM][AArch64] Add native ct.select support for ARM64

This patch implements architecture-specific lowering for ct.select on AArch64
using CSEL (conditional select) instructions for constant-time selection.

Implementation details:
- Uses CSEL family of instructions for scalar integer types
- Uses FCSEL for floating-point types (F16, BF16, F32, F64)
- Post-RA MC lowering to convert pseudo-instructions to real CSEL/FCSEL
- Handles vector types appropriately
- Comprehensive test coverage for AArch64

The implementation includes:
- ISelLowering: Custom lowering to CTSELECT pseudo-instructions
- InstrInfo: Pseudo-instruction definitions and patterns
- MCInstLower: Post-RA lowering of pseudo-instructions to actual CSEL/FCSEL
- Proper handling of condition codes for constant-time guarantees
---
 .../Target/AArch64/AArch64ISelLowering.cpp    |  53 +++++
 llvm/lib/Target/AArch64/AArch64ISelLowering.h |  12 ++
 llvm/lib/Target/AArch64/AArch64InstrInfo.cpp  | 200 ++++++++----------
 llvm/lib/Target/AArch64/AArch64InstrInfo.td   |  45 ++++
 .../lib/Target/AArch64/AArch64MCInstLower.cpp |  18 ++
 llvm/test/CodeGen/AArch64/ctselect.ll         | 153 ++++++++++++++
 6 files changed, 371 insertions(+), 110 deletions(-)
 create mode 100644 llvm/test/CodeGen/AArch64/ctselect.ll

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp 
b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 60aa61e993b26..a86aac88b94a8 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -511,12 +511,35 @@ AArch64TargetLowering::AArch64TargetLowering(const 
TargetMachine &TM,
   setOperationAction(ISD::BR_CC, MVT::f64, Custom);
   setOperationAction(ISD::SELECT, MVT::i32, Custom);
   setOperationAction(ISD::SELECT, MVT::i64, Custom);
+  setOperationAction(ISD::CTSELECT, MVT::i8,  Promote);
+  setOperationAction(ISD::CTSELECT, MVT::i16, Promote);
+  setOperationAction(ISD::CTSELECT, MVT::i32, Custom);
+  setOperationAction(ISD::CTSELECT, MVT::i64, Custom);
   if (Subtarget->hasFPARMv8()) {
     setOperationAction(ISD::SELECT, MVT::f16, Custom);
     setOperationAction(ISD::SELECT, MVT::bf16, Custom);
   }
+  if (Subtarget->hasFullFP16()) {
+    setOperationAction(ISD::CTSELECT, MVT::f16, Custom);
+    setOperationAction(ISD::CTSELECT, MVT::bf16, Custom);
+  } else {
+    setOperationAction(ISD::CTSELECT, MVT::f16, Promote);
+    setOperationAction(ISD::CTSELECT, MVT::bf16, Promote);
+  }
   setOperationAction(ISD::SELECT, MVT::f32, Custom);
   setOperationAction(ISD::SELECT, MVT::f64, Custom);
+  setOperationAction(ISD::CTSELECT, MVT::f32, Custom);
+  setOperationAction(ISD::CTSELECT, MVT::f64, Custom);
+  for (MVT VT : MVT::vector_valuetypes()) {
+    MVT elemType = VT.getVectorElementType();
+    if (elemType == MVT::i8 || elemType == MVT::i16) {
+      setOperationAction(ISD::CTSELECT, VT, Promote);
+    } else if ((elemType == MVT::f16 || elemType == MVT::bf16) && 
!Subtarget->hasFullFP16()) {
+      setOperationAction(ISD::CTSELECT, VT, Promote);
+    } else {
+      setOperationAction(ISD::CTSELECT, VT, Expand);
+    }
+  }
   setOperationAction(ISD::SELECT_CC, MVT::i32, Custom);
   setOperationAction(ISD::SELECT_CC, MVT::i64, Custom);
   setOperationAction(ISD::SELECT_CC, MVT::f16, Custom);
@@ -3328,6 +3351,18 @@ void AArch64TargetLowering::fixupPtrauthDiscriminator(
   IntDiscOp.setImm(IntDisc);
 }
 
+MachineBasicBlock *AArch64TargetLowering::EmitCTSELECT(MachineInstr &MI, 
MachineBasicBlock *MBB, unsigned Opcode) const {
+  const TargetInstrInfo *TII = Subtarget->getInstrInfo();
+  DebugLoc DL = MI.getDebugLoc();
+  MachineInstrBuilder Builder = BuildMI(*MBB, MI, DL, TII->get(Opcode));
+  for (unsigned Idx = 0; Idx < MI.getNumOperands(); ++Idx) {
+    Builder.add(MI.getOperand(Idx));
+  }
+  Builder->setFlag(MachineInstr::NoMerge);
+  MBB->remove_instr(&MI);
+  return MBB;
+}
+
 MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
     MachineInstr &MI, MachineBasicBlock *BB) const {
 
@@ -7590,6 +7625,8 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
     return LowerSELECT(Op, DAG);
   case ISD::SELECT_CC:
     return LowerSELECT_CC(Op, DAG);
+  case ISD::CTSELECT:
+    return LowerCTSELECT(Op, DAG);
   case ISD::JumpTable:
     return LowerJumpTable(Op, DAG);
   case ISD::BR_JT:
@@ -12149,6 +12186,22 @@ SDValue AArch64TargetLowering::LowerSELECT(SDValue Op,
   return Res;
 }
 
+SDValue AArch64TargetLowering::LowerCTSELECT(SDValue Op,
+                                             SelectionDAG &DAG) const {
+  SDValue CCVal = Op->getOperand(0);
+  SDValue TVal = Op->getOperand(1);
+  SDValue FVal = Op->getOperand(2);
+  SDLoc DL(Op);
+
+  EVT VT = Op.getValueType();
+
+  SDValue Zero = DAG.getConstant(0, DL, CCVal.getValueType());
+  SDValue CC;
+  SDValue Cmp = getAArch64Cmp(CCVal, Zero, ISD::SETNE, CC, DAG, DL);
+
+  return DAG.getNode(AArch64ISD::CTSELECT, DL, VT, TVal, FVal, CC, Cmp);
+}
+
 SDValue AArch64TargetLowering::LowerJumpTable(SDValue Op,
                                               SelectionDAG &DAG) const {
   // Jump table entries as PC relative offsets. No additional tweaking
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h 
b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 2cb8ed29f252a..d14d64ffe88b6 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -23,6 +23,11 @@
 
 namespace llvm {
 
+namespace AArch64ISD {
+// Forward declare the enum from the generated file
+enum GenNodeType : unsigned;
+} // namespace AArch64ISD
+
 class AArch64TargetMachine;
 
 namespace AArch64 {
@@ -202,6 +207,8 @@ class AArch64TargetLowering : public TargetLowering {
                                  MachineOperand &AddrDiscOp,
                                  const TargetRegisterClass *AddrDiscRC) const;
 
+  MachineBasicBlock *EmitCTSELECT(MachineInstr &MI, MachineBasicBlock *BB, 
unsigned Opcode) const;
+
   MachineBasicBlock *
   EmitInstrWithCustomInserter(MachineInstr &MI,
                               MachineBasicBlock *MBB) const override;
@@ -684,6 +691,7 @@ class AArch64TargetLowering : public TargetLowering {
                          iterator_range<SDNode::user_iterator> Users,
                          SDNodeFlags Flags, const SDLoc &dl,
                          SelectionDAG &DAG) const;
+  SDValue LowerCTSELECT(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerINIT_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerADJUST_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerJumpTable(SDValue Op, SelectionDAG &DAG) const;
@@ -919,6 +927,10 @@ class AArch64TargetLowering : public TargetLowering {
   bool hasMultipleConditionRegisters(EVT VT) const override {
     return VT.isScalableVector();
   }
+
+  bool isSelectSupported(SelectSupportKind Kind) const override {
+    return true;
+  }
 };
 
 namespace AArch64 {
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp 
b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index ccc8eb8a9706d..227e5d59610d6 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -700,7 +700,7 @@ static unsigned removeCopies(const MachineRegisterInfo 
&MRI, unsigned VReg) {
 // csel instruction. If so, return the folded opcode, and the replacement
 // register.
 static unsigned canFoldIntoCSel(const MachineRegisterInfo &MRI, unsigned VReg,
-                                unsigned *NewReg = nullptr) {
+                                unsigned *NewVReg = nullptr) {
   VReg = removeCopies(MRI, VReg);
   if (!Register::isVirtualRegister(VReg))
     return 0;
@@ -708,37 +708,8 @@ static unsigned canFoldIntoCSel(const MachineRegisterInfo 
&MRI, unsigned VReg,
   bool Is64Bit = 
AArch64::GPR64allRegClass.hasSubClassEq(MRI.getRegClass(VReg));
   const MachineInstr *DefMI = MRI.getVRegDef(VReg);
   unsigned Opc = 0;
-  unsigned SrcReg = 0;
+  unsigned SrcOpNum = 0;
   switch (DefMI->getOpcode()) {
-  case AArch64::SUBREG_TO_REG:
-    // Check for the following way to define an 64-bit immediate:
-    //   %0:gpr32 = MOVi32imm 1
-    //   %1:gpr64 = SUBREG_TO_REG 0, %0:gpr32, %subreg.sub_32
-    if (!DefMI->getOperand(1).isImm() || DefMI->getOperand(1).getImm() != 0)
-      return 0;
-    if (!DefMI->getOperand(2).isReg())
-      return 0;
-    if (!DefMI->getOperand(3).isImm() ||
-        DefMI->getOperand(3).getImm() != AArch64::sub_32)
-      return 0;
-    DefMI = MRI.getVRegDef(DefMI->getOperand(2).getReg());
-    if (DefMI->getOpcode() != AArch64::MOVi32imm)
-      return 0;
-    if (!DefMI->getOperand(1).isImm() || DefMI->getOperand(1).getImm() != 1)
-      return 0;
-    assert(Is64Bit);
-    SrcReg = AArch64::XZR;
-    Opc = AArch64::CSINCXr;
-    break;
-
-  case AArch64::MOVi32imm:
-  case AArch64::MOVi64imm:
-    if (!DefMI->getOperand(1).isImm() || DefMI->getOperand(1).getImm() != 1)
-      return 0;
-    SrcReg = Is64Bit ? AArch64::XZR : AArch64::WZR;
-    Opc = Is64Bit ? AArch64::CSINCXr : AArch64::CSINCWr;
-    break;
-
   case AArch64::ADDSXri:
   case AArch64::ADDSWri:
     // if NZCV is used, do not fold.
@@ -753,7 +724,7 @@ static unsigned canFoldIntoCSel(const MachineRegisterInfo 
&MRI, unsigned VReg,
     if (!DefMI->getOperand(2).isImm() || DefMI->getOperand(2).getImm() != 1 ||
         DefMI->getOperand(3).getImm() != 0)
       return 0;
-    SrcReg = DefMI->getOperand(1).getReg();
+    SrcOpNum = 1;
     Opc = Is64Bit ? AArch64::CSINCXr : AArch64::CSINCWr;
     break;
 
@@ -763,7 +734,7 @@ static unsigned canFoldIntoCSel(const MachineRegisterInfo 
&MRI, unsigned VReg,
     unsigned ZReg = removeCopies(MRI, DefMI->getOperand(1).getReg());
     if (ZReg != AArch64::XZR && ZReg != AArch64::WZR)
       return 0;
-    SrcReg = DefMI->getOperand(2).getReg();
+    SrcOpNum = 2;
     Opc = Is64Bit ? AArch64::CSINVXr : AArch64::CSINVWr;
     break;
   }
@@ -782,17 +753,17 @@ static unsigned canFoldIntoCSel(const MachineRegisterInfo 
&MRI, unsigned VReg,
     unsigned ZReg = removeCopies(MRI, DefMI->getOperand(1).getReg());
     if (ZReg != AArch64::XZR && ZReg != AArch64::WZR)
       return 0;
-    SrcReg = DefMI->getOperand(2).getReg();
+    SrcOpNum = 2;
     Opc = Is64Bit ? AArch64::CSNEGXr : AArch64::CSNEGWr;
     break;
   }
   default:
     return 0;
   }
-  assert(Opc && SrcReg && "Missing parameters");
+  assert(Opc && SrcOpNum && "Missing parameters");
 
-  if (NewReg)
-    *NewReg = SrcReg;
+  if (NewVReg)
+    *NewVReg = DefMI->getOperand(SrcOpNum).getReg();
   return Opc;
 }
 
@@ -993,34 +964,28 @@ void AArch64InstrInfo::insertSelect(MachineBasicBlock 
&MBB,
 
   // Try folding simple instructions into the csel.
   if (TryFold) {
-    unsigned NewReg = 0;
-    unsigned FoldedOpc = canFoldIntoCSel(MRI, TrueReg, &NewReg);
+    unsigned NewVReg = 0;
+    unsigned FoldedOpc = canFoldIntoCSel(MRI, TrueReg, &NewVReg);
     if (FoldedOpc) {
       // The folded opcodes csinc, csinc and csneg apply the operation to
       // FalseReg, so we need to invert the condition.
       CC = AArch64CC::getInvertedCondCode(CC);
       TrueReg = FalseReg;
     } else
-      FoldedOpc = canFoldIntoCSel(MRI, FalseReg, &NewReg);
+      FoldedOpc = canFoldIntoCSel(MRI, FalseReg, &NewVReg);
 
     // Fold the operation. Leave any dead instructions for DCE to clean up.
     if (FoldedOpc) {
-      FalseReg = NewReg;
+      FalseReg = NewVReg;
       Opc = FoldedOpc;
-      // Extend the live range of NewReg.
-      MRI.clearKillFlags(NewReg);
+      // The extends the live range of NewVReg.
+      MRI.clearKillFlags(NewVReg);
     }
   }
 
   // Pull all virtual register into the appropriate class.
   MRI.constrainRegClass(TrueReg, RC);
-  // FalseReg might be WZR or XZR if the folded operand is a literal 1.
-  assert(
-      (FalseReg.isVirtual() || FalseReg == AArch64::WZR ||
-       FalseReg == AArch64::XZR) &&
-      "FalseReg was folded into a non-virtual register other than WZR or XZR");
-  if (FalseReg.isVirtual())
-    MRI.constrainRegClass(FalseReg, RC);
+  MRI.constrainRegClass(FalseReg, RC);
 
   // Insert the csel.
   BuildMI(MBB, I, DL, get(Opc), DstReg)
@@ -2148,16 +2113,46 @@ bool AArch64InstrInfo::removeCmpToZeroOrOne(
   return true;
 }
 
-bool AArch64InstrInfo::expandPostRAPseudo(MachineInstr &MI) const {
-  if (MI.getOpcode() != TargetOpcode::LOAD_STACK_GUARD &&
-      MI.getOpcode() != AArch64::CATCHRET)
-    return false;
+static inline void expandCtSelect(MachineBasicBlock &MBB, MachineInstr &MI, 
DebugLoc &DL, const MCInstrDesc &MCID) {
+  MachineInstrBuilder Builder = BuildMI(MBB, MI, DL, MCID);
+  for (unsigned Idx = 0; Idx < MI.getNumOperands(); ++Idx) {
+    Builder.add(MI.getOperand(Idx));
+  }
+  Builder->setFlag(MachineInstr::NoMerge);
+  MBB.remove_instr(&MI);
+}
 
+bool AArch64InstrInfo::expandPostRAPseudo(MachineInstr &MI) const {
   MachineBasicBlock &MBB = *MI.getParent();
   auto &Subtarget = MBB.getParent()->getSubtarget<AArch64Subtarget>();
   auto TRI = Subtarget.getRegisterInfo();
   DebugLoc DL = MI.getDebugLoc();
 
+  switch (MI.getOpcode()) {
+    case AArch64::I32CTSELECT:
+      expandCtSelect(MBB, MI, DL, get(AArch64::CSELWr));
+      return true;
+    case AArch64::I64CTSELECT:
+      expandCtSelect(MBB, MI, DL, get(AArch64::CSELXr));
+      return true;
+    case AArch64::BF16CTSELECT:
+      expandCtSelect(MBB, MI, DL, get(AArch64::FCSELHrrr));
+      return true;
+    case AArch64::F16CTSELECT:
+      expandCtSelect(MBB, MI, DL, get(AArch64::FCSELHrrr));
+      return true;
+    case AArch64::F32CTSELECT:
+      expandCtSelect(MBB, MI, DL, get(AArch64::FCSELSrrr));
+      return true;
+    case AArch64::F64CTSELECT:
+      expandCtSelect(MBB, MI, DL, get(AArch64::FCSELDrrr));
+      return true;
+  }
+
+  if (MI.getOpcode() != TargetOpcode::LOAD_STACK_GUARD &&
+      MI.getOpcode() != AArch64::CATCHRET)
+    return false;
+
   if (MI.getOpcode() == AArch64::CATCHRET) {
     // Skip to the first instruction before the epilog.
     const TargetInstrInfo *TII =
@@ -5098,7 +5093,7 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB,
                                    bool RenamableDest,
                                    bool RenamableSrc) const {
   if (AArch64::GPR32spRegClass.contains(DestReg) &&
-      AArch64::GPR32spRegClass.contains(SrcReg)) {
+      (AArch64::GPR32spRegClass.contains(SrcReg) || SrcReg == AArch64::WZR)) {
     if (DestReg == AArch64::WSP || SrcReg == AArch64::WSP) {
       // If either operand is WSP, expand to ADD #0.
       if (Subtarget.hasZeroCycleRegMoveGPR64() &&
@@ -5123,14 +5118,31 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock 
&MBB,
             .addImm(0)
             .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0));
       }
+    } else if (SrcReg == AArch64::WZR &&
+               Subtarget.hasZeroCycleZeroingGPR64() &&
+               !Subtarget.hasZeroCycleZeroingGPR32()) {
+      // Use 64-bit zeroing when available but 32-bit zeroing is not
+      MCRegister DestRegX = RI.getMatchingSuperReg(DestReg, AArch64::sub_32,
+                                                   &AArch64::GPR64spRegClass);
+      assert(DestRegX.isValid() && "Destination super-reg not valid");
+      BuildMI(MBB, I, DL, get(AArch64::MOVZXi), DestRegX)
+          .addImm(0)
+          .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0));
+    } else if (SrcReg == AArch64::WZR && Subtarget.hasZeroCycleZeroingGPR32()) 
{
+      BuildMI(MBB, I, DL, get(AArch64::MOVZWi), DestReg)
+          .addImm(0)
+          .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0));
     } else if (Subtarget.hasZeroCycleRegMoveGPR64() &&
                !Subtarget.hasZeroCycleRegMoveGPR32()) {
       // Cyclone recognizes "ORR Xd, XZR, Xm" as a zero-cycle register move.
       MCRegister DestRegX = RI.getMatchingSuperReg(DestReg, AArch64::sub_32,
                                                    &AArch64::GPR64spRegClass);
       assert(DestRegX.isValid() && "Destination super-reg not valid");
-      MCRegister SrcRegX = RI.getMatchingSuperReg(SrcReg, AArch64::sub_32,
-                                                  &AArch64::GPR64spRegClass);
+      MCRegister SrcRegX =
+          SrcReg == AArch64::WZR
+              ? AArch64::XZR
+              : RI.getMatchingSuperReg(SrcReg, AArch64::sub_32,
+                                       &AArch64::GPR64spRegClass);
       assert(SrcRegX.isValid() && "Source super-reg not valid");
       // This instruction is reading and writing X registers.  This may upset
       // the register scavenger and machine verifier, so we need to indicate
@@ -5149,59 +5161,6 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock 
&MBB,
     return;
   }
 
-  // GPR32 zeroing
-  if (AArch64::GPR32spRegClass.contains(DestReg) && SrcReg == AArch64::WZR) {
-    if (Subtarget.hasZeroCycleZeroingGPR64() &&
-        !Subtarget.hasZeroCycleZeroingGPR32()) {
-      MCRegister DestRegX = RI.getMatchingSuperReg(DestReg, AArch64::sub_32,
-                                                   &AArch64::GPR64spRegClass);
-      assert(DestRegX.isValid() && "Destination super-reg not valid");
-      BuildMI(MBB, I, DL, get(AArch64::MOVZXi), DestRegX)
-          .addImm(0)
-          .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0));
-    } else if (Subtarget.hasZeroCycleZeroingGPR32()) {
-      BuildMI(MBB, I, DL, get(AArch64::MOVZWi), DestReg)
-          .addImm(0)
-          .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0));
-    } else {
-      BuildMI(MBB, I, DL, get(AArch64::ORRWrr), DestReg)
-          .addReg(AArch64::WZR)
-          .addReg(AArch64::WZR);
-    }
-    return;
-  }
-
-  if (AArch64::GPR64spRegClass.contains(DestReg) &&
-      AArch64::GPR64spRegClass.contains(SrcReg)) {
-    if (DestReg == AArch64::SP || SrcReg == AArch64::SP) {
-      // If either operand is SP, expand to ADD #0.
-      BuildMI(MBB, I, DL, get(AArch64::ADDXri), DestReg)
-          .addReg(SrcReg, getKillRegState(KillSrc))
-          .addImm(0)
-          .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0));
-    } else {
-      // Otherwise, expand to ORR XZR.
-      BuildMI(MBB, I, DL, get(AArch64::ORRXrr), DestReg)
-          .addReg(AArch64::XZR)
-          .addReg(SrcReg, getKillRegState(KillSrc));
-    }
-    return;
-  }
-
-  // GPR64 zeroing
-  if (AArch64::GPR64spRegClass.contains(DestReg) && SrcReg == AArch64::XZR) {
-    if (Subtarget.hasZeroCycleZeroingGPR64()) {
-      BuildMI(MBB, I, DL, get(AArch64::MOVZXi), DestReg)
-          .addImm(0)
-          .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0));
-    } else {
-      BuildMI(MBB, I, DL, get(AArch64::ORRXrr), DestReg)
-          .addReg(AArch64::XZR)
-          .addReg(AArch64::XZR);
-    }
-    return;
-  }
-
   // Copy a Predicate register by ORRing with itself.
   if (AArch64::PPRRegClass.contains(DestReg) &&
       AArch64::PPRRegClass.contains(SrcReg)) {
@@ -5286,6 +5245,27 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock 
&MBB,
     return;
   }
 
+  if (AArch64::GPR64spRegClass.contains(DestReg) &&
+      (AArch64::GPR64spRegClass.contains(SrcReg) || SrcReg == AArch64::XZR)) {
+    if (DestReg == AArch64::SP || SrcReg == AArch64::SP) {
+      // If either operand is SP, expand to ADD #0.
+      BuildMI(MBB, I, DL, get(AArch64::ADDXri), DestReg)
+          .addReg(SrcReg, getKillRegState(KillSrc))
+          .addImm(0)
+          .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0));
+    } else if (SrcReg == AArch64::XZR && Subtarget.hasZeroCycleZeroingGPR64()) 
{
+      BuildMI(MBB, I, DL, get(AArch64::MOVZXi), DestReg)
+          .addImm(0)
+          .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0));
+    } else {
+      // Otherwise, expand to ORR XZR.
+      BuildMI(MBB, I, DL, get(AArch64::ORRXrr), DestReg)
+          .addReg(AArch64::XZR)
+          .addReg(SrcReg, getKillRegState(KillSrc));
+    }
+    return;
+  }
+
   // Copy a DDDD register quad by copying the individual sub-registers.
   if (AArch64::DDDDRegClass.contains(DestReg) &&
       AArch64::DDDDRegClass.contains(SrcReg)) {
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td 
b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 2871a20e28b65..3a8fdc329f129 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -476,6 +476,11 @@ def SDT_AArch64cbz : SDTypeProfile<0, 2, [SDTCisInt<0>, 
SDTCisVT<1, OtherVT>]>;
 def SDT_AArch64tbz : SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisInt<1>,
                                         SDTCisVT<2, OtherVT>]>;
 
+def SDT_AArch64CtSelect : SDTypeProfile<1, 4,
+                                   [SDTCisSameAs<0, 1>,
+                                    SDTCisSameAs<0, 2>,
+                                    SDTCisInt<3>,
+                                    SDTCisVT<4, i32>]>;
 def SDT_AArch64CSel  : SDTypeProfile<1, 4,
                                    [SDTCisSameAs<0, 1>,
                                     SDTCisSameAs<0, 2>,
@@ -843,6 +848,7 @@ def AArch64tbz           : SDNode<"AArch64ISD::TBZ", 
SDT_AArch64tbz,
 def AArch64tbnz           : SDNode<"AArch64ISD::TBNZ", SDT_AArch64tbz,
                                 [SDNPHasChain]>;
 
+def AArch64ctselect      : SDNode<"AArch64ISD::CTSELECT", SDT_AArch64CtSelect>;
 
 def AArch64csel          : SDNode<"AArch64ISD::CSEL", SDT_AArch64CSel>;
 // Conditional select invert.
@@ -5644,6 +5650,45 @@ def F128CSEL : Pseudo<(outs FPR128:$Rd),
   let hasNoSchedulingInfo = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Constant-time conditional selection instructions
+//===----------------------------------------------------------------------===//
+
+let hasSideEffects = 1, isPseudo = 1, hasNoSchedulingInfo = 1, Uses = [NZCV] 
in {
+  def I32CTSELECT : Pseudo<(outs GPR32:$dst),
+                           (ins GPR32:$tval, GPR32:$fval, i32imm:$cc),
+                           [(set (i32 GPR32:$dst),
+                                 (AArch64ctselect GPR32:$tval, GPR32:$fval,
+                                                (i32 imm:$cc), NZCV))]>;
+  def I64CTSELECT : Pseudo<(outs GPR64:$dst),
+                           (ins GPR64:$tval, GPR64:$fval, i32imm:$cc),
+                           [(set (i64 GPR64:$dst),
+                                 (AArch64ctselect GPR64:$tval, GPR64:$fval,
+                                                (i32 imm:$cc), NZCV))]>;
+  let Predicates = [HasFullFP16] in {
+    def F16CTSELECT : Pseudo<(outs FPR16:$dst),
+                            (ins FPR16:$tval, FPR16:$fval, i32imm:$cc),
+                            [(set (f16 FPR16:$dst),
+                                  (AArch64ctselect (f16 FPR16:$tval), (f16 
FPR16:$fval),
+                                                  (i32 imm:$cc), NZCV))]>;
+    def BF16CTSELECT : Pseudo<(outs FPR16:$dst),
+                            (ins FPR16:$tval, FPR16:$fval, i32imm:$cc),
+                            [(set (bf16 FPR16:$dst),
+                                  (AArch64ctselect (bf16 FPR16:$tval), (bf16 
FPR16:$fval),
+                                                  (i32 imm:$cc), NZCV))]>;
+  }
+  def F32CTSELECT : Pseudo<(outs FPR32:$dst),
+                           (ins FPR32:$tval, FPR32:$fval, i32imm:$cc),
+                           [(set (f32 FPR32:$dst),
+                                 (AArch64ctselect FPR32:$tval, FPR32:$fval,
+                                                (i32 imm:$cc), NZCV))]>;
+  def F64CTSELECT : Pseudo<(outs FPR64:$dst),
+                           (ins FPR64:$tval, FPR64:$fval, i32imm:$cc),
+                           [(set (f64 FPR64:$dst),
+                                 (AArch64ctselect FPR64:$tval, FPR64:$fval,
+                                                (i32 imm:$cc), NZCV))]>;
+}
+
 
//===----------------------------------------------------------------------===//
 // Instructions used for emitting unwind opcodes on ARM64 Windows.
 
//===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/AArch64/AArch64MCInstLower.cpp 
b/llvm/lib/Target/AArch64/AArch64MCInstLower.cpp
index 39946633603f6..e2ec9118eb5ee 100644
--- a/llvm/lib/Target/AArch64/AArch64MCInstLower.cpp
+++ b/llvm/lib/Target/AArch64/AArch64MCInstLower.cpp
@@ -393,5 +393,23 @@ void AArch64MCInstLower::Lower(const MachineInstr *MI, 
MCInst &OutMI) const {
     OutMI.setOpcode(AArch64::RET);
     OutMI.addOperand(MCOperand::createReg(AArch64::LR));
     break;
+  case AArch64::I32CTSELECT:
+    OutMI.setOpcode(AArch64::CSELWr);
+    break;
+  case AArch64::I64CTSELECT:
+    OutMI.setOpcode(AArch64::CSELXr);
+    break;
+  case AArch64::BF16CTSELECT:
+    OutMI.setOpcode(AArch64::FCSELHrrr);
+    break;
+  case AArch64::F16CTSELECT:
+    OutMI.setOpcode(AArch64::FCSELHrrr);
+    break;
+  case AArch64::F32CTSELECT:
+    OutMI.setOpcode(AArch64::FCSELSrrr);
+    break;
+  case AArch64::F64CTSELECT:
+    OutMI.setOpcode(AArch64::FCSELDrrr);
+    break;
   }
 }
diff --git a/llvm/test/CodeGen/AArch64/ctselect.ll 
b/llvm/test/CodeGen/AArch64/ctselect.ll
new file mode 100644
index 0000000000000..77e9cf24e56cf
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/ctselect.ll
@@ -0,0 +1,153 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py 
UTC_ARGS: --version 6
+; RUN: llc < %s -verify-machineinstrs -mtriple=aarch64-none-eabi | FileCheck 
%s --check-prefixes=DEFAULT,NOFP16
+; RUN: llc < %s -verify-machineinstrs -mtriple=aarch64-none-eabi 
-mattr=+fullfp16 | FileCheck %s --check-prefixes=DEFAULT,FP16
+
+define i1 @ct_i1(i1 %cond, i1 %a, i1 %b) {
+; DEFAULT-LABEL: ct_i1:
+; DEFAULT:       // %bb.0:
+; DEFAULT-NEXT:    tst w0, #0x1
+; DEFAULT-NEXT:    csel w8, w1, w2, ne
+; DEFAULT-NEXT:    and w0, w8, #0x1
+; DEFAULT-NEXT:    ret
+  %1 = call i1 @llvm.ct.select.i1(i1 %cond, i1 %a, i1 %b)
+  ret i1 %1
+}
+
+define i8 @ct_i8(i1 %cond, i8 %a, i8 %b) {
+; DEFAULT-LABEL: ct_i8:
+; DEFAULT:       // %bb.0:
+; DEFAULT-NEXT:    tst w0, #0x1
+; DEFAULT-NEXT:    csel w0, w1, w2, ne
+; DEFAULT-NEXT:    ret
+  %1 = call i8 @llvm.ct.select.i8(i1 %cond, i8 %a, i8 %b)
+  ret i8 %1
+}
+
+define i16 @ct_i16(i1 %cond, i16 %a, i16 %b) {
+; DEFAULT-LABEL: ct_i16:
+; DEFAULT:       // %bb.0:
+; DEFAULT-NEXT:    tst w0, #0x1
+; DEFAULT-NEXT:    csel w0, w1, w2, ne
+; DEFAULT-NEXT:    ret
+  %1 = call i16 @llvm.ct.select.i16(i1 %cond, i16 %a, i16 %b)
+  ret i16 %1
+}
+
+define i32 @ct_i32(i1 %cond, i32 %a, i32 %b) {
+; DEFAULT-LABEL: ct_i32:
+; DEFAULT:       // %bb.0:
+; DEFAULT-NEXT:    tst w0, #0x1
+; DEFAULT-NEXT:    csel w0, w1, w2, ne
+; DEFAULT-NEXT:    ret
+  %1 = call i32 @llvm.ct.select.i32(i1 %cond, i32 %a, i32 %b)
+  ret i32 %1
+}
+
+define i64 @ct_i64(i1 %cond, i64 %a, i64 %b) {
+; DEFAULT-LABEL: ct_i64:
+; DEFAULT:       // %bb.0:
+; DEFAULT-NEXT:    tst w0, #0x1
+; DEFAULT-NEXT:    csel x0, x1, x2, ne
+; DEFAULT-NEXT:    ret
+  %1 = call i64 @llvm.ct.select.i64(i1 %cond, i64 %a, i64 %b)
+  ret i64 %1
+}
+
+define i128 @ct_i128(i1 %cond, i128 %a, i128 %b) {
+; DEFAULT-LABEL: ct_i128:
+; DEFAULT:       // %bb.0:
+; DEFAULT-NEXT:    tst w0, #0x1
+; DEFAULT-NEXT:    csel x0, x2, x4, ne
+; DEFAULT-NEXT:    csel x1, x3, x5, ne
+; DEFAULT-NEXT:    ret
+  %1 = call i128 @llvm.ct.select.i128(i1 %cond, i128 %a, i128 %b)
+  ret i128 %1
+}
+
+define half @ct_f16(i1 %cond, half %a, half %b) {
+; NOFP16-LABEL: ct_f16:
+; NOFP16:       // %bb.0:
+; NOFP16-NEXT:    fcvt s1, h1
+; NOFP16-NEXT:    fcvt s0, h0
+; NOFP16-NEXT:    tst w0, #0x1
+; NOFP16-NEXT:    fcsel s0, s0, s1, ne
+; NOFP16-NEXT:    fcvt h0, s0
+; NOFP16-NEXT:    ret
+;
+; FP16-LABEL: ct_f16:
+; FP16:       // %bb.0:
+; FP16-NEXT:    tst w0, #0x1
+; FP16-NEXT:    fcsel h0, h0, h1, ne
+; FP16-NEXT:    ret
+  %1 = call half @llvm.ct.select.f16(i1 %cond, half %a, half %b)
+  ret half %1
+}
+
+define float @ct_f32(i1 %cond, float %a, float %b) {
+; DEFAULT-LABEL: ct_f32:
+; DEFAULT:       // %bb.0:
+; DEFAULT-NEXT:    tst w0, #0x1
+; DEFAULT-NEXT:    fcsel s0, s0, s1, ne
+; DEFAULT-NEXT:    ret
+  %1 = call float @llvm.ct.select.f32(i1 %cond, float %a, float %b)
+  ret float %1
+}
+
+define double @ct_f64(i1 %cond, double %a, double %b) {
+; DEFAULT-LABEL: ct_f64:
+; DEFAULT:       // %bb.0:
+; DEFAULT-NEXT:    tst w0, #0x1
+; DEFAULT-NEXT:    fcsel d0, d0, d1, ne
+; DEFAULT-NEXT:    ret
+  %1 = call double @llvm.ct.select.f64(i1 %cond, double %a, double %b)
+  ret double %1
+}
+
+define <4 x i32> @ct_v4i32(i1 %cond, <4 x i32> %a, <4 x i32> %b) {
+; DEFAULT-LABEL: ct_v4i32:
+; DEFAULT:       // %bb.0:
+; DEFAULT-NEXT:    mov w8, v1.s[1]
+; DEFAULT-NEXT:    mov w9, v0.s[1]
+; DEFAULT-NEXT:    tst w0, #0x1
+; DEFAULT-NEXT:    fmov w10, s1
+; DEFAULT-NEXT:    fmov w11, s0
+; DEFAULT-NEXT:    csel w8, w9, w8, ne
+; DEFAULT-NEXT:    csel w9, w11, w10, ne
+; DEFAULT-NEXT:    mov w10, v1.s[2]
+; DEFAULT-NEXT:    fmov s2, w9
+; DEFAULT-NEXT:    mov w11, v0.s[2]
+; DEFAULT-NEXT:    mov w9, v0.s[3]
+; DEFAULT-NEXT:    mov v2.s[1], w8
+; DEFAULT-NEXT:    mov w8, v1.s[3]
+; DEFAULT-NEXT:    csel w10, w11, w10, ne
+; DEFAULT-NEXT:    mov v2.s[2], w10
+; DEFAULT-NEXT:    csel w8, w9, w8, ne
+; DEFAULT-NEXT:    mov v2.s[3], w8
+; DEFAULT-NEXT:    mov v0.16b, v2.16b
+; DEFAULT-NEXT:    ret
+  %1 = call <4 x i32> @llvm.ct.select.v4i32(i1 %cond, <4 x i32> %a, <4 x i32> 
%b)
+  ret <4 x i32> %1
+}
+
+define <4 x float> @ct_v4f32(i1 %cond, <4 x float> %a, <4 x float> %b) {
+; DEFAULT-LABEL: ct_v4f32:
+; DEFAULT:       // %bb.0:
+; DEFAULT-NEXT:    mov s2, v1.s[1]
+; DEFAULT-NEXT:    mov s3, v0.s[1]
+; DEFAULT-NEXT:    tst w0, #0x1
+; DEFAULT-NEXT:    mov s4, v1.s[2]
+; DEFAULT-NEXT:    mov s5, v0.s[2]
+; DEFAULT-NEXT:    fcsel s3, s3, s2, ne
+; DEFAULT-NEXT:    fcsel s2, s0, s1, ne
+; DEFAULT-NEXT:    mov s1, v1.s[3]
+; DEFAULT-NEXT:    mov s0, v0.s[3]
+; DEFAULT-NEXT:    mov v2.s[1], v3.s[0]
+; DEFAULT-NEXT:    fcsel s3, s5, s4, ne
+; DEFAULT-NEXT:    fcsel s0, s0, s1, ne
+; DEFAULT-NEXT:    mov v2.s[2], v3.s[0]
+; DEFAULT-NEXT:    mov v2.s[3], v0.s[0]
+; DEFAULT-NEXT:    mov v0.16b, v2.16b
+; DEFAULT-NEXT:    ret
+  %1 = call <4 x float> @llvm.ct.select.v4f32(i1 %cond, <4 x float> %a, <4 x 
float> %b)
+  ret <4 x float> %1
+}

_______________________________________________
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to