================ @@ -14236,28 +14279,282 @@ bool SemaOpenMP::checkTransformableLoopNest( return false; }, [&OriginalInits](OMPLoopBasedDirective *Transform) { - Stmt *DependentPreInits; - if (auto *Dir = dyn_cast<OMPTileDirective>(Transform)) - DependentPreInits = Dir->getPreInits(); - else if (auto *Dir = dyn_cast<OMPStripeDirective>(Transform)) - DependentPreInits = Dir->getPreInits(); - else if (auto *Dir = dyn_cast<OMPUnrollDirective>(Transform)) - DependentPreInits = Dir->getPreInits(); - else if (auto *Dir = dyn_cast<OMPReverseDirective>(Transform)) - DependentPreInits = Dir->getPreInits(); - else if (auto *Dir = dyn_cast<OMPInterchangeDirective>(Transform)) - DependentPreInits = Dir->getPreInits(); - else - llvm_unreachable("Unhandled loop transformation"); - - appendFlattenedStmtList(OriginalInits.back(), DependentPreInits); + updatePreInits(Transform, OriginalInits.back()); }); assert(OriginalInits.back().empty() && "No preinit after innermost loop"); OriginalInits.pop_back(); return Result; } -/// Add preinit statements that need to be propageted from the selected loop. +/// Counts the total number of nested loops, including the outermost loop (the +/// original loop). PRECONDITION of this visitor is that it must be invoked from +/// the original loop to be analyzed. The traversal stops for Decl's and +/// Expr's given that they may contain inner loops that must not be counted. +/// +/// Example AST structure for the code: +/// +/// int main() { +/// #pragma omp fuse +/// { +/// for (int i = 0; i < 100; i++) { <-- Outer loop +/// []() { +/// for(int j = 0; j < 100; j++) {} <-- NOT A LOOP +/// }; +/// for(int j = 0; j < 5; ++j) {} <-- Inner loop +/// } +/// for (int r = 0; i < 100; i++) { <-- Outer loop +/// struct LocalClass { +/// void bar() { +/// for(int j = 0; j < 100; j++) {} <-- NOT A LOOP +/// } +/// }; +/// for(int k = 0; k < 10; ++k) {} <-- Inner loop +/// {x = 5; for(k = 0; k < 10; ++k) x += k; x}; <-- NOT A LOOP +/// } +/// } +/// } +/// Result: Loop 'i' contains 2 loops, Loop 'r' also contains 2 loops. +class NestedLoopCounterVisitor final : public DynamicRecursiveASTVisitor { +private: + unsigned NestedLoopCount = 0; + +public: + explicit NestedLoopCounterVisitor() = default; + + unsigned getNestedLoopCount() const { return NestedLoopCount; } + + bool VisitForStmt(ForStmt *FS) override { + ++NestedLoopCount; + return true; + } + + bool VisitCXXForRangeStmt(CXXForRangeStmt *FRS) override { + ++NestedLoopCount; + return true; + } + + bool TraverseStmt(Stmt *S) override { + if (!S) + return true; + + // Skip traversal of all expressions, including special cases like + // LambdaExpr, StmtExpr, BlockExpr, and RequiresExpr. These expressions + // may contain inner statements (and even loops), but they are not part + // of the syntactic body of the surrounding loop structure. + // Therefore must not be counted. + if (isa<Expr>(S)) + return true; + + // Only recurse into CompoundStmt (block {}) and loop bodies. + if (isa<CompoundStmt, ForStmt, CXXForRangeStmt>(S)) { + return DynamicRecursiveASTVisitor::TraverseStmt(S); + } + + // Stop traversal of the rest of statements, that break perfect + // loop nesting, such as control flow (IfStmt, SwitchStmt...). + return true; + } + + bool TraverseDecl(Decl *D) override { + // Stop in the case of finding a declaration, it is not important + // in order to find nested loops (Possible CXXRecordDecl, RecordDecl, + // FunctionDecl...). + return true; + } +}; + +bool SemaOpenMP::analyzeLoopSequence(Stmt *LoopSeqStmt, + LoopSequenceAnalysis &SeqAnalysis, + ASTContext &Context, + OpenMPDirectiveKind Kind) { + + VarsWithInheritedDSAType TmpDSA; + /// Helper Lambda to handle storing initialization and body statements for + /// both ForStmt and CXXForRangeStmt. + auto StoreLoopStatements = [](LoopAnalysis &Analysis, Stmt *LoopStmt) { + if (auto *For = dyn_cast<ForStmt>(LoopStmt)) { + Analysis.OriginalInits.push_back(For->getInit()); + Analysis.ForStmt = For; + } else { + auto *CXXFor = cast<CXXForRangeStmt>(LoopStmt); + Analysis.OriginalInits.push_back(CXXFor->getBeginStmt()); + Analysis.ForStmt = CXXFor; + } + }; + + /// Helper lambda functions to encapsulate the processing of different + /// derivations of the canonical loop sequence grammar + /// Modularized code for handling loop generation and transformations. + auto AnalyzeLoopGeneration = [&](Stmt *Child) { + auto *LoopTransform = dyn_cast<OMPLoopTransformationDirective>(Child); + Stmt *TransformedStmt = LoopTransform->getTransformedStmt(); + unsigned NumGeneratedLoopNests = LoopTransform->getNumGeneratedLoopNests(); + unsigned NumGeneratedLoops = LoopTransform->getNumGeneratedLoops(); + // Handle the case where transformed statement is not available due to + // dependent contexts + if (!TransformedStmt) { + if (NumGeneratedLoopNests > 0) { + SeqAnalysis.LoopSeqSize += NumGeneratedLoopNests; + SeqAnalysis.NumLoops += NumGeneratedLoops; + return true; + } + // Unroll full (0 loops produced) + Diag(Child->getBeginLoc(), diag::err_omp_not_for) + << 0 << getOpenMPDirectiveName(Kind); + return false; + } + // Handle loop transformations with multiple loop nests + // Unroll full + if (!NumGeneratedLoopNests) { + Diag(Child->getBeginLoc(), diag::err_omp_not_for) + << 0 << getOpenMPDirectiveName(Kind); + return false; + } + // Loop transformatons such as split or loopranged fuse + if (NumGeneratedLoopNests > 1) { + // Get the preinits related to this loop sequence generating + // loop transformation (i.e loopranged fuse, split...) + // These preinits differ slightly from regular inits/pre-inits related + // to single loop generating loop transformations (interchange, unroll) + // given that they are not bounded to a particular loop nest + // so they need to be treated independently + updatePreInits(LoopTransform, SeqAnalysis.LoopSequencePreInits); + return analyzeLoopSequence(TransformedStmt, SeqAnalysis, Context, Kind); + } + // Vast majority: (Tile, Unroll, Stripe, Reverse, Interchange, Fuse all) + // Process the transformed loop statement + LoopAnalysis &NewTransformedSingleLoop = + SeqAnalysis.Loops.emplace_back(OMPLoopCategory::TransformSingleLoop); + unsigned IsCanonical = checkOpenMPLoop( + Kind, nullptr, nullptr, TransformedStmt, SemaRef, *DSAStack, TmpDSA, + NewTransformedSingleLoop.HelperExprs); + + if (!IsCanonical) { + Diag(TransformedStmt->getBeginLoc(), diag::err_omp_not_canonical_loop) + << getOpenMPDirectiveName(Kind); + return false; + } + StoreLoopStatements(NewTransformedSingleLoop, TransformedStmt); + updatePreInits(LoopTransform, NewTransformedSingleLoop.TransformsPreInits); + + SeqAnalysis.NumLoops += NumGeneratedLoops; + SeqAnalysis.LoopSeqSize++; + return true; + }; + + /// Modularized code for handling regular canonical loops. + auto AnalyzeRegularLoop = [&](Stmt *Child) { + LoopAnalysis &NewRegularLoop = + SeqAnalysis.Loops.emplace_back(OMPLoopCategory::RegularLoop); + unsigned IsCanonical = + checkOpenMPLoop(Kind, nullptr, nullptr, Child, SemaRef, *DSAStack, + TmpDSA, NewRegularLoop.HelperExprs); + + if (!IsCanonical) { + Diag(Child->getBeginLoc(), diag::err_omp_not_canonical_loop) + << getOpenMPDirectiveName(Kind); + return false; + } + + StoreLoopStatements(NewRegularLoop, Child); + auto NLCV = NestedLoopCounterVisitor(); ---------------- Meinersbur wrote:
```suggestion NestedLoopCounterVisitor NLCV; ``` [style] https://github.com/llvm/llvm-project/pull/139293 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits