This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 20ffaf8ef Add type ids in Union datatype (#1703)
20ffaf8ef is described below

commit 20ffaf8ef9be737f60b59d5a6ae258962bd191cb
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Tue May 17 12:52:27 2022 -0700

    Add type ids in Union datatype (#1703)
    
    * Store type ids in Union datatype
    
    * Add doc as suggested and put type ids in ipc
    
    * Add test
    
    * Fix equal_dense
    
    * Fix clippy
---
 arrow/src/array/array.rs          |  4 ++--
 arrow/src/array/array_union.rs    | 35 ++++++++++++++++++++++++-----------
 arrow/src/array/builder.rs        |  4 +++-
 arrow/src/array/data.rs           | 20 ++++++++++++--------
 arrow/src/array/equal/mod.rs      |  2 +-
 arrow/src/array/equal/union.rs    | 26 +++++++++++++++++++++-----
 arrow/src/array/equal/utils.rs    |  2 +-
 arrow/src/array/transform/mod.rs  |  6 +++---
 arrow/src/compute/kernels/cast.rs |  1 +
 arrow/src/datatypes/datatype.rs   | 25 ++++++++++---------------
 arrow/src/datatypes/field.rs      | 39 ++++++++++++++++++++++++---------------
 arrow/src/datatypes/mod.rs        | 18 ++++++------------
 arrow/src/ipc/convert.rs          | 39 +++++++++++++++++++++++++++++++++++----
 arrow/src/ipc/reader.rs           |  5 +++--
 arrow/src/ipc/writer.rs           |  7 ++++---
 arrow/src/util/display.rs         | 21 +++++++++++----------
 arrow/src/util/pretty.rs          |  6 +++++-
 integration-testing/src/lib.rs    | 35 +++++------------------------------
 parquet/src/arrow/arrow_writer.rs |  2 +-
 parquet/src/arrow/levels.rs       |  6 +++---
 parquet/src/arrow/schema.rs       |  2 +-
 21 files changed, 176 insertions(+), 129 deletions(-)

diff --git a/arrow/src/array/array.rs b/arrow/src/array/array.rs
index 421e60f04..ed99a6b9f 100644
--- a/arrow/src/array/array.rs
+++ b/arrow/src/array/array.rs
@@ -364,7 +364,7 @@ pub fn make_array(data: ArrayData) -> ArrayRef {
         DataType::LargeList(_) => Arc::new(LargeListArray::from(data)) as 
ArrayRef,
         DataType::Struct(_) => Arc::new(StructArray::from(data)) as ArrayRef,
         DataType::Map(_, _) => Arc::new(MapArray::from(data)) as ArrayRef,
-        DataType::Union(_, _) => Arc::new(UnionArray::from(data)) as ArrayRef,
+        DataType::Union(_, _, _) => Arc::new(UnionArray::from(data)) as 
ArrayRef,
         DataType::FixedSizeList(_, _) => {
             Arc::new(FixedSizeListArray::from(data)) as ArrayRef
         }
@@ -535,7 +535,7 @@ pub fn new_null_array(data_type: &DataType, length: usize) 
-> ArrayRef {
         DataType::Map(field, _keys_sorted) => {
             new_null_list_array::<i32>(data_type, field.data_type(), length)
         }
-        DataType::Union(_, _) => {
+        DataType::Union(_, _, _) => {
             unimplemented!("Creating null Union array not yet supported")
         }
         DataType::Dictionary(key, value) => {
diff --git a/arrow/src/array/array_union.rs b/arrow/src/array/array_union.rs
index 5ebf3d2d3..5cfab0bbf 100644
--- a/arrow/src/array/array_union.rs
+++ b/arrow/src/array/array_union.rs
@@ -58,6 +58,7 @@ use std::any::Any;
 /// ];
 ///
 /// let array = UnionArray::try_new(
+///     &vec![0, 1],
 ///     type_id_buffer,
 ///     Some(value_offsets_buffer),
 ///     children,
@@ -90,6 +91,7 @@ use std::any::Any;
 /// ];
 ///
 /// let array = UnionArray::try_new(
+///     &vec![0, 1],
 ///     type_id_buffer,
 ///     None,
 ///     children,
@@ -135,6 +137,7 @@ impl UnionArray {
     /// `i8` and `i32` values respectively.  `Buffer` objects are untyped and 
no attempt is made
     /// to ensure that the data provided is valid.
     pub unsafe fn new_unchecked(
+        field_type_ids: &[i8],
         type_ids: Buffer,
         value_offsets: Option<Buffer>,
         child_arrays: Vec<(Field, ArrayRef)>,
@@ -149,10 +152,14 @@ impl UnionArray {
             UnionMode::Sparse
         };
 
-        let builder = ArrayData::builder(DataType::Union(field_types, mode))
-            .add_buffer(type_ids)
-            .child_data(field_values.into_iter().map(|a| 
a.data().clone()).collect())
-            .len(len);
+        let builder = ArrayData::builder(DataType::Union(
+            field_types,
+            Vec::from(field_type_ids),
+            mode,
+        ))
+        .add_buffer(type_ids)
+        .child_data(field_values.into_iter().map(|a| 
a.data().clone()).collect())
+        .len(len);
 
         let data = match value_offsets {
             Some(b) => builder.add_buffer(b).build_unchecked(),
@@ -163,6 +170,7 @@ impl UnionArray {
 
     /// Attempts to create a new `UnionArray`, validating the inputs provided.
     pub fn try_new(
+        field_type_ids: &[i8],
         type_ids: Buffer,
         value_offsets: Option<Buffer>,
         child_arrays: Vec<(Field, ArrayRef)>,
@@ -209,8 +217,9 @@ impl UnionArray {
 
         // Unsafe Justification: arguments were validated above (and
         // re-revalidated as part of data().validate() below)
-        let new_self =
-            unsafe { Self::new_unchecked(type_ids, value_offsets, 
child_arrays) };
+        let new_self = unsafe {
+            Self::new_unchecked(field_type_ids, type_ids, value_offsets, 
child_arrays)
+        };
         new_self.data().validate()?;
 
         Ok(new_self)
@@ -269,7 +278,7 @@ impl UnionArray {
     /// Returns the names of the types in the union.
     pub fn type_names(&self) -> Vec<&str> {
         match self.data.data_type() {
-            DataType::Union(fields, _) => fields
+            DataType::Union(fields, _, _) => fields
                 .iter()
                 .map(|f| f.name().as_str())
                 .collect::<Vec<&str>>(),
@@ -280,7 +289,7 @@ impl UnionArray {
     /// Returns whether the `UnionArray` is dense (or sparse if `false`).
     fn is_dense(&self) -> bool {
         match self.data.data_type() {
-            DataType::Union(_, mode) => mode == &UnionMode::Dense,
+            DataType::Union(_, _, mode) => mode == &UnionMode::Dense,
             _ => unreachable!("Union array's data type is not a union!"),
         }
     }
@@ -626,9 +635,13 @@ mod tests {
                 Arc::new(float_array),
             ),
         ];
-        let array =
-            UnionArray::try_new(type_id_buffer, Some(value_offsets_buffer), 
children)
-                .unwrap();
+        let array = UnionArray::try_new(
+            &[0, 1, 2],
+            type_id_buffer,
+            Some(value_offsets_buffer),
+            children,
+        )
+        .unwrap();
 
         // Check type ids
         assert_eq!(Buffer::from_slice_ref(&type_ids), 
array.data().buffers()[0]);
diff --git a/arrow/src/array/builder.rs b/arrow/src/array/builder.rs
index da6d2f1c3..091a51b15 100644
--- a/arrow/src/array/builder.rs
+++ b/arrow/src/array/builder.rs
@@ -2168,7 +2168,9 @@ impl UnionBuilder {
         });
         let children: Vec<_> = children.into_iter().map(|(_, b)| b).collect();
 
-        UnionArray::try_new(type_id_buffer, value_offsets_buffer, children)
+        let type_ids: Vec<i8> = (0_i8..children.len() as i8).collect();
+
+        UnionArray::try_new(&type_ids, type_id_buffer, value_offsets_buffer, 
children)
     }
 }
 
diff --git a/arrow/src/array/data.rs b/arrow/src/array/data.rs
index f7eacb0a4..15cd9cc5f 100644
--- a/arrow/src/array/data.rs
+++ b/arrow/src/array/data.rs
@@ -194,7 +194,7 @@ pub(crate) fn new_buffers(data_type: &DataType, capacity: 
usize) -> [MutableBuff
             MutableBuffer::new(capacity * mem::size_of::<u8>()),
             empty_buffer,
         ],
-        DataType::Union(_, mode) => {
+        DataType::Union(_, _, mode) => {
             let type_ids = MutableBuffer::new(capacity * mem::size_of::<i8>());
             match mode {
                 UnionMode::Sparse => [type_ids, empty_buffer],
@@ -220,7 +220,7 @@ pub(crate) fn into_buffers(
         | DataType::Binary
         | DataType::LargeUtf8
         | DataType::LargeBinary => vec![buffer1.into(), buffer2.into()],
-        DataType::Union(_, mode) => {
+        DataType::Union(_, _, mode) => {
             match mode {
                 // Based on Union's DataTypeLayout
                 UnionMode::Sparse => vec![buffer1.into()],
@@ -581,7 +581,7 @@ impl ArrayData {
             DataType::Map(field, _) => {
                 vec![Self::new_empty(field.data_type())]
             }
-            DataType::Union(fields, _) => fields
+            DataType::Union(fields, _, _) => fields
                 .iter()
                 .map(|field| Self::new_empty(field.data_type()))
                 .collect(),
@@ -856,7 +856,7 @@ impl ArrayData {
                 }
                 Ok(())
             }
-            DataType::Union(fields, mode) => {
+            DataType::Union(fields, _, mode) => {
                 self.validate_num_child_data(fields.len())?;
 
                 for (i, field) in fields.iter().enumerate() {
@@ -1004,7 +1004,7 @@ impl ArrayData {
                 let child = &self.child_data[0];
                 self.validate_offsets_full::<i64>(child.len + child.offset)
             }
-            DataType::Union(_, _) => {
+            DataType::Union(_, _, _) => {
                 // Validate Union Array as part of implementing new Union 
semantics
                 // See comments in `ArrayData::validate()`
                 // https://github.com/apache/arrow-rs/issues/85
@@ -1269,7 +1269,7 @@ fn layout(data_type: &DataType) -> DataTypeLayout {
         DataType::FixedSizeList(_, _) => DataTypeLayout::new_empty(), // all 
in child data
         DataType::LargeList(_) => 
DataTypeLayout::new_fixed_width(size_of::<i32>()),
         DataType::Struct(_) => DataTypeLayout::new_empty(), // all in child 
data,
-        DataType::Union(_, mode) => {
+        DataType::Union(_, _, mode) => {
             let type_ids = BufferSpec::FixedWidth {
                 byte_width: size_of::<i8>(),
             };
@@ -2505,6 +2505,7 @@ mod tests {
                     Field::new("field1", DataType::Int32, true),
                     Field::new("field2", DataType::Int64, true), // data is 
int32
                 ],
+                vec![0, 1],
                 UnionMode::Sparse,
             ),
             2,
@@ -2536,6 +2537,7 @@ mod tests {
                     Field::new("field1", DataType::Int32, true),
                     Field::new("field2", DataType::Int64, true),
                 ],
+                vec![0, 1],
                 UnionMode::Sparse,
             ),
             2,
@@ -2563,6 +2565,7 @@ mod tests {
                     Field::new("field1", DataType::Int32, true),
                     Field::new("field2", DataType::Int64, true),
                 ],
+                vec![0, 1],
                 UnionMode::Dense,
             ),
             2,
@@ -2593,6 +2596,7 @@ mod tests {
                     Field::new("field1", DataType::Int32, true),
                     Field::new("field2", DataType::Int64, true),
                 ],
+                vec![0, 1],
                 UnionMode::Dense,
             ),
             2,
@@ -2705,8 +2709,8 @@ mod tests {
     #[test]
     fn test_into_buffers() {
         let data_types = vec![
-            DataType::Union(vec![], UnionMode::Dense),
-            DataType::Union(vec![], UnionMode::Sparse),
+            DataType::Union(vec![], vec![], UnionMode::Dense),
+            DataType::Union(vec![], vec![], UnionMode::Sparse),
         ];
 
         for data_type in data_types {
diff --git a/arrow/src/array/equal/mod.rs b/arrow/src/array/equal/mod.rs
index 1a6b9f331..c45b30ccc 100644
--- a/arrow/src/array/equal/mod.rs
+++ b/arrow/src/array/equal/mod.rs
@@ -193,7 +193,7 @@ fn equal_values(
             fixed_list_equal(lhs, rhs, lhs_start, rhs_start, len)
         }
         DataType::Struct(_) => struct_equal(lhs, rhs, lhs_start, rhs_start, 
len),
-        DataType::Union(_, _) => union_equal(lhs, rhs, lhs_start, rhs_start, 
len),
+        DataType::Union(_, _, _) => union_equal(lhs, rhs, lhs_start, 
rhs_start, len),
         DataType::Dictionary(data_type, _) => match data_type.as_ref() {
             DataType::Int8 => dictionary_equal::<i8>(lhs, rhs, lhs_start, 
rhs_start, len),
             DataType::Int16 => {
diff --git a/arrow/src/array/equal/union.rs b/arrow/src/array/equal/union.rs
index 021b0a3b7..e8b9d27b6 100644
--- a/arrow/src/array/equal/union.rs
+++ b/arrow/src/array/equal/union.rs
@@ -19,6 +19,7 @@ use crate::{array::ArrayData, datatypes::DataType, 
datatypes::UnionMode};
 
 use super::equal_range;
 
+#[allow(clippy::too_many_arguments)]
 fn equal_dense(
     lhs: &ArrayData,
     rhs: &ArrayData,
@@ -26,6 +27,8 @@ fn equal_dense(
     rhs_type_ids: &[i8],
     lhs_offsets: &[i32],
     rhs_offsets: &[i32],
+    lhs_field_type_ids: &[i8],
+    rhs_field_type_ids: &[i8],
 ) -> bool {
     let offsets = lhs_offsets.iter().zip(rhs_offsets.iter());
 
@@ -34,8 +37,16 @@ fn equal_dense(
         .zip(rhs_type_ids.iter())
         .zip(offsets)
         .all(|((l_type_id, r_type_id), (l_offset, r_offset))| {
-            let lhs_values = &lhs.child_data()[*l_type_id as usize];
-            let rhs_values = &rhs.child_data()[*r_type_id as usize];
+            let lhs_child_index = lhs_field_type_ids
+                .iter()
+                .position(|r| r == l_type_id)
+                .unwrap();
+            let rhs_child_index = rhs_field_type_ids
+                .iter()
+                .position(|r| r == r_type_id)
+                .unwrap();
+            let lhs_values = &lhs.child_data()[lhs_child_index];
+            let rhs_values = &rhs.child_data()[rhs_child_index];
 
             equal_range(
                 lhs_values,
@@ -76,7 +87,10 @@ pub(super) fn union_equal(
     let rhs_type_id_range = &rhs_type_ids[rhs_start..rhs_start + len];
 
     match (lhs.data_type(), rhs.data_type()) {
-        (DataType::Union(_, UnionMode::Dense), DataType::Union(_, 
UnionMode::Dense)) => {
+        (
+            DataType::Union(_, lhs_type_ids, UnionMode::Dense),
+            DataType::Union(_, rhs_type_ids, UnionMode::Dense),
+        ) => {
             let lhs_offsets = lhs.buffer::<i32>(1);
             let rhs_offsets = rhs.buffer::<i32>(1);
 
@@ -91,11 +105,13 @@ pub(super) fn union_equal(
                     rhs_type_id_range,
                     lhs_offsets_range,
                     rhs_offsets_range,
+                    lhs_type_ids,
+                    rhs_type_ids,
                 )
         }
         (
-            DataType::Union(_, UnionMode::Sparse),
-            DataType::Union(_, UnionMode::Sparse),
+            DataType::Union(_, _, UnionMode::Sparse),
+            DataType::Union(_, _, UnionMode::Sparse),
         ) => {
             lhs_type_id_range == rhs_type_id_range
                 && equal_sparse(lhs, rhs, lhs_start, rhs_start, len)
diff --git a/arrow/src/array/equal/utils.rs b/arrow/src/array/equal/utils.rs
index 8875239ca..fed3933a0 100644
--- a/arrow/src/array/equal/utils.rs
+++ b/arrow/src/array/equal/utils.rs
@@ -68,7 +68,7 @@ pub(super) fn equal_nulls(
 #[inline]
 pub(super) fn base_equal(lhs: &ArrayData, rhs: &ArrayData) -> bool {
     let equal_type = match (lhs.data_type(), rhs.data_type()) {
-        (DataType::Union(l_fields, l_mode), DataType::Union(r_fields, r_mode)) 
=> {
+        (DataType::Union(l_fields, _, l_mode), DataType::Union(r_fields, _, 
r_mode)) => {
             l_fields == r_fields && l_mode == r_mode
         }
         (DataType::Map(l_field, l_sorted), DataType::Map(r_field, r_sorted)) 
=> {
diff --git a/arrow/src/array/transform/mod.rs b/arrow/src/array/transform/mod.rs
index aa7d417a1..586a4fec2 100644
--- a/arrow/src/array/transform/mod.rs
+++ b/arrow/src/array/transform/mod.rs
@@ -274,7 +274,7 @@ fn build_extend(array: &ArrayData) -> Extend {
         DataType::FixedSizeBinary(_) => fixed_binary::build_extend(array),
         DataType::Float16 => primitive::build_extend::<f16>(array),
         DataType::FixedSizeList(_, _) => fixed_size_list::build_extend(array),
-        DataType::Union(_, mode) => match mode {
+        DataType::Union(_, _, mode) => match mode {
             UnionMode::Sparse => union::build_extend_sparse(array),
             UnionMode::Dense => union::build_extend_dense(array),
         },
@@ -325,7 +325,7 @@ fn build_extend_nulls(data_type: &DataType) -> ExtendNulls {
         DataType::FixedSizeBinary(_) => fixed_binary::extend_nulls,
         DataType::Float16 => primitive::extend_nulls::<f16>,
         DataType::FixedSizeList(_, _) => fixed_size_list::extend_nulls,
-        DataType::Union(_, mode) => match mode {
+        DataType::Union(_, _, mode) => match mode {
             UnionMode::Sparse => union::extend_nulls_sparse,
             UnionMode::Dense => union::extend_nulls_dense,
         },
@@ -524,7 +524,7 @@ impl<'a> MutableArrayData<'a> {
                     .collect::<Vec<_>>();
                 vec![MutableArrayData::new(childs, use_nulls, array_capacity)]
             }
-            DataType::Union(fields, _) => (0..fields.len())
+            DataType::Union(fields, _, _) => (0..fields.len())
                 .map(|i| {
                     let child_arrays = arrays
                         .iter()
diff --git a/arrow/src/compute/kernels/cast.rs 
b/arrow/src/compute/kernels/cast.rs
index 2c0ebb1e2..c989cd2fe 100644
--- a/arrow/src/compute/kernels/cast.rs
+++ b/arrow/src/compute/kernels/cast.rs
@@ -4776,6 +4776,7 @@ mod tests {
                     Field::new("f1", DataType::Int32, false),
                     Field::new("f2", DataType::Utf8, true),
                 ],
+                vec![0, 1],
                 UnionMode::Dense,
             ),
             Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
diff --git a/arrow/src/datatypes/datatype.rs b/arrow/src/datatypes/datatype.rs
index c5cc8f017..a740e8ecc 100644
--- a/arrow/src/datatypes/datatype.rs
+++ b/arrow/src/datatypes/datatype.rs
@@ -114,8 +114,12 @@ pub enum DataType {
     LargeList(Box<Field>),
     /// A nested datatype that contains a number of sub-fields.
     Struct(Vec<Field>),
-    /// A nested datatype that can represent slots of differing types.
-    Union(Vec<Field>, UnionMode),
+    /// A nested datatype that can represent slots of differing types. 
Components:
+    ///
+    /// 1. [`Field`] for each possible child type the Union can hold
+    /// 2. The corresponding `type_id` used to identify which Field
+    /// 3. The type of union (Sparse or Dense)
+    Union(Vec<Field>, Vec<i8>, UnionMode),
     /// A dictionary encoded array (`key_type`, `value_type`), where
     /// each array element is an index of `key_type` into an
     /// associated dictionary of `value_type`.
@@ -516,24 +520,15 @@ impl DataType {
                                 .as_array()
                                 .unwrap()
                                 .iter()
-                                .map(|t| t.as_i64().unwrap())
+                                .map(|t| t.as_i64().unwrap() as i8)
                                 .collect::<Vec<_>>();
 
                             let default_fields = type_ids
                                 .iter()
-                                .map(|t| {
-                                    Field::new("", DataType::Boolean, 
true).with_metadata(
-                                        Some(
-                                            [("type_id".to_string(), 
t.to_string())]
-                                                .iter()
-                                                .cloned()
-                                                .collect(),
-                                        ),
-                                    )
-                                })
+                                .map(|_| default_field.clone())
                                 .collect::<Vec<_>>();
 
-                            Ok(DataType::Union(default_fields, union_mode))
+                            Ok(DataType::Union(default_fields, type_ids, 
union_mode))
                         } else {
                             Err(ArrowError::ParseError(
                                 "Expecting a typeIds for union ".to_string(),
@@ -581,7 +576,7 @@ impl DataType {
                 json!({"name": "fixedsizebinary", "byteWidth": byte_width})
             }
             DataType::Struct(_) => json!({"name": "struct"}),
-            DataType::Union(_, _) => json!({"name": "union"}),
+            DataType::Union(_, _, _) => json!({"name": "union"}),
             DataType::List(_) => json!({ "name": "list"}),
             DataType::LargeList(_) => json!({ "name": "largelist"}),
             DataType::FixedSizeList(_, length) => {
diff --git a/arrow/src/datatypes/field.rs b/arrow/src/datatypes/field.rs
index 6471f1ed7..5025d32a4 100644
--- a/arrow/src/datatypes/field.rs
+++ b/arrow/src/datatypes/field.rs
@@ -168,7 +168,7 @@ impl Field {
         let mut collected_fields = vec![];
 
         match dt {
-            DataType::Struct(fields) | DataType::Union(fields, _) => {
+            DataType::Struct(fields) | DataType::Union(fields, _, _) => {
                 collected_fields.extend(fields.iter().flat_map(|f| f.fields()))
             }
             DataType::List(field)
@@ -390,18 +390,11 @@ impl Field {
                             }
                         }
                     }
-                    DataType::Union(fields, mode) => match map.get("children") 
{
+                    DataType::Union(_, type_ids, mode) => match 
map.get("children") {
                         Some(Value::Array(values)) => {
-                            let mut union_fields: Vec<Field> =
+                            let union_fields: Vec<Field> =
                                 
values.iter().map(Field::from).collect::<Result<_>>()?;
-                            
fields.iter().zip(union_fields.iter_mut()).for_each(
-                                |(f, union_field)| {
-                                    union_field.set_metadata(Some(
-                                        f.metadata().unwrap().clone(),
-                                    ));
-                                },
-                            );
-                            DataType::Union(union_fields, mode)
+                            DataType::Union(union_fields, type_ids, mode)
                         }
                         Some(_) => {
                             return Err(ArrowError::ParseError(
@@ -568,18 +561,34 @@ impl Field {
                     ));
                 }
             },
-            DataType::Union(nested_fields, _) => match &from.data_type {
-                DataType::Union(from_nested_fields, _) => {
-                    for from_field in from_nested_fields {
+            DataType::Union(nested_fields, type_ids, _) => match 
&from.data_type {
+                DataType::Union(from_nested_fields, from_type_ids, _) => {
+                    for (idx, from_field) in 
from_nested_fields.iter().enumerate() {
                         let mut is_new_field = true;
-                        for self_field in nested_fields.iter_mut() {
+                        let field_type_id = from_type_ids.get(idx).unwrap();
+
+                        for (self_idx, self_field) in 
nested_fields.iter_mut().enumerate()
+                        {
                             if from_field == self_field {
+                                let self_type_id = 
type_ids.get(self_idx).unwrap();
+
+                                // If the nested fields in two unions are the 
same, they must have same
+                                // type id.
+                                if self_type_id != field_type_id {
+                                    return Err(ArrowError::SchemaError(
+                                        "Fail to merge schema Field due to 
conflicting type ids in union datatype"
+                                            .to_string(),
+                                    ));
+                                }
+
                                 is_new_field = false;
                                 break;
                             }
                         }
+
                         if is_new_field {
                             nested_fields.push(from_field.clone());
+                            type_ids.push(*field_type_id);
                         }
                     }
                 }
diff --git a/arrow/src/datatypes/mod.rs b/arrow/src/datatypes/mod.rs
index c3015972a..47074633d 100644
--- a/arrow/src/datatypes/mod.rs
+++ b/arrow/src/datatypes/mod.rs
@@ -435,19 +435,10 @@ mod tests {
             "my_union",
             DataType::Union(
                 vec![
-                    Field::new("f1", DataType::Int32, true).with_metadata(Some(
-                        [("type_id".to_string(), "5".to_string())]
-                            .iter()
-                            .cloned()
-                            .collect(),
-                    )),
-                    Field::new("f2", DataType::Utf8, true).with_metadata(Some(
-                        [("type_id".to_string(), "7".to_string())]
-                            .iter()
-                            .cloned()
-                            .collect(),
-                    )),
+                    Field::new("f1", DataType::Int32, true),
+                    Field::new("f2", DataType::Utf8, true),
                 ],
+                vec![5, 7],
                 UnionMode::Sparse,
             ),
             false,
@@ -1444,6 +1435,7 @@ mod tests {
                             Field::new("c11", DataType::Utf8, true),
                             Field::new("c12", DataType::Utf8, true),
                         ],
+                        vec![0, 1],
                         UnionMode::Dense
                     ),
                     false
@@ -1455,6 +1447,7 @@ mod tests {
                             Field::new("c12", DataType::Utf8, true),
                             Field::new("c13", 
DataType::Time64(TimeUnit::Second), true),
                         ],
+                        vec![1, 2],
                         UnionMode::Dense
                     ),
                     false
@@ -1468,6 +1461,7 @@ mod tests {
                         Field::new("c12", DataType::Utf8, true),
                         Field::new("c13", DataType::Time64(TimeUnit::Second), 
true),
                     ],
+                    vec![0, 1, 2],
                     UnionMode::Dense
                 ),
                 false
diff --git a/arrow/src/ipc/convert.rs b/arrow/src/ipc/convert.rs
index 97ed9ed78..c81ea8278 100644
--- a/arrow/src/ipc/convert.rs
+++ b/arrow/src/ipc/convert.rs
@@ -338,7 +338,12 @@ pub(crate) fn get_data_type(field: ipc::Field, 
may_be_dictionary: bool) -> DataT
                 }
             };
 
-            DataType::Union(fields, union_mode)
+            let type_ids: Vec<i8> = match union.typeIds() {
+                None => (0_i8..fields.len() as i8).collect(),
+                Some(ids) => ids.iter().map(|i| i as i8).collect(),
+            };
+
+            DataType::Union(fields, type_ids, union_mode)
         }
         t => unimplemented!("Type {:?} not supported", t),
     }
@@ -666,7 +671,7 @@ pub(crate) fn get_fb_field_type<'a>(
                 children: Some(fbb.create_vector(&empty_fields[..])),
             }
         }
-        Union(fields, mode) => {
+        Union(fields, type_ids, mode) => {
             let mut children = vec![];
             for field in fields {
                 children.push(build_field(fbb, field));
@@ -677,8 +682,11 @@ pub(crate) fn get_fb_field_type<'a>(
                 UnionMode::Dense => ipc::UnionMode::Dense,
             };
 
+            let fbb_type_ids = fbb
+                .create_vector(&type_ids.iter().map(|t| *t as 
i32).collect::<Vec<_>>());
             let mut builder = ipc::UnionBuilder::new(fbb);
             builder.add_mode(union_mode);
+            builder.add_typeIds(fbb_type_ids);
 
             FBFieldType {
                 type_type: ipc::Type::Union,
@@ -874,6 +882,7 @@ mod tests {
                                                 
DataType::List(Box::new(Field::new(
                                                     "union",
                                                     DataType::Union(
+                                                        vec![],
                                                         vec![],
                                                         UnionMode::Sparse,
                                                     ),
@@ -882,6 +891,7 @@ mod tests {
                                                 false,
                                             ),
                                         ],
+                                        vec![0, 1],
                                         UnionMode::Dense,
                                     ),
                                     false,
@@ -889,13 +899,34 @@ mod tests {
                                 false,
                             ),
                         ],
+                        vec![0, 1],
                         UnionMode::Sparse,
                     ),
                     false,
                 ),
                 Field::new("struct<>", DataType::Struct(vec![]), true),
-                Field::new("union<>", DataType::Union(vec![], 
UnionMode::Dense), true),
-                Field::new("union<>", DataType::Union(vec![], 
UnionMode::Sparse), true),
+                Field::new(
+                    "union<>",
+                    DataType::Union(vec![], vec![], UnionMode::Dense),
+                    true,
+                ),
+                Field::new(
+                    "union<>",
+                    DataType::Union(vec![], vec![], UnionMode::Sparse),
+                    true,
+                ),
+                Field::new(
+                    "union<int32, utf8>",
+                    DataType::Union(
+                        vec![
+                            Field::new("int32", DataType::Int32, true),
+                            Field::new("utf8", DataType::Utf8, true),
+                        ],
+                        vec![2, 3], // non-default type ids
+                        UnionMode::Dense,
+                    ),
+                    true,
+                ),
                 Field::new_dict(
                     "dictionary<int32, utf8>",
                     DataType::Dictionary(
diff --git a/arrow/src/ipc/reader.rs b/arrow/src/ipc/reader.rs
index 4a73269e5..662b384c4 100644
--- a/arrow/src/ipc/reader.rs
+++ b/arrow/src/ipc/reader.rs
@@ -195,7 +195,7 @@ fn create_array(
                 value_array.clone(),
             )
         }
-        Union(fields, mode) => {
+        Union(fields, field_type_ids, mode) => {
             let union_node = nodes[node_index];
             node_index += 1;
 
@@ -234,7 +234,8 @@ fn create_array(
                 children.push((field.clone(), triple.0));
             }
 
-            let array = UnionArray::try_new(type_ids, value_offsets, 
children)?;
+            let array =
+                UnionArray::try_new(field_type_ids, type_ids, value_offsets, 
children)?;
             Arc::new(array)
         }
         Null => {
diff --git a/arrow/src/ipc/writer.rs b/arrow/src/ipc/writer.rs
index c03d5e449..f61d4ce4c 100644
--- a/arrow/src/ipc/writer.rs
+++ b/arrow/src/ipc/writer.rs
@@ -221,7 +221,7 @@ impl IpcDataGenerator {
                     write_options,
                 )?;
             }
-            DataType::Union(fields, _) => {
+            DataType::Union(fields, _, _) => {
                 let union = as_union_array(column);
                 for (field, ref column) in fields
                     .iter()
@@ -865,7 +865,7 @@ fn write_array_data(
     // UnionArray does not have a validity buffer
     if !matches!(
         array_data.data_type(),
-        DataType::Null | DataType::Union(_, _)
+        DataType::Null | DataType::Union(_, _, _)
     ) {
         // write null buffer if exists
         let null_buffer = match array_data.null_buffer() {
@@ -1328,7 +1328,8 @@ mod tests {
         let offsets = Buffer::from_slice_ref(&[0_i32, 1, 2]);
 
         let union =
-            UnionArray::try_new(types, Some(offsets), vec![(dctfield, 
array)]).unwrap();
+            UnionArray::try_new(&[0], types, Some(offsets), vec![(dctfield, 
array)])
+                .unwrap();
 
         let schema = Arc::new(Schema::new(vec![Field::new(
             "union",
diff --git a/arrow/src/util/display.rs b/arrow/src/util/display.rs
index b0493b6ce..6da73e4cf 100644
--- a/arrow/src/util/display.rs
+++ b/arrow/src/util/display.rs
@@ -396,7 +396,9 @@ pub fn array_value_to_string(column: &array::ArrayRef, row: 
usize) -> Result<Str
 
             Ok(s)
         }
-        DataType::Union(field_vec, mode) => union_to_string(column, row, 
field_vec, mode),
+        DataType::Union(field_vec, type_ids, mode) => {
+            union_to_string(column, row, field_vec, type_ids, mode)
+        }
         _ => Err(ArrowError::InvalidArgumentError(format!(
             "Pretty printing not implemented for {:?} type",
             column.data_type()
@@ -409,6 +411,7 @@ fn union_to_string(
     column: &array::ArrayRef,
     row: usize,
     fields: &[Field],
+    type_ids: &[i8],
     mode: &UnionMode,
 ) -> Result<String> {
     let list = column
@@ -420,15 +423,13 @@ fn union_to_string(
             )
         })?;
     let type_id = list.type_id(row);
-    let name = fields
-        .get(type_id as usize)
-        .ok_or_else(|| {
-            ArrowError::InvalidArgumentError(format!(
-                "Repl error: could not get field name for type id: {} in union 
array.",
-                type_id,
-            ))
-        })?
-        .name();
+    let field_idx = type_ids.iter().position(|t| t == &type_id).ok_or_else(|| {
+        ArrowError::InvalidArgumentError(format!(
+            "Repl error: could not get field name for type id: {} in union 
array.",
+            type_id,
+        ))
+    })?;
+    let name = fields.get(field_idx).unwrap().name();
 
     let value = array_value_to_string(
         &list.child(type_id),
diff --git a/arrow/src/util/pretty.rs b/arrow/src/util/pretty.rs
index 3fa2729ba..124de6127 100644
--- a/arrow/src/util/pretty.rs
+++ b/arrow/src/util/pretty.rs
@@ -664,6 +664,7 @@ mod tests {
                     Field::new("a", DataType::Int32, false),
                     Field::new("b", DataType::Float64, false),
                 ],
+                vec![0, 1],
                 UnionMode::Dense,
             ),
             false,
@@ -704,6 +705,7 @@ mod tests {
                     Field::new("a", DataType::Int32, false),
                     Field::new("b", DataType::Float64, false),
                 ],
+                vec![0, 1],
                 UnionMode::Sparse,
             ),
             false,
@@ -746,6 +748,7 @@ mod tests {
                     Field::new("b", DataType::Int32, false),
                     Field::new("c", DataType::Float64, false),
                 ],
+                vec![0, 1],
                 UnionMode::Dense,
             ),
             false,
@@ -760,12 +763,13 @@ mod tests {
             (inner_field.clone(), Arc::new(inner)),
         ];
 
-        let outer = UnionArray::try_new(type_ids, None, children).unwrap();
+        let outer = UnionArray::try_new(&[0, 1], type_ids, None, 
children).unwrap();
 
         let schema = Schema::new(vec![Field::new(
             "Teamsters",
             DataType::Union(
                 vec![Field::new("a", DataType::Int32, true), inner_field],
+                vec![0, 1],
                 UnionMode::Sparse,
             ),
             false,
diff --git a/integration-testing/src/lib.rs b/integration-testing/src/lib.rs
index c70459938..c57ef32bc 100644
--- a/integration-testing/src/lib.rs
+++ b/integration-testing/src/lib.rs
@@ -632,39 +632,13 @@ fn array_from_json(
             let array = MapArray::from(array_data);
             Ok(Arc::new(array))
         }
-        DataType::Union(fields, _) => {
-            let field_type_ids = fields
-                .iter()
-                .enumerate()
-                .into_iter()
-                .map(|(idx, f)| {
-                    (
-                        f.metadata()
-                            .and_then(|m| m.get("type_id"))
-                            .unwrap()
-                            .parse::<i8>()
-                            .unwrap(),
-                        idx,
-                    )
-                })
-                .collect::<HashMap<_, _>>();
-
+        DataType::Union(fields, field_type_ids, _) => {
             let type_ids = if let Some(type_id) = json_col.type_id {
                 type_id
-                    .iter()
-                    .map(|t| {
-                        if field_type_ids.contains_key(t) {
-                            Ok(*(field_type_ids.get(t).unwrap()) as i8)
-                        } else {
-                            Err(ArrowError::JsonError(format!(
-                                "Unable to find type id {:?}",
-                                t
-                            )))
-                        }
-                    })
-                    .collect::<Result<_>>()?
             } else {
-                vec![]
+                return Err(ArrowError::JsonError(
+                    "Cannot find expected type_id in json column".to_string(),
+                ));
             };
 
             let offset: Option<Buffer> = json_col.offset.map(|offsets| {
@@ -680,6 +654,7 @@ fn array_from_json(
             }
 
             let array = UnionArray::try_new(
+                field_type_ids,
                 Buffer::from(&type_ids.to_byte_slice()),
                 offset,
                 children,
diff --git a/parquet/src/arrow/arrow_writer.rs 
b/parquet/src/arrow/arrow_writer.rs
index 7ddd64432..1918c9675 100644
--- a/parquet/src/arrow/arrow_writer.rs
+++ b/parquet/src/arrow/arrow_writer.rs
@@ -324,7 +324,7 @@ fn write_leaves(
         ArrowDataType::Float16 => Err(ParquetError::ArrowError(
             "Float16 arrays not supported".to_string(),
         )),
-        ArrowDataType::FixedSizeList(_, _) | ArrowDataType::Union(_, _) => {
+        ArrowDataType::FixedSizeList(_, _) | ArrowDataType::Union(_, _, _) => {
             Err(ParquetError::NYI(
                 format!(
                     "Attempting to write an Arrow type {:?} to parquet that is 
not yet implemented",
diff --git a/parquet/src/arrow/levels.rs b/parquet/src/arrow/levels.rs
index a1979e591..be9a5e993 100644
--- a/parquet/src/arrow/levels.rs
+++ b/parquet/src/arrow/levels.rs
@@ -240,7 +240,7 @@ impl LevelInfo {
                         list_level.calculate_array_levels(&child_array, 
list_field)
                     }
                     DataType::FixedSizeList(_, _) => unimplemented!(),
-                    DataType::Union(_, _) => unimplemented!(),
+                    DataType::Union(_, _, _) => unimplemented!(),
                 }
             }
             DataType::Map(map_field, _) => {
@@ -310,7 +310,7 @@ impl LevelInfo {
                     });
                 struct_levels
             }
-            DataType::Union(_, _) => unimplemented!(),
+            DataType::Union(_, _, _) => unimplemented!(),
             DataType::Dictionary(_, _) => {
                 // Need to check for these cases not implemented in C++:
                 // - "Writing DictionaryArray with nested dictionary type not 
yet supported"
@@ -749,7 +749,7 @@ impl LevelInfo {
                     array_mask,
                 )
             }
-            DataType::FixedSizeList(_, _) | DataType::Union(_, _) => {
+            DataType::FixedSizeList(_, _) | DataType::Union(_, _, _) => {
                 unimplemented!("Getting offsets not yet implemented")
             }
         }
diff --git a/parquet/src/arrow/schema.rs b/parquet/src/arrow/schema.rs
index 71184e0b6..07c50d11c 100644
--- a/parquet/src/arrow/schema.rs
+++ b/parquet/src/arrow/schema.rs
@@ -520,7 +520,7 @@ fn arrow_to_parquet_type(field: &Field) -> Result<Type> {
                 ))
             }
         }
-        DataType::Union(_, _) => unimplemented!("See ARROW-8817."),
+        DataType::Union(_, _, _) => unimplemented!("See ARROW-8817."),
         DataType::Dictionary(_, ref value) => {
             // Dictionary encoding not handled at the schema level
             let dict_field = Field::new(name, *value.clone(), 
field.is_nullable());

Reply via email to