zhidongqu-db commented on code in PR #54368:
URL: https://github.com/apache/spark/pull/54368#discussion_r2824722170


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/vectorExpressions.scala:
##########
@@ -702,146 +600,118 @@ case class VectorSum(
   ): ImperativeAggregate =
     copy(inputAggBufferOffset = newInputAggBufferOffset)
 
-  private lazy val inputContainsNull = 
child.dataType.asInstanceOf[ArrayType].containsNull
-
-  override def initialize(buffer: InternalRow): Unit = {
-    buffer.update(mutableAggBufferOffset + sumIndex, null)
-    buffer.update(mutableAggBufferOffset + dimIndex, null)
-  }
-
-  override def update(buffer: InternalRow, input: InternalRow): Unit = {
-    val inputValue = child.eval(input)
-    if (inputValue == null) {
-      return
+  override protected def updateElements(
+      currentBuffer: ByteBuffer,
+      inputArray: ArrayData,
+      dim: Int,
+      newCount: Long): Unit = {
+    // Update running average: new_avg = old_avg + (new_value - old_avg) / 
new_count
+    val invCount = 1.0f / newCount
+    var i = 0
+    var idx = 0
+    while (i < dim) {
+      val oldAvg = currentBuffer.getFloat(idx)
+      val newVal = inputArray.getFloat(i)
+      currentBuffer.putFloat(idx, oldAvg + (newVal - oldAvg) * invCount)
+      i += 1
+      idx += 4 // 4 bytes per float
     }
+  }
 
-    val inputArray = inputValue.asInstanceOf[ArrayData]
-    val inputLen = inputArray.numElements()
-
-    // Check for NULL elements in input vector - skip if any NULL element found
-    // Only check if the array type can contain nulls
-    if (inputContainsNull) {
-      var i = 0
-      while (i < inputLen) {
-        if (inputArray.isNullAt(i)) {
-          return
-        }
-        i += 1
-      }
-    }
-
-    val sumOffset = mutableAggBufferOffset + sumIndex
-    val dimOffset = mutableAggBufferOffset + dimIndex
-
-    if (buffer.isNullAt(sumOffset)) {
-      // First valid vector - just copy it as the initial sum
-      val byteBuffer =
-        ByteBuffer.allocate(inputLen * 4).order(ByteOrder.LITTLE_ENDIAN)
-      var i = 0
-      while (i < inputLen) {
-        byteBuffer.putFloat(inputArray.getFloat(i))
-        i += 1
-      }
-      buffer.update(sumOffset, byteBuffer.array())
-      buffer.setInt(dimOffset, inputLen)
-    } else {
-      val currentDim = buffer.getInt(dimOffset)
-
-      // Empty array case - if current is empty and input is empty, keep empty
-      if (currentDim == 0 && inputLen == 0) {
-        return
-      }
-
-      // Dimension mismatch check
-      if (currentDim != inputLen) {
-        throw QueryExecutionErrors.vectorDimensionMismatchError(
-          prettyName,
-          currentDim,
-          inputLen
-        )
-      }
-
-      // Update sum: new_sum = old_sum + new_value
-      val currentSumBytes = buffer.getBinary(sumOffset)
-      // reuse the buffer without reallocation
-      val sumBuffer =
-        ByteBuffer.wrap(currentSumBytes).order(ByteOrder.LITTLE_ENDIAN)
-      var i = 0
-      var idx = 0
-      while (i < currentDim) {
-        sumBuffer.putFloat(idx, sumBuffer.getFloat(idx) + 
inputArray.getFloat(i))
-        i += 1
-        idx += 4 // 4 bytes per float
-      }
+  override protected def mergeElements(
+      currentBuffer: ByteBuffer,
+      inputBuffer: ByteBuffer,
+      dim: Int,
+      currentCount: Long,
+      inputCount: Long,
+      newCount: Long): Unit = {
+    // Merge running averages:
+    // combined_avg = left_avg * (left_count / total) + right_avg * 
(right_count / total)
+    val leftWeight = currentCount.toFloat / newCount
+    val rightWeight = inputCount.toFloat / newCount
+    var i = 0
+    var idx = 0
+    while (i < dim) {
+      val leftAvg = currentBuffer.getFloat(idx)
+      val rightAvg = inputBuffer.getFloat(idx)
+      currentBuffer.putFloat(idx, leftAvg * leftWeight + rightAvg * 
rightWeight)
+      i += 1
+      idx += 4 // 4 bytes per float
     }
   }
 
-  override def merge(buffer: InternalRow, inputBuffer: InternalRow): Unit = {
-    val sumOffset = mutableAggBufferOffset + sumIndex
-    val dimOffset = mutableAggBufferOffset + dimIndex
-    val inputSumOffset = inputAggBufferOffset + sumIndex
-    val inputDimOffset = inputAggBufferOffset + dimIndex
+  override protected def withNewChildInternal(newChild: Expression): VectorAvg 
=
+    copy(child = newChild)
+}
 
-    if (inputBuffer.isNullAt(inputSumOffset)) {
-      return
-    }
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage = """
+    _FUNC_(array) - Returns the element-wise sum of float vectors in a group.
+    All vectors must have the same dimension.
+  """,
+  examples = """
+    Examples:
+      > SELECT _FUNC_(col) FROM VALUES (array(1.0F, 2.0F)), (array(3.0F, 
4.0F)) AS tab(col);
+       [4.0,6.0]
+  """,
+  since = "4.2.0",
+  group = "vector_funcs"
+)
+// scalastyle:on line.size.limit
+// Note: This implementation uses single-precision floating-point arithmetic 
(Float).
+// Precision loss is expected for very large aggregates due to:
+// 1. Accumulated rounding errors when summing many values
+// 2. Loss of significance when adding small values to large accumulated sums
+case class VectorSum(
+    child: Expression,
+    mutableAggBufferOffset: Int = 0,
+    inputAggBufferOffset: Int = 0
+) extends VectorAggregateBase {
 
-    val inputSumBytes = inputBuffer.getBinary(inputSumOffset)
-    val inputDim = inputBuffer.getInt(inputDimOffset)
+  def this(child: Expression) = this(child, 0, 0)
 
-    if (buffer.isNullAt(sumOffset)) {
-      // Copy input buffer to current buffer
-      buffer.update(sumOffset, inputSumBytes.clone())
-      buffer.setInt(dimOffset, inputDim)
-    } else {
-      val currentDim = buffer.getInt(dimOffset)
+  override def prettyName: String = "vector_sum"
 
-      // Empty array case
-      if (currentDim == 0 && inputDim == 0) {
-        return
-      }
+  override def withNewMutableAggBufferOffset(
+      newMutableAggBufferOffset: Int
+  ): ImperativeAggregate =
+    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
 
-      // Dimension mismatch check
-      if (currentDim != inputDim) {
-        throw QueryExecutionErrors.vectorDimensionMismatchError(
-          prettyName,
-          currentDim,
-          inputDim
-        )
-      }
+  override def withNewInputAggBufferOffset(
+      newInputAggBufferOffset: Int
+  ): ImperativeAggregate =
+    copy(inputAggBufferOffset = newInputAggBufferOffset)
 
-      // Merge sums: combined_sum = left_sum + right_sum
-      val currentSumBytes = buffer.getBinary(sumOffset)
-      // reuse the buffer without reallocation
-      val sumBuffer =
-        ByteBuffer.wrap(currentSumBytes).order(ByteOrder.LITTLE_ENDIAN)
-      val inputSumBuffer =
-        ByteBuffer.wrap(inputSumBytes).order(ByteOrder.LITTLE_ENDIAN)
-      var i = 0
-      var idx = 0
-      while (i < currentDim) {
-        sumBuffer.putFloat(idx, sumBuffer.getFloat(idx) + 
inputSumBuffer.getFloat(idx))
-        i += 1
-        idx += 4 // 4 bytes per float
-      }
+  override protected def updateElements(
+      currentBuffer: ByteBuffer,
+      inputArray: ArrayData,
+      dim: Int,
+      newCount: Long): Unit = {
+    // Update sum: new_sum = old_sum + new_value
+    var i = 0
+    var idx = 0
+    while (i < dim) {

Review Comment:
   yeah, if you do some search on the codebase, it's pretty widely used 
everywhere as far as I could tell. One good example specifically in the context 
of agg function is 
`sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to