jorgecarleitao commented on a change in pull request #8172:
URL: https://github.com/apache/arrow/pull/8172#discussion_r487404081
##########
File path: rust/datafusion/src/physical_plan/aggregates.rs
##########
@@ -103,42 +103,54 @@ pub fn create_aggregate_expr(
fun: &AggregateFunction,
args: &Vec<Arc<dyn PhysicalExpr>>,
input_schema: &Schema,
+ name: String,
) -> Result<Arc<dyn AggregateExpr>> {
// coerce
let arg = coerce(args, input_schema, &signature(fun))?[0].clone();
+ let arg_types = args
+ .iter()
+ .map(|e| e.data_type(input_schema))
+ .collect::<Result<Vec<_>>>()?;
+
+ let return_type = return_type(&fun, &arg_types)?;
+
Ok(match fun {
- AggregateFunction::Count => expressions::count(arg),
- AggregateFunction::Sum => expressions::sum(arg),
- AggregateFunction::Min => expressions::min(arg),
- AggregateFunction::Max => expressions::max(arg),
- AggregateFunction::Avg => expressions::avg(arg),
+ AggregateFunction::Count => {
+ Arc::new(expressions::Count::new(arg, name, return_type))
+ }
+ AggregateFunction::Sum => Arc::new(expressions::Sum::new(arg, name,
return_type)),
+ AggregateFunction::Min => Arc::new(expressions::Min::new(arg, name,
return_type)),
+ AggregateFunction::Max => Arc::new(expressions::Max::new(arg, name,
return_type)),
+ AggregateFunction::Avg => Arc::new(expressions::Avg::new(arg, name,
return_type)),
})
}
+static NUMERICS: &'static [DataType] = &[
+ DataType::Int8,
+ DataType::Int16,
+ DataType::Int32,
+ DataType::Int64,
+ DataType::UInt8,
+ DataType::UInt16,
+ DataType::UInt32,
+ DataType::UInt64,
+ DataType::Float32,
+ DataType::Float64,
+];
+
/// the signatures supported by the function `fun`.
fn signature(fun: &AggregateFunction) -> Signature {
// note: the physical expression must accept the type returned by this
function or the execution panics.
-
match fun {
AggregateFunction::Count => Signature::Any(1),
- AggregateFunction::Min
- | AggregateFunction::Max
- | AggregateFunction::Avg
- | AggregateFunction::Sum => Signature::Uniform(
- 1,
- vec![
- DataType::Int8,
- DataType::Int16,
- DataType::Int32,
- DataType::Int64,
- DataType::UInt8,
- DataType::UInt16,
- DataType::UInt32,
- DataType::UInt64,
- DataType::Float32,
- DataType::Float64,
- ],
- ),
+ AggregateFunction::Min | AggregateFunction::Max => {
+ let mut valid = vec![DataType::Utf8, DataType::LargeUtf8];
+ valid.extend_from_slice(NUMERICS);
Review comment:
I have not though about that, but that is an interesting idea 👍
In this PR, `max` continues to only support a single column, which we select
in [this
line](https://github.com/apache/arrow/pull/8172/files#diff-a98d5d588d3c5b525c6840271a5bdddcR571).
This PR does enable us to create aggregate functions with more than one
argument, and therefore this allows that option if we wish so. My initial
thinking was supporting aggregate functions of more arguments just to support
things like `covariance` and `correlation`, but now that you mention, we can do
a lot of other things also. Another one is count distinct over N columns.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]