cloud-fan commented on code in PR #53695:
URL: https://github.com/apache/spark/pull/53695#discussion_r3299180825


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala:
##########
@@ -58,32 +63,46 @@ import org.apache.spark.util.ArrayImplicits._
  */
 object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
 
-  def apply(plan: LogicalPlan): LogicalPlan = plan match {
-    case _ => plan.transformWithPruning( _.containsAnyPattern(WINDOW, JOIN)) {
-      case w: Window if w.partitionSpec.exists(p => needNormalize(p)) =>
-        // Although the `windowExpressions` may refer to `partitionSpec` 
expressions, we don't need
-        // to normalize the `windowExpressions`, as they are executed per 
input row and should take
-        // the input row as it is.
-        w.copy(partitionSpec = w.partitionSpec.map(normalize))
-
-      // Only hash join and sort merge join need the normalization. Here we 
catch all Joins with
-      // join keys, assuming Joins with join keys are always planned as hash 
join or sort merge
-      // join. It's very unlikely that we will break this assumption in the 
near future.
-      case j @ ExtractEquiJoinKeys(_, leftKeys, rightKeys, condition, _, _, _, 
_)
-          // The analyzer guarantees left and right joins keys are of the same 
data type. Here we
-          // only need to check join keys of one side.
-          if leftKeys.exists(k => needNormalize(k)) =>
-        val newLeftJoinKeys = leftKeys.map(normalize)
-        val newRightJoinKeys = rightKeys.map(normalize)
-        val newConditions = newLeftJoinKeys.zip(newRightJoinKeys).map {
-          case (l, r) => EqualTo(l, r)
-        } ++ condition
-        j.copy(condition = Some(newConditions.reduce(And)))
-
-      // TODO: ideally Aggregate should also be handled here, but its grouping 
expressions are
-      // mixed in its aggregate expressions. It's unreliable to change the 
grouping expressions
-      // here. For now we normalize grouping expressions in `AggUtils` during 
planning.
-    }
+  def apply(plan: LogicalPlan): LogicalPlan = {
+    plan
+      .transformWithPruning( _.containsAnyPattern(WINDOW, JOIN)) {
+        case w: Window if w.partitionSpec.exists(p => needNormalize(p)) =>
+          // Although the `windowExpressions` may refer to `partitionSpec` 
expressions,
+          // we don't need to normalize the `windowExpressions`, as they are 
executed
+          // per input row and should take the input row as it is.
+          w.copy(partitionSpec = w.partitionSpec.map(normalize))
+
+        // Only hash join and sort merge join need the normalization. Here we 
catch all Joins with
+        // join keys, assuming Joins with join keys are always planned as hash 
join or sort merge
+        // join. It's very unlikely that we will break this assumption in the 
near future.
+        case j @ ExtractEquiJoinKeys(_, leftKeys, rightKeys, condition, _, _, 
_, _)
+            // The analyzer guarantees left and right joins keys are of the 
same data type. Here we
+            // only need to check join keys of one side.
+            if leftKeys.exists(k => needNormalize(k)) =>
+          val newLeftJoinKeys = leftKeys.map(normalize)
+          val newRightJoinKeys = rightKeys.map(normalize)
+          val newConditions = newLeftJoinKeys.zip(newRightJoinKeys).map {
+            case (l, r) => EqualTo(l, r)
+          } ++ condition
+          j.copy(condition = Some(newConditions.reduce(And)))
+
+        // TODO: ideally Aggregate should also be handled here, but its 
grouping expressions are
+        // mixed in its aggregate expressions. It's unreliable to change the 
grouping expressions
+        // here. For now we normalize grouping expressions in `AggUtils` 
during planning.
+      }
+      .transformAllExpressionsWithPruning(_.containsAnyPattern(
+        ARRAY_DISTINCT, ARRAY_UNION, ARRAY_INTERSECT, ARRAY_EXCEPT, 
ARRAYS_OVERLAP)) {
+        case e: ArrayDistinct if needNormalize(e.child.dataType) =>

Review Comment:
   The guards here use the dataType form `needNormalize(e.child.dataType)` / 
`needNormalize(e.left.dataType)`, while the existing `Window` and 
`ExtractEquiJoinKeys` cases in this same rule use the expression form 
(`needNormalize(p)` / `needNormalize(k)`), which short-circuits when the input 
is already `KnownFloatingPointNormalized`. Consequence: on a repeat optimizer 
pass the array cases still match an already-normalized child and produce 
`e.copy(child = e.child)` — semantically equal but a fresh node. Suggest 
aligning:
   
   ```scala
   case e: ArrayDistinct if needNormalize(e.child) =>
     e.copy(child = normalize(e.child))
   case e: ArrayUnion if needNormalize(e.left) =>
     e.copy(left = normalize(e.left), right = normalize(e.right))
   // … same for ArrayIntersect / ArrayExcept / ArraysOverlap
   ```
   



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala:
##########
@@ -4235,6 +4233,9 @@ trait ArraySetLike {
   since = "2.4.0")
 case class ArrayDistinct(child: Expression)
   extends UnaryExpression with ArraySetLike with ExpectsInputTypes {
+
+  final override val nodePatterns: Seq[TreePattern] = Seq(ARRAY_DISTINCT)

Review Comment:
   Concrete suggestion for the constant-folding gap (see top-level point 1). 
Add a conditional `foldable` so each of the five expressions opts out of 
folding when its input element type contains float/double — `ConstantFolding` 
then skips them, and `NormalizeFloatingNumbers` wraps them as designed.
   
   For `ArrayDistinct`:
   ```scala
     final override val nodePatterns: Seq[TreePattern] = Seq(ARRAY_DISTINCT)
   
     override lazy val foldable: Boolean =
       super.foldable && !NormalizeFloatingNumbers.needNormalize(child.dataType)
   ```
   
   Same shape for the four binary set ops, using `left.dataType` (the analyzer 
guarantees `left` and `right` have matching element types):
   ```scala
     override lazy val foldable: Boolean =
       super.foldable && !NormalizeFloatingNumbers.needNormalize(left.dataType)
   ```
   
   Apply to `ArrayUnion` (line 4435), `ArrayIntersect` (4614), `ArrayExcept` 
(4848), `ArraysOverlap` (1808). Bump 
`NormalizeFloatingNumbers.needNormalize(dt: DataType)` from `private` to 
`private[sql]` so it's callable from here.
   
   Trace for `array_distinct(typedLit(Array(-0.0d, 0.0d)))`:
   1. `ArrayDistinct.foldable = true && !needNormalize(ArrayType(DoubleType)) = 
false` → `ConstantFolding` skips.
   2. `NormalizeFloatingNumbers` matches, wraps the child in 
`KnownFloatingPointNormalized(ArrayTransform(literal, NormalizeNaNAndZero(x)))`.
   3. At execution, `ArrayTransform` normalizes (`-0.0 → 0.0`); `ArrayDistinct` 
dedups to `[0.0]`.
   
   Cost: `array_distinct(array(1.0, 2.0))`-style queries with no `-0.0`/NaN 
bits also lose folding and run at row time. Negligible — literal float arrays 
in `array_distinct`/etc. are rare and small.
   



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala:
##########
@@ -37,6 +37,8 @@ import org.apache.spark.util.ArrayImplicits._
  *      treated as same.
  *   4. In window partition keys, different NaNs should belong to the same 
partition, -0.0 and 0.0
  *      should belong to the same partition.
+ *   5. In hash-based array set operations, different NaNs should be treated 
as same, `-0.0` and 0.0

Review Comment:
   ```suggestion
    *   5. In hash-based array set operations, different NaNs should be treated 
as same, `-0.0` and `0.0`
   ```
   Match the backticking in items 1 and 3 above — both zeros should be in 
backticks.
   



##########
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala:
##########
@@ -132,5 +145,64 @@ class NormalizeFloatingPointNumbersSuite extends PlanTest {
     val normalizedExpr = NormalizeFloatingNumbers.normalize(nestedExpr)
     assert(nestedExpr.dataType == normalizedExpr.dataType)
   }
+
+  test("SPARK-54918: normalize floating points in array_distinct") {
+    val query = arrayRelation.select(ArrayDistinct(arr1).as("result"))
+
+    val optimized = Optimize.execute(query)
+    val correctAnswer = 
arrayRelation.select(ArrayDistinct(normalizedArray(arr1)).as("result"))
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("SPARK-54918: normalize floating points in array_distinct - 
idempotence") {
+    val query = arrayRelation.select(ArrayDistinct(arr1).as("result"))
+
+    val optimized = Optimize.execute(query)
+    val doubleOptimized = Optimize.execute(optimized)
+    val correctAnswer = 
arrayRelation.select(ArrayDistinct(normalizedArray(arr1)).as("result"))
+
+    comparePlans(doubleOptimized, correctAnswer)
+  }

Review Comment:
   Idempotence is only covered for `array_distinct`; the other four 
(`array_union`, `array_intersect`, `array_except`, `arrays_overlap`) don't have 
an equivalent. The `normalizedArray` helper is already in place, so each is ~5 
lines. Worth adding for symmetry — and once the guard switches to the 
expression form (other inline comment), an idempotence test on the binary ops 
would catch any regression of that property.
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to