gengliangwang commented on code in PR #51505:
URL: https://github.com/apache/spark/pull/51505#discussion_r2417562983
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala:
##########
@@ -435,3 +507,203 @@ case class ApproxTopKAccumulate(
override def prettyName: String =
getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("approx_top_k_accumulate")
}
+
+class CombineInternal[T](
+ sketch: ItemsSketch[T],
+ var itemDataType: DataType,
+ var maxItemsTracked: Int) {
+ def getSketch: ItemsSketch[T] = sketch
+
+ def getItemDataType: DataType = itemDataType
+
+ def setItemDataType(dataType: DataType): Unit = {
+ if (this.itemDataType == null) {
+ this.itemDataType = dataType
+ } else if (this.itemDataType != dataType) {
+ throw
QueryExecutionErrors.approxTopKSketchTypeNotMatch(this.itemDataType, dataType)
+ }
+ }
+
+ def getMaxItemsTracked: Int = maxItemsTracked
+
+ def setMaxItemsTracked(maxItemsTracked: Int): Unit = this.maxItemsTracked =
maxItemsTracked
+}
+
+/**
+ * An aggregate function that combines multiple sketches into a single sketch.
+ *
+ * @param state the expression containing the sketches to
combine
+ * @param maxItemsTracked the maximum number of items to track in the
sketch
+ * @param mutableAggBufferOffset the offset for mutable aggregation buffer
+ * @param inputAggBufferOffset the offset for input aggregation buffer
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_(state, maxItemsTracked) - Combines multiple sketches into a single
sketch.
+ `maxItemsTracked` An optional positive INTEGER literal with upper limit
of 1000000. If maxItemsTracked is not specified, it defaults to 10000.
+ """,
+ examples = """
+ Examples:
+ > SELECT approx_top_k_estimate(_FUNC_(sketch, 10000), 5) FROM (SELECT
approx_top_k_accumulate(expr) AS sketch FROM VALUES (0), (0), (1), (1) AS
tab(expr) UNION ALL SELECT approx_top_k_accumulate(expr) AS sketch FROM VALUES
(2), (3), (4), (4) AS tab(expr));
+
[{"item":0,"count":2},{"item":4,"count":2},{"item":1,"count":2},{"item":2,"count":1},{"item":3,"count":1}]
+ """,
+ group = "agg_funcs",
+ since = "4.1.0")
+// scalastyle:on line.size.limit
+case class ApproxTopKCombine(
+ state: Expression,
+ maxItemsTracked: Expression,
+ mutableAggBufferOffset: Int = 0,
+ inputAggBufferOffset: Int = 0)
+ extends TypedImperativeAggregate[CombineInternal[Any]]
+ with ImplicitCastInputTypes
+ with BinaryLike[Expression] {
+
+ def this(child: Expression, maxItemsTracked: Expression) = {
+ this(child, maxItemsTracked, 0, 0)
+ ApproxTopK.checkExpressionNotNull(maxItemsTracked, "maxItemsTracked")
+ ApproxTopK.checkMaxItemsTracked(maxItemsTrackedVal)
+ }
+
+ def this(child: Expression, maxItemsTracked: Int) = this(child,
Literal(maxItemsTracked))
+
+ def this(child: Expression) = this(child,
Literal(ApproxTopK.VOID_MAX_ITEMS_TRACKED), 0, 0)
+
+ private lazy val uncheckedItemDataType: DataType =
+ state.dataType.asInstanceOf[StructType](1).dataType
+ private lazy val maxItemsTrackedVal: Int =
maxItemsTracked.eval().asInstanceOf[Int]
+ private lazy val combineSizeSpecified: Boolean =
+ maxItemsTrackedVal != ApproxTopK.VOID_MAX_ITEMS_TRACKED
+
+ override def left: Expression = state
+
+ override def right: Expression = maxItemsTracked
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(StructType, IntegerType)
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val defaultCheck = super.checkInputDataTypes()
+ if (defaultCheck.isFailure) {
+ defaultCheck
+ } else {
+ val stateCheck = ApproxTopK.checkStateFieldAndType(state)
+ if (stateCheck.isFailure) {
+ stateCheck
+ } else if (!maxItemsTracked.foldable) {
+ TypeCheckFailure("Number of items tracked must be a constant literal")
+ } else {
+ TypeCheckSuccess
+ }
+ }
+ }
+
+ override def dataType: DataType =
ApproxTopK.getSketchStateDataType(uncheckedItemDataType)
+
+ override def createAggregationBuffer(): CombineInternal[Any] = {
+ if (combineSizeSpecified) {
+ val maxMapSize = ApproxTopK.calMaxMapSize(maxItemsTrackedVal)
+ new CombineInternal[Any](
+ new ItemsSketch[Any](maxMapSize),
+ null,
+ maxItemsTrackedVal)
+ } else {
+ new CombineInternal[Any](
+ new ItemsSketch[Any](ApproxTopK.SKETCH_SIZE_PLACEHOLDER),
+ null,
+ ApproxTopK.VOID_MAX_ITEMS_TRACKED)
+ }
+ }
+
+ override def update(buffer: CombineInternal[Any], input: InternalRow):
CombineInternal[Any] = {
+ val inputState = state.eval(input).asInstanceOf[InternalRow]
+ val inputSketchBytes = inputState.getBinary(0)
+ val inputMaxItemsTracked = inputState.getInt(2)
+ val typeCode = inputState.getBinary(3)
+ val actualItemDataType = ApproxTopK.bytesToDataType(typeCode)
+ buffer.setItemDataType(actualItemDataType)
+ val inputSketch = ItemsSketch.getInstance(
+ Memory.wrap(inputSketchBytes),
ApproxTopK.genSketchSerDe(buffer.getItemDataType))
+ buffer.getSketch.merge(inputSketch)
+ if (!combineSizeSpecified) {
+ buffer.setMaxItemsTracked(inputMaxItemsTracked)
+ }
+ buffer
+ }
+
+ override def merge(buffer: CombineInternal[Any], input: CombineInternal[Any])
+ : CombineInternal[Any] = {
+ if (!combineSizeSpecified) {
+ // check size
+ if (buffer.getMaxItemsTracked == ApproxTopK.VOID_MAX_ITEMS_TRACKED) {
+ // If buffer is a placeholder sketch, set it to the input sketch's max
items tracked
+ buffer.setMaxItemsTracked(input.getMaxItemsTracked)
+ }
+ if (buffer.getMaxItemsTracked != input.getMaxItemsTracked) {
+ throw QueryExecutionErrors.approxTopKSketchSizeNotMatch(
+ buffer.getMaxItemsTracked,
+ input.getMaxItemsTracked
+ )
+ }
+ }
+ // check item data type
+ if (buffer.getItemDataType != null && input.getItemDataType != null &&
+ buffer.getItemDataType != input.getItemDataType) {
+ throw QueryExecutionErrors.approxTopKSketchTypeNotMatch(
+ buffer.getItemDataType,
+ input.getItemDataType
+ )
+ } else if (buffer.getItemDataType == null) {
+ // If buffer is a placeholder sketch, set it to the input sketch's item
data type
+ buffer.setItemDataType(input.getItemDataType)
+ }
+ buffer.getSketch.merge(input.getSketch)
+ buffer
+ }
+
+ override def eval(buffer: CombineInternal[Any]): Any = {
+ val sketchBytes =
+
buffer.getSketch.toByteArray(ApproxTopK.genSketchSerDe(buffer.getItemDataType))
+ val maxItemsTracked = buffer.getMaxItemsTracked
+ val typeCode = ApproxTopK.dataTypeToBytes(buffer.getItemDataType)
+ InternalRow.apply(sketchBytes, null, maxItemsTracked, typeCode)
+ }
+
+ override def serialize(buffer: CombineInternal[Any]): Array[Byte] = {
+ val sketchBytes = buffer.getSketch.toByteArray(
+ ApproxTopK.genSketchSerDe(buffer.getItemDataType))
+ val maxItemsTrackedByte = buffer.getMaxItemsTracked.toByte
Review Comment:
Nice catch. This will cause overflow.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]