This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/main by this push:
     new e9bb9f6f30 Add support for `Union` types in `RowConverter` (#8839)
e9bb9f6f30 is described below

commit e9bb9f6f301f63806fc4f0983e92032b616d24f2
Author: Matthew Kim <[email protected]>
AuthorDate: Mon Dec 8 12:59:43 2025 +0100

    Add support for `Union` types in `RowConverter` (#8839)
    
    # Which issue does this PR close?
    
    - Closes https://github.com/apache/arrow-rs/issues/8828
    
    # Rationale for this change
    
    This PR implements row format conversion for Union types (both sparse
    and dense modes) in the row kernel. Union types can now be encoded into
    the row format for sorting and comparison ops
    
    It handles both sparse and dense union modes by encoding each row as a
    null sentinel byte, followed by the type id byte, and then the encoded
    child row data. During decoding, rows are grouped by their type id and
    routed to the appropriate child converter
---
 arrow-row/src/lib.rs | 452 +++++++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 452 insertions(+)

diff --git a/arrow-row/src/lib.rs b/arrow-row/src/lib.rs
index 5f690e9a67..72a295627e 100644
--- a/arrow-row/src/lib.rs
+++ b/arrow-row/src/lib.rs
@@ -458,6 +458,9 @@ enum Codec {
     List(RowConverter),
     /// A row converter for the values array of a run-end encoded array
     RunEndEncoded(RowConverter),
+    /// Row converters for each union field (indexed by type_id)
+    /// and the encoding of null rows for each field
+    Union(Vec<RowConverter>, Vec<OwnedRow>),
 }
 
 impl Codec {
@@ -524,6 +527,35 @@ impl Codec {
 
                 Ok(Self::Struct(converter, owned))
             }
+            DataType::Union(fields, _mode) => {
+                // similar to dictionaries and lists, we set descending to 
false and negate nulls_first
+                // since the encoded contents will be inverted if descending 
is set
+                let options = SortOptions {
+                    descending: false,
+                    nulls_first: sort_field.options.nulls_first != 
sort_field.options.descending,
+                };
+
+                let mut converters = Vec::with_capacity(fields.len());
+                let mut null_rows = Vec::with_capacity(fields.len());
+
+                for (_type_id, field) in fields.iter() {
+                    let sort_field =
+                        SortField::new_with_options(field.data_type().clone(), 
options);
+                    let converter = RowConverter::new(vec![sort_field])?;
+
+                    let null_array = new_null_array(field.data_type(), 1);
+                    let nulls = converter.convert_columns(&[null_array])?;
+                    let owned = OwnedRow {
+                        data: nulls.buffer.into(),
+                        config: nulls.config,
+                    };
+
+                    converters.push(converter);
+                    null_rows.push(owned);
+                }
+
+                Ok(Self::Union(converters, null_rows))
+            }
             _ => Err(ArrowError::NotYetImplemented(format!(
                 "not yet implemented: {:?}",
                 sort_field.data_type
@@ -592,6 +624,28 @@ impl Codec {
                 let rows = 
converter.convert_columns(std::slice::from_ref(values))?;
                 Ok(Encoder::RunEndEncoded(rows))
             }
+            Codec::Union(converters, _) => {
+                let union_array = array
+                    .as_any()
+                    .downcast_ref::<UnionArray>()
+                    .expect("expected Union array");
+
+                let type_ids = union_array.type_ids().clone();
+                let offsets = union_array.offsets().cloned();
+
+                let mut child_rows = Vec::with_capacity(converters.len());
+                for (type_id, converter) in converters.iter().enumerate() {
+                    let child_array = union_array.child(type_id as i8);
+                    let rows = 
converter.convert_columns(std::slice::from_ref(child_array))?;
+                    child_rows.push(rows);
+                }
+
+                Ok(Encoder::Union {
+                    child_rows,
+                    type_ids,
+                    offsets,
+                })
+            }
         }
     }
 
@@ -602,6 +656,10 @@ impl Codec {
             Codec::Struct(converter, nulls) => converter.size() + 
nulls.data.len(),
             Codec::List(converter) => converter.size(),
             Codec::RunEndEncoded(converter) => converter.size(),
+            Codec::Union(converters, null_rows) => {
+                converters.iter().map(|c| c.size()).sum::<usize>()
+                    + null_rows.iter().map(|n| n.data.len()).sum::<usize>()
+            }
         }
     }
 }
@@ -622,6 +680,12 @@ enum Encoder<'a> {
     List(Rows),
     /// The row encoding of the values array
     RunEndEncoded(Rows),
+    /// The row encoding of each union field's child array, type_ids buffer, 
offsets buffer (for Dense), and mode
+    Union {
+        child_rows: Vec<Rows>,
+        type_ids: ScalarBuffer<i8>,
+        offsets: Option<ScalarBuffer<i32>>,
+    },
 }
 
 /// Configure the data type and sort order for a given column
@@ -681,6 +745,9 @@ impl RowConverter {
             }
             DataType::Struct(f) => f.iter().all(|x| 
Self::supports_datatype(x.data_type())),
             DataType::RunEndEncoded(_, values) => 
Self::supports_datatype(values.data_type()),
+            DataType::Union(fs, _mode) => fs
+                .iter()
+                .all(|(_, f)| Self::supports_datatype(f.data_type())),
             _ => false,
         }
     }
@@ -1523,6 +1590,27 @@ fn row_lengths(cols: &[ArrayRef], encoders: &[Encoder]) 
-> LengthTracker {
                 },
                 _ => unreachable!(),
             },
+            Encoder::Union {
+                child_rows,
+                type_ids,
+                offsets,
+            } => {
+                let union_array = array
+                    .as_any()
+                    .downcast_ref::<UnionArray>()
+                    .expect("expected UnionArray");
+
+                let lengths = (0..union_array.len()).map(|i| {
+                    let type_id = type_ids[i];
+                    let child_row_i = offsets.as_ref().map(|o| o[i] as 
usize).unwrap_or(i);
+                    let child_row = child_rows[type_id as 
usize].row(child_row_i);
+
+                    // length: 1 byte type_id + child row bytes
+                    1 + child_row.as_ref().len()
+                });
+
+                tracker.push_variable(lengths);
+            }
         }
     }
 
@@ -1637,6 +1725,36 @@ fn encode_column(
             },
             _ => unreachable!(),
         },
+        Encoder::Union {
+            child_rows,
+            type_ids,
+            offsets: offsets_buf,
+        } => {
+            offsets
+                .iter_mut()
+                .skip(1)
+                .enumerate()
+                .for_each(|(i, offset)| {
+                    let type_id = type_ids[i];
+
+                    let child_row_idx = offsets_buf.as_ref().map(|o| o[i] as 
usize).unwrap_or(i);
+                    let child_row = child_rows[type_id as 
usize].row(child_row_idx);
+                    let child_bytes = child_row.as_ref();
+
+                    let type_id_byte = if opts.descending {
+                        !(type_id as u8)
+                    } else {
+                        type_id as u8
+                    };
+                    data[*offset] = type_id_byte;
+
+                    let child_start = *offset + 1;
+                    let child_end = child_start + child_bytes.len();
+                    data[child_start..child_end].copy_from_slice(child_bytes);
+
+                    *offset = child_end;
+                });
+        }
     }
 }
 
@@ -1762,6 +1880,107 @@ unsafe fn decode_column(
             },
             _ => unreachable!(),
         },
