msokolov commented on code in PR #15549:
URL: https://github.com/apache/lucene/pull/15549#discussion_r2833833291
##########
lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java:
##########
@@ -658,78 +758,111 @@ static int mergeAndRecalculateCentroids(
static int calculateCentroid(MergeState mergeState, FieldInfo fieldInfo,
float[] centroid)
throws IOException {
- assert fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32);
+
+ VectorEncoding vectorEncoding = fieldInfo.getVectorEncoding();
+ assert vectorEncoding == VectorEncoding.FLOAT32 || vectorEncoding ==
VectorEncoding.FLOAT16;
// clear out the centroid
Arrays.fill(centroid, 0);
int count = 0;
for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i];
if (knnVectorsReader == null) continue;
- FloatVectorValues vectorValues =
- mergeState.knnVectorsReaders[i].getFloatVectorValues(fieldInfo.name);
- if (vectorValues == null) {
- continue;
- }
- KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
- for (int doc = iterator.nextDoc();
- doc != DocIdSetIterator.NO_MORE_DOCS;
- doc = iterator.nextDoc()) {
- ++count;
- float[] vector = vectorValues.vectorValue(iterator.index());
- for (int j = 0; j < vector.length; j++) {
- centroid[j] += vector[j];
- }
- }
+
+ count += accumulateCentroid(knnVectorsReader, fieldInfo, vectorEncoding,
centroid);
}
+
if (count == 0) {
return count;
}
+
for (int i = 0; i < centroid.length; i++) {
centroid[i] /= count;
}
+
if (fieldInfo.getVectorSimilarityFunction() == COSINE) {
VectorUtil.l2normalize(centroid);
}
+
+ return count;
+ }
+
+ private static int accumulateCentroid(
+ KnnVectorsReader reader, FieldInfo fieldInfo, VectorEncoding encoding,
float[] centroid)
+ throws IOException {
+ int count = 0;
+
+ if (encoding == VectorEncoding.FLOAT32) {
+ FloatVectorValues vectorValues =
reader.getFloatVectorValues(fieldInfo.name);
+ if (vectorValues == null) return 0;
+
+ KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
+ for (int doc = iterator.nextDoc(); doc != NO_MORE_DOCS; doc =
iterator.nextDoc()) {
+ count++;
+ float[] vector = vectorValues.vectorValue(iterator.index());
+ for (int j = 0; j < vector.length; j++) {
+ centroid[j] += vector[j];
+ }
+ }
+ } else if (encoding == VectorEncoding.FLOAT16) {
+ Float16VectorValues vectorValues =
reader.getFloat16VectorValues(fieldInfo.name);
+ if (vectorValues == null) return 0;
+
+ KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
+ for (int doc = iterator.nextDoc(); doc != NO_MORE_DOCS; doc =
iterator.nextDoc()) {
+ count++;
+ short[] vector = vectorValues.vectorValue(iterator.index());
+ for (int j = 0; j < vector.length; j++) {
+ centroid[j] += Float.float16ToFloat(vector[j]);
+ }
+ }
+ }
+
return count;
}
@Override
public long ramBytesUsed() {
long total = SHALLOW_RAM_BYTES_USED;
- for (FieldWriter field : fields) {
+ for (FieldWriter<?> field : fields) {
// the field tracks the delegate field usage
total += field.ramBytesUsed();
}
return total;
}
- static class FieldWriter extends FlatFieldVectorsWriter<float[]> {
+ private abstract static class FieldWriter<T> extends
FlatFieldVectorsWriter<T> {
Review Comment:
ah I see, we already had generics so we're almost forced to use it here now
##########
lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java:
##########
@@ -328,7 +356,8 @@ public void finish() throws IOException {
@Override
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws
IOException {
- if (!fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
+ VectorEncoding vectorEncoding = fieldInfo.getVectorEncoding();
+ if (vectorEncoding != VectorEncoding.FLOAT32 && vectorEncoding !=
VectorEncoding.FLOAT16) {
Review Comment:
Maybe we could introduce a method `VectorEncoding.isFloatingPoint()`?
##########
lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java:
##########
@@ -658,78 +758,111 @@ static int mergeAndRecalculateCentroids(
static int calculateCentroid(MergeState mergeState, FieldInfo fieldInfo,
float[] centroid)
throws IOException {
- assert fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32);
+
+ VectorEncoding vectorEncoding = fieldInfo.getVectorEncoding();
+ assert vectorEncoding == VectorEncoding.FLOAT32 || vectorEncoding ==
VectorEncoding.FLOAT16;
// clear out the centroid
Arrays.fill(centroid, 0);
int count = 0;
for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i];
if (knnVectorsReader == null) continue;
- FloatVectorValues vectorValues =
- mergeState.knnVectorsReaders[i].getFloatVectorValues(fieldInfo.name);
- if (vectorValues == null) {
- continue;
- }
- KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
- for (int doc = iterator.nextDoc();
- doc != DocIdSetIterator.NO_MORE_DOCS;
- doc = iterator.nextDoc()) {
- ++count;
- float[] vector = vectorValues.vectorValue(iterator.index());
- for (int j = 0; j < vector.length; j++) {
- centroid[j] += vector[j];
- }
- }
+
+ count += accumulateCentroid(knnVectorsReader, fieldInfo, vectorEncoding,
centroid);
}
+
if (count == 0) {
return count;
}
+
for (int i = 0; i < centroid.length; i++) {
centroid[i] /= count;
}
+
if (fieldInfo.getVectorSimilarityFunction() == COSINE) {
VectorUtil.l2normalize(centroid);
}
+
+ return count;
+ }
+
+ private static int accumulateCentroid(
+ KnnVectorsReader reader, FieldInfo fieldInfo, VectorEncoding encoding,
float[] centroid)
+ throws IOException {
+ int count = 0;
+
+ if (encoding == VectorEncoding.FLOAT32) {
+ FloatVectorValues vectorValues =
reader.getFloatVectorValues(fieldInfo.name);
+ if (vectorValues == null) return 0;
+
+ KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
+ for (int doc = iterator.nextDoc(); doc != NO_MORE_DOCS; doc =
iterator.nextDoc()) {
+ count++;
+ float[] vector = vectorValues.vectorValue(iterator.index());
+ for (int j = 0; j < vector.length; j++) {
+ centroid[j] += vector[j];
+ }
+ }
+ } else if (encoding == VectorEncoding.FLOAT16) {
+ Float16VectorValues vectorValues =
reader.getFloat16VectorValues(fieldInfo.name);
+ if (vectorValues == null) return 0;
+
+ KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
+ for (int doc = iterator.nextDoc(); doc != NO_MORE_DOCS; doc =
iterator.nextDoc()) {
+ count++;
+ short[] vector = vectorValues.vectorValue(iterator.index());
+ for (int j = 0; j < vector.length; j++) {
+ centroid[j] += Float.float16ToFloat(vector[j]);
+ }
+ }
+ }
+
return count;
}
@Override
public long ramBytesUsed() {
long total = SHALLOW_RAM_BYTES_USED;
- for (FieldWriter field : fields) {
+ for (FieldWriter<?> field : fields) {
// the field tracks the delegate field usage
total += field.ramBytesUsed();
}
return total;
}
- static class FieldWriter extends FlatFieldVectorsWriter<float[]> {
+ private abstract static class FieldWriter<T> extends
FlatFieldVectorsWriter<T> {
private static final long SHALLOW_SIZE =
shallowSizeOfInstance(FieldWriter.class);
- private final FieldInfo fieldInfo;
+ protected final FieldInfo fieldInfo;
private boolean finished;
- private final FlatFieldVectorsWriter<float[]> flatFieldVectorsWriter;
- private final float[] dimensionSums;
- private final FloatArrayList magnitudes = new FloatArrayList();
+ protected final FlatFieldVectorsWriter<T> flatFieldVectorsWriter;
+ protected final float[] dimensionSums;
+ protected final FloatArrayList magnitudes = new FloatArrayList();
+ protected final int dim;
- FieldWriter(FieldInfo fieldInfo, FlatFieldVectorsWriter<float[]>
flatFieldVectorsWriter) {
+ FieldWriter(FieldInfo fieldInfo, FlatFieldVectorsWriter<T>
flatFieldVectorsWriter) {
this.fieldInfo = fieldInfo;
this.flatFieldVectorsWriter = flatFieldVectorsWriter;
- this.dimensionSums = new float[fieldInfo.getVectorDimension()];
+ this.dim = fieldInfo.getVectorDimension();
+ this.dimensionSums = new float[dim];
}
- @Override
- public List<float[]> getVectors() {
- return flatFieldVectorsWriter.getVectors();
+ @SuppressWarnings("unchecked")
+ static FieldWriter<?> create(
Review Comment:
would it help to have <T> co-varying with `FlatFieldVectorsWriter<T>`
instead of `<?>`?
##########
lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java:
##########
@@ -177,6 +178,39 @@ public ByteVectorValues getByteVectorValues(String field)
throws IOException {
return new SimpleTextByteVectorValues(fieldEntry, bytesSlice);
}
+ @Override
+ public Float16VectorValues getFloat16VectorValues(String field) throws
IOException {
+ FieldInfo info = readState.fieldInfos.fieldInfo(field);
+ if (info == null) {
+ // mirror the handling in Lucene90VectorReader#getVectorValues
+ // needed to pass TestSimpleTextKnnVectorsFormat#testDeleteAllVectorDocs
+ return null;
+ }
+ int dimension = info.getVectorDimension();
+ if (dimension == 0) {
+ throw new IllegalStateException(
+ "KNN vectors readers should not be called on fields that don't
enable KNN vectors");
+ }
+ FieldEntry fieldEntry = fieldEntries.get(info.number);
+ if (fieldEntry == null) {
+ // mirror the handling in Lucene90VectorReader#getVectorValues
Review Comment:
same here
##########
lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java:
##########
@@ -201,9 +203,22 @@ private void writeVectors(
new byte[encoding.getDocPackedLength(scratch.length)];
};
for (int i = 0; i < fieldData.getVectors().size(); i++) {
- float[] v = fieldData.getVectors().get(i);
- OptimizedScalarQuantizer.QuantizationResult corrections =
- scalarQuantizer.scalarQuantize(v, scratch, encoding.getBits(),
clusterCenter);
+ OptimizedScalarQuantizer.QuantizationResult corrections = null;
+ if (fieldData.fieldInfo.getVectorEncoding() == VectorEncoding.FLOAT32) {
+ corrections =
+ scalarQuantizer.scalarQuantize(
+ (float[]) fieldData.getVectors().get(i),
Review Comment:
hmm, even using generics we still have a cast here -- I wonder if the
introduction of generics is worth it
##########
lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorUtilBenchmark.java:
##########
@@ -314,6 +320,11 @@ public float floatDotProductVector() {
return VectorUtil.dotProduct(floatsA, floatsB);
}
+ @Benchmark
+ public float shortDotProductScalar() {
Review Comment:
This treats the `short[]` as an array of fp16 right? Maybe we should change
the name of the method to `fp16DotProductScalar1
##########
lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java:
##########
@@ -177,6 +178,39 @@ public ByteVectorValues getByteVectorValues(String field)
throws IOException {
return new SimpleTextByteVectorValues(fieldEntry, bytesSlice);
}
+ @Override
+ public Float16VectorValues getFloat16VectorValues(String field) throws
IOException {
+ FieldInfo info = readState.fieldInfos.fieldInfo(field);
+ if (info == null) {
+ // mirror the handling in Lucene90VectorReader#getVectorValues
Review Comment:
This comment seems out of date since `Lucene90VectorReader` no longer
exists. I'd simply delete the whole comment
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]