https://github.com/svkeerthy updated https://github.com/llvm/llvm-project/pull/143200
>From 7f2012cd56db0fc6e1c430a8d5b38d360b33145f Mon Sep 17 00:00:00 2001 From: svkeerthy <venkatakeer...@google.com> Date: Fri, 6 Jun 2025 20:32:32 +0000 Subject: [PATCH] Vocab changes1 --- llvm/include/llvm/Analysis/IR2Vec.h | 10 ++ llvm/lib/Analysis/IR2Vec.cpp | 82 +++++++++------ llvm/unittests/Analysis/IR2VecTest.cpp | 137 ++++++++++++++++++------- 3 files changed, 163 insertions(+), 66 deletions(-) diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h index 14f28999b174c..3d32942670785 100644 --- a/llvm/include/llvm/Analysis/IR2Vec.h +++ b/llvm/include/llvm/Analysis/IR2Vec.h @@ -31,7 +31,9 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/IR/PassManager.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Support/ErrorOr.h" +#include "llvm/Support/JSON.h" #include <map> namespace llvm { @@ -43,6 +45,7 @@ class Function; class Type; class Value; class raw_ostream; +class LLVMContext; /// IR2Vec computes two kinds of embeddings: Symbolic and Flow-aware. /// Symbolic embeddings capture the "syntactic" and "statistical correlation" @@ -53,6 +56,11 @@ class raw_ostream; enum class IR2VecKind { Symbolic }; namespace ir2vec { + +LLVM_ABI extern cl::opt<float> OpcWeight; +LLVM_ABI extern cl::opt<float> TypeWeight; +LLVM_ABI extern cl::opt<float> ArgWeight; + /// Embedding is a ADT that wraps std::vector<double>. It provides /// additional functionality for arithmetic and comparison operations. /// It is meant to be used *like* std::vector<double> but is more restrictive @@ -224,10 +232,12 @@ class IR2VecVocabResult { class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> { ir2vec::Vocab Vocabulary; Error readVocabulary(); + void emitError(Error Err, LLVMContext &Ctx); public: static AnalysisKey Key; IR2VecVocabAnalysis() = default; + explicit IR2VecVocabAnalysis(ir2vec::Vocab &&Vocab); using Result = IR2VecVocabResult; Result run(Module &M, ModuleAnalysisManager &MAM); }; diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp index 25ce35d4ace37..2ad65c2f40c33 100644 --- a/llvm/lib/Analysis/IR2Vec.cpp +++ b/llvm/lib/Analysis/IR2Vec.cpp @@ -16,13 +16,11 @@ #include "llvm/ADT/Statistic.h" #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" -#include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/Errc.h" #include "llvm/Support/Error.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/Format.h" -#include "llvm/Support/JSON.h" #include "llvm/Support/MemoryBuffer.h" using namespace llvm; @@ -33,6 +31,8 @@ using namespace ir2vec; STATISTIC(VocabMissCounter, "Number of lookups to entites not present in the vocabulary"); +namespace llvm { +namespace ir2vec { static cl::OptionCategory IR2VecCategory("IR2Vec Options"); // FIXME: Use a default vocab when not specified @@ -40,18 +40,20 @@ static cl::opt<std::string> VocabFile("ir2vec-vocab-path", cl::Optional, cl::desc("Path to the vocabulary file for IR2Vec"), cl::init(""), cl::cat(IR2VecCategory)); -static cl::opt<float> OpcWeight("ir2vec-opc-weight", cl::Optional, - cl::init(1.0), - cl::desc("Weight for opcode embeddings"), - cl::cat(IR2VecCategory)); -static cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional, - cl::init(0.5), - cl::desc("Weight for type embeddings"), - cl::cat(IR2VecCategory)); -static cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional, - cl::init(0.2), - cl::desc("Weight for argument embeddings"), - cl::cat(IR2VecCategory)); +LLVM_ABI cl::opt<float> OpcWeight("ir2vec-opc-weight", cl::Optional, + cl::init(1.0), + cl::desc("Weight for opcode embeddings"), + cl::cat(IR2VecCategory)); +LLVM_ABI cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional, + cl::init(0.5), + cl::desc("Weight for type embeddings"), + cl::cat(IR2VecCategory)); +LLVM_ABI cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional, + cl::init(0.2), + cl::desc("Weight for argument embeddings"), + cl::cat(IR2VecCategory)); +} // namespace ir2vec +} // namespace llvm AnalysisKey IR2VecVocabAnalysis::Key; @@ -251,9 +253,9 @@ bool IR2VecVocabResult::invalidate( // by auto-generating a default vocabulary during the build time. Error IR2VecVocabAnalysis::readVocabulary() { auto BufOrError = MemoryBuffer::getFileOrSTDIN(VocabFile, /*IsText=*/true); - if (!BufOrError) { + if (!BufOrError) return createFileError(VocabFile, BufOrError.getError()); - } + auto Content = BufOrError.get()->getBuffer(); json::Path::Root Path(""); Expected<json::Value> ParsedVocabValue = json::parse(Content); @@ -261,39 +263,57 @@ Error IR2VecVocabAnalysis::readVocabulary() { return ParsedVocabValue.takeError(); bool Res = json::fromJSON(*ParsedVocabValue, Vocabulary, Path); - if (!Res) { + if (!Res) return createStringError(errc::illegal_byte_sequence, "Unable to parse the vocabulary"); - } - assert(Vocabulary.size() > 0 && "Vocabulary is empty"); + + if (Vocabulary.empty()) + return createStringError(errc::illegal_byte_sequence, + "Vocabulary is empty"); unsigned Dim = Vocabulary.begin()->second.size(); - assert(Dim > 0 && "Dimension of vocabulary is zero"); - (void)Dim; - assert(std::all_of(Vocabulary.begin(), Vocabulary.end(), - [Dim](const std::pair<StringRef, Embedding> &Entry) { - return Entry.second.size() == Dim; - }) && - "All vectors in the vocabulary are not of the same dimension"); + if (Dim == 0) + return createStringError(errc::illegal_byte_sequence, + "Dimension of vocabulary is zero"); + + if (!std::all_of(Vocabulary.begin(), Vocabulary.end(), + [Dim](const std::pair<StringRef, Embedding> &Entry) { + return Entry.second.size() == Dim; + })) + return createStringError( + errc::illegal_byte_sequence, + "All vectors in the vocabulary are not of the same dimension"); + return Error::success(); } +IR2VecVocabAnalysis::IR2VecVocabAnalysis(Vocab &&Vocabulary) + : Vocabulary(std::move(Vocabulary)) {} + +void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) { + handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) { + Ctx.emitError("Error reading vocabulary: " + EI.message()); + }); +} + IR2VecVocabAnalysis::Result IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) { auto Ctx = &M.getContext(); + // FIXME: Scale the vocabulary once. This would avoid scaling per use later. + // If vocabulary is already populated by the constructor, use it. + if (!Vocabulary.empty()) + return IR2VecVocabResult(std::move(Vocabulary)); + + // Otherwise, try to read from the vocabulary file. if (VocabFile.empty()) { // FIXME: Use default vocabulary Ctx->emitError("IR2Vec vocabulary file path not specified"); return IR2VecVocabResult(); // Return invalid result } if (auto Err = readVocabulary()) { - handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) { - Ctx->emitError("Error reading vocabulary: " + EI.message()); - }); + emitError(std::move(Err), *Ctx); return IR2VecVocabResult(); } - // FIXME: Scale the vocabulary here once. This would avoid scaling per use - // later. return IR2VecVocabResult(std::move(Vocabulary)); } diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp index 46e9c71c58250..c2c65c92cfb07 100644 --- a/llvm/unittests/Analysis/IR2VecTest.cpp +++ b/llvm/unittests/Analysis/IR2VecTest.cpp @@ -261,25 +261,30 @@ TEST(IR2VecTest, IR2VecVocabResultValidity) { EXPECT_EQ(validResult.getDimension(), 2u); } -// Helper to create a minimal function and embedder for getter tests -struct GetterTestEnv { - Vocab V = {}; +// Fixture for IR2Vec tests requiring IR setup and weight management. +class IR2VecTestFixture : public ::testing::Test { +protected: + Vocab V; LLVMContext Ctx; - std::unique_ptr<Module> M = nullptr; + std::unique_ptr<Module> M; Function *F = nullptr; BasicBlock *BB = nullptr; - Instruction *Add = nullptr; - Instruction *Ret = nullptr; - std::unique_ptr<Embedder> Emb = nullptr; + Instruction *AddInst = nullptr; + Instruction *RetInst = nullptr; - GetterTestEnv() { + float OriginalOpcWeight = ::OpcWeight; + float OriginalTypeWeight = ::TypeWeight; + float OriginalArgWeight = ::ArgWeight; + + void SetUp() override { V = {{"add", {1.0, 2.0}}, {"integerTy", {0.5, 0.5}}, {"constant", {0.2, 0.3}}, {"variable", {0.0, 0.0}}, {"unknownTy", {0.0, 0.0}}}; - M = std::make_unique<Module>("M", Ctx); + // Setup IR + M = std::make_unique<Module>("TestM", Ctx); FunctionType *FTy = FunctionType::get( Type::getInt32Ty(Ctx), {Type::getInt32Ty(Ctx), Type::getInt32Ty(Ctx)}, false); @@ -288,61 +293,82 @@ struct GetterTestEnv { Argument *Arg = F->getArg(0); llvm::Value *Const = ConstantInt::get(Type::getInt32Ty(Ctx), 42); - Add = BinaryOperator::CreateAdd(Arg, Const, "add", BB); - Ret = ReturnInst::Create(Ctx, Add, BB); + AddInst = BinaryOperator::CreateAdd(Arg, Const, "add", BB); + RetInst = ReturnInst::Create(Ctx, AddInst, BB); + } + + void setWeights(float OpcWeight, float TypeWeight, float ArgWeight) { + ::OpcWeight = OpcWeight; + ::TypeWeight = TypeWeight; + ::ArgWeight = ArgWeight; + } - auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V); - EXPECT_TRUE(static_cast<bool>(Result)); - Emb = std::move(*Result); + void TearDown() override { + // Restore original global weights + ::OpcWeight = OriginalOpcWeight; + ::TypeWeight = OriginalTypeWeight; + ::ArgWeight = OriginalArgWeight; } }; -TEST(IR2VecTest, GetInstVecMap) { - GetterTestEnv Env; - const auto &InstMap = Env.Emb->getInstVecMap(); +TEST_F(IR2VecTestFixture, GetInstVecMap) { + auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V); + ASSERT_TRUE(static_cast<bool>(Result)); + auto Emb = std::move(*Result); + + const auto &InstMap = Emb->getInstVecMap(); EXPECT_EQ(InstMap.size(), 2u); - EXPECT_TRUE(InstMap.count(Env.Add)); - EXPECT_TRUE(InstMap.count(Env.Ret)); + EXPECT_TRUE(InstMap.count(AddInst)); + EXPECT_TRUE(InstMap.count(RetInst)); - EXPECT_EQ(InstMap.at(Env.Add).size(), 2u); - EXPECT_EQ(InstMap.at(Env.Ret).size(), 2u); + EXPECT_EQ(InstMap.at(AddInst).size(), 2u); + EXPECT_EQ(InstMap.at(RetInst).size(), 2u); // Check values for add: {1.29, 2.31} - EXPECT_THAT(InstMap.at(Env.Add), + EXPECT_THAT(InstMap.at(AddInst), ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6))); // Check values for ret: {0.0, 0.}; Neither ret nor voidTy are present in // vocab - EXPECT_THAT(InstMap.at(Env.Ret), ElementsAre(0.0, 0.0)); + EXPECT_THAT(InstMap.at(RetInst), ElementsAre(0.0, 0.0)); } -TEST(IR2VecTest, GetBBVecMap) { - GetterTestEnv Env; - const auto &BBMap = Env.Emb->getBBVecMap(); +TEST_F(IR2VecTestFixture, GetBBVecMap) { + auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V); + ASSERT_TRUE(static_cast<bool>(Result)); + auto Emb = std::move(*Result); + + const auto &BBMap = Emb->getBBVecMap(); EXPECT_EQ(BBMap.size(), 1u); - EXPECT_TRUE(BBMap.count(Env.BB)); - EXPECT_EQ(BBMap.at(Env.BB).size(), 2u); + EXPECT_TRUE(BBMap.count(BB)); + EXPECT_EQ(BBMap.at(BB).size(), 2u); // BB vector should be sum of add and ret: {1.29, 2.31} + {0.0, 0.0} = // {1.29, 2.31} - EXPECT_THAT(BBMap.at(Env.BB), + EXPECT_THAT(BBMap.at(BB), ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6))); } -TEST(IR2VecTest, GetBBVector) { - GetterTestEnv Env; - const auto &BBVec = Env.Emb->getBBVector(*Env.BB); +TEST_F(IR2VecTestFixture, GetBBVector) { + auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V); + ASSERT_TRUE(static_cast<bool>(Result)); + auto Emb = std::move(*Result); + + const auto &BBVec = Emb->getBBVector(*BB); EXPECT_EQ(BBVec.size(), 2u); EXPECT_THAT(BBVec, ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6))); } -TEST(IR2VecTest, GetFunctionVector) { - GetterTestEnv Env; - const auto &FuncVec = Env.Emb->getFunctionVector(); +TEST_F(IR2VecTestFixture, GetFunctionVector) { + auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V); + ASSERT_TRUE(static_cast<bool>(Result)); + auto Emb = std::move(*Result); + + const auto &FuncVec = Emb->getFunctionVector(); EXPECT_EQ(FuncVec.size(), 2u); @@ -351,4 +377,45 @@ TEST(IR2VecTest, GetFunctionVector) { ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6))); } +TEST_F(IR2VecTestFixture, GetFunctionVectorWithCustomWeights) { + setWeights(1.0, 1.0, 1.0); + + auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V); + ASSERT_TRUE(static_cast<bool>(Result)); + auto Emb = std::move(*Result); + + const auto &FuncVec = Emb->getFunctionVector(); + + EXPECT_EQ(FuncVec.size(), 2u); + + // Expected: 1*([1.0 2.0] + [0.0 0.0]) + 1*([0.5 0.5] + [0.0 0.0]) + 1*([0.2 + // 0.3] + [0.0 0.0]) + EXPECT_THAT(FuncVec, + ElementsAre(DoubleNear(1.7, 1e-6), DoubleNear(2.8, 1e-6))); +} + +TEST(IR2VecTest, IR2VecVocabAnalysisWithPrepopulatedVocab) { + Vocab InitialVocab = {{"key1", {1.1, 2.2}}, {"key2", {3.3, 4.4}}}; + Vocab ExpectedVocab = InitialVocab; + unsigned ExpectedDim = InitialVocab.begin()->second.size(); + + IR2VecVocabAnalysis VocabAnalysis(std::move(InitialVocab)); + + LLVMContext TestCtx; + Module TestMod("TestModuleForVocabAnalysis", TestCtx); + ModuleAnalysisManager MAM; + IR2VecVocabResult Result = VocabAnalysis.run(TestMod, MAM); + + EXPECT_TRUE(Result.isValid()); + ASSERT_FALSE(Result.getVocabulary().empty()); + EXPECT_EQ(Result.getDimension(), ExpectedDim); + + const auto &ResultVocab = Result.getVocabulary(); + EXPECT_EQ(ResultVocab.size(), ExpectedVocab.size()); + for (const auto &pair : ExpectedVocab) { + EXPECT_TRUE(ResultVocab.count(pair.first)); + EXPECT_THAT(ResultVocab.at(pair.first), ElementsAreArray(pair.second)); + } +} + } // end anonymous namespace _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits