Github user sethah commented on a diff in the pull request:
https://github.com/apache/spark/pull/9003#discussion_r42923519
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
---
@@ -930,3 +930,332 @@ object HyperLogLogPlusPlus {
)
// scalastyle:on
}
+
+/**
+ * A central moment is the expected value of a specified power of the
deviation of a random
+ * variable from the mean. Central moments are often used to characterize
the properties of about
+ * the shape of a distribution.
+ *
+ * This class implements online, one-pass algorithms for computing the
central moments of a set of
+ * points.
+ *
+ * Behavior:
+ * - null values are ignored
+ * - returns `Double.NaN` when the column contains `Double.NaN` values
+ *
+ * References:
+ * - Xiangrui Meng. "Simpler Online Updates for Arbitrary-Order Central
Moments."
+ * 2015. http://arxiv.org/abs/1510.04923
+ *
+ * @see [[https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
+ * Algorithms for calculating variance (Wikipedia)]]
+ *
+ * @param child to compute central moments of.
+ */
+abstract class CentralMomentAgg(child: Expression) extends
ImperativeAggregate with Serializable {
+
+ /**
+ * The central moment order to be computed.
+ */
+ protected def momentOrder: Int
+
+ override def children: Seq[Expression] = Seq(child)
+
+ override def nullable: Boolean = false
+
+ override def dataType: DataType = DoubleType
+
+ // Expected input data type.
+ // TODO: Right now, we replace old aggregate functions (based on
AggregateExpression1) to the
+ // new version at planning time (after analysis phase). For now,
NullType is added at here
+ // to make it resolved when we have cases like `select avg(null)`.
+ // We can use our analyzer to cast NullType to the default data type of
the NumericType once
+ // we remove the old aggregate functions. Then, we will not need
NullType at here.
+ override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(NumericType, NullType))
+
+ override def aggBufferSchema: StructType =
StructType.fromAttributes(aggBufferAttributes)
+
+ /**
+ * Size of aggregation buffer.
+ */
+ private[this] val bufferSize = 5
+
+ override val aggBufferAttributes: Seq[AttributeReference] =
Seq.tabulate(bufferSize) { i =>
+ AttributeReference(s"M$i", DoubleType)()
+ }
+
+ // Note: although this simply copies aggBufferAttributes, this common
code can not be placed
+ // in the superclass because that will lead to initialization ordering
issues.
+ override val inputAggBufferAttributes: Seq[AttributeReference] =
+ aggBufferAttributes.map(_.newInstance())
+
+ // buffer offsets
+ private[this] val nOffset = mutableAggBufferOffset
+ private[this] val meanOffset = mutableAggBufferOffset + 1
+ private[this] val secondMomentOffset = mutableAggBufferOffset + 2
+ private[this] val thirdMomentOffset = mutableAggBufferOffset + 3
+ private[this] val fourthMomentOffset = mutableAggBufferOffset + 4
+
+ // frequently used values for online updates
+ private[this] var delta = 0.0
+ private[this] var deltaN = 0.0
+ private[this] var delta2 = 0.0
+ private[this] var deltaN2 = 0.0
+ private[this] var n = 0.0
+ private[this] var mean = 0.0
+ private[this] var m2 = 0.0
+ private[this] var m3 = 0.0
+ private[this] var m4 = 0.0
+
+ /**
+ * Initialize all moments to zero.
+ */
+ override def initialize(buffer: MutableRow): Unit = {
+ for (aggIndex <- 0 until bufferSize) {
+ buffer.setDouble(mutableAggBufferOffset + aggIndex, 0.0)
+ }
+ }
+
+ /**
+ * Update the central moments buffer.
+ */
+ override def update(buffer: MutableRow, input: InternalRow): Unit = {
+ val v = Cast(child, DoubleType).eval(input)
+ if (v != null) {
+ val updateValue = v match {
+ case d: Double => d
+ }
+
+ n = buffer.getDouble(nOffset)
+ mean = buffer.getDouble(meanOffset)
+
+ n += 1.0
+ buffer.setDouble(nOffset, n)
+ delta = updateValue - mean
+ deltaN = delta / n
+ mean += deltaN
+ buffer.setDouble(meanOffset, mean)
+
+ if (momentOrder >= 2) {
+ m2 = buffer.getDouble(secondMomentOffset)
+ m2 += delta * (delta - deltaN)
+ buffer.setDouble(secondMomentOffset, m2)
+ }
+
+ if (momentOrder >= 3) {
+ delta2 = delta * delta
+ deltaN2 = deltaN * deltaN
+ m3 = buffer.getDouble(thirdMomentOffset)
+ m3 += -3.0 * deltaN * m2 + delta * (delta2 - deltaN2)
+ buffer.setDouble(thirdMomentOffset, m3)
+ }
+
+ if (momentOrder >= 4) {
+ m4 = buffer.getDouble(fourthMomentOffset)
+ m4 += -4.0 * deltaN * m3 - 6.0 * deltaN2 * m2 +
+ delta * (delta * delta2 - deltaN * deltaN2)
+ buffer.setDouble(fourthMomentOffset, m4)
+ }
+ }
+ }
+
+ /**
+ * Merge two central moment buffers.
+ */
+ override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
+ val n1 = buffer1.getDouble(nOffset)
+ val n2 = buffer2.getDouble(inputAggBufferOffset)
+ val mean1 = buffer1.getDouble(meanOffset)
+ val mean2 = buffer2.getDouble(inputAggBufferOffset + 1)
+
+ var secondMoment1 = 0.0
+ var secondMoment2 = 0.0
+
+ var thirdMoment1 = 0.0
+ var thirdMoment2 = 0.0
+
+ var fourthMoment1 = 0.0
+ var fourthMoment2 = 0.0
+
+ n = n1 + n2
+ buffer1.setDouble(nOffset, n)
+ delta = mean2 - mean1
+ deltaN = if (n == 0.0) 0.0 else delta / n
--- End diff --
Removed divide by zero case here, which was causing problems when number of
partitions > number of samples.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]