This is an automated email from the ASF dual-hosted git repository.

github-bot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 6d3a846483 Rewrite `SUM(expr + scalar)` --> `SUM(expr) + 
scalar*COUNT(expr)`  (#20749)
6d3a846483 is described below

commit 6d3a8464836eb59605cc59579b6eb3af8c4788bf
Author: Andrew Lamb <[email protected]>
AuthorDate: Sat Mar 14 08:22:01 2026 -0400

    Rewrite `SUM(expr + scalar)` --> `SUM(expr) + scalar*COUNT(expr)`  (#20749)
    
    ## Which issue does this PR close?
    
    - Part of #18489
    - Closes https://github.com/apache/datafusion/pull/20180
    - Closes https://github.com/apache/datafusion/issues/15524
    - Replaces https://github.com/apache/datafusion/pull/20665
    
    ## Rationale for this change
    
    I [want DataFusion to be the fastest parquet engine on
    ClickBench](https://github.com/apache/datafusion/issues/18489). One of
    the queries where DataFusion is significantly slower is Query 29 which
    has a very strange pattern of many aggregate functions that are offset
    by a constant:
    
    
    
https://github.com/apache/datafusion/blob/0ca9d6586a43c323525b2e299448e0f1af4d6195/benchmarks/queries/clickbench/queries/q29.sql#L4
    
    This is not a pattern I have ever seen in a real query, but it seems
    like the engine currently at the top of the ClickBench leaderboard has a
    special case for this pattern. ClickHouse probably does too. See
    - https://github.com/duckdb/duckdb/pull/15017
    - Discussion on https://github.com/apache/datafusion/issues/15524
    
    Thus I reluctantly conclude that we should have one too.
    
    ## What changes are included in this PR?
    
    This is an alternate to my first attempt.
     - https://github.com/apache/datafusion/pull/20665
    
    In particular, since this is such a ClickBench specific rule, I wanted
    to
    1. Minimize the downstream API / upgrade impact (aka not change existing
    APIs)
    2. Optimize performance for the case where this rewrite will not apply
    (most times)
    
    1. Add a rewrite `SUM(expr + scalar)` --> `SUM(expr) +
    scalar*COUNT(expr)`
    3. Tests for same
    
    Note there are quite a few other ideas to potentially make this more
    general on https://github.com/apache/datafusion/issues/15524 but I am
    going with the simple thing of making it work for the usecase we have in
    hand (ClickBench)
    
    ## Are these changes tested?
    
    Yes, new tests are added
    
    ## Are there any user-facing changes?
    
    Faster performance
    
    🚀
    ```
    │ QQuery 29 │  1012.63 ms │                 139.02 ms │ +7.28x faster │
    ```
---
 datafusion/expr/src/expr.rs                        |   2 +-
 datafusion/expr/src/logical_plan/plan.rs           |   4 +-
 datafusion/expr/src/udaf.rs                        |  96 +++++++++
 datafusion/functions-aggregate/src/sum.rs          |  50 ++++-
 .../src/simplify_expressions/linear_aggregates.rs  | 229 +++++++++++++++++++++
 .../optimizer/src/simplify_expressions/mod.rs      |   1 +
 .../src/simplify_expressions/simplify_exprs.rs     | 155 +++++++++++++-
 .../test_files/aggregates_simplify.slt             |  66 +++---
 datafusion/sqllogictest/test_files/clickbench.slt  |  15 +-
 9 files changed, 574 insertions(+), 44 deletions(-)

diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index 5c6acd480e..12c879a515 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -600,7 +600,7 @@ impl Alias {
     }
 }
 
-/// Binary expression
+/// Binary expression for [`Expr::BinaryExpr`]
 #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
 pub struct BinaryExpr {
     /// Left-hand side of the expression
diff --git a/datafusion/expr/src/logical_plan/plan.rs 
b/datafusion/expr/src/logical_plan/plan.rs
index fe8a8dd870..b2a5697183 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -3490,7 +3490,9 @@ pub struct Aggregate {
     pub input: Arc<LogicalPlan>,
     /// Grouping expressions
     pub group_expr: Vec<Expr>,
-    /// Aggregate expressions
+    /// Aggregate expressions.
+    ///
+    /// Note these *must* be either [`Expr::AggregateFunction`] or 
[`Expr::Alias`]
     pub aggr_expr: Vec<Expr>,
     /// The schema description of the aggregate output
     pub schema: DFSchemaRef,
diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs
index f2e2c53cdb..245a80c02c 100644
--- a/datafusion/expr/src/udaf.rs
+++ b/datafusion/expr/src/udaf.rs
@@ -28,6 +28,7 @@ use arrow::datatypes::{DataType, Field, FieldRef};
 
 use datafusion_common::{Result, ScalarValue, Statistics, exec_err, 
not_impl_err};
 use datafusion_expr_common::dyn_eq::{DynEq, DynHash};
+use datafusion_expr_common::operator::Operator;
 use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
 
 use crate::expr::{
@@ -301,6 +302,21 @@ impl AggregateUDF {
         self.inner.simplify()
     }
 
+    /// Rewrite aggregate to have simpler arguments
+    ///
+    /// See  [`AggregateUDFImpl::simplify_expr_op_literal`] for more details
+    pub fn simplify_expr_op_literal(
+        &self,
+        agg_function: &AggregateFunction,
+        arg: &Expr,
+        op: Operator,
+        lit: &Expr,
+        arg_is_left: bool,
+    ) -> Result<Option<Expr>> {
+        self.inner
+            .simplify_expr_op_literal(agg_function, arg, op, lit, arg_is_left)
+    }
+
     /// Returns true if the function is max, false if the function is min
     /// None in all other cases, used in certain optimizations for
     /// or aggregate
@@ -691,6 +707,74 @@ pub trait AggregateUDFImpl: Debug + DynEq + DynHash + Send 
+ Sync {
         None
     }
 
+    /// Rewrite the aggregate to have simpler arguments
+    ///
+    /// This query pattern is not common in most real workloads, and most
+    /// aggregate implementations can safely ignore it. This API is included in
+    /// DataFusion because it is important for ClickBench Q29. See backstory
+    /// on <https://github.com/apache/datafusion/issues/15524>
+    ///
+    /// # Rewrite Overview
+    ///
+    /// The idea is to rewrite multiple aggregates with "complex arguments" 
into
+    /// ones with simpler arguments that can be optimized by common 
subexpression
+    /// elimination (CSE). At a high level the rewrite looks like
+    ///
+    /// * `Aggregate(SUM(x + 1), SUM(x + 2), ...)`
+    ///
+    /// Into
+    ///
+    /// * `Aggregate(SUM(x) + 1 * COUNT(x), SUM(x) + 2 * COUNT(x), ...)`
+    ///
+    /// While this rewrite may seem worse (slower) than the original as it
+    /// computes *more* aggregate expressions, the common subexpression
+    /// elimination (CSE) can then reduce the number of distinct aggregates the
+    /// query actually needs to compute with a rewrite like
+    ///
+    /// * `Projection(_A + 1*_B, _A + 2*_B)`
+    /// * `  Aggregate(_A = SUM(x), _B = COUNT(x))`
+    ///
+    /// This optimization is extremely important for ClickBench Q29, which has 
90
+    /// such expressions for some reason, and so this optimization results in
+    /// only two aggregates being needed. The DataFusion optimizer will invoke
+    /// this method when it detects multiple aggregates in a query that share
+    /// arguments of the form `<arg> <op> <literal>`.
+    ///
+    /// # API
+    ///
+    /// If `agg_function` supports the rewrite, it should return a semantically
+    /// equivalent expression (likely with more aggregate expressions, but
+    /// simpler arguments)
+    ///
+    /// This is only called when:
+    /// 1. There are no "special" aggregate params (filters, null handling, 
etc)
+    /// 2. Aggregate functions with exactly one [`Expr`] argument
+    /// 3. There are no volatile expressions
+    ///
+    /// Arguments
+    /// * `agg_function`: the original aggregate function detected with complex
+    ///   arguments.
+    /// * `arg`: The common argument shared across multiple aggregates (e.g. 
`x`
+    ///   in the example above)
+    /// * `op`: the operator between the common argument and the literal (e.g.
+    ///   `+` in `x + 1` or `1 + x`)
+    /// * `lit`: the literal argument (e.g. `1` or `2` in the example above)
+    /// * `arg_is_left`: whether the common argument is on the left or right of
+    ///   the operator (e.g. `true` for `x + 1` and false for `1 + x`)
+    ///
+    /// The default implementation returns `None`, which is what most 
aggregates
+    /// should do.
+    fn simplify_expr_op_literal(
+        &self,
+        _agg_function: &AggregateFunction,
+        _arg: &Expr,
+        _op: Operator,
+        _lit: &Expr,
+        _arg_is_left: bool,
+    ) -> Result<Option<Expr>> {
+        Ok(None)
+    }
+
     /// Returns the reverse expression of the aggregate function.
     fn reverse_expr(&self) -> ReversedUDAF {
         ReversedUDAF::NotSupported
@@ -1243,6 +1327,18 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl {
         self.inner.simplify()
     }
 
+    fn simplify_expr_op_literal(
+        &self,
+        agg_function: &AggregateFunction,
+        arg: &Expr,
+        op: Operator,
+        lit: &Expr,
+        arg_is_left: bool,
+    ) -> Result<Option<Expr>> {
+        self.inner
+            .simplify_expr_op_literal(agg_function, arg, op, lit, arg_is_left)
+    }
+
     fn reverse_expr(&self) -> ReversedUDAF {
         self.inner.reverse_expr()
     }
diff --git a/datafusion/functions-aggregate/src/sum.rs 
b/datafusion/functions-aggregate/src/sum.rs
index 198ba54adf..5cced80d99 100644
--- a/datafusion/functions-aggregate/src/sum.rs
+++ b/datafusion/functions-aggregate/src/sum.rs
@@ -27,17 +27,20 @@ use arrow::datatypes::{
     DurationMillisecondType, DurationNanosecondType, DurationSecondType, 
FieldRef,
     Float64Type, Int64Type, TimeUnit, UInt64Type,
 };
+use datafusion_common::internal_err;
 use datafusion_common::types::{
     NativeType, logical_float64, logical_int8, logical_int16, logical_int32,
     logical_int64, logical_uint8, logical_uint16, logical_uint32, 
logical_uint64,
 };
 use datafusion_common::{HashMap, Result, ScalarValue, exec_err, not_impl_err};
+use datafusion_expr::expr::AggregateFunction;
+use datafusion_expr::expr_fn::cast;
 use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
 use datafusion_expr::utils::{AggregateOrderSensitivity, format_state_name};
 use datafusion_expr::{
     Accumulator, AggregateUDFImpl, Coercion, Documentation, Expr, 
GroupsAccumulator,
-    ReversedUDAF, SetMonotonicity, Signature, TypeSignature, 
TypeSignatureClass,
-    Volatility,
+    Operator, ReversedUDAF, SetMonotonicity, Signature, TypeSignature,
+    TypeSignatureClass, Volatility,
 };
 use 
datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
 use 
datafusion_functions_aggregate_common::aggregate::sum_distinct::DistinctSumAccumulator;
@@ -54,7 +57,7 @@ make_udaf_expr_and_func!(
 );
 
 pub fn sum_distinct(expr: Expr) -> Expr {
-    Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
+    Expr::AggregateFunction(AggregateFunction::new_udf(
         sum_udaf(),
         vec![expr],
         true,
@@ -346,6 +349,47 @@ impl AggregateUDFImpl for Sum {
             _ => SetMonotonicity::NotMonotonic,
         }
     }
+
+    /// Implement ClickBench Q29 specific optimization:
+    /// `SUM(arg + constant)` --> `SUM(arg) + constant * COUNT(arg)`
+    ///
+    /// See background on [`AggregateUDFImpl::simplify_expr_op_literal`]
+    fn simplify_expr_op_literal(
+        &self,
+        agg_function: &AggregateFunction,
+        arg: &Expr,
+        op: Operator,
+        lit: &Expr,
+        // Only support '+' so the order of the args doesn't matter
+        _arg_is_left: bool,
+    ) -> Result<Option<Expr>> {
+        if op != Operator::Plus {
+            return Ok(None);
+        }
+
+        let lit_type = match &lit {
+            Expr::Literal(value, _) => value.data_type(),
+            _ => {
+                return internal_err!(
+                    "Sum::simplify_expr_op_literal got a non literal argument"
+                );
+            }
+        };
+        if lit_type == DataType::Null {
+            return Ok(None);
+        }
+
+        // Build up SUM(arg)
+        let mut sum_agg = agg_function.clone();
+        sum_agg.params.args = vec![arg.clone()];
+        let sum_agg = Expr::AggregateFunction(sum_agg);
+
+        // COUNT(arg) - cast to the correct type
+        let count_agg = cast(crate::count::count(arg.clone()), lit_type);
+
+        // SUM(arg) + lit * COUNT(arg)
+        Ok(Some(sum_agg + (lit.clone() * count_agg)))
+    }
 }
 
 /// This accumulator computes SUM incrementally
diff --git a/datafusion/optimizer/src/simplify_expressions/linear_aggregates.rs 
b/datafusion/optimizer/src/simplify_expressions/linear_aggregates.rs
new file mode 100644
index 0000000000..21389cf326
--- /dev/null
+++ b/datafusion/optimizer/src/simplify_expressions/linear_aggregates.rs
@@ -0,0 +1,229 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Simplification to refactor multiple aggregate functions to use the same 
aggregate function
+
+use datafusion_common::HashMap;
+use datafusion_expr::expr::AggregateFunctionParams;
+use datafusion_expr::{BinaryExpr, Expr};
+use datafusion_expr_common::operator::Operator;
+
+/// Threshold of the number of aggregates that share similar arguments before
+/// triggering rewrite.
+///
+/// There is a threshold because the canonical SUM rewrite described in
+/// [`AggregateUDFImpl::simplify_expr_op_literal`] actually results in more
+/// aggregates (2) for each original aggregate. It is important that CSE then
+/// eliminate them.
+///
+/// [`AggregateUDFImpl::simplify_expr_op_literal`]: 
datafusion_expr::AggregateUDFImpl::simplify_expr_op_literal
+const DUPLICATE_THRESHOLD: usize = 2;
+
+/// Rewrites multiple aggregate expressions that have a common linear component
+/// into multiple aggregate expressions that share that common component.
+///
+/// For example, rewrites patterns such as
+/// * `SUM(x + 1), SUM(x + 2), ...`
+///
+/// Into
+/// * `SUM(x) + 1 * COUNT(x), SUM(x) + 2 * COUNT(x), ...`
+///
+/// See the background [`AggregateUDFImpl::simplify_expr_op_literal`] for 
details.
+///
+/// Returns `true` if any of the arguments are rewritten (modified), `false`
+/// otherwise.
+///
+/// ## Design goals:
+/// 1. Keep the aggregate specific logic out of the optimizer (can't depend 
directly on SUM)
+/// 2. Optimize for the case that this rewrite will not apply (it almost never 
does)
+///
+/// [`AggregateUDFImpl::simplify_expr_op_literal`]: 
datafusion_expr::AggregateUDFImpl::simplify_expr_op_literal
+pub(super) fn rewrite_multiple_linear_aggregates(
+    agg_expr: &mut [Expr],
+) -> datafusion_common::Result<bool> {
+    // map <expr>: count of expressions that have a common argument
+    let mut common_args = HashMap::new();
+
+    // First pass -- figure out any aggregates that can be split and have 
common
+    // expressions.
+    for agg in agg_expr.iter() {
+        let Expr::AggregateFunction(agg_function) = agg else {
+            continue;
+        };
+
+        let Some(arg) = candidate_linear_param(&agg_function.params) else {
+            continue;
+        };
+
+        let Some(expr_literal) = ExprLiteral::try_new(arg) else {
+            continue;
+        };
+
+        let counter = common_args.entry(expr_literal.expr()).or_insert(0);
+        *counter += 1;
+    }
+
+    // (agg_index, new_expr)
+    let mut new_aggs = vec![];
+
+    // Second pass, actually rewrite any aggregates that have a common
+    // expression and enough duplicates.
+    for (idx, agg) in agg_expr.iter().enumerate() {
+        let Expr::AggregateFunction(agg_function) = agg else {
+            continue;
+        };
+
+        let Some(arg) = candidate_linear_param(&agg_function.params) else {
+            continue;
+        };
+
+        let Some(expr_literal) = ExprLiteral::try_new(arg) else {
+            continue;
+        };
+
+        // Not enough common expressions to make it worth rewriting
+        if common_args.get(expr_literal.expr()).unwrap_or(&0) < 
&DUPLICATE_THRESHOLD {
+            continue;
+        }
+
+        if let Some(new_agg_function) = 
agg_function.func.simplify_expr_op_literal(
+            agg_function,
+            expr_literal.expr(),
+            expr_literal.op(),
+            expr_literal.lit(),
+            expr_literal.arg_is_left(),
+        )? {
+            new_aggs.push((idx, new_agg_function));
+        }
+    }
+
+    if new_aggs.is_empty() {
+        return Ok(false);
+    }
+
+    // Otherwise replace the aggregate expressions
+    drop(common_args); // release borrow
+    for (idx, new_agg) in new_aggs {
+        let orig_name = agg_expr[idx].name_for_alias()?;
+        agg_expr[idx] = new_agg.alias_if_changed(orig_name)?
+    }
+
+    Ok(true)
+}
+
+/// Returns Some(&Expr) with the single argument if this is a suitable 
candidate
+/// for the  linear rewrite
+fn candidate_linear_param(params: &AggregateFunctionParams) -> Option<&Expr> {
+    // Explicitly destructure to ensure we check all relevant fields
+    let AggregateFunctionParams {
+        args,
+        distinct,
+        filter,
+        order_by,
+        null_treatment,
+    } = params;
+
+    // Disqualify anything "non standard"
+    if *distinct
+        || filter.is_some()
+        || !order_by.is_empty()
+        || null_treatment.is_some()
+        || args.len() != 1
+    {
+        return None;
+    }
+    let arg = args.first()?;
+    if arg.is_volatile() {
+        return None;
+    };
+    Some(arg)
+}
+
+/// A view into a [`Expr::BinaryExpr`]  that is arbitrary expression and a
+/// literal
+///
+/// This is an enum to distinguish the direction of the operator arguments
+#[derive(Debug, Clone)]
+pub enum ExprLiteral<'a> {
+    /// if the expression is `<arg> <op> <lit>`
+    ArgOpLit {
+        arg: &'a Expr,
+        op: Operator,
+        lit: &'a Expr,
+    },
+    /// if the expression is `<lit> <op> <arg>`
+    LitOpArg {
+        lit: &'a Expr,
+        op: Operator,
+        arg: &'a Expr,
+    },
+}
+
+impl<'a> ExprLiteral<'a> {
+    /// Try and split the Expr into its parts
+    fn try_new(expr: &'a Expr) -> Option<Self> {
+        match expr {
+            // <lit> <op> <expr>
+            Expr::BinaryExpr(BinaryExpr { left, op, right })
+                if matches!(left.as_ref(), Expr::Literal(..)) =>
+            {
+                Some(Self::LitOpArg {
+                    arg: right,
+                    lit: left,
+                    op: *op,
+                })
+            }
+
+            // <expr> + <lit>
+            Expr::BinaryExpr(BinaryExpr { left, op, right })
+                if matches!(right.as_ref(), Expr::Literal(..)) =>
+            {
+                Some(Self::ArgOpLit {
+                    arg: left,
+                    lit: right,
+                    op: *op,
+                })
+            }
+            _ => None,
+        }
+    }
+
+    fn expr(&self) -> &'a Expr {
+        match self {
+            Self::ArgOpLit { arg, .. } => arg,
+            Self::LitOpArg { arg, .. } => arg,
+        }
+    }
+
+    fn lit(&self) -> &'a Expr {
+        match self {
+            Self::ArgOpLit { lit, .. } => lit,
+            Self::LitOpArg { lit, .. } => lit,
+        }
+    }
+
+    fn op(&self) -> Operator {
+        match self {
+            Self::ArgOpLit { op, .. } => *op,
+            Self::LitOpArg { op, .. } => *op,
+        }
+    }
+
+    fn arg_is_left(&self) -> bool {
+        matches!(self, Self::ArgOpLit { .. })
+    }
+}
diff --git a/datafusion/optimizer/src/simplify_expressions/mod.rs 
b/datafusion/optimizer/src/simplify_expressions/mod.rs
index b85b000821..89c79d3fb4 100644
--- a/datafusion/optimizer/src/simplify_expressions/mod.rs
+++ b/datafusion/optimizer/src/simplify_expressions/mod.rs
@@ -20,6 +20,7 @@
 
 pub mod expr_simplifier;
 mod inlist_simplifier;
+mod linear_aggregates;
 mod regex;
 pub mod simplify_exprs;
 pub mod simplify_literal;
diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs 
b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs
index 2114c5ef3d..29ee593422 100644
--- a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs
+++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs
@@ -20,18 +20,20 @@
 use std::sync::Arc;
 
 use datafusion_common::tree_node::{Transformed, TreeNode};
-use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result};
+use datafusion_common::{Column, DFSchema, DFSchemaRef, DataFusionError, 
Result};
 use datafusion_expr::Expr;
