fhahn updated this revision to Diff 269249.
fhahn added a comment.

Ping.

Applied feedback from D72778 <https://reviews.llvm.org/D72778> to this patch, 
improved tests, support conversions/placeholders.

One thing I am not sure is how to properly handle template substitutions for 
the pointer expression for code like the one below, where we need to apply 
substitutions to get the actual pointer type. Currently the patch looks through 
SubstTemplateTypeParmType types in Sema to construct the result type. Should we 
look through SubstTemplateTypeParmType in IRGen too to decide whether to call 
EmitPointerWithAlignment or EmitArrayToPointerDecay? Or is there a place in 
sema that should get rid of the substitution (perhaps in SemaChecking.cpp)?

  template <typename T> struct remove_pointer {
  typedef T type;
  };
  
  template <typename T> struct remove_pointer<T *>{
  typedef typename remove_pointer<T>::type type;
  };
  
  // Same as column_major_load_with_stride, but with the PtrT argument itself 
begin a pointer type.
  template <typename PtrT, unsigned R, unsigned C, unsigned S>
  matrix_t<typename remove_pointer<PtrT>::type, R, C> 
column_major_load_with_stride2(PtrT Ptr) {
  return __builtin_matrix_column_major_load(Ptr, R, C, S);
  }
  
  void call_column_major_load_with_stride2(float *Ptr) {
  matrix_t<float, 2, 2> m = column_major_load_with_stride2<float *, 2, 2, 
2>(Ptr);
  }


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D72781/new/

https://reviews.llvm.org/D72781

Files:
  clang/include/clang/AST/Type.h
  clang/include/clang/Basic/Builtins.def
  clang/include/clang/Basic/DiagnosticSemaKinds.td
  clang/include/clang/Sema/Sema.h
  clang/lib/CodeGen/CGBuiltin.cpp
  clang/lib/Sema/SemaChecking.cpp
  clang/lib/Sema/SemaExpr.cpp
  clang/test/CodeGen/matrix-type-builtins.c
  clang/test/CodeGenCXX/matrix-type-builtins.cpp
  clang/test/CodeGenObjC/matrix-type-builtins.m
  clang/test/Sema/matrix-type-builtins.c
  clang/test/SemaCXX/matrix-type-builtins.cpp

Index: clang/test/SemaCXX/matrix-type-builtins.cpp
===================================================================
--- clang/test/SemaCXX/matrix-type-builtins.cpp
+++ clang/test/SemaCXX/matrix-type-builtins.cpp
@@ -39,3 +39,65 @@
   Mat3.value = transpose<unsigned, 3, 3, float, 3, 3>(Mat2);
   // expected-note@-1 {{in instantiation of function template specialization 'transpose<unsigned int, 3, 3, float, 3, 3>' requested here}}
 }
+
+template <typename EltTy0, unsigned R0, unsigned C0, typename EltTy1, unsigned R1, unsigned C1>
+typename MyMatrix<EltTy1, R1, C1>::matrix_t column_major_load(MyMatrix<EltTy0, R0, C0> &A, EltTy0 *Ptr) {
+  char *v1 = __builtin_matrix_column_major_load(Ptr, 9, 4, 10);
+  // expected-error@-1 {{cannot initialize a variable of type 'char *' with an rvalue of type 'unsigned int __attribute__((matrix_type(9, 4)))'}}
+  // expected-error@-2 {{cannot initialize a variable of type 'char *' with an rvalue of type 'unsigned int __attribute__((matrix_type(9, 4)))'}}
+  // expected-error@-3 {{cannot initialize a variable of type 'char *' with an rvalue of type 'float __attribute__((matrix_type(9, 4)))'}}
+
+  return __builtin_matrix_column_major_load(Ptr, R0, C0, R0);
+  // expected-error@-1 {{cannot initialize return object of type 'typename MyMatrix<unsigned int, 5U, 5U>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(5, 5)))') with an rvalue of type 'unsigned int __attribute__((matrix_type(2, 3)))'}}
+  // expected-error@-2 {{cannot initialize return object of type 'typename MyMatrix<unsigned int, 2U, 3U>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 3)))') with an rvalue of type 'float __attribute__((matrix_type(2, 3)))'}}
+}
+
+void test_column_major_loads_template(unsigned *Ptr1, float *Ptr2) {
+  MyMatrix<unsigned, 2, 3> Mat1;
+  Mat1.value = column_major_load<unsigned, 2, 3, unsigned, 2, 3>(Mat1, Ptr1);
+  // expected-note@-1 {{in instantiation of function template specialization 'column_major_load<unsigned int, 2, 3, unsigned int, 2, 3>' requested here}}
+  column_major_load<unsigned, 2, 3, unsigned, 5, 5>(Mat1, Ptr1);
+  // expected-note@-1 {{in instantiation of function template specialization 'column_major_load<unsigned int, 2, 3, unsigned int, 5, 5>' requested here}}
+
+  MyMatrix<float, 2, 3> Mat2;
+  Mat1.value = column_major_load<float, 2, 3, unsigned, 2, 3>(Mat2, Ptr2);
+  // expected-note@-1 {{in instantiation of function template specialization 'column_major_load<float, 2, 3, unsigned int, 2, 3>' requested here}}
+}
+
+constexpr int constexpr1() { return 1; }
+constexpr int constexpr_neg1() { return -1; }
+
+void test_column_major_load_constexpr(unsigned *Ptr) {
+  (void)__builtin_matrix_column_major_load(Ptr, 2, 2, constexpr1());
+  // expected-error@-1 {{stride must be greater or equal to the number of rows}}
+  (void)__builtin_matrix_column_major_load(Ptr, constexpr_neg1(), 2, 4);
+  // expected-error@-1 {{row dimension is outside the allowed range [1, 1048575]}}
+  (void)__builtin_matrix_column_major_load(Ptr, 2, constexpr_neg1(), 4);
+  // expected-error@-1 {{column dimension is outside the allowed range [1, 1048575]}}
+}
+
+struct IntWrapper {
+  operator int() {
+    return 1;
+  }
+};
+
+void test_column_major_load_wrapper(unsigned *Ptr, IntWrapper &W) {
+  (void)__builtin_matrix_column_major_load(Ptr, W, 2, 2);
+  // expected-error@-1 {{row argument must be a constant unsigned integer expression}}
+  (void)__builtin_matrix_column_major_load(Ptr, 2, W, 2);
+  // expected-error@-1 {{column argument must be a constant unsigned integer expression}}
+}
+
+template <typename T, unsigned R, unsigned C, unsigned S>
+void test_column_major_load_temp(T Ptr) {
+  (void)__builtin_matrix_column_major_load(Ptr, R, C, S);
+}
+
+void call_column_major_load_temp(unsigned *Ptr, unsigned X) {
+  (void)__builtin_matrix_column_major_load(Ptr, X, X, X);
+  // expected-error@-1 {{row argument must be a constant unsigned integer expression}}
+  // expected-error@-2 {{column argument must be a constant unsigned integer expression}}
+  (void)__builtin_matrix_column_major_load(X, 2, 2, 2);
+  // expected-error@-1 {{first argument must be a pointer to a valid matrix element type}}
+}
Index: clang/test/Sema/matrix-type-builtins.c
===================================================================
--- clang/test/Sema/matrix-type-builtins.c
+++ clang/test/Sema/matrix-type-builtins.c
@@ -20,3 +20,49 @@
   ix3x3 m = __builtin_matrix_transpose(c);
   // expected-error@-1 {{initializing 'ix3x3' (aka 'unsigned int __attribute__((matrix_type(3, 3)))') with an expression of incompatible type 'double __attribute__((matrix_type(3, 3)))'}}
 }
+
+struct Foo {
+  unsigned x;
+};
+
+void column_major_load(float *p1, int *p2, _Bool *p3, struct Foo *p4) {
+  sx5x10_t a1 = __builtin_matrix_column_major_load(p1, 5, 11, 5);
+  // expected-error@-1 {{initializing 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') with an expression of incompatible type 'float __attribute__((matrix_type(5, 11)))'}}
+  sx5x10_t a2 = __builtin_matrix_column_major_load(p1, 5, 9, 5);
+  // expected-error@-1 {{initializing 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') with an expression of incompatible type 'float __attribute__((matrix_type(5, 9)))'}}
+  sx5x10_t a3 = __builtin_matrix_column_major_load(p1, 6, 10, 6);
+  // expected-error@-1 {{initializing 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') with an expression of incompatible type 'float __attribute__((matrix_type(6, 10)))'}}
+  sx5x10_t a4 = __builtin_matrix_column_major_load(p1, 4, 10, 4);
+  // expected-error@-1 {{initializing 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') with an expression of incompatible type 'float __attribute__((matrix_type(4, 10)))'}}
+  sx5x10_t a5 = __builtin_matrix_column_major_load(p1, 6, 9, 6);
+  // expected-error@-1 {{initializing 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') with an expression of incompatible type 'float __attribute__((matrix_type(6, 9)))'}}
+  sx5x10_t a6 = __builtin_matrix_column_major_load(p2, 5, 10, 6);
+  // expected-error@-1 {{initializing 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') with an expression of incompatible type 'int __attribute__((matrix_type(5, 10)))'}}
+
+  sx5x10_t a7 = __builtin_matrix_column_major_load(p1, 5, 10, 3);
+  // expected-error@-1 {{stride must be greater or equal to the number of rows}}
+
+  sx5x10_t a8 = __builtin_matrix_column_major_load(p3, 5, 10, 6);
+  // expected-error@-1 {{first argument must be a pointer to a valid matrix element type}}
+
+  sx5x10_t a9 = __builtin_matrix_column_major_load(p4, 5, 10, 6);
+  // expected-error@-1 {{first argument must be a pointer to a valid matrix element type}}
+
+  sx5x10_t a10 = __builtin_matrix_column_major_load(p1, 1ull << 21, 10, 6);
+  // expected-error@-1 {{row dimension is outside the allowed range [1, 1048575}}
+  sx5x10_t a11 = __builtin_matrix_column_major_load(p1, 10, 1ull << 21, 10);
+  // expected-error@-1 {{column dimension is outside the allowed range [1, 1048575}}
+
+  sx5x10_t a12 = __builtin_matrix_column_major_load(
+      10,         // expected-error {{first argument must be a pointer to a valid matrix element type}}
+      1ull << 21, // expected-error {{row dimension is outside the allowed range [1, 1048575]}}
+      1ull << 21, // expected-error {{column dimension is outside the allowed range [1, 1048575]}}
+      "");        // expected-warning {{incompatible pointer to integer conversion casting 'char [1]' to type 'unsigned long'}}
+
+  sx5x10_t a13 = __builtin_matrix_column_major_load(
+      10,  // expected-error {{first argument must be a pointer to a valid matrix element type}}
+      *p4, // expected-error {{casting 'struct Foo' to incompatible type 'unsigned long'}}
+      "",  // expected-error {{column argument must be a constant unsigned integer expression}}
+           // expected-warning@-1 {{incompatible pointer to integer conversion casting 'char [1]' to type 'unsigned long'}}
+      10);
+}
Index: clang/test/CodeGenObjC/matrix-type-builtins.m
===================================================================
--- clang/test/CodeGenObjC/matrix-type-builtins.m
+++ clang/test/CodeGenObjC/matrix-type-builtins.m
@@ -40,3 +40,23 @@
 
   m.value = __builtin_matrix_transpose(*r);
 }
+
+__attribute__((objc_root_class))
+@interface PtrValue
+@property unsigned *value;
+@end
+
+__attribute__((objc_root_class))
+@interface IntValue
+@property int value;
+@end
+
+void test_column_major_load(PtrValue *Ptr, IntValue *Stride) {
+  // CHECK-LABEL: define void @test_column_major_load(%2* %Ptr, %3* %Stride) #4 {
+  // CHECK:         [[STRIDE:%.*]] = call i32 bitcast (i8* (i8*, i8*, ...)* @objc_msgSend to i32 (i8*, i8*)*)
+  // CHECK-NEXT:    [[STRIDE_EXT:%.*]] = sext i32 [[STRIDE]] to i64
+  // CHECK:         [[PTR:%.*]] = call i32* bitcast (i8* (i8*, i8*, ...)* @objc_msgSend to i32* (i8*, i8*)*)
+  // CHECK-NEXT:    call <12 x i32> @llvm.matrix.columnwise.load.v12i32.p0i32(i32* [[PTR]], i64 [[STRIDE_EXT]], i64 3, i64 4)
+
+  u3x4 m = __builtin_matrix_column_major_load(Ptr.value, 3, 4, Stride.value);
+}
Index: clang/test/CodeGenCXX/matrix-type-builtins.cpp
===================================================================
--- clang/test/CodeGenCXX/matrix-type-builtins.cpp
+++ clang/test/CodeGenCXX/matrix-type-builtins.cpp
@@ -74,3 +74,107 @@
   // CHECK-NEXT:    store <9 x float> [[M_T]], <9 x float>* [[M_T_ADDR]], align 4
   matrix_t<float, 3, 3> m_t = __builtin_matrix_transpose(m);
 }
+
+template <typename T, unsigned R, unsigned C, unsigned S>
+matrix_t<T, R, C> column_major_load_with_stride(T *Ptr) {
+  return __builtin_matrix_column_major_load(Ptr, R, C, S);
+}
+
+void test_column_major_load_with_stride_template_double(double *Ptr) {
+  // CHECK-LABEL: define void @_Z50test_column_major_load_with_stride_template_doublePd(double* %Ptr)
+  // CHECK:         [[PTR:%.*]] = load double*, double** %Ptr.addr, align 8
+  // CHECK-NEXT:    call <40 x double> @_Z29column_major_load_with_strideIdLj10ELj4ELj15EEU11matrix_typeXT0_EXT1_ET_PS0_(double* [[PTR]])
+
+  // CHECK-LABEL:  define linkonce_odr <40 x double> @_Z29column_major_load_with_strideIdLj10ELj4ELj15EEU11matrix_typeXT0_EXT1_ET_PS0_(double* %Ptr)
+  // CHECK:         [[PTR:%.*]] = load double*, double** %Ptr.addr, align 8
+  // CHECK-NEXT:    call <40 x double> @llvm.matrix.columnwise.load.v40f64.p0f64(double* [[PTR]], i64 15, i64 10, i64 4)
+
+  matrix_t<double, 10, 4> M1 = column_major_load_with_stride<double, 10, 4, 15>(Ptr);
+}
+
+void test_column_major_load_with_stride_template_int(int *Ptr) {
+  // CHECK-LABEL: define void @_Z47test_column_major_load_with_stride_template_intPi(i32* %Ptr) #5 {
+  // CHECK:         [[PTR:%.*]] = load i32*, i32** %Ptr.addr, align 8
+  // CHECK-NEXT:    call <6 x i32> @_Z29column_major_load_with_strideIiLj3ELj2ELj12EEU11matrix_typeXT0_EXT1_ET_PS0_(i32* [[PTR]])
+
+  // CHECK-LABEL: define linkonce_odr <6 x i32> @_Z29column_major_load_with_strideIiLj3ELj2ELj12EEU11matrix_typeXT0_EXT1_ET_PS0_(i32* %Ptr)
+  // CHECK:         [[PTR:%.*]] = load i32*, i32** %Ptr.addr, align 8
+  // CHECK-NEXT:    call <6 x i32> @llvm.matrix.columnwise.load.v6i32.p0i32(i32* [[PTR]], i64 12, i64 3, i64 2)
+
+  matrix_t<int, 3, 2> M1 = column_major_load_with_stride<int, 3, 2, 12>(Ptr);
+}
+
+struct UnsignedWrapper {
+  char x;
+  operator unsigned() {
+    return x;
+  }
+};
+
+void test_column_major_load_stride_wrapper(int *Ptr, UnsignedWrapper &W) {
+  // CHECK-LABEL:  define void @_Z37test_column_major_load_stride_wrapperPiR15UnsignedWrapper(i32* %Ptr, %struct.UnsignedWrapper* nonnull align 1 dereferenceable(1) %W)
+  // CHECK:         [[W:%.*]] = load %struct.UnsignedWrapper*, %struct.UnsignedWrapper** %W.addr, align 8
+  // CHECK-NEXT:    [[STRIDE:%.*]] = call i32 @_ZN15UnsignedWrappercvjEv(%struct.UnsignedWrapper* [[W]])
+  // CHECK-NEXT:    [[STRIDE_EXT:%.*]] = zext i32 [[STRIDE]] to i64
+  // CHECK-NEXT:    [[PTR:%.*]] = load i32*, i32** %Ptr.addr, align 8
+  // CHECK-NEXT:    call <4 x i32> @llvm.matrix.columnwise.load.v4i32.p0i32(i32* [[PTR]], i64 [[STRIDE_EXT]], i64 2, i64 2)
+  matrix_t<int, 2, 2> M1 = __builtin_matrix_column_major_load(Ptr, 2, 2, W);
+}
+
+constexpr int constexpr3() { return 3; }
+
+void test_column_major_load_constexpr_num_rows(int *Ptr) {
+  // CHECK-LABEL: define void @_Z41test_column_major_load_constexpr_num_rowsPi(i32* %Ptr)
+  // CHECK:         [[PTR:%.*]] = load i32*, i32** %Ptr.addr, align 8
+  // CHECK-NEXT:    call <6 x i32> @llvm.matrix.columnwise.load.v6i32.p0i32(i32* [[PTR]], i64 3, i64 3, i64 2)
+
+  matrix_t<int, 3, 2> M1 = __builtin_matrix_column_major_load(Ptr, constexpr3(), 2, 3);
+}
+
+constexpr int constexpr1() { return 1; }
+
+void test_column_major_load_constexpr_num_columns(int *Ptr) {
+  // CHECK-LABEL: define void @_Z44test_column_major_load_constexpr_num_columnsPi(i32* %Ptr)
+  // CHECK:         [[PTR:%.*]] = load i32*, i32** %Ptr.addr, align 8
+  // CHECK-NEXT:    call <2 x i32> @llvm.matrix.columnwise.load.v2i32.p0i32(i32* [[PTR]], i64 3, i64 2, i64 1)
+  matrix_t<int, 2, 1> M1 = __builtin_matrix_column_major_load(Ptr, 2, constexpr1(), 3);
+}
+
+template <unsigned N>
+constexpr int constexpr_plus1() { return N + 1; }
+
+void test_column_major_load_constexpr_num_columns_temp(int *Ptr) {
+  // CHECK-LABEL:  define void @_Z49test_column_major_load_constexpr_num_columns_tempPi(i32* %Ptr)
+  // CHECK:         [[PTR:%.*]] = load i32*, i32** %Ptr.addr, align 8
+  // CHECK-NEXT:    call <10 x i32> @llvm.matrix.columnwise.load.v10i32.p0i32(i32* [[PTR]], i64 3, i64 2, i64 5)
+  matrix_t<int, 2, 5> M1 = __builtin_matrix_column_major_load(Ptr, 2, constexpr_plus1<4>(), 3);
+}
+
+void test_column_major_load_constexpr_stride_constexpr(int *Ptr) {
+  // CHECK-LABEL: define void @_Z49test_column_major_load_constexpr_stride_constexprPi(i32* %Ptr)
+  // CHECK:         [[STRIDE:%.*]] = call i32 @_Z10constexpr3v()
+  // CHECK-NEXT:    [[STRIDE_EXT:%.*]] = sext i32 [[STRIDE]] to i64
+  // CHECK-NEXT:    [[PTR:%.*]] = load i32*, i32** %Ptr.addr, align 8
+  // CHECK-NEXT:    call <4 x i32> @llvm.matrix.columnwise.load.v4i32.p0i32(i32* [[PTR]], i64 [[STRIDE_EXT]], i64 2, i64 2)
+
+  matrix_t<int, 2, 2> M1 = __builtin_matrix_column_major_load(Ptr, 2, 2, constexpr3());
+}
+
+// TODO:
+/*template <typename T> struct remove_pointer {*/
+//typedef T type;
+//};
+
+//template <typename T> struct remove_pointer<T *>{
+//typedef typename remove_pointer<T>::type type;
+//};
+
+//// Same as column_major_load_with_stride, but with the PtrT argument itself begin a pointer type.
+//template <typename PtrT, unsigned R, unsigned C, unsigned S>
+//matrix_t<typename remove_pointer<PtrT>::type, R, C> column_major_load_with_stride2(PtrT Ptr) {
+//return __builtin_matrix_column_major_load(Ptr, R, C, S);
+//}
+
+//void call_column_major_load_with_stride2(float *Ptr) {
+//matrix_t<float, 2, 2> m = column_major_load_with_stride2<float *, 2, 2, 2>(Ptr);
+//}
Index: clang/test/CodeGen/matrix-type-builtins.c
===================================================================
--- clang/test/CodeGen/matrix-type-builtins.c
+++ clang/test/CodeGen/matrix-type-builtins.c
@@ -96,3 +96,50 @@
 
   dx5x5_t m_t = __builtin_matrix_transpose(global_matrix);
 }
