This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/main by this push:
new 05e0d15a37 Add Map support to arrow-avro (#7451)
05e0d15a37 is described below
commit 05e0d15a37a159a5fd1ff6abae7ef830d62d2aa6
Author: Connor Sanders <[email protected]>
AuthorDate: Fri May 23 13:37:34 2025 -0500
Add Map support to arrow-avro (#7451)
* Added support for reading Avro Maps types
* Fixed lint errors, improved readability of `read_blockwise_items`, added
`Map` comments and improved `Map` nullability handling in `data_type` in
codec.rs
---
arrow-avro/src/codec.rs | 44 ++++++++-
arrow-avro/src/reader/record.rs | 198 +++++++++++++++++++++++++++++++++++++++-
2 files changed, 237 insertions(+), 5 deletions(-)
diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs
index fdd4eb2e81..ca5d8dec33 100644
--- a/arrow-avro/src/codec.rs
+++ b/arrow-avro/src/codec.rs
@@ -17,7 +17,7 @@
use crate::schema::{Attributes, ComplexType, PrimitiveType, Record, Schema,
TypeName};
use arrow_schema::{
- ArrowError, DataType, Field, FieldRef, IntervalUnit, SchemaBuilder,
SchemaRef, TimeUnit,
+ ArrowError, DataType, Field, FieldRef, Fields, IntervalUnit,
SchemaBuilder, SchemaRef, TimeUnit,
};
use std::borrow::Cow;
use std::collections::HashMap;
@@ -45,6 +45,19 @@ pub struct AvroDataType {
}
impl AvroDataType {
+ /// Create a new [`AvroDataType`] with the given parts.
+ pub fn new(
+ codec: Codec,
+ metadata: HashMap<String, String>,
+ nullability: Option<Nullability>,
+ ) -> Self {
+ AvroDataType {
+ codec,
+ metadata,
+ nullability,
+ }
+ }
+
/// Returns an arrow [`Field`] with the given name
pub fn field_with_name(&self, name: &str) -> Field {
let d = self.codec.data_type();
@@ -183,6 +196,8 @@ pub enum Codec {
List(Arc<AvroDataType>),
/// Represents Avro record type, maps to Arrow's Struct data type
Struct(Arc<[AvroField]>),
+ /// Represents Avro map type, maps to Arrow's Map data type
+ Map(Arc<AvroDataType>),
/// Represents Avro duration logical type, maps to Arrow's
Interval(IntervalUnit::MonthDayNano) data type
Interval,
}
@@ -214,6 +229,22 @@ impl Codec {
DataType::List(Arc::new(f.field_with_name(Field::LIST_FIELD_DEFAULT_NAME)))
}
Self::Struct(f) => DataType::Struct(f.iter().map(|x|
x.field()).collect()),
+ Self::Map(value_type) => {
+ let val_dt = value_type.codec.data_type();
+ let val_field = Field::new("value", val_dt,
value_type.nullability.is_some())
+ .with_metadata(value_type.metadata.clone());
+ DataType::Map(
+ Arc::new(Field::new(
+ "entries",
+ DataType::Struct(Fields::from(vec![
+ Field::new("key", DataType::Utf8, false),
+ val_field,
+ ])),
+ false,
+ )),
+ false,
+ )
+ }
}
}
}
@@ -390,9 +421,14 @@ fn make_data_type<'a>(
ComplexType::Enum(e) => Err(ArrowError::NotYetImplemented(format!(
"Enum of {e:?} not currently supported"
))),
- ComplexType::Map(m) => Err(ArrowError::NotYetImplemented(format!(
- "Map of {m:?} not currently supported"
- ))),
+ ComplexType::Map(m) => {
+ let val = make_data_type(&m.values, namespace, resolver)?;
+ Ok(AvroDataType {
+ nullability: None,
+ metadata: m.attributes.field_metadata(),
+ codec: Codec::Map(Arc::new(val)),
+ })
+ }
},
Schema::Type(t) => {
let mut field = make_data_type(
diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs
index d5d454c8ee..cd9d6e3c13 100644
--- a/arrow-avro/src/reader/record.rs
+++ b/arrow-avro/src/reader/record.rs
@@ -27,6 +27,7 @@ use arrow_buffer::*;
use arrow_schema::{
ArrowError, DataType, Field as ArrowField, FieldRef, Fields, Schema as
ArrowSchema, SchemaRef,
};
+use std::cmp::Ordering;
use std::collections::HashMap;
use std::io::Read;
use std::sync::Arc;
@@ -114,6 +115,13 @@ enum Decoder {
StringView(OffsetBufferBuilder<i32>, Vec<u8>),
List(FieldRef, OffsetBufferBuilder<i32>, Box<Decoder>),
Record(Fields, Vec<Decoder>),
+ Map(
+ FieldRef,
+ OffsetBufferBuilder<i32>,
+ OffsetBufferBuilder<i32>,
+ Vec<u8>,
+ Box<Decoder>,
+ ),
Nullable(Nullability, NullBufferBuilder, Box<Decoder>),
}
@@ -169,6 +177,25 @@ impl Decoder {
}
Self::Record(arrow_fields.into(), encodings)
}
+ Codec::Map(child) => {
+ let val_field =
child.field_with_name("value").with_nullable(true);
+ let map_field = Arc::new(ArrowField::new(
+ "entries",
+ DataType::Struct(Fields::from(vec![
+ ArrowField::new("key", DataType::Utf8, false),
+ val_field,
+ ])),
+ false,
+ ));
+ let val_dec = Self::try_new(child)?;
+ Self::Map(
+ map_field,
+ OffsetBufferBuilder::new(DEFAULT_CAPACITY),
+ OffsetBufferBuilder::new(DEFAULT_CAPACITY),
+ Vec::with_capacity(DEFAULT_CAPACITY),
+ Box::new(val_dec),
+ )
+ }
};
Ok(match data_type.nullability() {
@@ -201,6 +228,9 @@ impl Decoder {
e.append_null();
}
Self::Record(_, e) => e.iter_mut().for_each(|e| e.append_null()),
+ Self::Map(_, _koff, moff, _, _) => {
+ moff.push_length(0);
+ }
Self::Nullable(_, _, _) => unreachable!("Nulls cannot be nested"),
}
}
@@ -236,6 +266,15 @@ impl Decoder {
encoding.decode(buf)?;
}
}
+ Self::Map(_, koff, moff, kdata, valdec) => {
+ let newly_added = read_map_blocks(buf, |cur| {
+ let kb = cur.get_bytes()?;
+ koff.push_length(kb.len());
+ kdata.extend_from_slice(kb);
+ valdec.decode(cur)
+ })?;
+ moff.push_length(newly_added);
+ }
Self::Nullable(nullability, nulls, e) => {
let is_valid = buf.get_bool()? == matches!(nullability,
Nullability::NullFirst);
nulls.append(is_valid);
@@ -273,7 +312,6 @@ impl Decoder {
),
Self::Float32(values) =>
Arc::new(flush_primitive::<Float32Type>(values, nulls)),
Self::Float64(values) =>
Arc::new(flush_primitive::<Float64Type>(values, nulls)),
-
Self::Binary(offsets, values) => {
let offsets = flush_offsets(offsets);
let values = flush_values(values).into();
@@ -313,10 +351,89 @@ impl Decoder {
.collect::<Result<Vec<_>, _>>()?;
Arc::new(StructArray::new(fields.clone(), arrays, nulls))
}
+ Self::Map(map_field, k_off, m_off, kdata, valdec) => {
+ let moff = flush_offsets(m_off);
+ let koff = flush_offsets(k_off);
+ let kd = flush_values(kdata).into();
+ let val_arr = valdec.flush(None)?;
+ let key_arr = StringArray::new(koff, kd, None);
+ if key_arr.len() != val_arr.len() {
+ return Err(ArrowError::InvalidArgumentError(format!(
+ "Map keys length ({}) != map values length ({})",
+ key_arr.len(),
+ val_arr.len()
+ )));
+ }
+ let final_len = moff.len() - 1;
+ if let Some(n) = &nulls {
+ if n.len() != final_len {
+ return Err(ArrowError::InvalidArgumentError(format!(
+ "Map array null buffer length {} != final map
length {final_len}",
+ n.len()
+ )));
+ }
+ }
+ let entries_struct = StructArray::new(
+ Fields::from(vec![
+ Arc::new(ArrowField::new("key", DataType::Utf8,
false)),
+ Arc::new(ArrowField::new("value",
val_arr.data_type().clone(), true)),
+ ]),
+ vec![Arc::new(key_arr), val_arr],
+ None,
+ );
+ let map_arr = MapArray::new(map_field.clone(), moff,
entries_struct, nulls, false);
+ Arc::new(map_arr)
+ }
})
}
}
+fn read_map_blocks(
+ buf: &mut AvroCursor,
+ decode_entry: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>,
+) -> Result<usize, ArrowError> {
+ read_blockwise_items(buf, true, decode_entry)
+}
+
+fn read_blockwise_items(
+ buf: &mut AvroCursor,
+ read_size_after_negative: bool,
+ mut decode_fn: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>,
+) -> Result<usize, ArrowError> {
+ let mut total = 0usize;
+ loop {
+ // Read the block count
+ // positive = that many items
+ // negative = that many items + read block size
+ // See: https://avro.apache.org/docs/1.11.1/specification/#maps
+ let block_count = buf.get_long()?;
+ match block_count.cmp(&0) {
+ Ordering::Equal => break,
+ Ordering::Less => {
+ // If block_count is negative, read the absolute value of
count,
+ // then read the block size as a long and discard
+ let count = (-block_count) as usize;
+ if read_size_after_negative {
+ let _size_in_bytes = buf.get_long()?;
+ }
+ for _ in 0..count {
+ decode_fn(buf)?;
+ }
+ total += count;
+ }
+ Ordering::Greater => {
+ // If block_count is positive, decode that many items
+ let count = block_count as usize;
+ for _i in 0..count {
+ decode_fn(buf)?;
+ }
+ total += count;
+ }
+ }
+ }
+ Ok(total)
+}
+
#[inline]
fn flush_values<T>(values: &mut Vec<T>) -> Vec<T> {
std::mem::replace(values, Vec::with_capacity(DEFAULT_CAPACITY))
@@ -336,3 +453,82 @@ fn flush_primitive<T: ArrowPrimitiveType>(
}
const DEFAULT_CAPACITY: usize = 1024;
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use arrow_array::{
+ cast::AsArray, Array, Decimal128Array, DictionaryArray,
FixedSizeBinaryArray,
+ IntervalMonthDayNanoArray, ListArray, MapArray, StringArray,
StructArray,
+ };
+
+ fn encode_avro_long(value: i64) -> Vec<u8> {
+ let mut buf = Vec::new();
+ let mut v = (value << 1) ^ (value >> 63);
+ while v & !0x7F != 0 {
+ buf.push(((v & 0x7F) | 0x80) as u8);
+ v >>= 7;
+ }
+ buf.push(v as u8);
+ buf
+ }
+
+ fn encode_avro_bytes(bytes: &[u8]) -> Vec<u8> {
+ let mut buf = encode_avro_long(bytes.len() as i64);
+ buf.extend_from_slice(bytes);
+ buf
+ }
+
+ fn avro_from_codec(codec: Codec) -> AvroDataType {
+ AvroDataType::new(codec, Default::default(), None)
+ }
+
+ #[test]
+ fn test_map_decoding_one_entry() {
+ let value_type = avro_from_codec(Codec::Utf8);
+ let map_type = avro_from_codec(Codec::Map(Arc::new(value_type)));
+ let mut decoder = Decoder::try_new(&map_type).unwrap();
+ // Encode a single map with one entry: {"hello": "world"}
+ let mut data = Vec::new();
+ data.extend_from_slice(&encode_avro_long(1));
+ data.extend_from_slice(&encode_avro_bytes(b"hello")); // key
+ data.extend_from_slice(&encode_avro_bytes(b"world")); // value
+ data.extend_from_slice(&encode_avro_long(0));
+ let mut cursor = AvroCursor::new(&data);
+ decoder.decode(&mut cursor).unwrap();
+ let array = decoder.flush(None).unwrap();
+ let map_arr = array.as_any().downcast_ref::<MapArray>().unwrap();
+ assert_eq!(map_arr.len(), 1); // one map
+ assert_eq!(map_arr.value_length(0), 1);
+ let entries = map_arr.value(0);
+ let struct_entries =
entries.as_any().downcast_ref::<StructArray>().unwrap();
+ assert_eq!(struct_entries.len(), 1);
+ let key_arr = struct_entries
+ .column_by_name("key")
+ .unwrap()
+ .as_any()
+ .downcast_ref::<StringArray>()
+ .unwrap();
+ let val_arr = struct_entries
+ .column_by_name("value")
+ .unwrap()
+ .as_any()
+ .downcast_ref::<StringArray>()
+ .unwrap();
+ assert_eq!(key_arr.value(0), "hello");
+ assert_eq!(val_arr.value(0), "world");
+ }
+
+ #[test]
+ fn test_map_decoding_empty() {
+ let value_type = avro_from_codec(Codec::Utf8);
+ let map_type = avro_from_codec(Codec::Map(Arc::new(value_type)));
+ let mut decoder = Decoder::try_new(&map_type).unwrap();
+ let data = encode_avro_long(0);
+ decoder.decode(&mut AvroCursor::new(&data)).unwrap();
+ let array = decoder.flush(None).unwrap();
+ let map_arr = array.as_any().downcast_ref::<MapArray>().unwrap();
+ assert_eq!(map_arr.len(), 1);
+ assert_eq!(map_arr.value_length(0), 0);
+ }
+}