This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new aac3aa993 Improve in-place primitive sorts by 13-67% (#4473)
aac3aa993 is described below
commit aac3aa99398c4f4fe59c60d1839d3a8ab60d00f3
Author: Vrishabh <[email protected]>
AuthorDate: Tue Jul 4 21:34:52 2023 +0530
Improve in-place primitive sorts by 13-67% (#4473)
* Adding sort_primitives benchmark
* Adding sort_primitives improvements
* Fix lints
* Remove all unsafe code and handle offset cases
* Incorporate review comments
* Remove unneeded returns
---
arrow-ord/src/sort.rs | 72 +++++++++++++++++++++++++++++++--
arrow/Cargo.toml | 5 +++
arrow/benches/sort_kernel_primitives.rs | 59 +++++++++++++++++++++++++++
3 files changed, 132 insertions(+), 4 deletions(-)
diff --git a/arrow-ord/src/sort.rs b/arrow-ord/src/sort.rs
index 1d9653259..147af1e30 100644
--- a/arrow-ord/src/sort.rs
+++ b/arrow-ord/src/sort.rs
@@ -22,6 +22,7 @@ use arrow_array::builder::BufferBuilder;
use arrow_array::cast::*;
use arrow_array::types::*;
use arrow_array::*;
+use arrow_buffer::BooleanBufferBuilder;
use arrow_buffer::{ArrowNativeType, MutableBuffer, NullBuffer};
use arrow_data::ArrayData;
use arrow_data::ArrayDataBuilder;
@@ -57,11 +58,74 @@ pub fn sort(
values: &dyn Array,
options: Option<SortOptions>,
) -> Result<ArrayRef, ArrowError> {
- if let DataType::RunEndEncoded(_, _) = values.data_type() {
- return sort_run(values, options, None);
+ downcast_primitive_array!(
+ values => sort_native_type(values, options),
+ DataType::RunEndEncoded(_, _) => sort_run(values, options, None),
+ _ => {
+ let indices = sort_to_indices(values, options, None)?;
+ take(values, &indices, None)
+ }
+ )
+}
+
+fn sort_native_type<T>(
+ primitive_values: &PrimitiveArray<T>,
+ options: Option<SortOptions>,
+) -> Result<ArrayRef, ArrowError>
+where
+ T: ArrowPrimitiveType,
+{
+ let sort_options = options.unwrap_or_default();
+
+ let mut mutable_buffer = vec![T::default_value(); primitive_values.len()];
+ let mutable_slice = &mut mutable_buffer;
+
+ let input_values = primitive_values.values().as_ref();
+
+ let nulls_count = primitive_values.null_count();
+ let valid_count = primitive_values.len() - nulls_count;
+
+ let null_bit_buffer = match nulls_count > 0 {
+ true => {
+ let mut validity_buffer =
BooleanBufferBuilder::new(primitive_values.len());
+ if sort_options.nulls_first {
+ validity_buffer.append_n(nulls_count, false);
+ validity_buffer.append_n(valid_count, true);
+ } else {
+ validity_buffer.append_n(valid_count, true);
+ validity_buffer.append_n(nulls_count, false);
+ }
+ Some(validity_buffer.finish().into())
+ }
+ false => None,
+ };
+
+ if let Some(nulls) = primitive_values.nulls().filter(|n| n.null_count() >
0) {
+ let values_slice = match sort_options.nulls_first {
+ true => &mut mutable_slice[nulls_count..],
+ false => &mut mutable_slice[..valid_count],
+ };
+
+ for (write_index, index) in nulls.valid_indices().enumerate() {
+ values_slice[write_index] = primitive_values.value(index);
+ }
+
+ values_slice.sort_unstable_by(|a, b| a.compare(*b));
+ if sort_options.descending {
+ values_slice.reverse();
+ }
+ } else {
+ mutable_slice.copy_from_slice(input_values);
+ mutable_slice.sort_unstable_by(|a, b| a.compare(*b));
+ if sort_options.descending {
+ mutable_slice.reverse();
+ }
}
- let indices = sort_to_indices(values, options, None)?;
- take(values, &indices, None)
+
+ Ok(Arc::new(
+ PrimitiveArray::<T>::new(mutable_buffer.into(), null_bit_buffer)
+ .with_data_type(primitive_values.data_type().clone()),
+ ))
}
/// Sort the `ArrayRef` partially.
diff --git a/arrow/Cargo.toml b/arrow/Cargo.toml
index bc126a2f4..ed4786fb3 100644
--- a/arrow/Cargo.toml
+++ b/arrow/Cargo.toml
@@ -185,6 +185,11 @@ name = "sort_kernel"
harness = false
required-features = ["test_utils"]
+[[bench]]
+name = "sort_kernel_primitives"
+harness = false
+required-features = ["test_utils"]
+
[[bench]]
name = "partition_kernels"
harness = false
diff --git a/arrow/benches/sort_kernel_primitives.rs
b/arrow/benches/sort_kernel_primitives.rs
new file mode 100644
index 000000000..ca9183580
--- /dev/null
+++ b/arrow/benches/sort_kernel_primitives.rs
@@ -0,0 +1,59 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#[macro_use]
+extern crate criterion;
+use arrow_ord::sort::sort;
+use criterion::Criterion;
+
+use std::sync::Arc;
+
+extern crate arrow;
+
+use arrow::util::bench_util::*;
+use arrow::{array::*, datatypes::Int64Type};
+
+fn create_i64_array(size: usize, with_nulls: bool) -> ArrayRef {
+ let null_density = if with_nulls { 0.5 } else { 0.0 };
+ let array = create_primitive_array::<Int64Type>(size, null_density);
+ Arc::new(array)
+}
+
+fn bench_sort(array: &ArrayRef) {
+ criterion::black_box(sort(criterion::black_box(array), None).unwrap());
+}
+
+fn add_benchmark(c: &mut Criterion) {
+ let arr_a = create_i64_array(2u64.pow(10) as usize, false);
+
+ c.bench_function("sort 2^10", |b| b.iter(|| bench_sort(&arr_a)));
+
+ let arr_a = create_i64_array(2u64.pow(12) as usize, false);
+
+ c.bench_function("sort 2^12", |b| b.iter(|| bench_sort(&arr_a)));
+
+ let arr_a = create_i64_array(2u64.pow(10) as usize, true);
+
+ c.bench_function("sort nulls 2^10", |b| b.iter(|| bench_sort(&arr_a)));
+
+ let arr_a = create_i64_array(2u64.pow(12) as usize, true);
+
+ c.bench_function("sort nulls 2^12", |b| b.iter(|| bench_sort(&arr_a)));
+}
+
+criterion_group!(benches, add_benchmark);
+criterion_main!(benches);