https://github.com/farzonl updated 
https://github.com/llvm/llvm-project/pull/159446

>From 26980a21413801f0fd0fbe7ad1643670ed830d1b Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlo...@microsoft.com>
Date: Wed, 17 Sep 2025 12:16:12 -0400
Subject: [PATCH 1/2] [HLSL] Add support for the HLSL matrix type

fixes #109839

This change is really simple. It creates a matrix alias
that will let HLSL  use the existing clang `matrix_type` infra.

The only additional change was to add explict alias for the
typed  dimensions of 1-4 inclusive matricies available in HLSL.

Testing therefore is limited to exercising the alias.

The main difference in this attempt is the type printer.
---
 clang/include/clang/Driver/Options.td         |   2 +-
 .../clang/Sema/HLSLExternalSemaSource.h       |   1 +
 clang/lib/AST/TypePrinter.cpp                 |  37 ++-
 clang/lib/Headers/hlsl/hlsl_basic_types.h     | 233 ++++++++++++++++++
 clang/lib/Sema/HLSLExternalSemaSource.cpp     |  72 ++++++
 clang/lib/Sema/SemaHLSL.cpp                   |   1 -
 clang/test/AST/HLSL/matrix-alias.hlsl         |  49 ++++
 .../builtins/transpose-builtin.hlsl           |  30 +++
 .../BuiltIns/matrix-basic_types-errors.hlsl   |   5 +
 .../test/SemaHLSL/BuiltIns/matrix-errors.hlsl |  29 +++
 10 files changed, 451 insertions(+), 8 deletions(-)
 create mode 100644 clang/test/AST/HLSL/matrix-alias.hlsl
 create mode 100644 clang/test/CodeGenHLSL/builtins/transpose-builtin.hlsl
 create mode 100644 clang/test/SemaHLSL/BuiltIns/matrix-basic_types-errors.hlsl
 create mode 100644 clang/test/SemaHLSL/BuiltIns/matrix-errors.hlsl

diff --git a/clang/include/clang/Driver/Options.td 
b/clang/include/clang/Driver/Options.td
index a7c514e809aa9..e8869e04390f3 100644
--- a/clang/include/clang/Driver/Options.td
+++ b/clang/include/clang/Driver/Options.td
@@ -4582,7 +4582,7 @@ defm ptrauth_block_descriptor_pointers : 
OptInCC1FFlag<"ptrauth-block-descriptor
 def fenable_matrix : Flag<["-"], "fenable-matrix">, Group<f_Group>,
     Visibility<[ClangOption, CC1Option]>,
     HelpText<"Enable matrix data type and related builtin functions">,
-    MarshallingInfoFlag<LangOpts<"MatrixTypes">>;
+    MarshallingInfoFlag<LangOpts<"MatrixTypes">, hlsl.KeyPath>;
 
 defm raw_string_literals : BoolFOption<"raw-string-literals",
     LangOpts<"RawStringLiterals">, Default<std#".hasRawStringLiterals()">,
diff --git a/clang/include/clang/Sema/HLSLExternalSemaSource.h 
b/clang/include/clang/Sema/HLSLExternalSemaSource.h
index d93fb8c8eef6b..049fc7b8fe3f2 100644
--- a/clang/include/clang/Sema/HLSLExternalSemaSource.h
+++ b/clang/include/clang/Sema/HLSLExternalSemaSource.h
@@ -44,6 +44,7 @@ class HLSLExternalSemaSource : public ExternalSemaSource {
 private:
   void defineTrivialHLSLTypes();
   void defineHLSLVectorAlias();
+  void defineHLSLMatrixAlias();
   void defineHLSLTypesWithForwardDeclarations();
   void onCompletion(CXXRecordDecl *Record, CompletionFunction Fn);
 };
diff --git a/clang/lib/AST/TypePrinter.cpp b/clang/lib/AST/TypePrinter.cpp
index cd59678d67f2f..82859c4015772 100644
--- a/clang/lib/AST/TypePrinter.cpp
+++ b/clang/lib/AST/TypePrinter.cpp
@@ -846,16 +846,41 @@ void TypePrinter::printExtVectorAfter(const ExtVectorType 
*T, raw_ostream &OS) {
   }
 }
 
