This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new e88e5aa92e5 Implement `filter` kernel for byte view arrays. (#5624)
e88e5aa92e5 is described below

commit e88e5aa92e5419b01073cc0ea6649303fcb86b5f
Author: RinChanNOW <[email protected]>
AuthorDate: Mon Apr 15 20:02:12 2024 +0800

    Implement `filter` kernel for byte view arrays. (#5624)
    
    * Implement `filter` kernel for byte view arrays.
    
    * Add unit tests and fix.
    
    * Deprecate `ArrowPrimitiveType::get_byte_width`.
    
    * Add string view filter benchmark.
---
 arrow-arith/src/arity.rs        |   3 +-
 arrow-array/src/types.rs        |   1 +
 arrow-buffer/src/native.rs      |   5 ++
 arrow-data/src/data.rs          |   5 ++
 arrow-select/src/filter.rs      | 112 ++++++++++++++++++++++++++++++++++++----
 arrow/benches/filter_kernels.rs |  26 ++++++++++
 6 files changed, 142 insertions(+), 10 deletions(-)

diff --git a/arrow-arith/src/arity.rs b/arrow-arith/src/arity.rs
index 3d8214d89dc..ff5c8e822cc 100644
--- a/arrow-arith/src/arity.rs
+++ b/arrow-arith/src/arity.rs
@@ -21,6 +21,7 @@ use arrow_array::builder::BufferBuilder;
 use arrow_array::types::ArrowDictionaryKeyType;
 use arrow_array::*;
 use arrow_buffer::buffer::NullBuffer;
+use arrow_buffer::ArrowNativeType;
 use arrow_buffer::{Buffer, MutableBuffer};
 use arrow_data::ArrayData;
 use arrow_schema::ArrowError;
@@ -386,7 +387,7 @@ where
     O: ArrowPrimitiveType,
     F: Fn(A::Item, B::Item) -> Result<O::Native, ArrowError>,
 {
-    let mut buffer = MutableBuffer::new(len * O::get_byte_width());
+    let mut buffer = MutableBuffer::new(len * O::Native::get_byte_width());
     for idx in 0..len {
         unsafe {
             buffer.push_unchecked(op(a.value_unchecked(idx), 
b.value_unchecked(idx))?);
diff --git a/arrow-array/src/types.rs b/arrow-array/src/types.rs
index e33f7bde7cb..038b2a291f5 100644
--- a/arrow-array/src/types.rs
+++ b/arrow-array/src/types.rs
@@ -59,6 +59,7 @@ pub trait ArrowPrimitiveType: primitive::PrimitiveTypeSealed 
+ 'static {
     const DATA_TYPE: DataType;
 
     /// Returns the byte width of this primitive type.
+    #[deprecated(note = "Use ArrowNativeType::get_byte_width")]
     fn get_byte_width() -> usize {
         std::mem::size_of::<Self::Native>()
     }
diff --git a/arrow-buffer/src/native.rs b/arrow-buffer/src/native.rs
index 680974351a4..de665d4e387 100644
--- a/arrow-buffer/src/native.rs
+++ b/arrow-buffer/src/native.rs
@@ -47,6 +47,11 @@ mod private {
 pub trait ArrowNativeType:
     std::fmt::Debug + Send + Sync + Copy + PartialOrd + Default + 
private::Sealed + 'static
 {
+    /// Returns the byte width of this native type.
+    fn get_byte_width() -> usize {
+        std::mem::size_of::<Self>()
+    }
+
     /// Convert native integer type from usize
     ///
     /// Returns `None` if [`Self`] is not an integer or conversion would result
diff --git a/arrow-data/src/data.rs b/arrow-data/src/data.rs
index 358c44e41b4..0ea74f789dd 100644
--- a/arrow-data/src/data.rs
+++ b/arrow-data/src/data.rs
@@ -1770,6 +1770,11 @@ impl ArrayDataBuilder {
         self
     }
 
+    pub fn add_buffers(mut self, bs: Vec<Buffer>) -> Self {
+        self.buffers.extend(bs);
+        self
+    }
+
     pub fn child_data(mut self, v: Vec<ArrayData>) -> Self {
         self.child_data = v;
         self
diff --git a/arrow-select/src/filter.rs b/arrow-select/src/filter.rs
index 2af19ff8505..8e06b07f5ef 100644
--- a/arrow-select/src/filter.rs
+++ b/arrow-select/src/filter.rs
@@ -23,10 +23,10 @@ use std::sync::Arc;
 use arrow_array::builder::BooleanBufferBuilder;
 use arrow_array::cast::AsArray;
 use arrow_array::types::{
-    ArrowDictionaryKeyType, ArrowPrimitiveType, ByteArrayType, RunEndIndexType,
+    ArrowDictionaryKeyType, ArrowPrimitiveType, ByteArrayType, ByteViewType, 
RunEndIndexType,
 };
 use arrow_array::*;
-use arrow_buffer::{bit_util, BooleanBuffer, NullBuffer, RunEndBuffer};
+use arrow_buffer::{bit_util, ArrowNativeType, BooleanBuffer, NullBuffer, 
RunEndBuffer};
 use arrow_buffer::{Buffer, MutableBuffer};
 use arrow_data::bit_iterator::{BitIndexIterator, BitSliceIterator};
 use arrow_data::transform::MutableArrayData;
@@ -333,12 +333,18 @@ fn filter_array(values: &dyn Array, predicate: 
&FilterPredicate) -> Result<Array
             DataType::LargeUtf8 => {
                 Ok(Arc::new(filter_bytes(values.as_string::<i64>(), 
predicate)))
             }
+            DataType::Utf8View => {
+                Ok(Arc::new(filter_byte_view(values.as_string_view(), 
predicate)))
+            }
             DataType::Binary => {
                 Ok(Arc::new(filter_bytes(values.as_binary::<i32>(), 
predicate)))
             }
             DataType::LargeBinary => {
                 Ok(Arc::new(filter_bytes(values.as_binary::<i64>(), 
predicate)))
             }
+            DataType::BinaryView => {
+                Ok(Arc::new(filter_byte_view(values.as_binary_view(), 
predicate)))
+            }
             DataType::RunEndEncoded(_, _) => {
                 downcast_run_array!{
                     values => Ok(Arc::new(filter_run_end_array(values, 
predicate)?)),
@@ -508,12 +514,8 @@ fn filter_boolean(array: &BooleanArray, predicate: 
&FilterPredicate) -> BooleanA
     BooleanArray::from(data)
 }
 
-/// `filter` implementation for primitive arrays
-fn filter_primitive<T>(array: &PrimitiveArray<T>, predicate: &FilterPredicate) 
-> PrimitiveArray<T>
-where
-    T: ArrowPrimitiveType,
-{
-    let values = array.values();
+#[inline(never)]
+fn filter_native<T: ArrowNativeType>(values: &[T], predicate: 
&FilterPredicate) -> Buffer {
     assert!(values.len() >= predicate.filter.len());
 
     let buffer = match &predicate.strategy {
@@ -546,9 +548,19 @@ where
         IterationStrategy::All | IterationStrategy::None => unreachable!(),
     };
 
+    buffer.into()
+}
+
+/// `filter` implementation for primitive arrays
+fn filter_primitive<T>(array: &PrimitiveArray<T>, predicate: &FilterPredicate) 
-> PrimitiveArray<T>
+where
+    T: ArrowPrimitiveType,
+{
+    let values = array.values();
+    let buffer = filter_native(values, predicate);
     let mut builder = ArrayDataBuilder::new(array.data_type().clone())
         .len(predicate.count)
-        .add_buffer(buffer.into());
+        .add_buffer(buffer);
 
     if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), 
predicate) {
         builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
@@ -673,6 +685,25 @@ where
     GenericByteArray::from(data)
 }
 
+/// `filter` implementation for byte view arrays.
+fn filter_byte_view<T: ByteViewType>(
+    array: &GenericByteViewArray<T>,
+    predicate: &FilterPredicate,
+) -> GenericByteViewArray<T> {
+    let new_view_buffer = filter_native(array.views(), predicate);
+
+    let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
+        .len(predicate.count)
+        .add_buffer(new_view_buffer)
+        .add_buffers(array.data_buffers().to_vec());
+
+    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), 
predicate) {
+        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
+    }
+
+    GenericByteViewArray::from(unsafe { builder.build_unchecked() })
+}
+
 /// `filter` implementation for dictionaries
 fn filter_dict<T>(array: &DictionaryArray<T>, predicate: &FilterPredicate) -> 
DictionaryArray<T>
 where
@@ -888,6 +919,69 @@ mod tests {
         assert!(d.is_null(1));
     }
 
+    fn _test_filter_byte_view<T>()
+    where
+        T: ByteViewType,
+        str: AsRef<T::Native>,
+        T::Native: PartialEq,
+    {
+        let array = {
+            // ["hello", "world", null, "large payload over 12 bytes", "lulu"]
+            let mut builder = GenericByteViewBuilder::<T>::new();
+            builder.append_value("hello");
+            builder.append_value("world");
+            builder.append_null();
+            builder.append_value("large payload over 12 bytes");
+            builder.append_value("lulu");
+            builder.finish()
+        };
+
+        {
+            let predicate = BooleanArray::from(vec![true, false, true, true, 
false]);
+            let actual = filter(&array, &predicate).unwrap();
+
+            assert_eq!(actual.len(), 3);
+
+            let expected = {
+                // ["hello", null, "large payload over 12 bytes"]
+                let mut builder = GenericByteViewBuilder::<T>::new();
+                builder.append_value("hello");
+                builder.append_null();
+                builder.append_value("large payload over 12 bytes");
+                builder.finish()
+            };
+
+            assert_eq!(actual.as_ref(), &expected);
+        }
+
+        {
+            let predicate = BooleanArray::from(vec![true, false, false, false, 
true]);
+            let actual = filter(&array, &predicate).unwrap();
+
+            assert_eq!(actual.len(), 2);
+
+            let expected = {
+                // ["hello", "lulu"]
+                let mut builder = GenericByteViewBuilder::<T>::new();
+                builder.append_value("hello");
+                builder.append_value("lulu");
+                builder.finish()
+            };
+
+            assert_eq!(actual.as_ref(), &expected);
+        }
+    }
+
+    #[test]
+    fn test_filter_string_view() {
+        _test_filter_byte_view::<StringViewType>()
+    }
+
+    #[test]
+    fn test_filter_binary_view() {
+        _test_filter_byte_view::<BinaryViewType>()
+    }
+
     #[test]
     fn test_filter_array_slice_with_null() {
         let a = Int32Array::from(vec![Some(5), None, Some(7), Some(8), 
Some(9)]).slice(1, 4);
diff --git a/arrow/benches/filter_kernels.rs b/arrow/benches/filter_kernels.rs
index 50f3cb40094..e48b5302241 100644
--- a/arrow/benches/filter_kernels.rs
+++ b/arrow/benches/filter_kernels.rs
@@ -214,6 +214,32 @@ fn add_benchmark(c: &mut Criterion) {
     c.bench_function("filter single record batch", |b| {
         b.iter(|| filter_record_batch(&batch, &filter_array))
     });
+
+    let data_array = create_string_view_array_with_len(size, 0.5, 4, false);
+    c.bench_function("filter context short string view (kept 1/2)", |b| {
+        b.iter(|| bench_built_filter(&filter, &data_array))
+    });
+    c.bench_function(
+        "filter context short string view high selectivity (kept 1023/1024)",
+        |b| b.iter(|| bench_built_filter(&dense_filter, &data_array)),
+    );
+    c.bench_function(
+        "filter context short string view low selectivity (kept 1/1024)",
+        |b| b.iter(|| bench_built_filter(&sparse_filter, &data_array)),
+    );
+
+    let data_array = create_string_view_array_with_len(size, 0.5, 4, true);
+    c.bench_function("filter context mixed string view (kept 1/2)", |b| {
+        b.iter(|| bench_built_filter(&filter, &data_array))
+    });
+    c.bench_function(
+        "filter context mixed string view high selectivity (kept 1023/1024)",
+        |b| b.iter(|| bench_built_filter(&dense_filter, &data_array)),
+    );
+    c.bench_function(
+        "filter context mixed string view low selectivity (kept 1/1024)",
+        |b| b.iter(|| bench_built_filter(&sparse_filter, &data_array)),
+    );
 }
 
 criterion_group!(benches, add_benchmark);

Reply via email to