https://github.com/svkeerthy created https://github.com/llvm/llvm-project/pull/156952
None >From 6185e40a9a6731955e190131067cc3c5bc90595e Mon Sep 17 00:00:00 2001 From: svkeerthy <[email protected]> Date: Wed, 3 Sep 2025 22:56:08 +0000 Subject: [PATCH] Support predicates --- llvm/include/llvm/Analysis/IR2Vec.h | 43 +++++++++++++--- llvm/lib/Analysis/IR2Vec.cpp | 70 ++++++++++++++++++++++---- llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp | 2 +- 3 files changed, 98 insertions(+), 17 deletions(-) diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h index b7b881999241e..d49854e2d06a8 100644 --- a/llvm/include/llvm/Analysis/IR2Vec.h +++ b/llvm/include/llvm/Analysis/IR2Vec.h @@ -36,6 +36,7 @@ #define LLVM_ANALYSIS_IR2VEC_H #include "llvm/ADT/DenseMap.h" +#include "llvm/IR/Instructions.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/Type.h" #include "llvm/Support/CommandLine.h" @@ -162,16 +163,25 @@ using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>; /// embeddings. class Vocabulary { friend class llvm::IR2VecVocabAnalysis; + // Slot layout: + // [0 .. MaxOpcodes-1] => Instruction + // opcodes [MaxOpcodes .. MaxOpcodes+MaxCanonicalTypeIDs-1] => + // Canonicalized types [MaxOpcodes+MaxCanonicalTypeIDs .. end of operands) => + // Operands + // Within Operands: first OperandKind entries, followed by compare + // predicates using VocabVector = std::vector<ir2vec::Embedding>; VocabVector Vocab; + bool Valid = false; + static constexpr unsigned NumICmpPredicates = + static_cast<unsigned>(CmpInst::LAST_ICMP_PREDICATE) - + static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE) + 1; + static constexpr unsigned NumFCmpPredicates = + static_cast<unsigned>(CmpInst::LAST_FCMP_PREDICATE) - + static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE) + 1; public: - // Slot layout: - // [0 .. MaxOpcodes-1] => Instruction opcodes - // [MaxOpcodes .. MaxOpcodes+MaxCanonicalTypeIDs-1] => Canonicalized types - // [MaxOpcodes+MaxCanonicalTypeIDs .. NumCanonicalEntries-1] => Operand kinds - /// Canonical type IDs supported by IR2Vec Vocabulary enum class CanonicalTypeID : unsigned { FloatTy, @@ -208,13 +218,18 @@ class Vocabulary { static_cast<unsigned>(CanonicalTypeID::MaxCanonicalType); static constexpr unsigned MaxOperandKinds = static_cast<unsigned>(OperandKind::MaxOperandKind); + // CmpInst::Predicate has gaps. We want the vocabulary to be dense without + // empty slots. + static constexpr unsigned MaxPredicateKinds = + NumICmpPredicates + NumFCmpPredicates; Vocabulary() = default; LLVM_ABI Vocabulary(VocabVector &&Vocab); LLVM_ABI bool isValid() const; LLVM_ABI unsigned getDimension() const; - /// Total number of entries (opcodes + canonicalized types + operand kinds) + /// Total number of entries (opcodes + canonicalized types + operand kinds + + /// predicates) static constexpr size_t getCanonicalSize() { return NumCanonicalEntries; } /// Function to get vocabulary key for a given Opcode @@ -229,16 +244,21 @@ class Vocabulary { /// Function to classify an operand into OperandKind LLVM_ABI static OperandKind getOperandKind(const Value *Op); + /// Function to get vocabulary key for a given predicate + LLVM_ABI static StringRef getVocabKeyForPredicate(CmpInst::Predicate P); + /// Functions to return the slot index or position of a given Opcode, TypeID, /// or OperandKind in the vocabulary. LLVM_ABI static unsigned getSlotIndex(unsigned Opcode); LLVM_ABI static unsigned getSlotIndex(Type::TypeID TypeID); LLVM_ABI static unsigned getSlotIndex(const Value &Op); + LLVM_ABI static unsigned getSlotIndex(CmpInst::Predicate P); /// Accessors to get the embedding for a given entity. LLVM_ABI const ir2vec::Embedding &operator[](unsigned Opcode) const; LLVM_ABI const ir2vec::Embedding &operator[](Type::TypeID TypeId) const; LLVM_ABI const ir2vec::Embedding &operator[](const Value &Arg) const; + LLVM_ABI const ir2vec::Embedding &operator[](CmpInst::Predicate P) const; /// Const Iterator type aliases using const_iterator = VocabVector::const_iterator; @@ -275,7 +295,13 @@ class Vocabulary { private: constexpr static unsigned NumCanonicalEntries = - MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds; + MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds + MaxPredicateKinds; + + // Base offsets for slot layout to simplify index computation + constexpr static unsigned OperandBaseOffset = + MaxOpcodes + MaxCanonicalTypeIDs; + constexpr static unsigned PredicateBaseOffset = + OperandBaseOffset + MaxOperandKinds; /// String mappings for CanonicalTypeID values static constexpr StringLiteral CanonicalTypeNames[] = { @@ -327,6 +353,9 @@ class Vocabulary { /// Function to convert TypeID to CanonicalTypeID LLVM_ABI static CanonicalTypeID getCanonicalTypeID(Type::TypeID TypeID); + + /// Function to get the predicate enum value for a given index + LLVM_ABI static CmpInst::Predicate getPredicate(unsigned Index); }; /// Embedder provides the interface to generate embeddings (vector diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp index 98849fd922843..c79c9c1ed493d 100644 --- a/llvm/lib/Analysis/IR2Vec.cpp +++ b/llvm/lib/Analysis/IR2Vec.cpp @@ -216,6 +216,8 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const { ArgEmb += Vocab[*Op]; auto InstVector = Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb; + if (const auto *IC = dyn_cast<CmpInst>(&I)) + InstVector += Vocab[IC->getPredicate()]; InstVecMap[&I] = InstVector; BBVector += InstVector; } @@ -250,6 +252,9 @@ void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const { // embeddings auto InstVector = Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb; + // Add compare predicate embedding as an additional operand if applicable + if (const auto *IC = dyn_cast<CmpInst>(&I)) + InstVector += Vocab[IC->getPredicate()]; InstVecMap[&I] = InstVector; BBVector += InstVector; } @@ -285,7 +290,17 @@ unsigned Vocabulary::getSlotIndex(Type::TypeID TypeID) { unsigned Vocabulary::getSlotIndex(const Value &Op) { unsigned Index = static_cast<unsigned>(getOperandKind(&Op)); assert(Index < MaxOperandKinds && "Invalid OperandKind"); - return MaxOpcodes + MaxCanonicalTypeIDs + Index; + return OperandBaseOffset + Index; +} + +unsigned Vocabulary::getSlotIndex(CmpInst::Predicate P) { + unsigned PU = static_cast<unsigned>(P); + unsigned FirstFC = static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE); + unsigned FirstIC = static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE); + + unsigned PredIdx = + (PU >= FirstIC) ? (NumFCmpPredicates + (PU - FirstIC)) : (PU - FirstFC); + return PredicateBaseOffset + PredIdx; } const Embedding &Vocabulary::operator[](unsigned Opcode) const { @@ -300,6 +315,10 @@ const ir2vec::Embedding &Vocabulary::operator[](const Value &Arg) const { return Vocab[getSlotIndex(Arg)]; } +const ir2vec::Embedding &Vocabulary::operator[](CmpInst::Predicate P) const { + return Vocab[getSlotIndex(P)]; +} + StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) { assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode"); #define HANDLE_INST(NUM, OPCODE, CLASS) \ @@ -345,18 +364,35 @@ Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) { return OperandKind::VariableID; } +CmpInst::Predicate Vocabulary::getPredicate(unsigned Index) { + assert(Index < MaxPredicateKinds && "Invalid predicate index"); + unsigned PredEnumVal = + (Index < NumFCmpPredicates) + ? (static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE) + Index) + : (static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE) + + (Index - NumFCmpPredicates)); + return static_cast<CmpInst::Predicate>(PredEnumVal); +} + +StringRef Vocabulary::getVocabKeyForPredicate(CmpInst::Predicate Pred) { + return CmpInst::getPredicateName(Pred); +} + StringRef Vocabulary::getStringKey(unsigned Pos) { assert(Pos < NumCanonicalEntries && "Position out of bounds in vocabulary"); // Opcode if (Pos < MaxOpcodes) return getVocabKeyForOpcode(Pos + 1); // Type - if (Pos < MaxOpcodes + MaxCanonicalTypeIDs) + if (Pos < OperandBaseOffset) return getVocabKeyForCanonicalTypeID( static_cast<CanonicalTypeID>(Pos - MaxOpcodes)); // Operand - return getVocabKeyForOperandKind( - static_cast<OperandKind>(Pos - MaxOpcodes - MaxCanonicalTypeIDs)); + if (Pos < PredicateBaseOffset) + return getVocabKeyForOperandKind( + static_cast<OperandKind>(Pos - OperandBaseOffset)); + // Predicates + return getVocabKeyForPredicate(getPredicate(Pos - PredicateBaseOffset)); } // For now, assume vocabulary is stable unless explicitly invalidated. @@ -370,11 +406,9 @@ Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) { VocabVector DummyVocab; DummyVocab.reserve(NumCanonicalEntries); float DummyVal = 0.1f; - // Create a dummy vocabulary with entries for all opcodes, types, and - // operands - for ([[maybe_unused]] unsigned _ : - seq(0u, Vocabulary::MaxOpcodes + Vocabulary::MaxCanonicalTypeIDs + - Vocabulary::MaxOperandKinds)) { + // Create a dummy vocabulary with entries for all opcodes, types, operands + // and predicates + for ([[maybe_unused]] unsigned _ : seq(0u, Vocabulary::NumCanonicalEntries)) { DummyVocab.push_back(Embedding(Dim, DummyVal)); DummyVal += 0.1f; } @@ -517,6 +551,24 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() { } Vocab.insert(Vocab.end(), NumericArgEmbeddings.begin(), NumericArgEmbeddings.end()); + + // Handle Predicates: part of Operands section. We look up predicate keys + // in ArgVocab. + std::vector<Embedding> NumericPredEmbeddings(Vocabulary::MaxPredicateKinds, + Embedding(Dim, 0)); + NumericPredEmbeddings.reserve(Vocabulary::MaxPredicateKinds); + for (unsigned PK : seq(0u, Vocabulary::MaxPredicateKinds)) { + StringRef VocabKey = + Vocabulary::getVocabKeyForPredicate(Vocabulary::getPredicate(PK)); + auto It = ArgVocab.find(VocabKey.str()); + if (It != ArgVocab.end()) { + NumericPredEmbeddings[PK] = It->second; + continue; + } + handleMissingEntity(VocabKey.str()); + } + Vocab.insert(Vocab.end(), NumericPredEmbeddings.begin(), + NumericPredEmbeddings.end()); } IR2VecVocabAnalysis::IR2VecVocabAnalysis(const VocabVector &Vocab) diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp index aabebf0cc90a9..1c656b8fcf4e7 100644 --- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp +++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp @@ -184,7 +184,7 @@ class IR2VecTool { // Add "Arg" relationships unsigned ArgIndex = 0; for (const Use &U : I.operands()) { - unsigned OperandID = Vocabulary::getSlotIndex(*U); + unsigned OperandID = Vocabulary::getSlotIndex(*U.get()); unsigned RelationID = ArgRelation + ArgIndex; OS << Opcode << '\t' << OperandID << '\t' << RelationID << '\n'; _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