-void TypePrinter::printConstantMatrixBefore(const ConstantMatrixType *T,
-                                            raw_ostream &OS) {
-  printBefore(T->getElementType(), OS);
-  OS << " __attribute__((matrix_type(";
+static void printDims(const ConstantMatrixType *T, raw_ostream &OS) {
   OS << T->getNumRows() << ", " << T->getNumColumns();
+}
+
+static void printHLSLMatrixBefore(TypePrinter &TP, const ConstantMatrixType 
*T, raw_ostream &OS) {
+  OS << "matrix<";
+  TP.printBefore(T->getElementType(), OS);
+}
+
+static void printHLSLMatrixAfter(const ConstantMatrixType *T, raw_ostream &OS) 
{
+  OS << ", ";
+  printDims(T, OS);
+  OS << ">";
+}
+
+static void printClangMatrixBefore(TypePrinter &TP, const ConstantMatrixType 
*T, raw_ostream &OS) {
+  TP.printBefore(T->getElementType(), OS);
+  OS << " __attribute__((matrix_type(";
+  printDims(T, OS);
   OS << ")))";
 }
 
-void TypePrinter::printConstantMatrixAfter(const ConstantMatrixType *T,
-                                           raw_ostream &OS) {
+void TypePrinter::printConstantMatrixBefore(const ConstantMatrixType *T, 
raw_ostream &OS) {
+  if (Policy.UseHLSLTypes) {
+    printHLSLMatrixBefore(*this, T, OS);
+    return;
+  }
+  printClangMatrixBefore(*this, T, OS);
+}
+
+void TypePrinter::printConstantMatrixAfter(const ConstantMatrixType *T, 
raw_ostream &OS) {
+  if (Policy.UseHLSLTypes) {
+    printHLSLMatrixAfter(T, OS);
+    return;
+  }
   printAfter(T->getElementType(), OS);
 }
 
diff --git a/clang/lib/Headers/hlsl/hlsl_basic_types.h 
b/clang/lib/Headers/hlsl/hlsl_basic_types.h
index eff94e0d7f950..c750261952d4f 100644
--- a/clang/lib/Headers/hlsl/hlsl_basic_types.h
+++ b/clang/lib/Headers/hlsl/hlsl_basic_types.h
@@ -115,6 +115,239 @@ typedef vector<float64_t, 2> float64_t2;
 typedef vector<float64_t, 3> float64_t3;
 typedef vector<float64_t, 4> float64_t4;
 
+ifdef __HLSL_ENABLE_16_BIT
+typedef matrix<int16_t, 1, 1> int16_t1x1;
+typedef matrix<int16_t, 1, 2> int16_t1x2;
+typedef matrix<int16_t, 1, 3> int16_t1x3;
+typedef matrix<int16_t, 1, 4> int16_t1x4;
+typedef matrix<int16_t, 2, 1> int16_t2x1;
+typedef matrix<int16_t, 2, 2> int16_t2x2;
+typedef matrix<int16_t, 2, 3> int16_t2x3;
+typedef matrix<int16_t, 2, 4> int16_t2x4;
+typedef matrix<int16_t, 3, 1> int16_t3x1;
+typedef matrix<int16_t, 3, 2> int16_t3x2;
+typedef matrix<int16_t, 3, 3> int16_t3x3;
+typedef matrix<int16_t, 3, 4> int16_t3x4;
+typedef matrix<int16_t, 4, 1> int16_t4x1;
+typedef matrix<int16_t, 4, 2> int16_t4x2;
+typedef matrix<int16_t, 4, 3> int16_t4x3;
+typedef matrix<int16_t, 4, 4> int16_t4x4;
+typedef matrix<uint16_t, 1, 1> uint16_t1x1;
+typedef matrix<uint16_t, 1, 2> uint16_t1x2;
+typedef matrix<uint16_t, 1, 3> uint16_t1x3;
+typedef matrix<uint16_t, 1, 4> uint16_t1x4;
+typedef matrix<uint16_t, 2, 1> uint16_t2x1;
+typedef matrix<uint16_t, 2, 2> uint16_t2x2;
+typedef matrix<uint16_t, 2, 3> uint16_t2x3;
+typedef matrix<uint16_t, 2, 4> uint16_t2x4;
+typedef matrix<uint16_t, 3, 1> uint16_t3x1;
+typedef matrix<uint16_t, 3, 2> uint16_t3x2;
+typedef matrix<uint16_t, 3, 3> uint16_t3x3;
+typedef matrix<uint16_t, 3, 4> uint16_t3x4;
+typedef matrix<uint16_t, 4, 1> uint16_t4x1;
+typedef matrix<uint16_t, 4, 2> uint16_t4x2;
+typedef matrix<uint16_t, 4, 3> uint16_t4x3;
+typedef matrix<uint16_t, 4, 4> uint16_t4x4;
+#endif
+
+typedef matrix<int, 1, 1> int1x1;
+typedef matrix<int, 1, 2> int1x2;
+typedef matrix<int, 1, 3> int1x3;
+typedef matrix<int, 1, 4> int1x4;
+typedef matrix<int, 2, 1> int2x1;
+typedef matrix<int, 2, 2> int2x2;
+typedef matrix<int, 2, 3> int2x3;
+typedef matrix<int, 2, 4> int2x4;
+typedef matrix<int, 3, 1> int3x1;
+typedef matrix<int, 3, 2> int3x2;
+typedef matrix<int, 3, 3> int3x3;
+typedef matrix<int, 3, 4> int3x4;
+typedef matrix<int, 4, 1> int4x1;
+typedef matrix<int, 4, 2> int4x2;
+typedef matrix<int, 4, 3> int4x3;
+typedef matrix<int, 4, 4> int4x4;
+typedef matrix<uint, 1, 1> uint1x1;
+typedef matrix<uint, 1, 2> uint1x2;
+typedef matrix<uint, 1, 3> uint1x3;
+typedef matrix<uint, 1, 4> uint1x4;
+typedef matrix<uint, 2, 1> uint2x1;
+typedef matrix<uint, 2, 2> uint2x2;
+typedef matrix<uint, 2, 3> uint2x3;
+typedef matrix<uint, 2, 4> uint2x4;
+typedef matrix<uint, 3, 1> uint3x1;
+typedef matrix<uint, 3, 2> uint3x2;
+typedef matrix<uint, 3, 3> uint3x3;
+typedef matrix<uint, 3, 4> uint3x4;
+typedef matrix<uint, 4, 1> uint4x1;
+typedef matrix<uint, 4, 2> uint4x2;
+typedef matrix<uint, 4, 3> uint4x3;
+typedef matrix<uint, 4, 4> uint4x4;
+typedef matrix<int32_t, 1, 1> int32_t1x1;
+typedef matrix<int32_t, 1, 2> int32_t1x2;
+typedef matrix<int32_t, 1, 3> int32_t1x3;
+typedef matrix<int32_t, 1, 4> int32_t1x4;
+typedef matrix<int32_t, 2, 1> int32_t2x1;
+typedef matrix<int32_t, 2, 2> int32_t2x2;
+typedef matrix<int32_t, 2, 3> int32_t2x3;
+typedef matrix<int32_t, 2, 4> int32_t2x4;
+typedef matrix<int32_t, 3, 1> int32_t3x1;
+typedef matrix<int32_t, 3, 2> int32_t3x2;
+typedef matrix<int32_t, 3, 3> int32_t3x3;
+typedef matrix<int32_t, 3, 4> int32_t3x4;
+typedef matrix<int32_t, 4, 1> int32_t4x1;
+typedef matrix<int32_t, 4, 2> int32_t4x2;
+typedef matrix<int32_t, 4, 3> int32_t4x3;
+typedef matrix<int32_t, 4, 4> int32_t4x4;
+typedef matrix<uint32_t, 1, 1> uint32_t1x1;
+typedef matrix<uint32_t, 1, 2> uint32_t1x2;
+typedef matrix<uint32_t, 1, 3> uint32_t1x3;
+typedef matrix<uint32_t, 1, 4> uint32_t1x4;
+typedef matrix<uint32_t, 2, 1> uint32_t2x1;
+typedef matrix<uint32_t, 2, 2> uint32_t2x2;
+typedef matrix<uint32_t, 2, 3> uint32_t2x3;
+typedef matrix<uint32_t, 2, 4> uint32_t2x4;
+typedef matrix<uint32_t, 3, 1> uint32_t3x1;
+typedef matrix<uint32_t, 3, 2> uint32_t3x2;
+typedef matrix<uint32_t, 3, 3> uint32_t3x3;
+typedef matrix<uint32_t, 3, 4> uint32_t3x4;
+typedef matrix<uint32_t, 4, 1> uint32_t4x1;
+typedef matrix<uint32_t, 4, 2> uint32_t4x2;
+typedef matrix<uint32_t, 4, 3> uint32_t4x3;
+typedef matrix<uint32_t, 4, 4> uint32_t4x4;
+typedef matrix<int64_t, 1, 1> int64_t1x1;
+typedef matrix<int64_t, 1, 2> int64_t1x2;
+typedef matrix<int64_t, 1, 3> int64_t1x3;
+typedef matrix<int64_t, 1, 4> int64_t1x4;
+typedef matrix<int64_t, 2, 1> int64_t2x1;
+typedef matrix<int64_t, 2, 2> int64_t2x2;
+typedef matrix<int64_t, 2, 3> int64_t2x3;
+typedef matrix<int64_t, 2, 4> int64_t2x4;
+typedef matrix<int64_t, 3, 1> int64_t3x1;
+typedef matrix<int64_t, 3, 2> int64_t3x2;
+typedef matrix<int64_t, 3, 3> int64_t3x3;
+typedef matrix<int64_t, 3, 4> int64_t3x4;
+typedef matrix<int64_t, 4, 1> int64_t4x1;
+typedef matrix<int64_t, 4, 2> int64_t4x2;
+typedef matrix<int64_t, 4, 3> int64_t4x3;
+typedef matrix<int64_t, 4, 4> int64_t4x4;
+typedef matrix<uint64_t, 1, 1> uint64_t1x1;
+typedef matrix<uint64_t, 1, 2> uint64_t1x2;
+typedef matrix<uint64_t, 1, 3> uint64_t1x3;
+typedef matrix<uint64_t, 1, 4> uint64_t1x4;
+typedef matrix<uint64_t, 2, 1> uint64_t2x1;
+typedef matrix<uint64_t, 2, 2> uint64_t2x2;
+typedef matrix<uint64_t, 2, 3> uint64_t2x3;
+typedef matrix<uint64_t, 2, 4> uint64_t2x4;
+typedef matrix<uint64_t, 3, 1> uint64_t3x1;
+typedef matrix<uint64_t, 3, 2> uint64_t3x2;
+typedef matrix<uint64_t, 3, 3> uint64_t3x3;
+typedef matrix<uint64_t, 3, 4> uint64_t3x4;
+typedef matrix<uint64_t, 4, 1> uint64_t4x1;
+typedef matrix<uint64_t, 4, 2> uint64_t4x2;
+typedef matrix<uint64_t, 4, 3> uint64_t4x3;
+typedef matrix<uint64_t, 4, 4> uint64_t4x4;
+
+typedef matrix<half, 1, 1> half1x1;
+typedef matrix<half, 1, 2> half1x2;
+typedef matrix<half, 1, 3> half1x3;
+typedef matrix<half, 1, 4> half1x4;
+typedef matrix<half, 2, 1> half2x1;
+typedef matrix<half, 2, 2> half2x2;
+typedef matrix<half, 2, 3> half2x3;
+typedef matrix<half, 2, 4> half2x4;
+typedef matrix<half, 3, 1> half3x1;
+typedef matrix<half, 3, 2> half3x2;
+typedef matrix<half, 3, 3> half3x3;
+typedef matrix<half, 3, 4> half3x4;
+typedef matrix<half, 4, 1> half4x1;
+typedef matrix<half, 4, 2> half4x2;
+typedef matrix<half, 4, 3> half4x3;
+typedef matrix<half, 4, 4> half4x4;
+typedef matrix<float, 1, 1> float1x1;
+typedef matrix<float, 1, 2> float1x2;
+typedef matrix<float, 1, 3> float1x3;
+typedef matrix<float, 1, 4> float1x4;
+typedef matrix<float, 2, 1> float2x1;
+typedef matrix<float, 2, 2> float2x2;
+typedef matrix<float, 2, 3> float2x3;
+typedef matrix<float, 2, 4> float2x4;
+typedef matrix<float, 3, 1> float3x1;
+typedef matrix<float, 3, 2> float3x2;
+typedef matrix<float, 3, 3> float3x3;
+typedef matrix<float, 3, 4> float3x4;
+typedef matrix<float, 4, 1> float4x1;
+typedef matrix<float, 4, 2> float4x2;
+typedef matrix<float, 4, 3> float4x3;
+typedef matrix<float, 4, 4> float4x4;
+typedef matrix<double, 1, 1> double1x1;
+typedef matrix<double, 1, 2> double1x2;
+typedef matrix<double, 1, 3> double1x3;
+typedef matrix<double, 1, 4> double1x4;
+typedef matrix<double, 2, 1> double2x1;
+typedef matrix<double, 2, 2> double2x2;
+typedef matrix<double, 2, 3> double2x3;
+typedef matrix<double, 2, 4> double2x4;
+typedef matrix<double, 3, 1> double3x1;
+typedef matrix<double, 3, 2> double3x2;
+typedef matrix<double, 3, 3> double3x3;
+typedef matrix<double, 3, 4> double3x4;
+typedef matrix<double, 4, 1> double4x1;
+typedef matrix<double, 4, 2> double4x2;
+typedef matrix<double, 4, 3> double4x3;
+typedef matrix<double, 4, 4> double4x4;
+
+#ifdef __HLSL_ENABLE_16_BIT
+typedef matrix<float16_t, 1, 1> float16_t1x1;
+typedef matrix<float16_t, 1, 2> float16_t1x2;
+typedef matrix<float16_t, 1, 3> float16_t1x3;
+typedef matrix<float16_t, 1, 4> float16_t1x4;
+typedef matrix<float16_t, 2, 1> float16_t2x1;
+typedef matrix<float16_t, 2, 2> float16_t2x2;
+typedef matrix<float16_t, 2, 3> float16_t2x3;
+typedef matrix<float16_t, 2, 4> float16_t2x4;
+typedef matrix<float16_t, 3, 1> float16_t3x1;
+typedef matrix<float16_t, 3, 2> float16_t3x2;
+typedef matrix<float16_t, 3, 3> float16_t3x3;
+typedef matrix<float16_t, 3, 4> float16_t3x4;
+typedef matrix<float16_t, 4, 1> float16_t4x1;
+typedef matrix<float16_t, 4, 2> float16_t4x2;
+typedef matrix<float16_t, 4, 3> float16_t4x3;
+typedef matrix<float16_t, 4, 4> float16_t4x4;
+#endif
+
+typedef matrix<float32_t, 1, 1> float32_t1x1;
+typedef matrix<float32_t, 1, 2> float32_t1x2;
+typedef matrix<float32_t, 1, 3> float32_t1x3;
+typedef matrix<float32_t, 1, 4> float32_t1x4;
+typedef matrix<float32_t, 2, 1> float32_t2x1;
+typedef matrix<float32_t, 2, 2> float32_t2x2;
+typedef matrix<float32_t, 2, 3> float32_t2x3;
+typedef matrix<float32_t, 2, 4> float32_t2x4;
+typedef matrix<float32_t, 3, 1> float32_t3x1;
+typedef matrix<float32_t, 3, 2> float32_t3x2;
+typedef matrix<float32_t, 3, 3> float32_t3x3;
+typedef matrix<float32_t, 3, 4> float32_t3x4;
+typedef matrix<float32_t, 4, 1> float32_t4x1;
+typedef matrix<float32_t, 4, 2> float32_t4x2;
+typedef matrix<float32_t, 4, 3> float32_t4x3;
+typedef matrix<float32_t, 4, 4> float32_t4x4;
+typedef matrix<float64_t, 1, 1> float64_t1x1;
+typedef matrix<float64_t, 1, 2> float64_t1x2;
+typedef matrix<float64_t, 1, 3> float64_t1x3;
+typedef matrix<float64_t, 1, 4> float64_t1x4;
+typedef matrix<float64_t, 2, 1> float64_t2x1;
+typedef matrix<float64_t, 2, 2> float64_t2x2;
+typedef matrix<float64_t, 2, 3> float64_t2x3;
+typedef matrix<float64_t, 2, 4> float64_t2x4;
+typedef matrix<float64_t, 3, 1> float64_t3x1;
+typedef matrix<float64_t, 3, 2> float64_t3x2;
+typedef matrix<float64_t, 3, 3> float64_t3x3;
+typedef matrix<float64_t, 3, 4> float64_t3x4;
+typedef matrix<float64_t, 4, 1> float64_t4x1;
+typedef matrix<float64_t, 4, 2> float64_t4x2;
+typedef matrix<float64_t, 4, 3> float64_t4x3;
+typedef matrix<float64_t, 4, 4> float64_t4x4;
+
 } // namespace hlsl
 
 #endif //_HLSL_HLSL_BASIC_TYPES_H_
diff --git a/clang/lib/Sema/HLSLExternalSemaSource.cpp 
b/clang/lib/Sema/HLSLExternalSemaSource.cpp
index 3386d8da281e9..9ac60712b87b8 100644
--- a/clang/lib/Sema/HLSLExternalSemaSource.cpp
+++ b/clang/lib/Sema/HLSLExternalSemaSource.cpp
@@ -121,8 +121,80 @@ void HLSLExternalSemaSource::defineHLSLVectorAlias() {
   HLSLNamespace->addDecl(Template);
 }
 
+void HLSLExternalSemaSource::defineHLSLMatrixAlias() {
+  ASTContext &AST = SemaPtr->getASTContext();
+ llvm::SmallVector<NamedDecl *> TemplateParams;
+
+  auto *TypeParam = TemplateTypeParmDecl::Create(
+      AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 0,
+      &AST.Idents.get("element", tok::TokenKind::identifier), false, false);
+  TypeParam->setDefaultArgument(
+      AST, SemaPtr->getTrivialTemplateArgumentLoc(
+               TemplateArgument(AST.FloatTy), QualType(), SourceLocation()));
+
+  TemplateParams.emplace_back(TypeParam);
+
+  // these should be 64 bit to be consistent with other clang matrices.
+  auto *RowsParam = NonTypeTemplateParmDecl::Create(
+      AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 1,
+      &AST.Idents.get("rows_count", tok::TokenKind::identifier), AST.IntTy,
+      false, AST.getTrivialTypeSourceInfo(AST.IntTy));
+  llvm::APInt RVal(AST.getIntWidth(AST.IntTy), 4);
+  TemplateArgument RDefault(AST, llvm::APSInt(std::move(RVal)), AST.IntTy,
+                            /*IsDefaulted=*/true);
+  RowsParam->setDefaultArgument(
+      AST, SemaPtr->getTrivialTemplateArgumentLoc(RDefault, AST.IntTy,
+                                                  SourceLocation(), 
RowsParam));
+  TemplateParams.emplace_back(RowsParam);
+
+  auto *ColsParam = NonTypeTemplateParmDecl::Create(
+      AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 2,
+      &AST.Idents.get("cols_count", tok::TokenKind::identifier), AST.IntTy,
+      false, AST.getTrivialTypeSourceInfo(AST.IntTy));
+  llvm::APInt CVal(AST.getIntWidth(AST.IntTy), 4);
+  TemplateArgument CDefault(AST, llvm::APSInt(std::move(CVal)), AST.IntTy,
+                            /*IsDefaulted=*/true);
+  ColsParam->setDefaultArgument(
+      AST, SemaPtr->getTrivialTemplateArgumentLoc(CDefault, AST.IntTy,
+                                                  SourceLocation(), 
ColsParam));
+  TemplateParams.emplace_back(ColsParam);
+
+  auto *ParamList =
+      TemplateParameterList::Create(AST, SourceLocation(), SourceLocation(),
+                                    TemplateParams, SourceLocation(), nullptr);
+
+  IdentifierInfo &II = AST.Idents.get("matrix", tok::TokenKind::identifier);
+
+  QualType AliasType = AST.getDependentSizedMatrixType(
+      AST.getTemplateTypeParmType(0, 0, false, TypeParam),
+      DeclRefExpr::Create(
+          AST, NestedNameSpecifierLoc(), SourceLocation(), RowsParam, false,
+          DeclarationNameInfo(RowsParam->getDeclName(), SourceLocation()),
+          AST.IntTy, VK_LValue),
+      DeclRefExpr::Create(
+          AST, NestedNameSpecifierLoc(), SourceLocation(), ColsParam, false,
+          DeclarationNameInfo(ColsParam->getDeclName(), SourceLocation()),
+          AST.IntTy, VK_LValue),
+      SourceLocation());
+
+  auto *Record = TypeAliasDecl::Create(AST, HLSLNamespace, SourceLocation(),
+                                       SourceLocation(), &II,
+                                       
AST.getTrivialTypeSourceInfo(AliasType));
+  Record->setImplicit(true);
+
+  auto *Template =
+      TypeAliasTemplateDecl::Create(AST, HLSLNamespace, SourceLocation(),
+                                    Record->getIdentifier(), ParamList, 
Record);
+
+  Record->setDescribedAliasTemplate(Template);
+  Template->setImplicit(true);
+  Template->setLexicalDeclContext(Record->getDeclContext());
+  HLSLNamespace->addDecl(Template);
+}
+
 void HLSLExternalSemaSource::defineTrivialHLSLTypes() {
   defineHLSLVectorAlias();
+  defineHLSLMatrixAlias();
 }
 
 /// Set up common members and attributes for buffer types
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 0af38472b0fec..38f81a24945c5 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -3285,7 +3285,6 @@ static void BuildFlattenedTypeList(QualType BaseTy,
   while (!WorkList.empty()) {
     QualType T = WorkList.pop_back_val();
     T = T.getCanonicalType().getUnqualifiedType();
-    assert(!isa<MatrixType>(T) && "Matrix types not yet supported in HLSL");
     if (const auto *AT = dyn_cast<ConstantArrayType>(T)) {
       llvm::SmallVector<QualType, 16> ElementFields;
       // Generally I've avoided recursion in this algorithm, but arrays of
diff --git a/clang/test/AST/HLSL/matrix-alias.hlsl 
b/clang/test/AST/HLSL/matrix-alias.hlsl
new file mode 100644
index 0000000000000..2758b6f0d202f
--- /dev/null
+++ b/clang/test/AST/HLSL/matrix-alias.hlsl
@@ -0,0 +1,49 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -ast-dump -o - %s | 
FileCheck %s
+
+// Test that matrix aliases are set up properly for HLSL
+
+// CHECK: NamespaceDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> 
implicit hlsl
+// CHECK-NEXT: TypeAliasTemplateDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> 
<invalid sloc> implicit vector
+// CHECK-NEXT: TemplateTypeParmDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> 
<invalid sloc> class depth 0 index 0 element
+// CHECK-NEXT: TemplateArgument type 'float'
+// CHECK-NEXT: BuiltinType 0x{{[0-9a-fA-F]+}} 'float'
+// CHECK-NEXT: NonTypeTemplateParmDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> 
<invalid sloc> 'int' depth 0 index 1 element_count
+// CHECK-NEXT: TemplateArgument expr
+// CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <<invalid sloc>> 'int' 4
+// CHECK-NEXT: TypeAliasDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid 
sloc> implicit vector 'vector<element, element_count>'
+// CHECK-NEXT: DependentSizedExtVectorType 0x{{[0-9a-fA-F]+}} 'vector<element, 
element_count>' dependent <invalid sloc>
+// CHECK-NEXT: TemplateTypeParmType 0x{{[0-9a-fA-F]+}} 'element' dependent 
depth 0 index 0
+// CHECK-NEXT: TemplateTypeParm 0x{{[0-9a-fA-F]+}} 'element'
+// CHECK-NEXT: DeclRefExpr 0x{{[0-9a-fA-F]+}} <<invalid sloc>> 'int' lvalue
+// CHECK-SAME: NonTypeTemplateParm 0x{{[0-9a-fA-F]+}} 'element_count' 'int'
+
+// Make sure we got a using directive at the end.
+// CHECK: UsingDirectiveDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid 
sloc> Namespace 0x{{[0-9a-fA-F]+}} 'hlsl'
+
+[numthreads(1,1,1)]
+int entry() {
+  // Verify that the alias is generated inside the hlsl namespace.
+  hlsl::matrix<float, 2, 2> Mat2x2f;
+
+  // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:26:3, col:36>
+  // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:29> col:29 Mat2x2f 
'hlsl::matrix<float, 2, 2>'
+
+  // Verify that you don't need to specify the namespace.
+  matrix<int, 2, 2> Mat2x2i;
+
+  // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:32:3, col:28>
+  // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:21> col:21 Mat2x2i 
'matrix<int, 2, 2>'
+
+  // Build a bigger matrix.
+  matrix<double, 4, 4> Mat4x4d;
+
+  // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:38:3, col:31>
+  // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:24> col:24 Mat4x4d 
'matrix<double, 4, 4>'
+
+  // Verify that the implicit arguments generate the correct type.
+  matrix<> ImpMat4x4;
+
+  // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:44:3, col:21>
+  // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:12> col:12 ImpMat4x4 
'matrix<>':'matrix<float, 4, 4>'
+  return 1;
+}
diff --git a/clang/test/CodeGenHLSL/builtins/transpose-builtin.hlsl 
b/clang/test/CodeGenHLSL/builtins/transpose-builtin.hlsl
new file mode 100644
index 0000000000000..86aa7cd6985dd
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/transpose-builtin.hlsl
@@ -0,0 +1,30 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple 
dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm 
-disable-llvm-passes -o - | FileCheck %s
+
+// NOTE: This test is only to confirm we can do codgen with the matrix alias.
+
+// CHECK-LABEL: define {{.*}}transpose_half_2x2
+void transpose_half_2x2(half2x2 a) {
+  // CHECK:        [[A:%.*]] = load <4 x half>, ptr {{.*}}, align 2
+  // CHECK-NEXT:   [[TRANS:%.*]] = call {{.*}}<4 x half> 
@llvm.matrix.transpose.v4f16(<4 x half> [[A]], i32 2, i32 2)
+  // CHECK-NEXT:   store <4 x half> [[TRANS]], ptr %a_t, align 2
+
+  half2x2 a_t = __builtin_matrix_transpose(a);
+}
+
+// CHECK-LABEL: define {{.*}}transpose_float_3x2
+void transpose_float_3x2(float3x2 a) {
+  // CHECK:        [[A:%.*]] = load <6 x float>, ptr {{.*}}, align 4
+  // CHECK-NEXT:   [[TRANS:%.*]] = call {{.*}}<6 x float> 
@llvm.matrix.transpose.v6f32(<6 x float> [[A]], i32 3, i32 2)
+  // CHECK-NEXT:   store <6 x float> [[TRANS]], ptr %a_t, align 4
+
+  float2x3 a_t = __builtin_matrix_transpose(a);
+}
+
+// CHECK-LABEL: define {{.*}}transpose_int_4x3
+void transpose_int_4x3(int4x3 a) {
+  // CHECK:         [[A:%.*]] = load <12 x i32>, ptr {{.*}}, align 4
+  // CHECK-NEXT:    [[TRANS:%.*]] = call <12 x i32> 
@llvm.matrix.transpose.v12i32(<12 x i32> [[A]], i32 4, i32 3)
+  // CHECK-NEXT:    store <12 x i32> [[TRANS]], ptr %a_t, align 4
+
+  int3x4 a_t = __builtin_matrix_transpose(a);
+}
diff --git a/clang/test/SemaHLSL/BuiltIns/matrix-basic_types-errors.hlsl 
b/clang/test/SemaHLSL/BuiltIns/matrix-basic_types-errors.hlsl
new file mode 100644
index 0000000000000..e9d37ce22cc07
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/matrix-basic_types-errors.hlsl
@@ -0,0 +1,5 @@
+
+// RUN: %clang_cc1 -finclude-default-header -triple 
dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only 
-disable-llvm-passes -verify
+
+uint64_t5x5 mat;
+// expected-error@-1  {{unknown type name 'uint64_t5x5'}}
diff --git a/clang/test/SemaHLSL/BuiltIns/matrix-errors.hlsl 
b/clang/test/SemaHLSL/BuiltIns/matrix-errors.hlsl
new file mode 100644
index 0000000000000..9a820c2205843
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/matrix-errors.hlsl
@@ -0,0 +1,29 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl 
-fsyntax-only -verify %s
+
+// Some bad declarations
+hlsl::matrix ShouldWorkSomeday; // expected-error{{use of alias template 
'hlsl::matrix' requires template arguments}}
+// expected-note@*:* {{template declaration from hidden source: template 
<class element = float, int rows_count = 4, int cols_count = 4> using matrix = 
element __attribute__((matrix_type(rows_count, cols_count)))}}
+
+hlsl::matrix<1,1,1> BadMat; // expected-error{{template argument for template 
type parameter must be a type}}
+// expected-note@*:* {{template parameter from hidden source: class element = 
float}}
+
+hlsl::matrix<int, float,4> AnotherBadMat; // expected-error{{template argument 
for non-type template parameter must be an expression}}
+// expected-note@*:* {{template parameter from hidden source: int rows_count = 
4}}
+
+hlsl::matrix<int, 2, 3, 2> YABV; // expected-error{{too many template 
arguments for alias template 'matrix'}}
+// expected-note@*:* {{template declaration from hidden source: template 
<class element = float, int rows_count = 4, int cols_count = 4> using matrix = 
element __attribute__((matrix_type(rows_count, cols_count)))}}
+
+// This code is rejected by clang because clang puts the HLSL built-in types
+// into the HLSL namespace.
+namespace hlsl {
+  struct matrix {}; // expected-error {{redefinition of 'matrix'}}
+}
+
+// This code is rejected by dxc because dxc puts the HLSL built-in types
+// into the global space, but clang will allow it even though it will shadow 
the
+// matrix template.
+struct matrix {}; // expected-note {{candidate found by name lookup is 
'matrix'}}
+
+matrix<int,2,2> matInt2x2; // expected-error {{reference to 'matrix' is 
ambiguous}}
+
+// expected-note@*:* {{candidate found by name lookup is 'hlsl::matrix'}}

