This is an automated email from the ASF dual-hosted git repository.

github-bot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 2222abda25 refactor: remove unused `type_coercion/aggregate.rs` 
functions (#18091)
2222abda25 is described below

commit 2222abda2545812c25f1720d9c3cfc45c0b9bc85
Author: Jeffrey Vo <[email protected]>
AuthorDate: Sat Oct 18 03:39:13 2025 +1100

    refactor: remove unused `type_coercion/aggregate.rs` functions (#18091)
    
    ## Which issue does this PR close?
    
    <!--
    We generally require a GitHub issue to be filed for all bug fixes and
    enhancements and this helps us generate change logs for our releases.
    You can link an issue to this PR using the GitHub syntax. For example
    `Closes #123` indicates that this PR will close issue #123.
    -->
    
    N/A
    
    ## Rationale for this change
    
    <!--
    Why are you proposing this change? If this is already explained clearly
    in the issue then this section is not needed.
    Explaining clearly why changes are proposed helps reviewers understand
    your changes and offer better suggestions for fixes.
    -->
    
    There's a few functions in
    `datafusion/expr-common/src/type_coercion/aggregates.rs` that are unused
    elsewhere in the codebase, likely a remnant before the refactor to UDF,
    so removing them. Some are still used (`coerce_avg_type()` and
    `avg_return_type()`) so these are inlined into the Avg aggregate
    function (similar to Sum). Also refactor some window functions to use
    already available macros.
    
    ## What changes are included in this PR?
    
    <!--
    There is no need to duplicate the description in the issue here but it
    is sometimes worth providing a summary of the individual changes in this
    PR.
    -->
    
    - Remove some unused functions
    - Inline avg coerce & return type logic
    - Refactor Spark Avg a bit to remove unnecessary code
    - Refactor ntile & nth window functions to use available macros
    
    ## Are these changes tested?
    
    <!--
    We typically require tests for all PRs in order to:
    1. Prevent the code from being accidentally broken by subsequent changes
    2. Serve as another way to document the expected behavior of the code
    
    If tests are not included in your PR, please explain why (for example,
    are they covered by existing tests)?
    -->
    
    Existing tests.
    
    ## Are there any user-facing changes?
    
    Yes as these functions were publicly exported; however I'm not sure they
    were meant to be used by users anyway, given what they do.
    
    <!--
    If there are user-facing changes then we may require documentation to be
    updated before approving the PR.
    -->
    
    <!--
    If there are any breaking changes to public APIs, please add the `api
    change` label.
    -->
---
 .../expr-common/src/type_coercion/aggregates.rs    | 300 +--------------------
 datafusion/expr/src/test/function_stub.rs          |  67 ++++-
 datafusion/functions-aggregate/src/average.rs      |  67 ++++-
 datafusion/functions-aggregate/src/lib.rs          |   6 -
 datafusion/functions-window/src/nth_value.rs       |  25 +-
 datafusion/functions-window/src/ntile.rs           |  12 +-
 datafusion/spark/src/function/aggregate/avg.rs     | 110 ++++----
 datafusion/spark/src/function/aggregate/mod.rs     |   7 +-
 8 files changed, 194 insertions(+), 400 deletions(-)

diff --git a/datafusion/expr-common/src/type_coercion/aggregates.rs 
b/datafusion/expr-common/src/type_coercion/aggregates.rs
index e77a072a84..55a8843394 100644
--- a/datafusion/expr-common/src/type_coercion/aggregates.rs
+++ b/datafusion/expr-common/src/type_coercion/aggregates.rs
@@ -16,31 +16,12 @@
 // under the License.
 
 use crate::signature::TypeSignature;
-use arrow::datatypes::{
-    DataType, FieldRef, TimeUnit, DECIMAL128_MAX_PRECISION, 
DECIMAL128_MAX_SCALE,
-    DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, DECIMAL32_MAX_PRECISION,
-    DECIMAL32_MAX_SCALE, DECIMAL64_MAX_PRECISION, DECIMAL64_MAX_SCALE,
-};
+use arrow::datatypes::{DataType, FieldRef};
 
 use datafusion_common::{internal_err, plan_err, Result};
 
-pub static STRINGS: &[DataType] =
-    &[DataType::Utf8, DataType::LargeUtf8, DataType::Utf8View];
-
-pub static SIGNED_INTEGERS: &[DataType] = &[
-    DataType::Int8,
-    DataType::Int16,
-    DataType::Int32,
-    DataType::Int64,
-];
-
-pub static UNSIGNED_INTEGERS: &[DataType] = &[
-    DataType::UInt8,
-    DataType::UInt16,
-    DataType::UInt32,
-    DataType::UInt64,
-];
-
+// TODO: remove usage of these (INTEGERS and NUMERICS) in favour of signatures
+//       see https://github.com/apache/datafusion/issues/18092
 pub static INTEGERS: &[DataType] = &[
     DataType::Int8,
     DataType::Int16,
@@ -65,24 +46,6 @@ pub static NUMERICS: &[DataType] = &[
     DataType::Float64,
 ];
 
-pub static TIMESTAMPS: &[DataType] = &[
-    DataType::Timestamp(TimeUnit::Second, None),
-    DataType::Timestamp(TimeUnit::Millisecond, None),
-    DataType::Timestamp(TimeUnit::Microsecond, None),
-    DataType::Timestamp(TimeUnit::Nanosecond, None),
-];
-
-pub static DATES: &[DataType] = &[DataType::Date32, DataType::Date64];
-
-pub static BINARYS: &[DataType] = &[DataType::Binary, DataType::LargeBinary];
-
-pub static TIMES: &[DataType] = &[
-    DataType::Time32(TimeUnit::Second),
-    DataType::Time32(TimeUnit::Millisecond),
-    DataType::Time64(TimeUnit::Microsecond),
-    DataType::Time64(TimeUnit::Nanosecond),
-];
-
 /// Validate the length of `input_fields` matches the `signature` for 
`agg_fun`.
 ///
 /// This method DOES NOT validate the argument fields - only that (at least 
one,
@@ -144,260 +107,3 @@ pub fn check_arg_count(
     }
     Ok(())
 }
-
-/// Function return type of a sum
-pub fn sum_return_type(arg_type: &DataType) -> Result<DataType> {
-    match arg_type {
-        DataType::Int64 => Ok(DataType::Int64),
-        DataType::UInt64 => Ok(DataType::UInt64),
-        DataType::Float64 => Ok(DataType::Float64),
-        DataType::Decimal32(precision, scale) => {
-            // in the spark, the result type is DECIMAL(min(38,precision+10), 
s)
-            // ref: 
https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
-            let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 10);
-            Ok(DataType::Decimal32(new_precision, *scale))
-        }
-        DataType::Decimal64(precision, scale) => {
-            // in the spark, the result type is DECIMAL(min(38,precision+10), 
s)
-            // ref: 
https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
-            let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 10);
-            Ok(DataType::Decimal64(new_precision, *scale))
-        }
-        DataType::Decimal128(precision, scale) => {
-            // In the spark, the result type is DECIMAL(min(38,precision+10), 
s)
-            // Ref: 
https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
-            let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
-            Ok(DataType::Decimal128(new_precision, *scale))
-        }
-        DataType::Decimal256(precision, scale) => {
-            // In the spark, the result type is DECIMAL(min(38,precision+10), 
s)
-            // Ref: 
https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
-            let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10);
-            Ok(DataType::Decimal256(new_precision, *scale))
-        }
-        other => plan_err!("SUM does not support type \"{other:?}\""),
-    }
-}
-
-/// Function return type of variance
-pub fn variance_return_type(arg_type: &DataType) -> Result<DataType> {
-    if NUMERICS.contains(arg_type) {
-        Ok(DataType::Float64)
-    } else {
-        plan_err!("VAR does not support {arg_type}")
-    }
-}
-
-/// Function return type of covariance
-pub fn covariance_return_type(arg_type: &DataType) -> Result<DataType> {
-    if NUMERICS.contains(arg_type) {
-        Ok(DataType::Float64)
-    } else {
-        plan_err!("COVAR does not support {arg_type}")
-    }
-}
-
-/// Function return type of correlation
-pub fn correlation_return_type(arg_type: &DataType) -> Result<DataType> {
-    if NUMERICS.contains(arg_type) {
-        Ok(DataType::Float64)
-    } else {
-        plan_err!("CORR does not support {arg_type}")
-    }
-}
-
-/// Function return type of an average
-pub fn avg_return_type(func_name: &str, arg_type: &DataType) -> 
Result<DataType> {
-    match arg_type {
-        DataType::Decimal32(precision, scale) => {
-            // In the spark, the result type is DECIMAL(min(38,precision+4), 
min(38,scale+4)).
-            // Ref: 
https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
-            let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 4);
-            let new_scale = DECIMAL32_MAX_SCALE.min(*scale + 4);
-            Ok(DataType::Decimal32(new_precision, new_scale))
-        }
-        DataType::Decimal64(precision, scale) => {
-            // In the spark, the result type is DECIMAL(min(38,precision+4), 
min(38,scale+4)).
-            // Ref: 
https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
-            let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 4);
-            let new_scale = DECIMAL64_MAX_SCALE.min(*scale + 4);
-            Ok(DataType::Decimal64(new_precision, new_scale))
-        }
-        DataType::Decimal128(precision, scale) => {
-            // In the spark, the result type is DECIMAL(min(38,precision+4), 
min(38,scale+4)).
-            // Ref: 
https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
-            let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 4);
-            let new_scale = DECIMAL128_MAX_SCALE.min(*scale + 4);
-            Ok(DataType::Decimal128(new_precision, new_scale))
-        }
-        DataType::Decimal256(precision, scale) => {
-            // In the spark, the result type is DECIMAL(min(38,precision+4), 
min(38,scale+4)).
-            // Ref: 
https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
-            let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 4);
-            let new_scale = DECIMAL256_MAX_SCALE.min(*scale + 4);
-            Ok(DataType::Decimal256(new_precision, new_scale))
-        }
-        DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)),
-        arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64),
-        DataType::Dictionary(_, dict_value_type) => {
-            avg_return_type(func_name, dict_value_type.as_ref())
-        }
-        other => plan_err!("{func_name} does not support {other:?}"),
-    }
-}
-
-/// Internal sum type of an average
-pub fn avg_sum_type(arg_type: &DataType) -> Result<DataType> {
-    match arg_type {
-        DataType::Decimal32(precision, scale) => {
-            // In the spark, the sum type of avg is 
DECIMAL(min(38,precision+10), s)
-            let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 10);
-            Ok(DataType::Decimal32(new_precision, *scale))
-        }
-        DataType::Decimal64(precision, scale) => {
-            // In the spark, the sum type of avg is 
DECIMAL(min(38,precision+10), s)
-            let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 10);
-            Ok(DataType::Decimal64(new_precision, *scale))
-        }
-        DataType::Decimal128(precision, scale) => {
-            // In the spark, the sum type of avg is 
DECIMAL(min(38,precision+10), s)
-            let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
-            Ok(DataType::Decimal128(new_precision, *scale))
-        }
-        DataType::Decimal256(precision, scale) => {
-            // In Spark the sum type of avg is DECIMAL(min(38,precision+10), s)
-            let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10);
-            Ok(DataType::Decimal256(new_precision, *scale))
-        }
-        DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)),
-        arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64),
-        DataType::Dictionary(_, dict_value_type) => {
-            avg_sum_type(dict_value_type.as_ref())
-        }
-        other => plan_err!("AVG does not support {other:?}"),
-    }
-}
-
-pub fn is_sum_support_arg_type(arg_type: &DataType) -> bool {
-    match arg_type {
-        DataType::Dictionary(_, dict_value_type) => {
-            is_sum_support_arg_type(dict_value_type.as_ref())
-        }
-        _ => matches!(
-            arg_type,
-            arg_type if NUMERICS.contains(arg_type)
-            || matches!(arg_type, DataType::Decimal32(_, _) | 
DataType::Decimal64(_, _) |DataType::Decimal128(_, _) | DataType::Decimal256(_, 
_))
-        ),
-    }
-}
-
-pub fn is_avg_support_arg_type(arg_type: &DataType) -> bool {
-    match arg_type {
-        DataType::Dictionary(_, dict_value_type) => {
-            is_avg_support_arg_type(dict_value_type.as_ref())
-        }
-        _ => matches!(
-            arg_type,
-            arg_type if NUMERICS.contains(arg_type)
-            || matches!(arg_type, DataType::Decimal32(_, _) | 
DataType::Decimal64(_, _) |DataType::Decimal128(_, _) | DataType::Decimal256(_, 
_))
-        ),
-    }
-}
-
-pub fn is_variance_support_arg_type(arg_type: &DataType) -> bool {
-    matches!(
-        arg_type,
-        arg_type if NUMERICS.contains(arg_type)
-    )
-}
-
-pub fn is_covariance_support_arg_type(arg_type: &DataType) -> bool {
-    matches!(
-        arg_type,
-        arg_type if NUMERICS.contains(arg_type)
-    )
-}
-
-pub fn is_correlation_support_arg_type(arg_type: &DataType) -> bool {
-    matches!(
-        arg_type,
-        arg_type if NUMERICS.contains(arg_type)
-    )
-}
-
-pub fn is_integer_arg_type(arg_type: &DataType) -> bool {
-    arg_type.is_integer()
-}
-
-pub fn coerce_avg_type(func_name: &str, arg_types: &[DataType]) -> 
Result<Vec<DataType>> {
-    // Supported types smallint, int, bigint, real, double precision, decimal, 
or interval
-    // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html 
doc
-    fn coerced_type(func_name: &str, data_type: &DataType) -> Result<DataType> 
{
-        match &data_type {
-            DataType::Decimal32(p, s) => Ok(DataType::Decimal32(*p, *s)),
-            DataType::Decimal64(p, s) => Ok(DataType::Decimal64(*p, *s)),
-            DataType::Decimal128(p, s) => Ok(DataType::Decimal128(*p, *s)),
-            DataType::Decimal256(p, s) => Ok(DataType::Decimal256(*p, *s)),
-            d if d.is_numeric() => Ok(DataType::Float64),
-            DataType::Duration(time_unit) => 
Ok(DataType::Duration(*time_unit)),
-            DataType::Dictionary(_, v) => coerced_type(func_name, v.as_ref()),
-            _ => {
-                plan_err!(
-                    "The function {:?} does not support inputs of type {}.",
-                    func_name,
-                    data_type
-                )
-            }
-        }
-    }
-    Ok(vec![coerced_type(func_name, &arg_types[0])?])
-}
-#[cfg(test)]
-mod tests {
-    use super::*;
-
-    #[test]
-    fn test_variance_return_data_type() -> Result<()> {
-        let data_type = DataType::Float64;
-        let result_type = variance_return_type(&data_type)?;
-        assert_eq!(DataType::Float64, result_type);
-
-        let data_type = DataType::Decimal128(36, 10);
-        assert!(variance_return_type(&data_type).is_err());
-        Ok(())
-    }
-
-    #[test]
-    fn test_sum_return_data_type() -> Result<()> {
-        let data_type = DataType::Decimal128(10, 5);
-        let result_type = sum_return_type(&data_type)?;
-        assert_eq!(DataType::Decimal128(20, 5), result_type);
-
-        let data_type = DataType::Decimal128(36, 10);
-        let result_type = sum_return_type(&data_type)?;
-        assert_eq!(DataType::Decimal128(38, 10), result_type);
-        Ok(())
-    }
-
-    #[test]
-    fn test_covariance_return_data_type() -> Result<()> {
-        let data_type = DataType::Float64;
-        let result_type = covariance_return_type(&data_type)?;
-        assert_eq!(DataType::Float64, result_type);
-
-        let data_type = DataType::Decimal128(36, 10);
-        assert!(covariance_return_type(&data_type).is_err());
-        Ok(())
-    }
-
-    #[test]
-    fn test_correlation_return_data_type() -> Result<()> {
-        let data_type = DataType::Float64;
-        let result_type = correlation_return_type(&data_type)?;
-        assert_eq!(DataType::Float64, result_type);
-
-        let data_type = DataType::Decimal128(36, 10);
-        assert!(correlation_return_type(&data_type).is_err());
-        Ok(())
-    }
-}
diff --git a/datafusion/expr/src/test/function_stub.rs 
b/datafusion/expr/src/test/function_stub.rs
index 41bc645058..8609afeae6 100644
--- a/datafusion/expr/src/test/function_stub.rs
+++ b/datafusion/expr/src/test/function_stub.rs
@@ -22,13 +22,15 @@
 use std::any::Any;
 
 use arrow::datatypes::{
-    DataType, FieldRef, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
-    DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION,
+    DataType, FieldRef, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE,
+    DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, DECIMAL32_MAX_PRECISION,
+    DECIMAL32_MAX_SCALE, DECIMAL64_MAX_PRECISION, DECIMAL64_MAX_SCALE,
 };
 
