This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 2a0fc7703 Add support of sorting dictionary of other primitive arrays
(#2701)
2a0fc7703 is described below
commit 2a0fc7703420f99d28141516cabdd0408a583dfc
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Wed Sep 14 14:05:58 2022 -0700
Add support of sorting dictionary of other primitive arrays (#2701)
* Add support of sorting dictionary of other primitive arrays
* Collapse match statements
* Add one helper to match primitive types
---
arrow/src/compute/kernels/sort.rs | 259 +++++++++++++++++++++++---------------
arrow/src/datatypes/datatype.rs | 24 ++++
2 files changed, 179 insertions(+), 104 deletions(-)
diff --git a/arrow/src/compute/kernels/sort.rs
b/arrow/src/compute/kernels/sort.rs
index 34a321910..0bc2d3948 100644
--- a/arrow/src/compute/kernels/sort.rs
+++ b/arrow/src/compute/kernels/sort.rs
@@ -314,119 +314,32 @@ pub fn sort_to_indices(
}
},
DataType::Dictionary(_, _) => {
+ let value_null_first = if options.descending {
+ // When sorting dictionary in descending order, we take
inverse of of null ordering
+ // when sorting the values. Because if `nulls_first` is true,
null must be in front
+ // of non-null value. As we take the sorted order of value
array to sort dictionary
+ // keys, these null values will be treated as smallest ones
and be sorted to the end
+ // of sorted result. So we set `nulls_first` to false when
sorting dictionary value
+ // array to make them as largest ones, then null values will
be put at the beginning
+ // of sorted dictionary result.
+ !options.nulls_first
+ } else {
+ options.nulls_first
+ };
+ let value_options = Some(SortOptions {
+ descending: false,
+ nulls_first: value_null_first,
+ });
downcast_dictionary_array!(
values => match values.values().data_type() {
- DataType::Int8 => {
- let dict_values = values.values();
- let value_null_first = if options.descending {
- // When sorting dictionary in descending order, we
take inverse of of null ordering
- // when sorting the values. Because if
`nulls_first` is true, null must be in front
- // of non-null value. As we take the sorted order
of value array to sort dictionary
- // keys, these null values will be treated as
smallest ones and be sorted to the end
- // of sorted result. So we set `nulls_first` to
false when sorting dictionary value
- // array to make them as largest ones, then null
values will be put at the beginning
- // of sorted dictionary result.
- !options.nulls_first
- } else {
- options.nulls_first
- };
- let value_options = Some(SortOptions { descending:
false, nulls_first: value_null_first });
- let sorted_value_indices =
sort_to_indices(dict_values, value_options, None)?;
- let value_indices_map =
prepare_indices_map(&sorted_value_indices);
- sort_primitive_dictionary::<_, _>(values,
&value_indices_map, v, n, options, limit, cmp)
- },
- DataType::Int16 => {
+ dt if DataType::is_primitive(dt) => {
let dict_values = values.values();
- let value_null_first = if options.descending {
- !options.nulls_first
- } else {
- options.nulls_first
- };
- let value_options = Some(SortOptions { descending:
false, nulls_first: value_null_first });
- let sorted_value_indices =
sort_to_indices(dict_values, value_options, None)?;
- let value_indices_map =
prepare_indices_map(&sorted_value_indices);
- sort_primitive_dictionary::<_, _>(values,
&value_indices_map, v, n, options, limit, cmp)
- },
- DataType::Int32 => {
- let dict_values = values.values();
- let value_null_first = if options.descending {
- !options.nulls_first
- } else {
- options.nulls_first
- };
- let value_options = Some(SortOptions { descending:
false, nulls_first: value_null_first });
- let sorted_value_indices =
sort_to_indices(dict_values, value_options, None)?;
- let value_indices_map =
prepare_indices_map(&sorted_value_indices);
- sort_primitive_dictionary::<_, _>(values,
&value_indices_map, v, n, options, limit, cmp)
- },
- DataType::Int64 => {
- let dict_values = values.values();
- let value_null_first = if options.descending {
- !options.nulls_first
- } else {
- options.nulls_first
- };
- let value_options = Some(SortOptions { descending:
false, nulls_first: value_null_first });
- let sorted_value_indices =
sort_to_indices(dict_values, value_options, None)?;
- let value_indices_map =
prepare_indices_map(&sorted_value_indices);
- sort_primitive_dictionary::<_, _>(values,
&value_indices_map,v, n, options, limit, cmp)
- },
- DataType::UInt8 => {
- let dict_values = values.values();
- let value_null_first = if options.descending {
- !options.nulls_first
- } else {
- options.nulls_first
- };
- let value_options = Some(SortOptions { descending:
false, nulls_first: value_null_first });
- let sorted_value_indices =
sort_to_indices(dict_values, value_options, None)?;
- let value_indices_map =
prepare_indices_map(&sorted_value_indices);
- sort_primitive_dictionary::<_, _>(values,
&value_indices_map,v, n, options, limit, cmp)
- },
- DataType::UInt16 => {
- let dict_values = values.values();
- let value_null_first = if options.descending {
- !options.nulls_first
- } else {
- options.nulls_first
- };
- let value_options = Some(SortOptions { descending:
false, nulls_first: value_null_first });
- let sorted_value_indices =
sort_to_indices(dict_values, value_options, None)?;
- let value_indices_map =
prepare_indices_map(&sorted_value_indices);
- sort_primitive_dictionary::<_, _>(values,
&value_indices_map,v, n, options, limit, cmp)
- },
- DataType::UInt32 => {
- let dict_values = values.values();
- let value_null_first = if options.descending {
- !options.nulls_first
- } else {
- options.nulls_first
- };
- let value_options = Some(SortOptions { descending:
false, nulls_first: value_null_first });
- let sorted_value_indices =
sort_to_indices(dict_values, value_options, None)?;
- let value_indices_map =
prepare_indices_map(&sorted_value_indices);
- sort_primitive_dictionary::<_, _>(values,
&value_indices_map,v, n, options, limit, cmp)
- },
- DataType::UInt64 => {
- let dict_values = values.values();
- let value_null_first = if options.descending {
- !options.nulls_first
- } else {
- options.nulls_first
- };
- let value_options = Some(SortOptions { descending:
false, nulls_first: value_null_first });
let sorted_value_indices =
sort_to_indices(dict_values, value_options, None)?;
let value_indices_map =
prepare_indices_map(&sorted_value_indices);
sort_primitive_dictionary::<_, _>(values,
&value_indices_map, v, n, options, limit, cmp)
},
DataType::Utf8 => {
let dict_values = values.values();
- let value_null_first = if options.descending {
- !options.nulls_first
- } else {
- options.nulls_first
- };
- let value_options = Some(SortOptions { descending:
false, nulls_first: value_null_first });
let sorted_value_indices =
sort_to_indices(dict_values, value_options, None)?;
let value_indices_map =
prepare_indices_map(&sorted_value_indices);
sort_string_dictionary::<_>(values,
&value_indices_map, v, n, &options, limit)
@@ -3552,4 +3465,142 @@ mod tests {
vec![None, None, None, Some(5), Some(5), Some(3), Some(1)],
);
}
+
+ #[test]
+ fn test_sort_f32_dicts() {
+ let keys =
+ Int8Array::from(vec![Some(1_i8), None, Some(2), None, Some(2),
Some(0)]);
+ let values = Float32Array::from(vec![1.2, 3.0, 5.1]);
+ test_sort_primitive_dict_arrays::<Int8Type, Float32Type>(
+ keys,
+ values,
+ None,
+ None,
+ vec![None, None, Some(1.2), Some(3.0), Some(5.1), Some(5.1)],
+ );
+
+ let keys =
+ Int8Array::from(vec![Some(1_i8), None, Some(2), None, Some(2),
Some(0)]);
+ let values = Float32Array::from(vec![1.2, 3.0, 5.1]);
+ test_sort_primitive_dict_arrays::<Int8Type, Float32Type>(
+ keys,
+ values,
+ Some(SortOptions {
+ descending: true,
+ nulls_first: false,
+ }),
+ None,
+ vec![Some(5.1), Some(5.1), Some(3.0), Some(1.2), None, None],
+ );
+
+ let keys =
+ Int8Array::from(vec![Some(1_i8), None, Some(2), None, Some(2),
Some(0)]);
+ let values = Float32Array::from(vec![1.2, 3.0, 5.1]);
+ test_sort_primitive_dict_arrays::<Int8Type, Float32Type>(
+ keys,
+ values,
+ Some(SortOptions {
+ descending: false,
+ nulls_first: false,
+ }),
+ None,
+ vec![Some(1.2), Some(3.0), Some(5.1), Some(5.1), None, None],
+ );
+
+ let keys =
+ Int8Array::from(vec![Some(1_i8), None, Some(2), None, Some(2),
Some(0)]);
+ let values = Float32Array::from(vec![1.2, 3.0, 5.1]);
+ test_sort_primitive_dict_arrays::<Int8Type, Float32Type>(
+ keys,
+ values,
+ Some(SortOptions {
+ descending: true,
+ nulls_first: true,
+ }),
+ Some(3),
+ vec![None, None, Some(5.1)],
+ );
+
+ // Values have `None`.
+ let keys = Int8Array::from(vec![
+ Some(1_i8),
+ None,
+ Some(3),
+ None,
+ Some(2),
+ Some(3),
+ Some(0),
+ ]);
+ let values = Float32Array::from(vec![Some(1.2), Some(3.0), None,
Some(5.1)]);
+ test_sort_primitive_dict_arrays::<Int8Type, Float32Type>(
+ keys,
+ values,
+ None,
+ None,
+ vec![None, None, None, Some(1.2), Some(3.0), Some(5.1), Some(5.1)],
+ );
+
+ let keys = Int8Array::from(vec![
+ Some(1_i8),
+ None,
+ Some(3),
+ None,
+ Some(2),
+ Some(3),
+ Some(0),
+ ]);
+ let values = Float32Array::from(vec![Some(1.2), Some(3.0), None,
Some(5.1)]);
+ test_sort_primitive_dict_arrays::<Int8Type, Float32Type>(
+ keys,
+ values,
+ Some(SortOptions {
+ descending: false,
+ nulls_first: false,
+ }),
+ None,
+ vec![Some(1.2), Some(3.0), Some(5.1), Some(5.1), None, None, None],
+ );
+
+ let keys = Int8Array::from(vec![
+ Some(1_i8),
+ None,
+ Some(3),
+ None,
+ Some(2),
+ Some(3),
+ Some(0),
+ ]);
+ let values = Float32Array::from(vec![Some(1.2), Some(3.0), None,
Some(5.1)]);
+ test_sort_primitive_dict_arrays::<Int8Type, Float32Type>(
+ keys,
+ values,
+ Some(SortOptions {
+ descending: true,
+ nulls_first: false,
+ }),
+ None,
+ vec![Some(5.1), Some(5.1), Some(3.0), Some(1.2), None, None, None],
+ );
+
+ let keys = Int8Array::from(vec![
+ Some(1_i8),
+ None,
+ Some(3),
+ None,
+ Some(2),
+ Some(3),
+ Some(0),
+ ]);
+ let values = Float32Array::from(vec![Some(1.2), Some(3.0), None,
Some(5.1)]);
+ test_sort_primitive_dict_arrays::<Int8Type, Float32Type>(
+ keys,
+ values,
+ Some(SortOptions {
+ descending: true,
+ nulls_first: true,
+ }),
+ None,
+ vec![None, None, None, Some(5.1), Some(5.1), Some(3.0), Some(1.2)],
+ );
+ }
}
diff --git a/arrow/src/datatypes/datatype.rs b/arrow/src/datatypes/datatype.rs
index 2ca71ef77..d3189b8b1 100644
--- a/arrow/src/datatypes/datatype.rs
+++ b/arrow/src/datatypes/datatype.rs
@@ -1070,6 +1070,30 @@ impl DataType {
)
}
+ /// Returns true if the type is primitive: (numeric, temporal).
+ pub fn is_primitive(t: &DataType) -> bool {
+ use DataType::*;
+ matches!(
+ t,
+ Int8 | Int16
+ | Int32
+ | Int64
+ | UInt8
+ | UInt16
+ | UInt32
+ | UInt64
+ | Float32
+ | Float64
+ | Date32
+ | Date64
+ | Time32(_)
+ | Time64(_)
+ | Timestamp(_, _)
+ | Interval(_)
+ | Duration(_)
+ )
+ }
+
/// Returns true if this type is temporal: (Date*, Time*, Duration, or
Interval).
pub fn is_temporal(t: &DataType) -> bool {
use DataType::*;