This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 8ebc94c  fix: casting Int64 to Float64 unsuccessfully caused tpch8 to 
fail (#1601)
8ebc94c is described below

commit 8ebc94c42e301e00cb77e36daa1e7b17a74a2b5b
Author: xudong.w <[email protected]>
AuthorDate: Wed Jan 19 00:01:00 2022 +0800

    fix: casting Int64 to Float64 unsuccessfully caused tpch8 to fail (#1601)
---
 datafusion/src/physical_plan/expressions/case.rs | 39 ++++++++++++++++++++++--
 1 file changed, 37 insertions(+), 2 deletions(-)

diff --git a/datafusion/src/physical_plan/expressions/case.rs 
b/datafusion/src/physical_plan/expressions/case.rs
index 551d87a..2a680d3 100644
--- a/datafusion/src/physical_plan/expressions/case.rs
+++ b/datafusion/src/physical_plan/expressions/case.rs
@@ -18,6 +18,7 @@
 use std::{any::Any, sync::Arc};
 
 use crate::error::{DataFusionError, Result};
+use crate::physical_plan::expressions::try_cast;
 use crate::physical_plan::{ColumnarValue, PhysicalExpr};
 use arrow::array::{self, *};
 use arrow::compute::{eq, eq_utf8};
@@ -324,7 +325,10 @@ impl CaseExpr {
 
         // start with the else condition, or nulls
         let mut current_value: Option<ArrayRef> = if let Some(e) = 
&self.else_expr {
-            Some(e.evaluate(batch)?.into_array(batch.num_rows()))
+            // keep `else_expr`'s data type and return type consistent
+            let expr = try_cast(e.clone(), &*batch.schema(), 
return_type.clone())
+                .unwrap_or_else(|_| e.clone());
+            Some(expr.evaluate(batch)?.into_array(batch.num_rows()))
         } else {
             Some(new_null_array(&return_type, batch.num_rows()))
         };
@@ -365,7 +369,9 @@ impl CaseExpr {
 
         // start with the else condition, or nulls
         let mut current_value: Option<ArrayRef> = if let Some(e) = 
&self.else_expr {
-            Some(e.evaluate(batch)?.into_array(batch.num_rows()))
+            let expr = try_cast(e.clone(), &*batch.schema(), 
return_type.clone())
+                .unwrap_or_else(|_| e.clone());
+            Some(expr.evaluate(batch)?.into_array(batch.num_rows()))
         } else {
             Some(new_null_array(&return_type, batch.num_rows()))
         };
@@ -589,6 +595,35 @@ mod tests {
         Ok(())
     }
 
+    #[test]
+    fn case_with_type_cast() -> Result<()> {
+        let batch = case_test_batch()?;
+        let schema = batch.schema();
+
+        // CASE WHEN a = 'foo' THEN 123.3 ELSE 999 END
+        let when = binary(
+            col("a", &schema)?,
+            Operator::Eq,
+            lit(ScalarValue::Utf8(Some("foo".to_string()))),
+            &batch.schema(),
+        )?;
+        let then = lit(ScalarValue::Float64(Some(123.3)));
+        let else_value = lit(ScalarValue::Int32(Some(999)));
+
+        let expr = case(None, &[(when, then)], Some(else_value))?;
+        let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
+        let result = result
+            .as_any()
+            .downcast_ref::<Float64Array>()
+            .expect("failed to downcast to Float64Array");
+
+        let expected =
+            &Float64Array::from(vec![Some(123.3), Some(999.0), Some(999.0), 
Some(999.0)]);
+
+        assert_eq!(expected, result);
+
+        Ok(())
+    }
     fn case_test_batch() -> Result<RecordBatch> {
         let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
         let a = StringArray::from(vec![Some("foo"), Some("baz"), None, 
Some("bar")]);

Reply via email to