This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/main by this push:
new 2fd51814df Add `with_skip_validation` flag to IPC `StreamReader`,
`FileReader` and `FileDecoder` (#7120)
2fd51814df is described below
commit 2fd51814df6f85d74179be845e4d2b4cf2bff4e3
Author: Andrew Lamb <[email protected]>
AuthorDate: Thu Feb 27 07:24:54 2025 -0500
Add `with_skip_validation` flag to IPC `StreamReader`, `FileReader` and
`FileDecoder` (#7120)
* Make UnsafeFlag pub
* Add `with_skip_validation` flag to IPC `StreamReader`, `FileReader` and
`FileDecoder`
* Apply suggestions from code review
Co-authored-by: Raphael Taylor-Davies
<[email protected]>
* Update arrow-ipc/src/reader.rs
Co-authored-by: Raphael Taylor-Davies
<[email protected]>
* Update arrow-ipc/src/reader.rs
Co-authored-by: Raphael Taylor-Davies
<[email protected]>
* Add notes to RecordBatchDecoder::with_skip_validation flag
* remove repetition in read benchmark
---------
Co-authored-by: Raphael Taylor-Davies
<[email protected]>
---
arrow-data/src/data.rs | 74 +++++++---
arrow-ipc/benches/ipc_reader.rs | 154 ++++++++++++++++-----
arrow-ipc/src/reader.rs | 300 +++++++++++++++++++++++++++++++++-------
arrow-ipc/src/reader/stream.rs | 8 ++
4 files changed, 426 insertions(+), 110 deletions(-)
diff --git a/arrow-data/src/data.rs b/arrow-data/src/data.rs
index 376b6f9584..a76e00e267 100644
--- a/arrow-data/src/data.rs
+++ b/arrow-data/src/data.rs
@@ -28,7 +28,6 @@ use std::mem;
use std::ops::Range;
use std::sync::Arc;
-use crate::data::private::UnsafeFlag;
use crate::{equal, validate_binary_view, validate_string_view};
#[inline]
@@ -1781,33 +1780,62 @@ impl PartialEq for ArrayData {
}
}
-mod private {
- /// A boolean flag that cannot be mutated outside of unsafe code.
+/// A boolean flag that cannot be mutated outside of unsafe code.
+///
+/// Defaults to a value of false.
+///
+/// This structure is used to enforce safety in the [`ArrayDataBuilder`]
+///
+/// [`ArrayDataBuilder`]: super::ArrayDataBuilder
+///
+/// # Example
+/// ```rust
+/// use arrow_data::UnsafeFlag;
+/// assert!(!UnsafeFlag::default().get()); // default is false
+/// let mut flag = UnsafeFlag::new();
+/// assert!(!flag.get()); // defaults to false
+/// // can only set it to true in unsafe code
+/// unsafe { flag.set(true) };
+/// assert!(flag.get()); // now true
+/// ```
+#[derive(Debug, Clone)]
+#[doc(hidden)]
+pub struct UnsafeFlag(bool);
+
+impl UnsafeFlag {
+ /// Creates a new `UnsafeFlag` with the value set to `false`.
+ ///
+ /// See examples on [`Self::new`]
+ #[inline]
+ pub const fn new() -> Self {
+ Self(false)
+ }
+
+ /// Sets the value of the flag to the given value
///
- /// Defaults to a value of false.
+ /// Note this can purposely only be done in `unsafe` code
///
- /// This structure is used to enforce safety in the [`ArrayDataBuilder`]
+ /// # Safety
///
- /// [`ArrayDataBuilder`]: super::ArrayDataBuilder
- #[derive(Debug)]
- pub struct UnsafeFlag(bool);
-
- impl UnsafeFlag {
- /// Creates a new `UnsafeFlag` with the value set to `false`
- #[inline]
- pub const fn new() -> Self {
- Self(false)
- }
+ /// If set, the flag will be set to the given value. There is nothing
+ /// immediately unsafe about doing so, however, the flag can be used to
+ /// subsequently bypass safety checks in the [`ArrayDataBuilder`].
+ #[inline]
+ pub unsafe fn set(&mut self, val: bool) {
+ self.0 = val;
+ }
- #[inline]
- pub unsafe fn set(&mut self, val: bool) {
- self.0 = val;
- }
+ /// Returns the value of the flag
+ #[inline]
+ pub fn get(&self) -> bool {
+ self.0
+ }
+}
- #[inline]
- pub fn get(&self) -> bool {
- self.0
- }
+// Manual impl to make it clear you can not construct unsafe with true
+impl Default for UnsafeFlag {
+ fn default() -> Self {
+ Self::new()
}
}
diff --git a/arrow-ipc/benches/ipc_reader.rs b/arrow-ipc/benches/ipc_reader.rs
index 7fc14664b4..ab77449eeb 100644
--- a/arrow-ipc/benches/ipc_reader.rs
+++ b/arrow-ipc/benches/ipc_reader.rs
@@ -24,7 +24,7 @@ use arrow_ipc::writer::{FileWriter, IpcWriteOptions,
StreamWriter};
use arrow_ipc::{root_as_footer, Block, CompressionType};
use arrow_schema::{DataType, Field, Schema};
use criterion::{criterion_group, criterion_main, Criterion};
-use std::io::Cursor;
+use std::io::{Cursor, Write};
use std::sync::Arc;
use tempfile::tempdir;
@@ -32,17 +32,26 @@ fn criterion_benchmark(c: &mut Criterion) {
let mut group = c.benchmark_group("arrow_ipc_reader");
group.bench_function("StreamReader/read_10", |b| {
- let batch = create_batch(8192, true);
- let mut buffer = Vec::with_capacity(2 * 1024 * 1024);
- let mut writer = StreamWriter::try_new(&mut buffer,
batch.schema().as_ref()).unwrap();
- for _ in 0..10 {
- writer.write(&batch).unwrap();
- }
- writer.finish().unwrap();
+ let buffer = ipc_stream(IpcWriteOptions::default());
+ b.iter(move || {
+ let projection = None;
+ let mut reader = StreamReader::try_new(buffer.as_slice(),
projection).unwrap();
+ for _ in 0..10 {
+ reader.next().unwrap().unwrap();
+ }
+ assert!(reader.next().is_none());
+ })
+ });
+ group.bench_function("StreamReader/no_validation/read_10", |b| {
+ let buffer = ipc_stream(IpcWriteOptions::default());
b.iter(move || {
let projection = None;
let mut reader = StreamReader::try_new(buffer.as_slice(),
projection).unwrap();
+ unsafe {
+ // safety: we created a valid IPC file
+ reader = reader.with_skip_validation(true);
+ }
for _ in 0..10 {
reader.next().unwrap().unwrap();
}
@@ -51,22 +60,34 @@ fn criterion_benchmark(c: &mut Criterion) {
});
group.bench_function("StreamReader/read_10/zstd", |b| {
- let batch = create_batch(8192, true);
- let mut buffer = Vec::with_capacity(2 * 1024 * 1024);
- let options = IpcWriteOptions::default()
- .try_with_compression(Some(CompressionType::ZSTD))
- .unwrap();
- let mut writer =
- StreamWriter::try_new_with_options(&mut buffer,
batch.schema().as_ref(), options)
- .unwrap();
- for _ in 0..10 {
- writer.write(&batch).unwrap();
- }
- writer.finish().unwrap();
+ let buffer = ipc_stream(
+ IpcWriteOptions::default()
+ .try_with_compression(Some(CompressionType::ZSTD))
+ .unwrap(),
+ );
+ b.iter(move || {
+ let projection = None;
+ let mut reader = StreamReader::try_new(buffer.as_slice(),
projection).unwrap();
+ for _ in 0..10 {
+ reader.next().unwrap().unwrap();
+ }
+ assert!(reader.next().is_none());
+ })
+ });
+ group.bench_function("StreamReader/no_validation/read_10/zstd", |b| {
+ let buffer = ipc_stream(
+ IpcWriteOptions::default()
+ .try_with_compression(Some(CompressionType::ZSTD))
+ .unwrap(),
+ );
b.iter(move || {
let projection = None;
let mut reader = StreamReader::try_new(buffer.as_slice(),
projection).unwrap();
+ unsafe {
+ // safety: we created a valid IPC file
+ reader = reader.with_skip_validation(true);
+ }
for _ in 0..10 {
reader.next().unwrap().unwrap();
}
@@ -74,19 +95,30 @@ fn criterion_benchmark(c: &mut Criterion) {
})
});
+ // --- Create IPC File ---
group.bench_function("FileReader/read_10", |b| {
- let batch = create_batch(8192, true);
- let mut buffer = Vec::with_capacity(2 * 1024 * 1024);
- let mut writer = FileWriter::try_new(&mut buffer,
batch.schema().as_ref()).unwrap();
- for _ in 0..10 {
- writer.write(&batch).unwrap();
- }
- writer.finish().unwrap();
+ let buffer = ipc_file();
+ b.iter(move || {
+ let projection = None;
+ let cursor = Cursor::new(buffer.as_slice());
+ let mut reader = FileReader::try_new(cursor, projection).unwrap();
+ for _ in 0..10 {
+ reader.next().unwrap().unwrap();
+ }
+ assert!(reader.next().is_none());
+ })
+ });
+ group.bench_function("FileReader/no_validation/read_10", |b| {
+ let buffer = ipc_file();
b.iter(move || {
let projection = None;
let cursor = Cursor::new(buffer.as_slice());
let mut reader = FileReader::try_new(cursor, projection).unwrap();
+ unsafe {
+ // safety: we created a valid IPC file
+ reader = reader.with_skip_validation(true);
+ }
for _ in 0..10 {
reader.next().unwrap().unwrap();
}
@@ -94,26 +126,42 @@ fn criterion_benchmark(c: &mut Criterion) {
})
});
+ // write to an actual file
+ let dir = tempdir().unwrap();
+ let path = dir.path().join("test.arrow");
+ let mut file = std::fs::File::create(&path).unwrap();
+ file.write_all(&ipc_file()).unwrap();
+ drop(file);
+
group.bench_function("FileReader/read_10/mmap", |b| {
- let batch = create_batch(8192, true);
- // write to an actual file
- let dir = tempdir().unwrap();
- let path = dir.path().join("test.arrow");
- let file = std::fs::File::create(&path).unwrap();
- let mut writer = FileWriter::try_new(file,
batch.schema().as_ref()).unwrap();
- for _ in 0..10 {
- writer.write(&batch).unwrap();
- }
- writer.finish().unwrap();
+ let path = &path;
+ b.iter(move || {
+ let ipc_file = std::fs::File::open(path).expect("failed to open
file");
+ let mmap = unsafe { memmap2::Mmap::map(&ipc_file).expect("failed
to mmap file") };
+
+ // Convert the mmap region to an Arrow `Buffer` to back the arrow
arrays.
+ let bytes = bytes::Bytes::from_owner(mmap);
+ let buffer = Buffer::from(bytes);
+ let decoder = IPCBufferDecoder::new(buffer);
+ assert_eq!(decoder.num_batches(), 10);
+ for i in 0..decoder.num_batches() {
+ decoder.get_batch(i);
+ }
+ })
+ });
+
+ group.bench_function("FileReader/no_validation/read_10/mmap", |b| {
+ let path = &path;
b.iter(move || {
- let ipc_file = std::fs::File::open(&path).expect("failed to open
file");
+ let ipc_file = std::fs::File::open(path).expect("failed to open
file");
let mmap = unsafe { memmap2::Mmap::map(&ipc_file).expect("failed
to mmap file") };
// Convert the mmap region to an Arrow `Buffer` to back the arrow
arrays.
let bytes = bytes::Bytes::from_owner(mmap);
let buffer = Buffer::from(bytes);
let decoder = IPCBufferDecoder::new(buffer);
+ let decoder = unsafe { decoder.with_skip_validation(true) };
assert_eq!(decoder.num_batches(), 10);
for i in 0..decoder.num_batches() {
@@ -123,6 +171,31 @@ fn criterion_benchmark(c: &mut Criterion) {
});
}
+/// Return an IPC stream with 10 record batches
+fn ipc_stream(options: IpcWriteOptions) -> Vec<u8> {
+ let batch = create_batch(8192, true);
+ let mut buffer = Vec::with_capacity(2 * 1024 * 1024);
+ let mut writer =
+ StreamWriter::try_new_with_options(&mut buffer,
batch.schema().as_ref(), options).unwrap();
+ for _ in 0..10 {
+ writer.write(&batch).unwrap();
+ }
+ writer.finish().unwrap();
+ buffer
+}
+
+/// Return an IPC file with 10 record batches
+fn ipc_file() -> Vec<u8> {
+ let batch = create_batch(8192, true);
+ let mut buffer = Vec::with_capacity(2 * 1024 * 1024);
+ let mut writer = FileWriter::try_new(&mut buffer,
batch.schema().as_ref()).unwrap();
+ for _ in 0..10 {
+ writer.write(&batch).unwrap();
+ }
+ writer.finish().unwrap();
+ buffer
+}
+
// copied from the zero_copy_ipc example.
// should we move this to an actual API?
/// Wrapper around the example in the `FileDecoder` which handles the
@@ -166,6 +239,11 @@ impl IPCBufferDecoder {
}
}
+ unsafe fn with_skip_validation(mut self, skip_validation: bool) -> Self {
+ self.decoder = self.decoder.with_skip_validation(skip_validation);
+ self
+ }
+
fn num_batches(&self) -> usize {
self.batches.len()
}
diff --git a/arrow-ipc/src/reader.rs b/arrow-ipc/src/reader.rs
index ddda179cbe..695ead4fab 100644
--- a/arrow-ipc/src/reader.rs
+++ b/arrow-ipc/src/reader.rs
@@ -36,7 +36,7 @@ use std::sync::Arc;
use arrow_array::*;
use arrow_buffer::{ArrowNativeType, BooleanBuffer, Buffer, MutableBuffer,
ScalarBuffer};
-use arrow_data::ArrayData;
+use arrow_data::{ArrayData, ArrayDataBuilder, UnsafeFlag};
use arrow_schema::*;
use crate::compression::CompressionCodec;
@@ -136,24 +136,7 @@ impl RecordBatchDecoder<'_> {
let child = self.create_array(struct_field,
variadic_counts)?;
struct_arrays.push(child);
}
- let null_count = struct_node.null_count() as usize;
- let struct_array = if struct_arrays.is_empty() {
- // `StructArray::from` can't infer the correct row count
- // if we have zero fields
- let len = struct_node.length() as usize;
- StructArray::new_empty_fields(
- len,
- (null_count > 0).then(||
BooleanBuffer::new(null_buffer, 0, len).into()),
- )
- } else if null_count > 0 {
- // create struct array from fields, arrays and null data
- let len = struct_node.length() as usize;
- let nulls = BooleanBuffer::new(null_buffer, 0, len).into();
- StructArray::try_new(struct_fields.clone(), struct_arrays,
Some(nulls))?
- } else {
- StructArray::try_new(struct_fields.clone(), struct_arrays,
None)?
- };
- Ok(Arc::new(struct_array))
+ self.create_struct_array(struct_node, null_buffer,
struct_fields, struct_arrays)
}
RunEndEncoded(run_ends_field, values_field) => {
let run_node = self.next_node(field)?;
@@ -161,15 +144,12 @@ impl RecordBatchDecoder<'_> {
let values = self.create_array(values_field, variadic_counts)?;
let run_array_length = run_node.length() as usize;
- let array_data = ArrayData::builder(data_type.clone())
+ let builder = ArrayData::builder(data_type.clone())
.len(run_array_length)
.offset(0)
.add_child_data(run_ends.into_data())
- .add_child_data(values.into_data())
- .align_buffers(!self.require_alignment)
- .build()?;
-
- Ok(make_array(array_data))
+ .add_child_data(values.into_data());
+ self.create_array_from_builder(builder)
}
// Create dictionary array from RecordBatch
Dictionary(_, _) => {
@@ -223,7 +203,14 @@ impl RecordBatchDecoder<'_> {
children.push(child);
}
- let array = UnionArray::try_new(fields.clone(), type_ids,
value_offsets, children)?;
+ let array = if self.skip_validation.get() {
+ // safety: flag can only be set via unsafe code
+ unsafe {
+ UnionArray::new_unchecked(fields.clone(), type_ids,
value_offsets, children)
+ }
+ } else {
+ UnionArray::try_new(fields.clone(), type_ids,
value_offsets, children)?
+ };
Ok(Arc::new(array))
}
Null => {
@@ -237,14 +224,10 @@ impl RecordBatchDecoder<'_> {
)));
}
- let array_data = ArrayData::builder(data_type.clone())
+ let builder = ArrayData::builder(data_type.clone())
.len(length as usize)
- .offset(0)
- .align_buffers(!self.require_alignment)
- .build()?;
-
- // no buffer increases
- Ok(Arc::new(NullArray::from(array_data)))
+ .offset(0);
+ self.create_array_from_builder(builder)
}
_ => {
let field_node = self.next_node(field)?;
@@ -286,9 +269,17 @@ impl RecordBatchDecoder<'_> {
t => unreachable!("Data type {:?} either unsupported or not
primitive", t),
};
- let array_data =
builder.align_buffers(!self.require_alignment).build()?;
+ self.create_array_from_builder(builder)
+ }
- Ok(make_array(array_data))
+ /// Update the ArrayDataBuilder based on settings in this decoder
+ fn create_array_from_builder(&self, builder: ArrayDataBuilder) ->
Result<ArrayRef, ArrowError> {
+ let mut builder = builder.align_buffers(!self.require_alignment);
+ if self.skip_validation.get() {
+ // SAFETY: flag can only be set via unsafe code
+ unsafe { builder = builder.skip_validation(true) }
+ };
+ Ok(make_array(builder.build()?))
}
/// Reads the correct number of buffers based on list type and null_count,
and creates a
@@ -318,9 +309,34 @@ impl RecordBatchDecoder<'_> {
_ => unreachable!("Cannot create list or map array from {:?}",
data_type),
};
- let array_data =
builder.align_buffers(!self.require_alignment).build()?;
+ self.create_array_from_builder(builder)
+ }
+
+ fn create_struct_array(
+ &self,
+ struct_node: &FieldNode,
+ null_buffer: Buffer,
+ struct_fields: &Fields,
+ struct_arrays: Vec<ArrayRef>,
+ ) -> Result<ArrayRef, ArrowError> {
+ let null_count = struct_node.null_count() as usize;
+ let len = struct_node.length() as usize;
+
+ let nulls = (null_count > 0).then(|| BooleanBuffer::new(null_buffer,
0, len).into());
+ if struct_arrays.is_empty() {
+ // `StructArray::from` can't infer the correct row count
+ // if we have zero fields
+ return Ok(Arc::new(StructArray::new_empty_fields(len, nulls)));
+ }
- Ok(make_array(array_data))
+ let struct_array = if self.skip_validation.get() {
+ // safety: flag can only be set via unsafe code
+ unsafe { StructArray::new_unchecked(struct_fields.clone(),
struct_arrays, nulls) }
+ } else {
+ StructArray::try_new(struct_fields.clone(), struct_arrays, nulls)?
+ };
+
+ Ok(Arc::new(struct_array))
}
/// Reads the correct number of buffers based on list type and null_count,
and creates a
@@ -334,15 +350,12 @@ impl RecordBatchDecoder<'_> {
) -> Result<ArrayRef, ArrowError> {
if let Dictionary(_, _) = *data_type {
let null_buffer = (field_node.null_count() >
0).then_some(buffers[0].clone());
- let array_data = ArrayData::builder(data_type.clone())
+ let builder = ArrayData::builder(data_type.clone())
.len(field_node.length() as usize)
.add_buffer(buffers[1].clone())
.add_child_data(value_array.into_data())
- .null_bit_buffer(null_buffer)
- .align_buffers(!self.require_alignment)
- .build()?;
-
- Ok(make_array(array_data))
+ .null_bit_buffer(null_buffer);
+ self.create_array_from_builder(builder)
} else {
unreachable!("Cannot create dictionary array from {:?}", data_type)
}
@@ -376,6 +389,10 @@ struct RecordBatchDecoder<'a> {
/// Are buffers required to already be aligned? See
/// [`RecordBatchDecoder::with_require_alignment`] for details
require_alignment: bool,
+ /// Should validation be skipped when reading data? Defaults to false.
+ ///
+ /// See [`FileDecoder::with_skip_validation`] for details.
+ skip_validation: UnsafeFlag,
}
impl<'a> RecordBatchDecoder<'a> {
@@ -410,6 +427,7 @@ impl<'a> RecordBatchDecoder<'a> {
buffers: buffers.iter(),
projection: None,
require_alignment: false,
+ skip_validation: UnsafeFlag::new(),
})
}
@@ -432,6 +450,22 @@ impl<'a> RecordBatchDecoder<'a> {
self
}
+ /// Specifies if validation should be skipped when reading data (defaults
to `false`)
+ ///
+ /// Note this API is somewhat "funky" as it allows the caller to skip
validation
+ /// without having to use `unsafe` code. If this is ever made public
+ /// it should be made clearer that this is a potentially unsafe by
+ /// using an `unsafe` function that takes a boolean flag.
+ ///
+ /// # Safety
+ ///
+ /// Relies on the caller only passing a flag with `true` value if they are
+ /// certain that the data is valid
+ pub(crate) fn with_skip_validation(mut self, skip_validation: UnsafeFlag)
-> Self {
+ self.skip_validation = skip_validation;
+ self
+ }
+
/// Read the record batch, consuming the reader
fn read_record_batch(mut self) -> Result<RecordBatch, ArrowError> {
let mut variadic_counts: VecDeque<i64> = self
@@ -601,7 +635,15 @@ pub fn read_dictionary(
dictionaries_by_id: &mut HashMap<i64, ArrayRef>,
metadata: &MetadataVersion,
) -> Result<(), ArrowError> {
- read_dictionary_impl(buf, batch, schema, dictionaries_by_id, metadata,
false)
+ read_dictionary_impl(
+ buf,
+ batch,
+ schema,
+ dictionaries_by_id,
+ metadata,
+ false,
+ UnsafeFlag::new(),
+ )
}
fn read_dictionary_impl(
@@ -611,6 +653,7 @@ fn read_dictionary_impl(
dictionaries_by_id: &mut HashMap<i64, ArrayRef>,
metadata: &MetadataVersion,
require_alignment: bool,
+ skip_validation: UnsafeFlag,
) -> Result<(), ArrowError> {
if batch.isDelta() {
return Err(ArrowError::InvalidArgumentError(
@@ -642,6 +685,7 @@ fn read_dictionary_impl(
metadata,
)?
.with_require_alignment(require_alignment)
+ .with_skip_validation(skip_validation)
.read_record_batch()?;
Some(record_batch.column(0).clone())
@@ -772,6 +816,7 @@ pub struct FileDecoder {
version: MetadataVersion,
projection: Option<Vec<usize>>,
require_alignment: bool,
+ skip_validation: UnsafeFlag,
}
impl FileDecoder {
@@ -783,6 +828,7 @@ impl FileDecoder {
dictionaries: Default::default(),
projection: None,
require_alignment: false,
+ skip_validation: UnsafeFlag::new(),
}
}
@@ -792,7 +838,7 @@ impl FileDecoder {
self
}
- /// Specifies whether or not array data in input buffers is required to be
properly aligned.
+ /// Specifies if the array data in input buffers is required to be
properly aligned.
///
/// If `require_alignment` is true, this decoder will return an error if
any array data in the
/// input `buf` is not properly aligned.
@@ -809,6 +855,21 @@ impl FileDecoder {
self
}
+ /// Specifies if validation should be skipped when reading data (defaults
to `false`)
+ ///
+ /// # Safety
+ ///
+ /// This flag must only be set to `true` when you trust the input data and
are sure the data you are
+ /// reading is a valid Arrow IPC file, otherwise undefined behavior may
+ /// result.
+ ///
+ /// For example, some programs may wish to trust reading IPC files written
+ /// by the same process that created the files.
+ pub unsafe fn with_skip_validation(mut self, skip_validation: bool) ->
Self {
+ self.skip_validation.set(skip_validation);
+ self
+ }
+
fn read_message<'a>(&self, buf: &'a [u8]) -> Result<Message<'a>,
ArrowError> {
let message = parse_message(buf)?;
@@ -834,6 +895,7 @@ impl FileDecoder {
&mut self.dictionaries,
&message.version(),
self.require_alignment,
+ self.skip_validation.clone(),
)
}
t => Err(ArrowError::ParseError(format!(
@@ -867,6 +929,7 @@ impl FileDecoder {
)?
.with_projection(self.projection.as_deref())
.with_require_alignment(self.require_alignment)
+ .with_skip_validation(self.skip_validation.clone())
.read_record_batch()
.map(Some)
}
@@ -1177,6 +1240,16 @@ impl<R: Read + Seek> FileReader<R> {
pub fn get_mut(&mut self) -> &mut R {
&mut self.reader
}
+
+ /// Specifies if validation should be skipped when reading data (defaults
to `false`)
+ ///
+ /// # Safety
+ ///
+ /// See [`FileDecoder::with_skip_validation`]
+ pub unsafe fn with_skip_validation(mut self, skip_validation: bool) ->
Self {
+ self.decoder = self.decoder.with_skip_validation(skip_validation);
+ self
+ }
}
impl<R: Read + Seek> Iterator for FileReader<R> {
@@ -1250,6 +1323,11 @@ pub struct StreamReader<R> {
/// Optional projection
projection: Option<(Vec<usize>, Schema)>,
+
+ /// Should validation be skipped when reading data? Defaults to false.
+ ///
+ /// See [`FileDecoder::with_skip_validation`] for details.
+ skip_validation: UnsafeFlag,
}
impl<R> fmt::Debug for StreamReader<R> {
@@ -1329,6 +1407,7 @@ impl<R: Read> StreamReader<R> {
finished: false,
dictionaries_by_id,
projection,
+ skip_validation: UnsafeFlag::new(),
})
}
@@ -1417,6 +1496,7 @@ impl<R: Read> StreamReader<R> {
)?
.with_projection(self.projection.as_ref().map(|x|
x.0.as_ref()))
.with_require_alignment(false)
+ .with_skip_validation(self.skip_validation.clone())
.read_record_batch()
.map(Some)
}
@@ -1437,6 +1517,7 @@ impl<R: Read> StreamReader<R> {
&mut self.dictionaries_by_id,
&message.version(),
false,
+ self.skip_validation.clone(),
)?;
// read the next message until we encounter a RecordBatch
@@ -1462,6 +1543,16 @@ impl<R: Read> StreamReader<R> {
pub fn get_mut(&mut self) -> &mut R {
&mut self.reader
}
+
+ /// Specifies if validation should be skipped when reading data (defaults
to `false`)
+ ///
+ /// # Safety
+ ///
+ /// See [`FileDecoder::with_skip_validation`]
+ pub unsafe fn with_skip_validation(mut self, skip_validation: bool) ->
Self {
+ self.skip_validation.set(skip_validation);
+ self
+ }
}
impl<R: Read> Iterator for StreamReader<R> {
@@ -1740,6 +1831,15 @@ mod tests {
reader.next().unwrap()
}
+ /// Return the first record batch read from the IPC File buffer, disabling
+ /// validation
+ fn read_ipc_skip_validation(buf: &[u8]) -> Result<RecordBatch, ArrowError>
{
+ let mut reader = unsafe {
+ FileReader::try_new(std::io::Cursor::new(buf),
None)?.with_skip_validation(true)
+ };
+ reader.next().unwrap()
+ }
+
fn roundtrip_ipc(rb: &RecordBatch) -> RecordBatch {
let buf = write_ipc(rb);
read_ipc(&buf).unwrap()
@@ -1748,6 +1848,19 @@ mod tests {
/// Return the first record batch read from the IPC File buffer
/// using the FileDecoder API
fn read_ipc_with_decoder(buf: Vec<u8>) -> Result<RecordBatch, ArrowError> {
+ read_ipc_with_decoder_inner(buf, false)
+ }
+
+ /// Return the first record batch read from the IPC File buffer
+ /// using the FileDecoder API, disabling validation
+ fn read_ipc_with_decoder_skip_validation(buf: Vec<u8>) ->
Result<RecordBatch, ArrowError> {
+ read_ipc_with_decoder_inner(buf, true)
+ }
+
+ fn read_ipc_with_decoder_inner(
+ buf: Vec<u8>,
+ skip_validation: bool,
+ ) -> Result<RecordBatch, ArrowError> {
let buffer = Buffer::from_vec(buf);
let trailer_start = buffer.len() - 10;
let footer_len =
read_footer_length(buffer[trailer_start..].try_into().unwrap())?;
@@ -1756,7 +1869,10 @@ mod tests {
let schema = fb_to_schema(footer.schema().unwrap());
- let mut decoder = FileDecoder::new(Arc::new(schema), footer.version());
+ let mut decoder = unsafe {
+ FileDecoder::new(Arc::new(schema), footer.version())
+ .with_skip_validation(skip_validation)
+ };
// Read dictionaries
for block in footer.dictionaries().iter().flatten() {
let block_len = block.bodyLength() as usize +
block.metaDataLength() as usize;
@@ -1789,6 +1905,15 @@ mod tests {
reader.next().unwrap()
}
+ /// Return the first record batch read from the IPC Stream buffer,
+ /// disabling validation
+ fn read_stream_skip_validation(buf: &[u8]) -> Result<RecordBatch,
ArrowError> {
+ let mut reader = unsafe {
+ StreamReader::try_new(std::io::Cursor::new(buf),
None)?.with_skip_validation(true)
+ };
+ reader.next().unwrap()
+ }
+
fn roundtrip_ipc_stream(rb: &RecordBatch) -> RecordBatch {
let buf = write_stream(rb);
read_stream(&buf).unwrap()
@@ -2456,6 +2581,57 @@ mod tests {
);
}
+ #[test]
+ fn test_invalid_nested_array_ipc_read_errors() {
+ // one of the nested arrays has invalid data
+ let a_field = Field::new("a", DataType::Int32, false);
+ let b_field = Field::new("b", DataType::Utf8, false);
+
+ let schema = Arc::new(Schema::new(vec![Field::new_struct(
+ "s",
+ vec![a_field.clone(), b_field.clone()],
+ false,
+ )]));
+
+ let a_array_data = ArrayData::builder(a_field.data_type().clone())
+ .len(4)
+ .add_buffer(Buffer::from_slice_ref([1, 2, 3, 4]))
+ .build()
+ .unwrap();
+ // invalid nested child array -- length is correct, but has invalid
utf8 data
+ let b_array_data = {
+ let valid: &[u8] = b" ";
+ let mut invalid = vec![];
+ invalid.extend_from_slice(b"ValidString");
+ invalid.extend_from_slice(INVALID_UTF8_FIRST_CHAR);
+ let binary_array =
+ BinaryArray::from_iter(vec![None, Some(valid), None,
Some(&invalid)]);
+ let array = unsafe {
+ StringArray::new_unchecked(
+ binary_array.offsets().clone(),
+ binary_array.values().clone(),
+ binary_array.nulls().cloned(),
+ )
+ };
+ array.into_data()
+ };
+ let struct_data_type = schema.field(0).data_type();
+
+ let invalid_struct_arr = unsafe {
+ make_array(
+ ArrayData::builder(struct_data_type.clone())
+ .len(4)
+ .add_child_data(a_array_data)
+ .add_child_data(b_array_data)
+ .build_unchecked(),
+ )
+ };
+ expect_ipc_validation_error(
+ Arc::new(invalid_struct_arr),
+ "Invalid argument error: Invalid UTF8 sequence at string index 3
(3..18): invalid utf-8 sequence of 1 bytes from index 11",
+ );
+ }
+
#[test]
fn test_same_dict_id_without_preserve() {
let batch = RecordBatch::try_new(
@@ -2592,6 +2768,32 @@ mod tests {
);
}
+ #[test]
+ fn test_validation_of_invalid_union_array() {
+ let array = unsafe {
+ let fields = UnionFields::new(
+ vec![1, 3], // typeids : type id 2 is not valid
+ vec![
+ Field::new("a", DataType::Int32, false),
+ Field::new("b", DataType::Utf8, false),
+ ],
+ );
+ let type_ids = ScalarBuffer::from(vec![1i8, 2, 3]); // 2 is invalid
+ let offsets = None;
+ let children: Vec<ArrayRef> = vec![
+ Arc::new(Int32Array::from(vec![10, 20, 30])),
+ Arc::new(StringArray::from(vec![Some("a"), Some("b"),
Some("c")])),
+ ];
+
+ UnionArray::new_unchecked(fields, type_ids, offsets, children)
+ };
+
+ expect_ipc_validation_error(
+ Arc::new(array),
+ "Invalid argument error: Type Ids values must match one of the
field type ids",
+ );
+ }
+
/// Invalid Utf-8 sequence in the first character
///
<https://stackoverflow.com/questions/1301402/example-invalid-utf8-string>
const INVALID_UTF8_FIRST_CHAR: &[u8] = &[0xa0, 0xa1, 0x20, 0x20];
@@ -2602,18 +2804,18 @@ mod tests {
// IPC Stream format
let buf = write_stream(&rb); // write is ok
+ read_stream_skip_validation(&buf).unwrap();
let err = read_stream(&buf).unwrap_err();
assert_eq!(err.to_string(), expected_err);
// IPC File format
let buf = write_ipc(&rb); // write is ok
+ read_ipc_skip_validation(&buf).unwrap();
let err = read_ipc(&buf).unwrap_err();
assert_eq!(err.to_string(), expected_err);
- // TODO verify there is no error when validation is disabled
- // see https://github.com/apache/arrow-rs/issues/3287
-
// IPC Format with FileDecoder
+ read_ipc_with_decoder_skip_validation(buf.clone()).unwrap();
let err = read_ipc_with_decoder(buf).unwrap_err();
assert_eq!(err.to_string(), expected_err);
}
diff --git a/arrow-ipc/src/reader/stream.rs b/arrow-ipc/src/reader/stream.rs
index 174e69c1f6..5902cbe4e0 100644
--- a/arrow-ipc/src/reader/stream.rs
+++ b/arrow-ipc/src/reader/stream.rs
@@ -21,6 +21,7 @@ use std::sync::Arc;
use arrow_array::{ArrayRef, RecordBatch};
use arrow_buffer::{Buffer, MutableBuffer};
+use arrow_data::UnsafeFlag;
use arrow_schema::{ArrowError, SchemaRef};
use crate::convert::MessageBuffer;
@@ -42,6 +43,12 @@ pub struct StreamDecoder {
buf: MutableBuffer,
/// Whether or not array data in input buffers are required to be aligned
require_alignment: bool,
+ /// Should validation be skipped when reading data? Defaults to false.
+ ///
+ /// See [`FileDecoder::with_skip_validation`] for details.
+ ///
+ /// [`FileDecoder::with_skip_validation`]:
crate::reader::FileDecoder::with_skip_validation
+ skip_validation: UnsafeFlag,
}
#[derive(Debug)]
@@ -235,6 +242,7 @@ impl StreamDecoder {
&mut self.dictionaries,
&version,
self.require_alignment,
+ self.skip_validation.clone(),
)?;
self.state = DecoderState::default();
}