https://github.com/svkeerthy created 
https://github.com/llvm/llvm-project/pull/161713

None

>From abf89848938125eca35cb5ec0b6a13a7eea3bd20 Mon Sep 17 00:00:00 2001
From: svkeerthy <[email protected]>
Date: Thu, 2 Oct 2025 18:14:53 +0000
Subject: [PATCH] MIRVocabulary changes

---
 llvm/include/llvm/CodeGen/MIR2Vec.h    | 31 +++++++++++++++-----------
 llvm/lib/CodeGen/MIR2Vec.cpp           | 18 ++++++++++-----
 llvm/unittests/CodeGen/MIR2VecTest.cpp | 28 ++++++++++++++++++-----
 3 files changed, 52 insertions(+), 25 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/MIR2Vec.h 
b/llvm/include/llvm/CodeGen/MIR2Vec.h
index 8bb47e61624bf..fde61170fde50 100644
--- a/llvm/include/llvm/CodeGen/MIR2Vec.h
+++ b/llvm/include/llvm/CodeGen/MIR2Vec.h
@@ -8,8 +8,8 @@
 ///
 /// \file
 /// This file defines the MIR2Vec vocabulary
-/// analysis(MIR2VecVocabLegacyAnalysis), the core mir2vec::Embedder interface
-/// for generating Machine IR embeddings, and related utilities.
+/// analysis(MIR2VecVocabLegacyAnalysis), the core mir2vec::MIREmbedder
+/// interface for generating Machine IR embeddings, and related utilities.
 ///
 /// MIR2Vec extends IR2Vec to support Machine IR embeddings. It represents the
 /// LLVM Machine IR as embeddings which can be used as input to machine 
learning
@@ -71,25 +71,31 @@ class MIRVocabulary {
     unsigned TotalEntries = 0;
   } Layout;
 
+  enum class Section : unsigned { Opcodes = 0, MaxSections };
+
   ir2vec::VocabStorage Storage;
   mutable std::set<std::string> UniqueBaseOpcodeNames;
-  void generateStorage(const VocabMap &OpcodeMap, const TargetInstrInfo &TII);
-  void buildCanonicalOpcodeMapping(const TargetInstrInfo &TII);
+  const TargetInstrInfo &TII;
+  void generateStorage(const VocabMap &OpcodeMap);
+  void buildCanonicalOpcodeMapping();
+
+  /// Get canonical index for a machine opcode
+  unsigned getCanonicalOpcodeIndex(unsigned Opcode) const;
 
 public:
-  /// Static helper method for extracting base opcode names (public for 
testing)
+  /// Static method for extracting base opcode names (public for testing)
   static std::string extractBaseOpcodeName(StringRef InstrName);
 
-  /// Helper method for getting canonical index for base name (public for
-  /// testing)
+  /// Get canonical index for base name (public for testing)
   unsigned getCanonicalIndexForBaseName(StringRef BaseName) const;
 
   /// Get the string key for a vocabulary entry at the given position
   std::string getStringKey(unsigned Pos) const;
 
-  MIRVocabulary() = default;
+  MIRVocabulary() = delete;
   MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo *TII);
-  MIRVocabulary(ir2vec::VocabStorage &&Storage) : Storage(std::move(Storage)) 
{}
+  MIRVocabulary(ir2vec::VocabStorage &&Storage, const TargetInstrInfo &TII)
+      : Storage(std::move(Storage)), TII(TII) {}
 
   bool isValid() const {
     return UniqueBaseOpcodeNames.size() > 0 &&
@@ -103,11 +109,10 @@ class MIRVocabulary {
   }
 
   // Accessor methods
-  const Embedding &operator[](unsigned Index) const {
+  const Embedding &operator[](unsigned Opcode) const {
     assert(isValid() && "MIR2Vec Vocabulary is invalid");
-    assert(Index < Layout.TotalEntries && "Index out of bounds");
-    // Fixme: For now, use section 0 for all entries
-    return Storage[0][Index];
+    unsigned LocalIndex = getCanonicalOpcodeIndex(Opcode);
+    return Storage[static_cast<unsigned>(Section::Opcodes)][LocalIndex];
   }
 
   // Iterator access
diff --git a/llvm/lib/CodeGen/MIR2Vec.cpp b/llvm/lib/CodeGen/MIR2Vec.cpp
index 18cae0f51e8c3..d1d322dce7d1c 100644
--- a/llvm/lib/CodeGen/MIR2Vec.cpp
+++ b/llvm/lib/CodeGen/MIR2Vec.cpp
@@ -49,18 +49,19 @@ cl::opt<float> OpcWeight("mir2vec-opc-weight", 
cl::Optional, cl::init(1.0),
 
//===----------------------------------------------------------------------===//
 
 MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries,
-                             const TargetInstrInfo *TII) {
+                             const TargetInstrInfo *TII)
+    : TII(*TII) {
   // Early return for invalid inputs - creates empty/invalid vocabulary
   if (!TII || OpcodeEntries.empty())
     return;
 
-  buildCanonicalOpcodeMapping(*TII);
+  buildCanonicalOpcodeMapping();
 
   unsigned CanonicalOpcodeCount = UniqueBaseOpcodeNames.size();
   assert(CanonicalOpcodeCount > 0 &&
          "No canonical opcodes found for target - invalid vocabulary");
   Layout.OperandBase = CanonicalOpcodeCount;
-  generateStorage(OpcodeEntries, *TII);
+  generateStorage(OpcodeEntries);
   Layout.TotalEntries = Storage.size();
 }
 
@@ -103,6 +104,12 @@ unsigned 
MIRVocabulary::getCanonicalIndexForBaseName(StringRef BaseName) const {
   return std::distance(UniqueBaseOpcodeNames.begin(), It);
 }
 
+unsigned MIRVocabulary::getCanonicalOpcodeIndex(unsigned Opcode) const {
+  assert(isValid() && "MIR2Vec Vocabulary is invalid");
+  auto BaseOpcode = extractBaseOpcodeName(TII.getName(Opcode));
+  return getCanonicalIndexForBaseName(BaseOpcode);
+}
+
 std::string MIRVocabulary::getStringKey(unsigned Pos) const {
   assert(isValid() && "MIR2Vec Vocabulary is invalid");
   assert(Pos < Layout.TotalEntries && "Position out of bounds in vocabulary");
@@ -119,8 +126,7 @@ std::string MIRVocabulary::getStringKey(unsigned Pos) const 
{
   return "";
 }
 
-void MIRVocabulary::generateStorage(const VocabMap &OpcodeMap,
-                                    const TargetInstrInfo &TII) {
+void MIRVocabulary::generateStorage(const VocabMap &OpcodeMap) {
 
   // Helper for handling missing entities in the vocabulary.
   // Currently, we use a zero vector. In the future, we will throw an error to
@@ -168,7 +174,7 @@ void MIRVocabulary::generateStorage(const VocabMap 
&OpcodeMap,
   new (&Storage) ir2vec::VocabStorage(std::move(Sections));
 }
 
-void MIRVocabulary::buildCanonicalOpcodeMapping(const TargetInstrInfo &TII) {
+void MIRVocabulary::buildCanonicalOpcodeMapping() {
   // Check if already built
   if (!UniqueBaseOpcodeNames.empty())
     return;
diff --git a/llvm/unittests/CodeGen/MIR2VecTest.cpp 
b/llvm/unittests/CodeGen/MIR2VecTest.cpp
index adc7b0e41f075..232cf34e52415 100644
--- a/llvm/unittests/CodeGen/MIR2VecTest.cpp
+++ b/llvm/unittests/CodeGen/MIR2VecTest.cpp
@@ -93,6 +93,15 @@ class MIR2VecVocabTestFixture : public ::testing::Test {
   }
 };
 
+// Function to find an opcode by name
+static int findOpcodeByName(const TargetInstrInfo *TII, StringRef Name) {
+  for (unsigned Opcode = 1; Opcode < TII->getNumOpcodes(); ++Opcode) {
+    if (TII->getName(Opcode) == Name)
+      return Opcode;
+  }
+  return -1; // Not found
+}
+
 TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
   // Test that same base opcodes get same canonical indices
   std::string baseName1 = MIRVocabulary::extractBaseOpcodeName("ADD16ri");
@@ -138,9 +147,19 @@ TEST_F(MIR2VecVocabTestFixture, 
CanonicalOpcodeMappingTest) {
             6880u); // X86 has >6880 unique base opcodes
 
   // Check that the embeddings for opcodes not in the vocab are zero vectors
-  EXPECT_TRUE(testVocab[addIndex].approximatelyEquals(Val));
-  EXPECT_TRUE(testVocab[subIndex].approximatelyEquals(Embedding(64, 0.0f)));
-  EXPECT_TRUE(testVocab[movIndex].approximatelyEquals(Embedding(64, 0.0f)));
+  int add32rrOpcode = findOpcodeByName(TII, "ADD32rr");
+  ASSERT_NE(add32rrOpcode, -1) << "ADD32rr opcode not found";
+  EXPECT_TRUE(testVocab[add32rrOpcode].approximatelyEquals(Val));
+
+  int sub32rrOpcode = findOpcodeByName(TII, "SUB32rr");
+  ASSERT_NE(sub32rrOpcode, -1) << "SUB32rr opcode not found";
+  EXPECT_TRUE(
+      testVocab[sub32rrOpcode].approximatelyEquals(Embedding(64, 0.0f)));
+
+  int mov32rrOpcode = findOpcodeByName(TII, "MOV32rr");
+  ASSERT_NE(mov32rrOpcode, -1) << "MOV32rr opcode not found";
+  EXPECT_TRUE(
+      testVocab[mov32rrOpcode].approximatelyEquals(Embedding(64, 0.0f)));
 }
 
 // Test deterministic mapping
@@ -170,9 +189,6 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) {
 
 // Test MIRVocabulary construction
 TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) {
-  // Test empty MIRVocabulary
-  MIRVocabulary emptyVocab;
-  EXPECT_FALSE(emptyVocab.isValid());
 
   // Test MIRVocabulary with embeddings via VocabMap
   VocabMap vocabMap;

_______________________________________________
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to