This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new a96afde1b894 [SPARK-55593][SQL] Unify aggregation state for
vector_avg/vector_sum
a96afde1b894 is described below
commit a96afde1b894ceb7a3eb0fab154ea499958633f5
Author: zhidongqu-db <[email protected]>
AuthorDate: Fri Feb 20 10:21:00 2026 +0800
[SPARK-55593][SQL] Unify aggregation state for vector_avg/vector_sum
### What changes were proposed in this pull request?
Unify the aggregate buffer schema for `vector_avg` and `vector_sum` and
extract shared logic into a common base trait to reduce code duplication.
- Unified buffer schema: Both functions now use the same (current: BINARY,
count: LONG) aggregate state. The dim: INTEGER field is removed — dimension is
inferred from current.length / 4 (4 bytes per float). For vector_sum, count is
tracked and updated but not used in computation.
- Extract VectorAggregateBase trait: All shared aggregate lifecycle logic
(initialize, update, merge, eval, input validation, null checking, dimension
mismatch checking, first-vector initialization) is consolidated into a base
trait. Subclasses only override two abstract methods for their element-wise
math: updateElements and mergeElements.
### Why are the changes needed?
VectorAvg and VectorSum had nearly identical aggregate lifecycle logic
(buffer management, null handling, dimension validation, binary-to-array
conversion) duplicated across ~500 lines. The unified state schema and base
trait eliminate ~200 lines of duplication and make it easier to add new vector
aggregate functions in the future.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Existing tests
### Was this patch authored or co-authored using generative AI tooling?
Yes, code assistance with Claude Opus 4.6 in combination with manual
editing by the author.
Closes #54368 from zhidongqu-db/unify-vector-agg-func-state.
Authored-by: zhidongqu-db <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../catalyst/expressions/vectorExpressions.scala | 476 +++++++--------------
1 file changed, 165 insertions(+), 311 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/vectorExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/vectorExpressions.scala
index b4f222ccbc5b..e65fae3a2bc2 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/vectorExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/vectorExpressions.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.catalyst.expressions
-import java.nio.{ByteBuffer, ByteOrder}
-
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
@@ -33,11 +31,11 @@ import org.apache.spark.sql.types.{
BinaryType,
DataType,
FloatType,
- IntegerType,
LongType,
StringType,
StructType
}
+import org.apache.spark.unsafe.Platform
// scalastyle:off line.size.limit
@ExpressionDescription(
@@ -345,37 +343,16 @@ case class VectorNormalize(vector: Expression, degree:
Expression)
}
}
-// scalastyle:off line.size.limit
-@ExpressionDescription(
- usage = """
- _FUNC_(array) - Returns the element-wise mean of float vectors in a group.
- All vectors must have the same dimension.
- """,
- examples = """
- Examples:
- > SELECT _FUNC_(col) FROM VALUES (array(1.0F, 2.0F)), (array(3.0F,
4.0F)) AS tab(col);
- [2.0,3.0]
- """,
- since = "4.2.0",
- group = "vector_funcs"
-)
-// scalastyle:on line.size.limit
-// Note: This implementation uses single-precision floating-point arithmetic
(Float).
-// Precision loss is expected for very large aggregates due to:
-// 1. Accumulated rounding errors in incremental average updates
-// 2. Loss of significance when dividing by large counts
-case class VectorAvg(
- child: Expression,
- mutableAggBufferOffset: Int = 0,
- inputAggBufferOffset: Int = 0
-) extends ImperativeAggregate
+// Base trait for vector aggregate functions (vector_avg, vector_sum).
+// Provides a unified aggregate buffer schema: (acc: BINARY, count: LONG)
+// - acc: BINARY representation of the running vector (sum or average)
+// - count: number of valid vectors seen so far
+// - dimension is inferred from acc.length / 4 (4 bytes per float)
+// Subclasses only need to implement the element-wise update and merge logic.
+trait VectorAggregateBase extends ImperativeAggregate
with UnaryLike[Expression]
with QueryErrorsBase {
- def this(child: Expression) = this(child, 0, 0)
-
- override def prettyName: String = "vector_avg"
-
override def nullable: Boolean = true
override def dataType: DataType = ArrayType(FloatType, containsNull = false)
@@ -397,26 +374,17 @@ case class VectorAvg(
}
}
- // Aggregate buffer schema: (avg: BINARY, dim: INTEGER, count: LONG)
- // avg is a BINARY representation of the average vector of floats in the
group
- // dim is the dimension of the vector
- // count is the number of vectors in the group
- // null avg means no valid input has been seen yet
- private lazy val avgAttr = AttributeReference(
- "avg",
+ // Aggregate buffer schema: (acc: BINARY, count: LONG)
+ private lazy val accAttr = AttributeReference(
+ "acc",
BinaryType,
nullable = true
)()
- private lazy val dimAttr = AttributeReference(
- "dim",
- IntegerType,
- nullable = true
- )()
private lazy val countAttr =
AttributeReference("count", LongType, nullable = false)()
override def aggBufferAttributes: Seq[AttributeReference] =
- Seq(avgAttr, dimAttr, countAttr)
+ Seq(accAttr, countAttr)
override def aggBufferSchema: StructType =
DataTypeUtils.fromAttributes(aggBufferAttributes)
@@ -425,27 +393,37 @@ case class VectorAvg(
aggBufferAttributes.map(_.newInstance())
// Buffer indices
- private val avgIndex = 0
- private val dimIndex = 1
- private val countIndex = 2
+ protected val accIndex = 0
+ protected val countIndex = 1
- override def withNewMutableAggBufferOffset(
- newMutableAggBufferOffset: Int
- ): ImperativeAggregate =
- copy(mutableAggBufferOffset = newMutableAggBufferOffset)
-
- override def withNewInputAggBufferOffset(
- newInputAggBufferOffset: Int
- ): ImperativeAggregate =
- copy(inputAggBufferOffset = newInputAggBufferOffset)
+ protected lazy val inputContainsNull =
+ child.dataType.asInstanceOf[ArrayType].containsNull
override def initialize(buffer: InternalRow): Unit = {
- buffer.update(mutableAggBufferOffset + avgIndex, null)
- buffer.update(mutableAggBufferOffset + dimIndex, null)
+ buffer.update(mutableAggBufferOffset + accIndex, null)
buffer.setLong(mutableAggBufferOffset + countIndex, 0L)
}
- private lazy val inputContainsNull =
child.dataType.asInstanceOf[ArrayType].containsNull
+ // Infer vector dimension from byte array length (4 bytes per float)
+ protected def getDim(bytes: Array[Byte]): Int = bytes.length / 4
+
+ // Element-wise update for non-first vectors.
+ // accBytes contains the running vector; update it in-place with inputArray.
+ protected def updateElements(
+ accBytes: Array[Byte],
+ inputArray: ArrayData,
+ dim: Int,
+ newCount: Long): Unit
+
+ // Element-wise merge of two non-empty buffers.
+ // accBytes contains the left running vector; update it in-place.
+ protected def mergeElements(
+ accBytes: Array[Byte],
+ inputBytes: Array[Byte],
+ dim: Int,
+ currentCount: Long,
+ inputCount: Long,
+ newCount: Long): Unit
override def update(buffer: InternalRow, input: InternalRow): Unit = {
val inputValue = child.eval(input)
@@ -468,68 +446,50 @@ case class VectorAvg(
}
}
- val avgOffset = mutableAggBufferOffset + avgIndex
- val dimOffset = mutableAggBufferOffset + dimIndex
+ val accOffset = mutableAggBufferOffset + accIndex
val countOffset = mutableAggBufferOffset + countIndex
val currentCount = buffer.getLong(countOffset)
if (currentCount == 0L) {
- // First valid vector - just copy it as the initial average
- val byteBuffer =
- ByteBuffer.allocate(inputLen * 4).order(ByteOrder.LITTLE_ENDIAN)
+ // First valid vector - just copy it
+ val bytes = new Array[Byte](inputLen * 4)
var i = 0
while (i < inputLen) {
- byteBuffer.putFloat(inputArray.getFloat(i))
+ Platform.putFloat(bytes, Platform.BYTE_ARRAY_OFFSET + i.toLong * 4,
inputArray.getFloat(i))
i += 1
}
- buffer.update(avgOffset, byteBuffer.array())
- buffer.setInt(dimOffset, inputLen)
+ buffer.update(accOffset, bytes)
buffer.setLong(countOffset, 1L)
} else {
- val currentDim = buffer.getInt(dimOffset)
+ val accBytes = buffer.getBinary(accOffset)
+ val accDim = getDim(accBytes)
// Empty array case - if current is empty and input is empty, keep empty
- if (currentDim == 0 && inputLen == 0) {
+ if (accDim == 0 && inputLen == 0) {
buffer.setLong(countOffset, currentCount + 1L)
return
}
// Dimension mismatch check
- if (currentDim != inputLen) {
+ if (accDim != inputLen) {
throw QueryExecutionErrors.vectorDimensionMismatchError(
prettyName,
- currentDim,
+ accDim,
inputLen
)
}
- // Update running average: new_avg = old_avg + (new_value - old_avg) /
(count + 1)
val newCount = currentCount + 1L
- val invCount = 1.0f / newCount
- val currentAvgBytes = buffer.getBinary(avgOffset)
- // reuse the buffer without reallocation
- val avgBuffer =
- ByteBuffer.wrap(currentAvgBytes).order(ByteOrder.LITTLE_ENDIAN)
- var i = 0
- var idx = 0
- while (i < currentDim) {
- val oldAvg = avgBuffer.getFloat(idx)
- val newVal = inputArray.getFloat(i)
- avgBuffer.putFloat(idx, oldAvg + (newVal - oldAvg) * invCount)
- i += 1
- idx += 4 // 4 bytes per float
- }
+ updateElements(accBytes, inputArray, accDim, newCount)
buffer.setLong(countOffset, newCount)
}
}
override def merge(buffer: InternalRow, inputBuffer: InternalRow): Unit = {
- val avgOffset = mutableAggBufferOffset + avgIndex
- val dimOffset = mutableAggBufferOffset + dimIndex
+ val accOffset = mutableAggBufferOffset + accIndex
val countOffset = mutableAggBufferOffset + countIndex
- val inputAvgOffset = inputAggBufferOffset + avgIndex
- val inputDimOffset = inputAggBufferOffset + dimIndex
+ val inputAccOffset = inputAggBufferOffset + accIndex
val inputCountOffset = inputAggBufferOffset + countIndex
val inputCount = inputBuffer.getLong(inputCountOffset)
@@ -537,91 +497,68 @@ case class VectorAvg(
return
}
- val inputAvgBytes = inputBuffer.getBinary(inputAvgOffset)
- val inputDim = inputBuffer.getInt(inputDimOffset)
+ val inputAccBytes = inputBuffer.getBinary(inputAccOffset)
val currentCount = buffer.getLong(countOffset)
if (currentCount == 0L) {
// Copy input buffer to current buffer
- buffer.update(avgOffset, inputAvgBytes.clone())
- buffer.setInt(dimOffset, inputDim)
+ buffer.update(accOffset, inputAccBytes.clone())
buffer.setLong(countOffset, inputCount)
} else {
- val currentDim = buffer.getInt(dimOffset)
+ val accBytes = buffer.getBinary(accOffset)
+ val accDim = getDim(accBytes)
+ val inputDim = getDim(inputAccBytes)
// Empty array case
- if (currentDim == 0 && inputDim == 0) {
+ if (accDim == 0 && inputDim == 0) {
buffer.setLong(countOffset, currentCount + inputCount)
return
}
// Dimension mismatch check
- if (currentDim != inputDim) {
+ if (accDim != inputDim) {
throw QueryExecutionErrors.vectorDimensionMismatchError(
prettyName,
- currentDim,
+ accDim,
inputDim
)
}
- // Merge running averages:
- // combined_avg = (left_avg * left_count) / (left_count + right_count) +
- // (right_avg * right_count) / (left_count + right_count)
val newCount = currentCount + inputCount
- val leftWeight = currentCount.toFloat / newCount
- val rightWeight = inputCount.toFloat / newCount
- val currentAvgBytes = buffer.getBinary(avgOffset)
- // reuse the buffer without reallocation
- val avgBuffer =
- ByteBuffer.wrap(currentAvgBytes).order(ByteOrder.LITTLE_ENDIAN)
- val inputAvgBuffer =
- ByteBuffer.wrap(inputAvgBytes).order(ByteOrder.LITTLE_ENDIAN)
- var i = 0
- var idx = 0
- while (i < currentDim) {
- val leftAvg = avgBuffer.getFloat(idx)
- val rightAvg = inputAvgBuffer.getFloat(idx)
- avgBuffer.putFloat(idx, leftAvg * leftWeight + rightAvg * rightWeight)
- i += 1
- idx += 4 // 4 bytes per float
- }
+ mergeElements(accBytes, inputAccBytes, accDim,
+ currentCount, inputCount, newCount)
buffer.setLong(countOffset, newCount)
}
}
override def eval(buffer: InternalRow): Any = {
- val countOffset = mutableAggBufferOffset + countIndex
- val count = buffer.getLong(countOffset)
+ val count = buffer.getLong(mutableAggBufferOffset + countIndex)
if (count == 0L) {
null
} else {
- val dim = buffer.getInt(mutableAggBufferOffset + dimIndex)
- val avgBytes = buffer.getBinary(mutableAggBufferOffset + avgIndex)
- val avgBuffer = ByteBuffer.wrap(avgBytes).order(ByteOrder.LITTLE_ENDIAN)
+ val accBytes = buffer.getBinary(mutableAggBufferOffset + accIndex)
+ val dim = getDim(accBytes)
val result = new Array[Float](dim)
var i = 0
while (i < dim) {
- result(i) = avgBuffer.getFloat()
+ result(i) = Platform.getFloat(accBytes, Platform.BYTE_ARRAY_OFFSET +
i.toLong * 4)
i += 1
}
ArrayData.toArrayData(result)
}
}
-
- override protected def withNewChildInternal(newChild: Expression): VectorAvg
=
- copy(child = newChild)
}
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """
- _FUNC_(array) - Returns the element-wise sum of float vectors in a group.
+ _FUNC_(array) - Returns the element-wise mean of float vectors in a group.
All vectors must have the same dimension.
""",
examples = """
Examples:
> SELECT _FUNC_(col) FROM VALUES (array(1.0F, 2.0F)), (array(3.0F,
4.0F)) AS tab(col);
- [4.0,6.0]
+ [2.0,3.0]
""",
since = "4.2.0",
group = "vector_funcs"
@@ -629,68 +566,17 @@ case class VectorAvg(
// scalastyle:on line.size.limit
// Note: This implementation uses single-precision floating-point arithmetic
(Float).
// Precision loss is expected for very large aggregates due to:
-// 1. Accumulated rounding errors when summing many values
-// 2. Loss of significance when adding small values to large accumulated sums
-case class VectorSum(
+// 1. Accumulated rounding errors in incremental average updates
+// 2. Loss of significance when dividing by large counts
+case class VectorAvg(
child: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0
-) extends ImperativeAggregate
- with UnaryLike[Expression]
- with QueryErrorsBase {
+) extends VectorAggregateBase {
def this(child: Expression) = this(child, 0, 0)
- override def prettyName: String = "vector_sum"
-
- override def nullable: Boolean = true
-
- override def dataType: DataType = ArrayType(FloatType, containsNull = false)
-
- override def checkInputDataTypes(): TypeCheckResult = {
- child.dataType match {
- case ArrayType(FloatType, _) =>
- TypeCheckResult.TypeCheckSuccess
- case _ =>
- DataTypeMismatch(
- errorSubClass = "UNEXPECTED_INPUT_TYPE",
- messageParameters = Map(
- "paramIndex" -> ordinalNumber(0),
- "requiredType" -> toSQLType(ArrayType(FloatType)),
- "inputSql" -> toSQLExpr(child),
- "inputType" -> toSQLType(child.dataType)
- )
- )
- }
- }
-
- // Aggregate buffer schema: (sum: BINARY, dim: INTEGER)
- // sum is a BINARY representation of the sum vector of floats in the group
- // dim is the dimension of the vector
- // null sum means no valid input has been seen yet
- private lazy val sumAttr = AttributeReference(
- "sum",
- BinaryType,
- nullable = true
- )()
- private lazy val dimAttr = AttributeReference(
- "dim",
- IntegerType,
- nullable = true
- )()
-
- override def aggBufferAttributes: Seq[AttributeReference] =
- Seq(sumAttr, dimAttr)
-
- override def aggBufferSchema: StructType =
- DataTypeUtils.fromAttributes(aggBufferAttributes)
-
- override lazy val inputAggBufferAttributes: Seq[AttributeReference] =
- aggBufferAttributes.map(_.newInstance())
-
- // Buffer indices
- private val sumIndex = 0
- private val dimIndex = 1
+ override def prettyName: String = "vector_avg"
override def withNewMutableAggBufferOffset(
newMutableAggBufferOffset: Int
@@ -702,146 +588,114 @@ case class VectorSum(
): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)
- private lazy val inputContainsNull =
child.dataType.asInstanceOf[ArrayType].containsNull
-
- override def initialize(buffer: InternalRow): Unit = {
- buffer.update(mutableAggBufferOffset + sumIndex, null)
- buffer.update(mutableAggBufferOffset + dimIndex, null)
- }
-
- override def update(buffer: InternalRow, input: InternalRow): Unit = {
- val inputValue = child.eval(input)
- if (inputValue == null) {
- return
- }
-
- val inputArray = inputValue.asInstanceOf[ArrayData]
- val inputLen = inputArray.numElements()
-
- // Check for NULL elements in input vector - skip if any NULL element found
- // Only check if the array type can contain nulls
- if (inputContainsNull) {
- var i = 0
- while (i < inputLen) {
- if (inputArray.isNullAt(i)) {
- return
- }
- i += 1
- }
+ override protected def updateElements(
+ accBytes: Array[Byte],
+ inputArray: ArrayData,
+ dim: Int,
+ newCount: Long): Unit = {
+ // Update running average: new_avg = old_avg + (new_value - old_avg) /
new_count
+ val invCount = 1.0f / newCount
+ var i = 0
+ while (i < dim) {
+ val off = Platform.BYTE_ARRAY_OFFSET + i.toLong * 4
+ val oldAvg = Platform.getFloat(accBytes, off)
+ Platform.putFloat(accBytes, off, oldAvg + (inputArray.getFloat(i) -
oldAvg) * invCount)
+ i += 1
}
+ }
- val sumOffset = mutableAggBufferOffset + sumIndex
- val dimOffset = mutableAggBufferOffset + dimIndex
-
- if (buffer.isNullAt(sumOffset)) {
- // First valid vector - just copy it as the initial sum
- val byteBuffer =
- ByteBuffer.allocate(inputLen * 4).order(ByteOrder.LITTLE_ENDIAN)
- var i = 0
- while (i < inputLen) {
- byteBuffer.putFloat(inputArray.getFloat(i))
- i += 1
- }
- buffer.update(sumOffset, byteBuffer.array())
- buffer.setInt(dimOffset, inputLen)
- } else {
- val currentDim = buffer.getInt(dimOffset)
-
- // Empty array case - if current is empty and input is empty, keep empty
- if (currentDim == 0 && inputLen == 0) {
- return
- }
-
- // Dimension mismatch check
- if (currentDim != inputLen) {
- throw QueryExecutionErrors.vectorDimensionMismatchError(
- prettyName,
- currentDim,
- inputLen
- )
- }
-
- // Update sum: new_sum = old_sum + new_value
- val currentSumBytes = buffer.getBinary(sumOffset)
- // reuse the buffer without reallocation
- val sumBuffer =
- ByteBuffer.wrap(currentSumBytes).order(ByteOrder.LITTLE_ENDIAN)
- var i = 0
- var idx = 0
- while (i < currentDim) {
- sumBuffer.putFloat(idx, sumBuffer.getFloat(idx) +
inputArray.getFloat(i))
- i += 1
- idx += 4 // 4 bytes per float
- }
+ override protected def mergeElements(
+ accBytes: Array[Byte],
+ inputBytes: Array[Byte],
+ dim: Int,
+ currentCount: Long,
+ inputCount: Long,
+ newCount: Long): Unit = {
+ // Merge running averages:
+ // combined_avg = left_avg * (left_count / total) + right_avg *
(right_count / total)
+ val leftWeight = currentCount.toFloat / newCount
+ val rightWeight = inputCount.toFloat / newCount
+ var i = 0
+ while (i < dim) {
+ val off = Platform.BYTE_ARRAY_OFFSET + i.toLong * 4
+ val leftAvg = Platform.getFloat(accBytes, off)
+ val rightAvg = Platform.getFloat(inputBytes, off)
+ Platform.putFloat(accBytes, off, leftAvg * leftWeight + rightAvg *
rightWeight)
+ i += 1
}
}
- override def merge(buffer: InternalRow, inputBuffer: InternalRow): Unit = {
- val sumOffset = mutableAggBufferOffset + sumIndex
- val dimOffset = mutableAggBufferOffset + dimIndex
- val inputSumOffset = inputAggBufferOffset + sumIndex
- val inputDimOffset = inputAggBufferOffset + dimIndex
+ override protected def withNewChildInternal(newChild: Expression): VectorAvg
=
+ copy(child = newChild)
+}
- if (inputBuffer.isNullAt(inputSumOffset)) {
- return
- }
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_(array) - Returns the element-wise sum of float vectors in a group.
+ All vectors must have the same dimension.
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_(col) FROM VALUES (array(1.0F, 2.0F)), (array(3.0F,
4.0F)) AS tab(col);
+ [4.0,6.0]
+ """,
+ since = "4.2.0",
+ group = "vector_funcs"
+)
+// scalastyle:on line.size.limit
+// Note: This implementation uses single-precision floating-point arithmetic
(Float).
+// Precision loss is expected for very large aggregates due to:
+// 1. Accumulated rounding errors when summing many values
+// 2. Loss of significance when adding small values to large accumulated sums
+case class VectorSum(
+ child: Expression,
+ mutableAggBufferOffset: Int = 0,
+ inputAggBufferOffset: Int = 0
+) extends VectorAggregateBase {
- val inputSumBytes = inputBuffer.getBinary(inputSumOffset)
- val inputDim = inputBuffer.getInt(inputDimOffset)
+ def this(child: Expression) = this(child, 0, 0)
- if (buffer.isNullAt(sumOffset)) {
- // Copy input buffer to current buffer
- buffer.update(sumOffset, inputSumBytes.clone())
- buffer.setInt(dimOffset, inputDim)
- } else {
- val currentDim = buffer.getInt(dimOffset)
+ override def prettyName: String = "vector_sum"
- // Empty array case
- if (currentDim == 0 && inputDim == 0) {
- return
- }
+ override def withNewMutableAggBufferOffset(
+ newMutableAggBufferOffset: Int
+ ): ImperativeAggregate =
+ copy(mutableAggBufferOffset = newMutableAggBufferOffset)
- // Dimension mismatch check
- if (currentDim != inputDim) {
- throw QueryExecutionErrors.vectorDimensionMismatchError(
- prettyName,
- currentDim,
- inputDim
- )
- }
+ override def withNewInputAggBufferOffset(
+ newInputAggBufferOffset: Int
+ ): ImperativeAggregate =
+ copy(inputAggBufferOffset = newInputAggBufferOffset)
- // Merge sums: combined_sum = left_sum + right_sum
- val currentSumBytes = buffer.getBinary(sumOffset)
- // reuse the buffer without reallocation
- val sumBuffer =
- ByteBuffer.wrap(currentSumBytes).order(ByteOrder.LITTLE_ENDIAN)
- val inputSumBuffer =
- ByteBuffer.wrap(inputSumBytes).order(ByteOrder.LITTLE_ENDIAN)
- var i = 0
- var idx = 0
- while (i < currentDim) {
- sumBuffer.putFloat(idx, sumBuffer.getFloat(idx) +
inputSumBuffer.getFloat(idx))
- i += 1
- idx += 4 // 4 bytes per float
- }
+ override protected def updateElements(
+ accBytes: Array[Byte],
+ inputArray: ArrayData,
+ dim: Int,
+ newCount: Long): Unit = {
+ // Update sum: new_sum = old_sum + new_value
+ var i = 0
+ while (i < dim) {
+ val off = Platform.BYTE_ARRAY_OFFSET + i.toLong * 4
+ Platform.putFloat(accBytes, off, Platform.getFloat(accBytes, off) +
inputArray.getFloat(i))
+ i += 1
}
}
- override def eval(buffer: InternalRow): Any = {
- val sumOffset = mutableAggBufferOffset + sumIndex
- if (buffer.isNullAt(sumOffset)) {
- null
- } else {
- val dim = buffer.getInt(mutableAggBufferOffset + dimIndex)
- val sumBytes = buffer.getBinary(sumOffset)
- val sumBuffer = ByteBuffer.wrap(sumBytes).order(ByteOrder.LITTLE_ENDIAN)
- val result = new Array[Float](dim)
- var i = 0
- while (i < dim) {
- result(i) = sumBuffer.getFloat()
- i += 1
- }
- ArrayData.toArrayData(result)
+ override protected def mergeElements(
+ accBytes: Array[Byte],
+ inputBytes: Array[Byte],
+ dim: Int,
+ currentCount: Long,
+ inputCount: Long,
+ newCount: Long): Unit = {
+ // Merge sums: combined_sum = left_sum + right_sum
+ var i = 0
+ while (i < dim) {
+ val off = Platform.BYTE_ARRAY_OFFSET + i.toLong * 4
+ Platform.putFloat(accBytes, off,
+ Platform.getFloat(accBytes, off) + Platform.getFloat(inputBytes, off))
+ i += 1
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]