ftynse updated this revision to Diff 308643.
ftynse added a comment.
Herald added a project: clang.
Herald added a subscriber: cfe-commits.

Fix clang tests. The order of arguments is switched in the internal outlined 
function.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D92189/new/

https://reviews.llvm.org/D92189

Files:
  clang/test/OpenMP/parallel_codegen.cpp
  llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
  llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp

Index: llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
===================================================================
--- llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -60,6 +60,25 @@
   DebugLoc DL;
 };
 
+// Returns the value stored in the given allocation. Returns null if the given
+// value is not a result of an allocation, if no value is stored or if there is
+// more than one store.
+static Value *findStoredValue(Value *AllocaValue) {
+  Instruction *Alloca = dyn_cast<AllocaInst>(AllocaValue);
+  if (!Alloca)
+    return nullptr;
+  StoreInst *Store = nullptr;
+  for (Use &U : Alloca->uses()) {
+    if (auto *CandidateStore = dyn_cast<StoreInst>(U.getUser())) {
+      EXPECT_EQ(Store, nullptr);
+      Store = CandidateStore;
+    }
+  }
+  if (!Store)
+    return nullptr;
+  return Store->getValueOperand();
+};
+
 TEST_F(OpenMPIRBuilderTest, CreateBarrier) {
   OpenMPIRBuilder OMPBuilder(*M);
   OMPBuilder.initialize();
@@ -401,7 +420,7 @@
   EXPECT_EQ(ForkCI->getArgOperand(1),
             ConstantInt::get(Type::getInt32Ty(Ctx), 1U));
   EXPECT_EQ(ForkCI->getArgOperand(2), Usr);
-  EXPECT_EQ(ForkCI->getArgOperand(3), F->arg_begin());
+  EXPECT_EQ(findStoredValue(ForkCI->getArgOperand(3)), F->arg_begin());
 }
 
 TEST_F(OpenMPIRBuilderTest, ParallelNested) {
@@ -708,13 +727,15 @@
   EXPECT_TRUE(isa<GlobalVariable>(ForkCI->getArgOperand(0)));
   EXPECT_EQ(ForkCI->getArgOperand(1),
             ConstantInt::get(Type::getInt32Ty(Ctx), 1));
-  EXPECT_EQ(ForkCI->getArgOperand(3), F->arg_begin());
+  Value *StoredForkArg = findStoredValue(ForkCI->getArgOperand(3));
+  EXPECT_EQ(StoredForkArg, F->arg_begin());
 
   EXPECT_EQ(DirectCI->getCalledFunction(), OutlinedFn);
   EXPECT_EQ(DirectCI->getNumArgOperands(), 3U);
   EXPECT_TRUE(isa<AllocaInst>(DirectCI->getArgOperand(0)));
   EXPECT_TRUE(isa<AllocaInst>(DirectCI->getArgOperand(1)));
-  EXPECT_EQ(DirectCI->getArgOperand(2), F->arg_begin());
+  Value *StoredDirectArg = findStoredValue(DirectCI->getArgOperand(2));
+  EXPECT_EQ(StoredDirectArg, F->arg_begin());
 }
 
 TEST_F(OpenMPIRBuilderTest, ParallelCancelBarrier) {
@@ -829,6 +850,85 @@
   }
 }
 
+TEST_F(OpenMPIRBuilderTest, ParallelForwardAsPointers) {
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.initialize();
+  F->setName("func");
+  IRBuilder<> Builder(BB);
+  OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
+  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+
+  Type *I32Ty = Type::getInt32Ty(M->getContext());
+  Type *I32PtrTy = Type::getInt32PtrTy(M->getContext());
+  Type *StructTy = StructType::get(I32Ty, I32PtrTy);
+  Type *StructPtrTy = StructTy->getPointerTo();
+  Type *VoidTy = Type::getVoidTy(M->getContext());
+  FunctionCallee RetI32Func = M->getOrInsertFunction("ret_i32", I32Ty);
+  FunctionCallee TakeI32Func =
+      M->getOrInsertFunction("take_i32", VoidTy, I32Ty);
+  FunctionCallee RetI32PtrFunc = M->getOrInsertFunction("ret_i32ptr", I32PtrTy);
+  FunctionCallee TakeI32PtrFunc =
+      M->getOrInsertFunction("take_i32ptr", VoidTy, I32PtrTy);
+  FunctionCallee RetStructFunc = M->getOrInsertFunction("ret_struct", StructTy);
+  FunctionCallee TakeStructFunc =
+      M->getOrInsertFunction("take_struct", VoidTy, StructTy);
+  FunctionCallee RetStructPtrFunc =
+      M->getOrInsertFunction("ret_structptr", StructPtrTy);
+  FunctionCallee TakeStructPtrFunc =
+      M->getOrInsertFunction("take_structPtr", VoidTy, StructPtrTy);
+  Value *I32Val = Builder.CreateCall(RetI32Func);
+  Value *I32PtrVal = Builder.CreateCall(RetI32PtrFunc);
+  Value *StructVal = Builder.CreateCall(RetStructFunc);
+  Value *StructPtrVal = Builder.CreateCall(RetStructPtrFunc);
+
+  Instruction *Internal;
+  auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
+                       BasicBlock &ContinuationBB) {
+    IRBuilder<>::InsertPointGuard Guard(Builder);
+    Builder.restoreIP(CodeGenIP);
+    Internal = Builder.CreateCall(TakeI32Func, I32Val);
+    Builder.CreateCall(TakeI32PtrFunc, I32PtrVal);
+    Builder.CreateCall(TakeStructFunc, StructVal);
+    Builder.CreateCall(TakeStructPtrFunc, StructPtrVal);
+  };
+  auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
+                    Value &VPtr, Value *&ReplacementValue) {
+    ReplacementValue = &VPtr;
+    return CodeGenIP;
+  };
+  auto FiniCB = [](InsertPointTy) {};
+
+  IRBuilder<>::InsertPoint AllocaIP(&F->getEntryBlock(),
+                                    F->getEntryBlock().getFirstInsertionPt());
+  IRBuilder<>::InsertPoint AfterIP =
+      OMPBuilder.createParallel(Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB,
+                                nullptr, nullptr, OMP_PROC_BIND_default, false);
+  Builder.restoreIP(AfterIP);
+  Builder.CreateRetVoid();
+
+  OMPBuilder.finalize();
+
+  EXPECT_FALSE(verifyModule(*M, &errs()));
+  Function *OutlinedFn = Internal->getFunction();
+
+  Type *Arg2Type = OutlinedFn->getArg(2)->getType();
+  EXPECT_TRUE(Arg2Type->isPointerTy());
+  EXPECT_EQ(Arg2Type->getPointerElementType(), I32Ty);
+
+  // Arguments that need to be passed through pointers and reloaded will get
+  // used earlier in the functions and therefore will appear first in the
+  // argument list after outlining.
+  Type *Arg3Type = OutlinedFn->getArg(3)->getType();
+  EXPECT_TRUE(Arg3Type->isPointerTy());
+  EXPECT_EQ(Arg3Type->getPointerElementType(), StructTy);
+
+  Type *Arg4Type = OutlinedFn->getArg(4)->getType();
+  EXPECT_EQ(Arg4Type, I32PtrTy);
+
+  Type *Arg5Type = OutlinedFn->getArg(5)->getType();
+  EXPECT_EQ(Arg5Type, StructPtrTy);
+}
+
 TEST_F(OpenMPIRBuilderTest, CanonicalLoopSimple) {
   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
   OpenMPIRBuilder OMPBuilder(*M);
Index: llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
===================================================================
--- llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -522,7 +522,8 @@
 
   // Add some fake uses for OpenMP provided arguments.
   ToBeDeleted.push_back(Builder.CreateLoad(TIDAddr, "tid.addr.use"));
-  ToBeDeleted.push_back(Builder.CreateLoad(ZeroAddr, "zero.addr.use"));
+  Instruction *ZeroAddrUse = Builder.CreateLoad(ZeroAddr, "zero.addr.use");
+  ToBeDeleted.push_back(ZeroAddrUse);
 
   // ThenBB
   //   |
@@ -687,15 +688,41 @@
   FunctionCallee TIDRTLFn =
       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_global_thread_num);
 
+  // Define the insertion point for loading the values wrapped into pointers for
+  // passing into the to-be-outlined region. Insert them immediately after the
+  // fake use of zero address so that they are available in the generated body
+  // and so that the OpenMP-related values (thread ID and zero address pointers)
+  // remain leading in the argument list.
+  IRBuilder<>::InsertPoint ReloadIP(ZeroAddrUse->getParent(),
+                                    ZeroAddrUse->getNextNode()->getIterator());
+
   auto PrivHelper = [&](Value &V) {
     if (&V == TIDAddr || &V == ZeroAddr)
       return;
 
-    SmallVector<Use *, 8> Uses;
+    SetVector<Use *> Uses;
     for (Use &U : V.uses())
       if (auto *UserI = dyn_cast<Instruction>(U.getUser()))
         if (ParallelRegionBlockSet.count(UserI->getParent()))
-          Uses.push_back(&U);
+          Uses.insert(&U);
+
+    Value *Reloaded = nullptr;
+    if (!V.getType()->isPointerTy()) {
+      IRBuilder<>::InsertPointGuard Guard(Builder);
+
+      // Store to stack at end of the block that currently branches to the entry
+      // block of the to-be-outlined region.
+      Builder.SetInsertPoint(InsertBB,
+                             InsertBB->getTerminator()->getIterator());
+      Value *Ptr =
+          Builder.CreateAlloca(V.getType(), nullptr, V.getName() + ".reloaded");
+      Builder.CreateStore(&V, Ptr);
+
+      // Load back next to allocations in the to-be-outlined region.
+      Builder.restoreIP(ReloadIP);
+      Reloaded = Builder.CreateLoad(Ptr);
+      InnerAllocaIP = Builder.saveIP();
+    }
 
     Value *ReplacementValue = nullptr;
     CallInst *CI = dyn_cast<CallInst>(&V);
@@ -706,10 +733,47 @@
           PrivCB(InnerAllocaIP, Builder.saveIP(), V, ReplacementValue));
       assert(ReplacementValue &&
              "Expected copy/create callback to set replacement value!");
-      if (ReplacementValue == &V)
-        return;
     }
 
+    // __kmpc_fork_call expects extra arguments as pointers. If the input
+    // already has a pointer type, everything is fine, only use the replacement
+    // value inside the function. This also captures the TID case because it is
+    // passed in as a pointer.
+    if (V.getType()->isPointerTy()) {
+      if (ReplacementValue != &V)
+        for (Use *UPtr : Uses)
+          UPtr->set(ReplacementValue);
+
+      return;
+    }
+
+    // Otherwise, store the value onto stack and load it back inside the
+    // to-be-outlined region. This will ensure only the pointer will be passed
+    // to the function.
+    LLVM_DEBUG(llvm::dbgs() << "Forwarding input as pointer: " << V << "\n");
+    assert(Reloaded && "Expected non-pointer argument to be loaded back!");
+
+    // Find new uses created by the privatization.
+    SmallVector<Use *, 4> PrivatizationUses;
+    for (Use &U : V.uses()) {
+      if (Uses.contains(&U))
+        continue;
+      if (auto *UserI = dyn_cast<Instruction>(U.getUser()))
+        if (ParallelRegionBlockSet.count(UserI->getParent()))
+          PrivatizationUses.push_back(&U);
+    }
+
+    // Any uses of the original value introduced by the privatization callback
+    // should use the loaded-back value instead.
+    for (Use *UPtr : PrivatizationUses)
+      UPtr->set(Reloaded);
+
+    // Replace original uses of the value with the replacement value. If the
+    // callback returned the original value, use the loaded-back value instead
+    // because all uses of the original value in the to-be-outlined region must
+    // disappear.
+    if (ReplacementValue == &V)
+      ReplacementValue = Reloaded;
     for (Use *UPtr : Uses)
       UPtr->set(ReplacementValue);
   };
Index: clang/test/OpenMP/parallel_codegen.cpp
===================================================================
--- clang/test/OpenMP/parallel_codegen.cpp
+++ clang/test/OpenMP/parallel_codegen.cpp
@@ -136,19 +136,19 @@
 // ALL:       define linkonce_odr {{[a-z\_\b]*[ ]?i32}} [[TMAIN]](i8** %argc)
 // ALL:       store i8** %argc, i8*** [[ARGC_ADDR:%.+]],
 // CHECK:       call {{.*}}void (%struct.ident_t*, i32, void (i32*, i32*, ...)*, ...) @__kmpc_fork_call(%struct.ident_t* [[DEF_LOC_2]], i32 2, void (i32*, i32*, ...)* bitcast (void (i32*, i32*, i8***, i{{64|32}})* [[OMP_OUTLINED:@.+]] to void (i32*, i32*, ...)*), i8*** [[ARGC_ADDR]], i{{64|32}} %{{.+}})
