fhahn updated this revision to Diff 262413.
fhahn marked 2 inline comments as done.
fhahn added a comment.

Fix template deduction.

In D72281#2022004 <https://reviews.llvm.org/D72281#2022004>, @rjmccall wrote:

> > The type of m1 matches the matrix argument in all 3 definitions of 
> > use_matrix and for some reason the return type is not used to disambiguate 
> > the definitions:
>
> C++ does not have the ability to use return types to disambiguate function 
> calls.
>
> > I am not sure where things are going wrong unfortunately. The matrix 
> > argument deduction should mirror the code for types like 
> > DependentSizedArrayType. Do you have any idea what could be missing?
>
> This is what function template partial ordering is supposed to do.  You'll 
> see that it works if you replace the definition of `matrix` to use array 
> types:


Ah right, thanks for clarifying. I think I managed to fix the remaining 
problems. Previously the patch did not handle DependentSizedMatrixTypes with 
non-dependent constant dimensions properly (e.g. a DependentSizedMatrixType 
could have one dependent and one non-dependent dimension). That's a difference 
to DependentSizedArrayType. Now the example works as expected (I've etxended it 
a bit). Cases like the one below are rejected

  template <class T, unsigned long R, unsigned long C>
  using matrix = T __attribute__((matrix_type(R, C)));
  
  template <class T, unsigned long R>
  void use_matrix(matrix<T, R, 10> &m) {}
  // expected-note@-1 {{candidate function [with T = float, R = 10]}}
  
  template <class T, unsigned long C>
  void use_matrix(matrix<T, 10, C> &m) {}
  // expected-note@-1 {{candidate function [with T = float, C = 10]}}
  
  void test_ambigous_deduction1() {
    matrix<float, 10, 10> m;
    use_matrix(m);
    // expected-error@-1 {{call to 'use_matrix' is ambiguous}}
  }


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D72281/new/

https://reviews.llvm.org/D72281

Files:
  clang/include/clang/AST/ASTContext.h
  clang/include/clang/AST/RecursiveASTVisitor.h
  clang/include/clang/AST/Type.h
  clang/include/clang/AST/TypeLoc.h
  clang/include/clang/AST/TypeProperties.td
  clang/include/clang/Basic/Attr.td
  clang/include/clang/Basic/DiagnosticSemaKinds.td
  clang/include/clang/Basic/Features.def
  clang/include/clang/Basic/LangOptions.def
  clang/include/clang/Basic/TypeNodes.td
  clang/include/clang/Driver/Options.td
  clang/include/clang/Sema/Sema.h
  clang/include/clang/Serialization/TypeBitCodes.def
  clang/lib/AST/ASTContext.cpp
  clang/lib/AST/ASTStructuralEquivalence.cpp
  clang/lib/AST/ExprConstant.cpp
  clang/lib/AST/ItaniumMangle.cpp
  clang/lib/AST/MicrosoftMangle.cpp
  clang/lib/AST/Type.cpp
  clang/lib/AST/TypePrinter.cpp
  clang/lib/CodeGen/CGDebugInfo.cpp
  clang/lib/CodeGen/CGDebugInfo.h
  clang/lib/CodeGen/CGExpr.cpp
  clang/lib/CodeGen/CodeGenFunction.cpp
  clang/lib/CodeGen/CodeGenTypes.cpp
  clang/lib/CodeGen/ItaniumCXXABI.cpp
  clang/lib/Driver/ToolChains/Clang.cpp
  clang/lib/Frontend/CompilerInvocation.cpp
  clang/lib/Sema/SemaExpr.cpp
  clang/lib/Sema/SemaLookup.cpp
  clang/lib/Sema/SemaTemplate.cpp
  clang/lib/Sema/SemaTemplateDeduction.cpp
  clang/lib/Sema/SemaType.cpp
  clang/lib/Sema/TreeTransform.h
  clang/lib/Serialization/ASTReader.cpp
  clang/lib/Serialization/ASTWriter.cpp
  clang/test/CodeGen/debug-info-matrix-types.c
  clang/test/CodeGen/matrix-type.c
  clang/test/CodeGenCXX/matrix-type.cpp
  clang/test/Parser/matrix-type-disabled.c
  clang/test/SemaCXX/matrix-type.cpp
  clang/tools/libclang/CIndex.cpp

Index: clang/tools/libclang/CIndex.cpp
===================================================================
--- clang/tools/libclang/CIndex.cpp
+++ clang/tools/libclang/CIndex.cpp
@@ -1795,6 +1795,8 @@
 DEFAULT_TYPELOC_IMPL(DependentSizedExtVector, Type)
 DEFAULT_TYPELOC_IMPL(Vector, Type)
 DEFAULT_TYPELOC_IMPL(ExtVector, VectorType)
+DEFAULT_TYPELOC_IMPL(ConstantMatrix, MatrixType)
+DEFAULT_TYPELOC_IMPL(DependentSizedMatrix, MatrixType)
 DEFAULT_TYPELOC_IMPL(FunctionProto, FunctionType)
 DEFAULT_TYPELOC_IMPL(FunctionNoProto, FunctionType)
 DEFAULT_TYPELOC_IMPL(Record, TagType)
