https://github.com/kparzysz updated https://github.com/llvm/llvm-project/pull/81626
>From 87437159da37749ad395d84a3fc1b729bd9e2480 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek <krzysztof.parzys...@amd.com> Date: Thu, 8 Feb 2024 08:33:40 -0600 Subject: [PATCH] [flang][Lower] Convert OMP Map and related functions to evaluate::Expr The related functions are `gatherDataOperandAddrAndBounds` and `genBoundsOps`. The former is used in OpenACC as well, and it was updated to pass evaluate::Expr instead of parser objects. The difference in the test case comes from unfolded conversions of index expressions, which are explicitly of type integer(kind=8). Delete now unused `findRepeatableClause2` and `findClause2`. Add `AsGenericExpr` that takes std::optional. It already returns optional Expr. Making it accept an optional Expr as input would reduce the number of necessary checks when handling frequent optional values in evaluator. --- flang/include/flang/Evaluate/tools.h | 8 + flang/lib/Lower/DirectivesCommon.h | 389 ++++++++++++++++----------- flang/lib/Lower/OpenACC.cpp | 54 ++-- flang/lib/Lower/OpenMP.cpp | 105 +++----- 4 files changed, 311 insertions(+), 245 deletions(-) diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h index e9999974944e88..d5713cfe420a2e 100644 --- a/flang/include/flang/Evaluate/tools.h +++ b/flang/include/flang/Evaluate/tools.h @@ -148,6 +148,14 @@ inline Expr<SomeType> AsGenericExpr(Expr<SomeType> &&x) { return std::move(x); } std::optional<Expr<SomeType>> AsGenericExpr(DataRef &&); std::optional<Expr<SomeType>> AsGenericExpr(const Symbol &); +// Propagate std::optional from input to output. +template <typename A> +std::optional<Expr<SomeType>> AsGenericExpr(std::optional<A> &&x) { + if (!x) + return std::nullopt; + return AsGenericExpr(std::move(*x)); +} + template <typename A> common::IfNoLvalue<Expr<SomeKind<ResultType<A>::category>>, A> AsCategoryExpr( A &&x) { diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/lib/Lower/DirectivesCommon.h index 8d560db34e05bf..2fa90572bc63eb 100644 --- a/flang/lib/Lower/DirectivesCommon.h +++ b/flang/lib/Lower/DirectivesCommon.h @@ -808,6 +808,75 @@ genBaseBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc, return bounds; } +namespace detail { +template <typename T> // +static T &&AsRvalueRef(T &&t) { + return std::move(t); +} +template <typename T> // +static T AsRvalueRef(T &t) { + return t; +} +template <typename T> // +static T AsRvalueRef(const T &t) { + return t; +} + +// Helper class for stripping enclosing parentheses and a conversion that +// preserves type category. This is used for triplet elements, which are +// always of type integer(kind=8). The lower/upper bounds are converted to +// an "index" type, which is 64-bit, so the explicit conversion to kind=8 +// (if present) is not needed. When it's present, though, it causes generated +// names to contain "int(..., kind=8)". +struct PeelConvert { + template <Fortran::common::TypeCategory Category, int Kind> + static Fortran::semantics::MaybeExpr visit_with_category( + const Fortran::evaluate::Expr<Fortran::evaluate::Type<Category, Kind>> + &expr) { + return std::visit( + [](auto &&s) { return visit_with_category<Category, Kind>(s); }, + expr.u); + } + template <Fortran::common::TypeCategory Category, int Kind> + static Fortran::semantics::MaybeExpr visit_with_category( + const Fortran::evaluate::Convert<Fortran::evaluate::Type<Category, Kind>, + Category> &expr) { + return AsGenericExpr(AsRvalueRef(expr.left())); + } + template <Fortran::common::TypeCategory Category, int Kind, typename T> + static Fortran::semantics::MaybeExpr visit_with_category(const T &) { + return std::nullopt; // + } + template <Fortran::common::TypeCategory Category, typename T> + static Fortran::semantics::MaybeExpr visit_with_category(const T &) { + return std::nullopt; // + } + + template <Fortran::common::TypeCategory Category> + static Fortran::semantics::MaybeExpr + visit(const Fortran::evaluate::Expr<Fortran::evaluate::SomeKind<Category>> + &expr) { + return std::visit([](auto &&s) { return visit_with_category<Category>(s); }, + expr.u); + } + static Fortran::semantics::MaybeExpr + visit(const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &expr) { + return std::visit([](auto &&s) { return visit(s); }, expr.u); + } + template <typename T> // + static Fortran::semantics::MaybeExpr visit(const T &) { + return std::nullopt; + } +}; + +static Fortran::semantics::SomeExpr +peelOuterConvert(Fortran::semantics::SomeExpr &expr) { + if (auto peeled = PeelConvert::visit(expr)) + return *peeled; + return expr; +} +} // namespace detail + /// Generate bounds operations for an array section when subscripts are /// provided. template <typename BoundsOp, typename BoundsType> @@ -815,7 +884,7 @@ llvm::SmallVector<mlir::Value> genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc, Fortran::lower::AbstractConverter &converter, Fortran::lower::StatementContext &stmtCtx, - const std::list<Fortran::parser::SectionSubscript> &subscripts, + const std::vector<Fortran::evaluate::Subscript> &subscripts, std::stringstream &asFortran, fir::ExtendedValue &dataExv, bool dataExvIsAssumedSize, AddrAndBoundsInfo &info, bool treatIndexAsSection = false) { @@ -828,8 +897,7 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); const int dataExvRank = static_cast<int>(dataExv.rank()); for (const auto &subscript : subscripts) { - const auto *triplet{ - std::get_if<Fortran::parser::SubscriptTriplet>(&subscript.u)}; + const auto *triplet{std::get_if<Fortran::evaluate::Triplet>(&subscript.u)}; if (triplet || treatIndexAsSection) { if (dimension != 0) asFortran << ','; @@ -868,13 +936,18 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc, strideInBytes = true; } - const Fortran::lower::SomeExpr *lower{nullptr}; + Fortran::semantics::MaybeExpr lower; if (triplet) { - if (const auto &tripletLb{std::get<0>(triplet->t)}) - lower = Fortran::semantics::GetExpr(*tripletLb); + if ((lower = Fortran::evaluate::AsGenericExpr(triplet->lower()))) + lower = detail::peelOuterConvert(*lower); } else { - const auto &index{std::get<Fortran::parser::IntExpr>(subscript.u)}; - lower = Fortran::semantics::GetExpr(index); + // Case of IndirectSubscriptIntegerExpr + using IndirectSubscriptIntegerExpr = + Fortran::evaluate::IndirectSubscriptIntegerExpr; + using SubscriptInteger = Fortran::evaluate::SubscriptInteger; + Fortran::evaluate::Expr<SubscriptInteger> oneInt = + std::get<IndirectSubscriptIntegerExpr>(subscript.u).value(); + lower = Fortran::evaluate::AsGenericExpr(std::move(oneInt)); if (lower->Rank() > 0) { mlir::emitError( loc, "vector subscript cannot be used for an array section"); @@ -912,10 +985,12 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc, extent = one; } else { asFortran << ':'; - const auto &upper{std::get<1>(triplet->t)}; + Fortran::semantics::MaybeExpr upper = + Fortran::evaluate::AsGenericExpr(triplet->upper()); if (upper) { - uval = Fortran::semantics::GetIntValue(upper); + upper = detail::peelOuterConvert(*upper); + uval = Fortran::evaluate::ToInt64(*upper); if (uval) { if (defaultLb) { ubound = builder.createIntegerConstant(loc, idxTy, *uval - 1); @@ -925,22 +1000,21 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc, } asFortran << *uval; } else { - const Fortran::lower::SomeExpr *uexpr = - Fortran::semantics::GetExpr(*upper); mlir::Value ub = - fir::getBase(converter.genExprValue(loc, *uexpr, stmtCtx)); + fir::getBase(converter.genExprValue(loc, *upper, stmtCtx)); ub = builder.createConvert(loc, baseLb.getType(), ub); ubound = builder.create<mlir::arith::SubIOp>(loc, ub, baseLb); - asFortran << uexpr->AsFortran(); + asFortran << upper->AsFortran(); } } if (lower && upper) { if (lval && uval && *uval < *lval) { mlir::emitError(loc, "zero sized array section"); break; - } else if (std::get<2>(triplet->t)) { - const auto &strideExpr{std::get<2>(triplet->t)}; - if (strideExpr) { + } else { + // Stride is mandatory in evaluate::Triplet. Make sure it's 1. + auto val = Fortran::evaluate::ToInt64(triplet->GetStride()); + if (!val || *val != 1) { mlir::emitError(loc, "stride cannot be specified on " "an array section"); break; @@ -993,150 +1067,157 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc, return bounds; } -template <typename ObjectType, typename BoundsOp, typename BoundsType> +namespace detail { +template <typename Ref, typename Expr> // +std::optional<Ref> getRef(Expr &&expr) { + if constexpr (std::is_same_v<llvm::remove_cvref_t<Expr>, + Fortran::evaluate::DataRef>) { + if (auto *ref = std::get_if<Ref>(&expr.u)) + return *ref; + return std::nullopt; + } else { + auto maybeRef = Fortran::evaluate::ExtractDataRef(expr); + if (!maybeRef || !std::holds_alternative<Ref>(maybeRef->u)) + return std::nullopt; + return std::get<Ref>(maybeRef->u); + } +} +} // namespace detail + +template <typename BoundsOp, typename BoundsType> AddrAndBoundsInfo gatherDataOperandAddrAndBounds( Fortran::lower::AbstractConverter &converter, fir::FirOpBuilder &builder, - Fortran::semantics::SemanticsContext &semanticsContext, - Fortran::lower::StatementContext &stmtCtx, const ObjectType &object, + semantics::SemanticsContext &semaCtx, + Fortran::lower::StatementContext &stmtCtx, + Fortran::semantics::SymbolRef symbol, + const Fortran::semantics::MaybeExpr &maybeDesignator, mlir::Location operandLocation, std::stringstream &asFortran, llvm::SmallVector<mlir::Value> &bounds, bool treatIndexAsSection = false) { + using namespace Fortran; + AddrAndBoundsInfo info; - std::visit( - Fortran::common::visitors{ - [&](const Fortran::parser::Designator &designator) { - if (auto expr{Fortran::semantics::AnalyzeExpr(semanticsContext, - designator)}) { - if (((*expr).Rank() > 0 || treatIndexAsSection) && - Fortran::parser::Unwrap<Fortran::parser::ArrayElement>( - designator)) { - const auto *arrayElement = - Fortran::parser::Unwrap<Fortran::parser::ArrayElement>( - designator); - const auto *dataRef = - std::get_if<Fortran::parser::DataRef>(&designator.u); - fir::ExtendedValue dataExv; - bool dataExvIsAssumedSize = false; - if (Fortran::parser::Unwrap< - Fortran::parser::StructureComponent>( - arrayElement->base)) { - auto exprBase = Fortran::semantics::AnalyzeExpr( - semanticsContext, arrayElement->base); - dataExv = converter.genExprAddr(operandLocation, *exprBase, - stmtCtx); - info.addr = fir::getBase(dataExv); - info.rawInput = info.addr; - asFortran << (*exprBase).AsFortran(); - } else { - const Fortran::parser::Name &name = - Fortran::parser::GetLastName(*dataRef); - dataExvIsAssumedSize = Fortran::semantics::IsAssumedSizeArray( - name.symbol->GetUltimate()); - info = getDataOperandBaseAddr(converter, builder, - *name.symbol, operandLocation); - dataExv = converter.getSymbolExtendedValue(*name.symbol); - asFortran << name.ToString(); - } - - if (!arrayElement->subscripts.empty()) { - asFortran << '('; - bounds = genBoundsOps<BoundsOp, BoundsType>( - builder, operandLocation, converter, stmtCtx, - arrayElement->subscripts, asFortran, dataExv, - dataExvIsAssumedSize, info, treatIndexAsSection); - } - asFortran << ')'; - } else if (auto structComp = Fortran::parser::Unwrap< - Fortran::parser::StructureComponent>(designator)) { - fir::ExtendedValue compExv = - converter.genExprAddr(operandLocation, *expr, stmtCtx); - info.addr = fir::getBase(compExv); - info.rawInput = info.addr; - if (fir::unwrapRefType(info.addr.getType()) - .isa<fir::SequenceType>()) - bounds = genBaseBoundsOps<BoundsOp, BoundsType>( - builder, operandLocation, converter, compExv, - /*isAssumedSize=*/false); - asFortran << (*expr).AsFortran(); - - bool isOptional = Fortran::semantics::IsOptional( - *Fortran::parser::GetLastName(*structComp).symbol); - if (isOptional) - info.isPresent = builder.create<fir::IsPresentOp>( - operandLocation, builder.getI1Type(), info.rawInput); - - if (auto loadOp = mlir::dyn_cast_or_null<fir::LoadOp>( - info.addr.getDefiningOp())) { - if (fir::isAllocatableType(loadOp.getType()) || - fir::isPointerType(loadOp.getType())) - info.addr = builder.create<fir::BoxAddrOp>(operandLocation, - info.addr); - info.rawInput = info.addr; - } - - // If the component is an allocatable or pointer the result of - // genExprAddr will be the result of a fir.box_addr operation or - // a fir.box_addr has been inserted just before. - // Retrieve the box so we handle it like other descriptor. - if (auto boxAddrOp = mlir::dyn_cast_or_null<fir::BoxAddrOp>( - info.addr.getDefiningOp())) { - info.addr = boxAddrOp.getVal(); - info.rawInput = info.addr; - bounds = genBoundsOpsFromBox<BoundsOp, BoundsType>( - builder, operandLocation, converter, compExv, info); - } - } else { - if (Fortran::parser::Unwrap<Fortran::parser::ArrayElement>( - designator)) { - // Single array element. - const auto *arrayElement = - Fortran::parser::Unwrap<Fortran::parser::ArrayElement>( - designator); - (void)arrayElement; - fir::ExtendedValue compExv = - converter.genExprAddr(operandLocation, *expr, stmtCtx); - info.addr = fir::getBase(compExv); - info.rawInput = info.addr; - asFortran << (*expr).AsFortran(); - } else if (const auto *dataRef{ - std::get_if<Fortran::parser::DataRef>( - &designator.u)}) { - // Scalar or full array. - const Fortran::parser::Name &name = - Fortran::parser::GetLastName(*dataRef); - fir::ExtendedValue dataExv = - converter.getSymbolExtendedValue(*name.symbol); - info = getDataOperandBaseAddr(converter, builder, - *name.symbol, operandLocation); - if (fir::unwrapRefType(info.addr.getType()) - .isa<fir::BaseBoxType>()) { - bounds = genBoundsOpsFromBox<BoundsOp, BoundsType>( - builder, operandLocation, converter, dataExv, info); - } - bool dataExvIsAssumedSize = - Fortran::semantics::IsAssumedSizeArray( - name.symbol->GetUltimate()); - if (fir::unwrapRefType(info.addr.getType()) - .isa<fir::SequenceType>()) - bounds = genBaseBoundsOps<BoundsOp, BoundsType>( - builder, operandLocation, converter, dataExv, - dataExvIsAssumedSize); - asFortran << name.ToString(); - } else { // Unsupported - llvm::report_fatal_error( - "Unsupported type of OpenACC operand"); - } - } - } - }, - [&](const Fortran::parser::Name &name) { - info = getDataOperandBaseAddr(converter, builder, *name.symbol, - operandLocation); - asFortran << name.ToString(); - }}, - object.u); + + if (!maybeDesignator) { + info = getDataOperandBaseAddr(converter, builder, symbol, operandLocation); + asFortran << symbol->name().ToString(); + return info; + } + + semantics::SomeExpr designator = *maybeDesignator; + + if ((designator.Rank() > 0 || treatIndexAsSection) && + IsArrayElement(designator)) { + auto arrayRef = detail::getRef<evaluate::ArrayRef>(designator); + // This shouldn't fail after IsArrayElement(designator). + assert(arrayRef && "Expecting ArrayRef"); + + fir::ExtendedValue dataExv; + bool dataExvIsAssumedSize = false; + + auto toMaybeExpr = [&](auto &&base) { + using BaseType = llvm::remove_cvref_t<decltype(base)>; + evaluate::ExpressionAnalyzer ea{semaCtx}; + + if constexpr (std::is_same_v<evaluate::NamedEntity, BaseType>) { + if (auto *ref = base.UnwrapSymbolRef()) + return ea.Designate(evaluate::DataRef{*ref}); + if (auto *ref = base.UnwrapComponent()) + return ea.Designate(evaluate::DataRef{*ref}); + llvm_unreachable("Unexpected NamedEntity"); + } else { + static_assert(std::is_same_v<semantics::SymbolRef, BaseType>); + return ea.Designate(evaluate::DataRef{base}); + } + }; + + auto arrayBase = toMaybeExpr(arrayRef->base()); + assert(arrayBase); + + if (detail::getRef<evaluate::Component>(*arrayBase)) { + dataExv = converter.genExprAddr(operandLocation, *arrayBase, stmtCtx); + info.addr = fir::getBase(dataExv); + info.rawInput = info.addr; + asFortran << arrayBase->AsFortran(); + } else { + const semantics::Symbol &sym = arrayRef->GetLastSymbol(); + dataExvIsAssumedSize = + Fortran::semantics::IsAssumedSizeArray(sym.GetUltimate()); + info = getDataOperandBaseAddr(converter, builder, sym, operandLocation); + dataExv = converter.getSymbolExtendedValue(sym); + asFortran << sym.name().ToString(); + } + + if (!arrayRef->subscript().empty()) { + asFortran << '('; + bounds = genBoundsOps<BoundsOp, BoundsType>( + builder, operandLocation, converter, stmtCtx, arrayRef->subscript(), + asFortran, dataExv, dataExvIsAssumedSize, info, treatIndexAsSection); + } + asFortran << ')'; + } else if (auto compRef = detail::getRef<evaluate::Component>(designator)) { + fir::ExtendedValue compExv = + converter.genExprAddr(operandLocation, designator, stmtCtx); + info.addr = fir::getBase(compExv); + info.rawInput = info.addr; + if (fir::unwrapRefType(info.addr.getType()).isa<fir::SequenceType>()) + bounds = genBaseBoundsOps<BoundsOp, BoundsType>(builder, operandLocation, + converter, compExv, + /*isAssumedSize=*/false); + asFortran << designator.AsFortran(); + + if (semantics::IsOptional(compRef->GetLastSymbol())) { + info.isPresent = builder.create<fir::IsPresentOp>( + operandLocation, builder.getI1Type(), info.rawInput); + } + + if (auto loadOp = + mlir::dyn_cast_or_null<fir::LoadOp>(info.addr.getDefiningOp())) { + if (fir::isAllocatableType(loadOp.getType()) || + fir::isPointerType(loadOp.getType())) + info.addr = builder.create<fir::BoxAddrOp>(operandLocation, info.addr); + info.rawInput = info.addr; + } + + // If the component is an allocatable or pointer the result of + // genExprAddr will be the result of a fir.box_addr operation or + // a fir.box_addr has been inserted just before. + // Retrieve the box so we handle it like other descriptor. + if (auto boxAddrOp = + mlir::dyn_cast_or_null<fir::BoxAddrOp>(info.addr.getDefiningOp())) { + info.addr = boxAddrOp.getVal(); + info.rawInput = info.addr; + bounds = genBoundsOpsFromBox<BoundsOp, BoundsType>( + builder, operandLocation, converter, compExv, info); + } + } else { + if (detail::getRef<evaluate::ArrayRef>(designator)) { + fir::ExtendedValue compExv = + converter.genExprAddr(operandLocation, designator, stmtCtx); + info.addr = fir::getBase(compExv); + info.rawInput = info.addr; + asFortran << designator.AsFortran(); + } else if (auto symRef = detail::getRef<semantics::SymbolRef>(designator)) { + // Scalar or full array. + fir::ExtendedValue dataExv = converter.getSymbolExtendedValue(*symRef); + info = + getDataOperandBaseAddr(converter, builder, *symRef, operandLocation); + if (fir::unwrapRefType(info.addr.getType()).isa<fir::BaseBoxType>()) { + bounds = genBoundsOpsFromBox<BoundsOp, BoundsType>( + builder, operandLocation, converter, dataExv, info); + } + bool dataExvIsAssumedSize = + Fortran::semantics::IsAssumedSizeArray(symRef->get().GetUltimate()); + if (fir::unwrapRefType(info.addr.getType()).isa<fir::SequenceType>()) + bounds = genBaseBoundsOps<BoundsOp, BoundsType>( + builder, operandLocation, converter, dataExv, dataExvIsAssumedSize); + asFortran << symRef->get().name().ToString(); + } else { // Unsupported + llvm::report_fatal_error("Unsupported type of OpenACC operand"); + } + } + return info; } - } // namespace lower } // namespace Fortran diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index 6ae270f63f5cf4..a444682306ac20 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -269,6 +269,11 @@ getSymbolFromAccObject(const Fortran::parser::AccObject &accObject) { Fortran::parser::GetLastName(arrayElement->base); return *name.symbol; } + if (const auto *component = + Fortran::parser::Unwrap<Fortran::parser::StructureComponent>( + *designator)) { + return *component->component.symbol; + } } else if (const auto *name = std::get_if<Fortran::parser::Name>(&accObject.u)) { return *name->symbol; @@ -286,17 +291,20 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList, mlir::acc::DataClause dataClause, bool structured, bool implicit, bool setDeclareAttr = false) { fir::FirOpBuilder &builder = converter.getFirOpBuilder(); + Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext}; for (const auto &accObject : objectList.v) { llvm::SmallVector<mlir::Value> bounds; std::stringstream asFortran; mlir::Location operandLocation = genOperandLocation(converter, accObject); + Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject); + Fortran::semantics::MaybeExpr designator = + std::visit([&](auto &&s) { return ea.Analyze(s); }, accObject.u); Fortran::lower::AddrAndBoundsInfo info = Fortran::lower::gatherDataOperandAddrAndBounds< - Fortran::parser::AccObject, mlir::acc::DataBoundsOp, - mlir::acc::DataBoundsType>(converter, builder, semanticsContext, - stmtCtx, accObject, operandLocation, - asFortran, bounds, - /*treatIndexAsSection=*/true); + mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>( + converter, builder, semanticsContext, stmtCtx, symbol, designator, + operandLocation, asFortran, bounds, + /*treatIndexAsSection=*/true); // If the input value is optional and is not a descriptor, we use the // rawInput directly. @@ -321,16 +329,19 @@ static void genDeclareDataOperandOperations( llvm::SmallVectorImpl<mlir::Value> &dataOperands, mlir::acc::DataClause dataClause, bool structured, bool implicit) { fir::FirOpBuilder &builder = converter.getFirOpBuilder(); + Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext}; for (const auto &accObject : objectList.v) { llvm::SmallVector<mlir::Value> bounds; std::stringstream asFortran; mlir::Location operandLocation = genOperandLocation(converter, accObject); + Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject); + Fortran::semantics::MaybeExpr designator = + std::visit([&](auto &&s) { return ea.Analyze(s); }, accObject.u); Fortran::lower::AddrAndBoundsInfo info = Fortran::lower::gatherDataOperandAddrAndBounds< - Fortran::parser::AccObject, mlir::acc::DataBoundsOp, - mlir::acc::DataBoundsType>(converter, builder, semanticsContext, - stmtCtx, accObject, operandLocation, - asFortran, bounds); + mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>( + converter, builder, semanticsContext, stmtCtx, symbol, designator, + operandLocation, asFortran, bounds); EntryOp op = createDataEntryOp<EntryOp>( builder, operandLocation, info.addr, asFortran, bounds, structured, implicit, dataClause, info.addr.getType()); @@ -339,8 +350,7 @@ static void genDeclareDataOperandOperations( if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(info.addr.getType()))) { mlir::OpBuilder modBuilder(builder.getModule().getBodyRegion()); modBuilder.setInsertionPointAfter(builder.getFunction()); - std::string prefix = - converter.mangleName(getSymbolFromAccObject(accObject)); + std::string prefix = converter.mangleName(symbol); createDeclareAllocFuncWithArg<EntryOp>( modBuilder, builder, operandLocation, info.addr.getType(), prefix, asFortran, dataClause); @@ -783,16 +793,19 @@ genPrivatizations(const Fortran::parser::AccObjectList &objectList, llvm::SmallVectorImpl<mlir::Value> &dataOperands, llvm::SmallVector<mlir::Attribute> &privatizations) { fir::FirOpBuilder &builder = converter.getFirOpBuilder(); + Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext}; for (const auto &accObject : objectList.v) { llvm::SmallVector<mlir::Value> bounds; std::stringstream asFortran; mlir::Location operandLocation = genOperandLocation(converter, accObject); + Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject); + Fortran::semantics::MaybeExpr designator = + std::visit([&](auto &&s) { return ea.Analyze(s); }, accObject.u); Fortran::lower::AddrAndBoundsInfo info = Fortran::lower::gatherDataOperandAddrAndBounds< - Fortran::parser::AccObject, mlir::acc::DataBoundsOp, - mlir::acc::DataBoundsType>(converter, builder, semanticsContext, - stmtCtx, accObject, operandLocation, - asFortran, bounds); + mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>( + converter, builder, semanticsContext, stmtCtx, symbol, designator, + operandLocation, asFortran, bounds); RecipeOp recipe; mlir::Type retTy = getTypeFromBounds(bounds, info.addr.getType()); if constexpr (std::is_same_v<RecipeOp, mlir::acc::PrivateRecipeOp>) { @@ -1361,16 +1374,19 @@ genReductions(const Fortran::parser::AccObjectListWithReduction &objectList, const auto &op = std::get<Fortran::parser::AccReductionOperator>(objectList.t); mlir::acc::ReductionOperator mlirOp = getReductionOperator(op); + Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext}; for (const auto &accObject : objects.v) { llvm::SmallVector<mlir::Value> bounds; std::stringstream asFortran; mlir::Location operandLocation = genOperandLocation(converter, accObject); + Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject); + Fortran::semantics::MaybeExpr designator = + std::visit([&](auto &&s) { return ea.Analyze(s); }, accObject.u); Fortran::lower::AddrAndBoundsInfo info = Fortran::lower::gatherDataOperandAddrAndBounds< - Fortran::parser::AccObject, mlir::acc::DataBoundsOp, - mlir::acc::DataBoundsType>(converter, builder, semanticsContext, - stmtCtx, accObject, operandLocation, - asFortran, bounds); + mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>( + converter, builder, semanticsContext, stmtCtx, symbol, designator, + operandLocation, asFortran, bounds); mlir::Type reductionTy = fir::unwrapRefType(info.addr.getType()); if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(reductionTy)) diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp index caae5c0cef9251..4309d69434839f 100644 --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -1789,18 +1789,6 @@ class ClauseProcessor { return end; } - /// Utility to find a clause within a range in the clause list. - template <typename T> - static ClauseIterator2 findClause2(ClauseIterator2 begin, - ClauseIterator2 end) { - for (ClauseIterator2 it = begin; it != end; ++it) { - if (std::get_if<T>(&it->u)) - return it; - } - - return end; - } - /// Return the first instance of the given clause found in the clause list or /// `nullptr` if not present. If more than one instance is expected, use /// `findRepeatableClause` instead. @@ -1836,26 +1824,6 @@ class ClauseProcessor { return found; } - /// Call `callbackFn` for each occurrence of the given clause. Return `true` - /// if at least one instance was found. - template <typename T> - bool findRepeatableClause2( - std::function<void(const T *, const Fortran::parser::CharBlock &source)> - callbackFn) const { - bool found = false; - ClauseIterator2 nextIt, endIt = clauses2.v.end(); - for (ClauseIterator2 it = clauses2.v.begin(); it != endIt; it = nextIt) { - nextIt = findClause2<T>(it, endIt); - - if (nextIt != endIt) { - callbackFn(&std::get<T>(nextIt->u), nextIt->source); - found = true; - ++nextIt; - } - } - return found; - } - /// Set the `result` to a new `mlir::UnitAttr` if the clause is present. template <typename T> bool markClauseOccurrence(mlir::UnitAttr &result) const { @@ -2958,65 +2926,61 @@ bool ClauseProcessor::processMap( llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSymbols) const { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - return findRepeatableClause2<ClauseTy::Map>( - [&](const ClauseTy::Map *mapClause, + return findRepeatableClause<omp::clause::Map>( + [&](const omp::clause::Map &clause, const Fortran::parser::CharBlock &source) { + using Map = omp::clause::Map; mlir::Location clauseLocation = converter.genLocation(source); - const auto &oMapType = - std::get<std::optional<Fortran::parser::OmpMapType>>( - mapClause->v.t); + const auto &oMapType = std::get<std::optional<Map::MapType>>(clause.t); llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE; // If the map type is specified, then process it else Tofrom is the // default. if (oMapType) { - const Fortran::parser::OmpMapType::Type &mapType = - std::get<Fortran::parser::OmpMapType::Type>(oMapType->t); + const Map::MapType::Type &mapType = + std::get<Map::MapType::Type>(oMapType->t); switch (mapType) { - case Fortran::parser::OmpMapType::Type::To: + case Map::MapType::Type::To: mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO; break; - case Fortran::parser::OmpMapType::Type::From: + case Map::MapType::Type::From: mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; break; - case Fortran::parser::OmpMapType::Type::Tofrom: + case Map::MapType::Type::Tofrom: mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; break; - case Fortran::parser::OmpMapType::Type::Alloc: - case Fortran::parser::OmpMapType::Type::Release: + case Map::MapType::Type::Alloc: + case Map::MapType::Type::Release: // alloc and release is the default map_type for the Target Data // Ops, i.e. if no bits for map_type is supplied then alloc/release // is implicitly assumed based on the target directive. Default // value for Target Data and Enter Data is alloc and for Exit Data // it is release. break; - case Fortran::parser::OmpMapType::Type::Delete: + case Map::MapType::Type::Delete: mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE; } - if (std::get<std::optional<Fortran::parser::OmpMapType::Always>>( - oMapType->t)) + if (std::get<std::optional<Map::MapType::Always>>(oMapType->t)) mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS; } else { mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; } - for (const Fortran::parser::OmpObject &ompObject : - std::get<Fortran::parser::OmpObjectList>(mapClause->v.t).v) { + for (const omp::Object &object : std::get<omp::ObjectList>(clause.t)) { llvm::SmallVector<mlir::Value> bounds; std::stringstream asFortran; Fortran::lower::AddrAndBoundsInfo info = Fortran::lower::gatherDataOperandAddrAndBounds< - Fortran::parser::OmpObject, mlir::omp::DataBoundsOp, - mlir::omp::DataBoundsType>( - converter, firOpBuilder, semaCtx, stmtCtx, ompObject, - clauseLocation, asFortran, bounds, treatIndexAsSection); + mlir::omp::DataBoundsOp, mlir::omp::DataBoundsType>( + converter, firOpBuilder, semaCtx, stmtCtx, *object.sym, + object.dsg, clauseLocation, asFortran, bounds, + treatIndexAsSection); - auto origSymbol = - converter.getSymbolAddress(*getOmpObjectSymbol(ompObject)); + auto origSymbol = converter.getSymbolAddress(*object.sym); mlir::Value symAddr = info.addr; if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType())) symAddr = origSymbol; @@ -3039,7 +3003,7 @@ bool ClauseProcessor::processMap( mapSymLocs->push_back(symAddr.getLoc()); if (mapSymbols) - mapSymbols->push_back(getOmpObjectSymbol(ompObject)); + mapSymbols->push_back(object.sym); } }); } @@ -3120,32 +3084,31 @@ template <typename T> bool ClauseProcessor::processMotionClauses( Fortran::lower::StatementContext &stmtCtx, llvm::SmallVectorImpl<mlir::Value> &mapOperands) { - return findRepeatableClause2<T>( - [&](const T *motionClause, const Fortran::parser::CharBlock &source) { + return findRepeatableClause<T>( + [&](const T &clause, const Fortran::parser::CharBlock &source) { mlir::Location clauseLocation = converter.genLocation(source); fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - static_assert(std::is_same_v<T, ClauseProcessor::ClauseTy::To> || - std::is_same_v<T, ClauseProcessor::ClauseTy::From>); + static_assert(std::is_same_v<T, omp::clause::To> || + std::is_same_v<T, omp::clause::From>); // TODO Support motion modifiers: present, mapper, iterator. constexpr llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = - std::is_same_v<T, ClauseProcessor::ClauseTy::To> + std::is_same_v<T, omp::clause::To> ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; - for (const Fortran::parser::OmpObject &ompObject : motionClause->v.v) { + for (const omp::Object &object : clause.v) { llvm::SmallVector<mlir::Value> bounds; std::stringstream asFortran; Fortran::lower::AddrAndBoundsInfo info = Fortran::lower::gatherDataOperandAddrAndBounds< - Fortran::parser::OmpObject, mlir::omp::DataBoundsOp, - mlir::omp::DataBoundsType>( - converter, firOpBuilder, semaCtx, stmtCtx, ompObject, - clauseLocation, asFortran, bounds, treatIndexAsSection); + mlir::omp::DataBoundsOp, mlir::omp::DataBoundsType>( + converter, firOpBuilder, semaCtx, stmtCtx, *object.sym, + object.dsg, clauseLocation, asFortran, bounds, + treatIndexAsSection); - auto origSymbol = - converter.getSymbolAddress(*getOmpObjectSymbol(ompObject)); + auto origSymbol = converter.getSymbolAddress(*object.sym); mlir::Value symAddr = info.addr; if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType())) symAddr = origSymbol; @@ -3899,10 +3862,8 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter, cp.processNowait(nowaitAttr); if constexpr (std::is_same_v<OpTy, mlir::omp::UpdateDataOp>) { - cp.processMotionClauses<Fortran::parser::OmpClause::To>(stmtCtx, - mapOperands); - cp.processMotionClauses<Fortran::parser::OmpClause::From>(stmtCtx, - mapOperands); + cp.processMotionClauses<omp::clause::To>(stmtCtx, mapOperands); + cp.processMotionClauses<omp::clause::From>(stmtCtx, mapOperands); } else { cp.processMap(currentLocation, directive, stmtCtx, mapOperands); _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits