gengliangwang commented on code in PR #51505:
URL: https://github.com/apache/spark/pull/51505#discussion_r2417560164


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala:
##########
@@ -435,3 +507,203 @@ case class ApproxTopKAccumulate(
   override def prettyName: String =
     
getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("approx_top_k_accumulate")
 }
+
+class CombineInternal[T](
+    sketch: ItemsSketch[T],
+    var itemDataType: DataType,
+    var maxItemsTracked: Int) {
+  def getSketch: ItemsSketch[T] = sketch
+
+  def getItemDataType: DataType = itemDataType
+
+  def setItemDataType(dataType: DataType): Unit = {
+    if (this.itemDataType == null) {
+      this.itemDataType = dataType
+    } else if (this.itemDataType != dataType) {
+      throw 
QueryExecutionErrors.approxTopKSketchTypeNotMatch(this.itemDataType, dataType)
+    }
+  }
+
+  def getMaxItemsTracked: Int = maxItemsTracked
+
+  def setMaxItemsTracked(maxItemsTracked: Int): Unit = this.maxItemsTracked = 
maxItemsTracked
+}
+
+/**
+ * An aggregate function that combines multiple sketches into a single sketch.
+ *
+ * @param state                  the expression containing the sketches to 
combine
+ * @param maxItemsTracked        the maximum number of items to track in the 
sketch
+ * @param mutableAggBufferOffset the offset for mutable aggregation buffer
+ * @param inputAggBufferOffset   the offset for input aggregation buffer
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage = """
+    _FUNC_(state, maxItemsTracked) - Combines multiple sketches into a single 
sketch.
+      `maxItemsTracked` An optional positive INTEGER literal with upper limit 
of 1000000. If maxItemsTracked is not specified, it defaults to 10000.
+  """,
+  examples = """
+    Examples:
+      > SELECT approx_top_k_estimate(_FUNC_(sketch, 10000), 5) FROM (SELECT 
approx_top_k_accumulate(expr) AS sketch FROM VALUES (0), (0), (1), (1) AS 
tab(expr) UNION ALL SELECT approx_top_k_accumulate(expr) AS sketch FROM VALUES 
(2), (3), (4), (4) AS tab(expr));
+       
[{"item":0,"count":2},{"item":4,"count":2},{"item":1,"count":2},{"item":2,"count":1},{"item":3,"count":1}]
+  """,
+  group = "agg_funcs",
+  since = "4.1.0")
+// scalastyle:on line.size.limit
+case class ApproxTopKCombine(
+    state: Expression,
+    maxItemsTracked: Expression,
+    mutableAggBufferOffset: Int = 0,
+    inputAggBufferOffset: Int = 0)
+  extends TypedImperativeAggregate[CombineInternal[Any]]
+  with ImplicitCastInputTypes
+  with BinaryLike[Expression] {
+
+  def this(child: Expression, maxItemsTracked: Expression) = {
+    this(child, maxItemsTracked, 0, 0)
+    ApproxTopK.checkExpressionNotNull(maxItemsTracked, "maxItemsTracked")
+    ApproxTopK.checkMaxItemsTracked(maxItemsTrackedVal)
+  }
+
+  def this(child: Expression, maxItemsTracked: Int) = this(child, 
Literal(maxItemsTracked))
+
+  def this(child: Expression) = this(child, 
Literal(ApproxTopK.VOID_MAX_ITEMS_TRACKED), 0, 0)
+
+  private lazy val uncheckedItemDataType: DataType =
+    state.dataType.asInstanceOf[StructType](1).dataType
+  private lazy val maxItemsTrackedVal: Int = 
maxItemsTracked.eval().asInstanceOf[Int]
+  private lazy val combineSizeSpecified: Boolean =
+    maxItemsTrackedVal != ApproxTopK.VOID_MAX_ITEMS_TRACKED
+
+  override def left: Expression = state
+
+  override def right: Expression = maxItemsTracked
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(StructType, IntegerType)
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val defaultCheck = super.checkInputDataTypes()
+    if (defaultCheck.isFailure) {
+      defaultCheck
+    } else {
+      val stateCheck = ApproxTopK.checkStateFieldAndType(state)
+      if (stateCheck.isFailure) {
+        stateCheck
+      } else if (!maxItemsTracked.foldable) {
+        TypeCheckFailure("Number of items tracked must be a constant literal")
+      } else {
+        TypeCheckSuccess
+      }
+    }
+  }
+
+  override def dataType: DataType = 
ApproxTopK.getSketchStateDataType(uncheckedItemDataType)
+
+  override def createAggregationBuffer(): CombineInternal[Any] = {
+    if (combineSizeSpecified) {
+      val maxMapSize = ApproxTopK.calMaxMapSize(maxItemsTrackedVal)
+      new CombineInternal[Any](
+        new ItemsSketch[Any](maxMapSize),
+        null,
+        maxItemsTrackedVal)
+    } else {
+      new CombineInternal[Any](
+        new ItemsSketch[Any](ApproxTopK.SKETCH_SIZE_PLACEHOLDER),
+        null,
+        ApproxTopK.VOID_MAX_ITEMS_TRACKED)

Review Comment:
   Shall we simply set it as `MAX_ITEMS_TRACKED_LIMIT: Int = 1000000` here?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to