Github user juliuszsompolski commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19324#discussion_r141037818
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
 ---
    @@ -462,18 +464,36 @@ case class HashAggregateExec(
            $evaluateAggResults
            ${consume(ctx, resultVars)}
            """
    -
         } else if (modes.contains(Partial) || modes.contains(PartialMerge)) {
    -      // This should be the last operator in a stage, we should output 
UnsafeRow directly
    -      val joinerTerm = ctx.freshName("unsafeRowJoiner")
    -      ctx.addMutableState(classOf[UnsafeRowJoiner].getName, joinerTerm,
    -        s"$joinerTerm = $plan.createUnsafeJoiner();")
    -      val resultRow = ctx.freshName("resultRow")
    +      // resultExpressions are Attributes of groupingExpressions and 
aggregateBufferAttributes.
    +      assert(resultExpressions.forall(_.isInstanceOf[Attribute]))
    +      assert(resultExpressions.length ==
    +        groupingExpressions.length + aggregateBufferAttributes.length)
    +
    +      ctx.currentVars = null
    +
    +      ctx.INPUT_ROW = keyTerm
    +      val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) =>
    +        BoundReference(i, e.dataType, e.nullable).genCode(ctx)
    +      }
    +      val evaluateKeyVars = evaluateVariables(keyVars)
    +
    +      ctx.INPUT_ROW = bufferTerm
    +      val resultBufferVars = aggregateBufferAttributes.zipWithIndex.map { 
case (e, i) =>
    +        BoundReference(i, e.dataType, e.nullable).genCode(ctx)
    +      }
    +      val evaluateResultBufferVars = evaluateVariables(resultBufferVars)
    +
    +      ctx.currentVars = keyVars ++ resultBufferVars
    +      val inputAttrs = resultExpressions.map(_.toAttribute)
    +      val resultVars = resultExpressions.map { e =>
    +        BindReferences.bindReference(e, inputAttrs).genCode(ctx)
    +      }
           s"""
    -       UnsafeRow $resultRow = $joinerTerm.join($keyTerm, $bufferTerm);
    -       ${consume(ctx, null, resultRow)}
    +       $evaluateKeyVars
    +       $evaluateResultBufferVars
    +       ${consume(ctx, resultVars)}
            """
    -
         } else {
           // generate result based on grouping key
    --- End diff --
    
    Yes, e.g. for aggregation coming from Distinct.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to