This is an automated email from the ASF dual-hosted git repository.

jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new ff116c3da6 Support filter for List (#11091)
ff116c3da6 is described below

commit ff116c3da69897358f210a3ea944c8e51dcb7b52
Author: Jay Zhan <[email protected]>
AuthorDate: Thu Jun 27 16:40:16 2024 +0800

    Support filter for List (#11091)
    
    * support basic list cmp
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * add more ops
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * add distinct
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * nested
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * add comment
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    ---------
    
    Signed-off-by: jayzhan211 <[email protected]>
---
 datafusion/physical-expr-common/src/datum.rs       | 180 +++++++++++++++++++++
 datafusion/physical-expr-common/src/lib.rs         |   1 +
 datafusion/physical-expr/src/expressions/binary.rs |   9 +-
 datafusion/physical-expr/src/expressions/datum.rs  |  58 -------
 datafusion/physical-expr/src/expressions/like.rs   |   2 +-
 datafusion/physical-expr/src/expressions/mod.rs    |   1 -
 datafusion/sqllogictest/test_files/array_query.slt | 128 ++++++++++++++-
 7 files changed, 312 insertions(+), 67 deletions(-)

diff --git a/datafusion/physical-expr-common/src/datum.rs 
b/datafusion/physical-expr-common/src/datum.rs
new file mode 100644
index 0000000000..f4ce0eebc0
--- /dev/null
+++ b/datafusion/physical-expr-common/src/datum.rs
@@ -0,0 +1,180 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// UnLt required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::array::BooleanArray;
+use arrow::array::{make_comparator, ArrayRef, Datum};
+use arrow::buffer::NullBuffer;
+use arrow::compute::SortOptions;
+use arrow::error::ArrowError;
+use datafusion_common::internal_err;
+use datafusion_common::{Result, ScalarValue};
+use datafusion_expr::{ColumnarValue, Operator};
+use std::sync::Arc;
+
+/// Applies a binary [`Datum`] kernel `f` to `lhs` and `rhs`
+///
+/// This maps arrow-rs' [`Datum`] kernels to DataFusion's [`ColumnarValue`] 
abstraction
+pub fn apply(
+    lhs: &ColumnarValue,
+    rhs: &ColumnarValue,
+    f: impl Fn(&dyn Datum, &dyn Datum) -> Result<ArrayRef, ArrowError>,
+) -> Result<ColumnarValue> {
+    match (&lhs, &rhs) {
+        (ColumnarValue::Array(left), ColumnarValue::Array(right)) => {
+            Ok(ColumnarValue::Array(f(&left.as_ref(), &right.as_ref())?))
+        }
+        (ColumnarValue::Scalar(left), ColumnarValue::Array(right)) => Ok(
+            ColumnarValue::Array(f(&left.to_scalar()?, &right.as_ref())?),
+        ),
+        (ColumnarValue::Array(left), ColumnarValue::Scalar(right)) => Ok(
+            ColumnarValue::Array(f(&left.as_ref(), &right.to_scalar()?)?),
+        ),
+        (ColumnarValue::Scalar(left), ColumnarValue::Scalar(right)) => {
+            let array = f(&left.to_scalar()?, &right.to_scalar()?)?;
+            let scalar = ScalarValue::try_from_array(array.as_ref(), 0)?;
+            Ok(ColumnarValue::Scalar(scalar))
+        }
+    }
+}
+
+/// Applies a binary [`Datum`] comparison kernel `f` to `lhs` and `rhs`
+pub fn apply_cmp(
+    lhs: &ColumnarValue,
+    rhs: &ColumnarValue,
+    f: impl Fn(&dyn Datum, &dyn Datum) -> Result<BooleanArray, ArrowError>,
+) -> Result<ColumnarValue> {
+    apply(lhs, rhs, |l, r| Ok(Arc::new(f(l, r)?)))
+}
+
+/// Applies a binary [`Datum`] comparison kernel `f` to `lhs` and `rhs` for 
nested type like
+/// List, FixedSizeList, LargeList, Struct, Union, Map, or a dictionary of a 
nested type
+pub fn apply_cmp_for_nested(
+    op: Operator,
+    lhs: &ColumnarValue,
+    rhs: &ColumnarValue,
+) -> Result<ColumnarValue> {
+    if matches!(
+        op,
+        Operator::Eq
+            | Operator::NotEq
+            | Operator::Lt
+            | Operator::Gt
+            | Operator::LtEq
+            | Operator::GtEq
+            | Operator::IsDistinctFrom
+            | Operator::IsNotDistinctFrom
+    ) {
+        apply(lhs, rhs, |l, r| {
+            Ok(Arc::new(compare_op_for_nested(op, l, r)?))
+        })
+    } else {
+        internal_err!("invalid operator for nested")
+    }
+}
+
+/// Compare on nested type List, Struct, and so on
+fn compare_op_for_nested(
+    op: Operator,
+    lhs: &dyn Datum,
+    rhs: &dyn Datum,
+) -> Result<BooleanArray> {
+    let (l, is_l_scalar) = lhs.get();
+    let (r, is_r_scalar) = rhs.get();
+    let l_len = l.len();
+    let r_len = r.len();
+
+    if l_len != r_len && !is_l_scalar && !is_r_scalar {
+        return internal_err!("len mismatch");
+    }
+
+    let len = match is_l_scalar {
+        true => r_len,
+        false => l_len,
+    };
+
+    // fast path, if compare with one null and operator is not 'distinct', 
then we can return null array directly
+    if !matches!(op, Operator::IsDistinctFrom | Operator::IsNotDistinctFrom)
+        && (is_l_scalar && l.null_count() == 1 || is_r_scalar && 
r.null_count() == 1)
+    {
+        return Ok(BooleanArray::new_null(len));
+    }
+
+    // TODO: make SortOptions configurable
+    // we choose the default behaviour from arrow-rs which has null-first that 
follow spark's behaviour
+    let cmp = make_comparator(l, r, SortOptions::default())?;
+
+    let cmp_with_op = |i, j| match op {
+        Operator::Eq | Operator::IsNotDistinctFrom => cmp(i, j).is_eq(),
+        Operator::Lt => cmp(i, j).is_lt(),
+        Operator::Gt => cmp(i, j).is_gt(),
+        Operator::LtEq => !cmp(i, j).is_gt(),
+        Operator::GtEq => !cmp(i, j).is_lt(),
+        Operator::NotEq | Operator::IsDistinctFrom => !cmp(i, j).is_eq(),
+        _ => unreachable!("unexpected operator found"),
+    };
+
+    let values = match (is_l_scalar, is_r_scalar) {
+        (false, false) => (0..len).map(|i| cmp_with_op(i, i)).collect(),
+        (true, false) => (0..len).map(|i| cmp_with_op(0, i)).collect(),
+        (false, true) => (0..len).map(|i| cmp_with_op(i, 0)).collect(),
+        (true, true) => std::iter::once(cmp_with_op(0, 0)).collect(),
+    };
+
+    // Distinct understand how to compare with NULL
+    // i.e NULL is distinct from NULL -> false
+    if matches!(op, Operator::IsDistinctFrom | Operator::IsNotDistinctFrom) {
+        Ok(BooleanArray::new(values, None))
+    } else {
+        // If one of the side is NULL, we returns NULL
+        // i.e. NULL eq NULL -> NULL
+        let nulls = NullBuffer::union(l.nulls(), r.nulls());
+        Ok(BooleanArray::new(values, nulls))
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use arrow::{
+        array::{make_comparator, Array, BooleanArray, ListArray},
+        buffer::NullBuffer,
+        compute::SortOptions,
+        datatypes::Int32Type,
+    };
+
+    #[test]
+    fn test123() {
+        let data = vec![
+            Some(vec![Some(0), Some(1), Some(2)]),
+            None,
+            Some(vec![Some(3), None, Some(5)]),
+            Some(vec![Some(6), Some(7)]),
+        ];
+        let a = ListArray::from_iter_primitive::<Int32Type, _, _>(data);
+        let data = vec![
+            Some(vec![Some(0), Some(1), Some(2)]),
+            None,
+            Some(vec![Some(3), None, Some(5)]),
+            Some(vec![Some(6), Some(7)]),
+        ];
+        let b = ListArray::from_iter_primitive::<Int32Type, _, _>(data);
+        let cmp = make_comparator(&a, &b, SortOptions::default()).unwrap();
+        let len = a.len().min(b.len());
+        let values = (0..len).map(|i| cmp(i, i).is_eq()).collect();
+        let nulls = NullBuffer::union(a.nulls(), b.nulls());
+        println!("res: {:?}", BooleanArray::new(values, nulls));
+    }
+}
diff --git a/datafusion/physical-expr-common/src/lib.rs 
b/datafusion/physical-expr-common/src/lib.rs
index 0ddb84141a..8d50e0b964 100644
--- a/datafusion/physical-expr-common/src/lib.rs
+++ b/datafusion/physical-expr-common/src/lib.rs
@@ -17,6 +17,7 @@
 
 pub mod aggregate;
 pub mod binary_map;
+pub mod datum;
 pub mod expressions;
 pub mod physical_expr;
 pub mod sort_expr;
diff --git a/datafusion/physical-expr/src/expressions/binary.rs 
b/datafusion/physical-expr/src/expressions/binary.rs
index 98df0cba9f..3a8f7ee56a 100644
--- a/datafusion/physical-expr/src/expressions/binary.rs
+++ b/datafusion/physical-expr/src/expressions/binary.rs
@@ -20,7 +20,6 @@ mod kernels;
 use std::hash::{Hash, Hasher};
 use std::{any::Any, sync::Arc};
 
-use crate::expressions::datum::{apply, apply_cmp};
 use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison};
 use crate::physical_expr::down_cast_any_ref;
 use crate::PhysicalExpr;
@@ -40,6 +39,7 @@ use datafusion_expr::interval_arithmetic::{apply_operator, 
Interval};
 use datafusion_expr::sort_properties::ExprProperties;
 use datafusion_expr::type_coercion::binary::get_result_type;
 use datafusion_expr::{ColumnarValue, Operator};
+use datafusion_physical_expr_common::datum::{apply, apply_cmp, 
apply_cmp_for_nested};
 
 use kernels::{
     bitwise_and_dyn, bitwise_and_dyn_scalar, bitwise_or_dyn, 
bitwise_or_dyn_scalar,
@@ -265,6 +265,13 @@ impl PhysicalExpr for BinaryExpr {
         let schema = batch.schema();
         let input_schema = schema.as_ref();
 
+        if left_data_type.is_nested() {
+            if right_data_type != left_data_type {
+                return internal_err!("type mismatch");
+            }
+            return apply_cmp_for_nested(self.op, &lhs, &rhs);
+        }
+
         match self.op {
             Operator::Plus => return apply(&lhs, &rhs, add_wrapping),
             Operator::Minus => return apply(&lhs, &rhs, sub_wrapping),
diff --git a/datafusion/physical-expr/src/expressions/datum.rs 
b/datafusion/physical-expr/src/expressions/datum.rs
deleted file mode 100644
index 2bb79922cf..0000000000
--- a/datafusion/physical-expr/src/expressions/datum.rs
+++ /dev/null
@@ -1,58 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements.  See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership.  The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License.  You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied.  See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-use arrow::array::{ArrayRef, Datum};
-use arrow::error::ArrowError;
-use arrow_array::BooleanArray;
-use datafusion_common::{Result, ScalarValue};
-use datafusion_expr::ColumnarValue;
-use std::sync::Arc;
-
-/// Applies a binary [`Datum`] kernel `f` to `lhs` and `rhs`
-///
-/// This maps arrow-rs' [`Datum`] kernels to DataFusion's [`ColumnarValue`] 
abstraction
-pub(crate) fn apply(
-    lhs: &ColumnarValue,
-    rhs: &ColumnarValue,
-    f: impl Fn(&dyn Datum, &dyn Datum) -> Result<ArrayRef, ArrowError>,
-) -> Result<ColumnarValue> {
-    match (&lhs, &rhs) {
-        (ColumnarValue::Array(left), ColumnarValue::Array(right)) => {
-            Ok(ColumnarValue::Array(f(&left.as_ref(), &right.as_ref())?))
-        }
-        (ColumnarValue::Scalar(left), ColumnarValue::Array(right)) => Ok(
-            ColumnarValue::Array(f(&left.to_scalar()?, &right.as_ref())?),
-        ),
-        (ColumnarValue::Array(left), ColumnarValue::Scalar(right)) => Ok(
-            ColumnarValue::Array(f(&left.as_ref(), &right.to_scalar()?)?),
-        ),
-        (ColumnarValue::Scalar(left), ColumnarValue::Scalar(right)) => {
-            let array = f(&left.to_scalar()?, &right.to_scalar()?)?;
-            let scalar = ScalarValue::try_from_array(array.as_ref(), 0)?;
-            Ok(ColumnarValue::Scalar(scalar))
-        }
-    }
-}
-
-/// Applies a binary [`Datum`] comparison kernel `f` to `lhs` and `rhs`
-pub(crate) fn apply_cmp(
-    lhs: &ColumnarValue,
-    rhs: &ColumnarValue,
-    f: impl Fn(&dyn Datum, &dyn Datum) -> Result<BooleanArray, ArrowError>,
-) -> Result<ColumnarValue> {
-    apply(lhs, rhs, |l, r| Ok(Arc::new(f(l, r)?)))
-}
diff --git a/datafusion/physical-expr/src/expressions/like.rs 
b/datafusion/physical-expr/src/expressions/like.rs
index d18651c641..e0c02b0a90 100644
--- a/datafusion/physical-expr/src/expressions/like.rs
+++ b/datafusion/physical-expr/src/expressions/like.rs
@@ -20,11 +20,11 @@ use std::{any::Any, sync::Arc};
 
 use crate::{physical_expr::down_cast_any_ref, PhysicalExpr};
 
-use crate::expressions::datum::apply_cmp;
 use arrow::record_batch::RecordBatch;
 use arrow_schema::{DataType, Schema};
 use datafusion_common::{internal_err, Result};
 use datafusion_expr::ColumnarValue;
+use datafusion_physical_expr_common::datum::apply_cmp;
 
 // Like expression
 #[derive(Debug, Hash)]
diff --git a/datafusion/physical-expr/src/expressions/mod.rs 
b/datafusion/physical-expr/src/expressions/mod.rs
index c98bcc56ad..608609b81d 100644
--- a/datafusion/physical-expr/src/expressions/mod.rs
+++ b/datafusion/physical-expr/src/expressions/mod.rs
@@ -21,7 +21,6 @@
 mod binary;
 mod case;
 mod column;
-mod datum;
 mod in_list;
 mod is_not_null;
 mod is_null;
diff --git a/datafusion/sqllogictest/test_files/array_query.slt 
b/datafusion/sqllogictest/test_files/array_query.slt
index 24c99fc849..b29b5f5efd 100644
--- a/datafusion/sqllogictest/test_files/array_query.slt
+++ b/datafusion/sqllogictest/test_files/array_query.slt
@@ -41,17 +41,68 @@ SELECT * FROM data;
 # Filtering
 ###########
 
-query error DataFusion error: Arrow error: Invalid argument error: Invalid 
comparison operation: List\(Field \{ name: "item", data_type: Int64, nullable: 
true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) == List\(Field \{ 
name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: 
false, metadata: \{\} \}\)
+query ??I rowsort
 SELECT * FROM data WHERE column1 = [1,2,3];
+----
+[1, 2, 3] NULL 1
+[1, 2, 3] [4, 5] 1
 
-query error DataFusion error: Arrow error: Invalid argument error: Invalid 
comparison operation: List\(Field \{ name: "item", data_type: Int64, nullable: 
true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) == List\(Field \{ 
name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: 
false, metadata: \{\} \}\)
-SELECT * FROM data WHERE column1 = column2
-
-query error DataFusion error: Arrow error: Invalid argument error: Invalid 
comparison operation: List\(Field \{ name: "item", data_type: Int64, nullable: 
true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) != List\(Field \{ 
name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: 
false, metadata: \{\} \}\)
+query ??I
 SELECT * FROM data WHERE column1 != [1,2,3];
+----
+[2, 3] [2, 3] 1
 
-query error DataFusion error: Arrow error: Invalid argument error: Invalid 
comparison operation: List\(Field \{ name: "item", data_type: Int64, nullable: 
true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) != List\(Field \{ 
name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: 
false, metadata: \{\} \}\)
+query ??I
 SELECT * FROM data WHERE column1 != column2
+----
+[1, 2, 3] [4, 5] 1
+
+query ??I rowsort
+SELECT * FROM data WHERE column1 < [1,2,3,4];
+----
+[1, 2, 3] NULL 1
+[1, 2, 3] [4, 5] 1
+
+query ??I rowsort
+SELECT * FROM data WHERE column1 <= [2, 3];
+----
+[1, 2, 3] NULL 1
+[1, 2, 3] [4, 5] 1
+[2, 3] [2, 3] 1
+
+query ??I rowsort
+SELECT * FROM data WHERE column1 > [1,2];
+----
+[1, 2, 3] NULL 1
+[1, 2, 3] [4, 5] 1
+[2, 3] [2, 3] 1
+
+query ??I rowsort
+SELECT * FROM data WHERE column1 >= [1, 2, 3];
+----
+[1, 2, 3] NULL 1
+[1, 2, 3] [4, 5] 1
+[2, 3] [2, 3] 1
+
+# test with scalar null
+query ??I
+SELECT * FROM data WHERE column2 = null;
+----
+
+query ??I
+SELECT * FROM data WHERE null = column2;
+----
+
+query ??I
+SELECT * FROM data WHERE column2 is distinct from null;
+----
+[2, 3] [2, 3] 1
+[1, 2, 3] [4, 5] 1
+
+query ??I
+SELECT * FROM data WHERE column2 is not distinct from null;
+----
+[1, 2, 3] NULL 1
 
 ###########
 # Aggregates
@@ -158,3 +209,68 @@ SELECT * FROM data ORDER BY column1, column3, column2;
 
 statement ok
 drop table data
+
+
+# test filter column with all nulls
+statement ok
+create table data (a int) as values (null), (null), (null);
+
+query I
+select * from data where a = null;
+----
+
+query I
+select * from data where a is not distinct from null;
+----
+NULL
+NULL
+NULL
+
+statement ok
+drop table data;
+
+statement ok
+create table data (a int[][], b int) as values ([[1,2,3]], 1), ([[2,3], 
[4,5]], 2), (null, 3);
+
+query ?I
+select * from data;
+----
+[[1, 2, 3]] 1
+[[2, 3], [4, 5]] 2
+NULL 3
+
+query ?I
+select * from data where a = [[1,2,3]];
+----
+[[1, 2, 3]] 1
+
+query ?I
+select * from data where a > [[1,2,3]];
+----
+[[2, 3], [4, 5]] 2
+
+query ?I
+select * from data where a > [[1,2]];
+----
+[[1, 2, 3]] 1
+[[2, 3], [4, 5]] 2
+
+query ?I
+select * from data where a < [[2, 3]];
+----
+[[1, 2, 3]] 1
+
+# compare with null with eq results in null
+query ?I
+select * from data where a = null;
+----
+
+query ?I
+select * from data where a != null;
+----
+
+# compare with null with distinct results in true/false
+query ?I
+select * from data where a is not distinct from null;
+----
+NULL 3


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to