This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 5c5a94a11 Implement specialized filter kernel for `FixedSizeByteArray`
(#6178)
5c5a94a11 is described below
commit 5c5a94a11f01a286dd03b18af0f11c327a9accc6
Author: pn <[email protected]>
AuthorDate: Sat Aug 10 00:40:29 2024 +0800
Implement specialized filter kernel for `FixedSizeByteArray` (#6178)
* refactor filter for FixedSizeByteArray
* fix expect
* remove benchmark code
* fix
* remove from_trusted_len_iter_slice_u8
* fmt
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
arrow-buffer/src/buffer/mutable.rs | 2 +-
arrow-select/src/filter.rs | 133 +++++++++++++++++++++++++++++++++++++
2 files changed, 134 insertions(+), 1 deletion(-)
diff --git a/arrow-buffer/src/buffer/mutable.rs
b/arrow-buffer/src/buffer/mutable.rs
index e08d9c190..7fcbd89dd 100644
--- a/arrow-buffer/src/buffer/mutable.rs
+++ b/arrow-buffer/src/buffer/mutable.rs
@@ -896,7 +896,7 @@ mod tests {
#[test]
fn test_from_trusted_len_iter() {
let iter = vec![1u32, 2].into_iter();
- let buf = unsafe { Buffer::from_trusted_len_iter(iter) };
+ let buf = unsafe { MutableBuffer::from_trusted_len_iter(iter) };
assert_eq!(8, buf.len());
assert_eq!(&[1u8, 0, 0, 0, 2, 0, 0, 0], buf.as_slice());
}
diff --git a/arrow-select/src/filter.rs b/arrow-select/src/filter.rs
index 8e06b07f5..d72794b3e 100644
--- a/arrow-select/src/filter.rs
+++ b/arrow-select/src/filter.rs
@@ -345,6 +345,9 @@ fn filter_array(values: &dyn Array, predicate:
&FilterPredicate) -> Result<Array
DataType::BinaryView => {
Ok(Arc::new(filter_byte_view(values.as_binary_view(),
predicate)))
}
+ DataType::FixedSizeBinary(_) => {
+
Ok(Arc::new(filter_fixed_size_binary(values.as_fixed_size_binary(), predicate)))
+ }
DataType::RunEndEncoded(_, _) => {
downcast_run_array!{
values => Ok(Arc::new(filter_run_end_array(values,
predicate)?)),
@@ -704,6 +707,64 @@ fn filter_byte_view<T: ByteViewType>(
GenericByteViewArray::from(unsafe { builder.build_unchecked() })
}
+fn filter_fixed_size_binary(
+ array: &FixedSizeBinaryArray,
+ predicate: &FilterPredicate,
+) -> FixedSizeBinaryArray {
+ let values: &[u8] = array.values();
+ let value_length = array.value_length() as usize;
+ let calculate_offset_from_index = |index: usize| index * value_length;
+ let buffer = match &predicate.strategy {
+ IterationStrategy::SlicesIterator => {
+ let mut buffer = MutableBuffer::with_capacity(predicate.count *
value_length);
+ for (start, end) in SlicesIterator::new(&predicate.filter) {
+ buffer.extend_from_slice(
+
&values[calculate_offset_from_index(start)..calculate_offset_from_index(end)],
+ );
+ }
+ buffer
+ }
+ IterationStrategy::Slices(slices) => {
+ let mut buffer = MutableBuffer::with_capacity(predicate.count *
value_length);
+ for (start, end) in slices {
+ buffer.extend_from_slice(
+
&values[calculate_offset_from_index(*start)..calculate_offset_from_index(*end)],
+ );
+ }
+ buffer
+ }
+ IterationStrategy::IndexIterator => {
+ let iter = IndexIterator::new(&predicate.filter,
predicate.count).map(|x| {
+
&values[calculate_offset_from_index(x)..calculate_offset_from_index(x + 1)]
+ });
+
+ let mut buffer = MutableBuffer::new(predicate.count *
value_length);
+ iter.for_each(|item| buffer.extend_from_slice(item));
+ buffer
+ }
+ IterationStrategy::Indices(indices) => {
+ let iter = indices.iter().map(|x| {
+
&values[calculate_offset_from_index(*x)..calculate_offset_from_index(*x + 1)]
+ });
+
+ let mut buffer = MutableBuffer::new(predicate.count *
value_length);
+ iter.for_each(|item| buffer.extend_from_slice(item));
+ buffer
+ }
+ IterationStrategy::All | IterationStrategy::None => unreachable!(),
+ };
+ let mut builder = ArrayDataBuilder::new(array.data_type().clone())
+ .len(predicate.count)
+ .add_buffer(buffer.into());
+
+ if let Some((null_count, nulls)) = filter_null_mask(array.nulls(),
predicate) {
+ builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
+ }
+
+ let data = unsafe { builder.build_unchecked() };
+ FixedSizeBinaryArray::from(data)
+}
+
/// `filter` implementation for dictionaries
fn filter_dict<T>(array: &DictionaryArray<T>, predicate: &FilterPredicate) ->
DictionaryArray<T>
where
@@ -982,6 +1043,78 @@ mod tests {
_test_filter_byte_view::<BinaryViewType>()
}
+ #[test]
+ fn test_filter_fixed_binary() {
+ let v1 = [1_u8, 2];
+ let v2 = [3_u8, 4];
+ let v3 = [5_u8, 6];
+ let v = vec![&v1, &v2, &v3];
+ let a = FixedSizeBinaryArray::from(v);
+ let b = BooleanArray::from(vec![true, false, true]);
+ let c = filter(&a, &b).unwrap();
+ let d = c
+ .as_ref()
+ .as_any()
+ .downcast_ref::<FixedSizeBinaryArray>()
+ .unwrap();
+ assert_eq!(d.len(), 2);
+ assert_eq!(d.value(0), &v1);
+ assert_eq!(d.value(1), &v3);
+ let c2 = FilterBuilder::new(&b)
+ .optimize()
+ .build()
+ .filter(&a)
+ .unwrap();
+ let d2 = c2
+ .as_ref()
+ .as_any()
+ .downcast_ref::<FixedSizeBinaryArray>()
+ .unwrap();
+ assert_eq!(d, d2);
+
+ let b = BooleanArray::from(vec![false, false, false]);
+ let c = filter(&a, &b).unwrap();
+ let d = c
+ .as_ref()
+ .as_any()
+ .downcast_ref::<FixedSizeBinaryArray>()
+ .unwrap();
+ assert_eq!(d.len(), 0);
+
+ let b = BooleanArray::from(vec![true, true, true]);
+ let c = filter(&a, &b).unwrap();
+ let d = c
+ .as_ref()
+ .as_any()
+ .downcast_ref::<FixedSizeBinaryArray>()
+ .unwrap();
+ assert_eq!(d.len(), 3);
+ assert_eq!(d.value(0), &v1);
+ assert_eq!(d.value(1), &v2);
+ assert_eq!(d.value(2), &v3);
+
+ let b = BooleanArray::from(vec![false, false, true]);
+ let c = filter(&a, &b).unwrap();
+ let d = c
+ .as_ref()
+ .as_any()
+ .downcast_ref::<FixedSizeBinaryArray>()
+ .unwrap();
+ assert_eq!(d.len(), 1);
+ assert_eq!(d.value(0), &v3);
+ let c2 = FilterBuilder::new(&b)
+ .optimize()
+ .build()
+ .filter(&a)
+ .unwrap();
+ let d2 = c2
+ .as_ref()
+ .as_any()
+ .downcast_ref::<FixedSizeBinaryArray>()
+ .unwrap();
+ assert_eq!(d, d2);
+ }
+
#[test]
fn test_filter_array_slice_with_null() {
let a = Int32Array::from(vec![Some(5), None, Some(7), Some(8),
Some(9)]).slice(1, 4);