Dandandan commented on code in PR #6288: URL: https://github.com/apache/arrow-rs/pull/6288#discussion_r1751886691
########## arrow-buffer/src/util/bit_mask.rs: ########## @@ -30,35 +29,129 @@ pub fn set_bits( offset_read: usize, len: usize, ) -> usize { + assert!(offset_write + len <= write_data.len() * 8); + assert!(offset_read + len <= data.len() * 8); let mut null_count = 0; - - let mut bits_to_align = offset_write % 8; - if bits_to_align > 0 { - bits_to_align = std::cmp::min(len, 8 - bits_to_align); + let mut acc = 0; + while len > acc { + // SAFETY: the arguments to `set_upto_64bits` are within the valid range because + // (offset_write + acc) + (len - acc) == offset_write + len <= write_data.len() * 8 + // (offset_read + acc) + (len - acc) == offset_read + len <= data.len() * 8 + let (n, len_set) = unsafe { + set_upto_64bits( + write_data, + data, + offset_write + acc, + offset_read + acc, + len - acc, + ) + }; + null_count += n; + acc += len_set; } - let mut write_byte_index = ceil(offset_write + bits_to_align, 8); - - // Set full bytes provided by bit chunk iterator (which iterates in 64 bits at a time) - let chunks = BitChunks::new(data, offset_read + bits_to_align, len - bits_to_align); - chunks.iter().for_each(|chunk| { - null_count += chunk.count_zeros(); - write_data[write_byte_index..write_byte_index + 8].copy_from_slice(&chunk.to_le_bytes()); - write_byte_index += 8; - }); - - // Set individual bits both to align write_data to a byte offset and the remainder bits not covered by the bit chunk iterator - let remainder_offset = len - chunks.remainder_len(); - (0..bits_to_align) - .chain(remainder_offset..len) - .for_each(|i| { - if get_bit(data, offset_read + i) { - set_bit(write_data, offset_write + i); + + null_count +} + +/// # Safety +/// The caller must ensure all arguments are within the valid range. +#[inline] +unsafe fn set_upto_64bits( + write_data: &mut [u8], + data: &[u8], + offset_write: usize, + offset_read: usize, + len: usize, +) -> (usize, usize) { + let read_byte = offset_read / 8; + let read_shift = offset_read % 8; + let write_byte = offset_write / 8; + let write_shift = offset_write % 8; + + if len >= 64 { + let chunk = unsafe { (data.as_ptr().add(read_byte) as *const u64).read_unaligned() }; + if read_shift == 0 { + if write_shift == 0 { + // no shifting necessary + let len = 64; + let null_count = chunk.count_zeros() as usize; + unsafe { write_u64_bytes(write_data, write_byte, chunk) }; + (null_count, len) } else { - null_count += 1; + // only write shifting necessary + let len = 64 - write_shift; + let chunk = chunk << write_shift; + let null_count = len - chunk.count_ones() as usize; + unsafe { or_write_u64_bytes(write_data, write_byte, chunk) }; + (null_count, len) } - }); + } else if write_shift == 0 { + // only read shifting necessary + let len = 64 - 8; // 56 bits so the next set_upto_64bits call will see write_shift == 0 + let chunk = (chunk >> read_shift) & 0x00FFFFFFFFFFFFFF; // 56 bits mask + let null_count = len - chunk.count_ones() as usize; + unsafe { write_u64_bytes(write_data, write_byte, chunk) }; + (null_count, len) + } else { + let len = 64 - std::cmp::max(read_shift, write_shift); + let chunk = (chunk >> read_shift) << write_shift; + let null_count = len - chunk.count_ones() as usize; + unsafe { or_write_u64_bytes(write_data, write_byte, chunk) }; + (null_count, len) + } + } else if len == 1 { + let byte_chunk = (unsafe { data.get_unchecked(read_byte) } >> read_shift) & 1; + let ptr = write_data.as_mut_ptr(); + unsafe { *ptr.add(write_byte) |= byte_chunk << write_shift }; + ((byte_chunk ^ 1) as usize, 1) + } else { + let len = std::cmp::min(len, 64 - std::cmp::max(read_shift, write_shift)); + let bytes = ceil(len + read_shift, 8); + let chunk = unsafe { read_bytes_to_u64(data, read_byte, bytes) }; Review Comment: Could you add some `// SAFETY:` explanations to unsafe usages? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org