================
@@ -2774,6 +2779,88 @@ void
LoweringPreparePass::buildCUDARegisterGlobalFunctions(
}
}
+// Emit `__{cuda|hip}RegisterVar` calls inside `__{cuda|hip}_register_globals`
+// for every device-side shadow that carries a `cu.var_registration` attribute
+// (attached by `CIRGenNVCUDARuntime::handleVarRegistration`).
+void LoweringPreparePass::buildCUDARegisterVars(cir::CIRBaseBuilderTy &builder,
+ FuncOp regGlobalFunc) {
+ mlir::Location loc = mlirModule.getLoc();
+ llvm::StringRef cudaPrefix = getCUDAPrefix(astCtx);
+ cir::CIRDataLayout dataLayout(mlirModule);
+
+ PointerType voidPtrTy = builder.getVoidPtrTy();
+ PointerType voidPtrPtrTy = builder.getPointerTo(voidPtrTy);
+ IntType intTy = builder.getSIntNTy(32);
+ IntType sizeTy =
+ builder.getUIntNTy(astCtx->getTargetInfo().getMaxPointerWidth());
+ IntType charTy = cir::IntType::get(&getContext(), astCtx->getCharWidth(),
+ /*isSigned=*/false);
+
+ if (cudaDeviceVars.empty())
+ return;
+
+ cir::CIRBaseBuilderTy globalBuilder(getContext());
+ globalBuilder.setInsertionPointToStart(mlirModule.getBody());
+
+ // void __{cuda|hip}RegisterVar(void **fatbinHandle,
+ // char *hostVar, char *deviceAddress,
+ // const char *deviceName, int ext,
+ // size_t size, int constant, int normalized);
+ // OG ignores parameter types, treating pointers as void*.
+ cir::VoidType voidTy = builder.getVoidTy();
+ FuncOp cudaRegisterVar = buildRuntimeFunction(
+ globalBuilder, addUnderscoredPrefix(cudaPrefix, "RegisterVar"), loc,
+ FuncType::get({voidPtrPtrTy, voidPtrTy, voidPtrTy, voidPtrTy, intTy,
+ sizeTy, intTy, intTy},
+ voidTy));
+
+ auto makeConstantString = [&](llvm::StringRef str) -> GlobalOp {
+ auto strType = ArrayType::get(&getContext(), charTy, 1 + str.size());
+ auto tmpString = cir::GlobalOp::create(
+ globalBuilder, loc, (".str" + str).str(), strType,
+ /*isConstant=*/true, {},
+ /*linkage=*/cir::GlobalLinkageKind::PrivateLinkage);
+ tmpString.setInitialValueAttr(
+ ConstArrayAttr::get(strType, StringAttr::get(str + "\0", strType)));
+ tmpString.setPrivate();
+ return tmpString;
+ };
+
+ mlir::Value fatbinHandle = *regGlobalFunc.args_begin();
+
+ for (auto &[global, regAttr] : cudaDeviceVars) {
+ switch (regAttr.getKind()) {
+ case cir::CUDADeviceVarKind::Variable:
+ break;
+ case cir::CUDADeviceVarKind::Surface:
+ llvm_unreachable("Surface registration NYI");
+ case cir::CUDADeviceVarKind::Texture:
+ llvm_unreachable("Texture registration NYI");
+ }
+
+ if (regAttr.getIsManaged())
+ llvm_unreachable("Managed variable registration NYI");
+
+ GlobalOp deviceNameStr = makeConstantString(regAttr.getDeviceSideName());
+ mlir::Value deviceName = builder.createBitcast(
+ builder.createGetGlobal(deviceNameStr), voidPtrTy);
+ mlir::Value hostVar =
+ builder.createBitcast(builder.createGetGlobal(global), voidPtrTy);
+
+ auto isExtern = ConstantOp::create(
+ builder, loc, IntAttr::get(intTy, regAttr.getIsExtern() ? 1 : 0));
+ llvm::TypeSize size = dataLayout.getTypeSizeInBits(global.getSymType());
+ auto varSize = ConstantOp::create(
+ builder, loc, IntAttr::get(sizeTy, size.getFixedValue() / 8));
----------------
RiverDave wrote:
Great eye for detail, thanks!
https://github.com/llvm/llvm-project/pull/199270
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits