================ @@ -461,6 +465,198 @@ bool PreISelIntrinsicLowering::expandMemIntrinsicUses( return Changed; } +namespace { + +enum class PointerEncoding { + Rotate, + PACCopyable, + PACNonCopyable, +}; + +bool expandProtectedFieldPtr(Function &Intr) { + Module &M = *Intr.getParent(); + + std::set<GlobalValue *> DSsToDeactivate; + std::set<Instruction *> LoadsStores; + + Type *Int8Ty = Type::getInt8Ty(M.getContext()); + Type *Int64Ty = Type::getInt64Ty(M.getContext()); + PointerType *PtrTy = PointerType::get(M.getContext(), 0); + + Function *SignIntr = + Intrinsic::getOrInsertDeclaration(&M, Intrinsic::ptrauth_sign, {}); + Function *AuthIntr = + Intrinsic::getOrInsertDeclaration(&M, Intrinsic::ptrauth_auth, {}); + + auto *EmuFnTy = FunctionType::get(Int64Ty, {Int64Ty, Int64Ty}, false); + FunctionCallee EmuSignIntr = M.getOrInsertFunction("__emupac_pacda", EmuFnTy); + FunctionCallee EmuAuthIntr = M.getOrInsertFunction("__emupac_autda", EmuFnTy); + + auto CreateSign = [&](IRBuilder<> &B, Value *Val, Value *Disc, + OperandBundleDef DSBundle) { + Function *F = B.GetInsertBlock()->getParent(); + Attribute FSAttr = F->getFnAttribute("target-features"); + if (FSAttr.isValid() && FSAttr.getValueAsString().contains("+pauth")) + return B.CreateCall(SignIntr, {Val, B.getInt32(2), Disc}, DSBundle); + return B.CreateCall(EmuSignIntr, {Val, Disc}, DSBundle); + }; + + auto CreateAuth = [&](IRBuilder<> &B, Value *Val, Value *Disc, + OperandBundleDef DSBundle) { + Function *F = B.GetInsertBlock()->getParent(); + Attribute FSAttr = F->getFnAttribute("target-features"); + if (FSAttr.isValid() && FSAttr.getValueAsString().contains("+pauth")) + return B.CreateCall(AuthIntr, {Val, B.getInt32(2), Disc}, DSBundle); + return B.CreateCall(EmuAuthIntr, {Val, Disc}, DSBundle); + }; + + auto GetDeactivationSymbol = [&](CallInst *Call) -> GlobalValue * { + if (auto Bundle = + Call->getOperandBundle(LLVMContext::OB_deactivation_symbol)) + return cast<GlobalValue>(Bundle->Inputs[0]); + return nullptr; + }; + + for (User *U : Intr.users()) { + auto *Call = cast<CallInst>(U); + auto *DS = GetDeactivationSymbol(Call); + std::set<PHINode *> VisitedPhis; + + std::function<void(Instruction *)> FindLoadsStores; + FindLoadsStores = [&](Instruction *I) { + for (Use &U : I->uses()) { + if (auto *LI = dyn_cast<LoadInst>(U.getUser())) { + if (isa<PointerType>(LI->getType())) { + LoadsStores.insert(LI); + continue; + } + } + if (auto *SI = dyn_cast<StoreInst>(U.getUser())) { + if (U.getOperandNo() == 1 && + isa<PointerType>(SI->getValueOperand()->getType())) { + LoadsStores.insert(SI); + continue; + } + } + if (auto *P = dyn_cast<PHINode>(U.getUser())) { + if (VisitedPhis.insert(P).second) + FindLoadsStores(P); + continue; + } + // Comparisons against null cannot be used to recover the original + // pointer so we allow them. + if (auto *CI = dyn_cast<ICmpInst>(U.getUser())) { ---------------- pcc wrote:
Also noticed that this is missing a test, added. https://github.com/llvm/llvm-project/pull/151647 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits