Repository: spark
Updated Branches:
  refs/heads/branch-1.5 5598b6238 -> 47e473550


[SPARK-9753] [SQL] TungstenAggregate should also accept InternalRow instead of 
just UnsafeRow

https://issues.apache.org/jira/browse/SPARK-9753

This PR makes TungstenAggregate to accept `InternalRow` instead of just 
`UnsafeRow`. Also, it adds an `getAggregationBufferFromUnsafeRow` method to 
`UnsafeFixedWidthAggregationMap`. It is useful when we already have grouping 
keys stored in `UnsafeRow`s. Finally, it wraps `InputStream` and `OutputStream` 
in `UnsafeRowSerializer` with `BufferedInputStream` and `BufferedOutputStream`, 
respectively.

Author: Yin Huai <[email protected]>

Closes #8041 from yhuai/joinedRowForProjection and squashes the following 
commits:

7753e34 [Yin Huai] Use BufferedInputStream and BufferedOutputStream.
d68b74e [Yin Huai] Use joinedRow instead of UnsafeRowJoiner.
e93c009 [Yin Huai] Add getAggregationBufferFromUnsafeRow for cases that the 
given groupingKeyRow is already an UnsafeRow.

(cherry picked from commit c564b27447ed99e55b359b3df1d586d5766b85ea)
Signed-off-by: Reynold Xin <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/47e47355
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/47e47355
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/47e47355

Branch: refs/heads/branch-1.5
Commit: 47e47355069d8af0ae7cd6e7fce7fbb0c2810277
Parents: 5598b62
Author: Yin Huai <[email protected]>
Authored: Fri Aug 7 20:04:17 2015 -0700
Committer: Reynold Xin <[email protected]>
Committed: Fri Aug 7 20:04:24 2015 -0700

----------------------------------------------------------------------
 .../UnsafeFixedWidthAggregationMap.java         |  4 ++
 .../sql/execution/UnsafeRowSerializer.scala     | 30 +++---------
 .../execution/aggregate/TungstenAggregate.scala |  4 +-
 .../aggregate/TungstenAggregationIterator.scala | 51 ++++++++++----------
 4 files changed, 39 insertions(+), 50 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/47e47355/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
 
