Repository: spark
Updated Branches:
  refs/heads/master 4c33a34ba -> 7893cd95d


[SPARK-11119] [SQL] cleanup for unsafe array and map

The purpose of this PR is to keep the unsafe format detail only inside the 
unsafe class itself, so when we use them(like use unsafe array in unsafe map, 
use unsafe array and map in columnar cache), we don't need to understand the 
format before use them.

change list:
* unsafe array's 4-bytes numElements header is now required(was optional), and 
become a part of unsafe array format.
* w.r.t the previous changing, the `sizeInBytes` of unsafe array now counts the 
4-bytes header.
* unsafe map's format was `[numElements] [key array numBytes] [key array 
content(without numElements header)] [value array content(without numElements 
header)]` before, which is a little hacky as it makes unsafe array's header 
optional. I think saving 4 bytes is not a big deal, so the format is now: `[key 
array numBytes] [unsafe key array] [unsafe value array]`.
* w.r.t the previous changing, the `sizeInBytes` of unsafe map now counts both 
map's header and array's header.

Author: Wenchen Fan <wenc...@databricks.com>

Closes #9131 from cloud-fan/unsafe.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7893cd95
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7893cd95
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7893cd95

Branch: refs/heads/master
Commit: 7893cd95db5f2caba59ff5c859d7e4964ad7938d
Parents: 4c33a34
Author: Wenchen Fan <wenc...@databricks.com>
Authored: Mon Oct 19 11:02:26 2015 -0700
Committer: Davies Liu <davies....@gmail.com>
Committed: Mon Oct 19 11:02:26 2015 -0700

----------------------------------------------------------------------
 .../catalyst/expressions/UnsafeArrayData.java   | 43 ++++++----
 .../sql/catalyst/expressions/UnsafeMapData.java | 88 +++++++++++++++-----
 .../sql/catalyst/expressions/UnsafeReaders.java | 54 ------------
 .../sql/catalyst/expressions/UnsafeRow.java     |  8 +-
 .../expressions/codegen/UnsafeArrayWriter.java  | 24 ++----
 .../expressions/codegen/UnsafeRowWriter.java    | 15 ----
 .../codegen/GenerateUnsafeProjection.scala      | 60 +++++++------
 .../expressions/UnsafeRowConverterSuite.scala   | 42 +++++-----
 .../apache/spark/sql/columnar/ColumnType.scala  | 30 +++----
 .../spark/sql/columnar/ColumnTypeSuite.scala    |  2 +-
 10 files changed, 174 insertions(+), 192 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7893cd95/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
index 4c63abb..761f044 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
@@ -30,19 +30,18 @@ import org.apache.spark.unsafe.types.UTF8String;
 /**
  * An Unsafe implementation of Array which is backed by raw memory instead of 
Java objects.
  *
- * Each tuple has two parts: [offsets] [values]
+ * Each tuple has three parts: [numElements] [offsets] [values]
  *
- * In the `offsets` region, we store 4 bytes per element, represents the start 
address of this
- * element in `values` region. We can get the length of this element by 
subtracting next offset.
+ * The `numElements` is 4 bytes storing the number of elements of this array.
+ *
+ * In the `offsets` region, we store 4 bytes per element, represents the 
relative offset (w.r.t. the
+ * base address of the array) of this element in `values` region. We can get 
the length of this
+ * element by subtracting next offset.
  * Note that offset can by negative which means this element is null.
  *
  * In the `values` region, we store the content of elements. As we can get 
length info, so elements
  * can be variable-length.
  *
- * Note that when we write out this array, we should write out the 
`numElements` at first 4 bytes,
- * then follows content. When we read in an array, we should read first 4 
bytes as `numElements`
- * and take the rest as content.
- *
  * Instances of `UnsafeArrayData` act as pointers to row data stored in this 
format.
  */
 // todo: there is a lof of duplicated code between UnsafeRow and 
UnsafeArrayData.
@@ -54,11 +53,16 @@ public class UnsafeArrayData extends ArrayData {
   // The number of elements in this array
   private int numElements;
 
-  // The size of this array's backing data, in bytes
+  // The size of this array's backing data, in bytes.
+  // The 4-bytes header of `numElements` is also included.
   private int sizeInBytes;
 
+  public Object getBaseObject() { return baseObject; }
+  public long getBaseOffset() { return baseOffset; }
+  public int getSizeInBytes() { return sizeInBytes; }
+
   private int getElementOffset(int ordinal) {
-    return Platform.getInt(baseObject, baseOffset + ordinal * 4L);
+    return Platform.getInt(baseObject, baseOffset + 4 + ordinal * 4L);
   }
 
   private int getElementSize(int offset, int ordinal) {
@@ -85,10 +89,6 @@ public class UnsafeArrayData extends ArrayData {
    */
   public UnsafeArrayData() { }
 
-  public Object getBaseObject() { return baseObject; }
-  public long getBaseOffset() { return baseOffset; }
-  public int getSizeInBytes() { return sizeInBytes; }
-
   @Override
   public int numElements() { return numElements; }
 
@@ -97,10 +97,13 @@ public class UnsafeArrayData extends ArrayData {
    *
    * @param baseObject the base object
    * @param baseOffset the offset within the base object
-   * @param sizeInBytes the size of this row's backing data, in bytes
+   * @param sizeInBytes the size of this array's backing data, in bytes
    */
-  public void pointTo(Object baseObject, long baseOffset, int numElements, int 
sizeInBytes) {
+  public void pointTo(Object baseObject, long baseOffset, int sizeInBytes) {
+    // Read the number of elements from the first 4 bytes.
+    final int numElements = Platform.getInt(baseObject, baseOffset);
     assert numElements >= 0 : "numElements (" + numElements + ") should >= 0";
+
     this.numElements = numElements;
     this.baseObject = baseObject;
     this.baseOffset = baseOffset;
@@ -277,7 +280,9 @@ public class UnsafeArrayData extends ArrayData {
     final int offset = getElementOffset(ordinal);
     if (offset < 0) return null;
     final int size = getElementSize(offset, ordinal);
-    return UnsafeReaders.readArray(baseObject, baseOffset + offset, size);
+    final UnsafeArrayData array = new UnsafeArrayData();
+    array.pointTo(baseObject, baseOffset + offset, size);
+    return array;
   }
 
   @Override
@@ -286,7 +291,9 @@ public class UnsafeArrayData extends ArrayData {
     final int offset = getElementOffset(ordinal);
     if (offset < 0) return null;
     final int size = getElementSize(offset, ordinal);
-    return UnsafeReaders.readMap(baseObject, baseOffset + offset, size);
+    final UnsafeMapData map = new UnsafeMapData();
+    map.pointTo(baseObject, baseOffset + offset, size);
+    return map;
   }
 
   @Override
@@ -328,7 +335,7 @@ public class UnsafeArrayData extends ArrayData {
     final byte[] arrayDataCopy = new byte[sizeInBytes];
     Platform.copyMemory(
       baseObject, baseOffset, arrayDataCopy, Platform.BYTE_ARRAY_OFFSET, 
sizeInBytes);
-    arrayCopy.pointTo(arrayDataCopy, Platform.BYTE_ARRAY_OFFSET, numElements, 
sizeInBytes);
+    arrayCopy.pointTo(arrayDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes);
     return arrayCopy;
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/7893cd95/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java
index e9dab9e..5bebe2a 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java
@@ -17,41 +17,73 @@
 
 package org.apache.spark.sql.catalyst.expressions;
 
+import java.nio.ByteBuffer;
+
 import org.apache.spark.sql.types.MapData;
+import org.apache.spark.unsafe.Platform;
 
 /**
  * An Unsafe implementation of Map which is backed by raw memory instead of 
Java objects.
  *
- * Currently we just use 2 UnsafeArrayData to represent UnsafeMapData.
- *
- * Note that when we write out this map, we should write out the `numElements` 
at first 4 bytes,
- * and numBytes of key array at second 4 bytes, then follows key array content 
and value array
- * content without `numElements` header.
- * When we read in a map, we should read first 4 bytes as `numElements` and 
second 4 bytes as
- * numBytes of key array, and construct unsafe key array and value array with 
these 2 information.
+ * Currently we just use 2 UnsafeArrayData to represent UnsafeMapData, with 
extra 4 bytes at head
+ * to indicate the number of bytes of the unsafe key array.
+ * [unsafe key array numBytes] [unsafe key array] [unsafe value array]
  */
+// TODO: Use a more efficient format which doesn't depend on unsafe array.
 public class UnsafeMapData extends MapData {
 
-  private final UnsafeArrayData keys;
-  private final UnsafeArrayData values;
-  // The number of elements in this array
-  private int numElements;
-  // The size of this array's backing data, in bytes
+  private Object baseObject;
+  private long baseOffset;
+
+  // The size of this map's backing data, in bytes.
+  // The 4-bytes header of key array `numBytes` is also included, so it's 
actually equal to
+  // 4 + key array numBytes + value array numBytes.
   private int sizeInBytes;
 
+  public Object getBaseObject() { return baseObject; }
+  public long getBaseOffset() { return baseOffset; }
   public int getSizeInBytes() { return sizeInBytes; }
 
-  public UnsafeMapData(UnsafeArrayData keys, UnsafeArrayData values) {
+  private final UnsafeArrayData keys;
+  private final UnsafeArrayData values;
+
+  /**
+   * Construct a new UnsafeMapData. The resulting UnsafeMapData won't be 
usable until
+   * `pointTo()` has been called, since the value returned by this constructor 
is equivalent
+   * to a null pointer.
+   */
+  public UnsafeMapData() {
+    keys = new UnsafeArrayData();
+    values = new UnsafeArrayData();
+  }
+
+  /**
+   * Update this UnsafeMapData to point to different backing data.
+   *
+   * @param baseObject the base object
+   * @param baseOffset the offset within the base object
+   * @param sizeInBytes the size of this map's backing data, in bytes
+   */
+  public void pointTo(Object baseObject, long baseOffset, int sizeInBytes) {
+    // Read the numBytes of key array from the first 4 bytes.
+    final int keyArraySize = Platform.getInt(baseObject, baseOffset);
+    final int valueArraySize = sizeInBytes - keyArraySize - 4;
+    assert keyArraySize >= 0 : "keyArraySize (" + keyArraySize + ") should >= 
0";
+    assert valueArraySize >= 0 : "valueArraySize (" + valueArraySize + ") 
should >= 0";
+
+    keys.pointTo(baseObject, baseOffset + 4, keyArraySize);
+    values.pointTo(baseObject, baseOffset + 4 + keyArraySize, valueArraySize);
+
     assert keys.numElements() == values.numElements();
-    this.sizeInBytes = keys.getSizeInBytes() + values.getSizeInBytes();
-    this.numElements = keys.numElements();
-    this.keys = keys;
-    this.values = values;
+
+    this.baseObject = baseObject;
+    this.baseOffset = baseOffset;
+    this.sizeInBytes = sizeInBytes;
   }
 
   @Override
   public int numElements() {
-    return numElements;
+    return keys.numElements();
   }
 
   @Override
@@ -64,8 +96,26 @@ public class UnsafeMapData extends MapData {
     return values;
   }
 
+  public void writeToMemory(Object target, long targetOffset) {
+    Platform.copyMemory(baseObject, baseOffset, target, targetOffset, 
sizeInBytes);
+  }
+
+  public void writeTo(ByteBuffer buffer) {
+    assert(buffer.hasArray());
+    byte[] target = buffer.array();
+    int offset = buffer.arrayOffset();
+    int pos = buffer.position();
+    writeToMemory(target, Platform.BYTE_ARRAY_OFFSET + offset + pos);
+    buffer.position(pos + sizeInBytes);
+  }
+
   @Override
   public UnsafeMapData copy() {
-    return new UnsafeMapData(keys.copy(), values.copy());
+    UnsafeMapData mapCopy = new UnsafeMapData();
+    final byte[] mapDataCopy = new byte[sizeInBytes];
+    Platform.copyMemory(
+      baseObject, baseOffset, mapDataCopy, Platform.BYTE_ARRAY_OFFSET, 
sizeInBytes);
+    mapCopy.pointTo(mapDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes);
+    return mapCopy;
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/7893cd95/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java
deleted file mode 100644
index 6c5fcbc..0000000
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java
+++ /dev/null
@@ -1,54 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions;
-
-import org.apache.spark.unsafe.Platform;
-
-public class UnsafeReaders {
-
-  /**
-   * Reads in unsafe array according to the format described in 
`UnsafeArrayData`.
-   */
-  public static UnsafeArrayData readArray(Object baseObject, long baseOffset, 
int numBytes) {
-    // Read the number of elements from first 4 bytes.
-    final int numElements = Platform.getInt(baseObject, baseOffset);
-    final UnsafeArrayData array = new UnsafeArrayData();
-    // Skip the first 4 bytes.
-    array.pointTo(baseObject, baseOffset + 4, numElements, numBytes - 4);
-    return array;
-  }
-
-  /**
-   * Reads in unsafe map according to the format described in `UnsafeMapData`.
-   */
-  public static UnsafeMapData readMap(Object baseObject, long baseOffset, int 
numBytes) {
-    // Read the number of elements from first 4 bytes.
-    final int numElements = Platform.getInt(baseObject, baseOffset);
-    // Read the numBytes of key array in second 4 bytes.
-    final int keyArraySize = Platform.getInt(baseObject, baseOffset + 4);
-    final int valueArraySize = numBytes - 8 - keyArraySize;
-
-    final UnsafeArrayData keyArray = new UnsafeArrayData();
-    keyArray.pointTo(baseObject, baseOffset + 8, numElements, keyArraySize);
-
-    final UnsafeArrayData valueArray = new UnsafeArrayData();
-    valueArray.pointTo(baseObject, baseOffset + 8 + keyArraySize, numElements, 
valueArraySize);
-
-    return new UnsafeMapData(keyArray, valueArray);
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/7893cd95/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 36859fb..366615f 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -461,7 +461,9 @@ public final class UnsafeRow extends MutableRow implements 
Externalizable, KryoS
       final long offsetAndSize = getLong(ordinal);
       final int offset = (int) (offsetAndSize >> 32);
       final int size = (int) (offsetAndSize & ((1L << 32) - 1));
-      return UnsafeReaders.readArray(baseObject, baseOffset + offset, size);
+      final UnsafeArrayData array = new UnsafeArrayData();
+      array.pointTo(baseObject, baseOffset + offset, size);
+      return array;
     }
   }
 
@@ -473,7 +475,9 @@ public final class UnsafeRow extends MutableRow implements 
Externalizable, KryoS
       final long offsetAndSize = getLong(ordinal);
       final int offset = (int) (offsetAndSize >> 32);
       final int size = (int) (offsetAndSize & ((1L << 32) - 1));
-      return UnsafeReaders.readMap(baseObject, baseOffset + offset, size);
+      final UnsafeMapData map = new UnsafeMapData();
+      map.pointTo(baseObject, baseOffset + offset, size);
+      return map;
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/7893cd95/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
index 138178c..7f2a1cb 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
@@ -30,17 +30,19 @@ import org.apache.spark.unsafe.types.UTF8String;
 public class UnsafeArrayWriter {
 
   private BufferHolder holder;
+
   // The offset of the global buffer where we start to write this array.
   private int startingOffset;
 
   public void initialize(BufferHolder holder, int numElements, int 
fixedElementSize) {
-    // We need 4 bytes each element to store offset.
-    final int fixedSize = 4 * numElements;
+    // We need 4 bytes to store numElements and 4 bytes each element to store 
offset.
+    final int fixedSize = 4 + 4 * numElements;
 
     this.holder = holder;
     this.startingOffset = holder.cursor;
 
     holder.grow(fixedSize);
+    Platform.putInt(holder.buffer, holder.cursor, numElements);
     holder.cursor += fixedSize;
 
     // Grows the global buffer ahead for fixed size data.
@@ -48,7 +50,7 @@ public class UnsafeArrayWriter {
   }
 
   private long getElementOffset(int ordinal) {
-    return startingOffset + 4 * ordinal;
+    return startingOffset + 4 + 4 * ordinal;
   }
 
   public void setNullAt(int ordinal) {
@@ -132,20 +134,4 @@ public class UnsafeArrayWriter {
     // move the cursor forward.
     holder.cursor += 16;
   }
-
-
-
-  // If this array is already an UnsafeArray, we don't need to go through all 
elements, we can
-  // directly write it.
-  public static void directWrite(BufferHolder holder, UnsafeArrayData input) {
-    final int numBytes = input.getSizeInBytes();
-
-    // grow the global buffer before writing data.
-    holder.grow(numBytes);
-
-    // Writes the array content to the variable length portion.
-    input.writeToMemory(holder.buffer, holder.cursor);
-
-    holder.cursor += numBytes;
-  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/7893cd95/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
index 8b7debd..e1f5a05 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
@@ -181,19 +181,4 @@ public class UnsafeRowWriter {
     // move the cursor forward.
     holder.cursor += 16;
   }
-
-
-
-  // If this struct is already an UnsafeRow, we don't need to go through all 
fields, we can
-  // directly write it.
-  public static void directWrite(BufferHolder holder, UnsafeRow input) {
-    // No need to zero-out the bytes as UnsafeRow is word aligned for sure.
-    final int numBytes = input.getSizeInBytes();
-    // grow the global buffer before writing data.
-    holder.grow(numBytes);
-    // Write the bytes to the variable length portion.
-    input.writeToMemory(holder.buffer, holder.cursor);
-    // move the cursor forward.
-    holder.cursor += numBytes;
-  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/7893cd95/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 1b957a5..dbe92d6 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -62,7 +62,7 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
 
     s"""
       if ($input instanceof UnsafeRow) {
-        $rowWriterClass.directWrite($bufferHolder, (UnsafeRow) $input);
+        ${writeUnsafeData(ctx, s"((UnsafeRow) $input)", bufferHolder)}
       } else {
         ${writeExpressionsToBuffer(ctx, input, fieldEvals, fieldTypes, 
bufferHolder)}
       }
@@ -164,8 +164,7 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
       ctx: CodeGenContext,
       input: String,
       elementType: DataType,
-      bufferHolder: String,
-      needHeader: Boolean = true): String = {
+      bufferHolder: String): String = {
     val arrayWriter = ctx.freshName("arrayWriter")
     ctx.addMutableState(arrayWriterClass, arrayWriter,
       s"this.$arrayWriter = new $arrayWriterClass();")
@@ -227,21 +226,11 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
       case _ => s"$arrayWriter.write($index, $element);"
     }
 
-    val writeHeader = if (needHeader) {
-      // If header is required, we need to write the number of elements into 
first 4 bytes.
-      s"""
-        $bufferHolder.grow(4);
-        Platform.putInt($bufferHolder.buffer, $bufferHolder.cursor, 
$numElements);
-        $bufferHolder.cursor += 4;
-      """
-    } else ""
-
     s"""
-      final int $numElements = $input.numElements();
-      $writeHeader
       if ($input instanceof UnsafeArrayData) {
-        $arrayWriterClass.directWrite($bufferHolder, (UnsafeArrayData) $input);
+        ${writeUnsafeData(ctx, s"((UnsafeArrayData) $input)", bufferHolder)}
       } else {
+        final int $numElements = $input.numElements();
         $arrayWriter.initialize($bufferHolder, $numElements, 
$fixedElementSize);
 
         for (int $index = 0; $index < $numElements; $index++) {
@@ -270,23 +259,40 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
 
     // Writes out unsafe map according to the format described in 
`UnsafeMapData`.
     s"""
-      final ArrayData $keys = $input.keyArray();
-      final ArrayData $values = $input.valueArray();
+      if ($input instanceof UnsafeMapData) {
+        ${writeUnsafeData(ctx, s"((UnsafeMapData) $input)", bufferHolder)}
+      } else {
+        final ArrayData $keys = $input.keyArray();
+        final ArrayData $values = $input.valueArray();
 
-      $bufferHolder.grow(8);
+        // preserve 4 bytes to write the key array numBytes later.
+        $bufferHolder.grow(4);
+        $bufferHolder.cursor += 4;
 
-      // Write the numElements into first 4 bytes.
-      Platform.putInt($bufferHolder.buffer, $bufferHolder.cursor, 
$keys.numElements());
+        // Remember the current cursor so that we can write numBytes of key 
array later.
+        final int $tmpCursor = $bufferHolder.cursor;
 
-      $bufferHolder.cursor += 8;
-      // Remember the current cursor so that we can write numBytes of key 
array later.
-      final int $tmpCursor = $bufferHolder.cursor;
+        ${writeArrayToBuffer(ctx, keys, keyType, bufferHolder)}
+        // Write the numBytes of key array into the first 4 bytes.
+        Platform.putInt($bufferHolder.buffer, $tmpCursor - 4, 
$bufferHolder.cursor - $tmpCursor);
 
-      ${writeArrayToBuffer(ctx, keys, keyType, bufferHolder, needHeader = 
false)}
-      // Write the numBytes of key array into second 4 bytes.
-      Platform.putInt($bufferHolder.buffer, $tmpCursor - 4, 
$bufferHolder.cursor - $tmpCursor);
+        ${writeArrayToBuffer(ctx, values, valueType, bufferHolder)}
+      }
+    """
+  }
 
-      ${writeArrayToBuffer(ctx, values, valueType, bufferHolder, needHeader = 
false)}
+  /**
+   * If the input is already in unsafe format, we don't need to go through all 
elements/fields,
+   * we can directly write it.
+   */
+  private def writeUnsafeData(ctx: CodeGenContext, input: String, 
bufferHolder: String) = {
+    val sizeInBytes = ctx.freshName("sizeInBytes")
+    s"""
+      final int $sizeInBytes = $input.getSizeInBytes();
+      // grow the global buffer before writing data.
+      $bufferHolder.grow($sizeInBytes);
+      $input.writeToMemory($bufferHolder.buffer, $bufferHolder.cursor);
+      $bufferHolder.cursor += $sizeInBytes;
     """
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/7893cd95/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
index c991cd8..c6aad34 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
@@ -296,13 +296,9 @@ class UnsafeRowConverterSuite extends SparkFunSuite with 
Matchers {
     new ArrayBasedMapData(createArray(keys: _*), createArray(values: _*))
   }
 
-  private def arraySizeInRow(numBytes: Int): Int = roundedSize(4 + numBytes)
-
-  private def mapSizeInRow(numBytes: Int): Int = roundedSize(8 + numBytes)
-
   private def testArrayInt(array: UnsafeArrayData, values: Seq[Int]): Unit = {
     assert(array.numElements == values.length)
-    assert(array.getSizeInBytes == (4 + 4) * values.length)
+    assert(array.getSizeInBytes == 4 + (4 + 4) * values.length)
     values.zipWithIndex.foreach {
       case (value, index) => assert(array.getInt(index) == value)
     }
@@ -315,7 +311,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with 
Matchers {
     testArrayInt(map.keyArray, keys)
     testArrayInt(map.valueArray, values)
 
-    assert(map.getSizeInBytes == map.keyArray.getSizeInBytes + 
map.valueArray.getSizeInBytes)
+    assert(map.getSizeInBytes == 4 + map.keyArray.getSizeInBytes + 
map.valueArray.getSizeInBytes)
   }
 
   test("basic conversion with array type") {
@@ -341,10 +337,10 @@ class UnsafeRowConverterSuite extends SparkFunSuite with 
Matchers {
     val nestedArray = unsafeArray2.getArray(0)
     testArrayInt(nestedArray, Seq(3, 4))
 
-    assert(unsafeArray2.getSizeInBytes == 4 + (4 + nestedArray.getSizeInBytes))
+    assert(unsafeArray2.getSizeInBytes == 4 + 4 + nestedArray.getSizeInBytes)
 
-    val array1Size = arraySizeInRow(unsafeArray1.getSizeInBytes)
-    val array2Size = arraySizeInRow(unsafeArray2.getSizeInBytes)
+    val array1Size = roundedSize(unsafeArray1.getSizeInBytes)
+    val array2Size = roundedSize(unsafeArray2.getSizeInBytes)
     assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + array1Size + array2Size)
   }
 
@@ -384,13 +380,13 @@ class UnsafeRowConverterSuite extends SparkFunSuite with 
Matchers {
       val nestedMap = valueArray.getMap(0)
       testMapInt(nestedMap, Seq(5, 6), Seq(7, 8))
 
-      assert(valueArray.getSizeInBytes == 4 + (8 + nestedMap.getSizeInBytes))
+      assert(valueArray.getSizeInBytes == 4 + 4 + nestedMap.getSizeInBytes)
     }
 
-    assert(unsafeMap2.getSizeInBytes == keyArray.getSizeInBytes + 
valueArray.getSizeInBytes)
+    assert(unsafeMap2.getSizeInBytes == 4 + keyArray.getSizeInBytes + 
valueArray.getSizeInBytes)
 
-    val map1Size = mapSizeInRow(unsafeMap1.getSizeInBytes)
-    val map2Size = mapSizeInRow(unsafeMap2.getSizeInBytes)
+    val map1Size = roundedSize(unsafeMap1.getSizeInBytes)
+    val map2Size = roundedSize(unsafeMap2.getSizeInBytes)
     assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + map1Size + map2Size)
   }
 
@@ -414,7 +410,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with 
Matchers {
     val innerArray = field1.getArray(0)
     testArrayInt(innerArray, Seq(1))
 
-    assert(field1.getSizeInBytes == 8 + 8 + 
arraySizeInRow(innerArray.getSizeInBytes))
+    assert(field1.getSizeInBytes == 8 + 8 + 
roundedSize(innerArray.getSizeInBytes))
 
     val field2 = unsafeRow.getArray(1)
     assert(field2.numElements == 1)
@@ -427,10 +423,10 @@ class UnsafeRowConverterSuite extends SparkFunSuite with 
Matchers {
       assert(innerStruct.getLong(0) == 2L)
     }
 
-    assert(field2.getSizeInBytes == 4 + innerStruct.getSizeInBytes)
+    assert(field2.getSizeInBytes == 4 + 4 + innerStruct.getSizeInBytes)
 
     assert(unsafeRow.getSizeInBytes ==
-      8 + 8 * 2 + field1.getSizeInBytes + 
arraySizeInRow(field2.getSizeInBytes))
+      8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes))
   }
 
   test("basic conversion with struct and map") {
@@ -453,7 +449,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with 
Matchers {
     val innerMap = field1.getMap(0)
     testMapInt(innerMap, Seq(1), Seq(2))
 
-    assert(field1.getSizeInBytes == 8 + 8 + 
mapSizeInRow(innerMap.getSizeInBytes))
+    assert(field1.getSizeInBytes == 8 + 8 + 
roundedSize(innerMap.getSizeInBytes))
 
     val field2 = unsafeRow.getMap(1)
 
@@ -470,13 +466,13 @@ class UnsafeRowConverterSuite extends SparkFunSuite with 
Matchers {
       assert(innerStruct.getSizeInBytes == 8 + 8)
       assert(innerStruct.getLong(0) == 4L)
 
-      assert(valueArray.getSizeInBytes == 4 + innerStruct.getSizeInBytes)
+      assert(valueArray.getSizeInBytes == 4 + 4 + innerStruct.getSizeInBytes)
     }
 
-    assert(field2.getSizeInBytes == keyArray.getSizeInBytes + 
valueArray.getSizeInBytes)
+    assert(field2.getSizeInBytes == 4 + keyArray.getSizeInBytes + 
valueArray.getSizeInBytes)
 
     assert(unsafeRow.getSizeInBytes ==
-      8 + 8 * 2 + field1.getSizeInBytes + mapSizeInRow(field2.getSizeInBytes))
+      8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes))
   }
 
   test("basic conversion with array and map") {
@@ -499,7 +495,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with 
Matchers {
     val innerMap = field1.getMap(0)
     testMapInt(innerMap, Seq(1), Seq(2))
 
-    assert(field1.getSizeInBytes == 4 + (8 + innerMap.getSizeInBytes))
+    assert(field1.getSizeInBytes == 4 + 4 + innerMap.getSizeInBytes)
 
     val field2 = unsafeRow.getMap(1)
     assert(field2.numElements == 1)
@@ -518,9 +514,9 @@ class UnsafeRowConverterSuite extends SparkFunSuite with 
Matchers {
       assert(valueArray.getSizeInBytes == 4 + (4 + innerArray.getSizeInBytes))
     }
 
-    assert(field2.getSizeInBytes == keyArray.getSizeInBytes + 
valueArray.getSizeInBytes)
+    assert(field2.getSizeInBytes == 4 + keyArray.getSizeInBytes + 
valueArray.getSizeInBytes)
 
     assert(unsafeRow.getSizeInBytes ==
-      8 + 8 * 2 + arraySizeInRow(field1.getSizeInBytes) + 
mapSizeInRow(field2.getSizeInBytes))
+      8 + 8 * 2 + roundedSize(field1.getSizeInBytes) + 
roundedSize(field2.getSizeInBytes))
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/7893cd95/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
index 2bc2c96..a41f04d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
@@ -482,12 +482,14 @@ private[sql] case class STRUCT(dataType: StructType) 
extends ColumnType[UnsafeRo
   override def extract(buffer: ByteBuffer): UnsafeRow = {
     val sizeInBytes = buffer.getInt()
     assert(buffer.hasArray)
-    val base = buffer.array()
-    val offset = buffer.arrayOffset()
     val cursor = buffer.position()
     buffer.position(cursor + sizeInBytes)
     val unsafeRow = new UnsafeRow
-    unsafeRow.pointTo(base, Platform.BYTE_ARRAY_OFFSET + offset + cursor, 
numOfFields, sizeInBytes)
+    unsafeRow.pointTo(
+      buffer.array(),
+      Platform.BYTE_ARRAY_OFFSET + buffer.arrayOffset() + cursor,
+      numOfFields,
+      sizeInBytes)
     unsafeRow
   }
 
@@ -508,12 +510,11 @@ private[sql] case class ARRAY(dataType: ArrayType) 
extends ColumnType[UnsafeArra
 
   override def actualSize(row: InternalRow, ordinal: Int): Int = {
     val unsafeArray = getField(row, ordinal)
-    4 + 4 + unsafeArray.getSizeInBytes
+    4 + unsafeArray.getSizeInBytes
   }
 
   override def append(value: UnsafeArrayData, buffer: ByteBuffer): Unit = {
-    buffer.putInt(4 + value.getSizeInBytes)
-    buffer.putInt(value.numElements())
+    buffer.putInt(value.getSizeInBytes)
     value.writeTo(buffer)
   }
 
@@ -522,10 +523,12 @@ private[sql] case class ARRAY(dataType: ArrayType) 
extends ColumnType[UnsafeArra
     assert(buffer.hasArray)
     val cursor = buffer.position()
     buffer.position(cursor + numBytes)
-    UnsafeReaders.readArray(
+    val array = new UnsafeArrayData
+    array.pointTo(
       buffer.array(),
       Platform.BYTE_ARRAY_OFFSET + buffer.arrayOffset() + cursor,
       numBytes)
+    array
   }
 
   override def clone(v: UnsafeArrayData): UnsafeArrayData = v.copy()
@@ -545,15 +548,12 @@ private[sql] case class MAP(dataType: MapType) extends 
ColumnType[UnsafeMapData]
 
   override def actualSize(row: InternalRow, ordinal: Int): Int = {
     val unsafeMap = getField(row, ordinal)
-    12 + unsafeMap.keyArray().getSizeInBytes + 
unsafeMap.valueArray().getSizeInBytes
+    4 + unsafeMap.getSizeInBytes
   }
 
   override def append(value: UnsafeMapData, buffer: ByteBuffer): Unit = {
-    buffer.putInt(8 + value.keyArray().getSizeInBytes + 
value.valueArray().getSizeInBytes)
-    buffer.putInt(value.numElements())
-    buffer.putInt(value.keyArray().getSizeInBytes)
-    value.keyArray().writeTo(buffer)
-    value.valueArray().writeTo(buffer)
+    buffer.putInt(value.getSizeInBytes)
+    value.writeTo(buffer)
   }
 
   override def extract(buffer: ByteBuffer): UnsafeMapData = {
@@ -561,10 +561,12 @@ private[sql] case class MAP(dataType: MapType) extends 
ColumnType[UnsafeMapData]
     assert(buffer.hasArray)
     val cursor = buffer.position()
     buffer.position(cursor + numBytes)
-    UnsafeReaders.readMap(
+    val map = new UnsafeMapData
+    map.pointTo(
       buffer.array(),
       Platform.BYTE_ARRAY_OFFSET + buffer.arrayOffset() + cursor,
       numBytes)
+    map
   }
 
   override def clone(v: UnsafeMapData): UnsafeMapData = v.copy()

http://git-wip-us.apache.org/repos/asf/spark/blob/7893cd95/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
index 0e6e1bc..63bc39b 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
@@ -73,7 +73,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging {
     checkActualSize(COMPACT_DECIMAL(15, 10), Decimal(0, 15, 10), 8)
     checkActualSize(LARGE_DECIMAL(20, 10), Decimal(0, 20, 10), 5)
     checkActualSize(ARRAY_TYPE, Array[Any](1), 16)
-    checkActualSize(MAP_TYPE, Map(1 -> "a"), 25)
+    checkActualSize(MAP_TYPE, Map(1 -> "a"), 29)
     checkActualSize(STRUCT_TYPE, Row("hello"), 28)
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to