https://github.com/clementval created https://github.com/llvm/llvm-project/pull/131396
Convert the operation to `llvm.addressof` operation with `llvm.getelementptr` with the appropriate offset. >From bd44073dc01ff6b5dce02490eeceb8a52f2130b4 Mon Sep 17 00:00:00 2001 From: Valentin Clement <clement...@gmail.com> Date: Fri, 14 Mar 2025 14:40:07 -0700 Subject: [PATCH] [flang][cuda] Convert cuf.shared_memory operation to LLVM ops --- .../flang/Optimizer/Builder/CUFCommon.h | 1 + .../Transforms/CUFGPUToLLVMConversion.cpp | 70 ++++++++++++++++++- flang/test/Fir/CUDA/cuda-shared-to-llvm.mlir | 20 ++++++ 3 files changed, 90 insertions(+), 1 deletion(-) create mode 100644 flang/test/Fir/CUDA/cuda-shared-to-llvm.mlir diff --git a/flang/include/flang/Optimizer/Builder/CUFCommon.h b/flang/include/flang/Optimizer/Builder/CUFCommon.h index e3c7b5098b83f..65b9cce1d2021 100644 --- a/flang/include/flang/Optimizer/Builder/CUFCommon.h +++ b/flang/include/flang/Optimizer/Builder/CUFCommon.h @@ -14,6 +14,7 @@ #include "mlir/IR/BuiltinOps.h" static constexpr llvm::StringRef cudaDeviceModuleName = "cuda_device_mod"; +static constexpr llvm::StringRef cudaSharedMemSuffix = "__shared_mem"; namespace fir { class FirOpBuilder; diff --git a/flang/lib/Optimizer/Transforms/CUFGPUToLLVMConversion.cpp b/flang/lib/Optimizer/Transforms/CUFGPUToLLVMConversion.cpp index 2a95e41944f3f..b54332b6694c4 100644 --- a/flang/lib/Optimizer/Transforms/CUFGPUToLLVMConversion.cpp +++ b/flang/lib/Optimizer/Transforms/CUFGPUToLLVMConversion.cpp @@ -7,12 +7,15 @@ //===----------------------------------------------------------------------===// #include "flang/Optimizer/Transforms/CUFGPUToLLVMConversion.h" +#include "flang/Optimizer/Builder/CUFCommon.h" #include "flang/Optimizer/CodeGen/TypeConverter.h" +#include "flang/Optimizer/Dialect/CUF/CUFOps.h" #include "flang/Optimizer/Support/DataLayout.h" #include "flang/Runtime/CUDA/common.h" #include "flang/Support/Fortran.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -175,6 +178,69 @@ struct GPULaunchKernelConversion } }; +static std::string getFuncName(cuf::SharedMemoryOp op) { + if (auto gpuFuncOp = op->getParentOfType<mlir::gpu::GPUFuncOp>()) + return gpuFuncOp.getName().str(); + if (auto funcOp = op->getParentOfType<mlir::func::FuncOp>()) + return funcOp.getName().str(); + if (auto llvmFuncOp = op->getParentOfType<mlir::LLVM::LLVMFuncOp>()) + return llvmFuncOp.getSymName().str(); + return ""; +} + +static mlir::Value createAddressOfOp(mlir::ConversionPatternRewriter &rewriter, + mlir::Location loc, + gpu::GPUModuleOp gpuMod, + std::string &sharedGlobalName) { + auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get( + rewriter.getContext(), mlir::NVVM::NVVMMemorySpace::kSharedMemorySpace); + if (auto g = gpuMod.lookupSymbol<fir::GlobalOp>(sharedGlobalName)) + return rewriter.create<mlir::LLVM::AddressOfOp>(loc, llvmPtrTy, + g.getSymName()); + if (auto g = gpuMod.lookupSymbol<mlir::LLVM::GlobalOp>(sharedGlobalName)) + return rewriter.create<mlir::LLVM::AddressOfOp>(loc, llvmPtrTy, + g.getSymName()); + return {}; +} + +struct CUFSharedMemoryOpConversion + : public mlir::ConvertOpToLLVMPattern<cuf::SharedMemoryOp> { + explicit CUFSharedMemoryOpConversion( + const fir::LLVMTypeConverter &typeConverter, mlir::PatternBenefit benefit) + : mlir::ConvertOpToLLVMPattern<cuf::SharedMemoryOp>(typeConverter, + benefit) {} + using OpAdaptor = typename cuf::SharedMemoryOp::Adaptor; + + mlir::LogicalResult + matchAndRewrite(cuf::SharedMemoryOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::Location loc = op->getLoc(); + if (!op.getOffset()) + mlir::emitError(loc, + "cuf.shared_memory must have an offset for code gen"); + + auto gpuMod = op->getParentOfType<gpu::GPUModuleOp>(); + std::string sharedGlobalName = + (getFuncName(op) + llvm::Twine(cudaSharedMemSuffix)).str(); + mlir::Value sharedGlobalAddr = + createAddressOfOp(rewriter, loc, gpuMod, sharedGlobalName); + + if (!sharedGlobalAddr) + mlir::emitError(loc, "Could not find the shared global operation\n"); + + auto castPtr = rewriter.create<mlir::LLVM::AddrSpaceCastOp>( + loc, mlir::LLVM::LLVMPointerType::get(rewriter.getContext()), + sharedGlobalAddr); + mlir::Type baseType = castPtr->getResultTypes().front(); + llvm::SmallVector<mlir::LLVM::GEPArg> gepArgs = { + static_cast<int32_t>(*op.getOffset())}; + mlir::Value shmemPtr = rewriter.create<mlir::LLVM::GEPOp>( + loc, baseType, rewriter.getI8Type(), castPtr, gepArgs); + rewriter.replaceOp(op, {shmemPtr}); + return mlir::success(); + } +}; + class CUFGPUToLLVMConversion : public fir::impl::CUFGPUToLLVMConversionBase<CUFGPUToLLVMConversion> { public: @@ -194,6 +260,7 @@ class CUFGPUToLLVMConversion /*forceUnifiedTBAATree=*/false, *dl); cuf::populateCUFGPUToLLVMConversionPatterns(typeConverter, patterns); target.addIllegalOp<mlir::gpu::LaunchFuncOp>(); + target.addIllegalOp<cuf::SharedMemoryOp>(); target.addLegalDialect<mlir::LLVM::LLVMDialect>(); if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, std::move(patterns)))) { @@ -208,5 +275,6 @@ class CUFGPUToLLVMConversion void cuf::populateCUFGPUToLLVMConversionPatterns( const fir::LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns, mlir::PatternBenefit benefit) { - patterns.add<GPULaunchKernelConversion>(converter, benefit); + patterns.add<CUFSharedMemoryOpConversion, GPULaunchKernelConversion>( + converter, benefit); } diff --git a/flang/test/Fir/CUDA/cuda-shared-to-llvm.mlir b/flang/test/Fir/CUDA/cuda-shared-to-llvm.mlir new file mode 100644 index 0000000000000..478ca92b63b60 --- /dev/null +++ b/flang/test/Fir/CUDA/cuda-shared-to-llvm.mlir @@ -0,0 +1,20 @@ +// RUN: fir-opt --split-input-file --cuf-gpu-convert-to-llvm %s | FileCheck %s + +module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<!llvm.ptr, dense<64> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<271>, dense<32> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<270>, dense<32> : vector<4xi64>>, #dlti.dl_entry<f128, dense<128> : vector<2xi64>>, #dlti.dl_entry<f64, dense<64> : vector<2xi64>>, #dlti.dl_entry<f80, dense<128> : vector<2xi64>>, #dlti.dl_entry<f16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i32, dense<32> : vector<2xi64>>, #dlti.dl_entry<i16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i128, dense<128> : vector<2xi64>>, #dlti.dl_entry<i8, dense<8> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr<272>, dense<64> : vector<4xi64>>, #dlti.dl_entry<i64, dense<64> : vector<2xi64>>, #dlti.dl_entry<i1, dense<8> : vector<2xi64>>, #dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>>, fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", gpu.container_module, llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", llvm.ident = "flang version 20.0.0 (https://github.com/llvm/llvm-project.git cae351f3453a0a26ec8eb2ddaf773c24a29d929e)", llvm.target_triple = "x86_64-unknown-linux-gnu"} { + gpu.module @cuda_device_mod { + llvm.func @_QPshared_static() { + %0 = cuf.shared_memory i32 {bindc_name = "a", offset = 0 : i32, uniq_name = "_QFshared_staticEa"} -> !fir.ref<i32> + %1 = cuf.shared_memory i32 {bindc_name = "b", offset = 4 : i32, uniq_name = "_QFshared_staticEb"} -> !fir.ref<i32> + llvm.return + } + llvm.mlir.global common @_QPshared_static__shared_mem(dense<0> : vector<28xi8>) {addr_space = 3 : i32, alignment = 8 : i64} : !llvm.array<28 x i8> + } +} + +// CHECK-LABEL: llvm.func @_QPshared_static() +// CHECK: %[[ADDR0:.*]] = llvm.mlir.addressof @_QPshared_static__shared_mem : !llvm.ptr<3> +// CHECK: %[[ADDRCAST0:.*]] = llvm.addrspacecast %[[ADDR0]] : !llvm.ptr<3> to !llvm.ptr +// CHECK: %[[A:.*]] = llvm.getelementptr %[[ADDRCAST0]][0] : (!llvm.ptr) -> !llvm.ptr, i8 +// CHECK: %[[ADDR1:.*]] = llvm.mlir.addressof @_QPshared_static__shared_mem : !llvm.ptr<3> +// CHECK: %[[ADDRCAST1:.*]] = llvm.addrspacecast %[[ADDR1]] : !llvm.ptr<3> to !llvm.ptr +// CHECK: %[[B:.*]] = llvm.getelementptr %[[ADDRCAST1]][4] : (!llvm.ptr) -> !llvm.ptr, i8 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits