Author: Joshua Batista
Date: 2026-01-20T10:08:26-08:00
New Revision: 11b18362822759ac1592cee5b857943fa2320f8c

URL: 
https://github.com/llvm/llvm-project/commit/11b18362822759ac1592cee5b857943fa2320f8c
DIFF: 
https://github.com/llvm/llvm-project/commit/11b18362822759ac1592cee5b857943fa2320f8c.diff

LOG: [HLSL] Handle WaveActiveBallot struct return type appropriately (#175105)

The previous WaveActiveBallot implementation did not account for the
fact that the DXC implementation of the intrinsic returns a struct type
with 4 uints, rather than a vector of 4 uints. This must be respected,
otherwise the validator will reject the uses of WaveActiveBallot that
return a vector of 4 uints.
This PR updates the return type and adds the DXC-specific return type
`fouri32` to use for the intrinsic.

Added: 
    llvm/test/tools/dxil-dis/waveactiveballot.ll

Modified: 
    clang/lib/CodeGen/CGHLSLBuiltins.cpp
    clang/lib/CodeGen/CGHLSLRuntime.h
    clang/test/CodeGenHLSL/builtins/WaveActiveBallot.hlsl
    clang/test/CodeGenSPIRV/Builtins/subgroup.c
    clang/test/Headers/gpuintrin.c
    llvm/include/llvm/IR/IntrinsicsDirectX.td
    llvm/include/llvm/IR/IntrinsicsSPIRV.td
    llvm/lib/Target/DirectX/DXIL.td
    llvm/lib/Target/DirectX/DXILOpBuilder.cpp
    llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
    llvm/test/CodeGen/DirectX/WaveActiveBallot.ll
    llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveBallot.ll

Removed: 
    


################################################################################
diff  --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp 
b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
index 1b6c3714f7821..75995ff940bc4 100644
--- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp
+++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
@@ -160,6 +160,42 @@ static Value *handleHlslSplitdouble(const CallExpr *E, 
CodeGenFunction *CGF) {
   return LastInst;
 }
 
+static Value *handleHlslWaveActiveBallot(CodeGenFunction &CGF,
+                                         const CallExpr *E) {
+  Value *Cond = CGF.EmitScalarExpr(E->getArg(0));
+  llvm::Type *I32 = CGF.Int32Ty;
+
+  llvm::Type *Vec4I32 = llvm::FixedVectorType::get(I32, 4);
+  llvm::StructType *Struct4I32 =
+      llvm::StructType::get(CGF.getLLVMContext(), {I32, I32, I32, I32});
+
+  if (CGF.CGM.getTarget().getTriple().isDXIL()) {
+    // Call DXIL intrinsic: returns { i32, i32, i32, i32 }
+    llvm::Function *Fn = CGF.CGM.getIntrinsic(Intrinsic::dx_wave_ballot, 
{I32});
+
+    Value *StructVal = CGF.EmitRuntimeCall(Fn, Cond);
+    assert(StructVal->getType() == Struct4I32 &&
+           "dx.wave.ballot must return {i32,i32,i32,i32}");
+
+    // Reassemble struct to <4 x i32>
+    llvm::Value *VecVal = llvm::PoisonValue::get(Vec4I32);
+    for (unsigned I = 0; I < 4; ++I) {
+      Value *Elt = CGF.Builder.CreateExtractValue(StructVal, I);
+      VecVal =
+          CGF.Builder.CreateInsertElement(VecVal, Elt, 
CGF.Builder.getInt32(I));
+    }
+
+    return VecVal;
+  }
+
+  if (CGF.CGM.getTarget().getTriple().isSPIRV())
+    return CGF.EmitRuntimeCall(
+        CGF.CGM.getIntrinsic(Intrinsic::spv_subgroup_ballot), Cond);
+
+  llvm_unreachable(
+      "WaveActiveBallot is only supported for DXIL and SPIRV targets");
+}
+
 static Value *handleElementwiseF16ToF32(CodeGenFunction &CGF,
                                         const CallExpr *E) {
   Value *Op0 = CGF.EmitScalarExpr(E->getArg(0));
@@ -834,9 +870,7 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned 
BuiltinID,
     assert(Op->getType()->isIntegerTy(1) &&
            "Intrinsic WaveActiveBallot operand must be a bool");
 
-    Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveBallotIntrinsic();
-    return EmitRuntimeCall(
-        Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID), {Op});
+    return handleHlslWaveActiveBallot(*this, E);
   }
   case Builtin::BI__builtin_hlsl_wave_active_count_bits: {
     Value *OpExpr = EmitScalarExpr(E->getArg(0));

diff  --git a/clang/lib/CodeGen/CGHLSLRuntime.h 
b/clang/lib/CodeGen/CGHLSLRuntime.h
index 7a5643052ed84..ba2ca2c358388 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -146,7 +146,6 @@ class CGHLSLRuntime {
   GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddU8Packed, dot4add_u8packed)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAllTrue, wave_all)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAnyTrue, wave_any)
-  GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveBallot, wave_ballot)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveCountBits, wave_active_countbits)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveGetLaneCount, wave_get_lane_count)

diff  --git a/clang/test/CodeGenHLSL/builtins/WaveActiveBallot.hlsl 
b/clang/test/CodeGenHLSL/builtins/WaveActiveBallot.hlsl
index 61b077eb1fead..df2d854a64247 100644
--- a/clang/test/CodeGenHLSL/builtins/WaveActiveBallot.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/WaveActiveBallot.hlsl
@@ -10,8 +10,18 @@
 // CHECK-LABEL: define {{.*}}test
 uint4 test(bool p1) {
   // CHECK-SPIRV: %[[#entry_tok0:]] = call token 
@llvm.experimental.convergence.entry()
-  // CHECK-SPIRV:  %[[RET:.*]] = call spir_func <4 x i32> 
@llvm.spv.wave.ballot(i1 %{{[a-zA-Z0-9]+}}) [ "convergencectrl"(token 
%[[#entry_tok0]]) ]
-  // CHECK-DXIL:  %[[RET:.*]] = call <4 x i32> @llvm.dx.wave.ballot(i1 
%{{[a-zA-Z0-9]+}})
-  // CHECK:  ret <4 x i32> %[[RET]]
+  // CHECK-SPIRV: %[[SPIRVRET:.*]] = call spir_func <4 x i32> 
@llvm.spv.subgroup.ballot(i1 %{{[a-zA-Z0-9]+}}) [ "convergencectrl"(token 
%[[#entry_tok0]]) ]
+  // CHECK-DXIL: %[[WAB:.*]] = call { i32, i32, i32, i32 } 
@llvm.dx.wave.ballot.i32(i1 %{{[a-zA-Z0-9]+}})
+  // CHECK-DXIL-NEXT: extractvalue { i32, i32, i32, i32 } {{.*}} 0
+  // CHECK-DXIL-NEXT: insertelement <4 x i32> poison, i32 {{.*}}, i32 0
+  // CHECK-DXIL-NEXT: extractvalue { i32, i32, i32, i32 } {{.*}} 1
+  // CHECK-DXIL-NEXT: insertelement <4 x i32> {{.*}}, i32 {{.*}}, i32 1
+  // CHECK-DXIL-NEXT: extractvalue { i32, i32, i32, i32 } {{.*}} 2
+  // CHECK-DXIL-NEXT: insertelement <4 x i32> {{.*}}, i32 {{.*}}, i32 2
+  // CHECK-DXIL-NEXT: extractvalue { i32, i32, i32, i32 } {{.*}} 3
+  // CHECK-DXIL-NEXT: %[[DXILRET:.*]] = insertelement <4 x i32> {{.*}}, i32 
{{.*}}, i32 3
+  // CHECK-DXIL-NEXT: ret <4 x i32> %[[DXILRET]]
+  // CHECK-SPIRV: ret <4 x i32> %[[SPIRVRET]]
+
   return WaveActiveBallot(p1);
 }

diff  --git a/clang/test/CodeGenSPIRV/Builtins/subgroup.c 
b/clang/test/CodeGenSPIRV/Builtins/subgroup.c
index 78d41b7933f1f..ba6b48e3f3848 100644
--- a/clang/test/CodeGenSPIRV/Builtins/subgroup.c
+++ b/clang/test/CodeGenSPIRV/Builtins/subgroup.c
@@ -9,7 +9,7 @@ typedef unsigned __attribute__((ext_vector_type(4))) int4;
 
 // CHECK: @{{.*}}test_subgroup_ballot{{.*}}(
 // CHECK-NEXT:  [[ENTRY:.*:]]
-// CHECK-NEXT:    tail call <4 x i32> @llvm.spv.wave.ballot(i1 %i)
+// CHECK-NEXT:    tail call <4 x i32> @llvm.spv.subgroup.ballot(i1 %i)
 [[clang::sycl_external]] int4 test_subgroup_ballot(_Bool i) {
     return __builtin_spirv_subgroup_ballot(i);
 }

diff  --git a/clang/test/Headers/gpuintrin.c b/clang/test/Headers/gpuintrin.c
index c8fe721c8c37c..891a5abf7a72a 100644
--- a/clang/test/Headers/gpuintrin.c
+++ b/clang/test/Headers/gpuintrin.c
@@ -1267,7 +1267,7 @@ __gpu_kernel void foo() {
 // SPIRV-NEXT:  [[ENTRY:.*:]]
 // SPIRV-NEXT:    [[__MASK:%.*]] = alloca <4 x i32>, align 16
 // SPIRV-NEXT:    [[REF_TMP:%.*]] = alloca <2 x i32>, align 8
-// SPIRV-NEXT:    [[TMP0:%.*]] = call <4 x i32> @llvm.spv.wave.ballot(i1 true)
+// SPIRV-NEXT:    [[TMP0:%.*]] = call <4 x i32> @llvm.spv.subgroup.ballot(i1 
true)
 // SPIRV-NEXT:    store <4 x i32> [[TMP0]], ptr [[__MASK]], align 16
 // SPIRV-NEXT:    [[TMP1:%.*]] = load <4 x i32>, ptr [[__MASK]], align 16
 // SPIRV-NEXT:    [[TMP2:%.*]] = load <4 x i32>, ptr [[__MASK]], align 16
@@ -1335,7 +1335,7 @@ __gpu_kernel void foo() {
 // SPIRV-NEXT:    store i8 [[STOREDV]], ptr [[__X_ADDR]], align 1
 // SPIRV-NEXT:    [[TMP0:%.*]] = load i8, ptr [[__X_ADDR]], align 1
 // SPIRV-NEXT:    [[LOADEDV:%.*]] = trunc i8 [[TMP0]] to i1
-// SPIRV-NEXT:    [[TMP1:%.*]] = call <4 x i32> @llvm.spv.wave.ballot(i1 
[[LOADEDV]])
+// SPIRV-NEXT:    [[TMP1:%.*]] = call <4 x i32> @llvm.spv.subgroup.ballot(i1 
[[LOADEDV]])
 // SPIRV-NEXT:    store <4 x i32> [[TMP1]], ptr [[__MASK]], align 16
 // SPIRV-NEXT:    [[TMP2:%.*]] = load i64, ptr [[__LANE_MASK_ADDR]], align 8
 // SPIRV-NEXT:    [[TMP3:%.*]] = load <4 x i32>, ptr [[__MASK]], align 16

diff  --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td 
b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 6e6eb2d0ece9d..f79945785566c 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -153,7 +153,7 @@ def int_dx_rsqrt  : 
DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]
 def int_dx_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], 
[llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
 def int_dx_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], 
[IntrConvergent, IntrNoMem]>;
 def int_dx_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], 
[IntrConvergent, IntrNoMem]>;
-def int_dx_wave_ballot : DefaultAttrsIntrinsic<[llvm_v4i32_ty], [llvm_i1_ty], 
[IntrConvergent, IntrNoMem]>;
+def int_dx_wave_ballot : DefaultAttrsIntrinsic<[llvm_anyint_ty, 
LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [llvm_i1_ty], 
[IntrConvergent, IntrNoMem]>;
 def int_dx_wave_getlaneindex : DefaultAttrsIntrinsic<[llvm_i32_ty], [], 
[IntrConvergent, IntrNoMem]>;
 def int_dx_wave_reduce_max : DefaultAttrsIntrinsic<[llvm_any_ty], 
[LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
 def int_dx_wave_reduce_umax : DefaultAttrsIntrinsic<[llvm_anyint_ty], 
[LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;

diff  --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td 
b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index d782d4f5fae0b..293cb750cea98 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -120,7 +120,7 @@ def int_spv_rsqrt : 
DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]
   def int_spv_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], 
[llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
   def int_spv_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], 
[IntrConvergent, IntrNoMem]>;
   def int_spv_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], 
[IntrConvergent, IntrNoMem]>;
-  def int_spv_wave_ballot : ClangBuiltin<"__builtin_spirv_subgroup_ballot">,
+  def int_spv_subgroup_ballot : 
ClangBuiltin<"__builtin_spirv_subgroup_ballot">,
     DefaultAttrsIntrinsic<[llvm_v4i32_ty], [llvm_i1_ty], [IntrConvergent, 
IntrNoMem]>;
   def int_spv_wave_reduce_umax : DefaultAttrsIntrinsic<[llvm_any_ty], 
[LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
   def int_spv_wave_reduce_max : DefaultAttrsIntrinsic<[llvm_any_ty], 
[LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;

diff  --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 6d04732d92ecf..3a40d2c36139d 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -58,6 +58,7 @@ def ResPropsTy : DXILOpParamType;
 def SplitDoubleTy : DXILOpParamType;
 def BinaryWithCarryTy : DXILOpParamType;
 def DimensionsTy : DXILOpParamType;
+def Fouri32s : DXILOpParamType;
 
 class DXILOpClass;
 
@@ -212,13 +213,12 @@ defset list<DXILOpClass> OpClasses = {
   def unpack4x8 : DXILOpClass;
   def viewID : DXILOpClass;
   def waveActiveAllEqual : DXILOpClass;
-  def waveActiveBallot : DXILOpClass;
   def waveActiveBit : DXILOpClass;
   def waveActiveOp : DXILOpClass;
   def waveAllOp : DXILOpClass;
   def waveAllTrue : DXILOpClass;
   def waveAnyTrue : DXILOpClass;
-  def waveBallot : DXILOpClass;
+  def waveActiveBallot : DXILOpClass;
   def waveGetLaneCount : DXILOpClass;
   def waveGetLaneIndex : DXILOpClass;
   def waveIsFirstLane : DXILOpClass;
@@ -1062,6 +1062,14 @@ def WaveActiveAllTrue : DXILOp<114, waveAllTrue> {
   let stages = [Stages<DXIL1_0, [all_stages]>];
 }
 
+def WaveActiveBallot : DXILOp<116, waveActiveBallot> {
+  let Doc = "returns uint4 containing a bitmask of the evaluation of the 
boolean expression for all active lanes in the current wave.";
+  let intrinsics = [IntrinSelect<int_dx_wave_ballot>];
+  let arguments = [Int1Ty];
+  let result = Fouri32s;
+  let stages = [Stages<DXIL1_0, [all_stages]>];
+}
+
 def WaveReadLaneAt : DXILOp<117, waveReadLaneAt> {
   let Doc = "returns the value from the specified lane";
   let intrinsics = [IntrinSelect<int_dx_wave_readlane>];
@@ -1072,14 +1080,6 @@ def WaveReadLaneAt : DXILOp<117, waveReadLaneAt> {
   let stages = [Stages<DXIL1_0, [all_stages]>];
 }
 
-def WaveActiveBallot : DXILOp<118, waveBallot> {
-  let Doc = "returns uint4 containing a bitmask of the evaluation of the 
boolean expression for all active lanes in the current wave.";
-  let intrinsics = [IntrinSelect<int_dx_wave_ballot>];
-  let arguments = [Int1Ty];
-  let result = OverloadTy;
-  let stages = [Stages<DXIL1_0, [all_stages]>];
-}
-
 def WaveActiveOp : DXILOp<119, waveActiveOp> {
   let Doc = "returns the result of the operation across waves";
   let intrinsics = [

diff  --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp 
b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index 944b2e6433988..1f41d2457e5bc 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -261,10 +261,18 @@ static StructType *getBinaryWithCarryType(LLVMContext 
&Context) {
   return StructType::create({Int32Ty, Int1Ty}, "dx.types.i32c");
 }
 
-static StructType *getDimensionsType(LLVMContext &Ctx) {
-  Type *Int32Ty = Type::getInt32Ty(Ctx);
+static StructType *getDimensionsType(LLVMContext &Context) {
+  Type *Int32Ty = Type::getInt32Ty(Context);
   return getOrCreateStructType("dx.types.Dimensions",
-                               {Int32Ty, Int32Ty, Int32Ty, Int32Ty}, Ctx);
+                               {Int32Ty, Int32Ty, Int32Ty, Int32Ty}, Context);
+}
+
+static StructType *getFouri32sType(LLVMContext &Context) {
+  if (auto *ST = StructType::getTypeByName(Context, "dx.types.fouri32"))
+    return ST;
+  Type *Int32Ty = Type::getInt32Ty(Context);
+  return getOrCreateStructType("dx.types.fouri32",
+                               {Int32Ty, Int32Ty, Int32Ty, Int32Ty}, Context);
 }
 
 static Type *getTypeFromOpParamType(OpParamType Kind, LLVMContext &Ctx,
@@ -326,7 +334,10 @@ static Type *getTypeFromOpParamType(OpParamType Kind, 
LLVMContext &Ctx,
     return getBinaryWithCarryType(Ctx);
   case OpParamType::DimensionsTy:
     return getDimensionsType(Ctx);
+  case OpParamType::Fouri32s:
+    return getFouri32sType(Ctx);
   }
+
   llvm_unreachable("Invalid parameter kind");
   return nullptr;
 }

diff  --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp 
b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 98b5bfd678135..626393d4ecb40 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -3815,7 +3815,7 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register 
ResVReg,
     return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformAll);
   case Intrinsic::spv_wave_any:
     return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformAny);
-  case Intrinsic::spv_wave_ballot:
+  case Intrinsic::spv_subgroup_ballot:
     return selectWaveOpInst(ResVReg, ResType, I,
                             SPIRV::OpGroupNonUniformBallot);
   case Intrinsic::spv_wave_is_first_lane:

diff  --git a/llvm/test/CodeGen/DirectX/WaveActiveBallot.ll 
b/llvm/test/CodeGen/DirectX/WaveActiveBallot.ll
index cf6255de3a734..f0440cb4e6183 100644
--- a/llvm/test/CodeGen/DirectX/WaveActiveBallot.ll
+++ b/llvm/test/CodeGen/DirectX/WaveActiveBallot.ll
@@ -1,10 +1,37 @@
-; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | 
FileCheck %s
-
-define noundef <4 x i32> @wave_ballot_simple(i1 noundef %p1) {
-entry:
-; CHECK: call <4 x i32> @dx.op.waveBallot.void(i32 118, i1 %p1)
-  %ret = call <4 x i32> @llvm.dx.wave.ballot(i1 %p1)
-  ret <4 x i32> %ret
-}
-
-declare <4 x i32> @llvm.dx.wave.ballot(i1)
+; RUN: opt -S -scalarizer -dxil-op-lower %s | FileCheck %s
+
+target datalayout = 
"e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64-v48:16:16-v96:32:32-v192:64:64"
+target triple = "dxilv1.3-pc-shadermodel6.3-compute"
+
+; The definition of the custom type should be added
+; CHECK: %dx.types.fouri32 = type { i32, i32, i32, i32 }
+
+; Function Attrs: alwaysinline convergent mustprogress norecurse nounwind
+define hidden noundef <4 x i32> @_Z4testb(i1 noundef %p1) {
+entry:
+  %p1.addr = alloca i32, align 4
+  %storedv = zext i1 %p1 to i32
+  store i32 %storedv, ptr %p1.addr, align 4
+  %0 = load i32, ptr %p1.addr, align 4
+  %loadedv = trunc i32 %0 to i1
+  %1 = load i32, ptr %p1.addr, align 4
+  %loadedv1 = trunc i32 %1 to i1
+
+  ; CHECK: call %dx.types.fouri32 @dx.op.waveActiveBallot(i32 116, i1 
%loadedv1)
+
+  %2 = call { i32, i32, i32, i32 } @llvm.dx.wave.ballot.i32(i1 %loadedv1)
+  %3 = extractvalue { i32, i32, i32, i32 } %2, 0
+  %4 = insertelement <4 x i32> poison, i32 %3, i32 0
+  %5 = extractvalue { i32, i32, i32, i32 } %2, 1
+  %6 = insertelement <4 x i32> %4, i32 %5, i32 1
+  %7 = extractvalue { i32, i32, i32, i32 } %2, 2
+  %8 = insertelement <4 x i32> %6, i32 %7, i32 2
+  %9 = extractvalue { i32, i32, i32, i32 } %2, 3
+  %10 = insertelement <4 x i32> %8, i32 %9, i32 3
+
+  ; CHECK-NOT: ret %dx.types.fouri32
+  ; CHECK: ret <4 x i32>
+  ret <4 x i32> %10
+}
+
+declare { i32, i32, i32, i32 } @llvm.dx.wave.ballot.i32(i1)

diff  --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveBallot.ll 
b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveBallot.ll
index 6831888f038fd..e38d77360631b 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveBallot.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveBallot.ll
@@ -13,10 +13,10 @@ entry:
 ; CHECK: %[[#param:]] = OpFunctionParameter %[[#bool]]
 ; CHECK: %{{.+}} = OpGroupNonUniformBallot %[[#bitmask]] %[[#scope]] 
%[[#param]]
   %0 = call token @llvm.experimental.convergence.entry()
-  %ret = call <4 x i32> @llvm.spv.wave.ballot(i1 %p1) [ 
"convergencectrl"(token %0) ]
+  %ret = call <4 x i32> @llvm.spv.subgroup.ballot(i1 %p1) [ 
"convergencectrl"(token %0) ]
   ret <4 x i32> %ret
 }
 
-declare <4 x i32> @llvm.spv.wave.ballot(i1) #0
+declare <4 x i32> @llvm.spv.subgroup.ballot(i1) #0
 
 attributes #0 = { convergent }

diff  --git a/llvm/test/tools/dxil-dis/waveactiveballot.ll 
b/llvm/test/tools/dxil-dis/waveactiveballot.ll
new file mode 100644
index 0000000000000..2bdb4ec98a3db
--- /dev/null
+++ b/llvm/test/tools/dxil-dis/waveactiveballot.ll
@@ -0,0 +1,31 @@
+; RUN: llc %s --filetype=obj -o - | dxil-dis -o - | FileCheck %s
+
+; CHECK-NOT: llvm.dx.wave.ballot
+
+; CHECK: call %dx.types.fouri32 @dx.op.waveActiveBallot(i32 116, i1 %p1)
+; CHECK-NOT: ret %dx.types.fouri32
+; CHECK: ret <4 x i32>
+
+
+target triple = "dxil-unknown-shadermodel6.3-library"
+
+%dx.types.fouri32 = type { i32, i32, i32, i32 }
+
+define <4 x i32> @wave_ballot_simple(i1 %p1) {
+entry:
+  %s = call %dx.types.fouri32 @llvm.dx.wave.ballot(i1 %p1)
+
+  %v0 = extractvalue %dx.types.fouri32 %s, 0
+  %v1 = extractvalue %dx.types.fouri32 %s, 1
+  %v2 = extractvalue %dx.types.fouri32 %s, 2
+  %v3 = extractvalue %dx.types.fouri32 %s, 3
+
+  %vec0 = insertelement <4 x i32> poison, i32 %v0, i32 0
+  %vec1 = insertelement <4 x i32> %vec0, i32 %v1, i32 1
+  %vec2 = insertelement <4 x i32> %vec1, i32 %v2, i32 2
+  %vec3 = insertelement <4 x i32> %vec2, i32 %v3, i32 3
+
+  ret <4 x i32> %vec3
+}
+
+declare %dx.types.fouri32 @llvm.dx.wave.ballot(i1)


        
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to