fhahn created this revision.
Herald added a subscriber: tschuett.
Herald added a project: clang.

  rG LLVM Github Monorepo



Index: clang/test/CodeGen/builtin-matrix.c
--- clang/test/CodeGen/builtin-matrix.c
+++ clang/test/CodeGen/builtin-matrix.c
@@ -251,4 +251,24 @@
 // CHECK: declare <25 x double> @llvm.matrix.multiply.v25f64.v25f64.v25f64(<25 x double>, <25 x double>, i32 immarg, i32 immarg, i32 immarg) [[READNONE:#[0-9]]]
+void transpose1(dx5x5_t *a, dx5x5_t *b) {
+  *a = __builtin_matrix_transpose(*b);
+  // CHECK-LABEL: @transpose1(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [25 x double]*, align 8
+  // CHECK-NEXT:    %b.addr = alloca [25 x double]*, align 8
+  // CHECK-NEXT:    store [25 x double]* %a, [25 x double]** %a.addr, align 8
+  // CHECK-NEXT:    store [25 x double]* %b, [25 x double]** %b.addr, align 8
+  // CHECK-NEXT:    %0 = load [25 x double]*, [25 x double]** %b.addr, align 8
+  // CHECK-NEXT:    %1 = bitcast [25 x double]* %0 to <25 x double>*
+  // CHECK-NEXT:    %2 = load <25 x double>, <25 x double>* %1, align 8
+  // CHECK-NEXT:    %3 = call <25 x double> @llvm.matrix.transpose.v25f64(<25 x double> %2, i32 5, i32 5)
+  // CHECK-NEXT:    %4 = load [25 x double]*, [25 x double]** %a.addr, align 8
+  // CHECK-NEXT:    %5 = bitcast [25 x double]* %4 to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %3, <25 x double>* %5, align 8
+  // CHECK-NEXT:    ret void
+// CHECK: declare <25 x double> @llvm.matrix.transpose.v25f64(<25 x double>, i32 immarg, i32 immarg) [[READNONE]]
 // CHECK: attributes [[READNONE]] = { nounwind readnone speculatable willreturn }
Index: clang/lib/Sema/SemaChecking.cpp
--- clang/lib/Sema/SemaChecking.cpp
+++ clang/lib/Sema/SemaChecking.cpp
@@ -1618,6 +1618,7 @@
   case Builtin::BI__builtin_matrix_add:
   case Builtin::BI__builtin_matrix_subtract:
   case Builtin::BI__builtin_matrix_multiply:
+  case Builtin::BI__builtin_matrix_transpose:
     if (!getLangOpts().EnableMatrix) {
       Diag(TheCall->getBeginLoc(), diag::err_builtin_matrix_disabled);
       return ExprError();
@@ -1633,6 +1634,8 @@
       return SemaBuiltinMatrixEltwiseOverload(TheCall, TheCallResult);
     case Builtin::BI__builtin_matrix_multiply:
       return SemaBuiltinMatrixMultiplyOverload(TheCall, TheCallResult);
+    case Builtin::BI__builtin_matrix_transpose:
+      return SemaBuiltinMatrixTransposeOverload(TheCall, TheCallResult);
       llvm_unreachable("All matrix builtins should be handled here!");
@@ -15470,3 +15473,60 @@
   return CallResult;
+ExprResult Sema::SemaBuiltinMatrixTransposeOverload(CallExpr *TheCall,
+                                                    ExprResult CallResult) {
+  if (checkArgCount(*this, TheCall, 1))
+    return ExprError();
+  Expr *Arg = TheCall->getArg(0);
+  // Some very basic type chekcing, the parameter must be a matrix
+  if (!Arg->getType()->isMatrixType()) {
+    Diag(Arg->getBeginLoc(), diag::err_builtin_matrix_arg) << 0;
+    return ExprError();
+  }
+  MatrixType const *MType =
+      cast<MatrixType const>(Arg->getType().getCanonicalType());
+  unsigned R = MType->getNumRows();
+  unsigned C = MType->getNumColumns();
+  // Full Type Checking
+  // Set up the function prototype
+  if (!Arg->isRValue()) {
+    ExprResult Res = ImplicitCastExpr::Create(
+        Context, Arg->getType(), CK_LValueToRValue, Arg, nullptr, VK_RValue);
+    assert(!Res.isInvalid() && "Matrix Cast failed");
+    TheCall->setArg(0, Res.get());
+  }
+  Expr *Callee = TheCall->getCallee();
+  DeclRefExpr *DRE = cast<DeclRefExpr>(Callee->IgnoreParenCasts());
+  FunctionDecl *FDecl = cast<FunctionDecl>(DRE->getDecl());
+  // Function Return Type
+  QualType ReturnElementType = MType->getElementType();
+  QualType ResultType = Context.getMatrixType(ReturnElementType, C, R);
+  // Create a new DeclRefExpr to refer to the new decl.
+  DeclRefExpr *NewDRE = DeclRefExpr::Create(
+      Context, DRE->getQualifierLoc(), SourceLocation(), FDecl,
+      /*enclosing*/ false, DRE->getLocation(), Context.BuiltinFnTy,
+      DRE->getValueKind(), nullptr, nullptr, DRE->isNonOdrUse());
+  // Set the callee in the CallExpr.
+  // FIXME: This loses syntactic information.
+  QualType CalleePtrTy = Context.getPointerType(FDecl->getType());
+  ExprResult PromotedCall = ImpCastExprToType(NewDRE, CalleePtrTy,
+                                              CK_BuiltinFnToFnPtr);
+  TheCall->setCallee(PromotedCall.get());
+  // Change the result type of the call to match the original value type. This
+  // is arbitrary, but the codegen for these builtins ins design to handle it
+  // gracefully.
+  TheCall->setType(ResultType);
+  return CallResult;
Index: clang/lib/CodeGen/CGBuiltin.cpp
--- clang/lib/CodeGen/CGBuiltin.cpp
+++ clang/lib/CodeGen/CGBuiltin.cpp
@@ -2366,6 +2366,15 @@
     return RValue::get(Result);
+  case Builtin::BI__builtin_matrix_transpose: {
+    const MatrixType *MatrixTy = getMatrixTy(E->getArg(0)->getType());
+    Value *MatValue = EmitScalarExpr(E->getArg(0));
+    MatrixBuilder<CGBuilderTy> MB(Builder);
+    Value *Result = MB.CreateMatrixTranspose(
+        MatValue, MatrixTy->getNumRows(), MatrixTy->getNumColumns());
+    return RValue::get(Result);
+  }
   case Builtin::BI__builtin_matrix_add: {
     MatrixBuilder<CGBuilderTy> MB(Builder);
     Value *Matrix1 = EmitScalarExpr(E->getArg(0));
Index: clang/include/clang/Sema/Sema.h
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -11620,6 +11620,9 @@
                                               ExprResult CallResult);
   ExprResult SemaBuiltinMatrixMultiplyOverload(CallExpr *TheCall,
                                                ExprResult CallResult);
+  ExprResult SemaBuiltinMatrixTransposeOverload(CallExpr *TheCall,
+                                                ExprResult CallResult);
   enum FormatStringType {
Index: clang/include/clang/Basic/Builtins.def
--- clang/include/clang/Basic/Builtins.def
+++ clang/include/clang/Basic/Builtins.def
@@ -578,6 +578,7 @@
 BUILTIN(__builtin_matrix_subtract, "v.", "nt")
 BUILTIN(__builtin_matrix_add, "v.", "nt")
 BUILTIN(__builtin_matrix_multiply, "v.", "nt")
+BUILTIN(__builtin_matrix_transpose, "v.", "nFt")
 // "Overloaded" Atomic operator builtins.  These are overloaded to support data
 // types of i8, i16, i32, i64, and i128.  The front-end sees calls to the
cfe-commits mailing list
  • [PATCH] D72778: [Ma... Florian Hahn via Phabricator via cfe-commits

Reply via email to