This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 61602f00c Add `InList` support for binary type. (#3324)
61602f00c is described below

commit 61602f00caaa6935b53f95dc22dc69d82f9b2f52
Author: Remzi Yang <[email protected]>
AuthorDate: Sun Sep 4 19:47:15 2022 +0800

    Add `InList` support for binary type. (#3324)
    
    * add binary support
    
    Signed-off-by: remzi <[email protected]>
    
    * support inset and add tests
    
    Signed-off-by: remzi <[email protected]>
    
    * clean
    
    Signed-off-by: remzi <[email protected]>
    
    * fmt
    
    Signed-off-by: remzi <[email protected]>
    
    Signed-off-by: remzi <[email protected]>
---
 .../physical-expr/src/expressions/in_list.rs       | 259 +++++++++++++++------
 1 file changed, 184 insertions(+), 75 deletions(-)

diff --git a/datafusion/physical-expr/src/expressions/in_list.rs 
b/datafusion/physical-expr/src/expressions/in_list.rs
index a391bf51d..aed25e580 100644
--- a/datafusion/physical-expr/src/expressions/in_list.rs
+++ b/datafusion/physical-expr/src/expressions/in_list.rs
@@ -34,11 +34,10 @@ use arrow::{
 
 use crate::PhysicalExpr;
 use arrow::array::*;
-use arrow::buffer::{Buffer, MutableBuffer};
 use datafusion_common::ScalarValue;
 use datafusion_common::ScalarValue::{
-    Boolean, Decimal128, Int16, Int32, Int64, Int8, LargeUtf8, UInt16, UInt32, 
UInt64,
-    UInt8, Utf8,
+    Binary, Boolean, Decimal128, Int16, Int32, Int64, Int8, LargeBinary, 
LargeUtf8,
+    UInt16, UInt32, UInt64, UInt8, Utf8,
 };
 use datafusion_common::{DataFusionError, Result};
 use datafusion_expr::ColumnarValue;
@@ -49,30 +48,6 @@ use datafusion_expr::ColumnarValue;
 /// TODO: add switch codeGen in In_List
 static OPTIMIZER_INSET_THRESHOLD: usize = 30;
 
-macro_rules! compare_op_scalar {
-    ($left: expr, $right:expr, $op:expr) => {{
-        let null_bit_buffer = $left.data().null_buffer().cloned();
-
-        let comparison =
-            (0..$left.len()).map(|i| unsafe { $op($left.value_unchecked(i), 
$right) });
-        // same as $left.len()
-        let buffer = unsafe { 
MutableBuffer::from_trusted_len_iter_bool(comparison) };
-
-        let data = unsafe {
-            ArrayData::new_unchecked(
-                DataType::Boolean,
-                $left.len(),
-                None,
-                null_bit_buffer,
-                0,
-                vec![Buffer::from(buffer)],
-                vec![],
-            )
-        };
-        Ok(BooleanArray::from(data))
-    }};
-}
-
 /// InList
 #[derive(Debug)]
 pub struct InListExpr {
@@ -293,21 +268,6 @@ macro_rules! collection_contains_check_decimal {
     }};
 }
 
-// whether each value on the left (can be null) is contained in the non-null 
list
-fn in_list_utf8<OffsetSize: OffsetSizeTrait>(
-    array: &GenericStringArray<OffsetSize>,
-    values: &[&str],
-) -> Result<BooleanArray> {
-    compare_op_scalar!(array, values, |x, v: &[&str]| v.contains(&x))
-}
-
-fn not_in_list_utf8<OffsetSize: OffsetSizeTrait>(
-    array: &GenericStringArray<OffsetSize>,
-    values: &[&str],
-) -> Result<BooleanArray> {
-    compare_op_scalar!(array, values, |x, v: &[&str]| !v.contains(&x))
-}
-
 // try evaluate all list exprs and check if the exprs are constants or not
 fn try_cast_static_filter_to_set(
     list: &[Arc<dyn PhysicalExpr>],
@@ -386,8 +346,7 @@ fn set_contains_utf8<OffsetSize: OffsetSizeTrait>(
     let native_array = set
         .iter()
         .flat_map(|v| match v {
-            Utf8(v) => v.as_deref(),
-            LargeUtf8(v) => v.as_deref(),
+            Utf8(v) | LargeUtf8(v) => v.as_deref(),
             datatype => {
                 unreachable!("InList can't reach other data type {} for {}.", 
datatype, v)
             }
@@ -398,6 +357,26 @@ fn set_contains_utf8<OffsetSize: OffsetSizeTrait>(
     collection_contains_check!(array, native_set, negated, contains_null)
 }
 
+fn set_contains_binary<OffsetSize: OffsetSizeTrait>(
+    array: &GenericBinaryArray<OffsetSize>,
+    set: &HashSet<ScalarValue>,
+    negated: bool,
+) -> ColumnarValue {
+    let contains_null = set.iter().any(|v| v.is_null());
+    let native_array = set
+        .iter()
+        .flat_map(|v| match v {
+            Binary(v) | LargeBinary(v) => v.as_deref(),
+            datatype => {
+                unreachable!("InList can't reach other data type {} for {}.", 
datatype, v)
+            }
+        })
+        .collect::<Vec<_>>();
+    let native_set: HashSet<&[u8]> = HashSet::from_iter(native_array);
+
+    collection_contains_check!(array, native_set, negated, contains_null)
+}
+
 impl InListExpr {
     /// Create a new InList expression
     pub fn new(
@@ -471,37 +450,50 @@ impl InListExpr {
             })
             .collect::<Vec<&str>>();
 
-        if negated {
-            if contains_null {
-                Ok(ColumnarValue::Array(Arc::new(
-                    array
-                        .iter()
-                        .map(|x| match x.map(|v| !values.contains(&v)) {
-                            Some(true) => None,
-                            x => x,
-                        })
-                        .collect::<BooleanArray>(),
-                )))
-            } else {
-                Ok(ColumnarValue::Array(Arc::new(not_in_list_utf8(
-                    array, &values,
-                )?)))
-            }
-        } else if contains_null {
-            Ok(ColumnarValue::Array(Arc::new(
-                array
-                    .iter()
-                    .map(|x| match x.map(|v| values.contains(&v)) {
-                        Some(false) => None,
-                        x => x,
-                    })
-                    .collect::<BooleanArray>(),
-            )))
-        } else {
-            Ok(ColumnarValue::Array(Arc::new(in_list_utf8(
-                array, &values,
-            )?)))
-        }
+        Ok(collection_contains_check!(
+            array,
+            values,
+            negated,
+            contains_null
+        ))
+    }
+
+    fn compare_binary<T: OffsetSizeTrait>(
+        &self,
+        array: ArrayRef,
+        list_values: Vec<ColumnarValue>,
+        negated: bool,
+    ) -> Result<ColumnarValue> {
+        let array = array
+            .as_any()
+            .downcast_ref::<GenericBinaryArray<T>>()
+            .unwrap();
+
+        let contains_null = list_values
+            .iter()
+            .any(|v| matches!(v, ColumnarValue::Scalar(s) if s.is_null()));
+        let values = list_values
+            .iter()
+            .flat_map(|expr| match expr {
+                ColumnarValue::Scalar(s) => match s {
+                    ScalarValue::Binary(Some(v)) | 
ScalarValue::LargeBinary(Some(v)) => {
+                        Some(v.as_slice())
+                    }
+                    ScalarValue::Binary(None) | ScalarValue::LargeBinary(None) 
=> None,
+                    datatype => unimplemented!("Unexpected type {} for 
InList", datatype),
+                },
+                ColumnarValue::Array(_) => {
+                    unimplemented!("InList does not yet support nested 
columns.")
+                }
+            })
+            .collect::<Vec<&[u8]>>();
+
+        Ok(collection_contains_check!(
+            array,
+            values,
+            negated,
+            contains_null
+        ))
     }
 }
 
@@ -670,6 +662,20 @@ impl PhysicalExpr for InListExpr {
                         .unwrap();
                     Ok(set_contains_utf8(array, set, self.negated))
                 }
+                DataType::Binary => {
+                    let array = array
+                        .as_any()
+                        .downcast_ref::<GenericBinaryArray<i32>>()
+                        .unwrap();
+                    Ok(set_contains_binary(array, set, self.negated))
+                }
+                DataType::LargeBinary => {
+                    let array = array
+                        .as_any()
+                        .downcast_ref::<GenericBinaryArray<i64>>()
+                        .unwrap();
+                    Ok(set_contains_binary(array, set, self.negated))
+                }
                 DataType::Decimal128(_, _) => {
                     let array = 
array.as_any().downcast_ref::<Decimal128Array>().unwrap();
                     Ok(make_set_contains_decimal(array, set, self.negated))
@@ -795,6 +801,12 @@ impl PhysicalExpr for InListExpr {
                 DataType::LargeUtf8 => {
                     self.compare_utf8::<i64>(array, list_values, self.negated)
                 }
+                DataType::Binary => {
+                    self.compare_binary::<i32>(array, list_values, 
self.negated)
+                }
+                DataType::LargeBinary => {
+                    self.compare_binary::<i64>(array, list_values, 
self.negated)
+                }
                 DataType::Null => {
                     let null_array = new_null_array(&DataType::Boolean, 
array.len());
                     Ok(ColumnarValue::Array(Arc::new(null_array)))
@@ -906,6 +918,66 @@ mod tests {
         Ok(())
     }
 
+    #[test]
+    fn in_list_binary() -> Result<()> {
+        let schema = Schema::new(vec![Field::new("a", DataType::Binary, 
true)]);
+        let a = BinaryArray::from(vec![
+            Some([1, 2, 3].as_slice()),
+            Some([1, 2, 2].as_slice()),
+            None,
+        ]);
+        let col_a = col("a", &schema)?;
+        let batch = RecordBatch::try_new(Arc::new(schema.clone()), 
vec![Arc::new(a)])?;
+
+        // expression: "a in ([1, 2, 3], [4, 5, 6])"
+        let list = vec![lit([1, 2, 3].as_slice()), lit([4, 5, 6].as_slice())];
+        in_list!(
+            batch,
+            list.clone(),
+            &false,
+            vec![Some(true), Some(false), None],
+            col_a.clone(),
+            &schema
+        );
+
+        // expression: "a not in ([1, 2, 3], [4, 5, 6])"
+        in_list!(
+            batch,
+            list,
+            &true,
+            vec![Some(false), Some(true), None],
+            col_a.clone(),
+            &schema
+        );
+
+        // expression: "a in ([1, 2, 3], [4, 5, 6], null)"
+        let list = vec![
+            lit([1, 2, 3].as_slice()),
+            lit([4, 5, 6].as_slice()),
+            lit(ScalarValue::Binary(None)),
+        ];
+        in_list!(
+            batch,
+            list.clone(),
+            &false,
+            vec![Some(true), None, None],
+            col_a.clone(),
+            &schema
+        );
+
+        // expression: "a in ([1, 2, 3], [4, 5, 6], null)"
+        in_list!(
+            batch,
+            list,
+            &true,
+            vec![Some(false), None, None],
+            col_a.clone(),
+            &schema
+        );
+
+        Ok(())
+    }
+
     #[test]
     fn in_list_int64() -> Result<()> {
         let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]);
@@ -1316,6 +1388,43 @@ mod tests {
         Ok(())
     }
 
+    #[test]
+    fn in_list_set_binary() -> Result<()> {
+        let schema = Schema::new(vec![Field::new("a", DataType::Binary, 
true)]);
+        let a = BinaryArray::from(vec![
+            Some([1, 2, 3].as_slice()),
+            Some([3, 2, 1].as_slice()),
+            None,
+        ]);
+        let col_a = col("a", &schema)?;
+        let batch = RecordBatch::try_new(Arc::new(schema.clone()), 
vec![Arc::new(a)])?;
+
+        let mut list = vec![lit([1, 2, 3].as_slice()), 
lit(ScalarValue::Binary(None))];
+        for v in 0..OPTIMIZER_INSET_THRESHOLD {
+            list.push(lit([v as u8].as_slice()));
+        }
+
+        in_list!(
+            batch,
+            list.clone(),
+            &false,
+            vec![Some(true), None, None],
+            col_a.clone(),
+            &schema
+        );
+
+        in_list!(
+            batch,
+            list.clone(),
+            &true,
+            vec![Some(false), None, None],
+            col_a.clone(),
+            &schema
+        );
+
+        Ok(())
+    }
+
     #[test]
     fn in_list_set_decimal() -> Result<()> {
         let schema =

Reply via email to