This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new a579f9d  feat(ipc): support for reading union arrays through IPC 
(#1140)
a579f9d is described below

commit a579f9d7cad1d76d06e24836e29e59ab9f581581
Author: Helgi Kristvin Sigurbjarnarson <[email protected]>
AuthorDate: Thu Jan 6 14:46:16 2022 -0800

    feat(ipc): support for reading union arrays through IPC (#1140)
---
 arrow/src/ipc/reader.rs | 124 +++++++++++++++++++++++++++++++++++++++++-------
 1 file changed, 108 insertions(+), 16 deletions(-)

diff --git a/arrow/src/ipc/reader.rs b/arrow/src/ipc/reader.rs
index 1d9f36d..27633df 100644
--- a/arrow/src/ipc/reader.rs
+++ b/arrow/src/ipc/reader.rs
@@ -27,7 +27,7 @@ use std::sync::Arc;
 use crate::array::*;
 use crate::buffer::Buffer;
 use crate::compute::cast;
-use crate::datatypes::{DataType, Field, IntervalUnit, Schema, SchemaRef};
+use crate::datatypes::{DataType, Field, IntervalUnit, Schema, SchemaRef, 
UnionMode};
 use crate::error::{ArrowError, Result};
 use crate::ipc;
 use crate::record_batch::{RecordBatch, RecordBatchReader};
@@ -60,7 +60,7 @@ fn create_array(
     dictionaries: &[Option<ArrayRef>],
     mut node_index: usize,
     mut buffer_index: usize,
-) -> (ArrayRef, usize, usize) {
+) -> Result<(ArrayRef, usize, usize)> {
     use DataType::*;
     let array = match data_type {
         Utf8 | Binary | LargeBinary | LargeUtf8 => {
@@ -105,7 +105,7 @@ fn create_array(
                 dictionaries,
                 node_index,
                 buffer_index,
-            );
+            )?;
             node_index = triple.1;
             buffer_index = triple.2;
 
@@ -127,7 +127,7 @@ fn create_array(
                 dictionaries,
                 node_index,
                 buffer_index,
-            );
+            )?;
             node_index = triple.1;
             buffer_index = triple.2;
 
@@ -152,7 +152,7 @@ fn create_array(
                     dictionaries,
                     node_index,
                     buffer_index,
-                );
+                )?;
                 node_index = triple.1;
                 buffer_index = triple.2;
                 struct_arrays.push((struct_field.clone(), triple.0));
@@ -184,6 +184,55 @@ fn create_array(
                 value_array,
             )
         }
+        Union(fields, mode) => {
+            let union_node = nodes[node_index];
+            node_index += 1;
+
+            let len = union_node.length() as usize;
+
+            let null_buffer: Buffer = read_buffer(&buffers[buffer_index], 
data);
+            let type_ids: Buffer =
+                read_buffer(&buffers[buffer_index + 1], data)[..len].into();
+
+            buffer_index += 2;
+
+            let value_offsets = match mode {
+                UnionMode::Dense => {
+                    let buffer = read_buffer(&buffers[buffer_index], data);
+                    buffer_index += 1;
+                    Some(buffer[..len * 4].into())
+                }
+                UnionMode::Sparse => None,
+            };
+
+            let mut children = vec![];
+
+            for field in fields {
+                let triple = create_array(
+                    nodes,
+                    field.data_type(),
+                    data,
+                    buffers,
+                    dictionaries,
+                    node_index,
+                    buffer_index,
+                )?;
+
+                node_index = triple.1;
+                buffer_index = triple.2;
+
+                children.push((field.clone(), triple.0));
+            }
+
+            let array = UnionArray::try_new(
+                type_ids,
+                value_offsets,
+                children,
+                Some(null_buffer),
+            )?;
+
+            Arc::new(array)
+        }
         Null => {
             let length = nodes[node_index].length() as usize;
             let data = ArrayData::builder(data_type.clone())
@@ -209,7 +258,7 @@ fn create_array(
             array
         }
     };
-    (array, node_index, buffer_index)
+    Ok((array, node_index, buffer_index))
 }
 
 /// Reads the correct number of buffers based on data type and null_count, and 
creates a
@@ -438,7 +487,7 @@ pub fn read_record_batch(
             dictionaries,
             node_index,
             buffer_index,
-        );
+        )?;
         node_index = triple.1;
         buffer_index = triple.2;
         arrays.push(triple.0);
@@ -1165,6 +1214,19 @@ mod tests {
         })
     }
 
+    fn roundtrip_ipc(rb: &RecordBatch) -> RecordBatch {
+        let mut buf = Vec::new();
+        let mut writer =
+            ipc::writer::FileWriter::try_new(&mut buf, &rb.schema()).unwrap();
+        writer.write(rb).unwrap();
+        writer.finish().unwrap();
+        drop(writer);
+
+        let mut reader =
+            
ipc::reader::FileReader::try_new(std::io::Cursor::new(buf)).unwrap();
+        reader.next().unwrap().unwrap()
+    }
+
     #[test]
     fn test_roundtrip_nested_dict() {
         let inner: DictionaryArray<datatypes::Int32Type> =
@@ -1183,18 +1245,48 @@ mod tests {
             false,
         )]));
 
-        let batch = RecordBatch::try_new(schema.clone(), 
vec![struct_array]).unwrap();
+        let batch = RecordBatch::try_new(schema, vec![struct_array]).unwrap();
 
-        let mut buf = Vec::new();
-        let mut writer = ipc::writer::FileWriter::try_new(&mut buf, 
&schema).unwrap();
-        writer.write(&batch).unwrap();
-        writer.finish().unwrap();
-        drop(writer);
+        assert_eq!(batch, roundtrip_ipc(&batch));
+    }
+
+    fn check_union_with_builder(mut builder: UnionBuilder) {
+        builder.append::<datatypes::Int32Type>("a", 1).unwrap();
+        builder.append_null().unwrap();
+        builder.append::<datatypes::Float64Type>("c", 3.0).unwrap();
+        builder.append::<datatypes::Int32Type>("a", 4).unwrap();
+        builder.append::<datatypes::Int64Type>("d", 11).unwrap();
+        let union = builder.build().unwrap();
+
+        let schema = Arc::new(Schema::new(vec![Field::new(
+            "union",
+            union.data_type().clone(),
+            false,
+        )]));
+
+        let union_array = Arc::new(union) as ArrayRef;
 
-        let reader = 
ipc::reader::FileReader::try_new(std::io::Cursor::new(buf)).unwrap();
-        let batch2: std::result::Result<Vec<_>, _> = reader.collect();
+        let rb = RecordBatch::try_new(schema, vec![union_array]).unwrap();
+        let rb2 = roundtrip_ipc(&rb);
+        // TODO: equality not yet implemented for union, so we check that the 
length of the array is
+        // the same and that all of the buffers are the same instead.
+        assert_eq!(rb.schema(), rb2.schema());
+        assert_eq!(rb.num_columns(), rb2.num_columns());
+        assert_eq!(rb.num_rows(), rb2.num_rows());
+        let union1 = rb.column(0);
+        let union2 = rb2.column(0);
 
-        assert_eq!(batch, batch2.unwrap()[0]);
+        assert_eq!(union1.data().buffers(), union2.data().buffers());
+    }
+
+    #[test]
+    fn test_roundtrip_dense_union() {
+        check_union_with_builder(UnionBuilder::new_dense(6));
+    }
+
+    #[test]
+    fn test_roundtrip_sparse_union() {
+        check_union_with_builder(UnionBuilder::new_sparse(6));
     }
 
     /// Read gzipped JSON file

Reply via email to