This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 0f4b8b136c Optimize CASE expression for "expr or expr" usage. (#13953)
0f4b8b136c is described below
commit 0f4b8b136ceb9132fd6b6595bd6a09a09707f5d9
Author: Andre Weltsch <[email protected]>
AuthorDate: Sat Jan 4 15:51:25 2025 +0100
Optimize CASE expression for "expr or expr" usage. (#13953)
* Apply optimization for ExprOrExpr.
* Implement optimization similar to existing code.
* Add sqllogictest.
---
datafusion/physical-expr/src/expressions/case.rs | 84 ++++++++++++++++++++++++
datafusion/sqllogictest/test_files/case.slt | 11 ++++
2 files changed, 95 insertions(+)
diff --git a/datafusion/physical-expr/src/expressions/case.rs
b/datafusion/physical-expr/src/expressions/case.rs
index 711a521da1..ee19a8c9dd 100644
--- a/datafusion/physical-expr/src/expressions/case.rs
+++ b/datafusion/physical-expr/src/expressions/case.rs
@@ -60,6 +60,11 @@ enum EvalMethod {
/// are literal values
/// CASE WHEN condition THEN literal ELSE literal END
ScalarOrScalar,
+ /// This is a specialization for a specific use case where we can take a
fast path
+ /// if there is just one when/then pair and both the `then` and `else` are
expressions
+ ///
+ /// CASE WHEN condition THEN expression ELSE expression END
+ ExpressionOrExpression,
}
/// The CASE expression is similar to a series of nested if/else and there are
two forms that
@@ -149,6 +154,8 @@ impl CaseExpr {
&& else_expr.as_ref().unwrap().as_any().is::<Literal>()
{
EvalMethod::ScalarOrScalar
+ } else if when_then_expr.len() == 1 && else_expr.is_some() {
+ EvalMethod::ExpressionOrExpression
} else {
EvalMethod::NoExpression
};
@@ -394,6 +401,43 @@ impl CaseExpr {
Ok(ColumnarValue::Array(zip(&when_value, &then_value, &else_)?))
}
+
+ fn expr_or_expr(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
+ let return_type = self.data_type(&batch.schema())?;
+
+ // evalute when condition on batch
+ let when_value = self.when_then_expr[0].0.evaluate(batch)?;
+ let when_value = when_value.into_array(batch.num_rows())?;
+ let when_value = as_boolean_array(&when_value).map_err(|e| {
+ DataFusionError::Context(
+ "WHEN expression did not return a BooleanArray".to_string(),
+ Box::new(e),
+ )
+ })?;
+
+ // Treat 'NULL' as false value
+ let when_value = match when_value.null_count() {
+ 0 => Cow::Borrowed(when_value),
+ _ => Cow::Owned(prep_null_mask_filter(when_value)),
+ };
+
+ let then_value = self.when_then_expr[0]
+ .1
+ .evaluate_selection(batch, &when_value)?
+ .into_array(batch.num_rows())?;
+
+ // evaluate else expression on the values not covered by when_value
+ let remainder = not(&when_value)?;
+ let e = self.else_expr.as_ref().unwrap();
+ // keep `else_expr`'s data type and return type consistent
+ let expr = try_cast(Arc::clone(e), &batch.schema(),
return_type.clone())
+ .unwrap_or_else(|_| Arc::clone(e));
+ let else_ = expr
+ .evaluate_selection(batch, &remainder)?
+ .into_array(batch.num_rows())?;
+
+ Ok(ColumnarValue::Array(zip(&remainder, &else_, &then_value)?))
+ }
}
impl PhysicalExpr for CaseExpr {
@@ -457,6 +501,7 @@ impl PhysicalExpr for CaseExpr {
self.case_column_or_null(batch)
}
EvalMethod::ScalarOrScalar => self.scalar_or_scalar(batch),
+ EvalMethod::ExpressionOrExpression => self.expr_or_expr(batch),
}
}
@@ -1174,6 +1219,45 @@ mod tests {
Ok(())
}
+ #[test]
+ fn test_expr_or_expr_specialization() -> Result<()> {
+ let batch = case_test_batch1()?;
+ let schema = batch.schema();
+ let when = binary(
+ col("a", &schema)?,
+ Operator::LtEq,
+ lit(2i32),
+ &batch.schema(),
+ )?;
+ let then = binary(
+ col("a", &schema)?,
+ Operator::Plus,
+ lit(1i32),
+ &batch.schema(),
+ )?;
+ let else_expr = binary(
+ col("a", &schema)?,
+ Operator::Minus,
+ lit(1i32),
+ &batch.schema(),
+ )?;
+ let expr = CaseExpr::try_new(None, vec![(when, then)],
Some(else_expr))?;
+ assert!(matches!(
+ expr.eval_method,
+ EvalMethod::ExpressionOrExpression
+ ));
+ let result = expr
+ .evaluate(&batch)?
+ .into_array(batch.num_rows())
+ .expect("Failed to convert to array");
+ let result = as_int32_array(&result).expect("failed to downcast to
Int32Array");
+
+ let expected = &Int32Array::from(vec![Some(2), Some(1), None,
Some(4)]);
+
+ assert_eq!(expected, result);
+ Ok(())
+ }
+
fn make_col(name: &str, index: usize) -> Arc<dyn PhysicalExpr> {
Arc::new(Column::new(name, index))
}
diff --git a/datafusion/sqllogictest/test_files/case.slt
b/datafusion/sqllogictest/test_files/case.slt
index 4f3320931d..6b4dffd12c 100644
--- a/datafusion/sqllogictest/test_files/case.slt
+++ b/datafusion/sqllogictest/test_files/case.slt
@@ -224,3 +224,14 @@ query I
SELECT CASE arrow_cast([1,2,3], 'FixedSizeList(3, Int64)') WHEN
arrow_cast([1,2,3], 'FixedSizeList(3, Int64)') THEN 1 ELSE 0 END;
----
1
+
+# CASE WHEN with single predicate and two non-trivial branches (expr or expr
usage)
+query I
+SELECT CASE WHEN a < 5 THEN a + b ELSE b - NVL(a, 0) END FROM foo
+----
+3
+7
+1
+NULL
+NULL
+7
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]