This is an automated email from the ASF dual-hosted git repository. alamb pushed a commit to branch cherry_pick_c7cf8f77 in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
commit f8215f9624ec1d5dcbc31583daa44d977efc13c2 Author: Helgi Kristvin Sigurbjarnarson <[email protected]> AuthorDate: Fri Oct 29 06:11:40 2021 -0700 feat(ipc): Support writing dictionaries nested in structs and unions (#870) * feat(ipc): Support for writing dictionaries nested in structs and unions Dictionaries are lost when serializing a RecordBatch for IPC, producing invalid arrow data. This PR changes encoded_batch to recursively find all dictionary fields within the schema (currently only in structs and unions) so nested dictionaries are properly serialized. * address lint and clippy --- arrow/src/array/cast.rs | 1 + arrow/src/array/mod.rs | 2 +- arrow/src/ipc/writer.rs | 138 +++++++++++++++++++++++++++++++++++++++++++----- 3 files changed, 127 insertions(+), 14 deletions(-) diff --git a/arrow/src/array/cast.rs b/arrow/src/array/cast.rs index dfc1560..e4284ef 100644 --- a/arrow/src/array/cast.rs +++ b/arrow/src/array/cast.rs @@ -92,3 +92,4 @@ array_downcast_fn!(as_largestring_array, LargeStringArray); array_downcast_fn!(as_boolean_array, BooleanArray); array_downcast_fn!(as_null_array, NullArray); array_downcast_fn!(as_struct_array, StructArray); +array_downcast_fn!(as_union_array, UnionArray); diff --git a/arrow/src/array/mod.rs b/arrow/src/array/mod.rs index 63b8b61..5d4e57a 100644 --- a/arrow/src/array/mod.rs +++ b/arrow/src/array/mod.rs @@ -441,7 +441,7 @@ pub use self::ord::{build_compare, DynComparator}; pub use self::cast::{ as_boolean_array, as_dictionary_array, as_generic_binary_array, as_generic_list_array, as_large_list_array, as_largestring_array, as_list_array, - as_null_array, as_primitive_array, as_string_array, as_struct_array, + as_null_array, as_primitive_array, as_string_array, as_struct_array, as_union_array, }; // ------------------------------ C Data Interface --------------------------- diff --git a/arrow/src/ipc/writer.rs b/arrow/src/ipc/writer.rs index 0376265..853fc0f 100644 --- a/arrow/src/ipc/writer.rs +++ b/arrow/src/ipc/writer.rs @@ -25,7 +25,7 @@ use std::io::{BufWriter, Write}; use flatbuffers::FlatBufferBuilder; -use crate::array::{ArrayData, ArrayRef}; +use crate::array::{as_struct_array, as_union_array, ArrayData, ArrayRef}; use crate::buffer::{Buffer, MutableBuffer}; use crate::datatypes::*; use crate::error::{ArrowError, Result}; @@ -137,20 +137,45 @@ impl IpcDataGenerator { } } - pub fn encoded_batch( + fn encode_dictionaries( &self, - batch: &RecordBatch, + field: &Field, + column: &ArrayRef, + encoded_dictionaries: &mut Vec<EncodedData>, dictionary_tracker: &mut DictionaryTracker, write_options: &IpcWriteOptions, - ) -> Result<(Vec<EncodedData>, EncodedData)> { - // TODO: handle nested dictionaries - let schema = batch.schema(); - let mut encoded_dictionaries = Vec::with_capacity(schema.fields().len()); - - for (i, field) in schema.fields().iter().enumerate() { - let column = batch.column(i); - - if let DataType::Dictionary(_key_type, _value_type) = column.data_type() { + ) -> Result<()> { + // TODO: Handle other nested types (map, list, etc) + match column.data_type() { + DataType::Struct(fields) => { + let s = as_struct_array(column); + for (field, &column) in fields.iter().zip(s.columns().iter()) { + self.encode_dictionaries( + field, + column, + encoded_dictionaries, + dictionary_tracker, + write_options, + )?; + } + } + DataType::Union(fields) => { + let union = as_union_array(column); + for (field, ref column) in fields + .iter() + .enumerate() + .map(|(n, f)| (f, union.child(n as i8))) + { + self.encode_dictionaries( + field, + column, + encoded_dictionaries, + dictionary_tracker, + write_options, + )?; + } + } + DataType::Dictionary(_key_type, _value_type) => { let dict_id = field .dict_id() .expect("All Dictionary types have `dict_id`"); @@ -167,10 +192,33 @@ impl IpcDataGenerator { )); } } + _ => (), } - let encoded_message = self.record_batch_to_bytes(batch, write_options); + Ok(()) + } + + pub fn encoded_batch( + &self, + batch: &RecordBatch, + dictionary_tracker: &mut DictionaryTracker, + write_options: &IpcWriteOptions, + ) -> Result<(Vec<EncodedData>, EncodedData)> { + let schema = batch.schema(); + let mut encoded_dictionaries = Vec::with_capacity(schema.fields().len()); + for (i, field) in schema.fields().iter().enumerate() { + let column = batch.column(i); + self.encode_dictionaries( + field, + column, + &mut encoded_dictionaries, + dictionary_tracker, + write_options, + )?; + } + + let encoded_message = self.record_batch_to_bytes(batch, write_options); Ok((encoded_dictionaries, encoded_message)) } @@ -1161,4 +1209,68 @@ mod tests { let arrow_json: ArrowJson = serde_json::from_str(&s).unwrap(); arrow_json } + + #[test] + fn track_union_nested_dict() { + let inner: DictionaryArray<Int32Type> = vec!["a", "b", "a"].into_iter().collect(); + + let array = Arc::new(inner) as ArrayRef; + + // Dict field with id 2 + let dctfield = + Field::new_dict("dict", array.data_type().clone(), false, 2, false); + + let types = Buffer::from_slice_ref(&[0_i8, 0, 0]); + let offsets = Buffer::from_slice_ref(&[0_i32, 1, 2]); + + let union = + UnionArray::try_new(types, Some(offsets), vec![(dctfield, array)], None) + .unwrap(); + + let schema = Arc::new(Schema::new(vec![Field::new( + "union", + union.data_type().clone(), + false, + )])); + + let batch = RecordBatch::try_new(schema, vec![Arc::new(union)]).unwrap(); + + let gen = IpcDataGenerator {}; + let mut dict_tracker = DictionaryTracker::new(false); + gen.encoded_batch(&batch, &mut dict_tracker, &Default::default()) + .unwrap(); + + // Dictionary with id 2 should have been written to the dict tracker + assert!(dict_tracker.written.contains_key(&2)); + } + + #[test] + fn track_struct_nested_dict() { + let inner: DictionaryArray<Int32Type> = vec!["a", "b", "a"].into_iter().collect(); + + let array = Arc::new(inner) as ArrayRef; + + // Dict field with id 2 + let dctfield = + Field::new_dict("dict", array.data_type().clone(), false, 2, false); + + let s = StructArray::from(vec![(dctfield, array)]); + let struct_array = Arc::new(s) as ArrayRef; + + let schema = Arc::new(Schema::new(vec![Field::new( + "struct", + struct_array.data_type().clone(), + false, + )])); + + let batch = RecordBatch::try_new(schema, vec![struct_array]).unwrap(); + + let gen = IpcDataGenerator {}; + let mut dict_tracker = DictionaryTracker::new(false); + gen.encoded_batch(&batch, &mut dict_tracker, &Default::default()) + .unwrap(); + + // Dictionary with id 2 should have been written to the dict tracker + assert!(dict_tracker.written.contains_key(&2)); + } }
