This is an automated email from the ASF dual-hosted git repository.
JingsongLi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/paimon-vector-index.git
The following commit(s) were added to refs/heads/main by this push:
new 19dad70 Optimize vector index reads with batched pread (#26)
19dad70 is described below
commit 19dad70e4cafb0009cd32b467aaaf718488eae5b
Author: Jingsong Lee <[email protected]>
AuthorDate: Wed Jun 10 14:55:30 2026 +0800
Optimize vector index reads with batched pread (#26)
---
README.md | 11 +-
core/src/index.rs | 6 +-
core/src/index_io_util.rs | 19 +-
core/src/io.rs | 397 ++++++++++++++------
core/src/ivfflat_io.rs | 44 ++-
core/src/ivfhnswflat_io.rs | 405 +++++++++++++++++----
core/src/ivfhnswsq_io.rs | 404 +++++++++++++++++---
core/src/ivfpq.rs | 238 ++++++------
.../paimon/index/ivfpq/VectorIndexJavaApiTest.java | 7 +-
.../ivfpq/VectorIndexNativePanicBoundaryTest.java | 40 +-
.../paimon/index/ivfpq/VectorIndexInput.java | 23 ++
.../paimon/index/ivfpq/VectorIndexReader.java | 2 +-
jni/src/stream.rs | 237 ++++--------
python/src/lib.rs | 151 ++++++--
14 files changed, 1385 insertions(+), 599 deletions(-)
diff --git a/README.md b/README.md
index 66e93d9..13f202c 100644
--- a/README.md
+++ b/README.md
@@ -92,14 +92,21 @@ Java:
VectorIndexConfig config =
VectorIndexConfig.ivfHnswFlat(128, 1024, Metric.L2,
HnswConfig.DEFAULT);
VectorIndexWriter writer = new VectorIndexWriter(config);
-VectorIndexReader reader = new VectorIndexReader(input);
+VectorIndexReader reader = new VectorIndexReader(vectorIndexInput);
```
Python:
```python
+class VectorIndexInput:
+ def __init__(self, data: bytes):
+ self.data = data
+
+ def pread_many(self, ranges):
+ return [self.data[pos : pos + length] for pos, length in ranges]
+
writer = VectorIndexWriter(IvfPqConfig(128, 1024, 16, metric="l2"))
-reader = VectorIndexReader(file)
+reader = VectorIndexReader(VectorIndexInput(index_bytes))
ids, distances = reader.search(query, top_k=10, nprobe=16)
```
diff --git a/core/src/index.rs b/core/src/index.rs
index 674fb51..4b9aae6 100644
--- a/core/src/index.rs
+++ b/core/src/index.rs
@@ -17,7 +17,7 @@
use crate::distance::MetricType;
use crate::hnsw::HnswBuildParams;
-use crate::io::{write_index, IVFPQIndexReader, SeekRead, SeekWrite, MAGIC};
+use crate::io::{write_index, IVFPQIndexReader, ReadRequest, SeekRead,
SeekWrite, MAGIC};
use crate::ivfflat::IVFFlatIndex;
use crate::ivfflat_io::{
search_batch_ivfflat_reader, search_batch_ivfflat_reader_roaring_filter,
write_ivfflat_index,
@@ -293,11 +293,9 @@ pub enum VectorIndexReader<R: SeekRead> {
impl<R: SeekRead> VectorIndexReader<R> {
pub fn open(mut reader: R) -> io::Result<Self> {
- reader.seek(0)?;
let mut magic_buf = [0u8; 4];
- reader.read_exact(&mut magic_buf)?;
+ reader.pread(&mut [ReadRequest::new(0, &mut magic_buf)])?;
let magic = u32::from_le_bytes(magic_buf);
- reader.seek(0)?;
match magic {
IVFFLAT_MAGIC =>
Ok(Self::IvfFlat(IVFFlatIndexReader::open(reader)?)),
diff --git a/core/src/index_io_util.rs b/core/src/index_io_util.rs
index 292953b..d7d1fb1 100644
--- a/core/src/index_io_util.rs
+++ b/core/src/index_io_util.rs
@@ -17,7 +17,7 @@
use crate::distance::MetricType;
use crate::hnsw::{HnswBuildParams, HnswGraph};
-use crate::io::{SeekRead, SeekWrite};
+use crate::io::{PreadCursor, SeekRead, SeekWrite};
use roaring::RoaringTreemap;
use std::io;
@@ -219,25 +219,34 @@ pub(crate) fn write_f32_slice(out: &mut dyn SeekWrite,
data: &[f32]) -> io::Resu
out.write_all(&bytes)
}
-pub(crate) fn read_u32_le(reader: &mut dyn SeekRead) -> io::Result<u32> {
+pub(crate) fn read_u32_le<R: SeekRead + ?Sized>(
+ reader: &mut PreadCursor<'_, R>,
+) -> io::Result<u32> {
let mut buf = [0u8; 4];
reader.read_exact(&mut buf)?;
Ok(u32::from_le_bytes(buf))
}
-pub(crate) fn read_i32_le(reader: &mut dyn SeekRead) -> io::Result<i32> {
+pub(crate) fn read_i32_le<R: SeekRead + ?Sized>(
+ reader: &mut PreadCursor<'_, R>,
+) -> io::Result<i32> {
let mut buf = [0u8; 4];
reader.read_exact(&mut buf)?;
Ok(i32::from_le_bytes(buf))
}
-pub(crate) fn read_i64_le(reader: &mut dyn SeekRead) -> io::Result<i64> {
+pub(crate) fn read_i64_le<R: SeekRead + ?Sized>(
+ reader: &mut PreadCursor<'_, R>,
+) -> io::Result<i64> {
let mut buf = [0u8; 8];
reader.read_exact(&mut buf)?;
Ok(i64::from_le_bytes(buf))
}
-pub(crate) fn read_f32_vec(reader: &mut dyn SeekRead, count: usize) ->
io::Result<Vec<f32>> {
+pub(crate) fn read_f32_vec<R: SeekRead + ?Sized>(
+ reader: &mut PreadCursor<'_, R>,
+ count: usize,
+) -> io::Result<Vec<f32>> {
let byte_len = count.checked_mul(4).ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
diff --git a/core/src/io.rs b/core/src/io.rs
index dc8d8bc..7ce857f 100644
--- a/core/src/io.rs
+++ b/core/src/io.rs
@@ -30,21 +30,47 @@ pub const FLAG_BY_RESIDUAL: u32 = 1 << 1;
pub const FLAG_DELTA_IDS: u32 = 1 << 2;
pub const FLAG_TRANSPOSED_CODES: u32 = 1 << 3;
+pub struct ReadRequest<'a> {
+ pub pos: u64,
+ pub buf: &'a mut [u8],
+}
+
+impl<'a> ReadRequest<'a> {
+ pub fn new(pos: u64, buf: &'a mut [u8]) -> Self {
+ Self { pos, buf }
+ }
+}
+
pub trait SeekRead: Send {
- fn seek(&mut self, pos: u64) -> io::Result<()>;
- fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()>;
-
- /// Positional read: read `buf.len()` bytes at `pos` without changing the
cursor.
- /// Thread-safe if the underlying implementation supports it (e.g.,
pread(2)).
- /// Default implementation falls back to seek + read_exact.
- fn pread(&mut self, pos: u64, buf: &mut [u8]) -> io::Result<()> {
- self.seek(pos)?;
- self.read_exact(buf)
+ /// Positional reads for one or more ranges.
+ ///
+ /// Implementations may execute requests sequentially, coalesce them, or
issue
+ /// them concurrently when the underlying source supports independent
+ /// positional reads.
+ fn pread(&mut self, ranges: &mut [ReadRequest<'_>]) -> io::Result<()>;
+}
+
+pub(crate) struct PreadCursor<'a, R: SeekRead + ?Sized> {
+ reader: &'a mut R,
+ pos: u64,
+}
+
+impl<'a, R: SeekRead + ?Sized> PreadCursor<'a, R> {
+ pub(crate) fn new(reader: &'a mut R, pos: u64) -> Self {
+ Self { reader, pos }
}
- /// Whether this implementation supports true concurrent pread (no shared
cursor).
- fn supports_concurrent_pread(&self) -> bool {
- false
+ pub(crate) fn seek(&mut self, pos: u64) {
+ self.pos = pos;
+ }
+
+ pub(crate) fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
+ self.reader.pread(&mut [ReadRequest::new(self.pos, buf)])?;
+ self.pos = self
+ .pos
+ .checked_add(buf.len() as u64)
+ .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "read
cursor overflow"))?;
+ Ok(())
}
}
@@ -54,14 +80,15 @@ pub trait SeekWrite: Send {
}
impl<T: io::Read + io::Seek + Send> SeekRead for T {
- fn seek(&mut self, pos: u64) -> io::Result<()> {
- io::Seek::seek(self, io::SeekFrom::Start(pos))?;
+ fn pread(&mut self, ranges: &mut [ReadRequest<'_>]) -> io::Result<()> {
+ let old_pos = io::Seek::stream_position(self)?;
+ for range in ranges {
+ io::Seek::seek(self, io::SeekFrom::Start(range.pos))?;
+ io::Read::read_exact(self, range.buf)?;
+ }
+ io::Seek::seek(self, io::SeekFrom::Start(old_pos))?;
Ok(())
}
-
- fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
- io::Read::read_exact(self, buf)
- }
}
pub struct PosWriter<W: io::Write> {
@@ -191,19 +218,19 @@ fn write_f32_slice(out: &mut dyn SeekWrite, data: &[f32])
-> io::Result<()> {
out.write_all(&bytes)
}
-fn read_u32_le(reader: &mut dyn SeekRead) -> io::Result<u32> {
+fn read_u32_le<R: SeekRead + ?Sized>(reader: &mut PreadCursor<'_, R>) ->
io::Result<u32> {
let mut buf = [0u8; 4];
reader.read_exact(&mut buf)?;
Ok(u32::from_le_bytes(buf))
}
-fn read_i32_le(reader: &mut dyn SeekRead) -> io::Result<i32> {
+fn read_i32_le<R: SeekRead + ?Sized>(reader: &mut PreadCursor<'_, R>) ->
io::Result<i32> {
let mut buf = [0u8; 4];
reader.read_exact(&mut buf)?;
Ok(i32::from_le_bytes(buf))
}
-fn read_i64_le(reader: &mut dyn SeekRead) -> io::Result<i64> {
+fn read_i64_le<R: SeekRead + ?Sized>(reader: &mut PreadCursor<'_, R>) ->
io::Result<i64> {
let mut buf = [0u8; 8];
reader.read_exact(&mut buf)?;
Ok(i64::from_le_bytes(buf))
@@ -260,7 +287,10 @@ fn checked_list_bytes(count: usize, bytes_per_entry:
usize) -> io::Result<usize>
})
}
-fn read_f32_vec(reader: &mut dyn SeekRead, count: usize) ->
io::Result<Vec<f32>> {
+fn read_f32_vec<R: SeekRead + ?Sized>(
+ reader: &mut PreadCursor<'_, R>,
+ count: usize,
+) -> io::Result<Vec<f32>> {
let mut buf = vec![0u8; count * 4];
reader.read_exact(&mut buf)?;
let floats: Vec<f32> = buf
@@ -503,9 +533,9 @@ impl<R: SeekRead> IVFPQIndexReader<R> {
/// Open an index file. Only reads the 64-byte header.
/// Centroids, codebooks, and offset table are loaded lazily on first
search.
pub fn open(mut reader: R) -> io::Result<Self> {
- reader.seek(0)?;
+ let mut cursor = PreadCursor::new(&mut reader, 0);
- let magic = read_u32_le(&mut reader)?;
+ let magic = read_u32_le(&mut cursor)?;
if magic != MAGIC {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
@@ -513,7 +543,7 @@ impl<R: SeekRead> IVFPQIndexReader<R> {
));
}
- let version = read_u32_le(&mut reader)?;
+ let version = read_u32_le(&mut cursor)?;
if version != VERSION {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
@@ -521,11 +551,11 @@ impl<R: SeekRead> IVFPQIndexReader<R> {
));
}
- let d = validate_positive_i32(read_i32_le(&mut reader)?, "d")? as
usize;
- let nlist = validate_positive_i32(read_i32_le(&mut reader)?, "nlist")?
as usize;
- let m = validate_positive_i32(read_i32_le(&mut reader)?, "m")? as
usize;
- let ksub = validate_positive_i32(read_i32_le(&mut reader)?, "ksub")?
as usize;
- let dsub = validate_positive_i32(read_i32_le(&mut reader)?, "dsub")?
as usize;
+ let d = validate_positive_i32(read_i32_le(&mut cursor)?, "d")? as
usize;
+ let nlist = validate_positive_i32(read_i32_le(&mut cursor)?, "nlist")?
as usize;
+ let m = validate_positive_i32(read_i32_le(&mut cursor)?, "m")? as
usize;
+ let ksub = validate_positive_i32(read_i32_le(&mut cursor)?, "ksub")?
as usize;
+ let dsub = validate_positive_i32(read_i32_le(&mut cursor)?, "dsub")?
as usize;
if ksub != 16 && ksub != 256 {
return Err(io::Error::new(
@@ -552,18 +582,18 @@ impl<R: SeekRead> IVFPQIndexReader<R> {
));
}
- let metric_code = read_u32_le(&mut reader)?;
+ let metric_code = read_u32_le(&mut cursor)?;
let metric = MetricType::from_code(metric_code).ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("Unknown metric type: {}", metric_code),
)
})?;
- let total_vectors = read_i64_le(&mut reader)?;
+ let total_vectors = read_i64_le(&mut cursor)?;
- let flags = read_u32_le(&mut reader)?;
+ let flags = read_u32_le(&mut cursor)?;
let mut skip = [0u8; 20];
- reader.read_exact(&mut skip)?;
+ cursor.read_exact(&mut skip)?;
let by_residual = flags & FLAG_BY_RESIDUAL != 0;
let delta_ids = flags & FLAG_DELTA_IDS != 0;
let transposed_codes = flags & FLAG_TRANSPOSED_CODES != 0;
@@ -629,9 +659,10 @@ impl<R: SeekRead> IVFPQIndexReader<R> {
let pq_centroids_count = checked_section_size(mk, dsub)?;
// Seek to start of data sections
+ let mut cursor = PreadCursor::new(&mut self.reader,
self.centroids_offset);
if self.has_opq {
- self.reader.seek(HEADER_SIZE as u64)?;
- let rotation = read_f32_vec(&mut self.reader, rotation_count)?;
+ cursor.seek(HEADER_SIZE as u64);
+ let rotation = read_f32_vec(&mut cursor, rotation_count)?;
self.opq = Some(OPQMatrix {
d,
m,
@@ -642,13 +673,11 @@ impl<R: SeekRead> IVFPQIndexReader<R> {
niter_pq_0: 0,
max_train_points: 0,
});
- } else {
- self.reader.seek(self.centroids_offset)?;
}
- self.quantizer_centroids = read_f32_vec(&mut self.reader,
centroids_count)?;
+ self.quantizer_centroids = read_f32_vec(&mut cursor, centroids_count)?;
- let pq_centroids = read_f32_vec(&mut self.reader, pq_centroids_count)?;
+ let pq_centroids = read_f32_vec(&mut cursor, pq_centroids_count)?;
self.pq = ProductQuantizer {
d,
m,
@@ -664,8 +693,8 @@ impl<R: SeekRead> IVFPQIndexReader<R> {
self.list_counts = vec![0i32; nlist];
self.list_id_bytes_lens = vec![0i32; nlist];
for i in 0..nlist {
- self.list_offsets[i] = read_i64_le(&mut self.reader)?;
- let count = read_i32_le(&mut self.reader)?;
+ self.list_offsets[i] = read_i64_le(&mut cursor)?;
+ let count = read_i32_le(&mut cursor)?;
if count < 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
@@ -673,7 +702,7 @@ impl<R: SeekRead> IVFPQIndexReader<R> {
));
}
self.list_counts[i] = count;
- let id_bytes_len = read_i32_le(&mut self.reader)?;
+ let id_bytes_len = read_i32_le(&mut cursor)?;
if id_bytes_len < 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
@@ -724,27 +753,15 @@ impl<R: SeekRead> IVFPQIndexReader<R> {
)
})?;
let mut payload = vec![0u8; payload_len];
- self.reader.pread(offset, &mut payload)?;
+ self.reader
+ .pread(&mut [ReadRequest::new(offset, &mut payload)])?;
- let base_id =
i64::from_le_bytes(payload[0..8].try_into().unwrap());
- let encoded_id_bytes_len =
i32::from_le_bytes(payload[8..12].try_into().unwrap());
- if encoded_id_bytes_len != id_bytes_len_from_table {
- return Err(io::Error::new(
- io::ErrorKind::InvalidData,
- format!(
- "offset table id_bytes_len {} does not match list
header {}",
- id_bytes_len_from_table, encoded_id_bytes_len
- ),
- ));
- }
- let id_bytes = &payload[12..12 + id_bytes_len];
- let ids = decode_delta_varint_ids(base_id, id_bytes, count)?;
- let codes = payload[12 + id_bytes_len..].to_vec();
- return Ok((ids, codes));
+ return decode_delta_list_payload(&payload, count,
id_bytes_len_from_table);
}
let mut header = [0u8; 12];
- self.reader.pread(offset, &mut header)?;
+ self.reader
+ .pread(&mut [ReadRequest::new(offset, &mut header)])?;
let base_id = i64::from_le_bytes(header[0..8].try_into().unwrap());
let id_bytes_len =
i32::from_le_bytes(header[8..12].try_into().unwrap());
@@ -762,7 +779,8 @@ impl<R: SeekRead> IVFPQIndexReader<R> {
)
})?;
let mut payload = vec![0u8; rest_len];
- self.reader.pread(offset + 12, &mut payload)?;
+ self.reader
+ .pread(&mut [ReadRequest::new(offset + 12, &mut payload)])?;
let id_bytes = &payload[..id_bytes_len];
let ids = decode_delta_varint_ids(base_id, id_bytes, count)?;
@@ -778,15 +796,132 @@ impl<R: SeekRead> IVFPQIndexReader<R> {
)
})?;
let mut payload = vec![0u8; total_len];
- self.reader.pread(offset, &mut payload)?;
+ self.reader
+ .pread(&mut [ReadRequest::new(offset, &mut payload)])?;
- let ids = payload[..id_bytes_len]
- .chunks_exact(8)
- .map(|c| i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4],
c[5], c[6], c[7]]))
- .collect();
- let codes = payload[id_bytes_len..].to_vec();
- Ok((ids, codes))
+ Ok(decode_raw_list_payload(&payload, id_bytes_len))
+ }
+ }
+
+ /// Read multiple inverted lists. Lists whose payload length is known from
+ /// metadata are issued through a single batched pread call.
+ pub fn read_inverted_lists(&mut self, list_ids: &[usize]) ->
io::Result<Vec<InvertedListData>> {
+ self.ensure_loaded()?;
+
+ let code_size = self.pq.code_size();
+ let mut results: Vec<Option<InvertedListData>> =
+ (0..list_ids.len()).map(|_| None).collect();
+ let mut metas = Vec::new();
+ let mut payloads = Vec::new();
+
+ for (input_idx, &list_id) in list_ids.iter().enumerate() {
+ if list_id >= self.nlist {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!("list_id {} out of range (nlist={})", list_id,
self.nlist),
+ ));
+ }
+ let count = self.list_counts[list_id] as usize;
+ if count == 0 {
+ results[input_idx] = Some(InvertedListData {
+ list_id,
+ ids: Vec::new(),
+ codes: Vec::new(),
+ });
+ continue;
+ }
+
+ let offset = checked_list_offset(self.list_offsets[list_id],
list_id)?;
+ let code_bytes = checked_list_bytes(count, code_size)?;
+
+ if self.delta_ids {
+ let id_bytes_len_from_table = self.list_id_bytes_lens[list_id];
+ if id_bytes_len_from_table > 0 {
+ let id_bytes_len = id_bytes_len_from_table as usize;
+ let payload_len = 12usize
+ .checked_add(id_bytes_len)
+ .and_then(|len| len.checked_add(code_bytes))
+ .ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ "inverted list payload size overflow",
+ )
+ })?;
+ metas.push(BatchedListRead {
+ input_idx,
+ list_id,
+ count,
+ offset,
+ format: BatchedListFormat::Delta {
+ id_bytes_len_from_table,
+ },
+ });
+ payloads.push(vec![0u8; payload_len]);
+ } else {
+ let (ids, codes) = self.read_inverted_list(list_id)?;
+ results[input_idx] = Some(InvertedListData {
+ list_id,
+ ids,
+ codes,
+ });
+ }
+ } else {
+ let id_bytes_len = checked_list_bytes(count, 8)?;
+ let payload_len =
id_bytes_len.checked_add(code_bytes).ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ "inverted list payload size overflow",
+ )
+ })?;
+ metas.push(BatchedListRead {
+ input_idx,
+ list_id,
+ count,
+ offset,
+ format: BatchedListFormat::Raw { id_bytes_len },
+ });
+ payloads.push(vec![0u8; payload_len]);
+ }
+ }
+
+ if !metas.is_empty() {
+ {
+ let mut requests: Vec<_> = payloads
+ .iter_mut()
+ .zip(metas.iter())
+ .map(|(payload, meta)| ReadRequest::new(meta.offset,
payload.as_mut_slice()))
+ .collect();
+ self.reader.pread(&mut requests)?;
+ }
+
+ for (meta, payload) in metas.into_iter().zip(payloads) {
+ let (ids, codes) = match meta.format {
+ BatchedListFormat::Delta {
+ id_bytes_len_from_table,
+ } => decode_delta_list_payload(&payload, meta.count,
id_bytes_len_from_table)?,
+ BatchedListFormat::Raw { id_bytes_len } => {
+ decode_raw_list_payload(&payload, id_bytes_len)
+ }
+ };
+ results[meta.input_idx] = Some(InvertedListData {
+ list_id: meta.list_id,
+ ids,
+ codes,
+ });
+ }
}
+
+ results
+ .into_iter()
+ .map(|result| {
+ result.ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ "missing batched inverted list read result",
+ )
+ })
+ })
+ .collect()
}
pub fn search(
@@ -815,10 +950,71 @@ impl<R: SeekRead> IVFPQIndexReader<R> {
roaring_filter_bytes,
)
}
+}
+
+pub struct InvertedListData {
+ pub list_id: usize,
+ pub ids: Vec<i64>,
+ pub codes: Vec<u8>,
+}
+
+#[derive(Clone, Copy)]
+struct BatchedListRead {
+ input_idx: usize,
+ list_id: usize,
+ count: usize,
+ offset: u64,
+ format: BatchedListFormat,
+}
+
+#[derive(Clone, Copy)]
+enum BatchedListFormat {
+ Delta { id_bytes_len_from_table: i32 },
+ Raw { id_bytes_len: usize },
+}
- pub fn supports_concurrent_pread(&self) -> bool {
- self.reader.supports_concurrent_pread()
+fn decode_delta_list_payload(
+ payload: &[u8],
+ count: usize,
+ id_bytes_len_from_table: i32,
+) -> io::Result<(Vec<i64>, Vec<u8>)> {
+ let id_bytes_len = id_bytes_len_from_table as usize;
+ let header_len = 12usize.checked_add(id_bytes_len).ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ "inverted list payload size overflow",
+ )
+ })?;
+ if payload.len() < header_len {
+ return Err(io::Error::new(
+ io::ErrorKind::UnexpectedEof,
+ "truncated delta inverted list payload",
+ ));
+ }
+ let base_id = i64::from_le_bytes(payload[0..8].try_into().unwrap());
+ let encoded_id_bytes_len =
i32::from_le_bytes(payload[8..12].try_into().unwrap());
+ if encoded_id_bytes_len != id_bytes_len_from_table {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!(
+ "offset table id_bytes_len {} does not match list header {}",
+ id_bytes_len_from_table, encoded_id_bytes_len
+ ),
+ ));
}
+ let id_bytes = &payload[12..header_len];
+ let ids = decode_delta_varint_ids(base_id, id_bytes, count)?;
+ let codes = payload[header_len..].to_vec();
+ Ok((ids, codes))
+}
+
+fn decode_raw_list_payload(payload: &[u8], id_bytes_len: usize) -> (Vec<i64>,
Vec<u8>) {
+ let ids = payload[..id_bytes_len]
+ .chunks_exact(8)
+ .map(|c| i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6],
c[7]]))
+ .collect();
+ let codes = payload[id_bytes_len..].to_vec();
+ (ids, codes)
}
#[allow(dead_code)]
@@ -867,8 +1063,6 @@ mod tests {
#[derive(Default)]
struct ReadStats {
- seek_calls: usize,
- read_exact_calls: usize,
pread_calls: usize,
}
@@ -887,29 +1081,17 @@ mod tests {
}
impl SeekRead for CountingPreadCursor {
- fn seek(&mut self, pos: u64) -> io::Result<()> {
- self.stats.lock().unwrap().seek_calls += 1;
- io::Seek::seek(&mut self.inner, io::SeekFrom::Start(pos))?;
+ fn pread(&mut self, ranges: &mut [ReadRequest<'_>]) -> io::Result<()> {
+ for range in ranges {
+ self.stats.lock().unwrap().pread_calls += 1;
+ let old_pos = io::Seek::stream_position(&mut self.inner)?;
+ io::Seek::seek(&mut self.inner,
io::SeekFrom::Start(range.pos))?;
+ let result = io::Read::read_exact(&mut self.inner, range.buf);
+ io::Seek::seek(&mut self.inner, io::SeekFrom::Start(old_pos))?;
+ result?;
+ }
Ok(())
}
-
- fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
- self.stats.lock().unwrap().read_exact_calls += 1;
- io::Read::read_exact(&mut self.inner, buf)
- }
-
- fn pread(&mut self, pos: u64, buf: &mut [u8]) -> io::Result<()> {
- self.stats.lock().unwrap().pread_calls += 1;
- let old_pos = io::Seek::stream_position(&mut self.inner)?;
- io::Seek::seek(&mut self.inner, io::SeekFrom::Start(pos))?;
- let result = io::Read::read_exact(&mut self.inner, buf);
- io::Seek::seek(&mut self.inner, io::SeekFrom::Start(old_pos))?;
- result
- }
-
- fn supports_concurrent_pread(&self) -> bool {
- true
- }
}
fn offset_table_start(d: usize, nlist: usize, m: usize, ksub: usize) ->
usize {
@@ -1012,8 +1194,6 @@ mod tests {
{
let mut stats = stats.lock().unwrap();
- stats.seek_calls = 0;
- stats.read_exact_calls = 0;
stats.pread_calls = 0;
}
@@ -1032,20 +1212,29 @@ mod tests {
assert!(!codes.is_empty());
let stats = stats.lock().unwrap();
- assert_eq!(
- stats.seek_calls, 0,
- "reading a list should not move the shared cursor"
- );
- assert_eq!(
- stats.read_exact_calls, 0,
- "reading a list should use positional reads after metadata is
loaded"
- );
assert_eq!(
stats.pread_calls, 1,
"delta-varint lists with offset-table id length should use one
pread"
);
}
+ #[test]
+ fn test_default_pread_handles_multiple_ranges() {
+ let mut cursor = Cursor::new(vec![0, 1, 2, 3, 4, 5, 6, 7]);
+ let mut first = [0u8; 2];
+ let mut second = [0u8; 3];
+
+ cursor
+ .pread(&mut [
+ ReadRequest::new(2, &mut first),
+ ReadRequest::new(5, &mut second),
+ ])
+ .unwrap();
+
+ assert_eq!(first, [2, 3]);
+ assert_eq!(second, [5, 6, 7]);
+ }
+
#[test]
fn test_read_inverted_list_falls_back_for_old_delta_offset_table() {
let d = 8;
@@ -1083,8 +1272,6 @@ mod tests {
{
let mut stats = stats.lock().unwrap();
- stats.seek_calls = 0;
- stats.read_exact_calls = 0;
stats.pread_calls = 0;
}
diff --git a/core/src/ivfflat_io.rs b/core/src/ivfflat_io.rs
index da398c4..d293e00 100644
--- a/core/src/ivfflat_io.rs
+++ b/core/src/ivfflat_io.rs
@@ -16,7 +16,7 @@
// under the License.
use crate::distance::{fvec_distance, fvec_normalize, MetricType};
-use crate::io::{SeekRead, SeekWrite};
+use crate::io::{PreadCursor, ReadRequest, SeekRead, SeekWrite};
use crate::ivfflat::IVFFlatIndex;
use crate::ivfpq::RowIdFilter;
use crate::kmeans;
@@ -161,16 +161,16 @@ pub struct IVFFlatIndexReader<R: SeekRead> {
impl<R: SeekRead> IVFFlatIndexReader<R> {
pub fn open(mut reader: R) -> io::Result<Self> {
- reader.seek(0)?;
+ let mut cursor = PreadCursor::new(&mut reader, 0);
- let magic = read_u32_le(&mut reader)?;
+ let magic = read_u32_le(&mut cursor)?;
if magic != IVFFLAT_MAGIC {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Invalid IVFFLAT magic: 0x{:08X}", magic),
));
}
- let version = read_u32_le(&mut reader)?;
+ let version = read_u32_le(&mut cursor)?;
if version != IVFFLAT_VERSION {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
@@ -178,19 +178,19 @@ impl<R: SeekRead> IVFFlatIndexReader<R> {
));
}
- let d = validate_positive_i32(read_i32_le(&mut reader)?, "d")? as
usize;
- let nlist = validate_positive_i32(read_i32_le(&mut reader)?, "nlist")?
as usize;
- let metric_code = read_u32_le(&mut reader)?;
+ let d = validate_positive_i32(read_i32_le(&mut cursor)?, "d")? as
usize;
+ let nlist = validate_positive_i32(read_i32_le(&mut cursor)?, "nlist")?
as usize;
+ let metric_code = read_u32_le(&mut cursor)?;
let metric = MetricType::from_code(metric_code).ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("Unknown metric type: {}", metric_code),
)
})?;
- let total_vectors = read_i64_le(&mut reader)?;
- let flags = read_u32_le(&mut reader)?;
+ let total_vectors = read_i64_le(&mut cursor)?;
+ let flags = read_u32_le(&mut cursor)?;
let mut reserved = [0u8; 32];
- reader.read_exact(&mut reserved)?;
+ cursor.read_exact(&mut reserved)?;
Ok(Self {
reader,
@@ -212,15 +212,15 @@ impl<R: SeekRead> IVFFlatIndexReader<R> {
return Ok(());
}
- self.reader.seek(IVFFLAT_HEADER_SIZE as u64)?;
+ let mut cursor = PreadCursor::new(&mut self.reader,
IVFFLAT_HEADER_SIZE as u64);
self.quantizer_centroids =
- read_f32_vec(&mut self.reader, checked_section_size(self.nlist,
self.d)?)?;
+ read_f32_vec(&mut cursor, checked_section_size(self.nlist,
self.d)?)?;
self.list_offsets = vec![0; self.nlist];
self.list_counts = vec![0; self.nlist];
self.list_id_bytes_lens = vec![0; self.nlist];
for list_id in 0..self.nlist {
- self.list_offsets[list_id] = read_i64_le(&mut self.reader)?;
- let count = read_i32_le(&mut self.reader)?;
+ self.list_offsets[list_id] = read_i64_le(&mut cursor)?;
+ let count = read_i32_le(&mut cursor)?;
if count < 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
@@ -228,7 +228,7 @@ impl<R: SeekRead> IVFFlatIndexReader<R> {
));
}
self.list_counts[list_id] = count;
- let id_bytes_len = read_i32_le(&mut self.reader)?;
+ let id_bytes_len = read_i32_le(&mut cursor)?;
if id_bytes_len < 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
@@ -266,7 +266,8 @@ impl<R: SeekRead> IVFFlatIndexReader<R> {
io::Error::new(io::ErrorKind::InvalidData, "IVF-FLAT list
payload overflow")
})?;
let mut payload = vec![0u8; payload_len];
- self.reader.pread(offset, &mut payload)?;
+ self.reader
+ .pread(&mut [ReadRequest::new(offset, &mut payload)])?;
let base_id =
i64::from_le_bytes(payload[0..8].try_into().unwrap());
let encoded_len =
i32::from_le_bytes(payload[8..12].try_into().unwrap());
if encoded_len < 0 || encoded_len as usize != id_bytes_len {
@@ -555,19 +556,19 @@ fn write_f32_slice(out: &mut dyn SeekWrite, data: &[f32])
-> io::Result<()> {
out.write_all(&bytes)
}
-fn read_u32_le(reader: &mut dyn SeekRead) -> io::Result<u32> {
+fn read_u32_le<R: SeekRead + ?Sized>(reader: &mut PreadCursor<'_, R>) ->
io::Result<u32> {
let mut buf = [0u8; 4];
reader.read_exact(&mut buf)?;
Ok(u32::from_le_bytes(buf))
}
-fn read_i32_le(reader: &mut dyn SeekRead) -> io::Result<i32> {
+fn read_i32_le<R: SeekRead + ?Sized>(reader: &mut PreadCursor<'_, R>) ->
io::Result<i32> {
let mut buf = [0u8; 4];
reader.read_exact(&mut buf)?;
Ok(i32::from_le_bytes(buf))
}
-fn read_i64_le(reader: &mut dyn SeekRead) -> io::Result<i64> {
+fn read_i64_le<R: SeekRead + ?Sized>(reader: &mut PreadCursor<'_, R>) ->
io::Result<i64> {
let mut buf = [0u8; 8];
reader.read_exact(&mut buf)?;
Ok(i64::from_le_bytes(buf))
@@ -709,7 +710,10 @@ fn checked_list_bytes(count: usize, bytes_per_entry:
usize) -> io::Result<usize>
})
}
-fn read_f32_vec(reader: &mut dyn SeekRead, count: usize) ->
io::Result<Vec<f32>> {
+fn read_f32_vec<R: SeekRead + ?Sized>(
+ reader: &mut PreadCursor<'_, R>,
+ count: usize,
+) -> io::Result<Vec<f32>> {
let mut buf = vec![0u8; count * 4];
reader.read_exact(&mut buf)?;
bytes_to_f32_vec(&buf)
diff --git a/core/src/ivfhnswflat_io.rs b/core/src/ivfhnswflat_io.rs
index ebd155c..d7d970a 100644
--- a/core/src/ivfhnswflat_io.rs
+++ b/core/src/ivfhnswflat_io.rs
@@ -24,7 +24,7 @@ use crate::index_io_util::{
u64_to_i64, usize_to_i32, usize_to_i64, validate_positive_i32,
validate_search_inputs,
write_f32_slice, write_i32_le, write_i64_le, write_u32_le,
};
-use crate::io::{SeekRead, SeekWrite};
+use crate::io::{PreadCursor, ReadRequest, SeekRead, SeekWrite};
use crate::ivfhnswflat::IVFHNSWFlatIndex;
use crate::ivfpq::RowIdFilter;
use crate::kmeans;
@@ -34,6 +34,7 @@ use std::io;
pub const IVF_HNSW_FLAT_MAGIC: u32 = 0x4948464C; // "IHFL"
pub const IVF_HNSW_FLAT_VERSION: u32 = 1;
pub const IVF_HNSW_FLAT_HEADER_SIZE: usize = 64;
+const MAX_COALESCED_READ_GAP_BYTES: u64 = 1 << 20;
pub fn write_ivfhnswflat_index(
index: &IVFHNSWFlatIndex,
@@ -150,16 +151,16 @@ pub struct IVFHNSWFlatIndexReader<R: SeekRead> {
impl<R: SeekRead> IVFHNSWFlatIndexReader<R> {
pub fn open(mut reader: R) -> io::Result<Self> {
- reader.seek(0)?;
+ let mut cursor = PreadCursor::new(&mut reader, 0);
- let magic = read_u32_le(&mut reader)?;
+ let magic = read_u32_le(&mut cursor)?;
if magic != IVF_HNSW_FLAT_MAGIC {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Invalid IVF_HNSW_FLAT magic: 0x{:08X}", magic),
));
}
- let version = read_u32_le(&mut reader)?;
+ let version = read_u32_le(&mut cursor)?;
if version != IVF_HNSW_FLAT_VERSION {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
@@ -167,27 +168,27 @@ impl<R: SeekRead> IVFHNSWFlatIndexReader<R> {
));
}
- let d = validate_positive_i32(read_i32_le(&mut reader)?, "d")? as
usize;
- let nlist = validate_positive_i32(read_i32_le(&mut reader)?, "nlist")?
as usize;
- let metric_code = read_u32_le(&mut reader)?;
+ let d = validate_positive_i32(read_i32_le(&mut cursor)?, "d")? as
usize;
+ let nlist = validate_positive_i32(read_i32_le(&mut cursor)?, "nlist")?
as usize;
+ let metric_code = read_u32_le(&mut cursor)?;
let metric = MetricType::from_code(metric_code).ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("Unknown metric type: {}", metric_code),
)
})?;
- let total_vectors = read_i64_le(&mut reader)?;
+ let total_vectors = read_i64_le(&mut cursor)?;
let hnsw_params = HnswBuildParams {
- m: validate_positive_i32(read_i32_le(&mut reader)?, "hnsw m")? as
usize,
+ m: validate_positive_i32(read_i32_le(&mut cursor)?, "hnsw m")? as
usize,
ef_construction: validate_positive_i32(
- read_i32_le(&mut reader)?,
+ read_i32_le(&mut cursor)?,
"hnsw ef_construction",
)? as usize,
- max_level: validate_positive_i32(read_i32_le(&mut reader)?, "hnsw
max_level")? as usize,
+ max_level: validate_positive_i32(read_i32_le(&mut cursor)?, "hnsw
max_level")? as usize,
}
.sanitized();
let mut reserved = [0u8; 24];
- reader.read_exact(&mut reserved)?;
+ cursor.read_exact(&mut reserved)?;
Ok(Self {
reader,
@@ -209,15 +210,15 @@ impl<R: SeekRead> IVFHNSWFlatIndexReader<R> {
return Ok(());
}
- self.reader.seek(IVF_HNSW_FLAT_HEADER_SIZE as u64)?;
+ let mut cursor = PreadCursor::new(&mut self.reader,
IVF_HNSW_FLAT_HEADER_SIZE as u64);
self.quantizer_centroids =
- read_f32_vec(&mut self.reader, checked_section_size(self.nlist,
self.d)?)?;
+ read_f32_vec(&mut cursor, checked_section_size(self.nlist,
self.d)?)?;
self.list_offsets = vec![0; self.nlist];
self.list_counts = vec![0; self.nlist];
self.list_graph_bytes_lens = vec![0; self.nlist];
for list_id in 0..self.nlist {
- self.list_offsets[list_id] = read_i64_le(&mut self.reader)?;
- let count = read_i32_le(&mut self.reader)?;
+ self.list_offsets[list_id] = read_i64_le(&mut cursor)?;
+ let count = read_i32_le(&mut cursor)?;
if count < 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
@@ -225,7 +226,7 @@ impl<R: SeekRead> IVFHNSWFlatIndexReader<R> {
));
}
self.list_counts[list_id] = count;
- let graph_bytes_len = read_i32_le(&mut self.reader)?;
+ let graph_bytes_len = read_i32_le(&mut cursor)?;
if graph_bytes_len < 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
@@ -236,7 +237,7 @@ impl<R: SeekRead> IVFHNSWFlatIndexReader<R> {
));
}
self.list_graph_bytes_lens[list_id] = graph_bytes_len;
- let _reserved = read_i64_le(&mut self.reader)?;
+ let _reserved = read_i64_le(&mut cursor)?;
}
self.loaded = true;
@@ -256,6 +257,132 @@ impl<R: SeekRead> IVFHNSWFlatIndexReader<R> {
fn read_graph_list(&mut self, list_id: usize) ->
io::Result<Option<GraphList>> {
self.ensure_loaded()?;
+ let Some(meta) = self.list_payload_meta(list_id)? else {
+ return Ok(None);
+ };
+ let mut payload = vec![0u8; meta.payload_len];
+ self.reader
+ .pread(&mut [ReadRequest::new(meta.offset, &mut payload)])?;
+
+ self.decode_graph_list_payload(meta, &payload).map(Some)
+ }
+
+ fn read_graph_lists_coalesced(
+ &mut self,
+ list_ids: &[usize],
+ ) -> io::Result<Vec<(usize, GraphList)>> {
+ self.ensure_loaded()?;
+ let mut metas = Vec::new();
+ for &list_id in list_ids {
+ if let Some(meta) = self.list_payload_meta(list_id)? {
+ metas.push(meta);
+ }
+ }
+ if metas.is_empty() {
+ return Ok(Vec::new());
+ }
+
+ metas.sort_by_key(|meta| meta.offset);
+ let mut loaded = Vec::with_capacity(metas.len());
+ let mut range_start = metas[0].offset;
+ let mut range_end = metas[0].end_offset()?;
+ let mut range_payload_bytes = metas[0].payload_len;
+ let mut range_metas = vec![metas[0]];
+ for &meta in metas.iter().skip(1) {
+ let meta_end = meta.end_offset()?;
+ let gap = meta.offset.checked_sub(range_end).ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ "IVF-HNSW-FLAT list payload offsets overlap",
+ )
+ })?;
+ if should_coalesce_gap(
+ gap,
+ range_start,
+ meta_end,
+ range_payload_bytes,
+ meta.payload_len,
+ ) {
+ range_end = meta_end;
+ range_payload_bytes = range_payload_bytes
+ .checked_add(meta.payload_len)
+ .ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ "coalesced IVF-HNSW-FLAT requested payload bytes
overflow",
+ )
+ })?;
+ range_metas.push(meta);
+ } else {
+ self.read_coalesced_graph_list_range(
+ range_start,
+ range_end,
+ &range_metas,
+ &mut loaded,
+ )?;
+ range_start = meta.offset;
+ range_end = meta_end;
+ range_payload_bytes = meta.payload_len;
+ range_metas.clear();
+ range_metas.push(meta);
+ }
+ }
+ self.read_coalesced_graph_list_range(range_start, range_end,
&range_metas, &mut loaded)?;
+
+ loaded.sort_by_key(|(list_id, _)| {
+ list_ids
+ .iter()
+ .position(|&requested_id| requested_id == *list_id)
+ .unwrap_or(usize::MAX)
+ });
+ Ok(loaded)
+ }
+
+ fn read_coalesced_graph_list_range(
+ &mut self,
+ range_start: u64,
+ range_end: u64,
+ metas: &[ListPayloadMeta],
+ loaded: &mut Vec<(usize, GraphList)>,
+ ) -> io::Result<()> {
+ let byte_len =
usize::try_from(range_end.checked_sub(range_start).ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ "coalesced IVF-HNSW-FLAT read range is invalid",
+ )
+ })?)
+ .map_err(|_| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ "coalesced IVF-HNSW-FLAT read range exceeds usize",
+ )
+ })?;
+ let mut payload = vec![0u8; byte_len];
+ self.reader
+ .pread(&mut [ReadRequest::new(range_start, &mut payload)])?;
+
+ for &meta in metas {
+ let start = usize::try_from(meta.offset - range_start).map_err(|_|
{
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ "coalesced IVF-HNSW-FLAT payload offset exceeds usize",
+ )
+ })?;
+ let end = start.checked_add(meta.payload_len).ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ "coalesced IVF-HNSW-FLAT payload slice overflows",
+ )
+ })?;
+ loaded.push((
+ meta.list_id,
+ self.decode_graph_list_payload(meta, &payload[start..end])?,
+ ));
+ }
+ Ok(())
+ }
+
+ fn list_payload_meta(&self, list_id: usize) ->
io::Result<Option<ListPayloadMeta>> {
if list_id >= self.nlist {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
@@ -276,12 +403,33 @@ impl<R: SeekRead> IVFHNSWFlatIndexReader<R> {
));
}
let payload_len = list_payload_len(count, self.d, graph_bytes_len)?;
- let mut payload = vec![0u8; payload_len];
- self.reader.pread(offset, &mut payload)?;
+ Ok(Some(ListPayloadMeta {
+ list_id,
+ offset,
+ count,
+ payload_len,
+ }))
+ }
- let ids_bytes_len = checked_list_bytes(count, 8)?;
+ fn decode_graph_list_payload(
+ &self,
+ meta: ListPayloadMeta,
+ payload: &[u8],
+ ) -> io::Result<GraphList> {
+ if payload.len() != meta.payload_len {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!(
+ "list {} payload length {} does not match expected {}",
+ meta.list_id,
+ payload.len(),
+ meta.payload_len
+ ),
+ ));
+ }
+ let ids_bytes_len = checked_list_bytes(meta.count, 8)?;
let vector_bytes_len = checked_list_bytes(
- count,
+ meta.count,
self.d.checked_mul(4).ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
@@ -297,7 +445,7 @@ impl<R: SeekRead> IVFHNSWFlatIndexReader<R> {
let graph = decode_graph(
&payload[ids_bytes_len + vector_bytes_len..],
vectors.clone(),
- count,
+ meta.count,
self.d,
self.metric,
self.hnsw_params,
@@ -305,10 +453,10 @@ impl<R: SeekRead> IVFHNSWFlatIndexReader<R> {
let graph = graph.ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
- format!("list {} is missing HNSW graph", list_id),
+ format!("list {} is missing HNSW graph", meta.list_id),
)
})?;
- Ok(Some(GraphList { ids, graph }))
+ Ok(GraphList { ids, graph })
}
pub fn search(
@@ -357,10 +505,8 @@ impl<R: SeekRead> IVFHNSWFlatIndexReader<R> {
let (probe_indices, _) =
kmeans::find_topk(&q, &self.quantizer_centroids, self.nlist,
self.d, nprobe);
let mut loaded_lists = Vec::with_capacity(probe_indices.len());
- for list_id in probe_indices {
- if let Some(list) = self.read_graph_list(list_id)? {
- loaded_lists.push(list);
- }
+ for (_, list) in self.read_graph_lists_coalesced(&probe_indices)? {
+ loaded_lists.push(list);
}
let search_lists: Vec<_> = loaded_lists
.iter()
@@ -449,10 +595,7 @@ pub fn search_batch_ivfhnswflat_reader_filter<R: SeekRead>(
let mut heaps: Vec<TopKHeap> = (0..nq).map(|_| TopKHeap::new(k)).collect();
let mut query_filtered_counts = vec![0usize; nq];
let mut loaded_lists = Vec::with_capacity(unique_lists.len());
- for list_id in unique_lists {
- let Some(list) = reader.read_graph_list(list_id)? else {
- continue;
- };
+ for (list_id, list) in reader.read_graph_lists_coalesced(&unique_lists)? {
if let Some(f) = filter {
let filtered = list.ids.iter().filter(|&&id|
f.contains(id)).count();
for &qi in &list_to_queries[list_id] {
@@ -552,6 +695,46 @@ struct GraphList {
graph: HnswGraph,
}
+#[derive(Clone, Copy)]
+struct ListPayloadMeta {
+ list_id: usize,
+ offset: u64,
+ count: usize,
+ payload_len: usize,
+}
+
+impl ListPayloadMeta {
+ fn end_offset(self) -> io::Result<u64> {
+ self.offset
+ .checked_add(self.payload_len as u64)
+ .ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ "IVF-HNSW-FLAT list payload offset overflows u64",
+ )
+ })
+ }
+}
+
+fn should_coalesce_gap(
+ gap: u64,
+ range_start: u64,
+ next_range_end: u64,
+ current_payload_bytes: usize,
+ next_payload_bytes: usize,
+) -> bool {
+ if gap > MAX_COALESCED_READ_GAP_BYTES {
+ return false;
+ }
+ let Some(requested_bytes) =
current_payload_bytes.checked_add(next_payload_bytes) else {
+ return false;
+ };
+ let Some(range_bytes) = next_range_end.checked_sub(range_start) else {
+ return false;
+ };
+ range_bytes <= requested_bytes.saturating_mul(2) as u64
+}
+
struct LoadedBatchList {
query_ids: Vec<usize>,
ids: Vec<i64>,
@@ -677,7 +860,7 @@ mod tests {
use crate::distance::MetricType;
use crate::hnsw::HnswBuildParams;
use crate::index_io_util::decode_graph;
- use crate::io::{PosWriter, SeekRead};
+ use crate::io::{PosWriter, ReadRequest, SeekRead};
use crate::ivfhnswflat::IVFHNSWFlatIndex;
use crate::ivfhnswflat_io::{
search_batch_ivfhnswflat_reader,
search_batch_ivfhnswflat_reader_roaring_filter,
@@ -844,6 +1027,73 @@ mod tests {
}
}
+ #[test]
+ fn test_ivfhnswflat_batch_reader_coalesces_contiguous_list_reads() {
+ let d = 4;
+ let nlist = 4;
+ let n = 128;
+ let data: Vec<f32> = (0..n)
+ .flat_map(|i| {
+ let cluster = (i % nlist) as f32 * 100.0;
+ [cluster + i as f32 * 0.01, 1.0, 2.0, 3.0]
+ })
+ .collect();
+ let ids: Vec<i64> = (1000..1000 + n as i64).collect();
+
+ let mut index = IVFHNSWFlatIndex::new(d, nlist, MetricType::L2,
HnswBuildParams::default());
+ index.train(&data, n);
+ index.add(&data, &ids, n);
+ index.build_graphs().unwrap();
+
+ let mut buf = Vec::new();
+ write_ivfhnswflat_index(&index, &mut PosWriter::new(&mut
buf)).unwrap();
+
+ let pread_count = Arc::new(AtomicUsize::new(0));
+ let cursor = CountingPreadCursor::new(buf, Arc::clone(&pread_count));
+ let mut reader = IVFHNSWFlatIndexReader::open(cursor).unwrap();
+ reader.ensure_loaded().unwrap();
+ pread_count.store(0, Ordering::SeqCst);
+ let queries = data[0..d].to_vec();
+
+ search_batch_ivfhnswflat_reader(&mut reader, &queries, 1, 5, nlist,
32).unwrap();
+
+ assert_eq!(pread_count.load(Ordering::SeqCst), 1);
+ }
+
+ #[test]
+ fn test_ivfhnswflat_reader_coalesces_small_gap_between_requested_lists() {
+ let d = 4;
+ let nlist = 4;
+ let n = 128;
+ let data: Vec<f32> = (0..n)
+ .flat_map(|i| {
+ let cluster = (i % nlist) as f32 * 100.0;
+ [cluster + i as f32 * 0.01, 1.0, 2.0, 3.0]
+ })
+ .collect();
+ let ids: Vec<i64> = (1000..1000 + n as i64).collect();
+
+ let mut index = IVFHNSWFlatIndex::new(d, nlist, MetricType::L2,
HnswBuildParams::default());
+ index.train(&data, n);
+ index.add(&data, &ids, n);
+ index.build_graphs().unwrap();
+
+ let mut buf = Vec::new();
+ write_ivfhnswflat_index(&index, &mut PosWriter::new(&mut
buf)).unwrap();
+
+ let pread_count = Arc::new(AtomicUsize::new(0));
+ let cursor = CountingPreadCursor::new(buf, Arc::clone(&pread_count));
+ let mut reader = IVFHNSWFlatIndexReader::open(cursor).unwrap();
+ reader.ensure_loaded().unwrap();
+ assert!(reader.list_counts[..3].iter().all(|&count| count > 0));
+ pread_count.store(0, Ordering::SeqCst);
+
+ let lists = reader.read_graph_lists_coalesced(&[0, 2]).unwrap();
+
+ assert_eq!(lists.len(), 2);
+ assert_eq!(pread_count.load(Ordering::SeqCst), 1);
+ }
+
#[test]
fn test_ivfhnswflat_batch_reader_search_with_roaring_filter_bytes() {
let d = 2;
@@ -933,6 +1183,8 @@ mod tests {
let cursor = CountingPreadCursor::new(buf, Arc::clone(&pread_count));
let filter: HashSet<i64> = [10, 12].into_iter().collect();
let mut reader = IVFHNSWFlatIndexReader::open(cursor).unwrap();
+ reader.ensure_loaded().unwrap();
+ pread_count.store(0, Ordering::SeqCst);
reader
.search_with_filter(&[0.0, 0.0], 2, 1, 1, Some(&filter))
@@ -941,6 +1193,38 @@ mod tests {
assert_eq!(pread_count.load(Ordering::SeqCst), 1);
}
+ #[test]
+ fn test_ivfhnswflat_reader_search_coalesces_contiguous_list_reads() {
+ let d = 4;
+ let nlist = 4;
+ let n = 128;
+ let data: Vec<f32> = (0..n)
+ .flat_map(|i| {
+ let cluster = (i % nlist) as f32 * 100.0;
+ [cluster + i as f32 * 0.01, 1.0, 2.0, 3.0]
+ })
+ .collect();
+ let ids: Vec<i64> = (1000..1000 + n as i64).collect();
+
+ let mut index = IVFHNSWFlatIndex::new(d, nlist, MetricType::L2,
HnswBuildParams::default());
+ index.train(&data, n);
+ index.add(&data, &ids, n);
+ index.build_graphs().unwrap();
+
+ let mut buf = Vec::new();
+ write_ivfhnswflat_index(&index, &mut PosWriter::new(&mut
buf)).unwrap();
+
+ let pread_count = Arc::new(AtomicUsize::new(0));
+ let cursor = CountingPreadCursor::new(buf, Arc::clone(&pread_count));
+ let mut reader = IVFHNSWFlatIndexReader::open(cursor).unwrap();
+ reader.ensure_loaded().unwrap();
+ pread_count.store(0, Ordering::SeqCst);
+
+ reader.search(&data[0..d], 5, nlist, 32).unwrap();
+
+ assert_eq!(pread_count.load(Ordering::SeqCst), 1);
+ }
+
#[test]
fn test_ivfhnswflat_reader_rejects_truncated_graph_section() {
let d = 2;
@@ -1051,54 +1335,31 @@ mod tests {
struct CountingPreadCursor {
data: Vec<u8>,
- pos: usize,
pread_count: Arc<AtomicUsize>,
}
impl CountingPreadCursor {
fn new(data: Vec<u8>, pread_count: Arc<AtomicUsize>) -> Self {
- Self {
- data,
- pos: 0,
- pread_count,
- }
+ Self { data, pread_count }
}
}
impl SeekRead for CountingPreadCursor {
- fn seek(&mut self, pos: u64) -> io::Result<()> {
- self.pos = pos as usize;
- Ok(())
- }
-
- fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
- let end = self.pos.checked_add(buf.len()).ok_or_else(|| {
- io::Error::new(io::ErrorKind::UnexpectedEof, "cursor position
overflow")
- })?;
- if end > self.data.len() {
- return Err(io::Error::new(
- io::ErrorKind::UnexpectedEof,
- "failed to fill whole buffer",
- ));
- }
- buf.copy_from_slice(&self.data[self.pos..end]);
- self.pos = end;
- Ok(())
- }
-
- fn pread(&mut self, pos: u64, buf: &mut [u8]) -> io::Result<()> {
- self.pread_count.fetch_add(1, Ordering::SeqCst);
- let pos = pos as usize;
- let end = pos.checked_add(buf.len()).ok_or_else(|| {
- io::Error::new(io::ErrorKind::UnexpectedEof, "cursor position
overflow")
- })?;
- if end > self.data.len() {
- return Err(io::Error::new(
- io::ErrorKind::UnexpectedEof,
- "failed to fill whole buffer",
- ));
+ fn pread(&mut self, ranges: &mut [ReadRequest<'_>]) -> io::Result<()> {
+ for range in ranges {
+ self.pread_count.fetch_add(1, Ordering::SeqCst);
+ let pos = range.pos as usize;
+ let end = pos.checked_add(range.buf.len()).ok_or_else(|| {
+ io::Error::new(io::ErrorKind::UnexpectedEof, "cursor
position overflow")
+ })?;
+ if end > self.data.len() {
+ return Err(io::Error::new(
+ io::ErrorKind::UnexpectedEof,
+ "failed to fill whole buffer",
+ ));
+ }
+ range.buf.copy_from_slice(&self.data[pos..end]);
}
- buf.copy_from_slice(&self.data[pos..end]);
Ok(())
}
}
diff --git a/core/src/ivfhnswsq_io.rs b/core/src/ivfhnswsq_io.rs
index c86a478..c1c94b6 100644
--- a/core/src/ivfhnswsq_io.rs
+++ b/core/src/ivfhnswsq_io.rs
@@ -24,7 +24,7 @@ use crate::index_io_util::{
u64_to_i64, usize_to_i32, usize_to_i64, validate_positive_i32,
validate_search_inputs,
write_f32_slice, write_i32_le, write_i64_le, write_u32_le,
};
-use crate::io::{SeekRead, SeekWrite};
+use crate::io::{PreadCursor, ReadRequest, SeekRead, SeekWrite};
use crate::ivfhnswsq::IVFHNSWSQIndex;
use crate::ivfpq::RowIdFilter;
use crate::kmeans;
@@ -35,6 +35,7 @@ use std::io;
pub const IVF_HNSW_SQ_MAGIC: u32 = 0x49485351; // "IHSQ"
pub const IVF_HNSW_SQ_VERSION: u32 = 1;
pub const IVF_HNSW_SQ_HEADER_SIZE: usize = 64;
+const MAX_COALESCED_READ_GAP_BYTES: u64 = 1 << 20;
pub fn write_ivfhnswsq_index(index: &IVFHNSWSQIndex, out: &mut dyn SeekWrite)
-> io::Result<()> {
validate_index_shape(index)?;
@@ -158,16 +159,16 @@ pub struct IVFHNSWSQIndexReader<R: SeekRead> {
impl<R: SeekRead> IVFHNSWSQIndexReader<R> {
pub fn open(mut reader: R) -> io::Result<Self> {
- reader.seek(0)?;
+ let mut cursor = PreadCursor::new(&mut reader, 0);
- let magic = read_u32_le(&mut reader)?;
+ let magic = read_u32_le(&mut cursor)?;
if magic != IVF_HNSW_SQ_MAGIC {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Invalid IVF_HNSW_SQ magic: 0x{:08X}", magic),
));
}
- let version = read_u32_le(&mut reader)?;
+ let version = read_u32_le(&mut cursor)?;
if version != IVF_HNSW_SQ_VERSION {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
@@ -175,38 +176,38 @@ impl<R: SeekRead> IVFHNSWSQIndexReader<R> {
));
}
- let d = validate_positive_i32(read_i32_le(&mut reader)?, "d")? as
usize;
- let nlist = validate_positive_i32(read_i32_le(&mut reader)?, "nlist")?
as usize;
- let metric_code = read_u32_le(&mut reader)?;
+ let d = validate_positive_i32(read_i32_le(&mut cursor)?, "d")? as
usize;
+ let nlist = validate_positive_i32(read_i32_le(&mut cursor)?, "nlist")?
as usize;
+ let metric_code = read_u32_le(&mut cursor)?;
let metric = MetricType::from_code(metric_code).ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("Unknown metric type: {}", metric_code),
)
})?;
- let total_vectors = read_i64_le(&mut reader)?;
+ let total_vectors = read_i64_le(&mut cursor)?;
let hnsw_params = HnswBuildParams {
- m: validate_positive_i32(read_i32_le(&mut reader)?, "hnsw m")? as
usize,
+ m: validate_positive_i32(read_i32_le(&mut cursor)?, "hnsw m")? as
usize,
ef_construction: validate_positive_i32(
- read_i32_le(&mut reader)?,
+ read_i32_le(&mut cursor)?,
"hnsw ef_construction",
)? as usize,
- max_level: validate_positive_i32(read_i32_le(&mut reader)?, "hnsw
max_level")? as usize,
+ max_level: validate_positive_i32(read_i32_le(&mut cursor)?, "hnsw
max_level")? as usize,
}
.sanitized();
let mut bounds_summary = [0u8; 8];
- reader.read_exact(&mut bounds_summary)?;
+ cursor.read_exact(&mut bounds_summary)?;
let mut reserved = [0u8; 16];
- reader.read_exact(&mut reserved)?;
+ cursor.read_exact(&mut reserved)?;
- let mins = read_f32_vec(&mut reader, d)?;
- let maxs = read_f32_vec(&mut reader, d)?;
+ let mins = read_f32_vec(&mut cursor, d)?;
+ let maxs = read_f32_vec(&mut cursor, d)?;
validate_sq_bounds(d, &mins, &maxs)?;
let sq = ScalarQuantizer::with_dimension_bounds(d, mins, maxs);
let mut list_sqs = Vec::with_capacity(nlist);
for _ in 0..nlist {
- let mins = read_f32_vec(&mut reader, d)?;
- let maxs = read_f32_vec(&mut reader, d)?;
+ let mins = read_f32_vec(&mut cursor, d)?;
+ let maxs = read_f32_vec(&mut cursor, d)?;
validate_sq_bounds(d, &mins, &maxs)?;
list_sqs.push(ScalarQuantizer::with_dimension_bounds(d, mins,
maxs));
}
@@ -235,15 +236,15 @@ impl<R: SeekRead> IVFHNSWSQIndexReader<R> {
let quantizer_centroids_offset =
IVF_HNSW_SQ_HEADER_SIZE as u64 + (self.d as u64) * 8 * (self.nlist
as u64 + 1);
- self.reader.seek(quantizer_centroids_offset)?;
+ let mut cursor = PreadCursor::new(&mut self.reader,
quantizer_centroids_offset);
self.quantizer_centroids =
- read_f32_vec(&mut self.reader, checked_section_size(self.nlist,
self.d)?)?;
+ read_f32_vec(&mut cursor, checked_section_size(self.nlist,
self.d)?)?;
self.list_offsets = vec![0; self.nlist];
self.list_counts = vec![0; self.nlist];
self.list_graph_bytes_lens = vec![0; self.nlist];
for list_id in 0..self.nlist {
- self.list_offsets[list_id] = read_i64_le(&mut self.reader)?;
- let count = read_i32_le(&mut self.reader)?;
+ self.list_offsets[list_id] = read_i64_le(&mut cursor)?;
+ let count = read_i32_le(&mut cursor)?;
if count < 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
@@ -251,7 +252,7 @@ impl<R: SeekRead> IVFHNSWSQIndexReader<R> {
));
}
self.list_counts[list_id] = count;
- let graph_bytes_len = read_i32_le(&mut self.reader)?;
+ let graph_bytes_len = read_i32_le(&mut cursor)?;
if graph_bytes_len < 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
@@ -262,7 +263,7 @@ impl<R: SeekRead> IVFHNSWSQIndexReader<R> {
));
}
self.list_graph_bytes_lens[list_id] = graph_bytes_len;
- let _reserved = read_i64_le(&mut self.reader)?;
+ let _reserved = read_i64_le(&mut cursor)?;
}
self.loaded = true;
@@ -281,6 +282,132 @@ impl<R: SeekRead> IVFHNSWSQIndexReader<R> {
fn read_graph_list(&mut self, list_id: usize) ->
io::Result<Option<GraphList>> {
self.ensure_loaded()?;
+ let Some(meta) = self.list_payload_meta(list_id)? else {
+ return Ok(None);
+ };
+ let mut payload = vec![0u8; meta.payload_len];
+ self.reader
+ .pread(&mut [ReadRequest::new(meta.offset, &mut payload)])?;
+
+ self.decode_graph_list_payload(meta, &payload).map(Some)
+ }
+
+ fn read_graph_lists_coalesced(
+ &mut self,
+ list_ids: &[usize],
+ ) -> io::Result<Vec<(usize, GraphList)>> {
+ self.ensure_loaded()?;
+ let mut metas = Vec::new();
+ for &list_id in list_ids {
+ if let Some(meta) = self.list_payload_meta(list_id)? {
+ metas.push(meta);
+ }
+ }
+ if metas.is_empty() {
+ return Ok(Vec::new());
+ }
+
+ metas.sort_by_key(|meta| meta.offset);
+ let mut loaded = Vec::with_capacity(metas.len());
+ let mut range_start = metas[0].offset;
+ let mut range_end = metas[0].end_offset()?;
+ let mut range_payload_bytes = metas[0].payload_len;
+ let mut range_metas = vec![metas[0]];
+ for &meta in metas.iter().skip(1) {
+ let meta_end = meta.end_offset()?;
+ let gap = meta.offset.checked_sub(range_end).ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ "IVF-HNSW-SQ list payload offsets overlap",
+ )
+ })?;
+ if should_coalesce_gap(
+ gap,
+ range_start,
+ meta_end,
+ range_payload_bytes,
+ meta.payload_len,
+ ) {
+ range_end = meta_end;
+ range_payload_bytes = range_payload_bytes
+ .checked_add(meta.payload_len)
+ .ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ "coalesced IVF-HNSW-SQ requested payload bytes
overflow",
+ )
+ })?;
+ range_metas.push(meta);
+ } else {
+ self.read_coalesced_graph_list_range(
+ range_start,
+ range_end,
+ &range_metas,
+ &mut loaded,
+ )?;
+ range_start = meta.offset;
+ range_end = meta_end;
+ range_payload_bytes = meta.payload_len;
+ range_metas.clear();
+ range_metas.push(meta);
+ }
+ }
+ self.read_coalesced_graph_list_range(range_start, range_end,
&range_metas, &mut loaded)?;
+
+ loaded.sort_by_key(|(list_id, _)| {
+ list_ids
+ .iter()
+ .position(|&requested_id| requested_id == *list_id)
+ .unwrap_or(usize::MAX)
+ });
+ Ok(loaded)
+ }
+
+ fn read_coalesced_graph_list_range(
+ &mut self,
+ range_start: u64,
+ range_end: u64,
+ metas: &[ListPayloadMeta],
+ loaded: &mut Vec<(usize, GraphList)>,
+ ) -> io::Result<()> {
+ let byte_len =
usize::try_from(range_end.checked_sub(range_start).ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ "coalesced IVF-HNSW-SQ read range is invalid",
+ )
+ })?)
+ .map_err(|_| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ "coalesced IVF-HNSW-SQ read range exceeds usize",
+ )
+ })?;
+ let mut payload = vec![0u8; byte_len];
+ self.reader
+ .pread(&mut [ReadRequest::new(range_start, &mut payload)])?;
+
+ for &meta in metas {
+ let start = usize::try_from(meta.offset - range_start).map_err(|_|
{
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ "coalesced IVF-HNSW-SQ payload offset exceeds usize",
+ )
+ })?;
+ let end = start.checked_add(meta.payload_len).ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ "coalesced IVF-HNSW-SQ payload slice overflows",
+ )
+ })?;
+ loaded.push((
+ meta.list_id,
+ self.decode_graph_list_payload(meta, &payload[start..end])?,
+ ));
+ }
+ Ok(())
+ }
+
+ fn list_payload_meta(&self, list_id: usize) ->
io::Result<Option<ListPayloadMeta>> {
if list_id >= self.nlist {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
@@ -301,21 +428,42 @@ impl<R: SeekRead> IVFHNSWSQIndexReader<R> {
));
}
let payload_len = list_payload_len(count, self.sq.code_size(),
graph_bytes_len)?;
- let mut payload = vec![0u8; payload_len];
- self.reader.pread(offset, &mut payload)?;
+ Ok(Some(ListPayloadMeta {
+ list_id,
+ offset,
+ count,
+ payload_len,
+ }))
+ }
- let ids_bytes_len = checked_list_bytes(count, 8)?;
+ fn decode_graph_list_payload(
+ &self,
+ meta: ListPayloadMeta,
+ payload: &[u8],
+ ) -> io::Result<GraphList> {
+ if payload.len() != meta.payload_len {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!(
+ "list {} payload length {} does not match expected {}",
+ meta.list_id,
+ payload.len(),
+ meta.payload_len
+ ),
+ ));
+ }
+ let ids_bytes_len = checked_list_bytes(meta.count, 8)?;
let code_size = self.sq.code_size();
- let codes_bytes_len = checked_list_bytes(count, code_size)?;
+ let codes_bytes_len = checked_list_bytes(meta.count, code_size)?;
let ids = payload[..ids_bytes_len]
.chunks_exact(8)
.map(|c| i64::from_le_bytes(c.try_into().unwrap()))
.collect();
let codes = payload[ids_bytes_len..ids_bytes_len +
codes_bytes_len].to_vec();
- let mut vectors = vec![0.0f32; count * self.d];
- self.list_sq(list_id)
- .decode_batch(&codes, count, &mut vectors);
- let centroid = self.list_centroid(list_id).to_vec();
+ let mut vectors = vec![0.0f32; meta.count * self.d];
+ self.list_sq(meta.list_id)
+ .decode_batch(&codes, meta.count, &mut vectors);
+ let centroid = self.list_centroid(meta.list_id).to_vec();
for vector in vectors.chunks_exact_mut(self.d) {
for i in 0..self.d {
vector[i] += centroid[i];
@@ -324,7 +472,7 @@ impl<R: SeekRead> IVFHNSWSQIndexReader<R> {
let graph = decode_graph(
&payload[ids_bytes_len + codes_bytes_len..],
vectors,
- count,
+ meta.count,
self.d,
self.metric,
self.hnsw_params,
@@ -332,16 +480,16 @@ impl<R: SeekRead> IVFHNSWSQIndexReader<R> {
let graph = graph.ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
- format!("list {} is missing HNSW graph", list_id),
+ format!("list {} is missing HNSW graph", meta.list_id),
)
})?;
- Ok(Some(GraphList {
+ Ok(GraphList {
ids,
codes,
graph,
centroid: Some(centroid),
- sq: self.list_sq(list_id).clone(),
- }))
+ sq: self.list_sq(meta.list_id).clone(),
+ })
}
fn list_centroid(&self, list_id: usize) -> &[f32] {
@@ -377,10 +525,8 @@ impl<R: SeekRead> IVFHNSWSQIndexReader<R> {
let (probe_indices, _) =
kmeans::find_topk(&q, &self.quantizer_centroids, self.nlist,
self.d, nprobe);
let mut loaded_lists = Vec::with_capacity(probe_indices.len());
- for list_id in probe_indices {
- if let Some(list) = self.read_graph_list(list_id)? {
- loaded_lists.push(list);
- }
+ for (_, list) in self.read_graph_lists_coalesced(&probe_indices)? {
+ loaded_lists.push(list);
}
let search_lists: Vec<_> = loaded_lists
.iter()
@@ -470,10 +616,7 @@ pub fn search_batch_ivfhnswsq_reader_filter<R: SeekRead>(
let mut heaps: Vec<TopKHeap> = (0..nq).map(|_| TopKHeap::new(k)).collect();
let mut query_filtered_counts = vec![0usize; nq];
let mut loaded_lists = Vec::with_capacity(unique_lists.len());
- for list_id in unique_lists {
- let Some(list) = reader.read_graph_list(list_id)? else {
- continue;
- };
+ for (list_id, list) in reader.read_graph_lists_coalesced(&unique_lists)? {
if let Some(f) = filter {
let filtered = list.ids.iter().filter(|&&id|
f.contains(id)).count();
for &qi in &list_to_queries[list_id] {
@@ -584,6 +727,46 @@ struct GraphList {
sq: ScalarQuantizer,
}
+#[derive(Clone, Copy)]
+struct ListPayloadMeta {
+ list_id: usize,
+ offset: u64,
+ count: usize,
+ payload_len: usize,
+}
+
+impl ListPayloadMeta {
+ fn end_offset(self) -> io::Result<u64> {
+ self.offset
+ .checked_add(self.payload_len as u64)
+ .ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ "IVF-HNSW-SQ list payload offset overflows u64",
+ )
+ })
+ }
+}
+
+fn should_coalesce_gap(
+ gap: u64,
+ range_start: u64,
+ next_range_end: u64,
+ current_payload_bytes: usize,
+ next_payload_bytes: usize,
+) -> bool {
+ if gap > MAX_COALESCED_READ_GAP_BYTES {
+ return false;
+ }
+ let Some(requested_bytes) =
current_payload_bytes.checked_add(next_payload_bytes) else {
+ return false;
+ };
+ let Some(range_bytes) = next_range_end.checked_sub(range_start) else {
+ return false;
+ };
+ range_bytes <= requested_bytes.saturating_mul(2) as u64
+}
+
struct LoadedBatchList {
query_ids: Vec<usize>,
ids: Vec<i64>,
@@ -774,8 +957,11 @@ mod tests {
use super::*;
use crate::hnsw::HnswBuildParams;
use crate::io::PosWriter;
+ use crate::io::{ReadRequest, SeekRead};
use roaring::RoaringTreemap;
use std::io::Cursor;
+ use std::sync::atomic::{AtomicUsize, Ordering};
+ use std::sync::Arc;
#[test]
fn test_ivfhnswsq_write_read_search_roundtrip() {
@@ -870,6 +1056,38 @@ mod tests {
assert_eq!(labels, vec![12, -1]);
}
+ #[test]
+ fn test_ivfhnswsq_reader_search_coalesces_contiguous_list_reads() {
+ let d = 4;
+ let nlist = 4;
+ let n = 128;
+ let data: Vec<f32> = (0..n)
+ .flat_map(|i| {
+ let cluster = (i % nlist) as f32 * 100.0;
+ [cluster + i as f32 * 0.01, 1.0, 2.0, 3.0]
+ })
+ .collect();
+ let ids: Vec<i64> = (0..n as i64).collect();
+
+ let mut index = IVFHNSWSQIndex::new(d, nlist, MetricType::L2,
HnswBuildParams::default());
+ index.train(&data, n);
+ index.add(&data, &ids, n);
+ index.build_graphs().unwrap();
+
+ let mut buf = Vec::new();
+ write_ivfhnswsq_index(&index, &mut PosWriter::new(&mut buf)).unwrap();
+
+ let pread_count = Arc::new(AtomicUsize::new(0));
+ let cursor = CountingPreadCursor::new(buf, Arc::clone(&pread_count));
+ let mut reader = IVFHNSWSQIndexReader::open(cursor).unwrap();
+ reader.ensure_loaded().unwrap();
+ pread_count.store(0, Ordering::SeqCst);
+
+ reader.search(&data[0..d], 5, nlist, 32).unwrap();
+
+ assert_eq!(pread_count.load(Ordering::SeqCst), 1);
+ }
+
#[test]
fn test_ivfhnswsq_write_read_search_roundtrip_cosine() {
let d = 3;
@@ -934,6 +1152,73 @@ mod tests {
}
}
+ #[test]
+ fn test_ivfhnswsq_batch_reader_coalesces_contiguous_list_reads() {
+ let d = 4;
+ let nlist = 4;
+ let n = 128;
+ let data: Vec<f32> = (0..n)
+ .flat_map(|i| {
+ let cluster = (i % nlist) as f32 * 100.0;
+ [cluster + i as f32 * 0.01, 1.0, 2.0, 3.0]
+ })
+ .collect();
+ let ids: Vec<i64> = (0..n as i64).collect();
+
+ let mut index = IVFHNSWSQIndex::new(d, nlist, MetricType::L2,
HnswBuildParams::default());
+ index.train(&data, n);
+ index.add(&data, &ids, n);
+ index.build_graphs().unwrap();
+
+ let mut buf = Vec::new();
+ write_ivfhnswsq_index(&index, &mut PosWriter::new(&mut buf)).unwrap();
+
+ let pread_count = Arc::new(AtomicUsize::new(0));
+ let cursor = CountingPreadCursor::new(buf, Arc::clone(&pread_count));
+ let mut reader = IVFHNSWSQIndexReader::open(cursor).unwrap();
+ reader.ensure_loaded().unwrap();
+ pread_count.store(0, Ordering::SeqCst);
+ let queries = data[0..d].to_vec();
+
+ search_batch_ivfhnswsq_reader(&mut reader, &queries, 1, 5, nlist,
32).unwrap();
+
+ assert_eq!(pread_count.load(Ordering::SeqCst), 1);
+ }
+
+ #[test]
+ fn test_ivfhnswsq_reader_coalesces_small_gap_between_requested_lists() {
+ let d = 4;
+ let nlist = 4;
+ let n = 128;
+ let data: Vec<f32> = (0..n)
+ .flat_map(|i| {
+ let cluster = (i % nlist) as f32 * 100.0;
+ [cluster + i as f32 * 0.01, 1.0, 2.0, 3.0]
+ })
+ .collect();
+ let ids: Vec<i64> = (0..n as i64).collect();
+
+ let mut index = IVFHNSWSQIndex::new(d, nlist, MetricType::L2,
HnswBuildParams::default());
+ index.train(&data, n);
+ index.add(&data, &ids, n);
+ index.build_graphs().unwrap();
+
+ let mut buf = Vec::new();
+ write_ivfhnswsq_index(&index, &mut PosWriter::new(&mut buf)).unwrap();
+
+ let pread_count = Arc::new(AtomicUsize::new(0));
+ let cursor = CountingPreadCursor::new(buf, Arc::clone(&pread_count));
+ let mut reader = IVFHNSWSQIndexReader::open(cursor).unwrap();
+ reader.ensure_loaded().unwrap();
+ assert!(reader.list_counts[..3].iter().all(|&count| count > 0));
+ pread_count.store(0, Ordering::SeqCst);
+
+ let lists = reader.read_graph_lists_coalesced(&[0, 2]).unwrap();
+
+ assert_eq!(lists.len(), 2);
+ assert_eq!(pread_count.load(Ordering::SeqCst), 1);
+ }
+
#[test]
fn test_ivfhnswsq_write_requires_graphs() {
let mut index = IVFHNSWSQIndex::new(2, 1, MetricType::L2,
HnswBuildParams::default());
@@ -972,4 +1257,35 @@ mod tests {
assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
assert!(err.to_string().contains("graph does not match"));
}
+
+ struct CountingPreadCursor {
+ data: Vec<u8>,
+ pread_count: Arc<AtomicUsize>,
+ }
+
+ impl CountingPreadCursor {
+ fn new(data: Vec<u8>, pread_count: Arc<AtomicUsize>) -> Self {
+ Self { data, pread_count }
+ }
+ }
+
+ impl SeekRead for CountingPreadCursor {
+ fn pread(&mut self, ranges: &mut [ReadRequest<'_>]) -> io::Result<()> {
+ for range in ranges {
+ self.pread_count.fetch_add(1, Ordering::SeqCst);
+ let pos = range.pos as usize;
+ let end = pos.checked_add(range.buf.len()).ok_or_else(|| {
+ io::Error::new(io::ErrorKind::UnexpectedEof, "cursor
position overflow")
+ })?;
+ if end > self.data.len() {
+ return Err(io::Error::new(
+ io::ErrorKind::UnexpectedEof,
+ "failed to fill whole buffer",
+ ));
+ }
+ range.buf.copy_from_slice(&self.data[pos..end]);
+ }
+ Ok(())
+ }
+ }
}
diff --git a/core/src/ivfpq.rs b/core/src/ivfpq.rs
index 1684481..9971998 100644
--- a/core/src/ivfpq.rs
+++ b/core/src/ivfpq.rs
@@ -1023,93 +1023,69 @@ pub fn search_with_reader_filter<R: SeekRead>(
let mut heap = TopKHeap::new(k);
- if reader.supports_concurrent_pread() {
- // Pre-read all inverted lists upfront so we can scan them in parallel.
- let mut list_data: Vec<PreReadList> = Vec::new();
- for (probe_idx, &list_id) in probe_indices.iter().enumerate() {
- let count = reader.list_counts[list_id] as usize;
- if count == 0 {
- continue;
- }
- let dis0 = if use_precomputed {
- coarse_dists[probe_idx]
- } else {
- 0.0
- };
- let (ids, codes) = reader.read_inverted_list(list_id)?;
- list_data.push(PreReadList {
- list_id,
- count,
- dis0,
- ids,
- codes,
- });
+ let mut lists_to_read = Vec::new();
+ for (probe_idx, &list_id) in probe_indices.iter().enumerate() {
+ let count = reader.list_counts[list_id] as usize;
+ if count == 0 {
+ continue;
}
-
- let ctx = ReaderSearchContext {
- q: &q,
- ip_table: &ip_table,
- use_precomputed,
- filter,
- d,
- m,
- ksub,
- metric,
- by_residual,
- transposed_codes: reader.transposed_codes,
- pq: &reader.pq,
- quantizer_centroids: &reader.quantizer_centroids,
- precomputed_table: &reader.precomputed_table,
+ let dis0 = if use_precomputed {
+ coarse_dists[probe_idx]
+ } else {
+ 0.0
};
- let per_list_results: Vec<Vec<(f32, i64)>> = list_data
- .par_iter()
- .map(|entry| {
- let mut local_heap = TopKHeap::new(k);
- scan_reader_list(entry, &ctx, &mut local_heap);
- local_heap.into_sorted()
- })
- .collect();
+ lists_to_read.push((list_id, count, dis0));
+ }
+
+ let read_list_ids: Vec<usize> = lists_to_read
+ .iter()
+ .map(|&(list_id, _, _)| list_id)
+ .collect();
+ let read_lists = reader.read_inverted_lists(&read_list_ids)?;
+ let mut list_data: Vec<PreReadList> = Vec::with_capacity(read_lists.len());
+ for ((list_id, count, dis0), read_list) in
lists_to_read.into_iter().zip(read_lists) {
+ if list_id != read_list.list_id {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ "batched inverted list read returned lists out of order",
+ ));
+ }
+ list_data.push(PreReadList {
+ list_id,
+ count,
+ dis0,
+ ids: read_list.ids,
+ codes: read_list.codes,
+ });
+ }
- for results in per_list_results {
- for (dist, id) in results {
- heap.push(dist, id);
- }
- }
- } else {
- for (probe_idx, &list_id) in probe_indices.iter().enumerate() {
- let count = reader.list_counts[list_id] as usize;
- if count == 0 {
- continue;
- }
- let dis0 = if use_precomputed {
- coarse_dists[probe_idx]
- } else {
- 0.0
- };
- let (ids, codes) = reader.read_inverted_list(list_id)?;
- let entry = PreReadList {
- list_id,
- count,
- dis0,
- ids,
- codes,
- };
- let ctx = ReaderSearchContext {
- q: &q,
- ip_table: &ip_table,
- use_precomputed,
- filter,
- d,
- m,
- ksub,
- metric,
- by_residual,
- transposed_codes: reader.transposed_codes,
- pq: &reader.pq,
- quantizer_centroids: &reader.quantizer_centroids,
- precomputed_table: &reader.precomputed_table,
- };
- scan_reader_list(&entry, &ctx, &mut heap);
+ let ctx = ReaderSearchContext {
+ q: &q,
+ ip_table: &ip_table,
+ use_precomputed,
+ filter,
+ d,
+ m,
+ ksub,
+ metric,
+ by_residual,
+ transposed_codes: reader.transposed_codes,
+ pq: &reader.pq,
+ quantizer_centroids: &reader.quantizer_centroids,
+ precomputed_table: &reader.precomputed_table,
+ };
+ let per_list_results: Vec<Vec<(f32, i64)>> = list_data
+ .par_iter()
+ .map(|entry| {
+ let mut local_heap = TopKHeap::new(k);
+ scan_reader_list(entry, &ctx, &mut local_heap);
+ local_heap.into_sorted()
+ })
+ .collect();
+
+ for results in per_list_results {
+ for (dist, id) in results {
+ heap.push(dist, id);
}
}
@@ -1329,23 +1305,23 @@ pub fn search_batch_reader_filter<R: SeekRead>(
let mut heaps: Vec<TopKHeap> = (0..nq).map(|_| TopKHeap::new(k)).collect();
- for list_id in unique_lists {
- let count = reader.list_counts[list_id] as usize;
- if count == 0 {
- continue;
- }
+ let non_empty_lists: Vec<usize> = unique_lists
+ .into_iter()
+ .filter(|&list_id| reader.list_counts[list_id] > 0)
+ .collect();
+ let read_lists = reader.read_inverted_lists(&non_empty_lists)?;
- // Read list once (shared across all queries that probe it)
- let (ids, codes) = reader.read_inverted_list(list_id)?;
+ for read_list in read_lists {
+ let count = read_list.ids.len();
let mut entry = PreReadList {
- list_id,
+ list_id: read_list.list_id,
count,
dis0: 0.0,
- ids,
- codes,
+ ids: read_list.ids,
+ codes: read_list.codes,
};
- for &(qi, coarse_dist) in &list_to_queries[&list_id] {
+ for &(qi, coarse_dist) in &list_to_queries[&entry.list_id] {
let query = &processed[qi * d..(qi + 1) * d];
let dis0 = if use_precomputed { coarse_dist } else { 0.0 };
let ctx = ReaderSearchContext {
@@ -1491,7 +1467,7 @@ fn sift_down(heap: &mut [(f32, i64)], mut i: usize) {
#[cfg(test)]
mod tests {
use super::*;
- use crate::io::SeekRead;
+ use crate::io::{ReadRequest, SeekRead};
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use std::io::Cursor;
@@ -1500,6 +1476,8 @@ mod tests {
#[derive(Default)]
struct ReaderStats {
pread_calls: usize,
+ pread_batches: usize,
+ max_ranges_per_batch: usize,
max_pread_len: usize,
}
@@ -1518,23 +1496,22 @@ mod tests {
}
impl SeekRead for NonConcurrentPreadCursor {
- fn seek(&mut self, pos: u64) -> io::Result<()> {
- io::Seek::seek(&mut self.inner, io::SeekFrom::Start(pos))?;
- Ok(())
- }
-
- fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
- io::Read::read_exact(&mut self.inner, buf)
- }
-
- fn pread(&mut self, pos: u64, buf: &mut [u8]) -> io::Result<()> {
+ fn pread(&mut self, ranges: &mut [ReadRequest<'_>]) -> io::Result<()> {
{
let mut stats = self.stats.lock().unwrap();
- stats.pread_calls += 1;
- stats.max_pread_len = stats.max_pread_len.max(buf.len());
+ stats.pread_batches += 1;
+ stats.max_ranges_per_batch =
stats.max_ranges_per_batch.max(ranges.len());
}
- io::Seek::seek(&mut self.inner, io::SeekFrom::Start(pos))?;
- io::Read::read_exact(&mut self.inner, buf)
+ for range in ranges {
+ {
+ let mut stats = self.stats.lock().unwrap();
+ stats.pread_calls += 1;
+ stats.max_pread_len =
stats.max_pread_len.max(range.buf.len());
+ }
+ io::Seek::seek(&mut self.inner,
io::SeekFrom::Start(range.pos))?;
+ io::Read::read_exact(&mut self.inner, range.buf)?;
+ }
+ Ok(())
}
}
@@ -2116,7 +2093,6 @@ mod tests {
let stats = Arc::new(Mutex::new(ReaderStats::default()));
let stream = NonConcurrentPreadCursor::new(buf, Arc::clone(&stats));
let mut reader = IVFPQIndexReader::open(stream).unwrap();
- assert!(!reader.supports_concurrent_pread());
let (ids, dists) = reader.search(&data[0..d], k, nprobe).unwrap();
@@ -2128,6 +2104,46 @@ mod tests {
);
}
+ #[test]
+ fn test_reader_search_batches_multiple_list_preads() {
+ use crate::io::{write_index, IVFPQIndexReader, PosWriter};
+
+ let d = 16;
+ let nlist = 8;
+ let m = 4;
+ let n = 800;
+ let k = 5;
+ let nprobe = 4;
+
+ let data = generate_clustered_data(n, d, 8, 987);
+ let ids: Vec<i64> = (0..n as i64).collect();
+
+ let mut index = IVFPQIndex::new(d, nlist, m, MetricType::L2, false);
+ index.train(&data, n);
+ index.add(&data, &ids, n);
+
+ let mut buf = Vec::new();
+ write_index(&index, &mut PosWriter::new(&mut buf)).unwrap();
+
+ let stats = Arc::new(Mutex::new(ReaderStats::default()));
+ let stream = NonConcurrentPreadCursor::new(buf, Arc::clone(&stats));
+ let mut reader = IVFPQIndexReader::open(stream).unwrap();
+ reader.ensure_loaded().unwrap();
+
+ {
+ let mut stats = stats.lock().unwrap();
+ *stats = ReaderStats::default();
+ }
+
+ let (_ids, _dists) = reader.search(&data[0..d], k, nprobe).unwrap();
+
+ let stats = stats.lock().unwrap();
+ assert!(
+ stats.max_ranges_per_batch > 1,
+ "multiple probed IVF-PQ lists should share one batched pread"
+ );
+ }
+
#[test]
fn test_reader_search_validates_inputs() {
use crate::io::{write_index, IVFPQIndexReader, PosWriter};
diff --git
a/jni/java-test/org/apache/paimon/index/ivfpq/VectorIndexJavaApiTest.java
b/jni/java-test/org/apache/paimon/index/ivfpq/VectorIndexJavaApiTest.java
index a9502af..4ea2a68 100644
--- a/jni/java-test/org/apache/paimon/index/ivfpq/VectorIndexJavaApiTest.java
+++ b/jni/java-test/org/apache/paimon/index/ivfpq/VectorIndexJavaApiTest.java
@@ -150,7 +150,7 @@ public class VectorIndexJavaApiTest {
closedWriter.close();
if (System.currentTimeMillis() < 0) {
- VectorIndexReader reader = new VectorIndexReader(new Object());
+ VectorIndexReader reader = new VectorIndexReader(new
EmptyVectorIndexInput());
reader.metadata();
reader.indexType();
reader.dimension();
@@ -226,4 +226,9 @@ public class VectorIndexJavaApiTest {
private interface ThrowingRunnable {
void run() throws Throwable;
}
+
+ private static final class EmptyVectorIndexInput implements
VectorIndexInput {
+ @Override
+ public void pread(long[] positions, byte[][] buffers) {}
+ }
}
diff --git
a/jni/java-test/org/apache/paimon/index/ivfpq/VectorIndexNativePanicBoundaryTest.java
b/jni/java-test/org/apache/paimon/index/ivfpq/VectorIndexNativePanicBoundaryTest.java
index 949da6f..045eec3 100644
---
a/jni/java-test/org/apache/paimon/index/ivfpq/VectorIndexNativePanicBoundaryTest.java
+++
b/jni/java-test/org/apache/paimon/index/ivfpq/VectorIndexNativePanicBoundaryTest.java
@@ -122,42 +122,26 @@ public class VectorIndexNativePanicBoundaryTest {
}
}
- public static final class ByteArraySeekableInputStream {
+ public static final class ByteArraySeekableInputStream implements
VectorIndexInput {
private final byte[] data;
- private int position;
ByteArraySeekableInputStream(byte[] data) {
this.data = data.clone();
}
- public void seek(long newPosition) {
- if (newPosition < 0 || newPosition > data.length) {
- throw new IllegalArgumentException("position out of range: " +
newPosition);
+ @Override
+ public void pread(long[] positions, byte[][] buffers) {
+ if (positions.length != buffers.length) {
+ throw new IllegalArgumentException("positions and buffers
length mismatch");
}
- this.position = (int) newPosition;
- }
-
- public int read(byte[] buffer, int offset, int length) {
- if (position >= data.length) {
- return -1;
- }
- int bytesToRead = Math.min(length, data.length - position);
- System.arraycopy(data, position, buffer, offset, bytesToRead);
- position += bytesToRead;
- return bytesToRead;
- }
-
- public int pread(long readPosition, byte[] buffer, int offset, int
length) {
- if (readPosition < 0 || readPosition > data.length) {
- return -1;
- }
- int start = (int) readPosition;
- if (start >= data.length) {
- return -1;
+ for (int i = 0; i < positions.length; i++) {
+ long readPosition = positions[i];
+ byte[] buffer = buffers[i];
+ if (readPosition < 0 || readPosition + buffer.length >
data.length) {
+ throw new IllegalArgumentException("read out of range: " +
readPosition);
+ }
+ System.arraycopy(data, (int) readPosition, buffer, 0,
buffer.length);
}
- int bytesToRead = Math.min(length, data.length - start);
- System.arraycopy(data, start, buffer, offset, bytesToRead);
- return bytesToRead;
}
}
}
diff --git a/jni/java/org/apache/paimon/index/ivfpq/VectorIndexInput.java
b/jni/java/org/apache/paimon/index/ivfpq/VectorIndexInput.java
new file mode 100644
index 0000000..451c884
--- /dev/null
+++ b/jni/java/org/apache/paimon/index/ivfpq/VectorIndexInput.java
@@ -0,0 +1,23 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.paimon.index.ivfpq;
+
+public interface VectorIndexInput {
+
+ void pread(long[] positions, byte[][] buffers);
+}
diff --git a/jni/java/org/apache/paimon/index/ivfpq/VectorIndexReader.java
b/jni/java/org/apache/paimon/index/ivfpq/VectorIndexReader.java
index 45c83df..d0da1f4 100644
--- a/jni/java/org/apache/paimon/index/ivfpq/VectorIndexReader.java
+++ b/jni/java/org/apache/paimon/index/ivfpq/VectorIndexReader.java
@@ -22,7 +22,7 @@ public final class VectorIndexReader implements AutoCloseable
{
private long nativePtr;
private VectorIndexMetadata metadata;
- public VectorIndexReader(Object input) {
+ public VectorIndexReader(VectorIndexInput input) {
if (input == null) {
throw new NullPointerException("input");
}
diff --git a/jni/src/stream.rs b/jni/src/stream.rs
index 1e566a3..f0b0364 100644
--- a/jni/src/stream.rs
+++ b/jni/src/stream.rs
@@ -15,213 +15,106 @@
// specific language governing permissions and limitations
// under the License.
-use jni::objects::GlobalRef;
+use jni::objects::{GlobalRef, JByteArray, JObject, JObjectArray, JValue};
use jni::JavaVM;
-use paimon_vindex_core::io::SeekRead;
+use paimon_vindex_core::io::{ReadRequest, SeekRead};
use std::io;
-use std::sync::{Arc, Mutex};
+use std::sync::Arc;
-/// JNI-backed seekable stream that delegates to Java's SeekableInputStream.
-///
-/// If the Java stream also implements VectoredReadable, pread() is used for
-/// thread-safe positional reads without changing the stream cursor.
+/// JNI-backed input stream that delegates to Java's VectorIndexInput.
pub struct JniSeekableStream {
jvm: Arc<JavaVM>,
stream_ref: Arc<GlobalRef>,
- stream_lock: Arc<Mutex<()>>,
- /// Whether the Java stream supports pread (implements VectoredReadable)
- has_pread: bool,
}
impl JniSeekableStream {
pub fn new(jvm: JavaVM, stream_ref: GlobalRef) -> Self {
- let jvm = Arc::new(jvm);
- let has_pread = check_has_pread(&jvm, &stream_ref);
JniSeekableStream {
- jvm,
+ jvm: Arc::new(jvm),
stream_ref: Arc::new(stream_ref),
- stream_lock: Arc::new(Mutex::new(())),
- has_pread,
}
}
}
-/// Check if the Java object implements VectoredReadable (has pread method).
-fn check_has_pread(jvm: &JavaVM, stream_ref: &GlobalRef) -> bool {
- let mut env = match jvm.attach_current_thread() {
- Ok(e) => e,
- Err(_) => return false,
- };
- // Try to find the pread method — if it exists, the stream supports
positional reads
- let class = match env.get_object_class(stream_ref.as_obj()) {
- Ok(c) => c,
- Err(_) => return false,
- };
- env.get_method_id(&class, "pread", "(J[BII)I").is_ok()
-}
-
impl SeekRead for JniSeekableStream {
- fn seek(&mut self, pos: u64) -> io::Result<()> {
- let _guard = self
- .stream_lock
- .lock()
- .map_err(|e| io::Error::other(format!("Lock poisoned: {}", e)))?;
-
- let mut env = self
- .jvm
- .attach_current_thread()
- .map_err(|e| io::Error::other(format!("JNI attach: {}", e)))?;
-
- env.call_method(
- self.stream_ref.as_obj(),
- "seek",
- "(J)V",
- &[jni::objects::JValue::Long(pos as i64)],
- )
- .map_err(|e| io::Error::other(format!("JNI seek: {}", e)))?;
-
- Ok(())
- }
-
- fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
- let _guard = self
- .stream_lock
- .lock()
- .map_err(|e| io::Error::other(format!("Lock poisoned: {}", e)))?;
-
- read_bytes_from_stream(&self.jvm, &self.stream_ref, buf)
- }
-
- /// Positional read via Java's VectoredReadable.pread(position, buffer,
offset, length).
- /// Thread-safe: does not change the stream cursor position.
- fn pread(&mut self, pos: u64, buf: &mut [u8]) -> io::Result<()> {
- if !self.has_pread {
- // Fallback: seek + read with lock
- let _guard = self
- .stream_lock
- .lock()
- .map_err(|e| io::Error::other(format!("Lock poisoned: {}",
e)))?;
-
- let mut env = self
- .jvm
- .attach_current_thread()
- .map_err(|e| io::Error::other(format!("JNI attach: {}", e)))?;
-
- env.call_method(
- self.stream_ref.as_obj(),
- "seek",
- "(J)V",
- &[jni::objects::JValue::Long(pos as i64)],
- )
- .map_err(|e| io::Error::other(format!("JNI seek: {}", e)))?;
-
- drop(env);
- return read_bytes_from_stream(&self.jvm, &self.stream_ref, buf);
+ /// Positional reads via VectorIndexInput.pread(long[] positions, byte[][]
buffers).
+ fn pread(&mut self, ranges: &mut [ReadRequest<'_>]) -> io::Result<()> {
+ if ranges.is_empty() {
+ return Ok(());
}
- // Use pread — no lock needed, thread-safe positional read
let mut env = self
.jvm
.attach_current_thread()
.map_err(|e| io::Error::other(format!("JNI attach: {}", e)))?;
- let jbuf = env
- .new_byte_array(buf.len() as i32)
- .map_err(|e| io::Error::other(format!("JNI alloc: {}", e)))?;
-
- let mut total_read = 0i32;
- let length = buf.len() as i32;
-
- while total_read < length {
- let remaining = length - total_read;
- let n = env
- .call_method(
- self.stream_ref.as_obj(),
- "pread",
- "(J[BII)I",
- &[
- jni::objects::JValue::Long(pos as i64 + total_read as
i64),
- jni::objects::JValue::Object(&jbuf),
- jni::objects::JValue::Int(total_read),
- jni::objects::JValue::Int(remaining),
- ],
- )
- .map_err(|e| io::Error::other(format!("JNI pread: {}", e)))?
- .i()
- .map_err(|e| io::Error::other(format!("JNI pread return: {}",
e)))?;
-
- if n <= 0 {
- return Err(io::Error::new(
- io::ErrorKind::UnexpectedEof,
- format!("pread EOF: read {} of {} bytes", total_read,
length),
- ));
- }
- total_read += n;
- }
-
- let mut signed_buf = vec![0i8; buf.len()];
- env.get_byte_array_region(&jbuf, 0, &mut signed_buf)
- .map_err(|e| io::Error::other(format!("JNI get_region: {}", e)))?;
-
- for (i, &b) in signed_buf.iter().enumerate() {
- buf[i] = b as u8;
+ let positions = env
+ .new_long_array(ranges.len() as i32)
+ .map_err(|e| io::Error::other(format!("JNI alloc positions: {}",
e)))?;
+ let position_values: Vec<i64> = ranges.iter().map(|range| range.pos as
i64).collect();
+ env.set_long_array_region(&positions, 0, &position_values)
+ .map_err(|e| io::Error::other(format!("JNI set positions: {}",
e)))?;
+
+ let byte_array_class = env
+ .find_class("[B")
+ .map_err(|e| io::Error::other(format!("JNI find byte[] class: {}",
e)))?;
+ let buffers = env
+ .new_object_array(ranges.len() as i32, byte_array_class,
JObject::null())
+ .map_err(|e| io::Error::other(format!("JNI alloc buffers: {}",
e)))?;
+ for (idx, range) in ranges.iter().enumerate() {
+ let jbuf = env
+ .new_byte_array(range.buf.len() as i32)
+ .map_err(|e| io::Error::other(format!("JNI alloc range buffer:
{}", e)))?;
+ env.set_object_array_element(&buffers, idx as i32, jbuf)
+ .map_err(|e| io::Error::other(format!("JNI set buffer: {}",
e)))?;
}
- Ok(())
- }
+ env.call_method(
+ self.stream_ref.as_obj(),
+ "pread",
+ "([J[[B)V",
+ &[JValue::Object(&positions), JValue::Object(&buffers)],
+ )
+ .map_err(|e| io::Error::other(format!("JNI pread: {}", e)))?;
- fn supports_concurrent_pread(&self) -> bool {
- self.has_pread
+ copy_java_buffers(&mut env, &buffers, ranges)
}
}
-/// Helper: read bytes from the Java stream (after seek, under lock).
-fn read_bytes_from_stream(jvm: &JavaVM, stream_ref: &GlobalRef, buf: &mut
[u8]) -> io::Result<()> {
- let mut env = jvm
- .attach_current_thread()
- .map_err(|e| io::Error::other(format!("JNI attach: {}", e)))?;
-
- let jbuf = env
- .new_byte_array(buf.len() as i32)
- .map_err(|e| io::Error::other(format!("JNI alloc: {}", e)))?;
-
- let mut total_read = 0i32;
- let length = buf.len() as i32;
-
- while total_read < length {
- let remaining = length - total_read;
- let n = env
- .call_method(
- stream_ref.as_obj(),
- "read",
- "([BII)I",
- &[
- jni::objects::JValue::Object(&jbuf),
- jni::objects::JValue::Int(total_read),
- jni::objects::JValue::Int(remaining),
- ],
- )
- .map_err(|e| io::Error::other(format!("JNI read: {}", e)))?
- .i()
- .map_err(|e| io::Error::other(format!("JNI read return: {}", e)))?;
-
- if n <= 0 {
+fn copy_java_buffers(
+ env: &mut jni::JNIEnv<'_>,
+ buffers: &JObjectArray<'_>,
+ ranges: &mut [ReadRequest<'_>],
+) -> io::Result<()> {
+ for (idx, range) in ranges.iter_mut().enumerate() {
+ let obj = env
+ .get_object_array_element(buffers, idx as i32)
+ .map_err(|e| io::Error::other(format!("JNI get buffer: {}", e)))?;
+ let jbuf = JByteArray::from(obj);
+ let len = env
+ .get_array_length(&jbuf)
+ .map_err(|e| io::Error::other(format!("JNI get buffer length: {}",
e)))?
+ as usize;
+ if len != range.buf.len() {
return Err(io::Error::new(
- io::ErrorKind::UnexpectedEof,
- format!("EOF: read {} of {} bytes", total_read, length),
+ io::ErrorKind::InvalidData,
+ format!(
+ "Java pread returned buffer length {} != {}",
+ len,
+ range.buf.len()
+ ),
));
}
- total_read += n;
- }
-
- let mut signed_buf = vec![0i8; buf.len()];
- env.get_byte_array_region(&jbuf, 0, &mut signed_buf)
- .map_err(|e| io::Error::other(format!("JNI get_region: {}", e)))?;
+ if len > 0 {
+ let mut signed_buf = vec![0i8; range.buf.len()];
+ env.get_byte_array_region(&jbuf, 0, &mut signed_buf)
+ .map_err(|e| io::Error::other(format!("JNI get_region: {}",
e)))?;
- for (i, &b) in signed_buf.iter().enumerate() {
- buf[i] = b as u8;
+ for (i, &b) in signed_buf.iter().enumerate() {
+ range.buf[i] = b as u8;
+ }
+ }
}
-
Ok(())
}
diff --git a/python/src/lib.rs b/python/src/lib.rs
index 1f146b6..aaeb90b 100644
--- a/python/src/lib.rs
+++ b/python/src/lib.rs
@@ -26,50 +26,80 @@ use paimon_vindex_core::index::{
IndexType, VectorIndexConfig, VectorIndexReader as CoreVectorIndexReader,
VectorIndexWriter as CoreVectorIndexWriter, VectorSearchParams,
};
-use paimon_vindex_core::io::{SeekRead, SeekWrite};
+use paimon_vindex_core::io::{ReadRequest, SeekRead, SeekWrite};
use pyo3::exceptions::{PyIOError, PyValueError};
use pyo3::prelude::*;
-use pyo3::types::{PyAny, PyBytes};
+use pyo3::types::{PyAny, PyBytes, PyList};
use std::io;
-struct PyFileStream {
- file: PyObject,
+struct PyVectorIndexInput {
+ input: PyObject,
}
-impl SeekRead for PyFileStream {
- fn seek(&mut self, pos: u64) -> io::Result<()> {
+impl SeekRead for PyVectorIndexInput {
+ fn pread(&mut self, ranges: &mut [ReadRequest<'_>]) -> io::Result<()> {
Python::with_gil(|py| {
- self.file
- .call_method1(py, "seek", (pos as i64,))
- .map_err(|e| io::Error::other(format!("seek: {}", e)))?;
- Ok(())
- })
- }
+ if !self
+ .input
+ .bind(py)
+ .hasattr("pread_many")
+ .map_err(|e| io::Error::other(format!("hasattr(pread_many):
{}", e)))?
+ {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "Python input must define pread_many(ranges)",
+ ));
+ }
- fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
- Python::with_gil(|py| {
+ let request_list = PyList::empty_bound(py);
+ for range in ranges.iter() {
+ request_list
+ .append((range.pos, range.buf.len()))
+ .map_err(|e| io::Error::other(format!("build pread_many
request: {}", e)))?;
+ }
let result = self
- .file
- .call_method1(py, "read", (buf.len(),))
- .map_err(|e| io::Error::other(format!("read: {}", e)))?;
-
- let bytes: &Bound<PyBytes> = result
+ .input
+ .call_method1(py, "pread_many", (request_list,))
+ .map_err(|e| io::Error::other(format!("pread_many: {}", e)))?;
+ let result_list: &Bound<PyList> = result
.downcast_bound(py)
- .map_err(|e| io::Error::other(format!("downcast: {}", e)))?;
-
- let data = bytes.as_bytes();
- if data.len() != buf.len() {
+ .map_err(|e| io::Error::other(format!("pread_many result: {}",
e)))?;
+ if result_list.len() != ranges.len() {
return Err(io::Error::new(
- io::ErrorKind::UnexpectedEof,
- format!("read {} of {} bytes", data.len(), buf.len()),
+ io::ErrorKind::InvalidData,
+ format!(
+ "pread_many returned {} buffers for {} ranges",
+ result_list.len(),
+ ranges.len()
+ ),
));
}
- buf.copy_from_slice(data);
+ for (idx, range) in ranges.iter_mut().enumerate() {
+ let item = result_list
+ .get_item(idx)
+ .map_err(|e| io::Error::other(format!("pread_many item:
{}", e)))?;
+ copy_py_bytes(&item, range.buf)?;
+ }
Ok(())
})
}
}
+fn copy_py_bytes(value: &Bound<'_, PyAny>, buf: &mut [u8]) -> io::Result<()> {
+ let bytes: &Bound<PyBytes> = value
+ .downcast()
+ .map_err(|e| io::Error::other(format!("downcast bytes: {}", e)))?;
+ let data = bytes.as_bytes();
+ if data.len() != buf.len() {
+ return Err(io::Error::new(
+ io::ErrorKind::UnexpectedEof,
+ format!("pread returned {} of {} bytes", data.len(), buf.len()),
+ ));
+ }
+ buf.copy_from_slice(data);
+ Ok(())
+}
+
struct PyOutputStream {
file: PyObject,
pos: u64,
@@ -633,14 +663,14 @@ impl VectorIndexWriter {
#[pyclass]
struct VectorIndexReader {
- inner: CoreVectorIndexReader<PyFileStream>,
+ inner: CoreVectorIndexReader<PyVectorIndexInput>,
}
#[pymethods]
impl VectorIndexReader {
#[new]
- fn new(file: PyObject) -> PyResult<Self> {
- let stream = PyFileStream { file };
+ fn new(input: PyObject) -> PyResult<Self> {
+ let stream = PyVectorIndexInput { input };
let reader = CoreVectorIndexReader::open(stream)
.map_err(|e| PyIOError::new_err(format!("Failed to open index:
{}", e)))?;
Ok(Self { inner: reader })
@@ -836,10 +866,61 @@ mod tests {
.add_vectors(id_array.readonly(), train.readonly())
.unwrap();
writer.write(output.as_any().clone().unbind()).unwrap();
- output.call_method1("seek", (0,)).unwrap();
output
}
+ fn vector_index_input<'py>(
+ py: Python<'py>,
+ output: &Bound<'py, PyAny>,
+ ) -> Bound<'py, PyAny> {
+ let data = output
+ .call_method0("getvalue")
+ .unwrap()
+ .downcast_into::<PyBytes>()
+ .unwrap();
+ Py::new(
+ py,
+ PyBytesVectorIndexInput {
+ data: data.as_bytes().to_vec(),
+ },
+ )
+ .unwrap()
+ .into_bound(py)
+ .into_any()
+ }
+
+ #[pyclass]
+ struct PyBytesVectorIndexInput {
+ data: Vec<u8>,
+ }
+
+ #[pymethods]
+ impl PyBytesVectorIndexInput {
+ fn pread_many<'py>(
+ &self,
+ py: Python<'py>,
+ ranges: &Bound<'_, PyList>,
+ ) -> PyResult<Bound<'py, PyList>> {
+ let result = PyList::empty_bound(py);
+ for item in ranges.iter() {
+ let (pos, len): (usize, usize) = item.extract()?;
+ let end = pos.checked_add(len).ok_or_else(|| {
+ PyIOError::new_err("pread_many range position overflow")
+ })?;
+ if end > self.data.len() {
+ return Err(PyIOError::new_err(format!(
+ "pread_many range {}..{} out of bounds {}",
+ pos,
+ end,
+ self.data.len()
+ )));
+ }
+ result.append(PyBytes::new_bound(py, &self.data[pos..end]))?;
+ }
+ Ok(result)
+ }
+ }
+
#[test]
fn python_unified_writer_reader_roundtrips_supported_indexes() {
Python::with_gil(|py| {
@@ -880,7 +961,8 @@ mod tests {
for (config, d, expected_type) in configs {
let output = write_index_bytes(py, &config, d);
- let mut reader =
VectorIndexReader::new(output.unbind()).unwrap();
+ let input = vector_index_input(py, &output);
+ let mut reader =
VectorIndexReader::new(input.unbind()).unwrap();
assert_eq!(reader.index_type(), expected_type);
assert_eq!(reader.dimension(), d);
assert_eq!(reader.metadata().index_type, expected_type);
@@ -925,8 +1007,8 @@ mod tests {
allowed.serialize_into(&mut filter_bytes).unwrap();
let filter = PyBytes::new_bound(py, &filter_bytes);
- output.call_method1("seek", (0,)).unwrap();
- let mut reader = VectorIndexReader::new(output.unbind()).unwrap();
+ let input = vector_index_input(py, &output);
+ let mut reader = VectorIndexReader::new(input.unbind()).unwrap();
let queries =
PyArray::from_vec2_bound(py, &[vec![0.0f32, 0.0], vec![10.0,
10.0]]).unwrap();
let (result_ids, result_dists) = reader
@@ -950,7 +1032,8 @@ mod tests {
.into_bound(py)
.into_any();
let output = write_index_bytes(py, &config, 16);
- let mut reader = VectorIndexReader::new(output.unbind()).unwrap();
+ let input = vector_index_input(py, &output);
+ let mut reader = VectorIndexReader::new(input.unbind()).unwrap();
let wrong_dim = PyArray::from_vec2_bound(py, &[vec![0.0f32;
15]]).unwrap();
let err = reader