Index: clang/test/SemaCXX/matrix-type.cpp
===================================================================
--- /dev/null
+++ clang/test/SemaCXX/matrix-type.cpp
@@ -0,0 +1,98 @@
+// RUN: %clang_cc1 -fsyntax-only -pedantic -fenable-matrix -std=c++11 -verify -triple x86_64-apple-darwin %s
+
+using matrix_double_t = double __attribute__((matrix_type(6, 6)));
+using matrix_float_t = float __attribute__((matrix_type(6, 6)));
+using matrix_int_t = int __attribute__((matrix_type(6, 6)));
+
+void matrix_var_dimensions(int Rows, unsigned Columns, char C) {
+  using matrix1_t = int __attribute__((matrix_type(Rows, 1)));    // expected-error{{matrix_type attribute requires an integer constant}}
+  using matrix2_t = int __attribute__((matrix_type(1, Columns))); // expected-error{{matrix_type attribute requires an integer constant}}
+  using matrix3_t = int __attribute__((matrix_type(C, C)));       // expected-error{{matrix_type attribute requires an integer constant}}
+  using matrix4_t = int __attribute__((matrix_type(-1, 1)));      // expected-error{{matrix row size too large}}
+  using matrix5_t = int __attribute__((matrix_type(1, -1)));      // expected-error{{matrix column size too large}}
+  using matrix6_t = int __attribute__((matrix_type(0, 1)));       // expected-error{{zero matrix size}}
+  using matrix7_t = int __attribute__((matrix_type(1, 0)));       // expected-error{{zero matrix size}}
+  using matrix7_t = int __attribute__((matrix_type(char, 0)));    // expected-error{{expected '(' for function-style cast or type construction}}
+  using matrix8_t = int __attribute__((matrix_type(1048576, 1))); // expected-error{{matrix row size too large}}
+}
+
+struct S1 {};
+
+enum TestEnum {
+  A,
+  B
+};
+
+void matrix_unsupported_element_type() {
+  using matrix1_t = char *__attribute__((matrix_type(1, 1)));    // expected-error{{invalid matrix element type 'char *'}}
+  using matrix2_t = S1 __attribute__((matrix_type(1, 1)));       // expected-error{{invalid matrix element type 'S1'}}
+  using matrix3_t = bool __attribute__((matrix_type(1, 1)));     // expected-error{{invalid matrix element type 'bool'}}
+  using matrix4_t = TestEnum __attribute__((matrix_type(1, 1))); // expected-error{{invalid matrix element type 'TestEnum'}}
+}
+
+template <typename T> // expected-note{{declared here}}
+void matrix_template_1() {
+  using matrix1_t = float __attribute__((matrix_type(T, T))); // expected-error{{'T' does not refer to a value}}
+}
+
+template <class C> // expected-note{{declared here}}
+void matrix_template_2() {
+  using matrix1_t = float __attribute__((matrix_type(C, C))); // expected-error{{'C' does not refer to a value}}
+}
+
+template <unsigned Rows, unsigned Cols>
+void matrix_template_3() {
+  using matrix1_t = float __attribute__((matrix_type(Rows, Cols))); // expected-error{{zero matrix size}}
+}
+
+void instantiate_template_3() {
+  matrix_template_3<1, 10>();
+  matrix_template_3<0, 10>(); // expected-note{{in instantiation of function template specialization 'matrix_template_3<0, 10>' requested here}}
+}
+
+template <int Rows, unsigned Cols>
+void matrix_template_4() {
+  using matrix1_t = float __attribute__((matrix_type(Rows, Cols))); // expected-error{{matrix row size too large}}
+}
+
+void instantiate_template_4() {
+  matrix_template_4<2, 10>();
+  matrix_template_4<-3, 10>(); // expected-note{{in instantiation of function template specialization 'matrix_template_4<-3, 10>' requested here}}
+}
+
+template <class T, unsigned long R, unsigned long C>
+using matrix = T __attribute__((matrix_type(R, C)));
+
+template <class T, unsigned long R>
+void use_matrix(matrix<T, R, 10> &m) {}
+// expected-note@-1 {{candidate function [with T = float, R = 10]}}
+
+template <class T, unsigned long C>
+void use_matrix(matrix<T, 10, C> &m) {}
+// expected-note@-1 {{candidate function [with T = float, C = 10]}}
+
+void test_ambigous_deduction1() {
+  matrix<float, 10, 10> m;
+  use_matrix(m);
+  // expected-error@-1 {{call to 'use_matrix' is ambiguous}}
+}
+
+template <class T, long R>
+void invalid_template_arg_ty(matrix<T, R, 10> &m) {}
+// expected-note@-1 {{candidate template ignored: substitution failure [with T = float]: deduced non-type template argument does not have the same type as the corresponding template parameter ('unsigned long' vs 'long')}}
+
+void test_invalid_template_arg_ty() {
+  matrix<float, 10, 10> m;
+  invalid_template_arg_ty(m);
+  // expected-error@-1 {{no matching function for call to 'invalid_template_arg_ty'}}
+}
+
+template <class T, long R>
+void type_conflict(matrix<T, R, 10> &m, T x) {}
+// expected-note@-1 {{candidate template ignored: deduced conflicting types for parameter 'T' ('float' vs. 'char *')}}
+
+void test_type_conflict(char *p) {
+  matrix<float, 10, 10> m;
+  type_conflict(m, p);
+  // expected-error@-1 {{no matching function for call to 'type_conflict'}}
+}
Index: clang/test/Parser/matrix-type-disabled.c
===================================================================
--- /dev/null
+++ clang/test/Parser/matrix-type-disabled.c
@@ -0,0 +1,14 @@
+// RUN: %clang_cc1 %s -triple i686-apple-darwin -verify -fsyntax-only
+
+// Matrix types are disabled by default.
+
+#if __has_extension(matrix_types)
+#error Expected extension 'matrix_types' to be disabled
+#endif
+
+typedef double dx5x5_t __attribute__((matrix_type(5, 5)));
+// expected-error@-1 {{matrix types extension is disabled. Pass -fenable-matrix to enable it}}
+
+void load_store_double(dx5x5_t *a, dx5x5_t *b) {
+  *a = *b;
+}
Index: clang/test/CodeGenCXX/matrix-type.cpp
===================================================================
--- /dev/null
+++ clang/test/CodeGenCXX/matrix-type.cpp
@@ -0,0 +1,291 @@
+// RUN: %clang_cc1 -fenable-matrix -triple x86_64-apple-darwin %s -emit-llvm -disable-llvm-passes -o - -std=c++17 | FileCheck %s
+
+typedef double dx5x5_t __attribute__((matrix_type(5, 5)));
+typedef float fx3x4_t __attribute__((matrix_type(3, 4)));
+
+// CHECK: %struct.Matrix = type { i8, [12 x float], float }
+
+void load_store(dx5x5_t *a, dx5x5_t *b) {
+  // CHECK-LABEL:  define void @_Z10load_storePU11matrix_typeLm5ELm5EdS0_(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [25 x double]*, align 8
+  // CHECK-NEXT:    %b.addr = alloca [25 x double]*, align 8
+  // CHECK-NEXT:    store [25 x double]* %a, [25 x double]** %a.addr, align 8
+  // CHECK-NEXT:    store [25 x double]* %b, [25 x double]** %b.addr, align 8
+  // CHECK-NEXT:    %0 = load [25 x double]*, [25 x double]** %b.addr, align 8
+  // CHECK-NEXT:    %1 = bitcast [25 x double]* %0 to <25 x double>*
+  // CHECK-NEXT:    %2 = load <25 x double>, <25 x double>* %1, align 8
+  // CHECK-NEXT:    %3 = load [25 x double]*, [25 x double]** %a.addr, align 8
+  // CHECK-NEXT:    %4 = bitcast [25 x double]* %3 to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %2, <25 x double>* %4, align 8
+  // CHECK-NEXT:   ret void
+
+  *a = *b;
+}
+
+typedef float fx3x3_t __attribute__((matrix_type(3, 3)));
+
+void parameter_passing(fx3x3_t a, fx3x3_t *b) {
+  // CHECK-LABEL: define void @_Z17parameter_passingU11matrix_typeLm3ELm3EfPS_(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [9 x float], align 4
+  // CHECK-NEXT:    %b.addr = alloca [9 x float]*, align 8
+  // CHECK-NEXT:    %0 = bitcast [9 x float]* %a.addr to <9 x float>*
+  // CHECK-NEXT:    store <9 x float> %a, <9 x float>* %0, align 4
+  // CHECK-NEXT:    store [9 x float]* %b, [9 x float]** %b.addr, align 8
+  // CHECK-NEXT:    %1 = load <9 x float>, <9 x float>* %0, align 4
+  // CHECK-NEXT:    %2 = load [9 x float]*, [9 x float]** %b.addr, align 8
+  // CHECK-NEXT:    %3 = bitcast [9 x float]* %2 to <9 x float>*
+  // CHECK-NEXT:    store <9 x float> %1, <9 x float>* %3, align 4
+  // CHECK-NEXT:    ret void
+  *b = a;
+}
+
+fx3x3_t return_matrix(fx3x3_t *a) {
+  // CHECK-LABEL: define <9 x float> @_Z13return_matrixPU11matrix_typeLm3ELm3Ef(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [9 x float]*, align 8
+  // CHECK-NEXT:    store [9 x float]* %a, [9 x float]** %a.addr, align 8
+  // CHECK-NEXT:    %0 = load [9 x float]*, [9 x float]** %a.addr, align 8
+  // CHECK-NEXT:    %1 = bitcast [9 x float]* %0 to <9 x float>*
+  // CHECK-NEXT:    %2 = load <9 x float>, <9 x float>* %1, align 4
+  // CHECK-NEXT:    ret <9 x float> %2
+  return *a;
+}
+
+struct Matrix {
+  char Tmp1;
+  fx3x4_t Data;
+  float Tmp2;
+};
+
+void matrix_struct_pointers(Matrix *a, Matrix *b) {
+  // CHECK-LABEL: define void @_Z22matrix_struct_pointersP6MatrixS0_(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca %struct.Matrix*, align 8
+  // CHECK-NEXT:    %b.addr = alloca %struct.Matrix*, align 8
+  // CHECK-NEXT:    store %struct.Matrix* %a, %struct.Matrix** %a.addr, align 8
+  // CHECK-NEXT:    store %struct.Matrix* %b, %struct.Matrix** %b.addr, align 8
+  // CHECK-NEXT:    %0 = load %struct.Matrix*, %struct.Matrix** %a.addr, align 8
+  // CHECK-NEXT:    %Data = getelementptr inbounds %struct.Matrix, %struct.Matrix* %0, i32 0, i32 1
+  // CHECK-NEXT:    %1 = bitcast [12 x float]* %Data to <12 x float>*
+  // CHECK-NEXT:    %2 = load <12 x float>, <12 x float>* %1, align 4
+  // CHECK-NEXT:    %3 = load %struct.Matrix*, %struct.Matrix** %b.addr, align 8
+  // CHECK-NEXT:    %Data1 = getelementptr inbounds %struct.Matrix, %struct.Matrix* %3, i32 0, i32 1
+  // CHECK-NEXT:    %4 = bitcast [12 x float]* %Data1 to <12 x float>*
+  // CHECK-NEXT:    store <12 x float> %2, <12 x float>* %4, align 4
+  // CHECK-NEXT:    ret void
+  b->Data = a->Data;
+}
+
+void matrix_struct_reference(Matrix &a, Matrix &b) {
+  // CHECK-LABEL: define void @_Z23matrix_struct_referenceR6MatrixS0_(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca %struct.Matrix*, align 8
+  // CHECK-NEXT:    %b.addr = alloca %struct.Matrix*, align 8
+  // CHECK-NEXT:    store %struct.Matrix* %a, %struct.Matrix** %a.addr, align 8
+  // CHECK-NEXT:    store %struct.Matrix* %b, %struct.Matrix** %b.addr, align 8
+  // CHECK-NEXT:    %0 = load %struct.Matrix*, %struct.Matrix** %a.addr, align 8
+  // CHECK-NEXT:    %Data = getelementptr inbounds %struct.Matrix, %struct.Matrix* %0, i32 0, i32 1
+  // CHECK-NEXT:    %1 = bitcast [12 x float]* %Data to <12 x float>*
+  // CHECK-NEXT:    %2 = load <12 x float>, <12 x float>* %1, align 4
+  // CHECK-NEXT:    %3 = load %struct.Matrix*, %struct.Matrix** %b.addr, align 8
+  // CHECK-NEXT:    %Data1 = getelementptr inbounds %struct.Matrix, %struct.Matrix* %3, i32 0, i32 1
+  // CHECK-NEXT:    %4 = bitcast [12 x float]* %Data1 to <12 x float>*
+  // CHECK-NEXT:    store <12 x float> %2, <12 x float>* %4, align 4
+  // CHECK-NEXT:    ret void
+  b.Data = a.Data;
+}
+
+class MatrixClass {
+public:
+  int Tmp1;
+  fx3x4_t Data;
+  long Tmp2;
+};
+
+void matrix_class_reference(MatrixClass &a, MatrixClass &b) {
+  // CHECK-LABEL: define void @_Z22matrix_class_referenceR11MatrixClassS0_(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca %class.MatrixClass*, align 8
+  // CHECK-NEXT:    %b.addr = alloca %class.MatrixClass*, align 8
+  // CHECK-NEXT:    store %class.MatrixClass* %a, %class.MatrixClass** %a.addr, align 8
+  // CHECK-NEXT:    store %class.MatrixClass* %b, %class.MatrixClass** %b.addr, align 8
+  // CHECK-NEXT:    %0 = load %class.MatrixClass*, %class.MatrixClass** %a.addr, align 8
+  // CHECK-NEXT:    %Data = getelementptr inbounds %class.MatrixClass, %class.MatrixClass* %0, i32 0, i32 1
+  // CHECK-NEXT:    %1 = bitcast [12 x float]* %Data to <12 x float>*
+  // CHECK-NEXT:    %2 = load <12 x float>, <12 x float>* %1, align 4
+  // CHECK-NEXT:    %3 = load %class.MatrixClass*, %class.MatrixClass** %b.addr, align 8
+  // CHECK-NEXT:    %Data1 = getelementptr inbounds %class.MatrixClass, %class.MatrixClass* %3, i32 0, i32 1
+  // CHECK-NEXT:    %4 = bitcast [12 x float]* %Data1 to <12 x float>*
+  // CHECK-NEXT:    store <12 x float> %2, <12 x float>* %4, align 4
+  // CHECK-NEXT:    ret void
+  b.Data = a.Data;
+}
+
+template <typename Ty, unsigned Rows, unsigned Cols>
+class MatrixClassTemplate {
+public:
+  using MatrixTy = Ty __attribute__((matrix_type(Rows, Cols)));
+  int Tmp1;
+  MatrixTy Data;
+  long Tmp2;
+};
+
+template <typename Ty, unsigned Rows, unsigned Cols>
+void matrix_template_reference(MatrixClassTemplate<Ty, Rows, Cols> &a, MatrixClassTemplate<Ty, Rows, Cols> &b) {
+  b.Data = a.Data;
+}
+
+MatrixClassTemplate<float, 10, 15> matrix_template_reference_caller(float *Data) {
+  // CHECK-LABEL: define void @_Z32matrix_template_reference_callerPf(%class.MatrixClassTemplate* noalias sret align 8 %agg.result, float* %Data
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %Data.addr = alloca float*, align 8
+  // CHECK-NEXT:    %Arg = alloca %class.MatrixClassTemplate, align 8
+  // CHECK-NEXT:    store float* %Data, float** %Data.addr, align 8
+  // CHECK-NEXT:    %0 = load float*, float** %Data.addr, align 8
+  // CHECK-NEXT:    %1 = bitcast float* %0 to [150 x float]*
+  // CHECK-NEXT:    %2 = bitcast [150 x float]* %1 to <150 x float>*
+  // CHECK-NEXT:    %3 = load <150 x float>, <150 x float>* %2, align 4
+  // CHECK-NEXT:    %Data1 = getelementptr inbounds %class.MatrixClassTemplate, %class.MatrixClassTemplate* %Arg, i32 0, i32 1
+  // CHECK-NEXT:    %4 = bitcast [150 x float]* %Data1 to <150 x float>*
+  // CHECK-NEXT:    store <150 x float> %3, <150 x float>* %4, align 4
+  // CHECK-NEXT:    call void @_Z25matrix_template_referenceIfLj10ELj15EEvR19MatrixClassTemplateIT_XT0_EXT1_EES3_(%class.MatrixClassTemplate* dereferenceable(616) %Arg, %class.MatrixClassTemplate* dereferenceable(616) %agg.result)
+  // CHECK-NEXT:    ret void
+
+  // CHECK-LABEL: define linkonce_odr void @_Z25matrix_template_referenceIfLj10ELj15EEvR19MatrixClassTemplateIT_XT0_EXT1_EES3_(%class.MatrixClassTemplate* dereferenceable(616) %a, %class.MatrixClassTemplate* dereferenceable(616) %b)
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca %class.MatrixClassTemplate*, align 8
+  // CHECK-NEXT:    %b.addr = alloca %class.MatrixClassTemplate*, align 8
+  // CHECK-NEXT:    store %class.MatrixClassTemplate* %a, %class.MatrixClassTemplate** %a.addr, align 8
+  // CHECK-NEXT:    store %class.MatrixClassTemplate* %b, %class.MatrixClassTemplate** %b.addr, align 8
+  // CHECK-NEXT:    %0 = load %class.MatrixClassTemplate*, %class.MatrixClassTemplate** %a.addr, align 8
+  // CHECK-NEXT:    %Data = getelementptr inbounds %class.MatrixClassTemplate, %class.MatrixClassTemplate* %0, i32 0, i32 1
+  // CHECK-NEXT:    %1 = bitcast [150 x float]* %Data to <150 x float>*
+  // CHECK-NEXT:    %2 = load <150 x float>, <150 x float>* %1, align 4
+  // CHECK-NEXT:    %3 = load %class.MatrixClassTemplate*, %class.MatrixClassTemplate** %b.addr, align 8
+  // CHECK-NEXT:    %Data1 = getelementptr inbounds %class.MatrixClassTemplate, %class.MatrixClassTemplate* %3, i32 0, i32 1
+  // CHECK-NEXT:    %4 = bitcast [150 x float]* %Data1 to <150 x float>*
+  // CHECK-NEXT:    store <150 x float> %2, <150 x float>* %4, align 4
+  // CHECK-NEXT:    ret void
+
+  MatrixClassTemplate<float, 10, 15> Result, Arg;
+  Arg.Data = *((MatrixClassTemplate<float, 10, 15>::MatrixTy *)Data);
+  matrix_template_reference(Arg, Result);
+  return Result;
+}
+
+template <class T, unsigned long R, unsigned long C>
+using matrix = T __attribute__((matrix_type(R, C)));
+
+template <int N>
+struct selector {};
+
+template <class T, unsigned long R, unsigned long C>
+selector<0> use_matrix(matrix<T, R, C> &m) {}
+
+template <class T, unsigned long R>
+selector<1> use_matrix(matrix<T, R, 10> &m) {}
+
+template <class T>
+selector<2> use_matrix(matrix<T, 10, 10> &m) {}
+
+template <class T, unsigned long C>
+selector<3> use_matrix(matrix<T, 10, C> &m) {}
+
+template <unsigned long R, unsigned long C>
+selector<4> use_matrix(matrix<float, R, C> &m) {}
+
+void test_template_deduction() {
+
+  // CHECK-LABEL: define void @_Z23test_template_deductionv()
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %m0 = alloca [120 x i32], align 4
+  // CHECK-NEXT:    %w = alloca %struct.selector, align 1
+  // CHECK-NEXT:    %undef.agg.tmp = alloca %struct.selector, align 1
+  // CHECK-NEXT:    %m1 = alloca [100 x i32], align 4
+  // CHECK-NEXT:    %x = alloca %struct.selector.0, align 1
+  // CHECK-NEXT:    %undef.agg.tmp1 = alloca %struct.selector.0, align 1
+  // CHECK-NEXT:    %m2 = alloca [120 x i32], align 4
+  // CHECK-NEXT:    %y = alloca %struct.selector.1, align 1
+  // CHECK-NEXT:    %undef.agg.tmp2 = alloca %struct.selector.1, align 1
+  // CHECK-NEXT:    %m3 = alloca [144 x i32], align 4
+  // CHECK-NEXT:    %z = alloca %struct.selector.2, align 1
+  // CHECK-NEXT:    %undef.agg.tmp3 = alloca %struct.selector.2, align 1
+  // CHECK-NEXT:    %m4 = alloca [144 x float], align 4
+  // CHECK-NEXT:    %v = alloca %struct.selector.3, align 1
+  // CHECK-NEXT:    %undef.agg.tmp4 = alloca %struct.selector.3, align 1
+  // CHECK-NEXT:    call void @_Z10use_matrixIiLm12EE8selectorILi3EERU11matrix_typeXLm10EEXT0_ET_([120 x i32]* dereferenceable(480) %m0)
+  // CHECK-NEXT:    call void @_Z10use_matrixIiE8selectorILi2EERU11matrix_typeLm10ELm10ET_([100 x i32]* dereferenceable(400) %m1)
+  // CHECK-NEXT:    call void @_Z10use_matrixIiLm12EE8selectorILi1EERU11matrix_typeXT0_EXLm10EET_([120 x i32]* dereferenceable(480) %m2)
+  // CHECK-NEXT:    call void @_Z10use_matrixIiLm12ELm12EE8selectorILi0EERU11matrix_typeXT0_EXT1_ET_([144 x i32]* dereferenceable(576) %m3)
+  // CHECK-NEXT:    call void @_Z10use_matrixILm12ELm12EE8selectorILi4EERU11matrix_typeXT_EXT0_Ef([144 x float]* dereferenceable(576) %m4)
+  // CHECK-NEXT:    ret void
+
+  // CHECK-LABEL: define linkonce_odr void @_Z10use_matrixIiLm12EE8selectorILi3EERU11matrix_typeXLm10EEXT0_ET_([120 x i32]* dereferenceable(480) %m)
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %m.addr = alloca [120 x i32]*, align 8
+  // CHECK-NEXT:    store [120 x i32]* %m, [120 x i32]** %m.addr, align 8
+  // CHECK-NEXT:    call void @llvm.trap()
+  // CHECK-NEXT:    unreachable
+
+  // CHECK-LABEL: define linkonce_odr void @_Z10use_matrixIiE8selectorILi2EERU11matrix_typeLm10ELm10ET_([100 x i32]* dereferenceable(400) %m)
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %m.addr = alloca [100 x i32]*, align 8
+  // CHECK-NEXT:    store [100 x i32]* %m, [100 x i32]** %m.addr, align 8
+  // CHECK-NEXT:    call void @llvm.trap()
+  // CHECK-NEXT:    unreachable
+
+  // CHECK-LABEL: define linkonce_odr void @_Z10use_matrixIiLm12EE8selectorILi1EERU11matrix_typeXT0_EXLm10EET_([120 x i32]* dereferenceable(480) %m)
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %m.addr = alloca [120 x i32]*, align 8
+  // CHECK-NEXT:    store [120 x i32]* %m, [120 x i32]** %m.addr, align 8
+  // CHECK-NEXT:    call void @llvm.trap()
+  // CHECK-NEXT:    unreachable
+
+  // CHECK-LABEL: define linkonce_odr void @_Z10use_matrixIiLm12ELm12EE8selectorILi0EERU11matrix_typeXT0_EXT1_ET_([144 x i32]* dereferenceable(576) %m)
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %m.addr = alloca [144 x i32]*, align 8
+  // CHECK-NEXT:    store [144 x i32]* %m, [144 x i32]** %m.addr, align 8
+  // CHECK-NEXT:    call void @llvm.trap()
+  // CHECK-NEXT:    unreachable
+
+  // CHECK-LABEL: define linkonce_odr void @_Z10use_matrixILm12ELm12EE8selectorILi4EERU11matrix_typeXT_EXT0_Ef([144 x float]* dereferenceable(576)
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %m.addr = alloca [144 x float]*, align 8
+  // CHECK-NEXT:    store [144 x float]* %m, [144 x float]** %m.addr, align 8
+  // CHECK-NEXT:    call void @llvm.trap()
+  // CHECK-NEXT:    unreachable
+
+  matrix<int, 10, 12> m0;
+  selector<3> w = use_matrix(m0);
+  matrix<int, 10, 10> m1;
+  selector<2> x = use_matrix(m1);
+  matrix<int, 12, 10> m2;
+  selector<1> y = use_matrix(m2);
+  matrix<int, 12, 12> m3;
+  selector<0> z = use_matrix(m3);
+  matrix<float, 12, 12> m4;
+  selector<4> v = use_matrix(m4);
+}
+
+template <auto R>
+void foo(matrix<int, R, 10> &m) {
+}
+
+void test_auto_t() {
+  // CHECK-LABEL: define void @_Z11test_auto_tv()
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %m = alloca [130 x i32], align 4
+  // CHECK-NEXT:    call void @_Z3fooILm13EEvRU11matrix_typeXT_EXLm10EEi([130 x i32]* dereferenceable(520) %m)
+  // CHECK-NEXT:    ret void
+
+  // CHECK-LABEL: define linkonce_odr void @_Z3fooILm13EEvRU11matrix_typeXT_EXLm10EEi([130 x i32]* dereferenceable(520) %m)
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %m.addr = alloca [130 x i32]*, align 8
+  // CHECK-NEXT:    store [130 x i32]* %m, [130 x i32]** %m.addr, align 8
+  // CHECK-NEXT:    ret void
+
+  matrix<int, 13, 10> m;
+  foo(m);
+}
Index: clang/test/CodeGen/matrix-type.c
===================================================================
--- /dev/null
+++ clang/test/CodeGen/matrix-type.c
@@ -0,0 +1,158 @@
+// RUN: %clang_cc1 -fenable-matrix -triple x86_64-apple-darwin %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s
+
+#if !__has_extension(matrix_types)
+#error Expected extension 'matrix_types' to be enabled
+#endif
+
+typedef double dx5x5_t __attribute__((matrix_type(5, 5)));
+
+// CHECK: %struct.Matrix = type { i8, [12 x float], float }
+
+void load_store_double(dx5x5_t *a, dx5x5_t *b) {
+  // CHECK-LABEL:  define void @load_store_double(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [25 x double]*, align 8
+  // CHECK-NEXT:    %b.addr = alloca [25 x double]*, align 8
+  // CHECK-NEXT:    store [25 x double]* %a, [25 x double]** %a.addr, align 8
+  // CHECK-NEXT:    store [25 x double]* %b, [25 x double]** %b.addr, align 8
+  // CHECK-NEXT:    %0 = load [25 x double]*, [25 x double]** %b.addr, align 8
+  // CHECK-NEXT:    %1 = bitcast [25 x double]* %0 to <25 x double>*
+  // CHECK-NEXT:    %2 = load <25 x double>, <25 x double>* %1, align 8
+  // CHECK-NEXT:    %3 = load [25 x double]*, [25 x double]** %a.addr, align 8
+  // CHECK-NEXT:    %4 = bitcast [25 x double]* %3 to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %2, <25 x double>* %4, align 8
+  // CHECK-NEXT:   ret void
+
+  *a = *b;
+}
+
+typedef float fx3x4_t __attribute__((matrix_type(3, 4)));
+void load_store_float(fx3x4_t *a, fx3x4_t *b) {
+  // CHECK-LABEL:  define void @load_store_float(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [12 x float]*, align 8
+  // CHECK-NEXT:    %b.addr = alloca [12 x float]*, align 8
+  // CHECK-NEXT:    store [12 x float]* %a, [12 x float]** %a.addr, align 8
+  // CHECK-NEXT:    store [12 x float]* %b, [12 x float]** %b.addr, align 8
+  // CHECK-NEXT:    %0 = load [12 x float]*, [12 x float]** %b.addr, align 8
+  // CHECK-NEXT:    %1 = bitcast [12 x float]* %0 to <12 x float>*
+  // CHECK-NEXT:    %2 = load <12 x float>, <12 x float>* %1, align 4
+  // CHECK-NEXT:    %3 = load [12 x float]*, [12 x float]** %a.addr, align 8
+  // CHECK-NEXT:    %4 = bitcast [12 x float]* %3 to <12 x float>*
+  // CHECK-NEXT:    store <12 x float> %2, <12 x float>* %4, align 4
+  // CHECK-NEXT:   ret void
+
+  *a = *b;
+}
+
+typedef int ix3x4_t __attribute__((matrix_type(4, 3)));
+void load_store_int(ix3x4_t *a, ix3x4_t *b) {
+  // CHECK-LABEL:  define void @load_store_int(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [12 x i32]*, align 8
+  // CHECK-NEXT:    %b.addr = alloca [12 x i32]*, align 8
+  // CHECK-NEXT:    store [12 x i32]* %a, [12 x i32]** %a.addr, align 8
+  // CHECK-NEXT:    store [12 x i32]* %b, [12 x i32]** %b.addr, align 8
+  // CHECK-NEXT:    %0 = load [12 x i32]*, [12 x i32]** %b.addr, align 8
+  // CHECK-NEXT:    %1 = bitcast [12 x i32]* %0 to <12 x i32>*
+  // CHECK-NEXT:    %2 = load <12 x i32>, <12 x i32>* %1, align 4
+  // CHECK-NEXT:    %3 = load [12 x i32]*, [12 x i32]** %a.addr, align 8
+  // CHECK-NEXT:    %4 = bitcast [12 x i32]* %3 to <12 x i32>*
+  // CHECK-NEXT:    store <12 x i32> %2, <12 x i32>* %4, align 4
+  // CHECK-NEXT:   ret void
+
+  *a = *b;
+}
+
+typedef unsigned long long ullx3x4_t __attribute__((matrix_type(4, 3)));
+void load_store_ull(ullx3x4_t *a, ullx3x4_t *b) {
+  // CHECK-LABEL:  define void @load_store_ull(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [12 x i64]*, align 8
+  // CHECK-NEXT:    %b.addr = alloca [12 x i64]*, align 8
+  // CHECK-NEXT:    store [12 x i64]* %a, [12 x i64]** %a.addr, align 8
+  // CHECK-NEXT:    store [12 x i64]* %b, [12 x i64]** %b.addr, align 8
+  // CHECK-NEXT:    %0 = load [12 x i64]*, [12 x i64]** %b.addr, align 8
+  // CHECK-NEXT:    %1 = bitcast [12 x i64]* %0 to <12 x i64>*
+  // CHECK-NEXT:    %2 = load <12 x i64>, <12 x i64>* %1, align 8
+  // CHECK-NEXT:    %3 = load [12 x i64]*, [12 x i64]** %a.addr, align 8
+  // CHECK-NEXT:    %4 = bitcast [12 x i64]* %3 to <12 x i64>*
+  // CHECK-NEXT:    store <12 x i64> %2, <12 x i64>* %4, align 8
+  // CHECK-NEXT:   ret void
+
+  *a = *b;
+}
+
+typedef __fp16 fp16x3x4_t __attribute__((matrix_type(4, 3)));
+void load_store_fp16(fp16x3x4_t *a, fp16x3x4_t *b) {
+  // CHECK-LABEL:  define void @load_store_fp16(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [12 x half]*, align 8
+  // CHECK-NEXT:    %b.addr = alloca [12 x half]*, align 8
+  // CHECK-NEXT:    store [12 x half]* %a, [12 x half]** %a.addr, align 8
+  // CHECK-NEXT:    store [12 x half]* %b, [12 x half]** %b.addr, align 8
+  // CHECK-NEXT:    %0 = load [12 x half]*, [12 x half]** %b.addr, align 8
+  // CHECK-NEXT:    %1 = bitcast [12 x half]* %0 to <12 x half>*
+  // CHECK-NEXT:    %2 = load <12 x half>, <12 x half>* %1, align 2
+  // CHECK-NEXT:    %3 = load [12 x half]*, [12 x half]** %a.addr, align 8
+  // CHECK-NEXT:    %4 = bitcast [12 x half]* %3 to <12 x half>*
+  // CHECK-NEXT:    store <12 x half> %2, <12 x half>* %4, align 2
+  // CHECK-NEXT:   ret void
+
+  *a = *b;
+}
+
+typedef float fx3x3_t __attribute__((matrix_type(3, 3)));
+
+void parameter_passing(fx3x3_t a, fx3x3_t *b) {
+  // CHECK-LABEL: define void @parameter_passing(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [9 x float], align 4
+  // CHECK-NEXT:    %b.addr = alloca [9 x float]*, align 8
+  // CHECK-NEXT:    %0 = bitcast [9 x float]* %a.addr to <9 x float>*
+  // CHECK-NEXT:    store <9 x float> %a, <9 x float>* %0, align 4
+  // CHECK-NEXT:    store [9 x float]* %b, [9 x float]** %b.addr, align 8
+  // CHECK-NEXT:    %1 = load <9 x float>, <9 x float>* %0, align 4
+  // CHECK-NEXT:    %2 = load [9 x float]*, [9 x float]** %b.addr, align 8
+  // CHECK-NEXT:    %3 = bitcast [9 x float]* %2 to <9 x float>*
+  // CHECK-NEXT:    store <9 x float> %1, <9 x float>* %3, align 4
+  // CHECK-NEXT:    ret void
+  *b = a;
+}
+
+fx3x3_t return_matrix(fx3x3_t *a) {
+  // CHECK-LABEL: define <9 x float> @return_matrix
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [9 x float]*, align 8
+  // CHECK-NEXT:    store [9 x float]* %a, [9 x float]** %a.addr, align 8
+  // CHECK-NEXT:    %0 = load [9 x float]*, [9 x float]** %a.addr, align 8
+  // CHECK-NEXT:    %1 = bitcast [9 x float]* %0 to <9 x float>*
+  // CHECK-NEXT:    %2 = load <9 x float>, <9 x float>* %1, align 4
+  // CHECK-NEXT:    ret <9 x float> %2
+  return *a;
+}
+
+typedef struct {
+  char Tmp1;
+  fx3x4_t Data;
+  float Tmp2;
+} Matrix;
+
+void matrix_struct(Matrix *a, Matrix *b) {
+  // CHECK-LABEL: define void @matrix_struct(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca %struct.Matrix*, align 8
+  // CHECK-NEXT:    %b.addr = alloca %struct.Matrix*, align 8
+  // CHECK-NEXT:    store %struct.Matrix* %a, %struct.Matrix** %a.addr, align 8
+  // CHECK-NEXT:    store %struct.Matrix* %b, %struct.Matrix** %b.addr, align 8
+  // CHECK-NEXT:    %0 = load %struct.Matrix*, %struct.Matrix** %a.addr, align 8
+  // CHECK-NEXT:    %Data = getelementptr inbounds %struct.Matrix, %struct.Matrix* %0, i32 0, i32 1
+  // CHECK-NEXT:    %1 = bitcast [12 x float]* %Data to <12 x float>*
+  // CHECK-NEXT:    %2 = load <12 x float>, <12 x float>* %1, align 4
+  // CHECK-NEXT:    %3 = load %struct.Matrix*, %struct.Matrix** %b.addr, align 8
+  // CHECK-NEXT:    %Data1 = getelementptr inbounds %struct.Matrix, %struct.Matrix* %3, i32 0, i32 1
+  // CHECK-NEXT:    %4 = bitcast [12 x float]* %Data1 to <12 x float>*
+  // CHECK-NEXT:    store <12 x float> %2, <12 x float>* %4, align 4
+  // CHECK-NEXT:    ret void
+  b->Data = a->Data;
+}
Index: clang/test/CodeGen/debug-info-matrix-types.c
===================================================================
--- /dev/null
+++ clang/test/CodeGen/debug-info-matrix-types.c
@@ -0,0 +1,19 @@
+// RUN: %clang_cc1 -fenable-matrix -triple x86_64-apple-darwin %s -debug-info-kind=limited -emit-llvm -disable-llvm-passes -o - | FileCheck %s
+
+typedef double dx2x3_t __attribute__((matrix_type(2, 3)));
+
+void load_store_double(dx2x3_t *a, dx2x3_t *b) {
+  // CHECK-DAG:  @llvm.dbg.declare(metadata [6 x double]** %a.addr, metadata [[EXPR_A:![0-9]+]]
+  // CHECK-DAG:  @llvm.dbg.declare(metadata [6 x double]** %b.addr, metadata [[EXPR_B:![0-9]+]]
+  // CHECK: [[PTR_TY:![0-9]+]] = !DIDerivedType(tag: DW_TAG_pointer_type, baseType: [[TYPEDEF:![0-9]+]], size: 64)
+  // CHECK: [[TYPEDEF]] = !DIDerivedType(tag: DW_TAG_typedef, name: "dx2x3_t", {{.+}} baseType: [[MATRIX_TY:![0-9]+]])
+  // CHECK: [[MATRIX_TY]] = !DICompositeType(tag: DW_TAG_array_type, baseType: [[ELT_TY:![0-9]+]], size: 384, elements: [[ELEMENTS:![0-9]+]])
+  // CHECK: [[ELT_TY]] = !DIBasicType(name: "double", size: 64, encoding: DW_ATE_float)
+  // CHECK: [[ELEMENTS]] = !{[[COLS:![0-9]+]], [[ROWS:![0-9]+]]}
+  // CHECK: [[COLS]] = !DISubrange(count: 3)
+  // CHECK: [[ROWS]] = !DISubrange(count: 2)
+  // CHECK: [[EXPR_A]] = !DILocalVariable(name: "a", arg: 1, {{.+}} type: [[PTR_TY]])
+  // CHECK: [[EXPR_B]] = !DILocalVariable(name: "b", arg: 2, {{.+}} type: [[PTR_TY]])
+
+  *a = *b;
+}
Index: clang/lib/Serialization/ASTWriter.cpp
===================================================================
--- clang/lib/Serialization/ASTWriter.cpp
+++ clang/lib/Serialization/ASTWriter.cpp
@@ -288,6 +288,25 @@
   Record.AddSourceLocation(TL.getNameLoc());
 }
 
