This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new a1c794c fix return type conflict when calling builtin math fuctions
(#716)
a1c794c is described below
commit a1c794cec233f7fe34f34c7d64f529625a507669
Author: Cui Wenzheng <[email protected]>
AuthorDate: Fri Jul 16 19:27:03 2021 +0800
fix return type conflict when calling builtin math fuctions (#716)
---
datafusion/src/execution/context.rs | 76 +++++++++++++++++++++++-
datafusion/src/physical_plan/functions.rs | 25 +++++---
datafusion/src/physical_plan/math_expressions.rs | 2 +-
3 files changed, 92 insertions(+), 11 deletions(-)
diff --git a/datafusion/src/execution/context.rs
b/datafusion/src/execution/context.rs
index d2dcec5..d4d3a8a 100644
--- a/datafusion/src/execution/context.rs
+++ b/datafusion/src/execution/context.rs
@@ -916,9 +916,10 @@ mod tests {
physical_plan::expressions::AvgAccumulator,
};
use arrow::array::{
- Array, ArrayRef, BinaryArray, DictionaryArray, Float64Array,
Int32Array,
- Int64Array, LargeBinaryArray, LargeStringArray, StringArray,
- TimestampNanosecondArray,
+ Array, ArrayRef, BinaryArray, DictionaryArray, Float32Array,
Float64Array,
+ Int16Array, Int32Array, Int64Array, Int8Array, LargeBinaryArray,
+ LargeStringArray, StringArray, TimestampNanosecondArray, UInt16Array,
+ UInt32Array, UInt64Array, UInt8Array,
};
use arrow::compute::add;
use arrow::datatypes::*;
@@ -2365,6 +2366,75 @@ mod tests {
}
#[tokio::test]
+ async fn case_builtin_math_expression() {
+ let mut ctx = ExecutionContext::new();
+
+ let type_values = vec![
+ (
+ DataType::Int8,
+ Arc::new(Int8Array::from(vec![1])) as ArrayRef,
+ ),
+ (
+ DataType::Int16,
+ Arc::new(Int16Array::from(vec![1])) as ArrayRef,
+ ),
+ (
+ DataType::Int32,
+ Arc::new(Int32Array::from(vec![1])) as ArrayRef,
+ ),
+ (
+ DataType::Int64,
+ Arc::new(Int64Array::from(vec![1])) as ArrayRef,
+ ),
+ (
+ DataType::UInt8,
+ Arc::new(UInt8Array::from(vec![1])) as ArrayRef,
+ ),
+ (
+ DataType::UInt16,
+ Arc::new(UInt16Array::from(vec![1])) as ArrayRef,
+ ),
+ (
+ DataType::UInt32,
+ Arc::new(UInt32Array::from(vec![1])) as ArrayRef,
+ ),
+ (
+ DataType::UInt64,
+ Arc::new(UInt64Array::from(vec![1])) as ArrayRef,
+ ),
+ (
+ DataType::Float32,
+ Arc::new(Float32Array::from(vec![1.0_f32])) as ArrayRef,
+ ),
+ (
+ DataType::Float64,
+ Arc::new(Float64Array::from(vec![1.0_f64])) as ArrayRef,
+ ),
+ ];
+
+ for (data_type, array) in type_values.iter() {
+ let schema =
+ Arc::new(Schema::new(vec![Field::new("v", data_type.clone(),
false)]));
+ let batch =
+ RecordBatch::try_new(schema.clone(),
vec![array.clone()]).unwrap();
+ let provider = MemTable::try_new(schema,
vec![vec![batch]]).unwrap();
+ ctx.register_table("t", Arc::new(provider)).unwrap();
+ let expected = vec![
+ "+---------+",
+ "| sqrt(v) |",
+ "+---------+",
+ "| 1 |",
+ "+---------+",
+ ];
+ let results = plan_and_collect(&mut ctx, "SELECT sqrt(v) FROM t")
+ .await
+ .unwrap();
+
+ assert_batches_sorted_eq!(expected, &results);
+ }
+ }
+
+ #[tokio::test]
async fn case_sensitive_identifiers_user_defined_functions() -> Result<()>
{
let mut ctx = ExecutionContext::new();
ctx.register_table("t", test::table_with_sequence(1, 1).unwrap())
diff --git a/datafusion/src/physical_plan/functions.rs
b/datafusion/src/physical_plan/functions.rs
index 01f7e95..d856ca4 100644
--- a/datafusion/src/physical_plan/functions.rs
+++ b/datafusion/src/physical_plan/functions.rs
@@ -468,7 +468,18 @@ pub fn return_type(
| BuiltinScalarFunction::Sin
| BuiltinScalarFunction::Sqrt
| BuiltinScalarFunction::Tan
- | BuiltinScalarFunction::Trunc => Ok(DataType::Float64),
+ | BuiltinScalarFunction::Trunc => {
+ if arg_types.is_empty() {
+ return Err(DataFusionError::Internal(format!(
+ "builtin scalar function {} does not support empty
arguments",
+ fun
+ )));
+ }
+ match arg_types[0] {
+ DataType::Float32 => Ok(DataType::Float32),
+ _ => Ok(DataType::Float64),
+ }
+ }
}
}
@@ -1427,8 +1438,8 @@ mod tests {
};
use arrow::{
array::{
- Array, ArrayRef, BinaryArray, BooleanArray, FixedSizeListArray,
Float64Array,
- Int32Array, StringArray, UInt32Array, UInt64Array,
+ Array, ArrayRef, BinaryArray, BooleanArray, FixedSizeListArray,
Float32Array,
+ Float64Array, Int32Array, StringArray, UInt32Array, UInt64Array,
},
datatypes::Field,
record_batch::RecordBatch,
@@ -1857,10 +1868,10 @@ mod tests {
test_function!(
Exp,
&[lit(ScalarValue::Float32(Some(1.0)))],
- Ok(Some((1.0_f32).exp() as f64)),
- f64,
- Float64,
- Float64Array
+ Ok(Some((1.0_f32).exp())),
+ f32,
+ Float32,
+ Float32Array
);
test_function!(
InitCap,
diff --git a/datafusion/src/physical_plan/math_expressions.rs
b/datafusion/src/physical_plan/math_expressions.rs
index cfc239c..eabacfc 100644
--- a/datafusion/src/physical_plan/math_expressions.rs
+++ b/datafusion/src/physical_plan/math_expressions.rs
@@ -60,7 +60,7 @@ macro_rules! unary_primitive_array_op {
},
ColumnarValue::Scalar(a) => match a {
ScalarValue::Float32(a) => Ok(ColumnarValue::Scalar(
- ScalarValue::Float64(a.map(|x| x.$FUNC() as f64)),
+ ScalarValue::Float32(a.map(|x| x.$FUNC())),
)),
ScalarValue::Float64(a) => Ok(ColumnarValue::Scalar(
ScalarValue::Float64(a.map(|x| x.$FUNC())),