This is an automated email from the ASF dual-hosted git repository. alamb pushed a commit to branch cherry_pick_e20d3faf in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
commit 183c1ddff7e6c306f8f8bf1da5a43598c8aab8c1 Author: Helgi Kristvin Sigurbjarnarson <[email protected]> AuthorDate: Mon Nov 8 13:32:33 2021 -0800 feat(ipc): add support for deserializing messages with nested dictionary fields (#923) * feat(ipc): read a message containing nested dictionary fields * Apply suggestions from code review Co-authored-by: Andrew Lamb <[email protected]> * address lints Co-authored-by: Andrew Lamb <[email protected]> --- arrow/src/datatypes/field.rs | 88 +++++++++++++++++++++++++++++++++++++++++++ arrow/src/datatypes/schema.rs | 9 ++++- arrow/src/ipc/reader.rs | 38 +++++++++++++++++-- 3 files changed, 131 insertions(+), 4 deletions(-) diff --git a/arrow/src/datatypes/field.rs b/arrow/src/datatypes/field.rs index 497dbb3..4ed0661 100644 --- a/arrow/src/datatypes/field.rs +++ b/arrow/src/datatypes/field.rs @@ -107,6 +107,36 @@ impl Field { self.nullable } + /// Returns a (flattened) vector containing all fields contained within this field (including it self) + pub(crate) fn fields(&self) -> Vec<&Field> { + let mut collected_fields = vec![self]; + match &self.data_type { + DataType::Struct(fields) | DataType::Union(fields) => { + collected_fields.extend(fields.iter().map(|f| f.fields()).flatten()) + } + DataType::List(field) + | DataType::LargeList(field) + | DataType::FixedSizeList(field, _) + | DataType::Map(field, _) => collected_fields.push(field), + _ => (), + } + + collected_fields + } + + /// Returns a vector containing all (potentially nested) `Field` instances selected by the + /// dictionary ID they use + #[inline] + pub(crate) fn fields_with_dict_id(&self, id: i64) -> Vec<&Field> { + self.fields() + .into_iter() + .filter(|&field| { + matches!(field.data_type(), DataType::Dictionary(_, _)) + && field.dict_id == id + }) + .collect() + } + /// Returns the dictionary ID, if this is a dictionary type. #[inline] pub const fn dict_id(&self) -> Option<i64> { @@ -572,3 +602,61 @@ impl std::fmt::Display for Field { write!(f, "{:?}", self) } } + +#[cfg(test)] +mod test { + use super::{DataType, Field}; + + #[test] + fn test_fields_with_dict_id() { + let dict1 = Field::new_dict( + "dict1", + DataType::Dictionary(DataType::Utf8.into(), DataType::Int32.into()), + false, + 10, + false, + ); + let dict2 = Field::new_dict( + "dict2", + DataType::Dictionary(DataType::Int32.into(), DataType::Int8.into()), + false, + 20, + false, + ); + + let field = Field::new( + "struct<dict1, list[struct<dict2, list[struct<dict1]>]>", + DataType::Struct(vec![ + dict1.clone(), + Field::new( + "list[struct<dict1, list[struct<dict2>]>]", + DataType::List(Box::new(Field::new( + "struct<dict1, list[struct<dict2>]>", + DataType::Struct(vec![ + dict1.clone(), + Field::new( + "list[struct<dict2>]", + DataType::List(Box::new(Field::new( + "struct<dict2>", + DataType::Struct(vec![dict2.clone()]), + false, + ))), + false, + ), + ]), + false, + ))), + false, + ), + ]), + false, + ); + + for field in field.fields_with_dict_id(10) { + assert_eq!(dict1, *field); + } + for field in field.fields_with_dict_id(20) { + assert_eq!(dict2, *field); + } + } +} diff --git a/arrow/src/datatypes/schema.rs b/arrow/src/datatypes/schema.rs index cfc0744..cc8ddbd 100644 --- a/arrow/src/datatypes/schema.rs +++ b/arrow/src/datatypes/schema.rs @@ -159,6 +159,12 @@ impl Schema { &self.fields } + /// Returns a vector with references to all fields (including nested fields) + #[inline] + pub(crate) fn all_fields(&self) -> Vec<&Field> { + self.fields.iter().map(|f| f.fields()).flatten().collect() + } + /// Returns an immutable reference of a specific `Field` instance selected using an /// offset within the internal `fields` vector. pub fn field(&self, i: usize) -> &Field { @@ -175,7 +181,8 @@ impl Schema { pub fn fields_with_dict_id(&self, dict_id: i64) -> Vec<&Field> { self.fields .iter() - .filter(|f| f.dict_id() == Some(dict_id)) + .map(|f| f.fields_with_dict_id(dict_id)) + .flatten() .collect() } diff --git a/arrow/src/ipc/reader.rs b/arrow/src/ipc/reader.rs index e925e2a..5bc76d0 100644 --- a/arrow/src/ipc/reader.rs +++ b/arrow/src/ipc/reader.rs @@ -495,7 +495,7 @@ pub fn read_dictionary( // in the reader. Note that a dictionary batch may be shared between many fields. // We don't currently record the isOrdered field. This could be general // attributes of arrays. - for (i, field) in schema.fields().iter().enumerate() { + for (i, field) in schema.all_fields().iter().enumerate() { if field.dict_id() == Some(id) { // Add (possibly multiple) array refs to the dictionaries array. dictionaries_by_field[i] = Some(dictionary_values.clone()); @@ -582,7 +582,7 @@ impl<R: Read + Seek> FileReader<R> { let schema = ipc::convert::fb_to_schema(ipc_schema); // Create an array of optional dictionary value arrays, one per field. - let mut dictionaries_by_field = vec![None; schema.fields().len()]; + let mut dictionaries_by_field = vec![None; schema.all_fields().len()]; for block in footer.dictionaries().unwrap() { // read length from end of offset let mut message_size: [u8; 4] = [0; 4]; @@ -923,7 +923,7 @@ mod tests { use flate2::read::GzDecoder; - use crate::util::integration_util::*; + use crate::{datatypes, util::integration_util::*}; #[test] fn read_generated_files_014() { @@ -1149,6 +1149,38 @@ mod tests { }) } + #[test] + fn test_roundtrip_nested_dict() { + let inner: DictionaryArray<datatypes::Int32Type> = + vec!["a", "b", "a"].into_iter().collect(); + + let array = Arc::new(inner) as ArrayRef; + + let dctfield = Field::new("dict", array.data_type().clone(), false); + + let s = StructArray::from(vec![(dctfield, array)]); + let struct_array = Arc::new(s) as ArrayRef; + + let schema = Arc::new(Schema::new(vec![Field::new( + "struct", + struct_array.data_type().clone(), + false, + )])); + + let batch = RecordBatch::try_new(schema.clone(), vec![struct_array]).unwrap(); + + let mut buf = Vec::new(); + let mut writer = ipc::writer::FileWriter::try_new(&mut buf, &schema).unwrap(); + writer.write(&batch).unwrap(); + writer.finish().unwrap(); + drop(writer); + + let reader = ipc::reader::FileReader::try_new(std::io::Cursor::new(buf)).unwrap(); + let batch2: std::result::Result<Vec<_>, _> = reader.collect(); + + assert_eq!(batch, batch2.unwrap()[0]); + } + /// Read gzipped JSON file fn read_gzip_json(version: &str, path: &str) -> ArrowJson { let testdata = crate::util::test_util::arrow_test_data();
