This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new a1c794c  fix return type conflict when calling builtin math fuctions 
(#716)
a1c794c is described below

commit a1c794cec233f7fe34f34c7d64f529625a507669
Author: Cui Wenzheng <[email protected]>
AuthorDate: Fri Jul 16 19:27:03 2021 +0800

    fix return type conflict when calling builtin math fuctions (#716)
---
 datafusion/src/execution/context.rs              | 76 +++++++++++++++++++++++-
 datafusion/src/physical_plan/functions.rs        | 25 +++++---
 datafusion/src/physical_plan/math_expressions.rs |  2 +-
 3 files changed, 92 insertions(+), 11 deletions(-)

diff --git a/datafusion/src/execution/context.rs 
b/datafusion/src/execution/context.rs
index d2dcec5..d4d3a8a 100644
--- a/datafusion/src/execution/context.rs
+++ b/datafusion/src/execution/context.rs
@@ -916,9 +916,10 @@ mod tests {
         physical_plan::expressions::AvgAccumulator,
     };
     use arrow::array::{
-        Array, ArrayRef, BinaryArray, DictionaryArray, Float64Array, 
Int32Array,
-        Int64Array, LargeBinaryArray, LargeStringArray, StringArray,
-        TimestampNanosecondArray,
+        Array, ArrayRef, BinaryArray, DictionaryArray, Float32Array, 
Float64Array,
+        Int16Array, Int32Array, Int64Array, Int8Array, LargeBinaryArray,
+        LargeStringArray, StringArray, TimestampNanosecondArray, UInt16Array,
+        UInt32Array, UInt64Array, UInt8Array,
     };
     use arrow::compute::add;
     use arrow::datatypes::*;
@@ -2365,6 +2366,75 @@ mod tests {
     }
 
     #[tokio::test]
+    async fn case_builtin_math_expression() {
+        let mut ctx = ExecutionContext::new();
+
+        let type_values = vec![
+            (
+                DataType::Int8,
+                Arc::new(Int8Array::from(vec![1])) as ArrayRef,
+            ),
+            (
+                DataType::Int16,
+                Arc::new(Int16Array::from(vec![1])) as ArrayRef,
+            ),
+            (
+                DataType::Int32,
+                Arc::new(Int32Array::from(vec![1])) as ArrayRef,
+            ),
+            (
+                DataType::Int64,
+                Arc::new(Int64Array::from(vec![1])) as ArrayRef,
+            ),
+            (
+                DataType::UInt8,
+                Arc::new(UInt8Array::from(vec![1])) as ArrayRef,
+            ),
+            (
+                DataType::UInt16,
+                Arc::new(UInt16Array::from(vec![1])) as ArrayRef,
+            ),
+            (
+                DataType::UInt32,
+                Arc::new(UInt32Array::from(vec![1])) as ArrayRef,
+            ),
+            (
+                DataType::UInt64,
+                Arc::new(UInt64Array::from(vec![1])) as ArrayRef,
+            ),
+            (
+                DataType::Float32,
+                Arc::new(Float32Array::from(vec![1.0_f32])) as ArrayRef,
+            ),
+            (
+                DataType::Float64,
+                Arc::new(Float64Array::from(vec![1.0_f64])) as ArrayRef,
+            ),
+        ];
+
+        for (data_type, array) in type_values.iter() {
+            let schema =
+                Arc::new(Schema::new(vec![Field::new("v", data_type.clone(), 
false)]));
+            let batch =
+                RecordBatch::try_new(schema.clone(), 
vec![array.clone()]).unwrap();
+            let provider = MemTable::try_new(schema, 
vec![vec![batch]]).unwrap();
+            ctx.register_table("t", Arc::new(provider)).unwrap();
+            let expected = vec![
+                "+---------+",
+                "| sqrt(v) |",
+                "+---------+",
+                "| 1       |",
+                "+---------+",
+            ];
+            let results = plan_and_collect(&mut ctx, "SELECT sqrt(v) FROM t")
+                .await
+                .unwrap();
+
+            assert_batches_sorted_eq!(expected, &results);
+        }
+    }
+
+    #[tokio::test]
     async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> 
{
         let mut ctx = ExecutionContext::new();
         ctx.register_table("t", test::table_with_sequence(1, 1).unwrap())
diff --git a/datafusion/src/physical_plan/functions.rs 
b/datafusion/src/physical_plan/functions.rs
index 01f7e95..d856ca4 100644
--- a/datafusion/src/physical_plan/functions.rs
+++ b/datafusion/src/physical_plan/functions.rs
@@ -468,7 +468,18 @@ pub fn return_type(
         | BuiltinScalarFunction::Sin
         | BuiltinScalarFunction::Sqrt
         | BuiltinScalarFunction::Tan
-        | BuiltinScalarFunction::Trunc => Ok(DataType::Float64),
+        | BuiltinScalarFunction::Trunc => {
+            if arg_types.is_empty() {
+                return Err(DataFusionError::Internal(format!(
+                    "builtin scalar function {} does not support empty 
arguments",
+                    fun
+                )));
+            }
+            match arg_types[0] {
+                DataType::Float32 => Ok(DataType::Float32),
+                _ => Ok(DataType::Float64),
+            }
+        }
     }
 }
 
@@ -1427,8 +1438,8 @@ mod tests {
     };
     use arrow::{
         array::{
-            Array, ArrayRef, BinaryArray, BooleanArray, FixedSizeListArray, 
Float64Array,
-            Int32Array, StringArray, UInt32Array, UInt64Array,
+            Array, ArrayRef, BinaryArray, BooleanArray, FixedSizeListArray, 
Float32Array,
+            Float64Array, Int32Array, StringArray, UInt32Array, UInt64Array,
         },
         datatypes::Field,
         record_batch::RecordBatch,
@@ -1857,10 +1868,10 @@ mod tests {
         test_function!(
             Exp,
             &[lit(ScalarValue::Float32(Some(1.0)))],
-            Ok(Some((1.0_f32).exp() as f64)),
-            f64,
-            Float64,
-            Float64Array
+            Ok(Some((1.0_f32).exp())),
+            f32,
+            Float32,
+            Float32Array
         );
         test_function!(
             InitCap,
diff --git a/datafusion/src/physical_plan/math_expressions.rs 
b/datafusion/src/physical_plan/math_expressions.rs
index cfc239c..eabacfc 100644
--- a/datafusion/src/physical_plan/math_expressions.rs
+++ b/datafusion/src/physical_plan/math_expressions.rs
@@ -60,7 +60,7 @@ macro_rules! unary_primitive_array_op {
             },
             ColumnarValue::Scalar(a) => match a {
                 ScalarValue::Float32(a) => Ok(ColumnarValue::Scalar(
-                    ScalarValue::Float64(a.map(|x| x.$FUNC() as f64)),
+                    ScalarValue::Float32(a.map(|x| x.$FUNC())),
                 )),
                 ScalarValue::Float64(a) => Ok(ColumnarValue::Scalar(
                     ScalarValue::Float64(a.map(|x| x.$FUNC())),

Reply via email to