+void TypeLocWriter::VisitConstantMatrixTypeLoc(ConstantMatrixTypeLoc TL) {
+  Record.AddSourceLocation(TL.getAttrNameLoc());
+  SourceRange range = TL.getAttrOperandParensRange();
+  Record.AddSourceLocation(range.getBegin());
+  Record.AddSourceLocation(range.getEnd());
+  Record.AddStmt(TL.getAttrRowOperand());
+  Record.AddStmt(TL.getAttrColumnOperand());
+}
+
+void TypeLocWriter::VisitDependentSizedMatrixTypeLoc(
+    DependentSizedMatrixTypeLoc TL) {
+  Record.AddSourceLocation(TL.getAttrNameLoc());
+  SourceRange range = TL.getAttrOperandParensRange();
+  Record.AddSourceLocation(range.getBegin());
+  Record.AddSourceLocation(range.getEnd());
+  Record.AddStmt(TL.getAttrRowOperand());
+  Record.AddStmt(TL.getAttrColumnOperand());
+}
+
 void TypeLocWriter::VisitFunctionTypeLoc(FunctionTypeLoc TL) {
   Record.AddSourceLocation(TL.getLocalRangeBegin());
   Record.AddSourceLocation(TL.getLParenLoc());
Index: clang/lib/Serialization/ASTReader.cpp
===================================================================
--- clang/lib/Serialization/ASTReader.cpp
+++ clang/lib/Serialization/ASTReader.cpp
@@ -6554,6 +6554,21 @@
   TL.setNameLoc(readSourceLocation());
 }
 
+void TypeLocReader::VisitConstantMatrixTypeLoc(ConstantMatrixTypeLoc TL) {
+  TL.setAttrNameLoc(readSourceLocation());
+  TL.setAttrOperandParensRange(Reader.readSourceRange());
+  TL.setAttrRowOperand(Reader.readExpr());
+  TL.setAttrColumnOperand(Reader.readExpr());
+}
+
+void TypeLocReader::VisitDependentSizedMatrixTypeLoc(
+    DependentSizedMatrixTypeLoc TL) {
+  TL.setAttrNameLoc(readSourceLocation());
+  TL.setAttrOperandParensRange(Reader.readSourceRange());
+  TL.setAttrRowOperand(Reader.readExpr());
+  TL.setAttrColumnOperand(Reader.readExpr());
+}
+
 void TypeLocReader::VisitFunctionTypeLoc(FunctionTypeLoc TL) {
   TL.setLocalRangeBegin(readSourceLocation());
   TL.setLParenLoc(readSourceLocation());
Index: clang/lib/Sema/TreeTransform.h
===================================================================
--- clang/lib/Sema/TreeTransform.h
+++ clang/lib/Sema/TreeTransform.h
@@ -894,6 +894,16 @@
                                               Expr *SizeExpr,
                                               SourceLocation AttributeLoc);
 
+  /// Build a new matrix type given the element type and dimensions.
+  QualType RebuildConstantMatrixType(QualType ElementType, unsigned NumRows,
+                                     unsigned NumColumns);
+
+  /// Build a new matrix type given the type and dependently-defined
+  /// dimensions.
+  QualType RebuildDependentSizedMatrixType(QualType ElementType, Expr *RowExpr,
+                                           Expr *ColumnExpr,
+                                           SourceLocation AttributeLoc);
+
   /// Build a new DependentAddressSpaceType or return the pointee
   /// type variable with the correct address space (retrieved from
   /// AddrSpaceExpr) applied to it. The former will be returned in cases
@@ -5179,6 +5189,86 @@
   return Result;
 }
 
+template <typename Derived>
+QualType
+TreeTransform<Derived>::TransformConstantMatrixType(TypeLocBuilder &TLB,
+                                                    ConstantMatrixTypeLoc TL) {
+  const ConstantMatrixType *T = TL.getTypePtr();
+  QualType ElementType = getDerived().TransformType(T->getElementType());
+  if (ElementType.isNull())
+    return QualType();
+
+  QualType Result = TL.getType();
+  if (getDerived().AlwaysRebuild() || ElementType != T->getElementType()) {
+    Result = getDerived().RebuildConstantMatrixType(
+        ElementType, T->getNumRows(), T->getNumColumns());
+    if (Result.isNull())
+      return QualType();
+  }
+
+  ConstantMatrixTypeLoc NewTL = TLB.push<ConstantMatrixTypeLoc>(Result);
+  NewTL.setAttrNameLoc(TL.getAttrNameLoc());
+  NewTL.setAttrOperandParensRange(TL.getAttrOperandParensRange());
+  NewTL.setAttrRowOperand(TL.getAttrRowOperand());
+  NewTL.setAttrColumnOperand(TL.getAttrColumnOperand());
+
+  return Result;
+}
+
+template <typename Derived>
+QualType TreeTransform<Derived>::TransformDependentSizedMatrixType(
+    TypeLocBuilder &TLB, DependentSizedMatrixTypeLoc TL) {
+  const DependentSizedMatrixType *T = TL.getTypePtr();
+
+  QualType ElementType = getDerived().TransformType(T->getElementType());
+  if (ElementType.isNull()) {
+    return QualType();
+  }
+
+  // Matrix dimensions are constant expressions.
+  EnterExpressionEvaluationContext Unevaluated(
+      SemaRef, Sema::ExpressionEvaluationContext::ConstantEvaluated);
+
+  Expr *origRows = TL.getAttrRowOperand();
+  if (!origRows)
+    origRows = T->getRowExpr();
+  Expr *origColumns = TL.getAttrColumnOperand();
+  if (!origColumns)
+    origColumns = T->getColumnExpr();
+
+  ExprResult rowResult = getDerived().TransformExpr(origRows);
+  rowResult = SemaRef.ActOnConstantExpression(rowResult);
+  if (rowResult.isInvalid())
+    return QualType();
+
+  ExprResult columnResult = getDerived().TransformExpr(origColumns);
+  columnResult = SemaRef.ActOnConstantExpression(columnResult);
+  if (columnResult.isInvalid())
+    return QualType();
+
+  Expr *rows = rowResult.get();
+  Expr *columns = columnResult.get();
+
+  QualType Result = TL.getType();
+  if (getDerived().AlwaysRebuild() || ElementType != T->getElementType() ||
+      rows != origRows || columns != origColumns) {
+    Result = getDerived().RebuildDependentSizedMatrixType(
+        ElementType, rows, columns, T->getAttributeLoc());
+
+    if (Result.isNull())
+      return QualType();
+  }
+
+  // We might have any sort of matrix type now, but fortunately they
+  // all have the same location layout.
+  MatrixTypeLoc NewTL = TLB.push<MatrixTypeLoc>(Result);
+  NewTL.setAttrNameLoc(TL.getAttrNameLoc());
+  NewTL.setAttrOperandParensRange(TL.getAttrOperandParensRange());
+  NewTL.setAttrRowOperand(rows);
+  NewTL.setAttrColumnOperand(columns);
+  return Result;
+}
+
 template <typename Derived>
 QualType TreeTransform<Derived>::TransformDependentAddressSpaceType(
     TypeLocBuilder &TLB, DependentAddressSpaceTypeLoc TL) {
@@ -13750,6 +13840,21 @@
   return SemaRef.BuildExtVectorType(ElementType, SizeExpr, AttributeLoc);
 }
 
+template <typename Derived>
+QualType TreeTransform<Derived>::RebuildConstantMatrixType(
+    QualType ElementType, unsigned NumRows, unsigned NumColumns) {
+  return SemaRef.Context.getConstantMatrixType(ElementType, NumRows,
+                                               NumColumns);
+}
+
+template <typename Derived>
+QualType TreeTransform<Derived>::RebuildDependentSizedMatrixType(
+    QualType ElementType, Expr *RowExpr, Expr *ColumnExpr,
+    SourceLocation AttributeLoc) {
+  return SemaRef.BuildMatrixType(ElementType, RowExpr, ColumnExpr,
+                                 AttributeLoc);
+}
+
 template<typename Derived>
 QualType TreeTransform<Derived>::RebuildFunctionProtoType(
     QualType T,
Index: clang/lib/Sema/SemaType.cpp
===================================================================
--- clang/lib/Sema/SemaType.cpp
+++ clang/lib/Sema/SemaType.cpp
@@ -2492,14 +2492,15 @@
   if (!VecSize.isIntN(61)) {
     // Bit size will overflow uint64.
     Diag(AttrLoc, diag::err_attribute_size_too_large)
-        << SizeExpr->getSourceRange();
+        << SizeExpr->getSourceRange() << "vector";
     return QualType();
   }
   uint64_t VectorSizeBits = VecSize.getZExtValue() * 8;
   unsigned TypeSize = static_cast<unsigned>(Context.getTypeSize(CurType));
 
   if (VectorSizeBits == 0) {
-    Diag(AttrLoc, diag::err_attribute_zero_size) << SizeExpr->getSourceRange();
+    Diag(AttrLoc, diag::err_attribute_zero_size)
+        << SizeExpr->getSourceRange() << "vector";
     return QualType();
   }
 
@@ -2511,7 +2512,7 @@
 
   if (VectorSizeBits / TypeSize > std::numeric_limits<uint32_t>::max()) {
     Diag(AttrLoc, diag::err_attribute_size_too_large)
-        << SizeExpr->getSourceRange();
+        << SizeExpr->getSourceRange() << "vector";
     return QualType();
   }
 
@@ -2549,7 +2550,7 @@
 
     if (!vecSize.isIntN(32)) {
       Diag(AttrLoc, diag::err_attribute_size_too_large)
-          << ArraySize->getSourceRange();
+          << ArraySize->getSourceRange() << "vector";
       return QualType();
     }
     // Unlike gcc's vector_size attribute, the size is specified as the
@@ -2558,7 +2559,7 @@
 
     if (vectorSize == 0) {
       Diag(AttrLoc, diag::err_attribute_zero_size)
-      << ArraySize->getSourceRange();
+          << ArraySize->getSourceRange() << "vector";
       return QualType();
     }
 
@@ -2568,6 +2569,84 @@
   return Context.getDependentSizedExtVectorType(T, ArraySize, AttrLoc);
 }
 
+QualType Sema::BuildMatrixType(QualType ElementTy, Expr *NumRows, Expr *NumCols,
+                               SourceLocation AttrLoc) {
+  assert(Context.getLangOpts().MatrixTypes &&
+         "Should never build a matrix type when it is disabled");
+
+  if (NumRows->isTypeDependent() || NumCols->isTypeDependent() ||
+      NumRows->isValueDependent() || NumCols->isValueDependent())
+    return Context.getDependentSizedMatrixType(ElementTy, NumRows, NumCols,
+                                               AttrLoc);
+
+  // Check element type, if it is not dependent.
+  if (!ElementTy->isDependentType() &&
+      !MatrixType::isValidElementType(ElementTy)) {
+    Diag(AttrLoc, diag::err_attribute_invalid_matrix_type) << ElementTy;
+    return QualType();
+  }
+
+  // Both row and column values can only be 20 bit wide currently.
+  llvm::APSInt ValueRows(32), ValueColumns(32);
+
+  bool const RowsIsInteger = NumRows->isIntegerConstantExpr(ValueRows, Context);
+  bool const ColumnsIsInteger =
+      NumCols->isIntegerConstantExpr(ValueColumns, Context);
+
+  auto const RowRange = NumRows->getSourceRange();
+  auto const ColRange = NumCols->getSourceRange();
+
+  // Both are row and column expressions are invalid.
+  if (!RowsIsInteger && !ColumnsIsInteger) {
+    Diag(AttrLoc, diag::err_attribute_argument_type)
+        << "matrix_type" << AANT_ArgumentIntegerConstant << RowRange
+        << ColRange;
+    return QualType();
+  }
+
+  // Only the row expression is invalid.
+  if (!RowsIsInteger) {
+    Diag(AttrLoc, diag::err_attribute_argument_type)
+        << "matrix_type" << AANT_ArgumentIntegerConstant << RowRange;
+    return QualType();
+  }
+
+  // Only the column expression is invalid.
+  if (!ColumnsIsInteger) {
+    Diag(AttrLoc, diag::err_attribute_argument_type)
+        << "matrix_type" << AANT_ArgumentIntegerConstant << ColRange;
+    return QualType();
+  }
+
+  // Check the matrix dimensions.
+  unsigned MatrixRows = static_cast<unsigned>(ValueRows.getZExtValue());
+  unsigned MatrixColumns = static_cast<unsigned>(ValueColumns.getZExtValue());
+  if (MatrixRows == 0 && MatrixColumns == 0) {
+    Diag(AttrLoc, diag::err_attribute_zero_size)
+        << "matrix" << RowRange << ColRange;
+    return QualType();
+  }
+  if (MatrixRows == 0) {
+    Diag(AttrLoc, diag::err_attribute_zero_size) << "matrix" << RowRange;
+    return QualType();
+  }
+  if (MatrixColumns == 0) {
+    Diag(AttrLoc, diag::err_attribute_zero_size) << "matrix" << ColRange;
+    return QualType();
+  }
+  if (!ConstantMatrixType::isDimensionValid(MatrixRows)) {
+    Diag(AttrLoc, diag::err_attribute_size_too_large)
+        << RowRange << "matrix row";
+    return QualType();
+  }
+  if (!ConstantMatrixType::isDimensionValid(MatrixColumns)) {
+    Diag(AttrLoc, diag::err_attribute_size_too_large)
+        << ColRange << "matrix column";
+    return QualType();
+  }
+  return Context.getConstantMatrixType(ElementTy, MatrixRows, MatrixColumns);
+}
+
 bool Sema::CheckFunctionReturnType(QualType T, SourceLocation Loc) {
   if (T->isArrayType() || T->isFunctionType()) {
     Diag(Loc, diag::err_func_returning_array_function)
@@ -6013,6 +6092,21 @@
       "no address_space attribute found at the expected location!");
 }
 
+static void fillMatrixTypeLoc(MatrixTypeLoc MTL,
+                              const ParsedAttributesView &Attrs) {
+  for (const ParsedAttr &AL : Attrs) {
+    if (AL.getKind() == ParsedAttr::AT_MatrixType) {
+      MTL.setAttrNameLoc(AL.getLoc());
+      MTL.setAttrRowOperand(AL.getArgAsExpr(0));
+      MTL.setAttrColumnOperand(AL.getArgAsExpr(1));
+      MTL.setAttrOperandParensRange(SourceRange());
+      return;
+    }
+  }
+
+  llvm_unreachable("no matrix_type attribute found at the expected location!");
+}
+
 /// Create and instantiate a TypeSourceInfo with type source information.
 ///
 /// \param T QualType referring to the type as written in source code.
@@ -6061,6 +6155,9 @@
       CurrTL = TL.getPointeeTypeLoc().getUnqualifiedLoc();
     }
 
+    if (MatrixTypeLoc TL = CurrTL.getAs<MatrixTypeLoc>())
+      fillMatrixTypeLoc(TL, D.getTypeObject(i).getAttrs());
+
     // FIXME: Ordering here?
     while (AdjustedTypeLoc TL = CurrTL.getAs<AdjustedTypeLoc>())
       CurrTL = TL.getNextTypeLoc().getUnqualifiedLoc();
@@ -7706,6 +7803,68 @@
   }
 }
 
+/// HandleMatrixTypeAttr - "matrix_type" attribute, like ext_vector_type
+static void HandleMatrixTypeAttr(QualType &CurType, const ParsedAttr &Attr,
+                                 Sema &S) {
+  if (!S.getLangOpts().MatrixTypes) {
+    S.Diag(Attr.getLoc(), diag::err_builtin_matrix_disabled);
+    return;
+  }
+
+  if (Attr.getNumArgs() != 2) {
+    S.Diag(Attr.getLoc(), diag::err_attribute_wrong_number_arguments)
+        << Attr << 2;
+    return;
+  }
+
+  Expr *RowsExpr = nullptr;
+  Expr *ColsExpr = nullptr;
+
+  // TODO: Refactor parameter extraction into separate function
+  // Get the number of rows
+  if (Attr.isArgIdent(0)) {
+    CXXScopeSpec SS;
+    SourceLocation TemplateKeywordLoc;
+    UnqualifiedId id;
+    id.setIdentifier(Attr.getArgAsIdent(0)->Ident, Attr.getLoc());
+    ExprResult Rows = S.ActOnIdExpression(S.getCurScope(), SS,
+                                          TemplateKeywordLoc, id, false, false);
+
+    if (Rows.isInvalid())
+      // TODO: maybe a good error message would be nice here
+      return;
+    RowsExpr = Rows.get();
+  } else {
+    assert(Attr.isArgExpr(0) &&
+           "Argument to should either be an identity or expression");
+    RowsExpr = Attr.getArgAsExpr(0);
+  }
+
+  // Get the number of columns
+  if (Attr.isArgIdent(1)) {
+    CXXScopeSpec SS;
+    SourceLocation TemplateKeywordLoc;
+    UnqualifiedId id;
+    id.setIdentifier(Attr.getArgAsIdent(1)->Ident, Attr.getLoc());
+    ExprResult Columns = S.ActOnIdExpression(
+        S.getCurScope(), SS, TemplateKeywordLoc, id, false, false);
+
+    if (Columns.isInvalid())
+      // TODO: a good error message would be nice here
+      return;
+    RowsExpr = Columns.get();
+  } else {
+    assert(Attr.isArgExpr(1) &&
+           "Argument to should either be an identity or expression");
+    ColsExpr = Attr.getArgAsExpr(1);
+  }
+
+  // Create the matrix type.
+  QualType T = S.BuildMatrixType(CurType, RowsExpr, ColsExpr, Attr.getLoc());
+  if (!T.isNull())
+    CurType = T;
+}
+
 static void HandleLifetimeBoundAttr(TypeProcessingState &State,
                                     QualType &CurType,
                                     ParsedAttr &Attr) {
@@ -7857,6 +8016,11 @@
       break;
     }
 
+    case ParsedAttr::AT_MatrixType:
+      HandleMatrixTypeAttr(type, attr, state.getSema());
+      attr.setUsedAsTypeAttr();
+      break;
+
     MS_TYPE_ATTRS_CASELIST:
       if (!handleMSPointerTypeQualifierAttr(state, attr, type))
         attr.setUsedAsTypeAttr();
Index: clang/lib/Sema/SemaTemplateDeduction.cpp
===================================================================
--- clang/lib/Sema/SemaTemplateDeduction.cpp
+++ clang/lib/Sema/SemaTemplateDeduction.cpp
@@ -2055,6 +2055,116 @@
       return Sema::TDK_NonDeducedMismatch;
     }
 
