This is an automated email from the ASF dual-hosted git repository.
yaooqinn pushed a commit to branch branch-4.x
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.x by this push:
new f74802e99282 [SPARK-56949][SQL] DecimalAggregates drops evalMode when
rewriting Sum/Average
f74802e99282 is described below
commit f74802e992825d19b5817f1561a54bc898768d58
Author: Kent Yao <[email protected]>
AuthorDate: Thu May 21 02:06:41 2026 +0800
[SPARK-56949][SQL] DecimalAggregates drops evalMode when rewriting
Sum/Average
### What changes were proposed in this pull request?
In `Optimizer.DecimalAggregates`, rewrite the 6 `Sum` / `Average` rewrites
from the single-arg helper ctor (`Sum(child)` / `Average(child)`) to
`original.copy(child = UnscaledValue(...))`, so sibling fields like
`evalMode` are preserved instead of re-read from `SQLConf`. The 6 sites
are 4 Aggregate arms (SUM/AVG x un-widened/widened peel) plus 2 Window
arms (un-widened SUM and AVG); both Window arms are directly reachable
because the input is a bare `AttributeReference`, not a hoisted
composite Cast.
The helper ctors themselves are retained — used by other call sites.
### Why are the changes needed?
Latent bug since SPARK-38432 added `evalMode` to `Sum` / `Average`: the
helper ctor re-reads `EvalMode` from `SQLConf`, silently dropping
`EvalMode.TRY` from `try_sum` / `try_avg`. `Sum.shouldTrackIsEmpty =
evalMode == TRY` then drops the `isEmpty` aggregation buffer column, so
overflow no longer returns NULL.
### Does this PR introduce _any_ user-facing change?
Yes. `try_sum` / `try_avg` on narrow-precision decimals now preserve
`EvalMode.TRY` through the optimizer rewrite, so overflow returns NULL
as documented instead of wrapping (`ansi=false`) or throwing
(`ansi=true`). Affected surface: SUM `precision <= 8`, AVG `precision <= 7`
(both Aggregate and Window arms).
This is observable directly in the optimized plan — the rewritten Sum /
Average node carries `evalMode = TRY` post-fix vs `evalMode = ANSI` (the
SQLConf default, re-read by the helper ctor) pre-fix.
### How was this patch tested?
6 new unit tests in `DecimalAggregatesSuite` asserting `evalMode == TRY`
post-rewrite for: un-widened Aggregate SUM, un-widened Aggregate AVG,
widened-cast Aggregate SUM, widened-cast Aggregate AVG, un-widened
Window SUM, un-widened Window AVG.
- Pre-fix: 6/6 FAIL with `Expected evalMode=TRY post-rewrite, got ANSI`
- Post-fix: 37/37 PASS (6 new + 31 existing, 0 regression)
- Sibling regressions clean: `DecimalPrecisionSuite` +
`AggregateOptimizeSuite` 20/20 PASS
- `dev/lint-scala` clean
**Backport candidates**: branch-3.5 and branch-4.0 carry the 4 helper-ctor
sites (2 Aggregate + 2 Window) — verified by `git grep "aggregateFunction =
(Sum|Average)\(UnscaledValue"`. The 2 widened-cast Aggregate sites are
master-only (SPARK-56627). Backport patches would touch only the 4 shared
sites.
### Was this patch authored or co-authored using generative AI tooling?
Generated-by: Claude Opus 4.7
Closes #56009 from
yaooqinn/users/kentyao/decimal-aggregate-evalmode-preserve.
Authored-by: Kent Yao <[email protected]>
Signed-off-by: Kent Yao <[email protected]>
(cherry picked from commit 97325877edfad314bff3326b6b9b98a53bbce43d)
Signed-off-by: Kent Yao <[email protected]>
---
.../spark/sql/catalyst/optimizer/Optimizer.scala | 36 ++++----
.../optimizer/DecimalAggregatesSuite.scala | 95 +++++++++++++++++++++-
2 files changed, 114 insertions(+), 17 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 22cd7fda2414..b3e7eb44ae65 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -2587,13 +2587,16 @@ object DecimalAggregates extends Rule[LogicalPlan] {
// Window arm: `ExtractWindowExpressions` hoists composite children
// (here the widening Cast) into a child Project, so widened-Cast
// peel is unreachable from this expression-level rule.
- case Sum(e @ DecimalExpression(prec, scale), _) if prec + 10 <=
MAX_LONG_DIGITS =>
- MakeDecimal(we.copy(windowFunction = ae.copy(aggregateFunction =
Sum(UnscaledValue(e)))),
+ case s @ Sum(e @ DecimalExpression(prec, scale), _) if prec + 10 <=
MAX_LONG_DIGITS =>
+ MakeDecimal(we.copy(windowFunction =
+ ae.copy(aggregateFunction = s.copy(child = UnscaledValue(e)))),
prec + 10, scale)
- case Average(e @ DecimalExpression(prec, scale), _) if prec + 4 <=
MAX_DOUBLE_DIGITS =>
+ case a @ Average(e @ DecimalExpression(prec, scale), _)
+ if prec + 4 <= MAX_DOUBLE_DIGITS =>
val newAggExpr =
- we.copy(windowFunction = ae.copy(aggregateFunction =
Average(UnscaledValue(e))))
+ we.copy(windowFunction = ae.copy(aggregateFunction =
+ a.copy(child = UnscaledValue(e))))
Cast(
Divide(newAggExpr, Literal.create(math.pow(10.0, scale),
DoubleType)),
DecimalType(prec + 4, scale + 4),
Option(conf.sessionLocalTimeZone))
@@ -2601,29 +2604,30 @@ object DecimalAggregates extends Rule[LogicalPlan] {
case _ => we
}
case ae @ AggregateExpression(af, _, _, _, _) => af match {
- case Sum(WidenedDecimalChild(inner, p, pPrime, s), _)
+ case s @ Sum(WidenedDecimalChild(inner, p, pPrime, s_scale), _)
if p + 10 <= MAX_LONG_DIGITS =>
Cast(
MakeDecimal(
- ae.copy(aggregateFunction = Sum(UnscaledValue(inner))),
- p + 10, s),
- DecimalType.bounded(pPrime + 10, s),
+ ae.copy(aggregateFunction = s.copy(child =
UnscaledValue(inner))),
+ p + 10, s_scale),
+ DecimalType.bounded(pPrime + 10, s_scale),
Option(conf.sessionLocalTimeZone))
- case Sum(e @ DecimalExpression(prec, scale), _) if prec + 10 <=
MAX_LONG_DIGITS =>
- MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec
+ 10, scale)
+ case s @ Sum(e @ DecimalExpression(prec, scale), _) if prec + 10 <=
MAX_LONG_DIGITS =>
+ MakeDecimal(ae.copy(aggregateFunction = s.copy(child =
UnscaledValue(e))),
+ prec + 10, scale)
// Ordered before the un-widened Average arm: when pPrime in [8, 11],
// the outer Cast's DecimalType would otherwise match that arm first.
- case Average(WidenedDecimalChild(inner, p, pPrime, s), _)
+ case a @ Average(WidenedDecimalChild(inner, p, pPrime, s_scale), _)
if p <= AVG_PEEL_MAX_INNER_PRECISION =>
- val newAggExpr = ae.copy(aggregateFunction =
Average(UnscaledValue(inner)))
+ val newAggExpr = ae.copy(aggregateFunction = a.copy(child =
UnscaledValue(inner)))
Cast(
- Divide(newAggExpr, Literal.create(math.pow(10.0, s), DoubleType)),
- DecimalType.bounded(pPrime + 4, s + 4),
Option(conf.sessionLocalTimeZone))
+ Divide(newAggExpr, Literal.create(math.pow(10.0, s_scale),
DoubleType)),
+ DecimalType.bounded(pPrime + 4, s_scale + 4),
Option(conf.sessionLocalTimeZone))
- case Average(e @ DecimalExpression(prec, scale), _) if prec + 4 <=
MAX_DOUBLE_DIGITS =>
- val newAggExpr = ae.copy(aggregateFunction =
Average(UnscaledValue(e)))
+ case a @ Average(e @ DecimalExpression(prec, scale), _) if prec + 4 <=
MAX_DOUBLE_DIGITS =>
+ val newAggExpr = ae.copy(aggregateFunction = a.copy(child =
UnscaledValue(e)))
Cast(
Divide(newAggExpr, Literal.create(math.pow(10.0, scale),
DoubleType)),
DecimalType(prec + 4, scale + 4),
Option(conf.sessionLocalTimeZone))
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala
index bc5f8b984f5c..6f8c0db261b2 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala
@@ -23,7 +23,7 @@ import
org.scalatestplus.scalacheck.ScalaCheckDrivenPropertyChecks
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
+import org.apache.spark.sql.catalyst.expressions.aggregate.{Average, Sum}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
@@ -582,4 +582,97 @@ class DecimalAggregatesSuite extends PlanTest with
ScalaCheckDrivenPropertyCheck
val correctAnswer = q.analyze
comparePlans(optimized, correctAnswer)
}
+
+ //
---------------------------------------------------------------------------
+ // SPARK-56949: DecimalAggregates must preserve evalMode / evalContext when
+ // rewriting Sum / Average through the fast-path. The pre-fix rule called the
+ // single-arg helper ctor `Sum(child)` / `Average(child)`, which re-reads
+ // EvalMode from SQLConf and silently drops EvalMode.TRY from try_sum /
+ // try_avg, breaking their "return NULL on overflow" semantics.
+ //
+ // Vanilla 3.5.3 ground-truth (rule OFF vs ON) recorded in todos repo:
+ // features/spark-decimal-aggregate-evalmode-preserve/docs/0001-idea.md
(section 3)
+ //
---------------------------------------------------------------------------
+
+ private def findSum(plan: LogicalPlan): Seq[Sum] =
+ plan.collect { case n => n.expressions }.flatten
+ .flatMap(_.collect { case s: Sum => s })
+ private def findAverage(plan: LogicalPlan): Seq[Average] =
+ plan.collect { case n => n.expressions }.flatten
+ .flatMap(_.collect { case a: Average => a })
+
+ test("SPARK-56949: DecimalAggregates preserves Sum.evalContext for try_sum")
{
+ val trySum = Sum($"a", NumericEvalContext(EvalMode.TRY))
+ val q = testRelation.select(trySum.toAggregateExpression().as("ts"))
+ val optimized = Optimize.execute(q.analyze)
+ val sums = findSum(optimized)
+ assert(sums.nonEmpty, "DecimalAggregates fast path should fire for
dec(2,1)")
+ assert(sums.forall(_.evalContext.evalMode == EvalMode.TRY),
+ s"evalMode should be preserved as TRY after rewrite, got " +
+ sums.map(_.evalContext.evalMode).mkString(","))
+ }
+
+ test("SPARK-56949: DecimalAggregates preserves Average.evalMode for
try_avg") {
+ val tryAvg = Average($"a", EvalMode.TRY)
+ val q = testRelation.select(tryAvg.toAggregateExpression().as("ta"))
+ val optimized = Optimize.execute(q.analyze)
+ val avgs = findAverage(optimized)
+ assert(avgs.nonEmpty, "DecimalAggregates fast path should fire for
dec(2,1)")
+ assert(avgs.forall(_.evalMode == EvalMode.TRY),
+ s"evalMode should be preserved as TRY after rewrite, got " +
+ avgs.map(_.evalMode).mkString(","))
+ }
+
+ test("SPARK-56949: DecimalAggregates preserves Sum.evalContext " +
+ "for try_sum on widened-cast peel arm") {
+ val trySum = Sum($"d7_2".cast(DecimalType(12, 2)),
+ NumericEvalContext(EvalMode.TRY))
+ val q = widenRel.select(trySum.toAggregateExpression().as("ts"))
+ val optimized = Optimize.execute(q.analyze)
+ val sums = findSum(optimized)
+ assert(sums.nonEmpty, "widened-cast SUM peel should fire for
dec(7,2)->dec(12,2)")
+ assert(sums.forall(_.evalContext.evalMode == EvalMode.TRY),
+ s"evalMode should be preserved as TRY after rewrite, got " +
+ sums.map(_.evalContext.evalMode).mkString(","))
+ }
+
+ test("SPARK-56949: DecimalAggregates preserves Average.evalMode " +
+ "for try_avg on widened-cast peel arm") {
+ val tryAvg = Average($"d7_2".cast(DecimalType(12, 2)), EvalMode.TRY)
+ val q = widenRel.select(tryAvg.toAggregateExpression().as("ta"))
+ val optimized = Optimize.execute(q.analyze)
+ val avgs = findAverage(optimized)
+ assert(avgs.nonEmpty, "widened-cast AVG peel should fire for
dec(7,2)->dec(12,2)")
+ assert(avgs.forall(_.evalMode == EvalMode.TRY),
+ s"evalMode should be preserved as TRY after rewrite, got " +
+ avgs.map(_.evalMode).mkString(","))
+ }
+
+ test("SPARK-56949: DecimalAggregates preserves Sum.evalContext " +
+ "for try_sum over Window (un-widened arm)") {
+ val spec = windowSpec(Seq($"a"), Nil, UnspecifiedFrame)
+ val trySum = Sum($"a", NumericEvalContext(EvalMode.TRY))
+ val q = testRelation.select(
+ windowExpr(trySum.toAggregateExpression(), spec).as("ts"))
+ val optimized = Optimize.execute(q.analyze)
+ val sums = findSum(optimized)
+ assert(sums.nonEmpty, "Window-arm SUM peel should fire for dec(2,1)")
+ assert(sums.forall(_.evalContext.evalMode == EvalMode.TRY),
+ s"evalMode should be preserved as TRY after rewrite, got " +
+ sums.map(_.evalContext.evalMode).mkString(","))
+ }
+
+ test("SPARK-56949: DecimalAggregates preserves Average.evalMode " +
+ "for try_avg over Window (un-widened arm)") {
+ val spec = windowSpec(Seq($"a"), Nil, UnspecifiedFrame)
+ val tryAvg = Average($"a", EvalMode.TRY)
+ val q = testRelation.select(
+ windowExpr(tryAvg.toAggregateExpression(), spec).as("ta"))
+ val optimized = Optimize.execute(q.analyze)
+ val avgs = findAverage(optimized)
+ assert(avgs.nonEmpty, "Window-arm AVG peel should fire for dec(2,1)")
+ assert(avgs.forall(_.evalMode == EvalMode.TRY),
+ s"evalMode should be preserved as TRY after rewrite, got " +
+ avgs.map(_.evalMode).mkString(","))
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]