Github user sethah commented on a diff in the pull request:
https://github.com/apache/spark/pull/9003#discussion_r42453520
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
---
@@ -842,3 +699,302 @@ object HyperLogLogPlusPlus {
)
// scalastyle:on
}
+
+abstract class CentralMomentAgg(child: Expression) extends
ImperativeAggregate with Serializable {
+
+ /**
+ * The maximum central moment order to be computed
+ */
+ protected def maxMoment: Int
+
+ override def children: Seq[Expression] = Seq(child)
+
+ override def nullable: Boolean = false
+
+ override def dataType: DataType = DoubleType
+
+ // Expected input data type.
+ // TODO: Right now, we replace old aggregate functions (based on
AggregateExpression1) to the
+ // new version at planning time (after analysis phase). For now,
NullType is added at here
+ // to make it resolved when we have cases like `select avg(null)`.
+ // We can use our analyzer to cast NullType to the default data type of
the NumericType once
+ // we remove the old aggregate functions. Then, we will not need
NullType at here.
+ override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(NumericType, NullType))
+
+ override def aggBufferSchema: StructType =
StructType.fromAttributes(aggBufferAttributes)
+
+ /**
+ * The number of central moments to store in the buffer
+ */
+ private[this] val numMoments = 5
+
+ override val aggBufferAttributes: Seq[AttributeReference] =
Seq.tabulate(numMoments) { i =>
+ AttributeReference(s"M$i", DoubleType)()
+ }
+
+ /**
+ * Initialize all moments to zero
+ */
+ override def initialize(buffer: MutableRow): Unit = {
+ var aggIndex = 0
+ while (aggIndex < numMoments) {
+ buffer.setDouble(mutableAggBufferOffset + aggIndex, 0.0)
+ aggIndex += 1
+ }
+ }
+
+ /**
+ * Update the central moments buffer.
+ */
+ override def update(buffer: MutableRow, input: InternalRow): Unit = {
+ val v = Cast(child, DoubleType).eval(input)
+ if (v != null) {
+ val updateValue = v match {
+ case d: Double => d
+ case _ => 0.0
+ }
+ val currentM0 = buffer.getDouble(mutableAggBufferOffset)
+ val currentM1 = buffer.getDouble(mutableAggBufferOffset + 1)
+ val currentM2 = buffer.getDouble(mutableAggBufferOffset + 2)
+ val currentM3 = buffer.getDouble(mutableAggBufferOffset + 3)
+ val currentM4 = buffer.getDouble(mutableAggBufferOffset + 4)
+
+ val updateM0 = currentM0 + 1.0
+ val delta = updateValue - currentM1
+ val deltaN = delta / updateM0
+
+ val updateM1 = currentM1 + delta / updateM0
+ val updateM2 = if (maxMoment >= 2) {
+ currentM2 + delta * (delta - deltaN)
+ } else {
+ 0.0
+ }
+ val delta2 = delta * delta
+ val deltaN2 = deltaN * deltaN
+ val updateM3 = if (maxMoment >= 3) {
+ currentM3 - 3.0 * deltaN * updateM2 + delta * (delta2 - deltaN2)
+ } else {
+ 0.0
+ }
+ val updateM4 = if (maxMoment >= 4) {
+ currentM4 - 4.0 * deltaN * updateM3 - 6.0 * deltaN2 * updateM2 +
+ delta * (delta * delta2 - deltaN * deltaN2)
+ } else {
+ 0.0
+ }
+
+ buffer.setDouble(mutableAggBufferOffset, updateM0)
+ buffer.setDouble(mutableAggBufferOffset + 1, updateM1)
+ buffer.setDouble(mutableAggBufferOffset + 2, updateM2)
+ buffer.setDouble(mutableAggBufferOffset + 3, updateM3)
+ buffer.setDouble(mutableAggBufferOffset + 4, updateM4)
+ }
+ }
+
+ /** Merge two central moment buffers. */
+ override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
+ val zeroMoment1 = buffer1.getDouble(mutableAggBufferOffset)
+ val zeroMoment2 = buffer2.getDouble(inputAggBufferOffset)
+ val firstMoment1 = buffer1.getDouble(mutableAggBufferOffset + 1)
+ val firstMoment2 = buffer2.getDouble(inputAggBufferOffset + 1)
+ val secondMoment1 = buffer1.getDouble(mutableAggBufferOffset + 2)
+ val secondMoment2 = buffer2.getDouble(inputAggBufferOffset + 2)
+ val thirdMoment1 = buffer1.getDouble(mutableAggBufferOffset + 3)
+ val thirdMoment2 = buffer2.getDouble(inputAggBufferOffset + 3)
+ val fourthMoment1 = buffer1.getDouble(mutableAggBufferOffset + 4)
+ val fourthMoment2 = buffer2.getDouble(inputAggBufferOffset + 4)
+
+ val zeroMoment = zeroMoment1 + zeroMoment2
+ val delta = firstMoment2 - firstMoment1
+ val deltaN = delta / zeroMoment
+
+ val firstMoment = firstMoment1 + deltaN * zeroMoment2
+
+ val secondMoment = if (maxMoment >= 2) {
+ secondMoment1 + secondMoment2 + delta * deltaN * zeroMoment1 *
zeroMoment2
+ } else {
+ 0.0
+ }
+
+ val thirdMoment = if (maxMoment >= 3) {
+ thirdMoment1 + thirdMoment2 + deltaN * deltaN * delta * zeroMoment1
* zeroMoment2 *
+ (zeroMoment1 - zeroMoment2) + 3.0 * deltaN *
+ (zeroMoment1 * secondMoment2 - zeroMoment2 * secondMoment1)
+ } else {
+ 0.0
+ }
+
+ val fourthMoment = if (maxMoment >= 4) {
+ fourthMoment1 + fourthMoment2 + deltaN * deltaN * deltaN * delta *
zeroMoment1 *
+ zeroMoment2 * (zeroMoment1 * zeroMoment1 - zeroMoment1 *
zeroMoment2 +
+ zeroMoment2 * zeroMoment2) + deltaN * deltaN * 6.0 *
+ (zeroMoment1 * zeroMoment1 * secondMoment2 + zeroMoment2 *
zeroMoment2 * secondMoment1) +
+ 4.0 * deltaN * (zeroMoment1 * thirdMoment2 - zeroMoment2 *
thirdMoment1)
+ } else {
+ 0.0
+ }
+
+ buffer1.setDouble(mutableAggBufferOffset, zeroMoment)
+ buffer1.setDouble(mutableAggBufferOffset + 1, firstMoment)
+ buffer1.setDouble(mutableAggBufferOffset + 2, secondMoment)
+ buffer1.setDouble(mutableAggBufferOffset + 3, thirdMoment)
+ buffer1.setDouble(mutableAggBufferOffset + 4, fourthMoment)
+ }
+}
+
+case class Stddev(child: Expression) extends CentralMomentAgg(child) {
+
+ override def prettyName: String = "stddev"
+
+ protected val maxMoment = 2
+
+ def eval(buffer: InternalRow): Any = {
+ // stddev = sqrt(M2 / (M0 - 1))
+ val M0 = buffer.getDouble(mutableAggBufferOffset)
+ val M2 = buffer.getDouble(mutableAggBufferOffset + 2)
+
+ if (M0 == 0.0) {
+ 0.0
--- End diff --
R and SciPy differ on this. R outputs NaN in divide by zero situations
where SciPy outputs 0. I will assume we stick with R implementation.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]