================ @@ -14175,27 +14222,350 @@ bool SemaOpenMP::checkTransformableLoopNest( return false; }, [&OriginalInits](OMPLoopBasedDirective *Transform) { - Stmt *DependentPreInits; - if (auto *Dir = dyn_cast<OMPTileDirective>(Transform)) - DependentPreInits = Dir->getPreInits(); - else if (auto *Dir = dyn_cast<OMPStripeDirective>(Transform)) - DependentPreInits = Dir->getPreInits(); - else if (auto *Dir = dyn_cast<OMPUnrollDirective>(Transform)) - DependentPreInits = Dir->getPreInits(); - else if (auto *Dir = dyn_cast<OMPReverseDirective>(Transform)) - DependentPreInits = Dir->getPreInits(); - else if (auto *Dir = dyn_cast<OMPInterchangeDirective>(Transform)) - DependentPreInits = Dir->getPreInits(); - else - llvm_unreachable("Unhandled loop transformation"); - - appendFlattenedStmtList(OriginalInits.back(), DependentPreInits); + updatePreInits(Transform, OriginalInits); }); assert(OriginalInits.back().empty() && "No preinit after innermost loop"); OriginalInits.pop_back(); return Result; } +// Counts the total number of nested loops, including the outermost loop (the +// original loop). PRECONDITION of this visitor is that it must be invoked from +// the original loop to be analyzed. The traversal is stop for Decl's and +// Expr's given that they may contain inner loops that must not be counted. +// +// Example AST structure for the code: +// +// int main() { +// #pragma omp fuse +// { +// for (int i = 0; i < 100; i++) { <-- Outer loop +// []() { +// for(int j = 0; j < 100; j++) {} <-- NOT A LOOP +// }; +// for(int j = 0; j < 5; ++j) {} <-- Inner loop +// } +// for (int r = 0; i < 100; i++) { <-- Outer loop +// struct LocalClass { +// void bar() { +// for(int j = 0; j < 100; j++) {} <-- NOT A LOOP +// } +// }; +// for(int k = 0; k < 10; ++k) {} <-- Inner loop +// {x = 5; for(k = 0; k < 10; ++k) x += k; x}; <-- NOT A LOOP +// } +// } +// } +// Result: Loop 'i' contains 2 loops, Loop 'r' also contains 2 loops +class NestedLoopCounterVisitor : public DynamicRecursiveASTVisitor { +private: + unsigned NestedLoopCount = 0; + +public: + explicit NestedLoopCounterVisitor() {} + + unsigned getNestedLoopCount() const { return NestedLoopCount; } + + bool VisitForStmt(ForStmt *FS) override { + ++NestedLoopCount; + return true; + } + + bool VisitCXXForRangeStmt(CXXForRangeStmt *FRS) override { + ++NestedLoopCount; + return true; + } + + bool TraverseStmt(Stmt *S) override { + if (!S) + return true; + + // Skip traversal of all expressions, including special cases like + // LambdaExpr, StmtExpr, BlockExpr, and RequiresExpr. These expressions + // may contain inner statements (and even loops), but they are not part + // of the syntactic body of the surrounding loop structure. + // Therefore must not be counted + if (isa<Expr>(S)) + return true; + + // Only recurse into CompoundStmt (block {}) and loop bodies + if (isa<CompoundStmt>(S) || isa<ForStmt>(S) || isa<CXXForRangeStmt>(S)) { + return DynamicRecursiveASTVisitor::TraverseStmt(S); + } + + // Stop traversal of the rest of statements, that break perfect + // loop nesting, such as control flow (IfStmt, SwitchStmt...) + return true; + } + + bool TraverseDecl(Decl *D) override { + // Stop in the case of finding a declaration, it is not important + // in order to find nested loops (Possible CXXRecordDecl, RecordDecl, + // FunctionDecl...) + return true; + } +}; + +bool SemaOpenMP::analyzeLoopSequence( + Stmt *LoopSeqStmt, unsigned &LoopSeqSize, unsigned &NumLoops, + SmallVectorImpl<OMPLoopBasedDirective::HelperExprs> &LoopHelpers, + SmallVectorImpl<Stmt *> &ForStmts, + SmallVectorImpl<SmallVector<Stmt *, 0>> &OriginalInits, + SmallVectorImpl<SmallVector<Stmt *, 0>> &TransformsPreInits, + SmallVectorImpl<SmallVector<Stmt *, 0>> &LoopSequencePreInits, + SmallVectorImpl<OMPLoopCategory> &LoopCategories, ASTContext &Context, + OpenMPDirectiveKind Kind) { + + VarsWithInheritedDSAType TmpDSA; + QualType BaseInductionVarType; + // Helper Lambda to handle storing initialization and body statements for both + // ForStmt and CXXForRangeStmt and checks for any possible mismatch between + // induction variables types + auto storeLoopStatements = [&OriginalInits, &ForStmts, &BaseInductionVarType, + this, &Context](Stmt *LoopStmt) { + if (auto *For = dyn_cast<ForStmt>(LoopStmt)) { + OriginalInits.back().push_back(For->getInit()); + ForStmts.push_back(For); + // Extract induction variable + if (auto *InitStmt = dyn_cast_or_null<DeclStmt>(For->getInit())) { + if (auto *InitDecl = dyn_cast<VarDecl>(InitStmt->getSingleDecl())) { + QualType InductionVarType = InitDecl->getType().getCanonicalType(); + + // Compare with first loop type + if (BaseInductionVarType.isNull()) { + BaseInductionVarType = InductionVarType; + } else if (!Context.hasSameType(BaseInductionVarType, + InductionVarType)) { + Diag(InitDecl->getBeginLoc(), + diag::warn_omp_different_loop_ind_var_types) + << getOpenMPDirectiveName(OMPD_fuse) << BaseInductionVarType + << InductionVarType; + } + } + } + } else { + auto *CXXFor = cast<CXXForRangeStmt>(LoopStmt); + OriginalInits.back().push_back(CXXFor->getBeginStmt()); + ForStmts.push_back(CXXFor); + } + }; + + // Helper lambda functions to encapsulate the processing of different + // derivations of the canonical loop sequence grammar + // + // Modularized code for handling loop generation and transformations + auto analyzeLoopGeneration = [&storeLoopStatements, &LoopHelpers, + &OriginalInits, &TransformsPreInits, + &LoopCategories, &LoopSeqSize, &NumLoops, Kind, + &TmpDSA, &ForStmts, &Context, + &LoopSequencePreInits, this](Stmt *Child) { + auto LoopTransform = dyn_cast<OMPLoopTransformationDirective>(Child); + Stmt *TransformedStmt = LoopTransform->getTransformedStmt(); + unsigned NumGeneratedLoopNests = LoopTransform->getNumGeneratedLoopNests(); + unsigned NumGeneratedLoops = LoopTransform->getNumGeneratedLoops(); + // Handle the case where transformed statement is not available due to + // dependent contexts + if (!TransformedStmt) { + if (NumGeneratedLoopNests > 0) { + LoopSeqSize += NumGeneratedLoopNests; + NumLoops += NumGeneratedLoops; + return true; + } + // Unroll full (0 loops produced) + else { + Diag(Child->getBeginLoc(), diag::err_omp_not_for) + << 0 << getOpenMPDirectiveName(Kind); + return false; + } + } + // Handle loop transformations with multiple loop nests + // Unroll full + if (NumGeneratedLoopNests <= 0) { + Diag(Child->getBeginLoc(), diag::err_omp_not_for) + << 0 << getOpenMPDirectiveName(Kind); + return false; + } + // Loop transformatons such as split or loopranged fuse + else if (NumGeneratedLoopNests > 1) { ---------------- eZWALT wrote:
The rest of the conditionals i can understand removing the { tokens for conciseness, but this change worsens readibility of the code and can easily introduce errors if modified in the future due to this control flow being not identical to the one i proposed. Could you please elaborate a bit further why this change is needed? https://github.com/llvm/llvm-project/pull/139293 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits