This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new c203785ca39 Add `UnionArray::into_parts` (#5585)
c203785ca39 is described below
commit c203785ca398b879960bffbd30b988c9728b7c23
Author: Matthijs Brobbel <[email protected]>
AuthorDate: Tue Apr 9 11:42:44 2024 +0200
Add `UnionArray::into_parts` (#5585)
* Add `UnionArray::into_parts`
* Return `UnionFields` and `UnionMode` instead of `DataType`
* Add `into_parts` test with custom type ids
* Change `into_parts` output to better match `try_new`
* Remove UnionMode
---------
Co-authored-by: Raphael Taylor-Davies <[email protected]>
---
arrow-array/src/array/union_array.rs | 166 +++++++++++++++++++++++++++++++++++
1 file changed, 166 insertions(+)
diff --git a/arrow-array/src/array/union_array.rs
b/arrow-array/src/array/union_array.rs
index e3e63724753..22d4cf90a09 100644
--- a/arrow-array/src/array/union_array.rs
+++ b/arrow-array/src/array/union_array.rs
@@ -23,6 +23,7 @@ use arrow_schema::{ArrowError, DataType, Field, UnionFields,
UnionMode};
/// Contains the `UnionArray` type.
///
use std::any::Any;
+use std::collections::HashMap;
use std::sync::Arc;
/// An array of [values of varying
types](https://arrow.apache.org/docs/format/Columnar.html#union-layout)
@@ -319,6 +320,70 @@ impl UnionArray {
fields,
}
}
+
+ /// Deconstruct this array into its constituent parts
+ ///
+ /// # Example
+ ///
+ /// ```
+ /// # use arrow_array::array::UnionArray;
+ /// # use arrow_array::types::Int32Type;
+ /// # use arrow_array::builder::UnionBuilder;
+ /// # use arrow_buffer::ScalarBuffer;
+ /// # fn main() -> Result<(), arrow_schema::ArrowError> {
+ /// let mut builder = UnionBuilder::new_dense();
+ /// builder.append::<Int32Type>("a", 1).unwrap();
+ /// let union_array = builder.build()?;
+ ///
+ /// // Deconstruct into parts
+ /// let (type_ids, offsets, field_type_ids, fields) =
union_array.into_parts();
+ ///
+ /// // Reconstruct from parts
+ /// let union_array = UnionArray::try_new(
+ /// &field_type_ids,
+ /// type_ids.into_inner(),
+ /// offsets.map(ScalarBuffer::into_inner),
+ /// fields,
+ /// );
+ /// # Ok(())
+ /// # }
+ /// ```
+ #[allow(clippy::type_complexity)]
+ pub fn into_parts(
+ self,
+ ) -> (
+ ScalarBuffer<i8>,
+ Option<ScalarBuffer<i32>>,
+ Vec<i8>,
+ Vec<(Field, ArrayRef)>,
+ ) {
+ let Self {
+ data_type,
+ type_ids,
+ offsets,
+ fields,
+ } = self;
+ match data_type {
+ DataType::Union(union_fields, _) => {
+ let union_fields = union_fields.iter().collect::<HashMap<_,
_>>();
+ let (field_type_ids, fields) = fields
+ .into_iter()
+ .enumerate()
+ .flat_map(|(type_id, array_ref)| {
+ array_ref.map(|array_ref| {
+ let type_id = type_id as i8;
+ (
+ type_id,
+
((*Arc::clone(union_fields[&type_id])).clone(), array_ref),
+ )
+ })
+ })
+ .unzip();
+ (type_ids, offsets, field_type_ids, fields)
+ }
+ _ => unreachable!(),
+ }
+ }
}
impl From<ArrayData> for UnionArray {
@@ -505,6 +570,7 @@ impl std::fmt::Debug for UnionArray {
mod tests {
use super::*;
+ use crate::array::Int8Type;
use crate::builder::UnionBuilder;
use crate::cast::AsArray;
use crate::types::{Float32Type, Float64Type, Int32Type, Int64Type};
@@ -1201,4 +1267,104 @@ mod tests {
assert_eq!(v.len(), 1);
assert_eq!(v.as_string::<i32>().value(0), "baz");
}
+
+ #[test]
+ fn into_parts() {
+ let mut builder = UnionBuilder::new_dense();
+ builder.append::<Int32Type>("a", 1).unwrap();
+ builder.append::<Int8Type>("b", 2).unwrap();
+ builder.append::<Int32Type>("a", 3).unwrap();
+ let dense_union = builder.build().unwrap();
+
+ let field = [
+ Field::new("a", DataType::Int32, false),
+ Field::new("b", DataType::Int8, false),
+ ];
+ let (type_ids, offsets, field_type_ids, fields) =
dense_union.into_parts();
+ assert_eq!(field_type_ids, [0, 1]);
+ assert_eq!(
+ field.to_vec(),
+ fields
+ .iter()
+ .cloned()
+ .map(|(field, _)| field)
+ .collect::<Vec<_>>()
+ );
+ assert_eq!(type_ids, [0, 1, 0]);
+ assert!(offsets.is_some());
+ assert_eq!(offsets.as_ref().unwrap(), &[0, 0, 1]);
+
+ let result = UnionArray::try_new(
+ &field_type_ids,
+ type_ids.into_inner(),
+ offsets.map(ScalarBuffer::into_inner),
+ fields,
+ );
+ assert!(result.is_ok());
+ assert_eq!(result.unwrap().len(), 3);
+
+ let mut builder = UnionBuilder::new_sparse();
+ builder.append::<Int32Type>("a", 1).unwrap();
+ builder.append::<Int8Type>("b", 2).unwrap();
+ builder.append::<Int32Type>("a", 3).unwrap();
+ let sparse_union = builder.build().unwrap();
+
+ let (type_ids, offsets, field_type_ids, fields) =
sparse_union.into_parts();
+ assert_eq!(type_ids, [0, 1, 0]);
+ assert!(offsets.is_none());
+
+ let result = UnionArray::try_new(
+ &field_type_ids,
+ type_ids.into_inner(),
+ offsets.map(ScalarBuffer::into_inner),
+ fields,
+ );
+ assert!(result.is_ok());
+ assert_eq!(result.unwrap().len(), 3);
+ }
+
+ #[test]
+ fn into_parts_custom_type_ids() {
+ let mut set_field_type_ids: [i8; 3] = [8, 4, 9];
+ let data_type = DataType::Union(
+ UnionFields::new(
+ set_field_type_ids,
+ [
+ Field::new("strings", DataType::Utf8, false),
+ Field::new("integers", DataType::Int32, false),
+ Field::new("floats", DataType::Float64, false),
+ ],
+ ),
+ UnionMode::Dense,
+ );
+ let string_array = StringArray::from(vec!["foo", "bar", "baz"]);
+ let int_array = Int32Array::from(vec![5, 6, 4]);
+ let float_array = Float64Array::from(vec![10.0]);
+ let type_ids = Buffer::from_vec(vec![4_i8, 8, 4, 8, 9, 4, 8]);
+ let value_offsets = Buffer::from_vec(vec![0_i32, 0, 1, 1, 0, 2, 2]);
+ let data = ArrayData::builder(data_type)
+ .len(7)
+ .buffers(vec![type_ids, value_offsets])
+ .child_data(vec![
+ string_array.into_data(),
+ int_array.into_data(),
+ float_array.into_data(),
+ ])
+ .build()
+ .unwrap();
+ let array = UnionArray::from(data);
+
+ let (type_ids, offsets, field_type_ids, fields) = array.into_parts();
+ set_field_type_ids.sort();
+ assert_eq!(field_type_ids, set_field_type_ids);
+ let result = UnionArray::try_new(
+ &field_type_ids,
+ type_ids.into_inner(),
+ offsets.map(ScalarBuffer::into_inner),
+ fields,
+ );
+ assert!(result.is_ok());
+ let array = result.unwrap();
+ assert_eq!(array.len(), 7);
+ }
}