================ @@ -14175,27 +14222,350 @@ bool SemaOpenMP::checkTransformableLoopNest( return false; }, [&OriginalInits](OMPLoopBasedDirective *Transform) { - Stmt *DependentPreInits; - if (auto *Dir = dyn_cast<OMPTileDirective>(Transform)) - DependentPreInits = Dir->getPreInits(); - else if (auto *Dir = dyn_cast<OMPStripeDirective>(Transform)) - DependentPreInits = Dir->getPreInits(); - else if (auto *Dir = dyn_cast<OMPUnrollDirective>(Transform)) - DependentPreInits = Dir->getPreInits(); - else if (auto *Dir = dyn_cast<OMPReverseDirective>(Transform)) - DependentPreInits = Dir->getPreInits(); - else if (auto *Dir = dyn_cast<OMPInterchangeDirective>(Transform)) - DependentPreInits = Dir->getPreInits(); - else - llvm_unreachable("Unhandled loop transformation"); - - appendFlattenedStmtList(OriginalInits.back(), DependentPreInits); + updatePreInits(Transform, OriginalInits); }); assert(OriginalInits.back().empty() && "No preinit after innermost loop"); OriginalInits.pop_back(); return Result; } +// Counts the total number of nested loops, including the outermost loop (the +// original loop). PRECONDITION of this visitor is that it must be invoked from +// the original loop to be analyzed. The traversal is stop for Decl's and +// Expr's given that they may contain inner loops that must not be counted. +// +// Example AST structure for the code: +// +// int main() { +// #pragma omp fuse +// { +// for (int i = 0; i < 100; i++) { <-- Outer loop +// []() { +// for(int j = 0; j < 100; j++) {} <-- NOT A LOOP +// }; +// for(int j = 0; j < 5; ++j) {} <-- Inner loop +// } +// for (int r = 0; i < 100; i++) { <-- Outer loop +// struct LocalClass { +// void bar() { +// for(int j = 0; j < 100; j++) {} <-- NOT A LOOP +// } +// }; +// for(int k = 0; k < 10; ++k) {} <-- Inner loop +// {x = 5; for(k = 0; k < 10; ++k) x += k; x}; <-- NOT A LOOP +// } +// } +// } +// Result: Loop 'i' contains 2 loops, Loop 'r' also contains 2 loops +class NestedLoopCounterVisitor : public DynamicRecursiveASTVisitor { +private: + unsigned NestedLoopCount = 0; + +public: + explicit NestedLoopCounterVisitor() {} + + unsigned getNestedLoopCount() const { return NestedLoopCount; } + + bool VisitForStmt(ForStmt *FS) override { + ++NestedLoopCount; + return true; + } + + bool VisitCXXForRangeStmt(CXXForRangeStmt *FRS) override { + ++NestedLoopCount; + return true; + } + + bool TraverseStmt(Stmt *S) override { + if (!S) + return true; + + // Skip traversal of all expressions, including special cases like + // LambdaExpr, StmtExpr, BlockExpr, and RequiresExpr. These expressions + // may contain inner statements (and even loops), but they are not part + // of the syntactic body of the surrounding loop structure. + // Therefore must not be counted + if (isa<Expr>(S)) + return true; + + // Only recurse into CompoundStmt (block {}) and loop bodies + if (isa<CompoundStmt>(S) || isa<ForStmt>(S) || isa<CXXForRangeStmt>(S)) { + return DynamicRecursiveASTVisitor::TraverseStmt(S); + } + + // Stop traversal of the rest of statements, that break perfect + // loop nesting, such as control flow (IfStmt, SwitchStmt...) + return true; + } + + bool TraverseDecl(Decl *D) override { + // Stop in the case of finding a declaration, it is not important + // in order to find nested loops (Possible CXXRecordDecl, RecordDecl, + // FunctionDecl...) + return true; + } +}; + +bool SemaOpenMP::analyzeLoopSequence( + Stmt *LoopSeqStmt, unsigned &LoopSeqSize, unsigned &NumLoops, + SmallVectorImpl<OMPLoopBasedDirective::HelperExprs> &LoopHelpers, + SmallVectorImpl<Stmt *> &ForStmts, + SmallVectorImpl<SmallVector<Stmt *, 0>> &OriginalInits, + SmallVectorImpl<SmallVector<Stmt *, 0>> &TransformsPreInits, + SmallVectorImpl<SmallVector<Stmt *, 0>> &LoopSequencePreInits, + SmallVectorImpl<OMPLoopCategory> &LoopCategories, ASTContext &Context, + OpenMPDirectiveKind Kind) { + + VarsWithInheritedDSAType TmpDSA; + QualType BaseInductionVarType; + // Helper Lambda to handle storing initialization and body statements for both + // ForStmt and CXXForRangeStmt and checks for any possible mismatch between + // induction variables types + auto storeLoopStatements = [&OriginalInits, &ForStmts, &BaseInductionVarType, + this, &Context](Stmt *LoopStmt) { + if (auto *For = dyn_cast<ForStmt>(LoopStmt)) { + OriginalInits.back().push_back(For->getInit()); + ForStmts.push_back(For); + // Extract induction variable + if (auto *InitStmt = dyn_cast_or_null<DeclStmt>(For->getInit())) { + if (auto *InitDecl = dyn_cast<VarDecl>(InitStmt->getSingleDecl())) { + QualType InductionVarType = InitDecl->getType().getCanonicalType(); + + // Compare with first loop type + if (BaseInductionVarType.isNull()) { + BaseInductionVarType = InductionVarType; + } else if (!Context.hasSameType(BaseInductionVarType, + InductionVarType)) { + Diag(InitDecl->getBeginLoc(), + diag::warn_omp_different_loop_ind_var_types) + << getOpenMPDirectiveName(OMPD_fuse) << BaseInductionVarType + << InductionVarType; + } + } + } + } else { + auto *CXXFor = cast<CXXForRangeStmt>(LoopStmt); + OriginalInits.back().push_back(CXXFor->getBeginStmt()); + ForStmts.push_back(CXXFor); + } + }; + + // Helper lambda functions to encapsulate the processing of different + // derivations of the canonical loop sequence grammar + // + // Modularized code for handling loop generation and transformations + auto analyzeLoopGeneration = [&storeLoopStatements, &LoopHelpers, + &OriginalInits, &TransformsPreInits, + &LoopCategories, &LoopSeqSize, &NumLoops, Kind, + &TmpDSA, &ForStmts, &Context, + &LoopSequencePreInits, this](Stmt *Child) { + auto LoopTransform = dyn_cast<OMPLoopTransformationDirective>(Child); + Stmt *TransformedStmt = LoopTransform->getTransformedStmt(); + unsigned NumGeneratedLoopNests = LoopTransform->getNumGeneratedLoopNests(); + unsigned NumGeneratedLoops = LoopTransform->getNumGeneratedLoops(); + // Handle the case where transformed statement is not available due to + // dependent contexts + if (!TransformedStmt) { + if (NumGeneratedLoopNests > 0) { + LoopSeqSize += NumGeneratedLoopNests; + NumLoops += NumGeneratedLoops; + return true; + } + // Unroll full (0 loops produced) + else { + Diag(Child->getBeginLoc(), diag::err_omp_not_for) + << 0 << getOpenMPDirectiveName(Kind); + return false; + } + } + // Handle loop transformations with multiple loop nests + // Unroll full + if (NumGeneratedLoopNests <= 0) { + Diag(Child->getBeginLoc(), diag::err_omp_not_for) + << 0 << getOpenMPDirectiveName(Kind); + return false; + } + // Loop transformatons such as split or loopranged fuse + else if (NumGeneratedLoopNests > 1) { + // Get the preinits related to this loop sequence generating + // loop transformation (i.e loopranged fuse, split...) + LoopSequencePreInits.emplace_back(); + // These preinits differ slightly from regular inits/pre-inits related + // to single loop generating loop transformations (interchange, unroll) + // given that they are not bounded to a particular loop nest + // so they need to be treated independently + updatePreInits(LoopTransform, LoopSequencePreInits); + return analyzeLoopSequence(TransformedStmt, LoopSeqSize, NumLoops, + LoopHelpers, ForStmts, OriginalInits, + TransformsPreInits, LoopSequencePreInits, + LoopCategories, Context, Kind); + } + // Vast majority: (Tile, Unroll, Stripe, Reverse, Interchange, Fuse all) + else { + // Process the transformed loop statement + OriginalInits.emplace_back(); + TransformsPreInits.emplace_back(); + LoopHelpers.emplace_back(); + LoopCategories.push_back(OMPLoopCategory::TransformSingleLoop); + + unsigned IsCanonical = + checkOpenMPLoop(Kind, nullptr, nullptr, TransformedStmt, SemaRef, + *DSAStack, TmpDSA, LoopHelpers[LoopSeqSize]); + + if (!IsCanonical) { + Diag(TransformedStmt->getBeginLoc(), diag::err_omp_not_canonical_loop) + << getOpenMPDirectiveName(Kind); + return false; + } + storeLoopStatements(TransformedStmt); + updatePreInits(LoopTransform, TransformsPreInits); + + NumLoops += NumGeneratedLoops; + ++LoopSeqSize; + return true; + } + }; + + // Modularized code for handling regular canonical loops + auto analyzeRegularLoop = [&storeLoopStatements, &LoopHelpers, &OriginalInits, + &LoopSeqSize, &NumLoops, Kind, &TmpDSA, + &LoopCategories, this](Stmt *Child) { + OriginalInits.emplace_back(); + LoopHelpers.emplace_back(); + LoopCategories.push_back(OMPLoopCategory::RegularLoop); + + unsigned IsCanonical = + checkOpenMPLoop(Kind, nullptr, nullptr, Child, SemaRef, *DSAStack, + TmpDSA, LoopHelpers[LoopSeqSize]); + + if (!IsCanonical) { + Diag(Child->getBeginLoc(), diag::err_omp_not_canonical_loop) + << getOpenMPDirectiveName(Kind); + return false; + } + + storeLoopStatements(Child); + auto NLCV = NestedLoopCounterVisitor(); + NLCV.TraverseStmt(Child); + NumLoops += NLCV.getNestedLoopCount(); + return true; + }; + + // Helper functions to validate canonical loop sequence grammar is valid + auto isLoopSequenceDerivation = [](auto *Child) { + return isa<ForStmt>(Child) || isa<CXXForRangeStmt>(Child) || + isa<OMPLoopTransformationDirective>(Child); + }; + auto isLoopGeneratingStmt = [](auto *Child) { ---------------- alexey-bataev wrote:
```suggestion auto IsLoopGeneratingStmt = [](auto *Child) { ``` https://github.com/llvm/llvm-project/pull/139293 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits