Repository: spark
Updated Branches:
  refs/heads/branch-2.3 850b9f391 -> fd46a276c


[SPARK-22984] Fix incorrect bitmap copying and offset adjustment in 
GenerateUnsafeRowJoiner

## What changes were proposed in this pull request?

This PR fixes a longstanding correctness bug in `GenerateUnsafeRowJoiner`. This 
class was introduced in https://github.com/apache/spark/pull/7821 (July 2015 / 
Spark 1.5.0+) and is used to combine pairs of UnsafeRows in 
TungstenAggregationIterator, CartesianProductExec, and AppendColumns.

### Bugs fixed by this patch

1. **Incorrect combining of null-tracking bitmaps**: when concatenating two 
UnsafeRows, the implementation "Concatenate the two bitsets together into a 
single one, taking padding into account". If one row has no columns then it has 
a bitset size of 0, but the code was incorrectly assuming that if the left row 
had a non-zero number of fields then the right row would also have at least one 
field, so it was copying invalid bytes and and treating them as part of the 
bitset. I'm not sure whether this bug was also present in the original 
implementation or whether it was introduced in 
https://github.com/apache/spark/pull/7892 (which fixed another bug in this 
code).
2. **Incorrect updating of data offsets for null variable-length fields**: 
after updating the bitsets and copying fixed-length and variable-length data, 
we need to perform adjustments to the offsets pointing the start of variable 
length fields's data. The existing code was _conditionally_ adding a fixed 
offset to correct for the new length of the combined row, but it is unsafe to 
do this if the variable-length field has a null value: we always represent 
nulls by storing `0` in the fixed-length slot, but this code was incorrectly 
incrementing those values. This bug was present since the original version of 
`GenerateUnsafeRowJoiner`.

### Why this bug remained latent for so long

The PR which introduced `GenerateUnsafeRowJoiner` features several randomized 
tests, including tests of the cases where one side of the join has no fields 
and where string-valued fields are null. However, the existing assertions were 
too weak to uncover this bug:

- If a null field has a non-zero value in its fixed-length data slot then this 
will not cause problems for field accesses because the null-tracking bitmap 
should still be correct and we will not try to use the incorrect offset for 
anything.
- If the null tracking bitmap is corrupted by joining against a row with no 
fields then the corruption occurs in field numbers past the actual field 
numbers contained in the row. Thus valid `isNullAt()` calls will not read the 
incorrectly-set bits.

The existing `GenerateUnsafeRowJoinerSuite` tests only exercised `.get()` and 
`isNullAt()`, but didn't actually check the UnsafeRows for bit-for-bit 
equality, preventing these bugs from failing assertions. It turns out that 
there was even a 
[GenerateUnsafeRowJoinerBitsetSuite](https://github.com/apache/spark/blob/03377d2522776267a07b7d6ae9bddf79a4e0f516/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala)
 but it looks like it also didn't catch this problem because it only tested the 
bitsets in an end-to-end fashion by accessing them through the `UnsafeRow` 
interface instead of actually comparing the bitsets' bytes.

### Impact of these bugs

- This bug will cause `equals()` and `hashCode()` to be incorrect for these 
rows, which will be problematic in case`GenerateUnsafeRowJoiner`'s results are 
used as join or grouping keys.
- Chained / repeated invocations of `GenerateUnsafeRowJoiner` may result in 
reads from invalid null bitmap positions causing fields to incorrectly become 
NULL (see the end-to-end example below).
  - It looks like this generally only happens in `CartesianProductExec`, which 
our query optimizer often avoids executing (usually we try to plan a 
`BroadcastNestedLoopJoin` instead).

### End-to-end test case demonstrating the problem

The following query demonstrates how this bug may result in incorrect query 
results:

```sql
set spark.sql.autoBroadcastJoinThreshold=-1; -- Needed to trigger 
CartesianProductExec

create table a as select * from values 1;
create table b as select * from values 2;

SELECT
  t3.col1,
  t1.col1
FROM a t1
CROSS JOIN b t2
CROSS JOIN b t3
```

This should return `(2, 1)` but instead was returning `(null, 1)`.

Column pruning ends up trimming off all columns from `t2`, so when `t2` joins 
with another table this triggers the bitmap-copying bug. This incorrect bitmap 
is subsequently copied again when performing the final join, causing the final 
output to have an incorrectly-set null bit for the first field.

## How was this patch tested?

Strengthened the assertions in existing tests in GenerateUnsafeRowJoinerSuite. 
Also verified that the end-to-end test case which uncovered this now passes.

Author: Josh Rosen <[email protected]>

Closes #20181 from 
JoshRosen/SPARK-22984-fix-generate-unsaferow-joiner-bitmap-bugs.

(cherry picked from commit f20131dd35939734fe16b0005a086aa72400893b)
Signed-off-by: Wenchen Fan <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/fd46a276
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/fd46a276
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/fd46a276

Branch: refs/heads/branch-2.3
Commit: fd46a276c04d2a6617c56abe84be80a1be5d7f99
Parents: 850b9f3
Author: Josh Rosen <[email protected]>
Authored: Tue Jan 9 11:49:10 2018 +0800
Committer: Wenchen Fan <[email protected]>
Committed: Tue Jan 9 11:59:39 2018 +0800

----------------------------------------------------------------------
 .../codegen/GenerateUnsafeRowJoiner.scala       | 52 ++++++++++-
 .../codegen/GenerateUnsafeRowJoinerSuite.scala  | 95 ++++++++++++++++++--
 2 files changed, 138 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/fd46a276/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
index be5f5a7..febf7b0 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
@@ -70,7 +70,7 @@ object GenerateUnsafeRowJoiner extends 
CodeGenerator[(StructType, StructType), U
 
     // --------------------- copy bitset from row 1 and row 2 
--------------------------- //
     val copyBitset = Seq.tabulate(outputBitsetWords) { i =>
-      val bits = if (bitset1Remainder > 0) {
+      val bits = if (bitset1Remainder > 0 && bitset2Words != 0) {
         if (i < bitset1Words - 1) {
           s"$getLong(obj1, offset1 + ${i * 8})"
         } else if (i == bitset1Words - 1) {
@@ -152,7 +152,9 @@ object GenerateUnsafeRowJoiner extends 
CodeGenerator[(StructType, StructType), U
       } else {
         // Number of bytes to increase for the offset. Note that since in 
UnsafeRow we store the
         // offset in the upper 32 bit of the words, we can just shift the 
offset to the left by
-        // 32 and increment that amount in place.
+        // 32 and increment that amount in place. However, we need to handle 
the important special
+        // case of a null field, in which case the offset should be zero and 
should not have a
+        // shift added to it.
         val shift =
           if (i < schema1.size) {
             s"${(outputBitsetWords - bitset1Words + schema2.size) * 8}L"
@@ -160,14 +162,55 @@ object GenerateUnsafeRowJoiner extends 
CodeGenerator[(StructType, StructType), U
             s"(${(outputBitsetWords - bitset2Words + schema1.size) * 8}L + 
numBytesVariableRow1)"
           }
         val cursor = offset + outputBitsetWords * 8 + i * 8
-        s"$putLong(buf, $cursor, $getLong(buf, $cursor) + ($shift << 32));\n"
+        // UnsafeRow is a little underspecified, so in what follows we'll 
treat UnsafeRowWriter's
+        // output as a de-facto specification for the internal layout of data.
+        //
+        // Null-valued fields will always have a data offset of 0 because
+        // UnsafeRowWriter.setNullAt(ordinal) sets the null bit and stores 0 
to in field's
+        // position in the fixed-length section of the row. As a result, we 
must NOT add
+        // `shift` to the offset for null fields.
+        //
+        // We could perform a null-check here by inspecting the null-tracking 
bitmap, but doing
+        // so could be expensive and will add significant bloat to the 
generated code. Instead,
+        // we'll rely on the invariant "stored offset == 0 for variable-length 
data type implies
+        // that the field's value is null."
+        //
+        // To establish that this invariant holds, we'll prove that a non-null 
field can never
+        // have a stored offset of 0. There are two cases to consider:
+        //
+        //   1. The non-null field's data is of non-zero length: reading this 
field's value
+        //      must read data from the variable-length section of the row, so 
the stored offset
+        //      will actually be used in address calculation and must be 
correct. The offsets
+        //      count bytes from the start of the UnsafeRow so these offsets 
will always be
+        //      non-zero because the storage of the offsets themselves takes 
up space at the
+        //      start of the row.
+        //   2. The non-null field's data is of zero length (i.e. its data is 
empty). In this
+        //      case, we have to worry about the possibility that an arbitrary 
offset value was
+        //      stored because we never actually read any bytes using this 
offset and therefore
+        //      would not crash if it was incorrect. The variable-sized data 
writing paths in
+        //      UnsafeRowWriter unconditionally calls 
setOffsetAndSize(ordinal, numBytes) with
+        //      no special handling for the case where `numBytes == 0`. 
Internally,
+        //      setOffsetAndSize computes the offset without taking the size 
into account. Thus
+        //      the stored offset is the same non-zero offset that would be 
used if the field's
+        //      dataSize was non-zero (and in (1) above we've shown that case 
behaves as we
+        //      expect).
+        //
+        // Thus it is safe to perform `existingOffset != 0` checks here in the 
place of
+        // more expensive null-bit checks.
+        s"""
+           |existingOffset = $getLong(buf, $cursor);
+           |if (existingOffset != 0) {
+           |    $putLong(buf, $cursor, existingOffset + ($shift << 32));
+           |}
+         """.stripMargin
       }
     }
 
     val updateOffsets = ctx.splitExpressions(
       expressions = updateOffset,
       funcName = "copyBitsetFunc",
-      arguments = ("long", "numBytesVariableRow1") :: Nil)
+      arguments = ("long", "numBytesVariableRow1") :: Nil,
+      makeSplitFunction = (s: String) => "long existingOffset;\n" + s)
 
     // ------------------------ Finally, put everything together  
--------------------------- //
     val codeBody = s"""
@@ -200,6 +243,7 @@ object GenerateUnsafeRowJoiner extends 
CodeGenerator[(StructType, StructType), U
        |    $copyFixedLengthRow2
        |    $copyVariableLengthRow1
        |    $copyVariableLengthRow2
+       |    long existingOffset;
        |    $updateOffsets
        |
        |    out.pointTo(buf, sizeInBytes);

http://git-wip-us.apache.org/repos/asf/spark/blob/fd46a276/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala
index f203f25..75c6bee 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala
@@ -22,8 +22,10 @@ import scala.util.Random
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.RandomDataGenerator
 import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
-import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.{JoinedRow, UnsafeProjection, 
UnsafeRow}
 import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
 
 /**
  * Test suite for [[GenerateUnsafeRowJoiner]].
@@ -45,6 +47,32 @@ class GenerateUnsafeRowJoinerSuite extends SparkFunSuite {
     testConcat(64, 64, fixed)
   }
 
+  test("rows with all empty strings") {
+    val schema = StructType(Seq(
+      StructField("f1", StringType), StructField("f2", StringType)))
+    val row: UnsafeRow = UnsafeProjection.create(schema).apply(
+      InternalRow(UTF8String.EMPTY_UTF8, UTF8String.EMPTY_UTF8))
+     testConcat(schema, row, schema, row)
+  }
+
+  test("rows with all empty int arrays") {
+    val schema = StructType(Seq(
+      StructField("f1", ArrayType(IntegerType)), StructField("f2", 
ArrayType(IntegerType))))
+    val emptyIntArray =
+      
ExpressionEncoder[Array[Int]]().resolveAndBind().toRow(Array.emptyIntArray).getArray(0)
+    val row: UnsafeRow = UnsafeProjection.create(schema).apply(
+      InternalRow(emptyIntArray, emptyIntArray))
+    testConcat(schema, row, schema, row)
+  }
+
+  test("alternating empty and non-empty strings") {
+    val schema = StructType(Seq(
+      StructField("f1", StringType), StructField("f2", StringType)))
+    val row: UnsafeRow = UnsafeProjection.create(schema).apply(
+      InternalRow(UTF8String.EMPTY_UTF8, UTF8String.fromString("foo")))
+    testConcat(schema, row, schema, row)
+  }
+
   test("randomized fix width types") {
     for (i <- 0 until 20) {
       testConcatOnce(Random.nextInt(100), Random.nextInt(100), fixed)
@@ -94,27 +122,84 @@ class GenerateUnsafeRowJoinerSuite extends SparkFunSuite {
     val extRow2 = RandomDataGenerator.forType(schema2, nullable = 
false).get.apply()
     val row1 = 
converter1.apply(internalConverter1.apply(extRow1).asInstanceOf[InternalRow])
     val row2 = 
converter2.apply(internalConverter2.apply(extRow2).asInstanceOf[InternalRow])
+    testConcat(schema1, row1, schema2, row2)
+  }
+
+  private def testConcat(
+      schema1: StructType,
+      row1: UnsafeRow,
+      schema2: StructType,
+      row2: UnsafeRow) {
 
     // Run the joiner.
     val mergedSchema = StructType(schema1 ++ schema2)
     val concater = GenerateUnsafeRowJoiner.create(schema1, schema2)
-    val output = concater.join(row1, row2)
+    val output: UnsafeRow = concater.join(row1, row2)
+
+    // We'll also compare to an UnsafeRow produced with JoinedRow + 
UnsafeProjection. This ensures
+    // that unused space in the row (e.g. leftover bits in the null-tracking 
bitmap) is written
+    // correctly.
+    val expectedOutput: UnsafeRow = {
+      val joinedRowProjection = UnsafeProjection.create(mergedSchema)
+      val joined = new JoinedRow()
+      joinedRowProjection.apply(joined.apply(row1, row2))
+    }
 
     // Test everything equals ...
     for (i <- mergedSchema.indices) {
+      val dataType = mergedSchema(i).dataType
       if (i < schema1.size) {
         assert(output.isNullAt(i) === row1.isNullAt(i))
         if (!output.isNullAt(i)) {
-          assert(output.get(i, mergedSchema(i).dataType) === row1.get(i, 
mergedSchema(i).dataType))
+          assert(output.get(i, dataType) === row1.get(i, dataType))
+          assert(output.get(i, dataType) === expectedOutput.get(i, dataType))
         }
       } else {
         assert(output.isNullAt(i) === row2.isNullAt(i - schema1.size))
         if (!output.isNullAt(i)) {
-          assert(output.get(i, mergedSchema(i).dataType) ===
-            row2.get(i - schema1.size, mergedSchema(i).dataType))
+          assert(output.get(i, dataType) === row2.get(i - schema1.size, 
dataType))
+          assert(output.get(i, dataType) === expectedOutput.get(i, dataType))
         }
       }
     }
+
+
+    assert(
+      expectedOutput.getSizeInBytes == output.getSizeInBytes,
+      "output isn't same size in bytes as slow path")
+
+    // Compare the UnsafeRows byte-by-byte so that we can print more useful 
debug information in
+    // case this assertion fails:
+    val actualBytes = output.getBaseObject.asInstanceOf[Array[Byte]]
+      .take(output.getSizeInBytes)
+    val expectedBytes = expectedOutput.getBaseObject.asInstanceOf[Array[Byte]]
+      .take(expectedOutput.getSizeInBytes)
+
+    val bitsetWidth = 
UnsafeRow.calculateBitSetWidthInBytes(expectedOutput.numFields())
+    val actualBitset = actualBytes.take(bitsetWidth)
+    val expectedBitset = expectedBytes.take(bitsetWidth)
+    assert(actualBitset === expectedBitset, "bitsets were not equal")
+
+    val fixedLengthSize = expectedOutput.numFields() * 8
+    val actualFixedLength = actualBytes.slice(bitsetWidth, bitsetWidth + 
fixedLengthSize)
+    val expectedFixedLength = expectedBytes.slice(bitsetWidth, bitsetWidth + 
fixedLengthSize)
+    if (actualFixedLength !== expectedFixedLength) {
+      actualFixedLength.grouped(8)
+        .zip(expectedFixedLength.grouped(8))
+        .zip(mergedSchema.fields.toIterator)
+        .foreach {
+          case ((actual, expected), field) =>
+            assert(actual === expected, s"Fixed length sections are not equal 
for field $field")
+      }
+      fail("Fixed length sections were not equal")
+    }
+
+    val variableLengthStart = bitsetWidth + fixedLengthSize
+    val actualVariableLength = actualBytes.drop(variableLengthStart)
+    val expectedVariableLength = expectedBytes.drop(variableLengthStart)
+    assert(actualVariableLength === expectedVariableLength, "fixed length 
sections were not equal")
+
+    assert(output.hashCode() == expectedOutput.hashCode(), "hash codes were 
not equal")
   }
 
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to