This is an automated email from the ASF dual-hosted git repository.

liukun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new da392f4b3 CaseWhen: coerce the all then and else data type to a common 
data type (#2819)
da392f4b3 is described below

commit da392f4b3d77ad5fec0018a50146746a0efabac6
Author: Kun Liu <[email protected]>
AuthorDate: Mon Jul 4 10:13:19 2022 +0800

    CaseWhen: coerce the all then and else data type to a common data type 
(#2819)
    
    * case when: coerce to the same data type for the result data type
    
    * case when: support result type coerced
    
    * change usage of literal
---
 datafusion/physical-expr/src/expressions/case.rs | 157 +++++++++++++++++++++--
 datafusion/physical-expr/src/planner.rs          |  11 +-
 2 files changed, 150 insertions(+), 18 deletions(-)

diff --git a/datafusion/physical-expr/src/expressions/case.rs 
b/datafusion/physical-expr/src/expressions/case.rs
index 6e67ba4ad..d677c5a08 100644
--- a/datafusion/physical-expr/src/expressions/case.rs
+++ b/datafusion/physical-expr/src/expressions/case.rs
@@ -25,6 +25,7 @@ use arrow::compute::{and, eq_dyn, is_null, not, or, 
or_kleene};
 use arrow::datatypes::{DataType, Schema};
 use arrow::record_batch::RecordBatch;
 use datafusion_common::{DataFusionError, Result};
+use datafusion_expr::binary_rule::comparison_eq_coercion;
 use datafusion_expr::ColumnarValue;
 
 type WhenThen = (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>);
@@ -76,7 +77,7 @@ impl CaseExpr {
     /// Create a new CASE WHEN expression
     pub fn try_new(
         expr: Option<Arc<dyn PhysicalExpr>>,
-        when_then_expr: &[WhenThen],
+        when_then_expr: Vec<WhenThen>,
         else_expr: Option<Arc<dyn PhysicalExpr>>,
     ) -> Result<Self> {
         if when_then_expr.is_empty() {
@@ -86,7 +87,7 @@ impl CaseExpr {
         } else {
             Ok(Self {
                 expr,
-                when_then_expr: when_then_expr.to_vec(),
+                when_then_expr,
                 else_expr,
             })
         }
@@ -291,12 +292,68 @@ impl PhysicalExpr for CaseExpr {
 /// Create a CASE expression
 pub fn case(
     expr: Option<Arc<dyn PhysicalExpr>>,
-    when_thens: &[WhenThen],
+    when_thens: Vec<WhenThen>,
     else_expr: Option<Arc<dyn PhysicalExpr>>,
+    input_schema: &Schema,
 ) -> Result<Arc<dyn PhysicalExpr>> {
+    // all the result of then and else should be convert to a common data type,
+    // if they can be coercible to a common data type, return error.
+    let coerce_type = get_case_common_type(&when_thens, else_expr.clone(), 
input_schema);
+    let (when_thens, else_expr) = match coerce_type {
+        None => Err(DataFusionError::Plan(format!(
+            "Can't get a common type for then {:?} and else {:?} expression",
+            when_thens, else_expr
+        ))),
+        Some(data_type) => {
+            // cast then expr
+            let left = when_thens
+                .into_iter()
+                .map(|(when, then)| {
+                    let then = try_cast(then, input_schema, 
data_type.clone())?;
+                    Ok((when, then))
+                })
+                .collect::<Result<Vec<_>>>()?;
+            let right = match else_expr {
+                None => None,
+                Some(expr) => Some(try_cast(expr, input_schema, 
data_type.clone())?),
+            };
+
+            Ok((left, right))
+        }
+    }?;
+
     Ok(Arc::new(CaseExpr::try_new(expr, when_thens, else_expr)?))
 }
 
+fn get_case_common_type(
+    when_thens: &[WhenThen],
+    else_expr: Option<Arc<dyn PhysicalExpr>>,
+    input_schema: &Schema,
+) -> Option<DataType> {
+    let thens_type = when_thens
+        .iter()
+        .map(|when_then| {
+            let data_type = &when_then.1.data_type(input_schema).unwrap();
+            data_type.clone()
+        })
+        .collect::<Vec<_>>();
+    let else_type = match else_expr {
+        None => {
+            // case when then exprs must have one then value
+            thens_type[0].clone()
+        }
+        Some(else_phy_expr) => else_phy_expr.data_type(input_schema).unwrap(),
+    };
+    thens_type
+        .iter()
+        .fold(Some(else_type), |left, right_type| match left {
+            None => None,
+            // TODO: now just use the `equal` coercion rule for case when. If 
find the issue, and
+            // refactor again.
+            Some(left_type) => comparison_eq_coercion(&left_type, right_type),
+        })
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
@@ -323,8 +380,9 @@ mod tests {
 
         let expr = case(
             Some(col("a", &schema)?),
-            &[(when1, then1), (when2, then2)],
+            vec![(when1, then1), (when2, then2)],
             None,
+            schema.as_ref(),
         )?;
         let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
         let result = result
@@ -353,8 +411,9 @@ mod tests {
 
         let expr = case(
             Some(col("a", &schema)?),
-            &[(when1, then1), (when2, then2)],
+            vec![(when1, then1), (when2, then2)],
             Some(else_value),
+            schema.as_ref(),
         )?;
         let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
         let result = result
@@ -387,8 +446,9 @@ mod tests {
 
         let expr = case(
             Some(col("a", &schema)?),
-            &[(when1, then1)],
+            vec![(when1, then1)],
             Some(else_value),
+            schema.as_ref(),
         )?;
         let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
         let result = result
@@ -424,7 +484,12 @@ mod tests {
         )?;
         let then2 = lit(456i32);
 
-        let expr = case(None, &[(when1, then1), (when2, then2)], None)?;
+        let expr = case(
+            None,
+            vec![(when1, then1), (when2, then2)],
+            None,
+            schema.as_ref(),
+        )?;
         let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
         let result = result
             .as_any()
@@ -453,7 +518,7 @@ mod tests {
         )?;
         let x = lit(ScalarValue::Float64(None));
 
-        let expr = case(None, &[(when1, then1)], Some(x))?;
+        let expr = case(None, vec![(when1, then1)], Some(x), schema.as_ref())?;
         let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
         let result = result
             .as_any()
@@ -496,7 +561,12 @@ mod tests {
         let then2 = lit(456i32);
         let else_value = lit(999i32);
 
-        let expr = case(None, &[(when1, then1), (when2, then2)], 
Some(else_value))?;
+        let expr = case(
+            None,
+            vec![(when1, then1), (when2, then2)],
+            Some(else_value),
+            schema.as_ref(),
+        )?;
         let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
         let result = result
             .as_any()
@@ -526,7 +596,7 @@ mod tests {
         let then = lit(123.3f64);
         let else_value = lit(999i32);
 
-        let expr = case(None, &[(when, then)], Some(else_value))?;
+        let expr = case(None, vec![(when, then)], Some(else_value), 
schema.as_ref())?;
         let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
         let result = result
             .as_any()
@@ -555,7 +625,7 @@ mod tests {
         )?;
         let then = col("load4", &schema)?;
 
-        let expr = case(None, &[(when, then)], None)?;
+        let expr = case(None, vec![(when, then)], None, schema.as_ref())?;
         let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
         let result = result
             .as_any()
@@ -580,7 +650,7 @@ mod tests {
         let when = lit(1.77f64);
         let then = col("load4", &schema)?;
 
-        let expr = case(Some(expr), &[(when, then)], None)?;
+        let expr = case(Some(expr), vec![(when, then)], None, 
schema.as_ref())?;
         let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
         let result = result
             .as_any()
@@ -630,4 +700,67 @@ mod tests {
             RecordBatch::try_from_iter(vec![("load4", Arc::new(load4) as 
ArrayRef)])?;
         Ok(batch)
     }
+
+    #[test]
+    fn case_test_incompatible() -> Result<()> {
+        // 1 then is int64
+        // 2 then is boolean
+        let batch = case_test_batch()?;
+        let schema = batch.schema();
+
+        // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN true END
+        let when1 = binary(
+            col("a", &schema)?,
+            Operator::Eq,
+            lit("foo"),
+            &batch.schema(),
+        )?;
+        let then1 = lit(123i32);
+        let when2 = binary(
+            col("a", &schema)?,
+            Operator::Eq,
+            lit("bar"),
+            &batch.schema(),
+        )?;
+        let then2 = lit(true);
+
+        let expr = case(
+            None,
+            vec![(when1, then1), (when2, then2)],
+            None,
+            schema.as_ref(),
+        );
+        assert!(expr.is_err());
+
+        // then 1 is int32
+        // then 2 is int64
+        // else is float
+        // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 ELSE 1.23 END
+        let when1 = binary(
+            col("a", &schema)?,
+            Operator::Eq,
+            lit("foo"),
+            &batch.schema(),
+        )?;
+        let then1 = lit(123i32);
+        let when2 = binary(
+            col("a", &schema)?,
+            Operator::Eq,
+            lit("bar"),
+            &batch.schema(),
+        )?;
+        let then2 = lit(456i64);
+        let else_expr = lit(1.23f64);
+
+        let expr = case(
+            None,
+            vec![(when1, then1), (when2, then2)],
+            Some(else_expr),
+            schema.as_ref(),
+        );
+        assert!(expr.is_ok());
+        let result_type = expr.unwrap().data_type(schema.as_ref())?;
+        assert_eq!(DataType::Float64, result_type);
+        Ok(())
+    }
 }
diff --git a/datafusion/physical-expr/src/planner.rs 
b/datafusion/physical-expr/src/planner.rs
index d8a7a3004..1eb975bed 100644
--- a/datafusion/physical-expr/src/planner.rs
+++ b/datafusion/physical-expr/src/planner.rs
@@ -18,9 +18,7 @@
 use crate::expressions::try_cast;
 use crate::{
     execution_props::ExecutionProps,
-    expressions::{
-        self, binary, CaseExpr, Column, DateIntervalExpr, GetIndexedFieldExpr, 
Literal,
-    },
+    expressions::{self, binary, Column, DateIntervalExpr, GetIndexedFieldExpr, 
Literal},
     functions, udf,
     var_provider::VarType,
     PhysicalExpr,
@@ -162,11 +160,12 @@ pub fn create_physical_expr(
             } else {
                 None
             };
-            Ok(Arc::new(CaseExpr::try_new(
+            Ok(expressions::case(
                 expr,
-                &when_then_expr,
+                when_then_expr,
                 else_expr,
-            )?))
+                input_schema,
+            )?)
         }
         Expr::Cast { expr, data_type } => expressions::cast(
             create_physical_expr(expr, input_dfschema, input_schema, 
execution_props)?,

Reply via email to