https://github.com/inbelic updated https://github.com/llvm/llvm-project/pull/113382
>From 35731658c1769453f86dde6063b137a2c5aeca32 Mon Sep 17 00:00:00 2001 From: Finn Plummer <canadienf...@gmail.com> Date: Fri, 18 Oct 2024 15:48:29 -0700 Subject: [PATCH 1/4] [DXIL][SPIRV] Lower WaveActiveCountBits intrinsic - add codegen for llvm builtin to spirv/directx intrinsic in CGBuiltin.cpp - add lowering of spirv intrinsic to spirv backend in SPIRVInstructionSelector.cpp - add lowering of directx intrinsic to dxil op in DXIL.td - add test cases to illustrate passes - add test case for semantic analysis --- clang/lib/CodeGen/CGBuiltin.cpp | 7 ++++ clang/lib/CodeGen/CGHLSLRuntime.h | 1 + .../builtins/WaveActiveCountBits.hlsl | 22 +++++++++++ .../BuiltIns/WaveActiveCountBits-errors.hlsl | 18 +++++++++ llvm/include/llvm/IR/IntrinsicsDirectX.td | 1 + llvm/include/llvm/IR/IntrinsicsSPIRV.td | 1 + llvm/lib/Target/DirectX/DXIL.td | 9 +++++ .../Target/SPIRV/SPIRVInstructionSelector.cpp | 37 +++++++++++++++++++ .../CodeGen/DirectX/WaveActiveCountBits.ll | 10 +++++ .../hlsl-intrinsics/WaveActiveCountBits.ll | 19 ++++++++++ 10 files changed, 125 insertions(+) create mode 100755 clang/test/CodeGenHLSL/builtins/WaveActiveCountBits.hlsl create mode 100644 clang/test/SemaHLSL/BuiltIns/WaveActiveCountBits-errors.hlsl create mode 100644 llvm/test/CodeGen/DirectX/WaveActiveCountBits.ll create mode 100644 llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveCountBits.ll diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp index 0ef9058640db6a..db6b8f80195691 100644 --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -19056,6 +19056,13 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: { /*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getStepIntrinsic(), ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.step"); } + case Builtin::BI__builtin_hlsl_wave_active_count_bits: { + Value *OpExpr = EmitScalarExpr(E->getArg(0)); + Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveCountBitsIntrinsic(); + return EmitRuntimeCall( + Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID), + ArrayRef{OpExpr}); + } case Builtin::BI__builtin_hlsl_wave_get_lane_index: { // We don't define a SPIR-V intrinsic, instead it is a SPIR-V built-in // defined in SPIRVBuiltins.td. So instead we manually get the matching name diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h index caf8777fd95a9f..167cc04baf159f 100644 --- a/clang/lib/CodeGen/CGHLSLRuntime.h +++ b/clang/lib/CodeGen/CGHLSLRuntime.h @@ -91,6 +91,7 @@ class CGHLSLRuntime { GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot) GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddI8Packed, dot4add_i8packed) GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddU8Packed, dot4add_u8packed) + GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveCountBits, wave_active_countbits) GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane) GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlane) GENERATE_HLSL_INTRINSIC_FUNCTION(FirstBitUHigh, firstbituhigh) diff --git a/clang/test/CodeGenHLSL/builtins/WaveActiveCountBits.hlsl b/clang/test/CodeGenHLSL/builtins/WaveActiveCountBits.hlsl new file mode 100755 index 00000000000000..3e1f8fcaace9c2 --- /dev/null +++ b/clang/test/CodeGenHLSL/builtins/WaveActiveCountBits.hlsl @@ -0,0 +1,22 @@ +// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \ +// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \ +// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL +// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \ +// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \ +// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV + +// Test basic lowering to runtime function call. + +// CHECK-LABEL: test_bool +int test_bool(bool expr) { + // CHECK-SPIRV: %[[#entry_tok:]] = call token @llvm.experimental.convergence.entry() + // CHECK-SPIRV: %[[RET:.*]] = call spir_func i32 @llvm.spv.wave.active.countbits(i1 %{{.*}}) [ "convergencectrl"(token %[[#entry_tok]]) ] + // CHECK-DXIL: %[[RET:.*]] = call i32 @llvm.dx.wave.active.countbits(i1 %{{.*}}) + // CHECK: ret i32 %[[RET]] + return WaveActiveCountBits(expr); +} + +// CHECK-DXIL: declare i32 @llvm.dx.wave.active.countbits(i1) #[[#attr:]] +// CHECK-SPIRV: declare i32 @llvm.spv.wave.active.countbits(i1) #[[#attr:]] + +// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}} diff --git a/clang/test/SemaHLSL/BuiltIns/WaveActiveCountBits-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/WaveActiveCountBits-errors.hlsl new file mode 100644 index 00000000000000..02f45eb30b377a --- /dev/null +++ b/clang/test/SemaHLSL/BuiltIns/WaveActiveCountBits-errors.hlsl @@ -0,0 +1,18 @@ +// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify + +int test_too_few_arg() { + return __builtin_hlsl_wave_active_count_bits(); + // expected-error@-1 {{too few arguments to function call, expected 1, have 0}} +} + +int test_too_many_arg(bool x) { + return __builtin_hlsl_wave_active_count_bits(x, x); + // expected-error@-1 {{too many arguments to function call, expected 1, have 2}} +} + +struct S { float f; }; + +int test_bad_conversion(S x) { + return __builtin_hlsl_wave_active_count_bits(x); + // expected-error@-1 {{no viable conversion from 'S' to 'bool'}} +} diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td index 43267033f024a7..191dc8ad208f93 100644 --- a/llvm/include/llvm/IR/IntrinsicsDirectX.td +++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td @@ -86,6 +86,7 @@ def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLV def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>; def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>; def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>; +def int_dx_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; def int_dx_wave_getlaneindex : DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent, IntrNoMem]>; def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>; def int_dx_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>; diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td index e93d6fa83de61b..b9b1e6ab89ddcc 100644 --- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td +++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td @@ -85,6 +85,7 @@ let TargetPrefix = "spv" in { [IntrNoMem, Commutative] >; def int_spv_dot4add_i8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>; def int_spv_dot4add_u8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>; + def int_spv_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>; def int_spv_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>; def int_spv_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>; diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td index 1aabff90e5ec6e..b8de926f4be017 100644 --- a/llvm/lib/Target/DirectX/DXIL.td +++ b/llvm/lib/Target/DirectX/DXIL.td @@ -873,3 +873,12 @@ def WaveGetLaneIndex : DXILOp<111, waveGetLaneIndex> { let stages = [Stages<DXIL1_0, [all_stages]>]; let attributes = [Attributes<DXIL1_0, [ReadNone]>]; } + +def WaveAllBitCount : DXILOp<135, waveAllOp> { + let Doc = "returns the count of bits set to 1 across the wave"; + let LLVMIntrinsic = int_dx_wave_active_countbits; + let arguments = [Int1Ty]; + let result = Int32Ty; + let stages = [Stages<DXIL1_0, [all_stages]>]; + let attributes = [Attributes<DXIL1_0, [ReadNone]>]; +} diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 414583aea91e64..e5cbc82fc5ab95 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -256,6 +256,9 @@ class SPIRVInstructionSelector : public InstructionSelector { bool selectSpvThreadId(Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const; + bool selectWaveActiveCountBits(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + bool selectWaveReadLaneAt(Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const; @@ -1917,6 +1920,38 @@ bool SPIRVInstructionSelector::selectSign(Register ResVReg, return Result; } +bool SPIRVInstructionSelector::selectWaveActiveCountBits(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + assert(I.getNumOperands() == 3); + assert(I.getOperand(2).isReg()); + MachineBasicBlock &BB = *I.getParent(); + + Register BallotReg = MRI->createVirtualRegister(&SPIRV::IDRegClass); + SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII); + SPIRVType *BallotType = GR.getOrCreateSPIRVVectorType(IntTy, 4, I, TII); + + bool Result = + BuildMI(BB, I, I.getDebugLoc(), + TII.get(SPIRV::OpGroupNonUniformBallot)) + .addDef(BallotReg) + .addUse(GR.getSPIRVTypeID(BallotType)) + .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII)) + .addUse(I.getOperand(2).getReg()); + + Result |= + BuildMI(BB, I, I.getDebugLoc(), + TII.get(SPIRV::OpGroupNonUniformBallotBitCount)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII)) + .addImm(0) + .addUse(BallotReg) + .constrainAllUses(TII, TRI, RBI); + + return Result; +} + bool SPIRVInstructionSelector::selectWaveReadLaneAt(Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const { @@ -2739,6 +2774,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, } break; case Intrinsic::spv_saturate: return selectSaturate(ResVReg, ResType, I); + case Intrinsic::spv_wave_active_countbits: + return selectWaveActiveCountBits(ResVReg, ResType, I); case Intrinsic::spv_wave_is_first_lane: { SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII); return BuildMI(BB, I, I.getDebugLoc(), diff --git a/llvm/test/CodeGen/DirectX/WaveActiveCountBits.ll b/llvm/test/CodeGen/DirectX/WaveActiveCountBits.ll new file mode 100644 index 00000000000000..5d321372433198 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/WaveActiveCountBits.ll @@ -0,0 +1,10 @@ +; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s + +define void @main(i1 %expr) { +entry: +; CHECK: call i32 @dx.op.waveAllOp(i32 135, i1 %expr) + %0 = call i32 @llvm.dx.wave.active.countbits(i1 %expr) + ret void +} + +declare i32 @llvm.dx.wave.active.countbits(i1) diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveCountBits.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveCountBits.ll new file mode 100644 index 00000000000000..29944054111ac0 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveCountBits.ll @@ -0,0 +1,19 @@ +; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32v1.3-vulkan-unknown %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32v1.3-vulkan-unknown %s -o - -filetype=obj | spirv-val %} + +; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0 +; CHECK-DAG: %[[#ballot_type:]] = OpTypeVector %[[#uint]] 4 +; CHECK-DAG: %[[#bool:]] = OpTypeBool +; CHECK-DAG: %[[#scope:]] = OpConstant %[[#uint]] 3 + +; CHECK-LABEL: Begin function test_fun +; CHECK: %[[#bexpr:]] = OpFunctionParameter %[[#bool]] +define i32 @test_fun(i1 %expr) { +entry: +; CHECK: %[[#ballot:]] = OpGroupNonUniformBallot %[[#ballot_type]] %[[#scope]] %[[#bexpr]] +; CHECK: %[[#ret:]] = OpGroupNonUniformBallotBitCount %[[#uint]] %[[#scope]] Reduce %[[#ballot]] + %0 = call i32 @llvm.spv.wave.active.countbits(i1 %expr) + ret i32 %0 +} + +declare i32 @llvm.dx.wave.active.countbits(i1) >From e3683d55c8dc64bb57e45a8f772332310fc37b5d Mon Sep 17 00:00:00 2001 From: Finn Plummer <finnplum...@microsoft.com> Date: Tue, 22 Oct 2024 21:06:50 +0000 Subject: [PATCH 2/4] clang-format --- .../Target/SPIRV/SPIRVInstructionSelector.cpp | 26 +++++++++---------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index e5cbc82fc5ab95..85a425aa4ae025 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -257,7 +257,7 @@ class SPIRVInstructionSelector : public InstructionSelector { MachineInstr &I) const; bool selectWaveActiveCountBits(Register ResVReg, const SPIRVType *ResType, - MachineInstr &I) const; + MachineInstr &I) const; bool selectWaveReadLaneAt(Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const; @@ -1920,9 +1920,8 @@ bool SPIRVInstructionSelector::selectSign(Register ResVReg, return Result; } -bool SPIRVInstructionSelector::selectWaveActiveCountBits(Register ResVReg, - const SPIRVType *ResType, - MachineInstr &I) const { +bool SPIRVInstructionSelector::selectWaveActiveCountBits( + Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const { assert(I.getNumOperands() == 3); assert(I.getOperand(2).isReg()); MachineBasicBlock &BB = *I.getParent(); @@ -1932,22 +1931,21 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits(Register ResVReg, SPIRVType *BallotType = GR.getOrCreateSPIRVVectorType(IntTy, 4, I, TII); bool Result = - BuildMI(BB, I, I.getDebugLoc(), - TII.get(SPIRV::OpGroupNonUniformBallot)) + BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpGroupNonUniformBallot)) .addDef(BallotReg) .addUse(GR.getSPIRVTypeID(BallotType)) .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII)) .addUse(I.getOperand(2).getReg()); Result |= - BuildMI(BB, I, I.getDebugLoc(), - TII.get(SPIRV::OpGroupNonUniformBallotBitCount)) - .addDef(ResVReg) - .addUse(GR.getSPIRVTypeID(ResType)) - .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII)) - .addImm(0) - .addUse(BallotReg) - .constrainAllUses(TII, TRI, RBI); + BuildMI(BB, I, I.getDebugLoc(), + TII.get(SPIRV::OpGroupNonUniformBallotBitCount)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII)) + .addImm(0) + .addUse(BallotReg) + .constrainAllUses(TII, TRI, RBI); return Result; } >From 3180f0fcd9afa5951c8f49078a8bf70dac318652 Mon Sep 17 00:00:00 2001 From: Finn Plummer <canadienf...@gmail.com> Date: Thu, 24 Oct 2024 22:07:50 +0000 Subject: [PATCH 3/4] review comments: - add constrainAllUses to first spirv op - update testcase for ease of reading - use enum instead of int equivalent for documentation --- .../CodeGenHLSL/builtins/WaveActiveCountBits.hlsl | 12 ++++-------- llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp | 5 +++-- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/clang/test/CodeGenHLSL/builtins/WaveActiveCountBits.hlsl b/clang/test/CodeGenHLSL/builtins/WaveActiveCountBits.hlsl index 3e1f8fcaace9c2..086dd295ba938d 100755 --- a/clang/test/CodeGenHLSL/builtins/WaveActiveCountBits.hlsl +++ b/clang/test/CodeGenHLSL/builtins/WaveActiveCountBits.hlsl @@ -1,22 +1,18 @@ // RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \ // RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \ -// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL +// RUN: FileCheck %s -DTARGET=dx // RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \ // RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \ -// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV +// RUN: FileCheck %s -DTARGET=spv // Test basic lowering to runtime function call. // CHECK-LABEL: test_bool int test_bool(bool expr) { - // CHECK-SPIRV: %[[#entry_tok:]] = call token @llvm.experimental.convergence.entry() - // CHECK-SPIRV: %[[RET:.*]] = call spir_func i32 @llvm.spv.wave.active.countbits(i1 %{{.*}}) [ "convergencectrl"(token %[[#entry_tok]]) ] - // CHECK-DXIL: %[[RET:.*]] = call i32 @llvm.dx.wave.active.countbits(i1 %{{.*}}) - // CHECK: ret i32 %[[RET]] + // CHECK: call {{.*}} @llvm.[[TARGET]].wave.active.countbits return WaveActiveCountBits(expr); } -// CHECK-DXIL: declare i32 @llvm.dx.wave.active.countbits(i1) #[[#attr:]] -// CHECK-SPIRV: declare i32 @llvm.spv.wave.active.countbits(i1) #[[#attr:]] +// CHECK: declare i32 @llvm.[[TARGET]].wave.active.countbits(i1) #[[#attr:]] // CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}} diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 85a425aa4ae025..04dff2a5a08b6b 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -1935,7 +1935,8 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits( .addDef(BallotReg) .addUse(GR.getSPIRVTypeID(BallotType)) .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII)) - .addUse(I.getOperand(2).getReg()); + .addUse(I.getOperand(2).getReg()) + .constrainAllUses(TII, TRI, RBI); Result |= BuildMI(BB, I, I.getDebugLoc(), @@ -1943,7 +1944,7 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits( .addDef(ResVReg) .addUse(GR.getSPIRVTypeID(ResType)) .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII)) - .addImm(0) + .addImm(SPIRV::GroupOperation::Reduce) .addUse(BallotReg) .constrainAllUses(TII, TRI, RBI); >From cff8387169d07ae082af71e92511e43f5d092144 Mon Sep 17 00:00:00 2001 From: Finn Plummer <canadienf...@gmail.com> Date: Thu, 7 Nov 2024 22:26:03 +0000 Subject: [PATCH 4/4] review comments: - get the proper register class - use result and instead of or --- llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 04dff2a5a08b6b..c17bbfc60b3954 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -1926,9 +1926,9 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits( assert(I.getOperand(2).isReg()); MachineBasicBlock &BB = *I.getParent(); - Register BallotReg = MRI->createVirtualRegister(&SPIRV::IDRegClass); SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII); SPIRVType *BallotType = GR.getOrCreateSPIRVVectorType(IntTy, 4, I, TII); + Register BallotReg = MRI->createVirtualRegister(GR.getRegClass(BallotType)); bool Result = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpGroupNonUniformBallot)) @@ -1938,7 +1938,7 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits( .addUse(I.getOperand(2).getReg()) .constrainAllUses(TII, TRI, RBI); - Result |= + Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpGroupNonUniformBallotBitCount)) .addDef(ResVReg) _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits