This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 79bda7d36 Handle non-contiguous type_ids in UnionArray (#3653) (#3654)
79bda7d36 is described below

commit 79bda7d361579cda88fec2eb9b8793ad7f653442
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Fri Feb 3 11:38:09 2023 +0000

    Handle non-contiguous type_ids in UnionArray (#3653) (#3654)
---
 arrow-array/src/array/union_array.rs | 111 ++++++++++++++++++++++++++++++-----
 arrow-ipc/src/writer.rs              |   9 +--
 2 files changed, 98 insertions(+), 22 deletions(-)

diff --git a/arrow-array/src/array/union_array.rs 
b/arrow-array/src/array/union_array.rs
index 5870952d7..f215fb0de 100644
--- a/arrow-array/src/array/union_array.rs
+++ b/arrow-array/src/array/union_array.rs
@@ -107,7 +107,7 @@ use std::any::Any;
 #[derive(Clone)]
 pub struct UnionArray {
     data: ArrayData,
-    boxed_fields: Vec<ArrayRef>,
+    boxed_fields: Vec<Option<ArrayRef>>,
 }
 
 impl UnionArray {
@@ -229,9 +229,8 @@ impl UnionArray {
     /// Panics if the `type_id` provided is less than zero or greater than the 
number of types
     /// in the `Union`.
     pub fn child(&self, type_id: i8) -> &ArrayRef {
-        assert!(0 <= type_id);
-        assert!((type_id as usize) < self.boxed_fields.len());
-        &self.boxed_fields[type_id as usize]
+        let boxed = &self.boxed_fields[type_id as usize];
+        boxed.as_ref().expect("invalid type id")
     }
 
     /// Returns the `type_id` for the array slot at `index`.
@@ -264,8 +263,8 @@ impl UnionArray {
     pub fn value(&self, i: usize) -> ArrayRef {
         let type_id = self.type_id(i);
         let value_offset = self.value_offset(i) as usize;
-        let child_data = self.boxed_fields[type_id as usize].clone();
-        child_data.slice(value_offset, 1)
+        let child = self.child(type_id);
+        child.slice(value_offset, 1)
     }
 
     /// Returns the names of the types in the union.
@@ -290,9 +289,14 @@ impl UnionArray {
 
 impl From<ArrayData> for UnionArray {
     fn from(data: ArrayData) -> Self {
-        let mut boxed_fields = vec![];
-        for cd in data.child_data() {
-            boxed_fields.push(make_array(cd.clone()));
+        let field_ids = match data.data_type() {
+            DataType::Union(_, ids, _) => ids,
+            d => panic!("UnionArray expected ArrayData with type Union got 
{d}"),
+        };
+        let max_id = field_ids.iter().copied().max().unwrap_or_default() as 
usize;
+        let mut boxed_fields = vec![None; max_id + 1];
+        for (cd, field_id) in data.child_data().iter().zip(field_ids) {
+            boxed_fields[*field_id as usize] = Some(make_array(cd.clone()));
         }
         Self { data, boxed_fields }
     }
@@ -348,21 +352,27 @@ impl std::fmt::Debug for UnionArray {
         writeln!(f, "-- type id buffer:")?;
         writeln!(f, "{:?}", self.data().buffers()[0])?;
 
-        if self.is_dense() {
+        let (fields, ids, mode) = match self.data_type() {
+            DataType::Union(f, ids, mode) => (f, ids, mode),
+            _ => unreachable!(),
+        };
+
+        if mode == &UnionMode::Dense {
             writeln!(f, "-- offsets buffer:")?;
             writeln!(f, "{:?}", self.data().buffers()[1])?;
         }
 
-        for (child_index, name) in self.type_names().iter().enumerate() {
-            let column = &self.boxed_fields[child_index];
+        assert_eq!(fields.len(), ids.len());
+        for (field, type_id) in fields.iter().zip(ids) {
+            let child = self.child(*type_id);
             writeln!(
                 f,
                 "-- child {}: \"{}\" ({:?})",
-                child_index,
-                *name,
-                column.data_type()
+                type_id,
+                field.name(),
+                field.data_type()
             )?;
-            std::fmt::Debug::fmt(column, f)?;
+            std::fmt::Debug::fmt(child, f)?;
             writeln!(f)?;
         }
         writeln!(f, "]")
@@ -374,6 +384,7 @@ mod tests {
     use super::*;
 
     use crate::builder::UnionBuilder;
+    use crate::cast::{as_primitive_array, as_string_array};
     use crate::types::{Float32Type, Float64Type, Int32Type, Int64Type};
     use crate::RecordBatch;
     use crate::{Float64Array, Int32Array, Int64Array, StringArray};
@@ -1017,4 +1028,72 @@ mod tests {
         let record_batch_slice = record_batch.slice(1, 3);
         test_slice_union(record_batch_slice);
     }
+
+    #[test]
+    fn test_custom_type_ids() {
+        let data_type = DataType::Union(
+            vec![
+                Field::new("strings", DataType::Utf8, false),
+                Field::new("integers", DataType::Int32, false),
+                Field::new("floats", DataType::Float64, false),
+            ],
+            vec![8, 4, 9],
+            UnionMode::Dense,
+        );
+
+        let string_array = StringArray::from(vec!["foo", "bar", "baz"]);
+        let int_array = Int32Array::from(vec![5, 6, 4]);
+        let float_array = Float64Array::from(vec![10.0]);
+
+        let type_ids = Buffer::from_iter([4_i8, 8, 4, 8, 9, 4, 8]);
+        let value_offsets = Buffer::from_iter([0_i32, 0, 1, 1, 0, 2, 2]);
+
+        let data = ArrayData::builder(data_type)
+            .len(7)
+            .buffers(vec![type_ids, value_offsets])
+            .child_data(vec![
+                string_array.into_data(),
+                int_array.into_data(),
+                float_array.into_data(),
+            ])
+            .build()
+            .unwrap();
+
+        let array = UnionArray::from(data);
+
+        let v = array.value(0);
+        assert_eq!(v.data_type(), &DataType::Int32);
+        assert_eq!(v.len(), 1);
+        assert_eq!(as_primitive_array::<Int32Type>(v.as_ref()).value(0), 5);
+
+        let v = array.value(1);
+        assert_eq!(v.data_type(), &DataType::Utf8);
+        assert_eq!(v.len(), 1);
+        assert_eq!(as_string_array(v.as_ref()).value(0), "foo");
+
+        let v = array.value(2);
+        assert_eq!(v.data_type(), &DataType::Int32);
+        assert_eq!(v.len(), 1);
+        assert_eq!(as_primitive_array::<Int32Type>(v.as_ref()).value(0), 6);
+
+        let v = array.value(3);
+        assert_eq!(v.data_type(), &DataType::Utf8);
+        assert_eq!(v.len(), 1);
+        assert_eq!(as_string_array(v.as_ref()).value(0), "bar");
+
+        let v = array.value(4);
+        assert_eq!(v.data_type(), &DataType::Float64);
+        assert_eq!(v.len(), 1);
+        assert_eq!(as_primitive_array::<Float64Type>(v.as_ref()).value(0), 
10.0);
+
+        let v = array.value(5);
+        assert_eq!(v.data_type(), &DataType::Int32);
+        assert_eq!(v.len(), 1);
+        assert_eq!(as_primitive_array::<Int32Type>(v.as_ref()).value(0), 4);
+
+        let v = array.value(6);
+        assert_eq!(v.data_type(), &DataType::Utf8);
+        assert_eq!(v.len(), 1);
+        assert_eq!(as_string_array(v.as_ref()).value(0), "baz");
+    }
 }
diff --git a/arrow-ipc/src/writer.rs b/arrow-ipc/src/writer.rs
index 1879dde08..8835cb49f 100644
--- a/arrow-ipc/src/writer.rs
+++ b/arrow-ipc/src/writer.rs
@@ -279,13 +279,10 @@ impl IpcDataGenerator {
                     write_options,
                 )?;
             }
-            DataType::Union(fields, _, _) => {
+            DataType::Union(fields, type_ids, _) => {
                 let union = as_union_array(column);
-                for (field, column) in fields
-                    .iter()
-                    .enumerate()
-                    .map(|(n, f)| (f, union.child(n as i8)))
-                {
+                for (field, type_id) in fields.iter().zip(type_ids) {
+                    let column = union.child(*type_id);
                     self.encode_dictionaries(
                         field,
                         column,

Reply via email to