This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/main by this push:
new 63212d21f6 Add validated constructors for UnionFields (#8891)
63212d21f6 is described below
commit 63212d21f669d265c568ca1caff748fd152e1657
Author: Matthew Kim <[email protected]>
AuthorDate: Mon Dec 15 13:11:09 2025 -0500
Add validated constructors for UnionFields (#8891)
- related to #3982
# Rationale for this change
This PR introduces `UnionFields::try_new` and `UnionFields::from_fields`
as replacements for the "unsafe" `UnionFields::new` constructor, which
is now deprecated
Previously, `UnionFields` could be constructed with invalid invariants
including negative type ids, duplicate type ids, or duplicate fields.
Since `UnionArrays` index by type id, negative values cause panics at
runtimes and duplicates break comparisons.
`UnionFields::try_new` validates that type ids are non-negative, unique,
under 128, and that fields are unique, returning an error if any
constraint is violated.
`UnionFields::from_fields` is a convenience constructor that
auto-assigns sequential type ids starting from 0. (Will panic if it
exceeds 128 elements)
# Note about reviewing
I broke it up into smaller commits. The last commit updates all call
sites across the different kernels
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
arrow-array/benches/union_array.rs | 4 +-
arrow-array/src/array/mod.rs | 5 +-
arrow-array/src/array/union_array.rs | 25 +-
arrow-avro/benches/avro_writer.rs | 5 +-
arrow-avro/src/codec.rs | 8 +-
arrow-avro/src/reader/mod.rs | 45 ++-
arrow-avro/src/reader/record.rs | 7 +-
arrow-avro/src/writer/encoder.rs | 15 +-
arrow-avro/src/writer/mod.rs | 5 +-
arrow-cast/src/pretty.rs | 5 +-
arrow-flight/src/sql/metadata/sql_info.rs | 5 +-
arrow-ipc/src/convert.rs | 10 +-
arrow-ipc/src/reader.rs | 16 +-
arrow-schema/src/datatype.rs | 30 +-
arrow-schema/src/datatype_parse.rs | 55 ++-
arrow-schema/src/ffi.rs | 4 +-
arrow-schema/src/field.rs | 32 +-
arrow-schema/src/fields.rs | 428 +++++++++++++++++++++++-
arrow-select/src/take.rs | 2 +-
arrow-select/src/union_extract.rs | 17 +-
arrow/tests/array_cast.rs | 5 +-
arrow/tests/array_validation.rs | 20 +-
parquet-variant-compute/src/arrow_to_variant.rs | 31 +-
parquet-variant-compute/src/cast_to_variant.rs | 26 +-
parquet-variant-compute/src/shred_variant.rs | 11 +-
parquet-variant-compute/src/variant_to_arrow.rs | 3 +-
26 files changed, 631 insertions(+), 188 deletions(-)
diff --git a/arrow-array/benches/union_array.rs
b/arrow-array/benches/union_array.rs
index 255bc7cbdb..414529882a 100644
--- a/arrow-array/benches/union_array.rs
+++ b/arrow-array/benches/union_array.rs
@@ -54,10 +54,10 @@ fn criterion_benchmark(c: &mut Criterion) {
|b| {
let type_ids = 0..with_nulls+without_nulls;
- let fields = UnionFields::new(
+ let fields = UnionFields::try_new(
type_ids.clone(),
type_ids.clone().map(|i| Field::new(format!("f{i}"),
DataType::Int32, true)),
- );
+ ).unwrap();
let array = UnionArray::try_new(
fields,
diff --git a/arrow-array/src/array/mod.rs b/arrow-array/src/array/mod.rs
index f05916fd24..bb114be950 100644
--- a/arrow-array/src/array/mod.rs
+++ b/arrow-array/src/array/mod.rs
@@ -1073,13 +1073,14 @@ mod tests {
fn test_null_union() {
for mode in [UnionMode::Sparse, UnionMode::Dense] {
let data_type = DataType::Union(
- UnionFields::new(
+ UnionFields::try_new(
vec![2, 1],
vec![
Field::new("foo", DataType::Int32, true),
Field::new("bar", DataType::Int64, true),
],
- ),
+ )
+ .unwrap(),
mode,
);
let array = new_null_array(&data_type, 4);
diff --git a/arrow-array/src/array/union_array.rs
b/arrow-array/src/array/union_array.rs
index a133998bb2..934107d075 100644
--- a/arrow-array/src/array/union_array.rs
+++ b/arrow-array/src/array/union_array.rs
@@ -1682,14 +1682,15 @@ mod tests {
#[test]
fn test_custom_type_ids() {
let data_type = DataType::Union(
- UnionFields::new(
+ UnionFields::try_new(
vec![8, 4, 9],
vec![
Field::new("strings", DataType::Utf8, false),
Field::new("integers", DataType::Int32, false),
Field::new("floats", DataType::Float64, false),
],
- ),
+ )
+ .unwrap(),
UnionMode::Dense,
);
@@ -1796,14 +1797,15 @@ mod tests {
fn into_parts_custom_type_ids() {
let set_field_type_ids: [i8; 3] = [8, 4, 9];
let data_type = DataType::Union(
- UnionFields::new(
+ UnionFields::try_new(
set_field_type_ids,
[
Field::new("strings", DataType::Utf8, false),
Field::new("integers", DataType::Int32, false),
Field::new("floats", DataType::Float64, false),
],
- ),
+ )
+ .unwrap(),
UnionMode::Dense,
);
let string_array = StringArray::from(vec!["foo", "bar", "baz"]);
@@ -1836,13 +1838,14 @@ mod tests {
#[test]
fn test_invalid() {
- let fields = UnionFields::new(
+ let fields = UnionFields::try_new(
[3, 2],
[
Field::new("a", DataType::Utf8, false),
Field::new("b", DataType::Utf8, false),
],
- );
+ )
+ .unwrap();
let children = vec![
Arc::new(StringArray::from_iter_values(["a", "b"])) as _,
Arc::new(StringArray::from_iter_values(["c", "d"])) as _,
@@ -1912,13 +1915,14 @@ mod tests {
assert_eq!(array.logical_nulls(), None);
- let fields = UnionFields::new(
+ let fields = UnionFields::try_new(
[1, 3],
[
Field::new("a", DataType::Int8, false), // non nullable
Field::new("b", DataType::Int8, false), // non nullable
],
- );
+ )
+ .unwrap();
let array = UnionArray::try_new(
fields,
vec![1].into(),
@@ -1932,13 +1936,14 @@ mod tests {
assert_eq!(array.logical_nulls(), None);
- let nullable_fields = UnionFields::new(
+ let nullable_fields = UnionFields::try_new(
[1, 3],
[
Field::new("a", DataType::Int8, true), // nullable but without
nulls
Field::new("b", DataType::Int8, true), // nullable but without
nulls
],
- );
+ )
+ .unwrap();
let array = UnionArray::try_new(
nullable_fields.clone(),
vec![1, 1].into(),
diff --git a/arrow-avro/benches/avro_writer.rs
b/arrow-avro/benches/avro_writer.rs
index f15f13b989..58b014c5a3 100644
--- a/arrow-avro/benches/avro_writer.rs
+++ b/arrow-avro/benches/avro_writer.rs
@@ -688,14 +688,15 @@ static ENUM_DATA: Lazy<Vec<RecordBatch>> = Lazy::new(|| {
static UNION_DATA: Lazy<Vec<RecordBatch>> = Lazy::new(|| {
// Basic Dense Union of three types: Utf8, Int32, Float64
- let union_fields = UnionFields::new(
+ let union_fields = UnionFields::try_new(
vec![0, 1, 2],
vec![
Field::new("u_str", DataType::Utf8, true),
Field::new("u_int", DataType::Int32, true),
Field::new("u_f64", DataType::Float64, true),
],
- );
+ )
+ .expect("UnionFields should be valid");
let union_dt = DataType::Union(union_fields.clone(), UnionMode::Dense);
let schema = schema_single("field1", union_dt);
diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs
index 4f4669270d..04ef87d7ef 100644
--- a/arrow-avro/src/codec.rs
+++ b/arrow-avro/src/codec.rs
@@ -993,13 +993,13 @@ fn union_branch_name(dt: &AvroDataType) -> String {
dt.codec.union_field_name()
}
-fn build_union_fields(encodings: &[AvroDataType]) -> UnionFields {
+fn build_union_fields(encodings: &[AvroDataType]) -> Result<UnionFields,
ArrowError> {
let arrow_fields: Vec<Field> = encodings
.iter()
.map(|encoding| encoding.field_with_name(&union_branch_name(encoding)))
.collect();
let type_ids: Vec<i8> = (0..arrow_fields.len()).map(|i| i as i8).collect();
- UnionFields::new(type_ids, arrow_fields)
+ UnionFields::try_new(type_ids, arrow_fields)
}
/// Resolves Avro type names to [`AvroDataType`]
@@ -1267,7 +1267,7 @@ impl<'a> Maker<'a> {
.map(|s| self.parse_type(s, namespace))
.collect::<Result<_, _>>()?;
// Build Arrow layout once here
- let union_fields = build_union_fields(&children);
+ let union_fields = build_union_fields(&children)?;
Ok(AvroDataType::new(
Codec::Union(Arc::from(children), union_fields,
UnionMode::Dense),
Default::default(),
@@ -1620,7 +1620,7 @@ impl<'a> Maker<'a> {
for writer in writer_variants {
writer_to_reader.push(self.find_best_promotion(writer,
reader_variants, namespace));
}
- let union_fields = build_union_fields(&reader_encodings);
+ let union_fields = build_union_fields(&reader_encodings)?;
let mut dt = AvroDataType::new(
Codec::Union(reader_encodings.into(), union_fields,
UnionMode::Dense),
Default::default(),
diff --git a/arrow-avro/src/reader/mod.rs b/arrow-avro/src/reader/mod.rs
index beb8bbd976..546650faf5 100644
--- a/arrow-avro/src/reader/mod.rs
+++ b/arrow-avro/src/reader/mod.rs
@@ -7839,35 +7839,38 @@ mod test {
let uuid1 = uuid16_from_str("fe7bc30b-4ce8-4c5e-b67c-2234a2d38e66");
let uuid2 = uuid16_from_str("0826cc06-d2e3-4599-b4ad-af5fa6905cdb");
let item_name = Field::LIST_FIELD_DEFAULT_NAME;
- let uf_tri = UnionFields::new(
+ let uf_tri = UnionFields::try_new(
vec![0, 1, 2],
vec![
Field::new("int", DataType::Int32, false),
Field::new("string", DataType::Utf8, false),
Field::new("boolean", DataType::Boolean, false),
],
- );
- let uf_arr_items = UnionFields::new(
+ )
+ .unwrap();
+ let uf_arr_items = UnionFields::try_new(
vec![0, 1, 2],
vec![
Field::new("null", DataType::Null, false),
Field::new("string", DataType::Utf8, false),
Field::new("long", DataType::Int64, false),
],
- );
+ )
+ .unwrap();
let arr_items_field = Arc::new(Field::new(
item_name,
DataType::Union(uf_arr_items.clone(), UnionMode::Dense),
true,
));
- let uf_map_vals = UnionFields::new(
+ let uf_map_vals = UnionFields::try_new(
vec![0, 1, 2],
vec![
Field::new("string", DataType::Utf8, false),
Field::new("double", DataType::Float64, false),
Field::new("null", DataType::Null, false),
],
- );
+ )
+ .unwrap();
let map_entries_field = Arc::new(Field::new(
"entries",
DataType::Struct(Fields::from(vec![
@@ -7928,7 +7931,7 @@ mod test {
);
m
};
- let uf_union_big = UnionFields::new(
+ let uf_union_big = UnionFields::try_new(
vec![0, 1, 2, 3, 4],
vec![
Field::new(
@@ -7960,7 +7963,8 @@ mod test {
)
.with_metadata(enum_md_color.clone()),
],
- );
+ )
+ .unwrap();
let fx4_md = {
let mut m = HashMap::<String, String>::new();
m.insert(AVRO_NAME_METADATA_KEY.to_string(), "Fx4".to_string());
@@ -7970,7 +7974,7 @@ mod test {
);
m
};
- let uf_date_fixed4 = UnionFields::new(
+ let uf_date_fixed4 = UnionFields::try_new(
vec![0, 1],
vec![
Field::new(
@@ -7981,7 +7985,8 @@ mod test {
.with_metadata(fx4_md.clone()),
Field::new("date", DataType::Date32, false),
],
- );
+ )
+ .unwrap();
let dur12u_md = {
let mut m = HashMap::<String, String>::new();
m.insert(AVRO_NAME_METADATA_KEY.to_string(), "Dur12U".to_string());
@@ -7991,7 +7996,7 @@ mod test {
);
m
};
- let uf_dur_or_str = UnionFields::new(
+ let uf_dur_or_str = UnionFields::try_new(
vec![0, 1],
vec![
Field::new("string", DataType::Utf8, false),
@@ -8002,7 +8007,8 @@ mod test {
)
.with_metadata(dur12u_md.clone()),
],
- );
+ )
+ .unwrap();
let fx10_md = {
let mut m = HashMap::<String, String>::new();
m.insert(AVRO_NAME_METADATA_KEY.to_string(), "Fx10".to_string());
@@ -8012,7 +8018,7 @@ mod test {
);
m
};
- let uf_uuid_or_fx10 = UnionFields::new(
+ let uf_uuid_or_fx10 = UnionFields::try_new(
vec![0, 1],
vec![
Field::new(
@@ -8023,15 +8029,17 @@ mod test {
.with_metadata(fx10_md.clone()),
add_uuid_ext_union(Field::new("uuid",
DataType::FixedSizeBinary(16), false)),
],
- );
- let uf_kv_val = UnionFields::new(
+ )
+ .unwrap();
+ let uf_kv_val = UnionFields::try_new(
vec![0, 1, 2],
vec![
Field::new("null", DataType::Null, false),
Field::new("int", DataType::Int32, false),
Field::new("long", DataType::Int64, false),
],
- );
+ )
+ .unwrap();
let kv_fields = Fields::from(vec![
Field::new("key", DataType::Utf8, false),
Field::new(
@@ -8053,7 +8061,7 @@ mod test {
])),
false,
));
- let uf_map_or_array = UnionFields::new(
+ let uf_map_or_array = UnionFields::try_new(
vec![0, 1],
vec![
Field::new(
@@ -8063,7 +8071,8 @@ mod test {
),
Field::new("map", DataType::Map(map_int_entries.clone(),
false), false),
],
- );
+ )
+ .unwrap();
let mut enum_md_status = {
let mut m = HashMap::<String, String>::new();
m.insert(
diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs
index 0412a3e754..648baa60c7 100644
--- a/arrow-avro/src/reader/record.rs
+++ b/arrow-avro/src/reader/record.rs
@@ -3674,7 +3674,7 @@ mod tests {
avro_children.push(AvroDataType::new(codec, Default::default(),
None));
fields.push(arrow_schema::Field::new(name, dt, true));
}
- let union_fields = UnionFields::new(type_ids, fields);
+ let union_fields = UnionFields::try_new(type_ids, fields).unwrap();
let union_codec = Codec::Union(avro_children.into(), union_fields,
UnionMode::Dense);
AvroDataType::new(union_codec, Default::default(), None)
}
@@ -3823,13 +3823,14 @@ mod tests {
AvroDataType::new(Codec::Int32, Default::default(), None),
AvroDataType::new(Codec::Utf8, Default::default(), None),
];
- let uf = UnionFields::new(
+ let uf = UnionFields::try_new(
vec![1, 3],
vec![
arrow_schema::Field::new("i", DataType::Int32, true),
arrow_schema::Field::new("s", DataType::Utf8, true),
],
- );
+ )
+ .unwrap();
let codec = Codec::Union(children.into(), uf, UnionMode::Sparse);
let dt = AvroDataType::new(codec, Default::default(), None);
let err = Decoder::try_new(&dt).expect_err("sparse union should not be
supported");
diff --git a/arrow-avro/src/writer/encoder.rs b/arrow-avro/src/writer/encoder.rs
index c638c2b73f..ef9e02c8fa 100644
--- a/arrow-avro/src/writer/encoder.rs
+++ b/arrow-avro/src/writer/encoder.rs
@@ -2428,13 +2428,14 @@ mod tests {
let strings = StringArray::from(vec!["hello", "world"]);
let ints = Int32Array::from(vec![10, 20, 30]);
- let union_fields = UnionFields::new(
+ let union_fields = UnionFields::try_new(
vec![0, 1],
vec![
Field::new("v_str", DataType::Utf8, true),
Field::new("v_int", DataType::Int32, true),
],
- );
+ )
+ .unwrap();
let type_ids = Buffer::from_slice_ref([0_i8, 1, 1, 0, 1]);
let offsets = Buffer::from_slice_ref([0_i32, 0, 1, 1, 2]);
@@ -2485,14 +2486,15 @@ mod tests {
let strings = StringArray::from(vec!["hello"]);
let ints = Int32Array::from(vec![10]);
- let union_fields = UnionFields::new(
+ let union_fields = UnionFields::try_new(
vec![0, 1, 2],
vec![
Field::new("v_null", DataType::Null, true),
Field::new("v_str", DataType::Utf8, true),
Field::new("v_int", DataType::Int32, true),
],
- );
+ )
+ .unwrap();
let type_ids = Buffer::from_slice_ref([0_i8, 1, 2]);
// For a null value in a dense union, no value is added to a child
array.
@@ -2979,13 +2981,14 @@ mod tests {
fn union_encoder_string_int_nonzero_type_ids() {
let strings = StringArray::from(vec!["hello", "world"]);
let ints = Int32Array::from(vec![10, 20, 30]);
- let union_fields = UnionFields::new(
+ let union_fields = UnionFields::try_new(
vec![2, 5],
vec![
Field::new("v_str", DataType::Utf8, true),
Field::new("v_int", DataType::Int32, true),
],
- );
+ )
+ .unwrap();
let type_ids = Buffer::from_slice_ref([2_i8, 5, 5, 2, 5]);
let offsets = Buffer::from_slice_ref([0_i32, 0, 1, 1, 2]);
let union_array = UnionArray::try_new(
diff --git a/arrow-avro/src/writer/mod.rs b/arrow-avro/src/writer/mod.rs
index 9b3eea1d6f..f4a2e60ed5 100644
--- a/arrow-avro/src/writer/mod.rs
+++ b/arrow-avro/src/writer/mod.rs
@@ -683,13 +683,14 @@ mod tests {
use arrow_array::UnionArray;
use arrow_buffer::Buffer;
use arrow_schema::UnionFields;
- let union_fields = UnionFields::new(
+ let union_fields = UnionFields::try_new(
vec![2, 5],
vec![
Field::new("v_str", DataType::Utf8, true),
Field::new("v_int", DataType::Int32, true),
],
- );
+ )
+ .unwrap();
let strings = StringArray::from(vec!["hello", "world"]);
let ints = Int32Array::from(vec![10, 20, 30]);
let type_ids = Buffer::from_slice_ref([2_i8, 5, 5, 2, 5]);
diff --git a/arrow-cast/src/pretty.rs b/arrow-cast/src/pretty.rs
index 1e6535bb12..fbf9e1613b 100644
--- a/arrow-cast/src/pretty.rs
+++ b/arrow-cast/src/pretty.rs
@@ -1610,10 +1610,11 @@ mod tests {
extension::EXTENSION_TYPE_NAME_KEY.to_owned(),
"my_money".to_owned(),
)]);
- let fields = UnionFields::new(
+ let fields = UnionFields::try_new(
vec![0],
vec![Field::new("income", DataType::Int32,
true).with_metadata(money_metadata.clone())],
- );
+ )
+ .unwrap();
// Create nested data and construct it with the correct metadata
let mut array_builder = UnionBuilder::new_dense();
diff --git a/arrow-flight/src/sql/metadata/sql_info.rs
b/arrow-flight/src/sql/metadata/sql_info.rs
index 18adaa877f..155946ea6c 100644
--- a/arrow-flight/src/sql/metadata/sql_info.rs
+++ b/arrow-flight/src/sql/metadata/sql_info.rs
@@ -196,10 +196,7 @@ static UNION_TYPE: Lazy<DataType> = Lazy::new(|| {
),
];
- // create "type ids", one for each type, assume they go from 0 ..
num_fields
- let type_ids: Vec<i8> = (0..fields.len()).map(|v| v as i8).collect();
-
- DataType::Union(UnionFields::new(type_ids, fields), UnionMode::Dense)
+ DataType::Union(UnionFields::from_fields(fields), UnionMode::Dense)
});
impl SqlInfoUnionBuilder {
diff --git a/arrow-ipc/src/convert.rs b/arrow-ipc/src/convert.rs
index a5139e5f21..bf14cd397e 100644
--- a/arrow-ipc/src/convert.rs
+++ b/arrow-ipc/src/convert.rs
@@ -490,8 +490,9 @@ pub(crate) fn get_data_type(field: crate::Field,
may_be_dictionary: bool) -> Dat
};
let fields = match union.typeIds() {
- None => UnionFields::new(0_i8..fields.len() as i8, fields),
- Some(ids) => UnionFields::new(ids.iter().map(|i| i as i8),
fields),
+ None => UnionFields::from_fields(fields),
+ Some(ids) => UnionFields::try_new(ids.iter().map(|i| i as i8),
fields)
+ .expect("invalid union field"),
};
DataType::Union(fields, union_mode)
@@ -1151,13 +1152,14 @@ mod tests {
Field::new(
"union<int32, utf8>",
DataType::Union(
- UnionFields::new(
+ UnionFields::try_new(
vec![2, 3], // non-default type ids
vec![
Field::new("int32", DataType::Int32, true),
Field::new("utf8", DataType::Utf8, true),
],
- ),
+ )
+ .unwrap(),
UnionMode::Dense,
),
true,
diff --git a/arrow-ipc/src/reader.rs b/arrow-ipc/src/reader.rs
index 681f8f96cd..33363bd158 100644
--- a/arrow-ipc/src/reader.rs
+++ b/arrow-ipc/src/reader.rs
@@ -1835,13 +1835,10 @@ mod tests {
let fixed_size_list_data_type =
DataType::FixedSizeList(Arc::new(Field::new_list_field(DataType::Int32,
false)), 3);
- let union_fields = UnionFields::new(
- vec![0, 1],
- vec![
- Field::new("a", DataType::Int32, false),
- Field::new("b", DataType::Float64, false),
- ],
- );
+ let union_fields = UnionFields::from_fields(vec![
+ Field::new("a", DataType::Int32, false),
+ Field::new("b", DataType::Float64, false),
+ ]);
let union_data_type = DataType::Union(union_fields, UnionMode::Dense);
@@ -3107,13 +3104,14 @@ mod tests {
#[test]
fn test_validation_of_invalid_union_array() {
let array = unsafe {
- let fields = UnionFields::new(
+ let fields = UnionFields::try_new(
vec![1, 3], // typeids : type id 2 is not valid
vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Utf8, false),
],
- );
+ )
+ .unwrap();
let type_ids = ScalarBuffer::from(vec![1i8, 2, 3]); // 2 is invalid
let offsets = None;
let children: Vec<ArrayRef> = vec![
diff --git a/arrow-schema/src/datatype.rs b/arrow-schema/src/datatype.rs
index f3ee908fae..79e78830be 100644
--- a/arrow-schema/src/datatype.rs
+++ b/arrow-schema/src/datatype.rs
@@ -994,53 +994,58 @@ mod tests {
assert!(!list_s.equals_datatype(&list_v));
let union_a = DataType::Union(
- UnionFields::new(
+ UnionFields::try_new(
vec![1, 2],
vec![
Field::new("f1", DataType::Utf8, false),
Field::new("f2", DataType::UInt8, false),
],
- ),
+ )
+ .unwrap(),
UnionMode::Sparse,
);
let union_b = DataType::Union(
- UnionFields::new(
+ UnionFields::try_new(
vec![1, 2],
vec![
Field::new("ff1", DataType::Utf8, false),
Field::new("ff2", DataType::UInt8, false),
],
- ),
+ )
+ .unwrap(),
UnionMode::Sparse,
);
let union_c = DataType::Union(
- UnionFields::new(
+ UnionFields::try_new(
vec![2, 1],
vec![
Field::new("fff2", DataType::UInt8, false),
Field::new("fff1", DataType::Utf8, false),
],
- ),
+ )
+ .unwrap(),
UnionMode::Sparse,
);
let union_d = DataType::Union(
- UnionFields::new(
+ UnionFields::try_new(
vec![2, 1],
vec![
Field::new("fff1", DataType::Int8, false),
Field::new("fff2", DataType::UInt8, false),
],
- ),
+ )
+ .unwrap(),
UnionMode::Sparse,
);
let union_e = DataType::Union(
- UnionFields::new(
+ UnionFields::try_new(
vec![1, 2],
vec![
Field::new("f1", DataType::Utf8, true),
Field::new("f2", DataType::UInt8, false),
],
- ),
+ )
+ .unwrap(),
UnionMode::Sparse,
);
@@ -1164,13 +1169,14 @@ mod tests {
fn test_union_with_duplicated_type_id() {
let type_ids = vec![1, 1];
let _union = DataType::Union(
- UnionFields::new(
+ UnionFields::try_new(
type_ids,
vec![
Field::new("f1", DataType::Int32, false),
Field::new("f2", DataType::Utf8, false),
],
- ),
+ )
+ .unwrap(),
UnionMode::Dense,
);
}
diff --git a/arrow-schema/src/datatype_parse.rs
b/arrow-schema/src/datatype_parse.rs
index 68775f9d5b..9349635151 100644
--- a/arrow-schema/src/datatype_parse.rs
+++ b/arrow-schema/src/datatype_parse.rs
@@ -484,7 +484,7 @@ impl<'a> Parser<'a> {
fields.push(field);
}
Ok(DataType::Union(
- UnionFields::new(type_ids, fields),
+ UnionFields::try_new(type_ids, fields)?,
union_mode,
))
}
@@ -1109,49 +1109,40 @@ mod test {
2,
),
DataType::Union(
- UnionFields::new(
- vec![0, 1],
- vec![
- Field::new("Int32", DataType::Int32, false),
- Field::new("Utf8", DataType::Utf8, true),
- ],
- ),
+ UnionFields::from_fields(vec![
+ Field::new("Int32", DataType::Int32, false),
+ Field::new("Utf8", DataType::Utf8, true),
+ ]),
UnionMode::Sparse,
),
DataType::Union(
- UnionFields::new(
- vec![0, 1],
- vec![
- Field::new("Int32", DataType::Int32, false),
- Field::new("Utf8", DataType::Utf8, true),
- ],
- ),
+ UnionFields::from_fields(vec![
+ Field::new("Int32", DataType::Int32, false),
+ Field::new("Utf8", DataType::Utf8, true),
+ ]),
UnionMode::Dense,
),
DataType::Union(
- UnionFields::new(
- vec![0, 1],
- vec![
- Field::new_union(
- "nested_union",
- vec![0, 1],
- vec![
- Field::new("Int32", DataType::Int32, false),
- Field::new("Utf8", DataType::Utf8, true),
- ],
- UnionMode::Dense,
- ),
- Field::new("Utf8", DataType::Utf8, true),
- ],
- ),
+ UnionFields::from_fields(vec![
+ Field::new_union(
+ "nested_union",
+ vec![0, 1],
+ vec![
+ Field::new("Int32", DataType::Int32, false),
+ Field::new("Utf8", DataType::Utf8, true),
+ ],
+ UnionMode::Dense,
+ ),
+ Field::new("Utf8", DataType::Utf8, true),
+ ]),
UnionMode::Sparse,
),
DataType::Union(
- UnionFields::new(vec![0], vec![Field::new("Int32",
DataType::Int32, false)]),
+ UnionFields::from_fields(vec![Field::new("Int32",
DataType::Int32, false)]),
UnionMode::Dense,
),
DataType::Union(
- UnionFields::new(Vec::<i8>::new(), Vec::<Field>::new()),
+ UnionFields::try_new(Vec::<i8>::new(),
Vec::<Field>::new()).unwrap(),
UnionMode::Sparse,
),
DataType::Map(Arc::new(Field::new("Int64", DataType::Int64,
true)), true),
diff --git a/arrow-schema/src/ffi.rs b/arrow-schema/src/ffi.rs
index a1a32224a4..46c622a6d3 100644
--- a/arrow-schema/src/ffi.rs
+++ b/arrow-schema/src/ffi.rs
@@ -570,7 +570,7 @@ impl TryFrom<&FFI_ArrowSchema> for DataType {
));
}
- DataType::Union(UnionFields::new(type_ids, fields),
UnionMode::Dense)
+ DataType::Union(UnionFields::try_new(type_ids,
fields)?, UnionMode::Dense)
}
// SparseUnion
["+us", extra] => {
@@ -598,7 +598,7 @@ impl TryFrom<&FFI_ArrowSchema> for DataType {
));
}
- DataType::Union(UnionFields::new(type_ids, fields),
UnionMode::Sparse)
+ DataType::Union(UnionFields::try_new(type_ids,
fields)?, UnionMode::Sparse)
}
// Timestamps in format "tts:" and "tts:America/New_York"
for no timezones and timezones resp.
diff --git a/arrow-schema/src/field.rs b/arrow-schema/src/field.rs
index 66b09915ca..3b3372a78e 100644
--- a/arrow-schema/src/field.rs
+++ b/arrow-schema/src/field.rs
@@ -342,6 +342,13 @@ impl Field {
/// - `type_ids`: the union type ids
/// - `fields`: the union fields
/// - `mode`: the union mode
+ ///
+ /// # Panics
+ ///
+ /// Panics if:
+ /// - any type ID is negative
+ /// - type IDs contain duplicates
+ /// - the number of type IDs does not equal the number of fields
pub fn new_union<S, F, T>(name: S, type_ids: T, fields: F, mode:
UnionMode) -> Self
where
S: Into<String>,
@@ -351,7 +358,10 @@ impl Field {
{
Self::new(
name,
- DataType::Union(UnionFields::new(type_ids, fields), mode),
+ DataType::Union(
+ UnionFields::try_new(type_ids, fields).expect("Invalid
UnionField"),
+ mode,
+ ),
false, // Unions cannot be nullable
)
}
@@ -1373,13 +1383,14 @@ mod test {
let field1 = Field::new(
"field1",
DataType::Union(
- UnionFields::new(
+ UnionFields::try_new(
vec![1, 2],
vec![
Field::new("field1", DataType::UInt8, true),
Field::new("field3", DataType::Utf8, false),
],
- ),
+ )
+ .unwrap(),
UnionMode::Dense,
),
true,
@@ -1387,13 +1398,14 @@ mod test {
let field2 = Field::new(
"field1",
DataType::Union(
- UnionFields::new(
+ UnionFields::try_new(
vec![1, 3],
vec![
Field::new("field1", DataType::UInt8, false),
Field::new("field3", DataType::Utf8, false),
],
- ),
+ )
+ .unwrap(),
UnionMode::Dense,
),
true,
@@ -1404,13 +1416,14 @@ mod test {
let field1 = Field::new(
"field1",
DataType::Union(
- UnionFields::new(
+ UnionFields::try_new(
vec![1, 2],
vec![
Field::new("field1", DataType::UInt8, true),
Field::new("field3", DataType::Utf8, false),
],
- ),
+ )
+ .unwrap(),
UnionMode::Dense,
),
true,
@@ -1418,13 +1431,14 @@ mod test {
let field2 = Field::new(
"field1",
DataType::Union(
- UnionFields::new(
+ UnionFields::try_new(
vec![1, 2],
vec![
Field::new("field1", DataType::UInt8, false),
Field::new("field3", DataType::Utf8, false),
],
- ),
+ )
+ .unwrap(),
UnionMode::Dense,
),
true,
diff --git a/arrow-schema/src/fields.rs b/arrow-schema/src/fields.rs
index a488f88b6c..93638181d9 100644
--- a/arrow-schema/src/fields.rs
+++ b/arrow-schema/src/fields.rs
@@ -355,12 +355,204 @@ impl UnionFields {
///
/// See <https://arrow.apache.org/docs/format/Columnar.html#union-layout>
///
+ /// # Errors
+ ///
+ /// This function returns an error if:
+ /// - Any type_id appears more than once (duplicate type ids)
+ /// - The type_ids are duplicated
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use arrow_schema::{DataType, Field, UnionFields};
+ /// // Create a new UnionFields with type id mapping
+ /// // 1 -> DataType::UInt8
+ /// // 3 -> DataType::Utf8
+ /// let result = UnionFields::try_new(
+ /// vec![1, 3],
+ /// vec![
+ /// Field::new("field1", DataType::UInt8, false),
+ /// Field::new("field3", DataType::Utf8, false),
+ /// ],
+ /// );
+ /// assert!(result.is_ok());
+ ///
+ /// // This will fail due to duplicate type ids
+ /// let result = UnionFields::try_new(
+ /// vec![1, 1],
+ /// vec![
+ /// Field::new("field1", DataType::UInt8, false),
+ /// Field::new("field2", DataType::Utf8, false),
+ /// ],
+ /// );
+ /// assert!(result.is_err());
+ /// ```
+ pub fn try_new<F, T>(type_ids: T, fields: F) -> Result<Self, ArrowError>
+ where
+ F: IntoIterator,
+ F::Item: Into<FieldRef>,
+ T: IntoIterator<Item = i8>,
+ {
+ let mut type_ids_iter = type_ids.into_iter();
+ let mut fields_iter = fields.into_iter().map(Into::into);
+
+ let mut seen_type_ids = 0u128;
+
+ let mut out = Vec::new();
+
+ loop {
+ match (type_ids_iter.next(), fields_iter.next()) {
+ (None, None) => return Ok(Self(out.into())),
+ (Some(type_id), Some(field)) => {
+ // check type id is non-negative
+ if type_id < 0 {
+ return Err(ArrowError::InvalidArgumentError(format!(
+ "type ids must be non-negative: {type_id}"
+ )));
+ }
+
+ // check type id uniqueness
+ let mask = 1_u128 << type_id;
+ if (seen_type_ids & mask) != 0 {
+ return Err(ArrowError::InvalidArgumentError(format!(
+ "duplicate type id: {type_id}"
+ )));
+ }
+
+ seen_type_ids |= mask;
+
+ out.push((type_id, field));
+ }
+ (None, Some(_)) => {
+ return Err(ArrowError::InvalidArgumentError(
+ "fields iterator has more elements than type_ids
iterator".to_string(),
+ ));
+ }
+ (Some(_), None) => {
+ return Err(ArrowError::InvalidArgumentError(
+ "type_ids iterator has more elements than fields
iterator".to_string(),
+ ));
+ }
+ }
+ }
+ }
+
+ /// Create a new [`UnionFields`] from a collection of fields with
automatically
+ /// assigned type IDs starting from 0.
+ ///
+ /// The type IDs are assigned in increasing order: 0, 1, 2, 3, etc.
+ ///
+ /// See <https://arrow.apache.org/docs/format/Columnar.html#union-layout>
+ ///
+ /// # Panics
+ ///
+ /// Panics if the number of fields exceeds 127 (the maximum value for i8
type IDs).
+ ///
+ /// If you want to avoid panics, use [`UnionFields::try_from_fields`]
instead, which
+ /// returns a `Result`.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use arrow_schema::{DataType, Field, UnionFields};
+ /// // Create a new UnionFields with automatic type id assignment
+ /// // 0 -> DataType::UInt8
+ /// // 1 -> DataType::Utf8
+ /// let union_fields = UnionFields::from_fields(vec![
+ /// Field::new("field1", DataType::UInt8, false),
+ /// Field::new("field2", DataType::Utf8, false),
+ /// ]);
+ /// assert_eq!(union_fields.len(), 2);
+ /// ```
+ pub fn from_fields<F>(fields: F) -> Self
+ where
+ F: IntoIterator,
+ F::Item: Into<FieldRef>,
+ {
+ fields
+ .into_iter()
+ .enumerate()
+ .map(|(i, field)| {
+ let id = i8::try_from(i).expect("UnionFields cannot contain
more than 128 fields");
+
+ (id, field.into())
+ })
+ .collect()
+ }
+
+ /// Create a new [`UnionFields`] from a collection of fields with
automatically
+ /// assigned type IDs starting from 0.
+ ///
+ /// The type IDs are assigned in increasing order: 0, 1, 2, 3, etc.
+ ///
+ /// This is the non-panicking version of [`UnionFields::from_fields`].
+ ///
+ /// See <https://arrow.apache.org/docs/format/Columnar.html#union-layout>
+ ///
+ /// # Errors
+ ///
+ /// Returns an error if the number of fields exceeds 127 (the maximum
value for i8 type IDs).
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use arrow_schema::{DataType, Field, UnionFields};
+ /// // Create a new UnionFields with automatic type id assignment
+ /// // 0 -> DataType::UInt8
+ /// // 1 -> DataType::Utf8
+ /// let result = UnionFields::try_from_fields(vec![
+ /// Field::new("field1", DataType::UInt8, false),
+ /// Field::new("field2", DataType::Utf8, false),
+ /// ]);
+ /// assert!(result.is_ok());
+ /// assert_eq!(result.unwrap().len(), 2);
+ ///
+ /// // This will fail with too many fields
+ /// let many_fields: Vec<_> = (0..200)
+ /// .map(|i| Field::new(format!("field{}", i), DataType::Int32, false))
+ /// .collect();
+ /// let result = UnionFields::try_from_fields(many_fields);
+ /// assert!(result.is_err());
+ /// ```
+ pub fn try_from_fields<F>(fields: F) -> Result<Self, ArrowError>
+ where
+ F: IntoIterator,
+ F::Item: Into<FieldRef>,
+ {
+ let mut out = Vec::with_capacity(i8::MAX as usize + 1);
+
+ for (i, field) in fields.into_iter().enumerate() {
+ let id = i8::try_from(i).map_err(|_| {
+ ArrowError::InvalidArgumentError(
+ "UnionFields cannot contain more than 128 fields".into(),
+ )
+ })?;
+
+ out.push((id, field.into()));
+ }
+
+ Ok(Self(out.into()))
+ }
+
+ /// Create a new [`UnionFields`] from a [`Fields`] and array of type_ids
+ ///
+ /// See <https://arrow.apache.org/docs/format/Columnar.html#union-layout>
+ ///
+ /// # Deprecated
+ ///
+ /// Use [`UnionFields::try_new`] instead. This method panics on invalid
input,
+ /// while `try_new` returns a `Result`.
+ ///
+ /// # Panics
+ ///
+ /// Panics if any type_id appears more than once (duplicate type ids).
+ ///
/// ```
/// use arrow_schema::{DataType, Field, UnionFields};
/// // Create a new UnionFields with type id mapping
/// // 1 -> DataType::UInt8
/// // 3 -> DataType::Utf8
- /// UnionFields::new(
+ /// UnionFields::try_new(
/// vec![1, 3],
/// vec![
/// Field::new("field1", DataType::UInt8, false),
@@ -368,6 +560,7 @@ impl UnionFields {
/// ],
/// );
/// ```
+ #[deprecated(since = "57.0.0", note = "Use `try_new` instead")]
pub fn new<F, T>(type_ids: T, fields: F) -> Self
where
F: IntoIterator,
@@ -486,7 +679,6 @@ impl UnionFields {
impl FromIterator<(i8, FieldRef)> for UnionFields {
fn from_iter<T: IntoIterator<Item = (i8, FieldRef)>>(iter: T) -> Self {
- // TODO: Should this validate type IDs are unique (#3982)
Self(iter.into_iter().collect())
}
}
@@ -541,13 +733,14 @@ mod tests {
Field::new(
"h",
DataType::Union(
- UnionFields::new(
+ UnionFields::try_new(
vec![1, 3],
vec![
Field::new("field1", DataType::UInt8, false),
Field::new("field3", DataType::Utf8, false),
],
- ),
+ )
+ .unwrap(),
UnionMode::Dense,
),
true,
@@ -605,7 +798,8 @@ mod tests {
assert_eq!(r[0], fields[7]);
let union = DataType::Union(
- UnionFields::new(vec![1], vec![Field::new("field1",
DataType::UInt8, false)]),
+ UnionFields::try_new(vec![1], vec![Field::new("field1",
DataType::UInt8, false)])
+ .unwrap(),
UnionMode::Dense,
);
@@ -621,4 +815,228 @@ mod tests {
let r = fields.try_filter_leaves(|_, _|
Err(ArrowError::SchemaError("error".to_string())));
assert!(r.is_err());
}
+
+ #[test]
+ fn test_union_fields_try_new_valid() {
+ let res = UnionFields::try_new(
+ vec![1, 6, 7],
+ vec![
+ Field::new("f1", DataType::UInt8, false),
+ Field::new("f6", DataType::Utf8, false),
+ Field::new("f7", DataType::Int32, true),
+ ],
+ );
+ assert!(res.is_ok());
+ let union_fields = res.unwrap();
+ assert_eq!(union_fields.len(), 3);
+ assert_eq!(
+ union_fields.iter().map(|(id, _)| id).collect::<Vec<_>>(),
+ vec![1, 6, 7]
+ );
+ }
+
+ #[test]
+ fn test_union_fields_try_new_empty() {
+ let res = UnionFields::try_new(Vec::<i8>::new(), Vec::<Field>::new());
+ assert!(res.is_ok());
+ assert!(res.unwrap().is_empty());
+ }
+
+ #[test]
+ fn test_union_fields_try_new_duplicate_type_id() {
+ let res = UnionFields::try_new(
+ vec![1, 1],
+ vec![
+ Field::new("f1", DataType::UInt8, false),
+ Field::new("f2", DataType::Utf8, false),
+ ],
+ );
+ assert!(res.is_err());
+ assert!(
+ res.unwrap_err()
+ .to_string()
+ .contains("duplicate type id: 1")
+ );
+ }
+
+ #[test]
+ fn test_union_fields_try_new_duplicate_field() {
+ let field = Field::new("field", DataType::UInt8, false);
+ let res = UnionFields::try_new(vec![1, 2], vec![field.clone(), field]);
+ assert!(res.is_ok());
+ }
+
+ #[test]
+ fn test_union_fields_try_new_more_type_ids() {
+ let res = UnionFields::try_new(
+ vec![1, 2, 3],
+ vec![
+ Field::new("f1", DataType::UInt8, false),
+ Field::new("f2", DataType::Utf8, false),
+ ],
+ );
+ assert!(res.is_err());
+ assert!(
+ res.unwrap_err()
+ .to_string()
+ .contains("type_ids iterator has more elements")
+ );
+ }
+
+ #[test]
+ fn test_union_fields_try_new_more_fields() {
+ let res = UnionFields::try_new(
+ vec![1, 2],
+ vec![
+ Field::new("f1", DataType::UInt8, false),
+ Field::new("f2", DataType::Utf8, false),
+ Field::new("f3", DataType::Int32, true),
+ ],
+ );
+ assert!(res.is_err());
+ assert!(
+ res.unwrap_err()
+ .to_string()
+ .contains("fields iterator has more elements")
+ );
+ }
+
+ #[test]
+ fn test_union_fields_try_new_negative_type_ids() {
+ let res = UnionFields::try_new(
+ vec![-128, -1, 0, 127],
+ vec![
+ Field::new("field_min", DataType::UInt8, false),
+ Field::new("field_neg", DataType::Utf8, false),
+ Field::new("field_zero", DataType::Int32, true),
+ Field::new("field_max", DataType::Boolean, false),
+ ],
+ );
+ assert!(res.is_err());
+ assert!(
+ res.unwrap_err()
+ .to_string()
+ .contains("type ids must be non-negative")
+ )
+ }
+
+ #[test]
+ fn test_union_fields_try_new_complex_types() {
+ let res = UnionFields::try_new(
+ vec![0, 1, 2],
+ vec![
+ Field::new(
+ "struct_field",
+ DataType::Struct(Fields::from(vec![
+ Field::new("a", DataType::Int32, false),
+ Field::new("b", DataType::Utf8, true),
+ ])),
+ false,
+ ),
+ Field::new_list(
+ "list_field",
+ Field::new("item", DataType::Float64, true),
+ true,
+ ),
+ Field::new(
+ "dict_field",
+ DataType::Dictionary(Box::new(DataType::Int32),
Box::new(DataType::Utf8)),
+ false,
+ ),
+ ],
+ );
+ assert!(res.is_ok());
+ assert_eq!(res.unwrap().len(), 3);
+ }
+
+ #[test]
+ fn test_union_fields_try_new_single_field() {
+ let res = UnionFields::try_new(
+ vec![42],
+ vec![Field::new("only_field", DataType::Int64, false)],
+ );
+ assert!(res.is_ok());
+ let union_fields = res.unwrap();
+ assert_eq!(union_fields.len(), 1);
+ assert_eq!(union_fields.iter().next().unwrap().0, 42);
+ }
+
+ #[test]
+ fn test_union_fields_try_from_fields_empty() {
+ let res = UnionFields::try_from_fields(Vec::<Field>::new());
+ assert!(res.is_ok());
+ assert!(res.unwrap().is_empty());
+ }
+
+ #[test]
+ fn test_union_fields_try_from_fields_single() {
+ let res = UnionFields::try_from_fields(vec![Field::new("only",
DataType::Int64, false)]);
+ assert!(res.is_ok());
+ let union_fields = res.unwrap();
+ assert_eq!(union_fields.len(), 1);
+ assert_eq!(union_fields.iter().next().unwrap().0, 0);
+ }
+
+ #[test]
+ fn test_union_fields_try_from_fields_too_many() {
+ let many_fields: Vec<_> = (0..200)
+ .map(|i| Field::new(format!("field{}", i), DataType::Int32, false))
+ .collect();
+ let res = UnionFields::try_from_fields(many_fields);
+ assert!(res.is_err());
+ assert!(
+ res.unwrap_err()
+ .to_string()
+ .contains("UnionFields cannot contain more than 128 fields")
+ );
+ }
+
+ #[test]
+ fn test_union_fields_try_from_fields_max_valid() {
+ let fields: Vec<_> = (0..=i8::MAX)
+ .map(|i| Field::new(format!("field{}", i), DataType::Int32, false))
+ .collect();
+ let res = UnionFields::try_from_fields(fields);
+ assert!(res.is_ok());
+ let union_fields = res.unwrap();
+ assert_eq!(union_fields.len(), 128);
+ assert_eq!(union_fields.iter().map(|(id, _)| id).min().unwrap(), 0);
+ assert_eq!(union_fields.iter().map(|(id, _)| id).max().unwrap(), 127);
+ }
+
+ #[test]
+ fn test_union_fields_try_from_fields_over_max() {
+ // 129 fields should fail
+ let fields: Vec<_> = (0..129)
+ .map(|i| Field::new(format!("field{}", i), DataType::Int32, false))
+ .collect();
+ let res = UnionFields::try_from_fields(fields);
+ assert!(res.is_err());
+ }
+
+ #[test]
+ fn test_union_fields_try_from_fields_complex_types() {
+ let res = UnionFields::try_from_fields(vec![
+ Field::new(
+ "struct_field",
+ DataType::Struct(Fields::from(vec![
+ Field::new("a", DataType::Int32, false),
+ Field::new("b", DataType::Utf8, true),
+ ])),
+ false,
+ ),
+ Field::new_list(
+ "list_field",
+ Field::new("item", DataType::Float64, true),
+ true,
+ ),
+ Field::new(
+ "dict_field",
+ DataType::Dictionary(Box::new(DataType::Int32),
Box::new(DataType::Utf8)),
+ false,
+ ),
+ ]);
+ assert!(res.is_ok());
+ assert_eq!(res.unwrap().len(), 3);
+ }
}
diff --git a/arrow-select/src/take.rs b/arrow-select/src/take.rs
index edf50a6a2f..7459018ea4 100644
--- a/arrow-select/src/take.rs
+++ b/arrow-select/src/take.rs
@@ -2636,7 +2636,7 @@ mod tests {
#[test]
fn test_take_union_dense_all_match_issue_6206() {
- let fields = UnionFields::new(vec![0], vec![Field::new("a",
DataType::Int64, false)]);
+ let fields = UnionFields::from_fields(vec![Field::new("a",
DataType::Int64, false)]);
let ints = Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5]));
let array = UnionArray::try_new(
diff --git a/arrow-select/src/union_extract.rs
b/arrow-select/src/union_extract.rs
index 0460e59c45..3accecc359 100644
--- a/arrow-select/src/union_extract.rs
+++ b/arrow-select/src/union_extract.rs
@@ -53,13 +53,13 @@ use std::sync::Arc;
/// # use arrow_schema::{DataType, Field, UnionFields};
/// # use arrow_array::{UnionArray, StringArray, Int32Array};
/// # use arrow_select::union_extract::union_extract;
-/// let fields = UnionFields::new(
+/// let fields = UnionFields::try_new(
/// [1, 3],
/// [
/// Field::new("A", DataType::Int32, true),
/// Field::new("B", DataType::Utf8, true)
/// ]
-/// );
+/// ).unwrap();
///
/// let union = UnionArray::try_new(
/// fields,
@@ -543,17 +543,18 @@ mod tests {
}
fn str1() -> UnionFields {
- UnionFields::new(vec![1], vec![Field::new("str", DataType::Utf8,
true)])
+ UnionFields::try_new(vec![1], vec![Field::new("str", DataType::Utf8,
true)]).unwrap()
}
fn str1_int3() -> UnionFields {
- UnionFields::new(
+ UnionFields::try_new(
vec![1, 3],
vec![
Field::new("str", DataType::Utf8, true),
Field::new("int", DataType::Int32, true),
],
)
+ .unwrap()
}
#[test]
@@ -599,13 +600,14 @@ mod tests {
fn sparse_1_3a_null_target() {
let union = UnionArray::try_new(
// multiple fields
- UnionFields::new(
+ UnionFields::try_new(
vec![1, 3],
vec![
Field::new("str", DataType::Utf8, true),
Field::new("null", DataType::Null, true), // target type
is Null
],
- ),
+ )
+ .unwrap(),
ScalarBuffer::from(vec![1]), //not empty
None, // sparse
vec![
@@ -682,13 +684,14 @@ mod tests {
}
fn str1_union3(union3_datatype: DataType) -> UnionFields {
- UnionFields::new(
+ UnionFields::try_new(
vec![1, 3],
vec![
Field::new("str", DataType::Utf8, true),
Field::new("union", union3_datatype, true),
],
)
+ .unwrap()
}
#[test]
diff --git a/arrow/tests/array_cast.rs b/arrow/tests/array_cast.rs
index 3dcbfd970a..0e3b9c597b 100644
--- a/arrow/tests/array_cast.rs
+++ b/arrow/tests/array_cast.rs
@@ -552,13 +552,14 @@ fn get_all_types() -> Vec<DataType> {
Field::new("f2", DataType::Utf8, true),
])),
Union(
- UnionFields::new(
+ UnionFields::try_new(
vec![0, 1],
vec![
Field::new("f1", DataType::Int32, false),
Field::new("f2", DataType::Utf8, true),
],
- ),
+ )
+ .unwrap(),
UnionMode::Dense,
),
Decimal128(38, 0),
diff --git a/arrow/tests/array_validation.rs b/arrow/tests/array_validation.rs
index 66a7b7c452..62e7241f5e 100644
--- a/arrow/tests/array_validation.rs
+++ b/arrow/tests/array_validation.rs
@@ -825,13 +825,14 @@ fn test_validate_union_different_types() {
ArrayData::try_new(
DataType::Union(
- UnionFields::new(
+ UnionFields::try_new(
vec![0, 1],
vec![
Field::new("field1", DataType::Int32, true),
Field::new("field2", DataType::Int64, true), // data is
int32
],
- ),
+ )
+ .unwrap(),
UnionMode::Sparse,
),
2,
@@ -858,13 +859,14 @@ fn test_validate_union_sparse_different_child_len() {
ArrayData::try_new(
DataType::Union(
- UnionFields::new(
+ UnionFields::try_new(
vec![0, 1],
vec![
Field::new("field1", DataType::Int32, true),
Field::new("field2", DataType::Int64, true),
],
- ),
+ )
+ .unwrap(),
UnionMode::Sparse,
),
2,
@@ -887,13 +889,14 @@ fn test_validate_union_dense_without_offsets() {
ArrayData::try_new(
DataType::Union(
- UnionFields::new(
+ UnionFields::try_new(
vec![0, 1],
vec![
Field::new("field1", DataType::Int32, true),
Field::new("field2", DataType::Int64, true),
],
- ),
+ )
+ .unwrap(),
UnionMode::Dense,
),
2,
@@ -917,13 +920,14 @@ fn test_validate_union_dense_with_bad_len() {
ArrayData::try_new(
DataType::Union(
- UnionFields::new(
+ UnionFields::try_new(
vec![0, 1],
vec![
Field::new("field1", DataType::Int32, true),
Field::new("field2", DataType::Int64, true),
],
- ),
+ )
+ .unwrap(),
UnionMode::Dense,
),
2,
diff --git a/parquet-variant-compute/src/arrow_to_variant.rs
b/parquet-variant-compute/src/arrow_to_variant.rs
index 5db1530025..3009b602cb 100644
--- a/parquet-variant-compute/src/arrow_to_variant.rs
+++ b/parquet-variant-compute/src/arrow_to_variant.rs
@@ -1467,14 +1467,11 @@ mod tests {
let string_array = StringArray::from(vec![None, None, Some("hello"),
None, None, None]);
let type_ids = [0, 1, 2, 1, 0,
0].into_iter().collect::<ScalarBuffer<i8>>();
- let union_fields = UnionFields::new(
- vec![0, 1, 2],
- vec![
- Field::new("int_field", DataType::Int32, false),
- Field::new("float_field", DataType::Float64, false),
- Field::new("string_field", DataType::Utf8, false),
- ],
- );
+ let union_fields = UnionFields::from_fields(vec![
+ Field::new("int_field", DataType::Int32, false),
+ Field::new("float_field", DataType::Float64, false),
+ Field::new("string_field", DataType::Utf8, false),
+ ]);
let children: Vec<Arc<dyn Array>> = vec![
Arc::new(int_array),
@@ -1515,14 +1512,11 @@ mod tests {
.into_iter()
.collect::<ScalarBuffer<i32>>();
- let union_fields = UnionFields::new(
- vec![0, 1, 2],
- vec![
- Field::new("int_field", DataType::Int32, false),
- Field::new("float_field", DataType::Float64, false),
- Field::new("string_field", DataType::Utf8, false),
- ],
- );
+ let union_fields = UnionFields::from_fields(vec![
+ Field::new("int_field", DataType::Int32, false),
+ Field::new("float_field", DataType::Float64, false),
+ Field::new("string_field", DataType::Utf8, false),
+ ]);
let children: Vec<Arc<dyn Array>> = vec![
Arc::new(int_array),
@@ -1571,13 +1565,14 @@ mod tests {
let string_array = StringArray::from(vec![None, Some("test")]);
let type_ids = [1, 3].into_iter().collect::<ScalarBuffer<i8>>();
- let union_fields = UnionFields::new(
+ let union_fields = UnionFields::try_new(
vec![1, 3], // Non-contiguous type IDs
vec![
Field::new("int_field", DataType::Int32, false),
Field::new("string_field", DataType::Utf8, false),
],
- );
+ )
+ .unwrap();
let children: Vec<Arc<dyn Array>> = vec![Arc::new(int_array),
Arc::new(string_array)];
diff --git a/parquet-variant-compute/src/cast_to_variant.rs
b/parquet-variant-compute/src/cast_to_variant.rs
index 4f400a5f7b..c3ffc7a42c 100644
--- a/parquet-variant-compute/src/cast_to_variant.rs
+++ b/parquet-variant-compute/src/cast_to_variant.rs
@@ -2065,14 +2065,11 @@ mod tests {
let string_array = StringArray::from(vec![None, None, Some("hello"),
None, None, None]);
let type_ids = [0, 1, 2, 1, 0,
0].into_iter().collect::<ScalarBuffer<i8>>();
- let union_fields = UnionFields::new(
- vec![0, 1, 2],
- vec![
- Field::new("int_field", DataType::Int32, false),
- Field::new("float_field", DataType::Float64, false),
- Field::new("string_field", DataType::Utf8, false),
- ],
- );
+ let union_fields = UnionFields::from_fields(vec![
+ Field::new("int_field", DataType::Int32, false),
+ Field::new("float_field", DataType::Float64, false),
+ Field::new("string_field", DataType::Utf8, false),
+ ]);
let children: Vec<Arc<dyn Array>> = vec![
Arc::new(int_array),
@@ -2112,14 +2109,11 @@ mod tests {
.into_iter()
.collect::<ScalarBuffer<i32>>();
- let union_fields = UnionFields::new(
- vec![0, 1, 2],
- vec![
- Field::new("int_field", DataType::Int32, false),
- Field::new("float_field", DataType::Float64, false),
- Field::new("string_field", DataType::Utf8, false),
- ],
- );
+ let union_fields = UnionFields::from_fields(vec![
+ Field::new("int_field", DataType::Int32, false),
+ Field::new("float_field", DataType::Float64, false),
+ Field::new("string_field", DataType::Utf8, false),
+ ]);
let children: Vec<Arc<dyn Array>> = vec![
Arc::new(int_array),
diff --git a/parquet-variant-compute/src/shred_variant.rs
b/parquet-variant-compute/src/shred_variant.rs
index 7367336379..45e7fc95c9 100644
--- a/parquet-variant-compute/src/shred_variant.rs
+++ b/parquet-variant-compute/src/shred_variant.rs
@@ -1325,13 +1325,10 @@ mod tests {
DataType::LargeUtf8,
DataType::FixedSizeBinary(17),
DataType::Union(
- UnionFields::new(
- vec![0_i8, 1_i8],
- vec![
- Field::new("int_field", DataType::Int32, false),
- Field::new("str_field", DataType::Utf8, true),
- ],
- ),
+ UnionFields::from_fields(vec![
+ Field::new("int_field", DataType::Int32, false),
+ Field::new("str_field", DataType::Utf8, true),
+ ]),
UnionMode::Dense,
),
DataType::Map(
diff --git a/parquet-variant-compute/src/variant_to_arrow.rs
b/parquet-variant-compute/src/variant_to_arrow.rs
index 7d4c427fa4..57d9944bb5 100644
--- a/parquet-variant-compute/src/variant_to_arrow.rs
+++ b/parquet-variant-compute/src/variant_to_arrow.rs
@@ -798,7 +798,8 @@ mod tests {
true,
));
let union_fields =
- UnionFields::new(vec![1], vec![Field::new("child",
DataType::Int32, true)]);
+ UnionFields::try_new(vec![1], vec![Field::new("child",
DataType::Int32, true)])
+ .unwrap();
let run_ends_field = Arc::new(Field::new("run_ends", DataType::Int32,
false));
let ree_values_field = Arc::new(Field::new("values", DataType::Utf8,
true));