+        Codec::Union(converters, null_rows) => {
+            let len = rows.len();
+
+            let DataType::Union(union_fields, mode) = &field.data_type else {
+                unreachable!()
+            };
+
+            let mut type_ids = Vec::with_capacity(len);
+            let mut rows_by_field: Vec<Vec<(usize, &[u8])>> = vec![Vec::new(); 
converters.len()];
+
+            for (idx, row) in rows.iter_mut().enumerate() {
+                let type_id_byte = {
+                    let id = row[0];
+                    if options.descending { !id } else { id }
+                };
+
+                let type_id = type_id_byte as i8;
+                type_ids.push(type_id);
+
+                let field_idx = type_id as usize;
+
+                let child_row = &row[1..];
+                rows_by_field[field_idx].push((idx, child_row));
+
+                *row = &row[row.len()..];
+            }
+
+            let mut child_arrays: Vec<ArrayRef> = 
Vec::with_capacity(converters.len());
+
+            let mut offsets = (*mode == UnionMode::Dense).then(|| 
Vec::with_capacity(len));
+
+            for (field_idx, converter) in converters.iter().enumerate() {
+                let field_rows = &rows_by_field[field_idx];
+
+                match &mode {
+                    UnionMode::Dense => {
+                        if field_rows.is_empty() {
+                            let (_, field) = 
union_fields.iter().nth(field_idx).unwrap();
+                            
child_arrays.push(arrow_array::new_empty_array(field.data_type()));
+                            continue;
+                        }
+
+                        let mut child_data = field_rows
+                            .iter()
+                            .map(|(_, bytes)| *bytes)
+                            .collect::<Vec<_>>();
+
+                        let child_array =
+                            unsafe { converter.convert_raw(&mut child_data, 
validate_utf8) }?;
+
+                        
child_arrays.push(child_array.into_iter().next().unwrap());
+                    }
+                    UnionMode::Sparse => {
+                        let mut sparse_data: Vec<&[u8]> = 
Vec::with_capacity(len);
+                        let mut field_row_iter = field_rows.iter().peekable();
+                        let null_row_bytes: &[u8] = &null_rows[field_idx].data;
+
+                        for idx in 0..len {
+                            if let Some((next_idx, bytes)) = 
field_row_iter.peek() {
+                                if *next_idx == idx {
+                                    sparse_data.push(*bytes);
+
+                                    field_row_iter.next();
+                                    continue;
+                                }
+                            }
+                            sparse_data.push(null_row_bytes);
+                        }
+
+                        let child_array =
+                            unsafe { converter.convert_raw(&mut sparse_data, 
validate_utf8) }?;
+                        
child_arrays.push(child_array.into_iter().next().unwrap());
+                    }
+                }
+            }
+
+            // build offsets for dense unions
+            if let Some(ref mut offsets_vec) = offsets {
+                let mut count = vec![0i32; converters.len()];
+                for type_id in &type_ids {
+                    let field_idx = *type_id as usize;
+                    offsets_vec.push(count[field_idx]);
+
+                    count[field_idx] += 1;
+                }
+            }
+
+            let type_ids_buffer = ScalarBuffer::from(type_ids);
+            let offsets_buffer = offsets.map(ScalarBuffer::from);
+
+            let union_array = UnionArray::try_new(
+                union_fields.clone(),
+                type_ids_buffer,
+                offsets_buffer,
+                child_arrays,
+            )?;
+
+            // note: union arrays don't support physical null buffers
+            // nulls are represented logically though child arrays
+            Arc::new(union_array)
+        }
     };
     Ok(array)
 }
@@ -3598,4 +3817,237 @@ mod tests {
         assert_eq!(unchecked_values_len, 13);
         assert!(checked_values_len > unchecked_values_len);
     }
+
+    #[test]
+    fn test_sparse_union() {
+        // create a sparse union with Int32 (type_id = 0) and Utf8 (type_id = 
1)
+        let int_array = Int32Array::from(vec![Some(1), None, Some(3), None, 
Some(5)]);
+        let str_array = StringArray::from(vec![None, Some("b"), None, 
Some("d"), None]);
+
+        // [1, "b", 3, "d", 5]
+        let type_ids = vec![0, 1, 0, 1, 0].into();
+
+        let union_fields = [
+            (0, Arc::new(Field::new("int", DataType::Int32, false))),
+            (1, Arc::new(Field::new("str", DataType::Utf8, false))),
+        ]
+        .into_iter()
+        .collect();
+
+        let union_array = UnionArray::try_new(
+            union_fields,
+            type_ids,
+            None,
+            vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)],
+        )
+        .unwrap();
+
+        let union_type = union_array.data_type().clone();
+        let converter = 
RowConverter::new(vec![SortField::new(union_type)]).unwrap();
+
+        let rows = converter
+            .convert_columns(&[Arc::new(union_array.clone())])
+            .unwrap();
+
+        // round trip
+        let back = converter.convert_rows(&rows).unwrap();
+        let back_union = 
back[0].as_any().downcast_ref::<UnionArray>().unwrap();
+
+        assert_eq!(union_array.len(), back_union.len());
+        for i in 0..union_array.len() {
+            assert_eq!(union_array.type_id(i), back_union.type_id(i));
+        }
+    }
+
+    #[test]
+    fn test_sparse_union_with_nulls() {
+        // create a sparse union with Int32 (type_id = 0) and Utf8 (type_id = 
1)
+        let int_array = Int32Array::from(vec![Some(1), None, Some(3), None, 
Some(5)]);
+        let str_array = StringArray::from(vec![None::<&str>; 5]);
+
+        // [1, null (both children null), 3, null (both children null), 5]
+        let type_ids = vec![0, 1, 0, 1, 0].into();
+
+        let union_fields = [
+            (0, Arc::new(Field::new("int", DataType::Int32, true))),
+            (1, Arc::new(Field::new("str", DataType::Utf8, true))),
+        ]
+        .into_iter()
+        .collect();
+
+        let union_array = UnionArray::try_new(
+            union_fields,
+            type_ids,
+            None,
+            vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)],
+        )
+        .unwrap();
+
+        let union_type = union_array.data_type().clone();
+        let converter = 
RowConverter::new(vec![SortField::new(union_type)]).unwrap();
+
+        let rows = converter
+            .convert_columns(&[Arc::new(union_array.clone())])
+            .unwrap();
+
+        // round trip
+        let back = converter.convert_rows(&rows).unwrap();
+        let back_union = 
back[0].as_any().downcast_ref::<UnionArray>().unwrap();
+
+        assert_eq!(union_array.len(), back_union.len());
+        for i in 0..union_array.len() {
+            let expected_null = union_array.is_null(i);
+            let actual_null = back_union.is_null(i);
+            assert_eq!(expected_null, actual_null, "Null mismatch at index 
{i}");
+            if !expected_null {
+                assert_eq!(union_array.type_id(i), back_union.type_id(i));
+            }
+        }
+    }
+
+    #[test]
+    fn test_dense_union() {
+        // create a dense union with Int32 (type_id = 0) and use Utf8 (type_id 
= 1)
+        let int_array = Int32Array::from(vec![1, 3, 5]);
+        let str_array = StringArray::from(vec!["a", "b"]);
+
+        let type_ids = vec![0, 1, 0, 1, 0].into();
+
+        // [1, "a", 3, "b", 5]
+        let offsets = vec![0, 0, 1, 1, 2].into();
+
+        let union_fields = [
+            (0, Arc::new(Field::new("int", DataType::Int32, false))),
+            (1, Arc::new(Field::new("str", DataType::Utf8, false))),
+        ]
+        .into_iter()
+        .collect();
+
+        let union_array = UnionArray::try_new(
+            union_fields,
+            type_ids,
+            Some(offsets), // Dense mode
+            vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)],
+        )
+        .unwrap();
+
+        let union_type = union_array.data_type().clone();
+        let converter = 
RowConverter::new(vec![SortField::new(union_type)]).unwrap();
+
+        let rows = converter
+            .convert_columns(&[Arc::new(union_array.clone())])
+            .unwrap();
+
+        // round trip
+        let back = converter.convert_rows(&rows).unwrap();
+        let back_union = 
back[0].as_any().downcast_ref::<UnionArray>().unwrap();
+
+        assert_eq!(union_array.len(), back_union.len());
+        for i in 0..union_array.len() {
+            assert_eq!(union_array.type_id(i), back_union.type_id(i));
+        }
+    }
+
+    #[test]
+    fn test_dense_union_with_nulls() {
+        // create a dense union with Int32 (type_id = 0) and Utf8 (type_id = 1)
+        let int_array = Int32Array::from(vec![Some(1), None, Some(5)]);
+        let str_array = StringArray::from(vec![Some("a"), None]);
+
+        // [1, "a", 5, null (str null), null (int null)]
+        let type_ids = vec![0, 1, 0, 1, 0].into();
+        let offsets = vec![0, 0, 1, 1, 2].into();
+
+        let union_fields = [
+            (0, Arc::new(Field::new("int", DataType::Int32, true))),
+            (1, Arc::new(Field::new("str", DataType::Utf8, true))),
+        ]
+        .into_iter()
+        .collect();
+
+        let union_array = UnionArray::try_new(
+            union_fields,
+            type_ids,
+            Some(offsets),
+            vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)],
+        )
+        .unwrap();
+
+        let union_type = union_array.data_type().clone();
+        let converter = 
RowConverter::new(vec![SortField::new(union_type)]).unwrap();
+
+        let rows = converter
+            .convert_columns(&[Arc::new(union_array.clone())])
+            .unwrap();
+
+        // round trip
+        let back = converter.convert_rows(&rows).unwrap();
+        let back_union = 
back[0].as_any().downcast_ref::<UnionArray>().unwrap();
+
+        assert_eq!(union_array.len(), back_union.len());
+        for i in 0..union_array.len() {
+            let expected_null = union_array.is_null(i);
+            let actual_null = back_union.is_null(i);
+            assert_eq!(expected_null, actual_null, "Null mismatch at index 
{i}");
+            if !expected_null {
+                assert_eq!(union_array.type_id(i), back_union.type_id(i));
+            }
+        }
+    }
+
+    #[test]
+    fn test_union_ordering() {
+        let int_array = Int32Array::from(vec![100, 5, 20]);
+        let str_array = StringArray::from(vec!["z", "a"]);
+
+        // [100, "z", 5, "a", 20]
+        let type_ids = vec![0, 1, 0, 1, 0].into();
+        let offsets = vec![0, 0, 1, 1, 2].into();
+
+        let union_fields = [
+            (0, Arc::new(Field::new("int", DataType::Int32, false))),
+            (1, Arc::new(Field::new("str", DataType::Utf8, false))),
+        ]
+        .into_iter()
+        .collect();
+
+        let union_array = UnionArray::try_new(
+            union_fields,
+            type_ids,
+            Some(offsets),
+            vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)],
+        )
+        .unwrap();
+
+        let union_type = union_array.data_type().clone();
+        let converter = 
RowConverter::new(vec![SortField::new(union_type)]).unwrap();
+
+        let rows = 
converter.convert_columns(&[Arc::new(union_array)]).unwrap();
+
+        /*
+        expected ordering
+
+        row 2: 5    - type_id 0
+        row 4: 20   - type_id 0
+        row 0: 100  - type id 0
+        row 3: "a"  - type id 1
+        row 1: "z"  - type id 1
+        */
+
+        // 5 < "z"
+        assert!(rows.row(2) < rows.row(1));
+
+        // 100 < "a"
+        assert!(rows.row(0) < rows.row(3));
+
+        // among ints
+        // 5 < 20
+        assert!(rows.row(2) < rows.row(4));
+        // 20 < 100
+        assert!(rows.row(4) < rows.row(0));
+
+        // among strigns
+        // "a" < "z"
+        assert!(rows.row(3) < rows.row(1));
+    }
 }

Reply via email to