This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 09fd4528d Enforce struct nullability in JSON raw reader (#3900)
(#3904) (#3906)
09fd4528d is described below
commit 09fd4528dd3fe3539511aa3f528891eb1cabea1e
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Fri Mar 24 12:25:38 2023 +0000
Enforce struct nullability in JSON raw reader (#3900) (#3904) (#3906)
* Enforce struct nullability in JSON raw reader (#3900) (#3904)
* Fix tests
* Review feedback
---
arrow-array/src/array/boolean_array.rs | 11 +--
arrow-buffer/src/buffer/boolean.rs | 10 +--
arrow-buffer/src/util/bit_chunk_iterator.rs | 6 ++
arrow-data/src/data/mod.rs | 9 +--
arrow-data/src/equal/utils.rs | 13 +---
arrow-json/src/raw/mod.rs | 102 ++++++++++++++++++++++++++--
arrow-json/src/raw/struct_array.rs | 32 +++++++--
7 files changed, 142 insertions(+), 41 deletions(-)
diff --git a/arrow-array/src/array/boolean_array.rs
b/arrow-array/src/array/boolean_array.rs
index 98de62da0..dea5c07da 100644
--- a/arrow-array/src/array/boolean_array.rs
+++ b/arrow-array/src/array/boolean_array.rs
@@ -127,15 +127,10 @@ impl BooleanArray {
pub fn true_count(&self) -> usize {
match self.data.nulls() {
Some(nulls) => {
- let null_chunks = nulls.inner().bit_chunks();
- let value_chunks = self.values().bit_chunks();
+ let null_chunks = nulls.inner().bit_chunks().iter_padded();
+ let value_chunks = self.values().bit_chunks().iter_padded();
null_chunks
- .iter()
- .zip(value_chunks.iter())
- .chain(std::iter::once((
- null_chunks.remainder_bits(),
- value_chunks.remainder_bits(),
- )))
+ .zip(value_chunks)
.map(|(a, b)| (a & b).count_ones() as usize)
.sum()
}
diff --git a/arrow-buffer/src/buffer/boolean.rs
b/arrow-buffer/src/buffer/boolean.rs
index 53ead4573..c89cfb332 100644
--- a/arrow-buffer/src/buffer/boolean.rs
+++ b/arrow-buffer/src/buffer/boolean.rs
@@ -36,13 +36,9 @@ impl PartialEq for BooleanBuffer {
return false;
}
- let lhs = self.bit_chunks();
- let rhs = other.bit_chunks();
-
- if lhs.iter().zip(rhs.iter()).any(|(a, b)| a != b) {
- return false;
- }
- lhs.remainder_bits() == rhs.remainder_bits()
+ let lhs = self.bit_chunks().iter_padded();
+ let rhs = other.bit_chunks().iter_padded();
+ lhs.zip(rhs).all(|(a, b)| a == b)
}
}
diff --git a/arrow-buffer/src/util/bit_chunk_iterator.rs
b/arrow-buffer/src/util/bit_chunk_iterator.rs
index a739a9694..3d9632e73 100644
--- a/arrow-buffer/src/util/bit_chunk_iterator.rs
+++ b/arrow-buffer/src/util/bit_chunk_iterator.rs
@@ -296,6 +296,12 @@ impl<'a> BitChunks<'a> {
index: 0,
}
}
+
+ /// Returns an iterator over chunks of 64 bits, with the remaining bits
zero padded to 64-bits
+ #[inline]
+ pub fn iter_padded(&self) -> impl Iterator<Item = u64> + 'a {
+ self.iter().chain(std::iter::once(self.remainder_bits()))
+ }
}
impl<'a> IntoIterator for BitChunks<'a> {
diff --git a/arrow-data/src/data/mod.rs b/arrow-data/src/data/mod.rs
index cc908d639..7241a5d80 100644
--- a/arrow-data/src/data/mod.rs
+++ b/arrow-data/src/data/mod.rs
@@ -1218,12 +1218,9 @@ impl ArrayData {
let mask = BitChunks::new(mask, offset, child.len);
let nulls = BitChunks::new(nulls.validity(), nulls.offset(),
child.len);
mask
- .iter()
- .zip(nulls.iter())
- .chain(std::iter::once((
- mask.remainder_bits(),
- nulls.remainder_bits(),
- ))).try_for_each(|(m, c)| {
+ .iter_padded()
+ .zip(nulls.iter_padded())
+ .try_for_each(|(m, c)| {
if (m & !c) != 0 {
return Err(ArrowError::InvalidArgumentError(format!(
"non-nullable child of type {} contains nulls not
present in parent",
diff --git a/arrow-data/src/equal/utils.rs b/arrow-data/src/equal/utils.rs
index d1f0f392a..6b9a7940d 100644
--- a/arrow-data/src/equal/utils.rs
+++ b/arrow-data/src/equal/utils.rs
@@ -29,16 +29,9 @@ pub(super) fn equal_bits(
rhs_start: usize,
len: usize,
) -> bool {
- let lhs = BitChunks::new(lhs_values, lhs_start, len);
- let rhs = BitChunks::new(rhs_values, rhs_start, len);
-
- for (a, b) in lhs.iter().zip(rhs.iter()) {
- if a != b {
- return false;
- }
- }
-
- lhs.remainder_bits() == rhs.remainder_bits()
+ let lhs = BitChunks::new(lhs_values, lhs_start, len).iter_padded();
+ let rhs = BitChunks::new(rhs_values, rhs_start, len).iter_padded();
+ lhs.zip(rhs).all(|(a, b)| a == b)
}
#[inline]
diff --git a/arrow-json/src/raw/mod.rs b/arrow-json/src/raw/mod.rs
index 2e5055bf1..21e6191ac 100644
--- a/arrow-json/src/raw/mod.rs
+++ b/arrow-json/src/raw/mod.rs
@@ -359,7 +359,7 @@ mod tests {
use crate::ReaderBuilder;
use arrow_array::cast::AsArray;
use arrow_array::types::Int32Type;
- use arrow_array::Array;
+ use arrow_array::{Array, StructArray};
use arrow_buffer::ArrowNativeType;
use arrow_cast::display::{ArrayFormatter, FormatOptions};
use arrow_schema::{DataType, Field, Schema};
@@ -511,8 +511,8 @@ mod tests {
Field::new(
"nested",
DataType::Struct(vec![
- Field::new("a", DataType::Int32, false),
- Field::new("b", DataType::Int32, false),
+ Field::new("a", DataType::Int32, true),
+ Field::new("b", DataType::Int32, true),
]),
true,
),
@@ -591,7 +591,7 @@ mod tests {
"list2",
DataType::List(Box::new(Field::new(
"element",
- DataType::Struct(vec![Field::new("d", DataType::Int32,
false)]),
+ DataType::Struct(vec![Field::new("d", DataType::Int32,
true)]),
false,
))),
true,
@@ -1001,4 +1001,98 @@ mod tests {
test_time::<Time64MicrosecondType>();
test_time::<Time64NanosecondType>();
}
+
+ #[test]
+ fn test_delta_checkpoint() {
+ let json =
"{\"protocol\":{\"minReaderVersion\":1,\"minWriterVersion\":2}}";
+ let schema = Arc::new(Schema::new(vec![
+ Field::new(
+ "protocol",
+ DataType::Struct(vec![
+ Field::new("minReaderVersion", DataType::Int32, true),
+ Field::new("minWriterVersion", DataType::Int32, true),
+ ]),
+ true,
+ ),
+ Field::new(
+ "add",
+ DataType::Struct(vec![Field::new(
+ "partitionValues",
+ DataType::Map(
+ Box::new(Field::new(
+ "key_value",
+ DataType::Struct(vec![
+ Field::new("key", DataType::Utf8, false),
+ Field::new("value", DataType::Utf8, true),
+ ]),
+ false,
+ )),
+ false,
+ ),
+ false,
+ )]),
+ true,
+ ),
+ ]));
+
+ let batches = do_read(json, 1024, true, schema);
+ assert_eq!(batches.len(), 1);
+
+ let s: StructArray = batches.into_iter().next().unwrap().into();
+ let opts = FormatOptions::default().with_null("null");
+ let formatter = ArrayFormatter::try_new(&s, &opts).unwrap();
+ assert_eq!(
+ formatter.value(0).to_string(),
+ "{protocol: {minReaderVersion: 1, minWriterVersion: 2}, add: null}"
+ );
+ }
+
+ #[test]
+ fn struct_nullability() {
+ let do_test = |child: DataType| {
+ // Test correctly enforced nullability
+ let non_null = r#"{"foo": {}}"#;
+ let schema = Arc::new(Schema::new(vec![Field::new(
+ "foo",
+ DataType::Struct(vec![Field::new("bar", child, false)]),
+ true,
+ )]));
+ let mut reader = RawReaderBuilder::new(schema.clone())
+ .build(Cursor::new(non_null.as_bytes()))
+ .unwrap();
+ assert!(reader.next().unwrap().is_err()); // Should error as not
nullable
+
+ let null = r#"{"foo": {bar: null}}"#;
+ let mut reader = RawReaderBuilder::new(schema.clone())
+ .build(Cursor::new(null.as_bytes()))
+ .unwrap();
+ assert!(reader.next().unwrap().is_err()); // Should error as not
nullable
+
+ // Test nulls in nullable parent can mask nulls in non-nullable
child
+ let null = r#"{"foo": null}"#;
+ let mut reader = RawReaderBuilder::new(schema)
+ .build(Cursor::new(null.as_bytes()))
+ .unwrap();
+ let batch = reader.next().unwrap().unwrap();
+ assert_eq!(batch.num_columns(), 1);
+ let foo = batch.column(0).as_struct();
+ assert_eq!(foo.len(), 1);
+ assert!(foo.is_null(0));
+ assert_eq!(foo.num_columns(), 1);
+
+ let bar = foo.column(0);
+ assert_eq!(bar.len(), 1);
+ // Non-nullable child can still contain null as masked by parent
+ assert!(bar.is_null(0));
+ };
+
+ do_test(DataType::Boolean);
+ do_test(DataType::Int32);
+ do_test(DataType::Utf8);
+ do_test(DataType::Decimal128(2, 1));
+ do_test(DataType::Timestamp(
+ TimeUnit::Microsecond,
+ Some("+00:00".to_string()),
+ ));
+ }
}
diff --git a/arrow-json/src/raw/struct_array.rs
b/arrow-json/src/raw/struct_array.rs
index 1d0019993..219f56ae6 100644
--- a/arrow-json/src/raw/struct_array.rs
+++ b/arrow-json/src/raw/struct_array.rs
@@ -37,7 +37,11 @@ impl StructArrayDecoder {
let decoders = struct_fields(&data_type)
.iter()
.map(|f| {
- make_decoder(f.data_type().clone(), coerce_primitive,
f.is_nullable())
+ // If this struct nullable, need to permit nullability in
child array
+ // StructArrayDecoder::decode verifies that if the child is
not nullable
+ // it doesn't contain any nulls not masked by its parent
+ let nullable = f.is_nullable() || is_nullable;
+ make_decoder(f.data_type().clone(), coerce_primitive, nullable)
})
.collect::<Result<Vec<_>, ArrowError>>()?;
@@ -102,15 +106,31 @@ impl ArrayDecoder for StructArrayDecoder {
.map(|(d, pos)| d.decode(tape, &pos))
.collect::<Result<Vec<_>, ArrowError>>()?;
- // Sanity check
- child_data
- .iter()
- .for_each(|x| assert_eq!(x.len(), pos.len()));
-
let nulls = nulls
.as_mut()
.map(|x| NullBuffer::new(BooleanBuffer::new(x.finish(), 0,
pos.len())));
+ for (c, f) in child_data.iter().zip(fields) {
+ // Sanity check
+ assert_eq!(c.len(), pos.len());
+
+ if !f.is_nullable() && c.null_count() != 0 {
+ // Need to verify nulls
+ let valid = match nulls.as_ref() {
+ Some(nulls) => {
+ let lhs = nulls.inner().bit_chunks().iter_padded();
+ let rhs =
c.nulls().unwrap().inner().bit_chunks().iter_padded();
+ lhs.zip(rhs).all(|(l, r)| (l & !r) == 0)
+ }
+ None => false,
+ };
+
+ if !valid {
+ return Err(ArrowError::JsonError(format!("Encountered
unmasked nulls in non-nullable StructArray child: {f}")));
+ }
+ }
+ }
+
let data = ArrayDataBuilder::new(self.data_type.clone())
.len(pos.len())
.nulls(nulls)