Author: Deric C.
Date: 2026-01-12T08:57:51-08:00
New Revision: 50c1a69db0f551d204d2cb29abd8dabb185a8392

URL: 
https://github.com/llvm/llvm-project/commit/50c1a69db0f551d204d2cb29abd8dabb185a8392
DIFF: 
https://github.com/llvm/llvm-project/commit/50c1a69db0f551d204d2cb29abd8dabb185a8392.diff

LOG: [HLSL][Matrix] Load and store ConstantMatrixBoolTypes as i32 
FixedVectorTypes (#175245)

Fixes #175236 

This pull request improves support for HLSL constant matrix types with
boolean elements in Clang's code generation. The main changes ensure
that boolean (i1) matrices are correctly represented and stored as i32
vectors in LLVM IR. This includes updates to both the code generation
logic and related tests.

### Code generation improvements for HLSL boolean matrices

* Updated `convertTypeForLoadStore` in `CodeGenTypes.cpp` to represent
constant matrix types with boolean elements as `FixedVectorType` of
integers, ensuring atomic load/store operations and correct element type
conversion for HLSL.
* Modified `EmitToMemory` in `CGExpr.cpp` to handle both
`ExtVectorBoolType` and `ConstantMatrixBoolType`, improving the handling
of boolean matrices during memory emission.

### Test updates for boolean matrix codegen

* Adjusted test expectations in `BoolMatrix.hlsl` to reflect the new
representation, showing stores and loads of `<N x i32>` instead of `<N x
i1>` for boolean matrices, and added zero-extension where necessary.
* Added a new test for a 4x4 boolean matrix function to verify correct
code generation for initial stores to boolean matrix parameter
declaration allocas.

Added: 
    

Modified: 
    clang/lib/CodeGen/CGExpr.cpp
    clang/lib/CodeGen/CodeGenTypes.cpp
    clang/test/CodeGenHLSL/BoolMatrix.hlsl

Removed: 
    


################################################################################
diff  --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 6309c37788f0c..999726340aaed 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -2216,7 +2216,7 @@ llvm::Value *CodeGenFunction::EmitToMemory(llvm::Value 
*Value, QualType Ty) {
   if (auto *AtomicTy = Ty->getAs<AtomicType>())
     Ty = AtomicTy->getValueType();
 
-  if (Ty->isExtVectorBoolType()) {
+  if (Ty->isExtVectorBoolType() || Ty->isConstantMatrixBoolType()) {
     llvm::Type *StoreTy = convertTypeForLoadStore(Ty, Value->getType());
     if (StoreTy->isVectorTy() && StoreTy->getScalarSizeInBits() >
                                      Value->getType()->getScalarSizeInBits())

diff  --git a/clang/lib/CodeGen/CodeGenTypes.cpp 
b/clang/lib/CodeGen/CodeGenTypes.cpp
index 4239552d1299e..0e1131d586433 100644
--- a/clang/lib/CodeGen/CodeGenTypes.cpp
+++ b/clang/lib/CodeGen/CodeGenTypes.cpp
@@ -107,8 +107,7 @@ llvm::Type *CodeGenTypes::ConvertTypeForMem(QualType T) {
     llvm::Type *IRElemTy = ConvertType(MT->getElementType());
     if (Context.getLangOpts().HLSL && T->isConstantMatrixBoolType())
       IRElemTy = ConvertTypeForMem(Context.BoolTy);
-    return llvm::ArrayType::get(IRElemTy,
-                                MT->getNumRows() * MT->getNumColumns());
+    return llvm::ArrayType::get(IRElemTy, MT->getNumElementsFlattened());
   }
 
   llvm::Type *R = ConvertType(T);
@@ -180,6 +179,16 @@ llvm::Type *CodeGenTypes::convertTypeForLoadStore(QualType 
T,
     return llvm::IntegerType::get(getLLVMContext(),
                                   (unsigned)Context.getTypeSize(T));
 
+  if (T->isConstantMatrixBoolType()) {
+    // Matrices are loaded and stored atomically as vectors. Therefore we
+    // construct a FixedVectorType here instead of returning
+    // ConvertTypeForMem(T) which would return an ArrayType instead.
+    const Type *Ty = Context.getCanonicalType(T).getTypePtr();
+    const ConstantMatrixType *MT = cast<ConstantMatrixType>(Ty);
+    llvm::Type *IRElemTy = ConvertTypeForMem(MT->getElementType());
+    return llvm::FixedVectorType::get(IRElemTy, MT->getNumElementsFlattened());
+  }
+
   if (T->isExtVectorBoolType())
     return ConvertTypeForMem(T);
 

diff  --git a/clang/test/CodeGenHLSL/BoolMatrix.hlsl 
b/clang/test/CodeGenHLSL/BoolMatrix.hlsl
index 71186f775b241..05c9ad4b926e6 100644
--- a/clang/test/CodeGenHLSL/BoolMatrix.hlsl
+++ b/clang/test/CodeGenHLSL/BoolMatrix.hlsl
@@ -12,7 +12,7 @@ struct S {
 // CHECK-NEXT:  [[ENTRY:.*:]]
 // CHECK-NEXT:    [[RETVAL:%.*]] = alloca i1, align 4
 // CHECK-NEXT:    [[B:%.*]] = alloca [4 x i32], align 4
-// CHECK-NEXT:    store <4 x i1> splat (i1 true), ptr [[B]], align 4
+// CHECK-NEXT:    store <4 x i32> splat (i32 1), ptr [[B]], align 4
 // CHECK-NEXT:    [[TMP0:%.*]] = load <4 x i32>, ptr [[B]], align 4
 // CHECK-NEXT:    [[MATRIXEXT:%.*]] = extractelement <4 x i32> [[TMP0]], i32 0
 // CHECK-NEXT:    store i32 [[MATRIXEXT]], ptr [[RETVAL]], align 4
@@ -40,11 +40,12 @@ bool fn1() {
 // CHECK-NEXT:    [[VECINIT2:%.*]] = insertelement <4 x i1> [[VECINIT]], i1 
[[LOADEDV1]], i32 1
 // CHECK-NEXT:    [[VECINIT3:%.*]] = insertelement <4 x i1> [[VECINIT2]], i1 
true, i32 2
 // CHECK-NEXT:    [[VECINIT4:%.*]] = insertelement <4 x i1> [[VECINIT3]], i1 
false, i32 3
-// CHECK-NEXT:    store <4 x i1> [[VECINIT4]], ptr [[A]], align 4
-// CHECK-NEXT:    [[TMP2:%.*]] = load <4 x i32>, ptr [[A]], align 4
-// CHECK-NEXT:    store <4 x i32> [[TMP2]], ptr [[RETVAL]], align 4
-// CHECK-NEXT:    [[TMP3:%.*]] = load <4 x i1>, ptr [[RETVAL]], align 4
-// CHECK-NEXT:    ret <4 x i1> [[TMP3]]
+// CHECK-NEXT:    [[TMP2:%.*]] = zext <4 x i1> [[VECINIT4]] to <4 x i32>
+// CHECK-NEXT:    store <4 x i32> [[TMP2]], ptr [[A]], align 4
+// CHECK-NEXT:    [[TMP3:%.*]] = load <4 x i32>, ptr [[A]], align 4
+// CHECK-NEXT:    store <4 x i32> [[TMP3]], ptr [[RETVAL]], align 4
+// CHECK-NEXT:    [[TMP4:%.*]] = load <4 x i1>, ptr [[RETVAL]], align 4
+// CHECK-NEXT:    ret <4 x i1> [[TMP4]]
 //
 bool2x2 fn2(bool V) {
   bool2x2 A = {V, true, V, false};
@@ -57,7 +58,7 @@ bool2x2 fn2(bool V) {
 // CHECK-NEXT:    [[RETVAL:%.*]] = alloca i1, align 4
 // CHECK-NEXT:    [[S:%.*]] = alloca [[STRUCT_S:%.*]], align 1
 // CHECK-NEXT:    [[BM:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr 
[[S]], i32 0, i32 0
-// CHECK-NEXT:    store <4 x i1> <i1 true, i1 false, i1 true, i1 false>, ptr 
[[BM]], align 1
+// CHECK-NEXT:    store <4 x i32> <i32 1, i32 0, i32 1, i32 0>, ptr [[BM]], 
align 1
 // CHECK-NEXT:    [[F:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr 
[[S]], i32 0, i32 1
 // CHECK-NEXT:    store float 1.000000e+00, ptr [[F]], align 1
 // CHECK-NEXT:    [[BM1:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr 
[[S]], i32 0, i32 0
@@ -77,9 +78,9 @@ bool fn3() {
 // CHECK-NEXT:  [[ENTRY:.*:]]
 // CHECK-NEXT:    [[RETVAL:%.*]] = alloca i1, align 4
 // CHECK-NEXT:    [[ARR:%.*]] = alloca [2 x [4 x i32]], align 4
-// CHECK-NEXT:    store <4 x i1> splat (i1 true), ptr [[ARR]], align 4
+// CHECK-NEXT:    store <4 x i32> splat (i32 1), ptr [[ARR]], align 4
 // CHECK-NEXT:    [[ARRAYINIT_ELEMENT:%.*]] = getelementptr inbounds [4 x 
i32], ptr [[ARR]], i32 1
-// CHECK-NEXT:    store <4 x i1> zeroinitializer, ptr [[ARRAYINIT_ELEMENT]], 
align 4
+// CHECK-NEXT:    store <4 x i32> zeroinitializer, ptr [[ARRAYINIT_ELEMENT]], 
align 4
 // CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr inbounds [2 x [4 x i32]], 
ptr [[ARR]], i32 0, i32 0
 // CHECK-NEXT:    [[TMP0:%.*]] = load <4 x i32>, ptr [[ARRAYIDX]], align 4
 // CHECK-NEXT:    [[MATRIXEXT:%.*]] = extractelement <4 x i32> [[TMP0]], i32 1
@@ -96,7 +97,7 @@ bool fn4() {
 // CHECK-SAME: ) #[[ATTR0]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
 // CHECK-NEXT:    [[M:%.*]] = alloca [4 x i32], align 4
-// CHECK-NEXT:    store <4 x i1> splat (i1 true), ptr [[M]], align 4
+// CHECK-NEXT:    store <4 x i32> splat (i32 1), ptr [[M]], align 4
 // CHECK-NEXT:    [[TMP0:%.*]] = load <4 x i32>, ptr [[M]], align 4
 // CHECK-NEXT:    [[MATINS:%.*]] = insertelement <4 x i32> [[TMP0]], i32 0, 
i32 3
 // CHECK-NEXT:    store <4 x i32> [[MATINS]], ptr [[M]], align 4
@@ -114,7 +115,7 @@ void fn5() {
 // CHECK-NEXT:    [[S:%.*]] = alloca [[STRUCT_S:%.*]], align 1
 // CHECK-NEXT:    store i32 0, ptr [[V]], align 4
 // CHECK-NEXT:    [[BM:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr 
[[S]], i32 0, i32 0
-// CHECK-NEXT:    store <4 x i1> <i1 true, i1 false, i1 true, i1 false>, ptr 
[[BM]], align 1
+// CHECK-NEXT:    store <4 x i32> <i32 1, i32 0, i32 1, i32 0>, ptr [[BM]], 
align 1
 // CHECK-NEXT:    [[F:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr 
[[S]], i32 0, i32 1
 // CHECK-NEXT:    store float 1.000000e+00, ptr [[F]], align 1
 // CHECK-NEXT:    [[TMP0:%.*]] = load i32, ptr [[V]], align 4
@@ -136,9 +137,9 @@ void fn6() {
 // CHECK-SAME: ) #[[ATTR0]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
 // CHECK-NEXT:    [[ARR:%.*]] = alloca [2 x [4 x i32]], align 4
-// CHECK-NEXT:    store <4 x i1> splat (i1 true), ptr [[ARR]], align 4
+// CHECK-NEXT:    store <4 x i32> splat (i32 1), ptr [[ARR]], align 4
 // CHECK-NEXT:    [[ARRAYINIT_ELEMENT:%.*]] = getelementptr inbounds [4 x 
i32], ptr [[ARR]], i32 1
-// CHECK-NEXT:    store <4 x i1> zeroinitializer, ptr [[ARRAYINIT_ELEMENT]], 
align 4
+// CHECK-NEXT:    store <4 x i32> zeroinitializer, ptr [[ARRAYINIT_ELEMENT]], 
align 4
 // CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr inbounds [2 x [4 x i32]], 
ptr [[ARR]], i32 0, i32 0
 // CHECK-NEXT:    [[TMP0:%.*]] = load <4 x i32>, ptr [[ARRAYIDX]], align 4
 // CHECK-NEXT:    [[MATINS:%.*]] = insertelement <4 x i32> [[TMP0]], i32 0, 
i32 1
@@ -149,3 +150,19 @@ void fn7() {
   bool2x2 Arr[2] = {{true,true,true,true}, {false,false,false,false}};
   Arr[0][1][0] = false;
 }
+
+// CHECK-LABEL: define hidden noundef <16 x i1> 
@_Z3fn8u11matrix_typeILm4ELm4EbE(
+// CHECK-SAME: <16 x i1> noundef [[M:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[RETVAL:%.*]] = alloca <16 x i1>, align 4
+// CHECK-NEXT:    [[M_ADDR:%.*]] = alloca [16 x i32], align 4
+// CHECK-NEXT:    [[TMP0:%.*]] = zext <16 x i1> [[M]] to <16 x i32>
+// CHECK-NEXT:    store <16 x i32> [[TMP0]], ptr [[M_ADDR]], align 4
+// CHECK-NEXT:    [[TMP1:%.*]] = load <16 x i32>, ptr [[M_ADDR]], align 4
+// CHECK-NEXT:    store <16 x i32> [[TMP1]], ptr [[RETVAL]], align 4
+// CHECK-NEXT:    [[TMP2:%.*]] = load <16 x i1>, ptr [[RETVAL]], align 4
+// CHECK-NEXT:    ret <16 x i1> [[TMP2]]
+//
+bool4x4 fn8(bool4x4 m) {
+  return m;
+}


        
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to