cboumalh commented on code in PR #51393:
URL: https://github.com/apache/spark/pull/51393#discussion_r2224336021
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala:
##########
@@ -305,4 +318,341 @@ object ApproxTopK {
new ArrayOfDecimalsSerDe(dt).asInstanceOf[ArrayOfItemsSerDe[Any]]
}
}
+
+ def getSketchStateDataType(itemDataType: DataType): StructType =
+ StructType(
+ StructField("Sketch", BinaryType, nullable = false) ::
+ StructField("ItemTypeNull", itemDataType) ::
+ StructField("MaxItemsTracked", IntegerType, nullable = false) ::
+ StructField("TypeCode", BinaryType, nullable = false) :: Nil)
+
+ def dataTypeToBytes(dataType: DataType): Array[Byte] = {
+ dataType match {
+ case _: BooleanType => Array(0, 0, 0)
+ case _: ByteType => Array(1, 0, 0)
+ case _: ShortType => Array(2, 0, 0)
+ case _: IntegerType => Array(3, 0, 0)
+ case _: LongType => Array(4, 0, 0)
+ case _: FloatType => Array(5, 0, 0)
+ case _: DoubleType => Array(6, 0, 0)
+ case _: DateType => Array(7, 0, 0)
+ case _: TimestampType => Array(8, 0, 0)
+ case _: TimestampNTZType => Array(9, 0, 0)
+ case _: StringType => Array(10, 0, 0)
+ case dt: DecimalType => Array(11, dt.precision.toByte, dt.scale.toByte)
+ }
+ }
+
+ def bytesToDataType(bytes: Array[Byte]): DataType = {
+ bytes(0) match {
+ case 0 => BooleanType
+ case 1 => ByteType
+ case 2 => ShortType
+ case 3 => IntegerType
+ case 4 => LongType
+ case 5 => FloatType
+ case 6 => DoubleType
+ case 7 => DateType
+ case 8 => TimestampType
+ case 9 => TimestampNTZType
+ case 10 => StringType
+ case 11 => DecimalType(bytes(1).toInt, bytes(2).toInt)
+ }
+ }
+}
+
+/**
+ * An aggregate function that accumulates items into a sketch, which can then
be used
+ * to combine with other sketches, via ApproxTopKCombine,
+ * or to estimate the top K items, via ApproxTopKEstimate.
+ *
+ * The output of this function is a struct containing the sketch in binary
format,
+ * a null object indicating the type of items in the sketch,
+ * and the maximum number of items tracked by the sketch.
+ *
+ * @param expr the child expression to accumulate items from
+ * @param maxItemsTracked the maximum number of items to track in the sketch
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage =
+ """
+ _FUNC_(expr, maxItemsTracked) - Accumulates items into a sketch.
+ `maxItemsTracked` An optional positive INTEGER literal with upper limit
of 1000000. If maxItemsTracked is not specified, it defaults to 10000.
+ """,
+ examples =
+ """
+ Examples:
+ > SELECT approx_top_k_accumulate(_FUNC_(expr)) FROM VALUES (0), (0),
(1), (1), (2), (3), (4), (4) AS tab(expr);
+
[{"item":0,"count":2},{"item":4,"count":2},{"item":1,"count":2},{"item":2,"count":1},{"item":3,"count":1}]
+
+ > SELECT approx_top_k_accumulate(_FUNC_(expr, 100), 2) FROM VALUES 'a',
'b', 'c', 'c', 'c', 'c', 'd', 'd' AS tab(expr);
+ [{"item":"c","count":4},{"item":"d","count":2}]
+ """,
+ group = "agg_funcs",
+ since = "4.1.0")
+// scalastyle:on line.size.limit
+case class ApproxTopKAccumulate(
+ expr: Expression,
+ maxItemsTracked: Expression,
+ mutableAggBufferOffset: Int = 0,
+ inputAggBufferOffset: Int = 0)
+ extends TypedImperativeAggregate[ItemsSketch[Any]]
+ with ImplicitCastInputTypes
+ with BinaryLike[Expression] {
+
+ def this(child: Expression, maxItemsTracked: Expression) = this(child,
maxItemsTracked, 0, 0)
+
+ def this(child: Expression, maxItemsTracked: Int) = this(child,
Literal(maxItemsTracked), 0, 0)
+
+ def this(child: Expression) = this(child,
Literal(ApproxTopK.DEFAULT_MAX_ITEMS_TRACKED), 0, 0)
+
+ private lazy val itemDataType: DataType = expr.dataType
+
+ private lazy val maxItemsTrackedVal: Int = {
+ ApproxTopK.checkExpressionNotNull(maxItemsTracked, "maxItemsTracked")
+ val maxItemsTrackedVal = maxItemsTracked.eval().asInstanceOf[Int]
+ ApproxTopK.checkMaxItemsTracked(maxItemsTrackedVal)
+ maxItemsTrackedVal
+ }
+
+ override def left: Expression = expr
+
+ override def right: Expression = maxItemsTracked
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType,
IntegerType)
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val defaultCheck = super.checkInputDataTypes()
+ if (defaultCheck.isFailure) {
+ defaultCheck
+ } else if (!ApproxTopK.isDataTypeSupported(itemDataType)) {
+ TypeCheckFailure(f"${itemDataType.typeName} columns are not supported")
+ } else if (!maxItemsTracked.foldable) {
+ TypeCheckFailure("Number of items tracked must be a constant literal")
+ } else {
+ TypeCheckSuccess
+ }
+ }
+
+ override def dataType: DataType =
ApproxTopK.getSketchStateDataType(itemDataType)
+
+ override def createAggregationBuffer(): ItemsSketch[Any] = {
+ val maxMapSize = ApproxTopK.calMaxMapSize(maxItemsTrackedVal)
+ ApproxTopK.createAggregationBuffer(expr, maxMapSize)
+ }
+
+ override def update(buffer: ItemsSketch[Any], input: InternalRow):
ItemsSketch[Any] =
+ ApproxTopK.updateSketchBuffer(expr, buffer, input)
+
+ override def merge(buffer: ItemsSketch[Any], input: ItemsSketch[Any]):
ItemsSketch[Any] =
+ buffer.merge(input)
+
+ override def eval(buffer: ItemsSketch[Any]): Any = {
+ val sketchBytes = serialize(buffer)
+ val typeCode = ApproxTopK.dataTypeToBytes(itemDataType)
+ InternalRow.apply(sketchBytes, null, maxItemsTrackedVal, typeCode)
+ }
+
+ override def serialize(buffer: ItemsSketch[Any]): Array[Byte] =
+ buffer.toByteArray(ApproxTopK.genSketchSerDe(itemDataType))
+
+ override def deserialize(storageFormat: Array[Byte]): ItemsSketch[Any] =
+ ItemsSketch.getInstance(Memory.wrap(storageFormat),
ApproxTopK.genSketchSerDe(itemDataType))
+
+ override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int):
ImperativeAggregate =
+ copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+ override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int):
ImperativeAggregate =
+ copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+ override protected def withNewChildrenInternal(
+ newLeft: Expression,
+ newRight: Expression): Expression =
+ copy(expr = newLeft, maxItemsTracked = newRight)
+
+ override def nullable: Boolean = false
+
+ override def prettyName: String =
+
getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("approx_top_k_accumulate")
+}
+
+class CombineInternal[T](
+ sketch: ItemsSketch[T],
+ var itemDataType: DataType,
+ var maxItemsTracked: Int) {
+ def getSketch: ItemsSketch[T] = sketch
+
+ def getItemDataType: DataType = itemDataType
+
+ def setItemDataType(dataType: DataType): Unit = {
+ if (this.itemDataType == null) {
+ this.itemDataType = dataType
+ } else if (this.itemDataType != dataType) {
+ throw new SparkUnsupportedOperationException(
+ errorClass = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED",
+ messageParameters = Map(
+ "type1" -> this.itemDataType.typeName,
+ "type2" -> dataType.typeName))
+ }
+ }
+
+ def getMaxItemsTracked: Int = maxItemsTracked
+
+ def setMaxItemsTracked(maxItemsTracked: Int): Unit = this.maxItemsTracked =
maxItemsTracked
+}
+
+case class ApproxTopKCombine(
Review Comment:
Is a ExpressionDescription planned on being added?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]