This is an automated email from the ASF dual-hosted git repository.

yaooqinn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 97325877edfa [SPARK-56949][SQL] DecimalAggregates drops evalMode when 
rewriting Sum/Average
97325877edfa is described below

commit 97325877edfad314bff3326b6b9b98a53bbce43d
Author: Kent Yao <[email protected]>
AuthorDate: Thu May 21 02:06:41 2026 +0800

    [SPARK-56949][SQL] DecimalAggregates drops evalMode when rewriting 
Sum/Average
    
    ### What changes were proposed in this pull request?
    
    In `Optimizer.DecimalAggregates`, rewrite the 6 `Sum` / `Average` rewrites
    from the single-arg helper ctor (`Sum(child)` / `Average(child)`) to
    `original.copy(child = UnscaledValue(...))`, so sibling fields like
    `evalMode` are preserved instead of re-read from `SQLConf`. The 6 sites
    are 4 Aggregate arms (SUM/AVG x un-widened/widened peel) plus 2 Window
    arms (un-widened SUM and AVG); both Window arms are directly reachable
    because the input is a bare `AttributeReference`, not a hoisted
    composite Cast.
    
    The helper ctors themselves are retained — used by other call sites.
    
    ### Why are the changes needed?
    
    Latent bug since SPARK-38432 added `evalMode` to `Sum` / `Average`: the
    helper ctor re-reads `EvalMode` from `SQLConf`, silently dropping
    `EvalMode.TRY` from `try_sum` / `try_avg`. `Sum.shouldTrackIsEmpty =
    evalMode == TRY` then drops the `isEmpty` aggregation buffer column, so
    overflow no longer returns NULL.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. `try_sum` / `try_avg` on narrow-precision decimals now preserve
    `EvalMode.TRY` through the optimizer rewrite, so overflow returns NULL
    as documented instead of wrapping (`ansi=false`) or throwing
    (`ansi=true`). Affected surface: SUM `precision <= 8`, AVG `precision <= 7`
    (both Aggregate and Window arms).
    
    This is observable directly in the optimized plan — the rewritten Sum /
    Average node carries `evalMode = TRY` post-fix vs `evalMode = ANSI` (the
    SQLConf default, re-read by the helper ctor) pre-fix.
    
    ### How was this patch tested?
    
    6 new unit tests in `DecimalAggregatesSuite` asserting `evalMode == TRY`
    post-rewrite for: un-widened Aggregate SUM, un-widened Aggregate AVG,
    widened-cast Aggregate SUM, widened-cast Aggregate AVG, un-widened
    Window SUM, un-widened Window AVG.
    
    - Pre-fix: 6/6 FAIL with `Expected evalMode=TRY post-rewrite, got ANSI`
    - Post-fix: 37/37 PASS (6 new + 31 existing, 0 regression)
    - Sibling regressions clean: `DecimalPrecisionSuite` +
      `AggregateOptimizeSuite` 20/20 PASS
    - `dev/lint-scala` clean
    
    **Backport candidates**: branch-3.5 and branch-4.0 carry the 4 helper-ctor
    sites (2 Aggregate + 2 Window) — verified by `git grep "aggregateFunction =
    (Sum|Average)\(UnscaledValue"`. The 2 widened-cast Aggregate sites are
    master-only (SPARK-56627). Backport patches would touch only the 4 shared
    sites.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Generated-by: Claude Opus 4.7
    
    Closes #56009 from 
yaooqinn/users/kentyao/decimal-aggregate-evalmode-preserve.
    
    Authored-by: Kent Yao <[email protected]>
    Signed-off-by: Kent Yao <[email protected]>
