fhahn updated this revision to Diff 266303.
fhahn marked an inline comment as done.
fhahn added a comment.

Add support for user-defined conversion function, use PrepareScalarCast and add 
overloads for matrix +/- operators.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D76793/new/

https://reviews.llvm.org/D76793

Files:
  clang/include/clang/Sema/Sema.h
  clang/lib/CodeGen/CGExprScalar.cpp
  clang/lib/Sema/SemaExpr.cpp
  clang/lib/Sema/SemaOverload.cpp
  clang/test/CodeGen/matrix-type-operators.c
  clang/test/CodeGenCXX/matrix-type-operators.cpp
  clang/test/Sema/matrix-type-operators.c
  clang/test/SemaCXX/matrix-type-operators.cpp
  llvm/include/llvm/IR/MatrixBuilder.h

Index: llvm/include/llvm/IR/MatrixBuilder.h
===================================================================
--- llvm/include/llvm/IR/MatrixBuilder.h
+++ llvm/include/llvm/IR/MatrixBuilder.h
@@ -127,6 +127,16 @@
   /// Add matrixes \p LHS and \p RHS. Support both integer and floating point
   /// matrixes.
   Value *CreateAdd(Value *LHS, Value *RHS) {
+    assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());
+    if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy())
+      RHS = B.CreateVectorSplat(
+          cast<VectorType>(LHS->getType())->getNumElements(), RHS,
+          "scalar.splat");
+    else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy())
+      LHS = B.CreateVectorSplat(
+          cast<VectorType>(RHS->getType())->getNumElements(), LHS,
+          "scalar.splat");
+
     return cast<VectorType>(LHS->getType())
                    ->getElementType()
                    ->isFloatingPointTy()
@@ -137,6 +147,16 @@
   /// Subtract matrixes \p LHS and \p RHS. Support both integer and floating
   /// point matrixes.
   Value *CreateSub(Value *LHS, Value *RHS) {
+    assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());
+    if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy())
+      RHS = B.CreateVectorSplat(
+          cast<VectorType>(LHS->getType())->getNumElements(), RHS,
+          "scalar.splat");
+    else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy())
+      LHS = B.CreateVectorSplat(
+          cast<VectorType>(RHS->getType())->getNumElements(), LHS,
+          "scalar.splat");
+
     return cast<VectorType>(LHS->getType())
                    ->getElementType()
                    ->isFloatingPointTy()
Index: clang/test/SemaCXX/matrix-type-operators.cpp
===================================================================
--- clang/test/SemaCXX/matrix-type-operators.cpp
+++ clang/test/SemaCXX/matrix-type-operators.cpp
@@ -84,3 +84,96 @@
   a[2] = f;
   // expected-error@-1 {{single subscript expressions are not allowed for matrix values}}
 }
+
+template <typename EltTy, unsigned Rows, unsigned Columns>
+struct MyMatrix {
+  using matrix_t = EltTy __attribute__((matrix_type(Rows, Columns)));
+
+  matrix_t value;
+};
+
+template <typename EltTy0, unsigned R0, unsigned C0, typename EltTy1, unsigned R1, unsigned C1, typename EltTy2, unsigned R2, unsigned C2>
+typename MyMatrix<EltTy2, R2, C2>::matrix_t add(MyMatrix<EltTy0, R0, C0> &A, MyMatrix<EltTy1, R1, C1> &B) {
+  char *v1 = A.value + B.value;
+  // expected-error@-1 {{cannot initialize a variable of type 'char *' with an rvalue of type 'MyMatrix<unsigned int, 2, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))')}}
+  // expected-error@-2 {{invalid operands to binary expression ('MyMatrix<unsigned int, 3, 3>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))') and 'MyMatrix<float, 2, 2>::matrix_t' (aka 'float __attribute__((matrix_type(2, 2)))'))}}
+  // expected-error@-3 {{invalid operands to binary expression ('MyMatrix<unsigned int, 2, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))') and 'MyMatrix<unsigned int, 3, 3>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))'))}}
+
+  return A.value + B.value;
+  // expected-error@-1 {{invalid operands to binary expression ('MyMatrix<unsigned int, 3, 3>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))') and 'MyMatrix<float, 2, 2>::matrix_t' (aka 'float __attribute__((matrix_type(2, 2)))'))}}
+  // expected-error@-2 {{invalid operands to binary expression ('MyMatrix<unsigned int, 2, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))') and 'MyMatrix<unsigned int, 3, 3>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))'))}}
+}
+
+void test_add_template(unsigned *Ptr1, float *Ptr2) {
+  MyMatrix<unsigned, 2, 2> Mat1;
+  MyMatrix<unsigned, 3, 3> Mat2;
+  MyMatrix<float, 2, 2> Mat3;
+  Mat1.value = *((decltype(Mat1)::matrix_t *)Ptr1);
+  unsigned v1 = add<unsigned, 2, 2, unsigned, 2, 2, unsigned, 2, 2>(Mat1, Mat1);
+  // expected-error@-1 {{cannot initialize a variable of type 'unsigned int' with an rvalue of type 'typename MyMatrix<unsigned int, 2U, 2U>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))')}}
+  // expected-note@-2 {{in instantiation of function template specialization 'add<unsigned int, 2, 2, unsigned int, 2, 2, unsigned int, 2, 2>' requested here}}
+
+  Mat1.value = add<unsigned, 2, 2, unsigned, 3, 3, unsigned, 2, 2>(Mat1, Mat2);
+  // expected-note@-1 {{in instantiation of function template specialization 'add<unsigned int, 2, 2, unsigned int, 3, 3, unsigned int, 2, 2>' requested here}}
+
+  Mat1.value = add<unsigned, 3, 3, float, 2, 2, unsigned, 2, 2>(Mat2, Mat3);
+  // expected-note@-1 {{in instantiation of function template specialization 'add<unsigned int, 3, 3, float, 2, 2, unsigned int, 2, 2>' requested here}}
+}
+
+template <typename EltTy0, unsigned R0, unsigned C0, typename EltTy1, unsigned R1, unsigned C1, typename EltTy2, unsigned R2, unsigned C2>
+typename MyMatrix<EltTy2, R2, C2>::matrix_t subtract(MyMatrix<EltTy0, R0, C0> &A, MyMatrix<EltTy1, R1, C1> &B) {
+  char *v1 = A.value - B.value;
+  // expected-error@-1 {{cannot initialize a variable of type 'char *' with an rvalue of type 'MyMatrix<unsigned int, 2, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))')}}
+  // expected-error@-2 {{invalid operands to binary expression ('MyMatrix<unsigned int, 3, 3>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))') and 'MyMatrix<float, 2, 2>::matrix_t' (aka 'float __attribute__((matrix_type(2, 2)))')}}
+  // expected-error@-3 {{invalid operands to binary expression ('MyMatrix<unsigned int, 2, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))') and 'MyMatrix<unsigned int, 3, 3>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))')}}
+
+  return A.value - B.value;
+  // expected-error@-1 {{invalid operands to binary expression ('MyMatrix<unsigned int, 3, 3>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))') and 'MyMatrix<float, 2, 2>::matrix_t' (aka 'float __attribute__((matrix_type(2, 2)))')}}
+  // expected-error@-2 {{invalid operands to binary expression ('MyMatrix<unsigned int, 2, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))') and 'MyMatrix<unsigned int, 3, 3>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))')}}
+}
+
+void test_subtract_template(unsigned *Ptr1, float *Ptr2) {
+  MyMatrix<unsigned, 2, 2> Mat1;
+  MyMatrix<unsigned, 3, 3> Mat2;
+  MyMatrix<float, 2, 2> Mat3;
+  Mat1.value = *((decltype(Mat1)::matrix_t *)Ptr1);
+  unsigned v1 = subtract<unsigned, 2, 2, unsigned, 2, 2, unsigned, 2, 2>(Mat1, Mat1);
+  // expected-error@-1 {{cannot initialize a variable of type 'unsigned int' with an rvalue of type 'typename MyMatrix<unsigned int, 2U, 2U>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))')}}
+  // expected-note@-2 {{in instantiation of function template specialization 'subtract<unsigned int, 2, 2, unsigned int, 2, 2, unsigned int, 2, 2>' requested here}}
+
+  Mat1.value = subtract<unsigned, 2, 2, unsigned, 3, 3, unsigned, 2, 2>(Mat1, Mat2);
+  // expected-note@-1 {{in instantiation of function template specialization 'subtract<unsigned int, 2, 2, unsigned int, 3, 3, unsigned int, 2, 2>' requested here}}
+
+  Mat1.value = subtract<unsigned, 3, 3, float, 2, 2, unsigned, 2, 2>(Mat2, Mat3);
+  // expected-note@-1 {{in instantiation of function template specialization 'subtract<unsigned int, 3, 3, float, 2, 2, unsigned int, 2, 2>' requested here}}
+}
+
+struct UserT {};
+
+struct StructWithC {
+  operator UserT() {
+    // expected-note@-1 {{candidate function}}
+    // expected-note@-2 {{candidate function}}
+    // expected-note@-3 {{candidate function}}
+    // expected-note@-4 {{candidate function}}
+    return {};
+  }
+};
+
+void test_DoubleWrapper(MyMatrix<double, 10, 9> &m, StructWithC &c) {
+  m.value = m.value + c;
+  // expected-error@-1 {{no viable conversion from 'StructWithC' to 'double'}}
+  // expected-error@-2 {{invalid operands to binary expression ('MyMatrix<double, 10, 9>::matrix_t' (aka 'double __attribute__((matrix_type(10, 9)))') and 'StructWithC')}}
+
+  m.value = c + m.value;
+  // expected-error@-1 {{no viable conversion from 'StructWithC' to 'double'}}
+  // expected-error@-2 {{invalid operands to binary expression ('StructWithC' and 'MyMatrix<double, 10, 9>::matrix_t' (aka 'double __attribute__((matrix_type(10, 9)))'))}}
+
+  m.value = m.value - c;
+  // expected-error@-1 {{no viable conversion from 'StructWithC' to 'double'}}
+  // expected-error@-2 {{invalid operands to binary expression ('MyMatrix<double, 10, 9>::matrix_t' (aka 'double __attribute__((matrix_type(10, 9)))') and 'StructWithC')}}
+
+  m.value = c - m.value;
+  // expected-error@-1 {{no viable conversion from 'StructWithC' to 'double'}}
+  // expected-error@-2 {{invalid operands to binary expression ('StructWithC' and 'MyMatrix<double, 10, 9>::matrix_t' (aka 'double __attribute__((matrix_type(10, 9)))'))}}
+}
Index: clang/test/Sema/matrix-type-operators.c
===================================================================
--- clang/test/Sema/matrix-type-operators.c
+++ clang/test/Sema/matrix-type-operators.c
@@ -91,3 +91,34 @@
   float v12 = a[3];
   // expected-error@-1 {{single subscript expressions are not allowed for matrix values}}
 }
+
+typedef float sx10x5_t __attribute__((matrix_type(10, 5)));
+typedef float sx10x10_t __attribute__((matrix_type(10, 10)));
+
+void add(sx10x10_t a, sx5x10_t b, sx10x5_t c) {
+  a = b + c;
+  // expected-error@-1 {{invalid operands to binary expression ('sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') and 'sx10x5_t' (aka 'float __attribute__((matrix_type(10, 5)))'))}}
+
+  a = b + b; // expected-error {{assigning to 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))') from incompatible type 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))')}}
+
+  a = 10 + b;
+  // expected-error@-1 {{assigning to 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))') from incompatible type 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))')}}
+
+  a = b + &c;
+  // expected-error@-1 {{invalid operands to binary expression ('sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') and 'sx10x5_t *' (aka 'float  __attribute__((matrix_type(10, 5)))*'))}}
+  // expected-error@-2 {{casting 'sx10x5_t *' (aka 'float  __attribute__((matrix_type(10, 5)))*') to incompatible type 'float'}}
+}
+
+void sub(sx10x10_t a, sx5x10_t b, sx10x5_t c) {
+  a = b - c;
+  // expected-error@-1 {{invalid operands to binary expression ('sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') and 'sx10x5_t' (aka 'float __attribute__((matrix_type(10, 5)))'))}}
+
+  a = b - b; // expected-error {{assigning to 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))') from incompatible type 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))')}}
+
+  a = 10 - b;
+  // expected-error@-1 {{assigning to 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))') from incompatible type 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))')}}
+
+  a = b - &c;
+  // expected-error@-1 {{invalid operands to binary expression ('sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') and 'sx10x5_t *' (aka 'float  __attribute__((matrix_type(10, 5)))*'))}}
+  // expected-error@-2 {{casting 'sx10x5_t *' (aka 'float  __attribute__((matrix_type(10, 5)))*') to incompatible type 'float'}}
+}
Index: clang/test/CodeGenCXX/matrix-type-operators.cpp
===================================================================
--- clang/test/CodeGenCXX/matrix-type-operators.cpp
+++ clang/test/CodeGenCXX/matrix-type-operators.cpp
@@ -242,3 +242,252 @@
 
   return matrix_subscript(m, 1, 2);
 }
+
+template <typename EltTy0, unsigned R0, unsigned C0>
+typename MyMatrix<EltTy0, R0, C0>::matrix_t add(MyMatrix<EltTy0, R0, C0> &A, MyMatrix<EltTy0, R0, C0> &B) {
+  return A.value + B.value;
+}
+
+void test_add_template() {
+  // CHECK-LABEL:    define void @_Z17test_add_templatev()
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %Mat1 = alloca %struct.MyMatrix.1, align 4
+  // CHECK-NEXT:    %Mat2 = alloca %struct.MyMatrix.1, align 4
+  // CHECK-NEXT:    %call = call <10 x float> @_Z3addIfLj2ELj5EEN8MyMatrixIT_XT0_EXT1_EE8matrix_tERS2_S4_(%struct.MyMatrix.1* nonnull align 4 dereferenceable(40) %Mat1, %struct.MyMatrix.1* nonnull align 4 dereferenceable(40) %Mat2)
+  // CHECK-NEXT:    %value = getelementptr inbounds %struct.MyMatrix.1, %struct.MyMatrix.1* %Mat1, i32 0, i32 0
+  // CHECK-NEXT:    %0 = bitcast [10 x float]* %value to <10 x float>*
+  // CHECK-NEXT:    store <10 x float> %call, <10 x float>* %0, align 4
+  // CHECK-NEXT:    ret void
+
+  // CHECK-LABEL: define linkonce_odr <10 x float> @_Z3addIfLj2ELj5EEN8MyMatrixIT_XT0_EXT1_EE8matrix_tERS2_S4_(%struct.MyMatrix.1* nonnull align 4 dereferenceable(40) %A, %struct.MyMatrix.1* nonnull align 4 dereferenceable(40) %B)
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %A.addr = alloca %struct.MyMatrix.1*, align 8
+  // CHECK-NEXT:    %B.addr = alloca %struct.MyMatrix.1*, align 8
+  // CHECK-NEXT:    store %struct.MyMatrix.1* %A, %struct.MyMatrix.1** %A.addr, align 8
+  // CHECK-NEXT:    store %struct.MyMatrix.1* %B, %struct.MyMatrix.1** %B.addr, align 8
+  // CHECK-NEXT:    %0 = load %struct.MyMatrix.1*, %struct.MyMatrix.1** %A.addr, align 8
+  // CHECK-NEXT:    %value = getelementptr inbounds %struct.MyMatrix.1, %struct.MyMatrix.1* %0, i32 0, i32 0
+  // CHECK-NEXT:    %1 = bitcast [10 x float]* %value to <10 x float>*
+  // CHECK-NEXT:    %2 = load <10 x float>, <10 x float>* %1, align 4
+  // CHECK-NEXT:    %3 = load %struct.MyMatrix.1*, %struct.MyMatrix.1** %B.addr, align 8
+  // CHECK-NEXT:    %value1 = getelementptr inbounds %struct.MyMatrix.1, %struct.MyMatrix.1* %3, i32 0, i32 0
+  // CHECK-NEXT:    %4 = bitcast [10 x float]* %value1 to <10 x float>*
+  // CHECK-NEXT:    %5 = load <10 x float>, <10 x float>* %4, align 4
+  // CHECK-NEXT:    %6 = fadd <10 x float> %2, %5
+  // CHECK-NEXT:    ret <10 x float> %6
+
+  MyMatrix<float, 2, 5> Mat1;
+  MyMatrix<float, 2, 5> Mat2;
+  Mat1.value = add(Mat1, Mat2);
+}
+
+template <typename EltTy0, unsigned R0, unsigned C0>
+typename MyMatrix<EltTy0, R0, C0>::matrix_t subtract(MyMatrix<EltTy0, R0, C0> &A, MyMatrix<EltTy0, R0, C0> &B) {
+  return A.value - B.value;
+}
+
+void test_subtract_template() {
+  // CHECK-LABEL: define void @_Z22test_subtract_templatev()
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %Mat1 = alloca %struct.MyMatrix.1, align 4
+  // CHECK-NEXT:    %Mat2 = alloca %struct.MyMatrix.1, align 4
+  // CHECK-NEXT:    %call = call <10 x float> @_Z8subtractIfLj2ELj5EEN8MyMatrixIT_XT0_EXT1_EE8matrix_tERS2_S4_(%struct.MyMatrix.1* nonnull align 4 dereferenceable(40) %Mat1, %struct.MyMatrix.1* nonnull align 4 dereferenceable(40) %Mat2)
+  // CHECK-NEXT:    %value = getelementptr inbounds %struct.MyMatrix.1, %struct.MyMatrix.1* %Mat1, i32 0, i32 0
+  // CHECK-NEXT:    %0 = bitcast [10 x float]* %value to <10 x float>*
+  // CHECK-NEXT:    store <10 x float> %call, <10 x float>* %0, align 4
+  // CHECK-NEXT:    ret void
+
+  // CHECK-LABEL: define linkonce_odr <10 x float> @_Z8subtractIfLj2ELj5EEN8MyMatrixIT_XT0_EXT1_EE8matrix_tERS2_S4_(%struct.MyMatrix.1* nonnull align 4 dereferenceable(40) %A, %struct.MyMatrix.1* nonnull align 4 dereferenceable(40) %B)
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %A.addr = alloca %struct.MyMatrix.1*, align 8
+  // CHECK-NEXT:    %B.addr = alloca %struct.MyMatrix.1*, align 8
+  // CHECK-NEXT:    store %struct.MyMatrix.1* %A, %struct.MyMatrix.1** %A.addr, align 8
+  // CHECK-NEXT:    store %struct.MyMatrix.1* %B, %struct.MyMatrix.1** %B.addr, align 8
+  // CHECK-NEXT:    %0 = load %struct.MyMatrix.1*, %struct.MyMatrix.1** %A.addr, align 8
+  // CHECK-NEXT:    %value = getelementptr inbounds %struct.MyMatrix.1, %struct.MyMatrix.1* %0, i32 0, i32 0
+  // CHECK-NEXT:    %1 = bitcast [10 x float]* %value to <10 x float>*
+  // CHECK-NEXT:    %2 = load <10 x float>, <10 x float>* %1, align 4
+  // CHECK-NEXT:    %3 = load %struct.MyMatrix.1*, %struct.MyMatrix.1** %B.addr, align 8
+  // CHECK-NEXT:    %value1 = getelementptr inbounds %struct.MyMatrix.1, %struct.MyMatrix.1* %3, i32 0, i32 0
+  // CHECK-NEXT:    %4 = bitcast [10 x float]* %value1 to <10 x float>*
+  // CHECK-NEXT:    %5 = load <10 x float>, <10 x float>* %4, align 4
+  // CHECK-NEXT:    %6 = fsub <10 x float> %2, %5
+  // CHECK-NEXT:    ret <10 x float> %6
+
+  MyMatrix<float, 2, 5> Mat1;
+  MyMatrix<float, 2, 5> Mat2;
+  Mat1.value = subtract(Mat1, Mat2);
+}
+
+struct DoubleWrapper1 {
+  int x;
+  operator double() {
+    return x;
+  }
+};
+
+struct DoubleWrapper2 {
+  int x;
+  operator double() {
+    return x;
+  }
+};
+
+struct IntWrapper {
+  char x;
+  operator int() {
+    return x;
+  }
+};
+
+void test_DoubleWrapper(MyMatrix<double, 10, 9> &m, MyMatrix<int, 3, 4> &m2) {
+  // CHECK-LABEL:  define void @_Z18test_DoubleWrapperR8MyMatrixIdLj10ELj9EERS_IiLj3ELj4EE(%struct.MyMatrix.2* nonnull align 8 dereferenceable(720) %m, %struct.MyMatrix.3* nonnull align 4 dereferenceable(48) %m2)
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %m.addr = alloca %struct.MyMatrix.2*, align 8
+  // CHECK-NEXT:    %m2.addr = alloca %struct.MyMatrix.3*, align 8
+  // CHECK-NEXT:    %w1 = alloca %struct.DoubleWrapper1, align 4
+  // CHECK-NEXT:    %w2 = alloca %struct.DoubleWrapper2, align 4
+  // CHECK-NEXT:    %w3 = alloca %struct.IntWrapper, align 1
+  // CHECK-NEXT:    store %struct.MyMatrix.2* %m, %struct.MyMatrix.2** %m.addr, align 8
+  // CHECK-NEXT:    store %struct.MyMatrix.3* %m2, %struct.MyMatrix.3** %m2.addr, align 8
+  // CHECK-NEXT:    %x = getelementptr inbounds %struct.DoubleWrapper1, %struct.DoubleWrapper1* %w1, i32 0, i32 0
+  // CHECK-NEXT:    store i32 10, i32* %x, align 4
+  // CHECK-NEXT:    %0 = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %m.addr, align 8
+  // CHECK-NEXT:    %value = getelementptr inbounds %struct.MyMatrix.2, %struct.MyMatrix.2* %0, i32 0, i32 0
+  // CHECK-NEXT:    %1 = bitcast [90 x double]* %value to <90 x double>*
+  // CHECK-NEXT:    %2 = load <90 x double>, <90 x double>* %1, align 8
+  // CHECK-NEXT:    %call = call double @_ZN14DoubleWrapper1cvdEv(%struct.DoubleWrapper1* %w1)
+  // CHECK-NEXT:    %scalar.splat.splatinsert = insertelement <90 x double> undef, double %call, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat = shufflevector <90 x double> %scalar.splat.splatinsert, <90 x double> undef, <90 x i32> zeroinitializer
+  // CHECK-NEXT:    %3 = fadd <90 x double> %2, %scalar.splat.splat
+  // CHECK-NEXT:    %4 = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %m.addr, align 8
+  // CHECK-NEXT:    %value1 = getelementptr inbounds %struct.MyMatrix.2, %struct.MyMatrix.2* %4, i32 0, i32 0
+  // CHECK-NEXT:    %5 = bitcast [90 x double]* %value1 to <90 x double>*
+  // CHECK-NEXT:    store <90 x double> %3, <90 x double>* %5, align 8
+  // CHECK-NEXT:    %call2 = call double @_ZN14DoubleWrapper1cvdEv(%struct.DoubleWrapper1* %w1)
+  // CHECK-NEXT:    %6 = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %m.addr, align 8
+  // CHECK-NEXT:    %value3 = getelementptr inbounds %struct.MyMatrix.2, %struct.MyMatrix.2* %6, i32 0, i32 0
+  // CHECK-NEXT:    %7 = bitcast [90 x double]* %value3 to <90 x double>*
+  // CHECK-NEXT:    %8 = load <90 x double>, <90 x double>* %7, align 8
+  // CHECK-NEXT:    %scalar.splat.splatinsert4 = insertelement <90 x double> undef, double %call2, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat5 = shufflevector <90 x double> %scalar.splat.splatinsert4, <90 x double> undef, <90 x i32> zeroinitializer
+  // CHECK-NEXT:    %9 = fadd <90 x double> %scalar.splat.splat5, %8
+  // CHECK-NEXT:    %10 = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %m.addr, align 8
+  // CHECK-NEXT:    %value6 = getelementptr inbounds %struct.MyMatrix.2, %struct.MyMatrix.2* %10, i32 0, i32 0
+  // CHECK-NEXT:    %11 = bitcast [90 x double]* %value6 to <90 x double>*
+  // CHECK-NEXT:    store <90 x double> %9, <90 x double>* %11, align 8
+  // CHECK-NEXT:    %call7 = call double @_ZN14DoubleWrapper1cvdEv(%struct.DoubleWrapper1* %w1)
+  // CHECK-NEXT:    %12 = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %m.addr, align 8
+  // CHECK-NEXT:    %value8 = getelementptr inbounds %struct.MyMatrix.2, %struct.MyMatrix.2* %12, i32 0, i32 0
+  // CHECK-NEXT:    %13 = bitcast [90 x double]* %value8 to <90 x double>*
+  // CHECK-NEXT:    %14 = load <90 x double>, <90 x double>* %13, align 8
+  // CHECK-NEXT:    %scalar.splat.splatinsert9 = insertelement <90 x double> undef, double %call7, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat10 = shufflevector <90 x double> %scalar.splat.splatinsert9, <90 x double> undef, <90 x i32> zeroinitializer
+  // CHECK-NEXT:    %15 = fsub <90 x double> %scalar.splat.splat10, %14
+  // CHECK-NEXT:    %16 = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %m.addr, align 8
+  // CHECK-NEXT:    %value11 = getelementptr inbounds %struct.MyMatrix.2, %struct.MyMatrix.2* %16, i32 0, i32 0
+  // CHECK-NEXT:    %17 = bitcast [90 x double]* %value11 to <90 x double>*
+  // CHECK-NEXT:    store <90 x double> %15, <90 x double>* %17, align 8
+  // CHECK-NEXT:    %x12 = getelementptr inbounds %struct.DoubleWrapper2, %struct.DoubleWrapper2* %w2, i32 0, i32 0
+  // CHECK-NEXT:    store i32 20, i32* %x12, align 4
+
+  DoubleWrapper1 w1;
+  w1.x = 10;
+  m.value = m.value + w1;
+  m.value = w1 + m.value;
+  m.value = w1 - m.value;
+
+  // CHECK-NEXT:    %18 = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %m.addr, align 8
+  // CHECK-NEXT:    %value13 = getelementptr inbounds %struct.MyMatrix.2, %struct.MyMatrix.2* %18, i32 0, i32 0
+  // CHECK-NEXT:    %19 = bitcast [90 x double]* %value13 to <90 x double>*
+  // CHECK-NEXT:    %20 = load <90 x double>, <90 x double>* %19, align 8
+  // CHECK-NEXT:    %call14 = call double @_ZN14DoubleWrapper2cvdEv(%struct.DoubleWrapper2* %w2)
+  // CHECK-NEXT:    %scalar.splat.splatinsert15 = insertelement <90 x double> undef, double %call14, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat16 = shufflevector <90 x double> %scalar.splat.splatinsert15, <90 x double> undef, <90 x i32> zeroinitializer
+  // CHECK-NEXT:    %21 = fadd <90 x double> %20, %scalar.splat.splat16
+  // CHECK-NEXT:    %22 = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %m.addr, align 8
+  // CHECK-NEXT:    %value17 = getelementptr inbounds %struct.MyMatrix.2, %struct.MyMatrix.2* %22, i32 0, i32 0
+  // CHECK-NEXT:    %23 = bitcast [90 x double]* %value17 to <90 x double>*
+  // CHECK-NEXT:    store <90 x double> %21, <90 x double>* %23, align 8
+  // CHECK-NEXT:    %call18 = call double @_ZN14DoubleWrapper2cvdEv(%struct.DoubleWrapper2* %w2)
+  // CHECK-NEXT:    %24 = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %m.addr, align 8
+  // CHECK-NEXT:    %value19 = getelementptr inbounds %struct.MyMatrix.2, %struct.MyMatrix.2* %24, i32 0, i32 0
+  // CHECK-NEXT:    %25 = bitcast [90 x double]* %value19 to <90 x double>*
+  // CHECK-NEXT:    %26 = load <90 x double>, <90 x double>* %25, align 8
+  // CHECK-NEXT:    %scalar.splat.splatinsert20 = insertelement <90 x double> undef, double %call18, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat21 = shufflevector <90 x double> %scalar.splat.splatinsert20, <90 x double> undef, <90 x i32> zeroinitializer
+  // CHECK-NEXT:    %27 = fadd <90 x double> %scalar.splat.splat21, %26
+  // CHECK-NEXT:    %28 = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %m.addr, align 8
+  // CHECK-NEXT:    %value22 = getelementptr inbounds %struct.MyMatrix.2, %struct.MyMatrix.2* %28, i32 0, i32 0
+  // CHECK-NEXT:    %29 = bitcast [90 x double]* %value22 to <90 x double>*
+  // CHECK-NEXT:    store <90 x double> %27, <90 x double>* %29, align 8
+  DoubleWrapper2 w2;
+  w2.x = 20;
+  m.value = m.value + w2;
+  m.value = w2 + m.value;
+
+  // CHECK-NEXT:    %x23 = getelementptr inbounds %struct.IntWrapper, %struct.IntWrapper* %w3, i32 0, i32 0
+  // CHECK-NEXT:    store i8 99, i8* %x23, align 1
+  // CHECK-NEXT:    %30 = load %struct.MyMatrix.3*, %struct.MyMatrix.3** %m2.addr, align 8
+  // CHECK-NEXT:    %value24 = getelementptr inbounds %struct.MyMatrix.3, %struct.MyMatrix.3* %30, i32 0, i32 0
+  // CHECK-NEXT:    %31 = bitcast [12 x i32]* %value24 to <12 x i32>*
+  // CHECK-NEXT:    %32 = load <12 x i32>, <12 x i32>* %31, align 4
+  // CHECK-NEXT:    %call25 = call i32 @_ZN10IntWrappercviEv(%struct.IntWrapper* %w3)
+  // CHECK-NEXT:    %scalar.splat.splatinsert26 = insertelement <12 x i32> undef, i32 %call25, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat27 = shufflevector <12 x i32> %scalar.splat.splatinsert26, <12 x i32> undef, <12 x i32> zeroinitializer
+  // CHECK-NEXT:    %33 = add <12 x i32> %32, %scalar.splat.splat27
+  // CHECK-NEXT:    %34 = load %struct.MyMatrix.3*, %struct.MyMatrix.3** %m2.addr, align 8
+  // CHECK-NEXT:    %value28 = getelementptr inbounds %struct.MyMatrix.3, %struct.MyMatrix.3* %34, i32 0, i32 0
+  // CHECK-NEXT:    %35 = bitcast [12 x i32]* %value28 to <12 x i32>*
+  // CHECK-NEXT:    store <12 x i32> %33, <12 x i32>* %35, align 4
+  // CHECK-NEXT:    %call29 = call i32 @_ZN10IntWrappercviEv(%struct.IntWrapper* %w3)
+  // CHECK-NEXT:    %36 = load %struct.MyMatrix.3*, %struct.MyMatrix.3** %m2.addr, align 8
+  // CHECK-NEXT:    %value30 = getelementptr inbounds %struct.MyMatrix.3, %struct.MyMatrix.3* %36, i32 0, i32 0
+  // CHECK-NEXT:    %37 = bitcast [12 x i32]* %value30 to <12 x i32>*
+  // CHECK-NEXT:    %38 = load <12 x i32>, <12 x i32>* %37, align 4
+  // CHECK-NEXT:    %scalar.splat.splatinsert31 = insertelement <12 x i32> undef, i32 %call29, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat32 = shufflevector <12 x i32> %scalar.splat.splatinsert31, <12 x i32> undef, <12 x i32> zeroinitializer
+  // CHECK-NEXT:    %39 = add <12 x i32> %scalar.splat.splat32, %38
+  // CHECK-NEXT:    %40 = load %struct.MyMatrix.3*, %struct.MyMatrix.3** %m2.addr, align 8
+  // CHECK-NEXT:    %value33 = getelementptr inbounds %struct.MyMatrix.3, %struct.MyMatrix.3* %40, i32 0, i32 0
+  // CHECK-NEXT:    %41 = bitcast [12 x i32]* %value33 to <12 x i32>*
+  // CHECK-NEXT:    store <12 x i32> %39, <12 x i32>* %41, align 4
+
+  IntWrapper w3;
+  w3.x = 'c';
+  m2.value = m2.value + w3;
+  m2.value = w3 + m2.value;
+
+  // int conversion function in struct and implicit cast to element type double.
+  // CHECK-NEXT:    %42 = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %m.addr, align 8
+  // CHECK-NEXT:    %value34 = getelementptr inbounds %struct.MyMatrix.2, %struct.MyMatrix.2* %42, i32 0, i32 0
+  // CHECK-NEXT:    %43 = bitcast [90 x double]* %value34 to <90 x double>*
+  // CHECK-NEXT:    %44 = load <90 x double>, <90 x double>* %43, align 8
+  // CHECK-NEXT:    %call35 = call i32 @_ZN10IntWrappercviEv(%struct.IntWrapper* %w3)
+  // CHECK-NEXT:    %conv = sitofp i32 %call35 to double
+  // CHECK-NEXT:    %scalar.splat.splatinsert36 = insertelement <90 x double> undef, double %conv, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat37 = shufflevector <90 x double> %scalar.splat.splatinsert36, <90 x double> undef, <90 x i32> zeroinitializer
+  // CHECK-NEXT:    %45 = fsub <90 x double> %44, %scalar.splat.splat37
+  // CHECK-NEXT:    %46 = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %m.addr, align 8
+  // CHECK-NEXT:    %value38 = getelementptr inbounds %struct.MyMatrix.2, %struct.MyMatrix.2* %46, i32 0, i32 0
+  // CHECK-NEXT:    %47 = bitcast [90 x double]* %value38 to <90 x double>*
+  // CHECK-NEXT:    store <90 x double> %45, <90 x double>* %47, align 8
+  // CHECK-NEXT:    %call39 = call i32 @_ZN10IntWrappercviEv(%struct.IntWrapper* %w3)
+  // CHECK-NEXT:    %conv40 = sitofp i32 %call39 to double
+  // CHECK-NEXT:    %48 = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %m.addr, align 8
+  // CHECK-NEXT:    %value41 = getelementptr inbounds %struct.MyMatrix.2, %struct.MyMatrix.2* %48, i32 0, i32 0
+  // CHECK-NEXT:    %49 = bitcast [90 x double]* %value41 to <90 x double>*
+  // CHECK-NEXT:    %50 = load <90 x double>, <90 x double>* %49, align 8
+  // CHECK-NEXT:    %scalar.splat.splatinsert42 = insertelement <90 x double> undef, double %conv40, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat43 = shufflevector <90 x double> %scalar.splat.splatinsert42, <90 x double> undef, <90 x i32> zeroinitializer
+  // CHECK-NEXT:    %51 = fsub <90 x double> %scalar.splat.splat43, %50
+  // CHECK-NEXT:    %52 = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %m.addr, align 8
+  // CHECK-NEXT:    %value44 = getelementptr inbounds %struct.MyMatrix.2, %struct.MyMatrix.2* %52, i32 0, i32 0
+  // CHECK-NEXT:    %53 = bitcast [90 x double]* %value44 to <90 x double>*
+  // CHECK-NEXT:    store <90 x double> %51, <90 x double>* %53, align 8
+  // CHECK-NEXT:    ret void
+  // CHECK-NEXT:  }
+
+  m.value = m.value - w3;
+  m.value = w3 - m.value;
+}
Index: clang/test/CodeGen/matrix-type-operators.c
===================================================================
--- clang/test/CodeGen/matrix-type-operators.c
+++ clang/test/CodeGen/matrix-type-operators.c
@@ -312,3 +312,311 @@
   // CHECK-NEXT:    ret void
   b[2][j] = b[0][k];
 }
+
+void add_matrix_matrix(dx5x5_t a, dx5x5_t b, dx5x5_t c, ix9x3_t ai, ix9x3_t bi, ix9x3_t ci) {
+  a = b + c;
+  ai = bi + ci;
+
+  // CHECK-LABEL: @add_matrix_matrix(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [25 x double], align 8
+  // CHECK-NEXT:    %b.addr = alloca [25 x double], align 8
+  // CHECK-NEXT:    %c.addr = alloca [25 x double], align 8
+  // CHECK-NEXT:    %ai.addr = alloca [27 x i32], align 4
+  // CHECK-NEXT:    %bi.addr = alloca [27 x i32], align 4
+  // CHECK-NEXT:    %ci.addr = alloca [27 x i32], align 4
+  // CHECK-NEXT:    %0 = bitcast [25 x double]* %a.addr to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %a, <25 x double>* %0, align 8
+  // CHECK-NEXT:    %1 = bitcast [25 x double]* %b.addr to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %b, <25 x double>* %1, align 8
+  // CHECK-NEXT:    %2 = bitcast [25 x double]* %c.addr to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %c, <25 x double>* %2, align 8
+  // CHECK-NEXT:    %3 = bitcast [27 x i32]* %ai.addr to <27 x i32>*
+  // CHECK-NEXT:    store <27 x i32> %ai, <27 x i32>* %3, align 4
+  // CHECK-NEXT:    %4 = bitcast [27 x i32]* %bi.addr to <27 x i32>*
+  // CHECK-NEXT:    store <27 x i32> %bi, <27 x i32>* %4, align 4
+  // CHECK-NEXT:    %5 = bitcast [27 x i32]* %ci.addr to <27 x i32>*
+  // CHECK-NEXT:    store <27 x i32> %ci, <27 x i32>* %5, align 4
+  // CHECK-NEXT:    %6 = load <25 x double>, <25 x double>* %1, align 8
+  // CHECK-NEXT:    %7 = load <25 x double>, <25 x double>* %2, align 8
+  // CHECK-NEXT:    %8 = fadd <25 x double> %6, %7
+  // CHECK-NEXT:    store <25 x double> %8, <25 x double>* %0, align 8
+  // CHECK-NEXT:    %9 = load <27 x i32>, <27 x i32>* %4, align 4
+  // CHECK-NEXT:    %10 = load <27 x i32>, <27 x i32>* %5, align 4
+  // CHECK-NEXT:    %11 = add <27 x i32> %9, %10
+  // CHECK-NEXT:    store <27 x i32> %11, <27 x i32>* %3, align 4
+  // CHECK-NEXT:    ret void
+}
+
+void add_matrix_scalar_float(dx5x5_t a, fx2x3_t b, float vf, double vd) {
+  a = a + vf;
+  a = a + vd;
+
+  // CHECK-LABEL: define void @add_matrix_scalar_float(<25 x double> %a, <6 x float> %b, float %vf, double %vd)
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [25 x double], align 8
+  // CHECK-NEXT:    %b.addr = alloca [6 x float], align 4
+  // CHECK-NEXT:    %vf.addr = alloca float, align 4
+  // CHECK-NEXT:    %vd.addr = alloca double, align 8
+  // CHECK-NEXT:    %0 = bitcast [25 x double]* %a.addr to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %a, <25 x double>* %0, align 8
+  // CHECK-NEXT:    %1 = bitcast [6 x float]* %b.addr to <6 x float>*
+  // CHECK-NEXT:    store <6 x float> %b, <6 x float>* %1, align 4
+  // CHECK-NEXT:    store float %vf, float* %vf.addr, align 4
+  // CHECK-NEXT:    store double %vd, double* %vd.addr, align 8
+  // CHECK-NEXT:    %2 = load <25 x double>, <25 x double>* %0, align 8
+  // CHECK-NEXT:    %3 = load float, float* %vf.addr, align 4
+  // CHECK-NEXT:    %conv = fpext float %3 to double
+  // CHECK-NEXT:    %scalar.splat.splatinsert = insertelement <25 x double> undef, double %conv, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat = shufflevector <25 x double> %scalar.splat.splatinsert, <25 x double> undef, <25 x i32> zeroinitializer
+  // CHECK-NEXT:    %4 = fadd <25 x double> %2, %scalar.splat.splat
+  // CHECK-NEXT:    store <25 x double> %4, <25 x double>* %0, align 8
+  // CHECK-NEXT:    %5 = load <25 x double>, <25 x double>* %0, align 8
+  // CHECK-NEXT:    %6 = load double, double* %vd.addr, align 8
+  // CHECK-NEXT:    %scalar.splat.splatinsert1 = insertelement <25 x double> undef, double %6, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat2 = shufflevector <25 x double> %scalar.splat.splatinsert1, <25 x double> undef, <25 x i32> zeroinitializer
+  // CHECK-NEXT:    %7 = fadd <25 x double> %5, %scalar.splat.splat2
+  // CHECK-NEXT:    store <25 x double> %7, <25 x double>* %0, align 8
+
+  b = b + vf;
+  b = b + vd;
+
+  // CHECK-NEXT:    %8 = load <6 x float>, <6 x float>* %1, align 4
+  // CHECK-NEXT:    %9 = load float, float* %vf.addr, align 4
+  // CHECK-NEXT:    %scalar.splat.splatinsert3 = insertelement <6 x float> undef, float %9, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat4 = shufflevector <6 x float> %scalar.splat.splatinsert3, <6 x float> undef, <6 x i32> zeroinitializer
+  // CHECK-NEXT:    %10 = fadd <6 x float> %8, %scalar.splat.splat4
+  // CHECK-NEXT:    store <6 x float> %10, <6 x float>* %1, align 4
+  // CHECK-NEXT:    %11 = load <6 x float>, <6 x float>* %1, align 4
+  // CHECK-NEXT:    %12 = load double, double* %vd.addr, align 8
+  // CHECK-NEXT:    %conv5 = fptrunc double %12 to float
+  // CHECK-NEXT:    %scalar.splat.splatinsert6 = insertelement <6 x float> undef, float %conv5, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat7 = shufflevector <6 x float> %scalar.splat.splatinsert6, <6 x float> undef, <6 x i32> zeroinitializer
+  // CHECK-NEXT:    %13 = fadd <6 x float> %11, %scalar.splat.splat7
+  // CHECK-NEXT:    store <6 x float> %13, <6 x float>* %1, align 4
+  // CHECK-NEXT:    ret void
+}
+
+typedef int llix9x3_t __attribute__((matrix_type(9, 3)));
+
+void add_matrix_scalar_ints(ix9x3_t a, llix9x3_t b, short vs, long int vli, unsigned long long int vulli) {
+  a = a + vs;
+  a = a + vli;
+  a = a + vulli;
+
+  // CHECK-LABEL: define void @add_matrix_scalar_ints(<27 x i32> %a, <27 x i32> %b, i16 signext %vs, i64 %vli, i64 %vulli)
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [27 x i32], align 4
+  // CHECK-NEXT:    %b.addr = alloca [27 x i32], align 4
+  // CHECK-NEXT:    %vs.addr = alloca i16, align 2
+  // CHECK-NEXT:    %vli.addr = alloca i64, align 8
+  // CHECK-NEXT:    %vulli.addr = alloca i64, align 8
+  // CHECK-NEXT:    %0 = bitcast [27 x i32]* %a.addr to <27 x i32>*
+  // CHECK-NEXT:    store <27 x i32> %a, <27 x i32>* %0, align 4
+  // CHECK-NEXT:    %1 = bitcast [27 x i32]* %b.addr to <27 x i32>*
+  // CHECK-NEXT:    store <27 x i32> %b, <27 x i32>* %1, align 4
+  // CHECK-NEXT:    store i16 %vs, i16* %vs.addr, align 2
+  // CHECK-NEXT:    store i64 %vli, i64* %vli.addr, align 8
+  // CHECK-NEXT:    store i64 %vulli, i64* %vulli.addr, align 8
+  // CHECK-NEXT:    %2 = load <27 x i32>, <27 x i32>* %0, align 4
+  // CHECK-NEXT:    %3 = load i16, i16* %vs.addr, align 2
+  // CHECK-NEXT:    %conv = sext i16 %3 to i32
+  // CHECK-NEXT:    %scalar.splat.splatinsert = insertelement <27 x i32> undef, i32 %conv, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat = shufflevector <27 x i32> %scalar.splat.splatinsert, <27 x i32> undef, <27 x i32> zeroinitializer
+  // CHECK-NEXT:    %4 = add <27 x i32> %2, %scalar.splat.splat
+  // CHECK-NEXT:    store <27 x i32> %4, <27 x i32>* %0, align 4
+  // CHECK-NEXT:    %5 = load <27 x i32>, <27 x i32>* %0, align 4
+  // CHECK-NEXT:    %6 = load i64, i64* %vli.addr, align 8
+  // CHECK-NEXT:    %conv1 = trunc i64 %6 to i32
+  // CHECK-NEXT:    %scalar.splat.splatinsert2 = insertelement <27 x i32> undef, i32 %conv1, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat3 = shufflevector <27 x i32> %scalar.splat.splatinsert2, <27 x i32> undef, <27 x i32> zeroinitializer
+  // CHECK-NEXT:    %7 = add <27 x i32> %5, %scalar.splat.splat3
+  // CHECK-NEXT:    store <27 x i32> %7, <27 x i32>* %0, align 4
+  // CHECK-NEXT:    %8 = load <27 x i32>, <27 x i32>* %0, align 4
+  // CHECK-NEXT:    %9 = load i64, i64* %vulli.addr, align 8
+  // CHECK-NEXT:    %conv4 = trunc i64 %9 to i32
+  // CHECK-NEXT:    %scalar.splat.splatinsert5 = insertelement <27 x i32> undef, i32 %conv4, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat6 = shufflevector <27 x i32> %scalar.splat.splatinsert5, <27 x i32> undef, <27 x i32> zeroinitializer
+  // CHECK-NEXT:    %10 = add <27 x i32> %8, %scalar.splat.splat6
+  // CHECK-NEXT:    store <27 x i32> %10, <27 x i32>* %0, align 4
+  // CHECK-NEXT:    %11 = load i16, i16* %vs.addr, align 2
+
+  b = vs + b;
+  b = vli + b;
+  b = vulli + b;
+
+  // CHECK-NEXT:    %conv7 = sext i16 %11 to i32
+  // CHECK-NEXT:    %12 = load <27 x i32>, <27 x i32>* %1, align 4
+  // CHECK-NEXT:    %scalar.splat.splatinsert8 = insertelement <27 x i32> undef, i32 %conv7, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat9 = shufflevector <27 x i32> %scalar.splat.splatinsert8, <27 x i32> undef, <27 x i32> zeroinitializer
+  // CHECK-NEXT:    %13 = add <27 x i32> %scalar.splat.splat9, %12
+  // CHECK-NEXT:    store <27 x i32> %13, <27 x i32>* %1, align 4
+  // CHECK-NEXT:    %14 = load i64, i64* %vli.addr, align 8
+  // CHECK-NEXT:    %conv10 = trunc i64 %14 to i32
+  // CHECK-NEXT:    %15 = load <27 x i32>, <27 x i32>* %1, align 4
+  // CHECK-NEXT:    %scalar.splat.splatinsert11 = insertelement <27 x i32> undef, i32 %conv10, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat12 = shufflevector <27 x i32> %scalar.splat.splatinsert11, <27 x i32> undef, <27 x i32> zeroinitializer
+  // CHECK-NEXT:    %16 = add <27 x i32> %scalar.splat.splat12, %15
+  // CHECK-NEXT:    store <27 x i32> %16, <27 x i32>* %1, align 4
+  // CHECK-NEXT:    %17 = load i64, i64* %vulli.addr, align 8
+  // CHECK-NEXT:    %conv13 = trunc i64 %17 to i32
+  // CHECK-NEXT:    %18 = load <27 x i32>, <27 x i32>* %1, align 4
+  // CHECK-NEXT:    %scalar.splat.splatinsert14 = insertelement <27 x i32> undef, i32 %conv13, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat15 = shufflevector <27 x i32> %scalar.splat.splatinsert14, <27 x i32> undef, <27 x i32> zeroinitializer
+  // CHECK-NEXT:    %19 = add <27 x i32> %scalar.splat.splat15, %18
+  // CHECK-NEXT:    store <27 x i32> %19, <27 x i32>* %1, align 4
+  // CHECK-NEXT:    ret void
+}
+
+void sub_matrix_matrix(dx5x5_t a, dx5x5_t b, dx5x5_t c, ix9x3_t ai, ix9x3_t bi, ix9x3_t ci) {
+  a = b - c;
+  ai = bi - ci;
+
+  // CHECK-LABEL: @sub_matrix_matrix(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [25 x double], align 8
+  // CHECK-NEXT:    %b.addr = alloca [25 x double], align 8
+  // CHECK-NEXT:    %c.addr = alloca [25 x double], align 8
+  // CHECK-NEXT:    %ai.addr = alloca [27 x i32], align 4
+  // CHECK-NEXT:    %bi.addr = alloca [27 x i32], align 4
+  // CHECK-NEXT:    %ci.addr = alloca [27 x i32], align 4
+  // CHECK-NEXT:    %0 = bitcast [25 x double]* %a.addr to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %a, <25 x double>* %0, align 8
+  // CHECK-NEXT:    %1 = bitcast [25 x double]* %b.addr to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %b, <25 x double>* %1, align 8
+  // CHECK-NEXT:    %2 = bitcast [25 x double]* %c.addr to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %c, <25 x double>* %2, align 8
+  // CHECK-NEXT:    %3 = bitcast [27 x i32]* %ai.addr to <27 x i32>*
+  // CHECK-NEXT:    store <27 x i32> %ai, <27 x i32>* %3, align 4
+  // CHECK-NEXT:    %4 = bitcast [27 x i32]* %bi.addr to <27 x i32>*
+  // CHECK-NEXT:    store <27 x i32> %bi, <27 x i32>* %4, align 4
+  // CHECK-NEXT:    %5 = bitcast [27 x i32]* %ci.addr to <27 x i32>*
+  // CHECK-NEXT:    store <27 x i32> %ci, <27 x i32>* %5, align 4
+  // CHECK-NEXT:    %6 = load <25 x double>, <25 x double>* %1, align 8
+  // CHECK-NEXT:    %7 = load <25 x double>, <25 x double>* %2, align 8
+  // CHECK-NEXT:    %8 = fsub <25 x double> %6, %7
+  // CHECK-NEXT:    store <25 x double> %8, <25 x double>* %0, align 8
+  // CHECK-NEXT:    %9 = load <27 x i32>, <27 x i32>* %4, align 4
+  // CHECK-NEXT:    %10 = load <27 x i32>, <27 x i32>* %5, align 4
+  // CHECK-NEXT:    %11 = sub <27 x i32> %9, %10
+  // CHECK-NEXT:    store <27 x i32> %11, <27 x i32>* %3, align 4
+  // CHECK-NEXT:    ret void
+}
+
+void sub_matrix_scalar_float(dx5x5_t a, fx2x3_t b, float vf, double vd) {
+  a = a - vf;
+  a = a - vd;
+
+  // CHECK-LABEL: define void @sub_matrix_scalar_float(<25 x double> %a, <6 x float> %b, float %vf, double %vd)
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [25 x double], align 8
+  // CHECK-NEXT:    %b.addr = alloca [6 x float], align 4
+  // CHECK-NEXT:    %vf.addr = alloca float, align 4
+  // CHECK-NEXT:    %vd.addr = alloca double, align 8
+  // CHECK-NEXT:    %0 = bitcast [25 x double]* %a.addr to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %a, <25 x double>* %0, align 8
+  // CHECK-NEXT:    %1 = bitcast [6 x float]* %b.addr to <6 x float>*
+  // CHECK-NEXT:    store <6 x float> %b, <6 x float>* %1, align 4
+  // CHECK-NEXT:    store float %vf, float* %vf.addr, align 4
+  // CHECK-NEXT:    store double %vd, double* %vd.addr, align 8
+  // CHECK-NEXT:    %2 = load <25 x double>, <25 x double>* %0, align 8
+  // CHECK-NEXT:    %3 = load float, float* %vf.addr, align 4
+  // CHECK-NEXT:    %conv = fpext float %3 to double
+  // CHECK-NEXT:    %scalar.splat.splatinsert = insertelement <25 x double> undef, double %conv, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat = shufflevector <25 x double> %scalar.splat.splatinsert, <25 x double> undef, <25 x i32> zeroinitializer
+  // CHECK-NEXT:    %4 = fsub <25 x double> %2, %scalar.splat.splat
+  // CHECK-NEXT:    store <25 x double> %4, <25 x double>* %0, align 8
+  // CHECK-NEXT:    %5 = load <25 x double>, <25 x double>* %0, align 8
+  // CHECK-NEXT:    %6 = load double, double* %vd.addr, align 8
+  // CHECK-NEXT:    %scalar.splat.splatinsert1 = insertelement <25 x double> undef, double %6, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat2 = shufflevector <25 x double> %scalar.splat.splatinsert1, <25 x double> undef, <25 x i32> zeroinitializer
+  // CHECK-NEXT:    %7 = fsub <25 x double> %5, %scalar.splat.splat2
+  // CHECK-NEXT:    store <25 x double> %7, <25 x double>* %0, align 8
+
+  b = b - vf;
+  b = b - vd;
+
+  // CHECK-NEXT:    %8 = load <6 x float>, <6 x float>* %1, align 4
+  // CHECK-NEXT:    %9 = load float, float* %vf.addr, align 4
+  // CHECK-NEXT:    %scalar.splat.splatinsert3 = insertelement <6 x float> undef, float %9, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat4 = shufflevector <6 x float> %scalar.splat.splatinsert3, <6 x float> undef, <6 x i32> zeroinitializer
+  // CHECK-NEXT:    %10 = fsub <6 x float> %8, %scalar.splat.splat4
+  // CHECK-NEXT:    store <6 x float> %10, <6 x float>* %1, align 4
+  // CHECK-NEXT:    %11 = load <6 x float>, <6 x float>* %1, align 4
+  // CHECK-NEXT:    %12 = load double, double* %vd.addr, align 8
+  // CHECK-NEXT:    %conv5 = fptrunc double %12 to float
+  // CHECK-NEXT:    %scalar.splat.splatinsert6 = insertelement <6 x float> undef, float %conv5, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat7 = shufflevector <6 x float> %scalar.splat.splatinsert6, <6 x float> undef, <6 x i32> zeroinitializer
+  // CHECK-NEXT:    %13 = fsub <6 x float> %11, %scalar.splat.splat7
+  // CHECK-NEXT:    store <6 x float> %13, <6 x float>* %1, align 4
+  // CHECK-NEXT:    ret void
+}
+
+void sub_matrix_scalar_ints(ix9x3_t a, llix9x3_t b, short vs, long int vli, unsigned long long int vulli) {
+  a = a - vs;
+  a = a - vli;
+  a = a - vulli;
+
+  // CHECK-LABEL: define void @sub_matrix_scalar_ints(<27 x i32> %a, <27 x i32> %b, i16 signext %vs, i64 %vli, i64 %vulli)
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [27 x i32], align 4
+  // CHECK-NEXT:    %b.addr = alloca [27 x i32], align 4
+  // CHECK-NEXT:    %vs.addr = alloca i16, align 2
+  // CHECK-NEXT:    %vli.addr = alloca i64, align 8
+  // CHECK-NEXT:    %vulli.addr = alloca i64, align 8
+  // CHECK-NEXT:    %0 = bitcast [27 x i32]* %a.addr to <27 x i32>*
+  // CHECK-NEXT:    store <27 x i32> %a, <27 x i32>* %0, align 4
+  // CHECK-NEXT:    %1 = bitcast [27 x i32]* %b.addr to <27 x i32>*
+  // CHECK-NEXT:    store <27 x i32> %b, <27 x i32>* %1, align 4
+  // CHECK-NEXT:    store i16 %vs, i16* %vs.addr, align 2
+  // CHECK-NEXT:    store i64 %vli, i64* %vli.addr, align 8
+  // CHECK-NEXT:    store i64 %vulli, i64* %vulli.addr, align 8
+  // CHECK-NEXT:    %2 = load <27 x i32>, <27 x i32>* %0, align 4
+  // CHECK-NEXT:    %3 = load i16, i16* %vs.addr, align 2
+  // CHECK-NEXT:    %conv = sext i16 %3 to i32
+  // CHECK-NEXT:    %scalar.splat.splatinsert = insertelement <27 x i32> undef, i32 %conv, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat = shufflevector <27 x i32> %scalar.splat.splatinsert, <27 x i32> undef, <27 x i32> zeroinitializer
+  // CHECK-NEXT:    %4 = sub <27 x i32> %2, %scalar.splat.splat
+  // CHECK-NEXT:    store <27 x i32> %4, <27 x i32>* %0, align 4
+  // CHECK-NEXT:    %5 = load <27 x i32>, <27 x i32>* %0, align 4
+  // CHECK-NEXT:    %6 = load i64, i64* %vli.addr, align 8
+  // CHECK-NEXT:    %conv1 = trunc i64 %6 to i32
+  // CHECK-NEXT:    %scalar.splat.splatinsert2 = insertelement <27 x i32> undef, i32 %conv1, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat3 = shufflevector <27 x i32> %scalar.splat.splatinsert2, <27 x i32> undef, <27 x i32> zeroinitializer
+  // CHECK-NEXT:    %7 = sub <27 x i32> %5, %scalar.splat.splat3
+  // CHECK-NEXT:    store <27 x i32> %7, <27 x i32>* %0, align 4
+  // CHECK-NEXT:    %8 = load <27 x i32>, <27 x i32>* %0, align 4
+  // CHECK-NEXT:    %9 = load i64, i64* %vulli.addr, align 8
+  // CHECK-NEXT:    %conv4 = trunc i64 %9 to i32
+  // CHECK-NEXT:    %scalar.splat.splatinsert5 = insertelement <27 x i32> undef, i32 %conv4, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat6 = shufflevector <27 x i32> %scalar.splat.splatinsert5, <27 x i32> undef, <27 x i32> zeroinitializer
+  // CHECK-NEXT:    %10 = sub <27 x i32> %8, %scalar.splat.splat6
+  // CHECK-NEXT:    store <27 x i32> %10, <27 x i32>* %0, align 4
+
+  b = vs - b;
+  b = vli - b;
+  b = vulli - b;
+
+  // CHECK-NEXT:    %11 = load i16, i16* %vs.addr, align 2
+  // CHECK-NEXT:    %conv7 = sext i16 %11 to i32
+  // CHECK-NEXT:    %12 = load <27 x i32>, <27 x i32>* %1, align 4
+  // CHECK-NEXT:    %scalar.splat.splatinsert8 = insertelement <27 x i32> undef, i32 %conv7, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat9 = shufflevector <27 x i32> %scalar.splat.splatinsert8, <27 x i32> undef, <27 x i32> zeroinitializer
+  // CHECK-NEXT:    %13 = sub <27 x i32> %scalar.splat.splat9, %12
+  // CHECK-NEXT:    store <27 x i32> %13, <27 x i32>* %1, align 4
+  // CHECK-NEXT:    %14 = load i64, i64* %vli.addr, align 8
+  // CHECK-NEXT:    %conv10 = trunc i64 %14 to i32
+  // CHECK-NEXT:    %15 = load <27 x i32>, <27 x i32>* %1, align 4
+  // CHECK-NEXT:    %scalar.splat.splatinsert11 = insertelement <27 x i32> undef, i32 %conv10, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat12 = shufflevector <27 x i32> %scalar.splat.splatinsert11, <27 x i32> undef, <27 x i32> zeroinitializer
+  // CHECK-NEXT:    %16 = sub <27 x i32> %scalar.splat.splat12, %15
+  // CHECK-NEXT:    store <27 x i32> %16, <27 x i32>* %1, align 4
+  // CHECK-NEXT:    %17 = load i64, i64* %vulli.addr, align 8
+  // CHECK-NEXT:    %conv13 = trunc i64 %17 to i32
+  // CHECK-NEXT:    %18 = load <27 x i32>, <27 x i32>* %1, align 4
+  // CHECK-NEXT:    %scalar.splat.splatinsert14 = insertelement <27 x i32> undef, i32 %conv13, i32 0
+  // CHECK-NEXT:    %scalar.splat.splat15 = shufflevector <27 x i32> %scalar.splat.splatinsert14, <27 x i32> undef, <27 x i32> zeroinitializer
+  // CHECK-NEXT:    %19 = sub <27 x i32> %scalar.splat.splat15, %18
+  // CHECK-NEXT:    store <27 x i32> %19, <27 x i32>* %1, align 4
+  // CHECK-NEXT:    ret void
+}
Index: clang/lib/Sema/SemaOverload.cpp
===================================================================
--- clang/lib/Sema/SemaOverload.cpp
+++ clang/lib/Sema/SemaOverload.cpp
@@ -7687,6 +7687,10 @@
   /// candidates.
   TypeSet VectorTypes;
 
+  /// The set of matrix types that will be used in the built-in
+  /// candidates.
+  TypeSet MatrixTypes;
+
   /// A flag indicating non-record types are viable candidates
   bool HasNonRecordTypes;
 
@@ -7747,6 +7751,10 @@
   iterator vector_begin() { return VectorTypes.begin(); }
   iterator vector_end() { return VectorTypes.end(); }
 
+  llvm::iterator_range<iterator> matrix_types() { return MatrixTypes; }
+  iterator matrix_begin() { return MatrixTypes.begin(); }
+  iterator matrix_end() { return MatrixTypes.end(); }
+
   bool hasNonRecordTypes() { return HasNonRecordTypes; }
   bool hasArithmeticOrEnumeralTypes() { return HasArithmeticOrEnumeralTypes; }
   bool hasNullPtrType() const { return HasNullPtrType; }
@@ -7921,6 +7929,11 @@
     // extension.
     HasArithmeticOrEnumeralTypes = true;
     VectorTypes.insert(Ty);
+  } else if (Ty->isMatrixType()) {
+    // Similar to vector types, we treat vector types as arithmetic types in
+    // many contexts as an extension.
+    HasArithmeticOrEnumeralTypes = true;
+    MatrixTypes.insert(Ty);
   } else if (Ty->isNullPtrType()) {
     HasNullPtrType = true;
   } else if (AllowUserConversions && TyRec) {
@@ -8541,30 +8554,42 @@
     if (!HasArithmeticOrEnumeralCandidateType)
       return;
 
+    auto AddCandidate = [this](QualType L, QualType R) {
+      QualType LandR[2] = {L, R};
+      S.AddBuiltinCandidate(LandR, Args, CandidateSet);
+    };
     for (unsigned Left = FirstPromotedArithmeticType;
-         Left < LastPromotedArithmeticType; ++Left) {
+         Left < LastPromotedArithmeticType; ++Left)
       for (unsigned Right = FirstPromotedArithmeticType;
-           Right < LastPromotedArithmeticType; ++Right) {
-        QualType LandR[2] = { ArithmeticTypes[Left],
-                              ArithmeticTypes[Right] };
-        S.AddBuiltinCandidate(LandR, Args, CandidateSet);
-      }
-    }
+           Right < LastPromotedArithmeticType; ++Right)
+        AddCandidate(ArithmeticTypes[Left], ArithmeticTypes[Right]);
 
     // Extension: Add the binary operators ==, !=, <, <=, >=, >, *, /, and the
     // conditional operator for vector types.
     for (BuiltinCandidateTypeSet::iterator
-              Vec1 = CandidateTypes[0].vector_begin(),
-           Vec1End = CandidateTypes[0].vector_end();
-         Vec1 != Vec1End; ++Vec1) {
+             Vec1 = CandidateTypes[0].vector_begin(),
+             Vec1End = CandidateTypes[0].vector_end();
+         Vec1 != Vec1End; ++Vec1)
       for (BuiltinCandidateTypeSet::iterator
-                Vec2 = CandidateTypes[1].vector_begin(),
-             Vec2End = CandidateTypes[1].vector_end();
-           Vec2 != Vec2End; ++Vec2) {
-        QualType LandR[2] = { *Vec1, *Vec2 };
-        S.AddBuiltinCandidate(LandR, Args, CandidateSet);
-      }
-    }
+               Vec2 = CandidateTypes[1].vector_begin(),
+               Vec2End = CandidateTypes[1].vector_end();
+           Vec2 != Vec2End; ++Vec2)
+        AddCandidate(*Vec1, *Vec2);
+
+    // Extension: Add following the binary operators overloads for each
+    // candidate type M1, M2:
+    //  * (M1, M2) -> M1, if M1 == M2
+    //  * (M1, M1.getElementType()) -> M1
+    //  * (M2.getElementType(), M2) -> M2
+    for (const QualType &M1 : CandidateTypes[0].matrix_types()) {
+      for (const QualType &M2 : CandidateTypes[1].matrix_types())
+        if (S.Context.hasSameType(M1, M2))
+          AddCandidate(M1, M2);
+
+      AddCandidate(M1, cast<MatrixType>(M1)->getElementType());
+    }
+    for (const QualType &M2 : CandidateTypes[1].matrix_types())
+      AddCandidate(cast<MatrixType>(M2)->getElementType(), M2);
   }
 
   // C++2a [over.built]p14:
Index: clang/lib/Sema/SemaExpr.cpp
===================================================================
--- clang/lib/Sema/SemaExpr.cpp
+++ clang/lib/Sema/SemaExpr.cpp
@@ -10304,6 +10304,11 @@
     return compType;
   }
 
+  if (LHS.get()->getType()->isConstantMatrixType() ||
+      RHS.get()->getType()->isConstantMatrixType()) {
+    return CheckMatrixElementwiseOperands(LHS, RHS, Loc, CompLHSTy);
+  }
+
   QualType compType = UsualArithmeticConversions(
       LHS, RHS, Loc, CompLHSTy ? ACK_CompAssign : ACK_Arithmetic);
   if (LHS.isInvalid() || RHS.isInvalid())
@@ -10399,6 +10404,11 @@
     return compType;
   }
 
+  if (LHS.get()->getType()->isConstantMatrixType() ||
+      RHS.get()->getType()->isConstantMatrixType()) {
+    return CheckMatrixElementwiseOperands(LHS, RHS, Loc, CompLHSTy);
+  }
+
   QualType compType = UsualArithmeticConversions(
       LHS, RHS, Loc, CompLHSTy ? ACK_CompAssign : ACK_Arithmetic);
   if (LHS.isInvalid() || RHS.isInvalid())
