This is an automated email from the ASF dual-hosted git repository. kriskras99 pushed a commit to branch feat/enums in repository https://gitbox.apache.org/repos/asf/avro-rs.git
commit 9211ad411489894efea9b766313ffbe5e24217d5 Author: Kriskras99 <[email protected]> AuthorDate: Sun Feb 22 22:02:55 2026 +0100 wip: Full enum support --- avro/src/error.rs | 14 +- avro/src/schema/mod.rs | 4 + avro/src/schema/name.rs | 8 +- avro/src/schema/union.rs | 217 +++++++++- avro/src/serde/derive.rs | 20 + avro/src/serde/mod.rs | 1 + .../src/serde/{ser_schema.rs => ser_schema/mod.rs} | 260 +++++------- avro/src/serde/ser_schema/tuples.rs | 243 ++++++++++++ avro/src/serde/ser_schema2/mod.rs | 441 +++++++++++++++++++++ avro/src/serde/ser_schema2/union.rs | 355 +++++++++++++++++ avro/src/util.rs | 5 +- avro_derive/src/attributes/avro.rs | 16 + avro_derive/src/attributes/mod.rs | 123 ++++-- avro_derive/src/enums/discriminator_value.rs | 130 ++++++ avro_derive/src/enums/mod.rs | 66 +++ avro_derive/src/enums/plain.rs | 49 +++ avro_derive/src/enums/union.rs | 14 + avro_derive/src/enums/union_of_records.rs | 14 + avro_derive/src/lib.rs | 87 +--- avro_derive/tests/enum.rs | 134 +++++++ 20 files changed, 1922 insertions(+), 279 deletions(-) diff --git a/avro/src/error.rs b/avro/src/error.rs index 6b951c1..eda9483 100644 --- a/avro/src/error.rs +++ b/avro/src/error.rs @@ -300,8 +300,18 @@ pub enum Details { #[error("Unions may not directly contain a union")] GetNestedUnion, - #[error("Unions cannot contain duplicate types")] - GetUnionDuplicate, + #[error( + "Found two different maps while building Union: Schema::Map({0:?}), Schema::Map({1:?})" + )] + GetUnionDuplicateMap(Schema, Schema), + + #[error( + "Found two different arrays while building Union: Schema::Array({0:?}), Schema::Array({1:?})" + )] + GetUnionDuplicateArray(Schema, Schema), + + #[error("Unions cannot contain duplicate types, found at least two {0:?}")] + GetUnionDuplicate(SchemaKind), #[error("Unions cannot contain more than one named schema with the same name: {0}")] GetUnionDuplicateNamedSchemas(String), diff --git a/avro/src/schema/mod.rs b/avro/src/schema/mod.rs index a5166d8..029292e 100644 --- a/avro/src/schema/mod.rs +++ b/avro/src/schema/mod.rs @@ -50,6 +50,10 @@ pub(crate) use crate::schema::resolve::{ ResolvedOwnedSchema, resolve_names, resolve_names_with_schemata, }; pub use crate::schema::{ + builders::{ + SchemaArrayBuilder, SchemaEnumBuilder, SchemaFixedBuilder, SchemaMapBuilder, + SchemaRecordBuilder, + }, name::{Alias, Aliases, Name, Names, NamesRef, Namespace}, record::{ RecordField, RecordFieldBuilder, RecordFieldOrder, RecordSchema, RecordSchemaBuilder, diff --git a/avro/src/schema/name.rs b/avro/src/schema/name.rs index e572d8b..b551584 100644 --- a/avro/src/schema/name.rs +++ b/avro/src/schema/name.rs @@ -202,12 +202,12 @@ impl Alias { Name::new(name).map(Self) } - pub fn name(&self) -> String { - self.0.name.clone() + pub fn name(&self) -> &str { + &self.0.name } - pub fn namespace(&self) -> Namespace { - self.0.namespace.clone() + pub fn namespace(&self) -> &Namespace { + &self.0.namespace } pub fn fullname(&self, default_namespace: Namespace) -> String { diff --git a/avro/src/schema/union.rs b/avro/src/schema/union.rs index 7510a13..8e2f085 100644 --- a/avro/src/schema/union.rs +++ b/avro/src/schema/union.rs @@ -15,13 +15,17 @@ // specific language governing permissions and limitations // under the License. -use crate::AvroResult; use crate::error::Details; -use crate::schema::{Name, Namespace, ResolvedSchema, Schema, SchemaKind}; +use crate::schema::{ + DecimalSchema, InnerDecimalSchema, Name, NamesRef, Namespace, ResolvedSchema, Schema, + SchemaKind, UuidSchema, +}; use crate::types; +use crate::{AvroResult, Error}; use std::borrow::Borrow; use std::collections::{BTreeMap, HashMap, HashSet}; use std::fmt::Debug; +use strum::IntoDiscriminant; /// A description of a Union schema #[derive(Debug, Clone)] @@ -42,25 +46,15 @@ impl UnionSchema { /// Will return an error if `schemas` has duplicate unnamed schemas or if `schemas` /// contains a union. pub fn new(schemas: Vec<Schema>) -> AvroResult<Self> { - let mut named_schemas: HashSet<&Name> = HashSet::default(); - let mut vindex = BTreeMap::new(); - for (i, schema) in schemas.iter().enumerate() { - if let Schema::Union(_) = schema { - return Err(Details::GetNestedUnion.into()); - } else if !schema.is_named() && vindex.insert(SchemaKind::from(schema), i).is_some() { - return Err(Details::GetUnionDuplicate.into()); - } else if schema.is_named() { - let name = schema.name().unwrap(); - if !named_schemas.insert(name) { - return Err(Details::GetUnionDuplicateNamedSchemas(name.to_string()).into()); - } - vindex.insert(SchemaKind::from(schema), i); - } + let mut builder = Self::builder(); + for schema in schemas { + builder.variant(schema)?; } - Ok(UnionSchema { - schemas, - variant_index: vindex, - }) + Ok(builder.build()) + } + + pub fn builder() -> UnionSchemaBuilder { + UnionSchemaBuilder::new() } /// Returns a slice to all variants of this schema. @@ -121,6 +115,189 @@ impl UnionSchema { }) } } + + pub(crate) fn index(&self) -> &BTreeMap<SchemaKind, usize> { + &self.variant_index + } + + /// Strips the logical type of primitive types. + /// + /// Leaves logical types of complex types untouched. + fn schema_kind_without_logical_type(schema: &Schema) -> SchemaKind { + let kind = schema.discriminant(); + match kind { + SchemaKind::Date | SchemaKind::TimeMillis => SchemaKind::Int, + SchemaKind::TimeMicros + | SchemaKind::TimestampMillis + | SchemaKind::TimestampMicros + | SchemaKind::TimestampNanos + | SchemaKind::LocalTimestampMillis + | SchemaKind::LocalTimestampMicros + | SchemaKind::LocalTimestampNanos => SchemaKind::Long, + SchemaKind::Uuid => match schema { + Schema::Uuid(UuidSchema::Bytes) => SchemaKind::Bytes, + Schema::Uuid(UuidSchema::String) => SchemaKind::String, + Schema::Uuid(UuidSchema::Fixed(_)) => SchemaKind::Fixed, + _ => unreachable!(), + }, + SchemaKind::Decimal => match schema { + Schema::Decimal(DecimalSchema { + inner: InnerDecimalSchema::Bytes, + .. + }) => SchemaKind::Bytes, + Schema::Decimal(DecimalSchema { + inner: InnerDecimalSchema::Fixed(_), + .. + }) => SchemaKind::Fixed, + _ => unreachable!(), + }, + SchemaKind::Duration => SchemaKind::Fixed, + _ => kind, + } + } +} + +pub struct UnionSchemaBuilder { + schemas: Vec<Schema>, + array_items_type: Option<Schema>, + map_types_type: Option<Schema>, + kinds: HashSet<SchemaKind>, + names: HashSet<Name>, + variant_index: BTreeMap<SchemaKind, usize>, +} + +impl UnionSchemaBuilder { + pub fn new() -> Self { + Self { + schemas: Vec::new(), + array_items_type: None, + map_types_type: None, + kinds: HashSet::new(), + names: HashSet::new(), + variant_index: BTreeMap::new(), + } + } + + /// Add a variant to this union, if it already exists ignore it. + /// + /// # Errors + /// Will return a [`Details::GetUnionDuplicateMap`] or [`Details::GetUnionDuplicateArray`] if + /// duplicate maps or arrays are encountered with different subtypes. + pub fn variant_ignore_duplicates(&mut self, schema: Schema) -> Result<&mut Self, Error> { + if let Some(name) = schema.name() { + // Returns true if this name is not known yet + if self.names.insert(name.clone()) { + self.schemas.push(schema); + } + } else if let Schema::Map(map) = &schema { + if let Some(set_schema) = &self.map_types_type { + if set_schema != map.types.as_ref() { + return Err(Details::GetUnionDuplicateMap(set_schema.clone(), schema).into()); + } + } else { + self.map_types_type = Some(map.types.as_ref().clone()); + self.variant_index + .insert(SchemaKind::Map, self.schemas.len()); + self.schemas.push(schema); + } + } else if let Schema::Array(array) = &schema { + if let Some(set_schema) = &self.array_items_type { + if set_schema != array.items.as_ref() { + return Err(Details::GetUnionDuplicateArray(set_schema.clone(), schema).into()); + } + } else { + self.array_items_type = Some(array.items.as_ref().clone()); + self.variant_index + .insert(SchemaKind::Array, self.schemas.len()); + self.schemas.push(schema); + } + } else { + let discriminant = UnionSchema::schema_kind_without_logical_type(&schema); + // Returns true if this discriminant wasn't known yet + if self.kinds.insert(discriminant) { + self.variant_index.insert(discriminant, self.schemas.len()); + self.schemas.push(schema); + } + } + Ok(self) + } + + /// Add a variant to this union. + /// + /// # Errors + /// Will return a [`Details::GetUnionDuplicateNamedSchemas`] or [`Details::GetUnionDuplicate`] if + /// duplicate names or schema kinds are found. + pub fn variant(&mut self, schema: Schema) -> Result<&mut Self, Error> { + if let Some(name) = schema.name() { + if self.names.insert(name.clone()) { + self.schemas.push(schema); + } else { + return Err(Details::GetUnionDuplicateNamedSchemas(name.to_string()).into()); + } + } else if let Schema::Map(_) = &schema { + if self.map_types_type.is_some() { + return Err(Details::GetUnionDuplicate(SchemaKind::Map).into()); + } else { + self.map_types_type = Some(schema.clone()); + self.variant_index + .insert(SchemaKind::Map, self.schemas.len()); + self.schemas.push(schema); + } + } else if let Schema::Array(_) = &schema { + if self.array_items_type.is_some() { + return Err(Details::GetUnionDuplicate(SchemaKind::Array).into()); + } else { + self.array_items_type = Some(schema.clone()); + self.variant_index + .insert(SchemaKind::Array, self.schemas.len()); + self.schemas.push(schema); + } + } else { + let discriminant = UnionSchema::schema_kind_without_logical_type(&schema); + // Returns true if this discriminant wasn't known yet + if self.kinds.insert(discriminant) { + self.variant_index.insert(discriminant, self.schemas.len()); + self.schemas.push(schema); + } else { + return Err(Details::GetUnionDuplicate(discriminant).into()); + } + } + Ok(self) + } + + /// Check if a schema already exists in this union. + pub fn contains(&self, schema: &Schema) -> bool { + if let Some(name) = schema.name() { + self.names.contains(name) + } else if let Schema::Map(map) = &schema { + if let Some(set_schema) = &self.map_types_type + && set_schema == map.types.as_ref() + { + true + } else { + false + } + } else if let Schema::Array(array) = &schema { + if let Some(set_schema) = &self.array_items_type + && set_schema == array.items.as_ref() + { + true + } else { + false + } + } else { + let discriminant = UnionSchema::schema_kind_without_logical_type(schema); + self.kinds.contains(&discriminant) + } + } + + pub fn build(mut self) -> UnionSchema { + self.schemas.shrink_to_fit(); + UnionSchema { + variant_index: self.variant_index, + schemas: self.schemas, + } + } } // No need to compare variant_index, it is derivative of schemas. diff --git a/avro/src/serde/derive.rs b/avro/src/serde/derive.rs index 0ce846e..5186ae1 100644 --- a/avro/src/serde/derive.rs +++ b/avro/src/serde/derive.rs @@ -19,6 +19,7 @@ use crate::Schema; use crate::schema::{ FixedSchema, Name, Namespace, RecordField, RecordSchema, UnionSchema, UuidSchema, }; +use serde_json::Value; use std::borrow::Cow; use std::collections::{HashMap, HashSet}; @@ -684,6 +685,7 @@ where } } +// TODO: This does not match the Serde implementation and therefore does not work impl AvroSchemaComponent for core::time::Duration { /// The schema is [`Schema::Duration`] with the name `duration`. /// @@ -873,6 +875,24 @@ impl AvroSchemaComponent for i128 { } } +impl AvroSchemaComponent for () { + fn get_schema_in_ctxt(_: &mut HashSet<Name>, _: &Namespace) -> Schema { + Schema::Null + } + + fn get_record_fields_in_ctxt( + _: usize, + _: &mut HashSet<Name>, + _: &Namespace, + ) -> Option<Vec<RecordField>> { + None + } + + fn field_default() -> Option<Value> { + Some(Value::Null) + } +} + #[cfg(test)] mod tests { use crate::{ diff --git a/avro/src/serde/mod.rs b/avro/src/serde/mod.rs index b3bfd2a..281bf9a 100644 --- a/avro/src/serde/mod.rs +++ b/avro/src/serde/mod.rs @@ -111,6 +111,7 @@ mod de; mod derive; mod ser; pub(crate) mod ser_schema; +mod ser_schema2; mod util; mod with; diff --git a/avro/src/serde/ser_schema.rs b/avro/src/serde/ser_schema/mod.rs similarity index 95% rename from avro/src/serde/ser_schema.rs rename to avro/src/serde/ser_schema/mod.rs index 61534b9..7296606 100644 --- a/avro/src/serde/ser_schema.rs +++ b/avro/src/serde/ser_schema/mod.rs @@ -17,7 +17,12 @@ //! Logic for serde-compatible schema-aware serialization which writes directly to a writer. -use crate::schema::{DecimalSchema, InnerDecimalSchema, UuidSchema}; +mod tuples; + +use crate::schema::{DecimalSchema, InnerDecimalSchema, SchemaKind, UuidSchema}; +use crate::serde::ser_schema::tuples::{ + SchemaAwareTupleSerializer, SchemaAwareTupleSerializerRecord, +}; use crate::{ bigdecimal::big_decimal_as_bytes, encode::{encode_int, encode_long}, @@ -28,6 +33,7 @@ use crate::{ use bigdecimal::BigDecimal; use serde::{Serialize, ser}; use std::{borrow::Cow, cmp::Ordering, collections::HashMap, io::Write, str::FromStr}; +use strum::IntoDiscriminant; const COLLECTION_SERIALIZER_ITEM_LIMIT: usize = 1024; const COLLECTION_SERIALIZER_DEFAULT_INIT_ITEM_CAPACITY: usize = 32; @@ -532,70 +538,6 @@ impl<W: Write> ser::SerializeMap for SchemaAwareWriteSerializeMapOrStruct<'_, '_ } } -/// The tuple struct serializer for [`SchemaAwareWriteSerializer`]. -/// -/// This can serialize to an Avro array, record, or big-decimal. -/// When serializing to a record, fields must be provided in the correct order, since no names are provided. -pub enum SchemaAwareWriteSerializeTupleStruct<'a, 's, W: Write> { - Record(SchemaAwareWriteSerializeStruct<'a, 's, W>), - Array(SchemaAwareWriteSerializeSeq<'a, 's, W>), -} - -impl<W: Write> SchemaAwareWriteSerializeTupleStruct<'_, '_, W> { - fn serialize_field<T>(&mut self, value: &T) -> Result<(), Error> - where - T: ?Sized + ser::Serialize, - { - use SchemaAwareWriteSerializeTupleStruct::*; - match self { - Record(_record_ser) => { - unimplemented!("Tuple struct serialization to record is not supported!"); - } - Array(array_ser) => array_ser.serialize_element(&value), - } - } - - fn end(self) -> Result<usize, Error> { - use SchemaAwareWriteSerializeTupleStruct::*; - match self { - Record(record_ser) => record_ser.end(), - Array(array_ser) => array_ser.end(), - } - } -} - -impl<W: Write> ser::SerializeTupleStruct for SchemaAwareWriteSerializeTupleStruct<'_, '_, W> { - type Ok = usize; - type Error = Error; - - fn serialize_field<T>(&mut self, value: &T) -> Result<(), Self::Error> - where - T: ?Sized + ser::Serialize, - { - self.serialize_field(&value) - } - - fn end(self) -> Result<Self::Ok, Self::Error> { - self.end() - } -} - -impl<W: Write> ser::SerializeTupleVariant for SchemaAwareWriteSerializeTupleStruct<'_, '_, W> { - type Ok = usize; - type Error = Error; - - fn serialize_field<T>(&mut self, value: &T) -> Result<(), Self::Error> - where - T: ?Sized + ser::Serialize, - { - self.serialize_field(&value) - } - - fn end(self) -> Result<Self::Ok, Self::Error> { - self.end() - } -} - /// A [`Serializer`](ser::Serializer) implementation that serializes directly to raw Avro data. /// /// If data does not match with the schema it will return an error. @@ -1587,8 +1529,8 @@ impl<'s, W: Write> SchemaAwareWriteSerializer<'s, W> { &'a mut self, len: usize, schema: &'s Schema, - ) -> Result<SchemaAwareWriteSerializeSeq<'a, 's, W>, Error> { - let create_error = |cause: String| { + ) -> Result<SchemaAwareTupleSerializer<'a, 's, W>, Error> { + let create_error = |cause: &str| { Error::new(Details::SerializeValueWithSchema { value_type: "tuple", value: format!("tuple (len={len}). Cause: {cause}"), @@ -1597,26 +1539,16 @@ impl<'s, W: Write> SchemaAwareWriteSerializer<'s, W> { }; match schema { - Schema::Array(array_schema) => Ok(SchemaAwareWriteSerializeSeq::new( - self, - array_schema.items.as_ref(), - Some(len), - )), - Schema::Union(union_schema) => { - for (i, variant_schema) in union_schema.schemas.iter().enumerate() { - match variant_schema { - Schema::Array(_) => { - encode_int(i as i32, &mut *self.writer)?; - return self.serialize_tuple_with_schema(len, variant_schema); - } - _ => { /* skip */ } - } - } - Err(create_error(format!( - "Expected Array schema in {union_schema:?}" - ))) + Schema::Null if len == 0 => Ok(SchemaAwareTupleSerializer::null(0)), + _ if len == 1 => Ok(SchemaAwareTupleSerializer::transparent(self, schema, 0)), + Schema::Record(record) => Ok(SchemaAwareTupleSerializer::record(self, record, 0)), + Schema::Ref { name } => { + let schema = self.get_ref_schema(name)?; + self.serialize_tuple_with_schema(len, schema) } - _ => Err(create_error(format!("Expected: {schema}. Got: Array"))), + _ => Err(create_error( + "Expected: Null for 0-tuple, anything for 1-tuple, record for >=2-tuple.", + )), } } @@ -1625,60 +1557,25 @@ impl<'s, W: Write> SchemaAwareWriteSerializer<'s, W> { name: &'static str, len: usize, schema: &'s Schema, - ) -> Result<SchemaAwareWriteSerializeTupleStruct<'a, 's, W>, Error> { - let create_error = |cause: String| { + ) -> Result<SchemaAwareTupleSerializerRecord<'a, 's, W>, Error> { + let create_error = |cause: &str| { Error::new(Details::SerializeValueWithSchema { value_type: "tuple struct", - value: format!( - "{name}({}). Cause: {cause}", - vec!["?"; len].as_slice().join(",") - ), + value: format!("{name}({len}). Cause: {cause}"), schema: schema.clone(), }) }; match schema { - Schema::Array(sch) => Ok(SchemaAwareWriteSerializeTupleStruct::Array( - SchemaAwareWriteSerializeSeq::new(self, &sch.items, Some(len)), - )), - Schema::Record(sch) => Ok(SchemaAwareWriteSerializeTupleStruct::Record( - SchemaAwareWriteSerializeStruct::new(self, sch), - )), + Schema::Record(record) => { + assert_eq!(name, record.name.name, "Expected names to be the same"); + Ok(SchemaAwareTupleSerializerRecord::new(self, record, 0)) + } Schema::Ref { name: ref_name } => { let ref_schema = self.get_ref_schema(ref_name)?; self.serialize_tuple_struct_with_schema(name, len, ref_schema) } - Schema::Union(union_schema) => { - for (i, variant_schema) in union_schema.schemas.iter().enumerate() { - match variant_schema { - Schema::Record(inner) => { - if inner.fields.len() == len { - encode_int(i as i32, &mut *self.writer)?; - return self.serialize_tuple_struct_with_schema( - name, - len, - variant_schema, - ); - } - } - Schema::Array(_) | Schema::Ref { name: _ } => { - encode_int(i as i32, &mut *self.writer)?; - return self.serialize_tuple_struct_with_schema( - name, - len, - variant_schema, - ); - } - _ => { /* skip */ } - } - } - Err(create_error(format!( - "Expected Record, Array or Ref schema in {union_schema:?}" - ))) - } - _ => Err(create_error(format!( - "Expected Record, Array, Ref or Union schema. Got: {schema}" - ))), + _ => Err(create_error("Expected Record schema.")), } } @@ -1689,35 +1586,96 @@ impl<'s, W: Write> SchemaAwareWriteSerializer<'s, W> { variant: &'static str, len: usize, schema: &'s Schema, - ) -> Result<SchemaAwareWriteSerializeTupleStruct<'a, 's, W>, Error> { - let create_error = |cause: String| { + ) -> Result<SchemaAwareTupleSerializer<'a, 's, W>, Error> { + let create_error = |cause: &str| { Error::new(Details::SerializeValueWithSchema { value_type: "tuple variant", - value: format!( - "{name}::{variant}({}) (index={variant_index}). Cause: {cause}", - vec!["?"; len].as_slice().join(",") - ), + value: format!("{name}::{variant}({len}) (index={variant_index}). Cause: {cause}",), schema: schema.clone(), }) }; match schema { + // Bare union or Union of Records Schema::Union(union_schema) => { - let variant_schema = union_schema - .schemas - .get(variant_index as usize) - .ok_or_else(|| { - create_error(format!( - "Cannot find a variant at position {variant_index} in {union_schema:?}" - )) - })?; + todo!() + } + // Discriminator/Value + Schema::Record(record) => { + assert_eq!(name, record.name.name, "Expected names to be the same"); + assert_eq!( + 2, + record.fields.len(), + "Expected two fields for `discriminator_value` record" + ); - encode_int(variant_index as i32, &mut self.writer)?; - self.serialize_tuple_struct_with_schema(variant, len, variant_schema) + // Write the discriminator + let Schema::Enum(enum_schema) = &record.fields[0].schema else { + panic!("Expected enum for the first field of `discriminator_value` record"); + }; + let symbol_index = if enum_schema.symbols[variant_index as usize] == variant { + variant_index as i32 + } else if let Some((index, _symbol)) = enum_schema + .symbols + .iter() + .enumerate() + .find(|(_i, s)| *s == variant) + { + index as i32 + } else { + panic!("Could not find `{variant}` in `{:?}`", enum_schema.symbols) + }; + let mut bytes_written = encode_int(symbol_index, &mut self.writer)?; + + // Return the value writer + let Schema::Union(union_schema) = &record.fields[1].schema else { + panic!("Expected union for the second field of `discriminator_value` record"); + }; + if len == 0 { + let Some((index, _)) = union_schema + .variants() + .iter() + .map(Schema::discriminant) + .enumerate() + .find(|(i, k)| *k == SchemaKind::Null) + else { + panic!( + "Expected to find Schema::Null in variants {:?}", + union_schema.variants() + ) + }; + bytes_written += encode_int(index as i32, &mut *self.writer)?; + Ok(SchemaAwareTupleSerializer::null(bytes_written)) + } else if len == 1 { + // Maybe UnionAwareWriteSerializer or something?? + todo!("Deal with unions in a nice way") + } else { + let Some((index, Schema::Record(record))) = union_schema + .variants() + .iter() + .enumerate() + .find(|(i, s)| s.name().is_some_and(|n| n.name == variant)) + else { + panic!("Expected record with name `{variant}` for >=2-tuple"); + }; + assert_eq!( + record.fields.len(), + len, + "Expected record to have the same amount of fields as the tuple" + ); + bytes_written += encode_int(index as i32, &mut *self.writer)?; + Ok(SchemaAwareTupleSerializer::record( + self, + record, + bytes_written, + )) + } } - _ => Err(create_error(format!( - "Expected Union schema. Got: {schema}" - ))), + Schema::Ref { name: schema_name } => { + let schema = self.get_ref_schema(schema_name)?; + self.serialize_tuple_variant_with_schema(name, variant_index, variant, len, schema) + } + _ => Err(create_error("Expected Union or Record schema")), } } @@ -1858,9 +1816,9 @@ impl<'a, 's, W: Write> ser::Serializer for &'a mut SchemaAwareWriteSerializer<'s type Ok = usize; type Error = Error; type SerializeSeq = SchemaAwareWriteSerializeSeq<'a, 's, W>; - type SerializeTuple = SchemaAwareWriteSerializeSeq<'a, 's, W>; - type SerializeTupleStruct = SchemaAwareWriteSerializeTupleStruct<'a, 's, W>; - type SerializeTupleVariant = SchemaAwareWriteSerializeTupleStruct<'a, 's, W>; + type SerializeTuple = SchemaAwareTupleSerializer<'a, 's, W>; + type SerializeTupleStruct = SchemaAwareTupleSerializerRecord<'a, 's, W>; + type SerializeTupleVariant = SchemaAwareTupleSerializer<'a, 's, W>; type SerializeMap = SchemaAwareWriteSerializeMapOrStruct<'a, 's, W>; type SerializeStruct = SchemaAwareWriteSerializeStruct<'a, 's, W>; type SerializeStructVariant = SchemaAwareWriteSerializeStruct<'a, 's, W>; diff --git a/avro/src/serde/ser_schema/tuples.rs b/avro/src/serde/ser_schema/tuples.rs new file mode 100644 index 0000000..2730d1d --- /dev/null +++ b/avro/src/serde/ser_schema/tuples.rs @@ -0,0 +1,243 @@ +use crate::schema::RecordSchema; +use crate::serde::ser_schema::SchemaAwareWriteSerializer; +use crate::{Error, Schema}; +use serde::Serialize; +use serde::ser::{SerializeTuple, SerializeTupleStruct, SerializeTupleVariant}; +use std::io::Write; + +pub enum SchemaAwareTupleSerializer<'a, 's, W: Write> { + Null(usize), + Transparent(SchemaAwareTupleSerializerTransparent<'a, 's, W>), + Record(SchemaAwareTupleSerializerRecord<'a, 's, W>), +} + +impl<'a, 's, W: Write> SchemaAwareTupleSerializer<'a, 's, W> { + pub fn null(bytes_written: usize) -> Self { + Self::Null(bytes_written) + } + + pub fn transparent( + ser: &'a mut SchemaAwareWriteSerializer<'s, W>, + schema: &'s Schema, + bytes_written: usize, + ) -> Self { + Self::Transparent(SchemaAwareTupleSerializerTransparent::new( + ser, + schema, + bytes_written, + )) + } + + pub fn record( + ser: &'a mut SchemaAwareWriteSerializer<'s, W>, + schema: &'s RecordSchema, + bytes_written: usize, + ) -> Self { + Self::Record(SchemaAwareTupleSerializerRecord::new( + ser, + schema, + bytes_written, + )) + } +} + +impl<'a, 's, W: Write> SerializeTuple for SchemaAwareTupleSerializer<'a, 's, W> { + type Ok = usize; + type Error = Error; + + fn serialize_element<T>(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + match self { + SchemaAwareTupleSerializer::Null(_) => { + unreachable!("The Null variant is only created for empty tuples"); + } + SchemaAwareTupleSerializer::Transparent(transparent) => { + transparent.serialize_element(value) + } + SchemaAwareTupleSerializer::Record(record) => record.serialize_element(value), + } + } + + fn end(self) -> Result<Self::Ok, Self::Error> { + match self { + SchemaAwareTupleSerializer::Null(bytes) => Ok(bytes), + SchemaAwareTupleSerializer::Transparent(transparent) => { + SerializeTuple::end(transparent) + } + SchemaAwareTupleSerializer::Record(record) => SerializeTuple::end(record), + } + } +} + +impl<'a, 's, W: Write> SerializeTupleVariant for SchemaAwareTupleSerializer<'a, 's, W> { + type Ok = usize; + type Error = Error; + + fn serialize_field<T>(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.serialize_element(value) + } + + fn end(self) -> Result<Self::Ok, Self::Error> { + SerializeTuple::end(self) + } +} + +/// Serialize a tuple with one field as the inner field. +pub struct SchemaAwareTupleSerializerTransparent<'a, 's, W: Write> { + ser: &'a mut SchemaAwareWriteSerializer<'s, W>, + schema: &'s Schema, + bytes_written: usize, + field_written: bool, +} + +impl<'a, 's, W: Write> SchemaAwareTupleSerializerTransparent<'a, 's, W> { + pub fn new( + ser: &'a mut SchemaAwareWriteSerializer<'s, W>, + schema: &'s Schema, + bytes_written: usize, + ) -> Self { + Self { + ser, + schema, + bytes_written, + field_written: false, + } + } +} + +impl<'a, 's, W: Write> SerializeTuple for SchemaAwareTupleSerializerTransparent<'a, 's, W> { + type Ok = usize; + type Error = Error; + + fn serialize_element<T>(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + if self.field_written { + unreachable!("This struct should only be created for tuples of length 1"); + } + self.bytes_written += value.serialize(&mut *self.ser)?; + self.field_written = true; + Ok(()) + } + + fn end(self) -> Result<Self::Ok, Self::Error> { + if !self.field_written { + unreachable!("This struct should only be created for tuples of length 1"); + } + Ok(self.bytes_written) + } +} + +impl<'a, 's, W: Write> SerializeTupleVariant for SchemaAwareTupleSerializerTransparent<'a, 's, W> { + type Ok = usize; + type Error = Error; + + fn serialize_field<T>(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.serialize_element(value) + } + + fn end(self) -> Result<Self::Ok, Self::Error> { + SerializeTuple::end(self) + } +} + +/// Serialize a tuple as a record. +/// +/// The tuple must have the same amount of fields as the record. Field names are ignored, so the +/// tuple field order must match the record field order. +pub struct SchemaAwareTupleSerializerRecord<'a, 's, W: Write> { + ser: &'a mut SchemaAwareWriteSerializer<'s, W>, + schema: &'s RecordSchema, + field: usize, + bytes_written: usize, +} + +impl<'a, 's, W: Write> SchemaAwareTupleSerializerRecord<'a, 's, W> { + pub fn new( + ser: &'a mut SchemaAwareWriteSerializer<'s, W>, + schema: &'s RecordSchema, + bytes_written: usize, + ) -> Self { + Self { + ser, + schema, + field: 0, + bytes_written, + } + } +} + +impl<'a, 's, W: Write> SerializeTuple for SchemaAwareTupleSerializerRecord<'a, 's, W> { + type Ok = usize; + type Error = Error; + + fn serialize_element<T>(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + assert!( + self.field < self.schema.fields.len(), + "This struct should only be created for tuples with the same amount of fields as the record" + ); + let schema = &self.schema.fields[self.field].schema; + let mut value_ser = SchemaAwareWriteSerializer::new( + &mut *self.ser.writer, + schema, + self.ser.names, + self.ser.enclosing_namespace.clone(), + ); + self.bytes_written += value.serialize(&mut value_ser)?; + self.field += 1; + Ok(()) + } + + fn end(self) -> Result<Self::Ok, Self::Error> { + assert_eq!( + self.field, + self.schema.fields.len(), + "This struct should only be created for tuples with the same amount of fields as the record" + ); + Ok(self.bytes_written) + } +} + +impl<'a, 's, W: Write> SerializeTupleStruct for SchemaAwareTupleSerializerRecord<'a, 's, W> { + type Ok = usize; + type Error = Error; + + fn serialize_field<T>(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.serialize_element(value) + } + + fn end(self) -> Result<Self::Ok, Self::Error> { + SerializeTuple::end(self) + } +} + +impl<'a, 's, W: Write> SerializeTupleVariant for SchemaAwareTupleSerializerRecord<'a, 's, W> { + type Ok = usize; + type Error = Error; + + fn serialize_field<T>(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.serialize_element(value) + } + + fn end(self) -> Result<Self::Ok, Self::Error> { + SerializeTuple::end(self) + } +} diff --git a/avro/src/serde/ser_schema2/mod.rs b/avro/src/serde/ser_schema2/mod.rs new file mode 100644 index 0000000..0446c30 --- /dev/null +++ b/avro/src/serde/ser_schema2/mod.rs @@ -0,0 +1,441 @@ +mod union; + +use crate::encode::{encode_int, encode_long}; +use crate::error::Details; +use crate::schema::{DecimalSchema, InnerDecimalSchema, NamesRef, SchemaKind, UuidSchema}; +use crate::{Error, Schema}; +use serde::{Serialize, Serializer}; +use std::io::Write; + +pub struct SchemaAwareSerializer<'s, W: Write> { + writer: W, + schema: &'s Schema, + names: &'s NamesRef<'s>, +} + +impl<'s, W: Write> SchemaAwareSerializer<'s, W> { + pub fn new(writer: W, schema: &'s Schema, names: &'s NamesRef<'s>) -> Result<Self, Error> { + if let Schema::Ref { name } = schema { + let schema = names + .get(name) + .ok_or_else(|| Details::SchemaResolutionError(name.clone()))?; + Self::new(writer, schema, names) + } else { + Ok(Self { + writer, + schema, + names, + }) + } + } + + fn error(&self, ty: &'static str, error: impl Into<String>) -> Error { + Error::new(Details::SerializeValueWithSchema { + value_type: ty, + value: error.into(), + schema: self.schema.clone(), + }) + } + + fn serialize_int(mut self, original_ty: &'static str, v: i32) -> Result<usize, Error> { + match self.schema { + Schema::Int | Schema::Date | Schema::TimeMillis => encode_int(v, &mut self.writer), + _ => Err(self.error( + original_ty, + "Expected Schema::Int | Schema::Date | Schema::TimeMillis", + )), + } + } + + fn serialize_long(mut self, original_ty: &'static str, v: i64) -> Result<usize, Error> { + match self.schema { + Schema::Long | Schema::TimeMicros | Schema::TimestampMillis | Schema::TimestampMicros + | Schema::TimestampNanos | Schema::LocalTimestampMillis | Schema::LocalTimestampMicros + | Schema::LocalTimestampNanos => { + encode_long(v, &mut self.writer) + } + _ => { + Err(self.error(original_ty, "Expected Schema::Long | Schema::TimeMicros | Schema::{,Local}Timestamp{Millis,Micros,Nanos}")) + } + } + } + + fn write_bytes(&mut self, bytes: &[u8]) -> Result<usize, Error> { + self.writer.write_all(bytes).map_err(Details::WriteBytes)?; + + Ok(bytes.len()) + } + + fn write_bytes_with_len(&mut self, bytes: &[u8]) -> Result<usize, Error> { + let mut bytes_written: usize = 0; + + bytes_written += encode_long(bytes.len() as i64, &mut self.writer)?; + bytes_written += self.write_bytes(bytes)?; + + Ok(bytes_written) + } +} + +impl<'s, W: Write> Serializer for SchemaAwareSerializer<'s, W> { + /// Amount of bytes written + type Ok = usize; + type Error = Error; + type SerializeSeq = (); + type SerializeTuple = (); + type SerializeTupleStruct = (); + type SerializeTupleVariant = (); + type SerializeMap = (); + type SerializeStruct = (); + type SerializeStructVariant = (); + + fn serialize_bool(mut self, v: bool) -> Result<Self::Ok, Self::Error> { + let Schema::Boolean = self.schema else { + return Err(self.error("bool", "Expected Schema::Boolean")); + }; + self.writer + .write_all(&[v as u8]) + .map_err(Details::WriteBytes)?; + Ok(1) + } + + fn serialize_i8(self, v: i8) -> Result<Self::Ok, Self::Error> { + self.serialize_int("i8", i32::from(v)) + } + + fn serialize_i16(self, v: i16) -> Result<Self::Ok, Self::Error> { + self.serialize_int("i16", i32::from(v)) + } + + fn serialize_i32(self, v: i32) -> Result<Self::Ok, Self::Error> { + self.serialize_int("i32", v) + } + + fn serialize_i64(self, v: i64) -> Result<Self::Ok, Self::Error> { + self.serialize_long("i64", v) + } + + fn serialize_i128(mut self, v: i128) -> Result<Self::Ok, Self::Error> { + let Schema::Fixed(fixed) = self.schema else { + return Err(self.error("i128", r#"Expected Schema::Fixed(name: "i128", size: 16)"#)); + }; + if fixed.name.name != "i128" || fixed.size != 16 { + return Err(self.error("i128", r#"Expected Schema::Fixed(name: "i128", size: 16)"#)); + } + let bytes = v.to_le_bytes(); + self.write_bytes(&bytes) + } + + fn serialize_u8(self, v: u8) -> Result<Self::Ok, Self::Error> { + self.serialize_int("u8", i32::from(v)) + } + + fn serialize_u16(self, v: u16) -> Result<Self::Ok, Self::Error> { + self.serialize_int("u16", i32::from(v)) + } + + fn serialize_u32(self, v: u32) -> Result<Self::Ok, Self::Error> { + self.serialize_long("u32", i64::from(v)) + } + + fn serialize_u64(mut self, v: u64) -> Result<Self::Ok, Self::Error> { + let Schema::Fixed(fixed) = self.schema else { + return Err(self.error("u64", r#"Expected Schema::Fixed(name: "u64", size: 8)"#)); + }; + if fixed.name.name != "u64" || fixed.size != 8 { + return Err(self.error("u64", r#"Expected Schema::Fixed(name: "u64", size: 8)"#)); + } + let bytes = v.to_le_bytes(); + self.write_bytes(&bytes) + } + + fn serialize_u128(mut self, v: u128) -> Result<Self::Ok, Self::Error> { + let Schema::Fixed(fixed) = self.schema else { + return Err(self.error("u128", r#"Expected Schema::Fixed(name: "u128", size: 16)"#)); + }; + if fixed.name.name != "u128" || fixed.size != 16 { + return Err(self.error("u128", r#"Expected Schema::Fixed(name: "u128", size: 16)"#)); + } + let bytes = v.to_le_bytes(); + self.write_bytes(&bytes) + } + + fn serialize_f32(mut self, v: f32) -> Result<Self::Ok, Self::Error> { + let Schema::Float = self.schema else { + return Err(self.error("f32", "Expected Schema::Float")); + }; + let bytes = v.to_le_bytes(); + self.write_bytes(&bytes) + } + + fn serialize_f64(mut self, v: f64) -> Result<Self::Ok, Self::Error> { + let Schema::Double = self.schema else { + return Err(self.error("f64", "Expected Schema::Double")); + }; + let bytes = v.to_le_bytes(); + self.write_bytes(&bytes) + } + + fn serialize_char(mut self, v: char) -> Result<Self::Ok, Self::Error> { + let Schema::String = self.schema else { + return Err(self.error("char", "Expected Schema::String")); + }; + let bytes = v.to_string().into_bytes(); + self.write_bytes_with_len(&bytes) + } + + fn serialize_str(mut self, v: &str) -> Result<Self::Ok, Self::Error> { + match self.schema { + Schema::String | Schema::Uuid(UuidSchema::String) => { + let bytes = v.as_bytes(); + self.write_bytes_with_len(&bytes) + } + _ => Err(self.error("str", "Expected Schema::String | Schema::Uuid(String)")), + } + } + + fn serialize_bytes(mut self, v: &[u8]) -> Result<Self::Ok, Self::Error> { + match self.schema { + Schema::Bytes | Schema::BigDecimal | Schema::Decimal(DecimalSchema { inner: InnerDecimalSchema::Bytes, .. }) => { + self.write_bytes_with_len(v) + } + Schema::Fixed(fixed) | Schema::Decimal(DecimalSchema { inner: InnerDecimalSchema::Fixed(fixed), ..}) => { + if fixed.size != v.len() { + Err(self.error("bytes", format!("Fixed size ({}) does not match value length ({})", fixed.size, v.len()))) + } else { + self.write_bytes(v) + } + } + _ => Err(self.error("bytes", "Expected Schema::Bytes | Schema::Uuid(Fixed) | Schema::BigDecimal | Schema::Decimal")), + } + } + + fn serialize_none(mut self) -> Result<Self::Ok, Self::Error> { + let Schema::Union(union) = self.schema else { + return Err(self.error("None", "Expected Schema::Union([null, _])")); + }; + if union.variants().len() != 2 { + return Err(self.error("None", "Expected Schema::Union([null, _])")); + } + let Some(index) = union.index().get(&SchemaKind::Null).copied() else { + return Err(self.error("None", "Expected Schema::Union([null, _])")); + }; + encode_int(index as i32, &mut self.writer) + } + + fn serialize_some<T>(mut self, value: &T) -> Result<Self::Ok, Self::Error> + where + T: ?Sized + Serialize, + { + let Schema::Union(union) = self.schema else { + return Err(self.error("None", "Expected Schema::Union([null, _])")); + }; + if union.variants().len() != 2 { + return Err(self.error("None", "Expected Schema::Union([null, _])")); + } + let Some(index) = union.index().get(&SchemaKind::Null).copied() else { + return Err(self.error("None", "Expected Schema::Union([null, _])")); + }; + // Convert the index of null to the other index + let index = (index + 1) & 1; + let mut bytes_written = encode_int(index as i32, &mut self.writer)?; + let ser = Self::new(self.writer, &union.variants()[index], self.names)?; + bytes_written += value.serialize(ser)?; + Ok(bytes_written) + } + + fn serialize_unit(self) -> Result<Self::Ok, Self::Error> { + let Schema::Null = self.schema else { + return Err(self.error("()", "Expected Schema::Null")); + }; + Ok(0) + } + + fn serialize_unit_struct(self, name: &'static str) -> Result<Self::Ok, Self::Error> { + let Schema::Record(record) = self.schema else { + return Err(self.error( + "unit struct", + format!("Expected Schema::Record(name: {name}, fields: [])"), + )); + }; + if record.name.name != name || !record.fields.is_empty() { + return Err(self.error( + "unit struct", + format!("Expected Schema::Record(name: {name}, fields: [])"), + )); + } + Ok(0) + } + + fn serialize_unit_variant( + mut self, + name: &'static str, + variant_index: u32, + variant: &'static str, + ) -> Result<Self::Ok, Self::Error> { + match self.schema { + Schema::Enum(enum_schema) => { + if enum_schema.name.name != name { + return Err(self.error( + "unit variant", + format!("Enum name ({name}) does not match schema name"), + )); + } else if enum_schema.symbols[variant_index as usize] != variant { + return Err(self.error( + "unit variant", + format!( + "Enum variant ({variant}) is not at index {variant_index} in symbols" + ), + )); + } + encode_int(variant_index as i32, &mut self.writer) + } + Schema::Union(union_schema) => { + if let Some(index) = union_schema.index().get(&SchemaKind::Null).copied() { + // Bare union + encode_int(index as i32, &mut self.writer) + } else { + // Union of records + let Some(Schema::Record(record)) = + union_schema.variants().get(variant_index as usize) + else { + return Err(self.error("unit variant", format!("Union does not contain null and variant at index {variant_index} is not a record"))); + }; + if record.name.name != variant { + return Err(self.error("unit variant", format!("Union does not contain null and variant at index {variant_index} is not named {variant}"))); + } + encode_int(variant_index as i32, &mut self.writer) + } + } + Schema::Record(record) => { + // Discriminator value + if record.name.name != name { + return Err(self.error( + "unit variant", + format!("Enum name ({name}) does not match schema name"), + )); + } else if record.fields.len() != 2 { + return Err(self.error("unit variant", "Record does not have two fields")); + } + let Schema::Enum(tag) = &record.fields[0].schema else { + return Err(self.error("unit variant", "First field of record is not an enum")); + }; + let Schema::Union(content) = &record.fields[1].schema else { + return Err(self.error("unit variant", "Second field of record is not a union")); + }; + + if tag.name.name != name { + return Err(self.error( + "unit variant", + format!("Tag name ({name}) does not match schema name"), + )); + } else if tag.symbols[variant_index as usize] != variant { + return Err(self.error( + "unit variant", + format!( + "Tag variant ({variant}) is not at index {variant_index} in symbols" + ), + )); + } + let mut bytes_written = encode_int(variant_index as i32, &mut self.writer)?; + + let Some(index) = content.index().get(&SchemaKind::Null).copied() else { + return Err( + self.error("unit variant", "Content does not contain a Null schema") + ); + }; + bytes_written += encode_int(index as i32, &mut self.writer)?; + Ok(bytes_written) + } + _ => Err(self.error( + "unit variant", + "Expected Enum | Union | UnionOfRecords | Record(fields: [tag, content])", + )), + } + } + + fn serialize_newtype_struct<T>( + self, + name: &'static str, + value: &T, + ) -> Result<Self::Ok, Self::Error> + where + T: ?Sized + Serialize, + { + let Schema::Record(record) = self.schema else { + return Err(self.error( + "newtype struct", + format!("Expected Schema::Record(name: {name}) with one field"), + )); + }; + if record.name.name != name || record.fields.len() != 1 { + return Err(self.error( + "newtype struct", + format!("Expected Schema::Record(name: {name}) with one field"), + )); + }; + let schema = &record.fields[0].schema; + let ser = Self::new(self.writer, schema, self.names)?; + value.serialize(ser) + } + + fn serialize_newtype_variant<T>( + self, + name: &'static str, + variant_index: u32, + variant: &'static str, + value: &T, + ) -> Result<Self::Ok, Self::Error> + where + T: ?Sized + Serialize, + { + todo!("Create UnionAwareSerializer") + } + + fn serialize_seq(self, len: Option<usize>) -> Result<Self::SerializeSeq, Self::Error> { + todo!() + } + + fn serialize_tuple(self, len: usize) -> Result<Self::SerializeTuple, Self::Error> { + todo!() + } + + fn serialize_tuple_struct( + self, + name: &'static str, + len: usize, + ) -> Result<Self::SerializeTupleStruct, Self::Error> { + todo!() + } + + fn serialize_tuple_variant( + self, + name: &'static str, + variant_index: u32, + variant: &'static str, + len: usize, + ) -> Result<Self::SerializeTupleVariant, Self::Error> { + todo!() + } + + fn serialize_map(self, len: Option<usize>) -> Result<Self::SerializeMap, Self::Error> { + todo!() + } + + fn serialize_struct( + self, + name: &'static str, + len: usize, + ) -> Result<Self::SerializeStruct, Self::Error> { + todo!() + } + + fn serialize_struct_variant( + self, + name: &'static str, + variant_index: u32, + variant: &'static str, + len: usize, + ) -> Result<Self::SerializeStructVariant, Self::Error> { + todo!() + } +} diff --git a/avro/src/serde/ser_schema2/union.rs b/avro/src/serde/ser_schema2/union.rs new file mode 100644 index 0000000..26296f2 --- /dev/null +++ b/avro/src/serde/ser_schema2/union.rs @@ -0,0 +1,355 @@ +use crate::encode::encode_int; +use crate::error::Details; +use crate::schema::{Name, NamesRef, SchemaKind, UnionSchema}; +use crate::serde::ser_schema2::SchemaAwareSerializer; +use crate::{Error, Schema}; +use serde::{Serialize, Serializer}; +use std::io::Write; + +pub struct UnionAwareSerializer<'s, W: Write> { + writer: W, + union_schema: &'s UnionSchema, + names: &'s NamesRef<'s>, +} + +impl<'s, W: Write> UnionAwareSerializer<'s, W> { + pub fn new(writer: W, union_schema: &'s UnionSchema, names: &'s NamesRef<'s>) -> Self { + Self { + writer, + union_schema, + names, + } + } + + fn error(&self, ty: &'static str, error: impl Into<String>) -> Error { + Error::new(Details::SerializeValueWithSchema { + value_type: ty, + value: error.into(), + schema: Schema::Union(self.union_schema.clone()), + }) + } + + fn serialize_int(mut self, original_ty: &'static str, v: i32) -> Result<usize, Error> { + let Some(index) = self.union_schema.index().get(&SchemaKind::Int).copied() else { + return Err(self.error( + original_ty, + "Expected Schema::Int | Schema::Date | Schema::TimeMillis in variants", + )); + }; + let mut bytes_written = encode_int(index as i32, &mut self.writer)?; + let ser = SchemaAwareSerializer::new(self.writer, &Schema::Int, self.names)?; + bytes_written += ser.serialize_int(original_ty, v)?; + Ok(bytes_written) + } + + fn serialize_long(mut self, original_ty: &'static str, v: i64) -> Result<usize, Error> { + let Some(index) = self.union_schema.index().get(&SchemaKind::Long).copied() else { + return Err(self.error(original_ty, "Expected Schema::Long | Schema::TimeMicros | Schema::{,Local}Timestamp{Millis,Micros,Nanos} in variants")); + }; + let mut bytes_written = encode_int(index as i32, &mut self.writer)?; + let ser = SchemaAwareSerializer::new(self.writer, &Schema::Long, self.names)?; + bytes_written += ser.serialize_long(original_ty, v)?; + Ok(bytes_written) + } + + fn find_named_schema(&self, name: &str) -> Option<(usize, &'s Schema)> { + self.union_schema + .variants() + .iter() + .enumerate() + .find(|(_i, s)| s.name().is_some_and(|n| n.name == name)) + } +} + +impl<'s, W: Write> Serializer for UnionAwareSerializer<'s, W> { + type Ok = usize; + type Error = Error; + type SerializeSeq = SchemaAwareSerializer<'s, W>; + type SerializeTuple = SchemaAwareSerializer<'s, W>; + type SerializeTupleStruct = SchemaAwareSerializer<'s, W>; + type SerializeTupleVariant = SchemaAwareSerializer<'s, W>; + type SerializeMap = SchemaAwareSerializer<'s, W>; + type SerializeStruct = SchemaAwareSerializer<'s, W>; + type SerializeStructVariant = SchemaAwareSerializer<'s, W>; + + fn serialize_bool(mut self, v: bool) -> Result<Self::Ok, Self::Error> { + let Some(index) = self.union_schema.index().get(&SchemaKind::Boolean).copied() else { + return Err(self.error("bool", "Expected Schema::Boolean in variants")); + }; + let mut bytes_written = encode_int(index as i32, &mut self.writer)?; + let ser = SchemaAwareSerializer::new(self.writer, &Schema::Boolean, self.names)?; + bytes_written += ser.serialize_bool(v)?; + Ok(bytes_written) + } + + fn serialize_i8(self, v: i8) -> Result<Self::Ok, Self::Error> { + self.serialize_int("i8", i32::from(v)) + } + + fn serialize_i16(self, v: i16) -> Result<Self::Ok, Self::Error> { + self.serialize_int("i16", i32::from(v)) + } + + fn serialize_i32(self, v: i32) -> Result<Self::Ok, Self::Error> { + self.serialize_int("i32", v) + } + + fn serialize_i64(self, v: i64) -> Result<Self::Ok, Self::Error> { + self.serialize_long("i64", v) + } + + fn serialize_i128(mut self, v: i128) -> Result<Self::Ok, Self::Error> { + let Some((index, schema)) = self.find_named_schema("i128") else { + return Err(self.error( + "i128", + r#"Expected Schema::Fixed(name: "i128", size: 16) in variants"#, + )); + }; + let mut bytes_written = encode_int(index as i32, &mut self.writer)?; + let ser = SchemaAwareSerializer::new(self.writer, schema, self.names)?; + bytes_written += ser.serialize_i128(v)?; + Ok(bytes_written) + } + + fn serialize_u8(self, v: u8) -> Result<Self::Ok, Self::Error> { + self.serialize_int("u8", i32::from(v)) + } + + fn serialize_u16(self, v: u16) -> Result<Self::Ok, Self::Error> { + self.serialize_int("u16", i32::from(v)) + } + + fn serialize_u32(self, v: u32) -> Result<Self::Ok, Self::Error> { + self.serialize_long("u32", i64::from(v)) + } + + fn serialize_u64(mut self, v: u64) -> Result<Self::Ok, Self::Error> { + let Some((index, schema)) = self.find_named_schema("u64") else { + return Err(self.error( + "u64", + r#"Expected Schema::Fixed(name: "u64", size: 8) in variants"#, + )); + }; + let mut bytes_written = encode_int(index as i32, &mut self.writer)?; + let ser = SchemaAwareSerializer::new(self.writer, schema, self.names)?; + bytes_written += ser.serialize_u64(v)?; + Ok(bytes_written) + } + + fn serialize_u128(mut self, v: u128) -> Result<Self::Ok, Self::Error> { + let Some((index, schema)) = self.find_named_schema("u128") else { + return Err(self.error( + "i128", + r#"Expected Schema::Fixed(name: "u128", size: 16) in variants"#, + )); + }; + let mut bytes_written = encode_int(index as i32, &mut self.writer)?; + let ser = SchemaAwareSerializer::new(self.writer, schema, self.names)?; + bytes_written += ser.serialize_u128(v)?; + Ok(bytes_written) + } + + fn serialize_f32(mut self, v: f32) -> Result<Self::Ok, Self::Error> { + let Some(index) = self.union_schema.index().get(&SchemaKind::Float).copied() else { + return Err(self.error("f32", "Expected Schema::Float in variants")); + }; + let mut bytes_written = encode_int(index as i32, &mut self.writer)?; + let ser = SchemaAwareSerializer::new(self.writer, &Schema::Float, self.names)?; + bytes_written += ser.serialize_f32(v)?; + Ok(bytes_written) + } + + fn serialize_f64(mut self, v: f64) -> Result<Self::Ok, Self::Error> { + let Some(index) = self.union_schema.index().get(&SchemaKind::Double).copied() else { + return Err(self.error("f64", "Expected Schema::Double in variants")); + }; + let mut bytes_written = encode_int(index as i32, &mut self.writer)?; + let ser = SchemaAwareSerializer::new(self.writer, &Schema::Double, self.names)?; + bytes_written += ser.serialize_f64(v)?; + Ok(bytes_written) + } + + fn serialize_char(mut self, v: char) -> Result<Self::Ok, Self::Error> { + let Some(index) = self.union_schema.index().get(&SchemaKind::String).copied() else { + return Err(self.error("char", "Expected Schema::String in variants")); + }; + let mut bytes_written = encode_int(index as i32, &mut self.writer)?; + let ser = SchemaAwareSerializer::new(self.writer, &Schema::String, self.names)?; + bytes_written += ser.serialize_char(v)?; + Ok(bytes_written) + } + + fn serialize_str(mut self, v: &str) -> Result<Self::Ok, Self::Error> { + let Some(index) = self.union_schema.index().get(&SchemaKind::String).copied() else { + return Err(self.error("str", "Expected Schema::String in variants")); + }; + let mut bytes_written = encode_int(index as i32, &mut self.writer)?; + let ser = SchemaAwareSerializer::new(self.writer, &Schema::String, self.names)?; + bytes_written += ser.serialize_str(v)?; + Ok(bytes_written) + } + + fn serialize_bytes(mut self, v: &[u8]) -> Result<Self::Ok, Self::Error> { + let potential_bytes_index = self.union_schema.index().get(&SchemaKind::Bytes).copied(); + let potential_fixed_index = + self.union_schema + .variants() + .iter() + .enumerate() + .find(|(_i, s)| { + if let Schema::Fixed(f) = s { + f.size == v.len() + } else { + false + } + }); + let (index, schema) = match (potential_bytes_index, potential_fixed_index) { + (Some(bytes_index), Some((fixed_index, fixed_schema))) => { + if bytes_index < fixed_index { + (bytes_index, &Schema::Bytes) + } else { + (fixed_index, fixed_schema) + } + } + (Some(bytes_index), None) => (bytes_index, &Schema::Bytes), + (None, Some((fixed_index, fixed_schema))) => (fixed_index, fixed_schema), + (None, None) => { + return Err(self.error( + "bytes", + format!( + "Expected Schema::Bytes or Schema::Fixed(size: {}) in variants", + v.len() + ), + )); + } + }; + let mut bytes_written = encode_int(index as i32, &mut self.writer)?; + let ser = SchemaAwareSerializer::new(self.writer, schema, self.names)?; + bytes_written += ser.serialize_bytes(v)?; + Ok(bytes_written) + } + + fn serialize_none(self) -> Result<Self::Ok, Self::Error> { + Err(self.error("None", "Nested unions are not supported")) + } + + fn serialize_some<T>(self, value: &T) -> Result<Self::Ok, Self::Error> + where + T: ?Sized + Serialize, + { + Err(self.error("Some", "Nested unions are not supported")) + } + + fn serialize_unit(self) -> Result<Self::Ok, Self::Error> { + let Some(index) = self.union_schema.index().get(&SchemaKind::Null).copied() else { + return Err(self.error("()", "Expected Schema::Null in variants")); + }; + encode_int(index as i32, self.writer) + } + + fn serialize_unit_struct(self, name: &'static str) -> Result<Self::Ok, Self::Error> { + let Some((index, Schema::Record(schema))) = self.find_named_schema(name) else { + return Err(self.error( + "unit struct", + format!("Expected Schema::Record(name: {name}, fields: []) in variants"), + )); + }; + if !schema.fields.is_empty() { + return Err(self.error( + "unit struct", + format!("Expected Schema::Record(name: {name}, , fields: []) in variants"), + )); + } + encode_int(index as i32, self.writer) + } + + fn serialize_unit_variant( + mut self, + name: &'static str, + variant_index: u32, + variant: &'static str, + ) -> Result<Self::Ok, Self::Error> { + let Some((index, schema)) = self.find_named_schema(name) else { + return Err(self.error( + "unit variant", + "Expected Enum | Record(fields: [tag, content]) in variants", + )); + }; + let mut bytes_written = encode_int(index as i32, &mut self.writer)?; + let ser = SchemaAwareSerializer::new(self.writer, schema, self.names)?; + bytes_written += ser.serialize_unit_variant(name, variant_index, variant)?; + Ok(bytes_written) + } + + fn serialize_newtype_struct<T>( + self, + name: &'static str, + value: &T, + ) -> Result<Self::Ok, Self::Error> + where + T: ?Sized + Serialize, + { + todo!() + } + + fn serialize_newtype_variant<T>( + self, + name: &'static str, + variant_index: u32, + variant: &'static str, + value: &T, + ) -> Result<Self::Ok, Self::Error> + where + T: ?Sized + Serialize, + { + todo!() + } + + fn serialize_seq(self, len: Option<usize>) -> Result<Self::SerializeSeq, Self::Error> { + todo!() + } + + fn serialize_tuple(self, len: usize) -> Result<Self::SerializeTuple, Self::Error> { + todo!() + } + + fn serialize_tuple_struct( + self, + name: &'static str, + len: usize, + ) -> Result<Self::SerializeTupleStruct, Self::Error> { + todo!() + } + + fn serialize_tuple_variant( + self, + name: &'static str, + variant_index: u32, + variant: &'static str, + len: usize, + ) -> Result<Self::SerializeTupleVariant, Self::Error> { + todo!() + } + + fn serialize_map(self, len: Option<usize>) -> Result<Self::SerializeMap, Self::Error> { + todo!() + } + + fn serialize_struct( + self, + name: &'static str, + len: usize, + ) -> Result<Self::SerializeStruct, Self::Error> { + todo!() + } + + fn serialize_struct_variant( + self, + name: &'static str, + variant_index: u32, + variant: &'static str, + len: usize, + ) -> Result<Self::SerializeStructVariant, Self::Error> { + todo!() + } +} diff --git a/avro/src/util.rs b/avro/src/util.rs index 8acbfa4..3127f8d 100644 --- a/avro/src/util.rs +++ b/avro/src/util.rs @@ -127,8 +127,9 @@ fn encode_variable<W: Write>(mut zigzagged: u64, mut writer: W) -> AvroResult<us } } writer - .write(&buffer[..i]) - .map_err(|e| Details::WriteBytes(e).into()) + .write_all(&buffer[..i]) + .map_err(Details::WriteBytes)?; + Ok(i) } /// Read a varint from the reader. diff --git a/avro_derive/src/attributes/avro.rs b/avro_derive/src/attributes/avro.rs index fdf755e..f4380f5 100644 --- a/avro_derive/src/attributes/avro.rs +++ b/avro_derive/src/attributes/avro.rs @@ -57,6 +57,22 @@ pub struct ContainerAttributes { /// Set the default value if this schema is used as a field #[darling(default)] pub default: Option<String>, + #[darling(default)] + pub repr: Option<EnumRepr>, +} + +/// What kind of schema to use for an enum. +#[derive(Debug, FromMeta, PartialEq, Default, Clone, Copy)] +pub enum EnumRepr { + /// Create a `Schema::Enum`, only works for plain enums. + #[default] + Enum, + /// Untagged + Union, + /// Externally tagged + UnionOfRecords, + /// Adjacently tagged (`#[serde(tag = "type", content = "value")]`) + DiscriminatorValue, } impl ContainerAttributes { diff --git a/avro_derive/src/attributes/mod.rs b/avro_derive/src/attributes/mod.rs index 7dbab53..5b08f3d 100644 --- a/avro_derive/src/attributes/mod.rs +++ b/avro_derive/src/attributes/mod.rs @@ -24,6 +24,20 @@ use syn::{AttrStyle, Attribute, Expr, Ident, Path, spanned::Spanned}; mod avro; mod serde; +/// What kind of schema to use for an enum. +#[derive(Debug, PartialEq, Default)] +pub enum EnumRepr { + /// Create a `Schema::Enum`, only works for plain enums. + #[default] + Enum, + /// Untagged + Union, + /// Externally tagged + UnionOfRecords, + /// Adjacently tagged (`#[serde(tag = "type", content = "value")]`) + DiscriminatorValue { tag: String, content: String }, +} + #[derive(Default)] pub struct NamedTypeOptions { pub name: String, @@ -32,6 +46,7 @@ pub struct NamedTypeOptions { pub rename_all: RenameRule, pub transparent: bool, pub default: TokenStream, + pub repr: Option<EnumRepr>, } impl NamedTypeOptions { @@ -52,15 +67,10 @@ impl NamedTypeOptions { let mut errors = Vec::new(); // Check for any Serde attributes that are hard errors - if serde.tag.is_some() - || serde.content.is_some() - || serde.untagged - || serde.variant_identifier - || serde.field_identifier - { + if serde.variant_identifier || serde.field_identifier { errors.push(syn::Error::new( span, - "AvroSchema derive does not support changing the tagging Serde generates (`tag`, `content`, `untagged`, `variant_identifier`, `field_identifier`)", + "AvroSchema derive does not support `variant_identifier` and `field_identifier`", )); } if serde.remote.is_some() { @@ -80,13 +90,13 @@ impl NamedTypeOptions { if avro.name.is_some() && avro.name != serde.rename { errors.push(syn::Error::new( span, - "#[avro(name = \"..\")] must match #[serde(rename = \"..\")], it's also deprecated. Please use only `#[serde(rename = \"..\")]`", + "`#[avro(name = \"..\")]` must match `#[serde(rename = \"..\")]`, it's also deprecated. Please use only `#[serde(rename = \"..\")]`", )); } if avro.rename_all != RenameRule::None && serde.rename_all.serialize != avro.rename_all { errors.push(syn::Error::new( span, - "#[avro(rename_all = \"..\")] must match #[serde(rename_all = \"..\")], it's also deprecated. Please use only `#[serde(rename_all = \"..\")]`", + "`#[avro(rename_all = \"..\")]` must match `#[serde(rename_all = \"..\")]`, it's also deprecated. Please use only `#[serde(rename_all = \"..\")]`", )); } if serde.transparent @@ -101,10 +111,86 @@ impl NamedTypeOptions { { errors.push(syn::Error::new( span, - "AvroSchema: #[serde(transparent)] is incompatible with all other attributes", + "AvroSchema: `#[serde(transparent)]` is incompatible with all other attributes", )); } + let repr = if let Some(repr) = avro.repr { + match repr { + avro::EnumRepr::Enum => { + if serde.tag.is_some() || serde.content.is_some() || serde.untagged { + errors.push(syn::Error::new( + span, + r#"AvroSchema: `#[avro(repr = "enum")]` is incompatible with `#[serde(tag = "..")]`, `#[serde(content = "..")]`, and `#[serde(untagged)]`"#, + )); + } + Some(EnumRepr::Enum) + } + avro::EnumRepr::Union => { + if serde.tag.is_some() || serde.content.is_some() { + errors.push(syn::Error::new( + span, + r#"AvroSchema: `#[avro(repr = "union")]` is incompatible with `#[serde(tag = "..")]` and `#[serde(content = "..")]`"#, + )); + } + Some(EnumRepr::Union) + } + avro::EnumRepr::UnionOfRecords => { + if serde.tag.is_some() || serde.content.is_some() || serde.untagged { + errors.push(syn::Error::new( + span, + r#"AvroSchema: `#[avro(repr = "union_of_records")]` is incompatible with `#[serde(tag = "..")]`, `#[serde(content = "..")]`, and `#[serde(untagged)]`"#, + )); + } + Some(EnumRepr::UnionOfRecords) + } + avro::EnumRepr::DiscriminatorValue => { + if serde.untagged || serde.tag.is_none() || serde.content.is_none() { + errors.push(syn::Error::new( + span, + r#"AvroSchema: `#[avro(repr = "discriminator_value")]` requires `#[serde(tag = "..", content = "..")]`"#, + )); + } + let tag = serde.tag.unwrap_or("error".to_string()); + let content = serde.content.unwrap_or("error".to_string()); + Some(EnumRepr::DiscriminatorValue { tag, content }) + } + } + } else { + if serde.tag.is_some() && serde.content.is_some() { + let tag = serde.tag.unwrap_or("type".to_string()); + let content = serde.content.unwrap_or("value".to_string()); + Some(EnumRepr::DiscriminatorValue { tag, content }) + } else if serde.untagged { + Some(EnumRepr::Union) + } else if serde.tag.is_some() != serde.content.is_some() { + errors.push(syn::Error::new( + span, + r#"AvroSchema does not support `#[serde(tag = "..")]` without `#[serde(content = "..")]`"#, + )); + None + } else { + None + } + }; + + let default = match avro.default { + None => quote! { None }, + Some(default_value) => { + if let Err(err) = serde_json::from_str::<serde_json::Value>(&default_value[..]) { + errors.push(syn::Error::new( + ident.span(), + format!("Invalid Avro `default` JSON: \n{err}"), + )); + quote! { None } + } else { + quote! { + Some(serde_json::from_str(#default_value).expect(format!("Invalid JSON: {:?}", #default_value).as_str())) + } + } + } + }; + if !errors.is_empty() { return Err(errors); } @@ -118,22 +204,6 @@ impl NamedTypeOptions { let doc = avro.doc.or_else(|| extract_rustdoc(attributes)); - let default = match avro.default { - None => quote! { None }, - Some(default_value) => { - let _: serde_json::Value = - serde_json::from_str(&default_value[..]).map_err(|e| { - vec![syn::Error::new( - ident.span(), - format!("Invalid Avro `default` JSON: \n{e}"), - )] - })?; - quote! { - Some(serde_json::from_str(#default_value).expect(format!("Invalid JSON: {:?}", #default_value).as_str())) - } - } - }; - Ok(Self { name: full_schema_name, doc, @@ -141,6 +211,7 @@ impl NamedTypeOptions { rename_all: serde.rename_all.serialize, transparent: serde.transparent, default, + repr, }) } } diff --git a/avro_derive/src/enums/discriminator_value.rs b/avro_derive/src/enums/discriminator_value.rs new file mode 100644 index 0000000..f71cbe8 --- /dev/null +++ b/avro_derive/src/enums/discriminator_value.rs @@ -0,0 +1,130 @@ +use crate::attributes::{NamedTypeOptions, VariantOptions}; +use crate::case::RenameRule; +use crate::{aliases, preserve_optional, type_to_schema_expr}; +use proc_macro2::{Span, TokenStream}; +use quote::quote; +use syn::spanned::Spanned; +use syn::{DataEnum, Fields}; + +pub fn get_data_enum_schema_def( + container_attrs: &NamedTypeOptions, + data_enum: DataEnum, + tag: &str, + content: &str, +) -> Result<TokenStream, Vec<syn::Error>> { + let doc = preserve_optional(container_attrs.doc.as_ref()); + let enum_aliases = aliases(&container_attrs.aliases); + let mut symbols = Vec::new(); + let mut schema_definitions = Vec::new(); + for variant in &data_enum.variants { + let field_attrs = VariantOptions::new(&variant.attrs, variant.span())?; + let name = match (field_attrs.rename, container_attrs.rename_all) { + (Some(rename), _) => rename, + (None, rename_all) if !matches!(rename_all, RenameRule::None) => { + rename_all.apply_to_variant(&variant.ident.to_string()) + } + _ => variant.ident.to_string(), + }; + match &variant.fields { + Fields::Named(named) => { + let mut fields = Vec::with_capacity(named.named.len()); + for (index, field) in named.named.iter().enumerate() { + let ident = field.ident.as_ref().unwrap().to_string(); + let schema_expr = type_to_schema_expr(&field.ty)?; + fields.push(quote! { + ::apache_avro::schema::RecordField::builder() + .name(#ident.to_string()) + .schema(#schema_expr) + .position(#index) + .build() + }); + } + + let schema_expr = quote! { + ::apache_avro::schema::Schema::Record( + ::apache_avro::schema::RecordSchema::builder() + .name(::apache_avro::schema::Name::new(#name).expect(&format!("Unable to parse variant record name for schema {}", #name)[..]).fully_qualified_name(&full_schema_name.namespace)) + .fields(vec![ + #(#fields,)* + ]) + .build() + ) + }; + schema_definitions.push(schema_expr); + } + Fields::Unnamed(unnamed) => { + if unnamed.unnamed.is_empty() { + // TODO: Maybe replace this with an empty Record? + schema_definitions.push(quote! { ::apache_avro::schema::Schema::Null }) + } else if unnamed.unnamed.len() == 1 { + let only_one = unnamed.unnamed.iter().next().expect("There is one"); + let schema_expr = type_to_schema_expr(&only_one.ty)?; + schema_definitions.push(schema_expr); + } else if unnamed.unnamed.len() > 1 { + let mut fields = Vec::with_capacity(unnamed.unnamed.len()); + for (index, field) in unnamed.unnamed.iter().enumerate() { + let schema_expr = type_to_schema_expr(&field.ty)?; + fields.push(quote! { + ::apache_avro::schema::RecordField::builder() + .name(#index.to_string()) + .schema(#schema_expr) + .position(#index) + .build() + }); + } + + let schema_expr = quote! { + ::apache_avro::schema::Schema::Record( + ::apache_avro::schema::RecordSchema::builder() + .name(::apache_avro::schema::Name::new(#name).expect(&format!("Unable to parse variant record name for schema {}", #name)[..]).fully_qualified_name(&full_schema_name.namespace)) + .fields(vec![ + #(#fields,)* + ]) + .build() + ) + }; + schema_definitions.push(schema_expr); + } + } + Fields::Unit => schema_definitions.push(quote! { ::apache_avro::schema::Schema::Null }), + } + symbols.push(name); + } + let full_schema_name = &container_attrs.name; + Ok(quote! { + let full_schema_name = ::apache_avro::schema::Name::new(#full_schema_name).expect(&format!("Unable to parse name for schema {}", #full_schema_name)[..]); + let mut builder = ::apache_avro::schema::UnionSchema::builder(); + + #( + builder.variant_ignore_duplicates(#schema_definitions).expect("Unions cannot have duplicates"); + )* + + let content_contains_null = builder.contains(::apache_avro::schema::Schema::Null); + let content_schema = ::apache_avro::schema::Schema::Union(builder.build()); + + let tag_name = ::apache_avro::schema::Name::new(#tag).expect(&format!("Unable to parse name for schema tag {}", #tag)[..]); + let tag_schema = ::apache_avro::schema::Schema::r#enum(tag_name, vec![#(#symbols.to_owned()),*]).build(); + + let mut fields = ::std::vec::Vec::with_capacity(2); + fields.push(::apache_avro::schema::RecordField::builder() + .name(#tag) + .schema(tag_schema) + .position(0) + .build() + ); + fields.push(::apache_avro::schema::RecordField::builder() + .name(#content) + .schema(content_schema) + .maybe_default(if content_contains_null { Some(::serde_json::Value::Null) } else { None }) + .position(1) + .build() + ); + ::apache_avro::schema::Schema::Record(::apache_avro::schema::RecordSchema::builder() + .name(full_schema_name) + .maybe_aliases(#enum_aliases) + .maybe_doc(#doc) + .fields(fields) + .build() + ) + }) +} diff --git a/avro_derive/src/enums/mod.rs b/avro_derive/src/enums/mod.rs new file mode 100644 index 0000000..d77fd99 --- /dev/null +++ b/avro_derive/src/enums/mod.rs @@ -0,0 +1,66 @@ +use crate::attributes::{EnumRepr, NamedTypeOptions}; +use proc_macro2::{Ident, Span, TokenStream}; +use syn::{Attribute, DataEnum, Fields, Meta}; + +mod discriminator_value; +mod plain; +mod union; +mod union_of_records; + +/// Generate a schema definition for a enum. +pub fn get_data_enum_schema_def( + container_attrs: &NamedTypeOptions, + data_enum: DataEnum, + ident_span: Span, +) -> Result<TokenStream, Vec<syn::Error>> { + match &container_attrs.repr { + None => { + if data_enum.variants.iter().all(|v| Fields::Unit == v.fields) { + plain::get_data_enum_schema_def(container_attrs, data_enum, ident_span) + } else { + union_of_records::get_data_enum_schema_def(container_attrs, data_enum, ident_span) + } + } + Some(EnumRepr::Enum) => { + plain::get_data_enum_schema_def(container_attrs, data_enum, ident_span) + } + Some(EnumRepr::Union) => { + union::get_data_enum_schema_def(container_attrs, data_enum, ident_span) + } + Some(EnumRepr::UnionOfRecords) => { + union_of_records::get_data_enum_schema_def(container_attrs, data_enum, ident_span) + } + Some(EnumRepr::DiscriminatorValue { tag, content }) => { + discriminator_value::get_data_enum_schema_def(container_attrs, data_enum, tag, content) + } + } +} + +fn default_enum_variant( + data_enum: &syn::DataEnum, + error_span: Span, +) -> Result<Option<String>, Vec<syn::Error>> { + match data_enum + .variants + .iter() + .filter(|v| v.attrs.iter().any(is_default_attr)) + .collect::<Vec<_>>() + { + variants if variants.is_empty() => Ok(None), + single if single.len() == 1 => Ok(Some(single[0].ident.to_string())), + multiple => Err(vec![syn::Error::new( + error_span, + format!( + "Multiple defaults defined: {:?}", + multiple + .iter() + .map(|v| v.ident.to_string()) + .collect::<Vec<String>>() + ), + )]), + } +} + +fn is_default_attr(attr: &Attribute) -> bool { + matches!(attr, Attribute { meta: Meta::Path(path), .. } if path.get_ident().map(Ident::to_string).as_deref() == Some("default")) +} diff --git a/avro_derive/src/enums/plain.rs b/avro_derive/src/enums/plain.rs new file mode 100644 index 0000000..3c974f5 --- /dev/null +++ b/avro_derive/src/enums/plain.rs @@ -0,0 +1,49 @@ +use crate::attributes::{NamedTypeOptions, VariantOptions}; +use crate::case::RenameRule; +use crate::enums::default_enum_variant; +use crate::{aliases, preserve_optional}; +use proc_macro2::{Span, TokenStream}; +use quote::quote; +use syn::spanned::Spanned; +use syn::{DataEnum, Fields}; + +pub fn get_data_enum_schema_def( + container_attrs: &NamedTypeOptions, + data_enum: DataEnum, + ident_span: Span, +) -> Result<TokenStream, Vec<syn::Error>> { + let doc = preserve_optional(container_attrs.doc.as_ref()); + let enum_aliases = aliases(&container_attrs.aliases); + if data_enum.variants.iter().all(|v| Fields::Unit == v.fields) { + let default_value = default_enum_variant(&data_enum, ident_span)?; + let default = preserve_optional(default_value); + let mut symbols = Vec::new(); + for variant in &data_enum.variants { + let field_attrs = VariantOptions::new(&variant.attrs, variant.span())?; + let name = match (field_attrs.rename, container_attrs.rename_all) { + (Some(rename), _) => rename, + (None, rename_all) if !matches!(rename_all, RenameRule::None) => { + rename_all.apply_to_variant(&variant.ident.to_string()) + } + _ => variant.ident.to_string(), + }; + symbols.push(name); + } + let full_schema_name = &container_attrs.name; + Ok(quote! { + apache_avro::schema::Schema::Enum(apache_avro::schema::EnumSchema { + name: apache_avro::schema::Name::new(#full_schema_name).expect(&format!("Unable to parse enum name for schema {}", #full_schema_name)[..]), + aliases: #enum_aliases, + doc: #doc, + symbols: vec![#(#symbols.to_owned()),*], + default: #default, + attributes: Default::default(), + }) + }) + } else { + Err(vec![syn::Error::new( + ident_span, + r#"AvroSchema: `#[avro(repr = "enum")]` does not work for enums with non-unit variants"#, + )]) + } +} diff --git a/avro_derive/src/enums/union.rs b/avro_derive/src/enums/union.rs new file mode 100644 index 0000000..7b92631 --- /dev/null +++ b/avro_derive/src/enums/union.rs @@ -0,0 +1,14 @@ +use crate::attributes::NamedTypeOptions; +use proc_macro2::{Span, TokenStream}; +use quote::quote; +use syn::DataEnum; + +pub fn get_data_enum_schema_def( + container_attrs: &NamedTypeOptions, + data_enum: DataEnum, + ident_span: Span, +) -> Result<TokenStream, Vec<syn::Error>> { + Ok(quote! { + panic!("Hello world") + }) +} diff --git a/avro_derive/src/enums/union_of_records.rs b/avro_derive/src/enums/union_of_records.rs new file mode 100644 index 0000000..7b92631 --- /dev/null +++ b/avro_derive/src/enums/union_of_records.rs @@ -0,0 +1,14 @@ +use crate::attributes::NamedTypeOptions; +use proc_macro2::{Span, TokenStream}; +use quote::quote; +use syn::DataEnum; + +pub fn get_data_enum_schema_def( + container_attrs: &NamedTypeOptions, + data_enum: DataEnum, + ident_span: Span, +) -> Result<TokenStream, Vec<syn::Error>> { + Ok(quote! { + panic!("Hello world") + }) +} diff --git a/avro_derive/src/lib.rs b/avro_derive/src/lib.rs index f0bcbc2..e11cbab 100644 --- a/avro_derive/src/lib.rs +++ b/avro_derive/src/lib.rs @@ -31,6 +31,7 @@ mod attributes; mod case; +mod enums; use proc_macro2::{Span, TokenStream}; use quote::quote; @@ -60,6 +61,12 @@ fn derive_avro_schema(input: DeriveInput) -> Result<TokenStream, Vec<syn::Error> match input.data { syn::Data::Struct(data_struct) => { let named_type_options = NamedTypeOptions::new(&input.ident, &input.attrs, input_span)?; + if named_type_options.repr.is_some() { + return Err(vec![syn::Error::new( + input_span, + r#"AvroSchema: `#[avro(repr = "..")]`, `#[serde(tag = "..")]`, `#[serde(content = "..")]`, and `#[serde(untagged)]` are only supported on enums"#, + )]); + } let (get_schema_impl, get_record_fields_impl) = if named_type_options.transparent { get_transparent_struct_schema_def(data_struct.fields, input_span)? } else { @@ -86,8 +93,11 @@ fn derive_avro_schema(input: DeriveInput) -> Result<TokenStream, Vec<syn::Error> "AvroSchema: `#[serde(transparent)]` is only supported on structs", )]); } - let schema_def = - get_data_enum_schema_def(&named_type_options, data_enum, input.ident.span())?; + let schema_def = enums::get_data_enum_schema_def( + &named_type_options, + data_enum, + input.ident.span(), + )?; let inner = handle_named_schemas(named_type_options.name, schema_def); Ok(create_trait_definition( input.ident, @@ -99,7 +109,7 @@ fn derive_avro_schema(input: DeriveInput) -> Result<TokenStream, Vec<syn::Error> } syn::Data::Union(_) => Err(vec![syn::Error::new( input_span, - "AvroSchema: derive only works for structs and simple enums", + "AvroSchema: derive only works for structs and enums", )]), } } @@ -383,48 +393,6 @@ fn get_field_get_record_fields_expr( } } -/// Generate a schema definition for a enum. -fn get_data_enum_schema_def( - container_attrs: &NamedTypeOptions, - data_enum: DataEnum, - ident_span: Span, -) -> Result<TokenStream, Vec<syn::Error>> { - let doc = preserve_optional(container_attrs.doc.as_ref()); - let enum_aliases = aliases(&container_attrs.aliases); - if data_enum.variants.iter().all(|v| Fields::Unit == v.fields) { - let default_value = default_enum_variant(&data_enum, ident_span)?; - let default = preserve_optional(default_value); - let mut symbols = Vec::new(); - for variant in &data_enum.variants { - let field_attrs = VariantOptions::new(&variant.attrs, variant.span())?; - let name = match (field_attrs.rename, container_attrs.rename_all) { - (Some(rename), _) => rename, - (None, rename_all) if !matches!(rename_all, RenameRule::None) => { - rename_all.apply_to_variant(&variant.ident.to_string()) - } - _ => variant.ident.to_string(), - }; - symbols.push(name); - } - let full_schema_name = &container_attrs.name; - Ok(quote! { - apache_avro::schema::Schema::Enum(apache_avro::schema::EnumSchema { - name: apache_avro::schema::Name::new(#full_schema_name).expect(&format!("Unable to parse enum name for schema {}", #full_schema_name)[..]), - aliases: #enum_aliases, - doc: #doc, - symbols: vec![#(#symbols.to_owned()),*], - default: #default, - attributes: Default::default(), - }) - }) - } else { - Err(vec![syn::Error::new( - ident_span, - "AvroSchema: derive does not work for enums with non unit structs", - )]) - } -} - /// Takes in the Tokens of a type and returns the tokens of an expression with return type `Schema` fn type_to_schema_expr(ty: &Type) -> Result<TokenStream, Vec<syn::Error>> { match ty { @@ -492,35 +460,6 @@ fn type_to_field_default_expr(ty: &Type) -> Result<TokenStream, Vec<syn::Error>> } } -fn default_enum_variant( - data_enum: &syn::DataEnum, - error_span: Span, -) -> Result<Option<String>, Vec<syn::Error>> { - match data_enum - .variants - .iter() - .filter(|v| v.attrs.iter().any(is_default_attr)) - .collect::<Vec<_>>() - { - variants if variants.is_empty() => Ok(None), - single if single.len() == 1 => Ok(Some(single[0].ident.to_string())), - multiple => Err(vec![syn::Error::new( - error_span, - format!( - "Multiple defaults defined: {:?}", - multiple - .iter() - .map(|v| v.ident.to_string()) - .collect::<Vec<String>>() - ), - )]), - } -} - -fn is_default_attr(attr: &Attribute) -> bool { - matches!(attr, Attribute { meta: Meta::Path(path), .. } if path.get_ident().map(Ident::to_string).as_deref() == Some("default")) -} - /// Stolen from serde fn to_compile_errors(errors: Vec<syn::Error>) -> proc_macro2::TokenStream { let compile_errors = errors.iter().map(syn::Error::to_compile_error); diff --git a/avro_derive/tests/enum.rs b/avro_derive/tests/enum.rs new file mode 100644 index 0000000..065f06a --- /dev/null +++ b/avro_derive/tests/enum.rs @@ -0,0 +1,134 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use apache_avro::{AvroSchema, Error, Reader, Schema, Writer, from_value}; +use serde::{Deserialize, Serialize, de::DeserializeOwned}; + +/// Takes in a type that implements the right combination of traits and runs it through a Serde +/// round-trip and asserts the result is the same. +fn serde_assert<T>(obj: T) +where + T: std::fmt::Debug + Serialize + DeserializeOwned + AvroSchema + Clone + PartialEq, +{ + assert_eq!(obj, serde(obj.clone()).unwrap()); +} + +/// Takes in a type that implements the right combination of traits and runs it through a Serde +/// round-trip and asserts that the error matches the expected string. +fn serde_assert_err<T>(obj: T, expected: &str) +where + T: std::fmt::Debug + Serialize + DeserializeOwned + AvroSchema + Clone + PartialEq, +{ + let error = serde(obj).unwrap_err().to_string(); + assert!( + error.contains(expected), + "Error `{error}` does not contain `{expected}`" + ); +} + +fn serde<T>(obj: T) -> Result<T, Error> +where + T: Serialize + DeserializeOwned + AvroSchema, +{ + de(ser(obj)?) +} + +fn ser<T>(obj: T) -> Result<Vec<u8>, Error> +where + T: Serialize + AvroSchema, +{ + let schema = T::get_schema(); + let mut writer = Writer::new(&schema, Vec::new())?; + writer.append_ser(obj)?; + writer.into_inner() +} + +fn de<T>(encoded: Vec<u8>) -> Result<T, Error> +where + T: DeserializeOwned + AvroSchema, +{ + assert!(!encoded.is_empty()); + let schema = T::get_schema(); + let mut reader = Reader::builder(&encoded[..]) + .reader_schema(&schema) + .build()?; + if let Some(res) = reader.next() { + return res.and_then(|v| from_value::<T>(&v)); + } + panic!("Nothing was encoded!") +} + +#[test] +fn avro_rs_xxx_enum_repr_default() { + #[derive(AvroSchema, Debug, Serialize, Deserialize, Clone, PartialEq)] + enum Foo { + A, + B, + C, + } + + assert!(matches!(Foo::get_schema(), Schema::Enum(_))); + serde_assert(Foo::A); +} + +#[test] +fn avro_rs_xxx_enum_repr_enum() { + #[derive(AvroSchema, Debug, Serialize, Deserialize, Clone, PartialEq)] + #[avro(repr = "enum")] + enum Foo { + A, + B, + C, + } + + assert!(matches!(Foo::get_schema(), Schema::Enum(_))); + serde_assert(Foo::A); +} + +#[test] +fn avro_rs_xxx_enum_repr_discriminator_value_plain() { + #[derive(AvroSchema, Debug, Serialize, Deserialize, Clone, PartialEq)] + #[avro(repr = "discriminator_value")] + #[serde(tag = "type", content = "value")] + enum Foo { + A, + B, + C, + } + + assert!(matches!(Foo::get_schema(), Schema::Record(_))); + println!("{:#?}", Foo::get_schema()); + serde_assert(Foo::A); +} + +#[test] +fn avro_rs_xxx_enum_repr_discriminator_value_tuple() { + #[derive(AvroSchema, Debug, Serialize, Deserialize, Clone, PartialEq)] + #[avro(repr = "discriminator_value")] + #[serde(tag = "type", content = "value")] + enum Foo { + A(), + B(String), + C(String, bool), + } + + assert!(matches!(Foo::get_schema(), Schema::Record(_))); + println!("{:#?}", Foo::get_schema()); + serde_assert(Foo::A()); + serde_assert(Foo::B("Something".to_string())); + serde_assert(Foo::C("Something".to_string(), true)); +}
