llvmorg-github-actions[bot] wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-clangir

Author: AbdallahRashed (AbdallahRashed)

<details>
<summary>Changes</summary>

Implement device-side printf lowering for NVPTX targets in CIR codegen. The 
variadic arguments are packed into a stack-allocated struct and passed to 
vprintf, matching the classic codegen behavior in CGGPUBuiltin.cpp

When the target triple is NVPTX and the builtin is printf/__builtin_printf, we 
route to emitNVPTXDevicePrintfCallExpr
The no-varargs case passes a null pointer directly.

AMDGCN device printf remains NYI.
part of https://github.com/llvm/llvm-project/issues/179278

---
Full diff: https://github.com/llvm/llvm-project/pull/196573.diff


4 Files Affected:

- (modified) clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp (+7) 
- (modified) clang/lib/CIR/CodeGen/CIRGenBuiltinNVPTX.cpp (+100) 
- (modified) clang/lib/CIR/CodeGen/CIRGenFunction.h (+3) 
- (added) clang/test/CIR/CodeGenCUDA/device-printf.cu (+42) 


``````````diff
diff --git a/clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp 
b/clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp
index fe932834e9b55..9752d2571ed86 100644
--- a/clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp
@@ -2400,6 +2400,13 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl 
&gd, unsigned builtinID,
     return errorBuiltinNYI(*this, e, builtinID);
   case Builtin::BI__builtin_printf:
   case Builtin::BIprintf:
+    assert(e->getNumArgs() >= 1); // printf always has at least one arg.
+    if (getTarget().getTriple().isAMDGCN()) {
+      return errorBuiltinNYI(*this, e, builtinID);
+    }
+    if (getTarget().getTriple().isNVPTX()) {
+      return RValue::get(emitNVPTXDevicePrintfCallExpr(e));
+    }
     break;
   case Builtin::BI__builtin_canonicalize:
   case Builtin::BI__builtin_canonicalizef:
diff --git a/clang/lib/CIR/CodeGen/CIRGenBuiltinNVPTX.cpp 
b/clang/lib/CIR/CodeGen/CIRGenBuiltinNVPTX.cpp
index 52f98af8028b4..bd0607261e88b 100644
--- a/clang/lib/CIR/CodeGen/CIRGenBuiltinNVPTX.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenBuiltinNVPTX.cpp
@@ -14,6 +14,7 @@
 
 #include "mlir/IR/Value.h"
 #include "clang/Basic/TargetBuiltins.h"
+#include "clang/CIR/Dialect/IR/CIRDataLayout.h"
 #include "clang/CIR/Dialect/IR/CIRDialect.h"
 
 using namespace clang;
@@ -1008,3 +1009,102 @@ CIRGenFunction::emitNVPTXBuiltinExpr(unsigned 
builtinId, const CallExpr *expr) {
     return std::nullopt;
   }
 }
+
+// vprintf takes two args: A format string, and a pointer to a buffer 
containing
+// the varargs.
+//
+// For example, the call
+//
+//   printf("format string", arg1, arg2, arg3);
+//
+// is converted into something resembling
+//
+//   struct Tmp {
+//     Arg1 a1;
+//     Arg2 a2;
+//     Arg3 a3;
+//   };
+//   char* buf = alloca(sizeof(Tmp));
+//   *(Tmp*)buf = {a1, a2, a3};
+//   vprintf("format string", buf);
+//
+// `buf` is aligned to the max of {alignof(Arg1), ...}. Furthermore, each of
+// the args is itself aligned to its preferred alignment.
+//
+// Note that by the time this function runs, the arguments have already
+// undergone the standard C vararg promotion (short -> int, float -> double
+// etc). In this function we pack the arguments into the buffer described 
above.
+static mlir::Value packArgsIntoNVPTXFormatBuffer(CIRGenFunction &cgf,
+                                                 const CallArgList &args,
+                                                 mlir::Location loc) {
+  const cir::CIRDataLayout dataLayout = cgf.cgm.getDataLayout();
+  CIRGenBuilderTy &builder = cgf.getBuilder();
+
+  if (args.size() <= 1)
+    // If there are no arguments other than the format string,
+    // pass a nullptr to vprintf.
+    return builder.getNullPtr(builder.getVoidPtrTy(), loc);
+
+  llvm::SmallVector<mlir::Type, 8> argTypes;
+  for (const auto &arg : llvm::drop_begin(args))
+    argTypes.push_back(arg.getRValue(cgf, loc).getValue().getType());
+
+  // We can directly store the arguments into a struct, and the alignment
+  // would automatically be correct. That's because vprintf does not
+  // accept aggregates.
+  mlir::Type allocaTy = builder.getAnonRecordTy(argTypes);
+  auto allocaAlign = clang::CharUnits::fromQuantity(
+      dataLayout.getABITypeAlign(allocaTy).value());
+  Address allocaAddr =
+      cgf.createTempAlloca(allocaTy, allocaAlign, loc, "printf_args");
+  mlir::Value alloca = allocaAddr.getPointer();
+
+  for (auto [i, arg] : llvm::enumerate(llvm::drop_begin(args))) {
+    mlir::Value member = builder.createGetMember(
+        loc, cir::PointerType::get(argTypes[i]), alloca, /*name=*/"",
+        /*index=*/i);
+    auto abiAlign = clang::CharUnits::fromQuantity(
+        dataLayout.getABITypeAlign(argTypes[i]).value());
+    cir::StoreOp::create(builder, loc, arg.getRValue(cgf, loc).getValue(),
+                         member, /*is_volatile=*/false,
+                         builder.getAlignmentAttr(abiAlign),
+                         /*sync_scope=*/cir::SyncScopeKindAttr{},
+                         /*mem_order=*/cir::MemOrderAttr{});
+  }
+
+  return builder.createBitcast(alloca, builder.getVoidPtrTy());
+}
+
+mlir::Value
+CIRGenFunction::emitNVPTXDevicePrintfCallExpr(const CallExpr *expr) {
+  assert(cgm.getTriple().isNVPTX());
+  CallArgList args;
+  emitCallArgs(args,
+               expr->getDirectCallee()->getType()->getAs<FunctionProtoType>(),
+               expr->arguments(), expr->getDirectCallee());
+
+  mlir::Location loc = getLoc(expr->getBeginLoc());
+
+  // Except the format string, no non-scalar arguments are allowed for
+  // device-side printf.
+  bool hasNonScalar =
+      llvm::any_of(llvm::drop_begin(args), [&](const CallArg &a) {
+        return !a.getRValue(*this, loc).isScalar();
+      });
+  if (hasNonScalar) {
+    cgm.errorUnsupported(expr, "non-scalar args to printf");
+    return builder.getConstInt(loc, builder.getSInt32Ty(), 0);
+  }
+
+  mlir::Value packedData = packArgsIntoNVPTXFormatBuffer(*this, args, loc);
+
+  // int vprintf(char *format, void *packedData);
+  auto vprintf = cgm.createRuntimeFunction(
+      cir::FuncType::get(
+          {cir::PointerType::get(builder.getSInt8Ty()), 
builder.getVoidPtrTy()},
+          builder.getSInt32Ty()),
+      "vprintf");
+  auto formatString = args[0].getRValue(*this, loc).getValue();
+  llvm::SmallVector<mlir::Value, 2> callArgs = {formatString, packedData};
+  return builder.createCallOp(loc, vprintf, callArgs).getResult();
+}
diff --git a/clang/lib/CIR/CodeGen/CIRGenFunction.h 
b/clang/lib/CIR/CodeGen/CIRGenFunction.h
index de91b2f903018..dcc80817c866b 100644
--- a/clang/lib/CIR/CodeGen/CIRGenFunction.h
+++ b/clang/lib/CIR/CodeGen/CIRGenFunction.h
@@ -2029,6 +2029,9 @@ class CIRGenFunction : public CIRGenTypeCache {
   std::optional<mlir::Value> emitNVPTXBuiltinExpr(unsigned builtinID,
                                                   const CallExpr *expr);
 
+  /// Emit a device-side printf call for NVPTX targets.
+  mlir::Value emitNVPTXDevicePrintfCallExpr(const CallExpr *expr);
+
   LValue emitOpaqueValueLValue(const OpaqueValueExpr *e);
 
   LValue emitConditionalOperatorLValue(const AbstractConditionalOperator 
*expr);
diff --git a/clang/test/CIR/CodeGenCUDA/device-printf.cu 
b/clang/test/CIR/CodeGenCUDA/device-printf.cu
new file mode 100644
index 0000000000000..8528cb39f0450
--- /dev/null
+++ b/clang/test/CIR/CodeGenCUDA/device-printf.cu
@@ -0,0 +1,42 @@
+#include "Inputs/cuda.h"
+
+// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -target-cpu sm_80 -x cuda \
+// RUN:            -fcuda-is-device -fclangir -emit-cir %s -o %t.cir
+// RUN: FileCheck --check-prefix=CIR --input-file=%t.cir %s
+
+// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -target-cpu sm_80 -x cuda \
+// RUN:            -fcuda-is-device -fclangir -emit-llvm %s -o %t.ll
+// RUN: FileCheck --check-prefix=LLVM --input-file=%t.ll %s
+
+__device__ void print_int() {
+  printf("%d", 42);
+}
+
+// CIR: cir.func no_inline dso_local @_Z9print_intv()
+// CIR:   %[[#ALLOCA:]] = cir.alloca !rec_anon_struct
+// CIR:   %[[#VAL:]] = cir.const #cir.int<42> : !s32i
+// CIR:   %[[#FIELD:]] = cir.get_member %[[#ALLOCA]][0]
+// CIR:   cir.store align(4) %[[#VAL]], %[[#FIELD]]
+// CIR:   %[[#BUF:]] = cir.cast bitcast %[[#ALLOCA]] : 
!cir.ptr<!rec_anon_struct> -> !cir.ptr<!void>
+// CIR:   cir.call @vprintf(%{{.+}}, %[[#BUF]])
+// CIR:   cir.return
+
+// LLVM: define dso_local void @_Z9print_intv()
+// LLVM:   %[[#PACKED:]] = alloca { i32 }
+// LLVM:   %[[#GEP:]] = getelementptr inbounds nuw { i32 }, ptr %[[#PACKED]], 
i32 0, i32 0
+// LLVM:   store i32 42, ptr %[[#GEP]], align 4
+// LLVM:   call i32 @vprintf(ptr @{{.*}}, ptr %[[#PACKED]])
+// LLVM:   ret void
+
+__device__ void print_no_args() {
+  printf("hello world");
+}
+
+// CIR: cir.func no_inline dso_local @_Z13print_no_argsv()
+// CIR:   %[[#NULL:]] = cir.const #cir.ptr<null> : !cir.ptr<!void>
+// CIR:   cir.call @vprintf(%{{.+}}, %[[#NULL]])
+// CIR:   cir.return
+
+// LLVM: define dso_local void @_Z13print_no_argsv()
+// LLVM:   call i32 @vprintf(ptr @{{.*}}, ptr null)
+// LLVM:   ret void

``````````

</details>


https://github.com/llvm/llvm-project/pull/196573
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to