================
@@ -262,41 +270,106 @@ static bool IsLoopTransforming(llvm::omp::Directive dir) 
{
   }
 }
 
-void OmpStructureChecker::CheckNestedBlock(const parser::OpenMPLoopConstruct 
&x,
-    const parser::Block &body, size_t &nestedCount) {
+void OmpStructureChecker::CheckNestedBlock(
+    const parser::OpenMPLoopConstruct &x, const parser::Block &body) {
   for (auto &stmt : body) {
     if (auto *dir{parser::Unwrap<parser::CompilerDirective>(stmt)}) {
       context_.Say(dir->source,
           "Compiler directives are not allowed inside OpenMP loop 
constructs"_warn_en_US);
-    } else if (parser::Unwrap<parser::DoConstruct>(stmt)) {
-      ++nestedCount;
     } else if (auto *omp{parser::Unwrap<parser::OpenMPLoopConstruct>(stmt)}) {
       if (!IsLoopTransforming(omp->BeginDir().DirName().v)) {
         context_.Say(omp->source,
             "Only loop-transforming OpenMP constructs are allowed inside 
OpenMP loop constructs"_err_en_US);
       }
-      ++nestedCount;
     } else if (auto *block{parser::Unwrap<parser::BlockConstruct>(stmt)}) {
-      CheckNestedBlock(x, std::get<parser::Block>(block->t), nestedCount);
-    } else {
+      CheckNestedBlock(x, std::get<parser::Block>(block->t));
+    } else if (!parser::Unwrap<parser::DoConstruct>(stmt)) {
       parser::CharBlock source{parser::GetSource(stmt).value_or(x.source)};
       context_.Say(source,
           "OpenMP loop construct can only contain DO loops or 
loop-nest-generating OpenMP constructs"_err_en_US);
     }
   }
 }
 
+static bool IsFullUnroll(const parser::OpenMPLoopConstruct &x) {
+  const parser::OmpDirectiveSpecification &beginSpec{x.BeginDir()};
+
+  if (beginSpec.DirName().v == llvm::omp::Directive::OMPD_unroll) {
+    return llvm::none_of(beginSpec.Clauses().v, [](const parser::OmpClause &c) 
{
+      return c.Id() == llvm::omp::Clause::OMPC_partial;
+    });
+  }
+  return false;
+}
+
+static std::optional<size_t> CountGeneratedLoops(
+    const parser::ExecutionPartConstruct &epc) {
+  if (parser::Unwrap<parser::DoConstruct>(epc)) {
+    return 1;
+  }
+
+  auto &omp{DEREF(parser::Unwrap<parser::OpenMPLoopConstruct>(epc))};
+  const parser::OmpDirectiveSpecification &beginSpec{omp.BeginDir()};
+  llvm::omp::Directive dir{beginSpec.DirName().v};
+
+  // TODO: Handle split, apply.
+  if (IsFullUnroll(omp)) {
+    return std::nullopt;
+  }
+  if (dir == llvm::omp::Directive::OMPD_fuse) {
+    auto rangeAt{
+        llvm::find_if(beginSpec.Clauses().v, [](const parser::OmpClause &c) {
+          return c.Id() == llvm::omp::Clause::OMPC_looprange;
+        })};
+    if (rangeAt == beginSpec.Clauses().v.end()) {
+      return std::nullopt;
+    }
+
+    auto *loopRange{parser::Unwrap<parser::OmpLooprangeClause>(*rangeAt)};
+    std::optional<int64_t> count{GetIntValue(std::get<1>(loopRange->t))};
+    if (!count || *count <= 0) {
+      return std::nullopt;
+    }
+    if (auto nestedCount{CountGeneratedLoops(std::get<parser::Block>(omp.t))}) 
{
+      return 1 + *nestedCount - static_cast<size_t>(*count);
----------------
tblah wrote:

Could this subtraction wrap for erroneous code with a bad looprange clause?

https://github.com/llvm/llvm-project/pull/170735
_______________________________________________
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to