This is an automated email from the ASF dual-hosted git repository. nevime pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push: new 8b35014 ARROW-5350: [Rust] Allow filtering on simple lists 8b35014 is described below commit 8b350146f92388897af75a52e77f43ae939b6bef Author: Neville Dipale <nevilled...@gmail.com> AuthorDate: Sun Oct 18 07:46:19 2020 +0200 ARROW-5350: [Rust] Allow filtering on simple lists This extends filters to simple lists. CC @yordan-pavlov Closes #8364 from nevi-me/ARROW-5350 Lead-authored-by: Neville Dipale <nevilled...@gmail.com> Co-authored-by: Yordan Pavlov <64363766+yordan-pav...@users.noreply.github.com> Signed-off-by: Neville Dipale <nevilled...@gmail.com> --- rust/arrow/src/compute/kernels/filter.rs | 372 ++++++++++++++++++++++++++++++- 1 file changed, 367 insertions(+), 5 deletions(-) diff --git a/rust/arrow/src/compute/kernels/filter.rs b/rust/arrow/src/compute/kernels/filter.rs index cb5812c..8286f5c6a 100644 --- a/rust/arrow/src/compute/kernels/filter.rs +++ b/rust/arrow/src/compute/kernels/filter.rs @@ -17,8 +17,9 @@ //! Defines miscellaneous array kernels. +use crate::array::PrimitiveArrayOps; use crate::array::*; -use crate::datatypes::{ArrowNumericType, DataType, TimeUnit}; +use crate::datatypes::*; use crate::error::{ArrowError, Result}; use crate::record_batch::RecordBatch; use crate::{ @@ -161,7 +162,7 @@ fn filter_array_impl( // foreach u64 batch let filter_batch = *filter_batch; if filter_batch == 0 { - // if batch == 0: skip + // if batch == 0, all items are filtered out, so skip entire batch continue; } else if filter_batch == all_ones_batch { // if batch == all 1s: copy all 64 values in one go @@ -230,6 +231,86 @@ macro_rules! filter_dictionary_array { }}; } +macro_rules! filter_primitive_item_list_array { + ($context:expr, $array:expr, $item_type:ident, $list_type:ident, $list_builder_type:ident) => {{ + let input_array = $array.as_any().downcast_ref::<$list_type>().unwrap(); + let values_builder = PrimitiveBuilder::<$item_type>::new($context.filtered_count); + let mut builder = $list_builder_type::new(values_builder); + for i in 0..$context.filter_u64.len() { + // foreach u64 batch + let filter_batch = $context.filter_u64[i]; + if filter_batch == 0 { + // if batch == 0, all items are filtered out, so skip entire batch + continue; + } + for j in 0..64 { + // foreach bit in batch: + if (filter_batch & $context.filter_mask[j]) != 0 { + let data_index = (i * 64) + j; + if input_array.is_null(data_index) { + builder.append(false)?; + } else { + let this_inner_list = input_array.value(data_index); + let inner_list = this_inner_list + .as_any() + .downcast_ref::<PrimitiveArray<$item_type>>() + .unwrap(); + for k in 0..inner_list.len() { + if inner_list.is_null(k) { + builder.values().append_null()?; + } else { + builder.values().append_value(inner_list.value(k))?; + } + } + builder.append(true)?; + } + } + } + } + Ok(Arc::new(builder.finish())) + }}; +} + +macro_rules! filter_non_primitive_item_list_array { + ($context:expr, $array:expr, $item_array_type:ident, $item_builder:ident, $list_type:ident, $list_builder_type:ident) => {{ + let input_array = $array.as_any().downcast_ref::<$list_type>().unwrap(); + let values_builder = $item_builder::new($context.filtered_count); + let mut builder = $list_builder_type::new(values_builder); + for i in 0..$context.filter_u64.len() { + // foreach u64 batch + let filter_batch = $context.filter_u64[i]; + if filter_batch == 0 { + // if batch == 0, all items are filtered out, so skip entire batch + continue; + } + for j in 0..64 { + // foreach bit in batch: + if (filter_batch & $context.filter_mask[j]) != 0 { + let data_index = (i * 64) + j; + if input_array.is_null(data_index) { + builder.append(false)?; + } else { + let this_inner_list = input_array.value(data_index); + let inner_list = this_inner_list + .as_any() + .downcast_ref::<$item_array_type>() + .unwrap(); + for k in 0..inner_list.len() { + if inner_list.is_null(k) { + builder.values().append_null()?; + } else { + builder.values().append_value(inner_list.value(k))?; + } + } + builder.append(true)?; + } + } + } + } + Ok(Arc::new(builder.finish())) + }}; +} + impl FilterContext { /// Returns a new instance of FilterContext pub fn new(filter_array: &BooleanArray) -> Result<Self> { @@ -285,7 +366,7 @@ impl FilterContext { // foreach u64 batch let filter_batch = self.filter_u64[i]; if filter_batch == 0 { - // if batch == 0: skip + // if batch == 0, all items are filtered out, so skip entire batch continue; } for j in 0..64 { @@ -347,7 +428,7 @@ impl FilterContext { // foreach u64 batch let filter_batch = self.filter_u64[i]; if filter_batch == 0 { - // if batch == 0: skip + // if batch == 0, all items are filtered out, so skip entire batch continue; } for j in 0..64 { @@ -371,7 +452,7 @@ impl FilterContext { // foreach u64 batch let filter_batch = self.filter_u64[i]; if filter_batch == 0 { - // if batch == 0: skip + // if batch == 0, all items are filtered out, so skip entire batch continue; } for j in 0..64 { @@ -408,6 +489,232 @@ impl FilterContext { key_type, value_type ))) } + DataType::List(dt) => match &**dt { + DataType::UInt8 => { + filter_primitive_item_list_array!(self, array, UInt8Type, ListArray, ListBuilder) + } + DataType::UInt16 => { + filter_primitive_item_list_array!(self, array, UInt16Type, ListArray, ListBuilder) + } + DataType::UInt32 => { + filter_primitive_item_list_array!(self, array, UInt32Type, ListArray, ListBuilder) + } + DataType::UInt64 => { + filter_primitive_item_list_array!(self, array, UInt64Type, ListArray, ListBuilder) + } + DataType::Int8 => filter_primitive_item_list_array!(self, array, Int8Type, ListArray, ListBuilder), + DataType::Int16 => { + filter_primitive_item_list_array!(self, array, Int16Type, ListArray, ListBuilder) + } + DataType::Int32 => { + filter_primitive_item_list_array!(self, array, Int32Type, ListArray, ListBuilder) + } + DataType::Int64 => { + filter_primitive_item_list_array!(self, array, Int64Type, ListArray, ListBuilder) + } + DataType::Float32 => { + filter_primitive_item_list_array!(self, array, Float32Type, ListArray, ListBuilder) + } + DataType::Float64 => { + filter_primitive_item_list_array!(self, array, Float64Type, ListArray, ListBuilder) + } + DataType::Boolean => { + filter_primitive_item_list_array!(self, array, BooleanType, ListArray, ListBuilder) + } + DataType::Date32(_) => { + filter_primitive_item_list_array!(self, array, Date32Type, ListArray, ListBuilder) + } + DataType::Date64(_) => { + filter_primitive_item_list_array!(self, array, Date64Type, ListArray, ListBuilder) + } + DataType::Time32(TimeUnit::Second) => { + filter_primitive_item_list_array!(self, array, Time32SecondType, ListArray, ListBuilder) + } + DataType::Time32(TimeUnit::Millisecond) => { + filter_primitive_item_list_array!(self, array, Time32MillisecondType, ListArray, ListBuilder) + } + DataType::Time64(TimeUnit::Microsecond) => { + filter_primitive_item_list_array!(self, array, Time64MicrosecondType, ListArray, ListBuilder) + } + DataType::Time64(TimeUnit::Nanosecond) => { + filter_primitive_item_list_array!(self, array, Time64NanosecondType, ListArray, ListBuilder) + } + DataType::Duration(TimeUnit::Second) => { + filter_primitive_item_list_array!(self, array, DurationSecondType, ListArray, ListBuilder) + } + DataType::Duration(TimeUnit::Millisecond) => { + filter_primitive_item_list_array!(self, array, DurationMillisecondType, ListArray, ListBuilder) + } + DataType::Duration(TimeUnit::Microsecond) => { + filter_primitive_item_list_array!(self, array, DurationMicrosecondType, ListArray, ListBuilder) + } + DataType::Duration(TimeUnit::Nanosecond) => { + filter_primitive_item_list_array!(self, array, DurationNanosecondType, ListArray, ListBuilder) + } + DataType::Timestamp(TimeUnit::Second, _) => { + filter_primitive_item_list_array!(self, array, TimestampSecondType, ListArray, ListBuilder) + } + DataType::Timestamp(TimeUnit::Millisecond, _) => { + filter_primitive_item_list_array!(self, array, TimestampMillisecondType, ListArray, ListBuilder) + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + filter_primitive_item_list_array!(self, array, TimestampMicrosecondType, ListArray, ListBuilder) + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => { + filter_primitive_item_list_array!(self, array, TimestampNanosecondType, ListArray, ListBuilder) + } + DataType::Binary => filter_non_primitive_item_list_array!( + self, + array, + BinaryArray, + BinaryBuilder, + ListArray, + ListBuilder + ), + DataType::LargeBinary => filter_non_primitive_item_list_array!( + self, + array, + LargeBinaryArray, + LargeBinaryBuilder, + ListArray, + ListBuilder + ), + DataType::Utf8 => filter_non_primitive_item_list_array!( + self, + array, + StringArray, + StringBuilder, + ListArray + ,ListBuilder + ), + DataType::LargeUtf8 => filter_non_primitive_item_list_array!( + self, + array, + LargeStringArray, + LargeStringBuilder, + ListArray, + ListBuilder + ), + other => { + Err(ArrowError::ComputeError(format!( + "filter not supported for List({:?})", + other + ))) + } + } + DataType::LargeList(dt) => match &**dt { + DataType::UInt8 => { + filter_primitive_item_list_array!(self, array, UInt8Type, LargeListArray, LargeListBuilder) + } + DataType::UInt16 => { + filter_primitive_item_list_array!(self, array, UInt16Type, LargeListArray, LargeListBuilder) + } + DataType::UInt32 => { + filter_primitive_item_list_array!(self, array, UInt32Type, LargeListArray, LargeListBuilder) + } + DataType::UInt64 => { + filter_primitive_item_list_array!(self, array, UInt64Type, LargeListArray, LargeListBuilder) + } + DataType::Int8 => filter_primitive_item_list_array!(self, array, Int8Type, LargeListArray, LargeListBuilder), + DataType::Int16 => { + filter_primitive_item_list_array!(self, array, Int16Type, LargeListArray, LargeListBuilder) + } + DataType::Int32 => { + filter_primitive_item_list_array!(self, array, Int32Type, LargeListArray, LargeListBuilder) + } + DataType::Int64 => { + filter_primitive_item_list_array!(self, array, Int64Type, LargeListArray, LargeListBuilder) + } + DataType::Float32 => { + filter_primitive_item_list_array!(self, array, Float32Type, LargeListArray, LargeListBuilder) + } + DataType::Float64 => { + filter_primitive_item_list_array!(self, array, Float64Type, LargeListArray, LargeListBuilder) + } + DataType::Boolean => { + filter_primitive_item_list_array!(self, array, BooleanType, LargeListArray, LargeListBuilder) + } + DataType::Date32(_) => { + filter_primitive_item_list_array!(self, array, Date32Type, LargeListArray, LargeListBuilder) + } + DataType::Date64(_) => { + filter_primitive_item_list_array!(self, array, Date64Type, LargeListArray, LargeListBuilder) + } + DataType::Time32(TimeUnit::Second) => { + filter_primitive_item_list_array!(self, array, Time32SecondType, LargeListArray, LargeListBuilder) + } + DataType::Time32(TimeUnit::Millisecond) => { + filter_primitive_item_list_array!(self, array, Time32MillisecondType, LargeListArray, LargeListBuilder) + } + DataType::Time64(TimeUnit::Microsecond) => { + filter_primitive_item_list_array!(self, array, Time64MicrosecondType, LargeListArray, LargeListBuilder) + } + DataType::Time64(TimeUnit::Nanosecond) => { + filter_primitive_item_list_array!(self, array, Time64NanosecondType, LargeListArray, LargeListBuilder) + } + DataType::Duration(TimeUnit::Second) => { + filter_primitive_item_list_array!(self, array, DurationSecondType, LargeListArray, LargeListBuilder) + } + DataType::Duration(TimeUnit::Millisecond) => { + filter_primitive_item_list_array!(self, array, DurationMillisecondType, LargeListArray, LargeListBuilder) + } + DataType::Duration(TimeUnit::Microsecond) => { + filter_primitive_item_list_array!(self, array, DurationMicrosecondType, LargeListArray, LargeListBuilder) + } + DataType::Duration(TimeUnit::Nanosecond) => { + filter_primitive_item_list_array!(self, array, DurationNanosecondType, LargeListArray, LargeListBuilder) + } + DataType::Timestamp(TimeUnit::Second, _) => { + filter_primitive_item_list_array!(self, array, TimestampSecondType, LargeListArray, LargeListBuilder) + } + DataType::Timestamp(TimeUnit::Millisecond, _) => { + filter_primitive_item_list_array!(self, array, TimestampMillisecondType, LargeListArray, LargeListBuilder) + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + filter_primitive_item_list_array!(self, array, TimestampMicrosecondType, LargeListArray, LargeListBuilder) + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => { + filter_primitive_item_list_array!(self, array, TimestampNanosecondType, LargeListArray, LargeListBuilder) + } + DataType::Binary => filter_non_primitive_item_list_array!( + self, + array, + BinaryArray, + BinaryBuilder, + LargeListArray, + LargeListBuilder + ), + DataType::LargeBinary => filter_non_primitive_item_list_array!( + self, + array, + LargeBinaryArray, + LargeBinaryBuilder, + LargeListArray, + LargeListBuilder + ), + DataType::Utf8 => filter_non_primitive_item_list_array!( + self, + array, + StringArray, + StringBuilder, + LargeListArray, + LargeListBuilder + ), + DataType::LargeUtf8 => filter_non_primitive_item_list_array!( + self, + array, + LargeStringArray, + LargeStringBuilder, + LargeListArray, + LargeListBuilder + ), + other => { + Err(ArrowError::ComputeError(format!( + "filter not supported for LargeList({:?})", + other + ))) + } + } other => Err(ArrowError::ComputeError(format!( "filter not supported for {:?}", other @@ -500,6 +807,8 @@ pub fn filter_record_batch( #[cfg(test)] mod tests { use super::*; + use crate::buffer::Buffer; + use crate::datatypes::ToByteSlice; macro_rules! def_temporal_test { ($test:ident, $array_type: ident, $data: expr) => { @@ -767,4 +1076,57 @@ mod tests { assert_eq!("hello", d.value(0)); assert_eq!("world", d.value(1)); } + + #[test] + fn test_filter_list_array() { + let value_data = ArrayData::builder(DataType::Int32) + .len(8) + .add_buffer(Buffer::from(&[0, 1, 2, 3, 4, 5, 6, 7].to_byte_slice())) + .build(); + + let value_offsets = Buffer::from(&[0i64, 3, 6, 8, 8].to_byte_slice()); + + let list_data_type = DataType::LargeList(Box::new(DataType::Int32)); + let list_data = ArrayData::builder(list_data_type) + .len(4) + .add_buffer(value_offsets) + .add_child_data(value_data) + .null_bit_buffer(Buffer::from([0b00000111])) + .build(); + + // a = [[0, 1, 2], [3, 4, 5], [6, 7], null] + let a = LargeListArray::from(list_data); + let b = BooleanArray::from(vec![false, true, false, true]); + let c = filter(&a, &b).unwrap(); + let d = c + .as_ref() + .as_any() + .downcast_ref::<LargeListArray>() + .unwrap(); + + assert_eq!(DataType::Int32, d.value_type()); + + // result should be [[3, 4, 5], null] + assert_eq!(2, d.len()); + assert_eq!(1, d.null_count()); + assert_eq!(true, d.is_null(1)); + + assert_eq!(0, d.value_offset(0)); + assert_eq!(3, d.value_length(0)); + assert_eq!(3, d.value_offset(1)); + assert_eq!(0, d.value_length(1)); + assert_eq!( + Buffer::from(&[3, 4, 5].to_byte_slice()), + d.values().data().buffers()[0].clone() + ); + assert_eq!( + Buffer::from(&[0i64, 3, 3].to_byte_slice()), + d.data().buffers()[0].clone() + ); + let inner_list = d.value(0); + let inner_list = inner_list.as_any().downcast_ref::<Int32Array>().unwrap(); + assert_eq!(3, inner_list.len()); + assert_eq!(0, inner_list.null_count()); + assert_eq!(inner_list, &Int32Array::from(vec![3, 4, 5])); + } }