llitchev updated this revision to Diff 313077.
llitchev added a comment.

Fixed a casing issue with a local var.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D91556/new/

https://reviews.llvm.org/D91556

Files:
  clang/test/OpenMP/parallel_codegen.cpp
  llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
  llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
  llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
  mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
  mlir/test/Conversion/OpenMPToLLVM/openmp_float-parallel_param.mlir

Index: mlir/test/Conversion/OpenMPToLLVM/openmp_float-parallel_param.mlir
===================================================================
--- /dev/null
+++ mlir/test/Conversion/OpenMPToLLVM/openmp_float-parallel_param.mlir
@@ -0,0 +1,42 @@
+// RUN: mlir-translate  --mlir-to-llvmir %s | FileCheck %s
+
+module {
+  llvm.func @malloc(!llvm.i64) -> !llvm.ptr<i8>
+  llvm.func @main() {
+    %0 = llvm.mlir.constant(4 : index) : !llvm.i64
+    %1 = llvm.mlir.constant(4 : index) : !llvm.i64
+    %2 = llvm.mlir.null : !llvm.ptr<float>
+    %3 = llvm.mlir.constant(1 : index) : !llvm.i64
+    %4 = llvm.getelementptr %2[%3] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
+    %5 = llvm.ptrtoint %4 : !llvm.ptr<float> to !llvm.i64
+    %6 = llvm.mul %1, %5 : !llvm.i64
+    %7 = llvm.call @malloc(%6) : (!llvm.i64) -> !llvm.ptr<i8>
+    %8 = llvm.bitcast %7 : !llvm.ptr<i8> to !llvm.ptr<float>
+    %9 = llvm.mlir.undef : !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
+    %10 = llvm.insertvalue %8, %9[0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
+    %11 = llvm.insertvalue %8, %10[1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
+    %12 = llvm.mlir.constant(0 : index) : !llvm.i64
+    %13 = llvm.insertvalue %12, %11[2] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
+    %14 = llvm.mlir.constant(1 : index) : !llvm.i64
+    %15 = llvm.insertvalue %1, %13[3, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
+    %16 = llvm.insertvalue %14, %15[4, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
+    %17 = llvm.mlir.constant(4.200000e+01 : f32) : !llvm.float
+    // CHECK: %CaptureStructAlloca = alloca %CapturedStructType
+    // CHECK: %{{.*}} = insertvalue %CapturedStructType undef, {{.*}}, 0
+    // CHECK: store %CapturedStructType %{{.*}}, %CapturedStructType* %CaptureStructAlloca
+    omp.parallel num_threads(%0 : !llvm.i64) {
+      // CHECK: %{{.*}} = load %CapturedStructType, %CapturedStructType* %CaptureStructAlloca
+      // CHECK: %{{.*}} = extractvalue %CapturedStructType %{{.*}}, 0
+      %27 = llvm.mlir.constant(1 : i64) : !llvm.i64
+      %28 = llvm.extractvalue %16[1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
+      %29 = llvm.mlir.constant(0 : index) : !llvm.i64
+      %30 = llvm.mlir.constant(1 : index) : !llvm.i64
+      %31 = llvm.mul %27, %30 : !llvm.i64
+      %32 = llvm.add %29, %31 : !llvm.i64
+      %33 = llvm.getelementptr %28[%32] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
+      llvm.store %17, %33 : !llvm.ptr<float>
+      omp.terminator
+    }
+    llvm.return
+  }
+}
Index: mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
===================================================================
--- mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
+++ mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
@@ -6,7 +6,7 @@
   %end = constant 0 : index
   // CHECK: omp.parallel
   omp.parallel {
-    // CHECK-NEXT: llvm.br ^[[BB1:.*]](%{{[0-9]+}}, %{{[0-9]+}} : !llvm.i64, !llvm.i64
+    // CHECK: llvm.br ^[[BB1:.*]](%{{[0-9]+}}, %{{[0-9]+}} : !llvm.i64, !llvm.i64
     br ^bb1(%start, %end : index, index)
   // CHECK-NEXT: ^[[BB1]](%[[ARG1:[0-9]+]]: !llvm.i64, %[[ARG2:[0-9]+]]: !llvm.i64):{{.*}}
   ^bb1(%0: index, %1: index):
Index: llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
===================================================================
--- llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -60,25 +60,6 @@
   DebugLoc DL;
 };
 
-// Returns the value stored in the given allocation. Returns null if the given
-// value is not a result of an allocation, if no value is stored or if there is
-// more than one store.
-static Value *findStoredValue(Value *AllocaValue) {
-  Instruction *Alloca = dyn_cast<AllocaInst>(AllocaValue);
-  if (!Alloca)
-    return nullptr;
-  StoreInst *Store = nullptr;
-  for (Use &U : Alloca->uses()) {
-    if (auto *CandidateStore = dyn_cast<StoreInst>(U.getUser())) {
-      EXPECT_EQ(Store, nullptr);
-      Store = CandidateStore;
-    }
-  }
-  if (!Store)
-    return nullptr;
-  return Store->getValueOperand();
-}
-
 TEST_F(OpenMPIRBuilderTest, CreateBarrier) {
   OpenMPIRBuilder OMPBuilder(*M);
   OMPBuilder.initialize();
@@ -425,7 +406,6 @@
   EXPECT_EQ(ForkCI->getArgOperand(1),
             ConstantInt::get(Type::getInt32Ty(Ctx), 1U));
   EXPECT_EQ(ForkCI->getArgOperand(2), Usr);
-  EXPECT_EQ(findStoredValue(ForkCI->getArgOperand(3)), F->arg_begin());
 }
 
 TEST_F(OpenMPIRBuilderTest, ParallelNested) {
@@ -739,15 +719,13 @@
   EXPECT_TRUE(isa<GlobalVariable>(ForkCI->getArgOperand(0)));
   EXPECT_EQ(ForkCI->getArgOperand(1),
             ConstantInt::get(Type::getInt32Ty(Ctx), 1));
-  Value *StoredForkArg = findStoredValue(ForkCI->getArgOperand(3));
-  EXPECT_EQ(StoredForkArg, F->arg_begin());
+  EXPECT_EQ(ForkCI->getArgOperand(3), F->arg_begin());
 
   EXPECT_EQ(DirectCI->getCalledFunction(), OutlinedFn);
   EXPECT_EQ(DirectCI->getNumArgOperands(), 3U);
   EXPECT_TRUE(isa<AllocaInst>(DirectCI->getArgOperand(0)));
   EXPECT_TRUE(isa<AllocaInst>(DirectCI->getArgOperand(1)));
-  Value *StoredDirectArg = findStoredValue(DirectCI->getArgOperand(2));
-  EXPECT_EQ(StoredDirectArg, F->arg_begin());
+  EXPECT_EQ(DirectCI->getArgOperand(2), F->arg_begin());
 }
 
 TEST_F(OpenMPIRBuilderTest, ParallelCancelBarrier) {
@@ -862,85 +840,6 @@
   }
 }
 
-TEST_F(OpenMPIRBuilderTest, ParallelForwardAsPointers) {
-  OpenMPIRBuilder OMPBuilder(*M);
-  OMPBuilder.initialize();
-  F->setName("func");
-  IRBuilder<> Builder(BB);
-  OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
-  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
-
-  Type *I32Ty = Type::getInt32Ty(M->getContext());
-  Type *I32PtrTy = Type::getInt32PtrTy(M->getContext());
-  Type *StructTy = StructType::get(I32Ty, I32PtrTy);
-  Type *StructPtrTy = StructTy->getPointerTo();
-  Type *VoidTy = Type::getVoidTy(M->getContext());
-  FunctionCallee RetI32Func = M->getOrInsertFunction("ret_i32", I32Ty);
-  FunctionCallee TakeI32Func =
-      M->getOrInsertFunction("take_i32", VoidTy, I32Ty);
-  FunctionCallee RetI32PtrFunc = M->getOrInsertFunction("ret_i32ptr", I32PtrTy);
-  FunctionCallee TakeI32PtrFunc =
-      M->getOrInsertFunction("take_i32ptr", VoidTy, I32PtrTy);
-  FunctionCallee RetStructFunc = M->getOrInsertFunction("ret_struct", StructTy);
-  FunctionCallee TakeStructFunc =
-      M->getOrInsertFunction("take_struct", VoidTy, StructTy);
-  FunctionCallee RetStructPtrFunc =
-      M->getOrInsertFunction("ret_structptr", StructPtrTy);
-  FunctionCallee TakeStructPtrFunc =
-      M->getOrInsertFunction("take_structPtr", VoidTy, StructPtrTy);
-  Value *I32Val = Builder.CreateCall(RetI32Func);
-  Value *I32PtrVal = Builder.CreateCall(RetI32PtrFunc);
-  Value *StructVal = Builder.CreateCall(RetStructFunc);
-  Value *StructPtrVal = Builder.CreateCall(RetStructPtrFunc);
-
-  Instruction *Internal;
-  auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
-                       BasicBlock &ContinuationBB) {
-    IRBuilder<>::InsertPointGuard Guard(Builder);
-    Builder.restoreIP(CodeGenIP);
-    Internal = Builder.CreateCall(TakeI32Func, I32Val);
-    Builder.CreateCall(TakeI32PtrFunc, I32PtrVal);
-    Builder.CreateCall(TakeStructFunc, StructVal);
-    Builder.CreateCall(TakeStructPtrFunc, StructPtrVal);
-  };
-  auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &,
-                    Value &Inner, Value *&ReplacementValue) {
-    ReplacementValue = &Inner;
-    return CodeGenIP;
-  };
-  auto FiniCB = [](InsertPointTy) {};
-
-  IRBuilder<>::InsertPoint AllocaIP(&F->getEntryBlock(),
-                                    F->getEntryBlock().getFirstInsertionPt());
-  IRBuilder<>::InsertPoint AfterIP =
-      OMPBuilder.createParallel(Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB,
-                                nullptr, nullptr, OMP_PROC_BIND_default, false);
-  Builder.restoreIP(AfterIP);
-  Builder.CreateRetVoid();
-
-  OMPBuilder.finalize();
-
-  EXPECT_FALSE(verifyModule(*M, &errs()));
-  Function *OutlinedFn = Internal->getFunction();
-
-  Type *Arg2Type = OutlinedFn->getArg(2)->getType();
-  EXPECT_TRUE(Arg2Type->isPointerTy());
-  EXPECT_EQ(Arg2Type->getPointerElementType(), I32Ty);
-
-  // Arguments that need to be passed through pointers and reloaded will get
-  // used earlier in the functions and therefore will appear first in the
-  // argument list after outlining.
-  Type *Arg3Type = OutlinedFn->getArg(3)->getType();
-  EXPECT_TRUE(Arg3Type->isPointerTy());
-  EXPECT_EQ(Arg3Type->getPointerElementType(), StructTy);
-
-  Type *Arg4Type = OutlinedFn->getArg(4)->getType();
-  EXPECT_EQ(Arg4Type, I32PtrTy);
-
-  Type *Arg5Type = OutlinedFn->getArg(5)->getType();
-  EXPECT_EQ(Arg5Type, StructPtrTy);
-}
-
 TEST_F(OpenMPIRBuilderTest, CanonicalLoopSimple) {
   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
   OpenMPIRBuilder OMPBuilder(*M);
@@ -1437,4 +1336,83 @@
   EXPECT_EQ(SingleEndCI->getArgOperand(1), SingleEntryCI->getArgOperand(1));
 }
 
+TEST_F(OpenMPIRBuilderTest, ParallelCaptureUpperDefinedParameters) {
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.initialize();
+  F->setName("func");
+  IRBuilder<> Builder(BB);
+  OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
+  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+
+  Type *I32Ty = Type::getInt32Ty(M->getContext());
+  Type *I32PtrTy = Type::getInt32PtrTy(M->getContext());
+  Type *StructTy = StructType::get(I32Ty, I32PtrTy);
+  Type *StructPtrTy = StructTy->getPointerTo();
+  Type *VoidTy = Type::getVoidTy(M->getContext());
+  FunctionCallee RetI32Func = M->getOrInsertFunction("ret_i32", I32Ty);
+  FunctionCallee TakeI32Func =
+      M->getOrInsertFunction("take_i32", VoidTy, I32Ty);
+  FunctionCallee RetI32PtrFunc = M->getOrInsertFunction("ret_i32ptr", I32PtrTy);
+  FunctionCallee TakeI32PtrFunc =
+      M->getOrInsertFunction("take_i32ptr", VoidTy, I32PtrTy);
+  FunctionCallee RetStructFunc = M->getOrInsertFunction("ret_struct", StructTy);
+  FunctionCallee TakeStructFunc =
+      M->getOrInsertFunction("take_struct", VoidTy, StructTy);
+  FunctionCallee RetStructPtrFunc =
+      M->getOrInsertFunction("ret_structptr", StructPtrTy);
+  FunctionCallee TakeStructPtrFunc =
+      M->getOrInsertFunction("take_structPtr", VoidTy, StructPtrTy);
+  Value *I32Val = Builder.CreateCall(RetI32Func);
+  Value *I32PtrVal = Builder.CreateCall(RetI32PtrFunc);
+  Value *StructVal = Builder.CreateCall(RetStructFunc);
+  Value *StructPtrVal = Builder.CreateCall(RetStructPtrFunc);
+
+  Instruction *Internal;
+  auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
+                       BasicBlock &ContinuationBB) {
+    IRBuilder<>::InsertPointGuard Guard(Builder);
+    Builder.restoreIP(CodeGenIP);
+    Internal = Builder.CreateCall(TakeI32Func, I32Val);
+    Builder.CreateCall(TakeI32PtrFunc, I32PtrVal);
+    Builder.CreateCall(TakeStructFunc, StructVal);
+    Builder.CreateCall(TakeStructPtrFunc, StructPtrVal);
+  };
+  auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &,
+                    Value &Inner, Value *&ReplacementValue) {
+    ReplacementValue = &Inner;
+    return CodeGenIP;
+  };
+  auto FiniCB = [](InsertPointTy) {};
+
+  IRBuilder<>::InsertPoint AllocaIP(&F->getEntryBlock(),
+                                    F->getEntryBlock().getFirstInsertionPt());
+  IRBuilder<>::InsertPoint AfterIP =
+      OMPBuilder.createParallel(Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB,
+                                nullptr, nullptr, OMP_PROC_BIND_default, false);
+
+  Builder.restoreIP(AfterIP);
+  Builder.CreateRetVoid();
+
+  OMPBuilder.finalize();
+
+  EXPECT_FALSE(verifyModule(*M, &errs()));
+  Function *OutlinedFn = Internal->getFunction();
+
+  Type *Arg2Type = OutlinedFn->getArg(2)->getType();
+  EXPECT_TRUE(Arg2Type->isPointerTy());
+  Type *StructElemTy = Arg2Type->getPointerElementType();
+  EXPECT_STREQ(StructElemTy->getStructName().data(), "CapturedStructType");
+  EXPECT_TRUE(StructElemTy->isStructTy());
+  EXPECT_EQ(StructElemTy->getStructNumElements(), static_cast<unsigned>(2));
+  StructType *StructTypeTy = reinterpret_cast<StructType *>(StructElemTy);
+  EXPECT_TRUE(StructTypeTy->getElementType(0)->isIntegerTy(32));
+  EXPECT_TRUE(StructTypeTy->getElementType(1)->isStructTy());
+  StructType *InnerStructType =
+      reinterpret_cast<StructType *>(StructTypeTy->getElementType(1));
+  EXPECT_TRUE(InnerStructType->getElementType(0)->isIntegerTy(32));
+  EXPECT_TRUE(InnerStructType->getElementType(1)->isPointerTy());
+  EXPECT_TRUE(
+      InnerStructType->getElementType(1)->getPointerElementType()->isIntegerTy(
+          32));
+}
 } // namespace
