llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-flang-fir-hlfir Author: Krzysztof Parzyszek (kparzysz) <details> <summary>Changes</summary> This improves the separation of the generic Fortran lowering and the lowering of OpenMP constructs. The mixin is intended to be derived from via CRTP: ``` class FirConverter : public OpenMPMixin<FirConverter> ... ``` The primary goal of the mixin is to implement `genFIR` functions that the derived converter can then call via ``` std::visit([this](auto &&s) { genFIR(s); }); ``` The mixin is also expecting a handful of functions to be present in the derived class, most importantly `genFIR(Evaluation&)`, plus getter classes for the op builder, symbol table, etc. The pre-existing PFT-lowering functionality is preserved. --- Full diff: https://github.com/llvm/llvm-project/pull/74866.diff 5 Files Affected: - (modified) flang/lib/Lower/Bridge.cpp (+2-82) - (added) flang/lib/Lower/ConverterMixin.h (+28) - (modified) flang/lib/Lower/FirConverter.h (+24-18) - (modified) flang/lib/Lower/OpenMP.cpp (+117-1) - (added) flang/lib/Lower/OpenMPMixin.h (+66) ``````````diff diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index 885c9307b8caf..061f9f29ffb00 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -170,7 +170,7 @@ void FirConverter::run(Fortran::lower::pft::Program &pft) { }); finalizeOpenACCLowering(); - finalizeOpenMPLowering(globalOmpRequiresSymbol); + OpenMPBase::finalize(globalOmpRequiresSymbol); } /// Generate FIR for Evaluation \p eval. @@ -977,70 +977,6 @@ void FirConverter::genFIR(const Fortran::parser::OpenACCRoutineConstruct &acc) { // Handled by genFIR(const Fortran::parser::OpenACCDeclarativeConstruct &) } -void FirConverter::genFIR(const Fortran::parser::OpenMPConstruct &omp) { - mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint(); - localSymbols.pushScope(); - genOpenMPConstruct(*this, bridge.getSemanticsContext(), getEval(), omp); - - const Fortran::parser::OpenMPLoopConstruct *ompLoop = - std::get_if<Fortran::parser::OpenMPLoopConstruct>(&omp.u); - const Fortran::parser::OpenMPBlockConstruct *ompBlock = - std::get_if<Fortran::parser::OpenMPBlockConstruct>(&omp.u); - - // If loop is part of an OpenMP Construct then the OpenMP dialect - // workshare loop operation has already been created. Only the - // body needs to be created here and the do_loop can be skipped. - // Skip the number of collapsed loops, which is 1 when there is a - // no collapse requested. - - Fortran::lower::pft::Evaluation *curEval = &getEval(); - const Fortran::parser::OmpClauseList *loopOpClauseList = nullptr; - if (ompLoop) { - loopOpClauseList = &std::get<Fortran::parser::OmpClauseList>( - std::get<Fortran::parser::OmpBeginLoopDirective>(ompLoop->t).t); - int64_t collapseValue = Fortran::lower::getCollapseValue(*loopOpClauseList); - - curEval = &curEval->getFirstNestedEvaluation(); - for (int64_t i = 1; i < collapseValue; i++) { - curEval = &*std::next(curEval->getNestedEvaluations().begin()); - } - } - - for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations()) - genFIR(e); - - if (ompLoop) { - genOpenMPReduction(*this, *loopOpClauseList); - } else if (ompBlock) { - const auto &blockStart = - std::get<Fortran::parser::OmpBeginBlockDirective>(ompBlock->t); - const auto &blockClauses = - std::get<Fortran::parser::OmpClauseList>(blockStart.t); - genOpenMPReduction(*this, blockClauses); - } - - localSymbols.popScope(); - builder->restoreInsertionPoint(insertPt); - - // Register if a target region was found - ompDeviceCodeFound = - ompDeviceCodeFound || Fortran::lower::isOpenMPTargetConstruct(omp); -} - -void FirConverter::genFIR( - const Fortran::parser::OpenMPDeclarativeConstruct &ompDecl) { - mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint(); - // Register if a declare target construct intended for a target device was - // found - ompDeviceCodeFound = - ompDeviceCodeFound || - Fortran::lower::isOpenMPDeviceDeclareTarget(*this, getEval(), ompDecl); - genOpenMPDeclarativeConstruct(*this, getEval(), ompDecl); - for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations()) - genFIR(e); - builder->restoreInsertionPoint(insertPt); -} - void FirConverter::genFIR(const Fortran::parser::OpenStmt &stmt) { mlir::Value iostat = genOpenStatement(*this, stmt); genIoConditionBranches(getEval(), stmt.v, iostat); @@ -3752,13 +3688,7 @@ void FirConverter::instantiateVar(const Fortran::lower::pft::Variable &var, Fortran::lower::AggregateStoreMap &storeMap) { Fortran::lower::instantiateVariable(*this, var, localSymbols, storeMap); if (var.hasSymbol()) { - if (var.getSymbol().test( - Fortran::semantics::Symbol::Flag::OmpThreadprivate)) - Fortran::lower::genThreadprivateOp(*this, var); - - if (var.getSymbol().test( - Fortran::semantics::Symbol::Flag::OmpDeclareTarget)) - Fortran::lower::genDeclareTargetIntGlobal(*this, var); + OpenMPBase::instantiateVariable(*this, var); } } @@ -4443,16 +4373,6 @@ void FirConverter::finalizeOpenACCLowering() { accRoutineInfos); } -/// Performing OpenMP lowering actions that were deferred to the end of -/// lowering. -void FirConverter::finalizeOpenMPLowering( - const Fortran::semantics::Symbol *globalOmpRequiresSymbol) { - // Set the module attribute related to OpenMP requires directives - if (ompDeviceCodeFound) - Fortran::lower::genOpenMPRequires(getModuleOp().getOperation(), - globalOmpRequiresSymbol); -} - } // namespace Fortran::lower Fortran::evaluate::FoldingContext diff --git a/flang/lib/Lower/ConverterMixin.h b/flang/lib/Lower/ConverterMixin.h new file mode 100644 index 0000000000000..a873ff36d0f60 --- /dev/null +++ b/flang/lib/Lower/ConverterMixin.h @@ -0,0 +1,28 @@ +//===-- ConverterMixin.h --------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ +// +//===----------------------------------------------------------------------===// + +#ifndef FORTRAN_LOWER_CONVERTERMIXIN_H +#define FORTRAN_LOWER_CONVERTERMIXIN_H + +namespace Fortran::lower { + +template <typename FirConverterT> class ConverterMixinBase { +public: + FirConverterT *This() { return static_cast<FirConverterT *>(this); } + const FirConverterT *This() const { + return static_cast<const FirConverterT *>(this); + } +}; + +} // namespace Fortran::lower + +#endif // FORTRAN_LOWER_CONVERTERMIXIN_H diff --git a/flang/lib/Lower/FirConverter.h b/flang/lib/Lower/FirConverter.h index 51b8bd4fa0b38..0214ea88b1e5b 100644 --- a/flang/lib/Lower/FirConverter.h +++ b/flang/lib/Lower/FirConverter.h @@ -13,6 +13,9 @@ #ifndef FORTRAN_LOWER_FIRCONVERTER_H #define FORTRAN_LOWER_FIRCONVERTER_H +#include "ConverterMixin.h" +#include "OpenMPMixin.h" + #include "flang/Common/Fortran.h" #include "flang/Lower/AbstractConverter.h" #include "flang/Lower/Bridge.h" @@ -74,7 +77,11 @@ namespace Fortran::lower { -class FirConverter : public Fortran::lower::AbstractConverter { +class FirConverter : public Fortran::lower::AbstractConverter, + public OpenMPMixin<FirConverter> { + using OpenMPBase = OpenMPMixin<FirConverter>; + using OpenMPBase::genFIR; + public: explicit FirConverter(Fortran::lower::LoweringBridge &bridge) : Fortran::lower::AbstractConverter(bridge.getLoweringOptions()), @@ -83,6 +90,20 @@ class FirConverter : public Fortran::lower::AbstractConverter { void run(Fortran::lower::pft::Program &pft); +public: + // The interface that mixin is expecting. + + Fortran::lower::LoweringBridge &getBridge() { return bridge; } + fir::FirOpBuilder &getBuilder() { + assert(builder); + return *builder; + } + Fortran::lower::pft::Evaluation &getEval() { + assert(evalPtr); + return *evalPtr; + } + Fortran::lower::SymMap &getSymTable() { return localSymbols; } + /// The core of the conversion: take an evaluation and generate FIR for it. /// The generation for each individual element of PFT is done via a specific /// genFIR function (see below). @@ -141,8 +162,6 @@ class FirConverter : public Fortran::lower::AbstractConverter { void genFIR(const Fortran::parser::OpenACCConstruct &); void genFIR(const Fortran::parser::OpenACCDeclarativeConstruct &); void genFIR(const Fortran::parser::OpenACCRoutineConstruct &); - void genFIR(const Fortran::parser::OpenMPConstruct &); - void genFIR(const Fortran::parser::OpenMPDeclarativeConstruct &); void genFIR(const Fortran::parser::OpenStmt &); void genFIR(const Fortran::parser::PauseStmt &); void genFIR(const Fortran::parser::PointerAssignmentStmt &); @@ -194,7 +213,6 @@ class FirConverter : public Fortran::lower::AbstractConverter { void genFIR(const Fortran::parser::IfStmt &) {} // nop void genFIR(const Fortran::parser::IfThenStmt &) {} // nop void genFIR(const Fortran::parser::NonLabelDoStmt &) {} // nop - void genFIR(const Fortran::parser::OmpEndLoopDirective &) {} // nop void genFIR(const Fortran::parser::SelectTypeStmt &) {} // nop void genFIR(const Fortran::parser::TypeGuardStmt &) {} // nop @@ -684,7 +702,6 @@ class FirConverter : public Fortran::lower::AbstractConverter { mlir::Location toLocation(); void setCurrentEval(Fortran::lower::pft::Evaluation &eval); - Fortran::lower::pft::Evaluation &getEval(); std::optional<Fortran::evaluate::Shape> getShape(const Fortran::lower::SomeExpr &expr); @@ -707,8 +724,8 @@ class FirConverter : public Fortran::lower::AbstractConverter { void analyzeExplicitSpace(const Fortran::parser::PointerAssignmentStmt &s); void analyzeExplicitSpace(const Fortran::parser::WhereBodyConstruct &body); void analyzeExplicitSpace(const Fortran::parser::WhereConstruct &c); - void analyzeExplicitSpace( - const Fortran::parser::WhereConstruct::Elsewhere *ew); + void + analyzeExplicitSpace(const Fortran::parser::WhereConstruct::Elsewhere *ew); void analyzeExplicitSpace( const Fortran::parser::WhereConstruct::MaskedElsewhere &ew); void analyzeExplicitSpace(const Fortran::parser::WhereConstructStmt &ws); @@ -727,8 +744,6 @@ class FirConverter : public Fortran::lower::AbstractConverter { mlir::Type eleTy); void finalizeOpenACCLowering(); - void finalizeOpenMPLowering( - const Fortran::semantics::Symbol *globalOmpRequiresSymbol); //===--------------------------------------------------------------------===// @@ -776,10 +791,6 @@ class FirConverter : public Fortran::lower::AbstractConverter { /// Deferred OpenACC routine attachment. Fortran::lower::AccRoutineInfoMappingList accRoutineInfos; - /// Whether an OpenMP target region or declare target function/subroutine - /// intended for device offloading has been detected - bool ompDeviceCodeFound = false; - const Fortran::lower::ExprToValueMap *exprValueOverrides{nullptr}; }; @@ -1220,11 +1231,6 @@ FirConverter::setCurrentEval(Fortran::lower::pft::Evaluation &eval) { evalPtr = &eval; } -inline Fortran::lower::pft::Evaluation &FirConverter::getEval() { - assert(evalPtr); - return *evalPtr; -} - std::optional<Fortran::evaluate::Shape> inline FirConverter::getShape( const Fortran::lower::SomeExpr &expr) { return Fortran::evaluate::GetShape(foldingContext, expr); diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp index eeba87fcd1511..5ca7be5da26a6 100644 --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -10,12 +10,15 @@ // //===----------------------------------------------------------------------===// -#include "flang/Lower/OpenMP.h" +#include "FirConverter.h" +#include "OpenMPMixin.h" + #include "DirectivesCommon.h" #include "flang/Common/idioms.h" #include "flang/Lower/Bridge.h" #include "flang/Lower/ConvertExpr.h" #include "flang/Lower/ConvertVariable.h" +#include "flang/Lower/OpenMP.h" #include "flang/Lower/PFTBuilder.h" #include "flang/Lower/StatementContext.h" #include "flang/Optimizer/Builder/BoxValue.h" @@ -41,6 +44,25 @@ using DeclareTargetCapturePair = std::pair<mlir::omp::DeclareTargetCaptureClause, Fortran::semantics::Symbol>; +namespace Fortran::lower { + +template <> +Fortran::lower::LoweringBridge &OpenMPMixin<FirConverter>::getBridge() { + return This()->FirConverter::getBridge(); +} +template <> fir::FirOpBuilder &OpenMPMixin<FirConverter>::getBuilder() { + return This()->FirConverter::getBuilder(); +} +template <> +Fortran::lower::pft::Evaluation &OpenMPMixin<FirConverter>::getEval() { + return This()->FirConverter::getEval(); +} +template <> Fortran::lower::SymMap &OpenMPMixin<FirConverter>::getSymTable() { + return This()->FirConverter::getSymTable(); +} + +} // namespace Fortran::lower + //===----------------------------------------------------------------------===// // Common helper functions //===----------------------------------------------------------------------===// @@ -3860,3 +3882,97 @@ void Fortran::lower::genOpenMPRequires( offloadMod.setRequires(mlirFlags); } } + +namespace Fortran::lower { + +template <> +void OpenMPMixin<FirConverter>::genFIR( + const Fortran::parser::OpenMPConstruct &omp) { + mlir::OpBuilder::InsertPoint insertPt = getBuilder().saveInsertionPoint(); + getSymTable().pushScope(); + genOpenMPConstruct(*This(), getBridge().getSemanticsContext(), getEval(), + omp); + + const Fortran::parser::OpenMPLoopConstruct *ompLoop = + std::get_if<Fortran::parser::OpenMPLoopConstruct>(&omp.u); + const Fortran::parser::OpenMPBlockConstruct *ompBlock = + std::get_if<Fortran::parser::OpenMPBlockConstruct>(&omp.u); + + // If loop is part of an OpenMP Construct then the OpenMP dialect + // workshare loop operation has already been created. Only the + // body needs to be created here and the do_loop can be skipped. + // Skip the number of collapsed loops, which is 1 when there is a + // no collapse requested. + + Fortran::lower::pft::Evaluation *curEval = &getEval(); + const Fortran::parser::OmpClauseList *loopOpClauseList = nullptr; + if (ompLoop) { + loopOpClauseList = &std::get<Fortran::parser::OmpClauseList>( + std::get<Fortran::parser::OmpBeginLoopDirective>(ompLoop->t).t); + int64_t collapseValue = Fortran::lower::getCollapseValue(*loopOpClauseList); + + curEval = &curEval->getFirstNestedEvaluation(); + for (int64_t i = 1; i < collapseValue; i++) { + curEval = &*std::next(curEval->getNestedEvaluations().begin()); + } + } + + for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations()) + This()->genFIR(e); + + if (ompLoop) { + genOpenMPReduction(*This(), *loopOpClauseList); + } else if (ompBlock) { + const auto &blockStart = + std::get<Fortran::parser::OmpBeginBlockDirective>(ompBlock->t); + const auto &blockClauses = + std::get<Fortran::parser::OmpClauseList>(blockStart.t); + genOpenMPReduction(*This(), blockClauses); + } + + getSymTable().popScope(); + getBuilder().restoreInsertionPoint(insertPt); + + // Register if a target region was found + ompDeviceCodeFound = + ompDeviceCodeFound || Fortran::lower::isOpenMPTargetConstruct(omp); +} + +template <> +void OpenMPMixin<FirConverter>::genFIR( + const Fortran::parser::OpenMPDeclarativeConstruct &ompDecl) { + mlir::OpBuilder::InsertPoint insertPt = getBuilder().saveInsertionPoint(); + // Register if a declare target construct intended for a target device was + // found + ompDeviceCodeFound = + ompDeviceCodeFound || + Fortran::lower::isOpenMPDeviceDeclareTarget(*This(), getEval(), ompDecl); + genOpenMPDeclarativeConstruct(*This(), getEval(), ompDecl); + for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations()) + This()->genFIR(e); + getBuilder().restoreInsertionPoint(insertPt); +} + +template <> +void OpenMPMixin<FirConverter>::instantiateVariable( + Fortran::lower::AbstractConverter &converter, + const Fortran::lower::pft::Variable &var) { + assert(var.hasSymbol() && "Expecting symbol"); + if (var.getSymbol().test(Fortran::semantics::Symbol::Flag::OmpThreadprivate)) + genThreadprivateOp(*This(), var); + + if (var.getSymbol().test(Fortran::semantics::Symbol::Flag::OmpDeclareTarget)) + genDeclareTargetIntGlobal(*This(), var); +} + +template <> +void OpenMPMixin<FirConverter>::finalize( + const Fortran::semantics::Symbol *globalOmpRequiresSymbol) { + // Set the module attribute related to OpenMP requires directives + if (ompDeviceCodeFound) { + genOpenMPRequires(This()->getModuleOp().getOperation(), + globalOmpRequiresSymbol); + } +} + +} // namespace Fortran::lower diff --git a/flang/lib/Lower/OpenMPMixin.h b/flang/lib/Lower/OpenMPMixin.h new file mode 100644 index 0000000000000..7339d9eb4fc61 --- /dev/null +++ b/flang/lib/Lower/OpenMPMixin.h @@ -0,0 +1,66 @@ +//===-- OpenMPMixin.h -----------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ +// +//===----------------------------------------------------------------------===// + +#ifndef FORTRAN_LOWER_OPENMPMIXIN_H +#define FORTRAN_LOWER_OPENMPMIXIN_H + +#include "ConverterMixin.h" +#include "flang/Parser/parse-tree.h" + +namespace fir { +class FirOpBuilder; +} + +namespace Fortran::semantics { +class Symbol; +} + +namespace Fortran::lower { + +class AbstractConverter; +class LoweringBridge; +class SymMap; + +namespace pft { +class Evaluation; +class Variable; +} // namespace pft + +template <typename ConverterT> +class OpenMPMixin : public ConverterMixinBase<ConverterT> { +public: + void genFIR(const Fortran::parser::OpenMPConstruct &); + void genFIR(const Fortran::parser::OpenMPDeclarativeConstruct &); + + void genFIR(const Fortran::parser::OmpEndLoopDirective &) {} // nop + + void instantiateVariable(Fortran::lower::AbstractConverter &converter, + const Fortran::lower::pft::Variable &var); + void finalize(const Fortran::semantics::Symbol *globalOmpRequiresSymbol); + +private: + // Shortcuts to call ConverterT:: functions. They can't be defined here + // because the definition of ConverterT is not available at this point. + Fortran::lower::LoweringBridge &getBridge(); + fir::FirOpBuilder &getBuilder(); + Fortran::lower::pft::Evaluation &getEval(); + Fortran::lower::SymMap &getSymTable(); + +private: + /// Whether a target region or declare target function/subroutine + /// intended for device offloading have been detected + bool ompDeviceCodeFound = false; +}; + +} // namespace Fortran::lower + +#endif // FORTRAN_LOWER_OPENMPMIXIN_H `````````` </details> https://github.com/llvm/llvm-project/pull/74866 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits