https://github.com/AbdallahRashed updated https://github.com/llvm/llvm-project/pull/196573
>From 529abed9f9018ad62699b7e3cd18f105f9e6ec3f Mon Sep 17 00:00:00 2001 From: AbdallahRashed <[email protected]> Date: Fri, 8 May 2026 18:19:23 +0200 Subject: [PATCH] [CIR][CUDA] Support device-side printf for NVPTX Implement device-side printf lowering for NVPTX targets in CIR codegen. The variadic arguments are packed into a stack-allocated struct and passed to vprintf, matching the classic codegen behavior in CGGPUBuiltin.cpp When the target triple is NVPTX and the builtin is printf/__builtin_printf, we route to emitNVPTXDevicePrintfCallExpr The no-varargs case passes a null pointer directly. AMDGCN device printf remains NYI. --- clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp | 7 ++ clang/lib/CIR/CodeGen/CIRGenBuiltinNVPTX.cpp | 100 +++++++++++++++++++ clang/lib/CIR/CodeGen/CIRGenFunction.h | 3 + clang/test/CIR/CodeGenCUDA/device-printf.cu | 42 ++++++++ 4 files changed, 152 insertions(+) create mode 100644 clang/test/CIR/CodeGenCUDA/device-printf.cu diff --git a/clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp b/clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp index fe932834e9b55..9752d2571ed86 100644 --- a/clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp @@ -2400,6 +2400,13 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl &gd, unsigned builtinID, return errorBuiltinNYI(*this, e, builtinID); case Builtin::BI__builtin_printf: case Builtin::BIprintf: + assert(e->getNumArgs() >= 1); // printf always has at least one arg. + if (getTarget().getTriple().isAMDGCN()) { + return errorBuiltinNYI(*this, e, builtinID); + } + if (getTarget().getTriple().isNVPTX()) { + return RValue::get(emitNVPTXDevicePrintfCallExpr(e)); + } break; case Builtin::BI__builtin_canonicalize: case Builtin::BI__builtin_canonicalizef: diff --git a/clang/lib/CIR/CodeGen/CIRGenBuiltinNVPTX.cpp b/clang/lib/CIR/CodeGen/CIRGenBuiltinNVPTX.cpp index 52f98af8028b4..bd0607261e88b 100644 --- a/clang/lib/CIR/CodeGen/CIRGenBuiltinNVPTX.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenBuiltinNVPTX.cpp @@ -14,6 +14,7 @@ #include "mlir/IR/Value.h" #include "clang/Basic/TargetBuiltins.h" +#include "clang/CIR/Dialect/IR/CIRDataLayout.h" #include "clang/CIR/Dialect/IR/CIRDialect.h" using namespace clang; @@ -1008,3 +1009,102 @@ CIRGenFunction::emitNVPTXBuiltinExpr(unsigned builtinId, const CallExpr *expr) { return std::nullopt; } } + +// vprintf takes two args: A format string, and a pointer to a buffer containing +// the varargs. +// +// For example, the call +// +// printf("format string", arg1, arg2, arg3); +// +// is converted into something resembling +// +// struct Tmp { +// Arg1 a1; +// Arg2 a2; +// Arg3 a3; +// }; +// char* buf = alloca(sizeof(Tmp)); +// *(Tmp*)buf = {a1, a2, a3}; +// vprintf("format string", buf); +// +// `buf` is aligned to the max of {alignof(Arg1), ...}. Furthermore, each of +// the args is itself aligned to its preferred alignment. +// +// Note that by the time this function runs, the arguments have already +// undergone the standard C vararg promotion (short -> int, float -> double +// etc). In this function we pack the arguments into the buffer described above. +static mlir::Value packArgsIntoNVPTXFormatBuffer(CIRGenFunction &cgf, + const CallArgList &args, + mlir::Location loc) { + const cir::CIRDataLayout dataLayout = cgf.cgm.getDataLayout(); + CIRGenBuilderTy &builder = cgf.getBuilder(); + + if (args.size() <= 1) + // If there are no arguments other than the format string, + // pass a nullptr to vprintf. + return builder.getNullPtr(builder.getVoidPtrTy(), loc); + + llvm::SmallVector<mlir::Type, 8> argTypes; + for (const auto &arg : llvm::drop_begin(args)) + argTypes.push_back(arg.getRValue(cgf, loc).getValue().getType()); + + // We can directly store the arguments into a struct, and the alignment + // would automatically be correct. That's because vprintf does not + // accept aggregates. + mlir::Type allocaTy = builder.getAnonRecordTy(argTypes); + auto allocaAlign = clang::CharUnits::fromQuantity( + dataLayout.getABITypeAlign(allocaTy).value()); + Address allocaAddr = + cgf.createTempAlloca(allocaTy, allocaAlign, loc, "printf_args"); + mlir::Value alloca = allocaAddr.getPointer(); + + for (auto [i, arg] : llvm::enumerate(llvm::drop_begin(args))) { + mlir::Value member = builder.createGetMember( + loc, cir::PointerType::get(argTypes[i]), alloca, /*name=*/"", + /*index=*/i); + auto abiAlign = clang::CharUnits::fromQuantity( + dataLayout.getABITypeAlign(argTypes[i]).value()); + cir::StoreOp::create(builder, loc, arg.getRValue(cgf, loc).getValue(), + member, /*is_volatile=*/false, + builder.getAlignmentAttr(abiAlign), + /*sync_scope=*/cir::SyncScopeKindAttr{}, + /*mem_order=*/cir::MemOrderAttr{}); + } + + return builder.createBitcast(alloca, builder.getVoidPtrTy()); +} + +mlir::Value +CIRGenFunction::emitNVPTXDevicePrintfCallExpr(const CallExpr *expr) { + assert(cgm.getTriple().isNVPTX()); + CallArgList args; + emitCallArgs(args, + expr->getDirectCallee()->getType()->getAs<FunctionProtoType>(), + expr->arguments(), expr->getDirectCallee()); + + mlir::Location loc = getLoc(expr->getBeginLoc()); + + // Except the format string, no non-scalar arguments are allowed for + // device-side printf. + bool hasNonScalar = + llvm::any_of(llvm::drop_begin(args), [&](const CallArg &a) { + return !a.getRValue(*this, loc).isScalar(); + }); + if (hasNonScalar) { + cgm.errorUnsupported(expr, "non-scalar args to printf"); + return builder.getConstInt(loc, builder.getSInt32Ty(), 0); + } + + mlir::Value packedData = packArgsIntoNVPTXFormatBuffer(*this, args, loc); + + // int vprintf(char *format, void *packedData); + auto vprintf = cgm.createRuntimeFunction( + cir::FuncType::get( + {cir::PointerType::get(builder.getSInt8Ty()), builder.getVoidPtrTy()}, + builder.getSInt32Ty()), + "vprintf"); + auto formatString = args[0].getRValue(*this, loc).getValue(); + llvm::SmallVector<mlir::Value, 2> callArgs = {formatString, packedData}; + return builder.createCallOp(loc, vprintf, callArgs).getResult(); +} diff --git a/clang/lib/CIR/CodeGen/CIRGenFunction.h b/clang/lib/CIR/CodeGen/CIRGenFunction.h index de91b2f903018..dcc80817c866b 100644 --- a/clang/lib/CIR/CodeGen/CIRGenFunction.h +++ b/clang/lib/CIR/CodeGen/CIRGenFunction.h @@ -2029,6 +2029,9 @@ class CIRGenFunction : public CIRGenTypeCache { std::optional<mlir::Value> emitNVPTXBuiltinExpr(unsigned builtinID, const CallExpr *expr); + /// Emit a device-side printf call for NVPTX targets. + mlir::Value emitNVPTXDevicePrintfCallExpr(const CallExpr *expr); + LValue emitOpaqueValueLValue(const OpaqueValueExpr *e); LValue emitConditionalOperatorLValue(const AbstractConditionalOperator *expr); diff --git a/clang/test/CIR/CodeGenCUDA/device-printf.cu b/clang/test/CIR/CodeGenCUDA/device-printf.cu new file mode 100644 index 0000000000000..8528cb39f0450 --- /dev/null +++ b/clang/test/CIR/CodeGenCUDA/device-printf.cu @@ -0,0 +1,42 @@ +#include "Inputs/cuda.h" + +// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -target-cpu sm_80 -x cuda \ +// RUN: -fcuda-is-device -fclangir -emit-cir %s -o %t.cir +// RUN: FileCheck --check-prefix=CIR --input-file=%t.cir %s + +// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -target-cpu sm_80 -x cuda \ +// RUN: -fcuda-is-device -fclangir -emit-llvm %s -o %t.ll +// RUN: FileCheck --check-prefix=LLVM --input-file=%t.ll %s + +__device__ void print_int() { + printf("%d", 42); +} + +// CIR: cir.func no_inline dso_local @_Z9print_intv() +// CIR: %[[#ALLOCA:]] = cir.alloca !rec_anon_struct +// CIR: %[[#VAL:]] = cir.const #cir.int<42> : !s32i +// CIR: %[[#FIELD:]] = cir.get_member %[[#ALLOCA]][0] +// CIR: cir.store align(4) %[[#VAL]], %[[#FIELD]] +// CIR: %[[#BUF:]] = cir.cast bitcast %[[#ALLOCA]] : !cir.ptr<!rec_anon_struct> -> !cir.ptr<!void> +// CIR: cir.call @vprintf(%{{.+}}, %[[#BUF]]) +// CIR: cir.return + +// LLVM: define dso_local void @_Z9print_intv() +// LLVM: %[[#PACKED:]] = alloca { i32 } +// LLVM: %[[#GEP:]] = getelementptr inbounds nuw { i32 }, ptr %[[#PACKED]], i32 0, i32 0 +// LLVM: store i32 42, ptr %[[#GEP]], align 4 +// LLVM: call i32 @vprintf(ptr @{{.*}}, ptr %[[#PACKED]]) +// LLVM: ret void + +__device__ void print_no_args() { + printf("hello world"); +} + +// CIR: cir.func no_inline dso_local @_Z13print_no_argsv() +// CIR: %[[#NULL:]] = cir.const #cir.ptr<null> : !cir.ptr<!void> +// CIR: cir.call @vprintf(%{{.+}}, %[[#NULL]]) +// CIR: cir.return + +// LLVM: define dso_local void @_Z13print_no_argsv() +// LLVM: call i32 @vprintf(ptr @{{.*}}, ptr null) +// LLVM: ret void _______________________________________________ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
