https://github.com/wizardengineer created https://github.com/llvm/llvm-project/pull/166706
This patch implements architecture-specific lowering for ct.select on AArch64 using CSEL (conditional select) instructions for constant-time selection. Implementation details: - Uses CSEL family of instructions for scalar integer types - Uses FCSEL for floating-point types (F16, BF16, F32, F64) - Post-RA MC lowering to convert pseudo-instructions to real CSEL/FCSEL - Handles vector types appropriately - Comprehensive test coverage for AArch64 The implementation includes: - ISelLowering: Custom lowering to CTSELECT pseudo-instructions - InstrInfo: Pseudo-instruction definitions and patterns - MCInstLower: Post-RA lowering of pseudo-instructions to actual CSEL/FCSEL - Proper handling of condition codes for constant-time guarantees >From 071428b7a6eed7a800364cc4b9a7e25e1d8e310e Mon Sep 17 00:00:00 2001 From: wizardengineer <[email protected]> Date: Wed, 5 Nov 2025 17:09:45 -0500 Subject: [PATCH] [LLVM][AArch64] Add native ct.select support for ARM64 This patch implements architecture-specific lowering for ct.select on AArch64 using CSEL (conditional select) instructions for constant-time selection. Implementation details: - Uses CSEL family of instructions for scalar integer types - Uses FCSEL for floating-point types (F16, BF16, F32, F64) - Post-RA MC lowering to convert pseudo-instructions to real CSEL/FCSEL - Handles vector types appropriately - Comprehensive test coverage for AArch64 The implementation includes: - ISelLowering: Custom lowering to CTSELECT pseudo-instructions - InstrInfo: Pseudo-instruction definitions and patterns - MCInstLower: Post-RA lowering of pseudo-instructions to actual CSEL/FCSEL - Proper handling of condition codes for constant-time guarantees --- .../Target/AArch64/AArch64ISelLowering.cpp | 53 +++++ llvm/lib/Target/AArch64/AArch64ISelLowering.h | 12 ++ llvm/lib/Target/AArch64/AArch64InstrInfo.cpp | 200 ++++++++---------- llvm/lib/Target/AArch64/AArch64InstrInfo.td | 45 ++++ .../lib/Target/AArch64/AArch64MCInstLower.cpp | 18 ++ llvm/test/CodeGen/AArch64/ctselect.ll | 153 ++++++++++++++ 6 files changed, 371 insertions(+), 110 deletions(-) create mode 100644 llvm/test/CodeGen/AArch64/ctselect.ll diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 60aa61e993b26..a86aac88b94a8 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -511,12 +511,35 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::BR_CC, MVT::f64, Custom); setOperationAction(ISD::SELECT, MVT::i32, Custom); setOperationAction(ISD::SELECT, MVT::i64, Custom); + setOperationAction(ISD::CTSELECT, MVT::i8, Promote); + setOperationAction(ISD::CTSELECT, MVT::i16, Promote); + setOperationAction(ISD::CTSELECT, MVT::i32, Custom); + setOperationAction(ISD::CTSELECT, MVT::i64, Custom); if (Subtarget->hasFPARMv8()) { setOperationAction(ISD::SELECT, MVT::f16, Custom); setOperationAction(ISD::SELECT, MVT::bf16, Custom); } + if (Subtarget->hasFullFP16()) { + setOperationAction(ISD::CTSELECT, MVT::f16, Custom); + setOperationAction(ISD::CTSELECT, MVT::bf16, Custom); + } else { + setOperationAction(ISD::CTSELECT, MVT::f16, Promote); + setOperationAction(ISD::CTSELECT, MVT::bf16, Promote); + } setOperationAction(ISD::SELECT, MVT::f32, Custom); setOperationAction(ISD::SELECT, MVT::f64, Custom); + setOperationAction(ISD::CTSELECT, MVT::f32, Custom); + setOperationAction(ISD::CTSELECT, MVT::f64, Custom); + for (MVT VT : MVT::vector_valuetypes()) { + MVT elemType = VT.getVectorElementType(); + if (elemType == MVT::i8 || elemType == MVT::i16) { + setOperationAction(ISD::CTSELECT, VT, Promote); + } else if ((elemType == MVT::f16 || elemType == MVT::bf16) && !Subtarget->hasFullFP16()) { + setOperationAction(ISD::CTSELECT, VT, Promote); + } else { + setOperationAction(ISD::CTSELECT, VT, Expand); + } + } setOperationAction(ISD::SELECT_CC, MVT::i32, Custom); setOperationAction(ISD::SELECT_CC, MVT::i64, Custom); setOperationAction(ISD::SELECT_CC, MVT::f16, Custom); @@ -3328,6 +3351,18 @@ void AArch64TargetLowering::fixupPtrauthDiscriminator( IntDiscOp.setImm(IntDisc); } +MachineBasicBlock *AArch64TargetLowering::EmitCTSELECT(MachineInstr &MI, MachineBasicBlock *MBB, unsigned Opcode) const { + const TargetInstrInfo *TII = Subtarget->getInstrInfo(); + DebugLoc DL = MI.getDebugLoc(); + MachineInstrBuilder Builder = BuildMI(*MBB, MI, DL, TII->get(Opcode)); + for (unsigned Idx = 0; Idx < MI.getNumOperands(); ++Idx) { + Builder.add(MI.getOperand(Idx)); + } + Builder->setFlag(MachineInstr::NoMerge); + MBB->remove_instr(&MI); + return MBB; +} + MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter( MachineInstr &MI, MachineBasicBlock *BB) const { @@ -7590,6 +7625,8 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op, return LowerSELECT(Op, DAG); case ISD::SELECT_CC: return LowerSELECT_CC(Op, DAG); + case ISD::CTSELECT: + return LowerCTSELECT(Op, DAG); case ISD::JumpTable: return LowerJumpTable(Op, DAG); case ISD::BR_JT: @@ -12149,6 +12186,22 @@ SDValue AArch64TargetLowering::LowerSELECT(SDValue Op, return Res; } +SDValue AArch64TargetLowering::LowerCTSELECT(SDValue Op, + SelectionDAG &DAG) const { + SDValue CCVal = Op->getOperand(0); + SDValue TVal = Op->getOperand(1); + SDValue FVal = Op->getOperand(2); + SDLoc DL(Op); + + EVT VT = Op.getValueType(); + + SDValue Zero = DAG.getConstant(0, DL, CCVal.getValueType()); + SDValue CC; + SDValue Cmp = getAArch64Cmp(CCVal, Zero, ISD::SETNE, CC, DAG, DL); + + return DAG.getNode(AArch64ISD::CTSELECT, DL, VT, TVal, FVal, CC, Cmp); +} + SDValue AArch64TargetLowering::LowerJumpTable(SDValue Op, SelectionDAG &DAG) const { // Jump table entries as PC relative offsets. No additional tweaking diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h index 2cb8ed29f252a..d14d64ffe88b6 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -23,6 +23,11 @@ namespace llvm { +namespace AArch64ISD { +// Forward declare the enum from the generated file +enum GenNodeType : unsigned; +} // namespace AArch64ISD + class AArch64TargetMachine; namespace AArch64 { @@ -202,6 +207,8 @@ class AArch64TargetLowering : public TargetLowering { MachineOperand &AddrDiscOp, const TargetRegisterClass *AddrDiscRC) const; + MachineBasicBlock *EmitCTSELECT(MachineInstr &MI, MachineBasicBlock *BB, unsigned Opcode) const; + MachineBasicBlock * EmitInstrWithCustomInserter(MachineInstr &MI, MachineBasicBlock *MBB) const override; @@ -684,6 +691,7 @@ class AArch64TargetLowering : public TargetLowering { iterator_range<SDNode::user_iterator> Users, SDNodeFlags Flags, const SDLoc &dl, SelectionDAG &DAG) const; + SDValue LowerCTSELECT(SDValue Op, SelectionDAG &DAG) const; SDValue LowerINIT_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const; SDValue LowerADJUST_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const; SDValue LowerJumpTable(SDValue Op, SelectionDAG &DAG) const; @@ -919,6 +927,10 @@ class AArch64TargetLowering : public TargetLowering { bool hasMultipleConditionRegisters(EVT VT) const override { return VT.isScalableVector(); } + + bool isSelectSupported(SelectSupportKind Kind) const override { + return true; + } }; namespace AArch64 { diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp index ccc8eb8a9706d..227e5d59610d6 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp @@ -700,7 +700,7 @@ static unsigned removeCopies(const MachineRegisterInfo &MRI, unsigned VReg) { // csel instruction. If so, return the folded opcode, and the replacement // register. static unsigned canFoldIntoCSel(const MachineRegisterInfo &MRI, unsigned VReg, - unsigned *NewReg = nullptr) { + unsigned *NewVReg = nullptr) { VReg = removeCopies(MRI, VReg); if (!Register::isVirtualRegister(VReg)) return 0; @@ -708,37 +708,8 @@ static unsigned canFoldIntoCSel(const MachineRegisterInfo &MRI, unsigned VReg, bool Is64Bit = AArch64::GPR64allRegClass.hasSubClassEq(MRI.getRegClass(VReg)); const MachineInstr *DefMI = MRI.getVRegDef(VReg); unsigned Opc = 0; - unsigned SrcReg = 0; + unsigned SrcOpNum = 0; switch (DefMI->getOpcode()) { - case AArch64::SUBREG_TO_REG: - // Check for the following way to define an 64-bit immediate: - // %0:gpr32 = MOVi32imm 1 - // %1:gpr64 = SUBREG_TO_REG 0, %0:gpr32, %subreg.sub_32 - if (!DefMI->getOperand(1).isImm() || DefMI->getOperand(1).getImm() != 0) - return 0; - if (!DefMI->getOperand(2).isReg()) - return 0; - if (!DefMI->getOperand(3).isImm() || - DefMI->getOperand(3).getImm() != AArch64::sub_32) - return 0; - DefMI = MRI.getVRegDef(DefMI->getOperand(2).getReg()); - if (DefMI->getOpcode() != AArch64::MOVi32imm) - return 0; - if (!DefMI->getOperand(1).isImm() || DefMI->getOperand(1).getImm() != 1) - return 0; - assert(Is64Bit); - SrcReg = AArch64::XZR; - Opc = AArch64::CSINCXr; - break; - - case AArch64::MOVi32imm: - case AArch64::MOVi64imm: - if (!DefMI->getOperand(1).isImm() || DefMI->getOperand(1).getImm() != 1) - return 0; - SrcReg = Is64Bit ? AArch64::XZR : AArch64::WZR; - Opc = Is64Bit ? AArch64::CSINCXr : AArch64::CSINCWr; - break; - case AArch64::ADDSXri: case AArch64::ADDSWri: // if NZCV is used, do not fold. @@ -753,7 +724,7 @@ static unsigned canFoldIntoCSel(const MachineRegisterInfo &MRI, unsigned VReg, if (!DefMI->getOperand(2).isImm() || DefMI->getOperand(2).getImm() != 1 || DefMI->getOperand(3).getImm() != 0) return 0; - SrcReg = DefMI->getOperand(1).getReg(); + SrcOpNum = 1; Opc = Is64Bit ? AArch64::CSINCXr : AArch64::CSINCWr; break; @@ -763,7 +734,7 @@ static unsigned canFoldIntoCSel(const MachineRegisterInfo &MRI, unsigned VReg, unsigned ZReg = removeCopies(MRI, DefMI->getOperand(1).getReg()); if (ZReg != AArch64::XZR && ZReg != AArch64::WZR) return 0; - SrcReg = DefMI->getOperand(2).getReg(); + SrcOpNum = 2; Opc = Is64Bit ? AArch64::CSINVXr : AArch64::CSINVWr; break; } @@ -782,17 +753,17 @@ static unsigned canFoldIntoCSel(const MachineRegisterInfo &MRI, unsigned VReg, unsigned ZReg = removeCopies(MRI, DefMI->getOperand(1).getReg()); if (ZReg != AArch64::XZR && ZReg != AArch64::WZR) return 0; - SrcReg = DefMI->getOperand(2).getReg(); + SrcOpNum = 2; Opc = Is64Bit ? AArch64::CSNEGXr : AArch64::CSNEGWr; break; } default: return 0; } - assert(Opc && SrcReg && "Missing parameters"); + assert(Opc && SrcOpNum && "Missing parameters"); - if (NewReg) - *NewReg = SrcReg; + if (NewVReg) + *NewVReg = DefMI->getOperand(SrcOpNum).getReg(); return Opc; } @@ -993,34 +964,28 @@ void AArch64InstrInfo::insertSelect(MachineBasicBlock &MBB, // Try folding simple instructions into the csel. if (TryFold) { - unsigned NewReg = 0; - unsigned FoldedOpc = canFoldIntoCSel(MRI, TrueReg, &NewReg); + unsigned NewVReg = 0; + unsigned FoldedOpc = canFoldIntoCSel(MRI, TrueReg, &NewVReg); if (FoldedOpc) { // The folded opcodes csinc, csinc and csneg apply the operation to // FalseReg, so we need to invert the condition. CC = AArch64CC::getInvertedCondCode(CC); TrueReg = FalseReg; } else - FoldedOpc = canFoldIntoCSel(MRI, FalseReg, &NewReg); + FoldedOpc = canFoldIntoCSel(MRI, FalseReg, &NewVReg); // Fold the operation. Leave any dead instructions for DCE to clean up. if (FoldedOpc) { - FalseReg = NewReg; + FalseReg = NewVReg; Opc = FoldedOpc; - // Extend the live range of NewReg. - MRI.clearKillFlags(NewReg); + // The extends the live range of NewVReg. + MRI.clearKillFlags(NewVReg); } } // Pull all virtual register into the appropriate class. MRI.constrainRegClass(TrueReg, RC); - // FalseReg might be WZR or XZR if the folded operand is a literal 1. - assert( - (FalseReg.isVirtual() || FalseReg == AArch64::WZR || - FalseReg == AArch64::XZR) && - "FalseReg was folded into a non-virtual register other than WZR or XZR"); - if (FalseReg.isVirtual()) - MRI.constrainRegClass(FalseReg, RC); + MRI.constrainRegClass(FalseReg, RC); // Insert the csel. BuildMI(MBB, I, DL, get(Opc), DstReg) @@ -2148,16 +2113,46 @@ bool AArch64InstrInfo::removeCmpToZeroOrOne( return true; } -bool AArch64InstrInfo::expandPostRAPseudo(MachineInstr &MI) const { - if (MI.getOpcode() != TargetOpcode::LOAD_STACK_GUARD && - MI.getOpcode() != AArch64::CATCHRET) - return false; +static inline void expandCtSelect(MachineBasicBlock &MBB, MachineInstr &MI, DebugLoc &DL, const MCInstrDesc &MCID) { + MachineInstrBuilder Builder = BuildMI(MBB, MI, DL, MCID); + for (unsigned Idx = 0; Idx < MI.getNumOperands(); ++Idx) { + Builder.add(MI.getOperand(Idx)); + } + Builder->setFlag(MachineInstr::NoMerge); + MBB.remove_instr(&MI); +} +bool AArch64InstrInfo::expandPostRAPseudo(MachineInstr &MI) const { MachineBasicBlock &MBB = *MI.getParent(); auto &Subtarget = MBB.getParent()->getSubtarget<AArch64Subtarget>(); auto TRI = Subtarget.getRegisterInfo(); DebugLoc DL = MI.getDebugLoc(); + switch (MI.getOpcode()) { + case AArch64::I32CTSELECT: + expandCtSelect(MBB, MI, DL, get(AArch64::CSELWr)); + return true; + case AArch64::I64CTSELECT: + expandCtSelect(MBB, MI, DL, get(AArch64::CSELXr)); + return true; + case AArch64::BF16CTSELECT: + expandCtSelect(MBB, MI, DL, get(AArch64::FCSELHrrr)); + return true; + case AArch64::F16CTSELECT: + expandCtSelect(MBB, MI, DL, get(AArch64::FCSELHrrr)); + return true; + case AArch64::F32CTSELECT: + expandCtSelect(MBB, MI, DL, get(AArch64::FCSELSrrr)); + return true; + case AArch64::F64CTSELECT: + expandCtSelect(MBB, MI, DL, get(AArch64::FCSELDrrr)); + return true; + } + + if (MI.getOpcode() != TargetOpcode::LOAD_STACK_GUARD && + MI.getOpcode() != AArch64::CATCHRET) + return false; + if (MI.getOpcode() == AArch64::CATCHRET) { // Skip to the first instruction before the epilog. const TargetInstrInfo *TII = @@ -5098,7 +5093,7 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB, bool RenamableDest, bool RenamableSrc) const { if (AArch64::GPR32spRegClass.contains(DestReg) && - AArch64::GPR32spRegClass.contains(SrcReg)) { + (AArch64::GPR32spRegClass.contains(SrcReg) || SrcReg == AArch64::WZR)) { if (DestReg == AArch64::WSP || SrcReg == AArch64::WSP) { // If either operand is WSP, expand to ADD #0. if (Subtarget.hasZeroCycleRegMoveGPR64() && @@ -5123,14 +5118,31 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB, .addImm(0) .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0)); } + } else if (SrcReg == AArch64::WZR && + Subtarget.hasZeroCycleZeroingGPR64() && + !Subtarget.hasZeroCycleZeroingGPR32()) { + // Use 64-bit zeroing when available but 32-bit zeroing is not + MCRegister DestRegX = RI.getMatchingSuperReg(DestReg, AArch64::sub_32, + &AArch64::GPR64spRegClass); + assert(DestRegX.isValid() && "Destination super-reg not valid"); + BuildMI(MBB, I, DL, get(AArch64::MOVZXi), DestRegX) + .addImm(0) + .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0)); + } else if (SrcReg == AArch64::WZR && Subtarget.hasZeroCycleZeroingGPR32()) { + BuildMI(MBB, I, DL, get(AArch64::MOVZWi), DestReg) + .addImm(0) + .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0)); } else if (Subtarget.hasZeroCycleRegMoveGPR64() && !Subtarget.hasZeroCycleRegMoveGPR32()) { // Cyclone recognizes "ORR Xd, XZR, Xm" as a zero-cycle register move. MCRegister DestRegX = RI.getMatchingSuperReg(DestReg, AArch64::sub_32, &AArch64::GPR64spRegClass); assert(DestRegX.isValid() && "Destination super-reg not valid"); - MCRegister SrcRegX = RI.getMatchingSuperReg(SrcReg, AArch64::sub_32, - &AArch64::GPR64spRegClass); + MCRegister SrcRegX = + SrcReg == AArch64::WZR + ? AArch64::XZR + : RI.getMatchingSuperReg(SrcReg, AArch64::sub_32, + &AArch64::GPR64spRegClass); assert(SrcRegX.isValid() && "Source super-reg not valid"); // This instruction is reading and writing X registers. This may upset // the register scavenger and machine verifier, so we need to indicate @@ -5149,59 +5161,6 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB, return; } - // GPR32 zeroing - if (AArch64::GPR32spRegClass.contains(DestReg) && SrcReg == AArch64::WZR) { - if (Subtarget.hasZeroCycleZeroingGPR64() && - !Subtarget.hasZeroCycleZeroingGPR32()) { - MCRegister DestRegX = RI.getMatchingSuperReg(DestReg, AArch64::sub_32, - &AArch64::GPR64spRegClass); - assert(DestRegX.isValid() && "Destination super-reg not valid"); - BuildMI(MBB, I, DL, get(AArch64::MOVZXi), DestRegX) - .addImm(0) - .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0)); - } else if (Subtarget.hasZeroCycleZeroingGPR32()) { - BuildMI(MBB, I, DL, get(AArch64::MOVZWi), DestReg) - .addImm(0) - .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0)); - } else { - BuildMI(MBB, I, DL, get(AArch64::ORRWrr), DestReg) - .addReg(AArch64::WZR) - .addReg(AArch64::WZR); - } - return; - } - - if (AArch64::GPR64spRegClass.contains(DestReg) && - AArch64::GPR64spRegClass.contains(SrcReg)) { - if (DestReg == AArch64::SP || SrcReg == AArch64::SP) { - // If either operand is SP, expand to ADD #0. - BuildMI(MBB, I, DL, get(AArch64::ADDXri), DestReg) - .addReg(SrcReg, getKillRegState(KillSrc)) - .addImm(0) - .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0)); - } else { - // Otherwise, expand to ORR XZR. - BuildMI(MBB, I, DL, get(AArch64::ORRXrr), DestReg) - .addReg(AArch64::XZR) - .addReg(SrcReg, getKillRegState(KillSrc)); - } - return; - } - - // GPR64 zeroing - if (AArch64::GPR64spRegClass.contains(DestReg) && SrcReg == AArch64::XZR) { - if (Subtarget.hasZeroCycleZeroingGPR64()) { - BuildMI(MBB, I, DL, get(AArch64::MOVZXi), DestReg) - .addImm(0) - .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0)); - } else { - BuildMI(MBB, I, DL, get(AArch64::ORRXrr), DestReg) - .addReg(AArch64::XZR) - .addReg(AArch64::XZR); - } - return; - } - // Copy a Predicate register by ORRing with itself. if (AArch64::PPRRegClass.contains(DestReg) && AArch64::PPRRegClass.contains(SrcReg)) { @@ -5286,6 +5245,27 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB, return; } + if (AArch64::GPR64spRegClass.contains(DestReg) && + (AArch64::GPR64spRegClass.contains(SrcReg) || SrcReg == AArch64::XZR)) { + if (DestReg == AArch64::SP || SrcReg == AArch64::SP) { + // If either operand is SP, expand to ADD #0. + BuildMI(MBB, I, DL, get(AArch64::ADDXri), DestReg) + .addReg(SrcReg, getKillRegState(KillSrc)) + .addImm(0) + .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0)); + } else if (SrcReg == AArch64::XZR && Subtarget.hasZeroCycleZeroingGPR64()) { + BuildMI(MBB, I, DL, get(AArch64::MOVZXi), DestReg) + .addImm(0) + .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0)); + } else { + // Otherwise, expand to ORR XZR. + BuildMI(MBB, I, DL, get(AArch64::ORRXrr), DestReg) + .addReg(AArch64::XZR) + .addReg(SrcReg, getKillRegState(KillSrc)); + } + return; + } + // Copy a DDDD register quad by copying the individual sub-registers. if (AArch64::DDDDRegClass.contains(DestReg) && AArch64::DDDDRegClass.contains(SrcReg)) { diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td index 2871a20e28b65..3a8fdc329f129 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -476,6 +476,11 @@ def SDT_AArch64cbz : SDTypeProfile<0, 2, [SDTCisInt<0>, SDTCisVT<1, OtherVT>]>; def SDT_AArch64tbz : SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisInt<1>, SDTCisVT<2, OtherVT>]>; +def SDT_AArch64CtSelect : SDTypeProfile<1, 4, + [SDTCisSameAs<0, 1>, + SDTCisSameAs<0, 2>, + SDTCisInt<3>, + SDTCisVT<4, i32>]>; def SDT_AArch64CSel : SDTypeProfile<1, 4, [SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, @@ -843,6 +848,7 @@ def AArch64tbz : SDNode<"AArch64ISD::TBZ", SDT_AArch64tbz, def AArch64tbnz : SDNode<"AArch64ISD::TBNZ", SDT_AArch64tbz, [SDNPHasChain]>; +def AArch64ctselect : SDNode<"AArch64ISD::CTSELECT", SDT_AArch64CtSelect>; def AArch64csel : SDNode<"AArch64ISD::CSEL", SDT_AArch64CSel>; // Conditional select invert. @@ -5644,6 +5650,45 @@ def F128CSEL : Pseudo<(outs FPR128:$Rd), let hasNoSchedulingInfo = 1; } +//===----------------------------------------------------------------------===// +// Constant-time conditional selection instructions +//===----------------------------------------------------------------------===// + +let hasSideEffects = 1, isPseudo = 1, hasNoSchedulingInfo = 1, Uses = [NZCV] in { + def I32CTSELECT : Pseudo<(outs GPR32:$dst), + (ins GPR32:$tval, GPR32:$fval, i32imm:$cc), + [(set (i32 GPR32:$dst), + (AArch64ctselect GPR32:$tval, GPR32:$fval, + (i32 imm:$cc), NZCV))]>; + def I64CTSELECT : Pseudo<(outs GPR64:$dst), + (ins GPR64:$tval, GPR64:$fval, i32imm:$cc), + [(set (i64 GPR64:$dst), + (AArch64ctselect GPR64:$tval, GPR64:$fval, + (i32 imm:$cc), NZCV))]>; + let Predicates = [HasFullFP16] in { + def F16CTSELECT : Pseudo<(outs FPR16:$dst), + (ins FPR16:$tval, FPR16:$fval, i32imm:$cc), + [(set (f16 FPR16:$dst), + (AArch64ctselect (f16 FPR16:$tval), (f16 FPR16:$fval), + (i32 imm:$cc), NZCV))]>; + def BF16CTSELECT : Pseudo<(outs FPR16:$dst), + (ins FPR16:$tval, FPR16:$fval, i32imm:$cc), + [(set (bf16 FPR16:$dst), + (AArch64ctselect (bf16 FPR16:$tval), (bf16 FPR16:$fval), + (i32 imm:$cc), NZCV))]>; + } + def F32CTSELECT : Pseudo<(outs FPR32:$dst), + (ins FPR32:$tval, FPR32:$fval, i32imm:$cc), + [(set (f32 FPR32:$dst), + (AArch64ctselect FPR32:$tval, FPR32:$fval, + (i32 imm:$cc), NZCV))]>; + def F64CTSELECT : Pseudo<(outs FPR64:$dst), + (ins FPR64:$tval, FPR64:$fval, i32imm:$cc), + [(set (f64 FPR64:$dst), + (AArch64ctselect FPR64:$tval, FPR64:$fval, + (i32 imm:$cc), NZCV))]>; +} + //===----------------------------------------------------------------------===// // Instructions used for emitting unwind opcodes on ARM64 Windows. //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Target/AArch64/AArch64MCInstLower.cpp b/llvm/lib/Target/AArch64/AArch64MCInstLower.cpp index 39946633603f6..e2ec9118eb5ee 100644 --- a/llvm/lib/Target/AArch64/AArch64MCInstLower.cpp +++ b/llvm/lib/Target/AArch64/AArch64MCInstLower.cpp @@ -393,5 +393,23 @@ void AArch64MCInstLower::Lower(const MachineInstr *MI, MCInst &OutMI) const { OutMI.setOpcode(AArch64::RET); OutMI.addOperand(MCOperand::createReg(AArch64::LR)); break; + case AArch64::I32CTSELECT: + OutMI.setOpcode(AArch64::CSELWr); + break; + case AArch64::I64CTSELECT: + OutMI.setOpcode(AArch64::CSELXr); + break; + case AArch64::BF16CTSELECT: + OutMI.setOpcode(AArch64::FCSELHrrr); + break; + case AArch64::F16CTSELECT: + OutMI.setOpcode(AArch64::FCSELHrrr); + break; + case AArch64::F32CTSELECT: + OutMI.setOpcode(AArch64::FCSELSrrr); + break; + case AArch64::F64CTSELECT: + OutMI.setOpcode(AArch64::FCSELDrrr); + break; } } diff --git a/llvm/test/CodeGen/AArch64/ctselect.ll b/llvm/test/CodeGen/AArch64/ctselect.ll new file mode 100644 index 0000000000000..77e9cf24e56cf --- /dev/null +++ b/llvm/test/CodeGen/AArch64/ctselect.ll @@ -0,0 +1,153 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6 +; RUN: llc < %s -verify-machineinstrs -mtriple=aarch64-none-eabi | FileCheck %s --check-prefixes=DEFAULT,NOFP16 +; RUN: llc < %s -verify-machineinstrs -mtriple=aarch64-none-eabi -mattr=+fullfp16 | FileCheck %s --check-prefixes=DEFAULT,FP16 + +define i1 @ct_i1(i1 %cond, i1 %a, i1 %b) { +; DEFAULT-LABEL: ct_i1: +; DEFAULT: // %bb.0: +; DEFAULT-NEXT: tst w0, #0x1 +; DEFAULT-NEXT: csel w8, w1, w2, ne +; DEFAULT-NEXT: and w0, w8, #0x1 +; DEFAULT-NEXT: ret + %1 = call i1 @llvm.ct.select.i1(i1 %cond, i1 %a, i1 %b) + ret i1 %1 +} + +define i8 @ct_i8(i1 %cond, i8 %a, i8 %b) { +; DEFAULT-LABEL: ct_i8: +; DEFAULT: // %bb.0: +; DEFAULT-NEXT: tst w0, #0x1 +; DEFAULT-NEXT: csel w0, w1, w2, ne +; DEFAULT-NEXT: ret + %1 = call i8 @llvm.ct.select.i8(i1 %cond, i8 %a, i8 %b) + ret i8 %1 +} + +define i16 @ct_i16(i1 %cond, i16 %a, i16 %b) { +; DEFAULT-LABEL: ct_i16: +; DEFAULT: // %bb.0: +; DEFAULT-NEXT: tst w0, #0x1 +; DEFAULT-NEXT: csel w0, w1, w2, ne +; DEFAULT-NEXT: ret + %1 = call i16 @llvm.ct.select.i16(i1 %cond, i16 %a, i16 %b) + ret i16 %1 +} + +define i32 @ct_i32(i1 %cond, i32 %a, i32 %b) { +; DEFAULT-LABEL: ct_i32: +; DEFAULT: // %bb.0: +; DEFAULT-NEXT: tst w0, #0x1 +; DEFAULT-NEXT: csel w0, w1, w2, ne +; DEFAULT-NEXT: ret + %1 = call i32 @llvm.ct.select.i32(i1 %cond, i32 %a, i32 %b) + ret i32 %1 +} + +define i64 @ct_i64(i1 %cond, i64 %a, i64 %b) { +; DEFAULT-LABEL: ct_i64: +; DEFAULT: // %bb.0: +; DEFAULT-NEXT: tst w0, #0x1 +; DEFAULT-NEXT: csel x0, x1, x2, ne +; DEFAULT-NEXT: ret + %1 = call i64 @llvm.ct.select.i64(i1 %cond, i64 %a, i64 %b) + ret i64 %1 +} + +define i128 @ct_i128(i1 %cond, i128 %a, i128 %b) { +; DEFAULT-LABEL: ct_i128: +; DEFAULT: // %bb.0: +; DEFAULT-NEXT: tst w0, #0x1 +; DEFAULT-NEXT: csel x0, x2, x4, ne +; DEFAULT-NEXT: csel x1, x3, x5, ne +; DEFAULT-NEXT: ret + %1 = call i128 @llvm.ct.select.i128(i1 %cond, i128 %a, i128 %b) + ret i128 %1 +} + +define half @ct_f16(i1 %cond, half %a, half %b) { +; NOFP16-LABEL: ct_f16: +; NOFP16: // %bb.0: +; NOFP16-NEXT: fcvt s1, h1 +; NOFP16-NEXT: fcvt s0, h0 +; NOFP16-NEXT: tst w0, #0x1 +; NOFP16-NEXT: fcsel s0, s0, s1, ne +; NOFP16-NEXT: fcvt h0, s0 +; NOFP16-NEXT: ret +; +; FP16-LABEL: ct_f16: +; FP16: // %bb.0: +; FP16-NEXT: tst w0, #0x1 +; FP16-NEXT: fcsel h0, h0, h1, ne +; FP16-NEXT: ret + %1 = call half @llvm.ct.select.f16(i1 %cond, half %a, half %b) + ret half %1 +} + +define float @ct_f32(i1 %cond, float %a, float %b) { +; DEFAULT-LABEL: ct_f32: +; DEFAULT: // %bb.0: +; DEFAULT-NEXT: tst w0, #0x1 +; DEFAULT-NEXT: fcsel s0, s0, s1, ne +; DEFAULT-NEXT: ret + %1 = call float @llvm.ct.select.f32(i1 %cond, float %a, float %b) + ret float %1 +} + +define double @ct_f64(i1 %cond, double %a, double %b) { +; DEFAULT-LABEL: ct_f64: +; DEFAULT: // %bb.0: +; DEFAULT-NEXT: tst w0, #0x1 +; DEFAULT-NEXT: fcsel d0, d0, d1, ne +; DEFAULT-NEXT: ret + %1 = call double @llvm.ct.select.f64(i1 %cond, double %a, double %b) + ret double %1 +} + +define <4 x i32> @ct_v4i32(i1 %cond, <4 x i32> %a, <4 x i32> %b) { +; DEFAULT-LABEL: ct_v4i32: +; DEFAULT: // %bb.0: +; DEFAULT-NEXT: mov w8, v1.s[1] +; DEFAULT-NEXT: mov w9, v0.s[1] +; DEFAULT-NEXT: tst w0, #0x1 +; DEFAULT-NEXT: fmov w10, s1 +; DEFAULT-NEXT: fmov w11, s0 +; DEFAULT-NEXT: csel w8, w9, w8, ne +; DEFAULT-NEXT: csel w9, w11, w10, ne +; DEFAULT-NEXT: mov w10, v1.s[2] +; DEFAULT-NEXT: fmov s2, w9 +; DEFAULT-NEXT: mov w11, v0.s[2] +; DEFAULT-NEXT: mov w9, v0.s[3] +; DEFAULT-NEXT: mov v2.s[1], w8 +; DEFAULT-NEXT: mov w8, v1.s[3] +; DEFAULT-NEXT: csel w10, w11, w10, ne +; DEFAULT-NEXT: mov v2.s[2], w10 +; DEFAULT-NEXT: csel w8, w9, w8, ne +; DEFAULT-NEXT: mov v2.s[3], w8 +; DEFAULT-NEXT: mov v0.16b, v2.16b +; DEFAULT-NEXT: ret + %1 = call <4 x i32> @llvm.ct.select.v4i32(i1 %cond, <4 x i32> %a, <4 x i32> %b) + ret <4 x i32> %1 +} + +define <4 x float> @ct_v4f32(i1 %cond, <4 x float> %a, <4 x float> %b) { +; DEFAULT-LABEL: ct_v4f32: +; DEFAULT: // %bb.0: +; DEFAULT-NEXT: mov s2, v1.s[1] +; DEFAULT-NEXT: mov s3, v0.s[1] +; DEFAULT-NEXT: tst w0, #0x1 +; DEFAULT-NEXT: mov s4, v1.s[2] +; DEFAULT-NEXT: mov s5, v0.s[2] +; DEFAULT-NEXT: fcsel s3, s3, s2, ne +; DEFAULT-NEXT: fcsel s2, s0, s1, ne +; DEFAULT-NEXT: mov s1, v1.s[3] +; DEFAULT-NEXT: mov s0, v0.s[3] +; DEFAULT-NEXT: mov v2.s[1], v3.s[0] +; DEFAULT-NEXT: fcsel s3, s5, s4, ne +; DEFAULT-NEXT: fcsel s0, s0, s1, ne +; DEFAULT-NEXT: mov v2.s[2], v3.s[0] +; DEFAULT-NEXT: mov v2.s[3], v0.s[0] +; DEFAULT-NEXT: mov v0.16b, v2.16b +; DEFAULT-NEXT: ret + %1 = call <4 x float> @llvm.ct.select.v4f32(i1 %cond, <4 x float> %a, <4 x float> %b) + ret <4 x float> %1 +} _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
