This is an automated email from the ASF dual-hosted git repository.

blaginin pushed a commit to branch annarose/dict-coercion
in repository https://gitbox.apache.org/repos/asf/datafusion-sandbox.git

commit 4dfc193cbf40d8f3b67b339ec067b6c787f3a6ec
Author: Pepijn Van Eeckhoudt <[email protected]>
AuthorDate: Tue Feb 3 20:40:34 2026 +0100

    Improve performance of `CASE WHEN x THEN y ELSE NULL` expressions (#20097)
    
    ## Which issue does this PR close?
    
    - Related to #11570
    
    ## Rationale for this change
    
    While reviewing #19994 it became clear the optimised
    `ExpressionOrExpression` code path was not being used when the case
    expression has no `else` branch or has `else null`. In those situations
    the general evaluation strategies could end up being used.
    This PR refines the `ExpressionOrExpression` implementation to also
    handle `else null` expressions.
    
    ## What changes are included in this PR?
    
    Use `ExpressionOrExpression` for expressions of the form `CASE WHEN x
    THEN y [ELSE NULL]`
    
    ## Are these changes tested?
    
    Covered by existing SLTs
    
    ## Are there any user-facing changes?
    
    No
---
 datafusion/physical-expr/src/expressions/case.rs | 79 +++++++++++++++---------
 datafusion/sqllogictest/test_files/case.slt      | 32 ++++++++++
 2 files changed, 82 insertions(+), 29 deletions(-)

diff --git a/datafusion/physical-expr/src/expressions/case.rs 
b/datafusion/physical-expr/src/expressions/case.rs
index 758317d3d..dac208be5 100644
--- a/datafusion/physical-expr/src/expressions/case.rs
+++ b/datafusion/physical-expr/src/expressions/case.rs
@@ -42,6 +42,7 @@ use 
crate::expressions::case::literal_lookup_table::LiteralLookupTable;
 use arrow::compute::kernels::merge::{MergeIndex, merge, merge_n};
 use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
 use datafusion_physical_expr_common::datum::compare_with_eq;
+use datafusion_physical_expr_common::utils::scatter;
 use itertools::Itertools;
 use std::fmt::{Debug, Formatter};
 
@@ -64,7 +65,7 @@ enum EvalMethod {
     /// for expressions that are infallible and can be cheaply computed for 
the entire
     /// record batch rather than just for the rows where the predicate is true.
     ///
-    /// CASE WHEN condition THEN column [ELSE NULL] END
+    /// CASE WHEN condition THEN infallible_expression [ELSE NULL] END
     InfallibleExprOrNull,
     /// This is a specialization for a specific use case where we can take a 
fast path
     /// if there is just one when/then pair and both the `then` and `else` 
expressions
@@ -72,9 +73,13 @@ enum EvalMethod {
     /// CASE WHEN condition THEN literal ELSE literal END
     ScalarOrScalar,
     /// This is a specialization for a specific use case where we can take a 
fast path
-    /// if there is just one when/then pair and both the `then` and `else` are 
expressions
+    /// if there is just one when/then pair, the `then` is an expression, and 
`else` is either
+    /// an expression, literal NULL or absent.
     ///
-    /// CASE WHEN condition THEN expression ELSE expression END
+    /// In contrast to [`EvalMethod::InfallibleExprOrNull`], this 
specialization can handle fallible
+    /// `then` expressions.
+    ///
+    /// CASE WHEN condition THEN expression [ELSE expression] END
     ExpressionOrExpression(ProjectedCaseBody),
 
     /// This is a specialization for [`EvalMethod::WithExpression`] when the 
value and results are literals
@@ -659,7 +664,7 @@ impl CaseExpr {
                 && body.else_expr.as_ref().unwrap().as_any().is::<Literal>()
             {
                 EvalMethod::ScalarOrScalar
-            } else if body.when_then_expr.len() == 1 && 
body.else_expr.is_some() {
+            } else if body.when_then_expr.len() == 1 {
                 EvalMethod::ExpressionOrExpression(body.project()?)
             } else {
                 EvalMethod::NoExpression(body.project()?)
@@ -961,32 +966,40 @@ impl CaseBody {
         let then_batch = filter_record_batch(batch, &when_filter)?;
         let then_value = self.when_then_expr[0].1.evaluate(&then_batch)?;
 
-        let else_selection = not(&when_value)?;
-        let else_filter = create_filter(&else_selection, optimize_filter);
-        let else_batch = filter_record_batch(batch, &else_filter)?;
-
-        // keep `else_expr`'s data type and return type consistent
-        let e = self.else_expr.as_ref().unwrap();
-        let return_type = self.data_type(&batch.schema())?;
-        let else_expr = try_cast(Arc::clone(e), &batch.schema(), 
return_type.clone())
-            .unwrap_or_else(|_| Arc::clone(e));
-
-        let else_value = else_expr.evaluate(&else_batch)?;
-
-        Ok(ColumnarValue::Array(match (then_value, else_value) {
-            (ColumnarValue::Array(t), ColumnarValue::Array(e)) => {
-                merge(&when_value, &t, &e)
-            }
-            (ColumnarValue::Scalar(t), ColumnarValue::Array(e)) => {
-                merge(&when_value, &t.to_scalar()?, &e)
-            }
-            (ColumnarValue::Array(t), ColumnarValue::Scalar(e)) => {
-                merge(&when_value, &t, &e.to_scalar()?)
+        match &self.else_expr {
+            None => {
+                let then_array = then_value.to_array(when_value.true_count())?;
+                scatter(&when_value, 
then_array.as_ref()).map(ColumnarValue::Array)
             }
-            (ColumnarValue::Scalar(t), ColumnarValue::Scalar(e)) => {
-                merge(&when_value, &t.to_scalar()?, &e.to_scalar()?)
+            Some(else_expr) => {
+                let else_selection = not(&when_value)?;
+                let else_filter = create_filter(&else_selection, 
optimize_filter);
+                let else_batch = filter_record_batch(batch, &else_filter)?;
+
+                // keep `else_expr`'s data type and return type consistent
+                let return_type = self.data_type(&batch.schema())?;
+                let else_expr =
+                    try_cast(Arc::clone(else_expr), &batch.schema(), 
return_type.clone())
+                        .unwrap_or_else(|_| Arc::clone(else_expr));
+
+                let else_value = else_expr.evaluate(&else_batch)?;
+
+                Ok(ColumnarValue::Array(match (then_value, else_value) {
+                    (ColumnarValue::Array(t), ColumnarValue::Array(e)) => {
+                        merge(&when_value, &t, &e)
+                    }
+                    (ColumnarValue::Scalar(t), ColumnarValue::Array(e)) => {
+                        merge(&when_value, &t.to_scalar()?, &e)
+                    }
+                    (ColumnarValue::Array(t), ColumnarValue::Scalar(e)) => {
+                        merge(&when_value, &t, &e.to_scalar()?)
+                    }
+                    (ColumnarValue::Scalar(t), ColumnarValue::Scalar(e)) => {
+                        merge(&when_value, &t.to_scalar()?, &e.to_scalar()?)
+                    }
+                }?))
             }
-        }?))
+        }
     }
 }
 
@@ -1137,7 +1150,15 @@ impl CaseExpr {
             self.body.when_then_expr[0].1.evaluate(batch)
         } else if true_count == 0 {
             // All input rows are false/null, just call the 'else' expression
-            self.body.else_expr.as_ref().unwrap().evaluate(batch)
+            match &self.body.else_expr {
+                Some(else_expr) => else_expr.evaluate(batch),
+                None => {
+                    let return_type = self.data_type(&batch.schema())?;
+                    Ok(ColumnarValue::Scalar(ScalarValue::try_new_null(
+                        &return_type,
+                    )?))
+                }
+            }
         } else if projected.projection.len() < batch.num_columns() {
             // The case expressions do not use all the columns of the input 
batch.
             // Project first to reduce time spent filtering.
diff --git a/datafusion/sqllogictest/test_files/case.slt 
b/datafusion/sqllogictest/test_files/case.slt
index 8bb17b57f..3953878ce 100644
--- a/datafusion/sqllogictest/test_files/case.slt
+++ b/datafusion/sqllogictest/test_files/case.slt
@@ -642,6 +642,38 @@ NULL
 NULL
 -1
 
+# single WHEN, no ELSE (absent)
+query I
+SELECT CASE WHEN a > 0 THEN b END
+FROM (VALUES (1, 10), (0, 20)) AS t(a, b);
+----
+10
+NULL
+
+# single WHEN, explicit ELSE NULL
+query I
+SELECT CASE WHEN a > 0 THEN b ELSE NULL END
+FROM (VALUES (1, 10), (0, 20)) AS t(a, b);
+----
+10
+NULL
+
+# fallible THEN expression should only be evaluated on true rows
+query I
+SELECT CASE WHEN a > 0 THEN 10 / a END
+FROM (VALUES (1), (0)) AS t(a);
+----
+10
+NULL
+
+# all-false path returns typed NULLs
+query I
+SELECT CASE WHEN a < 0 THEN b END
+FROM (VALUES (1, 10), (2, 20)) AS t(a, b);
+----
+NULL
+NULL
+
 # EvalMethod::WithExpression using subset of all selected columns in case 
expression
 query III
 SELECT CASE a1 WHEN 1 THEN a1 WHEN 2 THEN a2 WHEN 3 THEN b END, b, c


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to