This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new e9947d9  Support DecimalType in sort and take kernels (#1172)
e9947d9 is described below

commit e9947d944934e108f9ffa8b14eefaead0461930a
Author: Kun Liu <[email protected]>
AuthorDate: Mon Jan 17 23:42:45 2022 +0800

    Support DecimalType in sort and take kernels (#1172)
---
 arrow/src/compute/kernels/sort.rs | 234 +++++++++++++++++++++++++++++++++++++-
 arrow/src/compute/kernels/take.rs | 105 +++++++++++++++++
 2 files changed, 334 insertions(+), 5 deletions(-)

diff --git a/arrow/src/compute/kernels/sort.rs 
b/arrow/src/compute/kernels/sort.rs
index 5aa1fdf..6549477 100644
--- a/arrow/src/compute/kernels/sort.rs
+++ b/arrow/src/compute/kernels/sort.rs
@@ -170,6 +170,7 @@ pub fn sort_to_indices(
     let (v, n) = partition_validity(values);
 
     Ok(match values.data_type() {
+        DataType::Decimal(_, _) => sort_decimal(values, v, n, cmp, &options, 
limit),
         DataType::Boolean => sort_boolean(values, v, n, &options, limit),
         DataType::Int8 => {
             sort_primitive::<Int8Type, _>(values, v, n, cmp, &options, limit)
@@ -293,7 +294,7 @@ pub fn sort_to_indices(
                 return Err(ArrowError::ComputeError(format!(
                     "Sort not supported for list type {:?}",
                     t
-                )))
+                )));
             }
         },
         DataType::LargeList(field) => match field.data_type() {
@@ -321,7 +322,7 @@ pub fn sort_to_indices(
                 return Err(ArrowError::ComputeError(format!(
                     "Sort not supported for list type {:?}",
                     t
-                )))
+                )));
             }
         },
         DataType::FixedSizeList(field, _) => match field.data_type() {
@@ -349,7 +350,7 @@ pub fn sort_to_indices(
                 return Err(ArrowError::ComputeError(format!(
                     "Sort not supported for list type {:?}",
                     t
-                )))
+                )));
             }
         },
         DataType::Dictionary(key_type, value_type)
@@ -384,7 +385,7 @@ pub fn sort_to_indices(
                     return Err(ArrowError::ComputeError(format!(
                         "Sort not supported for dictionary key type {:?}",
                         t
-                    )))
+                    )));
                 }
             }
         }
@@ -396,7 +397,7 @@ pub fn sort_to_indices(
             return Err(ArrowError::ComputeError(format!(
                 "Sort not supported for data type {:?}",
                 t
-            )))
+            )));
         }
     })
 }
@@ -509,6 +510,30 @@ fn sort_boolean(
     UInt32Array::from(result_data)
 }
 
+/// Sort Decimal array
+fn sort_decimal<F>(
+    decimal_values: &ArrayRef,
+    value_indices: Vec<u32>,
+    null_indices: Vec<u32>,
+    cmp: F,
+    options: &SortOptions,
+    limit: Option<usize>,
+) -> UInt32Array
+where
+    F: Fn(i128, i128) -> std::cmp::Ordering,
+{
+    // downcast to decimal array
+    let decimal_array = decimal_values
+        .as_any()
+        .downcast_ref::<DecimalArray>()
+        .expect("Unable to downcast to decimal array");
+    let valids = value_indices
+        .into_iter()
+        .map(|index| (index, decimal_array.value(index as usize)))
+        .collect::<Vec<(u32, i128)>>();
+    sort_primitive_inner(decimal_values, null_indices, cmp, options, limit, 
valids)
+}
+
 /// Sort primitive values
 fn sort_primitive<T, F>(
     values: &ArrayRef,
@@ -1080,6 +1105,49 @@ mod tests {
     use std::convert::TryFrom;
     use std::sync::Arc;
 
+    fn create_decimal_array(data: &[Option<i128>]) -> DecimalArray {
+        let mut builder = DecimalBuilder::new(20, 23, 6);
+
+        for d in data {
+            if let Some(v) = d {
+                builder.append_value(*v).unwrap();
+            } else {
+                builder.append_null().unwrap();
+            }
+        }
+        builder.finish()
+    }
+
+    fn test_sort_to_indices_decimal_array(
+        data: Vec<Option<i128>>,
+        options: Option<SortOptions>,
+        limit: Option<usize>,
+        expected_data: Vec<u32>,
+    ) {
+        let output = create_decimal_array(&data);
+        let expected = UInt32Array::from(expected_data);
+        let output =
+            sort_to_indices(&(Arc::new(output) as ArrayRef), options, 
limit).unwrap();
+        assert_eq!(output, expected)
+    }
+
+    fn test_sort_decimal_array(
+        data: Vec<Option<i128>>,
+        options: Option<SortOptions>,
+        limit: Option<usize>,
+        expected_data: Vec<Option<i128>>,
+    ) {
+        let output = create_decimal_array(&data);
+        let expected = Arc::new(create_decimal_array(&expected_data)) as 
ArrayRef;
+        let output = match limit {
+            Some(_) => {
+                sort_limit(&(Arc::new(output) as ArrayRef), options, 
limit).unwrap()
+            }
+            _ => sort(&(Arc::new(output) as ArrayRef), options).unwrap(),
+        };
+        assert_eq!(&output, &expected)
+    }
+
     fn test_sort_to_indices_boolean_arrays(
         data: Vec<Option<bool>>,
         options: Option<SortOptions>,
@@ -1660,6 +1728,162 @@ mod tests {
     }
 
     #[test]
+    fn test_sort_indices_decimal128() {
+        // decimal default
+        test_sort_to_indices_decimal_array(
+            vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None],
+            None,
+            None,
+            vec![0, 6, 4, 2, 3, 5, 1],
+        );
+        // decimal descending
+        test_sort_to_indices_decimal_array(
+            vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None],
+            Some(SortOptions {
+                descending: true,
+                nulls_first: false,
+            }),
+            None,
+            vec![1, 5, 3, 2, 4, 6, 0],
+        );
+        // decimal null_first and descending
+        test_sort_to_indices_decimal_array(
+            vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None],
+            Some(SortOptions {
+                descending: true,
+                nulls_first: true,
+            }),
+            None,
+            vec![6, 0, 1, 5, 3, 2, 4],
+        );
+        // decimal null_first
+        test_sort_to_indices_decimal_array(
+            vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None],
+            Some(SortOptions {
+                descending: false,
+                nulls_first: true,
+            }),
+            None,
+            vec![0, 6, 4, 2, 3, 5, 1],
+        );
+        // limit
+        test_sort_to_indices_decimal_array(
+            vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None],
+            None,
+            Some(3),
+            vec![0, 6, 4],
+        );
+        // limit descending
+        test_sort_to_indices_decimal_array(
+            vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None],
+            Some(SortOptions {
+                descending: true,
+                nulls_first: false,
+            }),
+            Some(3),
+            vec![1, 5, 3],
+        );
+        // limit descending null_first
+        test_sort_to_indices_decimal_array(
+            vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None],
+            Some(SortOptions {
+                descending: true,
+                nulls_first: true,
+            }),
+            Some(3),
+            vec![6, 0, 1],
+        );
+        // limit null_first
+        test_sort_to_indices_decimal_array(
+            vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None],
+            Some(SortOptions {
+                descending: false,
+                nulls_first: true,
+            }),
+            Some(3),
+            vec![0, 6, 4],
+        );
+    }
+
+    #[test]
+    fn test_sort_decimal128() {
+        // decimal default
+        test_sort_decimal_array(
+            vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None],
+            None,
+            None,
+            vec![None, None, Some(1), Some(2), Some(3), Some(4), Some(5)],
+        );
+        // decimal descending
+        test_sort_decimal_array(
+            vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None],
+            Some(SortOptions {
+                descending: true,
+                nulls_first: false,
+            }),
+            None,
+            vec![Some(5), Some(4), Some(3), Some(2), Some(1), None, None],
+        );
+        // decimal null_first and descending
+        test_sort_decimal_array(
+            vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None],
+            Some(SortOptions {
+                descending: true,
+                nulls_first: true,
+            }),
+            None,
+            vec![None, None, Some(5), Some(4), Some(3), Some(2), Some(1)],
+        );
+        // decimal null_first
+        test_sort_decimal_array(
+            vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None],
+            Some(SortOptions {
+                descending: false,
+                nulls_first: true,
+            }),
+            None,
+            vec![None, None, Some(1), Some(2), Some(3), Some(4), Some(5)],
+        );
+        // limit
+        test_sort_decimal_array(
+            vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None],
+            None,
+            Some(3),
+            vec![None, None, Some(1)],
+        );
+        // limit descending
+        test_sort_decimal_array(
+            vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None],
+            Some(SortOptions {
+                descending: true,
+                nulls_first: false,
+            }),
+            Some(3),
+            vec![Some(5), Some(4), Some(3)],
+        );
+        // limit descending null_first
+        test_sort_decimal_array(
+            vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None],
+            Some(SortOptions {
+                descending: true,
+                nulls_first: true,
+            }),
+            Some(3),
+            vec![None, None, Some(5)],
+        );
+        // limit null_first
+        test_sort_decimal_array(
+            vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None],
+            Some(SortOptions {
+                descending: false,
+                nulls_first: true,
+            }),
+            Some(3),
+            vec![None, None, Some(1)],
+        );
+    }
+
+    #[test]
     fn test_sort_primitives() {
         // default case
         test_sort_primitive_arrays::<UInt8Type>(
diff --git a/arrow/src/compute/kernels/take.rs 
b/arrow/src/compute/kernels/take.rs
index 4e14fac..a3b5283 100644
--- a/arrow/src/compute/kernels/take.rs
+++ b/arrow/src/compute/kernels/take.rs
@@ -131,6 +131,10 @@ where
             let values = 
values.as_any().downcast_ref::<BooleanArray>().unwrap();
             Ok(Arc::new(take_boolean(values, indices)?))
         }
+        DataType::Decimal(_, _) => {
+            let decimal_values = 
values.as_any().downcast_ref::<DecimalArray>().unwrap();
+            Ok(Arc::new(take_decimal128(decimal_values, indices)?))
+        }
         DataType::Int8 => downcast_take!(Int8Type, values, indices),
         DataType::Int16 => downcast_take!(Int16Type, values, indices),
         DataType::Int32 => downcast_take!(Int32Type, values, indices),
@@ -483,6 +487,38 @@ where
     Ok((buffer, nulls))
 }
 
+/// `take` implementation for decimal arrays
+fn take_decimal128<IndexType>(
+    decimal_values: &DecimalArray,
+    indices: &PrimitiveArray<IndexType>,
+) -> Result<DecimalArray>
+where
+    IndexType: ArrowNumericType,
+    IndexType::Native: ToPrimitive,
+{
+    // TODO optimize decimal take and construct decimal array from 
MutableBuffer
+    let mut builder = DecimalBuilder::new(
+        indices.len(),
+        decimal_values.precision(),
+        decimal_values.scale(),
+    );
+    for i in 0..indices.len() {
+        if indices.is_null(i) {
+            builder.append_null()?;
+        } else {
+            let index = ToPrimitive::to_usize(&indices.value(i)).ok_or_else(|| 
{
+                ArrowError::ComputeError("Cast to usize failed".to_string())
+            })?;
+            if decimal_values.is_null(index) {
+                builder.append_null()?
+            } else {
+                builder.append_value(decimal_values.value(index))?
+            }
+        }
+    }
+    Ok(builder.finish())
+}
+
 /// `take` implementation for all primitive arrays
 ///
 /// This checks if an `indices` slot is populated, and gets the value from 
`values`
@@ -921,6 +957,43 @@ mod tests {
     use super::*;
     use crate::compute::util::tests::build_fixed_size_list_nullable;
 
+    fn test_take_decimal_arrays(
+        data: Vec<Option<i128>>,
+        index: &UInt32Array,
+        options: Option<TakeOptions>,
+        expected_data: Vec<Option<i128>>,
+        precision: &usize,
+        scale: &usize,
+    ) -> Result<()> {
+        let mut builder = DecimalBuilder::new(data.len(), *precision, *scale);
+        for value in data {
+            match value {
+                None => {
+                    builder.append_null()?;
+                }
+                Some(v) => {
+                    builder.append_value(v)?;
+                }
+            }
+        }
+        let output = builder.finish();
+        let mut builder = DecimalBuilder::new(expected_data.len(), *precision, 
*scale);
+        for value in expected_data {
+            match value {
+                None => {
+                    builder.append_null()?;
+                }
+                Some(v) => {
+                    builder.append_value(v)?;
+                }
+            }
+        }
+        let expected = Arc::new(builder.finish()) as ArrayRef;
+        let output = take(&output, index, options).unwrap();
+        assert_eq!(&output, &expected);
+        Ok(())
+    }
+
     fn test_take_boolean_arrays(
         data: Vec<Option<bool>>,
         index: &UInt32Array,
@@ -1018,6 +1091,38 @@ mod tests {
     }
 
     #[test]
+    fn test_take_decimal128_non_null_indices() {
+        let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
+        let precision: usize = 10;
+        let scale: usize = 5;
+        test_take_decimal_arrays(
+            vec![None, Some(3), Some(5), Some(2), Some(3), None],
+            &index,
+            None,
+            vec![None, None, Some(2), Some(3), Some(3), Some(5)],
+            &precision,
+            &scale,
+        )
+        .unwrap();
+    }
+
+    #[test]
+    fn test_take_decimal128() {
+        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), 
Some(2)]);
+        let precision: usize = 10;
+        let scale: usize = 5;
+        test_take_decimal_arrays(
+            vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
+            &index,
+            None,
+            vec![Some(3), None, Some(1), Some(3), Some(2)],
+            &precision,
+            &scale,
+        )
+        .unwrap();
+    }
+
+    #[test]
     fn test_take_primitive_non_null_indices() {
         let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
         test_take_primitive_arrays::<Int8Type>(

Reply via email to