Repository: spark
Updated Branches:
  refs/heads/master 6b8cb1fe5 -> 2b0cc4e0d


[SPARK-12978][SQL] Skip unnecessary final group-by when input data already 
clustered with group-by keys

This ticket targets the optimization to skip an unnecessary group-by operation 
below;

Without opt.:
```
== Physical Plan ==
TungstenAggregate(key=[col0#159], 
functions=[(sum(col1#160),mode=Final,isDistinct=false),(avg(col2#161),mode=Final,isDistinct=false)],
 output=[col0#159,sum(col1)#177,avg(col2)#178])
+- TungstenAggregate(key=[col0#159], 
functions=[(sum(col1#160),mode=Partial,isDistinct=false),(avg(col2#161),mode=Partial,isDistinct=false)],
 output=[col0#159,sum#200,sum#201,count#202L])
   +- TungstenExchange hashpartitioning(col0#159,200), None
      +- InMemoryColumnarTableScan [col0#159,col1#160,col2#161], 
InMemoryRelation [col0#159,col1#160,col2#161], true, 10000, StorageLevel(true, 
true, false, true, 1), ConvertToUnsafe, None
```

With opt.:
```
== Physical Plan ==
TungstenAggregate(key=[col0#159], 
functions=[(sum(col1#160),mode=Complete,isDistinct=false),(avg(col2#161),mode=Final,isDistinct=false)],
 output=[col0#159,sum(col1)#177,avg(col2)#178])
+- TungstenExchange hashpartitioning(col0#159,200), None
  +- InMemoryColumnarTableScan [col0#159,col1#160,col2#161], InMemoryRelation 
[col0#159,col1#160,col2#161], true, 10000, StorageLevel(true, true, false, 
true, 1), ConvertToUnsafe, None
```

Author: Takeshi YAMAMURO <[email protected]>

Closes #10896 from maropu/SkipGroupbySpike.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2b0cc4e0
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2b0cc4e0
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2b0cc4e0

Branch: refs/heads/master
Commit: 2b0cc4e0dfa4ffb9f21ff4a303015bc9c962d42b
Parents: 6b8cb1f
Author: Takeshi YAMAMURO <[email protected]>
Authored: Thu Aug 25 12:39:58 2016 +0200
Committer: Herman van Hovell <[email protected]>
Committed: Thu Aug 25 12:39:58 2016 +0200

----------------------------------------------------------------------
 .../spark/sql/execution/SparkStrategies.scala   |  17 +-
 .../sql/execution/aggregate/AggUtils.scala      | 250 +++++++++----------
 .../sql/execution/aggregate/AggregateExec.scala |  56 +++++
 .../execution/aggregate/HashAggregateExec.scala |  22 +-
 .../execution/aggregate/SortAggregateExec.scala |  24 +-
 .../execution/exchange/EnsureRequirements.scala |  38 ++-
 .../org/apache/spark/sql/DataFrameSuite.scala   |  15 +-
 .../spark/sql/execution/PlannerSuite.scala      |  59 +++--
 8 files changed, 257 insertions(+), 224 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2b0cc4e0/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 4aaf454..cda3b2b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -259,24 +259,17 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
         }
 
         val aggregateOperator =
-          if 
(aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) {
-            if (functionsWithDistinct.nonEmpty) {
-              sys.error("Distinct columns cannot exist in Aggregate operator 
containing " +
-                "aggregate functions which don't support partial aggregation.")
-            } else {
-              aggregate.AggUtils.planAggregateWithoutPartial(
-                groupingExpressions,
-                aggregateExpressions,
-                resultExpressions,
-                planLater(child))
-            }
-          } else if (functionsWithDistinct.isEmpty) {
+          if (functionsWithDistinct.isEmpty) {
             aggregate.AggUtils.planAggregateWithoutDistinct(
               groupingExpressions,
               aggregateExpressions,
               resultExpressions,
               planLater(child))
           } else {
+            if 
(aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) {
+              sys.error("Distinct columns cannot exist in Aggregate operator 
containing " +
+                "aggregate functions which don't support partial aggregation.")
+            }
             aggregate.AggUtils.planAggregateWithOneDistinct(
               groupingExpressions,
               functionsWithDistinct,

http://git-wip-us.apache.org/repos/asf/spark/blob/2b0cc4e0/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
index 4fbb9d5..fe75ece 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
@@ -19,34 +19,97 @@ package org.apache.spark.sql.execution.aggregate
 
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.plans.physical.Distribution
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.execution.streaming.{StateStoreRestoreExec, 
StateStoreSaveExec}
 
 /**
+ * A pattern that finds aggregate operators to support partial aggregations.
+ */
+object PartialAggregate {
+
+  def unapply(plan: SparkPlan): Option[Distribution] = plan match {
+    case agg: AggregateExec if 
AggUtils.supportPartialAggregate(agg.aggregateExpressions) =>
+      Some(agg.requiredChildDistribution.head)
+    case _ =>
+      None
+  }
+}
+
+/**
  * Utility functions used by the query planner to convert our plan to new 
aggregation code path.
  */
 object AggUtils {
 
-  def planAggregateWithoutPartial(
+  def supportPartialAggregate(aggregateExpressions: Seq[AggregateExpression]): 
Boolean = {
+    aggregateExpressions.map(_.aggregateFunction).forall(_.supportsPartial)
+  }
+
+  private def createPartialAggregateExec(
       groupingExpressions: Seq[NamedExpression],
       aggregateExpressions: Seq[AggregateExpression],
-      resultExpressions: Seq[NamedExpression],
-      child: SparkPlan): Seq[SparkPlan] = {
+      child: SparkPlan): SparkPlan = {
+    val groupingAttributes = groupingExpressions.map(_.toAttribute)
+    val functionsWithDistinct = aggregateExpressions.filter(_.isDistinct)
+    val partialAggregateExpressions = aggregateExpressions.map {
+      case agg @ AggregateExpression(_, _, false, _) if 
functionsWithDistinct.length > 0 =>
+        agg.copy(mode = PartialMerge)
+      case agg =>
+        agg.copy(mode = Partial)
+    }
+    val partialAggregateAttributes =
+      
partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
+    val partialResultExpressions =
+      groupingAttributes ++
+        
partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
 
-    val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = 
Complete))
-    val completeAggregateAttributes = 
completeAggregateExpressions.map(_.resultAttribute)
-    SortAggregateExec(
-      requiredChildDistributionExpressions = Some(groupingExpressions),
+    createAggregateExec(
+      requiredChildDistributionExpressions = None,
       groupingExpressions = groupingExpressions,
-      aggregateExpressions = completeAggregateExpressions,
-      aggregateAttributes = completeAggregateAttributes,
-      initialInputBufferOffset = 0,
-      resultExpressions = resultExpressions,
-      child = child
-    ) :: Nil
+      aggregateExpressions = partialAggregateExpressions,
+      aggregateAttributes = partialAggregateAttributes,
+      initialInputBufferOffset = if (functionsWithDistinct.length > 0) {
+        groupingExpressions.length + 
functionsWithDistinct.head.aggregateFunction.children.length
+      } else {
+        0
+      },
+      resultExpressions = partialResultExpressions,
+      child = child)
   }
 
-  private def createAggregate(
+  private def updateMergeAggregateMode(aggregateExpressions: 
Seq[AggregateExpression]) = {
+    def updateMode(mode: AggregateMode) = mode match {
+      case Partial => PartialMerge
+      case Complete => Final
+      case mode => mode
+    }
+    aggregateExpressions.map(e => e.copy(mode = updateMode(e.mode)))
+  }
+
+  /**
+   * Builds new merge and map-side [[AggregateExec]]s from an input aggregate 
operator.
+   * If an aggregation needs a shuffle for satisfying its own distribution and 
supports partial
+   * aggregations, a map-side aggregation is appended before the shuffle in
+   * [[org.apache.spark.sql.execution.exchange.EnsureRequirements]].
+   */
+  def createMapMergeAggregatePair(operator: SparkPlan): (SparkPlan, SparkPlan) 
= operator match {
+    case agg: AggregateExec =>
+      val mapSideAgg = createPartialAggregateExec(
+        agg.groupingExpressions, agg.aggregateExpressions, agg.child)
+      val mergeAgg = createAggregateExec(
+        requiredChildDistributionExpressions = 
agg.requiredChildDistributionExpressions,
+        groupingExpressions = agg.groupingExpressions.map(_.toAttribute),
+        aggregateExpressions = 
updateMergeAggregateMode(agg.aggregateExpressions),
+        aggregateAttributes = agg.aggregateAttributes,
+        initialInputBufferOffset = agg.groupingExpressions.length,
+        resultExpressions = agg.resultExpressions,
+        child = mapSideAgg
+      )
+
+      (mergeAgg, mapSideAgg)
+  }
+
+  private def createAggregateExec(
       requiredChildDistributionExpressions: Option[Seq[Expression]] = None,
       groupingExpressions: Seq[NamedExpression] = Nil,
       aggregateExpressions: Seq[AggregateExpression] = Nil,
@@ -55,7 +118,8 @@ object AggUtils {
       resultExpressions: Seq[NamedExpression] = Nil,
       child: SparkPlan): SparkPlan = {
     val useHash = HashAggregateExec.supportsAggregate(
-      aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))
+      aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) &&
+      supportPartialAggregate(aggregateExpressions)
     if (useHash) {
       HashAggregateExec(
         requiredChildDistributionExpressions = 
requiredChildDistributionExpressions,
@@ -82,43 +146,21 @@ object AggUtils {
       aggregateExpressions: Seq[AggregateExpression],
       resultExpressions: Seq[NamedExpression],
       child: SparkPlan): Seq[SparkPlan] = {
-    // Check if we can use HashAggregate.
-
-    // 1. Create an Aggregate Operator for partial aggregations.
-
     val groupingAttributes = groupingExpressions.map(_.toAttribute)
-    val partialAggregateExpressions = aggregateExpressions.map(_.copy(mode = 
Partial))
-    val partialAggregateAttributes =
-      
partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
-    val partialResultExpressions =
-      groupingAttributes ++
-        
partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
-
-    val partialAggregate = createAggregate(
-        requiredChildDistributionExpressions = None,
-        groupingExpressions = groupingExpressions,
-        aggregateExpressions = partialAggregateExpressions,
-        aggregateAttributes = partialAggregateAttributes,
-        initialInputBufferOffset = 0,
-        resultExpressions = partialResultExpressions,
-        child = child)
-
-    // 2. Create an Aggregate Operator for final aggregations.
-    val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = 
Final))
-    // The attributes of the final aggregation buffer, which is presented as 
input to the result
-    // projection:
-    val finalAggregateAttributes = 
finalAggregateExpressions.map(_.resultAttribute)
-
-    val finalAggregate = createAggregate(
-        requiredChildDistributionExpressions = Some(groupingAttributes),
-        groupingExpressions = groupingAttributes,
-        aggregateExpressions = finalAggregateExpressions,
-        aggregateAttributes = finalAggregateAttributes,
-        initialInputBufferOffset = groupingExpressions.length,
-        resultExpressions = resultExpressions,
-        child = partialAggregate)
+    val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = 
Complete))
+    val completeAggregateAttributes = 
completeAggregateExpressions.map(_.resultAttribute)
+    val supportPartial = supportPartialAggregate(aggregateExpressions)
 
-    finalAggregate :: Nil
+    createAggregateExec(
+      requiredChildDistributionExpressions =
+        Some(if (supportPartial) groupingAttributes else groupingExpressions),
+      groupingExpressions = groupingExpressions,
+      aggregateExpressions = completeAggregateExpressions,
+      aggregateAttributes = completeAggregateAttributes,
+      initialInputBufferOffset = 0,
+      resultExpressions = resultExpressions,
+      child = child
+    ) :: Nil
   }
 
   def planAggregateWithOneDistinct(
@@ -141,39 +183,23 @@ object AggUtils {
     val distinctAttributes = namedDistinctExpressions.map(_.toAttribute)
     val groupingAttributes = groupingExpressions.map(_.toAttribute)
 
-    // 1. Create an Aggregate Operator for partial aggregations.
+    // 1. Create an Aggregate Operator for non-distinct aggregations.
     val partialAggregate: SparkPlan = {
       val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = 
Partial))
       val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
-      // We will group by the original grouping expression, plus an additional 
expression for the
-      // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, 
the grouping
-      // expressions will be [key, value].
-      createAggregate(
-        groupingExpressions = groupingExpressions ++ namedDistinctExpressions,
-        aggregateExpressions = aggregateExpressions,
-        aggregateAttributes = aggregateAttributes,
-        resultExpressions = groupingAttributes ++ distinctAttributes ++
-          
aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes),
-        child = child)
-    }
-
-    // 2. Create an Aggregate Operator for partial merge aggregations.
-    val partialMergeAggregate: SparkPlan = {
-      val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = 
PartialMerge))
-      val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
-      createAggregate(
+      createAggregateExec(
         requiredChildDistributionExpressions =
           Some(groupingAttributes ++ distinctAttributes),
-        groupingExpressions = groupingAttributes ++ distinctAttributes,
+        groupingExpressions = groupingExpressions ++ namedDistinctExpressions,
         aggregateExpressions = aggregateExpressions,
         aggregateAttributes = aggregateAttributes,
         initialInputBufferOffset = (groupingAttributes ++ 
distinctAttributes).length,
         resultExpressions = groupingAttributes ++ distinctAttributes ++
           
aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes),
-        child = partialAggregate)
+        child = child)
     }
 
-    // 3. Create an Aggregate operator for partial aggregation (for distinct)
+    // 2. Create an Aggregate Operator for the final aggregation.
     val distinctColumnAttributeLookup = 
distinctExpressions.zip(distinctAttributes).toMap
     val rewrittenDistinctFunctions = functionsWithDistinct.map {
       // Children of an AggregateFunction with DISTINCT keyword has already
@@ -183,38 +209,6 @@ object AggUtils {
         aggregateFunction.transformDown(distinctColumnAttributeLookup)
           .asInstanceOf[AggregateFunction]
     }
-
-    val partialDistinctAggregate: SparkPlan = {
-      val mergeAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode 
= PartialMerge))
-      // The attributes of the final aggregation buffer, which is presented as 
input to the result
-      // projection:
-      val mergeAggregateAttributes = 
mergeAggregateExpressions.map(_.resultAttribute)
-      val (distinctAggregateExpressions, distinctAggregateAttributes) =
-        rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) =>
-          // We rewrite the aggregate function to a non-distinct aggregation 
because
-          // its input will have distinct arguments.
-          // We just keep the isDistinct setting to true, so when users look 
at the query plan,
-          // they still can see distinct aggregations.
-          val expr = AggregateExpression(func, Partial, isDistinct = true)
-          // Use original AggregationFunction to lookup attributes, which is 
used to build
-          // aggregateFunctionToAttribute
-          val attr = functionsWithDistinct(i).resultAttribute
-          (expr, attr)
-      }.unzip
-
-      val partialAggregateResult = groupingAttributes ++
-          
mergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) 
++
-          
distinctAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
-      createAggregate(
-        groupingExpressions = groupingAttributes,
-        aggregateExpressions = mergeAggregateExpressions ++ 
distinctAggregateExpressions,
-        aggregateAttributes = mergeAggregateAttributes ++ 
distinctAggregateAttributes,
-        initialInputBufferOffset = (groupingAttributes ++ 
distinctAttributes).length,
-        resultExpressions = partialAggregateResult,
-        child = partialMergeAggregate)
-    }
-
-    // 4. Create an Aggregate Operator for the final aggregation.
     val finalAndCompleteAggregate: SparkPlan = {
       val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode 
= Final))
       // The attributes of the final aggregation buffer, which is presented as 
