This is an automated email from the ASF dual-hosted git repository.

mbrobbel pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/main by this push:
     new f19bda3e72 Arrow-avro Writer Dense Union support  (#8550)
f19bda3e72 is described below

commit f19bda3e72d64ef70408499d255b917f3033b53b
Author: nathaniel-d-ef <[email protected]>
AuthorDate: Wed Oct 8 09:03:11 2025 +0200

    Arrow-avro Writer Dense Union support  (#8550)
    
    # Which issue does this PR close?
    
    Relates to:
    https://github.com/apache/arrow-rs/pull/8348
    https://github.com/apache/arrow-rs/issues/4886
    
    # Rationale for this change
    
    This PR completes the efforts of @jecsand838, adding dense union support
    to the encoder side of the crate, along with four other minor extensions
    of existing time-related encoding.
    
    Note: currently this PR is stacked behind
    https://github.com/apache/arrow-rs/pull/8546. Once that's merged this
    will be updated and will not include those changes.
    
    # What changes are included in this PR?
    
    - Dense union support for the writer
    - Tests
    
    # Are these changes tested?
    
    - A full round-trip test, reading in an existing union avro file and
    asserting that the output matches expectations
    - Unit tests covering new encoders.
    
    # Are there any user-facing changes?
    
    Crate not yet public
    
    ---------
    
    Co-authored-by: Connor Sanders <[email protected]>
    Co-authored-by: Connor Sanders 
<[email protected]>
    Co-authored-by: Matthijs Brobbel <[email protected]>
---
 arrow-avro/benches/avro_writer.rs |  77 +++++++-
 arrow-avro/src/writer/encoder.rs  | 378 ++++++++++++++++++++++++++++++++++++--
 arrow-avro/src/writer/mod.rs      | 248 ++++++++++++++++++++++++-
 3 files changed, 682 insertions(+), 21 deletions(-)

diff --git a/arrow-avro/benches/avro_writer.rs 
b/arrow-avro/benches/avro_writer.rs
index 1ac94e865b..085ec66c1f 100644
--- a/arrow-avro/benches/avro_writer.rs
+++ b/arrow-avro/benches/avro_writer.rs
@@ -29,8 +29,8 @@ use arrow_array::{
     RecordBatch, StringArray, StructArray,
 };
 use arrow_avro::writer::AvroWriter;
-use arrow_buffer::i256;
-use arrow_schema::{DataType, Field, IntervalUnit, Schema, TimeUnit};
+use arrow_buffer::{i256, Buffer};
+use arrow_schema::{DataType, Field, IntervalUnit, Schema, TimeUnit, 
UnionFields, UnionMode};
 use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, 
Criterion, Throughput};
 use once_cell::sync::Lazy;
 use rand::{
@@ -679,6 +679,78 @@ static ENUM_DATA: Lazy<Vec<RecordBatch>> = Lazy::new(|| {
         .collect()
 });
 
+static UNION_DATA: Lazy<Vec<RecordBatch>> = Lazy::new(|| {
+    // Basic Dense Union of three types: Utf8, Int32, Float64
+    let union_fields = UnionFields::new(
+        vec![0, 1, 2],
+        vec![
+            Field::new("u_str", DataType::Utf8, true),
+            Field::new("u_int", DataType::Int32, true),
+            Field::new("u_f64", DataType::Float64, true),
+        ],
+    );
+    let union_dt = DataType::Union(union_fields.clone(), UnionMode::Dense);
+    let schema = schema_single("field1", union_dt);
+
+    SIZES
+        .iter()
+        .map(|&n| {
+            // Cycle type ids 0 -> 1 -> 2 ... for determinism
+            let mut type_ids: Vec<i8> = Vec::with_capacity(n);
+            let mut offsets: Vec<i32> = Vec::with_capacity(n);
+            let (mut c0, mut c1, mut c2) = (0i32, 0i32, 0i32);
+            for i in 0..n {
+                let tid = (i % 3) as i8;
+                type_ids.push(tid);
+                match tid {
+                    0 => {
+                        offsets.push(c0);
+                        c0 += 1;
+                    }
+                    1 => {
+                        offsets.push(c1);
+                        c1 += 1;
+                    }
+                    _ => {
+                        offsets.push(c2);
+                        c2 += 1;
+                    }
+                }
+            }
+
+            // Build children arrays with lengths equal to counts per type id
+            let mut rng = rng_for(0xDEAD_0003, n);
+            let strings: Vec<String> = (0..c0)
+                .map(|_| rand_ascii_string(&mut rng, 3, 12))
+                .collect();
+            let ints = 0..c1;
+            let floats = (0..c2).map(|_| rng.random::<f64>());
+
+            let str_arr = StringArray::from_iter_values(strings);
+            let int_arr: PrimitiveArray<Int32Type> = 
PrimitiveArray::from_iter_values(ints);
+            let f_arr = Float64Array::from_iter_values(floats);
+
+            let type_ids_buf = Buffer::from_slice_ref(type_ids.as_slice());
+            let offsets_buf = Buffer::from_slice_ref(offsets.as_slice());
+
+            let union_array = arrow_array::UnionArray::try_new(
+                union_fields.clone(),
+                type_ids_buf.into(),
+                Some(offsets_buf.into()),
+                vec![
+                    Arc::new(str_arr) as ArrayRef,
+                    Arc::new(int_arr) as ArrayRef,
+                    Arc::new(f_arr) as ArrayRef,
+                ],
+            )
+            .unwrap();
+
+            let col: ArrayRef = Arc::new(union_array);
+            RecordBatch::try_new(schema.clone(), vec![col]).unwrap()
+        })
+        .collect()
+});
+
 fn ocf_size_for_batch(batch: &RecordBatch) -> usize {
     let schema_owned: Schema = (*batch.schema()).clone();
     let cursor = Cursor::new(Vec::<u8>::with_capacity(1024));
@@ -756,6 +828,7 @@ fn criterion_benches(c: &mut Criterion) {
     bench_writer_scenario(c, "write-Decimal256(bytes)", &DECIMAL256_DATA);
     bench_writer_scenario(c, "write-Map", &MAP_DATA);
     bench_writer_scenario(c, "write-Enum", &ENUM_DATA);
+    bench_writer_scenario(c, "write-Union", &UNION_DATA);
 }
 
 criterion_group! {
diff --git a/arrow-avro/src/writer/encoder.rs b/arrow-avro/src/writer/encoder.rs
index f39c53c557..363fee6f22 100644
--- a/arrow-avro/src/writer/encoder.rs
+++ b/arrow-avro/src/writer/encoder.rs
@@ -21,19 +21,22 @@ use crate::codec::{AvroDataType, AvroField, Codec};
 use crate::schema::{Fingerprint, Nullability, Prefix};
 use arrow_array::cast::AsArray;
 use arrow_array::types::{
-    ArrowPrimitiveType, DurationMicrosecondType, DurationMillisecondType, 
DurationNanosecondType,
-    DurationSecondType, Float32Type, Float64Type, Int32Type, Int64Type, 
IntervalDayTimeType,
-    IntervalMonthDayNanoType, IntervalYearMonthType, TimestampMicrosecondType,
+    ArrowPrimitiveType, Date32Type, DurationMicrosecondType, 
DurationMillisecondType,
+    DurationNanosecondType, DurationSecondType, Float32Type, Float64Type, 
Int32Type, Int64Type,
+    IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType, 
Time32MillisecondType,
+    Time64MicrosecondType, TimestampMicrosecondType, TimestampMillisecondType,
 };
 use arrow_array::{
     Array, Decimal128Array, Decimal256Array, DictionaryArray, 
FixedSizeBinaryArray,
     GenericBinaryArray, GenericListArray, GenericStringArray, LargeListArray, 
ListArray, MapArray,
-    OffsetSizeTrait, PrimitiveArray, RecordBatch, StringArray, StructArray,
+    OffsetSizeTrait, PrimitiveArray, RecordBatch, StringArray, StructArray, 
UnionArray,
 };
 #[cfg(feature = "small_decimals")]
 use arrow_array::{Decimal32Array, Decimal64Array};
 use arrow_buffer::NullBuffer;
-use arrow_schema::{ArrowError, DataType, Field, IntervalUnit, Schema as 
ArrowSchema, TimeUnit};
+use arrow_schema::{
+    ArrowError, DataType, Field, IntervalUnit, Schema as ArrowSchema, 
TimeUnit, UnionMode,
+};
 use std::io::Write;
 use std::sync::Arc;
 use uuid::Uuid;
@@ -224,6 +227,7 @@ impl<'a> FieldEncoder<'a> {
     ) -> Result<Self, ArrowError> {
         let encoder = match plan {
             FieldPlan::Scalar => match array.data_type() {
+                DataType::Null => Encoder::Null,
                 DataType::Boolean => 
Encoder::Boolean(BooleanEncoder(array.as_boolean())),
                 DataType::Utf8 => {
                     
Encoder::Utf8(Utf8GenericEncoder::<i32>(array.as_string::<i32>()))
@@ -233,6 +237,13 @@ impl<'a> FieldEncoder<'a> {
                 }
                 DataType::Int32 => 
Encoder::Int(IntEncoder(array.as_primitive::<Int32Type>())),
                 DataType::Int64 => 
Encoder::Long(LongEncoder(array.as_primitive::<Int64Type>())),
+                DataType::Date32 => 
Encoder::Date32(IntEncoder(array.as_primitive::<Date32Type>())),
+                DataType::Time32(TimeUnit::Millisecond) => {
+                    
Encoder::Time32Millis(IntEncoder(array.as_primitive::<Time32MillisecondType>()))
+                }
+                DataType::Time64(TimeUnit::Microsecond) => {
+                    
Encoder::Time64Micros(LongEncoder(array.as_primitive::<Time64MicrosecondType>()))
+                }
                 DataType::Float32 => {
                     
Encoder::Float32(F32Encoder(array.as_primitive::<Float32Type>()))
                 }
@@ -252,9 +263,19 @@ impl<'a> FieldEncoder<'a> {
                         })?;
                     Encoder::Fixed(FixedEncoder(arr))
                 }
-                DataType::Timestamp(TimeUnit::Microsecond, _) => 
Encoder::Timestamp(LongEncoder(
-                    array.as_primitive::<TimestampMicrosecondType>(),
-                )),
+                DataType::Timestamp(unit, _) => match unit {
+                    TimeUnit::Millisecond => 
Encoder::TimestampMillis(LongEncoder(
+                        array.as_primitive::<TimestampMillisecondType>(),
+                    )),
+                    TimeUnit::Microsecond => Encoder::Timestamp(LongEncoder(
+                        array.as_primitive::<TimestampMicrosecondType>(),
+                    )),
+                    other => {
+                        return Err(ArrowError::NotYetImplemented(format!(
+                            "Avro writer does not support Timestamp with unit 
{other:?}"
+                        )));
+                    }
+                },
                 DataType::Interval(unit) => match unit {
                     IntervalUnit::MonthDayNano => {
                         Encoder::IntervalMonthDayNano(DurationEncoder(
@@ -292,12 +313,12 @@ impl<'a> FieldEncoder<'a> {
                     )));
                 }
             },
-            FieldPlan::Struct { encoders } => {
+            FieldPlan::Struct { bindings } => {
                 let arr = array
                     .as_any()
                     .downcast_ref::<StructArray>()
                     .ok_or_else(|| ArrowError::SchemaError("Expected 
StructArray".into()))?;
-                Encoder::Struct(Box::new(StructEncoder::try_new(arr, 
encoders)?))
+                Encoder::Struct(Box::new(StructEncoder::try_new(arr, 
bindings)?))
             }
             FieldPlan::List {
                 items_nullability,
@@ -429,6 +450,14 @@ impl<'a> FieldEncoder<'a> {
                     )))
                 }
             }
+            FieldPlan::Union { bindings } => {
+                let arr = array
+                    .as_any()
+                    .downcast_ref::<UnionArray>()
+                    .ok_or_else(|| ArrowError::SchemaError("Expected 
UnionArray".into()))?;
+
+                Encoder::Union(Box::new(UnionEncoder::try_new(arr, bindings)?))
+            }
         };
         // Compute the effective null state from writer-declared nullability 
and data nulls.
         let null_state = match (nullability, array.null_count() > 0) {
@@ -494,7 +523,7 @@ enum FieldPlan {
     /// Non-nested scalar/logical type
     Scalar,
     /// Record/Struct with Avro‑ordered children
-    Struct { encoders: Vec<FieldBinding> },
+    Struct { bindings: Vec<FieldBinding> },
     /// Array with item‑site nullability and nested plan
     List {
         items_nullability: Option<Nullability>,
@@ -512,6 +541,8 @@ enum FieldPlan {
     /// Avro enum; maps to Arrow Dictionary<Int32, Utf8> with dictionary values
     /// exactly equal and ordered as the Avro enum `symbols`.
     Enum { symbols: Arc<[String]> },
+    /// Avro union, maps to Arrow Union.
+    Union { bindings: Vec<FieldBinding> },
 }
 
 #[derive(Debug, Clone)]
@@ -700,7 +731,7 @@ impl FieldPlan {
                         )))
                     }
                 };
-                let mut encoders = Vec::with_capacity(avro_fields.len());
+                let mut bindings = Vec::with_capacity(avro_fields.len());
                 for avro_field in avro_fields.iter() {
                     let name = avro_field.name().to_string();
                     let idx = find_struct_child_index(fields, 
&name).ok_or_else(|| {
@@ -709,13 +740,13 @@ impl FieldPlan {
                             arrow_field.name()
                         ))
                     })?;
-                    encoders.push(FieldBinding {
+                    bindings.push(FieldBinding {
                         arrow_index: idx,
                         nullability: avro_field.data_type().nullability(),
                         plan: FieldPlan::build(avro_field.data_type(), 
fields[idx].as_ref())?,
                     });
                 }
-                Ok(FieldPlan::Struct { encoders })
+                Ok(FieldPlan::Struct { bindings })
             }
             Codec::List(items_dt) => match arrow_field.data_type() {
                 DataType::List(field_ref) => Ok(FieldPlan::List {
@@ -812,6 +843,50 @@ impl FieldPlan {
                     "Avro duration logical type requires Arrow 
Interval(MonthDayNano), found: {other:?}"
                 ))),
             }
+            Codec::Union(avro_branches, _, UnionMode::Dense) => {
+                let arrow_union_fields = match arrow_field.data_type() {
+                    DataType::Union(fields, UnionMode::Dense) => fields,
+                    DataType::Union(_, UnionMode::Sparse) => {
+                        return Err(ArrowError::NotYetImplemented(
+                            "Sparse Arrow unions are not yet 
supported".to_string(),
+                        ));
+                    }
+                    other => {
+                        return Err(ArrowError::SchemaError(format!(
+                            "Avro union maps to Arrow Union, found: {other:?}"
+                        )));
+                    }
+                };
+
+                if avro_branches.len() != arrow_union_fields.len() {
+                    return Err(ArrowError::SchemaError(format!(
+                        "Mismatched number of branches between Avro union ({}) 
and Arrow union ({}) for field '{}'",
+                        avro_branches.len(),
+                        arrow_union_fields.len(),
+                        arrow_field.name()
+                    )));
+                }
+
+                let bindings = avro_branches
+                    .iter()
+                    .zip(arrow_union_fields.iter())
+                    .enumerate()
+                    .map(|(i, (avro_branch, (_, arrow_child_field)))| {
+                        Ok(FieldBinding {
+                            arrow_index: i,
+                            nullability: avro_branch.nullability(),
+                            plan: FieldPlan::build(avro_branch, 
arrow_child_field)?,
+                        })
+                    })
+                    .collect::<Result<Vec<_>, ArrowError>>()?;
+
+                Ok(FieldPlan::Union { bindings })
+            }
+            Codec::Union(_, _, UnionMode::Sparse) => {
+                Err(ArrowError::NotYetImplemented(
+                    "Sparse Arrow unions are not yet supported".to_string(),
+                ))
+            }
             _ => Ok(FieldPlan::Scalar),
         }
     }
@@ -822,6 +897,10 @@ enum Encoder<'a> {
     Int(IntEncoder<'a, Int32Type>),
     Long(LongEncoder<'a, Int64Type>),
     Timestamp(LongEncoder<'a, TimestampMicrosecondType>),
+    TimestampMillis(LongEncoder<'a, TimestampMillisecondType>),
+    Date32(IntEncoder<'a, Date32Type>),
+    Time32Millis(IntEncoder<'a, Time32MillisecondType>),
+    Time64Micros(LongEncoder<'a, Time64MicrosecondType>),
     DurationSeconds(LongEncoder<'a, DurationSecondType>),
     DurationMillis(LongEncoder<'a, DurationMillisecondType>),
     DurationMicros(LongEncoder<'a, DurationMicrosecondType>),
@@ -854,6 +933,8 @@ enum Encoder<'a> {
     /// Avro `enum` encoder: writes the key (int) as the enum index.
     Enum(EnumEncoder<'a>),
     Map(Box<MapEncoder<'a>>),
+    Union(Box<UnionEncoder<'a>>),
+    Null,
 }
 
 impl<'a> Encoder<'a> {
@@ -864,6 +945,10 @@ impl<'a> Encoder<'a> {
             Encoder::Int(e) => e.encode(out, idx),
             Encoder::Long(e) => e.encode(out, idx),
             Encoder::Timestamp(e) => e.encode(out, idx),
+            Encoder::TimestampMillis(e) => e.encode(out, idx),
+            Encoder::Date32(e) => e.encode(out, idx),
+            Encoder::Time32Millis(e) => e.encode(out, idx),
+            Encoder::Time64Micros(e) => e.encode(out, idx),
             Encoder::DurationSeconds(e) => e.encode(out, idx),
             Encoder::DurationMicros(e) => e.encode(out, idx),
             Encoder::DurationMillis(e) => e.encode(out, idx),
@@ -890,6 +975,8 @@ impl<'a> Encoder<'a> {
             Encoder::Decimal256(e) => (e).encode(out, idx),
             Encoder::Map(e) => (e).encode(out, idx),
             Encoder::Enum(e) => (e).encode(out, idx),
+            Encoder::Union(e) => (e).encode(out, idx),
+            Encoder::Null => Ok(()),
         }
     }
 }
@@ -1086,6 +1173,58 @@ impl EnumEncoder<'_> {
     }
 }
 
+struct UnionEncoder<'a> {
+    encoders: Vec<FieldEncoder<'a>>,
+    array: &'a UnionArray,
+}
+
+impl<'a> UnionEncoder<'a> {
+    fn try_new(array: &'a UnionArray, field_bindings: &[FieldBinding]) -> 
Result<Self, ArrowError> {
+        let DataType::Union(fields, UnionMode::Dense) = array.data_type() else 
{
+            return Err(ArrowError::SchemaError("Expected Dense 
UnionArray".into()));
+        };
+
+        if fields.len() != field_bindings.len() {
+            return Err(ArrowError::SchemaError(format!(
+                "Mismatched number of union branches between Arrow array ({}) 
and encoding plan ({})",
+                fields.len(),
+                field_bindings.len()
+            )));
+        }
+        let mut encoders = Vec::with_capacity(fields.len());
+        for (type_id, field_ref) in fields.iter() {
+            let binding = field_bindings
+                .get(type_id as usize)
+                .ok_or_else(|| ArrowError::SchemaError("Binding and field 
mismatch".to_string()))?;
+
+            let child = array.child(type_id).as_ref();
+
+            let encoder = prepare_value_site_encoder(
+                child,
+                field_ref.as_ref(),
+                binding.nullability,
+                &binding.plan,
+            )?;
+            encoders.push(encoder);
+        }
+        Ok(Self { encoders, array })
+    }
+
+    fn encode<W: Write + ?Sized>(&mut self, out: &mut W, idx: usize) -> 
Result<(), ArrowError> {
+        let type_id = self.array.type_ids()[idx];
+        let branch_index = type_id as usize;
+        write_int(out, type_id as i32)?;
+        let child_row = self.array.value_offset(idx);
+
+        let encoder = self
+            .encoders
+            .get_mut(branch_index)
+            .ok_or_else(|| ArrowError::SchemaError(format!("Invalid type_id 
{type_id}")))?;
+
+        encoder.encode(out, child_row)
+    }
+}
+
 struct StructEncoder<'a> {
     encoders: Vec<FieldEncoder<'a>>,
 }
@@ -1420,9 +1559,11 @@ mod tests {
     use arrow_array::types::Int32Type;
     use arrow_array::{
         Array, ArrayRef, BinaryArray, BooleanArray, Float32Array, 
Float64Array, Int32Array,
-        Int64Array, LargeBinaryArray, LargeListArray, LargeStringArray, 
ListArray, StringArray,
+        Int64Array, LargeBinaryArray, LargeListArray, LargeStringArray, 
ListArray, NullArray,
+        StringArray,
     };
-    use arrow_schema::{DataType, Field, Fields};
+    use arrow_buffer::Buffer;
+    use arrow_schema::{DataType, Field, Fields, UnionFields};
 
     fn zigzag_i64(v: i64) -> u64 {
         ((v << 1) ^ (v >> 63)) as u64
@@ -1581,7 +1722,7 @@ mod tests {
             None,
         );
         let plan = FieldPlan::Struct {
-            encoders: vec![
+            bindings: vec![
                 FieldBinding {
                     arrow_index: 0,
                     nullability: None,
@@ -1807,6 +1948,92 @@ mod tests {
         }
     }
 
+    fn test_scalar_primitive_encoding<T>(
+        non_nullable_data: &[T::Native],
+        nullable_data: &[Option<T::Native>],
+    ) where
+        T: ArrowPrimitiveType,
+        T::Native: Into<i64> + Copy,
+        PrimitiveArray<T>: From<Vec<<T as ArrowPrimitiveType>::Native>>,
+    {
+        let plan = FieldPlan::Scalar;
+
+        let array = PrimitiveArray::<T>::from(non_nullable_data.to_vec());
+        let got = encode_all(&array, &plan, None);
+
+        let mut expected = Vec::new();
+        for &value in non_nullable_data {
+            expected.extend(avro_long_bytes(value.into()));
+        }
+        assert_bytes_eq(&got, &expected);
+
+        let array_nullable: PrimitiveArray<T> = 
nullable_data.iter().copied().collect();
+        let got_nullable = encode_all(&array_nullable, &plan, 
Some(Nullability::NullFirst));
+
+        let mut expected_nullable = Vec::new();
+        for &opt_value in nullable_data {
+            match opt_value {
+                Some(value) => {
+                    // Union index 1 for the value, then the value itself
+                    expected_nullable.extend(avro_long_bytes(1));
+                    expected_nullable.extend(avro_long_bytes(value.into()));
+                }
+                None => {
+                    // Union index 0 for the null
+                    expected_nullable.extend(avro_long_bytes(0));
+                }
+            }
+        }
+        assert_bytes_eq(&got_nullable, &expected_nullable);
+    }
+
+    #[test]
+    fn date32_encoder() {
+        test_scalar_primitive_encoding::<Date32Type>(
+            &[
+                19345, // 2022-12-20
+                0,     // 1970-01-01 (epoch)
+                -1,    // 1969-12-31 (pre-epoch)
+            ],
+            &[Some(19345), None],
+        );
+    }
+
+    #[test]
+    fn time32_millis_encoder() {
+        test_scalar_primitive_encoding::<Time32MillisecondType>(
+            &[
+                0,        // Midnight
+                49530123, // 13:45:30.123
+                86399999, // 23:59:59.999
+            ],
+            &[None, Some(49530123)],
+        );
+    }
+
+    #[test]
+    fn time64_micros_encoder() {
+        test_scalar_primitive_encoding::<Time64MicrosecondType>(
+            &[
+                0,           // Midnight
+                86399999999, // 23:59:59.999999
+            ],
+            &[Some(86399999999), None],
+        );
+    }
+
+    #[test]
+    fn timestamp_millis_encoder() {
+        test_scalar_primitive_encoding::<TimestampMillisecondType>(
+            &[
+                1704067200000, // 2024-01-01T00:00:00Z
+                0,             // 1970-01-01T00:00:00Z (epoch)
+                -123456789,    // Pre-epoch timestamp
+            ],
+            &[None, Some(1704067200000)],
+        );
+    }
+
     #[test]
     fn map_encoder_string_keys_int_values() {
         // Build MapArray with two rows
@@ -1849,6 +2076,123 @@ mod tests {
         assert_bytes_eq(&got, &expected);
     }
 
+    #[test]
+    fn union_encoder_string_int() {
+        let strings = StringArray::from(vec!["hello", "world"]);
+        let ints = Int32Array::from(vec![10, 20, 30]);
+
+        let union_fields = UnionFields::new(
+            vec![0, 1],
+            vec![
+                Field::new("v_str", DataType::Utf8, true),
+                Field::new("v_int", DataType::Int32, true),
+            ],
+        );
+
+        let type_ids = Buffer::from_slice_ref([0_i8, 1, 1, 0, 1]);
+        let offsets = Buffer::from_slice_ref([0_i32, 0, 1, 1, 2]);
+
+        let union_array = UnionArray::try_new(
+            union_fields,
+            type_ids.into(),
+            Some(offsets.into()),
+            vec![Arc::new(strings), Arc::new(ints)],
+        )
+        .unwrap();
+
+        let plan = FieldPlan::Union {
+            bindings: vec![
+                FieldBinding {
+                    arrow_index: 0,
+                    nullability: None,
+                    plan: FieldPlan::Scalar,
+                },
+                FieldBinding {
+                    arrow_index: 1,
+                    nullability: None,
+                    plan: FieldPlan::Scalar,
+                },
+            ],
+        };
+
+        let got = encode_all(&union_array, &plan, None);
+
+        let mut expected = Vec::new();
+        expected.extend(avro_long_bytes(0));
+        expected.extend(avro_len_prefixed_bytes(b"hello"));
+        expected.extend(avro_long_bytes(1));
+        expected.extend(avro_long_bytes(10));
+        expected.extend(avro_long_bytes(1));
+        expected.extend(avro_long_bytes(20));
+        expected.extend(avro_long_bytes(0));
+        expected.extend(avro_len_prefixed_bytes(b"world"));
+        expected.extend(avro_long_bytes(1));
+        expected.extend(avro_long_bytes(30));
+
+        assert_bytes_eq(&got, &expected);
+    }
+
+    #[test]
+    fn union_encoder_null_string_int() {
+        let nulls = NullArray::new(1);
+        let strings = StringArray::from(vec!["hello"]);
+        let ints = Int32Array::from(vec![10]);
+
+        let union_fields = UnionFields::new(
+            vec![0, 1, 2],
+            vec![
+                Field::new("v_null", DataType::Null, true),
+                Field::new("v_str", DataType::Utf8, true),
+                Field::new("v_int", DataType::Int32, true),
+            ],
+        );
+
+        let type_ids = Buffer::from_slice_ref([0_i8, 1, 2]);
+        // For a null value in a dense union, no value is added to a child 
array.
+        // The offset points to the last value of that type. Since there's 
only one
+        // null, and one of each other type, all offsets are 0.
+        let offsets = Buffer::from_slice_ref([0_i32, 0, 0]);
+
+        let union_array = UnionArray::try_new(
+            union_fields,
+            type_ids.into(),
+            Some(offsets.into()),
+            vec![Arc::new(nulls), Arc::new(strings), Arc::new(ints)],
+        )
+        .unwrap();
+
+        let plan = FieldPlan::Union {
+            bindings: vec![
+                FieldBinding {
+                    arrow_index: 0,
+                    nullability: None,
+                    plan: FieldPlan::Scalar,
+                },
+                FieldBinding {
+                    arrow_index: 1,
+                    nullability: None,
+                    plan: FieldPlan::Scalar,
+                },
+                FieldBinding {
+                    arrow_index: 2,
+                    nullability: None,
+                    plan: FieldPlan::Scalar,
+                },
+            ],
+        };
+
+        let got = encode_all(&union_array, &plan, None);
+
+        let mut expected = Vec::new();
+        expected.extend(avro_long_bytes(0));
+        expected.extend(avro_long_bytes(1));
+        expected.extend(avro_len_prefixed_bytes(b"hello"));
+        expected.extend(avro_long_bytes(2));
+        expected.extend(avro_long_bytes(10));
+
+        assert_bytes_eq(&got, &expected);
+    }
+
     #[test]
     fn list64_encoder_int32() {
         // LargeList [[1,2,3], []]
diff --git a/arrow-avro/src/writer/mod.rs b/arrow-avro/src/writer/mod.rs
index 5bc2700c1b..8f7cb666db 100644
--- a/arrow-avro/src/writer/mod.rs
+++ b/arrow-avro/src/writer/mod.rs
@@ -394,12 +394,14 @@ mod tests {
     use crate::reader::ReaderBuilder;
     use crate::schema::{AvroSchema, SchemaStore};
     use crate::test_util::arrow_test_data;
-    use arrow_array::{ArrayRef, BinaryArray, Int32Array, RecordBatch};
+    use arrow_array::{
+        Array, ArrayRef, BinaryArray, Int32Array, RecordBatch, StructArray, 
UnionArray,
+    };
     #[cfg(not(feature = "avro_custom_types"))]
     use arrow_schema::{DataType, Field, Schema};
     #[cfg(feature = "avro_custom_types")]
     use arrow_schema::{DataType, Field, Schema, TimeUnit};
-    #[cfg(feature = "avro_custom_types")]
+    use std::collections::HashMap;
     use std::collections::HashSet;
     use std::fs::File;
     use std::io::{BufReader, Cursor};
@@ -1017,6 +1019,248 @@ mod tests {
         Ok(())
     }
 
+    // Union Roundtrip Test Helpers
+
+    // Asserts that the `actual` schema is a semantically equivalent superset 
of the `expected` one.
+    // This allows the `actual` schema to contain additional metadata keys
+    // (`arrowUnionMode`, `arrowUnionTypeIds`, `avro.name`) that are added 
during an Arrow-to-Avro-to-Arrow
+    // roundtrip, while ensuring no other information was lost or changed.
+    fn assert_schema_is_semantically_equivalent(expected: &Schema, actual: 
&Schema) {
+        // Compare top-level schema metadata using the same superset logic.
+        assert_metadata_is_superset(expected.metadata(), actual.metadata(), 
"Schema");
+
+        // Compare fields.
+        assert_eq!(
+            expected.fields().len(),
+            actual.fields().len(),
+            "Schema must have the same number of fields"
+        );
+
+        for (expected_field, actual_field) in 
expected.fields().iter().zip(actual.fields().iter()) {
+            assert_field_is_semantically_equivalent(expected_field, 
actual_field);
+        }
+    }
+
+    fn assert_field_is_semantically_equivalent(expected: &Field, actual: 
&Field) {
+        let context = format!("Field '{}'", expected.name());
+
+        assert_eq!(
+            expected.name(),
+            actual.name(),
+            "{context}: names must match"
+        );
+        assert_eq!(
+            expected.is_nullable(),
+            actual.is_nullable(),
+            "{context}: nullability must match"
+        );
+
+        // Recursively check the data types.
+        assert_datatype_is_semantically_equivalent(
+            expected.data_type(),
+            actual.data_type(),
+            &context,
+        );
+
+        // Check that metadata is a valid superset.
+        assert_metadata_is_superset(expected.metadata(), actual.metadata(), 
&context);
+    }
+
+    fn assert_datatype_is_semantically_equivalent(
+        expected: &DataType,
+        actual: &DataType,
+        context: &str,
+    ) {
+        match (expected, actual) {
+            (DataType::List(expected_field), DataType::List(actual_field))
+            | (DataType::LargeList(expected_field), 
DataType::LargeList(actual_field))
+            | (DataType::Map(expected_field, _), DataType::Map(actual_field, 
_)) => {
+                assert_field_is_semantically_equivalent(expected_field, 
actual_field);
+            }
+            (DataType::Struct(expected_fields), 
DataType::Struct(actual_fields)) => {
+                assert_eq!(
+                    expected_fields.len(),
+                    actual_fields.len(),
+                    "{context}: struct must have same number of fields"
+                );
+                for (ef, af) in 
expected_fields.iter().zip(actual_fields.iter()) {
+                    assert_field_is_semantically_equivalent(ef, af);
+                }
+            }
+            (
+                DataType::Union(expected_fields, expected_mode),
+                DataType::Union(actual_fields, actual_mode),
+            ) => {
+                assert_eq!(
+                    expected_mode, actual_mode,
+                    "{context}: union mode must match"
+                );
+                assert_eq!(
+                    expected_fields.len(),
+                    actual_fields.len(),
+                    "{context}: union must have same number of variants"
+                );
+                for ((exp_id, exp_field), (act_id, act_field)) in
+                    expected_fields.iter().zip(actual_fields.iter())
+                {
+                    assert_eq!(exp_id, act_id, "{context}: union type ids must 
match");
+                    assert_field_is_semantically_equivalent(exp_field, 
act_field);
+                }
+            }
+            _ => {
+                assert_eq!(expected, actual, "{context}: data types must be 
identical");
+            }
+        }
+    }
+
+    fn assert_batch_data_is_identical(expected: &RecordBatch, actual: 
&RecordBatch) {
+        assert_eq!(
+            expected.num_columns(),
+            actual.num_columns(),
+            "RecordBatches must have the same number of columns"
+        );
+        assert_eq!(
+            expected.num_rows(),
+            actual.num_rows(),
+            "RecordBatches must have the same number of rows"
+        );
+
+        for i in 0..expected.num_columns() {
+            let context = format!("Column {i}");
+            let expected_col = expected.column(i);
+            let actual_col = actual.column(i);
+            assert_array_data_is_identical(expected_col, actual_col, &context);
+        }
+    }
+
+    /// Recursively asserts that the data content of two Arrays is identical.
+    fn assert_array_data_is_identical(expected: &dyn Array, actual: &dyn 
Array, context: &str) {
+        assert_eq!(
+            expected.nulls(),
+            actual.nulls(),
+            "{context}: null buffers must match"
+        );
+        assert_eq!(
+            expected.len(),
+            actual.len(),
+            "{context}: array lengths must match"
+        );
+
+        match (expected.data_type(), actual.data_type()) {
+            (DataType::Union(expected_fields, _), DataType::Union(..)) => {
+                let expected_union = 
expected.as_any().downcast_ref::<UnionArray>().unwrap();
+                let actual_union = 
actual.as_any().downcast_ref::<UnionArray>().unwrap();
+
+                // Compare the type_ids buffer (always the first buffer).
+                assert_eq!(
+                    &expected.to_data().buffers()[0],
+                    &actual.to_data().buffers()[0],
+                    "{context}: union type_ids buffer mismatch"
+                );
+
+                // For dense unions, compare the value_offsets buffer (the 
second buffer).
+                if expected.to_data().buffers().len() > 1 {
+                    assert_eq!(
+                        &expected.to_data().buffers()[1],
+                        &actual.to_data().buffers()[1],
+                        "{context}: union value_offsets buffer mismatch"
+                    );
+                }
+
+                // Recursively compare children based on the fields in the 
DataType.
+                for (type_id, _) in expected_fields.iter() {
+                    let child_context = format!("{context} -> child variant 
{type_id}");
+                    assert_array_data_is_identical(
+                        expected_union.child(type_id),
+                        actual_union.child(type_id),
+                        &child_context,
+                    );
+                }
+            }
+            (DataType::Struct(_), DataType::Struct(_)) => {
+                let expected_struct = 
expected.as_any().downcast_ref::<StructArray>().unwrap();
+                let actual_struct = 
actual.as_any().downcast_ref::<StructArray>().unwrap();
+                for i in 0..expected_struct.num_columns() {
+                    let child_context = format!("{context} -> struct child 
{i}");
+                    assert_array_data_is_identical(
+                        expected_struct.column(i),
+                        actual_struct.column(i),
+                        &child_context,
+                    );
+                }
+            }
+            // Fallback for primitive types and other types where buffer 
comparison is sufficient.
+            _ => {
+                assert_eq!(
+                    expected.to_data().buffers(),
+                    actual.to_data().buffers(),
+                    "{context}: data buffers must match"
+                );
+            }
+        }
+    }
+
+    /// Checks that `actual_meta` contains all of `expected_meta`, and any 
additional
+    /// keys in `actual_meta` are from a permitted set.
+    fn assert_metadata_is_superset(
+        expected_meta: &HashMap<String, String>,
+        actual_meta: &HashMap<String, String>,
+        context: &str,
+    ) {
+        let allowed_additions: HashSet<&str> =
+            vec!["arrowUnionMode", "arrowUnionTypeIds", "avro.name"]
+                .into_iter()
+                .collect();
+        for (key, expected_value) in expected_meta {
+            match actual_meta.get(key) {
+                Some(actual_value) => assert_eq!(
+                    expected_value, actual_value,
+                    "{context}: preserved metadata for key '{key}' must have 
the same value"
+                ),
+                None => panic!("{context}: metadata key '{key}' was lost 
during roundtrip"),
+            }
+        }
+        for key in actual_meta.keys() {
+            if !expected_meta.contains_key(key) && 
!allowed_additions.contains(key.as_str()) {
+                panic!("{context}: unexpected metadata key '{key}' was added 
during roundtrip");
+            }
+        }
+    }
+
+    #[test]
+    fn test_union_roundtrip() -> Result<(), ArrowError> {
+        let file_path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
+            .join("test/data/union_fields.avro")
+            .to_string_lossy()
+            .into_owned();
+        let rdr_file = File::open(&file_path).expect("open 
avro/union_fields.avro");
+        let reader = ReaderBuilder::new()
+            .build(BufReader::new(rdr_file))
+            .expect("build reader for union_fields.avro");
+        let schema = reader.schema();
+        let input_batches = reader.collect::<Result<Vec<_>, _>>()?;
+        let original =
+            arrow::compute::concat_batches(&schema, 
&input_batches).expect("concat input");
+        let mut writer = AvroWriter::new(Vec::<u8>::new(), 
original.schema().as_ref().clone())?;
+        writer.write(&original)?;
+        writer.finish()?;
+        let bytes = writer.into_inner();
+        let rt_reader = ReaderBuilder::new()
+            .build(Cursor::new(bytes))
+            .expect("build round_trip reader");
+        let rt_schema = rt_reader.schema();
+        let rt_batches = rt_reader.collect::<Result<Vec<_>, _>>()?;
+        let round_trip =
+            arrow::compute::concat_batches(&rt_schema, 
&rt_batches).expect("concat round_trip");
+
+        // The nature of the crate is such that metadata gets appended during 
the roundtrip,
+        // so we can't compare the schemas directly. Instead, we semantically 
compare the schemas and data.
+        assert_schema_is_semantically_equivalent(&original.schema(), 
&round_trip.schema());
+
+        assert_batch_data_is_identical(&original, &round_trip);
+        Ok(())
+    }
+
     #[test]
     fn test_enum_roundtrip_uses_reader_fixture() -> Result<(), ArrowError> {
         // Read the known-good enum file (same as reader::test_simple)


Reply via email to