This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 79bda7d36 Handle non-contiguous type_ids in UnionArray (#3653) (#3654)
79bda7d36 is described below
commit 79bda7d361579cda88fec2eb9b8793ad7f653442
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Fri Feb 3 11:38:09 2023 +0000
Handle non-contiguous type_ids in UnionArray (#3653) (#3654)
---
arrow-array/src/array/union_array.rs | 111 ++++++++++++++++++++++++++++++-----
arrow-ipc/src/writer.rs | 9 +--
2 files changed, 98 insertions(+), 22 deletions(-)
diff --git a/arrow-array/src/array/union_array.rs
b/arrow-array/src/array/union_array.rs
index 5870952d7..f215fb0de 100644
--- a/arrow-array/src/array/union_array.rs
+++ b/arrow-array/src/array/union_array.rs
@@ -107,7 +107,7 @@ use std::any::Any;
#[derive(Clone)]
pub struct UnionArray {
data: ArrayData,
- boxed_fields: Vec<ArrayRef>,
+ boxed_fields: Vec<Option<ArrayRef>>,
}
impl UnionArray {
@@ -229,9 +229,8 @@ impl UnionArray {
/// Panics if the `type_id` provided is less than zero or greater than the
number of types
/// in the `Union`.
pub fn child(&self, type_id: i8) -> &ArrayRef {
- assert!(0 <= type_id);
- assert!((type_id as usize) < self.boxed_fields.len());
- &self.boxed_fields[type_id as usize]
+ let boxed = &self.boxed_fields[type_id as usize];
+ boxed.as_ref().expect("invalid type id")
}
/// Returns the `type_id` for the array slot at `index`.
@@ -264,8 +263,8 @@ impl UnionArray {
pub fn value(&self, i: usize) -> ArrayRef {
let type_id = self.type_id(i);
let value_offset = self.value_offset(i) as usize;
- let child_data = self.boxed_fields[type_id as usize].clone();
- child_data.slice(value_offset, 1)
+ let child = self.child(type_id);
+ child.slice(value_offset, 1)
}
/// Returns the names of the types in the union.
@@ -290,9 +289,14 @@ impl UnionArray {
impl From<ArrayData> for UnionArray {
fn from(data: ArrayData) -> Self {
- let mut boxed_fields = vec![];
- for cd in data.child_data() {
- boxed_fields.push(make_array(cd.clone()));
+ let field_ids = match data.data_type() {
+ DataType::Union(_, ids, _) => ids,
+ d => panic!("UnionArray expected ArrayData with type Union got
{d}"),
+ };
+ let max_id = field_ids.iter().copied().max().unwrap_or_default() as
usize;
+ let mut boxed_fields = vec![None; max_id + 1];
+ for (cd, field_id) in data.child_data().iter().zip(field_ids) {
+ boxed_fields[*field_id as usize] = Some(make_array(cd.clone()));
}
Self { data, boxed_fields }
}
@@ -348,21 +352,27 @@ impl std::fmt::Debug for UnionArray {
writeln!(f, "-- type id buffer:")?;
writeln!(f, "{:?}", self.data().buffers()[0])?;
- if self.is_dense() {
+ let (fields, ids, mode) = match self.data_type() {
+ DataType::Union(f, ids, mode) => (f, ids, mode),
+ _ => unreachable!(),
+ };
+
+ if mode == &UnionMode::Dense {
writeln!(f, "-- offsets buffer:")?;
writeln!(f, "{:?}", self.data().buffers()[1])?;
}
- for (child_index, name) in self.type_names().iter().enumerate() {
- let column = &self.boxed_fields[child_index];
+ assert_eq!(fields.len(), ids.len());
+ for (field, type_id) in fields.iter().zip(ids) {
+ let child = self.child(*type_id);
writeln!(
f,
"-- child {}: \"{}\" ({:?})",
- child_index,
- *name,
- column.data_type()
+ type_id,
+ field.name(),
+ field.data_type()
)?;
- std::fmt::Debug::fmt(column, f)?;
+ std::fmt::Debug::fmt(child, f)?;
writeln!(f)?;
}
writeln!(f, "]")
@@ -374,6 +384,7 @@ mod tests {
use super::*;
use crate::builder::UnionBuilder;
+ use crate::cast::{as_primitive_array, as_string_array};
use crate::types::{Float32Type, Float64Type, Int32Type, Int64Type};
use crate::RecordBatch;
use crate::{Float64Array, Int32Array, Int64Array, StringArray};
@@ -1017,4 +1028,72 @@ mod tests {
let record_batch_slice = record_batch.slice(1, 3);
test_slice_union(record_batch_slice);
}
+
+ #[test]
+ fn test_custom_type_ids() {
+ let data_type = DataType::Union(
+ vec![
+ Field::new("strings", DataType::Utf8, false),
+ Field::new("integers", DataType::Int32, false),
+ Field::new("floats", DataType::Float64, false),
+ ],
+ vec![8, 4, 9],
+ UnionMode::Dense,
+ );
+
+ let string_array = StringArray::from(vec!["foo", "bar", "baz"]);
+ let int_array = Int32Array::from(vec![5, 6, 4]);
+ let float_array = Float64Array::from(vec![10.0]);
+
+ let type_ids = Buffer::from_iter([4_i8, 8, 4, 8, 9, 4, 8]);
+ let value_offsets = Buffer::from_iter([0_i32, 0, 1, 1, 0, 2, 2]);
+
+ let data = ArrayData::builder(data_type)
+ .len(7)
+ .buffers(vec![type_ids, value_offsets])
+ .child_data(vec![
+ string_array.into_data(),
+ int_array.into_data(),
+ float_array.into_data(),
+ ])
+ .build()
+ .unwrap();
+
+ let array = UnionArray::from(data);
+
+ let v = array.value(0);
+ assert_eq!(v.data_type(), &DataType::Int32);
+ assert_eq!(v.len(), 1);
+ assert_eq!(as_primitive_array::<Int32Type>(v.as_ref()).value(0), 5);
+
+ let v = array.value(1);
+ assert_eq!(v.data_type(), &DataType::Utf8);
+ assert_eq!(v.len(), 1);
+ assert_eq!(as_string_array(v.as_ref()).value(0), "foo");
+
+ let v = array.value(2);
+ assert_eq!(v.data_type(), &DataType::Int32);
+ assert_eq!(v.len(), 1);
+ assert_eq!(as_primitive_array::<Int32Type>(v.as_ref()).value(0), 6);
+
+ let v = array.value(3);
+ assert_eq!(v.data_type(), &DataType::Utf8);
+ assert_eq!(v.len(), 1);
+ assert_eq!(as_string_array(v.as_ref()).value(0), "bar");
+
+ let v = array.value(4);
+ assert_eq!(v.data_type(), &DataType::Float64);
+ assert_eq!(v.len(), 1);
+ assert_eq!(as_primitive_array::<Float64Type>(v.as_ref()).value(0),
10.0);
+
+ let v = array.value(5);
+ assert_eq!(v.data_type(), &DataType::Int32);
+ assert_eq!(v.len(), 1);
+ assert_eq!(as_primitive_array::<Int32Type>(v.as_ref()).value(0), 4);
+
+ let v = array.value(6);
+ assert_eq!(v.data_type(), &DataType::Utf8);
+ assert_eq!(v.len(), 1);
+ assert_eq!(as_string_array(v.as_ref()).value(0), "baz");
+ }
}
diff --git a/arrow-ipc/src/writer.rs b/arrow-ipc/src/writer.rs
index 1879dde08..8835cb49f 100644
--- a/arrow-ipc/src/writer.rs
+++ b/arrow-ipc/src/writer.rs
@@ -279,13 +279,10 @@ impl IpcDataGenerator {
write_options,
)?;
}
- DataType::Union(fields, _, _) => {
+ DataType::Union(fields, type_ids, _) => {
let union = as_union_array(column);
- for (field, column) in fields
- .iter()
- .enumerate()
- .map(|(n, f)| (f, union.child(n as i8)))
- {
+ for (field, type_id) in fields.iter().zip(type_ids) {
+ let column = union.child(*type_id);
self.encode_dictionaries(
field,
column,