This is an automated email from the ASF dual-hosted git repository.
github-bot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 6d3854f691 #17972 Restore case expr/expr optimisation while ensuring
lazy evaluation (#17973)
6d3854f691 is described below
commit 6d3854f691545808b0a2a15e5fa634971953d4a1
Author: Pepijn Van Eeckhoudt <[email protected]>
AuthorDate: Wed Oct 15 12:35:54 2025 +0200
#17972 Restore case expr/expr optimisation while ensuring lazy evaluation
(#17973)
* #17972 Restore case expr/expr optimisation while ensuring lazy evaluation
* Avoid calling `PhysicalExpr::evaluate` from
`PhysicalExpr::evaluate_selection` for empty selections.
* Make `PhysicalExpr::evaluate_selection` correctly handle empty input sets
and all false filters
* Reoragnize code to avoid scatter codepath when using `evaluate` fast path.
* Clarify comments in case
* Move null handling after true count check.
* Tweaking comments
* Add unit tests to help define the boundary case behaviour of
evaluate_selection
* Code polishing
- Add extra comments
- Use match for the scatter paragraph
- Validate that the size of selection and batch match
* Fix clippy errors
* Add additional case SLTs
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
.../physical-expr-common/src/physical_expr.rs | 268 +++++++++++++++++++--
datafusion/physical-expr/src/expressions/case.rs | 15 +-
datafusion/sqllogictest/test_files/case.slt | 27 +++
3 files changed, 282 insertions(+), 28 deletions(-)
diff --git a/datafusion/physical-expr-common/src/physical_expr.rs
b/datafusion/physical-expr-common/src/physical_expr.rs
index 6f7c432c75..e5e7d6c00f 100644
--- a/datafusion/physical-expr-common/src/physical_expr.rs
+++ b/datafusion/physical-expr-common/src/physical_expr.rs
@@ -23,14 +23,14 @@ use std::sync::Arc;
use crate::utils::scatter;
-use arrow::array::{ArrayRef, BooleanArray};
+use arrow::array::{new_empty_array, ArrayRef, BooleanArray};
use arrow::compute::filter_record_batch;
use arrow::datatypes::{DataType, Field, FieldRef, Schema};
use arrow::record_batch::RecordBatch;
use datafusion_common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
};
-use datafusion_common::{internal_err, not_impl_err, Result, ScalarValue};
+use datafusion_common::{exec_err, internal_err, not_impl_err, Result,
ScalarValue};
use datafusion_expr_common::columnar_value::ColumnarValue;
use datafusion_expr_common::interval_arithmetic::Interval;
use datafusion_expr_common::sort_properties::ExprProperties;
@@ -90,36 +90,69 @@ pub trait PhysicalExpr: Any + Send + Sync + Display + Debug
+ DynEq + DynHash {
self.nullable(input_schema)?,
)))
}
- /// Evaluate an expression against a RecordBatch after first applying a
- /// validity array
+ /// Evaluate an expression against a RecordBatch after first applying a
validity array
+ ///
+ /// # Errors
+ ///
+ /// Returns an `Err` if the expression could not be evaluated or if the
length of the
+ /// `selection` validity array and the number of row in `batch` is not
equal.
fn evaluate_selection(
&self,
batch: &RecordBatch,
selection: &BooleanArray,
) -> Result<ColumnarValue> {
- let tmp_batch = filter_record_batch(batch, selection)?;
-
- let tmp_result = self.evaluate(&tmp_batch)?;
-
- if batch.num_rows() == tmp_batch.num_rows() {
- // All values from the `selection` filter are true.
- Ok(tmp_result)
- } else if let ColumnarValue::Array(a) = tmp_result {
- scatter(selection, a.as_ref()).map(ColumnarValue::Array)
- } else if let ColumnarValue::Scalar(ScalarValue::Boolean(value)) =
&tmp_result {
- // When the scalar is true or false, skip the scatter process
- if let Some(v) = value {
- if *v {
- Ok(ColumnarValue::from(Arc::new(selection.clone()) as
ArrayRef))
+ let row_count = batch.num_rows();
+ if row_count != selection.len() {
+ return exec_err!("Selection array length does not match batch row
count: {} != {row_count}", selection.len());
+ }
+
+ let selection_count = selection.true_count();
+
+ // First, check if we can avoid filtering altogether.
+ if selection_count == row_count {
+ // All values from the `selection` filter are true and match the
input batch.
+ // No need to perform any filtering.
+ return self.evaluate(batch);
+ }
+
+ // Next, prepare the result array for each 'true' row in the selection
vector.
+ let filtered_result = if selection_count == 0 {
+ // Do not call `evaluate` when the selection is empty.
+ // `evaluate_selection` is used to conditionally evaluate
expressions.
+ // When the expression in question is fallible, evaluating it with
an empty
+ // record batch may trigger a runtime error (e.g. division by
zero).
+ //
+ // Instead, create an empty array matching the expected return
type.
+ let datatype = self.data_type(batch.schema_ref().as_ref())?;
+ ColumnarValue::Array(new_empty_array(&datatype))
+ } else {
+ // If we reach this point, there's no other option than to filter
the batch.
+ // This is a fairly costly operation since it requires creating
partial copies
+ // (worst case of length `row_count - 1`) of all the arrays in the
record batch.
+ // The resulting `filtered_batch` will contain `selection_count`
rows.
+ let filtered_batch = filter_record_batch(batch, selection)?;
+ self.evaluate(&filtered_batch)?
+ };
+
+ // Finally, scatter the filtered result array so that the indices
match the input rows again.
+ match &filtered_result {
+ ColumnarValue::Array(a) => {
+ scatter(selection, a.as_ref()).map(ColumnarValue::Array)
+ }
+ ColumnarValue::Scalar(ScalarValue::Boolean(value)) => {
+ // When the scalar is true or false, skip the scatter process
+ if let Some(v) = value {
+ if *v {
+ Ok(ColumnarValue::from(Arc::new(selection.clone()) as
ArrayRef))
+ } else {
+ Ok(filtered_result)
+ }
} else {
- Ok(tmp_result)
+ let array = BooleanArray::from(vec![None; row_count]);
+ scatter(selection, &array).map(ColumnarValue::Array)
}
- } else {
- let array = BooleanArray::from(vec![None; batch.num_rows()]);
- scatter(selection, &array).map(ColumnarValue::Array)
}
- } else {
- Ok(tmp_result)
+ ColumnarValue::Scalar(_) => Ok(filtered_result),
}
}
@@ -601,3 +634,190 @@ pub fn is_volatile(expr: &Arc<dyn PhysicalExpr>) -> bool {
.expect("infallible closure should not fail");
is_volatile
}
+
+#[cfg(test)]
+mod test {
+ use crate::physical_expr::PhysicalExpr;
+ use arrow::array::{Array, BooleanArray, Int64Array, RecordBatch};
+ use arrow::datatypes::{DataType, Schema};
+ use datafusion_expr_common::columnar_value::ColumnarValue;
+ use std::fmt::{Display, Formatter};
+ use std::sync::Arc;
+
+ #[derive(Debug, PartialEq, Eq, Hash)]
+ struct TestExpr {}
+
+ impl PhysicalExpr for TestExpr {
+ fn as_any(&self) -> &dyn std::any::Any {
+ self
+ }
+
+ fn data_type(&self, _schema: &Schema) ->
datafusion_common::Result<DataType> {
+ Ok(DataType::Int64)
+ }
+
+ fn nullable(&self, _schema: &Schema) ->
datafusion_common::Result<bool> {
+ Ok(false)
+ }
+
+ fn evaluate(
+ &self,
+ batch: &RecordBatch,
+ ) -> datafusion_common::Result<ColumnarValue> {
+ let data = vec![1; batch.num_rows()];
+ Ok(ColumnarValue::Array(Arc::new(Int64Array::from(data))))
+ }
+
+ fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
+ vec![]
+ }
+
+ fn with_new_children(
+ self: Arc<Self>,
+ _children: Vec<Arc<dyn PhysicalExpr>>,
+ ) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
+ Ok(Arc::new(Self {}))
+ }
+
+ fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ f.write_str("TestExpr")
+ }
+ }
+
+ impl Display for TestExpr {
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ self.fmt_sql(f)
+ }
+ }
+
+ macro_rules! assert_arrays_eq {
+ ($EXPECTED: expr, $ACTUAL: expr, $MESSAGE: expr) => {
+ let expected = $EXPECTED.to_array(1).unwrap();
+ let actual = $ACTUAL;
+
+ let actual_array = actual.to_array(expected.len()).unwrap();
+ let actual_ref = actual_array.as_ref();
+ let expected_ref = expected.as_ref();
+ assert!(
+ actual_ref == expected_ref,
+ "{}: expected: {:?}, actual: {:?}",
+ $MESSAGE,
+ $EXPECTED,
+ actual_ref
+ );
+ };
+ }
+
+ fn test_evaluate_selection(
+ batch: &RecordBatch,
+ selection: &BooleanArray,
+ expected: &ColumnarValue,
+ ) {
+ let expr = TestExpr {};
+
+ // First check that the `evaluate_selection` is the expected one
+ let selection_result = expr.evaluate_selection(batch,
selection).unwrap();
+ assert_eq!(
+ expected.to_array(1).unwrap().len(),
+ selection_result.to_array(1).unwrap().len(),
+ "evaluate_selection should output row count should match input
record batch"
+ );
+ assert_arrays_eq!(
+ expected,
+ &selection_result,
+ "evaluate_selection returned unexpected value"
+ );
+
+ // If we're selecting all rows, the result should be the same as
calling `evaluate`
+ // with the full record batch.
+ if (0..batch.num_rows())
+ .all(|row_idx| row_idx < selection.len() &&
selection.value(row_idx))
+ {
+ let empty_result = expr.evaluate(batch).unwrap();
+
+ assert_arrays_eq!(
+ empty_result,
+ &selection_result,
+ "evaluate_selection does not match unfiltered evaluate result"
+ );
+ }
+ }
+
+ fn test_evaluate_selection_error(batch: &RecordBatch, selection:
&BooleanArray) {
+ let expr = TestExpr {};
+
+ // First check that the `evaluate_selection` is the expected one
+ let selection_result = expr.evaluate_selection(batch, selection);
+ assert!(selection_result.is_err(), "evaluate_selection should fail");
+ }
+
+ #[test]
+ pub fn test_evaluate_selection_with_empty_record_batch() {
+ test_evaluate_selection(
+ &RecordBatch::new_empty(Arc::new(Schema::empty())),
+ &BooleanArray::from(vec![false; 0]),
+ &ColumnarValue::Array(Arc::new(Int64Array::new_null(0))),
+ );
+ }
+
+ #[test]
+ pub fn
test_evaluate_selection_with_empty_record_batch_with_larger_false_selection() {
+ test_evaluate_selection_error(
+ &RecordBatch::new_empty(Arc::new(Schema::empty())),
+ &BooleanArray::from(vec![false; 10]),
+ );
+ }
+
+ #[test]
+ pub fn
test_evaluate_selection_with_empty_record_batch_with_larger_true_selection() {
+ test_evaluate_selection_error(
+ &RecordBatch::new_empty(Arc::new(Schema::empty())),
+ &BooleanArray::from(vec![true; 10]),
+ );
+ }
+
+ #[test]
+ pub fn test_evaluate_selection_with_non_empty_record_batch() {
+ test_evaluate_selection(
+ unsafe { &RecordBatch::new_unchecked(Arc::new(Schema::empty()),
vec![], 10) },
+ &BooleanArray::from(vec![true; 10]),
+ &ColumnarValue::Array(Arc::new(Int64Array::from(vec![1; 10]))),
+ );
+ }
+
+ #[test]
+ pub fn
test_evaluate_selection_with_non_empty_record_batch_with_larger_false_selection(
+ ) {
+ test_evaluate_selection_error(
+ unsafe { &RecordBatch::new_unchecked(Arc::new(Schema::empty()),
vec![], 10) },
+ &BooleanArray::from(vec![false; 20]),
+ );
+ }
+
+ #[test]
+ pub fn
test_evaluate_selection_with_non_empty_record_batch_with_larger_true_selection(
+ ) {
+ test_evaluate_selection_error(
+ unsafe { &RecordBatch::new_unchecked(Arc::new(Schema::empty()),
vec![], 10) },
+ &BooleanArray::from(vec![true; 20]),
+ );
+ }
+
+ #[test]
+ pub fn
test_evaluate_selection_with_non_empty_record_batch_with_smaller_false_selection(
+ ) {
+ test_evaluate_selection_error(
+ unsafe { &RecordBatch::new_unchecked(Arc::new(Schema::empty()),
vec![], 10) },
+ &BooleanArray::from(vec![false; 5]),
+ );
+ }
+
+ #[test]
+ pub fn
test_evaluate_selection_with_non_empty_record_batch_with_smaller_true_selection(
+ ) {
+ test_evaluate_selection_error(
+ unsafe { &RecordBatch::new_unchecked(Arc::new(Schema::empty()),
vec![], 10) },
+ &BooleanArray::from(vec![true; 5]),
+ );
+ }
+}
diff --git a/datafusion/physical-expr/src/expressions/case.rs
b/datafusion/physical-expr/src/expressions/case.rs
index 5409cfe8e7..d14146a20d 100644
--- a/datafusion/physical-expr/src/expressions/case.rs
+++ b/datafusion/physical-expr/src/expressions/case.rs
@@ -155,10 +155,7 @@ impl CaseExpr {
&& else_expr.as_ref().unwrap().as_any().is::<Literal>()
{
EvalMethod::ScalarOrScalar
- } else if when_then_expr.len() == 1
- && is_cheap_and_infallible(&(when_then_expr[0].1))
- && else_expr.as_ref().is_some_and(is_cheap_and_infallible)
- {
+ } else if when_then_expr.len() == 1 && else_expr.is_some() {
EvalMethod::ExpressionOrExpression
} else {
EvalMethod::NoExpression
@@ -425,6 +422,16 @@ impl CaseExpr {
)
})?;
+ // For the true and false/null selection vectors, bypass
`evaluate_selection` and merging
+ // results. This avoids materializing the array for the other branch
which we will discard
+ // entirely anyway.
+ let true_count = when_value.true_count();
+ if true_count == batch.num_rows() {
+ return self.when_then_expr[0].1.evaluate(batch);
+ } else if true_count == 0 {
+ return self.else_expr.as_ref().unwrap().evaluate(batch);
+ }
+
// Treat 'NULL' as false value
let when_value = match when_value.null_count() {
0 => Cow::Borrowed(when_value),
diff --git a/datafusion/sqllogictest/test_files/case.slt
b/datafusion/sqllogictest/test_files/case.slt
index 69f80f4593..9bc1f83ed1 100644
--- a/datafusion/sqllogictest/test_files/case.slt
+++ b/datafusion/sqllogictest/test_files/case.slt
@@ -467,6 +467,7 @@ FROM t;
----
[{foo: blarg}]
+# mix of then and else
query II
SELECT v, CASE WHEN v != 0 THEN 10/v ELSE 42 END FROM (VALUES (0), (1), (2))
t(v)
----
@@ -474,12 +475,38 @@ SELECT v, CASE WHEN v != 0 THEN 10/v ELSE 42 END FROM
(VALUES (0), (1), (2)) t(v
1 10
2 5
+# when expressions is always false, then branch should never be evaluated
query II
SELECT v, CASE WHEN v < 0 THEN 10/0 ELSE 1 END FROM (VALUES (1), (2)) t(v)
----
1 1
2 1
+# when expressions is always true, else branch should never be evaluated
+query II
+SELECT v, CASE WHEN v > 0 THEN 1 ELSE 10/0 END FROM (VALUES (1), (2)) t(v)
+----
+1 1
+2 1
+
+
+# lazy evaluation of multiple when branches, else branch should never be
evaluated
+query II
+SELECT v, CASE WHEN v == 1 THEN -1 WHEN v == 2 THEN -2 WHEN v == 3 THEN -3
ELSE 10/0 END FROM (VALUES (1), (2), (3)) t(v)
+----
+1 -1
+2 -2
+3 -3
+
+# covers the InfallibleExprOrNull evaluation strategy
+query II
+SELECT v, CASE WHEN v THEN 1 END FROM (VALUES (1), (2), (3), (NULL)) t(v)
+----
+1 1
+2 1
+3 1
+NULL NULL
+
statement ok
drop table t
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]