This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/main by this push:
     new 2fd51814df Add `with_skip_validation` flag to IPC `StreamReader`, 
`FileReader` and `FileDecoder` (#7120)
2fd51814df is described below

commit 2fd51814df6f85d74179be845e4d2b4cf2bff4e3
Author: Andrew Lamb <[email protected]>
AuthorDate: Thu Feb 27 07:24:54 2025 -0500

    Add `with_skip_validation` flag to IPC `StreamReader`, `FileReader` and 
`FileDecoder` (#7120)
    
    * Make UnsafeFlag pub
    
    * Add `with_skip_validation` flag to IPC `StreamReader`, `FileReader` and 
`FileDecoder`
    
    * Apply suggestions from code review
    
    Co-authored-by: Raphael Taylor-Davies 
<[email protected]>
    
    * Update arrow-ipc/src/reader.rs
    
    Co-authored-by: Raphael Taylor-Davies 
<[email protected]>
    
    * Update arrow-ipc/src/reader.rs
    
    Co-authored-by: Raphael Taylor-Davies 
<[email protected]>
    
    * Add notes to RecordBatchDecoder::with_skip_validation flag
    
    * remove repetition in read benchmark
    
    ---------
    
    Co-authored-by: Raphael Taylor-Davies 
<[email protected]>
---
 arrow-data/src/data.rs          |  74 +++++++---
 arrow-ipc/benches/ipc_reader.rs | 154 ++++++++++++++++-----
 arrow-ipc/src/reader.rs         | 300 +++++++++++++++++++++++++++++++++-------
 arrow-ipc/src/reader/stream.rs  |   8 ++
 4 files changed, 426 insertions(+), 110 deletions(-)

diff --git a/arrow-data/src/data.rs b/arrow-data/src/data.rs
index 376b6f9584..a76e00e267 100644
--- a/arrow-data/src/data.rs
+++ b/arrow-data/src/data.rs
@@ -28,7 +28,6 @@ use std::mem;
 use std::ops::Range;
 use std::sync::Arc;
 
-use crate::data::private::UnsafeFlag;
 use crate::{equal, validate_binary_view, validate_string_view};
 
 #[inline]
@@ -1781,33 +1780,62 @@ impl PartialEq for ArrayData {
     }
 }
 
-mod private {
-    /// A boolean flag that cannot be mutated outside of unsafe code.
+/// A boolean flag that cannot be mutated outside of unsafe code.
+///
+/// Defaults to a value of false.
+///
+/// This structure is used to enforce safety in the [`ArrayDataBuilder`]
+///
+/// [`ArrayDataBuilder`]: super::ArrayDataBuilder
+///
+/// # Example
+/// ```rust
+/// use arrow_data::UnsafeFlag;
+/// assert!(!UnsafeFlag::default().get()); // default is false
+/// let mut flag = UnsafeFlag::new();
+/// assert!(!flag.get()); // defaults to false
+/// // can only set it to true in unsafe code
+/// unsafe { flag.set(true) };
+/// assert!(flag.get()); // now true
+/// ```
+#[derive(Debug, Clone)]
+#[doc(hidden)]
+pub struct UnsafeFlag(bool);
+
+impl UnsafeFlag {
+    /// Creates a new `UnsafeFlag` with the value set to `false`.
+    ///
+    /// See examples on [`Self::new`]
+    #[inline]
+    pub const fn new() -> Self {
+        Self(false)
+    }
+
+    /// Sets the value of the flag to the given value
     ///
-    /// Defaults to a value of false.
+    /// Note this can purposely only be done in `unsafe` code
     ///
-    /// This structure is used to enforce safety in the [`ArrayDataBuilder`]
+    /// # Safety
     ///
-    /// [`ArrayDataBuilder`]: super::ArrayDataBuilder
-    #[derive(Debug)]
-    pub struct UnsafeFlag(bool);
-
-    impl UnsafeFlag {
-        /// Creates a new `UnsafeFlag` with the value set to `false`
-        #[inline]
-        pub const fn new() -> Self {
-            Self(false)
-        }
+    /// If set, the flag will be set to the given value. There is nothing
+    /// immediately unsafe about doing so, however, the flag can be used to
+    /// subsequently bypass safety checks in the [`ArrayDataBuilder`].
+    #[inline]
+    pub unsafe fn set(&mut self, val: bool) {
+        self.0 = val;
+    }
 
-        #[inline]
-        pub unsafe fn set(&mut self, val: bool) {
-            self.0 = val;
-        }
+    /// Returns the value of the flag
+    #[inline]
+    pub fn get(&self) -> bool {
+        self.0
+    }
+}
 
-        #[inline]
-        pub fn get(&self) -> bool {
-            self.0
-        }
+// Manual impl to make it clear you can not construct unsafe with true
+impl Default for UnsafeFlag {
+    fn default() -> Self {
+        Self::new()
     }
 }
 
diff --git a/arrow-ipc/benches/ipc_reader.rs b/arrow-ipc/benches/ipc_reader.rs
index 7fc14664b4..ab77449eeb 100644
--- a/arrow-ipc/benches/ipc_reader.rs
+++ b/arrow-ipc/benches/ipc_reader.rs
@@ -24,7 +24,7 @@ use arrow_ipc::writer::{FileWriter, IpcWriteOptions, 
StreamWriter};
 use arrow_ipc::{root_as_footer, Block, CompressionType};
 use arrow_schema::{DataType, Field, Schema};
 use criterion::{criterion_group, criterion_main, Criterion};
-use std::io::Cursor;
+use std::io::{Cursor, Write};
 use std::sync::Arc;
 use tempfile::tempdir;
 
@@ -32,17 +32,26 @@ fn criterion_benchmark(c: &mut Criterion) {
     let mut group = c.benchmark_group("arrow_ipc_reader");
 
     group.bench_function("StreamReader/read_10", |b| {
-        let batch = create_batch(8192, true);
-        let mut buffer = Vec::with_capacity(2 * 1024 * 1024);
-        let mut writer = StreamWriter::try_new(&mut buffer, 
batch.schema().as_ref()).unwrap();
-        for _ in 0..10 {
-            writer.write(&batch).unwrap();
-        }
-        writer.finish().unwrap();
+        let buffer = ipc_stream(IpcWriteOptions::default());
+        b.iter(move || {
+            let projection = None;
+            let mut reader = StreamReader::try_new(buffer.as_slice(), 
projection).unwrap();
+            for _ in 0..10 {
+                reader.next().unwrap().unwrap();
+            }
+            assert!(reader.next().is_none());
+        })
+    });
 
+    group.bench_function("StreamReader/no_validation/read_10", |b| {
+        let buffer = ipc_stream(IpcWriteOptions::default());
         b.iter(move || {
             let projection = None;
             let mut reader = StreamReader::try_new(buffer.as_slice(), 
projection).unwrap();
+            unsafe {
+                // safety: we created a valid IPC file
+                reader = reader.with_skip_validation(true);
+            }
             for _ in 0..10 {
                 reader.next().unwrap().unwrap();
             }
@@ -51,22 +60,34 @@ fn criterion_benchmark(c: &mut Criterion) {
     });
 
     group.bench_function("StreamReader/read_10/zstd", |b| {
-        let batch = create_batch(8192, true);
-        let mut buffer = Vec::with_capacity(2 * 1024 * 1024);
-        let options = IpcWriteOptions::default()
-            .try_with_compression(Some(CompressionType::ZSTD))
-            .unwrap();
-        let mut writer =
-            StreamWriter::try_new_with_options(&mut buffer, 
batch.schema().as_ref(), options)
-                .unwrap();
-        for _ in 0..10 {
-            writer.write(&batch).unwrap();
-        }
-        writer.finish().unwrap();
+        let buffer = ipc_stream(
+            IpcWriteOptions::default()
+                .try_with_compression(Some(CompressionType::ZSTD))
+                .unwrap(),
+        );
+        b.iter(move || {
+            let projection = None;
+            let mut reader = StreamReader::try_new(buffer.as_slice(), 
projection).unwrap();
+            for _ in 0..10 {
+                reader.next().unwrap().unwrap();
+            }
+            assert!(reader.next().is_none());
+        })
+    });
 
+    group.bench_function("StreamReader/no_validation/read_10/zstd", |b| {
+        let buffer = ipc_stream(
+            IpcWriteOptions::default()
+                .try_with_compression(Some(CompressionType::ZSTD))
+                .unwrap(),
+        );
         b.iter(move || {
             let projection = None;
             let mut reader = StreamReader::try_new(buffer.as_slice(), 
projection).unwrap();
+            unsafe {
+                // safety: we created a valid IPC file
+                reader = reader.with_skip_validation(true);
+            }
             for _ in 0..10 {
                 reader.next().unwrap().unwrap();
             }
@@ -74,19 +95,30 @@ fn criterion_benchmark(c: &mut Criterion) {
         })
     });
 
+    // --- Create IPC File ---
     group.bench_function("FileReader/read_10", |b| {
-        let batch = create_batch(8192, true);
-        let mut buffer = Vec::with_capacity(2 * 1024 * 1024);
-        let mut writer = FileWriter::try_new(&mut buffer, 
batch.schema().as_ref()).unwrap();
-        for _ in 0..10 {
-            writer.write(&batch).unwrap();
-        }
-        writer.finish().unwrap();
+        let buffer = ipc_file();
+        b.iter(move || {
+            let projection = None;
+            let cursor = Cursor::new(buffer.as_slice());
+            let mut reader = FileReader::try_new(cursor, projection).unwrap();
+            for _ in 0..10 {
+                reader.next().unwrap().unwrap();
+            }
+            assert!(reader.next().is_none());
+        })
+    });
 
+    group.bench_function("FileReader/no_validation/read_10", |b| {
+        let buffer = ipc_file();
         b.iter(move || {
             let projection = None;
             let cursor = Cursor::new(buffer.as_slice());
             let mut reader = FileReader::try_new(cursor, projection).unwrap();
+            unsafe {
+                // safety: we created a valid IPC file
+                reader = reader.with_skip_validation(true);
+            }
             for _ in 0..10 {
                 reader.next().unwrap().unwrap();
             }
@@ -94,26 +126,42 @@ fn criterion_benchmark(c: &mut Criterion) {
         })
     });
 
+    // write to an actual file
+    let dir = tempdir().unwrap();
+    let path = dir.path().join("test.arrow");
+    let mut file = std::fs::File::create(&path).unwrap();
+    file.write_all(&ipc_file()).unwrap();
+    drop(file);
+
     group.bench_function("FileReader/read_10/mmap", |b| {
-        let batch = create_batch(8192, true);
-        // write to an actual file
-        let dir = tempdir().unwrap();
-        let path = dir.path().join("test.arrow");
-        let file = std::fs::File::create(&path).unwrap();
-        let mut writer = FileWriter::try_new(file, 
batch.schema().as_ref()).unwrap();
-        for _ in 0..10 {
-            writer.write(&batch).unwrap();
-        }
-        writer.finish().unwrap();
+        let path = &path;
+        b.iter(move || {
+            let ipc_file = std::fs::File::open(path).expect("failed to open 
file");
+            let mmap = unsafe { memmap2::Mmap::map(&ipc_file).expect("failed 
to mmap file") };
+
+            // Convert the mmap region to an Arrow `Buffer` to back the arrow 
arrays.
+            let bytes = bytes::Bytes::from_owner(mmap);
+            let buffer = Buffer::from(bytes);
+            let decoder = IPCBufferDecoder::new(buffer);
+            assert_eq!(decoder.num_batches(), 10);
 
+            for i in 0..decoder.num_batches() {
+                decoder.get_batch(i);
+            }
+        })
+    });
+
+    group.bench_function("FileReader/no_validation/read_10/mmap", |b| {
+        let path = &path;
         b.iter(move || {
-            let ipc_file = std::fs::File::open(&path).expect("failed to open 
file");
+            let ipc_file = std::fs::File::open(path).expect("failed to open 
file");
             let mmap = unsafe { memmap2::Mmap::map(&ipc_file).expect("failed 
to mmap file") };
 
             // Convert the mmap region to an Arrow `Buffer` to back the arrow 
arrays.
             let bytes = bytes::Bytes::from_owner(mmap);
             let buffer = Buffer::from(bytes);
             let decoder = IPCBufferDecoder::new(buffer);
+            let decoder = unsafe { decoder.with_skip_validation(true) };
             assert_eq!(decoder.num_batches(), 10);
 
             for i in 0..decoder.num_batches() {
@@ -123,6 +171,31 @@ fn criterion_benchmark(c: &mut Criterion) {
     });
 }
 
+/// Return an IPC stream with 10 record batches
+fn ipc_stream(options: IpcWriteOptions) -> Vec<u8> {
+    let batch = create_batch(8192, true);
+    let mut buffer = Vec::with_capacity(2 * 1024 * 1024);
+    let mut writer =
+        StreamWriter::try_new_with_options(&mut buffer, 
batch.schema().as_ref(), options).unwrap();
+    for _ in 0..10 {
+        writer.write(&batch).unwrap();
+    }
+    writer.finish().unwrap();
+    buffer
+}
+
+/// Return an IPC file with 10 record batches
+fn ipc_file() -> Vec<u8> {
+    let batch = create_batch(8192, true);
+    let mut buffer = Vec::with_capacity(2 * 1024 * 1024);
+    let mut writer = FileWriter::try_new(&mut buffer, 
batch.schema().as_ref()).unwrap();
+    for _ in 0..10 {
+        writer.write(&batch).unwrap();
+    }
+    writer.finish().unwrap();
+    buffer
+}
+
 // copied from the zero_copy_ipc example.
 // should we move this to an actual API?
 /// Wrapper around the example in the `FileDecoder` which handles the
@@ -166,6 +239,11 @@ impl IPCBufferDecoder {
         }
     }
 
+    unsafe fn with_skip_validation(mut self, skip_validation: bool) -> Self {
+        self.decoder = self.decoder.with_skip_validation(skip_validation);
+        self
+    }
+
     fn num_batches(&self) -> usize {
         self.batches.len()
     }
diff --git a/arrow-ipc/src/reader.rs b/arrow-ipc/src/reader.rs
index ddda179cbe..695ead4fab 100644
--- a/arrow-ipc/src/reader.rs
+++ b/arrow-ipc/src/reader.rs
@@ -36,7 +36,7 @@ use std::sync::Arc;
 
 use arrow_array::*;
 use arrow_buffer::{ArrowNativeType, BooleanBuffer, Buffer, MutableBuffer, 
ScalarBuffer};
-use arrow_data::ArrayData;
+use arrow_data::{ArrayData, ArrayDataBuilder, UnsafeFlag};
 use arrow_schema::*;
 
 use crate::compression::CompressionCodec;
@@ -136,24 +136,7 @@ impl RecordBatchDecoder<'_> {
                     let child = self.create_array(struct_field, 
variadic_counts)?;
                     struct_arrays.push(child);
                 }
-                let null_count = struct_node.null_count() as usize;
-                let struct_array = if struct_arrays.is_empty() {
-                    // `StructArray::from` can't infer the correct row count
-                    // if we have zero fields
-                    let len = struct_node.length() as usize;
-                    StructArray::new_empty_fields(
-                        len,
-                        (null_count > 0).then(|| 
BooleanBuffer::new(null_buffer, 0, len).into()),
-                    )
-                } else if null_count > 0 {
-                    // create struct array from fields, arrays and null data
-                    let len = struct_node.length() as usize;
-                    let nulls = BooleanBuffer::new(null_buffer, 0, len).into();
-                    StructArray::try_new(struct_fields.clone(), struct_arrays, 
Some(nulls))?
-                } else {
-                    StructArray::try_new(struct_fields.clone(), struct_arrays, 
None)?
-                };
-                Ok(Arc::new(struct_array))
+                self.create_struct_array(struct_node, null_buffer, 
struct_fields, struct_arrays)
             }
             RunEndEncoded(run_ends_field, values_field) => {
                 let run_node = self.next_node(field)?;
@@ -161,15 +144,12 @@ impl RecordBatchDecoder<'_> {
                 let values = self.create_array(values_field, variadic_counts)?;
 
                 let run_array_length = run_node.length() as usize;
-                let array_data = ArrayData::builder(data_type.clone())
+                let builder = ArrayData::builder(data_type.clone())
                     .len(run_array_length)
                     .offset(0)
                     .add_child_data(run_ends.into_data())
-                    .add_child_data(values.into_data())
-                    .align_buffers(!self.require_alignment)
-                    .build()?;
-
-                Ok(make_array(array_data))
+                    .add_child_data(values.into_data());
+                self.create_array_from_builder(builder)
             }
             // Create dictionary array from RecordBatch
             Dictionary(_, _) => {
@@ -223,7 +203,14 @@ impl RecordBatchDecoder<'_> {
                     children.push(child);
                 }
 
-                let array = UnionArray::try_new(fields.clone(), type_ids, 
value_offsets, children)?;
+                let array = if self.skip_validation.get() {
+                    // safety: flag can only be set via unsafe code
+                    unsafe {
+                        UnionArray::new_unchecked(fields.clone(), type_ids, 
value_offsets, children)
+                    }
+                } else {
+                    UnionArray::try_new(fields.clone(), type_ids, 
value_offsets, children)?
+                };
                 Ok(Arc::new(array))
             }
             Null => {
@@ -237,14 +224,10 @@ impl RecordBatchDecoder<'_> {
                     )));
                 }
 
-                let array_data = ArrayData::builder(data_type.clone())
+                let builder = ArrayData::builder(data_type.clone())
                     .len(length as usize)
-                    .offset(0)
-                    .align_buffers(!self.require_alignment)
-                    .build()?;
-
-                // no buffer increases
-                Ok(Arc::new(NullArray::from(array_data)))
+                    .offset(0);
+                self.create_array_from_builder(builder)
             }
             _ => {
                 let field_node = self.next_node(field)?;
@@ -286,9 +269,17 @@ impl RecordBatchDecoder<'_> {
             t => unreachable!("Data type {:?} either unsupported or not 
primitive", t),
         };
 
-        let array_data = 
builder.align_buffers(!self.require_alignment).build()?;
+        self.create_array_from_builder(builder)
+    }
 
-        Ok(make_array(array_data))
+    /// Update the ArrayDataBuilder based on settings in this decoder
+    fn create_array_from_builder(&self, builder: ArrayDataBuilder) -> 
Result<ArrayRef, ArrowError> {
+        let mut builder = builder.align_buffers(!self.require_alignment);
+        if self.skip_validation.get() {
+            // SAFETY: flag can only be set via unsafe code
+            unsafe { builder = builder.skip_validation(true) }
+        };
+        Ok(make_array(builder.build()?))
     }
 
     /// Reads the correct number of buffers based on list type and null_count, 
and creates a
@@ -318,9 +309,34 @@ impl RecordBatchDecoder<'_> {
             _ => unreachable!("Cannot create list or map array from {:?}", 
data_type),
         };
 
-        let array_data = 
builder.align_buffers(!self.require_alignment).build()?;
+        self.create_array_from_builder(builder)
+    }
+
+    fn create_struct_array(
+        &self,
+        struct_node: &FieldNode,
+        null_buffer: Buffer,
+        struct_fields: &Fields,
+        struct_arrays: Vec<ArrayRef>,
+    ) -> Result<ArrayRef, ArrowError> {
+        let null_count = struct_node.null_count() as usize;
+        let len = struct_node.length() as usize;
+
+        let nulls = (null_count > 0).then(|| BooleanBuffer::new(null_buffer, 
0, len).into());
+        if struct_arrays.is_empty() {
+            // `StructArray::from` can't infer the correct row count
+            // if we have zero fields
+            return Ok(Arc::new(StructArray::new_empty_fields(len, nulls)));
+        }
 
-        Ok(make_array(array_data))
+        let struct_array = if self.skip_validation.get() {
+            // safety: flag can only be set via unsafe code
+            unsafe { StructArray::new_unchecked(struct_fields.clone(), 
struct_arrays, nulls) }
+        } else {
+            StructArray::try_new(struct_fields.clone(), struct_arrays, nulls)?
+        };
+
+        Ok(Arc::new(struct_array))
     }
 
     /// Reads the correct number of buffers based on list type and null_count, 
and creates a
@@ -334,15 +350,12 @@ impl RecordBatchDecoder<'_> {
     ) -> Result<ArrayRef, ArrowError> {
         if let Dictionary(_, _) = *data_type {
             let null_buffer = (field_node.null_count() > 
0).then_some(buffers[0].clone());
-            let array_data = ArrayData::builder(data_type.clone())
+            let builder = ArrayData::builder(data_type.clone())
                 .len(field_node.length() as usize)
                 .add_buffer(buffers[1].clone())
                 .add_child_data(value_array.into_data())
-                .null_bit_buffer(null_buffer)
-                .align_buffers(!self.require_alignment)
-                .build()?;
-
-            Ok(make_array(array_data))
+                .null_bit_buffer(null_buffer);
+            self.create_array_from_builder(builder)
         } else {
             unreachable!("Cannot create dictionary array from {:?}", data_type)
         }
@@ -376,6 +389,10 @@ struct RecordBatchDecoder<'a> {
     /// Are buffers required to already be aligned? See
     /// [`RecordBatchDecoder::with_require_alignment`] for details
     require_alignment: bool,
+    /// Should validation be skipped when reading data? Defaults to false.
+    ///
+    /// See [`FileDecoder::with_skip_validation`] for details.
+    skip_validation: UnsafeFlag,
 }
 
 impl<'a> RecordBatchDecoder<'a> {
@@ -410,6 +427,7 @@ impl<'a> RecordBatchDecoder<'a> {
             buffers: buffers.iter(),
             projection: None,
             require_alignment: false,
+            skip_validation: UnsafeFlag::new(),
         })
     }
 
@@ -432,6 +450,22 @@ impl<'a> RecordBatchDecoder<'a> {
         self
     }
 
+    /// Specifies if validation should be skipped when reading data (defaults 
to `false`)
+    ///
+    /// Note this API is somewhat "funky" as it allows the caller to skip 
validation
+    /// without having to use `unsafe` code. If this is ever made public
+    /// it should be made clearer that this is a potentially unsafe by
+    /// using an `unsafe` function that takes a boolean flag.
+    ///
+    /// # Safety
+    ///
+    /// Relies on the caller only passing a flag with `true` value if they are
+    /// certain that the data is valid
+    pub(crate) fn with_skip_validation(mut self, skip_validation: UnsafeFlag) 
-> Self {
+        self.skip_validation = skip_validation;
+        self
+    }
+
     /// Read the record batch, consuming the reader
     fn read_record_batch(mut self) -> Result<RecordBatch, ArrowError> {
         let mut variadic_counts: VecDeque<i64> = self
@@ -601,7 +635,15 @@ pub fn read_dictionary(
     dictionaries_by_id: &mut HashMap<i64, ArrayRef>,
     metadata: &MetadataVersion,
 ) -> Result<(), ArrowError> {
-    read_dictionary_impl(buf, batch, schema, dictionaries_by_id, metadata, 
false)
+    read_dictionary_impl(
+        buf,
+        batch,
+        schema,
+        dictionaries_by_id,
+        metadata,
+        false,
+        UnsafeFlag::new(),
+    )
 }
 
 fn read_dictionary_impl(
@@ -611,6 +653,7 @@ fn read_dictionary_impl(
     dictionaries_by_id: &mut HashMap<i64, ArrayRef>,
     metadata: &MetadataVersion,
     require_alignment: bool,
+    skip_validation: UnsafeFlag,
 ) -> Result<(), ArrowError> {
     if batch.isDelta() {
         return Err(ArrowError::InvalidArgumentError(
@@ -642,6 +685,7 @@ fn read_dictionary_impl(
                 metadata,
             )?
             .with_require_alignment(require_alignment)
+            .with_skip_validation(skip_validation)
             .read_record_batch()?;
 
             Some(record_batch.column(0).clone())
@@ -772,6 +816,7 @@ pub struct FileDecoder {
     version: MetadataVersion,
     projection: Option<Vec<usize>>,
     require_alignment: bool,
+    skip_validation: UnsafeFlag,
 }
 
 impl FileDecoder {
@@ -783,6 +828,7 @@ impl FileDecoder {
             dictionaries: Default::default(),
             projection: None,
             require_alignment: false,
+            skip_validation: UnsafeFlag::new(),
         }
     }
 
@@ -792,7 +838,7 @@ impl FileDecoder {
         self
     }
 
-    /// Specifies whether or not array data in input buffers is required to be 
properly aligned.
+    /// Specifies if the array data in input buffers is required to be 
properly aligned.
     ///
     /// If `require_alignment` is true, this decoder will return an error if 
any array data in the
     /// input `buf` is not properly aligned.
@@ -809,6 +855,21 @@ impl FileDecoder {
         self
     }
 
+    /// Specifies if validation should be skipped when reading data (defaults 
to `false`)
+    ///
+    /// # Safety
+    ///
+    /// This flag must only be set to `true` when you trust the input data and 
are sure the data you are
+    /// reading is a valid Arrow IPC file, otherwise undefined behavior may
+    /// result.
+    ///
+    /// For example, some programs may wish to trust reading IPC files written
+    /// by the same process that created the files.
+    pub unsafe fn with_skip_validation(mut self, skip_validation: bool) -> 
Self {
+        self.skip_validation.set(skip_validation);
+        self
+    }
+
     fn read_message<'a>(&self, buf: &'a [u8]) -> Result<Message<'a>, 
ArrowError> {
         let message = parse_message(buf)?;
 
@@ -834,6 +895,7 @@ impl FileDecoder {
                     &mut self.dictionaries,
                     &message.version(),
                     self.require_alignment,
+                    self.skip_validation.clone(),
                 )
             }
             t => Err(ArrowError::ParseError(format!(
@@ -867,6 +929,7 @@ impl FileDecoder {
                 )?
                 .with_projection(self.projection.as_deref())
                 .with_require_alignment(self.require_alignment)
+                .with_skip_validation(self.skip_validation.clone())
                 .read_record_batch()
                 .map(Some)
             }
@@ -1177,6 +1240,16 @@ impl<R: Read + Seek> FileReader<R> {
     pub fn get_mut(&mut self) -> &mut R {
         &mut self.reader
     }
+
+    /// Specifies if validation should be skipped when reading data (defaults 
to `false`)
+    ///
+    /// # Safety
+    ///
+    /// See [`FileDecoder::with_skip_validation`]
+    pub unsafe fn with_skip_validation(mut self, skip_validation: bool) -> 
Self {
+        self.decoder = self.decoder.with_skip_validation(skip_validation);
+        self
+    }
 }
 
 impl<R: Read + Seek> Iterator for FileReader<R> {
@@ -1250,6 +1323,11 @@ pub struct StreamReader<R> {
 
     /// Optional projection
     projection: Option<(Vec<usize>, Schema)>,
+
+    /// Should validation be skipped when reading data? Defaults to false.
+    ///
+    /// See [`FileDecoder::with_skip_validation`] for details.
+    skip_validation: UnsafeFlag,
 }
 
 impl<R> fmt::Debug for StreamReader<R> {
@@ -1329,6 +1407,7 @@ impl<R: Read> StreamReader<R> {
             finished: false,
             dictionaries_by_id,
             projection,
+            skip_validation: UnsafeFlag::new(),
         })
     }
 
@@ -1417,6 +1496,7 @@ impl<R: Read> StreamReader<R> {
                 )?
                 .with_projection(self.projection.as_ref().map(|x| 
x.0.as_ref()))
                 .with_require_alignment(false)
+                .with_skip_validation(self.skip_validation.clone())
                 .read_record_batch()
                 .map(Some)
             }
@@ -1437,6 +1517,7 @@ impl<R: Read> StreamReader<R> {
                     &mut self.dictionaries_by_id,
                     &message.version(),
                     false,
+                    self.skip_validation.clone(),
                 )?;
 
                 // read the next message until we encounter a RecordBatch
@@ -1462,6 +1543,16 @@ impl<R: Read> StreamReader<R> {
     pub fn get_mut(&mut self) -> &mut R {
         &mut self.reader
     }
+
+    /// Specifies if validation should be skipped when reading data (defaults 
to `false`)
+    ///
+    /// # Safety
+    ///
+    /// See [`FileDecoder::with_skip_validation`]
+    pub unsafe fn with_skip_validation(mut self, skip_validation: bool) -> 
Self {
+        self.skip_validation.set(skip_validation);
+        self
+    }
 }
 
 impl<R: Read> Iterator for StreamReader<R> {
@@ -1740,6 +1831,15 @@ mod tests {
         reader.next().unwrap()
     }
 
+    /// Return the first record batch read from the IPC File buffer, disabling
+    /// validation
+    fn read_ipc_skip_validation(buf: &[u8]) -> Result<RecordBatch, ArrowError> 
{
+        let mut reader = unsafe {
+            FileReader::try_new(std::io::Cursor::new(buf), 
None)?.with_skip_validation(true)
+        };
+        reader.next().unwrap()
+    }
+
     fn roundtrip_ipc(rb: &RecordBatch) -> RecordBatch {
         let buf = write_ipc(rb);
         read_ipc(&buf).unwrap()
@@ -1748,6 +1848,19 @@ mod tests {
     /// Return the first record batch read from the IPC File buffer
     /// using the FileDecoder API
     fn read_ipc_with_decoder(buf: Vec<u8>) -> Result<RecordBatch, ArrowError> {
+        read_ipc_with_decoder_inner(buf, false)
+    }
+
+    /// Return the first record batch read from the IPC File buffer
+    /// using the FileDecoder API, disabling validation
+    fn read_ipc_with_decoder_skip_validation(buf: Vec<u8>) -> 
Result<RecordBatch, ArrowError> {
+        read_ipc_with_decoder_inner(buf, true)
+    }
+
+    fn read_ipc_with_decoder_inner(
+        buf: Vec<u8>,
+        skip_validation: bool,
+    ) -> Result<RecordBatch, ArrowError> {
         let buffer = Buffer::from_vec(buf);
         let trailer_start = buffer.len() - 10;
         let footer_len = 
read_footer_length(buffer[trailer_start..].try_into().unwrap())?;
@@ -1756,7 +1869,10 @@ mod tests {
 
         let schema = fb_to_schema(footer.schema().unwrap());
 
-        let mut decoder = FileDecoder::new(Arc::new(schema), footer.version());
+        let mut decoder = unsafe {
+            FileDecoder::new(Arc::new(schema), footer.version())
+                .with_skip_validation(skip_validation)
+        };
         // Read dictionaries
         for block in footer.dictionaries().iter().flatten() {
             let block_len = block.bodyLength() as usize + 
block.metaDataLength() as usize;
@@ -1789,6 +1905,15 @@ mod tests {
         reader.next().unwrap()
     }
 
+    /// Return the first record batch read from the IPC Stream buffer,
+    /// disabling validation
+    fn read_stream_skip_validation(buf: &[u8]) -> Result<RecordBatch, 
ArrowError> {
+        let mut reader = unsafe {
+            StreamReader::try_new(std::io::Cursor::new(buf), 
None)?.with_skip_validation(true)
+        };
+        reader.next().unwrap()
+    }
+
     fn roundtrip_ipc_stream(rb: &RecordBatch) -> RecordBatch {
         let buf = write_stream(rb);
         read_stream(&buf).unwrap()
@@ -2456,6 +2581,57 @@ mod tests {
         );
     }
 
+    #[test]
+    fn test_invalid_nested_array_ipc_read_errors() {
+        // one of the nested arrays has invalid data
+        let a_field = Field::new("a", DataType::Int32, false);
+        let b_field = Field::new("b", DataType::Utf8, false);
+
+        let schema = Arc::new(Schema::new(vec![Field::new_struct(
+            "s",
+            vec![a_field.clone(), b_field.clone()],
+            false,
+        )]));
+
+        let a_array_data = ArrayData::builder(a_field.data_type().clone())
+            .len(4)
+            .add_buffer(Buffer::from_slice_ref([1, 2, 3, 4]))
+            .build()
+            .unwrap();
+        // invalid nested child array -- length is correct, but has invalid 
utf8 data
+        let b_array_data = {
+            let valid: &[u8] = b"   ";
+            let mut invalid = vec![];
+            invalid.extend_from_slice(b"ValidString");
+            invalid.extend_from_slice(INVALID_UTF8_FIRST_CHAR);
+            let binary_array =
+                BinaryArray::from_iter(vec![None, Some(valid), None, 
Some(&invalid)]);
+            let array = unsafe {
+                StringArray::new_unchecked(
+                    binary_array.offsets().clone(),
+                    binary_array.values().clone(),
+                    binary_array.nulls().cloned(),
+                )
+            };
+            array.into_data()
+        };
+        let struct_data_type = schema.field(0).data_type();
+
+        let invalid_struct_arr = unsafe {
+            make_array(
+                ArrayData::builder(struct_data_type.clone())
+                    .len(4)
+                    .add_child_data(a_array_data)
+                    .add_child_data(b_array_data)
+                    .build_unchecked(),
+            )
+        };
+        expect_ipc_validation_error(
+            Arc::new(invalid_struct_arr),
+            "Invalid argument error: Invalid UTF8 sequence at string index 3 
(3..18): invalid utf-8 sequence of 1 bytes from index 11",
+        );
+    }
+
     #[test]
     fn test_same_dict_id_without_preserve() {
         let batch = RecordBatch::try_new(
@@ -2592,6 +2768,32 @@ mod tests {
         );
     }
 
+    #[test]
+    fn test_validation_of_invalid_union_array() {
+        let array = unsafe {
+            let fields = UnionFields::new(
+                vec![1, 3], // typeids : type id 2 is not valid
+                vec![
+                    Field::new("a", DataType::Int32, false),
+                    Field::new("b", DataType::Utf8, false),
+                ],
+            );
+            let type_ids = ScalarBuffer::from(vec![1i8, 2, 3]); // 2 is invalid
+            let offsets = None;
+            let children: Vec<ArrayRef> = vec![
+                Arc::new(Int32Array::from(vec![10, 20, 30])),
+                Arc::new(StringArray::from(vec![Some("a"), Some("b"), 
Some("c")])),
+            ];
+
+            UnionArray::new_unchecked(fields, type_ids, offsets, children)
+        };
+
+        expect_ipc_validation_error(
+            Arc::new(array),
+            "Invalid argument error: Type Ids values must match one of the 
field type ids",
+        );
+    }
+
     /// Invalid Utf-8 sequence in the first character
     /// 
<https://stackoverflow.com/questions/1301402/example-invalid-utf8-string>
     const INVALID_UTF8_FIRST_CHAR: &[u8] = &[0xa0, 0xa1, 0x20, 0x20];
@@ -2602,18 +2804,18 @@ mod tests {
 
         // IPC Stream format
         let buf = write_stream(&rb); // write is ok
+        read_stream_skip_validation(&buf).unwrap();
         let err = read_stream(&buf).unwrap_err();
         assert_eq!(err.to_string(), expected_err);
 
         // IPC File format
         let buf = write_ipc(&rb); // write is ok
+        read_ipc_skip_validation(&buf).unwrap();
         let err = read_ipc(&buf).unwrap_err();
         assert_eq!(err.to_string(), expected_err);
 
-        // TODO verify there is no error when validation is disabled
-        // see https://github.com/apache/arrow-rs/issues/3287
-
         // IPC Format with FileDecoder
+        read_ipc_with_decoder_skip_validation(buf.clone()).unwrap();
         let err = read_ipc_with_decoder(buf).unwrap_err();
         assert_eq!(err.to_string(), expected_err);
     }
diff --git a/arrow-ipc/src/reader/stream.rs b/arrow-ipc/src/reader/stream.rs
index 174e69c1f6..5902cbe4e0 100644
--- a/arrow-ipc/src/reader/stream.rs
+++ b/arrow-ipc/src/reader/stream.rs
@@ -21,6 +21,7 @@ use std::sync::Arc;
 
 use arrow_array::{ArrayRef, RecordBatch};
 use arrow_buffer::{Buffer, MutableBuffer};
+use arrow_data::UnsafeFlag;
 use arrow_schema::{ArrowError, SchemaRef};
 
 use crate::convert::MessageBuffer;
@@ -42,6 +43,12 @@ pub struct StreamDecoder {
     buf: MutableBuffer,
     /// Whether or not array data in input buffers are required to be aligned
     require_alignment: bool,
+    /// Should validation be skipped when reading data? Defaults to false.
+    ///
+    /// See [`FileDecoder::with_skip_validation`] for details.
+    ///
+    /// [`FileDecoder::with_skip_validation`]: 
crate::reader::FileDecoder::with_skip_validation
+    skip_validation: UnsafeFlag,
 }
 
 #[derive(Debug)]
@@ -235,6 +242,7 @@ impl StreamDecoder {
                                 &mut self.dictionaries,
                                 &version,
                                 self.require_alignment,
+                                self.skip_validation.clone(),
                             )?;
                             self.state = DecoderState::default();
                         }


Reply via email to