This is an automated email from the ASF dual-hosted git repository.
liukun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new da392f4b3 CaseWhen: coerce the all then and else data type to a common
data type (#2819)
da392f4b3 is described below
commit da392f4b3d77ad5fec0018a50146746a0efabac6
Author: Kun Liu <[email protected]>
AuthorDate: Mon Jul 4 10:13:19 2022 +0800
CaseWhen: coerce the all then and else data type to a common data type
(#2819)
* case when: coerce to the same data type for the result data type
* case when: support result type coerced
* change usage of literal
---
datafusion/physical-expr/src/expressions/case.rs | 157 +++++++++++++++++++++--
datafusion/physical-expr/src/planner.rs | 11 +-
2 files changed, 150 insertions(+), 18 deletions(-)
diff --git a/datafusion/physical-expr/src/expressions/case.rs
b/datafusion/physical-expr/src/expressions/case.rs
index 6e67ba4ad..d677c5a08 100644
--- a/datafusion/physical-expr/src/expressions/case.rs
+++ b/datafusion/physical-expr/src/expressions/case.rs
@@ -25,6 +25,7 @@ use arrow::compute::{and, eq_dyn, is_null, not, or,
or_kleene};
use arrow::datatypes::{DataType, Schema};
use arrow::record_batch::RecordBatch;
use datafusion_common::{DataFusionError, Result};
+use datafusion_expr::binary_rule::comparison_eq_coercion;
use datafusion_expr::ColumnarValue;
type WhenThen = (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>);
@@ -76,7 +77,7 @@ impl CaseExpr {
/// Create a new CASE WHEN expression
pub fn try_new(
expr: Option<Arc<dyn PhysicalExpr>>,
- when_then_expr: &[WhenThen],
+ when_then_expr: Vec<WhenThen>,
else_expr: Option<Arc<dyn PhysicalExpr>>,
) -> Result<Self> {
if when_then_expr.is_empty() {
@@ -86,7 +87,7 @@ impl CaseExpr {
} else {
Ok(Self {
expr,
- when_then_expr: when_then_expr.to_vec(),
+ when_then_expr,
else_expr,
})
}
@@ -291,12 +292,68 @@ impl PhysicalExpr for CaseExpr {
/// Create a CASE expression
pub fn case(
expr: Option<Arc<dyn PhysicalExpr>>,
- when_thens: &[WhenThen],
+ when_thens: Vec<WhenThen>,
else_expr: Option<Arc<dyn PhysicalExpr>>,
+ input_schema: &Schema,
) -> Result<Arc<dyn PhysicalExpr>> {
+ // all the result of then and else should be convert to a common data type,
+ // if they can be coercible to a common data type, return error.
+ let coerce_type = get_case_common_type(&when_thens, else_expr.clone(),
input_schema);
+ let (when_thens, else_expr) = match coerce_type {
+ None => Err(DataFusionError::Plan(format!(
+ "Can't get a common type for then {:?} and else {:?} expression",
+ when_thens, else_expr
+ ))),
+ Some(data_type) => {
+ // cast then expr
+ let left = when_thens
+ .into_iter()
+ .map(|(when, then)| {
+ let then = try_cast(then, input_schema,
data_type.clone())?;
+ Ok((when, then))
+ })
+ .collect::<Result<Vec<_>>>()?;
+ let right = match else_expr {
+ None => None,
+ Some(expr) => Some(try_cast(expr, input_schema,
data_type.clone())?),
+ };
+
+ Ok((left, right))
+ }
+ }?;
+
Ok(Arc::new(CaseExpr::try_new(expr, when_thens, else_expr)?))
}
+fn get_case_common_type(
+ when_thens: &[WhenThen],
+ else_expr: Option<Arc<dyn PhysicalExpr>>,
+ input_schema: &Schema,
+) -> Option<DataType> {
+ let thens_type = when_thens
+ .iter()
+ .map(|when_then| {
+ let data_type = &when_then.1.data_type(input_schema).unwrap();
+ data_type.clone()
+ })
+ .collect::<Vec<_>>();
+ let else_type = match else_expr {
+ None => {
+ // case when then exprs must have one then value
+ thens_type[0].clone()
+ }
+ Some(else_phy_expr) => else_phy_expr.data_type(input_schema).unwrap(),
+ };
+ thens_type
+ .iter()
+ .fold(Some(else_type), |left, right_type| match left {
+ None => None,
+ // TODO: now just use the `equal` coercion rule for case when. If
find the issue, and
+ // refactor again.
+ Some(left_type) => comparison_eq_coercion(&left_type, right_type),
+ })
+}
+
#[cfg(test)]
mod tests {
use super::*;
@@ -323,8 +380,9 @@ mod tests {
let expr = case(
Some(col("a", &schema)?),
- &[(when1, then1), (when2, then2)],
+ vec![(when1, then1), (when2, then2)],
None,
+ schema.as_ref(),
)?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
@@ -353,8 +411,9 @@ mod tests {
let expr = case(
Some(col("a", &schema)?),
- &[(when1, then1), (when2, then2)],
+ vec![(when1, then1), (when2, then2)],
Some(else_value),
+ schema.as_ref(),
)?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
@@ -387,8 +446,9 @@ mod tests {
let expr = case(
Some(col("a", &schema)?),
- &[(when1, then1)],
+ vec![(when1, then1)],
Some(else_value),
+ schema.as_ref(),
)?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
@@ -424,7 +484,12 @@ mod tests {
)?;
let then2 = lit(456i32);
- let expr = case(None, &[(when1, then1), (when2, then2)], None)?;
+ let expr = case(
+ None,
+ vec![(when1, then1), (when2, then2)],
+ None,
+ schema.as_ref(),
+ )?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
.as_any()
@@ -453,7 +518,7 @@ mod tests {
)?;
let x = lit(ScalarValue::Float64(None));
- let expr = case(None, &[(when1, then1)], Some(x))?;
+ let expr = case(None, vec![(when1, then1)], Some(x), schema.as_ref())?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
.as_any()
@@ -496,7 +561,12 @@ mod tests {
let then2 = lit(456i32);
let else_value = lit(999i32);
- let expr = case(None, &[(when1, then1), (when2, then2)],
Some(else_value))?;
+ let expr = case(
+ None,
+ vec![(when1, then1), (when2, then2)],
+ Some(else_value),
+ schema.as_ref(),
+ )?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
.as_any()
@@ -526,7 +596,7 @@ mod tests {
let then = lit(123.3f64);
let else_value = lit(999i32);
- let expr = case(None, &[(when, then)], Some(else_value))?;
+ let expr = case(None, vec![(when, then)], Some(else_value),
schema.as_ref())?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
.as_any()
@@ -555,7 +625,7 @@ mod tests {
)?;
let then = col("load4", &schema)?;
- let expr = case(None, &[(when, then)], None)?;
+ let expr = case(None, vec![(when, then)], None, schema.as_ref())?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
.as_any()
@@ -580,7 +650,7 @@ mod tests {
let when = lit(1.77f64);
let then = col("load4", &schema)?;
- let expr = case(Some(expr), &[(when, then)], None)?;
+ let expr = case(Some(expr), vec![(when, then)], None,
schema.as_ref())?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
.as_any()
@@ -630,4 +700,67 @@ mod tests {
RecordBatch::try_from_iter(vec![("load4", Arc::new(load4) as
ArrayRef)])?;
Ok(batch)
}
+
+ #[test]
+ fn case_test_incompatible() -> Result<()> {
+ // 1 then is int64
+ // 2 then is boolean
+ let batch = case_test_batch()?;
+ let schema = batch.schema();
+
+ // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN true END
+ let when1 = binary(
+ col("a", &schema)?,
+ Operator::Eq,
+ lit("foo"),
+ &batch.schema(),
+ )?;
+ let then1 = lit(123i32);
+ let when2 = binary(
+ col("a", &schema)?,
+ Operator::Eq,
+ lit("bar"),
+ &batch.schema(),
+ )?;
+ let then2 = lit(true);
+
+ let expr = case(
+ None,
+ vec![(when1, then1), (when2, then2)],
+ None,
+ schema.as_ref(),
+ );
+ assert!(expr.is_err());
+
+ // then 1 is int32
+ // then 2 is int64
+ // else is float
+ // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 ELSE 1.23 END
+ let when1 = binary(
+ col("a", &schema)?,
+ Operator::Eq,
+ lit("foo"),
+ &batch.schema(),
+ )?;
+ let then1 = lit(123i32);
+ let when2 = binary(
+ col("a", &schema)?,
+ Operator::Eq,
+ lit("bar"),
+ &batch.schema(),
+ )?;
+ let then2 = lit(456i64);
+ let else_expr = lit(1.23f64);
+
+ let expr = case(
+ None,
+ vec![(when1, then1), (when2, then2)],
+ Some(else_expr),
+ schema.as_ref(),
+ );
+ assert!(expr.is_ok());
+ let result_type = expr.unwrap().data_type(schema.as_ref())?;
+ assert_eq!(DataType::Float64, result_type);
+ Ok(())
+ }
}
diff --git a/datafusion/physical-expr/src/planner.rs
b/datafusion/physical-expr/src/planner.rs
index d8a7a3004..1eb975bed 100644
--- a/datafusion/physical-expr/src/planner.rs
+++ b/datafusion/physical-expr/src/planner.rs
@@ -18,9 +18,7 @@
use crate::expressions::try_cast;
use crate::{
execution_props::ExecutionProps,
- expressions::{
- self, binary, CaseExpr, Column, DateIntervalExpr, GetIndexedFieldExpr,
Literal,
- },
+ expressions::{self, binary, Column, DateIntervalExpr, GetIndexedFieldExpr,
Literal},
functions, udf,
var_provider::VarType,
PhysicalExpr,
@@ -162,11 +160,12 @@ pub fn create_physical_expr(
} else {
None
};
- Ok(Arc::new(CaseExpr::try_new(
+ Ok(expressions::case(
expr,
- &when_then_expr,
+ when_then_expr,
else_expr,
- )?))
+ input_schema,
+ )?)
}
Expr::Cast { expr, data_type } => expressions::cast(
create_physical_expr(expr, input_dfschema, input_schema,
execution_props)?,