Index: llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
===================================================================
--- llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -455,10 +455,6 @@
   BasicBlock *InsertBB = Builder.GetInsertBlock();
   Function *OuterFn = InsertBB->getParent();
 
-  // Save the outer alloca block because the insertion iterator may get
-  // invalidated and we still need this later.
-  BasicBlock *OuterAllocaBlock = OuterAllocaIP.getBlock();
-
   // Vector to remember instructions we used only during the modeling but which
   // we want to delete at the end.
   SmallVector<Instruction *, 4> ToBeDeleted;
@@ -526,8 +522,7 @@
 
   // Add some fake uses for OpenMP provided arguments.
   ToBeDeleted.push_back(Builder.CreateLoad(TIDAddr, "tid.addr.use"));
-  Instruction *ZeroAddrUse = Builder.CreateLoad(ZeroAddr, "zero.addr.use");
-  ToBeDeleted.push_back(ZeroAddrUse);
+  ToBeDeleted.push_back(Builder.CreateLoad(ZeroAddr, "zero.addr.use"));
 
   // ThenBB
   //   |
@@ -692,41 +687,19 @@
   FunctionCallee TIDRTLFn =
       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_global_thread_num);
 
+  // Capture the outer parameters for the ParallelRegions.
+  captureParallelRegionParameters(InsertBB->getTerminator(), OuterFn, Blocks,
+                                  TIDAddr, ZeroAddr);
+
   auto PrivHelper = [&](Value &V) {
     if (&V == TIDAddr || &V == ZeroAddr)
       return;
 
-    SetVector<Use *> Uses;
+    SmallVector<Use *, 8> Uses;
     for (Use &U : V.uses())
       if (auto *UserI = dyn_cast<Instruction>(U.getUser()))
         if (ParallelRegionBlockSet.count(UserI->getParent()))
-          Uses.insert(&U);
-
-    // __kmpc_fork_call expects extra arguments as pointers. If the input
-    // already has a pointer type, everything is fine. Otherwise, store the
-    // value onto stack and load it back inside the to-be-outlined region. This
-    // will ensure only the pointer will be passed to the function.
-    // FIXME: if there are more than 15 trailing arguments, they must be
-    // additionally packed in a struct.
-    Value *Inner = &V;
-    if (!V.getType()->isPointerTy()) {
-      IRBuilder<>::InsertPointGuard Guard(Builder);
-      LLVM_DEBUG(llvm::dbgs() << "Forwarding input as pointer: " << V << "\n");
-
-      Builder.restoreIP(OuterAllocaIP);
-      Value *Ptr =
-          Builder.CreateAlloca(V.getType(), nullptr, V.getName() + ".reloaded");
-
-      // Store to stack at end of the block that currently branches to the entry
-      // block of the to-be-outlined region.
-      Builder.SetInsertPoint(InsertBB,
-                             InsertBB->getTerminator()->getIterator());
-      Builder.CreateStore(&V, Ptr);
-
-      // Load back next to allocations in the to-be-outlined region.
-      Builder.restoreIP(InnerAllocaIP);
-      Inner = Builder.CreateLoad(Ptr);
-    }
+          Uses.push_back(&U);
 
     Value *ReplacementValue = nullptr;
     CallInst *CI = dyn_cast<CallInst>(&V);
@@ -734,7 +707,7 @@
       ReplacementValue = PrivTID;
     } else {
       Builder.restoreIP(
-          PrivCB(InnerAllocaIP, Builder.saveIP(), V, *Inner, ReplacementValue));
+          PrivCB(InnerAllocaIP, Builder.saveIP(), V, V, ReplacementValue));
       assert(ReplacementValue &&
              "Expected copy/create callback to set replacement value!");
       if (ReplacementValue == &V)
@@ -745,20 +718,6 @@
       UPtr->set(ReplacementValue);
   };
 
-  // Reset the inner alloca insertion as it will be used for loading the values
-  // wrapped into pointers before passing them into the to-be-outlined region.
-  // Configure it to insert immediately after the fake use of zero address so
-  // that they are available in the generated body and so that the
-  // OpenMP-related values (thread ID and zero address pointers) remain leading
-  // in the argument list.
-  InnerAllocaIP = IRBuilder<>::InsertPoint(
-      ZeroAddrUse->getParent(), ZeroAddrUse->getNextNode()->getIterator());
-
-  // Reset the outer alloca insertion point to the entry of the relevant block
-  // in case it was invalidated.
-  OuterAllocaIP = IRBuilder<>::InsertPoint(
-      OuterAllocaBlock, OuterAllocaBlock->getFirstInsertionPt());
-
   for (Value *Input : Inputs) {
     LLVM_DEBUG(dbgs() << "Captured input: " << *Input << "\n");
     PrivHelper(*Input);
@@ -785,6 +744,93 @@
   return AfterIP;
 }
 
+void OpenMPIRBuilder::captureParallelRegionParameters(
+    Instruction *InsertBeforeInst, Function *OuterFn,
+    const SmallVectorImpl<BasicBlock *> &Blocks, const Value *TIDAddr,
+    const Value *ZeroAddr) {
+  // Capture outside parameters.
+  SetVector<Value *> CapturedValues;
+  SetVector<BasicBlock *> BlockParents;
+  unsigned BlockSize = Blocks.size();
+  for (unsigned Counter = 0; Counter < BlockSize; Counter++) {
+    BasicBlock *ParallelRegionBlock = Blocks[Counter];
+    BlockParents.insert(ParallelRegionBlock);
+  }
+  for (unsigned Counter = 0; Counter < BlockSize; Counter++) {
+    BasicBlock *ParallelRegionBlock = Blocks[Counter];
+    for (auto I = ParallelRegionBlock->begin(), E = ParallelRegionBlock->end();
+         I != E; ++I) {
+      for (Use &U : I->operands()) {
+        Value *V = U.get();
+        if (V == TIDAddr || V == ZeroAddr)
+          continue;
+
+        // Skip pointers.
+        if (V->getType()->isPointerTy())
+          continue;
+
+        // One case for example, if propagated const, there is no instruction.
+        Instruction *DefInst = dyn_cast<Instruction>(V);
+        if (!DefInst || !DefInst->getParent())
+          continue;
+
+        // If the parent of the def instruction is not in the parallel
+        // region block set, the definition of the operand is in an
+        // upper block.
+        if (!BlockParents.contains(DefInst->getParent()))
+          CapturedValues.insert(V);
+      }
+    }
+  }
+
+  // If there are captured parameters to the parallel loop,
+  // allocate the captured struct on the stack, set the element values.
+  // Then, load the capture struct, extract the elements and replace the
+  // captured values with the extracted ones from the struct.
+  if (CapturedValues.empty())
+    return;
+
+  // Create the StructTy.
+  unsigned CapturedSize = CapturedValues.size();
+  std::vector<Type *> StructTypes;
+  StructTypes.reserve(CapturedSize);
+  for (unsigned Counter = 0; Counter < CapturedSize; Counter++)
+    StructTypes.push_back(CapturedValues[Counter]->getType());
+
+  Type *CaptureStructType =
+      StructType::create(StructTypes, "CapturedStructType");
+
+  AllocaInst *AllocaInst;
+  {
+    llvm::IRBuilder<>::InsertPointGuard Guard(Builder);
+    Builder.SetInsertPoint(InsertBeforeInst);
+
+    // Allocate and populate the capture struct.
+    AllocaInst =
+        Builder.CreateAlloca(CaptureStructType, nullptr, "CaptureStructAlloca");
+    Value *InsertValue = UndefValue::get(CaptureStructType);
+    for (auto SrcIdx : enumerate(CapturedValues))
+      InsertValue = Builder.CreateInsertValue(InsertValue, SrcIdx.value(),
+                                              SrcIdx.index());
+    Builder.CreateStore(InsertValue, AllocaInst);
+  }
+
+  Value *LoadedAlloca = Builder.CreateLoad(AllocaInst);
+  for (auto SrcIdx : enumerate(CapturedValues)) {
+    Value *LoadedValue =
+        Builder.CreateExtractValue(LoadedAlloca, SrcIdx.index());
+
+    // Find the usages of the captured values and replace them in the parallel
+    // region blocks.
+    for (unsigned Counter = 0; Counter < BlockSize; Counter++)
+      for (auto I = Blocks[Counter]->begin(), E = Blocks[Counter]->end();
+           I != E; ++I)
+        for (Use &U : I->operands())
+          if (SrcIdx.value() == U.get())
+            U.set(LoadedValue);
+  }
+}
+
 void OpenMPIRBuilder::emitFlush(const LocationDescription &Loc) {
   // Build call void __kmpc_flush(ident_t *loc)
   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc);
Index: llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
===================================================================
--- llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -124,9 +124,6 @@
   ///              an equivalent but different value.
   /// \param ReplVal The replacement value, thus a copy or new created version
   ///                of \p Inner.
-  ///
-  /// \returns The new insertion point where code generation continues and
-  ///          \p ReplVal the replacement value.
   using PrivatizeCallbackTy = function_ref<InsertPointTy(
       InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &Original,
       Value &Inner, Value *&ReplVal)>;
@@ -677,6 +674,19 @@
                                         BasicBlock *PreInsertBefore,
                                         BasicBlock *PostInsertBefore,
                                         const Twine &Name = {});
+
+  /// Capture the above-defined paraneters for the parallel regions.
+  ///
+  /// \param InsertBeforeInst The instruction before which the capture
+  /// alloca, insert and store should be inserted.
+  /// \param OuterFn The function containing the omp::Parallel.
+  /// \param Blocks The parallel region blocks.
+  /// \param TIDAddr The address of the TID value.
+  /// \param ZeroAddr The address of the Zero value.
+  void captureParallelRegionParameters(
+      Instruction *InsertBeforeInst, Function *OuterFn,
+      const SmallVectorImpl<BasicBlock *> &Blocks, const Value *const TIDAddr,
+      const Value *const ZeroAddr);
 };
 
 /// Class to represented the control flow structure of an OpenMP canonical loop.
Index: clang/test/OpenMP/parallel_codegen.cpp
===================================================================
--- clang/test/OpenMP/parallel_codegen.cpp
+++ clang/test/OpenMP/parallel_codegen.cpp
@@ -133,26 +133,22 @@
 // CHECK-DEBUG-DAG:       define internal void [[OMP_OUTLINED]](i32* noalias %.global_tid., i32* noalias %.bound_tid., i64 [[VLA_SIZE:%.+]], i32* {{.+}} [[VLA_ADDR:%[^)]+]])
 // CHECK-DEBUG-DAG:       call void [[OMP_OUTLINED_DEBUG]]
 
-// Note that OpenMPIRBuilder puts the trailing arguments in a different order:
-// arguments that are wrapped into additional pointers precede the other
-// arguments. This is expected and not problematic because both the call and the
-// function are generated from the same place, and the function is internal.
 // ALL:       define linkonce_odr {{[a-z\_\b]*[ ]?i32}} [[TMAIN]](i8** %argc)
 // ALL:       store i8** %argc, i8*** [[ARGC_ADDR:%.+]],
 // CHECK:       call {{.*}}void (%struct.ident_t*, i32, void (i32*, i32*, ...)*, ...) @__kmpc_fork_call(%struct.ident_t* [[DEF_LOC_2]], i32 2, void (i32*, i32*, ...)* bitcast (void (i32*, i32*, i8***, i{{64|32}})* [[OMP_OUTLINED:@.+]] to void (i32*, i32*, ...)*), i8*** [[ARGC_ADDR]], i{{64|32}} %{{.+}})
-// IRBUILDER:   call {{.*}}void (%struct.ident_t*, i32, void (i32*, i32*, ...)*, ...) @__kmpc_fork_call(%struct.ident_t* [[DEF_LOC_2]], i32 2, void (i32*, i32*, ...)* bitcast (void (i32*, i32*, i{{64|32}}*, i8***)* [[OMP_OUTLINED:@.+]] to void (i32*, i32*, ...)*), i{{64|32}}* %{{.+}}, i8*** [[ARGC_ADDR]])
+// IRBUILDER:   call {{.*}}void (%struct.ident_t*, i32, void (i32*, i32*, ...)*, ...) @__kmpc_fork_call(%struct.ident_t* [[DEF_LOC_2]], i32 2, void (i32*, i32*, ...)* bitcast (void (i32*, i32*, %CapturedStructType*, i8***)* [[OMP_OUTLINED:@.+]] to void (i32*, i32*, ...)*), %CapturedStructType* %CaptureStructAlloca, i8*** [[ARGC_ADDR]])
 // ALL:  ret i32 0
 // ALL-NEXT:  }
 // ALL-DEBUG:       define linkonce_odr i32 [[TMAIN]](i8** %argc)
 
 // CHECK-DEBUG:       store i8** %argc, i8*** [[ARGC_ADDR:%.+]],
 // CHECK-DEBUG:       call void (%struct.ident_t*, i32, void (i32*, i32*, ...)*, ...) @__kmpc_fork_call(%struct.ident_t* @{{.*}}, i32 2, void (i32*, i32*, ...)* bitcast (void (i32*, i32*, i8***, i64)* [[OMP_OUTLINED:@.+]] to void (i32*, i32*, ...)*), i8*** [[ARGC_ADDR]], i64 %{{.+}})
-// IRBUILDER-DEBUG:   call void (%struct.ident_t*, i32, void (i32*, i32*, ...)*, ...) @__kmpc_fork_call(%struct.ident_t* @{{.*}}, i32 2, void (i32*, i32*, ...)* bitcast (void (i32*, i32*, i64*, i8***)* [[OMP_OUTLINED:@.+]] to void (i32*, i32*, ...)*), i64* %{{.+}}, i8*** [[ARGC_ADDR]])
+// IRBUILDER-DEBUG:   call void (%struct.ident_t*, i32, void (i32*, i32*, ...)*, ...) @__kmpc_fork_call(%struct.ident_t* @{{.*}}, i32 2, void (i32*, i32*, ...)* bitcast (void (i32*, i32*, %CapturedStructType*, i8***)* [[OMP_OUTLINED:@.+]] to void (i32*, i32*, ...)*), %CapturedStructType* %CaptureStructAlloca, i8*** [[ARGC_ADDR]])
 // ALL-DEBUG:  ret i32 0
 // ALL-DEBUG-NEXT:  }
 
 // CHECK:       define internal {{.*}}void [[OMP_OUTLINED]](i32* noalias %.global_tid., i32* noalias %.bound_tid., i8*** nonnull align {{[0-9]+}} dereferenceable({{4|8}}) %argc, i{{64|32}}{{.*}} %{{.+}})
-// IRBUILDER:   define internal {{.*}}void [[OMP_OUTLINED]](i32* noalias %{{.*}}, i32* noalias %{{.*}}, i{{64|32}}*{{.*}} %{{.+}}, i8*** [[ARGC_REF:%.*]])
+// IRBUILDER:   define internal {{.*}}void [[OMP_OUTLINED]](i32* noalias %{{.*}}, i32* noalias %{{.*}}, %CapturedStructType* %CaptureStructAlloca, i8*** [[ARGC_REF:%.*]])
 // CHECK:       store i8*** %argc, i8**** [[ARGC_PTR_ADDR:%.+]],
 // CHECK:       [[ARGC_REF:%.+]] = load i8***, i8**** [[ARGC_PTR_ADDR]]
 // ALL:         [[ARGC:%.+]] = load i8**, i8*** [[ARGC_REF]]
@@ -163,7 +159,7 @@
 // CHECK-NEXT:  unreachable
 // CHECK-NEXT:  }
 // CHECK-DEBUG:       define internal void [[OMP_OUTLINED_DEBUG:@.+]](i32* noalias %.global_tid., i32* noalias %.bound_tid., i8*** nonnull align {{[0-9]+}} dereferenceable({{4|8}}) %argc, i64 %{{.+}})
-// IRBUILDER-DEBUG:   define internal void [[OMP_OUTLINED_DEBUG:@.+]](i32* noalias %{{.*}}, i32* noalias %{{.*}}, i64* %{{.+}}, i8*** [[ARGC_REF:%.*]])
+// IRBUILDER-DEBUG:   define internal void [[OMP_OUTLINED_DEBUG:@.+]](i32* noalias %{{.*}}, i32* noalias %{{.*}}, %CapturedStructType* %CaptureStructAlloca, i8*** [[ARGC_REF:%.*]])
 // CHECK-DEBUG:       store i8*** %argc, i8**** [[ARGC_PTR_ADDR:%.+]],
 // CHECK-DEBUG:       [[ARGC_REF:%.+]] = load i8***, i8**** [[ARGC_PTR_ADDR]]
 // ALL-DEBUG:         [[ARGC:%.+]] = load i8**, i8*** [[ARGC_REF]]
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to