fhahn created this revision.
fhahn added reviewers: rjmccall, anemet, Bigcheese, rsmith, martong.
Herald added subscribers: tschuett, dexonsmith, rnkovacs.
Herald added a project: clang.

This patch implements the + and - binary operators for values of
MatrixType. It adds support for matrix +/- matrix, scalar +/- matrix and
matrix +/- scalar.

For the matrix, matrix case, the types must initially be structurally
equivalent. For the scalar,matrix variants, the element type of the
matrix must match the scalar type.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D76793

Files:
  clang/include/clang/Sema/Sema.h
  clang/lib/CodeGen/CGExprScalar.cpp
  clang/lib/Sema/SemaExpr.cpp
  clang/test/CodeGen/matrix-type-operators.c
  clang/test/CodeGenCXX/matrix-type-operators.cpp
  clang/test/Sema/matrix-type-operators.c
  clang/test/SemaCXX/matrix-type-operators.cpp

Index: clang/test/SemaCXX/matrix-type-operators.cpp
===================================================================
--- clang/test/SemaCXX/matrix-type-operators.cpp
+++ clang/test/SemaCXX/matrix-type-operators.cpp
@@ -59,3 +59,66 @@
   float v11 = a[5][10.0];
   // expected-error@-1 {{matrix row index is outside the allowed range [0, 5)}}
 }
+
+template <typename EltTy, unsigned Rows, unsigned Columns>
+struct MyMatrix {
+  using matrix_t = EltTy __attribute__((matrix_type(Rows, Columns)));
+
+  matrix_t value;
+};
+
+template <typename EltTy0, unsigned R0, unsigned C0, typename EltTy1, unsigned R1, unsigned C1, typename EltTy2, unsigned R2, unsigned C2>
+typename MyMatrix<EltTy2, R2, C2>::matrix_t add(MyMatrix<EltTy0, R0, C0> &A, MyMatrix<EltTy1, R1, C1> &B) {
+  char *v1 = A.value + B.value;
+  // expected-error@-1 {{cannot initialize a variable of type 'char *' with an rvalue of type 'MyMatrix<unsigned int, 2, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))')}}
+  // expected-error@-2 {{invalid operands to binary expression ('MyMatrix<unsigned int, 3, 3>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))') and 'MyMatrix<float, 2, 2>::matrix_t' (aka 'float __attribute__((matrix_type(2, 2)))'))}}
+  // expected-error@-3 {{invalid operands to binary expression ('MyMatrix<unsigned int, 2, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))') and 'MyMatrix<unsigned int, 3, 3>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))'))}}
+
+  return A.value + B.value;
+  // expected-error@-1 {{invalid operands to binary expression ('MyMatrix<unsigned int, 3, 3>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))') and 'MyMatrix<float, 2, 2>::matrix_t' (aka 'float __attribute__((matrix_type(2, 2)))'))}}
+  // expected-error@-2 {{invalid operands to binary expression ('MyMatrix<unsigned int, 2, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))') and 'MyMatrix<unsigned int, 3, 3>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))'))}}
+}
+
+void test_add_template(unsigned *Ptr1, float *Ptr2) {
+  MyMatrix<unsigned, 2, 2> Mat1;
+  MyMatrix<unsigned, 3, 3> Mat2;
+  MyMatrix<float, 2, 2> Mat3;
+  Mat1.value = *((decltype(Mat1)::matrix_t *)Ptr1);
+  unsigned v1 = add<unsigned, 2, 2, unsigned, 2, 2, unsigned, 2, 2>(Mat1, Mat1);
+  // expected-error@-1 {{cannot initialize a variable of type 'unsigned int' with an rvalue of type 'typename MyMatrix<unsigned int, 2U, 2U>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))')}}
+  // expected-note@-2 {{in instantiation of function template specialization 'add<unsigned int, 2, 2, unsigned int, 2, 2, unsigned int, 2, 2>' requested here}}
+
+  Mat1.value = add<unsigned, 2, 2, unsigned, 3, 3, unsigned, 2, 2>(Mat1, Mat2);
+  // expected-note@-1 {{in instantiation of function template specialization 'add<unsigned int, 2, 2, unsigned int, 3, 3, unsigned int, 2, 2>' requested here}}
+
+  Mat1.value = add<unsigned, 3, 3, float, 2, 2, unsigned, 2, 2>(Mat2, Mat3);
+  // expected-note@-1 {{in instantiation of function template specialization 'add<unsigned int, 3, 3, float, 2, 2, unsigned int, 2, 2>' requested here}}
+}
+
+template <typename EltTy0, unsigned R0, unsigned C0, typename EltTy1, unsigned R1, unsigned C1, typename EltTy2, unsigned R2, unsigned C2>
+typename MyMatrix<EltTy2, R2, C2>::matrix_t subtract(MyMatrix<EltTy0, R0, C0> &A, MyMatrix<EltTy1, R1, C1> &B) {
+  char *v1 = A.value - B.value;
+  // expected-error@-1 {{cannot initialize a variable of type 'char *' with an rvalue of type 'MyMatrix<unsigned int, 2, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))')}}
+  // expected-error@-2 {{invalid operands to binary expression ('MyMatrix<unsigned int, 3, 3>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))') and 'MyMatrix<float, 2, 2>::matrix_t' (aka 'float __attribute__((matrix_type(2, 2)))')}}
+  // expected-error@-3 {{invalid operands to binary expression ('MyMatrix<unsigned int, 2, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))') and 'MyMatrix<unsigned int, 3, 3>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))')}}
+
+  return A.value - B.value;
+  // expected-error@-1 {{invalid operands to binary expression ('MyMatrix<unsigned int, 3, 3>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))') and 'MyMatrix<float, 2, 2>::matrix_t' (aka 'float __attribute__((matrix_type(2, 2)))')}}
+  // expected-error@-2 {{invalid operands to binary expression ('MyMatrix<unsigned int, 2, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))') and 'MyMatrix<unsigned int, 3, 3>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))')}}
+}
+
+void test_subtract_template(unsigned *Ptr1, float *Ptr2) {
+  MyMatrix<unsigned, 2, 2> Mat1;
+  MyMatrix<unsigned, 3, 3> Mat2;
+  MyMatrix<float, 2, 2> Mat3;
+  Mat1.value = *((decltype(Mat1)::matrix_t *)Ptr1);
+  unsigned v1 = subtract<unsigned, 2, 2, unsigned, 2, 2, unsigned, 2, 2>(Mat1, Mat1);
+  // expected-error@-1 {{cannot initialize a variable of type 'unsigned int' with an rvalue of type 'typename MyMatrix<unsigned int, 2U, 2U>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))')}}
+  // expected-note@-2 {{in instantiation of function template specialization 'subtract<unsigned int, 2, 2, unsigned int, 2, 2, unsigned int, 2, 2>' requested here}}
+
+  Mat1.value = subtract<unsigned, 2, 2, unsigned, 3, 3, unsigned, 2, 2>(Mat1, Mat2);
+  // expected-note@-1 {{in instantiation of function template specialization 'subtract<unsigned int, 2, 2, unsigned int, 3, 3, unsigned int, 2, 2>' requested here}}
+
+  Mat1.value = subtract<unsigned, 3, 3, float, 2, 2, unsigned, 2, 2>(Mat2, Mat3);
+  // expected-note@-1 {{in instantiation of function template specialization 'subtract<unsigned int, 3, 3, float, 2, 2, unsigned int, 2, 2>' requested here}}
+}
Index: clang/test/Sema/matrix-type-operators.c
===================================================================
--- clang/test/Sema/matrix-type-operators.c
+++ clang/test/Sema/matrix-type-operators.c
@@ -65,3 +65,34 @@
   float v12 = a[3];
   // expected-error@-1 {{initializing 'float' with an expression of incompatible type 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))')}}
 }
+
+typedef float sx10x5_t __attribute__((matrix_type(10, 5)));
+typedef float sx10x10_t __attribute__((matrix_type(10, 10)));
+
+void add(sx10x10_t a, sx5x10_t b, sx10x5_t c) {
+  a = b + c;
+  // expected-error@-1 {{invalid operands to binary expression ('sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') and 'sx10x5_t' (aka 'float __attribute__((matrix_type(10, 5)))'))}}
+
+  a = b + b; // expected-error {{assigning to 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))') from incompatible type 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))')}}
+
+  // TODO: Implement scalar & matrix binary add.
+  a = 10 + b;
+  // expected-error@-1 {{invalid operands to binary expression ('int' and 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))'))}}
+
+  a = b + &c;
+  // expected-error@-1 {{invalid operands to binary expression ('sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') and 'sx10x5_t *' (aka 'float  __attribute__((matrix_type(10, 5)))*'))}}
+}
+
+void sub(sx10x10_t a, sx5x10_t b, sx10x5_t c) {
+  a = b - c;
+  // expected-error@-1 {{invalid operands to binary expression ('sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') and 'sx10x5_t' (aka 'float __attribute__((matrix_type(10, 5)))'))}}
+
+  a = b - b; // expected-error {{assigning to 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))') from incompatible type 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))')}}
+
+  // TODO: Implement scalar & matrix binary add.
+  a = 10 - b;
+  // expected-error@-1 {{invalid operands to binary expression ('int' and 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))'))}}
+
+  a = b - &c;
+  // expected-error@-1 {{invalid operands to binary expression ('sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') and 'sx10x5_t *' (aka 'float  __attribute__((matrix_type(10, 5)))*'))}}
+}
Index: clang/test/CodeGenCXX/matrix-type-operators.cpp
===================================================================
--- clang/test/CodeGenCXX/matrix-type-operators.cpp
+++ clang/test/CodeGenCXX/matrix-type-operators.cpp
@@ -209,3 +209,79 @@
   Mat1.value = *((decltype(Mat1)::matrix_t *)Ptr1);
   unsigned v1 = extract(Mat1);
 }
+
+template <typename EltTy0, unsigned R0, unsigned C0>
+typename MyMatrix<EltTy0, R0, C0>::matrix_t add(MyMatrix<EltTy0, R0, C0> &A, MyMatrix<EltTy0, R0, C0> &B) {
+  return A.value + B.value;
+}
+
+void test_add_template() {
+  // CHECK-LABEL:    define void @_Z17test_add_templatev()
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %Mat1 = alloca %struct.MyMatrix.1, align 4
+  // CHECK-NEXT:    %Mat2 = alloca %struct.MyMatrix.1, align 4
+  // CHECK-NEXT:    %call = call <10 x float> @_Z3addIfLj2ELj5EEN8MyMatrixIT_XT0_EXT1_EE8matrix_tERS2_S4_(%struct.MyMatrix.1* dereferenceable(40) %Mat1, %struct.MyMatrix.1* dereferenceable(40) %Mat2)
+  // CHECK-NEXT:    %value = getelementptr inbounds %struct.MyMatrix.1, %struct.MyMatrix.1* %Mat1, i32 0, i32 0
+  // CHECK-NEXT:    %0 = bitcast [10 x float]* %value to <10 x float>*
+  // CHECK-NEXT:    store <10 x float> %call, <10 x float>* %0, align 4
+  // CHECK-NEXT:    ret void
+
+  // CHECK-LABEL: define linkonce_odr <10 x float> @_Z3addIfLj2ELj5EEN8MyMatrixIT_XT0_EXT1_EE8matrix_tERS2_S4_(%struct.MyMatrix.1* dereferenceable(40) %A, %struct.MyMatrix.1* dereferenceable(40) %B)
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %A.addr = alloca %struct.MyMatrix.1*, align 8
+  // CHECK-NEXT:    %B.addr = alloca %struct.MyMatrix.1*, align 8
+  // CHECK-NEXT:    store %struct.MyMatrix.1* %A, %struct.MyMatrix.1** %A.addr, align 8
+  // CHECK-NEXT:    store %struct.MyMatrix.1* %B, %struct.MyMatrix.1** %B.addr, align 8
+  // CHECK-NEXT:    %0 = load %struct.MyMatrix.1*, %struct.MyMatrix.1** %A.addr, align 8
+  // CHECK-NEXT:    %value = getelementptr inbounds %struct.MyMatrix.1, %struct.MyMatrix.1* %0, i32 0, i32 0
+  // CHECK-NEXT:    %1 = bitcast [10 x float]* %value to <10 x float>*
+  // CHECK-NEXT:    %2 = load <10 x float>, <10 x float>* %1, align 4
+  // CHECK-NEXT:    %3 = load %struct.MyMatrix.1*, %struct.MyMatrix.1** %B.addr, align 8
+  // CHECK-NEXT:    %value1 = getelementptr inbounds %struct.MyMatrix.1, %struct.MyMatrix.1* %3, i32 0, i32 0
+  // CHECK-NEXT:    %4 = bitcast [10 x float]* %value1 to <10 x float>*
+  // CHECK-NEXT:    %5 = load <10 x float>, <10 x float>* %4, align 4
+  // CHECK-NEXT:    %6 = fadd <10 x float> %2, %5
+  // CHECK-NEXT:    ret <10 x float> %6
+
+  MyMatrix<float, 2, 5> Mat1;
+  MyMatrix<float, 2, 5> Mat2;
+  Mat1.value = add(Mat1, Mat2);
+}
+
+template <typename EltTy0, unsigned R0, unsigned C0>
+typename MyMatrix<EltTy0, R0, C0>::matrix_t subtract(MyMatrix<EltTy0, R0, C0> &A, MyMatrix<EltTy0, R0, C0> &B) {
+  return A.value - B.value;
+}
+
+void test_subtract_template() {
+  // CHECK-LABEL: define void @_Z22test_subtract_templatev()
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %Mat1 = alloca %struct.MyMatrix.1, align 4
+  // CHECK-NEXT:    %Mat2 = alloca %struct.MyMatrix.1, align 4
+  // CHECK-NEXT:    %call = call <10 x float> @_Z8subtractIfLj2ELj5EEN8MyMatrixIT_XT0_EXT1_EE8matrix_tERS2_S4_(%struct.MyMatrix.1* dereferenceable(40) %Mat1, %struct.MyMatrix.1* dereferenceable(40) %Mat2)
+  // CHECK-NEXT:    %value = getelementptr inbounds %struct.MyMatrix.1, %struct.MyMatrix.1* %Mat1, i32 0, i32 0
+  // CHECK-NEXT:    %0 = bitcast [10 x float]* %value to <10 x float>*
+  // CHECK-NEXT:    store <10 x float> %call, <10 x float>* %0, align 4
+  // CHECK-NEXT:    ret void
+
+  // CHECK-LABEL: define linkonce_odr <10 x float> @_Z8subtractIfLj2ELj5EEN8MyMatrixIT_XT0_EXT1_EE8matrix_tERS2_S4_(%struct.MyMatrix.1* dereferenceable(40) %A, %struct.MyMatrix.1* dereferenceable(40) %B)
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %A.addr = alloca %struct.MyMatrix.1*, align 8
+  // CHECK-NEXT:    %B.addr = alloca %struct.MyMatrix.1*, align 8
+  // CHECK-NEXT:    store %struct.MyMatrix.1* %A, %struct.MyMatrix.1** %A.addr, align 8
+  // CHECK-NEXT:    store %struct.MyMatrix.1* %B, %struct.MyMatrix.1** %B.addr, align 8
+  // CHECK-NEXT:    %0 = load %struct.MyMatrix.1*, %struct.MyMatrix.1** %A.addr, align 8
+  // CHECK-NEXT:    %value = getelementptr inbounds %struct.MyMatrix.1, %struct.MyMatrix.1* %0, i32 0, i32 0
+  // CHECK-NEXT:    %1 = bitcast [10 x float]* %value to <10 x float>*
+  // CHECK-NEXT:    %2 = load <10 x float>, <10 x float>* %1, align 4
+  // CHECK-NEXT:    %3 = load %struct.MyMatrix.1*, %struct.MyMatrix.1** %B.addr, align 8
+  // CHECK-NEXT:    %value1 = getelementptr inbounds %struct.MyMatrix.1, %struct.MyMatrix.1* %3, i32 0, i32 0
+  // CHECK-NEXT:    %4 = bitcast [10 x float]* %value1 to <10 x float>*
+  // CHECK-NEXT:    %5 = load <10 x float>, <10 x float>* %4, align 4
+  // CHECK-NEXT:    %6 = fsub <10 x float> %2, %5
+  // CHECK-NEXT:    ret <10 x float> %6
+
+  MyMatrix<float, 2, 5> Mat1;
+  MyMatrix<float, 2, 5> Mat2;
+  Mat1.value = subtract(Mat1, Mat2);
+}
Index: clang/test/CodeGen/matrix-type-operators.c
===================================================================
--- clang/test/CodeGen/matrix-type-operators.c
+++ clang/test/CodeGen/matrix-type-operators.c
@@ -155,3 +155,73 @@
   // CHECK-NEXT:    store i32 %matext2, i32* %v3, align 4
   // CHECK-NEXT:    ret void
 }
