This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 924b6e9d0e IPC writer truncated sliced list/map values (#5071)
924b6e9d0e is described below

commit 924b6e9d0e62ad8cb85419268d8765611a72631e
Author: Jeffrey <[email protected]>
AuthorDate: Tue Nov 14 08:01:10 2023 +1100

    IPC writer truncated sliced list/map values (#5071)
    
    * IPC writer truncated sliced list/map values
    
    * Add empty list test
    
    * Revert submodule update
---
 arrow-ipc/src/writer.rs | 429 ++++++++++++++++++++++++++++++++----------------
 1 file changed, 285 insertions(+), 144 deletions(-)

diff --git a/arrow-ipc/src/writer.rs b/arrow-ipc/src/writer.rs
index a58cbfc514..1f6bf5f6fa 100644
--- a/arrow-ipc/src/writer.rs
+++ b/arrow-ipc/src/writer.rs
@@ -1139,6 +1139,29 @@ fn get_buffer_element_width(spec: &BufferSpec) -> usize {
     }
 }
 
+/// Common functionality for re-encoding offsets. Returns the new offsets as 
well as
+/// original start offset and length for use in slicing child data.
+fn reencode_offsets<O: OffsetSizeTrait>(
+    offsets: &Buffer,
+    data: &ArrayData,
+) -> (Buffer, usize, usize) {
+    let offsets_slice: &[O] = offsets.typed_data::<O>();
+    let offset_slice = &offsets_slice[data.offset()..data.offset() + 
data.len() + 1];
+
+    let start_offset = offset_slice.first().unwrap();
+    let end_offset = offset_slice.last().unwrap();
+
+    let offsets = match start_offset.as_usize() {
+        0 => offsets.clone(),
+        _ => offset_slice.iter().map(|x| *x - *start_offset).collect(),
+    };
+
+    let start_offset = start_offset.as_usize();
+    let end_offset = end_offset.as_usize();
+
+    (offsets, start_offset, end_offset - start_offset)
+}
+
 /// Returns the values and offsets [`Buffer`] for a ByteArray with offset type 
`O`
 ///
 /// In particular, this handles re-encoding the offsets if they don't start at 
`0`,
@@ -1149,23 +1172,24 @@ fn get_byte_array_buffers<O: OffsetSizeTrait>(data: 
&ArrayData) -> (Buffer, Buff
         return (MutableBuffer::new(0).into(), MutableBuffer::new(0).into());
     }
 
-    let buffers = data.buffers();
-    let offsets: &[O] = buffers[0].typed_data::<O>();
-    let offset_slice = &offsets[data.offset()..data.offset() + data.len() + 1];
-
-    let start_offset = offset_slice.first().unwrap();
-    let end_offset = offset_slice.last().unwrap();
+    let (offsets, original_start_offset, len) = 
reencode_offsets::<O>(&data.buffers()[0], data);
+    let values = data.buffers()[1].slice_with_length(original_start_offset, 
len);
+    (offsets, values)
+}
 
-    let offsets = match start_offset.as_usize() {
-        0 => buffers[0].clone(),
-        _ => offset_slice.iter().map(|x| *x - *start_offset).collect(),
-    };
+/// Similar logic as [`get_byte_array_buffers()`] but slices the child array 
instead
+/// of a values buffer.
+fn get_list_array_buffers<O: OffsetSizeTrait>(data: &ArrayData) -> (Buffer, 
ArrayData) {
+    if data.is_empty() {
+        return (
+            MutableBuffer::new(0).into(),
+            data.child_data()[0].slice(0, 0),
+        );
+    }
 
-    let values = buffers[1].slice_with_length(
-        start_offset.as_usize(),
-        end_offset.as_usize() - start_offset.as_usize(),
-    );
-    (offsets, values)
+    let (offsets, original_start_offset, len) = 
reencode_offsets::<O>(&data.buffers()[0], data);
+    let child_data = data.child_data()[0].slice(original_start_offset, len);
+    (offsets, child_data)
 }
 
 /// Write array data to a vector of bytes
@@ -1250,20 +1274,14 @@ fn write_array_data(
 
         let byte_width = get_buffer_element_width(spec);
         let min_length = array_data.len() * byte_width;
-        if buffer_need_truncate(array_data.offset(), buffer, spec, min_length) 
{
+        let buffer_slice = if buffer_need_truncate(array_data.offset(), 
buffer, spec, min_length) {
             let byte_offset = array_data.offset() * byte_width;
             let buffer_length = min(min_length, buffer.len() - byte_offset);
-            let buffer_slice = &buffer.as_slice()[byte_offset..(byte_offset + 
buffer_length)];
-            offset = write_buffer(buffer_slice, buffers, arrow_data, offset, 
compression_codec)?;
+            &buffer.as_slice()[byte_offset..(byte_offset + buffer_length)]
         } else {
-            offset = write_buffer(
-                buffer.as_slice(),
-                buffers,
-                arrow_data,
-                offset,
-                compression_codec,
-            )?;
-        }
+            buffer.as_slice()
+        };
+        offset = write_buffer(buffer_slice, buffers, arrow_data, offset, 
compression_codec)?;
     } else if matches!(data_type, DataType::Boolean) {
         // Bools are special because the payload (= 1 bit) is smaller than the 
physical container elements (= bytes).
         // The array data may not start at the physical boundary of the 
underlying buffer, so we need to shift bits around.
@@ -1272,6 +1290,39 @@ fn write_array_data(
         let buffer = &array_data.buffers()[0];
         let buffer = buffer.bit_slice(array_data.offset(), array_data.len());
         offset = write_buffer(&buffer, buffers, arrow_data, offset, 
compression_codec)?;
+    } else if matches!(
+        data_type,
+        DataType::List(_) | DataType::LargeList(_) | DataType::Map(_, _)
+    ) {
+        assert_eq!(array_data.buffers().len(), 1);
+        assert_eq!(array_data.child_data().len(), 1);
+
+        // Truncate offsets and the child data to avoid writing unnecessary 
data
+        let (offsets, sliced_child_data) = match data_type {
+            DataType::List(_) => get_list_array_buffers::<i32>(array_data),
+            DataType::Map(_, _) => get_list_array_buffers::<i32>(array_data),
+            DataType::LargeList(_) => 
get_list_array_buffers::<i64>(array_data),
+            _ => unreachable!(),
+        };
+        offset = write_buffer(
+            offsets.as_slice(),
+            buffers,
+            arrow_data,
+            offset,
+            compression_codec,
+        )?;
+        offset = write_array_data(
+            &sliced_child_data,
+            buffers,
+            arrow_data,
+            nodes,
+            offset,
+            sliced_child_data.len(),
+            sliced_child_data.null_count(),
+            compression_codec,
+            write_options,
+        )?;
+        return Ok(offset);
     } else {
         for buffer in array_data.buffers() {
             offset = write_buffer(buffer, buffers, arrow_data, offset, 
compression_codec)?;
@@ -1372,8 +1423,10 @@ mod tests {
     use std::io::Seek;
     use std::sync::Arc;
 
+    use arrow_array::builder::GenericListBuilder;
+    use arrow_array::builder::MapBuilder;
     use arrow_array::builder::UnionBuilder;
-    use arrow_array::builder::{ListBuilder, PrimitiveRunBuilder, 
UInt32Builder};
+    use arrow_array::builder::{PrimitiveRunBuilder, UInt32Builder};
     use arrow_array::types::*;
     use arrow_schema::DataType;
 
@@ -1382,6 +1435,30 @@ mod tests {
 
     use super::*;
 
+    fn serialize_file(rb: &RecordBatch) -> Vec<u8> {
+        let mut writer = FileWriter::try_new(vec![], &rb.schema()).unwrap();
+        writer.write(rb).unwrap();
+        writer.finish().unwrap();
+        writer.into_inner().unwrap()
+    }
+
+    fn deserialize_file(bytes: Vec<u8>) -> RecordBatch {
+        let mut reader = FileReader::try_new(Cursor::new(bytes), 
None).unwrap();
+        reader.next().unwrap().unwrap()
+    }
+
+    fn serialize_stream(record: &RecordBatch) -> Vec<u8> {
+        let mut stream_writer = StreamWriter::try_new(vec![], 
&record.schema()).unwrap();
+        stream_writer.write(record).unwrap();
+        stream_writer.finish().unwrap();
+        stream_writer.into_inner().unwrap()
+    }
+
+    fn deserialize_stream(bytes: Vec<u8>) -> RecordBatch {
+        let mut stream_reader = StreamReader::try_new(Cursor::new(bytes), 
None).unwrap();
+        stream_reader.next().unwrap().unwrap()
+    }
+
     #[test]
     #[cfg(feature = "lz4")]
     fn test_write_empty_record_batch_lz4_compression() {
@@ -1407,27 +1484,18 @@ mod tests {
         file.rewind().unwrap();
         {
             // read file
-            let mut reader = FileReader::try_new(file, None).unwrap();
-            loop {
-                match reader.next() {
-                    Some(Ok(read_batch)) => {
-                        read_batch
-                            .columns()
-                            .iter()
-                            .zip(record_batch.columns())
-                            .for_each(|(a, b)| {
-                                assert_eq!(a.data_type(), b.data_type());
-                                assert_eq!(a.len(), b.len());
-                                assert_eq!(a.null_count(), b.null_count());
-                            });
-                    }
-                    Some(Err(e)) => {
-                        panic!("{}", e);
-                    }
-                    None => {
-                        break;
-                    }
-                }
+            let reader = FileReader::try_new(file, None).unwrap();
+            for read_batch in reader {
+                read_batch
+                    .unwrap()
+                    .columns()
+                    .iter()
+                    .zip(record_batch.columns())
+                    .for_each(|(a, b)| {
+                        assert_eq!(a.data_type(), b.data_type());
+                        assert_eq!(a.len(), b.len());
+                        assert_eq!(a.null_count(), b.null_count());
+                    });
             }
         }
     }
@@ -1456,27 +1524,18 @@ mod tests {
         file.rewind().unwrap();
         {
             // read file
-            let mut reader = FileReader::try_new(file, None).unwrap();
-            loop {
-                match reader.next() {
-                    Some(Ok(read_batch)) => {
-                        read_batch
-                            .columns()
-                            .iter()
-                            .zip(record_batch.columns())
-                            .for_each(|(a, b)| {
-                                assert_eq!(a.data_type(), b.data_type());
-                                assert_eq!(a.len(), b.len());
-                                assert_eq!(a.null_count(), b.null_count());
-                            });
-                    }
-                    Some(Err(e)) => {
-                        panic!("{}", e);
-                    }
-                    None => {
-                        break;
-                    }
-                }
+            let reader = FileReader::try_new(file, None).unwrap();
+            for read_batch in reader {
+                read_batch
+                    .unwrap()
+                    .columns()
+                    .iter()
+                    .zip(record_batch.columns())
+                    .for_each(|(a, b)| {
+                        assert_eq!(a.data_type(), b.data_type());
+                        assert_eq!(a.len(), b.len());
+                        assert_eq!(a.null_count(), b.null_count());
+                    });
             }
         }
     }
@@ -1504,27 +1563,18 @@ mod tests {
         file.rewind().unwrap();
         {
             // read file
-            let mut reader = FileReader::try_new(file, None).unwrap();
-            loop {
-                match reader.next() {
-                    Some(Ok(read_batch)) => {
-                        read_batch
-                            .columns()
-                            .iter()
-                            .zip(record_batch.columns())
-                            .for_each(|(a, b)| {
-                                assert_eq!(a.data_type(), b.data_type());
-                                assert_eq!(a.len(), b.len());
-                                assert_eq!(a.null_count(), b.null_count());
-                            });
-                    }
-                    Some(Err(e)) => {
-                        panic!("{}", e);
-                    }
-                    None => {
-                        break;
-                    }
-                }
+            let reader = FileReader::try_new(file, None).unwrap();
+            for read_batch in reader {
+                read_batch
+                    .unwrap()
+                    .columns()
+                    .iter()
+                    .zip(record_batch.columns())
+                    .for_each(|(a, b)| {
+                        assert_eq!(a.data_type(), b.data_type());
+                        assert_eq!(a.len(), b.len());
+                        assert_eq!(a.null_count(), b.null_count());
+                    });
             }
         }
     }
@@ -1754,20 +1804,6 @@ mod tests {
         write_union_file(IpcWriteOptions::try_new(8, false, 
MetadataVersion::V5).unwrap());
     }
 
-    fn serialize(record: &RecordBatch) -> Vec<u8> {
-        let buffer: Vec<u8> = Vec::new();
-        let mut stream_writer = StreamWriter::try_new(buffer, 
&record.schema()).unwrap();
-        stream_writer.write(record).unwrap();
-        stream_writer.finish().unwrap();
-        stream_writer.into_inner().unwrap()
-    }
-
-    fn deserialize(bytes: Vec<u8>) -> RecordBatch {
-        let mut stream_reader =
-            crate::reader::StreamReader::try_new(std::io::Cursor::new(bytes), 
None).unwrap();
-        stream_reader.next().unwrap().unwrap()
-    }
-
     #[test]
     fn truncate_ipc_record_batch() {
         fn create_batch(rows: usize) -> RecordBatch {
@@ -1789,14 +1825,16 @@ mod tests {
 
         let offset = 2;
         let record_batch_slice = big_record_batch.slice(offset, length);
-        assert!(serialize(&big_record_batch).len() > 
serialize(&small_record_batch).len());
+        assert!(
+            serialize_stream(&big_record_batch).len() > 
serialize_stream(&small_record_batch).len()
+        );
         assert_eq!(
-            serialize(&small_record_batch).len(),
-            serialize(&record_batch_slice).len()
+            serialize_stream(&small_record_batch).len(),
+            serialize_stream(&record_batch_slice).len()
         );
 
         assert_eq!(
-            deserialize(serialize(&record_batch_slice)),
+            deserialize_stream(serialize_stream(&record_batch_slice)),
             record_batch_slice
         );
     }
@@ -1817,9 +1855,11 @@ mod tests {
 
         let record_batch = create_batch();
         let record_batch_slice = record_batch.slice(1, 2);
-        let deserialized_batch = deserialize(serialize(&record_batch_slice));
+        let deserialized_batch = 
deserialize_stream(serialize_stream(&record_batch_slice));
 
-        assert!(serialize(&record_batch).len() > 
serialize(&record_batch_slice).len());
+        assert!(
+            serialize_stream(&record_batch).len() > 
serialize_stream(&record_batch_slice).len()
+        );
 
         assert!(deserialized_batch.column(0).is_null(0));
         assert!(deserialized_batch.column(0).is_valid(1));
@@ -1846,9 +1886,11 @@ mod tests {
 
         let record_batch = create_batch();
         let record_batch_slice = record_batch.slice(1, 2);
-        let deserialized_batch = deserialize(serialize(&record_batch_slice));
+        let deserialized_batch = 
deserialize_stream(serialize_stream(&record_batch_slice));
 
-        assert!(serialize(&record_batch).len() > 
serialize(&record_batch_slice).len());
+        assert!(
+            serialize_stream(&record_batch).len() > 
serialize_stream(&record_batch_slice).len()
+        );
 
         assert!(deserialized_batch.column(0).is_valid(0));
         assert!(deserialized_batch.column(0).is_null(1));
@@ -1886,9 +1928,11 @@ mod tests {
 
         let record_batch = create_batch();
         let record_batch_slice = record_batch.slice(1, 2);
-        let deserialized_batch = deserialize(serialize(&record_batch_slice));
+        let deserialized_batch = 
deserialize_stream(serialize_stream(&record_batch_slice));
 
-        assert!(serialize(&record_batch).len() > 
serialize(&record_batch_slice).len());
+        assert!(
+            serialize_stream(&record_batch).len() > 
serialize_stream(&record_batch_slice).len()
+        );
 
         let structs = deserialized_batch
             .column(0)
@@ -1913,9 +1957,11 @@ mod tests {
 
         let record_batch = create_batch();
         let record_batch_slice = record_batch.slice(0, 1);
-        let deserialized_batch = deserialize(serialize(&record_batch_slice));
+        let deserialized_batch = 
deserialize_stream(serialize_stream(&record_batch_slice));
 
-        assert!(serialize(&record_batch).len() > 
serialize(&record_batch_slice).len());
+        assert!(
+            serialize_stream(&record_batch).len() > 
serialize_stream(&record_batch_slice).len()
+        );
         assert_eq!(record_batch_slice, deserialized_batch);
     }
 
@@ -1996,13 +2042,8 @@ mod tests {
         let batch = RecordBatch::try_new(Arc::clone(&schema), 
vec![Arc::new(bools)]).unwrap();
         let batch = batch.slice(offset, length);
 
-        let mut writer = StreamWriter::try_new(Vec::<u8>::new(), 
&schema).unwrap();
-        writer.write(&batch).unwrap();
-        writer.finish().unwrap();
-        let data = writer.into_inner().unwrap();
-
-        let mut reader = StreamReader::try_new(Cursor::new(data), 
None).unwrap();
-        let batch2 = reader.next().unwrap().unwrap();
+        let data = serialize_stream(&batch);
+        let batch2 = deserialize_stream(data);
         assert_eq!(batch, batch2);
     }
 
@@ -2060,37 +2101,137 @@ mod tests {
         }
     }
 
+    fn generate_list_data<O: OffsetSizeTrait>() -> GenericListArray<O> {
+        let mut ls = GenericListBuilder::<O, _>::new(UInt32Builder::new());
+
+        for i in 0..100_000 {
+            for value in [i, i, i] {
+                ls.values().append_value(value);
+            }
+            ls.append(true)
+        }
+
+        ls.finish()
+    }
+
+    fn generate_nested_list_data<O: OffsetSizeTrait>() -> GenericListArray<O> {
+        let mut ls =
+            GenericListBuilder::<O, _>::new(GenericListBuilder::<O, 
_>::new(UInt32Builder::new()));
+
+        for _i in 0..10_000 {
+            for j in 0..10 {
+                for value in [j, j, j, j] {
+                    ls.values().values().append_value(value);
+                }
+                ls.values().append(true)
+            }
+            ls.append(true);
+        }
+
+        ls.finish()
+    }
+
+    fn generate_map_array_data() -> MapArray {
+        let keys_builder = UInt32Builder::new();
+        let values_builder = UInt32Builder::new();
+
+        let mut builder = MapBuilder::new(None, keys_builder, values_builder);
+
+        for i in 0..100_000 {
+            for _j in 0..3 {
+                builder.keys().append_value(i);
+                builder.values().append_value(i * 2);
+            }
+            builder.append(true).unwrap();
+        }
+
+        builder.finish()
+    }
+
+    /// Ensure when serde full & sliced versions they are equal to original 
input.
+    /// Also ensure serialized sliced version is significantly smaller than 
serialized full.
+    fn roundtrip_ensure_sliced_smaller(in_batch: RecordBatch, 
expected_size_factor: usize) {
+        // test both full and sliced versions
+        let in_sliced = in_batch.slice(999, 1);
+
+        let bytes_batch = serialize_file(&in_batch);
+        let bytes_sliced = serialize_file(&in_sliced);
+
+        // serializing 1 row should be significantly smaller than serializing 
100,000
+        assert!(bytes_sliced.len() < (bytes_batch.len() / 
expected_size_factor));
+
+        // ensure both are still valid and equal to originals
+        let out_batch = deserialize_file(bytes_batch);
+        assert_eq!(in_batch, out_batch);
+
+        let out_sliced = deserialize_file(bytes_sliced);
+        assert_eq!(in_sliced, out_sliced);
+    }
+
     #[test]
     fn encode_lists() {
         let val_inner = Field::new("item", DataType::UInt32, true);
-        let val_list_field = Field::new_list("val", val_inner, false);
+        let val_list_field = Field::new("val", 
DataType::List(Arc::new(val_inner)), false);
+        let schema = Arc::new(Schema::new(vec![val_list_field]));
+
+        let values = Arc::new(generate_list_data::<i32>());
+
+        let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
+        roundtrip_ensure_sliced_smaller(in_batch, 1000);
+    }
+
+    #[test]
+    fn encode_empty_list() {
+        let val_inner = Field::new("item", DataType::UInt32, true);
+        let val_list_field = Field::new("val", 
DataType::List(Arc::new(val_inner)), false);
+        let schema = Arc::new(Schema::new(vec![val_list_field]));
+
+        let values = Arc::new(generate_list_data::<i32>());
 
+        let in_batch = RecordBatch::try_new(schema, vec![values])
+            .unwrap()
+            .slice(999, 0);
+        let out_batch = deserialize_file(serialize_file(&in_batch));
+        assert_eq!(in_batch, out_batch);
+    }
+
+    #[test]
+    fn encode_large_lists() {
+        let val_inner = Field::new("item", DataType::UInt32, true);
+        let val_list_field = Field::new("val", 
DataType::LargeList(Arc::new(val_inner)), false);
         let schema = Arc::new(Schema::new(vec![val_list_field]));
 
-        let values = {
-            let u32 = UInt32Builder::new();
-            let mut ls = ListBuilder::new(u32);
+        let values = Arc::new(generate_list_data::<i64>());
 
-            for list in [vec![1u32, 2, 3], vec![4, 5, 6], vec![7, 8, 9, 10]] {
-                for value in list {
-                    ls.values().append_value(value);
-                }
-                ls.append(true)
-            }
+        // ensure when serde full & sliced versions they are equal to original 
input
+        // also ensure serialized sliced version is significantly smaller than 
serialized full
+        let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
+        roundtrip_ensure_sliced_smaller(in_batch, 1000);
+    }
 
-            ls.finish()
-        };
+    #[test]
+    fn encode_nested_lists() {
+        let inner_int = Arc::new(Field::new("item", DataType::UInt32, true));
+        let inner_list_field = Arc::new(Field::new("item", 
DataType::List(inner_int), true));
+        let list_field = Field::new("val", DataType::List(inner_list_field), 
true);
+        let schema = Arc::new(Schema::new(vec![list_field]));
 
-        let batch = RecordBatch::try_new(Arc::clone(&schema), 
vec![Arc::new(values)]).unwrap();
-        let batch = batch.slice(1, 1);
+        let values = Arc::new(generate_nested_list_data::<i32>());
 
-        let mut writer = FileWriter::try_new(Vec::<u8>::new(), 
&schema).unwrap();
-        writer.write(&batch).unwrap();
-        writer.finish().unwrap();
-        let data = writer.into_inner().unwrap();
+        let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
+        roundtrip_ensure_sliced_smaller(in_batch, 1000);
+    }
 
-        let mut reader = FileReader::try_new(Cursor::new(data), None).unwrap();
-        let batch2 = reader.next().unwrap().unwrap();
-        assert_eq!(batch, batch2);
+    #[test]
+    fn encode_map_array() {
+        let keys = Arc::new(Field::new("keys", DataType::UInt32, false));
+        let values = Arc::new(Field::new("values", DataType::UInt32, true));
+        let map_field = Field::new_map("map", "entries", keys, values, false, 
true);
+        let schema = Arc::new(Schema::new(vec![map_field]));
+
+        let values = Arc::new(generate_map_array_data());
+
+        let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
+        roundtrip_ensure_sliced_smaller(in_batch, 1000);
     }
 }

Reply via email to