Author: Brandon Wu
Date: 2026-06-22T10:23:43+08:00
New Revision: fc7bcd0ce864c631b7429f795249ea8cef6634a7

URL: 
https://github.com/llvm/llvm-project/commit/fc7bcd0ce864c631b7429f795249ea8cef6634a7
DIFF: 
https://github.com/llvm/llvm-project/commit/fc7bcd0ce864c631b7429f795249ea8cef6634a7.diff

LOG: [clang][RISCV] Handle VLS CC on unsupported primitive type in aggregate 
type (#203898)

We handled this for pure vector type before but missed the aggregate
types, this patch try to apply same mechanism on them where unsupported
vector types are converted to same size i8 vector types.

Added: 
    

Modified: 
    clang/lib/CodeGen/Targets/RISCV.cpp
    clang/test/CodeGen/RISCV/riscv-vector-callingconv-llvm-ir.c
    clang/test/CodeGen/RISCV/riscv-vector-callingconv-llvm-ir.cpp

Removed: 
    


################################################################################
diff  --git a/clang/lib/CodeGen/Targets/RISCV.cpp 
b/clang/lib/CodeGen/Targets/RISCV.cpp
index ffe1cc6086215..bc3be06d176bc 100644
--- a/clang/lib/CodeGen/Targets/RISCV.cpp
+++ b/clang/lib/CodeGen/Targets/RISCV.cpp
@@ -67,6 +67,11 @@ class RISCVABIInfo : public DefaultABIInfo {
                                                CharUnits Field2Off) const;
 
   ABIArgInfo coerceVLSVector(QualType Ty, unsigned ABIVLen = 0) const;
+  // Some unsupported type e.g. bf16 without zvfbfmin or zvfbfa, should be
+  // passed as same size i8 type. This function check and return the 
appropriate
+  // fixed vector type.
+  llvm::FixedVectorType *
+  getVLSCCCompatibleType(llvm::FixedVectorType *FixedVecTy) const;
 
   using ABIInfo::appendAttributeMangling;
   void appendAttributeMangling(TargetClonesAttr *Attr, unsigned Index,
@@ -495,10 +500,10 @@ llvm::Type 
*RISCVABIInfo::detectVLSCCEligibleStruct(QualType Ty,
   // Turn them into scalable vector type or vector tuple type if legal.
   if (NumElts == 1) {
     // Handle single fixed-length vector.
+    llvm::FixedVectorType *VLSTy = getVLSCCCompatibleType(FixedVecTy);
     return llvm::ScalableVectorType::get(
-        FixedVecTy->getElementType(),
-        llvm::divideCeil(FixedVecTy->getNumElements() *
-                             llvm::RISCV::RVVBitsPerBlock,
+        VLSTy->getElementType(),
+        llvm::divideCeil(VLSTy->getNumElements() * 
llvm::RISCV::RVVBitsPerBlock,
                          ABIVLen));
   }
 
@@ -520,6 +525,23 @@ llvm::Type 
*RISCVABIInfo::detectVLSCCEligibleStruct(QualType Ty,
       NumElts);
 }
 
+llvm::FixedVectorType *
+RISCVABIInfo::getVLSCCCompatibleType(llvm::FixedVectorType *FixedVecTy) const {
+  llvm::Type *EltType = FixedVecTy->getElementType();
+  const TargetInfo &TI = getContext().getTargetInfo();
+  if ((EltType->isHalfTy() && !TI.hasFeature("zvfhmin")) ||
+      (EltType->isBFloatTy() &&
+       !(TI.hasFeature("zvfbfmin") || TI.hasFeature("experimental-zvfbfa"))) ||
+      (EltType->isFloatTy() && !TI.hasFeature("zve32f")) ||
+      (EltType->isDoubleTy() && !TI.hasFeature("zve64d")) ||
+      (EltType->isIntegerTy(64) && !TI.hasFeature("zve64x")) ||
+      EltType->isIntegerTy(128))
+    return llvm::FixedVectorType::get(llvm::Type::getInt8Ty(getVMContext()),
+                                      FixedVecTy->getNumElements() *
+                                          EltType->getScalarSizeInBits() / 8);
+  return FixedVecTy;
+}
+
 // Fixed-length RVV vectors are represented as scalable vectors in function
 // args/return and must be coerced from fixed vectors.
 ABIArgInfo RISCVABIInfo::coerceVLSVector(QualType Ty, unsigned ABIVLen) const {
@@ -569,27 +591,12 @@ ABIArgInfo RISCVABIInfo::coerceVLSVector(QualType Ty, 
unsigned ABIVLen) const {
 
     // Generic vector
     // The number of elements needs to be at least 1.
+    llvm::FixedVectorType *VLSTy =
+        getVLSCCCompatibleType(llvm::FixedVectorType::get(EltType, NumElts));
     ResType = llvm::ScalableVectorType::get(
-        EltType,
-        llvm::divideCeil(NumElts * llvm::RISCV::RVVBitsPerBlock, ABIVLen));
-
-    // If the corresponding extension is not supported, just make it an i8
-    // vector with same LMUL.
-    const TargetInfo &TI = getContext().getTargetInfo();
-    if ((EltType->isHalfTy() && !TI.hasFeature("zvfhmin")) ||
-        (EltType->isBFloatTy() && !(TI.hasFeature("zvfbfmin") ||
-                                    TI.hasFeature("experimental-zvfbfa"))) ||
-        (EltType->isFloatTy() && !TI.hasFeature("zve32f")) ||
-        (EltType->isDoubleTy() && !TI.hasFeature("zve64d")) ||
-        (EltType->isIntegerTy(64) && !TI.hasFeature("zve64x")) ||
-        EltType->isIntegerTy(128)) {
-      // The number of elements needs to be at least 1.
-      ResType = llvm::ScalableVectorType::get(
-          llvm::Type::getInt8Ty(getVMContext()),
-          llvm::divideCeil(EltType->getScalarSizeInBits() * NumElts *
-                               llvm::RISCV::RVVBitsPerBlock,
-                           8 * ABIVLen));
-    }
+        VLSTy->getElementType(),
+        llvm::divideCeil(VLSTy->getNumElements() * 
llvm::RISCV::RVVBitsPerBlock,
+                         ABIVLen));
   }
 
   return ABIArgInfo::getDirect(ResType);
@@ -826,11 +833,16 @@ llvm::Value *RISCVABIInfo::createCoercedLoad(Address Src, 
const ABIArgInfo &AI,
     for (unsigned i = 0; i < NumElts; ++i) {
       // Extract from struct
       llvm::Value *ExtractFromLoad = CGF.Builder.CreateExtractValue(Load, i);
+      auto *FixedVecTy =
+          cast<llvm::FixedVectorType>(ExtractFromLoad->getType());
+      llvm::FixedVectorType *VLSTy = getVLSCCCompatibleType(FixedVecTy);
+      if (VLSTy != FixedVecTy)
+        ExtractFromLoad = CGF.Builder.CreateBitCast(ExtractFromLoad, VLSTy);
       // Element in vector tuple type is always i8, so we need to cast back to
       // it's original element type.
       EltTy =
           
cast<llvm::ScalableVectorType>(llvm::VectorType::getWithSizeAndScalar(
-              cast<llvm::VectorType>(EltTy), ExtractFromLoad->getType()));
+              cast<llvm::VectorType>(EltTy), VLSTy));
       llvm::Value *VectorVal = llvm::PoisonValue::get(EltTy);
       // Insert to scalable vector
       VectorVal = CGF.Builder.CreateInsertVector(
@@ -863,9 +875,11 @@ llvm::Value *RISCVABIInfo::createCoercedLoad(Address Src, 
const ABIArgInfo &AI,
   if (auto *ArrayTy = dyn_cast<llvm::ArrayType>(SrcTy))
     SrcTy = ArrayTy->getElementType();
   Src = Src.withElementType(SrcTy);
-  [[maybe_unused]] auto *FixedSrcTy = cast<llvm::FixedVectorType>(SrcTy);
-  assert(ScalableDstTy->getElementType() == FixedSrcTy->getElementType());
-  auto *Load = CGF.Builder.CreateLoad(Src);
+  auto *FixedSrcTy = cast<llvm::FixedVectorType>(SrcTy);
+  llvm::Value *Load = CGF.Builder.CreateLoad(Src);
+  llvm::FixedVectorType *VLSTy = getVLSCCCompatibleType(FixedSrcTy);
+  if (VLSTy != FixedSrcTy)
+    Load = CGF.Builder.CreateBitCast(Load, VLSTy);
   auto *VectorVal = llvm::PoisonValue::get(ScalableDstTy);
   llvm::Value *Result = CGF.Builder.CreateInsertVector(
       ScalableDstTy, VectorVal, Load, uint64_t(0), "cast.scalable");
@@ -906,21 +920,26 @@ void RISCVABIInfo::createCoercedStore(llvm::Value *Val, 
Address Dst,
       FixedVecTy = ArrayTy->getArrayElementType();
     }
 
+    llvm::FixedVectorType *VLSTy =
+        getVLSCCCompatibleType(cast<llvm::FixedVectorType>(FixedVecTy));
+
     // Perform extract element and store
     for (unsigned i = 0; i < NumElts; ++i) {
       // Element in vector tuple type is always i8, so we need to cast back
       // to it's original element type.
       EltTy =
           
cast<llvm::ScalableVectorType>(llvm::VectorType::getWithSizeAndScalar(
-              cast<llvm::VectorType>(EltTy), FixedVecTy));
+              cast<llvm::VectorType>(EltTy), VLSTy));
       // Extract scalable vector from tuple
       llvm::Value *Idx = CGF.Builder.getInt32(i);
       auto *TupleElement = CGF.Builder.CreateIntrinsic(
           llvm::Intrinsic::riscv_tuple_extract, {EltTy, TupTy}, {Val, Idx});
 
       // Extract fixed vector from scalable vector
-      auto *ExtractVec = CGF.Builder.CreateExtractVector(
-          FixedVecTy, TupleElement, uint64_t(0));
+      llvm::Value *ExtractVec =
+          CGF.Builder.CreateExtractVector(VLSTy, TupleElement, uint64_t(0));
+      if (VLSTy != FixedVecTy)
+        ExtractVec = CGF.Builder.CreateBitCast(ExtractVec, FixedVecTy);
       // Store fixed vector to corresponding address
       Address EltPtr = Address::invalid();
       if (Dst.getElementType()->isStructTy())
@@ -952,8 +971,12 @@ void RISCVABIInfo::createCoercedStore(llvm::Value *Val, 
Address Dst,
     assert(ArrayTy->getNumElements() == 1);
     EltTy = ArrayTy->getElementType();
   }
-  auto *Coerced = CGF.Builder.CreateExtractVector(
-      cast<llvm::FixedVectorType>(EltTy), Val, uint64_t(0));
+  auto *FixedVecTy = cast<llvm::FixedVectorType>(EltTy);
+  llvm::FixedVectorType *VLSTy = getVLSCCCompatibleType(FixedVecTy);
+  llvm::Value *Coerced =
+      CGF.Builder.CreateExtractVector(VLSTy, Val, uint64_t(0));
+  if (VLSTy != FixedVecTy)
+    Coerced = CGF.Builder.CreateBitCast(Coerced, FixedVecTy);
   auto *I = CGF.Builder.CreateStore(Coerced, Dst, DestIsVolatile);
   CGF.addInstToCurrentSourceAtom(I, Val);
 }

diff  --git a/clang/test/CodeGen/RISCV/riscv-vector-callingconv-llvm-ir.c 
b/clang/test/CodeGen/RISCV/riscv-vector-callingconv-llvm-ir.c
index 0e5b76e7d024d..695bba284597d 100644
--- a/clang/test/CodeGen/RISCV/riscv-vector-callingconv-llvm-ir.c
+++ b/clang/test/CodeGen/RISCV/riscv-vector-callingconv-llvm-ir.c
@@ -150,6 +150,15 @@ struct st_i32x4x9 {
     __attribute__((vector_size(16))) int i32_9;
 };
 
+struct st_bf16x8 {
+    __attribute__((vector_size(16))) __bf16 bf16;
+};
+
+struct st_bf16x8x2 {
+    __attribute__((vector_size(16))) __bf16 bf16_1;
+    __attribute__((vector_size(16))) __bf16 bf16_2;
+};
+
 typedef int __attribute__((vector_size(256))) int32x64_t;
 
 // CHECK-LLVM: define dso_local riscv_vls_cc(128) void @test_too_large(ptr 
noundef dead_on_return %0)
@@ -207,6 +216,20 @@ void __attribute__((riscv_vls_cc)) test_st_i32x4x9(struct 
st_i32x4x9 arg) {}
 // CHECK-LLVM: define dso_local riscv_vls_cc(256) void 
@test_st_i32x4x9_256(ptr noundef dead_on_return %arg)
 void __attribute__((riscv_vls_cc(256))) test_st_i32x4x9_256(struct st_i32x4x9 
arg) {}
 
+// CHECK-LLVM: define dso_local riscv_vls_cc(128) void @test_st_bf16x8(<vscale 
x 8 x i8> %arg.target_coerce)
+// CHECK-LLVM-ZVFBFA: define dso_local riscv_vls_cc(128) void 
@test_st_bf16x8(<vscale x 4 x bfloat> %arg.target_coerce)
+void __attribute__((riscv_vls_cc)) test_st_bf16x8(struct st_bf16x8 arg) {}
+// CHECK-LLVM: define dso_local riscv_vls_cc(256) void 
@test_st_bf16x8_256(<vscale x 4 x i8> %arg.target_coerce)
+// CHECK-LLVM-ZVFBFA: define dso_local riscv_vls_cc(256) void 
@test_st_bf16x8_256(<vscale x 2 x bfloat> %arg.target_coerce)
+void __attribute__((riscv_vls_cc(256))) test_st_bf16x8_256(struct st_bf16x8 
arg) {}
+
+// CHECK-LLVM: define dso_local riscv_vls_cc(128) void 
@test_st_bf16x8x2(target("riscv.vector.tuple", <vscale x 8 x i8>, 2) 
%arg.target_coerce)
+// CHECK-LLVM-ZVFBFA: define dso_local riscv_vls_cc(128) void 
@test_st_bf16x8x2(target("riscv.vector.tuple", <vscale x 8 x i8>, 2) 
%arg.target_coerce)
+void __attribute__((riscv_vls_cc)) test_st_bf16x8x2(struct st_bf16x8x2 arg) {}
+// CHECK-LLVM: define dso_local riscv_vls_cc(256) void 
@test_st_bf16x8x2_256(target("riscv.vector.tuple", <vscale x 4 x i8>, 2) 
%arg.target_coerce)
+// CHECK-LLVM-ZVFBFA: define dso_local riscv_vls_cc(256) void 
@test_st_bf16x8x2_256(target("riscv.vector.tuple", <vscale x 4 x i8>, 2) 
%arg.target_coerce)
+void __attribute__((riscv_vls_cc(256))) test_st_bf16x8x2_256(struct 
st_bf16x8x2 arg) {}
+
 // CHECK-LLVM-LABEL: define dso_local riscv_vls_cc(128) 
target("riscv.vector.tuple", <vscale x 8 x i8>, 4) 
@test_function_prolog_epilog(target("riscv.vector.tuple", <vscale x 8 x i8>, 4) 
%arg.target_coerce) #0 {
 // CHECK-LLVM-NEXT: entry:
 // CHECK-LLVM-NEXT:   %retval = alloca %struct.st_i32x4_arr4, align 16

diff  --git a/clang/test/CodeGen/RISCV/riscv-vector-callingconv-llvm-ir.cpp 
b/clang/test/CodeGen/RISCV/riscv-vector-callingconv-llvm-ir.cpp
index 96a4c9741f738..da94574827123 100644
--- a/clang/test/CodeGen/RISCV/riscv-vector-callingconv-llvm-ir.cpp
+++ b/clang/test/CodeGen/RISCV/riscv-vector-callingconv-llvm-ir.cpp
@@ -124,6 +124,15 @@ struct st_i32x4x9 {
     __attribute__((vector_size(16))) int i32_9;
 };
 
+struct st_bf16x8 {
+    __attribute__((vector_size(16))) __bf16 bf16;
+};
+
+struct st_bf16x8x2 {
+    __attribute__((vector_size(16))) __bf16 bf16_1;
+    __attribute__((vector_size(16))) __bf16 bf16_2;
+};
+
 typedef int __attribute__((vector_size(256))) int32x64_t;
 
 // CHECK-LLVM: define dso_local riscv_vls_cc(128) void 
@_Z14test_too_largeDv64_i(ptr noundef dead_on_return %0)
@@ -180,3 +189,17 @@ typedef int __attribute__((vector_size(256))) int32x64_t;
 [[riscv::vls_cc]] void test_st_i32x4x9(struct st_i32x4x9 arg) {}
 // CHECK-LLVM: define dso_local riscv_vls_cc(256) void 
@_Z19test_st_i32x4x9_25610st_i32x4x9(ptr noundef dead_on_return %arg)
 [[riscv::vls_cc(256)]] void test_st_i32x4x9_256(struct st_i32x4x9 arg) {}
+
+// CHECK-LLVM: define dso_local riscv_vls_cc(128) void 
@_Z14test_st_bf16x89st_bf16x8(<vscale x 8 x i8> %arg.target_coerce)
+// CHECK-LLVM-ZVFBFA: define dso_local riscv_vls_cc(128) void 
@_Z14test_st_bf16x89st_bf16x8(<vscale x 4 x bfloat> %arg.target_coerce)
+[[riscv::vls_cc]] void test_st_bf16x8(struct st_bf16x8 arg) {}
+// CHECK-LLVM: define dso_local riscv_vls_cc(256) void 
@_Z18test_st_bf16x8_2569st_bf16x8(<vscale x 4 x i8> %arg.target_coerce)
+// CHECK-LLVM-ZVFBFA: define dso_local riscv_vls_cc(256) void 
@_Z18test_st_bf16x8_2569st_bf16x8(<vscale x 2 x bfloat> %arg.target_coerce)
+[[riscv::vls_cc(256)]] void test_st_bf16x8_256(struct st_bf16x8 arg) {}
+
+// CHECK-LLVM: define dso_local riscv_vls_cc(128) void 
@_Z16test_st_bf16x8x211st_bf16x8x2(target("riscv.vector.tuple", <vscale x 8 x 
i8>, 2) %arg.target_coerce)
+// CHECK-LLVM-ZVFBFA: define dso_local riscv_vls_cc(128) void 
@_Z16test_st_bf16x8x211st_bf16x8x2(target("riscv.vector.tuple", <vscale x 8 x 
i8>, 2) %arg.target_coerce)
+[[riscv::vls_cc]] void test_st_bf16x8x2(struct st_bf16x8x2 arg) {}
+// CHECK-LLVM: define dso_local riscv_vls_cc(256) void 
@_Z20test_st_bf16x8x2_25611st_bf16x8x2(target("riscv.vector.tuple", <vscale x 4 
x i8>, 2) %arg.target_coerce)
+// CHECK-LLVM-ZVFBFA: define dso_local riscv_vls_cc(256) void 
@_Z20test_st_bf16x8x2_25611st_bf16x8x2(target("riscv.vector.tuple", <vscale x 4 
x i8>, 2) %arg.target_coerce)
+[[riscv::vls_cc(256)]] void test_st_bf16x8x2_256(struct st_bf16x8x2 arg) {}


        
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to