This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 5c3ed61  chore: Reduce the amount of code generated by 
monomorphization (#715)
5c3ed61 is described below

commit 5c3ed6123d9ea0130a1eca95a0aae776b458208f
Author: Markus Westerlind <[email protected]>
AuthorDate: Mon Sep 13 18:55:56 2021 +0200

    chore: Reduce the amount of code generated by monomorphization (#715)
    
    * chore: Reduce the number of instantiations of take* (-3%)
    
    Many types have the same native type, so simplifying these functions to
    work directly with native types reduces the number of instantiations.
    
    Reduces the number of llvm lines generated by ~3%
    
    * chore: Shrink try_from_trusted_len_iter (-0.5%)
    
    * chore: Only compile sort_primitive per native type (-8.5%)
    
    * chore: Make the inner take_ functions less generic (-3.5%)
    
    * chore: Don't duplicate sort_list (-13%)
    
    * chore: Extract the "valid" sorting (-7%)
    
    * chore: Extract the array sorter (-1%)
---
 arrow/src/buffer/mutable.rs       |  22 ++++--
 arrow/src/compute/kernels/sort.rs | 160 ++++++++++++++++++++++----------------
 arrow/src/compute/kernels/take.rs | 107 ++++++++++++++++---------
 3 files changed, 178 insertions(+), 111 deletions(-)

diff --git a/arrow/src/buffer/mutable.rs b/arrow/src/buffer/mutable.rs
index 7d336e0..d83997a 100644
--- a/arrow/src/buffer/mutable.rs
+++ b/arrow/src/buffer/mutable.rs
@@ -530,12 +530,22 @@ impl MutableBuffer {
             std::ptr::write(dst, item?);
             dst = dst.add(1);
         }
-        assert_eq!(
-            dst.offset_from(buffer.data.as_ptr() as *mut T) as usize,
-            upper,
-            "Trusted iterator length was not accurately reported"
-        );
-        buffer.len = len;
+        // try_from_trusted_len_iter is instantiated a lot, so we extract part 
of it into a less
+        // generic method to reduce compile time
+        unsafe fn finalize_buffer<T>(
+            dst: *mut T,
+            buffer: &mut MutableBuffer,
+            upper: usize,
+            len: usize,
+        ) {
+            assert_eq!(
+                dst.offset_from(buffer.data.as_ptr() as *mut T) as usize,
+                upper,
+                "Trusted iterator length was not accurately reported"
+            );
+            buffer.len = len;
+        }
+        finalize_buffer(dst, &mut buffer, upper, len);
         Ok(buffer)
     }
 }
diff --git a/arrow/src/compute/kernels/sort.rs 
b/arrow/src/compute/kernels/sort.rs
index 925f42e..6f42be3 100644
--- a/arrow/src/compute/kernels/sort.rs
+++ b/arrow/src/compute/kernels/sort.rs
@@ -425,7 +425,7 @@ impl Default for SortOptions {
 fn sort_boolean(
     values: &ArrayRef,
     value_indices: Vec<u32>,
-    null_indices: Vec<u32>,
+    mut null_indices: Vec<u32>,
     options: &SortOptions,
     limit: Option<usize>,
 ) -> UInt32Array {
@@ -446,15 +446,8 @@ fn sort_boolean(
             .into_iter()
             .map(|index| (index, values.value(index as usize)))
             .collect::<Vec<(u32, bool)>>();
-        if !descending {
-            sort_unstable_by(&mut valids, len.saturating_sub(nulls_len), |a, 
b| {
-                cmp(a.1, b.1)
-            });
-        } else {
-            sort_unstable_by(&mut valids, len.saturating_sub(nulls_len), |a, 
b| {
-                cmp(a.1, b.1).reverse()
-            });
-        }
+
+        sort_valids(descending, &mut valids, &mut null_indices, len, cmp);
         valids
     } else {
         // when limit is not present, we have a better way than sorting: we 
can just partition
@@ -465,13 +458,13 @@ fn sort_boolean(
             .map(|index| (index, values.value(index as usize)))
             .partition(|(_, value)| *value == descending);
         a.extend(b);
+        if descending {
+            null_indices.reverse();
+        }
         a
     };
 
-    let mut nulls = null_indices;
-    if descending {
-        nulls.reverse();
-    }
+    let nulls = null_indices;
 
     // collect results directly into a buffer instead of a vec to avoid 
another aligned allocation
     let result_capacity = len * std::mem::size_of::<u32>();
@@ -522,15 +515,31 @@ where
     T::Native: std::cmp::PartialOrd,
     F: Fn(T::Native, T::Native) -> std::cmp::Ordering,
 {
-    let values = as_primitive_array::<T>(values);
-    let descending = options.descending;
-
     // create tuples that are used for sorting
-    let mut valids = value_indices
-        .into_iter()
-        .map(|index| (index, values.value(index as usize)))
-        .collect::<Vec<(u32, T::Native)>>();
+    let valids = {
+        let values = as_primitive_array::<T>(values);
+        value_indices
+            .into_iter()
+            .map(|index| (index, values.value(index as usize)))
+            .collect::<Vec<(u32, T::Native)>>()
+    };
+    sort_primitive_inner(values, null_indices, cmp, options, limit, valids)
+}
 
+// sort is instantiated a lot so we only compile this inner version for each 
native type
+fn sort_primitive_inner<T, F>(
+    values: &ArrayRef,
+    null_indices: Vec<u32>,
+    cmp: F,
+    options: &SortOptions,
+    limit: Option<usize>,
+    mut valids: Vec<(u32, T)>,
+) -> UInt32Array
+where
+    T: ArrowNativeType,
+    T: std::cmp::PartialOrd,
+    F: Fn(T, T) -> std::cmp::Ordering,
+{
     let mut nulls = null_indices;
 
     let valids_len = valids.len();
@@ -540,17 +549,8 @@ where
     if let Some(limit) = limit {
         len = limit.min(len);
     }
-    if !descending {
-        sort_unstable_by(&mut valids, len.saturating_sub(nulls_len), |a, b| {
-            cmp(a.1, b.1)
-        });
-    } else {
-        sort_unstable_by(&mut valids, len.saturating_sub(nulls_len), |a, b| {
-            cmp(a.1, b.1).reverse()
-        });
-        // reverse to keep a stable ordering
-        nulls.reverse();
-    }
+
+    sort_valids(options.descending, &mut valids, &mut nulls, len, cmp);
 
     // collect results directly into a buffer instead of a vec to avoid 
another aligned allocation
     let result_capacity = len * std::mem::size_of::<u32>();
@@ -673,22 +673,12 @@ where
     let mut nulls = null_indices;
     let descending = options.descending;
     let mut len = values.len();
-    let nulls_len = nulls.len();
 
     if let Some(limit) = limit {
         len = limit.min(len);
     }
-    if !descending {
-        sort_unstable_by(&mut valids, len.saturating_sub(nulls_len), |a, b| {
-            cmp(a.1, b.1)
-        });
-    } else {
-        sort_unstable_by(&mut valids, len.saturating_sub(nulls_len), |a, b| {
-            cmp(a.1, b.1).reverse()
-        });
-        // reverse to keep a stable ordering
-        nulls.reverse();
-    }
+
+    sort_valids(descending, &mut valids, &mut nulls, len, cmp);
     // collect the order of valid tuplies
     let mut valid_indices: Vec<u32> = valids.iter().map(|tuple| 
tuple.0).collect();
 
@@ -707,7 +697,7 @@ where
 fn sort_list<S, T>(
     values: &ArrayRef,
     value_indices: Vec<u32>,
-    mut null_indices: Vec<u32>,
+    null_indices: Vec<u32>,
     options: &SortOptions,
     limit: Option<usize>,
 ) -> UInt32Array
@@ -716,6 +706,19 @@ where
     T: ArrowPrimitiveType,
     T::Native: std::cmp::PartialOrd,
 {
+    sort_list_inner::<S>(values, value_indices, null_indices, options, limit)
+}
+
+fn sort_list_inner<S>(
+    values: &ArrayRef,
+    value_indices: Vec<u32>,
+    mut null_indices: Vec<u32>,
+    options: &SortOptions,
+    limit: Option<usize>,
+) -> UInt32Array
+where
+    S: OffsetSizeTrait,
+{
     let mut valids: Vec<(u32, ArrayRef)> = values
         .as_any()
         .downcast_ref::<FixedSizeListArray>()
@@ -738,23 +741,12 @@ where
         );
 
     let mut len = values.len();
-    let nulls_len = null_indices.len();
     let descending = options.descending;
 
     if let Some(limit) = limit {
         len = limit.min(len);
     }
-    if !descending {
-        sort_unstable_by(&mut valids, len.saturating_sub(nulls_len), |a, b| {
-            cmp_array(a.1.as_ref(), b.1.as_ref())
-        });
-    } else {
-        sort_unstable_by(&mut valids, len.saturating_sub(nulls_len), |a, b| {
-            cmp_array(a.1.as_ref(), b.1.as_ref()).reverse()
-        });
-        // reverse to keep a stable ordering
-        null_indices.reverse();
-    }
+    sort_valids_array(descending, &mut valids, &mut null_indices, len);
 
     let mut valid_indices: Vec<u32> = valids.iter().map(|tuple| 
tuple.0).collect();
     if options.nulls_first {
@@ -801,21 +793,12 @@ where
 
     let mut len = values.len();
     let descending = options.descending;
-    let nulls_len = null_indices.len();
 
     if let Some(limit) = limit {
         len = limit.min(len);
     }
-    if !descending {
-        sort_unstable_by(&mut valids, len.saturating_sub(nulls_len), |a, b| {
-            a.1.cmp(b.1)
-        });
-    } else {
-        sort_unstable_by(&mut valids, len.saturating_sub(nulls_len), |a, b| {
-            a.1.cmp(b.1).reverse()
-        });
-        null_indices.reverse();
-    }
+
+    sort_valids(descending, &mut valids, &mut null_indices, len, cmp);
 
     let mut valid_indices: Vec<u32> = valids.iter().map(|tuple| 
tuple.0).collect();
     if options.nulls_first {
@@ -1036,6 +1019,47 @@ impl LexicographicalComparator<'_> {
     }
 }
 
+fn sort_valids<T, U>(
+    descending: bool,
+    valids: &mut [(u32, T)],
+    nulls: &mut [U],
+    len: usize,
+    mut cmp: impl FnMut(T, T) -> Ordering,
+) where
+    T: ?Sized + Copy,
+{
+    let nulls_len = nulls.len();
+    if !descending {
+        sort_unstable_by(valids, len.saturating_sub(nulls_len), |a, b| 
cmp(a.1, b.1));
+    } else {
+        sort_unstable_by(valids, len.saturating_sub(nulls_len), |a, b| {
+            cmp(a.1, b.1).reverse()
+        });
+        // reverse to keep a stable ordering
+        nulls.reverse();
+    }
+}
+
+fn sort_valids_array<T>(
+    descending: bool,
+    valids: &mut [(u32, ArrayRef)],
+    nulls: &mut [T],
+    len: usize,
+) {
+    let nulls_len = nulls.len();
+    if !descending {
+        sort_unstable_by(valids, len.saturating_sub(nulls_len), |a, b| {
+            cmp_array(a.1.as_ref(), b.1.as_ref())
+        });
+    } else {
+        sort_unstable_by(valids, len.saturating_sub(nulls_len), |a, b| {
+            cmp_array(a.1.as_ref(), b.1.as_ref()).reverse()
+        });
+        // reverse to keep a stable ordering
+        nulls.reverse();
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
diff --git a/arrow/src/compute/kernels/take.rs 
b/arrow/src/compute/kernels/take.rs
index 225f263..7147972 100644
--- a/arrow/src/compute/kernels/take.rs
+++ b/arrow/src/compute/kernels/take.rs
@@ -302,20 +302,17 @@ impl Default for TakeOptions {
 }
 
 #[inline(always)]
-fn maybe_usize<I: ArrowPrimitiveType>(index: I::Native) -> Result<usize> {
+fn maybe_usize<I: ArrowNativeType>(index: I) -> Result<usize> {
     index
         .to_usize()
         .ok_or_else(|| ArrowError::ComputeError("Cast to usize 
failed".to_string()))
 }
 
 // take implementation when neither values nor indices contain nulls
-fn take_no_nulls<T, I>(
-    values: &[T::Native],
-    indices: &[I::Native],
-) -> Result<(Buffer, Option<Buffer>)>
+fn take_no_nulls<T, I>(values: &[T], indices: &[I]) -> Result<(Buffer, 
Option<Buffer>)>
 where
-    T: ArrowPrimitiveType,
-    I: ArrowNumericType,
+    T: ArrowNativeType,
+    I: ArrowNativeType,
 {
     let values = indices
         .iter()
@@ -329,27 +326,36 @@ where
 // take implementation when only values contain nulls
 fn take_values_nulls<T, I>(
     values: &PrimitiveArray<T>,
-    indices: &[I::Native],
+    indices: &[I],
 ) -> Result<(Buffer, Option<Buffer>)>
 where
     T: ArrowPrimitiveType,
-    I: ArrowNumericType,
-    I::Native: ToPrimitive,
+    I: ArrowNativeType,
+{
+    take_values_nulls_inner(values.data(), values.values(), indices)
+}
+
+fn take_values_nulls_inner<T, I>(
+    values_data: &ArrayData,
+    values: &[T],
+    indices: &[I],
+) -> Result<(Buffer, Option<Buffer>)>
+where
+    T: ArrowNativeType,
+    I: ArrowNativeType,
 {
     let num_bytes = bit_util::ceil(indices.len(), 8);
     let mut nulls = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
     let null_slice = nulls.as_slice_mut();
     let mut null_count = 0;
 
-    let values_values = values.values();
-
     let values = indices.iter().enumerate().map(|(i, index)| {
         let index = maybe_usize::<I>(*index)?;
-        if values.is_null(index) {
+        if values_data.is_null(index) {
             null_count += 1;
             bit_util::unset_bit(null_slice, i);
         }
-        Result::Ok(values_values[index])
+        Result::Ok(values[index])
     });
     // Soundness: `slice.map` is `TrustedLen`.
     let buffer = unsafe { Buffer::try_from_trusted_len_iter(values)? };
@@ -366,21 +372,33 @@ where
 
 // take implementation when only indices contain nulls
 fn take_indices_nulls<T, I>(
-    values: &[T::Native],
+    values: &[T],
     indices: &PrimitiveArray<I>,
 ) -> Result<(Buffer, Option<Buffer>)>
 where
-    T: ArrowPrimitiveType,
+    T: ArrowNativeType,
     I: ArrowNumericType,
     I::Native: ToPrimitive,
 {
-    let values = indices.values().iter().map(|index| {
+    take_indices_nulls_inner(values, indices.values(), indices.data())
+}
+
+fn take_indices_nulls_inner<T, I>(
+    values: &[T],
+    indices: &[I],
+    indices_data: &ArrayData,
+) -> Result<(Buffer, Option<Buffer>)>
+where
+    T: ArrowNativeType,
+    I: ArrowNativeType,
+{
+    let values = indices.iter().map(|index| {
         let index = maybe_usize::<I>(*index)?;
         Result::Ok(match values.get(index) {
             Some(value) => *value,
             None => {
-                if indices.is_null(index) {
-                    T::Native::default()
+                if indices_data.is_null(index) {
+                    T::default()
                 } else {
                     panic!("Out-of-bounds index {}", index)
                 }
@@ -393,10 +411,9 @@ where
 
     Ok((
         buffer,
-        indices
-            .data_ref()
+        indices_data
             .null_buffer()
-            .map(|b| b.bit_slice(indices.offset(), indices.len())),
+            .map(|b| b.bit_slice(indices_data.offset(), indices.len())),
     ))
 }
 
@@ -410,25 +427,41 @@ where
     I: ArrowNumericType,
     I::Native: ToPrimitive,
 {
+    take_values_indices_nulls_inner(
+        values.values(),
+        values.data(),
+        indices.values(),
+        indices.data(),
+    )
+}
+
+fn take_values_indices_nulls_inner<T, I>(
+    values: &[T],
+    values_data: &ArrayData,
+    indices: &[I],
+    indices_data: &ArrayData,
+) -> Result<(Buffer, Option<Buffer>)>
+where
+    T: ArrowNativeType,
+    I: ArrowNativeType,
+{
     let num_bytes = bit_util::ceil(indices.len(), 8);
     let mut nulls = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
     let null_slice = nulls.as_slice_mut();
     let mut null_count = 0;
 
-    let values_values = values.values();
-    let values = indices.iter().enumerate().map(|(i, index)| match index {
-        Some(index) => {
+    let values = indices.iter().enumerate().map(|(i, &index)| {
+        if indices_data.is_null(i) {
+            null_count += 1;
+            bit_util::unset_bit(null_slice, i);
+            Ok(T::default())
+        } else {
             let index = maybe_usize::<I>(index)?;
-            if values.is_null(index) {
+            if values_data.is_null(index) {
                 null_count += 1;
                 bit_util::unset_bit(null_slice, i);
             }
-            Result::Ok(values_values[index])
-        }
-        None => {
-            null_count += 1;
-            bit_util::unset_bit(null_slice, i);
-            Ok(T::Native::default())
+            Result::Ok(values[index])
         }
     });
     // Soundness: `slice.map` is `TrustedLen`.
@@ -471,17 +504,17 @@ where
         (false, false) => {
             // * no nulls
             // * all `indices.values()` are valid
-            take_no_nulls::<T, I>(values.values(), indices.values())?
+            take_no_nulls::<T::Native, I::Native>(values.values(), 
indices.values())?
         }
         (true, false) => {
             // * nulls come from `values` alone
             // * all `indices.values()` are valid
-            take_values_nulls::<T, I>(values, indices.values())?
+            take_values_nulls::<T, I::Native>(values, indices.values())?
         }
         (false, true) => {
             // in this branch it is unsound to read and use `index.values()`,
             // as doing so is UB when they come from a null slot.
-            take_indices_nulls::<T, I>(values.values(), indices)?
+            take_indices_nulls::<T::Native, I>(values.values(), indices)?
         }
         (true, true) => {
             // in this branch it is unsound to read and use `index.values()`,
@@ -795,7 +828,7 @@ where
         .values()
         .iter()
         .map(|idx| {
-            let idx = maybe_usize::<IndexType>(*idx)?;
+            let idx = maybe_usize::<IndexType::Native>(*idx)?;
             if data_ref.is_valid(idx) {
                 Ok(Some(values.value(idx)))
             } else {
@@ -821,7 +854,7 @@ where
         .values()
         .iter()
         .map(|idx| {
-            let idx = maybe_usize::<IndexType>(*idx)?;
+            let idx = maybe_usize::<IndexType::Native>(*idx)?;
             if data_ref.is_valid(idx) {
                 Ok(Some(values.value(idx)))
             } else {

Reply via email to