alamb commented on code in PR #20180:
URL: https://github.com/apache/datafusion/pull/20180#discussion_r2800713029


##########
datafusion/optimizer/src/rewrite_aggregate_with_constant.rs:
##########
@@ -0,0 +1,609 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! [`RewriteAggregateWithConstant`] rewrites `SUM(column ± constant)` to 
`SUM(column) ± constant * COUNT(column)`
+
+use crate::optimizer::ApplyOrder;
+use crate::{OptimizerConfig, OptimizerRule};
+
+use std::collections::HashMap;
+
+use datafusion_common::tree_node::Transformed;
+use datafusion_common::{Result, ScalarValue};
+use datafusion_expr::expr::AggregateFunctionParams;
+use datafusion_expr::{
+    Aggregate, BinaryExpr, Expr, LogicalPlan, LogicalPlanBuilder, Operator, 
binary_expr,
+    col, lit,
+};
+use datafusion_functions_aggregate::expr_fn::{count, sum};
+use indexmap::IndexMap;
+
+/// Optimizer rule that rewrites `SUM(column ± constant)` expressions
+/// into `SUM(column) ± constant * COUNT(column)` when multiple such 
expressions
+/// exist for the same base column.
+///
+/// This reduces computation by calculating SUM once and deriving other values.
+///
+/// # Example
+/// ```sql
+/// SELECT SUM(a), SUM(a + 1), SUM(a + 2) FROM t;
+/// ```
+/// is rewritten into a Projection on top of an Aggregate:
+/// ```sql
+/// -- New Projection Node
+/// SELECT sum_a, sum_a + 1 * count_a, sum_a + 2 * count_a
+/// -- New Aggregate Node
+/// FROM (SELECT SUM(a) as sum_a, COUNT(a) as count_a FROM t);
+/// ```
+#[derive(Default, Debug)]
+pub struct RewriteAggregateWithConstant {}
+
+impl RewriteAggregateWithConstant {
+    pub fn new() -> Self {
+        Self {}
+    }
+}
+
+impl OptimizerRule for RewriteAggregateWithConstant {
+    fn supports_rewrite(&self) -> bool {
+        true
+    }
+
+    fn rewrite(
+        &self,
+        plan: LogicalPlan,
+        _config: &dyn OptimizerConfig,
+    ) -> Result<Transformed<LogicalPlan>> {
+        match plan {
+            // This rule specifically targets Aggregate nodes
+            LogicalPlan::Aggregate(aggregate) => {
+                // Step 1: Identify which expressions can be rewritten and 
group them by base column
+                let rewrite_info = analyze_aggregate(&aggregate)?;
+
+                if rewrite_info.is_empty() {
+                    // No groups found with 2+ matching SUM expressions, 
return original plan
+                    return 
Ok(Transformed::no(LogicalPlan::Aggregate(aggregate)));
+                }
+
+                // Step 2: Perform the actual transformation into Aggregate + 
Projection
+                transform_aggregate(aggregate, &rewrite_info)
+            }
+            // Non-aggregate plans are passed through unchanged
+            _ => Ok(Transformed::no(plan)),
+        }
+    }
+
+    fn name(&self) -> &str {
+        "rewrite_aggregate_with_constant"
+    }
+
+    fn apply_order(&self) -> Option<ApplyOrder> {
+        // Bottom-up ensures we optimize subqueries before the outer query
+        Some(ApplyOrder::BottomUp)
+    }
+}
+
+/// Internal structure to track metadata for a SUM expression that qualifies 
for rewrite.
+#[derive(Debug, Clone)]
+struct SumWithConstant {
+    /// The inner expression being summed (e.g., the `a` in `SUM(a + 1)`)
+    base_expr: Expr,
+    /// The constant value being added/subtracted (e.g., `1` in `SUM(a + 1)`)
+    constant: ScalarValue,
+    /// The operator (`+` or `-`)
+    operator: Operator,
+    /// The index in the original Aggregate's `aggr_expr` list, used to 
maintain output order
+    original_index: usize,
+    // Note: ORDER BY inside SUM is irrelevant because SUM is commutative —
+    // the order of addition doesn't change the result. If this rule is ever
+    // extended to non-commutative aggregates, ORDER BY handling would need
+    // to be added back.
+}
+
+/// Maps a base expression's schema name to all its SUM(base ± const) variants.
+/// We use IndexMap to preserve insertion order, ensuring deterministic output
+/// in the rewritten plan (important for stable EXPLAIN output in tests).
+type RewriteGroups = IndexMap<String, Vec<SumWithConstant>>;
+
+/// Scans the aggregate expressions to find candidates for the rewrite.
+fn analyze_aggregate(aggregate: &Aggregate) -> Result<RewriteGroups> {
+    let mut groups: RewriteGroups = IndexMap::new();
+
+    for (idx, expr) in aggregate.aggr_expr.iter().enumerate() {
+        // Try to match the pattern SUM(col ± lit)
+        if let Some(sum_info) = extract_sum_with_constant(expr, idx)? {
+            let key = sum_info.base_expr.schema_name().to_string();
+            groups.entry(key).or_default().push(sum_info);
+        }
+    }
+
+    // Optimization: Only rewrite if we have at least 2 expressions for the 
same column.
+    // If there's only one SUM(a + 1), rewriting it to SUM(a) + 1*COUNT(a)
+    // actually increases the work (1 agg -> 2 aggs).
+    groups.retain(|_, v| v.len() >= 2);
+
+    Ok(groups)
+}
+
+/// Extract SUM(base_expr ± constant) pattern from an expression.
+/// Handles both `Expr::AggregateFunction(...)` and 
`Expr::Alias(Expr::AggregateFunction(...))`
+/// so the rule works regardless of whether aggregate expressions carry aliases
+/// (e.g., when plans are built via the LogicalPlanBuilder API).
+fn extract_sum_with_constant(expr: &Expr, idx: usize) -> 
Result<Option<SumWithConstant>> {
+    // Unwrap Expr::Alias if present — the SQL planner puts aliases in a
+    // Projection above the Aggregate, but the builder API allows aliases
+    // directly inside aggr_expr.
+    let inner = match expr {
+        Expr::Alias(alias) => alias.expr.as_ref(),
+        other => other,
+    };
+
+    match inner {
+        Expr::AggregateFunction(agg_fn) => {
+            // Rule only applies to SUM

Review Comment:
   Checking for SUM this was is non ideal because if someone has added a user 
defined function with a `sum` that has different behavior / semantics than the 
built in SUM aggregate this rule will still trigger
   
   > Hi, is this a blocking PR review comment? I am not able to understand how 
it would be simpler?
   > 
   > I am thinking that we still have the same logical flow
   > 
   > * Detect the pattern of SUM(col + literal)
   > * Modify the plan appropriately.
   
   I think it would be better for several reasons:
   1. A new optimizer pass adds non trivial overhead (it ends up copying each 
plan node I think) so we see planning time go up with each new rule we add to 
the base DataFusion
   2. It would fix the above issue



##########
datafusion/optimizer/src/rewrite_aggregate_with_constant.rs:
##########
@@ -0,0 +1,609 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! [`RewriteAggregateWithConstant`] rewrites `SUM(column ± constant)` to 
`SUM(column) ± constant * COUNT(column)`
+
+use crate::optimizer::ApplyOrder;
+use crate::{OptimizerConfig, OptimizerRule};
+
+use std::collections::HashMap;
+
+use datafusion_common::tree_node::Transformed;
+use datafusion_common::{Result, ScalarValue};
+use datafusion_expr::expr::AggregateFunctionParams;
+use datafusion_expr::{
+    Aggregate, BinaryExpr, Expr, LogicalPlan, LogicalPlanBuilder, Operator, 
binary_expr,
+    col, lit,
+};
+use datafusion_functions_aggregate::expr_fn::{count, sum};
+use indexmap::IndexMap;
+
+/// Optimizer rule that rewrites `SUM(column ± constant)` expressions
+/// into `SUM(column) ± constant * COUNT(column)` when multiple such 
expressions
+/// exist for the same base column.
+///
+/// This reduces computation by calculating SUM once and deriving other values.
+///
+/// # Example
+/// ```sql
+/// SELECT SUM(a), SUM(a + 1), SUM(a + 2) FROM t;
+/// ```
+/// is rewritten into a Projection on top of an Aggregate:
+/// ```sql
+/// -- New Projection Node
+/// SELECT sum_a, sum_a + 1 * count_a, sum_a + 2 * count_a
+/// -- New Aggregate Node
+/// FROM (SELECT SUM(a) as sum_a, COUNT(a) as count_a FROM t);
+/// ```
+#[derive(Default, Debug)]
+pub struct RewriteAggregateWithConstant {}
+
+impl RewriteAggregateWithConstant {
+    pub fn new() -> Self {
+        Self {}
+    }
+}
+
+impl OptimizerRule for RewriteAggregateWithConstant {
+    fn supports_rewrite(&self) -> bool {
+        true
+    }
+
+    fn rewrite(
+        &self,
+        plan: LogicalPlan,
+        _config: &dyn OptimizerConfig,
+    ) -> Result<Transformed<LogicalPlan>> {
+        match plan {
+            // This rule specifically targets Aggregate nodes
+            LogicalPlan::Aggregate(aggregate) => {
+                // Step 1: Identify which expressions can be rewritten and 
group them by base column
+                let rewrite_info = analyze_aggregate(&aggregate)?;
+
+                if rewrite_info.is_empty() {
+                    // No groups found with 2+ matching SUM expressions, 
return original plan
+                    return 
Ok(Transformed::no(LogicalPlan::Aggregate(aggregate)));
+                }
+
+                // Step 2: Perform the actual transformation into Aggregate + 
Projection
+                transform_aggregate(aggregate, &rewrite_info)
+            }
+            // Non-aggregate plans are passed through unchanged
+            _ => Ok(Transformed::no(plan)),
+        }
+    }
+
+    fn name(&self) -> &str {
+        "rewrite_aggregate_with_constant"
+    }
+
+    fn apply_order(&self) -> Option<ApplyOrder> {
+        // Bottom-up ensures we optimize subqueries before the outer query
+        Some(ApplyOrder::BottomUp)
+    }
+}
+
+/// Internal structure to track metadata for a SUM expression that qualifies 
for rewrite.
+#[derive(Debug, Clone)]
+struct SumWithConstant {
+    /// The inner expression being summed (e.g., the `a` in `SUM(a + 1)`)
+    base_expr: Expr,
+    /// The constant value being added/subtracted (e.g., `1` in `SUM(a + 1)`)
+    constant: ScalarValue,
+    /// The operator (`+` or `-`)
+    operator: Operator,
+    /// The index in the original Aggregate's `aggr_expr` list, used to 
maintain output order
+    original_index: usize,
+    // Note: ORDER BY inside SUM is irrelevant because SUM is commutative —
+    // the order of addition doesn't change the result. If this rule is ever
+    // extended to non-commutative aggregates, ORDER BY handling would need
+    // to be added back.
+}
+
+/// Maps a base expression's schema name to all its SUM(base ± const) variants.
+/// We use IndexMap to preserve insertion order, ensuring deterministic output
+/// in the rewritten plan (important for stable EXPLAIN output in tests).
+type RewriteGroups = IndexMap<String, Vec<SumWithConstant>>;
+
+/// Scans the aggregate expressions to find candidates for the rewrite.
+fn analyze_aggregate(aggregate: &Aggregate) -> Result<RewriteGroups> {
+    let mut groups: RewriteGroups = IndexMap::new();
+
+    for (idx, expr) in aggregate.aggr_expr.iter().enumerate() {
+        // Try to match the pattern SUM(col ± lit)
+        if let Some(sum_info) = extract_sum_with_constant(expr, idx)? {
+            let key = sum_info.base_expr.schema_name().to_string();
+            groups.entry(key).or_default().push(sum_info);
+        }
+    }
+
+    // Optimization: Only rewrite if we have at least 2 expressions for the 
same column.
+    // If there's only one SUM(a + 1), rewriting it to SUM(a) + 1*COUNT(a)
+    // actually increases the work (1 agg -> 2 aggs).
+    groups.retain(|_, v| v.len() >= 2);
+
+    Ok(groups)
+}
+
+/// Extract SUM(base_expr ± constant) pattern from an expression.
+/// Handles both `Expr::AggregateFunction(...)` and 
`Expr::Alias(Expr::AggregateFunction(...))`
+/// so the rule works regardless of whether aggregate expressions carry aliases
+/// (e.g., when plans are built via the LogicalPlanBuilder API).
+fn extract_sum_with_constant(expr: &Expr, idx: usize) -> 
Result<Option<SumWithConstant>> {
+    // Unwrap Expr::Alias if present — the SQL planner puts aliases in a
+    // Projection above the Aggregate, but the builder API allows aliases
+    // directly inside aggr_expr.
+    let inner = match expr {
+        Expr::Alias(alias) => alias.expr.as_ref(),
+        other => other,
+    };
+
+    match inner {
+        Expr::AggregateFunction(agg_fn) => {
+            // Rule only applies to SUM

Review Comment:
   I also think @UBarney has the same suggestion here
   - https://github.com/apache/datafusion/issues/15524#issuecomment-3876600012



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to