Jefffrey commented on code in PR #18091:
URL: https://github.com/apache/datafusion/pull/18091#discussion_r2434608691
##########
datafusion/expr/src/test/function_stub.rs:
##########
@@ -488,8 +490,61 @@ impl AggregateUDFImpl for Avg {
&self.signature
}
+ fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
Review Comment:
Technically this is a duplication with the avg logic, but I'm not sure why
we need these stubs in the first place 🤔
##########
datafusion/spark/src/function/aggregate/avg.rs:
##########
@@ -25,41 +25,37 @@ use arrow::array::{
use arrow::compute::sum;
use arrow::datatypes::{DataType, Field, FieldRef};
use datafusion_common::utils::take_function_args;
-use datafusion_common::{not_impl_err, Result, ScalarValue};
+use datafusion_common::{not_impl_err, plan_err, Result, ScalarValue};
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
-use datafusion_expr::type_coercion::aggregates::coerce_avg_type;
use datafusion_expr::utils::format_state_name;
use datafusion_expr::Volatility::Immutable;
use datafusion_expr::{
- type_coercion::aggregates::avg_return_type, Accumulator, AggregateUDFImpl,
EmitTo,
- GroupsAccumulator, ReversedUDAF, Signature,
+ Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF,
Signature,
};
use std::{any::Any, sync::Arc};
-use DataType::*;
/// AVG aggregate expression
/// Spark average aggregate expression. Differs from standard DataFusion
average aggregate
/// in that it uses an `i64` for the count (DataFusion version uses `u64`);
also there is ANSI mode
/// support planned in the future for Spark version.
+// TODO: see if can deduplicate with DF version
Review Comment:
I'll do this in a followup
##########
datafusion/spark/src/function/aggregate/avg.rs:
##########
@@ -69,63 +65,87 @@ impl AggregateUDFImpl for SparkAvg {
self
}
- fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn
Accumulator>> {
+ fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
+ let [args] = take_function_args(self.name(), arg_types)?;
+
+ fn coerced_type(data_type: &DataType) -> Result<DataType> {
+ match &data_type {
+ d if d.is_numeric() => Ok(DataType::Float64),
+ DataType::Dictionary(_, v) => coerced_type(v.as_ref()),
+ _ => {
+ plan_err!("Avg does not support inputs of type
{data_type}.")
+ }
+ }
+ }
Review Comment:
I reduced it to the minimal version as in current version it only accepts
numerics (excluding decimals) by casting to float, so didn't see need for the
full logic from the DataFusion version; in followup hopefully can find a way to
fold this into the DataFusion version
##########
datafusion/functions-aggregate/src/average.rs:
##########
@@ -125,8 +126,61 @@ impl AggregateUDFImpl for Avg {
&self.signature
}
+ fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
Review Comment:
Here we inline it so the logic is closer to function implementation
##########
datafusion/expr-common/src/type_coercion/aggregates.rs:
##########
@@ -16,31 +16,11 @@
// under the License.
use crate::signature::TypeSignature;
-use arrow::datatypes::{
- DataType, FieldRef, TimeUnit, DECIMAL128_MAX_PRECISION,
DECIMAL128_MAX_SCALE,
- DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, DECIMAL32_MAX_PRECISION,
- DECIMAL32_MAX_SCALE, DECIMAL64_MAX_PRECISION, DECIMAL64_MAX_SCALE,
-};
+use arrow::datatypes::{DataType, FieldRef};
use datafusion_common::{internal_err, plan_err, Result};
-pub static STRINGS: &[DataType] =
- &[DataType::Utf8, DataType::LargeUtf8, DataType::Utf8View];
-
-pub static SIGNED_INTEGERS: &[DataType] = &[
- DataType::Int8,
- DataType::Int16,
- DataType::Int32,
- DataType::Int64,
-];
-
-pub static UNSIGNED_INTEGERS: &[DataType] = &[
- DataType::UInt8,
- DataType::UInt16,
- DataType::UInt32,
- DataType::UInt64,
-];
-
+// TODO: remove usage of these (INTEGERS and NUMERICS) in favour of signatures
Review Comment:
I'll raise an issue for this
##########
datafusion/expr-common/src/type_coercion/aggregates.rs:
##########
@@ -16,31 +16,11 @@
// under the License.
use crate::signature::TypeSignature;
-use arrow::datatypes::{
- DataType, FieldRef, TimeUnit, DECIMAL128_MAX_PRECISION,
DECIMAL128_MAX_SCALE,
- DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, DECIMAL32_MAX_PRECISION,
- DECIMAL32_MAX_SCALE, DECIMAL64_MAX_PRECISION, DECIMAL64_MAX_SCALE,
-};
+use arrow::datatypes::{DataType, FieldRef};
use datafusion_common::{internal_err, plan_err, Result};
-pub static STRINGS: &[DataType] =
Review Comment:
I assume these were leftover from before moving to the Signature API
##########
datafusion/expr-common/src/type_coercion/aggregates.rs:
##########
@@ -144,260 +106,3 @@ pub fn check_arg_count(
}
Ok(())
}
-
-/// Function return type of a sum
-pub fn sum_return_type(arg_type: &DataType) -> Result<DataType> {
Review Comment:
All of these functions except the avg ones were unused in our code, and I
don't think it makes sense to have them available for users anyway
##########
datafusion/functions-window/src/nth_value.rs:
##########
@@ -40,39 +40,28 @@ use std::hash::Hash;
use std::ops::Range;
use std::sync::{Arc, LazyLock};
-get_or_init_udwf!(
+define_udwf_and_expr!(
First,
first_value,
- "returns the first value in the window frame",
+ [arg],
+ "Returns the first value in the window frame",
NthValue::first
);
-get_or_init_udwf!(
+define_udwf_and_expr!(
Review Comment:
Taking advantage of `define_udwf_and_expr!()` macro which creates the
function for us; no change to public API, function definition is the same
##########
datafusion/functions-window/src/ntile.rs:
##########
@@ -37,16 +36,13 @@ use std::any::Any;
use std::fmt::Debug;
use std::sync::Arc;
-get_or_init_udwf!(
+define_udwf_and_expr!(
Ntile,
ntile,
- "integer ranging from 1 to the argument value, dividing the partition as
equally as possible"
+ [arg],
+ "Integer ranging from 1 to the argument value, dividing the partition as
equally as possible."
);
-pub fn ntile(arg: Expr) -> Expr {
- ntile_udwf().call(vec![arg])
-}
-
Review Comment:
Ditto
##########
datafusion/functions-aggregate/src/lib.rs:
##########
@@ -207,13 +207,7 @@ mod tests {
#[test]
fn test_no_duplicate_name() -> Result<()> {
let mut names = HashSet::new();
- let migrated_functions = ["array_agg", "count", "max", "min"];
for func in all_default_aggregate_functions() {
- // TODO: remove this
- // These functions are in intermediate migration state, skip them
- if
migrated_functions.contains(&func.name().to_lowercase().as_str()) {
- continue;
- }
Review Comment:
Minor cleanup of TODO here
##########
datafusion/spark/src/function/aggregate/avg.rs:
##########
@@ -69,63 +65,87 @@ impl AggregateUDFImpl for SparkAvg {
self
}
- fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn
Accumulator>> {
+ fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
+ let [args] = take_function_args(self.name(), arg_types)?;
+
+ fn coerced_type(data_type: &DataType) -> Result<DataType> {
+ match &data_type {
+ d if d.is_numeric() => Ok(DataType::Float64),
+ DataType::Dictionary(_, v) => coerced_type(v.as_ref()),
+ _ => {
+ plan_err!("Avg does not support inputs of type
{data_type}.")
+ }
+ }
+ }
+ Ok(vec![coerced_type(args)?])
+ }
+
+ fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
+ Ok(DataType::Float64)
+ }
+
+ fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn
Accumulator>> {
+ if acc_args.is_distinct {
+ return not_impl_err!("DistinctAvgAccumulator");
+ }
Review Comment:
Added a missing distinct check
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]