================ @@ -0,0 +1,357 @@ +//===- CIRGenCUDANV.cpp - Interface to NVIDIA CUDA Runtime -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This provides a class for CUDA code generation targeting the NVIDIA CUDA +// runtime library. +// +//===----------------------------------------------------------------------===// + +#include "CIRGenCUDARuntime.h" +#include "CIRGenFunction.h" +#include "CIRGenModule.h" +#include "mlir/IR/Operation.h" +#include "clang/AST/ASTContext.h" +#include "clang/AST/Decl.h" +#include "clang/AST/GlobalDecl.h" +#include "clang/Basic/AddressSpaces.h" +#include "clang/Basic/Cuda.h" +#include "clang/CIR/Dialect/IR/CIRDialect.h" +#include "clang/CIR/Dialect/IR/CIRTypes.h" +#include "llvm/Support/Casting.h" + +using namespace clang; +using namespace clang::CIRGen; + +namespace { + +class CIRGenNVCUDARuntime : public CIRGenCUDARuntime { +protected: + StringRef Prefix; + + // Map a device stub function to a symbol for identifying kernel in host + // code. For CUDA, the symbol for identifying the kernel is the same as the + // device stub function. For HIP, they are different. + llvm::DenseMap<StringRef, mlir::Operation *> kernelHandles; + + // Map a kernel handle to the kernel stub. + llvm::DenseMap<mlir::Operation *, mlir::Operation *> kernelStubs; + // Mangle context for device. + std::unique_ptr<MangleContext> deviceMC; + +private: + void emitDeviceStubBodyNew(CIRGenFunction &cgf, cir::FuncOp fn, + FunctionArgList &args); + mlir::Value prepareKernelArgs(CIRGenFunction &cgf, mlir::Location loc, + FunctionArgList &args); + mlir::Operation *getKernelHandle(cir::FuncOp fn, GlobalDecl gd) override; + std::string addPrefixToName(StringRef funcName) const; + std::string addUnderscoredPrefixToName(StringRef funcName) const; + +public: + CIRGenNVCUDARuntime(CIRGenModule &cgm); + ~CIRGenNVCUDARuntime(); + + void emitDeviceStub(CIRGenFunction &cgf, cir::FuncOp fn, + FunctionArgList &args) override; +}; + +} // namespace + +std::string CIRGenNVCUDARuntime::addPrefixToName(StringRef funcName) const { + return (Prefix + funcName).str(); +} + +std::string +CIRGenNVCUDARuntime::addUnderscoredPrefixToName(StringRef funcName) const { + return ("__" + Prefix + funcName).str(); +} + +static std::unique_ptr<MangleContext> initDeviceMC(CIRGenModule &cgm) { + // If the host and device have different C++ ABIs, mark it as the device + // mangle context so that the mangling needs to retrieve the additional + // device lambda mangling number instead of the regular host one. + if (cgm.getASTContext().getAuxTargetInfo() && + cgm.getASTContext().getTargetInfo().getCXXABI().isMicrosoft() && + cgm.getASTContext().getAuxTargetInfo()->getCXXABI().isItaniumFamily()) { + return std::unique_ptr<MangleContext>( + cgm.getASTContext().createDeviceMangleContext( + *cgm.getASTContext().getAuxTargetInfo())); + } + + return std::unique_ptr<MangleContext>(cgm.getASTContext().createMangleContext( + cgm.getASTContext().getAuxTargetInfo())); +} + +CIRGenNVCUDARuntime::CIRGenNVCUDARuntime(CIRGenModule &cgm) + : CIRGenCUDARuntime(cgm), deviceMC(initDeviceMC(cgm)) { + if (cgm.getLangOpts().OffloadViaLLVM) + llvm_unreachable("NYI"); + else if (cgm.getLangOpts().HIP) + Prefix = "hip"; + else + Prefix = "cuda"; +} + +mlir::Value CIRGenNVCUDARuntime::prepareKernelArgs(CIRGenFunction &cgf, + mlir::Location loc, + FunctionArgList &args) { + auto &builder = cgm.getBuilder(); + + // Build void *args[] and populate with the addresses of kernel arguments. + auto voidPtrArrayTy = cir::ArrayType::get(cgm.voidPtrTy, args.size()); + mlir::Value kernelArgs = builder.createAlloca( + loc, cir::PointerType::get(voidPtrArrayTy), voidPtrArrayTy, "kernel_args", + CharUnits::fromQuantity(16)); + + mlir::Value kernelArgsDecayed = + builder.createCast(cir::CastKind::array_to_ptrdecay, kernelArgs, + cir::PointerType::get(cgm.voidPtrTy)); + + for (auto [i, arg] : llvm::enumerate(args)) { + mlir::Value index = + builder.getConstInt(loc, llvm::APInt(/*numBits=*/32, i)); + mlir::Value storePos = + builder.createPtrStride(loc, kernelArgsDecayed, index); + + // Get the address of the argument and cast the store destination to match + // its pointer-to-pointer type. This is needed because upstream's + // createStore doesn't auto-bitcast like the incubator version. ---------------- koparasy wrote:
I would delete the part of the comment ``` // This is needed because upstream's // createStore doesn't auto-bitcast like the incubator version. ``` If in the end this is the overall agreement we should not comment this. https://github.com/llvm/llvm-project/pull/177790 _______________________________________________ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
