https://github.com/Icohedron updated https://github.com/llvm/llvm-project/pull/183424
>From c33ea3969b9f527d18a37932751c625d9b3f71d1 Mon Sep 17 00:00:00 2001 From: Deric Cheung <[email protected]> Date: Wed, 25 Feb 2026 16:13:16 -0800 Subject: [PATCH 1/9] Support ConstantMatrixTypes and HLSL casts to bytecode constexpr evaluator This commit adds support for ConstantMatrixType and the HLSL casts CK_HLSLArrayRValue, CK_HLSLMatrixTruncation, CK_HLSLAggregateSplatCast, and CK_HLSLElementwiseCast to the bytecode constexpr evaluator. The implementations of CK_HLSLAggregateSplatCast and CK_HLSLElementwiseCast are incomplete, as they still need to support struct and array types. Assisted-by: claude-opus-4.6 --- clang/lib/AST/ByteCode/Compiler.cpp | 207 ++++++++++++++++++ clang/lib/AST/ByteCode/Compiler.h | 5 + clang/lib/AST/ByteCode/Pointer.cpp | 18 ++ clang/lib/AST/ByteCode/Program.cpp | 11 + .../BuiltinMatrix/MatrixConstantExpr.hlsl | 2 + 5 files changed, 243 insertions(+) diff --git a/clang/lib/AST/ByteCode/Compiler.cpp b/clang/lib/AST/ByteCode/Compiler.cpp index 15e65a4d96581..87982e67dcb51 100644 --- a/clang/lib/AST/ByteCode/Compiler.cpp +++ b/clang/lib/AST/ByteCode/Compiler.cpp @@ -811,6 +811,180 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) { case CK_LValueBitCast: return this->emitInvalidCast(CastKind::ReinterpretLike, /*Fatal=*/true, CE); + case CK_HLSLArrayRValue: { + // Non-decaying array rvalue cast - creates an rvalue copy of an lvalue + // array, similar to LValueToRValue for composite types. + if (!Initializing) { + UnsignedOrNone LocalIndex = allocateLocal(CE); + if (!LocalIndex) + return false; + if (!this->emitGetPtrLocal(*LocalIndex, CE)) + return false; + } + if (!this->visit(SubExpr)) + return false; + return this->emitMemcpy(CE); + } + + case CK_HLSLMatrixTruncation: { + assert(SubExpr->getType()->isConstantMatrixType()); + if (OptPrimType ResultT = classify(CE)) { + assert(!DiscardResult); + // Result must be either a float or integer. Take the first element. + if (!this->visit(SubExpr)) + return false; + return this->emitArrayElemPop(*ResultT, 0, CE); + } + // Otherwise, this truncates to a a constant matrix type. + assert(CE->getType()->isConstantMatrixType()); + + if (!Initializing) { + UnsignedOrNone LocalIndex = allocateTemporary(CE); + if (!LocalIndex) + return false; + if (!this->emitGetPtrLocal(*LocalIndex, CE)) + return false; + } + unsigned ToSize = + CE->getType()->getAs<ConstantMatrixType>()->getNumElementsFlattened(); + if (!this->visit(SubExpr)) + return false; + return this->emitCopyArray(classifyMatrixElementType(SubExpr->getType()), 0, + 0, ToSize, CE); + } + + case CK_HLSLAggregateSplatCast: { + // Aggregate splat cast: convert a scalar value to one of an aggregate type. + // TODO: Aggregate splat to struct and array types + assert(canClassify(SubExpr->getType())); + + unsigned NumElts; + PrimType DestElemT; + QualType DestElemType; + if (const auto *VT = CE->getType()->getAs<VectorType>()) { + NumElts = VT->getNumElements(); + DestElemType = VT->getElementType(); + } else if (const auto *MT = + CE->getType()->getAs<ConstantMatrixType>()) { + NumElts = MT->getNumElementsFlattened(); + DestElemType = MT->getElementType(); + } else { + return false; + } + DestElemT = classifyPrim(DestElemType); + + if (!Initializing) { + UnsignedOrNone LocalIndex = allocateLocal(CE); + if (!LocalIndex) + return false; + if (!this->emitGetPtrLocal(*LocalIndex, CE)) + return false; + } + + PrimType SrcElemT = classifyPrim(SubExpr->getType()); + unsigned SrcOffset = + allocateLocalPrimitive(SubExpr, DestElemT, /*IsConst=*/true); + + if (!this->visit(SubExpr)) + return false; + if (classifyPrim(SubExpr) == PT_Ptr && !this->emitLoadPop(SrcElemT, CE)) + return false; + if (SrcElemT != DestElemT) { + if (!this->emitPrimCast(SrcElemT, DestElemT, DestElemType, CE)) + return false; + } + if (!this->emitSetLocal(DestElemT, SrcOffset, CE)) + return false; + + for (unsigned I = 0; I != NumElts; ++I) { + if (!this->emitGetLocal(DestElemT, SrcOffset, CE)) + return false; + if (!this->emitInitElem(DestElemT, I, CE)) + return false; + } + return true; + } + + case CK_HLSLElementwiseCast: { + // Elementwise cast: flatten source elements of one aggregate type and store + // to a destination aggregate type of the same or fewer number of elements. + // TODO: Elementwise cast to structs, nested arrays, and arrays of composite + // types + QualType SrcType = SubExpr->getType(); + QualType DestType = CE->getType(); + + unsigned SrcNumElts; + PrimType SrcElemT; + if (const auto *VT = SrcType->getAs<VectorType>()) { + SrcNumElts = VT->getNumElements(); + SrcElemT = classifyPrim(VT->getElementType()); + } else if (const auto *MT = SrcType->getAs<ConstantMatrixType>()) { + SrcNumElts = MT->getNumElementsFlattened(); + SrcElemT = classifyPrim(MT->getElementType()); + } else if (const auto *AT = SrcType->getAsArrayTypeUnsafe()) { + if (const auto *CAT = dyn_cast<ConstantArrayType>(AT)) { + SrcNumElts = CAT->getZExtSize(); + SrcElemT = classifyPrim(CAT->getElementType()); + } else { + return false; + } + } else { + return false; + } + + unsigned DestNumElts; + PrimType DestElemT; + QualType DestElemType; + if (const auto *VT = DestType->getAs<VectorType>()) { + DestNumElts = VT->getNumElements(); + DestElemType = VT->getElementType(); + } else if (const auto *MT = DestType->getAs<ConstantMatrixType>()) { + DestNumElts = MT->getNumElementsFlattened(); + DestElemType = MT->getElementType(); + } else if (const auto *AT = DestType->getAsArrayTypeUnsafe()) { + if (const auto *CAT = dyn_cast<ConstantArrayType>(AT)) { + DestNumElts = CAT->getZExtSize(); + DestElemType = CAT->getElementType(); + } else { + return false; + } + } else { + return false; + } + DestElemT = classifyPrim(DestElemType); + + if (!Initializing) { + UnsignedOrNone LocalIndex = + classify(DestType) ? allocateLocal(CE) : allocateTemporary(CE); + if (!LocalIndex) + return false; + if (!this->emitGetPtrLocal(*LocalIndex, CE)) + return false; + } + + unsigned SrcOffset = + allocateLocalPrimitive(SubExpr, PT_Ptr, /*IsConst=*/true); + if (!this->visit(SubExpr)) + return false; + if (!this->emitSetLocal(PT_Ptr, SrcOffset, CE)) + return false; + + unsigned NumElts = std::min(SrcNumElts, DestNumElts); + for (unsigned I = 0; I != NumElts; ++I) { + if (!this->emitGetLocal(PT_Ptr, SrcOffset, CE)) + return false; + if (!this->emitArrayElemPop(SrcElemT, I, CE)) + return false; + if (SrcElemT != DestElemT) { + if (!this->emitPrimCast(SrcElemT, DestElemT, DestElemType, CE)) + return false; + } + if (!this->emitInitElem(DestElemT, I, CE)) + return false; + } + return true; + } + default: return this->emitInvalid(CE); } @@ -1813,6 +1987,20 @@ bool Compiler<Emitter>::VisitImplicitValueInitExpr( return true; } + if (const auto *MT = E->getType()->getAs<ConstantMatrixType>()) { + unsigned NumElts = MT->getNumElementsFlattened(); + QualType ElemQT = MT->getElementType(); + PrimType ElemT = classifyPrim(ElemQT); + + for (unsigned I = 0; I < NumElts; ++I) { + if (!this->visitZeroInitializer(ElemT, ElemQT, E)) + return false; + if (!this->emitInitElem(ElemT, I, E)) + return false; + } + return true; + } + return false; } @@ -2129,6 +2317,25 @@ bool Compiler<Emitter>::visitInitList(ArrayRef<const Expr *> Inits, return true; } + if (const auto *MT = QT->getAs<ConstantMatrixType>()) { + unsigned NumElts = MT->getNumElementsFlattened(); + assert(Inits.size() == NumElts); + + QualType ElemQT = MT->getElementType(); + PrimType ElemT = classifyPrim(ElemQT); + + // InitListExpr elements are in column-major order. + // Store in row-major order to match APValue convention. + for (unsigned I = 0; I < NumElts; ++I) { + if (!this->visit(Inits[I])) + return false; + if (!this->emitInitElem(ElemT, + MT->mapColumnMajorToRowMajorFlattenedIndex(I), E)) + return false; + } + return true; + } + return false; } diff --git a/clang/lib/AST/ByteCode/Compiler.h b/clang/lib/AST/ByteCode/Compiler.h index 1bd15c3d79563..74ded47e88792 100644 --- a/clang/lib/AST/ByteCode/Compiler.h +++ b/clang/lib/AST/ByteCode/Compiler.h @@ -406,6 +406,11 @@ class Compiler : public ConstStmtVisitor<Compiler<Emitter>, bool>, return *this->classify(T->getAs<VectorType>()->getElementType()); } + PrimType classifyMatrixElementType(QualType T) const { + assert(T->isMatrixType()); + return *this->classify(T->getAs<MatrixType>()->getElementType()); + } + bool emitComplexReal(const Expr *SubExpr); bool emitComplexBoolCast(const Expr *E); bool emitComplexComparison(const Expr *LHS, const Expr *RHS, diff --git a/clang/lib/AST/ByteCode/Pointer.cpp b/clang/lib/AST/ByteCode/Pointer.cpp index e237013f4199c..a569a4221cf2f 100644 --- a/clang/lib/AST/ByteCode/Pointer.cpp +++ b/clang/lib/AST/ByteCode/Pointer.cpp @@ -934,6 +934,24 @@ std::optional<APValue> Pointer::toRValue(const Context &Ctx, return true; } + // Constant Matrix types. + if (const auto *MT = Ty->getAs<ConstantMatrixType>()) { + assert(Ptr.getFieldDesc()->isPrimitiveArray()); + QualType ElemTy = MT->getElementType(); + PrimType ElemT = *Ctx.classify(ElemTy); + unsigned NumElts = MT->getNumElementsFlattened(); + + SmallVector<APValue> Values; + Values.reserve(NumElts); + for (unsigned I = 0; I != NumElts; ++I) { + TYPE_SWITCH(ElemT, + { Values.push_back(Ptr.elem<T>(I).toAPValue(ASTCtx)); }); + } + + R = APValue(Values.data(), MT->getNumRows(), MT->getNumColumns()); + return true; + } + llvm_unreachable("invalid value to return"); }; diff --git a/clang/lib/AST/ByteCode/Program.cpp b/clang/lib/AST/ByteCode/Program.cpp index 76fec63a8920d..8876ded409415 100644 --- a/clang/lib/AST/ByteCode/Program.cpp +++ b/clang/lib/AST/ByteCode/Program.cpp @@ -474,5 +474,16 @@ Descriptor *Program::createDescriptor(const DeclTy &D, const Type *Ty, IsTemporary, IsMutable); } + // Same with constant matrix types. + if (const auto *MT = Ty->getAs<ConstantMatrixType>()) { + OptPrimType ElemTy = Ctx.classify(MT->getElementType()); + if (!ElemTy) + return nullptr; + + return allocateDescriptor(D, *ElemTy, MDSize, + MT->getNumElementsFlattened(), IsConst, + IsTemporary, IsMutable); + } + return nullptr; } diff --git a/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl b/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl index 64220980d9edc..608af75bff4bb 100644 --- a/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl +++ b/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl @@ -1,5 +1,7 @@ // RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -finclude-default-header -std=hlsl202x -fmatrix-memory-layout=column-major -verify %s // RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -finclude-default-header -std=hlsl202x -fmatrix-memory-layout=row-major -verify %s +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -finclude-default-header -std=hlsl202x -fmatrix-memory-layout=column-major -fexperimental-new-constant-interpreter -verify %s +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -finclude-default-header -std=hlsl202x -fmatrix-memory-layout=row-major -fexperimental-new-constant-interpreter -verify %s // expected-no-diagnostics >From cd61ecb70cf65b5df44154c418179db137399645 Mon Sep 17 00:00:00 2001 From: Deric Cheung <[email protected]> Date: Wed, 25 Feb 2026 16:59:02 -0800 Subject: [PATCH 2/9] Apply clang-format --- clang/lib/AST/ByteCode/Compiler.cpp | 3 +-- clang/lib/AST/ByteCode/Program.cpp | 5 ++--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/clang/lib/AST/ByteCode/Compiler.cpp b/clang/lib/AST/ByteCode/Compiler.cpp index 87982e67dcb51..6f5ed3b7dd107 100644 --- a/clang/lib/AST/ByteCode/Compiler.cpp +++ b/clang/lib/AST/ByteCode/Compiler.cpp @@ -864,8 +864,7 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) { if (const auto *VT = CE->getType()->getAs<VectorType>()) { NumElts = VT->getNumElements(); DestElemType = VT->getElementType(); - } else if (const auto *MT = - CE->getType()->getAs<ConstantMatrixType>()) { + } else if (const auto *MT = CE->getType()->getAs<ConstantMatrixType>()) { NumElts = MT->getNumElementsFlattened(); DestElemType = MT->getElementType(); } else { diff --git a/clang/lib/AST/ByteCode/Program.cpp b/clang/lib/AST/ByteCode/Program.cpp index 8876ded409415..a2cba3da27675 100644 --- a/clang/lib/AST/ByteCode/Program.cpp +++ b/clang/lib/AST/ByteCode/Program.cpp @@ -480,9 +480,8 @@ Descriptor *Program::createDescriptor(const DeclTy &D, const Type *Ty, if (!ElemTy) return nullptr; - return allocateDescriptor(D, *ElemTy, MDSize, - MT->getNumElementsFlattened(), IsConst, - IsTemporary, IsMutable); + return allocateDescriptor(D, *ElemTy, MDSize, MT->getNumElementsFlattened(), + IsConst, IsTemporary, IsMutable); } return nullptr; >From 7bc9164b5d3a2977f55136844f7a39c9c952f1cd Mon Sep 17 00:00:00 2001 From: Deric Cheung <[email protected]> Date: Wed, 25 Feb 2026 22:02:56 -0800 Subject: [PATCH 3/9] Rename NumElts to NumElems --- clang/lib/AST/ByteCode/Compiler.cpp | 38 ++++++++++++++--------------- clang/lib/AST/ByteCode/Pointer.cpp | 6 ++--- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/clang/lib/AST/ByteCode/Compiler.cpp b/clang/lib/AST/ByteCode/Compiler.cpp index 6f5ed3b7dd107..32ec16570a5b6 100644 --- a/clang/lib/AST/ByteCode/Compiler.cpp +++ b/clang/lib/AST/ByteCode/Compiler.cpp @@ -858,14 +858,14 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) { // TODO: Aggregate splat to struct and array types assert(canClassify(SubExpr->getType())); - unsigned NumElts; + unsigned NumElems; PrimType DestElemT; QualType DestElemType; if (const auto *VT = CE->getType()->getAs<VectorType>()) { - NumElts = VT->getNumElements(); + NumElems = VT->getNumElements(); DestElemType = VT->getElementType(); } else if (const auto *MT = CE->getType()->getAs<ConstantMatrixType>()) { - NumElts = MT->getNumElementsFlattened(); + NumElems = MT->getNumElementsFlattened(); DestElemType = MT->getElementType(); } else { return false; @@ -895,7 +895,7 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) { if (!this->emitSetLocal(DestElemT, SrcOffset, CE)) return false; - for (unsigned I = 0; I != NumElts; ++I) { + for (unsigned I = 0; I != NumElems; ++I) { if (!this->emitGetLocal(DestElemT, SrcOffset, CE)) return false; if (!this->emitInitElem(DestElemT, I, CE)) @@ -912,17 +912,17 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) { QualType SrcType = SubExpr->getType(); QualType DestType = CE->getType(); - unsigned SrcNumElts; + unsigned SrcNumElems; PrimType SrcElemT; if (const auto *VT = SrcType->getAs<VectorType>()) { - SrcNumElts = VT->getNumElements(); + SrcNumElems = VT->getNumElements(); SrcElemT = classifyPrim(VT->getElementType()); } else if (const auto *MT = SrcType->getAs<ConstantMatrixType>()) { - SrcNumElts = MT->getNumElementsFlattened(); + SrcNumElems = MT->getNumElementsFlattened(); SrcElemT = classifyPrim(MT->getElementType()); } else if (const auto *AT = SrcType->getAsArrayTypeUnsafe()) { if (const auto *CAT = dyn_cast<ConstantArrayType>(AT)) { - SrcNumElts = CAT->getZExtSize(); + SrcNumElems = CAT->getZExtSize(); SrcElemT = classifyPrim(CAT->getElementType()); } else { return false; @@ -931,18 +931,18 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) { return false; } - unsigned DestNumElts; + unsigned DestNumElems; PrimType DestElemT; QualType DestElemType; if (const auto *VT = DestType->getAs<VectorType>()) { - DestNumElts = VT->getNumElements(); + DestNumElems = VT->getNumElements(); DestElemType = VT->getElementType(); } else if (const auto *MT = DestType->getAs<ConstantMatrixType>()) { - DestNumElts = MT->getNumElementsFlattened(); + DestNumElems = MT->getNumElementsFlattened(); DestElemType = MT->getElementType(); } else if (const auto *AT = DestType->getAsArrayTypeUnsafe()) { if (const auto *CAT = dyn_cast<ConstantArrayType>(AT)) { - DestNumElts = CAT->getZExtSize(); + DestNumElems = CAT->getZExtSize(); DestElemType = CAT->getElementType(); } else { return false; @@ -968,8 +968,8 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) { if (!this->emitSetLocal(PT_Ptr, SrcOffset, CE)) return false; - unsigned NumElts = std::min(SrcNumElts, DestNumElts); - for (unsigned I = 0; I != NumElts; ++I) { + unsigned NumElems = std::min(SrcNumElems, DestNumElems); + for (unsigned I = 0; I != NumElems; ++I) { if (!this->emitGetLocal(PT_Ptr, SrcOffset, CE)) return false; if (!this->emitArrayElemPop(SrcElemT, I, CE)) @@ -1987,11 +1987,11 @@ bool Compiler<Emitter>::VisitImplicitValueInitExpr( } if (const auto *MT = E->getType()->getAs<ConstantMatrixType>()) { - unsigned NumElts = MT->getNumElementsFlattened(); + unsigned NumElems = MT->getNumElementsFlattened(); QualType ElemQT = MT->getElementType(); PrimType ElemT = classifyPrim(ElemQT); - for (unsigned I = 0; I < NumElts; ++I) { + for (unsigned I = 0; I < NumElems; ++I) { if (!this->visitZeroInitializer(ElemT, ElemQT, E)) return false; if (!this->emitInitElem(ElemT, I, E)) @@ -2317,15 +2317,15 @@ bool Compiler<Emitter>::visitInitList(ArrayRef<const Expr *> Inits, } if (const auto *MT = QT->getAs<ConstantMatrixType>()) { - unsigned NumElts = MT->getNumElementsFlattened(); - assert(Inits.size() == NumElts); + unsigned NumElems = MT->getNumElementsFlattened(); + assert(Inits.size() == NumElems); QualType ElemQT = MT->getElementType(); PrimType ElemT = classifyPrim(ElemQT); // InitListExpr elements are in column-major order. // Store in row-major order to match APValue convention. - for (unsigned I = 0; I < NumElts; ++I) { + for (unsigned I = 0; I < NumElems; ++I) { if (!this->visit(Inits[I])) return false; if (!this->emitInitElem(ElemT, diff --git a/clang/lib/AST/ByteCode/Pointer.cpp b/clang/lib/AST/ByteCode/Pointer.cpp index a569a4221cf2f..f4352e7edf5f8 100644 --- a/clang/lib/AST/ByteCode/Pointer.cpp +++ b/clang/lib/AST/ByteCode/Pointer.cpp @@ -939,11 +939,11 @@ std::optional<APValue> Pointer::toRValue(const Context &Ctx, assert(Ptr.getFieldDesc()->isPrimitiveArray()); QualType ElemTy = MT->getElementType(); PrimType ElemT = *Ctx.classify(ElemTy); - unsigned NumElts = MT->getNumElementsFlattened(); + unsigned NumElems = MT->getNumElementsFlattened(); SmallVector<APValue> Values; - Values.reserve(NumElts); - for (unsigned I = 0; I != NumElts; ++I) { + Values.reserve(NumElems); + for (unsigned I = 0; I != NumElems; ++I) { TYPE_SWITCH(ElemT, { Values.push_back(Ptr.elem<T>(I).toAPValue(ASTCtx)); }); } >From 4bd726d2c2dd898b8403103cb4e0ab962f8b9de4 Mon Sep 17 00:00:00 2001 From: Deric Cheung <[email protected]> Date: Wed, 25 Feb 2026 22:10:49 -0800 Subject: [PATCH 4/9] Remove unnecessary SubExpr ptr check and load This code was copied over from the CK_VectorSplat case but is not needed for CK_HLSLAggregateSplatCast. The `classifyPrim(SubExpr) == PT_Ptr && !this->emitLoadPop(SrcElemT, CE)` is not necessary because an lvalue to rvalue conversion is always inserted with an aggregate splat cast, so the SubExpr will never be a ptr. See https://github.com/llvm/llvm-project/blob/143664fcd3df825befdb9586151d53aefef3d7d0/clang/lib/Sema/SemaCast.cpp#L2939-L2940 for confirmation that this is the case. --- clang/lib/AST/ByteCode/Compiler.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/clang/lib/AST/ByteCode/Compiler.cpp b/clang/lib/AST/ByteCode/Compiler.cpp index 32ec16570a5b6..9cdaa3cc31288 100644 --- a/clang/lib/AST/ByteCode/Compiler.cpp +++ b/clang/lib/AST/ByteCode/Compiler.cpp @@ -886,8 +886,6 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) { if (!this->visit(SubExpr)) return false; - if (classifyPrim(SubExpr) == PT_Ptr && !this->emitLoadPop(SrcElemT, CE)) - return false; if (SrcElemT != DestElemT) { if (!this->emitPrimCast(SrcElemT, DestElemT, DestElemType, CE)) return false; >From 952f2cb56d629a6f07ddce29a0b92ac5bcd3dc04 Mon Sep 17 00:00:00 2001 From: Deric Cheung <[email protected]> Date: Wed, 25 Feb 2026 23:16:09 -0800 Subject: [PATCH 5/9] Support scalar DestType for HLSLElementwiseCast Also edit the comments of HLSLElementwiseCast and HLSLAggregateSplatCasts to be more clear. Assisted-by: claude-opus-4.6 --- clang/lib/AST/ByteCode/Compiler.cpp | 22 +++++++++++++++---- .../BuiltinMatrix/MatrixConstantExpr.hlsl | 2 ++ 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/clang/lib/AST/ByteCode/Compiler.cpp b/clang/lib/AST/ByteCode/Compiler.cpp index 9cdaa3cc31288..08c39538107df 100644 --- a/clang/lib/AST/ByteCode/Compiler.cpp +++ b/clang/lib/AST/ByteCode/Compiler.cpp @@ -854,7 +854,9 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) { } case CK_HLSLAggregateSplatCast: { - // Aggregate splat cast: convert a scalar value to one of an aggregate type. + // Aggregate splat cast: convert a scalar value to one of an aggregate type, + // inserting casts when necessary to convert the scalar to the aggregate's + // element type(s). // TODO: Aggregate splat to struct and array types assert(canClassify(SubExpr->getType())); @@ -904,7 +906,8 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) { case CK_HLSLElementwiseCast: { // Elementwise cast: flatten source elements of one aggregate type and store - // to a destination aggregate type of the same or fewer number of elements. + // to a destination scalar or aggregate type of the same or fewer number of + // elements, while inserting casts as necessary. // TODO: Elementwise cast to structs, nested arrays, and arrays of composite // types QualType SrcType = SubExpr->getType(); @@ -945,14 +948,25 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) { } else { return false; } + } else if (classify(DestType)) { + // Scalar destination: extract element 0 and cast. + PrimType DestT = classifyPrim(DestType); + if (!this->visit(SubExpr)) + return false; + if (!this->emitArrayElemPop(SrcElemT, 0, CE)) + return false; + if (SrcElemT != DestT) { + if (!this->emitPrimCast(SrcElemT, DestT, DestType, CE)) + return false; + } + return true; } else { return false; } DestElemT = classifyPrim(DestElemType); if (!Initializing) { - UnsignedOrNone LocalIndex = - classify(DestType) ? allocateLocal(CE) : allocateTemporary(CE); + UnsignedOrNone LocalIndex = allocateTemporary(CE); if (!LocalIndex) return false; if (!this->emitGetPtrLocal(*LocalIndex, CE)) diff --git a/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl b/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl index 608af75bff4bb..2c55e1a0ee4b3 100644 --- a/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl +++ b/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl @@ -45,6 +45,8 @@ export void fn() { _Static_assert(FA4[1] == 2.5, "Woo!"); _Static_assert(FA4[2] == 3.5, "Woo!"); _Static_assert(FA4[3] == 4.5, "Woo!"); + constexpr float F = (float)FA4; + _Static_assert(F == 1.5, "Woo!"); } // Array cast to matrix to vector >From 974eec6611da919c1489f9d498e9be5913723660 Mon Sep 17 00:00:00 2001 From: Deric Cheung <[email protected]> Date: Wed, 25 Feb 2026 23:27:32 -0800 Subject: [PATCH 6/9] Change < to != in loop guards --- clang/lib/AST/ByteCode/Compiler.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/clang/lib/AST/ByteCode/Compiler.cpp b/clang/lib/AST/ByteCode/Compiler.cpp index 08c39538107df..914da43a3d714 100644 --- a/clang/lib/AST/ByteCode/Compiler.cpp +++ b/clang/lib/AST/ByteCode/Compiler.cpp @@ -2003,7 +2003,7 @@ bool Compiler<Emitter>::VisitImplicitValueInitExpr( QualType ElemQT = MT->getElementType(); PrimType ElemT = classifyPrim(ElemQT); - for (unsigned I = 0; I < NumElems; ++I) { + for (unsigned I = 0; I != NumElems; ++I) { if (!this->visitZeroInitializer(ElemT, ElemQT, E)) return false; if (!this->emitInitElem(ElemT, I, E)) @@ -2337,7 +2337,7 @@ bool Compiler<Emitter>::visitInitList(ArrayRef<const Expr *> Inits, // InitListExpr elements are in column-major order. // Store in row-major order to match APValue convention. - for (unsigned I = 0; I < NumElems; ++I) { + for (unsigned I = 0; I != NumElems; ++I) { if (!this->visit(Inits[I])) return false; if (!this->emitInitElem(ElemT, >From eb43998a7e2e081ced9c23fcf45307d26c74e052 Mon Sep 17 00:00:00 2001 From: Deric Cheung <[email protected]> Date: Thu, 26 Feb 2026 09:04:28 -0800 Subject: [PATCH 7/9] Replace PrimType DestT with OptPrimType DestT --- clang/lib/AST/ByteCode/Compiler.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/clang/lib/AST/ByteCode/Compiler.cpp b/clang/lib/AST/ByteCode/Compiler.cpp index 914da43a3d714..2f7026b961bbd 100644 --- a/clang/lib/AST/ByteCode/Compiler.cpp +++ b/clang/lib/AST/ByteCode/Compiler.cpp @@ -948,15 +948,14 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) { } else { return false; } - } else if (classify(DestType)) { + } else if (OptPrimType DestT = classify(DestType)) { // Scalar destination: extract element 0 and cast. - PrimType DestT = classifyPrim(DestType); if (!this->visit(SubExpr)) return false; if (!this->emitArrayElemPop(SrcElemT, 0, CE)) return false; - if (SrcElemT != DestT) { - if (!this->emitPrimCast(SrcElemT, DestT, DestType, CE)) + if (SrcElemT != *DestT) { + if (!this->emitPrimCast(SrcElemT, *DestT, DestType, CE)) return false; } return true; >From 6cb36787d65fe34087213c7a79c9d9203df9060c Mon Sep 17 00:00:00 2001 From: Deric Cheung <[email protected]> Date: Thu, 26 Feb 2026 09:59:06 -0800 Subject: [PATCH 8/9] Early return false from HLSLElementwiseCast if Src or Dest types are not allowed --- clang/lib/AST/ByteCode/Compiler.cpp | 85 ++++++++++++++++------------- 1 file changed, 47 insertions(+), 38 deletions(-) diff --git a/clang/lib/AST/ByteCode/Compiler.cpp b/clang/lib/AST/ByteCode/Compiler.cpp index 2f7026b961bbd..990539446ce22 100644 --- a/clang/lib/AST/ByteCode/Compiler.cpp +++ b/clang/lib/AST/ByteCode/Compiler.cpp @@ -913,54 +913,63 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) { QualType SrcType = SubExpr->getType(); QualType DestType = CE->getType(); - unsigned SrcNumElems; - PrimType SrcElemT; - if (const auto *VT = SrcType->getAs<VectorType>()) { - SrcNumElems = VT->getNumElements(); - SrcElemT = classifyPrim(VT->getElementType()); - } else if (const auto *MT = SrcType->getAs<ConstantMatrixType>()) { - SrcNumElems = MT->getNumElementsFlattened(); - SrcElemT = classifyPrim(MT->getElementType()); - } else if (const auto *AT = SrcType->getAsArrayTypeUnsafe()) { - if (const auto *CAT = dyn_cast<ConstantArrayType>(AT)) { - SrcNumElems = CAT->getZExtSize(); - SrcElemT = classifyPrim(CAT->getElementType()); - } else { - return false; - } - } else { + // Allowed SrcTypes + const auto *SrcVT = SrcType->getAs<VectorType>(); + const auto *SrcMT = SrcType->getAs<ConstantMatrixType>(); + const auto *SrcAT = SrcType->getAsArrayTypeUnsafe(); + const auto *SrcCAT = SrcAT ? dyn_cast<ConstantArrayType>(SrcAT) : nullptr; + + // Allowed DestTypes + const auto *DestVT = DestType->getAs<VectorType>(); + const auto *DestMT = DestType->getAs<ConstantMatrixType>(); + const auto *DestAT = DestType->getAsArrayTypeUnsafe(); + const auto *DestCAT = + DestAT ? dyn_cast<ConstantArrayType>(DestAT) : nullptr; + const OptPrimType DestPT = classify(DestType); + + if (!SrcVT && !SrcMT && !SrcCAT) + return false; + if (!DestVT && !DestMT && !DestCAT && !DestPT) return false; - } - unsigned DestNumElems; - PrimType DestElemT; - QualType DestElemType; - if (const auto *VT = DestType->getAs<VectorType>()) { - DestNumElems = VT->getNumElements(); - DestElemType = VT->getElementType(); - } else if (const auto *MT = DestType->getAs<ConstantMatrixType>()) { - DestNumElems = MT->getNumElementsFlattened(); - DestElemType = MT->getElementType(); - } else if (const auto *AT = DestType->getAsArrayTypeUnsafe()) { - if (const auto *CAT = dyn_cast<ConstantArrayType>(AT)) { - DestNumElems = CAT->getZExtSize(); - DestElemType = CAT->getElementType(); - } else { - return false; - } - } else if (OptPrimType DestT = classify(DestType)) { + unsigned SrcNumElems; + PrimType SrcElemT; + if (SrcVT) { + SrcNumElems = SrcVT->getNumElements(); + SrcElemT = classifyPrim(SrcVT->getElementType()); + } else if (SrcMT) { + SrcNumElems = SrcMT->getNumElementsFlattened(); + SrcElemT = classifyPrim(SrcMT->getElementType()); + } else if (SrcCAT) { + SrcNumElems = SrcCAT->getZExtSize(); + SrcElemT = classifyPrim(SrcCAT->getElementType()); + } + + if (DestPT) { // Scalar destination: extract element 0 and cast. if (!this->visit(SubExpr)) return false; if (!this->emitArrayElemPop(SrcElemT, 0, CE)) return false; - if (SrcElemT != *DestT) { - if (!this->emitPrimCast(SrcElemT, *DestT, DestType, CE)) + if (SrcElemT != *DestPT) { + if (!this->emitPrimCast(SrcElemT, *DestPT, DestType, CE)) return false; } return true; - } else { - return false; + } + + unsigned DestNumElems; + PrimType DestElemT; + QualType DestElemType; + if (DestVT) { + DestNumElems = DestVT->getNumElements(); + DestElemType = DestVT->getElementType(); + } else if (DestMT) { + DestNumElems = DestMT->getNumElementsFlattened(); + DestElemType = DestMT->getElementType(); + } else if (DestCAT) { + DestNumElems = DestCAT->getZExtSize(); + DestElemType = DestCAT->getElementType(); } DestElemT = classifyPrim(DestElemType); >From 3665597b82776b8ee95453ec4120ba2c78d71e0f Mon Sep 17 00:00:00 2001 From: Deric Cheung <[email protected]> Date: Tue, 3 Mar 2026 11:56:40 -0800 Subject: [PATCH 9/9] Update visitInitList for row-major order matrix initializer list change in upstream --- clang/lib/AST/ByteCode/Compiler.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/clang/lib/AST/ByteCode/Compiler.cpp b/clang/lib/AST/ByteCode/Compiler.cpp index dfe7fa173152a..93ad8eb26f29e 100644 --- a/clang/lib/AST/ByteCode/Compiler.cpp +++ b/clang/lib/AST/ByteCode/Compiler.cpp @@ -2343,13 +2343,13 @@ bool Compiler<Emitter>::visitInitList(ArrayRef<const Expr *> Inits, QualType ElemQT = MT->getElementType(); PrimType ElemT = classifyPrim(ElemQT); - // InitListExpr elements are in column-major order. - // Store in row-major order to match APValue convention. + // Matrix initializer list elements are in row-major order, which matches + // the matrix APValue convention and therefore no index remapping is + // required. for (unsigned I = 0; I != NumElems; ++I) { if (!this->visit(Inits[I])) return false; - if (!this->emitInitElem(ElemT, - MT->mapColumnMajorToRowMajorFlattenedIndex(I), E)) + if (!this->emitInitElem(ElemT, I, E)) return false; } return true; _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
