https://github.com/Saieiei updated 
https://github.com/llvm/llvm-project/pull/199967

>From 31ed7583d7fb069d3f42dcfa1121808446472655 Mon Sep 17 00:00:00 2001
From: Sairudra More <[email protected]>
Date: Wed, 27 May 2026 06:56:02 -0500
Subject: [PATCH] [flang][OpenMP] Lower target in_reduction for host fallback

Teach Flang lowering and MLIR OpenMP translation to carry
in_reduction through omp.target for the host-fallback path.

The translation looks up task reduction-private storage with
__kmpc_task_reduction_get_th_data and binds the target region's
in_reduction block argument to that private pointer, so uses inside the
region do not keep referring to the original variable.

The patch also preserves in_reduction operands in the TargetOp builder
path and ensures target in_reduction list items are mapped into the
target region when needed.

The device/offload-entry path remains diagnosed as not yet implemented.
---
 flang/lib/Lower/OpenMP/OpenMP.cpp             |  80 ++++++++---
 .../Lower/OpenMP/Todo/target-inreduction.f90  |  15 ---
 .../OpenMP/target-inreduction-unused.f90      |  27 ++++
 .../test/Lower/OpenMP/target-inreduction.f90  |  28 ++++
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  |  15 ++-
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 124 ++++++++++++++++--
 mlir/test/Dialect/OpenMP/invalid.mlir         |  60 +++++++++
 .../LLVMIR/openmp-target-in-reduction.mlir    |  50 +++++++
 mlir/test/Target/LLVMIR/openmp-todo.mlir      |  86 +++++++++++-
 9 files changed, 432 insertions(+), 53 deletions(-)
 delete mode 100644 flang/test/Lower/OpenMP/Todo/target-inreduction.f90
 create mode 100644 flang/test/Lower/OpenMP/target-inreduction-unused.f90
 create mode 100644 flang/test/Lower/OpenMP/target-inreduction.f90
 create mode 100644 mlir/test/Target/LLVMIR/openmp-target-in-reduction.mlir

diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp 
b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 7cb7e379eb503..7a96e903ef2ff 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -433,13 +433,16 @@ static void bindEntryBlockArgs(lower::AbstractConverter 
&converter,
               .first);
   };
 
-  // Process in clause name alphabetical order to match block arguments order.
   // Do not bind host_eval variables because they cannot be used inside of the
   // corresponding region, except for very specific cases handled separately.
+  // Bind map before in_reduction so that for target in_reduction list items
+  // (which are also implicitly mapped), the in_reduction binding wins and
+  // in-body references use the reduction-private block argument, not the
+  // mapped/original address.
   bindMapLike(args.hasDeviceAddr.objects, op.getHasDeviceAddrBlockArgs());
+  bindMapLike(args.map.objects, op.getMapBlockArgs());
   bindPrivateLike(args.inReduction.objects, args.inReduction.vars,
                   op.getInReductionBlockArgs());
-  bindMapLike(args.map.objects, op.getMapBlockArgs());
   bindPrivateLike(args.priv.objects, args.priv.vars, op.getPrivateBlockArgs());
   bindPrivateLike(args.reduction.objects, args.reduction.vars,
                   op.getReductionBlockArgs());
@@ -1873,6 +1876,7 @@ genTargetClauses(lower::AbstractConverter &converter,
                  mlir::omp::TargetOperands &clauseOps,
                  DefaultMapsTy &defaultMaps,
                  llvm::SmallVectorImpl<Object> &hasDeviceAddrObjects,
+                 llvm::SmallVectorImpl<Object> &inReductionObjects,
                  llvm::SmallVectorImpl<Object> &isDevicePtrObjects,
                  llvm::SmallVectorImpl<Object> &mapObjects) {
   ClauseProcessor cp(converter, semaCtx, clauses);
@@ -1887,13 +1891,14 @@ genTargetClauses(lower::AbstractConverter &converter,
     hostEvalInfo->collectValues(clauseOps.hostEvalVars);
   }
   cp.processIf(llvm::omp::Directive::OMPD_target, clauseOps);
+  cp.processInReduction(loc, clauseOps, inReductionObjects);
   cp.processIsDevicePtr(stmtCtx, clauseOps, isDevicePtrObjects);
   cp.processMap(loc, stmtCtx, clauseOps, llvm::omp::Directive::OMPD_unknown,
                 &mapObjects);
   cp.processNowait(clauseOps);
   cp.processThreadLimit(stmtCtx, clauseOps);
 
-  cp.processTODO<clause::Allocate, clause::InReduction, 
clause::UsesAllocators>(
+  cp.processTODO<clause::Allocate, clause::UsesAllocators>(
       loc, llvm::omp::Directive::OMPD_target);
 
   // `target private(..)` is only supported in delayed privatization mode.
@@ -2932,10 +2937,10 @@ genTargetOp(lower::AbstractConverter &converter, 
lower::SymMap &symTable,
   mlir::omp::TargetOperands clauseOps;
   DefaultMapsTy defaultMaps;
   llvm::SmallVector<Object> mapObjects, hasDeviceAddrObjects,
-      isDevicePtrObjects;
+      inReductionObjects, isDevicePtrObjects;
   genTargetClauses(converter, semaCtx, symTable, stmtCtx, eval, item->clauses,
                    loc, clauseOps, defaultMaps, hasDeviceAddrObjects,
-                   isDevicePtrObjects, mapObjects);
+                   inReductionObjects, isDevicePtrObjects, mapObjects);
 
   if (!isDevicePtrObjects.empty()) {
     // is_device_ptr maps get duplicated so the clause and synthesized
@@ -2989,7 +2994,16 @@ genTargetOp(lower::AbstractConverter &converter, 
lower::SymMap &symTable,
   // symbols used inside the region that do not have explicit data-environment
   // attribute clauses (neither data-sharing; e.g. `private`, nor `map`
   // clauses).
-  auto captureImplicitMap = [&](const semantics::Symbol &sym) {
+  //
+  // When `forceAddressPreserving` is set, the symbol is force-mapped as an
+  // address-preserving `capture(ByRef)` with implicit `tofrom` flags,
+  // bypassing the scalar default capture rules. This is used for `target
+  // in_reduction` list items, whose mapped pointer is passed as the `orig`
+  // argument of `__kmpc_task_reduction_get_th_data`; a ByCopy scalar capture
+  // would break the runtime lookup against the enclosing taskgroup's
+  // task_reduction descriptor.
+  auto captureImplicitMap = [&](const semantics::Symbol &sym,
+                                bool forceAddressPreserving = false) {
     // Structure component symbols don't have bindings, and can only be
     // explicitly mapped individually. If a member is captured implicitly
     // we map the entirety of the derived type when we find its symbol.
@@ -2998,12 +3012,13 @@ genTargetOp(lower::AbstractConverter &converter, 
lower::SymMap &symTable,
 
     // if the symbol is part of an already mapped common block, do not make a
     // map for it.
-    if (const Fortran::semantics::Symbol *common =
-            Fortran::semantics::FindCommonBlockContaining(sym.GetUltimate()))
-      if (llvm::any_of(mapObjects, [=](const Object &object) {
-            return object.sym() == common;
-          }))
-        return;
+    if (!forceAddressPreserving)
+      if (const Fortran::semantics::Symbol *common =
+              Fortran::semantics::FindCommonBlockContaining(sym.GetUltimate()))
+        if (llvm::any_of(mapObjects, [=](const Object &object) {
+              return object.sym() == common;
+            }))
+          return;
 
     // If we come across a symbol without a symbol address, we
     // return as we cannot process it, this is intended as a
@@ -3018,7 +3033,8 @@ genTargetOp(lower::AbstractConverter &converter, 
lower::SymMap &symTable,
     // dynamic indices on the device (e.g., const_array(runtime_index)).
     // Also, character scalar parameters must be mapped if they have dynamic
     // substring access.
-    if (semantics::IsNamedConstant(sym) && sym.Rank() == 0 &&
+    if (!forceAddressPreserving && semantics::IsNamedConstant(sym) &&
+        sym.Rank() == 0 &&
         !symbolsWithDynamicSubstring.contains(&sym.GetUltimate()))
       return;
 
@@ -3047,14 +3063,32 @@ genTargetOp(lower::AbstractConverter &converter, 
lower::SymMap &symTable,
       if (auto refType = mlir::dyn_cast<fir::ReferenceType>(baseOp.getType()))
         eleType = refType.getElementType();
 
+      // `target in_reduction` list items must keep the original variable
+      // address (ByRef + implicit tofrom) so the runtime lookup receives the
+      // variable address; all other implicit captures follow the scalar
+      // default mapping rules.
       std::pair<mlir::omp::ClauseMapFlags, mlir::omp::VariableCaptureKind>
-          mapFlagAndKind = getImplicitMapTypeAndKind(
-              firOpBuilder, converter, defaultMaps, eleType, loc, sym);
+          mapFlagAndKind =
+              forceAddressPreserving
+                  ? std::pair<
+                        mlir::omp::ClauseMapFlags,
+                        mlir::omp::
+                            VariableCaptureKind>{mlir::omp::ClauseMapFlags::
+                                                         implicit |
+                                                     
mlir::omp::ClauseMapFlags::
+                                                         to |
+                                                     
mlir::omp::ClauseMapFlags::
+                                                         from,
+                                                 mlir::omp::
+                                                     
VariableCaptureKind::ByRef}
+                  : getImplicitMapTypeAndKind(firOpBuilder, converter,
+                                              defaultMaps, eleType, loc, sym);
 
       mlir::FlatSymbolRefAttr mapperId;
       auto defaultmapBehaviour = getDefaultmapIfPresent(defaultMaps, eleType);
-      if (defaultmapBehaviour ==
-          clause::Defaultmap::ImplicitBehavior::Default) {
+      if (!forceAddressPreserving &&
+          defaultmapBehaviour ==
+              clause::Defaultmap::ImplicitBehavior::Default) {
         const semantics::DerivedTypeSpec *typeSpec =
             sym.GetType() ? sym.GetType()->AsDerived() : nullptr;
         if (typeSpec) {
@@ -3108,6 +3142,15 @@ genTargetOp(lower::AbstractConverter &converter, 
lower::SymMap &symTable,
           Object{const_cast<semantics::Symbol *>(&sym), std::nullopt});
     }
   };
+  // OpenMP requires `in_reduction` list items on `target` to be implicitly
+  // data-mapped. Force-map them as address-preserving captures before the
+  // generic implicit-map walk so that walk treats the symbols as already
+  // mapped via `isDuplicateMappedSymbol` and does not downgrade them to
+  // ByCopy.
+  for (const Object &object : inReductionObjects)
+    if (const semantics::Symbol *sym = object.sym())
+      captureImplicitMap(*sym, /*forceAddressPreserving=*/true);
+
   lower::pft::visitAllSymbols(eval, captureImplicitMap);
 
   auto targetOp = mlir::omp::TargetOp::create(firOpBuilder, loc, clauseOps);
@@ -3120,7 +3163,8 @@ genTargetOp(lower::AbstractConverter &converter, 
lower::SymMap &symTable,
   args.hasDeviceAddr.objects = hasDeviceAddrObjects;
   args.hasDeviceAddr.vars = hasDeviceAddrBaseValues;
   args.hostEvalVars = clauseOps.hostEvalVars;
-  // TODO: Add in_reduction syms and vars.
+  args.inReduction.objects = inReductionObjects;
+  args.inReduction.vars = clauseOps.inReductionVars;
   args.map.objects = mapObjects;
   args.map.vars = mapBaseValues;
   args.priv.objects = makeObjects(dsp.getDelayedPrivSymbols());
diff --git a/flang/test/Lower/OpenMP/Todo/target-inreduction.f90 
b/flang/test/Lower/OpenMP/Todo/target-inreduction.f90
deleted file mode 100644
index e5a9cffac5a11..0000000000000
--- a/flang/test/Lower/OpenMP/Todo/target-inreduction.f90
+++ /dev/null
@@ -1,15 +0,0 @@
-! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | 
FileCheck %s
-! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -fopenmp-version=50 -o - %s 
2>&1 | FileCheck %s
-
-!===============================================================================
-! `mergeable` clause
-!===============================================================================
-
-! CHECK: not yet implemented: Unhandled clause IN_REDUCTION in TARGET construct
-subroutine omp_target_inreduction()
-  integer i
-  i = 0
-  !$omp target in_reduction(+:i)
-  i = i + 1
-  !$omp end target
-end subroutine omp_target_inreduction
diff --git a/flang/test/Lower/OpenMP/target-inreduction-unused.f90 
b/flang/test/Lower/OpenMP/target-inreduction-unused.f90
new file mode 100644
index 0000000000000..cf0d39db3e9a7
--- /dev/null
+++ b/flang/test/Lower/OpenMP/target-inreduction-unused.f90
@@ -0,0 +1,27 @@
+! RUN: bbc -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | 
FileCheck %s
+
+! Per the OpenMP spec, an in_reduction list item on a target construct is
+! implicitly data-mapped. The lowering must not rely on the variable being
+! referenced inside the target body to discover that map: here `i` only
+! appears in the in_reduction clause and is never read or written inside
+! the region. Verify that an omp.map.info for `i` is still emitted and
+! flows into the omp.target's map_entries.
+
+!CHECK-LABEL: func.func @_QPomp_target_in_reduction_unused()
+!CHECK:       %[[IDECL:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = 
"_QFomp_target_in_reduction_unusedEi"}
+!CHECK:       %[[IMAP:.*]] = omp.map.info var_ptr(%[[IDECL]]#1 : 
!fir.ref<i32>, i32) map_clauses(implicit, tofrom) capture(ByRef) -> 
!fir.ref<i32> {name = "i"}
+!CHECK:       omp.target in_reduction(@{{[^ ]+}} %[[IDECL]]#0 -> %{{[^ ]+}} : 
!fir.ref<i32>)
+!CHECK-SAME:    map_entries(%[[IMAP]] -> %{{[^ ]+}} : !fir.ref<i32>)
+
+subroutine omp_target_in_reduction_unused()
+  interface
+    subroutine sub()
+    end subroutine
+  end interface
+  integer i
+  i = 0
+  !$omp target in_reduction(+:i)
+  call sub()
+  !$omp end target
+end subroutine omp_target_in_reduction_unused
diff --git a/flang/test/Lower/OpenMP/target-inreduction.f90 
b/flang/test/Lower/OpenMP/target-inreduction.f90
new file mode 100644
index 0000000000000..3955cacb744c2
--- /dev/null
+++ b/flang/test/Lower/OpenMP/target-inreduction.f90
@@ -0,0 +1,28 @@
+! RUN: bbc -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | 
FileCheck %s
+
+! Verify that in_reduction on a target construct is lowered to an
+! omp.target with both an in_reduction clause and an implicit map_entries
+! entry for the same variable. The implicit map captures the original
+! pointer into the target region so the MLIR -> LLVM IR translation can
+! pass it to __kmpc_task_reduction_get_th_data.
+
+!CHECK-LABEL: omp.declare_reduction
+!CHECK-SAME:  @[[RED_I32_NAME:.*]] : i32 init {
+
+!CHECK-LABEL: func.func @_QPomp_target_in_reduction()
+!CHECK:       %[[IDECL:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = 
"_QFomp_target_in_reductionEi"}
+!CHECK:       %[[IMAP:.*]] = omp.map.info var_ptr(%[[IDECL]]#1 : 
!fir.ref<i32>, i32) map_clauses(implicit, tofrom) capture(ByRef) -> 
!fir.ref<i32> {name = "i"}
+!CHECK:       omp.target in_reduction(@[[RED_I32_NAME]] %[[IDECL]]#0 -> 
%[[INARG:[^ ]+]] : !fir.ref<i32>)
+!CHECK-SAME:    map_entries(%[[IMAP]] -> %{{[^ ]+}} : !fir.ref<i32>)
+!CHECK:         hlfir.declare %[[INARG]]
+!CHECK:         omp.terminator
+!CHECK:       }
+
+subroutine omp_target_in_reduction()
+  integer i
+  i = 0
+  !$omp target in_reduction(+:i)
+  i = i + 1
+  !$omp end target
+end subroutine omp_target_in_reduction
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp 
b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 7cef23bdfef18..9daef3368ec4c 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2545,8 +2545,7 @@ LogicalResult TargetUpdateOp::verify() {
 void TargetOp::build(OpBuilder &builder, OperationState &state,
                      const TargetOperands &clauses) {
   MLIRContext *ctx = builder.getContext();
-  // TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars,
-  // inReductionByref, inReductionSyms.
+  // TODO Store clauses in op: allocateVars, allocatorVars.
   TargetOp::build(
       builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{}, 
clauses.bare,
       makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
@@ -2554,9 +2553,10 @@ void TargetOp::build(OpBuilder &builder, OperationState 
&state,
       clauses.device, clauses.dynGroupprivateAccessGroup,
       clauses.dynGroupprivateFallback, clauses.dynGroupprivateSize,
       clauses.hasDeviceAddrVars, clauses.hostEvalVars, clauses.ifExpr,
-      /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
-      /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars, clauses.mapVars,
-      clauses.nowait, clauses.privateVars,
+      clauses.inReductionVars,
+      makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
+      makeArrayAttr(ctx, clauses.inReductionSyms), clauses.isDevicePtrVars,
+      clauses.mapVars, clauses.nowait, clauses.privateVars,
       makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
       clauses.threadLimitVars,
       /*private_maps=*/nullptr);
@@ -2583,6 +2583,11 @@ LogicalResult TargetOp::verify() {
   if (failed(verifyPrivateVarList(*this)))
     return failure();
 
+  if (failed(verifyReductionVarList(*this, getInReductionSyms(),
+                                    getInReductionVars(),
+                                    getInReductionByref())))
+    return failure();
+
   return verifyPrivateVarsMapping(*this);
 }
 
diff --git 
a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp 
b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 03cc2505cbbd8..08bf79dee6948 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -344,10 +344,46 @@ static LogicalResult checkImplementationStatus(Operation 
&op) {
     if (op.getHint())
       op.emitWarning("hint clause discarded");
   };
+  // in_reduction support varies by operation:
+  //   - omp.task does not implement in_reduction at all yet.
+  //   - omp.taskloop.context and omp.target implement the non-byref form; the
+  //     byref form is not implemented yet.
+  //   - omp.target additionally does not implement declare reductions that use
+  //     a cleanup region or a two-argument (alloc) initializer.
   auto checkInReduction = [&todo](auto op, LogicalResult &result) {
-    if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
-        op.getInReductionSyms())
-      result = todo("in_reduction");
+    if (isa<omp::TaskOp>(op.getOperation())) {
+      if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
+          op.getInReductionSyms())
+        result = todo("in_reduction");
+      return;
+    }
+    if (auto byrefAttr = op.getInReductionByref())
+      for (bool isByRef : *byrefAttr)
+        if (isByRef) {
+          result = todo("in_reduction with byref modifier");
+          return;
+        }
+    if (isa<omp::TargetOp>(op.getOperation())) {
+      if (auto inReductionSyms = op.getInReductionSyms()) {
+        for (auto sym :
+             (*inReductionSyms).template getAsRange<SymbolRefAttr>()) {
+          auto decl =
+              SymbolTable::lookupNearestSymbolFrom<omp::DeclareReductionOp>(
+                  op, sym);
+          // Symbol resolution is guaranteed by the op verifier.
+          if (!decl)
+            continue;
+          if (decl.getInitializerRegion().front().getNumArguments() != 1) {
+            result = todo("in_reduction with two-argument initializer");
+            return;
+          }
+          if (!decl.getCleanupRegion().empty()) {
+            result = todo("in_reduction with cleanup region");
+            return;
+          }
+        }
+      }
+    }
   };
   auto checkNowait = [&todo](auto op, LogicalResult &result) {
     if (op.getNowait())
@@ -386,14 +422,6 @@ static LogicalResult checkImplementationStatus(Operation 
&op) {
           return;
         }
   };
-  auto checkInReductionByref = [&todo](auto op, LogicalResult &result) {
-    if (auto byrefAttr = op.getInReductionByref())
-      for (bool isByRef : *byrefAttr)
-        if (isByRef) {
-          result = todo("in_reduction with byref modifier");
-          return;
-        }
-  };
   auto checkNumTeams = [&todo](auto op, LogicalResult &result) {
     if (op.hasNumTeamsMultiDim())
       result = todo("num_teams with multi-dimensional values");
@@ -453,7 +481,7 @@ static LogicalResult checkImplementationStatus(Operation 
&op) {
       })
       .Case([&](omp::TaskloopContextOp op) {
         checkAllocate(op, result);
-        checkInReductionByref(op, result);
+        checkInReduction(op, result);
         checkReduction(op, result);
         checkReductionByref(op, result);
       })
@@ -490,6 +518,10 @@ static LogicalResult checkImplementationStatus(Operation 
&op) {
       .Case([&](omp::TargetOp op) {
         checkAllocate(op, result);
         checkBare(op, result);
+        // The byref / cleanup-region / two-argument-initializer in_reduction
+        // shapes on omp.target are not implemented yet (handled by
+        // checkInReduction). The device-side / offload-entry cases are
+        // diagnosed inline in convertOmpTarget.
         checkInReduction(op, result);
         checkThreadLimit(op, result);
       })
@@ -8222,6 +8254,44 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase 
&builder,
   bool isOffloadEntry =
       isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
 
+  // Validate and resolve in_reduction clauses on omp.target. We currently
+  // only support the non-offload host-fallback path: the per-task private
+  // pointer is obtained by calling __kmpc_task_reduction_get_th_data inside
+  // the to-be-outlined target task body. Threading that pointer through the
+  // device kernel argument list is left as follow-up work.
+  SmallVector<llvm::Value *> inRedOrigPtrs;
+  if (!targetOp.getInReductionVars().empty()) {
+    if (isTargetDevice || isOffloadEntry)
+      return opInst.emitError(
+          "not yet implemented: in_reduction clause on omp.target with "
+          "offload / target-device compilation");
+    // The byref / cleanup-region / two-argument-initializer in_reduction
+    // shapes are rejected earlier by checkImplementationStatus, and symbol
+    // resolution is guaranteed by verifyReductionVarList.
+    //
+    // Each in_reduction variable must also be captured by the target via a
+    // map_entries entry referring to the same outer SSA value. OMPIRBuilder
+    // outlines the target body and only rewires uses of values that enter
+    // the kernel through the map-derived input set. The runtime call below
+    // uses that same outer SSA value as its `orig` argument, so without a
+    // matching map entry the outlined kernel would reference a value defined
+    // in the host function and fail IR verification. At this (LLVM-dialect)
+    // stage the in_reduction operand and the map var_ptr are the same value,
+    // so it cannot be a producer-level (FIR) op invariant where they differ.
+    llvm::SmallPtrSet<Value, 4> mappedVarPtrs;
+    for (Value mapV : targetOp.getMapVars())
+      if (auto mapInfo = mapV.getDefiningOp<omp::MapInfoOp>())
+        mappedVarPtrs.insert(mapInfo.getVarPtr());
+    inRedOrigPtrs.reserve(targetOp.getInReductionVars().size());
+    for (Value v : targetOp.getInReductionVars()) {
+      if (!mappedVarPtrs.contains(v))
+        return targetOp.emitError()
+               << "not yet implemented: in_reduction variable on omp.target "
+                  "must also be captured by a matching map_entries entry";
+      inRedOrigPtrs.push_back(moduleTranslation.lookupValue(v));
+    }
+  }
+
   // For some private variables, the MapsForPrivatizedVariablesPass
   // creates MapInfoOp instances. Go through the private variables and
   // the mapped variables so that during codegeneration we are able
@@ -8334,6 +8404,36 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase 
&builder,
             targetOp.getPrivateNeedsBarrier(), &mappedPrivateVars)))
       return llvm::make_error<PreviouslyReportedError>();
 
+    // Map in_reduction block arguments to the per-task private storage
+    // returned by __kmpc_task_reduction_get_th_data. The lookup must run
+    // inside the target task body so the gtid corresponds to the executing
+    // thread. The descriptor argument is NULL: the runtime walks enclosing
+    // taskgroups to locate the matching task_reduction registration for
+    // `origPtr`. Mirrors the in_reduction handling on omp.taskloop.context.
+    ArrayRef<BlockArgument> inRedBlockArgs = 
argIface.getInReductionBlockArgs();
+    if (!inRedBlockArgs.empty()) {
+      llvm::OpenMPIRBuilder &ompB = *ompBuilder;
+      llvm::Module *m = moduleTranslation.getLLVMModule();
+      llvm::LLVMContext &llvmCtx = m->getContext();
+      uint32_t srcLocSize;
+      llvm::Constant *srcLocStr = ompB.getOrCreateDefaultSrcLocStr(srcLocSize);
+      llvm::Value *bodyIdent = ompB.getOrCreateIdent(srcLocStr, srcLocSize);
+      llvm::Function *gtidFn = ompB.getOrCreateRuntimeFunctionPtr(
+          llvm::omp::OMPRTL___kmpc_global_thread_num);
+      llvm::Value *bodyGtid =
+          builder.CreateCall(gtidFn, {bodyIdent}, "omp_global_thread_num");
+      llvm::FunctionCallee getThData = ompB.getOrCreateRuntimeFunction(
+          *m, llvm::omp::OMPRTL___kmpc_task_reduction_get_th_data);
+      llvm::Type *ptrTy = llvm::PointerType::getUnqual(llvmCtx);
+      llvm::Value *nullDesc = llvm::ConstantPointerNull::get(ptrTy);
+      for (auto [blockArg, origPtr] :
+           llvm::zip_equal(inRedBlockArgs, inRedOrigPtrs)) {
+        llvm::Value *priv = builder.CreateCall(
+            getThData, {bodyGtid, nullDesc, origPtr}, "omp.inred.priv");
+        moduleTranslation.mapValue(blockArg, priv);
+      }
+    }
+
     LLVM::ModuleTranslation::SaveStack<OpenMPAllocStackFrame> frame(
         moduleTranslation, allocaIP, deallocBlocks);
     llvm::Expected<llvm::BasicBlock *> exitBlock = convertOmpOpRegions(
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir 
b/mlir/test/Dialect/OpenMP/invalid.mlir
index 06ad3d60ea635..7e6793d23ac7d 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -3129,6 +3129,66 @@ func.func @omp_target_depend(%data_var: memref<i32>) {
 
 // -----
 
+func.func @omp_target_in_reduction_unresolved(%ptr: !llvm.ptr) {
+  // expected-error @below {{op expected symbol reference @add_f32 to point to 
a reduction declaration}}
+  omp.target in_reduction(@add_f32 %ptr -> %arg0 : !llvm.ptr) {
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+omp.declare_reduction @add_f32 : f32
+init {
+^bb0(%arg: f32):
+  %0 = arith.constant 0.0 : f32
+  omp.yield (%0 : f32)
+}
+combiner {
+^bb1(%arg0: f32, %arg1: f32):
+  %1 = arith.addf %arg0, %arg1 : f32
+  omp.yield (%1 : f32)
+}
+
+func.func @omp_target_in_reduction_duplicate(%ptr: !llvm.ptr) {
+  // expected-error @below {{op accumulator variable used more than once}}
+  omp.target in_reduction(@add_f32 %ptr -> %arg0, @add_f32 %ptr -> %arg1 : 
!llvm.ptr, !llvm.ptr) {
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+omp.declare_reduction @add_i32 : i32
+init {
+^bb0(%arg: i32):
+  %0 = arith.constant 0 : i32
+  omp.yield (%0 : i32)
+}
+combiner {
+^bb1(%arg0: i32, %arg1: i32):
+  %1 = arith.addi %arg0, %arg1 : i32
+  omp.yield (%1 : i32)
+}
+atomic {
+^bb2(%arg2: !llvm.ptr, %arg3: !llvm.ptr):
+  %2 = llvm.load %arg3 : !llvm.ptr -> i32
+  llvm.atomicrmw add %arg2, %2 monotonic : !llvm.ptr, i32
+  omp.yield
+}
+
+func.func @omp_target_in_reduction_type_mismatch(%mem: memref<1xf32>) {
+  // expected-error @below {{op expected accumulator ('memref<1xf32>') to be 
the same type as reduction declaration ('!llvm.ptr')}}
+  omp.target in_reduction(@add_i32 %mem -> %arg0 : memref<1xf32>) {
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
 func.func @omp_distribute_schedule(%chunk_size : i32, %lb : i32, %ub : i32, 
%step : i32) -> () {
   // expected-error @below {{op chunk size set without dist_schedule_static 
being present}}
   "omp.distribute"(%chunk_size) <{operandSegmentSizes = array<i32: 0, 0, 1, 
0>}> ({
diff --git a/mlir/test/Target/LLVMIR/openmp-target-in-reduction.mlir 
b/mlir/test/Target/LLVMIR/openmp-target-in-reduction.mlir
new file mode 100644
index 0000000000000..2b3cfd514d82e
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-target-in-reduction.mlir
@@ -0,0 +1,50 @@
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
+
+// in_reduction on omp.target: the in_reduction variable is also captured
+// into the target region as a map entry (the Flang front-end emits this
+// implicit map). Inside the outlined target body the captured pointer is
+// passed to __kmpc_task_reduction_get_th_data with a NULL descriptor;
+// the runtime walks enclosing taskgroups to locate the matching
+// task_reduction registration. The returned pointer is bound to the
+// in_reduction region block argument so subsequent loads/stores inside
+// the region use the private copy.
+
+omp.declare_reduction @add_i32 : i32
+init {
+^bb0(%arg0: i32):
+  %c0 = llvm.mlir.constant(0 : i32) : i32
+  omp.yield(%c0 : i32)
+}
+combiner {
+^bb0(%arg0: i32, %arg1: i32):
+  %s = llvm.add %arg0, %arg1 : i32
+  omp.yield(%s : i32)
+}
+
+llvm.func @target_inreduction(%x : !llvm.ptr) {
+  %m = omp.map.info var_ptr(%x : !llvm.ptr, i32) map_clauses(tofrom) 
capture(ByRef) -> !llvm.ptr
+  omp.target in_reduction(@add_i32 %x -> %prv : !llvm.ptr) map_entries(%m -> 
%marg : !llvm.ptr) {
+    %v = llvm.load %prv : !llvm.ptr -> i32
+    %c1 = llvm.mlir.constant(1 : i32) : i32
+    %s = llvm.add %v, %c1 : i32
+    llvm.store %s, %prv : i32, !llvm.ptr
+    omp.terminator
+  }
+  llvm.return
+}
+
+// The host stub forwards the captured pointer into the outlined target
+// kernel.
+// CHECK-LABEL: define void @target_inreduction(
+// CHECK:         call void 
@__omp_offloading_{{.*}}_target_inreduction_{{.*}}(ptr %{{.+}}, ptr null)
+
+// In the outlined target body the in_reduction private pointer is
+// obtained from the runtime using the captured original pointer; that
+// pointer is then the base of the load and store inside the region.
+// CHECK-LABEL: define internal void 
@__omp_offloading_{{.*}}_target_inreduction_
+// CHECK-SAME:    (ptr %[[CAPT:.+]], ptr %{{.+}})
+// CHECK:         %[[GTID:.+]] = call i32 @__kmpc_global_thread_num(
+// CHECK:         %[[PRIV:.+]] = call ptr 
@__kmpc_task_reduction_get_th_data(i32 %[[GTID]], ptr null, ptr %[[CAPT]])
+// CHECK:         %[[LOADED:.+]] = load i32, ptr %[[PRIV]]
+// CHECK:         %[[SUM:.+]] = add i32 %[[LOADED]], 1
+// CHECK:         store i32 %[[SUM]], ptr %[[PRIV]]
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir 
b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 5c22f7f081bb5..fccf132a8ebcb 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -190,10 +190,90 @@ atomic {
   llvm.atomicrmw fadd %arg2, %2 monotonic : !llvm.ptr, f32
   omp.yield
 }
-llvm.func @target_in_reduction(%x : !llvm.ptr) {
-  // expected-error@below {{not yet implemented: Unhandled clause in_reduction 
in omp.target operation}}
+llvm.func @target_in_reduction_byref(%x : !llvm.ptr) {
+  // expected-error@below {{not yet implemented: Unhandled clause in_reduction 
with byref modifier in omp.target operation}}
   // expected-error@below {{LLVM Translation failed for operation: omp.target}}
-  omp.target in_reduction(@add_f32 %x -> %prv : !llvm.ptr) {
+  omp.target in_reduction(byref @add_f32 %x -> %prv : !llvm.ptr) {
+    omp.terminator
+  }
+  llvm.return
+}
+
+// -----
+
+omp.declare_reduction @add_cleanup_f32 : f32
+init {
+^bb0(%arg: f32):
+  %0 = llvm.mlir.constant(0.0 : f32) : f32
+  omp.yield (%0 : f32)
+}
+combiner {
+^bb1(%arg0: f32, %arg1: f32):
+  %1 = llvm.fadd %arg0, %arg1 : f32
+  omp.yield (%1 : f32)
+}
+cleanup {
+^bb2(%arg2: f32):
+  omp.yield
+}
+llvm.func @target_in_reduction_cleanup(%x : !llvm.ptr) {
+  // expected-error@below {{not yet implemented: Unhandled clause in_reduction 
with cleanup region in omp.target operation}}
+  // expected-error@below {{LLVM Translation failed for operation: omp.target}}
+  omp.target in_reduction(@add_cleanup_f32 %x -> %prv : !llvm.ptr) {
+    omp.terminator
+  }
+  llvm.return
+}
+
+// -----
+
+omp.declare_reduction @add_two_arg_init_i32 : !llvm.ptr alloc {
+^bb0(%arg: !llvm.ptr):
+  %0 = llvm.mlir.constant(1 : i64) : i64
+  %1 = llvm.alloca %0 x i32 : (i64) -> !llvm.ptr
+  omp.yield(%1 : !llvm.ptr)
+} init {
+^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
+  %0 = llvm.mlir.constant(0 : i32) : i32
+  llvm.store %0, %arg1 : i32, !llvm.ptr
+  omp.yield(%arg1 : !llvm.ptr)
+} combiner {
+^bb1(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
+  %0 = llvm.load %arg0 : !llvm.ptr -> i32
+  %1 = llvm.load %arg1 : !llvm.ptr -> i32
+  %2 = llvm.add %0, %1 : i32
+  llvm.store %2, %arg0 : i32, !llvm.ptr
+  omp.yield(%arg0 : !llvm.ptr)
+}
+llvm.func @target_in_reduction_two_arg_init(%x : !llvm.ptr) {
+  // expected-error@below {{not yet implemented: Unhandled clause in_reduction 
with two-argument initializer in omp.target operation}}
+  // expected-error@below {{LLVM Translation failed for operation: omp.target}}
+  omp.target in_reduction(@add_two_arg_init_i32 %x -> %prv : !llvm.ptr) {
+    omp.terminator
+  }
+  llvm.return
+}
+
+// -----
+
+omp.declare_reduction @add_no_map_f32 : f32
+init {
+^bb0(%arg: f32):
+  %0 = llvm.mlir.constant(0.0 : f32) : f32
+  omp.yield (%0 : f32)
+}
+combiner {
+^bb1(%arg0: f32, %arg1: f32):
+  %1 = llvm.fadd %arg0, %arg1 : f32
+  omp.yield (%1 : f32)
+}
+llvm.func @target_in_reduction_no_map(%x : !llvm.ptr) {
+  // The in_reduction variable %x has no matching map_entries entry. The
+  // outlined target kernel would otherwise reference %x across function
+  // boundaries; the translation must reject this up front.
+  // expected-error@below {{not yet implemented: in_reduction variable on 
omp.target must also be captured by a matching map_entries entry}}
+  // expected-error@below {{LLVM Translation failed for operation: omp.target}}
+  omp.target in_reduction(@add_no_map_f32 %x -> %prv : !llvm.ptr) {
     omp.terminator
   }
   llvm.return

_______________________________________________
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to