zhaih commented on code in PR #12050:
URL: https://github.com/apache/lucene/pull/12050#discussion_r1061007997


##########
lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsWriter.java:
##########
@@ -461,6 +467,126 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState 
mergeState) throws IOE
     }
   }
 
+  private void maybeInitializeFromGraph(
+      HnswGraphBuilder<?> hnswGraphBuilder, MergeState mergeState, FieldInfo 
fieldInfo)
+      throws IOException {
+    int initializerIndex = selectGraphForInitialization(mergeState, fieldInfo);
+    if (initializerIndex == -1) {
+      return;
+    }
+
+    HnswGraph initializerGraph =
+        getHnswGraphFromReader(fieldInfo.name, 
mergeState.knnVectorsReaders[initializerIndex]);
+    Map<Integer, Integer> ordinalMapper =
+        getOldToNewOrdinalMap(mergeState, fieldInfo, initializerIndex);
+    hnswGraphBuilder.initializeFromGraph(initializerGraph, ordinalMapper);
+  }
+
+  private int selectGraphForInitialization(MergeState mergeState, FieldInfo 
fieldInfo)
+      throws IOException {
+    // Find the KnnVectorReader with the most docs that meets the following 
criteria:
+    //  1. Does not contain any deleted docs
+    //  2. Is a Lucene95HnswVectorsReader/PerFieldKnnVectorReader
+    // If no readers exist that meet this criteria, return -1. If they do, 
return their index in
+    // merge state
+    int maxCandidateVectorCount = 0;
+    int initializerIndex = -1;
+
+    for (int i = 0; i < mergeState.liveDocs.length; i++) {
+      KnnVectorsReader currKnnVectorsReader = mergeState.knnVectorsReaders[i];
+      if (mergeState.knnVectorsReaders[i]
+          instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) {
+        currKnnVectorsReader = candidateReader.getFieldReader(fieldInfo.name);
+      }
+
+      if (!allMatch(mergeState.liveDocs[i])
+          || !(currKnnVectorsReader instanceof Lucene95HnswVectorsReader 
candidateReader)) {
+        continue;
+      }
+
+      VectorValues vectorValues = 
candidateReader.getVectorValues(fieldInfo.name);
+      if (vectorValues == null) {
+        continue;
+      }
+
+      int candidateVectorCount = vectorValues.size();
+      if (candidateVectorCount > maxCandidateVectorCount) {
+        maxCandidateVectorCount = candidateVectorCount;
+        initializerIndex = i;
+      }
+    }
+    return initializerIndex;
+  }
+
+  private HnswGraph getHnswGraphFromReader(String fieldName, KnnVectorsReader 
knnVectorsReader)
+      throws IOException {
+    if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader 
perFieldReader
+        && perFieldReader.getFieldReader(fieldName)
+            instanceof Lucene95HnswVectorsReader fieldReader) {
+      return fieldReader.getGraph(fieldName);
+    }
+
+    if (knnVectorsReader instanceof Lucene95HnswVectorsReader) {
+      return ((Lucene95HnswVectorsReader) 
knnVectorsReader).getGraph(fieldName);
+    }
+
+    throw new IllegalArgumentException(
+        "Invalid KnnVectorsReader. Must be of type 
PerFieldKnnVectorsFormat.FieldsReader or Lucene94HnswVectorsReader");
+  }
+
+  private Map<Integer, Integer> getOldToNewOrdinalMap(
+      MergeState mergeState, FieldInfo fieldInfo, int initializerIndex) throws 
IOException {
+    VectorValues initializerVectorValues =
+        
mergeState.knnVectorsReaders[initializerIndex].getVectorValues(fieldInfo.name);
+    MergeState.DocMap initializerDocMap = mergeState.docMaps[initializerIndex];
+
+    Map<Integer, Integer> newIdToOldOrdinal = new HashMap<>();
+    int oldOrd = 0;
+    for (int oldId = initializerVectorValues.nextDoc();
+        oldId != NO_MORE_DOCS;
+        oldId = initializerVectorValues.nextDoc()) {
+      if (initializerVectorValues.vectorValue() == null) {
+        continue;
+      }
+      int newId = initializerDocMap.get(oldId);
+      newIdToOldOrdinal.put(newId, oldOrd);
+      oldOrd++;
+    }
+
+    Map<Integer, Integer> oldToNewOrdinalMap = new HashMap<>();
+    int newOrd = 0;
+    int maxNewDocID = Collections.max(newIdToOldOrdinal.keySet());

Review Comment:
   It might be a bit faster to calculate this max in the previous loop?



##########
lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsWriter.java:
##########
@@ -461,6 +467,126 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState 
mergeState) throws IOE
     }
   }
 
+  private void maybeInitializeFromGraph(
+      HnswGraphBuilder<?> hnswGraphBuilder, MergeState mergeState, FieldInfo 
fieldInfo)
+      throws IOException {
+    int initializerIndex = selectGraphForInitialization(mergeState, fieldInfo);
+    if (initializerIndex == -1) {
+      return;
+    }
+
+    HnswGraph initializerGraph =
+        getHnswGraphFromReader(fieldInfo.name, 
mergeState.knnVectorsReaders[initializerIndex]);
+    Map<Integer, Integer> ordinalMapper =
+        getOldToNewOrdinalMap(mergeState, fieldInfo, initializerIndex);
+    hnswGraphBuilder.initializeFromGraph(initializerGraph, ordinalMapper);
+  }
+
+  private int selectGraphForInitialization(MergeState mergeState, FieldInfo 
fieldInfo)
+      throws IOException {
+    // Find the KnnVectorReader with the most docs that meets the following 
criteria:
+    //  1. Does not contain any deleted docs
+    //  2. Is a Lucene95HnswVectorsReader/PerFieldKnnVectorReader
+    // If no readers exist that meet this criteria, return -1. If they do, 
return their index in
+    // merge state
+    int maxCandidateVectorCount = 0;
+    int initializerIndex = -1;
+
+    for (int i = 0; i < mergeState.liveDocs.length; i++) {
+      KnnVectorsReader currKnnVectorsReader = mergeState.knnVectorsReaders[i];
+      if (mergeState.knnVectorsReaders[i]
+          instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) {
+        currKnnVectorsReader = candidateReader.getFieldReader(fieldInfo.name);
+      }
+
+      if (!allMatch(mergeState.liveDocs[i])
+          || !(currKnnVectorsReader instanceof Lucene95HnswVectorsReader 
candidateReader)) {
+        continue;
+      }
+
+      VectorValues vectorValues = 
candidateReader.getVectorValues(fieldInfo.name);
+      if (vectorValues == null) {
+        continue;
+      }
+
+      int candidateVectorCount = vectorValues.size();
+      if (candidateVectorCount > maxCandidateVectorCount) {
+        maxCandidateVectorCount = candidateVectorCount;
+        initializerIndex = i;
+      }
+    }
+    return initializerIndex;
+  }
+
+  private HnswGraph getHnswGraphFromReader(String fieldName, KnnVectorsReader 
knnVectorsReader)
+      throws IOException {
+    if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader 
perFieldReader
+        && perFieldReader.getFieldReader(fieldName)
+            instanceof Lucene95HnswVectorsReader fieldReader) {
+      return fieldReader.getGraph(fieldName);
+    }
+
+    if (knnVectorsReader instanceof Lucene95HnswVectorsReader) {
+      return ((Lucene95HnswVectorsReader) 
knnVectorsReader).getGraph(fieldName);
+    }
+
+    throw new IllegalArgumentException(
+        "Invalid KnnVectorsReader. Must be of type 
PerFieldKnnVectorsFormat.FieldsReader or Lucene94HnswVectorsReader");

Review Comment:
   Maybe say:
   `"Invalid KnnVectorsReader type for field: " + fieldName + ". Must be 
Lucene95HnswVectorsReader or newer"`?



##########
lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsWriter.java:
##########
@@ -461,6 +467,126 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState 
mergeState) throws IOE
     }
   }
 
+  private void maybeInitializeFromGraph(
+      HnswGraphBuilder<?> hnswGraphBuilder, MergeState mergeState, FieldInfo 
fieldInfo)
+      throws IOException {
+    int initializerIndex = selectGraphForInitialization(mergeState, fieldInfo);
+    if (initializerIndex == -1) {
+      return;
+    }
+
+    HnswGraph initializerGraph =
+        getHnswGraphFromReader(fieldInfo.name, 
mergeState.knnVectorsReaders[initializerIndex]);
+    Map<Integer, Integer> ordinalMapper =
+        getOldToNewOrdinalMap(mergeState, fieldInfo, initializerIndex);
+    hnswGraphBuilder.initializeFromGraph(initializerGraph, ordinalMapper);
+  }
+
+  private int selectGraphForInitialization(MergeState mergeState, FieldInfo 
fieldInfo)
+      throws IOException {
+    // Find the KnnVectorReader with the most docs that meets the following 
criteria:
+    //  1. Does not contain any deleted docs
+    //  2. Is a Lucene95HnswVectorsReader/PerFieldKnnVectorReader
+    // If no readers exist that meet this criteria, return -1. If they do, 
return their index in
+    // merge state
+    int maxCandidateVectorCount = 0;
+    int initializerIndex = -1;
+
+    for (int i = 0; i < mergeState.liveDocs.length; i++) {
+      KnnVectorsReader currKnnVectorsReader = mergeState.knnVectorsReaders[i];
+      if (mergeState.knnVectorsReaders[i]
+          instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) {
+        currKnnVectorsReader = candidateReader.getFieldReader(fieldInfo.name);
+      }
+
+      if (!allMatch(mergeState.liveDocs[i])
+          || !(currKnnVectorsReader instanceof Lucene95HnswVectorsReader 
candidateReader)) {
+        continue;
+      }
+
+      VectorValues vectorValues = 
candidateReader.getVectorValues(fieldInfo.name);
+      if (vectorValues == null) {
+        continue;
+      }
+
+      int candidateVectorCount = vectorValues.size();
+      if (candidateVectorCount > maxCandidateVectorCount) {
+        maxCandidateVectorCount = candidateVectorCount;
+        initializerIndex = i;
+      }
+    }
+    return initializerIndex;
+  }
+
+  private HnswGraph getHnswGraphFromReader(String fieldName, KnnVectorsReader 
knnVectorsReader)
+      throws IOException {
+    if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader 
perFieldReader
+        && perFieldReader.getFieldReader(fieldName)
+            instanceof Lucene95HnswVectorsReader fieldReader) {
+      return fieldReader.getGraph(fieldName);
+    }
+
+    if (knnVectorsReader instanceof Lucene95HnswVectorsReader) {
+      return ((Lucene95HnswVectorsReader) 
knnVectorsReader).getGraph(fieldName);
+    }
+

Review Comment:
   Can we also add a comment indicating we shouldn't really reach here because 
the reader type should be already checked inside `selectGraphForInitialization`?



##########
lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java:
##########
@@ -94,36 +93,83 @@ public int size() {
   }
 
   /**
-   * Add node on the given level
+   * Add node on the given level. Nodes can be inserted out of order, but it 
requires that the nodes

Review Comment:
   Since we need out-of-order insertion, I wonder whether it could be better if 
we have another implementation of OnHeapHnswGraph where it uses BST for all 
layers other than layer 0?



##########
lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java:
##########
@@ -143,10 +148,64 @@ public OnHeapHnswGraph build(RandomAccessVectorValues 
vectorsToAdd) throws IOExc
     return hnsw;
   }
 
+  /**
+   * Initializes the graph of this builder. Transfers the nodes and their 
neighbors from the
+   * initializer graph into the graph being produced by this builder, mapping 
ordinals from the
+   * initializer graph to their new ordinals in this builder's graph. The 
builder's graph must be
+   * empty before calling this method.
+   *
+   * @param initializerGraph graph used for initialization
+   * @param oldToNewOrdinalMap map for converting from ordinals in the 
initializerGraph to this
+   *     builder's graph
+   */
+  public void initializeFromGraph(
+      HnswGraph initializerGraph, Map<Integer, Integer> oldToNewOrdinalMap) 
throws IOException {
+    assert hnsw.size() == 0;
+    float[] vectorValue = null;
+    BytesRef binaryValue = null;
+    for (int level = 0; level < initializerGraph.numLevels(); level++) {
+      HnswGraph.NodesIterator it = initializerGraph.getNodesOnLevel(level);
+
+      while (it.hasNext()) {
+        int oldOrd = it.nextInt();
+        int newOrd = oldToNewOrdinalMap.get(oldOrd);
+
+        hnsw.addNode(level, newOrd);
+
+        if (level == 0) {
+          initializedNodes.add(newOrd);
+        }
+
+        switch (this.vectorEncoding) {
+          case FLOAT32 -> vectorValue = vectors.vectorValue(newOrd);
+          case BYTE -> binaryValue = vectors.binaryValue(newOrd);
+        }
+
+        NeighborArray newNeighbors = this.hnsw.getNeighbors(level, newOrd);
+        initializerGraph.seek(level, oldOrd);
+        for (int oldNeighbor = initializerGraph.nextNeighbor();
+            oldNeighbor != NO_MORE_DOCS;
+            oldNeighbor = initializerGraph.nextNeighbor()) {
+          int newNeighbor = oldToNewOrdinalMap.get(oldNeighbor);
+          float score =
+              switch (this.vectorEncoding) {
+                case FLOAT32 -> this.similarityFunction.compare(

Review Comment:
   I wonder if those scores are lazily calculated whether we can save some time 
here?
   Since for all the nodes in the initializer we already know their order and 
we don't need their score as long as there's no new nodes inserted?
   When there's a new node we just do the binary search as usual and calculate 
scores if necessary?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org
For additional commands, e-mail: issues-h...@lucene.apache.org

Reply via email to