cboumalh commented on code in PR #51393:
URL: https://github.com/apache/spark/pull/51393#discussion_r2224345479
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala:
##########
@@ -305,4 +318,341 @@ object ApproxTopK {
new ArrayOfDecimalsSerDe(dt).asInstanceOf[ArrayOfItemsSerDe[Any]]
}
}
+
+ def getSketchStateDataType(itemDataType: DataType): StructType =
+ StructType(
+ StructField("Sketch", BinaryType, nullable = false) ::
+ StructField("ItemTypeNull", itemDataType) ::
+ StructField("MaxItemsTracked", IntegerType, nullable = false) ::
+ StructField("TypeCode", BinaryType, nullable = false) :: Nil)
+
+ def dataTypeToBytes(dataType: DataType): Array[Byte] = {
+ dataType match {
+ case _: BooleanType => Array(0, 0, 0)
+ case _: ByteType => Array(1, 0, 0)
+ case _: ShortType => Array(2, 0, 0)
+ case _: IntegerType => Array(3, 0, 0)
+ case _: LongType => Array(4, 0, 0)
+ case _: FloatType => Array(5, 0, 0)
+ case _: DoubleType => Array(6, 0, 0)
+ case _: DateType => Array(7, 0, 0)
+ case _: TimestampType => Array(8, 0, 0)
+ case _: TimestampNTZType => Array(9, 0, 0)
+ case _: StringType => Array(10, 0, 0)
+ case dt: DecimalType => Array(11, dt.precision.toByte, dt.scale.toByte)
+ }
+ }
+
+ def bytesToDataType(bytes: Array[Byte]): DataType = {
+ bytes(0) match {
+ case 0 => BooleanType
+ case 1 => ByteType
+ case 2 => ShortType
+ case 3 => IntegerType
+ case 4 => LongType
+ case 5 => FloatType
+ case 6 => DoubleType
+ case 7 => DateType
+ case 8 => TimestampType
+ case 9 => TimestampNTZType
+ case 10 => StringType
+ case 11 => DecimalType(bytes(1).toInt, bytes(2).toInt)
+ }
+ }
+}
+
+/**
+ * An aggregate function that accumulates items into a sketch, which can then
be used
+ * to combine with other sketches, via ApproxTopKCombine,
+ * or to estimate the top K items, via ApproxTopKEstimate.
+ *
+ * The output of this function is a struct containing the sketch in binary
format,
+ * a null object indicating the type of items in the sketch,
+ * and the maximum number of items tracked by the sketch.
+ *
+ * @param expr the child expression to accumulate items from
+ * @param maxItemsTracked the maximum number of items to track in the sketch
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage =
+ """
+ _FUNC_(expr, maxItemsTracked) - Accumulates items into a sketch.
+ `maxItemsTracked` An optional positive INTEGER literal with upper limit
of 1000000. If maxItemsTracked is not specified, it defaults to 10000.
+ """,
+ examples =
+ """
+ Examples:
+ > SELECT approx_top_k_accumulate(_FUNC_(expr)) FROM VALUES (0), (0),
(1), (1), (2), (3), (4), (4) AS tab(expr);
+
[{"item":0,"count":2},{"item":4,"count":2},{"item":1,"count":2},{"item":2,"count":1},{"item":3,"count":1}]
+
+ > SELECT approx_top_k_accumulate(_FUNC_(expr, 100), 2) FROM VALUES 'a',
'b', 'c', 'c', 'c', 'c', 'd', 'd' AS tab(expr);
+ [{"item":"c","count":4},{"item":"d","count":2}]
+ """,
+ group = "agg_funcs",
+ since = "4.1.0")
+// scalastyle:on line.size.limit
+case class ApproxTopKAccumulate(
+ expr: Expression,
+ maxItemsTracked: Expression,
+ mutableAggBufferOffset: Int = 0,
+ inputAggBufferOffset: Int = 0)
+ extends TypedImperativeAggregate[ItemsSketch[Any]]
+ with ImplicitCastInputTypes
+ with BinaryLike[Expression] {
+
+ def this(child: Expression, maxItemsTracked: Expression) = this(child,
maxItemsTracked, 0, 0)
+
+ def this(child: Expression, maxItemsTracked: Int) = this(child,
Literal(maxItemsTracked), 0, 0)
+
+ def this(child: Expression) = this(child,
Literal(ApproxTopK.DEFAULT_MAX_ITEMS_TRACKED), 0, 0)
+
+ private lazy val itemDataType: DataType = expr.dataType
+
+ private lazy val maxItemsTrackedVal: Int = {
+ ApproxTopK.checkExpressionNotNull(maxItemsTracked, "maxItemsTracked")
+ val maxItemsTrackedVal = maxItemsTracked.eval().asInstanceOf[Int]
+ ApproxTopK.checkMaxItemsTracked(maxItemsTrackedVal)
+ maxItemsTrackedVal
+ }
+
+ override def left: Expression = expr
+
+ override def right: Expression = maxItemsTracked
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType,
IntegerType)
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val defaultCheck = super.checkInputDataTypes()
+ if (defaultCheck.isFailure) {
+ defaultCheck
+ } else if (!ApproxTopK.isDataTypeSupported(itemDataType)) {
+ TypeCheckFailure(f"${itemDataType.typeName} columns are not supported")
+ } else if (!maxItemsTracked.foldable) {
+ TypeCheckFailure("Number of items tracked must be a constant literal")
+ } else {
+ TypeCheckSuccess
+ }
+ }
+
+ override def dataType: DataType =
ApproxTopK.getSketchStateDataType(itemDataType)
+
+ override def createAggregationBuffer(): ItemsSketch[Any] = {
+ val maxMapSize = ApproxTopK.calMaxMapSize(maxItemsTrackedVal)
+ ApproxTopK.createAggregationBuffer(expr, maxMapSize)
+ }
+
+ override def update(buffer: ItemsSketch[Any], input: InternalRow):
ItemsSketch[Any] =
+ ApproxTopK.updateSketchBuffer(expr, buffer, input)
+
+ override def merge(buffer: ItemsSketch[Any], input: ItemsSketch[Any]):
ItemsSketch[Any] =
+ buffer.merge(input)
+
+ override def eval(buffer: ItemsSketch[Any]): Any = {
+ val sketchBytes = serialize(buffer)
+ val typeCode = ApproxTopK.dataTypeToBytes(itemDataType)
+ InternalRow.apply(sketchBytes, null, maxItemsTrackedVal, typeCode)
+ }
+
+ override def serialize(buffer: ItemsSketch[Any]): Array[Byte] =
+ buffer.toByteArray(ApproxTopK.genSketchSerDe(itemDataType))
+
+ override def deserialize(storageFormat: Array[Byte]): ItemsSketch[Any] =
+ ItemsSketch.getInstance(Memory.wrap(storageFormat),
ApproxTopK.genSketchSerDe(itemDataType))
+
+ override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int):
ImperativeAggregate =
+ copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+ override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int):
ImperativeAggregate =
+ copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+ override protected def withNewChildrenInternal(
+ newLeft: Expression,
+ newRight: Expression): Expression =
+ copy(expr = newLeft, maxItemsTracked = newRight)
+
+ override def nullable: Boolean = false
+
+ override def prettyName: String =
+
getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("approx_top_k_accumulate")
+}
+
+class CombineInternal[T](
+ sketch: ItemsSketch[T],
+ var itemDataType: DataType,
+ var maxItemsTracked: Int) {
+ def getSketch: ItemsSketch[T] = sketch
+
+ def getItemDataType: DataType = itemDataType
+
+ def setItemDataType(dataType: DataType): Unit = {
+ if (this.itemDataType == null) {
+ this.itemDataType = dataType
+ } else if (this.itemDataType != dataType) {
+ throw new SparkUnsupportedOperationException(
+ errorClass = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED",
+ messageParameters = Map(
+ "type1" -> this.itemDataType.typeName,
+ "type2" -> dataType.typeName))
+ }
+ }
+
+ def getMaxItemsTracked: Int = maxItemsTracked
+
+ def setMaxItemsTracked(maxItemsTracked: Int): Unit = this.maxItemsTracked =
maxItemsTracked
+}
+
+case class ApproxTopKCombine(
+ expr: Expression,
+ maxItemsTracked: Expression,
+ mutableAggBufferOffset: Int = 0,
+ inputAggBufferOffset: Int = 0)
+ extends TypedImperativeAggregate[CombineInternal[Any]]
+ with ImplicitCastInputTypes
+ with BinaryLike[Expression] {
+
+ def this(child: Expression, maxItemsTracked: Expression) = {
+ this(child, maxItemsTracked, 0, 0)
+ ApproxTopK.checkExpressionNotNull(maxItemsTracked, "maxItemsTracked")
+ ApproxTopK.checkMaxItemsTracked(maxItemsTrackedVal)
+ }
+
+ def this(child: Expression, maxItemsTracked: Int) = this(child,
Literal(maxItemsTracked))
+
+ def this(child: Expression) = this(child,
Literal(ApproxTopK.VOID_MAX_ITEMS_TRACKED), 0, 0)
+
+ private lazy val uncheckedItemDataType: DataType =
+ expr.dataType.asInstanceOf[StructType]("ItemTypeNull").dataType
+ private lazy val maxItemsTrackedVal: Int =
maxItemsTracked.eval().asInstanceOf[Int]
+ private lazy val combineSizeSpecified: Boolean =
+ maxItemsTrackedVal != ApproxTopK.VOID_MAX_ITEMS_TRACKED
+
+ override def left: Expression = expr
+
+ override def right: Expression = maxItemsTracked
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(StructType, IntegerType)
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val defaultCheck = super.checkInputDataTypes()
+ if (defaultCheck.isFailure) {
+ defaultCheck
+ } else if (!maxItemsTracked.foldable) {
+ TypeCheckFailure("Number of items tracked must be a constant literal")
+ } else {
+ TypeCheckSuccess
+ }
+ }
+
+ override def dataType: DataType =
ApproxTopK.getSketchStateDataType(uncheckedItemDataType)
+
+ override def createAggregationBuffer(): CombineInternal[Any] = {
+ if (combineSizeSpecified) {
+ val maxMapSize = ApproxTopK.calMaxMapSize(maxItemsTrackedVal)
+ new CombineInternal[Any](
+ new ItemsSketch[Any](maxMapSize),
+ null,
+ maxItemsTrackedVal)
+ } else {
+ new CombineInternal[Any](
+ new ItemsSketch[Any](ApproxTopK.SKETCH_SIZE_PLACEHOLDER),
+ null,
+ ApproxTopK.VOID_MAX_ITEMS_TRACKED)
+ }
+ }
+
+ override def update(buffer: CombineInternal[Any], input: InternalRow):
CombineInternal[Any] = {
+ val inputSketchBytes =
expr.eval(input).asInstanceOf[InternalRow].getBinary(0)
+ val inputMaxItemsTracked =
expr.eval(input).asInstanceOf[InternalRow].getInt(2)
+ val typeCode = expr.eval(input).asInstanceOf[InternalRow].getBinary(3)
+ val actualItemDataType = ApproxTopK.bytesToDataType(typeCode)
+ buffer.setItemDataType(actualItemDataType)
+ val inputSketch = ItemsSketch.getInstance(
+ Memory.wrap(inputSketchBytes),
ApproxTopK.genSketchSerDe(buffer.getItemDataType))
+ buffer.getSketch.merge(inputSketch)
+ if (!combineSizeSpecified) {
+ buffer.setMaxItemsTracked(inputMaxItemsTracked)
+ }
+ buffer
+ }
+
+ override def merge(buffer: CombineInternal[Any], input: CombineInternal[Any])
+ : CombineInternal[Any] = {
+ if (!combineSizeSpecified) {
+ // check size
+ if (buffer.getMaxItemsTracked == ApproxTopK.VOID_MAX_ITEMS_TRACKED) {
+ // If buffer is a placeholder sketch, set it to the input sketch's max
items tracked
+ buffer.setMaxItemsTracked(input.getMaxItemsTracked)
+ }
+ if (buffer.getMaxItemsTracked != input.getMaxItemsTracked) {
+ throw new SparkUnsupportedOperationException(
+ errorClass = "APPROX_TOP_K_SKETCH_SIZE_UNMATCHED",
+ messageParameters = Map(
+ "size1" -> buffer.getMaxItemsTracked.toString,
+ "size2" -> input.getMaxItemsTracked.toString))
+ }
+ }
+ // check item data type
+ if (buffer.getItemDataType != null && input.getItemDataType != null &&
+ buffer.getItemDataType != input.getItemDataType) {
+ throw new SparkUnsupportedOperationException(
+ errorClass = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED",
+ messageParameters = Map(
+ "type1" -> buffer.getItemDataType.typeName,
+ "type2" -> input.getItemDataType.typeName))
+ } else if (buffer.getItemDataType == null) {
+ // If buffer is a placeholder sketch, set it to the input sketch's item
data type
+ buffer.setItemDataType(input.getItemDataType)
+ }
+ buffer.getSketch.merge(input.getSketch)
+ buffer
+ }
+
+ override def eval(buffer: CombineInternal[Any]): Any = {
+ val sketchBytes = try {
+
buffer.getSketch.toByteArray(ApproxTopK.genSketchSerDe(buffer.getItemDataType))
+ } catch {
+ case _: ArrayStoreException =>
+ throw new SparkUnsupportedOperationException(
+ errorClass = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED"
+ )
+ }
+ val maxItemsTracked = buffer.getMaxItemsTracked
+ val typeCode = ApproxTopK.dataTypeToBytes(buffer.getItemDataType)
+ InternalRow.apply(sketchBytes, null, maxItemsTracked, typeCode)
+ }
+
+ override def serialize(buffer: CombineInternal[Any]): Array[Byte] = {
+ val sketchBytes = buffer.getSketch.toByteArray(
+ ApproxTopK.genSketchSerDe(buffer.getItemDataType))
+ val maxItemsTrackedByte = buffer.getMaxItemsTracked.toByte
Review Comment:
I believe this is casting an `Int` as a `Byte`. Could this cause any
overflow?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]