cloud-fan closed pull request #23043: [SPARK-26021][SQL] replace minus zero
with zero in Platform.putDouble/Float
URL: https://github.com/apache/spark/pull/23043
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
index aca6fca00c48b..bc94f2171228a 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
@@ -120,6 +120,11 @@ public static float getFloat(Object object, long offset) {
}
public static void putFloat(Object object, long offset, float value) {
+ if (Float.isNaN(value)) {
+ value = Float.NaN;
+ } else if (value == -0.0f) {
+ value = 0.0f;
+ }
_UNSAFE.putFloat(object, offset, value);
}
@@ -128,6 +133,11 @@ public static double getDouble(Object object, long offset)
{
}
public static void putDouble(Object object, long offset, double value) {
+ if (Double.isNaN(value)) {
+ value = Double.NaN;
+ } else if (value == -0.0d) {
+ value = 0.0d;
+ }
_UNSAFE.putDouble(object, offset, value);
}
diff --git
a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
index 3ad9ac7b4de9c..ab34324eb54cc 100644
--- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
+++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
@@ -157,4 +157,18 @@ public void heapMemoryReuse() {
Assert.assertEquals(onheap4.size(), 1024 * 1024 + 7);
Assert.assertEquals(obj3, onheap4.getBaseObject());
}
+
+ @Test
+ // SPARK-26021
+ public void writeMinusZeroIsReplacedWithZero() {
+ byte[] doubleBytes = new byte[Double.BYTES];
+ byte[] floatBytes = new byte[Float.BYTES];
+ Platform.putDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET, -0.0d);
+ Platform.putFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET, -0.0f);
+ double doubleFromPlatform = Platform.getDouble(doubleBytes,
Platform.BYTE_ARRAY_OFFSET);
+ float floatFromPlatform = Platform.getFloat(floatBytes,
Platform.BYTE_ARRAY_OFFSET);
+
+ Assert.assertEquals(Double.doubleToLongBits(0.0d),
Double.doubleToLongBits(doubleFromPlatform));
+ Assert.assertEquals(Float.floatToIntBits(0.0f),
Float.floatToIntBits(floatFromPlatform));
+ }
}
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index a76e6ef8c91c1..9bf9452855f5f 100644
---
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -224,9 +224,6 @@ public void setLong(int ordinal, long value) {
public void setDouble(int ordinal, double value) {
assertIndexIsValid(ordinal);
setNotNullAt(ordinal);
- if (Double.isNaN(value)) {
- value = Double.NaN;
- }
Platform.putDouble(baseObject, getFieldOffset(ordinal), value);
}
@@ -255,9 +252,6 @@ public void setByte(int ordinal, byte value) {
public void setFloat(int ordinal, float value) {
assertIndexIsValid(ordinal);
setNotNullAt(ordinal);
- if (Float.isNaN(value)) {
- value = Float.NaN;
- }
Platform.putFloat(baseObject, getFieldOffset(ordinal), value);
}
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
index 2781655002000..95263a0da95a8 100644
---
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
@@ -199,16 +199,10 @@ protected final void writeLong(long offset, long value) {
}
protected final void writeFloat(long offset, float value) {
- if (Float.isNaN(value)) {
- value = Float.NaN;
- }
Platform.putFloat(getBuffer(), offset, value);
}
protected final void writeDouble(long offset, double value) {
- if (Double.isNaN(value)) {
- value = Double.NaN;
- }
Platform.putDouble(getBuffer(), offset, value);
}
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index d9ba6e2ce5120..ff64edcd07f4b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -723,4 +723,18 @@ class DataFrameAggregateSuite extends QueryTest with
SharedSQLContext {
"grouping expressions: [current_date(None)], value: [key: int, value:
string], " +
"type: GroupBy]"))
}
+
+ test("SPARK-26021: Double and Float 0.0/-0.0 should be equal when grouping")
{
+ val colName = "i"
+ val doubles = Seq(0.0d, -0.0d,
0.0d).toDF(colName).groupBy(colName).count().collect()
+ val floats = Seq(0.0f, -0.0f,
0.0f).toDF(colName).groupBy(colName).count().collect()
+
+ assert(doubles.length == 1)
+ assert(floats.length == 1)
+ // using compare since 0.0 == -0.0 is true
+ assert(java.lang.Double.compare(doubles(0).getDouble(0), 0.0d) == 0)
+ assert(java.lang.Float.compare(floats(0).getFloat(0), 0.0f) == 0)
+ assert(doubles(0).getLong(1) == 3)
+ assert(floats(0).getLong(1) == 3)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index baca9c1cfb9a0..8ba67239fb907 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -289,7 +289,7 @@ object QueryTest {
def prepareRow(row: Row): Row = {
Row.fromSeq(row.toSeq.map {
case null => null
- case d: java.math.BigDecimal => BigDecimal(d)
+ case bd: java.math.BigDecimal => BigDecimal(bd)
// Equality of WrappedArray differs for AnyVal and AnyRef in Scala
2.12.2+
case seq: Seq[_] => seq.map {
case b: java.lang.Byte => b.byteValue
@@ -303,6 +303,9 @@ object QueryTest {
// Convert array to Seq for easy equality check.
case b: Array[_] => b.toSeq
case r: Row => prepareRow(r)
+ // spark treats -0.0 as 0.0
+ case d: Double if d == -0.0d => 0.0d
+ case f: Float if f == -0.0f => 0.0f
case o => o
})
}
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]