This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 2a0fc7703 Add support of sorting dictionary of other primitive arrays 
(#2701)
2a0fc7703 is described below

commit 2a0fc7703420f99d28141516cabdd0408a583dfc
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Wed Sep 14 14:05:58 2022 -0700

    Add support of sorting dictionary of other primitive arrays (#2701)
    
    * Add support of sorting dictionary of other primitive arrays
    
    * Collapse match statements
    
    * Add one helper to match primitive types
---
 arrow/src/compute/kernels/sort.rs | 259 +++++++++++++++++++++++---------------
 arrow/src/datatypes/datatype.rs   |  24 ++++
 2 files changed, 179 insertions(+), 104 deletions(-)

diff --git a/arrow/src/compute/kernels/sort.rs 
b/arrow/src/compute/kernels/sort.rs
index 34a321910..0bc2d3948 100644
--- a/arrow/src/compute/kernels/sort.rs
+++ b/arrow/src/compute/kernels/sort.rs
@@ -314,119 +314,32 @@ pub fn sort_to_indices(
             }
         },
         DataType::Dictionary(_, _) => {
+            let value_null_first = if options.descending {
+                // When sorting dictionary in descending order, we take 
inverse of of null ordering
+                // when sorting the values. Because if `nulls_first` is true, 
null must be in front
+                // of non-null value. As we take the sorted order of value 
array to sort dictionary
+                // keys, these null values will be treated as smallest ones 
and be sorted to the end
+                // of sorted result. So we set `nulls_first` to false when 
sorting dictionary value
+                // array to make them as largest ones, then null values will 
be put at the beginning
+                // of sorted dictionary result.
+                !options.nulls_first
+            } else {
+                options.nulls_first
+            };
+            let value_options = Some(SortOptions {
+                descending: false,
+                nulls_first: value_null_first,
+            });
             downcast_dictionary_array!(
                 values => match values.values().data_type() {
-                    DataType::Int8 => {
-                        let dict_values = values.values();
-                        let value_null_first = if options.descending {
-                            // When sorting dictionary in descending order, we 
take inverse of of null ordering
-                            // when sorting the values. Because if 
`nulls_first` is true, null must be in front
-                            // of non-null value. As we take the sorted order 
of value array to sort dictionary
-                            // keys, these null values will be treated as 
smallest ones and be sorted to the end
-                            // of sorted result. So we set `nulls_first` to 
false when sorting dictionary value
-                            // array to make them as largest ones, then null 
values will be put at the beginning
-                            // of sorted dictionary result.
-                            !options.nulls_first
-                        } else {
-                            options.nulls_first
-                        };
-                        let value_options = Some(SortOptions { descending: 
false, nulls_first: value_null_first });
-                        let sorted_value_indices = 
sort_to_indices(dict_values, value_options, None)?;
-                        let value_indices_map = 
prepare_indices_map(&sorted_value_indices);
-                        sort_primitive_dictionary::<_, _>(values, 
&value_indices_map, v, n, options, limit, cmp)
-                    },
-                    DataType::Int16 => {
+                    dt if DataType::is_primitive(dt) => {
                         let dict_values = values.values();
-                        let value_null_first = if options.descending {
-                            !options.nulls_first
-                        } else {
-                            options.nulls_first
-                        };
-                        let value_options = Some(SortOptions { descending: 
false, nulls_first: value_null_first });
-                        let sorted_value_indices = 
sort_to_indices(dict_values, value_options, None)?;
-                        let value_indices_map = 
prepare_indices_map(&sorted_value_indices);
-                        sort_primitive_dictionary::<_, _>(values, 
&value_indices_map, v, n, options, limit, cmp)
-                    },
-                    DataType::Int32 => {
-                        let dict_values = values.values();
-                        let value_null_first = if options.descending {
-                            !options.nulls_first
-                        } else {
-                            options.nulls_first
-                        };
-                        let value_options = Some(SortOptions { descending: 
false, nulls_first: value_null_first });
-                        let sorted_value_indices = 
sort_to_indices(dict_values, value_options, None)?;
-                        let value_indices_map = 
prepare_indices_map(&sorted_value_indices);
-                        sort_primitive_dictionary::<_, _>(values, 
&value_indices_map, v, n, options, limit, cmp)
-                    },
-                    DataType::Int64 => {
-                        let dict_values = values.values();
-                        let value_null_first = if options.descending {
-                            !options.nulls_first
-                        } else {
-                            options.nulls_first
-                        };
-                        let value_options = Some(SortOptions { descending: 
false, nulls_first: value_null_first });
-                        let sorted_value_indices = 
sort_to_indices(dict_values, value_options, None)?;
-                        let value_indices_map = 
prepare_indices_map(&sorted_value_indices);
-                        sort_primitive_dictionary::<_, _>(values, 
&value_indices_map,v, n, options, limit, cmp)
-                    },
-                    DataType::UInt8 => {
-                        let dict_values = values.values();
-                        let value_null_first = if options.descending {
-                            !options.nulls_first
-                        } else {
-                            options.nulls_first
-                        };
-                        let value_options = Some(SortOptions { descending: 
false, nulls_first: value_null_first });
-                        let sorted_value_indices = 
sort_to_indices(dict_values, value_options, None)?;
-                        let value_indices_map = 
prepare_indices_map(&sorted_value_indices);
-                        sort_primitive_dictionary::<_, _>(values, 
&value_indices_map,v, n, options, limit, cmp)
-                    },
-                    DataType::UInt16 => {
-                        let dict_values = values.values();
-                        let value_null_first = if options.descending {
-                            !options.nulls_first
-                        } else {
-                            options.nulls_first
-                        };
-                        let value_options = Some(SortOptions { descending: 
false, nulls_first: value_null_first });
-                        let sorted_value_indices = 
sort_to_indices(dict_values, value_options, None)?;
-                        let value_indices_map = 
prepare_indices_map(&sorted_value_indices);
-                        sort_primitive_dictionary::<_, _>(values, 
&value_indices_map,v, n, options, limit, cmp)
-                    },
-                    DataType::UInt32 => {
-                        let dict_values = values.values();
-                        let value_null_first = if options.descending {
-                            !options.nulls_first
-                        } else {
-                            options.nulls_first
-                        };
-                        let value_options = Some(SortOptions { descending: 
false, nulls_first: value_null_first });
-                        let sorted_value_indices = 
sort_to_indices(dict_values, value_options, None)?;
-                        let value_indices_map = 
prepare_indices_map(&sorted_value_indices);
-                        sort_primitive_dictionary::<_, _>(values, 
&value_indices_map,v, n, options, limit, cmp)
-                    },
-                    DataType::UInt64 => {
-                        let dict_values = values.values();
-                        let value_null_first = if options.descending {
-                            !options.nulls_first
-                        } else {
-                            options.nulls_first
-                        };
-                        let value_options = Some(SortOptions { descending: 
false, nulls_first: value_null_first });
                         let sorted_value_indices = 
sort_to_indices(dict_values, value_options, None)?;
                         let value_indices_map = 
prepare_indices_map(&sorted_value_indices);
                         sort_primitive_dictionary::<_, _>(values, 
&value_indices_map, v, n, options, limit, cmp)
                     },
                     DataType::Utf8 => {
                         let dict_values = values.values();
-                        let value_null_first = if options.descending {
-                            !options.nulls_first
-                        } else {
-                            options.nulls_first
-                        };
-                        let value_options = Some(SortOptions { descending: 
false, nulls_first: value_null_first });
                         let sorted_value_indices = 
sort_to_indices(dict_values, value_options, None)?;
                         let value_indices_map = 
prepare_indices_map(&sorted_value_indices);
                         sort_string_dictionary::<_>(values, 
&value_indices_map, v, n, &options, limit)
@@ -3552,4 +3465,142 @@ mod tests {
             vec![None, None, None, Some(5), Some(5), Some(3), Some(1)],
         );
     }
+
+    #[test]
+    fn test_sort_f32_dicts() {
+        let keys =
+            Int8Array::from(vec![Some(1_i8), None, Some(2), None, Some(2), 
Some(0)]);
+        let values = Float32Array::from(vec![1.2, 3.0, 5.1]);
+        test_sort_primitive_dict_arrays::<Int8Type, Float32Type>(
+            keys,
+            values,
+            None,
+            None,
+            vec![None, None, Some(1.2), Some(3.0), Some(5.1), Some(5.1)],
+        );
+
+        let keys =
+            Int8Array::from(vec![Some(1_i8), None, Some(2), None, Some(2), 
Some(0)]);
+        let values = Float32Array::from(vec![1.2, 3.0, 5.1]);
+        test_sort_primitive_dict_arrays::<Int8Type, Float32Type>(
+            keys,
+            values,
+            Some(SortOptions {
+                descending: true,
+                nulls_first: false,
+            }),
+            None,
+            vec![Some(5.1), Some(5.1), Some(3.0), Some(1.2), None, None],
+        );
+
+        let keys =
+            Int8Array::from(vec![Some(1_i8), None, Some(2), None, Some(2), 
Some(0)]);
+        let values = Float32Array::from(vec![1.2, 3.0, 5.1]);
+        test_sort_primitive_dict_arrays::<Int8Type, Float32Type>(
+            keys,
+            values,
+            Some(SortOptions {
+                descending: false,
+                nulls_first: false,
+            }),
+            None,
+            vec![Some(1.2), Some(3.0), Some(5.1), Some(5.1), None, None],
+        );
+
+        let keys =
+            Int8Array::from(vec![Some(1_i8), None, Some(2), None, Some(2), 
Some(0)]);
+        let values = Float32Array::from(vec![1.2, 3.0, 5.1]);
+        test_sort_primitive_dict_arrays::<Int8Type, Float32Type>(
+            keys,
+            values,
+            Some(SortOptions {
+                descending: true,
+                nulls_first: true,
+            }),
+            Some(3),
+            vec![None, None, Some(5.1)],
+        );
+
+        // Values have `None`.
+        let keys = Int8Array::from(vec![
+            Some(1_i8),
+            None,
+            Some(3),
+            None,
+            Some(2),
+            Some(3),
+            Some(0),
+        ]);
+        let values = Float32Array::from(vec![Some(1.2), Some(3.0), None, 
Some(5.1)]);
+        test_sort_primitive_dict_arrays::<Int8Type, Float32Type>(
+            keys,
+            values,
+            None,
+            None,
+            vec![None, None, None, Some(1.2), Some(3.0), Some(5.1), Some(5.1)],
+        );
+
+        let keys = Int8Array::from(vec![
+            Some(1_i8),
+            None,
+            Some(3),
+            None,
+            Some(2),
+            Some(3),
+            Some(0),
+        ]);
+        let values = Float32Array::from(vec![Some(1.2), Some(3.0), None, 
Some(5.1)]);
+        test_sort_primitive_dict_arrays::<Int8Type, Float32Type>(
+            keys,
+            values,
+            Some(SortOptions {
+                descending: false,
+                nulls_first: false,
+            }),
+            None,
+            vec![Some(1.2), Some(3.0), Some(5.1), Some(5.1), None, None, None],
+        );
+
+        let keys = Int8Array::from(vec![
+            Some(1_i8),
+            None,
+            Some(3),
+            None,
+            Some(2),
+            Some(3),
+            Some(0),
+        ]);
+        let values = Float32Array::from(vec![Some(1.2), Some(3.0), None, 
Some(5.1)]);
+        test_sort_primitive_dict_arrays::<Int8Type, Float32Type>(
+            keys,
+            values,
+            Some(SortOptions {
+                descending: true,
+                nulls_first: false,
+            }),
+            None,
+            vec![Some(5.1), Some(5.1), Some(3.0), Some(1.2), None, None, None],
+        );
+
+        let keys = Int8Array::from(vec![
+            Some(1_i8),
+            None,
+            Some(3),
+            None,
+            Some(2),
+            Some(3),
+            Some(0),
+        ]);
+        let values = Float32Array::from(vec![Some(1.2), Some(3.0), None, 
Some(5.1)]);
+        test_sort_primitive_dict_arrays::<Int8Type, Float32Type>(
+            keys,
+            values,
+            Some(SortOptions {
+                descending: true,
+                nulls_first: true,
+            }),
+            None,
+            vec![None, None, None, Some(5.1), Some(5.1), Some(3.0), Some(1.2)],
+        );
+    }
 }
diff --git a/arrow/src/datatypes/datatype.rs b/arrow/src/datatypes/datatype.rs
index 2ca71ef77..d3189b8b1 100644
--- a/arrow/src/datatypes/datatype.rs
+++ b/arrow/src/datatypes/datatype.rs
@@ -1070,6 +1070,30 @@ impl DataType {
         )
     }
 
+    /// Returns true if the type is primitive: (numeric, temporal).
+    pub fn is_primitive(t: &DataType) -> bool {
+        use DataType::*;
+        matches!(
+            t,
+            Int8 | Int16
+                | Int32
+                | Int64
+                | UInt8
+                | UInt16
+                | UInt32
+                | UInt64
+                | Float32
+                | Float64
+                | Date32
+                | Date64
+                | Time32(_)
+                | Time64(_)
+                | Timestamp(_, _)
+                | Interval(_)
+                | Duration(_)
+        )
+    }
+
     /// Returns true if this type is temporal: (Date*, Time*, Duration, or 
Interval).
     pub fn is_temporal(t: &DataType) -> bool {
         use DataType::*;

Reply via email to