This is an automated email from the ASF dual-hosted git repository. alamb pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/main by this push: new 8c80fe17ed Added arrow-avro enum mapping support for schema resolution (#8223) 8c80fe17ed is described below commit 8c80fe17edfb85c1c6a9b57abb25155cb1288631 Author: Connor Sanders <con...@elastiflow.com> AuthorDate: Sat Sep 6 04:32:35 2025 -0500 Added arrow-avro enum mapping support for schema resolution (#8223) # Which issue does this PR close? - Part of https://github.com/apache/arrow-rs/issues/4886 - Follows up on https://github.com/apache/arrow-rs/pull/8047 # Rationale for this change Avro `enum` values are **encoded by index** but are **semantically identified by symbol name**. During schema evolution it is legal for the writer and reader to use different enum symbol *orders* so long as the **symbol set is compatible**. The Avro specification requires that, when resolving a writer enum against a reader enum, the value be mapped **by symbol name**, not by the writer’s numeric index. If the writer’s symbol is not present in the reader’s enum and the reader defines a default, the default is used; otherwise it is an error. # What changes are included in this PR? **Core changes** - Implement **writer to reader enum symbol remapping**: - Build a fast lookup table at schema resolution time from **writer enum index to reader enum index** using symbol **names**. - Apply this mapping during decode so the produced Arrow dictionary keys always reference the **reader’s** symbol order. - If a writer symbol is not found in the reader enum, surface a clear error. # Are these changes tested? Yes. This PR adds comprehensive **unit tests** for enum mapping in `reader/record.rs` and a **real‑file integration test** in `reader/mod.rs` using `avro/simple_enum.avro`. # Are there any user-facing changes? N/A due to `arrow-avro` not being public yet. --- arrow-avro/src/codec.rs | 312 +++++++++++++++++++++++++++------------- arrow-avro/src/reader/mod.rs | 93 ++++++++++++ arrow-avro/src/reader/record.rs | 159 ++++++++++++++++++-- 3 files changed, 454 insertions(+), 110 deletions(-) diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index bf2ee6deab..d19e9b8ccc 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -16,7 +16,7 @@ // under the License. use crate::schema::{ - Attributes, AvroSchema, ComplexType, PrimitiveType, Record, Schema, Type, TypeName, + Attributes, AvroSchema, ComplexType, Enum, PrimitiveType, Record, Schema, Type, TypeName, AVRO_ENUM_SYMBOLS_METADATA_KEY, }; use arrow_schema::{ @@ -48,7 +48,7 @@ pub(crate) enum ResolutionInfo { Promotion(Promotion), /// Indicates that a default value should be used for a field. (Implemented in a Follow-up PR) DefaultValue(AvroLiteral), - /// Provides mapping information for resolving enums. (Implemented in a Follow-up PR) + /// Provides mapping information for resolving enums. EnumMapping(EnumMapping), /// Provides resolution information for record fields. (Implemented in a Follow-up PR) Record(ResolvedRecord), @@ -587,6 +587,63 @@ impl<'a> Resolver<'a> { } } +fn names_match( + writer_name: &str, + writer_aliases: &[&str], + reader_name: &str, + reader_aliases: &[&str], +) -> bool { + writer_name == reader_name + || reader_aliases.contains(&writer_name) + || writer_aliases.contains(&reader_name) +} + +fn ensure_names_match( + data_type: &str, + writer_name: &str, + writer_aliases: &[&str], + reader_name: &str, + reader_aliases: &[&str], +) -> Result<(), ArrowError> { + if names_match(writer_name, writer_aliases, reader_name, reader_aliases) { + Ok(()) + } else { + Err(ArrowError::ParseError(format!( + "{data_type} name mismatch writer={writer_name}, reader={reader_name}" + ))) + } +} + +fn primitive_of(schema: &Schema) -> Option<PrimitiveType> { + match schema { + Schema::TypeName(TypeName::Primitive(primitive)) => Some(*primitive), + Schema::Type(Type { + r#type: TypeName::Primitive(primitive), + .. + }) => Some(*primitive), + _ => None, + } +} + +fn nullable_union_variants<'x, 'y>( + variant: &'y [Schema<'x>], +) -> Option<(Nullability, &'y Schema<'x>)> { + if variant.len() != 2 { + return None; + } + let is_null = |schema: &Schema<'x>| { + matches!( + schema, + Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)) + ) + }; + match (is_null(&variant[0]), is_null(&variant[1])) { + (true, false) => Some((Nullability::NullFirst, &variant[1])), + (false, true) => Some((Nullability::NullSecond, &variant[0])), + _ => None, + } +} + /// Resolves Avro type names to [`AvroDataType`] /// /// See <https://avro.apache.org/docs/1.11.1/specification/#names> @@ -815,77 +872,36 @@ impl<'a> Maker<'a> { reader_schema: &'s Schema<'a>, namespace: Option<&'a str>, ) -> Result<AvroDataType, ArrowError> { + if let (Some(write_primitive), Some(read_primitive)) = + (primitive_of(writer_schema), primitive_of(reader_schema)) + { + return self.resolve_primitives(write_primitive, read_primitive, reader_schema); + } match (writer_schema, reader_schema) { - ( - Schema::TypeName(TypeName::Primitive(writer_primitive)), - Schema::TypeName(TypeName::Primitive(reader_primitive)), - ) => self.resolve_primitives(*writer_primitive, *reader_primitive, reader_schema), - ( - Schema::Type(Type { - r#type: TypeName::Primitive(writer_primitive), - .. - }), - Schema::Type(Type { - r#type: TypeName::Primitive(reader_primitive), - .. - }), - ) => self.resolve_primitives(*writer_primitive, *reader_primitive, reader_schema), - ( - Schema::TypeName(TypeName::Primitive(writer_primitive)), - Schema::Type(Type { - r#type: TypeName::Primitive(reader_primitive), - .. - }), - ) => self.resolve_primitives(*writer_primitive, *reader_primitive, reader_schema), - ( - Schema::Type(Type { - r#type: TypeName::Primitive(writer_primitive), - .. - }), - Schema::TypeName(TypeName::Primitive(reader_primitive)), - ) => self.resolve_primitives(*writer_primitive, *reader_primitive, reader_schema), ( Schema::Complex(ComplexType::Record(writer_record)), Schema::Complex(ComplexType::Record(reader_record)), ) => self.resolve_records(writer_record, reader_record, namespace), - (Schema::Union(writer_variants), Schema::Union(reader_variants)) => { - self.resolve_nullable_union(writer_variants, reader_variants, namespace) - } + ( + Schema::Complex(ComplexType::Enum(writer_enum)), + Schema::Complex(ComplexType::Enum(reader_enum)), + ) => self.resolve_enums(writer_enum, reader_enum, reader_schema, namespace), + (Schema::Union(writer_variants), Schema::Union(reader_variants)) => self + .resolve_nullable_union( + writer_variants.as_slice(), + reader_variants.as_slice(), + namespace, + ), + (Schema::TypeName(TypeName::Ref(_)), _) => self.parse_type(reader_schema, namespace), + (_, Schema::TypeName(TypeName::Ref(_))) => self.parse_type(reader_schema, namespace), // if both sides are the same complex kind (non-record), adopt the reader type. // This aligns with Avro spec: arrays, maps, and enums resolve recursively; // for identical shapes we can just parse the reader schema. (Schema::Complex(ComplexType::Array(_)), Schema::Complex(ComplexType::Array(_))) | (Schema::Complex(ComplexType::Map(_)), Schema::Complex(ComplexType::Map(_))) - | (Schema::Complex(ComplexType::Fixed(_)), Schema::Complex(ComplexType::Fixed(_))) - | (Schema::Complex(ComplexType::Enum(_)), Schema::Complex(ComplexType::Enum(_))) => { + | (Schema::Complex(ComplexType::Fixed(_)), Schema::Complex(ComplexType::Fixed(_))) => { self.parse_type(reader_schema, namespace) } - // Named-type references (equal on both sides) – parse reader side. - (Schema::TypeName(TypeName::Ref(_)), Schema::TypeName(TypeName::Ref(_))) - | ( - Schema::Type(Type { - r#type: TypeName::Ref(_), - .. - }), - Schema::Type(Type { - r#type: TypeName::Ref(_), - .. - }), - ) - | ( - Schema::TypeName(TypeName::Ref(_)), - Schema::Type(Type { - r#type: TypeName::Ref(_), - .. - }), - ) - | ( - Schema::Type(Type { - r#type: TypeName::Ref(_), - .. - }), - Schema::TypeName(TypeName::Ref(_)), - ) => self.parse_type(reader_schema, namespace), _ => Err(ArrowError::NotYetImplemented( "Other resolutions not yet implemented".to_string(), )), @@ -921,64 +937,156 @@ impl<'a> Maker<'a> { Ok(datatype) } - fn resolve_nullable_union( + fn resolve_nullable_union<'s>( &mut self, - writer_variants: &[Schema<'a>], - reader_variants: &[Schema<'a>], + writer_variants: &'s [Schema<'a>], + reader_variants: &'s [Schema<'a>], namespace: Option<&'a str>, ) -> Result<AvroDataType, ArrowError> { - // Only support unions with exactly two branches, one of which is `null` on both sides - if writer_variants.len() != 2 || reader_variants.len() != 2 { - return Err(ArrowError::NotYetImplemented( - "Only 2-branch unions are supported for schema resolution".to_string(), - )); - } - let is_null = |s: &Schema<'a>| { - matches!( - s, - Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)) - ) - }; - let w_null_pos = writer_variants.iter().position(is_null); - let r_null_pos = reader_variants.iter().position(is_null); - match (w_null_pos, r_null_pos) { - (Some(wp), Some(rp)) => { - // Extract a non-null branch on each side - let w_nonnull = &writer_variants[1 - wp]; - let r_nonnull = &reader_variants[1 - rp]; - // Resolve the non-null branch - let mut dt = self.make_data_type(w_nonnull, Some(r_nonnull), namespace)?; + match ( + nullable_union_variants(writer_variants), + nullable_union_variants(reader_variants), + ) { + (Some((_, write_nonnull)), Some((read_nb, read_nonnull))) => { + let mut dt = self.make_data_type(write_nonnull, Some(read_nonnull), namespace)?; // Adopt reader union null ordering - dt.nullability = Some(match rp { - 0 => Nullability::NullFirst, - 1 => Nullability::NullSecond, - _ => unreachable!(), - }); + dt.nullability = Some(read_nb); Ok(dt) } _ => Err(ArrowError::NotYetImplemented( - "Union resolution requires both writer and reader to be nullable unions" + "Union resolution requires both writer and reader to be 2-branch nullable unions" .to_string(), )), } } + // Resolve writer vs. reader enum schemas according to Avro 1.11.1. + // + // # How enums resolve (writer to reader) + // Per “Schema Resolution”: + // * The two schemas must refer to the same (unqualified) enum name (or match + // via alias rewriting). + // * If the writer’s symbol is not present in the reader’s enum and the reader + // enum has a `default`, that `default` symbol must be used; otherwise, + // error. + // https://avro.apache.org/docs/1.11.1/specification/#schema-resolution + // * Avro “Aliases” are applied from the reader side to rewrite the writer’s + // names during resolution. For robustness across ecosystems, we also accept + // symmetry here (see note below). + // https://avro.apache.org/docs/1.11.1/specification/#aliases + // + // # Rationale for this code path + // 1. Do the work once at schema‑resolution time. Avro serializes an enum as a + // writer‑side position. Mapping positions on the hot decoder path is expensive + // if done with string lookups. This method builds a `writer_index to reader_index` + // vector once, so decoding just does an O(1) table lookup. + // 2. Adopt the reader’s symbol set and order. We return an Arrow + // `Dictionary(Int32, Utf8)` whose dictionary values are the reader enum + // symbols. This makes downstream semantics match the reader schema, including + // Avro’s sort order rule that orders enums by symbol position in the schema. + // https://avro.apache.org/docs/1.11.1/specification/#sort-order + // 3. Honor Avro’s `default` for enums. Avro 1.9+ allows a type‑level default + // on the enum. When the writer emits a symbol unknown to the reader, we map it + // to the reader’s validated `default` symbol if present; otherwise we signal an + // error at decoding time. + // https://avro.apache.org/docs/1.11.1/specification/#enums + // + // # Implementation notes + // * We first check that enum names match or are*alias‑equivalent. The Avro + // spec describes alias rewriting using reader aliases; this implementation + // additionally treats writer aliases as acceptable for name matching to be + // resilient with schemas produced by different tooling. + // * We build `EnumMapping`: + // - `mapping[i]` = reader index of the writer symbol at writer index `i`. + // - If the writer symbol is absent and the reader has a default, we store the + // reader index of that default. + // - Otherwise we store `-1` as a sentinel meaning unresolvable; the decoder + // must treat encountering such a value as an error, per the spec. + // * We persist the reader symbol list in field metadata under + // `AVRO_ENUM_SYMBOLS_METADATA_KEY`, so consumers can inspect the dictionary + // without needing the original Avro schema. + // * The Arrow representation is `Dictionary(Int32, Utf8)`, which aligns with + // Avro’s integer index encoding for enums. + // + // # Examples + // * Writer `["A","B","C"]`, Reader `["A","B"]`, Reader default `"A"` + // `mapping = [0, 1, 0]`, `default_index = 0`. + // * Writer `["A","B"]`, Reader `["B","A"]` (no default) + // `mapping = [1, 0]`, `default_index = -1`. + // * Writer `["A","B","C"]`, Reader `["A","B"]` (no default) + // `mapping = [0, 1, -1]` (decode must error on `"C"`). + fn resolve_enums( + &mut self, + writer_enum: &Enum<'a>, + reader_enum: &Enum<'a>, + reader_schema: &Schema<'a>, + namespace: Option<&'a str>, + ) -> Result<AvroDataType, ArrowError> { + ensure_names_match( + "Enum", + writer_enum.name, + &writer_enum.aliases, + reader_enum.name, + &reader_enum.aliases, + )?; + if writer_enum.symbols == reader_enum.symbols { + return self.parse_type(reader_schema, namespace); + } + let reader_index: HashMap<&str, i32> = reader_enum + .symbols + .iter() + .enumerate() + .map(|(index, &symbol)| (symbol, index as i32)) + .collect(); + let default_index: i32 = match reader_enum.default { + Some(symbol) => *reader_index.get(symbol).ok_or_else(|| { + ArrowError::SchemaError(format!( + "Reader enum '{}' default symbol '{symbol}' not found in symbols list", + reader_enum.name, + )) + })?, + None => -1, + }; + let mapping: Vec<i32> = writer_enum + .symbols + .iter() + .map(|&write_symbol| { + reader_index + .get(write_symbol) + .copied() + .unwrap_or(default_index) + }) + .collect(); + if self.strict_mode && mapping.iter().any(|&m| m < 0) { + return Err(ArrowError::SchemaError(format!( + "Reader enum '{}' does not cover all writer symbols and no default is provided", + reader_enum.name + ))); + } + let mut dt = self.parse_type(reader_schema, namespace)?; + dt.resolution = Some(ResolutionInfo::EnumMapping(EnumMapping { + mapping: Arc::from(mapping), + default_index, + })); + let reader_ns = reader_enum.namespace.or(namespace); + self.resolver + .register(reader_enum.name, reader_ns, dt.clone()); + Ok(dt) + } + fn resolve_records( &mut self, writer_record: &Record<'a>, reader_record: &Record<'a>, namespace: Option<&'a str>, ) -> Result<AvroDataType, ArrowError> { - // Names must match or be aliased - let names_match = writer_record.name == reader_record.name - || reader_record.aliases.contains(&writer_record.name) - || writer_record.aliases.contains(&reader_record.name); - if !names_match { - return Err(ArrowError::ParseError(format!( - "Record name mismatch writer={}, reader={}", - writer_record.name, reader_record.name - ))); - } + ensure_names_match( + "Record", + writer_record.name, + &writer_record.aliases, + reader_record.name, + &reader_record.aliases, + )?; let writer_ns = writer_record.namespace.or(namespace); let reader_ns = reader_record.namespace.or(namespace); // Map writer field name -> index @@ -995,7 +1103,7 @@ impl<'a> Maker<'a> { // Build reader fields and mapping for (reader_idx, r_field) in reader_record.fields.iter().enumerate() { if let Some(&writer_idx) = writer_index_map.get(r_field.name) { - // Field exists in writer: resolve types (including promotions and union-of-null) + // Field exists in a writer: resolve types (including promotions and union-of-null) let w_schema = &writer_record.fields[writer_idx].r#type; let resolved_dt = self.make_data_type(w_schema, Some(&r_field.r#type), reader_ns)?; diff --git a/arrow-avro/src/reader/mod.rs b/arrow-avro/src/reader/mod.rs index c7cebb393c..d1910790e5 100644 --- a/arrow-avro/src/reader/mod.rs +++ b/arrow-avro/src/reader/mod.rs @@ -910,6 +910,53 @@ mod test { AvroSchema::new(root.to_string()) } + fn make_reader_schema_with_enum_remap( + path: &str, + remap: &HashMap<&str, Vec<&str>>, + ) -> AvroSchema { + let mut root = load_writer_schema_json(path); + assert_eq!(root["type"], "record", "writer schema must be a record"); + let fields = root + .get_mut("fields") + .and_then(|f| f.as_array_mut()) + .expect("record has fields"); + + fn to_symbols_array(symbols: &[&str]) -> Value { + Value::Array(symbols.iter().map(|s| Value::String((*s).into())).collect()) + } + + fn update_enum_symbols(ty: &mut Value, symbols: &Value) { + match ty { + Value::Object(map) => { + if matches!(map.get("type"), Some(Value::String(t)) if t == "enum") { + map.insert("symbols".to_string(), symbols.clone()); + } + } + Value::Array(arr) => { + for b in arr.iter_mut() { + if let Value::Object(map) = b { + if matches!(map.get("type"), Some(Value::String(t)) if t == "enum") { + map.insert("symbols".to_string(), symbols.clone()); + } + } + } + } + _ => {} + } + } + for f in fields.iter_mut() { + let Some(name) = f.get("name").and_then(|n| n.as_str()) else { + continue; + }; + if let Some(new_symbols) = remap.get(name) { + let symbols_val = to_symbols_array(new_symbols); + let ty = f.get_mut("type").expect("field has a type"); + update_enum_symbols(ty, &symbols_val); + } + } + AvroSchema::new(root.to_string()) + } + fn read_alltypes_with_reader_schema(path: &str, reader_schema: AvroSchema) -> RecordBatch { let file = File::open(path).unwrap(); let reader = ReaderBuilder::new() @@ -1289,6 +1336,52 @@ mod test { ); } + #[test] + fn test_simple_enum_with_reader_schema_mapping() { + let file = arrow_test_data("avro/simple_enum.avro"); + let mut remap: HashMap<&str, Vec<&str>> = HashMap::new(); + remap.insert("f1", vec!["d", "c", "b", "a"]); + remap.insert("f2", vec!["h", "g", "f", "e"]); + remap.insert("f3", vec!["k", "i", "j"]); + let reader_schema = make_reader_schema_with_enum_remap(&file, &remap); + let actual = read_alltypes_with_reader_schema(&file, reader_schema); + let dict_type = DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); + let f1_keys = Int32Array::from(vec![3, 2, 1, 0]); + let f1_vals = StringArray::from(vec!["d", "c", "b", "a"]); + let f1 = DictionaryArray::<Int32Type>::try_new(f1_keys, Arc::new(f1_vals)).unwrap(); + let mut md_f1 = HashMap::new(); + md_f1.insert( + AVRO_ENUM_SYMBOLS_METADATA_KEY.to_string(), + r#"["d","c","b","a"]"#.to_string(), + ); + let f1_field = Field::new("f1", dict_type.clone(), false).with_metadata(md_f1); + let f2_keys = Int32Array::from(vec![1, 0, 3, 2]); + let f2_vals = StringArray::from(vec!["h", "g", "f", "e"]); + let f2 = DictionaryArray::<Int32Type>::try_new(f2_keys, Arc::new(f2_vals)).unwrap(); + let mut md_f2 = HashMap::new(); + md_f2.insert( + AVRO_ENUM_SYMBOLS_METADATA_KEY.to_string(), + r#"["h","g","f","e"]"#.to_string(), + ); + let f2_field = Field::new("f2", dict_type.clone(), false).with_metadata(md_f2); + let f3_keys = Int32Array::from(vec![Some(2), Some(0), None, Some(1)]); + let f3_vals = StringArray::from(vec!["k", "i", "j"]); + let f3 = DictionaryArray::<Int32Type>::try_new(f3_keys, Arc::new(f3_vals)).unwrap(); + let mut md_f3 = HashMap::new(); + md_f3.insert( + AVRO_ENUM_SYMBOLS_METADATA_KEY.to_string(), + r#"["k","i","j"]"#.to_string(), + ); + let f3_field = Field::new("f3", dict_type.clone(), true).with_metadata(md_f3); + let expected_schema = Arc::new(Schema::new(vec![f1_field, f2_field, f3_field])); + let expected = RecordBatch::try_new( + expected_schema, + vec![Arc::new(f1) as ArrayRef, Arc::new(f2), Arc::new(f3)], + ) + .unwrap(); + assert_eq!(actual, expected); + } + #[test] fn test_schema_store_register_lookup() { let schema_int = make_record_schema(PrimitiveType::Int); diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index e219efabb9..6e5756ef41 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -248,6 +248,12 @@ enum Decoder { Decimal128(usize, Option<usize>, Option<usize>, Decimal128Builder), Decimal256(usize, Option<usize>, Option<usize>, Decimal256Builder), Nullable(Nullability, NullBufferBuilder, Box<Decoder>), + EnumResolved { + indices: Vec<i32>, + symbols: Arc<[String]>, + mapping: Arc<[i32]>, + default_index: i32, + }, /// Resolved record that needs writer->reader projection and skipping writer-only fields RecordResolved { fields: Fields, @@ -369,7 +375,16 @@ impl Decoder { ) } (Codec::Enum(symbols), _) => { - Self::Enum(Vec::with_capacity(DEFAULT_CAPACITY), symbols.clone()) + if let Some(ResolutionInfo::EnumMapping(mapping)) = data_type.resolution.as_ref() { + Self::EnumResolved { + indices: Vec::with_capacity(DEFAULT_CAPACITY), + symbols: symbols.clone(), + mapping: mapping.mapping.clone(), + default_index: mapping.default_index, + } + } else { + Self::Enum(Vec::with_capacity(DEFAULT_CAPACITY), symbols.clone()) + } } (Codec::Struct(fields), _) => { let mut arrow_fields = Vec::with_capacity(fields.len()); @@ -461,6 +476,7 @@ impl Decoder { Self::Decimal128(_, _, _, builder) => builder.append_value(0), Self::Decimal256(_, _, _, builder) => builder.append_value(i256::ZERO), Self::Enum(indices, _) => indices.push(0), + Self::EnumResolved { indices, .. } => indices.push(0), Self::Duration(builder) => builder.append_null(), Self::Nullable(_, null_buffer, inner) => { null_buffer.append(false); @@ -555,6 +571,26 @@ impl Decoder { Self::Enum(indices, _) => { indices.push(buf.get_int()?); } + Self::EnumResolved { + indices, + mapping, + default_index, + .. + } => { + let raw = buf.get_int()?; + let resolved = usize::try_from(raw) + .ok() + .and_then(|idx| mapping.get(idx).copied()) + .filter(|&idx| idx >= 0) + .unwrap_or(*default_index); + if resolved >= 0 { + indices.push(resolved); + } else { + return Err(ArrowError::ParseError(format!( + "Enum symbol index {raw} not resolvable and no default provided", + ))); + } + } Self::Duration(builder) => { let b = buf.get_fixed(12)?; let months = u32::from_le_bytes(b[0..4].try_into().unwrap()); @@ -722,13 +758,10 @@ impl Decoder { .map_err(|e| ArrowError::ParseError(e.to_string()))?; Arc::new(dec) } - Self::Enum(indices, symbols) => { - let keys = flush_primitive::<Int32Type>(indices, nulls); - let values = Arc::new(StringArray::from( - symbols.iter().map(|s| s.as_str()).collect::<Vec<_>>(), - )); - Arc::new(DictionaryArray::try_new(keys, values)?) - } + Self::Enum(indices, symbols) => flush_dict(indices, symbols, nulls)?, + Self::EnumResolved { + indices, symbols, .. + } => flush_dict(indices, symbols, nulls)?, Self::Duration(builder) => { let (_, vals, _) = builder.finish().into_parts(); let vals = IntervalMonthDayNanoArray::try_new(vals, nulls) @@ -766,6 +799,21 @@ fn skip_blocks( ) } +#[inline] +fn flush_dict( + indices: &mut Vec<i32>, + symbols: &[String], + nulls: Option<NullBuffer>, +) -> Result<ArrayRef, ArrowError> { + let keys = flush_primitive::<Int32Type>(indices, nulls); + let values = Arc::new(StringArray::from_iter_values( + symbols.iter().map(|s| s.as_str()), + )); + DictionaryArray::try_new(keys, values) + .map_err(|e| ArrowError::ParseError(e.to_string())) + .map(|arr| Arc::new(arr) as ArrayRef) +} + #[inline] fn read_blocks( buf: &mut AvroCursor, @@ -1761,6 +1809,101 @@ mod tests { assert_eq!(int_array.value(1), 42); // row3 value is 42 } + #[test] + fn test_enum_mapping_reordered_symbols() { + let reader_symbols: Arc<[String]> = + vec!["B".to_string(), "C".to_string(), "A".to_string()].into(); + let mapping: Arc<[i32]> = Arc::from(vec![2, 0, 1]); + let default_index: i32 = -1; + let mut dec = Decoder::EnumResolved { + indices: Vec::with_capacity(DEFAULT_CAPACITY), + symbols: reader_symbols.clone(), + mapping, + default_index, + }; + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_int(0)); + data.extend_from_slice(&encode_avro_int(1)); + data.extend_from_slice(&encode_avro_int(2)); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + dec.decode(&mut cur).unwrap(); + dec.decode(&mut cur).unwrap(); + let arr = dec.flush(None).unwrap(); + let dict = arr + .as_any() + .downcast_ref::<DictionaryArray<Int32Type>>() + .unwrap(); + let expected_keys = Int32Array::from(vec![2, 0, 1]); + assert_eq!(dict.keys(), &expected_keys); + let values = dict + .values() + .as_any() + .downcast_ref::<StringArray>() + .unwrap(); + assert_eq!(values.value(0), "B"); + assert_eq!(values.value(1), "C"); + assert_eq!(values.value(2), "A"); + } + + #[test] + fn test_enum_mapping_unknown_symbol_and_out_of_range_fall_back_to_default() { + let reader_symbols: Arc<[String]> = vec!["A".to_string(), "B".to_string()].into(); + let default_index: i32 = 1; + let mapping: Arc<[i32]> = Arc::from(vec![0, 1]); + let mut dec = Decoder::EnumResolved { + indices: Vec::with_capacity(DEFAULT_CAPACITY), + symbols: reader_symbols.clone(), + mapping, + default_index, + }; + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_int(0)); + data.extend_from_slice(&encode_avro_int(1)); + data.extend_from_slice(&encode_avro_int(99)); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + dec.decode(&mut cur).unwrap(); + dec.decode(&mut cur).unwrap(); + let arr = dec.flush(None).unwrap(); + let dict = arr + .as_any() + .downcast_ref::<DictionaryArray<Int32Type>>() + .unwrap(); + let expected_keys = Int32Array::from(vec![0, 1, 1]); + assert_eq!(dict.keys(), &expected_keys); + let values = dict + .values() + .as_any() + .downcast_ref::<StringArray>() + .unwrap(); + assert_eq!(values.value(0), "A"); + assert_eq!(values.value(1), "B"); + } + + #[test] + fn test_enum_mapping_unknown_symbol_without_default_errors() { + let reader_symbols: Arc<[String]> = vec!["A".to_string()].into(); + let default_index: i32 = -1; // indicates no default at type-level + let mapping: Arc<[i32]> = Arc::from(vec![-1]); + let mut dec = Decoder::EnumResolved { + indices: Vec::with_capacity(DEFAULT_CAPACITY), + symbols: reader_symbols, + mapping, + default_index, + }; + let data = encode_avro_int(0); + let mut cur = AvroCursor::new(&data); + let err = dec + .decode(&mut cur) + .expect_err("expected decode error for unresolved enum without default"); + let msg = err.to_string(); + assert!( + msg.contains("not resolvable") && msg.contains("no default"), + "unexpected error message: {msg}" + ); + } + fn make_record_resolved_decoder( reader_fields: &[(&str, DataType, bool)], writer_to_reader: Vec<Option<usize>>,