================
@@ -461,6 +465,198 @@ bool PreISelIntrinsicLowering::expandMemIntrinsicUses(
   return Changed;
 }
 
+namespace {
+
+enum class PointerEncoding {
+  Rotate,
+  PACCopyable,
+  PACNonCopyable,
+};
+
+bool expandProtectedFieldPtr(Function &Intr) {
+  Module &M = *Intr.getParent();
+
+  std::set<GlobalValue *> DSsToDeactivate;
+  std::set<Instruction *> LoadsStores;
+
+  Type *Int8Ty = Type::getInt8Ty(M.getContext());
+  Type *Int64Ty = Type::getInt64Ty(M.getContext());
+  PointerType *PtrTy = PointerType::get(M.getContext(), 0);
+
+  Function *SignIntr =
+      Intrinsic::getOrInsertDeclaration(&M, Intrinsic::ptrauth_sign, {});
+  Function *AuthIntr =
+      Intrinsic::getOrInsertDeclaration(&M, Intrinsic::ptrauth_auth, {});
+
+  auto *EmuFnTy = FunctionType::get(Int64Ty, {Int64Ty, Int64Ty}, false);
+  FunctionCallee EmuSignIntr = M.getOrInsertFunction("__emupac_pacda", 
EmuFnTy);
+  FunctionCallee EmuAuthIntr = M.getOrInsertFunction("__emupac_autda", 
EmuFnTy);
+
+  auto CreateSign = [&](IRBuilder<> &B, Value *Val, Value *Disc,
+                       OperandBundleDef DSBundle) {
+    Function *F = B.GetInsertBlock()->getParent();
+    Attribute FSAttr = F->getFnAttribute("target-features");
+    if (FSAttr.isValid() && FSAttr.getValueAsString().contains("+pauth"))
+      return B.CreateCall(SignIntr, {Val, B.getInt32(2), Disc}, DSBundle);
+    return B.CreateCall(EmuSignIntr, {Val, Disc}, DSBundle);
+  };
+
+  auto CreateAuth = [&](IRBuilder<> &B, Value *Val, Value *Disc,
+                       OperandBundleDef DSBundle) {
+    Function *F = B.GetInsertBlock()->getParent();
+    Attribute FSAttr = F->getFnAttribute("target-features");
+    if (FSAttr.isValid() && FSAttr.getValueAsString().contains("+pauth"))
+      return B.CreateCall(AuthIntr, {Val, B.getInt32(2), Disc}, DSBundle);
+    return B.CreateCall(EmuAuthIntr, {Val, Disc}, DSBundle);
+  };
+
+  auto GetDeactivationSymbol = [&](CallInst *Call) -> GlobalValue * {
+    if (auto Bundle =
+            Call->getOperandBundle(LLVMContext::OB_deactivation_symbol))
+      return cast<GlobalValue>(Bundle->Inputs[0]);
+    return nullptr;
+  };
+
+  for (User *U : Intr.users()) {
+    auto *Call = cast<CallInst>(U);
+    auto *DS = GetDeactivationSymbol(Call);
+    std::set<PHINode *> VisitedPhis;
+
+    std::function<void(Instruction *)> FindLoadsStores;
+    FindLoadsStores = [&](Instruction *I) {
+      for (Use &U : I->uses()) {
+        if (auto *LI = dyn_cast<LoadInst>(U.getUser())) {
+          if (isa<PointerType>(LI->getType())) {
+            LoadsStores.insert(LI);
+            continue;
+          }
+        }
+        if (auto *SI = dyn_cast<StoreInst>(U.getUser())) {
+          if (U.getOperandNo() == 1 &&
+              isa<PointerType>(SI->getValueOperand()->getType())) {
+            LoadsStores.insert(SI);
+            continue;
+          }
+        }
+        if (auto *P = dyn_cast<PHINode>(U.getUser())) {
+          if (VisitedPhis.insert(P).second)
+            FindLoadsStores(P);
+          continue;
+        }
+        // Comparisons against null cannot be used to recover the original
+        // pointer so we allow them.
+        if (auto *CI = dyn_cast<ICmpInst>(U.getUser())) {
----------------
pcc wrote:

Also noticed that this is missing a test, added.

https://github.com/llvm/llvm-project/pull/151647
_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to