This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new af40ea382 Implement specialized min/max for `GenericBinaryView` 
(`StringView` and `BinaryView`) (#6089)
af40ea382 is described below

commit af40ea382275dba967bfabc1632fded07d2129b9
Author: Xiangpeng Hao <[email protected]>
AuthorDate: Tue Jul 23 15:48:25 2024 -0400

    Implement specialized min/max for `GenericBinaryView` (`StringView` and 
`BinaryView`) (#6089)
    
    * implement better min/max for string view
    
    * Apply suggestions from code review
    
    Co-authored-by: Andrew Lamb <[email protected]>
    
    * address review comments
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 arrow-arith/src/aggregate.rs             | 51 ++++++++++++++++++++++++---
 arrow-array/src/array/byte_view_array.rs | 60 ++++++++++++++++++++++++++++++++
 arrow-ord/src/cmp.rs                     |  7 ++--
 arrow/benches/aggregate_kernels.rs       | 23 ++++++++++--
 4 files changed, 132 insertions(+), 9 deletions(-)

diff --git a/arrow-arith/src/aggregate.rs b/arrow-arith/src/aggregate.rs
index f526268bf..a81e7fc1a 100644
--- a/arrow-arith/src/aggregate.rs
+++ b/arrow-arith/src/aggregate.rs
@@ -24,7 +24,9 @@ use arrow_buffer::{ArrowNativeType, NullBuffer};
 use arrow_data::bit_iterator::try_for_each_valid_idx;
 use arrow_schema::*;
 use std::borrow::BorrowMut;
+use std::cmp::{self, Ordering};
 use std::ops::{BitAnd, BitOr, BitXor};
+use types::ByteViewType;
 
 /// An accumulator for primitive numeric values.
 trait NumericAccumulator<T: ArrowNativeTypeOp>: Copy + Default {
@@ -425,6 +427,47 @@ where
     }
 }
 
+/// Helper to compute min/max of [`GenericByteViewArray<T>`].
+/// The specialized min/max leverages the inlined values to compare the byte 
views.
+/// `swap_cond` is the condition to swap current min/max with the new value.
+/// For example, `Ordering::Greater` for max and `Ordering::Less` for min.
+fn min_max_view_helper<T: ByteViewType>(
+    array: &GenericByteViewArray<T>,
+    swap_cond: cmp::Ordering,
+) -> Option<&T::Native> {
+    let null_count = array.null_count();
+    if null_count == array.len() {
+        None
+    } else if null_count == 0 {
+        let target_idx = (0..array.len()).reduce(|acc, item| {
+            // SAFETY:  array's length is correct so item is within bounds
+            let cmp = unsafe { GenericByteViewArray::compare_unchecked(array, 
item, array, acc) };
+            if cmp == swap_cond {
+                item
+            } else {
+                acc
+            }
+        });
+        // SAFETY: idx came from valid range `0..array.len()`
+        unsafe { target_idx.map(|idx| array.value_unchecked(idx)) }
+    } else {
+        let nulls = array.nulls().unwrap();
+
+        let target_idx = nulls.valid_indices().reduce(|acc_idx, idx| {
+            let cmp =
+                unsafe { GenericByteViewArray::compare_unchecked(array, idx, 
array, acc_idx) };
+            if cmp == swap_cond {
+                idx
+            } else {
+                acc_idx
+            }
+        });
+
+        // SAFETY: idx came from valid range `0..array.len()`
+        unsafe { target_idx.map(|idx| array.value_unchecked(idx)) }
+    }
+}
+
 /// Returns the maximum value in the binary array, according to the natural 
order.
 pub fn max_binary<T: OffsetSizeTrait>(array: &GenericBinaryArray<T>) -> 
Option<&[u8]> {
     min_max_helper::<&[u8], _, _>(array, |a, b| *a < *b)
@@ -432,7 +475,7 @@ pub fn max_binary<T: OffsetSizeTrait>(array: 
&GenericBinaryArray<T>) -> Option<&
 
 /// Returns the maximum value in the binary view array, according to the 
natural order.
 pub fn max_binary_view(array: &BinaryViewArray) -> Option<&[u8]> {
-    min_max_helper::<&[u8], _, _>(array, |a, b| *a < *b)
+    min_max_view_helper(array, Ordering::Greater)
 }
 
 /// Returns the minimum value in the binary array, according to the natural 
order.
@@ -442,7 +485,7 @@ pub fn min_binary<T: OffsetSizeTrait>(array: 
&GenericBinaryArray<T>) -> Option<&
 
 /// Returns the minimum value in the binary view array, according to the 
natural order.
 pub fn min_binary_view(array: &BinaryViewArray) -> Option<&[u8]> {
-    min_max_helper::<&[u8], _, _>(array, |a, b| *a > *b)
+    min_max_view_helper(array, Ordering::Less)
 }
 
 /// Returns the maximum value in the string array, according to the natural 
order.
@@ -452,7 +495,7 @@ pub fn max_string<T: OffsetSizeTrait>(array: 
&GenericStringArray<T>) -> Option<&
 
 /// Returns the maximum value in the string view array, according to the 
natural order.
 pub fn max_string_view(array: &StringViewArray) -> Option<&str> {
-    min_max_helper::<&str, _, _>(array, |a, b| *a < *b)
+    min_max_view_helper(array, Ordering::Greater)
 }
 
 /// Returns the minimum value in the string array, according to the natural 
order.
@@ -462,7 +505,7 @@ pub fn min_string<T: OffsetSizeTrait>(array: 
&GenericStringArray<T>) -> Option<&
 
 /// Returns the minimum value in the string view array, according to the 
natural order.
 pub fn min_string_view(array: &StringViewArray) -> Option<&str> {
-    min_max_helper::<&str, _, _>(array, |a, b| *a > *b)
+    min_max_view_helper(array, Ordering::Less)
 }
 
 /// Returns the sum of values in the array.
diff --git a/arrow-array/src/array/byte_view_array.rs 
b/arrow-array/src/array/byte_view_array.rs
index 7017add49..bd8c0cebc 100644
--- a/arrow-array/src/array/byte_view_array.rs
+++ b/arrow-array/src/array/byte_view_array.rs
@@ -336,6 +336,66 @@ impl<T: ByteViewType + ?Sized> GenericByteViewArray<T> {
 
         builder.finish()
     }
+
+    /// Comparing two [`GenericByteViewArray`] at index `left_idx` and 
`right_idx`
+    ///
+    /// Comparing two ByteView types are non-trivial.
+    /// It takes a bit of patience to understand why we don't just compare two 
&[u8] directly.
+    ///
+    /// ByteView types give us the following two advantages, and we need to be 
careful not to lose them:
+    /// (1) For string/byte smaller than 12 bytes, the entire data is inlined 
in the view.
+    ///     Meaning that reading one array element requires only one memory 
access
+    ///     (two memory access required for StringArray, one for offset 
buffer, the other for value buffer).
+    ///
+    /// (2) For string/byte larger than 12 bytes, we can still be faster than 
(for certain operations) StringArray/ByteArray,
+    ///     thanks to the inlined 4 bytes.
+    ///     Consider equality check:
+    ///     If the first four bytes of the two strings are different, we can 
return false immediately (with just one memory access).
+    ///
+    /// If we directly compare two &[u8], we materialize the entire string 
(i.e., make multiple memory accesses), which might be unnecessary.
+    /// - Most of the time (eq, ord), we only need to look at the first 4 
bytes to know the answer,
+    ///   e.g., if the inlined 4 bytes are different, we can directly return 
unequal without looking at the full string.
+    ///
+    /// # Order check flow
+    /// (1) if both string are smaller than 12 bytes, we can directly compare 
the data inlined to the view.
+    /// (2) if any of the string is larger than 12 bytes, we need to compare 
the full string.
+    ///     (2.1) if the inlined 4 bytes are different, we can return the 
result immediately.
+    ///     (2.2) o.w., we need to compare the full string.
+    ///
+    /// # Safety
+    /// The left/right_idx must within range of each array
+    pub unsafe fn compare_unchecked(
+        left: &GenericByteViewArray<T>,
+        left_idx: usize,
+        right: &GenericByteViewArray<T>,
+        right_idx: usize,
+    ) -> std::cmp::Ordering {
+        let l_view = left.views().get_unchecked(left_idx);
+        let l_len = *l_view as u32;
+
+        let r_view = right.views().get_unchecked(right_idx);
+        let r_len = *r_view as u32;
+
+        if l_len <= 12 && r_len <= 12 {
+            let l_data = unsafe { 
GenericByteViewArray::<T>::inline_value(l_view, l_len as usize) };
+            let r_data = unsafe { 
GenericByteViewArray::<T>::inline_value(r_view, r_len as usize) };
+            return l_data.cmp(r_data);
+        }
+
+        // one of the string is larger than 12 bytes,
+        // we then try to compare the inlined data first
+        let l_inlined_data = unsafe { 
GenericByteViewArray::<T>::inline_value(l_view, 4) };
+        let r_inlined_data = unsafe { 
GenericByteViewArray::<T>::inline_value(r_view, 4) };
+        if r_inlined_data != l_inlined_data {
+            return l_inlined_data.cmp(r_inlined_data);
+        }
+
+        // unfortunately, we need to compare the full data
+        let l_full_data: &[u8] = unsafe { 
left.value_unchecked(left_idx).as_ref() };
+        let r_full_data: &[u8] = unsafe { 
right.value_unchecked(right_idx).as_ref() };
+
+        l_full_data.cmp(r_full_data)
+    }
 }
 
 impl<T: ByteViewType + ?Sized> Debug for GenericByteViewArray<T> {
diff --git a/arrow-ord/src/cmp.rs b/arrow-ord/src/cmp.rs
index 26eb0d8d6..9d7874c64 100644
--- a/arrow-ord/src/cmp.rs
+++ b/arrow-ord/src/cmp.rs
@@ -579,13 +579,13 @@ impl<'a, T: ByteViewType> ArrayOrd for &'a 
GenericByteViewArray<T> {
             return false;
         }
 
-        unsafe { compare_byte_view_unchecked(l.0, l.1, r.0, r.1).is_eq() }
+        unsafe { GenericByteViewArray::compare_unchecked(l.0, l.1, r.0, 
r.1).is_eq() }
     }
 
     fn is_lt(l: Self::Item, r: Self::Item) -> bool {
         // # Safety
         // The index is within bounds as it is checked in value()
-        unsafe { compare_byte_view_unchecked(l.0, l.1, r.0, r.1).is_lt() }
+        unsafe { GenericByteViewArray::compare_unchecked(l.0, l.1, r.0, 
r.1).is_lt() }
     }
 
     fn len(&self) -> usize {
@@ -626,7 +626,7 @@ pub fn compare_byte_view<T: ByteViewType>(
 ) -> std::cmp::Ordering {
     assert!(left_idx < left.len());
     assert!(right_idx < right.len());
-    unsafe { compare_byte_view_unchecked(left, left_idx, right, right_idx) }
+    unsafe { GenericByteViewArray::compare_unchecked(left, left_idx, right, 
right_idx) }
 }
 
 /// Comparing two [`GenericByteViewArray`] at index `left_idx` and `right_idx`
@@ -656,6 +656,7 @@ pub fn compare_byte_view<T: ByteViewType>(
 ///
 /// # Safety
 /// The left/right_idx must within range of each array
+#[deprecated(note = "Use `GenericByteViewArray::compare_unchecked` instead")]
 pub unsafe fn compare_byte_view_unchecked<T: ByteViewType>(
     left: &GenericByteViewArray<T>,
     left_idx: usize,
diff --git a/arrow/benches/aggregate_kernels.rs 
b/arrow/benches/aggregate_kernels.rs
index 9bb866f36..434bb4778 100644
--- a/arrow/benches/aggregate_kernels.rs
+++ b/arrow/benches/aggregate_kernels.rs
@@ -57,8 +57,8 @@ fn add_benchmark(c: &mut Criterion) {
     primitive_benchmark::<Int64Type>(c, "int64");
 
     {
-        let nonnull_strings = create_string_array::<i32>(BATCH_SIZE, 0.0);
-        let nullable_strings = create_string_array::<i32>(BATCH_SIZE, 0.5);
+        let nonnull_strings = create_string_array_with_len::<i32>(BATCH_SIZE, 
0.0, 16);
+        let nullable_strings = create_string_array_with_len::<i32>(BATCH_SIZE, 
0.5, 16);
         c.benchmark_group("string")
             .throughput(Throughput::Elements(BATCH_SIZE as u64))
             .bench_function("min nonnull", |b| b.iter(|| 
min_string(&nonnull_strings)))
@@ -67,6 +67,25 @@ fn add_benchmark(c: &mut Criterion) {
             .bench_function("max nullable", |b| b.iter(|| 
max_string(&nullable_strings)));
     }
 
+    {
+        let nonnull_strings = create_string_view_array_with_len(BATCH_SIZE, 
0.0, 16, false);
+        let nullable_strings = create_string_view_array_with_len(BATCH_SIZE, 
0.5, 16, false);
+        c.benchmark_group("string view")
+            .throughput(Throughput::Elements(BATCH_SIZE as u64))
+            .bench_function("min nonnull", |b| {
+                b.iter(|| min_string_view(&nonnull_strings))
+            })
+            .bench_function("max nonnull", |b| {
+                b.iter(|| max_string_view(&nonnull_strings))
+            })
+            .bench_function("min nullable", |b| {
+                b.iter(|| min_string_view(&nullable_strings))
+            })
+            .bench_function("max nullable", |b| {
+                b.iter(|| max_string_view(&nullable_strings))
+            });
+    }
+
     {
         let nonnull_bools_mixed = create_boolean_array(BATCH_SIZE, 0.0, 0.5);
         let nonnull_bools_all_false = create_boolean_array(BATCH_SIZE, 0.0, 
0.0);

Reply via email to