This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/main by this push:
new ef91857d40 arrow-select: add support for merging primitive dictionary
values (#7519)
ef91857d40 is described below
commit ef91857d40a8f5f325296b7e6a70d1ecf302dd48
Author: Alfonso Subiotto Marqués <[email protected]>
AuthorDate: Thu Jun 5 22:56:28 2025 +0200
arrow-select: add support for merging primitive dictionary values (#7519)
Previously, should_merge_dictionaries would always return false in the
ptr_eq closure creation match arm for types that were not
{Large}{Utf8,Binary}. This could lead to excessive memory usage.
# Which issue does this PR close?
<!--
We generally require a GitHub issue to be filed for all bug fixes and
enhancements and this helps us generate change logs for our releases.
You can link an issue to this PR using the GitHub syntax. For example
`Closes #123` indicates that this PR will close issue #123.
-->
- Closes #7518
# What changes are included in this PR?
Update to the match arm in `should_merge_dictionary_values` to not short
circuit on primitive types. Also uses primitive byte representations to
reuse the `Interner` pipeline used for the bytes types.
# Are there any user-facing changes?
No
---
arrow-select/src/concat.rs | 43 +++++++++++++++++++++++++++++++
arrow-select/src/dictionary.rs | 58 ++++++++++++++++++++++++++++++++----------
2 files changed, 88 insertions(+), 13 deletions(-)
diff --git a/arrow-select/src/concat.rs b/arrow-select/src/concat.rs
index 486afbd144..1c99caef5a 100644
--- a/arrow-select/src/concat.rs
+++ b/arrow-select/src/concat.rs
@@ -1081,6 +1081,49 @@ mod tests {
assert!((30..40).contains(&values_len), "{values_len}")
}
+ #[test]
+ fn test_primitive_dictionary_merge() {
+ // Same value repeated 5 times.
+ let keys = vec![1; 5];
+ let values = (10..20).collect::<Vec<_>>();
+ let dict = DictionaryArray::new(
+ Int8Array::from(keys.clone()),
+ Arc::new(Int32Array::from(values.clone())),
+ );
+ let other = DictionaryArray::new(
+ Int8Array::from(keys.clone()),
+ Arc::new(Int32Array::from(values.clone())),
+ );
+
+ let result_same_dictionary = concat(&[&dict, &dict]).unwrap();
+ // Verify pointer equality check succeeds, and therefore the
+ // dictionaries are not merged. A single values buffer should be reused
+ // in this case.
+ assert!(dict.values().to_data().ptr_eq(
+ &result_same_dictionary
+ .as_dictionary::<Int8Type>()
+ .values()
+ .to_data()
+ ));
+ assert_eq!(
+ result_same_dictionary
+ .as_dictionary::<Int8Type>()
+ .values()
+ .len(),
+ values.len(),
+ );
+
+ let result_cloned_dictionary = concat(&[&dict, &other]).unwrap();
+ // Should have only 1 underlying value since all keys reference it.
+ assert_eq!(
+ result_cloned_dictionary
+ .as_dictionary::<Int8Type>()
+ .values()
+ .len(),
+ 1
+ );
+ }
+
#[test]
fn test_concat_string_sizes() {
let a: LargeStringArray = ((0..150).map(|_| Some("foo"))).collect();
diff --git a/arrow-select/src/dictionary.rs b/arrow-select/src/dictionary.rs
index 57aed644fe..c5773b16a4 100644
--- a/arrow-select/src/dictionary.rs
+++ b/arrow-select/src/dictionary.rs
@@ -18,12 +18,13 @@
use crate::interleave::interleave;
use ahash::RandomState;
use arrow_array::builder::BooleanBufferBuilder;
-use arrow_array::cast::AsArray;
use arrow_array::types::{
- ArrowDictionaryKeyType, BinaryType, ByteArrayType, LargeBinaryType,
LargeUtf8Type, Utf8Type,
+ ArrowDictionaryKeyType, ArrowPrimitiveType, BinaryType, ByteArrayType,
LargeBinaryType,
+ LargeUtf8Type, Utf8Type,
};
-use arrow_array::{Array, ArrayRef, DictionaryArray, GenericByteArray};
-use arrow_buffer::{ArrowNativeType, BooleanBuffer, ScalarBuffer};
+use arrow_array::{cast::AsArray, downcast_primitive};
+use arrow_array::{Array, ArrayRef, DictionaryArray, GenericByteArray,
PrimitiveArray};
+use arrow_buffer::{ArrowNativeType, BooleanBuffer, ScalarBuffer, ToByteSlice};
use arrow_schema::{ArrowError, DataType};
/// A best effort interner that maintains a fixed number of buckets
@@ -102,7 +103,7 @@ fn bytes_ptr_eq<T: ByteArrayType>(a: &dyn Array, b: &dyn
Array) -> bool {
}
/// A type-erased function that compares two array for pointer equality
-type PtrEq = dyn Fn(&dyn Array, &dyn Array) -> bool;
+type PtrEq = fn(&dyn Array, &dyn Array) -> bool;
/// A weak heuristic of whether to merge dictionary values that aims to only
/// perform the expensive merge computation when it is likely to yield at least
@@ -115,12 +116,17 @@ pub fn should_merge_dictionary_values<K:
ArrowDictionaryKeyType>(
) -> bool {
use DataType::*;
let first_values = dictionaries[0].values().as_ref();
- let ptr_eq: Box<PtrEq> = match first_values.data_type() {
- Utf8 => Box::new(bytes_ptr_eq::<Utf8Type>),
- LargeUtf8 => Box::new(bytes_ptr_eq::<LargeUtf8Type>),
- Binary => Box::new(bytes_ptr_eq::<BinaryType>),
- LargeBinary => Box::new(bytes_ptr_eq::<LargeBinaryType>),
- _ => return false,
+ let ptr_eq: PtrEq = match first_values.data_type() {
+ Utf8 => bytes_ptr_eq::<Utf8Type>,
+ LargeUtf8 => bytes_ptr_eq::<LargeUtf8Type>,
+ Binary => bytes_ptr_eq::<BinaryType>,
+ LargeBinary => bytes_ptr_eq::<LargeBinaryType>,
+ dt => {
+ if !dt.is_primitive() {
+ return false;
+ }
+ |a, b| a.to_data().ptr_eq(&b.to_data())
+ }
};
let mut single_dictionary = true;
@@ -233,17 +239,43 @@ fn compute_values_mask<K: ArrowNativeType>(
builder.finish()
}
+/// Process primitive array values to bytes
+fn masked_primitives_to_bytes<'a, T: ArrowPrimitiveType>(
+ array: &'a PrimitiveArray<T>,
+ mask: &BooleanBuffer,
+) -> Vec<(usize, Option<&'a [u8]>)>
+where
+ T::Native: ToByteSlice,
+{
+ let mut out = Vec::with_capacity(mask.count_set_bits());
+ let values = array.values();
+ for idx in mask.set_indices() {
+ out.push((
+ idx,
+ array.is_valid(idx).then_some(values[idx].to_byte_slice()),
+ ))
+ }
+ out
+}
+
+macro_rules! masked_primitive_to_bytes_helper {
+ ($t:ty, $array:expr, $mask:expr) => {
+ masked_primitives_to_bytes::<$t>($array.as_primitive(), $mask)
+ };
+}
+
/// Return a Vec containing for each set index in `mask`, the index and byte
value of that index
fn get_masked_values<'a>(
array: &'a dyn Array,
mask: &BooleanBuffer,
) -> Vec<(usize, Option<&'a [u8]>)> {
- match array.data_type() {
+ downcast_primitive! {
+ array.data_type() => (masked_primitive_to_bytes_helper, array, mask),
DataType::Utf8 => masked_bytes(array.as_string::<i32>(), mask),
DataType::LargeUtf8 => masked_bytes(array.as_string::<i64>(), mask),
DataType::Binary => masked_bytes(array.as_binary::<i32>(), mask),
DataType::LargeBinary => masked_bytes(array.as_binary::<i64>(), mask),
- _ => unimplemented!(),
+ _ => unimplemented!("Dictionary merging for type {} is not
implemented", array.data_type()),
}
}