https://github.com/svkeerthy updated 
https://github.com/llvm/llvm-project/pull/143200

>From 7f2012cd56db0fc6e1c430a8d5b38d360b33145f Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeer...@google.com>
Date: Fri, 6 Jun 2025 20:32:32 +0000
Subject: [PATCH] Vocab changes1

---
 llvm/include/llvm/Analysis/IR2Vec.h    |  10 ++
 llvm/lib/Analysis/IR2Vec.cpp           |  82 +++++++++------
 llvm/unittests/Analysis/IR2VecTest.cpp | 137 ++++++++++++++++++-------
 3 files changed, 163 insertions(+), 66 deletions(-)

diff --git a/llvm/include/llvm/Analysis/IR2Vec.h 
b/llvm/include/llvm/Analysis/IR2Vec.h
index 14f28999b174c..3d32942670785 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -31,7 +31,9 @@
 
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/IR/PassManager.h"
+#include "llvm/Support/CommandLine.h"
 #include "llvm/Support/ErrorOr.h"
+#include "llvm/Support/JSON.h"
 #include <map>
 
 namespace llvm {
@@ -43,6 +45,7 @@ class Function;
 class Type;
 class Value;
 class raw_ostream;
+class LLVMContext;
 
 /// IR2Vec computes two kinds of embeddings: Symbolic and Flow-aware.
 /// Symbolic embeddings capture the "syntactic" and "statistical correlation"
@@ -53,6 +56,11 @@ class raw_ostream;
 enum class IR2VecKind { Symbolic };
 
 namespace ir2vec {
+
+LLVM_ABI extern cl::opt<float> OpcWeight;
+LLVM_ABI extern cl::opt<float> TypeWeight;
+LLVM_ABI extern cl::opt<float> ArgWeight;
+
 /// Embedding is a ADT that wraps std::vector<double>. It provides
 /// additional functionality for arithmetic and comparison operations.
 /// It is meant to be used *like* std::vector<double> but is more restrictive
@@ -224,10 +232,12 @@ class IR2VecVocabResult {
 class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
   ir2vec::Vocab Vocabulary;
   Error readVocabulary();
+  void emitError(Error Err, LLVMContext &Ctx);
 
 public:
   static AnalysisKey Key;
   IR2VecVocabAnalysis() = default;
+  explicit IR2VecVocabAnalysis(ir2vec::Vocab &&Vocab);
   using Result = IR2VecVocabResult;
   Result run(Module &M, ModuleAnalysisManager &MAM);
 };
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 25ce35d4ace37..2ad65c2f40c33 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -16,13 +16,11 @@
 #include "llvm/ADT/Statistic.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/PassManager.h"
-#include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/Errc.h"
 #include "llvm/Support/Error.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/Format.h"
-#include "llvm/Support/JSON.h"
 #include "llvm/Support/MemoryBuffer.h"
 
 using namespace llvm;
@@ -33,6 +31,8 @@ using namespace ir2vec;
 STATISTIC(VocabMissCounter,
           "Number of lookups to entites not present in the vocabulary");
 
+namespace llvm {
+namespace ir2vec {
 static cl::OptionCategory IR2VecCategory("IR2Vec Options");
 
 // FIXME: Use a default vocab when not specified
@@ -40,18 +40,20 @@ static cl::opt<std::string>
     VocabFile("ir2vec-vocab-path", cl::Optional,
               cl::desc("Path to the vocabulary file for IR2Vec"), cl::init(""),
               cl::cat(IR2VecCategory));
-static cl::opt<float> OpcWeight("ir2vec-opc-weight", cl::Optional,
-                                cl::init(1.0),
-                                cl::desc("Weight for opcode embeddings"),
-                                cl::cat(IR2VecCategory));
-static cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional,
-                                 cl::init(0.5),
-                                 cl::desc("Weight for type embeddings"),
-                                 cl::cat(IR2VecCategory));
-static cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional,
-                                cl::init(0.2),
-                                cl::desc("Weight for argument embeddings"),
-                                cl::cat(IR2VecCategory));
+LLVM_ABI cl::opt<float> OpcWeight("ir2vec-opc-weight", cl::Optional,
+                                  cl::init(1.0),
+                                  cl::desc("Weight for opcode embeddings"),
+                                  cl::cat(IR2VecCategory));
+LLVM_ABI cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional,
+                                   cl::init(0.5),
+                                   cl::desc("Weight for type embeddings"),
+                                   cl::cat(IR2VecCategory));
+LLVM_ABI cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional,
+                                  cl::init(0.2),
+                                  cl::desc("Weight for argument embeddings"),
+                                  cl::cat(IR2VecCategory));
+} // namespace ir2vec
+} // namespace llvm
 
 AnalysisKey IR2VecVocabAnalysis::Key;
 
@@ -251,9 +253,9 @@ bool IR2VecVocabResult::invalidate(
 // by auto-generating a default vocabulary during the build time.
 Error IR2VecVocabAnalysis::readVocabulary() {
   auto BufOrError = MemoryBuffer::getFileOrSTDIN(VocabFile, /*IsText=*/true);
-  if (!BufOrError) {
+  if (!BufOrError)
     return createFileError(VocabFile, BufOrError.getError());
-  }
+
   auto Content = BufOrError.get()->getBuffer();
   json::Path::Root Path("");
   Expected<json::Value> ParsedVocabValue = json::parse(Content);
@@ -261,39 +263,57 @@ Error IR2VecVocabAnalysis::readVocabulary() {
     return ParsedVocabValue.takeError();
 
   bool Res = json::fromJSON(*ParsedVocabValue, Vocabulary, Path);
-  if (!Res) {
+  if (!Res)
     return createStringError(errc::illegal_byte_sequence,
                              "Unable to parse the vocabulary");
-  }
-  assert(Vocabulary.size() > 0 && "Vocabulary is empty");
+
+  if (Vocabulary.empty())
+    return createStringError(errc::illegal_byte_sequence,
+                             "Vocabulary is empty");
 
   unsigned Dim = Vocabulary.begin()->second.size();
-  assert(Dim > 0 && "Dimension of vocabulary is zero");
-  (void)Dim;
-  assert(std::all_of(Vocabulary.begin(), Vocabulary.end(),
-                     [Dim](const std::pair<StringRef, Embedding> &Entry) {
-                       return Entry.second.size() == Dim;
-                     }) &&
-         "All vectors in the vocabulary are not of the same dimension");
+  if (Dim == 0)
+    return createStringError(errc::illegal_byte_sequence,
+                             "Dimension of vocabulary is zero");
+
+  if (!std::all_of(Vocabulary.begin(), Vocabulary.end(),
+                   [Dim](const std::pair<StringRef, Embedding> &Entry) {
+                     return Entry.second.size() == Dim;
+                   }))
+    return createStringError(
+        errc::illegal_byte_sequence,
+        "All vectors in the vocabulary are not of the same dimension");
+
   return Error::success();
 }
 
+IR2VecVocabAnalysis::IR2VecVocabAnalysis(Vocab &&Vocabulary)
+    : Vocabulary(std::move(Vocabulary)) {}
+
+void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) {
+  handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
+    Ctx.emitError("Error reading vocabulary: " + EI.message());
+  });
+}
+
 IR2VecVocabAnalysis::Result
 IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
   auto Ctx = &M.getContext();
+  // FIXME: Scale the vocabulary once. This would avoid scaling per use later.
+  // If vocabulary is already populated by the constructor, use it.
+  if (!Vocabulary.empty())
+    return IR2VecVocabResult(std::move(Vocabulary));
+
+  // Otherwise, try to read from the vocabulary file.
   if (VocabFile.empty()) {
     // FIXME: Use default vocabulary
     Ctx->emitError("IR2Vec vocabulary file path not specified");
     return IR2VecVocabResult(); // Return invalid result
   }
   if (auto Err = readVocabulary()) {
-    handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
-      Ctx->emitError("Error reading vocabulary: " + EI.message());
-    });
+    emitError(std::move(Err), *Ctx);
     return IR2VecVocabResult();
   }
-  // FIXME: Scale the vocabulary here once. This would avoid scaling per use
-  // later.
   return IR2VecVocabResult(std::move(Vocabulary));
 }
 
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp 
b/llvm/unittests/Analysis/IR2VecTest.cpp
index 46e9c71c58250..c2c65c92cfb07 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -261,25 +261,30 @@ TEST(IR2VecTest, IR2VecVocabResultValidity) {
   EXPECT_EQ(validResult.getDimension(), 2u);
 }
 
