This is an automated email from the ASF dual-hosted git repository.
JingsongLi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/paimon-vector-index.git
The following commit(s) were added to refs/heads/main by this push:
new 711a715 Tighten HNSW v1 storage metadata (#35)
711a715 is described below
commit 711a715508c8783efad51a83d9ca8fd04050557e
Author: Jingsong Lee <[email protected]>
AuthorDate: Thu Jun 11 11:13:54 2026 +0800
Tighten HNSW v1 storage metadata (#35)
---
STORAGE_FORMAT.md | 50 +++-
core/src/index_io_util.rs | 211 +++++++++++++--
core/src/io.rs | 220 +++++++--------
core/src/ivfhnswflat_io.rs | 449 +++++++++++++++++++++++++++----
core/src/ivfhnswsq_io.rs | 379 +++++++++++++++++++++++---
core/tests/fixtures/ivf_hnsw_flat_v1.hex | 12 +-
core/tests/fixtures/ivf_hnsw_sq_v1.hex | 10 +-
7 files changed, 1083 insertions(+), 248 deletions(-)
diff --git a/STORAGE_FORMAT.md b/STORAGE_FORMAT.md
index 218f8c7..2e1bf5f 100644
--- a/STORAGE_FORMAT.md
+++ b/STORAGE_FORMAT.md
@@ -55,15 +55,21 @@ decoded sequence that is not monotonically non-decreasing
in signed order.
### HNSW Graph Section
IVF-HNSW-FLAT and IVF-HNSW-SQ store one graph section per non-empty list. The
-section is a contiguous sequence of little-endian `u32` values:
+section starts with a fixed header, followed by a contiguous sequence of
+unsigned LEB128 varints. Neighbor ids within each adjacency group are sorted by
+local vector id and stored as unsigned deltas from the previous neighbor id,
+with an initial previous id of `0`:
| Field | Count |
| --- | --- |
-| `graph_count` | 1 |
-| `entry_point` | 1 |
-| `max_observed_level` | 1 |
-| `level[node]` | `graph_count` |
-| `degree[node][level]` followed by neighbor ids | one group for each node
level |
+| `graph_magic` | 1 little-endian `u32`, `HWGR` (`0x48574752`) |
+| `graph_version` | 1 little-endian `u32`, currently `1` |
+| `graph_flags` | 1 little-endian `u32`; bit 0 delta-varint adjacency is
required |
+| `graph_count` | 1 varint |
+| `entry_point` | 1 varint |
+| `max_observed_level` | 1 varint |
+| `level[node]` | `graph_count` varints |
+| `degree[node][level]` followed by neighbor id deltas | one group for each
node level |
Each node has levels `0..=level[node]`. A level-0 node may have at most `2 * m`
neighbors, and higher levels may have at most `m` neighbors.
@@ -167,20 +173,30 @@ Magic: `IHFL` (`0x4948464C`). Version: `1`. Header size:
64 bytes.
| 28 | 4 | `i32` | HNSW `m` |
| 32 | 4 | `i32` | HNSW `ef_construction` |
| 36 | 4 | `i32` | HNSW `max_level` |
-| 40 | 24 | bytes | reserved |
+| 40 | 4 | `u32` | flags |
+| 44 | 20 | bytes | reserved |
+
+Flags:
+
+| Bit | Meaning |
+| ---: | --- |
+| 0 | sorted delta-varint ids are stored; required in v1 |
+| 1 | HNSW graph section uses the v1 delta-varint graph encoding; required in
v1 |
Sections after the header:
1. IVF coarse centroids: `nlist * d` `f32` values.
2. Offset table: `nlist` entries of
- `(offset: i64, count: i32, graph_bytes_len: i32, reserved: i64)`.
+ `(offset: i64, count: i32, graph_bytes_len: i32, payload_bytes_len: i64)`.
3. List payloads.
For each non-empty list payload:
| Field | Type | Notes |
| --- | --- | --- |
-| `ids` | `count` `i64` | row ids in list order |
+| `base_id` | `i64` | first sorted row id |
+| `id_bytes_len` | `i32` | byte length of encoded id stream |
+| `id_bytes` | bytes | delta-varint ids |
| `vectors` | `count * d` `f32` | raw stored vectors |
| `graph` | bytes | HNSW graph section |
@@ -201,7 +217,15 @@ Magic: `IHSQ` (`0x49485351`). Version: `1`. Header size:
64 bytes.
| 36 | 4 | `i32` | HNSW `max_level` |
| 40 | 4 | `f32` | global minimum SQ bound summary |
| 44 | 4 | `f32` | global maximum SQ bound summary |
-| 48 | 16 | bytes | reserved |
+| 48 | 4 | `u32` | flags |
+| 52 | 12 | bytes | reserved |
+
+Flags:
+
+| Bit | Meaning |
+| ---: | --- |
+| 0 | sorted delta-varint ids are stored; required in v1 |
+| 1 | HNSW graph section uses the v1 delta-varint graph encoding; required in
v1 |
Sections after the header:
@@ -211,13 +235,15 @@ Sections after the header:
max `f32` values.
4. IVF coarse centroids: `nlist * d` `f32` values.
5. Offset table: `nlist` entries of
- `(offset: i64, count: i32, graph_bytes_len: i32, reserved: i64)`.
+ `(offset: i64, count: i32, graph_bytes_len: i32, payload_bytes_len: i64)`.
6. List payloads.
For each non-empty list payload:
| Field | Type | Notes |
| --- | --- | --- |
-| `ids` | `count` `i64` | row ids in list order |
+| `base_id` | `i64` | first sorted row id |
+| `id_bytes_len` | `i32` | byte length of encoded id stream |
+| `id_bytes` | bytes | delta-varint ids |
| `codes` | bytes | scalar quantized residual codes, `count * d` bytes |
| `graph` | bytes | HNSW graph section over decoded vectors |
diff --git a/core/src/index_io_util.rs b/core/src/index_io_util.rs
index d7d1fb1..d788764 100644
--- a/core/src/index_io_util.rs
+++ b/core/src/index_io_util.rs
@@ -65,22 +65,103 @@ pub(crate) fn validate_search_inputs(
Ok(())
}
+const HNSW_GRAPH_MAGIC: u32 = 0x48574752; // "HWGR"
+const HNSW_GRAPH_VERSION: u32 = 1;
+const HNSW_GRAPH_FLAG_DELTA_VARINT: u32 = 1 << 0;
+const HNSW_GRAPH_REQUIRED_FLAGS: u32 = HNSW_GRAPH_FLAG_DELTA_VARINT;
+const HNSW_GRAPH_SUPPORTED_FLAGS: u32 = HNSW_GRAPH_REQUIRED_FLAGS;
+
+pub(crate) fn encode_delta_varint_ids(ids: &[i64]) -> (i64, Vec<u8>) {
+ if ids.is_empty() {
+ return (0, Vec::new());
+ }
+ let base = ids[0];
+ let mut buf = Vec::with_capacity(ids.len() * 2);
+ let mut prev = base;
+ for &id in ids {
+ let delta = (id as u64).wrapping_sub(prev as u64);
+ write_u64_varint(&mut buf, delta);
+ prev = id;
+ }
+ (base, buf)
+}
+
+pub(crate) fn decode_delta_varint_ids(base: i64, buf: &[u8], count: usize) ->
io::Result<Vec<i64>> {
+ let mut ids = Vec::with_capacity(count);
+ let mut pos = 0;
+ let mut current = base as u64;
+ let mut prev_signed = base;
+ for _ in 0..count {
+ let delta = read_u64_varint(buf, &mut pos)?;
+ current = current.wrapping_add(delta);
+ let id = current as i64;
+ if id < prev_signed {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ "decoded ID sequence is not monotonically non-decreasing",
+ ));
+ }
+ prev_signed = id;
+ ids.push(id);
+ }
+ if pos != buf.len() {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ "trailing bytes in delta-varint ID section",
+ ));
+ }
+ Ok(ids)
+}
+
pub(crate) fn encode_graph(graph: Option<&HnswGraph>) -> io::Result<Vec<u8>> {
let Some(graph) = graph else {
return Ok(Vec::new());
};
let mut buf = Vec::new();
- write_u32_vec(&mut buf, graph.len())?;
- write_u32_vec(&mut buf, graph.entry_point())?;
- write_u32_vec(&mut buf, graph.max_observed_level())?;
+ write_u32_fixed(&mut buf, HNSW_GRAPH_MAGIC)?;
+ write_u32_fixed(&mut buf, HNSW_GRAPH_VERSION)?;
+ write_u32_fixed(&mut buf, HNSW_GRAPH_REQUIRED_FLAGS)?;
+ write_u32_varint(&mut buf, graph.len())?;
+ write_u32_varint(&mut buf, graph.entry_point())?;
+ write_u32_varint(&mut buf, graph.max_observed_level())?;
+ for &level in graph.levels() {
+ write_u32_varint(&mut buf, level)?;
+ }
+ for node_levels in graph.neighbors() {
+ for level_neighbors in node_levels {
+ write_u32_varint(&mut buf, level_neighbors.len())?;
+ let mut sorted_neighbors = level_neighbors.clone();
+ sorted_neighbors.sort_unstable();
+ let mut previous = 0usize;
+ for neighbor in sorted_neighbors {
+ let delta = neighbor.checked_sub(previous).ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "graph neighbor delta underflow",
+ )
+ })?;
+ write_u32_varint(&mut buf, delta)?;
+ previous = neighbor;
+ }
+ }
+ }
+ Ok(buf)
+}
+
+#[cfg(test)]
+pub(crate) fn encode_graph_u32_for_size_estimate(graph: &HnswGraph) ->
io::Result<Vec<u8>> {
+ let mut buf = Vec::new();
+ write_u32_fixed(&mut buf, checked_u32_value(graph.len())?)?;
+ write_u32_fixed(&mut buf, checked_u32_value(graph.entry_point())?)?;
+ write_u32_fixed(&mut buf, checked_u32_value(graph.max_observed_level())?)?;
for &level in graph.levels() {
- write_u32_vec(&mut buf, level)?;
+ write_u32_fixed(&mut buf, checked_u32_value(level)?)?;
}
for node_levels in graph.neighbors() {
for level_neighbors in node_levels {
- write_u32_vec(&mut buf, level_neighbors.len())?;
+ write_u32_fixed(&mut buf,
checked_u32_value(level_neighbors.len())?)?;
for &neighbor in level_neighbors {
- write_u32_vec(&mut buf, neighbor)?;
+ write_u32_fixed(&mut buf, checked_u32_value(neighbor)?)?;
}
}
}
@@ -99,7 +180,38 @@ pub(crate) fn decode_graph(
return Ok(None);
}
let mut pos = 0usize;
- let graph_count = read_u32_vec(bytes, &mut pos)? as usize;
+ let graph_magic = read_u32_fixed(bytes, &mut pos)?;
+ if graph_magic != HNSW_GRAPH_MAGIC {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!("Invalid HNSW graph magic: 0x{:08X}", graph_magic),
+ ));
+ }
+ let graph_version = read_u32_fixed(bytes, &mut pos)?;
+ if graph_version != HNSW_GRAPH_VERSION {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!("Unsupported HNSW graph version: {}", graph_version),
+ ));
+ }
+ let graph_flags = read_u32_fixed(bytes, &mut pos)?;
+ let unknown_graph_flags = graph_flags & !HNSW_GRAPH_SUPPORTED_FLAGS;
+ if unknown_graph_flags != 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!(
+ "Unsupported HNSW graph flags: 0x{:08X}",
+ unknown_graph_flags
+ ),
+ ));
+ }
+ if graph_flags & HNSW_GRAPH_REQUIRED_FLAGS != HNSW_GRAPH_REQUIRED_FLAGS {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ "HNSW graph v1 requires delta-varint adjacency encoding",
+ ));
+ }
+ let graph_count = read_u32_varint(bytes, &mut pos)? as usize;
if graph_count != count {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
@@ -109,11 +221,11 @@ pub(crate) fn decode_graph(
),
));
}
- let entry_point = read_u32_vec(bytes, &mut pos)? as usize;
- let max_observed_level = read_u32_vec(bytes, &mut pos)? as usize;
+ let entry_point = read_u32_varint(bytes, &mut pos)? as usize;
+ let max_observed_level = read_u32_varint(bytes, &mut pos)? as usize;
let mut levels = Vec::with_capacity(count);
for node in 0..count {
- let level = read_u32_vec(bytes, &mut pos)? as usize;
+ let level = read_u32_varint(bytes, &mut pos)? as usize;
if level >= hnsw_params.max_level {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
@@ -131,7 +243,7 @@ pub(crate) fn decode_graph(
for (node, &level) in levels.iter().enumerate() {
let mut node_levels = Vec::with_capacity(level + 1);
for graph_level in 0..=level {
- let degree = read_u32_vec(bytes, &mut pos)? as usize;
+ let degree = read_u32_varint(bytes, &mut pos)? as usize;
let max_degree = if graph_level == 0 {
hnsw_params.m.saturating_mul(2)
} else {
@@ -147,8 +259,20 @@ pub(crate) fn decode_graph(
));
}
let mut level_neighbors = Vec::with_capacity(degree);
+ let mut previous = 0usize;
for _ in 0..degree {
- level_neighbors.push(read_u32_vec(bytes, &mut pos)? as usize);
+ let delta = read_u32_varint(bytes, &mut pos)? as usize;
+ let neighbor = previous.checked_add(delta).ok_or_else(|| {
+ io::Error::new(io::ErrorKind::InvalidData, "graph neighbor
id overflow")
+ })?;
+ if neighbor > u32::MAX as usize {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!("graph neighbor id {} is out of range",
neighbor),
+ ));
+ }
+ level_neighbors.push(neighbor);
+ previous = neighbor;
}
node_levels.push(level_neighbors);
}
@@ -173,18 +297,22 @@ pub(crate) fn decode_graph(
)?))
}
-fn write_u32_vec(buf: &mut Vec<u8>, value: usize) -> io::Result<()> {
+fn checked_u32_value(value: usize) -> io::Result<u32> {
if value > u32::MAX as usize {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("value {} exceeds u32 limit", value),
));
}
- buf.extend_from_slice(&(value as u32).to_le_bytes());
+ Ok(value as u32)
+}
+
+fn write_u32_fixed(buf: &mut Vec<u8>, value: u32) -> io::Result<()> {
+ buf.extend_from_slice(&value.to_le_bytes());
Ok(())
}
-fn read_u32_vec(bytes: &[u8], pos: &mut usize) -> io::Result<u32> {
+fn read_u32_fixed(bytes: &[u8], pos: &mut usize) -> io::Result<u32> {
let end = pos.checked_add(4).ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
@@ -202,6 +330,59 @@ fn read_u32_vec(bytes: &[u8], pos: &mut usize) ->
io::Result<u32> {
Ok(value)
}
+fn write_u32_varint(buf: &mut Vec<u8>, value: usize) -> io::Result<()> {
+ write_u64_varint(buf, checked_u32_value(value)? as u64);
+ Ok(())
+}
+
+fn write_u64_varint(buf: &mut Vec<u8>, mut value: u64) {
+ while value >= 0x80 {
+ buf.push((value as u8 & 0x7f) | 0x80);
+ value >>= 7;
+ }
+ buf.push(value as u8);
+}
+
+fn read_u32_varint(bytes: &[u8], pos: &mut usize) -> io::Result<u32> {
+ let value = read_u64_varint(bytes, pos)?;
+ u32::try_from(value).map_err(|_| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!("HNSW graph varint value {} exceeds u32 limit", value),
+ )
+ })
+}
+
+fn read_u64_varint(bytes: &[u8], pos: &mut usize) -> io::Result<u64> {
+ let mut value = 0u64;
+ let mut shift = 0u32;
+ for _ in 0..10 {
+ if *pos >= bytes.len() {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ "truncated HNSW graph varint",
+ ));
+ }
+ let byte = bytes[*pos];
+ *pos += 1;
+ if shift == 63 && (byte & 0x7e) != 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ "HNSW graph varint exceeds u64 limit",
+ ));
+ }
+ value |= ((byte & 0x7f) as u64) << shift;
+ if byte & 0x80 == 0 {
+ return Ok(value);
+ }
+ shift += 7;
+ }
+ Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ "HNSW graph varint exceeds u64 limit",
+ ))
+}
+
pub(crate) fn write_u32_le(out: &mut dyn SeekWrite, v: u32) -> io::Result<()> {
out.write_all(&v.to_le_bytes())
}
diff --git a/core/src/io.rs b/core/src/io.rs
index 42d42d9..67c7e24 100644
--- a/core/src/io.rs
+++ b/core/src/io.rs
@@ -16,6 +16,7 @@
// under the License.
use crate::distance::MetricType;
+use crate::index_io_util::{decode_delta_varint_ids, encode_delta_varint_ids};
use crate::ivfpq::IVFPQIndex;
use crate::opq::OPQMatrix;
use crate::pq::ProductQuantizer;
@@ -116,91 +117,6 @@ impl<W: io::Write + Send> SeekWrite for PosWriter<W> {
}
}
-// --- Varint encoding ---
-
-fn encode_varint(mut val: u64, buf: &mut Vec<u8>) {
- while val >= 0x80 {
- buf.push((val as u8) | 0x80);
- val >>= 7;
- }
- buf.push(val as u8);
-}
-
-fn decode_varint(buf: &[u8], pos: &mut usize) -> io::Result<u64> {
- let mut val: u64 = 0;
- let mut shift = 0u32;
- loop {
- if *pos >= buf.len() {
- return Err(io::Error::new(
- io::ErrorKind::InvalidData,
- "truncated varint",
- ));
- }
- let b = buf[*pos] as u64;
- *pos += 1;
- let payload = b & 0x7F;
- if shift == 63 && payload > 1 {
- return Err(io::Error::new(
- io::ErrorKind::InvalidData,
- "varint exceeds u64 range",
- ));
- }
- val |= payload << shift;
- if b & 0x80 == 0 {
- break;
- }
- shift += 7;
- if shift > 63 {
- return Err(io::Error::new(
- io::ErrorKind::InvalidData,
- "varint exceeds 64 bits",
- ));
- }
- }
- Ok(val)
-}
-
-/// Encode sorted i64 IDs as delta-varint. Returns (base_id, encoded_bytes).
-/// Uses unsigned subtraction to handle the full i64 range without overflow.
-fn encode_delta_varint_ids(ids: &[i64]) -> (i64, Vec<u8>) {
- if ids.is_empty() {
- return (0, Vec::new());
- }
- let base = ids[0];
- let mut buf = Vec::with_capacity(ids.len() * 2);
- let mut prev = base;
- for &id in ids {
- let delta = (id as u64).wrapping_sub(prev as u64);
- encode_varint(delta, &mut buf);
- prev = id;
- }
- (base, buf)
-}
-
-/// Decode delta-varint encoded IDs using wrapping unsigned arithmetic
-/// (inverse of encode_delta_varint_ids). Validates monotonically
non-decreasing
-/// signed order — rejects corrupt data that would wrap around.
-fn decode_delta_varint_ids(base: i64, buf: &[u8], count: usize) ->
io::Result<Vec<i64>> {
- let mut ids = Vec::with_capacity(count);
- let mut pos = 0;
- let mut current = base as u64;
- let mut prev_signed = base;
- for _ in 0..count {
- let delta = decode_varint(buf, &mut pos)?;
- current = current.wrapping_add(delta);
- let id = current as i64;
- if id < prev_signed {
- return Err(io::Error::new(
- io::ErrorKind::InvalidData,
- "decoded ID sequence is not monotonically non-decreasing",
- ));
- }
- prev_signed = id;
- ids.push(id);
- }
- Ok(ids)
-}
-
// --- Read/write helpers ---
fn write_u32_le(out: &mut dyn SeekWrite, v: u32) -> io::Result<()> {
@@ -310,6 +226,11 @@ pub fn write_index(index: &IVFPQIndex, out: &mut dyn
SeekWrite) -> io::Result<()
let ksub = index.pq.ksub;
let dsub = index.pq.dsub;
let code_size = index.pq.code_size();
+ let d_i32 = usize_to_i32(d, "dimension")?;
+ let nlist_i32 = usize_to_i32(nlist, "nlist")?;
+ let m_i32 = usize_to_i32(m, "pq m")?;
+ let ksub_i32 = usize_to_i32(ksub, "pq ksub")?;
+ let dsub_i32 = usize_to_i32(dsub, "pq dsub")?;
let mut flags: u32 = FLAG_DELTA_IDS | FLAG_TRANSPOSED_CODES;
if index.opq.is_some() {
@@ -319,7 +240,15 @@ pub fn write_index(index: &IVFPQIndex, out: &mut dyn
SeekWrite) -> io::Result<()
flags |= FLAG_BY_RESIDUAL;
}
- let total_vectors: i64 = index.ids.iter().map(|l| l.len() as i64).sum();
+ let total_vectors = index.ids.iter().try_fold(0i64, |sum, ids| {
+ let count = usize_to_i64(ids.len(), "total vector count")?;
+ sum.checked_add(count).ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "total vector count exceeds i64 length limit",
+ )
+ })
+ })?;
// Sort IDs within each list and prepare delta-varint encoded data
let mut sorted_lists: Vec<(Vec<i64>, Vec<u8>, Vec<u8>)> =
Vec::with_capacity(nlist);
@@ -335,7 +264,8 @@ pub fn write_index(index: &IVFPQIndex, out: &mut dyn
SeekWrite) -> io::Result<()
indices.sort_by_key(|&idx| index.ids[i][idx]);
let sorted_ids: Vec<i64> = indices.iter().map(|&idx|
index.ids[i][idx]).collect();
- let mut sorted_codes = vec![0u8; count * code_size];
+ let code_bytes = checked_list_bytes(count, code_size)?;
+ let mut sorted_codes = vec![0u8; code_bytes];
for (new_idx, &old_idx) in indices.iter().enumerate() {
sorted_codes[new_idx * code_size..(new_idx + 1) * code_size]
.copy_from_slice(&index.codes[i][old_idx * code_size..(old_idx
+ 1) * code_size]);
@@ -348,11 +278,11 @@ pub fn write_index(index: &IVFPQIndex, out: &mut dyn
SeekWrite) -> io::Result<()
// Header
write_u32_le(out, MAGIC)?;
write_u32_le(out, VERSION)?;
- write_i32_le(out, d as i32)?;
- write_i32_le(out, nlist as i32)?;
- write_i32_le(out, m as i32)?;
- write_i32_le(out, ksub as i32)?;
- write_i32_le(out, dsub as i32)?;
+ write_i32_le(out, d_i32)?;
+ write_i32_le(out, nlist_i32)?;
+ write_i32_le(out, m_i32)?;
+ write_i32_le(out, ksub_i32)?;
+ write_i32_le(out, dsub_i32)?;
write_u32_le(out, index.metric as u32)?;
write_i64_le(out, total_vectors)?;
write_u32_le(out, flags)?;
@@ -367,8 +297,21 @@ pub fn write_index(index: &IVFPQIndex, out: &mut dyn
SeekWrite) -> io::Result<()
// Compute offsets for inverted lists
// Delta-varint format per list: [base_id: i64][id_bytes_len:
u32][id_bytes][codes]
- let offset_table_size = nlist * 16;
- let data_start = out.pos() + offset_table_size as u64;
+ let offset_table_size = nlist.checked_mul(16).ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "IVFPQ offset table size overflow",
+ )
+ })?;
+ let data_start = out
+ .pos()
+ .checked_add(offset_table_size as u64)
+ .ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "IVFPQ data start offset overflow",
+ )
+ })?;
let mut list_offsets = vec![0i64; nlist];
let mut list_counts = vec![0i32; nlist];
@@ -376,20 +319,25 @@ pub fn write_index(index: &IVFPQIndex, out: &mut dyn
SeekWrite) -> io::Result<()
let mut current_offset = data_start;
for i in 0..nlist {
- list_offsets[i] = current_offset as i64;
+ list_offsets[i] = u64_to_i64(current_offset, "list offset")?;
let count = sorted_lists[i].0.len();
- list_counts[i] = count as i32;
+ list_counts[i] = usize_to_i32(count, "list count")?;
if count > 0 {
// base_id(8) + id_bytes_len(4) + id_bytes + codes
let id_bytes_len = sorted_lists[i].1.len();
- if id_bytes_len > i32::MAX as usize {
- return Err(io::Error::new(
- io::ErrorKind::InvalidInput,
- "delta ID section exceeds i32 length limit",
- ));
- }
- list_id_bytes_lens[i] = id_bytes_len as i32;
- current_offset += 8 + 4 + id_bytes_len as u64 + (count *
code_size) as u64;
+ list_id_bytes_lens[i] = usize_to_i32(id_bytes_len, "delta ID
section")?;
+ let code_bytes = checked_list_bytes(count, code_size)?;
+ let list_bytes = 12usize
+ .checked_add(id_bytes_len)
+ .and_then(|len| len.checked_add(code_bytes))
+ .ok_or_else(|| {
+ io::Error::new(io::ErrorKind::InvalidInput, "IVFPQ list
size overflow")
+ })?;
+ current_offset = current_offset
+ .checked_add(list_bytes as u64)
+ .ok_or_else(|| {
+ io::Error::new(io::ErrorKind::InvalidInput, "IVFPQ offset
overflow")
+ })?;
}
}
@@ -409,13 +357,14 @@ pub fn write_index(index: &IVFPQIndex, out: &mut dyn
SeekWrite) -> io::Result<()
// base_id
write_i64_le(out, sorted_ids[0])?;
// id_bytes_len + id_bytes
- write_i32_le(out, id_bytes.len() as i32)?;
+ write_i32_le(out, usize_to_i32(id_bytes.len(), "delta ID section")?)?;
out.write_all(id_bytes)?;
// PQ codes — transpose for cache-friendly SIMD scan
let count = sorted_ids.len();
if code_size == m {
// 8-bit: transpose from [n][M] to [M][n]
- let mut transposed = vec![0u8; count * m];
+ let transposed_len = checked_list_bytes(count, m)?;
+ let mut transposed = vec![0u8; transposed_len];
for vec_idx in 0..count {
for sub in 0..m {
transposed[sub * count + vec_idx] = sorted_codes[vec_idx *
m + sub];
@@ -426,7 +375,8 @@ pub fn write_index(index: &IVFPQIndex, out: &mut dyn
SeekWrite) -> io::Result<()
// 4-bit: transpose from [n][M/2] to [M/2][n]
// Each byte at position `pair` in a vector goes to column `pair`
let cs = code_size;
- let mut transposed = vec![0u8; count * cs];
+ let transposed_len = checked_list_bytes(count, cs)?;
+ let mut transposed = vec![0u8; transposed_len];
for vec_idx in 0..count {
for pair in 0..cs {
transposed[pair * count + vec_idx] = sorted_codes[vec_idx
* cs + pair];
@@ -439,6 +389,36 @@ pub fn write_index(index: &IVFPQIndex, out: &mut dyn
SeekWrite) -> io::Result<()
Ok(())
}
+fn usize_to_i32(value: usize, field: &str) -> io::Result<i32> {
+ if value > i32::MAX as usize {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!("{} exceeds i32 length limit: {}", field, value),
+ ));
+ }
+ Ok(value as i32)
+}
+
+fn usize_to_i64(value: usize, field: &str) -> io::Result<i64> {
+ if value > i64::MAX as usize {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!("{} exceeds i64 length limit: {}", field, value),
+ ));
+ }
+ Ok(value as i64)
+}
+
+fn u64_to_i64(value: u64, field: &str) -> io::Result<i64> {
+ if value > i64::MAX as u64 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!("{} exceeds i64 offset limit: {}", field, value),
+ ));
+ }
+ Ok(value as i64)
+}
+
// --- Reader ---
pub struct IVFPQIndexReader<R: SeekRead> {
@@ -954,27 +934,19 @@ mod tests {
#[test]
fn test_varint_roundtrip() {
- let mut buf = Vec::new();
- encode_varint(0, &mut buf);
- encode_varint(127, &mut buf);
- encode_varint(128, &mut buf);
- encode_varint(16383, &mut buf);
- encode_varint(1_000_000, &mut buf);
-
- let mut pos = 0;
- assert_eq!(decode_varint(&buf, &mut pos).unwrap(), 0);
- assert_eq!(decode_varint(&buf, &mut pos).unwrap(), 127);
- assert_eq!(decode_varint(&buf, &mut pos).unwrap(), 128);
- assert_eq!(decode_varint(&buf, &mut pos).unwrap(), 16383);
- assert_eq!(decode_varint(&buf, &mut pos).unwrap(), 1_000_000);
+ let ids = [0, 127, 128, 16_383, 1_000_000];
+ let (base, encoded) = encode_delta_varint_ids(&ids);
+ assert_eq!(
+ decode_delta_varint_ids(base, &encoded, ids.len()).unwrap(),
+ ids
+ );
}
#[test]
fn test_varint_above_u64_max_returns_error() {
let mut bytes = vec![0xFFu8; 9];
bytes.push(0x02); // 10th byte with payload > 1 at shift=63
- let mut pos = 0;
- assert!(decode_varint(&bytes, &mut pos).is_err());
+ assert!(decode_delta_varint_ids(0, &bytes, 1).is_err());
}
#[test]
@@ -1297,8 +1269,8 @@ mod tests {
#[test]
fn test_delta_ids_wraparound_returns_error() {
// base_id = i64::MAX, delta = 1 would wrap to i64::MIN (non-monotonic)
- let mut id_bytes = Vec::new();
- encode_varint(1, &mut id_bytes);
+ let (_, id_bytes) = encode_delta_varint_ids(&[i64::MAX, i64::MIN]);
+ let id_bytes = id_bytes[1..].to_vec();
let result = decode_delta_varint_ids(i64::MAX, &id_bytes, 1);
assert!(
result.is_err(),
diff --git a/core/src/ivfhnswflat_io.rs b/core/src/ivfhnswflat_io.rs
index d7d970a..0e949b2 100644
--- a/core/src/ivfhnswflat_io.rs
+++ b/core/src/ivfhnswflat_io.rs
@@ -19,10 +19,11 @@ use crate::distance::{fvec_distance, preprocess_vectors,
MetricType};
use crate::hnsw::{HnswBuildParams, HnswGraph};
use crate::hnsw_search::{search_hnsw_lists, HnswSearchList};
use crate::index_io_util::{
- bytes_to_f32_vec, checked_list_bytes, checked_list_offset,
checked_section_size, decode_graph,
- decode_roaring_filter, encode_graph, read_f32_vec, read_i32_le,
read_i64_le, read_u32_le,
- u64_to_i64, usize_to_i32, usize_to_i64, validate_positive_i32,
validate_search_inputs,
- write_f32_slice, write_i32_le, write_i64_le, write_u32_le,
+ bytes_to_f32_vec, checked_list_bytes, checked_list_offset,
checked_section_size,
+ decode_delta_varint_ids, decode_graph, decode_roaring_filter,
encode_delta_varint_ids,
+ encode_graph, read_f32_vec, read_i32_le, read_i64_le, read_u32_le,
u64_to_i64, usize_to_i32,
+ usize_to_i64, validate_positive_i32, validate_search_inputs,
write_f32_slice, write_i32_le,
+ write_i64_le, write_u32_le,
};
use crate::io::{PreadCursor, ReadRequest, SeekRead, SeekWrite};
use crate::ivfhnswflat::IVFHNSWFlatIndex;
@@ -34,6 +35,10 @@ use std::io;
pub const IVF_HNSW_FLAT_MAGIC: u32 = 0x4948464C; // "IHFL"
pub const IVF_HNSW_FLAT_VERSION: u32 = 1;
pub const IVF_HNSW_FLAT_HEADER_SIZE: usize = 64;
+const FLAG_DELTA_IDS: u32 = 1 << 0;
+const FLAG_GRAPH_V1: u32 = 1 << 1;
+const REQUIRED_FLAGS: u32 = FLAG_DELTA_IDS | FLAG_GRAPH_V1;
+const SUPPORTED_FLAGS: u32 = REQUIRED_FLAGS;
const MAX_COALESCED_READ_GAP_BYTES: u64 = 1 << 20;
pub fn write_ivfhnswflat_index(
@@ -52,14 +57,8 @@ pub fn write_ivfhnswflat_index(
)
})
})?;
- let graph_bytes: Vec<Vec<u8>> = (0..nlist)
- .map(|list_id| {
- if index.flat.ids[list_id].is_empty() {
- Ok(Vec::new())
- } else {
- encode_graph(index.graphs[list_id].as_ref())
- }
- })
+ let sorted_lists: Vec<SortedFlatGraphList> = (0..nlist)
+ .map(|list_id| build_sorted_flat_graph_list(index, list_id))
.collect::<io::Result<_>>()?;
write_u32_le(out, IVF_HNSW_FLAT_MAGIC)?;
@@ -75,7 +74,8 @@ pub fn write_ivfhnswflat_index(
usize_to_i32(params.ef_construction, "hnsw ef_construction")?,
)?;
write_i32_le(out, usize_to_i32(params.max_level, "hnsw max_level")?)?;
- out.write_all(&[0u8; 24])?;
+ write_u32_le(out, REQUIRED_FLAGS)?;
+ out.write_all(&[0u8; 20])?;
write_f32_slice(out, &index.flat.quantizer_centroids)?;
@@ -97,15 +97,23 @@ pub fn write_ivfhnswflat_index(
let mut list_offsets = vec![0i64; nlist];
let mut list_counts = vec![0i32; nlist];
let mut list_graph_bytes_lens = vec![0i32; nlist];
+ let mut list_payload_bytes_lens = vec![0i64; nlist];
let mut current_offset = data_start;
for list_id in 0..nlist {
list_offsets[list_id] = u64_to_i64(current_offset, "list offset")?;
- let count = index.flat.ids[list_id].len();
+ let count = sorted_lists[list_id].ids.len();
list_counts[list_id] = usize_to_i32(count, "list count")?;
- list_graph_bytes_lens[list_id] =
usize_to_i32(graph_bytes[list_id].len(), "graph bytes")?;
+ list_graph_bytes_lens[list_id] =
+ usize_to_i32(sorted_lists[list_id].graph_bytes.len(), "graph
bytes")?;
if count > 0 {
- let payload_len = list_payload_len(count, d,
graph_bytes[list_id].len())?;
+ let payload_len = list_payload_len(
+ count,
+ d,
+ sorted_lists[list_id].id_bytes.len(),
+ sorted_lists[list_id].graph_bytes.len(),
+ )?;
+ list_payload_bytes_lens[list_id] = usize_to_i64(payload_len, "list
payload bytes")?;
current_offset = current_offset
.checked_add(payload_len as u64)
.ok_or_else(|| {
@@ -118,18 +126,19 @@ pub fn write_ivfhnswflat_index(
write_i64_le(out, list_offsets[list_id])?;
write_i32_le(out, list_counts[list_id])?;
write_i32_le(out, list_graph_bytes_lens[list_id])?;
- write_i64_le(out, 0)?;
+ write_i64_le(out, list_payload_bytes_lens[list_id])?;
}
for list_id in 0..nlist {
- if index.flat.ids[list_id].is_empty() {
+ let list = &sorted_lists[list_id];
+ if list.ids.is_empty() {
continue;
}
- for &id in &index.flat.ids[list_id] {
- write_i64_le(out, id)?;
- }
- write_f32_slice(out, &index.flat.vectors[list_id])?;
- out.write_all(&graph_bytes[list_id])?;
+ write_i64_le(out, list.ids[0])?;
+ write_i32_le(out, usize_to_i32(list.id_bytes.len(), "delta ID
section")?)?;
+ out.write_all(&list.id_bytes)?;
+ write_f32_slice(out, &list.vectors)?;
+ out.write_all(&list.graph_bytes)?;
}
Ok(())
@@ -146,6 +155,7 @@ pub struct IVFHNSWFlatIndexReader<R: SeekRead> {
pub list_offsets: Vec<i64>,
pub list_counts: Vec<i32>,
pub list_graph_bytes_lens: Vec<i32>,
+ pub list_payload_bytes_lens: Vec<i64>,
loaded: bool,
}
@@ -187,8 +197,22 @@ impl<R: SeekRead> IVFHNSWFlatIndexReader<R> {
max_level: validate_positive_i32(read_i32_le(&mut cursor)?, "hnsw
max_level")? as usize,
}
.sanitized();
- let mut reserved = [0u8; 24];
+ let flags = read_u32_le(&mut cursor)?;
+ let mut reserved = [0u8; 20];
cursor.read_exact(&mut reserved)?;
+ let unknown_flags = flags & !SUPPORTED_FLAGS;
+ if unknown_flags != 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!("Unsupported IVF_HNSW_FLAT flags: 0x{:08X}",
unknown_flags),
+ ));
+ }
+ if flags & REQUIRED_FLAGS != REQUIRED_FLAGS {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ "IVF_HNSW_FLAT v1 requires delta-varint IDs and graph v1",
+ ));
+ }
Ok(Self {
reader,
@@ -201,6 +225,7 @@ impl<R: SeekRead> IVFHNSWFlatIndexReader<R> {
list_offsets: Vec::new(),
list_counts: Vec::new(),
list_graph_bytes_lens: Vec::new(),
+ list_payload_bytes_lens: Vec::new(),
loaded: false,
})
}
@@ -216,6 +241,7 @@ impl<R: SeekRead> IVFHNSWFlatIndexReader<R> {
self.list_offsets = vec![0; self.nlist];
self.list_counts = vec![0; self.nlist];
self.list_graph_bytes_lens = vec![0; self.nlist];
+ self.list_payload_bytes_lens = vec![0; self.nlist];
for list_id in 0..self.nlist {
self.list_offsets[list_id] = read_i64_le(&mut cursor)?;
let count = read_i32_le(&mut cursor)?;
@@ -237,7 +263,17 @@ impl<R: SeekRead> IVFHNSWFlatIndexReader<R> {
));
}
self.list_graph_bytes_lens[list_id] = graph_bytes_len;
- let _reserved = read_i64_le(&mut cursor)?;
+ let payload_bytes_len = read_i64_le(&mut cursor)?;
+ if payload_bytes_len < 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!(
+ "negative payload_bytes_len {} at list {}",
+ payload_bytes_len, list_id
+ ),
+ ));
+ }
+ self.list_payload_bytes_lens[list_id] = payload_bytes_len;
}
self.loaded = true;
@@ -402,7 +438,28 @@ impl<R: SeekRead> IVFHNSWFlatIndexReader<R> {
format!("list {} is missing HNSW graph", list_id),
));
}
- let payload_len = list_payload_len(count, self.d, graph_bytes_len)?;
+ let payload_len = self.list_payload_bytes_lens[list_id] as usize;
+ if payload_len == 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!("list {} is missing payload length", list_id),
+ ));
+ }
+ let minimum_payload_len =
12usize.checked_add(graph_bytes_len).ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ "IVF-HNSW-FLAT minimum payload length overflow",
+ )
+ })?;
+ if payload_len < minimum_payload_len {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!(
+ "list {} payload length {} is shorter than expected {}",
+ list_id, payload_len, minimum_payload_len
+ ),
+ ));
+ }
Ok(Some(ListPayloadMeta {
list_id,
offset,
@@ -427,7 +484,31 @@ impl<R: SeekRead> IVFHNSWFlatIndexReader<R> {
),
));
}
- let ids_bytes_len = checked_list_bytes(meta.count, 8)?;
+ let base_header_len = 12usize;
+ if payload.len() < base_header_len {
+ return Err(io::Error::new(
+ io::ErrorKind::UnexpectedEof,
+ format!("list {} has truncated ID header", meta.list_id),
+ ));
+ }
+ let base_id = i64::from_le_bytes(payload[0..8].try_into().unwrap());
+ let id_bytes_len =
i32::from_le_bytes(payload[8..12].try_into().unwrap());
+ if id_bytes_len < 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!(
+ "negative id_bytes_len {} at list {}",
+ id_bytes_len, meta.list_id
+ ),
+ ));
+ }
+ let id_bytes_len = id_bytes_len as usize;
+ let ids_end = base_header_len.checked_add(id_bytes_len).ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ "IVF-HNSW-FLAT ID payload length overflow",
+ )
+ })?;
let vector_bytes_len = checked_list_bytes(
meta.count,
self.d.checked_mul(4).ok_or_else(|| {
@@ -437,13 +518,22 @@ impl<R: SeekRead> IVFHNSWFlatIndexReader<R> {
)
})?,
)?;
- let ids = payload[..ids_bytes_len]
- .chunks_exact(8)
- .map(|c| i64::from_le_bytes(c.try_into().unwrap()))
- .collect();
- let vectors = bytes_to_f32_vec(&payload[ids_bytes_len..ids_bytes_len +
vector_bytes_len])?;
+ let vectors_end = ids_end.checked_add(vector_bytes_len).ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ "IVF-HNSW-FLAT vector payload length overflow",
+ )
+ })?;
+ if vectors_end > payload.len() {
+ return Err(io::Error::new(
+ io::ErrorKind::UnexpectedEof,
+ format!("list {} has truncated vector payload", meta.list_id),
+ ));
+ }
+ let ids = decode_delta_varint_ids(base_id,
&payload[base_header_len..ids_end], meta.count)?;
+ let vectors = bytes_to_f32_vec(&payload[ids_end..vectors_end])?;
let graph = decode_graph(
- &payload[ids_bytes_len + vector_bytes_len..],
+ &payload[vectors_end..],
vectors.clone(),
meta.count,
self.d,
@@ -833,8 +923,128 @@ fn validate_index_shape(index: &IVFHNSWFlatIndex) ->
io::Result<()> {
Ok(())
}
-fn list_payload_len(count: usize, d: usize, graph_bytes_len: usize) ->
io::Result<usize> {
- let id_bytes = checked_list_bytes(count, 8)?;
+struct SortedFlatGraphList {
+ ids: Vec<i64>,
+ id_bytes: Vec<u8>,
+ vectors: Vec<f32>,
+ graph_bytes: Vec<u8>,
+}
+
+fn build_sorted_flat_graph_list(
+ index: &IVFHNSWFlatIndex,
+ list_id: usize,
+) -> io::Result<SortedFlatGraphList> {
+ let count = index.flat.ids[list_id].len();
+ if count == 0 {
+ return Ok(SortedFlatGraphList {
+ ids: Vec::new(),
+ id_bytes: Vec::new(),
+ vectors: Vec::new(),
+ graph_bytes: Vec::new(),
+ });
+ }
+
+ let mut order: Vec<usize> = (0..count).collect();
+ order.sort_by_key(|&idx| index.flat.ids[list_id][idx]);
+
+ let ids: Vec<i64> = order
+ .iter()
+ .map(|&idx| index.flat.ids[list_id][idx])
+ .collect();
+ let (_, id_bytes) = encode_delta_varint_ids(&ids);
+
+ let mut vectors = Vec::with_capacity(count * index.flat.d);
+ for &idx in &order {
+ vectors.extend_from_slice(
+ &index.flat.vectors[list_id][idx * index.flat.d..(idx + 1) *
index.flat.d],
+ );
+ }
+ let old_to_new = old_to_new_order(&order);
+ let source_graph = index.graphs[list_id].as_ref().ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!("list {} is missing HNSW graph", list_id),
+ )
+ })?;
+ let graph = reorder_graph(
+ source_graph,
+ &order,
+ &old_to_new,
+ vectors.clone(),
+ index.flat.d,
+ index.flat.metric,
+ index.hnsw_params,
+ )?;
+ let graph_bytes = encode_graph(Some(&graph))?;
+
+ Ok(SortedFlatGraphList {
+ ids,
+ id_bytes,
+ vectors,
+ graph_bytes,
+ })
+}
+
+fn old_to_new_order(order: &[usize]) -> Vec<usize> {
+ let mut old_to_new = vec![0; order.len()];
+ for (new_idx, &old_idx) in order.iter().enumerate() {
+ old_to_new[old_idx] = new_idx;
+ }
+ old_to_new
+}
+
+fn reorder_graph(
+ graph: &HnswGraph,
+ order: &[usize],
+ old_to_new: &[usize],
+ vectors: Vec<f32>,
+ d: usize,
+ metric: MetricType,
+ hnsw_params: HnswBuildParams,
+) -> io::Result<HnswGraph> {
+ let levels: Vec<usize> = order
+ .iter()
+ .map(|&old_idx| graph.levels()[old_idx])
+ .collect();
+ let neighbors: Vec<Vec<Vec<usize>>> = order
+ .iter()
+ .map(|&old_idx| {
+ graph.neighbors()[old_idx]
+ .iter()
+ .map(|level_neighbors| {
+ level_neighbors
+ .iter()
+ .map(|&neighbor| old_to_new[neighbor])
+ .collect()
+ })
+ .collect()
+ })
+ .collect();
+ HnswGraph::from_parts(
+ vectors,
+ order.len(),
+ d,
+ metric,
+ levels,
+ neighbors,
+ old_to_new[graph.entry_point()],
+ graph.max_observed_level(),
+ hnsw_params,
+ )
+}
+
+fn list_payload_len(
+ count: usize,
+ d: usize,
+ id_bytes_len: usize,
+ graph_bytes_len: usize,
+) -> io::Result<usize> {
+ let id_bytes = 12usize.checked_add(id_bytes_len).ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ "IVF-HNSW-FLAT ID payload length overflow",
+ )
+ })?;
let vector_bytes = checked_list_bytes(
count,
d.checked_mul(4).ok_or_else(|| {
@@ -857,14 +1067,16 @@ fn list_payload_len(count: usize, d: usize,
graph_bytes_len: usize) -> io::Resul
#[cfg(test)]
mod tests {
+ use super::REQUIRED_FLAGS;
use crate::distance::MetricType;
- use crate::hnsw::HnswBuildParams;
- use crate::index_io_util::decode_graph;
+ use crate::hnsw::{HnswBuildParams, HnswGraph};
+ use crate::index_io_util::{decode_graph, encode_graph,
encode_graph_u32_for_size_estimate};
use crate::io::{PosWriter, ReadRequest, SeekRead};
use crate::ivfhnswflat::IVFHNSWFlatIndex;
use crate::ivfhnswflat_io::{
search_batch_ivfhnswflat_reader,
search_batch_ivfhnswflat_reader_roaring_filter,
write_ivfhnswflat_index, IVFHNSWFlatIndexReader,
IVF_HNSW_FLAT_HEADER_SIZE,
+ IVF_HNSW_FLAT_MAGIC, IVF_HNSW_FLAT_VERSION,
};
use roaring::RoaringTreemap;
use std::io;
@@ -872,6 +1084,50 @@ mod tests {
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
+ #[test]
+ fn test_ivfhnswflat_reader_rejects_missing_required_flags() {
+ let mut buf = vec![0u8; IVF_HNSW_FLAT_HEADER_SIZE];
+ buf[0..4].copy_from_slice(&IVF_HNSW_FLAT_MAGIC.to_le_bytes());
+ buf[4..8].copy_from_slice(&IVF_HNSW_FLAT_VERSION.to_le_bytes());
+ buf[8..12].copy_from_slice(&2i32.to_le_bytes());
+ buf[12..16].copy_from_slice(&1i32.to_le_bytes());
+ buf[16..20].copy_from_slice(&(MetricType::L2 as u32).to_le_bytes());
+ buf[20..28].copy_from_slice(&0i64.to_le_bytes());
+ buf[28..32].copy_from_slice(&2i32.to_le_bytes());
+ buf[32..36].copy_from_slice(&8i32.to_le_bytes());
+ buf[36..40].copy_from_slice(&3i32.to_le_bytes());
+ buf[40..44].copy_from_slice(&0u32.to_le_bytes());
+
+ let err = match IVFHNSWFlatIndexReader::open(Cursor::new(buf)) {
+ Ok(_) => panic!("missing required flags should be rejected"),
+ Err(err) => err,
+ };
+ assert!(err
+ .to_string()
+ .contains("requires delta-varint IDs and graph v1"));
+ }
+
+ #[test]
+ fn test_ivfhnswflat_reader_rejects_unknown_flags() {
+ let mut buf = vec![0u8; IVF_HNSW_FLAT_HEADER_SIZE];
+ buf[0..4].copy_from_slice(&IVF_HNSW_FLAT_MAGIC.to_le_bytes());
+ buf[4..8].copy_from_slice(&IVF_HNSW_FLAT_VERSION.to_le_bytes());
+ buf[8..12].copy_from_slice(&2i32.to_le_bytes());
+ buf[12..16].copy_from_slice(&1i32.to_le_bytes());
+ buf[16..20].copy_from_slice(&(MetricType::L2 as u32).to_le_bytes());
+ buf[20..28].copy_from_slice(&0i64.to_le_bytes());
+ buf[28..32].copy_from_slice(&2i32.to_le_bytes());
+ buf[32..36].copy_from_slice(&8i32.to_le_bytes());
+ buf[36..40].copy_from_slice(&3i32.to_le_bytes());
+ buf[40..44].copy_from_slice(&(REQUIRED_FLAGS | (1 <<
31)).to_le_bytes());
+
+ let err = match IVFHNSWFlatIndexReader::open(Cursor::new(buf)) {
+ Ok(_) => panic!("unknown flags should be rejected"),
+ Err(err) => err,
+ };
+ assert!(err.to_string().contains("Unsupported IVF_HNSW_FLAT flags"));
+ }
+
#[test]
fn test_ivfhnswflat_write_read_search_roundtrip() {
let d = 4;
@@ -1273,6 +1529,55 @@ mod tests {
assert!(err.to_string().contains("missing HNSW graph"));
}
+ #[test]
+ fn test_ivfhnswflat_graph_delta_varint_reduces_graph_bytes() {
+ let d = 2;
+ let n = 128;
+ let data: Vec<f32> = (0..n).flat_map(|i| [i as f32, 0.0]).collect();
+ let params = HnswBuildParams {
+ m: 8,
+ ef_construction: 32,
+ max_level: 4,
+ };
+ let graph = HnswGraph::build(&data, n, d, MetricType::L2,
params).unwrap();
+
+ let fixed = encode_graph_u32_for_size_estimate(&graph).unwrap();
+ let compressed = encode_graph(Some(&graph)).unwrap();
+
+ assert!(compressed.len() < fixed.len());
+ assert!(compressed.len() * 2 < fixed.len());
+ }
+
+ #[test]
+ #[ignore]
+ fn print_ivfhnswflat_graph_delta_varint_size_report() {
+ let d = 8;
+ for n in [128usize, 1_024, 4_096] {
+ let data: Vec<f32> = (0..n)
+ .flat_map(|i| {
+ (0..d).map(move |j| {
+ let bucket = (i % 64) as f32;
+ bucket * 0.01 + i as f32 * 0.0001 + j as f32 * 0.001
+ })
+ })
+ .collect();
+ let params = HnswBuildParams {
+ m: 16,
+ ef_construction: 64,
+ max_level: 5,
+ };
+ let graph = HnswGraph::build(&data, n, d, MetricType::L2,
params).unwrap();
+ let fixed = encode_graph_u32_for_size_estimate(&graph).unwrap();
+ let compressed = encode_graph(Some(&graph)).unwrap();
+ println!(
+ "n={n}, fixed_u32={} bytes, delta_varint={} bytes,
saved={:.1}%",
+ fixed.len(),
+ compressed.len(),
+ 100.0 - (compressed.len() as f64 * 100.0 / fixed.len() as f64)
+ );
+ }
+ }
+
#[test]
fn
test_ivfhnswflat_decoder_rejects_level_above_hnsw_max_before_allocation() {
let params = HnswBuildParams {
@@ -1281,10 +1586,11 @@ mod tests {
max_level: 3,
};
let mut graph_bytes = Vec::new();
- append_u32(&mut graph_bytes, 1);
- append_u32(&mut graph_bytes, 0);
- append_u32(&mut graph_bytes, 0);
- append_u32(&mut graph_bytes, params.max_level as u32 + 1);
+ append_graph_header(&mut graph_bytes);
+ append_u32_varint(&mut graph_bytes, 1);
+ append_u32_varint(&mut graph_bytes, 0);
+ append_u32_varint(&mut graph_bytes, 0);
+ append_u32_varint(&mut graph_bytes, params.max_level as u32 + 1);
let err =
decode_graph(&graph_bytes, vec![0.0, 0.0], 1, 2, MetricType::L2,
params).unwrap_err();
@@ -1301,11 +1607,12 @@ mod tests {
max_level: 3,
};
let mut graph_bytes = Vec::new();
- append_u32(&mut graph_bytes, 1);
- append_u32(&mut graph_bytes, 0);
- append_u32(&mut graph_bytes, 0);
- append_u32(&mut graph_bytes, 0);
- append_u32(&mut graph_bytes, (params.m * 2) as u32 + 1);
+ append_graph_header(&mut graph_bytes);
+ append_u32_varint(&mut graph_bytes, 1);
+ append_u32_varint(&mut graph_bytes, 0);
+ append_u32_varint(&mut graph_bytes, 0);
+ append_u32_varint(&mut graph_bytes, 0);
+ append_u32_varint(&mut graph_bytes, (params.m * 2) as u32 + 1);
let err =
decode_graph(&graph_bytes, vec![0.0, 0.0], 1, 2, MetricType::L2,
params).unwrap_err();
@@ -1314,6 +1621,42 @@ mod tests {
assert!(err.to_string().contains("degree"));
}
+ #[test]
+ fn test_ivfhnswflat_decoder_rejects_truncated_graph_varint() {
+ let params = HnswBuildParams {
+ m: 2,
+ ef_construction: 8,
+ max_level: 3,
+ };
+ let mut graph_bytes = Vec::new();
+ append_graph_header(&mut graph_bytes);
+ graph_bytes.push(0x81);
+
+ let err =
+ decode_graph(&graph_bytes, vec![0.0, 0.0], 1, 2, MetricType::L2,
params).unwrap_err();
+
+ assert_eq!(err.kind(), io::ErrorKind::InvalidData);
+ assert!(err.to_string().contains("truncated HNSW graph varint"));
+ }
+
+ #[test]
+ fn test_ivfhnswflat_decoder_rejects_oversized_graph_varint() {
+ let params = HnswBuildParams {
+ m: 2,
+ ef_construction: 8,
+ max_level: 3,
+ };
+ let mut graph_bytes = Vec::new();
+ append_graph_header(&mut graph_bytes);
+ graph_bytes.extend_from_slice(&[0xff; 10]);
+
+ let err =
+ decode_graph(&graph_bytes, vec![0.0, 0.0], 1, 2, MetricType::L2,
params).unwrap_err();
+
+ assert_eq!(err.kind(), io::ErrorKind::InvalidData);
+ assert!(err.to_string().contains("varint exceeds u64 limit"));
+ }
+
#[test]
fn test_ivfhnswflat_writer_requires_built_graphs() {
let d = 2;
@@ -1364,7 +1707,17 @@ mod tests {
}
}
- fn append_u32(buf: &mut Vec<u8>, value: u32) {
- buf.extend_from_slice(&value.to_le_bytes());
+ fn append_u32_varint(buf: &mut Vec<u8>, mut value: u32) {
+ while value >= 0x80 {
+ buf.push((value as u8 & 0x7f) | 0x80);
+ value >>= 7;
+ }
+ buf.push(value as u8);
+ }
+
+ fn append_graph_header(buf: &mut Vec<u8>) {
+ buf.extend_from_slice(&0x48574752u32.to_le_bytes());
+ buf.extend_from_slice(&1u32.to_le_bytes());
+ buf.extend_from_slice(&1u32.to_le_bytes());
}
}
diff --git a/core/src/ivfhnswsq_io.rs b/core/src/ivfhnswsq_io.rs
index c1c94b6..c6c57e2 100644
--- a/core/src/ivfhnswsq_io.rs
+++ b/core/src/ivfhnswsq_io.rs
@@ -19,10 +19,11 @@ use crate::distance::{preprocess_vectors, MetricType};
use crate::hnsw::{HnswBuildParams, HnswGraph};
use crate::hnsw_search::{search_hnsw_lists, HnswSearchList};
use crate::index_io_util::{
- checked_list_bytes, checked_list_offset, checked_section_size,
decode_graph,
- decode_roaring_filter, encode_graph, read_f32_vec, read_i32_le,
read_i64_le, read_u32_le,
- u64_to_i64, usize_to_i32, usize_to_i64, validate_positive_i32,
validate_search_inputs,
- write_f32_slice, write_i32_le, write_i64_le, write_u32_le,
+ checked_list_bytes, checked_list_offset, checked_section_size,
decode_delta_varint_ids,
+ decode_graph, decode_roaring_filter, encode_delta_varint_ids,
encode_graph, read_f32_vec,
+ read_i32_le, read_i64_le, read_u32_le, u64_to_i64, usize_to_i32,
usize_to_i64,
+ validate_positive_i32, validate_search_inputs, write_f32_slice,
write_i32_le, write_i64_le,
+ write_u32_le,
};
use crate::io::{PreadCursor, ReadRequest, SeekRead, SeekWrite};
use crate::ivfhnswsq::IVFHNSWSQIndex;
@@ -35,6 +36,10 @@ use std::io;
pub const IVF_HNSW_SQ_MAGIC: u32 = 0x49485351; // "IHSQ"
pub const IVF_HNSW_SQ_VERSION: u32 = 1;
pub const IVF_HNSW_SQ_HEADER_SIZE: usize = 64;
+const FLAG_DELTA_IDS: u32 = 1 << 0;
+const FLAG_GRAPH_V1: u32 = 1 << 1;
+const REQUIRED_FLAGS: u32 = FLAG_DELTA_IDS | FLAG_GRAPH_V1;
+const SUPPORTED_FLAGS: u32 = REQUIRED_FLAGS;
const MAX_COALESCED_READ_GAP_BYTES: u64 = 1 << 20;
pub fn write_ivfhnswsq_index(index: &IVFHNSWSQIndex, out: &mut dyn SeekWrite)
-> io::Result<()> {
@@ -48,14 +53,8 @@ pub fn write_ivfhnswsq_index(index: &IVFHNSWSQIndex, out:
&mut dyn SeekWrite) ->
)
})
})?;
- let graph_bytes: Vec<Vec<u8>> = (0..index.nlist)
- .map(|list_id| {
- if index.ids[list_id].is_empty() {
- Ok(Vec::new())
- } else {
- encode_graph(index.graphs[list_id].as_ref())
- }
- })
+ let sorted_lists: Vec<SortedSqGraphList> = (0..index.nlist)
+ .map(|list_id| build_sorted_sq_graph_list(index, list_id))
.collect::<io::Result<_>>()?;
write_u32_le(out, IVF_HNSW_SQ_MAGIC)?;
@@ -74,7 +73,8 @@ pub fn write_ivfhnswsq_index(index: &IVFHNSWSQIndex, out:
&mut dyn SeekWrite) ->
let (sq_min, sq_max) = sq_global_bounds(&index.sq.mins, &index.sq.maxs);
out.write_all(&sq_min.to_le_bytes())?;
out.write_all(&sq_max.to_le_bytes())?;
- out.write_all(&[0u8; 16])?;
+ write_u32_le(out, REQUIRED_FLAGS)?;
+ out.write_all(&[0u8; 12])?;
write_f32_slice(out, &index.sq.mins)?;
write_f32_slice(out, &index.sq.maxs)?;
@@ -102,16 +102,23 @@ pub fn write_ivfhnswsq_index(index: &IVFHNSWSQIndex, out:
&mut dyn SeekWrite) ->
let mut list_offsets = vec![0i64; index.nlist];
let mut list_counts = vec![0i32; index.nlist];
let mut list_graph_bytes_lens = vec![0i32; index.nlist];
+ let mut list_payload_bytes_lens = vec![0i64; index.nlist];
let mut current_offset = data_start;
for list_id in 0..index.nlist {
list_offsets[list_id] = u64_to_i64(current_offset, "list offset")?;
- let count = index.ids[list_id].len();
+ let count = sorted_lists[list_id].ids.len();
list_counts[list_id] = usize_to_i32(count, "list count")?;
- list_graph_bytes_lens[list_id] =
usize_to_i32(graph_bytes[list_id].len(), "graph bytes")?;
+ list_graph_bytes_lens[list_id] =
+ usize_to_i32(sorted_lists[list_id].graph_bytes.len(), "graph
bytes")?;
if count > 0 {
- let payload_len =
- list_payload_len(count, index.sq.code_size(),
graph_bytes[list_id].len())?;
+ let payload_len = list_payload_len(
+ count,
+ index.sq.code_size(),
+ sorted_lists[list_id].id_bytes.len(),
+ sorted_lists[list_id].graph_bytes.len(),
+ )?;
+ list_payload_bytes_lens[list_id] = usize_to_i64(payload_len, "list
payload bytes")?;
current_offset = current_offset
.checked_add(payload_len as u64)
.ok_or_else(|| {
@@ -124,18 +131,19 @@ pub fn write_ivfhnswsq_index(index: &IVFHNSWSQIndex, out:
&mut dyn SeekWrite) ->
write_i64_le(out, list_offsets[list_id])?;
write_i32_le(out, list_counts[list_id])?;
write_i32_le(out, list_graph_bytes_lens[list_id])?;
- write_i64_le(out, 0)?;
+ write_i64_le(out, list_payload_bytes_lens[list_id])?;
}
for list_id in 0..index.nlist {
- if index.ids[list_id].is_empty() {
+ let list = &sorted_lists[list_id];
+ if list.ids.is_empty() {
continue;
}
- for &id in &index.ids[list_id] {
- write_i64_le(out, id)?;
- }
- out.write_all(&index.codes[list_id])?;
- out.write_all(&graph_bytes[list_id])?;
+ write_i64_le(out, list.ids[0])?;
+ write_i32_le(out, usize_to_i32(list.id_bytes.len(), "delta ID
section")?)?;
+ out.write_all(&list.id_bytes)?;
+ out.write_all(&list.codes)?;
+ out.write_all(&list.graph_bytes)?;
}
Ok(())
@@ -154,6 +162,7 @@ pub struct IVFHNSWSQIndexReader<R: SeekRead> {
pub list_offsets: Vec<i64>,
pub list_counts: Vec<i32>,
pub list_graph_bytes_lens: Vec<i32>,
+ pub list_payload_bytes_lens: Vec<i64>,
loaded: bool,
}
@@ -195,14 +204,37 @@ impl<R: SeekRead> IVFHNSWSQIndexReader<R> {
max_level: validate_positive_i32(read_i32_le(&mut cursor)?, "hnsw
max_level")? as usize,
}
.sanitized();
- let mut bounds_summary = [0u8; 8];
- cursor.read_exact(&mut bounds_summary)?;
- let mut reserved = [0u8; 16];
+ let sq_min_summary = read_f32_le(&mut cursor)?;
+ let sq_max_summary = read_f32_le(&mut cursor)?;
+ let flags = read_u32_le(&mut cursor)?;
+ let mut reserved = [0u8; 12];
cursor.read_exact(&mut reserved)?;
+ let unknown_flags = flags & !SUPPORTED_FLAGS;
+ if unknown_flags != 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!("Unsupported IVF_HNSW_SQ flags: 0x{:08X}",
unknown_flags),
+ ));
+ }
+ if flags & REQUIRED_FLAGS != REQUIRED_FLAGS {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ "IVF_HNSW_SQ v1 requires delta-varint IDs and graph v1",
+ ));
+ }
let mins = read_f32_vec(&mut cursor, d)?;
let maxs = read_f32_vec(&mut cursor, d)?;
validate_sq_bounds(d, &mins, &maxs)?;
+ let (sq_min, sq_max) = sq_global_bounds(&mins, &maxs);
+ if sq_min.to_bits() != sq_min_summary.to_bits()
+ || sq_max.to_bits() != sq_max_summary.to_bits()
+ {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ "SQ bounds summary does not match global SQ bounds",
+ ));
+ }
let sq = ScalarQuantizer::with_dimension_bounds(d, mins, maxs);
let mut list_sqs = Vec::with_capacity(nlist);
for _ in 0..nlist {
@@ -225,6 +257,7 @@ impl<R: SeekRead> IVFHNSWSQIndexReader<R> {
list_offsets: Vec::new(),
list_counts: Vec::new(),
list_graph_bytes_lens: Vec::new(),
+ list_payload_bytes_lens: Vec::new(),
loaded: false,
})
}
@@ -242,6 +275,7 @@ impl<R: SeekRead> IVFHNSWSQIndexReader<R> {
self.list_offsets = vec![0; self.nlist];
self.list_counts = vec![0; self.nlist];
self.list_graph_bytes_lens = vec![0; self.nlist];
+ self.list_payload_bytes_lens = vec![0; self.nlist];
for list_id in 0..self.nlist {
self.list_offsets[list_id] = read_i64_le(&mut cursor)?;
let count = read_i32_le(&mut cursor)?;
@@ -263,7 +297,17 @@ impl<R: SeekRead> IVFHNSWSQIndexReader<R> {
));
}
self.list_graph_bytes_lens[list_id] = graph_bytes_len;
- let _reserved = read_i64_le(&mut cursor)?;
+ let payload_bytes_len = read_i64_le(&mut cursor)?;
+ if payload_bytes_len < 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!(
+ "negative payload_bytes_len {} at list {}",
+ payload_bytes_len, list_id
+ ),
+ ));
+ }
+ self.list_payload_bytes_lens[list_id] = payload_bytes_len;
}
self.loaded = true;
@@ -427,7 +471,28 @@ impl<R: SeekRead> IVFHNSWSQIndexReader<R> {
format!("list {} is missing HNSW graph", list_id),
));
}
- let payload_len = list_payload_len(count, self.sq.code_size(),
graph_bytes_len)?;
+ let payload_len = self.list_payload_bytes_lens[list_id] as usize;
+ if payload_len == 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!("list {} is missing payload length", list_id),
+ ));
+ }
+ let minimum_payload_len =
12usize.checked_add(graph_bytes_len).ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ "IVF-HNSW-SQ minimum payload length overflow",
+ )
+ })?;
+ if payload_len < minimum_payload_len {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!(
+ "list {} payload length {} is shorter than expected {}",
+ list_id, payload_len, minimum_payload_len
+ ),
+ ));
+ }
Ok(Some(ListPayloadMeta {
list_id,
offset,
@@ -452,14 +517,47 @@ impl<R: SeekRead> IVFHNSWSQIndexReader<R> {
),
));
}
- let ids_bytes_len = checked_list_bytes(meta.count, 8)?;
+ let base_header_len = 12usize;
+ if payload.len() < base_header_len {
+ return Err(io::Error::new(
+ io::ErrorKind::UnexpectedEof,
+ format!("list {} has truncated ID header", meta.list_id),
+ ));
+ }
+ let base_id = i64::from_le_bytes(payload[0..8].try_into().unwrap());
+ let id_bytes_len =
i32::from_le_bytes(payload[8..12].try_into().unwrap());
+ if id_bytes_len < 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!(
+ "negative id_bytes_len {} at list {}",
+ id_bytes_len, meta.list_id
+ ),
+ ));
+ }
+ let id_bytes_len = id_bytes_len as usize;
+ let ids_end = base_header_len.checked_add(id_bytes_len).ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ "IVF-HNSW-SQ ID payload length overflow",
+ )
+ })?;
let code_size = self.sq.code_size();
let codes_bytes_len = checked_list_bytes(meta.count, code_size)?;
- let ids = payload[..ids_bytes_len]
- .chunks_exact(8)
- .map(|c| i64::from_le_bytes(c.try_into().unwrap()))
- .collect();
- let codes = payload[ids_bytes_len..ids_bytes_len +
codes_bytes_len].to_vec();
+ let codes_end = ids_end.checked_add(codes_bytes_len).ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ "IVF-HNSW-SQ codes payload length overflow",
+ )
+ })?;
+ if codes_end > payload.len() {
+ return Err(io::Error::new(
+ io::ErrorKind::UnexpectedEof,
+ format!("list {} has truncated SQ codes payload",
meta.list_id),
+ ));
+ }
+ let ids = decode_delta_varint_ids(base_id,
&payload[base_header_len..ids_end], meta.count)?;
+ let codes = payload[ids_end..codes_end].to_vec();
let mut vectors = vec![0.0f32; meta.count * self.d];
self.list_sq(meta.list_id)
.decode_batch(&codes, meta.count, &mut vectors);
@@ -470,7 +568,7 @@ impl<R: SeekRead> IVFHNSWSQIndexReader<R> {
}
}
let graph = decode_graph(
- &payload[ids_bytes_len + codes_bytes_len..],
+ &payload[codes_end..],
vectors,
meta.count,
self.d,
@@ -928,6 +1026,12 @@ fn validate_sq_bounds(d: usize, mins: &[f32], maxs:
&[f32]) -> io::Result<()> {
Ok(())
}
+fn read_f32_le<R: SeekRead + ?Sized>(reader: &mut PreadCursor<'_, R>) ->
io::Result<f32> {
+ let mut buf = [0u8; 4];
+ reader.read_exact(&mut buf)?;
+ Ok(f32::from_le_bytes(buf))
+}
+
fn sq_global_bounds(mins: &[f32], maxs: &[f32]) -> (f32, f32) {
let min = mins.iter().copied().fold(f32::INFINITY, f32::min);
let max = maxs.iter().copied().fold(f32::NEG_INFINITY, f32::max);
@@ -938,8 +1042,136 @@ fn sq_global_bounds(mins: &[f32], maxs: &[f32]) -> (f32,
f32) {
}
}
-fn list_payload_len(count: usize, code_size: usize, graph_bytes_len: usize) ->
io::Result<usize> {
- let id_bytes = checked_list_bytes(count, 8)?;
+struct SortedSqGraphList {
+ ids: Vec<i64>,
+ id_bytes: Vec<u8>,
+ codes: Vec<u8>,
+ graph_bytes: Vec<u8>,
+}
+
+fn build_sorted_sq_graph_list(
+ index: &IVFHNSWSQIndex,
+ list_id: usize,
+) -> io::Result<SortedSqGraphList> {
+ let count = index.ids[list_id].len();
+ if count == 0 {
+ return Ok(SortedSqGraphList {
+ ids: Vec::new(),
+ id_bytes: Vec::new(),
+ codes: Vec::new(),
+ graph_bytes: Vec::new(),
+ });
+ }
+
+ let code_size = index.sq.code_size();
+ let mut order: Vec<usize> = (0..count).collect();
+ order.sort_by_key(|&idx| index.ids[list_id][idx]);
+
+ let ids: Vec<i64> = order.iter().map(|&idx|
index.ids[list_id][idx]).collect();
+ let (_, id_bytes) = encode_delta_varint_ids(&ids);
+
+ let mut codes = vec![0u8; checked_list_bytes(count, code_size)?];
+ for (new_idx, &old_idx) in order.iter().enumerate() {
+ codes[new_idx * code_size..(new_idx + 1) * code_size]
+ .copy_from_slice(&index.codes[list_id][old_idx *
code_size..(old_idx + 1) * code_size]);
+ }
+
+ let mut vectors = vec![0.0f32; count * index.d];
+ index
+ .list_sq(list_id)
+ .decode_batch(&codes, count, &mut vectors);
+ let centroid = &index.quantizer_centroids[list_id * index.d..(list_id + 1)
* index.d];
+ for vector in vectors.chunks_exact_mut(index.d) {
+ for i in 0..index.d {
+ vector[i] += centroid[i];
+ }
+ }
+ let old_to_new = old_to_new_order(&order);
+ let source_graph = index.graphs[list_id].as_ref().ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!("list {} is missing HNSW graph", list_id),
+ )
+ })?;
+ let graph = reorder_graph(
+ source_graph,
+ &order,
+ &old_to_new,
+ vectors,
+ index.d,
+ index.metric,
+ index.hnsw_params,
+ )?;
+ let graph_bytes = encode_graph(Some(&graph))?;
+
+ Ok(SortedSqGraphList {
+ ids,
+ id_bytes,
+ codes,
+ graph_bytes,
+ })
+}
+
+fn old_to_new_order(order: &[usize]) -> Vec<usize> {
+ let mut old_to_new = vec![0; order.len()];
+ for (new_idx, &old_idx) in order.iter().enumerate() {
+ old_to_new[old_idx] = new_idx;
+ }
+ old_to_new
+}
+
+fn reorder_graph(
+ graph: &HnswGraph,
+ order: &[usize],
+ old_to_new: &[usize],
+ vectors: Vec<f32>,
+ d: usize,
+ metric: MetricType,
+ hnsw_params: HnswBuildParams,
+) -> io::Result<HnswGraph> {
+ let levels: Vec<usize> = order
+ .iter()
+ .map(|&old_idx| graph.levels()[old_idx])
+ .collect();
+ let neighbors: Vec<Vec<Vec<usize>>> = order
+ .iter()
+ .map(|&old_idx| {
+ graph.neighbors()[old_idx]
+ .iter()
+ .map(|level_neighbors| {
+ level_neighbors
+ .iter()
+ .map(|&neighbor| old_to_new[neighbor])
+ .collect()
+ })
+ .collect()
+ })
+ .collect();
+ HnswGraph::from_parts(
+ vectors,
+ order.len(),
+ d,
+ metric,
+ levels,
+ neighbors,
+ old_to_new[graph.entry_point()],
+ graph.max_observed_level(),
+ hnsw_params,
+ )
+}
+
+fn list_payload_len(
+ count: usize,
+ code_size: usize,
+ id_bytes_len: usize,
+ graph_bytes_len: usize,
+) -> io::Result<usize> {
+ let id_bytes = 12usize.checked_add(id_bytes_len).ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "IVF-HNSW-SQ ID payload length overflow",
+ )
+ })?;
let code_bytes = checked_list_bytes(count, code_size)?;
id_bytes
.checked_add(code_bytes)
@@ -963,6 +1195,54 @@ mod tests {
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
+ #[test]
+ fn test_ivfhnswsq_reader_rejects_missing_required_flags() {
+ let mut buf = vec![0u8; IVF_HNSW_SQ_HEADER_SIZE + 16];
+ buf[0..4].copy_from_slice(&IVF_HNSW_SQ_MAGIC.to_le_bytes());
+ buf[4..8].copy_from_slice(&IVF_HNSW_SQ_VERSION.to_le_bytes());
+ buf[8..12].copy_from_slice(&2i32.to_le_bytes());
+ buf[12..16].copy_from_slice(&1i32.to_le_bytes());
+ buf[16..20].copy_from_slice(&(MetricType::L2 as u32).to_le_bytes());
+ buf[20..28].copy_from_slice(&0i64.to_le_bytes());
+ buf[28..32].copy_from_slice(&2i32.to_le_bytes());
+ buf[32..36].copy_from_slice(&8i32.to_le_bytes());
+ buf[36..40].copy_from_slice(&3i32.to_le_bytes());
+ buf[40..44].copy_from_slice(&0.0f32.to_le_bytes());
+ buf[44..48].copy_from_slice(&0.0f32.to_le_bytes());
+ buf[48..52].copy_from_slice(&0u32.to_le_bytes());
+
+ let err = match IVFHNSWSQIndexReader::open(Cursor::new(buf)) {
+ Ok(_) => panic!("missing required flags should be rejected"),
+ Err(err) => err,
+ };
+ assert!(err
+ .to_string()
+ .contains("requires delta-varint IDs and graph v1"));
+ }
+
+ #[test]
+ fn test_ivfhnswsq_reader_rejects_unknown_flags() {
+ let mut buf = vec![0u8; IVF_HNSW_SQ_HEADER_SIZE + 16];
+ buf[0..4].copy_from_slice(&IVF_HNSW_SQ_MAGIC.to_le_bytes());
+ buf[4..8].copy_from_slice(&IVF_HNSW_SQ_VERSION.to_le_bytes());
+ buf[8..12].copy_from_slice(&2i32.to_le_bytes());
+ buf[12..16].copy_from_slice(&1i32.to_le_bytes());
+ buf[16..20].copy_from_slice(&(MetricType::L2 as u32).to_le_bytes());
+ buf[20..28].copy_from_slice(&0i64.to_le_bytes());
+ buf[28..32].copy_from_slice(&2i32.to_le_bytes());
+ buf[32..36].copy_from_slice(&8i32.to_le_bytes());
+ buf[36..40].copy_from_slice(&3i32.to_le_bytes());
+ buf[40..44].copy_from_slice(&0.0f32.to_le_bytes());
+ buf[44..48].copy_from_slice(&0.0f32.to_le_bytes());
+ buf[48..52].copy_from_slice(&(REQUIRED_FLAGS | (1 <<
31)).to_le_bytes());
+
+ let err = match IVFHNSWSQIndexReader::open(Cursor::new(buf)) {
+ Ok(_) => panic!("unknown flags should be rejected"),
+ Err(err) => err,
+ };
+ assert!(err.to_string().contains("Unsupported IVF_HNSW_SQ flags"));
+ }
+
#[test]
fn test_ivfhnswsq_write_read_search_roundtrip() {
let d = 4;
@@ -1029,6 +1309,29 @@ mod tests {
assert_eq!(reader.list_sqs[0].maxs, index.list_sqs[0].maxs);
}
+ #[test]
+ fn test_ivfhnswsq_reader_rejects_mismatched_sq_bounds_summary() {
+ let d = 2;
+ let nlist = 1;
+ let data = vec![0.0, -100.0, 1.0, 100.0];
+ let ids = vec![10, 11];
+ let mut index = IVFHNSWSQIndex::new(d, nlist, MetricType::L2,
HnswBuildParams::default());
+ index.train(&data, 2);
+ index.add(&data, &ids, 2);
+ index.build_graphs().unwrap();
+
+ let mut buf = Vec::new();
+ write_ivfhnswsq_index(&index, &mut PosWriter::new(&mut buf)).unwrap();
+ buf[40..44].copy_from_slice(&123.0f32.to_le_bytes());
+
+ let err = match IVFHNSWSQIndexReader::open(Cursor::new(buf)) {
+ Ok(_) => panic!("mismatched SQ bounds summary should be rejected"),
+ Err(err) => err,
+ };
+ assert_eq!(err.kind(), io::ErrorKind::InvalidData);
+ assert!(err.to_string().contains("SQ bounds summary"));
+ }
+
#[test]
fn test_ivfhnswsq_reader_search_with_roaring_filter() {
let d = 2;
diff --git a/core/tests/fixtures/ivf_hnsw_flat_v1.hex
b/core/tests/fixtures/ivf_hnsw_flat_v1.hex
index f00b7e3..d88c2d7 100644
--- a/core/tests/fixtures/ivf_hnsw_flat_v1.hex
+++ b/core/tests/fixtures/ivf_hnsw_flat_v1.hex
@@ -1,7 +1,7 @@
4c 46 48 49 01 00 00 00 02 00 00 00 02 00 00 00 00 00 00 00 02 00 00 00 00 00
00 00 02 00 00 00
-08 00 00 00 03 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00
00 00 00 00 00 00
-00 00 00 00 00 00 00 00 00 00 20 41 00 00 20 41 80 00 00 00 00 00 00 00 01 00
00 00 14 00 00 00
-00 00 00 00 00 00 00 00 a4 00 00 00 00 00 00 00 01 00 00 00 14 00 00 00 00 00
00 00 00 00 00 00
-07 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 01 00 00 00 00 00 00 00 00 00
00 00 00 00 00 00
-00 00 00 00 63 00 00 00 00 00 00 00 00 00 20 41 00 00 20 41 01 00 00 00 00 00
00 00 00 00 00 00
-00 00 00 00 00 00 00 00
+08 00 00 00 03 00 00 00 03 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00
00 00 00 00 00 00
+00 00 00 00 00 00 00 00 00 00 20 41 00 00 20 41 80 00 00 00 00 00 00 00 01 00
00 00 11 00 00 00
+26 00 00 00 00 00 00 00 a6 00 00 00 00 00 00 00 01 00 00 00 11 00 00 00 26 00
00 00 00 00 00 00
+07 00 00 00 00 00 00 00 01 00 00 00 00 00 00 00 00 00 00 00 00 52 47 57 48 01
00 00 00 01 00 00
+00 01 00 00 00 00 63 00 00 00 00 00 00 00 01 00 00 00 00 00 00 20 41 00 00 20
41 52 47 57 48 01
+00 00 00 01 00 00 00 01 00 00 00 00
diff --git a/core/tests/fixtures/ivf_hnsw_sq_v1.hex
b/core/tests/fixtures/ivf_hnsw_sq_v1.hex
index 5c90955..1715feb 100644
--- a/core/tests/fixtures/ivf_hnsw_sq_v1.hex
+++ b/core/tests/fixtures/ivf_hnsw_sq_v1.hex
@@ -1,8 +1,8 @@
51 53 48 49 01 00 00 00 02 00 00 00 02 00 00 00 00 00 00 00 02 00 00 00 00 00
00 00 02 00 00 00
-08 00 00 00 03 00 00 00 00 00 00 00 00 00 80 3f 00 00 00 00 00 00 00 00 00 00
00 00 00 00 00 00
+08 00 00 00 03 00 00 00 00 00 00 00 00 00 80 3f 03 00 00 00 00 00 00 00 00 00
00 00 00 00 00 00
00 00 00 00 00 00 00 00 00 00 80 3f 00 00 80 3f 00 00 00 00 00 00 00 00 00 00
80 3f 00 00 80 3f
00 00 00 00 00 00 00 00 00 00 80 3f 00 00 80 3f 00 00 00 00 00 00 00 00 00 00
20 41 00 00 20 41
-b0 00 00 00 00 00 00 00 01 00 00 00 14 00 00 00 00 00 00 00 00 00 00 00 ce 00
00 00 00 00 00 00
-01 00 00 00 14 00 00 00 00 00 00 00 00 00 00 00 07 00 00 00 00 00 00 00 00 00
01 00 00 00 00 00
-00 00 00 00 00 00 00 00 00 00 00 00 00 00 63 00 00 00 00 00 00 00 00 00 01 00
00 00 00 00 00 00
-00 00 00 00 00 00 00 00 00 00 00 00
+b0 00 00 00 00 00 00 00 01 00 00 00 11 00 00 00 20 00 00 00 00 00 00 00 d0 00
00 00 00 00 00 00
+01 00 00 00 11 00 00 00 20 00 00 00 00 00 00 00 07 00 00 00 00 00 00 00 01 00
00 00 00 00 00 52
+47 57 48 01 00 00 00 01 00 00 00 01 00 00 00 00 63 00 00 00 00 00 00 00 01 00
00 00 00 00 00 52
+47 57 48 01 00 00 00 01 00 00 00 01 00 00 00 00