================ @@ -4898,6 +4898,274 @@ void CGOpenMPRuntime::emitSingleReductionCombiner(CodeGenFunction &CGF, } } +void CGOpenMPRuntime::emitPrivateReduction( + CodeGenFunction &CGF, SourceLocation Loc, const Expr *Privates, + const Expr *LHSExprs, const Expr *RHSExprs, const Expr *ReductionOps) { + + // Create a shared global variable (__shared_reduction_var) to accumulate the + // final result. + // + // Call __kmpc_barrier to synchronize threads before initialization. + // + // The master thread (thread_id == 0) initializes __shared_reduction_var + // with the identity value or initializer. + // + // Call __kmpc_barrier to synchronize before combining. + // For each i: + // - Thread enters critical section. + // - Reads its private value from LHSExprs[i]. + // - Updates __shared_reduction_var[i] = RedOp_i(__shared_reduction_var[i], + // LHSExprs[i]). + // - Exits critical section. + // + // Call __kmpc_barrier after combining. + // + // Each thread copies __shared_reduction_var[i] back to LHSExprs[i]. + // + // Final __kmpc_barrier to synchronize after broadcasting + QualType PrivateType = Privates->getType(); + llvm::Type *LLVMType = CGF.ConvertTypeForMem(PrivateType); + + llvm::Constant *InitVal = nullptr; + const OMPDeclareReductionDecl *UDR = getReductionInit(ReductionOps); + // Determine the initial value for the shared reduction variable + if (!UDR) { + InitVal = llvm::Constant::getNullValue(LLVMType); + if (const auto *DRE = dyn_cast<DeclRefExpr>(Privates)) { + if (const auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) { + const Expr *InitExpr = VD->getInit(); + if (InitExpr) { + Expr::EvalResult Result; + if (InitExpr->EvaluateAsRValue(Result, CGF.getContext())) { + APValue &InitValue = Result.Val; + if (InitValue.isInt()) + InitVal = llvm::ConstantInt::get(LLVMType, InitValue.getInt()); + else if (InitValue.isFloat()) + InitVal = llvm::ConstantFP::get(LLVMType, InitValue.getFloat()); + else if (InitValue.isComplexInt()) { + // For complex int: create struct { real, imag } + llvm::Constant *Real = llvm::ConstantInt::get( + cast<llvm::StructType>(LLVMType)->getElementType(0), + InitValue.getComplexIntReal()); + llvm::Constant *Imag = llvm::ConstantInt::get( + cast<llvm::StructType>(LLVMType)->getElementType(1), + InitValue.getComplexIntImag()); + InitVal = llvm::ConstantStruct::get( + cast<llvm::StructType>(LLVMType), {Real, Imag}); + } else if (InitValue.isComplexFloat()) { + llvm::Constant *Real = llvm::ConstantFP::get( + cast<llvm::StructType>(LLVMType)->getElementType(0), + InitValue.getComplexFloatReal()); + llvm::Constant *Imag = llvm::ConstantFP::get( + cast<llvm::StructType>(LLVMType)->getElementType(1), + InitValue.getComplexFloatImag()); + InitVal = llvm::ConstantStruct::get( + cast<llvm::StructType>(LLVMType), {Real, Imag}); + } + } + } + } + } + } else { + InitVal = llvm::Constant::getNullValue(LLVMType); + } + std::string ReductionVarNameStr; + if (const auto *DRE = dyn_cast<DeclRefExpr>(Privates->IgnoreParenCasts())) + ReductionVarNameStr = DRE->getDecl()->getNameAsString(); + else + ReductionVarNameStr = "unnamed_priv_var"; + + // Create an internal shared variable + std::string SharedName = + CGM.getOpenMPRuntime().getName({"internal_pivate_", ReductionVarNameStr}); + llvm::GlobalVariable *SharedVar = new llvm::GlobalVariable( + CGM.getModule(), LLVMType, false, llvm::GlobalValue::InternalLinkage, + InitVal, ".omp.reduction." + SharedName, nullptr, + llvm::GlobalVariable::NotThreadLocal); + + SharedVar->setAlignment( + llvm::MaybeAlign(CGF.getContext().getTypeAlign(PrivateType) / 8)); + + Address SharedResult(SharedVar, SharedVar->getValueType(), + CGF.getContext().getTypeAlignInChars(PrivateType)); + + llvm::Value *ThreadId = getThreadID(CGF, Loc); + llvm::Value *BarrierLoc = emitUpdateLocation(CGF, Loc, OMP_ATOMIC_REDUCE); + llvm::Value *BarrierArgs[] = {BarrierLoc, ThreadId}; + + llvm::BasicBlock *InitBB = CGF.createBasicBlock("init"); + llvm::BasicBlock *InitEndBB = CGF.createBasicBlock("init.end"); + + llvm::Value *IsWorker = CGF.Builder.CreateICmpEQ( + ThreadId, llvm::ConstantInt::get(ThreadId->getType(), 0)); + CGF.Builder.CreateCondBr(IsWorker, InitBB, InitEndBB); + + CGF.EmitBlock(InitBB); + + auto EmitSharedInit = [&]() { + if (UDR) { // Check if it's a User-Defined Reduction + if (const Expr *UDRInitExpr = UDR->getInitializer()) { + std::pair<llvm::Function *, llvm::Function *> FnPair = + getUserDefinedReduction(UDR); + llvm::Function *InitializerFn = FnPair.second; + if (InitializerFn) { + if (const auto *CE = + dyn_cast<CallExpr>(UDRInitExpr->IgnoreParenImpCasts())) { + const auto *OutDRE = cast<DeclRefExpr>( + cast<UnaryOperator>(CE->getArg(0)->IgnoreParenImpCasts()) + ->getSubExpr()); + const VarDecl *OutVD = cast<VarDecl>(OutDRE->getDecl()); + + CodeGenFunction::OMPPrivateScope LocalScope(CGF); + LocalScope.addPrivate(OutVD, SharedResult); + + (void)LocalScope.Privatize(); + if (const auto *OVE = dyn_cast<OpaqueValueExpr>( + CE->getCallee()->IgnoreParenImpCasts())) { + CodeGenFunction::OpaqueValueMapping OpaqueMap( + CGF, OVE, RValue::get(InitializerFn)); + CGF.EmitIgnoredExpr(CE); + } else { + CGF.EmitAnyExprToMem(UDRInitExpr, SharedResult, + PrivateType.getQualifiers(), true); + } + } else { + CGF.EmitAnyExprToMem(UDRInitExpr, SharedResult, + PrivateType.getQualifiers(), true); + } + } else { + CGF.EmitAnyExprToMem(UDRInitExpr, SharedResult, + PrivateType.getQualifiers(), true); + } + } else { + // EmitNullInitialization handles default construction for C++ classes + // and zeroing for scalars, which is a reasonable default. + CGF.EmitNullInitialization(SharedResult, PrivateType); + } + return; // UDR initialization handled + } + if (const auto *DRE = dyn_cast<DeclRefExpr>(Privates)) { + if (const auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) { + if (const Expr *InitExpr = VD->getInit()) { + CGF.EmitAnyExprToMem(InitExpr, SharedResult, + PrivateType.getQualifiers(), true); + return; + } + } + } + CGF.EmitNullInitialization(SharedResult, PrivateType); + }; + EmitSharedInit(); + CGF.Builder.CreateBr(InitEndBB); + CGF.EmitBlock(InitEndBB); + + CGF.EmitRuntimeCall(OMPBuilder.getOrCreateRuntimeFunction( + CGM.getModule(), OMPRTL___kmpc_barrier), + BarrierArgs); + + const Expr *ReductionOp = ReductionOps; + const OMPDeclareReductionDecl *CurrentUDR = getReductionInit(ReductionOp); + LValue SharedLV = CGF.MakeAddrLValue(SharedResult, PrivateType); + LValue LHSLV = CGF.EmitLValue(LHSExprs); + + auto EmitCriticalReduction = [&](auto ReductionGen) { + std::string CriticalName = getName({"reduction_critical"}); + emitCriticalRegion(CGF, CriticalName, ReductionGen, Loc); + }; + + if (CurrentUDR) { + // Handle user-defined reduction. + auto ReductionGen = [&](CodeGenFunction &CGF, PrePostActionTy &Action) { + Action.Enter(CGF); + std::pair<llvm::Function *, llvm::Function *> FnPair = + getUserDefinedReduction(CurrentUDR); + if (FnPair.first) { + if (const auto *CE = dyn_cast<CallExpr>(ReductionOp)) { + const auto *OutDRE = cast<DeclRefExpr>( + cast<UnaryOperator>(CE->getArg(0)->IgnoreParenImpCasts()) + ->getSubExpr()); + const auto *InDRE = cast<DeclRefExpr>( + cast<UnaryOperator>(CE->getArg(1)->IgnoreParenImpCasts()) + ->getSubExpr()); + CodeGenFunction::OMPPrivateScope LocalScope(CGF); + LocalScope.addPrivate(cast<VarDecl>(OutDRE->getDecl()), + SharedLV.getAddress()); + LocalScope.addPrivate(cast<VarDecl>(InDRE->getDecl()), + LHSLV.getAddress()); + (void)LocalScope.Privatize(); + emitReductionCombiner(CGF, ReductionOp); + } + } + }; + EmitCriticalReduction(ReductionGen); + } else { + // Handle built-in reduction operations. + const Expr *ReductionClauseExpr = ReductionOp->IgnoreParenCasts(); + if (const auto *Cleanup = dyn_cast<ExprWithCleanups>(ReductionClauseExpr)) + ReductionClauseExpr = Cleanup->getSubExpr()->IgnoreParenCasts(); + + const Expr *AssignRHS = nullptr; + if (const auto *BinOp = dyn_cast<BinaryOperator>(ReductionClauseExpr)) { + if (BinOp->getOpcode() == BO_Assign) + AssignRHS = BinOp->getRHS(); + } else if (const auto *OpCall = + dyn_cast<CXXOperatorCallExpr>(ReductionClauseExpr)) { + if (OpCall->getOperator() == OO_Equal) + AssignRHS = OpCall->getArg(1); + } + + if (!AssignRHS) + return; ---------------- alexey-bataev wrote:
Then it should be an assertion, not an exit. https://github.com/llvm/llvm-project/pull/134709 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits