================ @@ -339,6 +369,183 @@ class AArch64MCPlusBuilder : public MCPlusBuilder { } } + std::optional<std::pair<MCPhysReg, MCInst *>> + getAuthCheckedReg(BinaryBasicBlock &BB) const override { + // Match several possible hard-coded sequences of instructions which can be + // emitted by LLVM backend to check that the authenticated pointer is + // correct (see AArch64AsmPrinter::emitPtrauthCheckAuthenticatedValue). + // + // This function only matches sequences involving branch instructions. + // All these sequences have the form: + // + // (0) ... regular code that authenticates a pointer in Xn ... + // (1) analyze Xn + // (2) branch to .Lon_success if the pointer is correct + // (3) BRK #imm (fall-through basic block) + // + // In the above pseudocode, (1) + (2) is one of the following sequences: + // + // - eor Xtmp, Xn, Xn, lsl #1 + // tbz Xtmp, #62, .Lon_success + // + // - mov Xtmp, Xn + // xpac(i|d) Xn (or xpaclri if Xn is LR) + // cmp Xtmp, Xn + // b.eq .Lon_success + // + // Note that any branch destination operand is accepted as .Lon_success - + // it is the responsibility of the caller of getAuthCheckedReg to inspect + // the list of successors of this basic block as appropriate. + + // Any of the above code sequences assume the fall-through basic block + // is a dead-end BRK instruction (any immediate operand is accepted). + const BinaryBasicBlock *BreakBB = BB.getFallthrough(); + if (!BreakBB || BreakBB->empty() || + BreakBB->front().getOpcode() != AArch64::BRK) + return std::nullopt; + + // Iterate over the instructions of BB in reverse order, matching opcodes + // and operands. + MCPhysReg TestedReg = 0; + MCPhysReg ScratchReg = 0; + auto It = BB.end(); + auto StepAndGetOpcode = [&It, &BB]() -> int { + if (It == BB.begin()) + return -1; + --It; + return It->getOpcode(); + }; + + switch (StepAndGetOpcode()) { + default: + // Not matched the branch instruction. + return std::nullopt; + case AArch64::Bcc: + // Bcc EQ, .Lon_success + if (It->getOperand(0).getImm() != AArch64CC::EQ) + return std::nullopt; + // Not checking .Lon_success (see above). + + // SUBSXrs XZR, TestedReg, ScratchReg, 0 (used by "CMP reg, reg" alias) + if (StepAndGetOpcode() != AArch64::SUBSXrs || + It->getOperand(0).getReg() != AArch64::XZR || + It->getOperand(3).getImm() != 0) + return std::nullopt; + TestedReg = It->getOperand(1).getReg(); + ScratchReg = It->getOperand(2).getReg(); + + // Either XPAC(I|D) ScratchReg, ScratchReg + // or XPACLRI + switch (StepAndGetOpcode()) { + default: + return std::nullopt; + case AArch64::XPACLRI: + // No operands to check, but using XPACLRI forces TestedReg to be X30. + if (TestedReg != AArch64::LR) + return std::nullopt; + break; + case AArch64::XPACI: + case AArch64::XPACD: + if (It->getOperand(0).getReg() != ScratchReg || + It->getOperand(1).getReg() != ScratchReg) + return std::nullopt; + break; + } + + // ORRXrs ScratchReg, XZR, TestedReg, 0 (used by "MOV reg, reg" alias) + if (StepAndGetOpcode() != AArch64::ORRXrs) + return std::nullopt; + if (It->getOperand(0).getReg() != ScratchReg || + It->getOperand(1).getReg() != AArch64::XZR || + It->getOperand(2).getReg() != TestedReg || + It->getOperand(3).getImm() != 0) + return std::nullopt; + + return std::make_pair(TestedReg, &*It); + + case AArch64::TBZX: + // TBZX ScratchReg, 62, .Lon_success + ScratchReg = It->getOperand(0).getReg(); + if (It->getOperand(1).getImm() != 62) + return std::nullopt; + // Not checking .Lon_success (see above). + + // EORXrs ScratchReg, TestedReg, TestedReg, 1 + if (StepAndGetOpcode() != AArch64::EORXrs) + return std::nullopt; + TestedReg = It->getOperand(1).getReg(); + if (It->getOperand(0).getReg() != ScratchReg || + It->getOperand(2).getReg() != TestedReg || + It->getOperand(3).getImm() != 1) + return std::nullopt; + + return std::make_pair(TestedReg, &*It); + } + } + + MCPhysReg getAuthCheckedReg(const MCInst &Inst, + bool MayOverwrite) const override { + // Cannot trivially reuse AArch64InstrInfo::getMemOperandWithOffsetWidth() + // method as it accepts an instance of MachineInstr, not MCInst. + const MCInstrDesc &Desc = Info->get(Inst.getOpcode()); + + // If signing oracles are considered, the particular value left in the base + // register after this instruction is important. This function checks that + // if the base register was overwritten, it is due to address write-back. + // + // Note that this function is not needed for authentication oracles, as the + // particular value left in the register after a successful memory access + // is not important. + auto ClobbersBaseRegExceptWriteback = [&](unsigned BaseRegUseIndex) { + MCPhysReg BaseReg = Inst.getOperand(BaseRegUseIndex).getReg(); + unsigned WrittenBackDefIndex = Desc.getOperandConstraint( + BaseRegUseIndex, MCOI::OperandConstraint::TIED_TO); ---------------- atrosinenko wrote:
Now I probably got the idea, but it looks like adding these helpers to `AArch64InstrInfo` the right way would require some amount of refactoring. My understanding is that it would be great to add something like ```cpp class AArch64InstrInfo { // ... unsigned getLdStBaseOpIndex(unsigned Opc); unsigned getLdStWritebackOpIndex(unsigned Opc); // ... }; ``` There are already [`static` member functions](https://github.com/llvm/llvm-project/blob/96519028d514853d429c2d09482ba0bd9a899c57/llvm/lib/Target/AArch64/AArch64InstrInfo.h#L249) like these: ```cpp /// Returns the base register operator of a load/store. static const MachineOperand &getLdStBaseOp(const MachineInstr &MI); ``` This looks like one of the two functions we need, with the exception that we don't have a `MachineInstr`. For this PR, this would have to be updated similar to `getMemScale` functions above: ```cpp /// Scaling factor for (scaled or unscaled) load or store. static int getMemScale(unsigned Opc); static int getMemScale(const MachineInstr &MI) { return getMemScale(MI.getOpcode()); } ``` But according to the [definition](https://github.com/llvm/llvm-project/blob/cc6def4b7521676fd339936d027e48928e0ba398/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp#L4652) of `getLdStBaseOp`, it seems that it only handles a subset of possible cases. I could add non-`static` functions accepting opcodes and returning indices - these functions would be able to use tablegenerated instruction descriptions, but it would be rather surprising to have almost identical functions with significantly different implementations and behavior. Or I could define isPostLd/isPostSt/isPostLdSt counterparts, but it seems that I then have to support paired variants of pre- and post-incrementing instructions - this looks like a whole separate PR. https://github.com/llvm/llvm-project/pull/134146 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits