This is an automated email from the ASF dual-hosted git repository.

jakevin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 84832ac3f1 fix: incorrect nullability of `InList` expr (#6799)
84832ac3f1 is described below

commit 84832ac3f13af78336f3fa291e88e05f3909faac
Author: Jonah Gao <[email protected]>
AuthorDate: Sat Jul 1 15:21:39 2023 +0800

    fix: incorrect nullability of `InList` expr (#6799)
    
    * fix: incorrect nullability of InList expr
    
    * Update datafusion/expr/src/expr_schema.rs
    
    Improve code readability
    
    Co-authored-by: Andrew Lamb <[email protected]>
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 datafusion/core/src/physical_plan/insert.rs |  3 +-
 datafusion/expr/src/expr_schema.rs          | 64 +++++++++++++++++++++++++++--
 2 files changed, 62 insertions(+), 5 deletions(-)

diff --git a/datafusion/core/src/physical_plan/insert.rs 
b/datafusion/core/src/physical_plan/insert.rs
index 15f77914d4..8bcb9bab83 100644
--- a/datafusion/core/src/physical_plan/insert.rs
+++ b/datafusion/core/src/physical_plan/insert.rs
@@ -278,8 +278,7 @@ fn check_not_null_contraits(
     batch: RecordBatch,
     column_indices: &Vec<usize>,
 ) -> Result<RecordBatch> {
-    for i in column_indices {
-        let index = *i;
+    for &index in column_indices {
         if batch.num_columns() <= index {
             return Err(DataFusionError::Execution(format!(
                 "Invalid batch column count {} expected > {}",
diff --git a/datafusion/expr/src/expr_schema.rs 
b/datafusion/expr/src/expr_schema.rs
index 83ccc8c596..5861a76e88 100644
--- a/datafusion/expr/src/expr_schema.rs
+++ b/datafusion/expr/src/expr_schema.rs
@@ -173,8 +173,30 @@ impl ExprSchemable for Expr {
             Expr::Alias(Alias { expr, .. })
             | Expr::Not(expr)
             | Expr::Negative(expr)
-            | Expr::Sort(Sort { expr, .. })
-            | Expr::InList(InList { expr, .. }) => expr.nullable(input_schema),
+            | Expr::Sort(Sort { expr, .. }) => expr.nullable(input_schema),
+
+            Expr::InList(InList { expr, list, .. }) => {
+                // Avoid inspecting too many expressions.
+                const MAX_INSPECT_LIMIT: usize = 6;
+                // Stop if a nullable expression is found or an error occurs.
+                let has_nullable = std::iter::once(expr.as_ref())
+                    .chain(list)
+                    .take(MAX_INSPECT_LIMIT)
+                    .find_map(|e| {
+                        e.nullable(input_schema)
+                            .map(|nullable| if nullable { Some(()) } else { 
None })
+                            .transpose()
+                    })
+                    .transpose()?;
+                Ok(match has_nullable {
+                    // If a nullable subexpression is found, the result may 
also be nullable.
+                    Some(_) => true,
+                    // If the list is too long, we assume it is nullable.
+                    None if list.len() + 1 > MAX_INSPECT_LIMIT => true,
+                    // All the subexpressions are non-nullable, so the result 
must be non-nullable.
+                    _ => false,
+                })
+            }
 
             Expr::Between(Between {
                 expr, low, high, ..
@@ -390,6 +412,31 @@ mod tests {
         assert!(expr.nullable(&get_schema(false)).unwrap());
     }
 
+    #[test]
+    fn test_inlist_nullability() {
+        let get_schema = |nullable| {
+            MockExprSchema::new()
+                .with_data_type(DataType::Int32)
+                .with_nullable(nullable)
+        };
+
+        let expr = col("foo").in_list(vec![lit(1); 5], false);
+        assert!(!expr.nullable(&get_schema(false)).unwrap());
+        assert!(expr.nullable(&get_schema(true)).unwrap());
+        // Testing nullable() returns an error.
+        assert!(expr
+            .nullable(&get_schema(false).with_error_on_nullable(true))
+            .is_err());
+
+        let null = lit(ScalarValue::Int32(None));
+        let expr = col("foo").in_list(vec![null, lit(1)], false);
+        assert!(expr.nullable(&get_schema(false)).unwrap());
+
+        // Testing on long list
+        let expr = col("foo").in_list(vec![lit(1); 6], false);
+        assert!(expr.nullable(&get_schema(false)).unwrap());
+    }
+
     #[test]
     fn expr_schema_data_type() {
         let expr = col("foo");
@@ -404,6 +451,7 @@ mod tests {
     struct MockExprSchema {
         nullable: bool,
         data_type: DataType,
+        error_on_nullable: bool,
     }
 
     impl MockExprSchema {
@@ -411,6 +459,7 @@ mod tests {
             Self {
                 nullable: false,
                 data_type: DataType::Null,
+                error_on_nullable: false,
             }
         }
 
@@ -423,11 +472,20 @@ mod tests {
             self.data_type = data_type;
             self
         }
+
+        fn with_error_on_nullable(mut self, error_on_nullable: bool) -> Self {
+            self.error_on_nullable = error_on_nullable;
+            self
+        }
     }
 
     impl ExprSchema for MockExprSchema {
         fn nullable(&self, _col: &Column) -> Result<bool> {
-            Ok(self.nullable)
+            if self.error_on_nullable {
+                Err(DataFusionError::Internal("nullable error".into()))
+            } else {
+                Ok(self.nullable)
+            }
         }
 
         fn data_type(&self, _col: &Column) -> Result<&DataType> {

Reply via email to