This is an automated email from the ASF dual-hosted git repository.
github-bot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 6d3a846483 Rewrite `SUM(expr + scalar)` --> `SUM(expr) +
scalar*COUNT(expr)` (#20749)
6d3a846483 is described below
commit 6d3a8464836eb59605cc59579b6eb3af8c4788bf
Author: Andrew Lamb <[email protected]>
AuthorDate: Sat Mar 14 08:22:01 2026 -0400
Rewrite `SUM(expr + scalar)` --> `SUM(expr) + scalar*COUNT(expr)` (#20749)
## Which issue does this PR close?
- Part of #18489
- Closes https://github.com/apache/datafusion/pull/20180
- Closes https://github.com/apache/datafusion/issues/15524
- Replaces https://github.com/apache/datafusion/pull/20665
## Rationale for this change
I [want DataFusion to be the fastest parquet engine on
ClickBench](https://github.com/apache/datafusion/issues/18489). One of
the queries where DataFusion is significantly slower is Query 29 which
has a very strange pattern of many aggregate functions that are offset
by a constant:
https://github.com/apache/datafusion/blob/0ca9d6586a43c323525b2e299448e0f1af4d6195/benchmarks/queries/clickbench/queries/q29.sql#L4
This is not a pattern I have ever seen in a real query, but it seems
like the engine currently at the top of the ClickBench leaderboard has a
special case for this pattern. ClickHouse probably does too. See
- https://github.com/duckdb/duckdb/pull/15017
- Discussion on https://github.com/apache/datafusion/issues/15524
Thus I reluctantly conclude that we should have one too.
## What changes are included in this PR?
This is an alternate to my first attempt.
- https://github.com/apache/datafusion/pull/20665
In particular, since this is such a ClickBench specific rule, I wanted
to
1. Minimize the downstream API / upgrade impact (aka not change existing
APIs)
2. Optimize performance for the case where this rewrite will not apply
(most times)
1. Add a rewrite `SUM(expr + scalar)` --> `SUM(expr) +
scalar*COUNT(expr)`
3. Tests for same
Note there are quite a few other ideas to potentially make this more
general on https://github.com/apache/datafusion/issues/15524 but I am
going with the simple thing of making it work for the usecase we have in
hand (ClickBench)
## Are these changes tested?
Yes, new tests are added
## Are there any user-facing changes?
Faster performance
🚀
```
│ QQuery 29 │ 1012.63 ms │ 139.02 ms │ +7.28x faster │
```
---
datafusion/expr/src/expr.rs | 2 +-
datafusion/expr/src/logical_plan/plan.rs | 4 +-
datafusion/expr/src/udaf.rs | 96 +++++++++
datafusion/functions-aggregate/src/sum.rs | 50 ++++-
.../src/simplify_expressions/linear_aggregates.rs | 229 +++++++++++++++++++++
.../optimizer/src/simplify_expressions/mod.rs | 1 +
.../src/simplify_expressions/simplify_exprs.rs | 155 +++++++++++++-
.../test_files/aggregates_simplify.slt | 66 +++---
datafusion/sqllogictest/test_files/clickbench.slt | 15 +-
9 files changed, 574 insertions(+), 44 deletions(-)
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index 5c6acd480e..12c879a515 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -600,7 +600,7 @@ impl Alias {
}
}
-/// Binary expression
+/// Binary expression for [`Expr::BinaryExpr`]
#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
pub struct BinaryExpr {
/// Left-hand side of the expression
diff --git a/datafusion/expr/src/logical_plan/plan.rs
b/datafusion/expr/src/logical_plan/plan.rs
index fe8a8dd870..b2a5697183 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -3490,7 +3490,9 @@ pub struct Aggregate {
pub input: Arc<LogicalPlan>,
/// Grouping expressions
pub group_expr: Vec<Expr>,
- /// Aggregate expressions
+ /// Aggregate expressions.
+ ///
+ /// Note these *must* be either [`Expr::AggregateFunction`] or
[`Expr::Alias`]
pub aggr_expr: Vec<Expr>,
/// The schema description of the aggregate output
pub schema: DFSchemaRef,
diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs
index f2e2c53cdb..245a80c02c 100644
--- a/datafusion/expr/src/udaf.rs
+++ b/datafusion/expr/src/udaf.rs
@@ -28,6 +28,7 @@ use arrow::datatypes::{DataType, Field, FieldRef};
use datafusion_common::{Result, ScalarValue, Statistics, exec_err,
not_impl_err};
use datafusion_expr_common::dyn_eq::{DynEq, DynHash};
+use datafusion_expr_common::operator::Operator;
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
use crate::expr::{
@@ -301,6 +302,21 @@ impl AggregateUDF {
self.inner.simplify()
}
+ /// Rewrite aggregate to have simpler arguments
+ ///
+ /// See [`AggregateUDFImpl::simplify_expr_op_literal`] for more details
+ pub fn simplify_expr_op_literal(
+ &self,
+ agg_function: &AggregateFunction,
+ arg: &Expr,
+ op: Operator,
+ lit: &Expr,
+ arg_is_left: bool,
+ ) -> Result<Option<Expr>> {
+ self.inner
+ .simplify_expr_op_literal(agg_function, arg, op, lit, arg_is_left)
+ }
+
/// Returns true if the function is max, false if the function is min
/// None in all other cases, used in certain optimizations for
/// or aggregate
@@ -691,6 +707,74 @@ pub trait AggregateUDFImpl: Debug + DynEq + DynHash + Send
+ Sync {
None
}
+ /// Rewrite the aggregate to have simpler arguments
+ ///
+ /// This query pattern is not common in most real workloads, and most
+ /// aggregate implementations can safely ignore it. This API is included in
+ /// DataFusion because it is important for ClickBench Q29. See backstory
+ /// on <https://github.com/apache/datafusion/issues/15524>
+ ///
+ /// # Rewrite Overview
+ ///
+ /// The idea is to rewrite multiple aggregates with "complex arguments"
into
+ /// ones with simpler arguments that can be optimized by common
subexpression
+ /// elimination (CSE). At a high level the rewrite looks like
+ ///
+ /// * `Aggregate(SUM(x + 1), SUM(x + 2), ...)`
+ ///
+ /// Into
+ ///
+ /// * `Aggregate(SUM(x) + 1 * COUNT(x), SUM(x) + 2 * COUNT(x), ...)`
+ ///
+ /// While this rewrite may seem worse (slower) than the original as it
+ /// computes *more* aggregate expressions, the common subexpression
+ /// elimination (CSE) can then reduce the number of distinct aggregates the
+ /// query actually needs to compute with a rewrite like
+ ///
+ /// * `Projection(_A + 1*_B, _A + 2*_B)`
+ /// * ` Aggregate(_A = SUM(x), _B = COUNT(x))`
+ ///
+ /// This optimization is extremely important for ClickBench Q29, which has
90
+ /// such expressions for some reason, and so this optimization results in
+ /// only two aggregates being needed. The DataFusion optimizer will invoke
+ /// this method when it detects multiple aggregates in a query that share
+ /// arguments of the form `<arg> <op> <literal>`.
+ ///
+ /// # API
+ ///
+ /// If `agg_function` supports the rewrite, it should return a semantically
+ /// equivalent expression (likely with more aggregate expressions, but
+ /// simpler arguments)
+ ///
+ /// This is only called when:
+ /// 1. There are no "special" aggregate params (filters, null handling,
etc)
+ /// 2. Aggregate functions with exactly one [`Expr`] argument
+ /// 3. There are no volatile expressions
+ ///
+ /// Arguments
+ /// * `agg_function`: the original aggregate function detected with complex
+ /// arguments.
+ /// * `arg`: The common argument shared across multiple aggregates (e.g.
`x`
+ /// in the example above)
+ /// * `op`: the operator between the common argument and the literal (e.g.
+ /// `+` in `x + 1` or `1 + x`)
+ /// * `lit`: the literal argument (e.g. `1` or `2` in the example above)
+ /// * `arg_is_left`: whether the common argument is on the left or right of
+ /// the operator (e.g. `true` for `x + 1` and false for `1 + x`)
+ ///
+ /// The default implementation returns `None`, which is what most
aggregates
+ /// should do.
+ fn simplify_expr_op_literal(
+ &self,
+ _agg_function: &AggregateFunction,
+ _arg: &Expr,
+ _op: Operator,
+ _lit: &Expr,
+ _arg_is_left: bool,
+ ) -> Result<Option<Expr>> {
+ Ok(None)
+ }
+
/// Returns the reverse expression of the aggregate function.
fn reverse_expr(&self) -> ReversedUDAF {
ReversedUDAF::NotSupported
@@ -1243,6 +1327,18 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl {
self.inner.simplify()
}
+ fn simplify_expr_op_literal(
+ &self,
+ agg_function: &AggregateFunction,
+ arg: &Expr,
+ op: Operator,
+ lit: &Expr,
+ arg_is_left: bool,
+ ) -> Result<Option<Expr>> {
+ self.inner
+ .simplify_expr_op_literal(agg_function, arg, op, lit, arg_is_left)
+ }
+
fn reverse_expr(&self) -> ReversedUDAF {
self.inner.reverse_expr()
}
diff --git a/datafusion/functions-aggregate/src/sum.rs
b/datafusion/functions-aggregate/src/sum.rs
index 198ba54adf..5cced80d99 100644
--- a/datafusion/functions-aggregate/src/sum.rs
+++ b/datafusion/functions-aggregate/src/sum.rs
@@ -27,17 +27,20 @@ use arrow::datatypes::{
DurationMillisecondType, DurationNanosecondType, DurationSecondType,
FieldRef,
Float64Type, Int64Type, TimeUnit, UInt64Type,
};
+use datafusion_common::internal_err;
use datafusion_common::types::{
NativeType, logical_float64, logical_int8, logical_int16, logical_int32,
logical_int64, logical_uint8, logical_uint16, logical_uint32,
logical_uint64,
};
use datafusion_common::{HashMap, Result, ScalarValue, exec_err, not_impl_err};
+use datafusion_expr::expr::AggregateFunction;
+use datafusion_expr::expr_fn::cast;
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion_expr::utils::{AggregateOrderSensitivity, format_state_name};
use datafusion_expr::{
Accumulator, AggregateUDFImpl, Coercion, Documentation, Expr,
GroupsAccumulator,
- ReversedUDAF, SetMonotonicity, Signature, TypeSignature,
TypeSignatureClass,
- Volatility,
+ Operator, ReversedUDAF, SetMonotonicity, Signature, TypeSignature,
+ TypeSignatureClass, Volatility,
};
use
datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
use
datafusion_functions_aggregate_common::aggregate::sum_distinct::DistinctSumAccumulator;
@@ -54,7 +57,7 @@ make_udaf_expr_and_func!(
);
pub fn sum_distinct(expr: Expr) -> Expr {
- Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
+ Expr::AggregateFunction(AggregateFunction::new_udf(
sum_udaf(),
vec![expr],
true,
@@ -346,6 +349,47 @@ impl AggregateUDFImpl for Sum {
_ => SetMonotonicity::NotMonotonic,
}
}
+
+ /// Implement ClickBench Q29 specific optimization:
+ /// `SUM(arg + constant)` --> `SUM(arg) + constant * COUNT(arg)`
+ ///
+ /// See background on [`AggregateUDFImpl::simplify_expr_op_literal`]
+ fn simplify_expr_op_literal(
+ &self,
+ agg_function: &AggregateFunction,
+ arg: &Expr,
+ op: Operator,
+ lit: &Expr,
+ // Only support '+' so the order of the args doesn't matter
+ _arg_is_left: bool,
+ ) -> Result<Option<Expr>> {
+ if op != Operator::Plus {
+ return Ok(None);
+ }
+
+ let lit_type = match &lit {
+ Expr::Literal(value, _) => value.data_type(),
+ _ => {
+ return internal_err!(
+ "Sum::simplify_expr_op_literal got a non literal argument"
+ );
+ }
+ };
+ if lit_type == DataType::Null {
+ return Ok(None);
+ }
+
+ // Build up SUM(arg)
+ let mut sum_agg = agg_function.clone();
+ sum_agg.params.args = vec![arg.clone()];
+ let sum_agg = Expr::AggregateFunction(sum_agg);
+
+ // COUNT(arg) - cast to the correct type
+ let count_agg = cast(crate::count::count(arg.clone()), lit_type);
+
+ // SUM(arg) + lit * COUNT(arg)
+ Ok(Some(sum_agg + (lit.clone() * count_agg)))
+ }
}
/// This accumulator computes SUM incrementally
diff --git a/datafusion/optimizer/src/simplify_expressions/linear_aggregates.rs
b/datafusion/optimizer/src/simplify_expressions/linear_aggregates.rs
new file mode 100644
index 0000000000..21389cf326
--- /dev/null
+++ b/datafusion/optimizer/src/simplify_expressions/linear_aggregates.rs
@@ -0,0 +1,229 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Simplification to refactor multiple aggregate functions to use the same
aggregate function
+
+use datafusion_common::HashMap;
+use datafusion_expr::expr::AggregateFunctionParams;
+use datafusion_expr::{BinaryExpr, Expr};
+use datafusion_expr_common::operator::Operator;
+
+/// Threshold of the number of aggregates that share similar arguments before
+/// triggering rewrite.
+///
+/// There is a threshold because the canonical SUM rewrite described in
+/// [`AggregateUDFImpl::simplify_expr_op_literal`] actually results in more
+/// aggregates (2) for each original aggregate. It is important that CSE then
+/// eliminate them.
+///
+/// [`AggregateUDFImpl::simplify_expr_op_literal`]:
datafusion_expr::AggregateUDFImpl::simplify_expr_op_literal
+const DUPLICATE_THRESHOLD: usize = 2;
+
+/// Rewrites multiple aggregate expressions that have a common linear component
+/// into multiple aggregate expressions that share that common component.
+///
+/// For example, rewrites patterns such as
+/// * `SUM(x + 1), SUM(x + 2), ...`
+///
+/// Into
+/// * `SUM(x) + 1 * COUNT(x), SUM(x) + 2 * COUNT(x), ...`
+///
+/// See the background [`AggregateUDFImpl::simplify_expr_op_literal`] for
details.
+///
+/// Returns `true` if any of the arguments are rewritten (modified), `false`
+/// otherwise.
+///
+/// ## Design goals:
+/// 1. Keep the aggregate specific logic out of the optimizer (can't depend
directly on SUM)
+/// 2. Optimize for the case that this rewrite will not apply (it almost never
does)
+///
+/// [`AggregateUDFImpl::simplify_expr_op_literal`]:
datafusion_expr::AggregateUDFImpl::simplify_expr_op_literal
+pub(super) fn rewrite_multiple_linear_aggregates(
+ agg_expr: &mut [Expr],
+) -> datafusion_common::Result<bool> {
+ // map <expr>: count of expressions that have a common argument
+ let mut common_args = HashMap::new();
+
+ // First pass -- figure out any aggregates that can be split and have
common
+ // expressions.
+ for agg in agg_expr.iter() {
+ let Expr::AggregateFunction(agg_function) = agg else {
+ continue;
+ };
+
+ let Some(arg) = candidate_linear_param(&agg_function.params) else {
+ continue;
+ };
+
+ let Some(expr_literal) = ExprLiteral::try_new(arg) else {
+ continue;
+ };
+
+ let counter = common_args.entry(expr_literal.expr()).or_insert(0);
+ *counter += 1;
+ }
+
+ // (agg_index, new_expr)
+ let mut new_aggs = vec![];
+
+ // Second pass, actually rewrite any aggregates that have a common
+ // expression and enough duplicates.
+ for (idx, agg) in agg_expr.iter().enumerate() {
+ let Expr::AggregateFunction(agg_function) = agg else {
+ continue;
+ };
+
+ let Some(arg) = candidate_linear_param(&agg_function.params) else {
+ continue;
+ };
+
+ let Some(expr_literal) = ExprLiteral::try_new(arg) else {
+ continue;
+ };
+
+ // Not enough common expressions to make it worth rewriting
+ if common_args.get(expr_literal.expr()).unwrap_or(&0) <
&DUPLICATE_THRESHOLD {
+ continue;
+ }
+
+ if let Some(new_agg_function) =
agg_function.func.simplify_expr_op_literal(
+ agg_function,
+ expr_literal.expr(),
+ expr_literal.op(),
+ expr_literal.lit(),
+ expr_literal.arg_is_left(),
+ )? {
+ new_aggs.push((idx, new_agg_function));
+ }
+ }
+
+ if new_aggs.is_empty() {
+ return Ok(false);
+ }
+
+ // Otherwise replace the aggregate expressions
+ drop(common_args); // release borrow
+ for (idx, new_agg) in new_aggs {
+ let orig_name = agg_expr[idx].name_for_alias()?;
+ agg_expr[idx] = new_agg.alias_if_changed(orig_name)?
+ }
+
+ Ok(true)
+}
+
+/// Returns Some(&Expr) with the single argument if this is a suitable
candidate
+/// for the linear rewrite
+fn candidate_linear_param(params: &AggregateFunctionParams) -> Option<&Expr> {
+ // Explicitly destructure to ensure we check all relevant fields
+ let AggregateFunctionParams {
+ args,
+ distinct,
+ filter,
+ order_by,
+ null_treatment,
+ } = params;
+
+ // Disqualify anything "non standard"
+ if *distinct
+ || filter.is_some()
+ || !order_by.is_empty()
+ || null_treatment.is_some()
+ || args.len() != 1
+ {
+ return None;
+ }
+ let arg = args.first()?;
+ if arg.is_volatile() {
+ return None;
+ };
+ Some(arg)
+}
+
+/// A view into a [`Expr::BinaryExpr`] that is arbitrary expression and a
+/// literal
+///
+/// This is an enum to distinguish the direction of the operator arguments
+#[derive(Debug, Clone)]
+pub enum ExprLiteral<'a> {
+ /// if the expression is `<arg> <op> <lit>`
+ ArgOpLit {
+ arg: &'a Expr,
+ op: Operator,
+ lit: &'a Expr,
+ },
+ /// if the expression is `<lit> <op> <arg>`
+ LitOpArg {
+ lit: &'a Expr,
+ op: Operator,
+ arg: &'a Expr,
+ },
+}
+
+impl<'a> ExprLiteral<'a> {
+ /// Try and split the Expr into its parts
+ fn try_new(expr: &'a Expr) -> Option<Self> {
+ match expr {
+ // <lit> <op> <expr>
+ Expr::BinaryExpr(BinaryExpr { left, op, right })
+ if matches!(left.as_ref(), Expr::Literal(..)) =>
+ {
+ Some(Self::LitOpArg {
+ arg: right,
+ lit: left,
+ op: *op,
+ })
+ }
+
+ // <expr> + <lit>
+ Expr::BinaryExpr(BinaryExpr { left, op, right })
+ if matches!(right.as_ref(), Expr::Literal(..)) =>
+ {
+ Some(Self::ArgOpLit {
+ arg: left,
+ lit: right,
+ op: *op,
+ })
+ }
+ _ => None,
+ }
+ }
+
+ fn expr(&self) -> &'a Expr {
+ match self {
+ Self::ArgOpLit { arg, .. } => arg,
+ Self::LitOpArg { arg, .. } => arg,
+ }
+ }
+
+ fn lit(&self) -> &'a Expr {
+ match self {
+ Self::ArgOpLit { lit, .. } => lit,
+ Self::LitOpArg { lit, .. } => lit,
+ }
+ }
+
+ fn op(&self) -> Operator {
+ match self {
+ Self::ArgOpLit { op, .. } => *op,
+ Self::LitOpArg { op, .. } => *op,
+ }
+ }
+
+ fn arg_is_left(&self) -> bool {
+ matches!(self, Self::ArgOpLit { .. })
+ }
+}
diff --git a/datafusion/optimizer/src/simplify_expressions/mod.rs
b/datafusion/optimizer/src/simplify_expressions/mod.rs
index b85b000821..89c79d3fb4 100644
--- a/datafusion/optimizer/src/simplify_expressions/mod.rs
+++ b/datafusion/optimizer/src/simplify_expressions/mod.rs
@@ -20,6 +20,7 @@
pub mod expr_simplifier;
mod inlist_simplifier;
+mod linear_aggregates;
mod regex;
pub mod simplify_exprs;
pub mod simplify_literal;
diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs
b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs
index 2114c5ef3d..29ee593422 100644
--- a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs
+++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs
@@ -20,18 +20,20 @@
use std::sync::Arc;
use datafusion_common::tree_node::{Transformed, TreeNode};
-use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result};
+use datafusion_common::{Column, DFSchema, DFSchemaRef, DataFusionError,
Result};
use datafusion_expr::Expr;
-use datafusion_expr::logical_plan::LogicalPlan;
+use datafusion_expr::logical_plan::{Aggregate, LogicalPlan, Projection};
use datafusion_expr::simplify::SimplifyContext;
-use datafusion_expr::utils::merge_schema;
+use datafusion_expr::utils::{
+ columnize_expr, find_aggregate_exprs, grouping_set_to_exprlist,
merge_schema,
+};
+use super::ExprSimplifier;
use crate::optimizer::ApplyOrder;
+use
crate::simplify_expressions::linear_aggregates::rewrite_multiple_linear_aggregates;
use crate::utils::NamePreserver;
use crate::{OptimizerConfig, OptimizerRule};
-use super::ExprSimplifier;
-
/// Optimizer Pass that simplifies [`LogicalPlan`]s by rewriting
/// [`Expr`]`s evaluating constants and applying algebraic
/// simplifications
@@ -137,7 +139,8 @@ impl SimplifyExpressions {
} else {
rewrite_expr(expr)
}
- })
+ })?
+ .transform_data(rewrite_aggregate_non_aggregate_aggr_expr)
}
}
@@ -148,6 +151,98 @@ impl SimplifyExpressions {
}
}
+/// Ensures that `LogicalPlan::Aggregate` is well formed after rewrites
+/// by potentially introducing an extra `Projection`.
+///
+/// Also applies the [`rewrite_multiple_linear_aggregates`] special case
+///
+/// # Rationale:
+///
+/// [`LogicalPlan::Aggregate`] requires agg expressions to be (possibly
aliased)
+/// [`Expr::AggregateFunction`]. Some UDAF simplifiers may return other
[`Expr`]
+/// variants.
+///
+/// # Operation
+///
+/// Rewrites things like this (note that `exp1` is not an aggregate):
+/// * `Aggregate(group_expr, aggr_expr=[exp1 + agg(exp2)])`
+///
+/// into:
+/// * `Projection(exp1 + _X)`
+/// * ` Aggregate(group_expr, aggr_expr=[agg(exp2) AS _X])`
+fn rewrite_aggregate_non_aggregate_aggr_expr(
+ plan: LogicalPlan,
+) -> Result<Transformed<LogicalPlan>> {
+ let LogicalPlan::Aggregate(Aggregate {
+ input,
+ group_expr,
+ mut aggr_expr,
+ schema,
+ ..
+ }) = plan
+ else {
+ return Ok(Transformed::no(plan));
+ };
+
+ let rewrote_aggs = rewrite_multiple_linear_aggregates(&mut aggr_expr)?;
+
+ // Ensure that all Aggregate arguments are AggregateExpr
+ if aggr_expr.iter().all(is_top_level_aggregate_expr) {
+ let new_plan = LogicalPlan::Aggregate(Aggregate::try_new_with_schema(
+ input, group_expr, aggr_expr, schema,
+ )?);
+ return if !rewrote_aggs {
+ Ok(Transformed::no(new_plan))
+ } else {
+ Ok(Transformed::yes(new_plan))
+ };
+ }
+
+ // Otherwise we need to add a Projection above Aggregate to calculate
+ // the final output expressions.
+
+ let inner_aggr_expr = find_aggregate_exprs(aggr_expr.iter());
+ let inner_aggregate = LogicalPlan::Aggregate(Aggregate::try_new(
+ Arc::clone(&input),
+ group_expr.clone(),
+ inner_aggr_expr,
+ )?);
+ let inner_aggregate = Arc::new(inner_aggregate);
+
+ let mut projection_exprs = aggregate_output_exprs(&group_expr)?;
+ projection_exprs.extend(aggr_expr);
+ let projection_exprs = projection_exprs
+ .into_iter()
+ .map(|expr| columnize_expr(expr, inner_aggregate.as_ref()))
+ .collect::<Result<Vec<_>>>()?;
+
+ Ok(Transformed::yes(LogicalPlan::Projection(
+ Projection::try_new(projection_exprs, inner_aggregate)?,
+ )))
+}
+
+fn is_top_level_aggregate_expr(expr: &Expr) -> bool {
+ matches!(
+ expr.clone().unalias_nested().data,
+ Expr::AggregateFunction(_)
+ )
+}
+
+fn aggregate_output_exprs(group_expr: &[Expr]) -> Result<Vec<Expr>> {
+ let mut output_exprs = grouping_set_to_exprlist(group_expr)?
+ .into_iter()
+ .cloned()
+ .collect::<Vec<_>>();
+
+ if matches!(group_expr, [Expr::GroupingSet(_)]) {
+ output_exprs.push(Expr::Column(Column::from_name(
+ Aggregate::INTERNAL_GROUPING_ID,
+ )));
+ }
+
+ Ok(output_exprs)
+}
+
#[cfg(test)]
mod tests {
use std::ops::Not;
@@ -159,7 +254,7 @@ mod tests {
use datafusion_expr::logical_plan::builder::table_scan_with_filters;
use datafusion_expr::logical_plan::table_scan;
use datafusion_expr::*;
- use datafusion_functions_aggregate::expr_fn::{max, min};
+ use datafusion_functions_aggregate::expr_fn::{max, min, sum};
use crate::OptimizerContext;
use crate::assert_optimized_plan_eq_snapshot;
@@ -259,6 +354,52 @@ mod tests {
)
}
+ #[test]
+ fn test_simplify_udaf_to_non_aggregate_expr() -> Result<()> {
+ let schema = Schema::new(vec![Field::new("a", DataType::Int64,
false)]);
+ let table_scan = table_scan(Some("test"), &schema, None)?
+ .build()
+ .expect("building scan");
+
+ let plan = LogicalPlanBuilder::from(table_scan)
+ .aggregate(Vec::<Expr>::new(), vec![sum(col("a") + lit(2i64))])?
+ .build()?;
+
+ assert_optimized_plan_equal!(
+ plan,
+ @r"
+ Aggregate: groupBy=[[]], aggr=[[sum(test.a + Int64(2))]]
+ TableScan: test
+ "
+ )?;
+ Ok(())
+ }
+
+ #[test]
+ fn test_simplify_common_sum_arg() -> Result<()> {
+ let schema = Schema::new(vec![Field::new("a", DataType::Int64,
false)]);
+ let table_scan = table_scan(Some("test"), &schema, None)?
+ .build()
+ .expect("building scan");
+
+ let plan = LogicalPlanBuilder::from(table_scan)
+ .aggregate(
+ Vec::<Expr>::new(),
+ vec![sum(col("a") + lit(2i64)), sum(col("a") + lit(3i64))],
+ )?
+ .build()?;
+
+ assert_optimized_plan_equal!(
+ plan,
+ @r"
+ Projection: sum(test.a) + Int64(2) * CAST(count(test.a) AS Int64) AS
sum(test.a + Int64(2)), sum(test.a) + Int64(3) * CAST(count(test.a) AS Int64)
AS sum(test.a + Int64(3))
+ Aggregate: groupBy=[[]], aggr=[[sum(test.a), count(test.a)]]
+ TableScan: test
+ "
+ )?;
+ Ok(())
+ }
+
#[test]
fn test_simplify_optimized_plan_with_or() -> Result<()> {
let table_scan = test_table_scan();
diff --git a/datafusion/sqllogictest/test_files/aggregates_simplify.slt
b/datafusion/sqllogictest/test_files/aggregates_simplify.slt
index cc2e40540b..9aa3ecf7a2 100644
--- a/datafusion/sqllogictest/test_files/aggregates_simplify.slt
+++ b/datafusion/sqllogictest/test_files/aggregates_simplify.slt
@@ -106,11 +106,14 @@ query TT
EXPLAIN SELECT SUM(column1 + 1), SUM(column1 + 2) FROM sum_simplify_t;
----
logical_plan
-01)Aggregate: groupBy=[[]], aggr=[[sum(sum_simplify_t.column1 + Int64(1)),
sum(sum_simplify_t.column1 + Int64(2))]]
-02)--TableScan: sum_simplify_t projection=[column1]
+01)Projection: sum(sum_simplify_t.column1) + __common_expr_1 AS
sum(sum_simplify_t.column1 + Int64(1)), sum(sum_simplify_t.column1) + Int64(2)
* __common_expr_1 AS sum(sum_simplify_t.column1 + Int64(2))
+02)--Projection: CAST(count(sum_simplify_t.column1) AS Int64) AS
__common_expr_1, sum(sum_simplify_t.column1)
+03)----Aggregate: groupBy=[[]], aggr=[[sum(sum_simplify_t.column1),
count(sum_simplify_t.column1)]]
+04)------TableScan: sum_simplify_t projection=[column1]
physical_plan
-01)AggregateExec: mode=Single, gby=[], aggr=[sum(sum_simplify_t.column1 +
Int64(1)), sum(sum_simplify_t.column1 + Int64(2))]
-02)--DataSourceExec: partitions=1, partition_sizes=[1]
+01)ProjectionExec: expr=[sum(sum_simplify_t.column1)@0 +
count(sum_simplify_t.column1)@1 as sum(sum_simplify_t.column1 + Int64(1)),
sum(sum_simplify_t.column1)@0 + 2 * count(sum_simplify_t.column1)@1 as
sum(sum_simplify_t.column1 + Int64(2))]
+02)--AggregateExec: mode=Single, gby=[], aggr=[sum(sum_simplify_t.column1),
count(sum_simplify_t.column1)]
+03)----DataSourceExec: partitions=1, partition_sizes=[1]
# Reordered expressions that still compute the same thing
query II
@@ -122,11 +125,14 @@ query TT
EXPLAIN SELECT SUM(1 + column1), SUM(column1 + 2) FROM sum_simplify_t;
----
logical_plan
-01)Aggregate: groupBy=[[]], aggr=[[sum(Int64(1) + sum_simplify_t.column1),
sum(sum_simplify_t.column1 + Int64(2))]]
-02)--TableScan: sum_simplify_t projection=[column1]
+01)Projection: sum(sum_simplify_t.column1) + __common_expr_1 AS sum(Int64(1) +
sum_simplify_t.column1), sum(sum_simplify_t.column1) + Int64(2) *
__common_expr_1 AS sum(sum_simplify_t.column1 + Int64(2))
+02)--Projection: CAST(count(sum_simplify_t.column1) AS Int64) AS
__common_expr_1, sum(sum_simplify_t.column1)
+03)----Aggregate: groupBy=[[]], aggr=[[sum(sum_simplify_t.column1),
count(sum_simplify_t.column1)]]
+04)------TableScan: sum_simplify_t projection=[column1]
physical_plan
-01)AggregateExec: mode=Single, gby=[], aggr=[sum(Int64(1) +
sum_simplify_t.column1), sum(sum_simplify_t.column1 + Int64(2))]
-02)--DataSourceExec: partitions=1, partition_sizes=[1]
+01)ProjectionExec: expr=[sum(sum_simplify_t.column1)@0 +
count(sum_simplify_t.column1)@1 as sum(Int64(1) + sum_simplify_t.column1),
sum(sum_simplify_t.column1)@0 + 2 * count(sum_simplify_t.column1)@1 as
sum(sum_simplify_t.column1 + Int64(2))]
+02)--AggregateExec: mode=Single, gby=[], aggr=[sum(sum_simplify_t.column1),
count(sum_simplify_t.column1)]
+03)----DataSourceExec: partitions=1, partition_sizes=[1]
# DISTINCT aggregates with different arguments
query II
@@ -259,15 +265,18 @@ EXPLAIN SELECT column2, SUM(column1 + 1), SUM(column1 +
2) FROM sum_simplify_t G
----
logical_plan
01)Sort: sum_simplify_t.column2 DESC NULLS LAST
-02)--Aggregate: groupBy=[[sum_simplify_t.column2]],
aggr=[[sum(sum_simplify_t.column1 + Int64(1)), sum(sum_simplify_t.column1 +
Int64(2))]]
-03)----TableScan: sum_simplify_t projection=[column1, column2]
+02)--Projection: sum_simplify_t.column2, sum(sum_simplify_t.column1) +
__common_expr_1 AS sum(sum_simplify_t.column1 + Int64(1)),
sum(sum_simplify_t.column1) + Int64(2) * __common_expr_1 AS
sum(sum_simplify_t.column1 + Int64(2))
+03)----Projection: CAST(count(sum_simplify_t.column1) AS Int64) AS
__common_expr_1, sum_simplify_t.column2, sum(sum_simplify_t.column1)
+04)------Aggregate: groupBy=[[sum_simplify_t.column2]],
aggr=[[sum(sum_simplify_t.column1), count(sum_simplify_t.column1)]]
+05)--------TableScan: sum_simplify_t projection=[column1, column2]
physical_plan
01)SortPreservingMergeExec: [column2@0 DESC NULLS LAST]
02)--SortExec: expr=[column2@0 DESC NULLS LAST], preserve_partitioning=[true]
-03)----AggregateExec: mode=FinalPartitioned, gby=[column2@0 as column2],
aggr=[sum(sum_simplify_t.column1 + Int64(1)), sum(sum_simplify_t.column1 +
Int64(2))]
-04)------RepartitionExec: partitioning=Hash([column2@0], 4), input_partitions=1
-05)--------AggregateExec: mode=Partial, gby=[column2@1 as column2],
aggr=[sum(sum_simplify_t.column1 + Int64(1)), sum(sum_simplify_t.column1 +
Int64(2))]
-06)----------DataSourceExec: partitions=1, partition_sizes=[1]
+03)----ProjectionExec: expr=[column2@0 as column2,
sum(sum_simplify_t.column1)@1 + count(sum_simplify_t.column1)@2 as
sum(sum_simplify_t.column1 + Int64(1)), sum(sum_simplify_t.column1)@1 + 2 *
count(sum_simplify_t.column1)@2 as sum(sum_simplify_t.column1 + Int64(2))]
+04)------AggregateExec: mode=FinalPartitioned, gby=[column2@0 as column2],
aggr=[sum(sum_simplify_t.column1), count(sum_simplify_t.column1)]
+05)--------RepartitionExec: partitioning=Hash([column2@0], 4),
input_partitions=1
+06)----------AggregateExec: mode=Partial, gby=[column2@1 as column2],
aggr=[sum(sum_simplify_t.column1), count(sum_simplify_t.column1)]
+07)------------DataSourceExec: partitions=1, partition_sizes=[1]
# Checks commutative forms of equivalent aggregate arguments are simplified
consistently.
query II
@@ -279,13 +288,15 @@ query TT
EXPLAIN SELECT SUM(1 + column1), SUM(column1 + 1) FROM sum_simplify_t;
----
logical_plan
-01)Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS Int64(1) +
sum_simplify_t.column1), sum(__common_expr_1 AS sum_simplify_t.column1 +
Int64(1))]]
-02)--Projection: Int64(1) + sum_simplify_t.column1 AS __common_expr_1
-03)----TableScan: sum_simplify_t projection=[column1]
+01)Projection: __common_expr_1 AS sum(Int64(1) + sum_simplify_t.column1),
__common_expr_1 AS sum(sum_simplify_t.column1 + Int64(1))
+02)--Projection: sum(sum_simplify_t.column1) +
CAST(count(sum_simplify_t.column1) AS Int64) AS __common_expr_1
+03)----Aggregate: groupBy=[[]], aggr=[[sum(sum_simplify_t.column1),
count(sum_simplify_t.column1)]]
+04)------TableScan: sum_simplify_t projection=[column1]
physical_plan
-01)AggregateExec: mode=Single, gby=[], aggr=[sum(Int64(1) +
sum_simplify_t.column1), sum(sum_simplify_t.column1 + Int64(1))]
-02)--ProjectionExec: expr=[1 + column1@0 as __common_expr_1]
-03)----DataSourceExec: partitions=1, partition_sizes=[1]
+01)ProjectionExec: expr=[__common_expr_1@0 as sum(Int64(1) +
sum_simplify_t.column1), __common_expr_1@0 as sum(sum_simplify_t.column1 +
Int64(1))]
+02)--ProjectionExec: expr=[sum(sum_simplify_t.column1)@0 +
count(sum_simplify_t.column1)@1 as __common_expr_1]
+03)----AggregateExec: mode=Single, gby=[], aggr=[sum(sum_simplify_t.column1),
count(sum_simplify_t.column1)]
+04)------DataSourceExec: partitions=1, partition_sizes=[1]
# Checks unsigned overflow edge case from PR discussion using transformed SUM
arguments.
statement ok
@@ -308,14 +319,17 @@ EXPLAIN SELECT arrow_typeof(SUM(val + 1)), SUM(val + 1),
SUM(val + 2) FROM tbl;
----
logical_plan
01)Projection: arrow_typeof(sum(tbl.val + Int64(1))), sum(tbl.val + Int64(1)),
sum(tbl.val + Int64(2))
-02)--Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS tbl.val +
Int64(1)), sum(__common_expr_1 AS tbl.val + Int64(2))]]
-03)----Projection: CAST(tbl.val AS Int64) AS __common_expr_1
-04)------TableScan: tbl projection=[val]
+02)--Projection: sum(tbl.val) + __common_expr_1 AS sum(tbl.val + Int64(1)),
sum(tbl.val) + Int64(2) * __common_expr_1 AS sum(tbl.val + Int64(2))
+03)----Projection: CAST(count(tbl.val) AS Int64) AS __common_expr_1,
sum(tbl.val)
+04)------Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_2 AS tbl.val),
count(__common_expr_2 AS tbl.val)]]
+05)--------Projection: CAST(tbl.val AS Int64) AS __common_expr_2
+06)----------TableScan: tbl projection=[val]
physical_plan
01)ProjectionExec: expr=[arrow_typeof(sum(tbl.val + Int64(1))@0) as
arrow_typeof(sum(tbl.val + Int64(1))), sum(tbl.val + Int64(1))@0 as sum(tbl.val
+ Int64(1)), sum(tbl.val + Int64(2))@1 as sum(tbl.val + Int64(2))]
-02)--AggregateExec: mode=Single, gby=[], aggr=[sum(tbl.val + Int64(1)),
sum(tbl.val + Int64(2))]
-03)----ProjectionExec: expr=[CAST(val@0 AS Int64) as __common_expr_1]
-04)------DataSourceExec: partitions=1, partition_sizes=[2]
+02)--ProjectionExec: expr=[sum(tbl.val)@0 + count(tbl.val)@1 as sum(tbl.val +
Int64(1)), sum(tbl.val)@0 + 2 * count(tbl.val)@1 as sum(tbl.val + Int64(2))]
+03)----AggregateExec: mode=Single, gby=[], aggr=[sum(tbl.val), count(tbl.val)]
+04)------ProjectionExec: expr=[CAST(val@0 AS Int64) as __common_expr_2]
+05)--------DataSourceExec: partitions=1, partition_sizes=[2]
# Checks equivalent rewritten form (SUM + COUNT terms) matches transformed SUM
semantics.
query RR
diff --git a/datafusion/sqllogictest/test_files/clickbench.slt
b/datafusion/sqllogictest/test_files/clickbench.slt
index e14d28d5ef..42f066a80d 100644
--- a/datafusion/sqllogictest/test_files/clickbench.slt
+++ b/datafusion/sqllogictest/test_files/clickbench.slt
@@ -787,13 +787,16 @@ query TT
EXPLAIN SELECT SUM("ResolutionWidth"), SUM("ResolutionWidth" + 1),
SUM("ResolutionWidth" + 2), SUM("ResolutionWidth" + 3), SUM("ResolutionWidth" +
4), SUM("ResolutionWidth" + 5), SUM("ResolutionWidth" + 6),
SUM("ResolutionWidth" + 7), SUM("ResolutionWidth" + 8), SUM("ResolutionWidth" +
9), SUM("ResolutionWidth" + 10), SUM("ResolutionWidth" + 11),
SUM("ResolutionWidth" + 12), SUM("ResolutionWidth" + 13), SUM("ResolutionWidth"
+ 14), SUM("ResolutionWidth" + 15), SUM("ResolutionWidth" + 16) [...]
----
logical_plan
-01)Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS
hits.ResolutionWidth), sum(__common_expr_1 AS hits.ResolutionWidth + Int64(1)),
sum(__common_expr_1 AS hits.ResolutionWidth + Int64(2)), sum(__common_expr_1 AS
hits.ResolutionWidth + Int64(3)), sum(__common_expr_1 AS hits.ResolutionWidth +
Int64(4)), sum(__common_expr_1 AS hits.ResolutionWidth + Int64(5)),
sum(__common_expr_1 AS hits.ResolutionWidth + Int64(6)), sum(__common_expr_1 AS
hits.ResolutionWidth + Int64(7)), sum(__common [...]
-02)--Projection: CAST(hits.ResolutionWidth AS Int64) AS __common_expr_1
-03)----SubqueryAlias: hits
-04)------TableScan: hits_raw projection=[ResolutionWidth]
+01)Projection: sum(hits.ResolutionWidth), sum(hits.ResolutionWidth) +
__common_expr_1 AS sum(hits.ResolutionWidth + Int64(1)),
sum(hits.ResolutionWidth) + Int64(2) * __common_expr_1 AS
sum(hits.ResolutionWidth + Int64(2)), sum(hits.ResolutionWidth) + Int64(3) *
__common_expr_1 AS sum(hits.ResolutionWidth + Int64(3)),
sum(hits.ResolutionWidth) + Int64(4) * __common_expr_1 AS
sum(hits.ResolutionWidth + Int64(4)), sum(hits.ResolutionWidth) + Int64(5) *
__common_expr_1 AS sum(hits.Resolution [...]
+02)--Projection: CAST(count(hits.ResolutionWidth) AS Int64) AS
__common_expr_1, sum(hits.ResolutionWidth)
+03)----Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_2 AS
hits.ResolutionWidth), count(__common_expr_2 AS hits.ResolutionWidth)]]
+04)------Projection: CAST(hits.ResolutionWidth AS Int64) AS __common_expr_2
+05)--------SubqueryAlias: hits
+06)----------TableScan: hits_raw projection=[ResolutionWidth]
physical_plan
-01)AggregateExec: mode=Single, gby=[], aggr=[sum(hits.ResolutionWidth),
sum(hits.ResolutionWidth + Int64(1)), sum(hits.ResolutionWidth + Int64(2)),
sum(hits.ResolutionWidth + Int64(3)), sum(hits.ResolutionWidth + Int64(4)),
sum(hits.ResolutionWidth + Int64(5)), sum(hits.ResolutionWidth + Int64(6)),
sum(hits.ResolutionWidth + Int64(7)), sum(hits.ResolutionWidth + Int64(8)),
sum(hits.ResolutionWidth + Int64(9)), sum(hits.ResolutionWidth + Int64(10)),
sum(hits.ResolutionWidth + Int64(11)), [...]
-02)--DataSourceExec: file_groups={1 group:
[[WORKSPACE_ROOT/datafusion/core/tests/data/clickbench_hits_10.parquet]]},
projection=[CAST(ResolutionWidth@20 AS Int64) as __common_expr_1],
file_type=parquet
+01)ProjectionExec: expr=[sum(hits.ResolutionWidth)@0 as
sum(hits.ResolutionWidth), sum(hits.ResolutionWidth)@0 +
count(hits.ResolutionWidth)@1 as sum(hits.ResolutionWidth + Int64(1)),
sum(hits.ResolutionWidth)@0 + 2 * count(hits.ResolutionWidth)@1 as
sum(hits.ResolutionWidth + Int64(2)), sum(hits.ResolutionWidth)@0 + 3 *
count(hits.ResolutionWidth)@1 as sum(hits.ResolutionWidth + Int64(3)),
sum(hits.ResolutionWidth)@0 + 4 * count(hits.ResolutionWidth)@1 as
sum(hits.ResolutionWidth + Int6 [...]
+02)--AggregateExec: mode=Single, gby=[], aggr=[sum(hits.ResolutionWidth),
count(hits.ResolutionWidth)]
+03)----DataSourceExec: file_groups={1 group:
[[WORKSPACE_ROOT/datafusion/core/tests/data/clickbench_hits_10.parquet]]},
projection=[CAST(ResolutionWidth@20 AS Int64) as __common_expr_2],
file_type=parquet
query
IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII
SELECT SUM("ResolutionWidth"), SUM("ResolutionWidth" + 1),
SUM("ResolutionWidth" + 2), SUM("ResolutionWidth" + 3), SUM("ResolutionWidth" +
4), SUM("ResolutionWidth" + 5), SUM("ResolutionWidth" + 6),
SUM("ResolutionWidth" + 7), SUM("ResolutionWidth" + 8), SUM("ResolutionWidth" +
9), SUM("ResolutionWidth" + 10), SUM("ResolutionWidth" + 11),
SUM("ResolutionWidth" + 12), SUM("ResolutionWidth" + 13), SUM("ResolutionWidth"
+ 14), SUM("ResolutionWidth" + 15), SUM("ResolutionWidth" + 16), SUM("R [...]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]