---
 .../spark/sql/catalyst/optimizer/Optimizer.scala   | 36 ++++----
 .../optimizer/DecimalAggregatesSuite.scala         | 95 +++++++++++++++++++++-
 2 files changed, 114 insertions(+), 17 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 22cd7fda2414..b3e7eb44ae65 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -2587,13 +2587,16 @@ object DecimalAggregates extends Rule[LogicalPlan] {
         // Window arm: `ExtractWindowExpressions` hoists composite children
         // (here the widening Cast) into a child Project, so widened-Cast
         // peel is unreachable from this expression-level rule.
-        case Sum(e @ DecimalExpression(prec, scale), _) if prec + 10 <= 
MAX_LONG_DIGITS =>
-          MakeDecimal(we.copy(windowFunction = ae.copy(aggregateFunction = 
Sum(UnscaledValue(e)))),
+        case s @ Sum(e @ DecimalExpression(prec, scale), _) if prec + 10 <= 
MAX_LONG_DIGITS =>
+          MakeDecimal(we.copy(windowFunction =
+            ae.copy(aggregateFunction = s.copy(child = UnscaledValue(e)))),
             prec + 10, scale)
 
-        case Average(e @ DecimalExpression(prec, scale), _) if prec + 4 <= 
MAX_DOUBLE_DIGITS =>
+        case a @ Average(e @ DecimalExpression(prec, scale), _)
+            if prec + 4 <= MAX_DOUBLE_DIGITS =>
           val newAggExpr =
-            we.copy(windowFunction = ae.copy(aggregateFunction = 
Average(UnscaledValue(e))))
+            we.copy(windowFunction = ae.copy(aggregateFunction =
+              a.copy(child = UnscaledValue(e))))
           Cast(
             Divide(newAggExpr, Literal.create(math.pow(10.0, scale), 
DoubleType)),
             DecimalType(prec + 4, scale + 4), 
Option(conf.sessionLocalTimeZone))
@@ -2601,29 +2604,30 @@ object DecimalAggregates extends Rule[LogicalPlan] {
         case _ => we
       }
       case ae @ AggregateExpression(af, _, _, _, _) => af match {
-        case Sum(WidenedDecimalChild(inner, p, pPrime, s), _)
+        case s @ Sum(WidenedDecimalChild(inner, p, pPrime, s_scale), _)
             if p + 10 <= MAX_LONG_DIGITS =>
           Cast(
             MakeDecimal(
-              ae.copy(aggregateFunction = Sum(UnscaledValue(inner))),
-              p + 10, s),
-            DecimalType.bounded(pPrime + 10, s),
+              ae.copy(aggregateFunction = s.copy(child = 
UnscaledValue(inner))),
+              p + 10, s_scale),
+            DecimalType.bounded(pPrime + 10, s_scale),
             Option(conf.sessionLocalTimeZone))
 
-        case Sum(e @ DecimalExpression(prec, scale), _) if prec + 10 <= 
MAX_LONG_DIGITS =>
-          MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec 
+ 10, scale)
+        case s @ Sum(e @ DecimalExpression(prec, scale), _) if prec + 10 <= 
MAX_LONG_DIGITS =>
+          MakeDecimal(ae.copy(aggregateFunction = s.copy(child = 
UnscaledValue(e))),
+            prec + 10, scale)
 
         // Ordered before the un-widened Average arm: when pPrime in [8, 11],
         // the outer Cast's DecimalType would otherwise match that arm first.
-        case Average(WidenedDecimalChild(inner, p, pPrime, s), _)
+        case a @ Average(WidenedDecimalChild(inner, p, pPrime, s_scale), _)
             if p <= AVG_PEEL_MAX_INNER_PRECISION =>
-          val newAggExpr = ae.copy(aggregateFunction = 
Average(UnscaledValue(inner)))
+          val newAggExpr = ae.copy(aggregateFunction = a.copy(child = 
UnscaledValue(inner)))
           Cast(
-            Divide(newAggExpr, Literal.create(math.pow(10.0, s), DoubleType)),
-            DecimalType.bounded(pPrime + 4, s + 4), 
Option(conf.sessionLocalTimeZone))
+            Divide(newAggExpr, Literal.create(math.pow(10.0, s_scale), 
DoubleType)),
+            DecimalType.bounded(pPrime + 4, s_scale + 4), 
Option(conf.sessionLocalTimeZone))
 
-        case Average(e @ DecimalExpression(prec, scale), _) if prec + 4 <= 
MAX_DOUBLE_DIGITS =>
-          val newAggExpr = ae.copy(aggregateFunction = 
Average(UnscaledValue(e)))
+        case a @ Average(e @ DecimalExpression(prec, scale), _) if prec + 4 <= 
MAX_DOUBLE_DIGITS =>
+          val newAggExpr = ae.copy(aggregateFunction = a.copy(child = 
UnscaledValue(e)))
           Cast(
             Divide(newAggExpr, Literal.create(math.pow(10.0, scale), 
DoubleType)),
             DecimalType(prec + 4, scale + 4), 
Option(conf.sessionLocalTimeZone))
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala
index bc5f8b984f5c..6f8c0db261b2 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala
@@ -23,7 +23,7 @@ import 
org.scalatestplus.scalacheck.ScalaCheckDrivenPropertyChecks
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
+import org.apache.spark.sql.catalyst.expressions.aggregate.{Average, Sum}
 import org.apache.spark.sql.catalyst.plans.PlanTest
 import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
 import org.apache.spark.sql.catalyst.rules.RuleExecutor
@@ -582,4 +582,97 @@ class DecimalAggregatesSuite extends PlanTest with 
ScalaCheckDrivenPropertyCheck
     val correctAnswer = q.analyze
     comparePlans(optimized, correctAnswer)
   }
+
+  // 
---------------------------------------------------------------------------
+  // SPARK-56949: DecimalAggregates must preserve evalMode / evalContext when
+  // rewriting Sum / Average through the fast-path. The pre-fix rule called the
+  // single-arg helper ctor `Sum(child)` / `Average(child)`, which re-reads
+  // EvalMode from SQLConf and silently drops EvalMode.TRY from try_sum /
+  // try_avg, breaking their "return NULL on overflow" semantics.
+  //
+  // Vanilla 3.5.3 ground-truth (rule OFF vs ON) recorded in todos repo:
+  //   features/spark-decimal-aggregate-evalmode-preserve/docs/0001-idea.md 
(section 3)
+  // 
---------------------------------------------------------------------------
+
+  private def findSum(plan: LogicalPlan): Seq[Sum] =
+    plan.collect { case n => n.expressions }.flatten
+      .flatMap(_.collect { case s: Sum => s })
+  private def findAverage(plan: LogicalPlan): Seq[Average] =
+    plan.collect { case n => n.expressions }.flatten
+      .flatMap(_.collect { case a: Average => a })
+
+  test("SPARK-56949: DecimalAggregates preserves Sum.evalContext for try_sum") 
{
+    val trySum = Sum($"a", NumericEvalContext(EvalMode.TRY))
+    val q = testRelation.select(trySum.toAggregateExpression().as("ts"))
+    val optimized = Optimize.execute(q.analyze)
+    val sums = findSum(optimized)
+    assert(sums.nonEmpty, "DecimalAggregates fast path should fire for 
dec(2,1)")
+    assert(sums.forall(_.evalContext.evalMode == EvalMode.TRY),
+      s"evalMode should be preserved as TRY after rewrite, got " +
+        sums.map(_.evalContext.evalMode).mkString(","))
+  }
+
+  test("SPARK-56949: DecimalAggregates preserves Average.evalMode for 
try_avg") {
+    val tryAvg = Average($"a", EvalMode.TRY)
+    val q = testRelation.select(tryAvg.toAggregateExpression().as("ta"))
+    val optimized = Optimize.execute(q.analyze)
+    val avgs = findAverage(optimized)
+    assert(avgs.nonEmpty, "DecimalAggregates fast path should fire for 
dec(2,1)")
+    assert(avgs.forall(_.evalMode == EvalMode.TRY),
+      s"evalMode should be preserved as TRY after rewrite, got " +
+        avgs.map(_.evalMode).mkString(","))
+  }
+
+  test("SPARK-56949: DecimalAggregates preserves Sum.evalContext " +
+      "for try_sum on widened-cast peel arm") {
+    val trySum = Sum($"d7_2".cast(DecimalType(12, 2)),
+      NumericEvalContext(EvalMode.TRY))
+    val q = widenRel.select(trySum.toAggregateExpression().as("ts"))
+    val optimized = Optimize.execute(q.analyze)
+    val sums = findSum(optimized)
+    assert(sums.nonEmpty, "widened-cast SUM peel should fire for 
dec(7,2)->dec(12,2)")
+    assert(sums.forall(_.evalContext.evalMode == EvalMode.TRY),
+      s"evalMode should be preserved as TRY after rewrite, got " +
+        sums.map(_.evalContext.evalMode).mkString(","))
+  }
+
+  test("SPARK-56949: DecimalAggregates preserves Average.evalMode " +
+      "for try_avg on widened-cast peel arm") {
+    val tryAvg = Average($"d7_2".cast(DecimalType(12, 2)), EvalMode.TRY)
+    val q = widenRel.select(tryAvg.toAggregateExpression().as("ta"))
+    val optimized = Optimize.execute(q.analyze)
+    val avgs = findAverage(optimized)
+    assert(avgs.nonEmpty, "widened-cast AVG peel should fire for 
dec(7,2)->dec(12,2)")
+    assert(avgs.forall(_.evalMode == EvalMode.TRY),
+      s"evalMode should be preserved as TRY after rewrite, got " +
+        avgs.map(_.evalMode).mkString(","))
+  }
+
+  test("SPARK-56949: DecimalAggregates preserves Sum.evalContext " +
+      "for try_sum over Window (un-widened arm)") {
+    val spec = windowSpec(Seq($"a"), Nil, UnspecifiedFrame)
+    val trySum = Sum($"a", NumericEvalContext(EvalMode.TRY))
+    val q = testRelation.select(
+      windowExpr(trySum.toAggregateExpression(), spec).as("ts"))
+    val optimized = Optimize.execute(q.analyze)
+    val sums = findSum(optimized)
+    assert(sums.nonEmpty, "Window-arm SUM peel should fire for dec(2,1)")
+    assert(sums.forall(_.evalContext.evalMode == EvalMode.TRY),
+      s"evalMode should be preserved as TRY after rewrite, got " +
+        sums.map(_.evalContext.evalMode).mkString(","))
+  }
+
+  test("SPARK-56949: DecimalAggregates preserves Average.evalMode " +
+      "for try_avg over Window (un-widened arm)") {
+    val spec = windowSpec(Seq($"a"), Nil, UnspecifiedFrame)
+    val tryAvg = Average($"a", EvalMode.TRY)
+    val q = testRelation.select(
+      windowExpr(tryAvg.toAggregateExpression(), spec).as("ta"))
+    val optimized = Optimize.execute(q.analyze)
+    val avgs = findAverage(optimized)
+    assert(avgs.nonEmpty, "Window-arm AVG peel should fire for dec(2,1)")
+    assert(avgs.forall(_.evalMode == EvalMode.TRY),
+      s"evalMode should be preserved as TRY after rewrite, got " +
+        avgs.map(_.evalMode).mkString(","))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to