This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new e356f6a [SPARK-36741][SQL] ArrayDistinct handle duplicated Double.NaN
and Float.Nan
e356f6a is described below
commit e356f6aa1119f4ceeafc7bcdea5f7b8f1f010638
Author: Angerszhuuuu <[email protected]>
AuthorDate: Fri Sep 17 20:48:17 2021 +0800
[SPARK-36741][SQL] ArrayDistinct handle duplicated Double.NaN and Float.Nan
### What changes were proposed in this pull request?
For query
```
select array_distinct(array(cast('nan' as double), cast('nan' as double)))
```
This returns [NaN, NaN], but it should return [NaN].
This issue is caused by `OpenHashSet` can't handle `Double.NaN` and
`Float.NaN` too.
In this pr fix this based on https://github.com/apache/spark/pull/33955
### Why are the changes needed?
Fix bug
### Does this PR introduce _any_ user-facing change?
ArrayDistinct won't show duplicated `NaN` value
### How was this patch tested?
Added UT
Closes #33993 from AngersZhuuuu/SPARK-36741.
Authored-by: Angerszhuuuu <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../expressions/collectionOperations.scala | 124 +++++++++++----------
.../org/apache/spark/sql/util/SQLOpenHashSet.scala | 54 +++++++--
.../expressions/CollectionExpressionsSuite.scala | 9 ++
3 files changed, 121 insertions(+), 66 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index ba000a3..a50263c 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -3412,32 +3412,59 @@ case class ArrayDistinct(child: Expression)
}
override def nullSafeEval(array: Any): Any = {
- val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType)
+ val data = array.asInstanceOf[ArrayData]
doEvaluation(data)
}
@transient private lazy val doEvaluation = if
(TypeUtils.typeWithProperEquals(elementType)) {
- (data: Array[AnyRef]) => new
GenericArrayData(data.distinct.asInstanceOf[Array[Any]])
+ (array: ArrayData) =>
+ val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
+ val hs = new SQLOpenHashSet[Any]()
+ val withNaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs,
+ (value: Any) =>
+ if (!hs.contains(value)) {
+ if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
+
ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.size)
+ }
+ arrayBuffer += value
+ hs.add(value)
+ },
+ (valueNaN: Any) => arrayBuffer += valueNaN)
+ var i = 0
+ while (i < array.numElements()) {
+ if (array.isNullAt(i)) {
+ if (!hs.containsNull) {
+ hs.addNull
+ arrayBuffer += null
+ }
+ } else {
+ val elem = array.get(i, elementType)
+ withNaNCheckFunc(elem)
+ }
+ i += 1
+ }
+ new GenericArrayData(arrayBuffer.toSeq)
} else {
- (data: Array[AnyRef]) => {
+ (data: ArrayData) => {
+ val array = data.toArray[AnyRef](elementType)
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[AnyRef]
var alreadyStoredNull = false
- for (i <- 0 until data.length) {
- if (data(i) != null) {
+ for (i <- 0 until array.length) {
+ if (array(i) != null) {
var found = false
var j = 0
while (!found && j < arrayBuffer.size) {
val va = arrayBuffer(j)
- found = (va != null) && ordering.equiv(va, data(i))
+ found = (va != null) && ordering.equiv(va, array(i))
j += 1
}
if (!found) {
- arrayBuffer += data(i)
+ arrayBuffer += array(i)
}
} else {
// De-duplicate the null values.
if (!alreadyStoredNull) {
- arrayBuffer += data(i)
+ arrayBuffer += array(i)
alreadyStoredNull = true
}
}
@@ -3456,10 +3483,9 @@ case class ArrayDistinct(child: Expression)
val ptName = CodeGenerator.primitiveTypeName(jt)
nullSafeCodeGen(ctx, ev, (array) => {
- val foundNullElement = ctx.freshName("foundNullElement")
val nullElementIndex = ctx.freshName("nullElementIndex")
val builder = ctx.freshName("builder")
- val openHashSet = classOf[OpenHashSet[_]].getName
+ val openHashSet = classOf[SQLOpenHashSet[_]].getName
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
val hashSet = ctx.freshName("hashSet")
val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName
@@ -3468,7 +3494,6 @@ case class ArrayDistinct(child: Expression)
// Only need to track null element index when array's element is
nullable.
val declareNullTrackVariables = if
(dataType.asInstanceOf[ArrayType].containsNull) {
s"""
- |boolean $foundNullElement = false;
|int $nullElementIndex = -1;
""".stripMargin
} else {
@@ -3479,9 +3504,9 @@ case class ArrayDistinct(child: Expression)
if (dataType.asInstanceOf[ArrayType].containsNull) {
s"""
|if ($array.isNullAt($i)) {
- | if (!$foundNullElement) {
+ | if (!$hashSet.containsNull()) {
| $nullElementIndex = $size;
- | $foundNullElement = true;
+ | $hashSet.addNull();
| $size++;
| $builder.$$plus$$eq($nullValueHolder);
| }
@@ -3493,9 +3518,8 @@ case class ArrayDistinct(child: Expression)
body
}
- val processArray = withArrayNullAssignment(
+ val body =
s"""
- |$jt $value = ${genGetValue(array, i)};
|if (!$hashSet.contains($hsValueCast$value)) {
| if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
| break;
@@ -3503,7 +3527,16 @@ case class ArrayDistinct(child: Expression)
| $hashSet.add$hsPostFix($hsValueCast$value);
| $builder.$$plus$$eq($value);
|}
- """.stripMargin)
+ """.stripMargin
+
+ val processArray = withArrayNullAssignment(
+ s"$jt $value = ${genGetValue(array, i)};" +
+ SQLOpenHashSet.withNaNCheckCode(elementType, value, hashSet, body,
+ (valueNaN: String) =>
+ s"""
+ |$size++;
+ |$builder.$$plus$$eq($valueNaN);
+ |""".stripMargin))
s"""
|$openHashSet $hashSet = new $openHashSet$hsPostFix($classTag);
@@ -3579,8 +3612,16 @@ case class ArrayUnion(left: Expression, right:
Expression) extends ArrayBinaryLi
(array1, array2) =>
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
val hs = new SQLOpenHashSet[Any]()
- val isNaN = SQLOpenHashSet.isNaN(elementType)
- val valueNaN = SQLOpenHashSet.valueNaN(elementType)
+ val withNaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs,
+ (value: Any) =>
+ if (!hs.contains(value)) {
+ if (arrayBuffer.size >
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
+
ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.size)
+ }
+ arrayBuffer += value
+ hs.add(value)
+ },
+ (valueNaN: Any) => arrayBuffer += valueNaN)
Seq(array1, array2).foreach { array =>
var i = 0
while (i < array.numElements()) {
@@ -3591,20 +3632,7 @@ case class ArrayUnion(left: Expression, right:
Expression) extends ArrayBinaryLi
}
} else {
val elem = array.get(i, elementType)
- if (isNaN(elem)) {
- if (!hs.containsNaN) {
- arrayBuffer += valueNaN
- hs.addNaN
- }
- } else {
- if (!hs.contains(elem)) {
- if (arrayBuffer.size >
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
-
ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.size)
- }
- arrayBuffer += elem
- hs.add(elem)
- }
- }
+ withNaNCheckFunc(elem)
}
i += 1
}
@@ -3689,28 +3717,6 @@ case class ArrayUnion(left: Expression, right:
Expression) extends ArrayBinaryLi
body
}
- def withNaNCheck(body: String): String = {
- (elementType match {
- case DoubleType =>
- Some((s"java.lang.Double.isNaN((double)$value)",
"java.lang.Double.NaN"))
- case FloatType =>
- Some((s"java.lang.Float.isNaN((float)$value)",
"java.lang.Float.NaN"))
- case _ => None
- }).map { case (isNaN, valueNaN) =>
- s"""
- |if ($isNaN) {
- | if (!$hashSet.containsNaN()) {
- | $size++;
- | $hashSet.addNaN();
- | $builder.$$plus$$eq($valueNaN);
- | }
- |} else {
- | $body
- |}
- """.stripMargin
- }
- }.getOrElse(body)
-
val body =
s"""
|if (!$hashSet.contains($hsValueCast$value)) {
@@ -3721,8 +3727,14 @@ case class ArrayUnion(left: Expression, right:
Expression) extends ArrayBinaryLi
| $builder.$$plus$$eq($value);
|}
""".stripMargin
- val processArray =
- withArrayNullAssignment(s"$jt $value = ${genGetValue(array, i)};" +
withNaNCheck(body))
+ val processArray = withArrayNullAssignment(
+ s"$jt $value = ${genGetValue(array, i)};" +
+ SQLOpenHashSet.withNaNCheckCode(elementType, value, hashSet, body,
+ (valueNaN: String) =>
+ s"""
+ |$size++;
+ |$builder.$$plus$$eq($valueNaN);
+ """.stripMargin))
// Only need to track null element index when result array's element
is nullable.
val declareNullTrackVariables = if
(dataType.asInstanceOf[ArrayType].containsNull) {
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SQLOpenHashSet.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SQLOpenHashSet.scala
index 083cfdd..e09cd95 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SQLOpenHashSet.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SQLOpenHashSet.scala
@@ -60,21 +60,55 @@ class SQLOpenHashSet[@specialized(Long, Int, Double, Float)
T: ClassTag](
}
object SQLOpenHashSet {
- def isNaN(dataType: DataType): Any => Boolean = {
- dataType match {
+ def withNaNCheckFunc(
+ dataType: DataType,
+ hashSet: SQLOpenHashSet[Any],
+ handleNotNaN: Any => Unit,
+ handleNaN: Any => Unit): Any => Unit = {
+ val (isNaN, valueNaN) = dataType match {
case DoubleType =>
- (value: Any) =>
java.lang.Double.isNaN(value.asInstanceOf[java.lang.Double])
+ ((value: Any) =>
java.lang.Double.isNaN(value.asInstanceOf[java.lang.Double]),
+ java.lang.Double.NaN)
case FloatType =>
- (value: Any) =>
java.lang.Float.isNaN(value.asInstanceOf[java.lang.Float])
- case _ => (_: Any) => false
+ ((value: Any) =>
java.lang.Float.isNaN(value.asInstanceOf[java.lang.Float]),
+ java.lang.Float.NaN)
+ case _ => ((_: Any) => false, null)
}
+ (value: Any) =>
+ if (isNaN(value)) {
+ if (!hashSet.containsNaN) {
+ hashSet.addNaN
+ handleNaN(valueNaN)
+ }
+ } else {
+ handleNotNaN(value)
+ }
}
- def valueNaN(dataType: DataType): Any = {
- dataType match {
- case DoubleType => java.lang.Double.NaN
- case FloatType => java.lang.Float.NaN
- case _ => null
+ def withNaNCheckCode(
+ dataType: DataType,
+ valueName: String,
+ hashSet: String,
+ handleNotNaN: String,
+ handleNaN: String => String): String = {
+ val ret = dataType match {
+ case DoubleType =>
+ Some((s"java.lang.Double.isNaN((double)$valueName)",
"java.lang.Double.NaN"))
+ case FloatType =>
+ Some((s"java.lang.Float.isNaN((float)$valueName)",
"java.lang.Float.NaN"))
+ case _ => None
}
+ ret.map { case (isNaN, valueNaN) =>
+ s"""
+ |if ($isNaN) {
+ | if (!$hashSet.containsNaN()) {
+ | $hashSet.addNaN();
+ | ${handleNaN(valueNaN)}
+ | }
+ |} else {
+ | $handleNotNaN
+ |}
+ """.stripMargin
+ }.getOrElse(handleNotNaN)
}
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index b9b3abe..1be9b1f 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -2327,6 +2327,15 @@ class CollectionExpressionsSuite extends SparkFunSuite
with ExpressionEvalHelper
Seq(Float.NaN, null, 1f))
}
+ test("SPARK-36741: ArrayDistinct should handle duplicated Double.NaN and
Float.Nan") {
+ checkEvaluation(ArrayDistinct(
+ Literal.create(Seq(Double.NaN, Double.NaN, null, null, 1d, 1d),
ArrayType(DoubleType))),
+ Seq(Double.NaN, null, 1d))
+ checkEvaluation(ArrayDistinct(
+ Literal.create(Seq(Float.NaN, Float.NaN, null, null, 1f, 1f),
ArrayType(FloatType))),
+ Seq(Float.NaN, null, 1f))
+ }
+
test("SPARK-36755: ArraysOverlap hould handle duplicated Double.NaN and
Float.Nan") {
checkEvaluation(ArraysOverlap(
Literal.apply(Array(Double.NaN, 1d)), Literal.apply(Array(Double.NaN))),
true)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]