+
+void add1(dx5x5_t a, dx5x5_t b, dx5x5_t c, ix9x3_t ai, ix9x3_t bi, ix9x3_t ci) {
+  a = b + c;
+  ai = bi + ci;
+
+  // CHECK-LABEL: @add1(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [25 x double], align 8
+  // CHECK-NEXT:    %b.addr = alloca [25 x double], align 8
+  // CHECK-NEXT:    %c.addr = alloca [25 x double], align 8
+  // CHECK-NEXT:    %ai.addr = alloca [27 x i32], align 4
+  // CHECK-NEXT:    %bi.addr = alloca [27 x i32], align 4
+  // CHECK-NEXT:    %ci.addr = alloca [27 x i32], align 4
+  // CHECK-NEXT:    %0 = bitcast [25 x double]* %a.addr to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %a, <25 x double>* %0, align 8
+  // CHECK-NEXT:    %1 = bitcast [25 x double]* %b.addr to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %b, <25 x double>* %1, align 8
+  // CHECK-NEXT:    %2 = bitcast [25 x double]* %c.addr to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %c, <25 x double>* %2, align 8
+  // CHECK-NEXT:    %3 = bitcast [27 x i32]* %ai.addr to <27 x i32>*
+  // CHECK-NEXT:    store <27 x i32> %ai, <27 x i32>* %3, align 4
+  // CHECK-NEXT:    %4 = bitcast [27 x i32]* %bi.addr to <27 x i32>*
+  // CHECK-NEXT:    store <27 x i32> %bi, <27 x i32>* %4, align 4
+  // CHECK-NEXT:    %5 = bitcast [27 x i32]* %ci.addr to <27 x i32>*
+  // CHECK-NEXT:    store <27 x i32> %ci, <27 x i32>* %5, align 4
+  // CHECK-NEXT:    %6 = load <25 x double>, <25 x double>* %1, align 8
+  // CHECK-NEXT:    %7 = load <25 x double>, <25 x double>* %2, align 8
+  // CHECK-NEXT:    %8 = fadd <25 x double> %6, %7
+  // CHECK-NEXT:    store <25 x double> %8, <25 x double>* %0, align 8
+  // CHECK-NEXT:    %9 = load <27 x i32>, <27 x i32>* %4, align 4
+  // CHECK-NEXT:    %10 = load <27 x i32>, <27 x i32>* %5, align 4
+  // CHECK-NEXT:    %11 = add <27 x i32> %9, %10
+  // CHECK-NEXT:    store <27 x i32> %11, <27 x i32>* %3, align 4
+  // CHECK-NEXT:    ret void
+}
+
+void sub1(dx5x5_t a, dx5x5_t b, dx5x5_t c, ix9x3_t ai, ix9x3_t bi, ix9x3_t ci) {
+  a = b - c;
+  ai = bi - ci;
+
+  // CHECK-LABEL: @sub1(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [25 x double], align 8
+  // CHECK-NEXT:    %b.addr = alloca [25 x double], align 8
+  // CHECK-NEXT:    %c.addr = alloca [25 x double], align 8
+  // CHECK-NEXT:    %ai.addr = alloca [27 x i32], align 4
+  // CHECK-NEXT:    %bi.addr = alloca [27 x i32], align 4
+  // CHECK-NEXT:    %ci.addr = alloca [27 x i32], align 4
+  // CHECK-NEXT:    %0 = bitcast [25 x double]* %a.addr to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %a, <25 x double>* %0, align 8
+  // CHECK-NEXT:    %1 = bitcast [25 x double]* %b.addr to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %b, <25 x double>* %1, align 8
+  // CHECK-NEXT:    %2 = bitcast [25 x double]* %c.addr to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %c, <25 x double>* %2, align 8
+  // CHECK-NEXT:    %3 = bitcast [27 x i32]* %ai.addr to <27 x i32>*
+  // CHECK-NEXT:    store <27 x i32> %ai, <27 x i32>* %3, align 4
+  // CHECK-NEXT:    %4 = bitcast [27 x i32]* %bi.addr to <27 x i32>*
+  // CHECK-NEXT:    store <27 x i32> %bi, <27 x i32>* %4, align 4
+  // CHECK-NEXT:    %5 = bitcast [27 x i32]* %ci.addr to <27 x i32>*
+  // CHECK-NEXT:    store <27 x i32> %ci, <27 x i32>* %5, align 4
+  // CHECK-NEXT:    %6 = load <25 x double>, <25 x double>* %1, align 8
+  // CHECK-NEXT:    %7 = load <25 x double>, <25 x double>* %2, align 8
+  // CHECK-NEXT:    %8 = fsub <25 x double> %6, %7
+  // CHECK-NEXT:    store <25 x double> %8, <25 x double>* %0, align 8
+  // CHECK-NEXT:    %9 = load <27 x i32>, <27 x i32>* %4, align 4
+  // CHECK-NEXT:    %10 = load <27 x i32>, <27 x i32>* %5, align 4
+  // CHECK-NEXT:    %11 = sub <27 x i32> %9, %10
+  // CHECK-NEXT:    store <27 x i32> %11, <27 x i32>* %3, align 4
+  // CHECK-NEXT:    ret void
+}
Index: clang/lib/Sema/SemaExpr.cpp
===================================================================
--- clang/lib/Sema/SemaExpr.cpp
+++ clang/lib/Sema/SemaExpr.cpp
@@ -9865,6 +9865,11 @@
     return compType;
   }
 
