This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new af3d19038e Fix ApproxPercentileCont signature (#8825)
af3d19038e is described below
commit af3d19038ef40e45c70c6e9a28798ee53795dfe9
Author: Georgi Krastev <[email protected]>
AuthorDate: Sun Jan 14 20:45:18 2024 +0200
Fix ApproxPercentileCont signature (#8825)
* Fix ApproxPercentileCont signature
The number of centroids must be an integer in `coerce_types`.
Reflect that in the type signature.
* Add a unit test for percentile signature error message
---
datafusion/expr/src/aggregate_function.rs | 27 +++++++++++++---------
datafusion/expr/src/type_coercion/aggregates.rs | 22 +++++-------------
datafusion/optimizer/src/analyzer/type_coercion.rs | 26 ++++++++++++++++++++-
datafusion/sqllogictest/test_files/aggregate.slt | 3 +++
4 files changed, 50 insertions(+), 28 deletions(-)
diff --git a/datafusion/expr/src/aggregate_function.rs
b/datafusion/expr/src/aggregate_function.rs
index cea72c3cb5..9db7635d99 100644
--- a/datafusion/expr/src/aggregate_function.rs
+++ b/datafusion/expr/src/aggregate_function.rs
@@ -386,18 +386,23 @@ impl AggregateFunction {
Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
}
AggregateFunction::ApproxPercentileCont => {
+ let mut variants =
+ Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1));
// Accept any numeric value paired with a float64 percentile
- let with_tdigest_size = NUMERICS.iter().map(|t| {
- TypeSignature::Exact(vec![t.clone(), DataType::Float64,
t.clone()])
- });
- Signature::one_of(
- NUMERICS
- .iter()
- .map(|t| TypeSignature::Exact(vec![t.clone(),
DataType::Float64]))
- .chain(with_tdigest_size)
- .collect(),
- Volatility::Immutable,
- )
+ for num in NUMERICS {
+ variants
+ .push(TypeSignature::Exact(vec![num.clone(),
DataType::Float64]));
+ // Additionally accept an integer number of centroids for
T-Digest
+ for int in INTEGERS {
+ variants.push(TypeSignature::Exact(vec![
+ num.clone(),
+ DataType::Float64,
+ int.clone(),
+ ]))
+ }
+ }
+
+ Signature::one_of(variants, Volatility::Immutable)
}
AggregateFunction::ApproxPercentileContWithWeight =>
Signature::one_of(
// Accept any numeric value paired with a float64 percentile
diff --git a/datafusion/expr/src/type_coercion/aggregates.rs
b/datafusion/expr/src/type_coercion/aggregates.rs
index 7128b57597..56bb5c9b69 100644
--- a/datafusion/expr/src/type_coercion/aggregates.rs
+++ b/datafusion/expr/src/type_coercion/aggregates.rs
@@ -223,7 +223,7 @@ pub fn coerce_types(
| AggregateFunction::RegrSXX
| AggregateFunction::RegrSYY
| AggregateFunction::RegrSXY => {
- let valid_types = [NUMERICS.to_vec(),
vec![DataType::Null]].concat();
+ let valid_types = [NUMERICS.to_vec(), vec![Null]].concat();
let input_types_valid = // number of input already checked before
valid_types.contains(&input_types[0]) &&
valid_types.contains(&input_types[1]);
if !input_types_valid {
@@ -243,15 +243,15 @@ pub fn coerce_types(
input_types[0]
);
}
- if input_types.len() == 3 && !is_integer_arg_type(&input_types[2])
{
+ if input_types.len() == 3 && !input_types[2].is_integer() {
return plan_err!(
"The percentile sample points count for {:?} must be
integer, not {:?}.",
agg_fun, input_types[2]
);
}
let mut result = input_types.to_vec();
- if can_coerce_from(&DataType::Float64, &input_types[1]) {
- result[1] = DataType::Float64;
+ if can_coerce_from(&Float64, &input_types[1]) {
+ result[1] = Float64;
} else {
return plan_err!(
"Could not coerce the percent argument for {:?} to
Float64. Was {:?}.",
@@ -275,7 +275,7 @@ pub fn coerce_types(
input_types[1]
);
}
- if !matches!(input_types[2], DataType::Float64) {
+ if !matches!(input_types[2], Float64) {
return plan_err!(
"The percentile argument for {:?} must be Float64, not
{:?}.",
agg_fun,
@@ -560,17 +560,7 @@ pub fn is_correlation_support_arg_type(arg_type:
&DataType) -> bool {
}
pub fn is_integer_arg_type(arg_type: &DataType) -> bool {
- matches!(
- arg_type,
- DataType::UInt8
- | DataType::UInt16
- | DataType::UInt32
- | DataType::UInt64
- | DataType::Int8
- | DataType::Int16
- | DataType::Int32
- | DataType::Int64
- )
+ arg_type.is_integer()
}
/// Return `true` if `arg_type` is of a [`DataType`] that the
diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs
b/datafusion/optimizer/src/analyzer/type_coercion.rs
index 3821279fed..8c4e907e67 100644
--- a/datafusion/optimizer/src/analyzer/type_coercion.rs
+++ b/datafusion/optimizer/src/analyzer/type_coercion.rs
@@ -963,7 +963,7 @@ mod test {
}
#[test]
- fn agg_function_invalid_input() -> Result<()> {
+ fn agg_function_invalid_input_avg() -> Result<()> {
let empty = empty();
let fun: AggregateFunction = AggregateFunction::Avg;
let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new(
@@ -984,6 +984,30 @@ mod test {
Ok(())
}
+ #[test]
+ fn agg_function_invalid_input_percentile() {
+ let empty = empty();
+ let fun: AggregateFunction = AggregateFunction::ApproxPercentileCont;
+ let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new(
+ fun,
+ vec![lit(0.95), lit(42.0), lit(100.0)],
+ false,
+ None,
+ None,
+ ));
+
+ let err = Projection::try_new(vec![agg_expr], empty)
+ .err()
+ .unwrap()
+ .strip_backtrace();
+
+ let prefix = "Error during planning: No function matches the given
name and argument types 'APPROX_PERCENTILE_CONT(Float64, Float64, Float64)'.
You might need to add explicit type casts.\n\tCandidate functions:";
+ assert!(!err
+ .strip_prefix(prefix)
+ .unwrap()
+ .contains("APPROX_PERCENTILE_CONT(Float64, Float64, Float64)"));
+ }
+
#[test]
fn binary_op_date32_op_interval() -> Result<()> {
//CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("386547056640")
diff --git a/datafusion/sqllogictest/test_files/aggregate.slt
b/datafusion/sqllogictest/test_files/aggregate.slt
index aa512f6e26..50cdebd054 100644
--- a/datafusion/sqllogictest/test_files/aggregate.slt
+++ b/datafusion/sqllogictest/test_files/aggregate.slt
@@ -95,6 +95,9 @@ SELECT approx_percentile_cont(c3, 0.95, c1) FROM
aggregate_test_100
statement error DataFusion error: Error during planning: No function matches
the given name and argument types 'APPROX_PERCENTILE_CONT\(Int16, Float64,
Float64\)'\. You might need to add explicit type casts\.
SELECT approx_percentile_cont(c3, 0.95, 111.1) FROM aggregate_test_100
+statement error DataFusion error: Error during planning: No function matches
the given name and argument types 'APPROX_PERCENTILE_CONT\(Float64, Float64,
Float64\)'\. You might need to add explicit type casts\.
+SELECT approx_percentile_cont(c12, 0.95, 111.1) FROM aggregate_test_100
+
# array agg can use order by
query ?
SELECT array_agg(c13 ORDER BY c13)