m44444 commented on a change in pull request #24459: 
[SPARK-24935][SQL][followup] support INIT -> UPDATE -> MERGE -> FINISH in Hive 
UDAF adapter
URL: https://github.com/apache/spark/pull/24459#discussion_r278758245
 
 

 ##########
 File path: sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
 ##########
 @@ -410,55 +417,70 @@ private[hive] case class HiveUDAFFunction(
   // aggregate buffer. However, the Spark UDAF framework does not expose this 
information when
   // creating the buffer. Here we return null, and create the buffer in 
`update` and `merge`
   // on demand, so that we can know what input we are dealing with.
-  override def createAggregationBuffer(): AggregationBuffer = null
+  override def createAggregationBuffer(): HiveUDAFBuffer = null
 
   @transient
   private lazy val inputProjection = UnsafeProjection.create(children)
 
-  override def update(buffer: AggregationBuffer, input: InternalRow): 
AggregationBuffer = {
+  override def update(buffer: HiveUDAFBuffer, input: InternalRow): 
HiveUDAFBuffer = {
     // The input is original data, we create buffer with the partial1 
evaluator.
     val nonNullBuffer = if (buffer == null) {
-      partial1HiveEvaluator.evaluator.getNewAggregationBuffer
+      HiveUDAFBuffer(partial1HiveEvaluator.evaluator.getNewAggregationBuffer, 
false)
     } else {
       buffer
     }
 
+    assert(!nonNullBuffer.canDoMerge, "can not call `merge` then `update` on a 
Hive UDAF.")
+
     partial1HiveEvaluator.evaluator.iterate(
-      nonNullBuffer, wrap(inputProjection(input), inputWrappers, cached, 
inputDataTypes))
+      nonNullBuffer.buf, wrap(inputProjection(input), inputWrappers, cached, 
inputDataTypes))
     nonNullBuffer
   }
 
-  override def merge(buffer: AggregationBuffer, input: AggregationBuffer): 
AggregationBuffer = {
+  override def merge(buffer: HiveUDAFBuffer, input: HiveUDAFBuffer): 
HiveUDAFBuffer = {
     // The input is aggregate buffer, we create buffer with the final 
evaluator.
     val nonNullBuffer = if (buffer == null) {
-      finalHiveEvaluator.evaluator.getNewAggregationBuffer
+      HiveUDAFBuffer(finalHiveEvaluator.evaluator.getNewAggregationBuffer, 
true)
     } else {
       buffer
     }
 
+    // It's possible that we've called `update` of this Hive UDAF, and some 
specific Hive UDAF
+    // implementation can't mix the `update` and `merge` calls during its life 
cycle. To work
+    // around it, here we create a fresh buffer with final evaluator, and 
merge the existing buffer
+    // to it, and replace the existing buffer with it.
+    val mergeableBuf = if (!nonNullBuffer.canDoMerge) {
+      val newBuf = finalHiveEvaluator.evaluator.getNewAggregationBuffer
+      finalHiveEvaluator.evaluator.merge(
+        newBuf, 
partial1HiveEvaluator.evaluator.terminatePartial(nonNullBuffer.buf))
+      HiveUDAFBuffer(newBuf, true)
+    } else {
+      nonNullBuffer
+    }
+
     // The 2nd argument of the Hive `GenericUDAFEvaluator.merge()` method is 
an input aggregation
     // buffer in the 3rd format mentioned in the ScalaDoc of this class. 
Originally, Hive converts
     // this `AggregationBuffer`s into this format before shuffling partial 
aggregation results, and
     // calls `GenericUDAFEvaluator.terminatePartial()` to do the conversion.
     finalHiveEvaluator.evaluator.merge(
-      nonNullBuffer, partial1HiveEvaluator.evaluator.terminatePartial(input))
-    nonNullBuffer
+      mergeableBuf.buf, 
partial1HiveEvaluator.evaluator.terminatePartial(input.buf))
+    mergeableBuf
   }
 
-  override def eval(buffer: AggregationBuffer): Any = {
-    resultUnwrapper(finalHiveEvaluator.evaluator.terminate(buffer))
+  override def eval(buffer: HiveUDAFBuffer): Any = {
+    resultUnwrapper(finalHiveEvaluator.evaluator.terminate(buffer.buf))
   }
 
-  override def serialize(buffer: AggregationBuffer): Array[Byte] = {
+  override def serialize(buffer: HiveUDAFBuffer): Array[Byte] = {
     // Serializes an `AggregationBuffer` that holds partial aggregation 
results so that we can
     // shuffle it for global aggregation later.
-    aggBufferSerDe.serialize(buffer)
+    aggBufferSerDe.serialize(buffer.buf)
   }
 
-  override def deserialize(bytes: Array[Byte]): AggregationBuffer = {
+  override def deserialize(bytes: Array[Byte]): HiveUDAFBuffer = {
     // Deserializes an `AggregationBuffer` from the shuffled partial 
aggregation phase to prepare
     // for global aggregation by merging multiple partial aggregation results 
within a single group.
-    aggBufferSerDe.deserialize(bytes)
+    HiveUDAFBuffer(aggBufferSerDe.deserialize(bytes), false)
 
 Review comment:
   Once the value of canDoMerge is always set false after deserialization, in 
the merge() function, the aggregationBuffer will be always re-created even the 
passed buffer parameter is actually a Partial2 or Final state. This, correct me 
if I am wrong, is a flaw causing performance downgrade.
   May need to do none-trivial work in serialize() to include the state.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to