+    //     (clang extension)
+    //
+    //     T __attribute__((matrix_type(<integral constant>,
+    //                                  <integral constant>)))
+    case Type::ConstantMatrix: {
+      const ConstantMatrixType *MatrixArg = dyn_cast<ConstantMatrixType>(Arg);
+      if (!MatrixArg)
+        return Sema::TDK_NonDeducedMismatch;
+
+      const ConstantMatrixType *MatrixParam = cast<ConstantMatrixType>(Param);
+      // Check that the dimensions are the same
+      if (MatrixParam->getNumRows() != MatrixArg->getNumRows() ||
+          MatrixParam->getNumColumns() != MatrixArg->getNumColumns()) {
+        return Sema::TDK_NonDeducedMismatch;
+      }
+      // Perform deduction on element types.
+      unsigned SubTDF = TDF & TDF_IgnoreQualifiers;
+      return DeduceTemplateArgumentsByTypeMatch(
+          S, TemplateParams, MatrixParam->getElementType(),
+          MatrixArg->getElementType(), Info, Deduced, SubTDF);
+    }
+
+    case Type::DependentSizedMatrix: {
+      const MatrixType *MatrixArg = dyn_cast<MatrixType>(Arg);
+      if (!MatrixArg)
+        return Sema::TDK_NonDeducedMismatch;
+
+      unsigned SubTDF = TDF & TDF_IgnoreQualifiers;
+
+      // Check the element type of the matrixes.
+      const DependentSizedMatrixType *MatrixParam =
+          cast<DependentSizedMatrixType>(Param);
+      if (Sema::TemplateDeductionResult Result =
+              DeduceTemplateArgumentsByTypeMatch(
+                  S, TemplateParams, MatrixParam->getElementType(),
+                  MatrixArg->getElementType(), Info, Deduced, SubTDF))
+        return Result;
+
+      // Determine if the number of rows and columns is something we can deduce.
+      NonTypeTemplateParmDecl *RowNTTP =
+          getDeducedParameterFromExpr(Info, MatrixParam->getRowExpr());
+      NonTypeTemplateParmDecl *ColumnNTTP =
+          getDeducedParameterFromExpr(Info, MatrixParam->getColumnExpr());
+
+      // Try to deduce a matrix dimension. Note that DependentSizedMatrixType
+      // can have some dimensions that are non-dependent constant integer
+      // expression.
+      auto DeduceMatrixArg =
+          [&](NonTypeTemplateParmDecl *NTTP,
+              std::function<unsigned(const ConstantMatrixType *)> GetDimension,
+              std::function<Expr *(const DependentSizedMatrixType *)>
+                  GetDimensionExpr) {
+            if (!NTTP) {
+              // If we did not deduce an NTTP from MatrixParam, fail if
+              // MatrixArg is a DependentSizedMatrixType and the dimension is
+              // not a constant.
+              if (const DependentSizedMatrixType *DepMatrixArg =
+                      dyn_cast<DependentSizedMatrixType>(MatrixArg)) {
+                llvm::APSInt NumAPSInt(
+                    S.Context.getTypeSize(S.Context.getSizeType()));
+                Expr *DimExpr = GetDimensionExpr(DepMatrixArg);
+                if (DimExpr->isValueDependent() ||
+                    !DimExpr->isIntegerConstantExpr(NumAPSInt, S.Context))
+                  return Sema::TDK_NonDeducedMismatch;
+              }
+
+              return Sema::TDK_Success;
+            }
+            assert(NTTP->getDepth() == Info.getDeducedDepth() &&
+                   "saw non-type template parameter with wrong depth");
+            assert((isa<ConstantMatrixType>(MatrixArg) ||
+                    isa<DependentSizedMatrixType>(MatrixArg)) &&
+                   "should only pass a subclass of MatrixType");
+
+            // The constant value for the dimension, if MatrixArg is a
+            // ConstantMatrixType or a DependentSizedMatrixType and the
+            // dimension is an integer constant expression.
+            llvm::APSInt ConstIntVal(
+                S.Context.getTypeSize(S.Context.getSizeType()));
+            if (const DependentSizedMatrixType *DepMatrixArg =
+                    dyn_cast<DependentSizedMatrixType>(MatrixArg)) {
+              Expr *DimExpr = GetDimensionExpr(DepMatrixArg);
+              if (DimExpr->isValueDependent() ||
+                  !DimExpr->isIntegerConstantExpr(ConstIntVal, S.Context))
+                return DeduceNonTypeTemplateArgument(S, TemplateParams, NTTP,
+                                                     DimExpr, Info, Deduced);
+            } else if (const ConstantMatrixType *ConstantMatrixArg =
+                           dyn_cast<ConstantMatrixType>(MatrixArg))
+              ConstIntVal = GetDimension(ConstantMatrixArg);
+
+            return DeduceNonTypeTemplateArgument(
+                S, TemplateParams, NTTP, ConstIntVal, S.Context.getSizeType(),
+                /*ArrayBound=*/false, Info, Deduced);
+          };
+
+      auto Result = DeduceMatrixArg(
+          RowNTTP,
+          [](const ConstantMatrixType *MT) { return MT->getNumRows(); },
+          [](const DependentSizedMatrixType *MT) { return MT->getRowExpr(); });
+      if (Result != Sema::TDK_Success)
+        return Result;
+
+      return DeduceMatrixArg(
+          ColumnNTTP,
+          [](const ConstantMatrixType *MT) { return MT->getNumColumns(); },
+          [](const DependentSizedMatrixType *MT) {
+            return MT->getColumnExpr();
+          });
+    }
+
     //     (clang extension)
     //
     //     T __attribute__(((address_space(N))))
@@ -5723,6 +5833,24 @@
     break;
   }
 
