This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new e5a167695 Add UnionFields (#3955) (#3981)
e5a167695 is described below
commit e5a1676950ab5c04b0a74953ec5418da67cedb45
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Thu Mar 30 14:58:21 2023 +0100
Add UnionFields (#3955) (#3981)
* Add UnionFields (#3955)
* Fix array_cast
* Review feedback
* Clippy
---
arrow-array/src/array/mod.rs | 16 +++--
arrow-array/src/array/union_array.rs | 53 ++++++++-------
arrow-array/src/record_batch.rs | 2 +-
arrow-cast/src/display.rs | 14 ++--
arrow-cast/src/pretty.rs | 42 +++++++-----
arrow-data/src/data/mod.rs | 25 +++----
arrow-data/src/equal/mod.rs | 2 +-
arrow-data/src/equal/union.rs | 26 ++++----
arrow-data/src/equal/utils.rs | 2 +-
arrow-data/src/transform/mod.rs | 6 +-
arrow-integration-test/src/datatype.rs | 22 +++----
arrow-integration-test/src/field.rs | 29 +++++---
arrow-integration-test/src/lib.rs | 9 +--
arrow-ipc/src/convert.rs | 99 +++++++++++++++-------------
arrow-ipc/src/reader.rs | 31 +++++----
arrow-ipc/src/writer.rs | 20 +++---
arrow-schema/src/datatype.rs | 26 ++++----
arrow-schema/src/ffi.rs | 19 ++++--
arrow-schema/src/field.rs | 37 ++---------
arrow-schema/src/fields.rs | 117 ++++++++++++++++++++++++++++++++-
arrow-schema/src/schema.rs | 44 ++++++++-----
arrow/src/datatypes/mod.rs | 2 +-
arrow/src/ffi.rs | 8 +--
arrow/tests/array_cast.rs | 14 ++--
arrow/tests/array_validation.rs | 50 ++++++++------
parquet/src/arrow/arrow_writer/mod.rs | 2 +-
parquet/src/arrow/schema/mod.rs | 2 +-
27 files changed, 430 insertions(+), 289 deletions(-)
diff --git a/arrow-array/src/array/mod.rs b/arrow-array/src/array/mod.rs
index 8d20c6cb2..9a5172d0d 100644
--- a/arrow-array/src/array/mod.rs
+++ b/arrow-array/src/array/mod.rs
@@ -586,7 +586,7 @@ pub fn make_array(data: ArrayData) -> ArrayRef {
DataType::LargeList(_) => Arc::new(LargeListArray::from(data)) as
ArrayRef,
DataType::Struct(_) => Arc::new(StructArray::from(data)) as ArrayRef,
DataType::Map(_, _) => Arc::new(MapArray::from(data)) as ArrayRef,
- DataType::Union(_, _, _) => Arc::new(UnionArray::from(data)) as
ArrayRef,
+ DataType::Union(_, _) => Arc::new(UnionArray::from(data)) as ArrayRef,
DataType::FixedSizeList(_, _) => {
Arc::new(FixedSizeListArray::from(data)) as ArrayRef
}
@@ -740,7 +740,7 @@ mod tests {
use crate::cast::{as_union_array, downcast_array};
use crate::downcast_run_array;
use arrow_buffer::{Buffer, MutableBuffer};
- use arrow_schema::{Field, Fields, UnionMode};
+ use arrow_schema::{Field, Fields, UnionFields, UnionMode};
#[test]
fn test_empty_primitive() {
@@ -874,11 +874,13 @@ mod tests {
fn test_null_union() {
for mode in [UnionMode::Sparse, UnionMode::Dense] {
let data_type = DataType::Union(
- vec![
- Field::new("foo", DataType::Int32, true),
- Field::new("bar", DataType::Int64, true),
- ],
- vec![2, 1],
+ UnionFields::new(
+ vec![2, 1],
+ vec![
+ Field::new("foo", DataType::Int32, true),
+ Field::new("bar", DataType::Int64, true),
+ ],
+ ),
mode,
);
let array = new_null_array(&data_type, 4);
diff --git a/arrow-array/src/array/union_array.rs
b/arrow-array/src/array/union_array.rs
index 00ad94111..335b6b14f 100644
--- a/arrow-array/src/array/union_array.rs
+++ b/arrow-array/src/array/union_array.rs
@@ -19,7 +19,7 @@ use crate::{make_array, Array, ArrayRef};
use arrow_buffer::buffer::NullBuffer;
use arrow_buffer::{Buffer, ScalarBuffer};
use arrow_data::ArrayData;
-use arrow_schema::{ArrowError, DataType, Field, UnionMode};
+use arrow_schema::{ArrowError, DataType, Field, UnionFields, UnionMode};
/// Contains the `UnionArray` type.
///
use std::any::Any;
@@ -145,8 +145,7 @@ impl UnionArray {
value_offsets: Option<Buffer>,
child_arrays: Vec<(Field, ArrayRef)>,
) -> Self {
- let (field_types, field_values): (Vec<_>, Vec<_>) =
- child_arrays.into_iter().unzip();
+ let (fields, field_values): (Vec<_>, Vec<_>) =
child_arrays.into_iter().unzip();
let len = type_ids.len();
let mode = if value_offsets.is_some() {
@@ -156,8 +155,7 @@ impl UnionArray {
};
let builder = ArrayData::builder(DataType::Union(
- field_types,
- Vec::from(field_type_ids),
+ UnionFields::new(field_type_ids.iter().copied(), fields),
mode,
))
.add_buffer(type_ids)
@@ -282,9 +280,9 @@ impl UnionArray {
/// Returns the names of the types in the union.
pub fn type_names(&self) -> Vec<&str> {
match self.data.data_type() {
- DataType::Union(fields, _, _) => fields
+ DataType::Union(fields, _) => fields
.iter()
- .map(|f| f.name().as_str())
+ .map(|(_, f)| f.name().as_str())
.collect::<Vec<&str>>(),
_ => unreachable!("Union array's data type is not a union!"),
}
@@ -293,7 +291,7 @@ impl UnionArray {
/// Returns whether the `UnionArray` is dense (or sparse if `false`).
fn is_dense(&self) -> bool {
match self.data.data_type() {
- DataType::Union(_, _, mode) => mode == &UnionMode::Dense,
+ DataType::Union(_, mode) => mode == &UnionMode::Dense,
_ => unreachable!("Union array's data type is not a union!"),
}
}
@@ -307,8 +305,8 @@ impl UnionArray {
impl From<ArrayData> for UnionArray {
fn from(data: ArrayData) -> Self {
- let (field_ids, mode) = match data.data_type() {
- DataType::Union(_, ids, mode) => (ids, *mode),
+ let (fields, mode) = match data.data_type() {
+ DataType::Union(fields, mode) => (fields, *mode),
d => panic!("UnionArray expected ArrayData with type Union got
{d}"),
};
let (type_ids, offsets) = match mode {
@@ -326,10 +324,10 @@ impl From<ArrayData> for UnionArray {
),
};
- let max_id = field_ids.iter().copied().max().unwrap_or_default() as
usize;
+ let max_id = fields.iter().map(|(i, _)| i).max().unwrap_or_default()
as usize;
let mut boxed_fields = vec![None; max_id + 1];
- for (cd, field_id) in data.child_data().iter().zip(field_ids) {
- boxed_fields[*field_id as usize] = Some(make_array(cd.clone()));
+ for (cd, (field_id, _)) in data.child_data().iter().zip(fields.iter())
{
+ boxed_fields[field_id as usize] = Some(make_array(cd.clone()));
}
Self {
data,
@@ -402,19 +400,18 @@ impl std::fmt::Debug for UnionArray {
writeln!(f, "-- type id buffer:")?;
writeln!(f, "{:?}", self.type_ids)?;
- let (fields, ids) = match self.data_type() {
- DataType::Union(f, ids, _) => (f, ids),
- _ => unreachable!(),
- };
-
if let Some(offsets) = &self.offsets {
writeln!(f, "-- offsets buffer:")?;
writeln!(f, "{:?}", offsets)?;
}
- assert_eq!(fields.len(), ids.len());
- for (field, type_id) in fields.iter().zip(ids) {
- let child = self.child(*type_id);
+ let fields = match self.data_type() {
+ DataType::Union(fields, _) => fields,
+ _ => unreachable!(),
+ };
+
+ for (type_id, field) in fields.iter() {
+ let child = self.child(type_id);
writeln!(
f,
"-- child {}: \"{}\" ({:?})",
@@ -1058,12 +1055,14 @@ mod tests {
#[test]
fn test_custom_type_ids() {
let data_type = DataType::Union(
- vec![
- Field::new("strings", DataType::Utf8, false),
- Field::new("integers", DataType::Int32, false),
- Field::new("floats", DataType::Float64, false),
- ],
- vec![8, 4, 9],
+ UnionFields::new(
+ vec![8, 4, 9],
+ vec![
+ Field::new("strings", DataType::Utf8, false),
+ Field::new("integers", DataType::Int32, false),
+ Field::new("floats", DataType::Float64, false),
+ ],
+ ),
UnionMode::Dense,
);
diff --git a/arrow-array/src/record_batch.rs b/arrow-array/src/record_batch.rs
index 2754d04bf..17b1f04e8 100644
--- a/arrow-array/src/record_batch.rs
+++ b/arrow-array/src/record_batch.rs
@@ -590,7 +590,7 @@ mod tests {
let record_batch =
RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a),
Arc::new(b)])
.unwrap();
- assert_eq!(record_batch.get_array_memory_size(), 628);
+ assert_eq!(record_batch.get_array_memory_size(), 564);
}
fn check_batch(record_batch: RecordBatch, num_rows: usize) {
diff --git a/arrow-cast/src/display.rs b/arrow-cast/src/display.rs
index d10903697..0bca9ce65 100644
--- a/arrow-cast/src/display.rs
+++ b/arrow-cast/src/display.rs
@@ -278,7 +278,7 @@ fn make_formatter<'a>(
}
DataType::Struct(_) => array_format(as_struct_array(array), options),
DataType::Map(_, _) => array_format(as_map_array(array), options),
- DataType::Union(_, _, _) => array_format(as_union_array(array),
options),
+ DataType::Union(_, _) => array_format(as_union_array(array), options),
d => Err(ArrowError::NotYetImplemented(format!("formatting {d} is not
yet supported"))),
}
}
@@ -801,16 +801,16 @@ impl<'a> DisplayIndexState<'a> for &'a UnionArray {
);
fn prepare(&self, options: &FormatOptions<'a>) -> Result<Self::State,
ArrowError> {
- let (fields, type_ids, mode) = match (*self).data_type() {
- DataType::Union(fields, type_ids, mode) => (fields, type_ids,
mode),
+ let (fields, mode) = match (*self).data_type() {
+ DataType::Union(fields, mode) => (fields, mode),
_ => unreachable!(),
};
- let max_id = type_ids.iter().copied().max().unwrap_or_default() as
usize;
+ let max_id = fields.iter().map(|(id, _)| id).max().unwrap_or_default()
as usize;
let mut out: Vec<Option<FieldDisplay>> = (0..max_id + 1).map(|_|
None).collect();
- for (i, field) in type_ids.iter().zip(fields) {
- let formatter = make_formatter(self.child(*i).as_ref(), options)?;
- out[*i as usize] = Some((field.name().as_str(), formatter))
+ for (i, field) in fields.iter() {
+ let formatter = make_formatter(self.child(i).as_ref(), options)?;
+ out[i as usize] = Some((field.name().as_str(), formatter))
}
Ok((out, *mode))
}
diff --git a/arrow-cast/src/pretty.rs b/arrow-cast/src/pretty.rs
index ffa5af82d..818e9d3c0 100644
--- a/arrow-cast/src/pretty.rs
+++ b/arrow-cast/src/pretty.rs
@@ -703,11 +703,13 @@ mod tests {
let schema = Schema::new(vec![Field::new(
"Teamsters",
DataType::Union(
- vec![
- Field::new("a", DataType::Int32, false),
- Field::new("b", DataType::Float64, false),
- ],
- vec![0, 1],
+ UnionFields::new(
+ vec![0, 1],
+ vec![
+ Field::new("a", DataType::Int32, false),
+ Field::new("b", DataType::Float64, false),
+ ],
+ ),
UnionMode::Dense,
),
false,
@@ -743,11 +745,13 @@ mod tests {
let schema = Schema::new(vec![Field::new(
"Teamsters",
DataType::Union(
- vec![
- Field::new("a", DataType::Int32, false),
- Field::new("b", DataType::Float64, false),
- ],
- vec![0, 1],
+ UnionFields::new(
+ vec![0, 1],
+ vec![
+ Field::new("a", DataType::Int32, false),
+ Field::new("b", DataType::Float64, false),
+ ],
+ ),
UnionMode::Sparse,
),
false,
@@ -785,11 +789,13 @@ mod tests {
let inner_field = Field::new(
"European Union",
DataType::Union(
- vec![
- Field::new("b", DataType::Int32, false),
- Field::new("c", DataType::Float64, false),
- ],
- vec![0, 1],
+ UnionFields::new(
+ vec![0, 1],
+ vec![
+ Field::new("b", DataType::Int32, false),
+ Field::new("c", DataType::Float64, false),
+ ],
+ ),
UnionMode::Dense,
),
false,
@@ -809,8 +815,10 @@ mod tests {
let schema = Schema::new(vec![Field::new(
"Teamsters",
DataType::Union(
- vec![Field::new("a", DataType::Int32, true), inner_field],
- vec![0, 1],
+ UnionFields::new(
+ vec![0, 1],
+ vec![Field::new("a", DataType::Int32, true), inner_field],
+ ),
UnionMode::Sparse,
),
false,
diff --git a/arrow-data/src/data/mod.rs b/arrow-data/src/data/mod.rs
index c47c83663..581d4a10c 100644
--- a/arrow-data/src/data/mod.rs
+++ b/arrow-data/src/data/mod.rs
@@ -136,7 +136,7 @@ pub(crate) fn new_buffers(data_type: &DataType, capacity:
usize) -> [MutableBuff
MutableBuffer::new(capacity * mem::size_of::<u8>()),
empty_buffer,
],
- DataType::Union(_, _, mode) => {
+ DataType::Union(_, mode) => {
let type_ids = MutableBuffer::new(capacity * mem::size_of::<i8>());
match mode {
UnionMode::Sparse => [type_ids, empty_buffer],
@@ -162,7 +162,7 @@ pub(crate) fn into_buffers(
| DataType::Binary
| DataType::LargeUtf8
| DataType::LargeBinary => vec![buffer1.into(), buffer2.into()],
- DataType::Union(_, _, mode) => {
+ DataType::Union(_, mode) => {
match mode {
// Based on Union's DataTypeLayout
UnionMode::Sparse => vec![buffer1.into()],
@@ -621,8 +621,9 @@ impl ArrayData {
vec![ArrayData::new_empty(v.as_ref())],
true,
),
- DataType::Union(f, i, mode) => {
- let ids =
Buffer::from_iter(std::iter::repeat(i[0]).take(len));
+ DataType::Union(f, mode) => {
+ let (id, _) = f.iter().next().unwrap();
+ let ids =
Buffer::from_iter(std::iter::repeat(id).take(len));
let buffers = match mode {
UnionMode::Sparse => vec![ids],
UnionMode::Dense => {
@@ -634,7 +635,7 @@ impl ArrayData {
let children = f
.iter()
.enumerate()
- .map(|(idx, f)| match idx {
+ .map(|(idx, (_, f))| match idx {
0 => Self::new_null(f.data_type(), len),
_ => Self::new_empty(f.data_type()),
})
@@ -986,10 +987,10 @@ impl ArrayData {
}
Ok(())
}
- DataType::Union(fields, _, mode) => {
+ DataType::Union(fields, mode) => {
self.validate_num_child_data(fields.len())?;
- for (i, field) in fields.iter().enumerate() {
+ for (i, (_, field)) in fields.iter().enumerate() {
let field_data = self.get_valid_child_data(i,
field.data_type())?;
if mode == &UnionMode::Sparse
@@ -1255,7 +1256,7 @@ impl ArrayData {
let child = &self.child_data[0];
self.validate_offsets_full::<i64>(child.len)
}
- DataType::Union(_, _, _) => {
+ DataType::Union(_, _) => {
// Validate Union Array as part of implementing new Union
semantics
// See comments in `ArrayData::validate()`
// https://github.com/apache/arrow-rs/issues/85
@@ -1568,7 +1569,7 @@ pub fn layout(data_type: &DataType) -> DataTypeLayout {
DataType::LargeList(_) =>
DataTypeLayout::new_fixed_width(size_of::<i64>()),
DataType::Struct(_) => DataTypeLayout::new_empty(), // all in child
data,
DataType::RunEndEncoded(_, _) => DataTypeLayout::new_empty(), // all
in child data,
- DataType::Union(_, _, mode) => {
+ DataType::Union(_, mode) => {
let type_ids = BufferSpec::FixedWidth {
byte_width: size_of::<i8>(),
};
@@ -1823,7 +1824,7 @@ impl From<ArrayData> for ArrayDataBuilder {
#[cfg(test)]
mod tests {
use super::*;
- use arrow_schema::Field;
+ use arrow_schema::{Field, UnionFields};
// See arrow/tests/array_data_validation.rs for test of array validation
@@ -2072,8 +2073,8 @@ mod tests {
#[test]
fn test_into_buffers() {
let data_types = vec![
- DataType::Union(vec![], vec![], UnionMode::Dense),
- DataType::Union(vec![], vec![], UnionMode::Sparse),
+ DataType::Union(UnionFields::empty(), UnionMode::Dense),
+ DataType::Union(UnionFields::empty(), UnionMode::Sparse),
];
for data_type in data_types {
diff --git a/arrow-data/src/equal/mod.rs b/arrow-data/src/equal/mod.rs
index 871a312ca..fbc868d3f 100644
--- a/arrow-data/src/equal/mod.rs
+++ b/arrow-data/src/equal/mod.rs
@@ -112,7 +112,7 @@ fn equal_values(
fixed_list_equal(lhs, rhs, lhs_start, rhs_start, len)
}
DataType::Struct(_) => struct_equal(lhs, rhs, lhs_start, rhs_start,
len),
- DataType::Union(_, _, _) => union_equal(lhs, rhs, lhs_start,
rhs_start, len),
+ DataType::Union(_, _) => union_equal(lhs, rhs, lhs_start, rhs_start,
len),
DataType::Dictionary(data_type, _) => match data_type.as_ref() {
DataType::Int8 => dictionary_equal::<i8>(lhs, rhs, lhs_start,
rhs_start, len),
DataType::Int16 => {
diff --git a/arrow-data/src/equal/union.rs b/arrow-data/src/equal/union.rs
index fdf770096..4f04bc287 100644
--- a/arrow-data/src/equal/union.rs
+++ b/arrow-data/src/equal/union.rs
@@ -16,7 +16,7 @@
// under the License.
use crate::data::ArrayData;
-use arrow_schema::{DataType, UnionMode};
+use arrow_schema::{DataType, UnionFields, UnionMode};
use super::equal_range;
@@ -28,8 +28,8 @@ fn equal_dense(
rhs_type_ids: &[i8],
lhs_offsets: &[i32],
rhs_offsets: &[i32],
- lhs_field_type_ids: &[i8],
- rhs_field_type_ids: &[i8],
+ lhs_fields: &UnionFields,
+ rhs_fields: &UnionFields,
) -> bool {
let offsets = lhs_offsets.iter().zip(rhs_offsets.iter());
@@ -38,13 +38,13 @@ fn equal_dense(
.zip(rhs_type_ids.iter())
.zip(offsets)
.all(|((l_type_id, r_type_id), (l_offset, r_offset))| {
- let lhs_child_index = lhs_field_type_ids
+ let lhs_child_index = lhs_fields
.iter()
- .position(|r| r == l_type_id)
+ .position(|(r, _)| r == *l_type_id)
.unwrap();
- let rhs_child_index = rhs_field_type_ids
+ let rhs_child_index = rhs_fields
.iter()
- .position(|r| r == r_type_id)
+ .position(|(r, _)| r == *r_type_id)
.unwrap();
let lhs_values = &lhs.child_data()[lhs_child_index];
let rhs_values = &rhs.child_data()[rhs_child_index];
@@ -89,8 +89,8 @@ pub(super) fn union_equal(
match (lhs.data_type(), rhs.data_type()) {
(
- DataType::Union(_, lhs_type_ids, UnionMode::Dense),
- DataType::Union(_, rhs_type_ids, UnionMode::Dense),
+ DataType::Union(lhs_fields, UnionMode::Dense),
+ DataType::Union(rhs_fields, UnionMode::Dense),
) => {
let lhs_offsets = lhs.buffer::<i32>(1);
let rhs_offsets = rhs.buffer::<i32>(1);
@@ -106,13 +106,13 @@ pub(super) fn union_equal(
rhs_type_id_range,
lhs_offsets_range,
rhs_offsets_range,
- lhs_type_ids,
- rhs_type_ids,
+ lhs_fields,
+ rhs_fields,
)
}
(
- DataType::Union(_, _, UnionMode::Sparse),
- DataType::Union(_, _, UnionMode::Sparse),
+ DataType::Union(_, UnionMode::Sparse),
+ DataType::Union(_, UnionMode::Sparse),
) => {
lhs_type_id_range == rhs_type_id_range
&& equal_sparse(lhs, rhs, lhs_start, rhs_start, len)
diff --git a/arrow-data/src/equal/utils.rs b/arrow-data/src/equal/utils.rs
index 6b9a7940d..fa6211542 100644
--- a/arrow-data/src/equal/utils.rs
+++ b/arrow-data/src/equal/utils.rs
@@ -59,7 +59,7 @@ pub(super) fn equal_nulls(
#[inline]
pub(super) fn base_equal(lhs: &ArrayData, rhs: &ArrayData) -> bool {
let equal_type = match (lhs.data_type(), rhs.data_type()) {
- (DataType::Union(l_fields, _, l_mode), DataType::Union(r_fields, _,
r_mode)) => {
+ (DataType::Union(l_fields, l_mode), DataType::Union(r_fields, r_mode))
=> {
l_fields == r_fields && l_mode == r_mode
}
(DataType::Map(l_field, l_sorted), DataType::Map(r_field, r_sorted))
=> {
diff --git a/arrow-data/src/transform/mod.rs b/arrow-data/src/transform/mod.rs
index 2719b96b6..ccdbaec3b 100644
--- a/arrow-data/src/transform/mod.rs
+++ b/arrow-data/src/transform/mod.rs
@@ -231,7 +231,7 @@ fn build_extend(array: &ArrayData) -> Extend {
DataType::FixedSizeBinary(_) => fixed_binary::build_extend(array),
DataType::Float16 => primitive::build_extend::<f16>(array),
DataType::FixedSizeList(_, _) => fixed_size_list::build_extend(array),
- DataType::Union(_, _, mode) => match mode {
+ DataType::Union(_, mode) => match mode {
UnionMode::Sparse => union::build_extend_sparse(array),
UnionMode::Dense => union::build_extend_dense(array),
},
@@ -283,7 +283,7 @@ fn build_extend_nulls(data_type: &DataType) -> ExtendNulls {
DataType::FixedSizeBinary(_) => fixed_binary::extend_nulls,
DataType::Float16 => primitive::extend_nulls::<f16>,
DataType::FixedSizeList(_, _) => fixed_size_list::extend_nulls,
- DataType::Union(_, _, mode) => match mode {
+ DataType::Union(_, mode) => match mode {
UnionMode::Sparse => union::extend_nulls_sparse,
UnionMode::Dense => union::extend_nulls_dense,
},
@@ -501,7 +501,7 @@ impl<'a> MutableArrayData<'a> {
.collect::<Vec<_>>();
vec![MutableArrayData::new(childs, use_nulls, array_capacity)]
}
- DataType::Union(fields, _, _) => (0..fields.len())
+ DataType::Union(fields, _) => (0..fields.len())
.map(|i| {
let child_arrays = arrays
.iter()
diff --git a/arrow-integration-test/src/datatype.rs
b/arrow-integration-test/src/datatype.rs
index a08368d58..5a5dd67fc 100644
--- a/arrow-integration-test/src/datatype.rs
+++ b/arrow-integration-test/src/datatype.rs
@@ -17,6 +17,7 @@
use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, TimeUnit,
UnionMode};
use arrow::error::{ArrowError, Result};
+use std::sync::Arc;
/// Parse a data type from a JSON representation.
pub fn data_type_from_json(json: &serde_json::Value) -> Result<DataType> {
@@ -229,20 +230,15 @@ pub fn data_type_from_json(json: &serde_json::Value) ->
Result<DataType> {
"Unknown union mode {mode:?} for union"
)));
};
- if let Some(type_ids) = map.get("typeIds") {
- let type_ids = type_ids
- .as_array()
- .unwrap()
+ if let Some(values) = map.get("typeIds") {
+ let field = Arc::new(default_field);
+ let values = values.as_array().unwrap();
+ let fields = values
.iter()
- .map(|t| t.as_i64().unwrap() as i8)
- .collect::<Vec<_>>();
+ .map(|t| (t.as_i64().unwrap() as i8,
field.clone()))
+ .collect();
- let default_fields = type_ids
- .iter()
- .map(|_| default_field.clone())
- .collect::<Vec<_>>();
-
- Ok(DataType::Union(default_fields, type_ids,
union_mode))
+ Ok(DataType::Union(fields, union_mode))
} else {
Err(ArrowError::ParseError(
"Expecting a typeIds for union ".to_string(),
@@ -290,7 +286,7 @@ pub fn data_type_to_json(data_type: &DataType) ->
serde_json::Value {
json!({"name": "fixedsizebinary", "byteWidth": byte_width})
}
DataType::Struct(_) => json!({"name": "struct"}),
- DataType::Union(_, _, _) => json!({"name": "union"}),
+ DataType::Union(_, _) => json!({"name": "union"}),
DataType::List(_) => json!({ "name": "list"}),
DataType::LargeList(_) => json!({ "name": "largelist"}),
DataType::FixedSizeList(_, length) => {
diff --git a/arrow-integration-test/src/field.rs
b/arrow-integration-test/src/field.rs
index a60cd91c5..c714fe467 100644
--- a/arrow-integration-test/src/field.rs
+++ b/arrow-integration-test/src/field.rs
@@ -19,6 +19,7 @@ use crate::{data_type_from_json, data_type_to_json};
use arrow::datatypes::{DataType, Field};
use arrow::error::{ArrowError, Result};
use std::collections::HashMap;
+use std::sync::Arc;
/// Parse a `Field` definition from a JSON representation.
pub fn field_from_json(json: &serde_json::Value) -> Result<Field> {
@@ -194,11 +195,17 @@ pub fn field_from_json(json: &serde_json::Value) ->
Result<Field> {
}
}
}
- DataType::Union(_, type_ids, mode) => match
map.get("children") {
+ DataType::Union(fields, mode) => match map.get("children") {
Some(Value::Array(values)) => {
- let union_fields: Vec<Field> =
-
values.iter().map(field_from_json).collect::<Result<_>>()?;
- DataType::Union(union_fields, type_ids, mode)
+ let fields = fields
+ .iter()
+ .zip(values)
+ .map(|((id, _), value)| {
+ Ok((id, Arc::new(field_from_json(value)?)))
+ })
+ .collect::<Result<_>>()?;
+
+ DataType::Union(fields, mode)
}
Some(_) => {
return Err(ArrowError::ParseError(
@@ -296,7 +303,7 @@ pub fn field_to_json(field: &Field) -> serde_json::Value {
#[cfg(test)]
mod tests {
use super::*;
- use arrow::datatypes::{Fields, UnionMode};
+ use arrow::datatypes::{Fields, UnionFields, UnionMode};
use serde_json::Value;
#[test]
@@ -569,11 +576,13 @@ mod tests {
let expected = Field::new(
"my_union",
DataType::Union(
- vec![
- Field::new("f1", DataType::Int32, true),
- Field::new("f2", DataType::Utf8, true),
- ],
- vec![5, 7],
+ UnionFields::new(
+ vec![5, 7],
+ vec![
+ Field::new("f1", DataType::Int32, true),
+ Field::new("f2", DataType::Utf8, true),
+ ],
+ ),
UnionMode::Sparse,
),
false,
diff --git a/arrow-integration-test/src/lib.rs
b/arrow-integration-test/src/lib.rs
index 06f16ca1d..61bcbea5a 100644
--- a/arrow-integration-test/src/lib.rs
+++ b/arrow-integration-test/src/lib.rs
@@ -858,7 +858,7 @@ pub fn array_from_json(
let array = MapArray::from(array_data);
Ok(Arc::new(array))
}
- DataType::Union(fields, field_type_ids, _) => {
+ DataType::Union(fields, _) => {
let type_ids = if let Some(type_id) = json_col.type_id {
type_id
} else {
@@ -874,13 +874,14 @@ pub fn array_from_json(
});
let mut children: Vec<(Field, Arc<dyn Array>)> = vec![];
- for (field, col) in fields.iter().zip(json_col.children.unwrap()) {
+ for ((_, field), col) in
fields.iter().zip(json_col.children.unwrap()) {
let array = array_from_json(field, col, dictionaries)?;
- children.push((field.clone(), array));
+ children.push((field.as_ref().clone(), array));
}
+ let field_type_ids = fields.iter().map(|(id, _)|
id).collect::<Vec<_>>();
let array = UnionArray::try_new(
- field_type_ids,
+ &field_type_ids,
Buffer::from(&type_ids.to_byte_slice()),
offset,
children,
diff --git a/arrow-ipc/src/convert.rs b/arrow-ipc/src/convert.rs
index 7e44f37d4..8ca0d514f 100644
--- a/arrow-ipc/src/convert.rs
+++ b/arrow-ipc/src/convert.rs
@@ -410,16 +410,16 @@ pub(crate) fn get_data_type(field: crate::Field,
may_be_dictionary: bool) -> Dat
let mut fields = vec![];
if let Some(children) = field.children() {
for i in 0..children.len() {
- fields.push(children.get(i).into());
+ fields.push(Field::from(children.get(i)));
}
};
- let type_ids: Vec<i8> = match union.typeIds() {
- None => (0_i8..fields.len() as i8).collect(),
- Some(ids) => ids.iter().map(|i| i as i8).collect(),
+ let fields = match union.typeIds() {
+ None => UnionFields::new(0_i8..fields.len() as i8, fields),
+ Some(ids) => UnionFields::new(ids.iter().map(|i| i as i8),
fields),
};
- DataType::Union(fields, type_ids, union_mode)
+ DataType::Union(fields, union_mode)
}
t => unimplemented!("Type {:?} not supported", t),
}
@@ -769,9 +769,9 @@ pub(crate) fn get_fb_field_type<'a>(
children: Some(fbb.create_vector(&empty_fields[..])),
}
}
- Union(fields, type_ids, mode) => {
+ Union(fields, mode) => {
let mut children = vec![];
- for field in fields {
+ for (_, field) in fields.iter() {
children.push(build_field(fbb, field));
}
@@ -781,7 +781,7 @@ pub(crate) fn get_fb_field_type<'a>(
};
let fbb_type_ids = fbb
- .create_vector(&type_ids.iter().map(|t| *t as
i32).collect::<Vec<_>>());
+ .create_vector(&fields.iter().map(|(t, _)| t as
i32).collect::<Vec<_>>());
let mut builder = crate::UnionBuilder::new(fbb);
builder.add_mode(union_mode);
builder.add_typeIds(fbb_type_ids);
@@ -962,38 +962,47 @@ mod tests {
Field::new(
"union<int64, list[union<date32, list[union<>]>]>",
DataType::Union(
- vec![
- Field::new("int64", DataType::Int64, true),
- Field::new(
- "list[union<date32, list[union<>]>]",
- DataType::List(Box::new(Field::new(
- "union<date32, list[union<>]>",
- DataType::Union(
- vec![
- Field::new("date32",
DataType::Date32, true),
- Field::new(
- "list[union<>]",
-
DataType::List(Box::new(Field::new(
- "union",
- DataType::Union(
- vec![],
- vec![],
- UnionMode::Sparse,
+ UnionFields::new(
+ vec![0, 1],
+ vec![
+ Field::new("int64", DataType::Int64, true),
+ Field::new(
+ "list[union<date32, list[union<>]>]",
+ DataType::List(Box::new(Field::new(
+ "union<date32, list[union<>]>",
+ DataType::Union(
+ UnionFields::new(
+ vec![0, 1],
+ vec![
+ Field::new(
+ "date32",
+ DataType::Date32,
+ true,
+ ),
+ Field::new(
+ "list[union<>]",
+
DataType::List(Box::new(
+ Field::new(
+ "union",
+
DataType::Union(
+
UnionFields::empty(),
+
UnionMode::Sparse,
+ ),
+ false,
+ ),
+ )),
+ false,
),
- false,
- ))),
- false,
+ ],
),
- ],
- vec![0, 1],
- UnionMode::Dense,
- ),
+ UnionMode::Dense,
+ ),
+ false,
+ ))),
false,
- ))),
- false,
- ),
- ],
- vec![0, 1],
+ ),
+ ],
+ ),
UnionMode::Sparse,
),
false,
@@ -1001,22 +1010,24 @@ mod tests {
Field::new("struct<>", DataType::Struct(Fields::empty()),
true),
Field::new(
"union<>",
- DataType::Union(vec![], vec![], UnionMode::Dense),
+ DataType::Union(UnionFields::empty(), UnionMode::Dense),
true,
),
Field::new(
"union<>",
- DataType::Union(vec![], vec![], UnionMode::Sparse),
+ DataType::Union(UnionFields::empty(), UnionMode::Sparse),
true,
),
Field::new(
"union<int32, utf8>",
DataType::Union(
- vec![
- Field::new("int32", DataType::Int32, true),
- Field::new("utf8", DataType::Utf8, true),
- ],
- vec![2, 3], // non-default type ids
+ UnionFields::new(
+ vec![2, 3], // non-default type ids
+ vec![
+ Field::new("int32", DataType::Int32, true),
+ Field::new("utf8", DataType::Utf8, true),
+ ],
+ ),
UnionMode::Dense,
),
true,
diff --git a/arrow-ipc/src/reader.rs b/arrow-ipc/src/reader.rs
index 4597ed82d..4f2e51336 100644
--- a/arrow-ipc/src/reader.rs
+++ b/arrow-ipc/src/reader.rs
@@ -263,7 +263,7 @@ fn create_array(
value_array.clone(),
)?
}
- Union(fields, field_type_ids, mode) => {
+ Union(fields, mode) => {
let union_node = nodes.get(node_index);
node_index += 1;
@@ -292,9 +292,10 @@ fn create_array(
UnionMode::Sparse => None,
};
- let mut children = vec![];
+ let mut children = Vec::with_capacity(fields.len());
+ let mut ids = Vec::with_capacity(fields.len());
- for field in fields {
+ for (id, field) in fields.iter() {
let triple = create_array(
nodes,
field,
@@ -310,11 +311,11 @@ fn create_array(
node_index = triple.1;
buffer_index = triple.2;
- children.push((field.clone(), triple.0));
+ children.push((field.as_ref().clone(), triple.0));
+ ids.push(id);
}
- let array =
- UnionArray::try_new(field_type_ids, type_ids, value_offsets,
children)?;
+ let array = UnionArray::try_new(&ids, type_ids, value_offsets,
children)?;
Arc::new(array)
}
Null => {
@@ -418,7 +419,7 @@ fn skip_field(
node_index += 1;
buffer_index += 2;
}
- Union(fields, _field_type_ids, mode) => {
+ Union(fields, mode) => {
node_index += 1;
buffer_index += 1;
@@ -429,7 +430,7 @@ fn skip_field(
UnionMode::Sparse => {}
};
- for field in fields {
+ for (_, field) in fields.iter() {
let tuple = skip_field(field.data_type(), node_index,
buffer_index)?;
node_index = tuple.0;
@@ -1265,11 +1266,15 @@ mod tests {
let dict_data_type =
DataType::Dictionary(Box::new(key_type), Box::new(value_type));
- let union_fileds = vec![
- Field::new("a", DataType::Int32, false),
- Field::new("b", DataType::Float64, false),
- ];
- let union_data_type = DataType::Union(union_fileds, vec![0, 1],
UnionMode::Dense);
+ let union_fields = UnionFields::new(
+ vec![0, 1],
+ vec![
+ Field::new("a", DataType::Int32, false),
+ Field::new("b", DataType::Float64, false),
+ ],
+ );
+
+ let union_data_type = DataType::Union(union_fields, UnionMode::Dense);
let struct_fields = Fields::from(vec![
Field::new("id", DataType::Int32, false),
diff --git a/arrow-ipc/src/writer.rs b/arrow-ipc/src/writer.rs
index ceb9b6ffa..0e999dc72 100644
--- a/arrow-ipc/src/writer.rs
+++ b/arrow-ipc/src/writer.rs
@@ -298,10 +298,10 @@ impl IpcDataGenerator {
write_options,
)?;
}
- DataType::Union(fields, type_ids, _) => {
+ DataType::Union(fields, _) => {
let union = as_union_array(column);
- for (field, type_id) in fields.iter().zip(type_ids) {
- let column = union.child(*type_id);
+ for (type_id, field) in fields.iter() {
+ let column = union.child(type_id);
self.encode_dictionaries(
field,
column,
@@ -1069,7 +1069,7 @@ fn has_validity_bitmap(data_type: &DataType,
write_options: &IpcWriteOptions) ->
} else {
!matches!(
data_type,
- DataType::Null | DataType::Union(_, _, _) |
DataType::RunEndEncoded(_, _)
+ DataType::Null | DataType::Union(_, _) |
DataType::RunEndEncoded(_, _)
)
}
}
@@ -1781,11 +1781,13 @@ mod tests {
let schema = Schema::new(vec![Field::new(
"union",
DataType::Union(
- vec![
- Field::new("a", DataType::Int32, false),
- Field::new("c", DataType::Float64, false),
- ],
- vec![0, 1],
+ UnionFields::new(
+ vec![0, 1],
+ vec![
+ Field::new("a", DataType::Int32, false),
+ Field::new("c", DataType::Float64, false),
+ ],
+ ),
UnionMode::Sparse,
),
true,
diff --git a/arrow-schema/src/datatype.rs b/arrow-schema/src/datatype.rs
index 58747fb26..57a5c6838 100644
--- a/arrow-schema/src/datatype.rs
+++ b/arrow-schema/src/datatype.rs
@@ -19,7 +19,7 @@ use std::fmt;
use std::sync::Arc;
use crate::field::Field;
-use crate::Fields;
+use crate::{Fields, UnionFields};
/// The set of datatypes that are supported by this implementation of Apache
Arrow.
///
@@ -194,10 +194,9 @@ pub enum DataType {
Struct(Fields),
/// A nested datatype that can represent slots of differing types.
Components:
///
- /// 1. [`Field`] for each possible child type the Union can hold
- /// 2. The corresponding `type_id` used to identify which Field
- /// 3. The type of union (Sparse or Dense)
- Union(Vec<Field>, Vec<i8>, UnionMode),
+ /// 1. [`UnionFields`]
+ /// 2. The type of union (Sparse or Dense)
+ Union(UnionFields, UnionMode),
/// A dictionary encoded array (`key_type`, `value_type`), where
/// each array element is an index of `key_type` into an
/// associated dictionary of `value_type`.
@@ -384,7 +383,7 @@ impl DataType {
| FixedSizeList(_, _)
| LargeList(_)
| Struct(_)
- | Union(_, _, _)
+ | Union(_, _)
| Map(_, _) => true,
_ => false,
}
@@ -446,7 +445,7 @@ impl DataType {
DataType::List(_) | DataType::LargeList(_) | DataType::Map(_, _)
=> None,
DataType::FixedSizeList(_, _) => None,
DataType::Struct(_) => None,
- DataType::Union(_, _, _) => None,
+ DataType::Union(_, _) => None,
DataType::Dictionary(_, _) => None,
DataType::RunEndEncoded(_, _) => None,
}
@@ -492,13 +491,7 @@ impl DataType {
| DataType::LargeList(field)
| DataType::Map(field, _) => field.size(),
DataType::Struct(fields) => fields.size(),
- DataType::Union(fields, _, _) => {
- fields
- .iter()
- .map(|field| field.size() -
std::mem::size_of_val(field))
- .sum::<usize>()
- + (std::mem::size_of::<Field>() * fields.capacity())
- }
+ DataType::Union(fields, _) => fields.size(),
DataType::Dictionary(dt1, dt2) => dt1.size() + dt2.size(),
DataType::RunEndEncoded(run_ends, values) => {
run_ends.size() - std::mem::size_of_val(run_ends) +
values.size()
@@ -670,4 +663,9 @@ mod tests {
Box::new(list)
)));
}
+
+ #[test]
+ fn size_should_not_regress() {
+ assert_eq!(std::mem::size_of::<DataType>(), 24);
+ }
}
diff --git a/arrow-schema/src/ffi.rs b/arrow-schema/src/ffi.rs
index 0cfc1800f..72afc5b0b 100644
--- a/arrow-schema/src/ffi.rs
+++ b/arrow-schema/src/ffi.rs
@@ -34,7 +34,9 @@
//! assert_eq!(schema, back);
//! ```
-use crate::{ArrowError, DataType, Field, FieldRef, Schema, TimeUnit,
UnionMode};
+use crate::{
+ ArrowError, DataType, Field, FieldRef, Schema, TimeUnit, UnionFields,
UnionMode,
+};
use bitflags::bitflags;
use std::sync::Arc;
use std::{
@@ -484,7 +486,7 @@ impl TryFrom<&FFI_ArrowSchema> for DataType {
));
}
- DataType::Union(fields, type_ids, UnionMode::Dense)
+ DataType::Union(UnionFields::new(type_ids, fields),
UnionMode::Dense)
}
// SparseUnion
["+us", extra] => {
@@ -506,7 +508,7 @@ impl TryFrom<&FFI_ArrowSchema> for DataType {
));
}
- DataType::Union(fields, type_ids, UnionMode::Sparse)
+ DataType::Union(UnionFields::new(type_ids, fields),
UnionMode::Sparse)
}
// Timestamps in format "tts:" and "tts:America/New_York"
for no timezones and timezones resp.
@@ -585,9 +587,9 @@ impl TryFrom<&DataType> for FFI_ArrowSchema {
| DataType::Map(child, _) => {
vec![FFI_ArrowSchema::try_from(child.as_ref())?]
}
- DataType::Union(fields, _, _) => fields
+ DataType::Union(fields, _) => fields
.iter()
- .map(FFI_ArrowSchema::try_from)
+ .map(|(_, f)| f.as_ref().try_into())
.collect::<Result<Vec<_>, ArrowError>>()?,
DataType::Struct(fields) => fields
.iter()
@@ -658,8 +660,11 @@ fn get_format_string(dtype: &DataType) -> Result<String,
ArrowError> {
DataType::Struct(_) => Ok("+s".to_string()),
DataType::Map(_, _) => Ok("+m".to_string()),
DataType::Dictionary(key_data_type, _) =>
get_format_string(key_data_type),
- DataType::Union(_, type_ids, mode) => {
- let formats = type_ids.iter().map(|t|
t.to_string()).collect::<Vec<_>>();
+ DataType::Union(fields, mode) => {
+ let formats = fields
+ .iter()
+ .map(|(t, _)| t.to_string())
+ .collect::<Vec<_>>();
match mode {
UnionMode::Dense => Ok(format!("{}:{}", "+ud",
formats.join(","))),
UnionMode::Sparse => Ok(format!("{}:{}", "+us",
formats.join(","))),
diff --git a/arrow-schema/src/field.rs b/arrow-schema/src/field.rs
index 8ef9fd2b8..d68392f51 100644
--- a/arrow-schema/src/field.rs
+++ b/arrow-schema/src/field.rs
@@ -235,8 +235,8 @@ impl Field {
fn _fields(dt: &DataType) -> Vec<&Field> {
match dt {
DataType::Struct(fields) => fields.iter().flat_map(|f|
f.fields()).collect(),
- DataType::Union(fields, _, _) => {
- fields.iter().flat_map(|f| f.fields()).collect()
+ DataType::Union(fields, _) => {
+ fields.iter().flat_map(|(_, f)| f.fields()).collect()
}
DataType::List(field)
| DataType::LargeList(field)
@@ -341,36 +341,9 @@ impl Field {
self.name, from.data_type)
))}
},
- DataType::Union(nested_fields, type_ids, _) => match
&from.data_type {
- DataType::Union(from_nested_fields, from_type_ids, _) => {
- for (idx, from_field) in
from_nested_fields.iter().enumerate() {
- let mut is_new_field = true;
- let field_type_id = from_type_ids.get(idx).unwrap();
-
- for (self_idx, self_field) in
nested_fields.iter_mut().enumerate()
- {
- if from_field == self_field {
- let self_type_id =
type_ids.get(self_idx).unwrap();
-
- // If the nested fields in two unions are the
same, they must have same
- // type id.
- if self_type_id != field_type_id {
- return Err(ArrowError::SchemaError(
- format!("Fail to merge schema field
'{}' because the self_type_id = {} does not equal field_type_id = {}",
- self.name, self_type_id,
field_type_id)
- ));
- }
-
- is_new_field = false;
- break;
- }
- }
-
- if is_new_field {
- nested_fields.push(from_field.clone());
- type_ids.push(*field_type_id);
- }
- }
+ DataType::Union(nested_fields, _) => match &from.data_type {
+ DataType::Union(from_nested_fields, _) => {
+ nested_fields.try_merge(from_nested_fields)?
}
_ => {
return Err(ArrowError::SchemaError(
diff --git a/arrow-schema/src/fields.rs b/arrow-schema/src/fields.rs
index 268226136..1de5e5efd 100644
--- a/arrow-schema/src/fields.rs
+++ b/arrow-schema/src/fields.rs
@@ -15,13 +15,13 @@
// specific language governing permissions and limitations
// under the License.
-use crate::{Field, FieldRef};
+use crate::{ArrowError, Field, FieldRef};
use std::ops::Deref;
use std::sync::Arc;
/// A cheaply cloneable, owned slice of [`FieldRef`]
///
-/// Similar to `Arc<Vec<FieldPtr>>` or `Arc<[FieldPtr]>`
+/// Similar to `Arc<Vec<FieldRef>>` or `Arc<[FieldRef]>`
///
/// Can be constructed in a number of ways
///
@@ -55,7 +55,9 @@ impl Fields {
/// Return size of this instance in bytes.
pub fn size(&self) -> usize {
- self.iter().map(|field| field.size()).sum()
+ self.iter()
+ .map(|field| field.size() + std::mem::size_of::<FieldRef>())
+ .sum()
}
/// Searches for a field by name, returning it along with its index if
found
@@ -148,3 +150,112 @@ impl<'de> serde::Deserialize<'de> for Fields {
Ok(Vec::<Field>::deserialize(deserializer)?.into())
}
}
+
+/// A cheaply cloneable, owned collection of [`FieldRef`] and their
corresponding type ids
+#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+#[cfg_attr(feature = "serde", serde(transparent))]
+pub struct UnionFields(Arc<[(i8, FieldRef)]>);
+
+impl std::fmt::Debug for UnionFields {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ self.0.as_ref().fmt(f)
+ }
+}
+
+impl UnionFields {
+ /// Create a new [`UnionFields`] with no fields
+ pub fn empty() -> Self {
+ Self(Arc::from([]))
+ }
+
+ /// Create a new [`UnionFields`] from a [`Fields`] and array of type_ids
+ ///
+ /// See <https://arrow.apache.org/docs/format/Columnar.html#union-layout>
+ ///
+ /// ```
+ /// use arrow_schema::{DataType, Field, UnionFields};
+ /// // Create a new UnionFields with type id mapping
+ /// // 1 -> DataType::UInt8
+ /// // 3 -> DataType::Utf8
+ /// UnionFields::new(
+ /// vec![1, 3],
+ /// vec![
+ /// Field::new("field1", DataType::UInt8, false),
+ /// Field::new("field3", DataType::Utf8, false),
+ /// ],
+ /// );
+ /// ```
+ pub fn new<F, T>(type_ids: T, fields: F) -> Self
+ where
+ F: IntoIterator,
+ F::Item: Into<FieldRef>,
+ T: IntoIterator<Item = i8>,
+ {
+ let fields = fields.into_iter().map(Into::into);
+ type_ids.into_iter().zip(fields).collect()
+ }
+
+ /// Return size of this instance in bytes.
+ pub fn size(&self) -> usize {
+ self.iter()
+ .map(|(_, field)| field.size() + std::mem::size_of::<(i8,
FieldRef)>())
+ .sum()
+ }
+
+ /// Returns the number of fields in this [`UnionFields`]
+ pub fn len(&self) -> usize {
+ self.0.len()
+ }
+
+ /// Returns `true` if this is empty
+ pub fn is_empty(&self) -> bool {
+ self.0.is_empty()
+ }
+
+ /// Returns an iterator over the fields and type ids in this
[`UnionFields`]
+ ///
+ /// Note: the iteration order is not guaranteed
+ pub fn iter(&self) -> impl Iterator<Item = (i8, &FieldRef)> + '_ {
+ self.0.iter().map(|(id, f)| (*id, f))
+ }
+
+ /// Merge this field into self if it is compatible.
+ ///
+ /// See [`Field::try_merge`]
+ pub(crate) fn try_merge(&mut self, other: &Self) -> Result<(), ArrowError>
{
+ // TODO: This currently may produce duplicate type IDs (#3982)
+ let mut output: Vec<_> = self.iter().map(|(id, f)| (id,
f.clone())).collect();
+ for (field_type_id, from_field) in other.iter() {
+ let mut is_new_field = true;
+ for (self_type_id, self_field) in output.iter_mut() {
+ if from_field == self_field {
+ // If the nested fields in two unions are the same, they
must have same
+ // type id.
+ if *self_type_id != field_type_id {
+ return Err(ArrowError::SchemaError(
+ format!("Fail to merge schema field '{}' because
the self_type_id = {} does not equal field_type_id = {}",
+ self_field.name(), self_type_id,
field_type_id)
+ ));
+ }
+
+ is_new_field = false;
+ break;
+ }
+ }
+
+ if is_new_field {
+ output.push((field_type_id, from_field.clone()))
+ }
+ }
+ *self = output.into_iter().collect();
+ Ok(())
+ }
+}
+
+impl FromIterator<(i8, FieldRef)> for UnionFields {
+ fn from_iter<T: IntoIterator<Item = (i8, FieldRef)>>(iter: T) -> Self {
+ // TODO: Should this validate type IDs are unique (#3982)
+ Self(iter.into_iter().collect())
+ }
+}
diff --git a/arrow-schema/src/schema.rs b/arrow-schema/src/schema.rs
index 6089c1ae5..501c5c7fd 100644
--- a/arrow-schema/src/schema.rs
+++ b/arrow-schema/src/schema.rs
@@ -365,7 +365,7 @@ impl Hash for Schema {
#[cfg(test)]
mod tests {
use crate::datatype::DataType;
- use crate::{TimeUnit, UnionMode};
+ use crate::{TimeUnit, UnionFields, UnionMode};
use super::*;
@@ -778,11 +778,13 @@ mod tests {
Schema::new(vec![Field::new(
"c1",
DataType::Union(
- vec![
- Field::new("c11", DataType::Utf8, true),
- Field::new("c12", DataType::Utf8, true),
- ],
- vec![0, 1],
+ UnionFields::new(
+ vec![0, 1],
+ vec![
+ Field::new("c11", DataType::Utf8, true),
+ Field::new("c12", DataType::Utf8, true),
+ ]
+ ),
UnionMode::Dense
),
false
@@ -790,11 +792,17 @@ mod tests {
Schema::new(vec![Field::new(
"c1",
DataType::Union(
- vec![
- Field::new("c12", DataType::Utf8, true),
- Field::new("c13",
DataType::Time64(TimeUnit::Second), true),
- ],
- vec![1, 2],
+ UnionFields::new(
+ vec![1, 2],
+ vec![
+ Field::new("c12", DataType::Utf8, true),
+ Field::new(
+ "c13",
+ DataType::Time64(TimeUnit::Second),
+ true
+ ),
+ ]
+ ),
UnionMode::Dense
),
false
@@ -804,12 +812,14 @@ mod tests {
Schema::new(vec![Field::new(
"c1",
DataType::Union(
- vec![
- Field::new("c11", DataType::Utf8, true),
- Field::new("c12", DataType::Utf8, true),
- Field::new("c13", DataType::Time64(TimeUnit::Second),
true),
- ],
- vec![0, 1, 2],
+ UnionFields::new(
+ vec![0, 1, 2],
+ vec![
+ Field::new("c11", DataType::Utf8, true),
+ Field::new("c12", DataType::Utf8, true),
+ Field::new("c13",
DataType::Time64(TimeUnit::Second), true),
+ ]
+ ),
UnionMode::Dense
),
false
diff --git a/arrow/src/datatypes/mod.rs b/arrow/src/datatypes/mod.rs
index d1977d42b..74dad6b4a 100644
--- a/arrow/src/datatypes/mod.rs
+++ b/arrow/src/datatypes/mod.rs
@@ -30,7 +30,7 @@ pub use arrow_buffer::{i256, ArrowNativeType, ToByteSlice};
pub use arrow_data::decimal::*;
pub use arrow_schema::{
DataType, Field, FieldRef, Fields, IntervalUnit, Schema, SchemaBuilder,
SchemaRef,
- TimeUnit, UnionMode,
+ TimeUnit, UnionFields, UnionMode,
};
#[cfg(feature = "ffi")]
diff --git a/arrow/src/ffi.rs b/arrow/src/ffi.rs
index 2d6bbf1a0..fe2e186a7 100644
--- a/arrow/src/ffi.rs
+++ b/arrow/src/ffi.rs
@@ -174,15 +174,15 @@ fn bit_width(data_type: &DataType, i: usize) ->
Result<usize> {
)))
}
// type ids. UnionArray doesn't have null bitmap so buffer index
begins with 0.
- (DataType::Union(_, _, _), 0) => i8::BITS as _,
+ (DataType::Union(_, _), 0) => i8::BITS as _,
// Only DenseUnion has 2nd buffer
- (DataType::Union(_, _, UnionMode::Dense), 1) => i32::BITS as _,
- (DataType::Union(_, _, UnionMode::Sparse), _) => {
+ (DataType::Union(_, UnionMode::Dense), 1) => i32::BITS as _,
+ (DataType::Union(_, UnionMode::Sparse), _) => {
return Err(ArrowError::CDataInterface(format!(
"The datatype \"{data_type:?}\" expects 1 buffer, but
requested {i}. Please verify that the C data interface is correctly
implemented."
)))
}
- (DataType::Union(_, _, UnionMode::Dense), _) => {
+ (DataType::Union(_, UnionMode::Dense), _) => {
return Err(ArrowError::CDataInterface(format!(
"The datatype \"{data_type:?}\" expects 2 buffer, but
requested {i}. Please verify that the C data interface is correctly
implemented."
)))
diff --git a/arrow/tests/array_cast.rs b/arrow/tests/array_cast.rs
index b113ec04c..27fb1dcd2 100644
--- a/arrow/tests/array_cast.rs
+++ b/arrow/tests/array_cast.rs
@@ -41,7 +41,7 @@ use arrow_cast::pretty::pretty_format_columns;
use arrow_cast::{can_cast_types, cast};
use arrow_data::ArrayData;
use arrow_schema::{
- ArrowError, DataType, Field, Fields, IntervalUnit, TimeUnit, UnionMode,
+ ArrowError, DataType, Field, Fields, IntervalUnit, TimeUnit, UnionFields,
UnionMode,
};
use half::f16;
use std::sync::Arc;
@@ -405,11 +405,13 @@ fn get_all_types() -> Vec<DataType> {
Field::new("f2", DataType::Utf8, true),
])),
Union(
- vec![
- Field::new("f1", DataType::Int32, false),
- Field::new("f2", DataType::Utf8, true),
- ],
- vec![0, 1],
+ UnionFields::new(
+ vec![0, 1],
+ vec![
+ Field::new("f1", DataType::Int32, false),
+ Field::new("f2", DataType::Utf8, true),
+ ],
+ ),
UnionMode::Dense,
),
Decimal128(38, 0),
diff --git a/arrow/tests/array_validation.rs b/arrow/tests/array_validation.rs
index 73e013ff1..ef0d40d64 100644
--- a/arrow/tests/array_validation.rs
+++ b/arrow/tests/array_validation.rs
@@ -22,7 +22,7 @@ use arrow::array::{
use arrow_array::Decimal128Array;
use arrow_buffer::{ArrowNativeType, Buffer};
use arrow_data::ArrayData;
-use arrow_schema::{DataType, Field, UnionMode};
+use arrow_schema::{DataType, Field, UnionFields, UnionMode};
use std::ptr::NonNull;
use std::sync::Arc;
@@ -768,11 +768,13 @@ fn test_validate_union_different_types() {
ArrayData::try_new(
DataType::Union(
- vec![
- Field::new("field1", DataType::Int32, true),
- Field::new("field2", DataType::Int64, true), // data is int32
- ],
- vec![0, 1],
+ UnionFields::new(
+ vec![0, 1],
+ vec![
+ Field::new("field1", DataType::Int32, true),
+ Field::new("field2", DataType::Int64, true), // data is
int32
+ ],
+ ),
UnionMode::Sparse,
),
2,
@@ -799,11 +801,13 @@ fn test_validate_union_sparse_different_child_len() {
ArrayData::try_new(
DataType::Union(
- vec![
- Field::new("field1", DataType::Int32, true),
- Field::new("field2", DataType::Int64, true),
- ],
- vec![0, 1],
+ UnionFields::new(
+ vec![0, 1],
+ vec![
+ Field::new("field1", DataType::Int32, true),
+ Field::new("field2", DataType::Int64, true),
+ ],
+ ),
UnionMode::Sparse,
),
2,
@@ -826,11 +830,13 @@ fn test_validate_union_dense_without_offsets() {
ArrayData::try_new(
DataType::Union(
- vec![
- Field::new("field1", DataType::Int32, true),
- Field::new("field2", DataType::Int64, true),
- ],
- vec![0, 1],
+ UnionFields::new(
+ vec![0, 1],
+ vec![
+ Field::new("field1", DataType::Int32, true),
+ Field::new("field2", DataType::Int64, true),
+ ],
+ ),
UnionMode::Dense,
),
2,
@@ -854,11 +860,13 @@ fn test_validate_union_dense_with_bad_len() {
ArrayData::try_new(
DataType::Union(
- vec![
- Field::new("field1", DataType::Int32, true),
- Field::new("field2", DataType::Int64, true),
- ],
- vec![0, 1],
+ UnionFields::new(
+ vec![0, 1],
+ vec![
+ Field::new("field1", DataType::Int32, true),
+ Field::new("field2", DataType::Int64, true),
+ ],
+ ),
UnionMode::Dense,
),
2,
diff --git a/parquet/src/arrow/arrow_writer/mod.rs
b/parquet/src/arrow/arrow_writer/mod.rs
index 94c19cb2e..f594f2f79 100644
--- a/parquet/src/arrow/arrow_writer/mod.rs
+++ b/parquet/src/arrow/arrow_writer/mod.rs
@@ -360,7 +360,7 @@ fn write_leaves<W: Write>(
ArrowDataType::Float16 => Err(ParquetError::ArrowError(
"Float16 arrays not supported".to_string(),
)),
- ArrowDataType::FixedSizeList(_, _) | ArrowDataType::Union(_, _, _) |
ArrowDataType::RunEndEncoded(_, _) => {
+ ArrowDataType::FixedSizeList(_, _) | ArrowDataType::Union(_, _) |
ArrowDataType::RunEndEncoded(_, _) => {
Err(ParquetError::NYI(
format!(
"Attempting to write an Arrow type {data_type:?} to
parquet that is not yet implemented"
diff --git a/parquet/src/arrow/schema/mod.rs b/parquet/src/arrow/schema/mod.rs
index 09109d290..b541a754b 100644
--- a/parquet/src/arrow/schema/mod.rs
+++ b/parquet/src/arrow/schema/mod.rs
@@ -501,7 +501,7 @@ fn arrow_to_parquet_type(field: &Field) -> Result<Type> {
))
}
}
- DataType::Union(_, _, _) => unimplemented!("See ARROW-8817."),
+ DataType::Union(_, _) => unimplemented!("See ARROW-8817."),
DataType::Dictionary(_, ref value) => {
// Dictionary encoding not handled at the schema level
let dict_field = Field::new(name, *value.clone(),
field.is_nullable());