llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-clangir Author: Jiahao Guo (E00N777) <details> <summary>Changes</summary> Part of https://github.com/llvm/llvm-project/issues/185382. Lower: - [vdup_n_bf16](https://developer.arm.com/architectures/instruction-sets/intrinsics/vdup_n_bf16) - [vdupq_n_bf16](https://developer.arm.com/architectures/instruction-sets/intrinsics/vdupq_n_bf16) - [vdup_lane_bf16](https://developer.arm.com/architectures/instruction-sets/intrinsics/vdup_lane_bf16) - [vdupq_lane_bf16](https://developer.arm.com/architectures/instruction-sets/intrinsics/vdupq_lane_bf16) - [vdup_laneq_bf16](https://developer.arm.com/architectures/instruction-sets/intrinsics/vdup_laneq_bf16) - [vdupq_laneq_bf16](https://developer.arm.com/architectures/instruction-sets/intrinsics/vdupq_laneq_bf16) and add tests in [bf16-getset.c](https://github.com/llvm/llvm-project/blob/main/clang/test/CodeGen/AArch64/neon/bf16-getset.c). ## Approach ### `vdup_n_bf16` / `vdupq_n_bf16` These are not NEON builtins — they are regular `always_inline` functions defined in `arm_neon.h` that expand to vector aggregate initialization (`{v, v, v, v}`), so they work through the existing generic vector codegen path without requiring any builtin-specific handling. I just added CHECK lines in `bf16-getset.c` to verify the existing output is correct. ### `vdup_lane_bf16` / `vdupq_lane_bf16` / `vdup_laneq_bf16` / `vdupq_laneq_bf16` These are mapped (via `NEONEquivalentIntrinsicMap`) to the generic `splat_lane_v` / `splatq_lane_v` / `splat_laneq_v` / `splatq_laneq_v` builtins, which are handled in `emitCommonNeonBuiltinExpr`. I followed the approach used in both the OG codegen (`ARM.cpp`) and the [clangir incubator](https://github.com/nicovank/clangir): **OG codegen in `ARM.cpp`:** ```cpp switch (BuiltinID) { default: break; case NEON::BI__builtin_neon_splat_lane_v: case NEON::BI__builtin_neon_splat_laneq_v: case NEON::BI__builtin_neon_splatq_lane_v: case NEON::BI__builtin_neon_splatq_laneq_v: { auto NumElements = VTy->getElementCount(); if (BuiltinID == NEON::BI__builtin_neon_splatq_lane_v) NumElements = NumElements * 2; if (BuiltinID == NEON::BI__builtin_neon_splat_laneq_v) NumElements = NumElements.divideCoefficientBy(2); Ops[0] = Builder.CreateBitCast(Ops[0], VTy); return EmitNeonSplat(Ops[0], cast<ConstantInt>(Ops[1]), NumElements); } ``` **clangir incubator in `CIRGenBuiltinAArch64.cpp`:** ```cpp case NEON::BI__builtin_neon_splat_lane_v: case NEON::BI__builtin_neon_splat_laneq_v: case NEON::BI__builtin_neon_splatq_lane_v: case NEON::BI__builtin_neon_splatq_laneq_v: { uint64_t numElements = vTy.getSize(); if (builtinID == NEON::BI__builtin_neon_splatq_lane_v) numElements = numElements << 1; if (builtinID == NEON::BI__builtin_neon_splat_laneq_v) numElements = numElements >> 1; ops[0] = builder.createBitcast(ops[0], vTy); return emitNeonSplat(builder, getLoc(e->getExprLoc()), ops[0], ops[1], numElements); } ``` The call site for `splat_lane_v` already existed in `emitCommonNeonBuiltinExpr`, but had two issues: 1. **`emitNeonSplat` was called but never defined.** I added two helper functions (ported from the clangir incubator): `getIntValueFromConstOp` to extract the integer lane index from a CIR constant, and `emitNeonSplat` to build a splat shuffle mask and perform a `cir.vec.shuffle`. 2. **The call site used `getLoc(e->getExprLoc())`, which is invalid** because `emitCommonNeonBuiltinExpr` is a static free function, not a `CIRGenFunction` member. Fixed to use `cgf.getBuilder()` and the pre-computed `loc` variable. Additionally, I found that `NeonTypeFlags::BFloat16` and `NeonTypeFlags::Float16` were unhandled in `getNeonType`, which would cause the vector type to be unresolved for bf16/f16 intrinsics. I added the handling following the same pattern as the OG codegen: ```cpp case NeonTypeFlags::BFloat16: if (allowBFloatArgsAndRet) return cir::VectorType::get(cgf->getCIRGenModule().bFloat16Ty, v1Ty ? 1 : (4 << isQuad)); return cir::VectorType::get(cgf->uInt16Ty, v1Ty ? 1 : (4 << isQuad)); case NeonTypeFlags::Float16: if (hasLegalHalfType) return cir::VectorType::get(cgf->getCIRGenModule().fP16Ty, v1Ty ? 1 : (4 << isQuad)); return cir::VectorType::get(cgf->uInt16Ty, v1Ty ? 1 : (4 << isQuad)); ``` When `allowBFloatArgsAndRet` is true, we use the native `cir::BF16Type`; otherwise we fall back to `u16i`. The same logic applies to `Float16` with `hasLegalHalfType`. --- Full diff: https://github.com/llvm/llvm-project/pull/187460.diff 3 Files Affected: - (modified) clang/lib/CIR/CodeGen/CIRGenBuiltinAArch64.cpp (+25-9) - (modified) clang/test/CodeGen/AArch64/bf16-getset-intrinsics.c (-76) - (modified) clang/test/CodeGen/AArch64/neon/bf16-getset.c (+47) ``````````diff diff --git a/clang/lib/CIR/CodeGen/CIRGenBuiltinAArch64.cpp b/clang/lib/CIR/CodeGen/CIRGenBuiltinAArch64.cpp index 5d7b8d839fa84..2119d37c68a64 100644 --- a/clang/lib/CIR/CodeGen/CIRGenBuiltinAArch64.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenBuiltinAArch64.cpp @@ -796,16 +796,12 @@ static cir::VectorType getNeonType(CIRGenFunction *cgf, NeonTypeFlags typeFlags, v1Ty ? 1 : (4 << isQuad)); case NeonTypeFlags::BFloat16: if (allowBFloatArgsAndRet) - cgf->getCIRGenModule().errorNYI(loc, std::string("NEON type: BFloat16")); - else - cgf->getCIRGenModule().errorNYI(loc, std::string("NEON type: BFloat16")); - [[fallthrough]]; + return cir::VectorType::get(cgf->getCIRGenModule().bFloat16Ty, v1Ty ? 1 : (4 << isQuad)); + return cir::VectorType::get(cgf->uInt16Ty, v1Ty ? 1 : (4 << isQuad)); case NeonTypeFlags::Float16: if (hasLegalHalfType) - cgf->getCIRGenModule().errorNYI(loc, std::string("NEON type: Float16")); - else - cgf->getCIRGenModule().errorNYI(loc, std::string("NEON type: Float16")); - [[fallthrough]]; + return cir::VectorType::get(cgf->getCIRGenModule().fP16Ty, v1Ty ? 1 : (4 << isQuad)); + return cir::VectorType::get(cgf->uInt16Ty, v1Ty ? 1 : (4 << isQuad)); case NeonTypeFlags::Int32: return cir::VectorType::get(typeFlags.isUnsigned() ? cgf->uInt32Ty : cgf->sInt32Ty, @@ -831,6 +827,18 @@ static cir::VectorType getNeonType(CIRGenFunction *cgf, NeonTypeFlags typeFlags, llvm_unreachable("Unknown vector element type!"); } +static int64_t getIntValueFromConstOp(mlir::Value val) { + return val.getDefiningOp<cir::ConstantOp>().getIntValue().getSExtValue(); +} + +static mlir::Value emitNeonSplat(CIRGenBuilderTy &builder, mlir::Location loc, + mlir::Value splatVec, mlir::Value splatLane, + unsigned int splatCnt) { + int64_t splatValInt = getIntValueFromConstOp(splatLane); + llvm::SmallVector<int64_t, 4> splatMask(splatCnt, splatValInt); + return builder.createVecShuffle(loc, splatVec, splatMask); +} + static mlir::Value emitCommonNeonBuiltinExpr( CIRGenFunction &cgf, unsigned builtinID, unsigned llvmIntrinsic, unsigned altLLVMIntrinsic, const char *nameHint, unsigned modifier, @@ -867,7 +875,15 @@ static mlir::Value emitCommonNeonBuiltinExpr( case NEON::BI__builtin_neon_splat_lane_v: case NEON::BI__builtin_neon_splat_laneq_v: case NEON::BI__builtin_neon_splatq_lane_v: - case NEON::BI__builtin_neon_splatq_laneq_v: + case NEON::BI__builtin_neon_splatq_laneq_v: { + uint64_t numElements = vTy.getSize(); + if (builtinID == NEON::BI__builtin_neon_splatq_lane_v) + numElements = numElements << 1; + if (builtinID == NEON::BI__builtin_neon_splat_laneq_v) + numElements = numElements >> 1; + ops[0] = cgf.getBuilder().createBitcast(loc, ops[0], vTy); + return emitNeonSplat(cgf.getBuilder(), loc, ops[0], ops[1], numElements); + } case NEON::BI__builtin_neon_vpadd_v: case NEON::BI__builtin_neon_vpaddq_v: case NEON::BI__builtin_neon_vabs_v: diff --git a/clang/test/CodeGen/AArch64/bf16-getset-intrinsics.c b/clang/test/CodeGen/AArch64/bf16-getset-intrinsics.c index 55eb5210829d2..69171902c7e69 100644 --- a/clang/test/CodeGen/AArch64/bf16-getset-intrinsics.c +++ b/clang/test/CodeGen/AArch64/bf16-getset-intrinsics.c @@ -14,82 +14,6 @@ bfloat16x4_t test_vcreate_bf16(uint64_t a) { return vcreate_bf16(a); } -// CHECK-LABEL: @test_vdup_n_bf16( -// CHECK-NEXT: entry: -// CHECK-NEXT: [[VECINIT_I:%.*]] = insertelement <4 x bfloat> poison, bfloat [[V:%.*]], i32 0 -// CHECK-NEXT: [[VECINIT1_I:%.*]] = insertelement <4 x bfloat> [[VECINIT_I]], bfloat [[V]], i32 1 -// CHECK-NEXT: [[VECINIT2_I:%.*]] = insertelement <4 x bfloat> [[VECINIT1_I]], bfloat [[V]], i32 2 -// CHECK-NEXT: [[VECINIT3_I:%.*]] = insertelement <4 x bfloat> [[VECINIT2_I]], bfloat [[V]], i32 3 -// CHECK-NEXT: ret <4 x bfloat> [[VECINIT3_I]] -// -bfloat16x4_t test_vdup_n_bf16(bfloat16_t v) { - return vdup_n_bf16(v); -} - -// CHECK-LABEL: @test_vdupq_n_bf16( -// CHECK-NEXT: entry: -// CHECK-NEXT: [[VECINIT_I:%.*]] = insertelement <8 x bfloat> poison, bfloat [[V:%.*]], i32 0 -// CHECK-NEXT: [[VECINIT1_I:%.*]] = insertelement <8 x bfloat> [[VECINIT_I]], bfloat [[V]], i32 1 -// CHECK-NEXT: [[VECINIT2_I:%.*]] = insertelement <8 x bfloat> [[VECINIT1_I]], bfloat [[V]], i32 2 -// CHECK-NEXT: [[VECINIT3_I:%.*]] = insertelement <8 x bfloat> [[VECINIT2_I]], bfloat [[V]], i32 3 -// CHECK-NEXT: [[VECINIT4_I:%.*]] = insertelement <8 x bfloat> [[VECINIT3_I]], bfloat [[V]], i32 4 -// CHECK-NEXT: [[VECINIT5_I:%.*]] = insertelement <8 x bfloat> [[VECINIT4_I]], bfloat [[V]], i32 5 -// CHECK-NEXT: [[VECINIT6_I:%.*]] = insertelement <8 x bfloat> [[VECINIT5_I]], bfloat [[V]], i32 6 -// CHECK-NEXT: [[VECINIT7_I:%.*]] = insertelement <8 x bfloat> [[VECINIT6_I]], bfloat [[V]], i32 7 -// CHECK-NEXT: ret <8 x bfloat> [[VECINIT7_I]] -// -bfloat16x8_t test_vdupq_n_bf16(bfloat16_t v) { - return vdupq_n_bf16(v); -} - -// CHECK-LABEL: @test_vdup_lane_bf16( -// CHECK-NEXT: entry: -// CHECK-NEXT: [[TMP0:%.*]] = bitcast <4 x bfloat> [[V:%.*]] to <4 x i16> -// CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i16> [[TMP0]] to <8 x i8> -// CHECK-NEXT: [[TMP2:%.*]] = bitcast <8 x i8> [[TMP1]] to <4 x bfloat> -// CHECK-NEXT: [[LANE:%.*]] = shufflevector <4 x bfloat> [[TMP2]], <4 x bfloat> [[TMP2]], <4 x i32> <i32 1, i32 1, i32 1, i32 1> -// CHECK-NEXT: ret <4 x bfloat> [[LANE]] -// -bfloat16x4_t test_vdup_lane_bf16(bfloat16x4_t v) { - return vdup_lane_bf16(v, 1); -} - -// CHECK-LABEL: @test_vdupq_lane_bf16( -// CHECK-NEXT: entry: -// CHECK-NEXT: [[TMP0:%.*]] = bitcast <4 x bfloat> [[V:%.*]] to <4 x i16> -// CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i16> [[TMP0]] to <8 x i8> -// CHECK-NEXT: [[TMP2:%.*]] = bitcast <8 x i8> [[TMP1]] to <4 x bfloat> -// CHECK-NEXT: [[LANE:%.*]] = shufflevector <4 x bfloat> [[TMP2]], <4 x bfloat> [[TMP2]], <8 x i32> <i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1> -// CHECK-NEXT: ret <8 x bfloat> [[LANE]] -// -bfloat16x8_t test_vdupq_lane_bf16(bfloat16x4_t v) { - return vdupq_lane_bf16(v, 1); -} - -// CHECK-LABEL: @test_vdup_laneq_bf16( -// CHECK-NEXT: entry: -// CHECK-NEXT: [[TMP0:%.*]] = bitcast <8 x bfloat> [[V:%.*]] to <8 x i16> -// CHECK-NEXT: [[TMP1:%.*]] = bitcast <8 x i16> [[TMP0]] to <16 x i8> -// CHECK-NEXT: [[TMP2:%.*]] = bitcast <16 x i8> [[TMP1]] to <8 x bfloat> -// CHECK-NEXT: [[LANE:%.*]] = shufflevector <8 x bfloat> [[TMP2]], <8 x bfloat> [[TMP2]], <4 x i32> <i32 7, i32 7, i32 7, i32 7> -// CHECK-NEXT: ret <4 x bfloat> [[LANE]] -// -bfloat16x4_t test_vdup_laneq_bf16(bfloat16x8_t v) { - return vdup_laneq_bf16(v, 7); -} - -// CHECK-LABEL: @test_vdupq_laneq_bf16( -// CHECK-NEXT: entry: -// CHECK-NEXT: [[TMP0:%.*]] = bitcast <8 x bfloat> [[V:%.*]] to <8 x i16> -// CHECK-NEXT: [[TMP1:%.*]] = bitcast <8 x i16> [[TMP0]] to <16 x i8> -// CHECK-NEXT: [[TMP2:%.*]] = bitcast <16 x i8> [[TMP1]] to <8 x bfloat> -// CHECK-NEXT: [[LANE:%.*]] = shufflevector <8 x bfloat> [[TMP2]], <8 x bfloat> [[TMP2]], <8 x i32> <i32 7, i32 7, i32 7, i32 7, i32 7, i32 7, i32 7, i32 7> -// CHECK-NEXT: ret <8 x bfloat> [[LANE]] -// -bfloat16x8_t test_vdupq_laneq_bf16(bfloat16x8_t v) { - return vdupq_laneq_bf16(v, 7); -} - // CHECK-LABEL: @test_vcombine_bf16( // CHECK-NEXT: entry: // CHECK-NEXT: [[SHUFFLE_I:%.*]] = shufflevector <4 x bfloat> [[LOW:%.*]], <4 x bfloat> [[HIGH:%.*]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7> diff --git a/clang/test/CodeGen/AArch64/neon/bf16-getset.c b/clang/test/CodeGen/AArch64/neon/bf16-getset.c index faae31cb013dd..816ad11ae223a 100644 --- a/clang/test/CodeGen/AArch64/neon/bf16-getset.c +++ b/clang/test/CodeGen/AArch64/neon/bf16-getset.c @@ -34,3 +34,50 @@ bfloat16_t test_vduph_laneq_bf16(bfloat16x8_t v) { // LLVM: ret bfloat [[VGETQ_LANE]] return vduph_laneq_bf16(v, 7); } + +// ALL-LABEL: @test_vdup_lane_bf16( +bfloat16x4_t test_vdup_lane_bf16(bfloat16x4_t v) { + // CIR: cir.vec.shuffle({{%.*}}, {{%.*}} : !cir.vector<4 x !u16i>) [#cir.int<1> : !s32i, #cir.int<1> : !s32i, #cir.int<1> : !s32i, #cir.int<1> : !s32i] : !cir.vector<4 x !u16i> + // LLVM: shufflevector <4 x {{.*}}> {{.*}}, <4 x {{.*}}> {{.*}}, <4 x i32> <i32 1, i32 1, i32 1, i32 1> + return vdup_lane_bf16(v, 1); +} + +// ALL-LABEL: @test_vdupq_lane_bf16( +bfloat16x8_t test_vdupq_lane_bf16(bfloat16x4_t v) { + // CIR: cir.vec.shuffle({{%.*}}, {{%.*}} : !cir.vector<4 x !u16i>) [#cir.int<1> : !s32i, #cir.int<1> : !s32i, #cir.int<1> : !s32i, #cir.int<1> : !s32i, #cir.int<1> : !s32i, #cir.int<1> : !s32i, #cir.int<1> : !s32i, #cir.int<1> : !s32i] : !cir.vector<8 x !u16i> + // LLVM: shufflevector <4 x {{.*}}> {{.*}}, <4 x {{.*}}> {{.*}}, <8 x i32> <i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1> + return vdupq_lane_bf16(v, 1); +} + +// ALL-LABEL: @test_vdup_laneq_bf16( +bfloat16x4_t test_vdup_laneq_bf16(bfloat16x8_t v) { + // CIR: cir.vec.shuffle({{%.*}}, {{%.*}} : !cir.vector<8 x !u16i>) [#cir.int<7> : !s32i, #cir.int<7> : !s32i, #cir.int<7> : !s32i, #cir.int<7> : !s32i] : !cir.vector<4 x !u16i> + // LLVM: shufflevector <8 x {{.*}}> {{.*}}, <8 x {{.*}}> {{.*}}, <4 x i32> <i32 7, i32 7, i32 7, i32 7> + return vdup_laneq_bf16(v, 7); +} + +// ALL-LABEL: @test_vdupq_laneq_bf16( +bfloat16x8_t test_vdupq_laneq_bf16(bfloat16x8_t v) { + // CIR: cir.vec.shuffle({{%.*}}, {{%.*}} : !cir.vector<8 x !u16i>) [#cir.int<7> : !s32i, #cir.int<7> : !s32i, #cir.int<7> : !s32i, #cir.int<7> : !s32i, #cir.int<7> : !s32i, #cir.int<7> : !s32i, #cir.int<7> : !s32i, #cir.int<7> : !s32i] : !cir.vector<8 x !u16i> + // LLVM: shufflevector <8 x {{.*}}> {{.*}}, <8 x {{.*}}> {{.*}}, <8 x i32> <i32 7, i32 7, i32 7, i32 7, i32 7, i32 7, i32 7, i32 7> + return vdupq_laneq_bf16(v, 7); +} + +// ALL-LABEL: @test_vdup_n_bf16( +bfloat16x4_t test_vdup_n_bf16(bfloat16_t v) { + // CIR: cir.call @vdup_n_bf16 + // LLVM: insertelement <4 x bfloat> poison, bfloat %{{.*}}, i{{32|64}} 0 + // LLVM: insertelement <4 x bfloat> %{{.*}}, bfloat %{{.*}}, i{{32|64}} 3 + // LLVM: ret <4 x bfloat> + return vdup_n_bf16(v); +} + +// ALL-LABEL: @test_vdupq_n_bf16( +bfloat16x8_t test_vdupq_n_bf16(bfloat16_t v) { + // CIR: cir.call @vdupq_n_bf16 + // LLVM: insertelement <8 x bfloat> poison, bfloat %{{.*}}, i{{32|64}} 0 + // LLVM: insertelement <8 x bfloat> %{{.*}}, bfloat %{{.*}}, i{{32|64}} 7 + // LLVM: ret <8 x bfloat> + return vdupq_n_bf16(v); +} + \ No newline at end of file `````````` </details> https://github.com/llvm/llvm-project/pull/187460 _______________________________________________ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
