Author: Deric C.
Date: 2026-02-23T12:10:09-08:00
New Revision: 8910926363b0cf2e692c63736eddd671b6453e5f

URL: 
https://github.com/llvm/llvm-project/commit/8910926363b0cf2e692c63736eddd671b6453e5f
DIFF: 
https://github.com/llvm/llvm-project/commit/8910926363b0cf2e692c63736eddd671b6453e5f.diff

LOG: [HLSL][Matrix] Add implicit type conversions for constant matrix types 
(#181939)

Fixes #175853

This PR extends implicit type conversion support to Clang's HLSL
frontend for handling ConstantMatrix types in addition to Vectors. The
logic is pretty much identical when handling a ConstantMatrix versus a
Vector so the changes are rather simple.

Assisted-by: claude-opus-4.6

Added: 
    clang/test/SemaHLSL/Language/MatrixSplatCasts.hlsl

Modified: 
    clang/lib/Sema/SemaExprCXX.cpp
    clang/test/CodeGenHLSL/BasicFeatures/MatrixSplat.hlsl
    clang/test/SemaHLSL/MatrixElementOverloadResolution.hlsl

Removed: 
    


################################################################################
diff  --git a/clang/lib/Sema/SemaExprCXX.cpp b/clang/lib/Sema/SemaExprCXX.cpp
index 3db8d3a10252e..31b3a06bf10d0 100644
--- a/clang/lib/Sema/SemaExprCXX.cpp
+++ b/clang/lib/Sema/SemaExprCXX.cpp
@@ -4687,20 +4687,31 @@ Sema::PerformImplicitConversion(Expr *From, QualType 
ToType,
   return From;
 }
 
-// adjustVectorType - Compute the intermediate cast type casting elements of 
the
-// from type to the elements of the to type without resizing the vector.
-static QualType adjustVectorType(ASTContext &Context, QualType FromTy,
-                                 QualType ToType, QualType *ElTy = nullptr) {
+// adjustVectorOrConstantMatrixType - Compute the intermediate cast type 
casting
+// elements of the from type to the elements of the to type without resizing 
the
+// vector or matrix.
+static QualType adjustVectorOrConstantMatrixType(ASTContext &Context,
+                                                 QualType FromTy,
+                                                 QualType ToType,
+                                                 QualType *ElTy = nullptr) {
   QualType ElType = ToType;
   if (auto *ToVec = ToType->getAs<VectorType>())
     ElType = ToVec->getElementType();
+  else if (auto *ToMat = ToType->getAs<ConstantMatrixType>())
+    ElType = ToMat->getElementType();
 
   if (ElTy)
     *ElTy = ElType;
-  if (!FromTy->isVectorType())
-    return ElType;
-  auto *FromVec = FromTy->castAs<VectorType>();
-  return Context.getExtVectorType(ElType, FromVec->getNumElements());
+  if (FromTy->isVectorType()) {
+    auto *FromVec = FromTy->castAs<VectorType>();
+    return Context.getExtVectorType(ElType, FromVec->getNumElements());
+  }
+  if (FromTy->isConstantMatrixType()) {
+    auto *FromMat = FromTy->castAs<ConstantMatrixType>();
+    return Context.getConstantMatrixType(ElType, FromMat->getNumRows(),
+                                         FromMat->getNumColumns());
+  }
+  return ElType;
 }
 
 /// Check if an integral conversion involves incompatible overflow behavior
@@ -4884,8 +4895,10 @@ Sema::PerformImplicitConversion(Expr *From, QualType 
ToType,
   case ICK_Integral_Conversion: {
     QualType ElTy = ToType;
     QualType StepTy = ToType;
-    if (FromType->isVectorType() || ToType->isVectorType())
-      StepTy = adjustVectorType(Context, FromType, ToType, &ElTy);
+    if (FromType->isVectorType() || ToType->isVectorType() ||
+        FromType->isConstantMatrixType() || ToType->isConstantMatrixType())
+      StepTy =
+          adjustVectorOrConstantMatrixType(Context, FromType, ToType, &ElTy);
 
     // Check for incompatible OBT kinds before converting
     if (checkIncompatibleOBTConversion(*this, FromType, StepTy, From))
@@ -4909,8 +4922,9 @@ Sema::PerformImplicitConversion(Expr *From, QualType 
ToType,
   case ICK_Floating_Promotion:
   case ICK_Floating_Conversion: {
     QualType StepTy = ToType;
-    if (FromType->isVectorType() || ToType->isVectorType())
-      StepTy = adjustVectorType(Context, FromType, ToType);
+    if (FromType->isVectorType() || ToType->isVectorType() ||
+        FromType->isConstantMatrixType() || ToType->isConstantMatrixType())
+      StepTy = adjustVectorOrConstantMatrixType(Context, FromType, ToType);
     From = ImpCastExprToType(From, StepTy, CK_FloatingCast, VK_PRValue,
                              /*BasePath=*/nullptr, CCK)
                .get();
@@ -4941,8 +4955,10 @@ Sema::PerformImplicitConversion(Expr *From, QualType 
ToType,
   case ICK_Floating_Integral: {
     QualType ElTy = ToType;
     QualType StepTy = ToType;
-    if (FromType->isVectorType() || ToType->isVectorType())
-      StepTy = adjustVectorType(Context, FromType, ToType, &ElTy);
+    if (FromType->isVectorType() || ToType->isVectorType() ||
+        FromType->isConstantMatrixType() || ToType->isConstantMatrixType())
+      StepTy =
+          adjustVectorOrConstantMatrixType(Context, FromType, ToType, &ElTy);
     if (ElTy->isRealFloatingType())
       From = ImpCastExprToType(From, StepTy, CK_IntegralToFloating, VK_PRValue,
                                /*BasePath=*/nullptr, CCK)
@@ -5096,9 +5112,13 @@ Sema::PerformImplicitConversion(Expr *From, QualType 
ToType,
     QualType StepTy = ToType;
     if (FromType->isVectorType())
       ElTy = FromType->castAs<VectorType>()->getElementType();
-    if (getLangOpts().HLSL &&
-        (FromType->isVectorType() || ToType->isVectorType()))
-      StepTy = adjustVectorType(Context, FromType, ToType);
+    else if (FromType->isConstantMatrixType())
+      ElTy = FromType->castAs<ConstantMatrixType>()->getElementType();
+    if (getLangOpts().HLSL) {
+      if (FromType->isVectorType() || ToType->isVectorType() ||
+          FromType->isConstantMatrixType() || ToType->isConstantMatrixType())
+        StepTy = adjustVectorOrConstantMatrixType(Context, FromType, ToType);
+    }
 
     From = ImpCastExprToType(From, StepTy, ScalarTypeToBooleanCastKind(ElTy),
                              VK_PRValue,

diff  --git a/clang/test/CodeGenHLSL/BasicFeatures/MatrixSplat.hlsl 
b/clang/test/CodeGenHLSL/BasicFeatures/MatrixSplat.hlsl
index 5edb8a3dd4690..9ae13e3dc04b0 100644
--- a/clang/test/CodeGenHLSL/BasicFeatures/MatrixSplat.hlsl
+++ b/clang/test/CodeGenHLSL/BasicFeatures/MatrixSplat.hlsl
@@ -153,3 +153,40 @@ void ExplicitFloatToBoolCastThenSplat(float2 Value) {
 void ExplicitBoolToFloatCastThenSplat(bool Value) {
     float3x2 M = (float) Value;
 }
+
+// CHECK-LABEL: define hidden void @_Z32ImplicitFloatToBoolCastThenSplatf(
+// CHECK-SAME: float noundef nofpclass(nan inf) [[VALUE:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[VALUE_ADDR:%.*]] = alloca float, align 4
+// CHECK-NEXT:    [[M:%.*]] = alloca [3 x <2 x i32>], align 4
+// CHECK-NEXT:    store float [[VALUE]], ptr [[VALUE_ADDR]], align 4
+// CHECK-NEXT:    [[TMP0:%.*]] = load float, ptr [[VALUE_ADDR]], align 4
+// CHECK-NEXT:    [[TOBOOL:%.*]] = fcmp reassoc nnan ninf nsz arcp afn une 
float [[TMP0]], 0.000000e+00
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT:%.*]] = insertelement <6 x i1> poison, 
i1 [[TOBOOL]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT:%.*]] = shufflevector <6 x i1> 
[[SPLAT_SPLATINSERT]], <6 x i1> poison, <6 x i32> zeroinitializer
+// CHECK-NEXT:    [[TMP1:%.*]] = zext <6 x i1> [[SPLAT_SPLAT]] to <6 x i32>
+// CHECK-NEXT:    store <6 x i32> [[TMP1]], ptr [[M]], align 4
+// CHECK-NEXT:    ret void
+//
+void ImplicitFloatToBoolCastThenSplat(float Value) {
+    bool2x3 M = Value;
+}
+
+// CHECK-LABEL: define hidden void @_Z32ImplicitBoolToFloatCastThenSplatb(
+// CHECK-SAME: i1 noundef [[VALUE:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[VALUE_ADDR:%.*]] = alloca i32, align 4
+// CHECK-NEXT:    [[M:%.*]] = alloca [2 x <3 x float>], align 4
+// CHECK-NEXT:    [[STOREDV:%.*]] = zext i1 [[VALUE]] to i32
+// CHECK-NEXT:    store i32 [[STOREDV]], ptr [[VALUE_ADDR]], align 4
+// CHECK-NEXT:    [[TMP0:%.*]] = load i32, ptr [[VALUE_ADDR]], align 4
+// CHECK-NEXT:    [[LOADEDV:%.*]] = trunc i32 [[TMP0]] to i1
+// CHECK-NEXT:    [[CONV:%.*]] = uitofp i1 [[LOADEDV]] to float
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT:%.*]] = insertelement <6 x float> 
poison, float [[CONV]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT:%.*]] = shufflevector <6 x float> 
[[SPLAT_SPLATINSERT]], <6 x float> poison, <6 x i32> zeroinitializer
+// CHECK-NEXT:    store <6 x float> [[SPLAT_SPLAT]], ptr [[M]], align 4
+// CHECK-NEXT:    ret void
+//
+void ImplicitBoolToFloatCastThenSplat(bool Value) {
+    float3x2 M = Value;
+}

diff  --git a/clang/test/SemaHLSL/Language/MatrixSplatCasts.hlsl 
b/clang/test/SemaHLSL/Language/MatrixSplatCasts.hlsl
new file mode 100644
index 0000000000000..985b18a3cf705
--- /dev/null
+++ b/clang/test/SemaHLSL/Language/MatrixSplatCasts.hlsl
@@ -0,0 +1,35 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library 
-finclude-default-header -fnative-half-type %s -ast-dump | FileCheck %s
+
+// Test matrix splats where the initializer scalar type 
diff ers from matrix element type.
+
+// Bool to int matrix splat
+// CHECK-LABEL: FunctionDecl {{.*}} fn0 'int4x4 (bool)'
+// CHECK: ImplicitCastExpr {{.*}} 'int4x4':'matrix<int, 4, 4>' 
<HLSLAggregateSplatCast>
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'int' <IntegralCast>
+export int4x4 fn0(bool b) {
+    return b;
+}
+
+// Float to int matrix splat
+// CHECK-LABEL: FunctionDecl {{.*}} fn1 'int4x4 (float)'
+// CHECK: ImplicitCastExpr {{.*}} 'int4x4':'matrix<int, 4, 4>' 
<HLSLAggregateSplatCast>
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'int' <FloatingToIntegral>
+export int4x4 fn1(float f) {
+    return f;
+}
+
+// Int to float matrix splat
+// CHECK-LABEL: FunctionDecl {{.*}} fn2 'float4x4 (int)'
+// CHECK: ImplicitCastExpr {{.*}} 'float4x4':'matrix<float, 4, 4>' 
<HLSLAggregateSplatCast>
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'float' <IntegralToFloating>
+export float4x4 fn2(int i) {
+    return i;
+}
+
+// Bool to float matrix splat
+// CHECK-LABEL: FunctionDecl {{.*}} fn3 'float4x4 (bool)'
+// CHECK: ImplicitCastExpr {{.*}} 'float4x4':'matrix<float, 4, 4>' 
<HLSLAggregateSplatCast>
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'float' <IntegralToFloating>
+export float4x4 fn3(bool b) {
+    return b;
+}

diff  --git a/clang/test/SemaHLSL/MatrixElementOverloadResolution.hlsl 
b/clang/test/SemaHLSL/MatrixElementOverloadResolution.hlsl
index 51500a3bcc145..b377c8548c334 100644
--- a/clang/test/SemaHLSL/MatrixElementOverloadResolution.hlsl
+++ b/clang/test/SemaHLSL/MatrixElementOverloadResolution.hlsl
@@ -238,9 +238,10 @@ void matOrVec3(float4x4 F) {}
 
 export void Case8(float2x3 f23, float4x4 f44, float3x3 f33, float3x2 f32) {
   int2x2 i22 = f23;
-  // expected-warning@-1{{implicit conversion truncates matrix: 'float2x3' 
(aka 'matrix<float, 2, 3>') to 'int2x2' (aka 'matrix<int, 2, 2>')}}
+  // expected-warning@-1{{implicit conversion truncates matrix: 'float2x3' 
(aka 'matrix<float, 2, 3>') to 'matrix<int, 2, 2>'}}
   //CHECK: VarDecl {{.*}} i22 'int2x2':'matrix<int, 2, 2>' cinit
-  //CHECK-NEXT: ImplicitCastExpr {{.*}} 'int2x2':'matrix<int, 2, 2>' 
<FloatingToIntegral>
+  //CHECK-NEXT: ImplicitCastExpr {{.*}} 'matrix<int, 2, 2>' 
<HLSLMatrixTruncation>
+  //CHECK-NEXT: ImplicitCastExpr {{.*}} 'matrix<int, 2, 3>' 
<FloatingToIntegral>
   //CHECK-NEXT: ImplicitCastExpr {{.*}} 'float2x3':'matrix<float, 2, 3>' 
<LValueToRValue>
 #ifdef ERROR
   int3x2 i32 = f23; // expected-error{{cannot initialize a variable of type 
'matrix<int, 3, 2>' with an lvalue of type 'matrix<float, 2, 3>'}}


        
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to