zhidongqu-db commented on code in PR #54368:
URL: https://github.com/apache/spark/pull/54368#discussion_r2829607199
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/vectorExpressions.scala:
##########
@@ -702,146 +588,114 @@ case class VectorSum(
): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)
- private lazy val inputContainsNull =
child.dataType.asInstanceOf[ArrayType].containsNull
-
- override def initialize(buffer: InternalRow): Unit = {
- buffer.update(mutableAggBufferOffset + sumIndex, null)
- buffer.update(mutableAggBufferOffset + dimIndex, null)
- }
-
- override def update(buffer: InternalRow, input: InternalRow): Unit = {
- val inputValue = child.eval(input)
- if (inputValue == null) {
- return
- }
-
- val inputArray = inputValue.asInstanceOf[ArrayData]
- val inputLen = inputArray.numElements()
-
- // Check for NULL elements in input vector - skip if any NULL element found
- // Only check if the array type can contain nulls
- if (inputContainsNull) {
- var i = 0
- while (i < inputLen) {
- if (inputArray.isNullAt(i)) {
- return
- }
- i += 1
- }
+ override protected def updateElements(
+ accBytes: Array[Byte],
+ inputArray: ArrayData,
+ dim: Int,
+ newCount: Long): Unit = {
+ // Update running average: new_avg = old_avg + (new_value - old_avg) /
new_count
+ val invCount = 1.0f / newCount
+ var i = 0
+ while (i < dim) {
+ val off = Platform.BYTE_ARRAY_OFFSET + i.toLong * 4
+ val oldAvg = Platform.getFloat(accBytes, off)
+ Platform.putFloat(accBytes, off, oldAvg + (inputArray.getFloat(i) -
oldAvg) * invCount)
+ i += 1
}
+ }
- val sumOffset = mutableAggBufferOffset + sumIndex
- val dimOffset = mutableAggBufferOffset + dimIndex
-
- if (buffer.isNullAt(sumOffset)) {
- // First valid vector - just copy it as the initial sum
- val byteBuffer =
- ByteBuffer.allocate(inputLen * 4).order(ByteOrder.LITTLE_ENDIAN)
- var i = 0
- while (i < inputLen) {
- byteBuffer.putFloat(inputArray.getFloat(i))
- i += 1
- }
- buffer.update(sumOffset, byteBuffer.array())
- buffer.setInt(dimOffset, inputLen)
- } else {
- val currentDim = buffer.getInt(dimOffset)
-
- // Empty array case - if current is empty and input is empty, keep empty
- if (currentDim == 0 && inputLen == 0) {
- return
- }
-
- // Dimension mismatch check
- if (currentDim != inputLen) {
- throw QueryExecutionErrors.vectorDimensionMismatchError(
- prettyName,
- currentDim,
- inputLen
- )
- }
-
- // Update sum: new_sum = old_sum + new_value
- val currentSumBytes = buffer.getBinary(sumOffset)
- // reuse the buffer without reallocation
- val sumBuffer =
- ByteBuffer.wrap(currentSumBytes).order(ByteOrder.LITTLE_ENDIAN)
- var i = 0
- var idx = 0
- while (i < currentDim) {
- sumBuffer.putFloat(idx, sumBuffer.getFloat(idx) +
inputArray.getFloat(i))
- i += 1
- idx += 4 // 4 bytes per float
- }
+ override protected def mergeElements(
+ accBytes: Array[Byte],
+ inputBytes: Array[Byte],
+ dim: Int,
+ currentCount: Long,
+ inputCount: Long,
+ newCount: Long): Unit = {
+ // Merge running averages:
+ // combined_avg = left_avg * (left_count / total) + right_avg *
(right_count / total)
+ val leftWeight = currentCount.toFloat / newCount
+ val rightWeight = inputCount.toFloat / newCount
+ var i = 0
+ while (i < dim) {
+ val off = Platform.BYTE_ARRAY_OFFSET + i.toLong * 4
+ val leftAvg = Platform.getFloat(accBytes, off)
+ val rightAvg = Platform.getFloat(inputBytes, off)
+ Platform.putFloat(accBytes, off, leftAvg * leftWeight + rightAvg *
rightWeight)
+ i += 1
}
}
- override def merge(buffer: InternalRow, inputBuffer: InternalRow): Unit = {
- val sumOffset = mutableAggBufferOffset + sumIndex
- val dimOffset = mutableAggBufferOffset + dimIndex
- val inputSumOffset = inputAggBufferOffset + sumIndex
- val inputDimOffset = inputAggBufferOffset + dimIndex
+ override protected def withNewChildInternal(newChild: Expression): VectorAvg
=
+ copy(child = newChild)
+}
- if (inputBuffer.isNullAt(inputSumOffset)) {
- return
- }
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_(array) - Returns the element-wise sum of float vectors in a group.
+ All vectors must have the same dimension.
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_(col) FROM VALUES (array(1.0F, 2.0F)), (array(3.0F,
4.0F)) AS tab(col);
+ [4.0,6.0]
+ """,
+ since = "4.2.0",
+ group = "vector_funcs"
+)
+// scalastyle:on line.size.limit
+// Note: This implementation uses single-precision floating-point arithmetic
(Float).
+// Precision loss is expected for very large aggregates due to:
+// 1. Accumulated rounding errors when summing many values
+// 2. Loss of significance when adding small values to large accumulated sums
+case class VectorSum(
+ child: Expression,
+ mutableAggBufferOffset: Int = 0,
+ inputAggBufferOffset: Int = 0
+) extends VectorAggregateBase {
- val inputSumBytes = inputBuffer.getBinary(inputSumOffset)
- val inputDim = inputBuffer.getInt(inputDimOffset)
+ def this(child: Expression) = this(child, 0, 0)
- if (buffer.isNullAt(sumOffset)) {
- // Copy input buffer to current buffer
- buffer.update(sumOffset, inputSumBytes.clone())
- buffer.setInt(dimOffset, inputDim)
- } else {
- val currentDim = buffer.getInt(dimOffset)
+ override def prettyName: String = "vector_sum"
- // Empty array case
- if (currentDim == 0 && inputDim == 0) {
- return
- }
+ override def withNewMutableAggBufferOffset(
+ newMutableAggBufferOffset: Int
+ ): ImperativeAggregate =
+ copy(mutableAggBufferOffset = newMutableAggBufferOffset)
- // Dimension mismatch check
- if (currentDim != inputDim) {
- throw QueryExecutionErrors.vectorDimensionMismatchError(
- prettyName,
- currentDim,
- inputDim
- )
- }
+ override def withNewInputAggBufferOffset(
+ newInputAggBufferOffset: Int
+ ): ImperativeAggregate =
+ copy(inputAggBufferOffset = newInputAggBufferOffset)
- // Merge sums: combined_sum = left_sum + right_sum
- val currentSumBytes = buffer.getBinary(sumOffset)
- // reuse the buffer without reallocation
- val sumBuffer =
- ByteBuffer.wrap(currentSumBytes).order(ByteOrder.LITTLE_ENDIAN)
- val inputSumBuffer =
- ByteBuffer.wrap(inputSumBytes).order(ByteOrder.LITTLE_ENDIAN)
- var i = 0
- var idx = 0
- while (i < currentDim) {
- sumBuffer.putFloat(idx, sumBuffer.getFloat(idx) +
inputSumBuffer.getFloat(idx))
- i += 1
- idx += 4 // 4 bytes per float
- }
+ override protected def updateElements(
+ accBytes: Array[Byte],
+ inputArray: ArrayData,
+ dim: Int,
+ newCount: Long): Unit = {
+ // Update sum: new_sum = old_sum + new_value
+ var i = 0
+ while (i < dim) {
+ val off = Platform.BYTE_ARRAY_OFFSET + i.toLong * 4
+ Platform.putFloat(accBytes, off, Platform.getFloat(accBytes, off) +
inputArray.getFloat(i))
Review Comment:
I believe Spark assumes a homogeneous cluster - even if we are serializing
them during spilling and shuffles all executors should be on the same arch. So
native byte ordering should be safe. UnsafeRow, UnsafeArrayData,
ApproxCountDistinctForIntervals all use Platform.putLong/Platform.getLong on
Array[Byte] with no endianness conversion.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]