-use datafusion_expr::logical_plan::LogicalPlan;
+use datafusion_expr::logical_plan::{Aggregate, LogicalPlan, Projection};
 use datafusion_expr::simplify::SimplifyContext;
-use datafusion_expr::utils::merge_schema;
+use datafusion_expr::utils::{
+    columnize_expr, find_aggregate_exprs, grouping_set_to_exprlist, 
merge_schema,
+};
 
+use super::ExprSimplifier;
 use crate::optimizer::ApplyOrder;
+use 
crate::simplify_expressions::linear_aggregates::rewrite_multiple_linear_aggregates;
 use crate::utils::NamePreserver;
 use crate::{OptimizerConfig, OptimizerRule};
 
-use super::ExprSimplifier;
-
 /// Optimizer Pass that simplifies [`LogicalPlan`]s by rewriting
 /// [`Expr`]`s evaluating constants and applying algebraic
 /// simplifications
@@ -137,7 +139,8 @@ impl SimplifyExpressions {
             } else {
                 rewrite_expr(expr)
             }
-        })
+        })?
+        .transform_data(rewrite_aggregate_non_aggregate_aggr_expr)
     }
 }
 
@@ -148,6 +151,98 @@ impl SimplifyExpressions {
     }
 }
 
+/// Ensures that `LogicalPlan::Aggregate` is well formed after rewrites
+/// by potentially introducing an extra `Projection`.
+///
+/// Also applies the [`rewrite_multiple_linear_aggregates`] special case
+///
+/// # Rationale:
+///
+/// [`LogicalPlan::Aggregate`] requires agg expressions to be (possibly 
aliased)
+/// [`Expr::AggregateFunction`]. Some UDAF simplifiers may return other 
[`Expr`]
+/// variants.
+///
+/// # Operation
+///
+/// Rewrites things like this (note that `exp1` is not an aggregate):
+/// * `Aggregate(group_expr, aggr_expr=[exp1 + agg(exp2)])`
+///
+/// into:
+/// * `Projection(exp1 + _X)`
+/// * `  Aggregate(group_expr, aggr_expr=[agg(exp2) AS _X])`
+fn rewrite_aggregate_non_aggregate_aggr_expr(
+    plan: LogicalPlan,
+) -> Result<Transformed<LogicalPlan>> {
+    let LogicalPlan::Aggregate(Aggregate {
+        input,
+        group_expr,
+        mut aggr_expr,
+        schema,
+        ..
+    }) = plan
+    else {
+        return Ok(Transformed::no(plan));
+    };
+
+    let rewrote_aggs = rewrite_multiple_linear_aggregates(&mut aggr_expr)?;
+
+    // Ensure that all Aggregate arguments are AggregateExpr
+    if aggr_expr.iter().all(is_top_level_aggregate_expr) {
+        let new_plan = LogicalPlan::Aggregate(Aggregate::try_new_with_schema(
+            input, group_expr, aggr_expr, schema,
+        )?);
+        return if !rewrote_aggs {
+            Ok(Transformed::no(new_plan))
+        } else {
+            Ok(Transformed::yes(new_plan))
+        };
+    }
+
+    // Otherwise we need to add a Projection above Aggregate to calculate
+    // the final output expressions.
+
+    let inner_aggr_expr = find_aggregate_exprs(aggr_expr.iter());
+    let inner_aggregate = LogicalPlan::Aggregate(Aggregate::try_new(
+        Arc::clone(&input),
+        group_expr.clone(),
+        inner_aggr_expr,
+    )?);
+    let inner_aggregate = Arc::new(inner_aggregate);
+
+    let mut projection_exprs = aggregate_output_exprs(&group_expr)?;
+    projection_exprs.extend(aggr_expr);
+    let projection_exprs = projection_exprs
+        .into_iter()
+        .map(|expr| columnize_expr(expr, inner_aggregate.as_ref()))
+        .collect::<Result<Vec<_>>>()?;
+
+    Ok(Transformed::yes(LogicalPlan::Projection(
+        Projection::try_new(projection_exprs, inner_aggregate)?,
+    )))
+}
+
+fn is_top_level_aggregate_expr(expr: &Expr) -> bool {
+    matches!(
+        expr.clone().unalias_nested().data,
+        Expr::AggregateFunction(_)
+    )
+}
+
+fn aggregate_output_exprs(group_expr: &[Expr]) -> Result<Vec<Expr>> {
+    let mut output_exprs = grouping_set_to_exprlist(group_expr)?
+        .into_iter()
+        .cloned()
+        .collect::<Vec<_>>();
+
+    if matches!(group_expr, [Expr::GroupingSet(_)]) {
+        output_exprs.push(Expr::Column(Column::from_name(
+            Aggregate::INTERNAL_GROUPING_ID,
+        )));
+    }
+
+    Ok(output_exprs)
+}
+
 #[cfg(test)]
 mod tests {
     use std::ops::Not;
@@ -159,7 +254,7 @@ mod tests {
     use datafusion_expr::logical_plan::builder::table_scan_with_filters;
     use datafusion_expr::logical_plan::table_scan;
     use datafusion_expr::*;
-    use datafusion_functions_aggregate::expr_fn::{max, min};
+    use datafusion_functions_aggregate::expr_fn::{max, min, sum};
 
     use crate::OptimizerContext;
     use crate::assert_optimized_plan_eq_snapshot;
@@ -259,6 +354,52 @@ mod tests {
         )
     }
 
+    #[test]
+    fn test_simplify_udaf_to_non_aggregate_expr() -> Result<()> {
+        let schema = Schema::new(vec![Field::new("a", DataType::Int64, 
false)]);
+        let table_scan = table_scan(Some("test"), &schema, None)?
+            .build()
+            .expect("building scan");
+
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .aggregate(Vec::<Expr>::new(), vec![sum(col("a") + lit(2i64))])?
+            .build()?;
+
+        assert_optimized_plan_equal!(
+            plan,
+            @r"
+        Aggregate: groupBy=[[]], aggr=[[sum(test.a + Int64(2))]]
+          TableScan: test
+        "
+        )?;
+        Ok(())
+    }
+
+    #[test]
+    fn test_simplify_common_sum_arg() -> Result<()> {
+        let schema = Schema::new(vec![Field::new("a", DataType::Int64, 
false)]);
+        let table_scan = table_scan(Some("test"), &schema, None)?
+            .build()
+            .expect("building scan");
+
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .aggregate(
+                Vec::<Expr>::new(),
+                vec![sum(col("a") + lit(2i64)), sum(col("a") + lit(3i64))],
+            )?
+            .build()?;
+
+        assert_optimized_plan_equal!(
+            plan,
+            @r"
+        Projection: sum(test.a) + Int64(2) * CAST(count(test.a) AS Int64) AS 
sum(test.a + Int64(2)), sum(test.a) + Int64(3) * CAST(count(test.a) AS Int64) 
AS sum(test.a + Int64(3))
+          Aggregate: groupBy=[[]], aggr=[[sum(test.a), count(test.a)]]
+            TableScan: test
+        "
+        )?;
+        Ok(())
+    }
+
     #[test]
     fn test_simplify_optimized_plan_with_or() -> Result<()> {
         let table_scan = test_table_scan();
diff --git a/datafusion/sqllogictest/test_files/aggregates_simplify.slt 
b/datafusion/sqllogictest/test_files/aggregates_simplify.slt
index cc2e40540b..9aa3ecf7a2 100644
--- a/datafusion/sqllogictest/test_files/aggregates_simplify.slt
+++ b/datafusion/sqllogictest/test_files/aggregates_simplify.slt
@@ -106,11 +106,14 @@ query TT
 EXPLAIN SELECT SUM(column1 + 1), SUM(column1 + 2) FROM sum_simplify_t;
 ----
 logical_plan
-01)Aggregate: groupBy=[[]], aggr=[[sum(sum_simplify_t.column1 + Int64(1)), 
sum(sum_simplify_t.column1 + Int64(2))]]
-02)--TableScan: sum_simplify_t projection=[column1]
+01)Projection: sum(sum_simplify_t.column1) + __common_expr_1 AS 
sum(sum_simplify_t.column1 + Int64(1)), sum(sum_simplify_t.column1) + Int64(2) 
* __common_expr_1 AS sum(sum_simplify_t.column1 + Int64(2))
+02)--Projection: CAST(count(sum_simplify_t.column1) AS Int64) AS 
__common_expr_1, sum(sum_simplify_t.column1)
+03)----Aggregate: groupBy=[[]], aggr=[[sum(sum_simplify_t.column1), 
count(sum_simplify_t.column1)]]
+04)------TableScan: sum_simplify_t projection=[column1]
 physical_plan
