This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/main by this push:
     new a8a63c28d1 Add comparison support for Union arrays (#8838)
a8a63c28d1 is described below

commit a8a63c28d14b99d8f50b32f3184ab986bad15e50
Author: Matthew Kim <[email protected]>
AuthorDate: Mon Nov 24 16:13:32 2025 -0500

    Add comparison support for Union arrays (#8838)
    
    # Which issue does this PR close?
    
    - Closes https://github.com/apache/arrow-rs/issues/8837
    - Related to https://github.com/apache/arrow-rs/issues/8828
    
    # Rationale for this change
    
    This PR implements comparison functionality for Union arrays. This
    implementation follows a simple ordering strategy where unions are first
    compared by their type identifier, and only when type identifiers match
    are the actual values within those types compared
    
    This approach handles both sparse and dense union modes correctly by
    using offsets when present (dense unions) or direct indices (sparse
    unions) to locate the appropriate child array values
---
 arrow-ord/src/ord.rs | 320 ++++++++++++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 316 insertions(+), 4 deletions(-)

diff --git a/arrow-ord/src/ord.rs b/arrow-ord/src/ord.rs
index 6e3025576c..b12a06732d 100644
--- a/arrow-ord/src/ord.rs
+++ b/arrow-ord/src/ord.rs
@@ -21,8 +21,8 @@ use arrow_array::cast::AsArray;
 use arrow_array::types::*;
 use arrow_array::*;
 use arrow_buffer::{ArrowNativeType, NullBuffer};
-use arrow_schema::{ArrowError, SortOptions};
-use std::cmp::Ordering;
+use arrow_schema::{ArrowError, DataType, SortOptions};
+use std::{cmp::Ordering, collections::HashMap};
 
 /// Compare the values at two arbitrary indices in two arrays.
 pub type DynComparator = Box<dyn Fn(usize, usize) -> Ordering + Send + Sync>;
@@ -296,6 +296,78 @@ fn compare_struct(
     Ok(f)
 }
 
+fn compare_union(
+    left: &dyn Array,
+    right: &dyn Array,
+    opts: SortOptions,
+) -> Result<DynComparator, ArrowError> {
+    let left = left.as_union();
+    let right = right.as_union();
+
+    let (left_fields, left_mode) = match left.data_type() {
+        DataType::Union(fields, mode) => (fields, mode),
+        _ => unreachable!(),
+    };
+    let (right_fields, right_mode) = match right.data_type() {
+        DataType::Union(fields, mode) => (fields, mode),
+        _ => unreachable!(),
+    };
+
+    if left_fields != right_fields {
+        return Err(ArrowError::InvalidArgumentError(format!(
+            "Cannot compare UnionArrays with different fields: left={:?}, 
right={:?}",
+            left_fields, right_fields
+        )));
+    }
+
+    if left_mode != right_mode {
+        return Err(ArrowError::InvalidArgumentError(format!(
+            "Cannot compare UnionArrays with different modes: left={:?}, 
right={:?}",
+            left_mode, right_mode
+        )));
+    }
+
+    let c_opts = child_opts(opts);
+
+    let mut field_comparators = HashMap::with_capacity(left_fields.len());
+
+    for (type_id, _field) in left_fields.iter() {
+        let left_child = left.child(type_id);
+        let right_child = right.child(type_id);
+        let cmp = make_comparator(left_child.as_ref(), right_child.as_ref(), 
c_opts)?;
+
+        field_comparators.insert(type_id, cmp);
+    }
+
+    let left_type_ids = left.type_ids().clone();
+    let right_type_ids = right.type_ids().clone();
+
+    let left_offsets = left.offsets().cloned();
+    let right_offsets = right.offsets().cloned();
+
+    let f = compare(left, right, opts, move |i, j| {
+        let left_type_id = left_type_ids[i];
+        let right_type_id = right_type_ids[j];
+
+        // first, compare by type_id
+        match left_type_id.cmp(&right_type_id) {
+            Ordering::Equal => {
+                // second, compare by values
+                let left_offset = left_offsets.as_ref().map(|o| o[i] as 
usize).unwrap_or(i);
+                let right_offset = right_offsets.as_ref().map(|o| o[j] as 
usize).unwrap_or(j);
+
+                let cmp = field_comparators
+                    .get(&left_type_id)
+                    .expect("type id not found in field_comparators");
+
+                cmp(left_offset, right_offset)
+            }
+            other => other,
+        }
+    });
+    Ok(f)
+}
+
 /// Returns a comparison function that compares two values at two different 
positions
 /// between the two arrays.
 ///
@@ -412,6 +484,7 @@ pub fn make_comparator(
              }
         },
         (Map(_, _), Map(_, _)) => compare_map(left, right, opts),
+        (Union(_, _), Union(_, _)) => compare_union(left, right, opts),
         (lhs, rhs) => Err(ArrowError::InvalidArgumentError(match lhs == rhs {
             true => format!("The data type type {lhs:?} has no natural order"),
             false => "Can't compare arrays of different types".to_string(),
@@ -423,8 +496,8 @@ pub fn make_comparator(
 mod tests {
     use super::*;
     use arrow_array::builder::{Int32Builder, ListBuilder, MapBuilder, 
StringBuilder};
-    use arrow_buffer::{IntervalDayTime, OffsetBuffer, i256};
-    use arrow_schema::{DataType, Field, Fields};
+    use arrow_buffer::{IntervalDayTime, OffsetBuffer, ScalarBuffer, i256};
+    use arrow_schema::{DataType, Field, Fields, UnionFields};
     use half::f16;
     use std::sync::Arc;
 
@@ -1189,4 +1262,243 @@ mod tests {
             }
         }
     }
+
+    #[test]
+    fn test_dense_union() {
+        // create a dense union array with Int32 (type_id = 0) and Utf8 
(type_id=1)
+        // the values are: [1, "b", 2, "a", 3]
+        //  type_ids are: [0,  1,  0,  1,  0]
+        //   offsets are: [0, 0, 1, 1, 2] from [1, 2, 3] and ["b", "a"]
+        let int_array = Int32Array::from(vec![1, 2, 3]);
+        let str_array = StringArray::from(vec!["b", "a"]);
+
+        let type_ids = [0, 1, 0, 1, 
0].into_iter().collect::<ScalarBuffer<i8>>();
+        let offsets = [0, 0, 1, 1, 
2].into_iter().collect::<ScalarBuffer<i32>>();
+
+        let union_fields = [
+            (0, Arc::new(Field::new("A", DataType::Int32, false))),
+            (1, Arc::new(Field::new("B", DataType::Utf8, false))),
+        ]
+        .into_iter()
+        .collect::<UnionFields>();
+
+        let children = vec![Arc::new(int_array) as ArrayRef, 
Arc::new(str_array)];
+
+        let array1 =
+            UnionArray::try_new(union_fields.clone(), type_ids, Some(offsets), 
children).unwrap();
+
+        // create a second array: [2, "a", 1, "c"]
+        //          type ids are: [0,  1,  0,  1]
+        //           offsets are: [0, 0, 1, 1] from [2, 1] and ["a", "c"]
+        let int_array2 = Int32Array::from(vec![2, 1]);
+        let str_array2 = StringArray::from(vec!["a", "c"]);
+        let type_ids2 = [0, 1, 0, 1].into_iter().collect::<ScalarBuffer<i8>>();
+        let offsets2 = [0, 0, 1, 1].into_iter().collect::<ScalarBuffer<i32>>();
+
+        let children2 = vec![Arc::new(int_array2) as ArrayRef, 
Arc::new(str_array2)];
+
+        let array2 =
+            UnionArray::try_new(union_fields, type_ids2, Some(offsets2), 
children2).unwrap();
+
+        let opts = SortOptions {
+            descending: false,
+            nulls_first: true,
+        };
+
+        // comparing
+        // [1, "b", 2, "a", 3]
+        // [2, "a", 1, "c"]
+        let cmp = make_comparator(&array1, &array2, opts).unwrap();
+
+        // array1[0] = (type_id=0, value=1)
+        // array2[0] = (type_id=0, value=2)
+        assert_eq!(cmp(0, 0), Ordering::Less); // 1 < 2
+
+        // array1[0] = (type_id=0, value=1)
+        // array2[1] = (type_id=1, value="a")
+        assert_eq!(cmp(0, 1), Ordering::Less); // type_id 0 < 1
+
+        // array1[1] = (type_id=1, value="b")
+        // array2[1] = (type_id=1, value="a")
+        assert_eq!(cmp(1, 1), Ordering::Greater); // "b" > "a"
+
+        // array1[2] = (type_id=0, value=2)
+        // array2[0] = (type_id=0, value=2)
+        assert_eq!(cmp(2, 0), Ordering::Equal); // 2 == 2
+
+        // array1[3] = (type_id=1, value="a")
+        // array2[1] = (type_id=1, value="a")
+        assert_eq!(cmp(3, 1), Ordering::Equal); // "a" == "a"
+
+        // array1[1] = (type_id=1, value="b")
+        // array2[3] = (type_id=1, value="c")
+        assert_eq!(cmp(1, 3), Ordering::Less); // "b" < "c"
+
+        let opts_desc = SortOptions {
+            descending: true,
+            nulls_first: true,
+        };
+        let cmp_desc = make_comparator(&array1, &array2, opts_desc).unwrap();
+
+        assert_eq!(cmp_desc(0, 0), Ordering::Greater); // 1 > 2 (reversed)
+        assert_eq!(cmp_desc(0, 1), Ordering::Greater); // type_id 0 < 1, 
reversed to Greater
+        assert_eq!(cmp_desc(1, 1), Ordering::Less); // "b" < "a" (reversed)
+    }
+
+    #[test]
+    fn test_sparse_union() {
+        // create a sparse union array with Int32 (type_id=0) and Utf8 
(type_id=1)
+        // values: [1, "b", 3]
+        // note, in sparse unions, child arrays have the same length as the 
union
+        let int_array = Int32Array::from(vec![Some(1), None, Some(3)]);
+        let str_array = StringArray::from(vec![None, Some("b"), None]);
+        let type_ids = [0, 1, 0].into_iter().collect::<ScalarBuffer<i8>>();
+
+        let union_fields = [
+            (0, Arc::new(Field::new("a", DataType::Int32, false))),
+            (1, Arc::new(Field::new("b", DataType::Utf8, false))),
+        ]
+        .into_iter()
+        .collect::<UnionFields>();
+
+        let children = vec![Arc::new(int_array) as ArrayRef, 
Arc::new(str_array)];
+
+        let array = UnionArray::try_new(union_fields, type_ids, None, 
children).unwrap();
+
+        let opts = SortOptions::default();
+        let cmp = make_comparator(&array, &array, opts).unwrap();
+
+        // array[0] = (type_id=0, value=1), array[2] = (type_id=0, value=3)
+        assert_eq!(cmp(0, 2), Ordering::Less); // 1 < 3
+        // array[0] = (type_id=0, value=1), array[1] = (type_id=1, value="b")
+        assert_eq!(cmp(0, 1), Ordering::Less); // type_id 0 < 1
+    }
+
+    #[test]
+    #[should_panic(expected = "index out of bounds")]
+    fn test_union_out_of_bounds() {
+        // create a dense union array with 3 elements
+        let int_array = Int32Array::from(vec![1, 2]);
+        let str_array = StringArray::from(vec!["a"]);
+
+        let type_ids = [0, 1, 0].into_iter().collect::<ScalarBuffer<i8>>();
+        let offsets = [0, 0, 1].into_iter().collect::<ScalarBuffer<i32>>();
+
+        let union_fields = [
+            (0, Arc::new(Field::new("A", DataType::Int32, false))),
+            (1, Arc::new(Field::new("B", DataType::Utf8, false))),
+        ]
+        .into_iter()
+        .collect::<UnionFields>();
+
+        let children = vec![Arc::new(int_array) as ArrayRef, 
Arc::new(str_array)];
+
+        let array = UnionArray::try_new(union_fields, type_ids, Some(offsets), 
children).unwrap();
+
+        let opts = SortOptions::default();
+        let cmp = make_comparator(&array, &array, opts).unwrap();
+
+        // oob
+        cmp(0, 3);
+    }
+
+    #[test]
+    fn test_union_incompatible_fields() {
+        // create first union with Int32 and Utf8
+        let int_array1 = Int32Array::from(vec![1, 2]);
+        let str_array1 = StringArray::from(vec!["a", "b"]);
+
+        let type_ids1 = [0, 1].into_iter().collect::<ScalarBuffer<i8>>();
+        let offsets1 = [0, 0].into_iter().collect::<ScalarBuffer<i32>>();
+
+        let union_fields1 = [
+            (0, Arc::new(Field::new("A", DataType::Int32, false))),
+            (1, Arc::new(Field::new("B", DataType::Utf8, false))),
+        ]
+        .into_iter()
+        .collect::<UnionFields>();
+
+        let children1 = vec![Arc::new(int_array1) as ArrayRef, 
Arc::new(str_array1)];
+
+        let array1 =
+            UnionArray::try_new(union_fields1, type_ids1, Some(offsets1), 
children1).unwrap();
+
+        // create second union with Int32 and Float64 (incompatible with first)
+        let int_array2 = Int32Array::from(vec![3, 4]);
+        let float_array2 = Float64Array::from(vec![1.0, 2.0]);
+
+        let type_ids2 = [0, 1].into_iter().collect::<ScalarBuffer<i8>>();
+        let offsets2 = [0, 0].into_iter().collect::<ScalarBuffer<i32>>();
+
+        let union_fields2 = [
+            (0, Arc::new(Field::new("A", DataType::Int32, false))),
+            (1, Arc::new(Field::new("C", DataType::Float64, false))),
+        ]
+        .into_iter()
+        .collect::<UnionFields>();
+
+        let children2 = vec![Arc::new(int_array2) as ArrayRef, 
Arc::new(float_array2)];
+
+        let array2 =
+            UnionArray::try_new(union_fields2, type_ids2, Some(offsets2), 
children2).unwrap();
+
+        let opts = SortOptions::default();
+
+        let Result::Err(ArrowError::InvalidArgumentError(out)) =
+            make_comparator(&array1, &array2, opts)
+        else {
+            panic!("expected error when making comparator of incompatible 
union arrays");
+        };
+
+        assert_eq!(
+            &out,
+            "Cannot compare UnionArrays with different fields: left=[(0, Field 
{ name: \"A\", data_type: Int32 }), (1, Field { name: \"B\", data_type: Utf8 
})], right=[(0, Field { name: \"A\", data_type: Int32 }), (1, Field { name: 
\"C\", data_type: Float64 })]"
+        );
+    }
+
+    #[test]
+    fn test_union_incompatible_modes() {
+        // create first union as Dense with Int32 and Utf8
+        let int_array1 = Int32Array::from(vec![1, 2]);
+        let str_array1 = StringArray::from(vec!["a", "b"]);
+
+        let type_ids1 = [0, 1].into_iter().collect::<ScalarBuffer<i8>>();
+        let offsets1 = [0, 0].into_iter().collect::<ScalarBuffer<i32>>();
+
+        let union_fields1 = [
+            (0, Arc::new(Field::new("A", DataType::Int32, false))),
+            (1, Arc::new(Field::new("B", DataType::Utf8, false))),
+        ]
+        .into_iter()
+        .collect::<UnionFields>();
+
+        let children1 = vec![Arc::new(int_array1) as ArrayRef, 
Arc::new(str_array1)];
+
+        let array1 =
+            UnionArray::try_new(union_fields1.clone(), type_ids1, 
Some(offsets1), children1)
+                .unwrap();
+
+        // create second union as Sparse with same fields (Int32 and Utf8)
+        let int_array2 = Int32Array::from(vec![Some(3), None]);
+        let str_array2 = StringArray::from(vec![None, Some("c")]);
+
+        let type_ids2 = [0, 1].into_iter().collect::<ScalarBuffer<i8>>();
+
+        let children2 = vec![Arc::new(int_array2) as ArrayRef, 
Arc::new(str_array2)];
+
+        let array2 = UnionArray::try_new(union_fields1, type_ids2, None, 
children2).unwrap();
+
+        let opts = SortOptions::default();
+
+        let Result::Err(ArrowError::InvalidArgumentError(out)) =
+            make_comparator(&array1, &array2, opts)
+        else {
+            panic!("expected error when making comparator of union arrays with 
different modes");
+        };
+
+        assert_eq!(
+            &out,
+            "Cannot compare UnionArrays with different modes: left=Dense, 
right=Sparse"
+        );
+    }
 }

Reply via email to