alamb commented on code in PR #8839:
URL: https://github.com/apache/arrow-rs/pull/8839#discussion_r2586241375


##########
arrow-row/src/lib.rs:
##########
@@ -1762,6 +1880,111 @@ unsafe fn decode_column(
             },
             _ => unreachable!(),
         },
+        Codec::Union(converters, null_rows) => {
+            let len = rows.len();
+
+            let DataType::Union(union_fields, mode) = &field.data_type else {
+                unreachable!()
+            };
+
+            let mut type_ids = Vec::with_capacity(len);
+            let mut rows_by_field: Vec<Vec<(usize, &[u8])>> = vec![Vec::new(); 
converters.len()];
+
+            for (idx, row) in rows.iter_mut().enumerate() {
+                let mut cursor = 0;

Review Comment:
   cursor is always 0 or 1 -- it might make the code clearer if you just used 
`0` and `1` rather than a cursor that was updated



##########
arrow-row/src/lib.rs:
##########
@@ -592,6 +624,29 @@ impl Codec {
                 let rows = 
converter.convert_columns(std::slice::from_ref(values))?;
                 Ok(Encoder::RunEndEncoded(rows))
             }
+            Codec::Union(converters, _, mode) => {
+                let union_array = array
+                    .as_any()
+                    .downcast_ref::<UnionArray>()
+                    .expect("expected Union array");
+
+                let type_ids = union_array.type_ids().clone();
+                let offsets = union_array.offsets().cloned();
+
+                let mut child_rows = Vec::with_capacity(converters.len());
+                for (type_id, converter) in converters.iter().enumerate() {
+                    let child_array = union_array.child(type_id as i8);

Review Comment:
   resolved in the latest commit 
   
   ```rust
       /// Row converters for each union field (indexed by type_id)
       /// and the encoding of null rows for each field
       Union(Vec<RowConverter>, Vec<OwnedRow>),
   ```



##########
arrow-row/src/lib.rs:
##########
@@ -3598,4 +3821,237 @@ mod tests {
         assert_eq!(unchecked_values_len, 13);
         assert!(checked_values_len > unchecked_values_len);
     }
+
+    #[test]
+    fn test_sparse_union() {
+        // create a sparse union with Int32 (type_id = 0) and Utf8 (type_id = 
1)
+        let int_array = Int32Array::from(vec![Some(1), None, Some(3), None, 
Some(5)]);
+        let str_array = StringArray::from(vec![None, Some("b"), None, 
Some("d"), None]);
+
+        // [1, "b", 3, "d", 5]
+        let type_ids = vec![0, 1, 0, 1, 0].into();
+
+        let union_fields = [
+            (0, Arc::new(Field::new("int", DataType::Int32, false))),
+            (1, Arc::new(Field::new("str", DataType::Utf8, false))),
+        ]
+        .into_iter()
+        .collect();
+
+        let union_array = UnionArray::try_new(
+            union_fields,
+            type_ids,
+            None,
+            vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)],
+        )
+        .unwrap();
+
+        let union_type = union_array.data_type().clone();
+        let converter = 
RowConverter::new(vec![SortField::new(union_type)]).unwrap();
+
+        let rows = converter
+            .convert_columns(&[Arc::new(union_array.clone())])
+            .unwrap();
+
+        // round trip
+        let back = converter.convert_rows(&rows).unwrap();
+        let back_union = 
back[0].as_any().downcast_ref::<UnionArray>().unwrap();
+
+        assert_eq!(union_array.len(), back_union.len());
+        for i in 0..union_array.len() {
+            assert_eq!(union_array.type_id(i), back_union.type_id(i));
+        }
+    }
+
+    #[test]
+    fn test_sparse_union_with_nulls() {
+        // create a sparse union with Int32 (type_id = 0) and Utf8 (type_id = 
1)
+        let int_array = Int32Array::from(vec![Some(1), None, Some(3), None, 
Some(5)]);
+        let str_array = StringArray::from(vec![None::<&str>; 5]);
+
+        // [1, null (both children null), 3, null (both children null), 5]
+        let type_ids = vec![0, 1, 0, 1, 0].into();
+
+        let union_fields = [
+            (0, Arc::new(Field::new("int", DataType::Int32, true))),
+            (1, Arc::new(Field::new("str", DataType::Utf8, true))),
+        ]
+        .into_iter()
+        .collect();
+
+        let union_array = UnionArray::try_new(
+            union_fields,
+            type_ids,
+            None,
+            vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)],
+        )
+        .unwrap();
+
+        let union_type = union_array.data_type().clone();
+        let converter = 
RowConverter::new(vec![SortField::new(union_type)]).unwrap();
+
+        let rows = converter
+            .convert_columns(&[Arc::new(union_array.clone())])
+            .unwrap();
+
+        // round trip
+        let back = converter.convert_rows(&rows).unwrap();
+        let back_union = 
back[0].as_any().downcast_ref::<UnionArray>().unwrap();
+
+        assert_eq!(union_array.len(), back_union.len());
+        for i in 0..union_array.len() {
+            let expected_null = union_array.is_null(i);
+            let actual_null = back_union.is_null(i);
+            assert_eq!(expected_null, actual_null, "Null mismatch at index 
{i}");
+            if !expected_null {
+                assert_eq!(union_array.type_id(i), back_union.type_id(i));
+            }
+        }
+    }
+
+    #[test]
+    fn test_dense_union() {
+        // create a dense union with Int32 (type_id = 0) and use Utf8 (type_id 
= 1)
+        let int_array = Int32Array::from(vec![1, 3, 5]);
+        let str_array = StringArray::from(vec!["a", "b"]);
+
+        let type_ids = vec![0, 1, 0, 1, 0].into();
+
+        // [1, "a", 3, "b", 5]
+        let offsets = vec![0, 0, 1, 1, 2].into();
+
+        let union_fields = [
+            (0, Arc::new(Field::new("int", DataType::Int32, false))),
+            (1, Arc::new(Field::new("str", DataType::Utf8, false))),
+        ]
+        .into_iter()
+        .collect();
+
+        let union_array = UnionArray::try_new(
+            union_fields,
+            type_ids,
+            Some(offsets), // Dense mode
+            vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)],
+        )
+        .unwrap();
+
+        let union_type = union_array.data_type().clone();
+        let converter = 
RowConverter::new(vec![SortField::new(union_type)]).unwrap();
+
+        let rows = converter
+            .convert_columns(&[Arc::new(union_array.clone())])
+            .unwrap();
+
+        // round trip
+        let back = converter.convert_rows(&rows).unwrap();
+        let back_union = 
back[0].as_any().downcast_ref::<UnionArray>().unwrap();
+
+        assert_eq!(union_array.len(), back_union.len());
+        for i in 0..union_array.len() {
+            assert_eq!(union_array.type_id(i), back_union.type_id(i));
+        }
+    }
+
+    #[test]
+    fn test_dense_union_with_nulls() {
+        // create a dense union with Int32 (type_id = 0) and Utf8 (type_id = 1)
+        let int_array = Int32Array::from(vec![Some(1), None, Some(5)]);
+        let str_array = StringArray::from(vec![Some("a"), None]);
+
+        // [1, "a", 5, null (str null), null (int null)]
+        let type_ids = vec![0, 1, 0, 1, 0].into();
+        let offsets = vec![0, 0, 1, 1, 2].into();
+
+        let union_fields = [
+            (0, Arc::new(Field::new("int", DataType::Int32, true))),
+            (1, Arc::new(Field::new("str", DataType::Utf8, true))),
+        ]
+        .into_iter()
+        .collect();
+
+        let union_array = UnionArray::try_new(
+            union_fields,
+            type_ids,
+            Some(offsets),
+            vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)],
+        )
+        .unwrap();
+
+        let union_type = union_array.data_type().clone();
+        let converter = 
RowConverter::new(vec![SortField::new(union_type)]).unwrap();
+
+        let rows = converter
+            .convert_columns(&[Arc::new(union_array.clone())])
+            .unwrap();
+
+        // round trip
+        let back = converter.convert_rows(&rows).unwrap();
+        let back_union = 
back[0].as_any().downcast_ref::<UnionArray>().unwrap();
+
+        assert_eq!(union_array.len(), back_union.len());
+        for i in 0..union_array.len() {
+            let expected_null = union_array.is_null(i);
+            let actual_null = back_union.is_null(i);
+            assert_eq!(expected_null, actual_null, "Null mismatch at index 
{i}");
+            if !expected_null {
+                assert_eq!(union_array.type_id(i), back_union.type_id(i));
+            }
+        }
+    }
+
+    #[test]
+    fn test_union_ordering() {
+        let int_array = Int32Array::from(vec![100, 5, 20]);
+        let str_array = StringArray::from(vec!["z", "a"]);
+
+        // [100, "z", 5, "a", 20]
+        let type_ids = vec![0, 1, 0, 1, 0].into();
+        let offsets = vec![0, 0, 1, 1, 2].into();
+
+        let union_fields = [
+            (0, Arc::new(Field::new("int", DataType::Int32, false))),
+            (1, Arc::new(Field::new("str", DataType::Utf8, false))),
+        ]
+        .into_iter()
+        .collect();
+
+        let union_array = UnionArray::try_new(
+            union_fields,
+            type_ids,
+            Some(offsets),
+            vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)],
+        )
+        .unwrap();
+
+        let union_type = union_array.data_type().clone();
+        let converter = 
RowConverter::new(vec![SortField::new(union_type)]).unwrap();
+
+        let rows = 
converter.convert_columns(&[Arc::new(union_array)]).unwrap();
+
+        /*
+        expected ordering
+
+        row 2: 5    - type_id 0
+        row 4: 20   - type_id 0
+        row 0: 100  - type id 0
+        row 3: "a"  - type id 1
+        row 1: "z"  - type id 1
+        */
+
+        // 5 < "z"
+        assert!(rows.row(2) < rows.row(1));
+
+        // 100 < "a"
+        assert!(rows.row(0) < rows.row(3));
+
+        // among ints
+        // 5 < 20
+        assert!(rows.row(2) < rows.row(4));
+        // 20 < 100
+        assert!(rows.row(4) < rows.row(0));
+
+        // among strigns
+        // "a" < "z"
+        assert!(rows.row(3) < rows.row(1));
+    }

Review Comment:
   Here are some other test suggestions from codex:
   
    - arrow-row/src/lib.rs:3826 — All union round-trip tests only assert 
type_id equality (and nulls) after conversion; a broken child
       encoding/decoding would slip through. Add value equality checks for both 
children (e.g., compare rebuilt children to originals
       or project back into typed arrays).
     - arrow-row/src/lib.rs:3826 — Coverage never exercises unions whose 
type_ids are non-contiguous or don’t start at 0, yet the
       conversion code indexes children by type_id as usize. Add a dense and 
sparse case with ids like 2 and 7 (with offsets) to catch
       mapping errors.
     - arrow-row/src/lib.rs:3826 — No test slices a UnionArray before 
conversion. Because type_ids/offsets are cloned directly, a
       sliced array could expose offset-handling bugs. Add a test that builds a 
longer union, slices it (non-zero offset), and round-
       trips.
     - arrow-row/src/lib.rs:4011 — Ordering is only checked for dense unions 
with default sort options. Add ordering assertions for
       sparse unions and for SortOptions variations (descending, nulls_last) to 
ensure the option flipping in Codec::Union behaves
       as intended.



##########
arrow-row/src/lib.rs:
##########
@@ -1762,6 +1880,111 @@ unsafe fn decode_column(
             },
             _ => unreachable!(),
         },