+  if (LHS.get()->getType()->isMatrixType() ||
+      RHS.get()->getType()->isMatrixType()) {
+    return CheckMatrixElementwiseOperands(LHS, RHS, Loc);
+  }
+
   QualType compType = UsualArithmeticConversions(
       LHS, RHS, Loc, CompLHSTy ? ACK_CompAssign : ACK_Arithmetic);
   if (LHS.isInvalid() || RHS.isInvalid())
@@ -9960,6 +9965,11 @@
     return compType;
   }
 
+  if (LHS.get()->getType()->isMatrixType() ||
+      RHS.get()->getType()->isMatrixType()) {
+    return CheckMatrixElementwiseOperands(LHS, RHS, Loc);
+  }
+
   QualType compType = UsualArithmeticConversions(
       LHS, RHS, Loc, CompLHSTy ? ACK_CompAssign : ACK_Arithmetic);
   if (LHS.isInvalid() || RHS.isInvalid())
@@ -11551,6 +11561,23 @@
   return GetSignedVectorType(LHS.get()->getType());
 }
 
+QualType Sema::CheckMatrixElementwiseOperands(ExprResult &LHS, ExprResult &RHS,
+                                              SourceLocation Loc) {
+  // For conversion purposes, we ignore any qualifiers.
+  // For example, "const float" and "float" are equivalent.
+  QualType LHSType = LHS.get()->getType().getUnqualifiedType();
+  QualType RHSType = RHS.get()->getType().getUnqualifiedType();
+
+  // If the vector types are identical, return.
+  if (Context.hasSameType(LHSType, RHSType))
+    return LHSType;
+
+  return InvalidOperands(Loc, LHS, RHS);
+  /*const MatrixType *LHSMatType = LHSType->getAs<MatrixType>();*/
+  // const MatrixType *RHSMatType = RHSType->getAs<MatrixType>();
+  // assert(LHSMatType || RHSMatType);
+}
+
 inline QualType Sema::CheckBitwiseOperands(ExprResult &LHS, ExprResult &RHS,
                                            SourceLocation Loc,
                                            BinaryOperatorKind Opc) {
Index: clang/lib/CodeGen/CGExprScalar.cpp
===================================================================
--- clang/lib/CodeGen/CGExprScalar.cpp
+++ clang/lib/CodeGen/CGExprScalar.cpp
@@ -37,6 +37,7 @@
 #include "llvm/IR/GlobalVariable.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/IntrinsicsPowerPC.h"
+#include "llvm/IR/MatrixBuilder.h"
 #include "llvm/IR/Module.h"
 #include <cstdarg>
 
@@ -3469,6 +3470,11 @@
     }
   }
 
