Repository: spark
Updated Branches:
refs/heads/master 6b8cb1fe5 -> 2b0cc4e0d
[SPARK-12978][SQL] Skip unnecessary final group-by when input data already
clustered with group-by keys
This ticket targets the optimization to skip an unnecessary group-by operation
below;
Without opt.:
```
== Physical Plan ==
TungstenAggregate(key=[col0#159],
functions=[(sum(col1#160),mode=Final,isDistinct=false),(avg(col2#161),mode=Final,isDistinct=false)],
output=[col0#159,sum(col1)#177,avg(col2)#178])
+- TungstenAggregate(key=[col0#159],
functions=[(sum(col1#160),mode=Partial,isDistinct=false),(avg(col2#161),mode=Partial,isDistinct=false)],
output=[col0#159,sum#200,sum#201,count#202L])
+- TungstenExchange hashpartitioning(col0#159,200), None
+- InMemoryColumnarTableScan [col0#159,col1#160,col2#161],
InMemoryRelation [col0#159,col1#160,col2#161], true, 10000, StorageLevel(true,
true, false, true, 1), ConvertToUnsafe, None
```
With opt.:
```
== Physical Plan ==
TungstenAggregate(key=[col0#159],
functions=[(sum(col1#160),mode=Complete,isDistinct=false),(avg(col2#161),mode=Final,isDistinct=false)],
output=[col0#159,sum(col1)#177,avg(col2)#178])
+- TungstenExchange hashpartitioning(col0#159,200), None
+- InMemoryColumnarTableScan [col0#159,col1#160,col2#161], InMemoryRelation
[col0#159,col1#160,col2#161], true, 10000, StorageLevel(true, true, false,
true, 1), ConvertToUnsafe, None
```
Author: Takeshi YAMAMURO <[email protected]>
Closes #10896 from maropu/SkipGroupbySpike.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2b0cc4e0
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2b0cc4e0
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2b0cc4e0
Branch: refs/heads/master
Commit: 2b0cc4e0dfa4ffb9f21ff4a303015bc9c962d42b
Parents: 6b8cb1f
Author: Takeshi YAMAMURO <[email protected]>
Authored: Thu Aug 25 12:39:58 2016 +0200
Committer: Herman van Hovell <[email protected]>
Committed: Thu Aug 25 12:39:58 2016 +0200
----------------------------------------------------------------------
.../spark/sql/execution/SparkStrategies.scala | 17 +-
.../sql/execution/aggregate/AggUtils.scala | 250 +++++++++----------
.../sql/execution/aggregate/AggregateExec.scala | 56 +++++
.../execution/aggregate/HashAggregateExec.scala | 22 +-
.../execution/aggregate/SortAggregateExec.scala | 24 +-
.../execution/exchange/EnsureRequirements.scala | 38 ++-
.../org/apache/spark/sql/DataFrameSuite.scala | 15 +-
.../spark/sql/execution/PlannerSuite.scala | 59 +++--
8 files changed, 257 insertions(+), 224 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/2b0cc4e0/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 4aaf454..cda3b2b 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -259,24 +259,17 @@ abstract class SparkStrategies extends
QueryPlanner[SparkPlan] {
}
val aggregateOperator =
- if
(aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) {
- if (functionsWithDistinct.nonEmpty) {
- sys.error("Distinct columns cannot exist in Aggregate operator
containing " +
- "aggregate functions which don't support partial aggregation.")
- } else {
- aggregate.AggUtils.planAggregateWithoutPartial(
- groupingExpressions,
- aggregateExpressions,
- resultExpressions,
- planLater(child))
- }
- } else if (functionsWithDistinct.isEmpty) {
+ if (functionsWithDistinct.isEmpty) {
aggregate.AggUtils.planAggregateWithoutDistinct(
groupingExpressions,
aggregateExpressions,
resultExpressions,
planLater(child))
} else {
+ if
(aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) {
+ sys.error("Distinct columns cannot exist in Aggregate operator
containing " +
+ "aggregate functions which don't support partial aggregation.")
+ }
aggregate.AggUtils.planAggregateWithOneDistinct(
groupingExpressions,
functionsWithDistinct,
http://git-wip-us.apache.org/repos/asf/spark/blob/2b0cc4e0/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
index 4fbb9d5..fe75ece 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
@@ -19,34 +19,97 @@ package org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.plans.physical.Distribution
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.streaming.{StateStoreRestoreExec,
StateStoreSaveExec}
/**
+ * A pattern that finds aggregate operators to support partial aggregations.
+ */
+object PartialAggregate {
+
+ def unapply(plan: SparkPlan): Option[Distribution] = plan match {
+ case agg: AggregateExec if
AggUtils.supportPartialAggregate(agg.aggregateExpressions) =>
+ Some(agg.requiredChildDistribution.head)
+ case _ =>
+ None
+ }
+}
+
+/**
* Utility functions used by the query planner to convert our plan to new
aggregation code path.
*/
object AggUtils {
- def planAggregateWithoutPartial(
+ def supportPartialAggregate(aggregateExpressions: Seq[AggregateExpression]):
Boolean = {
+ aggregateExpressions.map(_.aggregateFunction).forall(_.supportsPartial)
+ }
+
+ private def createPartialAggregateExec(
groupingExpressions: Seq[NamedExpression],
aggregateExpressions: Seq[AggregateExpression],
- resultExpressions: Seq[NamedExpression],
- child: SparkPlan): Seq[SparkPlan] = {
+ child: SparkPlan): SparkPlan = {
+ val groupingAttributes = groupingExpressions.map(_.toAttribute)
+ val functionsWithDistinct = aggregateExpressions.filter(_.isDistinct)
+ val partialAggregateExpressions = aggregateExpressions.map {
+ case agg @ AggregateExpression(_, _, false, _) if
functionsWithDistinct.length > 0 =>
+ agg.copy(mode = PartialMerge)
+ case agg =>
+ agg.copy(mode = Partial)
+ }
+ val partialAggregateAttributes =
+
partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
+ val partialResultExpressions =
+ groupingAttributes ++
+
partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
- val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode =
Complete))
- val completeAggregateAttributes =
completeAggregateExpressions.map(_.resultAttribute)
- SortAggregateExec(
- requiredChildDistributionExpressions = Some(groupingExpressions),
+ createAggregateExec(
+ requiredChildDistributionExpressions = None,
groupingExpressions = groupingExpressions,
- aggregateExpressions = completeAggregateExpressions,
- aggregateAttributes = completeAggregateAttributes,
- initialInputBufferOffset = 0,
- resultExpressions = resultExpressions,
- child = child
- ) :: Nil
+ aggregateExpressions = partialAggregateExpressions,
+ aggregateAttributes = partialAggregateAttributes,
+ initialInputBufferOffset = if (functionsWithDistinct.length > 0) {
+ groupingExpressions.length +
functionsWithDistinct.head.aggregateFunction.children.length
+ } else {
+ 0
+ },
+ resultExpressions = partialResultExpressions,
+ child = child)
}
- private def createAggregate(
+ private def updateMergeAggregateMode(aggregateExpressions:
Seq[AggregateExpression]) = {
+ def updateMode(mode: AggregateMode) = mode match {
+ case Partial => PartialMerge
+ case Complete => Final
+ case mode => mode
+ }
+ aggregateExpressions.map(e => e.copy(mode = updateMode(e.mode)))
+ }
+
+ /**
+ * Builds new merge and map-side [[AggregateExec]]s from an input aggregate
operator.
+ * If an aggregation needs a shuffle for satisfying its own distribution and
supports partial
+ * aggregations, a map-side aggregation is appended before the shuffle in
+ * [[org.apache.spark.sql.execution.exchange.EnsureRequirements]].
+ */
+ def createMapMergeAggregatePair(operator: SparkPlan): (SparkPlan, SparkPlan)
= operator match {
+ case agg: AggregateExec =>
+ val mapSideAgg = createPartialAggregateExec(
+ agg.groupingExpressions, agg.aggregateExpressions, agg.child)
+ val mergeAgg = createAggregateExec(
+ requiredChildDistributionExpressions =
agg.requiredChildDistributionExpressions,
+ groupingExpressions = agg.groupingExpressions.map(_.toAttribute),
+ aggregateExpressions =
updateMergeAggregateMode(agg.aggregateExpressions),
+ aggregateAttributes = agg.aggregateAttributes,
+ initialInputBufferOffset = agg.groupingExpressions.length,
+ resultExpressions = agg.resultExpressions,
+ child = mapSideAgg
+ )
+
+ (mergeAgg, mapSideAgg)
+ }
+
+ private def createAggregateExec(
requiredChildDistributionExpressions: Option[Seq[Expression]] = None,
groupingExpressions: Seq[NamedExpression] = Nil,
aggregateExpressions: Seq[AggregateExpression] = Nil,
@@ -55,7 +118,8 @@ object AggUtils {
resultExpressions: Seq[NamedExpression] = Nil,
child: SparkPlan): SparkPlan = {
val useHash = HashAggregateExec.supportsAggregate(
- aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))
+ aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) &&
+ supportPartialAggregate(aggregateExpressions)
if (useHash) {
HashAggregateExec(
requiredChildDistributionExpressions =
requiredChildDistributionExpressions,
@@ -82,43 +146,21 @@ object AggUtils {
aggregateExpressions: Seq[AggregateExpression],
resultExpressions: Seq[NamedExpression],
child: SparkPlan): Seq[SparkPlan] = {
- // Check if we can use HashAggregate.
-
- // 1. Create an Aggregate Operator for partial aggregations.
-
val groupingAttributes = groupingExpressions.map(_.toAttribute)
- val partialAggregateExpressions = aggregateExpressions.map(_.copy(mode =
Partial))
- val partialAggregateAttributes =
-
partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
- val partialResultExpressions =
- groupingAttributes ++
-
partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
-
- val partialAggregate = createAggregate(
- requiredChildDistributionExpressions = None,
- groupingExpressions = groupingExpressions,
- aggregateExpressions = partialAggregateExpressions,
- aggregateAttributes = partialAggregateAttributes,
- initialInputBufferOffset = 0,
- resultExpressions = partialResultExpressions,
- child = child)
-
- // 2. Create an Aggregate Operator for final aggregations.
- val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode =
Final))
- // The attributes of the final aggregation buffer, which is presented as
input to the result
- // projection:
- val finalAggregateAttributes =
finalAggregateExpressions.map(_.resultAttribute)
-
- val finalAggregate = createAggregate(
- requiredChildDistributionExpressions = Some(groupingAttributes),
- groupingExpressions = groupingAttributes,
- aggregateExpressions = finalAggregateExpressions,
- aggregateAttributes = finalAggregateAttributes,
- initialInputBufferOffset = groupingExpressions.length,
- resultExpressions = resultExpressions,
- child = partialAggregate)
+ val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode =
Complete))
+ val completeAggregateAttributes =
completeAggregateExpressions.map(_.resultAttribute)
+ val supportPartial = supportPartialAggregate(aggregateExpressions)
- finalAggregate :: Nil
+ createAggregateExec(
+ requiredChildDistributionExpressions =
+ Some(if (supportPartial) groupingAttributes else groupingExpressions),
+ groupingExpressions = groupingExpressions,
+ aggregateExpressions = completeAggregateExpressions,
+ aggregateAttributes = completeAggregateAttributes,
+ initialInputBufferOffset = 0,
+ resultExpressions = resultExpressions,
+ child = child
+ ) :: Nil
}
def planAggregateWithOneDistinct(
@@ -141,39 +183,23 @@ object AggUtils {
val distinctAttributes = namedDistinctExpressions.map(_.toAttribute)
val groupingAttributes = groupingExpressions.map(_.toAttribute)
- // 1. Create an Aggregate Operator for partial aggregations.
+ // 1. Create an Aggregate Operator for non-distinct aggregations.
val partialAggregate: SparkPlan = {
val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode =
Partial))
val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
- // We will group by the original grouping expression, plus an additional
expression for the
- // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key,
the grouping
- // expressions will be [key, value].
- createAggregate(
- groupingExpressions = groupingExpressions ++ namedDistinctExpressions,
- aggregateExpressions = aggregateExpressions,
- aggregateAttributes = aggregateAttributes,
- resultExpressions = groupingAttributes ++ distinctAttributes ++
-
aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes),
- child = child)
- }
-
- // 2. Create an Aggregate Operator for partial merge aggregations.
- val partialMergeAggregate: SparkPlan = {
- val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode =
PartialMerge))
- val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
- createAggregate(
+ createAggregateExec(
requiredChildDistributionExpressions =
Some(groupingAttributes ++ distinctAttributes),
- groupingExpressions = groupingAttributes ++ distinctAttributes,
+ groupingExpressions = groupingExpressions ++ namedDistinctExpressions,
aggregateExpressions = aggregateExpressions,
aggregateAttributes = aggregateAttributes,
initialInputBufferOffset = (groupingAttributes ++
distinctAttributes).length,
resultExpressions = groupingAttributes ++ distinctAttributes ++
aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes),
- child = partialAggregate)
+ child = child)
}
- // 3. Create an Aggregate operator for partial aggregation (for distinct)
+ // 2. Create an Aggregate Operator for the final aggregation.
val distinctColumnAttributeLookup =
distinctExpressions.zip(distinctAttributes).toMap
val rewrittenDistinctFunctions = functionsWithDistinct.map {
// Children of an AggregateFunction with DISTINCT keyword has already
@@ -183,38 +209,6 @@ object AggUtils {
aggregateFunction.transformDown(distinctColumnAttributeLookup)
.asInstanceOf[AggregateFunction]
}
-
- val partialDistinctAggregate: SparkPlan = {
- val mergeAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode
= PartialMerge))
- // The attributes of the final aggregation buffer, which is presented as
input to the result
- // projection:
- val mergeAggregateAttributes =
mergeAggregateExpressions.map(_.resultAttribute)
- val (distinctAggregateExpressions, distinctAggregateAttributes) =
- rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) =>
- // We rewrite the aggregate function to a non-distinct aggregation
because
- // its input will have distinct arguments.
- // We just keep the isDistinct setting to true, so when users look
at the query plan,
- // they still can see distinct aggregations.
- val expr = AggregateExpression(func, Partial, isDistinct = true)
- // Use original AggregationFunction to lookup attributes, which is
used to build
- // aggregateFunctionToAttribute
- val attr = functionsWithDistinct(i).resultAttribute
- (expr, attr)
- }.unzip
-
- val partialAggregateResult = groupingAttributes ++
-
mergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
++
-
distinctAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
- createAggregate(
- groupingExpressions = groupingAttributes,
- aggregateExpressions = mergeAggregateExpressions ++
distinctAggregateExpressions,
- aggregateAttributes = mergeAggregateAttributes ++
distinctAggregateAttributes,
- initialInputBufferOffset = (groupingAttributes ++
distinctAttributes).length,
- resultExpressions = partialAggregateResult,
- child = partialMergeAggregate)
- }
-
- // 4. Create an Aggregate Operator for the final aggregation.
val finalAndCompleteAggregate: SparkPlan = {
val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode
= Final))
// The attributes of the final aggregation buffer, which is presented as
input to the result
@@ -225,23 +219,23 @@ object AggUtils {
rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) =>
// We rewrite the aggregate function to a non-distinct aggregation
because
// its input will have distinct arguments.
- // We just keep the isDistinct setting to true, so when users look
at the query plan,
- // they still can see distinct aggregations.
- val expr = AggregateExpression(func, Final, isDistinct = true)
+ // We keep the isDistinct setting to true because this flag is used
to generate partial
+ // aggregations and it is easy to see aggregation types in the query
plan.
+ val expr = AggregateExpression(func, Complete, isDistinct = true)
// Use original AggregationFunction to lookup attributes, which is
used to build
// aggregateFunctionToAttribute
val attr = functionsWithDistinct(i).resultAttribute
(expr, attr)
- }.unzip
+ }.unzip
- createAggregate(
+ createAggregateExec(
requiredChildDistributionExpressions = Some(groupingAttributes),
groupingExpressions = groupingAttributes,
aggregateExpressions = finalAggregateExpressions ++
distinctAggregateExpressions,
aggregateAttributes = finalAggregateAttributes ++
distinctAggregateAttributes,
initialInputBufferOffset = groupingAttributes.length,
resultExpressions = resultExpressions,
- child = partialDistinctAggregate)
+ child = partialAggregate)
}
finalAndCompleteAggregate :: Nil
@@ -249,13 +243,14 @@ object AggUtils {
/**
* Plans a streaming aggregation using the following progression:
- * - Partial Aggregation
- * - Shuffle
- * - Partial Merge (now there is at most 1 tuple per group)
+ * - Partial Aggregation (now there is at most 1 tuple per group)
* - StateStoreRestore (now there is 1 tuple from this batch + optionally
one from the previous)
* - PartialMerge (now there is at most 1 tuple per group)
* - StateStoreSave (saves the tuple for the next batch)
* - Complete (output the current result of the aggregation)
+ *
+ * If the first aggregation needs a shuffle to satisfy its distribution, a
map-side partial
+ * an aggregation and a shuffle are added in `EnsureRequirements`.
*/
def planStreamingAggregation(
groupingExpressions: Seq[NamedExpression],
@@ -268,39 +263,24 @@ object AggUtils {
val partialAggregate: SparkPlan = {
val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode =
Partial))
val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
- // We will group by the original grouping expression, plus an additional
expression for the
- // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key,
the grouping
- // expressions will be [key, value].
- createAggregate(
- groupingExpressions = groupingExpressions,
- aggregateExpressions = aggregateExpressions,
- aggregateAttributes = aggregateAttributes,
- resultExpressions = groupingAttributes ++
-
aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes),
- child = child)
- }
-
- val partialMerged1: SparkPlan = {
- val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode =
PartialMerge))
- val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
- createAggregate(
+ createAggregateExec(
requiredChildDistributionExpressions =
Some(groupingAttributes),
- groupingExpressions = groupingAttributes,
+ groupingExpressions = groupingExpressions,
aggregateExpressions = aggregateExpressions,
aggregateAttributes = aggregateAttributes,
initialInputBufferOffset = groupingAttributes.length,
resultExpressions = groupingAttributes ++
aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes),
- child = partialAggregate)
+ child = child)
}
- val restored = StateStoreRestoreExec(groupingAttributes, None,
partialMerged1)
+ val restored = StateStoreRestoreExec(groupingAttributes, None,
partialAggregate)
- val partialMerged2: SparkPlan = {
+ val partialMerged: SparkPlan = {
val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode =
PartialMerge))
val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
- createAggregate(
+ createAggregateExec(
requiredChildDistributionExpressions =
Some(groupingAttributes),
groupingExpressions = groupingAttributes,
@@ -314,7 +294,7 @@ object AggUtils {
// Note: stateId and returnAllStates are filled in later with preparation
rules
// in IncrementalExecution.
val saved = StateStoreSaveExec(
- groupingAttributes, stateId = None, returnAllStates = None,
partialMerged2)
+ groupingAttributes, stateId = None, returnAllStates = None,
partialMerged)
val finalAndCompleteAggregate: SparkPlan = {
val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode
= Final))
@@ -322,7 +302,7 @@ object AggUtils {
// projection:
val finalAggregateAttributes =
finalAggregateExpressions.map(_.resultAttribute)
- createAggregate(
+ createAggregateExec(
requiredChildDistributionExpressions = Some(groupingAttributes),
groupingExpressions = groupingAttributes,
aggregateExpressions = finalAggregateExpressions,
http://git-wip-us.apache.org/repos/asf/spark/blob/2b0cc4e0/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala
new file mode 100644
index 0000000..b88a8aa
--- /dev/null
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
+import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.UnaryExecNode
+
+/**
+ * A base class for aggregate implementation.
+ */
+abstract class AggregateExec extends UnaryExecNode {
+
+ def requiredChildDistributionExpressions: Option[Seq[Expression]]
+ def groupingExpressions: Seq[NamedExpression]
+ def aggregateExpressions: Seq[AggregateExpression]
+ def aggregateAttributes: Seq[Attribute]
+ def initialInputBufferOffset: Int
+ def resultExpressions: Seq[NamedExpression]
+
+ protected[this] val aggregateBufferAttributes = {
+ aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
+ }
+
+ override def producedAttributes: AttributeSet =
+ AttributeSet(aggregateAttributes) ++
+
AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
+ AttributeSet(aggregateBufferAttributes)
+
+ override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
+
+ override def requiredChildDistribution: List[Distribution] = {
+ requiredChildDistributionExpressions match {
+ case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
+ case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil
+ case None => UnspecifiedDistribution :: Nil
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/2b0cc4e0/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
index bd7efa6..525c7e3 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
@@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.types.{DecimalType, StringType, StructType}
@@ -42,11 +41,7 @@ case class HashAggregateExec(
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
- extends UnaryExecNode with CodegenSupport {
-
- private[this] val aggregateBufferAttributes = {
- aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
- }
+ extends AggregateExec with CodegenSupport {
require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes))
@@ -60,21 +55,6 @@ case class HashAggregateExec(
"spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"),
"aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "aggregate time"))
- override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
-
- override def producedAttributes: AttributeSet =
- AttributeSet(aggregateAttributes) ++
-
AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
- AttributeSet(aggregateBufferAttributes)
-
- override def requiredChildDistribution: List[Distribution] = {
- requiredChildDistributionExpressions match {
- case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
- case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil
- case None => UnspecifiedDistribution :: Nil
- }
- }
-
// This is for testing. We force TungstenAggregationIterator to fall back to
the unsafe row hash
// map and/or the sort-based aggregation once it has processed a given
number of input rows.
private val testFallbackStartsAt: Option[(Int, Int)] = {
http://git-wip-us.apache.org/repos/asf/spark/blob/2b0cc4e0/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala
index 2a81a82..68f86fc 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala
@@ -22,8 +22,7 @@ import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.plans.physical.{AllTuples,
ClusteredDistribution, Distribution, UnspecifiedDistribution}
-import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
+import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.util.Utils
@@ -38,30 +37,11 @@ case class SortAggregateExec(
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
- extends UnaryExecNode {
-
- private[this] val aggregateBufferAttributes = {
- aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
- }
-
- override def producedAttributes: AttributeSet =
- AttributeSet(aggregateAttributes) ++
-
AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
- AttributeSet(aggregateBufferAttributes)
+ extends AggregateExec {
override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output
rows"))
- override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
-
- override def requiredChildDistribution: List[Distribution] = {
- requiredChildDistributionExpressions match {
- case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
- case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil
- case None => UnspecifiedDistribution :: Nil
- }
- }
-
override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
}
http://git-wip-us.apache.org/repos/asf/spark/blob/2b0cc4e0/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
index 446571a..951051c 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
@@ -21,6 +21,8 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.aggregate.AggUtils
+import org.apache.spark.sql.execution.aggregate.PartialAggregate
import org.apache.spark.sql.internal.SQLConf
/**
@@ -151,18 +153,30 @@ case class EnsureRequirements(conf: SQLConf) extends
Rule[SparkPlan] {
private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = {
val requiredChildDistributions: Seq[Distribution] =
operator.requiredChildDistribution
val requiredChildOrderings: Seq[Seq[SortOrder]] =
operator.requiredChildOrdering
- var children: Seq[SparkPlan] = operator.children
- assert(requiredChildDistributions.length == children.length)
- assert(requiredChildOrderings.length == children.length)
+ assert(requiredChildDistributions.length == operator.children.length)
+ assert(requiredChildOrderings.length == operator.children.length)
- // Ensure that the operator's children satisfy their output distribution
requirements:
- children = children.zip(requiredChildDistributions).map {
- case (child, distribution) if
child.outputPartitioning.satisfies(distribution) =>
- child
- case (child, BroadcastDistribution(mode)) =>
- BroadcastExchangeExec(mode, child)
- case (child, distribution) =>
- ShuffleExchange(createPartitioning(distribution,
defaultNumPreShufflePartitions), child)
+ def createShuffleExchange(dist: Distribution, child: SparkPlan) =
+ ShuffleExchange(createPartitioning(dist,
defaultNumPreShufflePartitions), child)
+
+ var (parent, children) = operator match {
+ case PartialAggregate(childDist) if
!operator.outputPartitioning.satisfies(childDist) =>
+ // If an aggregation needs a shuffle and support partial aggregations,
a map-side partial
+ // aggregation and a shuffle are added as children.
+ val (mergeAgg, mapSideAgg) =
AggUtils.createMapMergeAggregatePair(operator)
+ (mergeAgg, createShuffleExchange(requiredChildDistributions.head,
mapSideAgg) :: Nil)
+ case _ =>
+ // Ensure that the operator's children satisfy their output
distribution requirements:
+ val childrenWithDist =
operator.children.zip(requiredChildDistributions)
+ val newChildren = childrenWithDist.map {
+ case (child, distribution) if
child.outputPartitioning.satisfies(distribution) =>
+ child
+ case (child, BroadcastDistribution(mode)) =>
+ BroadcastExchangeExec(mode, child)
+ case (child, distribution) =>
+ createShuffleExchange(distribution, child)
+ }
+ (operator, newChildren)
}
// If the operator has multiple children and specifies child output
distributions (e.g. join),
@@ -246,7 +260,7 @@ case class EnsureRequirements(conf: SQLConf) extends
Rule[SparkPlan] {
}
}
- operator.withNewChildren(children)
+ parent.withNewChildren(children)
}
def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
http://git-wip-us.apache.org/repos/asf/spark/blob/2b0cc4e0/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 499f318..cd48577 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -1248,17 +1248,17 @@ class DataFrameSuite extends QueryTest with
SharedSQLContext {
}
/**
- * Verifies that there is no Exchange between the Aggregations for `df`
+ * Verifies that there is a single Aggregation for `df`
*/
- private def verifyNonExchangingAgg(df: DataFrame) = {
+ private def verifyNonExchangingSingleAgg(df: DataFrame) = {
var atFirstAgg: Boolean = false
df.queryExecution.executedPlan.foreach {
case agg: HashAggregateExec =>
- atFirstAgg = !atFirstAgg
- case _ =>
if (atFirstAgg) {
- fail("Should not have operators between the two aggregations")
+ fail("Should not have back to back Aggregates")
}
+ atFirstAgg = true
+ case _ =>
}
}
@@ -1292,9 +1292,10 @@ class DataFrameSuite extends QueryTest with
SharedSQLContext {
// Group by the column we are distributed by. This should generate a plan
with no exchange
// between the aggregates
val df3 = testData.repartition($"key").groupBy("key").count()
- verifyNonExchangingAgg(df3)
- verifyNonExchangingAgg(testData.repartition($"key", $"value")
+ verifyNonExchangingSingleAgg(df3)
+ verifyNonExchangingSingleAgg(testData.repartition($"key", $"value")
.groupBy("key", "value").count())
+ verifyNonExchangingSingleAgg(testData.repartition($"key").groupBy("key",
"value").count())
// Grouping by just the first distributeBy expr, need to exchange.
verifyExchangingAgg(testData.repartition($"key", $"value")
http://git-wip-us.apache.org/repos/asf/spark/blob/2b0cc4e0/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 13490c3..436ff59 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.execution
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{execution, Row}
+import org.apache.spark.sql.{execution, DataFrame, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute,
Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans.Inner
@@ -37,36 +37,65 @@ class PlannerSuite extends SharedSQLContext {
setupTestData()
- private def testPartialAggregationPlan(query: LogicalPlan): Unit = {
+ private def testPartialAggregationPlan(query: LogicalPlan): Seq[SparkPlan] =
{
val planner = spark.sessionState.planner
import planner._
- val plannedOption = Aggregation(query).headOption
- val planned =
- plannedOption.getOrElse(
- fail(s"Could query play aggregation query $query. Is it an aggregation
query?"))
- val aggregations = planned.collect { case n if n.nodeName contains
"Aggregate" => n }
-
- // For the new aggregation code path, there will be four aggregate
operator for
- // distinct aggregations.
- assert(
- aggregations.size == 2 || aggregations.size == 4,
- s"The plan of query $query does not have partial aggregations.")
+ val ensureRequirements = EnsureRequirements(spark.sessionState.conf)
+ val planned = Aggregation(query).headOption.map(ensureRequirements(_))
+ .getOrElse(fail(s"Could query play aggregation query $query. Is it an
aggregation query?"))
+ planned.collect { case n if n.nodeName contains "Aggregate" => n }
}
test("count is partially aggregated") {
val query =
testData.groupBy('value).agg(count('key)).queryExecution.analyzed
- testPartialAggregationPlan(query)
+ assert(testPartialAggregationPlan(query).size == 2,
+ s"The plan of query $query does not have partial aggregations.")
}
test("count distinct is partially aggregated") {
val query =
testData.groupBy('value).agg(countDistinct('key)).queryExecution.analyzed
testPartialAggregationPlan(query)
+ // For the new aggregation code path, there will be four aggregate
operator for distinct
+ // aggregations.
+ assert(testPartialAggregationPlan(query).size == 4,
+ s"The plan of query $query does not have partial aggregations.")
}
test("mixed aggregates are partially aggregated") {
val query =
testData.groupBy('value).agg(count('value),
countDistinct('key)).queryExecution.analyzed
- testPartialAggregationPlan(query)
+ // For the new aggregation code path, there will be four aggregate
operator for distinct
+ // aggregations.
+ assert(testPartialAggregationPlan(query).size == 4,
+ s"The plan of query $query does not have partial aggregations.")
+ }
+
+ test("non-partial aggregation for aggregates") {
+ withTempView("testNonPartialAggregation") {
+ val schema = StructType(StructField(s"value", IntegerType, true) :: Nil)
+ val row = Row.fromSeq(Seq.fill(1)(null))
+ val rowRDD = sparkContext.parallelize(row :: Nil)
+ spark.createDataFrame(rowRDD, schema).repartition($"value")
+ .createOrReplaceTempView("testNonPartialAggregation")
+
+ val planned1 = sql("SELECT SUM(value) FROM testNonPartialAggregation
GROUP BY value")
+ .queryExecution.executedPlan
+
+ // If input data are already partitioned and the same columns are used
in grouping keys and
+ // aggregation values, no partial aggregation exist in query plans.
+ val aggOps1 = planned1.collect { case n if n.nodeName contains
"Aggregate" => n }
+ assert(aggOps1.size == 1, s"The plan $planned1 has partial
aggregations.")
+
+ val planned2 = sql(
+ """
+ |SELECT t.value, SUM(DISTINCT t.value)
+ |FROM (SELECT * FROM testNonPartialAggregation ORDER BY value) t
+ |GROUP BY t.value
+ """.stripMargin).queryExecution.executedPlan
+
+ val aggOps2 = planned1.collect { case n if n.nodeName contains
"Aggregate" => n }
+ assert(aggOps2.size == 1, s"The plan $planned2 has partial
aggregations.")
+ }
}
test("sizeInBytes estimation of limit operator for broadcast hash join
optimization") {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]