This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 20ffaf8ef Add type ids in Union datatype (#1703)
20ffaf8ef is described below
commit 20ffaf8ef9be737f60b59d5a6ae258962bd191cb
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Tue May 17 12:52:27 2022 -0700
Add type ids in Union datatype (#1703)
* Store type ids in Union datatype
* Add doc as suggested and put type ids in ipc
* Add test
* Fix equal_dense
* Fix clippy
---
arrow/src/array/array.rs | 4 ++--
arrow/src/array/array_union.rs | 35 ++++++++++++++++++++++++-----------
arrow/src/array/builder.rs | 4 +++-
arrow/src/array/data.rs | 20 ++++++++++++--------
arrow/src/array/equal/mod.rs | 2 +-
arrow/src/array/equal/union.rs | 26 +++++++++++++++++++++-----
arrow/src/array/equal/utils.rs | 2 +-
arrow/src/array/transform/mod.rs | 6 +++---
arrow/src/compute/kernels/cast.rs | 1 +
arrow/src/datatypes/datatype.rs | 25 ++++++++++---------------
arrow/src/datatypes/field.rs | 39 ++++++++++++++++++++++++---------------
arrow/src/datatypes/mod.rs | 18 ++++++------------
arrow/src/ipc/convert.rs | 39 +++++++++++++++++++++++++++++++++++----
arrow/src/ipc/reader.rs | 5 +++--
arrow/src/ipc/writer.rs | 7 ++++---
arrow/src/util/display.rs | 21 +++++++++++----------
arrow/src/util/pretty.rs | 6 +++++-
integration-testing/src/lib.rs | 35 +++++------------------------------
parquet/src/arrow/arrow_writer.rs | 2 +-
parquet/src/arrow/levels.rs | 6 +++---
parquet/src/arrow/schema.rs | 2 +-
21 files changed, 176 insertions(+), 129 deletions(-)
diff --git a/arrow/src/array/array.rs b/arrow/src/array/array.rs
index 421e60f04..ed99a6b9f 100644
--- a/arrow/src/array/array.rs
+++ b/arrow/src/array/array.rs
@@ -364,7 +364,7 @@ pub fn make_array(data: ArrayData) -> ArrayRef {
DataType::LargeList(_) => Arc::new(LargeListArray::from(data)) as
ArrayRef,
DataType::Struct(_) => Arc::new(StructArray::from(data)) as ArrayRef,
DataType::Map(_, _) => Arc::new(MapArray::from(data)) as ArrayRef,
- DataType::Union(_, _) => Arc::new(UnionArray::from(data)) as ArrayRef,
+ DataType::Union(_, _, _) => Arc::new(UnionArray::from(data)) as
ArrayRef,
DataType::FixedSizeList(_, _) => {
Arc::new(FixedSizeListArray::from(data)) as ArrayRef
}
@@ -535,7 +535,7 @@ pub fn new_null_array(data_type: &DataType, length: usize)
-> ArrayRef {
DataType::Map(field, _keys_sorted) => {
new_null_list_array::<i32>(data_type, field.data_type(), length)
}
- DataType::Union(_, _) => {
+ DataType::Union(_, _, _) => {
unimplemented!("Creating null Union array not yet supported")
}
DataType::Dictionary(key, value) => {
diff --git a/arrow/src/array/array_union.rs b/arrow/src/array/array_union.rs
index 5ebf3d2d3..5cfab0bbf 100644
--- a/arrow/src/array/array_union.rs
+++ b/arrow/src/array/array_union.rs
@@ -58,6 +58,7 @@ use std::any::Any;
/// ];
///
/// let array = UnionArray::try_new(
+/// &vec![0, 1],
/// type_id_buffer,
/// Some(value_offsets_buffer),
/// children,
@@ -90,6 +91,7 @@ use std::any::Any;
/// ];
///
/// let array = UnionArray::try_new(
+/// &vec![0, 1],
/// type_id_buffer,
/// None,
/// children,
@@ -135,6 +137,7 @@ impl UnionArray {
/// `i8` and `i32` values respectively. `Buffer` objects are untyped and
no attempt is made
/// to ensure that the data provided is valid.
pub unsafe fn new_unchecked(
+ field_type_ids: &[i8],
type_ids: Buffer,
value_offsets: Option<Buffer>,
child_arrays: Vec<(Field, ArrayRef)>,
@@ -149,10 +152,14 @@ impl UnionArray {
UnionMode::Sparse
};
- let builder = ArrayData::builder(DataType::Union(field_types, mode))
- .add_buffer(type_ids)
- .child_data(field_values.into_iter().map(|a|
a.data().clone()).collect())
- .len(len);
+ let builder = ArrayData::builder(DataType::Union(
+ field_types,
+ Vec::from(field_type_ids),
+ mode,
+ ))
+ .add_buffer(type_ids)
+ .child_data(field_values.into_iter().map(|a|
a.data().clone()).collect())
+ .len(len);
let data = match value_offsets {
Some(b) => builder.add_buffer(b).build_unchecked(),
@@ -163,6 +170,7 @@ impl UnionArray {
/// Attempts to create a new `UnionArray`, validating the inputs provided.
pub fn try_new(
+ field_type_ids: &[i8],
type_ids: Buffer,
value_offsets: Option<Buffer>,
child_arrays: Vec<(Field, ArrayRef)>,
@@ -209,8 +217,9 @@ impl UnionArray {
// Unsafe Justification: arguments were validated above (and
// re-revalidated as part of data().validate() below)
- let new_self =
- unsafe { Self::new_unchecked(type_ids, value_offsets,
child_arrays) };
+ let new_self = unsafe {
+ Self::new_unchecked(field_type_ids, type_ids, value_offsets,
child_arrays)
+ };
new_self.data().validate()?;
Ok(new_self)
@@ -269,7 +278,7 @@ impl UnionArray {
/// Returns the names of the types in the union.
pub fn type_names(&self) -> Vec<&str> {
match self.data.data_type() {
- DataType::Union(fields, _) => fields
+ DataType::Union(fields, _, _) => fields
.iter()
.map(|f| f.name().as_str())
.collect::<Vec<&str>>(),
@@ -280,7 +289,7 @@ impl UnionArray {
/// Returns whether the `UnionArray` is dense (or sparse if `false`).
fn is_dense(&self) -> bool {
match self.data.data_type() {
- DataType::Union(_, mode) => mode == &UnionMode::Dense,
+ DataType::Union(_, _, mode) => mode == &UnionMode::Dense,
_ => unreachable!("Union array's data type is not a union!"),
}
}
@@ -626,9 +635,13 @@ mod tests {
Arc::new(float_array),
),
];
- let array =
- UnionArray::try_new(type_id_buffer, Some(value_offsets_buffer),
children)
- .unwrap();
+ let array = UnionArray::try_new(
+ &[0, 1, 2],
+ type_id_buffer,
+ Some(value_offsets_buffer),
+ children,
+ )
+ .unwrap();
// Check type ids
assert_eq!(Buffer::from_slice_ref(&type_ids),
array.data().buffers()[0]);
diff --git a/arrow/src/array/builder.rs b/arrow/src/array/builder.rs
index da6d2f1c3..091a51b15 100644
--- a/arrow/src/array/builder.rs
+++ b/arrow/src/array/builder.rs
@@ -2168,7 +2168,9 @@ impl UnionBuilder {
});
let children: Vec<_> = children.into_iter().map(|(_, b)| b).collect();
- UnionArray::try_new(type_id_buffer, value_offsets_buffer, children)
+ let type_ids: Vec<i8> = (0_i8..children.len() as i8).collect();
+
+ UnionArray::try_new(&type_ids, type_id_buffer, value_offsets_buffer,
children)
}
}
diff --git a/arrow/src/array/data.rs b/arrow/src/array/data.rs
index f7eacb0a4..15cd9cc5f 100644
--- a/arrow/src/array/data.rs
+++ b/arrow/src/array/data.rs
@@ -194,7 +194,7 @@ pub(crate) fn new_buffers(data_type: &DataType, capacity:
usize) -> [MutableBuff
MutableBuffer::new(capacity * mem::size_of::<u8>()),
empty_buffer,
],
- DataType::Union(_, mode) => {
+ DataType::Union(_, _, mode) => {
let type_ids = MutableBuffer::new(capacity * mem::size_of::<i8>());
match mode {
UnionMode::Sparse => [type_ids, empty_buffer],
@@ -220,7 +220,7 @@ pub(crate) fn into_buffers(
| DataType::Binary
| DataType::LargeUtf8
| DataType::LargeBinary => vec![buffer1.into(), buffer2.into()],
- DataType::Union(_, mode) => {
+ DataType::Union(_, _, mode) => {
match mode {
// Based on Union's DataTypeLayout
UnionMode::Sparse => vec![buffer1.into()],
@@ -581,7 +581,7 @@ impl ArrayData {
DataType::Map(field, _) => {
vec![Self::new_empty(field.data_type())]
}
- DataType::Union(fields, _) => fields
+ DataType::Union(fields, _, _) => fields
.iter()
.map(|field| Self::new_empty(field.data_type()))
.collect(),
@@ -856,7 +856,7 @@ impl ArrayData {
}
Ok(())
}
- DataType::Union(fields, mode) => {
+ DataType::Union(fields, _, mode) => {
self.validate_num_child_data(fields.len())?;
for (i, field) in fields.iter().enumerate() {
@@ -1004,7 +1004,7 @@ impl ArrayData {
let child = &self.child_data[0];
self.validate_offsets_full::<i64>(child.len + child.offset)
}
- DataType::Union(_, _) => {
+ DataType::Union(_, _, _) => {
// Validate Union Array as part of implementing new Union
semantics
// See comments in `ArrayData::validate()`
// https://github.com/apache/arrow-rs/issues/85
@@ -1269,7 +1269,7 @@ fn layout(data_type: &DataType) -> DataTypeLayout {
DataType::FixedSizeList(_, _) => DataTypeLayout::new_empty(), // all
in child data
DataType::LargeList(_) =>
DataTypeLayout::new_fixed_width(size_of::<i32>()),
DataType::Struct(_) => DataTypeLayout::new_empty(), // all in child
data,
- DataType::Union(_, mode) => {
+ DataType::Union(_, _, mode) => {
let type_ids = BufferSpec::FixedWidth {
byte_width: size_of::<i8>(),
};
@@ -2505,6 +2505,7 @@ mod tests {
Field::new("field1", DataType::Int32, true),
Field::new("field2", DataType::Int64, true), // data is
int32
],
+ vec![0, 1],
UnionMode::Sparse,
),
2,
@@ -2536,6 +2537,7 @@ mod tests {
Field::new("field1", DataType::Int32, true),
Field::new("field2", DataType::Int64, true),
],
+ vec![0, 1],
UnionMode::Sparse,
),
2,
@@ -2563,6 +2565,7 @@ mod tests {
Field::new("field1", DataType::Int32, true),
Field::new("field2", DataType::Int64, true),
],
+ vec![0, 1],
UnionMode::Dense,
),
2,
@@ -2593,6 +2596,7 @@ mod tests {
Field::new("field1", DataType::Int32, true),
Field::new("field2", DataType::Int64, true),
],
+ vec![0, 1],
UnionMode::Dense,
),
2,
@@ -2705,8 +2709,8 @@ mod tests {
#[test]
fn test_into_buffers() {
let data_types = vec![
- DataType::Union(vec![], UnionMode::Dense),
- DataType::Union(vec![], UnionMode::Sparse),
+ DataType::Union(vec![], vec![], UnionMode::Dense),
+ DataType::Union(vec![], vec![], UnionMode::Sparse),
];
for data_type in data_types {
diff --git a/arrow/src/array/equal/mod.rs b/arrow/src/array/equal/mod.rs
index 1a6b9f331..c45b30ccc 100644
--- a/arrow/src/array/equal/mod.rs
+++ b/arrow/src/array/equal/mod.rs
@@ -193,7 +193,7 @@ fn equal_values(
fixed_list_equal(lhs, rhs, lhs_start, rhs_start, len)
}
DataType::Struct(_) => struct_equal(lhs, rhs, lhs_start, rhs_start,
len),
- DataType::Union(_, _) => union_equal(lhs, rhs, lhs_start, rhs_start,
len),
+ DataType::Union(_, _, _) => union_equal(lhs, rhs, lhs_start,
rhs_start, len),
DataType::Dictionary(data_type, _) => match data_type.as_ref() {
DataType::Int8 => dictionary_equal::<i8>(lhs, rhs, lhs_start,
rhs_start, len),
DataType::Int16 => {
diff --git a/arrow/src/array/equal/union.rs b/arrow/src/array/equal/union.rs
index 021b0a3b7..e8b9d27b6 100644
--- a/arrow/src/array/equal/union.rs
+++ b/arrow/src/array/equal/union.rs
@@ -19,6 +19,7 @@ use crate::{array::ArrayData, datatypes::DataType,
datatypes::UnionMode};
use super::equal_range;
+#[allow(clippy::too_many_arguments)]
fn equal_dense(
lhs: &ArrayData,
rhs: &ArrayData,
@@ -26,6 +27,8 @@ fn equal_dense(
rhs_type_ids: &[i8],
lhs_offsets: &[i32],
rhs_offsets: &[i32],
+ lhs_field_type_ids: &[i8],
+ rhs_field_type_ids: &[i8],
) -> bool {
let offsets = lhs_offsets.iter().zip(rhs_offsets.iter());
@@ -34,8 +37,16 @@ fn equal_dense(
.zip(rhs_type_ids.iter())
.zip(offsets)
.all(|((l_type_id, r_type_id), (l_offset, r_offset))| {
- let lhs_values = &lhs.child_data()[*l_type_id as usize];
- let rhs_values = &rhs.child_data()[*r_type_id as usize];
+ let lhs_child_index = lhs_field_type_ids
+ .iter()
+ .position(|r| r == l_type_id)
+ .unwrap();
+ let rhs_child_index = rhs_field_type_ids
+ .iter()
+ .position(|r| r == r_type_id)
+ .unwrap();
+ let lhs_values = &lhs.child_data()[lhs_child_index];
+ let rhs_values = &rhs.child_data()[rhs_child_index];
equal_range(
lhs_values,
@@ -76,7 +87,10 @@ pub(super) fn union_equal(
let rhs_type_id_range = &rhs_type_ids[rhs_start..rhs_start + len];
match (lhs.data_type(), rhs.data_type()) {
- (DataType::Union(_, UnionMode::Dense), DataType::Union(_,
UnionMode::Dense)) => {
+ (
+ DataType::Union(_, lhs_type_ids, UnionMode::Dense),
+ DataType::Union(_, rhs_type_ids, UnionMode::Dense),
+ ) => {
let lhs_offsets = lhs.buffer::<i32>(1);
let rhs_offsets = rhs.buffer::<i32>(1);
@@ -91,11 +105,13 @@ pub(super) fn union_equal(
rhs_type_id_range,
lhs_offsets_range,
rhs_offsets_range,
+ lhs_type_ids,
+ rhs_type_ids,
)
}
(
- DataType::Union(_, UnionMode::Sparse),
- DataType::Union(_, UnionMode::Sparse),
+ DataType::Union(_, _, UnionMode::Sparse),
+ DataType::Union(_, _, UnionMode::Sparse),
) => {
lhs_type_id_range == rhs_type_id_range
&& equal_sparse(lhs, rhs, lhs_start, rhs_start, len)
diff --git a/arrow/src/array/equal/utils.rs b/arrow/src/array/equal/utils.rs
index 8875239ca..fed3933a0 100644
--- a/arrow/src/array/equal/utils.rs
+++ b/arrow/src/array/equal/utils.rs
@@ -68,7 +68,7 @@ pub(super) fn equal_nulls(
#[inline]
pub(super) fn base_equal(lhs: &ArrayData, rhs: &ArrayData) -> bool {
let equal_type = match (lhs.data_type(), rhs.data_type()) {
- (DataType::Union(l_fields, l_mode), DataType::Union(r_fields, r_mode))
=> {
+ (DataType::Union(l_fields, _, l_mode), DataType::Union(r_fields, _,
r_mode)) => {
l_fields == r_fields && l_mode == r_mode
}
(DataType::Map(l_field, l_sorted), DataType::Map(r_field, r_sorted))
=> {
diff --git a/arrow/src/array/transform/mod.rs b/arrow/src/array/transform/mod.rs
index aa7d417a1..586a4fec2 100644
--- a/arrow/src/array/transform/mod.rs
+++ b/arrow/src/array/transform/mod.rs
@@ -274,7 +274,7 @@ fn build_extend(array: &ArrayData) -> Extend {
DataType::FixedSizeBinary(_) => fixed_binary::build_extend(array),
DataType::Float16 => primitive::build_extend::<f16>(array),
DataType::FixedSizeList(_, _) => fixed_size_list::build_extend(array),
- DataType::Union(_, mode) => match mode {
+ DataType::Union(_, _, mode) => match mode {
UnionMode::Sparse => union::build_extend_sparse(array),
UnionMode::Dense => union::build_extend_dense(array),
},
@@ -325,7 +325,7 @@ fn build_extend_nulls(data_type: &DataType) -> ExtendNulls {
DataType::FixedSizeBinary(_) => fixed_binary::extend_nulls,
DataType::Float16 => primitive::extend_nulls::<f16>,
DataType::FixedSizeList(_, _) => fixed_size_list::extend_nulls,
- DataType::Union(_, mode) => match mode {
+ DataType::Union(_, _, mode) => match mode {
UnionMode::Sparse => union::extend_nulls_sparse,
UnionMode::Dense => union::extend_nulls_dense,
},
@@ -524,7 +524,7 @@ impl<'a> MutableArrayData<'a> {
.collect::<Vec<_>>();
vec![MutableArrayData::new(childs, use_nulls, array_capacity)]
}
- DataType::Union(fields, _) => (0..fields.len())
+ DataType::Union(fields, _, _) => (0..fields.len())
.map(|i| {
let child_arrays = arrays
.iter()
diff --git a/arrow/src/compute/kernels/cast.rs
b/arrow/src/compute/kernels/cast.rs
index 2c0ebb1e2..c989cd2fe 100644
--- a/arrow/src/compute/kernels/cast.rs
+++ b/arrow/src/compute/kernels/cast.rs
@@ -4776,6 +4776,7 @@ mod tests {
Field::new("f1", DataType::Int32, false),
Field::new("f2", DataType::Utf8, true),
],
+ vec![0, 1],
UnionMode::Dense,
),
Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
diff --git a/arrow/src/datatypes/datatype.rs b/arrow/src/datatypes/datatype.rs
index c5cc8f017..a740e8ecc 100644
--- a/arrow/src/datatypes/datatype.rs
+++ b/arrow/src/datatypes/datatype.rs
@@ -114,8 +114,12 @@ pub enum DataType {
LargeList(Box<Field>),
/// A nested datatype that contains a number of sub-fields.
Struct(Vec<Field>),
- /// A nested datatype that can represent slots of differing types.
- Union(Vec<Field>, UnionMode),
+ /// A nested datatype that can represent slots of differing types.
Components:
+ ///
+ /// 1. [`Field`] for each possible child type the Union can hold
+ /// 2. The corresponding `type_id` used to identify which Field
+ /// 3. The type of union (Sparse or Dense)
+ Union(Vec<Field>, Vec<i8>, UnionMode),
/// A dictionary encoded array (`key_type`, `value_type`), where
/// each array element is an index of `key_type` into an
/// associated dictionary of `value_type`.
@@ -516,24 +520,15 @@ impl DataType {
.as_array()
.unwrap()
.iter()
- .map(|t| t.as_i64().unwrap())
+ .map(|t| t.as_i64().unwrap() as i8)
.collect::<Vec<_>>();
let default_fields = type_ids
.iter()
- .map(|t| {
- Field::new("", DataType::Boolean,
true).with_metadata(
- Some(
- [("type_id".to_string(),
t.to_string())]
- .iter()
- .cloned()
- .collect(),
- ),
- )
- })
+ .map(|_| default_field.clone())
.collect::<Vec<_>>();
- Ok(DataType::Union(default_fields, union_mode))
+ Ok(DataType::Union(default_fields, type_ids,
union_mode))
} else {
Err(ArrowError::ParseError(
"Expecting a typeIds for union ".to_string(),
@@ -581,7 +576,7 @@ impl DataType {
json!({"name": "fixedsizebinary", "byteWidth": byte_width})
}
DataType::Struct(_) => json!({"name": "struct"}),
- DataType::Union(_, _) => json!({"name": "union"}),
+ DataType::Union(_, _, _) => json!({"name": "union"}),
DataType::List(_) => json!({ "name": "list"}),
DataType::LargeList(_) => json!({ "name": "largelist"}),
DataType::FixedSizeList(_, length) => {
diff --git a/arrow/src/datatypes/field.rs b/arrow/src/datatypes/field.rs
index 6471f1ed7..5025d32a4 100644
--- a/arrow/src/datatypes/field.rs
+++ b/arrow/src/datatypes/field.rs
@@ -168,7 +168,7 @@ impl Field {
let mut collected_fields = vec![];
match dt {
- DataType::Struct(fields) | DataType::Union(fields, _) => {
+ DataType::Struct(fields) | DataType::Union(fields, _, _) => {
collected_fields.extend(fields.iter().flat_map(|f| f.fields()))
}
DataType::List(field)
@@ -390,18 +390,11 @@ impl Field {
}
}
}
- DataType::Union(fields, mode) => match map.get("children")
{
+ DataType::Union(_, type_ids, mode) => match
map.get("children") {
Some(Value::Array(values)) => {
- let mut union_fields: Vec<Field> =
+ let union_fields: Vec<Field> =
values.iter().map(Field::from).collect::<Result<_>>()?;
-
fields.iter().zip(union_fields.iter_mut()).for_each(
- |(f, union_field)| {
- union_field.set_metadata(Some(
- f.metadata().unwrap().clone(),
- ));
- },
- );
- DataType::Union(union_fields, mode)
+ DataType::Union(union_fields, type_ids, mode)
}
Some(_) => {
return Err(ArrowError::ParseError(
@@ -568,18 +561,34 @@ impl Field {
));
}
},
- DataType::Union(nested_fields, _) => match &from.data_type {
- DataType::Union(from_nested_fields, _) => {
- for from_field in from_nested_fields {
+ DataType::Union(nested_fields, type_ids, _) => match
&from.data_type {
+ DataType::Union(from_nested_fields, from_type_ids, _) => {
+ for (idx, from_field) in
from_nested_fields.iter().enumerate() {
let mut is_new_field = true;
- for self_field in nested_fields.iter_mut() {
+ let field_type_id = from_type_ids.get(idx).unwrap();
+
+ for (self_idx, self_field) in
nested_fields.iter_mut().enumerate()
+ {
if from_field == self_field {
+ let self_type_id =
type_ids.get(self_idx).unwrap();
+
+ // If the nested fields in two unions are the
same, they must have same
+ // type id.
+ if self_type_id != field_type_id {
+ return Err(ArrowError::SchemaError(
+ "Fail to merge schema Field due to
conflicting type ids in union datatype"
+ .to_string(),
+ ));
+ }
+
is_new_field = false;
break;
}
}
+
if is_new_field {
nested_fields.push(from_field.clone());
+ type_ids.push(*field_type_id);
}
}
}
diff --git a/arrow/src/datatypes/mod.rs b/arrow/src/datatypes/mod.rs
index c3015972a..47074633d 100644
--- a/arrow/src/datatypes/mod.rs
+++ b/arrow/src/datatypes/mod.rs
@@ -435,19 +435,10 @@ mod tests {
"my_union",
DataType::Union(
vec![
- Field::new("f1", DataType::Int32, true).with_metadata(Some(
- [("type_id".to_string(), "5".to_string())]
- .iter()
- .cloned()
- .collect(),
- )),
- Field::new("f2", DataType::Utf8, true).with_metadata(Some(
- [("type_id".to_string(), "7".to_string())]
- .iter()
- .cloned()
- .collect(),
- )),
+ Field::new("f1", DataType::Int32, true),
+ Field::new("f2", DataType::Utf8, true),
],
+ vec![5, 7],
UnionMode::Sparse,
),
false,
@@ -1444,6 +1435,7 @@ mod tests {
Field::new("c11", DataType::Utf8, true),
Field::new("c12", DataType::Utf8, true),
],
+ vec![0, 1],
UnionMode::Dense
),
false
@@ -1455,6 +1447,7 @@ mod tests {
Field::new("c12", DataType::Utf8, true),
Field::new("c13",
DataType::Time64(TimeUnit::Second), true),
],
+ vec![1, 2],
UnionMode::Dense
),
false
@@ -1468,6 +1461,7 @@ mod tests {
Field::new("c12", DataType::Utf8, true),
Field::new("c13", DataType::Time64(TimeUnit::Second),
true),
],
+ vec![0, 1, 2],
UnionMode::Dense
),
false
diff --git a/arrow/src/ipc/convert.rs b/arrow/src/ipc/convert.rs
index 97ed9ed78..c81ea8278 100644
--- a/arrow/src/ipc/convert.rs
+++ b/arrow/src/ipc/convert.rs
@@ -338,7 +338,12 @@ pub(crate) fn get_data_type(field: ipc::Field,
may_be_dictionary: bool) -> DataT
}
};
- DataType::Union(fields, union_mode)
+ let type_ids: Vec<i8> = match union.typeIds() {
+ None => (0_i8..fields.len() as i8).collect(),
+ Some(ids) => ids.iter().map(|i| i as i8).collect(),
+ };
+
+ DataType::Union(fields, type_ids, union_mode)
}
t => unimplemented!("Type {:?} not supported", t),
}
@@ -666,7 +671,7 @@ pub(crate) fn get_fb_field_type<'a>(
children: Some(fbb.create_vector(&empty_fields[..])),
}
}
- Union(fields, mode) => {
+ Union(fields, type_ids, mode) => {
let mut children = vec![];
for field in fields {
children.push(build_field(fbb, field));
@@ -677,8 +682,11 @@ pub(crate) fn get_fb_field_type<'a>(
UnionMode::Dense => ipc::UnionMode::Dense,
};
+ let fbb_type_ids = fbb
+ .create_vector(&type_ids.iter().map(|t| *t as
i32).collect::<Vec<_>>());
let mut builder = ipc::UnionBuilder::new(fbb);
builder.add_mode(union_mode);
+ builder.add_typeIds(fbb_type_ids);
FBFieldType {
type_type: ipc::Type::Union,
@@ -874,6 +882,7 @@ mod tests {
DataType::List(Box::new(Field::new(
"union",
DataType::Union(
+ vec![],
vec![],
UnionMode::Sparse,
),
@@ -882,6 +891,7 @@ mod tests {
false,
),
],
+ vec![0, 1],
UnionMode::Dense,
),
false,
@@ -889,13 +899,34 @@ mod tests {
false,
),
],
+ vec![0, 1],
UnionMode::Sparse,
),
false,
),
Field::new("struct<>", DataType::Struct(vec![]), true),
- Field::new("union<>", DataType::Union(vec![],
UnionMode::Dense), true),
- Field::new("union<>", DataType::Union(vec![],
UnionMode::Sparse), true),
+ Field::new(
+ "union<>",
+ DataType::Union(vec![], vec![], UnionMode::Dense),
+ true,
+ ),
+ Field::new(
+ "union<>",
+ DataType::Union(vec![], vec![], UnionMode::Sparse),
+ true,
+ ),
+ Field::new(
+ "union<int32, utf8>",
+ DataType::Union(
+ vec![
+ Field::new("int32", DataType::Int32, true),
+ Field::new("utf8", DataType::Utf8, true),
+ ],
+ vec![2, 3], // non-default type ids
+ UnionMode::Dense,
+ ),
+ true,
+ ),
Field::new_dict(
"dictionary<int32, utf8>",
DataType::Dictionary(
diff --git a/arrow/src/ipc/reader.rs b/arrow/src/ipc/reader.rs
index 4a73269e5..662b384c4 100644
--- a/arrow/src/ipc/reader.rs
+++ b/arrow/src/ipc/reader.rs
@@ -195,7 +195,7 @@ fn create_array(
value_array.clone(),
)
}
- Union(fields, mode) => {
+ Union(fields, field_type_ids, mode) => {
let union_node = nodes[node_index];
node_index += 1;
@@ -234,7 +234,8 @@ fn create_array(
children.push((field.clone(), triple.0));
}
- let array = UnionArray::try_new(type_ids, value_offsets,
children)?;
+ let array =
+ UnionArray::try_new(field_type_ids, type_ids, value_offsets,
children)?;
Arc::new(array)
}
Null => {
diff --git a/arrow/src/ipc/writer.rs b/arrow/src/ipc/writer.rs
index c03d5e449..f61d4ce4c 100644
--- a/arrow/src/ipc/writer.rs
+++ b/arrow/src/ipc/writer.rs
@@ -221,7 +221,7 @@ impl IpcDataGenerator {
write_options,
)?;
}
- DataType::Union(fields, _) => {
+ DataType::Union(fields, _, _) => {
let union = as_union_array(column);
for (field, ref column) in fields
.iter()
@@ -865,7 +865,7 @@ fn write_array_data(
// UnionArray does not have a validity buffer
if !matches!(
array_data.data_type(),
- DataType::Null | DataType::Union(_, _)
+ DataType::Null | DataType::Union(_, _, _)
) {
// write null buffer if exists
let null_buffer = match array_data.null_buffer() {
@@ -1328,7 +1328,8 @@ mod tests {
let offsets = Buffer::from_slice_ref(&[0_i32, 1, 2]);
let union =
- UnionArray::try_new(types, Some(offsets), vec![(dctfield,
array)]).unwrap();
+ UnionArray::try_new(&[0], types, Some(offsets), vec![(dctfield,
array)])
+ .unwrap();
let schema = Arc::new(Schema::new(vec![Field::new(
"union",
diff --git a/arrow/src/util/display.rs b/arrow/src/util/display.rs
index b0493b6ce..6da73e4cf 100644
--- a/arrow/src/util/display.rs
+++ b/arrow/src/util/display.rs
@@ -396,7 +396,9 @@ pub fn array_value_to_string(column: &array::ArrayRef, row:
usize) -> Result<Str
Ok(s)
}
- DataType::Union(field_vec, mode) => union_to_string(column, row,
field_vec, mode),
+ DataType::Union(field_vec, type_ids, mode) => {
+ union_to_string(column, row, field_vec, type_ids, mode)
+ }
_ => Err(ArrowError::InvalidArgumentError(format!(
"Pretty printing not implemented for {:?} type",
column.data_type()
@@ -409,6 +411,7 @@ fn union_to_string(
column: &array::ArrayRef,
row: usize,
fields: &[Field],
+ type_ids: &[i8],
mode: &UnionMode,
) -> Result<String> {
let list = column
@@ -420,15 +423,13 @@ fn union_to_string(
)
})?;
let type_id = list.type_id(row);
- let name = fields
- .get(type_id as usize)
- .ok_or_else(|| {
- ArrowError::InvalidArgumentError(format!(
- "Repl error: could not get field name for type id: {} in union
array.",
- type_id,
- ))
- })?
- .name();
+ let field_idx = type_ids.iter().position(|t| t == &type_id).ok_or_else(|| {
+ ArrowError::InvalidArgumentError(format!(
+ "Repl error: could not get field name for type id: {} in union
array.",
+ type_id,
+ ))
+ })?;
+ let name = fields.get(field_idx).unwrap().name();
let value = array_value_to_string(
&list.child(type_id),
diff --git a/arrow/src/util/pretty.rs b/arrow/src/util/pretty.rs
index 3fa2729ba..124de6127 100644
--- a/arrow/src/util/pretty.rs
+++ b/arrow/src/util/pretty.rs
@@ -664,6 +664,7 @@ mod tests {
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Float64, false),
],
+ vec![0, 1],
UnionMode::Dense,
),
false,
@@ -704,6 +705,7 @@ mod tests {
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Float64, false),
],
+ vec![0, 1],
UnionMode::Sparse,
),
false,
@@ -746,6 +748,7 @@ mod tests {
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Float64, false),
],
+ vec![0, 1],
UnionMode::Dense,
),
false,
@@ -760,12 +763,13 @@ mod tests {
(inner_field.clone(), Arc::new(inner)),
];
- let outer = UnionArray::try_new(type_ids, None, children).unwrap();
+ let outer = UnionArray::try_new(&[0, 1], type_ids, None,
children).unwrap();
let schema = Schema::new(vec![Field::new(
"Teamsters",
DataType::Union(
vec![Field::new("a", DataType::Int32, true), inner_field],
+ vec![0, 1],
UnionMode::Sparse,
),
false,
diff --git a/integration-testing/src/lib.rs b/integration-testing/src/lib.rs
index c70459938..c57ef32bc 100644
--- a/integration-testing/src/lib.rs
+++ b/integration-testing/src/lib.rs
@@ -632,39 +632,13 @@ fn array_from_json(
let array = MapArray::from(array_data);
Ok(Arc::new(array))
}
- DataType::Union(fields, _) => {
- let field_type_ids = fields
- .iter()
- .enumerate()
- .into_iter()
- .map(|(idx, f)| {
- (
- f.metadata()
- .and_then(|m| m.get("type_id"))
- .unwrap()
- .parse::<i8>()
- .unwrap(),
- idx,
- )
- })
- .collect::<HashMap<_, _>>();
-
+ DataType::Union(fields, field_type_ids, _) => {
let type_ids = if let Some(type_id) = json_col.type_id {
type_id
- .iter()
- .map(|t| {
- if field_type_ids.contains_key(t) {
- Ok(*(field_type_ids.get(t).unwrap()) as i8)
- } else {
- Err(ArrowError::JsonError(format!(
- "Unable to find type id {:?}",
- t
- )))
- }
- })
- .collect::<Result<_>>()?
} else {
- vec![]
+ return Err(ArrowError::JsonError(
+ "Cannot find expected type_id in json column".to_string(),
+ ));
};
let offset: Option<Buffer> = json_col.offset.map(|offsets| {
@@ -680,6 +654,7 @@ fn array_from_json(
}
let array = UnionArray::try_new(
+ field_type_ids,
Buffer::from(&type_ids.to_byte_slice()),
offset,
children,
diff --git a/parquet/src/arrow/arrow_writer.rs
b/parquet/src/arrow/arrow_writer.rs
index 7ddd64432..1918c9675 100644
--- a/parquet/src/arrow/arrow_writer.rs
+++ b/parquet/src/arrow/arrow_writer.rs
@@ -324,7 +324,7 @@ fn write_leaves(
ArrowDataType::Float16 => Err(ParquetError::ArrowError(
"Float16 arrays not supported".to_string(),
)),
- ArrowDataType::FixedSizeList(_, _) | ArrowDataType::Union(_, _) => {
+ ArrowDataType::FixedSizeList(_, _) | ArrowDataType::Union(_, _, _) => {
Err(ParquetError::NYI(
format!(
"Attempting to write an Arrow type {:?} to parquet that is
not yet implemented",
diff --git a/parquet/src/arrow/levels.rs b/parquet/src/arrow/levels.rs
index a1979e591..be9a5e993 100644
--- a/parquet/src/arrow/levels.rs
+++ b/parquet/src/arrow/levels.rs
@@ -240,7 +240,7 @@ impl LevelInfo {
list_level.calculate_array_levels(&child_array,
list_field)
}
DataType::FixedSizeList(_, _) => unimplemented!(),
- DataType::Union(_, _) => unimplemented!(),
+ DataType::Union(_, _, _) => unimplemented!(),
}
}
DataType::Map(map_field, _) => {
@@ -310,7 +310,7 @@ impl LevelInfo {
});
struct_levels
}
- DataType::Union(_, _) => unimplemented!(),
+ DataType::Union(_, _, _) => unimplemented!(),
DataType::Dictionary(_, _) => {
// Need to check for these cases not implemented in C++:
// - "Writing DictionaryArray with nested dictionary type not
yet supported"
@@ -749,7 +749,7 @@ impl LevelInfo {
array_mask,
)
}
- DataType::FixedSizeList(_, _) | DataType::Union(_, _) => {
+ DataType::FixedSizeList(_, _) | DataType::Union(_, _, _) => {
unimplemented!("Getting offsets not yet implemented")
}
}
diff --git a/parquet/src/arrow/schema.rs b/parquet/src/arrow/schema.rs
index 71184e0b6..07c50d11c 100644
--- a/parquet/src/arrow/schema.rs
+++ b/parquet/src/arrow/schema.rs
@@ -520,7 +520,7 @@ fn arrow_to_parquet_type(field: &Field) -> Result<Type> {
))
}
}
- DataType::Union(_, _) => unimplemented!("See ARROW-8817."),
+ DataType::Union(_, _, _) => unimplemented!("See ARROW-8817."),
DataType::Dictionary(_, ref value) => {
// Dictionary encoding not handled at the schema level
let dict_field = Field::new(name, *value.clone(),
field.is_nullable());