https://github.com/skatrak created 
https://github.com/llvm/llvm-project/pull/161862

This patch introduces the `omp.alloc_shared_mem` and `omp.free_shared_mem` 
operations to represent explicit allocations and deallocations of shared memory 
across threads in a team, mirroring the existing `omp.target_allocmem` and 
`omp.target_freemem`.

The `omp.alloc_shared_mem` op goes through the same Flang-specific 
transformations as `omp.target_allocmem`, so that the size of the buffer can be 
properly calculated when translating to LLVM IR.

The corresponding runtime functions produced for these new operations are 
`__kmpc_alloc_shared` and `__kmpc_free_shared`, which previously could only be 
created for implicit allocations (e.g. privatized and reduction variables).

>From d9139666d736567cafaa88ded7ecc2c6c169e182 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <[email protected]>
Date: Fri, 12 Sep 2025 15:56:04 +0100
Subject: [PATCH] [Flang][MLIR][OpenMP] Add explicit shared memory
 (de-)allocation ops

This patch introduces the `omp.alloc_shared_mem` and `omp.free_shared_mem`
operations to represent explicit allocations and deallocations of shared memory
across threads in a team, mirroring the existing `omp.target_allocmem` and
`omp.target_freemem`.

The `omp.alloc_shared_mem` op goes through the same Flang-specific
transformations as `omp.target_allocmem`, so that the size of the buffer can be
properly calculated when translating to LLVM IR.

The corresponding runtime functions produced for these new operations are
`__kmpc_alloc_shared` and `__kmpc_free_shared`, which previously could only be
created for implicit allocations (e.g. privatized and reduction variables).
---
 flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp | 43 ++++++++----
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       | 23 +++++++
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 29 +++++---
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 62 +++++++++++++++++
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 22 ++++++
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 68 +++++++++++++++----
 mlir/test/Dialect/OpenMP/invalid.mlir         | 28 ++++++++
 mlir/test/Dialect/OpenMP/ops.mlir             | 31 ++++++++-
 8 files changed, 268 insertions(+), 38 deletions(-)

diff --git a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp 
b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp
index 381b2a29c517a..c1a6b06d6a52b 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp
@@ -222,35 +222,47 @@ static mlir::Type convertObjectType(const 
fir::LLVMTypeConverter &converter,
   return converter.convertType(firType);
 }
 
-// FIR Op specific conversion for TargetAllocMemOp
-struct TargetAllocMemOpConversion
-    : public OpenMPFIROpConversion<mlir::omp::TargetAllocMemOp> {
-  using OpenMPFIROpConversion::OpenMPFIROpConversion;
+// FIR Op specific conversion for allocation operations
+template <typename T>
+struct AllocMemOpConversion : public OpenMPFIROpConversion<T> {
+  using OpenMPFIROpConversion<T>::OpenMPFIROpConversion;
 
   llvm::LogicalResult
-  matchAndRewrite(mlir::omp::TargetAllocMemOp allocmemOp, OpAdaptor adaptor,
+  matchAndRewrite(T allocmemOp,
+                  typename OpenMPFIROpConversion<T>::OpAdaptor adaptor,
                   mlir::ConversionPatternRewriter &rewriter) const override {
     mlir::Type heapTy = allocmemOp.getAllocatedType();
     mlir::Location loc = allocmemOp.getLoc();
-    auto ity = lowerTy().indexType();
+    auto ity = OpenMPFIROpConversion<T>::lowerTy().indexType();
     mlir::Type dataTy = fir::unwrapRefType(heapTy);
-    mlir::Type llvmObjectTy = convertObjectType(lowerTy(), dataTy);
+    mlir::Type llvmObjectTy =
+        convertObjectType(OpenMPFIROpConversion<T>::lowerTy(), dataTy);
     if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy)))
-      TODO(loc, "omp.target_allocmem codegen of derived type with length "
-                "parameters");
+      TODO(loc, allocmemOp->getName().getStringRef() +
+                    " codegen of derived type with length parameters");
     mlir::Value size = fir::computeElementDistance(
-        loc, llvmObjectTy, ity, rewriter, lowerTy().getDataLayout());
+        loc, llvmObjectTy, ity, rewriter,
+        OpenMPFIROpConversion<T>::lowerTy().getDataLayout());
     if (auto scaleSize = fir::genAllocationScaleSize(
             loc, allocmemOp.getInType(), ity, rewriter))
       size = rewriter.create<mlir::LLVM::MulOp>(loc, ity, size, scaleSize);
-    for (mlir::Value opnd : adaptor.getOperands().drop_front())
+    for (mlir::Value opnd : adaptor.getTypeparams())
+      size = rewriter.create<mlir::LLVM::MulOp>(
+          loc, ity, size,
+          integerCast(OpenMPFIROpConversion<T>::lowerTy(), loc, rewriter, ity,
+                      opnd));
+    for (mlir::Value opnd : adaptor.getShape())
       size = rewriter.create<mlir::LLVM::MulOp>(
-          loc, ity, size, integerCast(lowerTy(), loc, rewriter, ity, opnd));
-    auto mallocTyWidth = lowerTy().getIndexTypeBitwidth();
+          loc, ity, size,
+          integerCast(OpenMPFIROpConversion<T>::lowerTy(), loc, rewriter, ity,
+                      opnd));
+    auto mallocTyWidth =
+        OpenMPFIROpConversion<T>::lowerTy().getIndexTypeBitwidth();
     auto mallocTy =
         mlir::IntegerType::get(rewriter.getContext(), mallocTyWidth);
     if (mallocTyWidth != ity.getIntOrFloatBitWidth())
-      size = integerCast(lowerTy(), loc, rewriter, mallocTy, size);
+      size = integerCast(OpenMPFIROpConversion<T>::lowerTy(), loc, rewriter,
+                         mallocTy, size);
     rewriter.modifyOpInPlace(allocmemOp, [&]() {
       allocmemOp.setInType(rewriter.getI8Type());
       allocmemOp.getTypeparamsMutable().clear();
@@ -265,5 +277,6 @@ void fir::populateOpenMPFIRToLLVMConversionPatterns(
     const LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns) {
   patterns.add<MapInfoOpConversion>(converter);
   patterns.add<PrivateClauseOpConversion>(converter);
-  patterns.add<TargetAllocMemOpConversion>(converter);
+  patterns.add<AllocMemOpConversion<mlir::omp::TargetAllocMemOp>,
+               AllocMemOpConversion<mlir::omp::AllocSharedMemOp>>(converter);
 }
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h 
b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 02d61c1a3626a..d8e5f8cf5a45e 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -2950,6 +2950,17 @@ class OpenMPIRBuilder {
   LLVM_ABI CallInst *createOMPFree(const LocationDescription &Loc, Value *Addr,
                                    Value *Allocator, std::string Name = "");
 
+  /// Create a runtime call for kmpc_alloc_shared.
+  ///
+  /// \param Loc The insert and source location description.
+  /// \param Size Size of allocated memory space.
+  /// \param Name Name of call Instruction.
+  ///
+  /// \returns CallInst to the kmpc_alloc_shared call.
+  LLVM_ABI CallInst *createOMPAllocShared(const LocationDescription &Loc,
+                                          Value *Size,
+                                          const Twine &Name = Twine(""));
+
   /// Create a runtime call for kmpc_alloc_shared.
   ///
   /// \param Loc The insert and source location description.
@@ -2961,6 +2972,18 @@ class OpenMPIRBuilder {
                                           Type *VarType,
                                           const Twine &Name = Twine(""));
 
+  /// Create a runtime call for kmpc_free_shared.
+  ///
+  /// \param Loc The insert and source location description.
+  /// \param Addr Value obtained from the corresponding kmpc_alloc_shared call.
+  /// \param Size Size of allocated memory space.
+  /// \param Name Name of call Instruction.
+  ///
+  /// \returns CallInst to the kmpc_free_shared call.
+  LLVM_ABI CallInst *createOMPFreeShared(const LocationDescription &Loc,
+                                         Value *Addr, Value *Size,
+                                         const Twine &Name = Twine(""));
+
   /// Create a runtime call for kmpc_free_shared.
   ///
   /// \param Loc The insert and source location description.
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp 
b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index bd483aa2c5e02..a18db939b5876 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -6855,32 +6855,45 @@ CallInst *OpenMPIRBuilder::createOMPFree(const 
LocationDescription &Loc,
 }
 
 CallInst *OpenMPIRBuilder::createOMPAllocShared(const LocationDescription &Loc,
-                                                Type *VarType,
+                                                Value *Size,
                                                 const Twine &Name) {
   IRBuilder<>::InsertPointGuard IPG(Builder);
   updateToLocation(Loc);
 
-  const DataLayout &DL = M.getDataLayout();
-  Value *Args[] = {Builder.getInt64(DL.getTypeStoreSize(VarType))};
+  Value *Args[] = {Size};
   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_alloc_shared);
   CallInst *Call = Builder.CreateCall(Fn, Args, Name);
-  Call->addRetAttr(
-      Attribute::getWithAlignment(M.getContext(), DL.getPrefTypeAlign(Int64)));
+  Call->addRetAttr(Attribute::getWithAlignment(
+      M.getContext(), M.getDataLayout().getPrefTypeAlign(Int64)));
   return Call;
 }
 
+CallInst *OpenMPIRBuilder::createOMPAllocShared(const LocationDescription &Loc,
+                                                Type *VarType,
+                                                const Twine &Name) {
+  return createOMPAllocShared(
+      Loc, Builder.getInt64(M.getDataLayout().getTypeStoreSize(VarType)), 
Name);
+}
+
 CallInst *OpenMPIRBuilder::createOMPFreeShared(const LocationDescription &Loc,
-                                               Value *Addr, Type *VarType,
+                                               Value *Addr, Value *Size,
                                                const Twine &Name) {
   IRBuilder<>::InsertPointGuard IPG(Builder);
   updateToLocation(Loc);
 
-  Value *Args[] = {
-      Addr, Builder.getInt64(M.getDataLayout().getTypeStoreSize(VarType))};
+  Value *Args[] = {Addr, Size};
   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_free_shared);
   return Builder.CreateCall(Fn, Args, Name);
 }
 
+CallInst *OpenMPIRBuilder::createOMPFreeShared(const LocationDescription &Loc,
+                                               Value *Addr, Type *VarType,
+                                               const Twine &Name) {
+  return createOMPFreeShared(
+      Loc, Addr, Builder.getInt64(M.getDataLayout().getTypeStoreSize(VarType)),
+      Name);
+}
+
 CallInst *OpenMPIRBuilder::createOMPInteropInit(
     const LocationDescription &Loc, Value *InteropVar,
     omp::OMPInteropType InteropType, Value *Device, Value *NumDependences,
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td 
b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 8b206f58c7733..fa037c2ff9496 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -2202,6 +2202,68 @@ def TargetFreeMemOp : OpenMP_Op<"target_freemem",
   Arg<I64, "", [MemFree]>:$heapref
   );
   let assemblyFormat = "$device `,` $heapref attr-dict `:` type($device) `,` 
qualified(type($heapref))";
+  let hasVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// AllocSharedMemOp
+//===----------------------------------------------------------------------===//
+
+def AllocSharedMemOp : OpenMP_Op<"alloc_shared_mem", traits = [
+    AttrSizedOperandSegments
+  ], clauses = [
+    OpenMP_HeapAllocClause
+  ]> {
+  let summary = "allocate storage on shared memory for an object of a given 
type";
+
+  let description = [{
+    Allocates memory shared across threads of a team for an object of the given
+    type. Returns a pointer representing the allocated memory. The memory is
+    uninitialized after allocation. Operations must be paired with
+    `omp.free_shared` to avoid memory leaks.
+
+    ```mlir
+      // Allocate a static 3x3 integer vector.
+      %ptr_shared = omp.alloc_shared_mem vector<3x3xi32> : !llvm.ptr
+      // ...
+      omp.free_shared_mem %ptr_shared : !llvm.ptr
+    ```
+  }] # clausesDescription;
+
+  let results = (outs OpenMP_PointerLikeType);
+  let assemblyFormat = clausesAssemblyFormat # " attr-dict `:` type(results)";
+}
+
+//===----------------------------------------------------------------------===//
+// FreeSharedMemOp
+//===----------------------------------------------------------------------===//
+
+def FreeSharedMemOp : OpenMP_Op<"free_shared_mem", [MemoryEffects<[MemFree]>]> 
{
+  let summary = "free shared memory";
+
+  let description = [{
+    Deallocates shared memory that was previously allocated by an
+    `omp.alloc_shared_mem` operation. After this operation, the deallocated
+    memory is in an undefined state and should not be accessed.
+    It is crucial to ensure that all accesses to the memory region are 
completed
+    before `omp.alloc_shared_mem` is called to avoid undefined behavior.
+
+    ```mlir
+      // Example of allocating and freeing shared memory.
+      %ptr_shared = omp.alloc_shared_mem vector<3x3xi32> : !llvm.ptr
+      // ...
+      omp.free_shared_mem %ptr_shared : !llvm.ptr
+    ```
+
+    The `heapref` operand represents the pointer to shared memory to be
+    deallocated, previously returned by `omp.alloc_shared_mem`.
+  }];
+
+  let arguments = (ins
+    Arg<OpenMP_PointerLikeType, "", [MemFree]>:$heapref
+  );
+  let assemblyFormat = "$heapref attr-dict `:` type($heapref)";
+  let hasVerifier = 1;
 }
 
 
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp 
b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index fabb1b8c173a2..3b48dce4b7989 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -4161,6 +4161,28 @@ LogicalResult AllocateDirOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// TargetFreeMemOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult TargetFreeMemOp::verify() {
+  return getHeapref().getDefiningOp<TargetAllocMemOp>()
+             ? success()
+             : emitOpError() << "'heapref' operand must be defined by an "
+                                "'omp.target_allocmem' op";
+}
+
+//===----------------------------------------------------------------------===//
+// FreeSharedMemOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult FreeSharedMemOp::verify() {
+  return getHeapref().getDefiningOp<AllocSharedMemOp>()
+             ? success()
+             : emitOpError() << "'heapref' operand must be defined by an "
+                                "'omp.alloc_shared_memory' op";
+}
+
 
//===----------------------------------------------------------------------===//
 // WorkdistributeOp
 
//===----------------------------------------------------------------------===//
diff --git 
a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp 
b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 80e052105dc4c..3accca891ba9c 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -6104,11 +6104,9 @@ static bool isTargetDeviceOp(Operation *op) {
   // by taking it in as an operand, so we must always lower these in
   // some manner or result in an ICE (whether they end up in a no-op
   // or otherwise).
-  if (mlir::isa<omp::ThreadprivateOp>(op))
-    return true;
-
-  if (mlir::isa<omp::TargetAllocMemOp>(op) ||
-      mlir::isa<omp::TargetFreeMemOp>(op))
+  if (mlir::isa<omp::ThreadprivateOp, omp::TargetAllocMemOp,
+                omp::TargetFreeMemOp, omp::AllocSharedMemOp,
+                omp::FreeSharedMemOp>(op))
     return true;
 
   if (auto parentFn = op->getParentOfType<LLVM::LLVMFuncOp>())
@@ -6135,6 +6133,21 @@ static llvm::Function 
*getOmpTargetAlloc(llvm::IRBuilderBase &builder,
   return func;
 }
 
+static llvm::Value *
+getAllocationSize(llvm::IRBuilderBase &builder,
+                  LLVM::ModuleTranslation &moduleTranslation, Type allocatedTy,
+                  OperandRange typeparams, OperandRange shape) {
+  llvm::DataLayout dataLayout =
+      moduleTranslation.getLLVMModule()->getDataLayout();
+  llvm::Type *llvmHeapTy = moduleTranslation.convertType(allocatedTy);
+  llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy);
+  llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
+  for (auto typeParam : typeparams)
+    allocSize =
+        builder.CreateMul(allocSize, moduleTranslation.lookupValue(typeParam));
+  return allocSize;
+}
+
 static LogicalResult
 convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
                         LLVM::ModuleTranslation &moduleTranslation) {
@@ -6149,14 +6162,9 @@ convertTargetAllocMemOp(Operation &opInst, 
llvm::IRBuilderBase &builder,
   mlir::Value deviceNum = allocMemOp.getDevice();
   llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum);
   // Get the allocation size.
-  llvm::DataLayout dataLayout = llvmModule->getDataLayout();
-  mlir::Type heapTy = allocMemOp.getAllocatedType();
-  llvm::Type *llvmHeapTy = moduleTranslation.convertType(heapTy);
-  llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy);
-  llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
-  for (auto typeParam : allocMemOp.getTypeparams())
-    allocSize =
-        builder.CreateMul(allocSize, moduleTranslation.lookupValue(typeParam));
+  llvm::Value *allocSize = getAllocationSize(
+      builder, moduleTranslation, allocMemOp.getAllocatedType(),
+      allocMemOp.getTypeparams(), allocMemOp.getShape());
   // Create call to "omp_target_alloc" with the args as translated llvm values.
   llvm::CallInst *call =
       builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum});
@@ -6167,6 +6175,19 @@ convertTargetAllocMemOp(Operation &opInst, 
llvm::IRBuilderBase &builder,
   return success();
 }
 
+static LogicalResult
+convertAllocSharedMemOp(omp::AllocSharedMemOp allocMemOp,
+                        llvm::IRBuilderBase &builder,
+                        LLVM::ModuleTranslation &moduleTranslation) {
+  llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+  llvm::Value *size = getAllocationSize(
+      builder, moduleTranslation, allocMemOp.getAllocatedType(),
+      allocMemOp.getTypeparams(), allocMemOp.getShape());
+  moduleTranslation.mapValue(allocMemOp.getResult(),
+                             ompBuilder->createOMPAllocShared(builder, size));
+  return success();
+}
+
 static llvm::Function *getOmpTargetFree(llvm::IRBuilderBase &builder,
                                         llvm::Module *llvmModule) {
   llvm::Type *ptrTy = builder.getPtrTy(0);
@@ -6202,6 +6223,21 @@ convertTargetFreeMemOp(Operation &opInst, 
llvm::IRBuilderBase &builder,
   return success();
 }
 
+static LogicalResult
+convertFreeSharedMemOp(omp::FreeSharedMemOp freeMemOp,
+                       llvm::IRBuilderBase &builder,
+                       LLVM::ModuleTranslation &moduleTranslation) {
+  llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+  auto allocMemOp =
+      freeMemOp.getHeapref().getDefiningOp<omp::AllocSharedMemOp>();
+  llvm::Value *size = getAllocationSize(
+      builder, moduleTranslation, allocMemOp.getAllocatedType(),
+      allocMemOp.getTypeparams(), allocMemOp.getShape());
+  ompBuilder->createOMPFreeShared(
+      builder, moduleTranslation.lookupValue(freeMemOp.getHeapref()), size);
+  return success();
+}
+
 /// Given an OpenMP MLIR operation, create the corresponding LLVM IR (including
 /// OpenMP runtime calls).
 static LogicalResult
@@ -6382,6 +6418,12 @@ convertHostOrTargetOperation(Operation *op, 
llvm::IRBuilderBase &builder,
           .Case([&](omp::TargetFreeMemOp) {
             return convertTargetFreeMemOp(*op, builder, moduleTranslation);
           })
+          .Case([&](omp::AllocSharedMemOp op) {
+            return convertAllocSharedMemOp(op, builder, moduleTranslation);
+          })
+          .Case([&](omp::FreeSharedMemOp op) {
+            return convertFreeSharedMemOp(op, builder, moduleTranslation);
+          })
           .Default([&](Operation *inst) {
             return inst->emitError()
                    << "not yet implemented: " << inst->getName();
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir 
b/mlir/test/Dialect/OpenMP/invalid.mlir
index 0cc4b522db466..9f28172161fa8 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -3153,3 +3153,31 @@ func.func @target_allocmem_invalid_bindc_name(%device : 
i32) -> () {
   %0 = omp.target_allocmem %device : i32, i64 {bindc_name=2}
   return
 }
+
+// -----
+func.func @target_freemem_invalid_ptr(%device : i32, %ptr : i64) -> () {
+  // expected-error @below {{op 'heapref' operand must be defined by an 
'omp.target_allocmem' op}}
+  omp.target_freemem %device, %ptr : i32, i64
+  return
+}
+
+// -----
+func.func @alloc_shared_mem_invalid_uniq_name() -> () {
+  // expected-error @below {{op attribute 'uniq_name' failed to satisfy 
constraint: string attribute}}
+  %0 = omp.alloc_shared_mem i64 {uniq_name=2}
+  return
+}
+
+// -----
+func.func @alloc_shared_mem_invalid_bindc_name() -> () {
+  // expected-error @below {{op attribute 'bindc_name' failed to satisfy 
constraint: string attribute}}
+  %0 = omp.alloc_shared_mem i64 {bindc_name=2}
+  return
+}
+
+// -----
+func.func @free_shared_mem_invalid_ptr(%ptr : !llvm.ptr) -> () {
+  // expected-error @below {{op 'heapref' operand must be defined by an 
'omp.alloc_shared_memory' op}}
+  omp.free_shared_mem %ptr : !llvm.ptr
+  return
+}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir 
b/mlir/test/Dialect/OpenMP/ops.mlir
index 9e7287178ff66..55e6d77857972 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -3339,9 +3339,36 @@ func.func @omp_target_allocmem(%device: i32, %x: index, 
%y: index, %z: i32) {
 }
 
 // CHECK-LABEL: func.func @omp_target_freemem(
-// CHECK-SAME: %[[DEVICE:.*]]: i32, %[[PTR:.*]]: i64) {
-func.func @omp_target_freemem(%device : i32, %ptr : i64) {
+// CHECK-SAME: %[[DEVICE:.*]]: i32) {
+func.func @omp_target_freemem(%device : i32) {
+  // CHECK: %[[PTR:.*]] = omp.target_allocmem
+  %ptr = omp.target_allocmem %device : i32, i64
   // CHECK: omp.target_freemem %[[DEVICE]], %[[PTR]] : i32, i64
   omp.target_freemem %device, %ptr : i32, i64
   return
 }
+
+// CHECK-LABEL: func.func @omp_alloc_shared_mem(
+// CHECK-SAME: %[[X:.*]]: index, %[[Y:.*]]: index, %[[Z:.*]]: i32) {
+func.func @omp_alloc_shared_mem(%x: index, %y: index, %z: i32) {
+  // CHECK: %{{.*}} = omp.alloc_shared_mem i64 : !llvm.ptr
+  %0 = omp.alloc_shared_mem i64 : !llvm.ptr
+  // CHECK: %{{.*}} = omp.alloc_shared_mem vector<16x16xf32> {bindc_name = 
"bindc", uniq_name = "uniq"} : !llvm.ptr
+  %1 = omp.alloc_shared_mem vector<16x16xf32> {uniq_name="uniq", 
bindc_name="bindc"} : !llvm.ptr
+  // CHECK: %{{.*}} = omp.alloc_shared_mem !llvm.ptr(%[[X]], %[[Y]], %[[Z]] : 
index, index, i32) : !llvm.ptr
+  %2 = omp.alloc_shared_mem !llvm.ptr(%x, %y, %z : index, index, i32) : 
!llvm.ptr
+  // CHECK: %{{.*}} = omp.alloc_shared_mem !llvm.ptr, %[[X]], %[[Y]] : 
!llvm.ptr
+  %3 = omp.alloc_shared_mem !llvm.ptr, %x, %y : !llvm.ptr
+  // CHECK: %{{.*}} = omp.alloc_shared_mem !llvm.ptr(%[[X]], %[[Y]], %[[Z]] : 
index, index, i32), %[[X]], %[[Y]] : !llvm.ptr
+  %4 = omp.alloc_shared_mem !llvm.ptr(%x, %y, %z : index, index, i32), %x, %y 
: !llvm.ptr
+  return
+}
+
+// CHECK-LABEL: func.func @omp_free_shared_mem() {
+func.func @omp_free_shared_mem() {
+  // CHECK: %[[PTR:.*]] = omp.alloc_shared_mem
+  %0 = omp.alloc_shared_mem i64 : !llvm.ptr
+  // CHECK: omp.free_shared_mem %[[PTR]] : !llvm.ptr
+  omp.free_shared_mem %0 : !llvm.ptr
+  return
+}

_______________________________________________
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to