alamb commented on code in PR #4613:
URL: https://github.com/apache/arrow-rs/pull/4613#discussion_r1283679351
##########
arrow-ord/src/sort.rs:
##########
@@ -204,210 +196,33 @@ fn partition_validity(array: &dyn Array) -> (Vec<u32>,
Vec<u32>) {
/// For floating point arrays any NaN values are considered to be greater than
any other non-null value.
/// `limit` is an option for [partial_sort].
pub fn sort_to_indices(
- values: &dyn Array,
+ array: &dyn Array,
options: Option<SortOptions>,
limit: Option<usize>,
) -> Result<UInt32Array, ArrowError> {
let options = options.unwrap_or_default();
- let (v, n) = partition_validity(values);
-
- Ok(match values.data_type() {
- DataType::Decimal128(_, _) => {
- sort_primitive::<Decimal128Type, _>(values, v, n, cmp, &options,
limit)
- }
- DataType::Decimal256(_, _) => {
- sort_primitive::<Decimal256Type, _>(values, v, n, cmp, &options,
limit)
- }
- DataType::Boolean => sort_boolean(values, v, n, &options, limit),
- DataType::Int8 => {
- sort_primitive::<Int8Type, _>(values, v, n, cmp, &options, limit)
- }
- DataType::Int16 => {
- sort_primitive::<Int16Type, _>(values, v, n, cmp, &options, limit)
- }
- DataType::Int32 => {
- sort_primitive::<Int32Type, _>(values, v, n, cmp, &options, limit)
- }
- DataType::Int64 => {
- sort_primitive::<Int64Type, _>(values, v, n, cmp, &options, limit)
- }
- DataType::UInt8 => {
- sort_primitive::<UInt8Type, _>(values, v, n, cmp, &options, limit)
- }
- DataType::UInt16 => {
- sort_primitive::<UInt16Type, _>(values, v, n, cmp, &options, limit)
- }
- DataType::UInt32 => {
- sort_primitive::<UInt32Type, _>(values, v, n, cmp, &options, limit)
- }
- DataType::UInt64 => {
- sort_primitive::<UInt64Type, _>(values, v, n, cmp, &options, limit)
- }
- DataType::Float16 => sort_primitive::<Float16Type, _>(
- values,
- v,
- n,
- |x, y| x.total_cmp(&y),
- &options,
- limit,
- ),
- DataType::Float32 => sort_primitive::<Float32Type, _>(
- values,
- v,
- n,
- |x, y| x.total_cmp(&y),
- &options,
- limit,
- ),
- DataType::Float64 => sort_primitive::<Float64Type, _>(
- values,
- v,
- n,
- |x, y| x.total_cmp(&y),
- &options,
- limit,
- ),
- DataType::Date32 => {
- sort_primitive::<Date32Type, _>(values, v, n, cmp, &options, limit)
- }
- DataType::Date64 => {
- sort_primitive::<Date64Type, _>(values, v, n, cmp, &options, limit)
- }
- DataType::Time32(TimeUnit::Second) => {
- sort_primitive::<Time32SecondType, _>(values, v, n, cmp, &options,
limit)
- }
- DataType::Time32(TimeUnit::Millisecond) => {
- sort_primitive::<Time32MillisecondType, _>(values, v, n, cmp,
&options, limit)
- }
- DataType::Time64(TimeUnit::Microsecond) => {
- sort_primitive::<Time64MicrosecondType, _>(values, v, n, cmp,
&options, limit)
- }
- DataType::Time64(TimeUnit::Nanosecond) => {
- sort_primitive::<Time64NanosecondType, _>(values, v, n, cmp,
&options, limit)
- }
- DataType::Timestamp(TimeUnit::Second, _) => {
- sort_primitive::<TimestampSecondType, _>(values, v, n, cmp,
&options, limit)
- }
- DataType::Timestamp(TimeUnit::Millisecond, _) => {
- sort_primitive::<TimestampMillisecondType, _>(
- values, v, n, cmp, &options, limit,
- )
- }
- DataType::Timestamp(TimeUnit::Microsecond, _) => {
- sort_primitive::<TimestampMicrosecondType, _>(
- values, v, n, cmp, &options, limit,
- )
- }
- DataType::Timestamp(TimeUnit::Nanosecond, _) => {
- sort_primitive::<TimestampNanosecondType, _>(
- values, v, n, cmp, &options, limit,
- )
- }
- DataType::Interval(IntervalUnit::YearMonth) => {
- sort_primitive::<IntervalYearMonthType, _>(values, v, n, cmp,
&options, limit)
- }
- DataType::Interval(IntervalUnit::DayTime) => {
- sort_primitive::<IntervalDayTimeType, _>(values, v, n, cmp,
&options, limit)
- }
- DataType::Interval(IntervalUnit::MonthDayNano) => {
- sort_primitive::<IntervalMonthDayNanoType, _>(
- values, v, n, cmp, &options, limit,
- )
- }
- DataType::Duration(TimeUnit::Second) => {
- sort_primitive::<DurationSecondType, _>(values, v, n, cmp,
&options, limit)
+ let (v, n) = partition_validity(array);
Review Comment:
since `partition_validity` uses `u32` for indexes, doesn't that mean this
code can not sort arrays with more than 4B values? Maybe for things like
`LargeList` this could be a problem 🤔
However, it seems like this was the case prior to this PR as well, so it
probably isn't an issue
##########
arrow-ord/src/sort.rs:
##########
@@ -834,200 +608,6 @@ where
(values_indices, run_values)
}
-/// Sort strings
-fn sort_string<Offset: OffsetSizeTrait>(
- values: &dyn Array,
- value_indices: Vec<u32>,
- null_indices: Vec<u32>,
- options: &SortOptions,
- limit: Option<usize>,
-) -> UInt32Array {
- let values = values
- .as_any()
- .downcast_ref::<GenericStringArray<Offset>>()
- .unwrap();
-
- sort_string_helper(
- values,
- value_indices,
- null_indices,
- options,
- limit,
- |array, idx| array.value(idx as usize),
- )
-}
-
-/// shared implementation between dictionary encoded and plain string arrays
-#[inline]
-fn sort_string_helper<'a, A: Array, F>(
- values: &'a A,
- value_indices: Vec<u32>,
- null_indices: Vec<u32>,
- options: &SortOptions,
- limit: Option<usize>,
- value_fn: F,
-) -> UInt32Array
-where
- F: Fn(&'a A, u32) -> &str,
-{
- let mut valids = value_indices
- .into_iter()
- .map(|index| (index, value_fn(values, index)))
- .collect::<Vec<(u32, &str)>>();
- let mut nulls = null_indices;
- let descending = options.descending;
- let mut len = values.len();
-
- if let Some(limit) = limit {
- len = limit.min(len);
- }
-
- sort_valids(descending, &mut valids, len, cmp);
- // collect the order of valid tuplies
- let mut valid_indices: Vec<u32> = valids.iter().map(|tuple|
tuple.0).collect();
-
- if options.nulls_first {
- nulls.append(&mut valid_indices);
- nulls.truncate(len);
- UInt32Array::from(nulls)
- } else {
- // no need to sort nulls as they are in the correct order already
- valid_indices.append(&mut nulls);
- valid_indices.truncate(len);
- UInt32Array::from(valid_indices)
- }
-}
-
-fn sort_list<S>(
- values: &dyn Array,
- value_indices: Vec<u32>,
- null_indices: Vec<u32>,
- options: &SortOptions,
- limit: Option<usize>,
-) -> UInt32Array
-where
- S: OffsetSizeTrait,
-{
- sort_list_inner::<S>(values, value_indices, null_indices, options, limit)
-}
-
-fn sort_list_inner<S>(
- values: &dyn Array,
- value_indices: Vec<u32>,
- mut null_indices: Vec<u32>,
- options: &SortOptions,
- limit: Option<usize>,
-) -> UInt32Array
-where
- S: OffsetSizeTrait,
-{
- let mut valids: Vec<(u32, ArrayRef)> = values
- .as_any()
- .downcast_ref::<FixedSizeListArray>()
- .map_or_else(
- || {
- let values = as_generic_list_array::<S>(values);
- value_indices
- .iter()
- .copied()
- .map(|index| (index, values.value(index as usize)))
- .collect()
- },
- |values| {
- value_indices
- .iter()
- .copied()
- .map(|index| (index, values.value(index as usize)))
- .collect()
- },
- );
-
- let mut len = values.len();
- let descending = options.descending;
-
- if let Some(limit) = limit {
- len = limit.min(len);
- }
- sort_valids_array(descending, &mut valids, &mut null_indices, len);
-
- let mut valid_indices: Vec<u32> = valids.iter().map(|tuple|
tuple.0).collect();
- if options.nulls_first {
- null_indices.append(&mut valid_indices);
- null_indices.truncate(len);
- UInt32Array::from(null_indices)
- } else {
- valid_indices.append(&mut null_indices);
- valid_indices.truncate(len);
- UInt32Array::from(valid_indices)
- }
-}
-
-fn sort_binary<S>(
- values: &dyn Array,
- value_indices: Vec<u32>,
- mut null_indices: Vec<u32>,
- options: &SortOptions,
- limit: Option<usize>,
-) -> UInt32Array
-where
- S: OffsetSizeTrait,
-{
- let mut valids: Vec<(u32, &[u8])> = values
- .as_any()
- .downcast_ref::<FixedSizeBinaryArray>()
- .map_or_else(
- || {
- let values = as_generic_binary_array::<S>(values);
- value_indices
- .iter()
- .copied()
- .map(|index| (index, values.value(index as usize)))
- .collect()
- },
- |values| {
- value_indices
- .iter()
- .copied()
- .map(|index| (index, values.value(index as usize)))
- .collect()
- },
- );
-
- let mut len = values.len();
- let descending = options.descending;
-
- if let Some(limit) = limit {
- len = limit.min(len);
- }
-
- sort_valids(descending, &mut valids, len, cmp);
-
- let mut valid_indices: Vec<u32> = valids.iter().map(|tuple|
tuple.0).collect();
- if options.nulls_first {
- null_indices.append(&mut valid_indices);
- null_indices.truncate(len);
- UInt32Array::from(null_indices)
- } else {
- valid_indices.append(&mut null_indices);
- valid_indices.truncate(len);
- UInt32Array::from(valid_indices)
- }
-}
-
-/// Compare two `Array`s based on the ordering defined in [build_compare]
-fn cmp_array(a: &dyn Array, b: &dyn Array) -> Ordering {
- let cmp_op = build_compare(a, b).unwrap();
Review Comment:
Looks like it was called as part of sorting ListArray -- so that might
become substantially faster
##########
arrow-ord/src/sort.rs:
##########
@@ -422,231 +237,177 @@ pub fn sort_to_indices(
})
}
-/// Sort boolean values
-///
-/// when a limit is present, the sort is pair-comparison based as k-select
might be more efficient,
-/// when the limit is absent, binary partition is used to speed up (which is
linear).
-///
-/// TODO maybe partition_validity call can be eliminated in this case
-/// and [tri-color
sort](https://en.wikipedia.org/wiki/Dutch_national_flag_problem)
-/// can be used instead.
fn sort_boolean(
- values: &dyn Array,
+ values: &BooleanArray,
value_indices: Vec<u32>,
- mut null_indices: Vec<u32>,
- options: &SortOptions,
+ null_indices: Vec<u32>,
+ options: SortOptions,
limit: Option<usize>,
) -> UInt32Array {
- let values = values
- .as_any()
- .downcast_ref::<BooleanArray>()
- .expect("Unable to downcast to boolean array");
- let descending = options.descending;
-
- let valids_len = value_indices.len();
- let nulls_len = null_indices.len();
-
- let mut len = values.len();
- let valids = if let Some(limit) = limit {
- len = limit.min(len);
- // create tuples that are used for sorting
- let mut valids = value_indices
- .into_iter()
- .map(|index| (index, values.value(index as usize)))
- .collect::<Vec<(u32, bool)>>();
-
- sort_valids(descending, &mut valids, len, cmp);
- valids
- } else {
- // when limit is not present, we have a better way than sorting: we
can just partition
- // the vec into [false..., true...] or [true..., false...] when
descending
- // TODO when https://github.com/rust-lang/rust/issues/62543 is merged
we can use partition_in_place
- let (mut a, b): (Vec<_>, Vec<_>) = value_indices
- .into_iter()
- .map(|index| (index, values.value(index as usize)))
- .partition(|(_, value)| *value == descending);
- a.extend(b);
- if descending {
- null_indices.reverse();
- }
- a
- };
-
- let nulls = null_indices;
-
- // collect results directly into a buffer instead of a vec to avoid
another aligned allocation
- let result_capacity = len * std::mem::size_of::<u32>();
- let mut result = MutableBuffer::new(result_capacity);
- // sets len to capacity so we can access the whole buffer as a typed slice
- result.resize(result_capacity, 0);
- let result_slice: &mut [u32] = result.typed_data_mut();
-
- if options.nulls_first {
- let size = nulls_len.min(len);
- result_slice[0..size].copy_from_slice(&nulls[0..size]);
- if nulls_len < len {
- insert_valid_values(result_slice, nulls_len, &valids[0..len -
size]);
- }
- } else {
- // nulls last
- let size = valids.len().min(len);
- insert_valid_values(result_slice, 0, &valids[0..size]);
- if len > size {
- result_slice[valids_len..].copy_from_slice(&nulls[0..(len -
valids_len)]);
- }
- }
-
- let result_data = unsafe {
- ArrayData::new_unchecked(
- DataType::UInt32,
- len,
- Some(0),
- None,
- 0,
- vec![result.into()],
- vec![],
- )
- };
+ let mut valids = value_indices
+ .into_iter()
+ .map(|index| (index, values.value(index as usize)))
+ .collect::<Vec<(u32, bool)>>();
+ sort_impl(options, &mut valids, &null_indices, limit, |a, b|
a.cmp(&b)).into()
Review Comment:
👌 `sort_impl` is very nice
##########
arrow-ord/src/sort.rs:
##########
@@ -1980,7 +1526,7 @@ mod tests {
nulls_first: false,
}),
None,
- vec![2, 3, 1, 4, 5, 0],
+ vec![2, 3, 1, 4, 0, 5],
Review Comment:
`0` and `5` here are the indexes of nulls. Makes sense
##########
arrow-ord/src/sort.rs:
##########
@@ -204,210 +196,33 @@ fn partition_validity(array: &dyn Array) -> (Vec<u32>,
Vec<u32>) {
/// For floating point arrays any NaN values are considered to be greater than
any other non-null value.
/// `limit` is an option for [partial_sort].
pub fn sort_to_indices(
- values: &dyn Array,
+ array: &dyn Array,
options: Option<SortOptions>,
limit: Option<usize>,
) -> Result<UInt32Array, ArrowError> {
let options = options.unwrap_or_default();
- let (v, n) = partition_validity(values);
-
- Ok(match values.data_type() {
- DataType::Decimal128(_, _) => {
- sort_primitive::<Decimal128Type, _>(values, v, n, cmp, &options,
limit)
- }
- DataType::Decimal256(_, _) => {
- sort_primitive::<Decimal256Type, _>(values, v, n, cmp, &options,
limit)
- }
- DataType::Boolean => sort_boolean(values, v, n, &options, limit),
- DataType::Int8 => {
- sort_primitive::<Int8Type, _>(values, v, n, cmp, &options, limit)
- }
- DataType::Int16 => {
- sort_primitive::<Int16Type, _>(values, v, n, cmp, &options, limit)
- }
- DataType::Int32 => {
- sort_primitive::<Int32Type, _>(values, v, n, cmp, &options, limit)
- }
- DataType::Int64 => {
- sort_primitive::<Int64Type, _>(values, v, n, cmp, &options, limit)
- }
- DataType::UInt8 => {
- sort_primitive::<UInt8Type, _>(values, v, n, cmp, &options, limit)
- }
- DataType::UInt16 => {
- sort_primitive::<UInt16Type, _>(values, v, n, cmp, &options, limit)
- }
- DataType::UInt32 => {
- sort_primitive::<UInt32Type, _>(values, v, n, cmp, &options, limit)
- }
- DataType::UInt64 => {
- sort_primitive::<UInt64Type, _>(values, v, n, cmp, &options, limit)
- }
- DataType::Float16 => sort_primitive::<Float16Type, _>(
- values,
- v,
- n,
- |x, y| x.total_cmp(&y),
- &options,
- limit,
- ),
- DataType::Float32 => sort_primitive::<Float32Type, _>(
- values,
- v,
- n,
- |x, y| x.total_cmp(&y),
- &options,
- limit,
- ),
- DataType::Float64 => sort_primitive::<Float64Type, _>(
- values,
- v,
- n,
- |x, y| x.total_cmp(&y),
- &options,
- limit,
- ),
- DataType::Date32 => {
- sort_primitive::<Date32Type, _>(values, v, n, cmp, &options, limit)
- }
- DataType::Date64 => {
- sort_primitive::<Date64Type, _>(values, v, n, cmp, &options, limit)
- }
- DataType::Time32(TimeUnit::Second) => {
- sort_primitive::<Time32SecondType, _>(values, v, n, cmp, &options,
limit)
- }
- DataType::Time32(TimeUnit::Millisecond) => {
- sort_primitive::<Time32MillisecondType, _>(values, v, n, cmp,
&options, limit)
- }
- DataType::Time64(TimeUnit::Microsecond) => {
- sort_primitive::<Time64MicrosecondType, _>(values, v, n, cmp,
&options, limit)
- }
- DataType::Time64(TimeUnit::Nanosecond) => {
- sort_primitive::<Time64NanosecondType, _>(values, v, n, cmp,
&options, limit)
- }
- DataType::Timestamp(TimeUnit::Second, _) => {
- sort_primitive::<TimestampSecondType, _>(values, v, n, cmp,
&options, limit)
- }
- DataType::Timestamp(TimeUnit::Millisecond, _) => {
- sort_primitive::<TimestampMillisecondType, _>(
- values, v, n, cmp, &options, limit,
- )
- }
- DataType::Timestamp(TimeUnit::Microsecond, _) => {
- sort_primitive::<TimestampMicrosecondType, _>(
- values, v, n, cmp, &options, limit,
- )
- }
- DataType::Timestamp(TimeUnit::Nanosecond, _) => {
- sort_primitive::<TimestampNanosecondType, _>(
- values, v, n, cmp, &options, limit,
- )
- }
- DataType::Interval(IntervalUnit::YearMonth) => {
- sort_primitive::<IntervalYearMonthType, _>(values, v, n, cmp,
&options, limit)
- }
- DataType::Interval(IntervalUnit::DayTime) => {
- sort_primitive::<IntervalDayTimeType, _>(values, v, n, cmp,
&options, limit)
- }
- DataType::Interval(IntervalUnit::MonthDayNano) => {
- sort_primitive::<IntervalMonthDayNanoType, _>(
- values, v, n, cmp, &options, limit,
- )
- }
- DataType::Duration(TimeUnit::Second) => {
- sort_primitive::<DurationSecondType, _>(values, v, n, cmp,
&options, limit)
+ let (v, n) = partition_validity(array);
Review Comment:
since `partition_validity` uses `u32` for indexes, doesn't that mean this
code can not sort arrays with more than 4B values? Maybe for things like
`LargeList` this could be a problem 🤔
However, it seems like this was the case prior to this PR as well, so it
probably isn't an issue
##########
arrow-ord/src/sort.rs:
##########
@@ -422,231 +237,177 @@ pub fn sort_to_indices(
})
}
-/// Sort boolean values
-///
-/// when a limit is present, the sort is pair-comparison based as k-select
might be more efficient,
-/// when the limit is absent, binary partition is used to speed up (which is
linear).
-///
-/// TODO maybe partition_validity call can be eliminated in this case
-/// and [tri-color
sort](https://en.wikipedia.org/wiki/Dutch_national_flag_problem)
-/// can be used instead.
fn sort_boolean(
- values: &dyn Array,
+ values: &BooleanArray,
value_indices: Vec<u32>,
- mut null_indices: Vec<u32>,
- options: &SortOptions,
+ null_indices: Vec<u32>,
+ options: SortOptions,
limit: Option<usize>,
) -> UInt32Array {
- let values = values
- .as_any()
- .downcast_ref::<BooleanArray>()
- .expect("Unable to downcast to boolean array");
- let descending = options.descending;
-
- let valids_len = value_indices.len();
- let nulls_len = null_indices.len();
-
- let mut len = values.len();
- let valids = if let Some(limit) = limit {
- len = limit.min(len);
- // create tuples that are used for sorting
- let mut valids = value_indices
- .into_iter()
- .map(|index| (index, values.value(index as usize)))
- .collect::<Vec<(u32, bool)>>();
-
- sort_valids(descending, &mut valids, len, cmp);
- valids
- } else {
- // when limit is not present, we have a better way than sorting: we
can just partition
- // the vec into [false..., true...] or [true..., false...] when
descending
- // TODO when https://github.com/rust-lang/rust/issues/62543 is merged
we can use partition_in_place
- let (mut a, b): (Vec<_>, Vec<_>) = value_indices
- .into_iter()
- .map(|index| (index, values.value(index as usize)))
- .partition(|(_, value)| *value == descending);
- a.extend(b);
- if descending {
- null_indices.reverse();
- }
- a
- };
-
- let nulls = null_indices;
-
- // collect results directly into a buffer instead of a vec to avoid
another aligned allocation
- let result_capacity = len * std::mem::size_of::<u32>();
- let mut result = MutableBuffer::new(result_capacity);
- // sets len to capacity so we can access the whole buffer as a typed slice
- result.resize(result_capacity, 0);
- let result_slice: &mut [u32] = result.typed_data_mut();
-
- if options.nulls_first {
- let size = nulls_len.min(len);
- result_slice[0..size].copy_from_slice(&nulls[0..size]);
- if nulls_len < len {
- insert_valid_values(result_slice, nulls_len, &valids[0..len -
size]);
- }
- } else {
- // nulls last
- let size = valids.len().min(len);
- insert_valid_values(result_slice, 0, &valids[0..size]);
- if len > size {
- result_slice[valids_len..].copy_from_slice(&nulls[0..(len -
valids_len)]);
- }
- }
-
- let result_data = unsafe {
- ArrayData::new_unchecked(
- DataType::UInt32,
- len,
- Some(0),
- None,
- 0,
- vec![result.into()],
- vec![],
- )
- };
+ let mut valids = value_indices
+ .into_iter()
+ .map(|index| (index, values.value(index as usize)))
+ .collect::<Vec<(u32, bool)>>();
+ sort_impl(options, &mut valids, &null_indices, limit, |a, b|
a.cmp(&b)).into()
Review Comment:
👌 `sort_impl` is very nice
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]