llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-arm

Author: Julius Alexandre (wizardengineer)

<details>
<summary>Changes</summary>

This patch implements architecture-specific lowering for ct.select on ARM
(both ARM32 and Thumb modes) using conditional move instructions and
bitwise operations for constant-time selection.

Implementation details:
- Uses pseudo-instructions that are expanded Post-RA to bitwise operations
- Post-RA expansion in ARMBaseInstrInfo for BUNDLE pseudo-instructions
- Handles scalar integer types, floating-point, and half-precision types
- Handles vector types with NEON when available
- Support for both ARM and Thumb instruction sets (Thumb1 and Thumb2)
- Special handling for Thumb1 which lacks conditional execution
- Comprehensive test coverage including half-precision and vectors

The implementation includes:
- ISelLowering: Custom lowering to CTSELECT pseudo-instructions
- ISelDAGToDAG: Selection of appropriate pseudo-instructions
- BaseInstrInfo: Post-RA expansion of BUNDLE to bitwise instruction sequences
- InstrInfo.td: Pseudo-instruction definitions for different types
- TargetMachine: Registration of Post-RA expansion pass
- Proper handling of condition codes and register allocation constraints

---

Patch is 166.38 KiB, truncated to 20.00 KiB below, full version: 
https://github.com/llvm/llvm-project/pull/166707.diff


10 Files Affected:

- (modified) llvm/lib/Target/ARM/ARMBaseInstrInfo.cpp (+335-2) 
- (modified) llvm/lib/Target/ARM/ARMBaseInstrInfo.h (+6) 
- (modified) llvm/lib/Target/ARM/ARMISelDAGToDAG.cpp (+86) 
- (modified) llvm/lib/Target/ARM/ARMISelLowering.cpp (+164-20) 
- (modified) llvm/lib/Target/ARM/ARMISelLowering.h (+11-2) 
- (modified) llvm/lib/Target/ARM/ARMInstrInfo.td (+185) 
- (modified) llvm/lib/Target/ARM/ARMTargetMachine.cpp (+3-5) 
- (added) llvm/test/CodeGen/ARM/ctselect-half.ll (+975) 
- (added) llvm/test/CodeGen/ARM/ctselect-vector.ll (+2179) 
- (added) llvm/test/CodeGen/ARM/ctselect.ll (+555) 


``````````diff
diff --git a/llvm/lib/Target/ARM/ARMBaseInstrInfo.cpp 
b/llvm/lib/Target/ARM/ARMBaseInstrInfo.cpp
index 22769dbf38719..6d8a3b72244fe 100644
--- a/llvm/lib/Target/ARM/ARMBaseInstrInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMBaseInstrInfo.cpp
@@ -1526,18 +1526,351 @@ void 
ARMBaseInstrInfo::expandMEMCPY(MachineBasicBlock::iterator MI) const {
   BB->erase(MI);
 }
 
+// Expands the ctselect pseudo for vector operands, post-RA.
+bool ARMBaseInstrInfo::expandCtSelectVector(MachineInstr &MI) const {
+  MachineBasicBlock *MBB = MI.getParent();
+  DebugLoc DL = MI.getDebugLoc();
+
+  Register DestReg = MI.getOperand(0).getReg();
+  Register MaskReg = MI.getOperand(1).getReg();
+
+  // These operations will differ by operand register size.
+  unsigned AndOp = ARM::VANDd;
+  unsigned BicOp = ARM::VBICd;
+  unsigned OrrOp = ARM::VORRd;
+  unsigned BroadcastOp = ARM::VDUP32d;
+
+  const TargetRegisterInfo *TRI = &getRegisterInfo();
+  const TargetRegisterClass *RC = TRI->getMinimalPhysRegClass(DestReg);
+
+  if (ARM::QPRRegClass.hasSubClassEq(RC)) {
+    AndOp = ARM::VANDq;
+    BicOp = ARM::VBICq;
+    OrrOp = ARM::VORRq;
+    BroadcastOp = ARM::VDUP32q;
+  }
+
+  unsigned RsbOp = Subtarget.isThumb2() ? ARM::t2RSBri : ARM::RSBri;
+
+  // Any vector pseudo has: ((outs $dst, $tmp_mask, $bcast_mask), (ins $src1,
+  // $src2, $cond))
+  Register VectorMaskReg = MI.getOperand(2).getReg();
+  Register Src1Reg = MI.getOperand(3).getReg();
+  Register Src2Reg = MI.getOperand(4).getReg();
+  Register CondReg = MI.getOperand(5).getReg();
+
+  // The following sequence of steps yields: (src1 & mask) | (src2 & ~mask)
+
+  // 1. mask = 0 - cond
+  // When cond = 0: mask = 0x00000000.
+  // When cond = 1: mask = 0xFFFFFFFF.
+
+  MachineInstr *FirstNewMI = BuildMI(*MBB, MI, DL, get(RsbOp), MaskReg)
+                                 .addReg(CondReg)
+                                 .addImm(0)
+                                 .add(predOps(ARMCC::AL))
+                                 .add(condCodeOp())
+                                 .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+  // 2. A = src1 & mask
+  // For vectors, broadcast the scalar mask so it matches operand size.
+  BuildMI(*MBB, MI, DL, get(BroadcastOp), VectorMaskReg)
+      .addReg(MaskReg)
+      .add(predOps(ARMCC::AL))
+      .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+  BuildMI(*MBB, MI, DL, get(AndOp), DestReg)
+      .addReg(Src1Reg)
+      .addReg(VectorMaskReg)
+      .add(predOps(ARMCC::AL))
+      .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+  // 3. B = src2 & ~mask
+  BuildMI(*MBB, MI, DL, get(BicOp), VectorMaskReg)
+      .addReg(Src2Reg)
+      .addReg(VectorMaskReg)
+      .add(predOps(ARMCC::AL))
+      .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+  // 4. result = A | B
+  auto LastNewMI = BuildMI(*MBB, MI, DL, get(OrrOp), DestReg)
+                       .addReg(DestReg)
+                       .addReg(VectorMaskReg)
+                       .add(predOps(ARMCC::AL))
+                       .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+  auto BundleStart = FirstNewMI->getIterator();
+  auto BundleEnd = LastNewMI->getIterator();
+
+  // Add instruction bundling
+  finalizeBundle(*MBB, BundleStart, std::next(BundleEnd));
+
+  MI.eraseFromParent();
+  return true;
+}
+
+// Expands the ctselect pseudo for thumb1, post-RA.
+bool ARMBaseInstrInfo::expandCtSelectThumb(MachineInstr &MI) const {
+  MachineBasicBlock *MBB = MI.getParent();
+  DebugLoc DL = MI.getDebugLoc();
+
+  // pseudos in thumb1 mode have: (outs $dst, $tmp_mask), (ins $src1, $src2,
+  // $cond)) register class here is always tGPR.
+  Register DestReg = MI.getOperand(0).getReg();
+  Register MaskReg = MI.getOperand(1).getReg();
+  Register Src1Reg = MI.getOperand(2).getReg();
+  Register Src2Reg = MI.getOperand(3).getReg();
+  Register CondReg = MI.getOperand(4).getReg();
+
+  // Access register info
+  MachineFunction *MF = MBB->getParent();
+  const TargetRegisterInfo *TRI = MF->getSubtarget().getRegisterInfo();
+  MachineRegisterInfo &MRI = MF->getRegInfo();
+
+  unsigned RegSize = TRI->getRegSizeInBits(MaskReg, MRI);
+  unsigned ShiftAmount = RegSize - 1;
+
+  // Option 1: Shift-based mask (preferred - no flag modification)
+  MachineInstr *FirstNewMI = BuildMI(*MBB, MI, DL, get(ARM::tMOVr), MaskReg)
+                                 .addReg(CondReg)
+                                 .add(predOps(ARMCC::AL))
+                                 .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+  // Instead of using RSB, we can use LSL and ASR to get the mask. This is to
+  // avoid the flag modification caused by RSB. tLSLri: (outs tGPR:$Rd,
+  // s_cc_out:$s), (ins tGPR:$Rm, imm0_31:$imm5, pred:$p)
+  BuildMI(*MBB, MI, DL, get(ARM::tLSLri), MaskReg)
+      .addReg(ARM::CPSR, RegState::Define | RegState::Dead) // s_cc_out:$s
+      .addReg(MaskReg)                                      // $Rm
+      .addImm(ShiftAmount)                                  // imm0_31:$imm5
+      .add(predOps(ARMCC::AL))                              // pred:$p
+      .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+  // tASRri: (outs tGPR:$Rd, s_cc_out:$s), (ins tGPR:$Rm, imm_sr:$imm5, 
pred:$p)
+  BuildMI(*MBB, MI, DL, get(ARM::tASRri), MaskReg)
+      .addReg(ARM::CPSR, RegState::Define | RegState::Dead) // s_cc_out:$s
+      .addReg(MaskReg)                                      // $Rm
+      .addImm(ShiftAmount)                                  // imm_sr:$imm5
+      .add(predOps(ARMCC::AL))                              // pred:$p
+      .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+  // 2. xor_diff = src1 ^ src2
+  BuildMI(*MBB, MI, DL, get(ARM::tMOVr), DestReg)
+      .addReg(Src1Reg)
+      .add(predOps(ARMCC::AL))
+      .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+  // tEOR has tied operands: (outs tGPR:$Rdn, s_cc_out:$s), (ins tGPR:$Rn,
+  // pred:$p) with constraint "$Rn = $Rdn"
+  BuildMI(*MBB, MI, DL, get(ARM::tEOR), DestReg)
+      .addReg(ARM::CPSR, RegState::Define | RegState::Dead) // s_cc_out:$s
+      .addReg(DestReg)                                      // tied input $Rn
+      .addReg(Src2Reg)                                      // $Rm
+      .add(predOps(ARMCC::AL))                              // pred:$p
+      .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+  // 3. masked_xor = xor_diff & mask
+  // tAND has tied operands: (outs tGPR:$Rdn, s_cc_out:$s), (ins tGPR:$Rn,
+  // pred:$p) with constraint "$Rn = $Rdn"
+  BuildMI(*MBB, MI, DL, get(ARM::tAND), DestReg)
+      .addReg(ARM::CPSR, RegState::Define | RegState::Dead) // s_cc_out:$s
+      .addReg(DestReg)                                      // tied input $Rn
+      .addReg(MaskReg, RegState::Kill)                      // $Rm
+      .add(predOps(ARMCC::AL))                              // pred:$p
+      .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+  // 4. result = src2 ^ masked_xor
+  // tEOR has tied operands: (outs tGPR:$Rdn, s_cc_out:$s), (ins tGPR:$Rn,
+  // pred:$p) with constraint "$Rn = $Rdn"
+  auto LastMI =
+      BuildMI(*MBB, MI, DL, get(ARM::tEOR), DestReg)
+          .addReg(ARM::CPSR, RegState::Define | RegState::Dead) // s_cc_out:$s
+          .addReg(DestReg)         // tied input $Rn
+          .addReg(Src2Reg)         // $Rm
+          .add(predOps(ARMCC::AL)) // pred:$p
+          .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+  // Add instruction bundling
+  auto BundleStart = FirstNewMI->getIterator();
+  finalizeBundle(*MBB, BundleStart, std::next(LastMI->getIterator()));
+
+  MI.eraseFromParent();
+  return true;
+}
+
+// Expands the ctselect pseudo, post-RA.
+bool ARMBaseInstrInfo::expandCtSelect(MachineInstr &MI) const {
+  MachineBasicBlock *MBB = MI.getParent();
+  DebugLoc DL = MI.getDebugLoc();
+
+  Register DestReg = MI.getOperand(0).getReg();
+  Register MaskReg = MI.getOperand(1).getReg();
+  Register DestRegSavedRef = DestReg;
+  Register Src1Reg, Src2Reg, CondReg;
+
+  // These operations will differ by operand register size.
+  unsigned RsbOp = ARM::RSBri;
+  unsigned AndOp = ARM::ANDrr;
+  unsigned BicOp = ARM::BICrr;
+  unsigned OrrOp = ARM::ORRrr;
+
+  if (Subtarget.isThumb2()) {
+    RsbOp = ARM::t2RSBri;
+    AndOp = ARM::t2ANDrr;
+    BicOp = ARM::t2BICrr;
+    OrrOp = ARM::t2ORRrr;
+  }
+
+  unsigned Opcode = MI.getOpcode();
+  bool IsFloat = Opcode == ARM::CTSELECTf32 || Opcode == ARM::CTSELECTf16 ||
+                 Opcode == ARM::CTSELECTbf16;
+  MachineInstr *FirstNewMI = nullptr;
+  if (IsFloat) {
+    // Each float pseudo has: (outs $dst, $tmp_mask, $scratch1, $scratch2), 
(ins
+    // $src1, $src2, $cond)) We use two scratch registers in tablegen for
+    // bitwise ops on float types,.
+    Register GPRScratch1 = MI.getOperand(2).getReg();
+    Register GPRScratch2 = MI.getOperand(3).getReg();
+
+    // choice a from __builtin_ct_select(cond, a, b)
+    Src1Reg = MI.getOperand(4).getReg();
+    // choice b from __builtin_ct_select(cond, a, b)
+    Src2Reg = MI.getOperand(5).getReg();
+    // cond from __builtin_ct_select(cond, a, b)
+    CondReg = MI.getOperand(6).getReg();
+
+    // Move fp src1 to GPR scratch1 so we can do our bitwise ops
+    FirstNewMI = BuildMI(*MBB, MI, DL, get(ARM::VMOVRS), GPRScratch1)
+                     .addReg(Src1Reg)
+                     .add(predOps(ARMCC::AL))
+                     .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+    // Move src2 to scratch2
+    BuildMI(*MBB, MI, DL, get(ARM::VMOVRS), GPRScratch2)
+        .addReg(Src2Reg)
+        .add(predOps(ARMCC::AL))
+        .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+    Src1Reg = GPRScratch1;
+    Src2Reg = GPRScratch2;
+    // Reuse GPRScratch1 for dest after we are done working with src1.
+    DestReg = GPRScratch1;
+  } else {
+    // Any non-float, non-vector pseudo has: (outs $dst, $tmp_mask), (ins 
$src1,
+    // $src2, $cond))
+    Src1Reg = MI.getOperand(2).getReg();
+    Src2Reg = MI.getOperand(3).getReg();
+    CondReg = MI.getOperand(4).getReg();
+  }
+
+  // The following sequence of steps yields: (src1 & mask) | (src2 & ~mask)
+
+  // 1. mask = 0 - cond
+  // When cond = 0: mask = 0x00000000.
+  // When cond = 1: mask = 0xFFFFFFFF.
+  auto TmpNewMI = BuildMI(*MBB, MI, DL, get(RsbOp), MaskReg)
+                      .addReg(CondReg)
+                      .addImm(0)
+                      .add(predOps(ARMCC::AL))
+                      .add(condCodeOp())
+                      .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+  // We use the first instruction in the bundle as the first instruction.
+  if (!FirstNewMI)
+    FirstNewMI = TmpNewMI;
+
+  // 2. A = src1 & mask
+  BuildMI(*MBB, MI, DL, get(AndOp), DestReg)
+      .addReg(Src1Reg)
+      .addReg(MaskReg)
+      .add(predOps(ARMCC::AL))
+      .add(condCodeOp())
+      .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+  // 3. B = src2 & ~mask
+  BuildMI(*MBB, MI, DL, get(BicOp), MaskReg)
+      .addReg(Src2Reg)
+      .addReg(MaskReg)
+      .add(predOps(ARMCC::AL))
+      .add(condCodeOp())
+      .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+  // 4. result = A | B
+  auto LastNewMI = BuildMI(*MBB, MI, DL, get(OrrOp), DestReg)
+                       .addReg(DestReg)
+                       .addReg(MaskReg)
+                       .add(predOps(ARMCC::AL))
+                       .add(condCodeOp())
+                       .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+  if (IsFloat) {
+    // Return our result from GPR to the correct register type.
+    LastNewMI = BuildMI(*MBB, MI, DL, get(ARM::VMOVSR), DestRegSavedRef)
+                    .addReg(DestReg)
+                    .add(predOps(ARMCC::AL))
+                    .setMIFlag(MachineInstr::MIFlag::NoMerge);
+  }
+
+  auto BundleStart = FirstNewMI->getIterator();
+  auto BundleEnd = LastNewMI->getIterator();
+
+  // Add instruction bundling
+  finalizeBundle(*MBB, BundleStart, std::next(BundleEnd));
+
+  MI.eraseFromParent();
+  return true;
+}
+
 bool ARMBaseInstrInfo::expandPostRAPseudo(MachineInstr &MI) const {
-  if (MI.getOpcode() == TargetOpcode::LOAD_STACK_GUARD) {
+  auto opcode = MI.getOpcode();
+
+  if (opcode == TargetOpcode::LOAD_STACK_GUARD) {
     expandLoadStackGuard(MI);
     MI.getParent()->erase(MI);
     return true;
   }
 
-  if (MI.getOpcode() == ARM::MEMCPY) {
+  if (opcode == ARM::MEMCPY) {
     expandMEMCPY(MI);
     return true;
   }
 
+  if (opcode == ARM::CTSELECTf64) {
+    if (Subtarget.isThumb1Only()) {
+      LLVM_DEBUG(dbgs() << "Opcode (thumb1 subtarget) " << opcode
+                        << "replaced by: " << MI);
+      return expandCtSelectThumb(MI);
+    } else {
+      LLVM_DEBUG(dbgs() << "Opcode (vector) " << opcode
+                        << "replaced by: " << MI);
+      return expandCtSelectVector(MI);
+    }
+  }
+
+  if (opcode == ARM::CTSELECTv8i8 || opcode == ARM::CTSELECTv4i16 ||
+      opcode == ARM::CTSELECTv2i32 || opcode == ARM::CTSELECTv1i64 ||
+      opcode == ARM::CTSELECTv2f32 || opcode == ARM::CTSELECTv4f16 ||
+      opcode == ARM::CTSELECTv4bf16 || opcode == ARM::CTSELECTv16i8 ||
+      opcode == ARM::CTSELECTv8i16 || opcode == ARM::CTSELECTv4i32 ||
+      opcode == ARM::CTSELECTv2i64 || opcode == ARM::CTSELECTv4f32 ||
+      opcode == ARM::CTSELECTv2f64 || opcode == ARM::CTSELECTv8f16 ||
+      opcode == ARM::CTSELECTv8bf16) {
+    LLVM_DEBUG(dbgs() << "Opcode (vector) " << opcode << "replaced by: " << 
MI);
+    return expandCtSelectVector(MI);
+  }
+
+  if (opcode == ARM::CTSELECTint || opcode == ARM::CTSELECTf16 ||
+      opcode == ARM::CTSELECTbf16 || opcode == ARM::CTSELECTf32) {
+    if (Subtarget.isThumb1Only()) {
+      LLVM_DEBUG(dbgs() << "Opcode (thumb1 subtarget) " << opcode
+                        << "replaced by: " << MI);
+      return expandCtSelectThumb(MI);
+    } else {
+      LLVM_DEBUG(dbgs() << "Opcode " << opcode << "replaced by: " << MI);
+      return expandCtSelect(MI);
+    }
+  }
+
   // This hook gets to expand COPY instructions before they become
   // copyPhysReg() calls.  Look for VMOVS instructions that can legally be
   // widened to VMOVD.  We prefer the VMOVD when possible because it may be
diff --git a/llvm/lib/Target/ARM/ARMBaseInstrInfo.h 
b/llvm/lib/Target/ARM/ARMBaseInstrInfo.h
index 2869e7f708046..f0e090f09f5dc 100644
--- a/llvm/lib/Target/ARM/ARMBaseInstrInfo.h
+++ b/llvm/lib/Target/ARM/ARMBaseInstrInfo.h
@@ -221,6 +221,12 @@ class ARMBaseInstrInfo : public ARMGenInstrInfo {
       const TargetRegisterInfo *TRI, Register VReg,
       MachineInstr::MIFlag Flags = MachineInstr::NoFlags) const override;
 
+  bool expandCtSelectVector(MachineInstr &MI) const;
+
+  bool expandCtSelectThumb(MachineInstr &MI) const;
+
+  bool expandCtSelect(MachineInstr &MI) const;
+
   bool expandPostRAPseudo(MachineInstr &MI) const override;
 
   bool shouldSink(const MachineInstr &MI) const override;
diff --git a/llvm/lib/Target/ARM/ARMISelDAGToDAG.cpp 
b/llvm/lib/Target/ARM/ARMISelDAGToDAG.cpp
index 847b7af5a9b11..3fdc5734baaa5 100644
--- a/llvm/lib/Target/ARM/ARMISelDAGToDAG.cpp
+++ b/llvm/lib/Target/ARM/ARMISelDAGToDAG.cpp
@@ -4200,6 +4200,92 @@ void ARMDAGToDAGISel::Select(SDNode *N) {
     // Other cases are autogenerated.
     break;
   }
+  case ARMISD::CTSELECT: {
+    EVT VT = N->getValueType(0);
+    unsigned PseudoOpcode;
+    bool IsFloat = false;
+    bool IsVector = false;
+
+    if (VT == MVT::f16) {
+      PseudoOpcode = ARM::CTSELECTf16;
+      IsFloat = true;
+    } else if (VT == MVT::bf16) {
+      PseudoOpcode = ARM::CTSELECTbf16;
+      IsFloat = true;
+    } else if (VT == MVT::f32) {
+      PseudoOpcode = ARM::CTSELECTf32;
+      IsFloat = true;
+    } else if (VT == MVT::f64) {
+      PseudoOpcode = ARM::CTSELECTf64;
+      IsVector = true;
+    } else if (VT == MVT::v8i8) {
+      PseudoOpcode = ARM::CTSELECTv8i8;
+      IsVector = true;
+    } else if (VT == MVT::v4i16) {
+      PseudoOpcode = ARM::CTSELECTv4i16;
+      IsVector = true;
+    } else if (VT == MVT::v2i32) {
+      PseudoOpcode = ARM::CTSELECTv2i32;
+      IsVector = true;
+    } else if (VT == MVT::v1i64) {
+      PseudoOpcode = ARM::CTSELECTv1i64;
+      IsVector = true;
+    } else if (VT == MVT::v2f32) {
+      PseudoOpcode = ARM::CTSELECTv2f32;
+      IsVector = true;
+    } else if (VT == MVT::v4f16) {
+      PseudoOpcode = ARM::CTSELECTv4f16;
+      IsVector = true;
+    } else if (VT == MVT::v4bf16) {
+      PseudoOpcode = ARM::CTSELECTv4bf16;
+      IsVector = true;
+    } else if (VT == MVT::v16i8) {
+      PseudoOpcode = ARM::CTSELECTv16i8;
+      IsVector = true;
+    } else if (VT == MVT::v8i16) {
+      PseudoOpcode = ARM::CTSELECTv8i16;
+      IsVector = true;
+    } else if (VT == MVT::v4i32) {
+      PseudoOpcode = ARM::CTSELECTv4i32;
+      IsVector = true;
+    } else if (VT == MVT::v2i64) {
+      PseudoOpcode = ARM::CTSELECTv2i64;
+      IsVector = true;
+    } else if (VT == MVT::v4f32) {
+      PseudoOpcode = ARM::CTSELECTv4f32;
+      IsVector = true;
+    } else if (VT == MVT::v2f64) {
+      PseudoOpcode = ARM::CTSELECTv2f64;
+      IsVector = true;
+    } else if (VT == MVT::v8f16) {
+      PseudoOpcode = ARM::CTSELECTv8f16;
+      IsVector = true;
+    } else if (VT == MVT::v8bf16) {
+      PseudoOpcode = ARM::CTSELECTv8bf16;
+      IsVector = true;
+    } else {
+      // i1, i8, i16, i32, i64
+      PseudoOpcode = ARM::CTSELECTint;
+    }
+
+    SmallVector<EVT, 4> VTs;
+    VTs.push_back(VT);       // $dst
+    VTs.push_back(MVT::i32); // $tmp_mask (always GPR)
+
+    if (IsVector) {
+      VTs.push_back(VT); // $bcast_mask (same type as dst for vectors)
+    } else if (IsFloat) {
+      VTs.push_back(MVT::i32); // $scratch1 (GPR)
+      VTs.push_back(MVT::i32); // $scratch2 (GPR)
+    }
+
+    // src1, src2, cond
+    SDValue Ops[] = {N->getOperand(0), N->getOperand(1), N->getOperand(2)};
+
+    SDNode *ResNode = CurDAG->getMachineNode(PseudoOpcode, SDLoc(N), VTs, Ops);
+    ReplaceNode(N, ResNode);
+    return;
+  }
   case ARMISD::VZIP: {
     EVT VT = N->getValueType(0);
     // vzip.32 Dd, Dm is a pseudo-instruction expanded to vtrn.32 Dd, Dm.
diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp 
b/llvm/lib/Target/ARM/ARMISelLowering.cpp
index 6b0653457cbaf..63005f1c9f989 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.cpp
+++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp
@@ -203,6 +203,7 @@ void ARMTargetLowering::addTypeForNEON(MVT VT, MVT 
PromotedLdStVT) {
   setOperationAction(ISD::SELECT,            VT, Expand);
   setOperationAction(ISD::SELECT_CC,         VT, Expand);
   setOperationAction(ISD::VSELECT,           VT, Expand);
+  setOperationAction(ISD::CTSELECT, VT, Custom);
   setOperationAction(ISD::SIGN_EXTEND_INREG, VT, Expand);
   if (VT.isInteger()) {
     setOperationAction(ISD::SHL, VT, Custom);
@@ -304,6 +305,7 @@ void ARMTargetLowering::addMVEVectorTypes(bool HasMVEFP) {
     setOperationAction(ISD::CTPOP, VT, Expand);
     setOperationAction(ISD::SELECT, VT, Expand);
     setOperationAction(ISD::SELECT_CC, VT, Expand);
+    setOperationAction(ISD::CTSELECT, VT, Custom);
 
     // Vector reductions
     setOperationAction(ISD::VECREDUCE_ADD, VT, Legal);
@@ -355,6 +357,7 @@ void ARMTargetLowering::addMVEVectorTypes(bool HasMVEFP) {
     setOperationAction(ISD::MSTORE, VT, Legal);
     setOperationAction(ISD::SELECT, VT, Expand);
     setOperationAction(ISD::SELECT_CC, VT, Expand);
+    setOperationAction(ISD::CTSELECT, VT, Custom);
 
     // Pre and Post inc are supported on loads and stores
     for (unsigned im = (unsigned)ISD::PRE_INC;
@@ -408,6 +411,28 @@ void ARMTargetLowering::addMVEVectorTypes(bool HasMVEFP) {
   setOperationAction(ISD::VECREDUCE_FMIN, MVT::v2f16, Custom);
   setOperationAction(ISD::VECREDUCE_FMAX, MVT::v2f16, Custom);
 
+  if (Subtarget->hasFullFP16()) {
+    setOperationAction(ISD::CTSELECT, MVT::v4f16, Custom);
+    setOperationAction(ISD::CTSELECT, MVT::v8f16, Custom);
+  }
+
+  if (Subtarget->hasBF16()) {
+    setOperationAction(ISD::CTSELECT, MVT::v4bf16, Custom);
+    setOperationAction(ISD::CTSELECT, MVT::v8bf16, Custom);
+  }
+
+  // small exotic vectors get scalarised for ctselect
+  setOperationAction(ISD::CTSELECT, MVT::v1i8, Expand);
+  setOperationAction(ISD::CTSELECT, MVT::v1i16, Expand);
+  setOperationAction(ISD::CTSELECT, MVT::v1i32, Expand);
+  setOperationAction(ISD::CTSELECT, MVT::v1f32, Expand);
+  setOperationAction(ISD::CTSELECT, MVT::v2i8, Expand);
+
+  setOperationAction(ISD::CTSELECT, MVT::v2i16, Promote);
+  setOperationPromotedToType(ISD::CTSELECT, MVT::v2i16, MVT::v4i16);
+  setOperationAction(IS...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/166707
_______________________________________________
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to