[SPARK-4233] [SPARK-4367] [SPARK-3947] [SPARK-3056] [SQL] Aggregation 
Improvement

This is the first PR for the aggregation improvement, which is tracked by 
https://issues.apache.org/jira/browse/SPARK-4366 (umbrella JIRA). This PR 
contains work for its subtasks, SPARK-3056, SPARK-3947, SPARK-4233, and 
SPARK-4367.

This PR introduces a new code path for evaluating aggregate functions. This 
code path is guarded by `spark.sql.useAggregate2` and by default the value of 
this flag is true.

This new code path contains:
* A new aggregate function interface (`AggregateFunction2`) and 7 built-int 
aggregate functions based on this new interface (`AVG`, `COUNT`, `FIRST`, 
`LAST`, `MAX`, `MIN`, `SUM`)
* A UDAF interface (`UserDefinedAggregateFunction`) based on the new code path 
and two example UDAFs (`MyDoubleAvg` and `MyDoubleSum`).
* A sort-based aggregate operator (`Aggregate2Sort`) for the new aggregate 
function interface .
* A sort-based aggregate operator (`FinalAndCompleteAggregate2Sort`) for 
distinct aggregations (for distinct aggregations the query plan will use 
`Aggregate2Sort` and `FinalAndCompleteAggregate2Sort` together).

With this change, `spark.sql.useAggregate2` is `true`, the flow of compiling an 
aggregation query is:
1. Our analyzer looks up functions and returns aggregate functions built based 
on the old aggregate function interface.
2. When our planner is compiling the physical plan, it tries try to convert all 
aggregate functions to the ones built based on the new interface. The planner 
will fallback to the old code path if any of the following two conditions is 
true:
* code-gen is disabled.
* there is any function that cannot be converted (right now, Hive UDAFs).
* the schema of grouping expressions contain any complex data type.
* There are multiple distinct columns.

Right now, the new code path handles a single distinct column in the query (you 
can have multiple aggregate functions using that distinct column). For a query 
having a aggregate function with DISTINCT and regular aggregate functions, the 
generated plan will do partial aggregations for those regular aggregate 
function.

Thanks chenghao-intel for his initial work on it.

Author: Yin Huai <[email protected]>
Author: Michael Armbrust <[email protected]>

Closes #7458 from yhuai/UDAF and squashes the following commits:

