Copilot commented on code in PR #7330: URL: https://github.com/apache/paimon/pull/7330#discussion_r2904146082
########## paimon-lumina/src/main/java/org/apache/paimon/lumina/index/LuminaVectorGlobalIndexWriter.java: ########## @@ -0,0 +1,302 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.lumina.index; + +import org.apache.paimon.data.InternalArray; +import org.apache.paimon.fs.PositionOutputStream; +import org.apache.paimon.globalindex.GlobalIndexSingletonWriter; +import org.apache.paimon.globalindex.ResultEntry; +import org.apache.paimon.globalindex.io.GlobalIndexFileWriter; +import org.apache.paimon.types.ArrayType; +import org.apache.paimon.types.DataType; +import org.apache.paimon.types.FloatType; + +import org.aliyun.lumina.LuminaFileOutput; + +import java.io.Closeable; +import java.io.IOException; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.nio.FloatBuffer; +import java.nio.LongBuffer; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; + +/** + * Vector global index writer using Lumina. + * + * <p>Vectors are collected until the current index reaches {@code sizePerIndex} vectors, then + * pretrained, inserted in a single batch, and dumped to a file. DiskANN requires exactly one + * pretrain and one insertBatch call per index. + * + * <p>Each written vector is assigned a monotonically increasing 64-bit row ID ({@code count}) that + * spans across all produced index files. The second index file's IDs therefore start from {@code + * sizePerIndex}, not from 0. The min/max IDs stored in {@link LuminaIndexMeta} reflect this global + * range, enabling the reader to skip index files that have no overlap with a given filter set. + */ +public class LuminaVectorGlobalIndexWriter implements GlobalIndexSingletonWriter, Closeable { + + private final GlobalIndexFileWriter fileWriter; + private final LuminaVectorIndexOptions options; + private final int sizePerIndex; + private final int dim; + + private long count = 0; // monotonically increasing global row ID across all index files + private long currentIndexMinId = Long.MAX_VALUE; + private long currentIndexMaxId = Long.MIN_VALUE; + private ByteBuffer pendingVectors; + private ByteBuffer pendingIds; + private FloatBuffer pendingFloatView; + private LongBuffer pendingLongView; + private int pendingCount = 0; + private final List<ResultEntry> results; + + public LuminaVectorGlobalIndexWriter( + GlobalIndexFileWriter fileWriter, + DataType fieldType, + LuminaVectorIndexOptions options) { + this.fileWriter = fileWriter; + this.options = options; + this.dim = options.dimension(); + int configuredSize = options.sizePerIndex(); + long buildMemoryLimit = options.buildMemoryLimit(); + int maxByDim = + (int) Math.min(configuredSize, buildMemoryLimit / ((long) dim * Float.BYTES)); + this.sizePerIndex = Math.max(maxByDim, 1); + this.pendingVectors = LuminaIndex.allocateVectorBuffer(sizePerIndex, dim); + this.pendingIds = LuminaIndex.allocateIdBuffer(sizePerIndex); + this.pendingFloatView = pendingVectors.asFloatBuffer(); + this.pendingLongView = pendingIds.asLongBuffer(); + this.results = new ArrayList<>(); + + validateFieldType(fieldType); + } + + private void validateFieldType(DataType dataType) { + if (!(dataType instanceof ArrayType)) { + throw new IllegalArgumentException( + "Lumina vector index requires ArrayType, but got: " + dataType); + } + DataType elementType = ((ArrayType) dataType).getElementType(); + if (!(elementType instanceof FloatType)) { + throw new IllegalArgumentException( + "Lumina vector index requires float array, but got: " + elementType); + } + } + + @Override + public void write(Object fieldData) { + float[] vector; + if (fieldData instanceof float[]) { + vector = (float[]) fieldData; + } else if (fieldData instanceof InternalArray) { + vector = ((InternalArray) fieldData).toFloatArray(); + } else { + throw new RuntimeException( + "Unsupported vector type: " + fieldData.getClass().getName()); + } + checkDimension(vector); + if (options.normalize()) { + LuminaVectorUtils.normalizeL2(vector); + } + currentIndexMinId = Math.min(currentIndexMinId, count); + currentIndexMaxId = Math.max(currentIndexMaxId, count); + int offset = pendingCount * dim; + for (int i = 0; i < dim; i++) { + pendingFloatView.put(offset + i, vector[i]); + } + pendingLongView.put(pendingCount, count); + pendingCount++; + count++; + + try { + if (pendingCount >= sizePerIndex) { + buildAndFlushIndex(); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public List<ResultEntry> finish() { + try { + if (pendingCount > 0) { + buildAndFlushIndex(); + } + return results; + } catch (IOException e) { + throw new RuntimeException("Failed to write Lumina vector global index", e); + } + } + + /** + * Build a complete DiskANN index from the current pending batch: create index, pretrain, insert + * all vectors in a single batch, dump directly to the output stream, and close. + */ + private void buildAndFlushIndex() throws IOException { + if (pendingCount == 0) { + return; + } + + int n = pendingCount; + LuminaIndex index = createIndex(); + + try { + int trainingSize = Math.min(n, options.trainingSize()); + int[] sampleIndices = reservoirSample(n, trainingSize); + ByteBuffer trainingBuffer = LuminaIndex.allocateVectorBuffer(trainingSize, dim); + FloatBuffer trainingFloatView = trainingBuffer.asFloatBuffer(); + for (int i = 0; i < trainingSize; i++) { + int srcOffset = sampleIndices[i] * dim; + for (int j = 0; j < dim; j++) { + trainingFloatView.put(i * dim + j, pendingFloatView.get(srcOffset + j)); + } + } + index.pretrain(trainingBuffer, trainingSize); + trainingBuffer = null; + trainingFloatView = null; + + index.insertBatch(pendingVectors, pendingIds, n); + + String fileName = fileWriter.newFileName(LuminaVectorGlobalIndexerFactory.IDENTIFIER); + try (PositionOutputStream out = fileWriter.newOutputStream(fileName)) { + index.dump(new OutputStreamFileOutput(out)); + out.flush(); + } + + LuminaIndexMeta meta = + new LuminaIndexMeta( + dim, + options.metric().getValue(), + options.indexType().name(), + n, + currentIndexMinId, + currentIndexMaxId); + results.add(new ResultEntry(fileName, n, meta.serialize())); + } finally { + index.close(); + } + + pendingCount = 0; + currentIndexMinId = Long.MAX_VALUE; + currentIndexMaxId = Long.MIN_VALUE; + } + + private LuminaIndex createIndex() { + Map<String, String> extraOptions = new LinkedHashMap<>(); + extraOptions.put("encoding.type", options.encodingType()); + + if (options.pretrainSampleRatio() != 1.0) { + extraOptions.put( + "pretrain.sample_ratio", String.valueOf(options.pretrainSampleRatio())); + } + + if (options.diskannEfConstruction() != null) { + extraOptions.put( + "diskann.build.ef_construction", + String.valueOf(options.diskannEfConstruction())); + } + if (options.diskannNeighborCount() != null) { + extraOptions.put( + "diskann.build.neighbor_count", String.valueOf(options.diskannNeighborCount())); + } + if (options.diskannBuildThreadCount() != null) { + extraOptions.put( + "diskann.build.thread_count", + String.valueOf(options.diskannBuildThreadCount())); + } + + return LuminaIndex.createForBuild(dim, options.metric(), options.indexType(), extraOptions); + } + + /** + * Selects {@code k} indices from [0, n) using reservoir sampling (Algorithm R). + * + * <p>When {@code k >= n} all indices are returned in order. Otherwise a random representative + * subset is chosen, ensuring training data covers the full vector distribution instead of being + * biased toward the first {@code k} inserted vectors. + */ + private static int[] reservoirSample(int n, int k) { + int[] reservoir = new int[k]; + for (int i = 0; i < k; i++) { + reservoir[i] = i; + } + if (k < n) { + Random random = new Random(42); + for (int i = k; i < n; i++) { + int j = random.nextInt(i + 1); + if (j < k) { + reservoir[j] = i; + } + } + } + return reservoir; + } + + private void checkDimension(float[] vector) { + if (vector.length != dim) { + throw new IllegalArgumentException( + String.format( + "Vector dimension mismatch: expected %d, but got %d", + dim, vector.length)); + } + } + + @Override + public void close() { + pendingVectors = null; + pendingIds = null; + pendingFloatView = null; + pendingLongView = null; + pendingCount = 0; + } + + /** Adapts a {@link PositionOutputStream} to the {@link LuminaFileOutput} JNI callback API. */ + private static class OutputStreamFileOutput implements LuminaFileOutput { + private final OutputStream out; + + OutputStreamFileOutput(OutputStream out) { + this.out = out; + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + out.write(b, off, len); + } + + @Override + public void flush() throws IOException { + out.flush(); + } + + @Override + public long getPos() { + return -1; + } Review Comment: `LuminaFileOutput#getPos()` returning `-1` is very likely incorrect. If Lumina uses `getPos()` to track the write offset (common for streaming/native dump implementations), this can corrupt the dump or trigger undefined behavior. Consider delegating to the underlying `PositionOutputStream` position (e.g., store it as a `PositionOutputStream` field and return its current position), or implement position tracking in this adapter. ########## paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/LuminaVectorIndexTest.scala: ########## @@ -0,0 +1,391 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.sql + +import org.apache.paimon.spark.PaimonSparkTestBase + +import scala.collection.JavaConverters._ + +/** Tests for Lumina vector index read/write operations. */ +class LuminaVectorIndexTest extends PaimonSparkTestBase { + + private val indexType = "lumina-vector-ann" + private val defaultOptions = "vector.dim=3,vector.index-type=DISKANN" + + // ========== Index Creation Tests ========== + + test("create lumina vector index - basic") { + withTable("T") { + spark.sql(""" + |CREATE TABLE T (id INT, v ARRAY<FLOAT>) + |TBLPROPERTIES ( + | 'bucket' = '-1', + | 'global-index.row-count-per-shard' = '10000', + | 'row-tracking.enabled' = 'true', + | 'data-evolution.enabled' = 'true') + |""".stripMargin) + + val values = (0 until 100) + .map( + i => s"($i, array(cast($i as float), cast(${i + 1} as float), cast(${i + 2} as float)))") + .mkString(",") + spark.sql(s"INSERT INTO T VALUES $values") + + val output = spark + .sql( + s"CALL sys.create_global_index(table => 'test.T', index_column => 'v', index_type => '$indexType', options => '$defaultOptions')") + .collect() + .head + assert(output.getBoolean(0)) + + val table = loadTable("T") + val indexEntries = table + .store() + .newIndexFileHandler() + .scanEntries() + .asScala + .filter(_.indexFile().indexType() == indexType) + + assert(indexEntries.nonEmpty) + val totalRowCount = indexEntries.map(_.indexFile().rowCount()).sum + assert(totalRowCount == 100L) + } + } + + test("create lumina vector index - with different index types") { + withTable("T") { + spark.sql(""" + |CREATE TABLE T (id INT, v ARRAY<FLOAT>) + |TBLPROPERTIES ( + | 'bucket' = '-1', + | 'global-index.row-count-per-shard' = '10000', + | 'row-tracking.enabled' = 'true', + | 'data-evolution.enabled' = 'true') + |""".stripMargin) + + val values = (0 until 50) + .map( + i => s"($i, array(cast($i as float), cast(${i + 1} as float), cast(${i + 2} as float)))") + .mkString(",") + spark.sql(s"INSERT INTO T VALUES $values") + + val output = spark + .sql( + s"CALL sys.create_global_index(table => 'test.T', index_column => 'v', index_type => '$indexType', options => '$defaultOptions')") + .collect() + .head + assert(output.getBoolean(0)) + + val table = loadTable("T") + val indexEntries = table + .store() + .newIndexFileHandler() + .scanEntries() + .asScala + .filter(_.indexFile().indexType() == indexType) + + assert(indexEntries.nonEmpty) + } + } + + test("create lumina vector index - with partitioned table") { + withTable("T") { + spark.sql(""" + |CREATE TABLE T (id INT, v ARRAY<FLOAT>, pt STRING) + |TBLPROPERTIES ( + | 'bucket' = '-1', + | 'global-index.row-count-per-shard' = '10000', + | 'row-tracking.enabled' = 'true', + | 'data-evolution.enabled' = 'true') + | PARTITIONED BY (pt) + |""".stripMargin) + + var values = (0 until 500) + .map( + i => + s"($i, array(cast($i as float), cast(${i + 1} as float), cast(${i + 2} as float)), 'p0')") + .mkString(",") + spark.sql(s"INSERT INTO T VALUES $values") + + values = (0 until 300) + .map( + i => + s"($i, array(cast($i as float), cast(${i + 1} as float), cast(${i + 2} as float)), 'p1')") + .mkString(",") + spark.sql(s"INSERT INTO T VALUES $values") + + val output = spark + .sql( + s"CALL sys.create_global_index(table => 'test.T', index_column => 'v', index_type => '$indexType', options => '$defaultOptions')") + .collect() + .head + assert(output.getBoolean(0)) + + val table = loadTable("T") + val indexEntries = table + .store() + .newIndexFileHandler() + .scanEntries() + .asScala + .filter(_.indexFile().indexType() == indexType) + + assert(indexEntries.nonEmpty) + val totalRowCount = indexEntries.map(_.indexFile().rowCount()).sum + assert(totalRowCount == 800L) + } + } + + // ========== Index Write Tests ========== + + test("write vectors - large dataset") { + withTable("T") { + spark.sql(""" + |CREATE TABLE T (id INT, v ARRAY<FLOAT>) + |TBLPROPERTIES ( + | 'bucket' = '-1', + | 'global-index.row-count-per-shard' = '10000', + | 'row-tracking.enabled' = 'true', + | 'data-evolution.enabled' = 'true') + |""".stripMargin) + + val values = (0 until 10000) + .map( + i => s"($i, array(cast($i as float), cast(${i + 1} as float), cast(${i + 2} as float)))") + .mkString(",") + spark.sql(s"INSERT INTO T VALUES $values") Review Comment: This test builds a single SQL `INSERT INTO ... VALUES ...` statement containing 10,000 rows. That can be slow/flaky (very large SQL string, parser overhead, potential SQL length limits) and can dominate test runtime. Prefer writing data via a DataFrame/Dataset (e.g., `spark.createDataFrame`/`spark.range` + `selectExpr`) or batching inserts into smaller chunks. ```suggestion val df = spark .range(0, 10000) .selectExpr( "cast(id as int) as id", "array(cast(id as float), cast(id + 1 as float), cast(id + 2 as float)) as v") df.write.insertInto("T") ``` ########## paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/LuminaVectorIndexTest.scala: ########## @@ -0,0 +1,391 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.sql + +import org.apache.paimon.spark.PaimonSparkTestBase + +import scala.collection.JavaConverters._ + +/** Tests for Lumina vector index read/write operations. */ +class LuminaVectorIndexTest extends PaimonSparkTestBase { + + private val indexType = "lumina-vector-ann" + private val defaultOptions = "vector.dim=3,vector.index-type=DISKANN" + + // ========== Index Creation Tests ========== + + test("create lumina vector index - basic") { + withTable("T") { + spark.sql(""" + |CREATE TABLE T (id INT, v ARRAY<FLOAT>) + |TBLPROPERTIES ( + | 'bucket' = '-1', + | 'global-index.row-count-per-shard' = '10000', + | 'row-tracking.enabled' = 'true', + | 'data-evolution.enabled' = 'true') + |""".stripMargin) + + val values = (0 until 100) + .map( + i => s"($i, array(cast($i as float), cast(${i + 1} as float), cast(${i + 2} as float)))") + .mkString(",") + spark.sql(s"INSERT INTO T VALUES $values") + + val output = spark + .sql( + s"CALL sys.create_global_index(table => 'test.T', index_column => 'v', index_type => '$indexType', options => '$defaultOptions')") + .collect() + .head + assert(output.getBoolean(0)) + + val table = loadTable("T") + val indexEntries = table + .store() + .newIndexFileHandler() + .scanEntries() + .asScala + .filter(_.indexFile().indexType() == indexType) + + assert(indexEntries.nonEmpty) + val totalRowCount = indexEntries.map(_.indexFile().rowCount()).sum + assert(totalRowCount == 100L) + } + } + + test("create lumina vector index - with different index types") { + withTable("T") { + spark.sql(""" + |CREATE TABLE T (id INT, v ARRAY<FLOAT>) + |TBLPROPERTIES ( + | 'bucket' = '-1', + | 'global-index.row-count-per-shard' = '10000', + | 'row-tracking.enabled' = 'true', + | 'data-evolution.enabled' = 'true') + |""".stripMargin) + + val values = (0 until 50) + .map( + i => s"($i, array(cast($i as float), cast(${i + 1} as float), cast(${i + 2} as float)))") + .mkString(",") + spark.sql(s"INSERT INTO T VALUES $values") + + val output = spark + .sql( + s"CALL sys.create_global_index(table => 'test.T', index_column => 'v', index_type => '$indexType', options => '$defaultOptions')") Review Comment: The test name says “with different index types”, but the test uses a single `indexType` (`lumina-vector-ann`) and `defaultOptions` specifies only `DISKANN`. Either vary the index configuration(s) exercised by the test or rename the test to reflect what it actually validates. ########## paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/LuminaVectorIndexTest.scala: ########## @@ -0,0 +1,391 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.sql + +import org.apache.paimon.spark.PaimonSparkTestBase + +import scala.collection.JavaConverters._ + +/** Tests for Lumina vector index read/write operations. */ +class LuminaVectorIndexTest extends PaimonSparkTestBase { + + private val indexType = "lumina-vector-ann" + private val defaultOptions = "vector.dim=3,vector.index-type=DISKANN" + + // ========== Index Creation Tests ========== + + test("create lumina vector index - basic") { + withTable("T") { + spark.sql(""" + |CREATE TABLE T (id INT, v ARRAY<FLOAT>) + |TBLPROPERTIES ( + | 'bucket' = '-1', + | 'global-index.row-count-per-shard' = '10000', + | 'row-tracking.enabled' = 'true', + | 'data-evolution.enabled' = 'true') + |""".stripMargin) + + val values = (0 until 100) + .map( + i => s"($i, array(cast($i as float), cast(${i + 1} as float), cast(${i + 2} as float)))") + .mkString(",") + spark.sql(s"INSERT INTO T VALUES $values") + + val output = spark + .sql( + s"CALL sys.create_global_index(table => 'test.T', index_column => 'v', index_type => '$indexType', options => '$defaultOptions')") + .collect() + .head + assert(output.getBoolean(0)) + + val table = loadTable("T") + val indexEntries = table + .store() + .newIndexFileHandler() + .scanEntries() + .asScala + .filter(_.indexFile().indexType() == indexType) + + assert(indexEntries.nonEmpty) + val totalRowCount = indexEntries.map(_.indexFile().rowCount()).sum + assert(totalRowCount == 100L) + } + } + + test("create lumina vector index - with different index types") { + withTable("T") { + spark.sql(""" + |CREATE TABLE T (id INT, v ARRAY<FLOAT>) + |TBLPROPERTIES ( + | 'bucket' = '-1', + | 'global-index.row-count-per-shard' = '10000', + | 'row-tracking.enabled' = 'true', + | 'data-evolution.enabled' = 'true') + |""".stripMargin) Review Comment: The test name says “with different index types”, but the test uses a single `indexType` (`lumina-vector-ann`) and `defaultOptions` specifies only `DISKANN`. Either vary the index configuration(s) exercised by the test or rename the test to reflect what it actually validates. ########## paimon-lumina/src/main/java/org/apache/paimon/lumina/index/LuminaVectorGlobalIndexReader.java: ########## @@ -0,0 +1,531 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.lumina.index; + +import org.apache.paimon.fs.SeekableInputStream; +import org.apache.paimon.globalindex.GlobalIndexIOMeta; +import org.apache.paimon.globalindex.GlobalIndexReader; +import org.apache.paimon.globalindex.GlobalIndexResult; +import org.apache.paimon.globalindex.io.GlobalIndexFileReader; +import org.apache.paimon.predicate.FieldRef; +import org.apache.paimon.predicate.VectorSearch; +import org.apache.paimon.types.ArrayType; +import org.apache.paimon.types.DataType; +import org.apache.paimon.types.FloatType; +import org.apache.paimon.utils.IOUtils; +import org.apache.paimon.utils.RoaringNavigableMap64; + +import org.aliyun.lumina.LuminaFileInput; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.PriorityQueue; + +/** + * Vector global index reader using Lumina. + * + * <p>This reader loads Lumina indices from global index files and performs vector similarity + * search. + */ +public class LuminaVectorGlobalIndexReader implements GlobalIndexReader { + + private final LuminaIndex[] indices; + private final LuminaIndexMeta[] indexMetas; + private final List<SeekableInputStream> openStreams; + private final List<GlobalIndexIOMeta> ioMetas; + private final GlobalIndexFileReader fileReader; + private final DataType fieldType; + private final LuminaVectorIndexOptions options; + private volatile boolean metasLoaded = false; + private volatile boolean indicesLoaded = false; + + public LuminaVectorGlobalIndexReader( + GlobalIndexFileReader fileReader, + List<GlobalIndexIOMeta> ioMetas, + DataType fieldType, + LuminaVectorIndexOptions options) { + this.fileReader = fileReader; + this.ioMetas = ioMetas; + this.fieldType = fieldType; + this.options = options; + this.indices = new LuminaIndex[ioMetas.size()]; + this.indexMetas = new LuminaIndexMeta[ioMetas.size()]; + this.openStreams = Collections.synchronizedList(new ArrayList<>()); + } + + @Override + public Optional<GlobalIndexResult> visitVectorSearch(VectorSearch vectorSearch) { + try { + ensureLoadMetas(); + + RoaringNavigableMap64 includeRowIds = vectorSearch.includeRowIds(); + + if (includeRowIds != null) { + List<Integer> matchingIndices = new ArrayList<>(); + for (int i = 0; i < indexMetas.length; i++) { + LuminaIndexMeta meta = indexMetas[i]; + if (includeRowIds.containsRange(meta.minId(), meta.maxId())) { + matchingIndices.add(i); + } + } + if (matchingIndices.isEmpty()) { + return Optional.empty(); + } + ensureLoadIndices(matchingIndices); + } else { + ensureLoadAllIndices(); + } + + return Optional.ofNullable(search(vectorSearch)); + } catch (IOException e) { + throw new RuntimeException( + String.format( + "Failed to search Lumina vector index with fieldName=%s, limit=%d", + vectorSearch.fieldName(), vectorSearch.limit()), + e); + } + } + + private GlobalIndexResult search(VectorSearch vectorSearch) throws IOException { + validateVectorType(vectorSearch.vector()); + float[] queryVector = ((float[]) vectorSearch.vector()).clone(); + if (options.normalize()) { + LuminaVectorUtils.normalizeL2(queryVector); + } + int limit = vectorSearch.limit(); + + PriorityQueue<ScoredRow> result = + new PriorityQueue<>(Comparator.comparingDouble(sr -> sr.score)); + + RoaringNavigableMap64 includeRowIds = vectorSearch.includeRowIds(); + + // Extract sorted filter IDs once; per-index scoping happens inside the loop. + long[] allFilterIds = null; + if (includeRowIds != null) { + long cardinality = includeRowIds.getLongCardinality(); + if (cardinality <= 0) { + return new LuminaScoredGlobalIndexResult( + new RoaringNavigableMap64(), new HashMap<>()); + } + if (cardinality > Integer.MAX_VALUE) { + throw new RuntimeException( + "Filter bitmap cardinality (" + + cardinality + + ") exceeds maximum supported size for native pre-filtering"); + } + allFilterIds = new long[(int) cardinality]; + int idx = 0; + for (long id : includeRowIds) { + allFilterIds[idx++] = id; + } + } + + Map<String, String> filterSearchOptions = null; + Map<String, String> plainSearchOptions = null; + if (allFilterIds != null) { + filterSearchOptions = new LinkedHashMap<>(); + int listSize = Math.max(limit * options.searchFactor(), options.searchListSize()); + filterSearchOptions.put("diskann.search.list_size", String.valueOf(listSize)); + filterSearchOptions.put("search.thread_safe_filter", "true"); + } else { + plainSearchOptions = new LinkedHashMap<>(); + int listSize = Math.max(limit, options.searchListSize()); + plainSearchOptions.put("diskann.search.list_size", String.valueOf(listSize)); + } + + for (int i = 0; i < indices.length; i++) { + LuminaIndex index = indices[i]; + if (index == null) { + continue; + } + + int effectiveK = (int) Math.min(limit, index.size()); + if (effectiveK <= 0) { + continue; + } + + if (allFilterIds != null) { + LuminaIndexMeta meta = indexMetas[i]; + long[] scopedIds = scopeFilterIds(allFilterIds, meta.minId(), meta.maxId()); + if (scopedIds.length == 0) { + continue; + } + effectiveK = (int) Math.min(effectiveK, scopedIds.length); + + float[] distances = new float[effectiveK]; + long[] labels = new long[effectiveK]; + index.searchWithFilter( + queryVector, + 1, + effectiveK, + distances, + labels, + scopedIds, + filterSearchOptions); + collectResults(distances, labels, effectiveK, limit, result); + } else { + float[] distances = new float[effectiveK]; + long[] labels = new long[effectiveK]; + index.search(queryVector, 1, effectiveK, distances, labels, plainSearchOptions); + collectResults(distances, labels, effectiveK, limit, result); + } + } + + RoaringNavigableMap64 roaringBitmap64 = new RoaringNavigableMap64(); + HashMap<Long, Float> id2scores = new HashMap<>(result.size()); + for (ScoredRow scoredRow : result) { + id2scores.put(scoredRow.rowId, scoredRow.score); + roaringBitmap64.add(scoredRow.rowId); + } + return new LuminaScoredGlobalIndexResult(roaringBitmap64, id2scores); + } + + private void collectResults( + float[] distances, + long[] labels, + int count, + int limit, + PriorityQueue<ScoredRow> result) { + for (int i = 0; i < count; i++) { + long rowId = labels[i]; + if (rowId < 0) { + continue; + } + float score = convertDistanceToScore(distances[i]); + if (result.size() < limit) { + result.offer(new ScoredRow(rowId, score)); + } else if (result.peek() != null && score > result.peek().score) { + result.poll(); + result.offer(new ScoredRow(rowId, score)); + } + } + } + + /** + * Extract the subset of {@code sortedIds} that falls within [{@code minId}, {@code maxId}] + * using binary search. The input array must be sorted in ascending order (guaranteed by roaring + * bitmap iteration order). + */ + private static long[] scopeFilterIds(long[] sortedIds, long minId, long maxId) { + int from = lowerBound(sortedIds, minId); + int to = upperBound(sortedIds, maxId); + if (from >= to) { + return new long[0]; + } + if (from == 0 && to == sortedIds.length) { + return sortedIds; + } + return Arrays.copyOfRange(sortedIds, from, to); + } + + /** Return the index of the first element >= target. */ + private static int lowerBound(long[] arr, long target) { + int lo = 0; + int hi = arr.length; + while (lo < hi) { + int mid = (lo + hi) >>> 1; + if (arr[mid] < target) { + lo = mid + 1; + } else { + hi = mid; + } + } + return lo; + } + + /** Return the index of the first element > target. */ + private static int upperBound(long[] arr, long target) { + int lo = 0; + int hi = arr.length; + while (lo < hi) { + int mid = (lo + hi) >>> 1; + if (arr[mid] <= target) { + lo = mid + 1; + } else { + hi = mid; + } + } + return lo; + } + + private float convertDistanceToScore(float distance) { + if (options.metric() == LuminaVectorMetric.L2) { + return 1.0f / (1.0f + distance); + } else if (options.metric() == LuminaVectorMetric.COSINE) { + // Cosine distance is in [0, 2]; convert to similarity in [-1, 1] + return 1.0f - distance; + } else { + // Inner product is already a similarity + return distance; + } + } + + private void validateVectorType(Object vector) { + if (!(vector instanceof float[])) { + throw new IllegalArgumentException( + "Expected float[] vector but got: " + vector.getClass()); + } + if (!(fieldType instanceof ArrayType) + || !(((ArrayType) fieldType).getElementType() instanceof FloatType)) { + throw new IllegalArgumentException( + "Lumina currently only supports float arrays, but field type is: " + fieldType); + } + } + + private void ensureLoadMetas() throws IOException { + if (!metasLoaded) { + synchronized (this) { + if (!metasLoaded) { + for (int i = 0; i < ioMetas.size(); i++) { + byte[] metaBytes = ioMetas.get(i).metadata(); + indexMetas[i] = LuminaIndexMeta.deserialize(metaBytes); + } + metasLoaded = true; + } + } + } + } + + private void ensureLoadAllIndices() throws IOException { + if (!indicesLoaded) { + synchronized (this) { + if (!indicesLoaded) { + for (int i = 0; i < ioMetas.size(); i++) { + if (indices[i] == null) { + loadIndexAt(i); + } + } + indicesLoaded = true; + } + } + } + } + + private void ensureLoadIndices(List<Integer> positions) throws IOException { + synchronized (this) { + for (int pos : positions) { + if (indices[pos] == null) { + loadIndexAt(pos); + } + } + // Check if all indices are now loaded. + if (!indicesLoaded) { + boolean allLoaded = true; + for (LuminaIndex idx : indices) { + if (idx == null) { + allLoaded = false; + break; + } + } + if (allLoaded) { + indicesLoaded = true; + } + } + } + } + + private void loadIndexAt(int position) throws IOException { + GlobalIndexIOMeta ioMeta = ioMetas.get(position); + SeekableInputStream in = fileReader.getInputStream(ioMeta); + LuminaIndex index = null; + try { + index = loadIndex(in, position, ioMeta.fileSize()); + openStreams.add(in); + indices[position] = index; + } catch (Exception e) { + IOUtils.closeQuietly(index); + IOUtils.closeQuietly(in); + throw e; + } + } + + private LuminaIndex loadIndex(SeekableInputStream in, int position, long fileSize) + throws IOException { + LuminaIndexMeta meta = indexMetas[position]; + LuminaVectorMetric metric = LuminaVectorMetric.fromValue(meta.metricValue()); + LuminaIndexType indexType = meta.indexType(); + + LuminaFileInput fileInput = + new LuminaFileInput() { + @Override + public int read(byte[] b, int off, int len) throws IOException { + return in.read(b, off, len); + } + + @Override + public void seek(long position) throws IOException { + in.seek(position); + } + + @Override + public long getPos() throws IOException { + return in.getPos(); + } + + @Override + public void close() throws IOException { + // Do not close: stream lifecycle is managed by the Reader. + } + }; + + Map<String, String> extraOptions = new LinkedHashMap<>(); + return LuminaIndex.fromStream( + fileInput, fileSize, meta.dim(), metric, indexType, extraOptions); + } + + @Override + public void close() throws IOException { + Throwable firstException = null; + + // Close all indices first (releases native FileReader → JNI global refs). + for (int i = 0; i < indices.length; i++) { + LuminaIndex index = indices[i]; + if (index == null) { + continue; + } + try { + index.close(); + } catch (Throwable t) { + if (firstException == null) { + firstException = t; + } else { + firstException.addSuppressed(t); + } + } + indices[i] = null; + } + + // Then close underlying streams. + for (SeekableInputStream stream : openStreams) { + try { + if (stream != null) { + stream.close(); + } + } catch (Throwable t) { + if (firstException == null) { + firstException = t; + } else { + firstException.addSuppressed(t); + } + } + } + openStreams.clear(); Review Comment: `openStreams` is a `Collections.synchronizedList(...)`, but iterating it without synchronizing on the list can still race with concurrent `openStreams.add(in)` in `loadIndexAt`. If this reader can be used concurrently (search in one thread while another thread closes), consider synchronizing on `openStreams` during iteration/clear, or replace with a concurrency-friendly structure (e.g., `CopyOnWriteArrayList`) or close streams via a snapshot copy. ```suggestion List<SeekableInputStream> streamsSnapshot; synchronized (openStreams) { if (openStreams.isEmpty()) { streamsSnapshot = null; } else { streamsSnapshot = new ArrayList<>(openStreams); openStreams.clear(); } } if (streamsSnapshot != null) { for (SeekableInputStream stream : streamsSnapshot) { try { if (stream != null) { stream.close(); } } catch (Throwable t) { if (firstException == null) { firstException = t; } else { firstException.addSuppressed(t); } } } } ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
