This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new af3d19038e Fix ApproxPercentileCont signature (#8825)
af3d19038e is described below

commit af3d19038ef40e45c70c6e9a28798ee53795dfe9
Author: Georgi Krastev <[email protected]>
AuthorDate: Sun Jan 14 20:45:18 2024 +0200

    Fix ApproxPercentileCont signature (#8825)
    
    * Fix ApproxPercentileCont signature
    
    The number of centroids must be an integer in `coerce_types`.
    Reflect that in the type signature.
    
    * Add a unit test for percentile signature error message
---
 datafusion/expr/src/aggregate_function.rs          | 27 +++++++++++++---------
 datafusion/expr/src/type_coercion/aggregates.rs    | 22 +++++-------------
 datafusion/optimizer/src/analyzer/type_coercion.rs | 26 ++++++++++++++++++++-
 datafusion/sqllogictest/test_files/aggregate.slt   |  3 +++
 4 files changed, 50 insertions(+), 28 deletions(-)

diff --git a/datafusion/expr/src/aggregate_function.rs 
b/datafusion/expr/src/aggregate_function.rs
index cea72c3cb5..9db7635d99 100644
--- a/datafusion/expr/src/aggregate_function.rs
+++ b/datafusion/expr/src/aggregate_function.rs
@@ -386,18 +386,23 @@ impl AggregateFunction {
                 Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
             }
             AggregateFunction::ApproxPercentileCont => {
+                let mut variants =
+                    Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1));
                 // Accept any numeric value paired with a float64 percentile
-                let with_tdigest_size = NUMERICS.iter().map(|t| {
-                    TypeSignature::Exact(vec![t.clone(), DataType::Float64, 
t.clone()])
-                });
-                Signature::one_of(
-                    NUMERICS
-                        .iter()
-                        .map(|t| TypeSignature::Exact(vec![t.clone(), 
DataType::Float64]))
-                        .chain(with_tdigest_size)
-                        .collect(),
-                    Volatility::Immutable,
-                )
+                for num in NUMERICS {
+                    variants
+                        .push(TypeSignature::Exact(vec![num.clone(), 
DataType::Float64]));
+                    // Additionally accept an integer number of centroids for 
T-Digest
+                    for int in INTEGERS {
+                        variants.push(TypeSignature::Exact(vec![
+                            num.clone(),
+                            DataType::Float64,
+                            int.clone(),
+                        ]))
+                    }
+                }
+
+                Signature::one_of(variants, Volatility::Immutable)
             }
             AggregateFunction::ApproxPercentileContWithWeight => 
Signature::one_of(
                 // Accept any numeric value paired with a float64 percentile
diff --git a/datafusion/expr/src/type_coercion/aggregates.rs 
b/datafusion/expr/src/type_coercion/aggregates.rs
index 7128b57597..56bb5c9b69 100644
--- a/datafusion/expr/src/type_coercion/aggregates.rs
+++ b/datafusion/expr/src/type_coercion/aggregates.rs
@@ -223,7 +223,7 @@ pub fn coerce_types(
         | AggregateFunction::RegrSXX
         | AggregateFunction::RegrSYY
         | AggregateFunction::RegrSXY => {
-            let valid_types = [NUMERICS.to_vec(), 
vec![DataType::Null]].concat();
+            let valid_types = [NUMERICS.to_vec(), vec![Null]].concat();
             let input_types_valid = // number of input already checked before
                 valid_types.contains(&input_types[0]) && 
valid_types.contains(&input_types[1]);
             if !input_types_valid {
@@ -243,15 +243,15 @@ pub fn coerce_types(
                     input_types[0]
                 );
             }
-            if input_types.len() == 3 && !is_integer_arg_type(&input_types[2]) 
{
+            if input_types.len() == 3 && !input_types[2].is_integer() {
                 return plan_err!(
                         "The percentile sample points count for {:?} must be 
integer, not {:?}.",
                         agg_fun, input_types[2]
                     );
             }
             let mut result = input_types.to_vec();
-            if can_coerce_from(&DataType::Float64, &input_types[1]) {
-                result[1] = DataType::Float64;
+            if can_coerce_from(&Float64, &input_types[1]) {
+                result[1] = Float64;
             } else {
                 return plan_err!(
                     "Could not coerce the percent argument for {:?} to 
Float64. Was  {:?}.",
@@ -275,7 +275,7 @@ pub fn coerce_types(
                     input_types[1]
                 );
             }
-            if !matches!(input_types[2], DataType::Float64) {
+            if !matches!(input_types[2], Float64) {
                 return plan_err!(
                     "The percentile argument for {:?} must be Float64, not 
{:?}.",
                     agg_fun,
@@ -560,17 +560,7 @@ pub fn is_correlation_support_arg_type(arg_type: 
&DataType) -> bool {
 }
 
 pub fn is_integer_arg_type(arg_type: &DataType) -> bool {
-    matches!(
-        arg_type,
-        DataType::UInt8
-            | DataType::UInt16
-            | DataType::UInt32
-            | DataType::UInt64
-            | DataType::Int8
-            | DataType::Int16
-            | DataType::Int32
-            | DataType::Int64
-    )
+    arg_type.is_integer()
 }
 
 /// Return `true` if `arg_type` is of a [`DataType`] that the
diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs 
b/datafusion/optimizer/src/analyzer/type_coercion.rs
index 3821279fed..8c4e907e67 100644
--- a/datafusion/optimizer/src/analyzer/type_coercion.rs
+++ b/datafusion/optimizer/src/analyzer/type_coercion.rs
@@ -963,7 +963,7 @@ mod test {
     }
 
     #[test]
-    fn agg_function_invalid_input() -> Result<()> {
+    fn agg_function_invalid_input_avg() -> Result<()> {
         let empty = empty();
         let fun: AggregateFunction = AggregateFunction::Avg;
         let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new(
@@ -984,6 +984,30 @@ mod test {
         Ok(())
     }
 
+    #[test]
+    fn agg_function_invalid_input_percentile() {
+        let empty = empty();
+        let fun: AggregateFunction = AggregateFunction::ApproxPercentileCont;
+        let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new(
+            fun,
+            vec![lit(0.95), lit(42.0), lit(100.0)],
+            false,
+            None,
+            None,
+        ));
+
+        let err = Projection::try_new(vec![agg_expr], empty)
+            .err()
+            .unwrap()
+            .strip_backtrace();
+
+        let prefix = "Error during planning: No function matches the given 
name and argument types 'APPROX_PERCENTILE_CONT(Float64, Float64, Float64)'. 
You might need to add explicit type casts.\n\tCandidate functions:";
+        assert!(!err
+            .strip_prefix(prefix)
+            .unwrap()
+            .contains("APPROX_PERCENTILE_CONT(Float64, Float64, Float64)"));
+    }
+
     #[test]
     fn binary_op_date32_op_interval() -> Result<()> {
         //CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("386547056640")
diff --git a/datafusion/sqllogictest/test_files/aggregate.slt 
b/datafusion/sqllogictest/test_files/aggregate.slt
index aa512f6e26..50cdebd054 100644
--- a/datafusion/sqllogictest/test_files/aggregate.slt
+++ b/datafusion/sqllogictest/test_files/aggregate.slt
@@ -95,6 +95,9 @@ SELECT approx_percentile_cont(c3, 0.95, c1) FROM 
aggregate_test_100
 statement error DataFusion error: Error during planning: No function matches 
the given name and argument types 'APPROX_PERCENTILE_CONT\(Int16, Float64, 
Float64\)'\. You might need to add explicit type casts\.
 SELECT approx_percentile_cont(c3, 0.95, 111.1) FROM aggregate_test_100
 
+statement error DataFusion error: Error during planning: No function matches 
the given name and argument types 'APPROX_PERCENTILE_CONT\(Float64, Float64, 
Float64\)'\. You might need to add explicit type casts\.
+SELECT approx_percentile_cont(c12, 0.95, 111.1) FROM aggregate_test_100
+
 # array agg can use order by
 query ?
 SELECT array_agg(c13 ORDER BY c13)

Reply via email to