This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 09fd4528d Enforce struct nullability in JSON raw reader (#3900) 
(#3904) (#3906)
09fd4528d is described below

commit 09fd4528dd3fe3539511aa3f528891eb1cabea1e
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Fri Mar 24 12:25:38 2023 +0000

    Enforce struct nullability in JSON raw reader (#3900) (#3904) (#3906)
    
    * Enforce struct nullability in JSON raw reader (#3900) (#3904)
    
    * Fix tests
    
    * Review feedback
---
 arrow-array/src/array/boolean_array.rs      |  11 +--
 arrow-buffer/src/buffer/boolean.rs          |  10 +--
 arrow-buffer/src/util/bit_chunk_iterator.rs |   6 ++
 arrow-data/src/data/mod.rs                  |   9 +--
 arrow-data/src/equal/utils.rs               |  13 +---
 arrow-json/src/raw/mod.rs                   | 102 ++++++++++++++++++++++++++--
 arrow-json/src/raw/struct_array.rs          |  32 +++++++--
 7 files changed, 142 insertions(+), 41 deletions(-)

diff --git a/arrow-array/src/array/boolean_array.rs 
b/arrow-array/src/array/boolean_array.rs
index 98de62da0..dea5c07da 100644
--- a/arrow-array/src/array/boolean_array.rs
+++ b/arrow-array/src/array/boolean_array.rs
@@ -127,15 +127,10 @@ impl BooleanArray {
     pub fn true_count(&self) -> usize {
         match self.data.nulls() {
             Some(nulls) => {
-                let null_chunks = nulls.inner().bit_chunks();
-                let value_chunks = self.values().bit_chunks();
+                let null_chunks = nulls.inner().bit_chunks().iter_padded();
+                let value_chunks = self.values().bit_chunks().iter_padded();
                 null_chunks
-                    .iter()
-                    .zip(value_chunks.iter())
-                    .chain(std::iter::once((
-                        null_chunks.remainder_bits(),
-                        value_chunks.remainder_bits(),
-                    )))
+                    .zip(value_chunks)
                     .map(|(a, b)| (a & b).count_ones() as usize)
                     .sum()
             }
diff --git a/arrow-buffer/src/buffer/boolean.rs 
b/arrow-buffer/src/buffer/boolean.rs
index 53ead4573..c89cfb332 100644
--- a/arrow-buffer/src/buffer/boolean.rs
+++ b/arrow-buffer/src/buffer/boolean.rs
@@ -36,13 +36,9 @@ impl PartialEq for BooleanBuffer {
             return false;
         }
 
-        let lhs = self.bit_chunks();
-        let rhs = other.bit_chunks();
-
-        if lhs.iter().zip(rhs.iter()).any(|(a, b)| a != b) {
-            return false;
-        }
-        lhs.remainder_bits() == rhs.remainder_bits()
+        let lhs = self.bit_chunks().iter_padded();
+        let rhs = other.bit_chunks().iter_padded();
+        lhs.zip(rhs).all(|(a, b)| a == b)
     }
 }
 
diff --git a/arrow-buffer/src/util/bit_chunk_iterator.rs 
b/arrow-buffer/src/util/bit_chunk_iterator.rs
index a739a9694..3d9632e73 100644
--- a/arrow-buffer/src/util/bit_chunk_iterator.rs
+++ b/arrow-buffer/src/util/bit_chunk_iterator.rs
@@ -296,6 +296,12 @@ impl<'a> BitChunks<'a> {
             index: 0,
         }
     }
+
+    /// Returns an iterator over chunks of 64 bits, with the remaining bits 
zero padded to 64-bits
+    #[inline]
+    pub fn iter_padded(&self) -> impl Iterator<Item = u64> + 'a {
+        self.iter().chain(std::iter::once(self.remainder_bits()))
+    }
 }
 
 impl<'a> IntoIterator for BitChunks<'a> {
diff --git a/arrow-data/src/data/mod.rs b/arrow-data/src/data/mod.rs
index cc908d639..7241a5d80 100644
--- a/arrow-data/src/data/mod.rs
+++ b/arrow-data/src/data/mod.rs
@@ -1218,12 +1218,9 @@ impl ArrayData {
                 let mask = BitChunks::new(mask, offset, child.len);
                 let nulls = BitChunks::new(nulls.validity(), nulls.offset(), 
child.len);
                 mask
-                    .iter()
-                    .zip(nulls.iter())
-                    .chain(std::iter::once((
-                        mask.remainder_bits(),
-                        nulls.remainder_bits(),
-                    ))).try_for_each(|(m, c)| {
+                    .iter_padded()
+                    .zip(nulls.iter_padded())
+                    .try_for_each(|(m, c)| {
                     if (m & !c) != 0 {
                         return Err(ArrowError::InvalidArgumentError(format!(
                             "non-nullable child of type {} contains nulls not 
present in parent",
diff --git a/arrow-data/src/equal/utils.rs b/arrow-data/src/equal/utils.rs
index d1f0f392a..6b9a7940d 100644
--- a/arrow-data/src/equal/utils.rs
+++ b/arrow-data/src/equal/utils.rs
@@ -29,16 +29,9 @@ pub(super) fn equal_bits(
     rhs_start: usize,
     len: usize,
 ) -> bool {
-    let lhs = BitChunks::new(lhs_values, lhs_start, len);
-    let rhs = BitChunks::new(rhs_values, rhs_start, len);
-
-    for (a, b) in lhs.iter().zip(rhs.iter()) {
-        if a != b {
-            return false;
-        }
-    }
-
-    lhs.remainder_bits() == rhs.remainder_bits()
+    let lhs = BitChunks::new(lhs_values, lhs_start, len).iter_padded();
+    let rhs = BitChunks::new(rhs_values, rhs_start, len).iter_padded();
+    lhs.zip(rhs).all(|(a, b)| a == b)
 }
 
 #[inline]
diff --git a/arrow-json/src/raw/mod.rs b/arrow-json/src/raw/mod.rs
index 2e5055bf1..21e6191ac 100644
--- a/arrow-json/src/raw/mod.rs
+++ b/arrow-json/src/raw/mod.rs
@@ -359,7 +359,7 @@ mod tests {
     use crate::ReaderBuilder;
     use arrow_array::cast::AsArray;
     use arrow_array::types::Int32Type;
-    use arrow_array::Array;
+    use arrow_array::{Array, StructArray};
     use arrow_buffer::ArrowNativeType;
     use arrow_cast::display::{ArrayFormatter, FormatOptions};
     use arrow_schema::{DataType, Field, Schema};
@@ -511,8 +511,8 @@ mod tests {
             Field::new(
                 "nested",
                 DataType::Struct(vec![
-                    Field::new("a", DataType::Int32, false),
-                    Field::new("b", DataType::Int32, false),
+                    Field::new("a", DataType::Int32, true),
+                    Field::new("b", DataType::Int32, true),
                 ]),
                 true,
             ),
@@ -591,7 +591,7 @@ mod tests {
                     "list2",
                     DataType::List(Box::new(Field::new(
                         "element",
-                        DataType::Struct(vec![Field::new("d", DataType::Int32, 
false)]),
+                        DataType::Struct(vec![Field::new("d", DataType::Int32, 
true)]),
                         false,
                     ))),
                     true,
@@ -1001,4 +1001,98 @@ mod tests {
         test_time::<Time64MicrosecondType>();
         test_time::<Time64NanosecondType>();
     }
+
+    #[test]
+    fn test_delta_checkpoint() {
+        let json = 
"{\"protocol\":{\"minReaderVersion\":1,\"minWriterVersion\":2}}";
+        let schema = Arc::new(Schema::new(vec![
+            Field::new(
+                "protocol",
+                DataType::Struct(vec![
+                    Field::new("minReaderVersion", DataType::Int32, true),
+                    Field::new("minWriterVersion", DataType::Int32, true),
+                ]),
+                true,
+            ),
+            Field::new(
+                "add",
+                DataType::Struct(vec![Field::new(
+                    "partitionValues",
+                    DataType::Map(
+                        Box::new(Field::new(
+                            "key_value",
+                            DataType::Struct(vec![
+                                Field::new("key", DataType::Utf8, false),
+                                Field::new("value", DataType::Utf8, true),
+                            ]),
+                            false,
+                        )),
+                        false,
+                    ),
+                    false,
+                )]),
+                true,
+            ),
+        ]));
+
+        let batches = do_read(json, 1024, true, schema);
+        assert_eq!(batches.len(), 1);
+
+        let s: StructArray = batches.into_iter().next().unwrap().into();
+        let opts = FormatOptions::default().with_null("null");
+        let formatter = ArrayFormatter::try_new(&s, &opts).unwrap();
+        assert_eq!(
+            formatter.value(0).to_string(),
+            "{protocol: {minReaderVersion: 1, minWriterVersion: 2}, add: null}"
+        );
+    }
+
+    #[test]
+    fn struct_nullability() {
+        let do_test = |child: DataType| {
+            // Test correctly enforced nullability
+            let non_null = r#"{"foo": {}}"#;
+            let schema = Arc::new(Schema::new(vec![Field::new(
+                "foo",
+                DataType::Struct(vec![Field::new("bar", child, false)]),
+                true,
+            )]));
+            let mut reader = RawReaderBuilder::new(schema.clone())
+                .build(Cursor::new(non_null.as_bytes()))
+                .unwrap();
+            assert!(reader.next().unwrap().is_err()); // Should error as not 
nullable
+
+            let null = r#"{"foo": {bar: null}}"#;
+            let mut reader = RawReaderBuilder::new(schema.clone())
+                .build(Cursor::new(null.as_bytes()))
+                .unwrap();
+            assert!(reader.next().unwrap().is_err()); // Should error as not 
nullable
+
+            // Test nulls in nullable parent can mask nulls in non-nullable 
child
+            let null = r#"{"foo": null}"#;
+            let mut reader = RawReaderBuilder::new(schema)
+                .build(Cursor::new(null.as_bytes()))
+                .unwrap();
+            let batch = reader.next().unwrap().unwrap();
+            assert_eq!(batch.num_columns(), 1);
+            let foo = batch.column(0).as_struct();
+            assert_eq!(foo.len(), 1);
+            assert!(foo.is_null(0));
+            assert_eq!(foo.num_columns(), 1);
+
+            let bar = foo.column(0);
+            assert_eq!(bar.len(), 1);
+            // Non-nullable child can still contain null as masked by parent
+            assert!(bar.is_null(0));
+        };
+
+        do_test(DataType::Boolean);
+        do_test(DataType::Int32);
+        do_test(DataType::Utf8);
+        do_test(DataType::Decimal128(2, 1));
+        do_test(DataType::Timestamp(
+            TimeUnit::Microsecond,
+            Some("+00:00".to_string()),
+        ));
+    }
 }
diff --git a/arrow-json/src/raw/struct_array.rs 
b/arrow-json/src/raw/struct_array.rs
index 1d0019993..219f56ae6 100644
--- a/arrow-json/src/raw/struct_array.rs
+++ b/arrow-json/src/raw/struct_array.rs
@@ -37,7 +37,11 @@ impl StructArrayDecoder {
         let decoders = struct_fields(&data_type)
             .iter()
             .map(|f| {
-                make_decoder(f.data_type().clone(), coerce_primitive, 
f.is_nullable())
+                // If this struct nullable, need to permit nullability in 
child array
+                // StructArrayDecoder::decode verifies that if the child is 
not nullable
+                // it doesn't contain any nulls not masked by its parent
+                let nullable = f.is_nullable() || is_nullable;
+                make_decoder(f.data_type().clone(), coerce_primitive, nullable)
             })
             .collect::<Result<Vec<_>, ArrowError>>()?;
 
@@ -102,15 +106,31 @@ impl ArrayDecoder for StructArrayDecoder {
             .map(|(d, pos)| d.decode(tape, &pos))
             .collect::<Result<Vec<_>, ArrowError>>()?;
 
-        // Sanity check
-        child_data
-            .iter()
-            .for_each(|x| assert_eq!(x.len(), pos.len()));
-
         let nulls = nulls
             .as_mut()
             .map(|x| NullBuffer::new(BooleanBuffer::new(x.finish(), 0, 
pos.len())));
 
+        for (c, f) in child_data.iter().zip(fields) {
+            // Sanity check
+            assert_eq!(c.len(), pos.len());
+
+            if !f.is_nullable() && c.null_count() != 0 {
+                // Need to verify nulls
+                let valid = match nulls.as_ref() {
+                    Some(nulls) => {
+                        let lhs = nulls.inner().bit_chunks().iter_padded();
+                        let rhs = 
c.nulls().unwrap().inner().bit_chunks().iter_padded();
+                        lhs.zip(rhs).all(|(l, r)| (l & !r) == 0)
+                    }
+                    None => false,
+                };
+
+                if !valid {
+                    return Err(ArrowError::JsonError(format!("Encountered 
unmasked nulls in non-nullable StructArray child: {f}")));
+                }
+            }
+        }
+
         let data = ArrayDataBuilder::new(self.data_type.clone())
             .len(pos.len())
             .nulls(nulls)

Reply via email to