https://github.com/jhuber6 created 
https://github.com/llvm/llvm-project/pull/174862

Summary:
This patch adds an LLVM intrinsic and lowering for a subgroup ballot and
a corresponding clang builtin. This uses the already present support but
provides in a way accessible to other targets. With this and
https://github.com/llvm/llvm-project/pull/174655 we should be able to
most of the basic functions, like shuffling, active masks, and
reductions. More work will be needed for canonicalizing / exposing the
SPIR-V functions, but these are the fundamental builtins I need.


>From 01417cba0d54d0765220cf9058d6b19180582f76 Mon Sep 17 00:00:00 2001
From: Joseph Huber <[email protected]>
Date: Wed, 7 Jan 2026 15:00:01 -0600
Subject: [PATCH] [SPIR-V] Add builtin/intrinsic for subgroup ballot

Summary:
This patch adds an LLVM intrinsic and lowering for a subgroup ballot and
a corresponding clang builtin. This uses the already present support but
provides in a way accessible to other targets. With this and
https://github.com/llvm/llvm-project/pull/174655 we should be able to
most of the basic functions, like shuffling, active masks, and
reductions. More work will be needed for canonicalizing / exposing the
SPIR-V functions, but these are the fundamental builtins I need.
---
 clang/include/clang/Basic/BuiltinsSPIRVCommon.td |  2 ++
 clang/test/CodeGenSPIRV/Builtins/subgroup.c      | 15 +++++++++++++++
 clang/test/SemaSPIRV/BuiltIns/subgroup-errors.c  | 13 +++++++++++++
 llvm/include/llvm/IR/IntrinsicsSPIRV.td          |  2 ++
 .../Target/SPIRV/SPIRVInstructionSelector.cpp    | 15 +++++++++++++++
 llvm/test/CodeGen/SPIRV/ballot.ll                | 16 ++++++++++++++++
 6 files changed, 63 insertions(+)
 create mode 100644 clang/test/CodeGenSPIRV/Builtins/subgroup.c
 create mode 100644 clang/test/SemaSPIRV/BuiltIns/subgroup-errors.c
 create mode 100644 llvm/test/CodeGen/SPIRV/ballot.ll

diff --git a/clang/include/clang/Basic/BuiltinsSPIRVCommon.td 
b/clang/include/clang/Basic/BuiltinsSPIRVCommon.td
index d2ef6f99a0502..495851ed1727a 100644
--- a/clang/include/clang/Basic/BuiltinsSPIRVCommon.td
+++ b/clang/include/clang/Basic/BuiltinsSPIRVCommon.td
@@ -21,3 +21,5 @@ def subgroup_local_invocation_id : SPIRVBuiltin<"uint32_t()", 
[NoThrow, Const]>;
 def distance : SPIRVBuiltin<"void(...)", [NoThrow, Const]>;
 def length : SPIRVBuiltin<"void(...)", [NoThrow, Const]>;
 def smoothstep : SPIRVBuiltin<"void(...)", [NoThrow, Const, 
CustomTypeChecking]>;
+
+def subgroup_ballot : SPIRVBuiltin<"_ExtVector<4, uint32_t>(bool)", [NoThrow, 
Const]>;
diff --git a/clang/test/CodeGenSPIRV/Builtins/subgroup.c 
b/clang/test/CodeGenSPIRV/Builtins/subgroup.c
new file mode 100644
index 0000000000000..2ae2013c3c23e
--- /dev/null
+++ b/clang/test/CodeGenSPIRV/Builtins/subgroup.c
@@ -0,0 +1,15 @@
+// RUN: %clang_cc1 -O1 -triple spirv64 -fsycl-is-device -x c++ %s -emit-llvm 
-o - | FileCheck %s --check-prefixes=CHECK
+// RUN: %clang_cc1 -O1 -triple spirv64 -cl-std=CL3.0 -x cl %s -emit-llvm -o - 
| FileCheck %s --check-prefixes=CHECK
+// RUN: %clang_cc1 -O1 -triple spirv32 -cl-std=CL3.0 -x cl %s -emit-llvm -o - 
| FileCheck %s --check-prefixes=CHECK
+
+#if defined(__cplusplus)
+typedef bool _Bool;
+#endif
+typedef unsigned __attribute__((ext_vector_type(4))) int4;
+
+// CHECK: @{{.*}}test_subgroup_shuffle{{.*}}(
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    tail call <4 x i32> @llvm.spv.wave.ballot(i1 %i)
+[[clang::sycl_external]] int4 test_subgroup_shuffle(_Bool i) {
+    return __builtin_spirv_subgroup_ballot(i);
+}
diff --git a/clang/test/SemaSPIRV/BuiltIns/subgroup-errors.c 
b/clang/test/SemaSPIRV/BuiltIns/subgroup-errors.c
new file mode 100644
index 0000000000000..5ef9f499efd31
--- /dev/null
+++ b/clang/test/SemaSPIRV/BuiltIns/subgroup-errors.c
@@ -0,0 +1,13 @@
+// RUN: %clang_cc1 -O1 -Wno-unused-value -triple spirv64 -fsycl-is-device 
-verify %s -o -
+// RUN: %clang_cc1 -O1 -Wno-unused-value -triple spirv64 -verify %s 
-cl-std=CL3.0 -x cl -o -
+// RUN: %clang_cc1 -O1 -Wno-unused-value -triple spirv32 -verify %s 
-cl-std=CL3.0 -x cl -o -
+
+typedef unsigned __attribute__((ext_vector_type(4))) int4;
+
+void ballot(_Bool c) {
+  int4 x;
+  x = __builtin_spirv_subgroup_ballot(c);
+  x = __builtin_spirv_subgroup_ballot(1);
+  x = __builtin_spirv_subgroup_ballot(x); // expected-error{{parameter of 
incompatible type}}
+  int y = __builtin_spirv_subgroup_ballot(c); // expected-error{{with an 
expression of incompatible type}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td 
b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 402235ec7cd9c..51e4151c2fdae 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -118,6 +118,8 @@ def int_spv_rsqrt : 
DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]
   def int_spv_dot4add_i8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], 
[llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
   def int_spv_dot4add_u8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], 
[llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
   def int_spv_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], 
[llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
+  def int_spv_wave_ballot : ClangBuiltin<"__builtin_spirv_subgroup_ballot">,
+    DefaultAttrsIntrinsic<[llvm_v4i32_ty], [llvm_i1_ty], [IntrConvergent, 
IntrNoMem]>;
   def int_spv_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], 
[IntrConvergent, IntrNoMem]>;
   def int_spv_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], 
[IntrConvergent, IntrNoMem]>;
   def int_spv_wave_reduce_umax : DefaultAttrsIntrinsic<[llvm_any_ty], 
[LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp 
b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index f991938c14dfe..1918f5701e3b7 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -307,6 +307,8 @@ class SPIRVInstructionSelector : public InstructionSelector 
{
   bool selectWaveOpInst(Register ResVReg, const SPIRVType *ResType,
                         MachineInstr &I, unsigned Opcode) const;
 
+  bool selectWaveBallot(Register ResVReg, const SPIRVType *ResType,
+                        MachineInstr &I) const;
   bool selectWaveActiveCountBits(Register ResVReg, const SPIRVType *ResType,
                                  MachineInstr &I) const;
 
@@ -2710,6 +2712,17 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits(
   return Result;
 }
 
+bool SPIRVInstructionSelector::selectWaveBallot(Register ResVReg,
+                                                const SPIRVType *ResType,
+                                                MachineInstr &I) const {
+
+  SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
+  SPIRVType *BallotType = GR.getOrCreateSPIRVVectorType(IntTy, 4, I, TII);
+  Register BallotReg = MRI->createVirtualRegister(GR.getRegClass(BallotType));
+  return selectWaveOpInst(BallotReg, BallotType, I,
+                          SPIRV::OpGroupNonUniformBallot);
+}
+
 bool SPIRVInstructionSelector::selectWaveReduceMax(Register ResVReg,
                                                    const SPIRVType *ResType,
                                                    MachineInstr &I,
@@ -3797,6 +3810,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register 
ResVReg,
     return selectExtInst(ResVReg, ResType, I, CL::u_clamp, GL::UClamp);
   case Intrinsic::spv_sclamp:
     return selectExtInst(ResVReg, ResType, I, CL::s_clamp, GL::SClamp);
+  case Intrinsic::spv_wave_ballot:
+    return selectWaveBallot(ResVReg, ResType, I);
   case Intrinsic::spv_wave_active_countbits:
     return selectWaveActiveCountBits(ResVReg, ResType, I);
   case Intrinsic::spv_wave_all:
diff --git a/llvm/test/CodeGen/SPIRV/ballot.ll 
b/llvm/test/CodeGen/SPIRV/ballot.ll
new file mode 100644
index 0000000000000..3ca1a243feaea
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/ballot.ll
@@ -0,0 +1,16 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - 
-filetype=obj | spirv-val %}
+
+; CHECK-DAG:   %[[#uint:]] = OpTypeInt 32 0
+; CHECK-DAG:   %[[#ballot_type:]] = OpTypeVector %[[#uint]] 4
+; CHECK-DAG:   %[[#bool:]] = OpTypeBool
+; CHECK-DAG:   %[[#scope:]] = OpConstant %[[#uint]] 3
+
+; CHECK-LABEL: Begin function test_fun
+; CHECK:   %[[#bexpr:]] = OpFunctionParameter %[[#bool]]
+define <4 x i32> @test_fun(i1 %expr) {
+entry:
+; CHECK:   %[[#ballot:]] = OpGroupNonUniformBallot %[[#ballot_type]] 
%[[#scope]] %[[#bexpr]]
+  %0 = call <4 x i32> @llvm.spv.wave.ballot(i1 %expr)
+  ret <4 x i32> %0
+}

_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to