-// Helper to create a minimal function and embedder for getter tests
-struct GetterTestEnv {
-  Vocab V = {};
+// Fixture for IR2Vec tests requiring IR setup and weight management.
+class IR2VecTestFixture : public ::testing::Test {
+protected:
+  Vocab V;
   LLVMContext Ctx;
-  std::unique_ptr<Module> M = nullptr;
+  std::unique_ptr<Module> M;
   Function *F = nullptr;
   BasicBlock *BB = nullptr;
-  Instruction *Add = nullptr;
-  Instruction *Ret = nullptr;
-  std::unique_ptr<Embedder> Emb = nullptr;
+  Instruction *AddInst = nullptr;
+  Instruction *RetInst = nullptr;
 
-  GetterTestEnv() {
+  float OriginalOpcWeight = ::OpcWeight;
+  float OriginalTypeWeight = ::TypeWeight;
+  float OriginalArgWeight = ::ArgWeight;
+
+  void SetUp() override {
     V = {{"add", {1.0, 2.0}},
          {"integerTy", {0.5, 0.5}},
          {"constant", {0.2, 0.3}},
          {"variable", {0.0, 0.0}},
          {"unknownTy", {0.0, 0.0}}};
 
-    M = std::make_unique<Module>("M", Ctx);
+    // Setup IR
+    M = std::make_unique<Module>("TestM", Ctx);
     FunctionType *FTy = FunctionType::get(
         Type::getInt32Ty(Ctx), {Type::getInt32Ty(Ctx), Type::getInt32Ty(Ctx)},
         false);
@@ -288,61 +293,82 @@ struct GetterTestEnv {
     Argument *Arg = F->getArg(0);
     llvm::Value *Const = ConstantInt::get(Type::getInt32Ty(Ctx), 42);
 
-    Add = BinaryOperator::CreateAdd(Arg, Const, "add", BB);
-    Ret = ReturnInst::Create(Ctx, Add, BB);
+    AddInst = BinaryOperator::CreateAdd(Arg, Const, "add", BB);
+    RetInst = ReturnInst::Create(Ctx, AddInst, BB);
+  }
+
+  void setWeights(float OpcWeight, float TypeWeight, float ArgWeight) {
+    ::OpcWeight = OpcWeight;
+    ::TypeWeight = TypeWeight;
+    ::ArgWeight = ArgWeight;
+  }
 
-    auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
-    EXPECT_TRUE(static_cast<bool>(Result));
-    Emb = std::move(*Result);
+  void TearDown() override {
+    // Restore original global weights
+    ::OpcWeight = OriginalOpcWeight;
+    ::TypeWeight = OriginalTypeWeight;
+    ::ArgWeight = OriginalArgWeight;
   }
 };
 
-TEST(IR2VecTest, GetInstVecMap) {
-  GetterTestEnv Env;
-  const auto &InstMap = Env.Emb->getInstVecMap();
+TEST_F(IR2VecTestFixture, GetInstVecMap) {
+  auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
+  ASSERT_TRUE(static_cast<bool>(Result));
+  auto Emb = std::move(*Result);
+
+  const auto &InstMap = Emb->getInstVecMap();
 
   EXPECT_EQ(InstMap.size(), 2u);
-  EXPECT_TRUE(InstMap.count(Env.Add));
-  EXPECT_TRUE(InstMap.count(Env.Ret));
+  EXPECT_TRUE(InstMap.count(AddInst));
+  EXPECT_TRUE(InstMap.count(RetInst));
 
-  EXPECT_EQ(InstMap.at(Env.Add).size(), 2u);
-  EXPECT_EQ(InstMap.at(Env.Ret).size(), 2u);
+  EXPECT_EQ(InstMap.at(AddInst).size(), 2u);
+  EXPECT_EQ(InstMap.at(RetInst).size(), 2u);
 
   // Check values for add: {1.29, 2.31}
-  EXPECT_THAT(InstMap.at(Env.Add),
+  EXPECT_THAT(InstMap.at(AddInst),
               ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
 
   // Check values for ret: {0.0, 0.}; Neither ret nor voidTy are present in
   // vocab
-  EXPECT_THAT(InstMap.at(Env.Ret), ElementsAre(0.0, 0.0));
+  EXPECT_THAT(InstMap.at(RetInst), ElementsAre(0.0, 0.0));
 }
 
-TEST(IR2VecTest, GetBBVecMap) {
-  GetterTestEnv Env;
-  const auto &BBMap = Env.Emb->getBBVecMap();
+TEST_F(IR2VecTestFixture, GetBBVecMap) {
+  auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
+  ASSERT_TRUE(static_cast<bool>(Result));
+  auto Emb = std::move(*Result);
+
+  const auto &BBMap = Emb->getBBVecMap();
 
   EXPECT_EQ(BBMap.size(), 1u);
-  EXPECT_TRUE(BBMap.count(Env.BB));
-  EXPECT_EQ(BBMap.at(Env.BB).size(), 2u);
+  EXPECT_TRUE(BBMap.count(BB));
+  EXPECT_EQ(BBMap.at(BB).size(), 2u);
 
   // BB vector should be sum of add and ret: {1.29, 2.31} + {0.0, 0.0} =
   // {1.29, 2.31}
-  EXPECT_THAT(BBMap.at(Env.BB),
+  EXPECT_THAT(BBMap.at(BB),
               ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
 }
 
-TEST(IR2VecTest, GetBBVector) {
-  GetterTestEnv Env;
-  const auto &BBVec = Env.Emb->getBBVector(*Env.BB);
+TEST_F(IR2VecTestFixture, GetBBVector) {
+  auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
+  ASSERT_TRUE(static_cast<bool>(Result));
+  auto Emb = std::move(*Result);
+
+  const auto &BBVec = Emb->getBBVector(*BB);
 
   EXPECT_EQ(BBVec.size(), 2u);
   EXPECT_THAT(BBVec,
               ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
 }
 
-TEST(IR2VecTest, GetFunctionVector) {
-  GetterTestEnv Env;
-  const auto &FuncVec = Env.Emb->getFunctionVector();
+TEST_F(IR2VecTestFixture, GetFunctionVector) {
+  auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
+  ASSERT_TRUE(static_cast<bool>(Result));
+  auto Emb = std::move(*Result);
+
+  const auto &FuncVec = Emb->getFunctionVector();
 
   EXPECT_EQ(FuncVec.size(), 2u);
 
@@ -351,4 +377,45 @@ TEST(IR2VecTest, GetFunctionVector) {
               ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
 }
 
+TEST_F(IR2VecTestFixture, GetFunctionVectorWithCustomWeights) {
+  setWeights(1.0, 1.0, 1.0);
+
+  auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
+  ASSERT_TRUE(static_cast<bool>(Result));
+  auto Emb = std::move(*Result);
+
+  const auto &FuncVec = Emb->getFunctionVector();
+
+  EXPECT_EQ(FuncVec.size(), 2u);
+
+  // Expected: 1*([1.0 2.0] + [0.0 0.0]) + 1*([0.5 0.5] + [0.0 0.0]) + 1*([0.2
+  // 0.3] + [0.0 0.0])
+  EXPECT_THAT(FuncVec,
+              ElementsAre(DoubleNear(1.7, 1e-6), DoubleNear(2.8, 1e-6)));
+}
+
+TEST(IR2VecTest, IR2VecVocabAnalysisWithPrepopulatedVocab) {
+  Vocab InitialVocab = {{"key1", {1.1, 2.2}}, {"key2", {3.3, 4.4}}};
+  Vocab ExpectedVocab = InitialVocab;
+  unsigned ExpectedDim = InitialVocab.begin()->second.size();
+
+  IR2VecVocabAnalysis VocabAnalysis(std::move(InitialVocab));
+
+  LLVMContext TestCtx;
+  Module TestMod("TestModuleForVocabAnalysis", TestCtx);
+  ModuleAnalysisManager MAM;
+  IR2VecVocabResult Result = VocabAnalysis.run(TestMod, MAM);
+
+  EXPECT_TRUE(Result.isValid());
+  ASSERT_FALSE(Result.getVocabulary().empty());
+  EXPECT_EQ(Result.getDimension(), ExpectedDim);
+
+  const auto &ResultVocab = Result.getVocabulary();
+  EXPECT_EQ(ResultVocab.size(), ExpectedVocab.size());
+  for (const auto &pair : ExpectedVocab) {
+    EXPECT_TRUE(ResultVocab.count(pair.first));
+    EXPECT_THAT(ResultVocab.at(pair.first), ElementsAreArray(pair.second));
+  }
+}
+
 } // end anonymous namespace

_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to