cboumalh commented on code in PR #51505:
URL: https://github.com/apache/spark/pull/51505#discussion_r2229725214


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala:
##########
@@ -435,3 +507,203 @@ case class ApproxTopKAccumulate(
   override def prettyName: String =
     
getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("approx_top_k_accumulate")
 }
+
+class CombineInternal[T](
+    sketch: ItemsSketch[T],
+    var itemDataType: DataType,
+    var maxItemsTracked: Int) {
+  def getSketch: ItemsSketch[T] = sketch
+
+  def getItemDataType: DataType = itemDataType
+
+  def setItemDataType(dataType: DataType): Unit = {
+    if (this.itemDataType == null) {
+      this.itemDataType = dataType
+    } else if (this.itemDataType != dataType) {
+      throw 
QueryExecutionErrors.approxTopKSketchTypeNotMatch(this.itemDataType, dataType)
+    }
+  }
+
+  def getMaxItemsTracked: Int = maxItemsTracked
+
+  def setMaxItemsTracked(maxItemsTracked: Int): Unit = this.maxItemsTracked = 
maxItemsTracked
+}
+
+/**
+ * An aggregate function that combines multiple sketches into a single sketch.
+ *
+ * @param state                  the expression containing the sketches to 
combine
+ * @param maxItemsTracked        the maximum number of items to track in the 
sketch
+ * @param mutableAggBufferOffset the offset for mutable aggregation buffer
+ * @param inputAggBufferOffset   the offset for input aggregation buffer
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage = """
+    _FUNC_(state, maxItemsTracked) - Combines multiple sketches into a single 
sketch.
+      `maxItemsTracked` An optional positive INTEGER literal with upper limit 
of 1000000. If maxItemsTracked is not specified, it defaults to 10000.
+  """,
+  examples = """
+    Examples:
+      > SELECT approx_top_k_estimate(_FUNC_(sketch, 10000), 5) FROM (SELECT 
approx_top_k_accumulate(expr) AS sketch FROM VALUES (0), (0), (1), (1) AS 
tab(expr) UNION ALL SELECT approx_top_k_accumulate(expr) AS sketch FROM VALUES 
(2), (3), (4), (4) AS tab(expr));
+       
[{"item":0,"count":2},{"item":4,"count":2},{"item":1,"count":2},{"item":2,"count":1},{"item":3,"count":1}]
+  """,
+  group = "agg_funcs",
+  since = "4.1.0")
+// scalastyle:on line.size.limit
+case class ApproxTopKCombine(
+    state: Expression,
+    maxItemsTracked: Expression,
+    mutableAggBufferOffset: Int = 0,
+    inputAggBufferOffset: Int = 0)
+  extends TypedImperativeAggregate[CombineInternal[Any]]
+  with ImplicitCastInputTypes
+  with BinaryLike[Expression] {
+
+  def this(child: Expression, maxItemsTracked: Expression) = {
+    this(child, maxItemsTracked, 0, 0)
+    ApproxTopK.checkExpressionNotNull(maxItemsTracked, "maxItemsTracked")
+    ApproxTopK.checkMaxItemsTracked(maxItemsTrackedVal)
+  }
+
+  def this(child: Expression, maxItemsTracked: Int) = this(child, 
Literal(maxItemsTracked))
+
+  def this(child: Expression) = this(child, 
Literal(ApproxTopK.VOID_MAX_ITEMS_TRACKED), 0, 0)
+
+  private lazy val uncheckedItemDataType: DataType =
+    state.dataType.asInstanceOf[StructType](1).dataType
+  private lazy val maxItemsTrackedVal: Int = 
maxItemsTracked.eval().asInstanceOf[Int]
+  private lazy val combineSizeSpecified: Boolean =
+    maxItemsTrackedVal != ApproxTopK.VOID_MAX_ITEMS_TRACKED
+
+  override def left: Expression = state
+
+  override def right: Expression = maxItemsTracked
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(StructType, IntegerType)
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val defaultCheck = super.checkInputDataTypes()
+    if (defaultCheck.isFailure) {
+      defaultCheck
+    } else {
+      val stateCheck = ApproxTopK.checkStateFieldAndType(state)
+      if (stateCheck.isFailure) {
+        stateCheck
+      } else if (!maxItemsTracked.foldable) {
+        TypeCheckFailure("Number of items tracked must be a constant literal")
+      } else {
+        TypeCheckSuccess
+      }
+    }
+  }
+
+  override def dataType: DataType = 
ApproxTopK.getSketchStateDataType(uncheckedItemDataType)
+
+  override def createAggregationBuffer(): CombineInternal[Any] = {
+    if (combineSizeSpecified) {
+      val maxMapSize = ApproxTopK.calMaxMapSize(maxItemsTrackedVal)
+      new CombineInternal[Any](
+        new ItemsSketch[Any](maxMapSize),
+        null,
+        maxItemsTrackedVal)
+    } else {
+      new CombineInternal[Any](
+        new ItemsSketch[Any](ApproxTopK.SKETCH_SIZE_PLACEHOLDER),
+        null,
+        ApproxTopK.VOID_MAX_ITEMS_TRACKED)
+    }
+  }
+
+  override def update(buffer: CombineInternal[Any], input: InternalRow): 
CombineInternal[Any] = {
+    val inputState = state.eval(input).asInstanceOf[InternalRow]
+    val inputSketchBytes = inputState.getBinary(0)
+    val inputMaxItemsTracked = inputState.getInt(2)
+    val typeCode = inputState.getBinary(3)
+    val actualItemDataType = ApproxTopK.bytesToDataType(typeCode)
+    buffer.setItemDataType(actualItemDataType)
+    val inputSketch = ItemsSketch.getInstance(
+      Memory.wrap(inputSketchBytes), 
ApproxTopK.genSketchSerDe(buffer.getItemDataType))
+    buffer.getSketch.merge(inputSketch)
+    if (!combineSizeSpecified) {
+      buffer.setMaxItemsTracked(inputMaxItemsTracked)
+    }
+    buffer
+  }
+
+  override def merge(buffer: CombineInternal[Any], input: CombineInternal[Any])
+  : CombineInternal[Any] = {
+    if (!combineSizeSpecified) {
+      // check size
+      if (buffer.getMaxItemsTracked == ApproxTopK.VOID_MAX_ITEMS_TRACKED) {
+        // If buffer is a placeholder sketch, set it to the input sketch's max 
items tracked
+        buffer.setMaxItemsTracked(input.getMaxItemsTracked)
+      }
+      if (buffer.getMaxItemsTracked != input.getMaxItemsTracked) {
+        throw QueryExecutionErrors.approxTopKSketchSizeNotMatch(
+          buffer.getMaxItemsTracked,
+          input.getMaxItemsTracked
+        )
+      }
+    }
+    // check item data type
+    if (buffer.getItemDataType != null && input.getItemDataType != null &&
+      buffer.getItemDataType != input.getItemDataType) {
+      throw QueryExecutionErrors.approxTopKSketchTypeNotMatch(
+        buffer.getItemDataType,
+        input.getItemDataType
+      )
+    } else if (buffer.getItemDataType == null) {
+      // If buffer is a placeholder sketch, set it to the input sketch's item 
data type
+      buffer.setItemDataType(input.getItemDataType)
+    }
+    buffer.getSketch.merge(input.getSketch)
+    buffer
+  }
+
+  override def eval(buffer: CombineInternal[Any]): Any = {
+    val sketchBytes =
+      
buffer.getSketch.toByteArray(ApproxTopK.genSketchSerDe(buffer.getItemDataType))
+    val maxItemsTracked = buffer.getMaxItemsTracked
+    val typeCode = ApproxTopK.dataTypeToBytes(buffer.getItemDataType)
+    InternalRow.apply(sketchBytes, null, maxItemsTracked, typeCode)
+  }
+
+  override def serialize(buffer: CombineInternal[Any]): Array[Byte] = {

Review Comment:
   Serialization and deserialization here are a bit concerning. The current 
implementation has brittle byte-level manipulations with no robust error 
checking or flexibility. Hardcoded assumptions about byte sizes and type 
conversions create significant risks for future compatibility and data 
integrity. I think it may be worth looking into using a ByteBuffer and adding 
versioning to the sketch buffer, or any other approaches!



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala:
##########
@@ -177,6 +177,8 @@ object ApproxTopK {
   val DEFAULT_K: Int = 5
   val DEFAULT_MAX_ITEMS_TRACKED: Int = 10000
   private val MAX_ITEMS_TRACKED_LIMIT: Int = 1000000
+  val VOID_MAX_ITEMS_TRACKED = -1

Review Comment:
   Can we document what exactly this VOID_MAX_ITEMS_TRACKED and 
SKETCH_SIZE_PLACEHOLDER mean?



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala:
##########
@@ -435,3 +507,203 @@ case class ApproxTopKAccumulate(
   override def prettyName: String =
     
getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("approx_top_k_accumulate")
 }
+
+class CombineInternal[T](
+    sketch: ItemsSketch[T],
+    var itemDataType: DataType,
+    var maxItemsTracked: Int) {
+  def getSketch: ItemsSketch[T] = sketch
+
+  def getItemDataType: DataType = itemDataType
+
+  def setItemDataType(dataType: DataType): Unit = {
+    if (this.itemDataType == null) {
+      this.itemDataType = dataType
+    } else if (this.itemDataType != dataType) {
+      throw 
QueryExecutionErrors.approxTopKSketchTypeNotMatch(this.itemDataType, dataType)
+    }
+  }
+
+  def getMaxItemsTracked: Int = maxItemsTracked
+
+  def setMaxItemsTracked(maxItemsTracked: Int): Unit = this.maxItemsTracked = 
maxItemsTracked
+}
+
+/**
+ * An aggregate function that combines multiple sketches into a single sketch.
+ *
+ * @param state                  the expression containing the sketches to 
combine
+ * @param maxItemsTracked        the maximum number of items to track in the 
sketch
+ * @param mutableAggBufferOffset the offset for mutable aggregation buffer
+ * @param inputAggBufferOffset   the offset for input aggregation buffer
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage = """
+    _FUNC_(state, maxItemsTracked) - Combines multiple sketches into a single 
sketch.
+      `maxItemsTracked` An optional positive INTEGER literal with upper limit 
of 1000000. If maxItemsTracked is not specified, it defaults to 10000.
+  """,
+  examples = """
+    Examples:
+      > SELECT approx_top_k_estimate(_FUNC_(sketch, 10000), 5) FROM (SELECT 
approx_top_k_accumulate(expr) AS sketch FROM VALUES (0), (0), (1), (1) AS 
tab(expr) UNION ALL SELECT approx_top_k_accumulate(expr) AS sketch FROM VALUES 
(2), (3), (4), (4) AS tab(expr));
+       
[{"item":0,"count":2},{"item":4,"count":2},{"item":1,"count":2},{"item":2,"count":1},{"item":3,"count":1}]
+  """,
+  group = "agg_funcs",
+  since = "4.1.0")
+// scalastyle:on line.size.limit
+case class ApproxTopKCombine(
+    state: Expression,
+    maxItemsTracked: Expression,
+    mutableAggBufferOffset: Int = 0,
+    inputAggBufferOffset: Int = 0)
+  extends TypedImperativeAggregate[CombineInternal[Any]]
+  with ImplicitCastInputTypes
+  with BinaryLike[Expression] {
+
+  def this(child: Expression, maxItemsTracked: Expression) = {
+    this(child, maxItemsTracked, 0, 0)
+    ApproxTopK.checkExpressionNotNull(maxItemsTracked, "maxItemsTracked")
+    ApproxTopK.checkMaxItemsTracked(maxItemsTrackedVal)
+  }
+
+  def this(child: Expression, maxItemsTracked: Int) = this(child, 
Literal(maxItemsTracked))
+
+  def this(child: Expression) = this(child, 
Literal(ApproxTopK.VOID_MAX_ITEMS_TRACKED), 0, 0)
+
+  private lazy val uncheckedItemDataType: DataType =
+    state.dataType.asInstanceOf[StructType](1).dataType
+  private lazy val maxItemsTrackedVal: Int = 
maxItemsTracked.eval().asInstanceOf[Int]
+  private lazy val combineSizeSpecified: Boolean =
+    maxItemsTrackedVal != ApproxTopK.VOID_MAX_ITEMS_TRACKED
+
+  override def left: Expression = state
+
+  override def right: Expression = maxItemsTracked
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(StructType, IntegerType)
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val defaultCheck = super.checkInputDataTypes()
+    if (defaultCheck.isFailure) {
+      defaultCheck
+    } else {
+      val stateCheck = ApproxTopK.checkStateFieldAndType(state)
+      if (stateCheck.isFailure) {
+        stateCheck
+      } else if (!maxItemsTracked.foldable) {
+        TypeCheckFailure("Number of items tracked must be a constant literal")
+      } else {
+        TypeCheckSuccess
+      }
+    }
+  }
+
+  override def dataType: DataType = 
ApproxTopK.getSketchStateDataType(uncheckedItemDataType)
+
+  override def createAggregationBuffer(): CombineInternal[Any] = {
+    if (combineSizeSpecified) {
+      val maxMapSize = ApproxTopK.calMaxMapSize(maxItemsTrackedVal)
+      new CombineInternal[Any](
+        new ItemsSketch[Any](maxMapSize),
+        null,
+        maxItemsTrackedVal)
+    } else {
+      new CombineInternal[Any](
+        new ItemsSketch[Any](ApproxTopK.SKETCH_SIZE_PLACEHOLDER),
+        null,
+        ApproxTopK.VOID_MAX_ITEMS_TRACKED)
+    }
+  }
+
+  override def update(buffer: CombineInternal[Any], input: InternalRow): 
CombineInternal[Any] = {
+    val inputState = state.eval(input).asInstanceOf[InternalRow]
+    val inputSketchBytes = inputState.getBinary(0)
+    val inputMaxItemsTracked = inputState.getInt(2)
+    val typeCode = inputState.getBinary(3)
+    val actualItemDataType = ApproxTopK.bytesToDataType(typeCode)
+    buffer.setItemDataType(actualItemDataType)
+    val inputSketch = ItemsSketch.getInstance(
+      Memory.wrap(inputSketchBytes), 
ApproxTopK.genSketchSerDe(buffer.getItemDataType))
+    buffer.getSketch.merge(inputSketch)
+    if (!combineSizeSpecified) {
+      buffer.setMaxItemsTracked(inputMaxItemsTracked)
+    }
+    buffer
+  }
+
+  override def merge(buffer: CombineInternal[Any], input: CombineInternal[Any])
+  : CombineInternal[Any] = {
+    if (!combineSizeSpecified) {
+      // check size
+      if (buffer.getMaxItemsTracked == ApproxTopK.VOID_MAX_ITEMS_TRACKED) {
+        // If buffer is a placeholder sketch, set it to the input sketch's max 
items tracked
+        buffer.setMaxItemsTracked(input.getMaxItemsTracked)
+      }
+      if (buffer.getMaxItemsTracked != input.getMaxItemsTracked) {
+        throw QueryExecutionErrors.approxTopKSketchSizeNotMatch(
+          buffer.getMaxItemsTracked,
+          input.getMaxItemsTracked
+        )
+      }
+    }
+    // check item data type
+    if (buffer.getItemDataType != null && input.getItemDataType != null &&
+      buffer.getItemDataType != input.getItemDataType) {
+      throw QueryExecutionErrors.approxTopKSketchTypeNotMatch(
+        buffer.getItemDataType,
+        input.getItemDataType
+      )
+    } else if (buffer.getItemDataType == null) {
+      // If buffer is a placeholder sketch, set it to the input sketch's item 
data type
+      buffer.setItemDataType(input.getItemDataType)
+    }
+    buffer.getSketch.merge(input.getSketch)
+    buffer
+  }
+
+  override def eval(buffer: CombineInternal[Any]): Any = {
+    val sketchBytes =
+      
buffer.getSketch.toByteArray(ApproxTopK.genSketchSerDe(buffer.getItemDataType))
+    val maxItemsTracked = buffer.getMaxItemsTracked
+    val typeCode = ApproxTopK.dataTypeToBytes(buffer.getItemDataType)
+    InternalRow.apply(sketchBytes, null, maxItemsTracked, typeCode)
+  }
+
+  override def serialize(buffer: CombineInternal[Any]): Array[Byte] = {
+    val sketchBytes = buffer.getSketch.toByteArray(
+      ApproxTopK.genSketchSerDe(buffer.getItemDataType))
+    val maxItemsTrackedByte = buffer.getMaxItemsTracked.toByte

Review Comment:
   I believe this is casting an `Int` as a `Byte`. Could this cause any 
overflow?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to