This is an automated email from the ASF dual-hosted git repository.

cloud-fan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a4aa4adf4fd0 [SPARK-54918][SQL] Normalize floating numbers in array 
set operations
a4aa4adf4fd0 is described below

commit a4aa4adf4fd0f790d7b71b322292050f6cf8db38
Author: Albert Sugranyes <[email protected]>
AuthorDate: Wed May 27 15:13:56 2026 +0800

    [SPARK-54918][SQL] Normalize floating numbers in array set operations
    
    ### What changes were proposed in this pull request?
    
    Extends `NormalizeFloatingNumbers` Catalyst optimizer rule to normalize  
floating numbers in hash-based array set operations:
    
    - `array_distinct`
    - `array_union`
    - `array_intersect`
    - `array_except`
    - `arrays_overlap`
    
    ### Why are the changes needed?
    
    These expressions rely on hash-based set semantics for element comparison, 
which distinguish -0.0 from 0.0.Under Spark SQL semantics, these values are 
equivalent, so the resulting sets violate the expected algebraic properties  of 
the set operations.
    
    Examples:
    
    ```scala
    // Before fix: returns [0.0, -0.0, 1.0]
    // After fix: returns [0.0, 1.0]
    Seq(Array(0.0, -0.0, 1.0))
      .toDF("values")
      .selectExpr("array_distinct(values)")
      .show()
    
    // Before fix: returns [0.0, -0.0]
    // After fix: returns [0.0]
    Seq((Array(0.0), Array(-0.0)))
      .toDF("a", "b")
      .selectExpr("array_union(a, b)")
      .show()
    
    // Before fix: returns []
    // After fix: returns [0.0]
    Seq((Array(0.0, 1.0), Array(-0.0, 2.0)))
      .toDF("a", "b")
      .selectExpr("array_intersect(a, b)")
      .show()
    
    // Before fix: returns [0.0, 1.0]
    // After fix: returns [1.0]
    Seq((Array(0.0, 1.0), Array(-0.0)))
      .toDF("a", "b")
      .selectExpr("array_except(a, b)")
      .show()
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. Hash-based array set operations now treat -0.0/0.0 and different NaN 
representations as equal, consistent with the current behavior in joins, window 
partitions and aggregates.
    
    ### How was this patch tested?
    
    - 10 unit tests in `NormalizeFloatingPointNumbersSuite` covering logical 
plan rewrites for all 5 operations plus idempotence.
    - 11 end-to-end tests in `DataFrameFunctionsSuite` verifying runtime bit 
patterns via `Double.doubleToRawLongBits`, since IEEE 754 defines 0.0 == -0.0 
as true.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #53695 from asugranyes/SPARK-54918.
    
    Authored-by: Albert Sugranyes <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../expressions/collectionOperations.scala         |  21 ++-
 .../optimizer/NormalizeFloatingNumbers.scala       |  95 ++++++++-----
 .../spark/sql/catalyst/optimizer/Optimizer.scala   |   1 +
 .../spark/sql/catalyst/trees/TreePatterns.scala    |   5 +
 .../NormalizeFloatingPointNumbersSuite.scala       | 118 +++++++++++++++-
 .../apache/spark/sql/DataFrameFunctionsSuite.scala | 152 +++++++++++++++++++++
 6 files changed, 348 insertions(+), 44 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index b0396188bcdd..85172f795744 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -32,12 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
 import org.apache.spark.sql.catalyst.trees.{BinaryLike, UnaryLike}
-import org.apache.spark.sql.catalyst.trees.TreePattern.{
-  ARRAYS_ZIP,
-  CONCAT,
-  MAP_FROM_ENTRIES,
-  TreePattern
-}
+import org.apache.spark.sql.catalyst.trees.TreePattern.{ARRAY_DISTINCT, 
ARRAY_EXCEPT, ARRAY_INTERSECT, ARRAY_UNION, ARRAYS_OVERLAP, ARRAYS_ZIP, CONCAT, 
MAP_FROM_ENTRIES, TreePattern}
 import org.apache.spark.sql.catalyst.types.{DataTypeUtils, PhysicalDataType, 
PhysicalIntegralType}
 import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.catalyst.util.DateTimeConstants._
@@ -1809,6 +1804,9 @@ case class ArrayAppend(left: Expression, right: 
Expression) extends ArrayPendBas
 // scalastyle:off line.size.limit
 case class ArraysOverlap(left: Expression, right: Expression)
   extends BinaryArrayExpressionWithImplicitCast with Predicate {
+
+  final override val nodePatterns: Seq[TreePattern] = Seq(ARRAYS_OVERLAP)
+
   override def nullIntolerant: Boolean = true
 
   override def checkInputDataTypes(): TypeCheckResult = 
super.checkInputDataTypes() match {
@@ -4235,6 +4233,9 @@ trait ArraySetLike {
   since = "2.4.0")
 case class ArrayDistinct(child: Expression)
   extends UnaryExpression with ArraySetLike with ExpectsInputTypes {
+
+  final override val nodePatterns: Seq[TreePattern] = Seq(ARRAY_DISTINCT)
+
   override def nullIntolerant: Boolean = true
   override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)
 
@@ -4431,6 +4432,8 @@ trait ArrayBinaryLike
 case class ArrayUnion(left: Expression, right: Expression) extends 
ArrayBinaryLike
   with ComplexTypeMergingExpression {
 
+  final override val nodePatterns: Seq[TreePattern] = Seq(ARRAY_UNION)
+
   @transient lazy val evalUnion: (ArrayData, ArrayData) => ArrayData = {
     if (TypeUtils.typeWithProperEquals(elementType)) {
       (array1, array2) =>
@@ -4608,6 +4611,8 @@ case class ArrayUnion(left: Expression, right: 
Expression) extends ArrayBinaryLi
 case class ArrayIntersect(left: Expression, right: Expression) extends 
ArrayBinaryLike
   with ComplexTypeMergingExpression {
 
+  final override val nodePatterns: Seq[TreePattern] = Seq(ARRAY_INTERSECT)
+
   private lazy val internalDataType: DataType = {
     dataTypeCheck
     ArrayType(elementType, leftArrayElementNullable && 
rightArrayElementNullable)
@@ -4823,7 +4828,7 @@ case class ArrayIntersect(left: Expression, right: 
Expression) extends ArrayBina
 }
 
 /**
- * Returns an array of the elements in the intersect of x and y, without 
duplicates
+ * Returns an array of the elements in x but not in y, without duplicates
  */
 @ExpressionDescription(
   usage = """
@@ -4840,6 +4845,8 @@ case class ArrayIntersect(left: Expression, right: 
Expression) extends ArrayBina
 case class ArrayExcept(left: Expression, right: Expression) extends 
ArrayBinaryLike
   with ComplexTypeMergingExpression {
 
+  final override val nodePatterns: Seq[TreePattern] = Seq(ARRAY_EXCEPT)
+
   private lazy val internalDataType: DataType = {
     dataTypeCheck
     left.dataType
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
index 776efbed273e..44add1796169 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.catalyst.optimizer
 
 import org.apache.spark.SparkException
-import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, 
CaseWhen, Coalesce, CreateArray, CreateMap, CreateNamedStruct, EqualTo, 
ExpectsInputTypes, Expression, GetStructField, If, IsNull, 
KnownFloatingPointNormalized, LambdaFunction, Literal, NamedLambdaVariable, 
TransformValues, UnaryExpression}
+import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayDistinct, 
ArrayExcept, ArrayIntersect, ArraysOverlap, ArrayTransform, ArrayUnion, 
CaseWhen, Coalesce, CreateArray, CreateMap, CreateNamedStruct, EqualTo, 
ExpectsInputTypes, Expression, GetStructField, If, IsNull, 
KnownFloatingPointNormalized, LambdaFunction, Literal, NamedLambdaVariable, 
TransformValues, UnaryExpression}
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, 
ExprCode}
 import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
 import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Window}
@@ -31,59 +31,82 @@ import org.apache.spark.util.ArrayImplicits._
  * We need to take care of special floating numbers (NaN and -0.0) in several 
places:
  *   1. When compare values, different NaNs should be treated as same, `-0.0` 
and `0.0` should be
  *      treated as same.
- *   2. In aggregate grouping keys, different NaNs should belong to the same 
group, -0.0 and 0.0
+ *   2. In aggregate grouping keys, different NaNs should belong to the same 
group, `-0.0` and `0.0`
  *      should belong to the same group.
  *   3. In join keys, different NaNs should be treated as same, `-0.0` and 
`0.0` should be
  *      treated as same.
- *   4. In window partition keys, different NaNs should belong to the same 
partition, -0.0 and 0.0
- *      should belong to the same partition.
+ *   4. In window partition keys, different NaNs should belong to the same 
partition, `-0.0`
+ *      and `0.0` should belong to the same partition.
+ *   5. In hash-based array set operations, different NaNs should be treated 
as same, `-0.0`
+ *      and `0.0` should be treated as same.
  *
- * Case 1 is fine, as we handle NaN and -0.0 well during comparison. For 
complex types, we
+ * Case 1 is fine, as we handle NaN and `-0.0` well during comparison. For 
complex types, we
  * recursively compare the fields/elements, so it's also fine.
  *
  * Case 2, 3 and 4 are problematic, as Spark SQL turns grouping/join/window 
partition keys into
  * binary `UnsafeRow` and compare the binary data directly. Different NaNs 
have different binary
- * representation, and the same thing happens for -0.0 and 0.0.
+ * representation, and the same thing happens for `-0.0` and `0.0`.
  *
- * This rule normalizes NaN and -0.0 in window partition keys, join keys and 
aggregate grouping
- * keys.
+ * Case 5 is problematic for a similar reason: hash-based array set operations 
compare elements by
+ * their binary representation via hash sets.
+ *
+ * This rule runs in two places:
+ *    1. Early in `FinishAnalysis` (right after `ReplaceExpressions` and 
before `EvalInlineTables`)
+ *    so that array set-like operations are wrapped before optimizer rules 
that pre-evaluate
+ *    expressions (e.g. `ConstantFolding`, `ConvertToLocalRelation`, 
`EvalInlineTables`).
+ *
+ *    2. As a late batch at the end of the optimizer, because rules like 
subquery rewrite and
+ *    join reorder can create new joins or join conditions after 
`FinishAnalysis` that still
+ *    need their keys to be normalized.
  *
  * Ideally we should do the normalization in the physical operators that 
compare the
  * binary `UnsafeRow` directly. We don't need this normalization if the Spark 
SQL execution engine
  * is not optimized to run on binary data. This rule is created to simplify 
the implementation, so
  * that we have a single place to do normalization, which is more maintainable.
  *
- * Note that, this rule must be executed at the end of optimizer, because the 
optimizer may create
- * new joins(the subquery rewrite) and new join conditions(the join reorder).
  */
 object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
 
-  def apply(plan: LogicalPlan): LogicalPlan = plan match {
-    case _ => plan.transformWithPruning( _.containsAnyPattern(WINDOW, JOIN)) {
-      case w: Window if w.partitionSpec.exists(p => needNormalize(p)) =>
-        // Although the `windowExpressions` may refer to `partitionSpec` 
expressions, we don't need
-        // to normalize the `windowExpressions`, as they are executed per 
input row and should take
-        // the input row as it is.
-        w.copy(partitionSpec = w.partitionSpec.map(normalize))
-
-      // Only hash join and sort merge join need the normalization. Here we 
catch all Joins with
-      // join keys, assuming Joins with join keys are always planned as hash 
join or sort merge
-      // join. It's very unlikely that we will break this assumption in the 
near future.
-      case j @ ExtractEquiJoinKeys(_, leftKeys, rightKeys, condition, _, _, _, 
_)
-          // The analyzer guarantees left and right joins keys are of the same 
data type. Here we
-          // only need to check join keys of one side.
-          if leftKeys.exists(k => needNormalize(k)) =>
-        val newLeftJoinKeys = leftKeys.map(normalize)
-        val newRightJoinKeys = rightKeys.map(normalize)
-        val newConditions = newLeftJoinKeys.zip(newRightJoinKeys).map {
-          case (l, r) => EqualTo(l, r)
-        } ++ condition
-        j.copy(condition = Some(newConditions.reduce(And)))
-
-      // TODO: ideally Aggregate should also be handled here, but its grouping 
expressions are
-      // mixed in its aggregate expressions. It's unreliable to change the 
grouping expressions
-      // here. For now we normalize grouping expressions in `AggUtils` during 
planning.
-    }
+  def apply(plan: LogicalPlan): LogicalPlan = {
+    plan
+      .transformWithPruning( _.containsAnyPattern(WINDOW, JOIN)) {
+        case w: Window if w.partitionSpec.exists(p => needNormalize(p)) =>
+          // Although the `windowExpressions` may refer to `partitionSpec` 
expressions,
+          // we don't need to normalize the `windowExpressions`, as they are 
executed
+          // per input row and should take the input row as it is.
+          w.copy(partitionSpec = w.partitionSpec.map(normalize))
+
+        // Only hash join and sort merge join need the normalization. Here we 
catch all Joins with
+        // join keys, assuming Joins with join keys are always planned as hash 
join or sort merge
+        // join. It's very unlikely that we will break this assumption in the 
near future.
+        case j @ ExtractEquiJoinKeys(_, leftKeys, rightKeys, condition, _, _, 
_, _)
+            // The analyzer guarantees left and right joins keys are of the 
same data type. Here we
+            // only need to check join keys of one side.
+            if leftKeys.exists(k => needNormalize(k)) =>
+          val newLeftJoinKeys = leftKeys.map(normalize)
+          val newRightJoinKeys = rightKeys.map(normalize)
+          val newConditions = newLeftJoinKeys.zip(newRightJoinKeys).map {
+            case (l, r) => EqualTo(l, r)
+          } ++ condition
+          j.copy(condition = Some(newConditions.reduce(And)))
+
+        // TODO: ideally Aggregate should also be handled here, but its 
grouping expressions are
+        // mixed in its aggregate expressions. It's unreliable to change the 
grouping expressions
+        // here. For now we normalize grouping expressions in `AggUtils` 
during planning.
+      }
+      .transformAllExpressionsWithPruning(_.containsAnyPattern(
+        ARRAY_DISTINCT, ARRAY_UNION, ARRAY_INTERSECT, ARRAY_EXCEPT, 
ARRAYS_OVERLAP)) {
+        case e: ArrayDistinct if needNormalize(e.child) =>
+          e.copy(child = normalize(e.child))
+        case e: ArrayUnion if needNormalize(e.left) =>
+          e.copy(left = normalize(e.left), right = normalize(e.right))
+        case e: ArrayIntersect if needNormalize(e.left) =>
+          e.copy(left = normalize(e.left), right = normalize(e.right))
+        case e: ArrayExcept if needNormalize(e.left) =>
+          e.copy(left = normalize(e.left), right = normalize(e.right))
+        case e: ArraysOverlap if needNormalize(e.left) =>
+          e.copy(left = normalize(e.left), right = normalize(e.right))
+      }
   }
 
   /**
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 1c991729c7d4..0cf03052cbdb 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -328,6 +328,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
       EliminateView,
       EliminateSQLFunctionNode,
       ReplaceExpressions,
+      NormalizeFloatingNumbers,
       RewriteNonCorrelatedExists,
       PullOutGroupingExpressions,
       // Put `InsertMapSortInGroupingExpressions` after 
`PullOutGroupingExpressions`,
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
index 4e06fcb36767..557b01167d88 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
@@ -29,7 +29,12 @@ object TreePattern extends Enumeration  {
   val ALIAS: Value = Value
   val ANALYSIS_AWARE_EXPRESSION: Value = Value
   val AND: Value = Value
+  val ARRAYS_OVERLAP: Value = Value
   val ARRAYS_ZIP: Value = Value
+  val ARRAY_DISTINCT: Value = Value
+  val ARRAY_EXCEPT: Value = Value
+  val ARRAY_INTERSECT: Value = Value
+  val ARRAY_UNION: Value = Value
   val ATTRIBUTE_REFERENCE: Value = Value
   val AVERAGE: Value = Value
   val BINARY_ARITHMETIC: Value = Value
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala
index 21049ca3546d..a0a9c8ec3224 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala
@@ -19,10 +19,11 @@ package org.apache.spark.sql.catalyst.optimizer
 
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.{CaseWhen, If, IsNull, 
KnownFloatingPointNormalized}
+import org.apache.spark.sql.catalyst.expressions.{ArrayDistinct, ArrayExcept, 
ArrayIntersect, ArraysOverlap, ArrayTransform, ArrayUnion, CaseWhen, 
Expression, If, IsNull, KnownFloatingPointNormalized, LambdaFunction, 
NamedLambdaVariable}
 import org.apache.spark.sql.catalyst.plans.PlanTest
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.types.DoubleType
 
 class NormalizeFloatingPointNumbersSuite extends PlanTest {
 
@@ -34,6 +35,18 @@ class NormalizeFloatingPointNumbersSuite extends PlanTest {
   val a = testRelation1.output(0)
   val testRelation2 = LocalRelation($"a".double)
   val b = testRelation2.output(0)
+  val arrayRelation = LocalRelation($"arr1".array(DoubleType), 
$"arr2".array(DoubleType))
+  val arr1 = arrayRelation.output(0)
+  val arr2 = arrayRelation.output(1)
+
+  private def normalizedArray(e: Expression): KnownFloatingPointNormalized = {
+    val lv = NamedLambdaVariable("arg", DoubleType, nullable = true)
+    KnownFloatingPointNormalized(
+      ArrayTransform(e,
+        LambdaFunction(
+          KnownFloatingPointNormalized(NormalizeNaNAndZero(lv)),
+          Seq(lv))))
+  }
 
   test("normalize floating points in window function expressions") {
     val query = testRelation1.window(Seq(sum(a).as("sum")), Seq(a), Seq(a.asc))
@@ -132,5 +145,108 @@ class NormalizeFloatingPointNumbersSuite extends PlanTest 
{
     val normalizedExpr = NormalizeFloatingNumbers.normalize(nestedExpr)
     assert(nestedExpr.dataType == normalizedExpr.dataType)
   }
+
+  test("SPARK-54918: normalize floating points in array_distinct") {
+    val query = arrayRelation.select(ArrayDistinct(arr1).as("result"))
+
+    val optimized = Optimize.execute(query)
+    val correctAnswer = 
arrayRelation.select(ArrayDistinct(normalizedArray(arr1)).as("result"))
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("SPARK-54918: normalize floating points in array_distinct - 
idempotence") {
+    val query = arrayRelation.select(ArrayDistinct(arr1).as("result"))
+
+    val optimized = Optimize.execute(query)
+    val doubleOptimized = Optimize.execute(optimized)
+    val correctAnswer = 
arrayRelation.select(ArrayDistinct(normalizedArray(arr1)).as("result"))
+
+    comparePlans(doubleOptimized, correctAnswer)
+  }
+
+  test("SPARK-54918: normalize floating points in array_union") {
+    val query = arrayRelation.select(ArrayUnion(arr1, arr2).as("result"))
+
+    val optimized = Optimize.execute(query)
+    val correctAnswer = arrayRelation.select(
+      ArrayUnion(normalizedArray(arr1), normalizedArray(arr2)).as("result"))
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("SPARK-54918: normalize floating points in array_union - idempotence") {
+    val query = arrayRelation.select(ArrayUnion(arr1, arr2).as("result"))
+
+    val optimized = Optimize.execute(query)
+    val doubleOptimized = Optimize.execute(optimized)
+    val correctAnswer = arrayRelation.select(
+      ArrayUnion(normalizedArray(arr1), normalizedArray(arr2)).as("result"))
+
+    comparePlans(doubleOptimized, correctAnswer)
+  }
+
+  test("SPARK-54918: normalize floating points in array_intersect") {
+    val query = arrayRelation.select(ArrayIntersect(arr1, arr2).as("result"))
+
+    val optimized = Optimize.execute(query)
+    val correctAnswer = arrayRelation.select(
+      ArrayIntersect(normalizedArray(arr1), 
normalizedArray(arr2)).as("result"))
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("SPARK-54918: normalize floating points in array_intersect - 
idempotence") {
+    val query = arrayRelation.select(ArrayIntersect(arr1, arr2).as("result"))
+
+    val optimized = Optimize.execute(query)
+    val doubleOptimized = Optimize.execute(optimized)
+    val correctAnswer = arrayRelation.select(
+      ArrayIntersect(normalizedArray(arr1), 
normalizedArray(arr2)).as("result"))
+
+    comparePlans(doubleOptimized, correctAnswer)
+  }
+
+  test("SPARK-54918: normalize floating points in array_except") {
+    val query = arrayRelation.select(ArrayExcept(arr1, arr2).as("result"))
+
+    val optimized = Optimize.execute(query)
+    val correctAnswer = arrayRelation.select(
+      ArrayExcept(normalizedArray(arr1), normalizedArray(arr2)).as("result"))
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("SPARK-54918: normalize floating points in array_except - idempotence") 
{
+    val query = arrayRelation.select(ArrayExcept(arr1, arr2).as("result"))
+
+    val optimized = Optimize.execute(query)
+    val doubleOptimized = Optimize.execute(optimized)
+    val correctAnswer = arrayRelation.select(
+      ArrayExcept(normalizedArray(arr1), normalizedArray(arr2)).as("result"))
+
+    comparePlans(doubleOptimized, correctAnswer)
+  }
+
+  test("SPARK-54918: normalize floating points in arrays_overlap") {
+    val query = arrayRelation.select(ArraysOverlap(arr1, arr2).as("result"))
+
+    val optimized = Optimize.execute(query)
+    val correctAnswer = arrayRelation.select(
+      ArraysOverlap(normalizedArray(arr1), normalizedArray(arr2)).as("result"))
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("SPARK-54918: normalize floating points in arrays_overlap - 
idempotence") {
+    val query = arrayRelation.select(ArraysOverlap(arr1, arr2).as("result"))
+
+    val optimized = Optimize.execute(query)
+    val doubleOptimized = Optimize.execute(optimized)
+    val correctAnswer = arrayRelation.select(
+      ArraysOverlap(normalizedArray(arr1), normalizedArray(arr2)).as("result"))
+
+    comparePlans(doubleOptimized, correctAnswer)
+  }
 }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 7faccbde997d..8f3098bedccc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -6348,7 +6348,159 @@ class DataFrameFunctionsSuite extends 
SharedSparkSession {
           call_function("spark_catalog.default.custom_sum", $"a")),
         Row(12.0, 12.0, 12.0))
     }
+  }
+
+  private def isPositiveZero(d: Double): Boolean =
+    java.lang.Double.doubleToRawLongBits(d) == 0L
+
+  test("SPARK-54918: array set ops normalize -0.0 and NaN via VALUES inline 
table") {
+    val r = sql("""
+      SELECT
+        array_distinct(a) AS d,
+        array_union(a, b) AS u,
+        array_intersect(a, b) AS i,
+        array_except(a, b) AS e,
+        arrays_overlap(a, b) AS o
+      FROM VALUES (array(-0.0d, 0.0d, double('NaN')), array(0.0d, 
double('NaN')))
+      AS t(a, b)
+    """).head()
+
+    val distinct = r.getSeq[Double](0)
+    assert(distinct.length == 2)
+    assert(distinct.exists(isPositiveZero))
+    assert(distinct.exists(_.isNaN))
+
+    val union = r.getSeq[Double](1)
+    assert(union.length == 2)
+    assert(union.exists(isPositiveZero))
+    assert(union.exists(_.isNaN))
+
+    val intersect = r.getSeq[Double](2)
+    assert(intersect.length == 2)
+    assert(intersect.exists(isPositiveZero))
+    assert(intersect.exists(_.isNaN))
+
+    val except = r.getSeq[Double](3)
+    assert(except.isEmpty)
+
+    assert(r.getBoolean(4))
+  }
+
+  test("SPARK-54918: array_distinct normalizes -0.0 to +0.0 - literals") {
+    val r1 = Seq(1).toDF()
+      .select(array_distinct(typedLit(Array(-0.0d, 
0.0d)))).head().getSeq[Double](0)
+
+    assert(r1.length == 1)
+    assert(isPositiveZero(r1.head))
+
+    val r2 = Seq(1).toDF()
+      .select(array_distinct(
+        typedLit(Array(Double.NaN, 0.0d, -0.0d, Double.NaN)))
+      ).head().getSeq[Double](0)
+
+    assert(r2.length == 2)
+    assert(r2.exists(_.isNaN))
+    assert(r2.exists(isPositiveZero))
+  }
+
+  test("SPARK-54918: array_distinct normalizes -0.0 to +0.0") {
+    val r1 = Seq(Array(-0.0d, 0.0d)).toDF("a")
+      .select(array_distinct($"a")).head().getSeq[Double](0)
+
+    assert(r1.length == 1)
+    assert(isPositiveZero(r1.head))
+
+    val r2 = Seq(Array(Double.NaN, 0.0d, -0.0d, Double.NaN)).toDF("a")
+      .select(array_distinct($"a")).head().getSeq[Double](0)
+
+    assert(r2.length == 2)
+    assert(r2.exists(_.isNaN))
+    assert(r2.exists(isPositiveZero))
+  }
+
+  test("SPARK-54918: array_union normalizes -0.0 to +0.0 - literals") {
+    val r = Seq(1).toDF()
+      .select(array_union(
+        typedLit(Array(-0.0d)),
+        typedLit(Array(0.0d)))
+      ).head().getSeq[Double](0)
+
+    assert(r.length == 1)
+    assert(isPositiveZero(r.head))
+  }
+
+  test("SPARK-54918: array_union normalizes -0.0 to +0.0") {
+    val r = Seq((Array(-0.0d), Array(0.0d))).toDF("a", "b")
+      .select(array_union($"a", $"b")).head().getSeq[Double](0)
+
+    assert(r.length == 1)
+    assert(isPositiveZero(r.head))
+  }
+
+  test("SPARK-54918: array_intersect normalizes -0.0 to +0.0 - literals") {
+    val r = Seq(1).toDF()
+      .select(array_intersect(
+        typedLit(Array(-0.0d)),
+        typedLit(Array(0.0d)))
+      ).head().getSeq[Double](0)
+
+    assert(r.length == 1)
+    assert(isPositiveZero(r.head))
+  }
+
+  test("SPARK-54918: array_intersect normalizes -0.0 to +0.0") {
+    val r = Seq((Array(-0.0d), Array(0.0d))).toDF("a", "b")
+      .select(array_intersect($"a", $"b")).head().getSeq[Double](0)
+
+    assert(r.length == 1)
+    assert(isPositiveZero(r.head))
+  }
+
+  test("SPARK-54918: array_except normalizes -0.0 to +0.0 - literals") {
+    val r1 = Seq(1).toDF()
+      .select(array_except(
+        typedLit(Array(-0.0d)),
+        typedLit(Array(0.0d)))
+      ).head().getSeq[Double](0)
+
+    assert(r1.isEmpty)
+
+    val r2 = Seq(1).toDF()
+      .select(array_except(
+        typedLit(Array(0.0d)),
+        typedLit(Array(-0.0d)))
+      ).head().getSeq[Double](0)
+
+    assert(r2.isEmpty)
+  }
+
+  test("SPARK-54918: array_except normalizes -0.0 to +0.0") {
+    val r1 = Seq((Array(-0.0d), Array(0.0d))).toDF("a", "b")
+      .select(array_except($"a", $"b")).head().getSeq[Double](0)
+
+    assert(r1.isEmpty)
+
+    val r2 = Seq((Array(0.0d), Array(-0.0d))).toDF("a", "b")
+      .select(array_except($"a", $"b")).head().getSeq[Double](0)
+
+    assert(r2.isEmpty)
+  }
+
+  test("SPARK-54918: arrays_overlap normalizes -0.0 to +0.0 - literals") {
+    val r = Seq(1).toDF()
+      .select(arrays_overlap(
+        typedLit(Array(-0.0d)),
+        typedLit(Array(0.0d)))
+      ).head().getBoolean(0)
+
+    assert(r)
+  }
+
+  test("SPARK-54918: arrays_overlap normalizes -0.0 to +0.0") {
+    val r = Seq((Array(-0.0d), Array(0.0d))).toDF("a", "b")
+      .select(arrays_overlap($"a", $"b")).head().getBoolean(0)
 
+    assert(r)
   }
 }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to