llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-flang-fir-hlfir Author: Krzysztof Parzyszek (kparzysz) <details> <summary>Changes</summary> There semantic analysis of the ATOMIC construct will require additional rewriting (reassociation of certain expressions for user convenience), and that will be driven by diagnoses made in the semantic checks. While the rewriting of min/max is not required to be done in semantic analysis, moving it there will make all rewriting for ATOMIC construct be located in a single location. --- Full diff: https://github.com/llvm/llvm-project/pull/153038.diff 3 Files Affected: - (modified) flang/include/flang/Semantics/openmp-utils.h (+8) - (modified) flang/lib/Lower/OpenMP/Atomic.cpp (-271) - (modified) flang/lib/Semantics/check-omp-atomic.cpp (+126-1) ``````````diff diff --git a/flang/include/flang/Semantics/openmp-utils.h b/flang/include/flang/Semantics/openmp-utils.h index b8ad9ed17c720..1c54124a5738a 100644 --- a/flang/include/flang/Semantics/openmp-utils.h +++ b/flang/include/flang/Semantics/openmp-utils.h @@ -22,6 +22,8 @@ #include <optional> #include <string> +#include <type_traits> +#include <utility> namespace Fortran::semantics { class SemanticsContext; @@ -29,6 +31,12 @@ class Symbol; // Add this namespace to avoid potential conflicts namespace omp { +template <typename T, typename U = std::remove_const_t<T>> U AsRvalue(T &t) { + return U(t); +} + +template <typename T> T &&AsRvalue(T &&t) { return std::move(t); } + // There is no consistent way to get the source of an ActionStmt, but there // is "source" in Statement<T>. This structure keeps the ActionStmt with the // extracted source for further use. diff --git a/flang/lib/Lower/OpenMP/Atomic.cpp b/flang/lib/Lower/OpenMP/Atomic.cpp index ed0bff04ed889..ff82a36951bfa 100644 --- a/flang/lib/Lower/OpenMP/Atomic.cpp +++ b/flang/lib/Lower/OpenMP/Atomic.cpp @@ -43,179 +43,6 @@ namespace omp { using namespace Fortran::lower::omp; } -namespace { -// An example of a type that can be used to get the return value from -// the visitor: -// visitor(type_identity<Xyz>) -> result_type -using SomeArgType = evaluate::Type<common::TypeCategory::Integer, 4>; - -struct GetProc - : public evaluate::Traverse<GetProc, const evaluate::ProcedureDesignator *, - false> { - using Result = const evaluate::ProcedureDesignator *; - using Base = evaluate::Traverse<GetProc, Result, false>; - GetProc() : Base(*this) {} - - using Base::operator(); - - static Result Default() { return nullptr; } - - Result operator()(const evaluate::ProcedureDesignator &p) const { return &p; } - static Result Combine(Result a, Result b) { return a != nullptr ? a : b; } -}; - -struct WithType { - WithType(const evaluate::DynamicType &t) : type(t) { - assert(type.category() != common::TypeCategory::Derived && - "Type cannot be a derived type"); - } - - template <typename VisitorTy> // - auto visit(VisitorTy &&visitor) const - -> std::invoke_result_t<VisitorTy, SomeArgType> { - switch (type.category()) { - case common::TypeCategory::Integer: - switch (type.kind()) { - case 1: - return visitor(llvm::type_identity<evaluate::Type<Integer, 1>>{}); - case 2: - return visitor(llvm::type_identity<evaluate::Type<Integer, 2>>{}); - case 4: - return visitor(llvm::type_identity<evaluate::Type<Integer, 4>>{}); - case 8: - return visitor(llvm::type_identity<evaluate::Type<Integer, 8>>{}); - case 16: - return visitor(llvm::type_identity<evaluate::Type<Integer, 16>>{}); - } - break; - case common::TypeCategory::Unsigned: - switch (type.kind()) { - case 1: - return visitor(llvm::type_identity<evaluate::Type<Unsigned, 1>>{}); - case 2: - return visitor(llvm::type_identity<evaluate::Type<Unsigned, 2>>{}); - case 4: - return visitor(llvm::type_identity<evaluate::Type<Unsigned, 4>>{}); - case 8: - return visitor(llvm::type_identity<evaluate::Type<Unsigned, 8>>{}); - case 16: - return visitor(llvm::type_identity<evaluate::Type<Unsigned, 16>>{}); - } - break; - case common::TypeCategory::Real: - switch (type.kind()) { - case 2: - return visitor(llvm::type_identity<evaluate::Type<Real, 2>>{}); - case 3: - return visitor(llvm::type_identity<evaluate::Type<Real, 3>>{}); - case 4: - return visitor(llvm::type_identity<evaluate::Type<Real, 4>>{}); - case 8: - return visitor(llvm::type_identity<evaluate::Type<Real, 8>>{}); - case 10: - return visitor(llvm::type_identity<evaluate::Type<Real, 10>>{}); - case 16: - return visitor(llvm::type_identity<evaluate::Type<Real, 16>>{}); - } - break; - case common::TypeCategory::Complex: - switch (type.kind()) { - case 2: - return visitor(llvm::type_identity<evaluate::Type<Complex, 2>>{}); - case 3: - return visitor(llvm::type_identity<evaluate::Type<Complex, 3>>{}); - case 4: - return visitor(llvm::type_identity<evaluate::Type<Complex, 4>>{}); - case 8: - return visitor(llvm::type_identity<evaluate::Type<Complex, 8>>{}); - case 10: - return visitor(llvm::type_identity<evaluate::Type<Complex, 10>>{}); - case 16: - return visitor(llvm::type_identity<evaluate::Type<Complex, 16>>{}); - } - break; - case common::TypeCategory::Logical: - switch (type.kind()) { - case 1: - return visitor(llvm::type_identity<evaluate::Type<Logical, 1>>{}); - case 2: - return visitor(llvm::type_identity<evaluate::Type<Logical, 2>>{}); - case 4: - return visitor(llvm::type_identity<evaluate::Type<Logical, 4>>{}); - case 8: - return visitor(llvm::type_identity<evaluate::Type<Logical, 8>>{}); - } - break; - case common::TypeCategory::Character: - switch (type.kind()) { - case 1: - return visitor(llvm::type_identity<evaluate::Type<Character, 1>>{}); - case 2: - return visitor(llvm::type_identity<evaluate::Type<Character, 2>>{}); - case 4: - return visitor(llvm::type_identity<evaluate::Type<Character, 4>>{}); - } - break; - case common::TypeCategory::Derived: - (void)Derived; - break; - } - llvm_unreachable("Unhandled type"); - } - - const evaluate::DynamicType &type; - -private: - // Shorter names. - static constexpr auto Character = common::TypeCategory::Character; - static constexpr auto Complex = common::TypeCategory::Complex; - static constexpr auto Derived = common::TypeCategory::Derived; - static constexpr auto Integer = common::TypeCategory::Integer; - static constexpr auto Logical = common::TypeCategory::Logical; - static constexpr auto Real = common::TypeCategory::Real; - static constexpr auto Unsigned = common::TypeCategory::Unsigned; -}; - -template <typename T, typename U = std::remove_const_t<T>> -U AsRvalue(T &t) { - U copy{t}; - return std::move(copy); -} - -template <typename T> -T &&AsRvalue(T &&t) { - return std::move(t); -} - -struct ArgumentReplacer - : public evaluate::Traverse<ArgumentReplacer, bool, false> { - using Base = evaluate::Traverse<ArgumentReplacer, bool, false>; - using Result = bool; - - Result Default() const { return false; } - - ArgumentReplacer(evaluate::ActualArguments &&newArgs) - : Base(*this), args_(std::move(newArgs)) {} - - using Base::operator(); - - template <typename T> - Result operator()(const evaluate::FunctionRef<T> &x) { - assert(!done_); - auto &mut = const_cast<evaluate::FunctionRef<T> &>(x); - mut.arguments() = args_; - done_ = true; - return true; - } - - Result Combine(Result &&a, Result &&b) { return a || b; } - -private: - bool done_{false}; - evaluate::ActualArguments &&args_; -}; -} // namespace - [[maybe_unused]] static void dumpAtomicAnalysis(const parser::OpenMPAtomicConstruct::Analysis &analysis) { auto whatStr = [](int k) { @@ -412,85 +239,6 @@ makeMemOrderAttr(lower::AbstractConverter &converter, return nullptr; } -static bool replaceArgs(semantics::SomeExpr &expr, - evaluate::ActualArguments &&newArgs) { - return ArgumentReplacer(std::move(newArgs))(expr); -} - -static semantics::SomeExpr makeCall(const evaluate::DynamicType &type, - const evaluate::ProcedureDesignator &proc, - const evaluate::ActualArguments &args) { - return WithType(type).visit([&](auto &&s) -> semantics::SomeExpr { - using Type = typename llvm::remove_cvref_t<decltype(s)>::type; - return evaluate::AsGenericExpr( - evaluate::FunctionRef<Type>(AsRvalue(proc), AsRvalue(args))); - }); -} - -static const evaluate::ProcedureDesignator & -getProcedureDesignator(const semantics::SomeExpr &call) { - const evaluate::ProcedureDesignator *proc = GetProc{}(call); - assert(proc && "Call has no procedure designator"); - return *proc; -} - -static semantics::SomeExpr // -genReducedMinMax(const semantics::SomeExpr &orig, - const semantics::SomeExpr *atomArg, - const std::vector<semantics::SomeExpr> &args) { - // Take a list of arguments to a min/max operation, e.g. [a0, a1, ...] - // One of the a_i's, say a_t, must be atomArg. - // Generate tmp = min/max(a0, a1, ... [except a_t]). Then generate - // call = min/max(a_t, tmp). - // Return "call". - - // The min/max intrinsics have 2 mandatory arguments, the rest is optional. - // Make sure that the "tmp = min/max(...)" doesn't promote an optional - // argument to a non-optional position. This could happen if a_t is at - // position 0 or 1. - if (args.size() <= 2) - return orig; - - evaluate::ActualArguments nonAtoms; - - auto AsActual = [](const semantics::SomeExpr &x) { - semantics::SomeExpr copy = x; - return evaluate::ActualArgument(std::move(copy)); - }; - // Semantic checks guarantee that the "atom" shows exactly once in the - // argument list (with potential conversions around it). - // For the first two (non-optional) arguments, if "atom" is among them, - // replace it with another occurrence of the other non-optional argument. - if (atomArg == &args[0]) { - // (atom, x, y...) -> (x, x, y...) - nonAtoms.push_back(AsActual(args[1])); - nonAtoms.push_back(AsActual(args[1])); - } else if (atomArg == &args[1]) { - // (x, atom, y...) -> (x, x, y...) - nonAtoms.push_back(AsActual(args[0])); - nonAtoms.push_back(AsActual(args[0])); - } else { - // (x, y, z...) -> unchanged - nonAtoms.push_back(AsActual(args[0])); - nonAtoms.push_back(AsActual(args[1])); - } - - // The rest of arguments are optional, so we can just skip "atom". - for (size_t i = 2, e = args.size(); i != e; ++i) { - if (atomArg != &args[i]) - nonAtoms.push_back(AsActual(args[i])); - } - - // The type of the intermediate min/max is the same as the type of its - // arguments, which may be different from the type of the original - // expression. The original expression may have additional coverts. - auto tmp = - makeCall(*atomArg->GetType(), getProcedureDesignator(orig), nonAtoms); - semantics::SomeExpr call = orig; - replaceArgs(call, {AsActual(*atomArg), AsActual(tmp)}); - return call; -} - static mlir::Operation * // genAtomicRead(lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx, mlir::Location loc, @@ -610,25 +358,6 @@ genAtomicUpdate(lower::AbstractConverter &converter, auto [opcode, args] = evaluate::GetTopLevelOperationIgnoreResizing(input); assert(!args.empty() && "Update operation without arguments"); - // Pass args as an argument to avoid capturing a structured binding. - const semantics::SomeExpr *atomArg = [&](auto &args) { - for (const semantics::SomeExpr &e : args) { - if (evaluate::IsSameOrConvertOf(e, atom)) - return &e; - } - llvm_unreachable("Atomic variable not in argument list"); - }(args); - - if (opcode == evaluate::operation::Operator::Min || - opcode == evaluate::operation::Operator::Max) { - // Min and max operations are expanded inline, so reduce them to - // operations with exactly two (non-optional) arguments. - rhs = genReducedMinMax(rhs, atomArg, args); - input = *evaluate::GetConvertInput(rhs); - std::tie(opcode, args) = - evaluate::GetTopLevelOperationIgnoreResizing(input); - atomArg = nullptr; // No longer valid. - } for (auto &arg : args) { if (!evaluate::IsSameOrConvertOf(arg, atom)) { mlir::Value val = fir::getBase(converter.genExprValue(arg, naCtx, &loc)); diff --git a/flang/lib/Semantics/check-omp-atomic.cpp b/flang/lib/Semantics/check-omp-atomic.cpp index a5fe820b1069b..0c0e6158485e9 100644 --- a/flang/lib/Semantics/check-omp-atomic.cpp +++ b/flang/lib/Semantics/check-omp-atomic.cpp @@ -14,6 +14,7 @@ #include "flang/Common/indirection.h" #include "flang/Evaluate/expression.h" +#include "flang/Evaluate/rewrite.h" #include "flang/Evaluate/tools.h" #include "flang/Parser/char-block.h" #include "flang/Parser/parse-tree.h" @@ -42,6 +43,8 @@ using namespace Fortran::semantics::omp; namespace operation = Fortran::evaluate::operation; +static MaybeExpr PostSemaRewrite(const SomeExpr &atom, const SomeExpr &expr); + template <typename T, typename U> static bool operator!=(const evaluate::Expr<T> &e, const evaluate::Expr<U> &f) { return !(e == f); @@ -284,7 +287,15 @@ struct AtomicAnalysis { AtomicAnalysis &addOp(Op &op, int what, const std::optional<evaluate::Assignment> &maybeAssign) { op.what = what; - op.assign = maybeAssign; + if (maybeAssign) { + if (MaybeExpr rewritten{PostSemaRewrite(atom_, maybeAssign->rhs)}) { + op.assign = evaluate::Assignment( + AsRvalue(maybeAssign->lhs), std::move(*rewritten)); + op.assign->u = std::move(maybeAssign->u); + } else { + op.assign = *maybeAssign; + } + } return *this; } @@ -1293,4 +1304,118 @@ void OmpStructureChecker::Leave(const parser::OpenMPAtomicConstruct &) { dirContext_.pop_back(); } +// Rewrite min/max: +// Min and max intrinsics in Fortran take an arbitrary number of arguments +// (two or more). The first two are mandatory, the rest is optional. That +// means that arguments beyond the first two may be optional dummy argument +// from the caller. In that case, a reference to such an argument will +// cause presence test to be emitted, which cannot go inside of the atomic +// operation. Since the atom operand must be present, rewrite the min/max +// operation in a way that avoid the presence tests in the atomic code. +// For example, in +// subroutine f(atom, x, y, z) +// integer :: atom, x +// integer, optional :: y, z +// !$omp atomic update +// atom = min(atom, x, y, z) +// end +// the min operation will become +// atom = min(atom, min(x, y, z)) +// and in the final code +// // Presence check is fine here. +// tmp = min(x, y, z) +// atomic update { +// // Both operands are mandatory, no presence check needed. +// atom = min(atom, tmp) +// } +struct MinMaxRewriter : public evaluate::rewrite::Identity { + using Id = evaluate::rewrite::Identity; + using Id::operator(); + + MinMaxRewriter(const SomeExpr &atom) : atom_(atom) {} + + static bool IsMinMax(const evaluate::ProcedureDesignator &p) { + if (auto *intrin{p.GetSpecificIntrinsic()}) { + return intrin->name == "min" || intrin->name == "max"; + } + return false; + } + + // Take a list of arguments to a min/max operation, e.g. [a0, a1, ...] + // One of the a_i's, say a_t, must be the atom. + // Generate + // min/max(a_t, min/max(a0, a1, ... [except a_t])) + template <typename T> + evaluate::Expr<T> operator()( + evaluate::Expr<T> &&x, const evaluate::FunctionRef<T> &f) { + const evaluate::ProcedureDesignator &proc = f.proc(); + if (!IsMinMax(proc) || f.arguments().size() <= 2) { + return Id::operator()(std::move(x), f); + } + + // Collect arguments as SomeExpr's and find out which argument + // corresponds to atom. + const SomeExpr *atomArg{nullptr}; + std::vector<const SomeExpr *> args; + for (const std::optional<evaluate::ActualArgument> &a : f.arguments()) { + if (!a) { + continue; + } + if (const SomeExpr *e{a->UnwrapExpr()}) { + if (evaluate::IsSameOrConvertOf(*e, atom_)) { + atomArg = e; + } + args.push_back(e); + } + } + if (!atomArg) { + return Id::operator()(std::move(x), f); + } + + evaluate::ActualArguments nonAtoms; + + auto AsActual = [](const SomeExpr &z) { + SomeExpr copy = z; + return evaluate::ActualArgument(std::move(copy)); + }; + // Semantic checks guarantee that the "atom" shows exactly once in the + // argument list (with potential conversions around it). + // For the first two (non-optional) arguments, if "atom" is among them, + // replace it with another occurrence of the other non-optional argument. + if (atomArg == args[0]) { + // (atom, x, y...) -> (x, x, y...) + nonAtoms.push_back(AsActual(*args[1])); + nonAtoms.push_back(AsActual(*args[1])); + } else if (atomArg == args[1]) { + // (x, atom, y...) -> (x, x, y...) + nonAtoms.push_back(AsActual(*args[0])); + nonAtoms.push_back(AsActual(*args[0])); + } else { + // (x, y, z...) -> unchanged + nonAtoms.push_back(AsActual(*args[0])); + nonAtoms.push_back(AsActual(*args[1])); + } + + // The rest of arguments are optional, so we can just skip "atom". + for (size_t i = 2, e = args.size(); i != e; ++i) { + if (atomArg != args[i]) + nonAtoms.push_back(AsActual(*args[i])); + } + + SomeExpr tmp = evaluate::AsGenericExpr( + evaluate::FunctionRef<T>(AsRvalue(proc), AsRvalue(nonAtoms))); + + return evaluate::Expr<T>(evaluate::FunctionRef<T>( + AsRvalue(proc), {AsActual(*atomArg), AsActual(tmp)})); + } + +private: + const SomeExpr &atom_; +}; + +static MaybeExpr PostSemaRewrite(const SomeExpr &atom, const SomeExpr &expr) { + MinMaxRewriter rewriter(atom); + return evaluate::rewrite::Mutator(rewriter)(expr); +} + } // namespace Fortran::semantics `````````` </details> https://github.com/llvm/llvm-project/pull/153038 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits