Repository: spark
Updated Branches:
  refs/heads/master e5c307c50 -> 52636226d


[SPARK-18932][SQL] Support partial aggregation for collect_set/collect_list

## What changes were proposed in this pull request?

Currently collect_set/collect_list aggregation expression don't support partial 
aggregation. This patch is to enable partial aggregation for them.

## How was this patch tested?

Jenkins tests.

Please review http://spark.apache.org/contributing.html before opening a pull 
request.

Author: Liang-Chi Hsieh <vii...@gmail.com>

Closes #16371 from viirya/collect-partial-support.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/52636226
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/52636226
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/52636226

Branch: refs/heads/master
Commit: 52636226dc8cb7fcf00381d65e280d651b25a382
Parents: e5c307c
Author: Liang-Chi Hsieh <vii...@gmail.com>
Authored: Tue Jan 3 22:11:54 2017 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Tue Jan 3 22:11:54 2017 +0800

----------------------------------------------------------------------
 .../expressions/aggregate/Percentile.scala      |  7 +--
 .../expressions/aggregate/collect.scala         | 62 +++++++++++---------
 .../expressions/aggregate/interfaces.scala      |  4 +-
 .../RewriteDistinctAggregatesSuite.scala        |  9 ---
 4 files changed, 39 insertions(+), 43 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/52636226/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
index 2f4d68d..eaeb010 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
@@ -33,10 +33,9 @@ import org.apache.spark.util.collection.OpenHashMap
  * The Percentile aggregate function returns the exact percentile(s) of 
numeric column `expr` at
  * the given percentage(s) with value range in [0.0, 1.0].
  *
- * The operator is bound to the slower sort based aggregation path because the 
number of elements
- * and their partial order cannot be determined in advance. Therefore we have 
to store all the
- * elements in memory, and that too many elements can cause GC paused and 
eventually OutOfMemory
- * Errors.
+ * Because the number of elements and their partial order cannot be determined 
in advance.
+ * Therefore we have to store all the elements in memory, and so notice that 
too many elements can
+ * cause GC paused and eventually OutOfMemory Errors.
  *
  * @param child child expression that produce numeric column value with 
`child.eval(inputRow)`
  * @param percentageExpression Expression that represents a single percentage 
value or an array of

http://git-wip-us.apache.org/repos/asf/spark/blob/52636226/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
index b176e2a..411f058 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.catalyst.expressions.aggregate
 
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, 
DataOutputStream}
+
 import scala.collection.generic.Growable
 import scala.collection.mutable
 
@@ -27,14 +29,12 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.types._
 
 /**
- * The Collect aggregate function collects all seen expression values into a 
list of values.
+ * A base class for collect_list and collect_set aggregate functions.
  *
- * The operator is bound to the slower sort based aggregation path because the 
number of
- * elements (and their memory usage) can not be determined in advance. This 
also means that the
- * collected elements are stored on heap, and that too many elements can cause 
GC pauses and
- * eventually Out of Memory Errors.
+ * We have to store all the collected elements in memory, and so notice that 
too many elements
+ * can cause GC paused and eventually OutOfMemory Errors.
  */
