This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 04bd39521 Support sorting dictionary encoded primitive integer arrays
(#2680)
04bd39521 is described below
commit 04bd39521e264693d38048059c372b0712cc87a2
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Sat Sep 10 00:22:38 2022 -0700
Support sorting dictionary encoded primitive integer arrays (#2680)
* Support sorting dictionary encoded primitive arrays
* Reduce combinatorial fanout
* Change from &SortOptions to SortOptions
* Fix value order and add a test
* Fix null ordering and add test
* Add comment and increase test coverage.
---
arrow/src/compute/kernels/sort.rs | 405 ++++++++++++++++++++++++++++++++++----
1 file changed, 363 insertions(+), 42 deletions(-)
diff --git a/arrow/src/compute/kernels/sort.rs
b/arrow/src/compute/kernels/sort.rs
index 0e2273e92..7a2d47786 100644
--- a/arrow/src/compute/kernels/sort.rs
+++ b/arrow/src/compute/kernels/sort.rs
@@ -21,8 +21,10 @@ use crate::array::*;
use crate::buffer::MutableBuffer;
use crate::compute::take;
use crate::datatypes::*;
+use crate::downcast_dictionary_array;
use crate::error::{ArrowError, Result};
use std::cmp::Ordering;
+use std::collections::HashMap;
use TimeUnit::*;
/// Sort the `ArrayRef` using `SortOptions`.
@@ -311,41 +313,121 @@ pub fn sort_to_indices(
)));
}
},
- DataType::Dictionary(key_type, value_type)
- if *value_type.as_ref() == DataType::Utf8 =>
- {
- match key_type.as_ref() {
- DataType::Int8 => {
- sort_string_dictionary::<Int8Type>(values, v, n, &options,
limit)
- }
- DataType::Int16 => {
- sort_string_dictionary::<Int16Type>(values, v, n,
&options, limit)
- }
- DataType::Int32 => {
- sort_string_dictionary::<Int32Type>(values, v, n,
&options, limit)
- }
- DataType::Int64 => {
- sort_string_dictionary::<Int64Type>(values, v, n,
&options, limit)
- }
- DataType::UInt8 => {
- sort_string_dictionary::<UInt8Type>(values, v, n,
&options, limit)
- }
- DataType::UInt16 => {
- sort_string_dictionary::<UInt16Type>(values, v, n,
&options, limit)
- }
- DataType::UInt32 => {
- sort_string_dictionary::<UInt32Type>(values, v, n,
&options, limit)
- }
- DataType::UInt64 => {
- sort_string_dictionary::<UInt64Type>(values, v, n,
&options, limit)
- }
- t => {
- return Err(ArrowError::ComputeError(format!(
- "Sort not supported for dictionary key type {:?}",
- t
- )));
- }
- }
+ DataType::Dictionary(_, _) => {
+ downcast_dictionary_array!(
+ values => match values.values().data_type() {
+ DataType::Int8 => {
+ let dict_values = values.values();
+ let value_null_first = if options.descending {
+ // When sorting dictionary in descending order, we
take inverse of of null ordering
+ // when sorting the values. Because if
`nulls_first` is true, null must be in front
+ // of non-null value. As we take the sorted order
of value array to sort dictionary
+ // keys, these null values will be treated as
smallest ones and be sorted to the end
+ // of sorted result. So we set `nulls_first` to
false when sorting dictionary value
+ // array to make them as largest ones, then null
values will be put at the beginning
+ // of sorted dictionary result.
+ !options.nulls_first
+ } else {
+ options.nulls_first
+ };
+ let value_options = Some(SortOptions { descending:
false, nulls_first: value_null_first });
+ let sorted_value_indices =
sort_to_indices(dict_values, value_options, None)?;
+ let value_indices_map =
prepare_indices_map(&sorted_value_indices);
+ sort_primitive_dictionary::<_, _>(values,
&value_indices_map, v, n, options, limit, cmp)
+ },
+ DataType::Int16 => {
+ let dict_values = values.values();
+ let value_null_first = if options.descending {
+ !options.nulls_first
+ } else {
+ options.nulls_first
+ };
+ let value_options = Some(SortOptions { descending:
false, nulls_first: value_null_first });
+ let sorted_value_indices =
sort_to_indices(dict_values, value_options, None)?;
+ let value_indices_map =
prepare_indices_map(&sorted_value_indices);
+ sort_primitive_dictionary::<_, _>(values,
&value_indices_map, v, n, options, limit, cmp)
+ },
+ DataType::Int32 => {
+ let dict_values = values.values();
+ let value_null_first = if options.descending {
+ !options.nulls_first
+ } else {
+ options.nulls_first
+ };
+ let value_options = Some(SortOptions { descending:
false, nulls_first: value_null_first });
+ let sorted_value_indices =
sort_to_indices(dict_values, value_options, None)?;
+ let value_indices_map =
prepare_indices_map(&sorted_value_indices);
+ sort_primitive_dictionary::<_, _>(values,
&value_indices_map, v, n, options, limit, cmp)
+ },
+ DataType::Int64 => {
+ let dict_values = values.values();
+ let value_null_first = if options.descending {
+ !options.nulls_first
+ } else {
+ options.nulls_first
+ };
+ let value_options = Some(SortOptions { descending:
false, nulls_first: value_null_first });
+ let sorted_value_indices =
sort_to_indices(dict_values, value_options, None)?;
+ let value_indices_map =
prepare_indices_map(&sorted_value_indices);
+ sort_primitive_dictionary::<_, _>(values,
&value_indices_map,v, n, options, limit, cmp)
+ },
+ DataType::UInt8 => {
+ let dict_values = values.values();
+ let value_null_first = if options.descending {
+ !options.nulls_first
+ } else {
+ options.nulls_first
+ };
+ let value_options = Some(SortOptions { descending:
false, nulls_first: value_null_first });
+ let sorted_value_indices =
sort_to_indices(dict_values, value_options, None)?;
+ let value_indices_map =
prepare_indices_map(&sorted_value_indices);
+ sort_primitive_dictionary::<_, _>(values,
&value_indices_map,v, n, options, limit, cmp)
+ },
+ DataType::UInt16 => {
+ let dict_values = values.values();
+ let value_null_first = if options.descending {
+ !options.nulls_first
+ } else {
+ options.nulls_first
+ };
+ let value_options = Some(SortOptions { descending:
false, nulls_first: value_null_first });
+ let sorted_value_indices =
sort_to_indices(dict_values, value_options, None)?;
+ let value_indices_map =
prepare_indices_map(&sorted_value_indices);
+ sort_primitive_dictionary::<_, _>(values,
&value_indices_map,v, n, options, limit, cmp)
+ },
+ DataType::UInt32 => {
+ let dict_values = values.values();
+ let value_null_first = if options.descending {
+ !options.nulls_first
+ } else {
+ options.nulls_first
+ };
+ let value_options = Some(SortOptions { descending:
false, nulls_first: value_null_first });
+ let sorted_value_indices =
sort_to_indices(dict_values, value_options, None)?;
+ let value_indices_map =
prepare_indices_map(&sorted_value_indices);
+ sort_primitive_dictionary::<_, _>(values,
&value_indices_map,v, n, options, limit, cmp)
+ },
+ DataType::UInt64 => {
+ let dict_values = values.values();
+ let value_null_first = if options.descending {
+ !options.nulls_first
+ } else {
+ options.nulls_first
+ };
+ let value_options = Some(SortOptions { descending:
false, nulls_first: value_null_first });
+ let sorted_value_indices =
sort_to_indices(dict_values, value_options, None)?;
+ let value_indices_map =
prepare_indices_map(&sorted_value_indices);
+ sort_primitive_dictionary::<_, _>(values,
&value_indices_map, v, n, options, limit, cmp)
+ },
+ DataType::Utf8 => sort_string_dictionary::<_>(values, v,
n, &options, limit),
+ t => return Err(ArrowError::ComputeError(format!(
+ "Unsupported dictionary value type {}", t
+ ))),
+ },
+ t => return Err(ArrowError::ComputeError(format!(
+ "Unsupported datatype {}", t
+ ))),
+ )
}
DataType::Binary | DataType::FixedSizeBinary(_) => {
sort_binary::<i32>(values, v, n, &options, limit)
@@ -489,7 +571,14 @@ where
.into_iter()
.map(|index| (index, decimal_array.value(index as usize).as_i128()))
.collect::<Vec<(u32, i128)>>();
- sort_primitive_inner(decimal_values, null_indices, cmp, options, limit,
valids)
+ sort_primitive_inner(
+ decimal_values.len(),
+ null_indices,
+ cmp,
+ options,
+ limit,
+ valids,
+ )
}
/// Sort primitive values
@@ -514,12 +603,55 @@ where
.map(|index| (index, values.value(index as usize)))
.collect::<Vec<(u32, T::Native)>>()
};
- sort_primitive_inner(values, null_indices, cmp, options, limit, valids)
+ sort_primitive_inner(values.len(), null_indices, cmp, options, limit,
valids)
+}
+
+/// A helper function used to convert sorted value indices to a map that we
can look up sorted order
+/// for a value index later.
+fn prepare_indices_map(sorted_value_indices: &UInt32Array) -> HashMap<usize,
u32> {
+ sorted_value_indices
+ .into_iter()
+ .enumerate()
+ .map(|(idx, index)| {
+ // Indices don't have None value
+ let index = index.unwrap();
+ (index as usize, idx as u32)
+ })
+ .collect::<HashMap<usize, u32>>()
+}
+
+/// Sort dictionary encoded primitive values
+fn sort_primitive_dictionary<K, F>(
+ values: &DictionaryArray<K>,
+ value_indices_map: &HashMap<usize, u32>,
+ value_indices: Vec<u32>,
+ null_indices: Vec<u32>,
+ options: SortOptions,
+ limit: Option<usize>,
+ cmp: F,
+) -> UInt32Array
+where
+ K: ArrowDictionaryKeyType,
+ F: Fn(u32, u32) -> std::cmp::Ordering,
+{
+ let keys: &PrimitiveArray<K> = values.keys();
+
+ // create tuples that are used for sorting
+ let valids = value_indices
+ .into_iter()
+ .map(|index| {
+ let key: K::Native = keys.value(index as usize);
+ let value_order =
value_indices_map.get(&key.to_usize().unwrap()).unwrap();
+ (index, *value_order)
+ })
+ .collect::<Vec<(u32, u32)>>();
+
+ sort_primitive_inner::<_, _>(keys.len(), null_indices, cmp, &options,
limit, valids)
}
// sort is instantiated a lot so we only compile this inner version for each
native type
fn sort_primitive_inner<T, F>(
- values: &ArrayRef,
+ value_len: usize,
null_indices: Vec<u32>,
cmp: F,
options: &SortOptions,
@@ -535,7 +667,7 @@ where
let valids_len = valids.len();
let nulls_len = nulls.len();
- let mut len = values.len();
+ let mut len = value_len;
if let Some(limit) = limit {
len = limit.min(len);
@@ -620,14 +752,12 @@ fn sort_string<Offset: OffsetSizeTrait>(
/// Sort dictionary encoded strings
fn sort_string_dictionary<T: ArrowDictionaryKeyType>(
- values: &ArrayRef,
+ values: &DictionaryArray<T>,
value_indices: Vec<u32>,
null_indices: Vec<u32>,
options: &SortOptions,
limit: Option<usize>,
) -> UInt32Array {
- let values: &DictionaryArray<T> = as_dictionary_array::<T>(values);
-
let keys: &PrimitiveArray<T> = values.keys();
let dict = values.values();
@@ -1239,6 +1369,59 @@ mod tests {
assert_eq!(sorted_strings, expected)
}
+ fn test_sort_primitive_dict_arrays<K: ArrowDictionaryKeyType, T:
ArrowPrimitiveType>(
+ keys: PrimitiveArray<K>,
+ values: PrimitiveArray<T>,
+ options: Option<SortOptions>,
+ limit: Option<usize>,
+ expected_data: Vec<Option<T::Native>>,
+ ) where
+ PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
+ {
+ let array = DictionaryArray::<K>::try_new(&keys, &values).unwrap();
+ let array_values = array.values().clone();
+ let dict = array_values
+ .as_any()
+ .downcast_ref::<PrimitiveArray<T>>()
+ .expect("Unable to get dictionary values");
+
+ let sorted = match limit {
+ Some(_) => {
+ sort_limit(&(Arc::new(array) as ArrayRef), options,
limit).unwrap()
+ }
+ _ => sort(&(Arc::new(array) as ArrayRef), options).unwrap(),
+ };
+ let sorted = sorted
+ .as_any()
+ .downcast_ref::<DictionaryArray<K>>()
+ .unwrap();
+ let sorted_values = sorted.values();
+ let sorted_dict = sorted_values
+ .as_any()
+ .downcast_ref::<PrimitiveArray<T>>()
+ .expect("Unable to get dictionary values");
+ let sorted_keys = sorted.keys();
+
+ assert_eq!(sorted_dict, dict);
+
+ let sorted_values: PrimitiveArray<T> =
From::<Vec<Option<T::Native>>>::from(
+ (0..sorted.len())
+ .map(|i| {
+ let key = sorted_keys.value(i).to_usize().unwrap();
+ if sorted.is_valid(i) && sorted_dict.is_valid(key) {
+ Some(sorted_dict.value(key))
+ } else {
+ None
+ }
+ })
+ .collect::<Vec<Option<T::Native>>>(),
+ );
+ let expected: PrimitiveArray<T> =
+ From::<Vec<Option<T::Native>>>::from(expected_data);
+
+ assert_eq!(sorted_values, expected)
+ }
+
fn test_sort_list_arrays<T>(
data: Vec<Option<Vec<Option<T::Native>>>>,
options: Option<SortOptions>,
@@ -3222,4 +3405,142 @@ mod tests {
partial_sort(&mut before, last, |a, b| a.cmp(b));
assert_eq!(&d[0..last], &before[0..last]);
}
+
+ #[test]
+ fn test_sort_int8_dicts() {
+ let keys =
+ Int8Array::from(vec![Some(1_i8), None, Some(2), None, Some(2),
Some(0)]);
+ let values = Int8Array::from(vec![1, 3, 5]);
+ test_sort_primitive_dict_arrays::<Int8Type, Int8Type>(
+ keys,
+ values,
+ None,
+ None,
+ vec![None, None, Some(1), Some(3), Some(5), Some(5)],
+ );
+
+ let keys =
+ Int8Array::from(vec![Some(1_i8), None, Some(2), None, Some(2),
Some(0)]);
+ let values = Int8Array::from(vec![1, 3, 5]);
+ test_sort_primitive_dict_arrays::<Int8Type, Int8Type>(
+ keys,
+ values,
+ Some(SortOptions {
+ descending: true,
+ nulls_first: false,
+ }),
+ None,
+ vec![Some(5), Some(5), Some(3), Some(1), None, None],
+ );
+
+ let keys =
+ Int8Array::from(vec![Some(1_i8), None, Some(2), None, Some(2),
Some(0)]);
+ let values = Int8Array::from(vec![1, 3, 5]);
+ test_sort_primitive_dict_arrays::<Int8Type, Int8Type>(
+ keys,
+ values,
+ Some(SortOptions {
+ descending: false,
+ nulls_first: false,
+ }),
+ None,
+ vec![Some(1), Some(3), Some(5), Some(5), None, None],
+ );
+
+ let keys =
+ Int8Array::from(vec![Some(1_i8), None, Some(2), None, Some(2),
Some(0)]);
+ let values = Int8Array::from(vec![1, 3, 5]);
+ test_sort_primitive_dict_arrays::<Int8Type, Int8Type>(
+ keys,
+ values,
+ Some(SortOptions {
+ descending: true,
+ nulls_first: true,
+ }),
+ Some(3),
+ vec![None, None, Some(5)],
+ );
+
+ // Values have `None`.
+ let keys = Int8Array::from(vec![
+ Some(1_i8),
+ None,
+ Some(3),
+ None,
+ Some(2),
+ Some(3),
+ Some(0),
+ ]);
+ let values = Int8Array::from(vec![Some(1), Some(3), None, Some(5)]);
+ test_sort_primitive_dict_arrays::<Int8Type, Int8Type>(
+ keys,
+ values,
+ None,
+ None,
+ vec![None, None, None, Some(1), Some(3), Some(5), Some(5)],
+ );
+
+ let keys = Int8Array::from(vec![
+ Some(1_i8),
+ None,
+ Some(3),
+ None,
+ Some(2),
+ Some(3),
+ Some(0),
+ ]);
+ let values = Int8Array::from(vec![Some(1), Some(3), None, Some(5)]);
+ test_sort_primitive_dict_arrays::<Int8Type, Int8Type>(
+ keys,
+ values,
+ Some(SortOptions {
+ descending: false,
+ nulls_first: false,
+ }),
+ None,
+ vec![Some(1), Some(3), Some(5), Some(5), None, None, None],
+ );
+
+ let keys = Int8Array::from(vec![
+ Some(1_i8),
+ None,
+ Some(3),
+ None,
+ Some(2),
+ Some(3),
+ Some(0),
+ ]);
+ let values = Int8Array::from(vec![Some(1), Some(3), None, Some(5)]);
+ test_sort_primitive_dict_arrays::<Int8Type, Int8Type>(
+ keys,
+ values,
+ Some(SortOptions {
+ descending: true,
+ nulls_first: false,
+ }),
+ None,
+ vec![Some(5), Some(5), Some(3), Some(1), None, None, None],
+ );
+
+ let keys = Int8Array::from(vec![
+ Some(1_i8),
+ None,
+ Some(3),
+ None,
+ Some(2),
+ Some(3),
+ Some(0),
+ ]);
+ let values = Int8Array::from(vec![Some(1), Some(3), None, Some(5)]);
+ test_sort_primitive_dict_arrays::<Int8Type, Int8Type>(
+ keys,
+ values,
+ Some(SortOptions {
+ descending: true,
+ nulls_first: true,
+ }),
+ None,
+ vec![None, None, None, Some(5), Some(5), Some(3), Some(1)],
+ );
+ }
}