m44444 commented on a change in pull request #24459:
[SPARK-24935][SQL][followup] support INIT -> UPDATE -> MERGE -> FINISH in Hive
UDAF adapter
URL: https://github.com/apache/spark/pull/24459#discussion_r278758245
##########
File path: sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
##########
@@ -410,55 +417,70 @@ private[hive] case class HiveUDAFFunction(
// aggregate buffer. However, the Spark UDAF framework does not expose this
information when
// creating the buffer. Here we return null, and create the buffer in
`update` and `merge`
// on demand, so that we can know what input we are dealing with.
- override def createAggregationBuffer(): AggregationBuffer = null
+ override def createAggregationBuffer(): HiveUDAFBuffer = null
@transient
private lazy val inputProjection = UnsafeProjection.create(children)
- override def update(buffer: AggregationBuffer, input: InternalRow):
AggregationBuffer = {
+ override def update(buffer: HiveUDAFBuffer, input: InternalRow):
HiveUDAFBuffer = {
// The input is original data, we create buffer with the partial1
evaluator.
val nonNullBuffer = if (buffer == null) {
- partial1HiveEvaluator.evaluator.getNewAggregationBuffer
+ HiveUDAFBuffer(partial1HiveEvaluator.evaluator.getNewAggregationBuffer,
false)
} else {
buffer
}
+ assert(!nonNullBuffer.canDoMerge, "can not call `merge` then `update` on a
Hive UDAF.")
+
partial1HiveEvaluator.evaluator.iterate(
- nonNullBuffer, wrap(inputProjection(input), inputWrappers, cached,
inputDataTypes))
+ nonNullBuffer.buf, wrap(inputProjection(input), inputWrappers, cached,
inputDataTypes))
nonNullBuffer
}
- override def merge(buffer: AggregationBuffer, input: AggregationBuffer):
AggregationBuffer = {
+ override def merge(buffer: HiveUDAFBuffer, input: HiveUDAFBuffer):
HiveUDAFBuffer = {
// The input is aggregate buffer, we create buffer with the final
evaluator.
val nonNullBuffer = if (buffer == null) {
- finalHiveEvaluator.evaluator.getNewAggregationBuffer
+ HiveUDAFBuffer(finalHiveEvaluator.evaluator.getNewAggregationBuffer,
true)
} else {
buffer
}
+ // It's possible that we've called `update` of this Hive UDAF, and some
specific Hive UDAF
+ // implementation can't mix the `update` and `merge` calls during its life
cycle. To work
+ // around it, here we create a fresh buffer with final evaluator, and
merge the existing buffer
+ // to it, and replace the existing buffer with it.
+ val mergeableBuf = if (!nonNullBuffer.canDoMerge) {
+ val newBuf = finalHiveEvaluator.evaluator.getNewAggregationBuffer
+ finalHiveEvaluator.evaluator.merge(
+ newBuf,
partial1HiveEvaluator.evaluator.terminatePartial(nonNullBuffer.buf))
+ HiveUDAFBuffer(newBuf, true)
+ } else {
+ nonNullBuffer
+ }
+
// The 2nd argument of the Hive `GenericUDAFEvaluator.merge()` method is
an input aggregation
// buffer in the 3rd format mentioned in the ScalaDoc of this class.
Originally, Hive converts
// this `AggregationBuffer`s into this format before shuffling partial
aggregation results, and
// calls `GenericUDAFEvaluator.terminatePartial()` to do the conversion.
finalHiveEvaluator.evaluator.merge(
- nonNullBuffer, partial1HiveEvaluator.evaluator.terminatePartial(input))
- nonNullBuffer
+ mergeableBuf.buf,
partial1HiveEvaluator.evaluator.terminatePartial(input.buf))
+ mergeableBuf
}
- override def eval(buffer: AggregationBuffer): Any = {
- resultUnwrapper(finalHiveEvaluator.evaluator.terminate(buffer))
+ override def eval(buffer: HiveUDAFBuffer): Any = {
+ resultUnwrapper(finalHiveEvaluator.evaluator.terminate(buffer.buf))
}
- override def serialize(buffer: AggregationBuffer): Array[Byte] = {
+ override def serialize(buffer: HiveUDAFBuffer): Array[Byte] = {
// Serializes an `AggregationBuffer` that holds partial aggregation
results so that we can
// shuffle it for global aggregation later.
- aggBufferSerDe.serialize(buffer)
+ aggBufferSerDe.serialize(buffer.buf)
}
- override def deserialize(bytes: Array[Byte]): AggregationBuffer = {
+ override def deserialize(bytes: Array[Byte]): HiveUDAFBuffer = {
// Deserializes an `AggregationBuffer` from the shuffled partial
aggregation phase to prepare
// for global aggregation by merging multiple partial aggregation results
within a single group.
- aggBufferSerDe.deserialize(bytes)
+ HiveUDAFBuffer(aggBufferSerDe.deserialize(bytes), false)
Review comment:
Once the value of canDoMerge is always set false after deserialization, in
the merge() function, the aggregationBuffer will be always re-created even the
passed buffer parameter is actually a Partial2 or Final state. This, correct me
if I am wrong, is a flaw causing performance downgrade.
May need to do none-trivial work in serialize() to include the state.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]