This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 5c3ed61 chore: Reduce the amount of code generated by
monomorphization (#715)
5c3ed61 is described below
commit 5c3ed6123d9ea0130a1eca95a0aae776b458208f
Author: Markus Westerlind <[email protected]>
AuthorDate: Mon Sep 13 18:55:56 2021 +0200
chore: Reduce the amount of code generated by monomorphization (#715)
* chore: Reduce the number of instantiations of take* (-3%)
Many types have the same native type, so simplifying these functions to
work directly with native types reduces the number of instantiations.
Reduces the number of llvm lines generated by ~3%
* chore: Shrink try_from_trusted_len_iter (-0.5%)
* chore: Only compile sort_primitive per native type (-8.5%)
* chore: Make the inner take_ functions less generic (-3.5%)
* chore: Don't duplicate sort_list (-13%)
* chore: Extract the "valid" sorting (-7%)
* chore: Extract the array sorter (-1%)
---
arrow/src/buffer/mutable.rs | 22 ++++--
arrow/src/compute/kernels/sort.rs | 160 ++++++++++++++++++++++----------------
arrow/src/compute/kernels/take.rs | 107 ++++++++++++++++---------
3 files changed, 178 insertions(+), 111 deletions(-)
diff --git a/arrow/src/buffer/mutable.rs b/arrow/src/buffer/mutable.rs
index 7d336e0..d83997a 100644
--- a/arrow/src/buffer/mutable.rs
+++ b/arrow/src/buffer/mutable.rs
@@ -530,12 +530,22 @@ impl MutableBuffer {
std::ptr::write(dst, item?);
dst = dst.add(1);
}
- assert_eq!(
- dst.offset_from(buffer.data.as_ptr() as *mut T) as usize,
- upper,
- "Trusted iterator length was not accurately reported"
- );
- buffer.len = len;
+ // try_from_trusted_len_iter is instantiated a lot, so we extract part
of it into a less
+ // generic method to reduce compile time
+ unsafe fn finalize_buffer<T>(
+ dst: *mut T,
+ buffer: &mut MutableBuffer,
+ upper: usize,
+ len: usize,
+ ) {
+ assert_eq!(
+ dst.offset_from(buffer.data.as_ptr() as *mut T) as usize,
+ upper,
+ "Trusted iterator length was not accurately reported"
+ );
+ buffer.len = len;
+ }
+ finalize_buffer(dst, &mut buffer, upper, len);
Ok(buffer)
}
}
diff --git a/arrow/src/compute/kernels/sort.rs
b/arrow/src/compute/kernels/sort.rs
index 925f42e..6f42be3 100644
--- a/arrow/src/compute/kernels/sort.rs
+++ b/arrow/src/compute/kernels/sort.rs
@@ -425,7 +425,7 @@ impl Default for SortOptions {
fn sort_boolean(
values: &ArrayRef,
value_indices: Vec<u32>,
- null_indices: Vec<u32>,
+ mut null_indices: Vec<u32>,
options: &SortOptions,
limit: Option<usize>,
) -> UInt32Array {
@@ -446,15 +446,8 @@ fn sort_boolean(
.into_iter()
.map(|index| (index, values.value(index as usize)))
.collect::<Vec<(u32, bool)>>();
- if !descending {
- sort_unstable_by(&mut valids, len.saturating_sub(nulls_len), |a,
b| {
- cmp(a.1, b.1)
- });
- } else {
- sort_unstable_by(&mut valids, len.saturating_sub(nulls_len), |a,
b| {
- cmp(a.1, b.1).reverse()
- });
- }
+
+ sort_valids(descending, &mut valids, &mut null_indices, len, cmp);
valids
} else {
// when limit is not present, we have a better way than sorting: we
can just partition
@@ -465,13 +458,13 @@ fn sort_boolean(
.map(|index| (index, values.value(index as usize)))
.partition(|(_, value)| *value == descending);
a.extend(b);
+ if descending {
+ null_indices.reverse();
+ }
a
};
- let mut nulls = null_indices;
- if descending {
- nulls.reverse();
- }
+ let nulls = null_indices;
// collect results directly into a buffer instead of a vec to avoid
another aligned allocation
let result_capacity = len * std::mem::size_of::<u32>();
@@ -522,15 +515,31 @@ where
T::Native: std::cmp::PartialOrd,
F: Fn(T::Native, T::Native) -> std::cmp::Ordering,
{
- let values = as_primitive_array::<T>(values);
- let descending = options.descending;
-
// create tuples that are used for sorting
- let mut valids = value_indices
- .into_iter()
- .map(|index| (index, values.value(index as usize)))
- .collect::<Vec<(u32, T::Native)>>();
+ let valids = {
+ let values = as_primitive_array::<T>(values);
+ value_indices
+ .into_iter()
+ .map(|index| (index, values.value(index as usize)))
+ .collect::<Vec<(u32, T::Native)>>()
+ };
+ sort_primitive_inner(values, null_indices, cmp, options, limit, valids)
+}
+// sort is instantiated a lot so we only compile this inner version for each
native type
+fn sort_primitive_inner<T, F>(
+ values: &ArrayRef,
+ null_indices: Vec<u32>,
+ cmp: F,
+ options: &SortOptions,
+ limit: Option<usize>,
+ mut valids: Vec<(u32, T)>,
+) -> UInt32Array
+where
+ T: ArrowNativeType,
+ T: std::cmp::PartialOrd,
+ F: Fn(T, T) -> std::cmp::Ordering,
+{
let mut nulls = null_indices;
let valids_len = valids.len();
@@ -540,17 +549,8 @@ where
if let Some(limit) = limit {
len = limit.min(len);
}
- if !descending {
- sort_unstable_by(&mut valids, len.saturating_sub(nulls_len), |a, b| {
- cmp(a.1, b.1)
- });
- } else {
- sort_unstable_by(&mut valids, len.saturating_sub(nulls_len), |a, b| {
- cmp(a.1, b.1).reverse()
- });
- // reverse to keep a stable ordering
- nulls.reverse();
- }
+
+ sort_valids(options.descending, &mut valids, &mut nulls, len, cmp);
// collect results directly into a buffer instead of a vec to avoid
another aligned allocation
let result_capacity = len * std::mem::size_of::<u32>();
@@ -673,22 +673,12 @@ where
let mut nulls = null_indices;
let descending = options.descending;
let mut len = values.len();
- let nulls_len = nulls.len();
if let Some(limit) = limit {
len = limit.min(len);
}
- if !descending {
- sort_unstable_by(&mut valids, len.saturating_sub(nulls_len), |a, b| {
- cmp(a.1, b.1)
- });
- } else {
- sort_unstable_by(&mut valids, len.saturating_sub(nulls_len), |a, b| {
- cmp(a.1, b.1).reverse()
- });
- // reverse to keep a stable ordering
- nulls.reverse();
- }
+
+ sort_valids(descending, &mut valids, &mut nulls, len, cmp);
// collect the order of valid tuplies
let mut valid_indices: Vec<u32> = valids.iter().map(|tuple|
tuple.0).collect();
@@ -707,7 +697,7 @@ where
fn sort_list<S, T>(
values: &ArrayRef,
value_indices: Vec<u32>,
- mut null_indices: Vec<u32>,
+ null_indices: Vec<u32>,
options: &SortOptions,
limit: Option<usize>,
) -> UInt32Array
@@ -716,6 +706,19 @@ where
T: ArrowPrimitiveType,
T::Native: std::cmp::PartialOrd,
{
+ sort_list_inner::<S>(values, value_indices, null_indices, options, limit)
+}
+
+fn sort_list_inner<S>(
+ values: &ArrayRef,
+ value_indices: Vec<u32>,
+ mut null_indices: Vec<u32>,
+ options: &SortOptions,
+ limit: Option<usize>,
+) -> UInt32Array
+where
+ S: OffsetSizeTrait,
+{
let mut valids: Vec<(u32, ArrayRef)> = values
.as_any()
.downcast_ref::<FixedSizeListArray>()
@@ -738,23 +741,12 @@ where
);
let mut len = values.len();
- let nulls_len = null_indices.len();
let descending = options.descending;
if let Some(limit) = limit {
len = limit.min(len);
}
- if !descending {
- sort_unstable_by(&mut valids, len.saturating_sub(nulls_len), |a, b| {
- cmp_array(a.1.as_ref(), b.1.as_ref())
- });
- } else {
- sort_unstable_by(&mut valids, len.saturating_sub(nulls_len), |a, b| {
- cmp_array(a.1.as_ref(), b.1.as_ref()).reverse()
- });
- // reverse to keep a stable ordering
- null_indices.reverse();
- }
+ sort_valids_array(descending, &mut valids, &mut null_indices, len);
let mut valid_indices: Vec<u32> = valids.iter().map(|tuple|
tuple.0).collect();
if options.nulls_first {
@@ -801,21 +793,12 @@ where
let mut len = values.len();
let descending = options.descending;
- let nulls_len = null_indices.len();
if let Some(limit) = limit {
len = limit.min(len);
}
- if !descending {
- sort_unstable_by(&mut valids, len.saturating_sub(nulls_len), |a, b| {
- a.1.cmp(b.1)
- });
- } else {
- sort_unstable_by(&mut valids, len.saturating_sub(nulls_len), |a, b| {
- a.1.cmp(b.1).reverse()
- });
- null_indices.reverse();
- }
+
+ sort_valids(descending, &mut valids, &mut null_indices, len, cmp);
let mut valid_indices: Vec<u32> = valids.iter().map(|tuple|
tuple.0).collect();
if options.nulls_first {
@@ -1036,6 +1019,47 @@ impl LexicographicalComparator<'_> {
}
}
+fn sort_valids<T, U>(
+ descending: bool,
+ valids: &mut [(u32, T)],
+ nulls: &mut [U],
+ len: usize,
+ mut cmp: impl FnMut(T, T) -> Ordering,
+) where
+ T: ?Sized + Copy,
+{
+ let nulls_len = nulls.len();
+ if !descending {
+ sort_unstable_by(valids, len.saturating_sub(nulls_len), |a, b|
cmp(a.1, b.1));
+ } else {
+ sort_unstable_by(valids, len.saturating_sub(nulls_len), |a, b| {
+ cmp(a.1, b.1).reverse()
+ });
+ // reverse to keep a stable ordering
+ nulls.reverse();
+ }
+}
+
+fn sort_valids_array<T>(
+ descending: bool,
+ valids: &mut [(u32, ArrayRef)],
+ nulls: &mut [T],
+ len: usize,
+) {
+ let nulls_len = nulls.len();
+ if !descending {
+ sort_unstable_by(valids, len.saturating_sub(nulls_len), |a, b| {
+ cmp_array(a.1.as_ref(), b.1.as_ref())
+ });
+ } else {
+ sort_unstable_by(valids, len.saturating_sub(nulls_len), |a, b| {
+ cmp_array(a.1.as_ref(), b.1.as_ref()).reverse()
+ });
+ // reverse to keep a stable ordering
+ nulls.reverse();
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
diff --git a/arrow/src/compute/kernels/take.rs
b/arrow/src/compute/kernels/take.rs
index 225f263..7147972 100644
--- a/arrow/src/compute/kernels/take.rs
+++ b/arrow/src/compute/kernels/take.rs
@@ -302,20 +302,17 @@ impl Default for TakeOptions {
}
#[inline(always)]
-fn maybe_usize<I: ArrowPrimitiveType>(index: I::Native) -> Result<usize> {
+fn maybe_usize<I: ArrowNativeType>(index: I) -> Result<usize> {
index
.to_usize()
.ok_or_else(|| ArrowError::ComputeError("Cast to usize
failed".to_string()))
}
// take implementation when neither values nor indices contain nulls
-fn take_no_nulls<T, I>(
- values: &[T::Native],
- indices: &[I::Native],
-) -> Result<(Buffer, Option<Buffer>)>
+fn take_no_nulls<T, I>(values: &[T], indices: &[I]) -> Result<(Buffer,
Option<Buffer>)>
where
- T: ArrowPrimitiveType,
- I: ArrowNumericType,
+ T: ArrowNativeType,
+ I: ArrowNativeType,
{
let values = indices
.iter()
@@ -329,27 +326,36 @@ where
// take implementation when only values contain nulls
fn take_values_nulls<T, I>(
values: &PrimitiveArray<T>,
- indices: &[I::Native],
+ indices: &[I],
) -> Result<(Buffer, Option<Buffer>)>
where
T: ArrowPrimitiveType,
- I: ArrowNumericType,
- I::Native: ToPrimitive,
+ I: ArrowNativeType,
+{
+ take_values_nulls_inner(values.data(), values.values(), indices)
+}
+
+fn take_values_nulls_inner<T, I>(
+ values_data: &ArrayData,
+ values: &[T],
+ indices: &[I],
+) -> Result<(Buffer, Option<Buffer>)>
+where
+ T: ArrowNativeType,
+ I: ArrowNativeType,
{
let num_bytes = bit_util::ceil(indices.len(), 8);
let mut nulls = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
let null_slice = nulls.as_slice_mut();
let mut null_count = 0;
- let values_values = values.values();
-
let values = indices.iter().enumerate().map(|(i, index)| {
let index = maybe_usize::<I>(*index)?;
- if values.is_null(index) {
+ if values_data.is_null(index) {
null_count += 1;
bit_util::unset_bit(null_slice, i);
}
- Result::Ok(values_values[index])
+ Result::Ok(values[index])
});
// Soundness: `slice.map` is `TrustedLen`.
let buffer = unsafe { Buffer::try_from_trusted_len_iter(values)? };
@@ -366,21 +372,33 @@ where
// take implementation when only indices contain nulls
fn take_indices_nulls<T, I>(
- values: &[T::Native],
+ values: &[T],
indices: &PrimitiveArray<I>,
) -> Result<(Buffer, Option<Buffer>)>
where
- T: ArrowPrimitiveType,
+ T: ArrowNativeType,
I: ArrowNumericType,
I::Native: ToPrimitive,
{
- let values = indices.values().iter().map(|index| {
+ take_indices_nulls_inner(values, indices.values(), indices.data())
+}
+
+fn take_indices_nulls_inner<T, I>(
+ values: &[T],
+ indices: &[I],
+ indices_data: &ArrayData,
+) -> Result<(Buffer, Option<Buffer>)>
+where
+ T: ArrowNativeType,
+ I: ArrowNativeType,
+{
+ let values = indices.iter().map(|index| {
let index = maybe_usize::<I>(*index)?;
Result::Ok(match values.get(index) {
Some(value) => *value,
None => {
- if indices.is_null(index) {
- T::Native::default()
+ if indices_data.is_null(index) {
+ T::default()
} else {
panic!("Out-of-bounds index {}", index)
}
@@ -393,10 +411,9 @@ where
Ok((
buffer,
- indices
- .data_ref()
+ indices_data
.null_buffer()
- .map(|b| b.bit_slice(indices.offset(), indices.len())),
+ .map(|b| b.bit_slice(indices_data.offset(), indices.len())),
))
}
@@ -410,25 +427,41 @@ where
I: ArrowNumericType,
I::Native: ToPrimitive,
{
+ take_values_indices_nulls_inner(
+ values.values(),
+ values.data(),
+ indices.values(),
+ indices.data(),
+ )
+}
+
+fn take_values_indices_nulls_inner<T, I>(
+ values: &[T],
+ values_data: &ArrayData,
+ indices: &[I],
+ indices_data: &ArrayData,
+) -> Result<(Buffer, Option<Buffer>)>
+where
+ T: ArrowNativeType,
+ I: ArrowNativeType,
+{
let num_bytes = bit_util::ceil(indices.len(), 8);
let mut nulls = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
let null_slice = nulls.as_slice_mut();
let mut null_count = 0;
- let values_values = values.values();
- let values = indices.iter().enumerate().map(|(i, index)| match index {
- Some(index) => {
+ let values = indices.iter().enumerate().map(|(i, &index)| {
+ if indices_data.is_null(i) {
+ null_count += 1;
+ bit_util::unset_bit(null_slice, i);
+ Ok(T::default())
+ } else {
let index = maybe_usize::<I>(index)?;
- if values.is_null(index) {
+ if values_data.is_null(index) {
null_count += 1;
bit_util::unset_bit(null_slice, i);
}
- Result::Ok(values_values[index])
- }
- None => {
- null_count += 1;
- bit_util::unset_bit(null_slice, i);
- Ok(T::Native::default())
+ Result::Ok(values[index])
}
});
// Soundness: `slice.map` is `TrustedLen`.
@@ -471,17 +504,17 @@ where
(false, false) => {
// * no nulls
// * all `indices.values()` are valid
- take_no_nulls::<T, I>(values.values(), indices.values())?
+ take_no_nulls::<T::Native, I::Native>(values.values(),
indices.values())?
}
(true, false) => {
// * nulls come from `values` alone
// * all `indices.values()` are valid
- take_values_nulls::<T, I>(values, indices.values())?
+ take_values_nulls::<T, I::Native>(values, indices.values())?
}
(false, true) => {
// in this branch it is unsound to read and use `index.values()`,
// as doing so is UB when they come from a null slot.
- take_indices_nulls::<T, I>(values.values(), indices)?
+ take_indices_nulls::<T::Native, I>(values.values(), indices)?
}
(true, true) => {
// in this branch it is unsound to read and use `index.values()`,
@@ -795,7 +828,7 @@ where
.values()
.iter()
.map(|idx| {
- let idx = maybe_usize::<IndexType>(*idx)?;
+ let idx = maybe_usize::<IndexType::Native>(*idx)?;
if data_ref.is_valid(idx) {
Ok(Some(values.value(idx)))
} else {
@@ -821,7 +854,7 @@ where
.values()
.iter()
.map(|idx| {
- let idx = maybe_usize::<IndexType>(*idx)?;
+ let idx = maybe_usize::<IndexType::Native>(*idx)?;
if data_ref.is_valid(idx) {
Ok(Some(values.value(idx)))
} else {