================
@@ -213,6 +202,409 @@ class HostEvalInfo {
   llvm::SmallVector<const semantics::Symbol *> iv;
   bool loopNestApplied = false, parallelApplied = false;
 };
+
+/// A base class to help iterate over OpenMP constructs based on an expected
+/// sequence.
+///
+/// The main entry point process() will call processDirective() for the
+/// OpenMP directive associated to the initial given evaluation based on 
whether
+/// it is part of the initialDirsToProcess() set. A nested OpenMP evaluation
+/// might optionally be also visited by the pattern if it meets all of the
+/// following conditions:
+///   - It is the only nested evaluation, apart from an optional END statement
+///     associated to the same directive.
+///   - The OpenMP directive is part of the directive set returned by the
+///     `processDirective` call for the parent.
+///
+/// Subclasses define the expected pattern by implementing the
+/// initialDirsToProcess() and processDirective() methods, and users are
+/// expected to use process() to trigger the complete pattern visit.
+class OpenMPPatternProcessor {
+public:
+  OpenMPPatternProcessor(semantics::SemanticsContext &semaCtx)
+      : semaCtx{semaCtx} {}
+  virtual ~OpenMPPatternProcessor() = default;
+
+  /// Run the pattern from the given evaluation.
+  void process(lower::pft::Evaluation &eval) {
+    dirsToProcess = initialDirsToProcess();
+    processEval(eval);
+  }
+
+protected:
+  /// Returns the set of directives of interest at the beginning of the 
pattern.
+  virtual OmpDirectiveSet initialDirsToProcess() const = 0;
+
+  /// Processes a single directive and, based on it, returns the set of other
+  /// directives of interest that would be part of the pattern if nested inside
+  /// of it.
+  virtual OmpDirectiveSet processDirective(lower::pft::Evaluation &eval,
+                                           llvm::omp::Directive dir) = 0;
+
+  /// Obtain the list of clauses of the given OpenMP block or loop construct
+  /// evaluation. If it's not an OpenMP construct, no modifications are made to
+  /// the \c clauses output argument.
+  void extractClauses(lower::pft::Evaluation &eval, List<Clause> &clauses) {
+    const auto *ompEval{eval.getIf<parser::OpenMPConstruct>()};
+    if (!ompEval)
+      return;
+
+    const parser::OmpClauseList *beginClauseList{nullptr};
+    const parser::OmpClauseList *endClauseList{nullptr};
+    common::visit(
+        [&](const auto &construct) {
+          using Type = llvm::remove_cvref_t<decltype(construct)>;
+          if constexpr (std::is_same_v<Type, parser::OmpBlockConstruct> ||
+                        std::is_same_v<Type, parser::OpenMPLoopConstruct>) {
+            beginClauseList = &construct.BeginDir().Clauses();
+            if (auto &endSpec{construct.EndDir()})
+              endClauseList = &endSpec->Clauses();
+          }
+        },
+        ompEval->u);
+
+    assert(beginClauseList && "expected begin directive");
+    clauses.append(makeClauses(*beginClauseList, semaCtx));
+
+    if (endClauseList)
+      clauses.append(makeClauses(*endClauseList, semaCtx));
+  }
+
+private:
+  /// Decide whether an evaluation must be processed as part of the pattern.
+  ///
+  /// This is the case whenever it's an OpenMP construct and the associated
+  /// directive is part of the current set of directives of interest.
+  bool shouldProcessEval(lower::pft::Evaluation &eval) const {
+    const auto *ompEval{eval.getIf<parser::OpenMPConstruct>()};
+    if (!ompEval)
+      return false;
+
+    return dirsToProcess.test(parser::omp::GetOmpDirectiveName(*ompEval).v);
+  }
+
+  /// Processes an evaluation and, potentially, recursively processes a single
+  /// nested evaluation.
+  ///
+  /// For a nested evaluation to be recursively processed, it must be an OpenMP
+  /// construct, have no sibling evaluations and match one of the
+  /// next-directives of interest set returned by a call to processDirective()
+  /// on the parent evaluation.
+  void processEval(lower::pft::Evaluation &eval) {
+    if (!shouldProcessEval(eval))
+      return;
+
+    const auto &ompEval{eval.get<parser::OpenMPConstruct>()};
+    OmpDirectiveSet processNested{
+        processDirective(eval, parser::omp::GetOmpDirectiveName(ompEval).v)};
+
+    if (processNested.empty())
+      return;
+
+    if (lower::pft::Evaluation * nestedEval{extractOnlyOmpNestedEval(eval)}) {
+      OmpDirectiveSet prevDirs{dirsToProcess};
+      dirsToProcess = processNested;
+      processEval(*nestedEval);
+      dirsToProcess = prevDirs;
+    }
+  }
+
+  /// Return the directive that is immediately nested inside of the given
+  /// \c parent evaluation, if it is its only non-end-statement nested
+  /// evaluation and it represents an OpenMP construct.
+  lower::pft::Evaluation *
+  extractOnlyOmpNestedEval(lower::pft::Evaluation &parent) {
+    if (!parent.hasNestedEvaluations())
+      return nullptr;
+
+    auto &nested{parent.getFirstNestedEvaluation()};
+    if (!nested.isA<parser::OpenMPConstruct>())
+      return nullptr;
+
+    for (auto &sibling : parent.getNestedEvaluations())
+      if (&sibling != &nested && !sibling.isEndStmt())
+        return nullptr;
+
+    return &nested;
+  }
+
+protected:
+  semantics::SemanticsContext &semaCtx;
+
+private:
+  OmpDirectiveSet dirsToProcess;
+};
+
+/// Helper pattern to navigate target SPMD.
+class TargetSPMDPatternProcessor : public OpenMPPatternProcessor {
+public:
+  using OpenMPPatternProcessor::OpenMPPatternProcessor;
+  virtual ~TargetSPMDPatternProcessor() = default;
+
+protected:
+  virtual OmpDirectiveSet initialDirsToProcess() const override {
+    return llvm::omp::allTargetSet;
+  }
+
+  virtual OmpDirectiveSet processDirective(lower::pft::Evaluation &,
+                                           llvm::omp::Directive dir) override {
+    using namespace llvm::omp;
+
+    // The default implementation does nothing, except it returns the allowed
+    // single nested directives for an SPMD kernel. If called by subclasses, it
+    // helps navigate SPMD patterns.
+    //
+    // Patterns considered SPMD:
+    //   - target teams distribute parallel do [simd]
+    //   - target teams loop
+    //   - target parallel do [simd]
+    //   - target parallel loop
+    switch (dir) {
+    case OMPD_target:
+      return topTeamsSet | topParallelSet;
+    case OMPD_target_teams:
+    case OMPD_teams:
+      return topDistributeSet | topLoopSet;
+    case OMPD_target_parallel:
+    case OMPD_parallel:
+      return topLoopSet | topDoSet;
+    default:
+      return {};
+    }
+  }
+};
+
+/// Populates the given HostEvalInfo structure after processing clauses for
+/// the given \p eval OpenMP target construct, or nested constructs, if these
+/// must be evaluated outside of the target region per the spec.
+///
+/// In particular, this will ensure that in <tt>target teams</tt> and 
equivalent
+/// nested constructs, the \c thread_limit, \c num_teams and \c num_threads
+/// clauses will be evaluated in the host. Additionally, loop bounds and steps
+/// will also be evaluated in the host if a <tt>target teams distribute</tt> or
+/// target SPMD construct is detected (i.e. <tt>target teams distribute 
parallel
+/// do [simd]</tt>, <tt>target parallel do [simd]</tt> or equivalent nesting).
+///
+/// The resulting updated HostEvalInfo structure is intended to be used to
+/// populate the \c host_eval operands of the associated \c omp.target
+/// operation, and also to be checked and used by later lowering steps to
+/// populate the corresponding operands of the \c omp.teams, \c omp.parallel or
+/// \c omp.loop_nest operations.
+class HostEvalPatternProcessor : public TargetSPMDPatternProcessor {
+public:
+  HostEvalPatternProcessor(lower::AbstractConverter &converter,
+                           semantics::SemanticsContext &semaCtx,
+                           lower::StatementContext &stmtCtx, mlir::Location 
loc,
+                           HostEvalInfo &hostEvalInfo)
+      : TargetSPMDPatternProcessor{semaCtx}, converter{converter},
+        stmtCtx{stmtCtx}, loc{loc}, hostEvalInfo{hostEvalInfo} {}
+  virtual ~HostEvalPatternProcessor() = default;
+
+protected:
+  virtual OmpDirectiveSet processDirective(lower::pft::Evaluation &eval,
+                                           llvm::omp::Directive dir) override {
+    using namespace llvm::omp;
+
+    List<lower::omp::Clause> clauses;
+    extractClauses(eval, clauses);
+    ClauseProcessor cp{converter, semaCtx, clauses};
+
+    // Currently, we deal differently with e.g. `target parallel workshare` to
+    // `target parallel` with a single nested `workshare`. The first case would
+    // result in no clauses being evaluated in the host, as there's not a case
+    // for it in the below switch statement. The second case would evaluate
+    // `num_threads` clauses in the host, because `target parallel` could be
+    // followed by a `do` construct, which would make this an SPMD target
+    // region.
+    //
+    // TODO: We don't probably want to have such divergent behavior when 
dealing
+    // with combined directives. We need to revisit this logic without listing
+    // every possible combined directive containing a clause we'd otherwise
+    // evaluate in the host if the directive was split into its leafs.
+    switch (dir) {
+    case OMPD_teams_distribute_parallel_do:
+    case OMPD_teams_distribute_parallel_do_simd:
+      cp.processThreadLimit(stmtCtx, hostEvalInfo.ops);
+      [[fallthrough]];
+    case OMPD_target_teams_distribute_parallel_do:
+    case OMPD_target_teams_distribute_parallel_do_simd:
+      cp.processNumTeams(stmtCtx, hostEvalInfo.ops);
+      [[fallthrough]];
+    case OMPD_distribute_parallel_do:
+    case OMPD_distribute_parallel_do_simd:
+    case OMPD_target_parallel_do:
+    case OMPD_target_parallel_do_simd:
+    case OMPD_target_parallel_loop:
+    case OMPD_parallel_do:
+    case OMPD_parallel_do_simd:
+    case OMPD_parallel_loop:
+      cp.processNumThreads(stmtCtx, hostEvalInfo.ops);
+      [[fallthrough]];
+    case OMPD_distribute:
+    case OMPD_distribute_simd:
+    case OMPD_do:
+    case OMPD_do_simd:
+      cp.processCollapse(loc, eval, hostEvalInfo.ops, hostEvalInfo.ops,
+                         hostEvalInfo.iv);
+      return {};
+
+    case OMPD_teams:
+      cp.processThreadLimit(stmtCtx, hostEvalInfo.ops);
+      [[fallthrough]];
+    case OMPD_target_teams:
+      cp.processNumTeams(stmtCtx, hostEvalInfo.ops);
+      break;
+
+    case OMPD_teams_distribute:
+    case OMPD_teams_distribute_simd:
+      cp.processThreadLimit(stmtCtx, hostEvalInfo.ops);
+      [[fallthrough]];
+    case OMPD_target_teams_distribute:
+    case OMPD_target_teams_distribute_simd:
+      cp.processCollapse(loc, eval, hostEvalInfo.ops, hostEvalInfo.ops,
+                         hostEvalInfo.iv);
+      cp.processNumTeams(stmtCtx, hostEvalInfo.ops);
+      return {};
+
+    case OMPD_teams_loop:
+      cp.processThreadLimit(stmtCtx, hostEvalInfo.ops);
+      [[fallthrough]];
+    case OMPD_target_teams_loop:
+      cp.processNumTeams(stmtCtx, hostEvalInfo.ops);
+      [[fallthrough]];
+    case OMPD_loop:
+      cp.processCollapse(loc, eval, hostEvalInfo.ops, hostEvalInfo.ops,
+                         hostEvalInfo.iv);
+      return {};
+
+    case OMPD_teams_workdistribute:
+      cp.processThreadLimit(stmtCtx, hostEvalInfo.ops);
+      [[fallthrough]];
+    case OMPD_target_teams_workdistribute:
+      cp.processNumTeams(stmtCtx, hostEvalInfo.ops);
+      break;
+
+    case OMPD_target_parallel:
+    case OMPD_parallel:
+      cp.processNumThreads(stmtCtx, hostEvalInfo.ops);
+      break;
+
+    case OMPD_target:
+      break;
+
+    default:
+      return {};
+    }
+
+    // Visit nested directives as per the SPMD pattern.
+    return TargetSPMDPatternProcessor::processDirective(eval, dir);
+  }
+
+private:
+  lower::AbstractConverter &converter;
+  lower::StatementContext &stmtCtx;
+  mlir::Location loc;
+  HostEvalInfo &hostEvalInfo;
+};
+
+/// Checks target regions and, based on the directives and clauses encountered,
+/// determines its associated kernel type.
+class KernelTypePatternProcessor : protected TargetSPMDPatternProcessor {
+public:
+  KernelTypePatternProcessor(semantics::SemanticsContext &semaCtx,
+                             mlir::ModuleOp moduleOp)
+      : TargetSPMDPatternProcessor{semaCtx}, moduleOp{moduleOp} {}
+  virtual ~KernelTypePatternProcessor() = default;
+
+  /// Executes the pattern and returns the kernel type of the given target
+  /// region, or \c mlir::omp::TargetExecMode::generic by default for 
non-target
+  /// evaluations.
+  mlir::omp::TargetExecMode getKernelType(lower::pft::Evaluation &eval) {
+    execMode = mlir::omp::TargetExecMode::generic;
+    process(eval);
+    return execMode;
+  }
+
+protected:
+  virtual OmpDirectiveSet processDirective(lower::pft::Evaluation &eval,
+                                           llvm::omp::Directive dir) override {
+    using namespace llvm::omp;
+
+    switch (dir) {
+    case OMPD_target:
+    case OMPD_target_parallel:
+    case OMPD_parallel:
+    case OMPD_teams:
+      break;
+    case OMPD_target_teams:
+      if (hasOmpxBareClause(eval)) {
+        execMode = mlir::omp::TargetExecMode::bare;
+        return {};
+      }
+      break;
+    case OMPD_target_teams_distribute_parallel_do:
+    case OMPD_target_teams_distribute_parallel_do_simd:
+    case OMPD_target_teams_loop:
+    case OMPD_target_parallel_do:
+    case OMPD_target_parallel_do_simd:
+    case OMPD_target_parallel_loop:
+    case OMPD_teams_distribute_parallel_do:
+    case OMPD_teams_distribute_parallel_do_simd:
+    case OMPD_teams_loop:
+    case OMPD_distribute_parallel_do:
+    case OMPD_distribute_parallel_do_simd:
+    case OMPD_loop:
+    case OMPD_parallel_do:
+    case OMPD_parallel_do_simd:
+    case OMPD_do:
+    case OMPD_do_simd:
+      execMode = canPromoteSPMDToNoLoop(eval)
+                     ? mlir::omp::TargetExecMode::spmd_no_loop
+                     : mlir::omp::TargetExecMode::spmd;
+      return {};
+    default:
+      return {};
+    }
+
+    // Visit nested directives as per the SPMD pattern.
+    return TargetSPMDPatternProcessor::processDirective(eval, dir);
----------------
Meinersbur wrote:

Depending on what `dir` is, shouldn't there fewer choices for what we would 
allow for for the nested constructs, like `OMPD_target_teams` does not allow 
any nested ones.

For instance, assuming it was possible (I am lacking creativity for some real 
case), assume `!$omp target parallel` (which would `break;` in the switch 
above. i.e. generic), the nested one could be `!$omp teams loop` which could 
promote `execMode` to spmd, but it would be a very weird spmd mode.

https://github.com/llvm/llvm-project/pull/186166
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to