https://github.com/svkeerthy updated https://github.com/llvm/llvm-project/pull/155323
>From 267e2daa5fb42cf72c21ef7f7c3fdd02db7cfc5f Mon Sep 17 00:00:00 2001 From: svkeerthy <venkatakeer...@google.com> Date: Mon, 25 Aug 2025 22:58:43 +0000 Subject: [PATCH] Canonicalized type --- llvm/include/llvm/Analysis/IR2Vec.h | 135 +++++++++-- llvm/lib/Analysis/IR2Vec.cpp | 164 ++++++-------- .../Inputs/reference_default_vocab_print.txt | 11 +- .../Inputs/reference_wtd1_vocab_print.txt | 11 +- .../Inputs/reference_wtd2_vocab_print.txt | 11 +- llvm/test/tools/llvm-ir2vec/entities.ll | 41 ++-- llvm/test/tools/llvm-ir2vec/triplets.ll | 58 ++--- llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp | 14 +- llvm/unittests/Analysis/IR2VecTest.cpp | 213 ++++++++++-------- 9 files changed, 350 insertions(+), 308 deletions(-) diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h index 7ace83ba1d053..26a3d303e7ab0 100644 --- a/llvm/include/llvm/Analysis/IR2Vec.h +++ b/llvm/include/llvm/Analysis/IR2Vec.h @@ -36,6 +36,7 @@ #include "llvm/Support/Compiler.h" #include "llvm/Support/ErrorOr.h" #include "llvm/Support/JSON.h" +#include <array> #include <map> namespace llvm { @@ -137,13 +138,51 @@ using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>; using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>; /// Class for storing and accessing the IR2Vec vocabulary. -/// Encapsulates all vocabulary-related constants, logic, and access methods. +/// The Vocabulary class manages seed embeddings for LLVM IR entities. The +/// seed embeddings are the initial learned representations of the entities +/// of LLVM IR. The IR2Vec representation for a given IR is derived from these +/// seed embeddings. +/// +/// The vocabulary contains the seed embeddings for three types of entities: +/// instruction opcodes, types, and operands. Types are grouped/canonicalized +/// for better learning (e.g., all float variants map to FloatTy). The +/// vocabulary abstracts away the canonicalization effectively, the exposed APIs +/// handle all the known LLVM IR opcodes, types and operands. +/// +/// This class helps populate the seed embeddings in an internal vector-based +/// ADT. It provides logic to map every IR entity to a specific slot index or +/// position in this vector, enabling O(1) embedding lookup while avoiding +/// unnecessary computations involving string based lookups while generating the +/// embeddings. class Vocabulary { friend class llvm::IR2VecVocabAnalysis; using VocabVector = std::vector<ir2vec::Embedding>; VocabVector Vocab; bool Valid = false; +public: + // Slot layout: + // [0 .. MaxOpcodes-1] => Instruction opcodes + // [MaxOpcodes .. MaxOpcodes+MaxCanonicalTypeIDs-1] => Canonicalized types + // [MaxOpcodes+MaxCanonicalTypeIDs .. NumCanonicalEntries-1] => Operand kinds + + /// Canonical type IDs supported by IR2Vec Vocabulary + enum class CanonicalTypeID : unsigned { + FloatTy, + VoidTy, + LabelTy, + MetadataTy, + VectorTy, + TokenTy, + IntegerTy, + FunctionTy, + PointerTy, + StructTy, + ArrayTy, + UnknownTy, + MaxCanonicalType + }; + /// Operand kinds supported by IR2Vec Vocabulary enum class OperandKind : unsigned { FunctionID, @@ -152,20 +191,15 @@ class Vocabulary { VariableID, MaxOperandKind }; - /// String mappings for OperandKind values - static constexpr StringLiteral OperandKindNames[] = {"Function", "Pointer", - "Constant", "Variable"}; - static_assert(std::size(OperandKindNames) == - static_cast<unsigned>(OperandKind::MaxOperandKind), - "OperandKindNames array size must match MaxOperandKind"); -public: /// Vocabulary layout constants #define LAST_OTHER_INST(NUM) static constexpr unsigned MaxOpcodes = NUM; #include "llvm/IR/Instruction.def" #undef LAST_OTHER_INST static constexpr unsigned MaxTypeIDs = Type::TypeID::TargetExtTyID + 1; + static constexpr unsigned MaxCanonicalTypeIDs = + static_cast<unsigned>(CanonicalTypeID::MaxCanonicalType); static constexpr unsigned MaxOperandKinds = static_cast<unsigned>(OperandKind::MaxOperandKind); @@ -174,33 +208,31 @@ class Vocabulary { LLVM_ABI bool isValid() const; LLVM_ABI unsigned getDimension() const; - LLVM_ABI size_t size() const; + /// Total number of entries (opcodes + canonicalized types + operand kinds) + static constexpr size_t getCanonicalSize() { return NumCanonicalEntries; } - static size_t expectedSize() { - return MaxOpcodes + MaxTypeIDs + MaxOperandKinds; - } - - /// Helper function to get vocabulary key for a given Opcode + /// Function to get vocabulary key for a given Opcode LLVM_ABI static StringRef getVocabKeyForOpcode(unsigned Opcode); - /// Helper function to get vocabulary key for a given TypeID + /// Function to get vocabulary key for a given TypeID LLVM_ABI static StringRef getVocabKeyForTypeID(Type::TypeID TypeID); - /// Helper function to get vocabulary key for a given OperandKind + /// Function to get vocabulary key for a given OperandKind LLVM_ABI static StringRef getVocabKeyForOperandKind(OperandKind Kind); - /// Helper function to classify an operand into OperandKind + /// Function to classify an operand into OperandKind LLVM_ABI static OperandKind getOperandKind(const Value *Op); - /// Helpers to return the IDs of a given Opcode, TypeID, or OperandKind - LLVM_ABI static unsigned getNumericID(unsigned Opcode); - LLVM_ABI static unsigned getNumericID(Type::TypeID TypeID); - LLVM_ABI static unsigned getNumericID(const Value *Op); + /// Functions to return the slot index or position of a given Opcode, TypeID, + /// or OperandKind in the vocabulary. + LLVM_ABI static unsigned getSlotIndex(unsigned Opcode); + LLVM_ABI static unsigned getSlotIndex(Type::TypeID TypeID); + LLVM_ABI static unsigned getSlotIndex(const Value *Op); /// Accessors to get the embedding for a given entity. LLVM_ABI const ir2vec::Embedding &operator[](unsigned Opcode) const; LLVM_ABI const ir2vec::Embedding &operator[](Type::TypeID TypeId) const; - LLVM_ABI const ir2vec::Embedding &operator[](const Value *Arg) const; + LLVM_ABI const ir2vec::Embedding &operator[](const Value &Arg) const; /// Const Iterator type aliases using const_iterator = VocabVector::const_iterator; @@ -234,6 +266,61 @@ class Vocabulary { LLVM_ABI bool invalidate(Module &M, const PreservedAnalyses &PA, ModuleAnalysisManager::Invalidator &Inv) const; + +private: + constexpr static unsigned NumCanonicalEntries = + MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds; + + /// String mappings for CanonicalTypeID values + static constexpr StringLiteral CanonicalTypeNames[] = { + "FloatTy", "VoidTy", "LabelTy", "MetadataTy", + "VectorTy", "TokenTy", "IntegerTy", "FunctionTy", + "PointerTy", "StructTy", "ArrayTy", "UnknownTy"}; + static_assert(std::size(CanonicalTypeNames) == + static_cast<unsigned>(CanonicalTypeID::MaxCanonicalType), + "CanonicalTypeNames array size must match MaxCanonicalType"); + + /// String mappings for OperandKind values + static constexpr StringLiteral OperandKindNames[] = {"Function", "Pointer", + "Constant", "Variable"}; + static_assert(std::size(OperandKindNames) == + static_cast<unsigned>(OperandKind::MaxOperandKind), + "OperandKindNames array size must match MaxOperandKind"); + + /// Every known TypeID defined in llvm/IR/Type.h is expected to have a + /// corresponding mapping here in the same order as enum Type::TypeID. + static constexpr std::array<CanonicalTypeID, MaxTypeIDs> TypeIDMapping = {{ + CanonicalTypeID::FloatTy, // HalfTyID = 0 + CanonicalTypeID::FloatTy, // BFloatTyID + CanonicalTypeID::FloatTy, // FloatTyID + CanonicalTypeID::FloatTy, // DoubleTyID + CanonicalTypeID::FloatTy, // X86_FP80TyID + CanonicalTypeID::FloatTy, // FP128TyID + CanonicalTypeID::FloatTy, // PPC_FP128TyID + CanonicalTypeID::VoidTy, // VoidTyID + CanonicalTypeID::LabelTy, // LabelTyID + CanonicalTypeID::MetadataTy, // MetadataTyID + CanonicalTypeID::VectorTy, // X86_AMXTyID + CanonicalTypeID::TokenTy, // TokenTyID + CanonicalTypeID::IntegerTy, // IntegerTyID + CanonicalTypeID::FunctionTy, // FunctionTyID + CanonicalTypeID::PointerTy, // PointerTyID + CanonicalTypeID::StructTy, // StructTyID + CanonicalTypeID::ArrayTy, // ArrayTyID + CanonicalTypeID::VectorTy, // FixedVectorTyID + CanonicalTypeID::VectorTy, // ScalableVectorTyID + CanonicalTypeID::PointerTy, // TypedPointerTyID + CanonicalTypeID::UnknownTy // TargetExtTyID + }}; + static_assert(TypeIDMapping.size() == MaxTypeIDs, + "TypeIDMapping must cover all Type::TypeID values"); + + /// Function to get vocabulary key for canonical type by enum + LLVM_ABI static StringRef + getVocabKeyForCanonicalTypeID(CanonicalTypeID CType); + + /// Function to convert TypeID to CanonicalTypeID + LLVM_ABI static CanonicalTypeID getCanonicalTypeID(Type::TypeID TypeID); }; /// Embedder provides the interface to generate embeddings (vector @@ -262,11 +349,11 @@ class Embedder { LLVM_ABI Embedder(const Function &F, const Vocabulary &Vocab); - /// Helper function to compute embeddings. It generates embeddings for all + /// Function to compute embeddings. It generates embeddings for all /// the instructions and basic blocks in the function F. void computeEmbeddings() const; - /// Helper function to compute the embedding for a given basic block. + /// Function to compute the embedding for a given basic block. /// Specific to the kind of embeddings being computed. virtual void computeEmbeddings(const BasicBlock &BB) const = 0; diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp index e28938b64bfdb..c578c5d606bdc 100644 --- a/llvm/lib/Analysis/IR2Vec.cpp +++ b/llvm/lib/Analysis/IR2Vec.cpp @@ -32,7 +32,7 @@ using namespace ir2vec; #define DEBUG_TYPE "ir2vec" STATISTIC(VocabMissCounter, - "Number of lookups to entites not present in the vocabulary"); + "Number of lookups to entities not present in the vocabulary"); namespace llvm { namespace ir2vec { @@ -213,7 +213,7 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const { for (const auto &I : BB.instructionsWithoutDebug()) { Embedding ArgEmb(Dimension, 0); for (const auto &Op : I.operands()) - ArgEmb += Vocab[Op]; + ArgEmb += Vocab[*Op]; auto InstVector = Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb; InstVecMap[&I] = InstVector; @@ -242,8 +242,8 @@ void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const { // If the operand is not defined by an instruction, we use the vocabulary else { LLVM_DEBUG(errs() << "Using embedding from vocabulary for operand: " - << *Op << "=" << Vocab[Op][0] << "\n"); - ArgEmb += Vocab[Op]; + << *Op << "=" << Vocab[*Op][0] << "\n"); + ArgEmb += Vocab[*Op]; } } // Create the instruction vector by combining opcode, type, and arguments @@ -264,12 +264,7 @@ Vocabulary::Vocabulary(VocabVector &&Vocab) : Vocab(std::move(Vocab)), Valid(true) {} bool Vocabulary::isValid() const { - return Vocab.size() == Vocabulary::expectedSize() && Valid; -} - -size_t Vocabulary::size() const { - assert(Valid && "IR2Vec Vocabulary is invalid"); - return Vocab.size(); + return Vocab.size() == NumCanonicalEntries && Valid; } unsigned Vocabulary::getDimension() const { @@ -277,19 +272,32 @@ unsigned Vocabulary::getDimension() const { return Vocab[0].size(); } -const Embedding &Vocabulary::operator[](unsigned Opcode) const { +unsigned Vocabulary::getSlotIndex(unsigned Opcode) { assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode"); - return Vocab[Opcode - 1]; + return Opcode - 1; // Convert to zero-based index +} + +unsigned Vocabulary::getSlotIndex(Type::TypeID TypeID) { + assert(static_cast<unsigned>(TypeID) < MaxTypeIDs && "Invalid type ID"); + return MaxOpcodes + static_cast<unsigned>(getCanonicalTypeID(TypeID)); +} + +unsigned Vocabulary::getSlotIndex(const Value *Op) { + unsigned Index = static_cast<unsigned>(getOperandKind(Op)); + assert(Index < MaxOperandKinds && "Invalid OperandKind"); + return MaxOpcodes + MaxCanonicalTypeIDs + Index; +} + +const Embedding &Vocabulary::operator[](unsigned Opcode) const { + return Vocab[getSlotIndex(Opcode)]; } -const Embedding &Vocabulary::operator[](Type::TypeID TypeId) const { - assert(static_cast<unsigned>(TypeId) < MaxTypeIDs && "Invalid type ID"); - return Vocab[MaxOpcodes + static_cast<unsigned>(TypeId)]; +const Embedding &Vocabulary::operator[](Type::TypeID TypeID) const { + return Vocab[getSlotIndex(TypeID)]; } -const ir2vec::Embedding &Vocabulary::operator[](const Value *Arg) const { - OperandKind ArgKind = getOperandKind(Arg); - return Vocab[MaxOpcodes + MaxTypeIDs + static_cast<unsigned>(ArgKind)]; +const ir2vec::Embedding &Vocabulary::operator[](const Value &Arg) const { + return Vocab[getSlotIndex(&Arg)]; } StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) { @@ -303,43 +311,21 @@ StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) { return "UnknownOpcode"; } +StringRef Vocabulary::getVocabKeyForCanonicalTypeID(CanonicalTypeID CType) { + unsigned Index = static_cast<unsigned>(CType); + assert(Index < MaxCanonicalTypeIDs && "Invalid CanonicalTypeID"); + return CanonicalTypeNames[Index]; +} + +Vocabulary::CanonicalTypeID +Vocabulary::getCanonicalTypeID(Type::TypeID TypeID) { + unsigned Index = static_cast<unsigned>(TypeID); + assert(Index < MaxTypeIDs && "Invalid TypeID"); + return TypeIDMapping[Index]; +} + StringRef Vocabulary::getVocabKeyForTypeID(Type::TypeID TypeID) { - switch (TypeID) { - case Type::VoidTyID: - return "VoidTy"; - case Type::HalfTyID: - case Type::BFloatTyID: - case Type::FloatTyID: - case Type::DoubleTyID: - case Type::X86_FP80TyID: - case Type::FP128TyID: - case Type::PPC_FP128TyID: - return "FloatTy"; - case Type::IntegerTyID: - return "IntegerTy"; - case Type::FunctionTyID: - return "FunctionTy"; - case Type::StructTyID: - return "StructTy"; - case Type::ArrayTyID: - return "ArrayTy"; - case Type::PointerTyID: - case Type::TypedPointerTyID: - return "PointerTy"; - case Type::FixedVectorTyID: - case Type::ScalableVectorTyID: - return "VectorTy"; - case Type::LabelTyID: - return "LabelTy"; - case Type::TokenTyID: - return "TokenTy"; - case Type::MetadataTyID: - return "MetadataTy"; - case Type::X86_AMXTyID: - case Type::TargetExtTyID: - return "UnknownTy"; - } - return "UnknownTy"; + return getVocabKeyForCanonicalTypeID(getCanonicalTypeID(TypeID)); } StringRef Vocabulary::getVocabKeyForOperandKind(Vocabulary::OperandKind Kind) { @@ -348,20 +334,6 @@ StringRef Vocabulary::getVocabKeyForOperandKind(Vocabulary::OperandKind Kind) { return OperandKindNames[Index]; } -Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) { - VocabVector DummyVocab; - float DummyVal = 0.1f; - // Create a dummy vocabulary with entries for all opcodes, types, and - // operand - for ([[maybe_unused]] unsigned _ : - seq(0u, Vocabulary::MaxOpcodes + Vocabulary::MaxTypeIDs + - Vocabulary::MaxOperandKinds)) { - DummyVocab.push_back(Embedding(Dim, DummyVal)); - DummyVal += 0.1f; - } - return DummyVocab; -} - // Helper function to classify an operand into OperandKind Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) { if (isa<Function>(Op)) @@ -373,34 +345,18 @@ Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) { return OperandKind::VariableID; } -unsigned Vocabulary::getNumericID(unsigned Opcode) { - assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode"); - return Opcode - 1; // Convert to zero-based index -} - -unsigned Vocabulary::getNumericID(Type::TypeID TypeID) { - assert(static_cast<unsigned>(TypeID) < MaxTypeIDs && "Invalid type ID"); - return MaxOpcodes + static_cast<unsigned>(TypeID); -} - -unsigned Vocabulary::getNumericID(const Value *Op) { - unsigned Index = static_cast<unsigned>(getOperandKind(Op)); - assert(Index < MaxOperandKinds && "Invalid OperandKind"); - return MaxOpcodes + MaxTypeIDs + Index; -} - StringRef Vocabulary::getStringKey(unsigned Pos) { - assert(Pos < Vocabulary::expectedSize() && - "Position out of bounds in vocabulary"); + assert(Pos < NumCanonicalEntries && "Position out of bounds in vocabulary"); // Opcode if (Pos < MaxOpcodes) return getVocabKeyForOpcode(Pos + 1); // Type - if (Pos < MaxOpcodes + MaxTypeIDs) - return getVocabKeyForTypeID(static_cast<Type::TypeID>(Pos - MaxOpcodes)); + if (Pos < MaxOpcodes + MaxCanonicalTypeIDs) + return getVocabKeyForCanonicalTypeID( + static_cast<CanonicalTypeID>(Pos - MaxOpcodes)); // Operand return getVocabKeyForOperandKind( - static_cast<OperandKind>(Pos - MaxOpcodes - MaxTypeIDs)); + static_cast<OperandKind>(Pos - MaxOpcodes - MaxCanonicalTypeIDs)); } // For now, assume vocabulary is stable unless explicitly invalidated. @@ -410,6 +366,21 @@ bool Vocabulary::invalidate(Module &M, const PreservedAnalyses &PA, return !(PAC.preservedWhenStateless()); } +Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) { + VocabVector DummyVocab; + DummyVocab.reserve(NumCanonicalEntries); + float DummyVal = 0.1f; + // Create a dummy vocabulary with entries for all opcodes, types, and + // operands + for ([[maybe_unused]] unsigned _ : + seq(0u, Vocabulary::MaxOpcodes + Vocabulary::MaxCanonicalTypeIDs + + Vocabulary::MaxOperandKinds)) { + DummyVocab.push_back(Embedding(Dim, DummyVal)); + DummyVal += 0.1f; + } + return DummyVocab; +} + // ==----------------------------------------------------------------------===// // IR2VecVocabAnalysis //===----------------------------------------------------------------------===// @@ -502,6 +473,7 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() { // Handle Opcodes std::vector<Embedding> NumericOpcodeEmbeddings(Vocabulary::MaxOpcodes, Embedding(Dim, 0)); + NumericOpcodeEmbeddings.reserve(Vocabulary::MaxOpcodes); for (unsigned Opcode : seq(0u, Vocabulary::MaxOpcodes)) { StringRef VocabKey = Vocabulary::getVocabKeyForOpcode(Opcode + 1); auto It = OpcVocab.find(VocabKey.str()); @@ -513,14 +485,15 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() { Vocab.insert(Vocab.end(), NumericOpcodeEmbeddings.begin(), NumericOpcodeEmbeddings.end()); - // Handle Types - std::vector<Embedding> NumericTypeEmbeddings(Vocabulary::MaxTypeIDs, + // Handle Types - only canonical types are present in vocabulary + std::vector<Embedding> NumericTypeEmbeddings(Vocabulary::MaxCanonicalTypeIDs, Embedding(Dim, 0)); - for (unsigned TypeID : seq(0u, Vocabulary::MaxTypeIDs)) { - StringRef VocabKey = - Vocabulary::getVocabKeyForTypeID(static_cast<Type::TypeID>(TypeID)); + NumericTypeEmbeddings.reserve(Vocabulary::MaxCanonicalTypeIDs); + for (unsigned CTypeID : seq(0u, Vocabulary::MaxCanonicalTypeIDs)) { + StringRef VocabKey = Vocabulary::getVocabKeyForCanonicalTypeID( + static_cast<Vocabulary::CanonicalTypeID>(CTypeID)); if (auto It = TypeVocab.find(VocabKey.str()); It != TypeVocab.end()) { - NumericTypeEmbeddings[TypeID] = It->second; + NumericTypeEmbeddings[CTypeID] = It->second; continue; } handleMissingEntity(VocabKey.str()); @@ -531,6 +504,7 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() { // Handle Arguments/Operands std::vector<Embedding> NumericArgEmbeddings(Vocabulary::MaxOperandKinds, Embedding(Dim, 0)); + NumericArgEmbeddings.reserve(Vocabulary::MaxOperandKinds); for (unsigned OpKind : seq(0u, Vocabulary::MaxOperandKinds)) { Vocabulary::OperandKind Kind = static_cast<Vocabulary::OperandKind>(OpKind); StringRef VocabKey = Vocabulary::getVocabKeyForOperandKind(Kind); diff --git a/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt b/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt index 1b9b3c2acd8a5..df7769c9c6a65 100644 --- a/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt +++ b/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt @@ -67,25 +67,16 @@ Key: InsertValue: [ 129.00 130.00 ] Key: LandingPad: [ 131.00 132.00 ] Key: Freeze: [ 133.00 134.00 ] Key: FloatTy: [ 0.50 1.00 ] -Key: FloatTy: [ 0.50 1.00 ] -Key: FloatTy: [ 0.50 1.00 ] -Key: FloatTy: [ 0.50 1.00 ] -Key: FloatTy: [ 0.50 1.00 ] -Key: FloatTy: [ 0.50 1.00 ] -Key: FloatTy: [ 0.50 1.00 ] Key: VoidTy: [ 1.50 2.00 ] Key: LabelTy: [ 2.50 3.00 ] Key: MetadataTy: [ 3.50 4.00 ] -Key: UnknownTy: [ 4.50 5.00 ] +Key: VectorTy: [ 11.50 12.00 ] Key: TokenTy: [ 5.50 6.00 ] Key: IntegerTy: [ 6.50 7.00 ] Key: FunctionTy: [ 7.50 8.00 ] Key: PointerTy: [ 8.50 9.00 ] Key: StructTy: [ 9.50 10.00 ] Key: ArrayTy: [ 10.50 11.00 ] -Key: VectorTy: [ 11.50 12.00 ] -Key: VectorTy: [ 11.50 12.00 ] -Key: PointerTy: [ 8.50 9.00 ] Key: UnknownTy: [ 4.50 5.00 ] Key: Function: [ 0.20 0.40 ] Key: Pointer: [ 0.60 0.80 ] diff --git a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt index 9673e7f23fa5c..f3ce809fd2fd2 100644 --- a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt +++ b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt @@ -67,25 +67,16 @@ Key: InsertValue: [ 64.50 65.00 ] Key: LandingPad: [ 65.50 66.00 ] Key: Freeze: [ 66.50 67.00 ] Key: FloatTy: [ 0.50 1.00 ] -Key: FloatTy: [ 0.50 1.00 ] -Key: FloatTy: [ 0.50 1.00 ] -Key: FloatTy: [ 0.50 1.00 ] -Key: FloatTy: [ 0.50 1.00 ] -Key: FloatTy: [ 0.50 1.00 ] -Key: FloatTy: [ 0.50 1.00 ] Key: VoidTy: [ 1.50 2.00 ] Key: LabelTy: [ 2.50 3.00 ] Key: MetadataTy: [ 3.50 4.00 ] -Key: UnknownTy: [ 4.50 5.00 ] +Key: VectorTy: [ 11.50 12.00 ] Key: TokenTy: [ 5.50 6.00 ] Key: IntegerTy: [ 6.50 7.00 ] Key: FunctionTy: [ 7.50 8.00 ] Key: PointerTy: [ 8.50 9.00 ] Key: StructTy: [ 9.50 10.00 ] Key: ArrayTy: [ 10.50 11.00 ] -Key: VectorTy: [ 11.50 12.00 ] -Key: VectorTy: [ 11.50 12.00 ] -Key: PointerTy: [ 8.50 9.00 ] Key: UnknownTy: [ 4.50 5.00 ] Key: Function: [ 0.50 1.00 ] Key: Pointer: [ 1.50 2.00 ] diff --git a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt index 1f575d29092dd..72b25b9bd3d9c 100644 --- a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt +++ b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt @@ -67,25 +67,16 @@ Key: InsertValue: [ 12.90 13.00 ] Key: LandingPad: [ 13.10 13.20 ] Key: Freeze: [ 13.30 13.40 ] Key: FloatTy: [ 0.00 0.00 ] -Key: FloatTy: [ 0.00 0.00 ] -Key: FloatTy: [ 0.00 0.00 ] -Key: FloatTy: [ 0.00 0.00 ] -Key: FloatTy: [ 0.00 0.00 ] -Key: FloatTy: [ 0.00 0.00 ] -Key: FloatTy: [ 0.00 0.00 ] Key: VoidTy: [ 0.00 0.00 ] Key: LabelTy: [ 0.00 0.00 ] Key: MetadataTy: [ 0.00 0.00 ] -Key: UnknownTy: [ 0.00 0.00 ] +Key: VectorTy: [ 0.00 0.00 ] Key: TokenTy: [ 0.00 0.00 ] Key: IntegerTy: [ 0.00 0.00 ] Key: FunctionTy: [ 0.00 0.00 ] Key: PointerTy: [ 0.00 0.00 ] Key: StructTy: [ 0.00 0.00 ] Key: ArrayTy: [ 0.00 0.00 ] -Key: VectorTy: [ 0.00 0.00 ] -Key: VectorTy: [ 0.00 0.00 ] -Key: PointerTy: [ 0.00 0.00 ] Key: UnknownTy: [ 0.00 0.00 ] Key: Function: [ 0.00 0.00 ] Key: Pointer: [ 0.00 0.00 ] diff --git a/llvm/test/tools/llvm-ir2vec/entities.ll b/llvm/test/tools/llvm-ir2vec/entities.ll index 4ed6400d7a195..4b51adf30bf74 100644 --- a/llvm/test/tools/llvm-ir2vec/entities.ll +++ b/llvm/test/tools/llvm-ir2vec/entities.ll @@ -1,6 +1,6 @@ ; RUN: llvm-ir2vec entities | FileCheck %s -CHECK: 93 +CHECK: 84 CHECK-NEXT: Ret 0 CHECK-NEXT: Br 1 CHECK-NEXT: Switch 2 @@ -70,27 +70,18 @@ CHECK-NEXT: InsertValue 65 CHECK-NEXT: LandingPad 66 CHECK-NEXT: Freeze 67 CHECK-NEXT: FloatTy 68 -CHECK-NEXT: FloatTy 69 -CHECK-NEXT: FloatTy 70 -CHECK-NEXT: FloatTy 71 -CHECK-NEXT: FloatTy 72 -CHECK-NEXT: FloatTy 73 -CHECK-NEXT: FloatTy 74 -CHECK-NEXT: VoidTy 75 -CHECK-NEXT: LabelTy 76 -CHECK-NEXT: MetadataTy 77 -CHECK-NEXT: UnknownTy 78 -CHECK-NEXT: TokenTy 79 -CHECK-NEXT: IntegerTy 80 -CHECK-NEXT: FunctionTy 81 -CHECK-NEXT: PointerTy 82 -CHECK-NEXT: StructTy 83 -CHECK-NEXT: ArrayTy 84 -CHECK-NEXT: VectorTy 85 -CHECK-NEXT: VectorTy 86 -CHECK-NEXT: PointerTy 87 -CHECK-NEXT: UnknownTy 88 -CHECK-NEXT: Function 89 -CHECK-NEXT: Pointer 90 -CHECK-NEXT: Constant 91 -CHECK-NEXT: Variable 92 +CHECK-NEXT: VoidTy 69 +CHECK-NEXT: LabelTy 70 +CHECK-NEXT: MetadataTy 71 +CHECK-NEXT: VectorTy 72 +CHECK-NEXT: TokenTy 73 +CHECK-NEXT: IntegerTy 74 +CHECK-NEXT: FunctionTy 75 +CHECK-NEXT: PointerTy 76 +CHECK-NEXT: StructTy 77 +CHECK-NEXT: ArrayTy 78 +CHECK-NEXT: UnknownTy 79 +CHECK-NEXT: Function 80 +CHECK-NEXT: Pointer 81 +CHECK-NEXT: Constant 82 +CHECK-NEXT: Variable 83 diff --git a/llvm/test/tools/llvm-ir2vec/triplets.ll b/llvm/test/tools/llvm-ir2vec/triplets.ll index 6f64bab888f6b..7b476f60a07b3 100644 --- a/llvm/test/tools/llvm-ir2vec/triplets.ll +++ b/llvm/test/tools/llvm-ir2vec/triplets.ll @@ -25,41 +25,41 @@ entry: } ; TRIPLETS: MAX_RELATION=3 -; TRIPLETS-NEXT: 12 80 0 -; TRIPLETS-NEXT: 12 92 2 -; TRIPLETS-NEXT: 12 92 3 +; TRIPLETS-NEXT: 12 74 0 +; TRIPLETS-NEXT: 12 83 2 +; TRIPLETS-NEXT: 12 83 3 ; TRIPLETS-NEXT: 12 0 1 -; TRIPLETS-NEXT: 0 75 0 -; TRIPLETS-NEXT: 0 92 2 -; TRIPLETS-NEXT: 16 80 0 -; TRIPLETS-NEXT: 16 92 2 -; TRIPLETS-NEXT: 16 92 3 +; TRIPLETS-NEXT: 0 69 0 +; TRIPLETS-NEXT: 0 83 2 +; TRIPLETS-NEXT: 16 74 0 +; TRIPLETS-NEXT: 16 83 2 +; TRIPLETS-NEXT: 16 83 3 ; TRIPLETS-NEXT: 16 0 1 -; TRIPLETS-NEXT: 0 75 0 -; TRIPLETS-NEXT: 0 92 2 -; TRIPLETS-NEXT: 30 82 0 -; TRIPLETS-NEXT: 30 91 2 +; TRIPLETS-NEXT: 0 69 0 +; TRIPLETS-NEXT: 0 83 2 +; TRIPLETS-NEXT: 30 76 0 +; TRIPLETS-NEXT: 30 82 2 ; TRIPLETS-NEXT: 30 30 1 -; TRIPLETS-NEXT: 30 82 0 -; TRIPLETS-NEXT: 30 91 2 +; TRIPLETS-NEXT: 30 76 0 +; TRIPLETS-NEXT: 30 82 2 ; TRIPLETS-NEXT: 30 32 1 -; TRIPLETS-NEXT: 32 75 0 -; TRIPLETS-NEXT: 32 92 2 -; TRIPLETS-NEXT: 32 90 3 +; TRIPLETS-NEXT: 32 69 0 +; TRIPLETS-NEXT: 32 83 2 +; TRIPLETS-NEXT: 32 81 3 ; TRIPLETS-NEXT: 32 32 1 -; TRIPLETS-NEXT: 32 75 0 -; TRIPLETS-NEXT: 32 92 2 -; TRIPLETS-NEXT: 32 90 3 +; TRIPLETS-NEXT: 32 69 0 +; TRIPLETS-NEXT: 32 83 2 +; TRIPLETS-NEXT: 32 81 3 ; TRIPLETS-NEXT: 32 31 1 -; TRIPLETS-NEXT: 31 80 0 -; TRIPLETS-NEXT: 31 90 2 +; TRIPLETS-NEXT: 31 74 0 +; TRIPLETS-NEXT: 31 81 2 ; TRIPLETS-NEXT: 31 31 1 -; TRIPLETS-NEXT: 31 80 0 -; TRIPLETS-NEXT: 31 90 2 +; TRIPLETS-NEXT: 31 74 0 +; TRIPLETS-NEXT: 31 81 2 ; TRIPLETS-NEXT: 31 12 1 -; TRIPLETS-NEXT: 12 80 0 -; TRIPLETS-NEXT: 12 92 2 -; TRIPLETS-NEXT: 12 92 3 +; TRIPLETS-NEXT: 12 74 0 +; TRIPLETS-NEXT: 12 83 2 +; TRIPLETS-NEXT: 12 83 3 ; TRIPLETS-NEXT: 12 0 1 -; TRIPLETS-NEXT: 0 75 0 -; TRIPLETS-NEXT: 0 92 2 +; TRIPLETS-NEXT: 0 69 0 +; TRIPLETS-NEXT: 0 83 2 \ No newline at end of file diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp index c065aaeedd395..461ded77d9609 100644 --- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp +++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp @@ -162,8 +162,8 @@ class IR2VecTool { for (const BasicBlock &BB : F) { for (const auto &I : BB.instructionsWithoutDebug()) { - unsigned Opcode = Vocabulary::getNumericID(I.getOpcode()); - unsigned TypeID = Vocabulary::getNumericID(I.getType()->getTypeID()); + unsigned Opcode = Vocabulary::getSlotIndex(I.getOpcode()); + unsigned TypeID = Vocabulary::getSlotIndex(I.getType()->getTypeID()); // Add "Next" relationship with previous instruction if (HasPrevOpcode) { @@ -184,7 +184,7 @@ class IR2VecTool { // Add "Arg" relationships unsigned ArgIndex = 0; for (const Use &U : I.operands()) { - unsigned OperandID = Vocabulary::getNumericID(U.get()); + unsigned OperandID = Vocabulary::getSlotIndex(U.get()); unsigned RelationID = ArgRelation + ArgIndex; OS << Opcode << '\t' << OperandID << '\t' << RelationID << '\n'; @@ -211,13 +211,7 @@ class IR2VecTool { /// Dump entity ID to string mappings static void generateEntityMappings(raw_ostream &OS) { - // FIXME: Currently, the generated entity mappings are not one-to-one; - // Multiple TypeIDs map to same string key (Like Half, BFloat, etc. map to - // FloatTy). This would hinder learning good seed embeddings. - // We should fix this in the future by ensuring unique string keys either by - // post-processing here without changing the mapping in ir2vec::Vocabulary, - // or by changing the Vocabulary generation logic to ensure unique keys. - auto EntityLen = Vocabulary::expectedSize(); + auto EntityLen = Vocabulary::getCanonicalSize(); OS << EntityLen << "\n"; for (unsigned EntityID = 0; EntityID < EntityLen; ++EntityID) OS << Vocabulary::getStringKey(EntityID) << '\t' << EntityID << '\n'; diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp index f0c81e160ca15..9f5428758d64c 100644 --- a/llvm/unittests/Analysis/IR2VecTest.cpp +++ b/llvm/unittests/Analysis/IR2VecTest.cpp @@ -336,8 +336,8 @@ TEST_F(IR2VecTestFixture, GetInstVecMap_Symbolic) { EXPECT_EQ(AddEmb.size(), 2u); EXPECT_EQ(RetEmb.size(), 2u); - EXPECT_TRUE(AddEmb.approximatelyEquals(Embedding(2, 27.9))); - EXPECT_TRUE(RetEmb.approximatelyEquals(Embedding(2, 17.0))); + EXPECT_TRUE(AddEmb.approximatelyEquals(Embedding(2, 25.5))); + EXPECT_TRUE(RetEmb.approximatelyEquals(Embedding(2, 15.5))); } TEST_F(IR2VecTestFixture, GetInstVecMap_FlowAware) { @@ -353,8 +353,8 @@ TEST_F(IR2VecTestFixture, GetInstVecMap_FlowAware) { EXPECT_EQ(InstMap.at(AddInst).size(), 2u); EXPECT_EQ(InstMap.at(RetInst).size(), 2u); - EXPECT_TRUE(InstMap.at(AddInst).approximatelyEquals(Embedding(2, 27.9))); - EXPECT_TRUE(InstMap.at(RetInst).approximatelyEquals(Embedding(2, 35.6))); + EXPECT_TRUE(InstMap.at(AddInst).approximatelyEquals(Embedding(2, 25.5))); + EXPECT_TRUE(InstMap.at(RetInst).approximatelyEquals(Embedding(2, 32.6))); } TEST_F(IR2VecTestFixture, GetBBVecMap_Symbolic) { @@ -367,9 +367,9 @@ TEST_F(IR2VecTestFixture, GetBBVecMap_Symbolic) { EXPECT_TRUE(BBMap.count(BB)); EXPECT_EQ(BBMap.at(BB).size(), 2u); - // BB vector should be sum of add and ret: {27.9, 27.9} + {17.0, 17.0} = - // {44.9, 44.9} - EXPECT_TRUE(BBMap.at(BB).approximatelyEquals(Embedding(2, 44.9))); + // BB vector should be sum of add and ret: {25.5, 25.5} + {15.5, 15.5} = + // {41.0, 41.0} + EXPECT_TRUE(BBMap.at(BB).approximatelyEquals(Embedding(2, 41.0))); } TEST_F(IR2VecTestFixture, GetBBVecMap_FlowAware) { @@ -382,9 +382,9 @@ TEST_F(IR2VecTestFixture, GetBBVecMap_FlowAware) { EXPECT_TRUE(BBMap.count(BB)); EXPECT_EQ(BBMap.at(BB).size(), 2u); - // BB vector should be sum of add and ret: {27.9, 27.9} + {35.6, 35.6} = - // {63.5, 63.5} - EXPECT_TRUE(BBMap.at(BB).approximatelyEquals(Embedding(2, 63.5))); + // BB vector should be sum of add and ret: {25.5, 25.5} + {32.6, 32.6} = + // {58.1, 58.1} + EXPECT_TRUE(BBMap.at(BB).approximatelyEquals(Embedding(2, 58.1))); } TEST_F(IR2VecTestFixture, GetBBVector_Symbolic) { @@ -394,7 +394,7 @@ TEST_F(IR2VecTestFixture, GetBBVector_Symbolic) { const auto &BBVec = Emb->getBBVector(*BB); EXPECT_EQ(BBVec.size(), 2u); - EXPECT_TRUE(BBVec.approximatelyEquals(Embedding(2, 44.9))); + EXPECT_TRUE(BBVec.approximatelyEquals(Embedding(2, 41.0))); } TEST_F(IR2VecTestFixture, GetBBVector_FlowAware) { @@ -404,7 +404,7 @@ TEST_F(IR2VecTestFixture, GetBBVector_FlowAware) { const auto &BBVec = Emb->getBBVector(*BB); EXPECT_EQ(BBVec.size(), 2u); - EXPECT_TRUE(BBVec.approximatelyEquals(Embedding(2, 63.5))); + EXPECT_TRUE(BBVec.approximatelyEquals(Embedding(2, 58.1))); } TEST_F(IR2VecTestFixture, GetFunctionVector_Symbolic) { @@ -415,8 +415,8 @@ TEST_F(IR2VecTestFixture, GetFunctionVector_Symbolic) { EXPECT_EQ(FuncVec.size(), 2u); - // Function vector should match BB vector (only one BB): {44.9, 44.9} - EXPECT_TRUE(FuncVec.approximatelyEquals(Embedding(2, 44.9))); + // Function vector should match BB vector (only one BB): {41.0, 41.0} + EXPECT_TRUE(FuncVec.approximatelyEquals(Embedding(2, 41.0))); } TEST_F(IR2VecTestFixture, GetFunctionVector_FlowAware) { @@ -426,24 +426,40 @@ TEST_F(IR2VecTestFixture, GetFunctionVector_FlowAware) { const auto &FuncVec = Emb->getFunctionVector(); EXPECT_EQ(FuncVec.size(), 2u); - // Function vector should match BB vector (only one BB): {63.5, 63.5} - EXPECT_TRUE(FuncVec.approximatelyEquals(Embedding(2, 63.5))); + // Function vector should match BB vector (only one BB): {58.1, 58.1} + EXPECT_TRUE(FuncVec.approximatelyEquals(Embedding(2, 58.1))); } static constexpr unsigned MaxOpcodes = Vocabulary::MaxOpcodes; static constexpr unsigned MaxTypeIDs = Vocabulary::MaxTypeIDs; +static constexpr unsigned MaxCanonicalTypeIDs = Vocabulary::MaxCanonicalTypeIDs; static constexpr unsigned MaxOperands = Vocabulary::MaxOperandKinds; +// Mapping between LLVM Type::TypeID tokens and Vocabulary::CanonicalTypeID +// names and their canonical string keys. +#define IR2VEC_HANDLE_TYPE_BIMAP(X) \ + X(VoidTyID, VoidTy, "VoidTy") \ + X(IntegerTyID, IntegerTy, "IntegerTy") \ + X(FloatTyID, FloatTy, "FloatTy") \ + X(PointerTyID, PointerTy, "PointerTy") \ + X(FunctionTyID, FunctionTy, "FunctionTy") \ + X(StructTyID, StructTy, "StructTy") \ + X(ArrayTyID, ArrayTy, "ArrayTy") \ + X(FixedVectorTyID, VectorTy, "VectorTy") \ + X(LabelTyID, LabelTy, "LabelTy") \ + X(TokenTyID, TokenTy, "TokenTy") \ + X(MetadataTyID, MetadataTy, "MetadataTy") + TEST(IR2VecVocabularyTest, DummyVocabTest) { for (unsigned Dim = 1; Dim <= 10; ++Dim) { auto VocabVec = Vocabulary::createDummyVocabForTest(Dim); - + auto VocabVecSize = VocabVec.size(); // All embeddings should have the same dimension for (const auto &Emb : VocabVec) EXPECT_EQ(Emb.size(), Dim); // Should have the correct total number of embeddings - EXPECT_EQ(VocabVec.size(), MaxOpcodes + MaxTypeIDs + MaxOperands); + EXPECT_EQ(VocabVecSize, MaxOpcodes + MaxCanonicalTypeIDs + MaxOperands); auto ExpectedVocab = VocabVec; @@ -454,7 +470,7 @@ TEST(IR2VecVocabularyTest, DummyVocabTest) { Vocabulary Result = VocabAnalysis.run(TestMod, MAM); EXPECT_TRUE(Result.isValid()); EXPECT_EQ(Result.getDimension(), Dim); - EXPECT_EQ(Result.size(), MaxOpcodes + MaxTypeIDs + MaxOperands); + EXPECT_EQ(Result.getCanonicalSize(), VocabVecSize); unsigned CurPos = 0; for (const auto &Entry : Result) @@ -462,64 +478,68 @@ TEST(IR2VecVocabularyTest, DummyVocabTest) { } } -TEST(IR2VecVocabularyTest, NumericIDMap) { - // Test getNumericID for opcodes - EXPECT_EQ(Vocabulary::getNumericID(1u), 0u); - EXPECT_EQ(Vocabulary::getNumericID(13u), 12u); - EXPECT_EQ(Vocabulary::getNumericID(MaxOpcodes), MaxOpcodes - 1); - - // Test getNumericID for Type IDs - EXPECT_EQ(Vocabulary::getNumericID(Type::VoidTyID), - MaxOpcodes + static_cast<unsigned>(Type::VoidTyID)); - EXPECT_EQ(Vocabulary::getNumericID(Type::HalfTyID), - MaxOpcodes + static_cast<unsigned>(Type::HalfTyID)); - EXPECT_EQ(Vocabulary::getNumericID(Type::FloatTyID), - MaxOpcodes + static_cast<unsigned>(Type::FloatTyID)); - EXPECT_EQ(Vocabulary::getNumericID(Type::IntegerTyID), - MaxOpcodes + static_cast<unsigned>(Type::IntegerTyID)); - EXPECT_EQ(Vocabulary::getNumericID(Type::PointerTyID), - MaxOpcodes + static_cast<unsigned>(Type::PointerTyID)); - - // Test getNumericID for Value operands +TEST(IR2VecVocabularyTest, SlotIdxMapping) { + // Test getSlotIndex for Opcodes +#define EXPECT_OPCODE_SLOT(NUM, OPCODE, CLASS) \ + EXPECT_EQ(Vocabulary::getSlotIndex(NUM), static_cast<unsigned>(NUM - 1)); +#define HANDLE_INST(NUM, OPCODE, CLASS) EXPECT_OPCODE_SLOT(NUM, OPCODE, CLASS) +#include "llvm/IR/Instruction.def" +#undef HANDLE_INST +#undef EXPECT_OPCODE_SLOT + + // Test getSlotIndex for Types +#define EXPECT_TYPE_SLOT(TypeIDTok, CanonEnum, CanonStr) \ + EXPECT_EQ(Vocabulary::getSlotIndex(Type::TypeIDTok), \ + MaxOpcodes + static_cast<unsigned>( \ + Vocabulary::CanonicalTypeID::CanonEnum)); + + IR2VEC_HANDLE_TYPE_BIMAP(EXPECT_TYPE_SLOT) + +#undef EXPECT_TYPE_SLOT + + // Test getSlotIndex for Value operands LLVMContext Ctx; Module M("TestM", Ctx); FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), {Type::getInt32Ty(Ctx)}, false); Function *F = Function::Create(FTy, Function::ExternalLinkage, "testFunc", M); +#define EXPECTED_VOCAB_OPERAND_SLOT(X) \ + MaxOpcodes + MaxCanonicalTypeIDs + static_cast<unsigned>(X) // Test Function operand - EXPECT_EQ(Vocabulary::getNumericID(F), - MaxOpcodes + MaxTypeIDs + 0u); // Function = 0 + EXPECT_EQ(Vocabulary::getSlotIndex(F), + EXPECTED_VOCAB_OPERAND_SLOT(Vocabulary::OperandKind::FunctionID)); // Test Constant operand Constant *C = ConstantInt::get(Type::getInt32Ty(Ctx), 42); - EXPECT_EQ(Vocabulary::getNumericID(C), - MaxOpcodes + MaxTypeIDs + 2u); // Constant = 2 + EXPECT_EQ(Vocabulary::getSlotIndex(C), + EXPECTED_VOCAB_OPERAND_SLOT(Vocabulary::OperandKind::ConstantID)); // Test Pointer operand BasicBlock *BB = BasicBlock::Create(Ctx, "entry", F); AllocaInst *PtrVal = new AllocaInst(Type::getInt32Ty(Ctx), 0, "ptr", BB); - EXPECT_EQ(Vocabulary::getNumericID(PtrVal), - MaxOpcodes + MaxTypeIDs + 1u); // Pointer = 1 + EXPECT_EQ(Vocabulary::getSlotIndex(PtrVal), + EXPECTED_VOCAB_OPERAND_SLOT(Vocabulary::OperandKind::PointerID)); // Test Variable operand (function argument) Argument *Arg = F->getArg(0); - EXPECT_EQ(Vocabulary::getNumericID(Arg), - MaxOpcodes + MaxTypeIDs + 3u); // Variable = 3 + EXPECT_EQ(Vocabulary::getSlotIndex(Arg), + EXPECTED_VOCAB_OPERAND_SLOT(Vocabulary::OperandKind::VariableID)); +#undef EXPECTED_VOCAB_OPERAND_SLOT } #if GTEST_HAS_DEATH_TEST #ifndef NDEBUG TEST(IR2VecVocabularyTest, NumericIDMapInvalidInputs) { // Test invalid opcode IDs - EXPECT_DEATH(Vocabulary::getNumericID(0u), "Invalid opcode"); - EXPECT_DEATH(Vocabulary::getNumericID(MaxOpcodes + 1), "Invalid opcode"); + EXPECT_DEATH(Vocabulary::getSlotIndex(0u), "Invalid opcode"); + EXPECT_DEATH(Vocabulary::getSlotIndex(MaxOpcodes + 1), "Invalid opcode"); // Test invalid type IDs - EXPECT_DEATH(Vocabulary::getNumericID(static_cast<Type::TypeID>(MaxTypeIDs)), + EXPECT_DEATH(Vocabulary::getSlotIndex(static_cast<Type::TypeID>(MaxTypeIDs)), "Invalid type ID"); EXPECT_DEATH( - Vocabulary::getNumericID(static_cast<Type::TypeID>(MaxTypeIDs + 10)), + Vocabulary::getSlotIndex(static_cast<Type::TypeID>(MaxTypeIDs + 10)), "Invalid type ID"); } #endif // NDEBUG @@ -529,18 +549,46 @@ TEST(IR2VecVocabularyTest, StringKeyGeneration) { EXPECT_EQ(Vocabulary::getStringKey(0), "Ret"); EXPECT_EQ(Vocabulary::getStringKey(12), "Add"); - StringRef HalfTypeKey = Vocabulary::getStringKey(MaxOpcodes + 0); - StringRef FloatTypeKey = Vocabulary::getStringKey(MaxOpcodes + 2); - StringRef VoidTypeKey = Vocabulary::getStringKey(MaxOpcodes + 7); - StringRef IntTypeKey = Vocabulary::getStringKey(MaxOpcodes + 12); - - EXPECT_EQ(HalfTypeKey, "FloatTy"); - EXPECT_EQ(FloatTypeKey, "FloatTy"); - EXPECT_EQ(VoidTypeKey, "VoidTy"); - EXPECT_EQ(IntTypeKey, "IntegerTy"); - - StringRef FuncArgKey = Vocabulary::getStringKey(MaxOpcodes + MaxTypeIDs + 0); - StringRef PtrArgKey = Vocabulary::getStringKey(MaxOpcodes + MaxTypeIDs + 1); +#define EXPECT_OPCODE(NUM, OPCODE, CLASS) \ + EXPECT_EQ(Vocabulary::getStringKey(Vocabulary::getSlotIndex(NUM)), \ + Vocabulary::getVocabKeyForOpcode(NUM)); +#define HANDLE_INST(NUM, OPCODE, CLASS) EXPECT_OPCODE(NUM, OPCODE, CLASS) +#include "llvm/IR/Instruction.def" +#undef HANDLE_INST +#undef EXPECT_OPCODE + + // Verify CanonicalTypeID -> string mapping +#define EXPECT_CANONICAL_TYPE_NAME(TypeIDTok, CanonEnum, CanonStr) \ + EXPECT_EQ(Vocabulary::getStringKey( \ + MaxOpcodes + static_cast<unsigned>( \ + Vocabulary::CanonicalTypeID::CanonEnum)), \ + CanonStr); + + IR2VEC_HANDLE_TYPE_BIMAP(EXPECT_CANONICAL_TYPE_NAME) + +#undef EXPECT_CANONICAL_TYPE_NAME + +#define HANDLE_OPERAND_KINDS(X) \ + X(FunctionID, "Function") \ + X(PointerID, "Pointer") \ + X(ConstantID, "Constant") \ + X(VariableID, "Variable") + +#define EXPECT_OPERAND_KIND(EnumName, Str) \ + EXPECT_EQ(Vocabulary::getStringKey( \ + MaxOpcodes + MaxCanonicalTypeIDs + \ + static_cast<unsigned>(Vocabulary::OperandKind::EnumName)), \ + Str); + + HANDLE_OPERAND_KINDS(EXPECT_OPERAND_KIND) + +#undef EXPECT_OPERAND_KIND +#undef HANDLE_OPERAND_KINDS + + StringRef FuncArgKey = + Vocabulary::getStringKey(MaxOpcodes + MaxCanonicalTypeIDs + 0); + StringRef PtrArgKey = + Vocabulary::getStringKey(MaxOpcodes + MaxCanonicalTypeIDs + 1); EXPECT_EQ(FuncArgKey, "Function"); EXPECT_EQ(PtrArgKey, "Pointer"); } @@ -578,39 +626,14 @@ TEST(IR2VecVocabularyTest, InvalidAccess) { #endif // GTEST_HAS_DEATH_TEST TEST(IR2VecVocabularyTest, TypeIDStringKeyMapping) { - EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes + - static_cast<unsigned>(Type::VoidTyID)), - "VoidTy"); - EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes + - static_cast<unsigned>(Type::IntegerTyID)), - "IntegerTy"); - EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes + - static_cast<unsigned>(Type::FloatTyID)), - "FloatTy"); - EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes + - static_cast<unsigned>(Type::PointerTyID)), - "PointerTy"); - EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes + - static_cast<unsigned>(Type::FunctionTyID)), - "FunctionTy"); - EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes + - static_cast<unsigned>(Type::StructTyID)), - "StructTy"); - EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes + - static_cast<unsigned>(Type::ArrayTyID)), - "ArrayTy"); - EXPECT_EQ(Vocabulary::getStringKey( - MaxOpcodes + static_cast<unsigned>(Type::FixedVectorTyID)), - "VectorTy"); - EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes + - static_cast<unsigned>(Type::LabelTyID)), - "LabelTy"); - EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes + - static_cast<unsigned>(Type::TokenTyID)), - "TokenTy"); - EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes + - static_cast<unsigned>(Type::MetadataTyID)), - "MetadataTy"); +#define EXPECT_TYPE_TO_CANONICAL(TypeIDTok, CanonEnum, CanonStr) \ + EXPECT_EQ( \ + Vocabulary::getStringKey(Vocabulary::getSlotIndex(Type::TypeIDTok)), \ + CanonStr); + + IR2VEC_HANDLE_TYPE_BIMAP(EXPECT_TYPE_TO_CANONICAL) + +#undef EXPECT_TYPE_TO_CANONICAL } TEST(IR2VecVocabularyTest, InvalidVocabularyConstruction) { _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits