fhahn updated this revision to Diff 269378.
fhahn marked 5 inline comments as done.
fhahn added a comment.
Herald added a project: LLVM.
Herald added a subscriber: llvm-commits.

Simplified code as suggested, check if matrix type extensions is enabled (and 
add test) and set align attribute for pointer argument.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D72781/new/

https://reviews.llvm.org/D72781

Files:
  clang/include/clang/AST/Type.h
  clang/include/clang/Basic/Builtins.def
  clang/include/clang/Basic/DiagnosticSemaKinds.td
  clang/include/clang/Sema/Sema.h
  clang/lib/CodeGen/CGBuiltin.cpp
  clang/lib/Sema/SemaChecking.cpp
  clang/lib/Sema/SemaExpr.cpp
  clang/test/CodeGen/matrix-type-builtins.c
  clang/test/CodeGenCXX/matrix-type-builtins.cpp
  clang/test/CodeGenObjC/matrix-type-builtins.m
  clang/test/Sema/matrix-type-builtins.c
  clang/test/SemaCXX/matrix-type-builtins-disabled.cpp
  clang/test/SemaCXX/matrix-type-builtins.cpp
  llvm/include/llvm/IR/MatrixBuilder.h

Index: llvm/include/llvm/IR/MatrixBuilder.h
===================================================================
--- llvm/include/llvm/IR/MatrixBuilder.h
+++ llvm/include/llvm/IR/MatrixBuilder.h
@@ -56,10 +56,9 @@
   /// \p Rows    - Number of rows in matrix (must be a constant)
   /// \p Columns - Number of columns in matrix (must be a constant)
   /// \p Stride  - Space between columns
-  CallInst *CreateMatrixColumnwiseLoad(Value *DataPtr, unsigned Rows,
-                                       unsigned Columns, Value *Stride,
-                                       const Twine &Name = "") {
-
+  CallInst *CreateMatrixColumnwiseLoad(Value *DataPtr, unsigned Alignment,
+                                       unsigned Rows, unsigned Columns,
+                                       Value *Stride, const Twine &Name = "") {
     // Deal with the pointer
     PointerType *PtrTy = cast<PointerType>(DataPtr->getType());
     Type *EltTy = PtrTy->getElementType();
@@ -72,7 +71,11 @@
     Function *TheFn = Intrinsic::getDeclaration(
         getModule(), Intrinsic::matrix_columnwise_load, OverloadedTypes);
 
-    return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
+    CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
+    Attribute AlignAttr =
+        Attribute::getWithAlignment(Call->getContext(), Align(Alignment));
+    Call->addAttribute(1, AlignAttr);
+    return Call;
   }
 
   /// Create a columnwise, strided matrix store.
Index: clang/test/SemaCXX/matrix-type-builtins.cpp
===================================================================
--- clang/test/SemaCXX/matrix-type-builtins.cpp
+++ clang/test/SemaCXX/matrix-type-builtins.cpp
@@ -39,3 +39,65 @@
   Mat3.value = transpose<unsigned, 3, 3, float, 3, 3>(Mat2);
   // expected-note@-1 {{in instantiation of function template specialization 'transpose<unsigned int, 3, 3, float, 3, 3>' requested here}}
 }
+
+template <typename EltTy0, unsigned R0, unsigned C0, typename EltTy1, unsigned R1, unsigned C1>
+typename MyMatrix<EltTy1, R1, C1>::matrix_t column_major_load(MyMatrix<EltTy0, R0, C0> &A, EltTy0 *Ptr) {
+  char *v1 = __builtin_matrix_column_major_load(Ptr, 9, 4, 10);
+  // expected-error@-1 {{cannot initialize a variable of type 'char *' with an rvalue of type 'unsigned int __attribute__((matrix_type(9, 4)))'}}
+  // expected-error@-2 {{cannot initialize a variable of type 'char *' with an rvalue of type 'unsigned int __attribute__((matrix_type(9, 4)))'}}
+  // expected-error@-3 {{cannot initialize a variable of type 'char *' with an rvalue of type 'float __attribute__((matrix_type(9, 4)))'}}
+
+  return __builtin_matrix_column_major_load(Ptr, R0, C0, R0);
+  // expected-error@-1 {{cannot initialize return object of type 'typename MyMatrix<unsigned int, 5U, 5U>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(5, 5)))') with an rvalue of type 'unsigned int __attribute__((matrix_type(2, 3)))'}}
+  // expected-error@-2 {{cannot initialize return object of type 'typename MyMatrix<unsigned int, 2U, 3U>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 3)))') with an rvalue of type 'float __attribute__((matrix_type(2, 3)))'}}
+}
+
+void test_column_major_loads_template(unsigned *Ptr1, float *Ptr2) {
+  MyMatrix<unsigned, 2, 3> Mat1;
+  Mat1.value = column_major_load<unsigned, 2, 3, unsigned, 2, 3>(Mat1, Ptr1);
+  // expected-note@-1 {{in instantiation of function template specialization 'column_major_load<unsigned int, 2, 3, unsigned int, 2, 3>' requested here}}
+  column_major_load<unsigned, 2, 3, unsigned, 5, 5>(Mat1, Ptr1);
+  // expected-note@-1 {{in instantiation of function template specialization 'column_major_load<unsigned int, 2, 3, unsigned int, 5, 5>' requested here}}
+
+  MyMatrix<float, 2, 3> Mat2;
+  Mat1.value = column_major_load<float, 2, 3, unsigned, 2, 3>(Mat2, Ptr2);
+  // expected-note@-1 {{in instantiation of function template specialization 'column_major_load<float, 2, 3, unsigned int, 2, 3>' requested here}}
+}
+
+constexpr int constexpr1() { return 1; }
+constexpr int constexpr_neg1() { return -1; }
+
+void test_column_major_load_constexpr(unsigned *Ptr) {
+  (void)__builtin_matrix_column_major_load(Ptr, 2, 2, constexpr1());
+  // expected-error@-1 {{stride must be greater or equal to the number of rows}}
+  (void)__builtin_matrix_column_major_load(Ptr, constexpr_neg1(), 2, 4);
+  // expected-error@-1 {{row dimension is outside the allowed range [1, 1048575]}}
+  (void)__builtin_matrix_column_major_load(Ptr, 2, constexpr_neg1(), 4);
+  // expected-error@-1 {{column dimension is outside the allowed range [1, 1048575]}}
+}
+
+struct IntWrapper {
+  operator int() {
+    return 1;
+  }
+};
+
+void test_column_major_load_wrapper(unsigned *Ptr, IntWrapper &W) {
+  (void)__builtin_matrix_column_major_load(Ptr, W, 2, 2);
+  // expected-error@-1 {{row argument must be a constant unsigned integer expression}}
+  (void)__builtin_matrix_column_major_load(Ptr, 2, W, 2);
+  // expected-error@-1 {{column argument must be a constant unsigned integer expression}}
+}
+
+template <typename T, unsigned R, unsigned C, unsigned S>
+void test_column_major_load_temp(T Ptr) {
+  (void)__builtin_matrix_column_major_load(Ptr, R, C, S);
+}
+
+void call_column_major_load_temp(unsigned *Ptr, unsigned X) {
+  (void)__builtin_matrix_column_major_load(Ptr, X, X, X);
+  // expected-error@-1 {{row argument must be a constant unsigned integer expression}}
+  // expected-error@-2 {{column argument must be a constant unsigned integer expression}}
+  (void)__builtin_matrix_column_major_load(X, 2, 2, 2);
+  // expected-error@-1 {{first argument must be a pointer to a valid matrix element type}}
+}
Index: clang/test/SemaCXX/matrix-type-builtins-disabled.cpp
===================================================================
--- /dev/null
+++ clang/test/SemaCXX/matrix-type-builtins-disabled.cpp
@@ -0,0 +1,8 @@
+// RUN: %clang_cc1 %s -pedantic -std=c++11 -verify -triple=x86_64-apple-darwin9
+
+// Make sure we fail without -fenable-matrix when
+// __builtin_matrix_column_major_load is used to construct a new matrix type.
+void column_major_load_with_stride(int *Ptr) {
+  auto m = __builtin_matrix_column_major_load(Ptr, 2, 2, 2);
+  // expected-error@-1 {{matrix types extension is disabled. Pass -fenable-matrix to enable it}}
+}
Index: clang/test/Sema/matrix-type-builtins.c
===================================================================
--- clang/test/Sema/matrix-type-builtins.c
+++ clang/test/Sema/matrix-type-builtins.c
@@ -20,3 +20,49 @@
   ix3x3 m = __builtin_matrix_transpose(c);
   // expected-error@-1 {{initializing 'ix3x3' (aka 'unsigned int __attribute__((matrix_type(3, 3)))') with an expression of incompatible type 'double __attribute__((matrix_type(3, 3)))'}}
 }
+
+struct Foo {
+  unsigned x;
+};
+
+void column_major_load(float *p1, int *p2, _Bool *p3, struct Foo *p4) {
+  sx5x10_t a1 = __builtin_matrix_column_major_load(p1, 5, 11, 5);
+  // expected-error@-1 {{initializing 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') with an expression of incompatible type 'float __attribute__((matrix_type(5, 11)))'}}
+  sx5x10_t a2 = __builtin_matrix_column_major_load(p1, 5, 9, 5);
+  // expected-error@-1 {{initializing 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') with an expression of incompatible type 'float __attribute__((matrix_type(5, 9)))'}}
+  sx5x10_t a3 = __builtin_matrix_column_major_load(p1, 6, 10, 6);
+  // expected-error@-1 {{initializing 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') with an expression of incompatible type 'float __attribute__((matrix_type(6, 10)))'}}
+  sx5x10_t a4 = __builtin_matrix_column_major_load(p1, 4, 10, 4);
+  // expected-error@-1 {{initializing 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') with an expression of incompatible type 'float __attribute__((matrix_type(4, 10)))'}}
+  sx5x10_t a5 = __builtin_matrix_column_major_load(p1, 6, 9, 6);
+  // expected-error@-1 {{initializing 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') with an expression of incompatible type 'float __attribute__((matrix_type(6, 9)))'}}
+  sx5x10_t a6 = __builtin_matrix_column_major_load(p2, 5, 10, 6);
+  // expected-error@-1 {{initializing 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') with an expression of incompatible type 'int __attribute__((matrix_type(5, 10)))'}}
+
+  sx5x10_t a7 = __builtin_matrix_column_major_load(p1, 5, 10, 3);
+  // expected-error@-1 {{stride must be greater or equal to the number of rows}}
+
+  sx5x10_t a8 = __builtin_matrix_column_major_load(p3, 5, 10, 6);
+  // expected-error@-1 {{first argument must be a pointer to a valid matrix element type}}
+
+  sx5x10_t a9 = __builtin_matrix_column_major_load(p4, 5, 10, 6);
+  // expected-error@-1 {{first argument must be a pointer to a valid matrix element type}}
+
+  sx5x10_t a10 = __builtin_matrix_column_major_load(p1, 1ull << 21, 10, 6);
+  // expected-error@-1 {{row dimension is outside the allowed range [1, 1048575}}
+  sx5x10_t a11 = __builtin_matrix_column_major_load(p1, 10, 1ull << 21, 10);
+  // expected-error@-1 {{column dimension is outside the allowed range [1, 1048575}}
+
+  sx5x10_t a12 = __builtin_matrix_column_major_load(
+      10,         // expected-error {{first argument must be a pointer to a valid matrix element type}}
+      1ull << 21, // expected-error {{row dimension is outside the allowed range [1, 1048575]}}
+      1ull << 21, // expected-error {{column dimension is outside the allowed range [1, 1048575]}}
+      "");        // expected-warning {{incompatible pointer to integer conversion casting 'char [1]' to type 'unsigned long'}}
+
+  sx5x10_t a13 = __builtin_matrix_column_major_load(
+      10,  // expected-error {{first argument must be a pointer to a valid matrix element type}}
+      *p4, // expected-error {{casting 'struct Foo' to incompatible type 'unsigned long'}}
+      "",  // expected-error {{column argument must be a constant unsigned integer expression}}
+           // expected-warning@-1 {{incompatible pointer to integer conversion casting 'char [1]' to type 'unsigned long'}}
+      10);
+}
Index: clang/test/CodeGenObjC/matrix-type-builtins.m
===================================================================
--- clang/test/CodeGenObjC/matrix-type-builtins.m
+++ clang/test/CodeGenObjC/matrix-type-builtins.m
@@ -40,3 +40,23 @@
 
   m.value = __builtin_matrix_transpose(*r);
 }
+
+__attribute__((objc_root_class))
+@interface PtrValue
+@property unsigned *value;
+@end
+
+__attribute__((objc_root_class))
+@interface IntValue
+@property int value;
+@end
+
+void test_column_major_load(PtrValue *Ptr, IntValue *Stride) {
+  // CHECK-LABEL: define void @test_column_major_load(%2* %Ptr, %3* %Stride) #4 {
+  // CHECK:         [[STRIDE:%.*]] = call i32 bitcast (i8* (i8*, i8*, ...)* @objc_msgSend to i32 (i8*, i8*)*)
+  // CHECK-NEXT:    [[STRIDE_EXT:%.*]] = sext i32 [[STRIDE]] to i64
+  // CHECK:         [[PTR:%.*]] = call i32* bitcast (i8* (i8*, i8*, ...)* @objc_msgSend to i32* (i8*, i8*)*)
+  // CHECK-NEXT:    call <12 x i32> @llvm.matrix.columnwise.load.v12i32.p0i32(i32* align 4 [[PTR]], i64 [[STRIDE_EXT]], i64 3, i64 4)
+
+  u3x4 m = __builtin_matrix_column_major_load(Ptr.value, 3, 4, Stride.value);
+}
Index: clang/test/CodeGenCXX/matrix-type-builtins.cpp
===================================================================
--- clang/test/CodeGenCXX/matrix-type-builtins.cpp
+++ clang/test/CodeGenCXX/matrix-type-builtins.cpp
@@ -74,3 +74,109 @@
   // CHECK-NEXT:    store <9 x float> [[M_T]], <9 x float>* [[M_T_ADDR]], align 4
   matrix_t<float, 3, 3> m_t = __builtin_matrix_transpose(m);
 }
+
+template <typename T, unsigned R, unsigned C, unsigned S>
+matrix_t<T, R, C> column_major_load_with_stride(T *Ptr) {
+  return __builtin_matrix_column_major_load(Ptr, R, C, S);
+}
+
+void test_column_major_load_with_stride_template_double(double *Ptr) {
+  // CHECK-LABEL: define void @_Z50test_column_major_load_with_stride_template_doublePd(double* %Ptr)
+  // CHECK:         [[PTR:%.*]] = load double*, double** %Ptr.addr, align 8
+  // CHECK-NEXT:    call <40 x double> @_Z29column_major_load_with_strideIdLj10ELj4ELj15EEU11matrix_typeXT0_EXT1_ET_PS0_(double* [[PTR]])
+
+  // CHECK-LABEL:  define linkonce_odr <40 x double> @_Z29column_major_load_with_strideIdLj10ELj4ELj15EEU11matrix_typeXT0_EXT1_ET_PS0_(double* %Ptr)
+  // CHECK:         [[PTR:%.*]] = load double*, double** %Ptr.addr, align 8
+  // CHECK-NEXT:    call <40 x double> @llvm.matrix.columnwise.load.v40f64.p0f64(double* align 8 [[PTR]], i64 15, i64 10, i64 4)
+
+  matrix_t<double, 10, 4> M1 = column_major_load_with_stride<double, 10, 4, 15>(Ptr);
+}
+
+void test_column_major_load_with_stride_template_int(int *Ptr) {
+  // CHECK-LABEL: define void @_Z47test_column_major_load_with_stride_template_intPi(i32* %Ptr) #5 {
+  // CHECK:         [[PTR:%.*]] = load i32*, i32** %Ptr.addr, align 8
+  // CHECK-NEXT:    call <6 x i32> @_Z29column_major_load_with_strideIiLj3ELj2ELj12EEU11matrix_typeXT0_EXT1_ET_PS0_(i32* [[PTR]])
+
+  // CHECK-LABEL: define linkonce_odr <6 x i32> @_Z29column_major_load_with_strideIiLj3ELj2ELj12EEU11matrix_typeXT0_EXT1_ET_PS0_(i32* %Ptr)
+  // CHECK:         [[PTR:%.*]] = load i32*, i32** %Ptr.addr, align 8
+  // CHECK-NEXT:    call <6 x i32> @llvm.matrix.columnwise.load.v6i32.p0i32(i32* align 4 [[PTR]], i64 12, i64 3, i64 2)
+
+  matrix_t<int, 3, 2> M1 = column_major_load_with_stride<int, 3, 2, 12>(Ptr);
+}
+
+struct UnsignedWrapper {
+  char x;
+  operator unsigned() {
+    return x;
+  }
+};
+
+void test_column_major_load_stride_wrapper(int *Ptr, UnsignedWrapper &W) {
+  // CHECK-LABEL:  define void @_Z37test_column_major_load_stride_wrapperPiR15UnsignedWrapper(i32* %Ptr, %struct.UnsignedWrapper* nonnull align 1 dereferenceable(1) %W)
+  // CHECK:         [[W:%.*]] = load %struct.UnsignedWrapper*, %struct.UnsignedWrapper** %W.addr, align 8
+  // CHECK-NEXT:    [[STRIDE:%.*]] = call i32 @_ZN15UnsignedWrappercvjEv(%struct.UnsignedWrapper* [[W]])
+  // CHECK-NEXT:    [[STRIDE_EXT:%.*]] = zext i32 [[STRIDE]] to i64
+  // CHECK-NEXT:    [[PTR:%.*]] = load i32*, i32** %Ptr.addr, align 8
+  // CHECK-NEXT:    call <4 x i32> @llvm.matrix.columnwise.load.v4i32.p0i32(i32* align 4 [[PTR]], i64 [[STRIDE_EXT]], i64 2, i64 2)
+  matrix_t<int, 2, 2> M1 = __builtin_matrix_column_major_load(Ptr, 2, 2, W);
+}
+
+constexpr int constexpr3() { return 3; }
+
+void test_column_major_load_constexpr_num_rows(int *Ptr) {
+  // CHECK-LABEL: define void @_Z41test_column_major_load_constexpr_num_rowsPi(i32* %Ptr)
+  // CHECK:         [[PTR:%.*]] = load i32*, i32** %Ptr.addr, align 8
+  // CHECK-NEXT:    call <6 x i32> @llvm.matrix.columnwise.load.v6i32.p0i32(i32* align 4 [[PTR]], i64 3, i64 3, i64 2)
+
+  matrix_t<int, 3, 2> M1 = __builtin_matrix_column_major_load(Ptr, constexpr3(), 2, 3);
+}
+
+constexpr int constexpr1() { return 1; }
+
+void test_column_major_load_constexpr_num_columns(int *Ptr) {
+  // CHECK-LABEL: define void @_Z44test_column_major_load_constexpr_num_columnsPi(i32* %Ptr)
+  // CHECK:         [[PTR:%.*]] = load i32*, i32** %Ptr.addr, align 8
+  // CHECK-NEXT:    call <2 x i32> @llvm.matrix.columnwise.load.v2i32.p0i32(i32* align 4 [[PTR]], i64 3, i64 2, i64 1)
+  matrix_t<int, 2, 1> M1 = __builtin_matrix_column_major_load(Ptr, 2, constexpr1(), 3);
+}
+
+template <unsigned N>
+constexpr int constexpr_plus1() { return N + 1; }
+
+void test_column_major_load_constexpr_num_columns_temp(int *Ptr) {
+  // CHECK-LABEL:  define void @_Z49test_column_major_load_constexpr_num_columns_tempPi(i32* %Ptr)
+  // CHECK:         [[PTR:%.*]] = load i32*, i32** %Ptr.addr, align 8
+  // CHECK-NEXT:    call <10 x i32> @llvm.matrix.columnwise.load.v10i32.p0i32(i32* align 4 [[PTR]], i64 3, i64 2, i64 5)
+  matrix_t<int, 2, 5> M1 = __builtin_matrix_column_major_load(Ptr, 2, constexpr_plus1<4>(), 3);
+}
+
+void test_column_major_load_constexpr_stride_constexpr(int *Ptr) {
+  // CHECK-LABEL: define void @_Z49test_column_major_load_constexpr_stride_constexprPi(i32* %Ptr)
+  // CHECK:         [[STRIDE:%.*]] = call i32 @_Z10constexpr3v()
+  // CHECK-NEXT:    [[STRIDE_EXT:%.*]] = sext i32 [[STRIDE]] to i64
+  // CHECK-NEXT:    [[PTR:%.*]] = load i32*, i32** %Ptr.addr, align 8
+  // CHECK-NEXT:    call <4 x i32> @llvm.matrix.columnwise.load.v4i32.p0i32(i32* align 4 [[PTR]], i64 [[STRIDE_EXT]], i64 2, i64 2)
+
+  matrix_t<int, 2, 2> M1 = __builtin_matrix_column_major_load(Ptr, 2, 2, constexpr3());
+}
+
+// TODO:
+template <typename T>
+struct remove_pointer {
+  typedef T type;
+};
+
+template <typename T>
+struct remove_pointer<T *> {
+  typedef typename remove_pointer<T>::type type;
+};
+
+// Same as column_major_load_with_stride, but with the PtrT argument itself begin a pointer type.
+template <typename PtrT, unsigned R, unsigned C, unsigned S>
+matrix_t<typename remove_pointer<PtrT>::type, R, C> column_major_load_with_stride2(PtrT Ptr) {
+  return __builtin_matrix_column_major_load(Ptr, R, C, S);
+}
+
+void call_column_major_load_with_stride2(float *Ptr) {
+  matrix_t<float, 2, 2> m = column_major_load_with_stride2<float *, 2, 2, 2>(Ptr);
+}
Index: clang/test/CodeGen/matrix-type-builtins.c
===================================================================
--- clang/test/CodeGen/matrix-type-builtins.c
+++ clang/test/CodeGen/matrix-type-builtins.c
@@ -96,3 +96,83 @@
 
   dx5x5_t m_t = __builtin_matrix_transpose(global_matrix);
 }
+
+void column_major_load_with_const_stride_double(double *Ptr) {
+  // CHECK-LABEL: define void @column_major_load_with_const_stride_double(double* %Ptr)
+  // CHECK:         [[PTR:%.*]] = load double*, double** %Ptr.addr, align 8
+  // CHECK-NEXT:    call <25 x double> @llvm.matrix.columnwise.load.v25f64.p0f64(double* align 8 [[PTR]], i64 5, i64 5, i64 5)
+
+  dx5x5_t m_a1 = __builtin_matrix_column_major_load(Ptr, 5, 5, 5);
+}
+
+void column_major_load_with_const_stride2_double(double *Ptr) {
+  // CHECK-LABEL: define void @column_major_load_with_const_stride2_double(double* %Ptr)
+  // CHECK:         [[PTR:%.*]] = load double*, double** %Ptr.addr, align 8
+  // CHECK-NEXT:    call <25 x double> @llvm.matrix.columnwise.load.v25f64.p0f64(double* align 8 [[PTR]], i64 15, i64 5, i64 5)
+
+  dx5x5_t m_a2 = __builtin_matrix_column_major_load(Ptr, 5, 5, 2 * 3 + 9);
+}
+
+void column_major_load_with_variable_stride_ull_float(float *Ptr, unsigned long long S) {
+  // CHECK-LABEL: define void @column_major_load_with_variable_stride_ull_float(float* %Ptr, i64 %S)
+  // CHECK:         [[S:%.*]] = load i64, i64* %S.addr, align 8
+  // CHECK-NEXT:    [[PTR:%.*]] = load float*, float** %Ptr.addr, align 8
+  // CHECK-NEXT:    call <6 x float> @llvm.matrix.columnwise.load.v6f32.p0f32(float* align 4 [[PTR]], i64 [[S]], i64 2, i64 3)
+
+  fx2x3_t m_b = __builtin_matrix_column_major_load(Ptr, 2, 3, S);
+}
+
+void column_major_load_with_stride_math_int(int *Ptr, int S) {
+  // CHECK-LABEL: define void @column_major_load_with_stride_math_int(i32* %Ptr, i32 %S)
+  // CHECK:         [[S:%.*]] = load i32, i32* %S.addr, align 4
+  // CHECK-NEXT:    [[STRIDE:%.*]] = add nsw i32 [[S]], 32
+  // CHECK-NEXT:    [[STRIDE_EXT:%.*]] = sext i32 [[STRIDE]] to i64
+  // CHECK-NEXT:    [[PTR:%.*]] = load i32*, i32** %Ptr.addr, align 8
+  // CHECK-NEXT:    call <80 x i32> @llvm.matrix.columnwise.load.v80i32.p0i32(i32* align 4 [[PTR]], i64 [[STRIDE_EXT]], i64 4, i64 20)
+
+  ix4x20_t m_c = __builtin_matrix_column_major_load(Ptr, 4, 20, S + 32);
+}
+
+void column_major_load_with_stride_math_s_int(int *Ptr, short S) {
+  // CHECK-LABEL:  define void @column_major_load_with_stride_math_s_int(i32* %Ptr, i16 signext %S)
+  // CHECK:         [[S:%.*]] = load i16, i16* %S.addr, align 2
+  // CHECK-NEXT:    [[S_EXT:%.*]] = sext i16 [[S]] to i32
+  // CHECK-NEXT:    [[STRIDE:%.*]] = add nsw i32 [[S_EXT]], 32
+  // CHECK-NEXT:    [[STRIDE_EXT:%.*]] = sext i32 [[STRIDE]] to i64
+  // CHECK-NEXT:    [[PTR:%.*]] = load i32*, i32** %Ptr.addr, align 8
+  // CHECK-NEXT:    %matrix = call <80 x i32> @llvm.matrix.columnwise.load.v80i32.p0i32(i32* align 4 [[PTR]], i64 [[STRIDE_EXT]], i64 4, i64 20)
+
+  ix4x20_t m_c = __builtin_matrix_column_major_load(Ptr, 4, 20, S + 32);
+}
+
+void column_major_load_array1(double Ptr[25]) {
+  // CHECK-LABEL: define void @column_major_load_array1(double* %Ptr)
+  // CHECK:         [[ADDR:%.*]] = load double*, double** %Ptr.addr, align 8
+  // CHECK-NEXT:    call <25 x double> @llvm.matrix.columnwise.load.v25f64.p0f64(double* align 8 [[ADDR]], i64 5, i64 5, i64 5)
+
+  dx5x5_t m = __builtin_matrix_column_major_load(Ptr, 5, 5, 5);
+}
+
+void column_major_load_array2() {
+  // CHECK-LABEL: define void @column_major_load_array2() #0 {
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    [[PTR:%.*]] = alloca [25 x double], align 16
+  // CHECK:         [[ARRAY_DEC:%.*]] = getelementptr inbounds [25 x double], [25 x double]* [[PTR]], i64 0, i64 0
+  // CHECK-NEXT:    call <25 x double> @llvm.matrix.columnwise.load.v25f64.p0f64(double* align 16 [[ARRAY_DEC]], i64 5, i64 5, i64 5)
+
+  double Ptr[25];
+  dx5x5_t m = __builtin_matrix_column_major_load(Ptr, 5, 5, 5);
+}
+
+void column_major_load_const(const double *Ptr) {
+  // CHECK-LABEL: define void @column_major_load_const(double* %Ptr)
+  // CHECK:         [[PTR:%.*]] = load double*, double** %Ptr.addr, align 8
+  // CHECK-NEXT:    call <25 x double> @llvm.matrix.columnwise.load.v25f64.p0f64(double* align 8 [[PTR]], i64 5, i64 5, i64 5)
+
+  dx5x5_t m_a1 = __builtin_matrix_column_major_load(Ptr, 5, 5, 5);
+}
+
+// TODO: Support volatile.
+void column_major_load_volatile(volatile double *Ptr) {
+  //dx5x5_t m_a1 = __builtin_matrix_column_major_load(Ptr, 5, 5, 5);
+}
Index: clang/lib/Sema/SemaExpr.cpp
===================================================================
--- clang/lib/Sema/SemaExpr.cpp
+++ clang/lib/Sema/SemaExpr.cpp
@@ -4674,15 +4674,12 @@
   return Res;
 }
 
-static bool tryConvertToTy(Sema &S, QualType ElementType, ExprResult *Scalar) {
-  InitializedEntity Entity =
-      InitializedEntity::InitializeTemporary(ElementType);
-  InitializationKind Kind = InitializationKind::CreateCopy(
-      Scalar->get()->getBeginLoc(), SourceLocation());
-  Expr *Arg = Scalar->get();
-  InitializationSequence InitSeq(S, Entity, Kind, Arg);
-  *Scalar = InitSeq.Perform(S, Entity, Kind, Arg);
-  return !Scalar->isInvalid();
+ExprResult Sema::tryConvertExprToTy(Expr *E, QualType Ty) {
+  InitializedEntity Entity = InitializedEntity::InitializeTemporary(Ty);
+  InitializationKind Kind =
+      InitializationKind::CreateCopy(E->getBeginLoc(), SourceLocation());
+  InitializationSequence InitSeq(*this, Entity, Kind, E);
+  return InitSeq.Perform(*this, Entity, Kind, E);
 }
 
 ExprResult Sema::CreateBuiltinMatrixSubscriptExpr(Expr *Base, Expr *RowIdx,
@@ -4733,11 +4730,9 @@
       return nullptr;
     }
 
-    ExprResult ConvExpr = IndexExpr;
-    bool ConversionOk = tryConvertToTy(*this, Context.getSizeType(), &ConvExpr);
-    assert(ConversionOk &&
+    ExprResult ConvExpr = tryConvertExprToTy(IndexExpr, Context.getSizeType());
+    assert(!ConvExpr.isInvalid() &&
            "should be able to convert any integer type to size type");
-    (void)ConversionOk;
     return ConvExpr.get();
   };
 
@@ -12109,13 +12104,16 @@
   ExprResult OriginalLHS = LHS;
   ExprResult OriginalRHS = RHS;
   if (LHSMatType && !RHSMatType) {
-    if (tryConvertToTy(*this, LHSMatType->getElementType(), &RHS))
+    RHS = tryConvertExprToTy(RHS.get(), LHSMatType->getElementType());
+    if (!RHS.isInvalid())
       return LHSType;
+
     return InvalidOperands(Loc, OriginalLHS, OriginalRHS);
   }
 
   if (!LHSMatType && RHSMatType) {
-    if (tryConvertToTy(*this, RHSMatType->getElementType(), &LHS))
+    LHS = tryConvertExprToTy(LHS.get(), RHSMatType->getElementType());
+    if (!LHS.isInvalid())
       return RHSType;
     return InvalidOperands(Loc, OriginalLHS, OriginalRHS);
   }
Index: clang/lib/Sema/SemaChecking.cpp
===================================================================
--- clang/lib/Sema/SemaChecking.cpp
+++ clang/lib/Sema/SemaChecking.cpp
@@ -1915,6 +1915,9 @@
 
   case Builtin::BI__builtin_matrix_transpose:
     return SemaBuiltinMatrixTransposeOverload(TheCall, TheCallResult);
+
+  case Builtin::BI__builtin_matrix_column_major_load:
+    return SemaBuiltinMatrixColumnMajorLoadOverload(TheCall, TheCallResult);
   }
 
   // Since the target specific builtins for each arch overlap, only check those
@@ -15066,3 +15069,133 @@
   TheCall->setArg(0, Matrix);
   return CallResult;
 }
+
+// Get and verify the matrix dimensions.
+static llvm::Optional<unsigned>
+getAndVerifyMatrixDimension(Expr *Expr, StringRef Name, Sema &S) {
+  llvm::APSInt Value(64);
+  SourceLocation ErrorPos;
+  if (!Expr->isIntegerConstantExpr(Value, S.Context, &ErrorPos)) {
+    S.Diag(Expr->getBeginLoc(), diag::err_builtin_matrix_scalar_unsigned_arg)
+        << Name;
+    return {};
+  }
+  uint64_t Dim = Value.getZExtValue();
+  if (!ConstantMatrixType::isDimensionValid(Dim)) {
+    S.Diag(Expr->getBeginLoc(), diag::err_builtin_matrix_invalid_dimension)
+        << Name << ConstantMatrixType::getMaxElementsPerDimension();
+    return {};
+  }
+  return Dim;
+}
+
+ExprResult
+Sema::SemaBuiltinMatrixColumnMajorLoadOverload(CallExpr *TheCall,
+                                               ExprResult CallResult) {
+  if (!getLangOpts().MatrixTypes) {
+    Diag(TheCall->getBeginLoc(), diag::err_builtin_matrix_disabled);
+    return ExprError();
+  }
+
+  if (checkArgCount(*this, TheCall, 4))
+    return ExprError();
+
+  Expr *PtrExpr = TheCall->getArg(0);
+  Expr *RowsExpr = TheCall->getArg(1);
+  Expr *ColumnsExpr = TheCall->getArg(2);
+  Expr *StrideExpr = TheCall->getArg(3);
+
+  bool ArgError = false;
+
+  // Check pointer argument.
+  {
+    ExprResult PtrConv = DefaultFunctionArrayLvalueConversion(PtrExpr);
+    if (PtrConv.isInvalid())
+      return PtrConv;
+    PtrExpr = PtrConv.get();
+    TheCall->setArg(0, PtrExpr);
+    if (PtrExpr->isTypeDependent()) {
+      TheCall->setType(Context.DependentTy);
+      return TheCall;
+    }
+  }
+
+  auto *PtrTy = PtrExpr->getType()->getAs<PointerType>();
+  QualType ElementTy;
+  if (!PtrTy) {
+    Diag(PtrExpr->getBeginLoc(), diag::err_builtin_matrix_pointer_arg) << 0;
+    ArgError = true;
+  } else {
+    ElementTy = PtrTy->getPointeeType().getUnqualifiedType();
+
+    if (!ConstantMatrixType::isValidElementType(ElementTy)) {
+      Diag(PtrExpr->getBeginLoc(), diag::err_builtin_matrix_pointer_arg) << 0;
+      ArgError = true;
+    }
+  }
+
+  // Apply default Lvalue conversions and convert the expression to size_t.
+  auto ApplyArgumentConversions = [this](Expr *E) {
+    ExprResult Conv = DefaultLvalueConversion(E);
+    if (Conv.isInvalid())
+      return Conv;
+
+    return tryConvertExprToTy(Conv.get(), Context.getSizeType());
+  };
+
+  // Apply conversion to row and column expressions.
+  ExprResult RowsConv = ApplyArgumentConversions(RowsExpr);
+  if (!RowsConv.isInvalid()) {
+    RowsExpr = RowsConv.get();
+    TheCall->setArg(1, RowsExpr);
+  } else
+    RowsExpr = nullptr;
+
+  ExprResult ColumnsConv = ApplyArgumentConversions(ColumnsExpr);
+  if (!ColumnsConv.isInvalid()) {
+    ColumnsExpr = ColumnsConv.get();
+    TheCall->setArg(2, ColumnsExpr);
+  } else
+    ColumnsExpr = nullptr;
+
+  // If any any part of the result matrix type is still pending, just use
+  // Context.DependentTy, until all parts are resolved.
+  if ((RowsExpr && RowsExpr->isTypeDependent()) ||
+      (ColumnsExpr && ColumnsExpr->isTypeDependent())) {
+    TheCall->setType(Context.DependentTy);
+    return CallResult;
+  }
+
+  // Check row and column dimenions.
+  llvm::Optional<unsigned> MaybeRows;
+  if (RowsExpr)
+    MaybeRows = getAndVerifyMatrixDimension(RowsExpr, "row", *this);
+
+  llvm::Optional<unsigned> MaybeColumns;
+  if (ColumnsExpr)
+    MaybeColumns = getAndVerifyMatrixDimension(ColumnsExpr, "column", *this);
+
+  // Check stride argument.
+  ExprResult StrideConv = ApplyArgumentConversions(StrideExpr);
+  if (StrideConv.isInvalid())
+    return ExprError();
+  StrideExpr = StrideConv.get();
+  TheCall->setArg(3, StrideExpr);
+
+  llvm::APSInt Value(64);
+  if (MaybeRows && StrideExpr->isIntegerConstantExpr(Value, Context)) {
+    uint64_t Stride = Value.getZExtValue();
+    if (Stride < *MaybeRows) {
+      Diag(StrideExpr->getBeginLoc(),
+           diag::err_builtin_matrix_stride_too_small);
+      ArgError = true;
+    }
+  }
+
+  if (ArgError || !MaybeRows || !MaybeColumns)
+    return ExprError();
+
+  TheCall->setType(
+      Context.getConstantMatrixType(ElementTy, *MaybeRows, *MaybeColumns));
+  return CallResult;
+}
Index: clang/lib/CodeGen/CGBuiltin.cpp
===================================================================
--- clang/lib/CodeGen/CGBuiltin.cpp
+++ clang/lib/CodeGen/CGBuiltin.cpp
@@ -2385,6 +2385,27 @@
     return RValue::get(Result);
   }
 
+  case Builtin::BI__builtin_matrix_column_major_load: {
+    MatrixBuilder<CGBuilderTy> MB(Builder);
+    // Emit everything that isn't dependent on the first parameter type
+    Value *Stride = EmitScalarExpr(E->getArg(3));
+    const auto *ResultTy = E->getType()->getAs<ConstantMatrixType>();
+    auto *PtrTy = E->getArg(0)->getType()->getAs<PointerType>();
+    assert(PtrTy && "arg0 must be of pointer type");
+    bool Volatile = PtrTy->getPointeeType().isVolatileQualified();
+    // TODO: Pass through volatile flag to matrix builtin.
+    assert(!Volatile && "volatile pointers not supported yet");
+    (void)Volatile;
+
+    Address Src = EmitPointerWithAlignment(E->getArg(0));
+    EmitNonNullArgCheck(RValue::get(Src.getPointer()), E->getArg(0)->getType(),
+                        E->getArg(0)->getExprLoc(), FD, 0);
+    Value *Result = MB.CreateMatrixColumnwiseLoad(
+        Src.getPointer(), Src.getAlignment().getQuantity(),
+        ResultTy->getNumRows(), ResultTy->getNumColumns(), Stride, "matrix");
+    return RValue::get(Result);
+  }
+
   case Builtin::BIfinite:
   case Builtin::BI__finite:
   case Builtin::BIfinitef:
Index: clang/include/clang/Sema/Sema.h
===================================================================
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -4703,6 +4703,10 @@
   bool tryExprAsCall(Expr &E, QualType &ZeroArgCallReturnTy,
                      UnresolvedSetImpl &NonTemplateOverloads);
 
+  /// Try to convert an expression \p E to type \p Ty. Returns the result of the
+  /// conversion.
+  ExprResult tryConvertExprToTy(Expr *E, QualType Ty);
+
   /// Conditionally issue a diagnostic based on the current
   /// evaluation context.
   ///
@@ -12118,6 +12122,8 @@
   // Matrix builtin handling.
   ExprResult SemaBuiltinMatrixTransposeOverload(CallExpr *TheCall,
                                                 ExprResult CallResult);
+  ExprResult SemaBuiltinMatrixColumnMajorLoadOverload(CallExpr *TheCall,
+                                                      ExprResult CallResult);
 
 public:
   enum FormatStringType {
Index: clang/include/clang/Basic/DiagnosticSemaKinds.td
===================================================================
--- clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -10782,6 +10782,18 @@
 def err_builtin_matrix_arg: Error<
   "%select{first|second}0 argument must be a matrix">;
 
+def err_builtin_matrix_scalar_unsigned_arg: Error<
+  "%0 argument must be a constant unsigned integer expression">;
+
+def err_builtin_matrix_pointer_arg: Error<
+  "%select{first|second}0 argument must be a pointer to a valid matrix element type">;
+
+def err_builtin_matrix_stride_too_small: Error<
+  "stride must be greater or equal to the number of rows">;
+
+def err_builtin_matrix_invalid_dimension: Error<
+  "%0 dimension is outside the allowed range [1, %1]">;
+
 def err_preserve_field_info_not_field : Error<
   "__builtin_preserve_field_info argument %0 not a field access">;
 def err_preserve_field_info_not_const: Error<
Index: clang/include/clang/Basic/Builtins.def
===================================================================
--- clang/include/clang/Basic/Builtins.def
+++ clang/include/clang/Basic/Builtins.def
@@ -578,6 +578,7 @@
 BUILTIN(__builtin_call_with_static_chain, "v.", "nt")
 
 BUILTIN(__builtin_matrix_transpose, "v.", "nFt")
+BUILTIN(__builtin_matrix_column_major_load, "v.", "nFt")
 
 // "Overloaded" Atomic operator builtins.  These are overloaded to support data
 // types of i8, i16, i32, i64, and i128.  The front-end sees calls to the
Index: clang/include/clang/AST/Type.h
===================================================================
--- clang/include/clang/AST/Type.h
+++ clang/include/clang/AST/Type.h
@@ -3476,6 +3476,11 @@
            NumElements <= ConstantMatrixTypeBitfields::MaxElementsPerDimension;
   }
 
+  /// Returns the maximum number of elements per dimension.
+  static unsigned getMaxElementsPerDimension() {
+    return ConstantMatrixTypeBitfields::MaxElementsPerDimension;
+  }
+
   void Profile(llvm::FoldingSetNodeID &ID) {
     Profile(ID, getElementType(), getNumRows(), getNumColumns(),
             getTypeClass());
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to