This is an automated email from the ASF dual-hosted git repository.
yaooqinn pushed a commit to branch branch-4.x
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.x by this push:
new 4b12d57c4919 [SPARK-56627][SQL] DecimalAggregates: peel
scale-preserving widening Cast for SUM and AVG
4b12d57c4919 is described below
commit 4b12d57c491914d83b70556febebedafe9b9f617
Author: Kent Yao <[email protected]>
AuthorDate: Tue May 19 19:41:48 2026 +0800
[SPARK-56627][SQL] DecimalAggregates: peel scale-preserving widening Cast
for SUM and AVG
### What changes were proposed in this pull request?
Extend `DecimalAggregates` to peel a scale-preserving widening `Cast`
around `Sum`/`Average` arguments, recovering the long-backed fast path when the
inner expression's precision still fits the existing safety bounds.
When the input is `Sum(Cast(inner: dec(p, s), dec(p', s)))` with `p' >= p`:
- SUM arm fires under `p + 10 <= 18`, identical to the existing SUM
fast-path guard.
- AVG arm fires under `p <= 7` (`AVG_PEEL_MAX_INNER_PRECISION`), strictly
tighter than the existing AVG arm's `p + 4 <= 15` (= `p <= 11`), to avoid
amplifying SPARK-37024 Double-regime precision loss.
Both arms share a `WidenedDecimalChild` extractor that refuses to unwrap
`CheckOverflow` (preserves row-level overflow semantics). Window arm is
unchanged: `ExtractWindowExpressions` hoists the `Cast` into a preceding
`Project`, so an expression-level rewrite cannot see it.
### Why are the changes needed?
The existing fast path keys off the declared precision `p'` after a
widening Cast, not the effective precision `p` of the inner expression. User
patterns like `SUM(CAST(small_dec AS larger_dec))` — common from BI tools
generating SQL with normalized types — fall off the fast path even though `p +
10 <= 18`. TPC-DS q18 exhibits this pattern.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
- `DecimalAggregatesSuite`, including invariant-guard tests that lock the
SUM/AVG safety boundaries.
- ScalaCheck property-based tests in `DataFrameAggregateSuite` for
numerical equivalence of the peeled and un-peeled paths.
- `TPCDSV1_4PlanStabilitySuite` and `TPCDSV1_4PlanStabilityWithStatsSuite`
regenerated for q18.
- `DecimalAggregatesBenchmark` added; results committed for JDK 17/21/25
under `sql/core/benchmarks/`.
### Was this patch authored or co-authored using generative AI tooling?
Generated-by: Claude Opus 4.7
Closes #55957 from yaooqinn/decimal-cast-peel.
Authored-by: Kent Yao <[email protected]>
Signed-off-by: Kent Yao <[email protected]>
(cherry picked from commit 8d45d475ca0a40048fa2323b6e8c64fd2baf3017)
Signed-off-by: Kent Yao <[email protected]>
---
.../spark/sql/catalyst/optimizer/Optimizer.scala | 36 ++
.../optimizer/DecimalAggregatesSuite.scala | 466 ++++++++++++++++++++-
.../DecimalAggregatesBenchmark-jdk21-results.txt | 74 ++++
.../DecimalAggregatesBenchmark-jdk25-results.txt | 74 ++++
.../DecimalAggregatesBenchmark-results.txt | 74 ++++
.../approved-plans-v1_4/q18.sf100/explain.txt | 8 +-
.../approved-plans-v1_4/q18.sf100/simplified.txt | 2 +-
.../approved-plans-v1_4/q18/explain.txt | 8 +-
.../approved-plans-v1_4/q18/simplified.txt | 2 +-
.../apache/spark/sql/DataFrameAggregateSuite.scala | 192 ++++++++-
.../benchmark/DecimalAggregatesBenchmark.scala | 208 +++++++++
11 files changed, 1131 insertions(+), 13 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index ddfe80443d56..22cd7fda2414 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -2564,11 +2564,29 @@ object DecimalAggregates extends Rule[LogicalPlan] {
/** Maximum number of decimal digits representable precisely in a Double */
private val MAX_DOUBLE_DIGITS = 15
+ /** Tighter than the AVG fast path's `prec + 4 <= MAX_DOUBLE_DIGITS` (= 11):
+ * the strict-subset keeps SPARK-37024 Double-regime exposure unchanged. */
+ private val AVG_PEEL_MAX_INNER_PRECISION = 7
+
+ /** Matches a scale-preserving widening decimal Cast; refuses CheckOverflow
+ * to preserve overflow semantics on the unscaled value. */
+ private object WidenedDecimalChild {
+ def unapply(e: Expression): Option[(Expression, Int, Int, Int)] = e match {
+ case Cast(inner @ DecimalExpression(p, s), DecimalType.Fixed(pPrime,
sPrime), _, _)
+ if s == sPrime && pPrime >= p && !inner.isInstanceOf[CheckOverflow]
=>
+ Some((inner, p, pPrime, s))
+ case _ => None
+ }
+ }
+
def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsAnyPattern(SUM, AVERAGE), ruleId) {
case q: LogicalPlan => q.transformExpressionsDownWithPruning(
_.containsAnyPattern(SUM, AVERAGE), ruleId) {
case we @ WindowExpression(ae @ AggregateExpression(af, _, _, _, _), _)
=> af match {
+ // Window arm: `ExtractWindowExpressions` hoists composite children
+ // (here the widening Cast) into a child Project, so widened-Cast
+ // peel is unreachable from this expression-level rule.
case Sum(e @ DecimalExpression(prec, scale), _) if prec + 10 <=
MAX_LONG_DIGITS =>
MakeDecimal(we.copy(windowFunction = ae.copy(aggregateFunction =
Sum(UnscaledValue(e)))),
prec + 10, scale)
@@ -2583,9 +2601,27 @@ object DecimalAggregates extends Rule[LogicalPlan] {
case _ => we
}
case ae @ AggregateExpression(af, _, _, _, _) => af match {
+ case Sum(WidenedDecimalChild(inner, p, pPrime, s), _)
+ if p + 10 <= MAX_LONG_DIGITS =>
+ Cast(
+ MakeDecimal(
+ ae.copy(aggregateFunction = Sum(UnscaledValue(inner))),
+ p + 10, s),
+ DecimalType.bounded(pPrime + 10, s),
+ Option(conf.sessionLocalTimeZone))
+
case Sum(e @ DecimalExpression(prec, scale), _) if prec + 10 <=
MAX_LONG_DIGITS =>
MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec
+ 10, scale)
+ // Ordered before the un-widened Average arm: when pPrime in [8, 11],
+ // the outer Cast's DecimalType would otherwise match that arm first.
+ case Average(WidenedDecimalChild(inner, p, pPrime, s), _)
+ if p <= AVG_PEEL_MAX_INNER_PRECISION =>
+ val newAggExpr = ae.copy(aggregateFunction =
Average(UnscaledValue(inner)))
+ Cast(
+ Divide(newAggExpr, Literal.create(math.pow(10.0, s), DoubleType)),
+ DecimalType.bounded(pPrime + 4, s + 4),
Option(conf.sessionLocalTimeZone))
+
case Average(e @ DecimalExpression(prec, scale), _) if prec + 4 <=
MAX_DOUBLE_DIGITS =>
val newAggExpr = ae.copy(aggregateFunction =
Average(UnscaledValue(e)))
Cast(
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala
index 25adbce143fb..bc5f8b984f5c 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala
@@ -17,15 +17,19 @@
package org.apache.spark.sql.catalyst.optimizer
+import org.scalacheck.Gen
+import org.scalatestplus.scalacheck.ScalaCheckDrivenPropertyChecks
+
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
-import org.apache.spark.sql.types.DecimalType
+import org.apache.spark.sql.types.{Decimal, DecimalType, DoubleType, LongType}
-class DecimalAggregatesSuite extends PlanTest {
+class DecimalAggregatesSuite extends PlanTest with
ScalaCheckDrivenPropertyChecks {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("Decimal Optimizations", FixedPoint(100),
@@ -68,6 +72,115 @@ class DecimalAggregatesSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
+ val testRelationC = LocalRelation($"c".decimal(7, 2))
+
+ test("Decimal Average Aggregation widened-cast peel: Optimized (p=7,
p'=12)") {
+ val widened = $"c".cast(DecimalType(12, 2))
+ val originalQuery = testRelationC.select(avg(widened))
+ val optimized = Optimize.execute(originalQuery.analyze)
+ val correctAnswer = testRelationC
+ .select((avg(UnscaledValue($"c")) / 100.0).cast(DecimalType(16, 6))
+ .as("avg(CAST(c AS DECIMAL(12,2)))")).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("Decimal Average Aggregation widened-cast peel: Not Optimized
(narrowing cast)") {
+ val testRelationD = LocalRelation($"d".decimal(10, 2))
+ val narrowed = $"d".cast(DecimalType(8, 2))
+ val originalQuery = testRelationD.select(avg(narrowed))
+ val optimized = Optimize.execute(originalQuery.analyze)
+ val correctAnswer = testRelationD
+ .select((avg(UnscaledValue(narrowed)) / 100.0).cast(DecimalType(12, 6))
+ .as("avg(CAST(d AS DECIMAL(8,2)))")).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("Decimal Average Aggregation widened-cast peel: Not Optimized (scale
change)") {
+ val testRelationD = LocalRelation($"d".decimal(7, 2))
+ val rescaled = $"d".cast(DecimalType(12, 4))
+ val originalQuery = testRelationD.select(avg(rescaled))
+ val optimized = Optimize.execute(originalQuery.analyze)
+ val correctAnswer = originalQuery.analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("Decimal Average Aggregation widened-cast peel: Not Optimized (boundary
p=8)") {
+ val testRelationE = LocalRelation($"e".decimal(8, 2))
+ val widened = $"e".cast(DecimalType(13, 2))
+ val originalQuery = testRelationE.select(avg(widened))
+ val optimized = Optimize.execute(originalQuery.analyze)
+ val correctAnswer = originalQuery.analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ // SPARK-56627 F2 regression: with `pPrime in [8, 11]`, the outer Cast's
+ // dataType `Decimal(pPrime, s)` would let the un-widened existing
+ // `Average(DecimalExpression)` arm match first via `prec + 4 <=
MAX_DOUBLE_DIGITS`
+ // (= pPrime <= 11). New AVG peel arm must be ordered before to win this band
+ // and rewrite via the inner `p`-based UnscaledValue path.
+ test("Decimal Average Aggregation widened-cast peel: " +
+ "Optimized for pPrime band [8, 11] (must beat existing AVG fast-path
arm)") {
+ val testRelationE = LocalRelation($"e".decimal(7, 2))
+ val widened = $"e".cast(DecimalType(10, 2))
+ val originalQuery = testRelationE.select(avg(widened).as("avg_widened"))
+ val optimized = Optimize.execute(originalQuery.analyze)
+ // Expected: peeled via WidenedDecimalChild(inner=e, p=7, pPrime=10, s=2),
+ // outer Cast bounded(pPrime+4=14, s+4=6). NOT
+ // `MakeDecimal(Sum(UnscaledValue(cast(e as dec(10,2)))), 14, 2)` (existing
+ // arm form), which would lose F2's intent of avoiding the widened-cast
+ // intermediate.
+ val correctAnswer = testRelationE
+ .select(
+ Cast(
+ Divide(
+ avg(UnscaledValue($"e")),
+ Literal.create(math.pow(10.0, 2), DoubleType)),
+ DecimalType.bounded(14, 6),
+ Option(conf.sessionLocalTimeZone))
+ .as("avg_widened"))
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ // SPARK-56627 F1 regression: `WidenedDecimalChild` must NOT peel when the
+ // inner expression is a `CheckOverflow` (introduced by `DecimalPrecision`
+ // analyzer for nullOnOverflow semantics). Peeling through `CheckOverflow`
+ // would change the overflow behavior of the inner aggregate.
+ //
+ // The existing un-widened `Average(DecimalExpression)` arm still fires on
+ // the outer Cast (dataType `Decimal(pPrime=10, s=2)`, `pPrime + 4 = 14 <=
15`),
+ // so the optimized plan wraps `UnscaledValue` around the OUTER cast (not
+ // the inner CheckOverflow). The peel-arm-fired form would instead be
+ // `UnscaledValue(CheckOverflow(e))` (no outer cast), which we want to AVOID.
+ test("Decimal Average Aggregation widened-cast peel: " +
+ "Not peeled for Cast(CheckOverflow(inner), wider) form (F1 guard)") {
+ val testRelationE = LocalRelation($"e".decimal(7, 2))
+ val co = CheckOverflow($"e", DecimalType(7, 2), nullOnOverflow = true)
+ val widened = Cast(co, DecimalType(10, 2))
+ val originalQuery = testRelationE.select(avg(widened).as("avg_co"))
+ val optimized = Optimize.execute(originalQuery.analyze)
+
+ // Existing un-widened AVG arm fires on the outer Cast (pPrime=10,
+ // pPrime + 4 = 14 <= 15), wrapping UnscaledValue around the OUTER cast.
+ val correctAnswer = testRelationE
+ .select(
+ Cast(
+ Divide(
+ avg(UnscaledValue(widened)),
+ Literal.create(math.pow(10.0, 2), DoubleType)),
+ DecimalType(14, 6),
+ Option(conf.sessionLocalTimeZone))
+ .as("avg_co"))
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
test("Decimal Sum Aggregation over Window: Optimized") {
val spec = windowSpec(Seq($"a"), Nil, UnspecifiedFrame)
val originalQuery = testRelation.select(windowExpr(sum($"a"),
spec).as("sum_a"))
@@ -120,4 +233,353 @@ class DecimalAggregatesSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
+
+ //
---------------------------------------------------------------------------
+ // Widened-Cast Peel (SUM-only) -- SPARK-56627
+ //
+ // Extractor `WidenedDecimalChild` recognises a scale-preserving widening
+ // Cast, enabling the existing fast path to fire on `SUM(CAST(x, wider))`
+ // patterns that previously fell off the p+10 <= MAX_LONG_DIGITS guard.
+ //
+ // These tests assert behavioural plan-shape invariants via the local
+ // `Optimize` RuleExecutor (runs only DecimalAggregates). Literal-no-peel
+ // is covered separately via SimpleTestOptimizer because the local
+ // RuleExecutor here does not run ConstantFolding.
+ //
---------------------------------------------------------------------------
+
+ private val widenRel = LocalRelation(
+ $"d7_2".decimal(7, 2),
+ $"d8_2".decimal(8, 2),
+ $"d9_2".decimal(9, 2),
+ $"d17_2".decimal(17, 2),
+ $"i".int)
+
+ test("SPARK-56627: SUM(CAST(dec(7,2) AS dec(17,2))) peels via widened-Cast
fast path") {
+ // Witness chosen so p+10=17 <= MAX_LONG_DIGITS(18) < pPrime+10=27 -- the
+ // new case fires (a bare-Cast inner cannot fall through to the existing
+ // DecimalExpression case). Expected shape:
+ // Cast(MakeDecimal(Sum(UnscaledValue(d7_2)), p+10=17, s=2),
+ // DecimalType.bounded(pPrime+10=27, s=2),
+ // Option(conf.sessionLocalTimeZone))
+ val q = widenRel.select(sum($"d7_2".cast(DecimalType(17, 2))))
+ val optimized = Optimize.execute(q.analyze)
+ val correctAnswer = widenRel
+ .select(Cast(
+ MakeDecimal(sum(UnscaledValue($"d7_2")), 17, 2),
+ DecimalType.bounded(27, 2),
+ Option(conf.sessionLocalTimeZone))
+ .as("sum(CAST(d7_2 AS DECIMAL(17,2)))")).analyze
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("SPARK-56627: SUM(CAST(dec(7,2) AS dec(17,2))) -- peel preserves
schema") {
+ // Schema invariance via DataType equality (not string).
+ // Top-level output type of SUM(dec(p,s)) is DecimalType(min(p+10,38), s);
+ // peeled tree wraps inner with outer Cast(_, dec(pPrime+10,s)) = dec(27,2)
+ // -- identical to baseline schema.
+ val q = widenRel.select(sum($"d7_2".cast(DecimalType(17, 2))))
+ val baselineSchema = q.analyze.schema
+ val optimized = Optimize.execute(q.analyze)
+ assert(optimized.schema === baselineSchema,
+ s"peel changed schema: $baselineSchema -> ${optimized.schema}")
+ }
+
+ test("SPARK-56627: SUM(CAST(int AS dec(10,0))) does NOT peel (non-decimal
inner)") {
+ val q = widenRel.select(sum($"i".cast(DecimalType(10, 0))))
+ val optimized = Optimize.execute(q.analyze)
+ // Peel must NOT fire; plan shape == input analyze.
+ val correctAnswer = q.analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("SPARK-56627: AVG(CAST(dec(7,2) AS dec(17,2))) -- peel preserves
schema") {
+ val q = widenRel.select(avg($"d7_2".cast(DecimalType(17, 2))))
+ val baselineSchema = q.analyze.schema
+ val optimized = Optimize.execute(q.analyze)
+ assert(optimized.schema === baselineSchema,
+ s"peel changed schema: $baselineSchema -> ${optimized.schema}")
+ }
+
+ test("SPARK-56627: SUM(CAST(dec(7,2) AS dec(18,6))) does NOT peel (scale
change)") {
+ val q = widenRel.select(sum($"d7_2".cast(DecimalType(18, 6))))
+ val optimized = Optimize.execute(q.analyze)
+ val correctAnswer = q.analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("SPARK-56627: SUM(CAST(dec(17,2) AS dec(10,2))) does NOT peel
(narrowing)") {
+ val q = widenRel.select(sum($"d17_2".cast(DecimalType(10, 2))))
+ val optimized = Optimize.execute(q.analyze)
+ val correctAnswer = q.analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("SPARK-56627: SUM(CheckOverflow(Cast(...))) does NOT peel") {
+ val co = CheckOverflow(
+ $"d7_2".cast(DecimalType(17, 2)), DecimalType(17, 2), nullOnOverflow =
true)
+ val q = widenRel.select(sum(co).as("s"))
+ val optimized = Optimize.execute(q.analyze)
+ val correctAnswer = q.analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ // Pre-existing fast-path regression guard.
+ // Witness: SUM(d7_2), no Cast. p+10 = 17 <= MAX_LONG_DIGITS(18) hits the
+ // existing `Sum(e @ DecimalExpression(p, s))` case. The new peel case
+ // prepended must NOT shadow the existing fast path on no-cast inputs.
+ test("SPARK-56627: SUM(dec(7,2)) hits existing DecimalExpression fast path")
{
+ val expected = widenRel
+ .select(MakeDecimal(sum(UnscaledValue($"d7_2")), 17,
2).as("sum(d7_2)")).analyze
+ val q = widenRel.select(sum($"d7_2"))
+ val optimized = Optimize.execute(q.analyze)
+ comparePlans(optimized, expected)
+ }
+
+ // Literal-in-Cast no-peel regression guard.
+ //
+ // Uses `SimpleTestOptimizer` (full optimizer batches) rather than the local
+ // `Optimize` RuleExecutor defined above, because this test depends on
+ // `ConstantFolding` running before `DecimalAggregates`: the outer Cast on a
+ // foldable Literal child is folded away before the peel rule ever sees it,
+ // so there is no Cast left to peel. Post-optimization the plan contains
+ // neither `MakeDecimal` nor an `UnscaledValue` call -- SUM sees a bare
+ // `Literal(dec(17,2))` whose precision (17) already fails the existing
+ // `prec + 10 <= MAX_LONG_DIGITS` guard (27 > 18), so the whole SUM arm is
+ // a no-op. The assertion is deliberately absence-of-peel-shape rather than
+ // structural equality, to survive unrelated ConstantFolding changes.
+ test("SPARK-56627: SUM(CAST(Literal(dec(7,2)) AS dec(17,2))) does NOT peel "
+
+ "after ConstantFolding") {
+ val lit = Literal.create(Decimal("1.23"), DecimalType(7, 2))
+ val q = widenRel.select(sum(lit.cast(DecimalType(17, 2))))
+ val optimized = SimpleTestOptimizer.execute(q.analyze)
+ val hasMakeDecimal = optimized.expressions.exists(_.exists {
+ case _: MakeDecimal => true
+ case _ => false
+ })
+ val hasUnscaledValue = optimized.expressions.exists(_.exists {
+ case _: UnscaledValue => true
+ case _ => false
+ })
+ assert(!hasMakeDecimal,
+ s"peel unexpectedly fired on a folded Literal child; plan:\n$optimized")
+ assert(!hasUnscaledValue,
+ s"UnscaledValue unexpectedly present on folded Literal child;
plan:\n$optimized")
+ }
+
+ // Plan-shape invariant guards (null / empty-relation witnesses).
+ //
+ // DecimalAggregatesSuite is a PlanTest without a SparkSession; the local
+ // `Optimize` RuleExecutor runs DecimalAggregates only. At plan level, an
+ // all-null Literal-typed column shares the extractor path of any other
+ // DecimalExpression, and an empty LocalRelation shares the shape of the
+ // non-empty widenRel. These two witnesses assert the peel rule fires
+ // identically to the canonical witness under both inputs -- rule body is
+ // data-independent. End-to-end null-propagation semantics are covered
+ // separately in the sql-core equivalence suite.
+
+ test("SPARK-56627: SUM(CAST(Literal(null, dec(7,2)) AS dec(17,2))) peels " +
+ "(null Literal in Cast, plan-shape invariant)") {
+ val nullLit = Literal.create(null, DecimalType(7, 2))
+ val q = widenRel.select(sum(nullLit.cast(DecimalType(17, 2))))
+ val optimized = Optimize.execute(q.analyze)
+ val correctAnswer = widenRel
+ .select(Cast(
+ MakeDecimal(sum(UnscaledValue(nullLit)), 17, 2),
+ DecimalType.bounded(27, 2),
+ Option(conf.sessionLocalTimeZone))
+ .as("sum(CAST(NULL AS DECIMAL(17,2)))")).analyze
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("SPARK-56627: SUM(CAST(dec(7,2) AS dec(17,2))) on empty LocalRelation
peels " +
+ "(empty-relation plan-shape invariant)") {
+ val emptyRel = LocalRelation($"d7_2".decimal(7, 2))
+ val q = emptyRel.select(sum($"d7_2".cast(DecimalType(17, 2))))
+ val optimized = Optimize.execute(q.analyze)
+ val correctAnswer = emptyRel
+ .select(Cast(
+ MakeDecimal(sum(UnscaledValue($"d7_2")), 17, 2),
+ DecimalType.bounded(27, 2),
+ Option(conf.sessionLocalTimeZone))
+ .as("sum(CAST(d7_2 AS DECIMAL(17,2)))")).analyze
+ comparePlans(optimized, correctAnswer)
+ }
+
+ // Idempotence invariant guard.
+ //
+ // Post-peel, the `Sum` child is `UnscaledValue(DecimalExpression)` which
+ // types to `LongType`, so the `WidenedDecimalChild` extractor (which
+ // guards on `DecimalType(p, s)` with `p + 10 <= MAX_LONG_DIGITS < p' + 10`)
+ // cannot re-match on the second pass. Use `canonicalized` (not `==`) to
+ // neutralise `exprId` drift across `Sum` aggregate-expression allocation
+ // in successive rule applications.
+ test("SPARK-56627: DecimalAggregates is idempotent on canonical widened
witness " +
+ "(peel(peel(t)) == peel(t) under canonicalization)") {
+ val q = widenRel.select(sum($"d7_2".cast(DecimalType(17, 2)))).analyze
+ val once = DecimalAggregates(q)
+ val twice = DecimalAggregates(DecimalAggregates(q))
+ assert(once.canonicalized == twice.canonicalized,
+ s"DecimalAggregates re-fired on already-peeled plan.\n" +
+ s"once:\n$once\ntwice:\n$twice")
+ }
+
+ // RuleExecutor convergence: drive DecimalAggregates inside a fixed-point
+ // RuleExecutor batch and assert it converges in <= 1 application after the
+ // first peel. Catches accidental rewrite oscillations in fixed-point
batches.
+ test("SPARK-56627: DecimalAggregates converges under RuleExecutor on widened
SUM") {
+ object Once extends RuleExecutor[LogicalPlan] {
+ val batches: Seq[Batch] =
+ Seq(Batch("DecimalAggregates", FixedPoint(10), DecimalAggregates))
+ }
+ val q = widenRel.select(sum($"d7_2".cast(DecimalType(17, 2)))).analyze
+ val once = DecimalAggregates(q)
+ val converged = Once.execute(q)
+ assert(once.canonicalized == converged.canonicalized,
+ s"FixedPoint did not converge to single peel.\n" +
+ s"once:\n$once\nconverged:\n$converged")
+ }
+
+ // Negative guard-miss: at p=9, the inner decimal already exceeds the
+ // existing DecimalExpression fast path (p+10=19 > MAX_LONG_DIGITS=18) so
+ // the peel rewrite must NOT fire. Pin via plan-equality against analyzed.
+ test("SPARK-56627: SUM(CAST(dec(9,2) AS dec(19,2))) does NOT peel (p=9
guard)") {
+ val rel = LocalRelation($"d9_2".decimal(9, 2))
+ val q = rel.select(sum($"d9_2".cast(DecimalType(19, 2)))).analyze
+ val optimized = Optimize.execute(q)
+ comparePlans(optimized, q)
+ }
+
+ // Plan-shape property: structural invariants on the peeled tree.
+ //
+ // Sweeps the (p, p', s) lattice where the widened-cast peel fires:
+ // regime (ii): p + 10 <= 18 <= p' + 10 (new arm, old fast-path off)
+ // regime (iii): p + 10 <= 18 < p' + 10 <= 38
+ // Assertion (peel-on, structural -- NOT a hand-typed RHS clone):
+ // - aggregate expression is wrapped by exactly one outer Cast
+ // - the outer Cast wraps exactly one MakeDecimal
+ // - inside MakeDecimal, the Sum's child has dataType=LongType (i.e.
+ // UnscaledValue was inserted)
+ // - outer Cast target precision = p' + 10 (or 38, clamped)
+ // - outer Cast target scale = s
+ // Reframed away from RHS-equality to detect behavioural regressions
+ // rather than just refactor drift.
+ // Peel-off branch: plan is unchanged relative to its analyzed form
+ // (the local RuleExecutor runs only DecimalAggregates; no other rule
+ // can rewrite the SUM when the peel does not fire for a Cast child).
+
+ private case class PeelInputs(p: Int, pPrime: Int, s: Int)
+
+ private val peelGen: Gen[PeelInputs] = Gen.frequency(
+ 5 -> (for {
+ p <- Gen.choose(1, 8)
+ pPrime <- Gen.choose(math.max(p + 1, 9), 28)
+ s <- Gen.choose(0, p)
+ } yield PeelInputs(p, pPrime, s)),
+ 5 -> (for {
+ p <- Gen.choose(1, 8)
+ pPrime <- Gen.choose(9, 28)
+ s <- Gen.choose(0, p)
+ } yield PeelInputs(p, pPrime, s))
+ )
+
+ private val boundaryGen: Gen[PeelInputs] = Gen.oneOf(
+ PeelInputs(7, 17, 2), PeelInputs(7, 18, 2), PeelInputs(7, 19, 2))
+
+ private val peelSpaceGen: Gen[PeelInputs] = Gen.frequency(
+ 8 -> peelGen,
+ 2 -> boundaryGen
+ ).retryUntil(in => in.p + 10 <= 18 && in.p < in.pPrime && in.pPrime + 10 <=
38)
+
+ implicit override val generatorDrivenConfig: PropertyCheckConfiguration =
+ PropertyCheckConfiguration(minSuccessful = 50, minSize = 0, sizeRange = 0)
+
+ test("SPARK-56627: DecimalAggregates widened-Cast SUM peel -- plan-shape " +
+ "structural-invariants property") {
+ forAll(peelSpaceGen) { in =>
+ val rel = LocalRelation($"x".decimal(in.p, in.s))
+ val q = rel.select(sum($"x".cast(DecimalType(in.pPrime, in.s))))
+ val analyzed = q.analyze
+
+ val optimized = Optimize.execute(analyzed)
+
+ // Structural invariants the peel rewrite must establish, regardless
+ // of incidental tree-shape changes from neighbouring rules:
+ //
+ // I1. exactly one Sum node, whose child has LongType (the peeled
+ // UnscaledValue feed);
+ // I2. exactly one MakeDecimal node in the tree (rebuilds Decimal
+ // from the LONG accumulator);
+ // I3. an outer Cast whose target DecimalType has precision at
+ // least as wide as the user-written widened cast, so we never
+ // narrow result precision below the baseline plan.
+ val sums = optimized.expressions.flatMap(_.collect { case s: Sum => s })
+ assert(sums.size == 1, s"expected exactly 1 Sum, got ${sums.size} in
$optimized")
+ assert(sums.head.child.dataType == LongType,
+ s"expected Sum.child: LongType, got ${sums.head.child.dataType} in
$optimized")
+
+ val mds = optimized.expressions.flatMap(_.collect { case m: MakeDecimal
=> m })
+ assert(mds.size == 1,
+ s"expected exactly 1 MakeDecimal, got ${mds.size} in $optimized")
+
+ val outerCasts = optimized.expressions.flatMap(_.collect {
+ case c @ Cast(_, _: DecimalType, _, _) => c
+ })
+ assert(outerCasts.nonEmpty,
+ s"expected an outer Cast to DecimalType, got none in $optimized")
+ val outerPrec =
outerCasts.map(_.dataType.asInstanceOf[DecimalType].precision).max
+ assert(outerPrec >= in.pPrime,
+ s"outer Cast precision $outerPrec < baseline ${in.pPrime} in
$optimized")
+ }
+ }
+
+ //
---------------------------------------------------------------------------
+ // F5 (skeptic round 1): Long-accumulator / Double-regime safety boundary
+ // invariant guards.
+ //
+ // Background: a strict "overflow oracle" cannot be written at unit-test
+ // scale -- the existing fast-path guards (`p + 10 <= MAX_LONG_DIGITS = 18`
+ // for SUM, `AVG_PEEL_MAX_INNER_PRECISION = 7` for AVG) keep the
peel-eligible
+ // inner-precision band so narrow that the Long accumulator (~9.22e18) cannot
+ // wrap on any reachable peel input: at `p=8` we'd need ~9.22e10 rows. So
+ // there is no production input that exercises a "peeled Long-wrap vs
+ // un-peeled CheckOverflow" asymmetry to oracle against.
+ //
+ // What we CAN lock is the boundary itself: if someone in the future relaxes
+ // either guard (raising `MAX_LONG_DIGITS - 10` for SUM, or
+ // `AVG_PEEL_MAX_INNER_PRECISION` for AVG), the input shapes below WOULD
+ // start peeling -- and the assertion that the rule is a no-op for these
+ // inputs would fail. That is the safety net we want: a mechanical guard
+ // that catches accidental widening of the peel-trigger surface.
+ test("SPARK-56627: SUM(CAST(dec(9,2) AS dec(19,2))) does NOT peel " +
+ "(Long-accumulator safety boundary)") {
+ // Boundary witness: inner p=9 makes widened-arm `p + 10 = 19 > 18` reject,
+ // AND outer-cast existing-arm `prec + 10 = 29 > 18` reject. Both arms are
+ // no-ops by design -- peel cannot fire on this shape today, and must not
+ // start firing if the inner-precision band is later widened without
+ // re-deriving the Long-accumulator bound.
+ val q = widenRel.select(sum($"d9_2".cast(DecimalType(19, 2))))
+ val optimized = Optimize.execute(q.analyze)
+ val correctAnswer = q.analyze
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("SPARK-56627: AVG(CAST(dec(8,2) AS dec(20,2))) does NOT peel " +
+ "(Double-regime / SPARK-37024 safety boundary)") {
+ // Boundary witness: inner p=8 makes widened-AVG arm
+ // `p > AVG_PEEL_MAX_INNER_PRECISION (7)` reject, AND outer-cast existing
+ // AVG arm `prec + 4 = 24 > MAX_DOUBLE_DIGITS (15)` reject. The strict-
+ // subset guard `p <= 7` keeps this rule's trigger surface strictly
+ // inside the existing AVG fast path's surface, so SPARK-37024
+ // (Double-regime silent precision loss) is not amplified. If someone
+ // raises `AVG_PEEL_MAX_INNER_PRECISION` past 7 without first fixing
+ // SPARK-37024, this test will start firing and flag the regression.
+ val q = widenRel.select(avg($"d8_2".cast(DecimalType(20, 2))))
+ val optimized = Optimize.execute(q.analyze)
+ val correctAnswer = q.analyze
+ comparePlans(optimized, correctAnswer)
+ }
}
diff --git a/sql/core/benchmarks/DecimalAggregatesBenchmark-jdk21-results.txt
b/sql/core/benchmarks/DecimalAggregatesBenchmark-jdk21-results.txt
new file mode 100644
index 000000000000..1186901b3575
--- /dev/null
+++ b/sql/core/benchmarks/DecimalAggregatesBenchmark-jdk21-results.txt
@@ -0,0 +1,74 @@
+================================================================================================
+DecimalAggregates SUM widened-cast peel (Aggregate)
+================================================================================================
+
+OpenJDK 64-Bit Server VM 21.0.11+10-LTS on Linux 6.17.0-1013-azure
+Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz
+A1 p=7 s=2 p'=8: Best Time(ms) Avg Time(ms)
Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+native (no cast, rule on) 2178 2236
56 4.6 217.8 1.0X
+widened cast, peel off 2369 2381
9 4.2 236.9 0.9X
+widened cast, peel on 2105 2118
12 4.8 210.5 1.0X
+
+OpenJDK 64-Bit Server VM 21.0.11+10-LTS on Linux 6.17.0-1013-azure
+Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz
+A2 p=7 s=2 p'=17: Best Time(ms) Avg Time(ms)
Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+native (no cast, rule on) 2103 2115
17 4.8 210.3 1.0X
+widened cast, peel off 2366 2377
7 4.2 236.6 0.9X
+widened cast, peel on 2100 2109
11 4.8 210.0 1.0X
+
+OpenJDK 64-Bit Server VM 21.0.11+10-LTS on Linux 6.17.0-1013-azure
+Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz
+A3 p=5 s=0 p'=6: Best Time(ms) Avg Time(ms)
Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+native (no cast, rule on) 2117 2138
29 4.7 211.7 1.0X
+widened cast, peel off 2403 2416
13 4.2 240.3 0.9X
+widened cast, peel on 2157 2164
7 4.6 215.7 1.0X
+
+OpenJDK 64-Bit Server VM 21.0.11+10-LTS on Linux 6.17.0-1013-azure
+Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz
+A4 p=5 s=0 p'=15: Best Time(ms) Avg Time(ms)
Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+native (no cast, rule on) 2151 2157
7 4.6 215.1 1.0X
+widened cast, peel off 2420 2427
10 4.1 242.0 0.9X
+widened cast, peel on 2152 2159
9 4.6 215.2 1.0X
+
+
+================================================================================================
+DecimalAggregates AVG widened-cast peel (Aggregate)
+================================================================================================
+
+OpenJDK 64-Bit Server VM 21.0.11+10-LTS on Linux 6.17.0-1013-azure
+Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz
+B1 p=7 s=2 p'=8: Best Time(ms) Avg Time(ms)
Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+native (no cast, rule on) 2130 2136
5 4.7 213.0 1.0X
+widened cast, peel off 2358 2367
15 4.2 235.8 0.9X
+widened cast, peel on 2140 2150
7 4.7 214.0 1.0X
+
+OpenJDK 64-Bit Server VM 21.0.11+10-LTS on Linux 6.17.0-1013-azure
+Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz
+B2 p=7 s=2 p'=12: Best Time(ms) Avg Time(ms)
Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+native (no cast, rule on) 2147 2151
3 4.7 214.7 1.0X
+widened cast, peel off 2359 2361
2 4.2 235.9 0.9X
+widened cast, peel on 2126 2161
20 4.7 212.6 1.0X
+
+OpenJDK 64-Bit Server VM 21.0.11+10-LTS on Linux 6.17.0-1013-azure
+Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz
+B3 p=5 s=0 p'=6: Best Time(ms) Avg Time(ms)
Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+native (no cast, rule on) 2173 2185
9 4.6 217.3 1.0X
+widened cast, peel off 2405 2413
7 4.2 240.5 0.9X
+widened cast, peel on 2167 2177
12 4.6 216.7 1.0X
+
+OpenJDK 64-Bit Server VM 21.0.11+10-LTS on Linux 6.17.0-1013-azure
+Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz
+B4 p=5 s=0 p'=15: Best Time(ms) Avg Time(ms)
Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+native (no cast, rule on) 2173 2179
7 4.6 217.3 1.0X
+widened cast, peel off 2393 2400
11 4.2 239.3 0.9X
+widened cast, peel on 2172 2178
5 4.6 217.2 1.0X
+
+
diff --git a/sql/core/benchmarks/DecimalAggregatesBenchmark-jdk25-results.txt
b/sql/core/benchmarks/DecimalAggregatesBenchmark-jdk25-results.txt
new file mode 100644
index 000000000000..60109cac85ec
--- /dev/null
+++ b/sql/core/benchmarks/DecimalAggregatesBenchmark-jdk25-results.txt
@@ -0,0 +1,74 @@
+================================================================================================
+DecimalAggregates SUM widened-cast peel (Aggregate)
+================================================================================================
+
+OpenJDK 64-Bit Server VM 25.0.3+9-LTS on Linux 6.17.0-1013-azure
+Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz
+A1 p=7 s=2 p'=8: Best Time(ms) Avg Time(ms)
Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+native (no cast, rule on) 1194 1230
57 8.4 119.4 1.0X
+widened cast, peel off 1421 1433
11 7.0 142.1 0.8X
+widened cast, peel on 1181 1188
5 8.5 118.1 1.0X
+
+OpenJDK 64-Bit Server VM 25.0.3+9-LTS on Linux 6.17.0-1013-azure
+Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz
+A2 p=7 s=2 p'=17: Best Time(ms) Avg Time(ms)
Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+native (no cast, rule on) 1174 1189
12 8.5 117.4 1.0X
+widened cast, peel off 1401 1414
8 7.1 140.1 0.8X
+widened cast, peel on 1169 1178
8 8.6 116.9 1.0X
+
+OpenJDK 64-Bit Server VM 25.0.3+9-LTS on Linux 6.17.0-1013-azure
+Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz
+A3 p=5 s=0 p'=6: Best Time(ms) Avg Time(ms)
Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+native (no cast, rule on) 1245 1254
10 8.0 124.5 1.0X
+widened cast, peel off 1498 1503
5 6.7 149.8 0.8X
+widened cast, peel on 1222 1232
10 8.2 122.2 1.0X
+
+OpenJDK 64-Bit Server VM 25.0.3+9-LTS on Linux 6.17.0-1013-azure
+Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz
+A4 p=5 s=0 p'=15: Best Time(ms) Avg Time(ms)
Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+native (no cast, rule on) 1234 1238
3 8.1 123.4 1.0X
+widened cast, peel off 1473 1478
7 6.8 147.3 0.8X
+widened cast, peel on 1242 1255
16 8.1 124.2 1.0X
+
+
+================================================================================================
+DecimalAggregates AVG widened-cast peel (Aggregate)
+================================================================================================
+
+OpenJDK 64-Bit Server VM 25.0.3+9-LTS on Linux 6.17.0-1013-azure
+Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz
+B1 p=7 s=2 p'=8: Best Time(ms) Avg Time(ms)
Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+native (no cast, rule on) 1178 1185
9 8.5 117.8 1.0X
+widened cast, peel off 1434 1440
8 7.0 143.4 0.8X
+widened cast, peel on 1232 1235
3 8.1 123.2 1.0X
+
+OpenJDK 64-Bit Server VM 25.0.3+9-LTS on Linux 6.17.0-1013-azure
+Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz
+B2 p=7 s=2 p'=12: Best Time(ms) Avg Time(ms)
Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+native (no cast, rule on) 1222 1229
7 8.2 122.2 1.0X
+widened cast, peel off 1434 1444
10 7.0 143.4 0.9X
+widened cast, peel on 1216 1223
6 8.2 121.6 1.0X
+
+OpenJDK 64-Bit Server VM 25.0.3+9-LTS on Linux 6.17.0-1013-azure
+Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz
+B3 p=5 s=0 p'=6: Best Time(ms) Avg Time(ms)
Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+native (no cast, rule on) 1267 1274
6 7.9 126.7 1.0X
+widened cast, peel off 1505 1509
4 6.6 150.5 0.8X
+widened cast, peel on 1272 1277
7 7.9 127.2 1.0X
+
+OpenJDK 64-Bit Server VM 25.0.3+9-LTS on Linux 6.17.0-1013-azure
+Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz
+B4 p=5 s=0 p'=15: Best Time(ms) Avg Time(ms)
Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+native (no cast, rule on) 1269 1275
5 7.9 126.9 1.0X
+widened cast, peel off 1494 1501
9 6.7 149.4 0.8X
+widened cast, peel on 1268 1274
6 7.9 126.8 1.0X
+
+
diff --git a/sql/core/benchmarks/DecimalAggregatesBenchmark-results.txt
b/sql/core/benchmarks/DecimalAggregatesBenchmark-results.txt
new file mode 100644
index 000000000000..d9c2c9662826
--- /dev/null
+++ b/sql/core/benchmarks/DecimalAggregatesBenchmark-results.txt
@@ -0,0 +1,74 @@
+================================================================================================
+DecimalAggregates SUM widened-cast peel (Aggregate)
+================================================================================================
+
+OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 6.17.0-1013-azure
+AMD EPYC 9V74 80-Core Processor
+A1 p=7 s=2 p'=8: Best Time(ms) Avg Time(ms)
Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+native (no cast, rule on) 3068 3095
35 3.3 306.8 1.0X
+widened cast, peel off 3396 3410
19 2.9 339.6 0.9X
+widened cast, peel on 3107 3115
10 3.2 310.7 1.0X
+
+OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 6.17.0-1013-azure
+AMD EPYC 9V74 80-Core Processor
+A2 p=7 s=2 p'=17: Best Time(ms) Avg Time(ms)
Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+native (no cast, rule on) 3104 3120
23 3.2 310.4 1.0X
+widened cast, peel off 3386 3407
27 3.0 338.6 0.9X
+widened cast, peel on 3094 3106
17 3.2 309.4 1.0X
+
+OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 6.17.0-1013-azure
+AMD EPYC 9V74 80-Core Processor
+A3 p=5 s=0 p'=6: Best Time(ms) Avg Time(ms)
Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+native (no cast, rule on) 3039 3053
21 3.3 303.9 1.0X
+widened cast, peel off 3336 3340
5 3.0 333.6 0.9X
+widened cast, peel on 3034 3048
14 3.3 303.4 1.0X
+
+OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 6.17.0-1013-azure
+AMD EPYC 9V74 80-Core Processor
+A4 p=5 s=0 p'=15: Best Time(ms) Avg Time(ms)
Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+native (no cast, rule on) 3037 3049
16 3.3 303.7 1.0X
+widened cast, peel off 3324 3340
16 3.0 332.4 0.9X
+widened cast, peel on 3027 3031
4 3.3 302.7 1.0X
+
+
+================================================================================================
+DecimalAggregates AVG widened-cast peel (Aggregate)
+================================================================================================
+
+OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 6.17.0-1013-azure
+AMD EPYC 9V74 80-Core Processor
+B1 p=7 s=2 p'=8: Best Time(ms) Avg Time(ms)
Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+native (no cast, rule on) 3038 3041
2 3.3 303.8 1.0X
+widened cast, peel off 3274 3283
18 3.1 327.4 0.9X
+widened cast, peel on 3056 3074
15 3.3 305.6 1.0X
+
+OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 6.17.0-1013-azure
+AMD EPYC 9V74 80-Core Processor
+B2 p=7 s=2 p'=12: Best Time(ms) Avg Time(ms)
Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+native (no cast, rule on) 3029 3033
3 3.3 302.9 1.0X
+widened cast, peel off 3288 3291
2 3.0 328.8 0.9X
+widened cast, peel on 3031 3036
6 3.3 303.1 1.0X
+
+OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 6.17.0-1013-azure
+AMD EPYC 9V74 80-Core Processor
+B3 p=5 s=0 p'=6: Best Time(ms) Avg Time(ms)
Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+native (no cast, rule on) 3022 3030
5 3.3 302.2 1.0X
+widened cast, peel off 3275 3307
28 3.1 327.5 0.9X
+widened cast, peel on 3025 3028
3 3.3 302.5 1.0X
+
+OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 6.17.0-1013-azure
+AMD EPYC 9V74 80-Core Processor
+B4 p=5 s=0 p'=15: Best Time(ms) Avg Time(ms)
Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+native (no cast, rule on) 3024 3039
21 3.3 302.4 1.0X
+widened cast, peel off 3279 3298
17 3.1 327.9 0.9X
+widened cast, peel on 3016 3023
6 3.3 301.6 1.0X
+
+
diff --git
a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18.sf100/explain.txt
b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18.sf100/explain.txt
index f7c0dcd7c56b..ff0b0e468530 100644
---
a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18.sf100/explain.txt
+++
b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18.sf100/explain.txt
@@ -257,7 +257,7 @@ Arguments: [[cs_quantity#4, cs_list_price#5,
cs_sales_price#6, cs_coupon_amt#7,
(46) HashAggregate [codegen id : 13]
Input [12]: [cs_quantity#4, cs_list_price#5, cs_sales_price#6,
cs_coupon_amt#7, cs_net_profit#8, cd_dep_count#14, c_birth_year#22,
i_item_id#28, ca_country#29, ca_state#30, ca_county#31, spark_grouping_id#32]
Keys [5]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31,
spark_grouping_id#32]
-Functions [7]: [partial_avg(cast(cs_quantity#4 as decimal(12,2))),
partial_avg(cast(cs_list_price#5 as decimal(12,2))),
partial_avg(cast(cs_coupon_amt#7 as decimal(12,2))),
partial_avg(cast(cs_sales_price#6 as decimal(12,2))),
partial_avg(cast(cs_net_profit#8 as decimal(12,2))),
partial_avg(cast(c_birth_year#22 as decimal(12,2))),
partial_avg(cast(cd_dep_count#14 as decimal(12,2)))]
+Functions [7]: [partial_avg(cast(cs_quantity#4 as decimal(12,2))),
partial_avg(UnscaledValue(cs_list_price#5)),
partial_avg(UnscaledValue(cs_coupon_amt#7)),
partial_avg(UnscaledValue(cs_sales_price#6)),
partial_avg(UnscaledValue(cs_net_profit#8)), partial_avg(cast(c_birth_year#22
as decimal(12,2))), partial_avg(cast(cd_dep_count#14 as decimal(12,2)))]
Aggregate Attributes [14]: [sum#33, count#34, sum#35, count#36, sum#37,
count#38, sum#39, count#40, sum#41, count#42, sum#43, count#44, sum#45,
count#46]
Results [19]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31,
spark_grouping_id#32, sum#47, count#48, sum#49, count#50, sum#51, count#52,
sum#53, count#54, sum#55, count#56, sum#57, count#58, sum#59, count#60]
@@ -268,9 +268,9 @@ Arguments: hashpartitioning(i_item_id#28, ca_country#29,
ca_state#30, ca_county#
(48) HashAggregate [codegen id : 14]
Input [19]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31,
spark_grouping_id#32, sum#47, count#48, sum#49, count#50, sum#51, count#52,
sum#53, count#54, sum#55, count#56, sum#57, count#58, sum#59, count#60]
Keys [5]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31,
spark_grouping_id#32]
-Functions [7]: [avg(cast(cs_quantity#4 as decimal(12,2))),
avg(cast(cs_list_price#5 as decimal(12,2))), avg(cast(cs_coupon_amt#7 as
decimal(12,2))), avg(cast(cs_sales_price#6 as decimal(12,2))),
avg(cast(cs_net_profit#8 as decimal(12,2))), avg(cast(c_birth_year#22 as
decimal(12,2))), avg(cast(cd_dep_count#14 as decimal(12,2)))]
-Aggregate Attributes [7]: [avg(cast(cs_quantity#4 as decimal(12,2)))#61,
avg(cast(cs_list_price#5 as decimal(12,2)))#62, avg(cast(cs_coupon_amt#7 as
decimal(12,2)))#63, avg(cast(cs_sales_price#6 as decimal(12,2)))#64,
avg(cast(cs_net_profit#8 as decimal(12,2)))#65, avg(cast(c_birth_year#22 as
decimal(12,2)))#66, avg(cast(cd_dep_count#14 as decimal(12,2)))#67]
-Results [11]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31,
avg(cast(cs_quantity#4 as decimal(12,2)))#61 AS agg1#68,
avg(cast(cs_list_price#5 as decimal(12,2)))#62 AS agg2#69,
avg(cast(cs_coupon_amt#7 as decimal(12,2)))#63 AS agg3#70,
avg(cast(cs_sales_price#6 as decimal(12,2)))#64 AS agg4#71,
avg(cast(cs_net_profit#8 as decimal(12,2)))#65 AS agg5#72,
avg(cast(c_birth_year#22 as decimal(12,2)))#66 AS agg6#73,
avg(cast(cd_dep_count#14 as decimal(12,2)))#67 AS agg7#74]
+Functions [7]: [avg(cast(cs_quantity#4 as decimal(12,2))),
avg(UnscaledValue(cs_list_price#5)), avg(UnscaledValue(cs_coupon_amt#7)),
avg(UnscaledValue(cs_sales_price#6)), avg(UnscaledValue(cs_net_profit#8)),
avg(cast(c_birth_year#22 as decimal(12,2))), avg(cast(cd_dep_count#14 as
decimal(12,2)))]
+Aggregate Attributes [7]: [avg(cast(cs_quantity#4 as decimal(12,2)))#61,
avg(UnscaledValue(cs_list_price#5))#62, avg(UnscaledValue(cs_coupon_amt#7))#63,
avg(UnscaledValue(cs_sales_price#6))#64,
avg(UnscaledValue(cs_net_profit#8))#65, avg(cast(c_birth_year#22 as
decimal(12,2)))#66, avg(cast(cd_dep_count#14 as decimal(12,2)))#67]
+Results [11]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31,
avg(cast(cs_quantity#4 as decimal(12,2)))#61 AS agg1#68,
cast((avg(UnscaledValue(cs_list_price#5))#62 / 100.0) as decimal(16,6)) AS
agg2#69, cast((avg(UnscaledValue(cs_coupon_amt#7))#63 / 100.0) as
decimal(16,6)) AS agg3#70, cast((avg(UnscaledValue(cs_sales_price#6))#64 /
100.0) as decimal(16,6)) AS agg4#71,
cast((avg(UnscaledValue(cs_net_profit#8))#65 / 100.0) as decimal(16,6)) AS
agg5#72, avg(cast(c_birth_year#22 as [...]
(49) TakeOrderedAndProject
Input [11]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31, agg1#68,
agg2#69, agg3#70, agg4#71, agg5#72, agg6#73, agg7#74]
diff --git
a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18.sf100/simplified.txt
b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18.sf100/simplified.txt
index 276165729be5..079bb6aba3ec 100644
---
a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18.sf100/simplified.txt
+++
b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18.sf100/simplified.txt
@@ -1,6 +1,6 @@
TakeOrderedAndProject
[ca_country,ca_state,ca_county,i_item_id,agg1,agg2,agg3,agg4,agg5,agg6,agg7]
WholeStageCodegen (14)
- HashAggregate
[i_item_id,ca_country,ca_state,ca_county,spark_grouping_id,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count]
[avg(cast(cs_quantity as decimal(12,2))),avg(cast(cs_list_price as
decimal(12,2))),avg(cast(cs_coupon_amt as
decimal(12,2))),avg(cast(cs_sales_price as
decimal(12,2))),avg(cast(cs_net_profit as decimal(12,2))),avg(cast(c_birth_year
as decimal(12,2))),avg(cast(cd_dep_count as
decimal(12,2))),agg1,agg2,agg3,agg4,agg5,agg6,agg7,sum,count,sum,cou [...]
+ HashAggregate
[i_item_id,ca_country,ca_state,ca_county,spark_grouping_id,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count]
[avg(cast(cs_quantity as
decimal(12,2))),avg(UnscaledValue(cs_list_price)),avg(UnscaledValue(cs_coupon_amt)),avg(UnscaledValue(cs_sales_price)),avg(UnscaledValue(cs_net_profit)),avg(cast(c_birth_year
as decimal(12,2))),avg(cast(cd_dep_count as
decimal(12,2))),agg1,agg2,agg3,agg4,agg5,agg6,agg7,sum,count,sum,count,sum,count,sum,count,sum,count
[...]
InputAdapter
Exchange [i_item_id,ca_country,ca_state,ca_county,spark_grouping_id] #1
WholeStageCodegen (13)
diff --git
a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18/explain.txt
b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18/explain.txt
index 7db1c87c52a6..8f25c83767ff 100644
---
a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18/explain.txt
+++
b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18/explain.txt
@@ -227,7 +227,7 @@ Arguments: [[cs_quantity#4, cs_list_price#5,
cs_sales_price#6, cs_coupon_amt#7,
(40) HashAggregate [codegen id : 7]
Input [12]: [cs_quantity#4, cs_list_price#5, cs_sales_price#6,
cs_coupon_amt#7, cs_net_profit#8, cd_dep_count#14, c_birth_year#19,
i_item_id#28, ca_country#29, ca_state#30, ca_county#31, spark_grouping_id#32]
Keys [5]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31,
spark_grouping_id#32]
-Functions [7]: [partial_avg(cast(cs_quantity#4 as decimal(12,2))),
partial_avg(cast(cs_list_price#5 as decimal(12,2))),
partial_avg(cast(cs_coupon_amt#7 as decimal(12,2))),
partial_avg(cast(cs_sales_price#6 as decimal(12,2))),
partial_avg(cast(cs_net_profit#8 as decimal(12,2))),
partial_avg(cast(c_birth_year#19 as decimal(12,2))),
partial_avg(cast(cd_dep_count#14 as decimal(12,2)))]
+Functions [7]: [partial_avg(cast(cs_quantity#4 as decimal(12,2))),
partial_avg(UnscaledValue(cs_list_price#5)),
partial_avg(UnscaledValue(cs_coupon_amt#7)),
partial_avg(UnscaledValue(cs_sales_price#6)),
partial_avg(UnscaledValue(cs_net_profit#8)), partial_avg(cast(c_birth_year#19
as decimal(12,2))), partial_avg(cast(cd_dep_count#14 as decimal(12,2)))]
Aggregate Attributes [14]: [sum#33, count#34, sum#35, count#36, sum#37,
count#38, sum#39, count#40, sum#41, count#42, sum#43, count#44, sum#45,
count#46]
Results [19]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31,
spark_grouping_id#32, sum#47, count#48, sum#49, count#50, sum#51, count#52,
sum#53, count#54, sum#55, count#56, sum#57, count#58, sum#59, count#60]
@@ -238,9 +238,9 @@ Arguments: hashpartitioning(i_item_id#28, ca_country#29,
ca_state#30, ca_county#
(42) HashAggregate [codegen id : 8]
Input [19]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31,
spark_grouping_id#32, sum#47, count#48, sum#49, count#50, sum#51, count#52,
sum#53, count#54, sum#55, count#56, sum#57, count#58, sum#59, count#60]
Keys [5]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31,
spark_grouping_id#32]
-Functions [7]: [avg(cast(cs_quantity#4 as decimal(12,2))),
avg(cast(cs_list_price#5 as decimal(12,2))), avg(cast(cs_coupon_amt#7 as
decimal(12,2))), avg(cast(cs_sales_price#6 as decimal(12,2))),
avg(cast(cs_net_profit#8 as decimal(12,2))), avg(cast(c_birth_year#19 as
decimal(12,2))), avg(cast(cd_dep_count#14 as decimal(12,2)))]
-Aggregate Attributes [7]: [avg(cast(cs_quantity#4 as decimal(12,2)))#61,
avg(cast(cs_list_price#5 as decimal(12,2)))#62, avg(cast(cs_coupon_amt#7 as
decimal(12,2)))#63, avg(cast(cs_sales_price#6 as decimal(12,2)))#64,
avg(cast(cs_net_profit#8 as decimal(12,2)))#65, avg(cast(c_birth_year#19 as
decimal(12,2)))#66, avg(cast(cd_dep_count#14 as decimal(12,2)))#67]
-Results [11]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31,
avg(cast(cs_quantity#4 as decimal(12,2)))#61 AS agg1#68,
avg(cast(cs_list_price#5 as decimal(12,2)))#62 AS agg2#69,
avg(cast(cs_coupon_amt#7 as decimal(12,2)))#63 AS agg3#70,
avg(cast(cs_sales_price#6 as decimal(12,2)))#64 AS agg4#71,
avg(cast(cs_net_profit#8 as decimal(12,2)))#65 AS agg5#72,
avg(cast(c_birth_year#19 as decimal(12,2)))#66 AS agg6#73,
avg(cast(cd_dep_count#14 as decimal(12,2)))#67 AS agg7#74]
+Functions [7]: [avg(cast(cs_quantity#4 as decimal(12,2))),
avg(UnscaledValue(cs_list_price#5)), avg(UnscaledValue(cs_coupon_amt#7)),
avg(UnscaledValue(cs_sales_price#6)), avg(UnscaledValue(cs_net_profit#8)),
avg(cast(c_birth_year#19 as decimal(12,2))), avg(cast(cd_dep_count#14 as
decimal(12,2)))]
+Aggregate Attributes [7]: [avg(cast(cs_quantity#4 as decimal(12,2)))#61,
avg(UnscaledValue(cs_list_price#5))#62, avg(UnscaledValue(cs_coupon_amt#7))#63,
avg(UnscaledValue(cs_sales_price#6))#64,
avg(UnscaledValue(cs_net_profit#8))#65, avg(cast(c_birth_year#19 as
decimal(12,2)))#66, avg(cast(cd_dep_count#14 as decimal(12,2)))#67]
+Results [11]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31,
avg(cast(cs_quantity#4 as decimal(12,2)))#61 AS agg1#68,
cast((avg(UnscaledValue(cs_list_price#5))#62 / 100.0) as decimal(16,6)) AS
agg2#69, cast((avg(UnscaledValue(cs_coupon_amt#7))#63 / 100.0) as
decimal(16,6)) AS agg3#70, cast((avg(UnscaledValue(cs_sales_price#6))#64 /
100.0) as decimal(16,6)) AS agg4#71,
cast((avg(UnscaledValue(cs_net_profit#8))#65 / 100.0) as decimal(16,6)) AS
agg5#72, avg(cast(c_birth_year#19 as [...]
(43) TakeOrderedAndProject
Input [11]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31, agg1#68,
agg2#69, agg3#70, agg4#71, agg5#72, agg6#73, agg7#74]
diff --git
a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18/simplified.txt
b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18/simplified.txt
index 269bfd3f44fc..7c3075e26fa2 100644
---
a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18/simplified.txt
+++
b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18/simplified.txt
@@ -1,6 +1,6 @@
TakeOrderedAndProject
[ca_country,ca_state,ca_county,i_item_id,agg1,agg2,agg3,agg4,agg5,agg6,agg7]
WholeStageCodegen (8)
- HashAggregate
[i_item_id,ca_country,ca_state,ca_county,spark_grouping_id,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count]
[avg(cast(cs_quantity as decimal(12,2))),avg(cast(cs_list_price as
decimal(12,2))),avg(cast(cs_coupon_amt as
decimal(12,2))),avg(cast(cs_sales_price as
decimal(12,2))),avg(cast(cs_net_profit as decimal(12,2))),avg(cast(c_birth_year
as decimal(12,2))),avg(cast(cd_dep_count as
decimal(12,2))),agg1,agg2,agg3,agg4,agg5,agg6,agg7,sum,count,sum,cou [...]
+ HashAggregate
[i_item_id,ca_country,ca_state,ca_county,spark_grouping_id,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count]
[avg(cast(cs_quantity as
decimal(12,2))),avg(UnscaledValue(cs_list_price)),avg(UnscaledValue(cs_coupon_amt)),avg(UnscaledValue(cs_sales_price)),avg(UnscaledValue(cs_net_profit)),avg(cast(c_birth_year
as decimal(12,2))),avg(cast(cd_dep_count as
decimal(12,2))),agg1,agg2,agg3,agg4,agg5,agg6,agg7,sum,count,sum,count,sum,count,sum,count,sum,count
[...]
InputAdapter
Exchange [i_item_id,ca_country,ca_state,ca_county,spark_grouping_id] #1
WholeStageCodegen (7)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 3c99c975977a..180dd5d5db94 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -22,7 +22,9 @@ import java.util.Locale
import scala.util.Random
+import org.scalacheck.Gen
import org.scalatest.matchers.must.Matchers.the
+import org.scalatestplus.scalacheck.ScalaCheckDrivenPropertyChecks
import org.apache.spark.{SparkArithmeticException, SparkRuntimeException}
import org.apache.spark.sql.catalyst.plans.logical.Expand
@@ -47,7 +49,8 @@ case class Fact(date: Int, hour: Int, minute: Int, room_name:
String, temp: Doub
@SlowSQLTest
class DataFrameAggregateSuite extends SharedSparkSession
- with AdaptiveSparkPlanHelper {
+ with AdaptiveSparkPlanHelper
+ with ScalaCheckDrivenPropertyChecks {
import testImplicits._
val absTol = 1e-8
@@ -4816,6 +4819,193 @@ class DataFrameAggregateSuite extends SharedSparkSession
assert(estimate != null)
assert(estimate.asInstanceOf[Double] == 2.0)
}
+
+ // Numerical-equivalence property (sql-core layer).
+ //
+ // Sweeps the (p, p', s, n) lattice where the widened-cast peel fires,
+ // asserting that SUM(CAST(x AS DECIMAL(p', s))) on an on-vs-off SQLConf
+ // pair returns bit-equal java.math.BigDecimal (same unscaled value AND
+ // same scale). Domain is restricted to the non-overflow regime so the
+ // peeled LONG accumulator cannot wrap.
+ //
+ // Non-overflow bound: with |unscaled(x)| < 10^p, p <= 8, n <= 1000,
+ // worst-case accumulator is 1000 * (10^8 - 1) < 10^12 << 2^63.
+ //
+ // A wide-target-scale fixed witness (p=8, p'=30, s=2) is exercised below
+ // as a unit case to guarantee a hand-enumerated boundary even if the
+ // property generator shrinks.
+
+ private case class PeelDomain(p: Int, pPrime: Int, s: Int)
+
+ private val peelDomainGen: Gen[PeelDomain] = (for {
+ p <- Gen.choose(1, 8)
+ pPrime <- Gen.choose(math.max(p + 1, 9), 28)
+ s <- Gen.choose(0, p)
+ } yield PeelDomain(p, pPrime, s))
+ .retryUntil(d => d.p + 10 <= 18 && d.p < d.pPrime && d.pPrime + 10 <= 38)
+
+ // Reference SUM via java.math.BigDecimal at the widened target scale.
+ // Inside the non-overflow domain (|sum unscaled| < 10^(p+10)) this is
+ // bit-exact equivalent to both the peeled and the baseline plan, so we
+ // can pin the peeled result against an external oracle without depending
+ // on a baseline plan we no longer exercise.
+ private def referenceSum(
+ unscaledLongs: Seq[Long], d: PeelDomain): java.math.BigDecimal = {
+ if (unscaledLongs.isEmpty) {
+ null
+ } else {
+ val acc = unscaledLongs
+ .map(u => java.math.BigDecimal.valueOf(u, d.s))
+ .foldLeft(java.math.BigDecimal.ZERO)(_.add(_))
+ acc.setScale(d.s)
+ }
+ }
+
+ private def sumCastResult(
+ unscaledLongs: Seq[Long], d: PeelDomain): java.math.BigDecimal = {
+ // Use an explicit DecimalType(p, s) schema rather than Scala-tuple
+ // inference. createDataFrame on Tuple1[java.math.BigDecimal] infers
+ // DecimalType.SYSTEM_DEFAULT (38, 18), which would force the subsequent
+ // CAST to widen from (38, 18) -> (pPrime, s) rather than from the
+ // intended narrow (p, s) -> (pPrime, s) widening, defeating the
+ // WidenedDecimalChild trigger and silently exercising the wrong rule arm.
+ val rows = unscaledLongs.map(u => Row(java.math.BigDecimal.valueOf(u,
d.s)))
+ val schema = StructType(StructField("x", DecimalType(d.p, d.s)) :: Nil)
+ val df = spark.createDataFrame(spark.sparkContext.parallelize(rows),
schema)
+ assert(df.schema("x").dataType == DecimalType(d.p, d.s),
+ s"expected inner schema DecimalType(${d.p}, ${d.s}), got
${df.schema("x").dataType}")
+ df.select(sum(col("x").cast(DecimalType(d.pPrime, d.s))).as("s"))
+ .collect()(0).getDecimal(0)
+ }
+
+ test("SPARK-56627: DecimalAggregates widened-Cast SUM peel -- numerical " +
+ "equivalence property (sql-core layer)") {
+ val combinedGen: Gen[(PeelDomain, List[Long])] = for {
+ d <- peelDomainGen
+ upper = math.pow(10, d.p).toLong - 1
+ n <- Gen.choose(1, 1000)
+ xs <- Gen.listOfN(n, Gen.choose(-upper, upper))
+ } yield (d, xs)
+ forAll(combinedGen, minSuccessful(20), sizeRange(0)) { case (d, xs) =>
+ val r = sumCastResult(xs, d)
+ val ref = referenceSum(xs, d)
+ assert(r.compareTo(ref) == 0,
+ s"peel result diverges from BigDecimal reference for " +
+ s"PeelDomain(p=${d.p}, pPrime=${d.pPrime}, s=${d.s}), n=${xs.size},
" +
+ s"sample=${xs.take(3)}, got=$r ref=$ref")
+ }
+ }
+
+ // Wide target-scale fixed witness: (p=8, p'=30, s=2). Hand-enumerated so a
+ // wide target scale case is always exercised even if property shrinks.
+ test("SPARK-56627: SUM(CAST(dec(8,2) AS dec(30,2))) matches BigDecimal " +
+ "reference (wide-target-scale fixed witness, sql-core)") {
+ val d = PeelDomain(8, 30, 2)
+ val xs = Seq(0L, 1L, -1L, 99999999L, -99999999L, 12345678L, -87654321L)
+ val r = sumCastResult(xs, d)
+ val ref = referenceSum(xs, d)
+ assert(r.compareTo(ref) == 0, s"got=$r ref=$ref")
+ }
+
+ // AVG widened-Cast peel: equivalence property (sql-core layer).
+ //
+ // Oracle: peel(AVG(CAST(x AS dec(pPrime, s)))) must be observationally
+ // identical to the existing fast path on AVG(x) directly. Both arms in
+ // Optimizer.DecimalAggregates produce
+ // Cast(Divide(Avg(UnscaledValue(<inner>)), Lit(10^s, Double)),
+ // DecimalType.bounded(<outerP>, s + 4))
+ // and the peel arm makes <inner> equal to the user's column, so the
+ // Double-divide dividends are bit-identical between the two paths; only
+ // the outer Cast target precision differs (pPrime+4 vs p+4), a widening
+ // precision Cast that preserves numerical value. We therefore assert
+ // BigDecimal.compareTo == 0 (value equality across differing precisions).
+ //
+ // Domain: inner p in [1, 7] (the AVG strict-subset guard
+ // `AVG_PEEL_MAX_INNER_PRECISION = 7`), pPrime in [8, 11] (the band where
+ // the existing `Average(DecimalExpression)` arm would intercept on the
+ // outer Cast type if not for our prepended arm), s in [0, p],
+ // n <= 1000 rows. The inner DataFrame schema is constructed as
+ // DecimalType(p, s) explicitly (NOT via tuple-inference, which would
+ // infer DecimalType.SYSTEM_DEFAULT and silently route through a DIFFERENT
+ // rule arm than intended -- the failure mode this PBT must lock down).
+ private case class AvgDomain(p: Int, pPrime: Int, s: Int)
+
+ private val avgDomainGen: Gen[AvgDomain] = (for {
+ p <- Gen.choose(1, 7)
+ pPrime <- Gen.choose(8, 11)
+ s <- Gen.choose(0, p)
+ } yield AvgDomain(p, pPrime, s))
+ .retryUntil(d => d.p < d.pPrime)
+
+ private def avgInputDf(unscaledLongs: Seq[Long], d: AvgDomain) = {
+ val rows = unscaledLongs.map(u => Row(java.math.BigDecimal.valueOf(u,
d.s)))
+ val schema = StructType(StructField("x", DecimalType(d.p, d.s)) :: Nil)
+ val df = spark.createDataFrame(spark.sparkContext.parallelize(rows),
schema)
+ assert(df.schema("x").dataType == DecimalType(d.p, d.s),
+ s"expected inner schema DecimalType(${d.p}, ${d.s}), got
${df.schema("x").dataType}")
+ df
+ }
+
+ private def avgCastResult(
+ unscaledLongs: Seq[Long], d: AvgDomain): java.math.BigDecimal = {
+ avgInputDf(unscaledLongs, d)
+ .select(avg(col("x").cast(DecimalType(d.pPrime, d.s))).as("a"))
+ .collect()(0).getDecimal(0)
+ }
+
+ private def avgDirectResult(
+ unscaledLongs: Seq[Long], d: AvgDomain): java.math.BigDecimal = {
+ avgInputDf(unscaledLongs, d)
+ .select(avg(col("x")).as("a"))
+ .collect()(0).getDecimal(0)
+ }
+
+ test("SPARK-56627: DecimalAggregates widened-Cast AVG peel -- " +
+ "equivalence vs unpeeled AVG (sql-core)") {
+ val combinedGen: Gen[(AvgDomain, List[Long])] = for {
+ d <- avgDomainGen
+ upper = math.pow(10, d.p).toLong - 1
+ n <- Gen.choose(1, 1000)
+ xs <- Gen.listOfN(n, Gen.choose(-upper, upper))
+ } yield (d, xs)
+ forAll(combinedGen, minSuccessful(20), sizeRange(0)) { case (d, xs) =>
+ val peeled = avgCastResult(xs, d)
+ val direct = avgDirectResult(xs, d)
+ // BigDecimal.compareTo ignores trailing-zero precision differences:
+ // peeled has output DecimalType.bounded(pPrime+4, s+4), direct has
+ // DecimalType(p+4, s+4). Both wrap the same Double-divide bit pattern
+ // so the underlying value is identical.
+ assert(peeled.compareTo(direct) == 0,
+ s"peeled AVG diverges from unpeeled AVG for " +
+ s"AvgDomain(p=${d.p}, pPrime=${d.pPrime}, s=${d.s}), n=${xs.size}, "
+
+ s"sample=${xs.take(3)}, peeled=$peeled direct=$direct")
+ }
+ }
+
+ // Wider-pPrime regime shape witness: (p=4, p'=20, s=2). The equivalence
+ // PBT above only covers pPrime in [8, 11] (where the existing AVG arm
+ // would otherwise intercept and provide a comparable oracle). For pPrime
+ // outside that band the new arm still fires (only constrained by inner
+ // p <= 7), but the comparison oracle "AVG(x) directly" is no longer
+ // available because the existing arm targets a narrower output type.
+ // This witness asserts non-null result and the expected widened output
+ // schema, locking the rule's shape contract without claiming an
+ // unreachable oracle.
+ test("SPARK-56627: AVG(CAST(dec(4,2) AS dec(20,2))) peels and yields " +
+ "widened output schema (wider-pPrime regime shape witness)") {
+ val rows = Seq(123L, -456L, 789L, 0L)
+ .map(u => Row(java.math.BigDecimal.valueOf(u, 2)))
+ val schema = StructType(StructField("x", DecimalType(4, 2)) :: Nil)
+ val df = spark.createDataFrame(spark.sparkContext.parallelize(rows),
schema)
+ .select(avg(col("x").cast(DecimalType(20, 2))).as("a"))
+ val row = df.collect()(0)
+ assert(!row.isNullAt(0), s"expected non-null AVG, got null; df schema =
${df.schema}")
+ val outType = df.schema("a").dataType.asInstanceOf[DecimalType]
+ // Widened-arm output Cast target = DecimalType.bounded(pPrime + 4, s + 4)
+ // = DecimalType.bounded(24, 6).
+ assert(outType.precision == 24 && outType.scale == 6,
+ s"expected DecimalType(24, 6) from widened-arm peel, got $outType")
+ }
}
case class B(c: Option[Double])
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DecimalAggregatesBenchmark.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DecimalAggregatesBenchmark.scala
new file mode 100644
index 000000000000..fc00bea62dd1
--- /dev/null
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DecimalAggregatesBenchmark.scala
@@ -0,0 +1,208 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.benchmark
+
+import org.apache.spark.benchmark.Benchmark
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.Decimal
+
+/**
+ * Benchmark for the DecimalAggregates widened-cast peel optimizer rule
+ * (both SUM and AVG arms).
+ *
+ * Each case is a three-way comparison on the same `DECIMAL(p, s)` input:
+ * 1. `native` -- query writes `SUM(x)` / `AVG(x)` directly,
+ * hitting the existing fast path (rule on).
+ * 2. `widened, peel off` -- query writes `SUM(CAST(x AS DECIMAL(p', s)))`
+ * with `DecimalAggregates` excluded; the cast
+ * defeats the existing fast path, so the
+ * baseline `CheckOverflow` path runs.
+ * 3. `widened, peel on` -- same widened query with the rule enabled;
+ * the new peel arm strips the cast and
+ * restores the fast path.
+ *
+ * Reviewer story:
+ * - `native` vs `widened, peel off` -- shows the widening cast really
+ * evicts user-visible work onto the slow path (motivation).
+ * - `widened, peel off` vs `widened, peel on` -- shows the peel rule
+ * recovers the lost performance (rule benefit).
+ * - `widened, peel on` vs `native` -- shows the peel makes the cast
+ * effectively free (rule correctness echo of the numerical-equivalence
+ * PBT in `DataFrameAggregateSuite`).
+ *
+ * Sections:
+ * A -- Aggregate SUM widened-cast sweep over (p, s, p') cases.
+ * B -- Aggregate AVG widened-cast sweep (p <= 7 per
+ * AVG_PEEL_MAX_INNER_PRECISION).
+ *
+ * NOTE on Window arm: the optimizer does not extend widened-Cast peel to
+ * the Window arm (see DecimalAggregates rule comment) because the analyzer
+ * hoists the Cast into a child Project, so a Window microbenchmark would
+ * not exercise this rule. A Window benchmark belongs with a future
+ * plan-layer rewrite, not here.
+ *
+ * Case design (`p+1` boundary vs `p+10`-class wider) deliberately includes
+ * both the minimum widening (one extra digit, e.g. `dec(7,2) -> dec(8,2)`)
+ * and a production-canonical wider one (e.g. `dec(7,2) -> dec(17,2)` is the
+ * inner-widened-decimal shape in TPC-DS q18) so reviewers see whether
+ * widening magnitude matters and whether the canonical shape behaves like
+ * the boundary one.
+ *
+ * Args: aN (Section A/B row count), iters, apl
+ * (`spark.sql.decimalOperations.allowPrecisionLoss`; default true).
+ * Defaults committed for GHA: aN=10000000, iters=5, apl=true.
+ *
+ * To run this benchmark locally (pre-GHA smoke):
+ * {{{
+ * build/sbt "sql/Test/runMain \
+ * org.apache.spark.sql.execution.benchmark.DecimalAggregatesBenchmark \
+ * 10000000 5"
+ * }}}
+ *
+ * To regenerate committed baselines (via `benchmark.yml` GHA workflow):
+ * {{{
+ * SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt \
+ * "sql/Test/runMain
org.apache.spark.sql.execution.benchmark.DecimalAggregatesBenchmark"
+ * }}}
+ *
+ * Committed results:
+ * `sql/core/benchmarks/DecimalAggregatesBenchmark-results.txt` (JDK 17).
+ * `sql/core/benchmarks/DecimalAggregatesBenchmark-jdk21-results.txt`.
+ * `sql/core/benchmarks/DecimalAggregatesBenchmark-jdk25-results.txt`.
+ */
+object DecimalAggregatesBenchmark extends SqlBasedBenchmark {
+
+ /**
+ * Aggregate SUM cases: (label, p, s, widened p').
+ *
+ * All `p + 10 <= 18` so the *native* `SUM(x)` query hits the existing
+ * SUM Long fast path -- providing a meaningful baseline for the
+ * peel-on leg. Coverage: `p+1` boundary widening (A1, A3) plus a
+ * `p+10`-class wider cast representative of production shapes (A2,
+ * A4; A2 mirrors the TPC-DS q18 inner-widened-decimal shape).
+ */
+ private val SumAggCases: Seq[(String, Int, Int, Int)] = Seq(
+ ("A1 p=7 s=2 p'=8", 7, 2, 8), // p+1 boundary
+ ("A2 p=7 s=2 p'=17", 7, 2, 17), // p+10, canonical TPC-DS q18 shape
+ ("A3 p=5 s=0 p'=6", 5, 0, 6), // p+1 boundary, zero scale
+ ("A4 p=5 s=0 p'=15", 5, 0, 15) // p+10, zero scale
+ )
+
+ /**
+ * Aggregate AVG cases: (label, p, s, widened p').
+ *
+ * All `p <= 7` per the conservative `AVG_PEEL_MAX_INNER_PRECISION = 7`
+ * guard (see design doc 0001 rev 7 S7.1 -- strict-subset narrowing so
+ * SPARK-37024 Double-regime exposure is NOT amplified by this rule).
+ * Same `p+1` / `p+10` split as Section A. B2 mirrors the canonical
+ * TPC-DS q18 AVG shape.
+ */
+ private val AvgAggCases: Seq[(String, Int, Int, Int)] = Seq(
+ ("B1 p=7 s=2 p'=8", 7, 2, 8), // p+1 boundary
+ ("B2 p=7 s=2 p'=12", 7, 2, 12), // canonical TPC-DS q18 AVG shape
+ ("B3 p=5 s=0 p'=6", 5, 0, 6), // p+1 boundary, zero scale
+ ("B4 p=5 s=0 p'=15", 5, 0, 15) // p+10, zero scale
+ )
+
+ /** Clamp generator to `10^(p-s) - 1` so rand() * bound fits `DECIMAL(p,
s)`. */
+ private def unscaledBound(p: Int, s: Int): Long = {
+ require(p - s >= 0, s"p=$p s=$s p-s must be non-negative")
+ math.pow(10.0, (p - s).toDouble).toLong - 1L
+ }
+
+ private def setupAggTable(spark: org.apache.spark.sql.SparkSession,
+ n: Long, p: Int, s: Int): Unit = {
+ val bound = unscaledBound(p, s)
+ spark.range(n)
+ .selectExpr(s"cast(rand(42) * $bound as decimal($p, $s)) as x")
+ .coalesce(1)
+ .createOrReplaceTempView("t")
+ }
+
+ private val ExcludedRulesKey: String = SQLConf.OPTIMIZER_EXCLUDED_RULES.key
+ private val DecimalAggregatesRule: String =
+ "org.apache.spark.sql.catalyst.optimizer.DecimalAggregates"
+
+ /**
+ * Run a single three-way comparison: native (no cast, rule on),
+ * widened with rule off (baseline slow path), widened with rule on
+ * (peel restores fast path). `apl` is held identical across all three
+ * legs so any delta is attributable to (a) the widening cast and
+ * (b) the peel rule respectively.
+ */
+ private def runThreeWay(label: String, n: Long, nativeSql: String,
+ widenedSql: String, iters: Int, apl: String): Unit = {
+ val bench = new Benchmark(label, n, output = output)
+ bench.addCase("native (no cast, rule on)", numIters = iters) { _ =>
+ withSQLConf(SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key -> apl) {
+ spark.sql(nativeSql).noop()
+ }
+ }
+ bench.addCase("widened cast, peel off", numIters = iters) { _ =>
+ withSQLConf(
+ ExcludedRulesKey -> DecimalAggregatesRule,
+ SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key -> apl) {
+ spark.sql(widenedSql).noop()
+ }
+ }
+ bench.addCase("widened cast, peel on", numIters = iters) { _ =>
+ withSQLConf(SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key -> apl) {
+ spark.sql(widenedSql).noop()
+ }
+ }
+ bench.run()
+ }
+
+ override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
+ val aN: Long = if (mainArgs.length > 0) mainArgs(0).toLong else 10L *
1000L * 1000L
+ val iters: Int = if (mainArgs.length > 1) mainArgs(1).toInt else 5
+ val apl: String = if (mainArgs.length > 2) mainArgs(2) else "true"
+
+ require(Decimal.MAX_LONG_DIGITS == 18,
+ s"Decimal.MAX_LONG_DIGITS drift: expected 18 got
${Decimal.MAX_LONG_DIGITS}")
+
+ // Section A -- Aggregate SUM widened-cast.
+ runBenchmark("DecimalAggregates SUM widened-cast peel (Aggregate)") {
+ SumAggCases.foreach { case (label, p, s, pPrime) =>
+ require(pPrime > p, s"$label: p'=$pPrime must exceed p=$p")
+ require(p + 10 <= 18,
+ s"$label: p=$p violates SUM Long fast path guard
p+10<=MAX_LONG_DIGITS=18; " +
+ s"native baseline would not be meaningful")
+ setupAggTable(spark, aN, p, s)
+ runThreeWay(label, aN,
+ nativeSql = "select sum(x) from t",
+ widenedSql = s"select sum(cast(x as decimal($pPrime, $s))) from t",
+ iters, apl)
+ }
+ }
+
+ // Section B -- Aggregate AVG widened-cast.
+ runBenchmark("DecimalAggregates AVG widened-cast peel (Aggregate)") {
+ AvgAggCases.foreach { case (label, p, s, pPrime) =>
+ require(pPrime > p, s"$label: p'=$pPrime must exceed p=$p")
+ require(p <= 7,
+ s"$label: p=$p violates conservative AVG_PEEL_MAX_INNER_PRECISION=7
guard")
+ setupAggTable(spark, aN, p, s)
+ runThreeWay(label, aN,
+ nativeSql = "select avg(x) from t",
+ widenedSql = s"select avg(cast(x as decimal($pPrime, $s))) from t",
+ iters, apl)
+ }
+ }
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]