devanshu0987 commented on code in PR #20180:
URL: https://github.com/apache/datafusion/pull/20180#discussion_r2796616754


##########
datafusion/optimizer/src/rewrite_aggregate_with_constant.rs:
##########
@@ -0,0 +1,607 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! [`RewriteAggregateWithConstant`] rewrites `SUM(column ± constant)` to 
`SUM(column) ± constant * COUNT(column)`
+
+use crate::optimizer::ApplyOrder;
+use crate::{OptimizerConfig, OptimizerRule};
+
+use std::collections::HashMap;
+
+use datafusion_common::tree_node::Transformed;
+use datafusion_common::{Result, ScalarValue};
+use datafusion_expr::expr::AggregateFunctionParams;
+use datafusion_expr::{
+    Aggregate, BinaryExpr, Expr, LogicalPlan, LogicalPlanBuilder, Operator, 
binary_expr,
+    col, lit,
+};
+use datafusion_functions_aggregate::expr_fn::{count, sum};
+use indexmap::IndexMap;
+
+/// Optimizer rule that rewrites `SUM(column ± constant)` expressions
+/// into `SUM(column) ± constant * COUNT(column)` when multiple such 
expressions
+/// exist for the same base column.
+///
+/// This reduces computation by calculating SUM once and deriving other values.
+///
+/// # Example
+/// ```sql
+/// SELECT SUM(a), SUM(a + 1), SUM(a + 2) FROM t;
+/// ```
+/// is rewritten into a Projection on top of an Aggregate:
+/// ```sql
+/// -- New Projection Node
+/// SELECT sum_a, sum_a + 1 * count_a, sum_a + 2 * count_a
+/// -- New Aggregate Node
+/// FROM (SELECT SUM(a) as sum_a, COUNT(a) as count_a FROM t);
+/// ```
+#[derive(Default, Debug)]
+pub struct RewriteAggregateWithConstant {}
+
+impl RewriteAggregateWithConstant {
+    pub fn new() -> Self {
+        Self {}
+    }
+}
+
+impl OptimizerRule for RewriteAggregateWithConstant {
+    fn supports_rewrite(&self) -> bool {
+        true
+    }
+
+    fn rewrite(
+        &self,
+        plan: LogicalPlan,
+        _config: &dyn OptimizerConfig,
+    ) -> Result<Transformed<LogicalPlan>> {
+        match plan {
+            // This rule specifically targets Aggregate nodes
+            LogicalPlan::Aggregate(aggregate) => {
+                // Step 1: Identify which expressions can be rewritten and 
group them by base column
+                let rewrite_info = analyze_aggregate(&aggregate)?;
+
+                if rewrite_info.is_empty() {
+                    // No groups found with 2+ matching SUM expressions, 
return original plan
+                    return 
Ok(Transformed::no(LogicalPlan::Aggregate(aggregate)));
+                }
+
+                // Step 2: Perform the actual transformation into Aggregate + 
Projection
+                transform_aggregate(aggregate, &rewrite_info)
+            }
+            // Non-aggregate plans are passed through unchanged
+            _ => Ok(Transformed::no(plan)),
+        }
+    }
+
+    fn name(&self) -> &str {
+        "rewrite_aggregate_with_constant"
+    }
+
+    fn apply_order(&self) -> Option<ApplyOrder> {
+        // Bottom-up ensures we optimize subqueries before the outer query
+        Some(ApplyOrder::BottomUp)
+    }
+}
+
+/// Internal structure to track metadata for a SUM expression that qualifies 
for rewrite.
+#[derive(Debug, Clone)]
+struct SumWithConstant {
+    /// The inner expression being summed (e.g., the `a` in `SUM(a + 1)`)
+    base_expr: Expr,
+    /// The constant value being added/subtracted (e.g., `1` in `SUM(a + 1)`)
+    constant: ScalarValue,
+    /// The operator (`+` or `-`)
+    operator: Operator,
+    /// The index in the original Aggregate's `aggr_expr` list, used to 
maintain output order
+    original_index: usize,
+    // Note: ORDER BY inside SUM is irrelevant because SUM is commutative —
+    // the order of addition doesn't change the result. If this rule is ever
+    // extended to non-commutative aggregates, ORDER BY handling would need
+    // to be added back.
+}
+
+/// Maps a base expression's schema name to all its SUM(base ± const) variants.
+/// We use IndexMap to preserve insertion order, ensuring deterministic output
+/// in the rewritten plan (important for stable EXPLAIN output in tests).
+type RewriteGroups = IndexMap<String, Vec<SumWithConstant>>;
+
+/// Scans the aggregate expressions to find candidates for the rewrite.
+fn analyze_aggregate(aggregate: &Aggregate) -> Result<RewriteGroups> {
+    let mut groups: RewriteGroups = IndexMap::new();
+
+    for (idx, expr) in aggregate.aggr_expr.iter().enumerate() {
+        // Try to match the pattern SUM(col ± lit)
+        if let Some(sum_info) = extract_sum_with_constant(expr, idx)? {
+            let key = sum_info.base_expr.schema_name().to_string();
+            groups.entry(key).or_default().push(sum_info);
+        }
+    }
+
+    // Optimization: Only rewrite if we have at least 2 expressions for the 
same column.
+    // If there's only one SUM(a + 1), rewriting it to SUM(a) + 1*COUNT(a)
+    // actually increases the work (1 agg -> 2 aggs).
+    groups.retain(|_, v| v.len() >= 2);
+
+    Ok(groups)
+}
+
+/// Extract SUM(base_expr ± constant) pattern from an expression.
+/// Handles both `Expr::AggregateFunction(...)` and 
`Expr::Alias(Expr::AggregateFunction(...))`
+/// so the rule works regardless of whether aggregate expressions carry aliases
+/// (e.g., when plans are built via the LogicalPlanBuilder API).
+fn extract_sum_with_constant(expr: &Expr, idx: usize) -> 
Result<Option<SumWithConstant>> {
+    // Unwrap Expr::Alias if present — the SQL planner puts aliases in a
+    // Projection above the Aggregate, but the builder API allows aliases
+    // directly inside aggr_expr.
+    let inner = match expr {
+        Expr::Alias(alias) => alias.expr.as_ref(),
+        other => other,
+    };
+
+    match inner {
+        Expr::AggregateFunction(agg_fn) => {
+            // Rule only applies to SUM
+            if agg_fn.func.name().to_lowercase() != "sum" {
+                return Ok(None);
+            }
+
+            let AggregateFunctionParams {
+                args,
+                distinct,
+                filter,
+                order_by: _,
+                null_treatment: _,
+            } = &agg_fn.params;
+
+            // We cannot easily rewrite SUM(DISTINCT a + 1) or SUM(a + 1) 
FILTER (...)
+            // as the math SUM(a) + k*COUNT(a) wouldn't hold correctly with 
these modifiers.
+            if *distinct || filter.is_some() {
+                return Ok(None);
+            }
+
+            // SUM must have exactly one argument (e.g. SUM(a + 1)).
+            // This rejects invalid calls like SUM() or non-standard 
multi-argument variations.
+            if args.len() != 1 {
+                return Ok(None);
+            }
+
+            let arg = &args[0];
+
+            // Try to match: base_expr +/- constant
+            // Note: If the base_expr is complex (e.g., SUM(a + b + 1)), 
base_expr will be "a + b".
+            // The rule will still work if multiple SUMs have the exact same 
complex base_expr,
+            // as they will be grouped by the string representation of that 
expression.
+            if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = arg
+                && matches!(op, Operator::Plus | Operator::Minus)
+            {
+                // Check if right side is a literal constant
+                // Check if right side is a literal constant (e.g., SUM(a + 1))
+                if let Expr::Literal(constant, _) = right.as_ref()
+                    && is_numeric_constant(constant)
+                {
+                    return Ok(Some(SumWithConstant {
+                        base_expr: (**left).clone(),
+                        constant: constant.clone(),
+                        operator: *op,
+                        original_index: idx,
+                    }));
+                }
+
+                // Also check left side for commutative addition (e.g., SUM(1 
+ a))
+                // Does NOT apply to subtraction: SUM(5 - a) ≠ SUM(a - 5)
+                if let Expr::Literal(constant, _) = left.as_ref()
+                    && is_numeric_constant(constant)
+                    && *op == Operator::Plus
+                {
+                    return Ok(Some(SumWithConstant {
+                        base_expr: (**right).clone(),
+                        constant: constant.clone(),
+                        operator: Operator::Plus,
+                        original_index: idx,
+                    }));
+                }
+            }
+
+            Ok(None)
+        }
+        _ => Ok(None),
+    }
+}
+
+/// Check if a scalar value is a numeric constant
+/// (guards against non-arithmetic types like strings, booleans, dates, etc.)
+fn is_numeric_constant(value: &ScalarValue) -> bool {
+    matches!(
+        value,
+        ScalarValue::Int8(_)
+            | ScalarValue::Int16(_)
+            | ScalarValue::Int32(_)
+            | ScalarValue::Int64(_)
+            | ScalarValue::UInt8(_)
+            | ScalarValue::UInt16(_)
+            | ScalarValue::UInt32(_)
+            | ScalarValue::UInt64(_)
+            | ScalarValue::Float32(_)
+            | ScalarValue::Float64(_)
+            | ScalarValue::Decimal128(_, _, _)
+            | ScalarValue::Decimal256(_, _, _)

Review Comment:
   Hi, yes, let me update this. Thanks.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to