+
+void column_major_load_with_const_stride_double(double *Ptr) {
+  // CHECK-LABEL: define void @column_major_load_with_const_stride_double(double* %Ptr)
+  // CHECK:         [[PTR:%.*]] = load double*, double** %Ptr.addr, align 8
+  // CHECK-NEXT:    call <25 x double> @llvm.matrix.columnwise.load.v25f64.p0f64(double* [[PTR]], i64 5, i64 5, i64 5)
+
+  dx5x5_t m_a1 = __builtin_matrix_column_major_load(Ptr, 5, 5, 5);
+}
+
+void column_major_load_with_const_stride2_double(double *Ptr) {
+  // CHECK-LABEL: define void @column_major_load_with_const_stride2_double(double* %Ptr)
+  // CHECK:         [[PTR:%.*]] = load double*, double** %Ptr.addr, align 8
+  // CHECK-NEXT:    call <25 x double> @llvm.matrix.columnwise.load.v25f64.p0f64(double* [[PTR]], i64 15, i64 5, i64 5)
+
+  dx5x5_t m_a2 = __builtin_matrix_column_major_load(Ptr, 5, 5, 2 * 3 + 9);
+}
+
+void column_major_load_with_variable_stride_ull_float(float *Ptr, unsigned long long S) {
+  // CHECK-LABEL: define void @column_major_load_with_variable_stride_ull_float(float* %Ptr, i64 %S)
+  // CHECK:         [[S:%.*]] = load i64, i64* %S.addr, align 8
+  // CHECK-NEXT:    [[PTR:%.*]] = load float*, float** %Ptr.addr, align 8
+  // CHECK-NEXT:    call <6 x float> @llvm.matrix.columnwise.load.v6f32.p0f32(float* [[PTR]], i64 [[S]], i64 2, i64 3)
+
+  fx2x3_t m_b = __builtin_matrix_column_major_load(Ptr, 2, 3, S);
+}
+
+void column_major_load_with_stride_math_int(int *Ptr, int S) {
+  // CHECK-LABEL: define void @column_major_load_with_stride_math_int(i32* %Ptr, i32 %S)
+  // CHECK:         [[S:%.*]] = load i32, i32* %S.addr, align 4
+  // CHECK-NEXT:    [[STRIDE:%.*]] = add nsw i32 [[S]], 32
+  // CHECK-NEXT:    [[STRIDE_EXT:%.*]] = sext i32 [[STRIDE]] to i64
+  // CHECK-NEXT:    [[PTR:%.*]] = load i32*, i32** %Ptr.addr, align 8
+  // CHECK-NEXT:    call <80 x i32> @llvm.matrix.columnwise.load.v80i32.p0i32(i32* [[PTR]], i64 [[STRIDE_EXT]], i64 4, i64 20)
+
+  ix4x20_t m_c = __builtin_matrix_column_major_load(Ptr, 4, 20, S + 32);
+}
+
+void column_major_load_with_stride_math_s_int(int *Ptr, short S) {
+  // CHECK-LABEL:  define void @column_major_load_with_stride_math_s_int(i32* %Ptr, i16 signext %S)
+  // CHECK:         [[S:%.*]] = load i16, i16* %S.addr, align 2
+  // CHECK-NEXT:    [[S_EXT:%.*]] = sext i16 [[S]] to i32
+  // CHECK-NEXT:    [[STRIDE:%.*]] = add nsw i32 [[S_EXT]], 32
+  // CHECK-NEXT:    [[STRIDE_EXT:%.*]] = sext i32 [[STRIDE]] to i64
+  // CHECK-NEXT:    [[PTR:%.*]] = load i32*, i32** %Ptr.addr, align 8
+  // CHECK-NEXT:    %matrix = call <80 x i32> @llvm.matrix.columnwise.load.v80i32.p0i32(i32* [[PTR]], i64 [[STRIDE_EXT]], i64 4, i64 20)
+  ix4x20_t m_c = __builtin_matrix_column_major_load(Ptr, 4, 20, S + 32);
+}
Index: clang/lib/Sema/SemaExpr.cpp
===================================================================
--- clang/lib/Sema/SemaExpr.cpp
+++ clang/lib/Sema/SemaExpr.cpp
@@ -4674,15 +4674,12 @@
   return Res;
 }
 
