cboumalh commented on code in PR #51505:
URL: https://github.com/apache/spark/pull/51505#discussion_r2430661588


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala:
##########
@@ -435,3 +491,283 @@ case class ApproxTopKAccumulate(
   override def prettyName: String =
     
getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("approx_top_k_accumulate")
 }
+
+/**
+ * In internal class used as the aggregation buffer for ApproxTopKCombine.
+ *
+ * @param sketch          the ItemsSketch instance
+ * @param itemDataType    the data type of items in the sketch
+ * @param maxItemsTracked the maximum number of items tracked in the sketch
+ */
+class CombineInternal[T](
+    sketch: ItemsSketch[T],
+    var itemDataType: DataType,
+    var maxItemsTracked: Int) {
+  def getSketch: ItemsSketch[T] = sketch
+
+  def getItemDataType: DataType = itemDataType
+
+  def getMaxItemsTracked: Int = maxItemsTracked
+
+  def updateMaxItemsTracked(combineSizeSpecified: Boolean, newMaxItemsTracked: 
Int): Unit = {
+    if (!combineSizeSpecified) {
+      // check size
+      if (this.maxItemsTracked == ApproxTopK.VOID_MAX_ITEMS_TRACKED) {
+        // If buffer's maxItemsTracked VOID_MAX_ITEMS_TRACKED, it means the 
buffer is a placeholder
+        // sketch that has not beed updated by any input sketch yet.
+        // So we can set it to the input sketch's max items tracked.
+        this.maxItemsTracked = newMaxItemsTracked
+      } else {
+        if (this.maxItemsTracked != newMaxItemsTracked) {
+          // If buffer's maxItemsTracked is not VOID_MAX_ITEMS_TRACKED, it 
means the buffer has been
+          // updated by some input sketch. So if buffer and input sketch have 
different
+          // maxItemsTracked values, it means at least two of the input 
sketches have different
+          // maxItemsTracked values. In this case, we should throw an error.
+          throw QueryExecutionErrors.approxTopKSketchSizeNotMatch(
+            this.maxItemsTracked, newMaxItemsTracked)
+        }
+      }
+    }
+  }
+
+  def updateItemDataType(inputItemDataType: DataType): Unit = {
+    // When the buffer's dataType hasn't been set, set it to the input 
sketch's item data type
+    // When input sketch's item data type is null, buffer's item data type 
will remain null
+    if (this.itemDataType == null) {
+      this.itemDataType = inputItemDataType
+    } else {
+      // When the buffer's dataType has been set, throw an error
+      // if the input sketch's item data type is not null the two data types 
don't match
+      if (inputItemDataType != null && this.itemDataType != inputItemDataType) 
{
+        throw QueryExecutionErrors.approxTopKSketchTypeNotMatch(
+          this.itemDataType, inputItemDataType)
+      }
+    }
+  }
+
+  /**
+   * Serialize the CombineInternal instance to a byte array.
+   * Serialization format:
+   *     maxItemsTracked (4 bytes int) +
+   *     itemDataTypeDDL length n in byte  (4 bytes int) +
+   *     itemDataTypeDDL (n bytes) +
+   *     sketchBytes
+   */
+  def serialize(): Array[Byte] = {
+    val sketchBytes = sketch.toByteArray(
+      
ApproxTopK.genSketchSerDe(itemDataType).asInstanceOf[ArrayOfItemsSerDe[T]])
+    val itemDataTypeDDL = ApproxTopK.dataTypeToDDL(itemDataType)
+    val ddlBytes: Array[Byte] = 
itemDataTypeDDL.getBytes(StandardCharsets.UTF_8)
+    val byteArray = new Array[Byte](sketchBytes.length + 4 + 4 + 
ddlBytes.length)

Review Comment:
   super nit, we could consider `Integer.BYTES` instead of 4 here for clarity 
about what the extra bytes represent. Not necessary to make the change, your 
call.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to