7865f5e [Yin Huai] Put the catalyst expression in the comment of the generated 
code for it.
b04d6c8 [Yin Huai] Remove unnecessary change.
f1d5901 [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF
35b0520 [Yin Huai] Use semanticEquals to replace grouping expressions in the 
output of the aggregate operator.
3b43b24 [Yin Huai] bug fix.
00eb298 [Yin Huai] Make it compile.
a3ca551 [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF
e0afca3 [Yin Huai] Gracefully fallback to old aggregation code path.
8a8ac4a [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF
88c7d4d [Yin Huai] Enable spark.sql.useAggregate2 by default for testing 
purpose.
dc96fd1 [Yin Huai] Many updates:
85c9c4b [Yin Huai] newline.
43de3de [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF
c3614d7 [Yin Huai] Handle single distinct column.
68b8ee9 [Yin Huai] Support single distinct column set. WIP
3013579 [Yin Huai] Format.
d678aee [Yin Huai] Remove AggregateExpressionSuite.scala since our built-in 
aggregate functions will be based on AlgebraicAggregate and we need to have 
another way to test it.
e243ca6 [Yin Huai] Add aggregation iterators.
a101960 [Yin Huai] Change MyJavaUDAF to MyDoubleSum.
594cdf5 [Yin Huai] Change existing AggregateExpression to AggregateExpression1 
and add an AggregateExpression as the common interface for both 
AggregateExpression1 and AggregateExpression2.
380880f [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF
0a827b3 [Yin Huai] Add comments and doc. Move some classes to the right places.
a19fea6 [Yin Huai] Add UDAF interface.
262d4c4 [Yin Huai] Make it compile.
b2e358e [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF
6edb5ac [Yin Huai] Format update.
70b169c [Yin Huai] Remove groupOrdering.
4721936 [Yin Huai] Add CheckAggregateFunction to extendedCheckRules.
d821a34 [Yin Huai] Cleanup.
32aea9c [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF
5b46d41 [Yin Huai] Bug fix.
aff9534 [Yin Huai] Make Aggregate2Sort work with both algebraic 
AggregateFunctions and non-algebraic AggregateFunctions.
2857b55 [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF
4435f20 [Yin Huai] Add ConvertAggregateFunction to HiveContext's analyzer.
1b490ed [Michael Armbrust] make hive test
8cfa6a9 [Michael Armbrust] add test
1b0bb3f [Yin Huai] Do not bind references in AlgebraicAggregate and use code 
gen for all places.
072209f [Yin Huai] Bug fix: Handle expressions in grouping columns that are not 
attribute references.
f7d9e54 [Michael Armbrust] Merge remote-tracking branch 'apache/master' into 
UDAF
39ee975 [Yin Huai] Code cleanup: Remove unnecesary AttributeReferences.
b7720ba [Yin Huai] Add an analysis rule to convert aggregate function to the 
new version.
5c00f3f [Michael Armbrust] First draft of codegen
6bbc6ba [Michael Armbrust] now with correct answers\!
f7996d0 [Michael Armbrust] Add AlgebraicAggregate
dded1c5 [Yin Huai] wip


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c03299a1
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c03299a1
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c03299a1

Branch: refs/heads/master
Commit: c03299a18b4e076cabb4b7833a1e7632c5c0dabe
Parents: f4785f5
Author: Yin Huai <[email protected]>
Authored: Tue Jul 21 23:26:11 2015 -0700
Committer: Reynold Xin <[email protected]>
Committed: Tue Jul 21 23:26:11 2015 -0700

----------------------------------------------------------------------
 .../apache/spark/sql/catalyst/SqlParser.scala   |   3 +-
 .../spark/sql/catalyst/analysis/Analyzer.scala  |  24 +-
 .../sql/catalyst/analysis/CheckAnalysis.scala   |   1 +
 .../sql/catalyst/analysis/unresolved.scala      |   5 +-
 .../catalyst/expressions/BoundAttribute.scala   |   2 +-
 .../sql/catalyst/expressions/Expression.scala   |   3 +-
 .../expressions/aggregate/functions.scala       | 292 ++++++++
 .../expressions/aggregate/interfaces.scala      | 206 +++++
 .../sql/catalyst/expressions/aggregates.scala   | 100 +--
 .../codegen/GenerateMutableProjection.scala     |  21 +-
 .../spark/sql/catalyst/planning/patterns.scala  |   4 +-
 .../catalyst/plans/logical/basicOperators.scala |   1 +
 .../scala/org/apache/spark/sql/SQLConf.scala    |   5 +
 .../scala/org/apache/spark/sql/SQLContext.scala |   4 +
 .../org/apache/spark/sql/UDAFRegistration.scala |  35 +
 .../apache/spark/sql/execution/Aggregate.scala  |  12 +-
 .../apache/spark/sql/execution/Exchange.scala   |  11 +-
 .../sql/execution/GeneratedAggregate.scala      |   2 +-
 .../spark/sql/execution/SparkStrategies.scala   | 100 ++-
 .../aggregate/aggregateOperators.scala          | 173 +++++
 .../aggregate/sortBasedIterators.scala          | 749 +++++++++++++++++++
 .../spark/sql/execution/aggregate/utils.scala   | 364 +++++++++
 .../spark/sql/expressions/aggregate/udaf.scala  | 280 +++++++
 .../scala/org/apache/spark/sql/functions.scala  |   4 +-
 .../org/apache/spark/sql/SQLQuerySuite.scala    |   4 +-
 .../spark/sql/execution/PlannerSuite.scala      |  26 +-
 .../HiveWindowFunctionQuerySuite.scala          |   1 +
 .../execution/SortMergeCompatibilitySuite.scala |   7 +
 .../org/apache/spark/sql/hive/HiveContext.scala |   1 +
 .../org/apache/spark/sql/hive/HiveQl.scala      |   7 +-
 .../org/apache/spark/sql/hive/hiveUDFs.scala    |   8 +-
 .../spark/sql/hive/aggregate/MyDoubleAvg.java   | 107 +++
 .../spark/sql/hive/aggregate/MyDoubleSum.java   | 100 +++
 ...udf_unhex-0-50131c0ba7b7a6b65c789a5a8497bada |   1 +
 ...udf_unhex-1-11eb3cc5216d5446f4165007203acc47 |   1 +
 ...udf_unhex-2-a660886085b8651852b9b77934848ae4 |  14 +
 ...udf_unhex-3-4b2cf4050af229fde91ab53fd9f3af3e |   1 +
 ...udf_unhex-4-7d3e094f139892ecef17de3fd63ca3c3 |   1 +
 .../hive/execution/AggregationQuerySuite.scala  | 507 +++++++++++++
 39 files changed, 3087 insertions(+), 100 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index d4ef04c..c04bd6c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -266,11 +266,12 @@ class SqlParser extends AbstractSparkSQLParser with 
DataTypeParser {
       }
     }
     | ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^
-      { case udfName ~ exprs => UnresolvedFunction(udfName, exprs) }
+      { case udfName ~ exprs => UnresolvedFunction(udfName, exprs, isDistinct 
= false) }
     | ident ~ ("(" ~ DISTINCT ~> repsep(expression, ",")) <~ ")" ^^ { case 
udfName ~ exprs =>
       lexical.normalizeKeyword(udfName) match {
         case "sum" => SumDistinct(exprs.head)
         case "count" => CountDistinct(exprs)
+        case name => UnresolvedFunction(name, exprs, isDistinct = true)
         case _ => throw new AnalysisException(s"function $udfName does not 
support DISTINCT")
       }
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index e58f3f6..8cadbc5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.analysis
 
 import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, 
AggregateExpression2, AggregateFunction2}
 import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical._
@@ -277,7 +278,7 @@ class Analyzer(
         Project(
           projectList.flatMap {
             case s: Star => s.expand(child.output, resolver)
-            case UnresolvedAlias(f @ UnresolvedFunction(_, args)) if 
containsStar(args) =>
+            case UnresolvedAlias(f @ UnresolvedFunction(_, args, _)) if 
containsStar(args) =>
               val expandedArgs = args.flatMap {
                 case s: Star => s.expand(child.output, resolver)
                 case o => o :: Nil
@@ -517,9 +518,26 @@ class Analyzer(
     def apply(plan: LogicalPlan): LogicalPlan = plan transform {
       case q: LogicalPlan =>
         q transformExpressions {
-          case u @ UnresolvedFunction(name, children) =>
+          case u @ UnresolvedFunction(name, children, isDistinct) =>
             withPosition(u) {
-              registry.lookupFunction(name, children)
+              registry.lookupFunction(name, children) match {
+                // We get an aggregate function built based on 
AggregateFunction2 interface.
+                // So, we wrap it in AggregateExpression2.
+                case agg2: AggregateFunction2 => AggregateExpression2(agg2, 
Complete, isDistinct)
+                // Currently, our old aggregate function interface supports 
SUM(DISTINCT ...)
+                // and COUTN(DISTINCT ...).
+                case sumDistinct: SumDistinct => sumDistinct
+                case countDistinct: CountDistinct => countDistinct
+                // DISTINCT is not meaningful with Max and Min.
+                case max: Max if isDistinct => max
+                case min: Min if isDistinct => min
+                // For other aggregate functions, DISTINCT keyword is not 
supported for now.
+                // Once we converted to the new code path, we will allow using 
DISTINCT keyword.
+                case other if isDistinct =>
+                  failAnalysis(s"$name does not support DISTINCT keyword.")
+                // If it does not have DISTINCT keyword, we will return it as 
is.
+                case other => other
+              }
             }
         }
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index c7f9713..c203fce 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis
 
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.types._
 

http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 0daee19..03da45b 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -73,7 +73,10 @@ object UnresolvedAttribute {
   def quoted(name: String): UnresolvedAttribute = new 
UnresolvedAttribute(Seq(name))
 }
 
-case class UnresolvedFunction(name: String, children: Seq[Expression])
+case class UnresolvedFunction(
+    name: String,
+    children: Seq[Expression],
+    isDistinct: Boolean)
   extends Expression with Unevaluable {
 
   override def dataType: DataType = throw new UnresolvedException(this, 
"dataType")

http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index b09aea0..b10a3c8 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql.types._
 case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
   extends LeafExpression with NamedExpression {
 
-  override def toString: String = s"input[$ordinal]"
+  override def toString: String = s"input[$ordinal, $dataType]"
 
   override def eval(input: InternalRow): Any = input(ordinal)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index aada252..29ae47e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -96,7 +96,8 @@ abstract class Expression extends TreeNode[Expression] {
     val primitive = ctx.freshName("primitive")
     val ve = GeneratedExpressionCode("", isNull, primitive)
     ve.code = genCode(ctx, ve)
-    ve
+    // Add `this` in the comment.
+    ve.copy(s"/* $this */\n" + ve.code)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
new file mode 100644
index 0000000..b924af4
--- /dev/null
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
@@ -0,0 +1,292 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types._
+
+case class Average(child: Expression) extends AlgebraicAggregate {
+
+  override def children: Seq[Expression] = child :: Nil
+
+  override def nullable: Boolean = true
+
+  // Return data type.
+  override def dataType: DataType = resultType
+
+  // Expected input data type.
+  // TODO: Once we remove the old code path, we can use our analyzer to cast 
NullType
+  // to the default data type of the NumericType.
+  override def inputTypes: Seq[AbstractDataType] = 
Seq(TypeCollection(NumericType, NullType))
+
+  private val resultType = child.dataType match {
+    case DecimalType.Fixed(precision, scale) =>
+      DecimalType(precision + 4, scale + 4)
+    case DecimalType.Unlimited => DecimalType.Unlimited
+    case _ => DoubleType
+  }
+
+  private val sumDataType = child.dataType match {
+    case _ @ DecimalType() => DecimalType.Unlimited
+    case _ => DoubleType
+  }
+
+  private val currentSum = AttributeReference("currentSum", sumDataType)()
+  private val currentCount = AttributeReference("currentCount", LongType)()
+
+  override val bufferAttributes = currentSum :: currentCount :: Nil
+
+  override val initialValues = Seq(
+    /* currentSum = */ Cast(Literal(0), sumDataType),
+    /* currentCount = */ Literal(0L)
+  )
+
+  override val updateExpressions = Seq(
+    /* currentSum = */
+    Add(
+      currentSum,
+      Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: 
Nil)),
+    /* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L)
+  )
+
+  override val mergeExpressions = Seq(
+    /* currentSum = */ currentSum.left + currentSum.right,
+    /* currentCount = */ currentCount.left + currentCount.right
+  )
+
+  // If all input are nulls, currentCount will be 0 and we will get null after 
the division.
+  override val evaluateExpression = Cast(currentSum, resultType) / 
Cast(currentCount, resultType)
+}
+
+case class Count(child: Expression) extends AlgebraicAggregate {
+  override def children: Seq[Expression] = child :: Nil
+
+  override def nullable: Boolean = false
+
+  // Return data type.
+  override def dataType: DataType = LongType
+
+  // Expected input data type.
+  override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
+
+  private val currentCount = AttributeReference("currentCount", LongType)()
+
+  override val bufferAttributes = currentCount :: Nil
+
+  override val initialValues = Seq(
+    /* currentCount = */ Literal(0L)
+  )
+
+  override val updateExpressions = Seq(
+    /* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L)
+  )
+
+  override val mergeExpressions = Seq(
+    /* currentCount = */ currentCount.left + currentCount.right
+  )
+
+  override val evaluateExpression = Cast(currentCount, LongType)
+}
+
+case class First(child: Expression) extends AlgebraicAggregate {
+
+  override def children: Seq[Expression] = child :: Nil
+
+  override def nullable: Boolean = true
+
+  // First is not a deterministic function.
+  override def deterministic: Boolean = false
+
+  // Return data type.
+  override def dataType: DataType = child.dataType
+
+  // Expected input data type.
+  override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
+
+  private val first = AttributeReference("first", child.dataType)()
+
+  override val bufferAttributes = first :: Nil
+
+  override val initialValues = Seq(
+    /* first = */ Literal.create(null, child.dataType)
+  )
+
+  override val updateExpressions = Seq(
+    /* first = */ If(IsNull(first), child, first)
+  )
+
+  override val mergeExpressions = Seq(
+    /* first = */ If(IsNull(first.left), first.right, first.left)
+  )
+
+  override val evaluateExpression = first
+}
+
+case class Last(child: Expression) extends AlgebraicAggregate {
+
+  override def children: Seq[Expression] = child :: Nil
+
+  override def nullable: Boolean = true
+
+  // Last is not a deterministic function.
+  override def deterministic: Boolean = false
+
+  // Return data type.
+  override def dataType: DataType = child.dataType
+
+  // Expected input data type.
+  override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
+
+  private val last = AttributeReference("last", child.dataType)()
+
+  override val bufferAttributes = last :: Nil
+
+  override val initialValues = Seq(
+    /* last = */ Literal.create(null, child.dataType)
+  )
+
+  override val updateExpressions = Seq(
+    /* last = */ If(IsNull(child), last, child)
+  )
+
+  override val mergeExpressions = Seq(
+    /* last = */ If(IsNull(last.right), last.left, last.right)
+  )
+
+  override val evaluateExpression = last
+}
+
+case class Max(child: Expression) extends AlgebraicAggregate {
+
+  override def children: Seq[Expression] = child :: Nil
+
+  override def nullable: Boolean = true
+
+  // Return data type.
+  override def dataType: DataType = child.dataType
+
+  // Expected input data type.
+  override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
+
+  private val max = AttributeReference("max", child.dataType)()
+
+  override val bufferAttributes = max :: Nil
+
+  override val initialValues = Seq(
+    /* max = */ Literal.create(null, child.dataType)
+  )
+
+  override val updateExpressions = Seq(
+    /* max = */ If(IsNull(child), max, If(IsNull(max), child, 
Greatest(Seq(max, child))))
+  )
+
+  override val mergeExpressions = {
+    val greatest = Greatest(Seq(max.left, max.right))
+    Seq(
+      /* max = */ If(IsNull(max.right), max.left, If(IsNull(max.left), 
max.right, greatest))
+    )
+  }
+
+  override val evaluateExpression = max
+}
+
+case class Min(child: Expression) extends AlgebraicAggregate {
+
+  override def children: Seq[Expression] = child :: Nil
+
+  override def nullable: Boolean = true
+
+  // Return data type.
+  override def dataType: DataType = child.dataType
+
+  // Expected input data type.
+  override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
+
+  private val min = AttributeReference("min", child.dataType)()
+
+  override val bufferAttributes = min :: Nil
+
+  override val initialValues = Seq(
+    /* min = */ Literal.create(null, child.dataType)
+  )
+
+  override val updateExpressions = Seq(
+    /* min = */ If(IsNull(child), min, If(IsNull(min), child, Least(Seq(min, 
child))))
+  )
+
+  override val mergeExpressions = {
+    val least = Least(Seq(min.left, min.right))
+    Seq(
+      /* min = */ If(IsNull(min.right), min.left, If(IsNull(min.left), 
min.right, least))
+    )
+  }
+
+  override val evaluateExpression = min
+}
+
+case class Sum(child: Expression) extends AlgebraicAggregate {
+
+  override def children: Seq[Expression] = child :: Nil
+
+  override def nullable: Boolean = true
+
+  // Return data type.
+  override def dataType: DataType = resultType
+
+  // Expected input data type.
+  override def inputTypes: Seq[AbstractDataType] =
+    Seq(TypeCollection(LongType, DoubleType, DecimalType, NullType))
+
+  private val resultType = child.dataType match {
+    case DecimalType.Fixed(precision, scale) =>
+      DecimalType(precision + 4, scale + 4)
+    case DecimalType.Unlimited => DecimalType.Unlimited
+    case _ => child.dataType
+  }
+
+  private val sumDataType = child.dataType match {
+    case _ @ DecimalType() => DecimalType.Unlimited
+    case _ => child.dataType
+  }
+
+  private val currentSum = AttributeReference("currentSum", sumDataType)()
+
+  private val zero = Cast(Literal(0), sumDataType)
+
+  override val bufferAttributes = currentSum :: Nil
+
+  override val initialValues = Seq(
+    /* currentSum = */ Literal.create(null, sumDataType)
+  )
+
+  override val updateExpressions = Seq(
+    /* currentSum = */
+    Coalesce(Seq(Add(Coalesce(Seq(currentSum, zero)), Cast(child, 
sumDataType)), currentSum))
+  )
+
+  override val mergeExpressions = {
+    val add = Add(Coalesce(Seq(currentSum.left, zero)), Cast(currentSum.right, 
sumDataType))
+    Seq(
+      /* currentSum = */
+      Coalesce(Seq(add, currentSum.left))
+    )
+  }
+
+  override val evaluateExpression = Cast(currentSum, resultType)
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
new file mode 100644
index 0000000..577ede7
--- /dev/null
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -0,0 +1,206 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import org.apache.spark.sql.catalyst.errors.TreeNodeException
+import org.apache.spark.sql.catalyst.expressions._
+import 
org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, 
CodeGenContext}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.types._
+
+/** The mode of an [[AggregateFunction1]]. */
+private[sql] sealed trait AggregateMode
+
+/**
+ * An [[AggregateFunction1]] with [[Partial]] mode is used for partial 
aggregation.
+ * This function updates the given aggregation buffer with the original input 
of this
+ * function. When it has processed all input rows, the aggregation buffer is 
returned.
+ */
+private[sql] case object Partial extends AggregateMode
+
+/**
+ * An [[AggregateFunction1]] with [[PartialMerge]] mode is used to merge 
aggregation buffers
+ * containing intermediate results for this function.
+ * This function updates the given aggregation buffer by merging multiple 
aggregation buffers.
+ * When it has processed all input rows, the aggregation buffer is returned.
+ */
+private[sql] case object PartialMerge extends AggregateMode
+
+/**
+ * An [[AggregateFunction1]] with [[PartialMerge]] mode is used to merge 
aggregation buffers
+ * containing intermediate results for this function and the generate final 
result.
+ * This function updates the given aggregation buffer by merging multiple 
aggregation buffers.
+ * When it has processed all input rows, the final result of this function is 
returned.
+ */
+private[sql] case object Final extends AggregateMode
+
+/**
+ * An [[AggregateFunction2]] with [[Partial]] mode is used to evaluate this 
function directly
+ * from original input rows without any partial aggregation.
+ * This function updates the given aggregation buffer with the original input 
of this
+ * function. When it has processed all input rows, the final result of this 
function is returned.
+ */
+private[sql] case object Complete extends AggregateMode
+
+/**
+ * A place holder expressions used in code-gen, it does not change the 
corresponding value
+ * in the row.
+ */
+private[sql] case object NoOp extends Expression with Unevaluable {
+  override def nullable: Boolean = true
+  override def eval(input: InternalRow): Any = {
+    throw new TreeNodeException(
+      this, s"No function to evaluate expression. type: ${this.nodeName}")
+  }
+  override def dataType: DataType = NullType
+  override def children: Seq[Expression] = Nil
+}
+
+/**
+ * A container for an [[AggregateFunction2]] with its [[AggregateMode]] and a 
field
+ * (`isDistinct`) indicating if DISTINCT keyword is specified for this 
function.
+ * @param aggregateFunction
+ * @param mode
+ * @param isDistinct
+ */
+private[sql] case class AggregateExpression2(
+    aggregateFunction: AggregateFunction2,
+    mode: AggregateMode,
+    isDistinct: Boolean) extends AggregateExpression {
+
+  override def children: Seq[Expression] = aggregateFunction :: Nil
+  override def dataType: DataType = aggregateFunction.dataType
+  override def foldable: Boolean = false
+  override def nullable: Boolean = aggregateFunction.nullable
+
+  override def references: AttributeSet = {
+    val childReferemces = mode match {
+      case Partial | Complete => aggregateFunction.references.toSeq
+      case PartialMerge | Final => aggregateFunction.bufferAttributes
+    }
+
+    AttributeSet(childReferemces)
+  }
+
+  override def toString: String = 
s"(${aggregateFunction}2,mode=$mode,isDistinct=$isDistinct)"
+}
+
+abstract class AggregateFunction2
+  extends Expression with ImplicitCastInputTypes {
+
+  self: Product =>
+
+  /** An aggregate function is not foldable. */
+  override def foldable: Boolean = false
+
+  /**
+   * The offset of this function's buffer in the underlying buffer shared with 
other functions.
+   */
+  var bufferOffset: Int = 0
+
+  /** The schema of the aggregation buffer. */
+  def bufferSchema: StructType
+
+  /** Attributes of fields in bufferSchema. */
+  def bufferAttributes: Seq[AttributeReference]
+
+  /** Clones bufferAttributes. */
+  def cloneBufferAttributes: Seq[Attribute]
+
+  /**
+   * Initializes its aggregation buffer located in `buffer`.
+   * It will use bufferOffset to find the starting point of
+   * its buffer in the given `buffer` shared with other functions.
+   */
+  def initialize(buffer: MutableRow): Unit
+
+  /**
+   * Updates its aggregation buffer located in `buffer` based on the given 
`input`.
+   * It will use bufferOffset to find the starting point of its buffer in the 
given `buffer`
+   * shared with other functions.
+   */
+  def update(buffer: MutableRow, input: InternalRow): Unit
+
+  /**
+   * Updates its aggregation buffer located in `buffer1` by combining 
intermediate results
+   * in the current buffer and intermediate results from another buffer 
`buffer2`.
+   * It will use bufferOffset to find the starting point of its buffer in the 
given `buffer1`
+   * and `buffer2`.
+   */
+  def merge(buffer1: MutableRow, buffer2: InternalRow): Unit
+
+  override protected def genCode(ctx: CodeGenContext, ev: 
GeneratedExpressionCode): String =
+    throw new UnsupportedOperationException(s"Cannot evaluate expression: 
$this")
+}
+
+/**
+ * A helper class for aggregate functions that can be implemented in terms of 
catalyst expressions.
+ */
+abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable 
{
+  self: Product =>
+
+  val initialValues: Seq[Expression]
+  val updateExpressions: Seq[Expression]
+  val mergeExpressions: Seq[Expression]
+  val evaluateExpression: Expression
+
+  override lazy val cloneBufferAttributes = 
bufferAttributes.map(_.newInstance())
+
+  /**
+   * A helper class for representing an attribute used in merging two
+   * aggregation buffers. When merging two buffers, `bufferLeft` and 
`bufferRight`,
+   * we merge buffer values and then update bufferLeft. A [[RichAttribute]]
+   * of an [[AttributeReference]] `a` has two functions `left` and `right`,
+   * which represent `a` in `bufferLeft` and `bufferRight`, respectively.
+   * @param a
+   */
+  implicit class RichAttribute(a: AttributeReference) {
+    /** Represents this attribute at the mutable buffer side. */
+    def left: AttributeReference = a
+
+    /** Represents this attribute at the input buffer side (the data value is 
read-only). */
+    def right: AttributeReference = 
cloneBufferAttributes(bufferAttributes.indexOf(a))
+  }
+
+  /** An AlgebraicAggregate's bufferSchema is derived from bufferAttributes. */
+  override def bufferSchema: StructType = 
StructType.fromAttributes(bufferAttributes)
+
+  override def initialize(buffer: MutableRow): Unit = {
+    var i = 0
+    while (i < bufferAttributes.size) {
+      buffer(i + bufferOffset) = initialValues(i).eval()
+      i += 1
+    }
+  }
+
+  override def update(buffer: MutableRow, input: InternalRow): Unit = {
+    throw new UnsupportedOperationException(
+      "AlgebraicAggregate's update should not be called directly")
+  }
+
+  override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
+    throw new UnsupportedOperationException(
+      "AlgebraicAggregate's merge should not be called directly")
+  }
+
+  override def eval(buffer: InternalRow): Any = {
+    throw new UnsupportedOperationException(
+      "AlgebraicAggregate's eval should not be called directly")
+  }
+}
+

http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index d705a12..e07c920 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -27,7 +27,9 @@ import org.apache.spark.sql.types._
 import org.apache.spark.util.collection.OpenHashSet
 
 
-trait AggregateExpression extends Expression with Unevaluable {
+trait AggregateExpression extends Expression with Unevaluable
+
+trait AggregateExpression1 extends AggregateExpression {
 
   /**
    * Aggregate expressions should not be foldable.
@@ -38,7 +40,7 @@ trait AggregateExpression extends Expression with Unevaluable 
{
    * Creates a new instance that can be used to compute this aggregate 
expression for a group
    * of input rows/
    */
-  def newInstance(): AggregateFunction
+  def newInstance(): AggregateFunction1
 }
 
 /**
@@ -54,10 +56,10 @@ case class SplitEvaluation(
     partialEvaluations: Seq[NamedExpression])
 
 /**
- * An [[AggregateExpression]] that can be partially computed without seeing 
all relevant tuples.
+ * An [[AggregateExpression1]] that can be partially computed without seeing 
all relevant tuples.
  * These partial evaluations can then be combined to compute the actual answer.
  */
-trait PartialAggregate extends AggregateExpression {
+trait PartialAggregate1 extends AggregateExpression1 {
 
   /**
    * Returns a [[SplitEvaluation]] that computes this aggregation using 
partial aggregation.
@@ -67,13 +69,13 @@ trait PartialAggregate extends AggregateExpression {
 
 /**
  * A specific implementation of an aggregate function. Used to wrap a generic
- * [[AggregateExpression]] with an algorithm that will be used to compute one 
specific result.
+ * [[AggregateExpression1]] with an algorithm that will be used to compute one 
specific result.
  */
-abstract class AggregateFunction
-  extends LeafExpression with AggregateExpression with Serializable {
+abstract class AggregateFunction1
+  extends LeafExpression with AggregateExpression1 with Serializable {
 
   /** Base should return the generic aggregate expression that this function 
is computing */
-  val base: AggregateExpression
+  val base: AggregateExpression1
 
   override def nullable: Boolean = base.nullable
   override def dataType: DataType = base.dataType
@@ -81,12 +83,12 @@ abstract class AggregateFunction
   def update(input: InternalRow): Unit
 
   // Do we really need this?
-  override def newInstance(): AggregateFunction = {
+  override def newInstance(): AggregateFunction1 = {
     makeCopy(productIterator.map { case a: AnyRef => a }.toArray)
   }
 }
 
-case class Min(child: Expression) extends UnaryExpression with 
PartialAggregate {
+case class Min(child: Expression) extends UnaryExpression with 
PartialAggregate1 {
 
   override def nullable: Boolean = true
   override def dataType: DataType = child.dataType
@@ -102,7 +104,7 @@ case class Min(child: Expression) extends UnaryExpression 
with PartialAggregate
     TypeUtils.checkForOrderingExpr(child.dataType, "function min")
 }
 
-case class MinFunction(expr: Expression, base: AggregateExpression) extends 
AggregateFunction {
+case class MinFunction(expr: Expression, base: AggregateExpression1) extends 
AggregateFunction1 {
   def this() = this(null, null) // Required for serialization.
 
   val currentMin: MutableLiteral = MutableLiteral(null, expr.dataType)
@@ -119,7 +121,7 @@ case class MinFunction(expr: Expression, base: 
AggregateExpression) extends Aggr
   override def eval(input: InternalRow): Any = currentMin.value
 }
 
-case class Max(child: Expression) extends UnaryExpression with 
PartialAggregate {
+case class Max(child: Expression) extends UnaryExpression with 
PartialAggregate1 {
 
   override def nullable: Boolean = true
   override def dataType: DataType = child.dataType
@@ -135,7 +137,7 @@ case class Max(child: Expression) extends UnaryExpression 
with PartialAggregate
     TypeUtils.checkForOrderingExpr(child.dataType, "function max")
 }
 
-case class MaxFunction(expr: Expression, base: AggregateExpression) extends 
AggregateFunction {
+case class MaxFunction(expr: Expression, base: AggregateExpression1) extends 
AggregateFunction1 {
   def this() = this(null, null) // Required for serialization.
 
   val currentMax: MutableLiteral = MutableLiteral(null, expr.dataType)
@@ -152,7 +154,7 @@ case class MaxFunction(expr: Expression, base: 
AggregateExpression) extends Aggr
   override def eval(input: InternalRow): Any = currentMax.value
 }
 
-case class Count(child: Expression) extends UnaryExpression with 
PartialAggregate {
+case class Count(child: Expression) extends UnaryExpression with 
PartialAggregate1 {
 
   override def nullable: Boolean = false
   override def dataType: LongType.type = LongType
@@ -165,7 +167,7 @@ case class Count(child: Expression) extends UnaryExpression 
with PartialAggregat
   override def newInstance(): CountFunction = new CountFunction(child, this)
 }
 
-case class CountFunction(expr: Expression, base: AggregateExpression) extends 
AggregateFunction {
+case class CountFunction(expr: Expression, base: AggregateExpression1) extends 
AggregateFunction1 {
   def this() = this(null, null) // Required for serialization.
 
   var count: Long = _
@@ -180,7 +182,7 @@ case class CountFunction(expr: Expression, base: 
AggregateExpression) extends Ag
   override def eval(input: InternalRow): Any = count
 }
 
-case class CountDistinct(expressions: Seq[Expression]) extends 
PartialAggregate {
+case class CountDistinct(expressions: Seq[Expression]) extends 
PartialAggregate1 {
   def this() = this(null)
 
   override def children: Seq[Expression] = expressions
@@ -200,8 +202,8 @@ case class CountDistinct(expressions: Seq[Expression]) 
extends PartialAggregate
 
 case class CountDistinctFunction(
     @transient expr: Seq[Expression],
-    @transient base: AggregateExpression)
-  extends AggregateFunction {
+    @transient base: AggregateExpression1)
+  extends AggregateFunction1 {
 
   def this() = this(null, null) // Required for serialization.
 
@@ -220,7 +222,7 @@ case class CountDistinctFunction(
   override def eval(input: InternalRow): Any = seen.size.toLong
 }
 
-case class CollectHashSet(expressions: Seq[Expression]) extends 
AggregateExpression {
+case class CollectHashSet(expressions: Seq[Expression]) extends 
AggregateExpression1 {
   def this() = this(null)
 
   override def children: Seq[Expression] = expressions
@@ -233,8 +235,8 @@ case class CollectHashSet(expressions: Seq[Expression]) 
extends AggregateExpress
 
 case class CollectHashSetFunction(
     @transient expr: Seq[Expression],
-    @transient base: AggregateExpression)
-  extends AggregateFunction {
+    @transient base: AggregateExpression1)
+  extends AggregateFunction1 {
 
   def this() = this(null, null) // Required for serialization.
 
@@ -255,7 +257,7 @@ case class CollectHashSetFunction(
   }
 }
 
-case class CombineSetsAndCount(inputSet: Expression) extends 
AggregateExpression {
+case class CombineSetsAndCount(inputSet: Expression) extends 
AggregateExpression1 {
   def this() = this(null)
 
   override def children: Seq[Expression] = inputSet :: Nil
@@ -269,8 +271,8 @@ case class CombineSetsAndCount(inputSet: Expression) 
extends AggregateExpression
 
 case class CombineSetsAndCountFunction(
     @transient inputSet: Expression,
-    @transient base: AggregateExpression)
-  extends AggregateFunction {
+    @transient base: AggregateExpression1)
+  extends AggregateFunction1 {
 
   def this() = this(null, null) // Required for serialization.
 
@@ -305,7 +307,7 @@ private[sql] case object HyperLogLogUDT extends 
UserDefinedType[HyperLogLog] {
 }
 
 case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double)
-  extends UnaryExpression with AggregateExpression {
+  extends UnaryExpression with AggregateExpression1 {
 
   override def nullable: Boolean = false
   override def dataType: DataType = HyperLogLogUDT
@@ -317,9 +319,9 @@ case class ApproxCountDistinctPartition(child: Expression, 
relativeSD: Double)
 
 case class ApproxCountDistinctPartitionFunction(
     expr: Expression,
-    base: AggregateExpression,
+    base: AggregateExpression1,
     relativeSD: Double)
-  extends AggregateFunction {
+  extends AggregateFunction1 {
   def this() = this(null, null, 0) // Required for serialization.
 
   private val hyperLogLog = new HyperLogLog(relativeSD)
@@ -335,7 +337,7 @@ case class ApproxCountDistinctPartitionFunction(
 }
 
 case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double)
-  extends UnaryExpression with AggregateExpression {
+  extends UnaryExpression with AggregateExpression1 {
 
   override def nullable: Boolean = false
   override def dataType: LongType.type = LongType
@@ -347,9 +349,9 @@ case class ApproxCountDistinctMerge(child: Expression, 
relativeSD: Double)
 
 case class ApproxCountDistinctMergeFunction(
     expr: Expression,
-    base: AggregateExpression,
+    base: AggregateExpression1,
     relativeSD: Double)
-  extends AggregateFunction {
+  extends AggregateFunction1 {
   def this() = this(null, null, 0) // Required for serialization.
 
   private val hyperLogLog = new HyperLogLog(relativeSD)
@@ -363,7 +365,7 @@ case class ApproxCountDistinctMergeFunction(
 }
 
 case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
-  extends UnaryExpression with PartialAggregate {
+  extends UnaryExpression with PartialAggregate1 {
 
   override def nullable: Boolean = false
   override def dataType: LongType.type = LongType
@@ -381,7 +383,7 @@ case class ApproxCountDistinct(child: Expression, 
relativeSD: Double = 0.05)
   override def newInstance(): CountDistinctFunction = new 
CountDistinctFunction(child :: Nil, this)
 }
 
-case class Average(child: Expression) extends UnaryExpression with 
PartialAggregate {
+case class Average(child: Expression) extends UnaryExpression with 
PartialAggregate1 {
 
   override def prettyName: String = "avg"
 
@@ -427,8 +429,8 @@ case class Average(child: Expression) extends 
UnaryExpression with PartialAggreg
     TypeUtils.checkForNumericExpr(child.dataType, "function average")
 }
 
-case class AverageFunction(expr: Expression, base: AggregateExpression)
-  extends AggregateFunction {
+case class AverageFunction(expr: Expression, base: AggregateExpression1)
+  extends AggregateFunction1 {
 
   def this() = this(null, null) // Required for serialization.
 
@@ -474,7 +476,7 @@ case class AverageFunction(expr: Expression, base: 
AggregateExpression)
   }
 }
 
-case class Sum(child: Expression) extends UnaryExpression with 
PartialAggregate {
+case class Sum(child: Expression) extends UnaryExpression with 
PartialAggregate1 {
 
   override def nullable: Boolean = true
 
@@ -509,7 +511,7 @@ case class Sum(child: Expression) extends UnaryExpression 
with PartialAggregate
     TypeUtils.checkForNumericExpr(child.dataType, "function sum")
 }
 
-case class SumFunction(expr: Expression, base: AggregateExpression) extends 
AggregateFunction {
+case class SumFunction(expr: Expression, base: AggregateExpression1) extends 
AggregateFunction1 {
   def this() = this(null, null) // Required for serialization.
 
   private val calcType =
@@ -554,7 +556,7 @@ case class SumFunction(expr: Expression, base: 
AggregateExpression) extends Aggr
  *          <-- null         <-- no data
  * null     <-- null         <-- no data
  */
-case class CombineSum(child: Expression) extends AggregateExpression {
+case class CombineSum(child: Expression) extends AggregateExpression1 {
   def this() = this(null)
 
   override def children: Seq[Expression] = child :: Nil
@@ -564,8 +566,8 @@ case class CombineSum(child: Expression) extends 
AggregateExpression {
   override def newInstance(): CombineSumFunction = new 
CombineSumFunction(child, this)
 }
 
-case class CombineSumFunction(expr: Expression, base: AggregateExpression)
-  extends AggregateFunction {
+case class CombineSumFunction(expr: Expression, base: AggregateExpression1)
+  extends AggregateFunction1 {
 
   def this() = this(null, null) // Required for serialization.
 
@@ -601,7 +603,7 @@ case class CombineSumFunction(expr: Expression, base: 
AggregateExpression)
   }
 }
 
-case class SumDistinct(child: Expression) extends UnaryExpression with 
PartialAggregate {
+case class SumDistinct(child: Expression) extends UnaryExpression with 
PartialAggregate1 {
 
   def this() = this(null)
   override def nullable: Boolean = true
@@ -627,8 +629,8 @@ case class SumDistinct(child: Expression) extends 
UnaryExpression with PartialAg
     TypeUtils.checkForNumericExpr(child.dataType, "function sumDistinct")
 }
 
-case class SumDistinctFunction(expr: Expression, base: AggregateExpression)
-  extends AggregateFunction {
+case class SumDistinctFunction(expr: Expression, base: AggregateExpression1)
+  extends AggregateFunction1 {
 
   def this() = this(null, null) // Required for serialization.
 
@@ -653,7 +655,7 @@ case class SumDistinctFunction(expr: Expression, base: 
AggregateExpression)
   }
 }
 
-case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends 
AggregateExpression {
+case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends 
AggregateExpression1 {
   def this() = this(null, null)
 
   override def children: Seq[Expression] = inputSet :: Nil
@@ -667,8 +669,8 @@ case class CombineSetsAndSum(inputSet: Expression, base: 
Expression) extends Agg
 
 case class CombineSetsAndSumFunction(
     @transient inputSet: Expression,
-    @transient base: AggregateExpression)
-  extends AggregateFunction {
+    @transient base: AggregateExpression1)
+  extends AggregateFunction1 {
 
   def this() = this(null, null) // Required for serialization.
 
@@ -695,7 +697,7 @@ case class CombineSetsAndSumFunction(
   }
 }
 
-case class First(child: Expression) extends UnaryExpression with 
PartialAggregate {
+case class First(child: Expression) extends UnaryExpression with 
PartialAggregate1 {
   override def nullable: Boolean = true
   override def dataType: DataType = child.dataType
   override def toString: String = s"FIRST($child)"
@@ -709,7 +711,7 @@ case class First(child: Expression) extends UnaryExpression 
with PartialAggregat
   override def newInstance(): FirstFunction = new FirstFunction(child, this)
 }
 
-case class FirstFunction(expr: Expression, base: AggregateExpression) extends 
AggregateFunction {
+case class FirstFunction(expr: Expression, base: AggregateExpression1) extends 
AggregateFunction1 {
   def this() = this(null, null) // Required for serialization.
 
   var result: Any = null
@@ -723,7 +725,7 @@ case class FirstFunction(expr: Expression, base: 
AggregateExpression) extends Ag
   override def eval(input: InternalRow): Any = result
 }
 
-case class Last(child: Expression) extends UnaryExpression with 
PartialAggregate {
+case class Last(child: Expression) extends UnaryExpression with 
PartialAggregate1 {
   override def references: AttributeSet = child.references
   override def nullable: Boolean = true
   override def dataType: DataType = child.dataType
@@ -738,7 +740,7 @@ case class Last(child: Expression) extends UnaryExpression 
with PartialAggregate
   override def newInstance(): LastFunction = new LastFunction(child, this)
 }
 
-case class LastFunction(expr: Expression, base: AggregateExpression) extends 
AggregateFunction {
+case class LastFunction(expr: Expression, base: AggregateExpression1) extends 
AggregateFunction1 {
   def this() = this(null, null) // Required for serialization.
 
   var result: Any = null

http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index 03b4b3c..d838268 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions.codegen
 
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
 
 import scala.collection.mutable.ArrayBuffer
 
@@ -38,15 +39,17 @@ object GenerateMutableProjection extends 
CodeGenerator[Seq[Expression], () => Mu
 
   protected def create(expressions: Seq[Expression]): (() => 
MutableProjection) = {
     val ctx = newCodeGenContext()
-    val projectionCode = expressions.zipWithIndex.map { case (e, i) =>
-      val evaluationCode = e.gen(ctx)
-      evaluationCode.code +
-        s"""
-          if(${evaluationCode.isNull})
-            mutableRow.setNullAt($i);
-          else
-            ${ctx.setColumn("mutableRow", e.dataType, i, 
evaluationCode.primitive)};
-        """
+    val projectionCode = expressions.zipWithIndex.map {
+      case (NoOp, _) => ""
+      case (e, i) =>
+        val evaluationCode = e.gen(ctx)
+        evaluationCode.code +
+          s"""
+            if(${evaluationCode.isNull})
+              mutableRow.setNullAt($i);
+            else
+              ${ctx.setColumn("mutableRow", e.dataType, i, 
evaluationCode.primitive)};
+          """
     }
     // collect projections into blocks as function has 64kb codesize limit in 
JVM
     val projectionBlocks = new ArrayBuffer[String]()

http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 179a348..b8e3b0d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -129,10 +129,10 @@ object PartialAggregation {
     case logical.Aggregate(groupingExpressions, aggregateExpressions, child) =>
       // Collect all aggregate expressions.
       val allAggregates =
-        aggregateExpressions.flatMap(_ collect { case a: AggregateExpression 
=> a})
+        aggregateExpressions.flatMap(_ collect { case a: AggregateExpression1 
=> a})
       // Collect all aggregate expressions that can be computed partially.
       val partialAggregates =
-        aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => 
p})
+        aggregateExpressions.flatMap(_ collect { case p: PartialAggregate1 => 
p})
 
       // Only do partial aggregation if supported by all aggregate expressions.
       if (allAggregates.size == partialAggregates.size) {

http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 986c315..6aefa9f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.plans.logical
 
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.types._
 import org.apache.spark.util.collection.OpenHashSet

http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index 78c780b..1474b17 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -402,6 +402,9 @@ private[spark] object SQLConf {
     defaultValue = Some(true),
     isPublic = false)
 
+  val USE_SQL_AGGREGATE2 = booleanConf("spark.sql.useAggregate2",
+    defaultValue = Some(true), doc = "<TODO>")
+
   val USE_SQL_SERIALIZER2 = booleanConf(
     "spark.sql.useSerializer2",
     defaultValue = Some(true), isPublic = false)
@@ -473,6 +476,8 @@ private[sql] class SQLConf extends Serializable with 
CatalystConf {
 
   private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED)
 
+  private[spark] def useSqlAggregate2: Boolean = getConf(USE_SQL_AGGREGATE2)
+
   private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2)
 
   private[spark] def autoBroadcastJoinThreshold: Int = 
getConf(AUTO_BROADCASTJOIN_THRESHOLD)

http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 8b4528b..49bfe74 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -285,6 +285,9 @@ class SQLContext(@transient val sparkContext: SparkContext)
   @transient
   val udf: UDFRegistration = new UDFRegistration(this)
 
+  @transient
+  val udaf: UDAFRegistration = new UDAFRegistration(this)
+
   /**
    * Returns true if the table is currently cached in-memory.
    * @group cachemgmt
@@ -863,6 +866,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
       DDLStrategy ::
       TakeOrderedAndProject ::
       HashAggregation ::
+      Aggregation ::
       LeftSemiJoin ::
       HashJoin ::
       InMemoryScans ::

http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala
new file mode 100644
index 0000000..5b872f5
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.Logging
+import org.apache.spark.sql.catalyst.expressions.{Expression}
+import org.apache.spark.sql.expressions.aggregate.{ScalaUDAF, 
UserDefinedAggregateFunction}
+
+class UDAFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
+
+  private val functionRegistry = sqlContext.functionRegistry
+
+  def register(
+      name: String,
+      func: UserDefinedAggregateFunction): UserDefinedAggregateFunction = {
+    def builder(children: Seq[Expression]) = ScalaUDAF(children, func)
+    functionRegistry.registerFunction(name, builder)
+    func
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
index 3cd60a2..c2c9453 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
@@ -68,14 +68,14 @@ case class Aggregate(
    *                        output.
    */
   case class ComputedAggregate(
-      unbound: AggregateExpression,
-      aggregate: AggregateExpression,
+      unbound: AggregateExpression1,
+      aggregate: AggregateExpression1,
       resultAttribute: AttributeReference)
 
   /** A list of aggregates that need to be computed for each group. */
   private[this] val computedAggregates = aggregateExpressions.flatMap { agg =>
     agg.collect {
-      case a: AggregateExpression =>
+      case a: AggregateExpression1 =>
         ComputedAggregate(
           a,
           BindReferences.bindReference(a, child.output),
@@ -87,8 +87,8 @@ case class Aggregate(
   private[this] val computedSchema = computedAggregates.map(_.resultAttribute)
 
   /** Creates a new aggregate buffer for a group. */
-  private[this] def newAggregateBuffer(): Array[AggregateFunction] = {
-    val buffer = new Array[AggregateFunction](computedAggregates.length)
+  private[this] def newAggregateBuffer(): Array[AggregateFunction1] = {
+    val buffer = new Array[AggregateFunction1](computedAggregates.length)
     var i = 0
     while (i < computedAggregates.length) {
       buffer(i) = computedAggregates(i).aggregate.newInstance()
@@ -146,7 +146,7 @@ case class Aggregate(
       }
     } else {
       child.execute().mapPartitions { iter =>
-        val hashTable = new HashMap[InternalRow, Array[AggregateFunction]]
+        val hashTable = new HashMap[InternalRow, Array[AggregateFunction1]]
         val groupingProjection = new 
InterpretedMutableProjection(groupingExpressions, child.output)
 
         var currentRow: InternalRow = null

http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 2750053..d31e265 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -247,8 +247,15 @@ private[sql] case class EnsureRequirements(sqlContext: 
SQLContext) extends Rule[
         }
 
         def addSortIfNecessary(child: SparkPlan): SparkPlan = {
-          if (rowOrdering.nonEmpty && child.outputOrdering != rowOrdering) {
-            sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, 
global = false, child)
+
+          if (rowOrdering.nonEmpty) {
+            // If child.outputOrdering is [a, b] and rowOrdering is [a], we do 
not need to sort.
+            val minSize = Seq(rowOrdering.size, child.outputOrdering.size).min
+            if (minSize == 0 || rowOrdering.take(minSize) != 
child.outputOrdering.take(minSize)) {
+              sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, 
global = false, child)
+            } else {
+              child
+            }
           } else {
             child
           }

http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index ecde9c5..0e63f2f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -69,7 +69,7 @@ case class GeneratedAggregate(
 
   protected override def doExecute(): RDD[InternalRow] = {
     val aggregatesToCompute = aggregateExpressions.flatMap { a =>
-      a.collect { case agg: AggregateExpression => agg}
+      a.collect { case agg: AggregateExpression1 => agg}
     }
 
     // If you add any new function support, please add tests in 
org.apache.spark.sql.SQLQuerySuite

http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 8cef7f2..f54aa20 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution
 import org.apache.spark.sql.{SQLContext, Strategy, execution}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
+import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2}
 import org.apache.spark.sql.catalyst.planning._
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan}
@@ -148,7 +149,8 @@ private[sql] abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
              if canBeCodeGened(
                   allAggregates(partialComputation) ++
                   allAggregates(rewrittenAggregateExpressions)) &&
-               codegenEnabled =>
+               codegenEnabled &&
+               !canBeConvertedToNewAggregation(plan) =>
           execution.GeneratedAggregate(
             partial = false,
             namedGroupingAttributes,
@@ -167,7 +169,7 @@ private[sql] abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
              rewrittenAggregateExpressions,
              groupingExpressions,
              partialComputation,
-             child) =>
+             child) if !canBeConvertedToNewAggregation(plan) =>
         execution.Aggregate(
           partial = false,
           namedGroupingAttributes,
@@ -181,7 +183,14 @@ private[sql] abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
       case _ => Nil
     }
 
-    def canBeCodeGened(aggs: Seq[AggregateExpression]): Boolean = !aggs.exists 
{
+    def canBeConvertedToNewAggregation(plan: LogicalPlan): Boolean = {
+      aggregate.Utils.tryConvert(
+        plan,
+        sqlContext.conf.useSqlAggregate2,
+        sqlContext.conf.codegenEnabled).isDefined
+    }
+
+    def canBeCodeGened(aggs: Seq[AggregateExpression1]): Boolean = 
!aggs.exists {
       case _: CombineSum | _: Sum | _: Count | _: Max | _: Min |  _: 
CombineSetsAndCount => false
       // The generated set implementation is pretty limited ATM.
       case CollectHashSet(exprs) if exprs.size == 1  &&
@@ -189,10 +198,74 @@ private[sql] abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
       case _ => true
     }
 
-    def allAggregates(exprs: Seq[Expression]): Seq[AggregateExpression] =
-      exprs.flatMap(_.collect { case a: AggregateExpression => a })
+    def allAggregates(exprs: Seq[Expression]): Seq[AggregateExpression1] =
+      exprs.flatMap(_.collect { case a: AggregateExpression1 => a })
   }
 
+  /**
+   * Used to plan the aggregate operator for expressions based on the 
AggregateFunction2 interface.
+   */
+  object Aggregation extends Strategy {
+    def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+      case p: logical.Aggregate =>
+        val converted =
+          aggregate.Utils.tryConvert(
+            p,
+            sqlContext.conf.useSqlAggregate2,
+            sqlContext.conf.codegenEnabled)
+        converted match {
+          case None => Nil // Cannot convert to new aggregation code path.
+          case Some(logical.Aggregate(groupingExpressions, resultExpressions, 
child)) =>
+            // Extracts all distinct aggregate expressions from the 
resultExpressions.
+            val aggregateExpressions = resultExpressions.flatMap { expr =>
+              expr.collect {
+                case agg: AggregateExpression2 => agg
+              }
+            }.toSet.toSeq
+            // For those distinct aggregate expressions, we create a map from 
the
+            // aggregate function to the corresponding attribute of the 
function.
+            val aggregateFunctionMap = aggregateExpressions.map { agg =>
+              val aggregateFunction = agg.aggregateFunction
+              (aggregateFunction, agg.isDistinct) ->
+                Alias(aggregateFunction, 
aggregateFunction.toString)().toAttribute
+            }.toMap
+
+            val (functionsWithDistinct, functionsWithoutDistinct) =
+              aggregateExpressions.partition(_.isDistinct)
+            if 
(functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) {
+              // This is a sanity check. We should not reach here when we have 
multiple distinct
+              // column sets (aggregate.NewAggregation will not match).
+              sys.error(
+                "Multiple distinct column sets are not supported by the new 
aggregation" +
+                  "code path.")
+            }
+
+            val aggregateOperator =
+              if (functionsWithDistinct.isEmpty) {
+                aggregate.Utils.planAggregateWithoutDistinct(
+                  groupingExpressions,
+                  aggregateExpressions,
+                  aggregateFunctionMap,
+                  resultExpressions,
+                  planLater(child))
+              } else {
+                aggregate.Utils.planAggregateWithOneDistinct(
+                  groupingExpressions,
+                  functionsWithDistinct,
+                  functionsWithoutDistinct,
+                  aggregateFunctionMap,
+                  resultExpressions,
+                  planLater(child))
+              }
+
+            aggregateOperator
+        }
+
+      case _ => Nil
+    }
+  }
+
+
   object BroadcastNestedLoopJoin extends Strategy {
     def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
       case logical.Join(left, right, joinType, condition) =>
@@ -336,8 +409,21 @@ private[sql] abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
         execution.Filter(condition, planLater(child)) :: Nil
       case e @ logical.Expand(_, _, _, child) =>
         execution.Expand(e.projections, e.output, planLater(child)) :: Nil
-      case logical.Aggregate(group, agg, child) =>
-        execution.Aggregate(partial = false, group, agg, planLater(child)) :: 
Nil
+      case a @ logical.Aggregate(group, agg, child) => {
+        val useNewAggregation =
+          aggregate.Utils.tryConvert(
+            a,
+            sqlContext.conf.useSqlAggregate2,
+            sqlContext.conf.codegenEnabled).isDefined
+        if (useNewAggregation) {
+          // If this logical.Aggregate can be planned to use new aggregation 
code path
+          // (i.e. it can be planned by the Strategy Aggregation), we will not 
use the old
+          // aggregation code path.
+          Nil
+        } else {
+          execution.Aggregate(partial = false, group, agg, planLater(child)) 
:: Nil
+        }
+      }
       case logical.Window(projectList, windowExpressions, spec, child) =>
         execution.Window(projectList, windowExpressions, spec, 
planLater(child)) :: Nil
       case logical.Sample(lb, ub, withReplacement, seed, child) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala
new file mode 100644
index 0000000..0c90828
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala
@@ -0,0 +1,173 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.errors._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, 
ClusteredDistribution, Distribution, UnspecifiedDistribution}
+import org.apache.spark.sql.execution.{SparkPlan, UnaryNode}
+
+case class Aggregate2Sort(
+    requiredChildDistributionExpressions: Option[Seq[Expression]],
+    groupingExpressions: Seq[NamedExpression],
+    aggregateExpressions: Seq[AggregateExpression2],
+    aggregateAttributes: Seq[Attribute],
+    resultExpressions: Seq[NamedExpression],
+    child: SparkPlan)
+  extends UnaryNode {
+
+  override def canProcessUnsafeRows: Boolean = true
+
+  override def references: AttributeSet = {
+    val referencesInResults =
+      AttributeSet(resultExpressions.flatMap(_.references)) -- 
AttributeSet(aggregateAttributes)
+
+    AttributeSet(
+      groupingExpressions.flatMap(_.references) ++
+      aggregateExpressions.flatMap(_.references) ++
+      referencesInResults)
+  }
+
+  override def requiredChildDistribution: List[Distribution] = {
+    requiredChildDistributionExpressions match {
+      case Some(exprs) if exprs.length == 0 => AllTuples :: Nil
+      case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: 
Nil
+      case None => UnspecifiedDistribution :: Nil
+    }
+  }
+
+  override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
+    // TODO: We should not sort the input rows if they are just in reversed 
order.
+    groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
+  }
+
+  override def outputOrdering: Seq[SortOrder] = {
+    // It is possible that the child.outputOrdering starts with the required
+    // ordering expressions (e.g. we require [a] as the sort expression and the
+    // child's outputOrdering is [a, b]). We can only guarantee the output rows
+    // are sorted by values of groupingExpressions.
+    groupingExpressions.map(SortOrder(_, Ascending))
+  }
+
+  override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
+
+  protected override def doExecute(): RDD[InternalRow] = attachTree(this, 
"execute") {
+    child.execute().mapPartitions { iter =>
+      if (aggregateExpressions.length == 0) {
+        new GroupingIterator(
+          groupingExpressions,
+          resultExpressions,
+          newMutableProjection,
+          child.output,
+          iter)
+      } else {
+        val aggregationIterator: SortAggregationIterator = {
+          aggregateExpressions.map(_.mode).distinct.toList match {
+            case Partial :: Nil =>
+              new PartialSortAggregationIterator(
+                groupingExpressions,
+                aggregateExpressions,
+                newMutableProjection,
+                child.output,
+                iter)
+            case PartialMerge :: Nil =>
+              new PartialMergeSortAggregationIterator(
+                groupingExpressions,
+                aggregateExpressions,
+                newMutableProjection,
+                child.output,
+                iter)
+            case Final :: Nil =>
+              new FinalSortAggregationIterator(
+                groupingExpressions,
+                aggregateExpressions,
+                aggregateAttributes,
+                resultExpressions,
+                newMutableProjection,
+                child.output,
+                iter)
+            case other =>
+              sys.error(
+                s"Could not evaluate ${aggregateExpressions} because we do not 
support evaluate " +
+                  s"modes $other in this operator.")
+          }
+        }
+
+        aggregationIterator
+      }
+    }
+  }
+}
+
+case class FinalAndCompleteAggregate2Sort(
+    previousGroupingExpressions: Seq[NamedExpression],
+    groupingExpressions: Seq[NamedExpression],
+    finalAggregateExpressions: Seq[AggregateExpression2],
+    finalAggregateAttributes: Seq[Attribute],
+    completeAggregateExpressions: Seq[AggregateExpression2],
+    completeAggregateAttributes: Seq[Attribute],
+    resultExpressions: Seq[NamedExpression],
+    child: SparkPlan)
+  extends UnaryNode {
+  override def references: AttributeSet = {
+    val referencesInResults =
+      AttributeSet(resultExpressions.flatMap(_.references)) --
+        AttributeSet(finalAggregateExpressions) --
+        AttributeSet(completeAggregateExpressions)
+
+    AttributeSet(
+      groupingExpressions.flatMap(_.references) ++
+        finalAggregateExpressions.flatMap(_.references) ++
+        completeAggregateExpressions.flatMap(_.references) ++
+        referencesInResults)
+  }
+
+  override def requiredChildDistribution: List[Distribution] = {
+    if (groupingExpressions.isEmpty) {
+      AllTuples :: Nil
+    } else {
+      ClusteredDistribution(groupingExpressions) :: Nil
+    }
+  }
+
+  override def requiredChildOrdering: Seq[Seq[SortOrder]] =
+    groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
+
+  override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
+
+  protected override def doExecute(): RDD[InternalRow] = attachTree(this, 
"execute") {
+    child.execute().mapPartitions { iter =>
+
+      new FinalAndCompleteSortAggregationIterator(
+        previousGroupingExpressions.length,
+        groupingExpressions,
+        finalAggregateExpressions,
+        finalAggregateAttributes,
+        completeAggregateExpressions,
+        completeAggregateAttributes,
+        resultExpressions,
+        newMutableProjection,
+        child.output,
+        iter)
+    }
+  }
+
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to