This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new e5a167695 Add UnionFields (#3955) (#3981)
e5a167695 is described below

commit e5a1676950ab5c04b0a74953ec5418da67cedb45
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Thu Mar 30 14:58:21 2023 +0100

    Add UnionFields (#3955) (#3981)
    
    * Add UnionFields (#3955)
    
    * Fix array_cast
    
    * Review feedback
    
    * Clippy
---
 arrow-array/src/array/mod.rs           |  16 +++--
 arrow-array/src/array/union_array.rs   |  53 ++++++++-------
 arrow-array/src/record_batch.rs        |   2 +-
 arrow-cast/src/display.rs              |  14 ++--
 arrow-cast/src/pretty.rs               |  42 +++++++-----
 arrow-data/src/data/mod.rs             |  25 +++----
 arrow-data/src/equal/mod.rs            |   2 +-
 arrow-data/src/equal/union.rs          |  26 ++++----
 arrow-data/src/equal/utils.rs          |   2 +-
 arrow-data/src/transform/mod.rs        |   6 +-
 arrow-integration-test/src/datatype.rs |  22 +++----
 arrow-integration-test/src/field.rs    |  29 +++++---
 arrow-integration-test/src/lib.rs      |   9 +--
 arrow-ipc/src/convert.rs               |  99 +++++++++++++++-------------
 arrow-ipc/src/reader.rs                |  31 +++++----
 arrow-ipc/src/writer.rs                |  20 +++---
 arrow-schema/src/datatype.rs           |  26 ++++----
 arrow-schema/src/ffi.rs                |  19 ++++--
 arrow-schema/src/field.rs              |  37 ++---------
 arrow-schema/src/fields.rs             | 117 ++++++++++++++++++++++++++++++++-
 arrow-schema/src/schema.rs             |  44 ++++++++-----
 arrow/src/datatypes/mod.rs             |   2 +-
 arrow/src/ffi.rs                       |   8 +--
 arrow/tests/array_cast.rs              |  14 ++--
 arrow/tests/array_validation.rs        |  50 ++++++++------
 parquet/src/arrow/arrow_writer/mod.rs  |   2 +-
 parquet/src/arrow/schema/mod.rs        |   2 +-
 27 files changed, 430 insertions(+), 289 deletions(-)

diff --git a/arrow-array/src/array/mod.rs b/arrow-array/src/array/mod.rs
index 8d20c6cb2..9a5172d0d 100644
--- a/arrow-array/src/array/mod.rs
+++ b/arrow-array/src/array/mod.rs
@@ -586,7 +586,7 @@ pub fn make_array(data: ArrayData) -> ArrayRef {
         DataType::LargeList(_) => Arc::new(LargeListArray::from(data)) as 
ArrayRef,
         DataType::Struct(_) => Arc::new(StructArray::from(data)) as ArrayRef,
         DataType::Map(_, _) => Arc::new(MapArray::from(data)) as ArrayRef,
-        DataType::Union(_, _, _) => Arc::new(UnionArray::from(data)) as 
ArrayRef,
+        DataType::Union(_, _) => Arc::new(UnionArray::from(data)) as ArrayRef,
         DataType::FixedSizeList(_, _) => {
             Arc::new(FixedSizeListArray::from(data)) as ArrayRef
         }
@@ -740,7 +740,7 @@ mod tests {
     use crate::cast::{as_union_array, downcast_array};
     use crate::downcast_run_array;
     use arrow_buffer::{Buffer, MutableBuffer};
-    use arrow_schema::{Field, Fields, UnionMode};
+    use arrow_schema::{Field, Fields, UnionFields, UnionMode};
 
     #[test]
     fn test_empty_primitive() {
@@ -874,11 +874,13 @@ mod tests {
     fn test_null_union() {
         for mode in [UnionMode::Sparse, UnionMode::Dense] {
             let data_type = DataType::Union(
-                vec![
-                    Field::new("foo", DataType::Int32, true),
-                    Field::new("bar", DataType::Int64, true),
-                ],
-                vec![2, 1],
+                UnionFields::new(
+                    vec![2, 1],
+                    vec![
+                        Field::new("foo", DataType::Int32, true),
+                        Field::new("bar", DataType::Int64, true),
+                    ],
+                ),
                 mode,
             );
             let array = new_null_array(&data_type, 4);
diff --git a/arrow-array/src/array/union_array.rs 
b/arrow-array/src/array/union_array.rs
index 00ad94111..335b6b14f 100644
--- a/arrow-array/src/array/union_array.rs
+++ b/arrow-array/src/array/union_array.rs
@@ -19,7 +19,7 @@ use crate::{make_array, Array, ArrayRef};
 use arrow_buffer::buffer::NullBuffer;
 use arrow_buffer::{Buffer, ScalarBuffer};
 use arrow_data::ArrayData;
-use arrow_schema::{ArrowError, DataType, Field, UnionMode};
+use arrow_schema::{ArrowError, DataType, Field, UnionFields, UnionMode};
 /// Contains the `UnionArray` type.
 ///
 use std::any::Any;
@@ -145,8 +145,7 @@ impl UnionArray {
         value_offsets: Option<Buffer>,
         child_arrays: Vec<(Field, ArrayRef)>,
     ) -> Self {
-        let (field_types, field_values): (Vec<_>, Vec<_>) =
-            child_arrays.into_iter().unzip();
+        let (fields, field_values): (Vec<_>, Vec<_>) = 
child_arrays.into_iter().unzip();
         let len = type_ids.len();
 
         let mode = if value_offsets.is_some() {
@@ -156,8 +155,7 @@ impl UnionArray {
         };
 
         let builder = ArrayData::builder(DataType::Union(
-            field_types,
-            Vec::from(field_type_ids),
+            UnionFields::new(field_type_ids.iter().copied(), fields),
             mode,
         ))
         .add_buffer(type_ids)
@@ -282,9 +280,9 @@ impl UnionArray {
     /// Returns the names of the types in the union.
     pub fn type_names(&self) -> Vec<&str> {
         match self.data.data_type() {
-            DataType::Union(fields, _, _) => fields
+            DataType::Union(fields, _) => fields
                 .iter()
-                .map(|f| f.name().as_str())
+                .map(|(_, f)| f.name().as_str())
                 .collect::<Vec<&str>>(),
             _ => unreachable!("Union array's data type is not a union!"),
         }
@@ -293,7 +291,7 @@ impl UnionArray {
     /// Returns whether the `UnionArray` is dense (or sparse if `false`).
     fn is_dense(&self) -> bool {
         match self.data.data_type() {
-            DataType::Union(_, _, mode) => mode == &UnionMode::Dense,
+            DataType::Union(_, mode) => mode == &UnionMode::Dense,
             _ => unreachable!("Union array's data type is not a union!"),
         }
     }
@@ -307,8 +305,8 @@ impl UnionArray {
 
 impl From<ArrayData> for UnionArray {
     fn from(data: ArrayData) -> Self {
-        let (field_ids, mode) = match data.data_type() {
-            DataType::Union(_, ids, mode) => (ids, *mode),
+        let (fields, mode) = match data.data_type() {
+            DataType::Union(fields, mode) => (fields, *mode),
             d => panic!("UnionArray expected ArrayData with type Union got 
{d}"),
         };
         let (type_ids, offsets) = match mode {
@@ -326,10 +324,10 @@ impl From<ArrayData> for UnionArray {
             ),
         };
 
-        let max_id = field_ids.iter().copied().max().unwrap_or_default() as 
usize;
+        let max_id = fields.iter().map(|(i, _)| i).max().unwrap_or_default() 
as usize;
         let mut boxed_fields = vec![None; max_id + 1];
-        for (cd, field_id) in data.child_data().iter().zip(field_ids) {
-            boxed_fields[*field_id as usize] = Some(make_array(cd.clone()));
+        for (cd, (field_id, _)) in data.child_data().iter().zip(fields.iter()) 
{
+            boxed_fields[field_id as usize] = Some(make_array(cd.clone()));
         }
         Self {
             data,
@@ -402,19 +400,18 @@ impl std::fmt::Debug for UnionArray {
         writeln!(f, "-- type id buffer:")?;
         writeln!(f, "{:?}", self.type_ids)?;
 
-        let (fields, ids) = match self.data_type() {
-            DataType::Union(f, ids, _) => (f, ids),
-            _ => unreachable!(),
-        };
-
         if let Some(offsets) = &self.offsets {
             writeln!(f, "-- offsets buffer:")?;
             writeln!(f, "{:?}", offsets)?;
         }
 
-        assert_eq!(fields.len(), ids.len());
-        for (field, type_id) in fields.iter().zip(ids) {
-            let child = self.child(*type_id);
+        let fields = match self.data_type() {
+            DataType::Union(fields, _) => fields,
+            _ => unreachable!(),
+        };
+
+        for (type_id, field) in fields.iter() {
+            let child = self.child(type_id);
             writeln!(
                 f,
                 "-- child {}: \"{}\" ({:?})",
@@ -1058,12 +1055,14 @@ mod tests {
     #[test]
     fn test_custom_type_ids() {
         let data_type = DataType::Union(
-            vec![
-                Field::new("strings", DataType::Utf8, false),
-                Field::new("integers", DataType::Int32, false),
-                Field::new("floats", DataType::Float64, false),
-            ],
-            vec![8, 4, 9],
+            UnionFields::new(
+                vec![8, 4, 9],
+                vec![
+                    Field::new("strings", DataType::Utf8, false),
+                    Field::new("integers", DataType::Int32, false),
+                    Field::new("floats", DataType::Float64, false),
+                ],
+            ),
             UnionMode::Dense,
         );
 
diff --git a/arrow-array/src/record_batch.rs b/arrow-array/src/record_batch.rs
index 2754d04bf..17b1f04e8 100644
--- a/arrow-array/src/record_batch.rs
+++ b/arrow-array/src/record_batch.rs
@@ -590,7 +590,7 @@ mod tests {
         let record_batch =
             RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), 
Arc::new(b)])
                 .unwrap();
-        assert_eq!(record_batch.get_array_memory_size(), 628);
+        assert_eq!(record_batch.get_array_memory_size(), 564);
     }
 
     fn check_batch(record_batch: RecordBatch, num_rows: usize) {
diff --git a/arrow-cast/src/display.rs b/arrow-cast/src/display.rs
index d10903697..0bca9ce65 100644
--- a/arrow-cast/src/display.rs
+++ b/arrow-cast/src/display.rs
@@ -278,7 +278,7 @@ fn make_formatter<'a>(
         }
         DataType::Struct(_) => array_format(as_struct_array(array), options),
         DataType::Map(_, _) => array_format(as_map_array(array), options),
-        DataType::Union(_, _, _) => array_format(as_union_array(array), 
options),
+        DataType::Union(_, _) => array_format(as_union_array(array), options),
         d => Err(ArrowError::NotYetImplemented(format!("formatting {d} is not 
yet supported"))),
     }
 }
@@ -801,16 +801,16 @@ impl<'a> DisplayIndexState<'a> for &'a UnionArray {
     );
 
     fn prepare(&self, options: &FormatOptions<'a>) -> Result<Self::State, 
ArrowError> {
-        let (fields, type_ids, mode) = match (*self).data_type() {
-            DataType::Union(fields, type_ids, mode) => (fields, type_ids, 
mode),
+        let (fields, mode) = match (*self).data_type() {
+            DataType::Union(fields, mode) => (fields, mode),
             _ => unreachable!(),
         };
 
-        let max_id = type_ids.iter().copied().max().unwrap_or_default() as 
usize;
+        let max_id = fields.iter().map(|(id, _)| id).max().unwrap_or_default() 
as usize;
         let mut out: Vec<Option<FieldDisplay>> = (0..max_id + 1).map(|_| 
None).collect();
-        for (i, field) in type_ids.iter().zip(fields) {
-            let formatter = make_formatter(self.child(*i).as_ref(), options)?;
-            out[*i as usize] = Some((field.name().as_str(), formatter))
+        for (i, field) in fields.iter() {
+            let formatter = make_formatter(self.child(i).as_ref(), options)?;
+            out[i as usize] = Some((field.name().as_str(), formatter))
         }
         Ok((out, *mode))
     }
diff --git a/arrow-cast/src/pretty.rs b/arrow-cast/src/pretty.rs
index ffa5af82d..818e9d3c0 100644
--- a/arrow-cast/src/pretty.rs
+++ b/arrow-cast/src/pretty.rs
@@ -703,11 +703,13 @@ mod tests {
         let schema = Schema::new(vec![Field::new(
             "Teamsters",
             DataType::Union(
-                vec![
-                    Field::new("a", DataType::Int32, false),
-                    Field::new("b", DataType::Float64, false),
-                ],
-                vec![0, 1],
+                UnionFields::new(
+                    vec![0, 1],
+                    vec![
+                        Field::new("a", DataType::Int32, false),
+                        Field::new("b", DataType::Float64, false),
+                    ],
+                ),
                 UnionMode::Dense,
             ),
             false,
@@ -743,11 +745,13 @@ mod tests {
         let schema = Schema::new(vec![Field::new(
             "Teamsters",
             DataType::Union(
-                vec![
-                    Field::new("a", DataType::Int32, false),
-                    Field::new("b", DataType::Float64, false),
-                ],
-                vec![0, 1],
+                UnionFields::new(
+                    vec![0, 1],
+                    vec![
+                        Field::new("a", DataType::Int32, false),
+                        Field::new("b", DataType::Float64, false),
+                    ],
+                ),
                 UnionMode::Sparse,
             ),
             false,
@@ -785,11 +789,13 @@ mod tests {
         let inner_field = Field::new(
             "European Union",
             DataType::Union(
-                vec![
-                    Field::new("b", DataType::Int32, false),
-                    Field::new("c", DataType::Float64, false),
-                ],
-                vec![0, 1],
+                UnionFields::new(
+                    vec![0, 1],
+                    vec![
+                        Field::new("b", DataType::Int32, false),
+                        Field::new("c", DataType::Float64, false),
+                    ],
+                ),
                 UnionMode::Dense,
             ),
             false,
@@ -809,8 +815,10 @@ mod tests {
         let schema = Schema::new(vec![Field::new(
             "Teamsters",
             DataType::Union(
-                vec![Field::new("a", DataType::Int32, true), inner_field],
-                vec![0, 1],
+                UnionFields::new(
+                    vec![0, 1],
+                    vec![Field::new("a", DataType::Int32, true), inner_field],
+                ),
                 UnionMode::Sparse,
             ),
             false,
diff --git a/arrow-data/src/data/mod.rs b/arrow-data/src/data/mod.rs
index c47c83663..581d4a10c 100644
--- a/arrow-data/src/data/mod.rs
+++ b/arrow-data/src/data/mod.rs
@@ -136,7 +136,7 @@ pub(crate) fn new_buffers(data_type: &DataType, capacity: 
usize) -> [MutableBuff
             MutableBuffer::new(capacity * mem::size_of::<u8>()),
             empty_buffer,
         ],
-        DataType::Union(_, _, mode) => {
+        DataType::Union(_, mode) => {
             let type_ids = MutableBuffer::new(capacity * mem::size_of::<i8>());
             match mode {
                 UnionMode::Sparse => [type_ids, empty_buffer],
@@ -162,7 +162,7 @@ pub(crate) fn into_buffers(
         | DataType::Binary
         | DataType::LargeUtf8
         | DataType::LargeBinary => vec![buffer1.into(), buffer2.into()],
-        DataType::Union(_, _, mode) => {
+        DataType::Union(_, mode) => {
             match mode {
                 // Based on Union's DataTypeLayout
                 UnionMode::Sparse => vec![buffer1.into()],
@@ -621,8 +621,9 @@ impl ArrayData {
                     vec![ArrayData::new_empty(v.as_ref())],
                     true,
                 ),
-                DataType::Union(f, i, mode) => {
-                    let ids = 
Buffer::from_iter(std::iter::repeat(i[0]).take(len));
+                DataType::Union(f, mode) => {
+                    let (id, _) = f.iter().next().unwrap();
+                    let ids = 
Buffer::from_iter(std::iter::repeat(id).take(len));
                     let buffers = match mode {
                         UnionMode::Sparse => vec![ids],
                         UnionMode::Dense => {
@@ -634,7 +635,7 @@ impl ArrayData {
                     let children = f
                         .iter()
                         .enumerate()
-                        .map(|(idx, f)| match idx {
+                        .map(|(idx, (_, f))| match idx {
                             0 => Self::new_null(f.data_type(), len),
                             _ => Self::new_empty(f.data_type()),
                         })
@@ -986,10 +987,10 @@ impl ArrayData {
                 }
                 Ok(())
             }
-            DataType::Union(fields, _, mode) => {
+            DataType::Union(fields, mode) => {
                 self.validate_num_child_data(fields.len())?;
 
-                for (i, field) in fields.iter().enumerate() {
+                for (i, (_, field)) in fields.iter().enumerate() {
                     let field_data = self.get_valid_child_data(i, 
field.data_type())?;
 
                     if mode == &UnionMode::Sparse
@@ -1255,7 +1256,7 @@ impl ArrayData {
                 let child = &self.child_data[0];
                 self.validate_offsets_full::<i64>(child.len)
             }
-            DataType::Union(_, _, _) => {
+            DataType::Union(_, _) => {
                 // Validate Union Array as part of implementing new Union 
semantics
                 // See comments in `ArrayData::validate()`
                 // https://github.com/apache/arrow-rs/issues/85
@@ -1568,7 +1569,7 @@ pub fn layout(data_type: &DataType) -> DataTypeLayout {
         DataType::LargeList(_) => 
DataTypeLayout::new_fixed_width(size_of::<i64>()),
         DataType::Struct(_) => DataTypeLayout::new_empty(), // all in child 
data,
         DataType::RunEndEncoded(_, _) => DataTypeLayout::new_empty(), // all 
in child data,
-        DataType::Union(_, _, mode) => {
+        DataType::Union(_, mode) => {
             let type_ids = BufferSpec::FixedWidth {
                 byte_width: size_of::<i8>(),
             };
@@ -1823,7 +1824,7 @@ impl From<ArrayData> for ArrayDataBuilder {
 #[cfg(test)]
 mod tests {
     use super::*;
-    use arrow_schema::Field;
+    use arrow_schema::{Field, UnionFields};
 
     // See arrow/tests/array_data_validation.rs for test of array validation
 
@@ -2072,8 +2073,8 @@ mod tests {
     #[test]
     fn test_into_buffers() {
         let data_types = vec![
-            DataType::Union(vec![], vec![], UnionMode::Dense),
-            DataType::Union(vec![], vec![], UnionMode::Sparse),
+            DataType::Union(UnionFields::empty(), UnionMode::Dense),
+            DataType::Union(UnionFields::empty(), UnionMode::Sparse),
         ];
 
         for data_type in data_types {
diff --git a/arrow-data/src/equal/mod.rs b/arrow-data/src/equal/mod.rs
index 871a312ca..fbc868d3f 100644
--- a/arrow-data/src/equal/mod.rs
+++ b/arrow-data/src/equal/mod.rs
@@ -112,7 +112,7 @@ fn equal_values(
             fixed_list_equal(lhs, rhs, lhs_start, rhs_start, len)
         }
         DataType::Struct(_) => struct_equal(lhs, rhs, lhs_start, rhs_start, 
len),
-        DataType::Union(_, _, _) => union_equal(lhs, rhs, lhs_start, 
rhs_start, len),
+        DataType::Union(_, _) => union_equal(lhs, rhs, lhs_start, rhs_start, 
len),
         DataType::Dictionary(data_type, _) => match data_type.as_ref() {
             DataType::Int8 => dictionary_equal::<i8>(lhs, rhs, lhs_start, 
rhs_start, len),
             DataType::Int16 => {
diff --git a/arrow-data/src/equal/union.rs b/arrow-data/src/equal/union.rs
index fdf770096..4f04bc287 100644
--- a/arrow-data/src/equal/union.rs
+++ b/arrow-data/src/equal/union.rs
@@ -16,7 +16,7 @@
 // under the License.
 
 use crate::data::ArrayData;
-use arrow_schema::{DataType, UnionMode};
+use arrow_schema::{DataType, UnionFields, UnionMode};
 
 use super::equal_range;
 
@@ -28,8 +28,8 @@ fn equal_dense(
     rhs_type_ids: &[i8],
     lhs_offsets: &[i32],
     rhs_offsets: &[i32],
-    lhs_field_type_ids: &[i8],
-    rhs_field_type_ids: &[i8],
+    lhs_fields: &UnionFields,
+    rhs_fields: &UnionFields,
 ) -> bool {
     let offsets = lhs_offsets.iter().zip(rhs_offsets.iter());
 
@@ -38,13 +38,13 @@ fn equal_dense(
         .zip(rhs_type_ids.iter())
         .zip(offsets)
         .all(|((l_type_id, r_type_id), (l_offset, r_offset))| {
-            let lhs_child_index = lhs_field_type_ids
+            let lhs_child_index = lhs_fields
                 .iter()
-                .position(|r| r == l_type_id)
+                .position(|(r, _)| r == *l_type_id)
                 .unwrap();
-            let rhs_child_index = rhs_field_type_ids
+            let rhs_child_index = rhs_fields
                 .iter()
-                .position(|r| r == r_type_id)
+                .position(|(r, _)| r == *r_type_id)
                 .unwrap();
             let lhs_values = &lhs.child_data()[lhs_child_index];
             let rhs_values = &rhs.child_data()[rhs_child_index];
@@ -89,8 +89,8 @@ pub(super) fn union_equal(
 
     match (lhs.data_type(), rhs.data_type()) {
         (
-            DataType::Union(_, lhs_type_ids, UnionMode::Dense),
-            DataType::Union(_, rhs_type_ids, UnionMode::Dense),
+            DataType::Union(lhs_fields, UnionMode::Dense),
+            DataType::Union(rhs_fields, UnionMode::Dense),
         ) => {
             let lhs_offsets = lhs.buffer::<i32>(1);
             let rhs_offsets = rhs.buffer::<i32>(1);
@@ -106,13 +106,13 @@ pub(super) fn union_equal(
                     rhs_type_id_range,
                     lhs_offsets_range,
                     rhs_offsets_range,
-                    lhs_type_ids,
-                    rhs_type_ids,
+                    lhs_fields,
+                    rhs_fields,
                 )
         }
         (
-            DataType::Union(_, _, UnionMode::Sparse),
-            DataType::Union(_, _, UnionMode::Sparse),
+            DataType::Union(_, UnionMode::Sparse),
+            DataType::Union(_, UnionMode::Sparse),
         ) => {
             lhs_type_id_range == rhs_type_id_range
                 && equal_sparse(lhs, rhs, lhs_start, rhs_start, len)
diff --git a/arrow-data/src/equal/utils.rs b/arrow-data/src/equal/utils.rs
index 6b9a7940d..fa6211542 100644
--- a/arrow-data/src/equal/utils.rs
+++ b/arrow-data/src/equal/utils.rs
@@ -59,7 +59,7 @@ pub(super) fn equal_nulls(
 #[inline]
 pub(super) fn base_equal(lhs: &ArrayData, rhs: &ArrayData) -> bool {
     let equal_type = match (lhs.data_type(), rhs.data_type()) {
-        (DataType::Union(l_fields, _, l_mode), DataType::Union(r_fields, _, 
r_mode)) => {
+        (DataType::Union(l_fields, l_mode), DataType::Union(r_fields, r_mode)) 
=> {
             l_fields == r_fields && l_mode == r_mode
         }
         (DataType::Map(l_field, l_sorted), DataType::Map(r_field, r_sorted)) 
=> {
diff --git a/arrow-data/src/transform/mod.rs b/arrow-data/src/transform/mod.rs
index 2719b96b6..ccdbaec3b 100644
--- a/arrow-data/src/transform/mod.rs
+++ b/arrow-data/src/transform/mod.rs
@@ -231,7 +231,7 @@ fn build_extend(array: &ArrayData) -> Extend {
         DataType::FixedSizeBinary(_) => fixed_binary::build_extend(array),
         DataType::Float16 => primitive::build_extend::<f16>(array),
         DataType::FixedSizeList(_, _) => fixed_size_list::build_extend(array),
-        DataType::Union(_, _, mode) => match mode {
+        DataType::Union(_, mode) => match mode {
             UnionMode::Sparse => union::build_extend_sparse(array),
             UnionMode::Dense => union::build_extend_dense(array),
         },
@@ -283,7 +283,7 @@ fn build_extend_nulls(data_type: &DataType) -> ExtendNulls {
         DataType::FixedSizeBinary(_) => fixed_binary::extend_nulls,
         DataType::Float16 => primitive::extend_nulls::<f16>,
         DataType::FixedSizeList(_, _) => fixed_size_list::extend_nulls,
-        DataType::Union(_, _, mode) => match mode {
+        DataType::Union(_, mode) => match mode {
             UnionMode::Sparse => union::extend_nulls_sparse,
             UnionMode::Dense => union::extend_nulls_dense,
         },
@@ -501,7 +501,7 @@ impl<'a> MutableArrayData<'a> {
                     .collect::<Vec<_>>();
                 vec![MutableArrayData::new(childs, use_nulls, array_capacity)]
             }
-            DataType::Union(fields, _, _) => (0..fields.len())
+            DataType::Union(fields, _) => (0..fields.len())
                 .map(|i| {
                     let child_arrays = arrays
                         .iter()
diff --git a/arrow-integration-test/src/datatype.rs 
b/arrow-integration-test/src/datatype.rs
index a08368d58..5a5dd67fc 100644
--- a/arrow-integration-test/src/datatype.rs
+++ b/arrow-integration-test/src/datatype.rs
@@ -17,6 +17,7 @@
 
 use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, TimeUnit, 
UnionMode};
 use arrow::error::{ArrowError, Result};
+use std::sync::Arc;
 
 /// Parse a data type from a JSON representation.
 pub fn data_type_from_json(json: &serde_json::Value) -> Result<DataType> {
@@ -229,20 +230,15 @@ pub fn data_type_from_json(json: &serde_json::Value) -> 
Result<DataType> {
                             "Unknown union mode {mode:?} for union"
                         )));
                     };
-                    if let Some(type_ids) = map.get("typeIds") {
-                        let type_ids = type_ids
-                            .as_array()
-                            .unwrap()
+                    if let Some(values) = map.get("typeIds") {
+                        let field = Arc::new(default_field);
+                        let values = values.as_array().unwrap();
+                        let fields = values
                             .iter()
-                            .map(|t| t.as_i64().unwrap() as i8)
-                            .collect::<Vec<_>>();
+                            .map(|t| (t.as_i64().unwrap() as i8, 
field.clone()))
+                            .collect();
 
-                        let default_fields = type_ids
-                            .iter()
-                            .map(|_| default_field.clone())
-                            .collect::<Vec<_>>();
-
-                        Ok(DataType::Union(default_fields, type_ids, 
union_mode))
+                        Ok(DataType::Union(fields, union_mode))
                     } else {
                         Err(ArrowError::ParseError(
                             "Expecting a typeIds for union ".to_string(),
@@ -290,7 +286,7 @@ pub fn data_type_to_json(data_type: &DataType) -> 
serde_json::Value {
             json!({"name": "fixedsizebinary", "byteWidth": byte_width})
         }
         DataType::Struct(_) => json!({"name": "struct"}),
-        DataType::Union(_, _, _) => json!({"name": "union"}),
+        DataType::Union(_, _) => json!({"name": "union"}),
         DataType::List(_) => json!({ "name": "list"}),
         DataType::LargeList(_) => json!({ "name": "largelist"}),
         DataType::FixedSizeList(_, length) => {
diff --git a/arrow-integration-test/src/field.rs 
b/arrow-integration-test/src/field.rs
index a60cd91c5..c714fe467 100644
--- a/arrow-integration-test/src/field.rs
+++ b/arrow-integration-test/src/field.rs
@@ -19,6 +19,7 @@ use crate::{data_type_from_json, data_type_to_json};
 use arrow::datatypes::{DataType, Field};
 use arrow::error::{ArrowError, Result};
 use std::collections::HashMap;
+use std::sync::Arc;
 
 /// Parse a `Field` definition from a JSON representation.
 pub fn field_from_json(json: &serde_json::Value) -> Result<Field> {
@@ -194,11 +195,17 @@ pub fn field_from_json(json: &serde_json::Value) -> 
Result<Field> {
                         }
                     }
                 }
-                DataType::Union(_, type_ids, mode) => match 
map.get("children") {
+                DataType::Union(fields, mode) => match map.get("children") {
                     Some(Value::Array(values)) => {
-                        let union_fields: Vec<Field> =
-                            
values.iter().map(field_from_json).collect::<Result<_>>()?;
-                        DataType::Union(union_fields, type_ids, mode)
+                        let fields = fields
+                            .iter()
+                            .zip(values)
+                            .map(|((id, _), value)| {
+                                Ok((id, Arc::new(field_from_json(value)?)))
+                            })
+                            .collect::<Result<_>>()?;
+
+                        DataType::Union(fields, mode)
                     }
                     Some(_) => {
                         return Err(ArrowError::ParseError(
@@ -296,7 +303,7 @@ pub fn field_to_json(field: &Field) -> serde_json::Value {
 #[cfg(test)]
 mod tests {
     use super::*;
-    use arrow::datatypes::{Fields, UnionMode};
+    use arrow::datatypes::{Fields, UnionFields, UnionMode};
     use serde_json::Value;
 
     #[test]
@@ -569,11 +576,13 @@ mod tests {
         let expected = Field::new(
             "my_union",
             DataType::Union(
-                vec![
-                    Field::new("f1", DataType::Int32, true),
-                    Field::new("f2", DataType::Utf8, true),
-                ],
-                vec![5, 7],
+                UnionFields::new(
+                    vec![5, 7],
+                    vec![
+                        Field::new("f1", DataType::Int32, true),
+                        Field::new("f2", DataType::Utf8, true),
+                    ],
+                ),
                 UnionMode::Sparse,
             ),
             false,
diff --git a/arrow-integration-test/src/lib.rs 
b/arrow-integration-test/src/lib.rs
index 06f16ca1d..61bcbea5a 100644
--- a/arrow-integration-test/src/lib.rs
+++ b/arrow-integration-test/src/lib.rs
@@ -858,7 +858,7 @@ pub fn array_from_json(
             let array = MapArray::from(array_data);
             Ok(Arc::new(array))
         }
-        DataType::Union(fields, field_type_ids, _) => {
+        DataType::Union(fields, _) => {
             let type_ids = if let Some(type_id) = json_col.type_id {
                 type_id
             } else {
@@ -874,13 +874,14 @@ pub fn array_from_json(
             });
 
             let mut children: Vec<(Field, Arc<dyn Array>)> = vec![];
-            for (field, col) in fields.iter().zip(json_col.children.unwrap()) {
+            for ((_, field), col) in 
fields.iter().zip(json_col.children.unwrap()) {
                 let array = array_from_json(field, col, dictionaries)?;
-                children.push((field.clone(), array));
+                children.push((field.as_ref().clone(), array));
             }
 
+            let field_type_ids = fields.iter().map(|(id, _)| 
id).collect::<Vec<_>>();
             let array = UnionArray::try_new(
-                field_type_ids,
+                &field_type_ids,
                 Buffer::from(&type_ids.to_byte_slice()),
                 offset,
                 children,
diff --git a/arrow-ipc/src/convert.rs b/arrow-ipc/src/convert.rs
index 7e44f37d4..8ca0d514f 100644
--- a/arrow-ipc/src/convert.rs
+++ b/arrow-ipc/src/convert.rs
@@ -410,16 +410,16 @@ pub(crate) fn get_data_type(field: crate::Field, 
may_be_dictionary: bool) -> Dat
             let mut fields = vec![];
             if let Some(children) = field.children() {
                 for i in 0..children.len() {
-                    fields.push(children.get(i).into());
+                    fields.push(Field::from(children.get(i)));
                 }
             };
 
-            let type_ids: Vec<i8> = match union.typeIds() {
-                None => (0_i8..fields.len() as i8).collect(),
-                Some(ids) => ids.iter().map(|i| i as i8).collect(),
+            let fields = match union.typeIds() {
+                None => UnionFields::new(0_i8..fields.len() as i8, fields),
+                Some(ids) => UnionFields::new(ids.iter().map(|i| i as i8), 
fields),
             };
 
-            DataType::Union(fields, type_ids, union_mode)
+            DataType::Union(fields, union_mode)
         }
         t => unimplemented!("Type {:?} not supported", t),
     }
@@ -769,9 +769,9 @@ pub(crate) fn get_fb_field_type<'a>(
                 children: Some(fbb.create_vector(&empty_fields[..])),
             }
         }
-        Union(fields, type_ids, mode) => {
+        Union(fields, mode) => {
             let mut children = vec![];
-            for field in fields {
+            for (_, field) in fields.iter() {
                 children.push(build_field(fbb, field));
             }
 
@@ -781,7 +781,7 @@ pub(crate) fn get_fb_field_type<'a>(
             };
 
             let fbb_type_ids = fbb
-                .create_vector(&type_ids.iter().map(|t| *t as 
i32).collect::<Vec<_>>());
+                .create_vector(&fields.iter().map(|(t, _)| t as 
i32).collect::<Vec<_>>());
             let mut builder = crate::UnionBuilder::new(fbb);
             builder.add_mode(union_mode);
             builder.add_typeIds(fbb_type_ids);
@@ -962,38 +962,47 @@ mod tests {
                 Field::new(
                     "union<int64, list[union<date32, list[union<>]>]>",
                     DataType::Union(
-                        vec![
-                            Field::new("int64", DataType::Int64, true),
-                            Field::new(
-                                "list[union<date32, list[union<>]>]",
-                                DataType::List(Box::new(Field::new(
-                                    "union<date32, list[union<>]>",
-                                    DataType::Union(
-                                        vec![
-                                            Field::new("date32", 
DataType::Date32, true),
-                                            Field::new(
-                                                "list[union<>]",
-                                                
DataType::List(Box::new(Field::new(
-                                                    "union",
-                                                    DataType::Union(
-                                                        vec![],
-                                                        vec![],
-                                                        UnionMode::Sparse,
+                        UnionFields::new(
+                            vec![0, 1],
+                            vec![
+                                Field::new("int64", DataType::Int64, true),
+                                Field::new(
+                                    "list[union<date32, list[union<>]>]",
+                                    DataType::List(Box::new(Field::new(
+                                        "union<date32, list[union<>]>",
+                                        DataType::Union(
+                                            UnionFields::new(
+                                                vec![0, 1],
+                                                vec![
+                                                    Field::new(
+                                                        "date32",
+                                                        DataType::Date32,
+                                                        true,
+                                                    ),
+                                                    Field::new(
+                                                        "list[union<>]",
+                                                        
DataType::List(Box::new(
+                                                            Field::new(
+                                                                "union",
+                                                                
DataType::Union(
+                                                                    
UnionFields::empty(),
+                                                                    
UnionMode::Sparse,
+                                                                ),
+                                                                false,
+                                                            ),
+                                                        )),
+                                                        false,
                                                     ),
-                                                    false,
-                                                ))),
-                                                false,
+                                                ],
                                             ),
-                                        ],
-                                        vec![0, 1],
-                                        UnionMode::Dense,
-                                    ),
+                                            UnionMode::Dense,
+                                        ),
+                                        false,
+                                    ))),
                                     false,
-                                ))),
-                                false,
-                            ),
-                        ],
-                        vec![0, 1],
+                                ),
+                            ],
+                        ),
                         UnionMode::Sparse,
                     ),
                     false,
@@ -1001,22 +1010,24 @@ mod tests {
                 Field::new("struct<>", DataType::Struct(Fields::empty()), 
true),
                 Field::new(
                     "union<>",
-                    DataType::Union(vec![], vec![], UnionMode::Dense),
+                    DataType::Union(UnionFields::empty(), UnionMode::Dense),
                     true,
                 ),
                 Field::new(
                     "union<>",
-                    DataType::Union(vec![], vec![], UnionMode::Sparse),
+                    DataType::Union(UnionFields::empty(), UnionMode::Sparse),
                     true,
                 ),
                 Field::new(
                     "union<int32, utf8>",
                     DataType::Union(
-                        vec![
-                            Field::new("int32", DataType::Int32, true),
-                            Field::new("utf8", DataType::Utf8, true),
-                        ],
-                        vec![2, 3], // non-default type ids
+                        UnionFields::new(
+                            vec![2, 3], // non-default type ids
+                            vec![
+                                Field::new("int32", DataType::Int32, true),
+                                Field::new("utf8", DataType::Utf8, true),
+                            ],
+                        ),
                         UnionMode::Dense,
                     ),
                     true,
diff --git a/arrow-ipc/src/reader.rs b/arrow-ipc/src/reader.rs
index 4597ed82d..4f2e51336 100644
--- a/arrow-ipc/src/reader.rs
+++ b/arrow-ipc/src/reader.rs
@@ -263,7 +263,7 @@ fn create_array(
                 value_array.clone(),
             )?
         }
-        Union(fields, field_type_ids, mode) => {
+        Union(fields, mode) => {
             let union_node = nodes.get(node_index);
             node_index += 1;
 
@@ -292,9 +292,10 @@ fn create_array(
                 UnionMode::Sparse => None,
             };
 
-            let mut children = vec![];
+            let mut children = Vec::with_capacity(fields.len());
+            let mut ids = Vec::with_capacity(fields.len());
 
-            for field in fields {
+            for (id, field) in fields.iter() {
                 let triple = create_array(
                     nodes,
                     field,
@@ -310,11 +311,11 @@ fn create_array(
                 node_index = triple.1;
                 buffer_index = triple.2;
 
-                children.push((field.clone(), triple.0));
+                children.push((field.as_ref().clone(), triple.0));
+                ids.push(id);
             }
 
-            let array =
-                UnionArray::try_new(field_type_ids, type_ids, value_offsets, 
children)?;
+            let array = UnionArray::try_new(&ids, type_ids, value_offsets, 
children)?;
             Arc::new(array)
         }
         Null => {
@@ -418,7 +419,7 @@ fn skip_field(
             node_index += 1;
             buffer_index += 2;
         }
-        Union(fields, _field_type_ids, mode) => {
+        Union(fields, mode) => {
             node_index += 1;
             buffer_index += 1;
 
@@ -429,7 +430,7 @@ fn skip_field(
                 UnionMode::Sparse => {}
             };
 
-            for field in fields {
+            for (_, field) in fields.iter() {
                 let tuple = skip_field(field.data_type(), node_index, 
buffer_index)?;
 
                 node_index = tuple.0;
@@ -1265,11 +1266,15 @@ mod tests {
         let dict_data_type =
             DataType::Dictionary(Box::new(key_type), Box::new(value_type));
 
-        let union_fileds = vec![
-            Field::new("a", DataType::Int32, false),
-            Field::new("b", DataType::Float64, false),
-        ];
-        let union_data_type = DataType::Union(union_fileds, vec![0, 1], 
UnionMode::Dense);
+        let union_fields = UnionFields::new(
+            vec![0, 1],
+            vec![
+                Field::new("a", DataType::Int32, false),
+                Field::new("b", DataType::Float64, false),
+            ],
+        );
+
+        let union_data_type = DataType::Union(union_fields, UnionMode::Dense);
 
         let struct_fields = Fields::from(vec![
             Field::new("id", DataType::Int32, false),
diff --git a/arrow-ipc/src/writer.rs b/arrow-ipc/src/writer.rs
index ceb9b6ffa..0e999dc72 100644
--- a/arrow-ipc/src/writer.rs
+++ b/arrow-ipc/src/writer.rs
@@ -298,10 +298,10 @@ impl IpcDataGenerator {
                     write_options,
                 )?;
             }
-            DataType::Union(fields, type_ids, _) => {
+            DataType::Union(fields, _) => {
                 let union = as_union_array(column);
-                for (field, type_id) in fields.iter().zip(type_ids) {
-                    let column = union.child(*type_id);
+                for (type_id, field) in fields.iter() {
+                    let column = union.child(type_id);
                     self.encode_dictionaries(
                         field,
                         column,
@@ -1069,7 +1069,7 @@ fn has_validity_bitmap(data_type: &DataType, 
write_options: &IpcWriteOptions) ->
     } else {
         !matches!(
             data_type,
-            DataType::Null | DataType::Union(_, _, _) | 
DataType::RunEndEncoded(_, _)
+            DataType::Null | DataType::Union(_, _) | 
DataType::RunEndEncoded(_, _)
         )
     }
 }
@@ -1781,11 +1781,13 @@ mod tests {
         let schema = Schema::new(vec![Field::new(
             "union",
             DataType::Union(
-                vec![
-                    Field::new("a", DataType::Int32, false),
-                    Field::new("c", DataType::Float64, false),
-                ],
-                vec![0, 1],
+                UnionFields::new(
+                    vec![0, 1],
+                    vec![
+                        Field::new("a", DataType::Int32, false),
+                        Field::new("c", DataType::Float64, false),
+                    ],
+                ),
                 UnionMode::Sparse,
             ),
             true,
diff --git a/arrow-schema/src/datatype.rs b/arrow-schema/src/datatype.rs
index 58747fb26..57a5c6838 100644
--- a/arrow-schema/src/datatype.rs
+++ b/arrow-schema/src/datatype.rs
@@ -19,7 +19,7 @@ use std::fmt;
 use std::sync::Arc;
 
 use crate::field::Field;
-use crate::Fields;
+use crate::{Fields, UnionFields};
 
 /// The set of datatypes that are supported by this implementation of Apache 
Arrow.
 ///
@@ -194,10 +194,9 @@ pub enum DataType {
     Struct(Fields),
     /// A nested datatype that can represent slots of differing types. 
Components:
     ///
-    /// 1. [`Field`] for each possible child type the Union can hold
-    /// 2. The corresponding `type_id` used to identify which Field
-    /// 3. The type of union (Sparse or Dense)
-    Union(Vec<Field>, Vec<i8>, UnionMode),
+    /// 1. [`UnionFields`]
+    /// 2. The type of union (Sparse or Dense)
+    Union(UnionFields, UnionMode),
     /// A dictionary encoded array (`key_type`, `value_type`), where
     /// each array element is an index of `key_type` into an
     /// associated dictionary of `value_type`.
@@ -384,7 +383,7 @@ impl DataType {
             | FixedSizeList(_, _)
             | LargeList(_)
             | Struct(_)
-            | Union(_, _, _)
+            | Union(_, _)
             | Map(_, _) => true,
             _ => false,
         }
@@ -446,7 +445,7 @@ impl DataType {
             DataType::List(_) | DataType::LargeList(_) | DataType::Map(_, _) 
=> None,
             DataType::FixedSizeList(_, _) => None,
             DataType::Struct(_) => None,
-            DataType::Union(_, _, _) => None,
+            DataType::Union(_, _) => None,
             DataType::Dictionary(_, _) => None,
             DataType::RunEndEncoded(_, _) => None,
         }
@@ -492,13 +491,7 @@ impl DataType {
                 | DataType::LargeList(field)
                 | DataType::Map(field, _) => field.size(),
                 DataType::Struct(fields) => fields.size(),
-                DataType::Union(fields, _, _) => {
-                    fields
-                        .iter()
-                        .map(|field| field.size() - 
std::mem::size_of_val(field))
-                        .sum::<usize>()
-                        + (std::mem::size_of::<Field>() * fields.capacity())
-                }
+                DataType::Union(fields, _) => fields.size(),
                 DataType::Dictionary(dt1, dt2) => dt1.size() + dt2.size(),
                 DataType::RunEndEncoded(run_ends, values) => {
                     run_ends.size() - std::mem::size_of_val(run_ends) + 
values.size()
@@ -670,4 +663,9 @@ mod tests {
             Box::new(list)
         )));
     }
+
+    #[test]
+    fn size_should_not_regress() {
+        assert_eq!(std::mem::size_of::<DataType>(), 24);
+    }
 }
diff --git a/arrow-schema/src/ffi.rs b/arrow-schema/src/ffi.rs
index 0cfc1800f..72afc5b0b 100644
--- a/arrow-schema/src/ffi.rs
+++ b/arrow-schema/src/ffi.rs
@@ -34,7 +34,9 @@
 //! assert_eq!(schema, back);
 //! ```
 
-use crate::{ArrowError, DataType, Field, FieldRef, Schema, TimeUnit, 
UnionMode};
+use crate::{
+    ArrowError, DataType, Field, FieldRef, Schema, TimeUnit, UnionFields, 
UnionMode,
+};
 use bitflags::bitflags;
 use std::sync::Arc;
 use std::{
@@ -484,7 +486,7 @@ impl TryFrom<&FFI_ArrowSchema> for DataType {
                             ));
                         }
 
-                        DataType::Union(fields, type_ids, UnionMode::Dense)
+                        DataType::Union(UnionFields::new(type_ids, fields), 
UnionMode::Dense)
                     }
                     // SparseUnion
                     ["+us", extra] => {
@@ -506,7 +508,7 @@ impl TryFrom<&FFI_ArrowSchema> for DataType {
                             ));
                         }
 
-                        DataType::Union(fields, type_ids, UnionMode::Sparse)
+                        DataType::Union(UnionFields::new(type_ids, fields), 
UnionMode::Sparse)
                     }
 
                     // Timestamps in format "tts:" and "tts:America/New_York" 
for no timezones and timezones resp.
@@ -585,9 +587,9 @@ impl TryFrom<&DataType> for FFI_ArrowSchema {
             | DataType::Map(child, _) => {
                 vec![FFI_ArrowSchema::try_from(child.as_ref())?]
             }
-            DataType::Union(fields, _, _) => fields
+            DataType::Union(fields, _) => fields
                 .iter()
-                .map(FFI_ArrowSchema::try_from)
+                .map(|(_, f)| f.as_ref().try_into())
                 .collect::<Result<Vec<_>, ArrowError>>()?,
             DataType::Struct(fields) => fields
                 .iter()
@@ -658,8 +660,11 @@ fn get_format_string(dtype: &DataType) -> Result<String, 
ArrowError> {
         DataType::Struct(_) => Ok("+s".to_string()),
         DataType::Map(_, _) => Ok("+m".to_string()),
         DataType::Dictionary(key_data_type, _) => 
get_format_string(key_data_type),
-        DataType::Union(_, type_ids, mode) => {
-            let formats = type_ids.iter().map(|t| 
t.to_string()).collect::<Vec<_>>();
+        DataType::Union(fields, mode) => {
+            let formats = fields
+                .iter()
+                .map(|(t, _)| t.to_string())
+                .collect::<Vec<_>>();
             match mode {
                 UnionMode::Dense => Ok(format!("{}:{}", "+ud", 
formats.join(","))),
                 UnionMode::Sparse => Ok(format!("{}:{}", "+us", 
formats.join(","))),
diff --git a/arrow-schema/src/field.rs b/arrow-schema/src/field.rs
index 8ef9fd2b8..d68392f51 100644
--- a/arrow-schema/src/field.rs
+++ b/arrow-schema/src/field.rs
@@ -235,8 +235,8 @@ impl Field {
     fn _fields(dt: &DataType) -> Vec<&Field> {
         match dt {
             DataType::Struct(fields) => fields.iter().flat_map(|f| 
f.fields()).collect(),
-            DataType::Union(fields, _, _) => {
-                fields.iter().flat_map(|f| f.fields()).collect()
+            DataType::Union(fields, _) => {
+                fields.iter().flat_map(|(_, f)| f.fields()).collect()
             }
             DataType::List(field)
             | DataType::LargeList(field)
@@ -341,36 +341,9 @@ impl Field {
                             self.name, from.data_type)
                 ))}
             },
-            DataType::Union(nested_fields, type_ids, _) => match 
&from.data_type {
-                DataType::Union(from_nested_fields, from_type_ids, _) => {
-                    for (idx, from_field) in 
from_nested_fields.iter().enumerate() {
-                        let mut is_new_field = true;
-                        let field_type_id = from_type_ids.get(idx).unwrap();
-
-                        for (self_idx, self_field) in 
nested_fields.iter_mut().enumerate()
-                        {
-                            if from_field == self_field {
-                                let self_type_id = 
type_ids.get(self_idx).unwrap();
-
-                                // If the nested fields in two unions are the 
same, they must have same
-                                // type id.
-                                if self_type_id != field_type_id {
-                                    return Err(ArrowError::SchemaError(
-                                        format!("Fail to merge schema field 
'{}' because the self_type_id = {} does not equal field_type_id = {}",
-                                            self.name, self_type_id, 
field_type_id)
-                                    ));
-                                }
-
-                                is_new_field = false;
-                                break;
-                            }
-                        }
-
-                        if is_new_field {
-                            nested_fields.push(from_field.clone());
-                            type_ids.push(*field_type_id);
-                        }
-                    }
+            DataType::Union(nested_fields, _) => match &from.data_type {
+                DataType::Union(from_nested_fields, _) => {
+                    nested_fields.try_merge(from_nested_fields)?
                 }
                 _ => {
                     return Err(ArrowError::SchemaError(
diff --git a/arrow-schema/src/fields.rs b/arrow-schema/src/fields.rs
index 268226136..1de5e5efd 100644
--- a/arrow-schema/src/fields.rs
+++ b/arrow-schema/src/fields.rs
@@ -15,13 +15,13 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use crate::{Field, FieldRef};
+use crate::{ArrowError, Field, FieldRef};
 use std::ops::Deref;
 use std::sync::Arc;
 
 /// A cheaply cloneable, owned slice of [`FieldRef`]
 ///
-/// Similar to `Arc<Vec<FieldPtr>>` or `Arc<[FieldPtr]>`
+/// Similar to `Arc<Vec<FieldRef>>` or `Arc<[FieldRef]>`
 ///
 /// Can be constructed in a number of ways
 ///
@@ -55,7 +55,9 @@ impl Fields {
 
     /// Return size of this instance in bytes.
     pub fn size(&self) -> usize {
-        self.iter().map(|field| field.size()).sum()
+        self.iter()
+            .map(|field| field.size() + std::mem::size_of::<FieldRef>())
+            .sum()
     }
 
     /// Searches for a field by name, returning it along with its index if 
found
@@ -148,3 +150,112 @@ impl<'de> serde::Deserialize<'de> for Fields {
         Ok(Vec::<Field>::deserialize(deserializer)?.into())
     }
 }
+
+/// A cheaply cloneable, owned collection of [`FieldRef`] and their 
corresponding type ids
+#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+#[cfg_attr(feature = "serde", serde(transparent))]
+pub struct UnionFields(Arc<[(i8, FieldRef)]>);
+
+impl std::fmt::Debug for UnionFields {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        self.0.as_ref().fmt(f)
+    }
+}
+
+impl UnionFields {
+    /// Create a new [`UnionFields`] with no fields
+    pub fn empty() -> Self {
+        Self(Arc::from([]))
+    }
+
+    /// Create a new [`UnionFields`] from a [`Fields`] and array of type_ids
+    ///
+    /// See <https://arrow.apache.org/docs/format/Columnar.html#union-layout>
+    ///
+    /// ```
+    /// use arrow_schema::{DataType, Field, UnionFields};
+    /// // Create a new UnionFields with type id mapping
+    /// // 1 -> DataType::UInt8
+    /// // 3 -> DataType::Utf8
+    /// UnionFields::new(
+    ///     vec![1, 3],
+    ///     vec![
+    ///         Field::new("field1", DataType::UInt8, false),
+    ///         Field::new("field3", DataType::Utf8, false),
+    ///     ],
+    /// );
+    /// ```
+    pub fn new<F, T>(type_ids: T, fields: F) -> Self
+    where
+        F: IntoIterator,
+        F::Item: Into<FieldRef>,
+        T: IntoIterator<Item = i8>,
+    {
+        let fields = fields.into_iter().map(Into::into);
+        type_ids.into_iter().zip(fields).collect()
+    }
+
+    /// Return size of this instance in bytes.
+    pub fn size(&self) -> usize {
+        self.iter()
+            .map(|(_, field)| field.size() + std::mem::size_of::<(i8, 
FieldRef)>())
+            .sum()
+    }
+
+    /// Returns the number of fields in this [`UnionFields`]
+    pub fn len(&self) -> usize {
+        self.0.len()
+    }
+
+    /// Returns `true` if this is empty
+    pub fn is_empty(&self) -> bool {
+        self.0.is_empty()
+    }
+
+    /// Returns an iterator over the fields and type ids in this 
[`UnionFields`]
+    ///
+    /// Note: the iteration order is not guaranteed
+    pub fn iter(&self) -> impl Iterator<Item = (i8, &FieldRef)> + '_ {
+        self.0.iter().map(|(id, f)| (*id, f))
+    }
+
+    /// Merge this field into self if it is compatible.
+    ///
+    /// See [`Field::try_merge`]
+    pub(crate) fn try_merge(&mut self, other: &Self) -> Result<(), ArrowError> 
{
+        // TODO: This currently may produce duplicate type IDs (#3982)
+        let mut output: Vec<_> = self.iter().map(|(id, f)| (id, 
f.clone())).collect();
+        for (field_type_id, from_field) in other.iter() {
+            let mut is_new_field = true;
+            for (self_type_id, self_field) in output.iter_mut() {
+                if from_field == self_field {
+                    // If the nested fields in two unions are the same, they 
must have same
+                    // type id.
+                    if *self_type_id != field_type_id {
+                        return Err(ArrowError::SchemaError(
+                            format!("Fail to merge schema field '{}' because 
the self_type_id = {} does not equal field_type_id = {}",
+                                    self_field.name(), self_type_id, 
field_type_id)
+                        ));
+                    }
+
+                    is_new_field = false;
+                    break;
+                }
+            }
+
+            if is_new_field {
+                output.push((field_type_id, from_field.clone()))
+            }
+        }
+        *self = output.into_iter().collect();
+        Ok(())
+    }
+}
+
+impl FromIterator<(i8, FieldRef)> for UnionFields {
+    fn from_iter<T: IntoIterator<Item = (i8, FieldRef)>>(iter: T) -> Self {
+        // TODO: Should this validate type IDs are unique (#3982)
+        Self(iter.into_iter().collect())
+    }
+}
diff --git a/arrow-schema/src/schema.rs b/arrow-schema/src/schema.rs
index 6089c1ae5..501c5c7fd 100644
--- a/arrow-schema/src/schema.rs
+++ b/arrow-schema/src/schema.rs
@@ -365,7 +365,7 @@ impl Hash for Schema {
 #[cfg(test)]
 mod tests {
     use crate::datatype::DataType;
-    use crate::{TimeUnit, UnionMode};
+    use crate::{TimeUnit, UnionFields, UnionMode};
 
     use super::*;
 
@@ -778,11 +778,13 @@ mod tests {
                 Schema::new(vec![Field::new(
                     "c1",
                     DataType::Union(
-                        vec![
-                            Field::new("c11", DataType::Utf8, true),
-                            Field::new("c12", DataType::Utf8, true),
-                        ],
-                        vec![0, 1],
+                        UnionFields::new(
+                            vec![0, 1],
+                            vec![
+                                Field::new("c11", DataType::Utf8, true),
+                                Field::new("c12", DataType::Utf8, true),
+                            ]
+                        ),
                         UnionMode::Dense
                     ),
                     false
@@ -790,11 +792,17 @@ mod tests {
                 Schema::new(vec![Field::new(
                     "c1",
                     DataType::Union(
-                        vec![
-                            Field::new("c12", DataType::Utf8, true),
-                            Field::new("c13", 
DataType::Time64(TimeUnit::Second), true),
-                        ],
-                        vec![1, 2],
+                        UnionFields::new(
+                            vec![1, 2],
+                            vec![
+                                Field::new("c12", DataType::Utf8, true),
+                                Field::new(
+                                    "c13",
+                                    DataType::Time64(TimeUnit::Second),
+                                    true
+                                ),
+                            ]
+                        ),
                         UnionMode::Dense
                     ),
                     false
@@ -804,12 +812,14 @@ mod tests {
             Schema::new(vec![Field::new(
                 "c1",
                 DataType::Union(
-                    vec![
-                        Field::new("c11", DataType::Utf8, true),
-                        Field::new("c12", DataType::Utf8, true),
-                        Field::new("c13", DataType::Time64(TimeUnit::Second), 
true),
-                    ],
-                    vec![0, 1, 2],
+                    UnionFields::new(
+                        vec![0, 1, 2],
+                        vec![
+                            Field::new("c11", DataType::Utf8, true),
+                            Field::new("c12", DataType::Utf8, true),
+                            Field::new("c13", 
DataType::Time64(TimeUnit::Second), true),
+                        ]
+                    ),
                     UnionMode::Dense
                 ),
                 false
diff --git a/arrow/src/datatypes/mod.rs b/arrow/src/datatypes/mod.rs
index d1977d42b..74dad6b4a 100644
--- a/arrow/src/datatypes/mod.rs
+++ b/arrow/src/datatypes/mod.rs
@@ -30,7 +30,7 @@ pub use arrow_buffer::{i256, ArrowNativeType, ToByteSlice};
 pub use arrow_data::decimal::*;
 pub use arrow_schema::{
     DataType, Field, FieldRef, Fields, IntervalUnit, Schema, SchemaBuilder, 
SchemaRef,
-    TimeUnit, UnionMode,
+    TimeUnit, UnionFields, UnionMode,
 };
 
 #[cfg(feature = "ffi")]
diff --git a/arrow/src/ffi.rs b/arrow/src/ffi.rs
index 2d6bbf1a0..fe2e186a7 100644
--- a/arrow/src/ffi.rs
+++ b/arrow/src/ffi.rs
@@ -174,15 +174,15 @@ fn bit_width(data_type: &DataType, i: usize) -> 
Result<usize> {
             )))
         }
         // type ids. UnionArray doesn't have null bitmap so buffer index 
begins with 0.
-        (DataType::Union(_, _, _), 0) => i8::BITS as _,
+        (DataType::Union(_, _), 0) => i8::BITS as _,
         // Only DenseUnion has 2nd buffer
-        (DataType::Union(_, _, UnionMode::Dense), 1) => i32::BITS as _,
-        (DataType::Union(_, _, UnionMode::Sparse), _) => {
+        (DataType::Union(_, UnionMode::Dense), 1) => i32::BITS as _,
+        (DataType::Union(_, UnionMode::Sparse), _) => {
             return Err(ArrowError::CDataInterface(format!(
                 "The datatype \"{data_type:?}\" expects 1 buffer, but 
requested {i}. Please verify that the C data interface is correctly 
implemented."
             )))
         }
-        (DataType::Union(_, _, UnionMode::Dense), _) => {
+        (DataType::Union(_, UnionMode::Dense), _) => {
             return Err(ArrowError::CDataInterface(format!(
                 "The datatype \"{data_type:?}\" expects 2 buffer, but 
requested {i}. Please verify that the C data interface is correctly 
implemented."
             )))
diff --git a/arrow/tests/array_cast.rs b/arrow/tests/array_cast.rs
index b113ec04c..27fb1dcd2 100644
--- a/arrow/tests/array_cast.rs
+++ b/arrow/tests/array_cast.rs
@@ -41,7 +41,7 @@ use arrow_cast::pretty::pretty_format_columns;
 use arrow_cast::{can_cast_types, cast};
 use arrow_data::ArrayData;
 use arrow_schema::{
-    ArrowError, DataType, Field, Fields, IntervalUnit, TimeUnit, UnionMode,
+    ArrowError, DataType, Field, Fields, IntervalUnit, TimeUnit, UnionFields, 
UnionMode,
 };
 use half::f16;
 use std::sync::Arc;
@@ -405,11 +405,13 @@ fn get_all_types() -> Vec<DataType> {
             Field::new("f2", DataType::Utf8, true),
         ])),
         Union(
-            vec![
-                Field::new("f1", DataType::Int32, false),
-                Field::new("f2", DataType::Utf8, true),
-            ],
-            vec![0, 1],
+            UnionFields::new(
+                vec![0, 1],
+                vec![
+                    Field::new("f1", DataType::Int32, false),
+                    Field::new("f2", DataType::Utf8, true),
+                ],
+            ),
             UnionMode::Dense,
         ),
         Decimal128(38, 0),
diff --git a/arrow/tests/array_validation.rs b/arrow/tests/array_validation.rs
index 73e013ff1..ef0d40d64 100644
--- a/arrow/tests/array_validation.rs
+++ b/arrow/tests/array_validation.rs
@@ -22,7 +22,7 @@ use arrow::array::{
 use arrow_array::Decimal128Array;
 use arrow_buffer::{ArrowNativeType, Buffer};
 use arrow_data::ArrayData;
-use arrow_schema::{DataType, Field, UnionMode};
+use arrow_schema::{DataType, Field, UnionFields, UnionMode};
 use std::ptr::NonNull;
 use std::sync::Arc;
 
@@ -768,11 +768,13 @@ fn test_validate_union_different_types() {
 
     ArrayData::try_new(
         DataType::Union(
-            vec![
-                Field::new("field1", DataType::Int32, true),
-                Field::new("field2", DataType::Int64, true), // data is int32
-            ],
-            vec![0, 1],
+            UnionFields::new(
+                vec![0, 1],
+                vec![
+                    Field::new("field1", DataType::Int32, true),
+                    Field::new("field2", DataType::Int64, true), // data is 
int32
+                ],
+            ),
             UnionMode::Sparse,
         ),
         2,
@@ -799,11 +801,13 @@ fn test_validate_union_sparse_different_child_len() {
 
     ArrayData::try_new(
         DataType::Union(
-            vec![
-                Field::new("field1", DataType::Int32, true),
-                Field::new("field2", DataType::Int64, true),
-            ],
-            vec![0, 1],
+            UnionFields::new(
+                vec![0, 1],
+                vec![
+                    Field::new("field1", DataType::Int32, true),
+                    Field::new("field2", DataType::Int64, true),
+                ],
+            ),
             UnionMode::Sparse,
         ),
         2,
@@ -826,11 +830,13 @@ fn test_validate_union_dense_without_offsets() {
 
     ArrayData::try_new(
         DataType::Union(
-            vec![
-                Field::new("field1", DataType::Int32, true),
-                Field::new("field2", DataType::Int64, true),
-            ],
-            vec![0, 1],
+            UnionFields::new(
+                vec![0, 1],
+                vec![
+                    Field::new("field1", DataType::Int32, true),
+                    Field::new("field2", DataType::Int64, true),
+                ],
+            ),
             UnionMode::Dense,
         ),
         2,
@@ -854,11 +860,13 @@ fn test_validate_union_dense_with_bad_len() {
 
     ArrayData::try_new(
         DataType::Union(
-            vec![
-                Field::new("field1", DataType::Int32, true),
-                Field::new("field2", DataType::Int64, true),
-            ],
-            vec![0, 1],
+            UnionFields::new(
+                vec![0, 1],
+                vec![
+                    Field::new("field1", DataType::Int32, true),
+                    Field::new("field2", DataType::Int64, true),
+                ],
+            ),
             UnionMode::Dense,
         ),
         2,
diff --git a/parquet/src/arrow/arrow_writer/mod.rs 
b/parquet/src/arrow/arrow_writer/mod.rs
index 94c19cb2e..f594f2f79 100644
--- a/parquet/src/arrow/arrow_writer/mod.rs
+++ b/parquet/src/arrow/arrow_writer/mod.rs
@@ -360,7 +360,7 @@ fn write_leaves<W: Write>(
         ArrowDataType::Float16 => Err(ParquetError::ArrowError(
             "Float16 arrays not supported".to_string(),
         )),
-        ArrowDataType::FixedSizeList(_, _) | ArrowDataType::Union(_, _, _) | 
ArrowDataType::RunEndEncoded(_, _) => {
+        ArrowDataType::FixedSizeList(_, _) | ArrowDataType::Union(_, _) | 
ArrowDataType::RunEndEncoded(_, _) => {
             Err(ParquetError::NYI(
                 format!(
                     "Attempting to write an Arrow type {data_type:?} to 
parquet that is not yet implemented"
diff --git a/parquet/src/arrow/schema/mod.rs b/parquet/src/arrow/schema/mod.rs
index 09109d290..b541a754b 100644
--- a/parquet/src/arrow/schema/mod.rs
+++ b/parquet/src/arrow/schema/mod.rs
@@ -501,7 +501,7 @@ fn arrow_to_parquet_type(field: &Field) -> Result<Type> {
                 ))
             }
         }
-        DataType::Union(_, _, _) => unimplemented!("See ARROW-8817."),
+        DataType::Union(_, _) => unimplemented!("See ARROW-8817."),
         DataType::Dictionary(_, ref value) => {
             // Dictionary encoding not handled at the schema level
             let dict_field = Field::new(name, *value.clone(), 
field.is_nullable());

Reply via email to