This is an automated email from the ASF dual-hosted git repository. ishan pushed a commit to branch jira/solr-17927 in repository https://gitbox.apache.org/repos/asf/solr.git
commit 490e9056e3002a88bb5c3d94c2e57cdf8ecac3eb Author: Ishan Chattopadhyaya <[email protected]> AuthorDate: Thu Dec 11 22:19:20 2025 +0530 Rebasing against latest main --- .../org/apache/solr/schema/DenseVectorField.java | 79 +++++++++++++--------- .../org/apache/solr/search/vector/KnnQParser.java | 4 +- .../solr/search/vector/SolrKnnByteVectorQuery.java | 14 ++++ .../search/vector/SolrKnnFloatVectorQuery.java | 14 ++++ .../apache/solr/schema/DenseVectorFieldTest.java | 13 +++- .../apache/solr/search/vector/KnnQParserTest.java | 10 +-- 6 files changed, 96 insertions(+), 38 deletions(-) diff --git a/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java b/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java index 01c51a358d2..8c6d4dfaee7 100644 --- a/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java +++ b/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java @@ -41,7 +41,6 @@ import org.apache.lucene.search.PatienceKnnVectorQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.SeededKnnVectorQuery; import org.apache.lucene.search.SortField; -import org.apache.lucene.search.knn.KnnSearchStrategy; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.hnsw.HnswGraph; import org.apache.solr.common.SolrException; @@ -511,40 +510,58 @@ public class DenseVectorField extends FloatPointField { DenseVectorParser vectorBuilder = getVectorBuilder(vectorToSearch, DenseVectorParser.BuilderPhase.QUERY); + // Create KnnSearchStrategy if filteredSearchThreshold is provided + org.apache.lucene.search.knn.KnnSearchStrategy searchStrategy = null; + if (filteredSearchThreshold != null) { + searchStrategy = + new org.apache.lucene.search.knn.KnnSearchStrategy.Hnsw(filteredSearchThreshold); + } + + Query baseQuery; switch (vectorEncoding) { case FLOAT32: - SolrKnnFloatVectorQuery knnFloatVectorQuery = - new SolrKnnFloatVectorQuery( - fieldName, vectorBuilder.getFloatVector(), topK, efSearch, filterQuery); - if (earlyTermination.isEnabled()) { - return (earlyTermination.getSaturationThreshold() != null - && earlyTermination.getPatience() != null) - ? PatienceKnnVectorQuery.fromFloatQuery( - knnFloatVectorQuery, - earlyTermination.getSaturationThreshold(), - earlyTermination.getPatience()) - : PatienceKnnVectorQuery.fromFloatQuery(knnFloatVectorQuery); - } - return knnFloatVectorQuery; + baseQuery = + searchStrategy != null + ? new SolrKnnFloatVectorQuery( + fieldName, + vectorBuilder.getFloatVector(), + topK, + efSearch, + filterQuery, + searchStrategy) + : new SolrKnnFloatVectorQuery( + fieldName, vectorBuilder.getFloatVector(), topK, efSearch, filterQuery); + break; case BYTE: - SolrKnnByteVectorQuery knnByteVectorQuery = - new SolrKnnByteVectorQuery( - fieldName, vectorBuilder.getByteVector(), topK, efSearch, filterQuery); - if (earlyTermination.isEnabled()) { - return (earlyTermination.getSaturationThreshold() != null - && earlyTermination.getPatience() != null) - ? PatienceKnnVectorQuery.fromByteQuery( - knnByteVectorQuery, - earlyTermination.getSaturationThreshold(), - earlyTermination.getPatience()) - : PatienceKnnVectorQuery.fromByteQuery(knnByteVectorQuery); - } - return knnByteVectorQuery; + baseQuery = + searchStrategy != null + ? new SolrKnnByteVectorQuery( + fieldName, + vectorBuilder.getByteVector(), + topK, + efSearch, + filterQuery, + searchStrategy) + : new SolrKnnByteVectorQuery( + fieldName, vectorBuilder.getByteVector(), topK, efSearch, filterQuery); + break; default: throw new SolrException( SolrException.ErrorCode.SERVER_ERROR, "Unexpected state. Vector Encoding: " + vectorEncoding); } + + // Apply seeding if seedQuery is provided + if (seedQuery != null) { + baseQuery = getSeededQuery(baseQuery, seedQuery); + } + + // Apply early termination if enabled + if (earlyTermination != null && earlyTermination.isEnabled()) { + baseQuery = getEarlyTerminationQuery(baseQuery, earlyTermination); + } + + return baseQuery; } /** @@ -580,9 +597,9 @@ public class DenseVectorField extends FloatPointField { private Query getSeededQuery(Query knnQuery, Query seed) { return switch (knnQuery) { - case KnnFloatVectorQuery knnFloatQuery -> SeededKnnVectorQuery.fromFloatQuery( + case SolrKnnFloatVectorQuery knnFloatQuery -> SeededKnnVectorQuery.fromFloatQuery( knnFloatQuery, seed); - case KnnByteVectorQuery knnByteQuery -> SeededKnnVectorQuery.fromByteQuery( + case SolrKnnByteVectorQuery knnByteQuery -> SeededKnnVectorQuery.fromByteQuery( knnByteQuery, seed); default -> throw new SolrException( SolrException.ErrorCode.SERVER_ERROR, "Invalid type of knn query"); @@ -594,13 +611,13 @@ public class DenseVectorField extends FloatPointField { (earlyTermination.getSaturationThreshold() != null && earlyTermination.getPatience() != null); return switch (knnQuery) { - case KnnFloatVectorQuery knnFloatQuery -> useExplicitParams + case SolrKnnFloatVectorQuery knnFloatQuery -> useExplicitParams ? PatienceKnnVectorQuery.fromFloatQuery( knnFloatQuery, earlyTermination.getSaturationThreshold(), earlyTermination.getPatience()) : PatienceKnnVectorQuery.fromFloatQuery(knnFloatQuery); - case KnnByteVectorQuery knnByteQuery -> useExplicitParams + case SolrKnnByteVectorQuery knnByteQuery -> useExplicitParams ? PatienceKnnVectorQuery.fromByteQuery( knnByteQuery, earlyTermination.getSaturationThreshold(), diff --git a/solr/core/src/java/org/apache/solr/search/vector/KnnQParser.java b/solr/core/src/java/org/apache/solr/search/vector/KnnQParser.java index 9220afd3c98..1a5beb83880 100644 --- a/solr/core/src/java/org/apache/solr/search/vector/KnnQParser.java +++ b/solr/core/src/java/org/apache/solr/search/vector/KnnQParser.java @@ -125,6 +125,8 @@ public class KnnQParser extends AbstractVectorQParserBase { topK, efSearch, getFilterQuery(), - getEarlyTerminationParams()); + getSeedQuery(), + getEarlyTerminationParams(), + filteredSearchThreshold); } } diff --git a/solr/core/src/java/org/apache/solr/search/vector/SolrKnnByteVectorQuery.java b/solr/core/src/java/org/apache/solr/search/vector/SolrKnnByteVectorQuery.java index bd4e5bcd207..0e100d1fb36 100644 --- a/solr/core/src/java/org/apache/solr/search/vector/SolrKnnByteVectorQuery.java +++ b/solr/core/src/java/org/apache/solr/search/vector/SolrKnnByteVectorQuery.java @@ -19,6 +19,7 @@ package org.apache.solr.search.vector; import org.apache.lucene.search.KnnByteVectorQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.knn.KnnSearchStrategy; public class SolrKnnByteVectorQuery extends KnnByteVectorQuery { private final int topK; @@ -30,6 +31,19 @@ public class SolrKnnByteVectorQuery extends KnnByteVectorQuery { this.topK = topK; } + public SolrKnnByteVectorQuery( + String field, + byte[] target, + int topK, + int efSearch, + Query filter, + KnnSearchStrategy searchStrategy) { + // efSearch is used as 'k' to explore this many vectors in HNSW, then topK results are returned + // to the user + super(field, target, efSearch, filter, searchStrategy); + this.topK = topK; + } + @Override protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) { return TopDocs.merge(topK, perLeafResults); diff --git a/solr/core/src/java/org/apache/solr/search/vector/SolrKnnFloatVectorQuery.java b/solr/core/src/java/org/apache/solr/search/vector/SolrKnnFloatVectorQuery.java index 937b97ff7ec..3fa9af57df7 100644 --- a/solr/core/src/java/org/apache/solr/search/vector/SolrKnnFloatVectorQuery.java +++ b/solr/core/src/java/org/apache/solr/search/vector/SolrKnnFloatVectorQuery.java @@ -19,6 +19,7 @@ package org.apache.solr.search.vector; import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.knn.KnnSearchStrategy; public class SolrKnnFloatVectorQuery extends KnnFloatVectorQuery { private final int topK; @@ -31,6 +32,19 @@ public class SolrKnnFloatVectorQuery extends KnnFloatVectorQuery { this.topK = topK; } + public SolrKnnFloatVectorQuery( + String field, + float[] target, + int topK, + int efSearch, + Query filter, + KnnSearchStrategy searchStrategy) { + // efSearch is used as 'k' to explore this many vectors in HNSW then topK results are returned + // to the user + super(field, target, efSearch, filter, searchStrategy); + this.topK = topK; + } + @Override protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) { return TopDocs.merge(topK, perLeafResults); diff --git a/solr/core/src/test/org/apache/solr/schema/DenseVectorFieldTest.java b/solr/core/src/test/org/apache/solr/schema/DenseVectorFieldTest.java index 077871e3b6f..18794907df2 100644 --- a/solr/core/src/test/org/apache/solr/schema/DenseVectorFieldTest.java +++ b/solr/core/src/test/org/apache/solr/schema/DenseVectorFieldTest.java @@ -975,6 +975,7 @@ public class DenseVectorFieldTest extends AbstractBadConfigTestBase { "vector", "[2, 1, 3, 4]", 3, + 3, null, seedQuery, earlyTermination, @@ -1026,7 +1027,14 @@ public class DenseVectorFieldTest extends AbstractBadConfigTestBase { KnnByteVectorQuery vectorQuery = (KnnByteVectorQuery) type.getKnnVectorQuery( - "vector_byte_encoding", "[2, 1, 3, 4]", 3, 3, null, null, null, expectedThreshold); + "vector_byte_encoding", + "[2, 1, 3, 4]", + 3, + 3, + null, + null, + null, + expectedThreshold); KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy(); Integer threshold = strategy.filteredSearchThreshold(); @@ -1054,6 +1062,7 @@ public class DenseVectorFieldTest extends AbstractBadConfigTestBase { "vector_byte_encoding", "[2, 1, 3, 4]", 3, + 3, null, seedQuery, null, @@ -1087,6 +1096,7 @@ public class DenseVectorFieldTest extends AbstractBadConfigTestBase { "vector_byte_encoding", "[2, 1, 3, 4]", 3, + 3, null, null, earlyTermination, @@ -1121,6 +1131,7 @@ public class DenseVectorFieldTest extends AbstractBadConfigTestBase { "vector_byte_encoding", "[2, 1, 3, 4]", 3, + 3, null, seedQuery, earlyTermination, diff --git a/solr/core/src/test/org/apache/solr/search/vector/KnnQParserTest.java b/solr/core/src/test/org/apache/solr/search/vector/KnnQParserTest.java index 2d06c3027b4..1350b94b83e 100644 --- a/solr/core/src/test/org/apache/solr/search/vector/KnnQParserTest.java +++ b/solr/core/src/test/org/apache/solr/search/vector/KnnQParserTest.java @@ -1354,7 +1354,7 @@ public class KnnQParserTest extends SolrTestCaseJ4 { "id", "debugQuery", "true"), - "//result[@numFound='4']", + "//result[@numFound='4']"); } @Test @@ -1400,7 +1400,7 @@ public class KnnQParserTest extends SolrTestCaseJ4 { "id", "debugQuery", "true"), - "//result[@numFound='4']", + "//result[@numFound='4']"); } @Test @@ -1428,9 +1428,9 @@ public class KnnQParserTest extends SolrTestCaseJ4 { "//str[@name='parsedquery'][contains(.,'seed=id:1 id:4 id:7 id:8 id:9')]", // Verify that a seedWeight field is present — its value (BooleanWeight@<hash>) includes a // hash code that changes on each run, so it cannot be asserted explicitly - "//str[@name='parsedquery'][contains(.,'seedWeight=')]", - // Verify that the final delegate is a KnnFloatVectorQuery with the expected vector and topK - // value + "//str[@name='parsedquery'][contains(.,'seedWeight=')]"); + // Verify that the final delegate is a KnnFloatVectorQuery with the expected vector and topK + // value } @Test
