Repository: spark Updated Branches: refs/heads/branch-2.4 ea4068a0a -> 443d12dbb
[SPARK-25538][SQL] Zero-out all bytes when writing decimal ## What changes were proposed in this pull request? In #20850 when writing non-null decimals, instead of zero-ing all the 16 allocated bytes, we zero-out only the padding bytes. Since we always allocate 16 bytes, if the number of bytes needed for a decimal is lower than 9, then this means that the bytes between 8 and 16 are not zero-ed. I see 2 solutions here: - we can zero-out all the bytes in advance as it was done before #20850 (safer solution IMHO); - we can allocate only the needed bytes (may be a bit more efficient in terms of memory used, but I have not investigated the feasibility of this option). Hence I propose here the first solution in order to fix the correctness issue. We can eventually switch to the second if we think is more efficient later. ## How was this patch tested? Running the test attached in the JIRA + added UT Closes #22602 from mgaido91/SPARK-25582. Authored-by: Marco Gaido <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]> (cherry picked from commit d7ae36a810bfcbedfe7360eb2cdbbc3ca970e4d0) Signed-off-by: Dongjoon Hyun <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/443d12db Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/443d12db Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/443d12db Branch: refs/heads/branch-2.4 Commit: 443d12dbbe40e932978a9a1a811128da8afba89b Parents: ea4068a Author: Marco Gaido <[email protected]> Authored: Wed Oct 3 07:28:34 2018 -0700 Committer: Dongjoon Hyun <[email protected]> Committed: Wed Oct 3 07:28:48 2018 -0700 ---------------------------------------------------------------------- .../expressions/codegen/UnsafeRowWriter.java | 10 ++-- .../codegen/UnsafeRowWriterSuite.scala | 53 ++++++++++++++++++++ 2 files changed, 57 insertions(+), 6 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/443d12db/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index 71c49d8..3960d6d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -185,13 +185,13 @@ public final class UnsafeRowWriter extends UnsafeWriter { // grow the global buffer before writing data. holder.grow(16); + // always zero-out the 16-byte buffer + Platform.putLong(getBuffer(), cursor(), 0L); + Platform.putLong(getBuffer(), cursor() + 8, 0L); + // Make sure Decimal object has the same scale as DecimalType. // Note that we may pass in null Decimal object to set null for it. if (input == null || !input.changePrecision(precision, scale)) { - // zero-out the bytes - Platform.putLong(getBuffer(), cursor(), 0L); - Platform.putLong(getBuffer(), cursor() + 8, 0L); - BitSetMethods.set(getBuffer(), startingOffset, ordinal); // keep the offset for future update setOffsetAndSize(ordinal, 0); @@ -200,8 +200,6 @@ public final class UnsafeRowWriter extends UnsafeWriter { final int numBytes = bytes.length; assert numBytes <= 16; - zeroOutPaddingBytes(numBytes); - // Write the bytes to the variable length portion. Platform.copyMemory( bytes, Platform.BYTE_ARRAY_OFFSET, getBuffer(), cursor(), numBytes); http://git-wip-us.apache.org/repos/asf/spark/blob/443d12db/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala new file mode 100644 index 0000000..fb651b7 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types.Decimal + +class UnsafeRowWriterSuite extends SparkFunSuite { + + def checkDecimalSizeInBytes(decimal: Decimal, numBytes: Int): Unit = { + assert(decimal.toJavaBigDecimal.unscaledValue().toByteArray.length == numBytes) + } + + test("SPARK-25538: zero-out all bits for decimals") { + val decimal1 = Decimal(0.431) + decimal1.changePrecision(38, 18) + checkDecimalSizeInBytes(decimal1, 8) + + val decimal2 = Decimal(123456789.1232456789) + decimal2.changePrecision(38, 18) + checkDecimalSizeInBytes(decimal2, 11) + // On an UnsafeRowWriter we write decimal2 first and then decimal1 + val unsafeRowWriter1 = new UnsafeRowWriter(1) + unsafeRowWriter1.resetRowWriter() + unsafeRowWriter1.write(0, decimal2, decimal2.precision, decimal2.scale) + unsafeRowWriter1.reset() + unsafeRowWriter1.write(0, decimal1, decimal1.precision, decimal1.scale) + val res1 = unsafeRowWriter1.getRow + // On a second UnsafeRowWriter we write directly decimal1 + val unsafeRowWriter2 = new UnsafeRowWriter(1) + unsafeRowWriter2.resetRowWriter() + unsafeRowWriter2.write(0, decimal1, decimal1.precision, decimal1.scale) + val res2 = unsafeRowWriter2.getRow + // The two rows should be the equal + assert(res1 == res2) + } + +} --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
