https://github.com/atrosinenko updated https://github.com/llvm/llvm-project/pull/138883
>From 1c135a144d7f21e05c3598a992baa170cdde7950 Mon Sep 17 00:00:00 2001 From: Anatoly Trosinenko <atrosine...@accesssoftek.com> Date: Wed, 7 May 2025 16:42:00 +0300 Subject: [PATCH] [BOLT] Introduce helpers to match `MCInst`s one at a time (NFC) Introduce matchInst helper function to capture and/or match the operands of MCInst. Unlike the existing `MCPlusBuilder::MCInstMatcher` machinery, matchInst is intended for the use cases when precise control over the instruction order is required. For example, when validating PtrAuth hardening, all registers are usually considered unsafe after a function call, even though callee-saved registers should preserve their old values *under normal operation*. --- bolt/include/bolt/Core/MCInstUtils.h | 128 ++++++++++++++++++ .../Target/AArch64/AArch64MCPlusBuilder.cpp | 90 +++++------- 2 files changed, 162 insertions(+), 56 deletions(-) diff --git a/bolt/include/bolt/Core/MCInstUtils.h b/bolt/include/bolt/Core/MCInstUtils.h index a3912a8fb265a..b495eb8ef5eec 100644 --- a/bolt/include/bolt/Core/MCInstUtils.h +++ b/bolt/include/bolt/Core/MCInstUtils.h @@ -166,6 +166,134 @@ static inline raw_ostream &operator<<(raw_ostream &OS, return Ref.print(OS); } +/// Instruction-matching helpers operating on a single instruction at a time. +/// +/// Unlike MCPlusBuilder::MCInstMatcher, this matchInst() function focuses on +/// the cases where a precise control over the instruction order is important: +/// +/// // Bring the short names into the local scope: +/// using namespace MCInstMatcher; +/// // Declare the registers to capture: +/// Reg Xn, Xm; +/// // Capture the 0th and 1st operands, match the 2nd operand against the +/// // just captured Xm register, match the 3rd operand against literal 0: +/// if (!matchInst(MaybeAdd, AArch64::ADDXrs, Xm, Xn, Xm, Imm(0)) +/// return AArch64::NoRegister; +/// // Match the 0th operand against Xm: +/// if (!matchInst(MaybeBr, AArch64::BR, Xm)) +/// return AArch64::NoRegister; +/// // Return the matched register: +/// return Xm.get(); +namespace MCInstMatcher { + +// The base class to match an operand of type T. +// +// The subclasses of OpMatcher are intended to be allocated on the stack and +// to only be used by passing them to matchInst() and by calling their get() +// function, thus the peculiar `mutable` specifiers: to make the calling code +// compact and readable, the templated matchInst() function has to accept both +// long-lived Imm/Reg wrappers declared as local variables (intended to capture +// the first operand's value and match the subsequent operands, whether inside +// a single instruction or across multiple instructions), as well as temporary +// wrappers around literal values to match, f.e. Imm(42) or Reg(AArch64::XZR). +template <typename T> class OpMatcher { + mutable std::optional<T> Value; + mutable std::optional<T> SavedValue; + + // Remember/restore the last Value - to be called by matchInst. + void remember() const { SavedValue = Value; } + void restore() const { Value = SavedValue; } + + template <class... OpMatchers> + friend bool matchInst(const MCInst &, unsigned, const OpMatchers &...); + +protected: + OpMatcher(std::optional<T> ValueToMatch) : Value(ValueToMatch) {} + + bool matchValue(T OpValue) const { + // Check that OpValue does not contradict the existing Value. + bool MatchResult = !Value || *Value == OpValue; + // If MatchResult is false, all matchers will be reset before returning from + // matchInst, including this one, thus no need to assign conditionally. + Value = OpValue; + + return MatchResult; + } + +public: + /// Returns the captured value. + T get() const { + assert(Value.has_value()); + return *Value; + } +}; + +class Reg : public OpMatcher<MCPhysReg> { + bool matches(const MCOperand &Op) const { + if (!Op.isReg()) + return false; + + return matchValue(Op.getReg()); + } + + template <class... OpMatchers> + friend bool matchInst(const MCInst &, unsigned, const OpMatchers &...); + +public: + Reg(std::optional<MCPhysReg> RegToMatch = std::nullopt) + : OpMatcher<MCPhysReg>(RegToMatch) {} +}; + +class Imm : public OpMatcher<int64_t> { + bool matches(const MCOperand &Op) const { + if (!Op.isImm()) + return false; + + return matchValue(Op.getImm()); + } + + template <class... OpMatchers> + friend bool matchInst(const MCInst &, unsigned, const OpMatchers &...); + +public: + Imm(std::optional<int64_t> ImmToMatch = std::nullopt) + : OpMatcher<int64_t>(ImmToMatch) {} +}; + +/// Tries to match Inst and updates Ops on success. +/// +/// If Inst has the specified Opcode and its operand list prefix matches Ops, +/// this function returns true and updates Ops, otherwise false is returned and +/// values of Ops are kept as before matchInst was called. +/// +/// Please note that while Ops are technically passed by a const reference to +/// make invocations like `matchInst(MI, Opcode, Imm(42))` possible, all their +/// fields are marked mutable. +template <class... OpMatchers> +bool matchInst(const MCInst &Inst, unsigned Opcode, const OpMatchers &...Ops) { + if (Inst.getOpcode() != Opcode) + return false; + assert(sizeof...(Ops) <= Inst.getNumOperands() && + "Too many operands are matched for the Opcode"); + + // Ask each matcher to remember its current value in case of rollback. + (Ops.remember(), ...); + + // Check if all matchers match the corresponding operands. + auto It = Inst.begin(); + auto AllMatched = (Ops.matches(*(It++)) && ... && true); + + // If match failed, restore the original captured values. + if (!AllMatched) { + (Ops.restore(), ...); + return false; + } + + return true; +} + +} // namespace MCInstMatcher + } // namespace bolt } // namespace llvm diff --git a/bolt/lib/Target/AArch64/AArch64MCPlusBuilder.cpp b/bolt/lib/Target/AArch64/AArch64MCPlusBuilder.cpp index 4d11c5b206eab..2522de7005c64 100644 --- a/bolt/lib/Target/AArch64/AArch64MCPlusBuilder.cpp +++ b/bolt/lib/Target/AArch64/AArch64MCPlusBuilder.cpp @@ -19,6 +19,7 @@ #include "Utils/AArch64BaseInfo.h" #include "bolt/Core/BinaryBasicBlock.h" #include "bolt/Core/BinaryFunction.h" +#include "bolt/Core/MCInstUtils.h" #include "bolt/Core/MCPlusBuilder.h" #include "llvm/BinaryFormat/ELF.h" #include "llvm/MC/MCContext.h" @@ -393,81 +394,58 @@ class AArch64MCPlusBuilder : public MCPlusBuilder { // Iterate over the instructions of BB in reverse order, matching opcodes // and operands. - MCPhysReg TestedReg = 0; - MCPhysReg ScratchReg = 0; + auto It = BB.end(); - auto StepAndGetOpcode = [&It, &BB]() -> int { - if (It == BB.begin()) - return -1; - --It; - return It->getOpcode(); + auto StepBack = [&]() { + while (It != BB.begin()) { + --It; + if (!isCFI(*It)) + return true; + } + return false; }; - - switch (StepAndGetOpcode()) { - default: - // Not matched the branch instruction. + // Step to the last non-CFI instruction. + if (!StepBack()) return std::nullopt; - case AArch64::Bcc: - // Bcc EQ, .Lon_success - if (It->getOperand(0).getImm() != AArch64CC::EQ) - return std::nullopt; - // Not checking .Lon_success (see above). - // SUBSXrs XZR, TestedReg, ScratchReg, 0 (used by "CMP reg, reg" alias) - if (StepAndGetOpcode() != AArch64::SUBSXrs || - It->getOperand(0).getReg() != AArch64::XZR || - It->getOperand(3).getImm() != 0) + using namespace llvm::bolt::MCInstMatcher; + Reg TestedReg; + Reg ScratchReg; + + if (matchInst(*It, AArch64::Bcc, Imm(AArch64CC::EQ) /*, .Lon_success*/)) { + if (!StepBack() || !matchInst(*It, AArch64::SUBSXrs, Reg(AArch64::XZR), + TestedReg, ScratchReg, Imm(0))) return std::nullopt; - TestedReg = It->getOperand(1).getReg(); - ScratchReg = It->getOperand(2).getReg(); // Either XPAC(I|D) ScratchReg, ScratchReg // or XPACLRI - switch (StepAndGetOpcode()) { - default: + if (!StepBack()) return std::nullopt; - case AArch64::XPACLRI: + if (matchInst(*It, AArch64::XPACLRI)) { // No operands to check, but using XPACLRI forces TestedReg to be X30. - if (TestedReg != AArch64::LR) - return std::nullopt; - break; - case AArch64::XPACI: - case AArch64::XPACD: - if (It->getOperand(0).getReg() != ScratchReg || - It->getOperand(1).getReg() != ScratchReg) + if (TestedReg.get() != AArch64::LR) return std::nullopt; - break; + } else if (!matchInst(*It, AArch64::XPACI, ScratchReg, ScratchReg) && + !matchInst(*It, AArch64::XPACD, ScratchReg, ScratchReg)) { + return std::nullopt; } - // ORRXrs ScratchReg, XZR, TestedReg, 0 (used by "MOV reg, reg" alias) - if (StepAndGetOpcode() != AArch64::ORRXrs) + if (!StepBack() || !matchInst(*It, AArch64::ORRXrs, ScratchReg, + Reg(AArch64::XZR), TestedReg, Imm(0))) return std::nullopt; - if (It->getOperand(0).getReg() != ScratchReg || - It->getOperand(1).getReg() != AArch64::XZR || - It->getOperand(2).getReg() != TestedReg || - It->getOperand(3).getImm() != 0) - return std::nullopt; - - return std::make_pair(TestedReg, &*It); - case AArch64::TBZX: - // TBZX ScratchReg, 62, .Lon_success - ScratchReg = It->getOperand(0).getReg(); - if (It->getOperand(1).getImm() != 62) - return std::nullopt; - // Not checking .Lon_success (see above). + return std::make_pair(TestedReg.get(), &*It); + } - // EORXrs ScratchReg, TestedReg, TestedReg, 1 - if (StepAndGetOpcode() != AArch64::EORXrs) - return std::nullopt; - TestedReg = It->getOperand(1).getReg(); - if (It->getOperand(0).getReg() != ScratchReg || - It->getOperand(2).getReg() != TestedReg || - It->getOperand(3).getImm() != 1) + if (matchInst(*It, AArch64::TBZX, ScratchReg, Imm(62) /*, .Lon_success*/)) { + if (!StepBack() || !matchInst(*It, AArch64::EORXrs, Reg(ScratchReg), + TestedReg, TestedReg, Imm(1))) return std::nullopt; - return std::make_pair(TestedReg, &*It); + return std::make_pair(TestedReg.get(), &*It); } + + return std::nullopt; } std::optional<MCPhysReg> getAuthCheckedReg(const MCInst &Inst, _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits