https://github.com/svkeerthy created 
https://github.com/llvm/llvm-project/pull/158376

None

>From 81a84b27f4b2aeaf6ca1421b2abb2a960c4e7a50 Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeer...@google.com>
Date: Fri, 12 Sep 2025 22:06:44 +0000
Subject: [PATCH] VocabStorage

---
 llvm/include/llvm/Analysis/IR2Vec.h           | 145 +++++++--
 llvm/lib/Analysis/IR2Vec.cpp                  | 230 +++++++++----
 llvm/lib/Analysis/InlineAdvisor.cpp           |   2 +-
 llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp        |   6 +-
 .../FunctionPropertiesAnalysisTest.cpp        |   8 +-
 llvm/unittests/Analysis/IR2VecTest.cpp        | 301 ++++++++++++++++--
 6 files changed, 570 insertions(+), 122 deletions(-)

diff --git a/llvm/include/llvm/Analysis/IR2Vec.h 
b/llvm/include/llvm/Analysis/IR2Vec.h
index 4a6db5d895a62..7d51a7320d194 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -45,6 +45,7 @@
 #include "llvm/Support/JSON.h"
 #include <array>
 #include <map>
+#include <optional>
 
 namespace llvm {
 
@@ -144,6 +145,73 @@ struct Embedding {
 using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
 using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
 
+/// Generic storage class for section-based vocabularies.
+/// VocabStorage provides a generic foundation for storing and accessing
+/// embeddings organized into sections.
+class VocabStorage {
+private:
+  /// Section-based storage
+  std::vector<std::vector<Embedding>> Sections;
+
+  size_t TotalSize = 0;
+  unsigned Dimension = 0;
+
+public:
+  /// Default constructor creates empty storage (invalid state)
+  VocabStorage() : Sections(), TotalSize(0), Dimension(0) {}
+
+  /// Create a VocabStorage with pre-organized section data
+  VocabStorage(std::vector<std::vector<Embedding>> &&SectionData);
+
+  VocabStorage(VocabStorage &&) = default;
+  VocabStorage &operator=(VocabStorage &&Other);
+
+  VocabStorage(const VocabStorage &) = delete;
+  VocabStorage &operator=(const VocabStorage &) = delete;
+
+  /// Get total number of entries across all sections
+  size_t size() const { return TotalSize; }
+
+  /// Get number of sections
+  unsigned getNumSections() const {
+    return static_cast<unsigned>(Sections.size());
+  }
+
+  /// Section-based access: Storage[sectionId][localIndex]
+  const std::vector<Embedding> &operator[](unsigned SectionId) const {
+    assert(SectionId < Sections.size() && "Invalid section ID");
+    return Sections[SectionId];
+  }
+
+  /// Get vocabulary dimension
+  unsigned getDimension() const { return Dimension; }
+
+  /// Check if vocabulary is valid (has data)
+  bool isValid() const { return TotalSize > 0; }
+
+  /// Iterator support for section-based access
+  class const_iterator {
+    const VocabStorage *Storage;
+    unsigned SectionId;
+    size_t LocalIndex;
+
+  public:
+    const_iterator(const VocabStorage *Storage, unsigned SectionId,
+                   size_t LocalIndex)
+        : Storage(Storage), SectionId(SectionId), LocalIndex(LocalIndex) {}
+
+    LLVM_ABI const Embedding &operator*() const;
+    LLVM_ABI const_iterator &operator++();
+    LLVM_ABI bool operator==(const const_iterator &Other) const;
+    LLVM_ABI bool operator!=(const const_iterator &Other) const;
+  };
+
+  const_iterator begin() const { return const_iterator(this, 0, 0); }
+  const_iterator end() const {
+    return const_iterator(this, getNumSections(), 0);
+  }
+};
+
 /// Class for storing and accessing the IR2Vec vocabulary.
 /// The Vocabulary class manages seed embeddings for LLVM IR entities. The
 /// seed embeddings are the initial learned representations of the entities
@@ -164,7 +232,7 @@ using BBEmbeddingsMap = DenseMap<const BasicBlock *, 
Embedding>;
 class Vocabulary {
   friend class llvm::IR2VecVocabAnalysis;
 
-  // Vocabulary Slot Layout:
+  // Vocabulary Layout:
   // +----------------+------------------------------------------------------+
   // | Entity Type    | Index Range                                          |
   // +----------------+------------------------------------------------------+
@@ -175,8 +243,16 @@ class Vocabulary {
   // Note: "Similar" LLVM Types are grouped/canonicalized together.
   //       Operands include Comparison predicates (ICmp/FCmp).
   //       This can be extended to include other specializations in future.
-  using VocabVector = std::vector<ir2vec::Embedding>;
-  VocabVector Vocab;
+  enum class Section : unsigned {
+    Opcodes = 0,
+    CanonicalTypes = 1,
+    Operands = 2,
+    Predicates = 3,
+    MaxSections
+  };
+
+  // Use section-based storage for better organization and efficiency
+  VocabStorage Storage;
 
   static constexpr unsigned NumICmpPredicates =
       static_cast<unsigned>(CmpInst::LAST_ICMP_PREDICATE) -
@@ -228,9 +304,18 @@ class Vocabulary {
       NumICmpPredicates + NumFCmpPredicates;
 
   Vocabulary() = default;
-  LLVM_ABI Vocabulary(VocabVector &&Vocab) : Vocab(std::move(Vocab)) {}
+  LLVM_ABI Vocabulary(VocabStorage &&Storage) : Storage(std::move(Storage)) {}
+
+  Vocabulary(const Vocabulary &) = delete;
+  Vocabulary &operator=(const Vocabulary &) = delete;
+
+  Vocabulary(Vocabulary &&) = default;
+  Vocabulary &operator=(Vocabulary &&Other);
+
+  LLVM_ABI bool isValid() const {
+    return Storage.size() == NumCanonicalEntries;
+  }
 
-  LLVM_ABI bool isValid() const { return Vocab.size() == NumCanonicalEntries; 
};
   LLVM_ABI unsigned getDimension() const;
   /// Total number of entries (opcodes + canonicalized types + operand kinds +
   /// predicates)
@@ -251,12 +336,11 @@ class Vocabulary {
   /// Function to get vocabulary key for a given predicate
   LLVM_ABI static StringRef getVocabKeyForPredicate(CmpInst::Predicate P);
 
-  /// Functions to return the slot index or position of a given Opcode, TypeID,
-  /// or OperandKind in the vocabulary.
-  LLVM_ABI static unsigned getSlotIndex(unsigned Opcode);
-  LLVM_ABI static unsigned getSlotIndex(Type::TypeID TypeID);
-  LLVM_ABI static unsigned getSlotIndex(const Value &Op);
-  LLVM_ABI static unsigned getSlotIndex(CmpInst::Predicate P);
+  /// Functions to return flat index
+  LLVM_ABI static unsigned getIndex(unsigned Opcode);
+  LLVM_ABI static unsigned getIndex(Type::TypeID TypeID);
+  LLVM_ABI static unsigned getIndex(const Value &Op);
+  LLVM_ABI static unsigned getIndex(CmpInst::Predicate P);
 
   /// Accessors to get the embedding for a given entity.
   LLVM_ABI const ir2vec::Embedding &operator[](unsigned Opcode) const;
@@ -265,26 +349,21 @@ class Vocabulary {
   LLVM_ABI const ir2vec::Embedding &operator[](CmpInst::Predicate P) const;
 
   /// Const Iterator type aliases
-  using const_iterator = VocabVector::const_iterator;
+  using const_iterator = VocabStorage::const_iterator;
+
   const_iterator begin() const {
     assert(isValid() && "IR2Vec Vocabulary is invalid");
-    return Vocab.begin();
+    return Storage.begin();
   }
 
-  const_iterator cbegin() const {
-    assert(isValid() && "IR2Vec Vocabulary is invalid");
-    return Vocab.cbegin();
-  }
+  const_iterator cbegin() const { return begin(); }
 
   const_iterator end() const {
     assert(isValid() && "IR2Vec Vocabulary is invalid");
-    return Vocab.end();
+    return Storage.end();
   }
 
-  const_iterator cend() const {
-    assert(isValid() && "IR2Vec Vocabulary is invalid");
-    return Vocab.cend();
-  }
+  const_iterator cend() const { return end(); }
 
   /// Returns the string key for a given index position in the vocabulary.
   /// This is useful for debugging or printing the vocabulary. Do not use this
@@ -292,7 +371,7 @@ class Vocabulary {
   LLVM_ABI static StringRef getStringKey(unsigned Pos);
 
   /// Create a dummy vocabulary for testing purposes.
-  LLVM_ABI static VocabVector createDummyVocabForTest(unsigned Dim = 1);
+  LLVM_ABI static VocabStorage createDummyVocabForTest(unsigned Dim = 1);
 
   LLVM_ABI bool invalidate(Module &M, const PreservedAnalyses &PA,
                            ModuleAnalysisManager::Invalidator &Inv) const;
@@ -301,12 +380,16 @@ class Vocabulary {
   constexpr static unsigned NumCanonicalEntries =
       MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds + MaxPredicateKinds;
 
-  // Base offsets for slot layout to simplify index computation
+  // Base offsets for flat index computation
   constexpr static unsigned OperandBaseOffset =
       MaxOpcodes + MaxCanonicalTypeIDs;
   constexpr static unsigned PredicateBaseOffset =
       OperandBaseOffset + MaxOperandKinds;
 
+  /// Functions for predicate index calculations
+  static unsigned getPredicateLocalIndex(CmpInst::Predicate P);
+  static CmpInst::Predicate getPredicateFromLocalIndex(unsigned LocalIndex);
+
   /// String mappings for CanonicalTypeID values
   static constexpr StringLiteral CanonicalTypeNames[] = {
       "FloatTy",   "VoidTy",   "LabelTy",   "MetadataTy",
@@ -452,22 +535,22 @@ class LLVM_ABI FlowAwareEmbedder : public Embedder {
 /// mapping between an entity of the IR (like opcode, type, argument, etc.) and
 /// its corresponding embedding.
 class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
-  using VocabVector = std::vector<ir2vec::Embedding>;
   using VocabMap = std::map<std::string, ir2vec::Embedding>;
-  VocabMap OpcVocab, TypeVocab, ArgVocab;
-  VocabVector Vocab;
+  std::optional<ir2vec::VocabStorage> Vocab;
 
-  Error readVocabulary();
+  Error readVocabulary(VocabMap &OpcVocab, VocabMap &TypeVocab,
+                       VocabMap &ArgVocab);
   Error parseVocabSection(StringRef Key, const json::Value &ParsedVocabValue,
                           VocabMap &TargetVocab, unsigned &Dim);
-  void generateNumMappedVocab();
+  void generateVocabStorage(VocabMap &OpcVocab, VocabMap &TypeVocab,
+                            VocabMap &ArgVocab);
   void emitError(Error Err, LLVMContext &Ctx);
 
 public:
   LLVM_ABI static AnalysisKey Key;
   IR2VecVocabAnalysis() = default;
-  LLVM_ABI explicit IR2VecVocabAnalysis(const VocabVector &Vocab);
-  LLVM_ABI explicit IR2VecVocabAnalysis(VocabVector &&Vocab);
+  LLVM_ABI explicit IR2VecVocabAnalysis(ir2vec::VocabStorage &&Vocab)
+      : Vocab(std::move(Vocab)) {}
   using Result = ir2vec::Vocabulary;
   LLVM_ABI Result run(Module &M, ModuleAnalysisManager &MAM);
 };
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index f51f0898cb37e..eeba109eb7dbd 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -15,6 +15,7 @@
 
 #include "llvm/ADT/DepthFirstIterator.h"
 #include "llvm/ADT/Sequence.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/Statistic.h"
 #include "llvm/IR/CFG.h"
 #include "llvm/IR/Module.h"
@@ -261,55 +262,121 @@ void FlowAwareEmbedder::computeEmbeddings(const 
BasicBlock &BB) const {
   BBVecMap[&BB] = BBVector;
 }
 
+// 
==----------------------------------------------------------------------===//
+// VocabStorage
+//===----------------------------------------------------------------------===//
+
+VocabStorage::VocabStorage(std::vector<std::vector<Embedding>> &&SectionData)
+    : Sections(std::move(SectionData)) {
+  TotalSize = 0;
+  Dimension = 0;
+  assert(!Sections.empty() && "Vocabulary has no sections");
+  assert(!Sections[0].empty() && "First section of vocabulary is empty");
+
+  // Compute total size across all sections
+  for (const auto &Section : Sections)
+    TotalSize += Section.size();
+
+  // Get dimension from the first embedding in the first section - all
+  // embeddings must have the same dimension
+  Dimension = static_cast<unsigned>(Sections[0][0].size());
+}
+
+VocabStorage &VocabStorage::operator=(VocabStorage &&Other) {
+  if (this != &Other) {
+    Sections = std::move(Other.Sections);
+    TotalSize = Other.TotalSize;
+    Dimension = Other.Dimension;
+    Other.TotalSize = 0;
+    Other.Dimension = 0;
+  }
+  return *this;
+}
+
+const Embedding &VocabStorage::const_iterator::operator*() const {
+  assert(SectionId < Storage->Sections.size() && "Invalid section ID");
+  assert(LocalIndex < Storage->Sections[SectionId].size() &&
+         "Local index out of range");
+  return Storage->Sections[SectionId][LocalIndex];
+}
+
+VocabStorage::const_iterator &VocabStorage::const_iterator::operator++() {
+  ++LocalIndex;
+  // Check if we need to move to the next section
+  while (SectionId < Storage->getNumSections() &&
+         LocalIndex >= Storage->Sections[SectionId].size()) {
+    LocalIndex = 0;
+    ++SectionId;
+  }
+  return *this;
+}
+
+bool VocabStorage::const_iterator::operator==(
+    const const_iterator &Other) const {
+  return Storage == Other.Storage && SectionId == Other.SectionId &&
+         LocalIndex == Other.LocalIndex;
+}
+
+bool VocabStorage::const_iterator::operator!=(
+    const const_iterator &Other) const {
+  return !(*this == Other);
+}
+
 // 
==----------------------------------------------------------------------===//
 // Vocabulary
 
//===----------------------------------------------------------------------===//
 
+Vocabulary &Vocabulary::operator=(Vocabulary &&Other) {
+  if (this != &Other)
+    Storage = std::move(Other.Storage);
+  return *this;
+}
+
 unsigned Vocabulary::getDimension() const {
   assert(isValid() && "IR2Vec Vocabulary is invalid");
-  return Vocab[0].size();
+  return Storage.getDimension();
 }
 
-unsigned Vocabulary::getSlotIndex(unsigned Opcode) {
+unsigned Vocabulary::getIndex(unsigned Opcode) {
   assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
   return Opcode - 1; // Convert to zero-based index
 }
 
-unsigned Vocabulary::getSlotIndex(Type::TypeID TypeID) {
+unsigned Vocabulary::getIndex(Type::TypeID TypeID) {
   assert(static_cast<unsigned>(TypeID) < MaxTypeIDs && "Invalid type ID");
   return MaxOpcodes + static_cast<unsigned>(getCanonicalTypeID(TypeID));
 }
 
-unsigned Vocabulary::getSlotIndex(const Value &Op) {
+unsigned Vocabulary::getIndex(const Value &Op) {
   unsigned Index = static_cast<unsigned>(getOperandKind(&Op));
   assert(Index < MaxOperandKinds && "Invalid OperandKind");
   return OperandBaseOffset + Index;
 }
 
-unsigned Vocabulary::getSlotIndex(CmpInst::Predicate P) {
-  unsigned PU = static_cast<unsigned>(P);
-  unsigned FirstFC = static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE);
-  unsigned FirstIC = static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE);
-
-  unsigned PredIdx =
-      (PU >= FirstIC) ? (NumFCmpPredicates + (PU - FirstIC)) : (PU - FirstFC);
-  return PredicateBaseOffset + PredIdx;
+unsigned Vocabulary::getIndex(CmpInst::Predicate P) {
+  return PredicateBaseOffset + getPredicateLocalIndex(P);
 }
 
 const Embedding &Vocabulary::operator[](unsigned Opcode) const {
-  return Vocab[getSlotIndex(Opcode)];
+  assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
+  return Storage[static_cast<unsigned>(Section::Opcodes)][Opcode - 1];
 }
 
 const Embedding &Vocabulary::operator[](Type::TypeID TypeID) const {
-  return Vocab[getSlotIndex(TypeID)];
+  assert(static_cast<unsigned>(TypeID) < MaxTypeIDs && "Invalid type ID");
+  unsigned LocalIndex = static_cast<unsigned>(getCanonicalTypeID(TypeID));
+  return Storage[static_cast<unsigned>(Section::CanonicalTypes)][LocalIndex];
 }
 
 const ir2vec::Embedding &Vocabulary::operator[](const Value &Arg) const {
-  return Vocab[getSlotIndex(Arg)];
+  unsigned LocalIndex = static_cast<unsigned>(getOperandKind(&Arg));
+  assert(LocalIndex < MaxOperandKinds && "Invalid OperandKind");
+  return Storage[static_cast<unsigned>(Section::Operands)][LocalIndex];
 }
 
 const ir2vec::Embedding &Vocabulary::operator[](CmpInst::Predicate P) const {
-  return Vocab[getSlotIndex(P)];
+  unsigned LocalIndex = getPredicateLocalIndex(P);
+  return Storage[static_cast<unsigned>(Section::Predicates)][LocalIndex];
 }
 
 StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) {
@@ -359,12 +426,26 @@ Vocabulary::OperandKind Vocabulary::getOperandKind(const 
Value *Op) {
 
 CmpInst::Predicate Vocabulary::getPredicate(unsigned Index) {
   assert(Index < MaxPredicateKinds && "Invalid predicate index");
-  unsigned PredEnumVal =
-      (Index < NumFCmpPredicates)
-          ? (static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE) + Index)
-          : (static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE) +
-             (Index - NumFCmpPredicates));
-  return static_cast<CmpInst::Predicate>(PredEnumVal);
+  return getPredicateFromLocalIndex(Index);
+}
+
+unsigned Vocabulary::getPredicateLocalIndex(CmpInst::Predicate P) {
+  if (P >= CmpInst::FIRST_FCMP_PREDICATE && P <= CmpInst::LAST_FCMP_PREDICATE)
+    return P - CmpInst::FIRST_FCMP_PREDICATE;
+  else
+    return P - CmpInst::FIRST_ICMP_PREDICATE +
+           (CmpInst::LAST_FCMP_PREDICATE - CmpInst::FIRST_FCMP_PREDICATE + 1);
+}
+
+CmpInst::Predicate Vocabulary::getPredicateFromLocalIndex(unsigned LocalIndex) 
{
+  unsigned fcmpRange =
+      CmpInst::LAST_FCMP_PREDICATE - CmpInst::FIRST_FCMP_PREDICATE + 1;
+  if (LocalIndex < fcmpRange)
+    return static_cast<CmpInst::Predicate>(CmpInst::FIRST_FCMP_PREDICATE +
+                                           LocalIndex);
+  else
+    return static_cast<CmpInst::Predicate>(CmpInst::FIRST_ICMP_PREDICATE +
+                                           LocalIndex - fcmpRange);
 }
 
 StringRef Vocabulary::getVocabKeyForPredicate(CmpInst::Predicate Pred) {
@@ -401,17 +482,51 @@ bool Vocabulary::invalidate(Module &M, const 
PreservedAnalyses &PA,
   return !(PAC.preservedWhenStateless());
 }
 
-Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) {
-  VocabVector DummyVocab;
-  DummyVocab.reserve(NumCanonicalEntries);
+VocabStorage Vocabulary::createDummyVocabForTest(unsigned Dim) {
   float DummyVal = 0.1f;
-  // Create a dummy vocabulary with entries for all opcodes, types, operands
-  // and predicates
-  for ([[maybe_unused]] unsigned _ : seq(0u, Vocabulary::NumCanonicalEntries)) 
{
-    DummyVocab.push_back(Embedding(Dim, DummyVal));
+
+  // Create sections for opcodes, types, operands, and predicates
+  // Order must match Vocabulary::Section enum
+  std::vector<std::vector<Embedding>> Sections;
+  Sections.reserve(4);
+
+  // Opcodes section
+  std::vector<Embedding> OpcodeSec;
+  OpcodeSec.reserve(MaxOpcodes);
+  for (unsigned I = 0; I < MaxOpcodes; ++I) {
+    OpcodeSec.emplace_back(Dim, DummyVal);
+    DummyVal += 0.1f;
+  }
+  Sections.push_back(std::move(OpcodeSec));
+
+  // Types section
+  std::vector<Embedding> TypeSec;
+  TypeSec.reserve(MaxCanonicalTypeIDs);
+  for (unsigned I = 0; I < MaxCanonicalTypeIDs; ++I) {
+    TypeSec.emplace_back(Dim, DummyVal);
+    DummyVal += 0.1f;
+  }
+  Sections.push_back(std::move(TypeSec));
+
+  // Operands section
+  std::vector<Embedding> OperandSec;
+  OperandSec.reserve(MaxOperandKinds);
+  for (unsigned I = 0; I < MaxOperandKinds; ++I) {
+    OperandSec.emplace_back(Dim, DummyVal);
     DummyVal += 0.1f;
   }
-  return DummyVocab;
+  Sections.push_back(std::move(OperandSec));
+
+  // Predicates section
+  std::vector<Embedding> PredicateSec;
+  PredicateSec.reserve(MaxPredicateKinds);
+  for (unsigned I = 0; I < MaxPredicateKinds; ++I) {
+    PredicateSec.emplace_back(Dim, DummyVal);
+    DummyVal += 0.1f;
+  }
+  Sections.push_back(std::move(PredicateSec));
+
+  return VocabStorage(std::move(Sections));
 }
 
 // 
==----------------------------------------------------------------------===//
@@ -457,7 +572,9 @@ Error IR2VecVocabAnalysis::parseVocabSection(
 
 // FIXME: Make this optional. We can avoid file reads
 // by auto-generating a default vocabulary during the build time.
-Error IR2VecVocabAnalysis::readVocabulary() {
+Error IR2VecVocabAnalysis::readVocabulary(VocabMap &OpcVocab,
+                                          VocabMap &TypeVocab,
+                                          VocabMap &ArgVocab) {
   auto BufOrError = MemoryBuffer::getFileOrSTDIN(VocabFile, /*IsText=*/true);
   if (!BufOrError)
     return createFileError(VocabFile, BufOrError.getError());
@@ -488,7 +605,9 @@ Error IR2VecVocabAnalysis::readVocabulary() {
   return Error::success();
 }
 
-void IR2VecVocabAnalysis::generateNumMappedVocab() {
+void IR2VecVocabAnalysis::generateVocabStorage(VocabMap &OpcVocab,
+                                               VocabMap &TypeVocab,
+                                               VocabMap &ArgVocab) {
 
   // Helper for handling missing entities in the vocabulary.
   // Currently, we use a zero vector. In the future, we will throw an error to
@@ -506,7 +625,6 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
   // Handle Opcodes
   std::vector<Embedding> NumericOpcodeEmbeddings(Vocabulary::MaxOpcodes,
                                                  Embedding(Dim));
-  NumericOpcodeEmbeddings.reserve(Vocabulary::MaxOpcodes);
   for (unsigned Opcode : seq(0u, Vocabulary::MaxOpcodes)) {
     StringRef VocabKey = Vocabulary::getVocabKeyForOpcode(Opcode + 1);
     auto It = OpcVocab.find(VocabKey.str());
@@ -515,13 +633,10 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
     else
       handleMissingEntity(VocabKey.str());
   }
-  Vocab.insert(Vocab.end(), NumericOpcodeEmbeddings.begin(),
-               NumericOpcodeEmbeddings.end());
 
   // Handle Types - only canonical types are present in vocabulary
   std::vector<Embedding> NumericTypeEmbeddings(Vocabulary::MaxCanonicalTypeIDs,
                                                Embedding(Dim));
-  NumericTypeEmbeddings.reserve(Vocabulary::MaxCanonicalTypeIDs);
   for (unsigned CTypeID : seq(0u, Vocabulary::MaxCanonicalTypeIDs)) {
     StringRef VocabKey = Vocabulary::getVocabKeyForCanonicalTypeID(
         static_cast<Vocabulary::CanonicalTypeID>(CTypeID));
@@ -531,13 +646,10 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
     }
     handleMissingEntity(VocabKey.str());
   }
-  Vocab.insert(Vocab.end(), NumericTypeEmbeddings.begin(),
-               NumericTypeEmbeddings.end());
 
   // Handle Arguments/Operands
   std::vector<Embedding> NumericArgEmbeddings(Vocabulary::MaxOperandKinds,
                                               Embedding(Dim));
-  NumericArgEmbeddings.reserve(Vocabulary::MaxOperandKinds);
   for (unsigned OpKind : seq(0u, Vocabulary::MaxOperandKinds)) {
     Vocabulary::OperandKind Kind = 
static_cast<Vocabulary::OperandKind>(OpKind);
     StringRef VocabKey = Vocabulary::getVocabKeyForOperandKind(Kind);
@@ -548,14 +660,11 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
     }
     handleMissingEntity(VocabKey.str());
   }
-  Vocab.insert(Vocab.end(), NumericArgEmbeddings.begin(),
-               NumericArgEmbeddings.end());
 
   // Handle Predicates: part of Operands section. We look up predicate keys
   // in ArgVocab.
   std::vector<Embedding> NumericPredEmbeddings(Vocabulary::MaxPredicateKinds,
                                                Embedding(Dim, 0));
-  NumericPredEmbeddings.reserve(Vocabulary::MaxPredicateKinds);
   for (unsigned PK : seq(0u, Vocabulary::MaxPredicateKinds)) {
     StringRef VocabKey =
         Vocabulary::getVocabKeyForPredicate(Vocabulary::getPredicate(PK));
@@ -566,15 +675,22 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
     }
     handleMissingEntity(VocabKey.str());
   }
-  Vocab.insert(Vocab.end(), NumericPredEmbeddings.begin(),
-               NumericPredEmbeddings.end());
-}
 
-IR2VecVocabAnalysis::IR2VecVocabAnalysis(const VocabVector &Vocab)
-    : Vocab(Vocab) {}
+  // Create section-based storage instead of flat vocabulary
+  // Order must match Vocabulary::Section enum
+  std::vector<std::vector<Embedding>> Sections(4);
+  Sections[static_cast<unsigned>(Vocabulary::Section::Opcodes)] =
+      std::move(NumericOpcodeEmbeddings); // Section::Opcodes
+  Sections[static_cast<unsigned>(Vocabulary::Section::CanonicalTypes)] =
+      std::move(NumericTypeEmbeddings); // Section::CanonicalTypes
+  Sections[static_cast<unsigned>(Vocabulary::Section::Operands)] =
+      std::move(NumericArgEmbeddings); // Section::Operands
+  Sections[static_cast<unsigned>(Vocabulary::Section::Predicates)] =
+      std::move(NumericPredEmbeddings); // Section::Predicates
 
-IR2VecVocabAnalysis::IR2VecVocabAnalysis(VocabVector &&Vocab)
-    : Vocab(std::move(Vocab)) {}
+  // Create VocabStorage from organized sections
+  Vocab.emplace(std::move(Sections));
+}
 
 void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) {
   handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
@@ -586,8 +702,8 @@ IR2VecVocabAnalysis::Result
 IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
   auto Ctx = &M.getContext();
   // If vocabulary is already populated by the constructor, use it.
-  if (!Vocab.empty())
-    return Vocabulary(std::move(Vocab));
+  if (Vocab.has_value())
+    return Vocabulary(std::move(Vocab.value()));
 
   // Otherwise, try to read from the vocabulary file.
   if (VocabFile.empty()) {
@@ -596,7 +712,9 @@ IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager 
&AM) {
                    "set it using --ir2vec-vocab-path");
     return Vocabulary(); // Return invalid result
   }
-  if (auto Err = readVocabulary()) {
+
+  VocabMap OpcVocab, TypeVocab, ArgVocab;
+  if (auto Err = readVocabulary(OpcVocab, TypeVocab, ArgVocab)) {
     emitError(std::move(Err), *Ctx);
     return Vocabulary();
   }
@@ -611,9 +729,9 @@ IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager 
&AM) {
   scaleVocabSection(ArgVocab, ArgWeight);
 
   // Generate the numeric lookup vocabulary
-  generateNumMappedVocab();
+  generateVocabStorage(OpcVocab, TypeVocab, ArgVocab);
 
-  return Vocabulary(std::move(Vocab));
+  return Vocabulary(std::move(Vocab.value()));
 }
 
 // 
==----------------------------------------------------------------------===//
@@ -622,7 +740,7 @@ IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager 
&AM) {
 
 PreservedAnalyses IR2VecPrinterPass::run(Module &M,
                                          ModuleAnalysisManager &MAM) {
-  auto Vocabulary = MAM.getResult<IR2VecVocabAnalysis>(M);
+  auto &Vocabulary = MAM.getResult<IR2VecVocabAnalysis>(M);
   assert(Vocabulary.isValid() && "IR2Vec Vocabulary is invalid");
 
   for (Function &F : M) {
@@ -664,7 +782,7 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
 
 PreservedAnalyses IR2VecVocabPrinterPass::run(Module &M,
                                               ModuleAnalysisManager &MAM) {
-  auto IR2VecVocabulary = MAM.getResult<IR2VecVocabAnalysis>(M);
+  auto &IR2VecVocabulary = MAM.getResult<IR2VecVocabAnalysis>(M);
   assert(IR2VecVocabulary.isValid() && "IR2Vec Vocabulary is invalid");
 
   // Print each entry
diff --git a/llvm/lib/Analysis/InlineAdvisor.cpp 
b/llvm/lib/Analysis/InlineAdvisor.cpp
index 28b14c2562df1..0fa804f2959e8 100644
--- a/llvm/lib/Analysis/InlineAdvisor.cpp
+++ b/llvm/lib/Analysis/InlineAdvisor.cpp
@@ -217,7 +217,7 @@ AnalysisKey PluginInlineAdvisorAnalysis::Key;
 bool InlineAdvisorAnalysis::initializeIR2VecVocabIfRequested(
     Module &M, ModuleAnalysisManager &MAM) {
   if (!IR2VecVocabFile.empty()) {
-    auto IR2VecVocabResult = MAM.getResult<IR2VecVocabAnalysis>(M);
+    auto &IR2VecVocabResult = MAM.getResult<IR2VecVocabAnalysis>(M);
     if (!IR2VecVocabResult.isValid()) {
       M.getContext().emitError("Failed to load IR2Vec vocabulary");
       return false;
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp 
b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index 1c656b8fcf4e7..434449c7c5117 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -162,8 +162,8 @@ class IR2VecTool {
 
     for (const BasicBlock &BB : F) {
       for (const auto &I : BB.instructionsWithoutDebug()) {
-        unsigned Opcode = Vocabulary::getSlotIndex(I.getOpcode());
-        unsigned TypeID = Vocabulary::getSlotIndex(I.getType()->getTypeID());
+        unsigned Opcode = Vocabulary::getIndex(I.getOpcode());
+        unsigned TypeID = Vocabulary::getIndex(I.getType()->getTypeID());
 
         // Add "Next" relationship with previous instruction
         if (HasPrevOpcode) {
@@ -184,7 +184,7 @@ class IR2VecTool {
         // Add "Arg" relationships
         unsigned ArgIndex = 0;
         for (const Use &U : I.operands()) {
-          unsigned OperandID = Vocabulary::getSlotIndex(*U.get());
+          unsigned OperandID = Vocabulary::getIndex(*U.get());
           unsigned RelationID = ArgRelation + ArgIndex;
           OS << Opcode << '\t' << OperandID << '\t' << RelationID << '\n';
 
diff --git a/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp 
b/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp
index dc6059dcf6827..442f703f08d0c 100644
--- a/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp
+++ b/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp
@@ -43,8 +43,10 @@ class FunctionPropertiesAnalysisTest : public testing::Test {
 public:
   FunctionPropertiesAnalysisTest() {
     auto VocabVector = ir2vec::Vocabulary::createDummyVocabForTest(1);
-    MAM.registerPass([&] { return IR2VecVocabAnalysis(VocabVector); });
-    IR2VecVocab = ir2vec::Vocabulary(std::move(VocabVector));
+    MAM.registerPass([VocabVector = std::move(VocabVector)]() mutable { 
+      return IR2VecVocabAnalysis(std::move(VocabVector)); 
+    });
+    IR2VecVocab = 
ir2vec::Vocabulary(ir2vec::Vocabulary::createDummyVocabForTest(1));
     MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
     FAM.registerPass([&] { return ModuleAnalysisManagerFunctionProxy(MAM); });
     FAM.registerPass([&] { return DominatorTreeAnalysis(); });
@@ -78,7 +80,7 @@ class FunctionPropertiesAnalysisTest : public testing::Test {
   FunctionPropertiesInfo buildFPI(Function &F) {
     // FunctionPropertiesInfo assumes IR2VecVocabAnalysis has been run to
     // use IR2Vec.
-    auto VocabResult = MAM.getResult<IR2VecVocabAnalysis>(*F.getParent());
+    auto &VocabResult = MAM.getResult<IR2VecVocabAnalysis>(*F.getParent());
     (void)VocabResult;
     return FunctionPropertiesInfo::getFunctionPropertiesInfo(F, FAM);
   }
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp 
b/llvm/unittests/Analysis/IR2VecTest.cpp
index 9bc48e45eab5e..d915920eccda0 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -464,7 +464,10 @@ TEST(IR2VecVocabularyTest, DummyVocabTest) {
     EXPECT_EQ(VocabVecSize, MaxOpcodes + MaxCanonicalTypeIDs + MaxOperands +
                                 MaxPredicateKinds);
 
-    auto ExpectedVocab = VocabVec;
+    // Collect embeddings for later comparison before moving VocabVec
+    std::vector<Embedding> ExpectedVocab;
+    for (const auto &Emb : VocabVec)
+      ExpectedVocab.push_back(Emb);
 
     IR2VecVocabAnalysis VocabAnalysis(std::move(VocabVec));
     LLVMContext TestCtx;
@@ -482,17 +485,17 @@ TEST(IR2VecVocabularyTest, DummyVocabTest) {
 }
 
 TEST(IR2VecVocabularyTest, SlotIdxMapping) {
-  // Test getSlotIndex for Opcodes
+  // Test getIndex for Opcodes
 #define EXPECT_OPCODE_SLOT(NUM, OPCODE, CLASS)                                 
\
-  EXPECT_EQ(Vocabulary::getSlotIndex(NUM), static_cast<unsigned>(NUM - 1));
+  EXPECT_EQ(Vocabulary::getIndex(NUM), static_cast<unsigned>(NUM - 1));
 #define HANDLE_INST(NUM, OPCODE, CLASS) EXPECT_OPCODE_SLOT(NUM, OPCODE, CLASS)
 #include "llvm/IR/Instruction.def"
 #undef HANDLE_INST
 #undef EXPECT_OPCODE_SLOT
 
-  // Test getSlotIndex for Types
+  // Test getIndex for Types
 #define EXPECT_TYPE_SLOT(TypeIDTok, CanonEnum, CanonStr)                       
\
-  EXPECT_EQ(Vocabulary::getSlotIndex(Type::TypeIDTok),                         
\
+  EXPECT_EQ(Vocabulary::getIndex(Type::TypeIDTok),                             
\
             MaxOpcodes + static_cast<unsigned>(                                
\
                              Vocabulary::CanonicalTypeID::CanonEnum));
 
@@ -500,7 +503,7 @@ TEST(IR2VecVocabularyTest, SlotIdxMapping) {
 
 #undef EXPECT_TYPE_SLOT
 
-  // Test getSlotIndex for Value operands
+  // Test getIndex for Value operands
   LLVMContext Ctx;
   Module M("TestM", Ctx);
   FunctionType *FTy =
@@ -510,27 +513,27 @@ TEST(IR2VecVocabularyTest, SlotIdxMapping) {
 #define EXPECTED_VOCAB_OPERAND_SLOT(X)                                         
\
   MaxOpcodes + MaxCanonicalTypeIDs + static_cast<unsigned>(X)
   // Test Function operand
-  EXPECT_EQ(Vocabulary::getSlotIndex(*F),
+  EXPECT_EQ(Vocabulary::getIndex(*F),
             EXPECTED_VOCAB_OPERAND_SLOT(Vocabulary::OperandKind::FunctionID));
 
   // Test Constant operand
   Constant *C = ConstantInt::get(Type::getInt32Ty(Ctx), 42);
-  EXPECT_EQ(Vocabulary::getSlotIndex(*C),
+  EXPECT_EQ(Vocabulary::getIndex(*C),
             EXPECTED_VOCAB_OPERAND_SLOT(Vocabulary::OperandKind::ConstantID));
 
   // Test Pointer operand
   BasicBlock *BB = BasicBlock::Create(Ctx, "entry", F);
   AllocaInst *PtrVal = new AllocaInst(Type::getInt32Ty(Ctx), 0, "ptr", BB);
-  EXPECT_EQ(Vocabulary::getSlotIndex(*PtrVal),
+  EXPECT_EQ(Vocabulary::getIndex(*PtrVal),
             EXPECTED_VOCAB_OPERAND_SLOT(Vocabulary::OperandKind::PointerID));
 
   // Test Variable operand (function argument)
   Argument *Arg = F->getArg(0);
-  EXPECT_EQ(Vocabulary::getSlotIndex(*Arg),
+  EXPECT_EQ(Vocabulary::getIndex(*Arg),
             EXPECTED_VOCAB_OPERAND_SLOT(Vocabulary::OperandKind::VariableID));
 #undef EXPECTED_VOCAB_OPERAND_SLOT
 
-  // Test getSlotIndex for predicates
+  // Test getIndex for predicates
 #define EXPECTED_VOCAB_PREDICATE_SLOT(X)                                       
\
   MaxOpcodes + MaxCanonicalTypeIDs + MaxOperands + static_cast<unsigned>(X)
   for (unsigned P = CmpInst::FIRST_FCMP_PREDICATE;
@@ -538,7 +541,7 @@ TEST(IR2VecVocabularyTest, SlotIdxMapping) {
     CmpInst::Predicate Pred = static_cast<CmpInst::Predicate>(P);
     unsigned ExpectedIdx =
         EXPECTED_VOCAB_PREDICATE_SLOT((P - CmpInst::FIRST_FCMP_PREDICATE));
-    EXPECT_EQ(Vocabulary::getSlotIndex(Pred), ExpectedIdx);
+    EXPECT_EQ(Vocabulary::getIndex(Pred), ExpectedIdx);
   }
   auto ICMP_Start = CmpInst::LAST_FCMP_PREDICATE + 1;
   for (unsigned P = CmpInst::FIRST_ICMP_PREDICATE;
@@ -546,7 +549,7 @@ TEST(IR2VecVocabularyTest, SlotIdxMapping) {
     CmpInst::Predicate Pred = static_cast<CmpInst::Predicate>(P);
     unsigned ExpectedIdx = EXPECTED_VOCAB_PREDICATE_SLOT(
         ICMP_Start + P - CmpInst::FIRST_ICMP_PREDICATE);
-    EXPECT_EQ(Vocabulary::getSlotIndex(Pred), ExpectedIdx);
+    EXPECT_EQ(Vocabulary::getIndex(Pred), ExpectedIdx);
   }
 #undef EXPECTED_VOCAB_PREDICATE_SLOT
 }
@@ -555,15 +558,14 @@ TEST(IR2VecVocabularyTest, SlotIdxMapping) {
 #ifndef NDEBUG
 TEST(IR2VecVocabularyTest, NumericIDMapInvalidInputs) {
   // Test invalid opcode IDs
-  EXPECT_DEATH(Vocabulary::getSlotIndex(0u), "Invalid opcode");
-  EXPECT_DEATH(Vocabulary::getSlotIndex(MaxOpcodes + 1), "Invalid opcode");
+  EXPECT_DEATH(Vocabulary::getIndex(0u), "Invalid opcode");
+  EXPECT_DEATH(Vocabulary::getIndex(MaxOpcodes + 1), "Invalid opcode");
 
   // Test invalid type IDs
-  EXPECT_DEATH(Vocabulary::getSlotIndex(static_cast<Type::TypeID>(MaxTypeIDs)),
+  EXPECT_DEATH(Vocabulary::getIndex(static_cast<Type::TypeID>(MaxTypeIDs)),
+               "Invalid type ID");
+  EXPECT_DEATH(Vocabulary::getIndex(static_cast<Type::TypeID>(MaxTypeIDs + 
10)),
                "Invalid type ID");
-  EXPECT_DEATH(
-      Vocabulary::getSlotIndex(static_cast<Type::TypeID>(MaxTypeIDs + 10)),
-      "Invalid type ID");
 }
 #endif // NDEBUG
 #endif // GTEST_HAS_DEATH_TEST
@@ -573,7 +575,7 @@ TEST(IR2VecVocabularyTest, StringKeyGeneration) {
   EXPECT_EQ(Vocabulary::getStringKey(12), "Add");
 
 #define EXPECT_OPCODE(NUM, OPCODE, CLASS)                                      
\
-  EXPECT_EQ(Vocabulary::getStringKey(Vocabulary::getSlotIndex(NUM)),           
\
+  EXPECT_EQ(Vocabulary::getStringKey(Vocabulary::getIndex(NUM)),               
\
             Vocabulary::getVocabKeyForOpcode(NUM));
 #define HANDLE_INST(NUM, OPCODE, CLASS) EXPECT_OPCODE(NUM, OPCODE, CLASS)
 #include "llvm/IR/Instruction.def"
@@ -672,10 +674,12 @@ TEST(IR2VecVocabularyTest, InvalidAccess) {
 #endif // GTEST_HAS_DEATH_TEST
 
 TEST(IR2VecVocabularyTest, TypeIDStringKeyMapping) {
+  Vocabulary V = Vocabulary(Vocabulary::createDummyVocabForTest());
 #define EXPECT_TYPE_TO_CANONICAL(TypeIDTok, CanonEnum, CanonStr)               
\
-  EXPECT_EQ(                                                                   
\
-      Vocabulary::getStringKey(Vocabulary::getSlotIndex(Type::TypeIDTok)),     
\
-      CanonStr);
+  do {                                                                         
\
+    unsigned FlatIdx = V.getIndex(Type::TypeIDTok);                            
\
+    EXPECT_EQ(Vocabulary::getStringKey(FlatIdx), CanonStr);                    
\
+  } while (0);
 
   IR2VEC_HANDLE_TYPE_BIMAP(EXPECT_TYPE_TO_CANONICAL)
 
@@ -683,14 +687,20 @@ TEST(IR2VecVocabularyTest, TypeIDStringKeyMapping) {
 }
 
 TEST(IR2VecVocabularyTest, InvalidVocabularyConstruction) {
-  std::vector<Embedding> InvalidVocab;
-  InvalidVocab.push_back(Embedding(2, 1.0));
-  InvalidVocab.push_back(Embedding(2, 2.0));
-
-  Vocabulary V(std::move(InvalidVocab));
+  // Test 1: Create invalid VocabStorage with insufficient sections
+  std::vector<std::vector<Embedding>> InvalidSectionData;
+  // Only add one section with 2 embeddings, but the vocabulary needs 4 
sections
+  std::vector<Embedding> Section1;
+  Section1.push_back(Embedding(2, 1.0));
+  Section1.push_back(Embedding(2, 2.0));
+  InvalidSectionData.push_back(std::move(Section1));
+
+  VocabStorage InvalidStorage(std::move(InvalidSectionData));
+  Vocabulary V(std::move(InvalidStorage));
   EXPECT_FALSE(V.isValid());
 
   {
+    // Test 2: Default-constructed vocabulary should be invalid
     Vocabulary InvalidResult;
     EXPECT_FALSE(InvalidResult.isValid());
 #if GTEST_HAS_DEATH_TEST
@@ -701,4 +711,239 @@ TEST(IR2VecVocabularyTest, InvalidVocabularyConstruction) 
{
   }
 }
 
+TEST(VocabStorageTest, DefaultConstructor) {
+  VocabStorage storage;
+
+  EXPECT_EQ(storage.size(), 0u);
+  EXPECT_EQ(storage.getNumSections(), 0u);
+  EXPECT_EQ(storage.getDimension(), 0u);
+  EXPECT_FALSE(storage.isValid());
+
+  // Test iterators on empty storage
+  EXPECT_EQ(storage.begin(), storage.end());
+}
+
+TEST(VocabStorageTest, BasicConstruction) {
+  // Create test data with 3 sections
+  std::vector<std::vector<Embedding>> sectionData;
+
+  // Section 0: 2 embeddings of dimension 3
+  std::vector<Embedding> section0;
+  section0.emplace_back(std::vector<double>{1.0, 2.0, 3.0});
+  section0.emplace_back(std::vector<double>{4.0, 5.0, 6.0});
+  sectionData.push_back(std::move(section0));
+
+  // Section 1: 1 embedding of dimension 3
+  std::vector<Embedding> section1;
+  section1.emplace_back(std::vector<double>{7.0, 8.0, 9.0});
+  sectionData.push_back(std::move(section1));
+
+  // Section 2: 3 embeddings of dimension 3
+  std::vector<Embedding> section2;
+  section2.emplace_back(std::vector<double>{10.0, 11.0, 12.0});
+  section2.emplace_back(std::vector<double>{13.0, 14.0, 15.0});
+  section2.emplace_back(std::vector<double>{16.0, 17.0, 18.0});
+  sectionData.push_back(std::move(section2));
+
+  VocabStorage storage(std::move(sectionData));
+
+  EXPECT_EQ(storage.size(), 6u); // Total: 2 + 1 + 3 = 6
+  EXPECT_EQ(storage.getNumSections(), 3u);
+  EXPECT_EQ(storage.getDimension(), 3u);
+  EXPECT_TRUE(storage.isValid());
+}
+
+TEST(VocabStorageTest, SectionAccess) {
+  // Create test data
+  std::vector<std::vector<Embedding>> sectionData;
+
+  std::vector<Embedding> section0;
+  section0.emplace_back(std::vector<double>{1.0, 2.0});
+  section0.emplace_back(std::vector<double>{3.0, 4.0});
+  sectionData.push_back(std::move(section0));
+
+  std::vector<Embedding> section1;
+  section1.emplace_back(std::vector<double>{5.0, 6.0});
+  sectionData.push_back(std::move(section1));
+
+  VocabStorage storage(std::move(sectionData));
+
+  // Test section access
+  EXPECT_EQ(storage[0].size(), 2u);
+  EXPECT_EQ(storage[1].size(), 1u);
+
+  // Test embedding values
+  EXPECT_THAT(storage[0][0].getData(), ElementsAre(1.0, 2.0));
+  EXPECT_THAT(storage[0][1].getData(), ElementsAre(3.0, 4.0));
+  EXPECT_THAT(storage[1][0].getData(), ElementsAre(5.0, 6.0));
+}
+
+#if GTEST_HAS_DEATH_TEST
+#ifndef NDEBUG
+TEST(VocabStorageTest, InvalidSectionAccess) {
+  std::vector<std::vector<Embedding>> sectionData;
+  std::vector<Embedding> section0;
+  section0.emplace_back(std::vector<double>{1.0, 2.0});
+  sectionData.push_back(std::move(section0));
+
+  VocabStorage storage(std::move(sectionData));
+
+  EXPECT_DEATH(storage[1], "Invalid section ID");
+  EXPECT_DEATH(storage[10], "Invalid section ID");
+}
+
+TEST(VocabStorageTest, EmptySection) {
+  std::vector<std::vector<Embedding>> sectionData;
+  std::vector<Embedding> emptySection; // Empty section
+  sectionData.push_back(std::move(emptySection));
+
+  std::vector<Embedding> validSection;
+  validSection.emplace_back(std::vector<double>{1.0});
+  sectionData.push_back(std::move(validSection));
+
+  EXPECT_DEATH(VocabStorage(std::move(sectionData)),
+               "First section of vocabulary is empty");
+}
+
+TEST(VocabStorageTest, NoSections) {
+  std::vector<std::vector<Embedding>> sectionData; // No sections
+
+  EXPECT_DEATH(VocabStorage(std::move(sectionData)),
+               "Vocabulary has no sections");
+}
+#endif // NDEBUG
+#endif // GTEST_HAS_DEATH_TEST
+
+TEST(VocabStorageTest, MoveAssignment) {
+  // Create source storage
+  std::vector<std::vector<Embedding>> sectionData1;
+  std::vector<Embedding> section0;
+  section0.emplace_back(std::vector<double>{1.0, 2.0});
+  sectionData1.push_back(std::move(section0));
+  VocabStorage source(std::move(sectionData1));
+
+  // Create destination storage
+  std::vector<std::vector<Embedding>> sectionData2;
+  std::vector<Embedding> section1;
+  section1.emplace_back(std::vector<double>{5.0, 6.0, 7.0});
+  sectionData2.push_back(std::move(section1));
+  VocabStorage dest(std::move(sectionData2));
+
+  EXPECT_EQ(dest.getDimension(), 3u); // Initially 3D
+
+  // Move assign
+  dest = std::move(source);
+
+  // Check destination has source's data
+  EXPECT_EQ(dest.size(), 1u);
+  EXPECT_EQ(dest.getDimension(), 2u); // Now 2D from source
+  EXPECT_TRUE(dest.isValid());
+  EXPECT_THAT(dest[0][0].getData(), ElementsAre(1.0, 2.0));
+}
+
+TEST(VocabStorageTest, IteratorBasics) {
+  std::vector<std::vector<Embedding>> sectionData;
+
+  std::vector<Embedding> section0;
+  section0.emplace_back(std::vector<double>{1.0, 2.0});
+  section0.emplace_back(std::vector<double>{3.0, 4.0});
+  sectionData.push_back(std::move(section0));
+
+  std::vector<Embedding> section1;
+  section1.emplace_back(std::vector<double>{5.0, 6.0});
+  sectionData.push_back(std::move(section1));
+
+  VocabStorage storage(std::move(sectionData));
+
+  // Test iterator basics
+  auto it = storage.begin();
+  auto end = storage.end();
+
+  EXPECT_NE(it, end);
+
+  // Check first embedding
+  EXPECT_THAT((*it).getData(), ElementsAre(1.0, 2.0));
+
+  // Advance to second embedding
+  ++it;
+  EXPECT_NE(it, end);
+  EXPECT_THAT((*it).getData(), ElementsAre(3.0, 4.0));
+
+  // Advance to third embedding (in section 1)
+  ++it;
+  EXPECT_NE(it, end);
+  EXPECT_THAT((*it).getData(), ElementsAre(5.0, 6.0));
+
+  // Advance past the end
+  ++it;
+  EXPECT_EQ(it, end);
+}
+
+TEST(VocabStorageTest, IteratorTraversal) {
+  std::vector<std::vector<Embedding>> sectionData;
+
+  // Section 0: 2 embeddings
+  std::vector<Embedding> section0;
+  section0.emplace_back(std::vector<double>{10.0});
+  section0.emplace_back(std::vector<double>{20.0});
+  sectionData.push_back(std::move(section0));
+
+  // Section 1: empty section (to test section skipping)
+  std::vector<Embedding> section1; // Empty
+  sectionData.push_back(std::move(section1));
+
+  // Section 2: 3 embeddings
+  std::vector<Embedding> section2;
+  section2.emplace_back(std::vector<double>{30.0});
+  section2.emplace_back(std::vector<double>{40.0});
+  section2.emplace_back(std::vector<double>{50.0});
+  sectionData.push_back(std::move(section2));
+
+  VocabStorage storage(std::move(sectionData));
+
+  // Collect all values using iterator
+  std::vector<double> values;
+  for (const auto &emb : storage) {
+    EXPECT_EQ(emb.size(), 1u);
+    values.push_back(emb[0]);
+  }
+
+  // Should get all embeddings from non-empty sections
+  EXPECT_THAT(values, ElementsAre(10.0, 20.0, 30.0, 40.0, 50.0));
+}
+
+TEST(VocabStorageTest, IteratorComparison) {
+  std::vector<std::vector<Embedding>> sectionData;
+  std::vector<Embedding> section0;
+  section0.emplace_back(std::vector<double>{1.0});
+  section0.emplace_back(std::vector<double>{2.0});
+  sectionData.push_back(std::move(section0));
+
+  VocabStorage storage(std::move(sectionData));
+
+  auto it1 = storage.begin();
+  auto it2 = storage.begin();
+  auto end = storage.end();
+
+  // Test equality
+  EXPECT_EQ(it1, it2);
+  EXPECT_NE(it1, end);
+
+  // Advance one iterator
+  ++it1;
+  EXPECT_NE(it1, it2);
+  EXPECT_NE(it1, end);
+
+  // Advance second iterator to match
+  ++it2;
+  EXPECT_EQ(it1, it2);
+
+  // Advance both to end
+  ++it1;
+  ++it2;
+  EXPECT_EQ(it1, end);
+  EXPECT_EQ(it2, end);
+  EXPECT_EQ(it1, it2);
+}
+
 } // end anonymous namespace

_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to