https://github.com/kparzysz created https://github.com/llvm/llvm-project/pull/81623
…essor Rename `findRepeatableClause` to `findRepeatableClause2`, and make the new `findRepeatableClause` operate on new `omp::Clause` objects. Leave `Map` unchanged, because it will require more changes for it to work. >From be33fa2419d24490a221f78fbba4f2b7097b6011 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek <krzysztof.parzys...@amd.com> Date: Tue, 6 Feb 2024 17:06:29 -0600 Subject: [PATCH] [flang][OpenMP] Convert repeatable clauses (except Map) in ClauseProcessor Rename `findRepeatableClause` to `findRepeatableClause2`, and make the new `findRepeatableClause` operate on new `omp::Clause` objects. Leave `Map` unchanged, because it will require more changes for it to work. --- flang/include/flang/Evaluate/tools.h | 23 + flang/lib/Lower/OpenMP.cpp | 632 +++++++++++++-------------- 2 files changed, 328 insertions(+), 327 deletions(-) diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h index d257da1a709642..e9999974944e88 100644 --- a/flang/include/flang/Evaluate/tools.h +++ b/flang/include/flang/Evaluate/tools.h @@ -430,6 +430,29 @@ template <typename A> std::optional<CoarrayRef> ExtractCoarrayRef(const A &x) { } } +struct ExtractSubstringHelper { + template <typename T> static std::optional<Substring> visit(T &&) { + return std::nullopt; + } + + static std::optional<Substring> visit(const Substring &e) { return e; } + + template <typename T> + static std::optional<Substring> visit(const Designator<T> &e) { + return std::visit([](auto &&s) { return visit(s); }, e.u); + } + + template <typename T> + static std::optional<Substring> visit(const Expr<T> &e) { + return std::visit([](auto &&s) { return visit(s); }, e.u); + } +}; + +template <typename A> +std::optional<Substring> ExtractSubstring(const A &x) { + return ExtractSubstringHelper::visit(x); +} + // If an expression is simply a whole symbol data designator, // extract and return that symbol, else null. template <typename A> const Symbol *UnwrapWholeSymbolDataRef(const A &x) { diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp index d7a93db15a4bb8..4b21ab934c9393 100644 --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -72,9 +72,9 @@ getOmpObjectSymbol(const Fortran::parser::OmpObject &ompObject) { return sym; } -static void genObjectList(const Fortran::parser::OmpObjectList &objectList, - Fortran::lower::AbstractConverter &converter, - llvm::SmallVectorImpl<mlir::Value> &operands) { +static void genObjectList2(const Fortran::parser::OmpObjectList &objectList, + Fortran::lower::AbstractConverter &converter, + llvm::SmallVectorImpl<mlir::Value> &operands) { auto addOperands = [&](Fortran::lower::SymbolRef sym) { const mlir::Value variable = converter.getSymbolAddress(sym); if (variable) { @@ -93,27 +93,6 @@ static void genObjectList(const Fortran::parser::OmpObjectList &objectList, } } -static void gatherFuncAndVarSyms( - const Fortran::parser::OmpObjectList &objList, - mlir::omp::DeclareTargetCaptureClause clause, - llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) { - for (const Fortran::parser::OmpObject &ompObject : objList.v) { - Fortran::common::visit( - Fortran::common::visitors{ - [&](const Fortran::parser::Designator &designator) { - if (const Fortran::parser::Name *name = - Fortran::semantics::getDesignatorNameIfDataRef( - designator)) { - symbolAndClause.emplace_back(clause, *name->symbol); - } - }, - [&](const Fortran::parser::Name &name) { - symbolAndClause.emplace_back(clause, *name.symbol); - }}, - ompObject.u); - } -} - static Fortran::lower::pft::Evaluation * getCollapsedLoopEval(Fortran::lower::pft::Evaluation &eval, int collapseValue) { // Return the Evaluation of the innermost collapsed loop, or the current one @@ -1257,6 +1236,32 @@ List<Clause> makeList(const parser::OmpClauseList &clauses, } } // namespace omp +static void genObjectList(const omp::ObjectList &objects, + Fortran::lower::AbstractConverter &converter, + llvm::SmallVectorImpl<mlir::Value> &operands) { + for (const omp::Object &object : objects) { + const Fortran::semantics::Symbol *sym = object.sym; + assert(sym && "Expected Symbol"); + if (mlir::Value variable = converter.getSymbolAddress(*sym)) { + operands.push_back(variable); + } else { + if (const auto *details = + sym->detailsIf<Fortran::semantics::HostAssocDetails>()) { + operands.push_back(converter.getSymbolAddress(details->symbol())); + converter.copySymbolBinding(details->symbol(), *sym); + } + } + } +} + +static void gatherFuncAndVarSyms( + const omp::ObjectList &objects, + mlir::omp::DeclareTargetCaptureClause clause, + llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) { + for (const omp::Object &object : objects) + symbolAndClause.emplace_back(clause, *object.sym); +} + //===----------------------------------------------------------------------===// // DataSharingProcessor //===----------------------------------------------------------------------===// @@ -1718,9 +1723,8 @@ class ClauseProcessor { llvm::SmallVectorImpl<mlir::Value> &dependOperands) const; bool processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const; - bool - processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName, - mlir::Value &result) const; + bool processIf(omp::clause::If::DirectiveNameModifier directiveName, + mlir::Value &result) const; bool processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const; @@ -1815,6 +1819,26 @@ class ClauseProcessor { /// if at least one instance was found. template <typename T> bool findRepeatableClause( + std::function<void(const T &, const Fortran::parser::CharBlock &source)> + callbackFn) const { + bool found = false; + ClauseIterator nextIt, endIt = clauses.end(); + for (ClauseIterator it = clauses.begin(); it != endIt; it = nextIt) { + nextIt = findClause<T>(it, endIt); + + if (nextIt != endIt) { + callbackFn(std::get<T>(nextIt->u), nextIt->source); + found = true; + ++nextIt; + } + } + return found; + } + + /// Call `callbackFn` for each occurrence of the given clause. Return `true` + /// if at least one instance was found. + template <typename T> + bool findRepeatableClause2( std::function<void(const T *, const Fortran::parser::CharBlock &source)> callbackFn) const { bool found = false; @@ -1880,9 +1904,9 @@ class ReductionProcessor { IEOR }; static ReductionIdentifier - getReductionType(const Fortran::parser::ProcedureDesignator &pd) { + getReductionType(const omp::clause::ProcedureDesignator &pd) { auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>( - getRealName(pd).ToString()) + getRealName(pd.v.sym).ToString()) .Case("max", ReductionIdentifier::MAX) .Case("min", ReductionIdentifier::MIN) .Case("iand", ReductionIdentifier::IAND) @@ -1894,35 +1918,33 @@ class ReductionProcessor { } static ReductionIdentifier getReductionType( - Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp) { + omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp) { switch (intrinsicOp) { - case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: + case omp::clause::DefinedOperator::IntrinsicOperator::Add: return ReductionIdentifier::ADD; - case Fortran::parser::DefinedOperator::IntrinsicOperator::Subtract: + case omp::clause::DefinedOperator::IntrinsicOperator::Subtract: return ReductionIdentifier::SUBTRACT; - case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: + case omp::clause::DefinedOperator::IntrinsicOperator::Multiply: return ReductionIdentifier::MULTIPLY; - case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: + case omp::clause::DefinedOperator::IntrinsicOperator::AND: return ReductionIdentifier::AND; - case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: + case omp::clause::DefinedOperator::IntrinsicOperator::EQV: return ReductionIdentifier::EQV; - case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: + case omp::clause::DefinedOperator::IntrinsicOperator::OR: return ReductionIdentifier::OR; - case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: + case omp::clause::DefinedOperator::IntrinsicOperator::NEQV: return ReductionIdentifier::NEQV; default: llvm_unreachable("unexpected intrinsic operator in reduction"); } } - static bool supportedIntrinsicProcReduction( - const Fortran::parser::ProcedureDesignator &pd) { - const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)}; - assert(name && "Invalid Reduction Intrinsic."); - if (!name->symbol->GetUltimate().attrs().test( - Fortran::semantics::Attr::INTRINSIC)) + static bool + supportedIntrinsicProcReduction(const omp::clause::ProcedureDesignator &pd) { + Fortran::semantics::Symbol *sym = pd.v.sym; + if (!sym->GetUltimate().attrs().test(Fortran::semantics::Attr::INTRINSIC)) return false; - auto redType = llvm::StringSwitch<bool>(getRealName(name).ToString()) + auto redType = llvm::StringSwitch<bool>(getRealName(sym).ToString()) .Case("max", true) .Case("min", true) .Case("iand", true) @@ -1933,15 +1955,13 @@ class ReductionProcessor { } static const Fortran::semantics::SourceName - getRealName(const Fortran::parser::Name *name) { - return name->symbol->GetUltimate().name(); + getRealName(const Fortran::semantics::Symbol *symbol) { + return symbol->GetUltimate().name(); } static const Fortran::semantics::SourceName - getRealName(const Fortran::parser::ProcedureDesignator &pd) { - const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)}; - assert(name && "Invalid Reduction Intrinsic."); - return getRealName(name); + getRealName(const omp::clause::ProcedureDesignator &pd) { + return getRealName(pd.v.sym); } static std::string getReductionName(llvm::StringRef name, mlir::Type ty) { @@ -1951,25 +1971,25 @@ class ReductionProcessor { .str(); } - static std::string getReductionName( - Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp, - mlir::Type ty) { + static std::string + getReductionName(omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp, + mlir::Type ty) { std::string reductionName; switch (intrinsicOp) { - case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: + case omp::clause::DefinedOperator::IntrinsicOperator::Add: reductionName = "add_reduction"; break; - case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: + case omp::clause::DefinedOperator::IntrinsicOperator::Multiply: reductionName = "multiply_reduction"; break; - case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: + case omp::clause::DefinedOperator::IntrinsicOperator::AND: return "and_reduction"; - case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: + case omp::clause::DefinedOperator::IntrinsicOperator::EQV: return "eqv_reduction"; - case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: + case omp::clause::DefinedOperator::IntrinsicOperator::OR: return "or_reduction"; - case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: + case omp::clause::DefinedOperator::IntrinsicOperator::NEQV: return "neqv_reduction"; default: reductionName = "other_reduction"; @@ -2213,7 +2233,7 @@ class ReductionProcessor { static void addReductionDecl(mlir::Location currentLocation, Fortran::lower::AbstractConverter &converter, - const Fortran::parser::OmpReductionClause &reduction, + const omp::clause::Reduction &reduction, llvm::SmallVectorImpl<mlir::Value> &reductionVars, llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols, llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> @@ -2221,13 +2241,12 @@ class ReductionProcessor { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); mlir::omp::ReductionDeclareOp decl; const auto &redOperator{ - std::get<Fortran::parser::OmpReductionOperator>(reduction.t)}; - const auto &objectList{ - std::get<Fortran::parser::OmpObjectList>(reduction.t)}; + std::get<omp::clause::ReductionOperator>(reduction.t)}; + const auto &objectList{std::get<omp::ObjectList>(reduction.t)}; if (const auto &redDefinedOp = - std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) { + std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) { const auto &intrinsicOp{ - std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>( + std::get<omp::clause::DefinedOperator::IntrinsicOperator>( redDefinedOp->u)}; ReductionIdentifier redId = getReductionType(intrinsicOp); switch (redId) { @@ -2243,10 +2262,41 @@ class ReductionProcessor { "Reduction of some intrinsic operators is not supported"); break; } - for (const Fortran::parser::OmpObject &ompObject : objectList.v) { - if (const auto *name{ - Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) { - if (const Fortran::semantics::Symbol * symbol{name->symbol}) { + for (const omp::Object &object : objectList) { + if (const Fortran::semantics::Symbol *symbol = object.sym) { + if (reductionSymbols) + reductionSymbols->push_back(symbol); + mlir::Value symVal = converter.getSymbolAddress(*symbol); + if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) + symVal = declOp.getBase(); + mlir::Type redType = + symVal.getType().cast<fir::ReferenceType>().getEleTy(); + reductionVars.push_back(symVal); + if (redType.isa<fir::LogicalType>()) + decl = createReductionDecl( + firOpBuilder, + getReductionName(intrinsicOp, firOpBuilder.getI1Type()), redId, + redType, currentLocation); + else if (redType.isIntOrIndexOrFloat()) { + decl = createReductionDecl(firOpBuilder, + getReductionName(intrinsicOp, redType), + redId, redType, currentLocation); + } else { + TODO(currentLocation, "Reduction of some types is not supported"); + } + reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( + firOpBuilder.getContext(), decl.getSymName())); + } + } + } else if (const auto *reductionIntrinsic = + std::get_if<omp::clause::ProcedureDesignator>( + &redOperator.u)) { + if (ReductionProcessor::supportedIntrinsicProcReduction( + *reductionIntrinsic)) { + ReductionProcessor::ReductionIdentifier redId = + ReductionProcessor::getReductionType(*reductionIntrinsic); + for (const omp::Object &object : objectList) { + if (const Fortran::semantics::Symbol *symbol = object.sym) { if (reductionSymbols) reductionSymbols->push_back(symbol); mlir::Value symVal = converter.getSymbolAddress(*symbol); @@ -2255,55 +2305,18 @@ class ReductionProcessor { mlir::Type redType = symVal.getType().cast<fir::ReferenceType>().getEleTy(); reductionVars.push_back(symVal); - if (redType.isa<fir::LogicalType>()) - decl = createReductionDecl( - firOpBuilder, - getReductionName(intrinsicOp, firOpBuilder.getI1Type()), - redId, redType, currentLocation); - else if (redType.isIntOrIndexOrFloat()) { - decl = createReductionDecl(firOpBuilder, - getReductionName(intrinsicOp, redType), - redId, redType, currentLocation); - } else { - TODO(currentLocation, "Reduction of some types is not supported"); - } + assert(redType.isIntOrIndexOrFloat() && + "Unsupported reduction type"); + decl = createReductionDecl( + firOpBuilder, + getReductionName(getRealName(*reductionIntrinsic).ToString(), + redType), + redId, redType, currentLocation); reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( firOpBuilder.getContext(), decl.getSymName())); } } } - } else if (const auto *reductionIntrinsic = - std::get_if<Fortran::parser::ProcedureDesignator>( - &redOperator.u)) { - if (ReductionProcessor::supportedIntrinsicProcReduction( - *reductionIntrinsic)) { - ReductionProcessor::ReductionIdentifier redId = - ReductionProcessor::getReductionType(*reductionIntrinsic); - for (const Fortran::parser::OmpObject &ompObject : objectList.v) { - if (const auto *name{ - Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) { - if (const Fortran::semantics::Symbol * symbol{name->symbol}) { - if (reductionSymbols) - reductionSymbols->push_back(symbol); - mlir::Value symVal = converter.getSymbolAddress(*symbol); - if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) - symVal = declOp.getBase(); - mlir::Type redType = - symVal.getType().cast<fir::ReferenceType>().getEleTy(); - reductionVars.push_back(symVal); - assert(redType.isIntOrIndexOrFloat() && - "Unsupported reduction type"); - decl = createReductionDecl( - firOpBuilder, - getReductionName(getRealName(*reductionIntrinsic).ToString(), - redType), - redId, redType, currentLocation); - reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( - firOpBuilder.getContext(), decl.getSymName())); - } - } - } - } } } }; @@ -2365,7 +2378,7 @@ getSimdModifier(const omp::clause::Schedule &clause) { static void genAllocateClause(Fortran::lower::AbstractConverter &converter, - const Fortran::parser::OmpAllocateClause &ompAllocateClause, + const omp::clause::Allocate &clause, llvm::SmallVectorImpl<mlir::Value> &allocatorOperands, llvm::SmallVectorImpl<mlir::Value> &allocateOperands) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); @@ -2373,21 +2386,18 @@ genAllocateClause(Fortran::lower::AbstractConverter &converter, Fortran::lower::StatementContext stmtCtx; mlir::Value allocatorOperand; - const Fortran::parser::OmpObjectList &ompObjectList = - std::get<Fortran::parser::OmpObjectList>(ompAllocateClause.t); - const auto &allocateModifier = std::get< - std::optional<Fortran::parser::OmpAllocateClause::AllocateModifier>>( - ompAllocateClause.t); + const omp::ObjectList &objectList = std::get<omp::ObjectList>(clause.t); + const auto &modifier = + std::get<std::optional<omp::clause::Allocate::Modifier>>(clause.t); // If the allocate modifier is present, check if we only use the allocator // submodifier. ALIGN in this context is unimplemented const bool onlyAllocator = - allocateModifier && - std::holds_alternative< - Fortran::parser::OmpAllocateClause::AllocateModifier::Allocator>( - allocateModifier->u); + modifier && + std::holds_alternative<omp::clause::Allocate::Modifier::Allocator>( + modifier->u); - if (allocateModifier && !onlyAllocator) { + if (modifier && !onlyAllocator) { TODO(currentLocation, "OmpAllocateClause ALIGN modifier"); } @@ -2395,20 +2405,17 @@ genAllocateClause(Fortran::lower::AbstractConverter &converter, // to list of allocators, otherwise, add default allocator to // list of allocators. if (onlyAllocator) { - const auto &allocatorValue = std::get< - Fortran::parser::OmpAllocateClause::AllocateModifier::Allocator>( - allocateModifier->u); - allocatorOperand = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(allocatorValue.v), stmtCtx)); - allocatorOperands.insert(allocatorOperands.end(), ompObjectList.v.size(), - allocatorOperand); + const auto &value = + std::get<omp::clause::Allocate::Modifier::Allocator>(modifier->u); + mlir::Value operand = + fir::getBase(converter.genExprValue(value.v, stmtCtx)); + allocatorOperands.append(objectList.size(), operand); } else { - allocatorOperand = firOpBuilder.createIntegerConstant( + mlir::Value operand = firOpBuilder.createIntegerConstant( currentLocation, firOpBuilder.getI32Type(), 1); - allocatorOperands.insert(allocatorOperands.end(), ompObjectList.v.size(), - allocatorOperand); + allocatorOperands.append(objectList.size(), operand); } - genObjectList(ompObjectList, converter, allocateOperands); + genObjectList(objectList, converter, allocateOperands); } static mlir::omp::ClauseProcBindKindAttr @@ -2435,20 +2442,17 @@ genProcBindKindAttr(fir::FirOpBuilder &firOpBuilder, static mlir::omp::ClauseTaskDependAttr genDependKindAttr(fir::FirOpBuilder &firOpBuilder, - const Fortran::parser::OmpClause::Depend *dependClause) { + const omp::clause::Depend &clause) { mlir::omp::ClauseTaskDepend pbKind; - switch ( - std::get<Fortran::parser::OmpDependenceType>( - std::get<Fortran::parser::OmpDependClause::InOut>(dependClause->v.u) - .t) - .v) { - case Fortran::parser::OmpDependenceType::Type::In: + const auto &inOut = std::get<omp::clause::Depend::InOut>(clause.u); + switch (std::get<omp::clause::Depend::Type>(inOut.t)) { + case omp::clause::Depend::Type::In: pbKind = mlir::omp::ClauseTaskDepend::taskdependin; break; - case Fortran::parser::OmpDependenceType::Type::Out: + case omp::clause::Depend::Type::Out: pbKind = mlir::omp::ClauseTaskDepend::taskdependout; break; - case Fortran::parser::OmpDependenceType::Type::Inout: + case omp::clause::Depend::Type::Inout: pbKind = mlir::omp::ClauseTaskDepend::taskdependinout; break; default: @@ -2459,45 +2463,41 @@ genDependKindAttr(fir::FirOpBuilder &firOpBuilder, pbKind); } -static mlir::Value getIfClauseOperand( - Fortran::lower::AbstractConverter &converter, - const Fortran::parser::OmpClause::If *ifClause, - Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName, - mlir::Location clauseLocation) { +static mlir::Value +getIfClauseOperand(Fortran::lower::AbstractConverter &converter, + const omp::clause::If &clause, + omp::clause::If::DirectiveNameModifier directiveName, + mlir::Location clauseLocation) { // Only consider the clause if it's intended for the given directive. - auto &directive = std::get< - std::optional<Fortran::parser::OmpIfClause::DirectiveNameModifier>>( - ifClause->v.t); + auto &directive = + std::get<std::optional<omp::clause::If::DirectiveNameModifier>>(clause.t); if (directive && directive.value() != directiveName) return nullptr; Fortran::lower::StatementContext stmtCtx; fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - auto &expr = std::get<Fortran::parser::ScalarLogicalExpr>(ifClause->v.t); mlir::Value ifVal = fir::getBase( - converter.genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx)); + converter.genExprValue(std::get<omp::SomeExpr>(clause.t), stmtCtx)); return firOpBuilder.createConvert(clauseLocation, firOpBuilder.getI1Type(), ifVal); } static void addUseDeviceClause(Fortran::lower::AbstractConverter &converter, - const Fortran::parser::OmpObjectList &useDeviceClause, + const omp::ObjectList &objects, llvm::SmallVectorImpl<mlir::Value> &operands, llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes, llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs, llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSymbols) { - genObjectList(useDeviceClause, converter, operands); + genObjectList(objects, converter, operands); for (mlir::Value &operand : operands) { checkMapType(operand.getLoc(), operand.getType()); useDeviceTypes.push_back(operand.getType()); useDeviceLocs.push_back(operand.getLoc()); } - for (const Fortran::parser::OmpObject &ompObject : useDeviceClause.v) { - Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); - useDeviceSymbols.push_back(sym); - } + for (const omp::Object &object : objects) + useDeviceSymbols.push_back(object.sym); } //===----------------------------------------------------------------------===// @@ -2806,10 +2806,10 @@ bool ClauseProcessor::processUntied(mlir::UnitAttr &result) const { bool ClauseProcessor::processAllocate( llvm::SmallVectorImpl<mlir::Value> &allocatorOperands, llvm::SmallVectorImpl<mlir::Value> &allocateOperands) const { - return findRepeatableClause<ClauseTy::Allocate>( - [&](const ClauseTy::Allocate *allocateClause, + return findRepeatableClause<omp::clause::Allocate>( + [&](const omp::clause::Allocate &clause, const Fortran::parser::CharBlock &) { - genAllocateClause(converter, allocateClause->v, allocatorOperands, + genAllocateClause(converter, clause, allocatorOperands, allocateOperands); }); } @@ -2826,12 +2826,12 @@ bool ClauseProcessor::processCopyin() const { if (converter.isPresentShallowLookup(*sym)) converter.copyHostAssociateVar(*sym, copyAssignIP); }; - bool hasCopyin = findRepeatableClause<ClauseTy::Copyin>( - [&](const ClauseTy::Copyin *copyinClause, + bool hasCopyin = findRepeatableClause<omp::clause::Copyin>( + [&](const omp::clause::Copyin &clause, const Fortran::parser::CharBlock &) { - const Fortran::parser::OmpObjectList &ompObjectList = copyinClause->v; - for (const Fortran::parser::OmpObject &ompObject : ompObjectList.v) { - Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); + for (const omp::Object &object : clause.v) { + Fortran::semantics::Symbol *sym = object.sym; + assert(sym && "Expecting symbol"); if (const auto *commonDetails = sym->detailsIf<Fortran::semantics::CommonBlockDetails>()) { for (const auto &mem : commonDetails->objects()) @@ -2864,38 +2864,30 @@ bool ClauseProcessor::processDepend( llvm::SmallVectorImpl<mlir::Value> &dependOperands) const { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - return findRepeatableClause<ClauseTy::Depend>( - [&](const ClauseTy::Depend *dependClause, + return findRepeatableClause<omp::clause::Depend>( + [&](const omp::clause::Depend &clause, const Fortran::parser::CharBlock &) { - const std::list<Fortran::parser::Designator> &depVal = - std::get<std::list<Fortran::parser::Designator>>( - std::get<Fortran::parser::OmpDependClause::InOut>( - dependClause->v.u) - .t); + assert(std::holds_alternative<omp::clause::Depend::InOut>(clause.u) && + "Only InOut is handled at the moment"); + const auto &inOut = std::get<omp::clause::Depend::InOut>(clause.u); + const auto &objects = std::get<omp::ObjectList>(inOut.t); + mlir::omp::ClauseTaskDependAttr dependTypeOperand = - genDependKindAttr(firOpBuilder, dependClause); - dependTypeOperands.insert(dependTypeOperands.end(), depVal.size(), - dependTypeOperand); - for (const Fortran::parser::Designator &ompObject : depVal) { - Fortran::semantics::Symbol *sym = nullptr; - std::visit( - Fortran::common::visitors{ - [&](const Fortran::parser::DataRef &designator) { - if (const Fortran::parser::Name *name = - std::get_if<Fortran::parser::Name>(&designator.u)) { - sym = name->symbol; - } else if (std::get_if<Fortran::common::Indirection< - Fortran::parser::ArrayElement>>( - &designator.u)) { - TODO(converter.getCurrentLocation(), - "array sections not supported for task depend"); - } - }, - [&](const Fortran::parser::Substring &designator) { - TODO(converter.getCurrentLocation(), - "substring not supported for task depend"); - }}, - (ompObject).u); + genDependKindAttr(firOpBuilder, clause); + dependTypeOperands.append(objects.size(), dependTypeOperand); + + for (const omp::Object &object : objects) { + assert(object.dsg && "Expecting designator"); + + if (Fortran::evaluate::ExtractSubstring(*object.dsg)) { + TODO(converter.getCurrentLocation(), + "substring not supported for task depend"); + } else if (Fortran::evaluate::IsArrayElement(*object.dsg)) { + TODO(converter.getCurrentLocation(), + "array sections not supported for task depend"); + } + + Fortran::semantics::Symbol *sym = object.sym; const mlir::Value variable = converter.getSymbolAddress(*sym); dependOperands.push_back(variable); } @@ -2903,14 +2895,14 @@ bool ClauseProcessor::processDepend( } bool ClauseProcessor::processIf( - Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName, + omp::clause::If::DirectiveNameModifier directiveName, mlir::Value &result) const { bool found = false; - findRepeatableClause<ClauseTy::If>( - [&](const ClauseTy::If *ifClause, + findRepeatableClause<omp::clause::If>( + [&](const omp::clause::If &clause, const Fortran::parser::CharBlock &source) { mlir::Location clauseLocation = converter.genLocation(source); - mlir::Value operand = getIfClauseOperand(converter, ifClause, + mlir::Value operand = getIfClauseOperand(converter, clause, directiveName, clauseLocation); // Assume that, at most, a single 'if' clause will be applicable to the // given directive. @@ -2924,12 +2916,11 @@ bool ClauseProcessor::processIf( bool ClauseProcessor::processLink( llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const { - return findRepeatableClause<ClauseTy::Link>( - [&](const ClauseTy::Link *linkClause, - const Fortran::parser::CharBlock &) { + return findRepeatableClause<omp::clause::Link>( + [&](const omp::clause::Link &clause, const Fortran::parser::CharBlock &) { // Case: declare target link(var1, var2)... gatherFuncAndVarSyms( - linkClause->v, mlir::omp::DeclareTargetCaptureClause::link, result); + clause.v, mlir::omp::DeclareTargetCaptureClause::link, result); }); } @@ -2966,7 +2957,7 @@ bool ClauseProcessor::processMap( llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSymbols) const { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - return findRepeatableClause<ClauseTy::Map>( + return findRepeatableClause2<ClauseTy::Map>( [&](const ClauseTy::Map *mapClause, const Fortran::parser::CharBlock &source) { mlir::Location clauseLocation = converter.genLocation(source); @@ -3058,43 +3049,41 @@ bool ClauseProcessor::processReduction( llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols, llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *reductionSymbols) const { - return findRepeatableClause<ClauseTy::Reduction>( - [&](const ClauseTy::Reduction *reductionClause, + return findRepeatableClause<omp::clause::Reduction>( + [&](const omp::clause::Reduction &clause, const Fortran::parser::CharBlock &) { ReductionProcessor rp; - rp.addReductionDecl(currentLocation, converter, reductionClause->v, - reductionVars, reductionDeclSymbols, - reductionSymbols); + rp.addReductionDecl(currentLocation, converter, clause, reductionVars, + reductionDeclSymbols, reductionSymbols); }); } bool ClauseProcessor::processSectionsReduction( mlir::Location currentLocation) const { - return findRepeatableClause<ClauseTy::Reduction>( - [&](const ClauseTy::Reduction *, const Fortran::parser::CharBlock &) { + return findRepeatableClause<omp::clause::Reduction>( + [&](const omp::clause::Reduction &, const Fortran::parser::CharBlock &) { TODO(currentLocation, "OMPC_Reduction"); }); } bool ClauseProcessor::processTo( llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const { - return findRepeatableClause<ClauseTy::To>( - [&](const ClauseTy::To *toClause, const Fortran::parser::CharBlock &) { + return findRepeatableClause<omp::clause::To>( + [&](const omp::clause::To &clause, const Fortran::parser::CharBlock &) { // Case: declare target to(func, var1, var2)... - gatherFuncAndVarSyms(toClause->v, + gatherFuncAndVarSyms(clause.v, mlir::omp::DeclareTargetCaptureClause::to, result); }); } bool ClauseProcessor::processEnter( llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const { - return findRepeatableClause<ClauseTy::Enter>( - [&](const ClauseTy::Enter *enterClause, + return findRepeatableClause<omp::clause::Enter>( + [&](const omp::clause::Enter &clause, const Fortran::parser::CharBlock &) { // Case: declare target enter(func, var1, var2)... - gatherFuncAndVarSyms(enterClause->v, - mlir::omp::DeclareTargetCaptureClause::enter, - result); + gatherFuncAndVarSyms( + clause.v, mlir::omp::DeclareTargetCaptureClause::enter, result); }); } @@ -3104,11 +3093,11 @@ bool ClauseProcessor::processUseDeviceAddr( llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs, llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSymbols) const { - return findRepeatableClause<ClauseTy::UseDeviceAddr>( - [&](const ClauseTy::UseDeviceAddr *devAddrClause, + return findRepeatableClause<omp::clause::UseDeviceAddr>( + [&](const omp::clause::UseDeviceAddr &clause, const Fortran::parser::CharBlock &) { - addUseDeviceClause(converter, devAddrClause->v, operands, - useDeviceTypes, useDeviceLocs, useDeviceSymbols); + addUseDeviceClause(converter, clause.v, operands, useDeviceTypes, + useDeviceLocs, useDeviceSymbols); }); } @@ -3118,10 +3107,10 @@ bool ClauseProcessor::processUseDevicePtr( llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs, llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSymbols) const { - return findRepeatableClause<ClauseTy::UseDevicePtr>( - [&](const ClauseTy::UseDevicePtr *devPtrClause, + return findRepeatableClause<omp::clause::UseDevicePtr>( + [&](const omp::clause::UseDevicePtr &clause, const Fortran::parser::CharBlock &) { - addUseDeviceClause(converter, devPtrClause->v, operands, useDeviceTypes, + addUseDeviceClause(converter, clause.v, operands, useDeviceTypes, useDeviceLocs, useDeviceSymbols); }); } @@ -3130,7 +3119,7 @@ template <typename T> bool ClauseProcessor::processMotionClauses( Fortran::lower::StatementContext &stmtCtx, llvm::SmallVectorImpl<mlir::Value> &mapOperands) { - return findRepeatableClause<T>( + return findRepeatableClause2<T>( [&](const T *motionClause, const Fortran::parser::CharBlock &source) { mlir::Location clauseLocation = converter.genLocation(source); fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); @@ -3700,7 +3689,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter, llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols; ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel, + cp.processIf(omp::clause::If::DirectiveNameModifier::Parallel, ifClauseOperand); cp.processNumThreads(stmtCtx, numThreadsClauseOperand); cp.processProcBind(procBindKindAttr); @@ -3796,8 +3785,7 @@ genTaskOp(Fortran::lower::AbstractConverter &converter, dependOperands; ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Task, - ifClauseOperand); + cp.processIf(omp::clause::If::DirectiveNameModifier::Task, ifClauseOperand); cp.processAllocate(allocatorOperands, allocateOperands); cp.processDefault(); cp.processFinal(stmtCtx, finalClauseOperand); @@ -3858,7 +3846,7 @@ genDataOp(Fortran::lower::AbstractConverter &converter, llvm::SmallVector<const Fortran::semantics::Symbol *> useDeviceSymbols; ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetData, + cp.processIf(omp::clause::If::DirectiveNameModifier::TargetData, ifClauseOperand); cp.processDevice(stmtCtx, deviceOperand); cp.processUseDevicePtr(devicePtrOperands, useDeviceTypes, useDeviceLocs, @@ -3889,19 +3877,16 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter, mlir::UnitAttr nowaitAttr; llvm::SmallVector<mlir::Value> mapOperands; - Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName; + omp::clause::If::DirectiveNameModifier directiveName; llvm::omp::Directive directive; if constexpr (std::is_same_v<OpTy, mlir::omp::EnterDataOp>) { - directiveName = - Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetEnterData; + directiveName = omp::clause::If::DirectiveNameModifier::TargetEnterData; directive = llvm::omp::Directive::OMPD_target_enter_data; } else if constexpr (std::is_same_v<OpTy, mlir::omp::ExitDataOp>) { - directiveName = - Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetExitData; + directiveName = omp::clause::If::DirectiveNameModifier::TargetExitData; directive = llvm::omp::Directive::OMPD_target_exit_data; } else if constexpr (std::is_same_v<OpTy, mlir::omp::UpdateDataOp>) { - directiveName = - Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetUpdate; + directiveName = omp::clause::If::DirectiveNameModifier::TargetUpdate; directive = llvm::omp::Directive::OMPD_target_update; } else { return nullptr; @@ -4100,8 +4085,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter, llvm::SmallVector<const Fortran::semantics::Symbol *> mapSymbols; ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Target, - ifClauseOperand); + cp.processIf(omp::clause::If::DirectiveNameModifier::Target, ifClauseOperand); cp.processDevice(stmtCtx, deviceOperand); cp.processThreadLimit(stmtCtx, threadLimitOperand); cp.processNowait(nowaitAttr); @@ -4214,8 +4198,7 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter, llvm::SmallVector<mlir::Attribute> reductionDeclSymbols; ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Teams, - ifClauseOperand); + cp.processIf(omp::clause::If::DirectiveNameModifier::Teams, ifClauseOperand); cp.processAllocate(allocatorOperands, allocateOperands); cp.processDefault(); cp.processNumTeams(stmtCtx, numTeamsClauseOperand); @@ -4254,8 +4237,9 @@ static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo( if (const auto *objectList{ Fortran::parser::Unwrap<Fortran::parser::OmpObjectList>(spec.u)}) { + omp::ObjectList objects{omp::makeList(*objectList, semaCtx)}; // Case: declare target(func, var1, var2) - gatherFuncAndVarSyms(*objectList, mlir::omp::DeclareTargetCaptureClause::to, + gatherFuncAndVarSyms(objects, mlir::omp::DeclareTargetCaptureClause::to, symbolAndClause); } else if (const auto *clauseList{ Fortran::parser::Unwrap<Fortran::parser::OmpClauseList>( @@ -4369,7 +4353,7 @@ genOmpFlush(Fortran::lower::AbstractConverter &converter, if (const auto &ompObjectList = std::get<std::optional<Fortran::parser::OmpObjectList>>( flushConstruct.t)) - genObjectList(*ompObjectList, converter, operandRange); + genObjectList2(*ompObjectList, converter, operandRange); const auto &memOrderClause = std::get<std::optional<std::list<Fortran::parser::OmpMemoryOrderClause>>>( flushConstruct.t); @@ -4479,8 +4463,7 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter, loopVarTypeSize); cp.processScheduleChunk(stmtCtx, scheduleChunkClauseOperand); cp.processReduction(loc, reductionVars, reductionDeclSymbols); - cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Simd, - ifClauseOperand); + cp.processIf(omp::clause::If::DirectiveNameModifier::Simd, ifClauseOperand); cp.processSimdlen(simdlenClauseOperand); cp.processSafelen(safelenClauseOperand); cp.processTODO<Fortran::parser::OmpClause::Aligned, @@ -5274,106 +5257,101 @@ void Fortran::lower::genOpenMPReduction( const Fortran::parser::OmpClauseList &clauseList) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - for (const Fortran::parser::OmpClause &clause : clauseList.v) { + omp::List<omp::Clause> clauses{omp::makeList(clauseList, semaCtx)}; + + for (const omp::Clause &clause : clauses) { if (const auto &reductionClause = - std::get_if<Fortran::parser::OmpClause::Reduction>(&clause.u)) { - const auto &redOperator{std::get<Fortran::parser::OmpReductionOperator>( - reductionClause->v.t)}; - const auto &objectList{ - std::get<Fortran::parser::OmpObjectList>(reductionClause->v.t)}; + std::get_if<omp::clause::Reduction>(&clause.u)) { + const auto &redOperator{ + std::get<omp::clause::ReductionOperator>(reductionClause->t)}; + const auto &objectList{std::get<omp::ObjectList>(reductionClause->t)}; if (const auto *reductionOp = - std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) { + std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) { const auto &intrinsicOp{ - std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>( + std::get<omp::clause::DefinedOperator::IntrinsicOperator>( reductionOp->u)}; switch (intrinsicOp) { - case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: - case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: - case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: - case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: - case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: - case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: + case omp::clause::DefinedOperator::IntrinsicOperator::Add: + case omp::clause::DefinedOperator::IntrinsicOperator::Multiply: + case omp::clause::DefinedOperator::IntrinsicOperator::AND: + case omp::clause::DefinedOperator::IntrinsicOperator::EQV: + case omp::clause::DefinedOperator::IntrinsicOperator::OR: + case omp::clause::DefinedOperator::IntrinsicOperator::NEQV: break; default: continue; } - for (const Fortran::parser::OmpObject &ompObject : objectList.v) { - if (const auto *name{ - Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) { - if (const Fortran::semantics::Symbol * symbol{name->symbol}) { - mlir::Value reductionVal = converter.getSymbolAddress(*symbol); - if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>()) - reductionVal = declOp.getBase(); - mlir::Type reductionType = - reductionVal.getType().cast<fir::ReferenceType>().getEleTy(); - if (!reductionType.isa<fir::LogicalType>()) { - if (!reductionType.isIntOrIndexOrFloat()) - continue; - } - for (mlir::OpOperand &reductionValUse : reductionVal.getUses()) { - if (auto loadOp = mlir::dyn_cast<fir::LoadOp>( - reductionValUse.getOwner())) { - mlir::Value loadVal = loadOp.getRes(); - if (reductionType.isa<fir::LogicalType>()) { - mlir::Operation *reductionOp = findReductionChain(loadVal); - fir::ConvertOp convertOp = - getConvertFromReductionOp(reductionOp, loadVal); - updateReduction(reductionOp, firOpBuilder, loadVal, - reductionVal, &convertOp); - removeStoreOp(reductionOp, reductionVal); - } else if (mlir::Operation *reductionOp = - findReductionChain(loadVal, &reductionVal)) { - updateReduction(reductionOp, firOpBuilder, loadVal, - reductionVal); - } + for (const omp::Object &object : objectList) { + if (const Fortran::semantics::Symbol *symbol = object.sym) { + mlir::Value reductionVal = converter.getSymbolAddress(*symbol); + if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>()) + reductionVal = declOp.getBase(); + mlir::Type reductionType = + reductionVal.getType().cast<fir::ReferenceType>().getEleTy(); + if (!reductionType.isa<fir::LogicalType>()) { + if (!reductionType.isIntOrIndexOrFloat()) + continue; + } + for (mlir::OpOperand &reductionValUse : reductionVal.getUses()) { + if (auto loadOp = + mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner())) { + mlir::Value loadVal = loadOp.getRes(); + if (reductionType.isa<fir::LogicalType>()) { + mlir::Operation *reductionOp = findReductionChain(loadVal); + fir::ConvertOp convertOp = + getConvertFromReductionOp(reductionOp, loadVal); + updateReduction(reductionOp, firOpBuilder, loadVal, + reductionVal, &convertOp); + removeStoreOp(reductionOp, reductionVal); + } else if (mlir::Operation *reductionOp = + findReductionChain(loadVal, &reductionVal)) { + updateReduction(reductionOp, firOpBuilder, loadVal, + reductionVal); } } } } } } else if (const auto *reductionIntrinsic = - std::get_if<Fortran::parser::ProcedureDesignator>( + std::get_if<omp::clause::ProcedureDesignator>( &redOperator.u)) { if (!ReductionProcessor::supportedIntrinsicProcReduction( *reductionIntrinsic)) continue; ReductionProcessor::ReductionIdentifier redId = ReductionProcessor::getReductionType(*reductionIntrinsic); - for (const Fortran::parser::OmpObject &ompObject : objectList.v) { - if (const auto *name{ - Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) { - if (const Fortran::semantics::Symbol * symbol{name->symbol}) { - mlir::Value reductionVal = converter.getSymbolAddress(*symbol); - if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>()) - reductionVal = declOp.getBase(); - for (const mlir::OpOperand &reductionValUse : - reductionVal.getUses()) { - if (auto loadOp = mlir::dyn_cast<fir::LoadOp>( - reductionValUse.getOwner())) { - mlir::Value loadVal = loadOp.getRes(); - // Max is lowered as a compare -> select. - // Match the pattern here. - mlir::Operation *reductionOp = - findReductionChain(loadVal, &reductionVal); - if (reductionOp == nullptr) - continue; - - if (redId == ReductionProcessor::ReductionIdentifier::MAX || - redId == ReductionProcessor::ReductionIdentifier::MIN) { - assert(mlir::isa<mlir::arith::SelectOp>(reductionOp) && - "Selection Op not found in reduction intrinsic"); - mlir::Operation *compareOp = - getCompareFromReductionOp(reductionOp, loadVal); - updateReduction(compareOp, firOpBuilder, loadVal, - reductionVal); - } - if (redId == ReductionProcessor::ReductionIdentifier::IOR || - redId == ReductionProcessor::ReductionIdentifier::IEOR || - redId == ReductionProcessor::ReductionIdentifier::IAND) { - updateReduction(reductionOp, firOpBuilder, loadVal, - reductionVal); - } + for (const omp::Object &object : objectList) { + if (const Fortran::semantics::Symbol *symbol = object.sym) { + mlir::Value reductionVal = converter.getSymbolAddress(*symbol); + if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>()) + reductionVal = declOp.getBase(); + for (const mlir::OpOperand &reductionValUse : + reductionVal.getUses()) { + if (auto loadOp = + mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner())) { + mlir::Value loadVal = loadOp.getRes(); + // Max is lowered as a compare -> select. + // Match the pattern here. + mlir::Operation *reductionOp = + findReductionChain(loadVal, &reductionVal); + if (reductionOp == nullptr) + continue; + + if (redId == ReductionProcessor::ReductionIdentifier::MAX || + redId == ReductionProcessor::ReductionIdentifier::MIN) { + assert(mlir::isa<mlir::arith::SelectOp>(reductionOp) && + "Selection Op not found in reduction intrinsic"); + mlir::Operation *compareOp = + getCompareFromReductionOp(reductionOp, loadVal); + updateReduction(compareOp, firOpBuilder, loadVal, + reductionVal); + } + if (redId == ReductionProcessor::ReductionIdentifier::IOR || + redId == ReductionProcessor::ReductionIdentifier::IEOR || + redId == ReductionProcessor::ReductionIdentifier::IAND) { + updateReduction(reductionOp, firOpBuilder, loadVal, + reductionVal); } } } _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits