================
@@ -339,6 +369,183 @@ class AArch64MCPlusBuilder : public MCPlusBuilder {
     }
   }
 
+  std::optional<std::pair<MCPhysReg, MCInst *>>
+  getAuthCheckedReg(BinaryBasicBlock &BB) const override {
+    // Match several possible hard-coded sequences of instructions which can be
+    // emitted by LLVM backend to check that the authenticated pointer is
+    // correct (see AArch64AsmPrinter::emitPtrauthCheckAuthenticatedValue).
+    //
+    // This function only matches sequences involving branch instructions.
+    // All these sequences have the form:
+    //
+    // (0) ... regular code that authenticates a pointer in Xn ...
+    // (1) analyze Xn
+    // (2) branch to .Lon_success if the pointer is correct
+    // (3) BRK #imm (fall-through basic block)
+    //
+    // In the above pseudocode, (1) + (2) is one of the following sequences:
+    //
+    // - eor Xtmp, Xn, Xn, lsl #1
+    //   tbz Xtmp, #62, .Lon_success
+    //
+    // - mov Xtmp, Xn
+    //   xpac(i|d) Xn (or xpaclri if Xn is LR)
+    //   cmp Xtmp, Xn
+    //   b.eq .Lon_success
+    //
+    // Note that any branch destination operand is accepted as .Lon_success -
+    // it is the responsibility of the caller of getAuthCheckedReg to inspect
+    // the list of successors of this basic block as appropriate.
+
+    // Any of the above code sequences assume the fall-through basic block
+    // is a dead-end BRK instruction (any immediate operand is accepted).
+    const BinaryBasicBlock *BreakBB = BB.getFallthrough();
+    if (!BreakBB || BreakBB->empty() ||
+        BreakBB->front().getOpcode() != AArch64::BRK)
+      return std::nullopt;
+
+    // Iterate over the instructions of BB in reverse order, matching opcodes
+    // and operands.
+    MCPhysReg TestedReg = 0;
+    MCPhysReg ScratchReg = 0;
+    auto It = BB.end();
+    auto StepAndGetOpcode = [&It, &BB]() -> int {
+      if (It == BB.begin())
+        return -1;
+      --It;
+      return It->getOpcode();
+    };
+
+    switch (StepAndGetOpcode()) {
+    default:
+      // Not matched the branch instruction.
+      return std::nullopt;
+    case AArch64::Bcc:
+      // Bcc EQ, .Lon_success
+      if (It->getOperand(0).getImm() != AArch64CC::EQ)
+        return std::nullopt;
+      // Not checking .Lon_success (see above).
+
+      // SUBSXrs XZR, TestedReg, ScratchReg, 0 (used by "CMP reg, reg" alias)
+      if (StepAndGetOpcode() != AArch64::SUBSXrs ||
+          It->getOperand(0).getReg() != AArch64::XZR ||
+          It->getOperand(3).getImm() != 0)
+        return std::nullopt;
+      TestedReg = It->getOperand(1).getReg();
+      ScratchReg = It->getOperand(2).getReg();
+
+      // Either XPAC(I|D) ScratchReg, ScratchReg
+      // or     XPACLRI
+      switch (StepAndGetOpcode()) {
+      default:
+        return std::nullopt;
+      case AArch64::XPACLRI:
+        // No operands to check, but using XPACLRI forces TestedReg to be X30.
+        if (TestedReg != AArch64::LR)
+          return std::nullopt;
+        break;
+      case AArch64::XPACI:
+      case AArch64::XPACD:
+        if (It->getOperand(0).getReg() != ScratchReg ||
+            It->getOperand(1).getReg() != ScratchReg)
+          return std::nullopt;
+        break;
+      }
+
+      // ORRXrs ScratchReg, XZR, TestedReg, 0 (used by "MOV reg, reg" alias)
+      if (StepAndGetOpcode() != AArch64::ORRXrs)
+        return std::nullopt;
+      if (It->getOperand(0).getReg() != ScratchReg ||
+          It->getOperand(1).getReg() != AArch64::XZR ||
+          It->getOperand(2).getReg() != TestedReg ||
+          It->getOperand(3).getImm() != 0)
+        return std::nullopt;
+
+      return std::make_pair(TestedReg, &*It);
+
+    case AArch64::TBZX:
+      // TBZX ScratchReg, 62, .Lon_success
+      ScratchReg = It->getOperand(0).getReg();
+      if (It->getOperand(1).getImm() != 62)
+        return std::nullopt;
+      // Not checking .Lon_success (see above).
+
+      // EORXrs ScratchReg, TestedReg, TestedReg, 1
+      if (StepAndGetOpcode() != AArch64::EORXrs)
+        return std::nullopt;
+      TestedReg = It->getOperand(1).getReg();
+      if (It->getOperand(0).getReg() != ScratchReg ||
+          It->getOperand(2).getReg() != TestedReg ||
+          It->getOperand(3).getImm() != 1)
+        return std::nullopt;
+
+      return std::make_pair(TestedReg, &*It);
+    }
+  }
+
+  MCPhysReg getAuthCheckedReg(const MCInst &Inst,
+                              bool MayOverwrite) const override {
+    // Cannot trivially reuse AArch64InstrInfo::getMemOperandWithOffsetWidth()
+    // method as it accepts an instance of MachineInstr, not MCInst.
+    const MCInstrDesc &Desc = Info->get(Inst.getOpcode());
+
+    // If signing oracles are considered, the particular value left in the base
+    // register after this instruction is important. This function checks that
+    // if the base register was overwritten, it is due to address write-back.
+    //
+    // Note that this function is not needed for authentication oracles, as the
+    // particular value left in the register after a successful memory access
+    // is not important.
+    auto ClobbersBaseRegExceptWriteback = [&](unsigned BaseRegUseIndex) {
+      MCPhysReg BaseReg = Inst.getOperand(BaseRegUseIndex).getReg();
+      unsigned WrittenBackDefIndex = Desc.getOperandConstraint(
+          BaseRegUseIndex, MCOI::OperandConstraint::TIED_TO);
----------------
atrosinenko wrote:

Now I probably got the idea, but it looks like adding these helpers to 
`AArch64InstrInfo` the right way would require some amount of refactoring.

My understanding is that it would be great to add something like
```cpp
class AArch64InstrInfo {
  // ...
  unsigned getLdStBaseOpIndex(unsigned Opc);
  unsigned getLdStWritebackOpIndex(unsigned Opc);
  // ...
};
```
There are already [`static` member 
functions](https://github.com/llvm/llvm-project/blob/96519028d514853d429c2d09482ba0bd9a899c57/llvm/lib/Target/AArch64/AArch64InstrInfo.h#L249)
 like these:
```cpp
  /// Returns the base register operator of a load/store.
  static const MachineOperand &getLdStBaseOp(const MachineInstr &MI);
```
This looks like one of the two functions we need, with the exception that we 
don't have a `MachineInstr`. For this PR, this would have to be updated similar 
to `getMemScale` functions above:
```cpp
  /// Scaling factor for (scaled or unscaled) load or store.
  static int getMemScale(unsigned Opc);
  static int getMemScale(const MachineInstr &MI) {
    return getMemScale(MI.getOpcode());
  }
```

But according to the 
[definition](https://github.com/llvm/llvm-project/blob/cc6def4b7521676fd339936d027e48928e0ba398/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp#L4652)
 of `getLdStBaseOp`, it seems that it only handles a subset of possible cases. 
I could add non-`static` functions accepting opcodes and returning indices - 
these functions would be able to use tablegenerated instruction descriptions, 
but it would be rather surprising to have almost identical functions with 
significantly different implementations and behavior. Or I could define 
isPostLd/isPostSt/isPostLdSt counterparts, but it seems that I then have to 
support paired variants of pre- and post-incrementing instructions - this looks 
like a whole separate PR.

https://github.com/llvm/llvm-project/pull/134146
_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to