Closed by commit rL219597 (authored by @ABataev).

REPOSITORY
  rL LLVM

http://reviews.llvm.org/D4716

Files:
  cfe/trunk/include/clang/AST/StmtOpenMP.h
  cfe/trunk/lib/AST/Stmt.cpp
  cfe/trunk/lib/CodeGen/CGOpenMPRuntime.cpp
  cfe/trunk/lib/CodeGen/CGOpenMPRuntime.h
  cfe/trunk/lib/CodeGen/CGStmtOpenMP.cpp
  cfe/trunk/test/OpenMP/parallel_if_codegen.cpp
Index: cfe/trunk/lib/CodeGen/CGOpenMPRuntime.h
===================================================================
--- cfe/trunk/lib/CodeGen/CGOpenMPRuntime.h
+++ cfe/trunk/lib/CodeGen/CGOpenMPRuntime.h
@@ -74,7 +74,13 @@
     // kmp_critical_name *crit);
     OMPRTL__kmpc_end_critical,
     // Call to void __kmpc_barrier(ident_t *loc, kmp_int32 global_tid);
-    OMPRTL__kmpc_barrier
+    OMPRTL__kmpc_barrier,
+    // Call to void __kmpc_serialized_parallel(ident_t *loc, kmp_int32
+    // global_tid);
+    OMPRTL__kmpc_serialized_parallel,
+    // Call to void __kmpc_end_serialized_parallel(ident_t *loc, kmp_int32
+    // global_tid);
+    OMPRTL__kmpc_end_serialized_parallel
   };
 
 private:
@@ -156,17 +162,22 @@
   EmitOpenMPUpdateLocation(CodeGenFunction &CGF, SourceLocation Loc,
                            OpenMPLocationFlags Flags = OMP_IDENT_KMPC);
 
-  /// \brief Returns pointer to ident_t type;
+  /// \brief Returns pointer to ident_t type.
   llvm::Type *getIdentTyPointerTy();
 
-  /// \brief Returns pointer to kmpc_micro type;
+  /// \brief Returns pointer to kmpc_micro type.
   llvm::Type *getKmpc_MicroPointerTy();
 
   /// \brief Returns specified OpenMP runtime function.
   /// \param Function OpenMP runtime function.
   /// \return Specified function.
   llvm::Constant *CreateRuntimeFunction(OpenMPRTLFunction Function);
 
+  /// \brief Emits address of the word in a memory where current thread id is
+  /// stored.
+  virtual llvm::Value *EmitThreadIDAddress(CodeGenFunction &CGF,
+                                           SourceLocation Loc);
+
   /// \brief Gets thread id value for the current thread.
   ///
   llvm::Value *GetOpenMPThreadID(CodeGenFunction &CGF, SourceLocation Loc);
@@ -201,6 +212,16 @@
                                    llvm::Value *OutlinedFn,
                                    llvm::Value *CapturedStruct);
 
+  /// \brief Emits code for serial call of the \a OutlinedFn with variables
+  /// captured in a record which address is stored in \a CapturedStruct.
+  /// \param OutlinedFn Outlined function to be run in serial mode.
+  /// \param CapturedStruct A pointer to the record with the references to
+  /// variables used in \a OutlinedFn function.
+  ///
+  virtual void EmitOMPSerialCall(CodeGenFunction &CGF, SourceLocation Loc,
+                                 llvm::Value *OutlinedFn,
+                                 llvm::Value *CapturedStruct);
+
   /// \brief Returns corresponding lock object for the specified critical region
   /// name. If the lock object does not exist it is created, otherwise the
   /// reference to the existing copy is returned.
Index: cfe/trunk/lib/CodeGen/CGOpenMPRuntime.cpp
===================================================================
--- cfe/trunk/lib/CodeGen/CGOpenMPRuntime.cpp
+++ cfe/trunk/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -16,6 +16,7 @@
 #include "clang/AST/StmtOpenMP.h"
 #include "clang/AST/Decl.h"
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/IR/CallSite.h"
 #include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/GlobalValue.h"
 #include "llvm/IR/Value.h"
@@ -253,15 +254,15 @@
     llvm::Type *TypeParams[] = {getIdentTyPointerTy(), CGM.Int32Ty,
                                 getKmpc_MicroPointerTy()};
     llvm::FunctionType *FnTy =
-        llvm::FunctionType::get(CGM.VoidTy, TypeParams, true);
+        llvm::FunctionType::get(CGM.VoidTy, TypeParams, /*isVarArg*/ true);
     RTLFn = CGM.CreateRuntimeFunction(FnTy, "__kmpc_fork_call");
     break;
   }
   case OMPRTL__kmpc_global_thread_num: {
     // Build kmp_int32 __kmpc_global_thread_num(ident_t *loc);
     llvm::Type *TypeParams[] = {getIdentTyPointerTy()};
     llvm::FunctionType *FnTy =
-        llvm::FunctionType::get(CGM.Int32Ty, TypeParams, false);
+        llvm::FunctionType::get(CGM.Int32Ty, TypeParams, /*isVarArg*/ false);
     RTLFn = CGM.CreateRuntimeFunction(FnTy, "__kmpc_global_thread_num");
     break;
   }
@@ -295,6 +296,24 @@
     RTLFn = CGM.CreateRuntimeFunction(FnTy, /*Name*/ "__kmpc_barrier");
     break;
   }
+  case OMPRTL__kmpc_serialized_parallel: {
+    // Build void __kmpc_serialized_parallel(ident_t *loc, kmp_int32
+    // global_tid);
+    llvm::Type *TypeParams[] = {getIdentTyPointerTy(), CGM.Int32Ty};
+    llvm::FunctionType *FnTy =
+        llvm::FunctionType::get(CGM.VoidTy, TypeParams, /*isVarArg*/ false);
+    RTLFn = CGM.CreateRuntimeFunction(FnTy, "__kmpc_serialized_parallel");
+    break;
+  }
+  case OMPRTL__kmpc_end_serialized_parallel: {
+    // Build void __kmpc_end_serialized_parallel(ident_t *loc, kmp_int32
+    // global_tid);
+    llvm::Type *TypeParams[] = {getIdentTyPointerTy(), CGM.Int32Ty};
+    llvm::FunctionType *FnTy =
+        llvm::FunctionType::get(CGM.VoidTy, TypeParams, /*isVarArg*/ false);
+    RTLFn = CGM.CreateRuntimeFunction(FnTy, "__kmpc_end_serialized_parallel");
+    break;
+  }
   }
   return RTLFn;
 }
@@ -314,6 +333,56 @@
   CGF.EmitRuntimeCall(RTLFn, Args);
 }
 
+void CGOpenMPRuntime::EmitOMPSerialCall(CodeGenFunction &CGF,
+                                        SourceLocation Loc,
+                                        llvm::Value *OutlinedFn,
+                                        llvm::Value *CapturedStruct) {
+  auto ThreadID = GetOpenMPThreadID(CGF, Loc);
+  // Build calls:
+  // __kmpc_serialized_parallel(&Loc, GTid);
+  llvm::Value *SerArgs[] = {EmitOpenMPUpdateLocation(CGF, Loc), ThreadID};
+  auto RTLFn =
+      CreateRuntimeFunction(CGOpenMPRuntime::OMPRTL__kmpc_serialized_parallel);
+  CGF.EmitRuntimeCall(RTLFn, SerArgs);
+
+  // OutlinedFn(&GTid, &zero, CapturedStruct);
+  auto ThreadIDAddr = EmitThreadIDAddress(CGF, Loc);
+  auto Int32Ty =
+      CGF.getContext().getIntTypeForBitwidth(/*DestWidth*/ 32, /*Signed*/ true);
+  auto ZeroAddr = CGF.CreateMemTemp(Int32Ty, /*Name*/ ".zero.addr");
+  CGF.InitTempAlloca(ZeroAddr, CGF.Builder.getInt32(/*C*/ 0));
+  llvm::Value *OutlinedFnArgs[] = {ThreadIDAddr, ZeroAddr, CapturedStruct};
+  CGF.EmitCallOrInvoke(OutlinedFn, OutlinedFnArgs);
+
+  // __kmpc_end_serialized_parallel(&Loc, GTid);
+  llvm::Value *EndSerArgs[] = {EmitOpenMPUpdateLocation(CGF, Loc), ThreadID};
+  RTLFn = CreateRuntimeFunction(
+      CGOpenMPRuntime::OMPRTL__kmpc_end_serialized_parallel);
+  CGF.EmitRuntimeCall(RTLFn, EndSerArgs);
+}
+
+// If we’re inside an (outlined) parallel region, use the region info’s
+// thread-ID variable (it is passed in a first argument of the outlined function
+// as "kmp_int32 *gtid"). Otherwise, if we're not inside parallel region, but in
+// regular serial code region, get thread ID by calling kmp_int32
+// kmpc_global_thread_num(ident_t *loc), stash this thread ID in a temporary and
+// return the address of that temp.
+llvm::Value *CGOpenMPRuntime::EmitThreadIDAddress(CodeGenFunction &CGF,
+                                                  SourceLocation Loc) {
+  if (auto OMPRegionInfo =
+          dyn_cast_or_null<CGOpenMPRegionInfo>(CGF.CapturedStmtInfo))
+    return CGF.EmitLoadOfLValue(OMPRegionInfo->getThreadIDVariableLValue(CGF),
+                                SourceLocation()).getScalarVal();
+  auto ThreadID = GetOpenMPThreadID(CGF, Loc);
+  auto Int32Ty =
+      CGF.getContext().getIntTypeForBitwidth(/*DestWidth*/ 32, /*Signed*/ true);
+  auto ThreadIDTemp = CGF.CreateMemTemp(Int32Ty, /*Name*/ ".threadid_temp.");
+  CGF.EmitStoreOfScalar(ThreadID,
+                        CGF.MakeNaturalAlignAddrLValue(ThreadIDTemp, Int32Ty));
+
+  return ThreadIDTemp;
+}
+
 llvm::Value *CGOpenMPRuntime::GetCriticalRegionLock(StringRef CriticalName) {
   SmallString<256> Buffer;
   llvm::raw_svector_ostream Out(Buffer);
Index: cfe/trunk/lib/CodeGen/CGStmtOpenMP.cpp
===================================================================
--- cfe/trunk/lib/CodeGen/CGStmtOpenMP.cpp
+++ cfe/trunk/lib/CodeGen/CGStmtOpenMP.cpp
@@ -24,6 +24,52 @@
 //                              OpenMP Directive Emission
 //===----------------------------------------------------------------------===//
 
+/// \brief Emits code for OpenMP 'if' clause using specified \a CodeGen
+/// function. Here is the logic:
+/// if (Cond) {
+///   CodeGen(true);
+/// } else {
+///   CodeGen(false);
+/// }
+static void EmitOMPIfClause(CodeGenFunction &CGF, const Expr *Cond,
+                            const std::function<void(bool)> &CodeGen) {
+  CodeGenFunction::LexicalScope ConditionScope(CGF, Cond->getSourceRange());
+
+  // If the condition constant folds and can be elided, try to avoid emitting
+  // the condition and the dead arm of the if/else.
+  bool CondConstant;
+  if (CGF.ConstantFoldsToSimpleInteger(Cond, CondConstant)) {
+    CodeGen(CondConstant);
+    return;
+  }
+
+  // Otherwise, the condition did not fold, or we couldn't elide it.  Just
+  // emit the conditional branch.
+  auto ThenBlock = CGF.createBasicBlock(/*name*/ "omp_if.then");
+  auto ElseBlock = CGF.createBasicBlock(/*name*/ "omp_if.else");
+  auto ContBlock = CGF.createBasicBlock(/*name*/ "omp_if.end");
+  CGF.EmitBranchOnBoolExpr(Cond, ThenBlock, ElseBlock, /*TrueCount*/ 0);
+
+  // Emit the 'then' code.
+  CGF.EmitBlock(ThenBlock);
+  CodeGen(/*ThenBlock*/ true);
+  CGF.EmitBranch(ContBlock);
+  // Emit the 'else' code if present.
+  {
+    // There is no need to emit line number for unconditional branch.
+    SuppressDebugLocation SDL(CGF.Builder);
+    CGF.EmitBlock(ElseBlock);
+  }
+  CodeGen(/*ThenBlock*/ false);
+  {
+    // There is no need to emit line number for unconditional branch.
+    SuppressDebugLocation SDL(CGF.Builder);
+    CGF.EmitBranch(ContBlock);
+  }
+  // Emit the continuation block for code after the if.
+  CGF.EmitBlock(ContBlock, /*IsFinished*/ true);
+}
+
 void CodeGenFunction::EmitOMPAggregateAssign(LValue OriginalAddr,
                                              llvm::Value *PrivateAddr,
                                              const Expr *AssignExpr,
@@ -142,8 +188,20 @@
   auto CapturedStruct = GenerateCapturedStmtArgument(*CS);
   auto OutlinedFn = CGM.getOpenMPRuntime().EmitOpenMPOutlinedFunction(
       S, *CS->getCapturedDecl()->param_begin());
-  CGM.getOpenMPRuntime().EmitOMPParallelCall(*this, S.getLocStart(), OutlinedFn,
-                                             CapturedStruct);
+  if (auto C = S.getSingleClause(/*K*/ OMPC_if)) {
+    auto Cond = cast<OMPIfClause>(C)->getCondition();
+    EmitOMPIfClause(*this, Cond, [&](bool ThenBlock) {
+      if (ThenBlock)
+        CGM.getOpenMPRuntime().EmitOMPParallelCall(*this, S.getLocStart(),
+                                                   OutlinedFn, CapturedStruct);
+      else
+        CGM.getOpenMPRuntime().EmitOMPSerialCall(*this, S.getLocStart(),
+                                                 OutlinedFn, CapturedStruct);
+    });
+  } else {
+    CGM.getOpenMPRuntime().EmitOMPParallelCall(*this, S.getLocStart(),
+                                               OutlinedFn, CapturedStruct);
+  }
 }
 
 void CodeGenFunction::EmitOMPLoopBody(const OMPLoopDirective &S,
Index: cfe/trunk/lib/AST/Stmt.cpp
===================================================================
--- cfe/trunk/lib/AST/Stmt.cpp
+++ cfe/trunk/lib/AST/Stmt.cpp
@@ -1434,6 +1434,21 @@
   return new (Mem) OMPFlushClause(N);
 }
 
+const OMPClause *
+OMPExecutableDirective::getSingleClause(OpenMPClauseKind K) const {
+  auto ClauseFilter =
+      [=](const OMPClause *C) -> bool { return C->getClauseKind() == K; };
+  OMPExecutableDirective::filtered_clause_iterator<decltype(ClauseFilter)> I(
+      clauses(), ClauseFilter);
+
+  if (I) {
+    auto *Clause = *I;
+    assert(!++I && "There are at least 2 clauses of the specified kind");
+    return Clause;
+  }
+  return nullptr;
+}
+
 OMPParallelDirective *OMPParallelDirective::Create(
                                               const ASTContext &C,
                                               SourceLocation StartLoc,
Index: cfe/trunk/test/OpenMP/parallel_if_codegen.cpp
===================================================================
--- cfe/trunk/test/OpenMP/parallel_if_codegen.cpp
+++ cfe/trunk/test/OpenMP/parallel_if_codegen.cpp
@@ -0,0 +1,124 @@
+// RUN: %clang_cc1 -verify -fopenmp=libiomp5 -x c++ -triple %itanium_abi_triple -emit-llvm %s -o - | FileCheck %s
+// RUN: %clang_cc1 -fopenmp=libiomp5 -x c++ -std=c++11 -triple %itanium_abi_triple -emit-pch -o %t %s
+// RUN: %clang_cc1 -fopenmp=libiomp5 -x c++ -triple %itanium_abi_triple -std=c++11 -include-pch %t -verify %s -emit-llvm -o - | FileCheck --check-prefix=CHECK %s
+// expected-no-diagnostics
+#ifndef HEADER
+#define HEADER
+
+void fn1();
+void fn2();
+void fn3();
+void fn4();
+void fn5();
+void fn6();
+
+int Arg;
+
+// CHECK-LABEL: define void @{{.+}}gtid_test
+void gtid_test() {
+// CHECK:  call void {{.+}}* @__kmpc_fork_call(%{{.+}}* @{{.+}}, i{{.+}} 1, {{.+}}* [[GTID_TEST_REGION1:@.+]] to void
+#pragma omp parallel
+#pragma omp parallel if (false)
+  gtid_test();
+// CHECK: ret void
+}
+
+// CHECK: define internal void [[GTID_TEST_REGION1]](i{{.+}}* [[GTID_PARAM:%.+]], i
+// CHECK: store i{{[0-9]+}}* [[GTID_PARAM]], i{{[0-9]+}}** [[GTID_ADDR_REF:%.+]],
+// CHECK: [[GTID_ADDR:%.+]] = load i{{[0-9]+}}** [[GTID_ADDR_REF]]
+// CHECK: [[GTID:%.+]] = load i{{[0-9]+}}* [[GTID_ADDR]]
+// CHECK: call void @__kmpc_serialized_parallel(%{{.+}}* @{{.+}}, i{{.+}} [[GTID]])
+// CHECK: [[GTID_ADDR:%.+]] = load i{{[0-9]+}}** [[GTID_ADDR_REF]]
+// CHECK: call void [[GTID_TEST_REGION2:@.+]](i{{[0-9]+}}* [[GTID_ADDR]]
+// CHECK: call void @__kmpc_end_serialized_parallel(%{{.+}}* @{{.+}}, i{{.+}} [[GTID]])
+// CHECK: ret void
+
+// CHECK: define internal void [[GTID_TEST_REGION2]](
+// CHECK: call void @{{.+}}gtid_test
+// CHECK: ret void
+
+template <typename T>
+int tmain(T Arg) {
+#pragma omp parallel if (true)
+  fn1();
+#pragma omp parallel if (false)
+  fn2();
+#pragma omp parallel if (Arg)
+  fn3();
+  return 0;
+}
+
+// CHECK-LABEL: define {{.*}}i{{[0-9]+}} @main()
+int main() {
+// CHECK: [[GTID:%.+]] = call i32 @__kmpc_global_thread_num(
+// CHECK: call void {{.+}} @__kmpc_fork_call(%{{.+}}* @{{.+}}, i{{.+}} 1, void {{.+}}* [[CAP_FN4:@.+]] to void
+#pragma omp parallel if (true)
+  fn4();
+// CHECK: call void @__kmpc_serialized_parallel(%{{.+}}* @{{.+}}, i32 [[GTID]])
+// CHECK: store i32 [[GTID]], i32* [[GTID_ADDR:%.+]],
+// CHECK: call void [[CAP_FN5:@.+]](i32* [[GTID_ADDR]],
+// CHECK: call void @__kmpc_end_serialized_parallel(%{{.+}}* @{{.+}}, i32 [[GTID]])
+#pragma omp parallel if (false)
+  fn5();
+
+// CHECK: br i1 %{{.+}}, label %[[OMP_THEN:.+]], label %[[OMP_ELSE:.+]]
+// CHECK: [[OMP_THEN]]
+// CHECK: call void {{.+}} @__kmpc_fork_call(%{{.+}}* @{{.+}}, i{{.+}} 1, void {{.+}}* [[CAP_FN6:@.+]] to void
+// CHECK: br label %[[OMP_END:.+]]
+// CHECK: [[OMP_ELSE]]
+// CHECK: call void @__kmpc_serialized_parallel(%{{.+}}* @{{.+}}, i32 %0)
+// CHECK: store i32 [[GTID]], i32* [[GTID_ADDR:%.+]],
+// CHECK: call void [[CAP_FN6]](i32* [[GTID_ADDR]],
+// CHECK: call void @__kmpc_end_serialized_parallel(%{{.+}}* @{{.+}}, i32 [[GTID]])
+// CHECK: br label %[[OMP_END]]
+// CHECK: [[OMP_END]]
+#pragma omp parallel if (Arg)
+  fn6();
+  // CHECK: = call i{{.+}} @{{.+}}tmain
+  return tmain(Arg);
+}
+
+// CHECK: define internal void [[CAP_FN4]]
+// CHECK: call void @{{.+}}fn4
+// CHECK: ret void
+
+// CHECK: define internal void [[CAP_FN5]]
+// CHECK: call void @{{.+}}fn5
+// CHECK: ret void
+
+// CHECK: define internal void [[CAP_FN6]]
+// CHECK: call void @{{.+}}fn6
+// CHECK: ret void
+
+// CHECK-LABEL: define {{.+}} @{{.+}}tmain
+// CHECK: [[GTID:%.+]] = call i32 @__kmpc_global_thread_num(
+// CHECK: call void {{.+}} @__kmpc_fork_call(%{{.+}}* @{{.+}}, i{{.+}} 1, void {{.+}}* [[CAP_FN1:@.+]] to void
+// CHECK: call void @__kmpc_serialized_parallel(%{{.+}}* @{{.+}}, i32 [[GTID]])
+// CHECK: store i32 [[GTID]], i32* [[GTID_ADDR:%.+]],
+// CHECK: call void [[CAP_FN2:@.+]](i32* [[GTID_ADDR]],
+// CHECK: call void @__kmpc_end_serialized_parallel(%{{.+}}* @{{.+}}, i32 [[GTID]])
+// CHECK: br i1 %{{.+}}, label %[[OMP_THEN:.+]], label %[[OMP_ELSE:.+]]
+// CHECK: [[OMP_THEN]]
+// CHECK: call void {{.+}} @__kmpc_fork_call(%{{.+}}* @{{.+}}, i{{.+}} 1, void {{.+}}* [[CAP_FN3:@.+]] to void
+// CHECK: br label %[[OMP_END:.+]]
+// CHECK: [[OMP_ELSE]]
+// CHECK: call void @__kmpc_serialized_parallel(%{{.+}}* @{{.+}}, i32 %0)
+// CHECK: store i32 [[GTID]], i32* [[GTID_ADDR:%.+]],
+// CHECK: call void [[CAP_FN3]](i32* [[GTID_ADDR]],
+// CHECK: call void @__kmpc_end_serialized_parallel(%{{.+}}* @{{.+}}, i32 [[GTID]])
+// CHECK: br label %[[OMP_END]]
+// CHECK: [[OMP_END]]
+
+// CHECK: define internal void [[CAP_FN1]]
+// CHECK: call void @{{.+}}fn1
+// CHECK: ret void
+
+// CHECK: define internal void [[CAP_FN2]]
+// CHECK: call void @{{.+}}fn2
+// CHECK: ret void
+
+// CHECK: define internal void [[CAP_FN3]]
+// CHECK: call void @{{.+}}fn3
+// CHECK: ret void
+
+#endif
Index: cfe/trunk/include/clang/AST/StmtOpenMP.h
===================================================================
--- cfe/trunk/include/clang/AST/StmtOpenMP.h
+++ cfe/trunk/include/clang/AST/StmtOpenMP.h
@@ -128,6 +128,13 @@
     operator bool() { return Current != End; }
   };
 
+  /// \brief Gets a single clause of the specified kind \a K associated with the
+  /// current directive iff there is only one clause of this kind (and assertion
+  /// is fired if there is more than one clause is associated with the
+  /// directive). Returns nullptr if no clause of kind \a K is associated with
+  /// the directive.
+  const OMPClause *getSingleClause(OpenMPClauseKind K) const;
+
   /// \brief Returns starting location of directive kind.
   SourceLocation getLocStart() const { return StartLoc; }
   /// \brief Returns ending location of directive.
_______________________________________________
cfe-commits mailing list
[email protected]
http://lists.cs.uiuc.edu/mailman/listinfo/cfe-commits

Reply via email to