This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new fda561e803 fix: incorrect nullability calculation of InListExpr (#7496)
fda561e803 is described below

commit fda561e803d309018368696608afced4e771d17e
Author: Jonah Gao <[email protected]>
AuthorDate: Fri Sep 8 19:00:18 2023 +0800

    fix: incorrect nullability calculation of InListExpr (#7496)
---
 .../physical-expr/src/expressions/in_list.rs       | 61 +++++++++++++++++++++-
 datafusion/sqllogictest/test_files/scalar.slt      | 17 +++---
 2 files changed, 69 insertions(+), 9 deletions(-)

diff --git a/datafusion/physical-expr/src/expressions/in_list.rs 
b/datafusion/physical-expr/src/expressions/in_list.rs
index c4b1369f6b..c92bbbb74f 100644
--- a/datafusion/physical-expr/src/expressions/in_list.rs
+++ b/datafusion/physical-expr/src/expressions/in_list.rs
@@ -66,6 +66,7 @@ impl Debug for InListExpr {
 /// A type-erased container of array elements
 pub trait Set: Send + Sync {
     fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray>;
+    fn has_nulls(&self) -> bool;
 }
 
 struct ArrayHashSet {
@@ -134,6 +135,10 @@ where
             })
             .collect())
     }
+
+    fn has_nulls(&self) -> bool {
+        self.array.null_count() != 0
+    }
 }
 
 /// Computes an [`ArrayHashSet`] for the provided [`Array`] if there
@@ -325,7 +330,20 @@ impl PhysicalExpr for InListExpr {
     }
 
     fn nullable(&self, input_schema: &Schema) -> Result<bool> {
-        self.expr.nullable(input_schema)
+        if self.expr.nullable(input_schema)? {
+            return Ok(true);
+        }
+
+        if let Some(static_filter) = &self.static_filter {
+            Ok(static_filter.has_nulls())
+        } else {
+            for expr in &self.list {
+                if expr.nullable(input_schema)? {
+                    return Ok(true);
+                }
+            }
+            Ok(false)
+        }
     }
 
     fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
@@ -1203,4 +1221,45 @@ mod tests {
 
         Ok(())
     }
+
+    macro_rules! test_nullable {
+        ($COL:expr, $LIST:expr, $SCHEMA:expr, $EXPECTED:expr) => {{
+            let (cast_expr, cast_list_exprs) = in_list_cast($COL, $LIST, 
$SCHEMA)?;
+            let expr = in_list(cast_expr, cast_list_exprs, &false, 
$SCHEMA).unwrap();
+            let result = expr.nullable($SCHEMA)?;
+            assert_eq!($EXPECTED, result);
+        }};
+    }
+
+    #[test]
+    fn in_list_nullable() -> Result<()> {
+        let schema = Schema::new(vec![
+            Field::new("c1_nullable", DataType::Int64, true),
+            Field::new("c2_non_nullable", DataType::Int64, false),
+        ]);
+
+        let c1_nullable = col("c1_nullable", &schema)?;
+        let c2_non_nullable = col("c2_non_nullable", &schema)?;
+
+        // static_filter has no nulls
+        let list = vec![lit(1_i64), lit(2_i64)];
+        test_nullable!(c1_nullable.clone(), list.clone(), &schema, true);
+        test_nullable!(c2_non_nullable.clone(), list.clone(), &schema, false);
+
+        // static_filter has nulls
+        let list = vec![lit(1_i64), lit(2_i64), lit(ScalarValue::Null)];
+        test_nullable!(c1_nullable.clone(), list.clone(), &schema, true);
+        test_nullable!(c2_non_nullable.clone(), list.clone(), &schema, true);
+
+        let list = vec![c1_nullable.clone()];
+        test_nullable!(c2_non_nullable.clone(), list.clone(), &schema, true);
+
+        let list = vec![c2_non_nullable.clone()];
+        test_nullable!(c1_nullable.clone(), list.clone(), &schema, true);
+
+        let list = vec![c2_non_nullable.clone(), c2_non_nullable.clone()];
+        test_nullable!(c2_non_nullable.clone(), list.clone(), &schema, false);
+
+        Ok(())
+    }
 }
diff --git a/datafusion/sqllogictest/test_files/scalar.slt 
b/datafusion/sqllogictest/test_files/scalar.slt
index 7c1add8891..7d7f269606 100644
--- a/datafusion/sqllogictest/test_files/scalar.slt
+++ b/datafusion/sqllogictest/test_files/scalar.slt
@@ -1325,21 +1325,22 @@ SELECT arrow_typeof(c8), arrow_typeof(c6), 
arrow_typeof(c8 + c6) FROM aggregate_
 Int32 Int64 Int64
 
 # in list array
-query BBBBB rowsort
+query BBBBBB rowsort
 SELECT c1 IN ('a', 'c') AS utf8_in_true
       ,c1 IN ('x', 'y') AS utf8_in_false
       ,c1 NOT IN ('x', 'y') AS utf8_not_in_true
       ,c1 NOT IN ('a', 'c') AS utf8_not_in_false
       ,NULL IN ('a', 'c') AS utf8_in_null
+      ,'a' IN (c1, NULL, 'c') uft8_in_column
 FROM aggregate_test_100 WHERE c12 < 0.05
 ----
-false false true true NULL
-false false true true NULL
-false false true true NULL
-false false true true NULL
-true false true false NULL
-true false true false NULL
-true false true false NULL
+false false true true NULL NULL
+false false true true NULL NULL
+false false true true NULL NULL
+false false true true NULL NULL
+true false true false NULL NULL
+true false true false NULL true
+true false true false NULL true
 
 # csv count star
 query III

Reply via email to