b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index b08a4a1..00218f2 100644
--- 
a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ 
b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -121,6 +121,10 @@ public final class UnsafeFixedWidthAggregationMap {
   public UnsafeRow getAggregationBuffer(InternalRow groupingKey) {
     final UnsafeRow unsafeGroupingKeyRow = 
this.groupingKeyProjection.apply(groupingKey);
 
+    return getAggregationBufferFromUnsafeRow(unsafeGroupingKeyRow);
+  }
+
+  public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow 
unsafeGroupingKeyRow) {
     // Probe our map using the serialized key
     final BytesToBytesMap.Location loc = map.lookup(
       unsafeGroupingKeyRow.getBaseObject(),

http://git-wip-us.apache.org/repos/asf/spark/blob/47e47355/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
index 39f8f99..6c7e5ca 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
@@ -58,27 +58,14 @@ private class UnsafeRowSerializerInstance(numFields: Int) 
extends SerializerInst
    */
   override def serializeStream(out: OutputStream): SerializationStream = new 
SerializationStream {
     private[this] var writeBuffer: Array[Byte] = new Array[Byte](4096)
-    // When `out` is backed by ChainedBufferOutputStream, we will get an
-    // UnsupportedOperationException when we call dOut.writeInt because it 
internally calls
-    // ChainedBufferOutputStream's write(b: Int), which is not supported.
-    // To workaround this issue, we create an array for sorting the int value.
-    // To reproduce the problem, use dOut.writeInt(row.getSizeInBytes) and
-    // run SparkSqlSerializer2SortMergeShuffleSuite.
-    private[this] var intBuffer: Array[Byte] = new Array[Byte](4)
-    private[this] val dOut: DataOutputStream = new DataOutputStream(out)
+    private[this] val dOut: DataOutputStream =
+      new DataOutputStream(new BufferedOutputStream(out))
 
     override def writeValue[T: ClassTag](value: T): SerializationStream = {
       val row = value.asInstanceOf[UnsafeRow]
-      val size = row.getSizeInBytes
-      // This part is based on DataOutputStream's writeInt.
-      // It is for dOut.writeInt(row.getSizeInBytes).
-      intBuffer(0) = ((size >>> 24) & 0xFF).toByte
-      intBuffer(1) = ((size >>> 16) & 0xFF).toByte
-      intBuffer(2) = ((size >>> 8) & 0xFF).toByte
-      intBuffer(3) = ((size >>> 0) & 0xFF).toByte
-      dOut.write(intBuffer, 0, 4)
-
-      row.writeToStream(out, writeBuffer)
+
+      dOut.writeInt(row.getSizeInBytes)
+      row.writeToStream(dOut, writeBuffer)
       this
     }
 
@@ -105,7 +92,6 @@ private class UnsafeRowSerializerInstance(numFields: Int) 
extends SerializerInst
 
     override def close(): Unit = {
       writeBuffer = null
-      intBuffer = null
       dOut.writeInt(EOF)
       dOut.close()
     }
@@ -113,7 +99,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) 
extends SerializerInst
 
   override def deserializeStream(in: InputStream): DeserializationStream = {
     new DeserializationStream {
-      private[this] val dIn: DataInputStream = new DataInputStream(in)
+      private[this] val dIn: DataInputStream = new DataInputStream(new 
BufferedInputStream(in))
       // 1024 is a default buffer size; this buffer will grow to accommodate 
larger rows
       private[this] var rowBuffer: Array[Byte] = new Array[Byte](1024)
       private[this] var row: UnsafeRow = new UnsafeRow()
@@ -129,7 +115,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) 
extends SerializerInst
             if (rowBuffer.length < rowSize) {
               rowBuffer = new Array[Byte](rowSize)
             }
-            ByteStreams.readFully(in, rowBuffer, 0, rowSize)
+            ByteStreams.readFully(dIn, rowBuffer, 0, rowSize)
             row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, 
numFields, rowSize)
             rowSize = dIn.readInt() // read the next row's size
             if (rowSize == EOF) { // We are returning the last row in this 
stream
@@ -163,7 +149,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) 
extends SerializerInst
         if (rowBuffer.length < rowSize) {
           rowBuffer = new Array[Byte](rowSize)
         }
-        ByteStreams.readFully(in, rowBuffer, 0, rowSize)
+        ByteStreams.readFully(dIn, rowBuffer, 0, rowSize)
         row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, 
rowSize)
         row.asInstanceOf[T]
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/47e47355/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index c3dcbd2..1694794 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -39,7 +39,7 @@ case class TungstenAggregate(
 
   override def canProcessUnsafeRows: Boolean = true
 
-  override def canProcessSafeRows: Boolean = false
+  override def canProcessSafeRows: Boolean = true
 
   override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
 
@@ -77,7 +77,7 @@ case class TungstenAggregate(
             resultExpressions,
             newMutableProjection,
             child.output,
-            iter.asInstanceOf[Iterator[UnsafeRow]],
+            iter,
             testFallbackStartsAt)
 
         if (!hasInput && groupingExpressions.isEmpty) {

http://git-wip-us.apache.org/repos/asf/spark/blob/47e47355/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
index 440bef3..3216090 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -22,6 +22,7 @@ import org.apache.spark.{InternalAccumulator, Logging, 
SparkEnv, TaskContext}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import 
org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
+import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.execution.{UnsafeKVExternalSorter, 
UnsafeFixedWidthAggregationMap}
 import org.apache.spark.sql.types.StructType
 
@@ -46,8 +47,7 @@ import org.apache.spark.sql.types.StructType
  *            processing input rows from inputIter, and generating output
  *            rows.
  *  - Part 3: Methods and fields used by hash-based aggregation.
- *  - Part 4: The function used to switch this iterator from hash-based
- *            aggregation to sort-based aggregation.
+ *  - Part 4: Methods and fields used when we switch to sort-based aggregation.
  *  - Part 5: Methods and fields used by sort-based aggregation.
  *  - Part 6: Loads input and process input rows.
  *  - Part 7: Public methods of this iterator.
@@ -82,7 +82,7 @@ class TungstenAggregationIterator(
     resultExpressions: Seq[NamedExpression],
     newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => 
MutableProjection),
     originalInputAttributes: Seq[Attribute],
-    inputIter: Iterator[UnsafeRow],
+    inputIter: Iterator[InternalRow],
     testFallbackStartsAt: Option[Int])
   extends Iterator[UnsafeRow] with Logging {
 
@@ -174,13 +174,10 @@ class TungstenAggregationIterator(
 
   // Creates a function used to process a row based on the given 
inputAttributes.
   private def generateProcessRow(
-      inputAttributes: Seq[Attribute]): (UnsafeRow, UnsafeRow) => Unit = {
+      inputAttributes: Seq[Attribute]): (UnsafeRow, InternalRow) => Unit = {
 
     val aggregationBufferAttributes = 
allAggregateFunctions.flatMap(_.bufferAttributes)
-    val aggregationBufferSchema = 
StructType.fromAttributes(aggregationBufferAttributes)
-    val inputSchema = StructType.fromAttributes(inputAttributes)
-    val unsafeRowJoiner =
-      GenerateUnsafeRowJoiner.create(aggregationBufferSchema, inputSchema)
+    val joinedRow = new JoinedRow()
 
     aggregationMode match {
       // Partial-only
@@ -189,9 +186,9 @@ class TungstenAggregationIterator(
         val algebraicUpdateProjection =
           newMutableProjection(updateExpressions, aggregationBufferAttributes 
++ inputAttributes)()
 
-        (currentBuffer: UnsafeRow, row: UnsafeRow) => {
+        (currentBuffer: UnsafeRow, row: InternalRow) => {
           algebraicUpdateProjection.target(currentBuffer)
-          algebraicUpdateProjection(unsafeRowJoiner.join(currentBuffer, row))
+          algebraicUpdateProjection(joinedRow(currentBuffer, row))
         }
 
       // PartialMerge-only or Final-only
@@ -203,10 +200,10 @@ class TungstenAggregationIterator(
             mergeExpressions,
             aggregationBufferAttributes ++ inputAttributes)()
 
-        (currentBuffer: UnsafeRow, row: UnsafeRow) => {
+        (currentBuffer: UnsafeRow, row: InternalRow) => {
           // Process all algebraic aggregate functions.
           algebraicMergeProjection.target(currentBuffer)
-          algebraicMergeProjection(unsafeRowJoiner.join(currentBuffer, row))
+          algebraicMergeProjection(joinedRow(currentBuffer, row))
         }
 
       // Final-Complete
@@ -233,8 +230,8 @@ class TungstenAggregationIterator(
         val completeAlgebraicUpdateProjection =
           newMutableProjection(updateExpressions, aggregationBufferAttributes 
++ inputAttributes)()
 
-        (currentBuffer: UnsafeRow, row: UnsafeRow) => {
-          val input = unsafeRowJoiner.join(currentBuffer, row)
+        (currentBuffer: UnsafeRow, row: InternalRow) => {
+          val input = joinedRow(currentBuffer, row)
           // For all aggregate functions with mode Complete, update the given 
currentBuffer.
           completeAlgebraicUpdateProjection.target(currentBuffer)(input)
 
@@ -253,14 +250,14 @@ class TungstenAggregationIterator(
         val completeAlgebraicUpdateProjection =
           newMutableProjection(updateExpressions, aggregationBufferAttributes 
++ inputAttributes)()
 
-        (currentBuffer: UnsafeRow, row: UnsafeRow) => {
+        (currentBuffer: UnsafeRow, row: InternalRow) => {
           completeAlgebraicUpdateProjection.target(currentBuffer)
           // For all aggregate functions with mode Complete, update the given 
currentBuffer.
-          
completeAlgebraicUpdateProjection(unsafeRowJoiner.join(currentBuffer, row))
+          completeAlgebraicUpdateProjection(joinedRow(currentBuffer, row))
         }
 
       // Grouping only.
-      case (None, None) => (currentBuffer: UnsafeRow, row: UnsafeRow) => {}
+      case (None, None) => (currentBuffer: UnsafeRow, row: InternalRow) => {}
 
       case other =>
         throw new IllegalStateException(
@@ -272,15 +269,16 @@ class TungstenAggregationIterator(
   private def generateResultProjection(): (UnsafeRow, UnsafeRow) => UnsafeRow 
= {
 
     val groupingAttributes = groupingExpressions.map(_.toAttribute)
-    val groupingKeySchema = StructType.fromAttributes(groupingAttributes)
     val bufferAttributes = allAggregateFunctions.flatMap(_.bufferAttributes)
-    val bufferSchema = StructType.fromAttributes(bufferAttributes)
-    val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, 
bufferSchema)
 
     aggregationMode match {
       // Partial-only or PartialMerge-only: every output row is basically the 
values of
       // the grouping expressions and the corresponding aggregation buffer.
       case (Some(Partial), None) | (Some(PartialMerge), None) =>
+        val groupingKeySchema = StructType.fromAttributes(groupingAttributes)
+        val bufferSchema = StructType.fromAttributes(bufferAttributes)
+        val unsafeRowJoiner = 
GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema)
+
         (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => {
           unsafeRowJoiner.join(currentGroupingKey, currentBuffer)
         }
@@ -288,11 +286,12 @@ class TungstenAggregationIterator(
       // Final-only, Complete-only and Final-Complete: a output row is 
generated based on
       // resultExpressions.
       case (Some(Final), None) | (Some(Final) | None, Some(Complete)) =>
+        val joinedRow = new JoinedRow()
         val resultProjection =
           UnsafeProjection.create(resultExpressions, groupingAttributes ++ 
bufferAttributes)
 
         (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => {
-          resultProjection(unsafeRowJoiner.join(currentGroupingKey, 
currentBuffer))
+          resultProjection(joinedRow(currentGroupingKey, currentBuffer))
         }
 
       // Grouping-only: a output row is generated from values of grouping 
expressions.
@@ -316,7 +315,7 @@ class TungstenAggregationIterator(
 
   // A function used to process a input row. Its first argument is the 
aggregation buffer
   // and the second argument is the input row.
-  private[this] var processRow: (UnsafeRow, UnsafeRow) => Unit =
+  private[this] var processRow: (UnsafeRow, InternalRow) => Unit =
     generateProcessRow(originalInputAttributes)
 
   // A function used to generate output rows based on the grouping keys (first 
argument)
@@ -354,7 +353,7 @@ class TungstenAggregationIterator(
     while (!sortBased && inputIter.hasNext) {
       val newInput = inputIter.next()
       val groupingKey = groupProjection.apply(newInput)
-      val buffer: UnsafeRow = hashMap.getAggregationBuffer(groupingKey)
+      val buffer: UnsafeRow = 
hashMap.getAggregationBufferFromUnsafeRow(groupingKey)
       if (buffer == null) {
         // buffer == null means that we could not allocate more memory.
         // Now, we need to spill the map and switch to sort-based aggregation.
@@ -374,7 +373,7 @@ class TungstenAggregationIterator(
       val newInput = inputIter.next()
       val groupingKey = groupProjection.apply(newInput)
       val buffer: UnsafeRow = if (i < fallbackStartsAt) {
-        hashMap.getAggregationBuffer(groupingKey)
+        hashMap.getAggregationBufferFromUnsafeRow(groupingKey)
       } else {
         null
       }
@@ -397,7 +396,7 @@ class TungstenAggregationIterator(
   private[this] var mapIteratorHasNext: Boolean = false
 
   ///////////////////////////////////////////////////////////////////////////
-  // Part 3: Methods and fields used by sort-based aggregation.
+  // Part 4: Methods and fields used when we switch to sort-based aggregation.
   ///////////////////////////////////////////////////////////////////////////
 
   // This sorter is used for sort-based aggregation. It is initialized as soon 
as
@@ -407,7 +406,7 @@ class TungstenAggregationIterator(
   /**
    * Switch to sort-based aggregation when the hash-based approach is unable 
to acquire memory.
    */
-  private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: 
UnsafeRow): Unit = {
+  private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: 
InternalRow): Unit = {
     logInfo("falling back to sort based aggregation.")
     // Step 1: Get the ExternalSorter containing sorted entries of the map.
     externalSorter = hashMap.destructAndCreateExternalSorter()


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to