m44444 commented on a change in pull request #24459:
[SPARK-24935][SQL][followup] support INIT -> UPDATE -> MERGE -> FINISH in Hive
UDAF adapter
URL: https://github.com/apache/spark/pull/24459#discussion_r279074009
##########
File path: sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
##########
@@ -410,55 +417,70 @@ private[hive] case class HiveUDAFFunction(
// aggregate buffer. However, the Spark UDAF framework does not expose this
information when
// creating the buffer. Here we return null, and create the buffer in
`update` and `merge`
// on demand, so that we can know what input we are dealing with.
- override def createAggregationBuffer(): AggregationBuffer = null
+ override def createAggregationBuffer(): HiveUDAFBuffer = null
@transient
private lazy val inputProjection = UnsafeProjection.create(children)
- override def update(buffer: AggregationBuffer, input: InternalRow):
AggregationBuffer = {
+ override def update(buffer: HiveUDAFBuffer, input: InternalRow):
HiveUDAFBuffer = {
// The input is original data, we create buffer with the partial1
evaluator.
val nonNullBuffer = if (buffer == null) {
- partial1HiveEvaluator.evaluator.getNewAggregationBuffer
+ HiveUDAFBuffer(partial1HiveEvaluator.evaluator.getNewAggregationBuffer,
false)
} else {
buffer
}
+ assert(!nonNullBuffer.canDoMerge, "can not call `merge` then `update` on a
Hive UDAF.")
+
partial1HiveEvaluator.evaluator.iterate(
- nonNullBuffer, wrap(inputProjection(input), inputWrappers, cached,
inputDataTypes))
+ nonNullBuffer.buf, wrap(inputProjection(input), inputWrappers, cached,
inputDataTypes))
nonNullBuffer
}
- override def merge(buffer: AggregationBuffer, input: AggregationBuffer):
AggregationBuffer = {
+ override def merge(buffer: HiveUDAFBuffer, input: HiveUDAFBuffer):
HiveUDAFBuffer = {
// The input is aggregate buffer, we create buffer with the final
evaluator.
val nonNullBuffer = if (buffer == null) {
- finalHiveEvaluator.evaluator.getNewAggregationBuffer
+ HiveUDAFBuffer(finalHiveEvaluator.evaluator.getNewAggregationBuffer,
true)
} else {
buffer
}
+ // It's possible that we've called `update` of this Hive UDAF, and some
specific Hive UDAF
+ // implementation can't mix the `update` and `merge` calls during its life
cycle. To work
+ // around it, here we create a fresh buffer with final evaluator, and
merge the existing buffer
+ // to it, and replace the existing buffer with it.
+ val mergeableBuf = if (!nonNullBuffer.canDoMerge) {
+ val newBuf = finalHiveEvaluator.evaluator.getNewAggregationBuffer
+ finalHiveEvaluator.evaluator.merge(
+ newBuf,
partial1HiveEvaluator.evaluator.terminatePartial(nonNullBuffer.buf))
+ HiveUDAFBuffer(newBuf, true)
+ } else {
+ nonNullBuffer
+ }
+
// The 2nd argument of the Hive `GenericUDAFEvaluator.merge()` method is
an input aggregation
// buffer in the 3rd format mentioned in the ScalaDoc of this class.
Originally, Hive converts
// this `AggregationBuffer`s into this format before shuffling partial
aggregation results, and
// calls `GenericUDAFEvaluator.terminatePartial()` to do the conversion.
finalHiveEvaluator.evaluator.merge(
- nonNullBuffer, partial1HiveEvaluator.evaluator.terminatePartial(input))
- nonNullBuffer
+ mergeableBuf.buf,
partial1HiveEvaluator.evaluator.terminatePartial(input.buf))
+ mergeableBuf
}
- override def eval(buffer: AggregationBuffer): Any = {
- resultUnwrapper(finalHiveEvaluator.evaluator.terminate(buffer))
+ override def eval(buffer: HiveUDAFBuffer): Any = {
+ resultUnwrapper(finalHiveEvaluator.evaluator.terminate(buffer.buf))
}
- override def serialize(buffer: AggregationBuffer): Array[Byte] = {
+ override def serialize(buffer: HiveUDAFBuffer): Array[Byte] = {
// Serializes an `AggregationBuffer` that holds partial aggregation
results so that we can
// shuffle it for global aggregation later.
- aggBufferSerDe.serialize(buffer)
+ aggBufferSerDe.serialize(buffer.buf)
}
- override def deserialize(bytes: Array[Byte]): AggregationBuffer = {
+ override def deserialize(bytes: Array[Byte]): HiveUDAFBuffer = {
// Deserializes an `AggregationBuffer` from the shuffled partial
aggregation phase to prepare
// for global aggregation by merging multiple partial aggregation results
within a single group.
- aggBufferSerDe.deserialize(bytes)
+ HiveUDAFBuffer(aggBufferSerDe.deserialize(bytes), false)
Review comment:
I see, except for the case of falling back from hash agg, and that's what
you want to address here, and this is not impacting spark udaf. The logic looks
clear and good to me, thanks!
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]