-static bool tryConvertToTy(Sema &S, QualType ElementType, ExprResult *Scalar) {
-  InitializedEntity Entity =
-      InitializedEntity::InitializeTemporary(ElementType);
-  InitializationKind Kind = InitializationKind::CreateCopy(
-      Scalar->get()->getBeginLoc(), SourceLocation());
-  Expr *Arg = Scalar->get();
-  InitializationSequence InitSeq(S, Entity, Kind, Arg);
-  *Scalar = InitSeq.Perform(S, Entity, Kind, Arg);
-  return !Scalar->isInvalid();
+ExprResult Sema::tryConvertExprToTy(Expr *E, QualType Ty) {
+  InitializedEntity Entity = InitializedEntity::InitializeTemporary(Ty);
+  InitializationKind Kind =
+      InitializationKind::CreateCopy(E->getBeginLoc(), SourceLocation());
+  InitializationSequence InitSeq(*this, Entity, Kind, E);
+  return InitSeq.Perform(*this, Entity, Kind, E);
 }
 
 ExprResult Sema::CreateBuiltinMatrixSubscriptExpr(Expr *Base, Expr *RowIdx,
@@ -4733,11 +4730,9 @@
       return nullptr;
     }
 
-    ExprResult ConvExpr = IndexExpr;
-    bool ConversionOk = tryConvertToTy(*this, Context.getSizeType(), &ConvExpr);
-    assert(ConversionOk &&
+    ExprResult ConvExpr = tryConvertExprToTy(IndexExpr, Context.getSizeType());
+    assert(!ConvExpr.isInvalid() &&
            "should be able to convert any integer type to size type");
-    (void)ConversionOk;
     return ConvExpr.get();
   };
 
@@ -12109,13 +12104,16 @@
   ExprResult OriginalLHS = LHS;
   ExprResult OriginalRHS = RHS;
   if (LHSMatType && !RHSMatType) {
-    if (tryConvertToTy(*this, LHSMatType->getElementType(), &RHS))
+    RHS = tryConvertExprToTy(RHS.get(), LHSMatType->getElementType());
+    if (!RHS.isInvalid())
       return LHSType;
+
     return InvalidOperands(Loc, OriginalLHS, OriginalRHS);
   }
 
   if (!LHSMatType && RHSMatType) {
-    if (tryConvertToTy(*this, RHSMatType->getElementType(), &LHS))
+    LHS = tryConvertExprToTy(LHS.get(), RHSMatType->getElementType());
+    if (!LHS.isInvalid())
       return RHSType;
     return InvalidOperands(Loc, OriginalLHS, OriginalRHS);
   }
Index: clang/lib/Sema/SemaChecking.cpp
===================================================================
--- clang/lib/Sema/SemaChecking.cpp
+++ clang/lib/Sema/SemaChecking.cpp
@@ -1915,6 +1915,9 @@
 
   case Builtin::BI__builtin_matrix_transpose:
     return SemaBuiltinMatrixTransposeOverload(TheCall, TheCallResult);
+
+  case Builtin::BI__builtin_matrix_column_major_load:
+    return SemaBuiltinMatrixColumnMajorLoadOverload(TheCall, TheCallResult);
   }
 
   // Since the target specific builtins for each arch overlap, only check those
@@ -15066,3 +15069,135 @@
   TheCall->setArg(0, Matrix);
   return CallResult;
 }
