maedhroz commented on code in PR #4353:
URL: https://github.com/apache/cassandra/pull/4353#discussion_r2710962528


##########
src/java/org/apache/cassandra/index/sai/memory/VectorMemoryIndex.java:
##########
@@ -172,73 +186,92 @@ public KeyRangeIterator search(QueryContext queryContext, 
Expression expr, Abstr
             PrimaryKey right = isMaxToken ? null : 
index.keyFactory().create(keyRange.right.getToken()); // upper bound
 
             Set<PrimaryKey> resultKeys = isMaxToken ? 
primaryKeys.tailSet(left, leftInclusive) : primaryKeys.subSet(left, 
leftInclusive, right, rightInclusive);
-            if (!vectorQueryContext.getShadowedPrimaryKeys().isEmpty())
-                resultKeys = resultKeys.stream().filter(pk -> 
!vectorQueryContext.containsShadowedPrimaryKey(pk)).collect(Collectors.toSet());
 
             if (resultKeys.isEmpty())
-                return KeyRangeIterator.empty();
+                return CloseableIterator.empty();
 
-            int bruteForceRows = maxBruteForceRows(vectorQueryContext.limit(), 
resultKeys.size(), graph.size());
+            int bruteForceRows = maxBruteForceRows(queryContext.limit(), 
resultKeys.size(), graph.size());
             Tracing.trace("Search range covers {} rows; max brute force rows 
is {} for memtable index with {} nodes, LIMIT {}",
-                          resultKeys.size(), bruteForceRows, graph.size(), 
vectorQueryContext.limit());
-            if (resultKeys.size() < Math.max(vectorQueryContext.limit(), 
bruteForceRows))
-                return new ReorderingRangeIterator(new 
PriorityQueue<>(resultKeys));
+                          resultKeys.size(), bruteForceRows, graph.size(), 
queryContext.limit());
+            if (resultKeys.size() < Math.max(queryContext.limit(), 
bruteForceRows))
+                return orderByBruteForce(qv, resultKeys);
             else
-                bits = new KeyRangeFilteringBits(keyRange, 
vectorQueryContext.bitsetForShadowedPrimaryKeys(graph));
+                bits = new KeyRangeFilteringBits(keyRange, null);
         }
         else
         {
-            // partition/range deletion won't trigger index update, so we have 
to filter shadow primary keys in memtable index
-            bits = 
queryContext.vectorContext().bitsetForShadowedPrimaryKeys(graph);
+            // Accept all bits
+            bits = new Bits.MatchAllBits(Integer.MAX_VALUE);
         }
 
-        PriorityQueue<PrimaryKey> keyQueue = graph.search(qv, 
queryContext.vectorContext().limit(), bits);
-        if (keyQueue.isEmpty())
-            return KeyRangeIterator.empty();
-        return new ReorderingRangeIterator(keyQueue);
+        CloseableIterator<SearchResult.NodeScore> iterator = graph.search(qv, 
queryContext.limit(), bits);
+        return new NodeScoreToScoredPrimaryKeyIterator(iterator);
     }
 
     @Override
-    public KeyRangeIterator limitToTopResults(List<PrimaryKey> primaryKeys, 
Expression expression, int limit)
+    public CloseableIterator<PrimaryKeyWithScore> orderResultsBy(QueryContext 
queryContext, List<PrimaryKey> results, Expression orderer)
     {
         if (minimumKey == null)
             // This case implies maximumKey is empty too.
-            return KeyRangeIterator.empty();
+            return CloseableIterator.empty();
 
-        List<PrimaryKey> results = primaryKeys.stream()
-                                              .dropWhile(k -> 
k.compareTo(minimumKey) < 0)
-                                              .takeWhile(k -> 
k.compareTo(maximumKey) <= 0)
-                                              .collect(Collectors.toList());
+        int limit = queryContext.limit();
 
-        int maxBruteForceRows = maxBruteForceRows(limit, results.size(), 
graph.size());
+        List<PrimaryKey> resultsInRange = results.stream()
+                                                 .dropWhile(k -> 
k.compareTo(minimumKey) < 0)
+                                                 .takeWhile(k -> 
k.compareTo(maximumKey) <= 0)
+                                                 .collect(Collectors.toList());
+
+        int maxBruteForceRows = maxBruteForceRows(limit, 
resultsInRange.size(), graph.size());
         Tracing.trace("SAI materialized {} rows; max brute force rows is {} 
for memtable index with {} nodes, LIMIT {}",
-                      results.size(), maxBruteForceRows, graph.size(), limit);
-        if (results.size() <= maxBruteForceRows)
-        {
-            if (results.isEmpty())
-                return KeyRangeIterator.empty();
-            return new KeyRangeListIterator(minimumKey, maximumKey, results);
-        }
+                      resultsInRange.size(), maxBruteForceRows, graph.size(), 
limit);
 
-        ByteBuffer buffer = expression.lower().value.raw;
+        if (resultsInRange.isEmpty())
+            return CloseableIterator.empty();
+
+        ByteBuffer buffer = orderer.lower().value.raw;
         float[] qv = index.termType().decomposeVector(buffer);
-        KeyFilteringBits bits = new KeyFilteringBits(results);
-        PriorityQueue<PrimaryKey> keyQueue = graph.search(qv, limit, bits);
-        if (keyQueue.isEmpty())
-            return KeyRangeIterator.empty();
-        return new ReorderingRangeIterator(keyQueue);
+
+        if (resultsInRange.size() <= maxBruteForceRows)
+            return orderByBruteForce(qv, resultsInRange);
+
+        // Search the graph for the topK vectors near the query
+        KeyFilteringBits bits = new KeyFilteringBits(resultsInRange);
+        CloseableIterator<SearchResult.NodeScore> nodeScores = 
graph.search(qv, limit, bits);
+        return new NodeScoreToScoredPrimaryKeyIterator(nodeScores);
     }
 
     private int maxBruteForceRows(int limit, int nPermittedOrdinals, int 
graphSize)
     {
         int expectedNodesVisited = expectedNodesVisited(limit, 
nPermittedOrdinals, graphSize);
-        int expectedComparisons = 
index.indexWriterConfig().getMaximumNodeConnections() * expectedNodesVisited;
-        // in-memory comparisons are cheaper than pulling a row off disk and 
then comparing
-        // VSTODO this is dramatically oversimplified
-        // larger dimension should increase this, because comparisons are more 
expensive
-        // lower chunk cache hit ratio should decrease this, because loading 
rows is more expensive
-        double memoryToDiskFactor = 0.25;
-        return (int) max(limit, memoryToDiskFactor * expectedComparisons);
+        // ANN index will do a bunch of extra work besides the full 
comparisons (performing PQ similarity for each edge);
+        // VSTODO I'm not sure which one is more expensive (and it depends on 
things like sstable chunk cache hit ratio)

Review Comment:
   SSTable chunk cache has nothing to do with this right? At this point it's 
comparisons vs. in-memory graph node visits (and we're just kind of roughly 
treating them as being equal)?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to