https://github.com/svkeerthy updated https://github.com/llvm/llvm-project/pull/143479
>From 6085399a3d002b78a093d951c6a9b6c5bb3fb243 Mon Sep 17 00:00:00 2001 From: svkeerthy <venkatakeer...@google.com> Date: Tue, 10 Jun 2025 05:40:38 +0000 Subject: [PATCH] [MLIniner][IR2Vec] Integrating IR2Vec with MLInliner --- .../Analysis/FunctionPropertiesAnalysis.h | 26 ++- llvm/include/llvm/Analysis/IR2Vec.h | 2 +- llvm/include/llvm/Analysis/InlineAdvisor.h | 3 + .../llvm/Analysis/InlineModelFeatureMaps.h | 6 +- llvm/include/llvm/Analysis/MLInlineAdvisor.h | 1 + .../Analysis/FunctionPropertiesAnalysis.cpp | 115 ++++++++++- llvm/lib/Analysis/IR2Vec.cpp | 4 +- llvm/lib/Analysis/InlineAdvisor.cpp | 29 +++ llvm/lib/Analysis/MLInlineAdvisor.cpp | 34 +++- .../FunctionPropertiesAnalysisTest.cpp | 179 +++++++++++++++--- 10 files changed, 361 insertions(+), 38 deletions(-) diff --git a/llvm/include/llvm/Analysis/FunctionPropertiesAnalysis.h b/llvm/include/llvm/Analysis/FunctionPropertiesAnalysis.h index babb6d9d6cf0c..06dbfc35a5294 100644 --- a/llvm/include/llvm/Analysis/FunctionPropertiesAnalysis.h +++ b/llvm/include/llvm/Analysis/FunctionPropertiesAnalysis.h @@ -15,6 +15,7 @@ #define LLVM_ANALYSIS_FUNCTIONPROPERTIESANALYSIS_H #include "llvm/ADT/DenseSet.h" +#include "llvm/Analysis/IR2Vec.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/PassManager.h" #include "llvm/Support/Compiler.h" @@ -32,17 +33,19 @@ class FunctionPropertiesInfo { void updateAggregateStats(const Function &F, const LoopInfo &LI); void reIncludeBB(const BasicBlock &BB); + ir2vec::Embedding FunctionEmbedding = ir2vec::Embedding(0.0); + std::optional<ir2vec::Vocab> IR2VecVocab; + public: LLVM_ABI static FunctionPropertiesInfo getFunctionPropertiesInfo(const Function &F, const DominatorTree &DT, - const LoopInfo &LI); + const LoopInfo &LI, + const IR2VecVocabResult *VocabResult); LLVM_ABI static FunctionPropertiesInfo getFunctionPropertiesInfo(Function &F, FunctionAnalysisManager &FAM); - bool operator==(const FunctionPropertiesInfo &FPI) const { - return std::memcmp(this, &FPI, sizeof(FunctionPropertiesInfo)) == 0; - } + bool operator==(const FunctionPropertiesInfo &FPI) const; bool operator!=(const FunctionPropertiesInfo &FPI) const { return !(*this == FPI); @@ -137,6 +140,19 @@ class FunctionPropertiesInfo { int64_t CallReturnsVectorPointerCount = 0; int64_t CallWithManyArgumentsCount = 0; int64_t CallWithPointerArgumentCount = 0; + + const ir2vec::Embedding &getFunctionEmbedding() const { + return FunctionEmbedding; + } + + const std::optional<ir2vec::Vocab> &getIR2VecVocab() const { + return IR2VecVocab; + } + + // Helper intended to be useful for unittests + void setFunctionEmbeddingForTest(const ir2vec::Embedding &Embedding) { + FunctionEmbedding = Embedding; + } }; // Analysis pass @@ -192,7 +208,7 @@ class FunctionPropertiesUpdater { DominatorTree &getUpdatedDominatorTree(FunctionAnalysisManager &FAM) const; - DenseSet<const BasicBlock *> Successors; + DenseSet<const BasicBlock *> Successors, CallUsers; // Edges we might potentially need to remove from the dominator tree. SmallVector<DominatorTree::UpdateType, 2> DomTreeUpdates; diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h index 3a6f47ded8ca4..9acffb996283c 100644 --- a/llvm/include/llvm/Analysis/IR2Vec.h +++ b/llvm/include/llvm/Analysis/IR2Vec.h @@ -239,7 +239,7 @@ class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> { public: static AnalysisKey Key; IR2VecVocabAnalysis() = default; - explicit IR2VecVocabAnalysis(ir2vec::Vocab &&Vocab); + explicit IR2VecVocabAnalysis(ir2vec::Vocab Vocab); using Result = IR2VecVocabResult; Result run(Module &M, ModuleAnalysisManager &MAM); }; diff --git a/llvm/include/llvm/Analysis/InlineAdvisor.h b/llvm/include/llvm/Analysis/InlineAdvisor.h index 9d15136e81d10..d2cad4717cbdb 100644 --- a/llvm/include/llvm/Analysis/InlineAdvisor.h +++ b/llvm/include/llvm/Analysis/InlineAdvisor.h @@ -331,6 +331,9 @@ class InlineAdvisorAnalysis : public AnalysisInfoMixin<InlineAdvisorAnalysis> { }; Result run(Module &M, ModuleAnalysisManager &MAM) { return Result(M, MAM); } + +private: + static bool initializeIR2VecVocab(Module &M, ModuleAnalysisManager &MAM); }; /// Printer pass for the InlineAdvisorAnalysis results. diff --git a/llvm/include/llvm/Analysis/InlineModelFeatureMaps.h b/llvm/include/llvm/Analysis/InlineModelFeatureMaps.h index 961d5091bf9f3..91d3378565fc5 100644 --- a/llvm/include/llvm/Analysis/InlineModelFeatureMaps.h +++ b/llvm/include/llvm/Analysis/InlineModelFeatureMaps.h @@ -142,6 +142,10 @@ enum class FeatureIndex : size_t { INLINE_FEATURE_ITERATOR(POPULATE_INDICES) #undef POPULATE_INDICES +// IR2Vec embeddings + callee_embedding, + caller_embedding, + NumberOfFeatures }; // clang-format on @@ -154,7 +158,7 @@ inlineCostFeatureToMlFeature(InlineCostFeatureIndex Feature) { constexpr size_t NumberOfFeatures = static_cast<size_t>(FeatureIndex::NumberOfFeatures); -LLVM_ABI extern const std::vector<TensorSpec> FeatureMap; +LLVM_ABI extern std::vector<TensorSpec> FeatureMap; LLVM_ABI extern const char *const DecisionName; LLVM_ABI extern const TensorSpec InlineDecisionSpec; diff --git a/llvm/include/llvm/Analysis/MLInlineAdvisor.h b/llvm/include/llvm/Analysis/MLInlineAdvisor.h index 580dd5e95d760..935e4c56dfce6 100644 --- a/llvm/include/llvm/Analysis/MLInlineAdvisor.h +++ b/llvm/include/llvm/Analysis/MLInlineAdvisor.h @@ -82,6 +82,7 @@ class MLInlineAdvisor : public InlineAdvisor { int64_t NodeCount = 0; int64_t EdgeCount = 0; int64_t EdgesOfLastSeenNodes = 0; + bool UseIR2Vec = false; std::map<const LazyCallGraph::Node *, unsigned> FunctionLevels; const int32_t InitialIRSize = 0; diff --git a/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp b/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp index 9d044c8a35910..29d3aaf46dc06 100644 --- a/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp +++ b/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp @@ -199,6 +199,29 @@ void FunctionPropertiesInfo::updateForBB(const BasicBlock &BB, #undef CHECK_OPERAND } } + + if (IR2VecVocab) { + // We instantiate the IR2Vec embedder each time, as having an unique + // pointer to the embedder as member of the class would make it + // non-copyable. Instantiating the embedder in itself is not costly. + auto EmbOrErr = ir2vec::Embedder::create(IR2VecKind::Symbolic, + *BB.getParent(), *IR2VecVocab); + if (Error Err = EmbOrErr.takeError()) { + handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) { + BB.getContext().emitError("Error creating IR2Vec embeddings: " + + EI.message()); + }); + return; + } + auto Embedder = std::move(*EmbOrErr); + const auto &BBEmbedding = Embedder->getBBVector(BB); + // Subtract BBEmbedding from Function embedding if the direction is -1, + // and add it if the direction is +1. + if (Direction == -1) + FunctionEmbedding -= BBEmbedding; + else + FunctionEmbedding += BBEmbedding; + } } void FunctionPropertiesInfo::updateAggregateStats(const Function &F, @@ -220,14 +243,24 @@ void FunctionPropertiesInfo::updateAggregateStats(const Function &F, FunctionPropertiesInfo FunctionPropertiesInfo::getFunctionPropertiesInfo( Function &F, FunctionAnalysisManager &FAM) { + // We use the cached result of the IR2VecVocabAnalysis run by + // InlineAdvisorAnalysis. If the IR2VecVocabAnalysis is not run, we don't + // use IR2Vec embeddings. + auto VocabResult = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F) + .getCachedResult<IR2VecVocabAnalysis>(*F.getParent()); return getFunctionPropertiesInfo(F, FAM.getResult<DominatorTreeAnalysis>(F), - FAM.getResult<LoopAnalysis>(F)); + FAM.getResult<LoopAnalysis>(F), VocabResult); } FunctionPropertiesInfo FunctionPropertiesInfo::getFunctionPropertiesInfo( - const Function &F, const DominatorTree &DT, const LoopInfo &LI) { + const Function &F, const DominatorTree &DT, const LoopInfo &LI, + const IR2VecVocabResult *VocabResult) { FunctionPropertiesInfo FPI; + if (VocabResult && VocabResult->isValid()) { + FPI.IR2VecVocab = VocabResult->getVocabulary(); + FPI.FunctionEmbedding = ir2vec::Embedding(VocabResult->getDimension(), 0.0); + } for (const auto &BB : F) if (DT.isReachableFromEntry(&BB)) FPI.reIncludeBB(BB); @@ -235,6 +268,66 @@ FunctionPropertiesInfo FunctionPropertiesInfo::getFunctionPropertiesInfo( return FPI; } +bool FunctionPropertiesInfo::operator==( + const FunctionPropertiesInfo &FPI) const { + if (BasicBlockCount != FPI.BasicBlockCount || + BlocksReachedFromConditionalInstruction != + FPI.BlocksReachedFromConditionalInstruction || + Uses != FPI.Uses || + DirectCallsToDefinedFunctions != FPI.DirectCallsToDefinedFunctions || + LoadInstCount != FPI.LoadInstCount || + StoreInstCount != FPI.StoreInstCount || + MaxLoopDepth != FPI.MaxLoopDepth || + TopLevelLoopCount != FPI.TopLevelLoopCount || + TotalInstructionCount != FPI.TotalInstructionCount || + BasicBlocksWithSingleSuccessor != FPI.BasicBlocksWithSingleSuccessor || + BasicBlocksWithTwoSuccessors != FPI.BasicBlocksWithTwoSuccessors || + BasicBlocksWithMoreThanTwoSuccessors != + FPI.BasicBlocksWithMoreThanTwoSuccessors || + BasicBlocksWithSinglePredecessor != + FPI.BasicBlocksWithSinglePredecessor || + BasicBlocksWithTwoPredecessors != FPI.BasicBlocksWithTwoPredecessors || + BasicBlocksWithMoreThanTwoPredecessors != + FPI.BasicBlocksWithMoreThanTwoPredecessors || + BigBasicBlocks != FPI.BigBasicBlocks || + MediumBasicBlocks != FPI.MediumBasicBlocks || + SmallBasicBlocks != FPI.SmallBasicBlocks || + CastInstructionCount != FPI.CastInstructionCount || + FloatingPointInstructionCount != FPI.FloatingPointInstructionCount || + IntegerInstructionCount != FPI.IntegerInstructionCount || + ConstantIntOperandCount != FPI.ConstantIntOperandCount || + ConstantFPOperandCount != FPI.ConstantFPOperandCount || + ConstantOperandCount != FPI.ConstantOperandCount || + InstructionOperandCount != FPI.InstructionOperandCount || + BasicBlockOperandCount != FPI.BasicBlockOperandCount || + GlobalValueOperandCount != FPI.GlobalValueOperandCount || + InlineAsmOperandCount != FPI.InlineAsmOperandCount || + ArgumentOperandCount != FPI.ArgumentOperandCount || + UnknownOperandCount != FPI.UnknownOperandCount || + CriticalEdgeCount != FPI.CriticalEdgeCount || + ControlFlowEdgeCount != FPI.ControlFlowEdgeCount || + UnconditionalBranchCount != FPI.UnconditionalBranchCount || + IntrinsicCount != FPI.IntrinsicCount || + DirectCallCount != FPI.DirectCallCount || + IndirectCallCount != FPI.IndirectCallCount || + CallReturnsIntegerCount != FPI.CallReturnsIntegerCount || + CallReturnsFloatCount != FPI.CallReturnsFloatCount || + CallReturnsPointerCount != FPI.CallReturnsPointerCount || + CallReturnsVectorIntCount != FPI.CallReturnsVectorIntCount || + CallReturnsVectorFloatCount != FPI.CallReturnsVectorFloatCount || + CallReturnsVectorPointerCount != FPI.CallReturnsVectorPointerCount || + CallWithManyArgumentsCount != FPI.CallWithManyArgumentsCount || + CallWithPointerArgumentCount != FPI.CallWithPointerArgumentCount) { + return false; + } + // Check the equality of the function embeddings. We don't check the equality + // of Vocabulary as it remains the same. + if (!FunctionEmbedding.approximatelyEquals(FPI.FunctionEmbedding)) + return false; + + return true; +} + void FunctionPropertiesInfo::print(raw_ostream &OS) const { #define PRINT_PROPERTY(PROP_NAME) OS << #PROP_NAME ": " << PROP_NAME << "\n"; @@ -322,6 +415,16 @@ FunctionPropertiesUpdater::FunctionPropertiesUpdater( // The caller's entry BB may change due to new alloca instructions. LikelyToChangeBBs.insert(&*Caller.begin()); + // The users of the value returned by call instruction can change + // leading to the change in embeddings being computed, when used. + // We conservatively add the BBs with such uses to LikelyToChangeBBs. + for (const auto *User : CB.users()) + CallUsers.insert(dyn_cast<Instruction>(User)->getParent()); + // CallSiteBB can be removed from CallUsers if present, it's taken care + // separately. + CallUsers.erase(&CallSiteBB); + LikelyToChangeBBs.insert_range(CallUsers); + // The successors may become unreachable in the case of `invoke` inlining. // We track successors separately, too, because they form a boundary, together // with the CB BB ('Entry') between which the inlined callee will be pasted. @@ -435,6 +538,9 @@ void FunctionPropertiesUpdater::finish(FunctionAnalysisManager &FAM) const { if (&CallSiteBB != &*Caller.begin()) Reinclude.insert(&*Caller.begin()); + // Reinclude the BBs which use the values returned by call instruction + Reinclude.insert_range(CallUsers); + // Distribute the successors to the 2 buckets. for (const auto *Succ : Successors) if (DT.isReachableFromEntry(Succ)) @@ -486,6 +592,9 @@ bool FunctionPropertiesUpdater::isUpdateValid(Function &F, return false; DominatorTree DT(F); LoopInfo LI(DT); - auto Fresh = FunctionPropertiesInfo::getFunctionPropertiesInfo(F, DT, LI); + auto VocabResult = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F) + .getCachedResult<IR2VecVocabAnalysis>(*F.getParent()); + auto Fresh = + FunctionPropertiesInfo::getFunctionPropertiesInfo(F, DT, LI, VocabResult); return FPI == Fresh; } diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp index 6d939421ea3d8..f8676e5842faf 100644 --- a/llvm/lib/Analysis/IR2Vec.cpp +++ b/llvm/lib/Analysis/IR2Vec.cpp @@ -291,8 +291,8 @@ Error IR2VecVocabAnalysis::readVocabulary() { return Error::success(); } -IR2VecVocabAnalysis::IR2VecVocabAnalysis(Vocab &&Vocabulary) - : Vocabulary(std::move(Vocabulary)) {} +IR2VecVocabAnalysis::IR2VecVocabAnalysis(Vocab Vocabulary) + : Vocabulary(Vocabulary) {} void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) { handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) { diff --git a/llvm/lib/Analysis/InlineAdvisor.cpp b/llvm/lib/Analysis/InlineAdvisor.cpp index 3d30f3d10a9d0..2e869dfd91713 100644 --- a/llvm/lib/Analysis/InlineAdvisor.cpp +++ b/llvm/lib/Analysis/InlineAdvisor.cpp @@ -16,6 +16,7 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/EphemeralValuesCache.h" +#include "llvm/Analysis/IR2Vec.h" #include "llvm/Analysis/InlineCost.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ProfileSummaryInfo.h" @@ -64,6 +65,13 @@ static cl::opt<bool> cl::desc("If true, annotate inline advisor remarks " "with LTO and pass information.")); +// This flag is used to enable IR2Vec embeddings in the ML inliner; Only valid +// with ML inliner. The vocab file is used to initialize the embeddings. +static cl::opt<std::string> IR2VecVocabFile( + "ml-inliner-ir2vec-vocab-file", cl::Hidden, + cl::desc("Vocab file for IR2Vec; Setting this enables " + "configuring the model to use IR2Vec embeddings.")); + namespace llvm { extern cl::opt<InlinerFunctionImportStatsOpts> InlinerFunctionImportStats; } // namespace llvm @@ -206,6 +214,20 @@ void InlineAdvice::recordInliningWithCalleeDeleted() { AnalysisKey InlineAdvisorAnalysis::Key; AnalysisKey PluginInlineAdvisorAnalysis::Key; +bool InlineAdvisorAnalysis::initializeIR2VecVocab(Module &M, + ModuleAnalysisManager &MAM) { + if (!IR2VecVocabFile.empty()) { + auto IR2VecVocabResult = MAM.getResult<IR2VecVocabAnalysis>(M); + if (!IR2VecVocabResult.isValid()) { + M.getContext().emitError("Failed to load IR2Vec vocabulary"); + return false; + } + } + // No vocab file specified is OK; We just don't use IR2Vec + // embeddings. + return true; +} + bool InlineAdvisorAnalysis::Result::tryCreate( InlineParams Params, InliningAdvisorMode Mode, const ReplayInlinerSettings &ReplaySettings, InlineContext IC) { @@ -231,14 +253,21 @@ bool InlineAdvisorAnalysis::Result::tryCreate( /* EmitRemarks =*/true, IC); } break; + // Run IR2VecVocabAnalysis once per module to get the vocabulary. + // We run it here because it is immutable and we want to avoid running it + // multiple times. case InliningAdvisorMode::Development: #ifdef LLVM_HAVE_TFLITE LLVM_DEBUG(dbgs() << "Using development-mode inliner policy.\n"); + if (!InlineAdvisorAnalysis::initializeIR2VecVocab(M, MAM)) + return false; Advisor = llvm::getDevelopmentModeAdvisor(M, MAM, GetDefaultAdvice); #endif break; case InliningAdvisorMode::Release: LLVM_DEBUG(dbgs() << "Using release-mode inliner policy.\n"); + if (!InlineAdvisorAnalysis::initializeIR2VecVocab(M, MAM)) + return false; Advisor = llvm::getReleaseModeAdvisor(M, MAM, GetDefaultAdvice); break; } diff --git a/llvm/lib/Analysis/MLInlineAdvisor.cpp b/llvm/lib/Analysis/MLInlineAdvisor.cpp index 81a3bc94a6ad8..3a9a68670e852 100644 --- a/llvm/lib/Analysis/MLInlineAdvisor.cpp +++ b/llvm/lib/Analysis/MLInlineAdvisor.cpp @@ -107,7 +107,7 @@ static cl::opt<bool> KeepFPICache( cl::init(false)); // clang-format off -const std::vector<TensorSpec> llvm::FeatureMap{ +std::vector<TensorSpec> llvm::FeatureMap{ #define POPULATE_NAMES(DTYPE, SHAPE, NAME, __) TensorSpec::createSpec<DTYPE>(#NAME, SHAPE), // InlineCost features - these must come first INLINE_COST_FEATURE_ITERATOR(POPULATE_NAMES) @@ -186,6 +186,20 @@ MLInlineAdvisor::MLInlineAdvisor( EdgeCount += getLocalCalls(KVP.first->getFunction()); } NodeCount = AllNodes.size(); + + if (auto IR2VecVocabResult = MAM.getCachedResult<IR2VecVocabAnalysis>(M)) { + if (!IR2VecVocabResult->isValid()) { + M.getContext().emitError("IR2VecVocabAnalysis is not valid"); + return; + } + // Add the IR2Vec features to the feature map + auto IR2VecDim = IR2VecVocabResult->getDimension(); + FeatureMap.push_back( + TensorSpec::createSpec<float>("callee_embedding", {IR2VecDim})); + FeatureMap.push_back( + TensorSpec::createSpec<float>("caller_embedding", {IR2VecDim})); + UseIR2Vec = true; + } } unsigned MLInlineAdvisor::getInitialFunctionLevel(const Function &F) const { @@ -433,6 +447,24 @@ std::unique_ptr<InlineAdvice> MLInlineAdvisor::getAdviceImpl(CallBase &CB) { *ModelRunner->getTensor<int64_t>(FeatureIndex::is_caller_avail_external) = Caller.hasAvailableExternallyLinkage(); + if (UseIR2Vec) { + // Python side expects float embeddings. The IR2Vec embeddings are doubles + // as of now due to the restriction of fromJSON method used by the + // readVocabulary method in ir2vec::Embeddings. + auto setEmbedding = [&](const ir2vec::Embedding &Embedding, + FeatureIndex Index) { + auto Embedding_float = + std::vector<float>(Embedding.begin(), Embedding.end()); + std::memcpy(ModelRunner->getTensor<float>(Index), Embedding_float.data(), + Embedding.size() * sizeof(float)); + }; + + setEmbedding(CalleeBefore.getFunctionEmbedding(), + FeatureIndex::callee_embedding); + setEmbedding(CallerBefore.getFunctionEmbedding(), + FeatureIndex::caller_embedding); + } + // Add the cost features for (size_t I = 0; I < static_cast<size_t>(InlineCostFeatureIndex::NumberOfFeatures); ++I) { diff --git a/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp b/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp index 0720d935b0362..3ef2964f2d170 100644 --- a/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp +++ b/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp @@ -8,6 +8,7 @@ #include "llvm/Analysis/FunctionPropertiesAnalysis.h" #include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/IR2Vec.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/AsmParser/Parser.h" #include "llvm/IR/Dominators.h" @@ -20,15 +21,20 @@ #include "llvm/Support/Compiler.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Transforms/Utils/Cloning.h" +#include "gmock/gmock.h" #include "gtest/gtest.h" #include <cstring> using namespace llvm; +using namespace testing; namespace llvm { LLVM_ABI extern cl::opt<bool> EnableDetailedFunctionProperties; LLVM_ABI extern cl::opt<bool> BigBasicBlockInstructionThreshold; LLVM_ABI extern cl::opt<bool> MediumBasicBlockInstrutionThreshold; +LLVM_ABI extern cl::opt<float> ir2vec::OpcWeight; +LLVM_ABI extern cl::opt<float> ir2vec::TypeWeight; +LLVM_ABI extern cl::opt<float> ir2vec::ArgWeight; } // namespace llvm namespace { @@ -36,17 +42,81 @@ namespace { class FunctionPropertiesAnalysisTest : public testing::Test { public: FunctionPropertiesAnalysisTest() { + createTestVocabulary(1); + MAM.registerPass([&] { return IR2VecVocabAnalysis(Vocabulary); }); + MAM.registerPass([&] { return PassInstrumentationAnalysis(); }); + FAM.registerPass([&] { return ModuleAnalysisManagerFunctionProxy(MAM); }); FAM.registerPass([&] { return DominatorTreeAnalysis(); }); FAM.registerPass([&] { return LoopAnalysis(); }); FAM.registerPass([&] { return PassInstrumentationAnalysis(); }); + + ir2vec::OpcWeight = 1.0; + ir2vec::TypeWeight = 1.0; + ir2vec::ArgWeight = 1.0; + } + +private: + float OriginalOpcWeight = ir2vec::OpcWeight; + float OriginalTypeWeight = ir2vec::TypeWeight; + float OriginalArgWeight = ir2vec::ArgWeight; + + void createTestVocabulary(unsigned Dim) { + Vocabulary["add"] = ir2vec::Embedding(Dim, 0.1); + Vocabulary["sub"] = ir2vec::Embedding(Dim, 0.2); + Vocabulary["mul"] = ir2vec::Embedding(Dim, 0.3); + Vocabulary["icmp"] = ir2vec::Embedding(Dim, 0.4); + Vocabulary["br"] = ir2vec::Embedding(Dim, 0.5); + Vocabulary["ret"] = ir2vec::Embedding(Dim, 0.6); + Vocabulary["store"] = ir2vec::Embedding(Dim, 0.7); + Vocabulary["load"] = ir2vec::Embedding(Dim, 0.8); + Vocabulary["alloca"] = ir2vec::Embedding(Dim, 0.9); + Vocabulary["phi"] = ir2vec::Embedding(Dim, 1.0); + Vocabulary["call"] = ir2vec::Embedding(Dim, 1.1); + Vocabulary["voidTy"] = ir2vec::Embedding(Dim, 1.3); + Vocabulary["floatTy"] = ir2vec::Embedding(Dim, 1.4); + Vocabulary["integerTy"] = ir2vec::Embedding(Dim, 1.5); + Vocabulary["functionTy"] = ir2vec::Embedding(Dim, 1.6); + Vocabulary["structTy"] = ir2vec::Embedding(Dim, 1.7); + Vocabulary["arrayTy"] = ir2vec::Embedding(Dim, 1.8); + Vocabulary["pointerTy"] = ir2vec::Embedding(Dim, 1.9); + Vocabulary["vectorTy"] = ir2vec::Embedding(Dim, 2.0); + Vocabulary["emptyTy"] = ir2vec::Embedding(Dim, 2.1); + Vocabulary["labelTy"] = ir2vec::Embedding(Dim, 2.2); + Vocabulary["tokenTy"] = ir2vec::Embedding(Dim, 2.3); + Vocabulary["metadataTy"] = ir2vec::Embedding(Dim, 2.4); + Vocabulary["unknownTy"] = ir2vec::Embedding(Dim, 2.5); + Vocabulary["function"] = ir2vec::Embedding(Dim, 3.1); + Vocabulary["pointer"] = ir2vec::Embedding(Dim, 3.2); + Vocabulary["constant"] = ir2vec::Embedding(Dim, 3.3); + Vocabulary["variable"] = ir2vec::Embedding(Dim, 3.4); + Vocabulary["getelementptr"] = ir2vec::Embedding(Dim, 3.5); + Vocabulary["invoke"] = ir2vec::Embedding(Dim, 3.6); + Vocabulary["landingpad"] = ir2vec::Embedding(Dim, 3.7); + Vocabulary["resume"] = ir2vec::Embedding(Dim, 3.8); + Vocabulary["catch"] = ir2vec::Embedding(Dim, 3.9); + Vocabulary["cleanup"] = ir2vec::Embedding(Dim, 4.0); + return; } protected: std::unique_ptr<DominatorTree> DT; std::unique_ptr<LoopInfo> LI; FunctionAnalysisManager FAM; + ModuleAnalysisManager MAM; + ir2vec::Vocab Vocabulary; + + void TearDown() override { + // Restore original IR2Vec weights + ir2vec::OpcWeight = OriginalOpcWeight; + ir2vec::TypeWeight = OriginalTypeWeight; + ir2vec::ArgWeight = OriginalArgWeight; + } FunctionPropertiesInfo buildFPI(Function &F) { + // FunctionPropertiesInfo assumes IR2VecVocabAnalysis has been run to + // use IR2Vec. + auto VocabResult = MAM.getResult<IR2VecVocabAnalysis>(*F.getParent()); + (void)VocabResult; return FunctionPropertiesInfo::getFunctionPropertiesInfo(F, FAM); } @@ -62,15 +132,22 @@ class FunctionPropertiesAnalysisTest : public testing::Test { Err.print("MLAnalysisTests", errs()); return Mod; } - - CallBase* findCall(Function& F, const char* Name = nullptr) { + + CallBase *findCall(Function &F, const char *Name = nullptr) { for (auto &BB : F) - for (auto &I : BB ) + for (auto &I : BB) if (auto *CB = dyn_cast<CallBase>(&I)) if (!Name || CB->getName() == Name) return CB; return nullptr; } + + std::unique_ptr<ir2vec::Embedder> createEmbedder(const Function &F) { + auto EmbResult = + ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary); + EXPECT_TRUE(static_cast<bool>(EmbResult)); + return std::move(*EmbResult); + } }; TEST_F(FunctionPropertiesAnalysisTest, BasicTest) { @@ -113,6 +190,8 @@ define internal i32 @top() { EXPECT_EQ(BranchesFeatures.StoreInstCount, 0); EXPECT_EQ(BranchesFeatures.MaxLoopDepth, 0); EXPECT_EQ(BranchesFeatures.TopLevelLoopCount, 0); + EXPECT_TRUE(BranchesFeatures.getFunctionEmbedding().approximatelyEquals( + createEmbedder(*BranchesFunction)->getFunctionVector())); Function *TopFunction = M->getFunction("top"); FunctionPropertiesInfo TopFeatures = buildFPI(*TopFunction); @@ -120,6 +199,8 @@ define internal i32 @top() { EXPECT_EQ(TopFeatures.BlocksReachedFromConditionalInstruction, 0); EXPECT_EQ(TopFeatures.Uses, 0); EXPECT_EQ(TopFeatures.DirectCallsToDefinedFunctions, 1); + EXPECT_TRUE(TopFeatures.getFunctionEmbedding().approximatelyEquals( + createEmbedder(*TopFunction)->getFunctionVector())); EXPECT_EQ(BranchesFeatures.LoadInstCount, 0); EXPECT_EQ(BranchesFeatures.StoreInstCount, 0); EXPECT_EQ(BranchesFeatures.MaxLoopDepth, 0); @@ -159,6 +240,9 @@ define internal i32 @top() { EXPECT_EQ(DetailedBranchesFeatures.CallReturnsPointerCount, 0); EXPECT_EQ(DetailedBranchesFeatures.CallWithManyArgumentsCount, 0); EXPECT_EQ(DetailedBranchesFeatures.CallWithPointerArgumentCount, 0); + EXPECT_TRUE( + DetailedBranchesFeatures.getFunctionEmbedding().approximatelyEquals( + createEmbedder(*BranchesFunction)->getFunctionVector())); EnableDetailedFunctionProperties.setValue(false); } @@ -210,6 +294,8 @@ define i64 @f1() { EXPECT_EQ(DetailedF1Properties.CallReturnsPointerCount, 0); EXPECT_EQ(DetailedF1Properties.CallWithManyArgumentsCount, 0); EXPECT_EQ(DetailedF1Properties.CallWithPointerArgumentCount, 0); + EXPECT_TRUE(DetailedF1Properties.getFunctionEmbedding().approximatelyEquals( + createEmbedder(*F1)->getFunctionVector())); EnableDetailedFunctionProperties.setValue(false); } @@ -232,28 +318,29 @@ define i32 @f2(i32 %a) { )IR"); Function *F1 = M->getFunction("f1"); - CallBase* CB = findCall(*F1, "b"); + CallBase *CB = findCall(*F1, "b"); EXPECT_NE(CB, nullptr); - FunctionPropertiesInfo ExpectedInitial; - ExpectedInitial.BasicBlockCount = 1; - ExpectedInitial.TotalInstructionCount = 3; - ExpectedInitial.Uses = 1; - ExpectedInitial.DirectCallsToDefinedFunctions = 1; - - FunctionPropertiesInfo ExpectedFinal = ExpectedInitial; - ExpectedFinal.DirectCallsToDefinedFunctions = 0; - auto FPI = buildFPI(*F1); - EXPECT_EQ(FPI, ExpectedInitial); + EXPECT_EQ(FPI.BasicBlockCount, 1); + EXPECT_EQ(FPI.TotalInstructionCount, 3); + EXPECT_EQ(FPI.Uses, 1); + EXPECT_EQ(FPI.DirectCallsToDefinedFunctions, 1); + EXPECT_THAT(FPI.getFunctionEmbedding(), ElementsAre(DoubleNear(22.7, 1e-6))); FunctionPropertiesUpdater FPU(FPI, *CB); InlineFunctionInfo IFI; auto IR = llvm::InlineFunction(*CB, IFI); EXPECT_TRUE(IR.isSuccess()); invalidate(*F1); + EXPECT_TRUE(FPU.finishAndTest(FAM)); - EXPECT_EQ(FPI, ExpectedFinal); + EXPECT_EQ(FPI.BasicBlockCount, 1); + EXPECT_EQ(FPI.TotalInstructionCount, 3); + EXPECT_EQ(FPI.Uses, 1); + EXPECT_EQ(FPI.DirectCallsToDefinedFunctions, 0); + EXPECT_TRUE(FPI.getFunctionEmbedding().approximatelyEquals( + createEmbedder(*F1)->getFunctionVector())); } TEST_F(FunctionPropertiesAnalysisTest, InlineSameBBLargerCFG) { @@ -285,7 +372,7 @@ define i32 @f2(i32 %a) { )IR"); Function *F1 = M->getFunction("f1"); - CallBase* CB = findCall(*F1, "b"); + CallBase *CB = findCall(*F1, "b"); EXPECT_NE(CB, nullptr); FunctionPropertiesInfo ExpectedInitial; @@ -294,6 +381,8 @@ define i32 @f2(i32 %a) { ExpectedInitial.TotalInstructionCount = 9; ExpectedInitial.Uses = 1; ExpectedInitial.DirectCallsToDefinedFunctions = 1; + ExpectedInitial.setFunctionEmbeddingForTest( + createEmbedder(*F1)->getFunctionVector()); FunctionPropertiesInfo ExpectedFinal = ExpectedInitial; ExpectedFinal.DirectCallsToDefinedFunctions = 0; @@ -307,6 +396,9 @@ define i32 @f2(i32 %a) { EXPECT_TRUE(IR.isSuccess()); invalidate(*F1); EXPECT_TRUE(FPU.finishAndTest(FAM)); + + ExpectedFinal.setFunctionEmbeddingForTest( + createEmbedder(*F1)->getFunctionVector()); EXPECT_EQ(FPI, ExpectedFinal); } @@ -347,7 +439,7 @@ define i32 @f2(i32 %a) { )IR"); Function *F1 = M->getFunction("f1"); - CallBase* CB = findCall(*F1, "b"); + CallBase *CB = findCall(*F1, "b"); EXPECT_NE(CB, nullptr); FunctionPropertiesInfo ExpectedInitial; @@ -356,6 +448,8 @@ define i32 @f2(i32 %a) { ExpectedInitial.TotalInstructionCount = 9; ExpectedInitial.Uses = 1; ExpectedInitial.DirectCallsToDefinedFunctions = 1; + ExpectedInitial.setFunctionEmbeddingForTest( + createEmbedder(*F1)->getFunctionVector()); FunctionPropertiesInfo ExpectedFinal; ExpectedFinal.BasicBlockCount = 6; @@ -374,6 +468,9 @@ define i32 @f2(i32 %a) { EXPECT_TRUE(IR.isSuccess()); invalidate(*F1); EXPECT_TRUE(FPU.finishAndTest(FAM)); + + ExpectedFinal.setFunctionEmbeddingForTest( + createEmbedder(*F1)->getFunctionVector()); EXPECT_EQ(FPI, ExpectedFinal); } @@ -409,7 +506,7 @@ declare i32 @__gxx_personality_v0(...) )IR"); Function *F1 = M->getFunction("caller"); - CallBase* CB = findCall(*F1); + CallBase *CB = findCall(*F1); EXPECT_NE(CB, nullptr); auto FPI = buildFPI(*F1); @@ -422,6 +519,8 @@ declare i32 @__gxx_personality_v0(...) EXPECT_EQ(static_cast<size_t>(FPI.BasicBlockCount), F1->size()); EXPECT_EQ(static_cast<size_t>(FPI.TotalInstructionCount), F1->getInstructionCount()); + EXPECT_TRUE(FPI.getFunctionEmbedding().approximatelyEquals( + createEmbedder(*F1)->getFunctionVector())); } TEST_F(FunctionPropertiesAnalysisTest, InvokeUnreachableHandler) { @@ -462,7 +561,7 @@ declare i32 @__gxx_personality_v0(...) )IR"); Function *F1 = M->getFunction("caller"); - CallBase* CB = findCall(*F1); + CallBase *CB = findCall(*F1); EXPECT_NE(CB, nullptr); auto FPI = buildFPI(*F1); @@ -475,6 +574,8 @@ declare i32 @__gxx_personality_v0(...) EXPECT_EQ(static_cast<size_t>(FPI.BasicBlockCount), F1->size() - 1); EXPECT_EQ(static_cast<size_t>(FPI.TotalInstructionCount), F1->getInstructionCount() - 2); + EXPECT_TRUE(FPI.getFunctionEmbedding().approximatelyEquals( + createEmbedder(*F1)->getFunctionVector())); EXPECT_EQ(FPI, FunctionPropertiesInfo::getFunctionPropertiesInfo(*F1, FAM)); } @@ -516,7 +617,7 @@ declare i32 @__gxx_personality_v0(...) )IR"); Function *F1 = M->getFunction("caller"); - CallBase* CB = findCall(*F1); + CallBase *CB = findCall(*F1); EXPECT_NE(CB, nullptr); auto FPI = buildFPI(*F1); @@ -568,7 +669,7 @@ define void @outer() personality i8* null { )IR"); Function *F1 = M->getFunction("outer"); - CallBase* CB = findCall(*F1); + CallBase *CB = findCall(*F1); EXPECT_NE(CB, nullptr); auto FPI = buildFPI(*F1); @@ -581,6 +682,8 @@ define void @outer() personality i8* null { EXPECT_EQ(static_cast<size_t>(FPI.BasicBlockCount), F1->size() - 1); EXPECT_EQ(static_cast<size_t>(FPI.TotalInstructionCount), F1->getInstructionCount() - 2); + EXPECT_TRUE(FPI.getFunctionEmbedding().approximatelyEquals( + createEmbedder(*F1)->getFunctionVector())); EXPECT_EQ(FPI, FunctionPropertiesInfo::getFunctionPropertiesInfo(*F1, FAM)); } @@ -624,7 +727,7 @@ if.then: )IR"); Function *F1 = M->getFunction("outer"); - CallBase* CB = findCall(*F1); + CallBase *CB = findCall(*F1); EXPECT_NE(CB, nullptr); auto FPI = buildFPI(*F1); @@ -637,6 +740,8 @@ if.then: EXPECT_EQ(static_cast<size_t>(FPI.BasicBlockCount), F1->size() - 1); EXPECT_EQ(static_cast<size_t>(FPI.TotalInstructionCount), F1->getInstructionCount() - 2); + EXPECT_TRUE(FPI.getFunctionEmbedding().approximatelyEquals( + createEmbedder(*F1)->getFunctionVector())); EXPECT_EQ(FPI, FunctionPropertiesInfo::getFunctionPropertiesInfo(*F1, FAM)); } @@ -689,6 +794,8 @@ define i32 @f2(i32 %a) { ExpectedInitial.DirectCallsToDefinedFunctions = 1; ExpectedInitial.MaxLoopDepth = 1; ExpectedInitial.TopLevelLoopCount = 1; + ExpectedInitial.setFunctionEmbeddingForTest( + createEmbedder(*F1)->getFunctionVector()); FunctionPropertiesInfo ExpectedFinal = ExpectedInitial; ExpectedFinal.BasicBlockCount = 6; @@ -705,6 +812,9 @@ define i32 @f2(i32 %a) { EXPECT_TRUE(IR.isSuccess()); invalidate(*F1); EXPECT_TRUE(FPU.finishAndTest(FAM)); + + ExpectedFinal.setFunctionEmbeddingForTest( + createEmbedder(*F1)->getFunctionVector()); EXPECT_EQ(FPI, ExpectedFinal); } @@ -733,7 +843,7 @@ cond.false: ; preds = %entry extra2: br label %cond.end -cond.end: ; preds = %cond.false, %cond.true +cond.end: ; preds = %extra2, %cond.true %cond = phi i64 [ %conv2, %cond.true ], [ %call3, %extra ] ret i64 %cond } @@ -757,7 +867,9 @@ declare void @llvm.trap() ExpectedInitial.BlocksReachedFromConditionalInstruction = 2; ExpectedInitial.Uses = 1; ExpectedInitial.DirectCallsToDefinedFunctions = 1; - + ExpectedInitial.setFunctionEmbeddingForTest( + createEmbedder(*F1)->getFunctionVector()); + FunctionPropertiesInfo ExpectedFinal = ExpectedInitial; ExpectedFinal.BasicBlockCount = 4; ExpectedFinal.DirectCallsToDefinedFunctions = 0; @@ -772,6 +884,9 @@ declare void @llvm.trap() EXPECT_TRUE(IR.isSuccess()); invalidate(*F1); EXPECT_TRUE(FPU.finishAndTest(FAM)); + + ExpectedFinal.setFunctionEmbeddingForTest( + createEmbedder(*F1)->getFunctionVector()); EXPECT_EQ(FPI, ExpectedFinal); } @@ -817,6 +932,8 @@ declare void @f3() ExpectedInitial.BlocksReachedFromConditionalInstruction = 0; ExpectedInitial.Uses = 1; ExpectedInitial.DirectCallsToDefinedFunctions = 1; + ExpectedInitial.setFunctionEmbeddingForTest( + createEmbedder(*F1)->getFunctionVector()); FunctionPropertiesInfo ExpectedFinal = ExpectedInitial; ExpectedFinal.BasicBlockCount = 6; @@ -832,6 +949,9 @@ declare void @f3() EXPECT_TRUE(IR.isSuccess()); invalidate(*F1); EXPECT_TRUE(FPU.finishAndTest(FAM)); + + ExpectedFinal.setFunctionEmbeddingForTest( + createEmbedder(*F1)->getFunctionVector()); EXPECT_EQ(FPI, ExpectedFinal); } @@ -885,6 +1005,8 @@ define i64 @f1(i64 %e) { EXPECT_EQ(DetailedF1Properties.CallReturnsPointerCount, 0); EXPECT_EQ(DetailedF1Properties.CallWithManyArgumentsCount, 0); EXPECT_EQ(DetailedF1Properties.CallWithPointerArgumentCount, 0); + EXPECT_TRUE(DetailedF1Properties.getFunctionEmbedding().approximatelyEquals( + createEmbedder(*F1)->getFunctionVector())); EnableDetailedFunctionProperties.setValue(false); } @@ -910,6 +1032,8 @@ declare float @llvm.cos.f32(float) EXPECT_EQ(DetailedF1Properties.CallReturnsPointerCount, 0); EXPECT_EQ(DetailedF1Properties.CallWithManyArgumentsCount, 0); EXPECT_EQ(DetailedF1Properties.CallWithPointerArgumentCount, 0); + EXPECT_TRUE(DetailedF1Properties.getFunctionEmbedding().approximatelyEquals( + createEmbedder(*F1)->getFunctionVector())); EnableDetailedFunctionProperties.setValue(false); } @@ -943,6 +1067,8 @@ declare float @f5() EXPECT_EQ(DetailedF1Properties.CallReturnsPointerCount, 1); EXPECT_EQ(DetailedF1Properties.CallWithManyArgumentsCount, 1); EXPECT_EQ(DetailedF1Properties.CallWithPointerArgumentCount, 1); + EXPECT_TRUE(DetailedF1Properties.getFunctionEmbedding().approximatelyEquals( + createEmbedder(*F1)->getFunctionVector())); EnableDetailedFunctionProperties.setValue(false); } @@ -972,10 +1098,11 @@ define i64 @f1(i64 %a) { EnableDetailedFunctionProperties.setValue(true); FunctionPropertiesInfo DetailedF1Properties = buildFPI(*F1); EXPECT_EQ(DetailedF1Properties.CriticalEdgeCount, 1); + EXPECT_TRUE(DetailedF1Properties.getFunctionEmbedding().approximatelyEquals( + createEmbedder(*F1)->getFunctionVector())); EnableDetailedFunctionProperties.setValue(false); } - TEST_F(FunctionPropertiesAnalysisTest, FunctionReturnVectors) { LLVMContext C; std::unique_ptr<Module> M = makeLLVMModule(C, @@ -998,6 +1125,8 @@ declare <4 x ptr> @f4() EXPECT_EQ(DetailedF1Properties.CallReturnsVectorIntCount, 1); EXPECT_EQ(DetailedF1Properties.CallReturnsVectorFloatCount, 1); EXPECT_EQ(DetailedF1Properties.CallReturnsVectorPointerCount, 1); + EXPECT_TRUE(DetailedF1Properties.getFunctionEmbedding().approximatelyEquals( + createEmbedder(*F1)->getFunctionVector())); EnableDetailedFunctionProperties.setValue(false); } _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits