fhahn created this revision.
fhahn added reviewers: rjmccall, anemet, Bigcheese, rsmith, martong.
Herald added subscribers: tschuett, dexonsmith, rnkovacs.
Herald added a project: clang.

This patch implements the + and - binary operators for values of
MatrixType. It adds support for matrix +/- matrix, scalar +/- matrix and
matrix +/- scalar.

For the matrix, matrix case, the types must initially be structurally
equivalent. For the scalar,matrix variants, the element type of the
matrix must match the scalar type.

  rG LLVM Github Monorepo



Index: clang/test/SemaCXX/matrix-type-operators.cpp
--- clang/test/SemaCXX/matrix-type-operators.cpp
+++ clang/test/SemaCXX/matrix-type-operators.cpp
@@ -59,3 +59,66 @@
   float v11 = a[5][10.0];
   // expected-error@-1 {{matrix row index is outside the allowed range [0, 5)}}
+template <typename EltTy, unsigned Rows, unsigned Columns>
+struct MyMatrix {
+  using matrix_t = EltTy __attribute__((matrix_type(Rows, Columns)));
+  matrix_t value;
+template <typename EltTy0, unsigned R0, unsigned C0, typename EltTy1, unsigned R1, unsigned C1, typename EltTy2, unsigned R2, unsigned C2>
+typename MyMatrix<EltTy2, R2, C2>::matrix_t add(MyMatrix<EltTy0, R0, C0> &A, MyMatrix<EltTy1, R1, C1> &B) {
+  char *v1 = A.value + B.value;
+  // expected-error@-1 {{cannot initialize a variable of type 'char *' with an rvalue of type 'MyMatrix<unsigned int, 2, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))')}}
+  // expected-error@-2 {{invalid operands to binary expression ('MyMatrix<unsigned int, 3, 3>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))') and 'MyMatrix<float, 2, 2>::matrix_t' (aka 'float __attribute__((matrix_type(2, 2)))'))}}
+  // expected-error@-3 {{invalid operands to binary expression ('MyMatrix<unsigned int, 2, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))') and 'MyMatrix<unsigned int, 3, 3>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))'))}}
+  return A.value + B.value;
+  // expected-error@-1 {{invalid operands to binary expression ('MyMatrix<unsigned int, 3, 3>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))') and 'MyMatrix<float, 2, 2>::matrix_t' (aka 'float __attribute__((matrix_type(2, 2)))'))}}
+  // expected-error@-2 {{invalid operands to binary expression ('MyMatrix<unsigned int, 2, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))') and 'MyMatrix<unsigned int, 3, 3>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))'))}}
+void test_add_template(unsigned *Ptr1, float *Ptr2) {
+  MyMatrix<unsigned, 2, 2> Mat1;
+  MyMatrix<unsigned, 3, 3> Mat2;
+  MyMatrix<float, 2, 2> Mat3;
+  Mat1.value = *((decltype(Mat1)::matrix_t *)Ptr1);
+  unsigned v1 = add<unsigned, 2, 2, unsigned, 2, 2, unsigned, 2, 2>(Mat1, Mat1);
+  // expected-error@-1 {{cannot initialize a variable of type 'unsigned int' with an rvalue of type 'typename MyMatrix<unsigned int, 2U, 2U>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))')}}
+  // expected-note@-2 {{in instantiation of function template specialization 'add<unsigned int, 2, 2, unsigned int, 2, 2, unsigned int, 2, 2>' requested here}}
+  Mat1.value = add<unsigned, 2, 2, unsigned, 3, 3, unsigned, 2, 2>(Mat1, Mat2);
+  // expected-note@-1 {{in instantiation of function template specialization 'add<unsigned int, 2, 2, unsigned int, 3, 3, unsigned int, 2, 2>' requested here}}
+  Mat1.value = add<unsigned, 3, 3, float, 2, 2, unsigned, 2, 2>(Mat2, Mat3);
+  // expected-note@-1 {{in instantiation of function template specialization 'add<unsigned int, 3, 3, float, 2, 2, unsigned int, 2, 2>' requested here}}
+template <typename EltTy0, unsigned R0, unsigned C0, typename EltTy1, unsigned R1, unsigned C1, typename EltTy2, unsigned R2, unsigned C2>
+typename MyMatrix<EltTy2, R2, C2>::matrix_t subtract(MyMatrix<EltTy0, R0, C0> &A, MyMatrix<EltTy1, R1, C1> &B) {
+  char *v1 = A.value - B.value;
+  // expected-error@-1 {{cannot initialize a variable of type 'char *' with an rvalue of type 'MyMatrix<unsigned int, 2, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))')}}
+  // expected-error@-2 {{invalid operands to binary expression ('MyMatrix<unsigned int, 3, 3>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))') and 'MyMatrix<float, 2, 2>::matrix_t' (aka 'float __attribute__((matrix_type(2, 2)))')}}
+  // expected-error@-3 {{invalid operands to binary expression ('MyMatrix<unsigned int, 2, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))') and 'MyMatrix<unsigned int, 3, 3>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))')}}
+  return A.value - B.value;
+  // expected-error@-1 {{invalid operands to binary expression ('MyMatrix<unsigned int, 3, 3>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))') and 'MyMatrix<float, 2, 2>::matrix_t' (aka 'float __attribute__((matrix_type(2, 2)))')}}
+  // expected-error@-2 {{invalid operands to binary expression ('MyMatrix<unsigned int, 2, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))') and 'MyMatrix<unsigned int, 3, 3>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))')}}
+void test_subtract_template(unsigned *Ptr1, float *Ptr2) {
+  MyMatrix<unsigned, 2, 2> Mat1;
+  MyMatrix<unsigned, 3, 3> Mat2;
+  MyMatrix<float, 2, 2> Mat3;
+  Mat1.value = *((decltype(Mat1)::matrix_t *)Ptr1);
+  unsigned v1 = subtract<unsigned, 2, 2, unsigned, 2, 2, unsigned, 2, 2>(Mat1, Mat1);
+  // expected-error@-1 {{cannot initialize a variable of type 'unsigned int' with an rvalue of type 'typename MyMatrix<unsigned int, 2U, 2U>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))')}}
+  // expected-note@-2 {{in instantiation of function template specialization 'subtract<unsigned int, 2, 2, unsigned int, 2, 2, unsigned int, 2, 2>' requested here}}
+  Mat1.value = subtract<unsigned, 2, 2, unsigned, 3, 3, unsigned, 2, 2>(Mat1, Mat2);
+  // expected-note@-1 {{in instantiation of function template specialization 'subtract<unsigned int, 2, 2, unsigned int, 3, 3, unsigned int, 2, 2>' requested here}}
+  Mat1.value = subtract<unsigned, 3, 3, float, 2, 2, unsigned, 2, 2>(Mat2, Mat3);
+  // expected-note@-1 {{in instantiation of function template specialization 'subtract<unsigned int, 3, 3, float, 2, 2, unsigned int, 2, 2>' requested here}}
Index: clang/test/Sema/matrix-type-operators.c
--- clang/test/Sema/matrix-type-operators.c
+++ clang/test/Sema/matrix-type-operators.c
@@ -65,3 +65,34 @@
   float v12 = a[3];
   // expected-error@-1 {{initializing 'float' with an expression of incompatible type 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))')}}
+typedef float sx10x5_t __attribute__((matrix_type(10, 5)));
+typedef float sx10x10_t __attribute__((matrix_type(10, 10)));
+void add(sx10x10_t a, sx5x10_t b, sx10x5_t c) {
+  a = b + c;
+  // expected-error@-1 {{invalid operands to binary expression ('sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') and 'sx10x5_t' (aka 'float __attribute__((matrix_type(10, 5)))'))}}
+  a = b + b; // expected-error {{assigning to 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))') from incompatible type 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))')}}
+  // TODO: Implement scalar & matrix binary add.
+  a = 10 + b;
+  // expected-error@-1 {{invalid operands to binary expression ('int' and 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))'))}}
+  a = b + &c;
+  // expected-error@-1 {{invalid operands to binary expression ('sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') and 'sx10x5_t *' (aka 'float  __attribute__((matrix_type(10, 5)))*'))}}
+void sub(sx10x10_t a, sx5x10_t b, sx10x5_t c) {
+  a = b - c;
+  // expected-error@-1 {{invalid operands to binary expression ('sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') and 'sx10x5_t' (aka 'float __attribute__((matrix_type(10, 5)))'))}}
+  a = b - b; // expected-error {{assigning to 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))') from incompatible type 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))')}}
+  // TODO: Implement scalar & matrix binary add.
+  a = 10 - b;
+  // expected-error@-1 {{invalid operands to binary expression ('int' and 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))'))}}
+  a = b - &c;
+  // expected-error@-1 {{invalid operands to binary expression ('sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') and 'sx10x5_t *' (aka 'float  __attribute__((matrix_type(10, 5)))*'))}}
Index: clang/test/CodeGenCXX/matrix-type-operators.cpp
--- clang/test/CodeGenCXX/matrix-type-operators.cpp
+++ clang/test/CodeGenCXX/matrix-type-operators.cpp
@@ -209,3 +209,79 @@
   Mat1.value = *((decltype(Mat1)::matrix_t *)Ptr1);
   unsigned v1 = extract(Mat1);
+template <typename EltTy0, unsigned R0, unsigned C0>
+typename MyMatrix<EltTy0, R0, C0>::matrix_t add(MyMatrix<EltTy0, R0, C0> &A, MyMatrix<EltTy0, R0, C0> &B) {
+  return A.value + B.value;
+void test_add_template() {
+  // CHECK-LABEL:    define void @_Z17test_add_templatev()
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %Mat1 = alloca %struct.MyMatrix.1, align 4
+  // CHECK-NEXT:    %Mat2 = alloca %struct.MyMatrix.1, align 4
+  // CHECK-NEXT:    %call = call <10 x float> @_Z3addIfLj2ELj5EEN8MyMatrixIT_XT0_EXT1_EE8matrix_tERS2_S4_(%struct.MyMatrix.1* dereferenceable(40) %Mat1, %struct.MyMatrix.1* dereferenceable(40) %Mat2)
+  // CHECK-NEXT:    %value = getelementptr inbounds %struct.MyMatrix.1, %struct.MyMatrix.1* %Mat1, i32 0, i32 0
+  // CHECK-NEXT:    %0 = bitcast [10 x float]* %value to <10 x float>*
+  // CHECK-NEXT:    store <10 x float> %call, <10 x float>* %0, align 4
+  // CHECK-NEXT:    ret void
+  // CHECK-LABEL: define linkonce_odr <10 x float> @_Z3addIfLj2ELj5EEN8MyMatrixIT_XT0_EXT1_EE8matrix_tERS2_S4_(%struct.MyMatrix.1* dereferenceable(40) %A, %struct.MyMatrix.1* dereferenceable(40) %B)
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %A.addr = alloca %struct.MyMatrix.1*, align 8
+  // CHECK-NEXT:    %B.addr = alloca %struct.MyMatrix.1*, align 8
+  // CHECK-NEXT:    store %struct.MyMatrix.1* %A, %struct.MyMatrix.1** %A.addr, align 8
+  // CHECK-NEXT:    store %struct.MyMatrix.1* %B, %struct.MyMatrix.1** %B.addr, align 8
+  // CHECK-NEXT:    %0 = load %struct.MyMatrix.1*, %struct.MyMatrix.1** %A.addr, align 8
+  // CHECK-NEXT:    %value = getelementptr inbounds %struct.MyMatrix.1, %struct.MyMatrix.1* %0, i32 0, i32 0
+  // CHECK-NEXT:    %1 = bitcast [10 x float]* %value to <10 x float>*
+  // CHECK-NEXT:    %2 = load <10 x float>, <10 x float>* %1, align 4
+  // CHECK-NEXT:    %3 = load %struct.MyMatrix.1*, %struct.MyMatrix.1** %B.addr, align 8
+  // CHECK-NEXT:    %value1 = getelementptr inbounds %struct.MyMatrix.1, %struct.MyMatrix.1* %3, i32 0, i32 0
+  // CHECK-NEXT:    %4 = bitcast [10 x float]* %value1 to <10 x float>*
+  // CHECK-NEXT:    %5 = load <10 x float>, <10 x float>* %4, align 4
+  // CHECK-NEXT:    %6 = fadd <10 x float> %2, %5
+  // CHECK-NEXT:    ret <10 x float> %6
+  MyMatrix<float, 2, 5> Mat1;
+  MyMatrix<float, 2, 5> Mat2;
+  Mat1.value = add(Mat1, Mat2);
+template <typename EltTy0, unsigned R0, unsigned C0>
+typename MyMatrix<EltTy0, R0, C0>::matrix_t subtract(MyMatrix<EltTy0, R0, C0> &A, MyMatrix<EltTy0, R0, C0> &B) {
+  return A.value - B.value;
+void test_subtract_template() {
+  // CHECK-LABEL: define void @_Z22test_subtract_templatev()
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %Mat1 = alloca %struct.MyMatrix.1, align 4
+  // CHECK-NEXT:    %Mat2 = alloca %struct.MyMatrix.1, align 4
+  // CHECK-NEXT:    %call = call <10 x float> @_Z8subtractIfLj2ELj5EEN8MyMatrixIT_XT0_EXT1_EE8matrix_tERS2_S4_(%struct.MyMatrix.1* dereferenceable(40) %Mat1, %struct.MyMatrix.1* dereferenceable(40) %Mat2)
+  // CHECK-NEXT:    %value = getelementptr inbounds %struct.MyMatrix.1, %struct.MyMatrix.1* %Mat1, i32 0, i32 0
+  // CHECK-NEXT:    %0 = bitcast [10 x float]* %value to <10 x float>*
+  // CHECK-NEXT:    store <10 x float> %call, <10 x float>* %0, align 4
+  // CHECK-NEXT:    ret void
+  // CHECK-LABEL: define linkonce_odr <10 x float> @_Z8subtractIfLj2ELj5EEN8MyMatrixIT_XT0_EXT1_EE8matrix_tERS2_S4_(%struct.MyMatrix.1* dereferenceable(40) %A, %struct.MyMatrix.1* dereferenceable(40) %B)
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %A.addr = alloca %struct.MyMatrix.1*, align 8
+  // CHECK-NEXT:    %B.addr = alloca %struct.MyMatrix.1*, align 8
+  // CHECK-NEXT:    store %struct.MyMatrix.1* %A, %struct.MyMatrix.1** %A.addr, align 8
+  // CHECK-NEXT:    store %struct.MyMatrix.1* %B, %struct.MyMatrix.1** %B.addr, align 8
+  // CHECK-NEXT:    %0 = load %struct.MyMatrix.1*, %struct.MyMatrix.1** %A.addr, align 8
+  // CHECK-NEXT:    %value = getelementptr inbounds %struct.MyMatrix.1, %struct.MyMatrix.1* %0, i32 0, i32 0
+  // CHECK-NEXT:    %1 = bitcast [10 x float]* %value to <10 x float>*
+  // CHECK-NEXT:    %2 = load <10 x float>, <10 x float>* %1, align 4
+  // CHECK-NEXT:    %3 = load %struct.MyMatrix.1*, %struct.MyMatrix.1** %B.addr, align 8
+  // CHECK-NEXT:    %value1 = getelementptr inbounds %struct.MyMatrix.1, %struct.MyMatrix.1* %3, i32 0, i32 0
+  // CHECK-NEXT:    %4 = bitcast [10 x float]* %value1 to <10 x float>*
+  // CHECK-NEXT:    %5 = load <10 x float>, <10 x float>* %4, align 4
+  // CHECK-NEXT:    %6 = fsub <10 x float> %2, %5
+  // CHECK-NEXT:    ret <10 x float> %6
+  MyMatrix<float, 2, 5> Mat1;
+  MyMatrix<float, 2, 5> Mat2;
+  Mat1.value = subtract(Mat1, Mat2);
Index: clang/test/CodeGen/matrix-type-operators.c
--- clang/test/CodeGen/matrix-type-operators.c
+++ clang/test/CodeGen/matrix-type-operators.c
@@ -155,3 +155,73 @@
   // CHECK-NEXT:    store i32 %matext2, i32* %v3, align 4
   // CHECK-NEXT:    ret void
+void add1(dx5x5_t a, dx5x5_t b, dx5x5_t c, ix9x3_t ai, ix9x3_t bi, ix9x3_t ci) {
+  a = b + c;
+  ai = bi + ci;
+  // CHECK-LABEL: @add1(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [25 x double], align 8
+  // CHECK-NEXT:    %b.addr = alloca [25 x double], align 8
+  // CHECK-NEXT:    %c.addr = alloca [25 x double], align 8
+  // CHECK-NEXT:    %ai.addr = alloca [27 x i32], align 4
+  // CHECK-NEXT:    %bi.addr = alloca [27 x i32], align 4
+  // CHECK-NEXT:    %ci.addr = alloca [27 x i32], align 4
+  // CHECK-NEXT:    %0 = bitcast [25 x double]* %a.addr to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %a, <25 x double>* %0, align 8
+  // CHECK-NEXT:    %1 = bitcast [25 x double]* %b.addr to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %b, <25 x double>* %1, align 8
+  // CHECK-NEXT:    %2 = bitcast [25 x double]* %c.addr to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %c, <25 x double>* %2, align 8
+  // CHECK-NEXT:    %3 = bitcast [27 x i32]* %ai.addr to <27 x i32>*
+  // CHECK-NEXT:    store <27 x i32> %ai, <27 x i32>* %3, align 4
+  // CHECK-NEXT:    %4 = bitcast [27 x i32]* %bi.addr to <27 x i32>*
+  // CHECK-NEXT:    store <27 x i32> %bi, <27 x i32>* %4, align 4
+  // CHECK-NEXT:    %5 = bitcast [27 x i32]* %ci.addr to <27 x i32>*
+  // CHECK-NEXT:    store <27 x i32> %ci, <27 x i32>* %5, align 4
+  // CHECK-NEXT:    %6 = load <25 x double>, <25 x double>* %1, align 8
+  // CHECK-NEXT:    %7 = load <25 x double>, <25 x double>* %2, align 8
+  // CHECK-NEXT:    %8 = fadd <25 x double> %6, %7
+  // CHECK-NEXT:    store <25 x double> %8, <25 x double>* %0, align 8
+  // CHECK-NEXT:    %9 = load <27 x i32>, <27 x i32>* %4, align 4
+  // CHECK-NEXT:    %10 = load <27 x i32>, <27 x i32>* %5, align 4
+  // CHECK-NEXT:    %11 = add <27 x i32> %9, %10
+  // CHECK-NEXT:    store <27 x i32> %11, <27 x i32>* %3, align 4
+  // CHECK-NEXT:    ret void
+void sub1(dx5x5_t a, dx5x5_t b, dx5x5_t c, ix9x3_t ai, ix9x3_t bi, ix9x3_t ci) {
+  a = b - c;
+  ai = bi - ci;
+  // CHECK-LABEL: @sub1(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [25 x double], align 8
+  // CHECK-NEXT:    %b.addr = alloca [25 x double], align 8
+  // CHECK-NEXT:    %c.addr = alloca [25 x double], align 8
+  // CHECK-NEXT:    %ai.addr = alloca [27 x i32], align 4
+  // CHECK-NEXT:    %bi.addr = alloca [27 x i32], align 4
+  // CHECK-NEXT:    %ci.addr = alloca [27 x i32], align 4
+  // CHECK-NEXT:    %0 = bitcast [25 x double]* %a.addr to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %a, <25 x double>* %0, align 8
+  // CHECK-NEXT:    %1 = bitcast [25 x double]* %b.addr to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %b, <25 x double>* %1, align 8
+  // CHECK-NEXT:    %2 = bitcast [25 x double]* %c.addr to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %c, <25 x double>* %2, align 8
+  // CHECK-NEXT:    %3 = bitcast [27 x i32]* %ai.addr to <27 x i32>*
+  // CHECK-NEXT:    store <27 x i32> %ai, <27 x i32>* %3, align 4
+  // CHECK-NEXT:    %4 = bitcast [27 x i32]* %bi.addr to <27 x i32>*
+  // CHECK-NEXT:    store <27 x i32> %bi, <27 x i32>* %4, align 4
+  // CHECK-NEXT:    %5 = bitcast [27 x i32]* %ci.addr to <27 x i32>*
+  // CHECK-NEXT:    store <27 x i32> %ci, <27 x i32>* %5, align 4
+  // CHECK-NEXT:    %6 = load <25 x double>, <25 x double>* %1, align 8
+  // CHECK-NEXT:    %7 = load <25 x double>, <25 x double>* %2, align 8
+  // CHECK-NEXT:    %8 = fsub <25 x double> %6, %7
+  // CHECK-NEXT:    store <25 x double> %8, <25 x double>* %0, align 8
+  // CHECK-NEXT:    %9 = load <27 x i32>, <27 x i32>* %4, align 4
+  // CHECK-NEXT:    %10 = load <27 x i32>, <27 x i32>* %5, align 4
+  // CHECK-NEXT:    %11 = sub <27 x i32> %9, %10
+  // CHECK-NEXT:    store <27 x i32> %11, <27 x i32>* %3, align 4
+  // CHECK-NEXT:    ret void
Index: clang/lib/Sema/SemaExpr.cpp
--- clang/lib/Sema/SemaExpr.cpp
+++ clang/lib/Sema/SemaExpr.cpp
@@ -9865,6 +9865,11 @@
     return compType;
+  if (LHS.get()->getType()->isMatrixType() ||
+      RHS.get()->getType()->isMatrixType()) {
+    return CheckMatrixElementwiseOperands(LHS, RHS, Loc);
+  }
   QualType compType = UsualArithmeticConversions(
       LHS, RHS, Loc, CompLHSTy ? ACK_CompAssign : ACK_Arithmetic);
   if (LHS.isInvalid() || RHS.isInvalid())
@@ -9960,6 +9965,11 @@
     return compType;
+  if (LHS.get()->getType()->isMatrixType() ||
+      RHS.get()->getType()->isMatrixType()) {
+    return CheckMatrixElementwiseOperands(LHS, RHS, Loc);
+  }
   QualType compType = UsualArithmeticConversions(
       LHS, RHS, Loc, CompLHSTy ? ACK_CompAssign : ACK_Arithmetic);
   if (LHS.isInvalid() || RHS.isInvalid())
@@ -11551,6 +11561,23 @@
   return GetSignedVectorType(LHS.get()->getType());
+QualType Sema::CheckMatrixElementwiseOperands(ExprResult &LHS, ExprResult &RHS,
+                                              SourceLocation Loc) {
+  // For conversion purposes, we ignore any qualifiers.
+  // For example, "const float" and "float" are equivalent.
+  QualType LHSType = LHS.get()->getType().getUnqualifiedType();
+  QualType RHSType = RHS.get()->getType().getUnqualifiedType();
+  // If the vector types are identical, return.
+  if (Context.hasSameType(LHSType, RHSType))
+    return LHSType;
+  return InvalidOperands(Loc, LHS, RHS);
+  /*const MatrixType *LHSMatType = LHSType->getAs<MatrixType>();*/
+  // const MatrixType *RHSMatType = RHSType->getAs<MatrixType>();
+  // assert(LHSMatType || RHSMatType);
 inline QualType Sema::CheckBitwiseOperands(ExprResult &LHS, ExprResult &RHS,
                                            SourceLocation Loc,
                                            BinaryOperatorKind Opc) {
Index: clang/lib/CodeGen/CGExprScalar.cpp
--- clang/lib/CodeGen/CGExprScalar.cpp
+++ clang/lib/CodeGen/CGExprScalar.cpp
@@ -37,6 +37,7 @@
 #include "llvm/IR/GlobalVariable.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/IntrinsicsPowerPC.h"
+#include "llvm/IR/MatrixBuilder.h"
 #include "llvm/IR/Module.h"
 #include <cstdarg>
@@ -3469,6 +3470,11 @@
+  if (op.Ty->isMatrixType()) {
+    llvm::MatrixBuilder<CGBuilderTy> MB(Builder);
+    return MB.CreateAdd(op.LHS, op.RHS);
+  }
   if (op.Ty->isUnsignedIntegerType() &&
       CGF.SanOpts.has(SanitizerKind::UnsignedIntegerOverflow) &&
       !CanElideOverflowCheck(CGF.getContext(), op))
@@ -3614,6 +3620,11 @@
+    if (op.Ty->isMatrixType()) {
+      llvm::MatrixBuilder<CGBuilderTy> MB(Builder);
+      return MB.CreateSub(op.LHS, op.RHS);
+    }
     if (op.Ty->isUnsignedIntegerType() &&
         CGF.SanOpts.has(SanitizerKind::UnsignedIntegerOverflow) &&
         !CanElideOverflowCheck(CGF.getContext(), op))
Index: clang/include/clang/Sema/Sema.h
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -11074,6 +11074,10 @@
   QualType CheckVectorLogicalOperands(ExprResult &LHS, ExprResult &RHS,
                                       SourceLocation Loc);
+  /// Type checking for matrix binary operators.
+  QualType CheckMatrixElementwiseOperands(ExprResult &LHS, ExprResult &RHS,
+                                          SourceLocation Loc);
   bool areLaxCompatibleVectorTypes(QualType srcType, QualType destType);
   bool isLaxVectorConversion(QualType srcType, QualType destType);
cfe-commits mailing list
  • [PATCH] D76793: [Matrix] Impl... Florian Hahn via Phabricator via cfe-commits

Reply via email to