input to the result
@@ -225,23 +219,23 @@ object AggUtils {
         rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) =>
           // We rewrite the aggregate function to a non-distinct aggregation 
because
           // its input will have distinct arguments.
-          // We just keep the isDistinct setting to true, so when users look 
at the query plan,
-          // they still can see distinct aggregations.
-          val expr = AggregateExpression(func, Final, isDistinct = true)
+          // We keep the isDistinct setting to true because this flag is used 
to generate partial
+          // aggregations and it is easy to see aggregation types in the query 
plan.
+          val expr = AggregateExpression(func, Complete, isDistinct = true)
           // Use original AggregationFunction to lookup attributes, which is 
used to build
           // aggregateFunctionToAttribute
           val attr = functionsWithDistinct(i).resultAttribute
           (expr, attr)
-      }.unzip
+        }.unzip
 
-      createAggregate(
+      createAggregateExec(
         requiredChildDistributionExpressions = Some(groupingAttributes),
         groupingExpressions = groupingAttributes,
         aggregateExpressions = finalAggregateExpressions ++ 
distinctAggregateExpressions,
         aggregateAttributes = finalAggregateAttributes ++ 
distinctAggregateAttributes,
         initialInputBufferOffset = groupingAttributes.length,
         resultExpressions = resultExpressions,
-        child = partialDistinctAggregate)
+        child = partialAggregate)
     }
 
     finalAndCompleteAggregate :: Nil
@@ -249,13 +243,14 @@ object AggUtils {
 
   /**
    * Plans a streaming aggregation using the following progression:
-   *  - Partial Aggregation
-   *  - Shuffle
-   *  - Partial Merge (now there is at most 1 tuple per group)
+   *  - Partial Aggregation (now there is at most 1 tuple per group)
    *  - StateStoreRestore (now there is 1 tuple from this batch + optionally 
one from the previous)
    *  - PartialMerge (now there is at most 1 tuple per group)
    *  - StateStoreSave (saves the tuple for the next batch)
    *  - Complete (output the current result of the aggregation)
+   *
+   *  If the first aggregation needs a shuffle to satisfy its distribution, a 
map-side partial
+   *  an aggregation and a shuffle are added in `EnsureRequirements`.
    */
   def planStreamingAggregation(
       groupingExpressions: Seq[NamedExpression],
@@ -268,39 +263,24 @@ object AggUtils {
     val partialAggregate: SparkPlan = {
       val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = 
Partial))
       val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
-      // We will group by the original grouping expression, plus an additional 
expression for the
-      // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, 
the grouping
-      // expressions will be [key, value].
-      createAggregate(
-        groupingExpressions = groupingExpressions,
-        aggregateExpressions = aggregateExpressions,
-        aggregateAttributes = aggregateAttributes,
-        resultExpressions = groupingAttributes ++
-            
aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes),
-        child = child)
-    }
-
-    val partialMerged1: SparkPlan = {
-      val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = 
PartialMerge))
-      val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
-      createAggregate(
+      createAggregateExec(
         requiredChildDistributionExpressions =
             Some(groupingAttributes),
-        groupingExpressions = groupingAttributes,
+        groupingExpressions = groupingExpressions,
         aggregateExpressions = aggregateExpressions,
         aggregateAttributes = aggregateAttributes,
         initialInputBufferOffset = groupingAttributes.length,
         resultExpressions = groupingAttributes ++
             
aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes),
-        child = partialAggregate)
+        child = child)
     }
 
-    val restored = StateStoreRestoreExec(groupingAttributes, None, 
partialMerged1)
+    val restored = StateStoreRestoreExec(groupingAttributes, None, 
partialAggregate)
 
-    val partialMerged2: SparkPlan = {
+    val partialMerged: SparkPlan = {
       val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = 
PartialMerge))
       val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
-      createAggregate(
+      createAggregateExec(
         requiredChildDistributionExpressions =
             Some(groupingAttributes),
         groupingExpressions = groupingAttributes,
@@ -314,7 +294,7 @@ object AggUtils {
     // Note: stateId and returnAllStates are filled in later with preparation 
rules
     // in IncrementalExecution.
     val saved = StateStoreSaveExec(
-      groupingAttributes, stateId = None, returnAllStates = None, 
partialMerged2)
+      groupingAttributes, stateId = None, returnAllStates = None, 
partialMerged)
 
     val finalAndCompleteAggregate: SparkPlan = {
       val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode 
= Final))
@@ -322,7 +302,7 @@ object AggUtils {
       // projection:
       val finalAggregateAttributes = 
finalAggregateExpressions.map(_.resultAttribute)
 
-      createAggregate(
+      createAggregateExec(
         requiredChildDistributionExpressions = Some(groupingAttributes),
         groupingExpressions = groupingAttributes,
         aggregateExpressions = finalAggregateExpressions,

http://git-wip-us.apache.org/repos/asf/spark/blob/2b0cc4e0/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala
new file mode 100644
index 0000000..b88a8aa
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
+import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.UnaryExecNode
+
+/**
+ * A base class for aggregate implementation.
+ */
+abstract class AggregateExec extends UnaryExecNode {
+
+  def requiredChildDistributionExpressions: Option[Seq[Expression]]
+  def groupingExpressions: Seq[NamedExpression]
+  def aggregateExpressions: Seq[AggregateExpression]
+  def aggregateAttributes: Seq[Attribute]
+  def initialInputBufferOffset: Int
+  def resultExpressions: Seq[NamedExpression]
+
+  protected[this] val aggregateBufferAttributes = {
+    aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
+  }
+
+  override def producedAttributes: AttributeSet =
+    AttributeSet(aggregateAttributes) ++
+      
AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
+      AttributeSet(aggregateBufferAttributes)
+
+  override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
+
+  override def requiredChildDistribution: List[Distribution] = {
+    requiredChildDistributionExpressions match {
+      case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
+      case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil
+      case None => UnspecifiedDistribution :: Nil
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/2b0cc4e0/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
index bd7efa6..525c7e3 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
@@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.errors._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
 import org.apache.spark.sql.types.{DecimalType, StringType, StructType}
@@ -42,11 +41,7 @@ case class HashAggregateExec(
     initialInputBufferOffset: Int,
     resultExpressions: Seq[NamedExpression],
     child: SparkPlan)
-  extends UnaryExecNode with CodegenSupport {
-
-  private[this] val aggregateBufferAttributes = {
-    aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
-  }
+  extends AggregateExec with CodegenSupport {
 
   require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes))
 
@@ -60,21 +55,6 @@ case class HashAggregateExec(
     "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"),
     "aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "aggregate time"))
 
-  override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
-
-  override def producedAttributes: AttributeSet =
-    AttributeSet(aggregateAttributes) ++
-    
AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
-    AttributeSet(aggregateBufferAttributes)
-
-  override def requiredChildDistribution: List[Distribution] = {
-    requiredChildDistributionExpressions match {
-      case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
-      case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil
-      case None => UnspecifiedDistribution :: Nil
-    }
-  }
-
   // This is for testing. We force TungstenAggregationIterator to fall back to 
the unsafe row hash
   // map and/or the sort-based aggregation once it has processed a given 
number of input rows.
   private val testFallbackStartsAt: Option[(Int, Int)] = {

http://git-wip-us.apache.org/repos/asf/spark/blob/2b0cc4e0/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala
index 2a81a82..68f86fc 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala
@@ -22,8 +22,7 @@ import org.apache.spark.sql.catalyst.errors._
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, 
ClusteredDistribution, Distribution, UnspecifiedDistribution}
-import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
+import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.execution.metric.SQLMetrics
 import org.apache.spark.util.Utils
 
@@ -38,30 +37,11 @@ case class SortAggregateExec(
     initialInputBufferOffset: Int,
     resultExpressions: Seq[NamedExpression],
     child: SparkPlan)
-  extends UnaryExecNode {
-
-  private[this] val aggregateBufferAttributes = {
-    aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
-  }
-
-  override def producedAttributes: AttributeSet =
-    AttributeSet(aggregateAttributes) ++
-      
AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
-      AttributeSet(aggregateBufferAttributes)
+  extends AggregateExec {
 
   override lazy val metrics = Map(
     "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output 
rows"))
 
-  override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
-
-  override def requiredChildDistribution: List[Distribution] = {
-    requiredChildDistributionExpressions match {
-      case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
-      case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil
-      case None => UnspecifiedDistribution :: Nil
-    }
-  }
-
   override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
     groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/2b0cc4e0/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
index 446571a..951051c 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
@@ -21,6 +21,8 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.aggregate.AggUtils
+import org.apache.spark.sql.execution.aggregate.PartialAggregate
 import org.apache.spark.sql.internal.SQLConf
 
 /**
@@ -151,18 +153,30 @@ case class EnsureRequirements(conf: SQLConf) extends 
Rule[SparkPlan] {
   private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = {
     val requiredChildDistributions: Seq[Distribution] = 
operator.requiredChildDistribution
     val requiredChildOrderings: Seq[Seq[SortOrder]] = 
operator.requiredChildOrdering
-    var children: Seq[SparkPlan] = operator.children
-    assert(requiredChildDistributions.length == children.length)
-    assert(requiredChildOrderings.length == children.length)
+    assert(requiredChildDistributions.length == operator.children.length)
+    assert(requiredChildOrderings.length == operator.children.length)
 
-    // Ensure that the operator's children satisfy their output distribution 
requirements:
-    children = children.zip(requiredChildDistributions).map {
-      case (child, distribution) if 
child.outputPartitioning.satisfies(distribution) =>
-        child
-      case (child, BroadcastDistribution(mode)) =>
-        BroadcastExchangeExec(mode, child)
-      case (child, distribution) =>
-        ShuffleExchange(createPartitioning(distribution, 
defaultNumPreShufflePartitions), child)
+    def createShuffleExchange(dist: Distribution, child: SparkPlan) =
+      ShuffleExchange(createPartitioning(dist, 
defaultNumPreShufflePartitions), child)
+
+    var (parent, children) = operator match {
+      case PartialAggregate(childDist) if 
!operator.outputPartitioning.satisfies(childDist) =>
+        // If an aggregation needs a shuffle and support partial aggregations, 
a map-side partial
+        // aggregation and a shuffle are added as children.
+        val (mergeAgg, mapSideAgg) = 
AggUtils.createMapMergeAggregatePair(operator)
+        (mergeAgg, createShuffleExchange(requiredChildDistributions.head, 
mapSideAgg) :: Nil)
+      case _ =>
+        // Ensure that the operator's children satisfy their output 
distribution requirements:
+        val childrenWithDist = 
operator.children.zip(requiredChildDistributions)
+        val newChildren = childrenWithDist.map {
+          case (child, distribution) if 
child.outputPartitioning.satisfies(distribution) =>
+            child
+          case (child, BroadcastDistribution(mode)) =>
+            BroadcastExchangeExec(mode, child)
+          case (child, distribution) =>
+            createShuffleExchange(distribution, child)
+        }
+        (operator, newChildren)
     }
 
     // If the operator has multiple children and specifies child output 
distributions (e.g. join),
@@ -246,7 +260,7 @@ case class EnsureRequirements(conf: SQLConf) extends 
Rule[SparkPlan] {
       }
     }
 
-    operator.withNewChildren(children)
+    parent.withNewChildren(children)
   }
 
   def apply(plan: SparkPlan): SparkPlan = plan.transformUp {

http://git-wip-us.apache.org/repos/asf/spark/blob/2b0cc4e0/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 499f318..cd48577 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -1248,17 +1248,17 @@ class DataFrameSuite extends QueryTest with 
SharedSQLContext {
   }
 
   /**
-   * Verifies that there is no Exchange between the Aggregations for `df`
+   * Verifies that there is a single Aggregation for `df`
    */
-  private def verifyNonExchangingAgg(df: DataFrame) = {
+  private def verifyNonExchangingSingleAgg(df: DataFrame) = {
     var atFirstAgg: Boolean = false
     df.queryExecution.executedPlan.foreach {
       case agg: HashAggregateExec =>
-        atFirstAgg = !atFirstAgg
-      case _ =>
         if (atFirstAgg) {
-          fail("Should not have operators between the two aggregations")
+          fail("Should not have back to back Aggregates")
         }
+        atFirstAgg = true
+      case _ =>
     }
   }
 
@@ -1292,9 +1292,10 @@ class DataFrameSuite extends QueryTest with 
SharedSQLContext {
     // Group by the column we are distributed by. This should generate a plan 
with no exchange
     // between the aggregates
     val df3 = testData.repartition($"key").groupBy("key").count()
-    verifyNonExchangingAgg(df3)
-    verifyNonExchangingAgg(testData.repartition($"key", $"value")
+    verifyNonExchangingSingleAgg(df3)
+    verifyNonExchangingSingleAgg(testData.repartition($"key", $"value")
       .groupBy("key", "value").count())
+    verifyNonExchangingSingleAgg(testData.repartition($"key").groupBy("key", 
"value").count())
 
     // Grouping by just the first distributeBy expr, need to exchange.
     verifyExchangingAgg(testData.repartition($"key", $"value")

http://git-wip-us.apache.org/repos/asf/spark/blob/2b0cc4e0/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 13490c3..436ff59 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.execution
 
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{execution, Row}
+import org.apache.spark.sql.{execution, DataFrame, Row}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, 
Literal, SortOrder}
 import org.apache.spark.sql.catalyst.plans.Inner
@@ -37,36 +37,65 @@ class PlannerSuite extends SharedSQLContext {
 
   setupTestData()
 
-  private def testPartialAggregationPlan(query: LogicalPlan): Unit = {
+  private def testPartialAggregationPlan(query: LogicalPlan): Seq[SparkPlan] = 
{
     val planner = spark.sessionState.planner
     import planner._
-    val plannedOption = Aggregation(query).headOption
-    val planned =
-      plannedOption.getOrElse(
-        fail(s"Could query play aggregation query $query. Is it an aggregation 
query?"))
-    val aggregations = planned.collect { case n if n.nodeName contains 
"Aggregate" => n }
-
-    // For the new aggregation code path, there will be four aggregate 
operator for
-    // distinct aggregations.
-    assert(
-      aggregations.size == 2 || aggregations.size == 4,
-      s"The plan of query $query does not have partial aggregations.")
+    val ensureRequirements = EnsureRequirements(spark.sessionState.conf)
+    val planned = Aggregation(query).headOption.map(ensureRequirements(_))
+      .getOrElse(fail(s"Could query play aggregation query $query. Is it an 
aggregation query?"))
+    planned.collect { case n if n.nodeName contains "Aggregate" => n }
   }
 
   test("count is partially aggregated") {
     val query = 
testData.groupBy('value).agg(count('key)).queryExecution.analyzed
-    testPartialAggregationPlan(query)
+    assert(testPartialAggregationPlan(query).size == 2,
+      s"The plan of query $query does not have partial aggregations.")
   }
 
   test("count distinct is partially aggregated") {
     val query = 
testData.groupBy('value).agg(countDistinct('key)).queryExecution.analyzed
     testPartialAggregationPlan(query)
+    // For the new aggregation code path, there will be four aggregate 
operator for  distinct
+    // aggregations.
+    assert(testPartialAggregationPlan(query).size == 4,
+      s"The plan of query $query does not have partial aggregations.")
   }
 
   test("mixed aggregates are partially aggregated") {
     val query =
       testData.groupBy('value).agg(count('value), 
countDistinct('key)).queryExecution.analyzed
-    testPartialAggregationPlan(query)
+    // For the new aggregation code path, there will be four aggregate 
operator for  distinct
+    // aggregations.
+    assert(testPartialAggregationPlan(query).size == 4,
+      s"The plan of query $query does not have partial aggregations.")
+  }
+
+  test("non-partial aggregation for aggregates") {
+    withTempView("testNonPartialAggregation") {
+      val schema = StructType(StructField(s"value", IntegerType, true) :: Nil)
+      val row = Row.fromSeq(Seq.fill(1)(null))
+      val rowRDD = sparkContext.parallelize(row :: Nil)
+      spark.createDataFrame(rowRDD, schema).repartition($"value")
+        .createOrReplaceTempView("testNonPartialAggregation")
+
+      val planned1 = sql("SELECT SUM(value) FROM testNonPartialAggregation 
GROUP BY value")
+        .queryExecution.executedPlan
+
+      // If input data are already partitioned and the same columns are used 
in grouping keys and
+      // aggregation values, no partial aggregation exist in query plans.
+      val aggOps1 = planned1.collect { case n if n.nodeName contains 
"Aggregate" => n }
+      assert(aggOps1.size == 1, s"The plan $planned1 has partial 
aggregations.")
+
+      val planned2 = sql(
+        """
+          |SELECT t.value, SUM(DISTINCT t.value)
+          |FROM (SELECT * FROM testNonPartialAggregation ORDER BY value) t
+          |GROUP BY t.value
+        """.stripMargin).queryExecution.executedPlan
+
+      val aggOps2 = planned1.collect { case n if n.nodeName contains 
"Aggregate" => n }
+      assert(aggOps2.size == 1, s"The plan $planned2 has partial 
aggregations.")
+    }
   }
 
   test("sizeInBytes estimation of limit operator for broadcast hash join 
optimization") {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to