-// IRBUILDER:   call {{.*}}void (%struct.ident_t*, i32, void (i32*, i32*, ...)*, ...) @__kmpc_fork_call(%struct.ident_t* [[DEF_LOC_2]], i32 2, void (i32*, i32*, ...)* bitcast (void (i32*, i32*, i8***, i{{64|32}})* [[OMP_OUTLINED:@.+]] to void (i32*, i32*, ...)*), i8*** [[ARGC_ADDR]], i{{64|32}} %{{.+}})
+// IRBUILDER:   call {{.*}}void (%struct.ident_t*, i32, void (i32*, i32*, ...)*, ...) @__kmpc_fork_call(%struct.ident_t* [[DEF_LOC_2]], i32 2, void (i32*, i32*, ...)* bitcast (void (i32*, i32*, i{{64|32}}*, i8***)* [[OMP_OUTLINED:@.+]] to void (i32*, i32*, ...)*), i{{64|32}}* %{{.+}}, i8*** [[ARGC_ADDR]])
 // ALL:  ret i32 0
 // ALL-NEXT:  }
 // ALL-DEBUG:       define linkonce_odr i32 [[TMAIN]](i8** %argc)
 
 // CHECK-DEBUG:       store i8** %argc, i8*** [[ARGC_ADDR:%.+]],
 // CHECK-DEBUG:       call void (%struct.ident_t*, i32, void (i32*, i32*, ...)*, ...) @__kmpc_fork_call(%struct.ident_t* @{{.*}}, i32 2, void (i32*, i32*, ...)* bitcast (void (i32*, i32*, i8***, i64)* [[OMP_OUTLINED:@.+]] to void (i32*, i32*, ...)*), i8*** [[ARGC_ADDR]], i64 %{{.+}})
-// IRBUILDER-DEBUG:   call void (%struct.ident_t*, i32, void (i32*, i32*, ...)*, ...) @__kmpc_fork_call(%struct.ident_t* @{{.*}}, i32 2, void (i32*, i32*, ...)* bitcast (void (i32*, i32*, i8***, i64)* [[OMP_OUTLINED:@.+]] to void (i32*, i32*, ...)*), i8*** [[ARGC_ADDR]], i64 %{{.+}})
+// IRBUILDER-DEBUG:   call void (%struct.ident_t*, i32, void (i32*, i32*, ...)*, ...) @__kmpc_fork_call(%struct.ident_t* @{{.*}}, i32 2, void (i32*, i32*, ...)* bitcast (void (i32*, i32*, i64*, i8***)* [[OMP_OUTLINED:@.+]] to void (i32*, i32*, ...)*), i64* %{{.+}}, i8*** [[ARGC_ADDR]])
 // ALL-DEBUG:  ret i32 0
 // ALL-DEBUG-NEXT:  }
 
 // CHECK:       define internal {{.*}}void [[OMP_OUTLINED]](i32* noalias %.global_tid., i32* noalias %.bound_tid., i8*** nonnull align {{[0-9]+}} dereferenceable({{4|8}}) %argc, i{{64|32}}{{.*}} %{{.+}})
-// IRBUILDER:   define internal {{.*}}void [[OMP_OUTLINED]](i32* noalias %{{.*}}, i32* noalias %{{.*}}, i8*** [[ARGC_REF:%.*]], i{{64|32}}{{.*}} %{{.+}})
+// IRBUILDER:   define internal {{.*}}void [[OMP_OUTLINED]](i32* noalias %{{.*}}, i32* noalias %{{.*}}, i{{64|32}}*{{.*}} %{{.+}}, i8*** [[ARGC_REF:%.*]])
 // CHECK:       store i8*** %argc, i8**** [[ARGC_PTR_ADDR:%.+]],
 // CHECK:       [[ARGC_REF:%.+]] = load i8***, i8**** [[ARGC_PTR_ADDR]]
 // ALL:         [[ARGC:%.+]] = load i8**, i8*** [[ARGC_REF]]
@@ -159,7 +159,7 @@
 // CHECK-NEXT:  unreachable
 // CHECK-NEXT:  }
 // CHECK-DEBUG:       define internal void [[OMP_OUTLINED_DEBUG:@.+]](i32* noalias %.global_tid., i32* noalias %.bound_tid., i8*** nonnull align {{[0-9]+}} dereferenceable({{4|8}}) %argc, i64 %{{.+}})
-// IRBUILDER-DEBUG:   define internal void [[OMP_OUTLINED_DEBUG:@.+]](i32* noalias %{{.*}}, i32* noalias %{{.*}}, i8*** [[ARGC_REF:%.*]], i64 %{{.+}})
+// IRBUILDER-DEBUG:   define internal void [[OMP_OUTLINED_DEBUG:@.+]](i32* noalias %{{.*}}, i32* noalias %{{.*}}, i64* %{{.+}}, i8*** [[ARGC_REF:%.*]])
 // CHECK-DEBUG:       store i8*** %argc, i8**** [[ARGC_PTR_ADDR:%.+]],
 // CHECK-DEBUG:       [[ARGC_REF:%.+]] = load i8***, i8**** [[ARGC_PTR_ADDR]]
 // ALL-DEBUG:         [[ARGC:%.+]] = load i8**, i8*** [[ARGC_REF]]
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to