Hi Simon, Yes, we are interested in any and all kinds of improvements! Please create a JIRA issue and send a Pull Request with the diff above! Thank you!
Martin On Sat, Nov 4, 2023 at 3:25 AM Simon Gittins <sgitt...@gmail.com> wrote: > Hello > > I've written some failing rust unit tests showing: > - Non trivial union schemas (without null variant) failing to > encode/decode round-trip (serde) > - Non trivial union schemas including null variant at position 0 failing > to round trip (serde) > - Non trivial union schemas including null varant at position other than 0 > failing to round trip (serde) > > I've also attached and appended an update to the encoder that makes the > above tests pass. Is this a patch that the team is interested in? > > Thanks > Simon > > diff --git a/lang/rust/avro/src/encode.rs b/lang/rust/avro/src/encode.rs > index 4593779ac..829a8ee6c 100644 > --- a/lang/rust/avro/src/encode.rs > +++ b/lang/rust/avro/src/encode.rs > @@ -19,7 +19,7 @@ use crate::{ > decimal::serialize_big_decimal, > schema::{ > DecimalSchema, EnumSchema, FixedSchema, Name, Namespace, > RecordSchema, ResolvedSchema, > - Schema, SchemaKind, > + UnionSchema, Schema, SchemaKind, > }, > types::{Value, ValueKind}, > util::{zig_i32, zig_i64}, > @@ -71,7 +71,20 @@ pub(crate) fn encode_internal<S: Borrow<Schema>>( > } > > match value { > - Value::Null => (), > + Value::Null => { > + match schema { > + Schema::Union(s) => { > + match s.schemas.iter().position(|sch|*sch == > Schema::Null) { > + None => > + return Err(Error::EncodeValueAsSchemaError { > + value_kind: ValueKind::Null, > + supported_schema: vec![SchemaKind::Null, > SchemaKind::Union], }), > + Some(p) => encode_long(p as i64, buffer), > + } > + } > + _ => () > + } > + }, > Value::Boolean(b) => buffer.push(u8::from(*b)), > // Pattern | Pattern here to signify that these _must_ have the > same encoding. > Value::Int(i) | Value::Date(i) | Value::TimeMillis(i) => > encode_int(*i, buffer), > @@ -242,6 +255,21 @@ pub(crate) fn encode_internal<S: Borrow<Schema>>( > )); > } > } > + } else if let Schema::Union(UnionSchema{ schemas, .. }) = > schema { > + let original_size = buffer.len(); > + for (index,s) in schemas.iter().enumerate() { > + encode_long(index as i64, buffer); > + match encode_internal(value, s.borrow(), names, > enclosing_namespace, buffer) { > + Ok(_) => return Ok(()), > + Err(e) => { > + buffer.truncate(original_size); //undo any > partial encoding > + } > + } > + } > + return Err(Error::EncodeValueAsSchemaError { > + value_kind: ValueKind::Record, > + supported_schema: vec![SchemaKind::Record, > SchemaKind::Union], > + }); > } else { > error!("invalid schema type for Record: {:?}", schema); > return Err(Error::EncodeValueAsSchemaError { > diff --git a/lang/rust/avro/tests/union_schema.rs b/lang/rust/avro/tests/ > union_schema.rs > index e69de29bb..1dc19d25c 100644 > --- a/lang/rust/avro/tests/union_schema.rs > +++ b/lang/rust/avro/tests/union_schema.rs > @@ -0,0 +1,185 @@ > +// Licensed to the Apache Software Foundation (ASF) under one > +// or more contributor license agreements. See the NOTICE file > +// distributed with this work for additional information > +// regarding copyright ownership. The ASF licenses this file > +// to you under the Apache License, Version 2.0 (the > +// "License"); you may not use this file except in compliance > +// with the License. You may obtain a copy of the License at > +// > +// http://www.apache.org/licenses/LICENSE-2.0 > +// > +// Unless required by applicable law or agreed to in writing, > +// software distributed under the License is distributed on an > +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY > +// KIND, either express or implied. See the License for the > +// specific language governing permissions and limitations > +// under the License. > + > +use serde::{Deserialize, Serialize}; > +use serde::de::DeserializeOwned; > +use apache_avro::{from_avro_datum, to_avro_datum, to_value, from_value, > types, Schema, Writer, Reader, Codec}; > + > + > +static SCHEMA_A_STR: &str = r#"{ > + "name": "A", > + "type": "record", > + "fields": [ > + {"name": "field_a", "type": "float"} > + ] > + }"#; > + > +static SCHEMA_B_STR: &str = r#"{ > + "name": "B", > + "type": "record", > + "fields": [ > + {"name": "field_b", "type": "long"} > + ] > + }"#; > + > +static SCHEMA_C_STR: &str = r#"{ > + "name": "C", > + "type": "record", > + "fields": [ > + {"name": "field_union", "type": ["A", "B"]}, > + {"name": "field_c", "type": "string"} > + ] > + }"#; > + > +#[derive(Serialize,Deserialize,Clone, PartialEq, Debug)] > +struct A { > + field_a: f32, > +} > + > +#[derive(Serialize,Deserialize,Clone, PartialEq, Debug)] > +struct B { > + field_b: i64, > +} > + > +#[derive(Serialize,Deserialize,Clone, PartialEq, Debug)] > +#[serde(untagged)] > +enum UnionAB { > + A(A), > + B(B), > +} > + > +#[derive(Serialize,Deserialize,Clone, PartialEq, Debug)] > +struct C { > + field_union: UnionAB, > + field_c: String > +} > + > +fn encode_decode<T> (input: &T,schema: &Schema,schemata: &Vec<Schema>) -> > T > + where T: DeserializeOwned + Serialize { > + let mut encoded: Vec<u8> = Vec::new(); > + let mut writer = Writer::with_schemata(&schema, > schemata.iter().collect(), &mut encoded, Codec::Null); > + writer.append_ser((input.clone())).unwrap(); > + writer.flush().unwrap(); > + > + let mut reader = Reader::with_schemata(schema, > schemata.iter().collect(), encoded.as_slice()).unwrap(); > + from_value::<T>(&reader.next().unwrap().unwrap()).unwrap() > +} > + > + > +#[test] > +fn union_schema_round_trip_no_null() { > + let schemata: Vec<Schema> = Schema::parse_list(&[SCHEMA_A_STR, > SCHEMA_B_STR, SCHEMA_C_STR]).expect("parsing schemata"); > + > + { > + let input = C { field_union: (UnionAB::A(A { field_a: 45.5 })), > field_c: "foo".to_string() }; > + let output = encode_decode(&input,&schemata[2],&schemata); > + assert_eq!(input,output); > + } > + { > + let input = C { field_union: (UnionAB::B(B { field_b: 73 })), > field_c: "bar".to_string() }; > + let output = encode_decode(&input,&schemata[2],&schemata); > + assert_eq!(input,output); > + } > +} > + > +static SCHEMA_D_STR: &str = r#"{ > + "name": "D", > + "type": "record", > + "fields": [ > + {"name": "field_union", "type": ["null", "A", "B"]}, > + {"name": "field_d", "type": "string"} > + ] > + }"#; > + > +#[derive(Serialize,Deserialize,Clone, PartialEq, Debug)] > +#[serde(untagged)] > +enum UnionNoneAB { > + None, > + A(A), > + B(B), > +} > + > +#[derive(Serialize,Deserialize,Clone, PartialEq, Debug)] > +struct D { > + field_union: UnionNoneAB, > + field_d: String > +} > + > +#[test] > +fn union_schema_round_trip_null_at_start() { > + let schemata: Vec<Schema> = Schema::parse_list(&[SCHEMA_A_STR, > SCHEMA_B_STR, SCHEMA_D_STR]).expect("parsing schemata"); > + > + { > + let input = D { field_union: UnionNoneAB::A(A { field_a: 54.25 > }), field_d: "fooy".to_string() }; > + let output = encode_decode(&input,&schemata[2],&schemata); > + assert_eq!(input,output); > + } > + { > + let input = D { field_union: UnionNoneAB::None, field_d: > "fooyy".to_string() }; > + let output = encode_decode(&input,&schemata[2],&schemata); > + assert_eq!(input,output); > + } > + { > + let input = D { field_union: UnionNoneAB::B(B { field_b: 103 }), > field_d: "foov".to_string() }; > + let output = encode_decode(&input,&schemata[2],&schemata); > + assert_eq!(input,output); > + } > +} > + > +static SCHEMA_E_STR: &str = r#"{ > + "name": "E", > + "type": "record", > + "fields": [ > + {"name": "field_union", "type": ["A", "null", "B"]}, > + {"name": "field_e", "type": "string"} > + ] > + }"#; > + > +#[derive(Serialize,Deserialize,Clone, PartialEq, Debug)] > +#[serde(untagged)] > +enum UnionANoneB { > + A(A), > + None, > + B(B), > +} > + > +#[derive(Serialize,Deserialize,Clone, PartialEq, Debug)] > +struct E { > + field_union: UnionANoneB, > + field_e: String > +} > + > +#[test] > +fn union_schema_round_trip_with_out_of_order_null() { > + let schemata: Vec<Schema> = Schema::parse_list(&[SCHEMA_A_STR, > SCHEMA_B_STR, SCHEMA_E_STR]).expect("parsing schemata"); > + > + { > + let input = E { field_union: UnionANoneB::A(A { field_a: 23.75 > }), field_e: "barme".to_string() }; > + let output = encode_decode(&input,&schemata[2],&schemata); > + assert_eq!(input,output); > + } > + { > + let input = E { field_union: UnionANoneB::None, field_e: > "barme2".to_string() }; > + let output = encode_decode(&input,&schemata[2],&schemata); > + assert_eq!(input,output); > + } > + { > + let input = E { field_union: UnionANoneB::B(B { field_b: 89 }), > field_e: "barme3".to_string() }; > + let output = encode_decode(&input,&schemata[2],&schemata); > + assert_eq!(input,output); > + } > +} > > > >