This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new fda561e803 fix: incorrect nullability calculation of InListExpr (#7496)
fda561e803 is described below
commit fda561e803d309018368696608afced4e771d17e
Author: Jonah Gao <[email protected]>
AuthorDate: Fri Sep 8 19:00:18 2023 +0800
fix: incorrect nullability calculation of InListExpr (#7496)
---
.../physical-expr/src/expressions/in_list.rs | 61 +++++++++++++++++++++-
datafusion/sqllogictest/test_files/scalar.slt | 17 +++---
2 files changed, 69 insertions(+), 9 deletions(-)
diff --git a/datafusion/physical-expr/src/expressions/in_list.rs
b/datafusion/physical-expr/src/expressions/in_list.rs
index c4b1369f6b..c92bbbb74f 100644
--- a/datafusion/physical-expr/src/expressions/in_list.rs
+++ b/datafusion/physical-expr/src/expressions/in_list.rs
@@ -66,6 +66,7 @@ impl Debug for InListExpr {
/// A type-erased container of array elements
pub trait Set: Send + Sync {
fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray>;
+ fn has_nulls(&self) -> bool;
}
struct ArrayHashSet {
@@ -134,6 +135,10 @@ where
})
.collect())
}
+
+ fn has_nulls(&self) -> bool {
+ self.array.null_count() != 0
+ }
}
/// Computes an [`ArrayHashSet`] for the provided [`Array`] if there
@@ -325,7 +330,20 @@ impl PhysicalExpr for InListExpr {
}
fn nullable(&self, input_schema: &Schema) -> Result<bool> {
- self.expr.nullable(input_schema)
+ if self.expr.nullable(input_schema)? {
+ return Ok(true);
+ }
+
+ if let Some(static_filter) = &self.static_filter {
+ Ok(static_filter.has_nulls())
+ } else {
+ for expr in &self.list {
+ if expr.nullable(input_schema)? {
+ return Ok(true);
+ }
+ }
+ Ok(false)
+ }
}
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
@@ -1203,4 +1221,45 @@ mod tests {
Ok(())
}
+
+ macro_rules! test_nullable {
+ ($COL:expr, $LIST:expr, $SCHEMA:expr, $EXPECTED:expr) => {{
+ let (cast_expr, cast_list_exprs) = in_list_cast($COL, $LIST,
$SCHEMA)?;
+ let expr = in_list(cast_expr, cast_list_exprs, &false,
$SCHEMA).unwrap();
+ let result = expr.nullable($SCHEMA)?;
+ assert_eq!($EXPECTED, result);
+ }};
+ }
+
+ #[test]
+ fn in_list_nullable() -> Result<()> {
+ let schema = Schema::new(vec![
+ Field::new("c1_nullable", DataType::Int64, true),
+ Field::new("c2_non_nullable", DataType::Int64, false),
+ ]);
+
+ let c1_nullable = col("c1_nullable", &schema)?;
+ let c2_non_nullable = col("c2_non_nullable", &schema)?;
+
+ // static_filter has no nulls
+ let list = vec![lit(1_i64), lit(2_i64)];
+ test_nullable!(c1_nullable.clone(), list.clone(), &schema, true);
+ test_nullable!(c2_non_nullable.clone(), list.clone(), &schema, false);
+
+ // static_filter has nulls
+ let list = vec![lit(1_i64), lit(2_i64), lit(ScalarValue::Null)];
+ test_nullable!(c1_nullable.clone(), list.clone(), &schema, true);
+ test_nullable!(c2_non_nullable.clone(), list.clone(), &schema, true);
+
+ let list = vec![c1_nullable.clone()];
+ test_nullable!(c2_non_nullable.clone(), list.clone(), &schema, true);
+
+ let list = vec![c2_non_nullable.clone()];
+ test_nullable!(c1_nullable.clone(), list.clone(), &schema, true);
+
+ let list = vec![c2_non_nullable.clone(), c2_non_nullable.clone()];
+ test_nullable!(c2_non_nullable.clone(), list.clone(), &schema, false);
+
+ Ok(())
+ }
}
diff --git a/datafusion/sqllogictest/test_files/scalar.slt
b/datafusion/sqllogictest/test_files/scalar.slt
index 7c1add8891..7d7f269606 100644
--- a/datafusion/sqllogictest/test_files/scalar.slt
+++ b/datafusion/sqllogictest/test_files/scalar.slt
@@ -1325,21 +1325,22 @@ SELECT arrow_typeof(c8), arrow_typeof(c6),
arrow_typeof(c8 + c6) FROM aggregate_
Int32 Int64 Int64
# in list array
-query BBBBB rowsort
+query BBBBBB rowsort
SELECT c1 IN ('a', 'c') AS utf8_in_true
,c1 IN ('x', 'y') AS utf8_in_false
,c1 NOT IN ('x', 'y') AS utf8_not_in_true
,c1 NOT IN ('a', 'c') AS utf8_not_in_false
,NULL IN ('a', 'c') AS utf8_in_null
+ ,'a' IN (c1, NULL, 'c') uft8_in_column
FROM aggregate_test_100 WHERE c12 < 0.05
----
-false false true true NULL
-false false true true NULL
-false false true true NULL
-false false true true NULL
-true false true false NULL
-true false true false NULL
-true false true false NULL
+false false true true NULL NULL
+false false true true NULL NULL
+false false true true NULL NULL
+false false true true NULL NULL
+true false true false NULL NULL
+true false true false NULL true
+true false true false NULL true
# csv count star
query III