This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new c7cf8f7 feat(ipc): Support writing dictionaries nested in structs and
unions (#870)
c7cf8f7 is described below
commit c7cf8f77318ecc61531c5ee0785e81f7f26fe69a
Author: Helgi Kristvin Sigurbjarnarson <[email protected]>
AuthorDate: Fri Oct 29 06:11:40 2021 -0700
feat(ipc): Support writing dictionaries nested in structs and unions (#870)
* feat(ipc): Support for writing dictionaries nested in structs and unions
Dictionaries are lost when serializing a RecordBatch for IPC, producing
invalid arrow data. This PR changes encoded_batch to recursively find
all dictionary fields within the schema (currently only in structs and
unions) so nested dictionaries are properly serialized.
* address lint and clippy
---
arrow/src/array/cast.rs | 1 +
arrow/src/array/mod.rs | 2 +-
arrow/src/ipc/writer.rs | 138 +++++++++++++++++++++++++++++++++++++++++++-----
3 files changed, 127 insertions(+), 14 deletions(-)
diff --git a/arrow/src/array/cast.rs b/arrow/src/array/cast.rs
index dfc1560..e4284ef 100644
--- a/arrow/src/array/cast.rs
+++ b/arrow/src/array/cast.rs
@@ -92,3 +92,4 @@ array_downcast_fn!(as_largestring_array, LargeStringArray);
array_downcast_fn!(as_boolean_array, BooleanArray);
array_downcast_fn!(as_null_array, NullArray);
array_downcast_fn!(as_struct_array, StructArray);
+array_downcast_fn!(as_union_array, UnionArray);
diff --git a/arrow/src/array/mod.rs b/arrow/src/array/mod.rs
index 63b8b61..5d4e57a 100644
--- a/arrow/src/array/mod.rs
+++ b/arrow/src/array/mod.rs
@@ -441,7 +441,7 @@ pub use self::ord::{build_compare, DynComparator};
pub use self::cast::{
as_boolean_array, as_dictionary_array, as_generic_binary_array,
as_generic_list_array, as_large_list_array, as_largestring_array,
as_list_array,
- as_null_array, as_primitive_array, as_string_array, as_struct_array,
+ as_null_array, as_primitive_array, as_string_array, as_struct_array,
as_union_array,
};
// ------------------------------ C Data Interface ---------------------------
diff --git a/arrow/src/ipc/writer.rs b/arrow/src/ipc/writer.rs
index 0376265..853fc0f 100644
--- a/arrow/src/ipc/writer.rs
+++ b/arrow/src/ipc/writer.rs
@@ -25,7 +25,7 @@ use std::io::{BufWriter, Write};
use flatbuffers::FlatBufferBuilder;
-use crate::array::{ArrayData, ArrayRef};
+use crate::array::{as_struct_array, as_union_array, ArrayData, ArrayRef};
use crate::buffer::{Buffer, MutableBuffer};
use crate::datatypes::*;
use crate::error::{ArrowError, Result};
@@ -137,20 +137,45 @@ impl IpcDataGenerator {
}
}
- pub fn encoded_batch(
+ fn encode_dictionaries(
&self,
- batch: &RecordBatch,
+ field: &Field,
+ column: &ArrayRef,
+ encoded_dictionaries: &mut Vec<EncodedData>,
dictionary_tracker: &mut DictionaryTracker,
write_options: &IpcWriteOptions,
- ) -> Result<(Vec<EncodedData>, EncodedData)> {
- // TODO: handle nested dictionaries
- let schema = batch.schema();
- let mut encoded_dictionaries =
Vec::with_capacity(schema.fields().len());
-
- for (i, field) in schema.fields().iter().enumerate() {
- let column = batch.column(i);
-
- if let DataType::Dictionary(_key_type, _value_type) =
column.data_type() {
+ ) -> Result<()> {
+ // TODO: Handle other nested types (map, list, etc)
+ match column.data_type() {
+ DataType::Struct(fields) => {
+ let s = as_struct_array(column);
+ for (field, &column) in fields.iter().zip(s.columns().iter()) {
+ self.encode_dictionaries(
+ field,
+ column,
+ encoded_dictionaries,
+ dictionary_tracker,
+ write_options,
+ )?;
+ }
+ }
+ DataType::Union(fields) => {
+ let union = as_union_array(column);
+ for (field, ref column) in fields
+ .iter()
+ .enumerate()
+ .map(|(n, f)| (f, union.child(n as i8)))
+ {
+ self.encode_dictionaries(
+ field,
+ column,
+ encoded_dictionaries,
+ dictionary_tracker,
+ write_options,
+ )?;
+ }
+ }
+ DataType::Dictionary(_key_type, _value_type) => {
let dict_id = field
.dict_id()
.expect("All Dictionary types have `dict_id`");
@@ -167,10 +192,33 @@ impl IpcDataGenerator {
));
}
}
+ _ => (),
}
- let encoded_message = self.record_batch_to_bytes(batch, write_options);
+ Ok(())
+ }
+
+ pub fn encoded_batch(
+ &self,
+ batch: &RecordBatch,
+ dictionary_tracker: &mut DictionaryTracker,
+ write_options: &IpcWriteOptions,
+ ) -> Result<(Vec<EncodedData>, EncodedData)> {
+ let schema = batch.schema();
+ let mut encoded_dictionaries =
Vec::with_capacity(schema.fields().len());
+ for (i, field) in schema.fields().iter().enumerate() {
+ let column = batch.column(i);
+ self.encode_dictionaries(
+ field,
+ column,
+ &mut encoded_dictionaries,
+ dictionary_tracker,
+ write_options,
+ )?;
+ }
+
+ let encoded_message = self.record_batch_to_bytes(batch, write_options);
Ok((encoded_dictionaries, encoded_message))
}
@@ -1161,4 +1209,68 @@ mod tests {
let arrow_json: ArrowJson = serde_json::from_str(&s).unwrap();
arrow_json
}
+
+ #[test]
+ fn track_union_nested_dict() {
+ let inner: DictionaryArray<Int32Type> = vec!["a", "b",
"a"].into_iter().collect();
+
+ let array = Arc::new(inner) as ArrayRef;
+
+ // Dict field with id 2
+ let dctfield =
+ Field::new_dict("dict", array.data_type().clone(), false, 2,
false);
+
+ let types = Buffer::from_slice_ref(&[0_i8, 0, 0]);
+ let offsets = Buffer::from_slice_ref(&[0_i32, 1, 2]);
+
+ let union =
+ UnionArray::try_new(types, Some(offsets), vec![(dctfield, array)],
None)
+ .unwrap();
+
+ let schema = Arc::new(Schema::new(vec![Field::new(
+ "union",
+ union.data_type().clone(),
+ false,
+ )]));
+
+ let batch = RecordBatch::try_new(schema,
vec![Arc::new(union)]).unwrap();
+
+ let gen = IpcDataGenerator {};
+ let mut dict_tracker = DictionaryTracker::new(false);
+ gen.encoded_batch(&batch, &mut dict_tracker, &Default::default())
+ .unwrap();
+
+ // Dictionary with id 2 should have been written to the dict tracker
+ assert!(dict_tracker.written.contains_key(&2));
+ }
+
+ #[test]
+ fn track_struct_nested_dict() {
+ let inner: DictionaryArray<Int32Type> = vec!["a", "b",
"a"].into_iter().collect();
+
+ let array = Arc::new(inner) as ArrayRef;
+
+ // Dict field with id 2
+ let dctfield =
+ Field::new_dict("dict", array.data_type().clone(), false, 2,
false);
+
+ let s = StructArray::from(vec![(dctfield, array)]);
+ let struct_array = Arc::new(s) as ArrayRef;
+
+ let schema = Arc::new(Schema::new(vec![Field::new(
+ "struct",
+ struct_array.data_type().clone(),
+ false,
+ )]));
+
+ let batch = RecordBatch::try_new(schema, vec![struct_array]).unwrap();
+
+ let gen = IpcDataGenerator {};
+ let mut dict_tracker = DictionaryTracker::new(false);
+ gen.encoded_batch(&batch, &mut dict_tracker, &Default::default())
+ .unwrap();
+
+ // Dictionary with id 2 should have been written to the dict tracker
+ assert!(dict_tracker.written.contains_key(&2));
+ }
}