This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new e9947d9 Support DecimalType in sort and take kernels (#1172)
e9947d9 is described below
commit e9947d944934e108f9ffa8b14eefaead0461930a
Author: Kun Liu <[email protected]>
AuthorDate: Mon Jan 17 23:42:45 2022 +0800
Support DecimalType in sort and take kernels (#1172)
---
arrow/src/compute/kernels/sort.rs | 234 +++++++++++++++++++++++++++++++++++++-
arrow/src/compute/kernels/take.rs | 105 +++++++++++++++++
2 files changed, 334 insertions(+), 5 deletions(-)
diff --git a/arrow/src/compute/kernels/sort.rs
b/arrow/src/compute/kernels/sort.rs
index 5aa1fdf..6549477 100644
--- a/arrow/src/compute/kernels/sort.rs
+++ b/arrow/src/compute/kernels/sort.rs
@@ -170,6 +170,7 @@ pub fn sort_to_indices(
let (v, n) = partition_validity(values);
Ok(match values.data_type() {
+ DataType::Decimal(_, _) => sort_decimal(values, v, n, cmp, &options,
limit),
DataType::Boolean => sort_boolean(values, v, n, &options, limit),
DataType::Int8 => {
sort_primitive::<Int8Type, _>(values, v, n, cmp, &options, limit)
@@ -293,7 +294,7 @@ pub fn sort_to_indices(
return Err(ArrowError::ComputeError(format!(
"Sort not supported for list type {:?}",
t
- )))
+ )));
}
},
DataType::LargeList(field) => match field.data_type() {
@@ -321,7 +322,7 @@ pub fn sort_to_indices(
return Err(ArrowError::ComputeError(format!(
"Sort not supported for list type {:?}",
t
- )))
+ )));
}
},
DataType::FixedSizeList(field, _) => match field.data_type() {
@@ -349,7 +350,7 @@ pub fn sort_to_indices(
return Err(ArrowError::ComputeError(format!(
"Sort not supported for list type {:?}",
t
- )))
+ )));
}
},
DataType::Dictionary(key_type, value_type)
@@ -384,7 +385,7 @@ pub fn sort_to_indices(
return Err(ArrowError::ComputeError(format!(
"Sort not supported for dictionary key type {:?}",
t
- )))
+ )));
}
}
}
@@ -396,7 +397,7 @@ pub fn sort_to_indices(
return Err(ArrowError::ComputeError(format!(
"Sort not supported for data type {:?}",
t
- )))
+ )));
}
})
}
@@ -509,6 +510,30 @@ fn sort_boolean(
UInt32Array::from(result_data)
}
+/// Sort Decimal array
+fn sort_decimal<F>(
+ decimal_values: &ArrayRef,
+ value_indices: Vec<u32>,
+ null_indices: Vec<u32>,
+ cmp: F,
+ options: &SortOptions,
+ limit: Option<usize>,
+) -> UInt32Array
+where
+ F: Fn(i128, i128) -> std::cmp::Ordering,
+{
+ // downcast to decimal array
+ let decimal_array = decimal_values
+ .as_any()
+ .downcast_ref::<DecimalArray>()
+ .expect("Unable to downcast to decimal array");
+ let valids = value_indices
+ .into_iter()
+ .map(|index| (index, decimal_array.value(index as usize)))
+ .collect::<Vec<(u32, i128)>>();
+ sort_primitive_inner(decimal_values, null_indices, cmp, options, limit,
valids)
+}
+
/// Sort primitive values
fn sort_primitive<T, F>(
values: &ArrayRef,
@@ -1080,6 +1105,49 @@ mod tests {
use std::convert::TryFrom;
use std::sync::Arc;
+ fn create_decimal_array(data: &[Option<i128>]) -> DecimalArray {
+ let mut builder = DecimalBuilder::new(20, 23, 6);
+
+ for d in data {
+ if let Some(v) = d {
+ builder.append_value(*v).unwrap();
+ } else {
+ builder.append_null().unwrap();
+ }
+ }
+ builder.finish()
+ }
+
+ fn test_sort_to_indices_decimal_array(
+ data: Vec<Option<i128>>,
+ options: Option<SortOptions>,
+ limit: Option<usize>,
+ expected_data: Vec<u32>,
+ ) {
+ let output = create_decimal_array(&data);
+ let expected = UInt32Array::from(expected_data);
+ let output =
+ sort_to_indices(&(Arc::new(output) as ArrayRef), options,
limit).unwrap();
+ assert_eq!(output, expected)
+ }
+
+ fn test_sort_decimal_array(
+ data: Vec<Option<i128>>,
+ options: Option<SortOptions>,
+ limit: Option<usize>,
+ expected_data: Vec<Option<i128>>,
+ ) {
+ let output = create_decimal_array(&data);
+ let expected = Arc::new(create_decimal_array(&expected_data)) as
ArrayRef;
+ let output = match limit {
+ Some(_) => {
+ sort_limit(&(Arc::new(output) as ArrayRef), options,
limit).unwrap()
+ }
+ _ => sort(&(Arc::new(output) as ArrayRef), options).unwrap(),
+ };
+ assert_eq!(&output, &expected)
+ }
+
fn test_sort_to_indices_boolean_arrays(
data: Vec<Option<bool>>,
options: Option<SortOptions>,
@@ -1660,6 +1728,162 @@ mod tests {
}
#[test]
+ fn test_sort_indices_decimal128() {
+ // decimal default
+ test_sort_to_indices_decimal_array(
+ vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None],
+ None,
+ None,
+ vec![0, 6, 4, 2, 3, 5, 1],
+ );
+ // decimal descending
+ test_sort_to_indices_decimal_array(
+ vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None],
+ Some(SortOptions {
+ descending: true,
+ nulls_first: false,
+ }),
+ None,
+ vec![1, 5, 3, 2, 4, 6, 0],
+ );
+ // decimal null_first and descending
+ test_sort_to_indices_decimal_array(
+ vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None],
+ Some(SortOptions {
+ descending: true,
+ nulls_first: true,
+ }),
+ None,
+ vec![6, 0, 1, 5, 3, 2, 4],
+ );
+ // decimal null_first
+ test_sort_to_indices_decimal_array(
+ vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None],
+ Some(SortOptions {
+ descending: false,
+ nulls_first: true,
+ }),
+ None,
+ vec![0, 6, 4, 2, 3, 5, 1],
+ );
+ // limit
+ test_sort_to_indices_decimal_array(
+ vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None],
+ None,
+ Some(3),
+ vec![0, 6, 4],
+ );
+ // limit descending
+ test_sort_to_indices_decimal_array(
+ vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None],
+ Some(SortOptions {
+ descending: true,
+ nulls_first: false,
+ }),
+ Some(3),
+ vec![1, 5, 3],
+ );
+ // limit descending null_first
+ test_sort_to_indices_decimal_array(
+ vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None],
+ Some(SortOptions {
+ descending: true,
+ nulls_first: true,
+ }),
+ Some(3),
+ vec![6, 0, 1],
+ );
+ // limit null_first
+ test_sort_to_indices_decimal_array(
+ vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None],
+ Some(SortOptions {
+ descending: false,
+ nulls_first: true,
+ }),
+ Some(3),
+ vec![0, 6, 4],
+ );
+ }
+
+ #[test]
+ fn test_sort_decimal128() {
+ // decimal default
+ test_sort_decimal_array(
+ vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None],
+ None,
+ None,
+ vec![None, None, Some(1), Some(2), Some(3), Some(4), Some(5)],
+ );
+ // decimal descending
+ test_sort_decimal_array(
+ vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None],
+ Some(SortOptions {
+ descending: true,
+ nulls_first: false,
+ }),
+ None,
+ vec![Some(5), Some(4), Some(3), Some(2), Some(1), None, None],
+ );
+ // decimal null_first and descending
+ test_sort_decimal_array(
+ vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None],
+ Some(SortOptions {
+ descending: true,
+ nulls_first: true,
+ }),
+ None,
+ vec![None, None, Some(5), Some(4), Some(3), Some(2), Some(1)],
+ );
+ // decimal null_first
+ test_sort_decimal_array(
+ vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None],
+ Some(SortOptions {
+ descending: false,
+ nulls_first: true,
+ }),
+ None,
+ vec![None, None, Some(1), Some(2), Some(3), Some(4), Some(5)],
+ );
+ // limit
+ test_sort_decimal_array(
+ vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None],
+ None,
+ Some(3),
+ vec![None, None, Some(1)],
+ );
+ // limit descending
+ test_sort_decimal_array(
+ vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None],
+ Some(SortOptions {
+ descending: true,
+ nulls_first: false,
+ }),
+ Some(3),
+ vec![Some(5), Some(4), Some(3)],
+ );
+ // limit descending null_first
+ test_sort_decimal_array(
+ vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None],
+ Some(SortOptions {
+ descending: true,
+ nulls_first: true,
+ }),
+ Some(3),
+ vec![None, None, Some(5)],
+ );
+ // limit null_first
+ test_sort_decimal_array(
+ vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None],
+ Some(SortOptions {
+ descending: false,
+ nulls_first: true,
+ }),
+ Some(3),
+ vec![None, None, Some(1)],
+ );
+ }
+
+ #[test]
fn test_sort_primitives() {
// default case
test_sort_primitive_arrays::<UInt8Type>(
diff --git a/arrow/src/compute/kernels/take.rs
b/arrow/src/compute/kernels/take.rs
index 4e14fac..a3b5283 100644
--- a/arrow/src/compute/kernels/take.rs
+++ b/arrow/src/compute/kernels/take.rs
@@ -131,6 +131,10 @@ where
let values =
values.as_any().downcast_ref::<BooleanArray>().unwrap();
Ok(Arc::new(take_boolean(values, indices)?))
}
+ DataType::Decimal(_, _) => {
+ let decimal_values =
values.as_any().downcast_ref::<DecimalArray>().unwrap();
+ Ok(Arc::new(take_decimal128(decimal_values, indices)?))
+ }
DataType::Int8 => downcast_take!(Int8Type, values, indices),
DataType::Int16 => downcast_take!(Int16Type, values, indices),
DataType::Int32 => downcast_take!(Int32Type, values, indices),
@@ -483,6 +487,38 @@ where
Ok((buffer, nulls))
}
+/// `take` implementation for decimal arrays
+fn take_decimal128<IndexType>(
+ decimal_values: &DecimalArray,
+ indices: &PrimitiveArray<IndexType>,
+) -> Result<DecimalArray>
+where
+ IndexType: ArrowNumericType,
+ IndexType::Native: ToPrimitive,
+{
+ // TODO optimize decimal take and construct decimal array from
MutableBuffer
+ let mut builder = DecimalBuilder::new(
+ indices.len(),
+ decimal_values.precision(),
+ decimal_values.scale(),
+ );
+ for i in 0..indices.len() {
+ if indices.is_null(i) {
+ builder.append_null()?;
+ } else {
+ let index = ToPrimitive::to_usize(&indices.value(i)).ok_or_else(||
{
+ ArrowError::ComputeError("Cast to usize failed".to_string())
+ })?;
+ if decimal_values.is_null(index) {
+ builder.append_null()?
+ } else {
+ builder.append_value(decimal_values.value(index))?
+ }
+ }
+ }
+ Ok(builder.finish())
+}
+
/// `take` implementation for all primitive arrays
///
/// This checks if an `indices` slot is populated, and gets the value from
`values`
@@ -921,6 +957,43 @@ mod tests {
use super::*;
use crate::compute::util::tests::build_fixed_size_list_nullable;
+ fn test_take_decimal_arrays(
+ data: Vec<Option<i128>>,
+ index: &UInt32Array,
+ options: Option<TakeOptions>,
+ expected_data: Vec<Option<i128>>,
+ precision: &usize,
+ scale: &usize,
+ ) -> Result<()> {
+ let mut builder = DecimalBuilder::new(data.len(), *precision, *scale);
+ for value in data {
+ match value {
+ None => {
+ builder.append_null()?;
+ }
+ Some(v) => {
+ builder.append_value(v)?;
+ }
+ }
+ }
+ let output = builder.finish();
+ let mut builder = DecimalBuilder::new(expected_data.len(), *precision,
*scale);
+ for value in expected_data {
+ match value {
+ None => {
+ builder.append_null()?;
+ }
+ Some(v) => {
+ builder.append_value(v)?;
+ }
+ }
+ }
+ let expected = Arc::new(builder.finish()) as ArrayRef;
+ let output = take(&output, index, options).unwrap();
+ assert_eq!(&output, &expected);
+ Ok(())
+ }
+
fn test_take_boolean_arrays(
data: Vec<Option<bool>>,
index: &UInt32Array,
@@ -1018,6 +1091,38 @@ mod tests {
}
#[test]
+ fn test_take_decimal128_non_null_indices() {
+ let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
+ let precision: usize = 10;
+ let scale: usize = 5;
+ test_take_decimal_arrays(
+ vec![None, Some(3), Some(5), Some(2), Some(3), None],
+ &index,
+ None,
+ vec![None, None, Some(2), Some(3), Some(3), Some(5)],
+ &precision,
+ &scale,
+ )
+ .unwrap();
+ }
+
+ #[test]
+ fn test_take_decimal128() {
+ let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3),
Some(2)]);
+ let precision: usize = 10;
+ let scale: usize = 5;
+ test_take_decimal_arrays(
+ vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
+ &index,
+ None,
+ vec![Some(3), None, Some(1), Some(3), Some(2)],
+ &precision,
+ &scale,
+ )
+ .unwrap();
+ }
+
+ #[test]
fn test_take_primitive_non_null_indices() {
let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
test_take_primitive_arrays::<Int8Type>(