Repository: spark
Updated Branches:
  refs/heads/branch-2.4 ea4068a0a -> 443d12dbb


[SPARK-25538][SQL] Zero-out all bytes when writing decimal

## What changes were proposed in this pull request?

In #20850 when writing non-null decimals, instead of zero-ing all the 16 
allocated bytes, we zero-out only the padding bytes. Since we always allocate 
16 bytes, if the number of bytes needed for a decimal is lower than 9, then 
this means that the bytes between 8 and 16 are not zero-ed.

I see 2 solutions here:
 - we can zero-out all the bytes in advance as it was done before #20850 (safer 
solution IMHO);
 - we can allocate only the needed bytes (may be a bit more efficient in terms 
of memory used, but I have not investigated the feasibility of this option).

Hence I propose here the first solution in order to fix the correctness issue. 
We can eventually switch to the second if we think is more efficient later.

## How was this patch tested?

Running the test attached in the JIRA + added UT

Closes #22602 from mgaido91/SPARK-25582.

Authored-by: Marco Gaido <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
(cherry picked from commit d7ae36a810bfcbedfe7360eb2cdbbc3ca970e4d0)
Signed-off-by: Dongjoon Hyun <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/443d12db
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/443d12db
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/443d12db

Branch: refs/heads/branch-2.4
Commit: 443d12dbbe40e932978a9a1a811128da8afba89b
Parents: ea4068a
Author: Marco Gaido <[email protected]>
Authored: Wed Oct 3 07:28:34 2018 -0700
Committer: Dongjoon Hyun <[email protected]>
Committed: Wed Oct 3 07:28:48 2018 -0700

----------------------------------------------------------------------
 .../expressions/codegen/UnsafeRowWriter.java    | 10 ++--
 .../codegen/UnsafeRowWriterSuite.scala          | 53 ++++++++++++++++++++
 2 files changed, 57 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/443d12db/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
index 71c49d8..3960d6d 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
@@ -185,13 +185,13 @@ public final class UnsafeRowWriter extends UnsafeWriter {
       // grow the global buffer before writing data.
       holder.grow(16);
 
+      // always zero-out the 16-byte buffer
+      Platform.putLong(getBuffer(), cursor(), 0L);
+      Platform.putLong(getBuffer(), cursor() + 8, 0L);
+
       // Make sure Decimal object has the same scale as DecimalType.
       // Note that we may pass in null Decimal object to set null for it.
       if (input == null || !input.changePrecision(precision, scale)) {
-        // zero-out the bytes
-        Platform.putLong(getBuffer(), cursor(), 0L);
-        Platform.putLong(getBuffer(), cursor() + 8, 0L);
-
         BitSetMethods.set(getBuffer(), startingOffset, ordinal);
         // keep the offset for future update
         setOffsetAndSize(ordinal, 0);
@@ -200,8 +200,6 @@ public final class UnsafeRowWriter extends UnsafeWriter {
         final int numBytes = bytes.length;
         assert numBytes <= 16;
 
-        zeroOutPaddingBytes(numBytes);
-
         // Write the bytes to the variable length portion.
         Platform.copyMemory(
           bytes, Platform.BYTE_ARRAY_OFFSET, getBuffer(), cursor(), numBytes);

http://git-wip-us.apache.org/repos/asf/spark/blob/443d12db/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala
new file mode 100644
index 0000000..fb651b7
--- /dev/null
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.types.Decimal
+
+class UnsafeRowWriterSuite extends SparkFunSuite {
+
+  def checkDecimalSizeInBytes(decimal: Decimal, numBytes: Int): Unit = {
+    assert(decimal.toJavaBigDecimal.unscaledValue().toByteArray.length == 
numBytes)
+  }
+
+  test("SPARK-25538: zero-out all bits for decimals") {
+    val decimal1 = Decimal(0.431)
+    decimal1.changePrecision(38, 18)
+    checkDecimalSizeInBytes(decimal1, 8)
+
+    val decimal2 = Decimal(123456789.1232456789)
+    decimal2.changePrecision(38, 18)
+    checkDecimalSizeInBytes(decimal2, 11)
+    // On an UnsafeRowWriter we write decimal2 first and then decimal1
+    val unsafeRowWriter1 = new UnsafeRowWriter(1)
+    unsafeRowWriter1.resetRowWriter()
+    unsafeRowWriter1.write(0, decimal2, decimal2.precision, decimal2.scale)
+    unsafeRowWriter1.reset()
+    unsafeRowWriter1.write(0, decimal1, decimal1.precision, decimal1.scale)
+    val res1 = unsafeRowWriter1.getRow
+    // On a second UnsafeRowWriter we write directly decimal1
+    val unsafeRowWriter2 = new UnsafeRowWriter(1)
+    unsafeRowWriter2.resetRowWriter()
+    unsafeRowWriter2.write(0, decimal1, decimal1.precision, decimal1.scale)
+    val res2 = unsafeRowWriter2.getRow
+    // The two rows should be the equal
+    assert(res1 == res2)
+  }
+
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to