Rework after review.

http://reviews.llvm.org/D4716

Files:
  include/clang/AST/StmtOpenMP.h
  lib/AST/Stmt.cpp
  lib/CodeGen/CGOpenMPRuntime.cpp
  lib/CodeGen/CGOpenMPRuntime.h
  lib/CodeGen/CGStmtOpenMP.cpp
  lib/CodeGen/CodeGenFunction.h
  test/OpenMP/parallel_if_codegen.cpp
Index: test/OpenMP/parallel_if_codegen.cpp
===================================================================
--- test/OpenMP/parallel_if_codegen.cpp
+++ test/OpenMP/parallel_if_codegen.cpp
@@ -0,0 +1,124 @@
+// RUN: %clang_cc1 -verify -fopenmp=libiomp5 -x c++ -emit-llvm %s -o - | FileCheck %s
+// RUN: %clang_cc1 -fopenmp=libiomp5 -x c++ -std=c++11 -triple x86_64-unknown-unknown -emit-pch -o %t %s
+// RUN: %clang_cc1 -fopenmp=libiomp5 -x c++ -triple x86_64-unknown-unknown -std=c++11 -include-pch %t -verify %s -emit-llvm -o - | FileCheck --check-prefix=CHECK %s
+// expected-no-diagnostics
+#ifndef HEADER
+#define HEADER
+
+void fn1();
+void fn2();
+void fn3();
+void fn4();
+void fn5();
+void fn6();
+
+int Arg;
+
+// CHECK-LABEL: define void @{{.+}}gtid_test
+void gtid_test() {
+// CHECK:  call void {{.+}}* @__kmpc_fork_call(%{{.+}}* @{{.+}}, i{{.+}} 1, {{.+}}* [[GTID_TEST_REGION1:@.+]] to void
+#pragma omp parallel
+#pragma omp parallel if (false)
+  gtid_test();
+// CHECK: ret void
+}
+
+// CHECK: define internal void [[GTID_TEST_REGION1]](i{{.+}}* [[GTID_PARAM:%.+]], i
+// CHECK: store i{{[0-9]+}}* [[GTID_PARAM]], i{{[0-9]+}}** [[GTID_ADDR_REF:%.+]],
+// CHECK: [[GTID_ADDR:%.+]] = load i{{[0-9]+}}** [[GTID_ADDR_REF]]
+// CHECK: [[GTID:%.+]] = load i{{[0-9]+}}* [[GTID_ADDR]]
+// CHECK: call void @__kmpc_serialized_parallel(%{{.+}}* @{{.+}}, i{{.+}} [[GTID]])
+// CHECK: [[GTID_ADDR:%.+]] = load i{{[0-9]+}}** [[GTID_ADDR_REF]]
+// CHECK: call void [[GTID_TEST_REGION2:@.+]](i{{[0-9]+}}* [[GTID_ADDR]]
+// CHECK: call void @__kmpc_end_serialized_parallel(%{{.+}}* @{{.+}}, i{{.+}} [[GTID]])
+// CHECK: ret void
+
+// CHECK: define internal void [[GTID_TEST_REGION2]](
+// CHECK: call void @{{.+}}gtid_test
+// CHECK: ret void
+
+template <typename T>
+int tmain(T Arg) {
+#pragma omp parallel if (true)
+  fn1();
+#pragma omp parallel if (false)
+  fn2();
+#pragma omp parallel if (Arg)
+  fn3();
+  return 0;
+}
+
+// CHECK-LABEL: define {{[a-z]*[ ]?i32}} @main()
+int main() {
+// CHECK: [[GTID:%.+]] = call i32 @__kmpc_global_thread_num(
+// CHECK: call void {{.+}} @__kmpc_fork_call(%{{.+}}* @{{.+}}, i{{.+}} 1, void {{.+}}* [[CAP_FN4:@.+]] to void
+#pragma omp parallel if (true)
+  fn4();
+// CHECK: call void @__kmpc_serialized_parallel(%{{.+}}* @{{.+}}, i32 [[GTID]])
+// CHECK: store i32 [[GTID]], i32* [[GTID_ADDR:%.+]],
+// CHECK: call void [[CAP_FN5:@.+]](i32* [[GTID_ADDR]],
+// CHECK: call void @__kmpc_end_serialized_parallel(%{{.+}}* @{{.+}}, i32 [[GTID]])
+#pragma omp parallel if (false)
+  fn5();
+
+// CHECK: br i1 %{{.+}}, label %[[OMP_THEN:.+]], label %[[OMP_ELSE:.+]]
+// CHECK: [[OMP_THEN]]:
+// CHECK: call void {{.+}} @__kmpc_fork_call(%{{.+}}* @{{.+}}, i{{.+}} 1, void {{.+}}* [[CAP_FN6:@.+]] to void
+// CHECK: br label %[[OMP_END:.+]]
+// CHECK: [[OMP_ELSE]]:
+// CHECK: call void @__kmpc_serialized_parallel(%{{.+}}* @{{.+}}, i32 %0)
+// CHECK: store i32 [[GTID]], i32* [[GTID_ADDR:%.+]],
+// CHECK: call void [[CAP_FN6]](i32* [[GTID_ADDR]],
+// CHECK: call void @__kmpc_end_serialized_parallel(%{{.+}}* @{{.+}}, i32 [[GTID]])
+// CHECK: br label %[[OMP_END]]
+// CHECK: [[OMP_END]]:
+#pragma omp parallel if (Arg)
+  fn6();
+  // CHECK: = call i{{.+}} @{{.+}}tmain
+  return tmain(Arg);
+}
+
+// CHECK: define internal void [[CAP_FN4]]
+// CHECK: call void @{{.+}}fn4
+// CHECK: ret void
+
+// CHECK: define internal void [[CAP_FN5]]
+// CHECK: call void @{{.+}}fn5
+// CHECK: ret void
+
+// CHECK: define internal void [[CAP_FN6]]
+// CHECK: call void @{{.+}}fn6
+// CHECK: ret void
+
+// CHECK-LABEL: define {{.+}} @{{.+}}tmain
+// CHECK: [[GTID:%.+]] = call i32 @__kmpc_global_thread_num(
+// CHECK: call void {{.+}} @__kmpc_fork_call(%{{.+}}* @{{.+}}, i{{.+}} 1, void {{.+}}* [[CAP_FN1:@.+]] to void
+// CHECK: call void @__kmpc_serialized_parallel(%{{.+}}* @{{.+}}, i32 [[GTID]])
+// CHECK: store i32 [[GTID]], i32* [[GTID_ADDR:%.+]],
+// CHECK: call void [[CAP_FN2:@.+]](i32* [[GTID_ADDR]],
+// CHECK: call void @__kmpc_end_serialized_parallel(%{{.+}}* @{{.+}}, i32 [[GTID]])
+// CHECK: br i1 %{{.+}}, label %[[OMP_THEN:.+]], label %[[OMP_ELSE:.+]]
+// CHECK: [[OMP_THEN]]:
+// CHECK: call void {{.+}} @__kmpc_fork_call(%{{.+}}* @{{.+}}, i{{.+}} 1, void {{.+}}* [[CAP_FN3:@.+]] to void
+// CHECK: br label %[[OMP_END:.+]]
+// CHECK: [[OMP_ELSE]]:
+// CHECK: call void @__kmpc_serialized_parallel(%{{.+}}* @{{.+}}, i32 %0)
+// CHECK: store i32 [[GTID]], i32* [[GTID_ADDR:%.+]],
+// CHECK: call void [[CAP_FN3]](i32* [[GTID_ADDR]],
+// CHECK: call void @__kmpc_end_serialized_parallel(%{{.+}}* @{{.+}}, i32 [[GTID]])
+// CHECK: br label %[[OMP_END]]
+// CHECK: [[OMP_END]]:
+
+// CHECK: define internal void [[CAP_FN1]]
+// CHECK: call void @{{.+}}fn1
+// CHECK: ret void
+
+// CHECK: define internal void [[CAP_FN2]]
+// CHECK: call void @{{.+}}fn2
+// CHECK: ret void
+
+// CHECK: define internal void [[CAP_FN3]]
+// CHECK: call void @{{.+}}fn3
+// CHECK: ret void
+
+#endif
Index: include/clang/AST/StmtOpenMP.h
===================================================================
--- include/clang/AST/StmtOpenMP.h
+++ include/clang/AST/StmtOpenMP.h
@@ -128,6 +128,10 @@
     operator bool() { return Current != End; }
   };
 
+  /// \brief Gets single clause of the specified kind \a K associated with the
+  /// current directive iff there is only one clause of this kind.
+  const OMPClause *getSingleClause(OpenMPClauseKind K) const;
+
   /// \brief Returns starting location of directive kind.
   SourceLocation getLocStart() const { return StartLoc; }
   /// \brief Returns ending location of directive.
Index: lib/AST/Stmt.cpp
===================================================================
--- lib/AST/Stmt.cpp
+++ lib/AST/Stmt.cpp
@@ -1350,6 +1350,21 @@
   return new (Mem) OMPFlushClause(N);
 }
 
+const OMPClause *
+OMPExecutableDirective::getSingleClause(OpenMPClauseKind K) const {
+  auto ClauseFilter =
+      [=](const OMPClause *C) -> bool { return C->getClauseKind() == K; };
+  OMPExecutableDirective::filtered_clause_iterator<decltype(ClauseFilter)> I(
+      clauses(), ClauseFilter);
+
+  if (I) {
+    auto *Clause = *I;
+    assert(!++I && "There is at least 2 clauses of the  specified kind");
+    return Clause;
+  }
+  return nullptr;
+}
+
 OMPParallelDirective *OMPParallelDirective::Create(
                                               const ASTContext &C,
                                               SourceLocation StartLoc,
Index: lib/CodeGen/CGOpenMPRuntime.cpp
===================================================================
--- lib/CodeGen/CGOpenMPRuntime.cpp
+++ lib/CodeGen/CGOpenMPRuntime.cpp
@@ -15,6 +15,7 @@
 #include "CodeGenFunction.h"
 #include "clang/AST/Decl.h"
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/IR/CallSite.h"
 #include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/GlobalValue.h"
 #include "llvm/IR/Value.h"
@@ -24,6 +25,12 @@
 using namespace clang;
 using namespace CodeGen;
 
+LValue CGOpenMPRegionInfo::getThreadIDVariableLValue(CodeGenFunction &CGF) {
+  return CGF.MakeNaturalAlignAddrLValue(
+      CGF.GetAddrOfLocalVar(ThreadIDVar),
+      CGF.getContext().getPointerType(ThreadIDVar->getType()));
+}
+
 CGOpenMPRuntime::CGOpenMPRuntime(CodeGenModule &CGM)
     : CGM(CGM), DefaultOpenMPPSource(nullptr) {
   IdentTy = llvm::StructType::create(
@@ -119,49 +126,41 @@
   return LocValue;
 }
 
-llvm::Value *CGOpenMPRuntime::GetOpenMPGlobalThreadNum(CodeGenFunction &CGF,
-                                                       SourceLocation Loc) {
+llvm::Value *CGOpenMPRuntime::GetOpenMPThreadID(CodeGenFunction &CGF,
+                                                SourceLocation Loc) {
   assert(CGF.CurFn && "No function in current CodeGenFunction.");
 
-  llvm::Value *GTid = nullptr;
-  OpenMPGtidMapTy::iterator I = OpenMPGtidMap.find(CGF.CurFn);
-  if (I != OpenMPGtidMap.end()) {
-    GTid = I->second;
+  llvm::Value *ThreadID = nullptr;
+  OpenMPThreadIDMapTy::iterator I = OpenMPThreadIDMap.find(CGF.CurFn);
+  if (I != OpenMPThreadIDMap.end()) {
+    ThreadID = I->second;
+  } else if (auto OMPRegionInfo =
+                 dyn_cast_or_null<CGOpenMPRegionInfo>(CGF.CapturedStmtInfo)) {
+    auto ThreadIDVar = OMPRegionInfo->getThreadIDVariable();
+    auto LVal = OMPRegionInfo->getThreadIDVariableLValue(CGF);
+    auto RVal = CGF.EmitLoadOfLValue(LVal, SourceLocation());
+    LVal = CGF.MakeNaturalAlignAddrLValue(RVal.getScalarVal(),
+                                          ThreadIDVar->getType());
+    ThreadID = CGF.EmitLoadOfLValue(LVal, SourceLocation()).getScalarVal();
+    // If value loaded in entry block, use it everywhere in function.
+    if (CGF.Builder.GetInsertBlock() == CGF.AllocaInsertPt->getParent())
+      OpenMPThreadIDMap[CGF.CurFn] = ThreadID;
   } else {
-    // Check if current function is a function which has first parameter
-    // with type int32 and name ".global_tid.".
-    if (!CGF.CurFn->arg_empty() &&
-        CGF.CurFn->arg_begin()->getType()->isPointerTy() &&
-        CGF.CurFn->arg_begin()
-            ->getType()
-            ->getPointerElementType()
-            ->isIntegerTy() &&
-        CGF.CurFn->arg_begin()
-                ->getType()
-                ->getPointerElementType()
-                ->getIntegerBitWidth() == 32 &&
-        CGF.CurFn->arg_begin()->hasName() &&
-        CGF.CurFn->arg_begin()->getName() == ".global_tid.") {
-      CGBuilderTy::InsertPointGuard IPG(CGF.Builder);
-      CGF.Builder.SetInsertPoint(CGF.AllocaInsertPt);
-      GTid = CGF.Builder.CreateLoad(CGF.CurFn->arg_begin());
-    } else {
-      // Generate "int32 .kmpc_global_thread_num.addr;"
-      CGBuilderTy::InsertPointGuard IPG(CGF.Builder);
-      CGF.Builder.SetInsertPoint(CGF.AllocaInsertPt);
-      llvm::Value *Args[] = {EmitOpenMPUpdateLocation(CGF, Loc)};
-      GTid = CGF.EmitRuntimeCall(
-          CreateRuntimeFunction(OMPRTL__kmpc_global_thread_num), Args);
-    }
-    OpenMPGtidMap[CGF.CurFn] = GTid;
+    // Generate "int32 .kmpc_global_thread_num.addr;"
+    CGBuilderTy::InsertPointGuard IPG(CGF.Builder);
+    CGF.Builder.SetInsertPoint(CGF.AllocaInsertPt);
+    llvm::Value *Args[] = {EmitOpenMPUpdateLocation(CGF, Loc)};
+    ThreadID = CGF.EmitRuntimeCall(
+        CreateRuntimeFunction(OMPRTL__kmpc_global_thread_num), Args);
+    OpenMPThreadIDMap[CGF.CurFn] = ThreadID;
   }
-  return GTid;
+  return ThreadID;
 }
 
 void CGOpenMPRuntime::FunctionFinished(CodeGenFunction &CGF) {
   assert(CGF.CurFn && "No function in current CodeGenFunction.");
-  if (OpenMPGtidMap.count(CGF.CurFn))
-    OpenMPGtidMap.erase(CGF.CurFn);
+  if (OpenMPThreadIDMap.count(CGF.CurFn))
+    OpenMPThreadIDMap.erase(CGF.CurFn);
   if (OpenMPLocMap.count(CGF.CurFn))
     OpenMPLocMap.erase(CGF.CurFn);
 }
@@ -184,18 +183,102 @@
     llvm::Type *TypeParams[] = {getIdentTyPointerTy(), CGM.Int32Ty,
                                 getKmpc_MicroPointerTy()};
     llvm::FunctionType *FnTy =
-        llvm::FunctionType::get(CGM.VoidTy, TypeParams, true);
+        llvm::FunctionType::get(CGM.VoidTy, TypeParams, /*isVarArg*/ true);
     RTLFn = CGM.CreateRuntimeFunction(FnTy, "__kmpc_fork_call");
     break;
   }
   case OMPRTL__kmpc_global_thread_num: {
     // Build kmp_int32 __kmpc_global_thread_num(ident_t *loc);
     llvm::Type *TypeParams[] = {getIdentTyPointerTy()};
     llvm::FunctionType *FnTy =
-        llvm::FunctionType::get(CGM.Int32Ty, TypeParams, false);
+        llvm::FunctionType::get(CGM.Int32Ty, TypeParams, /*isVarArg*/ false);
     RTLFn = CGM.CreateRuntimeFunction(FnTy, "__kmpc_global_thread_num");
     break;
   }
+  case OMPRTL__kmpc_serialized_parallel: {
+    // Build void __kmpc_serialized_parallel(ident_t *loc, kmp_int32
+    // global_tid);
+    llvm::Type *TypeParams[] = {getIdentTyPointerTy(), CGM.Int32Ty};
+    llvm::FunctionType *FnTy =
+        llvm::FunctionType::get(CGM.VoidTy, TypeParams, /*isVarArg*/ false);
+    RTLFn = CGM.CreateRuntimeFunction(FnTy, "__kmpc_serialized_parallel");
+    break;
+  }
+  case OMPRTL__kmpc_end_serialized_parallel: {
+    // Build void __kmpc_end_serialized_parallel(ident_t *loc, kmp_int32
+    // global_tid);
+    llvm::Type *TypeParams[] = {getIdentTyPointerTy(), CGM.Int32Ty};
+    llvm::FunctionType *FnTy =
+        llvm::FunctionType::get(CGM.VoidTy, TypeParams, /*isVarArg*/ false);
+    RTLFn = CGM.CreateRuntimeFunction(FnTy, "__kmpc_end_serialized_parallel");
+    break;
+  }
   }
   return RTLFn;
 }
+
+void CGOpenMPRuntime::EmitOMPParallelCall(CodeGenFunction &CGF,
+                                          SourceLocation Loc,
+                                          llvm::Value *OutlinedFn,
+                                          llvm::Value *CapturedStruct) {
+  // Build call __kmpc_fork_call(loc, 1, microtask, captured_struct/*context*/)
+  llvm::Value *Args[] = {
+      EmitOpenMPUpdateLocation(CGF, Loc),
+      CGF.Builder.getInt32(1), // Number of arguments after 'microtask' argument
+      // (there is only one additional argument - 'context')
+      CGF.Builder.CreateBitCast(OutlinedFn, getKmpc_MicroPointerTy()),
+      CGF.EmitCastToVoidPtr(CapturedStruct)};
+  auto RTLFn = CreateRuntimeFunction(CGOpenMPRuntime::OMPRTL__kmpc_fork_call);
+  CGF.EmitRuntimeCall(RTLFn, Args);
+}
+
+void CGOpenMPRuntime::EmitOMPSerialCall(CodeGenFunction &CGF,
+                                        SourceLocation Loc,
+                                        llvm::Value *OutlinedFn,
+                                        llvm::Value *CapturedStruct) {
+  auto ThreadID = GetOpenMPThreadID(CGF, Loc);
+  // Build calls:
+  // __kmpc_serialized_parallel(&Loc, GTid);
+  llvm::Value *SerArgs[] = {EmitOpenMPUpdateLocation(CGF, Loc), ThreadID};
+  auto RTLFn =
+      CreateRuntimeFunction(CGOpenMPRuntime::OMPRTL__kmpc_serialized_parallel);
+  CGF.EmitRuntimeCall(RTLFn, SerArgs);
+  // OutlinedFn(&GTid, &zero, CapturedStruct);
+  auto ThreadIDAddr = EmitThreadIDAddress(CGF, Loc);
+  auto Int32Ty =
+      CGF.getContext().getIntTypeForBitwidth(/*DestWidth*/ 32, /*Signed*/ true);
+  auto ZeroAddr = CGF.CreateMemTemp(Int32Ty, /*Name*/ ".zero.addr");
+  CGF.InitTempAlloca(ZeroAddr, CGF.Builder.getInt32(/*C*/ 0));
+  llvm::Value *OutlinedFnArgs[] = {ThreadIDAddr, ZeroAddr, CapturedStruct};
+  CGF.EmitCallOrInvoke(OutlinedFn, OutlinedFnArgs);
+  // __kmpc_end_serialized_parallel(&Loc, GTid);
+  llvm::Value *EndSerArgs[] = {EmitOpenMPUpdateLocation(CGF, Loc), ThreadID};
+  RTLFn = CreateRuntimeFunction(
+      CGOpenMPRuntime::OMPRTL__kmpc_end_serialized_parallel);
+  CGF.EmitRuntimeCall(RTLFn, EndSerArgs);
+}
+
+// If we’re inside an (outlined) parallel region, use the region info’s
+// thread-ID variable (it is passed in a first argument of the outlined function
+// as "kmp_int32 *gtid");
+// otherwise, if we're not inside parallel region, but in regular serial code
+// region, get thread ID by calling kmp_int32 kmpc_global_thread_num(ident_t
+// *loc), stash this thread ID in a temporary and return the address of that
+// temp.
+//
+llvm::Value *CGOpenMPRuntime::EmitThreadIDAddress(CodeGenFunction &CGF,
+                                                  SourceLocation Loc) {
+  if (auto OMPRegionInfo =
+          dyn_cast_or_null<CGOpenMPRegionInfo>(CGF.CapturedStmtInfo))
+    return CGF.EmitLoadOfLValue(OMPRegionInfo->getThreadIDVariableLValue(CGF),
+                                SourceLocation()).getScalarVal();
+  auto ThreadID = GetOpenMPThreadID(CGF, Loc);
+  auto Int32Ty =
+      CGF.getContext().getIntTypeForBitwidth(/*DestWidth*/ 32, /*Signed*/ true);
+  auto ThreadIDTemp = CGF.CreateMemTemp(Int32Ty, /*Name*/ ".threadid_temp.");
+  CGF.EmitStoreOfScalar(ThreadID,
+                        CGF.MakeNaturalAlignAddrLValue(ThreadIDTemp, Int32Ty));
+
+  return ThreadIDTemp;
+}
+
Index: lib/CodeGen/CodeGenFunction.h
===================================================================
--- lib/CodeGen/CodeGenFunction.h
+++ lib/CodeGen/CodeGenFunction.h
@@ -224,6 +224,8 @@
     /// \brief Get the name of the capture helper.
     virtual StringRef getHelperName() const { return "__captured_stmt"; }
 
+    static bool classof(const CGCapturedStmtInfo *) { return true; }
+
   private:
     /// \brief The kind of captured statement being generated.
     CapturedRegionKind Kind;
Index: lib/CodeGen/CGStmtOpenMP.cpp
===================================================================
--- lib/CodeGen/CGStmtOpenMP.cpp
+++ lib/CodeGen/CGStmtOpenMP.cpp
@@ -23,29 +23,71 @@
 //                              OpenMP Directive Emission
 //===----------------------------------------------------------------------===//
 
+static void EmitOMPIfClause(CodeGenFunction &CGF, const Expr *Cond,
+                            std::function<void(bool)> CodeGen) {
+  CodeGenFunction::LexicalScope ConditionScope(CGF, Cond->getSourceRange());
+
+  // If the condition constant folds and can be elided, try to avoid emitting
+  // the condition and the dead arm of the if/else.
+  bool CondConstant;
+  if (CGF.ConstantFoldsToSimpleInteger(Cond, CondConstant)) {
+    CodeGen(CondConstant);
+    return;
+  }
+
+  // Otherwise, the condition did not fold, or we couldn't elide it.  Just
+  // emit the conditional branch.
+  auto ThenBlock = CGF.createBasicBlock(/*name*/ "omp_if.then");
+  auto ElseBlock = CGF.createBasicBlock(/*name*/ "omp_if.else");
+  auto ContBlock = CGF.createBasicBlock(/*name*/ "omp_if.end");
+  CGF.EmitBranchOnBoolExpr(Cond, ThenBlock, ElseBlock, /*TrueCount*/ 0);
+
+  // Emit the 'then' code.
+  CGF.EmitBlock(ThenBlock);
+  CodeGen(/*ThenBlock*/ true);
+  CGF.EmitBranch(ContBlock);
+  // Emit the 'else' code if present.
+  {
+    // There is no need to emit line number for unconditional branch.
+    SuppressDebugLocation SDL(CGF.Builder);
+    CGF.EmitBlock(ElseBlock);
+  }
+  CodeGen(/*ThenBlock*/ false);
+  {
+    // There is no need to emit line number for unconditional branch.
+    SuppressDebugLocation SDL(CGF.Builder);
+    CGF.EmitBranch(ContBlock);
+  }
+  // Emit the continuation block for code after the if.
+  CGF.EmitBlock(ContBlock, /*IsFinished*/ true);
+}
+
 void CodeGenFunction::EmitOMPParallelDirective(const OMPParallelDirective &S) {
   const CapturedStmt *CS = cast<CapturedStmt>(S.getAssociatedStmt());
   llvm::Value *CapturedStruct = GenerateCapturedStmtArgument(*CS);
 
   llvm::Value *OutlinedFn;
   {
-    CodeGenFunction CGF(CGM, true);
-    CGCapturedStmtInfo CGInfo(*CS, CS->getCapturedRegionKind());
+    CodeGenFunction CGF(CGM, /*suppressNewContext*/ true);
+    CGOpenMPRegionInfo CGInfo(*CS, *CS->getCapturedDecl()->param_begin());
     CGF.CapturedStmtInfo = &CGInfo;
     OutlinedFn = CGF.GenerateCapturedStmtFunction(*CS);
   }
 
-  // Build call __kmpc_fork_call(loc, 1, microtask, captured_struct/*context*/)
-  llvm::Value *Args[] = {
-      CGM.getOpenMPRuntime().EmitOpenMPUpdateLocation(*this, S.getLocStart()),
-      Builder.getInt32(1), // Number of arguments after 'microtask' argument
-      // (there is only one additional argument - 'context')
-      Builder.CreateBitCast(OutlinedFn,
-                            CGM.getOpenMPRuntime().getKmpc_MicroPointerTy()),
-      EmitCastToVoidPtr(CapturedStruct)};
-  llvm::Constant *RTLFn = CGM.getOpenMPRuntime().CreateRuntimeFunction(
-      CGOpenMPRuntime::OMPRTL__kmpc_fork_call);
-  EmitRuntimeCall(RTLFn, Args);
+  if (auto C = S.getSingleClause(/*K*/ OMPC_if)) {
+    auto Cond = cast<OMPIfClause>(C)->getCondition();
+    EmitOMPIfClause(*this, Cond, [&](bool ThenBlock) {
+      if (ThenBlock)
+        CGM.getOpenMPRuntime().EmitOMPParallelCall(*this, S.getLocStart(),
+                                                   OutlinedFn, CapturedStruct);
+      else
+        CGM.getOpenMPRuntime().EmitOMPSerialCall(*this, S.getLocStart(),
+                                                 OutlinedFn, CapturedStruct);
+    });
+  } else {
+    CGM.getOpenMPRuntime().EmitOMPParallelCall(*this, S.getLocStart(),
+                                               OutlinedFn, CapturedStruct);
+  }
 }
 
 void CodeGenFunction::EmitOMPSimdDirective(const OMPSimdDirective &S) {
Index: lib/CodeGen/CGOpenMPRuntime.h
===================================================================
--- lib/CodeGen/CGOpenMPRuntime.h
+++ lib/CodeGen/CGOpenMPRuntime.h
@@ -14,31 +14,40 @@
 #ifndef CLANG_CODEGEN_OPENMPRUNTIME_H
 #define CLANG_CODEGEN_OPENMPRUNTIME_H
 
+#include "CodeGenFunction.h"
 #include "clang/AST/Type.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Value.h"
 
-namespace llvm {
-class AllocaInst;
-class CallInst;
-class GlobalVariable;
-class Constant;
-class Function;
-class Module;
-class StructLayout;
-class FunctionType;
-class StructType;
-class Type;
-class Value;
-} // namespace llvm
-
 namespace clang {
 
 namespace CodeGen {
 
-class CodeGenFunction;
-class CodeGenModule;
+/// \brief API for captured statement code generation in OpenMP constructs.
+class CGOpenMPRegionInfo : public CodeGenFunction::CGCapturedStmtInfo {
+public:
+  CGOpenMPRegionInfo(const CapturedStmt &S, const VarDecl *ThreadIDVar)
+      : CGCapturedStmtInfo(S, CR_OpenMP), ThreadIDVar(ThreadIDVar) {
+    assert(ThreadIDVar != nullptr && "No ThreadID in OpenMP region.");
+  }
+
+  /// \brief Gets a variable or parameter for storing global thread identity of
+  /// the thread executing OpenMP construct. The type of this variable is
+  /// kmp_int32.
+  const VarDecl *getThreadIDVariable() const { return ThreadIDVar; }
+  /// \brief Gets an LValue for the current ThreadID variable.
+  LValue getThreadIDVariableLValue(CodeGenFunction &CGF);
+
+  static bool classof(const CGCapturedStmtInfo *Info) {
+    return Info->getKind() == CR_OpenMP;
+  }
+
+private:
+  /// \brief A variable or parameter storing global thread id for OpenMP
+  /// constructs.
+  const VarDecl *ThreadIDVar;
+};
 
 class CGOpenMPRuntime {
 public:
@@ -68,7 +77,13 @@
     // microtask, ...);
     OMPRTL__kmpc_fork_call,
     // Call to kmp_int32 kmpc_global_thread_num(ident_t *loc);
-    OMPRTL__kmpc_global_thread_num
+    OMPRTL__kmpc_global_thread_num,
+    // Call to void __kmpc_serialized_parallel(ident_t *loc, kmp_int32
+    // global_tid);
+    OMPRTL__kmpc_serialized_parallel,
+    // Call to void __kmpc_end_serialized_parallel(ident_t *loc, kmp_int32
+    // global_tid);
+    OMPRTL__kmpc_end_serialized_parallel
   };
 
 private:
@@ -131,18 +146,21 @@
   /// \brief Map of local debug location and functions.
   typedef llvm::DenseMap<llvm::Function *, llvm::Value *> OpenMPLocMapTy;
   OpenMPLocMapTy OpenMPLocMap;
-  /// \brief Map of local gtid and functions.
-  typedef llvm::DenseMap<llvm::Function *, llvm::Value *> OpenMPGtidMapTy;
-  OpenMPGtidMapTy OpenMPGtidMap;
+  /// \brief Map of local thread id and functions.
+  typedef llvm::DenseMap<llvm::Function *, llvm::Value *> OpenMPThreadIDMapTy;
+  OpenMPThreadIDMapTy OpenMPThreadIDMap;
 
-public:
-  explicit CGOpenMPRuntime(CodeGenModule &CGM);
-  ~CGOpenMPRuntime() {}
+  /// \brief Returns pointer to ident_t type;
+  llvm::Type *getIdentTyPointerTy();
 
-  /// \brief Cleans up references to the objects in finished function.
-  /// \param CGF Reference to finished CodeGenFunction.
+  /// \brief Returns pointer to kmpc_micro type;
+  llvm::Type *getKmpc_MicroPointerTy();
+
+  /// \brief Gets thread id value for the current thread.
+  /// \param CGF Reference to current CodeGenFunction.
+  /// \param Loc Clang source location.
   ///
-  void FunctionFinished(CodeGenFunction &CGF);
+  llvm::Value *GetOpenMPThreadID(CodeGenFunction &CGF, SourceLocation Loc);
 
   /// \brief Emits object of ident_t type with info for source location.
   /// \param CGF Reference to current CodeGenFunction.
@@ -153,23 +171,50 @@
   EmitOpenMPUpdateLocation(CodeGenFunction &CGF, SourceLocation Loc,
                            OpenMPLocationFlags Flags = OMP_IDENT_KMPC);
 
-  /// \brief Generates global thread number value.
+  /// \brief Returns specified OpenMP runtime function.
+  /// \param Function OpenMP runtime function.
+  /// \return Specified function.
+  llvm::Constant *CreateRuntimeFunction(OpenMPRTLFunction Function);
+
+public:
+  explicit CGOpenMPRuntime(CodeGenModule &CGM);
+  virtual ~CGOpenMPRuntime() {}
+
+  /// \brief Cleans up references to the objects in finished function.
+  /// \param CGF Reference to finished CodeGenFunction.
+  ///
+  void FunctionFinished(CodeGenFunction &CGF);
+
+  /// \brief Emits code for parallel call of the \a OutlinedFn with variables
+  /// captured in a record which address is stored in \a CapturedStruct.
   /// \param CGF Reference to current CodeGenFunction.
   /// \param Loc Clang source location.
+  /// \param OutlinedFn Outlined function to be run in parallel threads.
+  /// \param CapturedStruct A pointer to the record with the references to
+  /// variables used in \a OutlinedFn function.
   ///
-  llvm::Value *GetOpenMPGlobalThreadNum(CodeGenFunction &CGF,
-                                        SourceLocation Loc);
-
-  /// \brief Returns pointer to ident_t type;
-  llvm::Type *getIdentTyPointerTy();
+  virtual void EmitOMPParallelCall(CodeGenFunction &CGF, SourceLocation Loc,
+                                   llvm::Value *OutlinedFn,
+                                   llvm::Value *CapturedStruct);
 
-  /// \brief Returns pointer to kmpc_micro type;
-  llvm::Type *getKmpc_MicroPointerTy();
+  /// \brief Emits code for serial call of the \a OutlinedFn with variables
+  /// captured in a record which address is stored in \a CapturedStruct.
+  /// \param CGF Reference to current CodeGenFunction.
+  /// \param Loc Clang source location.
+  /// \param OutlinedFn Outlined function to be run in serial mode.
+  /// \param CapturedStruct A pointer to the record with the references to
+  /// variables used in \a OutlinedFn function.
+  ///
+  virtual void EmitOMPSerialCall(CodeGenFunction &CGF, SourceLocation Loc,
+                                 llvm::Value *OutlinedFn,
+                                 llvm::Value *CapturedStruct);
 
-  /// \brief Returns specified OpenMP runtime function.
-  /// \param Function OpenMP runtime function.
-  /// \return Specified function.
-  llvm::Constant *CreateRuntimeFunction(OpenMPRTLFunction Function);
+  /// \brief Emits address of the word in a memory where current thread id is
+  /// stored.
+  /// \param CGF Reference to current CodeGenFunction.
+  /// \param Loc Clang source location.
+  virtual llvm::Value *EmitThreadIDAddress(CodeGenFunction &CGF,
+                                           SourceLocation Loc);
 };
 } // namespace CodeGen
 } // namespace clang
_______________________________________________
cfe-commits mailing list
[email protected]
http://lists.cs.uiuc.edu/mailman/listinfo/cfe-commits

Reply via email to