Github user rxin commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5357#discussion_r28609928
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
 ---
    @@ -616,6 +617,115 @@ case class SumFunction(expr: Expression, base: 
AggregateExpression) extends Aggr
       }
     }
     
    +case class StdDeviation(child: Expression)
    +  extends PartialAggregate with trees.UnaryNode[Expression]{
    +  override def nullable: Boolean = true
    +
    +  override def dataType: DataType = child.dataType match {
    +    case DecimalType.Fixed(precision, scale) =>
    +      DecimalType(precision + 4, scale + 4)  // Add 4 digits after decimal 
point, like Hive
    +    case DecimalType.Unlimited =>
    +      DecimalType.Unlimited
    +    case _ =>
    +      DoubleType
    +  }
    +
    +  override def asPartial: SplitEvaluation = {
    +    val (seqPartialData, castStddev) = child.dataType match {
    +      case DecimalType.Fixed(_, _) =>
    +        // Turn the child to unlimited decimals for calculation, before 
going back to fixed
    +        val (seqPartialData, stddev) = 
calcPartialStddev(DecimalType.Unlimited)
    +        (seqPartialData, Cast(stddev, dataType))
    +      case _ =>
    +        calcPartialStddev(dataType)
    +    }
    +
    +    SplitEvaluation(castStddev, seqPartialData)
    +  }
    +
    +  /**
    +   *
    +   * @param calcType data type during calcuation
    +   * @return seqPartialData is data for partialEvaluations
    +   */
    +  protected def calcPartialStddev(calcType: DataType): (List[Alias], 
Expression) = {
    +    val castedChild = Cast(child, calcType)
    +    val partialSquredSum = Alias(Sum(Multiply(castedChild, castedChild)), 
"PartialSquredSum")()
    +    val partialSum = Alias(Sum(castedChild), "PartialSum")()
    +    val partialCount = Alias(Count(child), "PartialCount")()
    +    val seqPartialData = partialCount :: partialSum :: partialSquredSum :: 
Nil
    +
    +    val castedSquredSum = Cast(Sum(partialSquredSum.toAttribute), calcType)
    +    val castedSum = Cast(Sum(partialSum.toAttribute), calcType)
    +    val castedCount = Cast(Sum(partialCount.toAttribute), calcType)
    +    val castedAvg = Divide(castedSum, castedCount)
    +    val stddev = Sqrt(Divide(
    +                        Subtract(castedSquredSum, Multiply(castedSum, 
castedAvg)),
    +                        castedCount));
    +    (seqPartialData, stddev)
    +  }
    +
    +  override def toString: String = s"STDDEV($child)"
    +
    +  override def newInstance(): StdDeviationFunction = new 
StdDeviationFunction(child, this)
    +}
    +
    +case class StdDeviationFunction(expr: Expression, base: 
AggregateExpression)
    +  extends AggregateFunction {
    +  def this() = this(null, null) // /Required for serialization.
    +
    +  private val calcType =
    +    expr.dataType match {
    +      case DecimalType.Fixed(precision, scale) =>
    +        DecimalType(precision + 4, scale + 4)  // Add 4 digits after 
decimal point, like Hive
    +      case DecimalType.Unlimited =>
    +        DecimalType.Unlimited
    +      case _ =>
    +        DoubleType
    +    }
    +
    +  private val zero = Cast(Literal(0), calcType)
    +
    +  private var count: Long = _
    +
    --- End diff --
    
    remove the extra blank line here


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to