This is an automated email from the ASF dual-hosted git repository. alamb pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/main by this push: new aed2f3b6a7 Add arrow-avro Reader support for Dense Union and Union resolution (Part 1) (#8348) aed2f3b6a7 is described below commit aed2f3b6a72375acb06cf958c9e3ff3c6ecb760f Author: Connor Sanders <con...@elastiflow.com> AuthorDate: Thu Sep 18 05:20:27 2025 -0500 Add arrow-avro Reader support for Dense Union and Union resolution (Part 1) (#8348) # Which issue does this PR close? This work continues arrow-avro schema resolution support and aligns behavior with the Avro spec. - **Related to**: #4886 (“Add Avro Support”): ongoing work to round out the reader/decoder, including schema resolution and type promotion. # Rationale for this change `arrow-avro` lacked end‑to‑end support for Avro unions and Arrow `Union` schemas. Many Avro datasets rely on unions (i.e., `["null","string"]`, tagged unions of different records), and without schema‐level resolution and JSON encoding the crate could not interoperate cleanly. This PR brings union schema resolution to parity with the Avro spec (duplicate-branch and nested‑union checks), adds Arrow to Avro union schema conversion (with mode/type‑id metadata), and lays groundwork for data decoding in a follow‑up. # What changes are included in this PR? **Schema resolution & codecs** - Add `Codec::Union(Arc<[AvroDataType]>, UnionFields, UnionMode)` and map it to Arrow `DataType::Union`. - Introduce `ResolvedUnion` and extend `ResolutionInfo` with a `Union(...)` variant to capture writer to reader branch mapping (prefers direct matches over promotions). - Support union defaults: permit `null` defaults for unions whose **first** branch is `null`; reject empty unions for defaults. - Enforce Avro spec constraints during parsing/resolution: - Disallow nested unions. - Disallow duplicate branch *kinds* (except distinct named `record`/`enum`/`fixed`). - Keep **writer** null ordering when resolving nullable 2‑branch unions (i.e., `["null", "int"]` vs `["int", "null"]`). - Provide stable union field names derived from branch kind (i.e., `int`, `string`, `map`, ...) and construct dense `UnionFields` consistently. **Arrow and Avro schema conversion** - Implement Arrow `DataType::Union` to Avro union JSON: - Persist Arrow union layout via metadata keys: - `"arrowUnionMode"`: `"dense"` or `"sparse"`. - `"arrowUnionTypeIds"`: ordered list of Arrow type IDs. - Attach union‑level metadata to the **first non‑null** branch object (Avro JSON can’t carry attributes on the union array). - Persist additional Arrow metadata in Avro JSON: - `"arrowBinaryView"` for `BinaryView`. - `"arrowListView"` / `"arrowLargeList"` for list view types. - Reject invalid output shapes (i.e., a union branch that is itself an Avro union). **Reader/decoder stub** - Return a clear error for union **value** decoding in `RecordDecoder` (schema support first; decoding to follow). **Refactors & utilities** - Expose `make_full_name` within the crate for union branch keying; derive `Hash` for `PrimitiveType`; add helpers for branch de‑duplication. # Are these changes tested? Yes. New unit tests cover: - Resolution across writer/reader unions and non‑unions (direct vs promoted matches, partial coverage). - Nullable‑union semantics (writer null ordering preserved). - Arrow `Union` to Avro union JSON including mode/type‑id metadata and branch shapes. - Validation errors for duplicates and nested unions. # Are there any user-facing changes? N/A --- arrow-avro/src/codec.rs | 514 ++++++++++++++++++++++++++++++++++++---- arrow-avro/src/reader/record.rs | 5 + arrow-avro/src/schema.rs | 187 +++++++++++++-- 3 files changed, 635 insertions(+), 71 deletions(-) diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index cf0276f0a2..b3c8da2b5e 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -16,20 +16,21 @@ // under the License. use crate::schema::{ - Array, Attributes, AvroSchema, ComplexType, Enum, Fixed, Map, Nullability, PrimitiveType, - Record, Schema, Type, TypeName, AVRO_ENUM_SYMBOLS_METADATA_KEY, + make_full_name, Array, Attributes, AvroSchema, ComplexType, Enum, Fixed, Map, Nullability, + PrimitiveType, Record, Schema, Type, TypeName, AVRO_ENUM_SYMBOLS_METADATA_KEY, AVRO_FIELD_DEFAULT_METADATA_KEY, AVRO_ROOT_RECORD_DEFAULT_NAME, }; use arrow_schema::{ - ArrowError, DataType, Field, Fields, IntervalUnit, TimeUnit, DECIMAL128_MAX_PRECISION, - DECIMAL256_MAX_PRECISION, + ArrowError, DataType, Field, Fields, IntervalUnit, TimeUnit, UnionFields, UnionMode, + DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, }; #[cfg(feature = "small_decimals")] use arrow_schema::{DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION}; use indexmap::IndexMap; use serde_json::Value; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::sync::Arc; +use strum_macros::AsRefStr; /// Contains information about how to resolve differences between a writer's and a reader's schema. #[derive(Debug, Clone, PartialEq)] @@ -42,6 +43,8 @@ pub(crate) enum ResolutionInfo { EnumMapping(EnumMapping), /// Provides resolution information for record fields. Record(ResolvedRecord), + /// Provides mapping and shape info for resolving unions. + Union(ResolvedUnion), } /// Represents a literal Avro value. @@ -92,8 +95,10 @@ pub struct ResolvedRecord { /// /// Schema resolution may require promoting a writer's data type to a reader's data type. /// For example, an `int` can be promoted to a `long`, `float`, or `double`. -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub(crate) enum Promotion { + /// Direct read with no data type promotion. + Direct, /// Promotes an `int` to a `long`. IntToLong, /// Promotes an `int` to a `float`. @@ -112,6 +117,18 @@ pub(crate) enum Promotion { BytesToString, } +/// Information required to resolve a writer union against a reader union (or single type). +#[derive(Debug, Clone, PartialEq)] +pub struct ResolvedUnion { + /// For each writer branch index, the reader branch index and how to read it. + /// `None` means the writer branch doesn't resolve against the reader. + pub(crate) writer_to_reader: Arc<[Option<(usize, Promotion)>]>, + /// Whether the writer schema at this site is a union + pub(crate) writer_is_union: bool, + /// Whether the reader schema at this site is a union + pub(crate) reader_is_union: bool, +} + /// Holds the mapping information for resolving Avro enums. /// /// When resolving schemas, the writer's enum symbols must be mapped to the reader's symbols. @@ -267,6 +284,11 @@ impl AvroDataType { if default_json.is_null() { return match self.codec() { Codec::Null => Ok(AvroLiteral::Null), + Codec::Union(encodings, _, _) if !encodings.is_empty() + && matches!(encodings[0].codec(), Codec::Null) => + { + Ok(AvroLiteral::Null) + } _ if self.nullability() == Some(Nullability::NullFirst) => Ok(AvroLiteral::Null), _ => Err(ArrowError::SchemaError( "JSON null default is only valid for `null` type or for a union whose first branch is `null`" @@ -401,6 +423,14 @@ impl AvroDataType { )) } }, + Codec::Union(encodings, _, _) => { + if encodings.is_empty() { + return Err(ArrowError::SchemaError( + "Union with no branches cannot have a default".to_string(), + )); + } + encodings[0].parse_default_literal(default_json)? + } }; Ok(lit) } @@ -635,6 +665,8 @@ pub enum Codec { Map(Arc<AvroDataType>), /// Represents Avro duration logical type, maps to Arrow's Interval(IntervalUnit::MonthDayNano) data type Interval, + /// Represents Avro union type, maps to Arrow's Union data type + Union(Arc<[AvroDataType]>, UnionFields, UnionMode), } impl Codec { @@ -708,8 +740,42 @@ impl Codec { false, ) } + Self::Union(_, fields, mode) => DataType::Union(fields.clone(), *mode), + } + } + + /// Converts a string codec to use Utf8View if requested + /// + /// The conversion only happens if both: + /// 1. `use_utf8view` is true + /// 2. The codec is currently `Utf8` + /// + /// # Example + /// ``` + /// # use arrow_avro::codec::Codec; + /// let utf8_codec1 = Codec::Utf8; + /// let utf8_codec2 = Codec::Utf8; + /// + /// // Convert to Utf8View + /// let view_codec = utf8_codec1.with_utf8view(true); + /// assert!(matches!(view_codec, Codec::Utf8View)); + /// + /// // Don't convert if use_utf8view is false + /// let unchanged_codec = utf8_codec2.with_utf8view(false); + /// assert!(matches!(unchanged_codec, Codec::Utf8)); + /// ``` + pub fn with_utf8view(self, use_utf8view: bool) -> Self { + if use_utf8view && matches!(self, Self::Utf8) { + Self::Utf8View + } else { + self } } + + #[inline] + fn union_field_name(&self) -> String { + UnionFieldKind::from(self).as_ref().to_owned() + } } impl From<PrimitiveType> for Codec { @@ -804,36 +870,75 @@ fn parse_decimal_attributes( Ok((precision, scale, size)) } -impl Codec { - /// Converts a string codec to use Utf8View if requested - /// - /// The conversion only happens if both: - /// 1. `use_utf8view` is true - /// 2. The codec is currently `Utf8` - /// - /// # Example - /// ``` - /// # use arrow_avro::codec::Codec; - /// let utf8_codec1 = Codec::Utf8; - /// let utf8_codec2 = Codec::Utf8; - /// - /// // Convert to Utf8View - /// let view_codec = utf8_codec1.with_utf8view(true); - /// assert!(matches!(view_codec, Codec::Utf8View)); - /// - /// // Don't convert if use_utf8view is false - /// let unchanged_codec = utf8_codec2.with_utf8view(false); - /// assert!(matches!(unchanged_codec, Codec::Utf8)); - /// ``` - pub fn with_utf8view(self, use_utf8view: bool) -> Self { - if use_utf8view && matches!(self, Self::Utf8) { - Self::Utf8View - } else { - self +#[derive(Debug, Clone, Copy, PartialEq, Eq, AsRefStr)] +#[strum(serialize_all = "snake_case")] +enum UnionFieldKind { + Null, + Boolean, + Int, + Long, + Float, + Double, + Bytes, + String, + Date, + TimeMillis, + TimeMicros, + TimestampMillisUtc, + TimestampMillisLocal, + TimestampMicrosUtc, + TimestampMicrosLocal, + Duration, + Fixed, + Decimal, + Enum, + Array, + Record, + Map, + Uuid, + Union, +} + +impl From<&Codec> for UnionFieldKind { + fn from(c: &Codec) -> Self { + match c { + Codec::Null => Self::Null, + Codec::Boolean => Self::Boolean, + Codec::Int32 => Self::Int, + Codec::Int64 => Self::Long, + Codec::Float32 => Self::Float, + Codec::Float64 => Self::Double, + Codec::Binary => Self::Bytes, + Codec::Utf8 | Codec::Utf8View => Self::String, + Codec::Date32 => Self::Date, + Codec::TimeMillis => Self::TimeMillis, + Codec::TimeMicros => Self::TimeMicros, + Codec::TimestampMillis(true) => Self::TimestampMillisUtc, + Codec::TimestampMillis(false) => Self::TimestampMillisLocal, + Codec::TimestampMicros(true) => Self::TimestampMicrosUtc, + Codec::TimestampMicros(false) => Self::TimestampMicrosLocal, + Codec::Interval => Self::Duration, + Codec::Fixed(_) => Self::Fixed, + Codec::Decimal(..) => Self::Decimal, + Codec::Enum(_) => Self::Enum, + Codec::List(_) => Self::Array, + Codec::Struct(_) => Self::Record, + Codec::Map(_) => Self::Map, + Codec::Uuid => Self::Uuid, + Codec::Union(..) => Self::Union, } } } +fn build_union_fields(encodings: &[AvroDataType]) -> UnionFields { + let arrow_fields: Vec<Field> = encodings + .iter() + .map(|encoding| encoding.field_with_name(&encoding.codec().union_field_name())) + .collect(); + let type_ids: Vec<i8> = (0..arrow_fields.len()).map(|i| i as i8).collect(); + UnionFields::new(type_ids, arrow_fields) +} + /// Resolves Avro type names to [`AvroDataType`] /// /// See <https://avro.apache.org/docs/1.11.1/specification/#names> @@ -915,6 +1020,76 @@ fn nullable_union_variants<'x, 'y>( } } +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +enum UnionBranchKey { + Named(String), + Primitive(PrimitiveType), + Array, + Map, +} + +fn branch_key_of<'a>(s: &Schema<'a>, enclosing_ns: Option<&'a str>) -> Option<UnionBranchKey> { + match s { + // Primitives + Schema::TypeName(TypeName::Primitive(p)) => Some(UnionBranchKey::Primitive(*p)), + Schema::Type(Type { + r#type: TypeName::Primitive(p), + .. + }) => Some(UnionBranchKey::Primitive(*p)), + // Named references + Schema::TypeName(TypeName::Ref(name)) => { + let (full, _) = make_full_name(name, None, enclosing_ns); + Some(UnionBranchKey::Named(full)) + } + Schema::Type(Type { + r#type: TypeName::Ref(name), + .. + }) => { + let (full, _) = make_full_name(name, None, enclosing_ns); + Some(UnionBranchKey::Named(full)) + } + // Complex non‑named + Schema::Complex(ComplexType::Array(_)) => Some(UnionBranchKey::Array), + Schema::Complex(ComplexType::Map(_)) => Some(UnionBranchKey::Map), + // Inline named definitions + Schema::Complex(ComplexType::Record(r)) => { + let (full, _) = make_full_name(r.name, r.namespace, enclosing_ns); + Some(UnionBranchKey::Named(full)) + } + Schema::Complex(ComplexType::Enum(e)) => { + let (full, _) = make_full_name(e.name, e.namespace, enclosing_ns); + Some(UnionBranchKey::Named(full)) + } + Schema::Complex(ComplexType::Fixed(f)) => { + let (full, _) = make_full_name(f.name, f.namespace, enclosing_ns); + Some(UnionBranchKey::Named(full)) + } + // Unions are validated separately (and disallowed as immediate branches) + Schema::Union(_) => None, + } +} + +fn union_first_duplicate<'a>( + branches: &'a [Schema<'a>], + enclosing_ns: Option<&'a str>, +) -> Option<String> { + let mut seen: HashSet<UnionBranchKey> = HashSet::with_capacity(branches.len()); + for b in branches { + if let Some(key) = branch_key_of(b, enclosing_ns) { + if !seen.insert(key.clone()) { + let msg = match key { + UnionBranchKey::Named(full) => format!("named type {full}"), + UnionBranchKey::Primitive(p) => format!("primitive {}", p.as_ref()), + UnionBranchKey::Array => "array".to_string(), + UnionBranchKey::Map => "map".to_string(), + }; + return Some(msg); + } + } + } + None +} + /// Resolves Avro type names to [`AvroDataType`] /// /// See <https://avro.apache.org/docs/1.11.1/specification/#names> @@ -969,7 +1144,6 @@ impl<'a> Maker<'a> { )), Schema::TypeName(TypeName::Ref(name)) => self.resolver.resolve(name, namespace), Schema::Union(f) => { - // Special case the common case of nullable primitives let null = f .iter() .position(|x| x == &Schema::TypeName(TypeName::Primitive(PrimitiveType::Null))); @@ -977,7 +1151,7 @@ impl<'a> Maker<'a> { (true, Some(0)) => { let mut field = self.parse_type(&f[1], namespace)?; field.nullability = Some(Nullability::NullFirst); - Ok(field) + return Ok(field); } (true, Some(1)) => { if self.strict_mode { @@ -988,12 +1162,34 @@ impl<'a> Maker<'a> { } let mut field = self.parse_type(&f[0], namespace)?; field.nullability = Some(Nullability::NullSecond); - Ok(field) + return Ok(field); } - _ => Err(ArrowError::NotYetImplemented(format!( - "Union of {f:?} not currently supported" - ))), + _ => {} + } + // Validate: unions may not immediately contain unions + if f.iter().any(|s| matches!(s, Schema::Union(_))) { + return Err(ArrowError::SchemaError( + "Avro unions may not immediately contain other unions".to_string(), + )); + } + // Validate: duplicates (named by full name; non-named by kind) + if let Some(dup) = union_first_duplicate(f, namespace) { + return Err(ArrowError::SchemaError(format!( + "Avro union contains duplicate branch type: {dup}" + ))); } + // Parse all branches + let children: Vec<AvroDataType> = f + .iter() + .map(|s| self.parse_type(s, namespace)) + .collect::<Result<_, _>>()?; + // Build Arrow layout once here + let union_fields = build_union_fields(&children); + Ok(AvroDataType::new( + Codec::Union(Arc::from(children), union_fields, UnionMode::Dense), + Default::default(), + None, + )) } Schema::Complex(c) => match c { ComplexType::Record(r) => { @@ -1149,6 +1345,67 @@ impl<'a> Maker<'a> { return self.resolve_primitives(write_primitive, read_primitive, reader_schema); } match (writer_schema, reader_schema) { + (Schema::Union(writer_variants), Schema::Union(reader_variants)) => { + match ( + nullable_union_variants(writer_variants.as_slice()), + nullable_union_variants(reader_variants.as_slice()), + ) { + (Some((w_nb, w_nonnull)), Some((_r_nb, r_nonnull))) => { + let mut dt = self.make_data_type(w_nonnull, Some(r_nonnull), namespace)?; + dt.nullability = Some(w_nb); + Ok(dt) + } + _ => self.resolve_unions( + writer_variants.as_slice(), + reader_variants.as_slice(), + namespace, + ), + } + } + (Schema::Union(writer_variants), reader_non_union) => { + let mut writer_to_reader: Vec<Option<(usize, Promotion)>> = + Vec::with_capacity(writer_variants.len()); + for writer in writer_variants { + match self.resolve_type(writer, reader_non_union, namespace) { + Ok(tmp) => writer_to_reader.push(Some((0usize, Self::coercion_from(&tmp)))), + Err(_) => writer_to_reader.push(None), + } + } + let mut dt = self.parse_type(reader_non_union, namespace)?; + dt.resolution = Some(ResolutionInfo::Union(ResolvedUnion { + writer_to_reader: Arc::from(writer_to_reader), + writer_is_union: true, + reader_is_union: false, + })); + Ok(dt) + } + (writer_non_union, Schema::Union(reader_variants)) => { + let mut direct: Option<(usize, Promotion)> = None; + let mut promo: Option<(usize, Promotion)> = None; + for (reader_index, reader) in reader_variants.iter().enumerate() { + if let Ok(tmp) = self.resolve_type(writer_non_union, reader, namespace) { + let how = Self::coercion_from(&tmp); + if how == Promotion::Direct { + direct = Some((reader_index, how)); + break; // first exact match wins + } else if promo.is_none() { + promo = Some((reader_index, how)); + } + } + } + let (reader_index, promotion) = direct.or(promo).ok_or_else(|| { + ArrowError::SchemaError( + "Writer schema does not match any reader union branch".to_string(), + ) + })?; + let mut dt = self.parse_type(reader_schema, namespace)?; + dt.resolution = Some(ResolutionInfo::Union(ResolvedUnion { + writer_to_reader: Arc::from(vec![Some((reader_index, promotion))]), + writer_is_union: false, + reader_is_union: true, + })); + Ok(dt) + } ( Schema::Complex(ComplexType::Array(writer_array)), Schema::Complex(ComplexType::Array(reader_array)), @@ -1169,12 +1426,6 @@ impl<'a> Maker<'a> { Schema::Complex(ComplexType::Enum(writer_enum)), Schema::Complex(ComplexType::Enum(reader_enum)), ) => self.resolve_enums(writer_enum, reader_enum, reader_schema, namespace), - (Schema::Union(writer_variants), Schema::Union(reader_variants)) => self - .resolve_nullable_union( - writer_variants.as_slice(), - reader_variants.as_slice(), - namespace, - ), (Schema::TypeName(TypeName::Ref(_)), _) => self.parse_type(reader_schema, namespace), (_, Schema::TypeName(TypeName::Ref(_))) => self.parse_type(reader_schema, namespace), _ => Err(ArrowError::NotYetImplemented( @@ -1183,6 +1434,56 @@ impl<'a> Maker<'a> { } } + #[inline] + fn coercion_from(dt: &AvroDataType) -> Promotion { + match dt.resolution.as_ref() { + Some(ResolutionInfo::Promotion(promotion)) => *promotion, + _ => Promotion::Direct, + } + } + + fn resolve_unions<'s>( + &mut self, + writer_variants: &'s [Schema<'a>], + reader_variants: &'s [Schema<'a>], + namespace: Option<&'a str>, + ) -> Result<AvroDataType, ArrowError> { + let reader_encodings: Vec<AvroDataType> = reader_variants + .iter() + .map(|reader_schema| self.parse_type(reader_schema, namespace)) + .collect::<Result<_, _>>()?; + let mut writer_to_reader: Vec<Option<(usize, Promotion)>> = + Vec::with_capacity(writer_variants.len()); + for writer in writer_variants { + let mut direct: Option<(usize, Promotion)> = None; + let mut promo: Option<(usize, Promotion)> = None; + for (reader_index, reader) in reader_variants.iter().enumerate() { + if let Ok(tmp) = self.resolve_type(writer, reader, namespace) { + let promotion = Self::coercion_from(&tmp); + if promotion == Promotion::Direct { + direct = Some((reader_index, promotion)); + break; + } else if promo.is_none() { + promo = Some((reader_index, promotion)); + } + } + } + writer_to_reader.push(direct.or(promo)); + } + let union_fields = build_union_fields(&reader_encodings); + let mut dt = AvroDataType::new( + Codec::Union(reader_encodings.into(), union_fields, UnionMode::Dense), + Default::default(), + None, + ); + dt.resolution = Some(ResolutionInfo::Union(ResolvedUnion { + writer_to_reader: Arc::from(writer_to_reader), + writer_is_union: true, + reader_is_union: true, + })); + Ok(dt) + } + fn resolve_array( &mut self, writer_array: &Array<'a>, @@ -1281,10 +1582,9 @@ impl<'a> Maker<'a> { nullable_union_variants(writer_variants), nullable_union_variants(reader_variants), ) { - (Some((_, write_nonnull)), Some((read_nb, read_nonnull))) => { + (Some((write_nb, write_nonnull)), Some((_read_nb, read_nonnull))) => { let mut dt = self.make_data_type(write_nonnull, Some(read_nonnull), namespace)?; - // Adopt reader union null ordering - dt.nullability = Some(read_nb); + dt.nullability = Some(write_nb); Ok(dt) } _ => Err(ArrowError::NotYetImplemented( @@ -1557,6 +1857,24 @@ mod tests { .expect("promotion should resolve") } + fn mk_primitive(pt: PrimitiveType) -> Schema<'static> { + Schema::TypeName(TypeName::Primitive(pt)) + } + fn mk_union(branches: Vec<Schema<'static>>) -> Schema<'static> { + Schema::Union(branches) + } + + fn mk_record_named(name: &'static str) -> Schema<'static> { + Schema::Complex(ComplexType::Record(Record { + name, + namespace: None, + doc: None, + aliases: vec![], + fields: vec![], + attributes: Attributes::default(), + })) + } + #[test] fn test_date_logical_type() { let schema = create_schema_with_logical_type(PrimitiveType::Int, "date"); @@ -1842,7 +2160,7 @@ mod tests { } #[test] - fn test_promotion_within_nullable_union_keeps_reader_null_ordering() { + fn test_promotion_within_nullable_union_keeps_writer_null_ordering() { let writer = Schema::Union(vec![ Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)), @@ -1858,7 +2176,105 @@ mod tests { result.resolution, Some(ResolutionInfo::Promotion(Promotion::IntToDouble)) ); - assert_eq!(result.nullability, Some(Nullability::NullSecond)); + assert_eq!(result.nullability, Some(Nullability::NullFirst)); + } + + #[test] + fn test_resolve_writer_union_to_reader_non_union_partial_coverage() { + let writer = mk_union(vec![ + mk_primitive(PrimitiveType::String), + mk_primitive(PrimitiveType::Long), + ]); + let reader = mk_primitive(PrimitiveType::Bytes); + let mut maker = Maker::new(false, false); + let dt = maker.make_data_type(&writer, Some(&reader), None).unwrap(); + assert!(matches!(dt.codec(), Codec::Binary)); + let resolved = match dt.resolution { + Some(ResolutionInfo::Union(u)) => u, + other => panic!("expected union resolution info, got {other:?}"), + }; + assert!(resolved.writer_is_union && !resolved.reader_is_union); + assert_eq!( + resolved.writer_to_reader.as_ref(), + &[Some((0, Promotion::StringToBytes)), None] + ); + } + + #[test] + fn test_resolve_writer_non_union_to_reader_union_prefers_direct_over_promotion() { + let writer = mk_primitive(PrimitiveType::Long); + let reader = mk_union(vec![ + mk_primitive(PrimitiveType::Long), + mk_primitive(PrimitiveType::Double), + ]); + let mut maker = Maker::new(false, false); + let dt = maker.make_data_type(&writer, Some(&reader), None).unwrap(); + let resolved = match dt.resolution { + Some(ResolutionInfo::Union(u)) => u, + other => panic!("expected union resolution info, got {other:?}"), + }; + assert!(!resolved.writer_is_union && resolved.reader_is_union); + assert_eq!( + resolved.writer_to_reader.as_ref(), + &[Some((0, Promotion::Direct))] + ); + } + + #[test] + fn test_resolve_writer_non_union_to_reader_union_uses_promotion_when_needed() { + let writer = mk_primitive(PrimitiveType::Int); + let reader = mk_union(vec![ + mk_primitive(PrimitiveType::Null), + mk_primitive(PrimitiveType::Long), + mk_primitive(PrimitiveType::String), + ]); + let mut maker = Maker::new(false, false); + let dt = maker.make_data_type(&writer, Some(&reader), None).unwrap(); + let resolved = match dt.resolution { + Some(ResolutionInfo::Union(u)) => u, + other => panic!("expected union resolution info, got {other:?}"), + }; + assert_eq!( + resolved.writer_to_reader.as_ref(), + &[Some((1, Promotion::IntToLong))] + ); + } + + #[test] + fn test_resolve_both_nullable_unions_direct_match() { + let writer = mk_union(vec![ + mk_primitive(PrimitiveType::Null), + mk_primitive(PrimitiveType::String), + ]); + let reader = mk_union(vec![ + mk_primitive(PrimitiveType::String), + mk_primitive(PrimitiveType::Null), + ]); + let mut maker = Maker::new(false, false); + let dt = maker.make_data_type(&writer, Some(&reader), None).unwrap(); + assert!(matches!(dt.codec(), Codec::Utf8)); + assert_eq!(dt.nullability, Some(Nullability::NullFirst)); + assert!(dt.resolution.is_none()); + } + + #[test] + fn test_resolve_both_nullable_unions_with_promotion() { + let writer = mk_union(vec![ + mk_primitive(PrimitiveType::Null), + mk_primitive(PrimitiveType::Int), + ]); + let reader = mk_union(vec![ + mk_primitive(PrimitiveType::Double), + mk_primitive(PrimitiveType::Null), + ]); + let mut maker = Maker::new(false, false); + let dt = maker.make_data_type(&writer, Some(&reader), None).unwrap(); + assert!(matches!(dt.codec(), Codec::Float64)); + assert_eq!(dt.nullability, Some(Nullability::NullFirst)); + assert_eq!( + dt.resolution, + Some(ResolutionInfo::Promotion(Promotion::IntToDouble)) + ); } #[test] diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index 9ca8acb45b..80a3c19d5c 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -426,6 +426,11 @@ impl Decoder { ) } (Codec::Uuid, _) => Self::Uuid(Vec::with_capacity(DEFAULT_CAPACITY)), + (&Codec::Union(_, _, _), _) => { + return Err(ArrowError::NotYetImplemented( + "Union type decoding is not yet supported".to_string(), + )) + } }; Ok(match data_type.nullability() { Some(nullability) => Self::Nullable( diff --git a/arrow-avro/src/schema.rs b/arrow-avro/src/schema.rs index 511ba280f7..6c501a56ab 100644 --- a/arrow-avro/src/schema.rs +++ b/arrow-avro/src/schema.rs @@ -17,6 +17,7 @@ use arrow_schema::{ ArrowError, DataType, Field as ArrowField, IntervalUnit, Schema as ArrowSchema, TimeUnit, + UnionMode, }; use serde::{Deserialize, Serialize}; use serde_json::{json, Map as JsonMap, Value}; @@ -94,7 +95,7 @@ pub enum TypeName<'a> { /// A primitive type /// /// <https://avro.apache.org/docs/1.11.1/specification/#primitive-types> -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, AsRefStr)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, AsRefStr)] #[serde(rename_all = "camelCase")] #[strum(serialize_all = "lowercase")] pub enum PrimitiveType { @@ -718,7 +719,7 @@ fn quote(s: &str) -> Result<String, ArrowError> { // handling both ways of specifying the name. It prioritizes a namespace // defined within the `name` attribute itself, then the explicit `namespace_attr`, // and finally the `enclosing_ns`. -fn make_full_name( +pub(crate) fn make_full_name( name: &str, namespace_attr: Option<&str>, enclosing_ns: Option<&str>, @@ -955,6 +956,8 @@ fn merge_extras(schema: Value, mut extras: JsonMap<String, Value>) -> Value { Value::Object(map) } Value::Array(mut union) => { + // For unions, we cannot attach attributes to the array itself (per Avro spec). + // As a fallback for extension metadata, attach extras to the first non-null branch object. if let Some(non_null) = union.iter_mut().find(|val| val.as_str() != Some("null")) { let original = std::mem::take(non_null); *non_null = merge_extras(original, extras); @@ -970,13 +973,59 @@ fn merge_extras(schema: Value, mut extras: JsonMap<String, Value>) -> Value { } } +#[inline] +fn is_avro_json_null(v: &Value) -> bool { + matches!(v, Value::String(s) if s == "null") +} + fn wrap_nullable(inner: Value, null_order: Nullability) -> Value { let null = Value::String("null".into()); - let elements = match null_order { - Nullability::NullFirst => vec![null, inner], - Nullability::NullSecond => vec![inner, null], - }; - Value::Array(elements) + match inner { + Value::Array(mut union) => { + union.retain(|v| !is_avro_json_null(v)); + match null_order { + Nullability::NullFirst => { + let mut out = Vec::with_capacity(union.len() + 1); + out.push(null); + out.extend(union); + Value::Array(out) + } + Nullability::NullSecond => { + union.push(null); + Value::Array(union) + } + } + } + other => match null_order { + Nullability::NullFirst => Value::Array(vec![null, other]), + Nullability::NullSecond => Value::Array(vec![other, null]), + }, + } +} + +fn union_branch_signature(branch: &Value) -> Result<String, ArrowError> { + match branch { + Value::String(t) => Ok(format!("P:{t}")), + Value::Object(map) => { + let t = map.get("type").and_then(|v| v.as_str()).ok_or_else(|| { + ArrowError::SchemaError("Union branch object missing string 'type'".into()) + })?; + match t { + "record" | "enum" | "fixed" => { + let name = map.get("name").and_then(|v| v.as_str()).unwrap_or_default(); + Ok(format!("N:{t}:{name}")) + } + "array" | "map" => Ok(format!("C:{t}")), + other => Ok(format!("P:{other}")), + } + } + Value::Array(_) => Err(ArrowError::SchemaError( + "Avro union may not immediately contain another union".into(), + )), + _ => Err(ArrowError::SchemaError( + "Invalid JSON for Avro union branch".into(), + )), + } } fn datatype_to_avro( @@ -1028,6 +1077,10 @@ fn datatype_to_avro( DataType::Float64 => Value::String("double".into()), DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Value::String("string".into()), DataType::Binary | DataType::LargeBinary => Value::String("bytes".into()), + DataType::BinaryView => { + extras.insert("arrowBinaryView".into(), Value::Bool(true)); + Value::String("bytes".into()) + } DataType::FixedSizeBinary(len) => { let is_uuid = metadata .get("logicalType") @@ -1129,6 +1182,24 @@ fn datatype_to_avro( "items": items_schema }) } + DataType::ListView(child) | DataType::LargeListView(child) => { + if matches!(dt, DataType::LargeListView(_)) { + extras.insert("arrowLargeList".into(), Value::Bool(true)); + } + extras.insert("arrowListView".into(), Value::Bool(true)); + let items_schema = process_datatype( + child.data_type(), + child.name(), + child.metadata(), + name_gen, + null_order, + child.is_nullable(), + )?; + json!({ + "type": "array", + "items": items_schema + }) + } DataType::FixedSizeList(child, len) => { extras.insert("arrowFixedSize".into(), json!(len)); let items_schema = process_datatype( @@ -1205,10 +1276,52 @@ fn datatype_to_avro( null_order, false, )?, - DataType::Union(_, _) => { - return Err(ArrowError::NotYetImplemented( - "Arrow Union to Avro Union not yet supported".into(), - )) + DataType::Union(fields, mode) => { + let mut branches: Vec<Value> = Vec::with_capacity(fields.len()); + let mut type_ids: Vec<i32> = Vec::with_capacity(fields.len()); + for (type_id, field_ref) in fields.iter() { + // NOTE: `process_datatype` would wrap nullability; force is_nullable=false here. + let (branch_schema, _branch_extras) = datatype_to_avro( + field_ref.data_type(), + field_ref.name(), + field_ref.metadata(), + name_gen, + null_order, + )?; + // Avro unions cannot immediately contain another union + if matches!(branch_schema, Value::Array(_)) { + return Err(ArrowError::SchemaError( + "Avro union may not immediately contain another union".into(), + )); + } + branches.push(branch_schema); + type_ids.push(type_id as i32); + } + let mut seen: HashSet<String> = HashSet::with_capacity(branches.len()); + for b in &branches { + let sig = union_branch_signature(b)?; + if !seen.insert(sig) { + return Err(ArrowError::SchemaError( + "Avro union contains duplicate branch types (disallowed by spec)".into(), + )); + } + } + extras.insert( + "arrowUnionMode".into(), + Value::String( + match mode { + UnionMode::Sparse => "sparse", + UnionMode::Dense => "dense", + } + .to_string(), + ), + ); + extras.insert( + "arrowUnionTypeIds".into(), + Value::Array(type_ids.into_iter().map(|id| json!(id)).collect()), + ); + + Value::Array(branches) } other => { return Err(ArrowError::NotYetImplemented(format!( @@ -1281,7 +1394,7 @@ fn arrow_field_to_avro( mod tests { use super::*; use crate::codec::{AvroDataType, AvroField}; - use arrow_schema::{DataType, Fields, SchemaBuilder, TimeUnit}; + use arrow_schema::{DataType, Fields, SchemaBuilder, TimeUnit, UnionFields}; use serde_json::json; use std::sync::Arc; @@ -1988,17 +2101,47 @@ mod tests { } #[test] - fn test_dense_union_error() { - use arrow_schema::UnionFields; - let uf: UnionFields = vec![(0i8, Arc::new(ArrowField::new("a", DataType::Int32, false)))] - .into_iter() - .collect(); - let union_dt = DataType::Union(uf, arrow_schema::UnionMode::Dense); + fn test_dense_union() { + let uf: UnionFields = vec![ + (2i8, Arc::new(ArrowField::new("a", DataType::Int32, false))), + (7i8, Arc::new(ArrowField::new("b", DataType::Utf8, true))), + ] + .into_iter() + .collect(); + let union_dt = DataType::Union(uf, UnionMode::Dense); let s = single_field_schema(ArrowField::new("u", union_dt, false)); - let err = AvroSchema::try_from(&s).unwrap_err(); - assert!(err - .to_string() - .contains("Arrow Union to Avro Union not yet supported")); + let avro = + AvroSchema::try_from(&s).expect("Arrow Union -> Avro union conversion should succeed"); + let v: serde_json::Value = serde_json::from_str(&avro.json_string).unwrap(); + let fields = v + .get("fields") + .and_then(|x| x.as_array()) + .expect("fields array"); + let u_field = fields + .iter() + .find(|f| f.get("name").and_then(|n| n.as_str()) == Some("u")) + .expect("field 'u'"); + let union = u_field.get("type").expect("u.type"); + let arr = union.as_array().expect("u.type must be Avro union array"); + assert_eq!(arr.len(), 2, "expected two union branches"); + let first = &arr[0]; + let obj = first + .as_object() + .expect("first branch should be an object with metadata"); + assert_eq!(obj.get("type").and_then(|t| t.as_str()), Some("int")); + assert_eq!( + obj.get("arrowUnionMode").and_then(|m| m.as_str()), + Some("dense") + ); + let type_ids: Vec<i64> = obj + .get("arrowUnionTypeIds") + .and_then(|a| a.as_array()) + .expect("arrowUnionTypeIds array") + .iter() + .map(|n| n.as_i64().expect("i64")) + .collect(); + assert_eq!(type_ids, vec![2, 7], "type id ordering should be preserved"); + assert_eq!(arr[1], Value::String("string".into())); } #[test]