This is an automated email from the ASF dual-hosted git repository. alamb pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/main by this push: new 0c4e58f9d8 [Variant]: Implement `DataType::Union` support for `cast_to_variant` kernel (#8196) 0c4e58f9d8 is described below commit 0c4e58f9d8e499237b1e8bd2249a9b06deeae378 Author: Liam Bao <liam.zw....@gmail.com> AuthorDate: Sat Aug 23 07:04:37 2025 -0400 [Variant]: Implement `DataType::Union` support for `cast_to_variant` kernel (#8196) # Which issue does this PR close? - Closes #8195. # Rationale for this change # What changes are included in this PR? Implement `DataType::Union` for `cast_to_variant` # Are these changes tested? Yes # Are there any user-facing changes? New cast type supported --------- Co-authored-by: Andrew Lamb <and...@nerdnetworks.org> --- parquet-variant-compute/src/cast_to_variant.rs | 198 +++++++++++++++++++++---- 1 file changed, 170 insertions(+), 28 deletions(-) diff --git a/parquet-variant-compute/src/cast_to_variant.rs b/parquet-variant-compute/src/cast_to_variant.rs index 3850579946..782e336b09 100644 --- a/parquet-variant-compute/src/cast_to_variant.rs +++ b/parquet-variant-compute/src/cast_to_variant.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::collections::HashMap; use std::sync::Arc; use crate::type_conversion::{ @@ -39,7 +40,7 @@ use arrow::temporal_conversions::{ timestamp_ms_to_datetime, timestamp_ns_to_datetime, timestamp_s_to_datetime, timestamp_us_to_datetime, }; -use arrow_schema::{ArrowError, DataType, TimeUnit}; +use arrow_schema::{ArrowError, DataType, TimeUnit, UnionFields}; use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Utc}; use parquet_variant::{ Variant, VariantBuilder, VariantDecimal16, VariantDecimal4, VariantDecimal8, @@ -379,6 +380,9 @@ pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> { builder.append_variant(variant); } } + DataType::Union(fields, _) => { + convert_union(fields, input, &mut builder)?; + } DataType::Date32 => { generic_conversion_array!( Date32Type, @@ -398,9 +402,9 @@ pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> { ); } DataType::RunEndEncoded(run_ends, _) => match run_ends.data_type() { - DataType::Int16 => process_run_end_encoded::<Int16Type>(input, &mut builder)?, - DataType::Int32 => process_run_end_encoded::<Int32Type>(input, &mut builder)?, - DataType::Int64 => process_run_end_encoded::<Int64Type>(input, &mut builder)?, + DataType::Int16 => convert_run_end_encoded::<Int16Type>(input, &mut builder)?, + DataType::Int32 => convert_run_end_encoded::<Int32Type>(input, &mut builder)?, + DataType::Int64 => convert_run_end_encoded::<Int64Type>(input, &mut builder)?, _ => { return Err(ArrowError::CastError(format!( "Unsupported run ends type: {:?}", @@ -409,25 +413,7 @@ pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> { } }, DataType::Dictionary(_, _) => { - let dict_array = input.as_any_dictionary(); - let values_variant_array = cast_to_variant(dict_array.values().as_ref())?; - let normalized_keys = dict_array.normalized_keys(); - let keys = dict_array.keys(); - - for (i, key_idx) in normalized_keys.iter().enumerate() { - if keys.is_null(i) { - builder.append_null(); - continue; - } - - if values_variant_array.is_null(*key_idx) { - builder.append_null(); - continue; - } - - let value = values_variant_array.value(*key_idx); - builder.append_variant(value); - } + convert_dictionary_encoded(input, &mut builder)?; } DataType::Map(field, _) => match field.data_type() { @@ -559,8 +545,45 @@ pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> { Ok(builder.build()) } -/// Generic function to process run-end encoded arrays -fn process_run_end_encoded<R: RunEndIndexType>( +/// Convert union arrays +fn convert_union( + fields: &UnionFields, + input: &dyn Array, + builder: &mut VariantArrayBuilder, +) -> Result<(), ArrowError> { + let union_array = input.as_union(); + + // Convert each child array to variant arrays + let mut child_variant_arrays = HashMap::new(); + for (type_id, _) in fields.iter() { + let child_array = union_array.child(type_id); + let child_variant_array = cast_to_variant(child_array.as_ref())?; + child_variant_arrays.insert(type_id, child_variant_array); + } + + // Process each element in the union array + for i in 0..union_array.len() { + let type_id = union_array.type_id(i); + let value_offset = union_array.value_offset(i); + + if let Some(child_variant_array) = child_variant_arrays.get(&type_id) { + if child_variant_array.is_null(value_offset) { + builder.append_null(); + } else { + let value = child_variant_array.value(value_offset); + builder.append_variant(value); + } + } else { + // This should not happen in a valid union, but handle gracefully + builder.append_null(); + } + } + + Ok(()) +} + +/// Generic function to convert run-end encoded arrays +fn convert_run_end_encoded<R: RunEndIndexType>( input: &dyn Array, builder: &mut VariantArrayBuilder, ) -> Result<(), ArrowError> { @@ -594,6 +617,34 @@ fn process_run_end_encoded<R: RunEndIndexType>( Ok(()) } +/// Convert dictionary encoded arrays +fn convert_dictionary_encoded( + input: &dyn Array, + builder: &mut VariantArrayBuilder, +) -> Result<(), ArrowError> { + let dict_array = input.as_any_dictionary(); + let values_variant_array = cast_to_variant(dict_array.values().as_ref())?; + let normalized_keys = dict_array.normalized_keys(); + let keys = dict_array.keys(); + + for (i, key_idx) in normalized_keys.iter().enumerate() { + if keys.is_null(i) { + builder.append_null(); + continue; + } + + if values_variant_array.is_null(*key_idx) { + builder.append_null(); + continue; + } + + let value = values_variant_array.value(*key_idx); + builder.append_variant(value); + } + + Ok(()) +} + // TODO do we need a cast_with_options to allow specifying conversion behavior, // e.g. how to handle overflows, whether to convert to Variant::Null or return // an error, etc. ? @@ -609,10 +660,10 @@ mod tests { LargeStringArray, ListArray, MapArray, NullArray, StringArray, StringRunBuilder, StringViewArray, StructArray, Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, UInt16Array, UInt32Array, UInt64Array, - UInt8Array, + UInt8Array, UnionArray, }; - use arrow::buffer::{NullBuffer, OffsetBuffer}; - use arrow_schema::{Field, Fields}; + use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer}; + use arrow_schema::{DataType, Field, Fields, UnionFields}; use arrow_schema::{ DECIMAL128_MAX_PRECISION, DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION, }; @@ -1637,6 +1688,97 @@ mod tests { assert_eq!(obj4.get("age"), None); } + #[test] + fn test_cast_to_variant_union_sparse() { + // Create a sparse union array with mixed types (int, float, string) + let int_array = Int32Array::from(vec![Some(1), None, None, None, Some(34), None]); + let float_array = Float64Array::from(vec![None, Some(3.2), None, Some(32.5), None, None]); + let string_array = StringArray::from(vec![None, None, Some("hello"), None, None, None]); + let type_ids = [0, 1, 2, 1, 0, 0].into_iter().collect::<ScalarBuffer<i8>>(); + + let union_fields = UnionFields::new( + vec![0, 1, 2], + vec![ + Field::new("int_field", DataType::Int32, false), + Field::new("float_field", DataType::Float64, false), + Field::new("string_field", DataType::Utf8, false), + ], + ); + + let children: Vec<Arc<dyn Array>> = vec![ + Arc::new(int_array), + Arc::new(float_array), + Arc::new(string_array), + ]; + + let union_array = UnionArray::try_new( + union_fields, + type_ids, + None, // Sparse union + children, + ) + .unwrap(); + + run_test( + Arc::new(union_array), + vec![ + Some(Variant::Int32(1)), + Some(Variant::Double(3.2)), + Some(Variant::from("hello")), + Some(Variant::Double(32.5)), + Some(Variant::Int32(34)), + None, + ], + ); + } + + #[test] + fn test_cast_to_variant_union_dense() { + // Create a dense union array with mixed types (int, float, string) + let int_array = Int32Array::from(vec![Some(1), Some(34), None]); + let float_array = Float64Array::from(vec![3.2, 32.5]); + let string_array = StringArray::from(vec!["hello"]); + let type_ids = [0, 1, 2, 1, 0, 0].into_iter().collect::<ScalarBuffer<i8>>(); + let offsets = [0, 0, 0, 1, 1, 2] + .into_iter() + .collect::<ScalarBuffer<i32>>(); + + let union_fields = UnionFields::new( + vec![0, 1, 2], + vec![ + Field::new("int_field", DataType::Int32, false), + Field::new("float_field", DataType::Float64, false), + Field::new("string_field", DataType::Utf8, false), + ], + ); + + let children: Vec<Arc<dyn Array>> = vec![ + Arc::new(int_array), + Arc::new(float_array), + Arc::new(string_array), + ]; + + let union_array = UnionArray::try_new( + union_fields, + type_ids, + Some(offsets), // Dense union + children, + ) + .unwrap(); + + run_test( + Arc::new(union_array), + vec![ + Some(Variant::Int32(1)), + Some(Variant::Double(3.2)), + Some(Variant::from("hello")), + Some(Variant::Double(32.5)), + Some(Variant::Int32(34)), + None, + ], + ); + } + #[test] fn test_cast_to_variant_struct_with_nulls() { // Test struct with null values at the struct level