This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 04bd39521 Support sorting dictionary encoded primitive integer arrays 
(#2680)
04bd39521 is described below

commit 04bd39521e264693d38048059c372b0712cc87a2
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Sat Sep 10 00:22:38 2022 -0700

    Support sorting dictionary encoded primitive integer arrays (#2680)
    
    * Support sorting dictionary encoded primitive arrays
    
    * Reduce combinatorial fanout
    
    * Change from &SortOptions to SortOptions
    
    * Fix value order and add a test
    
    * Fix null ordering and add test
    
    * Add comment and increase test coverage.
---
 arrow/src/compute/kernels/sort.rs | 405 ++++++++++++++++++++++++++++++++++----
 1 file changed, 363 insertions(+), 42 deletions(-)

diff --git a/arrow/src/compute/kernels/sort.rs 
b/arrow/src/compute/kernels/sort.rs
index 0e2273e92..7a2d47786 100644
--- a/arrow/src/compute/kernels/sort.rs
+++ b/arrow/src/compute/kernels/sort.rs
@@ -21,8 +21,10 @@ use crate::array::*;
 use crate::buffer::MutableBuffer;
 use crate::compute::take;
 use crate::datatypes::*;
+use crate::downcast_dictionary_array;
 use crate::error::{ArrowError, Result};
 use std::cmp::Ordering;
+use std::collections::HashMap;
 use TimeUnit::*;
 
 /// Sort the `ArrayRef` using `SortOptions`.
@@ -311,41 +313,121 @@ pub fn sort_to_indices(
                 )));
             }
         },
-        DataType::Dictionary(key_type, value_type)
-            if *value_type.as_ref() == DataType::Utf8 =>
-        {
-            match key_type.as_ref() {
-                DataType::Int8 => {
-                    sort_string_dictionary::<Int8Type>(values, v, n, &options, 
limit)
-                }
-                DataType::Int16 => {
-                    sort_string_dictionary::<Int16Type>(values, v, n, 
&options, limit)
-                }
-                DataType::Int32 => {
-                    sort_string_dictionary::<Int32Type>(values, v, n, 
&options, limit)
-                }
-                DataType::Int64 => {
-                    sort_string_dictionary::<Int64Type>(values, v, n, 
&options, limit)
-                }
-                DataType::UInt8 => {
-                    sort_string_dictionary::<UInt8Type>(values, v, n, 
&options, limit)
-                }
-                DataType::UInt16 => {
-                    sort_string_dictionary::<UInt16Type>(values, v, n, 
&options, limit)
-                }
-                DataType::UInt32 => {
-                    sort_string_dictionary::<UInt32Type>(values, v, n, 
&options, limit)
-                }
-                DataType::UInt64 => {
-                    sort_string_dictionary::<UInt64Type>(values, v, n, 
&options, limit)
-                }
-                t => {
-                    return Err(ArrowError::ComputeError(format!(
-                        "Sort not supported for dictionary key type {:?}",
-                        t
-                    )));
-                }
-            }
+        DataType::Dictionary(_, _) => {
+            downcast_dictionary_array!(
+                values => match values.values().data_type() {
+                    DataType::Int8 => {
+                        let dict_values = values.values();
+                        let value_null_first = if options.descending {
+                            // When sorting dictionary in descending order, we 
take inverse of of null ordering
+                            // when sorting the values. Because if 
`nulls_first` is true, null must be in front
+                            // of non-null value. As we take the sorted order 
of value array to sort dictionary
+                            // keys, these null values will be treated as 
smallest ones and be sorted to the end
+                            // of sorted result. So we set `nulls_first` to 
false when sorting dictionary value
+                            // array to make them as largest ones, then null 
values will be put at the beginning
+                            // of sorted dictionary result.
+                            !options.nulls_first
+                        } else {
+                            options.nulls_first
+                        };
+                        let value_options = Some(SortOptions { descending: 
false, nulls_first: value_null_first });
+                        let sorted_value_indices = 
sort_to_indices(dict_values, value_options, None)?;
+                        let value_indices_map = 
prepare_indices_map(&sorted_value_indices);
+                        sort_primitive_dictionary::<_, _>(values, 
&value_indices_map, v, n, options, limit, cmp)
+                    },
+                    DataType::Int16 => {
+                        let dict_values = values.values();
+                        let value_null_first = if options.descending {
+                            !options.nulls_first
+                        } else {
+                            options.nulls_first
+                        };
+                        let value_options = Some(SortOptions { descending: 
false, nulls_first: value_null_first });
+                        let sorted_value_indices = 
sort_to_indices(dict_values, value_options, None)?;
+                        let value_indices_map = 
prepare_indices_map(&sorted_value_indices);
+                        sort_primitive_dictionary::<_, _>(values, 
&value_indices_map, v, n, options, limit, cmp)
+                    },
+                    DataType::Int32 => {
+                        let dict_values = values.values();
+                        let value_null_first = if options.descending {
+                            !options.nulls_first
+                        } else {
+                            options.nulls_first
+                        };
+                        let value_options = Some(SortOptions { descending: 
false, nulls_first: value_null_first });
+                        let sorted_value_indices = 
sort_to_indices(dict_values, value_options, None)?;
+                        let value_indices_map = 
prepare_indices_map(&sorted_value_indices);
+                        sort_primitive_dictionary::<_, _>(values, 
&value_indices_map, v, n, options, limit, cmp)
+                    },
+                    DataType::Int64 => {
+                        let dict_values = values.values();
+                        let value_null_first = if options.descending {
+                            !options.nulls_first
+                        } else {
+                            options.nulls_first
+                        };
+                        let value_options = Some(SortOptions { descending: 
false, nulls_first: value_null_first });
+                        let sorted_value_indices = 
sort_to_indices(dict_values, value_options, None)?;
+                        let value_indices_map = 
prepare_indices_map(&sorted_value_indices);
+                        sort_primitive_dictionary::<_, _>(values, 
&value_indices_map,v, n, options, limit, cmp)
+                    },
+                    DataType::UInt8 => {
+                        let dict_values = values.values();
+                        let value_null_first = if options.descending {
+                            !options.nulls_first
+                        } else {
+                            options.nulls_first
+                        };
+                        let value_options = Some(SortOptions { descending: 
false, nulls_first: value_null_first });
+                        let sorted_value_indices = 
sort_to_indices(dict_values, value_options, None)?;
+                        let value_indices_map = 
prepare_indices_map(&sorted_value_indices);
+                        sort_primitive_dictionary::<_, _>(values, 
&value_indices_map,v, n, options, limit, cmp)
+                    },
+                    DataType::UInt16 => {
+                        let dict_values = values.values();
+                        let value_null_first = if options.descending {
+                            !options.nulls_first
+                        } else {
+                            options.nulls_first
+                        };
+                        let value_options = Some(SortOptions { descending: 
false, nulls_first: value_null_first });
+                        let sorted_value_indices = 
sort_to_indices(dict_values, value_options, None)?;
+                        let value_indices_map = 
prepare_indices_map(&sorted_value_indices);
+                        sort_primitive_dictionary::<_, _>(values, 
&value_indices_map,v, n, options, limit, cmp)
+                    },
+                    DataType::UInt32 => {
+                        let dict_values = values.values();
+                        let value_null_first = if options.descending {
+                            !options.nulls_first
+                        } else {
+                            options.nulls_first
+                        };
+                        let value_options = Some(SortOptions { descending: 
false, nulls_first: value_null_first });
+                        let sorted_value_indices = 
sort_to_indices(dict_values, value_options, None)?;
+                        let value_indices_map = 
prepare_indices_map(&sorted_value_indices);
+                        sort_primitive_dictionary::<_, _>(values, 
&value_indices_map,v, n, options, limit, cmp)
+                    },
+                    DataType::UInt64 => {
+                        let dict_values = values.values();
+                        let value_null_first = if options.descending {
+                            !options.nulls_first
+                        } else {
+                            options.nulls_first
+                        };
+                        let value_options = Some(SortOptions { descending: 
false, nulls_first: value_null_first });
+                        let sorted_value_indices = 
sort_to_indices(dict_values, value_options, None)?;
+                        let value_indices_map = 
prepare_indices_map(&sorted_value_indices);
+                        sort_primitive_dictionary::<_, _>(values, 
&value_indices_map, v, n, options, limit, cmp)
+                    },
+                    DataType::Utf8 => sort_string_dictionary::<_>(values, v, 
n, &options, limit),
+                    t => return Err(ArrowError::ComputeError(format!(
+                        "Unsupported dictionary value type {}", t
+                    ))),
+                },
+                t => return Err(ArrowError::ComputeError(format!(
+                    "Unsupported datatype {}", t
+                ))),
+            )
         }
         DataType::Binary | DataType::FixedSizeBinary(_) => {
             sort_binary::<i32>(values, v, n, &options, limit)
@@ -489,7 +571,14 @@ where
         .into_iter()
         .map(|index| (index, decimal_array.value(index as usize).as_i128()))
         .collect::<Vec<(u32, i128)>>();
-    sort_primitive_inner(decimal_values, null_indices, cmp, options, limit, 
valids)
+    sort_primitive_inner(
+        decimal_values.len(),
+        null_indices,
+        cmp,
+        options,
+        limit,
+        valids,
+    )
 }
 
 /// Sort primitive values
@@ -514,12 +603,55 @@ where
             .map(|index| (index, values.value(index as usize)))
             .collect::<Vec<(u32, T::Native)>>()
     };
-    sort_primitive_inner(values, null_indices, cmp, options, limit, valids)
+    sort_primitive_inner(values.len(), null_indices, cmp, options, limit, 
valids)
+}
+
+/// A helper function used to convert sorted value indices to a map that we 
can look up sorted order
+/// for a value index later.
+fn prepare_indices_map(sorted_value_indices: &UInt32Array) -> HashMap<usize, 
u32> {
+    sorted_value_indices
+        .into_iter()
+        .enumerate()
+        .map(|(idx, index)| {
+            // Indices don't have None value
+            let index = index.unwrap();
+            (index as usize, idx as u32)
+        })
+        .collect::<HashMap<usize, u32>>()
+}
+
+/// Sort dictionary encoded primitive values
+fn sort_primitive_dictionary<K, F>(
+    values: &DictionaryArray<K>,
+    value_indices_map: &HashMap<usize, u32>,
+    value_indices: Vec<u32>,
+    null_indices: Vec<u32>,
+    options: SortOptions,
+    limit: Option<usize>,
+    cmp: F,
+) -> UInt32Array
+where
+    K: ArrowDictionaryKeyType,
+    F: Fn(u32, u32) -> std::cmp::Ordering,
+{
+    let keys: &PrimitiveArray<K> = values.keys();
+
+    // create tuples that are used for sorting
+    let valids = value_indices
+        .into_iter()
+        .map(|index| {
+            let key: K::Native = keys.value(index as usize);
+            let value_order = 
value_indices_map.get(&key.to_usize().unwrap()).unwrap();
+            (index, *value_order)
+        })
+        .collect::<Vec<(u32, u32)>>();
+
+    sort_primitive_inner::<_, _>(keys.len(), null_indices, cmp, &options, 
limit, valids)
 }
 
 // sort is instantiated a lot so we only compile this inner version for each 
native type
 fn sort_primitive_inner<T, F>(
-    values: &ArrayRef,
+    value_len: usize,
     null_indices: Vec<u32>,
     cmp: F,
     options: &SortOptions,
@@ -535,7 +667,7 @@ where
 
     let valids_len = valids.len();
     let nulls_len = nulls.len();
-    let mut len = values.len();
+    let mut len = value_len;
 
     if let Some(limit) = limit {
         len = limit.min(len);
@@ -620,14 +752,12 @@ fn sort_string<Offset: OffsetSizeTrait>(
 
 /// Sort dictionary encoded strings
 fn sort_string_dictionary<T: ArrowDictionaryKeyType>(
-    values: &ArrayRef,
+    values: &DictionaryArray<T>,
     value_indices: Vec<u32>,
     null_indices: Vec<u32>,
     options: &SortOptions,
     limit: Option<usize>,
 ) -> UInt32Array {
-    let values: &DictionaryArray<T> = as_dictionary_array::<T>(values);
-
     let keys: &PrimitiveArray<T> = values.keys();
 
     let dict = values.values();
@@ -1239,6 +1369,59 @@ mod tests {
         assert_eq!(sorted_strings, expected)
     }
 
+    fn test_sort_primitive_dict_arrays<K: ArrowDictionaryKeyType, T: 
ArrowPrimitiveType>(
+        keys: PrimitiveArray<K>,
+        values: PrimitiveArray<T>,
+        options: Option<SortOptions>,
+        limit: Option<usize>,
+        expected_data: Vec<Option<T::Native>>,
+    ) where
+        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
+    {
+        let array = DictionaryArray::<K>::try_new(&keys, &values).unwrap();
+        let array_values = array.values().clone();
+        let dict = array_values
+            .as_any()
+            .downcast_ref::<PrimitiveArray<T>>()
+            .expect("Unable to get dictionary values");
+
+        let sorted = match limit {
+            Some(_) => {
+                sort_limit(&(Arc::new(array) as ArrayRef), options, 
limit).unwrap()
+            }
+            _ => sort(&(Arc::new(array) as ArrayRef), options).unwrap(),
+        };
+        let sorted = sorted
+            .as_any()
+            .downcast_ref::<DictionaryArray<K>>()
+            .unwrap();
+        let sorted_values = sorted.values();
+        let sorted_dict = sorted_values
+            .as_any()
+            .downcast_ref::<PrimitiveArray<T>>()
+            .expect("Unable to get dictionary values");
+        let sorted_keys = sorted.keys();
+
+        assert_eq!(sorted_dict, dict);
+
+        let sorted_values: PrimitiveArray<T> = 
From::<Vec<Option<T::Native>>>::from(
+            (0..sorted.len())
+                .map(|i| {
+                    let key = sorted_keys.value(i).to_usize().unwrap();
+                    if sorted.is_valid(i) && sorted_dict.is_valid(key) {
+                        Some(sorted_dict.value(key))
+                    } else {
+                        None
+                    }
+                })
+                .collect::<Vec<Option<T::Native>>>(),
+        );
+        let expected: PrimitiveArray<T> =
+            From::<Vec<Option<T::Native>>>::from(expected_data);
+
+        assert_eq!(sorted_values, expected)
+    }
+
     fn test_sort_list_arrays<T>(
         data: Vec<Option<Vec<Option<T::Native>>>>,
         options: Option<SortOptions>,
@@ -3222,4 +3405,142 @@ mod tests {
         partial_sort(&mut before, last, |a, b| a.cmp(b));
         assert_eq!(&d[0..last], &before[0..last]);
     }
+
+    #[test]
+    fn test_sort_int8_dicts() {
+        let keys =
+            Int8Array::from(vec![Some(1_i8), None, Some(2), None, Some(2), 
Some(0)]);
+        let values = Int8Array::from(vec![1, 3, 5]);
+        test_sort_primitive_dict_arrays::<Int8Type, Int8Type>(
+            keys,
+            values,
+            None,
+            None,
+            vec![None, None, Some(1), Some(3), Some(5), Some(5)],
+        );
+
+        let keys =
+            Int8Array::from(vec![Some(1_i8), None, Some(2), None, Some(2), 
Some(0)]);
+        let values = Int8Array::from(vec![1, 3, 5]);
+        test_sort_primitive_dict_arrays::<Int8Type, Int8Type>(
+            keys,
+            values,
+            Some(SortOptions {
+                descending: true,
+                nulls_first: false,
+            }),
+            None,
+            vec![Some(5), Some(5), Some(3), Some(1), None, None],
+        );
+
+        let keys =
+            Int8Array::from(vec![Some(1_i8), None, Some(2), None, Some(2), 
Some(0)]);
+        let values = Int8Array::from(vec![1, 3, 5]);
+        test_sort_primitive_dict_arrays::<Int8Type, Int8Type>(
+            keys,
+            values,
+            Some(SortOptions {
+                descending: false,
+                nulls_first: false,
+            }),
+            None,
+            vec![Some(1), Some(3), Some(5), Some(5), None, None],
+        );
+
+        let keys =
+            Int8Array::from(vec![Some(1_i8), None, Some(2), None, Some(2), 
Some(0)]);
+        let values = Int8Array::from(vec![1, 3, 5]);
+        test_sort_primitive_dict_arrays::<Int8Type, Int8Type>(
+            keys,
+            values,
+            Some(SortOptions {
+                descending: true,
+                nulls_first: true,
+            }),
+            Some(3),
+            vec![None, None, Some(5)],
+        );
+
+        // Values have `None`.
+        let keys = Int8Array::from(vec![
+            Some(1_i8),
+            None,
+            Some(3),
+            None,
+            Some(2),
+            Some(3),
+            Some(0),
+        ]);
+        let values = Int8Array::from(vec![Some(1), Some(3), None, Some(5)]);
+        test_sort_primitive_dict_arrays::<Int8Type, Int8Type>(
+            keys,
+            values,
+            None,
+            None,
+            vec![None, None, None, Some(1), Some(3), Some(5), Some(5)],
+        );
+
+        let keys = Int8Array::from(vec![
+            Some(1_i8),
+            None,
+            Some(3),
+            None,
+            Some(2),
+            Some(3),
+            Some(0),
+        ]);
+        let values = Int8Array::from(vec![Some(1), Some(3), None, Some(5)]);
+        test_sort_primitive_dict_arrays::<Int8Type, Int8Type>(
+            keys,
+            values,
+            Some(SortOptions {
+                descending: false,
+                nulls_first: false,
+            }),
+            None,
+            vec![Some(1), Some(3), Some(5), Some(5), None, None, None],
+        );
+
+        let keys = Int8Array::from(vec![
+            Some(1_i8),
+            None,
+            Some(3),
+            None,
+            Some(2),
+            Some(3),
+            Some(0),
+        ]);
+        let values = Int8Array::from(vec![Some(1), Some(3), None, Some(5)]);
+        test_sort_primitive_dict_arrays::<Int8Type, Int8Type>(
+            keys,
+            values,
+            Some(SortOptions {
+                descending: true,
+                nulls_first: false,
+            }),
+            None,
+            vec![Some(5), Some(5), Some(3), Some(1), None, None, None],
+        );
+
+        let keys = Int8Array::from(vec![
+            Some(1_i8),
+            None,
+            Some(3),
+            None,
+            Some(2),
+            Some(3),
+            Some(0),
+        ]);
+        let values = Int8Array::from(vec![Some(1), Some(3), None, Some(5)]);
+        test_sort_primitive_dict_arrays::<Int8Type, Int8Type>(
+            keys,
+            values,
+            Some(SortOptions {
+                descending: true,
+                nulls_first: true,
+            }),
+            None,
+            vec![None, None, None, Some(5), Some(5), Some(3), Some(1)],
+        );
+    }
 }

Reply via email to