This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new b4de6927c8 Improve performance of set_bits by avoiding to set 
individual bits (#6288)
b4de6927c8 is described below

commit b4de6927c8ad9955b1b723353a06be4afecdca57
Author: KAZUYUKI TANIMURA <[email protected]>
AuthorDate: Sun Sep 15 04:12:52 2024 -0700

    Improve performance of set_bits by avoiding to set individual bits (#6288)
    
    * bench
    
    * fix: Optimize set_bits
    
    * clippy
    
    * clippyj
    
    * miri
    
    * fix: Optimize set_bits
    
    * fix: Optimize set_bits
    
    * fix: Optimize set_bits
    
    * fix: Optimize set_bits
    
    * fix: Optimize set_bits
    
    * fix: Optimize set_bits
    
    * fix: Optimize set_bits
    
    * fix: Optimize set_bits
    
    * miri
    
    * miri
    
    * miri
    
    * miri
    
    * miri
    
    * miri
    
    * miri
    
    * miri
    
    * miri
    
    * miri
    
    * miri
    
    * address review comments
    
    * address review comments
    
    * address review comments
    
    * Revert "address review comments"
    
    This reverts commit ef2864fe15d2c856c05eae70693d68eb2ae00fa8.
    
    * address review comments
    
    * address review comments
    
    * address review comments
    
    * address review comments
    
    * address review comments
    
    * address review comments
    
    * address review comments
    
    * address review comments
    
    * address review comments
    
    * address review comments
    
    * address review comments
    
    * address review comments
    
    * Revert "address review comments"
    
    This reverts commit a15db144effdfdae7dad4d93c8fb6eb93216dab0.
    
    * address review comments
    
    * address review comments
    
    * address review comments
    
    * address review comments
    
    * address review comments
    
    * address review comments
    
    * address review comments
---
 arrow-buffer/src/util/bit_mask.rs | 184 ++++++++++++++++++++++++++++++++------
 1 file changed, 158 insertions(+), 26 deletions(-)

diff --git a/arrow-buffer/src/util/bit_mask.rs 
b/arrow-buffer/src/util/bit_mask.rs
index 8f81cb7d04..2074f0fab9 100644
--- a/arrow-buffer/src/util/bit_mask.rs
+++ b/arrow-buffer/src/util/bit_mask.rs
@@ -17,12 +17,12 @@
 
 //! Utils for working with packed bit masks
 
-use crate::bit_chunk_iterator::BitChunks;
-use crate::bit_util::{ceil, get_bit, set_bit};
+use crate::bit_util::ceil;
 
 /// Sets all bits on `write_data` in the range 
`[offset_write..offset_write+len]` to be equal to the
 /// bits in `data` in the range `[offset_read..offset_read+len]`
 /// returns the number of `0` bits `data[offset_read..offset_read+len]`
+/// `offset_write`, `offset_read`, and `len` are in terms of bits
 pub fn set_bits(
     write_data: &mut [u8],
     data: &[u8],
@@ -30,35 +30,131 @@ pub fn set_bits(
     offset_read: usize,
     len: usize,
 ) -> usize {
+    assert!(offset_write + len <= write_data.len() * 8);
+    assert!(offset_read + len <= data.len() * 8);
     let mut null_count = 0;
-
-    let mut bits_to_align = offset_write % 8;
-    if bits_to_align > 0 {
-        bits_to_align = std::cmp::min(len, 8 - bits_to_align);
+    let mut acc = 0;
+    while len > acc {
+        // SAFETY: the arguments to `set_upto_64bits` are within the valid 
range because
+        // (offset_write + acc) + (len - acc) == offset_write + len <= 
write_data.len() * 8
+        // (offset_read + acc) + (len - acc) == offset_read + len <= 
data.len() * 8
+        let (n, len_set) = unsafe {
+            set_upto_64bits(
+                write_data,
+                data,
+                offset_write + acc,
+                offset_read + acc,
+                len - acc,
+            )
+        };
+        null_count += n;
+        acc += len_set;
     }
-    let mut write_byte_index = ceil(offset_write + bits_to_align, 8);
-
-    // Set full bytes provided by bit chunk iterator (which iterates in 64 
bits at a time)
-    let chunks = BitChunks::new(data, offset_read + bits_to_align, len - 
bits_to_align);
-    chunks.iter().for_each(|chunk| {
-        null_count += chunk.count_zeros();
-        write_data[write_byte_index..write_byte_index + 
8].copy_from_slice(&chunk.to_le_bytes());
-        write_byte_index += 8;
-    });
-
-    // Set individual bits both to align write_data to a byte offset and the 
remainder bits not covered by the bit chunk iterator
-    let remainder_offset = len - chunks.remainder_len();
-    (0..bits_to_align)
-        .chain(remainder_offset..len)
-        .for_each(|i| {
-            if get_bit(data, offset_read + i) {
-                set_bit(write_data, offset_write + i);
+
+    null_count
+}
+
+/// Similar to `set_bits` but sets only upto 64 bits, actual number of bits 
set may vary.
+/// Returns a pair of the number of `0` bits and the number of bits set
+///
+/// # Safety
+/// The caller must ensure all arguments are within the valid range.
+#[inline]
+unsafe fn set_upto_64bits(
+    write_data: &mut [u8],
+    data: &[u8],
+    offset_write: usize,
+    offset_read: usize,
+    len: usize,
+) -> (usize, usize) {
+    let read_byte = offset_read / 8;
+    let read_shift = offset_read % 8;
+    let write_byte = offset_write / 8;
+    let write_shift = offset_write % 8;
+
+    if len >= 64 {
+        let chunk = unsafe { (data.as_ptr().add(read_byte) as *const 
u64).read_unaligned() };
+        if read_shift == 0 {
+            if write_shift == 0 {
+                // no shifting necessary
+                let len = 64;
+                let null_count = chunk.count_zeros() as usize;
+                unsafe { write_u64_bytes(write_data, write_byte, chunk) };
+                (null_count, len)
             } else {
-                null_count += 1;
+                // only write shifting necessary
+                let len = 64 - write_shift;
+                let chunk = chunk << write_shift;
+                let null_count = len - chunk.count_ones() as usize;
+                unsafe { or_write_u64_bytes(write_data, write_byte, chunk) };
+                (null_count, len)
             }
-        });
+        } else if write_shift == 0 {
+            // only read shifting necessary
+            let len = 64 - 8; // 56 bits so the next set_upto_64bits call will 
see write_shift == 0
+            let chunk = (chunk >> read_shift) & 0x00FFFFFFFFFFFFFF; // 56 bits 
mask
+            let null_count = len - chunk.count_ones() as usize;
+            unsafe { write_u64_bytes(write_data, write_byte, chunk) };
+            (null_count, len)
+        } else {
+            let len = 64 - std::cmp::max(read_shift, write_shift);
+            let chunk = (chunk >> read_shift) << write_shift;
+            let null_count = len - chunk.count_ones() as usize;
+            unsafe { or_write_u64_bytes(write_data, write_byte, chunk) };
+            (null_count, len)
+        }
+    } else if len == 1 {
+        let byte_chunk = (unsafe { data.get_unchecked(read_byte) } >> 
read_shift) & 1;
+        unsafe { *write_data.get_unchecked_mut(write_byte) |= byte_chunk << 
write_shift };
+        ((byte_chunk ^ 1) as usize, 1)
+    } else {
+        let len = std::cmp::min(len, 64 - std::cmp::max(read_shift, 
write_shift));
+        let bytes = ceil(len + read_shift, 8);
+        // SAFETY: the args of `read_bytes_to_u64` are valid as read_byte + 
bytes <= data.len()
+        let chunk = unsafe { read_bytes_to_u64(data, read_byte, bytes) };
+        let mask = u64::MAX >> (64 - len);
+        let chunk = (chunk >> read_shift) & mask; // masking to read `len` 
bits only
+        let chunk = chunk << write_shift; // shifting back to align with 
`write_data`
+        let null_count = len - chunk.count_ones() as usize;
+        let bytes = ceil(len + write_shift, 8);
+        for (i, c) in chunk.to_le_bytes().iter().enumerate().take(bytes) {
+            unsafe { *write_data.get_unchecked_mut(write_byte + i) |= c };
+        }
+        (null_count, len)
+    }
+}
 
-    null_count as usize
+/// # Safety
+/// The caller must ensure all arguments are within the valid range.
+#[inline]
+unsafe fn read_bytes_to_u64(data: &[u8], offset: usize, count: usize) -> u64 {
+    debug_assert!(count <= 8);
+    let mut tmp = std::mem::MaybeUninit::<u64>::new(0);
+    let src = data.as_ptr().add(offset);
+    unsafe {
+        std::ptr::copy_nonoverlapping(src, tmp.as_mut_ptr() as *mut u8, count);
+        tmp.assume_init()
+    }
+}
+
+/// # Safety
+/// The caller must ensure `data` has `offset..(offset + 8)` range
+#[inline]
+unsafe fn write_u64_bytes(data: &mut [u8], offset: usize, chunk: u64) {
+    let ptr = data.as_mut_ptr().add(offset) as *mut u64;
+    ptr.write_unaligned(chunk);
+}
+
+/// Similar to `write_u64_bytes`, but this method ORs the offset addressed 
`data` and `chunk`
+/// instead of overwriting
+///
+/// # Safety
+/// The caller must ensure `data` has `offset..(offset + 8)` range
+#[inline]
+unsafe fn or_write_u64_bytes(data: &mut [u8], offset: usize, chunk: u64) {
+    let ptr = data.as_mut_ptr().add(offset);
+    let chunk = chunk | (*ptr) as u64;
+    (ptr as *mut u64).write_unaligned(chunk);
 }
 
 #[cfg(test)]
@@ -185,4 +281,40 @@ mod tests {
         assert_eq!(destination, expected_data);
         assert_eq!(result, expected_null_count);
     }
+
+    #[test]
+    fn test_set_upto_64bits() {
+        // len >= 64
+        let write_data: &mut [u8] = &mut [0; 9];
+        let data: &[u8] = &[
+            0b00000001, 0b00000001, 0b00000001, 0b00000001, 0b00000001, 
0b00000001, 0b00000001,
+            0b00000001, 0b00000001,
+        ];
+        let offset_write = 1;
+        let offset_read = 0;
+        let len = 65;
+        let (n, len_set) =
+            unsafe { set_upto_64bits(write_data, data, offset_write, 
offset_read, len) };
+        assert_eq!(n, 55);
+        assert_eq!(len_set, 63);
+        assert_eq!(
+            write_data,
+            &[
+                0b00000010, 0b00000010, 0b00000010, 0b00000010, 0b00000010, 
0b00000010, 0b00000010,
+                0b00000010, 0b00000000
+            ]
+        );
+
+        // len = 1
+        let write_data: &mut [u8] = &mut [0b00000000];
+        let data: &[u8] = &[0b00000001];
+        let offset_write = 1;
+        let offset_read = 0;
+        let len = 1;
+        let (n, len_set) =
+            unsafe { set_upto_64bits(write_data, data, offset_write, 
offset_read, len) };
+        assert_eq!(n, 0);
+        assert_eq!(len_set, 1);
+        assert_eq!(write_data, &[0b00000010]);
+    }
 }

Reply via email to