This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-2.4 by this push:
new 3d49bd4 [SPARK-24935][SQL][FOLLOWUP] support INIT -> UPDATE -> MERGE
-> FINISH in Hive UDAF adapter
3d49bd4 is described below
commit 3d49bd496e8abfda816beea03269cd4094f2ec52
Author: Wenchen Fan <[email protected]>
AuthorDate: Tue Apr 30 10:35:23 2019 +0800
[SPARK-24935][SQL][FOLLOWUP] support INIT -> UPDATE -> MERGE -> FINISH in
Hive UDAF adapter
## What changes were proposed in this pull request?
This is a followup of https://github.com/apache/spark/pull/24144 . #24144
missed one case: when hash aggregate fallback to sort aggregate, the life cycle
of UDAF is: INIT -> UPDATE -> MERGE -> FINISH.
However, not all Hive UDAF can support it. Hive UDAF knows the aggregation
mode when creating the aggregation buffer, so that it can create different
buffers for different inputs: the original data or the aggregation buffer.
Please see an example in the [sketches
library](https://github.com/DataSketches/sketches-hive/blob/7f9e76e9e03807277146291beb2c7bec40e8672b/src/main/java/com/yahoo/sketches/hive/cpc/DataToSketchUDAF.java#L107).
The buffer for UPDATE may not support MERGE.
This PR updates the Hive UDAF adapter in Spark to support INIT -> UPDATE ->
MERGE -> FINISH, by turning it to INIT -> UPDATE -> FINISH + IINIT -> MERGE ->
FINISH.
## How was this patch tested?
a new test case
Closes #24459 from cloud-fan/hive-udaf.
Authored-by: Wenchen Fan <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
(cherry picked from commit 7432e7ded44cc0014590d229827546f5d8f93868)
Signed-off-by: Wenchen Fan <[email protected]>
---
.../scala/org/apache/spark/sql/hive/hiveUDFs.scala | 54 ++++++++++++++++------
.../spark/sql/hive/execution/HiveUDAFSuite.scala | 38 +++++++++------
2 files changed, 64 insertions(+), 28 deletions(-)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index 8ece4b5..d8d9e97 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -303,6 +303,13 @@ private[hive] case class HiveGenericUDTF(
* - `wrap()`/`wrapperFor()`: from 3 to 1
* - `unwrap()`/`unwrapperFor()`: from 1 to 3
* - `GenericUDAFEvaluator.terminatePartial()`: from 2 to 3
+ *
+ * Note that, Hive UDAF is initialized with aggregate mode, and some specific
Hive UDAFs can't
+ * mix UPDATE and MERGE actions during its life cycle. However, Spark may do
UPDATE on a UDAF and
+ * then do MERGE, in case of hash aggregate falling back to sort aggregate.
To work around this
+ * issue, we track the ability to do MERGE in the Hive UDAF aggregate buffer.
If Spark does
+ * UPDATE then MERGE, we can detect it and re-create the aggregate buffer
with a different
+ * aggregate mode.
*/
private[hive] case class HiveUDAFFunction(
name: String,
@@ -311,7 +318,7 @@ private[hive] case class HiveUDAFFunction(
isUDAFBridgeRequired: Boolean = false,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
- extends TypedImperativeAggregate[GenericUDAFEvaluator.AggregationBuffer]
+ extends TypedImperativeAggregate[HiveUDAFBuffer]
with HiveInspectors
with UserDefinedExpression {
@@ -397,55 +404,70 @@ private[hive] case class HiveUDAFFunction(
// aggregate buffer. However, the Spark UDAF framework does not expose this
information when
// creating the buffer. Here we return null, and create the buffer in
`update` and `merge`
// on demand, so that we can know what input we are dealing with.
- override def createAggregationBuffer(): AggregationBuffer = null
+ override def createAggregationBuffer(): HiveUDAFBuffer = null
@transient
private lazy val inputProjection = UnsafeProjection.create(children)
- override def update(buffer: AggregationBuffer, input: InternalRow):
AggregationBuffer = {
+ override def update(buffer: HiveUDAFBuffer, input: InternalRow):
HiveUDAFBuffer = {
// The input is original data, we create buffer with the partial1
evaluator.
val nonNullBuffer = if (buffer == null) {
- partial1HiveEvaluator.evaluator.getNewAggregationBuffer
+ HiveUDAFBuffer(partial1HiveEvaluator.evaluator.getNewAggregationBuffer,
false)
} else {
buffer
}
+ assert(!nonNullBuffer.canDoMerge, "can not call `merge` then `update` on a
Hive UDAF.")
+
partial1HiveEvaluator.evaluator.iterate(
- nonNullBuffer, wrap(inputProjection(input), inputWrappers, cached,
inputDataTypes))
+ nonNullBuffer.buf, wrap(inputProjection(input), inputWrappers, cached,
inputDataTypes))
nonNullBuffer
}
- override def merge(buffer: AggregationBuffer, input: AggregationBuffer):
AggregationBuffer = {
+ override def merge(buffer: HiveUDAFBuffer, input: HiveUDAFBuffer):
HiveUDAFBuffer = {
// The input is aggregate buffer, we create buffer with the final
evaluator.
val nonNullBuffer = if (buffer == null) {
- finalHiveEvaluator.evaluator.getNewAggregationBuffer
+ HiveUDAFBuffer(finalHiveEvaluator.evaluator.getNewAggregationBuffer,
true)
} else {
buffer
}
+ // It's possible that we've called `update` of this Hive UDAF, and some
specific Hive UDAF
+ // implementation can't mix the `update` and `merge` calls during its life
cycle. To work
+ // around it, here we create a fresh buffer with final evaluator, and
merge the existing buffer
+ // to it, and replace the existing buffer with it.
+ val mergeableBuf = if (!nonNullBuffer.canDoMerge) {
+ val newBuf = finalHiveEvaluator.evaluator.getNewAggregationBuffer
+ finalHiveEvaluator.evaluator.merge(
+ newBuf,
partial1HiveEvaluator.evaluator.terminatePartial(nonNullBuffer.buf))
+ HiveUDAFBuffer(newBuf, true)
+ } else {
+ nonNullBuffer
+ }
+
// The 2nd argument of the Hive `GenericUDAFEvaluator.merge()` method is
an input aggregation
// buffer in the 3rd format mentioned in the ScalaDoc of this class.
Originally, Hive converts
// this `AggregationBuffer`s into this format before shuffling partial
aggregation results, and
// calls `GenericUDAFEvaluator.terminatePartial()` to do the conversion.
finalHiveEvaluator.evaluator.merge(
- nonNullBuffer, partial1HiveEvaluator.evaluator.terminatePartial(input))
- nonNullBuffer
+ mergeableBuf.buf,
partial1HiveEvaluator.evaluator.terminatePartial(input.buf))
+ mergeableBuf
}
- override def eval(buffer: AggregationBuffer): Any = {
- resultUnwrapper(finalHiveEvaluator.evaluator.terminate(buffer))
+ override def eval(buffer: HiveUDAFBuffer): Any = {
+ resultUnwrapper(finalHiveEvaluator.evaluator.terminate(buffer.buf))
}
- override def serialize(buffer: AggregationBuffer): Array[Byte] = {
+ override def serialize(buffer: HiveUDAFBuffer): Array[Byte] = {
// Serializes an `AggregationBuffer` that holds partial aggregation
results so that we can
// shuffle it for global aggregation later.
- aggBufferSerDe.serialize(buffer)
+ aggBufferSerDe.serialize(buffer.buf)
}
- override def deserialize(bytes: Array[Byte]): AggregationBuffer = {
+ override def deserialize(bytes: Array[Byte]): HiveUDAFBuffer = {
// Deserializes an `AggregationBuffer` from the shuffled partial
aggregation phase to prepare
// for global aggregation by merging multiple partial aggregation results
within a single group.
- aggBufferSerDe.deserialize(bytes)
+ HiveUDAFBuffer(aggBufferSerDe.deserialize(bytes), false)
}
// Helper class used to de/serialize Hive UDAF `AggregationBuffer` objects
@@ -493,3 +515,5 @@ private[hive] case class HiveUDAFFunction(
}
}
}
+
+case class HiveUDAFBuffer(buf: AggregationBuffer, canDoMerge: Boolean)
diff --git
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala
index ed82dbd..27aff26 100644
---
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala
+++
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala
@@ -28,10 +28,10 @@ import
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo
import test.org.apache.spark.sql.MyDoubleAvg
-import org.apache.spark.SparkException
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec
import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils
{
@@ -93,21 +93,33 @@ class HiveUDAFSuite extends QueryTest with
TestHiveSingleton with SQLTestUtils {
))
}
- test("customized Hive UDAF with two aggregation buffers") {
- val df = sql("SELECT key % 2, mock2(value) FROM t GROUP BY key % 2")
+ test("SPARK-24935: customized Hive UDAF with two aggregation buffers") {
+ withTempView("v") {
+ spark.range(100).createTempView("v")
+ val df = sql("SELECT id % 2, mock2(id) FROM v GROUP BY id % 2")
- val aggs = df.queryExecution.executedPlan.collect {
- case agg: ObjectHashAggregateExec => agg
- }
+ val aggs = df.queryExecution.executedPlan.collect {
+ case agg: ObjectHashAggregateExec => agg
+ }
- // There should be two aggregate operators, one for partial aggregation,
and the other for
- // global aggregation.
- assert(aggs.length == 2)
+ // There should be two aggregate operators, one for partial aggregation,
and the other for
+ // global aggregation.
+ assert(aggs.length == 2)
- checkAnswer(df, Seq(
- Row(0, Row(1, 1)),
- Row(1, Row(1, 1))
- ))
+ withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "1")
{
+ checkAnswer(df, Seq(
+ Row(0, Row(50, 0)),
+ Row(1, Row(50, 0))
+ ))
+ }
+
+ withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key ->
"100") {
+ checkAnswer(df, Seq(
+ Row(0, Row(50, 0)),
+ Row(1, Row(50, 0))
+ ))
+ }
+ }
}
test("call JAVA UDAF") {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]