This is an automated email from the ASF dual-hosted git repository. alamb pushed a commit to branch cherry_pick_5c3ed612 in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
commit fd02abada7395872b230021fd808ff72e7cb9b49 Author: Markus Westerlind <[email protected]> AuthorDate: Mon Sep 13 18:55:56 2021 +0200 chore: Reduce the amount of code generated by monomorphization (#715) * chore: Reduce the number of instantiations of take* (-3%) Many types have the same native type, so simplifying these functions to work directly with native types reduces the number of instantiations. Reduces the number of llvm lines generated by ~3% * chore: Shrink try_from_trusted_len_iter (-0.5%) * chore: Make the inner take_ functions less generic (-3.5%) * chore: Extract the array sorter (-1%) --- arrow/src/buffer/mutable.rs | 22 +++++--- arrow/src/compute/kernels/take.rs | 107 +++++++++++++++++++++++++------------- 2 files changed, 86 insertions(+), 43 deletions(-) diff --git a/arrow/src/buffer/mutable.rs b/arrow/src/buffer/mutable.rs index 7d336e0..d83997a 100644 --- a/arrow/src/buffer/mutable.rs +++ b/arrow/src/buffer/mutable.rs @@ -530,12 +530,22 @@ impl MutableBuffer { std::ptr::write(dst, item?); dst = dst.add(1); } - assert_eq!( - dst.offset_from(buffer.data.as_ptr() as *mut T) as usize, - upper, - "Trusted iterator length was not accurately reported" - ); - buffer.len = len; + // try_from_trusted_len_iter is instantiated a lot, so we extract part of it into a less + // generic method to reduce compile time + unsafe fn finalize_buffer<T>( + dst: *mut T, + buffer: &mut MutableBuffer, + upper: usize, + len: usize, + ) { + assert_eq!( + dst.offset_from(buffer.data.as_ptr() as *mut T) as usize, + upper, + "Trusted iterator length was not accurately reported" + ); + buffer.len = len; + } + finalize_buffer(dst, &mut buffer, upper, len); Ok(buffer) } } diff --git a/arrow/src/compute/kernels/take.rs b/arrow/src/compute/kernels/take.rs index 225f263..7147972 100644 --- a/arrow/src/compute/kernels/take.rs +++ b/arrow/src/compute/kernels/take.rs @@ -302,20 +302,17 @@ impl Default for TakeOptions { } #[inline(always)] -fn maybe_usize<I: ArrowPrimitiveType>(index: I::Native) -> Result<usize> { +fn maybe_usize<I: ArrowNativeType>(index: I) -> Result<usize> { index .to_usize() .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string())) } // take implementation when neither values nor indices contain nulls -fn take_no_nulls<T, I>( - values: &[T::Native], - indices: &[I::Native], -) -> Result<(Buffer, Option<Buffer>)> +fn take_no_nulls<T, I>(values: &[T], indices: &[I]) -> Result<(Buffer, Option<Buffer>)> where - T: ArrowPrimitiveType, - I: ArrowNumericType, + T: ArrowNativeType, + I: ArrowNativeType, { let values = indices .iter() @@ -329,27 +326,36 @@ where // take implementation when only values contain nulls fn take_values_nulls<T, I>( values: &PrimitiveArray<T>, - indices: &[I::Native], + indices: &[I], ) -> Result<(Buffer, Option<Buffer>)> where T: ArrowPrimitiveType, - I: ArrowNumericType, - I::Native: ToPrimitive, + I: ArrowNativeType, +{ + take_values_nulls_inner(values.data(), values.values(), indices) +} + +fn take_values_nulls_inner<T, I>( + values_data: &ArrayData, + values: &[T], + indices: &[I], +) -> Result<(Buffer, Option<Buffer>)> +where + T: ArrowNativeType, + I: ArrowNativeType, { let num_bytes = bit_util::ceil(indices.len(), 8); let mut nulls = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true); let null_slice = nulls.as_slice_mut(); let mut null_count = 0; - let values_values = values.values(); - let values = indices.iter().enumerate().map(|(i, index)| { let index = maybe_usize::<I>(*index)?; - if values.is_null(index) { + if values_data.is_null(index) { null_count += 1; bit_util::unset_bit(null_slice, i); } - Result::Ok(values_values[index]) + Result::Ok(values[index]) }); // Soundness: `slice.map` is `TrustedLen`. let buffer = unsafe { Buffer::try_from_trusted_len_iter(values)? }; @@ -366,21 +372,33 @@ where // take implementation when only indices contain nulls fn take_indices_nulls<T, I>( - values: &[T::Native], + values: &[T], indices: &PrimitiveArray<I>, ) -> Result<(Buffer, Option<Buffer>)> where - T: ArrowPrimitiveType, + T: ArrowNativeType, I: ArrowNumericType, I::Native: ToPrimitive, { - let values = indices.values().iter().map(|index| { + take_indices_nulls_inner(values, indices.values(), indices.data()) +} + +fn take_indices_nulls_inner<T, I>( + values: &[T], + indices: &[I], + indices_data: &ArrayData, +) -> Result<(Buffer, Option<Buffer>)> +where + T: ArrowNativeType, + I: ArrowNativeType, +{ + let values = indices.iter().map(|index| { let index = maybe_usize::<I>(*index)?; Result::Ok(match values.get(index) { Some(value) => *value, None => { - if indices.is_null(index) { - T::Native::default() + if indices_data.is_null(index) { + T::default() } else { panic!("Out-of-bounds index {}", index) } @@ -393,10 +411,9 @@ where Ok(( buffer, - indices - .data_ref() + indices_data .null_buffer() - .map(|b| b.bit_slice(indices.offset(), indices.len())), + .map(|b| b.bit_slice(indices_data.offset(), indices.len())), )) } @@ -410,25 +427,41 @@ where I: ArrowNumericType, I::Native: ToPrimitive, { + take_values_indices_nulls_inner( + values.values(), + values.data(), + indices.values(), + indices.data(), + ) +} + +fn take_values_indices_nulls_inner<T, I>( + values: &[T], + values_data: &ArrayData, + indices: &[I], + indices_data: &ArrayData, +) -> Result<(Buffer, Option<Buffer>)> +where + T: ArrowNativeType, + I: ArrowNativeType, +{ let num_bytes = bit_util::ceil(indices.len(), 8); let mut nulls = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true); let null_slice = nulls.as_slice_mut(); let mut null_count = 0; - let values_values = values.values(); - let values = indices.iter().enumerate().map(|(i, index)| match index { - Some(index) => { + let values = indices.iter().enumerate().map(|(i, &index)| { + if indices_data.is_null(i) { + null_count += 1; + bit_util::unset_bit(null_slice, i); + Ok(T::default()) + } else { let index = maybe_usize::<I>(index)?; - if values.is_null(index) { + if values_data.is_null(index) { null_count += 1; bit_util::unset_bit(null_slice, i); } - Result::Ok(values_values[index]) - } - None => { - null_count += 1; - bit_util::unset_bit(null_slice, i); - Ok(T::Native::default()) + Result::Ok(values[index]) } }); // Soundness: `slice.map` is `TrustedLen`. @@ -471,17 +504,17 @@ where (false, false) => { // * no nulls // * all `indices.values()` are valid - take_no_nulls::<T, I>(values.values(), indices.values())? + take_no_nulls::<T::Native, I::Native>(values.values(), indices.values())? } (true, false) => { // * nulls come from `values` alone // * all `indices.values()` are valid - take_values_nulls::<T, I>(values, indices.values())? + take_values_nulls::<T, I::Native>(values, indices.values())? } (false, true) => { // in this branch it is unsound to read and use `index.values()`, // as doing so is UB when they come from a null slot. - take_indices_nulls::<T, I>(values.values(), indices)? + take_indices_nulls::<T::Native, I>(values.values(), indices)? } (true, true) => { // in this branch it is unsound to read and use `index.values()`, @@ -795,7 +828,7 @@ where .values() .iter() .map(|idx| { - let idx = maybe_usize::<IndexType>(*idx)?; + let idx = maybe_usize::<IndexType::Native>(*idx)?; if data_ref.is_valid(idx) { Ok(Some(values.value(idx))) } else { @@ -821,7 +854,7 @@ where .values() .iter() .map(|idx| { - let idx = maybe_usize::<IndexType>(*idx)?; + let idx = maybe_usize::<IndexType::Native>(*idx)?; if data_ref.is_valid(idx) { Ok(Some(values.value(idx))) } else {