-abstract class Collect extends ImperativeAggregate {
+abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends 
TypedImperativeAggregate[T] {
 
   val child: Expression
 
@@ -44,40 +44,44 @@ abstract class Collect extends ImperativeAggregate {
 
   override def dataType: DataType = ArrayType(child.dataType)
 
-  override def supportsPartial: Boolean = false
-
-  override def aggBufferAttributes: Seq[AttributeReference] = Nil
-
-  override def aggBufferSchema: StructType = 
StructType.fromAttributes(aggBufferAttributes)
-
-  override def inputAggBufferAttributes: Seq[AttributeReference] = Nil
-
   // Both `CollectList` and `CollectSet` are non-deterministic since their 
results depend on the
   // actual order of input rows.
   override def deterministic: Boolean = false
 
-  protected[this] val buffer: Growable[Any] with Iterable[Any]
-
-  override def initialize(b: InternalRow): Unit = {
-    buffer.clear()
-  }
+  override def update(buffer: T, input: InternalRow): T = {
+    val value = child.eval(input)
 
-  override def update(b: InternalRow, input: InternalRow): Unit = {
     // Do not allow null values. We follow the semantics of Hive's 
collect_list/collect_set here.
     // See: 
org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMkCollectionEvaluator
-    val value = child.eval(input)
     if (value != null) {
       buffer += value
     }
+    buffer
   }
 
-  override def merge(buffer: InternalRow, input: InternalRow): Unit = {
-    sys.error("Collect cannot be used in partial aggregations.")
+  override def merge(buffer: T, other: T): T = {
+    buffer ++= other
   }
 
-  override def eval(input: InternalRow): Any = {
+  override def eval(buffer: T): Any = {
     new GenericArrayData(buffer.toArray)
   }
+
+  private lazy val projection = UnsafeProjection.create(
+    Array[DataType](ArrayType(elementType = child.dataType, containsNull = 
false)))
+  private lazy val row = new UnsafeRow(1)
+
+  override def serialize(obj: T): Array[Byte] = {
+    val array = new GenericArrayData(obj.toArray)
+    projection.apply(InternalRow.apply(array)).getBytes()
+  }
+
+  override def deserialize(bytes: Array[Byte]): T = {
+    val buffer = createAggregationBuffer()
+    row.pointTo(bytes, bytes.length)
+    row.getArray(0).foreach(child.dataType, (_, x: Any) => buffer += x)
+    buffer
+  }
 }
 
 /**
@@ -88,7 +92,7 @@ abstract class Collect extends ImperativeAggregate {
 case class CollectList(
     child: Expression,
     mutableAggBufferOffset: Int = 0,
-    inputAggBufferOffset: Int = 0) extends Collect {
+    inputAggBufferOffset: Int = 0) extends Collect[mutable.ArrayBuffer[Any]] {
 
   def this(child: Expression) = this(child, 0, 0)
 
@@ -98,9 +102,9 @@ case class CollectList(
   override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): 
ImperativeAggregate =
     copy(inputAggBufferOffset = newInputAggBufferOffset)
 
-  override def prettyName: String = "collect_list"
+  override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = 
mutable.ArrayBuffer.empty
 
-  override protected[this] val buffer: mutable.ArrayBuffer[Any] = 
mutable.ArrayBuffer.empty
+  override def prettyName: String = "collect_list"
 }
 
 /**
@@ -111,7 +115,7 @@ case class CollectList(
 case class CollectSet(
     child: Expression,
     mutableAggBufferOffset: Int = 0,
-    inputAggBufferOffset: Int = 0) extends Collect {
+    inputAggBufferOffset: Int = 0) extends Collect[mutable.HashSet[Any]] {
 
   def this(child: Expression) = this(child, 0, 0)
 
@@ -131,5 +135,5 @@ case class CollectSet(
 
   override def prettyName: String = "collect_set"
 
-  override protected[this] val buffer: mutable.HashSet[Any] = 
mutable.HashSet.empty
+  override def createAggregationBuffer(): mutable.HashSet[Any] = 
mutable.HashSet.empty
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/52636226/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index 8e63fba..ccd4ae6 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -458,7 +458,9 @@ abstract class DeclarativeAggregate
  * instead of hash based aggregation, as TypedImperativeAggregate use 
BinaryType as aggregation
  * buffer's storage format, which is not supported by hash based aggregation. 
Hash based
  * aggregation only support aggregation buffer of mutable types (like 
LongType, IntType that have
- * fixed length and can be mutated in place in UnsafeRow)
+ * fixed length and can be mutated in place in UnsafeRow).
+ * NOTE: The newly added ObjectHashAggregateExec supports 
TypedImperativeAggregate functions in
+ * hash based aggregation under some constraints.
  */
 abstract class TypedImperativeAggregate[T] extends ImperativeAggregate {
 

http://git-wip-us.apache.org/repos/asf/spark/blob/52636226/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala
index 0b973c3..5c1faae 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala
@@ -59,15 +59,6 @@ class RewriteDistinctAggregatesSuite extends PlanTest {
     comparePlans(input, rewrite)
   }
 
-  test("single distinct group with non-partial aggregates") {
-    val input = testRelation
-      .groupBy('a, 'd)(
-        countDistinct('e, 'c).as('agg1),
-        CollectSet('b).toAggregateExpression().as('agg2))
-      .analyze
-    checkRewrite(RewriteDistinctAggregates(input))
-  }
-
   test("multiple distinct groups") {
     val input = testRelation
       .groupBy('a)(countDistinct('b, 'c), countDistinct('d))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to