cloud-fan commented on code in PR #53695:
URL: https://github.com/apache/spark/pull/53695#discussion_r3299180825
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala:
##########
@@ -58,32 +63,46 @@ import org.apache.spark.util.ArrayImplicits._
*/
object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan match {
- case _ => plan.transformWithPruning( _.containsAnyPattern(WINDOW, JOIN)) {
- case w: Window if w.partitionSpec.exists(p => needNormalize(p)) =>
- // Although the `windowExpressions` may refer to `partitionSpec`
expressions, we don't need
- // to normalize the `windowExpressions`, as they are executed per
input row and should take
- // the input row as it is.
- w.copy(partitionSpec = w.partitionSpec.map(normalize))
-
- // Only hash join and sort merge join need the normalization. Here we
catch all Joins with
- // join keys, assuming Joins with join keys are always planned as hash
join or sort merge
- // join. It's very unlikely that we will break this assumption in the
near future.
- case j @ ExtractEquiJoinKeys(_, leftKeys, rightKeys, condition, _, _, _,
_)
- // The analyzer guarantees left and right joins keys are of the same
data type. Here we
- // only need to check join keys of one side.
- if leftKeys.exists(k => needNormalize(k)) =>
- val newLeftJoinKeys = leftKeys.map(normalize)
- val newRightJoinKeys = rightKeys.map(normalize)
- val newConditions = newLeftJoinKeys.zip(newRightJoinKeys).map {
- case (l, r) => EqualTo(l, r)
- } ++ condition
- j.copy(condition = Some(newConditions.reduce(And)))
-
- // TODO: ideally Aggregate should also be handled here, but its grouping
expressions are
- // mixed in its aggregate expressions. It's unreliable to change the
grouping expressions
- // here. For now we normalize grouping expressions in `AggUtils` during
planning.
- }
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ plan
+ .transformWithPruning( _.containsAnyPattern(WINDOW, JOIN)) {
+ case w: Window if w.partitionSpec.exists(p => needNormalize(p)) =>
+ // Although the `windowExpressions` may refer to `partitionSpec`
expressions,
+ // we don't need to normalize the `windowExpressions`, as they are
executed
+ // per input row and should take the input row as it is.
+ w.copy(partitionSpec = w.partitionSpec.map(normalize))
+
+ // Only hash join and sort merge join need the normalization. Here we
catch all Joins with
+ // join keys, assuming Joins with join keys are always planned as hash
join or sort merge
+ // join. It's very unlikely that we will break this assumption in the
near future.
+ case j @ ExtractEquiJoinKeys(_, leftKeys, rightKeys, condition, _, _,
_, _)
+ // The analyzer guarantees left and right joins keys are of the
same data type. Here we
+ // only need to check join keys of one side.
+ if leftKeys.exists(k => needNormalize(k)) =>
+ val newLeftJoinKeys = leftKeys.map(normalize)
+ val newRightJoinKeys = rightKeys.map(normalize)
+ val newConditions = newLeftJoinKeys.zip(newRightJoinKeys).map {
+ case (l, r) => EqualTo(l, r)
+ } ++ condition
+ j.copy(condition = Some(newConditions.reduce(And)))
+
+ // TODO: ideally Aggregate should also be handled here, but its
grouping expressions are
+ // mixed in its aggregate expressions. It's unreliable to change the
grouping expressions
+ // here. For now we normalize grouping expressions in `AggUtils`
during planning.
+ }
+ .transformAllExpressionsWithPruning(_.containsAnyPattern(
+ ARRAY_DISTINCT, ARRAY_UNION, ARRAY_INTERSECT, ARRAY_EXCEPT,
ARRAYS_OVERLAP)) {
+ case e: ArrayDistinct if needNormalize(e.child.dataType) =>
Review Comment:
The guards here use the dataType form `needNormalize(e.child.dataType)` /
`needNormalize(e.left.dataType)`, while the existing `Window` and
`ExtractEquiJoinKeys` cases in this same rule use the expression form
(`needNormalize(p)` / `needNormalize(k)`), which short-circuits when the input
is already `KnownFloatingPointNormalized`. Consequence: on a repeat optimizer
pass the array cases still match an already-normalized child and produce
`e.copy(child = e.child)` — semantically equal but a fresh node. Suggest
aligning:
```scala
case e: ArrayDistinct if needNormalize(e.child) =>
e.copy(child = normalize(e.child))
case e: ArrayUnion if needNormalize(e.left) =>
e.copy(left = normalize(e.left), right = normalize(e.right))
// … same for ArrayIntersect / ArrayExcept / ArraysOverlap
```
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala:
##########
@@ -4235,6 +4233,9 @@ trait ArraySetLike {
since = "2.4.0")
case class ArrayDistinct(child: Expression)
extends UnaryExpression with ArraySetLike with ExpectsInputTypes {
+
+ final override val nodePatterns: Seq[TreePattern] = Seq(ARRAY_DISTINCT)
Review Comment:
Concrete suggestion for the constant-folding gap (see top-level point 1).
Add a conditional `foldable` so each of the five expressions opts out of
folding when its input element type contains float/double — `ConstantFolding`
then skips them, and `NormalizeFloatingNumbers` wraps them as designed.
For `ArrayDistinct`:
```scala
final override val nodePatterns: Seq[TreePattern] = Seq(ARRAY_DISTINCT)
override lazy val foldable: Boolean =
super.foldable && !NormalizeFloatingNumbers.needNormalize(child.dataType)
```
Same shape for the four binary set ops, using `left.dataType` (the analyzer
guarantees `left` and `right` have matching element types):
```scala
override lazy val foldable: Boolean =
super.foldable && !NormalizeFloatingNumbers.needNormalize(left.dataType)
```
Apply to `ArrayUnion` (line 4435), `ArrayIntersect` (4614), `ArrayExcept`
(4848), `ArraysOverlap` (1808). Bump
`NormalizeFloatingNumbers.needNormalize(dt: DataType)` from `private` to
`private[sql]` so it's callable from here.
Trace for `array_distinct(typedLit(Array(-0.0d, 0.0d)))`:
1. `ArrayDistinct.foldable = true && !needNormalize(ArrayType(DoubleType)) =
false` → `ConstantFolding` skips.
2. `NormalizeFloatingNumbers` matches, wraps the child in
`KnownFloatingPointNormalized(ArrayTransform(literal, NormalizeNaNAndZero(x)))`.
3. At execution, `ArrayTransform` normalizes (`-0.0 → 0.0`); `ArrayDistinct`
dedups to `[0.0]`.
Cost: `array_distinct(array(1.0, 2.0))`-style queries with no `-0.0`/NaN
bits also lose folding and run at row time. Negligible — literal float arrays
in `array_distinct`/etc. are rare and small.
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala:
##########
@@ -37,6 +37,8 @@ import org.apache.spark.util.ArrayImplicits._
* treated as same.
* 4. In window partition keys, different NaNs should belong to the same
partition, -0.0 and 0.0
* should belong to the same partition.
+ * 5. In hash-based array set operations, different NaNs should be treated
as same, `-0.0` and 0.0
Review Comment:
```suggestion
* 5. In hash-based array set operations, different NaNs should be treated
as same, `-0.0` and `0.0`
```
Match the backticking in items 1 and 3 above — both zeros should be in
backticks.
##########
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala:
##########
@@ -132,5 +145,64 @@ class NormalizeFloatingPointNumbersSuite extends PlanTest {
val normalizedExpr = NormalizeFloatingNumbers.normalize(nestedExpr)
assert(nestedExpr.dataType == normalizedExpr.dataType)
}
+
+ test("SPARK-54918: normalize floating points in array_distinct") {
+ val query = arrayRelation.select(ArrayDistinct(arr1).as("result"))
+
+ val optimized = Optimize.execute(query)
+ val correctAnswer =
arrayRelation.select(ArrayDistinct(normalizedArray(arr1)).as("result"))
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("SPARK-54918: normalize floating points in array_distinct -
idempotence") {
+ val query = arrayRelation.select(ArrayDistinct(arr1).as("result"))
+
+ val optimized = Optimize.execute(query)
+ val doubleOptimized = Optimize.execute(optimized)
+ val correctAnswer =
arrayRelation.select(ArrayDistinct(normalizedArray(arr1)).as("result"))
+
+ comparePlans(doubleOptimized, correctAnswer)
+ }
Review Comment:
Idempotence is only covered for `array_distinct`; the other four
(`array_union`, `array_intersect`, `array_except`, `arrays_overlap`) don't have
an equivalent. The `normalizedArray` helper is already in place, so each is ~5
lines. Worth adding for symmetry — and once the guard switches to the
expression form (other inline comment), an idempotence test on the binary ops
would catch any regression of that property.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]