gengliangwang commented on code in PR #51505:
URL: https://github.com/apache/spark/pull/51505#discussion_r2430780027


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala:
##########
@@ -435,3 +491,283 @@ case class ApproxTopKAccumulate(
   override def prettyName: String =
     
getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("approx_top_k_accumulate")
 }
+
+/**
+ * In internal class used as the aggregation buffer for ApproxTopKCombine.
+ *
+ * @param sketch          the ItemsSketch instance
+ * @param itemDataType    the data type of items in the sketch
+ * @param maxItemsTracked the maximum number of items tracked in the sketch
+ */
+class CombineInternal[T](
+    sketch: ItemsSketch[T],
+    var itemDataType: DataType,
+    var maxItemsTracked: Int) {
+  def getSketch: ItemsSketch[T] = sketch
+
+  def getItemDataType: DataType = itemDataType
+
+  def getMaxItemsTracked: Int = maxItemsTracked
+
+  def updateMaxItemsTracked(combineSizeSpecified: Boolean, newMaxItemsTracked: 
Int): Unit = {
+    if (!combineSizeSpecified) {
+      // check size
+      if (this.maxItemsTracked == ApproxTopK.VOID_MAX_ITEMS_TRACKED) {
+        // If buffer's maxItemsTracked VOID_MAX_ITEMS_TRACKED, it means the 
buffer is a placeholder
+        // sketch that has not beed updated by any input sketch yet.
+        // So we can set it to the input sketch's max items tracked.
+        this.maxItemsTracked = newMaxItemsTracked
+      } else {
+        if (this.maxItemsTracked != newMaxItemsTracked) {
+          // If buffer's maxItemsTracked is not VOID_MAX_ITEMS_TRACKED, it 
means the buffer has been
+          // updated by some input sketch. So if buffer and input sketch have 
different
+          // maxItemsTracked values, it means at least two of the input 
sketches have different
+          // maxItemsTracked values. In this case, we should throw an error.
+          throw QueryExecutionErrors.approxTopKSketchSizeNotMatch(
+            this.maxItemsTracked, newMaxItemsTracked)
+        }
+      }
+    }
+  }
+
+  def updateItemDataType(inputItemDataType: DataType): Unit = {
+    // When the buffer's dataType hasn't been set, set it to the input 
sketch's item data type
+    // When input sketch's item data type is null, buffer's item data type 
will remain null
+    if (this.itemDataType == null) {
+      this.itemDataType = inputItemDataType
+    } else {
+      // When the buffer's dataType has been set, throw an error
+      // if the input sketch's item data type is not null the two data types 
don't match
+      if (inputItemDataType != null && this.itemDataType != inputItemDataType) 
{
+        throw QueryExecutionErrors.approxTopKSketchTypeNotMatch(
+          this.itemDataType, inputItemDataType)
+      }
+    }
+  }
+
+  /**
+   * Serialize the CombineInternal instance to a byte array.
+   * Serialization format:
+   *     maxItemsTracked (4 bytes int) +
+   *     itemDataTypeDDL length n in byte  (4 bytes int) +
+   *     itemDataTypeDDL (n bytes) +
+   *     sketchBytes
+   */
+  def serialize(): Array[Byte] = {
+    val sketchBytes = sketch.toByteArray(
+      
ApproxTopK.genSketchSerDe(itemDataType).asInstanceOf[ArrayOfItemsSerDe[T]])
+    val itemDataTypeDDL = ApproxTopK.dataTypeToDDL(itemDataType)
+    val ddlBytes: Array[Byte] = 
itemDataTypeDDL.getBytes(StandardCharsets.UTF_8)
+    val byteArray = new Array[Byte](sketchBytes.length + 4 + 4 + 
ddlBytes.length)
+
+    val byteBuffer = ByteBuffer.wrap(byteArray)
+    byteBuffer.putInt(maxItemsTracked)
+    byteBuffer.putInt(ddlBytes.length)
+    byteBuffer.put(ddlBytes)
+    byteBuffer.put(sketchBytes)
+    byteArray
+  }
+}
+
+object CombineInternal {
+  /**
+   * Deserialize a byte array to a CombineInternal instance.
+   * Serialization format:
+   *     maxItemsTracked (4 bytes int) +
+   *     itemDataTypeDDL length n in byte  (4 bytes int) +
+   *     itemDataTypeDDL (n bytes) +
+   *     sketchBytes
+   */
+  def deserialize(buffer: Array[Byte]): CombineInternal[Any] = {
+    val byteBuffer = ByteBuffer.wrap(buffer)
+    // read maxItemsTracked
+    val maxItemsTracked = byteBuffer.getInt
+    // read itemDataTypeDDL
+    val ddlLength = byteBuffer.getInt
+    val ddlBytes = new Array[Byte](ddlLength)
+    byteBuffer.get(ddlBytes)
+    val itemDataTypeDDL = new String(ddlBytes, StandardCharsets.UTF_8)
+    val itemDataType = ApproxTopK.DDLToDataType(itemDataTypeDDL)
+    // read sketchBytes
+    val sketchBytes = new Array[Byte](buffer.length - 4 - 4 - ddlLength)
+    byteBuffer.get(sketchBytes)
+    val sketch = ItemsSketch.getInstance(
+      Memory.wrap(sketchBytes), ApproxTopK.genSketchSerDe(itemDataType))
+    new CombineInternal[Any](sketch, itemDataType, maxItemsTracked)
+  }
+}
+
+/**
+ * An aggregate function that combines multiple sketches into a single sketch.
+ *
+ * @param state                  the expression containing the sketches to 
combine
+ * @param maxItemsTracked        the maximum number of items to track in the 
sketch
+ * @param mutableAggBufferOffset the offset for mutable aggregation buffer
+ * @param inputAggBufferOffset   the offset for input aggregation buffer
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage = """
+    _FUNC_(state, maxItemsTracked) - Combines multiple sketches into a single 
sketch.
+      `maxItemsTracked` An optional positive INTEGER literal with upper limit 
of 1000000. If maxItemsTracked is specified, it will be set for the combined 
sketch. If maxItemsTracked is not specified, the input sketches must have the 
same maxItemsTracked value, otherwise an error will be thrown. The output 
sketch will use the same value from the input sketches.
+  """,
+  examples = """
+    Examples:
+      > SELECT approx_top_k_estimate(_FUNC_(sketch, 10000), 5) FROM (SELECT 
approx_top_k_accumulate(expr) AS sketch FROM VALUES (0), (0), (1), (1) AS 
tab(expr) UNION ALL SELECT approx_top_k_accumulate(expr) AS sketch FROM VALUES 
(2), (3), (4), (4) AS tab(expr));
+       
[{"item":0,"count":2},{"item":4,"count":2},{"item":1,"count":2},{"item":2,"count":1},{"item":3,"count":1}]
+  """,
+  group = "agg_funcs",
+  since = "4.1.0")
+// scalastyle:on line.size.limit
+case class ApproxTopKCombine(
+    state: Expression,
+    maxItemsTracked: Expression,
+    mutableAggBufferOffset: Int = 0,
+    inputAggBufferOffset: Int = 0)
+  extends TypedImperativeAggregate[CombineInternal[Any]]
+  with ImplicitCastInputTypes
+  with BinaryLike[Expression] {
+
+  def this(child: Expression, maxItemsTracked: Expression) = {
+    this(child, maxItemsTracked, 0, 0)
+    ApproxTopK.checkExpressionNotNull(maxItemsTracked, "maxItemsTracked")
+    ApproxTopK.checkMaxItemsTracked(maxItemsTrackedVal)
+  }
+
+  def this(child: Expression, maxItemsTracked: Int) = this(child, 
Literal(maxItemsTracked))
+
+  // If maxItemsTracked is not specified, set it to VOID_MAX_ITEMS_TRACKED.
+  // This indicates that there is no explicit maxItemsTracked input from the 
function call.
+  // Hence, function needs to check the input sketches' maxItemsTracked values 
during merge.
+  def this(child: Expression) = this(child, 
Literal(ApproxTopK.VOID_MAX_ITEMS_TRACKED), 0, 0)
+
+  // The item data type extracted from the third field of the state struct.
+  // It is named "unchecked" because it may be inaccurate when input sketches 
have different
+  // item data types. For example, if one sketch has int type null and another 
has string type
+  // null, the union of the two sketches will have bigint type null.
+  // The accurate item data type will be tracked in the aggregation buffer 
during update/merge.
+  // It is okay to use uncheckedItemDataType to create the output data type of 
this function,
+  // because if the input sketches have different item data types, an error 
will be thrown
+  // during update/merge. Otherwise, the uncheckedItemDataType is accurate.
+  private lazy val uncheckedItemDataType: DataType =
+    state.dataType.asInstanceOf[StructType](2).dataType
+  private lazy val maxItemsTrackedVal: Int = 
maxItemsTracked.eval().asInstanceOf[Int]
+  private lazy val combineSizeSpecified: Boolean =
+    maxItemsTrackedVal != ApproxTopK.VOID_MAX_ITEMS_TRACKED
+
+  override def left: Expression = state
+
+  override def right: Expression = maxItemsTracked
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(StructType, IntegerType)
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val defaultCheck = super.checkInputDataTypes()
+    if (defaultCheck.isFailure) {
+      defaultCheck
+    } else {
+      val stateCheck = ApproxTopK.checkStateFieldAndType(state)
+      if (stateCheck.isFailure) {
+        stateCheck
+      } else if (!maxItemsTracked.foldable) {
+        TypeCheckFailure("Number of items tracked must be a constant literal")
+      } else {
+        TypeCheckSuccess
+      }
+    }
+  }
+
+  override def dataType: DataType = 
ApproxTopK.getSketchStateDataType(uncheckedItemDataType)
+
+  /**
+   * If maxItemsTracked is specified in function call, use it for the output 
sketch.
+   * Otherwise, create a placeholder sketch with VOID_MAX_ITEMS_TRACKED. The 
actual value will be
+   * decided during the first update.
+   */
+  override def createAggregationBuffer(): CombineInternal[Any] = {
+    if (combineSizeSpecified) {
+      val maxMapSize = ApproxTopK.calMaxMapSize(maxItemsTrackedVal)
+      new CombineInternal[Any](
+        new ItemsSketch[Any](maxMapSize),
+        null,
+        maxItemsTrackedVal)
+    } else {
+      // If maxItemsTracked is not specified, create a sketch with the maximum 
allowed size.
+      // No need to worry about memory waste, as the sketch always grows from 
a small init size.
+      // The actual maxItemsTracked will be checked during the updates.
+      val maxMapSize = 
ApproxTopK.calMaxMapSize(ApproxTopK.MAX_ITEMS_TRACKED_LIMIT)
+      new CombineInternal[Any](
+        new ItemsSketch[Any](maxMapSize),
+        null,
+        ApproxTopK.VOID_MAX_ITEMS_TRACKED)
+    }
+  }
+
+  /**
+   * Update the aggregation buffer with an input sketch. The input has the 
same schema as the
+   * ApproxTopKAccumulate output, i.e., sketchBytes + maxItemsTracked + null + 
DDL.
+   */
+  override def update(buffer: CombineInternal[Any], input: InternalRow): 
CombineInternal[Any] = {
+    val inputState = state.eval(input).asInstanceOf[InternalRow]
+    val inputSketchBytes = inputState.getBinary(0)
+    val inputMaxItemsTracked = inputState.getInt(1)
+    val inputItemDataTypeDDL = inputState.getUTF8String(3).toString
+    val inputItemDataType = ApproxTopK.DDLToDataType(inputItemDataTypeDDL)
+    // update maxItemsTracked (throw error if not match)
+    buffer.updateMaxItemsTracked(combineSizeSpecified, inputMaxItemsTracked)
+    // update itemDataType (throw error if not match)
+    buffer.updateItemDataType(inputItemDataType)
+    // update sketch
+    val inputSketch = ItemsSketch.getInstance(
+      Memory.wrap(inputSketchBytes), 
ApproxTopK.genSketchSerDe(buffer.getItemDataType))
+    buffer.getSketch.merge(inputSketch)
+    buffer
+  }
+
+  override def merge(buffer: CombineInternal[Any], input: CombineInternal[Any])

Review Comment:
   ```suggestion
     override def merge(
         buffer: CombineInternal[Any],
         input: CombineInternal[Any]): CombineInternal[Any] = {```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to