This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e356f6a  [SPARK-36741][SQL] ArrayDistinct handle duplicated Double.NaN 
and Float.Nan
e356f6a is described below

commit e356f6aa1119f4ceeafc7bcdea5f7b8f1f010638
Author: Angerszhuuuu <[email protected]>
AuthorDate: Fri Sep 17 20:48:17 2021 +0800

    [SPARK-36741][SQL] ArrayDistinct handle duplicated Double.NaN and Float.Nan
    
    ### What changes were proposed in this pull request?
    For query
    ```
    select array_distinct(array(cast('nan' as double), cast('nan' as double)))
    ```
    This returns [NaN, NaN], but it should return [NaN].
    This issue is caused by `OpenHashSet` can't handle `Double.NaN` and 
`Float.NaN` too.
    In this pr fix this based on https://github.com/apache/spark/pull/33955
    
    ### Why are the changes needed?
    Fix bug
    
    ### Does this PR introduce _any_ user-facing change?
    ArrayDistinct won't show duplicated `NaN` value
    
    ### How was this patch tested?
    Added UT
    
    Closes #33993 from AngersZhuuuu/SPARK-36741.
    
    Authored-by: Angerszhuuuu <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../expressions/collectionOperations.scala         | 124 +++++++++++----------
 .../org/apache/spark/sql/util/SQLOpenHashSet.scala |  54 +++++++--
 .../expressions/CollectionExpressionsSuite.scala   |   9 ++
 3 files changed, 121 insertions(+), 66 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index ba000a3..a50263c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -3412,32 +3412,59 @@ case class ArrayDistinct(child: Expression)
   }
 
   override def nullSafeEval(array: Any): Any = {
-    val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType)
+    val data = array.asInstanceOf[ArrayData]
     doEvaluation(data)
   }
 
   @transient private lazy val doEvaluation = if 