@@ -11994,6 +12004,71 @@
   return GetSignedVectorType(LHS.get()->getType());
 }
 
+static bool tryConvertScalarToMatrixElementTy(Sema &S, QualType ElementType,
+                                              ExprResult *Scalar) {
+  QualType ScalarTy = Scalar->get()->getType().getUnqualifiedType();
+  if (!ScalarTy->isArithmeticType()) {
+
+    InitializedEntity Entity =
+        InitializedEntity::InitializeTemporary(ElementType);
+    InitializationKind Kind = InitializationKind::CreateCopy(
+        Scalar->get()->getBeginLoc(), SourceLocation());
+    Expr *Arg = Scalar->get();
+    InitializationSequence InitSeq(S, Entity, Kind, Arg);
+    *Scalar = InitSeq.Perform(S, Entity, Kind, Arg);
+    return !Scalar->isInvalid();
+  }
+
+  CastKind CK = S.PrepareScalarCast(*Scalar, ElementType);
+  *Scalar = S.ImpCastExprToType(Scalar->get(), ElementType, CK);
+  return true;
+}
+
+QualType Sema::CheckMatrixElementwiseOperands(ExprResult &LHS, ExprResult &RHS,
+                                              SourceLocation Loc,
+                                              bool IsCompAssign) {
+  if (!IsCompAssign) {
+    LHS = DefaultFunctionArrayLvalueConversion(LHS.get());
+    if (LHS.isInvalid())
+      return QualType();
+  }
+  RHS = DefaultFunctionArrayLvalueConversion(RHS.get());
+  if (RHS.isInvalid())
+    return QualType();
+
+  // For conversion purposes, we ignore any qualifiers.
+  // For example, "const float" and "float" are equivalent.
+  QualType LHSType = LHS.get()->getType().getUnqualifiedType();
+  QualType RHSType = RHS.get()->getType().getUnqualifiedType();
+
+  const MatrixType *LHSMatType = LHSType->getAs<MatrixType>();
+  const MatrixType *RHSMatType = RHSType->getAs<MatrixType>();
+  assert((LHSMatType || RHSMatType) && "At least one operand must be a matrix");
+
+  if (Context.hasSameType(LHSType, RHSType))
+    return LHSType;
+
+  // Type conversion may change LHS/RHS. Keep copies to the original results, in
+  // case we have to return InvalidOperands.
+  ExprResult OriginalLHS = LHS;
+  ExprResult OriginalRHS = RHS;
+  if (LHSMatType && !RHSMatType) {
+    if (tryConvertScalarToMatrixElementTy(*this, LHSMatType->getElementType(),
+                                          &RHS))
+      return LHSType;
+    return InvalidOperands(Loc, OriginalLHS, OriginalRHS);
+  }
+
+  if (!LHSMatType && RHSMatType) {
+    if (tryConvertScalarToMatrixElementTy(*this, RHSMatType->getElementType(),
+                                          &LHS))
+      return RHSType;
+    return InvalidOperands(Loc, OriginalLHS, OriginalRHS);
+  }
+
+  return InvalidOperands(Loc, LHS, RHS);
+}
+
 inline QualType Sema::CheckBitwiseOperands(ExprResult &LHS, ExprResult &RHS,
                                            SourceLocation Loc,
                                            BinaryOperatorKind Opc) {
Index: clang/lib/CodeGen/CGExprScalar.cpp
===================================================================
--- clang/lib/CodeGen/CGExprScalar.cpp
+++ clang/lib/CodeGen/CGExprScalar.cpp
@@ -3554,6 +3554,11 @@
     }
   }
 
