https://github.com/farzonl updated https://github.com/llvm/llvm-project/pull/159446
>From 26980a21413801f0fd0fbe7ad1643670ed830d1b Mon Sep 17 00:00:00 2001 From: Farzon Lotfi <farzonlo...@microsoft.com> Date: Wed, 17 Sep 2025 12:16:12 -0400 Subject: [PATCH 1/2] [HLSL] Add support for the HLSL matrix type fixes #109839 This change is really simple. It creates a matrix alias that will let HLSL use the existing clang `matrix_type` infra. The only additional change was to add explict alias for the typed dimensions of 1-4 inclusive matricies available in HLSL. Testing therefore is limited to exercising the alias. The main difference in this attempt is the type printer. --- clang/include/clang/Driver/Options.td | 2 +- .../clang/Sema/HLSLExternalSemaSource.h | 1 + clang/lib/AST/TypePrinter.cpp | 37 ++- clang/lib/Headers/hlsl/hlsl_basic_types.h | 233 ++++++++++++++++++ clang/lib/Sema/HLSLExternalSemaSource.cpp | 72 ++++++ clang/lib/Sema/SemaHLSL.cpp | 1 - clang/test/AST/HLSL/matrix-alias.hlsl | 49 ++++ .../builtins/transpose-builtin.hlsl | 30 +++ .../BuiltIns/matrix-basic_types-errors.hlsl | 5 + .../test/SemaHLSL/BuiltIns/matrix-errors.hlsl | 29 +++ 10 files changed, 451 insertions(+), 8 deletions(-) create mode 100644 clang/test/AST/HLSL/matrix-alias.hlsl create mode 100644 clang/test/CodeGenHLSL/builtins/transpose-builtin.hlsl create mode 100644 clang/test/SemaHLSL/BuiltIns/matrix-basic_types-errors.hlsl create mode 100644 clang/test/SemaHLSL/BuiltIns/matrix-errors.hlsl diff --git a/clang/include/clang/Driver/Options.td b/clang/include/clang/Driver/Options.td index a7c514e809aa9..e8869e04390f3 100644 --- a/clang/include/clang/Driver/Options.td +++ b/clang/include/clang/Driver/Options.td @@ -4582,7 +4582,7 @@ defm ptrauth_block_descriptor_pointers : OptInCC1FFlag<"ptrauth-block-descriptor def fenable_matrix : Flag<["-"], "fenable-matrix">, Group<f_Group>, Visibility<[ClangOption, CC1Option]>, HelpText<"Enable matrix data type and related builtin functions">, - MarshallingInfoFlag<LangOpts<"MatrixTypes">>; + MarshallingInfoFlag<LangOpts<"MatrixTypes">, hlsl.KeyPath>; defm raw_string_literals : BoolFOption<"raw-string-literals", LangOpts<"RawStringLiterals">, Default<std#".hasRawStringLiterals()">, diff --git a/clang/include/clang/Sema/HLSLExternalSemaSource.h b/clang/include/clang/Sema/HLSLExternalSemaSource.h index d93fb8c8eef6b..049fc7b8fe3f2 100644 --- a/clang/include/clang/Sema/HLSLExternalSemaSource.h +++ b/clang/include/clang/Sema/HLSLExternalSemaSource.h @@ -44,6 +44,7 @@ class HLSLExternalSemaSource : public ExternalSemaSource { private: void defineTrivialHLSLTypes(); void defineHLSLVectorAlias(); + void defineHLSLMatrixAlias(); void defineHLSLTypesWithForwardDeclarations(); void onCompletion(CXXRecordDecl *Record, CompletionFunction Fn); }; diff --git a/clang/lib/AST/TypePrinter.cpp b/clang/lib/AST/TypePrinter.cpp index cd59678d67f2f..82859c4015772 100644 --- a/clang/lib/AST/TypePrinter.cpp +++ b/clang/lib/AST/TypePrinter.cpp @@ -846,16 +846,41 @@ void TypePrinter::printExtVectorAfter(const ExtVectorType *T, raw_ostream &OS) { } } -void TypePrinter::printConstantMatrixBefore(const ConstantMatrixType *T, - raw_ostream &OS) { - printBefore(T->getElementType(), OS); - OS << " __attribute__((matrix_type("; +static void printDims(const ConstantMatrixType *T, raw_ostream &OS) { OS << T->getNumRows() << ", " << T->getNumColumns(); +} + +static void printHLSLMatrixBefore(TypePrinter &TP, const ConstantMatrixType *T, raw_ostream &OS) { + OS << "matrix<"; + TP.printBefore(T->getElementType(), OS); +} + +static void printHLSLMatrixAfter(const ConstantMatrixType *T, raw_ostream &OS) { + OS << ", "; + printDims(T, OS); + OS << ">"; +} + +static void printClangMatrixBefore(TypePrinter &TP, const ConstantMatrixType *T, raw_ostream &OS) { + TP.printBefore(T->getElementType(), OS); + OS << " __attribute__((matrix_type("; + printDims(T, OS); OS << ")))"; } -void TypePrinter::printConstantMatrixAfter(const ConstantMatrixType *T, - raw_ostream &OS) { +void TypePrinter::printConstantMatrixBefore(const ConstantMatrixType *T, raw_ostream &OS) { + if (Policy.UseHLSLTypes) { + printHLSLMatrixBefore(*this, T, OS); + return; + } + printClangMatrixBefore(*this, T, OS); +} + +void TypePrinter::printConstantMatrixAfter(const ConstantMatrixType *T, raw_ostream &OS) { + if (Policy.UseHLSLTypes) { + printHLSLMatrixAfter(T, OS); + return; + } printAfter(T->getElementType(), OS); } diff --git a/clang/lib/Headers/hlsl/hlsl_basic_types.h b/clang/lib/Headers/hlsl/hlsl_basic_types.h index eff94e0d7f950..c750261952d4f 100644 --- a/clang/lib/Headers/hlsl/hlsl_basic_types.h +++ b/clang/lib/Headers/hlsl/hlsl_basic_types.h @@ -115,6 +115,239 @@ typedef vector<float64_t, 2> float64_t2; typedef vector<float64_t, 3> float64_t3; typedef vector<float64_t, 4> float64_t4; +ifdef __HLSL_ENABLE_16_BIT +typedef matrix<int16_t, 1, 1> int16_t1x1; +typedef matrix<int16_t, 1, 2> int16_t1x2; +typedef matrix<int16_t, 1, 3> int16_t1x3; +typedef matrix<int16_t, 1, 4> int16_t1x4; +typedef matrix<int16_t, 2, 1> int16_t2x1; +typedef matrix<int16_t, 2, 2> int16_t2x2; +typedef matrix<int16_t, 2, 3> int16_t2x3; +typedef matrix<int16_t, 2, 4> int16_t2x4; +typedef matrix<int16_t, 3, 1> int16_t3x1; +typedef matrix<int16_t, 3, 2> int16_t3x2; +typedef matrix<int16_t, 3, 3> int16_t3x3; +typedef matrix<int16_t, 3, 4> int16_t3x4; +typedef matrix<int16_t, 4, 1> int16_t4x1; +typedef matrix<int16_t, 4, 2> int16_t4x2; +typedef matrix<int16_t, 4, 3> int16_t4x3; +typedef matrix<int16_t, 4, 4> int16_t4x4; +typedef matrix<uint16_t, 1, 1> uint16_t1x1; +typedef matrix<uint16_t, 1, 2> uint16_t1x2; +typedef matrix<uint16_t, 1, 3> uint16_t1x3; +typedef matrix<uint16_t, 1, 4> uint16_t1x4; +typedef matrix<uint16_t, 2, 1> uint16_t2x1; +typedef matrix<uint16_t, 2, 2> uint16_t2x2; +typedef matrix<uint16_t, 2, 3> uint16_t2x3; +typedef matrix<uint16_t, 2, 4> uint16_t2x4; +typedef matrix<uint16_t, 3, 1> uint16_t3x1; +typedef matrix<uint16_t, 3, 2> uint16_t3x2; +typedef matrix<uint16_t, 3, 3> uint16_t3x3; +typedef matrix<uint16_t, 3, 4> uint16_t3x4; +typedef matrix<uint16_t, 4, 1> uint16_t4x1; +typedef matrix<uint16_t, 4, 2> uint16_t4x2; +typedef matrix<uint16_t, 4, 3> uint16_t4x3; +typedef matrix<uint16_t, 4, 4> uint16_t4x4; +#endif + +typedef matrix<int, 1, 1> int1x1; +typedef matrix<int, 1, 2> int1x2; +typedef matrix<int, 1, 3> int1x3; +typedef matrix<int, 1, 4> int1x4; +typedef matrix<int, 2, 1> int2x1; +typedef matrix<int, 2, 2> int2x2; +typedef matrix<int, 2, 3> int2x3; +typedef matrix<int, 2, 4> int2x4; +typedef matrix<int, 3, 1> int3x1; +typedef matrix<int, 3, 2> int3x2; +typedef matrix<int, 3, 3> int3x3; +typedef matrix<int, 3, 4> int3x4; +typedef matrix<int, 4, 1> int4x1; +typedef matrix<int, 4, 2> int4x2; +typedef matrix<int, 4, 3> int4x3; +typedef matrix<int, 4, 4> int4x4; +typedef matrix<uint, 1, 1> uint1x1; +typedef matrix<uint, 1, 2> uint1x2; +typedef matrix<uint, 1, 3> uint1x3; +typedef matrix<uint, 1, 4> uint1x4; +typedef matrix<uint, 2, 1> uint2x1; +typedef matrix<uint, 2, 2> uint2x2; +typedef matrix<uint, 2, 3> uint2x3; +typedef matrix<uint, 2, 4> uint2x4; +typedef matrix<uint, 3, 1> uint3x1; +typedef matrix<uint, 3, 2> uint3x2; +typedef matrix<uint, 3, 3> uint3x3; +typedef matrix<uint, 3, 4> uint3x4; +typedef matrix<uint, 4, 1> uint4x1; +typedef matrix<uint, 4, 2> uint4x2; +typedef matrix<uint, 4, 3> uint4x3; +typedef matrix<uint, 4, 4> uint4x4; +typedef matrix<int32_t, 1, 1> int32_t1x1; +typedef matrix<int32_t, 1, 2> int32_t1x2; +typedef matrix<int32_t, 1, 3> int32_t1x3; +typedef matrix<int32_t, 1, 4> int32_t1x4; +typedef matrix<int32_t, 2, 1> int32_t2x1; +typedef matrix<int32_t, 2, 2> int32_t2x2; +typedef matrix<int32_t, 2, 3> int32_t2x3; +typedef matrix<int32_t, 2, 4> int32_t2x4; +typedef matrix<int32_t, 3, 1> int32_t3x1; +typedef matrix<int32_t, 3, 2> int32_t3x2; +typedef matrix<int32_t, 3, 3> int32_t3x3; +typedef matrix<int32_t, 3, 4> int32_t3x4; +typedef matrix<int32_t, 4, 1> int32_t4x1; +typedef matrix<int32_t, 4, 2> int32_t4x2; +typedef matrix<int32_t, 4, 3> int32_t4x3; +typedef matrix<int32_t, 4, 4> int32_t4x4; +typedef matrix<uint32_t, 1, 1> uint32_t1x1; +typedef matrix<uint32_t, 1, 2> uint32_t1x2; +typedef matrix<uint32_t, 1, 3> uint32_t1x3; +typedef matrix<uint32_t, 1, 4> uint32_t1x4; +typedef matrix<uint32_t, 2, 1> uint32_t2x1; +typedef matrix<uint32_t, 2, 2> uint32_t2x2; +typedef matrix<uint32_t, 2, 3> uint32_t2x3; +typedef matrix<uint32_t, 2, 4> uint32_t2x4; +typedef matrix<uint32_t, 3, 1> uint32_t3x1; +typedef matrix<uint32_t, 3, 2> uint32_t3x2; +typedef matrix<uint32_t, 3, 3> uint32_t3x3; +typedef matrix<uint32_t, 3, 4> uint32_t3x4; +typedef matrix<uint32_t, 4, 1> uint32_t4x1; +typedef matrix<uint32_t, 4, 2> uint32_t4x2; +typedef matrix<uint32_t, 4, 3> uint32_t4x3; +typedef matrix<uint32_t, 4, 4> uint32_t4x4; +typedef matrix<int64_t, 1, 1> int64_t1x1; +typedef matrix<int64_t, 1, 2> int64_t1x2; +typedef matrix<int64_t, 1, 3> int64_t1x3; +typedef matrix<int64_t, 1, 4> int64_t1x4; +typedef matrix<int64_t, 2, 1> int64_t2x1; +typedef matrix<int64_t, 2, 2> int64_t2x2; +typedef matrix<int64_t, 2, 3> int64_t2x3; +typedef matrix<int64_t, 2, 4> int64_t2x4; +typedef matrix<int64_t, 3, 1> int64_t3x1; +typedef matrix<int64_t, 3, 2> int64_t3x2; +typedef matrix<int64_t, 3, 3> int64_t3x3; +typedef matrix<int64_t, 3, 4> int64_t3x4; +typedef matrix<int64_t, 4, 1> int64_t4x1; +typedef matrix<int64_t, 4, 2> int64_t4x2; +typedef matrix<int64_t, 4, 3> int64_t4x3; +typedef matrix<int64_t, 4, 4> int64_t4x4; +typedef matrix<uint64_t, 1, 1> uint64_t1x1; +typedef matrix<uint64_t, 1, 2> uint64_t1x2; +typedef matrix<uint64_t, 1, 3> uint64_t1x3; +typedef matrix<uint64_t, 1, 4> uint64_t1x4; +typedef matrix<uint64_t, 2, 1> uint64_t2x1; +typedef matrix<uint64_t, 2, 2> uint64_t2x2; +typedef matrix<uint64_t, 2, 3> uint64_t2x3; +typedef matrix<uint64_t, 2, 4> uint64_t2x4; +typedef matrix<uint64_t, 3, 1> uint64_t3x1; +typedef matrix<uint64_t, 3, 2> uint64_t3x2; +typedef matrix<uint64_t, 3, 3> uint64_t3x3; +typedef matrix<uint64_t, 3, 4> uint64_t3x4; +typedef matrix<uint64_t, 4, 1> uint64_t4x1; +typedef matrix<uint64_t, 4, 2> uint64_t4x2; +typedef matrix<uint64_t, 4, 3> uint64_t4x3; +typedef matrix<uint64_t, 4, 4> uint64_t4x4; + +typedef matrix<half, 1, 1> half1x1; +typedef matrix<half, 1, 2> half1x2; +typedef matrix<half, 1, 3> half1x3; +typedef matrix<half, 1, 4> half1x4; +typedef matrix<half, 2, 1> half2x1; +typedef matrix<half, 2, 2> half2x2; +typedef matrix<half, 2, 3> half2x3; +typedef matrix<half, 2, 4> half2x4; +typedef matrix<half, 3, 1> half3x1; +typedef matrix<half, 3, 2> half3x2; +typedef matrix<half, 3, 3> half3x3; +typedef matrix<half, 3, 4> half3x4; +typedef matrix<half, 4, 1> half4x1; +typedef matrix<half, 4, 2> half4x2; +typedef matrix<half, 4, 3> half4x3; +typedef matrix<half, 4, 4> half4x4; +typedef matrix<float, 1, 1> float1x1; +typedef matrix<float, 1, 2> float1x2; +typedef matrix<float, 1, 3> float1x3; +typedef matrix<float, 1, 4> float1x4; +typedef matrix<float, 2, 1> float2x1; +typedef matrix<float, 2, 2> float2x2; +typedef matrix<float, 2, 3> float2x3; +typedef matrix<float, 2, 4> float2x4; +typedef matrix<float, 3, 1> float3x1; +typedef matrix<float, 3, 2> float3x2; +typedef matrix<float, 3, 3> float3x3; +typedef matrix<float, 3, 4> float3x4; +typedef matrix<float, 4, 1> float4x1; +typedef matrix<float, 4, 2> float4x2; +typedef matrix<float, 4, 3> float4x3; +typedef matrix<float, 4, 4> float4x4; +typedef matrix<double, 1, 1> double1x1; +typedef matrix<double, 1, 2> double1x2; +typedef matrix<double, 1, 3> double1x3; +typedef matrix<double, 1, 4> double1x4; +typedef matrix<double, 2, 1> double2x1; +typedef matrix<double, 2, 2> double2x2; +typedef matrix<double, 2, 3> double2x3; +typedef matrix<double, 2, 4> double2x4; +typedef matrix<double, 3, 1> double3x1; +typedef matrix<double, 3, 2> double3x2; +typedef matrix<double, 3, 3> double3x3; +typedef matrix<double, 3, 4> double3x4; +typedef matrix<double, 4, 1> double4x1; +typedef matrix<double, 4, 2> double4x2; +typedef matrix<double, 4, 3> double4x3; +typedef matrix<double, 4, 4> double4x4; + +#ifdef __HLSL_ENABLE_16_BIT +typedef matrix<float16_t, 1, 1> float16_t1x1; +typedef matrix<float16_t, 1, 2> float16_t1x2; +typedef matrix<float16_t, 1, 3> float16_t1x3; +typedef matrix<float16_t, 1, 4> float16_t1x4; +typedef matrix<float16_t, 2, 1> float16_t2x1; +typedef matrix<float16_t, 2, 2> float16_t2x2; +typedef matrix<float16_t, 2, 3> float16_t2x3; +typedef matrix<float16_t, 2, 4> float16_t2x4; +typedef matrix<float16_t, 3, 1> float16_t3x1; +typedef matrix<float16_t, 3, 2> float16_t3x2; +typedef matrix<float16_t, 3, 3> float16_t3x3; +typedef matrix<float16_t, 3, 4> float16_t3x4; +typedef matrix<float16_t, 4, 1> float16_t4x1; +typedef matrix<float16_t, 4, 2> float16_t4x2; +typedef matrix<float16_t, 4, 3> float16_t4x3; +typedef matrix<float16_t, 4, 4> float16_t4x4; +#endif + +typedef matrix<float32_t, 1, 1> float32_t1x1; +typedef matrix<float32_t, 1, 2> float32_t1x2; +typedef matrix<float32_t, 1, 3> float32_t1x3; +typedef matrix<float32_t, 1, 4> float32_t1x4; +typedef matrix<float32_t, 2, 1> float32_t2x1; +typedef matrix<float32_t, 2, 2> float32_t2x2; +typedef matrix<float32_t, 2, 3> float32_t2x3; +typedef matrix<float32_t, 2, 4> float32_t2x4; +typedef matrix<float32_t, 3, 1> float32_t3x1; +typedef matrix<float32_t, 3, 2> float32_t3x2; +typedef matrix<float32_t, 3, 3> float32_t3x3; +typedef matrix<float32_t, 3, 4> float32_t3x4; +typedef matrix<float32_t, 4, 1> float32_t4x1; +typedef matrix<float32_t, 4, 2> float32_t4x2; +typedef matrix<float32_t, 4, 3> float32_t4x3; +typedef matrix<float32_t, 4, 4> float32_t4x4; +typedef matrix<float64_t, 1, 1> float64_t1x1; +typedef matrix<float64_t, 1, 2> float64_t1x2; +typedef matrix<float64_t, 1, 3> float64_t1x3; +typedef matrix<float64_t, 1, 4> float64_t1x4; +typedef matrix<float64_t, 2, 1> float64_t2x1; +typedef matrix<float64_t, 2, 2> float64_t2x2; +typedef matrix<float64_t, 2, 3> float64_t2x3; +typedef matrix<float64_t, 2, 4> float64_t2x4; +typedef matrix<float64_t, 3, 1> float64_t3x1; +typedef matrix<float64_t, 3, 2> float64_t3x2; +typedef matrix<float64_t, 3, 3> float64_t3x3; +typedef matrix<float64_t, 3, 4> float64_t3x4; +typedef matrix<float64_t, 4, 1> float64_t4x1; +typedef matrix<float64_t, 4, 2> float64_t4x2; +typedef matrix<float64_t, 4, 3> float64_t4x3; +typedef matrix<float64_t, 4, 4> float64_t4x4; + } // namespace hlsl #endif //_HLSL_HLSL_BASIC_TYPES_H_ diff --git a/clang/lib/Sema/HLSLExternalSemaSource.cpp b/clang/lib/Sema/HLSLExternalSemaSource.cpp index 3386d8da281e9..9ac60712b87b8 100644 --- a/clang/lib/Sema/HLSLExternalSemaSource.cpp +++ b/clang/lib/Sema/HLSLExternalSemaSource.cpp @@ -121,8 +121,80 @@ void HLSLExternalSemaSource::defineHLSLVectorAlias() { HLSLNamespace->addDecl(Template); } +void HLSLExternalSemaSource::defineHLSLMatrixAlias() { + ASTContext &AST = SemaPtr->getASTContext(); + llvm::SmallVector<NamedDecl *> TemplateParams; + + auto *TypeParam = TemplateTypeParmDecl::Create( + AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 0, + &AST.Idents.get("element", tok::TokenKind::identifier), false, false); + TypeParam->setDefaultArgument( + AST, SemaPtr->getTrivialTemplateArgumentLoc( + TemplateArgument(AST.FloatTy), QualType(), SourceLocation())); + + TemplateParams.emplace_back(TypeParam); + + // these should be 64 bit to be consistent with other clang matrices. + auto *RowsParam = NonTypeTemplateParmDecl::Create( + AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 1, + &AST.Idents.get("rows_count", tok::TokenKind::identifier), AST.IntTy, + false, AST.getTrivialTypeSourceInfo(AST.IntTy)); + llvm::APInt RVal(AST.getIntWidth(AST.IntTy), 4); + TemplateArgument RDefault(AST, llvm::APSInt(std::move(RVal)), AST.IntTy, + /*IsDefaulted=*/true); + RowsParam->setDefaultArgument( + AST, SemaPtr->getTrivialTemplateArgumentLoc(RDefault, AST.IntTy, + SourceLocation(), RowsParam)); + TemplateParams.emplace_back(RowsParam); + + auto *ColsParam = NonTypeTemplateParmDecl::Create( + AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 2, + &AST.Idents.get("cols_count", tok::TokenKind::identifier), AST.IntTy, + false, AST.getTrivialTypeSourceInfo(AST.IntTy)); + llvm::APInt CVal(AST.getIntWidth(AST.IntTy), 4); + TemplateArgument CDefault(AST, llvm::APSInt(std::move(CVal)), AST.IntTy, + /*IsDefaulted=*/true); + ColsParam->setDefaultArgument( + AST, SemaPtr->getTrivialTemplateArgumentLoc(CDefault, AST.IntTy, + SourceLocation(), ColsParam)); + TemplateParams.emplace_back(ColsParam); + + auto *ParamList = + TemplateParameterList::Create(AST, SourceLocation(), SourceLocation(), + TemplateParams, SourceLocation(), nullptr); + + IdentifierInfo &II = AST.Idents.get("matrix", tok::TokenKind::identifier); + + QualType AliasType = AST.getDependentSizedMatrixType( + AST.getTemplateTypeParmType(0, 0, false, TypeParam), + DeclRefExpr::Create( + AST, NestedNameSpecifierLoc(), SourceLocation(), RowsParam, false, + DeclarationNameInfo(RowsParam->getDeclName(), SourceLocation()), + AST.IntTy, VK_LValue), + DeclRefExpr::Create( + AST, NestedNameSpecifierLoc(), SourceLocation(), ColsParam, false, + DeclarationNameInfo(ColsParam->getDeclName(), SourceLocation()), + AST.IntTy, VK_LValue), + SourceLocation()); + + auto *Record = TypeAliasDecl::Create(AST, HLSLNamespace, SourceLocation(), + SourceLocation(), &II, + AST.getTrivialTypeSourceInfo(AliasType)); + Record->setImplicit(true); + + auto *Template = + TypeAliasTemplateDecl::Create(AST, HLSLNamespace, SourceLocation(), + Record->getIdentifier(), ParamList, Record); + + Record->setDescribedAliasTemplate(Template); + Template->setImplicit(true); + Template->setLexicalDeclContext(Record->getDeclContext()); + HLSLNamespace->addDecl(Template); +} + void HLSLExternalSemaSource::defineTrivialHLSLTypes() { defineHLSLVectorAlias(); + defineHLSLMatrixAlias(); } /// Set up common members and attributes for buffer types diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 0af38472b0fec..38f81a24945c5 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -3285,7 +3285,6 @@ static void BuildFlattenedTypeList(QualType BaseTy, while (!WorkList.empty()) { QualType T = WorkList.pop_back_val(); T = T.getCanonicalType().getUnqualifiedType(); - assert(!isa<MatrixType>(T) && "Matrix types not yet supported in HLSL"); if (const auto *AT = dyn_cast<ConstantArrayType>(T)) { llvm::SmallVector<QualType, 16> ElementFields; // Generally I've avoided recursion in this algorithm, but arrays of diff --git a/clang/test/AST/HLSL/matrix-alias.hlsl b/clang/test/AST/HLSL/matrix-alias.hlsl new file mode 100644 index 0000000000000..2758b6f0d202f --- /dev/null +++ b/clang/test/AST/HLSL/matrix-alias.hlsl @@ -0,0 +1,49 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -ast-dump -o - %s | FileCheck %s + +// Test that matrix aliases are set up properly for HLSL + +// CHECK: NamespaceDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> implicit hlsl +// CHECK-NEXT: TypeAliasTemplateDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> implicit vector +// CHECK-NEXT: TemplateTypeParmDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> class depth 0 index 0 element +// CHECK-NEXT: TemplateArgument type 'float' +// CHECK-NEXT: BuiltinType 0x{{[0-9a-fA-F]+}} 'float' +// CHECK-NEXT: NonTypeTemplateParmDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> 'int' depth 0 index 1 element_count +// CHECK-NEXT: TemplateArgument expr +// CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <<invalid sloc>> 'int' 4 +// CHECK-NEXT: TypeAliasDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> implicit vector 'vector<element, element_count>' +// CHECK-NEXT: DependentSizedExtVectorType 0x{{[0-9a-fA-F]+}} 'vector<element, element_count>' dependent <invalid sloc> +// CHECK-NEXT: TemplateTypeParmType 0x{{[0-9a-fA-F]+}} 'element' dependent depth 0 index 0 +// CHECK-NEXT: TemplateTypeParm 0x{{[0-9a-fA-F]+}} 'element' +// CHECK-NEXT: DeclRefExpr 0x{{[0-9a-fA-F]+}} <<invalid sloc>> 'int' lvalue +// CHECK-SAME: NonTypeTemplateParm 0x{{[0-9a-fA-F]+}} 'element_count' 'int' + +// Make sure we got a using directive at the end. +// CHECK: UsingDirectiveDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> Namespace 0x{{[0-9a-fA-F]+}} 'hlsl' + +[numthreads(1,1,1)] +int entry() { + // Verify that the alias is generated inside the hlsl namespace. + hlsl::matrix<float, 2, 2> Mat2x2f; + + // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:26:3, col:36> + // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:29> col:29 Mat2x2f 'hlsl::matrix<float, 2, 2>' + + // Verify that you don't need to specify the namespace. + matrix<int, 2, 2> Mat2x2i; + + // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:32:3, col:28> + // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:21> col:21 Mat2x2i 'matrix<int, 2, 2>' + + // Build a bigger matrix. + matrix<double, 4, 4> Mat4x4d; + + // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:38:3, col:31> + // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:24> col:24 Mat4x4d 'matrix<double, 4, 4>' + + // Verify that the implicit arguments generate the correct type. + matrix<> ImpMat4x4; + + // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:44:3, col:21> + // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:12> col:12 ImpMat4x4 'matrix<>':'matrix<float, 4, 4>' + return 1; +} diff --git a/clang/test/CodeGenHLSL/builtins/transpose-builtin.hlsl b/clang/test/CodeGenHLSL/builtins/transpose-builtin.hlsl new file mode 100644 index 0000000000000..86aa7cd6985dd --- /dev/null +++ b/clang/test/CodeGenHLSL/builtins/transpose-builtin.hlsl @@ -0,0 +1,30 @@ +// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -o - | FileCheck %s + +// NOTE: This test is only to confirm we can do codgen with the matrix alias. + +// CHECK-LABEL: define {{.*}}transpose_half_2x2 +void transpose_half_2x2(half2x2 a) { + // CHECK: [[A:%.*]] = load <4 x half>, ptr {{.*}}, align 2 + // CHECK-NEXT: [[TRANS:%.*]] = call {{.*}}<4 x half> @llvm.matrix.transpose.v4f16(<4 x half> [[A]], i32 2, i32 2) + // CHECK-NEXT: store <4 x half> [[TRANS]], ptr %a_t, align 2 + + half2x2 a_t = __builtin_matrix_transpose(a); +} + +// CHECK-LABEL: define {{.*}}transpose_float_3x2 +void transpose_float_3x2(float3x2 a) { + // CHECK: [[A:%.*]] = load <6 x float>, ptr {{.*}}, align 4 + // CHECK-NEXT: [[TRANS:%.*]] = call {{.*}}<6 x float> @llvm.matrix.transpose.v6f32(<6 x float> [[A]], i32 3, i32 2) + // CHECK-NEXT: store <6 x float> [[TRANS]], ptr %a_t, align 4 + + float2x3 a_t = __builtin_matrix_transpose(a); +} + +// CHECK-LABEL: define {{.*}}transpose_int_4x3 +void transpose_int_4x3(int4x3 a) { + // CHECK: [[A:%.*]] = load <12 x i32>, ptr {{.*}}, align 4 + // CHECK-NEXT: [[TRANS:%.*]] = call <12 x i32> @llvm.matrix.transpose.v12i32(<12 x i32> [[A]], i32 4, i32 3) + // CHECK-NEXT: store <12 x i32> [[TRANS]], ptr %a_t, align 4 + + int3x4 a_t = __builtin_matrix_transpose(a); +} diff --git a/clang/test/SemaHLSL/BuiltIns/matrix-basic_types-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/matrix-basic_types-errors.hlsl new file mode 100644 index 0000000000000..e9d37ce22cc07 --- /dev/null +++ b/clang/test/SemaHLSL/BuiltIns/matrix-basic_types-errors.hlsl @@ -0,0 +1,5 @@ + +// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only -disable-llvm-passes -verify + +uint64_t5x5 mat; +// expected-error@-1 {{unknown type name 'uint64_t5x5'}} diff --git a/clang/test/SemaHLSL/BuiltIns/matrix-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/matrix-errors.hlsl new file mode 100644 index 0000000000000..9a820c2205843 --- /dev/null +++ b/clang/test/SemaHLSL/BuiltIns/matrix-errors.hlsl @@ -0,0 +1,29 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -fsyntax-only -verify %s + +// Some bad declarations +hlsl::matrix ShouldWorkSomeday; // expected-error{{use of alias template 'hlsl::matrix' requires template arguments}} +// expected-note@*:* {{template declaration from hidden source: template <class element = float, int rows_count = 4, int cols_count = 4> using matrix = element __attribute__((matrix_type(rows_count, cols_count)))}} + +hlsl::matrix<1,1,1> BadMat; // expected-error{{template argument for template type parameter must be a type}} +// expected-note@*:* {{template parameter from hidden source: class element = float}} + +hlsl::matrix<int, float,4> AnotherBadMat; // expected-error{{template argument for non-type template parameter must be an expression}} +// expected-note@*:* {{template parameter from hidden source: int rows_count = 4}} + +hlsl::matrix<int, 2, 3, 2> YABV; // expected-error{{too many template arguments for alias template 'matrix'}} +// expected-note@*:* {{template declaration from hidden source: template <class element = float, int rows_count = 4, int cols_count = 4> using matrix = element __attribute__((matrix_type(rows_count, cols_count)))}} + +// This code is rejected by clang because clang puts the HLSL built-in types +// into the HLSL namespace. +namespace hlsl { + struct matrix {}; // expected-error {{redefinition of 'matrix'}} +} + +// This code is rejected by dxc because dxc puts the HLSL built-in types +// into the global space, but clang will allow it even though it will shadow the +// matrix template. +struct matrix {}; // expected-note {{candidate found by name lookup is 'matrix'}} + +matrix<int,2,2> matInt2x2; // expected-error {{reference to 'matrix' is ambiguous}} + +// expected-note@*:* {{candidate found by name lookup is 'hlsl::matrix'}} >From 9e18ebbd44be81798e875159e7800b772ba58bde Mon Sep 17 00:00:00 2001 From: Farzon Lotfi <farzonlo...@microsoft.com> Date: Wed, 17 Sep 2025 16:26:43 -0400 Subject: [PATCH 2/2] fix formatting --- clang/lib/AST/TypePrinter.cpp | 12 ++++++++---- clang/lib/Headers/hlsl/hlsl_basic_types.h | 2 +- clang/lib/Sema/HLSLExternalSemaSource.cpp | 2 +- .../SemaHLSL/BuiltIns/matrix-basic_types-errors.hlsl | 7 +++++-- 4 files changed, 15 insertions(+), 8 deletions(-) diff --git a/clang/lib/AST/TypePrinter.cpp b/clang/lib/AST/TypePrinter.cpp index 82859c4015772..f3448af5f8f50 100644 --- a/clang/lib/AST/TypePrinter.cpp +++ b/clang/lib/AST/TypePrinter.cpp @@ -850,7 +850,8 @@ static void printDims(const ConstantMatrixType *T, raw_ostream &OS) { OS << T->getNumRows() << ", " << T->getNumColumns(); } -static void printHLSLMatrixBefore(TypePrinter &TP, const ConstantMatrixType *T, raw_ostream &OS) { +static void printHLSLMatrixBefore(TypePrinter &TP, const ConstantMatrixType *T, + raw_ostream &OS) { OS << "matrix<"; TP.printBefore(T->getElementType(), OS); } @@ -861,14 +862,16 @@ static void printHLSLMatrixAfter(const ConstantMatrixType *T, raw_ostream &OS) { OS << ">"; } -static void printClangMatrixBefore(TypePrinter &TP, const ConstantMatrixType *T, raw_ostream &OS) { +static void printClangMatrixBefore(TypePrinter &TP, const ConstantMatrixType *T, + raw_ostream &OS) { TP.printBefore(T->getElementType(), OS); OS << " __attribute__((matrix_type("; printDims(T, OS); OS << ")))"; } -void TypePrinter::printConstantMatrixBefore(const ConstantMatrixType *T, raw_ostream &OS) { +void TypePrinter::printConstantMatrixBefore(const ConstantMatrixType *T, + raw_ostream &OS) { if (Policy.UseHLSLTypes) { printHLSLMatrixBefore(*this, T, OS); return; @@ -876,7 +879,8 @@ void TypePrinter::printConstantMatrixBefore(const ConstantMatrixType *T, raw_ost printClangMatrixBefore(*this, T, OS); } -void TypePrinter::printConstantMatrixAfter(const ConstantMatrixType *T, raw_ostream &OS) { +void TypePrinter::printConstantMatrixAfter(const ConstantMatrixType *T, + raw_ostream &OS) { if (Policy.UseHLSLTypes) { printHLSLMatrixAfter(T, OS); return; diff --git a/clang/lib/Headers/hlsl/hlsl_basic_types.h b/clang/lib/Headers/hlsl/hlsl_basic_types.h index c750261952d4f..fc1e265067714 100644 --- a/clang/lib/Headers/hlsl/hlsl_basic_types.h +++ b/clang/lib/Headers/hlsl/hlsl_basic_types.h @@ -115,7 +115,7 @@ typedef vector<float64_t, 2> float64_t2; typedef vector<float64_t, 3> float64_t3; typedef vector<float64_t, 4> float64_t4; -ifdef __HLSL_ENABLE_16_BIT +#ifdef __HLSL_ENABLE_16_BIT typedef matrix<int16_t, 1, 1> int16_t1x1; typedef matrix<int16_t, 1, 2> int16_t1x2; typedef matrix<int16_t, 1, 3> int16_t1x3; diff --git a/clang/lib/Sema/HLSLExternalSemaSource.cpp b/clang/lib/Sema/HLSLExternalSemaSource.cpp index 9ac60712b87b8..a30a440d07e6a 100644 --- a/clang/lib/Sema/HLSLExternalSemaSource.cpp +++ b/clang/lib/Sema/HLSLExternalSemaSource.cpp @@ -123,7 +123,7 @@ void HLSLExternalSemaSource::defineHLSLVectorAlias() { void HLSLExternalSemaSource::defineHLSLMatrixAlias() { ASTContext &AST = SemaPtr->getASTContext(); - llvm::SmallVector<NamedDecl *> TemplateParams; + llvm::SmallVector<NamedDecl *> TemplateParams; auto *TypeParam = TemplateTypeParmDecl::Create( AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 0, diff --git a/clang/test/SemaHLSL/BuiltIns/matrix-basic_types-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/matrix-basic_types-errors.hlsl index e9d37ce22cc07..db43a61774760 100644 --- a/clang/test/SemaHLSL/BuiltIns/matrix-basic_types-errors.hlsl +++ b/clang/test/SemaHLSL/BuiltIns/matrix-basic_types-errors.hlsl @@ -1,5 +1,8 @@ - -// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only -disable-llvm-passes -verify +// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify uint64_t5x5 mat; // expected-error@-1 {{unknown type name 'uint64_t5x5'}} + +// Note: this one only fails because -fnative-half-type is not set +uint16_t4x4 mat; +// expected-error@-1 {{unknown type name 'uint16_t4x4'}} _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits