cboumalh commented on code in PR #51505:
URL: https://github.com/apache/spark/pull/51505#discussion_r2229725214
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala:
##########
@@ -435,3 +507,203 @@ case class ApproxTopKAccumulate(
override def prettyName: String =
getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("approx_top_k_accumulate")
}
+
+class CombineInternal[T](
+ sketch: ItemsSketch[T],
+ var itemDataType: DataType,
+ var maxItemsTracked: Int) {
+ def getSketch: ItemsSketch[T] = sketch
+
+ def getItemDataType: DataType = itemDataType
+
+ def setItemDataType(dataType: DataType): Unit = {
+ if (this.itemDataType == null) {
+ this.itemDataType = dataType
+ } else if (this.itemDataType != dataType) {
+ throw
QueryExecutionErrors.approxTopKSketchTypeNotMatch(this.itemDataType, dataType)
+ }
+ }
+
+ def getMaxItemsTracked: Int = maxItemsTracked
+
+ def setMaxItemsTracked(maxItemsTracked: Int): Unit = this.maxItemsTracked =
maxItemsTracked
+}
+
+/**
+ * An aggregate function that combines multiple sketches into a single sketch.
+ *
+ * @param state the expression containing the sketches to
combine
+ * @param maxItemsTracked the maximum number of items to track in the
sketch
+ * @param mutableAggBufferOffset the offset for mutable aggregation buffer
+ * @param inputAggBufferOffset the offset for input aggregation buffer
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_(state, maxItemsTracked) - Combines multiple sketches into a single
sketch.
+ `maxItemsTracked` An optional positive INTEGER literal with upper limit
of 1000000. If maxItemsTracked is not specified, it defaults to 10000.
+ """,
+ examples = """
+ Examples:
+ > SELECT approx_top_k_estimate(_FUNC_(sketch, 10000), 5) FROM (SELECT
approx_top_k_accumulate(expr) AS sketch FROM VALUES (0), (0), (1), (1) AS
tab(expr) UNION ALL SELECT approx_top_k_accumulate(expr) AS sketch FROM VALUES
(2), (3), (4), (4) AS tab(expr));
+
[{"item":0,"count":2},{"item":4,"count":2},{"item":1,"count":2},{"item":2,"count":1},{"item":3,"count":1}]
+ """,
+ group = "agg_funcs",
+ since = "4.1.0")
+// scalastyle:on line.size.limit
+case class ApproxTopKCombine(
+ state: Expression,
+ maxItemsTracked: Expression,
+ mutableAggBufferOffset: Int = 0,
+ inputAggBufferOffset: Int = 0)
+ extends TypedImperativeAggregate[CombineInternal[Any]]
+ with ImplicitCastInputTypes
+ with BinaryLike[Expression] {
+
+ def this(child: Expression, maxItemsTracked: Expression) = {
+ this(child, maxItemsTracked, 0, 0)
+ ApproxTopK.checkExpressionNotNull(maxItemsTracked, "maxItemsTracked")
+ ApproxTopK.checkMaxItemsTracked(maxItemsTrackedVal)
+ }
+
+ def this(child: Expression, maxItemsTracked: Int) = this(child,
Literal(maxItemsTracked))
+
+ def this(child: Expression) = this(child,
Literal(ApproxTopK.VOID_MAX_ITEMS_TRACKED), 0, 0)
+
+ private lazy val uncheckedItemDataType: DataType =
+ state.dataType.asInstanceOf[StructType](1).dataType
+ private lazy val maxItemsTrackedVal: Int =
maxItemsTracked.eval().asInstanceOf[Int]
+ private lazy val combineSizeSpecified: Boolean =
+ maxItemsTrackedVal != ApproxTopK.VOID_MAX_ITEMS_TRACKED
+
+ override def left: Expression = state
+
+ override def right: Expression = maxItemsTracked
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(StructType, IntegerType)
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val defaultCheck = super.checkInputDataTypes()
+ if (defaultCheck.isFailure) {
+ defaultCheck
+ } else {
+ val stateCheck = ApproxTopK.checkStateFieldAndType(state)
+ if (stateCheck.isFailure) {
+ stateCheck
+ } else if (!maxItemsTracked.foldable) {
+ TypeCheckFailure("Number of items tracked must be a constant literal")
+ } else {
+ TypeCheckSuccess
+ }
+ }
+ }
+
+ override def dataType: DataType =
ApproxTopK.getSketchStateDataType(uncheckedItemDataType)
+
+ override def createAggregationBuffer(): CombineInternal[Any] = {
+ if (combineSizeSpecified) {
+ val maxMapSize = ApproxTopK.calMaxMapSize(maxItemsTrackedVal)
+ new CombineInternal[Any](
+ new ItemsSketch[Any](maxMapSize),
+ null,
+ maxItemsTrackedVal)
+ } else {
+ new CombineInternal[Any](
+ new ItemsSketch[Any](ApproxTopK.SKETCH_SIZE_PLACEHOLDER),
+ null,
+ ApproxTopK.VOID_MAX_ITEMS_TRACKED)
+ }
+ }
+
+ override def update(buffer: CombineInternal[Any], input: InternalRow):
CombineInternal[Any] = {
+ val inputState = state.eval(input).asInstanceOf[InternalRow]
+ val inputSketchBytes = inputState.getBinary(0)
+ val inputMaxItemsTracked = inputState.getInt(2)
+ val typeCode = inputState.getBinary(3)
+ val actualItemDataType = ApproxTopK.bytesToDataType(typeCode)
+ buffer.setItemDataType(actualItemDataType)
+ val inputSketch = ItemsSketch.getInstance(
+ Memory.wrap(inputSketchBytes),
ApproxTopK.genSketchSerDe(buffer.getItemDataType))
+ buffer.getSketch.merge(inputSketch)
+ if (!combineSizeSpecified) {
+ buffer.setMaxItemsTracked(inputMaxItemsTracked)
+ }
+ buffer
+ }
+
+ override def merge(buffer: CombineInternal[Any], input: CombineInternal[Any])
+ : CombineInternal[Any] = {
+ if (!combineSizeSpecified) {
+ // check size
+ if (buffer.getMaxItemsTracked == ApproxTopK.VOID_MAX_ITEMS_TRACKED) {
+ // If buffer is a placeholder sketch, set it to the input sketch's max
items tracked
+ buffer.setMaxItemsTracked(input.getMaxItemsTracked)
+ }
+ if (buffer.getMaxItemsTracked != input.getMaxItemsTracked) {
+ throw QueryExecutionErrors.approxTopKSketchSizeNotMatch(
+ buffer.getMaxItemsTracked,
+ input.getMaxItemsTracked
+ )
+ }
+ }
+ // check item data type
+ if (buffer.getItemDataType != null && input.getItemDataType != null &&
+ buffer.getItemDataType != input.getItemDataType) {
+ throw QueryExecutionErrors.approxTopKSketchTypeNotMatch(
+ buffer.getItemDataType,
+ input.getItemDataType
+ )
+ } else if (buffer.getItemDataType == null) {
+ // If buffer is a placeholder sketch, set it to the input sketch's item
data type
+ buffer.setItemDataType(input.getItemDataType)
+ }
+ buffer.getSketch.merge(input.getSketch)
+ buffer
+ }
+
+ override def eval(buffer: CombineInternal[Any]): Any = {
+ val sketchBytes =
+
buffer.getSketch.toByteArray(ApproxTopK.genSketchSerDe(buffer.getItemDataType))
+ val maxItemsTracked = buffer.getMaxItemsTracked
+ val typeCode = ApproxTopK.dataTypeToBytes(buffer.getItemDataType)
+ InternalRow.apply(sketchBytes, null, maxItemsTracked, typeCode)
+ }
+
+ override def serialize(buffer: CombineInternal[Any]): Array[Byte] = {
Review Comment:
Serialization and deserialization here are a bit concerning. The current
implementation has brittle byte-level manipulations with no robust error
checking or flexibility. Hardcoded assumptions about byte sizes and type
conversions create significant risks for future compatibility and data
integrity. I think it may be worth looking into using a ByteBuffer and adding
versioning to the sketch buffer, or any other approaches!
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala:
##########
@@ -177,6 +177,8 @@ object ApproxTopK {
val DEFAULT_K: Int = 5
val DEFAULT_MAX_ITEMS_TRACKED: Int = 10000
private val MAX_ITEMS_TRACKED_LIMIT: Int = 1000000
+ val VOID_MAX_ITEMS_TRACKED = -1
Review Comment:
Can we document what exactly this VOID_MAX_ITEMS_TRACKED and
SKETCH_SIZE_PLACEHOLDER mean?
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala:
##########
@@ -435,3 +507,203 @@ case class ApproxTopKAccumulate(
override def prettyName: String =
getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("approx_top_k_accumulate")
}
+
+class CombineInternal[T](
+ sketch: ItemsSketch[T],
+ var itemDataType: DataType,
+ var maxItemsTracked: Int) {
+ def getSketch: ItemsSketch[T] = sketch
+
+ def getItemDataType: DataType = itemDataType
+
+ def setItemDataType(dataType: DataType): Unit = {
+ if (this.itemDataType == null) {
+ this.itemDataType = dataType
+ } else if (this.itemDataType != dataType) {
+ throw
QueryExecutionErrors.approxTopKSketchTypeNotMatch(this.itemDataType, dataType)
+ }
+ }
+
+ def getMaxItemsTracked: Int = maxItemsTracked
+
+ def setMaxItemsTracked(maxItemsTracked: Int): Unit = this.maxItemsTracked =
maxItemsTracked
+}
+
+/**
+ * An aggregate function that combines multiple sketches into a single sketch.
+ *
+ * @param state the expression containing the sketches to
combine
+ * @param maxItemsTracked the maximum number of items to track in the
sketch
+ * @param mutableAggBufferOffset the offset for mutable aggregation buffer
+ * @param inputAggBufferOffset the offset for input aggregation buffer
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_(state, maxItemsTracked) - Combines multiple sketches into a single
sketch.
+ `maxItemsTracked` An optional positive INTEGER literal with upper limit
of 1000000. If maxItemsTracked is not specified, it defaults to 10000.
+ """,
+ examples = """
+ Examples:
+ > SELECT approx_top_k_estimate(_FUNC_(sketch, 10000), 5) FROM (SELECT
approx_top_k_accumulate(expr) AS sketch FROM VALUES (0), (0), (1), (1) AS
tab(expr) UNION ALL SELECT approx_top_k_accumulate(expr) AS sketch FROM VALUES
(2), (3), (4), (4) AS tab(expr));
+
[{"item":0,"count":2},{"item":4,"count":2},{"item":1,"count":2},{"item":2,"count":1},{"item":3,"count":1}]
+ """,
+ group = "agg_funcs",
+ since = "4.1.0")
+// scalastyle:on line.size.limit
+case class ApproxTopKCombine(
+ state: Expression,
+ maxItemsTracked: Expression,
+ mutableAggBufferOffset: Int = 0,
+ inputAggBufferOffset: Int = 0)
+ extends TypedImperativeAggregate[CombineInternal[Any]]
+ with ImplicitCastInputTypes
+ with BinaryLike[Expression] {
+
+ def this(child: Expression, maxItemsTracked: Expression) = {
+ this(child, maxItemsTracked, 0, 0)
+ ApproxTopK.checkExpressionNotNull(maxItemsTracked, "maxItemsTracked")
+ ApproxTopK.checkMaxItemsTracked(maxItemsTrackedVal)
+ }
+
+ def this(child: Expression, maxItemsTracked: Int) = this(child,
Literal(maxItemsTracked))
+
+ def this(child: Expression) = this(child,
Literal(ApproxTopK.VOID_MAX_ITEMS_TRACKED), 0, 0)
+
+ private lazy val uncheckedItemDataType: DataType =
+ state.dataType.asInstanceOf[StructType](1).dataType
+ private lazy val maxItemsTrackedVal: Int =
maxItemsTracked.eval().asInstanceOf[Int]
+ private lazy val combineSizeSpecified: Boolean =
+ maxItemsTrackedVal != ApproxTopK.VOID_MAX_ITEMS_TRACKED
+
+ override def left: Expression = state
+
+ override def right: Expression = maxItemsTracked
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(StructType, IntegerType)
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val defaultCheck = super.checkInputDataTypes()
+ if (defaultCheck.isFailure) {
+ defaultCheck
+ } else {
+ val stateCheck = ApproxTopK.checkStateFieldAndType(state)
+ if (stateCheck.isFailure) {
+ stateCheck
+ } else if (!maxItemsTracked.foldable) {
+ TypeCheckFailure("Number of items tracked must be a constant literal")
+ } else {
+ TypeCheckSuccess
+ }
+ }
+ }
+
+ override def dataType: DataType =
ApproxTopK.getSketchStateDataType(uncheckedItemDataType)
+
+ override def createAggregationBuffer(): CombineInternal[Any] = {
+ if (combineSizeSpecified) {
+ val maxMapSize = ApproxTopK.calMaxMapSize(maxItemsTrackedVal)
+ new CombineInternal[Any](
+ new ItemsSketch[Any](maxMapSize),
+ null,
+ maxItemsTrackedVal)
+ } else {
+ new CombineInternal[Any](
+ new ItemsSketch[Any](ApproxTopK.SKETCH_SIZE_PLACEHOLDER),
+ null,
+ ApproxTopK.VOID_MAX_ITEMS_TRACKED)
+ }
+ }
+
+ override def update(buffer: CombineInternal[Any], input: InternalRow):
CombineInternal[Any] = {
+ val inputState = state.eval(input).asInstanceOf[InternalRow]
+ val inputSketchBytes = inputState.getBinary(0)
+ val inputMaxItemsTracked = inputState.getInt(2)
+ val typeCode = inputState.getBinary(3)
+ val actualItemDataType = ApproxTopK.bytesToDataType(typeCode)
+ buffer.setItemDataType(actualItemDataType)
+ val inputSketch = ItemsSketch.getInstance(
+ Memory.wrap(inputSketchBytes),
ApproxTopK.genSketchSerDe(buffer.getItemDataType))
+ buffer.getSketch.merge(inputSketch)
+ if (!combineSizeSpecified) {
+ buffer.setMaxItemsTracked(inputMaxItemsTracked)
+ }
+ buffer
+ }
+
+ override def merge(buffer: CombineInternal[Any], input: CombineInternal[Any])
+ : CombineInternal[Any] = {
+ if (!combineSizeSpecified) {
+ // check size
+ if (buffer.getMaxItemsTracked == ApproxTopK.VOID_MAX_ITEMS_TRACKED) {
+ // If buffer is a placeholder sketch, set it to the input sketch's max
items tracked
+ buffer.setMaxItemsTracked(input.getMaxItemsTracked)
+ }
+ if (buffer.getMaxItemsTracked != input.getMaxItemsTracked) {
+ throw QueryExecutionErrors.approxTopKSketchSizeNotMatch(
+ buffer.getMaxItemsTracked,
+ input.getMaxItemsTracked
+ )
+ }
+ }
+ // check item data type
+ if (buffer.getItemDataType != null && input.getItemDataType != null &&
+ buffer.getItemDataType != input.getItemDataType) {
+ throw QueryExecutionErrors.approxTopKSketchTypeNotMatch(
+ buffer.getItemDataType,
+ input.getItemDataType
+ )
+ } else if (buffer.getItemDataType == null) {
+ // If buffer is a placeholder sketch, set it to the input sketch's item
data type
+ buffer.setItemDataType(input.getItemDataType)
+ }
+ buffer.getSketch.merge(input.getSketch)
+ buffer
+ }
+
+ override def eval(buffer: CombineInternal[Any]): Any = {
+ val sketchBytes =
+
buffer.getSketch.toByteArray(ApproxTopK.genSketchSerDe(buffer.getItemDataType))
+ val maxItemsTracked = buffer.getMaxItemsTracked
+ val typeCode = ApproxTopK.dataTypeToBytes(buffer.getItemDataType)
+ InternalRow.apply(sketchBytes, null, maxItemsTracked, typeCode)
+ }
+
+ override def serialize(buffer: CombineInternal[Any]): Array[Byte] = {
+ val sketchBytes = buffer.getSketch.toByteArray(
+ ApproxTopK.genSketchSerDe(buffer.getItemDataType))
+ val maxItemsTrackedByte = buffer.getMaxItemsTracked.toByte
Review Comment:
I believe this is casting an `Int` as a `Byte`. Could this cause any
overflow?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]