This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 8ebc94c fix: casting Int64 to Float64 unsuccessfully caused tpch8 to
fail (#1601)
8ebc94c is described below
commit 8ebc94c42e301e00cb77e36daa1e7b17a74a2b5b
Author: xudong.w <[email protected]>
AuthorDate: Wed Jan 19 00:01:00 2022 +0800
fix: casting Int64 to Float64 unsuccessfully caused tpch8 to fail (#1601)
---
datafusion/src/physical_plan/expressions/case.rs | 39 ++++++++++++++++++++++--
1 file changed, 37 insertions(+), 2 deletions(-)
diff --git a/datafusion/src/physical_plan/expressions/case.rs
b/datafusion/src/physical_plan/expressions/case.rs
index 551d87a..2a680d3 100644
--- a/datafusion/src/physical_plan/expressions/case.rs
+++ b/datafusion/src/physical_plan/expressions/case.rs
@@ -18,6 +18,7 @@
use std::{any::Any, sync::Arc};
use crate::error::{DataFusionError, Result};
+use crate::physical_plan::expressions::try_cast;
use crate::physical_plan::{ColumnarValue, PhysicalExpr};
use arrow::array::{self, *};
use arrow::compute::{eq, eq_utf8};
@@ -324,7 +325,10 @@ impl CaseExpr {
// start with the else condition, or nulls
let mut current_value: Option<ArrayRef> = if let Some(e) =
&self.else_expr {
- Some(e.evaluate(batch)?.into_array(batch.num_rows()))
+ // keep `else_expr`'s data type and return type consistent
+ let expr = try_cast(e.clone(), &*batch.schema(),
return_type.clone())
+ .unwrap_or_else(|_| e.clone());
+ Some(expr.evaluate(batch)?.into_array(batch.num_rows()))
} else {
Some(new_null_array(&return_type, batch.num_rows()))
};
@@ -365,7 +369,9 @@ impl CaseExpr {
// start with the else condition, or nulls
let mut current_value: Option<ArrayRef> = if let Some(e) =
&self.else_expr {
- Some(e.evaluate(batch)?.into_array(batch.num_rows()))
+ let expr = try_cast(e.clone(), &*batch.schema(),
return_type.clone())
+ .unwrap_or_else(|_| e.clone());
+ Some(expr.evaluate(batch)?.into_array(batch.num_rows()))
} else {
Some(new_null_array(&return_type, batch.num_rows()))
};
@@ -589,6 +595,35 @@ mod tests {
Ok(())
}
+ #[test]
+ fn case_with_type_cast() -> Result<()> {
+ let batch = case_test_batch()?;
+ let schema = batch.schema();
+
+ // CASE WHEN a = 'foo' THEN 123.3 ELSE 999 END
+ let when = binary(
+ col("a", &schema)?,
+ Operator::Eq,
+ lit(ScalarValue::Utf8(Some("foo".to_string()))),
+ &batch.schema(),
+ )?;
+ let then = lit(ScalarValue::Float64(Some(123.3)));
+ let else_value = lit(ScalarValue::Int32(Some(999)));
+
+ let expr = case(None, &[(when, then)], Some(else_value))?;
+ let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
+ let result = result
+ .as_any()
+ .downcast_ref::<Float64Array>()
+ .expect("failed to downcast to Float64Array");
+
+ let expected =
+ &Float64Array::from(vec![Some(123.3), Some(999.0), Some(999.0),
Some(999.0)]);
+
+ assert_eq!(expected, result);
+
+ Ok(())
+ }
fn case_test_batch() -> Result<RecordBatch> {
let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
let a = StringArray::from(vec![Some("foo"), Some("baz"), None,
Some("bar")]);