+        Codec::Union(converters, null_rows) => {
+            let len = rows.len();
+
+            let DataType::Union(union_fields, mode) = &field.data_type else {
+                unreachable!()
+            };
+
+            let mut type_ids = Vec::with_capacity(len);
+            let mut rows_by_field: Vec<Vec<(usize, &[u8])>> = vec![Vec::new(); 
converters.len()];
+
+            for (idx, row) in rows.iter_mut().enumerate() {
+                let mut cursor = 0;
+
+                let type_id_byte = {
+                    let id = row[cursor];
+                    cursor += 1;
+
+                    if options.descending { !id } else { id }
+                };
+
+                let type_id = type_id_byte as i8;
+                type_ids.push(type_id);
+
+                let field_idx = type_id as usize;
+
+                let child_row = &row[cursor..];
+                rows_by_field[field_idx].push((idx, child_row));
+
+                *row = &row[row.len()..];
+            }
+
+            let mut child_arrays: Vec<ArrayRef> = 
Vec::with_capacity(converters.len());
+
+            let mut offsets = (*mode == UnionMode::Dense).then(|| 
Vec::with_capacity(len));
+
+            for (field_idx, converter) in converters.iter().enumerate() {
+                let field_rows = &rows_by_field[field_idx];
+
+                match &mode {
+                    UnionMode::Dense => {
+                        if field_rows.is_empty() {
+                            let (_, field) = 
union_fields.iter().nth(field_idx).unwrap();

Review Comment:
   I found the `iter().nth` thing weird until I realized there is now way to 
index `UnionFields` directly. Maybe we should add an `impl Index` or something



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to