-01)AggregateExec: mode=Single, gby=[], aggr=[sum(sum_simplify_t.column1 + 
Int64(1)), sum(sum_simplify_t.column1 + Int64(2))]
-02)--DataSourceExec: partitions=1, partition_sizes=[1]
+01)ProjectionExec: expr=[sum(sum_simplify_t.column1)@0 + 
count(sum_simplify_t.column1)@1 as sum(sum_simplify_t.column1 + Int64(1)), 
sum(sum_simplify_t.column1)@0 + 2 * count(sum_simplify_t.column1)@1 as 
sum(sum_simplify_t.column1 + Int64(2))]
+02)--AggregateExec: mode=Single, gby=[], aggr=[sum(sum_simplify_t.column1), 
count(sum_simplify_t.column1)]
+03)----DataSourceExec: partitions=1, partition_sizes=[1]
 
 # Reordered expressions that still compute the same thing
 query II
@@ -122,11 +125,14 @@ query TT
 EXPLAIN SELECT SUM(1 + column1), SUM(column1 + 2) FROM sum_simplify_t;
 ----
 logical_plan
-01)Aggregate: groupBy=[[]], aggr=[[sum(Int64(1) + sum_simplify_t.column1), 
sum(sum_simplify_t.column1 + Int64(2))]]
-02)--TableScan: sum_simplify_t projection=[column1]
+01)Projection: sum(sum_simplify_t.column1) + __common_expr_1 AS sum(Int64(1) + 
sum_simplify_t.column1), sum(sum_simplify_t.column1) + Int64(2) * 
__common_expr_1 AS sum(sum_simplify_t.column1 + Int64(2))
+02)--Projection: CAST(count(sum_simplify_t.column1) AS Int64) AS 
__common_expr_1, sum(sum_simplify_t.column1)
+03)----Aggregate: groupBy=[[]], aggr=[[sum(sum_simplify_t.column1), 
count(sum_simplify_t.column1)]]
+04)------TableScan: sum_simplify_t projection=[column1]
 physical_plan
-01)AggregateExec: mode=Single, gby=[], aggr=[sum(Int64(1) + 
sum_simplify_t.column1), sum(sum_simplify_t.column1 + Int64(2))]
-02)--DataSourceExec: partitions=1, partition_sizes=[1]
+01)ProjectionExec: expr=[sum(sum_simplify_t.column1)@0 + 
count(sum_simplify_t.column1)@1 as sum(Int64(1) + sum_simplify_t.column1), 
sum(sum_simplify_t.column1)@0 + 2 * count(sum_simplify_t.column1)@1 as 
sum(sum_simplify_t.column1 + Int64(2))]
+02)--AggregateExec: mode=Single, gby=[], aggr=[sum(sum_simplify_t.column1), 
count(sum_simplify_t.column1)]
+03)----DataSourceExec: partitions=1, partition_sizes=[1]
 
 # DISTINCT aggregates with different arguments
 query II
@@ -259,15 +265,18 @@ EXPLAIN SELECT column2, SUM(column1 + 1), SUM(column1 + 
2) FROM sum_simplify_t G
 ----
 logical_plan
 01)Sort: sum_simplify_t.column2 DESC NULLS LAST
-02)--Aggregate: groupBy=[[sum_simplify_t.column2]], 
aggr=[[sum(sum_simplify_t.column1 + Int64(1)), sum(sum_simplify_t.column1 + 
Int64(2))]]
-03)----TableScan: sum_simplify_t projection=[column1, column2]
+02)--Projection: sum_simplify_t.column2, sum(sum_simplify_t.column1) + 
__common_expr_1 AS sum(sum_simplify_t.column1 + Int64(1)), 
sum(sum_simplify_t.column1) + Int64(2) * __common_expr_1 AS 
sum(sum_simplify_t.column1 + Int64(2))
+03)----Projection: CAST(count(sum_simplify_t.column1) AS Int64) AS 
__common_expr_1, sum_simplify_t.column2, sum(sum_simplify_t.column1)
+04)------Aggregate: groupBy=[[sum_simplify_t.column2]], 
aggr=[[sum(sum_simplify_t.column1), count(sum_simplify_t.column1)]]
+05)--------TableScan: sum_simplify_t projection=[column1, column2]
 physical_plan
 01)SortPreservingMergeExec: [column2@0 DESC NULLS LAST]
 02)--SortExec: expr=[column2@0 DESC NULLS LAST], preserve_partitioning=[true]
-03)----AggregateExec: mode=FinalPartitioned, gby=[column2@0 as column2], 
aggr=[sum(sum_simplify_t.column1 + Int64(1)), sum(sum_simplify_t.column1 + 
Int64(2))]
-04)------RepartitionExec: partitioning=Hash([column2@0], 4), input_partitions=1
-05)--------AggregateExec: mode=Partial, gby=[column2@1 as column2], 
aggr=[sum(sum_simplify_t.column1 + Int64(1)), sum(sum_simplify_t.column1 + 
Int64(2))]
-06)----------DataSourceExec: partitions=1, partition_sizes=[1]
+03)----ProjectionExec: expr=[column2@0 as column2, 
sum(sum_simplify_t.column1)@1 + count(sum_simplify_t.column1)@2 as 
sum(sum_simplify_t.column1 + Int64(1)), sum(sum_simplify_t.column1)@1 + 2 * 
count(sum_simplify_t.column1)@2 as sum(sum_simplify_t.column1 + Int64(2))]
+04)------AggregateExec: mode=FinalPartitioned, gby=[column2@0 as column2], 
aggr=[sum(sum_simplify_t.column1), count(sum_simplify_t.column1)]
+05)--------RepartitionExec: partitioning=Hash([column2@0], 4), 
input_partitions=1
+06)----------AggregateExec: mode=Partial, gby=[column2@1 as column2], 
aggr=[sum(sum_simplify_t.column1), count(sum_simplify_t.column1)]
+07)------------DataSourceExec: partitions=1, partition_sizes=[1]
 
 # Checks commutative forms of equivalent aggregate arguments are simplified 
consistently.
 query II
@@ -279,13 +288,15 @@ query TT
 EXPLAIN SELECT SUM(1 + column1), SUM(column1 + 1) FROM sum_simplify_t;
 ----
 logical_plan
-01)Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS Int64(1) + 
sum_simplify_t.column1), sum(__common_expr_1 AS sum_simplify_t.column1 + 
Int64(1))]]
-02)--Projection: Int64(1) + sum_simplify_t.column1 AS __common_expr_1
-03)----TableScan: sum_simplify_t projection=[column1]
+01)Projection: __common_expr_1 AS sum(Int64(1) + sum_simplify_t.column1), 
__common_expr_1 AS sum(sum_simplify_t.column1 + Int64(1))
+02)--Projection: sum(sum_simplify_t.column1) + 
CAST(count(sum_simplify_t.column1) AS Int64) AS __common_expr_1
+03)----Aggregate: groupBy=[[]], aggr=[[sum(sum_simplify_t.column1), 
count(sum_simplify_t.column1)]]
+04)------TableScan: sum_simplify_t projection=[column1]
 physical_plan
-01)AggregateExec: mode=Single, gby=[], aggr=[sum(Int64(1) + 
sum_simplify_t.column1), sum(sum_simplify_t.column1 + Int64(1))]
-02)--ProjectionExec: expr=[1 + column1@0 as __common_expr_1]
-03)----DataSourceExec: partitions=1, partition_sizes=[1]
+01)ProjectionExec: expr=[__common_expr_1@0 as sum(Int64(1) + 
sum_simplify_t.column1), __common_expr_1@0 as sum(sum_simplify_t.column1 + 
Int64(1))]
+02)--ProjectionExec: expr=[sum(sum_simplify_t.column1)@0 + 
count(sum_simplify_t.column1)@1 as __common_expr_1]
+03)----AggregateExec: mode=Single, gby=[], aggr=[sum(sum_simplify_t.column1), 
count(sum_simplify_t.column1)]
+04)------DataSourceExec: partitions=1, partition_sizes=[1]
 
 # Checks unsigned overflow edge case from PR discussion using transformed SUM 
