https://github.com/svkeerthy created https://github.com/llvm/llvm-project/pull/161713
None >From abf89848938125eca35cb5ec0b6a13a7eea3bd20 Mon Sep 17 00:00:00 2001 From: svkeerthy <[email protected]> Date: Thu, 2 Oct 2025 18:14:53 +0000 Subject: [PATCH] MIRVocabulary changes --- llvm/include/llvm/CodeGen/MIR2Vec.h | 31 +++++++++++++++----------- llvm/lib/CodeGen/MIR2Vec.cpp | 18 ++++++++++----- llvm/unittests/CodeGen/MIR2VecTest.cpp | 28 ++++++++++++++++++----- 3 files changed, 52 insertions(+), 25 deletions(-) diff --git a/llvm/include/llvm/CodeGen/MIR2Vec.h b/llvm/include/llvm/CodeGen/MIR2Vec.h index 8bb47e61624bf..fde61170fde50 100644 --- a/llvm/include/llvm/CodeGen/MIR2Vec.h +++ b/llvm/include/llvm/CodeGen/MIR2Vec.h @@ -8,8 +8,8 @@ /// /// \file /// This file defines the MIR2Vec vocabulary -/// analysis(MIR2VecVocabLegacyAnalysis), the core mir2vec::Embedder interface -/// for generating Machine IR embeddings, and related utilities. +/// analysis(MIR2VecVocabLegacyAnalysis), the core mir2vec::MIREmbedder +/// interface for generating Machine IR embeddings, and related utilities. /// /// MIR2Vec extends IR2Vec to support Machine IR embeddings. It represents the /// LLVM Machine IR as embeddings which can be used as input to machine learning @@ -71,25 +71,31 @@ class MIRVocabulary { unsigned TotalEntries = 0; } Layout; + enum class Section : unsigned { Opcodes = 0, MaxSections }; + ir2vec::VocabStorage Storage; mutable std::set<std::string> UniqueBaseOpcodeNames; - void generateStorage(const VocabMap &OpcodeMap, const TargetInstrInfo &TII); - void buildCanonicalOpcodeMapping(const TargetInstrInfo &TII); + const TargetInstrInfo &TII; + void generateStorage(const VocabMap &OpcodeMap); + void buildCanonicalOpcodeMapping(); + + /// Get canonical index for a machine opcode + unsigned getCanonicalOpcodeIndex(unsigned Opcode) const; public: - /// Static helper method for extracting base opcode names (public for testing) + /// Static method for extracting base opcode names (public for testing) static std::string extractBaseOpcodeName(StringRef InstrName); - /// Helper method for getting canonical index for base name (public for - /// testing) + /// Get canonical index for base name (public for testing) unsigned getCanonicalIndexForBaseName(StringRef BaseName) const; /// Get the string key for a vocabulary entry at the given position std::string getStringKey(unsigned Pos) const; - MIRVocabulary() = default; + MIRVocabulary() = delete; MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo *TII); - MIRVocabulary(ir2vec::VocabStorage &&Storage) : Storage(std::move(Storage)) {} + MIRVocabulary(ir2vec::VocabStorage &&Storage, const TargetInstrInfo &TII) + : Storage(std::move(Storage)), TII(TII) {} bool isValid() const { return UniqueBaseOpcodeNames.size() > 0 && @@ -103,11 +109,10 @@ class MIRVocabulary { } // Accessor methods - const Embedding &operator[](unsigned Index) const { + const Embedding &operator[](unsigned Opcode) const { assert(isValid() && "MIR2Vec Vocabulary is invalid"); - assert(Index < Layout.TotalEntries && "Index out of bounds"); - // Fixme: For now, use section 0 for all entries - return Storage[0][Index]; + unsigned LocalIndex = getCanonicalOpcodeIndex(Opcode); + return Storage[static_cast<unsigned>(Section::Opcodes)][LocalIndex]; } // Iterator access diff --git a/llvm/lib/CodeGen/MIR2Vec.cpp b/llvm/lib/CodeGen/MIR2Vec.cpp index 18cae0f51e8c3..d1d322dce7d1c 100644 --- a/llvm/lib/CodeGen/MIR2Vec.cpp +++ b/llvm/lib/CodeGen/MIR2Vec.cpp @@ -49,18 +49,19 @@ cl::opt<float> OpcWeight("mir2vec-opc-weight", cl::Optional, cl::init(1.0), //===----------------------------------------------------------------------===// MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries, - const TargetInstrInfo *TII) { + const TargetInstrInfo *TII) + : TII(*TII) { // Early return for invalid inputs - creates empty/invalid vocabulary if (!TII || OpcodeEntries.empty()) return; - buildCanonicalOpcodeMapping(*TII); + buildCanonicalOpcodeMapping(); unsigned CanonicalOpcodeCount = UniqueBaseOpcodeNames.size(); assert(CanonicalOpcodeCount > 0 && "No canonical opcodes found for target - invalid vocabulary"); Layout.OperandBase = CanonicalOpcodeCount; - generateStorage(OpcodeEntries, *TII); + generateStorage(OpcodeEntries); Layout.TotalEntries = Storage.size(); } @@ -103,6 +104,12 @@ unsigned MIRVocabulary::getCanonicalIndexForBaseName(StringRef BaseName) const { return std::distance(UniqueBaseOpcodeNames.begin(), It); } +unsigned MIRVocabulary::getCanonicalOpcodeIndex(unsigned Opcode) const { + assert(isValid() && "MIR2Vec Vocabulary is invalid"); + auto BaseOpcode = extractBaseOpcodeName(TII.getName(Opcode)); + return getCanonicalIndexForBaseName(BaseOpcode); +} + std::string MIRVocabulary::getStringKey(unsigned Pos) const { assert(isValid() && "MIR2Vec Vocabulary is invalid"); assert(Pos < Layout.TotalEntries && "Position out of bounds in vocabulary"); @@ -119,8 +126,7 @@ std::string MIRVocabulary::getStringKey(unsigned Pos) const { return ""; } -void MIRVocabulary::generateStorage(const VocabMap &OpcodeMap, - const TargetInstrInfo &TII) { +void MIRVocabulary::generateStorage(const VocabMap &OpcodeMap) { // Helper for handling missing entities in the vocabulary. // Currently, we use a zero vector. In the future, we will throw an error to @@ -168,7 +174,7 @@ void MIRVocabulary::generateStorage(const VocabMap &OpcodeMap, new (&Storage) ir2vec::VocabStorage(std::move(Sections)); } -void MIRVocabulary::buildCanonicalOpcodeMapping(const TargetInstrInfo &TII) { +void MIRVocabulary::buildCanonicalOpcodeMapping() { // Check if already built if (!UniqueBaseOpcodeNames.empty()) return; diff --git a/llvm/unittests/CodeGen/MIR2VecTest.cpp b/llvm/unittests/CodeGen/MIR2VecTest.cpp index adc7b0e41f075..232cf34e52415 100644 --- a/llvm/unittests/CodeGen/MIR2VecTest.cpp +++ b/llvm/unittests/CodeGen/MIR2VecTest.cpp @@ -93,6 +93,15 @@ class MIR2VecVocabTestFixture : public ::testing::Test { } }; +// Function to find an opcode by name +static int findOpcodeByName(const TargetInstrInfo *TII, StringRef Name) { + for (unsigned Opcode = 1; Opcode < TII->getNumOpcodes(); ++Opcode) { + if (TII->getName(Opcode) == Name) + return Opcode; + } + return -1; // Not found +} + TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) { // Test that same base opcodes get same canonical indices std::string baseName1 = MIRVocabulary::extractBaseOpcodeName("ADD16ri"); @@ -138,9 +147,19 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) { 6880u); // X86 has >6880 unique base opcodes // Check that the embeddings for opcodes not in the vocab are zero vectors - EXPECT_TRUE(testVocab[addIndex].approximatelyEquals(Val)); - EXPECT_TRUE(testVocab[subIndex].approximatelyEquals(Embedding(64, 0.0f))); - EXPECT_TRUE(testVocab[movIndex].approximatelyEquals(Embedding(64, 0.0f))); + int add32rrOpcode = findOpcodeByName(TII, "ADD32rr"); + ASSERT_NE(add32rrOpcode, -1) << "ADD32rr opcode not found"; + EXPECT_TRUE(testVocab[add32rrOpcode].approximatelyEquals(Val)); + + int sub32rrOpcode = findOpcodeByName(TII, "SUB32rr"); + ASSERT_NE(sub32rrOpcode, -1) << "SUB32rr opcode not found"; + EXPECT_TRUE( + testVocab[sub32rrOpcode].approximatelyEquals(Embedding(64, 0.0f))); + + int mov32rrOpcode = findOpcodeByName(TII, "MOV32rr"); + ASSERT_NE(mov32rrOpcode, -1) << "MOV32rr opcode not found"; + EXPECT_TRUE( + testVocab[mov32rrOpcode].approximatelyEquals(Embedding(64, 0.0f))); } // Test deterministic mapping @@ -170,9 +189,6 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) { // Test MIRVocabulary construction TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) { - // Test empty MIRVocabulary - MIRVocabulary emptyVocab; - EXPECT_FALSE(emptyVocab.isValid()); // Test MIRVocabulary with embeddings via VocabMap VocabMap vocabMap; _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
