This is an automated email from the ASF dual-hosted git repository. kriskras99 pushed a commit to branch feat/union_builder in repository https://gitbox.apache.org/repos/asf/avro-rs.git
commit d4caf2e83f28b011bcbb896adecdb91bf5b23b0a Author: Kriskras99 <[email protected]> AuthorDate: Wed Feb 25 23:01:45 2026 +0100 feat: Add a `UnionSchemaBuilder` This also fixes a issue with the original `new` implementation where it would insert named types in the `variant_index` and then `find_schema_with_known_schemata` would use the fast path without checking the schema. `find_schema_with_known_schemata` has also been simplified to use `known_schemata` directly instead of rebuilding it with the current schema, as this would cause duplicate schema errors after the incorrect fast path was removed. The `UnionSchemaBuilder::variant_ignore_duplicates` and `UnionSchemaBuilder::contains` are needed for `avro_derive` to implement full support for enums. --- avro/src/error.rs | 14 ++- avro/src/schema/name.rs | 2 +- avro/src/schema/union.rs | 320 ++++++++++++++++++++++++++++++++++++++--------- 3 files changed, 276 insertions(+), 60 deletions(-) diff --git a/avro/src/error.rs b/avro/src/error.rs index 50a09af..12bee1e 100644 --- a/avro/src/error.rs +++ b/avro/src/error.rs @@ -300,8 +300,18 @@ pub enum Details { #[error("Unions may not directly contain a union")] GetNestedUnion, - #[error("Unions cannot contain duplicate types")] - GetUnionDuplicate, + #[error( + "Found two different maps while building Union: Schema::Map({0:?}), Schema::Map({1:?})" + )] + GetUnionDuplicateMap(Schema, Schema), + + #[error( + "Found two different arrays while building Union: Schema::Array({0:?}), Schema::Array({1:?})" + )] + GetUnionDuplicateArray(Schema, Schema), + + #[error("Unions cannot contain duplicate types, found at least two {0:?}")] + GetUnionDuplicate(SchemaKind), #[error("Unions cannot contain more than one named schema with the same name: {0}")] GetUnionDuplicateNamedSchemas(String), diff --git a/avro/src/schema/name.rs b/avro/src/schema/name.rs index e572d8b..1eeac0d 100644 --- a/avro/src/schema/name.rs +++ b/avro/src/schema/name.rs @@ -38,7 +38,7 @@ use crate::{ /// /// More information about schema names can be found in the /// [Avro specification](https://avro.apache.org/docs/++version++/specification/#names) -#[derive(Clone, Debug, Hash, PartialEq, Eq)] +#[derive(Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)] pub struct Name { pub name: String, pub namespace: Namespace, diff --git a/avro/src/schema/union.rs b/avro/src/schema/union.rs index 7510a13..428f09d 100644 --- a/avro/src/schema/union.rs +++ b/avro/src/schema/union.rs @@ -15,24 +15,32 @@ // specific language governing permissions and limitations // under the License. -use crate::AvroResult; use crate::error::Details; -use crate::schema::{Name, Namespace, ResolvedSchema, Schema, SchemaKind}; +use crate::schema::{ + DecimalSchema, InnerDecimalSchema, Name, Namespace, Schema, SchemaKind, UuidSchema, +}; use crate::types; +use crate::{AvroResult, Error}; use std::borrow::Borrow; -use std::collections::{BTreeMap, HashMap, HashSet}; +use std::collections::{BTreeMap, HashMap}; use std::fmt::Debug; +use strum::IntoDiscriminant; /// A description of a Union schema #[derive(Debug, Clone)] pub struct UnionSchema { /// The schemas that make up this union pub(crate) schemas: Vec<Schema>, - // Used to ensure uniqueness of schema inputs, and provide constant time finding of the - // schema index given a value. - // **NOTE** that this approach does not work for named types, and will have to be modified - // to support that. A simple solution is to also keep a mapping of the names used. + /// The indexes of unnamed types. + /// + /// Logical types have been reduced to their inner type. + /// Used to provide constant time finding of the + /// schema index given an unnamed type. Must only contain unnamed types. variant_index: BTreeMap<SchemaKind, usize>, + /// The indexes of named types. + /// + /// The names self aren't saved as they aren't used. + named_index: Vec<usize>, } impl UnionSchema { @@ -42,25 +50,16 @@ impl UnionSchema { /// Will return an error if `schemas` has duplicate unnamed schemas or if `schemas` /// contains a union. pub fn new(schemas: Vec<Schema>) -> AvroResult<Self> { - let mut named_schemas: HashSet<&Name> = HashSet::default(); - let mut vindex = BTreeMap::new(); - for (i, schema) in schemas.iter().enumerate() { - if let Schema::Union(_) = schema { - return Err(Details::GetNestedUnion.into()); - } else if !schema.is_named() && vindex.insert(SchemaKind::from(schema), i).is_some() { - return Err(Details::GetUnionDuplicate.into()); - } else if schema.is_named() { - let name = schema.name().unwrap(); - if !named_schemas.insert(name) { - return Err(Details::GetUnionDuplicateNamedSchemas(name.to_string()).into()); - } - vindex.insert(SchemaKind::from(schema), i); - } + let mut builder = Self::builder(); + for schema in schemas { + builder.variant(schema)?; } - Ok(UnionSchema { - schemas, - variant_index: vindex, - }) + Ok(builder.build()) + } + + /// Build a `UnionSchema` piece-by-piece. + pub fn builder() -> UnionSchemaBuilder { + UnionSchemaBuilder::new() } /// Returns a slice to all variants of this schema. @@ -70,7 +69,7 @@ impl UnionSchema { /// Returns true if the any of the variants of this `UnionSchema` is `Null`. pub fn is_nullable(&self) -> bool { - self.schemas.iter().any(|x| matches!(x, Schema::Null)) + self.variant_index.contains_key(&SchemaKind::Null) } /// Optionally returns a reference to the schema matched by this value, as well as its position @@ -86,39 +85,31 @@ impl UnionSchema { ) -> Option<(usize, &Schema)> { let schema_kind = SchemaKind::from(value); if let Some(&i) = self.variant_index.get(&schema_kind) { - // fast path + // fast path for unnamed types Some((i, &self.schemas[i])) } else { - // slow path (required for matching logical or named types) - - // first collect what schemas we already know - let mut collected_names: HashMap<Name, &Schema> = known_schemata - .map(|names| { - names - .iter() - .map(|(name, schema)| (name.clone(), schema.borrow())) - .collect() + // slow path required for named types + let known_schemata_if_none = HashMap::new(); + let known_schemata = known_schemata.unwrap_or(&known_schemata_if_none); + + self.named_index + .iter() + .copied() + .map(|i| (i, &self.schemas[i])) + .filter(|(i, s)| s.discriminant() == schema_kind) + .find(|(_i, schema)| { + let namespace = if schema.namespace().is_some() { + &schema.namespace() + } else { + enclosing_namespace + }; + + // TODO: Do this without the clone + value + .clone() + .resolve_internal(schema, known_schemata, namespace, &None) + .is_ok() }) - .unwrap_or_default(); - - self.schemas.iter().enumerate().find(|(_, schema)| { - let resolved_schema = ResolvedSchema::new_with_known_schemata( - vec![*schema], - enclosing_namespace, - &collected_names, - ) - .expect("Schema didn't successfully parse"); - let resolved_names = resolved_schema.names_ref; - - // extend known schemas with just resolved names - collected_names.extend(resolved_names); - let namespace = &schema.namespace().or_else(|| enclosing_namespace.clone()); - - value - .clone() - .resolve_internal(schema, &collected_names, namespace, &None) - .is_ok() - }) } } } @@ -130,11 +121,172 @@ impl PartialEq for UnionSchema { } } +pub struct UnionSchemaBuilder { + schemas: Vec<Schema>, + names: BTreeMap<Name, usize>, + variant_index: BTreeMap<SchemaKind, usize>, +} + +impl UnionSchemaBuilder { + /// Create a builder. + /// + /// See also [`UnionSchema::builder`]. + pub fn new() -> Self { + Self { + schemas: Vec::new(), + names: BTreeMap::new(), + variant_index: BTreeMap::new(), + } + } + + /// Add a variant to this union, if it already exists ignore it. + /// + /// # Errors + /// Will return a [`Details::GetUnionDuplicateMap`] or [`Details::GetUnionDuplicateArray`] if + /// duplicate maps or arrays are encountered with different subtypes. + pub fn variant_ignore_duplicates(&mut self, schema: Schema) -> Result<&mut Self, Error> { + if let Some(name) = schema.name() { + if let Some(current) = self.names.get(name).copied() { + if self.schemas[current] != schema { + return Err(Details::GetUnionDuplicateNamedSchemas(name.to_string()).into()); + } + } else { + self.names.insert(name.clone(), self.schemas.len()); + self.schemas.push(schema); + } + } else if let Schema::Map(_) = &schema { + if let Some(index) = self.variant_index.get(&SchemaKind::Map).copied() { + if self.schemas[index] != schema { + return Err( + Details::GetUnionDuplicateMap(self.schemas.remove(index), schema).into(), + ); + } + } else { + self.variant_index + .insert(SchemaKind::Map, self.schemas.len()); + self.schemas.push(schema); + } + } else if let Schema::Array(_) = &schema { + if let Some(index) = self.variant_index.get(&SchemaKind::Array).copied() { + if self.schemas[index] != schema { + return Err( + Details::GetUnionDuplicateMap(self.schemas.remove(index), schema).into(), + ); + } + } else { + self.variant_index + .insert(SchemaKind::Array, self.schemas.len()); + self.schemas.push(schema); + } + } else { + let discriminant = Self::schema_kind_without_logical_type(&schema); + if discriminant == SchemaKind::Union { + return Err(Details::GetNestedUnion.into()); + } + if !self.variant_index.contains_key(&discriminant) { + self.variant_index.insert(discriminant, self.schemas.len()); + self.schemas.push(schema); + } + } + Ok(self) + } + + /// Add a variant to this union. + /// + /// # Errors + /// Will return a [`Details::GetUnionDuplicateNamedSchemas`] or [`Details::GetUnionDuplicate`] if + /// duplicate names or schema kinds are found. + pub fn variant(&mut self, schema: Schema) -> Result<&mut Self, Error> { + if let Some(name) = schema.name() { + if self.names.contains_key(name) { + return Err(Details::GetUnionDuplicateNamedSchemas(name.to_string()).into()); + } else { + self.names.insert(name.clone(), self.schemas.len()); + self.schemas.push(schema); + } + } else { + let discriminant = Self::schema_kind_without_logical_type(&schema); + if discriminant == SchemaKind::Union { + return Err(Details::GetNestedUnion.into()); + } + if self.variant_index.contains_key(&discriminant) { + return Err(Details::GetUnionDuplicate(discriminant).into()); + } else { + self.variant_index.insert(discriminant, self.schemas.len()); + self.schemas.push(schema); + } + } + Ok(self) + } + + /// Check if a schema already exists in this union. + pub fn contains(&self, schema: &Schema) -> bool { + if let Some(name) = schema.name() { + if let Some(current) = self.names.get(name).copied() { + &self.schemas[current] == schema + } else { + false + } + } else { + let discriminant = Self::schema_kind_without_logical_type(schema); + if let Some(index) = self.variant_index.get(&discriminant).copied() { + &self.schemas[index] == schema + } else { + false + } + } + } + + /// Create the `UnionSchema`. + pub fn build(mut self) -> UnionSchema { + self.schemas.shrink_to_fit(); + UnionSchema { + variant_index: self.variant_index, + named_index: self.names.into_values().collect(), + schemas: self.schemas, + } + } + + /// Get the [`SchemaKind`] of a [`Schema`] converting logical types to their inner type. + fn schema_kind_without_logical_type(schema: &Schema) -> SchemaKind { + let kind = schema.discriminant(); + match kind { + SchemaKind::Date | SchemaKind::TimeMillis => SchemaKind::Int, + SchemaKind::TimeMicros + | SchemaKind::TimestampMillis + | SchemaKind::TimestampMicros + | SchemaKind::TimestampNanos + | SchemaKind::LocalTimestampMillis + | SchemaKind::LocalTimestampMicros + | SchemaKind::LocalTimestampNanos => SchemaKind::Long, + SchemaKind::Uuid => match schema { + Schema::Uuid(UuidSchema::Bytes) => SchemaKind::Bytes, + Schema::Uuid(UuidSchema::String) => SchemaKind::String, + Schema::Uuid(UuidSchema::Fixed(_)) => SchemaKind::Fixed, + _ => unreachable!(), + }, + SchemaKind::Decimal => match schema { + Schema::Decimal(DecimalSchema { + inner: InnerDecimalSchema::Bytes, + .. + }) => SchemaKind::Bytes, + Schema::Decimal(DecimalSchema { + inner: InnerDecimalSchema::Fixed(_), + .. + }) => SchemaKind::Fixed, + _ => unreachable!(), + }, + SchemaKind::Duration => SchemaKind::Fixed, + _ => kind, + } + } +} + #[cfg(test)] mod tests { use super::*; use crate::error::{Details, Error}; - use crate::schema::RecordSchema; + use crate::schema::{EnumSchema, FixedSchema, RecordSchema}; use apache_avro_test_helper::TestResult; #[test] @@ -165,4 +317,58 @@ mod tests { Ok(()) } + + #[test] + fn avro_rs_xxx_union_schema_builder() -> TestResult { + let mut builder = UnionSchema::builder(); + builder.variant(Schema::Null)?; + assert!(builder.variant(Schema::Null).is_err()); + builder.variant_ignore_duplicates(Schema::Null)?; + + let enum_schema = Schema::Enum(EnumSchema { + name: Name::new("ABC")?, + aliases: None, + doc: None, + symbols: vec!["A".into(), "B".into(), "C".into()], + default: None, + attributes: Default::default(), + }); + let enum_schema2 = Schema::Enum(EnumSchema { + name: Name::new("ABC")?, + aliases: None, + doc: None, + symbols: vec!["A".into(), "B".into(), "C".into(), "D".into()], + default: None, + attributes: Default::default(), + }); + let fixed_schema = Schema::Fixed(FixedSchema { + name: Name::new("ABC")?, + aliases: None, + doc: None, + size: 1, + attributes: Default::default(), + }); + builder.variant(enum_schema.clone())?; + assert!(builder.variant(enum_schema.clone()).is_err()); + builder.variant_ignore_duplicates(enum_schema.clone())?; + // Name is the same but different schemas, so should always fail + assert!(builder.variant(fixed_schema.clone()).is_err()); + assert!( + builder + .variant_ignore_duplicates(fixed_schema.clone()) + .is_err() + ); + // Name and schema type are the same but symbols are different + assert!(builder.variant(enum_schema2.clone()).is_err()); + assert!( + builder + .variant_ignore_duplicates(enum_schema2.clone()) + .is_err() + ); + + let union = builder.build(); + assert_eq!(union.variants(), &[Schema::Null, enum_schema]); + + Ok(()) + } }