arguments.
 statement ok
@@ -308,14 +319,17 @@ EXPLAIN SELECT arrow_typeof(SUM(val + 1)), SUM(val + 1), 
SUM(val + 2) FROM tbl;
 ----
 logical_plan
 01)Projection: arrow_typeof(sum(tbl.val + Int64(1))), sum(tbl.val + Int64(1)), 
sum(tbl.val + Int64(2))
-02)--Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS tbl.val + 
Int64(1)), sum(__common_expr_1 AS tbl.val + Int64(2))]]
-03)----Projection: CAST(tbl.val AS Int64) AS __common_expr_1
-04)------TableScan: tbl projection=[val]
+02)--Projection: sum(tbl.val) + __common_expr_1 AS sum(tbl.val + Int64(1)), 
sum(tbl.val) + Int64(2) * __common_expr_1 AS sum(tbl.val + Int64(2))
+03)----Projection: CAST(count(tbl.val) AS Int64) AS __common_expr_1, 
sum(tbl.val)
+04)------Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_2 AS tbl.val), 
count(__common_expr_2 AS tbl.val)]]
+05)--------Projection: CAST(tbl.val AS Int64) AS __common_expr_2
+06)----------TableScan: tbl projection=[val]
 physical_plan
 01)ProjectionExec: expr=[arrow_typeof(sum(tbl.val + Int64(1))@0) as 
arrow_typeof(sum(tbl.val + Int64(1))), sum(tbl.val + Int64(1))@0 as sum(tbl.val 
+ Int64(1)), sum(tbl.val + Int64(2))@1 as sum(tbl.val + Int64(2))]
-02)--AggregateExec: mode=Single, gby=[], aggr=[sum(tbl.val + Int64(1)), 
sum(tbl.val + Int64(2))]
-03)----ProjectionExec: expr=[CAST(val@0 AS Int64) as __common_expr_1]
-04)------DataSourceExec: partitions=1, partition_sizes=[2]
+02)--ProjectionExec: expr=[sum(tbl.val)@0 + count(tbl.val)@1 as sum(tbl.val + 
Int64(1)), sum(tbl.val)@0 + 2 * count(tbl.val)@1 as sum(tbl.val + Int64(2))]
+03)----AggregateExec: mode=Single, gby=[], aggr=[sum(tbl.val), count(tbl.val)]
+04)------ProjectionExec: expr=[CAST(val@0 AS Int64) as __common_expr_2]
+05)--------DataSourceExec: partitions=1, partition_sizes=[2]
 
 # Checks equivalent rewritten form (SUM + COUNT terms) matches transformed SUM 
semantics.
 query RR
diff --git a/datafusion/sqllogictest/test_files/clickbench.slt 
b/datafusion/sqllogictest/test_files/clickbench.slt
index e14d28d5ef..42f066a80d 100644
--- a/datafusion/sqllogictest/test_files/clickbench.slt
+++ b/datafusion/sqllogictest/test_files/clickbench.slt
@@ -787,13 +787,16 @@ query TT
 EXPLAIN SELECT SUM("ResolutionWidth"), SUM("ResolutionWidth" + 1), 
SUM("ResolutionWidth" + 2), SUM("ResolutionWidth" + 3), SUM("ResolutionWidth" + 
4), SUM("ResolutionWidth" + 5), SUM("ResolutionWidth" + 6), 
SUM("ResolutionWidth" + 7), SUM("ResolutionWidth" + 8), SUM("ResolutionWidth" + 
9), SUM("ResolutionWidth" + 10), SUM("ResolutionWidth" + 11), 
SUM("ResolutionWidth" + 12), SUM("ResolutionWidth" + 13), SUM("ResolutionWidth" 
+ 14), SUM("ResolutionWidth" + 15), SUM("ResolutionWidth" + 16) [...]
 ----
 logical_plan
-01)Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS 
hits.ResolutionWidth), sum(__common_expr_1 AS hits.ResolutionWidth + Int64(1)), 
sum(__common_expr_1 AS hits.ResolutionWidth + Int64(2)), sum(__common_expr_1 AS 
hits.ResolutionWidth + Int64(3)), sum(__common_expr_1 AS hits.ResolutionWidth + 
Int64(4)), sum(__common_expr_1 AS hits.ResolutionWidth + Int64(5)), 
sum(__common_expr_1 AS hits.ResolutionWidth + Int64(6)), sum(__common_expr_1 AS 
hits.ResolutionWidth + Int64(7)), sum(__common [...]
-02)--Projection: CAST(hits.ResolutionWidth AS Int64) AS __common_expr_1
-03)----SubqueryAlias: hits
-04)------TableScan: hits_raw projection=[ResolutionWidth]
+01)Projection: sum(hits.ResolutionWidth), sum(hits.ResolutionWidth) + 
__common_expr_1 AS sum(hits.ResolutionWidth + Int64(1)), 
sum(hits.ResolutionWidth) + Int64(2) * __common_expr_1 AS 
sum(hits.ResolutionWidth + Int64(2)), sum(hits.ResolutionWidth) + Int64(3) * 
__common_expr_1 AS sum(hits.ResolutionWidth + Int64(3)), 
sum(hits.ResolutionWidth) + Int64(4) * __common_expr_1 AS 
sum(hits.ResolutionWidth + Int64(4)), sum(hits.ResolutionWidth) + Int64(5) * 
__common_expr_1 AS sum(hits.Resolution [...]
+02)--Projection: CAST(count(hits.ResolutionWidth) AS Int64) AS 
__common_expr_1, sum(hits.ResolutionWidth)
+03)----Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_2 AS 
hits.ResolutionWidth), count(__common_expr_2 AS hits.ResolutionWidth)]]
+04)------Projection: CAST(hits.ResolutionWidth AS Int64) AS __common_expr_2
+05)--------SubqueryAlias: hits
+06)----------TableScan: hits_raw projection=[ResolutionWidth]
 physical_plan
-01)AggregateExec: mode=Single, gby=[], aggr=[sum(hits.ResolutionWidth), 
sum(hits.ResolutionWidth + Int64(1)), sum(hits.ResolutionWidth + Int64(2)), 
sum(hits.ResolutionWidth + Int64(3)), sum(hits.ResolutionWidth + Int64(4)), 
sum(hits.ResolutionWidth + Int64(5)), sum(hits.ResolutionWidth + Int64(6)), 
sum(hits.ResolutionWidth + Int64(7)), sum(hits.ResolutionWidth + Int64(8)), 
sum(hits.ResolutionWidth + Int64(9)), sum(hits.ResolutionWidth + Int64(10)), 
sum(hits.ResolutionWidth + Int64(11)),  [...]
-02)--DataSourceExec: file_groups={1 group: 
[[WORKSPACE_ROOT/datafusion/core/tests/data/clickbench_hits_10.parquet]]}, 
projection=[CAST(ResolutionWidth@20 AS Int64) as __common_expr_1], 
file_type=parquet
+01)ProjectionExec: expr=[sum(hits.ResolutionWidth)@0 as 
sum(hits.ResolutionWidth), sum(hits.ResolutionWidth)@0 + 
count(hits.ResolutionWidth)@1 as sum(hits.ResolutionWidth + Int64(1)), 
sum(hits.ResolutionWidth)@0 + 2 * count(hits.ResolutionWidth)@1 as 
sum(hits.ResolutionWidth + Int64(2)), sum(hits.ResolutionWidth)@0 + 3 * 
count(hits.ResolutionWidth)@1 as sum(hits.ResolutionWidth + Int64(3)), 
sum(hits.ResolutionWidth)@0 + 4 * count(hits.ResolutionWidth)@1 as 
sum(hits.ResolutionWidth + Int6 [...]
+02)--AggregateExec: mode=Single, gby=[], aggr=[sum(hits.ResolutionWidth), 
count(hits.ResolutionWidth)]
+03)----DataSourceExec: file_groups={1 group: 
[[WORKSPACE_ROOT/datafusion/core/tests/data/clickbench_hits_10.parquet]]}, 
projection=[CAST(ResolutionWidth@20 AS Int64) as __common_expr_2], 
file_type=parquet
 
 query 
IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII
 SELECT SUM("ResolutionWidth"), SUM("ResolutionWidth" + 1), 
SUM("ResolutionWidth" + 2), SUM("ResolutionWidth" + 3), SUM("ResolutionWidth" + 
4), SUM("ResolutionWidth" + 5), SUM("ResolutionWidth" + 6), 
SUM("ResolutionWidth" + 7), SUM("ResolutionWidth" + 8), SUM("ResolutionWidth" + 
9), SUM("ResolutionWidth" + 10), SUM("ResolutionWidth" + 11), 
SUM("ResolutionWidth" + 12), SUM("ResolutionWidth" + 13), SUM("ResolutionWidth" 
+ 14), SUM("ResolutionWidth" + 15), SUM("ResolutionWidth" + 16), SUM("R [...]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to