This is an automated email from the ASF dual-hosted git repository.
mbrobbel pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/main by this push:
new f19bda3e72 Arrow-avro Writer Dense Union support (#8550)
f19bda3e72 is described below
commit f19bda3e72d64ef70408499d255b917f3033b53b
Author: nathaniel-d-ef <[email protected]>
AuthorDate: Wed Oct 8 09:03:11 2025 +0200
Arrow-avro Writer Dense Union support (#8550)
# Which issue does this PR close?
Relates to:
https://github.com/apache/arrow-rs/pull/8348
https://github.com/apache/arrow-rs/issues/4886
# Rationale for this change
This PR completes the efforts of @jecsand838, adding dense union support
to the encoder side of the crate, along with four other minor extensions
of existing time-related encoding.
Note: currently this PR is stacked behind
https://github.com/apache/arrow-rs/pull/8546. Once that's merged this
will be updated and will not include those changes.
# What changes are included in this PR?
- Dense union support for the writer
- Tests
# Are these changes tested?
- A full round-trip test, reading in an existing union avro file and
asserting that the output matches expectations
- Unit tests covering new encoders.
# Are there any user-facing changes?
Crate not yet public
---------
Co-authored-by: Connor Sanders <[email protected]>
Co-authored-by: Connor Sanders
<[email protected]>
Co-authored-by: Matthijs Brobbel <[email protected]>
---
arrow-avro/benches/avro_writer.rs | 77 +++++++-
arrow-avro/src/writer/encoder.rs | 378 ++++++++++++++++++++++++++++++++++++--
arrow-avro/src/writer/mod.rs | 248 ++++++++++++++++++++++++-
3 files changed, 682 insertions(+), 21 deletions(-)
diff --git a/arrow-avro/benches/avro_writer.rs
b/arrow-avro/benches/avro_writer.rs
index 1ac94e865b..085ec66c1f 100644
--- a/arrow-avro/benches/avro_writer.rs
+++ b/arrow-avro/benches/avro_writer.rs
@@ -29,8 +29,8 @@ use arrow_array::{
RecordBatch, StringArray, StructArray,
};
use arrow_avro::writer::AvroWriter;
-use arrow_buffer::i256;
-use arrow_schema::{DataType, Field, IntervalUnit, Schema, TimeUnit};
+use arrow_buffer::{i256, Buffer};
+use arrow_schema::{DataType, Field, IntervalUnit, Schema, TimeUnit,
UnionFields, UnionMode};
use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId,
Criterion, Throughput};
use once_cell::sync::Lazy;
use rand::{
@@ -679,6 +679,78 @@ static ENUM_DATA: Lazy<Vec<RecordBatch>> = Lazy::new(|| {
.collect()
});
+static UNION_DATA: Lazy<Vec<RecordBatch>> = Lazy::new(|| {
+ // Basic Dense Union of three types: Utf8, Int32, Float64
+ let union_fields = UnionFields::new(
+ vec![0, 1, 2],
+ vec![
+ Field::new("u_str", DataType::Utf8, true),
+ Field::new("u_int", DataType::Int32, true),
+ Field::new("u_f64", DataType::Float64, true),
+ ],
+ );
+ let union_dt = DataType::Union(union_fields.clone(), UnionMode::Dense);
+ let schema = schema_single("field1", union_dt);
+
+ SIZES
+ .iter()
+ .map(|&n| {
+ // Cycle type ids 0 -> 1 -> 2 ... for determinism
+ let mut type_ids: Vec<i8> = Vec::with_capacity(n);
+ let mut offsets: Vec<i32> = Vec::with_capacity(n);
+ let (mut c0, mut c1, mut c2) = (0i32, 0i32, 0i32);
+ for i in 0..n {
+ let tid = (i % 3) as i8;
+ type_ids.push(tid);
+ match tid {
+ 0 => {
+ offsets.push(c0);
+ c0 += 1;
+ }
+ 1 => {
+ offsets.push(c1);
+ c1 += 1;
+ }
+ _ => {
+ offsets.push(c2);
+ c2 += 1;
+ }
+ }
+ }
+
+ // Build children arrays with lengths equal to counts per type id
+ let mut rng = rng_for(0xDEAD_0003, n);
+ let strings: Vec<String> = (0..c0)
+ .map(|_| rand_ascii_string(&mut rng, 3, 12))
+ .collect();
+ let ints = 0..c1;
+ let floats = (0..c2).map(|_| rng.random::<f64>());
+
+ let str_arr = StringArray::from_iter_values(strings);
+ let int_arr: PrimitiveArray<Int32Type> =
PrimitiveArray::from_iter_values(ints);
+ let f_arr = Float64Array::from_iter_values(floats);
+
+ let type_ids_buf = Buffer::from_slice_ref(type_ids.as_slice());
+ let offsets_buf = Buffer::from_slice_ref(offsets.as_slice());
+
+ let union_array = arrow_array::UnionArray::try_new(
+ union_fields.clone(),
+ type_ids_buf.into(),
+ Some(offsets_buf.into()),
+ vec![
+ Arc::new(str_arr) as ArrayRef,
+ Arc::new(int_arr) as ArrayRef,
+ Arc::new(f_arr) as ArrayRef,
+ ],
+ )
+ .unwrap();
+
+ let col: ArrayRef = Arc::new(union_array);
+ RecordBatch::try_new(schema.clone(), vec![col]).unwrap()
+ })
+ .collect()
+});
+
fn ocf_size_for_batch(batch: &RecordBatch) -> usize {
let schema_owned: Schema = (*batch.schema()).clone();
let cursor = Cursor::new(Vec::<u8>::with_capacity(1024));
@@ -756,6 +828,7 @@ fn criterion_benches(c: &mut Criterion) {
bench_writer_scenario(c, "write-Decimal256(bytes)", &DECIMAL256_DATA);
bench_writer_scenario(c, "write-Map", &MAP_DATA);
bench_writer_scenario(c, "write-Enum", &ENUM_DATA);
+ bench_writer_scenario(c, "write-Union", &UNION_DATA);
}
criterion_group! {
diff --git a/arrow-avro/src/writer/encoder.rs b/arrow-avro/src/writer/encoder.rs
index f39c53c557..363fee6f22 100644
--- a/arrow-avro/src/writer/encoder.rs
+++ b/arrow-avro/src/writer/encoder.rs
@@ -21,19 +21,22 @@ use crate::codec::{AvroDataType, AvroField, Codec};
use crate::schema::{Fingerprint, Nullability, Prefix};
use arrow_array::cast::AsArray;
use arrow_array::types::{
- ArrowPrimitiveType, DurationMicrosecondType, DurationMillisecondType,
DurationNanosecondType,
- DurationSecondType, Float32Type, Float64Type, Int32Type, Int64Type,
IntervalDayTimeType,
- IntervalMonthDayNanoType, IntervalYearMonthType, TimestampMicrosecondType,
+ ArrowPrimitiveType, Date32Type, DurationMicrosecondType,
DurationMillisecondType,
+ DurationNanosecondType, DurationSecondType, Float32Type, Float64Type,
Int32Type, Int64Type,
+ IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType,
Time32MillisecondType,
+ Time64MicrosecondType, TimestampMicrosecondType, TimestampMillisecondType,
};
use arrow_array::{
Array, Decimal128Array, Decimal256Array, DictionaryArray,
FixedSizeBinaryArray,
GenericBinaryArray, GenericListArray, GenericStringArray, LargeListArray,
ListArray, MapArray,
- OffsetSizeTrait, PrimitiveArray, RecordBatch, StringArray, StructArray,
+ OffsetSizeTrait, PrimitiveArray, RecordBatch, StringArray, StructArray,
UnionArray,
};
#[cfg(feature = "small_decimals")]
use arrow_array::{Decimal32Array, Decimal64Array};
use arrow_buffer::NullBuffer;
-use arrow_schema::{ArrowError, DataType, Field, IntervalUnit, Schema as
ArrowSchema, TimeUnit};
+use arrow_schema::{
+ ArrowError, DataType, Field, IntervalUnit, Schema as ArrowSchema,
TimeUnit, UnionMode,
+};
use std::io::Write;
use std::sync::Arc;
use uuid::Uuid;
@@ -224,6 +227,7 @@ impl<'a> FieldEncoder<'a> {
) -> Result<Self, ArrowError> {
let encoder = match plan {
FieldPlan::Scalar => match array.data_type() {
+ DataType::Null => Encoder::Null,
DataType::Boolean =>
Encoder::Boolean(BooleanEncoder(array.as_boolean())),
DataType::Utf8 => {
Encoder::Utf8(Utf8GenericEncoder::<i32>(array.as_string::<i32>()))
@@ -233,6 +237,13 @@ impl<'a> FieldEncoder<'a> {
}
DataType::Int32 =>
Encoder::Int(IntEncoder(array.as_primitive::<Int32Type>())),
DataType::Int64 =>
Encoder::Long(LongEncoder(array.as_primitive::<Int64Type>())),
+ DataType::Date32 =>
Encoder::Date32(IntEncoder(array.as_primitive::<Date32Type>())),
+ DataType::Time32(TimeUnit::Millisecond) => {
+
Encoder::Time32Millis(IntEncoder(array.as_primitive::<Time32MillisecondType>()))
+ }
+ DataType::Time64(TimeUnit::Microsecond) => {
+
Encoder::Time64Micros(LongEncoder(array.as_primitive::<Time64MicrosecondType>()))
+ }
DataType::Float32 => {
Encoder::Float32(F32Encoder(array.as_primitive::<Float32Type>()))
}
@@ -252,9 +263,19 @@ impl<'a> FieldEncoder<'a> {
})?;
Encoder::Fixed(FixedEncoder(arr))
}
- DataType::Timestamp(TimeUnit::Microsecond, _) =>
Encoder::Timestamp(LongEncoder(
- array.as_primitive::<TimestampMicrosecondType>(),
- )),
+ DataType::Timestamp(unit, _) => match unit {
+ TimeUnit::Millisecond =>
Encoder::TimestampMillis(LongEncoder(
+ array.as_primitive::<TimestampMillisecondType>(),
+ )),
+ TimeUnit::Microsecond => Encoder::Timestamp(LongEncoder(
+ array.as_primitive::<TimestampMicrosecondType>(),
+ )),
+ other => {
+ return Err(ArrowError::NotYetImplemented(format!(
+ "Avro writer does not support Timestamp with unit
{other:?}"
+ )));
+ }
+ },
DataType::Interval(unit) => match unit {
IntervalUnit::MonthDayNano => {
Encoder::IntervalMonthDayNano(DurationEncoder(
@@ -292,12 +313,12 @@ impl<'a> FieldEncoder<'a> {
)));
}
},
- FieldPlan::Struct { encoders } => {
+ FieldPlan::Struct { bindings } => {
let arr = array
.as_any()
.downcast_ref::<StructArray>()
.ok_or_else(|| ArrowError::SchemaError("Expected
StructArray".into()))?;
- Encoder::Struct(Box::new(StructEncoder::try_new(arr,
encoders)?))
+ Encoder::Struct(Box::new(StructEncoder::try_new(arr,
bindings)?))
}
FieldPlan::List {
items_nullability,
@@ -429,6 +450,14 @@ impl<'a> FieldEncoder<'a> {
)))
}
}
+ FieldPlan::Union { bindings } => {
+ let arr = array
+ .as_any()
+ .downcast_ref::<UnionArray>()
+ .ok_or_else(|| ArrowError::SchemaError("Expected
UnionArray".into()))?;
+
+ Encoder::Union(Box::new(UnionEncoder::try_new(arr, bindings)?))
+ }
};
// Compute the effective null state from writer-declared nullability
and data nulls.
let null_state = match (nullability, array.null_count() > 0) {
@@ -494,7 +523,7 @@ enum FieldPlan {
/// Non-nested scalar/logical type
Scalar,
/// Record/Struct with Avro‑ordered children
- Struct { encoders: Vec<FieldBinding> },
+ Struct { bindings: Vec<FieldBinding> },
/// Array with item‑site nullability and nested plan
List {
items_nullability: Option<Nullability>,
@@ -512,6 +541,8 @@ enum FieldPlan {
/// Avro enum; maps to Arrow Dictionary<Int32, Utf8> with dictionary values
/// exactly equal and ordered as the Avro enum `symbols`.
Enum { symbols: Arc<[String]> },
+ /// Avro union, maps to Arrow Union.
+ Union { bindings: Vec<FieldBinding> },
}
#[derive(Debug, Clone)]
@@ -700,7 +731,7 @@ impl FieldPlan {
)))
}
};
- let mut encoders = Vec::with_capacity(avro_fields.len());
+ let mut bindings = Vec::with_capacity(avro_fields.len());
for avro_field in avro_fields.iter() {
let name = avro_field.name().to_string();
let idx = find_struct_child_index(fields,
&name).ok_or_else(|| {
@@ -709,13 +740,13 @@ impl FieldPlan {
arrow_field.name()
))
})?;
- encoders.push(FieldBinding {
+ bindings.push(FieldBinding {
arrow_index: idx,
nullability: avro_field.data_type().nullability(),
plan: FieldPlan::build(avro_field.data_type(),
fields[idx].as_ref())?,
});
}
- Ok(FieldPlan::Struct { encoders })
+ Ok(FieldPlan::Struct { bindings })
}
Codec::List(items_dt) => match arrow_field.data_type() {
DataType::List(field_ref) => Ok(FieldPlan::List {
@@ -812,6 +843,50 @@ impl FieldPlan {
"Avro duration logical type requires Arrow
Interval(MonthDayNano), found: {other:?}"
))),
}
+ Codec::Union(avro_branches, _, UnionMode::Dense) => {
+ let arrow_union_fields = match arrow_field.data_type() {
+ DataType::Union(fields, UnionMode::Dense) => fields,
+ DataType::Union(_, UnionMode::Sparse) => {
+ return Err(ArrowError::NotYetImplemented(
+ "Sparse Arrow unions are not yet
supported".to_string(),
+ ));
+ }
+ other => {
+ return Err(ArrowError::SchemaError(format!(
+ "Avro union maps to Arrow Union, found: {other:?}"
+ )));
+ }
+ };
+
+ if avro_branches.len() != arrow_union_fields.len() {
+ return Err(ArrowError::SchemaError(format!(
+ "Mismatched number of branches between Avro union ({})
and Arrow union ({}) for field '{}'",
+ avro_branches.len(),
+ arrow_union_fields.len(),
+ arrow_field.name()
+ )));
+ }
+
+ let bindings = avro_branches
+ .iter()
+ .zip(arrow_union_fields.iter())
+ .enumerate()
+ .map(|(i, (avro_branch, (_, arrow_child_field)))| {
+ Ok(FieldBinding {
+ arrow_index: i,
+ nullability: avro_branch.nullability(),
+ plan: FieldPlan::build(avro_branch,
arrow_child_field)?,
+ })
+ })
+ .collect::<Result<Vec<_>, ArrowError>>()?;
+
+ Ok(FieldPlan::Union { bindings })
+ }
+ Codec::Union(_, _, UnionMode::Sparse) => {
+ Err(ArrowError::NotYetImplemented(
+ "Sparse Arrow unions are not yet supported".to_string(),
+ ))
+ }
_ => Ok(FieldPlan::Scalar),
}
}
@@ -822,6 +897,10 @@ enum Encoder<'a> {
Int(IntEncoder<'a, Int32Type>),
Long(LongEncoder<'a, Int64Type>),
Timestamp(LongEncoder<'a, TimestampMicrosecondType>),
+ TimestampMillis(LongEncoder<'a, TimestampMillisecondType>),
+ Date32(IntEncoder<'a, Date32Type>),
+ Time32Millis(IntEncoder<'a, Time32MillisecondType>),
+ Time64Micros(LongEncoder<'a, Time64MicrosecondType>),
DurationSeconds(LongEncoder<'a, DurationSecondType>),
DurationMillis(LongEncoder<'a, DurationMillisecondType>),
DurationMicros(LongEncoder<'a, DurationMicrosecondType>),
@@ -854,6 +933,8 @@ enum Encoder<'a> {
/// Avro `enum` encoder: writes the key (int) as the enum index.
Enum(EnumEncoder<'a>),
Map(Box<MapEncoder<'a>>),
+ Union(Box<UnionEncoder<'a>>),
+ Null,
}
impl<'a> Encoder<'a> {
@@ -864,6 +945,10 @@ impl<'a> Encoder<'a> {
Encoder::Int(e) => e.encode(out, idx),
Encoder::Long(e) => e.encode(out, idx),
Encoder::Timestamp(e) => e.encode(out, idx),
+ Encoder::TimestampMillis(e) => e.encode(out, idx),
+ Encoder::Date32(e) => e.encode(out, idx),
+ Encoder::Time32Millis(e) => e.encode(out, idx),
+ Encoder::Time64Micros(e) => e.encode(out, idx),
Encoder::DurationSeconds(e) => e.encode(out, idx),
Encoder::DurationMicros(e) => e.encode(out, idx),
Encoder::DurationMillis(e) => e.encode(out, idx),
@@ -890,6 +975,8 @@ impl<'a> Encoder<'a> {
Encoder::Decimal256(e) => (e).encode(out, idx),
Encoder::Map(e) => (e).encode(out, idx),
Encoder::Enum(e) => (e).encode(out, idx),
+ Encoder::Union(e) => (e).encode(out, idx),
+ Encoder::Null => Ok(()),
}
}
}
@@ -1086,6 +1173,58 @@ impl EnumEncoder<'_> {
}
}
+struct UnionEncoder<'a> {
+ encoders: Vec<FieldEncoder<'a>>,
+ array: &'a UnionArray,
+}
+
+impl<'a> UnionEncoder<'a> {
+ fn try_new(array: &'a UnionArray, field_bindings: &[FieldBinding]) ->
Result<Self, ArrowError> {
+ let DataType::Union(fields, UnionMode::Dense) = array.data_type() else
{
+ return Err(ArrowError::SchemaError("Expected Dense
UnionArray".into()));
+ };
+
+ if fields.len() != field_bindings.len() {
+ return Err(ArrowError::SchemaError(format!(
+ "Mismatched number of union branches between Arrow array ({})
and encoding plan ({})",
+ fields.len(),
+ field_bindings.len()
+ )));
+ }
+ let mut encoders = Vec::with_capacity(fields.len());
+ for (type_id, field_ref) in fields.iter() {
+ let binding = field_bindings
+ .get(type_id as usize)
+ .ok_or_else(|| ArrowError::SchemaError("Binding and field
mismatch".to_string()))?;
+
+ let child = array.child(type_id).as_ref();
+
+ let encoder = prepare_value_site_encoder(
+ child,
+ field_ref.as_ref(),
+ binding.nullability,
+ &binding.plan,
+ )?;
+ encoders.push(encoder);
+ }
+ Ok(Self { encoders, array })
+ }
+
+ fn encode<W: Write + ?Sized>(&mut self, out: &mut W, idx: usize) ->
Result<(), ArrowError> {
+ let type_id = self.array.type_ids()[idx];
+ let branch_index = type_id as usize;
+ write_int(out, type_id as i32)?;
+ let child_row = self.array.value_offset(idx);
+
+ let encoder = self
+ .encoders
+ .get_mut(branch_index)
+ .ok_or_else(|| ArrowError::SchemaError(format!("Invalid type_id
{type_id}")))?;
+
+ encoder.encode(out, child_row)
+ }
+}
+
struct StructEncoder<'a> {
encoders: Vec<FieldEncoder<'a>>,
}
@@ -1420,9 +1559,11 @@ mod tests {
use arrow_array::types::Int32Type;
use arrow_array::{
Array, ArrayRef, BinaryArray, BooleanArray, Float32Array,
Float64Array, Int32Array,
- Int64Array, LargeBinaryArray, LargeListArray, LargeStringArray,
ListArray, StringArray,
+ Int64Array, LargeBinaryArray, LargeListArray, LargeStringArray,
ListArray, NullArray,
+ StringArray,
};
- use arrow_schema::{DataType, Field, Fields};
+ use arrow_buffer::Buffer;
+ use arrow_schema::{DataType, Field, Fields, UnionFields};
fn zigzag_i64(v: i64) -> u64 {
((v << 1) ^ (v >> 63)) as u64
@@ -1581,7 +1722,7 @@ mod tests {
None,
);
let plan = FieldPlan::Struct {
- encoders: vec![
+ bindings: vec![
FieldBinding {
arrow_index: 0,
nullability: None,
@@ -1807,6 +1948,92 @@ mod tests {
}
}
+ fn test_scalar_primitive_encoding<T>(
+ non_nullable_data: &[T::Native],
+ nullable_data: &[Option<T::Native>],
+ ) where
+ T: ArrowPrimitiveType,
+ T::Native: Into<i64> + Copy,
+ PrimitiveArray<T>: From<Vec<<T as ArrowPrimitiveType>::Native>>,
+ {
+ let plan = FieldPlan::Scalar;
+
+ let array = PrimitiveArray::<T>::from(non_nullable_data.to_vec());
+ let got = encode_all(&array, &plan, None);
+
+ let mut expected = Vec::new();
+ for &value in non_nullable_data {
+ expected.extend(avro_long_bytes(value.into()));
+ }
+ assert_bytes_eq(&got, &expected);
+
+ let array_nullable: PrimitiveArray<T> =
nullable_data.iter().copied().collect();
+ let got_nullable = encode_all(&array_nullable, &plan,
Some(Nullability::NullFirst));
+
+ let mut expected_nullable = Vec::new();
+ for &opt_value in nullable_data {
+ match opt_value {
+ Some(value) => {
+ // Union index 1 for the value, then the value itself
+ expected_nullable.extend(avro_long_bytes(1));
+ expected_nullable.extend(avro_long_bytes(value.into()));
+ }
+ None => {
+ // Union index 0 for the null
+ expected_nullable.extend(avro_long_bytes(0));
+ }
+ }
+ }
+ assert_bytes_eq(&got_nullable, &expected_nullable);
+ }
+
+ #[test]
+ fn date32_encoder() {
+ test_scalar_primitive_encoding::<Date32Type>(
+ &[
+ 19345, // 2022-12-20
+ 0, // 1970-01-01 (epoch)
+ -1, // 1969-12-31 (pre-epoch)
+ ],
+ &[Some(19345), None],
+ );
+ }
+
+ #[test]
+ fn time32_millis_encoder() {
+ test_scalar_primitive_encoding::<Time32MillisecondType>(
+ &[
+ 0, // Midnight
+ 49530123, // 13:45:30.123
+ 86399999, // 23:59:59.999
+ ],
+ &[None, Some(49530123)],
+ );
+ }
+
+ #[test]
+ fn time64_micros_encoder() {
+ test_scalar_primitive_encoding::<Time64MicrosecondType>(
+ &[
+ 0, // Midnight
+ 86399999999, // 23:59:59.999999
+ ],
+ &[Some(86399999999), None],
+ );
+ }
+
+ #[test]
+ fn timestamp_millis_encoder() {
+ test_scalar_primitive_encoding::<TimestampMillisecondType>(
+ &[
+ 1704067200000, // 2024-01-01T00:00:00Z
+ 0, // 1970-01-01T00:00:00Z (epoch)
+ -123456789, // Pre-epoch timestamp
+ ],
+ &[None, Some(1704067200000)],
+ );
+ }
+
#[test]
fn map_encoder_string_keys_int_values() {
// Build MapArray with two rows
@@ -1849,6 +2076,123 @@ mod tests {
assert_bytes_eq(&got, &expected);
}
+ #[test]
+ fn union_encoder_string_int() {
+ let strings = StringArray::from(vec!["hello", "world"]);
+ let ints = Int32Array::from(vec![10, 20, 30]);
+
+ let union_fields = UnionFields::new(
+ vec![0, 1],
+ vec![
+ Field::new("v_str", DataType::Utf8, true),
+ Field::new("v_int", DataType::Int32, true),
+ ],
+ );
+
+ let type_ids = Buffer::from_slice_ref([0_i8, 1, 1, 0, 1]);
+ let offsets = Buffer::from_slice_ref([0_i32, 0, 1, 1, 2]);
+
+ let union_array = UnionArray::try_new(
+ union_fields,
+ type_ids.into(),
+ Some(offsets.into()),
+ vec![Arc::new(strings), Arc::new(ints)],
+ )
+ .unwrap();
+
+ let plan = FieldPlan::Union {
+ bindings: vec![
+ FieldBinding {
+ arrow_index: 0,
+ nullability: None,
+ plan: FieldPlan::Scalar,
+ },
+ FieldBinding {
+ arrow_index: 1,
+ nullability: None,
+ plan: FieldPlan::Scalar,
+ },
+ ],
+ };
+
+ let got = encode_all(&union_array, &plan, None);
+
+ let mut expected = Vec::new();
+ expected.extend(avro_long_bytes(0));
+ expected.extend(avro_len_prefixed_bytes(b"hello"));
+ expected.extend(avro_long_bytes(1));
+ expected.extend(avro_long_bytes(10));
+ expected.extend(avro_long_bytes(1));
+ expected.extend(avro_long_bytes(20));
+ expected.extend(avro_long_bytes(0));
+ expected.extend(avro_len_prefixed_bytes(b"world"));
+ expected.extend(avro_long_bytes(1));
+ expected.extend(avro_long_bytes(30));
+
+ assert_bytes_eq(&got, &expected);
+ }
+
+ #[test]
+ fn union_encoder_null_string_int() {
+ let nulls = NullArray::new(1);
+ let strings = StringArray::from(vec!["hello"]);
+ let ints = Int32Array::from(vec![10]);
+
+ let union_fields = UnionFields::new(
+ vec![0, 1, 2],
+ vec![
+ Field::new("v_null", DataType::Null, true),
+ Field::new("v_str", DataType::Utf8, true),
+ Field::new("v_int", DataType::Int32, true),
+ ],
+ );
+
+ let type_ids = Buffer::from_slice_ref([0_i8, 1, 2]);
+ // For a null value in a dense union, no value is added to a child
array.
+ // The offset points to the last value of that type. Since there's
only one
+ // null, and one of each other type, all offsets are 0.
+ let offsets = Buffer::from_slice_ref([0_i32, 0, 0]);
+
+ let union_array = UnionArray::try_new(
+ union_fields,
+ type_ids.into(),
+ Some(offsets.into()),
+ vec![Arc::new(nulls), Arc::new(strings), Arc::new(ints)],
+ )
+ .unwrap();
+
+ let plan = FieldPlan::Union {
+ bindings: vec![
+ FieldBinding {
+ arrow_index: 0,
+ nullability: None,
+ plan: FieldPlan::Scalar,
+ },
+ FieldBinding {
+ arrow_index: 1,
+ nullability: None,
+ plan: FieldPlan::Scalar,
+ },
+ FieldBinding {
+ arrow_index: 2,
+ nullability: None,
+ plan: FieldPlan::Scalar,
+ },
+ ],
+ };
+
+ let got = encode_all(&union_array, &plan, None);
+
+ let mut expected = Vec::new();
+ expected.extend(avro_long_bytes(0));
+ expected.extend(avro_long_bytes(1));
+ expected.extend(avro_len_prefixed_bytes(b"hello"));
+ expected.extend(avro_long_bytes(2));
+ expected.extend(avro_long_bytes(10));
+
+ assert_bytes_eq(&got, &expected);
+ }
+
#[test]
fn list64_encoder_int32() {
// LargeList [[1,2,3], []]
diff --git a/arrow-avro/src/writer/mod.rs b/arrow-avro/src/writer/mod.rs
index 5bc2700c1b..8f7cb666db 100644
--- a/arrow-avro/src/writer/mod.rs
+++ b/arrow-avro/src/writer/mod.rs
@@ -394,12 +394,14 @@ mod tests {
use crate::reader::ReaderBuilder;
use crate::schema::{AvroSchema, SchemaStore};
use crate::test_util::arrow_test_data;
- use arrow_array::{ArrayRef, BinaryArray, Int32Array, RecordBatch};
+ use arrow_array::{
+ Array, ArrayRef, BinaryArray, Int32Array, RecordBatch, StructArray,
UnionArray,
+ };
#[cfg(not(feature = "avro_custom_types"))]
use arrow_schema::{DataType, Field, Schema};
#[cfg(feature = "avro_custom_types")]
use arrow_schema::{DataType, Field, Schema, TimeUnit};
- #[cfg(feature = "avro_custom_types")]
+ use std::collections::HashMap;
use std::collections::HashSet;
use std::fs::File;
use std::io::{BufReader, Cursor};
@@ -1017,6 +1019,248 @@ mod tests {
Ok(())
}
+ // Union Roundtrip Test Helpers
+
+ // Asserts that the `actual` schema is a semantically equivalent superset
of the `expected` one.
+ // This allows the `actual` schema to contain additional metadata keys
+ // (`arrowUnionMode`, `arrowUnionTypeIds`, `avro.name`) that are added
during an Arrow-to-Avro-to-Arrow
+ // roundtrip, while ensuring no other information was lost or changed.
+ fn assert_schema_is_semantically_equivalent(expected: &Schema, actual:
&Schema) {
+ // Compare top-level schema metadata using the same superset logic.
+ assert_metadata_is_superset(expected.metadata(), actual.metadata(),
"Schema");
+
+ // Compare fields.
+ assert_eq!(
+ expected.fields().len(),
+ actual.fields().len(),
+ "Schema must have the same number of fields"
+ );
+
+ for (expected_field, actual_field) in
expected.fields().iter().zip(actual.fields().iter()) {
+ assert_field_is_semantically_equivalent(expected_field,
actual_field);
+ }
+ }
+
+ fn assert_field_is_semantically_equivalent(expected: &Field, actual:
&Field) {
+ let context = format!("Field '{}'", expected.name());
+
+ assert_eq!(
+ expected.name(),
+ actual.name(),
+ "{context}: names must match"
+ );
+ assert_eq!(
+ expected.is_nullable(),
+ actual.is_nullable(),
+ "{context}: nullability must match"
+ );
+
+ // Recursively check the data types.
+ assert_datatype_is_semantically_equivalent(
+ expected.data_type(),
+ actual.data_type(),
+ &context,
+ );
+
+ // Check that metadata is a valid superset.
+ assert_metadata_is_superset(expected.metadata(), actual.metadata(),
&context);
+ }
+
+ fn assert_datatype_is_semantically_equivalent(
+ expected: &DataType,
+ actual: &DataType,
+ context: &str,
+ ) {
+ match (expected, actual) {
+ (DataType::List(expected_field), DataType::List(actual_field))
+ | (DataType::LargeList(expected_field),
DataType::LargeList(actual_field))
+ | (DataType::Map(expected_field, _), DataType::Map(actual_field,
_)) => {
+ assert_field_is_semantically_equivalent(expected_field,
actual_field);
+ }
+ (DataType::Struct(expected_fields),
DataType::Struct(actual_fields)) => {
+ assert_eq!(
+ expected_fields.len(),
+ actual_fields.len(),
+ "{context}: struct must have same number of fields"
+ );
+ for (ef, af) in
expected_fields.iter().zip(actual_fields.iter()) {
+ assert_field_is_semantically_equivalent(ef, af);
+ }
+ }
+ (
+ DataType::Union(expected_fields, expected_mode),
+ DataType::Union(actual_fields, actual_mode),
+ ) => {
+ assert_eq!(
+ expected_mode, actual_mode,
+ "{context}: union mode must match"
+ );
+ assert_eq!(
+ expected_fields.len(),
+ actual_fields.len(),
+ "{context}: union must have same number of variants"
+ );
+ for ((exp_id, exp_field), (act_id, act_field)) in
+ expected_fields.iter().zip(actual_fields.iter())
+ {
+ assert_eq!(exp_id, act_id, "{context}: union type ids must
match");
+ assert_field_is_semantically_equivalent(exp_field,
act_field);
+ }
+ }
+ _ => {
+ assert_eq!(expected, actual, "{context}: data types must be
identical");
+ }
+ }
+ }
+
+ fn assert_batch_data_is_identical(expected: &RecordBatch, actual:
&RecordBatch) {
+ assert_eq!(
+ expected.num_columns(),
+ actual.num_columns(),
+ "RecordBatches must have the same number of columns"
+ );
+ assert_eq!(
+ expected.num_rows(),
+ actual.num_rows(),
+ "RecordBatches must have the same number of rows"
+ );
+
+ for i in 0..expected.num_columns() {
+ let context = format!("Column {i}");
+ let expected_col = expected.column(i);
+ let actual_col = actual.column(i);
+ assert_array_data_is_identical(expected_col, actual_col, &context);
+ }
+ }
+
+ /// Recursively asserts that the data content of two Arrays is identical.
+ fn assert_array_data_is_identical(expected: &dyn Array, actual: &dyn
Array, context: &str) {
+ assert_eq!(
+ expected.nulls(),
+ actual.nulls(),
+ "{context}: null buffers must match"
+ );
+ assert_eq!(
+ expected.len(),
+ actual.len(),
+ "{context}: array lengths must match"
+ );
+
+ match (expected.data_type(), actual.data_type()) {
+ (DataType::Union(expected_fields, _), DataType::Union(..)) => {
+ let expected_union =
expected.as_any().downcast_ref::<UnionArray>().unwrap();
+ let actual_union =
actual.as_any().downcast_ref::<UnionArray>().unwrap();
+
+ // Compare the type_ids buffer (always the first buffer).
+ assert_eq!(
+ &expected.to_data().buffers()[0],
+ &actual.to_data().buffers()[0],
+ "{context}: union type_ids buffer mismatch"
+ );
+
+ // For dense unions, compare the value_offsets buffer (the
second buffer).
+ if expected.to_data().buffers().len() > 1 {
+ assert_eq!(
+ &expected.to_data().buffers()[1],
+ &actual.to_data().buffers()[1],
+ "{context}: union value_offsets buffer mismatch"
+ );
+ }
+
+ // Recursively compare children based on the fields in the
DataType.
+ for (type_id, _) in expected_fields.iter() {
+ let child_context = format!("{context} -> child variant
{type_id}");
+ assert_array_data_is_identical(
+ expected_union.child(type_id),
+ actual_union.child(type_id),
+ &child_context,
+ );
+ }
+ }
+ (DataType::Struct(_), DataType::Struct(_)) => {
+ let expected_struct =
expected.as_any().downcast_ref::<StructArray>().unwrap();
+ let actual_struct =
actual.as_any().downcast_ref::<StructArray>().unwrap();
+ for i in 0..expected_struct.num_columns() {
+ let child_context = format!("{context} -> struct child
{i}");
+ assert_array_data_is_identical(
+ expected_struct.column(i),
+ actual_struct.column(i),
+ &child_context,
+ );
+ }
+ }
+ // Fallback for primitive types and other types where buffer
comparison is sufficient.
+ _ => {
+ assert_eq!(
+ expected.to_data().buffers(),
+ actual.to_data().buffers(),
+ "{context}: data buffers must match"
+ );
+ }
+ }
+ }
+
+ /// Checks that `actual_meta` contains all of `expected_meta`, and any
additional
+ /// keys in `actual_meta` are from a permitted set.
+ fn assert_metadata_is_superset(
+ expected_meta: &HashMap<String, String>,
+ actual_meta: &HashMap<String, String>,
+ context: &str,
+ ) {
+ let allowed_additions: HashSet<&str> =
+ vec!["arrowUnionMode", "arrowUnionTypeIds", "avro.name"]
+ .into_iter()
+ .collect();
+ for (key, expected_value) in expected_meta {
+ match actual_meta.get(key) {
+ Some(actual_value) => assert_eq!(
+ expected_value, actual_value,
+ "{context}: preserved metadata for key '{key}' must have
the same value"
+ ),
+ None => panic!("{context}: metadata key '{key}' was lost
during roundtrip"),
+ }
+ }
+ for key in actual_meta.keys() {
+ if !expected_meta.contains_key(key) &&
!allowed_additions.contains(key.as_str()) {
+ panic!("{context}: unexpected metadata key '{key}' was added
during roundtrip");
+ }
+ }
+ }
+
+ #[test]
+ fn test_union_roundtrip() -> Result<(), ArrowError> {
+ let file_path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
+ .join("test/data/union_fields.avro")
+ .to_string_lossy()
+ .into_owned();
+ let rdr_file = File::open(&file_path).expect("open
avro/union_fields.avro");
+ let reader = ReaderBuilder::new()
+ .build(BufReader::new(rdr_file))
+ .expect("build reader for union_fields.avro");
+ let schema = reader.schema();
+ let input_batches = reader.collect::<Result<Vec<_>, _>>()?;
+ let original =
+ arrow::compute::concat_batches(&schema,
&input_batches).expect("concat input");
+ let mut writer = AvroWriter::new(Vec::<u8>::new(),
original.schema().as_ref().clone())?;
+ writer.write(&original)?;
+ writer.finish()?;
+ let bytes = writer.into_inner();
+ let rt_reader = ReaderBuilder::new()
+ .build(Cursor::new(bytes))
+ .expect("build round_trip reader");
+ let rt_schema = rt_reader.schema();
+ let rt_batches = rt_reader.collect::<Result<Vec<_>, _>>()?;
+ let round_trip =
+ arrow::compute::concat_batches(&rt_schema,
&rt_batches).expect("concat round_trip");
+
+ // The nature of the crate is such that metadata gets appended during
the roundtrip,
+ // so we can't compare the schemas directly. Instead, we semantically
compare the schemas and data.
+ assert_schema_is_semantically_equivalent(&original.schema(),
&round_trip.schema());
+
+ assert_batch_data_is_identical(&original, &round_trip);
+ Ok(())
+ }
+
#[test]
fn test_enum_roundtrip_uses_reader_fixture() -> Result<(), ArrowError> {
// Read the known-good enum file (same as reader::test_simple)