JonChesterfield created this revision.
JonChesterfield added reviewers: jdoerfert, tianshilei1992, jhuber6, ye-luo, 
ronlieb, pdhaliwal.
Herald added subscribers: guansong, yaxunl.
JonChesterfield requested review of this revision.
Herald added subscribers: cfe-commits, sstefan1.
Herald added a project: clang.

A refactor to share code in CGGPUBuiltin renders poorly
on phabricator, splitting that part out as this diff to help review.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D112682

Files:
  clang/lib/CodeGen/CGGPUBuiltin.cpp

Index: clang/lib/CodeGen/CGGPUBuiltin.cpp
===================================================================
--- clang/lib/CodeGen/CGGPUBuiltin.cpp
+++ clang/lib/CodeGen/CGGPUBuiltin.cpp
@@ -21,13 +21,14 @@
 using namespace clang;
 using namespace CodeGen;
 
-static llvm::Function *GetVprintfDeclaration(llvm::Module &M) {
+namespace {
+llvm::Function *GetVprintfDeclaration(llvm::Module &M) {
   llvm::Type *ArgTypes[] = {llvm::Type::getInt8PtrTy(M.getContext()),
                             llvm::Type::getInt8PtrTy(M.getContext())};
   llvm::FunctionType *VprintfFuncType = llvm::FunctionType::get(
       llvm::Type::getInt32Ty(M.getContext()), ArgTypes, false);
 
-  if (auto* F = M.getFunction("vprintf")) {
+  if (auto *F = M.getFunction("vprintf")) {
     // Our CUDA system header declares vprintf with the right signature, so
     // nobody else should have been able to declare vprintf with a bogus
     // signature.
@@ -66,39 +67,24 @@
 //
 // Note that by the time this function runs, E's args have already undergone the
 // standard C vararg promotion (short -> int, float -> double, etc.).
-RValue
-CodeGenFunction::EmitNVPTXDevicePrintfCallExpr(const CallExpr *E,
-                                               ReturnValueSlot ReturnValue) {
-  assert(getTarget().getTriple().isNVPTX());
-  assert(E->getBuiltinCallee() == Builtin::BIprintf);
-  assert(E->getNumArgs() >= 1); // printf always has at least one arg.
 
-  const llvm::DataLayout &DL = CGM.getDataLayout();
-  llvm::LLVMContext &Ctx = CGM.getLLVMContext();
+std::pair<llvm::Value *, llvm::TypeSize>
+packArgsIntoNVPTXFormatBuffer(CodeGenFunction *CGF, const CallArgList &Args) {
 
-  CallArgList Args;
-  EmitCallArgs(Args,
-               E->getDirectCallee()->getType()->getAs<FunctionProtoType>(),
-               E->arguments(), E->getDirectCallee(),
-               /* ParamsToSkip = */ 0);
-
-  // We don't know how to emit non-scalar varargs.
-  if (llvm::any_of(llvm::drop_begin(Args), [&](const CallArg &A) {
-        return !A.getRValue(*this).isScalar();
-      })) {
-    CGM.ErrorUnsupported(E, "non-scalar arg to printf");
-    return RValue::get(llvm::ConstantInt::get(IntTy, 0));
-  }
+  const llvm::DataLayout &DL = CGF->CGM.getDataLayout();
+  llvm::LLVMContext &Ctx = CGF->CGM.getLLVMContext();
+  CGBuilderTy &Builder = CGF->Builder;
 
   // Construct and fill the args buffer that we'll pass to vprintf.
   llvm::Value *BufferPtr;
   if (Args.size() <= 1) {
-    // If there are no args, pass a null pointer to vprintf.
+    // If there are no args, pass a null pointer and size 0
     BufferPtr = llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(Ctx));
+    return {BufferPtr, llvm::TypeSize::Fixed(0)};
   } else {
     llvm::SmallVector<llvm::Type *, 8> ArgTypes;
     for (unsigned I = 1, NumArgs = Args.size(); I < NumArgs; ++I)
-      ArgTypes.push_back(Args[I].getRValue(*this).getScalarVal()->getType());
+      ArgTypes.push_back(Args[I].getRValue(*CGF).getScalarVal()->getType());
 
     // Using llvm::StructType is correct only because printf doesn't accept
     // aggregates.  If we had to handle aggregates here, we'd have to manually
@@ -106,18 +92,43 @@
     // that the alignment of the llvm type was the same as the alignment of the
     // clang type.
     llvm::Type *AllocaTy = llvm::StructType::create(ArgTypes, "printf_args");
-    llvm::Value *Alloca = CreateTempAlloca(AllocaTy);
+    llvm::Value *Alloca = CGF->CreateTempAlloca(AllocaTy);
 
     for (unsigned I = 1, NumArgs = Args.size(); I < NumArgs; ++I) {
       llvm::Value *P = Builder.CreateStructGEP(AllocaTy, Alloca, I - 1);
-      llvm::Value *Arg = Args[I].getRValue(*this).getScalarVal();
+      llvm::Value *Arg = Args[I].getRValue(*CGF).getScalarVal();
       Builder.CreateAlignedStore(Arg, P, DL.getPrefTypeAlign(Arg->getType()));
     }
     BufferPtr = Builder.CreatePointerCast(Alloca, llvm::Type::getInt8PtrTy(Ctx));
+    return {BufferPtr, DL.getTypeAllocSize(AllocaTy)};
+  }
+}
+} // namespace
+
+RValue
+CodeGenFunction::EmitNVPTXDevicePrintfCallExpr(const CallExpr *E,
+                                               ReturnValueSlot ReturnValue) {
+  assert(getTarget().getTriple().isNVPTX());
+  assert(E->getBuiltinCallee() == Builtin::BIprintf);
+  assert(E->getNumArgs() >= 1); // printf always has at least one arg.
+
+  CallArgList Args;
+  EmitCallArgs(Args,
+               E->getDirectCallee()->getType()->getAs<FunctionProtoType>(),
+               E->arguments(), E->getDirectCallee(),
+               /* ParamsToSkip = */ 0);
+
+  // We don't know how to emit non-scalar varargs.
+  if (llvm::any_of(llvm::drop_begin(Args), [&](const CallArg &A) {
+        return !A.getRValue(*this).isScalar();
+      })) {
+    CGM.ErrorUnsupported(E, "non-scalar arg to printf");
+    return RValue::get(llvm::ConstantInt::get(IntTy, 0));
   }
 
+  llvm::Value *BufferPtr = packArgsIntoNVPTXFormatBuffer(this, Args).first;
   // Invoke vprintf and return.
-  llvm::Function* VprintfFunc = GetVprintfDeclaration(CGM.getModule());
+  llvm::Function *VprintfFunc = GetVprintfDeclaration(CGM.getModule());
   return RValue::get(Builder.CreateCall(
       VprintfFunc, {Args[0].getRValue(*this).getScalarVal(), BufferPtr}));
 }
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to