+use datafusion_common::plan_err;
 use datafusion_common::{exec_err, not_impl_err, utils::take_function_args, 
Result};
 
-use crate::type_coercion::aggregates::{avg_return_type, coerce_avg_type, 
NUMERICS};
+use crate::type_coercion::aggregates::NUMERICS;
 use crate::Volatility::Immutable;
 use crate::{
     expr::AggregateFunction,
@@ -488,8 +490,61 @@ impl AggregateUDFImpl for Avg {
         &self.signature
     }
 
+    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
+        let [args] = take_function_args(self.name(), arg_types)?;
+
+        // Supported types smallint, int, bigint, real, double precision, 
decimal, or interval
+        // Refer to 
https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
+        fn coerced_type(data_type: &DataType) -> Result<DataType> {
+            match &data_type {
+                DataType::Decimal32(p, s) => Ok(DataType::Decimal32(*p, *s)),
+                DataType::Decimal64(p, s) => Ok(DataType::Decimal64(*p, *s)),
+                DataType::Decimal128(p, s) => Ok(DataType::Decimal128(*p, *s)),
+                DataType::Decimal256(p, s) => Ok(DataType::Decimal256(*p, *s)),
+                d if d.is_numeric() => Ok(DataType::Float64),
+                DataType::Duration(time_unit) => 
Ok(DataType::Duration(*time_unit)),
+                DataType::Dictionary(_, v) => coerced_type(v.as_ref()),
+                _ => {
+                    plan_err!("Avg does not support inputs of type 
{data_type}.")
+                }
+            }
+        }
+        Ok(vec![coerced_type(args)?])
+    }
+
     fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
-        avg_return_type(self.name(), &arg_types[0])
+        match &arg_types[0] {
+            DataType::Decimal32(precision, scale) => {
+                // In the spark, the result type is 
DECIMAL(min(38,precision+4), min(38,scale+4)).
+                // Ref: 
https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
+                let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 
4);
+                let new_scale = DECIMAL32_MAX_SCALE.min(*scale + 4);
+                Ok(DataType::Decimal32(new_precision, new_scale))
+            }
+            DataType::Decimal64(precision, scale) => {
+                // In the spark, the result type is 
DECIMAL(min(38,precision+4), min(38,scale+4)).
+                // Ref: 
https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
+                let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 
4);
+                let new_scale = DECIMAL64_MAX_SCALE.min(*scale + 4);
+                Ok(DataType::Decimal64(new_precision, new_scale))
+            }
+            DataType::Decimal128(precision, scale) => {
+                // In the spark, the result type is 
DECIMAL(min(38,precision+4), min(38,scale+4)).
+                // Ref: 
https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
+                let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 
4);
+                let new_scale = DECIMAL128_MAX_SCALE.min(*scale + 4);
+                Ok(DataType::Decimal128(new_precision, new_scale))
+            }
+            DataType::Decimal256(precision, scale) => {
+                // In the spark, the result type is 
DECIMAL(min(38,precision+4), min(38,scale+4)).
+                // Ref: 
https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
+                let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 
4);
+                let new_scale = DECIMAL256_MAX_SCALE.min(*scale + 4);
+                Ok(DataType::Decimal256(new_precision, new_scale))
+            }
+            DataType::Duration(time_unit) => 
Ok(DataType::Duration(*time_unit)),
+            _ => Ok(DataType::Float64),
+        }
     }
 
     fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn 
Accumulator>> {
@@ -503,8 +558,4 @@ impl AggregateUDFImpl for Avg {
     fn aliases(&self) -> &[String] {
         &self.aliases
     }
-
-    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
-        coerce_avg_type(self.name(), arg_types)
-    }
 }
diff --git a/datafusion/functions-aggregate/src/average.rs 
b/datafusion/functions-aggregate/src/average.rs
index d007163e7c..11960779ed 100644
--- a/datafusion/functions-aggregate/src/average.rs
+++ b/datafusion/functions-aggregate/src/average.rs
@@ -27,14 +27,15 @@ use arrow::datatypes::{
     i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, 
Decimal32Type,
     Decimal64Type, DecimalType, DurationMicrosecondType, 
DurationMillisecondType,
     DurationNanosecondType, DurationSecondType, Field, FieldRef, Float64Type, 
TimeUnit,
-    UInt64Type, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
-    DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION,
+    UInt64Type, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, 
DECIMAL256_MAX_PRECISION,
+    DECIMAL256_MAX_SCALE, DECIMAL32_MAX_PRECISION, DECIMAL32_MAX_SCALE,
+    DECIMAL64_MAX_PRECISION, DECIMAL64_MAX_SCALE,
 };
+use datafusion_common::plan_err;
 use datafusion_common::{
     exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue,
 };
 use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
-use datafusion_expr::type_coercion::aggregates::{avg_return_type, 
coerce_avg_type};
 use datafusion_expr::utils::format_state_name;
 use datafusion_expr::Volatility::Immutable;
 use datafusion_expr::{
@@ -125,8 +126,61 @@ impl AggregateUDFImpl for Avg {
         &self.signature
     }
 
+    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
+        let [args] = take_function_args(self.name(), arg_types)?;
+
+        // Supported types smallint, int, bigint, real, double precision, 
decimal, or interval
+        // Refer to 
https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
+        fn coerced_type(data_type: &DataType) -> Result<DataType> {
+            match &data_type {
+                DataType::Decimal32(p, s) => Ok(DataType::Decimal32(*p, *s)),
+                DataType::Decimal64(p, s) => Ok(DataType::Decimal64(*p, *s)),
+                DataType::Decimal128(p, s) => Ok(DataType::Decimal128(*p, *s)),
+                DataType::Decimal256(p, s) => Ok(DataType::Decimal256(*p, *s)),
+                d if d.is_numeric() => Ok(DataType::Float64),
+                DataType::Duration(time_unit) => 
Ok(DataType::Duration(*time_unit)),
+                DataType::Dictionary(_, v) => coerced_type(v.as_ref()),
+                _ => {
+                    plan_err!("Avg does not support inputs of type 
{data_type}.")
+                }
+            }
+        }
+        Ok(vec![coerced_type(args)?])
+    }
+
     fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
-        avg_return_type(self.name(), &arg_types[0])
+        match &arg_types[0] {
+            DataType::Decimal32(precision, scale) => {
+                // In the spark, the result type is 
DECIMAL(min(38,precision+4), min(38,scale+4)).
+                // Ref: 
https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
+                let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 
4);
+                let new_scale = DECIMAL32_MAX_SCALE.min(*scale + 4);
+                Ok(DataType::Decimal32(new_precision, new_scale))
+            }
+            DataType::Decimal64(precision, scale) => {
+                // In the spark, the result type is 
DECIMAL(min(38,precision+4), min(38,scale+4)).
+                // Ref: 
https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
+                let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 
4);
+                let new_scale = DECIMAL64_MAX_SCALE.min(*scale + 4);
+                Ok(DataType::Decimal64(new_precision, new_scale))
+            }
+            DataType::Decimal128(precision, scale) => {
+                // In the spark, the result type is 
DECIMAL(min(38,precision+4), min(38,scale+4)).
+                // Ref: 
https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
+                let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 
4);
+                let new_scale = DECIMAL128_MAX_SCALE.min(*scale + 4);
+                Ok(DataType::Decimal128(new_precision, new_scale))
+            }
+            DataType::Decimal256(precision, scale) => {
+                // In the spark, the result type is 
DECIMAL(min(38,precision+4), min(38,scale+4)).
+                // Ref: 
https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
+                let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 
4);
+                let new_scale = DECIMAL256_MAX_SCALE.min(*scale + 4);
+                Ok(DataType::Decimal256(new_precision, new_scale))
+            }
+            DataType::Duration(time_unit) => 
Ok(DataType::Duration(*time_unit)),
+            _ => Ok(DataType::Float64),
+        }
     }
 
     fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn 
Accumulator>> {
@@ -452,11 +506,6 @@ impl AggregateUDFImpl for Avg {
         ReversedUDAF::Identical
     }
 
-    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
-        let [args] = take_function_args(self.name(), arg_types)?;
-        coerce_avg_type(self.name(), std::slice::from_ref(args))
-    }
-
     fn documentation(&self) -> Option<&Documentation> {
         self.doc()
     }
diff --git a/datafusion/functions-aggregate/src/lib.rs 
b/datafusion/functions-aggregate/src/lib.rs
index b56b2b118e..056cd45fa2 100644
--- a/datafusion/functions-aggregate/src/lib.rs
+++ b/datafusion/functions-aggregate/src/lib.rs
@@ -211,13 +211,7 @@ mod tests {
     #[test]
     fn test_no_duplicate_name() -> Result<()> {
         let mut names = HashSet::new();
-        let migrated_functions = ["array_agg", "count", "max", "min"];
         for func in all_default_aggregate_functions() {
-            // TODO: remove this
-            // These functions are in intermediate migration state, skip them
-            if 
migrated_functions.contains(&func.name().to_lowercase().as_str()) {
-                continue;
-            }
             assert!(
                 names.insert(func.name().to_string().to_lowercase()),
                 "duplicate function name: {}",
diff --git a/datafusion/functions-window/src/nth_value.rs 
b/datafusion/functions-window/src/nth_value.rs
index 329d8aa5ab..1ba6ad5ce0 100644
--- a/datafusion/functions-window/src/nth_value.rs
+++ b/datafusion/functions-window/src/nth_value.rs
@@ -40,39 +40,28 @@ use std::hash::Hash;
 use std::ops::Range;
 use std::sync::{Arc, LazyLock};
 
-get_or_init_udwf!(
+define_udwf_and_expr!(
     First,
     first_value,
-    "returns the first value in the window frame",
+    [arg],
+    "Returns the first value in the window frame",
     NthValue::first
 );
-get_or_init_udwf!(
+define_udwf_and_expr!(
     Last,
     last_value,
-    "returns the last value in the window frame",
+    [arg],
+    "Returns the last value in the window frame",
     NthValue::last
 );
 get_or_init_udwf!(
     NthValue,
     nth_value,
-    "returns the nth value in the window frame",
+    "Returns the nth value in the window frame",
     NthValue::nth
 );
 
-/// Create an expression to represent the `first_value` window function
-///
-pub fn first_value(arg: datafusion_expr::Expr) -> datafusion_expr::Expr {
-    first_value_udwf().call(vec![arg])
-}
-
-/// Create an expression to represent the `last_value` window function
-///
-pub fn last_value(arg: datafusion_expr::Expr) -> datafusion_expr::Expr {
-    last_value_udwf().call(vec![arg])
-}
-
 /// Create an expression to represent the `nth_value` window function
-///
 pub fn nth_value(arg: datafusion_expr::Expr, n: i64) -> datafusion_expr::Expr {
     nth_value_udwf().call(vec![arg, n.lit()])
 }
diff --git a/datafusion/functions-window/src/ntile.rs 
b/datafusion/functions-window/src/ntile.rs
index d188db3bbf..008caaa848 100644
--- a/datafusion/functions-window/src/ntile.rs
+++ b/datafusion/functions-window/src/ntile.rs
@@ -25,8 +25,7 @@ use datafusion_common::arrow::array::{ArrayRef, UInt64Array};
 use datafusion_common::arrow::datatypes::{DataType, Field};
 use datafusion_common::{exec_datafusion_err, exec_err, Result};
 use datafusion_expr::{
-    Documentation, Expr, LimitEffect, PartitionEvaluator, Signature, 
Volatility,
-    WindowUDFImpl,
+    Documentation, LimitEffect, PartitionEvaluator, Signature, Volatility, 
WindowUDFImpl,
 };
 use datafusion_functions_window_common::field;
 use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
@@ -37,16 +36,13 @@ use std::any::Any;
 use std::fmt::Debug;
 use std::sync::Arc;
 
-get_or_init_udwf!(
+define_udwf_and_expr!(
     Ntile,
     ntile,
-    "integer ranging from 1 to the argument value, dividing the partition as 
equally as possible"
+    [arg],
+    "Integer ranging from 1 to the argument value, dividing the partition as 
equally as possible."
 );
 
-pub fn ntile(arg: Expr) -> Expr {
-    ntile_udwf().call(vec![arg])
-}
-
 #[user_doc(
     doc_section(label = "Ranking Functions"),
     description = "Integer ranging from 1 to the argument value, dividing the 
partition as equally as possible",
diff --git a/datafusion/spark/src/function/aggregate/avg.rs 
b/datafusion/spark/src/function/aggregate/avg.rs
index a22561ba8b..65736815fe 100644
--- a/datafusion/spark/src/function/aggregate/avg.rs
+++ b/datafusion/spark/src/function/aggregate/avg.rs
@@ -25,41 +25,38 @@ use arrow::array::{
 use arrow::compute::sum;
 use arrow::datatypes::{DataType, Field, FieldRef};
 use datafusion_common::utils::take_function_args;
-use datafusion_common::{not_impl_err, Result, ScalarValue};
+use datafusion_common::{not_impl_err, plan_err, Result, ScalarValue};
 use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
-use datafusion_expr::type_coercion::aggregates::coerce_avg_type;
 use datafusion_expr::utils::format_state_name;
 use datafusion_expr::Volatility::Immutable;
 use datafusion_expr::{
-    type_coercion::aggregates::avg_return_type, Accumulator, AggregateUDFImpl, 
EmitTo,
-    GroupsAccumulator, ReversedUDAF, Signature,
+    Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF, 
Signature,
 };
 use std::{any::Any, sync::Arc};
-use DataType::*;
 
 /// AVG aggregate expression
 /// Spark average aggregate expression. Differs from standard DataFusion 
average aggregate
 /// in that it uses an `i64` for the count (DataFusion version uses `u64`); 
also there is ANSI mode
 /// support planned in the future for Spark version.
 
+// TODO: see if can deduplicate with DF version
+//       https://github.com/apache/datafusion/issues/17964
 #[derive(Debug, Clone, PartialEq, Eq, Hash)]
 pub struct SparkAvg {
-    name: String,
     signature: Signature,
-    input_data_type: DataType,
-    result_data_type: DataType,
+}
+
+impl Default for SparkAvg {
+    fn default() -> Self {
+        Self::new()
+    }
 }
 
 impl SparkAvg {
     /// Implement AVG aggregate function
-    pub fn new(name: impl Into<String>, data_type: DataType) -> Self {
-        let result_data_type = avg_return_type("avg", &data_type).unwrap();
-
+    pub fn new() -> Self {
         Self {
-            name: name.into(),
             signature: Signature::user_defined(Immutable),
-            input_data_type: data_type,
-            result_data_type,
         }
     }
 }
@@ -69,63 +66,87 @@ impl AggregateUDFImpl for SparkAvg {
         self
     }
 
-    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn 
Accumulator>> {
+    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
+        let [args] = take_function_args(self.name(), arg_types)?;
+
+        fn coerced_type(data_type: &DataType) -> Result<DataType> {
+            match &data_type {
+                d if d.is_numeric() => Ok(DataType::Float64),
+                DataType::Dictionary(_, v) => coerced_type(v.as_ref()),
+                _ => {
+                    plan_err!("Avg does not support inputs of type 
{data_type}.")
+                }
+            }
+        }
+        Ok(vec![coerced_type(args)?])
+    }
+
+    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
+        Ok(DataType::Float64)
+    }
+
+    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn 
Accumulator>> {
+        if acc_args.is_distinct {
+            return not_impl_err!("DistinctAvgAccumulator");
+        }
+
+        let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
+
         // instantiate specialized accumulator based for the type
-        match (&self.input_data_type, &self.result_data_type) {
-            (Float64, Float64) => Ok(Box::<AvgAccumulator>::default()),
-            _ => not_impl_err!(
-                "AvgAccumulator for ({} --> {})",
-                self.input_data_type,
-                self.result_data_type
-            ),
+        match (&data_type, &acc_args.return_type()) {
+            (DataType::Float64, DataType::Float64) => {
+                Ok(Box::<AvgAccumulator>::default())
+            }
+            (dt, return_type) => {
+                not_impl_err!("AvgAccumulator for ({dt} --> {return_type})")
+            }
         }
     }
 
-    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
+    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
         Ok(vec![
             Arc::new(Field::new(
-                format_state_name(&self.name, "sum"),
-                self.input_data_type.clone(),
+                format_state_name(self.name(), "sum"),
+                args.input_fields[0].data_type().clone(),
                 true,
             )),
             Arc::new(Field::new(
-                format_state_name(&self.name, "count"),
-                Int64,
+                format_state_name(self.name(), "count"),
+                DataType::Int64,
                 true,
             )),
         ])
     }
 
     fn name(&self) -> &str {
-        &self.name
+        "avg"
     }
 
     fn reverse_expr(&self) -> ReversedUDAF {
         ReversedUDAF::Identical
     }
 
-    fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
-        true
+    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
+        !args.is_distinct
     }
 
     fn create_groups_accumulator(
         &self,
-        _args: AccumulatorArgs,
+        args: AccumulatorArgs,
     ) -> Result<Box<dyn GroupsAccumulator>> {
+        let data_type = args.exprs[0].data_type(args.schema)?;
+
         // instantiate specialized accumulator based for the type
-        match (&self.input_data_type, &self.result_data_type) {
-            (Float64, Float64) => {
+        match (&data_type, args.return_type()) {
+            (DataType::Float64, DataType::Float64) => {
                 Ok(Box::new(AvgGroupsAccumulator::<Float64Type, _>::new(
-                    &self.input_data_type,
+                    args.return_field.data_type(),
                     |sum: f64, count: i64| Ok(sum / count as f64),
                 )))
             }
-
-            _ => not_impl_err!(
-                "AvgGroupsAccumulator for ({} --> {})",
-                self.input_data_type,
-                self.result_data_type
-            ),
+            (dt, return_type) => {
+                not_impl_err!("AvgGroupsAccumulator for ({dt} --> 
{return_type})")
+            }
         }
     }
 
@@ -136,15 +157,6 @@ impl AggregateUDFImpl for SparkAvg {
     fn signature(&self) -> &Signature {
         &self.signature
     }
-
-    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
-        avg_return_type(self.name(), &arg_types[0])
-    }
-
-    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
-        let [arg] = take_function_args(self.name(), arg_types)?;
-        coerce_avg_type(self.name(), std::slice::from_ref(arg))
-    }
 }
 
 /// An accumulator to compute the average
diff --git a/datafusion/spark/src/function/aggregate/mod.rs 
b/datafusion/spark/src/function/aggregate/mod.rs
index 54001d28da..d765d9c82f 100644
--- a/datafusion/spark/src/function/aggregate/mod.rs
+++ b/datafusion/spark/src/function/aggregate/mod.rs
@@ -15,7 +15,6 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use arrow::datatypes::DataType;
 use datafusion_expr::AggregateUDF;
 use std::sync::Arc;
 
@@ -26,11 +25,9 @@ pub mod expr_fn {
     export_functions!((avg, "Returns the average value of a given column", 
arg1));
 }
 
+// TODO: try use something like datafusion_functions_aggregate::create_func!()
 pub fn avg() -> Arc<AggregateUDF> {
-    Arc::new(AggregateUDF::new_from_impl(avg::SparkAvg::new(
-        "avg",
-        DataType::Float64,
-    )))
+    Arc::new(AggregateUDF::new_from_impl(avg::SparkAvg::new()))
 }
 
 pub fn functions() -> Vec<Arc<AggregateUDF>> {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to