This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 4729b9917107 [SPARK-55031][SQL] Add vector avg/sum aggregation
function expressions
4729b9917107 is described below
commit 4729b9917107f6bed01ee7e91d3e6a988a11b954
Author: zhidongqu-db <[email protected]>
AuthorDate: Fri Jan 30 11:34:00 2026 +0800
[SPARK-55031][SQL] Add vector avg/sum aggregation function expressions
### What changes were proposed in this pull request?
This PR adds support for vector aggregation functions to Spark SQL,
enabling element-wise sum and average computations across groups of vectors.
- vector_sum(vectors) - Returns the element-wise sum of float vectors in a
group. Each element in the result is the sum of the corresponding elements
across all input vectors.
- vector_avg(vectors) - Returns the element-wise average of float vectors
in a group. Each element in the result is the arithmetic mean of the
corresponding elements across all input vectors.
Key implementation details:
- Type Safety: Functions accept only ARRAY<FLOAT> for vectors. No implicit
type casting is performed - passing ARRAY<DOUBLE> or ARRAY<INT> results in a
DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE error.
- Dimension Validation: All vectors in a group must have the same
dimension; throws VECTOR_DIMENSION_MISMATCH error if dimensions do not match.
- NULL Handling: NULL vectors are skipped in aggregation. Non-NULL vectors
containing NULL elements are also treated as NULL and skipped.
- Edge Cases: Returns NULL if all values in the group are invalid. Returns
an empty array [] if all input vectors are empty.
- Compact Buffer Storage: Aggregate state uses BINARY format (dim * 4
bytes) instead of ARRAY<FLOAT> for more efficient storage without null field
overhead.
This PR only includes SQL language support; DataFrame API will be added in
a separate PR.
### Why are the changes needed?
Vector aggregation functions are fundamental operations for:
- Clustering workloads: Computing cluster centroids by averaging member
vectors
- RAG applications: Aggregating embeddings across document chunks
- Distributed ML: Gradient accumulation and combining pre-aggregated
vectors across partitions
- Recommendation systems: Computing user preference vectors from
interaction history
These functions complement the vector distance/similarity functions and are
commonly available in other systems (Snowflake's VECTOR_SUM/VECTOR_AVG,
PostgreSQL pgvector's SUM/AVG over vectors).
### Does this PR introduce _any_ user-facing change?
Yes, this PR introduces 2 new SQL aggregate functions:
```
-- Setup example table
CREATE TABLE vector_data (category STRING, embedding ARRAY<FLOAT>);
INSERT INTO vector_data VALUES
('A', array(1.0F, 2.0F, 3.0F)),
('A', array(4.0F, 5.0F, 6.0F)),
('B', array(2.0F, 1.0F, 4.0F)),
('B', array(3.0F, 2.0F, 1.0F));
-- Element-wise sum per category
SELECT category, vector_sum(embedding) AS sum_vector
FROM vector_data
GROUP BY category
ORDER BY category;
-- category: A, sum_vector: [5.0, 7.0, 9.0]
-- category: B, sum_vector: [5.0, 3.0, 5.0]
-- Element-wise average per category (centroid computation)
SELECT category, vector_avg(embedding) AS centroid
FROM vector_data
GROUP BY category
ORDER BY category;
-- category: A, centroid: [2.5, 3.5, 4.5]
-- category: B, centroid: [2.5, 1.5, 2.5]
-- Scalar aggregation (no GROUP BY)
SELECT vector_sum(embedding) AS total_sum, vector_avg(embedding) AS
overall_centroid
FROM vector_data;
-- total_sum: [10.0, 10.0, 14.0]
-- overall_centroid: [2.5, 2.5, 3.5]
-- NULL vectors are skipped
INSERT INTO vector_data VALUES ('A', NULL);
SELECT category, vector_avg(embedding) FROM vector_data WHERE category =
'A' GROUP BY category;
-- Returns: [2.5, 3.5, 4.5] (unchanged, NULL skipped)
-- Vectors with NULL elements are skipped
INSERT INTO vector_data VALUES ('A', array(100.0F, NULL, 100.0F));
SELECT category, vector_avg(embedding) FROM vector_data WHERE
category = 'A' GROUP BY category;
-- Returns: [2.5, 3.5, 4.5] (unchanged, vector with NULL element skipped)
```
### How was this patch tested?
SQL Golden File Tests: Added `vector-agg.sql` with test coverage:
- Basic functionality tests for both `vector_sum` and `vector_avg`
- GROUP BY aggregation and scalar aggregation (no GROUP BY)
- Mathematical correctness validation
- Empty vector handling (returns empty array)
- NULL vector handling (skipped in aggregation)
- NULL element within vector handling (entire vector skipped)
- All-NULL group handling (returns NULL)
- Dimension mismatch error cases
- Type mismatch error cases
- Single element vectors
- Large vectors (16 elements)
- Window function aggregation (PARTITION BY)
- Special float values: NaN, Infinity, -Infinity (IEEE 754 propagation)
Unit tests: Added `VectorAggSuite.scala` to test aggregate lifecycle phases:
- `initialize()`: Empty buffer returns null
- `update()`: Single/multiple vectors, NULL handling, special floats
- `merge()`: Two buffers, empty buffers, different counts (for weighted
average)
- `eval()`: Result extraction from binary buffer
- Numerical stability test for running average algorithm
### Was this patch authored or co-authored using generative AI tooling?
Yes, code assistance with Claude Opus 4.5 in combination with manual
editing by the author.
Closes #54011 from zhidongqu-db/vector-avg-sum-functions.
Authored-by: zhidongqu-db <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../sql/catalyst/analysis/FunctionRegistry.scala | 2 +
.../catalyst/expressions/vectorExpressions.scala | 473 +++++++++++++++++++-
.../expressions/aggregate/VectorAggSuite.scala | 497 +++++++++++++++++++++
.../sql-functions/sql-expression-schema.md | 2 +
.../sql-tests/analyzer-results/vector-agg.sql.out | 243 ++++++++++
.../test/resources/sql-tests/inputs/vector-agg.sql | 81 ++++
.../resources/sql-tests/results/vector-agg.sql.out | 265 +++++++++++
7 files changed, 1561 insertions(+), 2 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 9bc181815f9c..014b0ee606fa 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -553,6 +553,8 @@ object FunctionRegistry {
expression[VectorL2Distance]("vector_l2_distance"),
expression[VectorNorm]("vector_norm"),
expression[VectorNormalize]("vector_normalize"),
+ expression[VectorAvg]("vector_avg"),
+ expression[VectorSum]("vector_sum"),
// string functions
expression[Ascii]("ascii"),
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/vectorExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/vectorExpressions.scala
index 9ff2ef57c88b..733ac88d952d 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/vectorExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/vectorExpressions.scala
@@ -17,11 +17,27 @@
package org.apache.spark.sql.catalyst.expressions
+import java.nio.{ByteBuffer, ByteOrder}
+
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
+import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
-import org.apache.spark.sql.errors.QueryErrorsBase
-import org.apache.spark.sql.types.{ArrayType, FloatType, StringType}
+import org.apache.spark.sql.catalyst.trees.UnaryLike
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
+import org.apache.spark.sql.catalyst.util.ArrayData
+import org.apache.spark.sql.errors.{QueryErrorsBase, QueryExecutionErrors}
+import org.apache.spark.sql.types.{
+ ArrayType,
+ BinaryType,
+ DataType,
+ FloatType,
+ IntegerType,
+ LongType,
+ StringType,
+ StructType
+}
// scalastyle:off line.size.limit
@ExpressionDescription(
@@ -328,3 +344,456 @@ case class VectorNormalize(vector: Expression, degree:
Expression)
copy(vector = newChildren(0), degree = newChildren(1))
}
}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_(array) - Returns the element-wise mean of float vectors in a group.
+ All vectors must have the same dimension.
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_(col) FROM VALUES (array(1.0F, 2.0F)), (array(3.0F,
4.0F)) AS tab(col);
+ [2.0,3.0]
+ """,
+ since = "4.2.0",
+ group = "vector_funcs"
+)
+// scalastyle:on line.size.limit
+case class VectorAvg(
+ child: Expression,
+ mutableAggBufferOffset: Int = 0,
+ inputAggBufferOffset: Int = 0
+) extends ImperativeAggregate
+ with UnaryLike[Expression]
+ with QueryErrorsBase {
+
+ def this(child: Expression) = this(child, 0, 0)
+
+ override def prettyName: String = "vector_avg"
+
+ override def nullable: Boolean = true
+
+ override def dataType: DataType = ArrayType(FloatType, containsNull = false)
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ child.dataType match {
+ case ArrayType(FloatType, _) =>
+ TypeCheckResult.TypeCheckSuccess
+ case _ =>
+ DataTypeMismatch(
+ errorSubClass = "UNEXPECTED_INPUT_TYPE",
+ messageParameters = Map(
+ "paramIndex" -> ordinalNumber(0),
+ "requiredType" -> toSQLType(ArrayType(FloatType)),
+ "inputSql" -> toSQLExpr(child),
+ "inputType" -> toSQLType(child.dataType)
+ )
+ )
+ }
+ }
+
+ // Aggregate buffer schema: (avg: BINARY, dim: INTEGER, count: LONG)
+ // avg is a BINARY representation of the average vector of floats in the
group
+ // dim is the dimension of the vector
+ // count is the number of vectors in the group
+ // null avg means no valid input has been seen yet
+ private lazy val avgAttr = AttributeReference(
+ "avg",
+ BinaryType,
+ nullable = true
+ )()
+ private lazy val dimAttr = AttributeReference(
+ "dim",
+ IntegerType,
+ nullable = true
+ )()
+ private lazy val countAttr =
+ AttributeReference("count", LongType, nullable = false)()
+
+ override def aggBufferAttributes: Seq[AttributeReference] =
+ Seq(avgAttr, dimAttr, countAttr)
+
+ override def aggBufferSchema: StructType =
+ DataTypeUtils.fromAttributes(aggBufferAttributes)
+
+ override lazy val inputAggBufferAttributes: Seq[AttributeReference] =
+ aggBufferAttributes.map(_.newInstance())
+
+ // Buffer indices
+ private val avgIndex = 0
+ private val dimIndex = 1
+ private val countIndex = 2
+
+ override def withNewMutableAggBufferOffset(
+ newMutableAggBufferOffset: Int
+ ): ImperativeAggregate =
+ copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+ override def withNewInputAggBufferOffset(
+ newInputAggBufferOffset: Int
+ ): ImperativeAggregate =
+ copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+ override def initialize(buffer: InternalRow): Unit = {
+ buffer.update(mutableAggBufferOffset + avgIndex, null)
+ buffer.update(mutableAggBufferOffset + dimIndex, null)
+ buffer.setLong(mutableAggBufferOffset + countIndex, 0L)
+ }
+
+ override def update(buffer: InternalRow, input: InternalRow): Unit = {
+ val inputValue = child.eval(input)
+ if (inputValue == null) {
+ return
+ }
+
+ val inputArray = inputValue.asInstanceOf[ArrayData]
+ val inputLen = inputArray.numElements()
+
+ // Check for NULL elements in input vector - skip if any NULL element found
+ for (i <- 0 until inputLen) {
+ if (inputArray.isNullAt(i)) {
+ return
+ }
+ }
+
+ val currentCount = buffer.getLong(mutableAggBufferOffset + countIndex)
+
+ if (currentCount == 0L) {
+ // First valid vector - just copy it as the initial average
+ val byteBuffer =
+ ByteBuffer.allocate(inputLen * 4).order(ByteOrder.LITTLE_ENDIAN)
+ for (i <- 0 until inputLen) {
+ byteBuffer.putFloat(inputArray.getFloat(i))
+ }
+ buffer.update(mutableAggBufferOffset + avgIndex, byteBuffer.array())
+ buffer.setInt(mutableAggBufferOffset + dimIndex, inputLen)
+ buffer.setLong(mutableAggBufferOffset + countIndex, 1L)
+ } else {
+ val currentDim = buffer.getInt(mutableAggBufferOffset + dimIndex)
+
+ // Empty array case - if current is empty and input is empty, keep empty
+ if (currentDim == 0 && inputLen == 0) {
+ buffer.setLong(mutableAggBufferOffset + countIndex, currentCount + 1L)
+ return
+ }
+
+ // Dimension mismatch check
+ if (currentDim != inputLen) {
+ throw QueryExecutionErrors.vectorDimensionMismatchError(
+ prettyName,
+ currentDim,
+ inputLen
+ )
+ }
+
+ // Update running average: new_avg = old_avg + (new_value - old_avg) /
(count + 1)
+ val newCount = currentCount + 1L
+ val currentAvgBytes = buffer.getBinary(mutableAggBufferOffset + avgIndex)
+ val currentAvgBuffer =
+ ByteBuffer.wrap(currentAvgBytes).order(ByteOrder.LITTLE_ENDIAN)
+ val newAvgBuffer =
+ ByteBuffer.allocate(currentDim * 4).order(ByteOrder.LITTLE_ENDIAN)
+ for (i <- 0 until currentDim) {
+ val oldAvg = currentAvgBuffer.getFloat()
+ val newVal = inputArray.getFloat(i)
+ newAvgBuffer.putFloat(oldAvg + ((newVal - oldAvg) / newCount.toFloat))
+ }
+ buffer.update(mutableAggBufferOffset + avgIndex, newAvgBuffer.array())
+ buffer.setLong(mutableAggBufferOffset + countIndex, newCount)
+ }
+ }
+
+ override def merge(buffer: InternalRow, inputBuffer: InternalRow): Unit = {
+ val inputCount = inputBuffer.getLong(inputAggBufferOffset + countIndex)
+ if (inputCount == 0L) {
+ return
+ }
+
+ val inputAvgBytes = inputBuffer.getBinary(inputAggBufferOffset + avgIndex)
+ val inputDim = inputBuffer.getInt(inputAggBufferOffset + dimIndex)
+ val currentCount = buffer.getLong(mutableAggBufferOffset + countIndex)
+
+ if (currentCount == 0L) {
+ // Copy input buffer to current buffer
+ buffer.update(mutableAggBufferOffset + avgIndex, inputAvgBytes.clone())
+ buffer.setInt(mutableAggBufferOffset + dimIndex, inputDim)
+ buffer.setLong(mutableAggBufferOffset + countIndex, inputCount)
+ } else {
+ val currentDim = buffer.getInt(mutableAggBufferOffset + dimIndex)
+
+ // Empty array case
+ if (currentDim == 0 && inputDim == 0) {
+ buffer.setLong(
+ mutableAggBufferOffset + countIndex,
+ currentCount + inputCount
+ )
+ return
+ }
+
+ // Dimension mismatch check
+ if (currentDim != inputDim) {
+ throw QueryExecutionErrors.vectorDimensionMismatchError(
+ prettyName,
+ currentDim,
+ inputDim
+ )
+ }
+
+ // Merge running averages:
+ // combined_avg = (left_avg * left_count) / (left_count + right_count) +
+ // (right_avg * right_count) / (left_count + right_count)
+ val newCount = currentCount + inputCount
+ val currentAvgBytes = buffer.getBinary(mutableAggBufferOffset + avgIndex)
+ val currentAvgBuffer =
+ ByteBuffer.wrap(currentAvgBytes).order(ByteOrder.LITTLE_ENDIAN)
+ val inputAvgBuffer =
+ ByteBuffer.wrap(inputAvgBytes).order(ByteOrder.LITTLE_ENDIAN)
+ val newAvgBuffer =
+ ByteBuffer.allocate(currentDim * 4).order(ByteOrder.LITTLE_ENDIAN)
+ for (_ <- 0 until currentDim) {
+ // getFloat() will auto-increment the buffer's current position by 4
+ val leftAvg = currentAvgBuffer.getFloat()
+ val rightAvg = inputAvgBuffer.getFloat()
+ newAvgBuffer.putFloat(
+ (leftAvg * currentCount) / newCount.toFloat +
+ (rightAvg * inputCount) / newCount.toFloat
+ )
+ }
+ buffer.update(mutableAggBufferOffset + avgIndex, newAvgBuffer.array())
+ buffer.setLong(mutableAggBufferOffset + countIndex, newCount)
+ }
+ }
+
+ override def eval(buffer: InternalRow): Any = {
+ val count = buffer.getLong(mutableAggBufferOffset + countIndex)
+ if (count == 0L) {
+ null
+ } else {
+ val dim = buffer.getInt(mutableAggBufferOffset + dimIndex)
+ val avgBytes = buffer.getBinary(mutableAggBufferOffset + avgIndex)
+ val avgBuffer = ByteBuffer.wrap(avgBytes).order(ByteOrder.LITTLE_ENDIAN)
+ val result = new Array[Float](dim)
+ for (i <- 0 until dim) {
+ result(i) = avgBuffer.getFloat()
+ }
+ ArrayData.toArrayData(result)
+ }
+ }
+
+ override protected def withNewChildInternal(newChild: Expression): VectorAvg
=
+ copy(child = newChild)
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_(array) - Returns the element-wise sum of float vectors in a group.
+ All vectors must have the same dimension.
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_(col) FROM VALUES (array(1.0F, 2.0F)), (array(3.0F,
4.0F)) AS tab(col);
+ [4.0,6.0]
+ """,
+ since = "4.2.0",
+ group = "vector_funcs"
+)
+// scalastyle:on line.size.limit
+case class VectorSum(
+ child: Expression,
+ mutableAggBufferOffset: Int = 0,
+ inputAggBufferOffset: Int = 0
+) extends ImperativeAggregate
+ with UnaryLike[Expression]
+ with QueryErrorsBase {
+
+ def this(child: Expression) = this(child, 0, 0)
+
+ override def prettyName: String = "vector_sum"
+
+ override def nullable: Boolean = true
+
+ override def dataType: DataType = ArrayType(FloatType, containsNull = false)
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ child.dataType match {
+ case ArrayType(FloatType, _) =>
+ TypeCheckResult.TypeCheckSuccess
+ case _ =>
+ DataTypeMismatch(
+ errorSubClass = "UNEXPECTED_INPUT_TYPE",
+ messageParameters = Map(
+ "paramIndex" -> ordinalNumber(0),
+ "requiredType" -> toSQLType(ArrayType(FloatType)),
+ "inputSql" -> toSQLExpr(child),
+ "inputType" -> toSQLType(child.dataType)
+ )
+ )
+ }
+ }
+
+ // Aggregate buffer schema: (sum: BINARY, dim: INTEGER)
+ // sum is a BINARY representation of the sum vector of floats in the group
+ // dim is the dimension of the vector
+ // null sum means no valid input has been seen yet
+ private lazy val sumAttr = AttributeReference(
+ "sum",
+ BinaryType,
+ nullable = true
+ )()
+ private lazy val dimAttr = AttributeReference(
+ "dim",
+ IntegerType,
+ nullable = true
+ )()
+
+ override def aggBufferAttributes: Seq[AttributeReference] =
+ Seq(sumAttr, dimAttr)
+
+ override def aggBufferSchema: StructType =
+ DataTypeUtils.fromAttributes(aggBufferAttributes)
+
+ override lazy val inputAggBufferAttributes: Seq[AttributeReference] =
+ aggBufferAttributes.map(_.newInstance())
+
+ // Buffer indices
+ private val sumIndex = 0
+ private val dimIndex = 1
+
+ override def withNewMutableAggBufferOffset(
+ newMutableAggBufferOffset: Int
+ ): ImperativeAggregate =
+ copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+ override def withNewInputAggBufferOffset(
+ newInputAggBufferOffset: Int
+ ): ImperativeAggregate =
+ copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+ override def initialize(buffer: InternalRow): Unit = {
+ buffer.update(mutableAggBufferOffset + sumIndex, null)
+ buffer.update(mutableAggBufferOffset + dimIndex, null)
+ }
+
+ override def update(buffer: InternalRow, input: InternalRow): Unit = {
+ val inputValue = child.eval(input)
+ if (inputValue == null) {
+ return
+ }
+
+ val inputArray = inputValue.asInstanceOf[ArrayData]
+ val inputLen = inputArray.numElements()
+
+ // Check for NULL elements in input vector - skip if any NULL element found
+ for (i <- 0 until inputLen) {
+ if (inputArray.isNullAt(i)) {
+ return
+ }
+ }
+
+ if (buffer.isNullAt(mutableAggBufferOffset + sumIndex)) {
+ // First valid vector - just copy it as the initial sum
+ val byteBuffer =
+ ByteBuffer.allocate(inputLen * 4).order(ByteOrder.LITTLE_ENDIAN)
+ for (i <- 0 until inputLen) {
+ byteBuffer.putFloat(inputArray.getFloat(i))
+ }
+ buffer.update(mutableAggBufferOffset + sumIndex, byteBuffer.array())
+ buffer.setInt(mutableAggBufferOffset + dimIndex, inputLen)
+ } else {
+ val currentDim = buffer.getInt(mutableAggBufferOffset + dimIndex)
+
+ // Empty array case - if current is empty and input is empty, keep empty
+ if (currentDim == 0 && inputLen == 0) {
+ return
+ }
+
+ // Dimension mismatch check
+ if (currentDim != inputLen) {
+ throw QueryExecutionErrors.vectorDimensionMismatchError(
+ prettyName,
+ currentDim,
+ inputLen
+ )
+ }
+
+ // Update sum: new_sum = old_sum + new_value
+ val currentSumBytes = buffer.getBinary(mutableAggBufferOffset + sumIndex)
+ val currentSumBuffer =
+ ByteBuffer.wrap(currentSumBytes).order(ByteOrder.LITTLE_ENDIAN)
+ val newSumBuffer =
+ ByteBuffer.allocate(currentDim * 4).order(ByteOrder.LITTLE_ENDIAN)
+ for (i <- 0 until currentDim) {
+ newSumBuffer.putFloat(
+ currentSumBuffer.getFloat() + inputArray.getFloat(i)
+ )
+ }
+ buffer.update(mutableAggBufferOffset + sumIndex, newSumBuffer.array())
+ }
+ }
+
+ override def merge(buffer: InternalRow, inputBuffer: InternalRow): Unit = {
+ if (inputBuffer.isNullAt(inputAggBufferOffset + sumIndex)) {
+ return
+ }
+
+ val inputSumBytes = inputBuffer.getBinary(inputAggBufferOffset + sumIndex)
+ val inputDim = inputBuffer.getInt(inputAggBufferOffset + dimIndex)
+
+ if (buffer.isNullAt(mutableAggBufferOffset + sumIndex)) {
+ // Copy input buffer to current buffer
+ buffer.update(mutableAggBufferOffset + sumIndex, inputSumBytes.clone())
+ buffer.setInt(mutableAggBufferOffset + dimIndex, inputDim)
+ } else {
+ val currentDim = buffer.getInt(mutableAggBufferOffset + dimIndex)
+
+ // Empty array case
+ if (currentDim == 0 && inputDim == 0) {
+ return
+ }
+
+ // Dimension mismatch check
+ if (currentDim != inputDim) {
+ throw QueryExecutionErrors.vectorDimensionMismatchError(
+ prettyName,
+ currentDim,
+ inputDim
+ )
+ }
+
+ // Merge sums: combined_sum = left_sum + right_sum
+ val currentSumBytes = buffer.getBinary(mutableAggBufferOffset + sumIndex)
+ val currentSumBuffer =
+ ByteBuffer.wrap(currentSumBytes).order(ByteOrder.LITTLE_ENDIAN)
+ val inputSumBuffer =
+ ByteBuffer.wrap(inputSumBytes).order(ByteOrder.LITTLE_ENDIAN)
+ val newSumBuffer =
+ ByteBuffer.allocate(currentDim * 4).order(ByteOrder.LITTLE_ENDIAN)
+ for (_ <- 0 until currentDim) {
+ newSumBuffer.putFloat(
+ currentSumBuffer.getFloat() + inputSumBuffer.getFloat()
+ )
+ }
+ buffer.update(mutableAggBufferOffset + sumIndex, newSumBuffer.array())
+ }
+ }
+
+ override def eval(buffer: InternalRow): Any = {
+ if (buffer.isNullAt(mutableAggBufferOffset + sumIndex)) {
+ null
+ } else {
+ val dim = buffer.getInt(mutableAggBufferOffset + dimIndex)
+ val sumBytes = buffer.getBinary(mutableAggBufferOffset + sumIndex)
+ val sumBuffer = ByteBuffer.wrap(sumBytes).order(ByteOrder.LITTLE_ENDIAN)
+ val result = new Array[Float](dim)
+ for (i <- 0 until dim) {
+ result(i) = sumBuffer.getFloat()
+ }
+ ArrayData.toArrayData(result)
+ }
+ }
+
+ override protected def withNewChildInternal(newChild: Expression): VectorSum
=
+ copy(child = newChild)
+}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/VectorAggSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/VectorAggSuite.scala
new file mode 100644
index 000000000000..0f1f71807f8d
--- /dev/null
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/VectorAggSuite.scala
@@ -0,0 +1,497 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{BoundReference,
SpecificInternalRow, VectorAvg, VectorSum}
+import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
+import org.apache.spark.sql.types.{ArrayType, FloatType}
+
+class VectorAggSuite extends SparkFunSuite {
+
+ // Helper to create a VectorSum instance with buffer
+ def createVectorSum(): (VectorSum, InternalRow, BoundReference) = {
+ val input = new BoundReference(0, ArrayType(FloatType), nullable = true)
+ val agg = new VectorSum(input)
+ val buffer = new
SpecificInternalRow(agg.aggBufferAttributes.map(_.dataType))
+ agg.initialize(buffer)
+ (agg, buffer, input)
+ }
+
+ // Helper to create a VectorAvg instance with buffer
+ def createVectorAvg(): (VectorAvg, InternalRow, BoundReference) = {
+ val input = new BoundReference(0, ArrayType(FloatType), nullable = true)
+ val agg = new VectorAvg(input)
+ val buffer = new
SpecificInternalRow(agg.aggBufferAttributes.map(_.dataType))
+ agg.initialize(buffer)
+ (agg, buffer, input)
+ }
+
+ // Helper to create input row with float array
+ def createInputRow(values: Array[Float]): InternalRow = {
+ InternalRow(ArrayData.toArrayData(values))
+ }
+
+ // Helper to create input row with null
+ def createNullInputRow(): InternalRow = {
+ val row = new SpecificInternalRow(Seq(ArrayType(FloatType)))
+ row.setNullAt(0)
+ row
+ }
+
+ // Helper to create input row with array containing null elements
+ def createInputRowWithNullElement(values: Array[java.lang.Float]):
InternalRow = {
+ val arrayData = new GenericArrayData(values.map {
+ case null => null
+ case v => v.floatValue().asInstanceOf[AnyRef]
+ })
+ InternalRow(arrayData)
+ }
+
+ // Helper to extract result as float array
+ def evalAsFloatArray(agg: VectorSum, buffer: InternalRow): Array[Float] = {
+ val result = agg.eval(buffer)
+ if (result == null) return null
+ val arrayData = result.asInstanceOf[ArrayData]
+ (0 until arrayData.numElements()).map(i => arrayData.getFloat(i)).toArray
+ }
+
+ def evalAsFloatArray(agg: VectorAvg, buffer: InternalRow): Array[Float] = {
+ val result = agg.eval(buffer)
+ if (result == null) return null
+ val arrayData = result.asInstanceOf[ArrayData]
+ (0 until arrayData.numElements()).map(i => arrayData.getFloat(i)).toArray
+ }
+
+ // Asserts that two floats are approximately equal within a tolerance.
+ def assertFloatEquals(actual: Float, expected: Float, tolerance: Float =
1e-5f): Unit = {
+ assert(
+ java.lang.Float.isNaN(expected) && java.lang.Float.isNaN(actual) ||
+ math.abs(actual - expected) <= tolerance,
+ s"Expected $expected but got $actual (tolerance: $tolerance)"
+ )
+ }
+
+ // Asserts that two float arrays are approximately equal element-wise.
+ def assertFloatArrayEquals(
+ actual: Array[Float],
+ expected: Array[Float],
+ tolerance: Float = 1e-5f): Unit = {
+ assert(actual.length == expected.length,
+ s"Array lengths differ: ${actual.length} vs ${expected.length}")
+ actual.zip(expected).zipWithIndex.foreach { case ((a, e), i) =>
+ assertFloatEquals(a, e, tolerance)
+ }
+ }
+
+ test("VectorSum - empty buffer returns null") {
+ val (agg, buffer, _) = createVectorSum()
+ assert(agg.eval(buffer) === null)
+ }
+
+ test("VectorSum - single vector") {
+ val (agg, buffer, _) = createVectorSum()
+ agg.update(buffer, createInputRow(Array(1.0f, 2.0f, 3.0f)))
+ val result = evalAsFloatArray(agg, buffer)
+ assert(result === Array(1.0f, 2.0f, 3.0f))
+ }
+
+ test("VectorSum - multiple vectors") {
+ val (agg, buffer, _) = createVectorSum()
+ agg.update(buffer, createInputRow(Array(1.0f, 2.0f)))
+ agg.update(buffer, createInputRow(Array(3.0f, 4.0f)))
+ agg.update(buffer, createInputRow(Array(5.0f, 6.0f)))
+ val result = evalAsFloatArray(agg, buffer)
+ assert(result === Array(9.0f, 12.0f))
+ }
+
+ test("VectorSum - null vectors are skipped") {
+ val (agg, buffer, _) = createVectorSum()
+ agg.update(buffer, createInputRow(Array(1.0f, 2.0f)))
+ agg.update(buffer, createNullInputRow())
+ agg.update(buffer, createInputRow(Array(3.0f, 4.0f)))
+ val result = evalAsFloatArray(agg, buffer)
+ assert(result === Array(4.0f, 6.0f))
+ }
+
+ test("VectorSum - all null vectors returns null") {
+ val (agg, buffer, _) = createVectorSum()
+ agg.update(buffer, createNullInputRow())
+ agg.update(buffer, createNullInputRow())
+ assert(agg.eval(buffer) === null)
+ }
+
+ test("VectorSum - empty vectors") {
+ val (agg, buffer, _) = createVectorSum()
+ agg.update(buffer, createInputRow(Array.empty[Float]))
+ agg.update(buffer, createInputRow(Array.empty[Float]))
+ val result = evalAsFloatArray(agg, buffer)
+ assert(result === Array.empty[Float])
+ }
+
+ test("VectorSum - merge two buffers") {
+ val (agg, buffer1, _) = createVectorSum()
+ val (_, buffer2, _) = createVectorSum()
+
+ // Partition 1: [1, 2] + [3, 4] = [4, 6]
+ agg.update(buffer1, createInputRow(Array(1.0f, 2.0f)))
+ agg.update(buffer1, createInputRow(Array(3.0f, 4.0f)))
+
+ // Partition 2: [5, 6] + [7, 8] = [12, 14]
+ agg.update(buffer2, createInputRow(Array(5.0f, 6.0f)))
+ agg.update(buffer2, createInputRow(Array(7.0f, 8.0f)))
+
+ // Merge: [4, 6] + [12, 14] = [16, 20]
+ agg.merge(buffer1, buffer2)
+
+ val result = evalAsFloatArray(agg, buffer1)
+ assert(result === Array(16.0f, 20.0f))
+ }
+
+ test("VectorSum - merge with empty buffer") {
+ val (agg, buffer1, _) = createVectorSum()
+ val (_, buffer2, _) = createVectorSum()
+
+ agg.update(buffer1, createInputRow(Array(1.0f, 2.0f)))
+ // buffer2 is empty (no updates)
+
+ agg.merge(buffer1, buffer2)
+ val result = evalAsFloatArray(agg, buffer1)
+ assert(result === Array(1.0f, 2.0f))
+ }
+
+ test("VectorSum - merge empty buffer with non-empty") {
+ val (agg, buffer1, _) = createVectorSum()
+ val (_, buffer2, _) = createVectorSum()
+
+ // buffer1 is empty
+ agg.update(buffer2, createInputRow(Array(1.0f, 2.0f)))
+
+ agg.merge(buffer1, buffer2)
+ val result = evalAsFloatArray(agg, buffer1)
+ assert(result === Array(1.0f, 2.0f))
+ }
+
+ test("VectorSum - merge two empty buffers") {
+ val (agg, buffer1, _) = createVectorSum()
+ val (_, buffer2, _) = createVectorSum()
+
+ agg.merge(buffer1, buffer2)
+ assert(agg.eval(buffer1) === null)
+ }
+
+ test("VectorSum - special float: NaN propagates") {
+ val (agg, buffer, _) = createVectorSum()
+ agg.update(buffer, createInputRow(Array(1.0f, 2.0f)))
+ agg.update(buffer, createInputRow(Array(Float.NaN, 3.0f)))
+ val result = evalAsFloatArray(agg, buffer)
+ assert(result(0).isNaN)
+ assert(result(1) === 5.0f)
+ }
+
+ test("VectorSum - special float: Infinity") {
+ val (agg, buffer, _) = createVectorSum()
+ agg.update(buffer, createInputRow(Array(1.0f, 2.0f)))
+ agg.update(buffer, createInputRow(Array(Float.PositiveInfinity, 3.0f)))
+ val result = evalAsFloatArray(agg, buffer)
+ assert(result(0) === Float.PositiveInfinity)
+ assert(result(1) === 5.0f)
+ }
+
+ test("VectorSum - special float: Negative Infinity") {
+ val (agg, buffer, _) = createVectorSum()
+ agg.update(buffer, createInputRow(Array(1.0f, 2.0f)))
+ agg.update(buffer, createInputRow(Array(Float.NegativeInfinity, 3.0f)))
+ val result = evalAsFloatArray(agg, buffer)
+ assert(result(0) === Float.NegativeInfinity)
+ assert(result(1) === 5.0f)
+ }
+
+ test("VectorAvg - empty buffer returns null") {
+ val (agg, buffer, _) = createVectorAvg()
+ assert(agg.eval(buffer) === null)
+ }
+
+ test("VectorAvg - single vector") {
+ val (agg, buffer, _) = createVectorAvg()
+ agg.update(buffer, createInputRow(Array(1.0f, 2.0f, 3.0f)))
+ val result = evalAsFloatArray(agg, buffer)
+ assert(result === Array(1.0f, 2.0f, 3.0f))
+ }
+
+ test("VectorAvg - multiple vectors") {
+ val (agg, buffer, _) = createVectorAvg()
+ agg.update(buffer, createInputRow(Array(1.0f, 1.0f)))
+ agg.update(buffer, createInputRow(Array(2.0f, 2.0f)))
+ agg.update(buffer, createInputRow(Array(3.0f, 3.0f)))
+ val result = evalAsFloatArray(agg, buffer)
+ assert(result === Array(2.0f, 2.0f))
+ }
+
+ test("VectorAvg - null vectors are skipped") {
+ val (agg, buffer, _) = createVectorAvg()
+ agg.update(buffer, createInputRow(Array(1.0f, 2.0f)))
+ agg.update(buffer, createNullInputRow())
+ agg.update(buffer, createInputRow(Array(3.0f, 4.0f)))
+ val result = evalAsFloatArray(agg, buffer)
+ // Average of [1,2] and [3,4] = [2, 3]
+ assert(result === Array(2.0f, 3.0f))
+ }
+
+ test("VectorAvg - all null vectors returns null") {
+ val (agg, buffer, _) = createVectorAvg()
+ agg.update(buffer, createNullInputRow())
+ agg.update(buffer, createNullInputRow())
+ assert(agg.eval(buffer) === null)
+ }
+
+ test("VectorAvg - empty vectors") {
+ val (agg, buffer, _) = createVectorAvg()
+ agg.update(buffer, createInputRow(Array.empty[Float]))
+ agg.update(buffer, createInputRow(Array.empty[Float]))
+ val result = evalAsFloatArray(agg, buffer)
+ assert(result === Array.empty[Float])
+ }
+
+ test("VectorAvg - merge two buffers") {
+ val (agg, buffer1, _) = createVectorAvg()
+ val (_, buffer2, _) = createVectorAvg()
+
+ // Partition 1: avg([1,2], [3,4]) = [2, 3], count=2
+ agg.update(buffer1, createInputRow(Array(1.0f, 2.0f)))
+ agg.update(buffer1, createInputRow(Array(3.0f, 4.0f)))
+
+ // Partition 2: avg([5,6], [7,8]) = [6, 7], count=2
+ agg.update(buffer2, createInputRow(Array(5.0f, 6.0f)))
+ agg.update(buffer2, createInputRow(Array(7.0f, 8.0f)))
+
+ // Merged average: (2*[2,3] + 2*[6,7]) / 4 = [4, 5]
+ agg.merge(buffer1, buffer2)
+
+ val result = evalAsFloatArray(agg, buffer1)
+ assert(result === Array(4.0f, 5.0f))
+ }
+
+ test("VectorAvg - merge with different counts") {
+ val (agg, buffer1, _) = createVectorAvg()
+ val (_, buffer2, _) = createVectorAvg()
+
+ // Partition 1: single vector [10, 20], count=1
+ agg.update(buffer1, createInputRow(Array(10.0f, 20.0f)))
+
+ // Partition 2: three vectors, avg = [2, 2], count=3
+ agg.update(buffer2, createInputRow(Array(1.0f, 1.0f)))
+ agg.update(buffer2, createInputRow(Array(2.0f, 2.0f)))
+ agg.update(buffer2, createInputRow(Array(3.0f, 3.0f)))
+
+ // Merged: (1*[10,20] + 3*[2,2]) / 4 = [16/4, 26/4] = [4, 6.5]
+ agg.merge(buffer1, buffer2)
+
+ val result = evalAsFloatArray(agg, buffer1)
+ assert(result === Array(4.0f, 6.5f))
+ }
+
+ test("VectorAvg - merge with empty buffer") {
+ val (agg, buffer1, _) = createVectorAvg()
+ val (_, buffer2, _) = createVectorAvg()
+
+ agg.update(buffer1, createInputRow(Array(1.0f, 2.0f)))
+ // buffer2 is empty
+
+ agg.merge(buffer1, buffer2)
+ val result = evalAsFloatArray(agg, buffer1)
+ assert(result === Array(1.0f, 2.0f))
+ }
+
+ test("VectorAvg - merge empty buffer with non-empty") {
+ val (agg, buffer1, _) = createVectorAvg()
+ val (_, buffer2, _) = createVectorAvg()
+
+ // buffer1 is empty
+ agg.update(buffer2, createInputRow(Array(1.0f, 2.0f)))
+
+ agg.merge(buffer1, buffer2)
+ val result = evalAsFloatArray(agg, buffer1)
+ assert(result === Array(1.0f, 2.0f))
+ }
+
+ test("VectorAvg - merge two empty buffers") {
+ val (agg, buffer1, _) = createVectorAvg()
+ val (_, buffer2, _) = createVectorAvg()
+
+ agg.merge(buffer1, buffer2)
+ assert(agg.eval(buffer1) === null)
+ }
+
+ test("VectorAvg - special float: NaN propagates") {
+ val (agg, buffer, _) = createVectorAvg()
+ agg.update(buffer, createInputRow(Array(1.0f, 2.0f)))
+ agg.update(buffer, createInputRow(Array(Float.NaN, 4.0f)))
+ val result = evalAsFloatArray(agg, buffer)
+ assert(result(0).isNaN)
+ assert(result(1) === 3.0f)
+ }
+
+ test("VectorAvg - special float: Infinity") {
+ val (agg, buffer, _) = createVectorAvg()
+ agg.update(buffer, createInputRow(Array(1.0f, 2.0f)))
+ agg.update(buffer, createInputRow(Array(Float.PositiveInfinity, 4.0f)))
+ val result = evalAsFloatArray(agg, buffer)
+ assert(result(0) === Float.PositiveInfinity)
+ assert(result(1) === 3.0f)
+ }
+
+ test("VectorAvg - numerical stability with running average") {
+ val (agg, buffer, _) = createVectorAvg()
+ // Add many vectors to test numerical stability of running average
+ for (i <- 1 to 100) {
+ agg.update(buffer, createInputRow(Array(i.toFloat, i.toFloat)))
+ }
+ val result = evalAsFloatArray(agg, buffer)
+ // Average of 1 to 100 = 50.5
+ assertFloatArrayEquals(result, Array(50.5f, 50.5f), tolerance = 1e-3f)
+ }
+
+ test("VectorSum - mathematical correctness: element-wise sum") {
+ val (agg, buffer, _) = createVectorSum()
+ agg.update(buffer, createInputRow(Array(1.0f, 2.0f, 3.0f)))
+ agg.update(buffer, createInputRow(Array(10.0f, 20.0f, 30.0f)))
+ val result = evalAsFloatArray(agg, buffer)
+ // [1, 2, 3] + [10, 20, 30] = [11, 22, 33]
+ assert(result === Array(11.0f, 22.0f, 33.0f))
+ }
+
+ test("VectorAvg - mathematical correctness: element-wise average") {
+ val (agg, buffer, _) = createVectorAvg()
+ agg.update(buffer, createInputRow(Array(0.0f, 0.0f)))
+ agg.update(buffer, createInputRow(Array(10.0f, 20.0f)))
+ val result = evalAsFloatArray(agg, buffer)
+ // avg([0, 0], [10, 20]) = [5, 10]
+ assert(result === Array(5.0f, 10.0f))
+ }
+
+ test("VectorAvg - mathematical correctness: negative values") {
+ val (agg, buffer, _) = createVectorAvg()
+ agg.update(buffer, createInputRow(Array(-5.0f, 10.0f)))
+ agg.update(buffer, createInputRow(Array(5.0f, -10.0f)))
+ val result = evalAsFloatArray(agg, buffer)
+ // avg([-5, 10], [5, -10]) = [0, 0]
+ assert(result === Array(0.0f, 0.0f))
+ }
+
+ test("VectorSum - vectors with null elements are skipped") {
+ val (agg, buffer, _) = createVectorSum()
+ agg.update(buffer, createInputRow(Array(1.0f, 2.0f)))
+ agg.update(buffer, createInputRowWithNullElement(Array(null, 10.0f)))
+ agg.update(buffer, createInputRow(Array(3.0f, 4.0f)))
+ val result = evalAsFloatArray(agg, buffer)
+ // Vector with null element is skipped, so [1, 2] + [3, 4] = [4, 6]
+ assert(result === Array(4.0f, 6.0f))
+ }
+
+ test("VectorAvg - vectors with null elements are skipped") {
+ val (agg, buffer, _) = createVectorAvg()
+ agg.update(buffer, createInputRow(Array(1.0f, 2.0f)))
+ agg.update(buffer, createInputRowWithNullElement(Array(null, 10.0f)))
+ agg.update(buffer, createInputRow(Array(3.0f, 4.0f)))
+ val result = evalAsFloatArray(agg, buffer)
+ // Vector with null element is skipped, so avg([1, 2], [3, 4]) = [2, 3]
+ assert(result === Array(2.0f, 3.0f))
+ }
+
+ test("VectorSum - only vectors with null elements returns null") {
+ val (agg, buffer, _) = createVectorSum()
+ agg.update(buffer, createInputRowWithNullElement(Array(1.0f, null)))
+ agg.update(buffer, createInputRowWithNullElement(Array(null, 2.0f)))
+ assert(agg.eval(buffer) === null)
+ }
+
+ test("VectorAvg - only vectors with null elements returns null") {
+ val (agg, buffer, _) = createVectorAvg()
+ agg.update(buffer, createInputRowWithNullElement(Array(1.0f, null)))
+ agg.update(buffer, createInputRowWithNullElement(Array(null, 2.0f)))
+ assert(agg.eval(buffer) === null)
+ }
+
+ test("VectorSum - mix of null vectors and vectors with null elements") {
+ val (agg, buffer, _) = createVectorSum()
+ agg.update(buffer, createNullInputRow())
+ agg.update(buffer, createInputRowWithNullElement(Array(1.0f, null)))
+ agg.update(buffer, createInputRow(Array(1.0f, 2.0f)))
+ agg.update(buffer, createInputRow(Array(3.0f, 4.0f)))
+ val result = evalAsFloatArray(agg, buffer)
+ // Only valid vectors are summed: [1, 2] + [3, 4] = [4, 6]
+ assert(result === Array(4.0f, 6.0f))
+ }
+
+ test("VectorSum - large vectors (16 elements)") {
+ val (agg, buffer, _) = createVectorSum()
+ val vec1 = Array(1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f,
+ 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f)
+ val vec2 = Array(16.0f, 15.0f, 14.0f, 13.0f, 12.0f, 11.0f, 10.0f, 9.0f,
+ 8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f)
+ agg.update(buffer, createInputRow(vec1))
+ agg.update(buffer, createInputRow(vec2))
+ val result = evalAsFloatArray(agg, buffer)
+ // Each element should sum to 17
+ assert(result === Array.fill(16)(17.0f))
+ }
+
+ test("VectorAvg - large vectors (16 elements)") {
+ val (agg, buffer, _) = createVectorAvg()
+ val vec1 = Array(1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f,
+ 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f)
+ val vec2 = Array(16.0f, 15.0f, 14.0f, 13.0f, 12.0f, 11.0f, 10.0f, 9.0f,
+ 8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f)
+ agg.update(buffer, createInputRow(vec1))
+ agg.update(buffer, createInputRow(vec2))
+ val result = evalAsFloatArray(agg, buffer)
+ // Each element average should be 8.5
+ assert(result === Array.fill(16)(8.5f))
+ }
+
+ test("VectorSum - large vector with null element is skipped") {
+ val (agg, buffer, _) = createVectorSum()
+ val vec1 = Array[java.lang.Float](1.0f, 2.0f, 3.0f, 4.0f, 5.0f, null,
7.0f, 8.0f,
+ 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f)
+ val vec2 = Array(16.0f, 15.0f, 14.0f, 13.0f, 12.0f, 11.0f, 10.0f, 9.0f,
+ 8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f)
+ agg.update(buffer, createInputRowWithNullElement(vec1))
+ agg.update(buffer, createInputRow(vec2))
+ val result = evalAsFloatArray(agg, buffer)
+ // First vector is skipped due to null element, result is just vec2
+ assert(result === vec2)
+ }
+
+ test("VectorSum - single element vectors") {
+ val (agg, buffer, _) = createVectorSum()
+ agg.update(buffer, createInputRow(Array(5.0f)))
+ agg.update(buffer, createInputRow(Array(3.0f)))
+ val result = evalAsFloatArray(agg, buffer)
+ assert(result === Array(8.0f))
+ }
+
+ test("VectorAvg - single element vectors") {
+ val (agg, buffer, _) = createVectorAvg()
+ agg.update(buffer, createInputRow(Array(5.0f)))
+ agg.update(buffer, createInputRow(Array(3.0f)))
+ val result = evalAsFloatArray(agg, buffer)
+ assert(result === Array(4.0f))
+ }
+}
diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
index 35d7c00795e1..311f3565f121 100644
--- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
+++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
@@ -433,11 +433,13 @@
| org.apache.spark.sql.catalyst.expressions.UrlEncode | url_encode | SELECT
url_encode('https://spark.apache.org') |
struct<url_encode(https://spark.apache.org):string> |
| org.apache.spark.sql.catalyst.expressions.Uuid | uuid | SELECT uuid() |
struct<uuid():string> |
| org.apache.spark.sql.catalyst.expressions.ValidateUTF8 | validate_utf8 |
SELECT validate_utf8('Spark') | struct<validate_utf8(Spark):string> |
+| org.apache.spark.sql.catalyst.expressions.VectorAvg | vector_avg | SELECT
vector_avg(col) FROM VALUES (array(1.0F, 2.0F)), (array(3.0F, 4.0F)) AS
tab(col) | struct<vector_avg(col):array<float>> |
| org.apache.spark.sql.catalyst.expressions.VectorCosineSimilarity |
vector_cosine_similarity | SELECT vector_cosine_similarity(array(1.0F, 2.0F,
3.0F), array(4.0F, 5.0F, 6.0F)) | struct<vector_cosine_similarity(array(1.0,
2.0, 3.0), array(4.0, 5.0, 6.0)):float> |
| org.apache.spark.sql.catalyst.expressions.VectorInnerProduct |
vector_inner_product | SELECT vector_inner_product(array(1.0F, 2.0F, 3.0F),
array(4.0F, 5.0F, 6.0F)) | struct<vector_inner_product(array(1.0, 2.0, 3.0),
array(4.0, 5.0, 6.0)):float> |
| org.apache.spark.sql.catalyst.expressions.VectorL2Distance |
vector_l2_distance | SELECT vector_l2_distance(array(1.0F, 2.0F, 3.0F),
array(4.0F, 5.0F, 6.0F)) | struct<vector_l2_distance(array(1.0, 2.0, 3.0),
array(4.0, 5.0, 6.0)):float> |
| org.apache.spark.sql.catalyst.expressions.VectorNorm | vector_norm | SELECT
vector_norm(array(3.0F, 4.0F), 2.0F) | struct<vector_norm(array(3.0, 4.0),
2.0):float> |
| org.apache.spark.sql.catalyst.expressions.VectorNormalize | vector_normalize
| SELECT vector_normalize(array(3.0F, 4.0F), 2.0F) |
struct<vector_normalize(array(3.0, 4.0), 2.0):array<float>> |
+| org.apache.spark.sql.catalyst.expressions.VectorSum | vector_sum | SELECT
vector_sum(col) FROM VALUES (array(1.0F, 2.0F)), (array(3.0F, 4.0F)) AS
tab(col) | struct<vector_sum(col):array<float>> |
| org.apache.spark.sql.catalyst.expressions.WeekDay | weekday | SELECT
weekday('2009-07-30') | struct<weekday(2009-07-30):int> |
| org.apache.spark.sql.catalyst.expressions.WeekOfYear | weekofyear | SELECT
weekofyear('2008-02-20') | struct<weekofyear(2008-02-20):int> |
| org.apache.spark.sql.catalyst.expressions.WidthBucket | width_bucket |
SELECT width_bucket(5.3, 0.2, 10.6, 5) | struct<width_bucket(5.3, 0.2, 10.6,
5):bigint> |
diff --git
a/sql/core/src/test/resources/sql-tests/analyzer-results/vector-agg.sql.out
b/sql/core/src/test/resources/sql-tests/analyzer-results/vector-agg.sql.out
new file mode 100644
index 000000000000..1df64545f55c
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/vector-agg.sql.out
@@ -0,0 +1,243 @@
+-- Automatically generated by SQLQueryTestSuite
+-- !query
+SELECT vector_sum(col) FROM VALUES (array(1.0F, 2.0F, 3.0F)), (array(4.0F,
5.0F, 6.0F)) AS tab(col)
+-- !query analysis
+Aggregate [vector_sum(col#x, 0, 0) AS vector_sum(col)#x]
++- SubqueryAlias tab
+ +- LocalRelation [col#x]
+
+
+-- !query
+SELECT vector_avg(col) FROM VALUES (array(1.0F, 2.0F, 3.0F)), (array(3.0F,
4.0F, 5.0F)) AS tab(col)
+-- !query analysis
+Aggregate [vector_avg(col#x, 0, 0) AS vector_avg(col)#x]
++- SubqueryAlias tab
+ +- LocalRelation [col#x]
+
+
+-- !query
+SELECT category, vector_sum(embedding) AS sum_vector
+FROM VALUES
+ ('A', array(1.0F, 2.0F, 3.0F)),
+ ('A', array(4.0F, 5.0F, 6.0F)),
+ ('B', array(2.0F, 1.0F, 4.0F)),
+ ('B', array(3.0F, 2.0F, 1.0F))
+AS tab(category, embedding)
+GROUP BY category
+ORDER BY category
+-- !query analysis
+Sort [category#x ASC NULLS FIRST], true
++- Aggregate [category#x], [category#x, vector_sum(embedding#x, 0, 0) AS
sum_vector#x]
+ +- SubqueryAlias tab
+ +- LocalRelation [category#x, embedding#x]
+
+
+-- !query
+SELECT category, vector_avg(embedding) AS avg_vector
+FROM VALUES
+ ('A', array(2.0F, 4.0F, 6.0F)),
+ ('A', array(4.0F, 8.0F, 12.0F)),
+ ('B', array(1.0F, 2.0F, 3.0F)),
+ ('B', array(3.0F, 6.0F, 9.0F))
+AS tab(category, embedding)
+GROUP BY category
+ORDER BY category
+-- !query analysis
+Sort [category#x ASC NULLS FIRST], true
++- Aggregate [category#x], [category#x, vector_avg(embedding#x, 0, 0) AS
avg_vector#x]
+ +- SubqueryAlias tab
+ +- LocalRelation [category#x, embedding#x]
+
+
+-- !query
+SELECT vector_sum(col) FROM VALUES
+ (array(1.0F, 2.0F)),
+ (CAST(NULL AS ARRAY<FLOAT>)),
+ (array(3.0F, 4.0F)) AS tab(col)
+-- !query analysis
+Aggregate [vector_sum(col#x, 0, 0) AS vector_sum(col)#x]
++- SubqueryAlias tab
+ +- LocalRelation [col#x]
+
+
+-- !query
+SELECT vector_avg(col) FROM VALUES
+ (array(1.0F, 2.0F)),
+ (CAST(NULL AS ARRAY<FLOAT>)),
+ (array(3.0F, 4.0F)) AS tab(col)
+-- !query analysis
+Aggregate [vector_avg(col#x, 0, 0) AS vector_avg(col)#x]
++- SubqueryAlias tab
+ +- LocalRelation [col#x]
+
+
+-- !query
+SELECT vector_sum(col) FROM VALUES
+ (CAST(NULL AS ARRAY<FLOAT>)),
+ (CAST(NULL AS ARRAY<FLOAT>)) AS tab(col)
+-- !query analysis
+Aggregate [vector_sum(col#x, 0, 0) AS vector_sum(col)#x]
++- SubqueryAlias tab
+ +- LocalRelation [col#x]
+
+
+-- !query
+SELECT vector_sum(col) FROM VALUES
+ (CAST(array() AS ARRAY<FLOAT>)),
+ (CAST(array() AS ARRAY<FLOAT>)) AS tab(col)
+-- !query analysis
+Aggregate [vector_sum(col#x, 0, 0) AS vector_sum(col)#x]
++- SubqueryAlias tab
+ +- LocalRelation [col#x]
+
+
+-- !query
+SELECT key, vector_sum(vec) OVER (PARTITION BY grp) as sum_vec
+FROM VALUES
+ ('a', 'g1', array(1.0F, 2.0F)),
+ ('b', 'g1', array(3.0F, 4.0F)),
+ ('c', 'g2', array(5.0F, 6.0F))
+AS tab(key, grp, vec)
+ORDER BY key
+-- !query analysis
+Sort [key#x ASC NULLS FIRST], true
++- Project [key#x, sum_vec#x]
+ +- Project [key#x, vec#x, grp#x, sum_vec#x, sum_vec#x]
+ +- Window [vector_sum(vec#x, 0, 0) windowspecdefinition(grp#x,
specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$()))
AS sum_vec#x], [grp#x]
+ +- Project [key#x, vec#x, grp#x]
+ +- SubqueryAlias tab
+ +- LocalRelation [key#x, grp#x, vec#x]
+
+
+-- !query
+SELECT key, vector_avg(vec) OVER (PARTITION BY grp) as avg_vec
+FROM VALUES
+ ('a', 'g1', array(1.0F, 2.0F)),
+ ('b', 'g1', array(3.0F, 4.0F)),
+ ('c', 'g2', array(5.0F, 6.0F))
+AS tab(key, grp, vec)
+ORDER BY key
+-- !query analysis
+Sort [key#x ASC NULLS FIRST], true
++- Project [key#x, avg_vec#x]
+ +- Project [key#x, vec#x, grp#x, avg_vec#x, avg_vec#x]
+ +- Window [vector_avg(vec#x, 0, 0) windowspecdefinition(grp#x,
specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$()))
AS avg_vec#x], [grp#x]
+ +- Project [key#x, vec#x, grp#x]
+ +- SubqueryAlias tab
+ +- LocalRelation [key#x, grp#x, vec#x]
+
+
+-- !query
+SELECT vector_sum(col) FROM VALUES
+ (array(1.0F, 2.0F, 3.0F)),
+ (array(1.0F, 2.0F)) AS tab(col)
+-- !query analysis
+Aggregate [vector_sum(col#x, 0, 0) AS vector_sum(col)#x]
++- SubqueryAlias tab
+ +- LocalRelation [col#x]
+
+
+-- !query
+SELECT vector_avg(col) FROM VALUES
+ (array(1.0F, 2.0F, 3.0F)),
+ (array(1.0F, 2.0F)) AS tab(col)
+-- !query analysis
+Aggregate [vector_avg(col#x, 0, 0) AS vector_avg(col)#x]
++- SubqueryAlias tab
+ +- LocalRelation [col#x]
+
+
+-- !query
+SELECT vector_sum(col) FROM VALUES (array(1.0D, 2.0D)), (array(3.0D, 4.0D)) AS
tab(col)
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"col\"",
+ "inputType" : "\"ARRAY<DOUBLE>\"",
+ "paramIndex" : "first",
+ "requiredType" : "\"ARRAY<FLOAT>\"",
+ "sqlExpr" : "\"vector_sum(col)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 22,
+ "fragment" : "vector_sum(col)"
+ } ]
+}
+
+
+-- !query
+SELECT vector_avg(col) FROM VALUES (array(1, 2)), (array(3, 4)) AS tab(col)
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"col\"",
+ "inputType" : "\"ARRAY<INT>\"",
+ "paramIndex" : "first",
+ "requiredType" : "\"ARRAY<FLOAT>\"",
+ "sqlExpr" : "\"vector_avg(col)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 22,
+ "fragment" : "vector_avg(col)"
+ } ]
+}
+
+
+-- !query
+SELECT vector_sum(col) FROM VALUES ('not an array'), ('also not an array') AS
tab(col)
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"col\"",
+ "inputType" : "\"STRING\"",
+ "paramIndex" : "first",
+ "requiredType" : "\"ARRAY<FLOAT>\"",
+ "sqlExpr" : "\"vector_sum(col)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 22,
+ "fragment" : "vector_sum(col)"
+ } ]
+}
+
+
+-- !query
+SELECT vector_avg(col) FROM VALUES (1.0F), (2.0F) AS tab(col)
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"col\"",
+ "inputType" : "\"FLOAT\"",
+ "paramIndex" : "first",
+ "requiredType" : "\"ARRAY<FLOAT>\"",
+ "sqlExpr" : "\"vector_avg(col)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 22,
+ "fragment" : "vector_avg(col)"
+ } ]
+}
diff --git a/sql/core/src/test/resources/sql-tests/inputs/vector-agg.sql
b/sql/core/src/test/resources/sql-tests/inputs/vector-agg.sql
new file mode 100644
index 000000000000..50d532f6ff2b
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/vector-agg.sql
@@ -0,0 +1,81 @@
+-- Tests for vector aggregation functions: vector_sum, vector_avg
+
+-- Basic functionality
+SELECT vector_sum(col) FROM VALUES (array(1.0F, 2.0F, 3.0F)), (array(4.0F,
5.0F, 6.0F)) AS tab(col);
+SELECT vector_avg(col) FROM VALUES (array(1.0F, 2.0F, 3.0F)), (array(3.0F,
4.0F, 5.0F)) AS tab(col);
+
+-- GROUP BY aggregation
+SELECT category, vector_sum(embedding) AS sum_vector
+FROM VALUES
+ ('A', array(1.0F, 2.0F, 3.0F)),
+ ('A', array(4.0F, 5.0F, 6.0F)),
+ ('B', array(2.0F, 1.0F, 4.0F)),
+ ('B', array(3.0F, 2.0F, 1.0F))
+AS tab(category, embedding)
+GROUP BY category
+ORDER BY category;
+
+SELECT category, vector_avg(embedding) AS avg_vector
+FROM VALUES
+ ('A', array(2.0F, 4.0F, 6.0F)),
+ ('A', array(4.0F, 8.0F, 12.0F)),
+ ('B', array(1.0F, 2.0F, 3.0F)),
+ ('B', array(3.0F, 6.0F, 9.0F))
+AS tab(category, embedding)
+GROUP BY category
+ORDER BY category;
+
+-- NULL vector handling
+SELECT vector_sum(col) FROM VALUES
+ (array(1.0F, 2.0F)),
+ (CAST(NULL AS ARRAY<FLOAT>)),
+ (array(3.0F, 4.0F)) AS tab(col);
+
+SELECT vector_avg(col) FROM VALUES
+ (array(1.0F, 2.0F)),
+ (CAST(NULL AS ARRAY<FLOAT>)),
+ (array(3.0F, 4.0F)) AS tab(col);
+
+-- All NULL vectors returns NULL
+SELECT vector_sum(col) FROM VALUES
+ (CAST(NULL AS ARRAY<FLOAT>)),
+ (CAST(NULL AS ARRAY<FLOAT>)) AS tab(col);
+
+-- Empty vectors
+SELECT vector_sum(col) FROM VALUES
+ (CAST(array() AS ARRAY<FLOAT>)),
+ (CAST(array() AS ARRAY<FLOAT>)) AS tab(col);
+
+-- Window function
+SELECT key, vector_sum(vec) OVER (PARTITION BY grp) as sum_vec
+FROM VALUES
+ ('a', 'g1', array(1.0F, 2.0F)),
+ ('b', 'g1', array(3.0F, 4.0F)),
+ ('c', 'g2', array(5.0F, 6.0F))
+AS tab(key, grp, vec)
+ORDER BY key;
+
+SELECT key, vector_avg(vec) OVER (PARTITION BY grp) as avg_vec
+FROM VALUES
+ ('a', 'g1', array(1.0F, 2.0F)),
+ ('b', 'g1', array(3.0F, 4.0F)),
+ ('c', 'g2', array(5.0F, 6.0F))
+AS tab(key, grp, vec)
+ORDER BY key;
+
+-- Error cases: dimension mismatch
+SELECT vector_sum(col) FROM VALUES
+ (array(1.0F, 2.0F, 3.0F)),
+ (array(1.0F, 2.0F)) AS tab(col);
+
+SELECT vector_avg(col) FROM VALUES
+ (array(1.0F, 2.0F, 3.0F)),
+ (array(1.0F, 2.0F)) AS tab(col);
+
+-- Error cases: wrong element type (only ARRAY<FLOAT> accepted)
+SELECT vector_sum(col) FROM VALUES (array(1.0D, 2.0D)), (array(3.0D, 4.0D)) AS
tab(col);
+SELECT vector_avg(col) FROM VALUES (array(1, 2)), (array(3, 4)) AS tab(col);
+
+-- Error cases: non-array argument
+SELECT vector_sum(col) FROM VALUES ('not an array'), ('also not an array') AS
tab(col);
+SELECT vector_avg(col) FROM VALUES (1.0F), (2.0F) AS tab(col);
diff --git a/sql/core/src/test/resources/sql-tests/results/vector-agg.sql.out
b/sql/core/src/test/resources/sql-tests/results/vector-agg.sql.out
new file mode 100644
index 000000000000..f87757d852ec
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/vector-agg.sql.out
@@ -0,0 +1,265 @@
+-- Automatically generated by SQLQueryTestSuite
+-- !query
+SELECT vector_sum(col) FROM VALUES (array(1.0F, 2.0F, 3.0F)), (array(4.0F,
5.0F, 6.0F)) AS tab(col)
+-- !query schema
+struct<vector_sum(col):array<float>>
+-- !query output
+[5.0,7.0,9.0]
+
+
+-- !query
+SELECT vector_avg(col) FROM VALUES (array(1.0F, 2.0F, 3.0F)), (array(3.0F,
4.0F, 5.0F)) AS tab(col)
+-- !query schema
+struct<vector_avg(col):array<float>>
+-- !query output
+[2.0,3.0,4.0]
+
+
+-- !query
+SELECT category, vector_sum(embedding) AS sum_vector
+FROM VALUES
+ ('A', array(1.0F, 2.0F, 3.0F)),
+ ('A', array(4.0F, 5.0F, 6.0F)),
+ ('B', array(2.0F, 1.0F, 4.0F)),
+ ('B', array(3.0F, 2.0F, 1.0F))
+AS tab(category, embedding)
+GROUP BY category
+ORDER BY category
+-- !query schema
+struct<category:string,sum_vector:array<float>>
+-- !query output
+A [5.0,7.0,9.0]
+B [5.0,3.0,5.0]
+
+
+-- !query
+SELECT category, vector_avg(embedding) AS avg_vector
+FROM VALUES
+ ('A', array(2.0F, 4.0F, 6.0F)),
+ ('A', array(4.0F, 8.0F, 12.0F)),
+ ('B', array(1.0F, 2.0F, 3.0F)),
+ ('B', array(3.0F, 6.0F, 9.0F))
+AS tab(category, embedding)
+GROUP BY category
+ORDER BY category
+-- !query schema
+struct<category:string,avg_vector:array<float>>
+-- !query output
+A [3.0,6.0,9.0]
+B [2.0,4.0,6.0]
+
+
+-- !query
+SELECT vector_sum(col) FROM VALUES
+ (array(1.0F, 2.0F)),
+ (CAST(NULL AS ARRAY<FLOAT>)),
+ (array(3.0F, 4.0F)) AS tab(col)
+-- !query schema
+struct<vector_sum(col):array<float>>
+-- !query output
+[4.0,6.0]
+
+
+-- !query
+SELECT vector_avg(col) FROM VALUES
+ (array(1.0F, 2.0F)),
+ (CAST(NULL AS ARRAY<FLOAT>)),
+ (array(3.0F, 4.0F)) AS tab(col)
+-- !query schema
+struct<vector_avg(col):array<float>>
+-- !query output
+[2.0,3.0]
+
+
+-- !query
+SELECT vector_sum(col) FROM VALUES
+ (CAST(NULL AS ARRAY<FLOAT>)),
+ (CAST(NULL AS ARRAY<FLOAT>)) AS tab(col)
+-- !query schema
+struct<vector_sum(col):array<float>>
+-- !query output
+NULL
+
+
+-- !query
+SELECT vector_sum(col) FROM VALUES
+ (CAST(array() AS ARRAY<FLOAT>)),
+ (CAST(array() AS ARRAY<FLOAT>)) AS tab(col)
+-- !query schema
+struct<vector_sum(col):array<float>>
+-- !query output
+[]
+
+
+-- !query
+SELECT key, vector_sum(vec) OVER (PARTITION BY grp) as sum_vec
+FROM VALUES
+ ('a', 'g1', array(1.0F, 2.0F)),
+ ('b', 'g1', array(3.0F, 4.0F)),
+ ('c', 'g2', array(5.0F, 6.0F))
+AS tab(key, grp, vec)
+ORDER BY key
+-- !query schema
+struct<key:string,sum_vec:array<float>>
+-- !query output
+a [4.0,6.0]
+b [4.0,6.0]
+c [5.0,6.0]
+
+
+-- !query
+SELECT key, vector_avg(vec) OVER (PARTITION BY grp) as avg_vec
+FROM VALUES
+ ('a', 'g1', array(1.0F, 2.0F)),
+ ('b', 'g1', array(3.0F, 4.0F)),
+ ('c', 'g2', array(5.0F, 6.0F))
+AS tab(key, grp, vec)
+ORDER BY key
+-- !query schema
+struct<key:string,avg_vec:array<float>>
+-- !query output
+a [2.0,3.0]
+b [2.0,3.0]
+c [5.0,6.0]
+
+
+-- !query
+SELECT vector_sum(col) FROM VALUES
+ (array(1.0F, 2.0F, 3.0F)),
+ (array(1.0F, 2.0F)) AS tab(col)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkRuntimeException
+{
+ "errorClass" : "VECTOR_DIMENSION_MISMATCH",
+ "sqlState" : "22000",
+ "messageParameters" : {
+ "functionName" : "`vector_sum`",
+ "leftDim" : "3",
+ "rightDim" : "2"
+ }
+}
+
+
+-- !query
+SELECT vector_avg(col) FROM VALUES
+ (array(1.0F, 2.0F, 3.0F)),
+ (array(1.0F, 2.0F)) AS tab(col)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkRuntimeException
+{
+ "errorClass" : "VECTOR_DIMENSION_MISMATCH",
+ "sqlState" : "22000",
+ "messageParameters" : {
+ "functionName" : "`vector_avg`",
+ "leftDim" : "3",
+ "rightDim" : "2"
+ }
+}
+
+
+-- !query
+SELECT vector_sum(col) FROM VALUES (array(1.0D, 2.0D)), (array(3.0D, 4.0D)) AS
tab(col)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"col\"",
+ "inputType" : "\"ARRAY<DOUBLE>\"",
+ "paramIndex" : "first",
+ "requiredType" : "\"ARRAY<FLOAT>\"",
+ "sqlExpr" : "\"vector_sum(col)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 22,
+ "fragment" : "vector_sum(col)"
+ } ]
+}
+
+
+-- !query
+SELECT vector_avg(col) FROM VALUES (array(1, 2)), (array(3, 4)) AS tab(col)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"col\"",
+ "inputType" : "\"ARRAY<INT>\"",
+ "paramIndex" : "first",
+ "requiredType" : "\"ARRAY<FLOAT>\"",
+ "sqlExpr" : "\"vector_avg(col)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 22,
+ "fragment" : "vector_avg(col)"
+ } ]
+}
+
+
+-- !query
+SELECT vector_sum(col) FROM VALUES ('not an array'), ('also not an array') AS
tab(col)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"col\"",
+ "inputType" : "\"STRING\"",
+ "paramIndex" : "first",
+ "requiredType" : "\"ARRAY<FLOAT>\"",
+ "sqlExpr" : "\"vector_sum(col)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 22,
+ "fragment" : "vector_sum(col)"
+ } ]
+}
+
+
+-- !query
+SELECT vector_avg(col) FROM VALUES (1.0F), (2.0F) AS tab(col)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"col\"",
+ "inputType" : "\"FLOAT\"",
+ "paramIndex" : "first",
+ "requiredType" : "\"ARRAY<FLOAT>\"",
+ "sqlExpr" : "\"vector_avg(col)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 22,
+ "fragment" : "vector_avg(col)"
+ } ]
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]