Github user marmbrus commented on a diff in the pull request:

    https://github.com/apache/spark/pull/3247#discussion_r22499696
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
 ---
    @@ -373,284 +191,276 @@ case class SumDistinct(child: Expression)
         case _ =>
           child.dataType
       }
    -  override def toString = s"SUM(DISTINCT ${child})"
    -  override def newInstance() = new SumDistinctFunction(child, this)
    -
    -  override def asPartial = {
    -    val partialSet = Alias(CollectHashSet(child :: Nil), "partialSets")()
    -    SplitEvaluation(
    -      CombineSetsAndSum(partialSet.toAttribute, this),
    -      partialSet :: Nil)
    -  }
    -}
    -
    -case class CombineSetsAndSum(inputSet: Expression, base: Expression) 
extends AggregateExpression {
    -  def this() = this(null, null)
    -
    -  override def children = inputSet :: Nil
    -  override def nullable = true
    -  override def dataType = base.dataType
    -  override def toString = s"CombineAndSum($inputSet)"
    -  override def newInstance() = new CombineSetsAndSumFunction(inputSet, 
this)
    -}
     
    -case class CombineSetsAndSumFunction(
    -    @transient inputSet: Expression,
    -    @transient base: AggregateExpression)
    -  extends AggregateFunction {
    -
    -  def this() = this(null, null) // Required for serialization.
    -
    -  val seen = new OpenHashSet[Any]()
    -
    -  override def update(input: Row): Unit = {
    -    val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]]
    -    val inputIterator = inputSetEval.iterator
    -    while (inputIterator.hasNext) {
    -      seen.add(inputIterator.next)
    -    }
    -  }
    +  override def toString = s"SUM($child)"
     
    -  override def eval(input: Row): Any = {
    -    val casted = seen.asInstanceOf[OpenHashSet[Row]]
    -    if (casted.size == 0) {
    -      null
    -    } else {
    -      Cast(Literal(
    -        casted.iterator.map(f => f.apply(0)).reduceLeft(
    -          
base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)),
    -        base.dataType).eval(null)
    -    }
    -  }
    +  override def bufferDataType: Seq[DataType] = dataType :: Nil
    +  override def newInstance(buffers: Seq[BoundReference]) = 
SumFunction(buffers(0), this)
     }
     
    -case class First(child: Expression) extends PartialAggregate with 
trees.UnaryNode[Expression] {
    +case class First(child: Expression, distinct: Boolean = false)
    +  extends UnaryAggregateExpression {
       override def nullable = true
       override def dataType = child.dataType
       override def toString = s"FIRST($child)"
     
    -  override def asPartial: SplitEvaluation = {
    -    val partialFirst = Alias(First(child), "PartialFirst")()
    -    SplitEvaluation(
    -      First(partialFirst.toAttribute),
    -      partialFirst :: Nil)
    -  }
    -  override def newInstance() = new FirstFunction(child, this)
    +  override def bufferDataType: Seq[DataType] = dataType :: Nil
    +  override def newInstance(buffers: Seq[BoundReference]) = 
FirstFunction(buffers(0), this)
     }
     
    -case class Last(child: Expression) extends PartialAggregate with 
trees.UnaryNode[Expression] {
    -  override def references = child.references
    +case class Last(child: Expression, distinct: Boolean = false)
    +  extends UnaryAggregateExpression {
       override def nullable = true
       override def dataType = child.dataType
       override def toString = s"LAST($child)"
     
    -  override def asPartial: SplitEvaluation = {
    -    val partialLast = Alias(Last(child), "PartialLast")()
    -    SplitEvaluation(
    -      Last(partialLast.toAttribute),
    -      partialLast :: Nil)
    -  }
    -  override def newInstance() = new LastFunction(child, this)
    +  override def bufferDataType: Seq[DataType] = dataType :: Nil
    +  override def newInstance(buffers: Seq[BoundReference]) = 
LastFunction(buffers(0), this)
     }
     
    -case class AverageFunction(expr: Expression, base: AggregateExpression)
    -  extends AggregateFunction {
    -
    -  def this() = this(null, null) // Required for serialization.
    -
    -  private val calcType =
    -    expr.dataType match {
    -      case DecimalType.Fixed(_, _) =>
    -        DecimalType.Unlimited
    -      case _ =>
    -        expr.dataType
    -    }
    -
    -  private val zero = Cast(Literal(0), calcType)
    -
    -  private var count: Long = _
    -  private val sum = MutableLiteral(zero.eval(null), calcType)
    +case class MinFunction(aggr: BoundReference, base: Min) extends 
AggregateFunction {
    +  val arg: MutableLiteral = MutableLiteral(null, base.dataType)
    +  val buffer: MutableLiteral = MutableLiteral(null, base.dataType)
    +  val cmp = LessThan(arg, buffer)
     
    -  private def addFunction(value: Any) = Add(sum, Cast(Literal(value, 
expr.dataType), calcType))
    +  override def reset(buf: MutableRow): Unit = {
    +    buf.update(aggr.ordinal, null)
    +  }
     
    -  override def eval(input: Row): Any = {
    -    if (count == 0L) {
    -      null
    -    } else {
    -      expr.dataType match {
    -        case DecimalType.Fixed(_, _) =>
    -          Cast(Divide(
    -            Cast(sum, DecimalType.Unlimited),
    -            Cast(Literal(count), DecimalType.Unlimited)), 
dataType).eval(null)
    -        case _ =>
    -          Divide(
    -            Cast(sum, dataType),
    -            Cast(Literal(count), dataType)).eval(null)
    +  override def iterate(argument: Any, buf: MutableRow): Unit = {
    +    if (argument != null) {
    +      arg.value = argument
    +      buffer.value = buf(aggr.ordinal)
    +      if (buf.isNullAt(aggr.ordinal) || cmp.eval(null) == true) {
    +        buf.update(aggr.ordinal, argument)
           }
         }
       }
     
    -  override def update(input: Row): Unit = {
    -    val evaluatedExpr = expr.eval(input)
    -    if (evaluatedExpr != null) {
    -      count += 1
    -      sum.update(addFunction(evaluatedExpr), input)
    +  override def merge(value: Row, rowBuf: MutableRow): Unit = {
    +    if (!value.isNullAt(aggr.ordinal)) {
    +      arg.value = value(aggr.ordinal)
    +      buffer.value = rowBuf(aggr.ordinal)
    +      if (rowBuf.isNullAt(aggr.ordinal) || cmp.eval(null) == true) {
    +        rowBuf.update(aggr.ordinal, arg.value)
    +      }
         }
       }
    -}
     
    -case class CountFunction(expr: Expression, base: AggregateExpression) 
extends AggregateFunction {
    -  def this() = this(null, null) // Required for serialization.
    +  override def terminate(row: Row): Any = aggr.eval(row)
    +}
     
    -  var count: Long = _
    +case class AverageFunction(count: BoundReference, sum: BoundReference, 
base: Average)
    +  extends AggregateFunction {
    +  // for iterate
    +  val arg = MutableLiteral(null, base.child.dataType)
    +  val cast = if (arg.dataType != base.dataType) Cast(arg, base.dataType) 
else arg
    +  val add = Add(cast, sum)
    +
    +  // for merge
    +  val argInMerge = MutableLiteral(null, base.dataType)
    +  val addInMerge = Add(argInMerge, sum)
    +
    +  // for terminate
    +  val divide = Divide(sum, Cast(count, base.dataType))
    +
    +  override def reset(buf: MutableRow): Unit = {
    +    buf.update(count.ordinal, 0L)
    +    buf.update(sum.ordinal, null)
    +  }
    +
    +  override def iterate(argument: Any, buf: MutableRow): Unit = {
    +    if (argument != null) {
    +      arg.value = argument
    +      buf.update(count.ordinal, buf.getLong(count.ordinal) + 1)
    +      if (buf.isNullAt(sum.ordinal)) {
    +        buf.update(sum.ordinal, cast.eval())
    +      } else {
    +        buf.update(sum.ordinal, add.eval(buf))
    +      }
    +    }
    +  }
     
    -  override def update(input: Row): Unit = {
    -    val evaluatedExpr = expr.eval(input)
    -    if (evaluatedExpr != null) {
    -      count += 1L
    +  override def merge(value: Row, buf: MutableRow): Unit = {
    +    if (!value.isNullAt(sum.ordinal)) {
    +      buf.setLong(count.ordinal, value.getLong(count.ordinal) + 
buf.getLong(count.ordinal))
    +      if (buf.isNullAt(sum.ordinal)) {
    +        buf.update(sum.ordinal, value(sum.ordinal))
    +      } else {
    +        argInMerge.value = value(sum.ordinal)
    +        buf.update(sum.ordinal, addInMerge.eval(buf))
    +      }
         }
       }
     
    -  override def eval(input: Row): Any = count
    +  override def terminate(row: Row): Any = if (count.eval(row) == 0) null 
else divide.eval(row)
     }
     
    -case class ApproxCountDistinctPartitionFunction(
    -    expr: Expression,
    -    base: AggregateExpression,
    -    relativeSD: Double)
    -  extends AggregateFunction {
    -  def this() = this(null, null, 0) // Required for serialization.
    +case class MaxFunction(aggr: BoundReference, base: Max) extends 
AggregateFunction {
    +  val arg: MutableLiteral = MutableLiteral(null, base.dataType)
    +  val buffer: MutableLiteral = MutableLiteral(null, base.dataType)
    +  val cmp = GreaterThan(arg, buffer)
     
    -  private val hyperLogLog = new HyperLogLog(relativeSD)
    +  override def reset(buf: MutableRow): Unit = {
    +    buf.update(aggr.ordinal, null)
    +  }
     
    -  override def update(input: Row): Unit = {
    -    val evaluatedExpr = expr.eval(input)
    -    if (evaluatedExpr != null) {
    -      hyperLogLog.offer(evaluatedExpr)
    +  override def iterate(argument: Any, buf: MutableRow): Unit = {
    +    if (argument != null) {
    +      arg.value = argument
    +      buffer.value = buf(aggr.ordinal)
    +      if (buf.isNullAt(aggr.ordinal) || cmp.eval(null) == true) {
    +        buf.update(aggr.ordinal, argument)
    +      }
         }
       }
     
    -  override def eval(input: Row): Any = hyperLogLog
    -}
    -
    -case class ApproxCountDistinctMergeFunction(
    -    expr: Expression,
    -    base: AggregateExpression,
    -    relativeSD: Double)
    -  extends AggregateFunction {
    -  def this() = this(null, null, 0) // Required for serialization.
    -
    -  private val hyperLogLog = new HyperLogLog(relativeSD)
    -
    -  override def update(input: Row): Unit = {
    -    val evaluatedExpr = expr.eval(input)
    -    hyperLogLog.addAll(evaluatedExpr.asInstanceOf[HyperLogLog])
    +  override def merge(value: Row, rowBuf: MutableRow): Unit = {
    +    if (!value.isNullAt(aggr.ordinal)) {
    +      arg.value = value(aggr.ordinal)
    +      buffer.value = rowBuf(aggr.ordinal)
    +      if (rowBuf.isNullAt(aggr.ordinal) || cmp.eval(null) == true) {
    +        rowBuf.update(aggr.ordinal, arg.value)
    +      }
    +    }
       }
     
    -  override def eval(input: Row): Any = hyperLogLog.cardinality()
    +  override def terminate(row: Row): Any = aggr.eval(row)
     }
     
    -case class SumFunction(expr: Expression, base: AggregateExpression) 
extends AggregateFunction {
    -  def this() = this(null, null) // Required for serialization.
    +case class CountFunction(aggr: BoundReference, base: Count)
    +    extends AggregateFunction {
    +  override def reset(buf: MutableRow): Unit = {
    +    buf.update(aggr.ordinal, 0L)
    +  }
     
    -  private val calcType =
    -    expr.dataType match {
    -      case DecimalType.Fixed(_, _) =>
    -        DecimalType.Unlimited
    -      case _ =>
    -        expr.dataType
    +  override def iterate(argument: Any, buf: MutableRow): Unit = {
    +    if (argument != null) {
    +      if (buf.isNullAt(aggr.ordinal)) {
    +        buf.setLong(aggr.ordinal, 1L)
    +      } else {
    +        buf.update(aggr.ordinal, buf.getLong(aggr.ordinal) + 1L)
    +      }
         }
    -
    -  private val zero = Cast(Literal(0), calcType)
    -
    -  private val sum = MutableLiteral(null, calcType)
    -
    -  private val addFunction = Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), 
Cast(expr, calcType)), sum))
    -
    -  override def update(input: Row): Unit = {
    -    sum.update(addFunction, input)
       }
     
    -  override def eval(input: Row): Any = {
    -    expr.dataType match {
    -      case DecimalType.Fixed(_, _) =>
    -        Cast(sum, dataType).eval(null)
    -      case _ => sum.eval(null)
    +  override def merge(value: Row, rowBuf: MutableRow): Unit = {
    +    if (value.isNullAt(aggr.ordinal)) {
    +      // do nothing
    +    } else if (rowBuf.isNullAt(aggr.ordinal)) {
    +      rowBuf(aggr.ordinal) = value(aggr.ordinal)
    +    } else {
    +      rowBuf.update(aggr.ordinal, value.getLong(aggr.ordinal) + 
rowBuf.getLong(aggr.ordinal))
         }
       }
    +
    +  override def terminate(row: Row): Any = aggr.eval(row)
     }
     
    -case class SumDistinctFunction(expr: Expression, base: AggregateExpression)
    +case class CountDistinctFunction(aggr: BoundReference, base: CountDistinct)
       extends AggregateFunction {
    -
    -  def this() = this(null, null) // Required for serialization.
    -
    -  private val seen = new scala.collection.mutable.HashSet[Any]()
    -
    -  override def update(input: Row): Unit = {
    -    val evaluatedExpr = expr.eval(input)
    -    if (evaluatedExpr != null) {
    -      seen += evaluatedExpr
    +  override def reset(buf: MutableRow): Unit = {
    +    buf.update(aggr.ordinal, 0L)
    --- End diff --
    
    Style nit:  `buf(agg.ordinal) = 0L`
    
    We might consider adding `def update(ref: BoundReference, newValue: Any)` 
also as sugar.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to