https://github.com/jhuber6 created https://github.com/llvm/llvm-project/pull/174862
Summary: This patch adds an LLVM intrinsic and lowering for a subgroup ballot and a corresponding clang builtin. This uses the already present support but provides in a way accessible to other targets. With this and https://github.com/llvm/llvm-project/pull/174655 we should be able to most of the basic functions, like shuffling, active masks, and reductions. More work will be needed for canonicalizing / exposing the SPIR-V functions, but these are the fundamental builtins I need. >From 01417cba0d54d0765220cf9058d6b19180582f76 Mon Sep 17 00:00:00 2001 From: Joseph Huber <[email protected]> Date: Wed, 7 Jan 2026 15:00:01 -0600 Subject: [PATCH] [SPIR-V] Add builtin/intrinsic for subgroup ballot Summary: This patch adds an LLVM intrinsic and lowering for a subgroup ballot and a corresponding clang builtin. This uses the already present support but provides in a way accessible to other targets. With this and https://github.com/llvm/llvm-project/pull/174655 we should be able to most of the basic functions, like shuffling, active masks, and reductions. More work will be needed for canonicalizing / exposing the SPIR-V functions, but these are the fundamental builtins I need. --- clang/include/clang/Basic/BuiltinsSPIRVCommon.td | 2 ++ clang/test/CodeGenSPIRV/Builtins/subgroup.c | 15 +++++++++++++++ clang/test/SemaSPIRV/BuiltIns/subgroup-errors.c | 13 +++++++++++++ llvm/include/llvm/IR/IntrinsicsSPIRV.td | 2 ++ .../Target/SPIRV/SPIRVInstructionSelector.cpp | 15 +++++++++++++++ llvm/test/CodeGen/SPIRV/ballot.ll | 16 ++++++++++++++++ 6 files changed, 63 insertions(+) create mode 100644 clang/test/CodeGenSPIRV/Builtins/subgroup.c create mode 100644 clang/test/SemaSPIRV/BuiltIns/subgroup-errors.c create mode 100644 llvm/test/CodeGen/SPIRV/ballot.ll diff --git a/clang/include/clang/Basic/BuiltinsSPIRVCommon.td b/clang/include/clang/Basic/BuiltinsSPIRVCommon.td index d2ef6f99a0502..495851ed1727a 100644 --- a/clang/include/clang/Basic/BuiltinsSPIRVCommon.td +++ b/clang/include/clang/Basic/BuiltinsSPIRVCommon.td @@ -21,3 +21,5 @@ def subgroup_local_invocation_id : SPIRVBuiltin<"uint32_t()", [NoThrow, Const]>; def distance : SPIRVBuiltin<"void(...)", [NoThrow, Const]>; def length : SPIRVBuiltin<"void(...)", [NoThrow, Const]>; def smoothstep : SPIRVBuiltin<"void(...)", [NoThrow, Const, CustomTypeChecking]>; + +def subgroup_ballot : SPIRVBuiltin<"_ExtVector<4, uint32_t>(bool)", [NoThrow, Const]>; diff --git a/clang/test/CodeGenSPIRV/Builtins/subgroup.c b/clang/test/CodeGenSPIRV/Builtins/subgroup.c new file mode 100644 index 0000000000000..2ae2013c3c23e --- /dev/null +++ b/clang/test/CodeGenSPIRV/Builtins/subgroup.c @@ -0,0 +1,15 @@ +// RUN: %clang_cc1 -O1 -triple spirv64 -fsycl-is-device -x c++ %s -emit-llvm -o - | FileCheck %s --check-prefixes=CHECK +// RUN: %clang_cc1 -O1 -triple spirv64 -cl-std=CL3.0 -x cl %s -emit-llvm -o - | FileCheck %s --check-prefixes=CHECK +// RUN: %clang_cc1 -O1 -triple spirv32 -cl-std=CL3.0 -x cl %s -emit-llvm -o - | FileCheck %s --check-prefixes=CHECK + +#if defined(__cplusplus) +typedef bool _Bool; +#endif +typedef unsigned __attribute__((ext_vector_type(4))) int4; + +// CHECK: @{{.*}}test_subgroup_shuffle{{.*}}( +// CHECK-NEXT: [[ENTRY:.*:]] +// CHECK-NEXT: tail call <4 x i32> @llvm.spv.wave.ballot(i1 %i) +[[clang::sycl_external]] int4 test_subgroup_shuffle(_Bool i) { + return __builtin_spirv_subgroup_ballot(i); +} diff --git a/clang/test/SemaSPIRV/BuiltIns/subgroup-errors.c b/clang/test/SemaSPIRV/BuiltIns/subgroup-errors.c new file mode 100644 index 0000000000000..5ef9f499efd31 --- /dev/null +++ b/clang/test/SemaSPIRV/BuiltIns/subgroup-errors.c @@ -0,0 +1,13 @@ +// RUN: %clang_cc1 -O1 -Wno-unused-value -triple spirv64 -fsycl-is-device -verify %s -o - +// RUN: %clang_cc1 -O1 -Wno-unused-value -triple spirv64 -verify %s -cl-std=CL3.0 -x cl -o - +// RUN: %clang_cc1 -O1 -Wno-unused-value -triple spirv32 -verify %s -cl-std=CL3.0 -x cl -o - + +typedef unsigned __attribute__((ext_vector_type(4))) int4; + +void ballot(_Bool c) { + int4 x; + x = __builtin_spirv_subgroup_ballot(c); + x = __builtin_spirv_subgroup_ballot(1); + x = __builtin_spirv_subgroup_ballot(x); // expected-error{{parameter of incompatible type}} + int y = __builtin_spirv_subgroup_ballot(c); // expected-error{{with an expression of incompatible type}} +} diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td index 402235ec7cd9c..51e4151c2fdae 100644 --- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td +++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td @@ -118,6 +118,8 @@ def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty] def int_spv_dot4add_i8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>; def int_spv_dot4add_u8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>; def int_spv_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; + def int_spv_wave_ballot : ClangBuiltin<"__builtin_spirv_subgroup_ballot">, + DefaultAttrsIntrinsic<[llvm_v4i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; def int_spv_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; def int_spv_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; def int_spv_wave_reduce_umax : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>; diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index f991938c14dfe..1918f5701e3b7 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -307,6 +307,8 @@ class SPIRVInstructionSelector : public InstructionSelector { bool selectWaveOpInst(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, unsigned Opcode) const; + bool selectWaveBallot(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; bool selectWaveActiveCountBits(Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const; @@ -2710,6 +2712,17 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits( return Result; } +bool SPIRVInstructionSelector::selectWaveBallot(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + + SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII); + SPIRVType *BallotType = GR.getOrCreateSPIRVVectorType(IntTy, 4, I, TII); + Register BallotReg = MRI->createVirtualRegister(GR.getRegClass(BallotType)); + return selectWaveOpInst(BallotReg, BallotType, I, + SPIRV::OpGroupNonUniformBallot); +} + bool SPIRVInstructionSelector::selectWaveReduceMax(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, @@ -3797,6 +3810,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, return selectExtInst(ResVReg, ResType, I, CL::u_clamp, GL::UClamp); case Intrinsic::spv_sclamp: return selectExtInst(ResVReg, ResType, I, CL::s_clamp, GL::SClamp); + case Intrinsic::spv_wave_ballot: + return selectWaveBallot(ResVReg, ResType, I); case Intrinsic::spv_wave_active_countbits: return selectWaveActiveCountBits(ResVReg, ResType, I); case Intrinsic::spv_wave_all: diff --git a/llvm/test/CodeGen/SPIRV/ballot.ll b/llvm/test/CodeGen/SPIRV/ballot.ll new file mode 100644 index 0000000000000..3ca1a243feaea --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/ballot.ll @@ -0,0 +1,16 @@ +; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %} + +; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0 +; CHECK-DAG: %[[#ballot_type:]] = OpTypeVector %[[#uint]] 4 +; CHECK-DAG: %[[#bool:]] = OpTypeBool +; CHECK-DAG: %[[#scope:]] = OpConstant %[[#uint]] 3 + +; CHECK-LABEL: Begin function test_fun +; CHECK: %[[#bexpr:]] = OpFunctionParameter %[[#bool]] +define <4 x i32> @test_fun(i1 %expr) { +entry: +; CHECK: %[[#ballot:]] = OpGroupNonUniformBallot %[[#ballot_type]] %[[#scope]] %[[#bexpr]] + %0 = call <4 x i32> @llvm.spv.wave.ballot(i1 %expr) + ret <4 x i32> %0 +} _______________________________________________ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