+  case Type::ConstantMatrix: {
+    const ConstantMatrixType *MatType = cast<ConstantMatrixType>(T);
+    MarkUsedTemplateParameters(Ctx, MatType->getElementType(), OnlyDeduced,
+                               Depth, Used);
+    break;
+  }
+
+  case Type::DependentSizedMatrix: {
+    const DependentSizedMatrixType *MatType = cast<DependentSizedMatrixType>(T);
+    MarkUsedTemplateParameters(Ctx, MatType->getElementType(), OnlyDeduced,
+                               Depth, Used);
+    MarkUsedTemplateParameters(Ctx, MatType->getRowExpr(), OnlyDeduced, Depth,
+                               Used);
+    MarkUsedTemplateParameters(Ctx, MatType->getColumnExpr(), OnlyDeduced,
+                               Depth, Used);
+    break;
+  }
+
   case Type::FunctionProto: {
     const FunctionProtoType *Proto = cast<FunctionProtoType>(T);
     MarkUsedTemplateParameters(Ctx, Proto->getReturnType(), OnlyDeduced, Depth,
Index: clang/lib/Sema/SemaTemplate.cpp
===================================================================
--- clang/lib/Sema/SemaTemplate.cpp
+++ clang/lib/Sema/SemaTemplate.cpp
@@ -5867,6 +5867,11 @@
   return Visit(T->getElementType());
 }
 
+bool UnnamedLocalNoLinkageFinder::VisitDependentSizedMatrixType(
+    const DependentSizedMatrixType *T) {
+  return Visit(T->getElementType());
+}
+
 bool UnnamedLocalNoLinkageFinder::VisitDependentAddressSpaceType(
     const DependentAddressSpaceType *T) {
   return Visit(T->getPointeeType());
@@ -5885,6 +5890,11 @@
   return Visit(T->getElementType());
 }
 
+bool UnnamedLocalNoLinkageFinder::VisitConstantMatrixType(
+    const ConstantMatrixType *T) {
+  return Visit(T->getElementType());
+}
+
 bool UnnamedLocalNoLinkageFinder::VisitFunctionProtoType(
                                                   const FunctionProtoType* T) {
   for (const auto &A : T->param_types()) {
Index: clang/lib/Sema/SemaLookup.cpp
===================================================================
--- clang/lib/Sema/SemaLookup.cpp
+++ clang/lib/Sema/SemaLookup.cpp
@@ -2966,6 +2966,7 @@
     // These are fundamental types.
     case Type::Vector:
     case Type::ExtVector:
+    case Type::ConstantMatrix:
     case Type::Complex:
     case Type::ExtInt:
       break;
Index: clang/lib/Sema/SemaExpr.cpp
===================================================================
--- clang/lib/Sema/SemaExpr.cpp
+++ clang/lib/Sema/SemaExpr.cpp
@@ -4257,6 +4257,7 @@
     case Type::Complex:
     case Type::Vector:
     case Type::ExtVector:
+    case Type::ConstantMatrix:
     case Type::Record:
     case Type::Enum:
     case Type::Elaborated:
Index: clang/lib/Frontend/CompilerInvocation.cpp
===================================================================
--- clang/lib/Frontend/CompilerInvocation.cpp
+++ clang/lib/Frontend/CompilerInvocation.cpp
@@ -3336,6 +3336,8 @@
   Opts.CompleteMemberPointers = Args.hasArg(OPT_fcomplete_member_pointers);
   Opts.BuildingPCHWithObjectFile = Args.hasArg(OPT_building_pch_with_obj);
 
+  Opts.MatrixTypes = Args.hasArg(OPT_fenable_matrix);
+
   Opts.MaxTokens = getLastArgIntValue(Args, OPT_fmax_tokens_EQ, 0, Diags);
 
   if (Arg *A = Args.getLastArg(OPT_msign_return_address_EQ)) {
Index: clang/lib/Driver/ToolChains/Clang.cpp
===================================================================
--- clang/lib/Driver/ToolChains/Clang.cpp
+++ clang/lib/Driver/ToolChains/Clang.cpp
@@ -4565,6 +4565,13 @@
   if (Args.hasFlag(options::OPT_mrtd, options::OPT_mno_rtd, false))
     CmdArgs.push_back("-fdefault-calling-conv=stdcall");
 
+  if (Args.hasArg(options::OPT_fenable_matrix)) {
+    // enable-matrix is needed by both the LangOpts and by LLVM.
+    CmdArgs.push_back("-fenable-matrix");
+    CmdArgs.push_back("-mllvm");
+    CmdArgs.push_back("-enable-matrix");
+  }
+
   CodeGenOptions::FramePointerKind FPKeepKind =
                   getFramePointerKind(Args, RawTriple);
   const char *FPKeepKindStr = nullptr;
Index: clang/lib/CodeGen/ItaniumCXXABI.cpp
===================================================================
--- clang/lib/CodeGen/ItaniumCXXABI.cpp
+++ clang/lib/CodeGen/ItaniumCXXABI.cpp
@@ -3223,6 +3223,7 @@
   // GCC treats vector and complex types as fundamental types.
   case Type::Vector:
   case Type::ExtVector:
+  case Type::ConstantMatrix:
   case Type::Complex:
   case Type::Atomic:
   // FIXME: GCC treats block pointers as fundamental types?!
@@ -3458,6 +3459,7 @@
   case Type::Builtin:
   case Type::Vector:
   case Type::ExtVector:
+  case Type::ConstantMatrix:
   case Type::Complex:
   case Type::BlockPointer:
     // Itanium C++ ABI 2.9.5p4:
Index: clang/lib/CodeGen/CodeGenTypes.cpp
===================================================================
--- clang/lib/CodeGen/CodeGenTypes.cpp
+++ clang/lib/CodeGen/CodeGenTypes.cpp
@@ -82,6 +82,13 @@
 /// a type.  For example, the scalar representation for _Bool is i1, but the
 /// memory representation is usually i8 or i32, depending on the target.
 llvm::Type *CodeGenTypes::ConvertTypeForMem(QualType T, bool ForBitField) {
+  if (T->isConstantMatrixType()) {
+    const Type *Ty = Context.getCanonicalType(T).getTypePtr();
+    const ConstantMatrixType *MT = cast<ConstantMatrixType>(Ty);
+    return llvm::ArrayType::get(ConvertType(MT->getElementType()),
+                                MT->getNumRows() * MT->getNumColumns());
+  }
+
   llvm::Type *R = ConvertType(T);
 
   // If this is a bool type, or an ExtIntType in a bitfield representation,
@@ -646,6 +653,12 @@
                                        VT->getNumElements());
     break;
   }
+  case Type::ConstantMatrix: {
+    const ConstantMatrixType *MT = cast<ConstantMatrixType>(Ty);
+    ResultType = llvm::VectorType::get(ConvertType(MT->getElementType()),
+                                       MT->getNumRows() * MT->getNumColumns());
+    break;
+  }
   case Type::FunctionNoProto:
   case Type::FunctionProto:
     ResultType = ConvertFunctionTypeInternal(T);
Index: clang/lib/CodeGen/CodeGenFunction.cpp
===================================================================
--- clang/lib/CodeGen/CodeGenFunction.cpp
+++ clang/lib/CodeGen/CodeGenFunction.cpp
@@ -247,6 +247,7 @@
     case Type::MemberPointer:
     case Type::Vector:
     case Type::ExtVector:
+    case Type::ConstantMatrix:
     case Type::FunctionProto:
     case Type::FunctionNoProto:
     case Type::Enum:
@@ -2000,6 +2001,7 @@
     case Type::Complex:
     case Type::Vector:
     case Type::ExtVector:
+    case Type::ConstantMatrix:
     case Type::Record:
     case Type::Enum:
     case Type::Elaborated:
Index: clang/lib/CodeGen/CGExpr.cpp
===================================================================
--- clang/lib/CodeGen/CGExpr.cpp
+++ clang/lib/CodeGen/CGExpr.cpp
@@ -145,8 +145,19 @@
 
 Address CodeGenFunction::CreateMemTemp(QualType Ty, CharUnits Align,
                                        const Twine &Name, Address *Alloca) {
-  return CreateTempAlloca(ConvertTypeForMem(Ty), Align, Name,
-                          /*ArraySize=*/nullptr, Alloca);
+  Address Result = CreateTempAlloca(ConvertTypeForMem(Ty), Align, Name,
+                                    /*ArraySize=*/nullptr, Alloca);
+
+  if (Ty->isConstantMatrixType()) {
+    auto *ArrayTy = cast<llvm::ArrayType>(Result.getType()->getElementType());
+    auto *VectorTy = llvm::VectorType::get(ArrayTy->getElementType(),
+                                           ArrayTy->getNumElements());
+
+    Result = Address(
+        Builder.CreateBitCast(Result.getPointer(), VectorTy->getPointerTo()),
+        Result.getAlignment());
+  }
+  return Result;
 }
 
 Address CodeGenFunction::CreateMemTempWithoutCast(QualType Ty, CharUnits Align,
@@ -1732,6 +1743,42 @@
   return Value;
 }
 
+// Convert the pointer of \p Addr to a pointer to a vector (the value type of
+// MatrixType), if it points to a array (the memory type of MatrixType).
+static Address MaybeConvertMatrixAddress(Address Addr, CodeGenFunction &CGF,
+                                         bool IsVector = true) {
+  auto *ArrayTy = dyn_cast<llvm::ArrayType>(
+      cast<llvm::PointerType>(Addr.getPointer()->getType())->getElementType());
+  if (ArrayTy && IsVector) {
+    auto *VectorTy = llvm::VectorType::get(ArrayTy->getElementType(),
+                                           ArrayTy->getNumElements());
+
+    return Address(CGF.Builder.CreateElementBitCast(Addr, VectorTy));
+  }
+  auto *VectorTy = dyn_cast<llvm::VectorType>(
+      cast<llvm::PointerType>(Addr.getPointer()->getType())->getElementType());
+  if (VectorTy && !IsVector) {
+    auto *ArrayTy = llvm::ArrayType::get(VectorTy->getElementType(),
+                                         VectorTy->getNumElements());
+
+    return Address(CGF.Builder.CreateElementBitCast(Addr, ArrayTy));
+  }
+
+  return Addr;
+}
+
+// Emit a store of a matrix LValue. This may require casting the original
+// pointer to memory address (ArrayType) to a pointer to the value type
+// (VectorType).
+static void EmitStoreOfMatrixScalar(llvm::Value *value, LValue lvalue,
+                                    bool isInit, CodeGenFunction &CGF) {
+  Address Addr = MaybeConvertMatrixAddress(lvalue.getAddress(CGF), CGF,
+                                           value->getType()->isVectorTy());
+  CGF.EmitStoreOfScalar(value, Addr, lvalue.isVolatile(), lvalue.getType(),
+                        lvalue.getBaseInfo(), lvalue.getTBAAInfo(), isInit,
+                        lvalue.isNontemporal());
+}
+
 void CodeGenFunction::EmitStoreOfScalar(llvm::Value *Value, Address Addr,
                                         bool Volatile, QualType Ty,
                                         LValueBaseInfo BaseInfo,
@@ -1779,11 +1826,26 @@
 
 void CodeGenFunction::EmitStoreOfScalar(llvm::Value *value, LValue lvalue,
                                         bool isInit) {
+  if (lvalue.getType()->isConstantMatrixType()) {
+    EmitStoreOfMatrixScalar(value, lvalue, isInit, *this);
+    return;
+  }
+
   EmitStoreOfScalar(value, lvalue.getAddress(*this), lvalue.isVolatile(),
                     lvalue.getType(), lvalue.getBaseInfo(),
                     lvalue.getTBAAInfo(), isInit, lvalue.isNontemporal());
 }
 
+// Emit a load of a LValue of matrix type. This may require casting the pointer
+// to memory address (ArrayType) to a pointer to the value type (VectorType).
+static RValue EmitLoadOfMatrixLValue(LValue LV, SourceLocation Loc,
+                                     CodeGenFunction &CGF) {
+  assert(LV.getType()->isConstantMatrixType());
+  Address Addr = MaybeConvertMatrixAddress(LV.getAddress(CGF), CGF);
+  LV.setAddress(Addr);
+  return RValue::get(CGF.EmitLoadOfScalar(LV, Loc));
+}
+
 /// EmitLoadOfLValue - Given an expression that represents a value lvalue, this
 /// method emits the address of the lvalue, then loads the result as an rvalue,
 /// returning the rvalue.
@@ -1809,6 +1871,9 @@
   if (LV.isSimple()) {
     assert(!LV.getType()->isFunctionType());
 
+    if (LV.getType()->isConstantMatrixType())
+      return EmitLoadOfMatrixLValue(LV, Loc, *this);
+
     // Everything needs a load.
     return RValue::get(EmitLoadOfScalar(LV, Loc));
   }
Index: clang/lib/CodeGen/CGDebugInfo.h
===================================================================
--- clang/lib/CodeGen/CGDebugInfo.h
+++ clang/lib/CodeGen/CGDebugInfo.h
@@ -192,6 +192,7 @@
   llvm::DIType *CreateType(const ObjCTypeParamType *Ty, llvm::DIFile *Unit);
 
   llvm::DIType *CreateType(const VectorType *Ty, llvm::DIFile *F);
+  llvm::DIType *CreateType(const ConstantMatrixType *Ty, llvm::DIFile *F);
   llvm::DIType *CreateType(const ArrayType *Ty, llvm::DIFile *F);
   llvm::DIType *CreateType(const LValueReferenceType *Ty, llvm::DIFile *F);
   llvm::DIType *CreateType(const RValueReferenceType *Ty, llvm::DIFile *Unit);
Index: clang/lib/CodeGen/CGDebugInfo.cpp
===================================================================
--- clang/lib/CodeGen/CGDebugInfo.cpp
+++ clang/lib/CodeGen/CGDebugInfo.cpp
@@ -2736,6 +2736,23 @@
   return DBuilder.createVectorType(Size, Align, ElementTy, SubscriptArray);
 }
 
+llvm::DIType *CGDebugInfo::CreateType(const ConstantMatrixType *Ty,
+                                      llvm::DIFile *Unit) {
+  // FIXME: Create another debug type for matrices
+  // For the time being, it treats it like a nested ArrayType.
+
+  llvm::DIType *ElementTy = getOrCreateType(Ty->getElementType(), Unit);
+  uint64_t Size = CGM.getContext().getTypeSize(Ty);
+  uint32_t Align = getTypeAlignIfRequired(Ty, CGM.getContext());
+
+  // Create ranges for both dimensions.
+  llvm::SmallVector<llvm::Metadata *, 2> Subscripts;
+  Subscripts.push_back(DBuilder.getOrCreateSubrange(0, Ty->getNumColumns()));
+  Subscripts.push_back(DBuilder.getOrCreateSubrange(0, Ty->getNumRows()));
+  llvm::DINodeArray SubscriptArray = DBuilder.getOrCreateArray(Subscripts);
+  return DBuilder.createArrayType(Size, Align, ElementTy, SubscriptArray);
+}
+
 llvm::DIType *CGDebugInfo::CreateType(const ArrayType *Ty, llvm::DIFile *Unit) {
   uint64_t Size;
   uint32_t Align;
@@ -3129,6 +3146,8 @@
   case Type::ExtVector:
   case Type::Vector:
     return CreateType(cast<VectorType>(Ty), Unit);
+  case Type::ConstantMatrix:
+    return CreateType(cast<ConstantMatrixType>(Ty), Unit);
   case Type::ObjCObjectPointer:
     return CreateType(cast<ObjCObjectPointerType>(Ty), Unit);
   case Type::ObjCObject:
Index: clang/lib/AST/TypePrinter.cpp
===================================================================
--- clang/lib/AST/TypePrinter.cpp
+++ clang/lib/AST/TypePrinter.cpp
@@ -256,6 +256,8 @@
     case Type::DependentSizedExtVector:
     case Type::Vector:
     case Type::ExtVector:
+    case Type::ConstantMatrix:
+    case Type::DependentSizedMatrix:
     case Type::FunctionProto:
     case Type::FunctionNoProto:
     case Type::Paren:
@@ -720,6 +722,38 @@
   OS << ")))";
 }
 
+void TypePrinter::printConstantMatrixBefore(const ConstantMatrixType *T,
+                                            raw_ostream &OS) {
+  printBefore(T->getElementType(), OS);
+  OS << " __attribute__((matrix_type(";
+  OS << T->getNumRows() << ", " << T->getNumColumns();
+  OS << ")))";
+}
+
+void TypePrinter::printConstantMatrixAfter(const ConstantMatrixType *T,
+                                           raw_ostream &OS) {
+  printAfter(T->getElementType(), OS);
+}
+
+void TypePrinter::printDependentSizedMatrixBefore(
+    const DependentSizedMatrixType *T, raw_ostream &OS) {
+  printBefore(T->getElementType(), OS);
+  OS << " __attribute__((matrix_type(";
+  if (T->getRowExpr()) {
+    T->getRowExpr()->printPretty(OS, nullptr, Policy);
+  }
+  OS << ", ";
+  if (T->getColumnExpr()) {
+    T->getColumnExpr()->printPretty(OS, nullptr, Policy);
+  }
+  OS << ")))";
+}
+
+void TypePrinter::printDependentSizedMatrixAfter(
+    const DependentSizedMatrixType *T, raw_ostream &OS) {
+  printAfter(T->getElementType(), OS);
+}
+
 void
 FunctionProtoType::printExceptionSpecification(raw_ostream &OS,
                                                const PrintingPolicy &Policy)
Index: clang/lib/AST/Type.cpp
===================================================================
--- clang/lib/AST/Type.cpp
+++ clang/lib/AST/Type.cpp
@@ -282,6 +282,53 @@
   AddrSpaceExpr->Profile(ID, Context, true);
 }
 
+MatrixType::MatrixType(TypeClass tc, QualType matrixType, QualType canonType,
+                       const Expr *RowExpr, const Expr *ColumnExpr)
+    : Type(tc, canonType,
+           (RowExpr ? (matrixType->getDependence() | TypeDependence::Dependent |
+                       TypeDependence::Instantiation |
+                       (matrixType->isVariablyModifiedType()
+                            ? TypeDependence::VariablyModified
+                            : TypeDependence::None) |
+                       (matrixType->containsUnexpandedParameterPack() ||
+                                (RowExpr &&
+                                 RowExpr->containsUnexpandedParameterPack()) ||
+                                (ColumnExpr &&
+                                 ColumnExpr->containsUnexpandedParameterPack())
+                            ? TypeDependence::UnexpandedPack
+                            : TypeDependence::None))
+                    : matrixType->getDependence())),
+      ElementType(matrixType) {}
+
+ConstantMatrixType::ConstantMatrixType(QualType matrixType, unsigned nRows,
+                                       unsigned nColumns, QualType canonType)
+    : ConstantMatrixType(ConstantMatrix, matrixType, nRows, nColumns,
+                         canonType) {}
+
+ConstantMatrixType::ConstantMatrixType(TypeClass tc, QualType matrixType,
+                                       unsigned nRows, unsigned nColumns,
+                                       QualType canonType)
+    : MatrixType(tc, matrixType, canonType) {
+  ConstantMatrixTypeBits.NumRows = nRows;
+  ConstantMatrixTypeBits.NumColumns = nColumns;
+}
+
+DependentSizedMatrixType::DependentSizedMatrixType(
+    const ASTContext &CTX, QualType ElementType, QualType CanonicalType,
+    Expr *RowExpr, Expr *ColumnExpr, SourceLocation loc)
+    : MatrixType(DependentSizedMatrix, ElementType, CanonicalType, RowExpr,
+                 ColumnExpr),
+      Context(CTX), RowExpr(RowExpr), ColumnExpr(ColumnExpr), loc(loc) {}
+
+void DependentSizedMatrixType::Profile(llvm::FoldingSetNodeID &ID,
+                                       const ASTContext &CTX,
+                                       QualType ElementType, Expr *RowExpr,
+                                       Expr *ColumnExpr) {
+  ID.AddPointer(ElementType.getAsOpaquePtr());
+  RowExpr->Profile(ID, CTX, true);
+  ColumnExpr->Profile(ID, CTX, true);
+}
+
 VectorType::VectorType(QualType vecType, unsigned nElements, QualType canonType,
                        VectorKind vecKind)
     : VectorType(Vector, vecType, nElements, canonType, vecKind) {}
@@ -971,6 +1018,17 @@
     return Ctx.getExtVectorType(elementType, T->getNumElements());
   }
 
+  QualType VisitConstantMatrixType(const ConstantMatrixType *T) {
+    QualType elementType = recurse(T->getElementType());
+    if (elementType.isNull())
+      return {};
+    if (elementType.getAsOpaquePtr() == T->getElementType().getAsOpaquePtr())
+      return QualType(T, 0);
+
+    return Ctx.getConstantMatrixType(elementType, T->getNumRows(),
+                                     T->getNumColumns());
+  }
+
   QualType VisitFunctionNoProtoType(const FunctionNoProtoType *T) {
     QualType returnType = recurse(T->getReturnType());
     if (returnType.isNull())
@@ -1790,6 +1848,14 @@
       return Visit(T->getElementType());
     }
 
+    Type *VisitDependentSizedMatrixType(const DependentSizedMatrixType *T) {
+      return Visit(T->getElementType());
+    }
+
+    Type *VisitConstantMatrixType(const ConstantMatrixType *T) {
+      return Visit(T->getElementType());
+    }
+
     Type *VisitFunctionProtoType(const FunctionProtoType *T) {
       if (Syntactic && T->hasTrailingReturn())
         return const_cast<FunctionProtoType*>(T);
@@ -3744,6 +3810,8 @@
   case Type::Vector:
   case Type::ExtVector:
     return Cache::get(cast<VectorType>(T)->getElementType());
+  case Type::ConstantMatrix:
+    return Cache::get(cast<ConstantMatrixType>(T)->getElementType());
   case Type::FunctionNoProto:
     return Cache::get(cast<FunctionType>(T)->getReturnType());
   case Type::FunctionProto: {
@@ -3830,6 +3898,9 @@
   case Type::Vector:
   case Type::ExtVector:
     return computeTypeLinkageInfo(cast<VectorType>(T)->getElementType());
+  case Type::ConstantMatrix:
+    return computeTypeLinkageInfo(
+        cast<ConstantMatrixType>(T)->getElementType());
   case Type::FunctionNoProto:
     return computeTypeLinkageInfo(cast<FunctionType>(T)->getReturnType());
   case Type::FunctionProto: {
@@ -3993,6 +4064,8 @@
   case Type::DependentSizedExtVector:
   case Type::Vector:
   case Type::ExtVector:
+  case Type::ConstantMatrix:
+  case Type::DependentSizedMatrix:
   case Type::DependentAddressSpace:
   case Type::FunctionProto:
   case Type::FunctionNoProto:
Index: clang/lib/AST/MicrosoftMangle.cpp
===================================================================
--- clang/lib/AST/MicrosoftMangle.cpp
+++ clang/lib/AST/MicrosoftMangle.cpp
@@ -2730,6 +2730,23 @@
     << Range;
 }
 
+void MicrosoftCXXNameMangler::mangleType(const ConstantMatrixType *T,
+                                         Qualifiers quals, SourceRange Range) {
+  DiagnosticsEngine &Diags = Context.getDiags();
+  unsigned DiagID = Diags.getCustomDiagID(DiagnosticsEngine::Error,
+                                          "Cannot mangle this matrix type yet");
+  Diags.Report(Range.getBegin(), DiagID) << Range;
+}
+
+void MicrosoftCXXNameMangler::mangleType(const DependentSizedMatrixType *T,
+                                         Qualifiers quals, SourceRange Range) {
+  DiagnosticsEngine &Diags = Context.getDiags();
+  unsigned DiagID = Diags.getCustomDiagID(
+      DiagnosticsEngine::Error,
+      "Cannot mangle this dependent-sized matrix type yet");
+  Diags.Report(Range.getBegin(), DiagID) << Range;
+}
+
 void MicrosoftCXXNameMangler::mangleType(const DependentAddressSpaceType *T,
                                          Qualifiers, SourceRange Range) {
   DiagnosticsEngine &Diags = Context.getDiags();
Index: clang/lib/AST/ItaniumMangle.cpp
===================================================================
--- clang/lib/AST/ItaniumMangle.cpp
+++ clang/lib/AST/ItaniumMangle.cpp
@@ -2079,6 +2079,8 @@
   case Type::DependentSizedExtVector:
   case Type::Vector:
   case Type::ExtVector:
+  case Type::ConstantMatrix:
+  case Type::DependentSizedMatrix:
   case Type::FunctionProto:
   case Type::FunctionNoProto:
   case Type::Paren:
@@ -3343,6 +3345,31 @@
   mangleType(T->getElementType());
 }
 
+void CXXNameMangler::mangleType(const ConstantMatrixType *T) {
+  // Mangle matrix types using a vendor extended type qualifier:
+  // U<Len>matrix_type<Rows><Columns><element type>
+  StringRef VendorQualifier = "matrix_type";
+  Out << "U" << VendorQualifier.size() << VendorQualifier;
+  auto &ASTCtx = getASTContext();
+  unsigned BitWidth = ASTCtx.getTypeSize(ASTCtx.getSizeType());
+  llvm::APSInt Rows(BitWidth);
+  Rows = T->getNumRows();
+  mangleIntegerLiteral(ASTCtx.getSizeType(), Rows);
+  llvm::APSInt Columns(BitWidth);
+  Columns = T->getNumColumns();
+  mangleIntegerLiteral(ASTCtx.getSizeType(), Columns);
+  mangleType(T->getElementType());
+}
+
+void CXXNameMangler::mangleType(const DependentSizedMatrixType *T) {
+  // U<Len>matrix_type<row expr><column expr><element type>
+  StringRef VendorQualifier = "matrix_type";
+  Out << "U" << VendorQualifier.size() << VendorQualifier;
+  mangleTemplateArg(T->getRowExpr());
+  mangleTemplateArg(T->getColumnExpr());
+  mangleType(T->getElementType());
+}
+
 void CXXNameMangler::mangleType(const DependentAddressSpaceType *T) {
   SplitQualType split = T->getPointeeType().split();
   mangleQualifiers(split.Quals, T);
Index: clang/lib/AST/ExprConstant.cpp
===================================================================
--- clang/lib/AST/ExprConstant.cpp
+++ clang/lib/AST/ExprConstant.cpp
@@ -10350,6 +10350,7 @@
   case Type::BlockPointer:
   case Type::Vector:
   case Type::ExtVector:
+  case Type::ConstantMatrix:
   case Type::ObjCObject:
   case Type::ObjCInterface:
   case Type::ObjCObjectPointer:
Index: clang/lib/AST/ASTStructuralEquivalence.cpp
===================================================================
--- clang/lib/AST/ASTStructuralEquivalence.cpp
+++ clang/lib/AST/ASTStructuralEquivalence.cpp
@@ -617,6 +617,34 @@
     break;
   }
 
+  case Type::DependentSizedMatrix: {
+    const DependentSizedMatrixType *Mat1 = cast<DependentSizedMatrixType>(T1);
+    const DependentSizedMatrixType *Mat2 = cast<DependentSizedMatrixType>(T2);
+    // The element types, row and column expressions must be structurally
+    // equivalent.
+    if (!IsStructurallyEquivalent(Context, Mat1->getRowExpr(),
+                                  Mat2->getRowExpr()) ||
+        !IsStructurallyEquivalent(Context, Mat1->getColumnExpr(),
+                                  Mat2->getColumnExpr()) ||
+        !IsStructurallyEquivalent(Context, Mat1->getElementType(),
+                                  Mat2->getElementType()))
+      return false;
+    break;
+  }
+
+  case Type::ConstantMatrix: {
+    const ConstantMatrixType *Mat1 = cast<ConstantMatrixType>(T1);
+    const ConstantMatrixType *Mat2 = cast<ConstantMatrixType>(T2);
+    // The element types must be structurally equivalent and the number of rows
+    // and columns must match.
+    if (!IsStructurallyEquivalent(Context, Mat1->getElementType(),
+                                  Mat2->getElementType()) ||
+        Mat1->getNumRows() != Mat2->getNumRows() ||
+        Mat1->getNumColumns() != Mat2->getNumColumns())
+      return false;
+    break;
+  }
+
   case Type::FunctionProto: {
     const auto *Proto1 = cast<FunctionProtoType>(T1);
     const auto *Proto2 = cast<FunctionProtoType>(T2);
Index: clang/lib/AST/ASTContext.cpp
===================================================================
--- clang/lib/AST/ASTContext.cpp
+++ clang/lib/AST/ASTContext.cpp
@@ -1932,6 +1932,17 @@
     break;
   }
 
+  case Type::ConstantMatrix: {
+    const auto *MT = cast<ConstantMatrixType>(T);
+    TypeInfo ElementInfo = getTypeInfo(MT->getElementType());
+    // The internal layout of a matrix value is implementation defined.
+    // Initially be ABI compatible with arrays with respect to alignment and
+    // size.
+    Width = ElementInfo.Width * MT->getNumRows() * MT->getNumColumns();
+    Align = ElementInfo.Align;
+    break;
+  }
+
   case Type::Builtin:
     switch (cast<BuiltinType>(T)->getKind()) {
     default: llvm_unreachable("Unknown builtin type!");
@@ -3362,6 +3373,8 @@
   case Type::DependentVector:
   case Type::ExtVector:
   case Type::DependentSizedExtVector:
+  case Type::ConstantMatrix:
+  case Type::DependentSizedMatrix:
   case Type::DependentAddressSpace:
   case Type::ObjCObject:
   case Type::ObjCInterface:
@@ -3775,6 +3788,78 @@
   return QualType(New, 0);
 }
 
+QualType ASTContext::getConstantMatrixType(QualType ElementTy, unsigned NumRows,
+                                           unsigned NumColumns) const {
+  llvm::FoldingSetNodeID ID;
+  ConstantMatrixType::Profile(ID, ElementTy, NumRows, NumColumns,
+                              Type::ConstantMatrix);
+
+  assert(MatrixType::isValidElementType(ElementTy) &&
+         "need a valid element type");
+  assert(ConstantMatrixType::isDimensionValid(NumRows) &&
+         ConstantMatrixType::isDimensionValid(NumColumns) &&
+         "need valid matrix dimensions");
+  void *InsertPos = nullptr;
+  if (ConstantMatrixType *MTP = MatrixTypes.FindNodeOrInsertPos(ID, InsertPos))
+    return QualType(MTP, 0);
+
+  QualType Canonical;
+  if (!ElementTy.isCanonical()) {
+    Canonical =
+        getConstantMatrixType(getCanonicalType(ElementTy), NumRows, NumColumns);
+
+    ConstantMatrixType *NewIP = MatrixTypes.FindNodeOrInsertPos(ID, InsertPos);
+    assert(!NewIP && "Matrix type shouldn't already exist in the map");
+    (void)NewIP;
+  }
+
+  auto *New = new (*this, TypeAlignment)
+      ConstantMatrixType(ElementTy, NumRows, NumColumns, Canonical);
+  MatrixTypes.InsertNode(New, InsertPos);
+  Types.push_back(New);
+  return QualType(New, 0);
+}
+
+QualType ASTContext::getDependentSizedMatrixType(QualType ElementTy,
+                                                 Expr *RowExpr,
+                                                 Expr *ColumnExpr,
+                                                 SourceLocation AttrLoc) const {
+  QualType CanonElementTy = getCanonicalType(ElementTy);
+  llvm::FoldingSetNodeID ID;
+  DependentSizedMatrixType::Profile(ID, *this, CanonElementTy, RowExpr,
+                                    ColumnExpr);
+
+  void *InsertPos = nullptr;
+  DependentSizedMatrixType *Canon =
+      DependentSizedMatrixTypes.FindNodeOrInsertPos(ID, InsertPos);
+
+  if (!Canon) {
+    Canon = new (*this, TypeAlignment) DependentSizedMatrixType(
+        *this, CanonElementTy, QualType(), RowExpr, ColumnExpr, AttrLoc);
+#ifndef NDEBUG
+    DependentSizedMatrixType *CanonCheck =
+        DependentSizedMatrixTypes.FindNodeOrInsertPos(ID, InsertPos);
+    assert(!CanonCheck && "Dependent-sized matrix canonical type broken");
+#endif
+    DependentSizedMatrixTypes.InsertNode(Canon, InsertPos);
+    Types.push_back(Canon);
+  }
+
+  // Already have a canonical version of the matrix type
+  //
+  // If it exactly matches the requested type, use it directly.
+  if (Canon->getElementType() == ElementTy && Canon->getRowExpr() == RowExpr &&
+      Canon->getRowExpr() == ColumnExpr)
+    return QualType(Canon, 0);
+
+  // Use Canon as the canonical type for newly-built type.
+  DependentSizedMatrixType *New = new (*this, TypeAlignment)
+      DependentSizedMatrixType(*this, ElementTy, QualType(Canon, 0), RowExpr,
+                               ColumnExpr, AttrLoc);
+  Types.push_back(New);
+  return QualType(New, 0);
+}
+
 QualType ASTContext::getDependentAddressSpaceType(QualType PointeeType,
                                                   Expr *AddrSpaceExpr,
                                                   SourceLocation AttrLoc) const {
@@ -7338,6 +7423,11 @@
       *NotEncodedT = T;
     return;
 
+  case Type::ConstantMatrix:
+    if (NotEncodedT)
+      *NotEncodedT = T;
+    return;
+
   // We could see an undeduced auto type here during error recovery.
   // Just ignore it.
   case Type::Auto:
@@ -8217,6 +8307,16 @@
          LHS->getNumElements() == RHS->getNumElements();
 }
 
+/// areCompatMatrixTypes - Return true if the two specified matrix types are
+/// compatible.
+static bool areCompatMatrixTypes(const ConstantMatrixType *LHS,
+                                 const ConstantMatrixType *RHS) {
+  assert(LHS->isCanonicalUnqualified() && RHS->isCanonicalUnqualified());
+  return LHS->getElementType() == RHS->getElementType() &&
+         LHS->getNumRows() == RHS->getNumRows() &&
+         LHS->getNumColumns() == RHS->getNumColumns();
+}
+
 bool ASTContext::areCompatibleVectorTypes(QualType FirstVec,
                                           QualType SecondVec) {
   assert(FirstVec->isVectorType() && "FirstVec should be a vector type");
@@ -9414,6 +9514,11 @@
                              RHSCan->castAs<VectorType>()))
       return LHS;
     return {};
+  case Type::ConstantMatrix:
+    if (areCompatMatrixTypes(LHSCan->castAs<ConstantMatrixType>(),
+                             RHSCan->castAs<ConstantMatrixType>()))
+      return LHS;
+    return {};
   case Type::ObjCObject: {
     // Check if the types are assignment compatible.
     // FIXME: This should be type compatibility, e.g. whether
Index: clang/include/clang/Serialization/TypeBitCodes.def
===================================================================
--- clang/include/clang/Serialization/TypeBitCodes.def
+++ clang/include/clang/Serialization/TypeBitCodes.def
@@ -60,5 +60,7 @@
 TYPE_BIT_CODE(MacroQualified, MACRO_QUALIFIED, 49)
 TYPE_BIT_CODE(ExtInt, EXT_INT, 50)
 TYPE_BIT_CODE(DependentExtInt, DEPENDENT_EXT_INT, 51)
+TYPE_BIT_CODE(ConstantMatrix, CONSTANT_MATRIX, 52)
+TYPE_BIT_CODE(DependentSizedMatrix, DEPENDENT_SIZE_MATRIX, 53)
 
 #undef TYPE_BIT_CODE
Index: clang/include/clang/Sema/Sema.h
===================================================================
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -1627,6 +1627,9 @@
   QualType BuildVectorType(QualType T, Expr *VecSize, SourceLocation AttrLoc);
   QualType BuildExtVectorType(QualType T, Expr *ArraySize,
                               SourceLocation AttrLoc);
+  QualType BuildMatrixType(QualType T, Expr *NumRows, Expr *NumColumns,
+                           SourceLocation AttrLoc);
+
   QualType BuildAddressSpaceAttr(QualType &T, LangAS ASIdx, Expr *AddrSpace,
                                  SourceLocation AttrLoc);
 
Index: clang/include/clang/Driver/Options.td
===================================================================
--- clang/include/clang/Driver/Options.td
+++ clang/include/clang/Driver/Options.td
@@ -2007,6 +2007,10 @@
 def fno_strict_return : Flag<["-"], "fno-strict-return">, Group<f_Group>,
   Flags<[CC1Option]>;
 
+def fenable_matrix : Flag<["-"], "fenable-matrix">, Group<f_Group>,
+    Flags<[CC1Option]>,
+    HelpText<"Enable matrix data type and related builtin functions">;
+
 def fallow_editor_placeholders : Flag<["-"], "fallow-editor-placeholders">,
   Group<f_Group>, Flags<[CC1Option]>,
   HelpText<"Treat editor placeholders as valid source code">;
Index: clang/include/clang/Basic/TypeNodes.td
===================================================================
--- clang/include/clang/Basic/TypeNodes.td
+++ clang/include/clang/Basic/TypeNodes.td
@@ -69,6 +69,9 @@
 def VectorType : TypeNode<Type>;
 def DependentVectorType : TypeNode<Type>, AlwaysDependent;
 def ExtVectorType : TypeNode<VectorType>;
+def MatrixType : TypeNode<Type, 1>;
+def ConstantMatrixType : TypeNode<MatrixType>;
+def DependentSizedMatrixType : TypeNode<MatrixType>, AlwaysDependent;
 def FunctionType : TypeNode<Type, 1>;
 def FunctionProtoType : TypeNode<FunctionType>;
 def FunctionNoProtoType : TypeNode<FunctionType>;
Index: clang/include/clang/Basic/LangOptions.def
===================================================================
--- clang/include/clang/Basic/LangOptions.def
+++ clang/include/clang/Basic/LangOptions.def
@@ -357,6 +357,8 @@
 
 LANGOPT(RegisterStaticDestructors, 1, 1, "Register C++ static destructors")
 
+LANGOPT(MatrixTypes, 1, 0, "Enable or disable the builtin matrix type")
+
 COMPATIBLE_VALUE_LANGOPT(MaxTokens, 32, 0, "Max number of tokens per TU or 0")
 
 ENUM_LANGOPT(SignReturnAddressScope, SignReturnAddressScopeKind, 2, SignReturnAddressScopeKind::None,
Index: clang/include/clang/Basic/Features.def
===================================================================
--- clang/include/clang/Basic/Features.def
+++ clang/include/clang/Basic/Features.def
@@ -253,6 +253,7 @@
 EXTENSION(pragma_clang_attribute_external_declaration, true)
 EXTENSION(gnu_asm, LangOpts.GNUAsm)
 EXTENSION(gnu_asm_goto_with_outputs, LangOpts.GNUAsm)
+EXTENSION(matrix_types, LangOpts.MatrixTypes)
 
 #undef EXTENSION
 #undef FEATURE
Index: clang/include/clang/Basic/DiagnosticSemaKinds.td
===================================================================
--- clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -2776,6 +2776,7 @@
 def err_attribute_too_few_arguments : Error<
   "%0 attribute takes at least %1 argument%s1">;
 def err_attribute_invalid_vector_type : Error<"invalid vector element type %0">;
+def err_attribute_invalid_matrix_type : Error<"invalid matrix element type %0">;
 def err_attribute_bad_neon_vector_size : Error<
   "Neon vector size must be 64 or 128 bits">;
 def err_attribute_requires_positive_integer : Error<
@@ -2879,8 +2880,8 @@
   "init methods must return an object pointer type, not %0">;
 def err_attribute_invalid_size : Error<
   "vector size not an integral multiple of component size">;
-def err_attribute_zero_size : Error<"zero vector size">;
-def err_attribute_size_too_large : Error<"vector size too large">;
+def err_attribute_zero_size : Error<"zero %0 size">;
+def err_attribute_size_too_large : Error<"%0 size too large">;
 def err_typecheck_vector_not_convertable_implict_truncation : Error<
    "cannot convert between %select{scalar|vector}0 type %1 and vector type"
    " %2 as implicit conversion would cause truncation">;
@@ -10727,6 +10728,9 @@
   "%select{non-pointer|function pointer|void pointer}0 argument to "
   "'__builtin_launder' is not allowed">;
 
+def err_builtin_matrix_disabled: Error<
+  "matrix types extension is disabled. Pass -fenable-matrix to enable it">;
+
 def err_preserve_field_info_not_field : Error<
   "__builtin_preserve_field_info argument %0 not a field access">;
 def err_preserve_field_info_not_const: Error<
Index: clang/include/clang/Basic/Attr.td
===================================================================
--- clang/include/clang/Basic/Attr.td
+++ clang/include/clang/Basic/Attr.td
@@ -2460,6 +2460,15 @@
   let Documentation = [Undocumented];
 }
 
+def MatrixType : TypeAttr {
+  let Spellings = [Clang<"matrix_type">];
+  let Subjects = SubjectList<[TypedefName], ErrorDiag>;
+  let Args = [ExprArgument<"NumRows">, ExprArgument<"NumColumns">];
+  let Documentation = [Undocumented];
+  let ASTNode = 0;
+  let PragmaAttributeSupport = 0;
+}
+
 def Visibility : InheritableAttr {
   let Clone = 0;
   let Spellings = [GCC<"visibility">];
Index: clang/include/clang/AST/TypeProperties.td
===================================================================
--- clang/include/clang/AST/TypeProperties.td
+++ clang/include/clang/AST/TypeProperties.td
@@ -224,6 +224,41 @@
   }]>;
 }
 
+let Class = MatrixType in {
+  def : Property<"elementType", QualType> {
+    let Read = [{ node->getElementType() }];
+  }
+}
+
+let Class = ConstantMatrixType in {
+  def : Property<"numRows", UInt32> {
+    let Read = [{ node->getNumRows() }];
+  }
+  def : Property<"numColumns", UInt32> {
+    let Read = [{ node->getNumColumns() }];
+  }
+
+  def : Creator<[{
+    return ctx.getConstantMatrixType(elementType, numRows, numColumns);
+  }]>;
+}
+
+let Class = DependentSizedMatrixType in {
+  def : Property<"rows", ExprRef> {
+    let Read = [{ node->getRowExpr() }];
+  }
+  def : Property<"columns", ExprRef> {
+    let Read = [{ node->getColumnExpr() }];
+  }
+  def : Property<"attributeLoc", SourceLocation> {
+    let Read = [{ node->getAttributeLoc() }];
+  }
+
+  def : Creator<[{
+    return ctx.getDependentSizedMatrixType(elementType, rows, columns, attributeLoc);
+  }]>;
+}
+
 let Class = FunctionType in {
   def : Property<"returnType", QualType> {
     let Read = [{ node->getReturnType() }];
Index: clang/include/clang/AST/TypeLoc.h
===================================================================
--- clang/include/clang/AST/TypeLoc.h
+++ clang/include/clang/AST/TypeLoc.h
@@ -1735,6 +1735,7 @@
 
   void initializeLocal(ASTContext &Context, SourceLocation loc) {
     setAttrNameLoc(loc);
+    setAttrOperandParensRange(loc);
     setAttrOperandParensRange(SourceRange(loc));
     setAttrExprOperand(getTypePtr()->getAddrSpaceExpr());
   }
@@ -1774,6 +1775,68 @@
                                      DependentSizedExtVectorType> {
 };
 
+struct MatrixTypeLocInfo {
+  SourceLocation AttrLoc;
+  SourceRange OperandParens;
+  Expr *RowOperand;
+  Expr *ColumnOperand;
+};
+
+class MatrixTypeLoc : public ConcreteTypeLoc<UnqualTypeLoc, MatrixTypeLoc,
+                                             MatrixType, MatrixTypeLocInfo> {
+public:
+  /// The location of the attribute name, i.e.
+  ///    float __attribute__((matrix_type(4, 2)))
+  ///                         ^~~~~~~~~~~~~~~~~
+  SourceLocation getAttrNameLoc() const { return getLocalData()->AttrLoc; }
+  void setAttrNameLoc(SourceLocation loc) { getLocalData()->AttrLoc = loc; }
+
+  /// The attribute's row operand, if it has one.
+  ///    float __attribute__((matrix_type(4, 2)))
+  ///                                     ^
+  Expr *getAttrRowOperand() const { return getLocalData()->RowOperand; }
+  void setAttrRowOperand(Expr *e) { getLocalData()->RowOperand = e; }
+
+  /// The attribute's column operand, if it has one.
+  ///    float __attribute__((matrix_type(4, 2)))
+  ///                                        ^
+  Expr *getAttrColumnOperand() const { return getLocalData()->ColumnOperand; }
+  void setAttrColumnOperand(Expr *e) { getLocalData()->ColumnOperand = e; }
+
+  /// The location of the parentheses around the operand, if there is
+  /// an operand.
+  ///    float __attribute__((matrix_type(4, 2)))
+  ///                                    ^    ^
+  SourceRange getAttrOperandParensRange() const {
+    return getLocalData()->OperandParens;
+  }
+  void setAttrOperandParensRange(SourceRange range) {
+    getLocalData()->OperandParens = range;
+  }
+
+  SourceRange getLocalSourceRange() const {
+    SourceRange range(getAttrNameLoc());
+    range.setEnd(getAttrOperandParensRange().getEnd());
+    return range;
+  }
+
+  void initializeLocal(ASTContext &Context, SourceLocation loc) {
+    setAttrNameLoc(loc);
+    setAttrOperandParensRange(loc);
+    setAttrRowOperand(nullptr);
+    setAttrColumnOperand(nullptr);
+  }
+};
+
+class ConstantMatrixTypeLoc
+    : public InheritingConcreteTypeLoc<MatrixTypeLoc, ConstantMatrixTypeLoc,
+                                       ConstantMatrixType> {};
+
+class DependentSizedMatrixTypeLoc
+    : public InheritingConcreteTypeLoc<MatrixTypeLoc,
+                                       DependentSizedMatrixTypeLoc,
+                                       DependentSizedMatrixType> {};
+
 // FIXME: location of the '_Complex' keyword.
 class ComplexTypeLoc : public InheritingConcreteTypeLoc<TypeSpecTypeLoc,
                                                         ComplexTypeLoc,
Index: clang/include/clang/AST/Type.h
===================================================================
--- clang/include/clang/AST/Type.h
+++ clang/include/clang/AST/Type.h
@@ -1654,6 +1654,19 @@
     uint32_t NumElements;
   };
 
+  class ConstantMatrixTypeBitfields {
+    friend class ConstantMatrixType;
+
+    unsigned : NumTypeBits;
+
+    /// Number of rows and columns. Using 20 bits allows supporting very large
+    /// matrixes, while keeping 24 bits to accommodate NumTypeBits.
+    unsigned NumRows : 20;
+    unsigned NumColumns : 20;
+
+    static constexpr uint32_t MaxElementsPerDimension = (1 << 20) - 1;
+  };
+
   class AttributedTypeBitfields {
     friend class AttributedType;
 
@@ -1763,6 +1776,7 @@
     TypeWithKeywordBitfields TypeWithKeywordBits;
     ElaboratedTypeBitfields ElaboratedTypeBits;
     VectorTypeBitfields VectorTypeBits;
+    ConstantMatrixTypeBitfields ConstantMatrixTypeBits;
     SubstTemplateTypeParmPackTypeBitfields SubstTemplateTypeParmPackTypeBits;
     TemplateSpecializationTypeBitfields TemplateSpecializationTypeBits;
     DependentTemplateSpecializationTypeBitfields
@@ -2021,6 +2035,7 @@
   bool isComplexIntegerType() const;            // GCC _Complex integer type.
   bool isVectorType() const;                    // GCC vector type.
   bool isExtVectorType() const;                 // Extended vector type.
+  bool isConstantMatrixType() const;            // Matrix type.
   bool isDependentAddressSpaceType() const;     // value-dependent address space qualifier
   bool isObjCObjectPointerType() const;         // pointer to ObjC object
   bool isObjCRetainableType() const;            // ObjC object or block pointer
@@ -3390,6 +3405,131 @@
   }
 };
 
+/// Represents a matrix type, as defined in the Matrix Types clang extensions.
+/// __attribute__((matrix_type(rows, columns))), where "rows" specifies
+/// number of rows and "columns" specifies the number of columns.
+class MatrixType : public Type, public llvm::FoldingSetNode {
+protected:
+  friend class ASTContext;
+
+  /// The element type of the matrix.
+  QualType ElementType;
+
+  MatrixType(QualType ElementTy, QualType CanonElementTy);
+
+  MatrixType(TypeClass TypeClass, QualType ElementTy, QualType CanonElementTy,
+             const Expr *RowExpr = nullptr, const Expr *ColumnExpr = nullptr);
+
+public:
+  /// Returns type of the elements being stored in the matrix
+  QualType getElementType() const { return ElementType; }
+
+  /// Valid elements types are the following:
+  /// * an integer type (as in C2x 6.2.5p19), but excluding enumerated types
+  ///   and _Bool
+  /// * the standard floating types float or double
+  /// * a half-precision floating point type, if one is supported on the target
+  static bool isValidElementType(QualType T) {
+    return T->isDependentType() ||
+           (T->isRealType() && !T->isBooleanType() && !T->isEnumeralType());
+  }
+
+  bool isSugared() const { return false; }
+  QualType desugar() const { return QualType(this, 0); }
+
+  static bool classof(const Type *T) {
+    return T->getTypeClass() == ConstantMatrix ||
+           T->getTypeClass() == DependentSizedMatrix;
+  }
+};
+
+/// Represents a concrete matrix type with constant number of rows and columns
+class ConstantMatrixType final : public MatrixType {
+protected:
+  friend class ASTContext;
+
+  /// The element type of the matrix.
+  QualType ElementType;
+
+  ConstantMatrixType(QualType MatrixElementType, unsigned NRows,
+                     unsigned NColumns, QualType CanonElementType);
+
+  ConstantMatrixType(TypeClass typeClass, QualType MatrixType, unsigned NRows,
+                     unsigned NColumns, QualType CanonElementType);
+
+public:
+  /// Returns the number of rows in the matrix.
+  unsigned getNumRows() const { return ConstantMatrixTypeBits.NumRows; }
+
+  /// Returns the number of columns in the matrix.
+  unsigned getNumColumns() const { return ConstantMatrixTypeBits.NumColumns; }
+
+  /// Returns the number of elements required to embed the matrix into a vector.
+  unsigned getNumElementsFlattened() const {
+    return ConstantMatrixTypeBits.NumRows * ConstantMatrixTypeBits.NumColumns;
+  }
+
+  /// Returns true if \p NumElements is a valid matrix dimension.
+  static bool isDimensionValid(uint64_t NumElements) {
+    return NumElements > 0 &&
+           NumElements <= ConstantMatrixTypeBitfields::MaxElementsPerDimension;
+  }
+
+  void Profile(llvm::FoldingSetNodeID &ID) {
+    Profile(ID, getElementType(), getNumRows(), getNumColumns(),
+            getTypeClass());
+  }
+
+  static void Profile(llvm::FoldingSetNodeID &ID, QualType ElementType,
+                      unsigned NumRows, unsigned NumColumns,
+                      TypeClass TypeClass) {
+    ID.AddPointer(ElementType.getAsOpaquePtr());
+    ID.AddInteger(NumRows);
+    ID.AddInteger(NumColumns);
+    ID.AddInteger(TypeClass);
+  }
+
+  static bool classof(const Type *T) {
+    return T->getTypeClass() == ConstantMatrix;
+  }
+};
+
+/// Represents a matrix type where the type and the number of rows and columns
+/// is dependent on a template.
+class DependentSizedMatrixType final : public MatrixType {
+  friend class ASTContext;
+
+  const ASTContext &Context;
+  Expr *RowExpr;
+  Expr *ColumnExpr;
+
+  SourceLocation loc;
+
+  DependentSizedMatrixType(const ASTContext &Context, QualType ElementType,
+                           QualType CanonicalType, Expr *RowExpr,
+                           Expr *ColumnExpr, SourceLocation loc);
+
+public:
+  QualType getElementType() const { return ElementType; }
+  Expr *getRowExpr() const { return RowExpr; }
+  Expr *getColumnExpr() const { return ColumnExpr; }
+  SourceLocation getAttributeLoc() const { return loc; }
+
+  bool isSugared() const { return false; }
+  QualType desugar() const { return QualType(this, 0); }
+
+  static bool classof(const Type *T) {
+    return T->getTypeClass() == DependentSizedMatrix;
+  }
+
+  void Profile(llvm::FoldingSetNodeID &ID) {
+    Profile(ID, Context, getElementType(), getRowExpr(), getColumnExpr());
+  }
+
+  static void Profile(llvm::FoldingSetNodeID &ID, const ASTContext &Context,
+                      QualType ElementType, Expr *RowExpr, Expr *ColumnExpr);
+};
+
 /// FunctionType - C99 6.7.5.3 - Function Declarators.  This is the common base
 /// class of FunctionNoProtoType and FunctionProtoType.
 class FunctionType : public Type {
@@ -6605,6 +6745,10 @@
   return isa<ExtVectorType>(CanonicalType);
 }
 
+inline bool Type::isConstantMatrixType() const {
+  return isa<ConstantMatrixType>(CanonicalType);
+}
+
 inline bool Type::isDependentAddressSpaceType() const {
   return isa<DependentAddressSpaceType>(CanonicalType);
 }
Index: clang/include/clang/AST/RecursiveASTVisitor.h
===================================================================
--- clang/include/clang/AST/RecursiveASTVisitor.h
+++ clang/include/clang/AST/RecursiveASTVisitor.h
@@ -1006,6 +1006,17 @@
 
 DEF_TRAVERSE_TYPE(ExtVectorType, { TRY_TO(TraverseType(T->getElementType())); })
 
+DEF_TRAVERSE_TYPE(ConstantMatrixType,
+                  { TRY_TO(TraverseType(T->getElementType())); })
+
+DEF_TRAVERSE_TYPE(DependentSizedMatrixType, {
+  if (T->getRowExpr())
+    TRY_TO(TraverseStmt(T->getRowExpr()));
+  if (T->getColumnExpr())
+    TRY_TO(TraverseStmt(T->getColumnExpr()));
+  TRY_TO(TraverseType(T->getElementType()));
+})
+
 DEF_TRAVERSE_TYPE(FunctionNoProtoType,
                   { TRY_TO(TraverseType(T->getReturnType())); })
 
@@ -1258,6 +1269,18 @@
   TRY_TO(TraverseType(TL.getTypePtr()->getElementType()));
 })
 