>From 9e18ebbd44be81798e875159e7800b772ba58bde Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlo...@microsoft.com>
Date: Wed, 17 Sep 2025 16:26:43 -0400
Subject: [PATCH 2/2] fix formatting

---
 clang/lib/AST/TypePrinter.cpp                        | 12 ++++++++----
 clang/lib/Headers/hlsl/hlsl_basic_types.h            |  2 +-
 clang/lib/Sema/HLSLExternalSemaSource.cpp            |  2 +-
 .../SemaHLSL/BuiltIns/matrix-basic_types-errors.hlsl |  7 +++++--
 4 files changed, 15 insertions(+), 8 deletions(-)

diff --git a/clang/lib/AST/TypePrinter.cpp b/clang/lib/AST/TypePrinter.cpp
index 82859c4015772..f3448af5f8f50 100644
--- a/clang/lib/AST/TypePrinter.cpp
+++ b/clang/lib/AST/TypePrinter.cpp
@@ -850,7 +850,8 @@ static void printDims(const ConstantMatrixType *T, 
raw_ostream &OS) {
   OS << T->getNumRows() << ", " << T->getNumColumns();
 }
 
-static void printHLSLMatrixBefore(TypePrinter &TP, const ConstantMatrixType 
*T, raw_ostream &OS) {
+static void printHLSLMatrixBefore(TypePrinter &TP, const ConstantMatrixType *T,
+                                  raw_ostream &OS) {
   OS << "matrix<";
   TP.printBefore(T->getElementType(), OS);
 }
@@ -861,14 +862,16 @@ static void printHLSLMatrixAfter(const ConstantMatrixType 
*T, raw_ostream &OS) {
   OS << ">";
 }
 
-static void printClangMatrixBefore(TypePrinter &TP, const ConstantMatrixType 
*T, raw_ostream &OS) {
+static void printClangMatrixBefore(TypePrinter &TP, const ConstantMatrixType 
*T,
+                                   raw_ostream &OS) {
   TP.printBefore(T->getElementType(), OS);
   OS << " __attribute__((matrix_type(";
   printDims(T, OS);
   OS << ")))";
 }
 
-void TypePrinter::printConstantMatrixBefore(const ConstantMatrixType *T, 
raw_ostream &OS) {
+void TypePrinter::printConstantMatrixBefore(const ConstantMatrixType *T,
+                                            raw_ostream &OS) {
   if (Policy.UseHLSLTypes) {
     printHLSLMatrixBefore(*this, T, OS);
     return;
@@ -876,7 +879,8 @@ void TypePrinter::printConstantMatrixBefore(const 
ConstantMatrixType *T, raw_ost
   printClangMatrixBefore(*this, T, OS);
 }
 
-void TypePrinter::printConstantMatrixAfter(const ConstantMatrixType *T, 
raw_ostream &OS) {
+void TypePrinter::printConstantMatrixAfter(const ConstantMatrixType *T,
+                                           raw_ostream &OS) {
   if (Policy.UseHLSLTypes) {
     printHLSLMatrixAfter(T, OS);
     return;
diff --git a/clang/lib/Headers/hlsl/hlsl_basic_types.h 
b/clang/lib/Headers/hlsl/hlsl_basic_types.h
index c750261952d4f..fc1e265067714 100644
--- a/clang/lib/Headers/hlsl/hlsl_basic_types.h
+++ b/clang/lib/Headers/hlsl/hlsl_basic_types.h
@@ -115,7 +115,7 @@ typedef vector<float64_t, 2> float64_t2;
 typedef vector<float64_t, 3> float64_t3;
 typedef vector<float64_t, 4> float64_t4;
 
-ifdef __HLSL_ENABLE_16_BIT
+#ifdef __HLSL_ENABLE_16_BIT
 typedef matrix<int16_t, 1, 1> int16_t1x1;
 typedef matrix<int16_t, 1, 2> int16_t1x2;
 typedef matrix<int16_t, 1, 3> int16_t1x3;
diff --git a/clang/lib/Sema/HLSLExternalSemaSource.cpp 
b/clang/lib/Sema/HLSLExternalSemaSource.cpp
index 9ac60712b87b8..a30a440d07e6a 100644
--- a/clang/lib/Sema/HLSLExternalSemaSource.cpp
+++ b/clang/lib/Sema/HLSLExternalSemaSource.cpp
@@ -123,7 +123,7 @@ void HLSLExternalSemaSource::defineHLSLVectorAlias() {
 
 void HLSLExternalSemaSource::defineHLSLMatrixAlias() {
   ASTContext &AST = SemaPtr->getASTContext();
- llvm::SmallVector<NamedDecl *> TemplateParams;
+  llvm::SmallVector<NamedDecl *> TemplateParams;
 
   auto *TypeParam = TemplateTypeParmDecl::Create(
       AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 0,
diff --git a/clang/test/SemaHLSL/BuiltIns/matrix-basic_types-errors.hlsl 
b/clang/test/SemaHLSL/BuiltIns/matrix-basic_types-errors.hlsl
index e9d37ce22cc07..db43a61774760 100644
--- a/clang/test/SemaHLSL/BuiltIns/matrix-basic_types-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/matrix-basic_types-errors.hlsl
@@ -1,5 +1,8 @@
-
-// RUN: %clang_cc1 -finclude-default-header -triple 
dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only 
-disable-llvm-passes -verify
+// RUN: %clang_cc1 -finclude-default-header -triple 
dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
 
 uint64_t5x5 mat;
 // expected-error@-1  {{unknown type name 'uint64_t5x5'}}
+
+// Note: this one only fails because -fnative-half-type is not set
+uint16_t4x4 mat;
+// expected-error@-1  {{unknown type name 'uint16_t4x4'}}

_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to