+  if (op.Ty->isConstantMatrixType()) {
+    llvm::MatrixBuilder<CGBuilderTy> MB(Builder);
+    return MB.CreateAdd(op.LHS, op.RHS);
+  }
+
   if (op.Ty->isUnsignedIntegerType() &&
       CGF.SanOpts.has(SanitizerKind::UnsignedIntegerOverflow) &&
       !CanElideOverflowCheck(CGF.getContext(), op))
@@ -3738,6 +3743,11 @@
       }
     }
 
+    if (op.Ty->isConstantMatrixType()) {
+      llvm::MatrixBuilder<CGBuilderTy> MB(Builder);
+      return MB.CreateSub(op.LHS, op.RHS);
+    }
+
     if (op.Ty->isUnsignedIntegerType() &&
         CGF.SanOpts.has(SanitizerKind::UnsignedIntegerOverflow) &&
         !CanElideOverflowCheck(CGF.getContext(), op))
Index: clang/include/clang/Sema/Sema.h
===================================================================
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -11210,6 +11210,11 @@
   QualType CheckVectorLogicalOperands(ExprResult &LHS, ExprResult &RHS,
                                       SourceLocation Loc);
 
+  /// Type checking for matrix binary operators.
+  QualType CheckMatrixElementwiseOperands(ExprResult &LHS, ExprResult &RHS,
+                                          SourceLocation Loc,
+                                          bool IsCompAssign);
+
   bool areLaxCompatibleVectorTypes(QualType srcType, QualType destType);
   bool isLaxVectorConversion(QualType srcType, QualType destType);
 
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to