+  if (op.Ty->isMatrixType()) {
+    llvm::MatrixBuilder<CGBuilderTy> MB(Builder);
+    return MB.CreateAdd(op.LHS, op.RHS);
+  }
+
   if (op.Ty->isUnsignedIntegerType() &&
       CGF.SanOpts.has(SanitizerKind::UnsignedIntegerOverflow) &&
       !CanElideOverflowCheck(CGF.getContext(), op))
@@ -3614,6 +3620,11 @@
       }
     }
 
+    if (op.Ty->isMatrixType()) {
+      llvm::MatrixBuilder<CGBuilderTy> MB(Builder);
+      return MB.CreateSub(op.LHS, op.RHS);
+    }
+
     if (op.Ty->isUnsignedIntegerType() &&
         CGF.SanOpts.has(SanitizerKind::UnsignedIntegerOverflow) &&
         !CanElideOverflowCheck(CGF.getContext(), op))
Index: clang/include/clang/Sema/Sema.h
===================================================================
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -11074,6 +11074,10 @@
   QualType CheckVectorLogicalOperands(ExprResult &LHS, ExprResult &RHS,
                                       SourceLocation Loc);
 
+  /// Type checking for matrix binary operators.
+  QualType CheckMatrixElementwiseOperands(ExprResult &LHS, ExprResult &RHS,
+                                          SourceLocation Loc);
+
   bool areLaxCompatibleVectorTypes(QualType srcType, QualType destType);
   bool isLaxVectorConversion(QualType srcType, QualType destType);
 
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to