https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/161862
>From 9ad453c194717784f0104e91d0decf34ae9110ea Mon Sep 17 00:00:00 2001 From: Sergio Afonso <[email protected]> Date: Fri, 12 Sep 2025 15:56:04 +0100 Subject: [PATCH 1/3] [Flang][MLIR][OpenMP] Add explicit shared memory (de-)allocation ops This patch introduces the `omp.alloc_shared_mem` and `omp.free_shared_mem` operations to represent explicit allocations and deallocations of shared memory across threads in a team, mirroring the existing `omp.target_allocmem` and `omp.target_freemem`. The `omp.alloc_shared_mem` op goes through the same Flang-specific transformations as `omp.target_allocmem`, so that the size of the buffer can be properly calculated when translating to LLVM IR. The corresponding runtime functions produced for these new operations are `__kmpc_alloc_shared` and `__kmpc_free_shared`, which previously could only be created for implicit allocations (e.g. privatized and reduction variables). --- flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp | 42 +++++++----- .../llvm/Frontend/OpenMP/OMPIRBuilder.h | 23 +++++++ llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 29 ++++++--- mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 61 ++++++++++++++++++ mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 11 ++++ .../OpenMP/OpenMPToLLVMIRTranslation.cpp | 64 ++++++++++++++++--- mlir/test/Dialect/OpenMP/invalid.mlir | 21 ++++++ mlir/test/Dialect/OpenMP/ops.mlir | 31 ++++++++- 8 files changed, 249 insertions(+), 33 deletions(-) diff --git a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp index 3e1fe1d2b1613..13214a9e51161 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp @@ -222,36 +222,47 @@ static mlir::Type convertObjectType(const fir::LLVMTypeConverter &converter, return converter.convertType(firType); } -// FIR Op specific conversion for TargetAllocMemOp -struct TargetAllocMemOpConversion - : public OpenMPFIROpConversion<mlir::omp::TargetAllocMemOp> { - using OpenMPFIROpConversion::OpenMPFIROpConversion; +// FIR Op specific conversion for allocation operations +template <typename T> +struct AllocMemOpConversion : public OpenMPFIROpConversion<T> { + using OpenMPFIROpConversion<T>::OpenMPFIROpConversion; llvm::LogicalResult - matchAndRewrite(mlir::omp::TargetAllocMemOp allocmemOp, OpAdaptor adaptor, + matchAndRewrite(T allocmemOp, + typename OpenMPFIROpConversion<T>::OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { mlir::Type heapTy = allocmemOp.getAllocatedType(); mlir::Location loc = allocmemOp.getLoc(); - auto ity = lowerTy().indexType(); + auto ity = OpenMPFIROpConversion<T>::lowerTy().indexType(); mlir::Type dataTy = fir::unwrapRefType(heapTy); - mlir::Type llvmObjectTy = convertObjectType(lowerTy(), dataTy); + mlir::Type llvmObjectTy = + convertObjectType(OpenMPFIROpConversion<T>::lowerTy(), dataTy); if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy))) - TODO(loc, "omp.target_allocmem codegen of derived type with length " - "parameters"); + TODO(loc, allocmemOp->getName().getStringRef() + + " codegen of derived type with length parameters"); mlir::Value size = fir::computeElementDistance( - loc, llvmObjectTy, ity, rewriter, lowerTy().getDataLayout()); + loc, llvmObjectTy, ity, rewriter, + OpenMPFIROpConversion<T>::lowerTy().getDataLayout()); if (auto scaleSize = fir::genAllocationScaleSize( loc, allocmemOp.getInType(), ity, rewriter)) size = mlir::LLVM::MulOp::create(rewriter, loc, ity, size, scaleSize); - for (mlir::Value opnd : adaptor.getOperands().drop_front()) + for (mlir::Value opnd : adaptor.getTypeparams()) + size = mlir::LLVM::MulOp::create( + rewriter, loc, ity, size, + integerCast(OpenMPFIROpConversion<T>::lowerTy(), loc, rewriter, ity, + opnd)); + for (mlir::Value opnd : adaptor.getShape()) size = mlir::LLVM::MulOp::create( rewriter, loc, ity, size, - integerCast(lowerTy(), loc, rewriter, ity, opnd)); - auto mallocTyWidth = lowerTy().getIndexTypeBitwidth(); + integerCast(OpenMPFIROpConversion<T>::lowerTy(), loc, rewriter, ity, + opnd)); + auto mallocTyWidth = + OpenMPFIROpConversion<T>::lowerTy().getIndexTypeBitwidth(); auto mallocTy = mlir::IntegerType::get(rewriter.getContext(), mallocTyWidth); if (mallocTyWidth != ity.getIntOrFloatBitWidth()) - size = integerCast(lowerTy(), loc, rewriter, mallocTy, size); + size = integerCast(OpenMPFIROpConversion<T>::lowerTy(), loc, rewriter, + mallocTy, size); rewriter.modifyOpInPlace(allocmemOp, [&]() { allocmemOp.setInType(rewriter.getI8Type()); allocmemOp.getTypeparamsMutable().clear(); @@ -281,6 +292,7 @@ void fir::populateOpenMPFIRToLLVMConversionPatterns( const LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns) { patterns.add<MapInfoOpConversion>(converter); patterns.add<PrivateClauseOpConversion>(converter); - patterns.add<TargetAllocMemOpConversion>(converter); patterns.add<DeclareMapperOpConversion>(converter); + patterns.add<AllocMemOpConversion<mlir::omp::TargetAllocMemOp>, + AllocMemOpConversion<mlir::omp::AllocSharedMemOp>>(converter); } diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h index d73ed61c42235..5d7b446767d94 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -3231,6 +3231,17 @@ class OpenMPIRBuilder { LLVM_ABI CallInst *createOMPFree(const LocationDescription &Loc, Value *Addr, Value *Allocator, std::string Name = ""); + /// Create a runtime call for kmpc_alloc_shared. + /// + /// \param Loc The insert and source location description. + /// \param Size Size of allocated memory space. + /// \param Name Name of call Instruction. + /// + /// \returns CallInst to the kmpc_alloc_shared call. + LLVM_ABI CallInst *createOMPAllocShared(const LocationDescription &Loc, + Value *Size, + const Twine &Name = Twine("")); + /// Create a runtime call for kmpc_alloc_shared. /// /// \param Loc The insert and source location description. @@ -3242,6 +3253,18 @@ class OpenMPIRBuilder { Type *VarType, const Twine &Name = Twine("")); + /// Create a runtime call for kmpc_free_shared. + /// + /// \param Loc The insert and source location description. + /// \param Addr Value obtained from the corresponding kmpc_alloc_shared call. + /// \param Size Size of allocated memory space. + /// \param Name Name of call Instruction. + /// + /// \returns CallInst to the kmpc_free_shared call. + LLVM_ABI CallInst *createOMPFreeShared(const LocationDescription &Loc, + Value *Addr, Value *Size, + const Twine &Name = Twine("")); + /// Create a runtime call for kmpc_free_shared. /// /// \param Loc The insert and source location description. diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index 481d3f4b21df5..14a058370fc1c 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -7836,32 +7836,45 @@ CallInst *OpenMPIRBuilder::createOMPFree(const LocationDescription &Loc, } CallInst *OpenMPIRBuilder::createOMPAllocShared(const LocationDescription &Loc, - Type *VarType, + Value *Size, const Twine &Name) { IRBuilder<>::InsertPointGuard IPG(Builder); updateToLocation(Loc); - const DataLayout &DL = M.getDataLayout(); - Value *Args[] = {Builder.getInt64(DL.getTypeAllocSize(VarType))}; + Value *Args[] = {Size}; Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_alloc_shared); CallInst *Call = Builder.CreateCall(Fn, Args, Name); - Call->addRetAttr( - Attribute::getWithAlignment(M.getContext(), DL.getPrefTypeAlign(Int64))); + Call->addRetAttr(Attribute::getWithAlignment( + M.getContext(), M.getDataLayout().getPrefTypeAlign(Int64))); return Call; } +CallInst *OpenMPIRBuilder::createOMPAllocShared(const LocationDescription &Loc, + Type *VarType, + const Twine &Name) { + return createOMPAllocShared( + Loc, Builder.getInt64(M.getDataLayout().getTypeAllocSize(VarType)), Name); +} + CallInst *OpenMPIRBuilder::createOMPFreeShared(const LocationDescription &Loc, - Value *Addr, Type *VarType, + Value *Addr, Value *Size, const Twine &Name) { IRBuilder<>::InsertPointGuard IPG(Builder); updateToLocation(Loc); - Value *Args[] = { - Addr, Builder.getInt64(M.getDataLayout().getTypeAllocSize(VarType))}; + Value *Args[] = {Addr, Size}; Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_free_shared); return Builder.CreateCall(Fn, Args, Name); } +CallInst *OpenMPIRBuilder::createOMPFreeShared(const LocationDescription &Loc, + Value *Addr, Type *VarType, + const Twine &Name) { + return createOMPFreeShared( + Loc, Addr, Builder.getInt64(M.getDataLayout().getTypeAllocSize(VarType)), + Name); +} + CallInst *OpenMPIRBuilder::createOMPInteropInit( const LocationDescription &Loc, Value *InteropVar, omp::OMPInteropType InteropType, Value *Device, Value *NumDependences, diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index a09b964a63901..c89bdef9d4d13 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -2235,6 +2235,67 @@ def TargetFreeMemOp : OpenMP_Op<"target_freemem", let assemblyFormat = "$device `,` $heapref attr-dict `:` type($device) `,` qualified(type($heapref))"; } +//===----------------------------------------------------------------------===// +// AllocSharedMemOp +//===----------------------------------------------------------------------===// + +def AllocSharedMemOp : OpenMP_Op<"alloc_shared_mem", traits = [ + AttrSizedOperandSegments + ], clauses = [ + OpenMP_HeapAllocClause + ]> { + let summary = "allocate storage on shared memory for an object of a given type"; + + let description = [{ + Allocates memory shared across threads of a team for an object of the given + type. Returns a pointer representing the allocated memory. The memory is + uninitialized after allocation. Operations must be paired with + `omp.free_shared` to avoid memory leaks. + + ```mlir + // Allocate a static 3x3 integer vector. + %ptr_shared = omp.alloc_shared_mem vector<3x3xi32> : !llvm.ptr + // ... + omp.free_shared_mem %ptr_shared : !llvm.ptr + ``` + }] # clausesDescription; + + let results = (outs OpenMP_PointerLikeType); + let assemblyFormat = clausesAssemblyFormat # " attr-dict `:` type(results)"; +} + +//===----------------------------------------------------------------------===// +// FreeSharedMemOp +//===----------------------------------------------------------------------===// + +def FreeSharedMemOp : OpenMP_Op<"free_shared_mem", [MemoryEffects<[MemFree]>]> { + let summary = "free shared memory"; + + let description = [{ + Deallocates shared memory that was previously allocated by an + `omp.alloc_shared_mem` operation. After this operation, the deallocated + memory is in an undefined state and should not be accessed. + It is crucial to ensure that all accesses to the memory region are completed + before `omp.alloc_shared_mem` is called to avoid undefined behavior. + + ```mlir + // Example of allocating and freeing shared memory. + %ptr_shared = omp.alloc_shared_mem vector<3x3xi32> : !llvm.ptr + // ... + omp.free_shared_mem %ptr_shared : !llvm.ptr + ``` + + The `heapref` operand represents the pointer to shared memory to be + deallocated, previously returned by `omp.alloc_shared_mem`. + }]; + + let arguments = (ins + Arg<OpenMP_PointerLikeType, "", [MemFree]>:$heapref + ); + let assemblyFormat = "$heapref attr-dict `:` type($heapref)"; + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // workdistribute Construct //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 2aa63b498f3fa..11b0ca53a9ac1 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -4496,6 +4496,17 @@ LogicalResult AllocateDirOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// FreeSharedMemOp +//===----------------------------------------------------------------------===// + +LogicalResult FreeSharedMemOp::verify() { + return getHeapref().getDefiningOp<AllocSharedMemOp>() + ? success() + : emitOpError() << "'heapref' operand must be defined by an " + "'omp.alloc_shared_memory' op"; +} + //===----------------------------------------------------------------------===// // WorkdistributeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 74d458733dd03..1f2df60c8765b 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -7533,6 +7533,25 @@ static llvm::Function *getOmpTargetAlloc(llvm::IRBuilderBase &builder, return func; } +static llvm::Value * +getAllocationSize(llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation, Type allocatedTy, + OperandRange typeparams, OperandRange shape) { + llvm::DataLayout dataLayout = + moduleTranslation.getLLVMModule()->getDataLayout(); + llvm::Type *llvmHeapTy = moduleTranslation.convertType(allocatedTy); + llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy); + llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue()); + for (auto typeParam : typeparams) { + allocSize = builder.CreateMul( + allocSize, + builder.CreateIntCast(moduleTranslation.lookupValue(typeParam), + builder.getInt64Ty(), + /*isSigned=*/false)); + } + return allocSize; +} + static LogicalResult convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) { @@ -7547,14 +7566,9 @@ convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder, mlir::Value deviceNum = allocMemOp.getDevice(); llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum); // Get the allocation size. - llvm::DataLayout dataLayout = llvmModule->getDataLayout(); - mlir::Type heapTy = allocMemOp.getAllocatedType(); - llvm::Type *llvmHeapTy = moduleTranslation.convertType(heapTy); - llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy); - llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue()); - for (auto typeParam : allocMemOp.getTypeparams()) - allocSize = - builder.CreateMul(allocSize, moduleTranslation.lookupValue(typeParam)); + llvm::Value *allocSize = getAllocationSize( + builder, moduleTranslation, allocMemOp.getAllocatedType(), + allocMemOp.getTypeparams(), allocMemOp.getShape()); // Create call to "omp_target_alloc" with the args as translated llvm values. llvm::CallInst *call = builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum}); @@ -7565,6 +7579,19 @@ convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder, return success(); } +static LogicalResult +convertAllocSharedMemOp(omp::AllocSharedMemOp allocMemOp, + llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); + llvm::Value *size = getAllocationSize( + builder, moduleTranslation, allocMemOp.getAllocatedType(), + allocMemOp.getTypeparams(), allocMemOp.getShape()); + moduleTranslation.mapValue(allocMemOp.getResult(), + ompBuilder->createOMPAllocShared(builder, size)); + return success(); +} + static llvm::Function *getOmpTargetFree(llvm::IRBuilderBase &builder, llvm::Module *llvmModule) { llvm::Type *ptrTy = builder.getPtrTy(0); @@ -7600,6 +7627,21 @@ convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder, return success(); } +static LogicalResult +convertFreeSharedMemOp(omp::FreeSharedMemOp freeMemOp, + llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); + auto allocMemOp = + freeMemOp.getHeapref().getDefiningOp<omp::AllocSharedMemOp>(); + llvm::Value *size = getAllocationSize( + builder, moduleTranslation, allocMemOp.getAllocatedType(), + allocMemOp.getTypeparams(), allocMemOp.getShape()); + ompBuilder->createOMPFreeShared( + builder, moduleTranslation.lookupValue(freeMemOp.getHeapref()), size); + return success(); +} + /// Given an OpenMP MLIR operation, create the corresponding LLVM IR (including /// OpenMP runtime calls). LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation( @@ -7795,6 +7837,12 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation( .Case([&](omp::TargetFreeMemOp) { return convertTargetFreeMemOp(*op, builder, moduleTranslation); }) + .Case([&](omp::AllocSharedMemOp op) { + return convertAllocSharedMemOp(op, builder, moduleTranslation); + }) + .Case([&](omp::FreeSharedMemOp op) { + return convertFreeSharedMemOp(op, builder, moduleTranslation); + }) .Default([&](Operation *inst) { return inst->emitError() << "not yet implemented: " << inst->getName(); diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index 11ed6b94e0053..5a6cdb6599f8d 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -3263,3 +3263,24 @@ func.func @target_allocmem_invalid_bindc_name(%device : i32) -> () { %0 = omp.target_allocmem %device : i32, i64 {bindc_name=2} return } + +// ----- +func.func @alloc_shared_mem_invalid_uniq_name() -> () { + // expected-error @below {{op attribute 'uniq_name' failed to satisfy constraint: string attribute}} + %0 = omp.alloc_shared_mem i64 {uniq_name=2} + return +} + +// ----- +func.func @alloc_shared_mem_invalid_bindc_name() -> () { + // expected-error @below {{op attribute 'bindc_name' failed to satisfy constraint: string attribute}} + %0 = omp.alloc_shared_mem i64 {bindc_name=2} + return +} + +// ----- +func.func @free_shared_mem_invalid_ptr(%ptr : !llvm.ptr) -> () { + // expected-error @below {{op 'heapref' operand must be defined by an 'omp.alloc_shared_memory' op}} + omp.free_shared_mem %ptr : !llvm.ptr + return +} diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index 6b2e79e0425d0..ff48673c1d7f0 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -3690,9 +3690,36 @@ func.func @omp_target_allocmem(%device: i32, %x: index, %y: index, %z: i32) { } // CHECK-LABEL: func.func @omp_target_freemem( -// CHECK-SAME: %[[DEVICE:.*]]: i32, %[[PTR:.*]]: i64) { -func.func @omp_target_freemem(%device : i32, %ptr : i64) { +// CHECK-SAME: %[[DEVICE:.*]]: i32) { +func.func @omp_target_freemem(%device : i32) { + // CHECK: %[[PTR:.*]] = omp.target_allocmem + %ptr = omp.target_allocmem %device : i32, i64 // CHECK: omp.target_freemem %[[DEVICE]], %[[PTR]] : i32, i64 omp.target_freemem %device, %ptr : i32, i64 return } + +// CHECK-LABEL: func.func @omp_alloc_shared_mem( +// CHECK-SAME: %[[X:.*]]: index, %[[Y:.*]]: index, %[[Z:.*]]: i32) { +func.func @omp_alloc_shared_mem(%x: index, %y: index, %z: i32) { + // CHECK: %{{.*}} = omp.alloc_shared_mem i64 : !llvm.ptr + %0 = omp.alloc_shared_mem i64 : !llvm.ptr + // CHECK: %{{.*}} = omp.alloc_shared_mem vector<16x16xf32> {bindc_name = "bindc", uniq_name = "uniq"} : !llvm.ptr + %1 = omp.alloc_shared_mem vector<16x16xf32> {uniq_name="uniq", bindc_name="bindc"} : !llvm.ptr + // CHECK: %{{.*}} = omp.alloc_shared_mem !llvm.ptr(%[[X]], %[[Y]], %[[Z]] : index, index, i32) : !llvm.ptr + %2 = omp.alloc_shared_mem !llvm.ptr(%x, %y, %z : index, index, i32) : !llvm.ptr + // CHECK: %{{.*}} = omp.alloc_shared_mem !llvm.ptr, %[[X]], %[[Y]] : !llvm.ptr + %3 = omp.alloc_shared_mem !llvm.ptr, %x, %y : !llvm.ptr + // CHECK: %{{.*}} = omp.alloc_shared_mem !llvm.ptr(%[[X]], %[[Y]], %[[Z]] : index, index, i32), %[[X]], %[[Y]] : !llvm.ptr + %4 = omp.alloc_shared_mem !llvm.ptr(%x, %y, %z : index, index, i32), %x, %y : !llvm.ptr + return +} + +// CHECK-LABEL: func.func @omp_free_shared_mem() { +func.func @omp_free_shared_mem() { + // CHECK: %[[PTR:.*]] = omp.alloc_shared_mem + %0 = omp.alloc_shared_mem i64 : !llvm.ptr + // CHECK: omp.free_shared_mem %[[PTR]] : !llvm.ptr + omp.free_shared_mem %0 : !llvm.ptr + return +} >From 5f302872ff32b1270a14e38f2c0197fd43c374ba Mon Sep 17 00:00:00 2001 From: Sergio Afonso <[email protected]> Date: Thu, 5 Feb 2026 13:06:23 +0000 Subject: [PATCH 2/3] simplify omp.alloc_shared_mem --- flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp | 42 +++++++----------- mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 37 ++++++++++++---- mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 23 +++++++--- .../OpenMP/OpenMPToLLVMIRTranslation.cpp | 44 +++++++++++++------ mlir/test/Dialect/OpenMP/invalid.mlir | 19 +++++--- mlir/test/Dialect/OpenMP/ops.mlir | 35 +++++++-------- .../LLVMIR/omptarget-device-shared-mem.mlir | 42 ++++++++++++++++++ 7 files changed, 160 insertions(+), 82 deletions(-) create mode 100644 mlir/test/Target/LLVMIR/omptarget-device-shared-mem.mlir diff --git a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp index 13214a9e51161..3e1fe1d2b1613 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp @@ -222,47 +222,36 @@ static mlir::Type convertObjectType(const fir::LLVMTypeConverter &converter, return converter.convertType(firType); } -// FIR Op specific conversion for allocation operations -template <typename T> -struct AllocMemOpConversion : public OpenMPFIROpConversion<T> { - using OpenMPFIROpConversion<T>::OpenMPFIROpConversion; +// FIR Op specific conversion for TargetAllocMemOp +struct TargetAllocMemOpConversion + : public OpenMPFIROpConversion<mlir::omp::TargetAllocMemOp> { + using OpenMPFIROpConversion::OpenMPFIROpConversion; llvm::LogicalResult - matchAndRewrite(T allocmemOp, - typename OpenMPFIROpConversion<T>::OpAdaptor adaptor, + matchAndRewrite(mlir::omp::TargetAllocMemOp allocmemOp, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { mlir::Type heapTy = allocmemOp.getAllocatedType(); mlir::Location loc = allocmemOp.getLoc(); - auto ity = OpenMPFIROpConversion<T>::lowerTy().indexType(); + auto ity = lowerTy().indexType(); mlir::Type dataTy = fir::unwrapRefType(heapTy); - mlir::Type llvmObjectTy = - convertObjectType(OpenMPFIROpConversion<T>::lowerTy(), dataTy); + mlir::Type llvmObjectTy = convertObjectType(lowerTy(), dataTy); if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy))) - TODO(loc, allocmemOp->getName().getStringRef() + - " codegen of derived type with length parameters"); + TODO(loc, "omp.target_allocmem codegen of derived type with length " + "parameters"); mlir::Value size = fir::computeElementDistance( - loc, llvmObjectTy, ity, rewriter, - OpenMPFIROpConversion<T>::lowerTy().getDataLayout()); + loc, llvmObjectTy, ity, rewriter, lowerTy().getDataLayout()); if (auto scaleSize = fir::genAllocationScaleSize( loc, allocmemOp.getInType(), ity, rewriter)) size = mlir::LLVM::MulOp::create(rewriter, loc, ity, size, scaleSize); - for (mlir::Value opnd : adaptor.getTypeparams()) - size = mlir::LLVM::MulOp::create( - rewriter, loc, ity, size, - integerCast(OpenMPFIROpConversion<T>::lowerTy(), loc, rewriter, ity, - opnd)); - for (mlir::Value opnd : adaptor.getShape()) + for (mlir::Value opnd : adaptor.getOperands().drop_front()) size = mlir::LLVM::MulOp::create( rewriter, loc, ity, size, - integerCast(OpenMPFIROpConversion<T>::lowerTy(), loc, rewriter, ity, - opnd)); - auto mallocTyWidth = - OpenMPFIROpConversion<T>::lowerTy().getIndexTypeBitwidth(); + integerCast(lowerTy(), loc, rewriter, ity, opnd)); + auto mallocTyWidth = lowerTy().getIndexTypeBitwidth(); auto mallocTy = mlir::IntegerType::get(rewriter.getContext(), mallocTyWidth); if (mallocTyWidth != ity.getIntOrFloatBitWidth()) - size = integerCast(OpenMPFIROpConversion<T>::lowerTy(), loc, rewriter, - mallocTy, size); + size = integerCast(lowerTy(), loc, rewriter, mallocTy, size); rewriter.modifyOpInPlace(allocmemOp, [&]() { allocmemOp.setInType(rewriter.getI8Type()); allocmemOp.getTypeparamsMutable().clear(); @@ -292,7 +281,6 @@ void fir::populateOpenMPFIRToLLVMConversionPatterns( const LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns) { patterns.add<MapInfoOpConversion>(converter); patterns.add<PrivateClauseOpConversion>(converter); + patterns.add<TargetAllocMemOpConversion>(converter); patterns.add<DeclareMapperOpConversion>(converter); - patterns.add<AllocMemOpConversion<mlir::omp::TargetAllocMemOp>, - AllocMemOpConversion<mlir::omp::AllocSharedMemOp>>(converter); } diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index c89bdef9d4d13..86411d71a5787 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -2240,11 +2240,15 @@ def TargetFreeMemOp : OpenMP_Op<"target_freemem", //===----------------------------------------------------------------------===// def AllocSharedMemOp : OpenMP_Op<"alloc_shared_mem", traits = [ - AttrSizedOperandSegments - ], clauses = [ - OpenMP_HeapAllocClause + MemoryEffects<[MemAlloc<DefaultResource>]> ]> { - let summary = "allocate storage on shared memory for an object of a given type"; + let summary = "allocate storage on shared memory for objects of a given type"; + + let arguments = (ins + TypeAttr:$elem_type, + AnySignlessInteger:$array_size, + ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$alignment + ); let description = [{ Allocates memory shared across threads of a team for an object of the given @@ -2253,15 +2257,30 @@ def AllocSharedMemOp : OpenMP_Op<"alloc_shared_mem", traits = [ `omp.free_shared` to avoid memory leaks. ```mlir - // Allocate a static 3x3 integer vector. - %ptr_shared = omp.alloc_shared_mem vector<3x3xi32> : !llvm.ptr + // Allocate an i32 vector with %size elements and aligned to 8 bytes. + %ptr_shared = omp.alloc_shared_mem %size x i32 {alignment = 8} : (i64) -> (!llvm.ptr) // ... omp.free_shared_mem %ptr_shared : !llvm.ptr ``` - }] # clausesDescription; + + The `elem_type` is the type of the object for which memory is being + allocated. + + The `array_size` is the number of objects to allocate memory for. + + The optional `alignment` is used to specify the alignment for each element. + If not set, the `DataLayout` defaults will be used instead. + }]; let results = (outs OpenMP_PointerLikeType); - let assemblyFormat = clausesAssemblyFormat # " attr-dict `:` type(results)"; + let assemblyFormat = [{ + $array_size `x` $elem_type attr-dict `:` `(` type($array_size) `)` `->` type(results) + }]; + + let extraClassDeclaration = [{ + mlir::Type getAllocatedType() { return getElemTypeAttr().getValue(); } + }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -2280,7 +2299,7 @@ def FreeSharedMemOp : OpenMP_Op<"free_shared_mem", [MemoryEffects<[MemFree]>]> { ```mlir // Example of allocating and freeing shared memory. - %ptr_shared = omp.alloc_shared_mem vector<3x3xi32> : !llvm.ptr + %ptr_shared = omp.alloc_shared_mem %size x i32 : (i64) -> (!llvm.ptr) // ... omp.free_shared_mem %ptr_shared : !llvm.ptr ``` diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 11b0ca53a9ac1..446fa1fecc122 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -4483,17 +4483,26 @@ LogicalResult ScanOp::verify() { } /// Verifies align clause in allocate directive +LogicalResult verifyAlignment(Operation &op, + std::optional<uint64_t> alignment) { + if (alignment.has_value()) { + if ((alignment.value() != 0) && !llvm::has_single_bit(alignment.value())) + return op.emitError() + << "ALIGN value : " << alignment.value() << " must be power of 2"; + } + return success(); +} LogicalResult AllocateDirOp::verify() { - std::optional<uint64_t> align = this->getAlign(); + return verifyAlignment(*getOperation(), getAlign()); +} - if (align.has_value()) { - if ((align.value() > 0) && !llvm::has_single_bit(align.value())) - return emitError() << "ALIGN value : " << align.value() - << " must be power of 2"; - } +//===----------------------------------------------------------------------===// +// AllocSharedMemOp +//===----------------------------------------------------------------------===// - return success(); +LogicalResult AllocSharedMemOp::verify() { + return verifyAlignment(*getOperation(), getAlignment()); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 1f2df60c8765b..74decff68b321 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -7535,14 +7535,14 @@ static llvm::Function *getOmpTargetAlloc(llvm::IRBuilderBase &builder, static llvm::Value * getAllocationSize(llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation, Type allocatedTy, - OperandRange typeparams, OperandRange shape) { + LLVM::ModuleTranslation &moduleTranslation, + omp::TargetAllocMemOp op) { llvm::DataLayout dataLayout = moduleTranslation.getLLVMModule()->getDataLayout(); - llvm::Type *llvmHeapTy = moduleTranslation.convertType(allocatedTy); - llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy); + llvm::Type *llvmHeapTy = moduleTranslation.convertType(op.getAllocatedType()); + llvm::TypeSize typeSize = dataLayout.getTypeAllocSize(llvmHeapTy); llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue()); - for (auto typeParam : typeparams) { + for (auto typeParam : op.getTypeparams()) { allocSize = builder.CreateMul( allocSize, builder.CreateIntCast(moduleTranslation.lookupValue(typeParam), @@ -7552,6 +7552,27 @@ getAllocationSize(llvm::IRBuilderBase &builder, return allocSize; } +static llvm::Value * +getAllocationSize(llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation, + omp::AllocSharedMemOp op) { + llvm::DataLayout dataLayout = + moduleTranslation.getLLVMModule()->getDataLayout(); + llvm::Type *llvmHeapTy = moduleTranslation.convertType(op.getAllocatedType()); + + auto alignment = op.getAlignment(); + llvm::TypeSize typeSize = llvm::alignTo( + dataLayout.getTypeStoreSize(llvmHeapTy), + alignment ? *alignment : dataLayout.getABITypeAlign(llvmHeapTy).value()); + + llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue()); + return builder.CreateMul( + allocSize, + builder.CreateIntCast(moduleTranslation.lookupValue(op.getArraySize()), + builder.getInt64Ty(), + /*isSigned=*/false)); +} + static LogicalResult convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) { @@ -7566,9 +7587,8 @@ convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder, mlir::Value deviceNum = allocMemOp.getDevice(); llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum); // Get the allocation size. - llvm::Value *allocSize = getAllocationSize( - builder, moduleTranslation, allocMemOp.getAllocatedType(), - allocMemOp.getTypeparams(), allocMemOp.getShape()); + llvm::Value *allocSize = + getAllocationSize(builder, moduleTranslation, allocMemOp); // Create call to "omp_target_alloc" with the args as translated llvm values. llvm::CallInst *call = builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum}); @@ -7584,9 +7604,7 @@ convertAllocSharedMemOp(omp::AllocSharedMemOp allocMemOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) { llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); - llvm::Value *size = getAllocationSize( - builder, moduleTranslation, allocMemOp.getAllocatedType(), - allocMemOp.getTypeparams(), allocMemOp.getShape()); + llvm::Value *size = getAllocationSize(builder, moduleTranslation, allocMemOp); moduleTranslation.mapValue(allocMemOp.getResult(), ompBuilder->createOMPAllocShared(builder, size)); return success(); @@ -7634,9 +7652,7 @@ convertFreeSharedMemOp(omp::FreeSharedMemOp freeMemOp, llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); auto allocMemOp = freeMemOp.getHeapref().getDefiningOp<omp::AllocSharedMemOp>(); - llvm::Value *size = getAllocationSize( - builder, moduleTranslation, allocMemOp.getAllocatedType(), - allocMemOp.getTypeparams(), allocMemOp.getShape()); + llvm::Value *size = getAllocationSize(builder, moduleTranslation, allocMemOp); ompBuilder->createOMPFreeShared( builder, moduleTranslation.lookupValue(freeMemOp.getHeapref()), size); return success(); diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index 5a6cdb6599f8d..c3f408761841c 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -3265,16 +3265,23 @@ func.func @target_allocmem_invalid_bindc_name(%device : i32) -> () { } // ----- -func.func @alloc_shared_mem_invalid_uniq_name() -> () { - // expected-error @below {{op attribute 'uniq_name' failed to satisfy constraint: string attribute}} - %0 = omp.alloc_shared_mem i64 {uniq_name=2} +func.func @alloc_shared_mem_invalid_alignment1(%n: i32) -> () { + // expected-error @below {{op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive}} + %0 = omp.alloc_shared_mem %n x i64 {alignment=-2} : (i32) -> !llvm.ptr return } // ----- -func.func @alloc_shared_mem_invalid_bindc_name() -> () { - // expected-error @below {{op attribute 'bindc_name' failed to satisfy constraint: string attribute}} - %0 = omp.alloc_shared_mem i64 {bindc_name=2} +func.func @alloc_shared_mem_invalid_alignment2(%n: i32) -> () { + // expected-error @below {{ALIGN value : 3 must be power of 2}} + %0 = omp.alloc_shared_mem %n x i64 {alignment=3} : (i32) -> !llvm.ptr + return +} + +// ----- +func.func @alloc_shared_mem_invalid_array_size(%n: f32) -> () { + // expected-error @below {{invalid kind of type specified: expected builtin.integer, but found 'f32'}} + %0 = omp.alloc_shared_mem %n x i64 : (f32) -> !llvm.ptr return } diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index ff48673c1d7f0..c5bf2e2bd5818 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -3700,25 +3700,22 @@ func.func @omp_target_freemem(%device : i32) { } // CHECK-LABEL: func.func @omp_alloc_shared_mem( -// CHECK-SAME: %[[X:.*]]: index, %[[Y:.*]]: index, %[[Z:.*]]: i32) { -func.func @omp_alloc_shared_mem(%x: index, %y: index, %z: i32) { - // CHECK: %{{.*}} = omp.alloc_shared_mem i64 : !llvm.ptr - %0 = omp.alloc_shared_mem i64 : !llvm.ptr - // CHECK: %{{.*}} = omp.alloc_shared_mem vector<16x16xf32> {bindc_name = "bindc", uniq_name = "uniq"} : !llvm.ptr - %1 = omp.alloc_shared_mem vector<16x16xf32> {uniq_name="uniq", bindc_name="bindc"} : !llvm.ptr - // CHECK: %{{.*}} = omp.alloc_shared_mem !llvm.ptr(%[[X]], %[[Y]], %[[Z]] : index, index, i32) : !llvm.ptr - %2 = omp.alloc_shared_mem !llvm.ptr(%x, %y, %z : index, index, i32) : !llvm.ptr - // CHECK: %{{.*}} = omp.alloc_shared_mem !llvm.ptr, %[[X]], %[[Y]] : !llvm.ptr - %3 = omp.alloc_shared_mem !llvm.ptr, %x, %y : !llvm.ptr - // CHECK: %{{.*}} = omp.alloc_shared_mem !llvm.ptr(%[[X]], %[[Y]], %[[Z]] : index, index, i32), %[[X]], %[[Y]] : !llvm.ptr - %4 = omp.alloc_shared_mem !llvm.ptr(%x, %y, %z : index, index, i32), %x, %y : !llvm.ptr - return -} - -// CHECK-LABEL: func.func @omp_free_shared_mem() { -func.func @omp_free_shared_mem() { - // CHECK: %[[PTR:.*]] = omp.alloc_shared_mem - %0 = omp.alloc_shared_mem i64 : !llvm.ptr +// CHECK-SAME: %[[N:.*]]: i32) { +func.func @omp_alloc_shared_mem(%n: i32) { + // CHECK: %{{.*}} = omp.alloc_shared_mem %[[N]] x i64 : (i32) -> !llvm.ptr + %0 = omp.alloc_shared_mem %n x i64 : (i32) -> !llvm.ptr + // CHECK: %{{.*}} = omp.alloc_shared_mem %[[N]] x vector<16x16xf32> : (i32) -> !llvm.ptr + %1 = omp.alloc_shared_mem %n x vector<16x16xf32> : (i32) -> !llvm.ptr + // CHECK: %{{.*}} = omp.alloc_shared_mem %[[N]] x !llvm.ptr {alignment = 16 : i64} : (i32) -> !llvm.ptr + %2 = omp.alloc_shared_mem %n x !llvm.ptr {alignment = 16} : (i32) -> !llvm.ptr + return +} + +// CHECK-LABEL: func.func @omp_free_shared_mem( +// CHECK-SAME: %[[N:.*]]: i64) { +func.func @omp_free_shared_mem(%n: i64) { + // CHECK: %[[PTR:.*]] = omp.alloc_shared_mem %[[N]] x f32 : (i64) -> !llvm.ptr + %0 = omp.alloc_shared_mem %n x f32 : (i64) -> !llvm.ptr // CHECK: omp.free_shared_mem %[[PTR]] : !llvm.ptr omp.free_shared_mem %0 : !llvm.ptr return diff --git a/mlir/test/Target/LLVMIR/omptarget-device-shared-mem.mlir b/mlir/test/Target/LLVMIR/omptarget-device-shared-mem.mlir new file mode 100644 index 0000000000000..72b0a2daadfc3 --- /dev/null +++ b/mlir/test/Target/LLVMIR/omptarget-device-shared-mem.mlir @@ -0,0 +1,42 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memory_space", 5 : ui32>>, llvm.data_layout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128:128:48-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8:9", llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} { + // CHECK-LABEL: define void @device_shared_mem( + // CHECK-SAME: i32 %[[N0:.*]], i64 %[[N1:.*]]) + llvm.func @device_shared_mem(%n0: i32, %n1: i64) attributes {omp.declare_target = #omp.declaretarget<device_type = (nohost), capture_clause = (to)>} { + // CHECK: %[[CAST_N0:.*]] = zext i32 %[[N0]] to i64 + // CHECK-NEXT: %[[ALLOC0_SZ:.*]] = mul i64 8, %[[CAST_N0]] + // CHECK-NEXT: %[[ALLOC0:.*]] = call align 8 ptr @__kmpc_alloc_shared(i64 %[[ALLOC0_SZ]]) + %0 = omp.alloc_shared_mem %n0 x i64 : (i32) -> !llvm.ptr + + // CHECK: %[[ALLOC1_SZ:.*]] = mul i64 8, %[[N1]] + // CHECK-NEXT: %[[ALLOC1:.*]] = call align 8 ptr @__kmpc_alloc_shared(i64 %[[ALLOC1_SZ]]) + %1 = omp.alloc_shared_mem %n1 x i64 : (i64) -> !llvm.ptr + + // CHECK: %[[ALLOC2_SZ:.*]] = mul i64 64, %[[N1]] + // CHECK-NEXT: %[[ALLOC2:.*]] = call align 8 ptr @__kmpc_alloc_shared(i64 %[[ALLOC2_SZ]]) + %2 = omp.alloc_shared_mem %n1 x vector<16xf32> : (i64) -> !llvm.ptr + + // CHECK: %[[ALLOC3_SZ:.*]] = mul i64 128, %[[N1]] + // CHECK-NEXT: %[[ALLOC3:.*]] = call align 8 ptr @__kmpc_alloc_shared(i64 %[[ALLOC3_SZ]]) + %3 = omp.alloc_shared_mem %n1 x vector<16xf32> {alignment = 128} : (i64) -> !llvm.ptr + + // CHECK: %[[CAST_N0_1:.*]] = zext i32 %[[N0]] to i64 + // CHECK-NEXT: %[[FREE0_SZ:.*]] = mul i64 8, %[[CAST_N0_1]] + // CHECK-NEXT: call void @__kmpc_free_shared(ptr %[[ALLOC0]], i64 %[[FREE0_SZ]]) + omp.free_shared_mem %0 : !llvm.ptr + + // CHECK: %[[FREE1_SZ:.*]] = mul i64 8, %[[N1]] + // CHECK-NEXT: call void @__kmpc_free_shared(ptr %[[ALLOC1]], i64 %[[FREE1_SZ]]) + omp.free_shared_mem %1 : !llvm.ptr + + // CHECK: %[[FREE2_SZ:.*]] = mul i64 64, %[[N1]] + // CHECK-NEXT: call void @__kmpc_free_shared(ptr %[[ALLOC2]], i64 %[[FREE2_SZ]]) + omp.free_shared_mem %2 : !llvm.ptr + + // CHECK: %[[FREE3_SZ:.*]] = mul i64 128, %[[N1]] + // CHECK-NEXT: call void @__kmpc_free_shared(ptr %[[ALLOC3]], i64 %[[FREE3_SZ]]) + omp.free_shared_mem %3 : !llvm.ptr + llvm.return + } +} >From a529c296d69d60f43dd6d9f0b6f574eb28632b6a Mon Sep 17 00:00:00 2001 From: Sergio Afonso <[email protected]> Date: Tue, 17 Feb 2026 10:49:35 +0000 Subject: [PATCH 3/3] address review comments: make omp.free_shared_mem self-contained, update alignment handling for shared memory allocations --- .../mlir/Dialect/OpenMP/OpenMPClauses.td | 37 ++++++++++++++ mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 50 ++++++++---------- mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 7 +-- .../OpenMP/OpenMPToLLVMIRTranslation.cpp | 51 +++++++++---------- mlir/test/Dialect/OpenMP/invalid.mlir | 23 ++++++--- mlir/test/Dialect/OpenMP/ops.mlir | 12 +++-- .../LLVMIR/omptarget-device-shared-mem.mlir | 10 ++-- 7 files changed, 112 insertions(+), 78 deletions(-) diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index 85fc392902a93..68b99828df1d4 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -910,6 +910,43 @@ class OpenMP_MapClauseSkip< def OpenMP_MapClause : OpenMP_MapClauseSkip<>; +//===----------------------------------------------------------------------===// +// Not in the spec: Clause-like structure to memory allocation information. +//===----------------------------------------------------------------------===// + +class OpenMP_MemAllocationSizeClauseSkip< + bit traits = false, bit arguments = false, bit assemblyFormat = false, + bit description = false, bit extraClassDeclaration = false + > : OpenMP_Clause<traits, arguments, assemblyFormat, description, + extraClassDeclaration> { + + let arguments = (ins + TypeAttr:$mem_elem_type, + AnySignlessInteger:$mem_array_size, + ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$mem_alignment + ); + + let reqAssemblyFormat = [{ + $mem_array_size `x` $mem_elem_type `:` `(` type($mem_array_size) `)` + }]; + + let optAssemblyFormat = [{ + `align` `(` $mem_alignment `)` + }]; + + let description = [{ + The `mem_elem_type` is the type of the object the memory allocation refers + to. It is used to calculate the size of the allocation. + + The `mem_array_size` is the number of objects. + + The optional `mem_alignment` is used to specify the alignment for each + element. If not set, the `DataLayout` defaults will be used instead. + }]; +} + +def OpenMP_MemAllocationSizeClause : OpenMP_MemAllocationSizeClauseSkip<>; + //===----------------------------------------------------------------------===// // V5.2: [15.8.1] `memory-order` clause set //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index 86411d71a5787..bbc07346d6915 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -2241,15 +2241,11 @@ def TargetFreeMemOp : OpenMP_Op<"target_freemem", def AllocSharedMemOp : OpenMP_Op<"alloc_shared_mem", traits = [ MemoryEffects<[MemAlloc<DefaultResource>]> + ], clauses = [ + OpenMP_MemAllocationSizeClause ]> { let summary = "allocate storage on shared memory for objects of a given type"; - let arguments = (ins - TypeAttr:$elem_type, - AnySignlessInteger:$array_size, - ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$alignment - ); - let description = [{ Allocates memory shared across threads of a team for an object of the given type. Returns a pointer representing the allocated memory. The memory is @@ -2258,27 +2254,18 @@ def AllocSharedMemOp : OpenMP_Op<"alloc_shared_mem", traits = [ ```mlir // Allocate an i32 vector with %size elements and aligned to 8 bytes. - %ptr_shared = omp.alloc_shared_mem %size x i32 {alignment = 8} : (i64) -> (!llvm.ptr) + %ptr_shared = omp.alloc_shared_mem %size x i32 : (i64) align(8) -> !llvm.ptr // ... - omp.free_shared_mem %ptr_shared : !llvm.ptr + omp.free_shared_mem [%size x i32 : (i64) align(8)] %ptr_shared : !llvm.ptr ``` - - The `elem_type` is the type of the object for which memory is being - allocated. - - The `array_size` is the number of objects to allocate memory for. - - The optional `alignment` is used to specify the alignment for each element. - If not set, the `DataLayout` defaults will be used instead. - }]; + }] # clausesDescription; let results = (outs OpenMP_PointerLikeType); - let assemblyFormat = [{ - $array_size `x` $elem_type attr-dict `:` `(` type($array_size) `)` `->` type(results) - }]; + let assemblyFormat = clausesReqAssemblyFormat # " oilist(" # + clausesOptAssemblyFormat # ") `->` type(results) attr-dict"; let extraClassDeclaration = [{ - mlir::Type getAllocatedType() { return getElemTypeAttr().getValue(); } + mlir::Type getAllocatedType() { return getMemElemTypeAttr().getValue(); } }]; let hasVerifier = 1; } @@ -2287,31 +2274,34 @@ def AllocSharedMemOp : OpenMP_Op<"alloc_shared_mem", traits = [ // FreeSharedMemOp //===----------------------------------------------------------------------===// -def FreeSharedMemOp : OpenMP_Op<"free_shared_mem", [MemoryEffects<[MemFree]>]> { +def FreeSharedMemOp : OpenMP_Op<"free_shared_mem", traits = [ + MemoryEffects<[MemFree]> + ], clauses = [ + OpenMP_MemAllocationSizeClause + ]> { let summary = "free shared memory"; let description = [{ Deallocates shared memory that was previously allocated by an `omp.alloc_shared_mem` operation. After this operation, the deallocated memory is in an undefined state and should not be accessed. - It is crucial to ensure that all accesses to the memory region are completed - before `omp.alloc_shared_mem` is called to avoid undefined behavior. ```mlir // Example of allocating and freeing shared memory. - %ptr_shared = omp.alloc_shared_mem %size x i32 : (i64) -> (!llvm.ptr) + %ptr_shared = omp.alloc_shared_mem %size x i32 : (i64) -> !llvm.ptr // ... - omp.free_shared_mem %ptr_shared : !llvm.ptr + omp.free_shared_mem [%size x i32 : (i64)] %ptr_shared : !llvm.ptr ``` The `heapref` operand represents the pointer to shared memory to be deallocated, previously returned by `omp.alloc_shared_mem`. - }]; + }] # clausesDescription; - let arguments = (ins + let arguments = !con(clausesArgs, (ins Arg<OpenMP_PointerLikeType, "", [MemFree]>:$heapref - ); - let assemblyFormat = "$heapref attr-dict `:` type($heapref)"; + )); + let assemblyFormat = "` ` `[`" # clausesReqAssemblyFormat # " oilist(" # + clausesOptAssemblyFormat # ") `]` $heapref `:` type($heapref) attr-dict"; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 446fa1fecc122..55a5e79ba4407 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -4502,7 +4502,7 @@ LogicalResult AllocateDirOp::verify() { //===----------------------------------------------------------------------===// LogicalResult AllocSharedMemOp::verify() { - return verifyAlignment(*getOperation(), getAlignment()); + return verifyAlignment(*getOperation(), getMemAlignment()); } //===----------------------------------------------------------------------===// @@ -4510,10 +4510,7 @@ LogicalResult AllocSharedMemOp::verify() { //===----------------------------------------------------------------------===// LogicalResult FreeSharedMemOp::verify() { - return getHeapref().getDefiningOp<AllocSharedMemOp>() - ? success() - : emitOpError() << "'heapref' operand must be defined by an " - "'omp.alloc_shared_memory' op"; + return verifyAlignment(*getOperation(), getMemAlignment()); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 74decff68b321..b73e8b2e62f49 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -7533,10 +7533,32 @@ static llvm::Function *getOmpTargetAlloc(llvm::IRBuilderBase &builder, return func; } +template <typename T> static llvm::Value * getAllocationSize(llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation, - omp::TargetAllocMemOp op) { + LLVM::ModuleTranslation &moduleTranslation, T op) { + llvm::DataLayout dataLayout = + moduleTranslation.getLLVMModule()->getDataLayout(); + llvm::Type *llvmHeapTy = + moduleTranslation.convertType(op.getMemElemTypeAttr().getValue()); + + auto alignment = op.getMemAlignment(); + llvm::TypeSize typeSize = llvm::alignTo( + dataLayout.getTypeStoreSize(llvmHeapTy), + alignment ? *alignment : dataLayout.getABITypeAlign(llvmHeapTy).value()); + + llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue()); + return builder.CreateMul( + allocSize, + builder.CreateIntCast(moduleTranslation.lookupValue(op.getMemArraySize()), + builder.getInt64Ty(), + /*isSigned=*/false)); +} + +template <> +llvm::Value *getAllocationSize(llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation, + omp::TargetAllocMemOp op) { llvm::DataLayout dataLayout = moduleTranslation.getLLVMModule()->getDataLayout(); llvm::Type *llvmHeapTy = moduleTranslation.convertType(op.getAllocatedType()); @@ -7552,27 +7574,6 @@ getAllocationSize(llvm::IRBuilderBase &builder, return allocSize; } -static llvm::Value * -getAllocationSize(llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation, - omp::AllocSharedMemOp op) { - llvm::DataLayout dataLayout = - moduleTranslation.getLLVMModule()->getDataLayout(); - llvm::Type *llvmHeapTy = moduleTranslation.convertType(op.getAllocatedType()); - - auto alignment = op.getAlignment(); - llvm::TypeSize typeSize = llvm::alignTo( - dataLayout.getTypeStoreSize(llvmHeapTy), - alignment ? *alignment : dataLayout.getABITypeAlign(llvmHeapTy).value()); - - llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue()); - return builder.CreateMul( - allocSize, - builder.CreateIntCast(moduleTranslation.lookupValue(op.getArraySize()), - builder.getInt64Ty(), - /*isSigned=*/false)); -} - static LogicalResult convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) { @@ -7650,9 +7651,7 @@ convertFreeSharedMemOp(omp::FreeSharedMemOp freeMemOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) { llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); - auto allocMemOp = - freeMemOp.getHeapref().getDefiningOp<omp::AllocSharedMemOp>(); - llvm::Value *size = getAllocationSize(builder, moduleTranslation, allocMemOp); + llvm::Value *size = getAllocationSize(builder, moduleTranslation, freeMemOp); ompBuilder->createOMPFreeShared( builder, moduleTranslation.lookupValue(freeMemOp.getHeapref()), size); return success(); diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index c3f408761841c..e37fc8902b548 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -3266,28 +3266,35 @@ func.func @target_allocmem_invalid_bindc_name(%device : i32) -> () { // ----- func.func @alloc_shared_mem_invalid_alignment1(%n: i32) -> () { - // expected-error @below {{op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive}} - %0 = omp.alloc_shared_mem %n x i64 {alignment=-2} : (i32) -> !llvm.ptr + // expected-error @below {{op attribute 'mem_alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive}} + %0 = omp.alloc_shared_mem %n x i64 : (i32) align(-2) -> !llvm.ptr return } // ----- func.func @alloc_shared_mem_invalid_alignment2(%n: i32) -> () { // expected-error @below {{ALIGN value : 3 must be power of 2}} - %0 = omp.alloc_shared_mem %n x i64 {alignment=3} : (i32) -> !llvm.ptr + %0 = omp.alloc_shared_mem %n x i64 : (i32) align(3) -> !llvm.ptr return } // ----- -func.func @alloc_shared_mem_invalid_array_size(%n: f32) -> () { +func.func @free_shared_mem_invalid_array_size(%n: f32, %ptr : !llvm.ptr) -> () { // expected-error @below {{invalid kind of type specified: expected builtin.integer, but found 'f32'}} - %0 = omp.alloc_shared_mem %n x i64 : (f32) -> !llvm.ptr + %0 = omp.free_shared_mem [%n x i64 : (f32)] %ptr : !llvm.ptr return } // ----- -func.func @free_shared_mem_invalid_ptr(%ptr : !llvm.ptr) -> () { - // expected-error @below {{op 'heapref' operand must be defined by an 'omp.alloc_shared_memory' op}} - omp.free_shared_mem %ptr : !llvm.ptr +func.func @free_shared_mem_invalid_alignment1(%n: i32, %ptr : !llvm.ptr) -> () { + // expected-error @below {{op attribute 'mem_alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive}} + omp.free_shared_mem [%n x i64 : (i32) align(-2)] %ptr : !llvm.ptr + return +} + +// ----- +func.func @free_shared_mem_invalid_alignment2(%n: i32, %ptr : !llvm.ptr) -> () { + // expected-error @below {{ALIGN value : 3 must be power of 2}} + omp.free_shared_mem [%n x i64 : (i32) align(3)] %ptr : !llvm.ptr return } diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index c5bf2e2bd5818..13900a4504d8c 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -3706,8 +3706,8 @@ func.func @omp_alloc_shared_mem(%n: i32) { %0 = omp.alloc_shared_mem %n x i64 : (i32) -> !llvm.ptr // CHECK: %{{.*}} = omp.alloc_shared_mem %[[N]] x vector<16x16xf32> : (i32) -> !llvm.ptr %1 = omp.alloc_shared_mem %n x vector<16x16xf32> : (i32) -> !llvm.ptr - // CHECK: %{{.*}} = omp.alloc_shared_mem %[[N]] x !llvm.ptr {alignment = 16 : i64} : (i32) -> !llvm.ptr - %2 = omp.alloc_shared_mem %n x !llvm.ptr {alignment = 16} : (i32) -> !llvm.ptr + // CHECK: %{{.*}} = omp.alloc_shared_mem %[[N]] x !llvm.ptr : (i32) align(16) -> !llvm.ptr + %2 = omp.alloc_shared_mem %n x !llvm.ptr : (i32) align(16) -> !llvm.ptr return } @@ -3716,7 +3716,11 @@ func.func @omp_alloc_shared_mem(%n: i32) { func.func @omp_free_shared_mem(%n: i64) { // CHECK: %[[PTR:.*]] = omp.alloc_shared_mem %[[N]] x f32 : (i64) -> !llvm.ptr %0 = omp.alloc_shared_mem %n x f32 : (i64) -> !llvm.ptr - // CHECK: omp.free_shared_mem %[[PTR]] : !llvm.ptr - omp.free_shared_mem %0 : !llvm.ptr + // CHECK: omp.free_shared_mem [%[[N]] x f32 : (i64)] %[[PTR]] : !llvm.ptr + omp.free_shared_mem [%n x f32 : (i64)] %0 : !llvm.ptr + // CHECK: %[[PTR:.*]] = omp.alloc_shared_mem %[[N]] x f32 : (i64) align(32) -> !llvm.ptr + %1 = omp.alloc_shared_mem %n x f32 : (i64) align(32) -> !llvm.ptr + // CHECK: omp.free_shared_mem [%[[N]] x f32 : (i64) align(32)] %[[PTR]] : !llvm.ptr + omp.free_shared_mem [%n x f32 : (i64) align(32)] %1 : !llvm.ptr return } diff --git a/mlir/test/Target/LLVMIR/omptarget-device-shared-mem.mlir b/mlir/test/Target/LLVMIR/omptarget-device-shared-mem.mlir index 72b0a2daadfc3..cdebebc3ed233 100644 --- a/mlir/test/Target/LLVMIR/omptarget-device-shared-mem.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-device-shared-mem.mlir @@ -19,24 +19,24 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo // CHECK: %[[ALLOC3_SZ:.*]] = mul i64 128, %[[N1]] // CHECK-NEXT: %[[ALLOC3:.*]] = call align 8 ptr @__kmpc_alloc_shared(i64 %[[ALLOC3_SZ]]) - %3 = omp.alloc_shared_mem %n1 x vector<16xf32> {alignment = 128} : (i64) -> !llvm.ptr + %3 = omp.alloc_shared_mem %n1 x vector<16xf32> : (i64) align(128) -> !llvm.ptr // CHECK: %[[CAST_N0_1:.*]] = zext i32 %[[N0]] to i64 // CHECK-NEXT: %[[FREE0_SZ:.*]] = mul i64 8, %[[CAST_N0_1]] // CHECK-NEXT: call void @__kmpc_free_shared(ptr %[[ALLOC0]], i64 %[[FREE0_SZ]]) - omp.free_shared_mem %0 : !llvm.ptr + omp.free_shared_mem [%n0 x i64 : (i32)] %0 : !llvm.ptr // CHECK: %[[FREE1_SZ:.*]] = mul i64 8, %[[N1]] // CHECK-NEXT: call void @__kmpc_free_shared(ptr %[[ALLOC1]], i64 %[[FREE1_SZ]]) - omp.free_shared_mem %1 : !llvm.ptr + omp.free_shared_mem [%n1 x i64 : (i64)] %1 : !llvm.ptr // CHECK: %[[FREE2_SZ:.*]] = mul i64 64, %[[N1]] // CHECK-NEXT: call void @__kmpc_free_shared(ptr %[[ALLOC2]], i64 %[[FREE2_SZ]]) - omp.free_shared_mem %2 : !llvm.ptr + omp.free_shared_mem [%n1 x vector<16xf32> : (i64)] %2 : !llvm.ptr // CHECK: %[[FREE3_SZ:.*]] = mul i64 128, %[[N1]] // CHECK-NEXT: call void @__kmpc_free_shared(ptr %[[ALLOC3]], i64 %[[FREE3_SZ]]) - omp.free_shared_mem %3 : !llvm.ptr + omp.free_shared_mem [%n1 x vector<16xf32> : (i64) align(128)] %3 : !llvm.ptr llvm.return } } _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
