This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/main by this push:
     new 0c4e58f9d8 [Variant]: Implement `DataType::Union` support for 
`cast_to_variant` kernel (#8196)
0c4e58f9d8 is described below

commit 0c4e58f9d8e499237b1e8bd2249a9b06deeae378
Author: Liam Bao <liam.zw....@gmail.com>
AuthorDate: Sat Aug 23 07:04:37 2025 -0400

    [Variant]: Implement `DataType::Union` support for `cast_to_variant` kernel 
(#8196)
    
    # Which issue does this PR close?
    
    - Closes #8195.
    
    # Rationale for this change
    
    # What changes are included in this PR?
    
    Implement `DataType::Union` for `cast_to_variant`
    
    # Are these changes tested?
    
    Yes
    
    # Are there any user-facing changes?
    
    New cast type supported
    
    ---------
    
    Co-authored-by: Andrew Lamb <and...@nerdnetworks.org>
---
 parquet-variant-compute/src/cast_to_variant.rs | 198 +++++++++++++++++++++----
 1 file changed, 170 insertions(+), 28 deletions(-)

diff --git a/parquet-variant-compute/src/cast_to_variant.rs 
b/parquet-variant-compute/src/cast_to_variant.rs
index 3850579946..782e336b09 100644
--- a/parquet-variant-compute/src/cast_to_variant.rs
+++ b/parquet-variant-compute/src/cast_to_variant.rs
@@ -15,6 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use std::collections::HashMap;
 use std::sync::Arc;
 
 use crate::type_conversion::{
@@ -39,7 +40,7 @@ use arrow::temporal_conversions::{
     timestamp_ms_to_datetime, timestamp_ns_to_datetime, 
timestamp_s_to_datetime,
     timestamp_us_to_datetime,
 };
-use arrow_schema::{ArrowError, DataType, TimeUnit};
+use arrow_schema::{ArrowError, DataType, TimeUnit, UnionFields};
 use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Utc};
 use parquet_variant::{
     Variant, VariantBuilder, VariantDecimal16, VariantDecimal4, 
VariantDecimal8,
@@ -379,6 +380,9 @@ pub fn cast_to_variant(input: &dyn Array) -> 
Result<VariantArray, ArrowError> {
                 builder.append_variant(variant);
             }
         }
+        DataType::Union(fields, _) => {
+            convert_union(fields, input, &mut builder)?;
+        }
         DataType::Date32 => {
             generic_conversion_array!(
                 Date32Type,
@@ -398,9 +402,9 @@ pub fn cast_to_variant(input: &dyn Array) -> 
Result<VariantArray, ArrowError> {
             );
         }
         DataType::RunEndEncoded(run_ends, _) => match run_ends.data_type() {
-            DataType::Int16 => process_run_end_encoded::<Int16Type>(input, 
&mut builder)?,
-            DataType::Int32 => process_run_end_encoded::<Int32Type>(input, 
&mut builder)?,
-            DataType::Int64 => process_run_end_encoded::<Int64Type>(input, 
&mut builder)?,
+            DataType::Int16 => convert_run_end_encoded::<Int16Type>(input, 
&mut builder)?,
+            DataType::Int32 => convert_run_end_encoded::<Int32Type>(input, 
&mut builder)?,
+            DataType::Int64 => convert_run_end_encoded::<Int64Type>(input, 
&mut builder)?,
             _ => {
                 return Err(ArrowError::CastError(format!(
                     "Unsupported run ends type: {:?}",
@@ -409,25 +413,7 @@ pub fn cast_to_variant(input: &dyn Array) -> 
Result<VariantArray, ArrowError> {
             }
         },
         DataType::Dictionary(_, _) => {
-            let dict_array = input.as_any_dictionary();
-            let values_variant_array = 
cast_to_variant(dict_array.values().as_ref())?;
-            let normalized_keys = dict_array.normalized_keys();
-            let keys = dict_array.keys();
-
-            for (i, key_idx) in normalized_keys.iter().enumerate() {
-                if keys.is_null(i) {
-                    builder.append_null();
-                    continue;
-                }
-
-                if values_variant_array.is_null(*key_idx) {
-                    builder.append_null();
-                    continue;
-                }
-
-                let value = values_variant_array.value(*key_idx);
-                builder.append_variant(value);
-            }
+            convert_dictionary_encoded(input, &mut builder)?;
         }
 
         DataType::Map(field, _) => match field.data_type() {
@@ -559,8 +545,45 @@ pub fn cast_to_variant(input: &dyn Array) -> 
Result<VariantArray, ArrowError> {
     Ok(builder.build())
 }
 
-/// Generic function to process run-end encoded arrays
-fn process_run_end_encoded<R: RunEndIndexType>(
+/// Convert union arrays
+fn convert_union(
+    fields: &UnionFields,
+    input: &dyn Array,
+    builder: &mut VariantArrayBuilder,
+) -> Result<(), ArrowError> {
+    let union_array = input.as_union();
+
+    // Convert each child array to variant arrays
+    let mut child_variant_arrays = HashMap::new();
+    for (type_id, _) in fields.iter() {
+        let child_array = union_array.child(type_id);
+        let child_variant_array = cast_to_variant(child_array.as_ref())?;
+        child_variant_arrays.insert(type_id, child_variant_array);
+    }
+
+    // Process each element in the union array
+    for i in 0..union_array.len() {
+        let type_id = union_array.type_id(i);
+        let value_offset = union_array.value_offset(i);
+
+        if let Some(child_variant_array) = child_variant_arrays.get(&type_id) {
+            if child_variant_array.is_null(value_offset) {
+                builder.append_null();
+            } else {
+                let value = child_variant_array.value(value_offset);
+                builder.append_variant(value);
+            }
+        } else {
+            // This should not happen in a valid union, but handle gracefully
+            builder.append_null();
+        }
+    }
+
+    Ok(())
+}
+
+/// Generic function to convert run-end encoded arrays
+fn convert_run_end_encoded<R: RunEndIndexType>(
     input: &dyn Array,
     builder: &mut VariantArrayBuilder,
 ) -> Result<(), ArrowError> {
@@ -594,6 +617,34 @@ fn process_run_end_encoded<R: RunEndIndexType>(
     Ok(())
 }
 
+/// Convert dictionary encoded arrays
+fn convert_dictionary_encoded(
+    input: &dyn Array,
+    builder: &mut VariantArrayBuilder,
+) -> Result<(), ArrowError> {
+    let dict_array = input.as_any_dictionary();
+    let values_variant_array = cast_to_variant(dict_array.values().as_ref())?;
+    let normalized_keys = dict_array.normalized_keys();
+    let keys = dict_array.keys();
+
+    for (i, key_idx) in normalized_keys.iter().enumerate() {
+        if keys.is_null(i) {
+            builder.append_null();
+            continue;
+        }
+
+        if values_variant_array.is_null(*key_idx) {
+            builder.append_null();
+            continue;
+        }
+
+        let value = values_variant_array.value(*key_idx);
+        builder.append_variant(value);
+    }
+
+    Ok(())
+}
+
 // TODO do we need a cast_with_options to allow specifying conversion behavior,
 // e.g. how to handle overflows, whether to convert to Variant::Null or return
 // an error, etc. ?
@@ -609,10 +660,10 @@ mod tests {
         LargeStringArray, ListArray, MapArray, NullArray, StringArray, 
StringRunBuilder,
         StringViewArray, StructArray, Time32MillisecondArray, 
Time32SecondArray,
         Time64MicrosecondArray, Time64NanosecondArray, UInt16Array, 
UInt32Array, UInt64Array,
-        UInt8Array,
+        UInt8Array, UnionArray,
     };
-    use arrow::buffer::{NullBuffer, OffsetBuffer};
-    use arrow_schema::{Field, Fields};
+    use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer};
+    use arrow_schema::{DataType, Field, Fields, UnionFields};
     use arrow_schema::{
         DECIMAL128_MAX_PRECISION, DECIMAL32_MAX_PRECISION, 
DECIMAL64_MAX_PRECISION,
     };
@@ -1637,6 +1688,97 @@ mod tests {
         assert_eq!(obj4.get("age"), None);
     }
 
+    #[test]
+    fn test_cast_to_variant_union_sparse() {
+        // Create a sparse union array with mixed types (int, float, string)
+        let int_array = Int32Array::from(vec![Some(1), None, None, None, 
Some(34), None]);
+        let float_array = Float64Array::from(vec![None, Some(3.2), None, 
Some(32.5), None, None]);
+        let string_array = StringArray::from(vec![None, None, Some("hello"), 
None, None, None]);
+        let type_ids = [0, 1, 2, 1, 0, 
0].into_iter().collect::<ScalarBuffer<i8>>();
+
+        let union_fields = UnionFields::new(
+            vec![0, 1, 2],
+            vec![
+                Field::new("int_field", DataType::Int32, false),
+                Field::new("float_field", DataType::Float64, false),
+                Field::new("string_field", DataType::Utf8, false),
+            ],
+        );
+
+        let children: Vec<Arc<dyn Array>> = vec![
+            Arc::new(int_array),
+            Arc::new(float_array),
+            Arc::new(string_array),
+        ];
+
+        let union_array = UnionArray::try_new(
+            union_fields,
+            type_ids,
+            None, // Sparse union
+            children,
+        )
+        .unwrap();
+
+        run_test(
+            Arc::new(union_array),
+            vec![
+                Some(Variant::Int32(1)),
+                Some(Variant::Double(3.2)),
+                Some(Variant::from("hello")),
+                Some(Variant::Double(32.5)),
+                Some(Variant::Int32(34)),
+                None,
+            ],
+        );
+    }
+
+    #[test]
+    fn test_cast_to_variant_union_dense() {
+        // Create a dense union array with mixed types (int, float, string)
+        let int_array = Int32Array::from(vec![Some(1), Some(34), None]);
+        let float_array = Float64Array::from(vec![3.2, 32.5]);
+        let string_array = StringArray::from(vec!["hello"]);
+        let type_ids = [0, 1, 2, 1, 0, 
0].into_iter().collect::<ScalarBuffer<i8>>();
+        let offsets = [0, 0, 0, 1, 1, 2]
+            .into_iter()
+            .collect::<ScalarBuffer<i32>>();
+
+        let union_fields = UnionFields::new(
+            vec![0, 1, 2],
+            vec![
+                Field::new("int_field", DataType::Int32, false),
+                Field::new("float_field", DataType::Float64, false),
+                Field::new("string_field", DataType::Utf8, false),
+            ],
+        );
+
+        let children: Vec<Arc<dyn Array>> = vec![
+            Arc::new(int_array),
+            Arc::new(float_array),
+            Arc::new(string_array),
+        ];
+
+        let union_array = UnionArray::try_new(
+            union_fields,
+            type_ids,
+            Some(offsets), // Dense union
+            children,
+        )
+        .unwrap();
+
+        run_test(
+            Arc::new(union_array),
+            vec![
+                Some(Variant::Int32(1)),
+                Some(Variant::Double(3.2)),
+                Some(Variant::from("hello")),
+                Some(Variant::Double(32.5)),
+                Some(Variant::Int32(34)),
+                None,
+            ],
+        );
+    }
+
     #[test]
     fn test_cast_to_variant_struct_with_nulls() {
         // Test struct with null values at the struct level

Reply via email to