andygrove commented on code in PR #4749:
URL: https://github.com/apache/datafusion-comet/pull/4749#discussion_r3493123450
##########
spark/src/main/scala/org/apache/spark/sql/comet/CometWindowExec.scala:
##########
@@ -80,18 +79,72 @@ object CometWindowExec extends
CometOperatorSerde[WindowExec] {
// operator itself carries a fallback attribution. Without this, the plan
// prints a bare `Window` and the real reason lives on a sub-expression
// that isn't obvious in the standard explain output.
- val failing = winExprs.toSeq.zip(windowExprProto).collect { case (we,
None) => we } ++
+ val failing = winExprs.toSeq.zip(windowExprProto).collect { case (we,
None) =>
+ we.windowExpression
+ } ++
op.partitionSpec.zip(partitionExprs).collect { case (e, None) => e } ++
op.orderSpec.zip(sortOrders).collect { case (e, None) => e }
withFallbackReason(op, failing: _*)
None
}
}
+ private case class WindowExpressionInfo(
+ windowExpression: WindowExpression,
+ resultDataType: DataType)
+
+ private def extractWindowExpression(expr: Expression):
Option[WindowExpressionInfo] = {
+ expr match {
+ case Alias(child, _) =>
+ extractWindowExpression(child)
+ case w: WindowExpression =>
+ Some(WindowExpressionInfo(w, w.dataType))
+ case m @ MakeDecimal(child, _, _, _) =>
+ extractWindowExpression(child).map { info =>
+ info.copy(
+ windowExpression =
restoreDecimalAggregateInput(info.windowExpression),
+ resultDataType = m.dataType)
+ }
+ case c @ Cast(Divide(child, _, _), _: DecimalType, _, _) =>
+ extractWindowExpression(child).map { info =>
+ info.copy(
+ windowExpression =
restoreDecimalAggregateInput(info.windowExpression),
+ resultDataType = c.dataType)
+ }
Review Comment:
This match extracts the window expression from the numerator and discards
the divisor. That substitution is only correct for the exact shape
`DecimalAggregates` emits: divisor `Literal(10^scale, DoubleType)` and inner
`Average(UnscaledValue(decimal))`. For any other `Cast(Divide(windowExpr, X),
decimal)` the divisor would be silently dropped and the result would be wrong.
I believe it is safe today, because `DecimalAggregates` is the only rule
that injects a `Cast`/`Divide` directly inside a `WindowExec.windowExpression`
(it runs in the optimizer, after `ExtractWindowExpressions` has pulled window
expressions into the operator). A user-written `CAST(SUM(x) OVER w / 2 AS
DECIMAL)` keeps its cast and divide in the `Project` above the window, so it
never reaches here.
Could we tighten the guard to verify the divisor is the expected `10^scale`
literal and the inner is `Average(UnscaledValue(_))`, and fall back otherwise?
That way an unexpected shape degrades to a Spark fallback instead of a silently
dropped divide. A short comment naming `DecimalAggregates` as the source of
this shape would also help the next reader.
##########
spark/src/main/scala/org/apache/spark/sql/comet/CometWindowExec.scala:
##########
@@ -80,18 +79,72 @@ object CometWindowExec extends
CometOperatorSerde[WindowExec] {
// operator itself carries a fallback attribution. Without this, the plan
// prints a bare `Window` and the real reason lives on a sub-expression
// that isn't obvious in the standard explain output.
- val failing = winExprs.toSeq.zip(windowExprProto).collect { case (we,
None) => we } ++
+ val failing = winExprs.toSeq.zip(windowExprProto).collect { case (we,
None) =>
+ we.windowExpression
+ } ++
op.partitionSpec.zip(partitionExprs).collect { case (e, None) => e } ++
op.orderSpec.zip(sortOrders).collect { case (e, None) => e }
withFallbackReason(op, failing: _*)
None
}
}
+ private case class WindowExpressionInfo(
+ windowExpression: WindowExpression,
+ resultDataType: DataType)
+
+ private def extractWindowExpression(expr: Expression):
Option[WindowExpressionInfo] = {
+ expr match {
+ case Alias(child, _) =>
+ extractWindowExpression(child)
+ case w: WindowExpression =>
+ Some(WindowExpressionInfo(w, w.dataType))
+ case m @ MakeDecimal(child, _, _, _) =>
+ extractWindowExpression(child).map { info =>
+ info.copy(
+ windowExpression =
restoreDecimalAggregateInput(info.windowExpression),
+ resultDataType = m.dataType)
+ }
+ case c @ Cast(Divide(child, _, _), _: DecimalType, _, _) =>
+ extractWindowExpression(child).map { info =>
+ info.copy(
+ windowExpression =
restoreDecimalAggregateInput(info.windowExpression),
+ resultDataType = c.dataType)
+ }
+ case _ =>
+ None
+ }
+ }
+
+ // Spark wraps decimal SUM / AVG window aggregates around UnscaledValue plus
+ // rescaling arithmetic. Comet's native decimal aggregates expect the
original
+ // decimal child, so restore that child when unwrapping those Spark wrappers.
+ private def restoreDecimalAggregateInput(windowExpr: WindowExpression):
WindowExpression = {
+ windowExpr
+ .transform {
+ case agg @ AggregateExpression(avg: Average, _, _, _, _) =>
+ avg.child match {
+ case UnscaledValue(child) if
child.dataType.isInstanceOf[DecimalType] =>
+ agg.copy(aggregateFunction = avg.copy(child = child))
+ case _ =>
+ agg
+ }
+ case agg @ AggregateExpression(sum: Sum, _, _, _, _) =>
+ sum.child match {
+ case UnscaledValue(child) if
child.dataType.isInstanceOf[DecimalType] =>
+ agg.copy(aggregateFunction = sum.copy(child = child))
+ case _ =>
+ agg
+ }
Review Comment:
This `Sum` branch and the `MakeDecimal` result-type handling are a behavior
change from the old `case Alias(MakeDecimal(w: WindowExpression, _, _, _), _)
=> w`, which extracted the bare window expression and used
`windowExpr.dataType`. The new test only covers AVG, so this branch is
unexercised. Could we add a decimal SUM over a window with small enough
precision to trigger the `MakeDecimal` rewrite (`prec + 10 <= 18`)? That
exercises the other half of the new extraction logic.
##########
spark/src/main/scala/org/apache/spark/sql/comet/CometWindowExec.scala:
##########
@@ -80,18 +79,72 @@ object CometWindowExec extends
CometOperatorSerde[WindowExec] {
// operator itself carries a fallback attribution. Without this, the plan
// prints a bare `Window` and the real reason lives on a sub-expression
// that isn't obvious in the standard explain output.
- val failing = winExprs.toSeq.zip(windowExprProto).collect { case (we,
None) => we } ++
+ val failing = winExprs.toSeq.zip(windowExprProto).collect { case (we,
None) =>
+ we.windowExpression
+ } ++
op.partitionSpec.zip(partitionExprs).collect { case (e, None) => e } ++
op.orderSpec.zip(sortOrders).collect { case (e, None) => e }
withFallbackReason(op, failing: _*)
None
}
}
+ private case class WindowExpressionInfo(
+ windowExpression: WindowExpression,
+ resultDataType: DataType)
+
+ private def extractWindowExpression(expr: Expression):
Option[WindowExpressionInfo] = {
+ expr match {
+ case Alias(child, _) =>
+ extractWindowExpression(child)
+ case w: WindowExpression =>
+ Some(WindowExpressionInfo(w, w.dataType))
+ case m @ MakeDecimal(child, _, _, _) =>
+ extractWindowExpression(child).map { info =>
+ info.copy(
+ windowExpression =
restoreDecimalAggregateInput(info.windowExpression),
+ resultDataType = m.dataType)
+ }
+ case c @ Cast(Divide(child, _, _), _: DecimalType, _, _) =>
+ extractWindowExpression(child).map { info =>
+ info.copy(
+ windowExpression =
restoreDecimalAggregateInput(info.windowExpression),
+ resultDataType = c.dataType)
+ }
+ case _ =>
+ None
+ }
+ }
+
+ // Spark wraps decimal SUM / AVG window aggregates around UnscaledValue plus
+ // rescaling arithmetic. Comet's native decimal aggregates expect the
original
+ // decimal child, so restore that child when unwrapping those Spark wrappers.
+ private def restoreDecimalAggregateInput(windowExpr: WindowExpression):
WindowExpression = {
+ windowExpr
+ .transform {
+ case agg @ AggregateExpression(avg: Average, _, _, _, _) =>
+ avg.child match {
+ case UnscaledValue(child) if
child.dataType.isInstanceOf[DecimalType] =>
+ agg.copy(aggregateFunction = avg.copy(child = child))
Review Comment:
Restoring the decimal child here makes Comet run its precise `avg_decimal`
UDAF (decimal sum at `decimal(p+10, s)`, decimal division with HALF_UP). But
the plan that `DecimalAggregates` produced computes the average in `Double`:
`Average(UnscaledValue(e))` is a double average, divided by `10^scale` as a
double, then cast to `decimal(p+4, s+4)`. So for the case this PR targets we
are substituting precise decimal arithmetic for Spark's floating-point
arithmetic, and the two can round differently at the `s+4` scale.
Spark only takes this double path when `prec + 4 <= 15`, so values stay in
double's exact integer range and any divergence would be rare and in the last
digit, but it is not guaranteed to be zero. For contrast, the regular
non-window decimal AVG stays faithful: the `HashAggregate` computes the double
average and a `Project` does the divide and cast, matching Spark exactly. That
asymmetry is what worries me here.
Could we add a fuzz or property test that runs many random decimal values
across a few precision and scale combinations through window AVG and compares
against Spark? The clean values in the new test all happen to round the same
way in both schemes, so they would not catch this.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]