This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 924b6e9d0e IPC writer truncated sliced list/map values (#5071)
924b6e9d0e is described below
commit 924b6e9d0e62ad8cb85419268d8765611a72631e
Author: Jeffrey <[email protected]>
AuthorDate: Tue Nov 14 08:01:10 2023 +1100
IPC writer truncated sliced list/map values (#5071)
* IPC writer truncated sliced list/map values
* Add empty list test
* Revert submodule update
---
arrow-ipc/src/writer.rs | 429 ++++++++++++++++++++++++++++++++----------------
1 file changed, 285 insertions(+), 144 deletions(-)
diff --git a/arrow-ipc/src/writer.rs b/arrow-ipc/src/writer.rs
index a58cbfc514..1f6bf5f6fa 100644
--- a/arrow-ipc/src/writer.rs
+++ b/arrow-ipc/src/writer.rs
@@ -1139,6 +1139,29 @@ fn get_buffer_element_width(spec: &BufferSpec) -> usize {
}
}
+/// Common functionality for re-encoding offsets. Returns the new offsets as
well as
+/// original start offset and length for use in slicing child data.
+fn reencode_offsets<O: OffsetSizeTrait>(
+ offsets: &Buffer,
+ data: &ArrayData,
+) -> (Buffer, usize, usize) {
+ let offsets_slice: &[O] = offsets.typed_data::<O>();
+ let offset_slice = &offsets_slice[data.offset()..data.offset() +
data.len() + 1];
+
+ let start_offset = offset_slice.first().unwrap();
+ let end_offset = offset_slice.last().unwrap();
+
+ let offsets = match start_offset.as_usize() {
+ 0 => offsets.clone(),
+ _ => offset_slice.iter().map(|x| *x - *start_offset).collect(),
+ };
+
+ let start_offset = start_offset.as_usize();
+ let end_offset = end_offset.as_usize();
+
+ (offsets, start_offset, end_offset - start_offset)
+}
+
/// Returns the values and offsets [`Buffer`] for a ByteArray with offset type
`O`
///
/// In particular, this handles re-encoding the offsets if they don't start at
`0`,
@@ -1149,23 +1172,24 @@ fn get_byte_array_buffers<O: OffsetSizeTrait>(data:
&ArrayData) -> (Buffer, Buff
return (MutableBuffer::new(0).into(), MutableBuffer::new(0).into());
}
- let buffers = data.buffers();
- let offsets: &[O] = buffers[0].typed_data::<O>();
- let offset_slice = &offsets[data.offset()..data.offset() + data.len() + 1];
-
- let start_offset = offset_slice.first().unwrap();
- let end_offset = offset_slice.last().unwrap();
+ let (offsets, original_start_offset, len) =
reencode_offsets::<O>(&data.buffers()[0], data);
+ let values = data.buffers()[1].slice_with_length(original_start_offset,
len);
+ (offsets, values)
+}
- let offsets = match start_offset.as_usize() {
- 0 => buffers[0].clone(),
- _ => offset_slice.iter().map(|x| *x - *start_offset).collect(),
- };
+/// Similar logic as [`get_byte_array_buffers()`] but slices the child array
instead
+/// of a values buffer.
+fn get_list_array_buffers<O: OffsetSizeTrait>(data: &ArrayData) -> (Buffer,
ArrayData) {
+ if data.is_empty() {
+ return (
+ MutableBuffer::new(0).into(),
+ data.child_data()[0].slice(0, 0),
+ );
+ }
- let values = buffers[1].slice_with_length(
- start_offset.as_usize(),
- end_offset.as_usize() - start_offset.as_usize(),
- );
- (offsets, values)
+ let (offsets, original_start_offset, len) =
reencode_offsets::<O>(&data.buffers()[0], data);
+ let child_data = data.child_data()[0].slice(original_start_offset, len);
+ (offsets, child_data)
}
/// Write array data to a vector of bytes
@@ -1250,20 +1274,14 @@ fn write_array_data(
let byte_width = get_buffer_element_width(spec);
let min_length = array_data.len() * byte_width;
- if buffer_need_truncate(array_data.offset(), buffer, spec, min_length)
{
+ let buffer_slice = if buffer_need_truncate(array_data.offset(),
buffer, spec, min_length) {
let byte_offset = array_data.offset() * byte_width;
let buffer_length = min(min_length, buffer.len() - byte_offset);
- let buffer_slice = &buffer.as_slice()[byte_offset..(byte_offset +
buffer_length)];
- offset = write_buffer(buffer_slice, buffers, arrow_data, offset,
compression_codec)?;
+ &buffer.as_slice()[byte_offset..(byte_offset + buffer_length)]
} else {
- offset = write_buffer(
- buffer.as_slice(),
- buffers,
- arrow_data,
- offset,
- compression_codec,
- )?;
- }
+ buffer.as_slice()
+ };
+ offset = write_buffer(buffer_slice, buffers, arrow_data, offset,
compression_codec)?;
} else if matches!(data_type, DataType::Boolean) {
// Bools are special because the payload (= 1 bit) is smaller than the
physical container elements (= bytes).
// The array data may not start at the physical boundary of the
underlying buffer, so we need to shift bits around.
@@ -1272,6 +1290,39 @@ fn write_array_data(
let buffer = &array_data.buffers()[0];
let buffer = buffer.bit_slice(array_data.offset(), array_data.len());
offset = write_buffer(&buffer, buffers, arrow_data, offset,
compression_codec)?;
+ } else if matches!(
+ data_type,
+ DataType::List(_) | DataType::LargeList(_) | DataType::Map(_, _)
+ ) {
+ assert_eq!(array_data.buffers().len(), 1);
+ assert_eq!(array_data.child_data().len(), 1);
+
+ // Truncate offsets and the child data to avoid writing unnecessary
data
+ let (offsets, sliced_child_data) = match data_type {
+ DataType::List(_) => get_list_array_buffers::<i32>(array_data),
+ DataType::Map(_, _) => get_list_array_buffers::<i32>(array_data),
+ DataType::LargeList(_) =>
get_list_array_buffers::<i64>(array_data),
+ _ => unreachable!(),
+ };
+ offset = write_buffer(
+ offsets.as_slice(),
+ buffers,
+ arrow_data,
+ offset,
+ compression_codec,
+ )?;
+ offset = write_array_data(
+ &sliced_child_data,
+ buffers,
+ arrow_data,
+ nodes,
+ offset,
+ sliced_child_data.len(),
+ sliced_child_data.null_count(),
+ compression_codec,
+ write_options,
+ )?;
+ return Ok(offset);
} else {
for buffer in array_data.buffers() {
offset = write_buffer(buffer, buffers, arrow_data, offset,
compression_codec)?;
@@ -1372,8 +1423,10 @@ mod tests {
use std::io::Seek;
use std::sync::Arc;
+ use arrow_array::builder::GenericListBuilder;
+ use arrow_array::builder::MapBuilder;
use arrow_array::builder::UnionBuilder;
- use arrow_array::builder::{ListBuilder, PrimitiveRunBuilder,
UInt32Builder};
+ use arrow_array::builder::{PrimitiveRunBuilder, UInt32Builder};
use arrow_array::types::*;
use arrow_schema::DataType;
@@ -1382,6 +1435,30 @@ mod tests {
use super::*;
+ fn serialize_file(rb: &RecordBatch) -> Vec<u8> {
+ let mut writer = FileWriter::try_new(vec![], &rb.schema()).unwrap();
+ writer.write(rb).unwrap();
+ writer.finish().unwrap();
+ writer.into_inner().unwrap()
+ }
+
+ fn deserialize_file(bytes: Vec<u8>) -> RecordBatch {
+ let mut reader = FileReader::try_new(Cursor::new(bytes),
None).unwrap();
+ reader.next().unwrap().unwrap()
+ }
+
+ fn serialize_stream(record: &RecordBatch) -> Vec<u8> {
+ let mut stream_writer = StreamWriter::try_new(vec![],
&record.schema()).unwrap();
+ stream_writer.write(record).unwrap();
+ stream_writer.finish().unwrap();
+ stream_writer.into_inner().unwrap()
+ }
+
+ fn deserialize_stream(bytes: Vec<u8>) -> RecordBatch {
+ let mut stream_reader = StreamReader::try_new(Cursor::new(bytes),
None).unwrap();
+ stream_reader.next().unwrap().unwrap()
+ }
+
#[test]
#[cfg(feature = "lz4")]
fn test_write_empty_record_batch_lz4_compression() {
@@ -1407,27 +1484,18 @@ mod tests {
file.rewind().unwrap();
{
// read file
- let mut reader = FileReader::try_new(file, None).unwrap();
- loop {
- match reader.next() {
- Some(Ok(read_batch)) => {
- read_batch
- .columns()
- .iter()
- .zip(record_batch.columns())
- .for_each(|(a, b)| {
- assert_eq!(a.data_type(), b.data_type());
- assert_eq!(a.len(), b.len());
- assert_eq!(a.null_count(), b.null_count());
- });
- }
- Some(Err(e)) => {
- panic!("{}", e);
- }
- None => {
- break;
- }
- }
+ let reader = FileReader::try_new(file, None).unwrap();
+ for read_batch in reader {
+ read_batch
+ .unwrap()
+ .columns()
+ .iter()
+ .zip(record_batch.columns())
+ .for_each(|(a, b)| {
+ assert_eq!(a.data_type(), b.data_type());
+ assert_eq!(a.len(), b.len());
+ assert_eq!(a.null_count(), b.null_count());
+ });
}
}
}
@@ -1456,27 +1524,18 @@ mod tests {
file.rewind().unwrap();
{
// read file
- let mut reader = FileReader::try_new(file, None).unwrap();
- loop {
- match reader.next() {
- Some(Ok(read_batch)) => {
- read_batch
- .columns()
- .iter()
- .zip(record_batch.columns())
- .for_each(|(a, b)| {
- assert_eq!(a.data_type(), b.data_type());
- assert_eq!(a.len(), b.len());
- assert_eq!(a.null_count(), b.null_count());
- });
- }
- Some(Err(e)) => {
- panic!("{}", e);
- }
- None => {
- break;
- }
- }
+ let reader = FileReader::try_new(file, None).unwrap();
+ for read_batch in reader {
+ read_batch
+ .unwrap()
+ .columns()
+ .iter()
+ .zip(record_batch.columns())
+ .for_each(|(a, b)| {
+ assert_eq!(a.data_type(), b.data_type());
+ assert_eq!(a.len(), b.len());
+ assert_eq!(a.null_count(), b.null_count());
+ });
}
}
}
@@ -1504,27 +1563,18 @@ mod tests {
file.rewind().unwrap();
{
// read file
- let mut reader = FileReader::try_new(file, None).unwrap();
- loop {
- match reader.next() {
- Some(Ok(read_batch)) => {
- read_batch
- .columns()
- .iter()
- .zip(record_batch.columns())
- .for_each(|(a, b)| {
- assert_eq!(a.data_type(), b.data_type());
- assert_eq!(a.len(), b.len());
- assert_eq!(a.null_count(), b.null_count());
- });
- }
- Some(Err(e)) => {
- panic!("{}", e);
- }
- None => {
- break;
- }
- }
+ let reader = FileReader::try_new(file, None).unwrap();
+ for read_batch in reader {
+ read_batch
+ .unwrap()
+ .columns()
+ .iter()
+ .zip(record_batch.columns())
+ .for_each(|(a, b)| {
+ assert_eq!(a.data_type(), b.data_type());
+ assert_eq!(a.len(), b.len());
+ assert_eq!(a.null_count(), b.null_count());
+ });
}
}
}
@@ -1754,20 +1804,6 @@ mod tests {
write_union_file(IpcWriteOptions::try_new(8, false,
MetadataVersion::V5).unwrap());
}
- fn serialize(record: &RecordBatch) -> Vec<u8> {
- let buffer: Vec<u8> = Vec::new();
- let mut stream_writer = StreamWriter::try_new(buffer,
&record.schema()).unwrap();
- stream_writer.write(record).unwrap();
- stream_writer.finish().unwrap();
- stream_writer.into_inner().unwrap()
- }
-
- fn deserialize(bytes: Vec<u8>) -> RecordBatch {
- let mut stream_reader =
- crate::reader::StreamReader::try_new(std::io::Cursor::new(bytes),
None).unwrap();
- stream_reader.next().unwrap().unwrap()
- }
-
#[test]
fn truncate_ipc_record_batch() {
fn create_batch(rows: usize) -> RecordBatch {
@@ -1789,14 +1825,16 @@ mod tests {
let offset = 2;
let record_batch_slice = big_record_batch.slice(offset, length);
- assert!(serialize(&big_record_batch).len() >
serialize(&small_record_batch).len());
+ assert!(
+ serialize_stream(&big_record_batch).len() >
serialize_stream(&small_record_batch).len()
+ );
assert_eq!(
- serialize(&small_record_batch).len(),
- serialize(&record_batch_slice).len()
+ serialize_stream(&small_record_batch).len(),
+ serialize_stream(&record_batch_slice).len()
);
assert_eq!(
- deserialize(serialize(&record_batch_slice)),
+ deserialize_stream(serialize_stream(&record_batch_slice)),
record_batch_slice
);
}
@@ -1817,9 +1855,11 @@ mod tests {
let record_batch = create_batch();
let record_batch_slice = record_batch.slice(1, 2);
- let deserialized_batch = deserialize(serialize(&record_batch_slice));
+ let deserialized_batch =
deserialize_stream(serialize_stream(&record_batch_slice));
- assert!(serialize(&record_batch).len() >
serialize(&record_batch_slice).len());
+ assert!(
+ serialize_stream(&record_batch).len() >
serialize_stream(&record_batch_slice).len()
+ );
assert!(deserialized_batch.column(0).is_null(0));
assert!(deserialized_batch.column(0).is_valid(1));
@@ -1846,9 +1886,11 @@ mod tests {
let record_batch = create_batch();
let record_batch_slice = record_batch.slice(1, 2);
- let deserialized_batch = deserialize(serialize(&record_batch_slice));
+ let deserialized_batch =
deserialize_stream(serialize_stream(&record_batch_slice));
- assert!(serialize(&record_batch).len() >
serialize(&record_batch_slice).len());
+ assert!(
+ serialize_stream(&record_batch).len() >
serialize_stream(&record_batch_slice).len()
+ );
assert!(deserialized_batch.column(0).is_valid(0));
assert!(deserialized_batch.column(0).is_null(1));
@@ -1886,9 +1928,11 @@ mod tests {
let record_batch = create_batch();
let record_batch_slice = record_batch.slice(1, 2);
- let deserialized_batch = deserialize(serialize(&record_batch_slice));
+ let deserialized_batch =
deserialize_stream(serialize_stream(&record_batch_slice));
- assert!(serialize(&record_batch).len() >
serialize(&record_batch_slice).len());
+ assert!(
+ serialize_stream(&record_batch).len() >
serialize_stream(&record_batch_slice).len()
+ );
let structs = deserialized_batch
.column(0)
@@ -1913,9 +1957,11 @@ mod tests {
let record_batch = create_batch();
let record_batch_slice = record_batch.slice(0, 1);
- let deserialized_batch = deserialize(serialize(&record_batch_slice));
+ let deserialized_batch =
deserialize_stream(serialize_stream(&record_batch_slice));
- assert!(serialize(&record_batch).len() >
serialize(&record_batch_slice).len());
+ assert!(
+ serialize_stream(&record_batch).len() >
serialize_stream(&record_batch_slice).len()
+ );
assert_eq!(record_batch_slice, deserialized_batch);
}
@@ -1996,13 +2042,8 @@ mod tests {
let batch = RecordBatch::try_new(Arc::clone(&schema),
vec![Arc::new(bools)]).unwrap();
let batch = batch.slice(offset, length);
- let mut writer = StreamWriter::try_new(Vec::<u8>::new(),
&schema).unwrap();
- writer.write(&batch).unwrap();
- writer.finish().unwrap();
- let data = writer.into_inner().unwrap();
-
- let mut reader = StreamReader::try_new(Cursor::new(data),
None).unwrap();
- let batch2 = reader.next().unwrap().unwrap();
+ let data = serialize_stream(&batch);
+ let batch2 = deserialize_stream(data);
assert_eq!(batch, batch2);
}
@@ -2060,37 +2101,137 @@ mod tests {
}
}
+ fn generate_list_data<O: OffsetSizeTrait>() -> GenericListArray<O> {
+ let mut ls = GenericListBuilder::<O, _>::new(UInt32Builder::new());
+
+ for i in 0..100_000 {
+ for value in [i, i, i] {
+ ls.values().append_value(value);
+ }
+ ls.append(true)
+ }
+
+ ls.finish()
+ }
+
+ fn generate_nested_list_data<O: OffsetSizeTrait>() -> GenericListArray<O> {
+ let mut ls =
+ GenericListBuilder::<O, _>::new(GenericListBuilder::<O,
_>::new(UInt32Builder::new()));
+
+ for _i in 0..10_000 {
+ for j in 0..10 {
+ for value in [j, j, j, j] {
+ ls.values().values().append_value(value);
+ }
+ ls.values().append(true)
+ }
+ ls.append(true);
+ }
+
+ ls.finish()
+ }
+
+ fn generate_map_array_data() -> MapArray {
+ let keys_builder = UInt32Builder::new();
+ let values_builder = UInt32Builder::new();
+
+ let mut builder = MapBuilder::new(None, keys_builder, values_builder);
+
+ for i in 0..100_000 {
+ for _j in 0..3 {
+ builder.keys().append_value(i);
+ builder.values().append_value(i * 2);
+ }
+ builder.append(true).unwrap();
+ }
+
+ builder.finish()
+ }
+
+ /// Ensure when serde full & sliced versions they are equal to original
input.
+ /// Also ensure serialized sliced version is significantly smaller than
serialized full.
+ fn roundtrip_ensure_sliced_smaller(in_batch: RecordBatch,
expected_size_factor: usize) {
+ // test both full and sliced versions
+ let in_sliced = in_batch.slice(999, 1);
+
+ let bytes_batch = serialize_file(&in_batch);
+ let bytes_sliced = serialize_file(&in_sliced);
+
+ // serializing 1 row should be significantly smaller than serializing
100,000
+ assert!(bytes_sliced.len() < (bytes_batch.len() /
expected_size_factor));
+
+ // ensure both are still valid and equal to originals
+ let out_batch = deserialize_file(bytes_batch);
+ assert_eq!(in_batch, out_batch);
+
+ let out_sliced = deserialize_file(bytes_sliced);
+ assert_eq!(in_sliced, out_sliced);
+ }
+
#[test]
fn encode_lists() {
let val_inner = Field::new("item", DataType::UInt32, true);
- let val_list_field = Field::new_list("val", val_inner, false);
+ let val_list_field = Field::new("val",
DataType::List(Arc::new(val_inner)), false);
+ let schema = Arc::new(Schema::new(vec![val_list_field]));
+
+ let values = Arc::new(generate_list_data::<i32>());
+
+ let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
+ roundtrip_ensure_sliced_smaller(in_batch, 1000);
+ }
+
+ #[test]
+ fn encode_empty_list() {
+ let val_inner = Field::new("item", DataType::UInt32, true);
+ let val_list_field = Field::new("val",
DataType::List(Arc::new(val_inner)), false);
+ let schema = Arc::new(Schema::new(vec![val_list_field]));
+
+ let values = Arc::new(generate_list_data::<i32>());
+ let in_batch = RecordBatch::try_new(schema, vec![values])
+ .unwrap()
+ .slice(999, 0);
+ let out_batch = deserialize_file(serialize_file(&in_batch));
+ assert_eq!(in_batch, out_batch);
+ }
+
+ #[test]
+ fn encode_large_lists() {
+ let val_inner = Field::new("item", DataType::UInt32, true);
+ let val_list_field = Field::new("val",
DataType::LargeList(Arc::new(val_inner)), false);
let schema = Arc::new(Schema::new(vec![val_list_field]));
- let values = {
- let u32 = UInt32Builder::new();
- let mut ls = ListBuilder::new(u32);
+ let values = Arc::new(generate_list_data::<i64>());
- for list in [vec![1u32, 2, 3], vec![4, 5, 6], vec![7, 8, 9, 10]] {
- for value in list {
- ls.values().append_value(value);
- }
- ls.append(true)
- }
+ // ensure when serde full & sliced versions they are equal to original
input
+ // also ensure serialized sliced version is significantly smaller than
serialized full
+ let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
+ roundtrip_ensure_sliced_smaller(in_batch, 1000);
+ }
- ls.finish()
- };
+ #[test]
+ fn encode_nested_lists() {
+ let inner_int = Arc::new(Field::new("item", DataType::UInt32, true));
+ let inner_list_field = Arc::new(Field::new("item",
DataType::List(inner_int), true));
+ let list_field = Field::new("val", DataType::List(inner_list_field),
true);
+ let schema = Arc::new(Schema::new(vec![list_field]));
- let batch = RecordBatch::try_new(Arc::clone(&schema),
vec![Arc::new(values)]).unwrap();
- let batch = batch.slice(1, 1);
+ let values = Arc::new(generate_nested_list_data::<i32>());
- let mut writer = FileWriter::try_new(Vec::<u8>::new(),
&schema).unwrap();
- writer.write(&batch).unwrap();
- writer.finish().unwrap();
- let data = writer.into_inner().unwrap();
+ let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
+ roundtrip_ensure_sliced_smaller(in_batch, 1000);
+ }
- let mut reader = FileReader::try_new(Cursor::new(data), None).unwrap();
- let batch2 = reader.next().unwrap().unwrap();
- assert_eq!(batch, batch2);
+ #[test]
+ fn encode_map_array() {
+ let keys = Arc::new(Field::new("keys", DataType::UInt32, false));
+ let values = Arc::new(Field::new("values", DataType::UInt32, true));
+ let map_field = Field::new_map("map", "entries", keys, values, false,
true);
+ let schema = Arc::new(Schema::new(vec![map_field]));
+
+ let values = Arc::new(generate_map_array_data());
+
+ let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
+ roundtrip_ensure_sliced_smaller(in_batch, 1000);
}
}