This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 61602f00c Add `InList` support for binary type. (#3324)
61602f00c is described below
commit 61602f00caaa6935b53f95dc22dc69d82f9b2f52
Author: Remzi Yang <[email protected]>
AuthorDate: Sun Sep 4 19:47:15 2022 +0800
Add `InList` support for binary type. (#3324)
* add binary support
Signed-off-by: remzi <[email protected]>
* support inset and add tests
Signed-off-by: remzi <[email protected]>
* clean
Signed-off-by: remzi <[email protected]>
* fmt
Signed-off-by: remzi <[email protected]>
Signed-off-by: remzi <[email protected]>
---
.../physical-expr/src/expressions/in_list.rs | 259 +++++++++++++++------
1 file changed, 184 insertions(+), 75 deletions(-)
diff --git a/datafusion/physical-expr/src/expressions/in_list.rs
b/datafusion/physical-expr/src/expressions/in_list.rs
index a391bf51d..aed25e580 100644
--- a/datafusion/physical-expr/src/expressions/in_list.rs
+++ b/datafusion/physical-expr/src/expressions/in_list.rs
@@ -34,11 +34,10 @@ use arrow::{
use crate::PhysicalExpr;
use arrow::array::*;
-use arrow::buffer::{Buffer, MutableBuffer};
use datafusion_common::ScalarValue;
use datafusion_common::ScalarValue::{
- Boolean, Decimal128, Int16, Int32, Int64, Int8, LargeUtf8, UInt16, UInt32,
UInt64,
- UInt8, Utf8,
+ Binary, Boolean, Decimal128, Int16, Int32, Int64, Int8, LargeBinary,
LargeUtf8,
+ UInt16, UInt32, UInt64, UInt8, Utf8,
};
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::ColumnarValue;
@@ -49,30 +48,6 @@ use datafusion_expr::ColumnarValue;
/// TODO: add switch codeGen in In_List
static OPTIMIZER_INSET_THRESHOLD: usize = 30;
-macro_rules! compare_op_scalar {
- ($left: expr, $right:expr, $op:expr) => {{
- let null_bit_buffer = $left.data().null_buffer().cloned();
-
- let comparison =
- (0..$left.len()).map(|i| unsafe { $op($left.value_unchecked(i),
$right) });
- // same as $left.len()
- let buffer = unsafe {
MutableBuffer::from_trusted_len_iter_bool(comparison) };
-
- let data = unsafe {
- ArrayData::new_unchecked(
- DataType::Boolean,
- $left.len(),
- None,
- null_bit_buffer,
- 0,
- vec![Buffer::from(buffer)],
- vec![],
- )
- };
- Ok(BooleanArray::from(data))
- }};
-}
-
/// InList
#[derive(Debug)]
pub struct InListExpr {
@@ -293,21 +268,6 @@ macro_rules! collection_contains_check_decimal {
}};
}
-// whether each value on the left (can be null) is contained in the non-null
list
-fn in_list_utf8<OffsetSize: OffsetSizeTrait>(
- array: &GenericStringArray<OffsetSize>,
- values: &[&str],
-) -> Result<BooleanArray> {
- compare_op_scalar!(array, values, |x, v: &[&str]| v.contains(&x))
-}
-
-fn not_in_list_utf8<OffsetSize: OffsetSizeTrait>(
- array: &GenericStringArray<OffsetSize>,
- values: &[&str],
-) -> Result<BooleanArray> {
- compare_op_scalar!(array, values, |x, v: &[&str]| !v.contains(&x))
-}
-
// try evaluate all list exprs and check if the exprs are constants or not
fn try_cast_static_filter_to_set(
list: &[Arc<dyn PhysicalExpr>],
@@ -386,8 +346,7 @@ fn set_contains_utf8<OffsetSize: OffsetSizeTrait>(
let native_array = set
.iter()
.flat_map(|v| match v {
- Utf8(v) => v.as_deref(),
- LargeUtf8(v) => v.as_deref(),
+ Utf8(v) | LargeUtf8(v) => v.as_deref(),
datatype => {
unreachable!("InList can't reach other data type {} for {}.",
datatype, v)
}
@@ -398,6 +357,26 @@ fn set_contains_utf8<OffsetSize: OffsetSizeTrait>(
collection_contains_check!(array, native_set, negated, contains_null)
}
+fn set_contains_binary<OffsetSize: OffsetSizeTrait>(
+ array: &GenericBinaryArray<OffsetSize>,
+ set: &HashSet<ScalarValue>,
+ negated: bool,
+) -> ColumnarValue {
+ let contains_null = set.iter().any(|v| v.is_null());
+ let native_array = set
+ .iter()
+ .flat_map(|v| match v {
+ Binary(v) | LargeBinary(v) => v.as_deref(),
+ datatype => {
+ unreachable!("InList can't reach other data type {} for {}.",
datatype, v)
+ }
+ })
+ .collect::<Vec<_>>();
+ let native_set: HashSet<&[u8]> = HashSet::from_iter(native_array);
+
+ collection_contains_check!(array, native_set, negated, contains_null)
+}
+
impl InListExpr {
/// Create a new InList expression
pub fn new(
@@ -471,37 +450,50 @@ impl InListExpr {
})
.collect::<Vec<&str>>();
- if negated {
- if contains_null {
- Ok(ColumnarValue::Array(Arc::new(
- array
- .iter()
- .map(|x| match x.map(|v| !values.contains(&v)) {
- Some(true) => None,
- x => x,
- })
- .collect::<BooleanArray>(),
- )))
- } else {
- Ok(ColumnarValue::Array(Arc::new(not_in_list_utf8(
- array, &values,
- )?)))
- }
- } else if contains_null {
- Ok(ColumnarValue::Array(Arc::new(
- array
- .iter()
- .map(|x| match x.map(|v| values.contains(&v)) {
- Some(false) => None,
- x => x,
- })
- .collect::<BooleanArray>(),
- )))
- } else {
- Ok(ColumnarValue::Array(Arc::new(in_list_utf8(
- array, &values,
- )?)))
- }
+ Ok(collection_contains_check!(
+ array,
+ values,
+ negated,
+ contains_null
+ ))
+ }
+
+ fn compare_binary<T: OffsetSizeTrait>(
+ &self,
+ array: ArrayRef,
+ list_values: Vec<ColumnarValue>,
+ negated: bool,
+ ) -> Result<ColumnarValue> {
+ let array = array
+ .as_any()
+ .downcast_ref::<GenericBinaryArray<T>>()
+ .unwrap();
+
+ let contains_null = list_values
+ .iter()
+ .any(|v| matches!(v, ColumnarValue::Scalar(s) if s.is_null()));
+ let values = list_values
+ .iter()
+ .flat_map(|expr| match expr {
+ ColumnarValue::Scalar(s) => match s {
+ ScalarValue::Binary(Some(v)) |
ScalarValue::LargeBinary(Some(v)) => {
+ Some(v.as_slice())
+ }
+ ScalarValue::Binary(None) | ScalarValue::LargeBinary(None)
=> None,
+ datatype => unimplemented!("Unexpected type {} for
InList", datatype),
+ },
+ ColumnarValue::Array(_) => {
+ unimplemented!("InList does not yet support nested
columns.")
+ }
+ })
+ .collect::<Vec<&[u8]>>();
+
+ Ok(collection_contains_check!(
+ array,
+ values,
+ negated,
+ contains_null
+ ))
}
}
@@ -670,6 +662,20 @@ impl PhysicalExpr for InListExpr {
.unwrap();
Ok(set_contains_utf8(array, set, self.negated))
}
+ DataType::Binary => {
+ let array = array
+ .as_any()
+ .downcast_ref::<GenericBinaryArray<i32>>()
+ .unwrap();
+ Ok(set_contains_binary(array, set, self.negated))
+ }
+ DataType::LargeBinary => {
+ let array = array
+ .as_any()
+ .downcast_ref::<GenericBinaryArray<i64>>()
+ .unwrap();
+ Ok(set_contains_binary(array, set, self.negated))
+ }
DataType::Decimal128(_, _) => {
let array =
array.as_any().downcast_ref::<Decimal128Array>().unwrap();
Ok(make_set_contains_decimal(array, set, self.negated))
@@ -795,6 +801,12 @@ impl PhysicalExpr for InListExpr {
DataType::LargeUtf8 => {
self.compare_utf8::<i64>(array, list_values, self.negated)
}
+ DataType::Binary => {
+ self.compare_binary::<i32>(array, list_values,
self.negated)
+ }
+ DataType::LargeBinary => {
+ self.compare_binary::<i64>(array, list_values,
self.negated)
+ }
DataType::Null => {
let null_array = new_null_array(&DataType::Boolean,
array.len());
Ok(ColumnarValue::Array(Arc::new(null_array)))
@@ -906,6 +918,66 @@ mod tests {
Ok(())
}
+ #[test]
+ fn in_list_binary() -> Result<()> {
+ let schema = Schema::new(vec![Field::new("a", DataType::Binary,
true)]);
+ let a = BinaryArray::from(vec![
+ Some([1, 2, 3].as_slice()),
+ Some([1, 2, 2].as_slice()),
+ None,
+ ]);
+ let col_a = col("a", &schema)?;
+ let batch = RecordBatch::try_new(Arc::new(schema.clone()),
vec![Arc::new(a)])?;
+
+ // expression: "a in ([1, 2, 3], [4, 5, 6])"
+ let list = vec![lit([1, 2, 3].as_slice()), lit([4, 5, 6].as_slice())];
+ in_list!(
+ batch,
+ list.clone(),
+ &false,
+ vec![Some(true), Some(false), None],
+ col_a.clone(),
+ &schema
+ );
+
+ // expression: "a not in ([1, 2, 3], [4, 5, 6])"
+ in_list!(
+ batch,
+ list,
+ &true,
+ vec![Some(false), Some(true), None],
+ col_a.clone(),
+ &schema
+ );
+
+ // expression: "a in ([1, 2, 3], [4, 5, 6], null)"
+ let list = vec![
+ lit([1, 2, 3].as_slice()),
+ lit([4, 5, 6].as_slice()),
+ lit(ScalarValue::Binary(None)),
+ ];
+ in_list!(
+ batch,
+ list.clone(),
+ &false,
+ vec![Some(true), None, None],
+ col_a.clone(),
+ &schema
+ );
+
+ // expression: "a in ([1, 2, 3], [4, 5, 6], null)"
+ in_list!(
+ batch,
+ list,
+ &true,
+ vec![Some(false), None, None],
+ col_a.clone(),
+ &schema
+ );
+
+ Ok(())
+ }
+
#[test]
fn in_list_int64() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]);
@@ -1316,6 +1388,43 @@ mod tests {
Ok(())
}
+ #[test]
+ fn in_list_set_binary() -> Result<()> {
+ let schema = Schema::new(vec![Field::new("a", DataType::Binary,
true)]);
+ let a = BinaryArray::from(vec![
+ Some([1, 2, 3].as_slice()),
+ Some([3, 2, 1].as_slice()),
+ None,
+ ]);
+ let col_a = col("a", &schema)?;
+ let batch = RecordBatch::try_new(Arc::new(schema.clone()),
vec![Arc::new(a)])?;
+
+ let mut list = vec![lit([1, 2, 3].as_slice()),
lit(ScalarValue::Binary(None))];
+ for v in 0..OPTIMIZER_INSET_THRESHOLD {
+ list.push(lit([v as u8].as_slice()));
+ }
+
+ in_list!(
+ batch,
+ list.clone(),
+ &false,
+ vec![Some(true), None, None],
+ col_a.clone(),
+ &schema
+ );
+
+ in_list!(
+ batch,
+ list.clone(),
+ &true,
+ vec![Some(false), None, None],
+ col_a.clone(),
+ &schema
+ );
+
+ Ok(())
+ }
+
#[test]
fn in_list_set_decimal() -> Result<()> {
let schema =