(TypeUtils.typeWithProperEquals(elementType)) {
-    (data: Array[AnyRef]) => new 
GenericArrayData(data.distinct.asInstanceOf[Array[Any]])
+    (array: ArrayData) =>
+      val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
+      val hs = new SQLOpenHashSet[Any]()
+      val withNaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs,
+        (value: Any) =>
+          if (!hs.contains(value)) {
+            if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
+              
ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.size)
+            }
+            arrayBuffer += value
+            hs.add(value)
+          },
+        (valueNaN: Any) => arrayBuffer += valueNaN)
+      var i = 0
+      while (i < array.numElements()) {
+        if (array.isNullAt(i)) {
+          if (!hs.containsNull) {
+            hs.addNull
+            arrayBuffer += null
+          }
+        } else {
+          val elem = array.get(i, elementType)
+          withNaNCheckFunc(elem)
+        }
+        i += 1
+      }
+      new GenericArrayData(arrayBuffer.toSeq)
   } else {
-    (data: Array[AnyRef]) => {
+    (data: ArrayData) => {
+      val array = data.toArray[AnyRef](elementType)
       val arrayBuffer = new scala.collection.mutable.ArrayBuffer[AnyRef]
       var alreadyStoredNull = false
-      for (i <- 0 until data.length) {
-        if (data(i) != null) {
+      for (i <- 0 until array.length) {
+        if (array(i) != null) {
           var found = false
           var j = 0
           while (!found && j < arrayBuffer.size) {
             val va = arrayBuffer(j)
-            found = (va != null) && ordering.equiv(va, data(i))
+            found = (va != null) && ordering.equiv(va, array(i))
             j += 1
           }
           if (!found) {
-            arrayBuffer += data(i)
+            arrayBuffer += array(i)
           }
         } else {
           // De-duplicate the null values.
           if (!alreadyStoredNull) {
-            arrayBuffer += data(i)
+            arrayBuffer += array(i)
             alreadyStoredNull = true
           }
         }
@@ -3456,10 +3483,9 @@ case class ArrayDistinct(child: Expression)
       val ptName = CodeGenerator.primitiveTypeName(jt)
 
       nullSafeCodeGen(ctx, ev, (array) => {
-        val foundNullElement = ctx.freshName("foundNullElement")
         val nullElementIndex = ctx.freshName("nullElementIndex")
         val builder = ctx.freshName("builder")
-        val openHashSet = classOf[OpenHashSet[_]].getName
+        val openHashSet = classOf[SQLOpenHashSet[_]].getName
         val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
         val hashSet = ctx.freshName("hashSet")
         val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName
@@ -3468,7 +3494,6 @@ case class ArrayDistinct(child: Expression)
         // Only need to track null element index when array's element is 
nullable.
         val declareNullTrackVariables = if 
(dataType.asInstanceOf[ArrayType].containsNull) {
           s"""
-             |boolean $foundNullElement = false;
              |int $nullElementIndex = -1;
            """.stripMargin
         } else {
@@ -3479,9 +3504,9 @@ case class ArrayDistinct(child: Expression)
           if (dataType.asInstanceOf[ArrayType].containsNull) {
             s"""
                |if ($array.isNullAt($i)) {
-               |  if (!$foundNullElement) {
+               |  if (!$hashSet.containsNull()) {
                |    $nullElementIndex = $size;
-               |    $foundNullElement = true;
+               |    $hashSet.addNull();
                |    $size++;
                |    $builder.$$plus$$eq($nullValueHolder);
                |  }
@@ -3493,9 +3518,8 @@ case class ArrayDistinct(child: Expression)
             body
           }
 
-        val processArray = withArrayNullAssignment(
+        val body =
           s"""
-             |$jt $value = ${genGetValue(array, i)};
              |if (!$hashSet.contains($hsValueCast$value)) {
              |  if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
              |    break;
@@ -3503,7 +3527,16 @@ case class ArrayDistinct(child: Expression)
              |  $hashSet.add$hsPostFix($hsValueCast$value);
              |  $builder.$$plus$$eq($value);
              |}
-           """.stripMargin)
+           """.stripMargin
+
+        val processArray = withArrayNullAssignment(
+          s"$jt $value = ${genGetValue(array, i)};" +
+            SQLOpenHashSet.withNaNCheckCode(elementType, value, hashSet, body,
+              (valueNaN: String) =>
+                s"""
+                   |$size++;
+                   |$builder.$$plus$$eq($valueNaN);
+                   |""".stripMargin))
 
         s"""
            |$openHashSet $hashSet = new $openHashSet$hsPostFix($classTag);
@@ -3579,8 +3612,16 @@ case class ArrayUnion(left: Expression, right: 
Expression) extends ArrayBinaryLi
       (array1, array2) =>
         val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
         val hs = new SQLOpenHashSet[Any]()
-        val isNaN = SQLOpenHashSet.isNaN(elementType)
-        val valueNaN = SQLOpenHashSet.valueNaN(elementType)
+        val withNaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs,
+          (value: Any) =>
+            if (!hs.contains(value)) {
+              if (arrayBuffer.size > 
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
+                
ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.size)
+              }
+              arrayBuffer += value
+              hs.add(value)
+            },
+          (valueNaN: Any) => arrayBuffer += valueNaN)
         Seq(array1, array2).foreach { array =>
           var i = 0
           while (i < array.numElements()) {
@@ -3591,20 +3632,7 @@ case class ArrayUnion(left: Expression, right: 
Expression) extends ArrayBinaryLi
               }
             } else {
               val elem = array.get(i, elementType)
-              if (isNaN(elem)) {
-                if (!hs.containsNaN) {
-                  arrayBuffer += valueNaN
-                  hs.addNaN
-                }
-              } else {
-                if (!hs.contains(elem)) {
-                  if (arrayBuffer.size > 
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
-                    
ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.size)
-                  }
-                  arrayBuffer += elem
-                  hs.add(elem)
-                }
-              }
+              withNaNCheckFunc(elem)
             }
             i += 1
           }
@@ -3689,28 +3717,6 @@ case class ArrayUnion(left: Expression, right: 
Expression) extends ArrayBinaryLi
             body
           }
 
-        def withNaNCheck(body: String): String = {
-          (elementType match {
-            case DoubleType =>
-              Some((s"java.lang.Double.isNaN((double)$value)", 
"java.lang.Double.NaN"))
-            case FloatType =>
-              Some((s"java.lang.Float.isNaN((float)$value)", 
"java.lang.Float.NaN"))
-            case _ => None
-          }).map { case (isNaN, valueNaN) =>
-            s"""
-               |if ($isNaN) {
-               |  if (!$hashSet.containsNaN()) {
-               |     $size++;
-               |     $hashSet.addNaN();
-               |     $builder.$$plus$$eq($valueNaN);
-               |  }
-               |} else {
-               |  $body
-               |}
-             """.stripMargin
-          }
-        }.getOrElse(body)
-
         val body =
           s"""
              |if (!$hashSet.contains($hsValueCast$value)) {
@@ -3721,8 +3727,14 @@ case class ArrayUnion(left: Expression, right: 
Expression) extends ArrayBinaryLi
              |  $builder.$$plus$$eq($value);
              |}
            """.stripMargin
-        val processArray =
-          withArrayNullAssignment(s"$jt $value = ${genGetValue(array, i)};" + 
withNaNCheck(body))
+        val processArray = withArrayNullAssignment(
+          s"$jt $value = ${genGetValue(array, i)};" +
+            SQLOpenHashSet.withNaNCheckCode(elementType, value, hashSet, body,
+              (valueNaN: String) =>
+                s"""
+                   |$size++;
+                   |$builder.$$plus$$eq($valueNaN);
+                 """.stripMargin))
 
         // Only need to track null element index when result array's element 
is nullable.
         val declareNullTrackVariables = if 
(dataType.asInstanceOf[ArrayType].containsNull) {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SQLOpenHashSet.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SQLOpenHashSet.scala
index 083cfdd..e09cd95 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SQLOpenHashSet.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SQLOpenHashSet.scala
@@ -60,21 +60,55 @@ class SQLOpenHashSet[@specialized(Long, Int, Double, Float) 
T: ClassTag](
 }
 
 object SQLOpenHashSet {
-  def isNaN(dataType: DataType): Any => Boolean = {
-    dataType match {
+  def withNaNCheckFunc(
+      dataType: DataType,
+      hashSet: SQLOpenHashSet[Any],
+      handleNotNaN: Any => Unit,
+      handleNaN: Any => Unit): Any => Unit = {
+    val (isNaN, valueNaN) = dataType match {
       case DoubleType =>
-        (value: Any) => 
java.lang.Double.isNaN(value.asInstanceOf[java.lang.Double])
+        ((value: Any) => 
java.lang.Double.isNaN(value.asInstanceOf[java.lang.Double]),
+          java.lang.Double.NaN)
       case FloatType =>
-        (value: Any) => 
java.lang.Float.isNaN(value.asInstanceOf[java.lang.Float])
-      case _ => (_: Any) => false
+        ((value: Any) => 
java.lang.Float.isNaN(value.asInstanceOf[java.lang.Float]),
+          java.lang.Float.NaN)
+      case _ => ((_: Any) => false, null)
     }
+    (value: Any) =>
+      if (isNaN(value)) {
+        if (!hashSet.containsNaN) {
+          hashSet.addNaN
+          handleNaN(valueNaN)
+        }
+      } else {
+        handleNotNaN(value)
+      }
   }
 
-  def valueNaN(dataType: DataType): Any = {
-    dataType match {
-      case DoubleType => java.lang.Double.NaN
-      case FloatType => java.lang.Float.NaN
-      case _ => null
+  def withNaNCheckCode(
+      dataType: DataType,
+      valueName: String,
+      hashSet: String,
+      handleNotNaN: String,
+      handleNaN: String => String): String = {
+    val ret = dataType match {
+      case DoubleType =>
+        Some((s"java.lang.Double.isNaN((double)$valueName)", 
"java.lang.Double.NaN"))
+      case FloatType =>
+        Some((s"java.lang.Float.isNaN((float)$valueName)", 
"java.lang.Float.NaN"))
+      case _ => None
     }
+    ret.map { case (isNaN, valueNaN) =>
+      s"""
+         |if ($isNaN) {
+         |  if (!$hashSet.containsNaN()) {
+         |     $hashSet.addNaN();
+         |     ${handleNaN(valueNaN)}
+         |  }
+         |} else {
+         |  $handleNotNaN
+         |}
+       """.stripMargin
+    }.getOrElse(handleNotNaN)
   }
 }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index b9b3abe..1be9b1f 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -2327,6 +2327,15 @@ class CollectionExpressionsSuite extends SparkFunSuite 
with ExpressionEvalHelper
       Seq(Float.NaN, null, 1f))
   }
 
+  test("SPARK-36741: ArrayDistinct should handle duplicated Double.NaN and 
Float.Nan") {
+    checkEvaluation(ArrayDistinct(
+      Literal.create(Seq(Double.NaN, Double.NaN, null, null, 1d, 1d), 
ArrayType(DoubleType))),
+      Seq(Double.NaN, null, 1d))
+    checkEvaluation(ArrayDistinct(
+      Literal.create(Seq(Float.NaN, Float.NaN, null, null, 1f, 1f), 
ArrayType(FloatType))),
+      Seq(Float.NaN, null, 1f))
+  }
+
   test("SPARK-36755: ArraysOverlap hould handle duplicated Double.NaN and 
Float.Nan") {
     checkEvaluation(ArraysOverlap(
       Literal.apply(Array(Double.NaN, 1d)), Literal.apply(Array(Double.NaN))), 
true)

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to