andygrove commented on code in PR #4749:
URL: https://github.com/apache/datafusion-comet/pull/4749#discussion_r3493123450


##########
spark/src/main/scala/org/apache/spark/sql/comet/CometWindowExec.scala:
##########
@@ -80,18 +79,72 @@ object CometWindowExec extends 
CometOperatorSerde[WindowExec] {
       // operator itself carries a fallback attribution. Without this, the plan
       // prints a bare `Window` and the real reason lives on a sub-expression
       // that isn't obvious in the standard explain output.
-      val failing = winExprs.toSeq.zip(windowExprProto).collect { case (we, 
None) => we } ++
+      val failing = winExprs.toSeq.zip(windowExprProto).collect { case (we, 
None) =>
+        we.windowExpression
+      } ++
         op.partitionSpec.zip(partitionExprs).collect { case (e, None) => e } ++
         op.orderSpec.zip(sortOrders).collect { case (e, None) => e }
       withFallbackReason(op, failing: _*)
       None
     }
   }
 
+  private case class WindowExpressionInfo(
+      windowExpression: WindowExpression,
+      resultDataType: DataType)
+
+  private def extractWindowExpression(expr: Expression): 
Option[WindowExpressionInfo] = {
+    expr match {
+      case Alias(child, _) =>
+        extractWindowExpression(child)
+      case w: WindowExpression =>
+        Some(WindowExpressionInfo(w, w.dataType))
+      case m @ MakeDecimal(child, _, _, _) =>
+        extractWindowExpression(child).map { info =>
+          info.copy(
+            windowExpression = 
restoreDecimalAggregateInput(info.windowExpression),
+            resultDataType = m.dataType)
+        }
+      case c @ Cast(Divide(child, _, _), _: DecimalType, _, _) =>
+        extractWindowExpression(child).map { info =>
+          info.copy(
+            windowExpression = 
restoreDecimalAggregateInput(info.windowExpression),
+            resultDataType = c.dataType)
+        }

Review Comment:
   This match extracts the window expression from the numerator and discards 
the divisor. That substitution is only correct for the exact shape 
`DecimalAggregates` emits: divisor `Literal(10^scale, DoubleType)` and inner 
`Average(UnscaledValue(decimal))`. For any other `Cast(Divide(windowExpr, X), 
decimal)` the divisor would be silently dropped and the result would be wrong.
   
   I believe it is safe today, because `DecimalAggregates` is the only rule 
that injects a `Cast`/`Divide` directly inside a `WindowExec.windowExpression` 
(it runs in the optimizer, after `ExtractWindowExpressions` has pulled window 
expressions into the operator). A user-written `CAST(SUM(x) OVER w / 2 AS 
DECIMAL)` keeps its cast and divide in the `Project` above the window, so it 
never reaches here.
   
   Could we tighten the guard to verify the divisor is the expected `10^scale` 
literal and the inner is `Average(UnscaledValue(_))`, and fall back otherwise? 
That way an unexpected shape degrades to a Spark fallback instead of a silently 
dropped divide. A short comment naming `DecimalAggregates` as the source of 
this shape would also help the next reader.



##########
spark/src/main/scala/org/apache/spark/sql/comet/CometWindowExec.scala:
##########
@@ -80,18 +79,72 @@ object CometWindowExec extends 
CometOperatorSerde[WindowExec] {
       // operator itself carries a fallback attribution. Without this, the plan
       // prints a bare `Window` and the real reason lives on a sub-expression
       // that isn't obvious in the standard explain output.
-      val failing = winExprs.toSeq.zip(windowExprProto).collect { case (we, 
None) => we } ++
+      val failing = winExprs.toSeq.zip(windowExprProto).collect { case (we, 
None) =>
+        we.windowExpression
+      } ++
         op.partitionSpec.zip(partitionExprs).collect { case (e, None) => e } ++
         op.orderSpec.zip(sortOrders).collect { case (e, None) => e }
       withFallbackReason(op, failing: _*)
       None
     }
   }
 
+  private case class WindowExpressionInfo(
+      windowExpression: WindowExpression,
+      resultDataType: DataType)
+
+  private def extractWindowExpression(expr: Expression): 
Option[WindowExpressionInfo] = {
+    expr match {
+      case Alias(child, _) =>
+        extractWindowExpression(child)
+      case w: WindowExpression =>
+        Some(WindowExpressionInfo(w, w.dataType))
+      case m @ MakeDecimal(child, _, _, _) =>
+        extractWindowExpression(child).map { info =>
+          info.copy(
+            windowExpression = 
restoreDecimalAggregateInput(info.windowExpression),
+            resultDataType = m.dataType)
+        }
+      case c @ Cast(Divide(child, _, _), _: DecimalType, _, _) =>
+        extractWindowExpression(child).map { info =>
+          info.copy(
+            windowExpression = 
restoreDecimalAggregateInput(info.windowExpression),
+            resultDataType = c.dataType)
+        }
+      case _ =>
+        None
+    }
+  }
+
+  // Spark wraps decimal SUM / AVG window aggregates around UnscaledValue plus
+  // rescaling arithmetic. Comet's native decimal aggregates expect the 
original
+  // decimal child, so restore that child when unwrapping those Spark wrappers.
+  private def restoreDecimalAggregateInput(windowExpr: WindowExpression): 
WindowExpression = {
+    windowExpr
+      .transform {
+        case agg @ AggregateExpression(avg: Average, _, _, _, _) =>
+          avg.child match {
+            case UnscaledValue(child) if 
child.dataType.isInstanceOf[DecimalType] =>
+              agg.copy(aggregateFunction = avg.copy(child = child))
+            case _ =>
+              agg
+          }
+        case agg @ AggregateExpression(sum: Sum, _, _, _, _) =>
+          sum.child match {
+            case UnscaledValue(child) if 
child.dataType.isInstanceOf[DecimalType] =>
+              agg.copy(aggregateFunction = sum.copy(child = child))
+            case _ =>
+              agg
+          }

Review Comment:
   This `Sum` branch and the `MakeDecimal` result-type handling are a behavior 
change from the old `case Alias(MakeDecimal(w: WindowExpression, _, _, _), _) 
=> w`, which extracted the bare window expression and used 
`windowExpr.dataType`. The new test only covers AVG, so this branch is 
unexercised. Could we add a decimal SUM over a window with small enough 
precision to trigger the `MakeDecimal` rewrite (`prec + 10 <= 18`)? That 
exercises the other half of the new extraction logic.



##########
spark/src/main/scala/org/apache/spark/sql/comet/CometWindowExec.scala:
##########
@@ -80,18 +79,72 @@ object CometWindowExec extends 
CometOperatorSerde[WindowExec] {
       // operator itself carries a fallback attribution. Without this, the plan
       // prints a bare `Window` and the real reason lives on a sub-expression
       // that isn't obvious in the standard explain output.
-      val failing = winExprs.toSeq.zip(windowExprProto).collect { case (we, 
None) => we } ++
+      val failing = winExprs.toSeq.zip(windowExprProto).collect { case (we, 
None) =>
+        we.windowExpression
+      } ++
         op.partitionSpec.zip(partitionExprs).collect { case (e, None) => e } ++
         op.orderSpec.zip(sortOrders).collect { case (e, None) => e }
       withFallbackReason(op, failing: _*)
       None
     }
   }
 
+  private case class WindowExpressionInfo(
+      windowExpression: WindowExpression,
+      resultDataType: DataType)
+
+  private def extractWindowExpression(expr: Expression): 
Option[WindowExpressionInfo] = {
+    expr match {
+      case Alias(child, _) =>
+        extractWindowExpression(child)
+      case w: WindowExpression =>
+        Some(WindowExpressionInfo(w, w.dataType))
+      case m @ MakeDecimal(child, _, _, _) =>
+        extractWindowExpression(child).map { info =>
+          info.copy(
+            windowExpression = 
restoreDecimalAggregateInput(info.windowExpression),
+            resultDataType = m.dataType)
+        }
+      case c @ Cast(Divide(child, _, _), _: DecimalType, _, _) =>
+        extractWindowExpression(child).map { info =>
+          info.copy(
+            windowExpression = 
restoreDecimalAggregateInput(info.windowExpression),
+            resultDataType = c.dataType)
+        }
+      case _ =>
+        None
+    }
+  }
+
+  // Spark wraps decimal SUM / AVG window aggregates around UnscaledValue plus
+  // rescaling arithmetic. Comet's native decimal aggregates expect the 
original
+  // decimal child, so restore that child when unwrapping those Spark wrappers.
+  private def restoreDecimalAggregateInput(windowExpr: WindowExpression): 
WindowExpression = {
+    windowExpr
+      .transform {
+        case agg @ AggregateExpression(avg: Average, _, _, _, _) =>
+          avg.child match {
+            case UnscaledValue(child) if 
child.dataType.isInstanceOf[DecimalType] =>
+              agg.copy(aggregateFunction = avg.copy(child = child))

Review Comment:
   Restoring the decimal child here makes Comet run its precise `avg_decimal` 
UDAF (decimal sum at `decimal(p+10, s)`, decimal division with HALF_UP). But 
the plan that `DecimalAggregates` produced computes the average in `Double`: 
`Average(UnscaledValue(e))` is a double average, divided by `10^scale` as a 
double, then cast to `decimal(p+4, s+4)`. So for the case this PR targets we 
are substituting precise decimal arithmetic for Spark's floating-point 
arithmetic, and the two can round differently at the `s+4` scale.
   
   Spark only takes this double path when `prec + 4 <= 15`, so values stay in 
double's exact integer range and any divergence would be rare and in the last 
digit, but it is not guaranteed to be zero. For contrast, the regular 
non-window decimal AVG stays faithful: the `HashAggregate` computes the double 
average and a `Project` does the divide and cast, matching Spark exactly. That 
asymmetry is what worries me here.
   
   Could we add a fuzz or property test that runs many random decimal values 
across a few precision and scale combinations through window AVG and compares 
against Spark? The clean values in the new test all happen to round the same 
way in both schemes, so they would not catch this.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to