llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-mlgo Author: S. VenkataKeerthy (svkeerthy) <details> <summary>Changes</summary> Initialize `Embedding` vectors with zeros by default when only size is provided. --- Full diff: https://github.com/llvm/llvm-project/pull/155690.diff 2 Files Affected: - (modified) llvm/include/llvm/Analysis/IR2Vec.h (+1-1) - (modified) llvm/lib/Analysis/IR2Vec.cpp (+4-4) ``````````diff diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h index 44932a3385e16..6fb8f736da092 100644 --- a/llvm/include/llvm/Analysis/IR2Vec.h +++ b/llvm/include/llvm/Analysis/IR2Vec.h @@ -92,7 +92,7 @@ struct Embedding { Embedding(std::vector<double> &&V) : Data(std::move(V)) {} Embedding(std::initializer_list<double> IL) : Data(IL) {} - explicit Embedding(size_t Size) : Data(Size) {} + explicit Embedding(size_t Size) : Data(Size, 0.0) {} Embedding(size_t Size, double InitialValue) : Data(Size, InitialValue) {} size_t size() const { return Data.size(); } diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp index 565ec2a6287b7..6b90f1aabacfa 100644 --- a/llvm/lib/Analysis/IR2Vec.cpp +++ b/llvm/lib/Analysis/IR2Vec.cpp @@ -155,7 +155,7 @@ void Embedding::print(raw_ostream &OS) const { Embedder::Embedder(const Function &F, const Vocabulary &Vocab) : F(F), Vocab(Vocab), Dimension(Vocab.getDimension()), OpcWeight(::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight), - FuncVector(Embedding(Dimension, 0)) {} + FuncVector(Embedding(Dimension)) {} std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F, const Vocabulary &Vocab) { @@ -472,7 +472,7 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() { // Handle Opcodes std::vector<Embedding> NumericOpcodeEmbeddings(Vocabulary::MaxOpcodes, - Embedding(Dim, 0)); + Embedding(Dim)); NumericOpcodeEmbeddings.reserve(Vocabulary::MaxOpcodes); for (unsigned Opcode : seq(0u, Vocabulary::MaxOpcodes)) { StringRef VocabKey = Vocabulary::getVocabKeyForOpcode(Opcode + 1); @@ -487,7 +487,7 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() { // Handle Types - only canonical types are present in vocabulary std::vector<Embedding> NumericTypeEmbeddings(Vocabulary::MaxCanonicalTypeIDs, - Embedding(Dim, 0)); + Embedding(Dim)); NumericTypeEmbeddings.reserve(Vocabulary::MaxCanonicalTypeIDs); for (unsigned CTypeID : seq(0u, Vocabulary::MaxCanonicalTypeIDs)) { StringRef VocabKey = Vocabulary::getVocabKeyForCanonicalTypeID( @@ -503,7 +503,7 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() { // Handle Arguments/Operands std::vector<Embedding> NumericArgEmbeddings(Vocabulary::MaxOperandKinds, - Embedding(Dim, 0)); + Embedding(Dim)); NumericArgEmbeddings.reserve(Vocabulary::MaxOperandKinds); for (unsigned OpKind : seq(0u, Vocabulary::MaxOperandKinds)) { Vocabulary::OperandKind Kind = static_cast<Vocabulary::OperandKind>(OpKind); `````````` </details> https://github.com/llvm/llvm-project/pull/155690 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits