This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a96afde1b894 [SPARK-55593][SQL] Unify aggregation state for 
vector_avg/vector_sum
a96afde1b894 is described below

commit a96afde1b894ceb7a3eb0fab154ea499958633f5
Author: zhidongqu-db <[email protected]>
AuthorDate: Fri Feb 20 10:21:00 2026 +0800

    [SPARK-55593][SQL] Unify aggregation state for vector_avg/vector_sum
    
    ### What changes were proposed in this pull request?
    
    Unify the aggregate buffer schema for `vector_avg` and `vector_sum` and 
extract shared logic into a common base trait to reduce code duplication.
    
    - Unified buffer schema: Both functions now use the same (current: BINARY, 
count: LONG) aggregate state. The dim: INTEGER field is removed — dimension is 
inferred from current.length / 4 (4 bytes per float). For vector_sum, count is 
tracked and updated but not used in computation.
    
    - Extract VectorAggregateBase trait: All shared aggregate lifecycle logic 
(initialize, update, merge, eval, input validation, null checking, dimension 
mismatch checking, first-vector initialization) is consolidated into a base 
trait. Subclasses only override two abstract methods for their element-wise 
math: updateElements and mergeElements.
    
    ### Why are the changes needed?
    
    VectorAvg and VectorSum had nearly identical aggregate lifecycle logic 
(buffer management, null handling, dimension validation, binary-to-array 
conversion) duplicated across ~500 lines. The unified state schema and base 
trait eliminate ~200 lines of duplication and make it easier to add new vector 
aggregate functions in the future.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Existing tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Yes, code assistance with Claude Opus 4.6 in combination with manual 
editing by the author.
    
    Closes #54368 from zhidongqu-db/unify-vector-agg-func-state.
    
    Authored-by: zhidongqu-db <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../catalyst/expressions/vectorExpressions.scala   | 476 +++++++--------------
 1 file changed, 165 insertions(+), 311 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/vectorExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/vectorExpressions.scala
index b4f222ccbc5b..e65fae3a2bc2 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/vectorExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/vectorExpressions.scala
@@ -17,8 +17,6 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
-import java.nio.{ByteBuffer, ByteOrder}
-
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
@@ -33,11 +31,11 @@ import org.apache.spark.sql.types.{
   BinaryType,
   DataType,
   FloatType,
-  IntegerType,
   LongType,
   StringType,
   StructType
 }
+import org.apache.spark.unsafe.Platform
 
 // scalastyle:off line.size.limit
 @ExpressionDescription(
@@ -345,37 +343,16 @@ case class VectorNormalize(vector: Expression, degree: 
Expression)
   }
 }
 
-// scalastyle:off line.size.limit
-@ExpressionDescription(
-  usage = """
-    _FUNC_(array) - Returns the element-wise mean of float vectors in a group.
-    All vectors must have the same dimension.
-  """,
-  examples = """
-    Examples:
-      > SELECT _FUNC_(col) FROM VALUES (array(1.0F, 2.0F)), (array(3.0F, 
4.0F)) AS tab(col);
-       [2.0,3.0]
-  """,
-  since = "4.2.0",
-  group = "vector_funcs"
-)
-// scalastyle:on line.size.limit
-// Note: This implementation uses single-precision floating-point arithmetic 
(Float).
-// Precision loss is expected for very large aggregates due to:
-// 1. Accumulated rounding errors in incremental average updates
-// 2. Loss of significance when dividing by large counts
-case class VectorAvg(
-    child: Expression,
-    mutableAggBufferOffset: Int = 0,
-    inputAggBufferOffset: Int = 0
-) extends ImperativeAggregate
+// Base trait for vector aggregate functions (vector_avg, vector_sum).
+// Provides a unified aggregate buffer schema: (acc: BINARY, count: LONG)
+// - acc: BINARY representation of the running vector (sum or average)
+// - count: number of valid vectors seen so far
+// - dimension is inferred from acc.length / 4 (4 bytes per float)
+// Subclasses only need to implement the element-wise update and merge logic.
+trait VectorAggregateBase extends ImperativeAggregate
     with UnaryLike[Expression]
     with QueryErrorsBase {
 
-  def this(child: Expression) = this(child, 0, 0)
-
-  override def prettyName: String = "vector_avg"
-
   override def nullable: Boolean = true
 
   override def dataType: DataType = ArrayType(FloatType, containsNull = false)
@@ -397,26 +374,17 @@ case class VectorAvg(
     }
   }
 
-  // Aggregate buffer schema: (avg: BINARY, dim: INTEGER, count: LONG)
-  // avg is a BINARY representation of the average vector of floats in the 
group
-  // dim is the dimension of the vector
-  // count is the number of vectors in the group
-  // null avg means no valid input has been seen yet
-  private lazy val avgAttr = AttributeReference(
-    "avg",
+  // Aggregate buffer schema: (acc: BINARY, count: LONG)
+  private lazy val accAttr = AttributeReference(
+    "acc",
     BinaryType,
     nullable = true
   )()
-  private lazy val dimAttr = AttributeReference(
-    "dim",
-    IntegerType,
-    nullable = true
-  )()
   private lazy val countAttr =
     AttributeReference("count", LongType, nullable = false)()
 
   override def aggBufferAttributes: Seq[AttributeReference] =
-    Seq(avgAttr, dimAttr, countAttr)
+    Seq(accAttr, countAttr)
 
   override def aggBufferSchema: StructType =
     DataTypeUtils.fromAttributes(aggBufferAttributes)
@@ -425,27 +393,37 @@ case class VectorAvg(
     aggBufferAttributes.map(_.newInstance())
 
   // Buffer indices
-  private val avgIndex = 0
-  private val dimIndex = 1
-  private val countIndex = 2
+  protected val accIndex = 0
+  protected val countIndex = 1
 
-  override def withNewMutableAggBufferOffset(
-      newMutableAggBufferOffset: Int
-  ): ImperativeAggregate =
-    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
-
-  override def withNewInputAggBufferOffset(
-      newInputAggBufferOffset: Int
-  ): ImperativeAggregate =
-    copy(inputAggBufferOffset = newInputAggBufferOffset)
+  protected lazy val inputContainsNull =
+    child.dataType.asInstanceOf[ArrayType].containsNull
 
   override def initialize(buffer: InternalRow): Unit = {
-    buffer.update(mutableAggBufferOffset + avgIndex, null)
-    buffer.update(mutableAggBufferOffset + dimIndex, null)
+    buffer.update(mutableAggBufferOffset + accIndex, null)
     buffer.setLong(mutableAggBufferOffset + countIndex, 0L)
   }
 
-  private lazy val inputContainsNull = 
child.dataType.asInstanceOf[ArrayType].containsNull
+  // Infer vector dimension from byte array length (4 bytes per float)
+  protected def getDim(bytes: Array[Byte]): Int = bytes.length / 4
+
+  // Element-wise update for non-first vectors.
+  // accBytes contains the running vector; update it in-place with inputArray.
+  protected def updateElements(
+      accBytes: Array[Byte],
+      inputArray: ArrayData,
+      dim: Int,
+      newCount: Long): Unit
+
+  // Element-wise merge of two non-empty buffers.
+  // accBytes contains the left running vector; update it in-place.
+  protected def mergeElements(
+      accBytes: Array[Byte],
+      inputBytes: Array[Byte],
+      dim: Int,
+      currentCount: Long,
+      inputCount: Long,
+      newCount: Long): Unit
 
   override def update(buffer: InternalRow, input: InternalRow): Unit = {
     val inputValue = child.eval(input)
@@ -468,68 +446,50 @@ case class VectorAvg(
       }
     }
 
-    val avgOffset = mutableAggBufferOffset + avgIndex
-    val dimOffset = mutableAggBufferOffset + dimIndex
+    val accOffset = mutableAggBufferOffset + accIndex
     val countOffset = mutableAggBufferOffset + countIndex
 
     val currentCount = buffer.getLong(countOffset)
 
     if (currentCount == 0L) {
-      // First valid vector - just copy it as the initial average
-      val byteBuffer =
-        ByteBuffer.allocate(inputLen * 4).order(ByteOrder.LITTLE_ENDIAN)
+      // First valid vector - just copy it
+      val bytes = new Array[Byte](inputLen * 4)
       var i = 0
       while (i < inputLen) {
-        byteBuffer.putFloat(inputArray.getFloat(i))
+        Platform.putFloat(bytes, Platform.BYTE_ARRAY_OFFSET + i.toLong * 4, 
inputArray.getFloat(i))
         i += 1
       }
-      buffer.update(avgOffset, byteBuffer.array())
-      buffer.setInt(dimOffset, inputLen)
+      buffer.update(accOffset, bytes)
       buffer.setLong(countOffset, 1L)
     } else {
-      val currentDim = buffer.getInt(dimOffset)
+      val accBytes = buffer.getBinary(accOffset)
+      val accDim = getDim(accBytes)
 
       // Empty array case - if current is empty and input is empty, keep empty
-      if (currentDim == 0 && inputLen == 0) {
+      if (accDim == 0 && inputLen == 0) {
         buffer.setLong(countOffset, currentCount + 1L)
         return
       }
 
       // Dimension mismatch check
-      if (currentDim != inputLen) {
+      if (accDim != inputLen) {
         throw QueryExecutionErrors.vectorDimensionMismatchError(
           prettyName,
-          currentDim,
+          accDim,
           inputLen
         )
       }
 
-      // Update running average: new_avg = old_avg + (new_value - old_avg) / 
(count + 1)
       val newCount = currentCount + 1L
-      val invCount = 1.0f / newCount
-      val currentAvgBytes = buffer.getBinary(avgOffset)
-      // reuse the buffer without reallocation
-      val avgBuffer =
-        ByteBuffer.wrap(currentAvgBytes).order(ByteOrder.LITTLE_ENDIAN)
-      var i = 0
-      var idx = 0
-      while (i < currentDim) {
-        val oldAvg = avgBuffer.getFloat(idx)
-        val newVal = inputArray.getFloat(i)
-        avgBuffer.putFloat(idx, oldAvg + (newVal - oldAvg) * invCount)
-        i += 1
-        idx += 4 // 4 bytes per float
-      }
+      updateElements(accBytes, inputArray, accDim, newCount)
       buffer.setLong(countOffset, newCount)
     }
   }
 
   override def merge(buffer: InternalRow, inputBuffer: InternalRow): Unit = {
-    val avgOffset = mutableAggBufferOffset + avgIndex
-    val dimOffset = mutableAggBufferOffset + dimIndex
+    val accOffset = mutableAggBufferOffset + accIndex
     val countOffset = mutableAggBufferOffset + countIndex
-    val inputAvgOffset = inputAggBufferOffset + avgIndex
-    val inputDimOffset = inputAggBufferOffset + dimIndex
+    val inputAccOffset = inputAggBufferOffset + accIndex
     val inputCountOffset = inputAggBufferOffset + countIndex
 
     val inputCount = inputBuffer.getLong(inputCountOffset)
@@ -537,91 +497,68 @@ case class VectorAvg(
       return
     }
 
-    val inputAvgBytes = inputBuffer.getBinary(inputAvgOffset)
-    val inputDim = inputBuffer.getInt(inputDimOffset)
+    val inputAccBytes = inputBuffer.getBinary(inputAccOffset)
     val currentCount = buffer.getLong(countOffset)
 
     if (currentCount == 0L) {
       // Copy input buffer to current buffer
-      buffer.update(avgOffset, inputAvgBytes.clone())
-      buffer.setInt(dimOffset, inputDim)
+      buffer.update(accOffset, inputAccBytes.clone())
       buffer.setLong(countOffset, inputCount)
     } else {
-      val currentDim = buffer.getInt(dimOffset)
+      val accBytes = buffer.getBinary(accOffset)
+      val accDim = getDim(accBytes)
+      val inputDim = getDim(inputAccBytes)
 
       // Empty array case
-      if (currentDim == 0 && inputDim == 0) {
+      if (accDim == 0 && inputDim == 0) {
         buffer.setLong(countOffset, currentCount + inputCount)
         return
       }
 
       // Dimension mismatch check
-      if (currentDim != inputDim) {
+      if (accDim != inputDim) {
         throw QueryExecutionErrors.vectorDimensionMismatchError(
           prettyName,
-          currentDim,
+          accDim,
           inputDim
         )
       }
 
-      // Merge running averages:
-      // combined_avg = (left_avg * left_count) / (left_count + right_count) +
-      //   (right_avg * right_count) / (left_count + right_count)
       val newCount = currentCount + inputCount
-      val leftWeight = currentCount.toFloat / newCount
-      val rightWeight = inputCount.toFloat / newCount
-      val currentAvgBytes = buffer.getBinary(avgOffset)
-      // reuse the buffer without reallocation
-      val avgBuffer =
-        ByteBuffer.wrap(currentAvgBytes).order(ByteOrder.LITTLE_ENDIAN)
-      val inputAvgBuffer =
-        ByteBuffer.wrap(inputAvgBytes).order(ByteOrder.LITTLE_ENDIAN)
-      var i = 0
-      var idx = 0
-      while (i < currentDim) {
-        val leftAvg = avgBuffer.getFloat(idx)
-        val rightAvg = inputAvgBuffer.getFloat(idx)
-        avgBuffer.putFloat(idx, leftAvg * leftWeight + rightAvg * rightWeight)
-        i += 1
-        idx += 4 // 4 bytes per float
-      }
+      mergeElements(accBytes, inputAccBytes, accDim,
+        currentCount, inputCount, newCount)
       buffer.setLong(countOffset, newCount)
     }
   }
 
   override def eval(buffer: InternalRow): Any = {
-    val countOffset = mutableAggBufferOffset + countIndex
-    val count = buffer.getLong(countOffset)
+    val count = buffer.getLong(mutableAggBufferOffset + countIndex)
     if (count == 0L) {
       null
     } else {
-      val dim = buffer.getInt(mutableAggBufferOffset + dimIndex)
-      val avgBytes = buffer.getBinary(mutableAggBufferOffset + avgIndex)
-      val avgBuffer = ByteBuffer.wrap(avgBytes).order(ByteOrder.LITTLE_ENDIAN)
+      val accBytes = buffer.getBinary(mutableAggBufferOffset + accIndex)
+      val dim = getDim(accBytes)
       val result = new Array[Float](dim)
       var i = 0
       while (i < dim) {
-        result(i) = avgBuffer.getFloat()
+        result(i) = Platform.getFloat(accBytes, Platform.BYTE_ARRAY_OFFSET + 
i.toLong * 4)
         i += 1
       }
       ArrayData.toArrayData(result)
     }
   }
-
-  override protected def withNewChildInternal(newChild: Expression): VectorAvg 
=
-    copy(child = newChild)
 }
 
 // scalastyle:off line.size.limit
 @ExpressionDescription(
   usage = """
-    _FUNC_(array) - Returns the element-wise sum of float vectors in a group.
+    _FUNC_(array) - Returns the element-wise mean of float vectors in a group.
     All vectors must have the same dimension.
   """,
   examples = """
     Examples:
       > SELECT _FUNC_(col) FROM VALUES (array(1.0F, 2.0F)), (array(3.0F, 
4.0F)) AS tab(col);
-       [4.0,6.0]
+       [2.0,3.0]
   """,
   since = "4.2.0",
   group = "vector_funcs"
@@ -629,68 +566,17 @@ case class VectorAvg(
 // scalastyle:on line.size.limit
 // Note: This implementation uses single-precision floating-point arithmetic 
(Float).
 // Precision loss is expected for very large aggregates due to:
-// 1. Accumulated rounding errors when summing many values
-// 2. Loss of significance when adding small values to large accumulated sums
-case class VectorSum(
+// 1. Accumulated rounding errors in incremental average updates
+// 2. Loss of significance when dividing by large counts
+case class VectorAvg(
     child: Expression,
     mutableAggBufferOffset: Int = 0,
     inputAggBufferOffset: Int = 0
-) extends ImperativeAggregate
-    with UnaryLike[Expression]
-    with QueryErrorsBase {
+) extends VectorAggregateBase {
 
   def this(child: Expression) = this(child, 0, 0)
 
-  override def prettyName: String = "vector_sum"
-
-  override def nullable: Boolean = true
-
-  override def dataType: DataType = ArrayType(FloatType, containsNull = false)
-
-  override def checkInputDataTypes(): TypeCheckResult = {
-    child.dataType match {
-      case ArrayType(FloatType, _) =>
-        TypeCheckResult.TypeCheckSuccess
-      case _ =>
-        DataTypeMismatch(
-          errorSubClass = "UNEXPECTED_INPUT_TYPE",
-          messageParameters = Map(
-            "paramIndex" -> ordinalNumber(0),
-            "requiredType" -> toSQLType(ArrayType(FloatType)),
-            "inputSql" -> toSQLExpr(child),
-            "inputType" -> toSQLType(child.dataType)
-          )
-        )
-    }
-  }
-
-  // Aggregate buffer schema: (sum: BINARY, dim: INTEGER)
-  // sum is a BINARY representation of the sum vector of floats in the group
-  // dim is the dimension of the vector
-  // null sum means no valid input has been seen yet
-  private lazy val sumAttr = AttributeReference(
-    "sum",
-    BinaryType,
-    nullable = true
-  )()
-  private lazy val dimAttr = AttributeReference(
-    "dim",
-    IntegerType,
-    nullable = true
-  )()
-
-  override def aggBufferAttributes: Seq[AttributeReference] =
-    Seq(sumAttr, dimAttr)
-
-  override def aggBufferSchema: StructType =
-    DataTypeUtils.fromAttributes(aggBufferAttributes)
-
-  override lazy val inputAggBufferAttributes: Seq[AttributeReference] =
-    aggBufferAttributes.map(_.newInstance())
-
-  // Buffer indices
-  private val sumIndex = 0
-  private val dimIndex = 1
+  override def prettyName: String = "vector_avg"
 
   override def withNewMutableAggBufferOffset(
       newMutableAggBufferOffset: Int
@@ -702,146 +588,114 @@ case class VectorSum(
   ): ImperativeAggregate =
     copy(inputAggBufferOffset = newInputAggBufferOffset)
 
-  private lazy val inputContainsNull = 
child.dataType.asInstanceOf[ArrayType].containsNull
-
-  override def initialize(buffer: InternalRow): Unit = {
-    buffer.update(mutableAggBufferOffset + sumIndex, null)
-    buffer.update(mutableAggBufferOffset + dimIndex, null)
-  }
-
-  override def update(buffer: InternalRow, input: InternalRow): Unit = {
-    val inputValue = child.eval(input)
-    if (inputValue == null) {
-      return
-    }
-
-    val inputArray = inputValue.asInstanceOf[ArrayData]
-    val inputLen = inputArray.numElements()
-
-    // Check for NULL elements in input vector - skip if any NULL element found
-    // Only check if the array type can contain nulls
-    if (inputContainsNull) {
-      var i = 0
-      while (i < inputLen) {
-        if (inputArray.isNullAt(i)) {
-          return
-        }
-        i += 1
-      }
+  override protected def updateElements(
+      accBytes: Array[Byte],
+      inputArray: ArrayData,
+      dim: Int,
+      newCount: Long): Unit = {
+    // Update running average: new_avg = old_avg + (new_value - old_avg) / 
new_count
+    val invCount = 1.0f / newCount
+    var i = 0
+    while (i < dim) {
+      val off = Platform.BYTE_ARRAY_OFFSET + i.toLong * 4
+      val oldAvg = Platform.getFloat(accBytes, off)
+      Platform.putFloat(accBytes, off, oldAvg + (inputArray.getFloat(i) - 
oldAvg) * invCount)
+      i += 1
     }
+  }
 
-    val sumOffset = mutableAggBufferOffset + sumIndex
-    val dimOffset = mutableAggBufferOffset + dimIndex
-
-    if (buffer.isNullAt(sumOffset)) {
-      // First valid vector - just copy it as the initial sum
-      val byteBuffer =
-        ByteBuffer.allocate(inputLen * 4).order(ByteOrder.LITTLE_ENDIAN)
-      var i = 0
-      while (i < inputLen) {
-        byteBuffer.putFloat(inputArray.getFloat(i))
-        i += 1
-      }
-      buffer.update(sumOffset, byteBuffer.array())
-      buffer.setInt(dimOffset, inputLen)
-    } else {
-      val currentDim = buffer.getInt(dimOffset)
-
-      // Empty array case - if current is empty and input is empty, keep empty
-      if (currentDim == 0 && inputLen == 0) {
-        return
-      }
-
-      // Dimension mismatch check
-      if (currentDim != inputLen) {
-        throw QueryExecutionErrors.vectorDimensionMismatchError(
-          prettyName,
-          currentDim,
-          inputLen
-        )
-      }
-
-      // Update sum: new_sum = old_sum + new_value
-      val currentSumBytes = buffer.getBinary(sumOffset)
-      // reuse the buffer without reallocation
-      val sumBuffer =
-        ByteBuffer.wrap(currentSumBytes).order(ByteOrder.LITTLE_ENDIAN)
-      var i = 0
-      var idx = 0
-      while (i < currentDim) {
-        sumBuffer.putFloat(idx, sumBuffer.getFloat(idx) + 
inputArray.getFloat(i))
-        i += 1
-        idx += 4 // 4 bytes per float
-      }
+  override protected def mergeElements(
+      accBytes: Array[Byte],
+      inputBytes: Array[Byte],
+      dim: Int,
+      currentCount: Long,
+      inputCount: Long,
+      newCount: Long): Unit = {
+    // Merge running averages:
+    // combined_avg = left_avg * (left_count / total) + right_avg * 
(right_count / total)
+    val leftWeight = currentCount.toFloat / newCount
+    val rightWeight = inputCount.toFloat / newCount
+    var i = 0
+    while (i < dim) {
+      val off = Platform.BYTE_ARRAY_OFFSET + i.toLong * 4
+      val leftAvg = Platform.getFloat(accBytes, off)
+      val rightAvg = Platform.getFloat(inputBytes, off)
+      Platform.putFloat(accBytes, off, leftAvg * leftWeight + rightAvg * 
rightWeight)
+      i += 1
     }
   }
 
-  override def merge(buffer: InternalRow, inputBuffer: InternalRow): Unit = {
-    val sumOffset = mutableAggBufferOffset + sumIndex
-    val dimOffset = mutableAggBufferOffset + dimIndex
-    val inputSumOffset = inputAggBufferOffset + sumIndex
-    val inputDimOffset = inputAggBufferOffset + dimIndex
+  override protected def withNewChildInternal(newChild: Expression): VectorAvg 
=
+    copy(child = newChild)
+}
 
-    if (inputBuffer.isNullAt(inputSumOffset)) {
-      return
-    }
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage = """
+    _FUNC_(array) - Returns the element-wise sum of float vectors in a group.
+    All vectors must have the same dimension.
+  """,
+  examples = """
+    Examples:
+      > SELECT _FUNC_(col) FROM VALUES (array(1.0F, 2.0F)), (array(3.0F, 
4.0F)) AS tab(col);
+       [4.0,6.0]
+  """,
+  since = "4.2.0",
+  group = "vector_funcs"
+)
+// scalastyle:on line.size.limit
+// Note: This implementation uses single-precision floating-point arithmetic 
(Float).
+// Precision loss is expected for very large aggregates due to:
+// 1. Accumulated rounding errors when summing many values
+// 2. Loss of significance when adding small values to large accumulated sums
+case class VectorSum(
+    child: Expression,
+    mutableAggBufferOffset: Int = 0,
+    inputAggBufferOffset: Int = 0
+) extends VectorAggregateBase {
 
-    val inputSumBytes = inputBuffer.getBinary(inputSumOffset)
-    val inputDim = inputBuffer.getInt(inputDimOffset)
+  def this(child: Expression) = this(child, 0, 0)
 
-    if (buffer.isNullAt(sumOffset)) {
-      // Copy input buffer to current buffer
-      buffer.update(sumOffset, inputSumBytes.clone())
-      buffer.setInt(dimOffset, inputDim)
-    } else {
-      val currentDim = buffer.getInt(dimOffset)
+  override def prettyName: String = "vector_sum"
 
-      // Empty array case
-      if (currentDim == 0 && inputDim == 0) {
-        return
-      }
+  override def withNewMutableAggBufferOffset(
+      newMutableAggBufferOffset: Int
+  ): ImperativeAggregate =
+    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
 
-      // Dimension mismatch check
-      if (currentDim != inputDim) {
-        throw QueryExecutionErrors.vectorDimensionMismatchError(
-          prettyName,
-          currentDim,
-          inputDim
-        )
-      }
+  override def withNewInputAggBufferOffset(
+      newInputAggBufferOffset: Int
+  ): ImperativeAggregate =
+    copy(inputAggBufferOffset = newInputAggBufferOffset)
 
-      // Merge sums: combined_sum = left_sum + right_sum
-      val currentSumBytes = buffer.getBinary(sumOffset)
-      // reuse the buffer without reallocation
-      val sumBuffer =
-        ByteBuffer.wrap(currentSumBytes).order(ByteOrder.LITTLE_ENDIAN)
-      val inputSumBuffer =
-        ByteBuffer.wrap(inputSumBytes).order(ByteOrder.LITTLE_ENDIAN)
-      var i = 0
-      var idx = 0
-      while (i < currentDim) {
-        sumBuffer.putFloat(idx, sumBuffer.getFloat(idx) + 
inputSumBuffer.getFloat(idx))
-        i += 1
-        idx += 4 // 4 bytes per float
-      }
+  override protected def updateElements(
+      accBytes: Array[Byte],
+      inputArray: ArrayData,
+      dim: Int,
+      newCount: Long): Unit = {
+    // Update sum: new_sum = old_sum + new_value
+    var i = 0
+    while (i < dim) {
+      val off = Platform.BYTE_ARRAY_OFFSET + i.toLong * 4
+      Platform.putFloat(accBytes, off, Platform.getFloat(accBytes, off) + 
inputArray.getFloat(i))
+      i += 1
     }
   }
 
-  override def eval(buffer: InternalRow): Any = {
-    val sumOffset = mutableAggBufferOffset + sumIndex
-    if (buffer.isNullAt(sumOffset)) {
-      null
-    } else {
-      val dim = buffer.getInt(mutableAggBufferOffset + dimIndex)
-      val sumBytes = buffer.getBinary(sumOffset)
-      val sumBuffer = ByteBuffer.wrap(sumBytes).order(ByteOrder.LITTLE_ENDIAN)
-      val result = new Array[Float](dim)
-      var i = 0
-      while (i < dim) {
-        result(i) = sumBuffer.getFloat()
-        i += 1
-      }
-      ArrayData.toArrayData(result)
+  override protected def mergeElements(
+      accBytes: Array[Byte],
+      inputBytes: Array[Byte],
+      dim: Int,
+      currentCount: Long,
+      inputCount: Long,
+      newCount: Long): Unit = {
+    // Merge sums: combined_sum = left_sum + right_sum
+    var i = 0
+    while (i < dim) {
+      val off = Platform.BYTE_ARRAY_OFFSET + i.toLong * 4
+      Platform.putFloat(accBytes, off,
+        Platform.getFloat(accBytes, off) + Platform.getFloat(inputBytes, off))
+      i += 1
     }
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to