Repository: spark
Updated Branches:
  refs/heads/master 7f6e3ec79 -> be5dd881f


[SPARK-12913] [SQL] Improve performance of stat functions

As benchmarked and discussed here: 
https://github.com/apache/spark/pull/10786/files#r50038294, benefits from 
codegen, the declarative aggregate function could be much faster than 
imperative one.

Author: Davies Liu <[email protected]>

Closes #10960 from davies/stddev.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/be5dd881
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/be5dd881
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/be5dd881

Branch: refs/heads/master
Commit: be5dd881f1eff248224a92d57cfd1309cb3acf38
Parents: 7f6e3ec
Author: Davies Liu <[email protected]>
Authored: Tue Feb 2 11:50:14 2016 -0800
Committer: Davies Liu <[email protected]>
Committed: Tue Feb 2 11:50:14 2016 -0800

----------------------------------------------------------------------
 .../catalyst/analysis/HiveTypeCoercion.scala    |  18 +-
 .../aggregate/CentralMomentAgg.scala            | 285 +++++++++----------
 .../catalyst/expressions/aggregate/Corr.scala   | 208 ++++----------
 .../expressions/aggregate/Covariance.scala      | 205 ++++---------
 .../expressions/aggregate/Kurtosis.scala        |  54 ----
 .../expressions/aggregate/Skewness.scala        |  53 ----
 .../catalyst/expressions/aggregate/Stddev.scala |  81 ------
 .../expressions/aggregate/Variance.scala        |  81 ------
 .../spark/sql/catalyst/expressions/misc.scala   |  18 ++
 .../org/apache/spark/sql/execution/Window.scala |   6 +-
 .../execution/aggregate/TungstenAggregate.scala |   1 -
 .../execution/BenchmarkWholeStageCodegen.scala  |  55 +++-
 .../hive/execution/HiveCompatibilitySuite.scala |   4 +-
 .../hive/execution/AggregationQuerySuite.scala  |  17 +-
 14 files changed, 331 insertions(+), 755 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/be5dd881/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 957ac89..57bdb16 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -347,18 +347,12 @@ object HiveTypeCoercion {
 
       case Sum(e @ StringType()) => Sum(Cast(e, DoubleType))
       case Average(e @ StringType()) => Average(Cast(e, DoubleType))
-      case StddevPop(e @ StringType(), mutableAggBufferOffset, 
inputAggBufferOffset) =>
-        StddevPop(Cast(e, DoubleType), mutableAggBufferOffset, 
inputAggBufferOffset)
-      case StddevSamp(e @ StringType(), mutableAggBufferOffset, 
inputAggBufferOffset) =>
-        StddevSamp(Cast(e, DoubleType), mutableAggBufferOffset, 
inputAggBufferOffset)
-      case VariancePop(e @ StringType(), mutableAggBufferOffset, 
inputAggBufferOffset) =>
-        VariancePop(Cast(e, DoubleType), mutableAggBufferOffset, 
inputAggBufferOffset)
-      case VarianceSamp(e @ StringType(), mutableAggBufferOffset, 
inputAggBufferOffset) =>
-        VarianceSamp(Cast(e, DoubleType), mutableAggBufferOffset, 
inputAggBufferOffset)
-      case Skewness(e @ StringType(), mutableAggBufferOffset, 
inputAggBufferOffset) =>
-        Skewness(Cast(e, DoubleType), mutableAggBufferOffset, 
inputAggBufferOffset)
-      case Kurtosis(e @ StringType(), mutableAggBufferOffset, 
inputAggBufferOffset) =>
-        Kurtosis(Cast(e, DoubleType), mutableAggBufferOffset, 
inputAggBufferOffset)
+      case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType))
+      case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType))
+      case VariancePop(e @ StringType()) => VariancePop(Cast(e, DoubleType))
+      case VarianceSamp(e @ StringType()) => VarianceSamp(Cast(e, DoubleType))
+      case Skewness(e @ StringType()) => Skewness(Cast(e, DoubleType))
+      case Kurtosis(e @ StringType()) => Kurtosis(Cast(e, DoubleType))
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/be5dd881/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
index 30f6022..9d2db45 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
@@ -17,10 +17,8 @@
 
 package org.apache.spark.sql.catalyst.expressions.aggregate
 
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.util.TypeUtils
 import org.apache.spark.sql.types._
 
 /**
@@ -44,7 +42,7 @@ import org.apache.spark.sql.types._
  *
  * @param child to compute central moments of.
  */
-abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate 
with Serializable {
+abstract class CentralMomentAgg(child: Expression) extends 
DeclarativeAggregate {
 
   /**
    * The central moment order to be computed.
@@ -52,178 +50,161 @@ abstract class CentralMomentAgg(child: Expression) 
extends ImperativeAggregate w
   protected def momentOrder: Int
 
   override def children: Seq[Expression] = Seq(child)
-
   override def nullable: Boolean = true
-
   override def dataType: DataType = DoubleType
+  override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType)
 
-  override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
+  protected val n = AttributeReference("n", DoubleType, nullable = false)()
+  protected val avg = AttributeReference("avg", DoubleType, nullable = false)()
+  protected val m2 = AttributeReference("m2", DoubleType, nullable = false)()
+  protected val m3 = AttributeReference("m3", DoubleType, nullable = false)()
+  protected val m4 = AttributeReference("m4", DoubleType, nullable = false)()
 
-  override def checkInputDataTypes(): TypeCheckResult =
-    TypeUtils.checkForNumericExpr(child.dataType, s"function $prettyName")
+  private def trimHigherOrder[T](expressions: Seq[T]) = 
expressions.take(momentOrder + 1)
 
-  override def aggBufferSchema: StructType = 
StructType.fromAttributes(aggBufferAttributes)
+  override val aggBufferAttributes = trimHigherOrder(Seq(n, avg, m2, m3, m4))
 
-  /**
-   * Size of aggregation buffer.
-   */
-  private[this] val bufferSize = 5
+  override val initialValues: Seq[Expression] = Array.fill(momentOrder + 
1)(Literal(0.0))
 
-  override val aggBufferAttributes: Seq[AttributeReference] = 
Seq.tabulate(bufferSize) { i =>
-    AttributeReference(s"M$i", DoubleType)()
+  override val updateExpressions: Seq[Expression] = {
+    val newN = n + Literal(1.0)
+    val delta = child - avg
+    val deltaN = delta / newN
+    val newAvg = avg + deltaN
+    val newM2 = m2 + delta * (delta - deltaN)
+
+    val delta2 = delta * delta
+    val deltaN2 = deltaN * deltaN
+    val newM3 = if (momentOrder >= 3) {
+      m3 - Literal(3.0) * deltaN * newM2 + delta * (delta2 - deltaN2)
+    } else {
+      Literal(0.0)
+    }
+    val newM4 = if (momentOrder >= 4) {
+      m4 - Literal(4.0) * deltaN * newM3 - Literal(6.0) * deltaN2 * newM2 +
+        delta * (delta * delta2 - deltaN * deltaN2)
+    } else {
+      Literal(0.0)
+    }
+
+    trimHigherOrder(Seq(
+      If(IsNull(child), n, newN),
+      If(IsNull(child), avg, newAvg),
+      If(IsNull(child), m2, newM2),
+      If(IsNull(child), m3, newM3),
+      If(IsNull(child), m4, newM4)
+    ))
   }
 
-  // Note: although this simply copies aggBufferAttributes, this common code 
can not be placed
-  // in the superclass because that will lead to initialization ordering 
issues.
-  override val inputAggBufferAttributes: Seq[AttributeReference] =
-    aggBufferAttributes.map(_.newInstance())
-
-  // buffer offsets
-  private[this] val nOffset = mutableAggBufferOffset
-  private[this] val meanOffset = mutableAggBufferOffset + 1
-  private[this] val secondMomentOffset = mutableAggBufferOffset + 2
-  private[this] val thirdMomentOffset = mutableAggBufferOffset + 3
-  private[this] val fourthMomentOffset = mutableAggBufferOffset + 4
-
-  // frequently used values for online updates
-  private[this] var delta = 0.0
-  private[this] var deltaN = 0.0
-  private[this] var delta2 = 0.0
-  private[this] var deltaN2 = 0.0
-  private[this] var n = 0.0
-  private[this] var mean = 0.0
-  private[this] var m2 = 0.0
-  private[this] var m3 = 0.0
-  private[this] var m4 = 0.0
+  override val mergeExpressions: Seq[Expression] = {
 
-  /**
-   * Initialize all moments to zero.
-   */
-  override def initialize(buffer: MutableRow): Unit = {
-    for (aggIndex <- 0 until bufferSize) {
-      buffer.setDouble(mutableAggBufferOffset + aggIndex, 0.0)
+    val n1 = n.left
+    val n2 = n.right
+    val newN = n1 + n2
+    val delta = avg.right - avg.left
+    val deltaN = If(newN === Literal(0.0), Literal(0.0), delta / newN)
+    val newAvg = avg.left + deltaN * n2
+
+    // higher order moments computed according to:
+    // 
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics
+    val newM2 = m2.left + m2.right + delta * deltaN * n1 * n2
+    // `m3.right` is not available if momentOrder < 3
+    val newM3 = if (momentOrder >= 3) {
+      m3.left + m3.right + deltaN * deltaN * delta * n1 * n2 * (n1 - n2) +
+        Literal(3.0) * deltaN * (n1 * m2.right - n2 * m2.left)
+    } else {
+      Literal(0.0)
     }
+    // `m4.right` is not available if momentOrder < 4
+    val newM4 = if (momentOrder >= 4) {
+      m4.left + m4.right +
+        deltaN * deltaN * deltaN * delta * n1 * n2 * (n1 * n1 - n1 * n2 + n2 * 
n2) +
+        Literal(6.0) * deltaN * deltaN * (n1 * n1 * m2.right + n2 * n2 * 
m2.left) +
+        Literal(4.0) * deltaN * (n1 * m3.right - n2 * m3.left)
+    } else {
+      Literal(0.0)
+    }
+
+    trimHigherOrder(Seq(newN, newAvg, newM2, newM3, newM4))
   }
+}
 
-  /**
-   * Update the central moments buffer.
-   */
-  override def update(buffer: MutableRow, input: InternalRow): Unit = {
-    val v = Cast(child, DoubleType).eval(input)
-    if (v != null) {
-      val updateValue = v match {
-        case d: Double => d
-      }
-
-      n = buffer.getDouble(nOffset)
-      mean = buffer.getDouble(meanOffset)
-
-      n += 1.0
-      buffer.setDouble(nOffset, n)
-      delta = updateValue - mean
-      deltaN = delta / n
-      mean += deltaN
-      buffer.setDouble(meanOffset, mean)
-
-      if (momentOrder >= 2) {
-        m2 = buffer.getDouble(secondMomentOffset)
-        m2 += delta * (delta - deltaN)
-        buffer.setDouble(secondMomentOffset, m2)
-      }
-
-      if (momentOrder >= 3) {
-        delta2 = delta * delta
-        deltaN2 = deltaN * deltaN
-        m3 = buffer.getDouble(thirdMomentOffset)
-        m3 += -3.0 * deltaN * m2 + delta * (delta2 - deltaN2)
-        buffer.setDouble(thirdMomentOffset, m3)
-      }
-
-      if (momentOrder >= 4) {
-        m4 = buffer.getDouble(fourthMomentOffset)
-        m4 += -4.0 * deltaN * m3 - 6.0 * deltaN2 * m2 +
-          delta * (delta * delta2 - deltaN * deltaN2)
-        buffer.setDouble(fourthMomentOffset, m4)
-      }
-    }
+// Compute the population standard deviation of a column
+case class StddevPop(child: Expression) extends CentralMomentAgg(child) {
+
+  override protected def momentOrder = 2
+
+  override val evaluateExpression: Expression = {
+    If(n === Literal(0.0), Literal.create(null, DoubleType),
+      Sqrt(m2 / n))
   }
 
-  /**
-   * Merge two central moment buffers.
-   */
-  override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
-    val n1 = buffer1.getDouble(nOffset)
-    val n2 = buffer2.getDouble(inputAggBufferOffset)
-    val mean1 = buffer1.getDouble(meanOffset)
-    val mean2 = buffer2.getDouble(inputAggBufferOffset + 1)
+  override def prettyName: String = "stddev_pop"
+}
+
+// Compute the sample standard deviation of a column
+case class StddevSamp(child: Expression) extends CentralMomentAgg(child) {
+
+  override protected def momentOrder = 2
 
-    var secondMoment1 = 0.0
-    var secondMoment2 = 0.0
+  override val evaluateExpression: Expression = {
+    If(n === Literal(0.0), Literal.create(null, DoubleType),
+      If(n === Literal(1.0), Literal(Double.NaN),
+        Sqrt(m2 / (n - Literal(1.0)))))
+  }
 
-    var thirdMoment1 = 0.0
-    var thirdMoment2 = 0.0
+  override def prettyName: String = "stddev_samp"
+}
 
-    var fourthMoment1 = 0.0
-    var fourthMoment2 = 0.0
+// Compute the population variance of a column
+case class VariancePop(child: Expression) extends CentralMomentAgg(child) {
 
-    n = n1 + n2
-    buffer1.setDouble(nOffset, n)
-    delta = mean2 - mean1
-    deltaN = if (n == 0.0) 0.0 else delta / n
-    mean = mean1 + deltaN * n2
-    buffer1.setDouble(mutableAggBufferOffset + 1, mean)
+  override protected def momentOrder = 2
 
-    // higher order moments computed according to:
-    // 
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics
-    if (momentOrder >= 2) {
-      secondMoment1 = buffer1.getDouble(secondMomentOffset)
-      secondMoment2 = buffer2.getDouble(inputAggBufferOffset + 2)
-      m2 = secondMoment1 + secondMoment2 + delta * deltaN * n1 * n2
-      buffer1.setDouble(secondMomentOffset, m2)
-    }
+  override val evaluateExpression: Expression = {
+    If(n === Literal(0.0), Literal.create(null, DoubleType),
+      m2 / n)
+  }
 
-    if (momentOrder >= 3) {
-      thirdMoment1 = buffer1.getDouble(thirdMomentOffset)
-      thirdMoment2 = buffer2.getDouble(inputAggBufferOffset + 3)
-      m3 = thirdMoment1 + thirdMoment2 + deltaN * deltaN * delta * n1 * n2 *
-        (n1 - n2) + 3.0 * deltaN * (n1 * secondMoment2 - n2 * secondMoment1)
-      buffer1.setDouble(thirdMomentOffset, m3)
-    }
+  override def prettyName: String = "var_pop"
+}
 
-    if (momentOrder >= 4) {
-      fourthMoment1 = buffer1.getDouble(fourthMomentOffset)
-      fourthMoment2 = buffer2.getDouble(inputAggBufferOffset + 4)
-      m4 = fourthMoment1 + fourthMoment2 + deltaN * deltaN * deltaN * delta * 
n1 *
-        n2 * (n1 * n1 - n1 * n2 + n2 * n2) + deltaN * deltaN * 6.0 *
-        (n1 * n1 * secondMoment2 + n2 * n2 * secondMoment1) +
-        4.0 * deltaN * (n1 * thirdMoment2 - n2 * thirdMoment1)
-      buffer1.setDouble(fourthMomentOffset, m4)
-    }
+// Compute the sample variance of a column
+case class VarianceSamp(child: Expression) extends CentralMomentAgg(child) {
+
+  override protected def momentOrder = 2
+
+  override val evaluateExpression: Expression = {
+    If(n === Literal(0.0), Literal.create(null, DoubleType),
+      If(n === Literal(1.0), Literal(Double.NaN),
+        m2 / (n - Literal(1.0))))
   }
 
-  /**
-   * Compute aggregate statistic from sufficient moments.
-   * @param centralMoments Length `momentOrder + 1` array of central moments 
(un-normalized)
-   *                       needed to compute the aggregate stat.
-   */
-  def getStatistic(n: Double, mean: Double, centralMoments: Array[Double]): Any
-
-  override final def eval(buffer: InternalRow): Any = {
-    val n = buffer.getDouble(nOffset)
-    val mean = buffer.getDouble(meanOffset)
-    val moments = Array.ofDim[Double](momentOrder + 1)
-    moments(0) = 1.0
-    moments(1) = 0.0
-    if (momentOrder >= 2) {
-      moments(2) = buffer.getDouble(secondMomentOffset)
-    }
-    if (momentOrder >= 3) {
-      moments(3) = buffer.getDouble(thirdMomentOffset)
-    }
-    if (momentOrder >= 4) {
-      moments(4) = buffer.getDouble(fourthMomentOffset)
-    }
+  override def prettyName: String = "var_samp"
+}
+
+case class Skewness(child: Expression) extends CentralMomentAgg(child) {
+
+  override def prettyName: String = "skewness"
+
+  override protected def momentOrder = 3
 
-    getStatistic(n, mean, moments)
+  override val evaluateExpression: Expression = {
+    If(n === Literal(0.0), Literal.create(null, DoubleType),
+      If(m2 === Literal(0.0), Literal(Double.NaN),
+        Sqrt(n) * m3 / Sqrt(m2 * m2 * m2)))
   }
 }
+
+case class Kurtosis(child: Expression) extends CentralMomentAgg(child) {
+
+  override protected def momentOrder = 4
+
+  override val evaluateExpression: Expression = {
+    If(n === Literal(0.0), Literal.create(null, DoubleType),
+      If(m2 === Literal(0.0), Literal(Double.NaN),
+        n * m4 / (m2 * m2) - Literal(3.0)))
+  }
+
+  override def prettyName: String = "kurtosis"
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/be5dd881/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
index d25f333..e6b8214 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
@@ -17,8 +17,7 @@
 
 package org.apache.spark.sql.catalyst.expressions.aggregate
 
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.types._
 
@@ -29,165 +28,70 @@ import org.apache.spark.sql.types._
  * Definition of Pearson correlation can be found at
  * http://en.wikipedia.org/wiki/Pearson_product-moment_correlation_coefficient
  */
-case class Corr(
-    left: Expression,
-    right: Expression,
-    mutableAggBufferOffset: Int = 0,
-    inputAggBufferOffset: Int = 0)
-  extends ImperativeAggregate {
-
-  def this(left: Expression, right: Expression) =
-    this(left, right, mutableAggBufferOffset = 0, inputAggBufferOffset = 0)
-
-  override def children: Seq[Expression] = Seq(left, right)
+case class Corr(x: Expression, y: Expression) extends DeclarativeAggregate {
 
+  override def children: Seq[Expression] = Seq(x, y)
   override def nullable: Boolean = true
-
   override def dataType: DataType = DoubleType
-
   override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)
 
-  override def checkInputDataTypes(): TypeCheckResult = {
-    if (left.dataType.isInstanceOf[DoubleType] && 
right.dataType.isInstanceOf[DoubleType]) {
-      TypeCheckResult.TypeCheckSuccess
-    } else {
-      TypeCheckResult.TypeCheckFailure(
-        s"corr requires that both arguments are double type, " +
-          s"not (${left.dataType}, ${right.dataType}).")
-    }
+  protected val n = AttributeReference("n", DoubleType, nullable = false)()
+  protected val xAvg = AttributeReference("xAvg", DoubleType, nullable = 
false)()
+  protected val yAvg = AttributeReference("yAvg", DoubleType, nullable = 
false)()
+  protected val ck = AttributeReference("ck", DoubleType, nullable = false)()
+  protected val xMk = AttributeReference("xMk", DoubleType, nullable = false)()
+  protected val yMk = AttributeReference("yMk", DoubleType, nullable = false)()
+
+  override val aggBufferAttributes: Seq[AttributeReference] = Seq(n, xAvg, 
yAvg, ck, xMk, yMk)
+
+  override val initialValues: Seq[Expression] = Array.fill(6)(Literal(0.0))
+
+  override val updateExpressions: Seq[Expression] = {
+    val newN = n + Literal(1.0)
+    val dx = x - xAvg
+    val dxN = dx / newN
+    val dy = y - yAvg
+    val dyN = dy / newN
+    val newXAvg = xAvg + dxN
+    val newYAvg = yAvg + dyN
+    val newCk = ck + dx * (y - newYAvg)
+    val newXMk = xMk + dx * (x - newXAvg)
+    val newYMk = yMk + dy * (y - newYAvg)
+
+    val isNull = IsNull(x) || IsNull(y)
+    Seq(
+      If(isNull, n, newN),
+      If(isNull, xAvg, newXAvg),
+      If(isNull, yAvg, newYAvg),
+      If(isNull, ck, newCk),
+      If(isNull, xMk, newXMk),
+      If(isNull, yMk, newYMk)
+    )
   }
 
-  override def aggBufferSchema: StructType = 
StructType.fromAttributes(aggBufferAttributes)
-
-  override def inputAggBufferAttributes: Seq[AttributeReference] = {
-    aggBufferAttributes.map(_.newInstance())
+  override val mergeExpressions: Seq[Expression] = {
+
+    val n1 = n.left
+    val n2 = n.right
+    val newN = n1 + n2
+    val dx = xAvg.right - xAvg.left
+    val dxN = If(newN === Literal(0.0), Literal(0.0), dx / newN)
+    val dy = yAvg.right - yAvg.left
+    val dyN = If(newN === Literal(0.0), Literal(0.0), dy / newN)
+    val newXAvg = xAvg.left + dxN * n2
+    val newYAvg = yAvg.left + dyN * n2
+    val newCk = ck.left + ck.right + dx * dyN * n1 * n2
+    val newXMk = xMk.left + xMk.right + dx * dxN * n1 * n2
+    val newYMk = yMk.left + yMk.right + dy * dyN * n1 * n2
+
+    Seq(newN, newXAvg, newYAvg, newCk, newXMk, newYMk)
   }
 
-  override val aggBufferAttributes: Seq[AttributeReference] = Seq(
-    AttributeReference("xAvg", DoubleType)(),
-    AttributeReference("yAvg", DoubleType)(),
-    AttributeReference("Ck", DoubleType)(),
-    AttributeReference("MkX", DoubleType)(),
-    AttributeReference("MkY", DoubleType)(),
-    AttributeReference("count", LongType)())
-
-  // Local cache of mutableAggBufferOffset(s) that will be used in update and 
merge
-  private[this] val mutableAggBufferOffsetPlus1 = mutableAggBufferOffset + 1
-  private[this] val mutableAggBufferOffsetPlus2 = mutableAggBufferOffset + 2
-  private[this] val mutableAggBufferOffsetPlus3 = mutableAggBufferOffset + 3
-  private[this] val mutableAggBufferOffsetPlus4 = mutableAggBufferOffset + 4
-  private[this] val mutableAggBufferOffsetPlus5 = mutableAggBufferOffset + 5
-
-  // Local cache of inputAggBufferOffset(s) that will be used in update and 
merge
-  private[this] val inputAggBufferOffsetPlus1 = inputAggBufferOffset + 1
-  private[this] val inputAggBufferOffsetPlus2 = inputAggBufferOffset + 2
-  private[this] val inputAggBufferOffsetPlus3 = inputAggBufferOffset + 3
-  private[this] val inputAggBufferOffsetPlus4 = inputAggBufferOffset + 4
-  private[this] val inputAggBufferOffsetPlus5 = inputAggBufferOffset + 5
-
-  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): 
ImperativeAggregate =
-    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
-
-  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): 
ImperativeAggregate =
-    copy(inputAggBufferOffset = newInputAggBufferOffset)
-
-  override def initialize(buffer: MutableRow): Unit = {
-    buffer.setDouble(mutableAggBufferOffset, 0.0)
-    buffer.setDouble(mutableAggBufferOffsetPlus1, 0.0)
-    buffer.setDouble(mutableAggBufferOffsetPlus2, 0.0)
-    buffer.setDouble(mutableAggBufferOffsetPlus3, 0.0)
-    buffer.setDouble(mutableAggBufferOffsetPlus4, 0.0)
-    buffer.setLong(mutableAggBufferOffsetPlus5, 0L)
+  override val evaluateExpression: Expression = {
+    If(n === Literal(0.0), Literal.create(null, DoubleType),
+      If(n === Literal(1.0), Literal(Double.NaN),
+        ck / Sqrt(xMk * yMk)))
   }
 
-  override def update(buffer: MutableRow, input: InternalRow): Unit = {
-    val leftEval = left.eval(input)
-    val rightEval = right.eval(input)
-
-    if (leftEval != null && rightEval != null) {
-      val x = leftEval.asInstanceOf[Double]
-      val y = rightEval.asInstanceOf[Double]
-
-      var xAvg = buffer.getDouble(mutableAggBufferOffset)
-      var yAvg = buffer.getDouble(mutableAggBufferOffsetPlus1)
-      var Ck = buffer.getDouble(mutableAggBufferOffsetPlus2)
-      var MkX = buffer.getDouble(mutableAggBufferOffsetPlus3)
-      var MkY = buffer.getDouble(mutableAggBufferOffsetPlus4)
-      var count = buffer.getLong(mutableAggBufferOffsetPlus5)
-
-      val deltaX = x - xAvg
-      val deltaY = y - yAvg
-      count += 1
-      xAvg += deltaX / count
-      yAvg += deltaY / count
-      Ck += deltaX * (y - yAvg)
-      MkX += deltaX * (x - xAvg)
-      MkY += deltaY * (y - yAvg)
-
-      buffer.setDouble(mutableAggBufferOffset, xAvg)
-      buffer.setDouble(mutableAggBufferOffsetPlus1, yAvg)
-      buffer.setDouble(mutableAggBufferOffsetPlus2, Ck)
-      buffer.setDouble(mutableAggBufferOffsetPlus3, MkX)
-      buffer.setDouble(mutableAggBufferOffsetPlus4, MkY)
-      buffer.setLong(mutableAggBufferOffsetPlus5, count)
-    }
-  }
-
-  // Merge counters from other partitions. Formula can be found at:
-  // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
-  override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
-    val count2 = buffer2.getLong(inputAggBufferOffsetPlus5)
-
-    // We only go to merge two buffers if there is at least one record 
aggregated in buffer2.
-    // We don't need to check count in buffer1 because if count2 is more than 
zero, totalCount
-    // is more than zero too, then we won't get a divide by zero exception.
-    if (count2 > 0) {
-      var xAvg = buffer1.getDouble(mutableAggBufferOffset)
-      var yAvg = buffer1.getDouble(mutableAggBufferOffsetPlus1)
-      var Ck = buffer1.getDouble(mutableAggBufferOffsetPlus2)
-      var MkX = buffer1.getDouble(mutableAggBufferOffsetPlus3)
-      var MkY = buffer1.getDouble(mutableAggBufferOffsetPlus4)
-      var count = buffer1.getLong(mutableAggBufferOffsetPlus5)
-
-      val xAvg2 = buffer2.getDouble(inputAggBufferOffset)
-      val yAvg2 = buffer2.getDouble(inputAggBufferOffsetPlus1)
-      val Ck2 = buffer2.getDouble(inputAggBufferOffsetPlus2)
-      val MkX2 = buffer2.getDouble(inputAggBufferOffsetPlus3)
-      val MkY2 = buffer2.getDouble(inputAggBufferOffsetPlus4)
-
-      val totalCount = count + count2
-      val deltaX = xAvg - xAvg2
-      val deltaY = yAvg - yAvg2
-      Ck += Ck2 + deltaX * deltaY * count / totalCount * count2
-      xAvg = (xAvg * count + xAvg2 * count2) / totalCount
-      yAvg = (yAvg * count + yAvg2 * count2) / totalCount
-      MkX += MkX2 + deltaX * deltaX * count / totalCount * count2
-      MkY += MkY2 + deltaY * deltaY * count / totalCount * count2
-      count = totalCount
-
-      buffer1.setDouble(mutableAggBufferOffset, xAvg)
-      buffer1.setDouble(mutableAggBufferOffsetPlus1, yAvg)
-      buffer1.setDouble(mutableAggBufferOffsetPlus2, Ck)
-      buffer1.setDouble(mutableAggBufferOffsetPlus3, MkX)
-      buffer1.setDouble(mutableAggBufferOffsetPlus4, MkY)
-      buffer1.setLong(mutableAggBufferOffsetPlus5, count)
-    }
-  }
-
-  override def eval(buffer: InternalRow): Any = {
-    val count = buffer.getLong(mutableAggBufferOffsetPlus5)
-    if (count > 0) {
-      val Ck = buffer.getDouble(mutableAggBufferOffsetPlus2)
-      val MkX = buffer.getDouble(mutableAggBufferOffsetPlus3)
-      val MkY = buffer.getDouble(mutableAggBufferOffsetPlus4)
-      val corr = Ck / math.sqrt(MkX * MkY)
-      if (corr.isNaN) {
-        null
-      } else {
-        corr
-      }
-    } else {
-      null
-    }
-  }
+  override def prettyName: String = "corr"
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/be5dd881/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
index f53b01b..c175a8c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
@@ -17,182 +17,79 @@
 
 package org.apache.spark.sql.catalyst.expressions.aggregate
 
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.util.TypeUtils
 import org.apache.spark.sql.types._
 
 /**
  * Compute the covariance between two expressions.
  * When applied on empty data (i.e., count is zero), it returns NULL.
- *
  */
-abstract class Covariance(left: Expression, right: Expression) extends 
ImperativeAggregate
-    with Serializable {
-  override def children: Seq[Expression] = Seq(left, right)
+abstract class Covariance(x: Expression, y: Expression) extends 
DeclarativeAggregate {
 
+  override def children: Seq[Expression] = Seq(x, y)
   override def nullable: Boolean = true
-
   override def dataType: DataType = DoubleType
-
   override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)
 
-  override def checkInputDataTypes(): TypeCheckResult = {
-    if (left.dataType.isInstanceOf[DoubleType] && 
right.dataType.isInstanceOf[DoubleType]) {
-      TypeCheckResult.TypeCheckSuccess
-    } else {
-      TypeCheckResult.TypeCheckFailure(
-        s"covariance requires that both arguments are double type, " +
-          s"not (${left.dataType}, ${right.dataType}).")
-    }
-  }
-
-  override def aggBufferSchema: StructType = 
StructType.fromAttributes(aggBufferAttributes)
-
-  override def inputAggBufferAttributes: Seq[AttributeReference] = {
-    aggBufferAttributes.map(_.newInstance())
-  }
-
-  override val aggBufferAttributes: Seq[AttributeReference] = Seq(
-    AttributeReference("xAvg", DoubleType)(),
-    AttributeReference("yAvg", DoubleType)(),
-    AttributeReference("Ck", DoubleType)(),
-    AttributeReference("count", LongType)())
-
-  // Local cache of mutableAggBufferOffset(s) that will be used in update and 
merge
-  val xAvgOffset = mutableAggBufferOffset
-  val yAvgOffset = mutableAggBufferOffset + 1
-  val CkOffset = mutableAggBufferOffset + 2
-  val countOffset = mutableAggBufferOffset + 3
-
-  // Local cache of inputAggBufferOffset(s) that will be used in update and 
merge
-  val inputXAvgOffset = inputAggBufferOffset
-  val inputYAvgOffset = inputAggBufferOffset + 1
-  val inputCkOffset = inputAggBufferOffset + 2
-  val inputCountOffset = inputAggBufferOffset + 3
-
-  override def initialize(buffer: MutableRow): Unit = {
-    buffer.setDouble(xAvgOffset, 0.0)
-    buffer.setDouble(yAvgOffset, 0.0)
-    buffer.setDouble(CkOffset, 0.0)
-    buffer.setLong(countOffset, 0L)
-  }
-
-  override def update(buffer: MutableRow, input: InternalRow): Unit = {
-    val leftEval = left.eval(input)
-    val rightEval = right.eval(input)
-
-    if (leftEval != null && rightEval != null) {
-      val x = leftEval.asInstanceOf[Double]
-      val y = rightEval.asInstanceOf[Double]
-
-      var xAvg = buffer.getDouble(xAvgOffset)
-      var yAvg = buffer.getDouble(yAvgOffset)
-      var Ck = buffer.getDouble(CkOffset)
-      var count = buffer.getLong(countOffset)
-
-      val deltaX = x - xAvg
-      val deltaY = y - yAvg
-      count += 1
-      xAvg += deltaX / count
-      yAvg += deltaY / count
-      Ck += deltaX * (y - yAvg)
-
-      buffer.setDouble(xAvgOffset, xAvg)
-      buffer.setDouble(yAvgOffset, yAvg)
-      buffer.setDouble(CkOffset, Ck)
-      buffer.setLong(countOffset, count)
-    }
+  protected val n = AttributeReference("n", DoubleType, nullable = false)()
+  protected val xAvg = AttributeReference("xAvg", DoubleType, nullable = 
false)()
+  protected val yAvg = AttributeReference("yAvg", DoubleType, nullable = 
false)()
+  protected val ck = AttributeReference("ck", DoubleType, nullable = false)()
+
+  override val aggBufferAttributes: Seq[AttributeReference] = Seq(n, xAvg, 
yAvg, ck)
+
+  override val initialValues: Seq[Expression] = Array.fill(4)(Literal(0.0))
+
+  override lazy val updateExpressions: Seq[Expression] = {
+    val newN = n + Literal(1.0)
+    val dx = x - xAvg
+    val dy = y - yAvg
+    val dyN = dy / newN
+    val newXAvg = xAvg + dx / newN
+    val newYAvg = yAvg + dyN
+    val newCk = ck + dx * (y - newYAvg)
+
+    val isNull = IsNull(x) || IsNull(y)
+    Seq(
+      If(isNull, n, newN),
+      If(isNull, xAvg, newXAvg),
+      If(isNull, yAvg, newYAvg),
+      If(isNull, ck, newCk)
+    )
   }
 
-  // Merge counters from other partitions. Formula can be found at:
-  // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
-  override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
-    val count2 = buffer2.getLong(inputCountOffset)
-
-    // We only go to merge two buffers if there is at least one record 
aggregated in buffer2.
-    // We don't need to check count in buffer1 because if count2 is more than 
zero, totalCount
-    // is more than zero too, then we won't get a divide by zero exception.
-    if (count2 > 0) {
-      var xAvg = buffer1.getDouble(xAvgOffset)
-      var yAvg = buffer1.getDouble(yAvgOffset)
-      var Ck = buffer1.getDouble(CkOffset)
-      var count = buffer1.getLong(countOffset)
+  override val mergeExpressions: Seq[Expression] = {
 
-      val xAvg2 = buffer2.getDouble(inputXAvgOffset)
-      val yAvg2 = buffer2.getDouble(inputYAvgOffset)
-      val Ck2 = buffer2.getDouble(inputCkOffset)
+    val n1 = n.left
+    val n2 = n.right
+    val newN = n1 + n2
+    val dx = xAvg.right - xAvg.left
+    val dxN = If(newN === Literal(0.0), Literal(0.0), dx / newN)
+    val dy = yAvg.right - yAvg.left
+    val dyN = If(newN === Literal(0.0), Literal(0.0), dy / newN)
+    val newXAvg = xAvg.left + dxN * n2
+    val newYAvg = yAvg.left + dyN * n2
+    val newCk = ck.left + ck.right + dx * dyN * n1 * n2
 
-      val totalCount = count + count2
-      val deltaX = xAvg - xAvg2
-      val deltaY = yAvg - yAvg2
-      Ck += Ck2 + deltaX * deltaY * count / totalCount * count2
-      xAvg = (xAvg * count + xAvg2 * count2) / totalCount
-      yAvg = (yAvg * count + yAvg2 * count2) / totalCount
-      count = totalCount
-
-      buffer1.setDouble(xAvgOffset, xAvg)
-      buffer1.setDouble(yAvgOffset, yAvg)
-      buffer1.setDouble(CkOffset, Ck)
-      buffer1.setLong(countOffset, count)
-    }
+    Seq(newN, newXAvg, newYAvg, newCk)
   }
 }
 
-case class CovSample(
-    left: Expression,
-    right: Expression,
-    mutableAggBufferOffset: Int = 0,
-    inputAggBufferOffset: Int = 0)
-  extends Covariance(left, right) {
-
-  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): 
ImperativeAggregate =
-    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
-
-  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): 
ImperativeAggregate =
-    copy(inputAggBufferOffset = newInputAggBufferOffset)
-
-  override def eval(buffer: InternalRow): Any = {
-    val count = buffer.getLong(countOffset)
-    if (count > 1) {
-      val Ck = buffer.getDouble(CkOffset)
-      val cov = Ck / (count - 1)
-      if (cov.isNaN) {
-        null
-      } else {
-        cov
-      }
-    } else {
-      null
-    }
+case class CovPopulation(left: Expression, right: Expression) extends 
Covariance(left, right) {
+  override val evaluateExpression: Expression = {
+    If(n === Literal(0.0), Literal.create(null, DoubleType),
+      ck / n)
   }
+  override def prettyName: String = "covar_pop"
 }
 
-case class CovPopulation(
-    left: Expression,
-    right: Expression,
-    mutableAggBufferOffset: Int = 0,
-    inputAggBufferOffset: Int = 0)
-  extends Covariance(left, right) {
-
-  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): 
ImperativeAggregate =
-    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
-
-  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): 
ImperativeAggregate =
-    copy(inputAggBufferOffset = newInputAggBufferOffset)
 
-  override def eval(buffer: InternalRow): Any = {
-    val count = buffer.getLong(countOffset)
-    if (count > 0) {
-      val Ck = buffer.getDouble(CkOffset)
-      if (Ck.isNaN) {
-        null
-      } else {
-        Ck / count
-      }
-    } else {
-      null
-    }
+case class CovSample(left: Expression, right: Expression) extends 
Covariance(left, right) {
+  override val evaluateExpression: Expression = {
+    If(n === Literal(0.0), Literal.create(null, DoubleType),
+      If(n === Literal(1.0), Literal(Double.NaN),
+        ck / (n - Literal(1.0))))
   }
+  override def prettyName: String = "covar_samp"
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/be5dd881/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala
deleted file mode 100644
index c2bf2cb..0000000
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala
+++ /dev/null
@@ -1,54 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions.aggregate
-
-import org.apache.spark.sql.catalyst.expressions._
-
-case class Kurtosis(child: Expression,
-    mutableAggBufferOffset: Int = 0,
-    inputAggBufferOffset: Int = 0)
-  extends CentralMomentAgg(child) {
-
-  def this(child: Expression) = this(child, mutableAggBufferOffset = 0, 
inputAggBufferOffset = 0)
-
-  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): 
ImperativeAggregate =
-    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
-
-  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): 
ImperativeAggregate =
-    copy(inputAggBufferOffset = newInputAggBufferOffset)
-
-  override def prettyName: String = "kurtosis"
-
-  override protected val momentOrder = 4
-
-  // NOTE: this is the formula for excess kurtosis, which is default for R and 
SciPy
-  override def getStatistic(n: Double, mean: Double, moments: Array[Double]): 
Any = {
-    require(moments.length == momentOrder + 1,
-      s"$prettyName requires ${momentOrder + 1} central moments, received: 
${moments.length}")
-    val m2 = moments(2)
-    val m4 = moments(4)
-
-    if (n == 0.0) {
-      null
-    } else if (m2 == 0.0) {
-      Double.NaN
-    } else {
-      n * m4 / (m2 * m2) - 3.0
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/be5dd881/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala
deleted file mode 100644
index 9411bce..0000000
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala
+++ /dev/null
@@ -1,53 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions.aggregate
-
-import org.apache.spark.sql.catalyst.expressions._
-
-case class Skewness(child: Expression,
-    mutableAggBufferOffset: Int = 0,
-    inputAggBufferOffset: Int = 0)
-  extends CentralMomentAgg(child) {
-
-  def this(child: Expression) = this(child, mutableAggBufferOffset = 0, 
inputAggBufferOffset = 0)
-
-  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): 
ImperativeAggregate =
-    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
-
-  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): 
ImperativeAggregate =
-    copy(inputAggBufferOffset = newInputAggBufferOffset)
-
-  override def prettyName: String = "skewness"
-
-  override protected val momentOrder = 3
-
-  override def getStatistic(n: Double, mean: Double, moments: Array[Double]): 
Any = {
-    require(moments.length == momentOrder + 1,
-      s"$prettyName requires ${momentOrder + 1} central moments, received: 
${moments.length}")
-    val m2 = moments(2)
-    val m3 = moments(3)
-
-    if (n == 0.0) {
-      null
-    } else if (m2 == 0.0) {
-      Double.NaN
-    } else {
-      math.sqrt(n) * m3 / math.sqrt(m2 * m2 * m2)
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/be5dd881/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
deleted file mode 100644
index eec79a9..0000000
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
+++ /dev/null
@@ -1,81 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions.aggregate
-
-import org.apache.spark.sql.catalyst.expressions._
-
-case class StddevSamp(child: Expression,
-    mutableAggBufferOffset: Int = 0,
-    inputAggBufferOffset: Int = 0)
-  extends CentralMomentAgg(child) {
-
-  def this(child: Expression) = this(child, mutableAggBufferOffset = 0, 
inputAggBufferOffset = 0)
-
-  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): 
ImperativeAggregate =
-    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
-
-  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): 
ImperativeAggregate =
-    copy(inputAggBufferOffset = newInputAggBufferOffset)
-
-  override def prettyName: String = "stddev_samp"
-
-  override protected val momentOrder = 2
-
-  override def getStatistic(n: Double, mean: Double, moments: Array[Double]): 
Any = {
-    require(moments.length == momentOrder + 1,
-      s"$prettyName requires ${momentOrder + 1} central moment, received: 
${moments.length}")
-
-    if (n == 0.0) {
-      null
-    } else if (n == 1.0) {
-      Double.NaN
-    } else {
-      math.sqrt(moments(2) / (n - 1.0))
-    }
-  }
-}
-
-case class StddevPop(
-    child: Expression,
-    mutableAggBufferOffset: Int = 0,
-    inputAggBufferOffset: Int = 0)
-  extends CentralMomentAgg(child) {
-
-  def this(child: Expression) = this(child, mutableAggBufferOffset = 0, 
inputAggBufferOffset = 0)
-
-  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): 
ImperativeAggregate =
-    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
-
-  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): 
ImperativeAggregate =
-    copy(inputAggBufferOffset = newInputAggBufferOffset)
-
-  override def prettyName: String = "stddev_pop"
-
-  override protected val momentOrder = 2
-
-  override def getStatistic(n: Double, mean: Double, moments: Array[Double]): 
Any = {
-    require(moments.length == momentOrder + 1,
-      s"$prettyName requires ${momentOrder + 1} central moment, received: 
${moments.length}")
-
-    if (n == 0.0) {
-      null
-    } else {
-      math.sqrt(moments(2) / n)
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/be5dd881/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala
deleted file mode 100644
index cf3a740..0000000
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala
+++ /dev/null
@@ -1,81 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions.aggregate
-
-import org.apache.spark.sql.catalyst.expressions._
-
-case class VarianceSamp(child: Expression,
-    mutableAggBufferOffset: Int = 0,
-    inputAggBufferOffset: Int = 0)
-  extends CentralMomentAgg(child) {
-
-  def this(child: Expression) = this(child, mutableAggBufferOffset = 0, 
inputAggBufferOffset = 0)
-
-  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): 
ImperativeAggregate =
-    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
-
-  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): 
ImperativeAggregate =
-    copy(inputAggBufferOffset = newInputAggBufferOffset)
-
-  override def prettyName: String = "var_samp"
-
-  override protected val momentOrder = 2
-
-  override def getStatistic(n: Double, mean: Double, moments: Array[Double]): 
Any = {
-    require(moments.length == momentOrder + 1,
-      s"$prettyName requires ${momentOrder + 1} central moment, received: 
${moments.length}")
-
-    if (n == 0.0) {
-      null
-    } else if (n == 1.0) {
-      Double.NaN
-    } else {
-      moments(2) / (n - 1.0)
-    }
-  }
-}
-
-case class VariancePop(
-    child: Expression,
-    mutableAggBufferOffset: Int = 0,
-    inputAggBufferOffset: Int = 0)
-  extends CentralMomentAgg(child) {
-
-  def this(child: Expression) = this(child, mutableAggBufferOffset = 0, 
inputAggBufferOffset = 0)
-
-  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): 
ImperativeAggregate =
-    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
-
-  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): 
ImperativeAggregate =
-    copy(inputAggBufferOffset = newInputAggBufferOffset)
-
-  override def prettyName: String = "var_pop"
-
-  override protected val momentOrder = 2
-
-  override def getStatistic(n: Double, mean: Double, moments: Array[Double]): 
Any = {
-    require(moments.length == momentOrder + 1,
-      s"$prettyName requires ${momentOrder + 1} central moment, received: 
${moments.length}")
-
-    if (n == 0.0) {
-      null
-    } else {
-      moments(2) / n
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/be5dd881/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index 36e1fa1..f4ccadd 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -424,3 +424,21 @@ case class Murmur3Hash(children: Seq[Expression], seed: 
Int) extends Expression
     }
   }
 }
+
+/**
+  * Print the result of an expression to stderr (used for debugging codegen).
+  */
+case class PrintToStderr(child: Expression) extends UnaryExpression {
+
+  override def dataType: DataType = child.dataType
+
+  protected override def nullSafeEval(input: Any): Any = input
+
+  override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
+    nullSafeCodeGen(ctx, ev, c =>
+      s"""
+         | System.err.println("Result of ${child.simpleString} is " + $c);
+         | ${ev.value} = $c;
+       """.stripMargin)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/be5dd881/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
index 26a7340..84154a4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
@@ -198,7 +198,8 @@ case class Window(
           functions,
           ordinal,
           child.output,
-          (expressions, schema) => newMutableProjection(expressions, schema))
+          (expressions, schema) =>
+            newMutableProjection(expressions, schema, 
subexpressionEliminationEnabled))
 
         // Create the factory
         val factory = key match {
@@ -210,7 +211,8 @@ case class Window(
                 ordinal,
                 functions,
                 child.output,
-                (expressions, schema) => newMutableProjection(expressions, 
schema),
+                (expressions, schema) =>
+                  newMutableProjection(expressions, schema, 
subexpressionEliminationEnabled),
                 offset)
 
           // Growing Frame.

http://git-wip-us.apache.org/repos/asf/spark/blob/be5dd881/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 57db726..a8a81d6 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -240,7 +240,6 @@ case class TungstenAggregate(
          | ${bufVars(i).value} = ${ev.value};
        """.stripMargin
     }
-
     s"""
        | // do aggregate
        | ${aggVals.map(_.code).mkString("\n")}

http://git-wip-us.apache.org/repos/asf/spark/blob/be5dd881/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index 2f09c8a..1ccf0e3 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -59,6 +59,55 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
     benchmark.run()
   }
 
+  def testStatFunctions(values: Int): Unit = {
+
+    val benchmark = new Benchmark("stat functions", values)
+
+    benchmark.addCase("stddev w/o codegen") { iter =>
+      sqlContext.setConf("spark.sql.codegen.wholeStage", "false")
+      sqlContext.range(values).groupBy().agg("id" -> "stddev").collect()
+    }
+
+    benchmark.addCase("stddev w codegen") { iter =>
+      sqlContext.setConf("spark.sql.codegen.wholeStage", "true")
+      sqlContext.range(values).groupBy().agg("id" -> "stddev").collect()
+    }
+
+    benchmark.addCase("kurtosis w/o codegen") { iter =>
+      sqlContext.setConf("spark.sql.codegen.wholeStage", "false")
+      sqlContext.range(values).groupBy().agg("id" -> "kurtosis").collect()
+    }
+
+    benchmark.addCase("kurtosis w codegen") { iter =>
+      sqlContext.setConf("spark.sql.codegen.wholeStage", "true")
+      sqlContext.range(values).groupBy().agg("id" -> "kurtosis").collect()
+    }
+
+
+    /**
+      Using ImperativeAggregate (as implemented in Spark 1.6):
+
+      Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+      stddev:                            Avg Time(ms)    Avg Rate(M/s)  
Relative Rate
+      
-------------------------------------------------------------------------------
+      stddev w/o codegen                      2019.04            10.39         
1.00 X
+      stddev w codegen                        2097.29            10.00         
0.96 X
+      kurtosis w/o codegen                    2108.99             9.94         
0.96 X
+      kurtosis w codegen                      2090.69            10.03         
0.97 X
+
+      Using DeclarativeAggregate:
+
+      Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+      stddev:                            Avg Time(ms)    Avg Rate(M/s)  
Relative Rate
+      
-------------------------------------------------------------------------------
+      stddev w/o codegen                       989.22            21.20         
1.00 X
+      stddev w codegen                         352.35            59.52         
2.81 X
+      kurtosis w/o codegen                    3636.91             5.77         
0.27 X
+      kurtosis w codegen                       369.25            56.79         
2.68 X
+      */
+    benchmark.run()
+  }
+
   def testAggregateWithKey(values: Int): Unit = {
     val benchmark = new Benchmark("Aggregate with keys", values)
 
@@ -147,8 +196,10 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
     benchmark.run()
   }
 
-  test("benchmark") {
-    // testWholeStage(1024 * 1024 * 200)
+  // These benchmark are skipped in normal build
+  ignore("benchmark") {
+    // testWholeStage(200 << 20)
+    // testStddev(20 << 20)
     // testAggregateWithKey(20 << 20)
     // testBytesToBytesMap(1024 * 1024 * 50)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/be5dd881/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
 
b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index 554d47d..61b73fa 100644
--- 
a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ 
b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -325,6 +325,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with 
BeforeAndAfter {
     "drop_partitions_ignore_protection",
     "protectmode",
 
+    // Hive returns null rather than NaN when n = 1
+    "udaf_covar_samp",
+
     // Spark parser treats numerical literals differently: it creates decimals 
instead of doubles.
     "udf_abs",
     "udf_format_number",
@@ -881,7 +884,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with 
BeforeAndAfter {
     "type_widening",
     "udaf_collect_set",
     "udaf_covar_pop",
-    "udaf_covar_samp",
     "udaf_histogram_numeric",
     "udf2",
     "udf5",

http://git-wip-us.apache.org/repos/asf/spark/blob/be5dd881/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index 7a9ed1e..caf1db9 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -798,7 +798,7 @@ abstract class AggregationQuerySuite extends QueryTest with 
SQLTestUtils with Te
         """
           |SELECT corr(b, c) FROM covar_tab WHERE a = 3
         """.stripMargin),
-      Row(null) :: Nil)
+      Row(Double.NaN) :: Nil)
 
     checkAnswer(
       sqlContext.sql(
@@ -807,10 +807,10 @@ abstract class AggregationQuerySuite extends QueryTest 
with SQLTestUtils with Te
         """.stripMargin),
       Row(1, null) ::
       Row(2, null) ::
-      Row(3, null) ::
-      Row(4, null) ::
-      Row(5, null) ::
-      Row(6, null) :: Nil)
+      Row(3, Double.NaN) ::
+      Row(4, Double.NaN) ::
+      Row(5, Double.NaN) ::
+      Row(6, Double.NaN) :: Nil)
 
     val corr7 = sqlContext.sql("SELECT corr(b, c) FROM 
covar_tab").collect()(0).getDouble(0)
     assert(math.abs(corr7 - 0.6633880657639323) < 1e-12)
@@ -841,11 +841,8 @@ abstract class AggregationQuerySuite extends QueryTest 
with SQLTestUtils with Te
 
     // one row test
     val df3 = Seq.tabulate(1)(x => (1 * x, x * x * x - 2)).toDF("a", "b")
-    val cov_samp3 = df3.groupBy().agg(covar_samp("a", "b")).collect()(0).get(0)
-    assert(cov_samp3 == null)
-
-    val cov_pop3 = df3.groupBy().agg(covar_pop("a", 
"b")).collect()(0).getDouble(0)
-    assert(cov_pop3 == 0.0)
+    checkAnswer(df3.groupBy().agg(covar_samp("a", "b")), Row(Double.NaN))
+    checkAnswer(df3.groupBy().agg(covar_pop("a", "b")), Row(0.0))
   }
 
   test("no aggregation function (SPARK-11486)") {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to