jdoerfert created this revision.
jdoerfert added reviewers: kiranchandramohan, ABataev, RaviNarayanaswamy, 
gtbercea, grokos, sdmitriev, JonChesterfield, fghanim.
Herald added subscribers: guansong, bollu, hiraditya.
Herald added a project: LLVM.
jdoerfert added a child revision: D70290: [OpenMP] Use the OpenMPIRBuilder for 
"omp parallel".

An `omp cancel parallel` needs to be emitted by the OpenMPIRBuilder if
the `parallel` was emitted by the OpenMPIRBuilder. This patch makes
this possible. The cancel logic is shared with the cancel barriers.
Testing is done via unit tests and the clang cancel_codegen.cpp file
once D70290 <https://reviews.llvm.org/D70290> lands.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D71948

Files:
  llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
  llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
  llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
  llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp

Index: llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
===================================================================
--- llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -99,6 +99,122 @@
   EXPECT_FALSE(verifyModule(*M));
 }
 
+TEST_F(OpenMPIRBuilderTest, CreateCancel) {
+  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.initialize();
+
+  BasicBlock *CBB = BasicBlock::Create(Ctx, "", F);
+  new UnreachableInst(Ctx, CBB);
+  auto FiniCB = [&](InsertPointTy IP) {
+    ASSERT_NE(IP.getBlock(), nullptr);
+    ASSERT_EQ(IP.getBlock()->end(), IP.getPoint());
+    BranchInst::Create(CBB, IP.getBlock());
+  };
+  OMPBuilder.pushFinalizationCB({FiniCB, OMPD_parallel, true});
+
+  IRBuilder<> Builder(BB);
+
+  OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP()});
+  auto NewIP = OMPBuilder.CreateCancel(Loc, nullptr, OMPD_parallel);
+  Builder.restoreIP(NewIP);
+  EXPECT_FALSE(M->global_empty());
+  EXPECT_EQ(M->size(), 3U);
+  EXPECT_EQ(F->size(), 4U);
+  EXPECT_EQ(BB->size(), 4U);
+
+  CallInst *GTID = dyn_cast<CallInst>(&BB->front());
+  EXPECT_NE(GTID, nullptr);
+  EXPECT_EQ(GTID->getNumArgOperands(), 1U);
+  EXPECT_EQ(GTID->getCalledFunction()->getName(), "__kmpc_global_thread_num");
+  EXPECT_FALSE(GTID->getCalledFunction()->doesNotAccessMemory());
+  EXPECT_FALSE(GTID->getCalledFunction()->doesNotFreeMemory());
+
+  CallInst *Cancel = dyn_cast<CallInst>(GTID->getNextNode());
+  EXPECT_NE(Cancel, nullptr);
+  EXPECT_EQ(Cancel->getNumArgOperands(), 3U);
+  EXPECT_EQ(Cancel->getCalledFunction()->getName(), "__kmpc_cancel");
+  EXPECT_FALSE(Cancel->getCalledFunction()->doesNotAccessMemory());
+  EXPECT_FALSE(Cancel->getCalledFunction()->doesNotFreeMemory());
+  EXPECT_EQ(Cancel->getNumUses(), 1U);
+  Instruction *CancelBBTI = Cancel->getParent()->getTerminator();
+  EXPECT_EQ(CancelBBTI->getNumSuccessors(), 2U);
+  EXPECT_EQ(CancelBBTI->getSuccessor(0), NewIP.getBlock());
+  EXPECT_EQ(CancelBBTI->getSuccessor(1)->size(), 1U);
+  EXPECT_EQ(CancelBBTI->getSuccessor(1)->getTerminator()->getNumSuccessors(),
+            1U);
+  EXPECT_EQ(CancelBBTI->getSuccessor(1)->getTerminator()->getSuccessor(0),
+            CBB);
+
+  EXPECT_EQ(cast<CallInst>(Cancel)->getArgOperand(1), GTID);
+
+  OMPBuilder.popFinalizationCB();
+
+  Builder.CreateUnreachable();
+  EXPECT_FALSE(verifyModule(*M));
+}
+
+TEST_F(OpenMPIRBuilderTest, CreateCancelIfCond) {
+  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.initialize();
+
+  BasicBlock *CBB = BasicBlock::Create(Ctx, "", F);
+  new UnreachableInst(Ctx, CBB);
+  auto FiniCB = [&](InsertPointTy IP) {
+    ASSERT_NE(IP.getBlock(), nullptr);
+    ASSERT_EQ(IP.getBlock()->end(), IP.getPoint());
+    BranchInst::Create(CBB, IP.getBlock());
+  };
+  OMPBuilder.pushFinalizationCB({FiniCB, OMPD_parallel, true});
+
+  IRBuilder<> Builder(BB);
+
+  OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP()});
+  auto NewIP = OMPBuilder.CreateCancel(Loc, Builder.getTrue(), OMPD_parallel);
+  Builder.restoreIP(NewIP);
+  EXPECT_FALSE(M->global_empty());
+  EXPECT_EQ(M->size(), 3U);
+  EXPECT_EQ(F->size(), 7U);
+  EXPECT_EQ(BB->size(), 1U);
+  ASSERT_TRUE(isa<BranchInst>(BB->getTerminator()));
+  ASSERT_EQ(BB->getTerminator()->getNumSuccessors(), 2U);
+  BB = BB->getTerminator()->getSuccessor(0);
+  EXPECT_EQ(BB->size(), 4U);
+
+
+  CallInst *GTID = dyn_cast<CallInst>(&BB->front());
+  EXPECT_NE(GTID, nullptr);
+  EXPECT_EQ(GTID->getNumArgOperands(), 1U);
+  EXPECT_EQ(GTID->getCalledFunction()->getName(), "__kmpc_global_thread_num");
+  EXPECT_FALSE(GTID->getCalledFunction()->doesNotAccessMemory());
+  EXPECT_FALSE(GTID->getCalledFunction()->doesNotFreeMemory());
+
+  CallInst *Cancel = dyn_cast<CallInst>(GTID->getNextNode());
+  EXPECT_NE(Cancel, nullptr);
+  EXPECT_EQ(Cancel->getNumArgOperands(), 3U);
+  EXPECT_EQ(Cancel->getCalledFunction()->getName(), "__kmpc_cancel");
+  EXPECT_FALSE(Cancel->getCalledFunction()->doesNotAccessMemory());
+  EXPECT_FALSE(Cancel->getCalledFunction()->doesNotFreeMemory());
+  EXPECT_EQ(Cancel->getNumUses(), 1U);
+  Instruction *CancelBBTI = Cancel->getParent()->getTerminator();
+  EXPECT_EQ(CancelBBTI->getNumSuccessors(), 2U);
+  EXPECT_EQ(CancelBBTI->getSuccessor(0)->size(), 1U);
+  EXPECT_EQ(CancelBBTI->getSuccessor(0)->getUniqueSuccessor(), NewIP.getBlock());
+  EXPECT_EQ(CancelBBTI->getSuccessor(1)->size(), 1U);
+  EXPECT_EQ(CancelBBTI->getSuccessor(1)->getTerminator()->getNumSuccessors(),
+            1U);
+  EXPECT_EQ(CancelBBTI->getSuccessor(1)->getTerminator()->getSuccessor(0),
+            CBB);
+
+  EXPECT_EQ(cast<CallInst>(Cancel)->getArgOperand(1), GTID);
+
+  OMPBuilder.popFinalizationCB();
+
+  Builder.CreateUnreachable();
+  EXPECT_FALSE(verifyModule(*M));
+}
+
 TEST_F(OpenMPIRBuilderTest, CreateCancelBarrier) {
   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
   OpenMPIRBuilder OMPBuilder(*M);
Index: llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
===================================================================
--- llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -216,41 +216,99 @@
                                                   : OMPRTL___kmpc_barrier),
       Args);
 
-  if (UseCancelBarrier && CheckCancelFlag) {
-    // For a cancel barrier we create two new blocks.
-    BasicBlock *BB = Builder.GetInsertBlock();
-    BasicBlock *NonCancellationBlock;
-    if (Builder.GetInsertPoint() == BB->end()) {
-      // TODO: This branch will not be needed once we moved to the
-      // OpenMPIRBuilder codegen completely.
-      NonCancellationBlock = BasicBlock::Create(
-          BB->getContext(), BB->getName() + ".cont", BB->getParent());
-    } else {
-      NonCancellationBlock = SplitBlock(BB, &*Builder.GetInsertPoint());
-      BB->getTerminator()->eraseFromParent();
-      Builder.SetInsertPoint(BB);
-    }
-    BasicBlock *CancellationBlock = BasicBlock::Create(
-        BB->getContext(), BB->getName() + ".cncl", BB->getParent());
-
-    // Jump to them based on the return value.
-    Value *Cmp = Builder.CreateIsNull(Result);
-    Builder.CreateCondBr(Cmp, NonCancellationBlock, CancellationBlock,
-                         /* TODO weight */ nullptr, nullptr);
-
-    // From the cancellation block we finalize all variables and go to the
-    // post finalization block that is known to the FiniCB callback.
-    Builder.SetInsertPoint(CancellationBlock);
-    auto &FI = FinalizationStack.back();
-    FI.FiniCB(Builder.saveIP());
-
-    // The continuation block is where code generation continues.
-    Builder.SetInsertPoint(NonCancellationBlock, NonCancellationBlock->begin());
+  if (UseCancelBarrier && CheckCancelFlag)
+    emitCancelationCheckImpl(Result, OMPD_parallel);
+
+  return Builder.saveIP();
+}
+
+OpenMPIRBuilder::InsertPointTy
+OpenMPIRBuilder::CreateCancel(const LocationDescription &Loc,
+                              Value *IfCondition,
+                              omp::Directive CanceledDirective) {
+  if (!updateToLocation(Loc))
+    return Loc.IP;
+
+  // LLVM utilities like blocks with terminators.
+  auto *UI = Builder.CreateUnreachable();
+
+  Instruction *ThenTI = UI, *ElseTI = nullptr;
+  if (IfCondition)
+    SplitBlockAndInsertIfThenElse(IfCondition, UI, &ThenTI, &ElseTI);
+  Builder.SetInsertPoint(ThenTI);
+
+  // This seems to be used only once without much change of reuse, could live in
+  // OMPKinds.def but seems not necessary.
+  Value *CancelKind = nullptr;
+  switch (CanceledDirective) {
+  case OMPD_parallel:
+    CancelKind = Builder.getInt32(1);
+    break;
+  case OMPD_for:
+    CancelKind = Builder.getInt32(2);
+    break;
+  case OMPD_sections:
+    CancelKind = Builder.getInt32(3);
+    break;
+  case OMPD_taskgroup:
+    CancelKind = Builder.getInt32(4);
+    break;
+  default:
+    llvm_unreachable("Unknown cancel kind!");
   }
 
+  Constant *SrcLocStr = getOrCreateSrcLocStr(Loc);
+  Value *Ident = getOrCreateIdent(SrcLocStr);
+  Value *Args[] = {Ident, getOrCreateThreadID(Ident), CancelKind};
+  Value *Result = Builder.CreateCall(
+      getOrCreateRuntimeFunction(OMPRTL___kmpc_cancel), Args);
+
+  // The actual cancel logic is shared with others, e.g., cancel_barriers.
+  emitCancelationCheckImpl(Result, CanceledDirective);
+
+  // Update the insertion point and remove the terminator we introduced.
+  Builder.SetInsertPoint(UI->getParent());
+  UI->eraseFromParent();
+
   return Builder.saveIP();
 }
 
+void OpenMPIRBuilder::emitCancelationCheckImpl(
+    Value *CancelFlag, omp::Directive CanceledDirective) {
+  assert(isLastFinalizationInfoCancellable(CanceledDirective) &&
+         "Unexpected cancellation!");
+
+  // For a cancel barrier we create two new blocks.
+  BasicBlock *BB = Builder.GetInsertBlock();
+  BasicBlock *NonCancellationBlock;
+  if (Builder.GetInsertPoint() == BB->end()) {
+    // TODO: This branch will not be needed once we moved to the
+    // OpenMPIRBuilder codegen completely.
+    NonCancellationBlock = BasicBlock::Create(
+        BB->getContext(), BB->getName() + ".cont", BB->getParent());
+  } else {
+    NonCancellationBlock = SplitBlock(BB, &*Builder.GetInsertPoint());
+    BB->getTerminator()->eraseFromParent();
+    Builder.SetInsertPoint(BB);
+  }
+  BasicBlock *CancellationBlock = BasicBlock::Create(
+      BB->getContext(), BB->getName() + ".cncl", BB->getParent());
+
+  // Jump to them based on the return value.
+  Value *Cmp = Builder.CreateIsNull(CancelFlag);
+  Builder.CreateCondBr(Cmp, NonCancellationBlock, CancellationBlock,
+                       /* TODO weight */ nullptr, nullptr);
+
+  // From the cancellation block we finalize all variables and go to the
+  // post finalization block that is known to the FiniCB callback.
+  Builder.SetInsertPoint(CancellationBlock);
+  auto &FI = FinalizationStack.back();
+  FI.FiniCB(Builder.saveIP());
+
+  // The continuation block is where code generation continues.
+  Builder.SetInsertPoint(NonCancellationBlock, NonCancellationBlock->begin());
+}
+
 IRBuilder<>::InsertPoint OpenMPIRBuilder::CreateParallel(
     const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
     PrivatizeCallbackTy PrivCB, FinalizeCallbackTy FiniCB, Value *IfCondition,
Index: llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
===================================================================
--- llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
+++ llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
@@ -164,6 +164,7 @@
   OMP_RTL(OMPRTL_##Name, #Name, IsVarArg, ReturnType, __VA_ARGS__)
 
 __OMP_RTL(__kmpc_barrier, false, Void, IdentPtr, Int32)
+__OMP_RTL(__kmpc_cancel, false, Int32, IdentPtr, Int32, Int32)
 __OMP_RTL(__kmpc_cancel_barrier, false, Int32, IdentPtr, Int32)
 __OMP_RTL(__kmpc_global_thread_num, false, Int32, IdentPtr)
 __OMP_RTL(__kmpc_fork_call, true, Void, IdentPtr, Int32, ParallelTaskPtr)
Index: llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
===================================================================
--- llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -139,6 +139,17 @@
                               bool ForceSimpleCall = false,
                               bool CheckCancelFlag = true);
 
+  /// Generator for '#omp cancel'
+  ///
+  /// \param Loc The location where the directive was encountered.
+  /// \param IfCondition The evaluated 'if' clause expression, if any.
+  /// \param CanceledDirective The kind of directive that is cancled.
+  ///
+  /// \returns The insertion point after the barrier.
+  InsertPointTy CreateCancel(const LocationDescription &Loc,
+                              Value *IfCondition,
+                              omp::Directive CanceledDirective);
+
   /// Generator for '#omp parallel'
   ///
   /// \param Loc The insert and source location description.
@@ -183,6 +194,13 @@
   Value *getOrCreateIdent(Constant *SrcLocStr,
                           omp::IdentFlag Flags = omp::IdentFlag(0));
 
+  /// Generate control flow and cleanup for cancellation.
+  ///
+  /// \param CancelFlag Flag indicating if the cancellation is performed.
+  /// \param CanceledDirective The kind of directive that is cancled.
+  void emitCancelationCheckImpl(Value *CancelFlag,
+                                omp::Directive CanceledDirective);
+
   /// Generate a barrier runtime call.
   ///
   /// \param Loc The location at which the request originated and is fulfilled.
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
  • [PATCH] D71948: [Op... Johannes Doerfert via Phabricator via cfe-commits

Reply via email to