This is an automated email from the ASF dual-hosted git repository.
nevime pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 8b35014 ARROW-5350: [Rust] Allow filtering on simple lists
8b35014 is described below
commit 8b350146f92388897af75a52e77f43ae939b6bef
Author: Neville Dipale <[email protected]>
AuthorDate: Sun Oct 18 07:46:19 2020 +0200
ARROW-5350: [Rust] Allow filtering on simple lists
This extends filters to simple lists. CC @yordan-pavlov
Closes #8364 from nevi-me/ARROW-5350
Lead-authored-by: Neville Dipale <[email protected]>
Co-authored-by: Yordan Pavlov
<[email protected]>
Signed-off-by: Neville Dipale <[email protected]>
---
rust/arrow/src/compute/kernels/filter.rs | 372 ++++++++++++++++++++++++++++++-
1 file changed, 367 insertions(+), 5 deletions(-)
diff --git a/rust/arrow/src/compute/kernels/filter.rs
b/rust/arrow/src/compute/kernels/filter.rs
index cb5812c..8286f5c6a 100644
--- a/rust/arrow/src/compute/kernels/filter.rs
+++ b/rust/arrow/src/compute/kernels/filter.rs
@@ -17,8 +17,9 @@
//! Defines miscellaneous array kernels.
+use crate::array::PrimitiveArrayOps;
use crate::array::*;
-use crate::datatypes::{ArrowNumericType, DataType, TimeUnit};
+use crate::datatypes::*;
use crate::error::{ArrowError, Result};
use crate::record_batch::RecordBatch;
use crate::{
@@ -161,7 +162,7 @@ fn filter_array_impl(
// foreach u64 batch
let filter_batch = *filter_batch;
if filter_batch == 0 {
- // if batch == 0: skip
+ // if batch == 0, all items are filtered out, so skip entire batch
continue;
} else if filter_batch == all_ones_batch {
// if batch == all 1s: copy all 64 values in one go
@@ -230,6 +231,86 @@ macro_rules! filter_dictionary_array {
}};
}
+macro_rules! filter_primitive_item_list_array {
+ ($context:expr, $array:expr, $item_type:ident, $list_type:ident,
$list_builder_type:ident) => {{
+ let input_array =
$array.as_any().downcast_ref::<$list_type>().unwrap();
+ let values_builder =
PrimitiveBuilder::<$item_type>::new($context.filtered_count);
+ let mut builder = $list_builder_type::new(values_builder);
+ for i in 0..$context.filter_u64.len() {
+ // foreach u64 batch
+ let filter_batch = $context.filter_u64[i];
+ if filter_batch == 0 {
+ // if batch == 0, all items are filtered out, so skip entire
batch
+ continue;
+ }
+ for j in 0..64 {
+ // foreach bit in batch:
+ if (filter_batch & $context.filter_mask[j]) != 0 {
+ let data_index = (i * 64) + j;
+ if input_array.is_null(data_index) {
+ builder.append(false)?;
+ } else {
+ let this_inner_list = input_array.value(data_index);
+ let inner_list = this_inner_list
+ .as_any()
+ .downcast_ref::<PrimitiveArray<$item_type>>()
+ .unwrap();
+ for k in 0..inner_list.len() {
+ if inner_list.is_null(k) {
+ builder.values().append_null()?;
+ } else {
+
builder.values().append_value(inner_list.value(k))?;
+ }
+ }
+ builder.append(true)?;
+ }
+ }
+ }
+ }
+ Ok(Arc::new(builder.finish()))
+ }};
+}
+
+macro_rules! filter_non_primitive_item_list_array {
+ ($context:expr, $array:expr, $item_array_type:ident, $item_builder:ident,
$list_type:ident, $list_builder_type:ident) => {{
+ let input_array =
$array.as_any().downcast_ref::<$list_type>().unwrap();
+ let values_builder = $item_builder::new($context.filtered_count);
+ let mut builder = $list_builder_type::new(values_builder);
+ for i in 0..$context.filter_u64.len() {
+ // foreach u64 batch
+ let filter_batch = $context.filter_u64[i];
+ if filter_batch == 0 {
+ // if batch == 0, all items are filtered out, so skip entire
batch
+ continue;
+ }
+ for j in 0..64 {
+ // foreach bit in batch:
+ if (filter_batch & $context.filter_mask[j]) != 0 {
+ let data_index = (i * 64) + j;
+ if input_array.is_null(data_index) {
+ builder.append(false)?;
+ } else {
+ let this_inner_list = input_array.value(data_index);
+ let inner_list = this_inner_list
+ .as_any()
+ .downcast_ref::<$item_array_type>()
+ .unwrap();
+ for k in 0..inner_list.len() {
+ if inner_list.is_null(k) {
+ builder.values().append_null()?;
+ } else {
+
builder.values().append_value(inner_list.value(k))?;
+ }
+ }
+ builder.append(true)?;
+ }
+ }
+ }
+ }
+ Ok(Arc::new(builder.finish()))
+ }};
+}
+
impl FilterContext {
/// Returns a new instance of FilterContext
pub fn new(filter_array: &BooleanArray) -> Result<Self> {
@@ -285,7 +366,7 @@ impl FilterContext {
// foreach u64 batch
let filter_batch = self.filter_u64[i];
if filter_batch == 0 {
- // if batch == 0: skip
+ // if batch == 0, all items are filtered out, so skip
entire batch
continue;
}
for j in 0..64 {
@@ -347,7 +428,7 @@ impl FilterContext {
// foreach u64 batch
let filter_batch = self.filter_u64[i];
if filter_batch == 0 {
- // if batch == 0: skip
+ // if batch == 0, all items are filtered out, so skip
entire batch
continue;
}
for j in 0..64 {
@@ -371,7 +452,7 @@ impl FilterContext {
// foreach u64 batch
let filter_batch = self.filter_u64[i];
if filter_batch == 0 {
- // if batch == 0: skip
+ // if batch == 0, all items are filtered out, so skip
entire batch
continue;
}
for j in 0..64 {
@@ -408,6 +489,232 @@ impl FilterContext {
key_type, value_type
)))
}
+ DataType::List(dt) => match &**dt {
+ DataType::UInt8 => {
+ filter_primitive_item_list_array!(self, array, UInt8Type,
ListArray, ListBuilder)
+ }
+ DataType::UInt16 => {
+ filter_primitive_item_list_array!(self, array, UInt16Type,
ListArray, ListBuilder)
+ }
+ DataType::UInt32 => {
+ filter_primitive_item_list_array!(self, array, UInt32Type,
ListArray, ListBuilder)
+ }
+ DataType::UInt64 => {
+ filter_primitive_item_list_array!(self, array, UInt64Type,
ListArray, ListBuilder)
+ }
+ DataType::Int8 => filter_primitive_item_list_array!(self,
array, Int8Type, ListArray, ListBuilder),
+ DataType::Int16 => {
+ filter_primitive_item_list_array!(self, array, Int16Type,
ListArray, ListBuilder)
+ }
+ DataType::Int32 => {
+ filter_primitive_item_list_array!(self, array, Int32Type,
ListArray, ListBuilder)
+ }
+ DataType::Int64 => {
+ filter_primitive_item_list_array!(self, array, Int64Type,
ListArray, ListBuilder)
+ }
+ DataType::Float32 => {
+ filter_primitive_item_list_array!(self, array,
Float32Type, ListArray, ListBuilder)
+ }
+ DataType::Float64 => {
+ filter_primitive_item_list_array!(self, array,
Float64Type, ListArray, ListBuilder)
+ }
+ DataType::Boolean => {
+ filter_primitive_item_list_array!(self, array,
BooleanType, ListArray, ListBuilder)
+ }
+ DataType::Date32(_) => {
+ filter_primitive_item_list_array!(self, array, Date32Type,
ListArray, ListBuilder)
+ }
+ DataType::Date64(_) => {
+ filter_primitive_item_list_array!(self, array, Date64Type,
ListArray, ListBuilder)
+ }
+ DataType::Time32(TimeUnit::Second) => {
+ filter_primitive_item_list_array!(self, array,
Time32SecondType, ListArray, ListBuilder)
+ }
+ DataType::Time32(TimeUnit::Millisecond) => {
+ filter_primitive_item_list_array!(self, array,
Time32MillisecondType, ListArray, ListBuilder)
+ }
+ DataType::Time64(TimeUnit::Microsecond) => {
+ filter_primitive_item_list_array!(self, array,
Time64MicrosecondType, ListArray, ListBuilder)
+ }
+ DataType::Time64(TimeUnit::Nanosecond) => {
+ filter_primitive_item_list_array!(self, array,
Time64NanosecondType, ListArray, ListBuilder)
+ }
+ DataType::Duration(TimeUnit::Second) => {
+ filter_primitive_item_list_array!(self, array,
DurationSecondType, ListArray, ListBuilder)
+ }
+ DataType::Duration(TimeUnit::Millisecond) => {
+ filter_primitive_item_list_array!(self, array,
DurationMillisecondType, ListArray, ListBuilder)
+ }
+ DataType::Duration(TimeUnit::Microsecond) => {
+ filter_primitive_item_list_array!(self, array,
DurationMicrosecondType, ListArray, ListBuilder)
+ }
+ DataType::Duration(TimeUnit::Nanosecond) => {
+ filter_primitive_item_list_array!(self, array,
DurationNanosecondType, ListArray, ListBuilder)
+ }
+ DataType::Timestamp(TimeUnit::Second, _) => {
+ filter_primitive_item_list_array!(self, array,
TimestampSecondType, ListArray, ListBuilder)
+ }
+ DataType::Timestamp(TimeUnit::Millisecond, _) => {
+ filter_primitive_item_list_array!(self, array,
TimestampMillisecondType, ListArray, ListBuilder)
+ }
+ DataType::Timestamp(TimeUnit::Microsecond, _) => {
+ filter_primitive_item_list_array!(self, array,
TimestampMicrosecondType, ListArray, ListBuilder)
+ }
+ DataType::Timestamp(TimeUnit::Nanosecond, _) => {
+ filter_primitive_item_list_array!(self, array,
TimestampNanosecondType, ListArray, ListBuilder)
+ }
+ DataType::Binary => filter_non_primitive_item_list_array!(
+ self,
+ array,
+ BinaryArray,
+ BinaryBuilder,
+ ListArray,
+ ListBuilder
+ ),
+ DataType::LargeBinary => filter_non_primitive_item_list_array!(
+ self,
+ array,
+ LargeBinaryArray,
+ LargeBinaryBuilder,
+ ListArray,
+ ListBuilder
+ ),
+ DataType::Utf8 => filter_non_primitive_item_list_array!(
+ self,
+ array,
+ StringArray,
+ StringBuilder,
+ ListArray
+ ,ListBuilder
+ ),
+ DataType::LargeUtf8 => filter_non_primitive_item_list_array!(
+ self,
+ array,
+ LargeStringArray,
+ LargeStringBuilder,
+ ListArray,
+ ListBuilder
+ ),
+ other => {
+ Err(ArrowError::ComputeError(format!(
+ "filter not supported for List({:?})",
+ other
+ )))
+ }
+ }
+ DataType::LargeList(dt) => match &**dt {
+ DataType::UInt8 => {
+ filter_primitive_item_list_array!(self, array, UInt8Type,
LargeListArray, LargeListBuilder)
+ }
+ DataType::UInt16 => {
+ filter_primitive_item_list_array!(self, array, UInt16Type,
LargeListArray, LargeListBuilder)
+ }
+ DataType::UInt32 => {
+ filter_primitive_item_list_array!(self, array, UInt32Type,
LargeListArray, LargeListBuilder)
+ }
+ DataType::UInt64 => {
+ filter_primitive_item_list_array!(self, array, UInt64Type,
LargeListArray, LargeListBuilder)
+ }
+ DataType::Int8 => filter_primitive_item_list_array!(self,
array, Int8Type, LargeListArray, LargeListBuilder),
+ DataType::Int16 => {
+ filter_primitive_item_list_array!(self, array, Int16Type,
LargeListArray, LargeListBuilder)
+ }
+ DataType::Int32 => {
+ filter_primitive_item_list_array!(self, array, Int32Type,
LargeListArray, LargeListBuilder)
+ }
+ DataType::Int64 => {
+ filter_primitive_item_list_array!(self, array, Int64Type,
LargeListArray, LargeListBuilder)
+ }
+ DataType::Float32 => {
+ filter_primitive_item_list_array!(self, array,
Float32Type, LargeListArray, LargeListBuilder)
+ }
+ DataType::Float64 => {
+ filter_primitive_item_list_array!(self, array,
Float64Type, LargeListArray, LargeListBuilder)
+ }
+ DataType::Boolean => {
+ filter_primitive_item_list_array!(self, array,
BooleanType, LargeListArray, LargeListBuilder)
+ }
+ DataType::Date32(_) => {
+ filter_primitive_item_list_array!(self, array, Date32Type,
LargeListArray, LargeListBuilder)
+ }
+ DataType::Date64(_) => {
+ filter_primitive_item_list_array!(self, array, Date64Type,
LargeListArray, LargeListBuilder)
+ }
+ DataType::Time32(TimeUnit::Second) => {
+ filter_primitive_item_list_array!(self, array,
Time32SecondType, LargeListArray, LargeListBuilder)
+ }
+ DataType::Time32(TimeUnit::Millisecond) => {
+ filter_primitive_item_list_array!(self, array,
Time32MillisecondType, LargeListArray, LargeListBuilder)
+ }
+ DataType::Time64(TimeUnit::Microsecond) => {
+ filter_primitive_item_list_array!(self, array,
Time64MicrosecondType, LargeListArray, LargeListBuilder)
+ }
+ DataType::Time64(TimeUnit::Nanosecond) => {
+ filter_primitive_item_list_array!(self, array,
Time64NanosecondType, LargeListArray, LargeListBuilder)
+ }
+ DataType::Duration(TimeUnit::Second) => {
+ filter_primitive_item_list_array!(self, array,
DurationSecondType, LargeListArray, LargeListBuilder)
+ }
+ DataType::Duration(TimeUnit::Millisecond) => {
+ filter_primitive_item_list_array!(self, array,
DurationMillisecondType, LargeListArray, LargeListBuilder)
+ }
+ DataType::Duration(TimeUnit::Microsecond) => {
+ filter_primitive_item_list_array!(self, array,
DurationMicrosecondType, LargeListArray, LargeListBuilder)
+ }
+ DataType::Duration(TimeUnit::Nanosecond) => {
+ filter_primitive_item_list_array!(self, array,
DurationNanosecondType, LargeListArray, LargeListBuilder)
+ }
+ DataType::Timestamp(TimeUnit::Second, _) => {
+ filter_primitive_item_list_array!(self, array,
TimestampSecondType, LargeListArray, LargeListBuilder)
+ }
+ DataType::Timestamp(TimeUnit::Millisecond, _) => {
+ filter_primitive_item_list_array!(self, array,
TimestampMillisecondType, LargeListArray, LargeListBuilder)
+ }
+ DataType::Timestamp(TimeUnit::Microsecond, _) => {
+ filter_primitive_item_list_array!(self, array,
TimestampMicrosecondType, LargeListArray, LargeListBuilder)
+ }
+ DataType::Timestamp(TimeUnit::Nanosecond, _) => {
+ filter_primitive_item_list_array!(self, array,
TimestampNanosecondType, LargeListArray, LargeListBuilder)
+ }
+ DataType::Binary => filter_non_primitive_item_list_array!(
+ self,
+ array,
+ BinaryArray,
+ BinaryBuilder,
+ LargeListArray,
+ LargeListBuilder
+ ),
+ DataType::LargeBinary => filter_non_primitive_item_list_array!(
+ self,
+ array,
+ LargeBinaryArray,
+ LargeBinaryBuilder,
+ LargeListArray,
+ LargeListBuilder
+ ),
+ DataType::Utf8 => filter_non_primitive_item_list_array!(
+ self,
+ array,
+ StringArray,
+ StringBuilder,
+ LargeListArray,
+ LargeListBuilder
+ ),
+ DataType::LargeUtf8 => filter_non_primitive_item_list_array!(
+ self,
+ array,
+ LargeStringArray,
+ LargeStringBuilder,
+ LargeListArray,
+ LargeListBuilder
+ ),
+ other => {
+ Err(ArrowError::ComputeError(format!(
+ "filter not supported for LargeList({:?})",
+ other
+ )))
+ }
+ }
other => Err(ArrowError::ComputeError(format!(
"filter not supported for {:?}",
other
@@ -500,6 +807,8 @@ pub fn filter_record_batch(
#[cfg(test)]
mod tests {
use super::*;
+ use crate::buffer::Buffer;
+ use crate::datatypes::ToByteSlice;
macro_rules! def_temporal_test {
($test:ident, $array_type: ident, $data: expr) => {
@@ -767,4 +1076,57 @@ mod tests {
assert_eq!("hello", d.value(0));
assert_eq!("world", d.value(1));
}
+
+ #[test]
+ fn test_filter_list_array() {
+ let value_data = ArrayData::builder(DataType::Int32)
+ .len(8)
+ .add_buffer(Buffer::from(&[0, 1, 2, 3, 4, 5, 6,
7].to_byte_slice()))
+ .build();
+
+ let value_offsets = Buffer::from(&[0i64, 3, 6, 8, 8].to_byte_slice());
+
+ let list_data_type = DataType::LargeList(Box::new(DataType::Int32));
+ let list_data = ArrayData::builder(list_data_type)
+ .len(4)
+ .add_buffer(value_offsets)
+ .add_child_data(value_data)
+ .null_bit_buffer(Buffer::from([0b00000111]))
+ .build();
+
+ // a = [[0, 1, 2], [3, 4, 5], [6, 7], null]
+ let a = LargeListArray::from(list_data);
+ let b = BooleanArray::from(vec![false, true, false, true]);
+ let c = filter(&a, &b).unwrap();
+ let d = c
+ .as_ref()
+ .as_any()
+ .downcast_ref::<LargeListArray>()
+ .unwrap();
+
+ assert_eq!(DataType::Int32, d.value_type());
+
+ // result should be [[3, 4, 5], null]
+ assert_eq!(2, d.len());
+ assert_eq!(1, d.null_count());
+ assert_eq!(true, d.is_null(1));
+
+ assert_eq!(0, d.value_offset(0));
+ assert_eq!(3, d.value_length(0));
+ assert_eq!(3, d.value_offset(1));
+ assert_eq!(0, d.value_length(1));
+ assert_eq!(
+ Buffer::from(&[3, 4, 5].to_byte_slice()),
+ d.values().data().buffers()[0].clone()
+ );
+ assert_eq!(
+ Buffer::from(&[0i64, 3, 3].to_byte_slice()),
+ d.data().buffers()[0].clone()
+ );
+ let inner_list = d.value(0);
+ let inner_list =
inner_list.as_any().downcast_ref::<Int32Array>().unwrap();
+ assert_eq!(3, inner_list.len());
+ assert_eq!(0, inner_list.null_count());
+ assert_eq!(inner_list, &Int32Array::from(vec![3, 4, 5]));
+ }
}