+
+// Get and verify the matrix dimensions.
+static llvm::Optional<unsigned>
+getAndVerifyMatrixDimension(Expr *Expr, unsigned ErrIdx, Sema &S) {
+  llvm::APSInt Value(64);
+  SourceLocation ErrorPos;
+  if (!Expr->isIntegerConstantExpr(Value, S.Context, &ErrorPos)) {
+    S.Diag(Expr->getBeginLoc(), diag::err_builtin_matrix_scalar_int_arg)
+        << ErrIdx << 1;
+    return {};
+  }
+  uint64_t Dim = Value.getZExtValue();
+  if (!ConstantMatrixType::isDimensionValid(Dim)) {
+    S.Diag(Expr->getBeginLoc(), diag::err_builtin_matrix_invalid_dimension)
+        << ErrIdx << ConstantMatrixType::getMaxElementsPerDimension();
+    return {};
+  }
+  return Dim;
+}
+
+ExprResult
+Sema::SemaBuiltinMatrixColumnMajorLoadOverload(CallExpr *TheCall,
+                                               ExprResult CallResult) {
+  if (checkArgCount(*this, TheCall, 4))
+    return ExprError();
+
+  Expr *PtrExpr = TheCall->getArg(0);
+  Expr *RowsExpr = TheCall->getArg(1);
+  Expr *ColumnsExpr = TheCall->getArg(2);
+  Expr *StrideExpr = TheCall->getArg(3);
+
+  bool ArgError = false;
+
+  // Check pointer argument.
+  {
+    ExprResult PtrConv = DefaultLvalueConversion(PtrExpr);
+    if (PtrConv.isInvalid())
+      return PtrConv;
+    PtrExpr = PtrConv.get();
+  }
+
+  QualType PtrTy = PtrExpr->getType();
+  // TODO: We need to loop through template substitutions properly somewhere.
+  if (auto *SubstTy = PtrTy->getAs<SubstTemplateTypeParmType>())
+    PtrTy = SubstTy->getReplacementType();
+
+  QualType ElementTy;
+  if (!(PtrTy->isPointerType() || PtrTy->isArrayType())) {
+    Diag(PtrExpr->getBeginLoc(), diag::err_builtin_matrix_pointer_arg) << 0;
+    ArgError = true;
+  } else {
+    if (const PointerType *PTy = dyn_cast<PointerType>(PtrTy))
+      ElementTy = PTy->getPointeeType();
+    else if (const ArrayType *ATy = dyn_cast<ArrayType>(PtrTy))
+      ElementTy = ATy->getElementType();
+    else
+      llvm_unreachable("Pointer Expression must be a pointer or an array");
+
+    ElementTy.removeLocalConst();
+    if (!ConstantMatrixType::isValidElementType(ElementTy)) {
+      Diag(PtrExpr->getBeginLoc(), diag::err_builtin_matrix_pointer_arg) << 0;
+      ArgError = true;
+    }
+  }
+
+  if (RowsExpr->isValueDependent() || RowsExpr->isTypeDependent() ||
+      ColumnsExpr->isValueDependent() || ColumnsExpr->isTypeDependent()) {
+    QualType ReturnType = Context.getDependentSizedMatrixType(
+        ElementTy, RowsExpr, ColumnsExpr, {});
+    TheCall->setType(ReturnType);
+    return CallResult;
+  }
+
+  // Apply default Lvalue conversions and convert the expression to size_t.
+  auto ApplyArgumentConversions = [this](Expr *E) {
+    ExprResult Conv = DefaultLvalueConversion(E);
+    if (Conv.isInvalid())
+      return Conv;
+
+    return tryConvertExprToTy(Conv.get(), Context.getSizeType());
+  };
+
+  // Check rows argument.
+  llvm::Optional<unsigned> MaybeRows;
+  ExprResult RowsConv = ApplyArgumentConversions(RowsExpr);
+  if (!RowsConv.isInvalid()) {
+    RowsExpr = RowsConv.get();
+    MaybeRows = getAndVerifyMatrixDimension(RowsExpr, 0, *this);
+  }
+
+  // Check columns argument.
+  llvm::Optional<unsigned> MaybeColumns;
+  ExprResult ColumnsConv = ApplyArgumentConversions(ColumnsExpr);
+  if (!ColumnsConv.isInvalid()) {
+    ColumnsExpr = ColumnsConv.get();
+    MaybeColumns = getAndVerifyMatrixDimension(ColumnsExpr, 1, *this);
+  }
+
+  // Check stride argument.
+  ExprResult StrideConv = ApplyArgumentConversions(StrideExpr);
+  if (StrideConv.isInvalid())
+    return ExprError();
+  StrideExpr = StrideConv.get();
+
+  if (!StrideExpr->getType()->isIntegralType(Context)) {
+    Diag(StrideExpr->getBeginLoc(), diag::err_builtin_matrix_scalar_int_arg)
+        << 2 << 1;
+    ArgError = true;
+  } else {
+    llvm::APSInt Value(64);
+    if (StrideExpr->isIntegerConstantExpr(Value, Context)) {
+      uint64_t Stride = Value.getZExtValue();
+      if (MaybeRows && Stride < *MaybeRows) {
+        Diag(StrideExpr->getBeginLoc(),
+             diag::err_builtin_matrix_stride_too_small);
+        ArgError = true;
+      }
+    }
+  }
+
+  if (ArgError || !MaybeRows || !MaybeColumns)
+    return ExprError();
+
+  QualType ReturnType =
+      Context.getConstantMatrixType(ElementTy, *MaybeRows, *MaybeColumns);
+  TheCall->setType(ReturnType);
+  TheCall->setArg(0, PtrExpr);
+  TheCall->setArg(1, RowsExpr);
+  TheCall->setArg(2, ColumnsExpr);
+  TheCall->setArg(3, StrideExpr);
+  return CallResult;
+}
Index: clang/lib/CodeGen/CGBuiltin.cpp
===================================================================
--- clang/lib/CodeGen/CGBuiltin.cpp
+++ clang/lib/CodeGen/CGBuiltin.cpp
@@ -2383,6 +2383,31 @@
     return RValue::get(Result);
   }
 
+  case Builtin::BI__builtin_matrix_column_major_load: {
+    MatrixBuilder<CGBuilderTy> MB(Builder);
+    // Emit everything that isn't dependent on the first parameter type
+    Value *Stride = EmitScalarExpr(E->getArg(3));
+    const auto *ResultTy = E->getType()->getAs<ConstantMatrixType>();
+
+    QualType PtrTy = E->getArg(0)->getType();
+    // If it's an address we need to emit the pointer
+    // otherwise, emit the array
+    Address Src = Address::invalid();
+    if (isa<PointerType>(PtrTy))
+      Src = EmitPointerWithAlignment(E->getArg(0));
+    else if (isa<ArrayType>(PtrTy))
+      Src = EmitArrayToPointerDecay(E->getArg(0));
+    else
+      llvm_unreachable("first argument must either be a pointer or an array");
+
+    EmitNonNullArgCheck(RValue::get(Src.getPointer()), PtrTy,
+                        E->getArg(0)->getExprLoc(), FD, 0);
+    Value *Result = MB.CreateMatrixColumnwiseLoad(
+        Src.getPointer(), ResultTy->getNumRows(), ResultTy->getNumColumns(),
+        Stride, "matrix");
+    return RValue::get(Result);
+  }
+
   case Builtin::BIfinite:
   case Builtin::BI__finite:
   case Builtin::BIfinitef:
Index: clang/include/clang/Sema/Sema.h
===================================================================
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -4703,6 +4703,10 @@
   bool tryExprAsCall(Expr &E, QualType &ZeroArgCallReturnTy,
                      UnresolvedSetImpl &NonTemplateOverloads);
 
+  /// Try to convert an expression \p E to type \p Ty. Returns the result of the
+  /// conversion.
+  ExprResult tryConvertExprToTy(Expr *E, QualType Ty);
+
   /// Conditionally issue a diagnostic based on the current
   /// evaluation context.
   ///
@@ -12118,6 +12122,8 @@
   // Matrix builtin handling.
   ExprResult SemaBuiltinMatrixTransposeOverload(CallExpr *TheCall,
                                                 ExprResult CallResult);
+  ExprResult SemaBuiltinMatrixColumnMajorLoadOverload(CallExpr *TheCall,
+                                                      ExprResult CallResult);
 
 public:
   enum FormatStringType {
Index: clang/include/clang/Basic/DiagnosticSemaKinds.td
===================================================================
--- clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -10781,6 +10781,18 @@
 def err_builtin_matrix_arg: Error<
   "%select{first|second}0 argument must be a matrix">;
 
+def err_builtin_matrix_scalar_int_arg: Error<
+  "%select{row|column|stride}0 argument must be %select{an unsigned integer|a constant unsigned integer expression}1">;
+
+def err_builtin_matrix_pointer_arg: Error<
+  "%select{first|second}0 argument must be a pointer to a valid matrix element type">;
+
+def err_builtin_matrix_stride_too_small: Error<
+  "stride must be greater or equal to the number of rows">;
+
+def err_builtin_matrix_invalid_dimension: Error<
+  "%select{row|column}0 dimension is outside the allowed range [1, %1]">;
+
 def err_preserve_field_info_not_field : Error<
   "__builtin_preserve_field_info argument %0 not a field access">;
 def err_preserve_field_info_not_const: Error<
Index: clang/include/clang/Basic/Builtins.def
===================================================================
--- clang/include/clang/Basic/Builtins.def
+++ clang/include/clang/Basic/Builtins.def
@@ -578,6 +578,7 @@
 BUILTIN(__builtin_call_with_static_chain, "v.", "nt")
 
 BUILTIN(__builtin_matrix_transpose, "v.", "nFt")
+BUILTIN(__builtin_matrix_column_major_load, "v.", "nFt")
 
 // "Overloaded" Atomic operator builtins.  These are overloaded to support data
 // types of i8, i16, i32, i64, and i128.  The front-end sees calls to the
Index: clang/include/clang/AST/Type.h
===================================================================
--- clang/include/clang/AST/Type.h
+++ clang/include/clang/AST/Type.h
@@ -3476,6 +3476,11 @@
            NumElements <= ConstantMatrixTypeBitfields::MaxElementsPerDimension;
   }
 
+  /// Returns the maximum number of elements per dimension.
+  static unsigned getMaxElementsPerDimension() {
+    return ConstantMatrixTypeBitfields::MaxElementsPerDimension;
+  }
+
   void Profile(llvm::FoldingSetNodeID &ID) {
     Profile(ID, getElementType(), getNumRows(), getNumColumns(),
             getTypeClass());
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to