+DEF_TRAVERSE_TYPELOC(ConstantMatrixType, {
+  TRY_TO(TraverseStmt(TL.getAttrRowOperand()));
+  TRY_TO(TraverseStmt(TL.getAttrColumnOperand()));
+  TRY_TO(TraverseType(TL.getTypePtr()->getElementType()));
+})
+
+DEF_TRAVERSE_TYPELOC(DependentSizedMatrixType, {
+  TRY_TO(TraverseStmt(TL.getAttrRowOperand()));
+  TRY_TO(TraverseStmt(TL.getAttrColumnOperand()));
+  TRY_TO(TraverseType(TL.getTypePtr()->getElementType()));
+})
+
 DEF_TRAVERSE_TYPELOC(FunctionNoProtoType,
                      { TRY_TO(TraverseTypeLoc(TL.getReturnLoc())); })
 
Index: clang/include/clang/AST/ASTContext.h
===================================================================
--- clang/include/clang/AST/ASTContext.h
+++ clang/include/clang/AST/ASTContext.h
@@ -194,6 +194,8 @@
       DependentAddressSpaceTypes;
   mutable llvm::FoldingSet<VectorType> VectorTypes;
   mutable llvm::FoldingSet<DependentVectorType> DependentVectorTypes;
+  mutable llvm::FoldingSet<ConstantMatrixType> MatrixTypes;
+  mutable llvm::FoldingSet<DependentSizedMatrixType> DependentSizedMatrixTypes;
   mutable llvm::FoldingSet<FunctionNoProtoType> FunctionNoProtoTypes;
   mutable llvm::ContextualFoldingSet<FunctionProtoType, ASTContext&>
     FunctionProtoTypes;
@@ -1326,6 +1328,20 @@
                                           Expr *SizeExpr,
                                           SourceLocation AttrLoc) const;
 
+  /// Return the unique reference to the matrix type of the specified element
+  /// type and size
+  ///
+  /// \pre \p ElementType must be a valid matrix element type (see
+  /// MatrixType::isValidElementType).
+  QualType getConstantMatrixType(QualType ElementType, unsigned NumRows,
+                                 unsigned NumColumns) const;
+
+  /// Return the unique reference to the matrix type of the specified element
+  /// type and size
+  QualType getDependentSizedMatrixType(QualType ElementType, Expr *RowExpr,
+                                       Expr *ColumnExpr,
+                                       SourceLocation AttrLoc) const;
+
   QualType getDependentAddressSpaceType(QualType PointeeType,
                                         Expr *AddrSpaceExpr,
                                         SourceLocation AttrLoc) const;
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to