This is an automated email from the ASF dual-hosted git repository.

nevime pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 8b35014  ARROW-5350: [Rust] Allow filtering on simple lists
8b35014 is described below

commit 8b350146f92388897af75a52e77f43ae939b6bef
Author: Neville Dipale <nevilled...@gmail.com>
AuthorDate: Sun Oct 18 07:46:19 2020 +0200

    ARROW-5350: [Rust] Allow filtering on simple lists
    
    This extends filters to simple lists.  CC @yordan-pavlov
    
    Closes #8364 from nevi-me/ARROW-5350
    
    Lead-authored-by: Neville Dipale <nevilled...@gmail.com>
    Co-authored-by: Yordan Pavlov 
<64363766+yordan-pav...@users.noreply.github.com>
    Signed-off-by: Neville Dipale <nevilled...@gmail.com>
---
 rust/arrow/src/compute/kernels/filter.rs | 372 ++++++++++++++++++++++++++++++-
 1 file changed, 367 insertions(+), 5 deletions(-)

diff --git a/rust/arrow/src/compute/kernels/filter.rs 
b/rust/arrow/src/compute/kernels/filter.rs
index cb5812c..8286f5c6a 100644
--- a/rust/arrow/src/compute/kernels/filter.rs
+++ b/rust/arrow/src/compute/kernels/filter.rs
@@ -17,8 +17,9 @@
 
 //! Defines miscellaneous array kernels.
 
+use crate::array::PrimitiveArrayOps;
 use crate::array::*;
-use crate::datatypes::{ArrowNumericType, DataType, TimeUnit};
+use crate::datatypes::*;
 use crate::error::{ArrowError, Result};
 use crate::record_batch::RecordBatch;
 use crate::{
@@ -161,7 +162,7 @@ fn filter_array_impl(
         // foreach u64 batch
         let filter_batch = *filter_batch;
         if filter_batch == 0 {
-            // if batch == 0: skip
+            // if batch == 0, all items are filtered out, so skip entire batch
             continue;
         } else if filter_batch == all_ones_batch {
             // if batch == all 1s: copy all 64 values in one go
@@ -230,6 +231,86 @@ macro_rules! filter_dictionary_array {
     }};
 }
 
+macro_rules! filter_primitive_item_list_array {
+    ($context:expr, $array:expr, $item_type:ident, $list_type:ident, 
$list_builder_type:ident) => {{
+        let input_array = 
$array.as_any().downcast_ref::<$list_type>().unwrap();
+        let values_builder = 
PrimitiveBuilder::<$item_type>::new($context.filtered_count);
+        let mut builder = $list_builder_type::new(values_builder);
+        for i in 0..$context.filter_u64.len() {
+            // foreach u64 batch
+            let filter_batch = $context.filter_u64[i];
+            if filter_batch == 0 {
+                // if batch == 0, all items are filtered out, so skip entire 
batch
+                continue;
+            }
+            for j in 0..64 {
+                // foreach bit in batch:
+                if (filter_batch & $context.filter_mask[j]) != 0 {
+                    let data_index = (i * 64) + j;
+                    if input_array.is_null(data_index) {
+                        builder.append(false)?;
+                    } else {
+                        let this_inner_list = input_array.value(data_index);
+                        let inner_list = this_inner_list
+                            .as_any()
+                            .downcast_ref::<PrimitiveArray<$item_type>>()
+                            .unwrap();
+                        for k in 0..inner_list.len() {
+                            if inner_list.is_null(k) {
+                                builder.values().append_null()?;
+                            } else {
+                                
builder.values().append_value(inner_list.value(k))?;
+                            }
+                        }
+                        builder.append(true)?;
+                    }
+                }
+            }
+        }
+        Ok(Arc::new(builder.finish()))
+    }};
+}
+
+macro_rules! filter_non_primitive_item_list_array {
+    ($context:expr, $array:expr, $item_array_type:ident, $item_builder:ident, 
$list_type:ident, $list_builder_type:ident) => {{
+        let input_array = 
$array.as_any().downcast_ref::<$list_type>().unwrap();
+        let values_builder = $item_builder::new($context.filtered_count);
+        let mut builder = $list_builder_type::new(values_builder);
+        for i in 0..$context.filter_u64.len() {
+            // foreach u64 batch
+            let filter_batch = $context.filter_u64[i];
+            if filter_batch == 0 {
+                // if batch == 0, all items are filtered out, so skip entire 
batch
+                continue;
+            }
+            for j in 0..64 {
+                // foreach bit in batch:
+                if (filter_batch & $context.filter_mask[j]) != 0 {
+                    let data_index = (i * 64) + j;
+                    if input_array.is_null(data_index) {
+                        builder.append(false)?;
+                    } else {
+                        let this_inner_list = input_array.value(data_index);
+                        let inner_list = this_inner_list
+                            .as_any()
+                            .downcast_ref::<$item_array_type>()
+                            .unwrap();
+                        for k in 0..inner_list.len() {
+                            if inner_list.is_null(k) {
+                                builder.values().append_null()?;
+                            } else {
+                                
builder.values().append_value(inner_list.value(k))?;
+                            }
+                        }
+                        builder.append(true)?;
+                    }
+                }
+            }
+        }
+        Ok(Arc::new(builder.finish()))
+    }};
+}
+
 impl FilterContext {
     /// Returns a new instance of FilterContext
     pub fn new(filter_array: &BooleanArray) -> Result<Self> {
@@ -285,7 +366,7 @@ impl FilterContext {
                     // foreach u64 batch
                     let filter_batch = self.filter_u64[i];
                     if filter_batch == 0 {
-                        // if batch == 0: skip
+                        // if batch == 0, all items are filtered out, so skip 
entire batch
                         continue;
                     }
                     for j in 0..64 {
@@ -347,7 +428,7 @@ impl FilterContext {
                     // foreach u64 batch
                     let filter_batch = self.filter_u64[i];
                     if filter_batch == 0 {
-                        // if batch == 0: skip
+                        // if batch == 0, all items are filtered out, so skip 
entire batch
                         continue;
                     }
                     for j in 0..64 {
@@ -371,7 +452,7 @@ impl FilterContext {
                     // foreach u64 batch
                     let filter_batch = self.filter_u64[i];
                     if filter_batch == 0 {
-                        // if batch == 0: skip
+                        // if batch == 0, all items are filtered out, so skip 
entire batch
                         continue;
                     }
                     for j in 0..64 {
@@ -408,6 +489,232 @@ impl FilterContext {
                     key_type, value_type
                 )))
             }
+            DataType::List(dt) => match &**dt {
+                DataType::UInt8 => {
+                    filter_primitive_item_list_array!(self, array, UInt8Type, 
ListArray, ListBuilder)
+                }
+                DataType::UInt16 => {
+                    filter_primitive_item_list_array!(self, array, UInt16Type, 
ListArray, ListBuilder)
+                }
+                DataType::UInt32 => {
+                    filter_primitive_item_list_array!(self, array, UInt32Type, 
ListArray, ListBuilder)
+                }
+                DataType::UInt64 => {
+                    filter_primitive_item_list_array!(self, array, UInt64Type, 
ListArray, ListBuilder)
+                }
+                DataType::Int8 => filter_primitive_item_list_array!(self, 
array, Int8Type, ListArray, ListBuilder),
+                DataType::Int16 => {
+                    filter_primitive_item_list_array!(self, array, Int16Type, 
ListArray, ListBuilder)
+                }
+                DataType::Int32 => {
+                    filter_primitive_item_list_array!(self, array, Int32Type, 
ListArray, ListBuilder)
+                }
+                DataType::Int64 => {
+                    filter_primitive_item_list_array!(self, array, Int64Type, 
ListArray, ListBuilder)
+                }
+                DataType::Float32 => {
+                    filter_primitive_item_list_array!(self, array, 
Float32Type, ListArray, ListBuilder)
+                }
+                DataType::Float64 => {
+                    filter_primitive_item_list_array!(self, array, 
Float64Type, ListArray, ListBuilder)
+                }
+                DataType::Boolean => {
+                    filter_primitive_item_list_array!(self, array, 
BooleanType, ListArray, ListBuilder)
+                }
+                DataType::Date32(_) => {
+                    filter_primitive_item_list_array!(self, array, Date32Type, 
ListArray, ListBuilder)
+                }
+                DataType::Date64(_) => {
+                    filter_primitive_item_list_array!(self, array, Date64Type, 
ListArray, ListBuilder)
+                }
+                DataType::Time32(TimeUnit::Second) => {
+                    filter_primitive_item_list_array!(self, array, 
Time32SecondType, ListArray, ListBuilder)
+                }
+                DataType::Time32(TimeUnit::Millisecond) => {
+                    filter_primitive_item_list_array!(self, array, 
Time32MillisecondType, ListArray, ListBuilder)
+                }
+                DataType::Time64(TimeUnit::Microsecond) => {
+                    filter_primitive_item_list_array!(self, array, 
Time64MicrosecondType, ListArray, ListBuilder)
+                }
+                DataType::Time64(TimeUnit::Nanosecond) => {
+                    filter_primitive_item_list_array!(self, array, 
Time64NanosecondType, ListArray, ListBuilder)
+                }
+                DataType::Duration(TimeUnit::Second) => {
+                    filter_primitive_item_list_array!(self, array, 
DurationSecondType, ListArray, ListBuilder)
+                }
+                DataType::Duration(TimeUnit::Millisecond) => {
+                    filter_primitive_item_list_array!(self, array, 
DurationMillisecondType, ListArray, ListBuilder)
+                }
+                DataType::Duration(TimeUnit::Microsecond) => {
+                    filter_primitive_item_list_array!(self, array, 
DurationMicrosecondType, ListArray, ListBuilder)
+                }
+                DataType::Duration(TimeUnit::Nanosecond) => {
+                    filter_primitive_item_list_array!(self, array, 
DurationNanosecondType, ListArray, ListBuilder)
+                }
+                DataType::Timestamp(TimeUnit::Second, _) => {
+                    filter_primitive_item_list_array!(self, array, 
TimestampSecondType, ListArray, ListBuilder)
+                }
+                DataType::Timestamp(TimeUnit::Millisecond, _) => {
+                    filter_primitive_item_list_array!(self, array, 
TimestampMillisecondType, ListArray, ListBuilder)
+                }
+                DataType::Timestamp(TimeUnit::Microsecond, _) => {
+                    filter_primitive_item_list_array!(self, array, 
TimestampMicrosecondType, ListArray, ListBuilder)
+                }
+                DataType::Timestamp(TimeUnit::Nanosecond, _) => {
+                    filter_primitive_item_list_array!(self, array, 
TimestampNanosecondType, ListArray, ListBuilder)
+                }
+                DataType::Binary => filter_non_primitive_item_list_array!(
+                    self,
+                    array,
+                    BinaryArray,
+                    BinaryBuilder,
+                    ListArray,
+                    ListBuilder
+                ),
+                DataType::LargeBinary => filter_non_primitive_item_list_array!(
+                    self,
+                    array,
+                    LargeBinaryArray,
+                    LargeBinaryBuilder,
+                    ListArray,
+                    ListBuilder
+                ),
+                DataType::Utf8 => filter_non_primitive_item_list_array!(
+                    self,
+                    array,
+                    StringArray,
+                    StringBuilder,
+                    ListArray
+                    ,ListBuilder
+                ),
+                DataType::LargeUtf8 => filter_non_primitive_item_list_array!(
+                    self,
+                    array,
+                    LargeStringArray,
+                    LargeStringBuilder,
+                    ListArray,
+                    ListBuilder
+                ),
+                other => {
+                    Err(ArrowError::ComputeError(format!(
+                        "filter not supported for List({:?})",
+                        other
+                    )))
+                }
+            }
+            DataType::LargeList(dt) => match &**dt {
+                DataType::UInt8 => {
+                    filter_primitive_item_list_array!(self, array, UInt8Type, 
LargeListArray, LargeListBuilder)
+                }
+                DataType::UInt16 => {
+                    filter_primitive_item_list_array!(self, array, UInt16Type, 
LargeListArray, LargeListBuilder)
+                }
+                DataType::UInt32 => {
+                    filter_primitive_item_list_array!(self, array, UInt32Type, 
LargeListArray, LargeListBuilder)
+                }
+                DataType::UInt64 => {
+                    filter_primitive_item_list_array!(self, array, UInt64Type, 
LargeListArray, LargeListBuilder)
+                }
+                DataType::Int8 => filter_primitive_item_list_array!(self, 
array, Int8Type, LargeListArray, LargeListBuilder),
+                DataType::Int16 => {
+                    filter_primitive_item_list_array!(self, array, Int16Type, 
LargeListArray, LargeListBuilder)
+                }
+                DataType::Int32 => {
+                    filter_primitive_item_list_array!(self, array, Int32Type, 
LargeListArray, LargeListBuilder)
+                }
+                DataType::Int64 => {
+                    filter_primitive_item_list_array!(self, array, Int64Type, 
LargeListArray, LargeListBuilder)
+                }
+                DataType::Float32 => {
+                    filter_primitive_item_list_array!(self, array, 
Float32Type, LargeListArray, LargeListBuilder)
+                }
+                DataType::Float64 => {
+                    filter_primitive_item_list_array!(self, array, 
Float64Type, LargeListArray, LargeListBuilder)
+                }
+                DataType::Boolean => {
+                    filter_primitive_item_list_array!(self, array, 
BooleanType, LargeListArray, LargeListBuilder)
+                }
+                DataType::Date32(_) => {
+                    filter_primitive_item_list_array!(self, array, Date32Type, 
LargeListArray, LargeListBuilder)
+                }
+                DataType::Date64(_) => {
+                    filter_primitive_item_list_array!(self, array, Date64Type, 
LargeListArray, LargeListBuilder)
+                }
+                DataType::Time32(TimeUnit::Second) => {
+                    filter_primitive_item_list_array!(self, array, 
Time32SecondType, LargeListArray, LargeListBuilder)
+                }
+                DataType::Time32(TimeUnit::Millisecond) => {
+                    filter_primitive_item_list_array!(self, array, 
Time32MillisecondType, LargeListArray, LargeListBuilder)
+                }
+                DataType::Time64(TimeUnit::Microsecond) => {
+                    filter_primitive_item_list_array!(self, array, 
Time64MicrosecondType, LargeListArray, LargeListBuilder)
+                }
+                DataType::Time64(TimeUnit::Nanosecond) => {
+                    filter_primitive_item_list_array!(self, array, 
Time64NanosecondType, LargeListArray, LargeListBuilder)
+                }
+                DataType::Duration(TimeUnit::Second) => {
+                    filter_primitive_item_list_array!(self, array, 
DurationSecondType, LargeListArray, LargeListBuilder)
+                }
+                DataType::Duration(TimeUnit::Millisecond) => {
+                    filter_primitive_item_list_array!(self, array, 
DurationMillisecondType, LargeListArray, LargeListBuilder)
+                }
+                DataType::Duration(TimeUnit::Microsecond) => {
+                    filter_primitive_item_list_array!(self, array, 
DurationMicrosecondType, LargeListArray, LargeListBuilder)
+                }
+                DataType::Duration(TimeUnit::Nanosecond) => {
+                    filter_primitive_item_list_array!(self, array, 
DurationNanosecondType, LargeListArray, LargeListBuilder)
+                }
+                DataType::Timestamp(TimeUnit::Second, _) => {
+                    filter_primitive_item_list_array!(self, array, 
TimestampSecondType, LargeListArray, LargeListBuilder)
+                }
+                DataType::Timestamp(TimeUnit::Millisecond, _) => {
+                    filter_primitive_item_list_array!(self, array, 
TimestampMillisecondType, LargeListArray, LargeListBuilder)
+                }
+                DataType::Timestamp(TimeUnit::Microsecond, _) => {
+                    filter_primitive_item_list_array!(self, array, 
TimestampMicrosecondType, LargeListArray, LargeListBuilder)
+                }
+                DataType::Timestamp(TimeUnit::Nanosecond, _) => {
+                    filter_primitive_item_list_array!(self, array, 
TimestampNanosecondType, LargeListArray, LargeListBuilder)
+                }
+                DataType::Binary => filter_non_primitive_item_list_array!(
+                    self,
+                    array,
+                    BinaryArray,
+                    BinaryBuilder,
+                    LargeListArray,
+                    LargeListBuilder
+                ),
+                DataType::LargeBinary => filter_non_primitive_item_list_array!(
+                    self,
+                    array,
+                    LargeBinaryArray,
+                    LargeBinaryBuilder,
+                    LargeListArray,
+                    LargeListBuilder
+                ),
+                DataType::Utf8 => filter_non_primitive_item_list_array!(
+                    self,
+                    array,
+                    StringArray,
+                    StringBuilder,
+                    LargeListArray,
+                    LargeListBuilder
+                ),
+                DataType::LargeUtf8 => filter_non_primitive_item_list_array!(
+                    self,
+                    array,
+                    LargeStringArray,
+                    LargeStringBuilder,
+                    LargeListArray,
+                    LargeListBuilder
+                ),
+                other => {
+                    Err(ArrowError::ComputeError(format!(
+                        "filter not supported for LargeList({:?})",
+                        other
+                    )))
+                }
+            }
             other => Err(ArrowError::ComputeError(format!(
                 "filter not supported for {:?}",
                 other
@@ -500,6 +807,8 @@ pub fn filter_record_batch(
 #[cfg(test)]
 mod tests {
     use super::*;
+    use crate::buffer::Buffer;
+    use crate::datatypes::ToByteSlice;
 
     macro_rules! def_temporal_test {
         ($test:ident, $array_type: ident, $data: expr) => {
@@ -767,4 +1076,57 @@ mod tests {
         assert_eq!("hello", d.value(0));
         assert_eq!("world", d.value(1));
     }
+
+    #[test]
+    fn test_filter_list_array() {
+        let value_data = ArrayData::builder(DataType::Int32)
+            .len(8)
+            .add_buffer(Buffer::from(&[0, 1, 2, 3, 4, 5, 6, 
7].to_byte_slice()))
+            .build();
+
+        let value_offsets = Buffer::from(&[0i64, 3, 6, 8, 8].to_byte_slice());
+
+        let list_data_type = DataType::LargeList(Box::new(DataType::Int32));
+        let list_data = ArrayData::builder(list_data_type)
+            .len(4)
+            .add_buffer(value_offsets)
+            .add_child_data(value_data)
+            .null_bit_buffer(Buffer::from([0b00000111]))
+            .build();
+
+        //  a = [[0, 1, 2], [3, 4, 5], [6, 7], null]
+        let a = LargeListArray::from(list_data);
+        let b = BooleanArray::from(vec![false, true, false, true]);
+        let c = filter(&a, &b).unwrap();
+        let d = c
+            .as_ref()
+            .as_any()
+            .downcast_ref::<LargeListArray>()
+            .unwrap();
+
+        assert_eq!(DataType::Int32, d.value_type());
+
+        // result should be [[3, 4, 5], null]
+        assert_eq!(2, d.len());
+        assert_eq!(1, d.null_count());
+        assert_eq!(true, d.is_null(1));
+
+        assert_eq!(0, d.value_offset(0));
+        assert_eq!(3, d.value_length(0));
+        assert_eq!(3, d.value_offset(1));
+        assert_eq!(0, d.value_length(1));
+        assert_eq!(
+            Buffer::from(&[3, 4, 5].to_byte_slice()),
+            d.values().data().buffers()[0].clone()
+        );
+        assert_eq!(
+            Buffer::from(&[0i64, 3, 3].to_byte_slice()),
+            d.data().buffers()[0].clone()
+        );
+        let inner_list = d.value(0);
+        let inner_list = 
inner_list.as_any().downcast_ref::<Int32Array>().unwrap();
+        assert_eq!(3, inner_list.len());
+        assert_eq!(0, inner_list.null_count());
+        assert_eq!(inner_list, &Int32Array::from(vec![3, 4, 5]));
+    }
 }

Reply via email to