This is an automated email from the ASF dual-hosted git repository.

github-bot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 6d3854f691 #17972 Restore case expr/expr optimisation while ensuring 
lazy evaluation (#17973)
6d3854f691 is described below

commit 6d3854f691545808b0a2a15e5fa634971953d4a1
Author: Pepijn Van Eeckhoudt <[email protected]>
AuthorDate: Wed Oct 15 12:35:54 2025 +0200

    #17972 Restore case expr/expr optimisation while ensuring lazy evaluation 
(#17973)
    
    * #17972 Restore case expr/expr optimisation while ensuring lazy evaluation
    
    * Avoid calling `PhysicalExpr::evaluate` from 
`PhysicalExpr::evaluate_selection` for empty selections.
    
    * Make `PhysicalExpr::evaluate_selection` correctly handle empty input sets 
and all false filters
    
    * Reoragnize code to avoid scatter codepath when using `evaluate` fast path.
    
    * Clarify comments in case
    
    * Move null handling after true count check.
    
    * Tweaking comments
    
    * Add unit tests to help define the boundary case behaviour of 
evaluate_selection
    
    * Code polishing
    - Add extra comments
    - Use match for the scatter paragraph
    - Validate that the size of selection and batch match
    
    * Fix clippy errors
    
    * Add additional case SLTs
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 .../physical-expr-common/src/physical_expr.rs      | 268 +++++++++++++++++++--
 datafusion/physical-expr/src/expressions/case.rs   |  15 +-
 datafusion/sqllogictest/test_files/case.slt        |  27 +++
 3 files changed, 282 insertions(+), 28 deletions(-)

diff --git a/datafusion/physical-expr-common/src/physical_expr.rs 
b/datafusion/physical-expr-common/src/physical_expr.rs
index 6f7c432c75..e5e7d6c00f 100644
--- a/datafusion/physical-expr-common/src/physical_expr.rs
+++ b/datafusion/physical-expr-common/src/physical_expr.rs
@@ -23,14 +23,14 @@ use std::sync::Arc;
 
 use crate::utils::scatter;
 
-use arrow::array::{ArrayRef, BooleanArray};
+use arrow::array::{new_empty_array, ArrayRef, BooleanArray};
 use arrow::compute::filter_record_batch;
 use arrow::datatypes::{DataType, Field, FieldRef, Schema};
 use arrow::record_batch::RecordBatch;
 use datafusion_common::tree_node::{
     Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
 };
-use datafusion_common::{internal_err, not_impl_err, Result, ScalarValue};
+use datafusion_common::{exec_err, internal_err, not_impl_err, Result, 
ScalarValue};
 use datafusion_expr_common::columnar_value::ColumnarValue;
 use datafusion_expr_common::interval_arithmetic::Interval;
 use datafusion_expr_common::sort_properties::ExprProperties;
@@ -90,36 +90,69 @@ pub trait PhysicalExpr: Any + Send + Sync + Display + Debug 
+ DynEq + DynHash {
             self.nullable(input_schema)?,
         )))
     }
-    /// Evaluate an expression against a RecordBatch after first applying a
-    /// validity array
+    /// Evaluate an expression against a RecordBatch after first applying a 
validity array
+    ///
+    /// # Errors
+    ///
+    /// Returns an `Err` if the expression could not be evaluated or if the 
length of the
+    /// `selection` validity array and the number of row in `batch` is not 
equal.
     fn evaluate_selection(
         &self,
         batch: &RecordBatch,
         selection: &BooleanArray,
     ) -> Result<ColumnarValue> {
-        let tmp_batch = filter_record_batch(batch, selection)?;
-
-        let tmp_result = self.evaluate(&tmp_batch)?;
-
-        if batch.num_rows() == tmp_batch.num_rows() {
-            // All values from the `selection` filter are true.
-            Ok(tmp_result)
-        } else if let ColumnarValue::Array(a) = tmp_result {
-            scatter(selection, a.as_ref()).map(ColumnarValue::Array)
-        } else if let ColumnarValue::Scalar(ScalarValue::Boolean(value)) = 
&tmp_result {
-            // When the scalar is true or false, skip the scatter process
-            if let Some(v) = value {
-                if *v {
-                    Ok(ColumnarValue::from(Arc::new(selection.clone()) as 
ArrayRef))
+        let row_count = batch.num_rows();
+        if row_count != selection.len() {
+            return exec_err!("Selection array length does not match batch row 
count: {} != {row_count}", selection.len());
+        }
+
+        let selection_count = selection.true_count();
+
+        // First, check if we can avoid filtering altogether.
+        if selection_count == row_count {
+            // All values from the `selection` filter are true and match the 
input batch.
+            // No need to perform any filtering.
+            return self.evaluate(batch);
+        }
+
+        // Next, prepare the result array for each 'true' row in the selection 
vector.
+        let filtered_result = if selection_count == 0 {
+            // Do not call `evaluate` when the selection is empty.
+            // `evaluate_selection` is used to conditionally evaluate 
expressions.
+            // When the expression in question is fallible, evaluating it with 
an empty
+            // record batch may trigger a runtime error (e.g. division by 
zero).
+            //
+            // Instead, create an empty array matching the expected return 
type.
+            let datatype = self.data_type(batch.schema_ref().as_ref())?;
+            ColumnarValue::Array(new_empty_array(&datatype))
+        } else {
+            // If we reach this point, there's no other option than to filter 
the batch.
+            // This is a fairly costly operation since it requires creating 
partial copies
+            // (worst case of length `row_count - 1`) of all the arrays in the 
record batch.
+            // The resulting `filtered_batch` will contain `selection_count` 
rows.
+            let filtered_batch = filter_record_batch(batch, selection)?;
+            self.evaluate(&filtered_batch)?
+        };
+
+        // Finally, scatter the filtered result array so that the indices 
match the input rows again.
+        match &filtered_result {
+            ColumnarValue::Array(a) => {
+                scatter(selection, a.as_ref()).map(ColumnarValue::Array)
+            }
+            ColumnarValue::Scalar(ScalarValue::Boolean(value)) => {
+                // When the scalar is true or false, skip the scatter process
+                if let Some(v) = value {
+                    if *v {
+                        Ok(ColumnarValue::from(Arc::new(selection.clone()) as 
ArrayRef))
+                    } else {
+                        Ok(filtered_result)
+                    }
                 } else {
-                    Ok(tmp_result)
+                    let array = BooleanArray::from(vec![None; row_count]);
+                    scatter(selection, &array).map(ColumnarValue::Array)
                 }
-            } else {
-                let array = BooleanArray::from(vec![None; batch.num_rows()]);
-                scatter(selection, &array).map(ColumnarValue::Array)
             }
-        } else {
-            Ok(tmp_result)
+            ColumnarValue::Scalar(_) => Ok(filtered_result),
         }
     }
 
@@ -601,3 +634,190 @@ pub fn is_volatile(expr: &Arc<dyn PhysicalExpr>) -> bool {
     .expect("infallible closure should not fail");
     is_volatile
 }
+
+#[cfg(test)]
+mod test {
+    use crate::physical_expr::PhysicalExpr;
+    use arrow::array::{Array, BooleanArray, Int64Array, RecordBatch};
+    use arrow::datatypes::{DataType, Schema};
+    use datafusion_expr_common::columnar_value::ColumnarValue;
+    use std::fmt::{Display, Formatter};
+    use std::sync::Arc;
+
+    #[derive(Debug, PartialEq, Eq, Hash)]
+    struct TestExpr {}
+
+    impl PhysicalExpr for TestExpr {
+        fn as_any(&self) -> &dyn std::any::Any {
+            self
+        }
+
+        fn data_type(&self, _schema: &Schema) -> 
datafusion_common::Result<DataType> {
+            Ok(DataType::Int64)
+        }
+
+        fn nullable(&self, _schema: &Schema) -> 
datafusion_common::Result<bool> {
+            Ok(false)
+        }
+
+        fn evaluate(
+            &self,
+            batch: &RecordBatch,
+        ) -> datafusion_common::Result<ColumnarValue> {
+            let data = vec![1; batch.num_rows()];
+            Ok(ColumnarValue::Array(Arc::new(Int64Array::from(data))))
+        }
+
+        fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
+            vec![]
+        }
+
+        fn with_new_children(
+            self: Arc<Self>,
+            _children: Vec<Arc<dyn PhysicalExpr>>,
+        ) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
+            Ok(Arc::new(Self {}))
+        }
+
+        fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+            f.write_str("TestExpr")
+        }
+    }
+
+    impl Display for TestExpr {
+        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+            self.fmt_sql(f)
+        }
+    }
+
+    macro_rules! assert_arrays_eq {
+        ($EXPECTED: expr, $ACTUAL: expr, $MESSAGE: expr) => {
+            let expected = $EXPECTED.to_array(1).unwrap();
+            let actual = $ACTUAL;
+
+            let actual_array = actual.to_array(expected.len()).unwrap();
+            let actual_ref = actual_array.as_ref();
+            let expected_ref = expected.as_ref();
+            assert!(
+                actual_ref == expected_ref,
+                "{}: expected: {:?}, actual: {:?}",
+                $MESSAGE,
+                $EXPECTED,
+                actual_ref
+            );
+        };
+    }
+
+    fn test_evaluate_selection(
+        batch: &RecordBatch,
+        selection: &BooleanArray,
+        expected: &ColumnarValue,
+    ) {
+        let expr = TestExpr {};
+
+        // First check that the `evaluate_selection` is the expected one
+        let selection_result = expr.evaluate_selection(batch, 
selection).unwrap();
+        assert_eq!(
+            expected.to_array(1).unwrap().len(),
+            selection_result.to_array(1).unwrap().len(),
+            "evaluate_selection should output row count should match input 
record batch"
+        );
+        assert_arrays_eq!(
+            expected,
+            &selection_result,
+            "evaluate_selection returned unexpected value"
+        );
+
+        // If we're selecting all rows, the result should be the same as 
calling `evaluate`
+        // with the full record batch.
+        if (0..batch.num_rows())
+            .all(|row_idx| row_idx < selection.len() && 
selection.value(row_idx))
+        {
+            let empty_result = expr.evaluate(batch).unwrap();
+
+            assert_arrays_eq!(
+                empty_result,
+                &selection_result,
+                "evaluate_selection does not match unfiltered evaluate result"
+            );
+        }
+    }
+
+    fn test_evaluate_selection_error(batch: &RecordBatch, selection: 
&BooleanArray) {
+        let expr = TestExpr {};
+
+        // First check that the `evaluate_selection` is the expected one
+        let selection_result = expr.evaluate_selection(batch, selection);
+        assert!(selection_result.is_err(), "evaluate_selection should fail");
+    }
+
+    #[test]
+    pub fn test_evaluate_selection_with_empty_record_batch() {
+        test_evaluate_selection(
+            &RecordBatch::new_empty(Arc::new(Schema::empty())),
+            &BooleanArray::from(vec![false; 0]),
+            &ColumnarValue::Array(Arc::new(Int64Array::new_null(0))),
+        );
+    }
+
+    #[test]
+    pub fn 
test_evaluate_selection_with_empty_record_batch_with_larger_false_selection() {
+        test_evaluate_selection_error(
+            &RecordBatch::new_empty(Arc::new(Schema::empty())),
+            &BooleanArray::from(vec![false; 10]),
+        );
+    }
+
+    #[test]
+    pub fn 
test_evaluate_selection_with_empty_record_batch_with_larger_true_selection() {
+        test_evaluate_selection_error(
+            &RecordBatch::new_empty(Arc::new(Schema::empty())),
+            &BooleanArray::from(vec![true; 10]),
+        );
+    }
+
+    #[test]
+    pub fn test_evaluate_selection_with_non_empty_record_batch() {
+        test_evaluate_selection(
+            unsafe { &RecordBatch::new_unchecked(Arc::new(Schema::empty()), 
vec![], 10) },
+            &BooleanArray::from(vec![true; 10]),
+            &ColumnarValue::Array(Arc::new(Int64Array::from(vec![1; 10]))),
+        );
+    }
+
+    #[test]
+    pub fn 
test_evaluate_selection_with_non_empty_record_batch_with_larger_false_selection(
+    ) {
+        test_evaluate_selection_error(
+            unsafe { &RecordBatch::new_unchecked(Arc::new(Schema::empty()), 
vec![], 10) },
+            &BooleanArray::from(vec![false; 20]),
+        );
+    }
+
+    #[test]
+    pub fn 
test_evaluate_selection_with_non_empty_record_batch_with_larger_true_selection(
+    ) {
+        test_evaluate_selection_error(
+            unsafe { &RecordBatch::new_unchecked(Arc::new(Schema::empty()), 
vec![], 10) },
+            &BooleanArray::from(vec![true; 20]),
+        );
+    }
+
+    #[test]
+    pub fn 
test_evaluate_selection_with_non_empty_record_batch_with_smaller_false_selection(
+    ) {
+        test_evaluate_selection_error(
+            unsafe { &RecordBatch::new_unchecked(Arc::new(Schema::empty()), 
vec![], 10) },
+            &BooleanArray::from(vec![false; 5]),
+        );
+    }
+
+    #[test]
+    pub fn 
test_evaluate_selection_with_non_empty_record_batch_with_smaller_true_selection(
+    ) {
+        test_evaluate_selection_error(
+            unsafe { &RecordBatch::new_unchecked(Arc::new(Schema::empty()), 
vec![], 10) },
+            &BooleanArray::from(vec![true; 5]),
+        );
+    }
+}
diff --git a/datafusion/physical-expr/src/expressions/case.rs 
b/datafusion/physical-expr/src/expressions/case.rs
index 5409cfe8e7..d14146a20d 100644
--- a/datafusion/physical-expr/src/expressions/case.rs
+++ b/datafusion/physical-expr/src/expressions/case.rs
@@ -155,10 +155,7 @@ impl CaseExpr {
                 && else_expr.as_ref().unwrap().as_any().is::<Literal>()
             {
                 EvalMethod::ScalarOrScalar
-            } else if when_then_expr.len() == 1
-                && is_cheap_and_infallible(&(when_then_expr[0].1))
-                && else_expr.as_ref().is_some_and(is_cheap_and_infallible)
-            {
+            } else if when_then_expr.len() == 1 && else_expr.is_some() {
                 EvalMethod::ExpressionOrExpression
             } else {
                 EvalMethod::NoExpression
@@ -425,6 +422,16 @@ impl CaseExpr {
             )
         })?;
 
+        // For the true and false/null selection vectors, bypass 
`evaluate_selection` and merging
+        // results. This avoids materializing the array for the other branch 
which we will discard
+        // entirely anyway.
+        let true_count = when_value.true_count();
+        if true_count == batch.num_rows() {
+            return self.when_then_expr[0].1.evaluate(batch);
+        } else if true_count == 0 {
+            return self.else_expr.as_ref().unwrap().evaluate(batch);
+        }
+
         // Treat 'NULL' as false value
         let when_value = match when_value.null_count() {
             0 => Cow::Borrowed(when_value),
diff --git a/datafusion/sqllogictest/test_files/case.slt 
b/datafusion/sqllogictest/test_files/case.slt
index 69f80f4593..9bc1f83ed1 100644
--- a/datafusion/sqllogictest/test_files/case.slt
+++ b/datafusion/sqllogictest/test_files/case.slt
@@ -467,6 +467,7 @@ FROM t;
 ----
 [{foo: blarg}]
 
+# mix of then and else
 query II
 SELECT v, CASE WHEN v != 0 THEN 10/v ELSE 42 END FROM (VALUES (0), (1), (2)) 
t(v)
 ----
@@ -474,12 +475,38 @@ SELECT v, CASE WHEN v != 0 THEN 10/v ELSE 42 END FROM 
(VALUES (0), (1), (2)) t(v
 1 10
 2 5
 
+# when expressions is always false, then branch should never be evaluated
 query II
 SELECT v, CASE WHEN v < 0 THEN 10/0 ELSE 1 END FROM (VALUES (1), (2)) t(v)
 ----
 1 1
 2 1
 
+# when expressions is always true, else branch should never be evaluated
+query II
+SELECT v, CASE WHEN v > 0 THEN 1 ELSE 10/0 END FROM (VALUES (1), (2)) t(v)
+----
+1 1
+2 1
+
+
+# lazy evaluation of multiple when branches, else branch should never be 
evaluated
+query II
+SELECT v, CASE WHEN v == 1 THEN -1 WHEN v == 2 THEN -2 WHEN v == 3 THEN -3  
ELSE 10/0 END FROM (VALUES (1), (2), (3)) t(v)
+----
+1 -1
+2 -2
+3 -3
+
+# covers the InfallibleExprOrNull evaluation strategy
+query II
+SELECT v, CASE WHEN v THEN 1 END FROM (VALUES (1), (2), (3), (NULL)) t(v)
+----
+1 1
+2 1
+3 1
+NULL NULL
+
 statement ok
 drop table t
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to