Author: Deric C.
Date: 2026-01-14T10:17:13-08:00
New Revision: 6212c87ab9890b3c8cc829867c0cf5be37fbccf9

URL: 
https://github.com/llvm/llvm-project/commit/6212c87ab9890b3c8cc829867c0cf5be37fbccf9
DIFF: 
https://github.com/llvm/llvm-project/commit/6212c87ab9890b3c8cc829867c0cf5be37fbccf9.diff

LOG: [HLSL][Matrix] Add Matrix splat support for booleans (#175809)

Fixes #175808

This PR adds support for boolean matrix splats by adding tests and
fixing a bug in `CodeGenFunction::EmitToMemory` when the type of a
boolean matrix already matches the type expected of a load/store.

This PR also addresses the todo comment in `clang/lib/Sema/SemaExpr.cpp`
regarding support for boolean matrix splats by removing the comment
altogether since it is not necessary.

---------

Co-authored-by: Farzon Lotfi <[email protected]>

Added: 
    

Modified: 
    clang/lib/CodeGen/CGExpr.cpp
    clang/lib/Sema/SemaExpr.cpp
    clang/test/CodeGenHLSL/BasicFeatures/MatrixSplat.hlsl

Removed: 
    


################################################################################
diff  --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 91407b233e890..896c60b13c160 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -2218,6 +2218,10 @@ llvm::Value *CodeGenFunction::EmitToMemory(llvm::Value 
*Value, QualType Ty) {
 
   if (Ty->isExtVectorBoolType() || Ty->isConstantMatrixBoolType()) {
     llvm::Type *StoreTy = convertTypeForLoadStore(Ty, Value->getType());
+
+    if (Value->getType() == StoreTy)
+      return Value;
+
     if (StoreTy->isVectorTy() && StoreTy->getScalarSizeInBits() >
                                      Value->getType()->getScalarSizeInBits())
       return Builder.CreateZExt(Value, StoreTy);

diff  --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index 4d787a60eba3b..51739c3b49ac9 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -7899,8 +7899,6 @@ ExprResult Sema::prepareMatrixSplat(QualType MatrixTy, 
Expr *SplattedExpr) {
   assert(DestElemTy->isFloatingType() ||
          DestElemTy->isIntegralOrEnumerationType());
 
-  // TODO: Add support for boolean matrix once exposed
-  // https://github.com/llvm/llvm-project/issues/170920
   ExprResult CastExprRes = SplattedExpr;
   CastKind CK = PrepareScalarCast(CastExprRes, DestElemTy);
   if (CastExprRes.isInvalid())

diff  --git a/clang/test/CodeGenHLSL/BasicFeatures/MatrixSplat.hlsl 
b/clang/test/CodeGenHLSL/BasicFeatures/MatrixSplat.hlsl
index 802c418f1dad5..9b9538e0afdd1 100644
--- a/clang/test/CodeGenHLSL/BasicFeatures/MatrixSplat.hlsl
+++ b/clang/test/CodeGenHLSL/BasicFeatures/MatrixSplat.hlsl
@@ -23,6 +23,28 @@ void ConstantFloatSplat() {
     float2x2 M = 3.25;
 }
 
+// CHECK-LABEL: define hidden void @_Z21ConstantTrueBoolSplatv(
+// CHECK-SAME: ) #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[M:%.*]] = alloca [9 x i32], align 4
+// CHECK-NEXT:    store <9 x i32> splat (i32 1), ptr [[M]], align 4
+// CHECK-NEXT:    ret void
+//
+void ConstantTrueBoolSplat() {
+    bool3x3 M = true;
+}
+
+// CHECK-LABEL: define hidden void @_Z22ConstantFalseBoolSplatv(
+// CHECK-SAME: ) #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[M:%.*]] = alloca [9 x i32], align 4
+// CHECK-NEXT:    store <9 x i32> zeroinitializer, ptr [[M]], align 4
+// CHECK-NEXT:    ret void
+//
+void ConstantFalseBoolSplat() {
+    bool3x3 M = false;
+}
+
 // CHECK-LABEL: define hidden void @_Z12DynamicSplatf(
 // CHECK-SAME: float noundef nofpclass(nan inf) [[VALUE:%.*]]) #[[ATTR0]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
@@ -39,6 +61,25 @@ void DynamicSplat(float Value) {
     float3x3 M = Value;
 }
 
+// CHECK-LABEL: define hidden void @_Z16DynamicBoolSplatb(
+// CHECK-SAME: i1 noundef [[VALUE:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[VALUE_ADDR:%.*]] = alloca i32, align 4
+// CHECK-NEXT:    [[M:%.*]] = alloca [16 x i32], align 4
+// CHECK-NEXT:    [[STOREDV:%.*]] = zext i1 [[VALUE]] to i32
+// CHECK-NEXT:    store i32 [[STOREDV]], ptr [[VALUE_ADDR]], align 4
+// CHECK-NEXT:    [[TMP0:%.*]] = load i32, ptr [[VALUE_ADDR]], align 4
+// CHECK-NEXT:    [[LOADEDV:%.*]] = trunc i32 [[TMP0]] to i1
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT:%.*]] = insertelement <16 x i1> poison, 
i1 [[LOADEDV]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT:%.*]] = shufflevector <16 x i1> 
[[SPLAT_SPLATINSERT]], <16 x i1> poison, <16 x i32> zeroinitializer
+// CHECK-NEXT:    [[TMP1:%.*]] = zext <16 x i1> [[SPLAT_SPLAT]] to <16 x i32>
+// CHECK-NEXT:    store <16 x i32> [[TMP1]], ptr [[M]], align 4
+// CHECK-NEXT:    ret void
+//
+void DynamicBoolSplat(bool Value) {
+    bool4x4 M = Value;
+}
+
 // CHECK-LABEL: define hidden void @_Z13CastThenSplatDv4_f(
 // CHECK-SAME: <4 x float> noundef nofpclass(nan inf) [[VALUE:%.*]]) 
#[[ATTR0]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
@@ -55,3 +96,60 @@ void DynamicSplat(float Value) {
 void CastThenSplat(float4 Value) {
     float3x3 M = (float) Value;
 }
+
+// CHECK-LABEL: define hidden void @_Z30ExplicitIntToBoolCastThenSplatDv3_i(
+// CHECK-SAME: <3 x i32> noundef [[VALUE:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[VALUE_ADDR:%.*]] = alloca <3 x i32>, align 16
+// CHECK-NEXT:    [[M:%.*]] = alloca [4 x i32], align 4
+// CHECK-NEXT:    store <3 x i32> [[VALUE]], ptr [[VALUE_ADDR]], align 16
+// CHECK-NEXT:    [[TMP0:%.*]] = load <3 x i32>, ptr [[VALUE_ADDR]], align 16
+// CHECK-NEXT:    [[TOBOOL:%.*]] = icmp ne <3 x i32> [[TMP0]], zeroinitializer
+// CHECK-NEXT:    [[CAST_VTRUNC:%.*]] = extractelement <3 x i1> [[TOBOOL]], 
i32 0
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT:%.*]] = insertelement <4 x i1> poison, 
i1 [[CAST_VTRUNC]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT:%.*]] = shufflevector <4 x i1> 
[[SPLAT_SPLATINSERT]], <4 x i1> poison, <4 x i32> zeroinitializer
+// CHECK-NEXT:    [[TMP1:%.*]] = zext <4 x i1> [[SPLAT_SPLAT]] to <4 x i32>
+// CHECK-NEXT:    store <4 x i32> [[TMP1]], ptr [[M]], align 4
+// CHECK-NEXT:    ret void
+//
+void ExplicitIntToBoolCastThenSplat(int3 Value) {
+    bool2x2 M = (bool) Value;
+}
+
+// CHECK-LABEL: define hidden void @_Z32ExplicitFloatToBoolCastThenSplatDv2_f(
+// CHECK-SAME: <2 x float> noundef nofpclass(nan inf) [[VALUE:%.*]]) 
#[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[VALUE_ADDR:%.*]] = alloca <2 x float>, align 8
+// CHECK-NEXT:    [[M:%.*]] = alloca [6 x i32], align 4
+// CHECK-NEXT:    store <2 x float> [[VALUE]], ptr [[VALUE_ADDR]], align 8
+// CHECK-NEXT:    [[TMP0:%.*]] = load <2 x float>, ptr [[VALUE_ADDR]], align 8
+// CHECK-NEXT:    [[TOBOOL:%.*]] = fcmp reassoc nnan ninf nsz arcp afn une <2 
x float> [[TMP0]], zeroinitializer
+// CHECK-NEXT:    [[CAST_VTRUNC:%.*]] = extractelement <2 x i1> [[TOBOOL]], 
i32 0
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT:%.*]] = insertelement <6 x i1> poison, 
i1 [[CAST_VTRUNC]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT:%.*]] = shufflevector <6 x i1> 
[[SPLAT_SPLATINSERT]], <6 x i1> poison, <6 x i32> zeroinitializer
+// CHECK-NEXT:    [[TMP1:%.*]] = zext <6 x i1> [[SPLAT_SPLAT]] to <6 x i32>
+// CHECK-NEXT:    store <6 x i32> [[TMP1]], ptr [[M]], align 4
+// CHECK-NEXT:    ret void
+//
+void ExplicitFloatToBoolCastThenSplat(float2 Value) {
+    bool2x3 M = (bool) Value;
+}
+
+// CHECK-LABEL: define hidden void @_Z32ExplicitBoolToFloatCastThenSplatb(
+// CHECK-SAME: i1 noundef [[VALUE:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[VALUE_ADDR:%.*]] = alloca i32, align 4
+// CHECK-NEXT:    [[M:%.*]] = alloca [6 x float], align 4
+// CHECK-NEXT:    [[STOREDV:%.*]] = zext i1 [[VALUE]] to i32
+// CHECK-NEXT:    store i32 [[STOREDV]], ptr [[VALUE_ADDR]], align 4
+// CHECK-NEXT:    [[TMP0:%.*]] = load i32, ptr [[VALUE_ADDR]], align 4
+// CHECK-NEXT:    [[LOADEDV:%.*]] = trunc i32 [[TMP0]] to i1
+// CHECK-NEXT:    [[CONV:%.*]] = uitofp i1 [[LOADEDV]] to float
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT:%.*]] = insertelement <6 x float> 
poison, float [[CONV]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT:%.*]] = shufflevector <6 x float> 
[[SPLAT_SPLATINSERT]], <6 x float> poison, <6 x i32> zeroinitializer
+// CHECK-NEXT:    store <6 x float> [[SPLAT_SPLAT]], ptr [[M]], align 4
+// CHECK-NEXT:    ret void
+//
+void ExplicitBoolToFloatCastThenSplat(bool Value) {
+    float3x2 M = (float) Value;
+}


        
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to