mbrobbel commented on code in PR #8550:
URL: https://github.com/apache/arrow-rs/pull/8550#discussion_r2409850298
##########
arrow-avro/src/writer/mod.rs:
##########
@@ -1017,6 +1019,251 @@ mod tests {
Ok(())
}
+ // Union Roundtrip Test Helpers
+
+ // Asserts that the `actual` schema is a semantically equivalent superset
of the `expected` one.
+ // This allows the `actual` schema to contain additional metadata keys
+ // (`arrowUnionMode`, `arrowUnionTypeIds`, `avro.name`) that are added
during an Arrow-to-Avro-to-Arrow
+ // roundtrip, while ensuring no other information was lost or changed.
+ fn assert_schema_is_semantically_equivalent(expected: &Schema, actual:
&Schema) {
+ // Compare top-level schema metadata using the same superset logic.
+ assert_metadata_is_superset(expected.metadata(), actual.metadata(),
"Schema");
+
+ // Compare fields.
+ assert_eq!(
+ expected.fields().len(),
+ actual.fields().len(),
+ "Schema must have the same number of fields"
+ );
+
+ for (expected_field, actual_field) in
expected.fields().iter().zip(actual.fields().iter()) {
+ assert_field_is_semantically_equivalent(expected_field,
actual_field);
+ }
+ }
+
+ fn assert_field_is_semantically_equivalent(expected: &Field, actual:
&Field) {
+ let context = format!("Field '{}'", expected.name());
+
+ assert_eq!(
+ expected.name(),
+ actual.name(),
+ "{context}: names must match"
+ );
+ assert_eq!(
+ expected.is_nullable(),
+ actual.is_nullable(),
+ "{context}: nullability must match"
+ );
+
+ // Recursively check the data types.
+ assert_datatype_is_semantically_equivalent(
+ expected.data_type(),
+ actual.data_type(),
+ &context,
+ );
+
+ // Check that metadata is a valid superset.
+ assert_metadata_is_superset(expected.metadata(), actual.metadata(),
&context);
+ }
+
+ fn assert_datatype_is_semantically_equivalent(
+ expected: &DataType,
+ actual: &DataType,
+ context: &str,
+ ) {
+ match (expected, actual) {
+ (DataType::List(expected_field), DataType::List(actual_field))
+ | (DataType::LargeList(expected_field),
DataType::LargeList(actual_field))
+ | (DataType::Map(expected_field, _), DataType::Map(actual_field,
_)) => {
+ assert_field_is_semantically_equivalent(expected_field,
actual_field);
+ }
+ (DataType::Struct(expected_fields),
DataType::Struct(actual_fields)) => {
+ assert_eq!(
+ expected_fields.len(),
+ actual_fields.len(),
+ "{context}: struct must have same number of fields"
+ );
+ for (ef, af) in
expected_fields.iter().zip(actual_fields.iter()) {
+ assert_field_is_semantically_equivalent(ef, af);
+ }
+ }
+ (
+ DataType::Union(expected_fields, expected_mode),
+ DataType::Union(actual_fields, actual_mode),
+ ) => {
+ assert_eq!(
+ expected_mode, actual_mode,
+ "{context}: union mode must match"
+ );
+ assert_eq!(
+ expected_fields.len(),
+ actual_fields.len(),
+ "{context}: union must have same number of variants"
+ );
+ for ((exp_id, exp_field), (act_id, act_field)) in
+ expected_fields.iter().zip(actual_fields.iter())
+ {
+ assert_eq!(exp_id, act_id, "{context}: union type ids must
match");
+ assert_field_is_semantically_equivalent(exp_field,
act_field);
+ }
+ }
+ _ => {
+ assert_eq!(expected, actual, "{context}: data types must be
identical");
+ }
+ }
+ }
+
+ fn assert_batch_data_is_identical(expected: &RecordBatch, actual:
&RecordBatch) {
+ assert_eq!(
+ expected.num_columns(),
+ actual.num_columns(),
+ "RecordBatches must have the same number of columns"
+ );
+ assert_eq!(
+ expected.num_rows(),
+ actual.num_rows(),
+ "RecordBatches must have the same number of rows"
+ );
+
+ for i in 0..expected.num_columns() {
+ let context = format!("Column {}", i);
Review Comment:
```suggestion
let context = format!("Column {i}");
```
##########
arrow-avro/src/writer/encoder.rs:
##########
@@ -1086,6 +1173,58 @@ impl EnumEncoder<'_> {
}
}
+struct UnionEncoder<'a> {
+ encoders: Vec<FieldEncoder<'a>>,
+ array: &'a UnionArray,
+}
+
+impl<'a> UnionEncoder<'a> {
+ fn try_new(array: &'a UnionArray, field_bindings: &[FieldBinding]) ->
Result<Self, ArrowError> {
+ let DataType::Union(fields, UnionMode::Dense) = array.data_type() else
{
+ return Err(ArrowError::SchemaError("Expected Dense
UnionArray".into()));
+ };
+
+ if fields.len() != field_bindings.len() {
+ return Err(ArrowError::SchemaError(format!(
+ "Mismatched number of union branches between Arrow array ({})
and encoding plan ({})",
+ fields.len(),
+ field_bindings.len()
+ )));
+ }
+ let mut encoders = Vec::with_capacity(fields.len());
+ for (type_id, field_ref) in fields.iter() {
+ let binding = field_bindings
+ .get(type_id as usize)
+ .ok_or_else(|| ArrowError::SchemaError("Binding and field
mismatch".to_string()))?;
+
+ let child = array.child(type_id).as_ref();
+
+ let encoder = prepare_value_site_encoder(
+ child,
+ field_ref.as_ref(),
+ binding.nullability,
+ &binding.plan,
+ )?;
+ encoders.push(encoder);
+ }
+ Ok(Self { encoders, array })
+ }
+
+ fn encode<W: Write + ?Sized>(&mut self, out: &mut W, idx: usize) ->
Result<(), ArrowError> {
+ let type_id = self.array.type_ids()[idx];
+ let branch_index = type_id as usize;
+ write_int(out, type_id as i32)?;
+ let child_row = self.array.value_offset(idx);
+
+ let encoder = self
+ .encoders
+ .get_mut(branch_index)
+ .ok_or_else(|| ArrowError::SchemaError(format!("Invalid type_id
{}", type_id)))?;
Review Comment:
```suggestion
.ok_or_else(|| ArrowError::SchemaError(format!("Invalid type_id
{type_id}")))?;
```
##########
arrow-avro/src/writer/mod.rs:
##########
@@ -1017,6 +1019,251 @@ mod tests {
Ok(())
}
+ // Union Roundtrip Test Helpers
+
+ // Asserts that the `actual` schema is a semantically equivalent superset
of the `expected` one.
+ // This allows the `actual` schema to contain additional metadata keys
+ // (`arrowUnionMode`, `arrowUnionTypeIds`, `avro.name`) that are added
during an Arrow-to-Avro-to-Arrow
+ // roundtrip, while ensuring no other information was lost or changed.
+ fn assert_schema_is_semantically_equivalent(expected: &Schema, actual:
&Schema) {
+ // Compare top-level schema metadata using the same superset logic.
+ assert_metadata_is_superset(expected.metadata(), actual.metadata(),
"Schema");
+
+ // Compare fields.
+ assert_eq!(
+ expected.fields().len(),
+ actual.fields().len(),
+ "Schema must have the same number of fields"
+ );
+
+ for (expected_field, actual_field) in
expected.fields().iter().zip(actual.fields().iter()) {
+ assert_field_is_semantically_equivalent(expected_field,
actual_field);
+ }
+ }
+
+ fn assert_field_is_semantically_equivalent(expected: &Field, actual:
&Field) {
+ let context = format!("Field '{}'", expected.name());
+
+ assert_eq!(
+ expected.name(),
+ actual.name(),
+ "{context}: names must match"
+ );
+ assert_eq!(
+ expected.is_nullable(),
+ actual.is_nullable(),
+ "{context}: nullability must match"
+ );
+
+ // Recursively check the data types.
+ assert_datatype_is_semantically_equivalent(
+ expected.data_type(),
+ actual.data_type(),
+ &context,
+ );
+
+ // Check that metadata is a valid superset.
+ assert_metadata_is_superset(expected.metadata(), actual.metadata(),
&context);
+ }
+
+ fn assert_datatype_is_semantically_equivalent(
+ expected: &DataType,
+ actual: &DataType,
+ context: &str,
+ ) {
+ match (expected, actual) {
+ (DataType::List(expected_field), DataType::List(actual_field))
+ | (DataType::LargeList(expected_field),
DataType::LargeList(actual_field))
+ | (DataType::Map(expected_field, _), DataType::Map(actual_field,
_)) => {
+ assert_field_is_semantically_equivalent(expected_field,
actual_field);
+ }
+ (DataType::Struct(expected_fields),
DataType::Struct(actual_fields)) => {
+ assert_eq!(
+ expected_fields.len(),
+ actual_fields.len(),
+ "{context}: struct must have same number of fields"
+ );
+ for (ef, af) in
expected_fields.iter().zip(actual_fields.iter()) {
+ assert_field_is_semantically_equivalent(ef, af);
+ }
+ }
+ (
+ DataType::Union(expected_fields, expected_mode),
+ DataType::Union(actual_fields, actual_mode),
+ ) => {
+ assert_eq!(
+ expected_mode, actual_mode,
+ "{context}: union mode must match"
+ );
+ assert_eq!(
+ expected_fields.len(),
+ actual_fields.len(),
+ "{context}: union must have same number of variants"
+ );
+ for ((exp_id, exp_field), (act_id, act_field)) in
+ expected_fields.iter().zip(actual_fields.iter())
+ {
+ assert_eq!(exp_id, act_id, "{context}: union type ids must
match");
+ assert_field_is_semantically_equivalent(exp_field,
act_field);
+ }
+ }
+ _ => {
+ assert_eq!(expected, actual, "{context}: data types must be
identical");
+ }
+ }
+ }
+
+ fn assert_batch_data_is_identical(expected: &RecordBatch, actual:
&RecordBatch) {
+ assert_eq!(
+ expected.num_columns(),
+ actual.num_columns(),
+ "RecordBatches must have the same number of columns"
+ );
+ assert_eq!(
+ expected.num_rows(),
+ actual.num_rows(),
+ "RecordBatches must have the same number of rows"
+ );
+
+ for i in 0..expected.num_columns() {
+ let context = format!("Column {}", i);
+ let expected_col = expected.column(i);
+ let actual_col = actual.column(i);
+ assert_array_data_is_identical(expected_col, actual_col, &context);
+ }
+ }
+
+ /// Recursively asserts that the data content of two Arrays is identical.
+ fn assert_array_data_is_identical(expected: &dyn Array, actual: &dyn
Array, context: &str) {
+ assert_eq!(
+ expected.nulls(),
+ actual.nulls(),
+ "{context}: null buffers must match"
+ );
+ assert_eq!(
+ expected.len(),
+ actual.len(),
+ "{context}: array lengths must match"
+ );
+
+ match (expected.data_type(), actual.data_type()) {
+ (DataType::Union(expected_fields, _), DataType::Union(..)) => {
+ let expected_union =
expected.as_any().downcast_ref::<UnionArray>().unwrap();
+ let actual_union =
actual.as_any().downcast_ref::<UnionArray>().unwrap();
+
+ // Compare the type_ids buffer (always the first buffer).
+ assert_eq!(
+ &expected.to_data().buffers()[0],
+ &actual.to_data().buffers()[0],
+ "{context}: union type_ids buffer mismatch"
+ );
+
+ // For dense unions, compare the value_offsets buffer (the
second buffer).
+ if expected.to_data().buffers().len() > 1 {
+ assert_eq!(
+ &expected.to_data().buffers()[1],
+ &actual.to_data().buffers()[1],
+ "{context}: union value_offsets buffer mismatch"
+ );
+ }
+
+ // Recursively compare children based on the fields in the
DataType.
+ for (type_id, _) in expected_fields.iter() {
+ let child_context = format!("{context} -> child variant
{type_id}");
+ assert_array_data_is_identical(
+ expected_union.child(type_id),
+ actual_union.child(type_id),
+ &child_context,
+ );
+ }
+ }
+ (DataType::Struct(_), DataType::Struct(_)) => {
+ let expected_struct =
expected.as_any().downcast_ref::<StructArray>().unwrap();
+ let actual_struct =
actual.as_any().downcast_ref::<StructArray>().unwrap();
+ for i in 0..expected_struct.num_columns() {
+ let child_context = format!("{context} -> struct child
{i}");
+ assert_array_data_is_identical(
+ expected_struct.column(i),
+ actual_struct.column(i),
+ &child_context,
+ );
+ }
+ }
+ // Fallback for primitive types and other types where buffer
comparison is sufficient.
+ _ => {
+ assert_eq!(
+ expected.to_data().buffers(),
+ actual.to_data().buffers(),
+ "{context}: data buffers must match"
+ );
+ }
+ }
+ }
+
+ /// Checks that `actual_meta` contains all of `expected_meta`, and any
additional
+ /// keys in `actual_meta` are from a permitted set.
+ fn assert_metadata_is_superset(
+ expected_meta: &HashMap<String, String>,
+ actual_meta: &HashMap<String, String>,
+ context: &str,
+ ) {
+ let allowed_additions: HashSet<&str> =
+ vec!["arrowUnionMode", "arrowUnionTypeIds", "avro.name"]
+ .into_iter()
+ .collect();
+ for (key, expected_value) in expected_meta {
+ match actual_meta.get(key) {
+ Some(actual_value) => assert_eq!(
+ expected_value, actual_value,
+ "{context}: preserved metadata for key '{key}' must have
the same value"
+ ),
+ None => panic!("{context}: metadata key '{key}' was lost
during roundtrip"),
+ }
+ }
+ for key in actual_meta.keys() {
+ if !expected_meta.contains_key(key) &&
!allowed_additions.contains(key.as_str()) {
+ panic!("{context}: unexpected metadata key '{key}' was added
during roundtrip");
+ }
+ }
+ }
+
+ #[test]
+ fn test_union_roundtrip() -> Result<(), ArrowError> {
+ let file_path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
+ .join("test/data/union_fields.avro")
+ .to_string_lossy()
+ .into_owned();
+ let rdr_file = File::open(&file_path).expect("open
avro/union_fields.avro");
+ let reader = ReaderBuilder::new()
+ .build(BufReader::new(rdr_file))
+ .expect("build reader for union_fields.avro");
+ let schema = reader.schema();
+ println!("schema: {:?}", schema);
Review Comment:
```suggestion
```
##########
arrow-avro/src/writer/mod.rs:
##########
@@ -1017,6 +1019,251 @@ mod tests {
Ok(())
}
+ // Union Roundtrip Test Helpers
+
+ // Asserts that the `actual` schema is a semantically equivalent superset
of the `expected` one.
+ // This allows the `actual` schema to contain additional metadata keys
+ // (`arrowUnionMode`, `arrowUnionTypeIds`, `avro.name`) that are added
during an Arrow-to-Avro-to-Arrow
+ // roundtrip, while ensuring no other information was lost or changed.
+ fn assert_schema_is_semantically_equivalent(expected: &Schema, actual:
&Schema) {
+ // Compare top-level schema metadata using the same superset logic.
+ assert_metadata_is_superset(expected.metadata(), actual.metadata(),
"Schema");
+
+ // Compare fields.
+ assert_eq!(
+ expected.fields().len(),
+ actual.fields().len(),
+ "Schema must have the same number of fields"
+ );
+
+ for (expected_field, actual_field) in
expected.fields().iter().zip(actual.fields().iter()) {
+ assert_field_is_semantically_equivalent(expected_field,
actual_field);
+ }
+ }
+
+ fn assert_field_is_semantically_equivalent(expected: &Field, actual:
&Field) {
+ let context = format!("Field '{}'", expected.name());
+
+ assert_eq!(
+ expected.name(),
+ actual.name(),
+ "{context}: names must match"
+ );
+ assert_eq!(
+ expected.is_nullable(),
+ actual.is_nullable(),
+ "{context}: nullability must match"
+ );
+
+ // Recursively check the data types.
+ assert_datatype_is_semantically_equivalent(
+ expected.data_type(),
+ actual.data_type(),
+ &context,
+ );
+
+ // Check that metadata is a valid superset.
+ assert_metadata_is_superset(expected.metadata(), actual.metadata(),
&context);
+ }
+
+ fn assert_datatype_is_semantically_equivalent(
+ expected: &DataType,
+ actual: &DataType,
+ context: &str,
+ ) {
+ match (expected, actual) {
+ (DataType::List(expected_field), DataType::List(actual_field))
+ | (DataType::LargeList(expected_field),
DataType::LargeList(actual_field))
+ | (DataType::Map(expected_field, _), DataType::Map(actual_field,
_)) => {
+ assert_field_is_semantically_equivalent(expected_field,
actual_field);
+ }
+ (DataType::Struct(expected_fields),
DataType::Struct(actual_fields)) => {
+ assert_eq!(
+ expected_fields.len(),
+ actual_fields.len(),
+ "{context}: struct must have same number of fields"
+ );
+ for (ef, af) in
expected_fields.iter().zip(actual_fields.iter()) {
+ assert_field_is_semantically_equivalent(ef, af);
+ }
+ }
+ (
+ DataType::Union(expected_fields, expected_mode),
+ DataType::Union(actual_fields, actual_mode),
+ ) => {
+ assert_eq!(
+ expected_mode, actual_mode,
+ "{context}: union mode must match"
+ );
+ assert_eq!(
+ expected_fields.len(),
+ actual_fields.len(),
+ "{context}: union must have same number of variants"
+ );
+ for ((exp_id, exp_field), (act_id, act_field)) in
+ expected_fields.iter().zip(actual_fields.iter())
+ {
+ assert_eq!(exp_id, act_id, "{context}: union type ids must
match");
+ assert_field_is_semantically_equivalent(exp_field,
act_field);
+ }
+ }
+ _ => {
+ assert_eq!(expected, actual, "{context}: data types must be
identical");
+ }
+ }
+ }
+
+ fn assert_batch_data_is_identical(expected: &RecordBatch, actual:
&RecordBatch) {
+ assert_eq!(
+ expected.num_columns(),
+ actual.num_columns(),
+ "RecordBatches must have the same number of columns"
+ );
+ assert_eq!(
+ expected.num_rows(),
+ actual.num_rows(),
+ "RecordBatches must have the same number of rows"
+ );
+
+ for i in 0..expected.num_columns() {
+ let context = format!("Column {}", i);
+ let expected_col = expected.column(i);
+ let actual_col = actual.column(i);
+ assert_array_data_is_identical(expected_col, actual_col, &context);
+ }
+ }
+
+ /// Recursively asserts that the data content of two Arrays is identical.
+ fn assert_array_data_is_identical(expected: &dyn Array, actual: &dyn
Array, context: &str) {
+ assert_eq!(
+ expected.nulls(),
+ actual.nulls(),
+ "{context}: null buffers must match"
+ );
+ assert_eq!(
+ expected.len(),
+ actual.len(),
+ "{context}: array lengths must match"
+ );
+
+ match (expected.data_type(), actual.data_type()) {
+ (DataType::Union(expected_fields, _), DataType::Union(..)) => {
+ let expected_union =
expected.as_any().downcast_ref::<UnionArray>().unwrap();
+ let actual_union =
actual.as_any().downcast_ref::<UnionArray>().unwrap();
+
+ // Compare the type_ids buffer (always the first buffer).
+ assert_eq!(
+ &expected.to_data().buffers()[0],
+ &actual.to_data().buffers()[0],
+ "{context}: union type_ids buffer mismatch"
+ );
+
+ // For dense unions, compare the value_offsets buffer (the
second buffer).
+ if expected.to_data().buffers().len() > 1 {
+ assert_eq!(
+ &expected.to_data().buffers()[1],
+ &actual.to_data().buffers()[1],
+ "{context}: union value_offsets buffer mismatch"
+ );
+ }
+
+ // Recursively compare children based on the fields in the
DataType.
+ for (type_id, _) in expected_fields.iter() {
+ let child_context = format!("{context} -> child variant
{type_id}");
+ assert_array_data_is_identical(
+ expected_union.child(type_id),
+ actual_union.child(type_id),
+ &child_context,
+ );
+ }
+ }
+ (DataType::Struct(_), DataType::Struct(_)) => {
+ let expected_struct =
expected.as_any().downcast_ref::<StructArray>().unwrap();
+ let actual_struct =
actual.as_any().downcast_ref::<StructArray>().unwrap();
+ for i in 0..expected_struct.num_columns() {
+ let child_context = format!("{context} -> struct child
{i}");
+ assert_array_data_is_identical(
+ expected_struct.column(i),
+ actual_struct.column(i),
+ &child_context,
+ );
+ }
+ }
+ // Fallback for primitive types and other types where buffer
comparison is sufficient.
+ _ => {
+ assert_eq!(
+ expected.to_data().buffers(),
+ actual.to_data().buffers(),
+ "{context}: data buffers must match"
+ );
+ }
+ }
+ }
+
+ /// Checks that `actual_meta` contains all of `expected_meta`, and any
additional
+ /// keys in `actual_meta` are from a permitted set.
+ fn assert_metadata_is_superset(
+ expected_meta: &HashMap<String, String>,
+ actual_meta: &HashMap<String, String>,
+ context: &str,
+ ) {
+ let allowed_additions: HashSet<&str> =
+ vec!["arrowUnionMode", "arrowUnionTypeIds", "avro.name"]
+ .into_iter()
+ .collect();
+ for (key, expected_value) in expected_meta {
+ match actual_meta.get(key) {
+ Some(actual_value) => assert_eq!(
+ expected_value, actual_value,
+ "{context}: preserved metadata for key '{key}' must have
the same value"
+ ),
+ None => panic!("{context}: metadata key '{key}' was lost
during roundtrip"),
+ }
+ }
+ for key in actual_meta.keys() {
+ if !expected_meta.contains_key(key) &&
!allowed_additions.contains(key.as_str()) {
+ panic!("{context}: unexpected metadata key '{key}' was added
during roundtrip");
+ }
+ }
+ }
+
+ #[test]
+ fn test_union_roundtrip() -> Result<(), ArrowError> {
+ let file_path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
+ .join("test/data/union_fields.avro")
+ .to_string_lossy()
+ .into_owned();
+ let rdr_file = File::open(&file_path).expect("open
avro/union_fields.avro");
+ let reader = ReaderBuilder::new()
+ .build(BufReader::new(rdr_file))
+ .expect("build reader for union_fields.avro");
+ let schema = reader.schema();
+ println!("schema: {:?}", schema);
+ let input_batches = reader.collect::<Result<Vec<_>, _>>()?;
+ let original =
+ arrow::compute::concat_batches(&schema,
&input_batches).expect("concat input");
+ let tmp = NamedTempFile::new().expect("create temp file");
+ let out_file = File::create(tmp.path()).expect("create temp avro");
Review Comment:
Maybe we can also use in-memory buffers here
(https://github.com/apache/arrow-rs/pull/8546#discussion_r2405484016)?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]