https://github.com/svkeerthy updated https://github.com/llvm/llvm-project/pull/162161
>From c58a41517b9358e9a56f4ea0b304dd8c7fe48741 Mon Sep 17 00:00:00 2001 From: svkeerthy <[email protected]> Date: Mon, 6 Oct 2025 21:15:14 +0000 Subject: [PATCH] MIR2Vec embedding --- llvm/include/llvm/CodeGen/MIR2Vec.h | 108 ++++++ llvm/include/llvm/CodeGen/Passes.h | 4 + llvm/include/llvm/InitializePasses.h | 1 + llvm/lib/CodeGen/CodeGen.cpp | 1 + llvm/lib/CodeGen/MIR2Vec.cpp | 195 ++++++++++- .../Inputs/mir2vec_dummy_3D_vocab.json | 22 ++ llvm/test/CodeGen/MIR2Vec/if-else.mir | 144 ++++++++ .../MIR2Vec/mir2vec-basic-symbolic.mir | 76 ++++ llvm/tools/llc/llc.cpp | 15 + llvm/unittests/CodeGen/MIR2VecTest.cpp | 324 ++++++++++++++++-- 10 files changed, 863 insertions(+), 27 deletions(-) create mode 100644 llvm/test/CodeGen/MIR2Vec/Inputs/mir2vec_dummy_3D_vocab.json create mode 100644 llvm/test/CodeGen/MIR2Vec/if-else.mir create mode 100644 llvm/test/CodeGen/MIR2Vec/mir2vec-basic-symbolic.mir diff --git a/llvm/include/llvm/CodeGen/MIR2Vec.h b/llvm/include/llvm/CodeGen/MIR2Vec.h index ea68b4594a2ad..ebafe4ccddff3 100644 --- a/llvm/include/llvm/CodeGen/MIR2Vec.h +++ b/llvm/include/llvm/CodeGen/MIR2Vec.h @@ -51,11 +51,21 @@ class LLVMContext; class MIR2VecVocabLegacyAnalysis; class TargetInstrInfo; +enum class MIR2VecKind { Symbolic }; + namespace mir2vec { + +// Forward declarations +class MIREmbedder; +class SymbolicMIREmbedder; + extern llvm::cl::OptionCategory MIR2VecCategory; extern cl::opt<float> OpcWeight; using Embedding = ir2vec::Embedding; +using MachineInstEmbeddingsMap = DenseMap<const MachineInstr *, Embedding>; +using MachineBlockEmbeddingsMap = + DenseMap<const MachineBasicBlock *, Embedding>; /// Class for storing and accessing the MIR2Vec vocabulary. /// The MIRVocabulary class manages seed embeddings for LLVM Machine IR @@ -132,6 +142,79 @@ class MIRVocabulary { assert(isValid() && "Invalid vocabulary"); return Storage.size(); } + + /// Create a dummy vocabulary for testing purposes. + static MIRVocabulary createDummyVocabForTest(const TargetInstrInfo &TII, + unsigned Dim = 1); +}; + +/// Base class for MIR embedders +class MIREmbedder { +protected: + const MachineFunction &MF; + const MIRVocabulary &Vocab; + + /// Dimension of the embeddings; Captured from the vocabulary + const unsigned Dimension; + + /// Weight for opcode embeddings + const float OpcWeight; + + // Utility maps - these are used to store the vector representations of + // instructions, basic blocks and functions. + mutable Embedding MFuncVector; + mutable MachineBlockEmbeddingsMap MBBVecMap; + mutable MachineInstEmbeddingsMap MInstVecMap; + + MIREmbedder(const MachineFunction &MF, const MIRVocabulary &Vocab); + + /// Function to compute embeddings. It generates embeddings for all + /// the instructions and basic blocks in the function F. + void computeEmbeddings() const; + + /// Function to compute the embedding for a given basic block. + /// Specific to the kind of embeddings being computed. + virtual void computeEmbeddings(const MachineBasicBlock &MBB) const = 0; + +public: + virtual ~MIREmbedder() = default; + + /// Factory method to create an Embedder object of the specified kind + /// Returns nullptr if the requested kind is not supported. + static std::unique_ptr<MIREmbedder> create(MIR2VecKind Mode, + const MachineFunction &MF, + const MIRVocabulary &Vocab); + + /// Returns a map containing machine instructions and the corresponding + /// embeddings for the machine function MF if it has been computed. If not, it + /// computes the embeddings for MF and returns the map. + const MachineInstEmbeddingsMap &getMInstVecMap() const; + + /// Returns a map containing machine basic block and the corresponding + /// embeddings for the machine function MF if it has been computed. If not, it + /// computes the embeddings for MF and returns the map. + const MachineBlockEmbeddingsMap &getMBBVecMap() const; + + /// Returns the embedding for a given machine basic block in the machine + /// function MF if it has been computed. If not, it computes the embedding for + /// MBB and returns it. + const Embedding &getMBBVector(const MachineBasicBlock &MBB) const; + + /// Computes and returns the embedding for the current machine function. + const Embedding &getMFunctionVector() const; +}; + +/// Class for computing Symbolic embeddings +/// Symbolic embeddings are constructed based on the entity-level +/// representations obtained from the MIR Vocabulary. +class SymbolicMIREmbedder : public MIREmbedder { +private: + void computeEmbeddings(const MachineBasicBlock &MBB) const override; + +public: + SymbolicMIREmbedder(const MachineFunction &F, const MIRVocabulary &Vocab); + static std::unique_ptr<SymbolicMIREmbedder> + create(const MachineFunction &MF, const MIRVocabulary &Vocab); }; } // namespace mir2vec @@ -181,6 +264,31 @@ class MIR2VecVocabPrinterLegacyPass : public MachineFunctionPass { } }; +/// This pass prints the MIR2Vec embeddings for machine functions, basic blocks, +/// and instructions +class MIR2VecPrinterLegacyPass : public MachineFunctionPass { + raw_ostream &OS; + +public: + static char ID; + explicit MIR2VecPrinterLegacyPass(raw_ostream &OS) + : MachineFunctionPass(ID), OS(OS) {} + + bool runOnMachineFunction(MachineFunction &MF) override; + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<MIR2VecVocabLegacyAnalysis>(); + AU.setPreservesAll(); + MachineFunctionPass::getAnalysisUsage(AU); + } + + StringRef getPassName() const override { + return "MIR2Vec Embedder Printer Pass"; + } +}; + +/// Create a machine pass that prints MIR2Vec embeddings +MachineFunctionPass *createMIR2VecPrinterLegacyPass(raw_ostream &OS); + } // namespace llvm #endif // LLVM_CODEGEN_MIR2VEC_H \ No newline at end of file diff --git a/llvm/include/llvm/CodeGen/Passes.h b/llvm/include/llvm/CodeGen/Passes.h index 272b4acf950c5..7fae550d8d170 100644 --- a/llvm/include/llvm/CodeGen/Passes.h +++ b/llvm/include/llvm/CodeGen/Passes.h @@ -93,6 +93,10 @@ createMachineFunctionPrinterPass(raw_ostream &OS, LLVM_ABI MachineFunctionPass * createMIR2VecVocabPrinterLegacyPass(raw_ostream &OS); +/// MIR2VecPrinter pass - This pass prints out the MIR2Vec embeddings for +/// machine functions, basic blocks and instructions. +LLVM_ABI MachineFunctionPass *createMIR2VecPrinterLegacyPass(raw_ostream &OS); + /// StackFramePrinter pass - This pass prints out the machine function's /// stack frame to the given stream as a debugging tool. LLVM_ABI MachineFunctionPass *createStackFrameLayoutAnalysisPass(); diff --git a/llvm/include/llvm/InitializePasses.h b/llvm/include/llvm/InitializePasses.h index cd774e7888e64..d507ba267d791 100644 --- a/llvm/include/llvm/InitializePasses.h +++ b/llvm/include/llvm/InitializePasses.h @@ -222,6 +222,7 @@ LLVM_ABI void initializeMachineSanitizerBinaryMetadataLegacyPass(PassRegistry &); LLVM_ABI void initializeMIR2VecVocabLegacyAnalysisPass(PassRegistry &); LLVM_ABI void initializeMIR2VecVocabPrinterLegacyPassPass(PassRegistry &); +LLVM_ABI void initializeMIR2VecPrinterLegacyPassPass(PassRegistry &); LLVM_ABI void initializeMachineSchedulerLegacyPass(PassRegistry &); LLVM_ABI void initializeMachineSinkingLegacyPass(PassRegistry &); LLVM_ABI void initializeMachineTraceMetricsWrapperPassPass(PassRegistry &); diff --git a/llvm/lib/CodeGen/CodeGen.cpp b/llvm/lib/CodeGen/CodeGen.cpp index c438eaeb29d1e..9795a0b707fd3 100644 --- a/llvm/lib/CodeGen/CodeGen.cpp +++ b/llvm/lib/CodeGen/CodeGen.cpp @@ -98,6 +98,7 @@ void llvm::initializeCodeGen(PassRegistry &Registry) { initializeMachineUniformityAnalysisPassPass(Registry); initializeMIR2VecVocabLegacyAnalysisPass(Registry); initializeMIR2VecVocabPrinterLegacyPassPass(Registry); + initializeMIR2VecPrinterLegacyPassPass(Registry); initializeMachineUniformityInfoPrinterPassPass(Registry); initializeMachineVerifierLegacyPassPass(Registry); initializeObjCARCContractLegacyPassPass(Registry); diff --git a/llvm/lib/CodeGen/MIR2Vec.cpp b/llvm/lib/CodeGen/MIR2Vec.cpp index 87565c0c77115..ffd076331c7d2 100644 --- a/llvm/lib/CodeGen/MIR2Vec.cpp +++ b/llvm/lib/CodeGen/MIR2Vec.cpp @@ -41,11 +41,18 @@ static cl::opt<std::string> cl::opt<float> OpcWeight("mir2vec-opc-weight", cl::Optional, cl::init(1.0), cl::desc("Weight for machine opcode embeddings"), cl::cat(MIR2VecCategory)); +cl::opt<MIR2VecKind> MIR2VecEmbeddingKind( + "mir2vec-kind", cl::Optional, + cl::values(clEnumValN(MIR2VecKind::Symbolic, "symbolic", + "Generate symbolic embeddings for MIR")), + cl::init(MIR2VecKind::Symbolic), cl::desc("MIR2Vec embedding kind"), + cl::cat(MIR2VecCategory)); + } // namespace mir2vec } // namespace llvm //===----------------------------------------------------------------------===// -// Vocabulary Implementation +// Vocabulary //===----------------------------------------------------------------------===// MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries, @@ -190,6 +197,29 @@ void MIRVocabulary::buildCanonicalOpcodeMapping() { << " unique base opcodes\n"); } +MIRVocabulary MIRVocabulary::createDummyVocabForTest(const TargetInstrInfo &TII, + unsigned Dim) { + assert(Dim > 0 && "Dimension must be greater than zero"); + + float DummyVal = 0.1f; + + // Create a temporary vocabulary instance to build canonical mapping + MIRVocabulary TempVocab({}, &TII); + TempVocab.buildCanonicalOpcodeMapping(); + + // Create dummy embeddings for all canonical opcode names + VocabMap DummyVocabMap; + for (const auto &COpcodeName : TempVocab.UniqueBaseOpcodeNames) { + // Create dummy embedding filled with DummyVal + Embedding DummyEmbedding(Dim, DummyVal); + DummyVocabMap[COpcodeName] = DummyEmbedding; + DummyVal += 0.1f; + } + + // Create and return vocabulary with dummy embeddings + return MIRVocabulary(std::move(DummyVocabMap), &TII); +} + //===----------------------------------------------------------------------===// // MIR2VecVocabLegacyAnalysis Implementation //===----------------------------------------------------------------------===// @@ -267,7 +297,104 @@ MIR2VecVocabLegacyAnalysis::getMIR2VecVocabulary(const Module &M) { } //===----------------------------------------------------------------------===// -// Printer Passes Implementation +// MIREmbedder and its subclasses +//===----------------------------------------------------------------------===// + +MIREmbedder::MIREmbedder(const MachineFunction &MF, const MIRVocabulary &Vocab) + : MF(MF), Vocab(Vocab), Dimension(Vocab.getDimension()), + OpcWeight(::OpcWeight), MFuncVector(Embedding(Dimension)) {} + +std::unique_ptr<MIREmbedder> MIREmbedder::create(MIR2VecKind Mode, + const MachineFunction &MF, + const MIRVocabulary &Vocab) { + switch (Mode) { + case MIR2VecKind::Symbolic: + return std::make_unique<SymbolicMIREmbedder>(MF, Vocab); + } + return nullptr; +} + +const MachineInstEmbeddingsMap &MIREmbedder::getMInstVecMap() const { + if (MInstVecMap.empty()) + computeEmbeddings(); + return MInstVecMap; +} + +const MachineBlockEmbeddingsMap &MIREmbedder::getMBBVecMap() const { + if (MBBVecMap.empty()) + computeEmbeddings(); + return MBBVecMap; +} + +const Embedding &MIREmbedder::getMBBVector(const MachineBasicBlock &BB) const { + auto It = MBBVecMap.find(&BB); + if (It != MBBVecMap.end()) + return It->second; + computeEmbeddings(BB); + return MBBVecMap[&BB]; +} + +const Embedding &MIREmbedder::getMFunctionVector() const { + // Currently, we always (re)compute the embeddings for the function. + // This is cheaper than caching the vector. + computeEmbeddings(); + return MFuncVector; +} + +void MIREmbedder::computeEmbeddings() const { + // Reset function vector to zero before recomputing + MFuncVector = Embedding(Dimension, 0.0); + + // Consider all machine basic blocks in the function + for (const auto &MBB : MF) { + computeEmbeddings(MBB); + MFuncVector += MBBVecMap[&MBB]; + } +} + +SymbolicMIREmbedder::SymbolicMIREmbedder(const MachineFunction &MF, + const MIRVocabulary &Vocab) + : MIREmbedder(MF, Vocab) {} + +std::unique_ptr<SymbolicMIREmbedder> +SymbolicMIREmbedder::create(const MachineFunction &MF, + const MIRVocabulary &Vocab) { + return std::make_unique<SymbolicMIREmbedder>(MF, Vocab); +} + +void SymbolicMIREmbedder::computeEmbeddings( + const MachineBasicBlock &MBB) const { + Embedding MBBVector(Dimension, 0); + + // Get instruction info for opcode name resolution + const auto &Subtarget = MF.getSubtarget(); + const auto *TII = Subtarget.getInstrInfo(); + if (!TII) { + MF.getFunction().getContext().emitError( + "MIR2Vec: No TargetInstrInfo available; cannot compute embeddings"); + return; + } + + // Process each machine instruction in the basic block + for (const auto &MI : MBB) { + // Skip debug instructions and other metadata + if (MI.isDebugInstr()) + continue; + + // Todo: Add operand/argument contributions + + // Store the instruction embedding + auto InstVector = Vocab[MI.getOpcode()]; + MInstVecMap[&MI] = InstVector; + MBBVector += InstVector; + } + + // Store the basic block embedding + MBBVecMap[&MBB] = MBBVector; +} + +//===----------------------------------------------------------------------===// +// Printer Passes //===----------------------------------------------------------------------===// char MIR2VecVocabPrinterLegacyPass::ID = 0; @@ -304,3 +431,67 @@ MachineFunctionPass * llvm::createMIR2VecVocabPrinterLegacyPass(raw_ostream &OS) { return new MIR2VecVocabPrinterLegacyPass(OS); } + +char MIR2VecPrinterLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(MIR2VecPrinterLegacyPass, "print-mir2vec", + "MIR2Vec Embedder Printer Pass", false, true) +INITIALIZE_PASS_DEPENDENCY(MIR2VecVocabLegacyAnalysis) +INITIALIZE_PASS_DEPENDENCY(MachineModuleInfoWrapperPass) +INITIALIZE_PASS_END(MIR2VecPrinterLegacyPass, "print-mir2vec", + "MIR2Vec Embedder Printer Pass", false, true) + +bool MIR2VecPrinterLegacyPass::runOnMachineFunction(MachineFunction &MF) { + auto &Analysis = getAnalysis<MIR2VecVocabLegacyAnalysis>(); + auto MIRVocab = Analysis.getMIR2VecVocabulary(*MF.getFunction().getParent()); + + if (!MIRVocab.isValid()) { + OS << "MIR2Vec Embedder Printer: Invalid vocabulary for function " + << MF.getName() << "\n"; + return false; + } + + auto Emb = mir2vec::MIREmbedder::create(MIR2VecEmbeddingKind, MF, MIRVocab); + if (!Emb) { + OS << "Error creating MIR2Vec embeddings for function " << MF.getName() + << "\n"; + return false; + } + + OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n"; + OS << "Machine Function vector: "; + Emb->getMFunctionVector().print(OS); + + OS << "Machine basic block vectors:\n"; + const auto &MBBMap = Emb->getMBBVecMap(); + for (const MachineBasicBlock &MBB : MF) { + auto It = MBBMap.find(&MBB); + if (It != MBBMap.end()) { + OS << "Machine basic block: " << MBB.getFullName() << ":\n"; + It->second.print(OS); + } + } + + OS << "Machine instruction vectors:\n"; + const auto &MInstMap = Emb->getMInstVecMap(); + for (const MachineBasicBlock &MBB : MF) { + for (const MachineInstr &MI : MBB) { + // Skip debug instructions as they are not + // embedded + if (MI.isDebugInstr()) + continue; + + auto It = MInstMap.find(&MI); + if (It != MInstMap.end()) { + OS << "Machine instruction: "; + MI.print(OS); + It->second.print(OS); + } + } + } + + return false; +} + +MachineFunctionPass *llvm::createMIR2VecPrinterLegacyPass(raw_ostream &OS) { + return new MIR2VecPrinterLegacyPass(OS); +} diff --git a/llvm/test/CodeGen/MIR2Vec/Inputs/mir2vec_dummy_3D_vocab.json b/llvm/test/CodeGen/MIR2Vec/Inputs/mir2vec_dummy_3D_vocab.json new file mode 100644 index 0000000000000..5de715bf80917 --- /dev/null +++ b/llvm/test/CodeGen/MIR2Vec/Inputs/mir2vec_dummy_3D_vocab.json @@ -0,0 +1,22 @@ +{ + "entities": { + "KILL": [0.1, 0.2, 0.3], + "MOV": [0.4, 0.5, 0.6], + "LEA": [0.7, 0.8, 0.9], + "RET": [1.0, 1.1, 1.2], + "ADD": [1.3, 1.4, 1.5], + "SUB": [1.6, 1.7, 1.8], + "IMUL": [1.9, 2.0, 2.1], + "AND": [2.2, 2.3, 2.4], + "OR": [2.5, 2.6, 2.7], + "XOR": [2.8, 2.9, 3.0], + "CMP": [3.1, 3.2, 3.3], + "TEST": [3.4, 3.5, 3.6], + "JMP": [3.7, 3.8, 3.9], + "CALL": [4.0, 4.1, 4.2], + "PUSH": [4.3, 4.4, 4.5], + "POP": [4.6, 4.7, 4.8], + "NOP": [4.9, 5.0, 5.1], + "COPY": [5.2, 5.3, 5.4] + } +} \ No newline at end of file diff --git a/llvm/test/CodeGen/MIR2Vec/if-else.mir b/llvm/test/CodeGen/MIR2Vec/if-else.mir new file mode 100644 index 0000000000000..2accf476f7c4d --- /dev/null +++ b/llvm/test/CodeGen/MIR2Vec/if-else.mir @@ -0,0 +1,144 @@ +# REQUIRES: x86_64-linux +# RUN: llc -run-pass=none -print-mir2vec -mir2vec-vocab-path=%S/Inputs/mir2vec_dummy_3D_vocab.json %s -o /dev/null 2>&1 | FileCheck %s + +--- | + target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128" + + define dso_local i32 @abc(i32 noundef %a, i32 noundef %b) { + entry: + %retval = alloca i32, align 4 + %a.addr = alloca i32, align 4 + %b.addr = alloca i32, align 4 + store i32 %a, ptr %a.addr, align 4 + store i32 %b, ptr %b.addr, align 4 + %0 = load i32, ptr %a.addr, align 4 + %1 = load i32, ptr %b.addr, align 4 + %cmp = icmp sgt i32 %0, %1 + br i1 %cmp, label %if.then, label %if.else + + if.then: ; preds = %entry + %2 = load i32, ptr %b.addr, align 4 + store i32 %2, ptr %retval, align 4 + br label %return + + if.else: ; preds = %entry + %3 = load i32, ptr %a.addr, align 4 + store i32 %3, ptr %retval, align 4 + br label %return + + return: ; preds = %if.else, %if.then + %4 = load i32, ptr %retval, align 4 + ret i32 %4 + } +... +--- +name: abc +alignment: 16 +exposesReturnsTwice: false +legalized: false +regBankSelected: false +selected: false +failedISel: false +tracksRegLiveness: true +hasWinCFI: false +noPhis: false +isSSA: true +noVRegs: false +hasFakeUses: false +callsEHReturn: false +callsUnwindInit: false +hasEHContTarget: false +hasEHScopes: false +hasEHFunclets: false +isOutlined: false +debugInstrRef: true +failsVerification: false +tracksDebugUserValues: false +registers: + - { id: 0, class: gr32, preferred-register: '', flags: [ ] } + - { id: 1, class: gr32, preferred-register: '', flags: [ ] } + - { id: 2, class: gr32, preferred-register: '', flags: [ ] } + - { id: 3, class: gr32, preferred-register: '', flags: [ ] } + - { id: 4, class: gr32, preferred-register: '', flags: [ ] } + - { id: 5, class: gr32, preferred-register: '', flags: [ ] } +liveins: + - { reg: '$edi', virtual-reg: '%0' } + - { reg: '$esi', virtual-reg: '%1' } +frameInfo: + isFrameAddressTaken: false + isReturnAddressTaken: false + hasStackMap: false + hasPatchPoint: false + stackSize: 0 + offsetAdjustment: 0 + maxAlignment: 4 + adjustsStack: false + hasCalls: false + stackProtector: '' + functionContext: '' + maxCallFrameSize: 4294967295 + cvBytesOfCalleeSavedRegisters: 0 + hasOpaqueSPAdjustment: false + hasVAStart: false + hasMustTailInVarArgFunc: false + hasTailCall: false + isCalleeSavedInfoValid: false + localFrameSize: 0 +fixedStack: [] +stack: + - { id: 0, name: retval, type: default, offset: 0, size: 4, alignment: 4, + stack-id: default, callee-saved-register: '', callee-saved-restored: true, + debug-info-variable: '', debug-info-expression: '', debug-info-location: '' } + - { id: 1, name: a.addr, type: default, offset: 0, size: 4, alignment: 4, + stack-id: default, callee-saved-register: '', callee-saved-restored: true, + debug-info-variable: '', debug-info-expression: '', debug-info-location: '' } + - { id: 2, name: b.addr, type: default, offset: 0, size: 4, alignment: 4, + stack-id: default, callee-saved-register: '', callee-saved-restored: true, + debug-info-variable: '', debug-info-expression: '', debug-info-location: '' } +entry_values: [] +callSites: [] +debugValueSubstitutions: [] +constants: [] +machineFunctionInfo: + amxProgModel: None +body: | + bb.0.entry: + successors: %bb.1(0x40000000), %bb.2(0x40000000) + liveins: $edi, $esi + + %1:gr32 = COPY $esi + %0:gr32 = COPY $edi + MOV32mr %stack.1.a.addr, 1, $noreg, 0, $noreg, %0 :: (store (s32) into %ir.a.addr) + MOV32mr %stack.2.b.addr, 1, $noreg, 0, $noreg, %1 :: (store (s32) into %ir.b.addr) + %2:gr32 = SUB32rr %0, %1, implicit-def $eflags + JCC_1 %bb.2, 14, implicit $eflags + JMP_1 %bb.1 + + bb.1.if.then: + successors: %bb.3(0x80000000) + + %4:gr32 = MOV32rm %stack.2.b.addr, 1, $noreg, 0, $noreg :: (dereferenceable load (s32) from %ir.b.addr) + MOV32mr %stack.0.retval, 1, $noreg, 0, $noreg, killed %4 :: (store (s32) into %ir.retval) + JMP_1 %bb.3 + + bb.2.if.else: + successors: %bb.3(0x80000000) + + %3:gr32 = MOV32rm %stack.1.a.addr, 1, $noreg, 0, $noreg :: (dereferenceable load (s32) from %ir.a.addr) + MOV32mr %stack.0.retval, 1, $noreg, 0, $noreg, killed %3 :: (store (s32) into %ir.retval) + + bb.3.return: + %5:gr32 = MOV32rm %stack.0.retval, 1, $noreg, 0, $noreg :: (dereferenceable load (s32) from %ir.retval) + $eax = COPY %5 + RET 0, $eax +... + +# CHECK: Machine basic block vectors: +# CHECK-NEXT: Machine basic block: abc:entry: +# CHECK-NEXT: [ 16.50 17.10 17.70 ] +# CHECK-NEXT: Machine basic block: abc:if.then: +# CHECK-NEXT: [ 4.50 4.80 5.10 ] +# CHECK-NEXT: Machine basic block: abc:if.else: +# CHECK-NEXT: [ 0.80 1.00 1.20 ] +# CHECK-NEXT: Machine basic block: abc:return: +# CHECK-NEXT: [ 6.60 6.90 7.20 ] \ No newline at end of file diff --git a/llvm/test/CodeGen/MIR2Vec/mir2vec-basic-symbolic.mir b/llvm/test/CodeGen/MIR2Vec/mir2vec-basic-symbolic.mir new file mode 100644 index 0000000000000..44240affb2206 --- /dev/null +++ b/llvm/test/CodeGen/MIR2Vec/mir2vec-basic-symbolic.mir @@ -0,0 +1,76 @@ +# REQUIRES: x86_64-linux +# RUN: llc -run-pass=none -print-mir2vec -mir2vec-vocab-path=%S/Inputs/mir2vec_dummy_3D_vocab.json %s -o /dev/null 2>&1 | FileCheck %s + +--- | + target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128" + + define dso_local noundef i32 @add_function(i32 noundef %a, i32 noundef %b) { + entry: + %sum = add nsw i32 %a, %b + %result = mul nsw i32 %sum, 2 + ret i32 %result + } + + define dso_local void @simple_function() { + entry: + ret void + } +... +--- +name: add_function +alignment: 16 +tracksRegLiveness: true +registers: + - { id: 0, class: gr32 } + - { id: 1, class: gr32 } + - { id: 2, class: gr32 } + - { id: 3, class: gr32 } +liveins: + - { reg: '$edi', virtual-reg: '%0' } + - { reg: '$esi', virtual-reg: '%1' } +body: | + bb.0.entry: + liveins: $edi, $esi + + %1:gr32 = COPY $esi + %0:gr32 = COPY $edi + %2:gr32 = nsw ADD32rr %0, %1, implicit-def dead $eflags + %3:gr32 = ADD32rr %2, %2, implicit-def dead $eflags + $eax = COPY %3 + RET 0, $eax + +--- +name: simple_function +alignment: 16 +tracksRegLiveness: true +body: | + bb.0.entry: + RET 0 + +# CHECK: MIR2Vec embeddings for machine function add_function: +# CHECK: Function vector: [ 19.20 19.80 20.40 ] +# CHECK-NEXT: Machine basic block vectors: +# CHECK-NEXT: Machine basic block: add_function:entry: +# CHECK-NEXT: [ 19.20 19.80 20.40 ] +# CHECK-NEXT: Machine instruction vectors: +# CHECK-NEXT: Machine instruction: %1:gr32 = COPY $esi +# CHECK-NEXT: [ 5.20 5.30 5.40 ] +# CHECK-NEXT: Machine instruction: %0:gr32 = COPY $edi +# CHECK-NEXT: [ 5.20 5.30 5.40 ] +# CHECK-NEXT: Machine instruction: %2:gr32 = nsw ADD32rr %0:gr32(tied-def 0), %1:gr32, implicit-def dead $eflags +# CHECK-NEXT: [ 1.30 1.40 1.50 ] +# CHECK-NEXT: Machine instruction: %3:gr32 = ADD32rr %2:gr32(tied-def 0), %2:gr32, implicit-def dead $eflags +# CHECK-NEXT: [ 1.30 1.40 1.50 ] +# CHECK-NEXT: Machine instruction: $eax = COPY %3:gr32 +# CHECK-NEXT: [ 5.20 5.30 5.40 ] +# CHECK-NEXT: Machine instruction: RET 0, $eax +# CHECK-NEXT: [ 1.00 1.10 1.20 ] + +# CHECK: MIR2Vec embeddings for machine function simple_function: +# CHECK-NEXT:Function vector: [ 1.00 1.10 1.20 ] +# CHECK-NEXT: Machine basic block vectors: +# CHECK-NEXT: Machine basic block: simple_function:entry: +# CHECK-NEXT: [ 1.00 1.10 1.20 ] +# CHECK-NEXT: Machine instruction vectors: +# CHECK-NEXT: Machine instruction: RET 0 +# CHECK-NEXT: [ 1.00 1.10 1.20 ] \ No newline at end of file diff --git a/llvm/tools/llc/llc.cpp b/llvm/tools/llc/llc.cpp index 7551a80fbd5d8..d85b632fa7c3d 100644 --- a/llvm/tools/llc/llc.cpp +++ b/llvm/tools/llc/llc.cpp @@ -171,6 +171,11 @@ static cl::opt<bool> cl::desc("Print MIR2Vec vocabulary contents"), cl::init(false)); +static cl::opt<bool> + PrintMIR2Vec("print-mir2vec", cl::Hidden, + cl::desc("Print MIR2Vec embeddings for functions"), + cl::init(false)); + static cl::list<std::string> IncludeDirs("I", cl::desc("include search path")); static cl::opt<bool> RemarksWithHotness( @@ -736,6 +741,11 @@ static int compileModule(char **argv, LLVMContext &Context) { PM.add(createMIR2VecVocabPrinterLegacyPass(errs())); } + // Add MIR2Vec printer if requested + if (PrintMIR2Vec) { + PM.add(createMIR2VecPrinterLegacyPass(errs())); + } + PM.add(createFreeMachineFunctionPass()); } else { if (Target->addPassesToEmitFile(PM, *OS, DwoOut ? &DwoOut->os() : nullptr, @@ -749,6 +759,11 @@ static int compileModule(char **argv, LLVMContext &Context) { if (PrintMIR2VecVocab) { PM.add(createMIR2VecVocabPrinterLegacyPass(errs())); } + + // Add MIR2Vec printer if requested + if (PrintMIR2Vec) { + PM.add(createMIR2VecPrinterLegacyPass(errs())); + } } Target->getObjFileLowering()->Initialize(MMIWP->getMMI().getContext(), diff --git a/llvm/unittests/CodeGen/MIR2VecTest.cpp b/llvm/unittests/CodeGen/MIR2VecTest.cpp index 7b282c7abe68c..c7e8f339cf90a 100644 --- a/llvm/unittests/CodeGen/MIR2VecTest.cpp +++ b/llvm/unittests/CodeGen/MIR2VecTest.cpp @@ -76,6 +76,9 @@ class MIR2VecVocabTestFixture : public ::testing::Test { M->getTargetTriple(), "", "", Options, Reloc::Model::Static)); ASSERT_TRUE(TM); + // Set the data layout to match the target machine + M->setDataLayout(TM->createDataLayout()); + // Create a dummy function to get subtarget info FunctionType *FT = FunctionType::get(Type::getVoidTy(*Ctx), false); Function *F = @@ -85,16 +88,32 @@ class MIR2VecVocabTestFixture : public ::testing::Test { TII = TM->getSubtargetImpl(*F)->getInstrInfo(); ASSERT_TRUE(TII); } -}; -// Function to find an opcode by name -static int findOpcodeByName(const TargetInstrInfo *TII, StringRef Name) { - for (unsigned Opcode = 1; Opcode < TII->getNumOpcodes(); ++Opcode) { - if (TII->getName(Opcode) == Name) - return Opcode; + // Find an opcode by name + int findOpcodeByName(StringRef Name) { + for (unsigned Opcode = 1; Opcode < TII->getNumOpcodes(); ++Opcode) { + if (TII->getName(Opcode) == Name) + return Opcode; + } + return -1; // Not found + } + + // Create a vocabulary with specific opcodes and embeddings + MIRVocabulary + createTestVocab(std::initializer_list<std::pair<const char *, float>> opcodes, + unsigned dimension = 2) { + VocabMap VMap; + for (const auto &[name, value] : opcodes) + VMap[name] = Embedding(dimension, value); + return MIRVocabulary(std::move(VMap), TII); } - return -1; // Not found -} + + // Create empty/invalid vocabulary + MIRVocabulary createEmptyVocab() { + VocabMap EmptyVMap; + return MIRVocabulary(std::move(EmptyVMap), TII); + } +}; TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) { // Test that same base opcodes get same canonical indices @@ -107,10 +126,8 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) { // Create a MIRVocabulary instance to test the mapping // Use a minimal MIRVocabulary to trigger canonical mapping construction - VocabMap VMap; Embedding Val = Embedding(64, 1.0f); - VMap["ADD"] = Val; - MIRVocabulary TestVocab(std::move(VMap), TII); + auto TestVocab = createTestVocab({{"ADD", 1.0f}}, 64); unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName1); unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName2); @@ -141,16 +158,16 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) { 6880u); // X86 has >6880 unique base opcodes // Check that the embeddings for opcodes not in the vocab are zero vectors - int Add32rrOpcode = findOpcodeByName(TII, "ADD32rr"); + int Add32rrOpcode = findOpcodeByName("ADD32rr"); ASSERT_NE(Add32rrOpcode, -1) << "ADD32rr opcode not found"; EXPECT_TRUE(TestVocab[Add32rrOpcode].approximatelyEquals(Val)); - int Sub32rrOpcode = findOpcodeByName(TII, "SUB32rr"); + int Sub32rrOpcode = findOpcodeByName("SUB32rr"); ASSERT_NE(Sub32rrOpcode, -1) << "SUB32rr opcode not found"; EXPECT_TRUE( TestVocab[Sub32rrOpcode].approximatelyEquals(Embedding(64, 0.0f))); - int Mov32rrOpcode = findOpcodeByName(TII, "MOV32rr"); + int Mov32rrOpcode = findOpcodeByName("MOV32rr"); ASSERT_NE(Mov32rrOpcode, -1) << "MOV32rr opcode not found"; EXPECT_TRUE( TestVocab[Mov32rrOpcode].approximatelyEquals(Embedding(64, 0.0f))); @@ -163,15 +180,11 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) { // Create a MIRVocabulary instance to test deterministic mapping // Use a minimal MIRVocabulary to trigger canonical mapping construction - VocabMap VMap; - VMap["ADD"] = Embedding(64, 1.0f); - MIRVocabulary TestVocab(std::move(VMap), TII); + auto TestVocab = createTestVocab({{"ADD", 1.0f}}, 64); unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName); unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName); unsigned Index3 = TestVocab.getCanonicalIndexForBaseName(BaseName); - - EXPECT_EQ(Index1, Index2); EXPECT_EQ(Index2, Index3); // Test across multiple runs @@ -183,11 +196,8 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) { // Test MIRVocabulary construction TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) { - VocabMap VMap; - VMap["ADD"] = Embedding(128, 1.0f); // Dimension 128, all values 1.0 - VMap["SUB"] = Embedding(128, 2.0f); // Dimension 128, all values 2.0 - - MIRVocabulary Vocab(std::move(VMap), TII); + // Test MIRVocabulary with embeddings via VocabMap + auto Vocab = createTestVocab({{"ADD", 1.0f}, {"SUB", 2.0f}}, 128); EXPECT_TRUE(Vocab.isValid()); EXPECT_EQ(Vocab.getDimension(), 128u); @@ -206,4 +216,268 @@ TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) { EXPECT_GT(Count, 0u); } -} // namespace \ No newline at end of file +// Test embedder with invalid vocabulary +TEST_F(MIR2VecVocabTestFixture, InvalidVocabulary) { + // Create an invalid vocabulary (empty) + auto InvalidVocab = createEmptyVocab(); + + // Verify vocabulary is invalid + EXPECT_FALSE(InvalidVocab.isValid()); + EXPECT_EQ(InvalidVocab.getDimension(), 0u); +} + +// Fixture for embedding related tests +class MIR2VecEmbeddingTestFixture : public MIR2VecVocabTestFixture { +protected: + std::unique_ptr<MachineModuleInfo> MMI; + MachineFunction *MF = nullptr; + + void SetUp() override { + MIR2VecVocabTestFixture::SetUp(); + + // Create a dummy function for MachineFunction + FunctionType *FT = FunctionType::get(Type::getVoidTy(*Ctx), false); + Function *F = + Function::Create(FT, Function::ExternalLinkage, "test", M.get()); + + MMI = std::make_unique<MachineModuleInfo>(TM.get()); + MF = &MMI->getOrCreateMachineFunction(*F); + } + + void TearDown() override { MIR2VecVocabTestFixture::TearDown(); } + + // Create a machine instruction + MachineInstr *createMachineInstr(MachineBasicBlock &MBB, unsigned Opcode) { + const MCInstrDesc &Desc = TII->get(Opcode); + // Create instruction - operands don't affect opcode-based embeddings + MachineInstr *MI = BuildMI(MBB, MBB.end(), DebugLoc(), Desc); + return MI; + } + + MachineInstr *createMachineInstr(MachineBasicBlock &MBB, + const char *OpcodeName) { + int Opcode = findOpcodeByName(OpcodeName); + if (Opcode == -1) + return nullptr; + return createMachineInstr(MBB, Opcode); + } + + void createMachineInstrs(MachineBasicBlock &MBB, + std::initializer_list<const char *> Opcodes) { + for (const char *OpcodeName : Opcodes) { + MachineInstr *MI = createMachineInstr(MBB, OpcodeName); + ASSERT_TRUE(MI != nullptr); + } + } +}; + +// Test factory method for creating embedder +TEST_F(MIR2VecEmbeddingTestFixture, CreateSymbolicEmbedder) { + auto V = MIRVocabulary::createDummyVocabForTest(*TII, 1); + auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, *MF, V); + EXPECT_NE(Emb, nullptr); +} + +TEST_F(MIR2VecEmbeddingTestFixture, CreateInvalidMode) { + auto V = MIRVocabulary::createDummyVocabForTest(*TII, 1); + + // static_cast an invalid int to IR2VecKind + auto Result = MIREmbedder::create(static_cast<MIR2VecKind>(-1), *MF, V); + EXPECT_FALSE(static_cast<bool>(Result)); +} + +// Test SymbolicMIREmbedder with simple target opcodes +TEST_F(MIR2VecEmbeddingTestFixture, TestSymbolicEmbedder) { + // Create a test vocabulary with specific values + auto Vocab = createTestVocab( + { + {"NOOP", 1.0f}, // [1.0, 1.0, 1.0, 1.0] + {"RET", 2.0f}, // [2.0, 2.0, 2.0, 2.0] + {"TRAP", 3.0f} // [3.0, 3.0, 3.0, 3.0] + }, + 4); + + // Create a basic block using fixture's MF + MachineBasicBlock *MBB = MF->CreateMachineBasicBlock(); + MF->push_back(MBB); + + // Use real X86 opcodes that should exist and not be pseudo + auto NoopInst = createMachineInstr(*MBB, "NOOP"); + ASSERT_TRUE(NoopInst != nullptr); + + auto RetInst = createMachineInstr(*MBB, "RET64"); + ASSERT_TRUE(RetInst != nullptr); + + auto TrapInst = createMachineInstr(*MBB, "TRAP"); + ASSERT_TRUE(TrapInst != nullptr); + + // Verify these are not pseudo instructions + ASSERT_FALSE(NoopInst->isPseudo()) << "NOOP is marked as pseudo instruction"; + ASSERT_FALSE(RetInst->isPseudo()) << "RET is marked as pseudo instruction"; + ASSERT_FALSE(TrapInst->isPseudo()) << "TRAP is marked as pseudo instruction"; + + // Create embedder + auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab); + ASSERT_TRUE(Embedder != nullptr); + + // Test instruction embeddings + const auto &MInstMap = Embedder->getMInstVecMap(); + EXPECT_EQ(MInstMap.size(), 3u); // Should have 3 instructions + + // Check individual instruction embeddings + auto NoopIt = MInstMap.find(NoopInst); + auto RetIt = MInstMap.find(RetInst); + auto TrapIt = MInstMap.find(TrapInst); + + ASSERT_NE(NoopIt, MInstMap.end()); + ASSERT_NE(RetIt, MInstMap.end()); + ASSERT_NE(TrapIt, MInstMap.end()); + + // Verify embeddings match expected values (accounting for weight scaling) + float ExpectedWeight = ::OpcWeight; // Global weight from command line + EXPECT_TRUE( + NoopIt->second.approximatelyEquals(Embedding(4, 1.0f * ExpectedWeight))); + EXPECT_TRUE( + RetIt->second.approximatelyEquals(Embedding(4, 2.0f * ExpectedWeight))); + EXPECT_TRUE( + TrapIt->second.approximatelyEquals(Embedding(4, 3.0f * ExpectedWeight))); + + // Test basic block embedding (should be sum of instruction embeddings) + const auto &MBBMap = Embedder->getMBBVecMap(); + EXPECT_EQ(MBBMap.size(), 1u); + + auto MBBIt = MBBMap.find(MBB); + ASSERT_NE(MBBIt, MBBMap.end()); + + // Expected BB vector: NOOP + RET + TRAP = [1+2+3, 1+2+3, 1+2+3, 1+2+3] * + // weight = [6, 6, 6, 6] * weight + Embedding ExpectedMBBVector(4, 6.0f * ExpectedWeight); + EXPECT_TRUE(MBBIt->second.approximatelyEquals(ExpectedMBBVector)); + + // Test function embedding (should equal MBB embedding since we have one MBB) + const Embedding &MFuncVector = Embedder->getMFunctionVector(); + EXPECT_TRUE(MFuncVector.approximatelyEquals(ExpectedMBBVector)); +} + +// Test embedder with multiple basic blocks +TEST_F(MIR2VecEmbeddingTestFixture, MultipleBasicBlocks) { + // Create a test vocabulary + auto Vocab = createTestVocab({{"NOOP", 1.0f}, {"TRAP", 2.0f}}); + + // Create two basic blocks using fixture's MF + MachineBasicBlock *MBB1 = MF->CreateMachineBasicBlock(); + MachineBasicBlock *MBB2 = MF->CreateMachineBasicBlock(); + MF->push_back(MBB1); + MF->push_back(MBB2); + + createMachineInstrs(*MBB1, {"NOOP", "NOOP"}); + createMachineInstr(*MBB2, "TRAP"); + + // Create embedder + auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab); + ASSERT_TRUE(Embedder != nullptr); + + // Test basic block embeddings + const auto &MBBMap = Embedder->getMBBVecMap(); + EXPECT_EQ(MBBMap.size(), 2u); + + auto MBB1It = MBBMap.find(MBB1); + auto MBB2It = MBBMap.find(MBB2); + ASSERT_NE(MBB1It, MBBMap.end()); + ASSERT_NE(MBB2It, MBBMap.end()); + + float ExpectedWeight = ::OpcWeight; + // BB1: NOOP + NOOP = 2 * ([1, 1] * weight) + Embedding ExpectedBB1Vector(2, 2.0f * ExpectedWeight); + EXPECT_TRUE(MBB1It->second.approximatelyEquals(ExpectedBB1Vector)); + + // BB2: TRAP = [2, 2] * weight + Embedding ExpectedBB2Vector(2, 2.0f * ExpectedWeight); + EXPECT_TRUE(MBB2It->second.approximatelyEquals(ExpectedBB2Vector)); + + // Function embedding: BB1 + BB2 = [2+2, 2+2] * weight = [4, 4] * weight + const Embedding &MFVector = Embedder->getMFunctionVector(); + Embedding ExpectedFuncVector(2, 4.0f * ExpectedWeight); + EXPECT_TRUE(MFVector.approximatelyEquals(ExpectedFuncVector)); +} + +// Test embedder with empty basic block +TEST_F(MIR2VecEmbeddingTestFixture, EmptyBasicBlock) { + + // Create an empty basic block + MachineBasicBlock *MBB = MF->CreateMachineBasicBlock(); + MF->push_back(MBB); + + // Create embedder + auto Vocab = MIRVocabulary::createDummyVocabForTest(*TII, 2); + auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab); + ASSERT_TRUE(Embedder != nullptr); + + // Test that empty BB has zero embedding + const auto &MBBMap = Embedder->getMBBVecMap(); + EXPECT_EQ(MBBMap.size(), 1u); + + auto MBBIt = MBBMap.find(MBB); + ASSERT_NE(MBBIt, MBBMap.end()); + + // Empty BB should have zero embedding + Embedding ExpectedBBVector(2, 0.0f); + EXPECT_TRUE(MBBIt->second.approximatelyEquals(ExpectedBBVector)); + + // Function embedding should also be zero + const Embedding &MFVector = Embedder->getMFunctionVector(); + EXPECT_TRUE(MFVector.approximatelyEquals(ExpectedBBVector)); +} + +// Test embedder with opcodes not in vocabulary +TEST_F(MIR2VecEmbeddingTestFixture, UnknownOpcodes) { + // Create a test vocabulary with limited entries + // SUB is intentionally not included + auto Vocab = createTestVocab({{"ADD", 1.0f}}); + + // Create a basic block + MachineBasicBlock *MBB = MF->CreateMachineBasicBlock(); + MF->push_back(MBB); + + // Find opcodes + int AddOpcode = findOpcodeByName("ADD32rr"); + int SubOpcode = findOpcodeByName("SUB32rr"); + + ASSERT_NE(AddOpcode, -1) << "ADD32rr opcode not found"; + ASSERT_NE(SubOpcode, -1) << "SUB32rr opcode not found"; + + // Create instructions + MachineInstr *AddInstr = createMachineInstr(*MBB, AddOpcode); + MachineInstr *SubInstr = createMachineInstr(*MBB, SubOpcode); + + // Create embedder + auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab); + ASSERT_TRUE(Embedder != nullptr); + + // Test instruction embeddings + const auto &MIMap = Embedder->getMInstVecMap(); + + auto AddIt = MIMap.find(AddInstr); + auto SubIt = MIMap.find(SubInstr); + + ASSERT_NE(AddIt, MIMap.end()); + ASSERT_NE(SubIt, MIMap.end()); + + float ExpectedWeight = ::OpcWeight; + // ADD should have the embedding from vocabulary + EXPECT_TRUE( + AddIt->second.approximatelyEquals(Embedding(2, 1.0f * ExpectedWeight))); + + // SUB should have zero embedding (not in vocabulary) + EXPECT_TRUE(SubIt->second.approximatelyEquals(Embedding(2, 0.0f))); + + // Basic block embedding should be ADD + SUB = [1.0, 1.0] * weight + [0.0, + // 0.0] = [1.0, 1.0] * weight + const auto &MBBMap = Embedder->getMBBVecMap(); + auto MBBIt = MBBMap.find(MBB); + ASSERT_NE(MBBIt, MBBMap.end()); + + Embedding ExpectedBBVector(2, 1.0f * ExpectedWeight); + EXPECT_TRUE(MBBIt->second.approximatelyEquals(ExpectedBBVector)); +} +} // namespace _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
