This is an automated email from the ASF dual-hosted git repository.
cloud-fan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 846376a8db0d [SPARK-56983][SQL] DecimalAggregates widened-Cast peel
must preserve query semantics
846376a8db0d is described below
commit 846376a8db0df88ef720644a1ffa1fafa561fab6
Author: Wenchen Fan <[email protected]>
AuthorDate: Fri May 22 10:46:34 2026 +0800
[SPARK-56983][SQL] DecimalAggregates widened-Cast peel must preserve query
semantics
## What changes were proposed in this pull request?
Two semantic-preserving fixes to the SPARK-56627 widened-Cast peel in
`DecimalAggregates`:
1. **SUM arm**: Replace `Cast(MakeDecimal(_, p + 10, s), bounded(pPrime +
10, s))`
with `MakeDecimal(_, min(pPrime + 10, 38), s)`. The merged form's inner
`MakeDecimal` narrowed the overflow check to `10^(p+10)`, where the
un-rewritten `SUM(Cast(x, dec(pPrime, s)))` accepted up to
`10^min(pPrime+10, 38)`. For `pPrime + 10 > 18`, `Decimal.setOrNull`
falls
into the BigDecimal branch and never rejects a Long, so the cleaner form
preserves the original overflow boundary for all Long-fit sums.
2. **AVG arm**: Change the guard from `p <= 7`
(`AVG_PEEL_MAX_INNER_PRECISION`) to `pPrime + 4 <= 15`
(`MAX_DOUBLE_DIGITS`). The merged guard switched `AVG(CAST(x AS
DECIMAL(pPrime, s)))`
from full Decimal arithmetic to Double-regime whenever `p <= 7`,
including
the `pPrime > 11` band where un-rewritten was Decimal-exact. The new
guard
ensures the rewrite only fires inside the existing AVG fast-path's Double
envelope.
## Why are the changes needed?
A query rewrite should preserve the query's observable semantics. The
SPARK-56627 rule changed semantics in two ways:
- **SUM**: Inner `MakeDecimal(_, p + 10, s)` could return null (non-ANSI) or
throw (ANSI) for long-fit sums in \`(10^(p+10), 10^min(pPrime+10, 38))\`.
Example: \`SUM(CAST(x AS DECIMAL(15, 2)))\` where \`x: DECIMAL(5, 2)\`
rejected at \`10^15\` instead of \`10^25\` -- reachable around \`~10^10\`
rows of small-precision input.
- **AVG**: For \`pPrime > 11, p <= 7\`, the rule switched an exact-Decimal
computation into Double-regime, visible as last-digit rounding differences
at any input size. TPC-DS q18 (\`p=7, pPrime=12\`) was affected.
## Does this PR introduce _any_ user-facing change?
Yes -- restores the un-rewritten semantics for affected queries:
- \`SUM(CAST(x AS DECIMAL(pPrime, s)))\` no longer rejects long-fit sums
that the un-rewritten expression would have accepted.
- \`AVG(CAST(x AS DECIMAL(pPrime, s)))\` with \`pPrime + 4 >
MAX_DOUBLE_DIGITS\`
is no longer peeled -- the un-rewritten Decimal-exact path is preserved.
TPC-DS q18 AVG aggregations revert to the un-peeled form (visible in the
regenerated \`q18\` plan-stability files).
## How was this patch tested?
- \`DecimalAggregatesSuite\`: 37/37 pass, including updated SUM arm shape
assertions, the property-test invariants for the new MakeDecimal-only
form, and the new AVG bound boundary tests.
- \`DataFrameAggregateSuite\`: SPARK-56627 numerical-equivalence PBTs pass
under the new AVG generator domain (\`pPrime in [p+1, 11]\`).
- TPC-DS q18 plan-stability files regenerated locally.
- \`DecimalAggregatesBenchmark\` AVG cases (B2, B4) updated to \`pPrime =
11\`
(the new bound). Result files left stale relative to the code change --
the \`benchmark.yml\` GHA workflow can regenerate them.
## Was this patch authored or co-authored using generative AI tooling?
Generated-by: Claude Opus 4.7
Closes #56036 from cloud-fan/SPARK-decimal-aggregates-semantics-fix.
Authored-by: Wenchen Fan <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../spark/sql/catalyst/optimizer/Optimizer.scala | 36 ++--
.../optimizer/DecimalAggregatesSuite.scala | 218 ++++++++-------------
.../approved-plans-v1_4/q18.sf100/explain.txt | 8 +-
.../approved-plans-v1_4/q18.sf100/simplified.txt | 2 +-
.../approved-plans-v1_4/q18/explain.txt | 8 +-
.../approved-plans-v1_4/q18/simplified.txt | 2 +-
.../apache/spark/sql/DataFrameAggregateSuite.scala | 55 +++---
.../benchmark/DecimalAggregatesBenchmark.scala | 39 ++--
8 files changed, 152 insertions(+), 216 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index b3e7eb44ae65..95d774c6e991 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -2564,12 +2564,8 @@ object DecimalAggregates extends Rule[LogicalPlan] {
/** Maximum number of decimal digits representable precisely in a Double */
private val MAX_DOUBLE_DIGITS = 15
- /** Tighter than the AVG fast path's `prec + 4 <= MAX_DOUBLE_DIGITS` (= 11):
- * the strict-subset keeps SPARK-37024 Double-regime exposure unchanged. */
- private val AVG_PEEL_MAX_INNER_PRECISION = 7
-
- /** Matches a scale-preserving widening decimal Cast; refuses CheckOverflow
- * to preserve overflow semantics on the unscaled value. */
+ /** Matches a scale-preserving widening decimal Cast.
+ * Refuses CheckOverflow so per-row overflow checks are not hoisted out. */
private object WidenedDecimalChild {
def unapply(e: Expression): Option[(Expression, Int, Int, Int)] = e match {
case Cast(inner @ DecimalExpression(p, s), DecimalType.Fixed(pPrime,
sPrime), _, _)
@@ -2604,27 +2600,35 @@ object DecimalAggregates extends Rule[LogicalPlan] {
case _ => we
}
case ae @ AggregateExpression(af, _, _, _, _) => af match {
+ // Hoist a scale-preserving widening Cast out of Sum so the existing
+ // Long fast-path can fire on the inner. The MakeDecimal target type
+ // matches `Sum(Cast(x, dec(pPrime, s))).dataType` (see Sum.resultType)
+ // so the final-value overflow boundary is the same as the un-rewritten
+ // expression.
case s @ Sum(WidenedDecimalChild(inner, p, pPrime, s_scale), _)
if p + 10 <= MAX_LONG_DIGITS =>
- Cast(
- MakeDecimal(
- ae.copy(aggregateFunction = s.copy(child =
UnscaledValue(inner))),
- p + 10, s_scale),
- DecimalType.bounded(pPrime + 10, s_scale),
- Option(conf.sessionLocalTimeZone))
+ val target = DecimalType.bounded(pPrime + 10, s_scale)
+ MakeDecimal(
+ ae.copy(aggregateFunction = s.copy(child = UnscaledValue(inner))),
+ target.precision, target.scale)
case s @ Sum(e @ DecimalExpression(prec, scale), _) if prec + 10 <=
MAX_LONG_DIGITS =>
MakeDecimal(ae.copy(aggregateFunction = s.copy(child =
UnscaledValue(e))),
prec + 10, scale)
- // Ordered before the un-widened Average arm: when pPrime in [8, 11],
- // the outer Cast's DecimalType would otherwise match that arm first.
+ // Hoist a scale-preserving widening Cast out of Average. Guarded on
+ // the OUTER precision `pPrime + 4 <= MAX_DOUBLE_DIGITS` so the
+ // rewrite only fires inside the existing Double-regime envelope;
+ // for wider outer casts the un-rewritten Decimal-exact path is
+ // preserved. Ordered before the un-widened arm so the outer Cast's
+ // dataType does not let that arm intercept first (when pPrime <= 11,
+ // it would also match -- but on the outer Cast, not the inner).
case a @ Average(WidenedDecimalChild(inner, p, pPrime, s_scale), _)
- if p <= AVG_PEEL_MAX_INNER_PRECISION =>
+ if pPrime + 4 <= MAX_DOUBLE_DIGITS =>
val newAggExpr = ae.copy(aggregateFunction = a.copy(child =
UnscaledValue(inner)))
Cast(
Divide(newAggExpr, Literal.create(math.pow(10.0, s_scale),
DoubleType)),
- DecimalType.bounded(pPrime + 4, s_scale + 4),
Option(conf.sessionLocalTimeZone))
+ DecimalType(pPrime + 4, s_scale + 4),
Option(conf.sessionLocalTimeZone))
case a @ Average(e @ DecimalExpression(prec, scale), _) if prec + 4 <=
MAX_DOUBLE_DIGITS =>
val newAggExpr = ae.copy(aggregateFunction = a.copy(child =
UnscaledValue(e)))
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala
index 6f8c0db261b2..b65ce3a0f017 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala
@@ -74,13 +74,15 @@ class DecimalAggregatesSuite extends PlanTest with
ScalaCheckDrivenPropertyCheck
val testRelationC = LocalRelation($"c".decimal(7, 2))
- test("Decimal Average Aggregation widened-cast peel: Optimized (p=7,
p'=12)") {
+ test("Decimal Average Aggregation widened-cast peel: " +
+ "Not Optimized (pPrime+4 > MAX_DOUBLE_DIGITS preserves Decimal-exact
path)") {
+ // pPrime=12, pPrime+4=16 > 15. The new AVG arm only fires inside the
+ // existing Double-regime envelope (pPrime+4 <= 15); for wider outer casts
+ // the un-rewritten Decimal-exact path is preserved.
val widened = $"c".cast(DecimalType(12, 2))
val originalQuery = testRelationC.select(avg(widened))
val optimized = Optimize.execute(originalQuery.analyze)
- val correctAnswer = testRelationC
- .select((avg(UnscaledValue($"c")) / 100.0).cast(DecimalType(16, 6))
- .as("avg(CAST(c AS DECIMAL(12,2)))")).analyze
+ val correctAnswer = originalQuery.analyze
comparePlans(optimized, correctAnswer)
}
@@ -107,7 +109,10 @@ class DecimalAggregatesSuite extends PlanTest with
ScalaCheckDrivenPropertyCheck
comparePlans(optimized, correctAnswer)
}
- test("Decimal Average Aggregation widened-cast peel: Not Optimized (boundary
p=8)") {
+ test("Decimal Average Aggregation widened-cast peel: " +
+ "Not Optimized (pPrime+4 > MAX_DOUBLE_DIGITS, boundary)") {
+ // pPrime=13, pPrime+4=17 > 15. AVG peel does not fire; existing un-widened
+ // arm also does not fire on the outer Cast (same guard). Plan unchanged.
val testRelationE = LocalRelation($"e".decimal(8, 2))
val widened = $"e".cast(DecimalType(13, 2))
val originalQuery = testRelationE.select(avg(widened))
@@ -117,29 +122,25 @@ class DecimalAggregatesSuite extends PlanTest with
ScalaCheckDrivenPropertyCheck
comparePlans(optimized, correctAnswer)
}
- // SPARK-56627 F2 regression: with `pPrime in [8, 11]`, the outer Cast's
- // dataType `Decimal(pPrime, s)` would let the un-widened existing
- // `Average(DecimalExpression)` arm match first via `prec + 4 <=
MAX_DOUBLE_DIGITS`
- // (= pPrime <= 11). New AVG peel arm must be ordered before to win this band
- // and rewrite via the inner `p`-based UnscaledValue path.
+ // Cast-hoisting plan simplification: when pPrime+4 <= MAX_DOUBLE_DIGITS, the
+ // existing un-widened AVG arm would also match the outer Cast, but wraps
+ // UnscaledValue around the OUTER Cast (running the Cast per row). The new
+ // arm is ordered before so that UnscaledValue feeds directly off the inner.
test("Decimal Average Aggregation widened-cast peel: " +
- "Optimized for pPrime band [8, 11] (must beat existing AVG fast-path
arm)") {
+ "Optimized for pPrime band [p+1, 11] (drops per-row inner Cast)") {
val testRelationE = LocalRelation($"e".decimal(7, 2))
val widened = $"e".cast(DecimalType(10, 2))
val originalQuery = testRelationE.select(avg(widened).as("avg_widened"))
val optimized = Optimize.execute(originalQuery.analyze)
// Expected: peeled via WidenedDecimalChild(inner=e, p=7, pPrime=10, s=2),
- // outer Cast bounded(pPrime+4=14, s+4=6). NOT
- // `MakeDecimal(Sum(UnscaledValue(cast(e as dec(10,2)))), 14, 2)` (existing
- // arm form), which would lose F2's intent of avoiding the widened-cast
- // intermediate.
+ // outer Cast target DecimalType(pPrime+4=14, s+4=6).
val correctAnswer = testRelationE
.select(
Cast(
Divide(
avg(UnscaledValue($"e")),
Literal.create(math.pow(10.0, 2), DoubleType)),
- DecimalType.bounded(14, 6),
+ DecimalType(14, 6),
Option(conf.sessionLocalTimeZone))
.as("avg_widened"))
.analyze
@@ -147,18 +148,19 @@ class DecimalAggregatesSuite extends PlanTest with
ScalaCheckDrivenPropertyCheck
comparePlans(optimized, correctAnswer)
}
- // SPARK-56627 F1 regression: `WidenedDecimalChild` must NOT peel when the
- // inner expression is a `CheckOverflow` (introduced by `DecimalPrecision`
- // analyzer for nullOnOverflow semantics). Peeling through `CheckOverflow`
- // would change the overflow behavior of the inner aggregate.
+ // WidenedDecimalChild must NOT peel when the inner expression is a
+ // CheckOverflow (introduced by DecimalPrecision for nullOnOverflow
+ // semantics). Peeling through CheckOverflow would hoist a per-row
+ // overflow check out of the aggregate.
//
- // The existing un-widened `Average(DecimalExpression)` arm still fires on
- // the outer Cast (dataType `Decimal(pPrime=10, s=2)`, `pPrime + 4 = 14 <=
15`),
- // so the optimized plan wraps `UnscaledValue` around the OUTER cast (not
- // the inner CheckOverflow). The peel-arm-fired form would instead be
- // `UnscaledValue(CheckOverflow(e))` (no outer cast), which we want to AVOID.
+ // The existing un-widened Average(DecimalExpression) arm still fires on
+ // the outer Cast (dataType Decimal(pPrime=10, s=2), pPrime + 4 = 14 <= 15),
+ // so the optimized plan wraps UnscaledValue around the OUTER cast. Without
+ // the CheckOverflow guard the peel arm would feed UnscaledValue off the
+ // inner CheckOverflow instead, which we want to AVOID.
test("Decimal Average Aggregation widened-cast peel: " +
- "Not peeled for Cast(CheckOverflow(inner), wider) form (F1 guard)") {
+ "Not peeled for Cast(CheckOverflow(inner), wider) form " +
+ "(CheckOverflow guard)") {
val testRelationE = LocalRelation($"e".decimal(7, 2))
val co = CheckOverflow($"e", DecimalType(7, 2), nullOnOverflow = true)
val widened = Cast(co, DecimalType(10, 2))
@@ -255,28 +257,22 @@ class DecimalAggregatesSuite extends PlanTest with
ScalaCheckDrivenPropertyCheck
$"i".int)
test("SPARK-56627: SUM(CAST(dec(7,2) AS dec(17,2))) peels via widened-Cast
fast path") {
- // Witness chosen so p+10=17 <= MAX_LONG_DIGITS(18) < pPrime+10=27 -- the
- // new case fires (a bare-Cast inner cannot fall through to the existing
- // DecimalExpression case). Expected shape:
- // Cast(MakeDecimal(Sum(UnscaledValue(d7_2)), p+10=17, s=2),
- // DecimalType.bounded(pPrime+10=27, s=2),
- // Option(conf.sessionLocalTimeZone))
+ // Cast-hoisting framing: SUM(Cast(x, dec(pPrime, s))) is rewritten to
+ // SUM(x) wrapped in a MakeDecimal whose precision equals the un-rewritten
+ // Sum's output type `min(pPrime + 10, 38)`. Expected shape:
+ // MakeDecimal(Sum(UnscaledValue(d7_2)), min(pPrime+10, 38)=27, s=2)
val q = widenRel.select(sum($"d7_2".cast(DecimalType(17, 2))))
val optimized = Optimize.execute(q.analyze)
val correctAnswer = widenRel
- .select(Cast(
- MakeDecimal(sum(UnscaledValue($"d7_2")), 17, 2),
- DecimalType.bounded(27, 2),
- Option(conf.sessionLocalTimeZone))
+ .select(MakeDecimal(sum(UnscaledValue($"d7_2")), 27, 2)
.as("sum(CAST(d7_2 AS DECIMAL(17,2)))")).analyze
comparePlans(optimized, correctAnswer)
}
test("SPARK-56627: SUM(CAST(dec(7,2) AS dec(17,2))) -- peel preserves
schema") {
- // Schema invariance via DataType equality (not string).
- // Top-level output type of SUM(dec(p,s)) is DecimalType(min(p+10,38), s);
- // peeled tree wraps inner with outer Cast(_, dec(pPrime+10,s)) = dec(27,2)
- // -- identical to baseline schema.
+ // Schema invariance via DataType equality. Top-level output type of
+ // `SUM(Cast(x, dec(pPrime, s)))` is `DecimalType.bounded(pPrime+10, s)`;
+ // the peeled MakeDecimal target precision matches.
val q = widenRel.select(sum($"d7_2".cast(DecimalType(17, 2))))
val baselineSchema = q.analyze.schema
val optimized = Optimize.execute(q.analyze)
@@ -293,8 +289,9 @@ class DecimalAggregatesSuite extends PlanTest with
ScalaCheckDrivenPropertyCheck
comparePlans(optimized, correctAnswer)
}
- test("SPARK-56627: AVG(CAST(dec(7,2) AS dec(17,2))) -- peel preserves
schema") {
- val q = widenRel.select(avg($"d7_2".cast(DecimalType(17, 2))))
+ test("SPARK-56627: AVG(CAST(dec(7,2) AS dec(10,2))) -- peel preserves
schema") {
+ // Witness inside the new AVG peel bound (pPrime+4 = 14 <= 15).
+ val q = widenRel.select(avg($"d7_2".cast(DecimalType(10, 2))))
val baselineSchema = q.analyze.schema
val optimized = Optimize.execute(q.analyze)
assert(optimized.schema === baselineSchema,
@@ -387,10 +384,7 @@ class DecimalAggregatesSuite extends PlanTest with
ScalaCheckDrivenPropertyCheck
val q = widenRel.select(sum(nullLit.cast(DecimalType(17, 2))))
val optimized = Optimize.execute(q.analyze)
val correctAnswer = widenRel
- .select(Cast(
- MakeDecimal(sum(UnscaledValue(nullLit)), 17, 2),
- DecimalType.bounded(27, 2),
- Option(conf.sessionLocalTimeZone))
+ .select(MakeDecimal(sum(UnscaledValue(nullLit)), 27, 2)
.as("sum(CAST(NULL AS DECIMAL(17,2)))")).analyze
comparePlans(optimized, correctAnswer)
}
@@ -401,10 +395,7 @@ class DecimalAggregatesSuite extends PlanTest with
ScalaCheckDrivenPropertyCheck
val q = emptyRel.select(sum($"d7_2".cast(DecimalType(17, 2))))
val optimized = Optimize.execute(q.analyze)
val correctAnswer = emptyRel
- .select(Cast(
- MakeDecimal(sum(UnscaledValue($"d7_2")), 17, 2),
- DecimalType.bounded(27, 2),
- Option(conf.sessionLocalTimeZone))
+ .select(MakeDecimal(sum(UnscaledValue($"d7_2")), 27, 2)
.as("sum(CAST(d7_2 AS DECIMAL(17,2)))")).analyze
comparePlans(optimized, correctAnswer)
}
@@ -453,69 +444,39 @@ class DecimalAggregatesSuite extends PlanTest with
ScalaCheckDrivenPropertyCheck
comparePlans(optimized, q)
}
- // Plan-shape property: structural invariants on the peeled tree.
+ // Plan-shape property: structural invariants on the peeled SUM tree.
//
- // Sweeps the (p, p', s) lattice where the widened-cast peel fires:
- // regime (ii): p + 10 <= 18 <= p' + 10 (new arm, old fast-path off)
- // regime (iii): p + 10 <= 18 < p' + 10 <= 38
- // Assertion (peel-on, structural -- NOT a hand-typed RHS clone):
- // - aggregate expression is wrapped by exactly one outer Cast
- // - the outer Cast wraps exactly one MakeDecimal
- // - inside MakeDecimal, the Sum's child has dataType=LongType (i.e.
- // UnscaledValue was inserted)
- // - outer Cast target precision = p' + 10 (or 38, clamped)
- // - outer Cast target scale = s
- // Reframed away from RHS-equality to detect behavioural regressions
- // rather than just refactor drift.
- // Peel-off branch: plan is unchanged relative to its analyzed form
- // (the local RuleExecutor runs only DecimalAggregates; no other rule
- // can rewrite the SUM when the peel does not fire for a Cast child).
+ // Sweeps the (p, p', s) lattice where the widened-cast SUM peel fires:
+ // p + 10 <= 18 and p' > p, with p' <= 38. The rewrite produces a single
+ // MakeDecimal at precision min(p' + 10, 38) wrapping Sum(UnscaledValue(x)).
+ // I1. exactly one Sum node, whose child has LongType.
+ // I2. exactly one MakeDecimal node, with precision = min(p' + 10, 38)
+ // and scale = s -- matches Sum(Cast(x, dec(p', s))).dataType, so the
+ // final-value overflow boundary is unchanged from un-rewritten.
private case class PeelInputs(p: Int, pPrime: Int, s: Int)
- private val peelGen: Gen[PeelInputs] = Gen.frequency(
- 5 -> (for {
- p <- Gen.choose(1, 8)
- pPrime <- Gen.choose(math.max(p + 1, 9), 28)
- s <- Gen.choose(0, p)
- } yield PeelInputs(p, pPrime, s)),
- 5 -> (for {
- p <- Gen.choose(1, 8)
- pPrime <- Gen.choose(9, 28)
- s <- Gen.choose(0, p)
- } yield PeelInputs(p, pPrime, s))
- )
-
- private val boundaryGen: Gen[PeelInputs] = Gen.oneOf(
- PeelInputs(7, 17, 2), PeelInputs(7, 18, 2), PeelInputs(7, 19, 2))
-
- private val peelSpaceGen: Gen[PeelInputs] = Gen.frequency(
- 8 -> peelGen,
- 2 -> boundaryGen
- ).retryUntil(in => in.p + 10 <= 18 && in.p < in.pPrime && in.pPrime + 10 <=
38)
+ // Bounds already enforce the peel-firing predicate:
+ // p + 10 <= 18 (p <= 8), p < pPrime (pPrime >= p+1), pPrime + 10 <= 38
+ // (pPrime <= 28).
+ private val peelGen: Gen[PeelInputs] = for {
+ p <- Gen.choose(1, 8)
+ pPrime <- Gen.choose(p + 1, 28)
+ s <- Gen.choose(0, p)
+ } yield PeelInputs(p, pPrime, s)
implicit override val generatorDrivenConfig: PropertyCheckConfiguration =
PropertyCheckConfiguration(minSuccessful = 50, minSize = 0, sizeRange = 0)
test("SPARK-56627: DecimalAggregates widened-Cast SUM peel -- plan-shape " +
"structural-invariants property") {
- forAll(peelSpaceGen) { in =>
+ forAll(peelGen) { in =>
val rel = LocalRelation($"x".decimal(in.p, in.s))
val q = rel.select(sum($"x".cast(DecimalType(in.pPrime, in.s))))
val analyzed = q.analyze
val optimized = Optimize.execute(analyzed)
- // Structural invariants the peel rewrite must establish, regardless
- // of incidental tree-shape changes from neighbouring rules:
- //
- // I1. exactly one Sum node, whose child has LongType (the peeled
- // UnscaledValue feed);
- // I2. exactly one MakeDecimal node in the tree (rebuilds Decimal
- // from the LONG accumulator);
- // I3. an outer Cast whose target DecimalType has precision at
- // least as wide as the user-written widened cast, so we never
- // narrow result precision below the baseline plan.
val sums = optimized.expressions.flatMap(_.collect { case s: Sum => s })
assert(sums.size == 1, s"expected exactly 1 Sum, got ${sums.size} in
$optimized")
assert(sums.head.child.dataType == LongType,
@@ -524,43 +485,23 @@ class DecimalAggregatesSuite extends PlanTest with
ScalaCheckDrivenPropertyCheck
val mds = optimized.expressions.flatMap(_.collect { case m: MakeDecimal
=> m })
assert(mds.size == 1,
s"expected exactly 1 MakeDecimal, got ${mds.size} in $optimized")
-
- val outerCasts = optimized.expressions.flatMap(_.collect {
- case c @ Cast(_, _: DecimalType, _, _) => c
- })
- assert(outerCasts.nonEmpty,
- s"expected an outer Cast to DecimalType, got none in $optimized")
- val outerPrec =
outerCasts.map(_.dataType.asInstanceOf[DecimalType].precision).max
- assert(outerPrec >= in.pPrime,
- s"outer Cast precision $outerPrec < baseline ${in.pPrime} in
$optimized")
+ val expectedPrec = math.min(in.pPrime + 10, DecimalType.MAX_PRECISION)
+ assert(mds.head.precision == expectedPrec && mds.head.scale == in.s,
+ s"expected MakeDecimal($expectedPrec, ${in.s}), got " +
+ s"MakeDecimal(${mds.head.precision}, ${mds.head.scale}) in
$optimized")
}
}
- //
---------------------------------------------------------------------------
- // F5 (skeptic round 1): Long-accumulator / Double-regime safety boundary
- // invariant guards.
- //
- // Background: a strict "overflow oracle" cannot be written at unit-test
- // scale -- the existing fast-path guards (`p + 10 <= MAX_LONG_DIGITS = 18`
- // for SUM, `AVG_PEEL_MAX_INNER_PRECISION = 7` for AVG) keep the
peel-eligible
- // inner-precision band so narrow that the Long accumulator (~9.22e18) cannot
- // wrap on any reachable peel input: at `p=8` we'd need ~9.22e10 rows. So
- // there is no production input that exercises a "peeled Long-wrap vs
- // un-peeled CheckOverflow" asymmetry to oracle against.
- //
- // What we CAN lock is the boundary itself: if someone in the future relaxes
- // either guard (raising `MAX_LONG_DIGITS - 10` for SUM, or
- // `AVG_PEEL_MAX_INNER_PRECISION` for AVG), the input shapes below WOULD
- // start peeling -- and the assertion that the rule is a no-op for these
- // inputs would fail. That is the safety net we want: a mechanical guard
- // that catches accidental widening of the peel-trigger surface.
+ // Safety-boundary guards: pin the SUM Long-fast-path and AVG
Double-fast-path
+ // bounds. If either guard is later relaxed (raising `MAX_LONG_DIGITS - 10`
+ // for SUM, or relaxing `pPrime + 4 <= MAX_DOUBLE_DIGITS` for AVG), the input
+ // shapes below would start peeling and these tests would fail, flagging the
+ // change for re-review.
test("SPARK-56627: SUM(CAST(dec(9,2) AS dec(19,2))) does NOT peel " +
"(Long-accumulator safety boundary)") {
- // Boundary witness: inner p=9 makes widened-arm `p + 10 = 19 > 18` reject,
- // AND outer-cast existing-arm `prec + 10 = 29 > 18` reject. Both arms are
- // no-ops by design -- peel cannot fire on this shape today, and must not
- // start firing if the inner-precision band is later widened without
- // re-deriving the Long-accumulator bound.
+ // Inner p=9 makes the widened-arm guard p + 10 = 19 > 18 reject. The
+ // existing un-widened arm also rejects (prec + 10 = 29 > 18 on the outer
+ // Cast). Both arms are no-ops by design.
val q = widenRel.select(sum($"d9_2".cast(DecimalType(19, 2))))
val optimized = Optimize.execute(q.analyze)
val correctAnswer = q.analyze
@@ -568,15 +509,10 @@ class DecimalAggregatesSuite extends PlanTest with
ScalaCheckDrivenPropertyCheck
}
test("SPARK-56627: AVG(CAST(dec(8,2) AS dec(20,2))) does NOT peel " +
- "(Double-regime / SPARK-37024 safety boundary)") {
- // Boundary witness: inner p=8 makes widened-AVG arm
- // `p > AVG_PEEL_MAX_INNER_PRECISION (7)` reject, AND outer-cast existing
- // AVG arm `prec + 4 = 24 > MAX_DOUBLE_DIGITS (15)` reject. The strict-
- // subset guard `p <= 7` keeps this rule's trigger surface strictly
- // inside the existing AVG fast path's surface, so SPARK-37024
- // (Double-regime silent precision loss) is not amplified. If someone
- // raises `AVG_PEEL_MAX_INNER_PRECISION` past 7 without first fixing
- // SPARK-37024, this test will start firing and flag the regression.
+ "(Double-regime safety boundary)") {
+ // pPrime=20, pPrime+4 = 24 > 15 rejects the widened AVG peel arm. The
+ // existing un-widened AVG arm also rejects on the outer Cast (same
+ // guard). Plan unchanged.
val q = widenRel.select(avg($"d8_2".cast(DecimalType(20, 2))))
val optimized = Optimize.execute(q.analyze)
val correctAnswer = q.analyze
@@ -638,11 +574,13 @@ class DecimalAggregatesSuite extends PlanTest with
ScalaCheckDrivenPropertyCheck
test("SPARK-56949: DecimalAggregates preserves Average.evalMode " +
"for try_avg on widened-cast peel arm") {
- val tryAvg = Average($"d7_2".cast(DecimalType(12, 2)), EvalMode.TRY)
+ // pPrime=10 keeps pPrime+4=14 <= MAX_DOUBLE_DIGITS so the AVG peel arm
+ // fires. (pPrime=12 is outside the new bound; see SPARK-56983.)
+ val tryAvg = Average($"d7_2".cast(DecimalType(10, 2)), EvalMode.TRY)
val q = widenRel.select(tryAvg.toAggregateExpression().as("ta"))
val optimized = Optimize.execute(q.analyze)
val avgs = findAverage(optimized)
- assert(avgs.nonEmpty, "widened-cast AVG peel should fire for
dec(7,2)->dec(12,2)")
+ assert(avgs.nonEmpty, "widened-cast AVG peel should fire for
dec(7,2)->dec(10,2)")
assert(avgs.forall(_.evalMode == EvalMode.TRY),
s"evalMode should be preserved as TRY after rewrite, got " +
avgs.map(_.evalMode).mkString(","))
diff --git
a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18.sf100/explain.txt
b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18.sf100/explain.txt
index ff0b0e468530..f7c0dcd7c56b 100644
---
a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18.sf100/explain.txt
+++
b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18.sf100/explain.txt
@@ -257,7 +257,7 @@ Arguments: [[cs_quantity#4, cs_list_price#5,
cs_sales_price#6, cs_coupon_amt#7,
(46) HashAggregate [codegen id : 13]
Input [12]: [cs_quantity#4, cs_list_price#5, cs_sales_price#6,
cs_coupon_amt#7, cs_net_profit#8, cd_dep_count#14, c_birth_year#22,
i_item_id#28, ca_country#29, ca_state#30, ca_county#31, spark_grouping_id#32]
Keys [5]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31,
spark_grouping_id#32]
-Functions [7]: [partial_avg(cast(cs_quantity#4 as decimal(12,2))),
partial_avg(UnscaledValue(cs_list_price#5)),
partial_avg(UnscaledValue(cs_coupon_amt#7)),
partial_avg(UnscaledValue(cs_sales_price#6)),
partial_avg(UnscaledValue(cs_net_profit#8)), partial_avg(cast(c_birth_year#22
as decimal(12,2))), partial_avg(cast(cd_dep_count#14 as decimal(12,2)))]
+Functions [7]: [partial_avg(cast(cs_quantity#4 as decimal(12,2))),
partial_avg(cast(cs_list_price#5 as decimal(12,2))),
partial_avg(cast(cs_coupon_amt#7 as decimal(12,2))),
partial_avg(cast(cs_sales_price#6 as decimal(12,2))),
partial_avg(cast(cs_net_profit#8 as decimal(12,2))),
partial_avg(cast(c_birth_year#22 as decimal(12,2))),
partial_avg(cast(cd_dep_count#14 as decimal(12,2)))]
Aggregate Attributes [14]: [sum#33, count#34, sum#35, count#36, sum#37,
count#38, sum#39, count#40, sum#41, count#42, sum#43, count#44, sum#45,
count#46]
Results [19]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31,
spark_grouping_id#32, sum#47, count#48, sum#49, count#50, sum#51, count#52,
sum#53, count#54, sum#55, count#56, sum#57, count#58, sum#59, count#60]
@@ -268,9 +268,9 @@ Arguments: hashpartitioning(i_item_id#28, ca_country#29,
ca_state#30, ca_county#
(48) HashAggregate [codegen id : 14]
Input [19]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31,
spark_grouping_id#32, sum#47, count#48, sum#49, count#50, sum#51, count#52,
sum#53, count#54, sum#55, count#56, sum#57, count#58, sum#59, count#60]
Keys [5]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31,
spark_grouping_id#32]
-Functions [7]: [avg(cast(cs_quantity#4 as decimal(12,2))),
avg(UnscaledValue(cs_list_price#5)), avg(UnscaledValue(cs_coupon_amt#7)),
avg(UnscaledValue(cs_sales_price#6)), avg(UnscaledValue(cs_net_profit#8)),
avg(cast(c_birth_year#22 as decimal(12,2))), avg(cast(cd_dep_count#14 as
decimal(12,2)))]
-Aggregate Attributes [7]: [avg(cast(cs_quantity#4 as decimal(12,2)))#61,
avg(UnscaledValue(cs_list_price#5))#62, avg(UnscaledValue(cs_coupon_amt#7))#63,
avg(UnscaledValue(cs_sales_price#6))#64,
avg(UnscaledValue(cs_net_profit#8))#65, avg(cast(c_birth_year#22 as
decimal(12,2)))#66, avg(cast(cd_dep_count#14 as decimal(12,2)))#67]
-Results [11]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31,
avg(cast(cs_quantity#4 as decimal(12,2)))#61 AS agg1#68,
cast((avg(UnscaledValue(cs_list_price#5))#62 / 100.0) as decimal(16,6)) AS
agg2#69, cast((avg(UnscaledValue(cs_coupon_amt#7))#63 / 100.0) as
decimal(16,6)) AS agg3#70, cast((avg(UnscaledValue(cs_sales_price#6))#64 /
100.0) as decimal(16,6)) AS agg4#71,
cast((avg(UnscaledValue(cs_net_profit#8))#65 / 100.0) as decimal(16,6)) AS
agg5#72, avg(cast(c_birth_year#22 as [...]
+Functions [7]: [avg(cast(cs_quantity#4 as decimal(12,2))),
avg(cast(cs_list_price#5 as decimal(12,2))), avg(cast(cs_coupon_amt#7 as
decimal(12,2))), avg(cast(cs_sales_price#6 as decimal(12,2))),
avg(cast(cs_net_profit#8 as decimal(12,2))), avg(cast(c_birth_year#22 as
decimal(12,2))), avg(cast(cd_dep_count#14 as decimal(12,2)))]
+Aggregate Attributes [7]: [avg(cast(cs_quantity#4 as decimal(12,2)))#61,
avg(cast(cs_list_price#5 as decimal(12,2)))#62, avg(cast(cs_coupon_amt#7 as
decimal(12,2)))#63, avg(cast(cs_sales_price#6 as decimal(12,2)))#64,
avg(cast(cs_net_profit#8 as decimal(12,2)))#65, avg(cast(c_birth_year#22 as
decimal(12,2)))#66, avg(cast(cd_dep_count#14 as decimal(12,2)))#67]
+Results [11]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31,
avg(cast(cs_quantity#4 as decimal(12,2)))#61 AS agg1#68,
avg(cast(cs_list_price#5 as decimal(12,2)))#62 AS agg2#69,
avg(cast(cs_coupon_amt#7 as decimal(12,2)))#63 AS agg3#70,
avg(cast(cs_sales_price#6 as decimal(12,2)))#64 AS agg4#71,
avg(cast(cs_net_profit#8 as decimal(12,2)))#65 AS agg5#72,
avg(cast(c_birth_year#22 as decimal(12,2)))#66 AS agg6#73,
avg(cast(cd_dep_count#14 as decimal(12,2)))#67 AS agg7#74]
(49) TakeOrderedAndProject
Input [11]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31, agg1#68,
agg2#69, agg3#70, agg4#71, agg5#72, agg6#73, agg7#74]
diff --git
a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18.sf100/simplified.txt
b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18.sf100/simplified.txt
index 079bb6aba3ec..276165729be5 100644
---
a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18.sf100/simplified.txt
+++
b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18.sf100/simplified.txt
@@ -1,6 +1,6 @@
TakeOrderedAndProject
[ca_country,ca_state,ca_county,i_item_id,agg1,agg2,agg3,agg4,agg5,agg6,agg7]
WholeStageCodegen (14)
- HashAggregate
[i_item_id,ca_country,ca_state,ca_county,spark_grouping_id,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count]
[avg(cast(cs_quantity as
decimal(12,2))),avg(UnscaledValue(cs_list_price)),avg(UnscaledValue(cs_coupon_amt)),avg(UnscaledValue(cs_sales_price)),avg(UnscaledValue(cs_net_profit)),avg(cast(c_birth_year
as decimal(12,2))),avg(cast(cd_dep_count as
decimal(12,2))),agg1,agg2,agg3,agg4,agg5,agg6,agg7,sum,count,sum,count,sum,count,sum,count,sum,count
[...]
+ HashAggregate
[i_item_id,ca_country,ca_state,ca_county,spark_grouping_id,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count]
[avg(cast(cs_quantity as decimal(12,2))),avg(cast(cs_list_price as
decimal(12,2))),avg(cast(cs_coupon_amt as
decimal(12,2))),avg(cast(cs_sales_price as
decimal(12,2))),avg(cast(cs_net_profit as decimal(12,2))),avg(cast(c_birth_year
as decimal(12,2))),avg(cast(cd_dep_count as
decimal(12,2))),agg1,agg2,agg3,agg4,agg5,agg6,agg7,sum,count,sum,cou [...]
InputAdapter
Exchange [i_item_id,ca_country,ca_state,ca_county,spark_grouping_id] #1
WholeStageCodegen (13)
diff --git
a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18/explain.txt
b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18/explain.txt
index 8f25c83767ff..7db1c87c52a6 100644
---
a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18/explain.txt
+++
b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18/explain.txt
@@ -227,7 +227,7 @@ Arguments: [[cs_quantity#4, cs_list_price#5,
cs_sales_price#6, cs_coupon_amt#7,
(40) HashAggregate [codegen id : 7]
Input [12]: [cs_quantity#4, cs_list_price#5, cs_sales_price#6,
cs_coupon_amt#7, cs_net_profit#8, cd_dep_count#14, c_birth_year#19,
i_item_id#28, ca_country#29, ca_state#30, ca_county#31, spark_grouping_id#32]
Keys [5]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31,
spark_grouping_id#32]
-Functions [7]: [partial_avg(cast(cs_quantity#4 as decimal(12,2))),
partial_avg(UnscaledValue(cs_list_price#5)),
partial_avg(UnscaledValue(cs_coupon_amt#7)),
partial_avg(UnscaledValue(cs_sales_price#6)),
partial_avg(UnscaledValue(cs_net_profit#8)), partial_avg(cast(c_birth_year#19
as decimal(12,2))), partial_avg(cast(cd_dep_count#14 as decimal(12,2)))]
+Functions [7]: [partial_avg(cast(cs_quantity#4 as decimal(12,2))),
partial_avg(cast(cs_list_price#5 as decimal(12,2))),
partial_avg(cast(cs_coupon_amt#7 as decimal(12,2))),
partial_avg(cast(cs_sales_price#6 as decimal(12,2))),
partial_avg(cast(cs_net_profit#8 as decimal(12,2))),
partial_avg(cast(c_birth_year#19 as decimal(12,2))),
partial_avg(cast(cd_dep_count#14 as decimal(12,2)))]
Aggregate Attributes [14]: [sum#33, count#34, sum#35, count#36, sum#37,
count#38, sum#39, count#40, sum#41, count#42, sum#43, count#44, sum#45,
count#46]
Results [19]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31,
spark_grouping_id#32, sum#47, count#48, sum#49, count#50, sum#51, count#52,
sum#53, count#54, sum#55, count#56, sum#57, count#58, sum#59, count#60]
@@ -238,9 +238,9 @@ Arguments: hashpartitioning(i_item_id#28, ca_country#29,
ca_state#30, ca_county#
(42) HashAggregate [codegen id : 8]
Input [19]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31,
spark_grouping_id#32, sum#47, count#48, sum#49, count#50, sum#51, count#52,
sum#53, count#54, sum#55, count#56, sum#57, count#58, sum#59, count#60]
Keys [5]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31,
spark_grouping_id#32]
-Functions [7]: [avg(cast(cs_quantity#4 as decimal(12,2))),
avg(UnscaledValue(cs_list_price#5)), avg(UnscaledValue(cs_coupon_amt#7)),
avg(UnscaledValue(cs_sales_price#6)), avg(UnscaledValue(cs_net_profit#8)),
avg(cast(c_birth_year#19 as decimal(12,2))), avg(cast(cd_dep_count#14 as
decimal(12,2)))]
-Aggregate Attributes [7]: [avg(cast(cs_quantity#4 as decimal(12,2)))#61,
avg(UnscaledValue(cs_list_price#5))#62, avg(UnscaledValue(cs_coupon_amt#7))#63,
avg(UnscaledValue(cs_sales_price#6))#64,
avg(UnscaledValue(cs_net_profit#8))#65, avg(cast(c_birth_year#19 as
decimal(12,2)))#66, avg(cast(cd_dep_count#14 as decimal(12,2)))#67]
-Results [11]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31,
avg(cast(cs_quantity#4 as decimal(12,2)))#61 AS agg1#68,
cast((avg(UnscaledValue(cs_list_price#5))#62 / 100.0) as decimal(16,6)) AS
agg2#69, cast((avg(UnscaledValue(cs_coupon_amt#7))#63 / 100.0) as
decimal(16,6)) AS agg3#70, cast((avg(UnscaledValue(cs_sales_price#6))#64 /
100.0) as decimal(16,6)) AS agg4#71,
cast((avg(UnscaledValue(cs_net_profit#8))#65 / 100.0) as decimal(16,6)) AS
agg5#72, avg(cast(c_birth_year#19 as [...]
+Functions [7]: [avg(cast(cs_quantity#4 as decimal(12,2))),
avg(cast(cs_list_price#5 as decimal(12,2))), avg(cast(cs_coupon_amt#7 as
decimal(12,2))), avg(cast(cs_sales_price#6 as decimal(12,2))),
avg(cast(cs_net_profit#8 as decimal(12,2))), avg(cast(c_birth_year#19 as
decimal(12,2))), avg(cast(cd_dep_count#14 as decimal(12,2)))]
+Aggregate Attributes [7]: [avg(cast(cs_quantity#4 as decimal(12,2)))#61,
avg(cast(cs_list_price#5 as decimal(12,2)))#62, avg(cast(cs_coupon_amt#7 as
decimal(12,2)))#63, avg(cast(cs_sales_price#6 as decimal(12,2)))#64,
avg(cast(cs_net_profit#8 as decimal(12,2)))#65, avg(cast(c_birth_year#19 as
decimal(12,2)))#66, avg(cast(cd_dep_count#14 as decimal(12,2)))#67]
+Results [11]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31,
avg(cast(cs_quantity#4 as decimal(12,2)))#61 AS agg1#68,
avg(cast(cs_list_price#5 as decimal(12,2)))#62 AS agg2#69,
avg(cast(cs_coupon_amt#7 as decimal(12,2)))#63 AS agg3#70,
avg(cast(cs_sales_price#6 as decimal(12,2)))#64 AS agg4#71,
avg(cast(cs_net_profit#8 as decimal(12,2)))#65 AS agg5#72,
avg(cast(c_birth_year#19 as decimal(12,2)))#66 AS agg6#73,
avg(cast(cd_dep_count#14 as decimal(12,2)))#67 AS agg7#74]
(43) TakeOrderedAndProject
Input [11]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31, agg1#68,
agg2#69, agg3#70, agg4#71, agg5#72, agg6#73, agg7#74]
diff --git
a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18/simplified.txt
b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18/simplified.txt
index 7c3075e26fa2..269bfd3f44fc 100644
---
a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18/simplified.txt
+++
b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18/simplified.txt
@@ -1,6 +1,6 @@
TakeOrderedAndProject
[ca_country,ca_state,ca_county,i_item_id,agg1,agg2,agg3,agg4,agg5,agg6,agg7]
WholeStageCodegen (8)
- HashAggregate
[i_item_id,ca_country,ca_state,ca_county,spark_grouping_id,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count]
[avg(cast(cs_quantity as
decimal(12,2))),avg(UnscaledValue(cs_list_price)),avg(UnscaledValue(cs_coupon_amt)),avg(UnscaledValue(cs_sales_price)),avg(UnscaledValue(cs_net_profit)),avg(cast(c_birth_year
as decimal(12,2))),avg(cast(cd_dep_count as
decimal(12,2))),agg1,agg2,agg3,agg4,agg5,agg6,agg7,sum,count,sum,count,sum,count,sum,count,sum,count
[...]
+ HashAggregate
[i_item_id,ca_country,ca_state,ca_county,spark_grouping_id,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count]
[avg(cast(cs_quantity as decimal(12,2))),avg(cast(cs_list_price as
decimal(12,2))),avg(cast(cs_coupon_amt as
decimal(12,2))),avg(cast(cs_sales_price as
decimal(12,2))),avg(cast(cs_net_profit as decimal(12,2))),avg(cast(c_birth_year
as decimal(12,2))),avg(cast(cd_dep_count as
decimal(12,2))),agg1,agg2,agg3,agg4,agg5,agg6,agg7,sum,count,sum,cou [...]
InputAdapter
Exchange [i_item_id,ca_country,ca_state,ca_county,spark_grouping_id] #1
WholeStageCodegen (7)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 180dd5d5db94..694b087d6c3f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -4822,11 +4822,10 @@ class DataFrameAggregateSuite extends SharedSparkSession
// Numerical-equivalence property (sql-core layer).
//
- // Sweeps the (p, p', s, n) lattice where the widened-cast peel fires,
- // asserting that SUM(CAST(x AS DECIMAL(p', s))) on an on-vs-off SQLConf
- // pair returns bit-equal java.math.BigDecimal (same unscaled value AND
- // same scale). Domain is restricted to the non-overflow regime so the
- // peeled LONG accumulator cannot wrap.
+ // Sweeps the (p, p', s, n) lattice where the widened-cast SUM peel fires
+ // and asserts that the optimized result matches an external
java.math.BigDecimal
+ // reference computed in pure Scala. Domain is restricted to the non-overflow
+ // regime so the peeled LONG accumulator cannot wrap.
//
// Non-overflow bound: with |unscaled(x)| < 10^p, p <= 8, n <= 1000,
// worst-case accumulator is 1000 * (10^8 - 1) < 10^12 << 2^63.
@@ -4913,29 +4912,26 @@ class DataFrameAggregateSuite extends SharedSparkSession
// identical to the existing fast path on AVG(x) directly. Both arms in
// Optimizer.DecimalAggregates produce
// Cast(Divide(Avg(UnscaledValue(<inner>)), Lit(10^s, Double)),
- // DecimalType.bounded(<outerP>, s + 4))
+ // DecimalType(<outerP>, s + 4))
// and the peel arm makes <inner> equal to the user's column, so the
// Double-divide dividends are bit-identical between the two paths; only
// the outer Cast target precision differs (pPrime+4 vs p+4), a widening
// precision Cast that preserves numerical value. We therefore assert
// BigDecimal.compareTo == 0 (value equality across differing precisions).
//
- // Domain: inner p in [1, 7] (the AVG strict-subset guard
- // `AVG_PEEL_MAX_INNER_PRECISION = 7`), pPrime in [8, 11] (the band where
- // the existing `Average(DecimalExpression)` arm would intercept on the
- // outer Cast type if not for our prepended arm), s in [0, p],
- // n <= 1000 rows. The inner DataFrame schema is constructed as
- // DecimalType(p, s) explicitly (NOT via tuple-inference, which would
- // infer DecimalType.SYSTEM_DEFAULT and silently route through a DIFFERENT
- // rule arm than intended -- the failure mode this PBT must lock down).
+ // Domain: pPrime in [p+1, 11] -- the band where pPrime + 4 <=
MAX_DOUBLE_DIGITS
+ // so the new arm fires and the existing un-widened arm would also have
+ // matched the outer Cast (allowing comparison against AVG(x) as oracle).
+ // The inner DataFrame schema is constructed as DecimalType(p, s) explicitly
+ // (NOT via tuple-inference, which would infer DecimalType.SYSTEM_DEFAULT and
+ // silently route through a DIFFERENT rule arm than intended).
private case class AvgDomain(p: Int, pPrime: Int, s: Int)
private val avgDomainGen: Gen[AvgDomain] = (for {
- p <- Gen.choose(1, 7)
- pPrime <- Gen.choose(8, 11)
+ p <- Gen.choose(1, 10)
+ pPrime <- Gen.choose(p + 1, 11)
s <- Gen.choose(0, p)
} yield AvgDomain(p, pPrime, s))
- .retryUntil(d => d.p < d.pPrime)
private def avgInputDf(unscaledLongs: Seq[Long], d: AvgDomain) = {
val rows = unscaledLongs.map(u => Row(java.math.BigDecimal.valueOf(u,
d.s)))
@@ -4972,7 +4968,7 @@ class DataFrameAggregateSuite extends SharedSparkSession
val peeled = avgCastResult(xs, d)
val direct = avgDirectResult(xs, d)
// BigDecimal.compareTo ignores trailing-zero precision differences:
- // peeled has output DecimalType.bounded(pPrime+4, s+4), direct has
+ // peeled has output DecimalType(pPrime+4, s+4), direct has
// DecimalType(p+4, s+4). Both wrap the same Double-divide bit pattern
// so the underlying value is identical.
assert(peeled.compareTo(direct) == 0,
@@ -4982,17 +4978,13 @@ class DataFrameAggregateSuite extends SharedSparkSession
}
}
- // Wider-pPrime regime shape witness: (p=4, p'=20, s=2). The equivalence
- // PBT above only covers pPrime in [8, 11] (where the existing AVG arm
- // would otherwise intercept and provide a comparable oracle). For pPrime
- // outside that band the new arm still fires (only constrained by inner
- // p <= 7), but the comparison oracle "AVG(x) directly" is no longer
- // available because the existing arm targets a narrower output type.
- // This witness asserts non-null result and the expected widened output
- // schema, locking the rule's shape contract without claiming an
- // unreachable oracle.
- test("SPARK-56627: AVG(CAST(dec(4,2) AS dec(20,2))) peels and yields " +
- "widened output schema (wider-pPrime regime shape witness)") {
+ // Wider-pPrime regime: when pPrime + 4 > MAX_DOUBLE_DIGITS the AVG peel arm
+ // is intentionally NOT fired so the un-rewritten Decimal-exact path is
+ // preserved. Witness: (p=4, p'=20, s=2) -- pPrime + 4 = 24 > 15. Asserts
+ // non-null result and the expected widened output schema; rule shape is
+ // covered by the catalyst-layer suite.
+ test("SPARK-56627: AVG(CAST(dec(4,2) AS dec(20,2))) yields " +
+ "widened output schema (Decimal-exact path preserved)") {
val rows = Seq(123L, -456L, 789L, 0L)
.map(u => Row(java.math.BigDecimal.valueOf(u, 2)))
val schema = StructType(StructField("x", DecimalType(4, 2)) :: Nil)
@@ -5001,10 +4993,9 @@ class DataFrameAggregateSuite extends SharedSparkSession
val row = df.collect()(0)
assert(!row.isNullAt(0), s"expected non-null AVG, got null; df schema =
${df.schema}")
val outType = df.schema("a").dataType.asInstanceOf[DecimalType]
- // Widened-arm output Cast target = DecimalType.bounded(pPrime + 4, s + 4)
- // = DecimalType.bounded(24, 6).
+ // Un-rewritten Average.dataType = bounded(pPrime + 4, s + 4) = (24, 6).
assert(outType.precision == 24 && outType.scale == 6,
- s"expected DecimalType(24, 6) from widened-arm peel, got $outType")
+ s"expected DecimalType(24, 6) from un-rewritten AVG, got $outType")
}
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DecimalAggregatesBenchmark.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DecimalAggregatesBenchmark.scala
index fc00bea62dd1..e006787dbfa1 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DecimalAggregatesBenchmark.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DecimalAggregatesBenchmark.scala
@@ -46,9 +46,11 @@ import org.apache.spark.sql.types.Decimal
* PBT in `DataFrameAggregateSuite`).
*
* Sections:
- * A -- Aggregate SUM widened-cast sweep over (p, s, p') cases.
- * B -- Aggregate AVG widened-cast sweep (p <= 7 per
- * AVG_PEEL_MAX_INNER_PRECISION).
+ * A -- Aggregate SUM widened-cast sweep (`p + 10 <= MAX_LONG_DIGITS`,
+ * any `pPrime > p` up to 38).
+ * B -- Aggregate AVG widened-cast sweep (`pPrime + 4 <= MAX_DOUBLE_DIGITS`
+ * so the rule fires only inside the existing AVG Double-regime
+ * envelope; wider casts stay on the Decimal-exact path).
*
* NOTE on Window arm: the optimizer does not extend widened-Cast peel to
* the Window arm (see DecimalAggregates rule comment) because the analyzer
@@ -56,12 +58,13 @@ import org.apache.spark.sql.types.Decimal
* not exercise this rule. A Window benchmark belongs with a future
* plan-layer rewrite, not here.
*
- * Case design (`p+1` boundary vs `p+10`-class wider) deliberately includes
- * both the minimum widening (one extra digit, e.g. `dec(7,2) -> dec(8,2)`)
- * and a production-canonical wider one (e.g. `dec(7,2) -> dec(17,2)` is the
- * inner-widened-decimal shape in TPC-DS q18) so reviewers see whether
- * widening magnitude matters and whether the canonical shape behaves like
- * the boundary one.
+ * Case design:
+ * - Section A pairs a `p+1` boundary widening with a `p+10`-class wider
+ * cast (A2 mirrors the TPC-DS q18 inner-widened-decimal shape), so
+ * reviewers see whether widening magnitude matters.
+ * - Section B pairs a `p+1` boundary widening with the `pPrime <= 11`
+ * upper bound, the widest cast the AVG arm will accept under the
+ * semantics-preserving guard.
*
* Args: aN (Section A/B row count), iters, apl
* (`spark.sql.decimalOperations.allowPrecisionLoss`; default true).
@@ -106,17 +109,16 @@ object DecimalAggregatesBenchmark extends
SqlBasedBenchmark {
/**
* Aggregate AVG cases: (label, p, s, widened p').
*
- * All `p <= 7` per the conservative `AVG_PEEL_MAX_INNER_PRECISION = 7`
- * guard (see design doc 0001 rev 7 S7.1 -- strict-subset narrowing so
- * SPARK-37024 Double-regime exposure is NOT amplified by this rule).
- * Same `p+1` / `p+10` split as Section A. B2 mirrors the canonical
- * TPC-DS q18 AVG shape.
+ * All `pPrime + 4 <= MAX_DOUBLE_DIGITS = 15`, i.e. `pPrime <= 11` -- the
+ * AVG peel arm only fires inside the existing Double-regime envelope, so
+ * the un-rewritten Decimal-exact path is preserved for wider casts (see
+ * SPARK-56983).
*/
private val AvgAggCases: Seq[(String, Int, Int, Int)] = Seq(
("B1 p=7 s=2 p'=8", 7, 2, 8), // p+1 boundary
- ("B2 p=7 s=2 p'=12", 7, 2, 12), // canonical TPC-DS q18 AVG shape
+ ("B2 p=7 s=2 p'=11", 7, 2, 11), // pPrime upper bound
("B3 p=5 s=0 p'=6", 5, 0, 6), // p+1 boundary, zero scale
- ("B4 p=5 s=0 p'=15", 5, 0, 15) // p+10, zero scale
+ ("B4 p=5 s=0 p'=11", 5, 0, 11) // pPrime upper bound, zero scale
)
/** Clamp generator to `10^(p-s) - 1` so rand() * bound fits `DECIMAL(p,
s)`. */
@@ -195,8 +197,9 @@ object DecimalAggregatesBenchmark extends SqlBasedBenchmark
{
runBenchmark("DecimalAggregates AVG widened-cast peel (Aggregate)") {
AvgAggCases.foreach { case (label, p, s, pPrime) =>
require(pPrime > p, s"$label: p'=$pPrime must exceed p=$p")
- require(p <= 7,
- s"$label: p=$p violates conservative AVG_PEEL_MAX_INNER_PRECISION=7
guard")
+ require(pPrime + 4 <= 15,
+ s"$label: p'=$pPrime violates AVG fast-path guard " +
+ s"pPrime+4<=MAX_DOUBLE_DIGITS=15; rule would not fire")
setupAggTable(spark, aN, p, s)
runThreeWay(label, aN,
nativeSql = "select avg(x) from t",
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]