Repository: spark
Updated Branches:
  refs/heads/master bc2767df2 -> f7e26d788


[SPARK-16922] [SPARK-17211] [SQL] make the address of values portable in 
LongToUnsafeRowMap

## What changes were proposed in this pull request?

In LongToUnsafeRowMap, we use offset of a value as pointer, stored in a array 
also in the page for chained values. The offset is not portable, because 
Platform.LONG_ARRAY_OFFSET will be different with different JVM Heap size, then 
the deserialized LongToUnsafeRowMap will be corrupt.

This PR will change to use portable address (without 
Platform.LONG_ARRAY_OFFSET).

## How was this patch tested?

Added a test case with random generated keys, to improve the coverage. But this 
test is not a regression test, that could require a Spark cluster that have at 
least 32G heap in driver or executor.

Author: Davies Liu <dav...@databricks.com>

Closes #14927 from davies/longmap.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f7e26d78
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f7e26d78
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f7e26d78

Branch: refs/heads/master
Commit: f7e26d788757f917b32749856bb29feb7b4c2987
Parents: bc2767d
Author: Davies Liu <dav...@databricks.com>
Authored: Tue Sep 6 10:46:31 2016 -0700
Committer: Davies Liu <davies....@gmail.com>
Committed: Tue Sep 6 10:46:31 2016 -0700

----------------------------------------------------------------------
 .../sql/execution/joins/HashedRelation.scala    | 27 +++++++---
 .../execution/joins/HashedRelationSuite.scala   | 56 ++++++++++++++++++++
 2 files changed, 75 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f7e26d78/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index 0897573..8821c0d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -447,10 +447,20 @@ private[execution] final class LongToUnsafeRowMap(val mm: 
TaskMemoryManager, cap
    */
   private def nextSlot(pos: Int): Int = (pos + 2) & mask
 
+  private[this] def toAddress(offset: Long, size: Int): Long = {
+    ((offset - Platform.LONG_ARRAY_OFFSET) << SIZE_BITS) | size
+  }
+
+  private[this] def toOffset(address: Long): Long = {
+    (address >>> SIZE_BITS) + Platform.LONG_ARRAY_OFFSET
+  }
+
+  private[this] def toSize(address: Long): Int = {
+    (address & SIZE_MASK).toInt
+  }
+
   private def getRow(address: Long, resultRow: UnsafeRow): UnsafeRow = {
-    val offset = address >>> SIZE_BITS
-    val size = address & SIZE_MASK
-    resultRow.pointTo(page, offset, size.toInt)
+    resultRow.pointTo(page, toOffset(address), toSize(address))
     resultRow
   }
 
@@ -485,9 +495,9 @@ private[execution] final class LongToUnsafeRowMap(val mm: 
TaskMemoryManager, cap
       var addr = address
       override def hasNext: Boolean = addr != 0
       override def next(): UnsafeRow = {
-        val offset = addr >>> SIZE_BITS
-        val size = addr & SIZE_MASK
-        resultRow.pointTo(page, offset, size.toInt)
+        val offset = toOffset(addr)
+        val size = toSize(addr)
+        resultRow.pointTo(page, offset, size)
         addr = Platform.getLong(page, offset + size)
         resultRow
       }
@@ -554,7 +564,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: 
TaskMemoryManager, cap
     Platform.putLong(page, cursor, 0)
     cursor += 8
     numValues += 1
-    updateIndex(key, (offset.toLong << SIZE_BITS) | row.getSizeInBytes)
+    updateIndex(key, toAddress(offset, row.getSizeInBytes))
   }
 
   /**
@@ -562,6 +572,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: 
TaskMemoryManager, cap
    */
   private def updateIndex(key: Long, address: Long): Unit = {
     var pos = firstSlot(key)
+    assert(numKeys < array.length / 2)
     while (array(pos) != key && array(pos + 1) != 0) {
       pos = nextSlot(pos)
     }
@@ -582,7 +593,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: 
TaskMemoryManager, cap
       }
     } else {
       // there are some values for this key, put the address in the front of 
them.
-      val pointer = (address >>> SIZE_BITS) + (address & SIZE_MASK)
+      val pointer = toOffset(address) + toSize(address)
       Platform.putLong(page, pointer, array(pos + 1))
       array(pos + 1) = address
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/f7e26d78/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
index 1196f5e..ede63fe 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.joins
 
 import java.io.{ByteArrayInputStream, ByteArrayOutputStream, 
ObjectInputStream, ObjectOutputStream}
 
+import scala.util.Random
+
 import org.apache.spark.{SparkConf, SparkFunSuite}
 import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
 import org.apache.spark.serializer.KryoSerializer
@@ -197,6 +199,60 @@ class HashedRelationSuite extends SparkFunSuite with 
SharedSQLContext {
     }
   }
 
+  test("LongToUnsafeRowMap with random keys") {
+    val taskMemoryManager = new TaskMemoryManager(
+      new StaticMemoryManager(
+        new SparkConf().set("spark.memory.offHeap.enabled", "false"),
+        Long.MaxValue,
+        Long.MaxValue,
+        1),
+      0)
+    val unsafeProj = UnsafeProjection.create(Seq(BoundReference(0, LongType, 
false)))
+
+    val N = 1000000
+    val rand = new Random
+    val keys = (0 to N).map(x => rand.nextLong()).toArray
+
+    val map = new LongToUnsafeRowMap(taskMemoryManager, 10)
+    keys.foreach { k =>
+      map.append(k, unsafeProj(InternalRow(k)))
+    }
+    map.optimize()
+
+    val os = new ByteArrayOutputStream()
+    val out = new ObjectOutputStream(os)
+    map.writeExternal(out)
+    out.flush()
+    val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray))
+    val map2 = new LongToUnsafeRowMap(taskMemoryManager, 1)
+    map2.readExternal(in)
+
+    val row = unsafeProj(InternalRow(0L)).copy()
+    keys.foreach { k =>
+      val r = map2.get(k, row)
+      assert(r.hasNext)
+      var c = 0
+      while (r.hasNext) {
+        val rr = r.next()
+        assert(rr.getLong(0) === k)
+        c += 1
+      }
+    }
+    var i = 0
+    while (i < N * 10) {
+      val k = rand.nextLong()
+      val r = map2.get(k, row)
+      if (r != null) {
+        assert(r.hasNext)
+        while (r.hasNext) {
+          assert(r.next().getLong(0) === k)
+        }
+      }
+      i += 1
+    }
+    map.free()
+  }
+
   test("Spark-14521") {
     val ser = new KryoSerializer(
       (new SparkConf).set("spark.kryo.referenceTracking", 
"false")).newInstance()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to