[SPARK-4233] [SPARK-4367] [SPARK-3947] [SPARK-3056] [SQL] Aggregation Improvement
This is the first PR for the aggregation improvement, which is tracked by https://issues.apache.org/jira/browse/SPARK-4366 (umbrella JIRA). This PR contains work for its subtasks, SPARK-3056, SPARK-3947, SPARK-4233, and SPARK-4367. This PR introduces a new code path for evaluating aggregate functions. This code path is guarded by `spark.sql.useAggregate2` and by default the value of this flag is true. This new code path contains: * A new aggregate function interface (`AggregateFunction2`) and 7 built-int aggregate functions based on this new interface (`AVG`, `COUNT`, `FIRST`, `LAST`, `MAX`, `MIN`, `SUM`) * A UDAF interface (`UserDefinedAggregateFunction`) based on the new code path and two example UDAFs (`MyDoubleAvg` and `MyDoubleSum`). * A sort-based aggregate operator (`Aggregate2Sort`) for the new aggregate function interface . * A sort-based aggregate operator (`FinalAndCompleteAggregate2Sort`) for distinct aggregations (for distinct aggregations the query plan will use `Aggregate2Sort` and `FinalAndCompleteAggregate2Sort` together). With this change, `spark.sql.useAggregate2` is `true`, the flow of compiling an aggregation query is: 1. Our analyzer looks up functions and returns aggregate functions built based on the old aggregate function interface. 2. When our planner is compiling the physical plan, it tries try to convert all aggregate functions to the ones built based on the new interface. The planner will fallback to the old code path if any of the following two conditions is true: * code-gen is disabled. * there is any function that cannot be converted (right now, Hive UDAFs). * the schema of grouping expressions contain any complex data type. * There are multiple distinct columns. Right now, the new code path handles a single distinct column in the query (you can have multiple aggregate functions using that distinct column). For a query having a aggregate function with DISTINCT and regular aggregate functions, the generated plan will do partial aggregations for those regular aggregate function. Thanks chenghao-intel for his initial work on it. Author: Yin Huai <[email protected]> Author: Michael Armbrust <[email protected]> Closes #7458 from yhuai/UDAF and squashes the following commits: 7865f5e [Yin Huai] Put the catalyst expression in the comment of the generated code for it. b04d6c8 [Yin Huai] Remove unnecessary change. f1d5901 [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF 35b0520 [Yin Huai] Use semanticEquals to replace grouping expressions in the output of the aggregate operator. 3b43b24 [Yin Huai] bug fix. 00eb298 [Yin Huai] Make it compile. a3ca551 [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF e0afca3 [Yin Huai] Gracefully fallback to old aggregation code path. 8a8ac4a [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF 88c7d4d [Yin Huai] Enable spark.sql.useAggregate2 by default for testing purpose. dc96fd1 [Yin Huai] Many updates: 85c9c4b [Yin Huai] newline. 43de3de [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF c3614d7 [Yin Huai] Handle single distinct column. 68b8ee9 [Yin Huai] Support single distinct column set. WIP 3013579 [Yin Huai] Format. d678aee [Yin Huai] Remove AggregateExpressionSuite.scala since our built-in aggregate functions will be based on AlgebraicAggregate and we need to have another way to test it. e243ca6 [Yin Huai] Add aggregation iterators. a101960 [Yin Huai] Change MyJavaUDAF to MyDoubleSum. 594cdf5 [Yin Huai] Change existing AggregateExpression to AggregateExpression1 and add an AggregateExpression as the common interface for both AggregateExpression1 and AggregateExpression2. 380880f [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF 0a827b3 [Yin Huai] Add comments and doc. Move some classes to the right places. a19fea6 [Yin Huai] Add UDAF interface. 262d4c4 [Yin Huai] Make it compile. b2e358e [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF 6edb5ac [Yin Huai] Format update. 70b169c [Yin Huai] Remove groupOrdering. 4721936 [Yin Huai] Add CheckAggregateFunction to extendedCheckRules. d821a34 [Yin Huai] Cleanup. 32aea9c [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF 5b46d41 [Yin Huai] Bug fix. aff9534 [Yin Huai] Make Aggregate2Sort work with both algebraic AggregateFunctions and non-algebraic AggregateFunctions. 2857b55 [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF 4435f20 [Yin Huai] Add ConvertAggregateFunction to HiveContext's analyzer. 1b490ed [Michael Armbrust] make hive test 8cfa6a9 [Michael Armbrust] add test 1b0bb3f [Yin Huai] Do not bind references in AlgebraicAggregate and use code gen for all places. 072209f [Yin Huai] Bug fix: Handle expressions in grouping columns that are not attribute references. f7d9e54 [Michael Armbrust] Merge remote-tracking branch 'apache/master' into UDAF 39ee975 [Yin Huai] Code cleanup: Remove unnecesary AttributeReferences. b7720ba [Yin Huai] Add an analysis rule to convert aggregate function to the new version. 5c00f3f [Michael Armbrust] First draft of codegen 6bbc6ba [Michael Armbrust] now with correct answers\! f7996d0 [Michael Armbrust] Add AlgebraicAggregate dded1c5 [Yin Huai] wip Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c03299a1 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c03299a1 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c03299a1 Branch: refs/heads/master Commit: c03299a18b4e076cabb4b7833a1e7632c5c0dabe Parents: f4785f5 Author: Yin Huai <[email protected]> Authored: Tue Jul 21 23:26:11 2015 -0700 Committer: Reynold Xin <[email protected]> Committed: Tue Jul 21 23:26:11 2015 -0700 ---------------------------------------------------------------------- .../apache/spark/sql/catalyst/SqlParser.scala | 3 +- .../spark/sql/catalyst/analysis/Analyzer.scala | 24 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 1 + .../sql/catalyst/analysis/unresolved.scala | 5 +- .../catalyst/expressions/BoundAttribute.scala | 2 +- .../sql/catalyst/expressions/Expression.scala | 3 +- .../expressions/aggregate/functions.scala | 292 ++++++++ .../expressions/aggregate/interfaces.scala | 206 +++++ .../sql/catalyst/expressions/aggregates.scala | 100 +-- .../codegen/GenerateMutableProjection.scala | 21 +- .../spark/sql/catalyst/planning/patterns.scala | 4 +- .../catalyst/plans/logical/basicOperators.scala | 1 + .../scala/org/apache/spark/sql/SQLConf.scala | 5 + .../scala/org/apache/spark/sql/SQLContext.scala | 4 + .../org/apache/spark/sql/UDAFRegistration.scala | 35 + .../apache/spark/sql/execution/Aggregate.scala | 12 +- .../apache/spark/sql/execution/Exchange.scala | 11 +- .../sql/execution/GeneratedAggregate.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 100 ++- .../aggregate/aggregateOperators.scala | 173 +++++ .../aggregate/sortBasedIterators.scala | 749 +++++++++++++++++++ .../spark/sql/execution/aggregate/utils.scala | 364 +++++++++ .../spark/sql/expressions/aggregate/udaf.scala | 280 +++++++ .../scala/org/apache/spark/sql/functions.scala | 4 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 4 +- .../spark/sql/execution/PlannerSuite.scala | 26 +- .../HiveWindowFunctionQuerySuite.scala | 1 + .../execution/SortMergeCompatibilitySuite.scala | 7 + .../org/apache/spark/sql/hive/HiveContext.scala | 1 + .../org/apache/spark/sql/hive/HiveQl.scala | 7 +- .../org/apache/spark/sql/hive/hiveUDFs.scala | 8 +- .../spark/sql/hive/aggregate/MyDoubleAvg.java | 107 +++ .../spark/sql/hive/aggregate/MyDoubleSum.java | 100 +++ ...udf_unhex-0-50131c0ba7b7a6b65c789a5a8497bada | 1 + ...udf_unhex-1-11eb3cc5216d5446f4165007203acc47 | 1 + ...udf_unhex-2-a660886085b8651852b9b77934848ae4 | 14 + ...udf_unhex-3-4b2cf4050af229fde91ab53fd9f3af3e | 1 + ...udf_unhex-4-7d3e094f139892ecef17de3fd63ca3c3 | 1 + .../hive/execution/AggregationQuerySuite.scala | 507 +++++++++++++ 39 files changed, 3087 insertions(+), 100 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index d4ef04c..c04bd6c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -266,11 +266,12 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { } } | ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^ - { case udfName ~ exprs => UnresolvedFunction(udfName, exprs) } + { case udfName ~ exprs => UnresolvedFunction(udfName, exprs, isDistinct = false) } | ident ~ ("(" ~ DISTINCT ~> repsep(expression, ",")) <~ ")" ^^ { case udfName ~ exprs => lexical.normalizeKeyword(udfName) match { case "sum" => SumDistinct(exprs.head) case "count" => CountDistinct(exprs) + case name => UnresolvedFunction(name, exprs, isDistinct = true) case _ => throw new AnalysisException(s"function $udfName does not support DISTINCT") } } http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e58f3f6..8cadbc5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2, AggregateFunction2} import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -277,7 +278,7 @@ class Analyzer( Project( projectList.flatMap { case s: Star => s.expand(child.output, resolver) - case UnresolvedAlias(f @ UnresolvedFunction(_, args)) if containsStar(args) => + case UnresolvedAlias(f @ UnresolvedFunction(_, args, _)) if containsStar(args) => val expandedArgs = args.flatMap { case s: Star => s.expand(child.output, resolver) case o => o :: Nil @@ -517,9 +518,26 @@ class Analyzer( def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressions { - case u @ UnresolvedFunction(name, children) => + case u @ UnresolvedFunction(name, children, isDistinct) => withPosition(u) { - registry.lookupFunction(name, children) + registry.lookupFunction(name, children) match { + // We get an aggregate function built based on AggregateFunction2 interface. + // So, we wrap it in AggregateExpression2. + case agg2: AggregateFunction2 => AggregateExpression2(agg2, Complete, isDistinct) + // Currently, our old aggregate function interface supports SUM(DISTINCT ...) + // and COUTN(DISTINCT ...). + case sumDistinct: SumDistinct => sumDistinct + case countDistinct: CountDistinct => countDistinct + // DISTINCT is not meaningful with Max and Min. + case max: Max if isDistinct => max + case min: Min if isDistinct => min + // For other aggregate functions, DISTINCT keyword is not supported for now. + // Once we converted to the new code path, we will allow using DISTINCT keyword. + case other if isDistinct => + failAnalysis(s"$name does not support DISTINCT keyword.") + // If it does not have DISTINCT keyword, we will return it as is. + case other => other + } } } } http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index c7f9713..c203fce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 0daee19..03da45b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -73,7 +73,10 @@ object UnresolvedAttribute { def quoted(name: String): UnresolvedAttribute = new UnresolvedAttribute(Seq(name)) } -case class UnresolvedFunction(name: String, children: Seq[Expression]) +case class UnresolvedFunction( + name: String, + children: Seq[Expression], + isDistinct: Boolean) extends Expression with Unevaluable { override def dataType: DataType = throw new UnresolvedException(this, "dataType") http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index b09aea0..b10a3c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.types._ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) extends LeafExpression with NamedExpression { - override def toString: String = s"input[$ordinal]" + override def toString: String = s"input[$ordinal, $dataType]" override def eval(input: InternalRow): Any = input(ordinal) http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index aada252..29ae47e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -96,7 +96,8 @@ abstract class Expression extends TreeNode[Expression] { val primitive = ctx.freshName("primitive") val ve = GeneratedExpressionCode("", isNull, primitive) ve.code = genCode(ctx, ve) - ve + // Add `this` in the comment. + ve.copy(s"/* $this */\n" + ve.code) } /** http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala new file mode 100644 index 0000000..b924af4 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -0,0 +1,292 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + +case class Average(child: Expression) extends AlgebraicAggregate { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + // Return data type. + override def dataType: DataType = resultType + + // Expected input data type. + // TODO: Once we remove the old code path, we can use our analyzer to cast NullType + // to the default data type of the NumericType. + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) + + private val resultType = child.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType(precision + 4, scale + 4) + case DecimalType.Unlimited => DecimalType.Unlimited + case _ => DoubleType + } + + private val sumDataType = child.dataType match { + case _ @ DecimalType() => DecimalType.Unlimited + case _ => DoubleType + } + + private val currentSum = AttributeReference("currentSum", sumDataType)() + private val currentCount = AttributeReference("currentCount", LongType)() + + override val bufferAttributes = currentSum :: currentCount :: Nil + + override val initialValues = Seq( + /* currentSum = */ Cast(Literal(0), sumDataType), + /* currentCount = */ Literal(0L) + ) + + override val updateExpressions = Seq( + /* currentSum = */ + Add( + currentSum, + Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)), + /* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L) + ) + + override val mergeExpressions = Seq( + /* currentSum = */ currentSum.left + currentSum.right, + /* currentCount = */ currentCount.left + currentCount.right + ) + + // If all input are nulls, currentCount will be 0 and we will get null after the division. + override val evaluateExpression = Cast(currentSum, resultType) / Cast(currentCount, resultType) +} + +case class Count(child: Expression) extends AlgebraicAggregate { + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = false + + // Return data type. + override def dataType: DataType = LongType + + // Expected input data type. + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + + private val currentCount = AttributeReference("currentCount", LongType)() + + override val bufferAttributes = currentCount :: Nil + + override val initialValues = Seq( + /* currentCount = */ Literal(0L) + ) + + override val updateExpressions = Seq( + /* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L) + ) + + override val mergeExpressions = Seq( + /* currentCount = */ currentCount.left + currentCount.right + ) + + override val evaluateExpression = Cast(currentCount, LongType) +} + +case class First(child: Expression) extends AlgebraicAggregate { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + // First is not a deterministic function. + override def deterministic: Boolean = false + + // Return data type. + override def dataType: DataType = child.dataType + + // Expected input data type. + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + + private val first = AttributeReference("first", child.dataType)() + + override val bufferAttributes = first :: Nil + + override val initialValues = Seq( + /* first = */ Literal.create(null, child.dataType) + ) + + override val updateExpressions = Seq( + /* first = */ If(IsNull(first), child, first) + ) + + override val mergeExpressions = Seq( + /* first = */ If(IsNull(first.left), first.right, first.left) + ) + + override val evaluateExpression = first +} + +case class Last(child: Expression) extends AlgebraicAggregate { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + // Last is not a deterministic function. + override def deterministic: Boolean = false + + // Return data type. + override def dataType: DataType = child.dataType + + // Expected input data type. + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + + private val last = AttributeReference("last", child.dataType)() + + override val bufferAttributes = last :: Nil + + override val initialValues = Seq( + /* last = */ Literal.create(null, child.dataType) + ) + + override val updateExpressions = Seq( + /* last = */ If(IsNull(child), last, child) + ) + + override val mergeExpressions = Seq( + /* last = */ If(IsNull(last.right), last.left, last.right) + ) + + override val evaluateExpression = last +} + +case class Max(child: Expression) extends AlgebraicAggregate { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + // Return data type. + override def dataType: DataType = child.dataType + + // Expected input data type. + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + + private val max = AttributeReference("max", child.dataType)() + + override val bufferAttributes = max :: Nil + + override val initialValues = Seq( + /* max = */ Literal.create(null, child.dataType) + ) + + override val updateExpressions = Seq( + /* max = */ If(IsNull(child), max, If(IsNull(max), child, Greatest(Seq(max, child)))) + ) + + override val mergeExpressions = { + val greatest = Greatest(Seq(max.left, max.right)) + Seq( + /* max = */ If(IsNull(max.right), max.left, If(IsNull(max.left), max.right, greatest)) + ) + } + + override val evaluateExpression = max +} + +case class Min(child: Expression) extends AlgebraicAggregate { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + // Return data type. + override def dataType: DataType = child.dataType + + // Expected input data type. + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + + private val min = AttributeReference("min", child.dataType)() + + override val bufferAttributes = min :: Nil + + override val initialValues = Seq( + /* min = */ Literal.create(null, child.dataType) + ) + + override val updateExpressions = Seq( + /* min = */ If(IsNull(child), min, If(IsNull(min), child, Least(Seq(min, child)))) + ) + + override val mergeExpressions = { + val least = Least(Seq(min.left, min.right)) + Seq( + /* min = */ If(IsNull(min.right), min.left, If(IsNull(min.left), min.right, least)) + ) + } + + override val evaluateExpression = min +} + +case class Sum(child: Expression) extends AlgebraicAggregate { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + // Return data type. + override def dataType: DataType = resultType + + // Expected input data type. + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(LongType, DoubleType, DecimalType, NullType)) + + private val resultType = child.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType(precision + 4, scale + 4) + case DecimalType.Unlimited => DecimalType.Unlimited + case _ => child.dataType + } + + private val sumDataType = child.dataType match { + case _ @ DecimalType() => DecimalType.Unlimited + case _ => child.dataType + } + + private val currentSum = AttributeReference("currentSum", sumDataType)() + + private val zero = Cast(Literal(0), sumDataType) + + override val bufferAttributes = currentSum :: Nil + + override val initialValues = Seq( + /* currentSum = */ Literal.create(null, sumDataType) + ) + + override val updateExpressions = Seq( + /* currentSum = */ + Coalesce(Seq(Add(Coalesce(Seq(currentSum, zero)), Cast(child, sumDataType)), currentSum)) + ) + + override val mergeExpressions = { + val add = Add(Coalesce(Seq(currentSum.left, zero)), Cast(currentSum.right, sumDataType)) + Seq( + /* currentSum = */ + Coalesce(Seq(add, currentSum.left)) + ) + } + + override val evaluateExpression = Cast(currentSum, resultType) +} http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala new file mode 100644 index 0000000..577ede7 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types._ + +/** The mode of an [[AggregateFunction1]]. */ +private[sql] sealed trait AggregateMode + +/** + * An [[AggregateFunction1]] with [[Partial]] mode is used for partial aggregation. + * This function updates the given aggregation buffer with the original input of this + * function. When it has processed all input rows, the aggregation buffer is returned. + */ +private[sql] case object Partial extends AggregateMode + +/** + * An [[AggregateFunction1]] with [[PartialMerge]] mode is used to merge aggregation buffers + * containing intermediate results for this function. + * This function updates the given aggregation buffer by merging multiple aggregation buffers. + * When it has processed all input rows, the aggregation buffer is returned. + */ +private[sql] case object PartialMerge extends AggregateMode + +/** + * An [[AggregateFunction1]] with [[PartialMerge]] mode is used to merge aggregation buffers + * containing intermediate results for this function and the generate final result. + * This function updates the given aggregation buffer by merging multiple aggregation buffers. + * When it has processed all input rows, the final result of this function is returned. + */ +private[sql] case object Final extends AggregateMode + +/** + * An [[AggregateFunction2]] with [[Partial]] mode is used to evaluate this function directly + * from original input rows without any partial aggregation. + * This function updates the given aggregation buffer with the original input of this + * function. When it has processed all input rows, the final result of this function is returned. + */ +private[sql] case object Complete extends AggregateMode + +/** + * A place holder expressions used in code-gen, it does not change the corresponding value + * in the row. + */ +private[sql] case object NoOp extends Expression with Unevaluable { + override def nullable: Boolean = true + override def eval(input: InternalRow): Any = { + throw new TreeNodeException( + this, s"No function to evaluate expression. type: ${this.nodeName}") + } + override def dataType: DataType = NullType + override def children: Seq[Expression] = Nil +} + +/** + * A container for an [[AggregateFunction2]] with its [[AggregateMode]] and a field + * (`isDistinct`) indicating if DISTINCT keyword is specified for this function. + * @param aggregateFunction + * @param mode + * @param isDistinct + */ +private[sql] case class AggregateExpression2( + aggregateFunction: AggregateFunction2, + mode: AggregateMode, + isDistinct: Boolean) extends AggregateExpression { + + override def children: Seq[Expression] = aggregateFunction :: Nil + override def dataType: DataType = aggregateFunction.dataType + override def foldable: Boolean = false + override def nullable: Boolean = aggregateFunction.nullable + + override def references: AttributeSet = { + val childReferemces = mode match { + case Partial | Complete => aggregateFunction.references.toSeq + case PartialMerge | Final => aggregateFunction.bufferAttributes + } + + AttributeSet(childReferemces) + } + + override def toString: String = s"(${aggregateFunction}2,mode=$mode,isDistinct=$isDistinct)" +} + +abstract class AggregateFunction2 + extends Expression with ImplicitCastInputTypes { + + self: Product => + + /** An aggregate function is not foldable. */ + override def foldable: Boolean = false + + /** + * The offset of this function's buffer in the underlying buffer shared with other functions. + */ + var bufferOffset: Int = 0 + + /** The schema of the aggregation buffer. */ + def bufferSchema: StructType + + /** Attributes of fields in bufferSchema. */ + def bufferAttributes: Seq[AttributeReference] + + /** Clones bufferAttributes. */ + def cloneBufferAttributes: Seq[Attribute] + + /** + * Initializes its aggregation buffer located in `buffer`. + * It will use bufferOffset to find the starting point of + * its buffer in the given `buffer` shared with other functions. + */ + def initialize(buffer: MutableRow): Unit + + /** + * Updates its aggregation buffer located in `buffer` based on the given `input`. + * It will use bufferOffset to find the starting point of its buffer in the given `buffer` + * shared with other functions. + */ + def update(buffer: MutableRow, input: InternalRow): Unit + + /** + * Updates its aggregation buffer located in `buffer1` by combining intermediate results + * in the current buffer and intermediate results from another buffer `buffer2`. + * It will use bufferOffset to find the starting point of its buffer in the given `buffer1` + * and `buffer2`. + */ + def merge(buffer1: MutableRow, buffer2: InternalRow): Unit + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = + throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") +} + +/** + * A helper class for aggregate functions that can be implemented in terms of catalyst expressions. + */ +abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable { + self: Product => + + val initialValues: Seq[Expression] + val updateExpressions: Seq[Expression] + val mergeExpressions: Seq[Expression] + val evaluateExpression: Expression + + override lazy val cloneBufferAttributes = bufferAttributes.map(_.newInstance()) + + /** + * A helper class for representing an attribute used in merging two + * aggregation buffers. When merging two buffers, `bufferLeft` and `bufferRight`, + * we merge buffer values and then update bufferLeft. A [[RichAttribute]] + * of an [[AttributeReference]] `a` has two functions `left` and `right`, + * which represent `a` in `bufferLeft` and `bufferRight`, respectively. + * @param a + */ + implicit class RichAttribute(a: AttributeReference) { + /** Represents this attribute at the mutable buffer side. */ + def left: AttributeReference = a + + /** Represents this attribute at the input buffer side (the data value is read-only). */ + def right: AttributeReference = cloneBufferAttributes(bufferAttributes.indexOf(a)) + } + + /** An AlgebraicAggregate's bufferSchema is derived from bufferAttributes. */ + override def bufferSchema: StructType = StructType.fromAttributes(bufferAttributes) + + override def initialize(buffer: MutableRow): Unit = { + var i = 0 + while (i < bufferAttributes.size) { + buffer(i + bufferOffset) = initialValues(i).eval() + i += 1 + } + } + + override def update(buffer: MutableRow, input: InternalRow): Unit = { + throw new UnsupportedOperationException( + "AlgebraicAggregate's update should not be called directly") + } + + override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { + throw new UnsupportedOperationException( + "AlgebraicAggregate's merge should not be called directly") + } + + override def eval(buffer: InternalRow): Any = { + throw new UnsupportedOperationException( + "AlgebraicAggregate's eval should not be called directly") + } +} + http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index d705a12..e07c920 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -27,7 +27,9 @@ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet -trait AggregateExpression extends Expression with Unevaluable { +trait AggregateExpression extends Expression with Unevaluable + +trait AggregateExpression1 extends AggregateExpression { /** * Aggregate expressions should not be foldable. @@ -38,7 +40,7 @@ trait AggregateExpression extends Expression with Unevaluable { * Creates a new instance that can be used to compute this aggregate expression for a group * of input rows/ */ - def newInstance(): AggregateFunction + def newInstance(): AggregateFunction1 } /** @@ -54,10 +56,10 @@ case class SplitEvaluation( partialEvaluations: Seq[NamedExpression]) /** - * An [[AggregateExpression]] that can be partially computed without seeing all relevant tuples. + * An [[AggregateExpression1]] that can be partially computed without seeing all relevant tuples. * These partial evaluations can then be combined to compute the actual answer. */ -trait PartialAggregate extends AggregateExpression { +trait PartialAggregate1 extends AggregateExpression1 { /** * Returns a [[SplitEvaluation]] that computes this aggregation using partial aggregation. @@ -67,13 +69,13 @@ trait PartialAggregate extends AggregateExpression { /** * A specific implementation of an aggregate function. Used to wrap a generic - * [[AggregateExpression]] with an algorithm that will be used to compute one specific result. + * [[AggregateExpression1]] with an algorithm that will be used to compute one specific result. */ -abstract class AggregateFunction - extends LeafExpression with AggregateExpression with Serializable { +abstract class AggregateFunction1 + extends LeafExpression with AggregateExpression1 with Serializable { /** Base should return the generic aggregate expression that this function is computing */ - val base: AggregateExpression + val base: AggregateExpression1 override def nullable: Boolean = base.nullable override def dataType: DataType = base.dataType @@ -81,12 +83,12 @@ abstract class AggregateFunction def update(input: InternalRow): Unit // Do we really need this? - override def newInstance(): AggregateFunction = { + override def newInstance(): AggregateFunction1 = { makeCopy(productIterator.map { case a: AnyRef => a }.toArray) } } -case class Min(child: Expression) extends UnaryExpression with PartialAggregate { +case class Min(child: Expression) extends UnaryExpression with PartialAggregate1 { override def nullable: Boolean = true override def dataType: DataType = child.dataType @@ -102,7 +104,7 @@ case class Min(child: Expression) extends UnaryExpression with PartialAggregate TypeUtils.checkForOrderingExpr(child.dataType, "function min") } -case class MinFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { +case class MinFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. val currentMin: MutableLiteral = MutableLiteral(null, expr.dataType) @@ -119,7 +121,7 @@ case class MinFunction(expr: Expression, base: AggregateExpression) extends Aggr override def eval(input: InternalRow): Any = currentMin.value } -case class Max(child: Expression) extends UnaryExpression with PartialAggregate { +case class Max(child: Expression) extends UnaryExpression with PartialAggregate1 { override def nullable: Boolean = true override def dataType: DataType = child.dataType @@ -135,7 +137,7 @@ case class Max(child: Expression) extends UnaryExpression with PartialAggregate TypeUtils.checkForOrderingExpr(child.dataType, "function max") } -case class MaxFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { +case class MaxFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. val currentMax: MutableLiteral = MutableLiteral(null, expr.dataType) @@ -152,7 +154,7 @@ case class MaxFunction(expr: Expression, base: AggregateExpression) extends Aggr override def eval(input: InternalRow): Any = currentMax.value } -case class Count(child: Expression) extends UnaryExpression with PartialAggregate { +case class Count(child: Expression) extends UnaryExpression with PartialAggregate1 { override def nullable: Boolean = false override def dataType: LongType.type = LongType @@ -165,7 +167,7 @@ case class Count(child: Expression) extends UnaryExpression with PartialAggregat override def newInstance(): CountFunction = new CountFunction(child, this) } -case class CountFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { +case class CountFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. var count: Long = _ @@ -180,7 +182,7 @@ case class CountFunction(expr: Expression, base: AggregateExpression) extends Ag override def eval(input: InternalRow): Any = count } -case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate { +case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate1 { def this() = this(null) override def children: Seq[Expression] = expressions @@ -200,8 +202,8 @@ case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate case class CountDistinctFunction( @transient expr: Seq[Expression], - @transient base: AggregateExpression) - extends AggregateFunction { + @transient base: AggregateExpression1) + extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. @@ -220,7 +222,7 @@ case class CountDistinctFunction( override def eval(input: InternalRow): Any = seen.size.toLong } -case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression { +case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression1 { def this() = this(null) override def children: Seq[Expression] = expressions @@ -233,8 +235,8 @@ case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpress case class CollectHashSetFunction( @transient expr: Seq[Expression], - @transient base: AggregateExpression) - extends AggregateFunction { + @transient base: AggregateExpression1) + extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. @@ -255,7 +257,7 @@ case class CollectHashSetFunction( } } -case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression { +case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression1 { def this() = this(null) override def children: Seq[Expression] = inputSet :: Nil @@ -269,8 +271,8 @@ case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression case class CombineSetsAndCountFunction( @transient inputSet: Expression, - @transient base: AggregateExpression) - extends AggregateFunction { + @transient base: AggregateExpression1) + extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. @@ -305,7 +307,7 @@ private[sql] case object HyperLogLogUDT extends UserDefinedType[HyperLogLog] { } case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double) - extends UnaryExpression with AggregateExpression { + extends UnaryExpression with AggregateExpression1 { override def nullable: Boolean = false override def dataType: DataType = HyperLogLogUDT @@ -317,9 +319,9 @@ case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double) case class ApproxCountDistinctPartitionFunction( expr: Expression, - base: AggregateExpression, + base: AggregateExpression1, relativeSD: Double) - extends AggregateFunction { + extends AggregateFunction1 { def this() = this(null, null, 0) // Required for serialization. private val hyperLogLog = new HyperLogLog(relativeSD) @@ -335,7 +337,7 @@ case class ApproxCountDistinctPartitionFunction( } case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double) - extends UnaryExpression with AggregateExpression { + extends UnaryExpression with AggregateExpression1 { override def nullable: Boolean = false override def dataType: LongType.type = LongType @@ -347,9 +349,9 @@ case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double) case class ApproxCountDistinctMergeFunction( expr: Expression, - base: AggregateExpression, + base: AggregateExpression1, relativeSD: Double) - extends AggregateFunction { + extends AggregateFunction1 { def this() = this(null, null, 0) // Required for serialization. private val hyperLogLog = new HyperLogLog(relativeSD) @@ -363,7 +365,7 @@ case class ApproxCountDistinctMergeFunction( } case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) - extends UnaryExpression with PartialAggregate { + extends UnaryExpression with PartialAggregate1 { override def nullable: Boolean = false override def dataType: LongType.type = LongType @@ -381,7 +383,7 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) override def newInstance(): CountDistinctFunction = new CountDistinctFunction(child :: Nil, this) } -case class Average(child: Expression) extends UnaryExpression with PartialAggregate { +case class Average(child: Expression) extends UnaryExpression with PartialAggregate1 { override def prettyName: String = "avg" @@ -427,8 +429,8 @@ case class Average(child: Expression) extends UnaryExpression with PartialAggreg TypeUtils.checkForNumericExpr(child.dataType, "function average") } -case class AverageFunction(expr: Expression, base: AggregateExpression) - extends AggregateFunction { +case class AverageFunction(expr: Expression, base: AggregateExpression1) + extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. @@ -474,7 +476,7 @@ case class AverageFunction(expr: Expression, base: AggregateExpression) } } -case class Sum(child: Expression) extends UnaryExpression with PartialAggregate { +case class Sum(child: Expression) extends UnaryExpression with PartialAggregate1 { override def nullable: Boolean = true @@ -509,7 +511,7 @@ case class Sum(child: Expression) extends UnaryExpression with PartialAggregate TypeUtils.checkForNumericExpr(child.dataType, "function sum") } -case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { +case class SumFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. private val calcType = @@ -554,7 +556,7 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr * <-- null <-- no data * null <-- null <-- no data */ -case class CombineSum(child: Expression) extends AggregateExpression { +case class CombineSum(child: Expression) extends AggregateExpression1 { def this() = this(null) override def children: Seq[Expression] = child :: Nil @@ -564,8 +566,8 @@ case class CombineSum(child: Expression) extends AggregateExpression { override def newInstance(): CombineSumFunction = new CombineSumFunction(child, this) } -case class CombineSumFunction(expr: Expression, base: AggregateExpression) - extends AggregateFunction { +case class CombineSumFunction(expr: Expression, base: AggregateExpression1) + extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. @@ -601,7 +603,7 @@ case class CombineSumFunction(expr: Expression, base: AggregateExpression) } } -case class SumDistinct(child: Expression) extends UnaryExpression with PartialAggregate { +case class SumDistinct(child: Expression) extends UnaryExpression with PartialAggregate1 { def this() = this(null) override def nullable: Boolean = true @@ -627,8 +629,8 @@ case class SumDistinct(child: Expression) extends UnaryExpression with PartialAg TypeUtils.checkForNumericExpr(child.dataType, "function sumDistinct") } -case class SumDistinctFunction(expr: Expression, base: AggregateExpression) - extends AggregateFunction { +case class SumDistinctFunction(expr: Expression, base: AggregateExpression1) + extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. @@ -653,7 +655,7 @@ case class SumDistinctFunction(expr: Expression, base: AggregateExpression) } } -case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression { +case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression1 { def this() = this(null, null) override def children: Seq[Expression] = inputSet :: Nil @@ -667,8 +669,8 @@ case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends Agg case class CombineSetsAndSumFunction( @transient inputSet: Expression, - @transient base: AggregateExpression) - extends AggregateFunction { + @transient base: AggregateExpression1) + extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. @@ -695,7 +697,7 @@ case class CombineSetsAndSumFunction( } } -case class First(child: Expression) extends UnaryExpression with PartialAggregate { +case class First(child: Expression) extends UnaryExpression with PartialAggregate1 { override def nullable: Boolean = true override def dataType: DataType = child.dataType override def toString: String = s"FIRST($child)" @@ -709,7 +711,7 @@ case class First(child: Expression) extends UnaryExpression with PartialAggregat override def newInstance(): FirstFunction = new FirstFunction(child, this) } -case class FirstFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { +case class FirstFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. var result: Any = null @@ -723,7 +725,7 @@ case class FirstFunction(expr: Expression, base: AggregateExpression) extends Ag override def eval(input: InternalRow): Any = result } -case class Last(child: Expression) extends UnaryExpression with PartialAggregate { +case class Last(child: Expression) extends UnaryExpression with PartialAggregate1 { override def references: AttributeSet = child.references override def nullable: Boolean = true override def dataType: DataType = child.dataType @@ -738,7 +740,7 @@ case class Last(child: Expression) extends UnaryExpression with PartialAggregate override def newInstance(): LastFunction = new LastFunction(child, this) } -case class LastFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { +case class LastFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. var result: Any = null http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 03b4b3c..d838268 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp import scala.collection.mutable.ArrayBuffer @@ -38,15 +39,17 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu protected def create(expressions: Seq[Expression]): (() => MutableProjection) = { val ctx = newCodeGenContext() - val projectionCode = expressions.zipWithIndex.map { case (e, i) => - val evaluationCode = e.gen(ctx) - evaluationCode.code + - s""" - if(${evaluationCode.isNull}) - mutableRow.setNullAt($i); - else - ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)}; - """ + val projectionCode = expressions.zipWithIndex.map { + case (NoOp, _) => "" + case (e, i) => + val evaluationCode = e.gen(ctx) + evaluationCode.code + + s""" + if(${evaluationCode.isNull}) + mutableRow.setNullAt($i); + else + ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)}; + """ } // collect projections into blocks as function has 64kb codesize limit in JVM val projectionBlocks = new ArrayBuffer[String]() http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 179a348..b8e3b0d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -129,10 +129,10 @@ object PartialAggregation { case logical.Aggregate(groupingExpressions, aggregateExpressions, child) => // Collect all aggregate expressions. val allAggregates = - aggregateExpressions.flatMap(_ collect { case a: AggregateExpression => a}) + aggregateExpressions.flatMap(_ collect { case a: AggregateExpression1 => a}) // Collect all aggregate expressions that can be computed partially. val partialAggregates = - aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => p}) + aggregateExpressions.flatMap(_ collect { case p: PartialAggregate1 => p}) // Only do partial aggregation if supported by all aggregate expressions. if (allAggregates.size == partialAggregates.size) { http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 986c315..6aefa9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 78c780b..1474b17 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -402,6 +402,9 @@ private[spark] object SQLConf { defaultValue = Some(true), isPublic = false) + val USE_SQL_AGGREGATE2 = booleanConf("spark.sql.useAggregate2", + defaultValue = Some(true), doc = "<TODO>") + val USE_SQL_SERIALIZER2 = booleanConf( "spark.sql.useSerializer2", defaultValue = Some(true), isPublic = false) @@ -473,6 +476,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED) + private[spark] def useSqlAggregate2: Boolean = getConf(USE_SQL_AGGREGATE2) + private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2) private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD) http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 8b4528b..49bfe74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -285,6 +285,9 @@ class SQLContext(@transient val sparkContext: SparkContext) @transient val udf: UDFRegistration = new UDFRegistration(this) + @transient + val udaf: UDAFRegistration = new UDAFRegistration(this) + /** * Returns true if the table is currently cached in-memory. * @group cachemgmt @@ -863,6 +866,7 @@ class SQLContext(@transient val sparkContext: SparkContext) DDLStrategy :: TakeOrderedAndProject :: HashAggregation :: + Aggregation :: LeftSemiJoin :: HashJoin :: InMemoryScans :: http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala new file mode 100644 index 0000000..5b872f5 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.expressions.{Expression} +import org.apache.spark.sql.expressions.aggregate.{ScalaUDAF, UserDefinedAggregateFunction} + +class UDAFRegistration private[sql] (sqlContext: SQLContext) extends Logging { + + private val functionRegistry = sqlContext.functionRegistry + + def register( + name: String, + func: UserDefinedAggregateFunction): UserDefinedAggregateFunction = { + def builder(children: Seq[Expression]) = ScalaUDAF(children, func) + functionRegistry.registerFunction(name, builder) + func + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index 3cd60a2..c2c9453 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -68,14 +68,14 @@ case class Aggregate( * output. */ case class ComputedAggregate( - unbound: AggregateExpression, - aggregate: AggregateExpression, + unbound: AggregateExpression1, + aggregate: AggregateExpression1, resultAttribute: AttributeReference) /** A list of aggregates that need to be computed for each group. */ private[this] val computedAggregates = aggregateExpressions.flatMap { agg => agg.collect { - case a: AggregateExpression => + case a: AggregateExpression1 => ComputedAggregate( a, BindReferences.bindReference(a, child.output), @@ -87,8 +87,8 @@ case class Aggregate( private[this] val computedSchema = computedAggregates.map(_.resultAttribute) /** Creates a new aggregate buffer for a group. */ - private[this] def newAggregateBuffer(): Array[AggregateFunction] = { - val buffer = new Array[AggregateFunction](computedAggregates.length) + private[this] def newAggregateBuffer(): Array[AggregateFunction1] = { + val buffer = new Array[AggregateFunction1](computedAggregates.length) var i = 0 while (i < computedAggregates.length) { buffer(i) = computedAggregates(i).aggregate.newInstance() @@ -146,7 +146,7 @@ case class Aggregate( } } else { child.execute().mapPartitions { iter => - val hashTable = new HashMap[InternalRow, Array[AggregateFunction]] + val hashTable = new HashMap[InternalRow, Array[AggregateFunction1]] val groupingProjection = new InterpretedMutableProjection(groupingExpressions, child.output) var currentRow: InternalRow = null http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 2750053..d31e265 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -247,8 +247,15 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ } def addSortIfNecessary(child: SparkPlan): SparkPlan = { - if (rowOrdering.nonEmpty && child.outputOrdering != rowOrdering) { - sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child) + + if (rowOrdering.nonEmpty) { + // If child.outputOrdering is [a, b] and rowOrdering is [a], we do not need to sort. + val minSize = Seq(rowOrdering.size, child.outputOrdering.size).min + if (minSize == 0 || rowOrdering.take(minSize) != child.outputOrdering.take(minSize)) { + sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child) + } else { + child + } } else { child } http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index ecde9c5..0e63f2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -69,7 +69,7 @@ case class GeneratedAggregate( protected override def doExecute(): RDD[InternalRow] = { val aggregatesToCompute = aggregateExpressions.flatMap { a => - a.collect { case agg: AggregateExpression => agg} + a.collect { case agg: AggregateExpression1 => agg} } // If you add any new function support, please add tests in org.apache.spark.sql.SQLQuerySuite http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 8cef7f2..f54aa20 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.{SQLContext, Strategy, execution} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2} import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan} @@ -148,7 +149,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { if canBeCodeGened( allAggregates(partialComputation) ++ allAggregates(rewrittenAggregateExpressions)) && - codegenEnabled => + codegenEnabled && + !canBeConvertedToNewAggregation(plan) => execution.GeneratedAggregate( partial = false, namedGroupingAttributes, @@ -167,7 +169,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { rewrittenAggregateExpressions, groupingExpressions, partialComputation, - child) => + child) if !canBeConvertedToNewAggregation(plan) => execution.Aggregate( partial = false, namedGroupingAttributes, @@ -181,7 +183,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case _ => Nil } - def canBeCodeGened(aggs: Seq[AggregateExpression]): Boolean = !aggs.exists { + def canBeConvertedToNewAggregation(plan: LogicalPlan): Boolean = { + aggregate.Utils.tryConvert( + plan, + sqlContext.conf.useSqlAggregate2, + sqlContext.conf.codegenEnabled).isDefined + } + + def canBeCodeGened(aggs: Seq[AggregateExpression1]): Boolean = !aggs.exists { case _: CombineSum | _: Sum | _: Count | _: Max | _: Min | _: CombineSetsAndCount => false // The generated set implementation is pretty limited ATM. case CollectHashSet(exprs) if exprs.size == 1 && @@ -189,10 +198,74 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case _ => true } - def allAggregates(exprs: Seq[Expression]): Seq[AggregateExpression] = - exprs.flatMap(_.collect { case a: AggregateExpression => a }) + def allAggregates(exprs: Seq[Expression]): Seq[AggregateExpression1] = + exprs.flatMap(_.collect { case a: AggregateExpression1 => a }) } + /** + * Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface. + */ + object Aggregation extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case p: logical.Aggregate => + val converted = + aggregate.Utils.tryConvert( + p, + sqlContext.conf.useSqlAggregate2, + sqlContext.conf.codegenEnabled) + converted match { + case None => Nil // Cannot convert to new aggregation code path. + case Some(logical.Aggregate(groupingExpressions, resultExpressions, child)) => + // Extracts all distinct aggregate expressions from the resultExpressions. + val aggregateExpressions = resultExpressions.flatMap { expr => + expr.collect { + case agg: AggregateExpression2 => agg + } + }.toSet.toSeq + // For those distinct aggregate expressions, we create a map from the + // aggregate function to the corresponding attribute of the function. + val aggregateFunctionMap = aggregateExpressions.map { agg => + val aggregateFunction = agg.aggregateFunction + (aggregateFunction, agg.isDistinct) -> + Alias(aggregateFunction, aggregateFunction.toString)().toAttribute + }.toMap + + val (functionsWithDistinct, functionsWithoutDistinct) = + aggregateExpressions.partition(_.isDistinct) + if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { + // This is a sanity check. We should not reach here when we have multiple distinct + // column sets (aggregate.NewAggregation will not match). + sys.error( + "Multiple distinct column sets are not supported by the new aggregation" + + "code path.") + } + + val aggregateOperator = + if (functionsWithDistinct.isEmpty) { + aggregate.Utils.planAggregateWithoutDistinct( + groupingExpressions, + aggregateExpressions, + aggregateFunctionMap, + resultExpressions, + planLater(child)) + } else { + aggregate.Utils.planAggregateWithOneDistinct( + groupingExpressions, + functionsWithDistinct, + functionsWithoutDistinct, + aggregateFunctionMap, + resultExpressions, + planLater(child)) + } + + aggregateOperator + } + + case _ => Nil + } + } + + object BroadcastNestedLoopJoin extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Join(left, right, joinType, condition) => @@ -336,8 +409,21 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Filter(condition, planLater(child)) :: Nil case e @ logical.Expand(_, _, _, child) => execution.Expand(e.projections, e.output, planLater(child)) :: Nil - case logical.Aggregate(group, agg, child) => - execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil + case a @ logical.Aggregate(group, agg, child) => { + val useNewAggregation = + aggregate.Utils.tryConvert( + a, + sqlContext.conf.useSqlAggregate2, + sqlContext.conf.codegenEnabled).isDefined + if (useNewAggregation) { + // If this logical.Aggregate can be planned to use new aggregation code path + // (i.e. it can be planned by the Strategy Aggregation), we will not use the old + // aggregation code path. + Nil + } else { + execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil + } + } case logical.Window(projectList, windowExpressions, spec, child) => execution.Window(projectList, windowExpressions, spec, planLater(child)) :: Nil case logical.Sample(lb, ub, withReplacement, seed, child) => http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala new file mode 100644 index 0000000..0c90828 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution} +import org.apache.spark.sql.execution.{SparkPlan, UnaryNode} + +case class Aggregate2Sort( + requiredChildDistributionExpressions: Option[Seq[Expression]], + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression2], + aggregateAttributes: Seq[Attribute], + resultExpressions: Seq[NamedExpression], + child: SparkPlan) + extends UnaryNode { + + override def canProcessUnsafeRows: Boolean = true + + override def references: AttributeSet = { + val referencesInResults = + AttributeSet(resultExpressions.flatMap(_.references)) -- AttributeSet(aggregateAttributes) + + AttributeSet( + groupingExpressions.flatMap(_.references) ++ + aggregateExpressions.flatMap(_.references) ++ + referencesInResults) + } + + override def requiredChildDistribution: List[Distribution] = { + requiredChildDistributionExpressions match { + case Some(exprs) if exprs.length == 0 => AllTuples :: Nil + case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil + case None => UnspecifiedDistribution :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = { + // TODO: We should not sort the input rows if they are just in reversed order. + groupingExpressions.map(SortOrder(_, Ascending)) :: Nil + } + + override def outputOrdering: Seq[SortOrder] = { + // It is possible that the child.outputOrdering starts with the required + // ordering expressions (e.g. we require [a] as the sort expression and the + // child's outputOrdering is [a, b]). We can only guarantee the output rows + // are sorted by values of groupingExpressions. + groupingExpressions.map(SortOrder(_, Ascending)) + } + + override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { + child.execute().mapPartitions { iter => + if (aggregateExpressions.length == 0) { + new GroupingIterator( + groupingExpressions, + resultExpressions, + newMutableProjection, + child.output, + iter) + } else { + val aggregationIterator: SortAggregationIterator = { + aggregateExpressions.map(_.mode).distinct.toList match { + case Partial :: Nil => + new PartialSortAggregationIterator( + groupingExpressions, + aggregateExpressions, + newMutableProjection, + child.output, + iter) + case PartialMerge :: Nil => + new PartialMergeSortAggregationIterator( + groupingExpressions, + aggregateExpressions, + newMutableProjection, + child.output, + iter) + case Final :: Nil => + new FinalSortAggregationIterator( + groupingExpressions, + aggregateExpressions, + aggregateAttributes, + resultExpressions, + newMutableProjection, + child.output, + iter) + case other => + sys.error( + s"Could not evaluate ${aggregateExpressions} because we do not support evaluate " + + s"modes $other in this operator.") + } + } + + aggregationIterator + } + } + } +} + +case class FinalAndCompleteAggregate2Sort( + previousGroupingExpressions: Seq[NamedExpression], + groupingExpressions: Seq[NamedExpression], + finalAggregateExpressions: Seq[AggregateExpression2], + finalAggregateAttributes: Seq[Attribute], + completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateAttributes: Seq[Attribute], + resultExpressions: Seq[NamedExpression], + child: SparkPlan) + extends UnaryNode { + override def references: AttributeSet = { + val referencesInResults = + AttributeSet(resultExpressions.flatMap(_.references)) -- + AttributeSet(finalAggregateExpressions) -- + AttributeSet(completeAggregateExpressions) + + AttributeSet( + groupingExpressions.flatMap(_.references) ++ + finalAggregateExpressions.flatMap(_.references) ++ + completeAggregateExpressions.flatMap(_.references) ++ + referencesInResults) + } + + override def requiredChildDistribution: List[Distribution] = { + if (groupingExpressions.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupingExpressions) :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + groupingExpressions.map(SortOrder(_, Ascending)) :: Nil + + override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { + child.execute().mapPartitions { iter => + + new FinalAndCompleteSortAggregationIterator( + previousGroupingExpressions.length, + groupingExpressions, + finalAggregateExpressions, + finalAggregateAttributes, + completeAggregateExpressions, + completeAggregateAttributes, + resultExpressions, + newMutableProjection, + child.output, + iter) + } + } + +} --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
