Github user mengxr commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5357#discussion_r29737910
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
 ---
    @@ -616,6 +617,119 @@ case class SumFunction(expr: Expression, base: 
AggregateExpression) extends Aggr
       }
     }
     
    +case class StdDeviation(child: Expression)
    +  extends PartialAggregate with trees.UnaryNode[Expression]{
    +  override def nullable: Boolean = true
    +
    +  override def dataType: DataType = child.dataType match {
    +    case DecimalType.Fixed(precision, scale) =>
    +      DecimalType(precision + 4, scale + 4)  // Add 4 digits after decimal 
point, like Hive
    +    case DecimalType.Unlimited =>
    +      DecimalType.Unlimited
    +    case _ =>
    +      DoubleType
    +  }
    +
    +  override def asPartial: SplitEvaluation = {
    +    val (seqPartialData, castStddev) = child.dataType match {
    +      case DecimalType.Fixed(_, _) =>
    +        // Turn the child to unlimited decimals for calculation, before 
going back to fixed
    +        val (seqPartialData, stddev) = 
calcPartialStddev(DecimalType.Unlimited)
    +        (seqPartialData, Cast(stddev, dataType))
    +      case _ =>
    +        calcPartialStddev(dataType)
    +    }
    +
    +    SplitEvaluation(castStddev, seqPartialData)
    +  }
    +
    +  /**
    +   * partialSquredSum = partial of sum(xi^2)
    +   * @param calcType data type during calcuation
    +   * @return seqPartialData is data for partialEvaluations
    +   */
    +  protected def calcPartialStddev(calcType: DataType): (List[Alias], 
Expression) = {
    +    val castedChild = Cast(child, calcType)
    +    val partialSquredSum = Alias(Sum(Multiply(castedChild, castedChild)), 
"PartialSquredSum")()
    +    val partialSum = Alias(Sum(castedChild), "PartialSum")()
    +    val partialCount = Alias(Count(child), "PartialCount")()
    +    val seqPartialData = partialCount :: partialSum :: partialSquredSum :: 
Nil
    +
    +    val castedSquredSum = Cast(Sum(partialSquredSum.toAttribute), calcType)
    +    val castedSum = Cast(Sum(partialSum.toAttribute), calcType)
    +    val castedCount = Cast(Sum(partialCount.toAttribute), calcType)
    +    val castedAvg = Divide(castedSum, castedCount)
    +    val stddev = Sqrt(Divide(
    +                        Subtract(castedSquredSum, Multiply(castedSum, 
castedAvg)),
    +                        castedCount));
    +    (seqPartialData, stddev)
    +  }
    +
    +  override def toString: String = s"STDDEV($child)"
    +
    +  override def newInstance(): StdDeviationFunction = new 
StdDeviationFunction(child, this)
    +}
    +
    +/**
    + * the standard deviation = sqrt(sum((xi-avg)^2)/N)
    --- End diff --
    
    `N` -> `N - 1` for sample standard deviation. please check the result from 
R and make sure we output the same. And we should leave a TODO here to 
implement the numerically stable method in the future. See:
    
    
https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala#L36


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to