Author: abataev
Date: Mon May  7 10:23:05 2018
New Revision: 331652

URL: http://llvm.org/viewvc/llvm-project?rev=331652&view=rev
Log:
[OPENMP, NVPTX] Codegen for critical construct.

Added correct codegen for the critical construct on NVPTX devices.

Modified:
    cfe/trunk/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp
    cfe/trunk/lib/CodeGen/CGOpenMPRuntimeNVPTX.h
    cfe/trunk/test/OpenMP/nvptx_parallel_codegen.cpp

Modified: cfe/trunk/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp
URL: 
http://llvm.org/viewvc/llvm-project/cfe/trunk/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp?rev=331652&r1=331651&r2=331652&view=diff
==============================================================================
--- cfe/trunk/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp (original)
+++ cfe/trunk/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp Mon May  7 10:23:05 2018
@@ -1837,6 +1837,66 @@ void CGOpenMPRuntimeNVPTX::emitSpmdParal
   emitOutlinedFunctionCall(CGF, Loc, OutlinedFn, OutlinedFnArgs);
 }
 
+void CGOpenMPRuntimeNVPTX::emitCriticalRegion(
+    CodeGenFunction &CGF, StringRef CriticalName,
+    const RegionCodeGenTy &CriticalOpGen, SourceLocation Loc,
+    const Expr *Hint) {
+  llvm::BasicBlock *LoopBB = CGF.createBasicBlock("omp.critical.loop");
+  llvm::BasicBlock *TestBB = CGF.createBasicBlock("omp.critical.test");
+  llvm::BasicBlock *SyncBB = CGF.createBasicBlock("omp.critical.sync");
+  llvm::BasicBlock *BodyBB = CGF.createBasicBlock("omp.critical.body");
+  llvm::BasicBlock *ExitBB = CGF.createBasicBlock("omp.critical.exit");
+
+  // Fetch team-local id of the thread.
+  llvm::Value *ThreadID = getNVPTXThreadID(CGF);
+
+  // Get the width of the team.
+  llvm::Value *TeamWidth = getNVPTXNumThreads(CGF);
+
+  // Initialize the counter variable for the loop.
+  QualType Int32Ty =
+      CGF.getContext().getIntTypeForBitwidth(/*DestWidth=*/32, /*Signed=*/0);
+  Address Counter = CGF.CreateMemTemp(Int32Ty, "critical_counter");
+  LValue CounterLVal = CGF.MakeAddrLValue(Counter, Int32Ty);
+  CGF.EmitStoreOfScalar(llvm::Constant::getNullValue(CGM.Int32Ty), CounterLVal,
+                        /*isInit=*/true);
+
+  // Block checks if loop counter exceeds upper bound.
+  CGF.EmitBlock(LoopBB);
+  llvm::Value *CounterVal = CGF.EmitLoadOfScalar(CounterLVal, Loc);
+  llvm::Value *CmpLoopBound = CGF.Builder.CreateICmpSLT(CounterVal, TeamWidth);
+  CGF.Builder.CreateCondBr(CmpLoopBound, TestBB, ExitBB);
+
+  // Block tests which single thread should execute region, and which threads
+  // should go straight to synchronisation point.
+  CGF.EmitBlock(TestBB);
+  CounterVal = CGF.EmitLoadOfScalar(CounterLVal, Loc);
+  llvm::Value *CmpThreadToCounter =
+      CGF.Builder.CreateICmpEQ(ThreadID, CounterVal);
+  CGF.Builder.CreateCondBr(CmpThreadToCounter, BodyBB, SyncBB);
+
+  // Block emits the body of the critical region.
+  CGF.EmitBlock(BodyBB);
+
+  // Output the critical statement.
+  CriticalOpGen(CGF);
+
+  // After the body surrounded by the critical region, the single executing
+  // thread will jump to the synchronisation point.
+  // Block waits for all threads in current team to finish then increments the
+  // counter variable and returns to the loop.
+  CGF.EmitBlock(SyncBB);
+  getNVPTXCTABarrier(CGF);
+
+  llvm::Value *IncCounterVal =
+      CGF.Builder.CreateNSWAdd(CounterVal, CGF.Builder.getInt32(1));
+  CGF.EmitStoreOfScalar(IncCounterVal, CounterLVal);
+  CGF.EmitBranch(LoopBB);
+
+  // Block that is reached when  all threads in the team complete the region.
+  CGF.EmitBlock(ExitBB, /*IsFinished=*/true);
+}
+
 /// Cast value to the specified type.
 static llvm::Value *castValueToType(CodeGenFunction &CGF, llvm::Value *Val,
                                     QualType ValTy, QualType CastTy,

Modified: cfe/trunk/lib/CodeGen/CGOpenMPRuntimeNVPTX.h
URL: 
http://llvm.org/viewvc/llvm-project/cfe/trunk/lib/CodeGen/CGOpenMPRuntimeNVPTX.h?rev=331652&r1=331651&r2=331652&view=diff
==============================================================================
--- cfe/trunk/lib/CodeGen/CGOpenMPRuntimeNVPTX.h (original)
+++ cfe/trunk/lib/CodeGen/CGOpenMPRuntimeNVPTX.h Mon May  7 10:23:05 2018
@@ -250,6 +250,16 @@ public:
                         ArrayRef<llvm::Value *> CapturedVars,
                         const Expr *IfCond) override;
 
+  /// Emits a critical region.
+  /// \param CriticalName Name of the critical region.
+  /// \param CriticalOpGen Generator for the statement associated with the 
given
+  /// critical region.
+  /// \param Hint Value of the 'hint' clause (optional).
+  void emitCriticalRegion(CodeGenFunction &CGF, StringRef CriticalName,
+                          const RegionCodeGenTy &CriticalOpGen,
+                          SourceLocation Loc,
+                          const Expr *Hint = nullptr) override;
+
   /// Emit a code for reduction clause.
   ///
   /// \param Privates List of private copies for original reduction arguments.

Modified: cfe/trunk/test/OpenMP/nvptx_parallel_codegen.cpp
URL: 
http://llvm.org/viewvc/llvm-project/cfe/trunk/test/OpenMP/nvptx_parallel_codegen.cpp?rev=331652&r1=331651&r2=331652&view=diff
==============================================================================
--- cfe/trunk/test/OpenMP/nvptx_parallel_codegen.cpp (original)
+++ cfe/trunk/test/OpenMP/nvptx_parallel_codegen.cpp Mon May  7 10:23:05 2018
@@ -70,8 +70,6 @@ int bar(int n){
   return a;
 }
 
-// CHECK: @"_gomp_critical_user_$var" = common global [8 x i32] zeroinitializer
-
 // CHECK-NOT: define {{.*}}void {{@__omp_offloading_.+template.+l17}}_worker()
 
 // CHECK-LABEL: define {{.*}}void 
{{@__omp_offloading_.+template.+l26}}_worker()
@@ -313,4 +311,28 @@ int bar(int n){
 // CHECK: [[A:%.+]] = alloca i[[SZ:32|64]],
 // CHECK: store i[[SZ]] 45, i[[SZ]]* %a,
 // CHECK: ret void
+
+// CHECK-LABEL: define {{.*}}void 
{{@__omp_offloading_.+template.+l54}}_worker()
+// CHECK-LABEL: define {{.*}}void {{@__omp_offloading_.+template.+l54}}(
+
+// CHECK-LABEL: define internal void @{{.+}}(i32* noalias %{{.+}}, i32* 
noalias %{{.+}}, i32* dereferenceable{{.*}})
+// CHECK:  [[CC:%.+]] = alloca i32,
+// CHECK:  [[TID:%.+]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(),
+// CHECK:  [[NUM_THREADS:%.+]] = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x(),
+// CHECK:  store i32 0, i32* [[CC]],
+// CHECK:  br label
+
+// CHECK:  [[CC_VAL:%.+]] = load i32, i32* [[CC]],
+// CHECK:  [[RES:%.+]] = icmp slt i32 [[CC_VAL]], [[NUM_THREADS]]
+// CHECK:  br i1 [[RES]], label
+
+// CHECK:  [[CC_VAL:%.+]] = load i32, i32* [[CC]],
+// CHECK:  [[RES:%.+]] = icmp eq i32 [[TID]], [[CC_VAL]]
+// CHECK:  br i1 [[RES]], label
+
+// CHECK:  call void @llvm.nvvm.barrier0()
+// CHECK:  [[NEW_CC_VAL:%.+]] = add nsw i32 [[CC_VAL]], 1
+// CHECK:  store i32 [[NEW_CC_VAL]], i32* [[CC]],
+// CHECK:  br label
+
 #endif


_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to