https://github.com/atrosinenko created https://github.com/llvm/llvm-project/pull/135661
* use more flexible `const ArrayRef<T>` and `StringRef` types instead of `const std::vector<T> &` and `const std::string &`, correspondingly, for function arguments * return plain `const SrcState &` instead of `ErrorOr<const SrcState &>` from `SrcSafetyAnalysis::getStateBefore`, as absent state is not handled gracefully by any caller >From 51373db0c000ad32a91eb4097ccc4404a6e54d25 Mon Sep 17 00:00:00 2001 From: Anatoly Trosinenko <atrosine...@accesssoftek.com> Date: Mon, 14 Apr 2025 14:35:56 +0300 Subject: [PATCH] [BOLT] Gadget scanner: use more appropriate types (NFC) * use more flexible `const ArrayRef<T>` and `StringRef` types instead of `const std::vector<T> &` and `const std::string &`, correspondingly, for function arguments * return plain `const SrcState &` instead of `ErrorOr<const SrcState &>` from `SrcSafetyAnalysis::getStateBefore`, as absent state is not handled gracefully by any caller --- bolt/include/bolt/Passes/PAuthGadgetScanner.h | 8 +--- bolt/lib/Passes/PAuthGadgetScanner.cpp | 39 ++++++++----------- 2 files changed, 19 insertions(+), 28 deletions(-) diff --git a/bolt/include/bolt/Passes/PAuthGadgetScanner.h b/bolt/include/bolt/Passes/PAuthGadgetScanner.h index 6765e2aff414f..3e39b64e59e0f 100644 --- a/bolt/include/bolt/Passes/PAuthGadgetScanner.h +++ b/bolt/include/bolt/Passes/PAuthGadgetScanner.h @@ -12,7 +12,6 @@ #include "bolt/Core/BinaryContext.h" #include "bolt/Core/BinaryFunction.h" #include "bolt/Passes/BinaryPasses.h" -#include "llvm/ADT/SmallSet.h" #include "llvm/Support/raw_ostream.h" #include <memory> @@ -199,9 +198,6 @@ raw_ostream &operator<<(raw_ostream &OS, const MCInstReference &); namespace PAuthGadgetScanner { -class SrcSafetyAnalysis; -struct SrcState; - /// Description of a gadget kind that can be detected. Intended to be /// statically allocated to be attached to reports by reference. class GadgetKind { @@ -210,7 +206,7 @@ class GadgetKind { public: GadgetKind(const char *Description) : Description(Description) {} - const StringRef getDescription() const { return Description; } + StringRef getDescription() const { return Description; } }; /// Base report located at some instruction, without any additional information. @@ -261,7 +257,7 @@ struct GadgetReport : public Report { /// Report with a free-form message attached. struct GenericReport : public Report { std::string Text; - GenericReport(MCInstReference Location, const std::string &Text) + GenericReport(MCInstReference Location, StringRef Text) : Report(Location), Text(Text) {} virtual void generateReport(raw_ostream &OS, const BinaryContext &BC) const override; diff --git a/bolt/lib/Passes/PAuthGadgetScanner.cpp b/bolt/lib/Passes/PAuthGadgetScanner.cpp index ad47bdff753c8..ed89471cbb8d3 100644 --- a/bolt/lib/Passes/PAuthGadgetScanner.cpp +++ b/bolt/lib/Passes/PAuthGadgetScanner.cpp @@ -91,14 +91,14 @@ class TrackedRegisters { const std::vector<MCPhysReg> Registers; std::vector<uint16_t> RegToIndexMapping; - static size_t getMappingSize(const std::vector<MCPhysReg> &RegsToTrack) { + static size_t getMappingSize(const ArrayRef<MCPhysReg> RegsToTrack) { if (RegsToTrack.empty()) return 0; return 1 + *llvm::max_element(RegsToTrack); } public: - TrackedRegisters(const std::vector<MCPhysReg> &RegsToTrack) + TrackedRegisters(const ArrayRef<MCPhysReg> RegsToTrack) : Registers(RegsToTrack), RegToIndexMapping(getMappingSize(RegsToTrack), NoIndex) { for (unsigned I = 0; I < RegsToTrack.size(); ++I) @@ -234,7 +234,7 @@ struct SrcState { static void printLastInsts( raw_ostream &OS, - const std::vector<SmallPtrSet<const MCInst *, 4>> &LastInstWritingReg) { + const ArrayRef<SmallPtrSet<const MCInst *, 4>> LastInstWritingReg) { OS << "Insts: "; for (unsigned I = 0; I < LastInstWritingReg.size(); ++I) { auto &Set = LastInstWritingReg[I]; @@ -295,7 +295,7 @@ void SrcStatePrinter::print(raw_ostream &OS, const SrcState &S) const { class SrcSafetyAnalysis { public: SrcSafetyAnalysis(BinaryFunction &BF, - const std::vector<MCPhysReg> &RegsToTrackInstsFor) + const ArrayRef<MCPhysReg> RegsToTrackInstsFor) : BC(BF.getBinaryContext()), NumRegs(BC.MRI->getNumRegs()), RegsToTrackInstsFor(RegsToTrackInstsFor) {} @@ -303,11 +303,10 @@ class SrcSafetyAnalysis { static std::shared_ptr<SrcSafetyAnalysis> create(BinaryFunction &BF, MCPlusBuilder::AllocatorIdTy AllocId, - const std::vector<MCPhysReg> &RegsToTrackInstsFor); + const ArrayRef<MCPhysReg> RegsToTrackInstsFor); virtual void run() = 0; - virtual ErrorOr<const SrcState &> - getStateBefore(const MCInst &Inst) const = 0; + virtual const SrcState &getStateBefore(const MCInst &Inst) const = 0; protected: BinaryContext &BC; @@ -348,7 +347,7 @@ class SrcSafetyAnalysis { } BitVector getClobberedRegs(const MCInst &Point) const { - BitVector Clobbered(NumRegs, false); + BitVector Clobbered(NumRegs); // Assume a call can clobber all registers, including callee-saved // registers. There's a good chance that callee-saved registers will be // saved on the stack at some point during execution of the callee. @@ -409,8 +408,7 @@ class SrcSafetyAnalysis { // FirstCheckerInst should belong to the same basic block, meaning // it was deterministically processed a few steps before this instruction. - const SrcState &StateBeforeChecker = - getStateBefore(*FirstCheckerInst).get(); + const SrcState &StateBeforeChecker = getStateBefore(*FirstCheckerInst); if (StateBeforeChecker.SafeToDerefRegs[CheckedReg]) Regs.push_back(CheckedReg); } @@ -523,10 +521,7 @@ class SrcSafetyAnalysis { const ArrayRef<MCPhysReg> UsedDirtyRegs) const { if (RegsToTrackInstsFor.empty()) return {}; - auto MaybeState = getStateBefore(Inst); - if (!MaybeState) - llvm_unreachable("Expected state to be present"); - const SrcState &S = *MaybeState; + const SrcState &S = getStateBefore(Inst); // Due to aliasing registers, multiple registers may have been tracked. std::set<const MCInst *> LastWritingInsts; for (MCPhysReg TrackedReg : UsedDirtyRegs) { @@ -537,7 +532,7 @@ class SrcSafetyAnalysis { for (const MCInst *Inst : LastWritingInsts) { MCInstReference Ref = MCInstReference::get(Inst, BF); assert(Ref && "Expected Inst to be found"); - Result.push_back(MCInstReference(Ref)); + Result.push_back(Ref); } return Result; } @@ -557,11 +552,11 @@ class DataflowSrcSafetyAnalysis public: DataflowSrcSafetyAnalysis(BinaryFunction &BF, MCPlusBuilder::AllocatorIdTy AllocId, - const std::vector<MCPhysReg> &RegsToTrackInstsFor) + const ArrayRef<MCPhysReg> RegsToTrackInstsFor) : SrcSafetyAnalysis(BF, RegsToTrackInstsFor), DFParent(BF, AllocId) {} - ErrorOr<const SrcState &> getStateBefore(const MCInst &Inst) const override { - return DFParent::getStateBefore(Inst); + const SrcState &getStateBefore(const MCInst &Inst) const override { + return DFParent::getStateBefore(Inst).get(); } void run() override { @@ -670,7 +665,7 @@ class CFGUnawareSrcSafetyAnalysis : public SrcSafetyAnalysis { public: CFGUnawareSrcSafetyAnalysis(BinaryFunction &BF, MCPlusBuilder::AllocatorIdTy AllocId, - const std::vector<MCPhysReg> &RegsToTrackInstsFor) + const ArrayRef<MCPhysReg> RegsToTrackInstsFor) : SrcSafetyAnalysis(BF, RegsToTrackInstsFor), BF(BF), AllocId(AllocId) { StateAnnotationIndex = BC.MIB->getOrCreateAnnotationIndex("CFGUnawareSrcSafetyAnalysis"); @@ -704,7 +699,7 @@ class CFGUnawareSrcSafetyAnalysis : public SrcSafetyAnalysis { } } - ErrorOr<const SrcState &> getStateBefore(const MCInst &Inst) const override { + const SrcState &getStateBefore(const MCInst &Inst) const override { return BC.MIB->getAnnotationAs<SrcState>(Inst, StateAnnotationIndex); } @@ -714,7 +709,7 @@ class CFGUnawareSrcSafetyAnalysis : public SrcSafetyAnalysis { std::shared_ptr<SrcSafetyAnalysis> SrcSafetyAnalysis::create(BinaryFunction &BF, MCPlusBuilder::AllocatorIdTy AllocId, - const std::vector<MCPhysReg> &RegsToTrackInstsFor) { + const ArrayRef<MCPhysReg> RegsToTrackInstsFor) { if (BF.hasCFG()) return std::make_shared<DataflowSrcSafetyAnalysis>(BF, AllocId, RegsToTrackInstsFor); @@ -821,7 +816,7 @@ Analysis::findGadgets(BinaryFunction &BF, BinaryContext &BC = BF.getBinaryContext(); iterateOverInstrs(BF, [&](MCInstReference Inst) { - const SrcState &S = *Analysis->getStateBefore(Inst); + const SrcState &S = Analysis->getStateBefore(Inst); // If non-empty state was never propagated from the entry basic block // to Inst, assume it to be unreachable and report a warning. _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits