This is an automated email from the ASF dual-hosted git repository.
corgy pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/seatunnel.git
The following commit(s) were added to refs/heads/dev by this push:
new c115613de1 [Feature][Transform V2] Add vector dimension reduction
transform (#9783)
c115613de1 is described below
commit c115613de1c62af680a43363a766f72165be9bfc
Author: CosmosNi <[email protected]>
AuthorDate: Tue Sep 2 15:25:42 2025 +0800
[Feature][Transform V2] Add vector dimension reduction transform (#9783)
---
docs/en/transform-v2/sql-functions.md | 42 +++-
docs/zh/transform-v2/sql-functions.md | 41 +++-
.../apache/seatunnel/e2e/transform/TestSQLIT.java | 12 ++
.../test/resources/sql_transform/func_vector.conf | 142 +++++++++++++
.../transform/sql/zeta/ZetaSQLFunction.java | 8 +
.../seatunnel/transform/sql/zeta/ZetaSQLType.java | 5 +
.../sql/zeta/functions/VectorFunction.java | 156 ++++++++++++++
.../transform/sql/SQLVectorFunctionTest.java | 235 +++++++++++++++++++++
8 files changed, 639 insertions(+), 2 deletions(-)
diff --git a/docs/en/transform-v2/sql-functions.md
b/docs/en/transform-v2/sql-functions.md
index 7fcaf0344c..857932c2bd 100644
--- a/docs/en/transform-v2/sql-functions.md
+++ b/docs/en/transform-v2/sql-functions.md
@@ -1221,4 +1221,44 @@ Calculates the Euclidean (L2) distance between two
vectors.
Example:
-L2_DISTANCE(vector1, vector2)
\ No newline at end of file
+L2_DISTANCE(vector1, vector2)
+
+### VECTOR_REDUCE
+
+```VECTOR_REDUCE(vector_field, target_dimension, method)```
+
+Generic vector dimension reduction function that supports multiple reduction
methods.
+
+**Parameters:**
+- `vector_field`: The vector field to reduce (VECTOR type)
+- `target_dimension`: The target dimension (INTEGER, must be smaller than
source dimension)
+- `method`: The reduction method (STRING):
+ - **'TRUNCATE'**: Truncates the vector by keeping only the first N elements.
This is the simplest and fastest dimension reduction method, but may lose
important information in the truncated dimensions.
+ - **'RANDOM_PROJECTION'**: Uses Gaussian random projection with normally
distributed random matrix. This method preserves relative distances between
vectors while reducing dimensionality, following the Johnson-Lindenstrauss
lemma.
+ - **'SPARSE_RANDOM_PROJECTION'**: Uses sparse random projection where matrix
elements are mostly zero (±√3, 0). This is more computationally efficient than
regular random projection while maintaining similar distance preservation
properties.
+
+**Returns:** VECTOR type with reduced dimensions
+
+**Example:**
+```sql
+SELECT id, VECTOR_REDUCE(embedding, 256, 'TRUNCATE') as reduced_embedding FROM
table
+SELECT id, VECTOR_REDUCE(embedding, 128, 'RANDOM_PROJECTION') as
reduced_embedding FROM table
+SELECT id, VECTOR_REDUCE(embedding, 64, 'SPARSE_RANDOM_PROJECTION') as
reduced_embedding FROM table
+```
+
+### VECTOR_NORMALIZE
+
+```VECTOR_NORMALIZE(vector_field)```
+
+Normalizes a vector to unit length (magnitude = 1). This is useful for
computing cosine similarity.
+
+**Parameters:**
+- `vector_field`: The vector field to normalize (VECTOR type)
+
+**Returns:** VECTOR type - the normalized vector
+
+**Example:**
+```sql
+SELECT id, VECTOR_NORMALIZE(embedding) as normalized_embedding FROM table
+```
+
diff --git a/docs/zh/transform-v2/sql-functions.md
b/docs/zh/transform-v2/sql-functions.md
index ad47beeb4a..188d12f665 100644
--- a/docs/zh/transform-v2/sql-functions.md
+++ b/docs/zh/transform-v2/sql-functions.md
@@ -1215,4 +1215,43 @@ L1_DISTANCE(vector1, vector2)
示例:
-L2_DISTANCE(vector1, vector2)
\ No newline at end of file
+L2_DISTANCE(vector1, vector2)
+
+### VECTOR_REDUCE
+
+```VECTOR_REDUCE(vector_field, target_dimension, method)```
+
+通用向量降维函数,支持多种降维方法。
+
+**参数:**
+- `vector_field`: 要降维的向量字段 (VECTOR 类型)
+- `target_dimension`: 目标维度 (INTEGER,必须小于源维度)
+- `method`: 降维方法 (STRING):
+ - **'TRUNCATE'**: 截断法,通过保留前N个元素来缩减向量维度。这是最简单、最快速的降维方法,但可能会丢失被截断维度中的重要信息。
+ - **'RANDOM_PROJECTION'**:
随机投影法,使用高斯随机投影和正态分布的随机矩阵。该方法在降维的同时保持向量间的相对距离,遵循Johnson-Lindenstrauss引理。
+ - **'SPARSE_RANDOM_PROJECTION'**: 稀疏随机投影法,矩阵元素大多为零(±√3,
0)。比常规随机投影在计算上更高效,同时保持相似的距离保持特性。
+
+**返回值:** 降维后的 VECTOR 类型
+
+**示例:**
+```sql
+SELECT id, VECTOR_REDUCE(embedding, 256, 'TRUNCATE') as reduced_embedding FROM
table
+SELECT id, VECTOR_REDUCE(embedding, 128, 'RANDOM_PROJECTION') as
reduced_embedding FROM table
+SELECT id, VECTOR_REDUCE(embedding, 64, 'SPARSE_RANDOM_PROJECTION') as
reduced_embedding FROM table
+```
+
+### VECTOR_NORMALIZE
+
+```VECTOR_NORMALIZE(vector_field)```
+
+将向量归一化为单位长度(模长 = 1)。这对于计算余弦相似度很有用。
+
+**参数:**
+- `vector_field`: 要归一化的向量字段 (VECTOR 类型)
+
+**返回值:** VECTOR 类型 - 归一化后的向量
+
+**示例:**
+```sql
+SELECT id, VECTOR_NORMALIZE(embedding) as normalized_embedding FROM table
+```
\ No newline at end of file
diff --git
a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-2/src/test/java/org/apache/seatunnel/e2e/transform/TestSQLIT.java
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-2/src/test/java/org/apache/seatunnel/e2e/transform/TestSQLIT.java
index cb588d0aef..6d9dd11856 100644
---
a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-2/src/test/java/org/apache/seatunnel/e2e/transform/TestSQLIT.java
+++
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-2/src/test/java/org/apache/seatunnel/e2e/transform/TestSQLIT.java
@@ -86,6 +86,18 @@ public class TestSQLIT extends TestSuiteBase {
Assertions.assertEquals(0, multiIfSql.getExitCode());
}
+ @TestTemplate
+ @DisabledOnContainer(
+ value = {},
+ type = {EngineType.SPARK},
+ disabledReason = "Vector functions are not supported in Spark
engine")
+ public void testVectorFunctions(TestContainer container)
+ throws IOException, InterruptedException {
+ Container.ExecResult vectorFunctionResult =
+ container.executeJob("/sql_transform/func_vector.conf");
+ Assertions.assertEquals(0, vectorFunctionResult.getExitCode());
+ }
+
@TestTemplate
public void testSQLTransformMultiTable(TestContainer container)
throws IOException, InterruptedException {
diff --git
a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-2/src/test/resources/sql_transform/func_vector.conf
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-2/src/test/resources/sql_transform/func_vector.conf
new file mode 100644
index 0000000000..daaaa3a8d5
--- /dev/null
+++
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-2/src/test/resources/sql_transform/func_vector.conf
@@ -0,0 +1,142 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+######
+###### This config file is a demonstration of vector functions in SQL transform
+######
+
+env {
+ parallelism = 1
+ job.mode = "BATCH"
+ checkpoint.interval = 10000
+}
+
+source {
+ FakeSource {
+ plugin_output = "fake"
+ schema = {
+ fields {
+ id = "int"
+ name = "string"
+ vector_field = "array<float>"
+ vector_field2 = "array<float>"
+ }
+ }
+ rows = [
+ {
+ fields = [1, "test1", [1.0, 2.0, 3.0, 4.0, 5.0], [1.0, 2.0, 3.0, 4.0,
5.0]]
+ kind = INSERT
+ },
+ {
+ fields = [2, "test2", [2.0, 4.0, 6.0, 8.0, 10.0], [0.6, 0.8, 0.0, 0.0,
0.0]]
+ kind = INSERT
+ },
+ {
+ fields = [3, "test3", [3.0, 4.0, 0.0, 0.0, 0.0], [3.0, 4.0, 0.0, 0.0,
0.0]]
+ kind = INSERT
+ }
+ ]
+ }
+}
+
+transform {
+ Sql {
+ plugin_input = "fake"
+ plugin_output = "fake1"
+ query = """SELECT
+ id,
+ name,
+ VECTOR_DIMS(vector_field) as original_dim,
+ VECTOR_DIMS(VECTOR_REDUCE(vector_field, 3, 'TRUNCATE')) as truncated_dim,
+ VECTOR_DIMS(VECTOR_REDUCE(vector_field, 3, 'RANDOM_PROJECTION')) as
projected_dim,
+ VECTOR_DIMS(VECTOR_REDUCE(vector_field, 3, 'SPARSE_RANDOM_PROJECTION'))
as sparse_projected_dim,
+ VECTOR_DIMS(VECTOR_NORMALIZE(vector_field)) as normalized_dim
+ FROM dual"""
+ }
+}
+
+sink {
+ Assert {
+ plugin_input = "fake1"
+ rules = {
+ field_rules = [
+ {
+ field_name = "id"
+ field_type = "int"
+ field_value = [
+ {
+ rule_type = NOT_NULL
+ }
+ ]
+ },
+ {
+ field_name = "name"
+ field_type = "string"
+ field_value = [
+ {
+ rule_type = NOT_NULL
+ }
+ ]
+ },
+ {
+ field_name = "original_dim"
+ field_type = "int"
+ field_value = [
+ {equals_to = 5}
+ ]
+ },
+ {
+ field_name = "truncated_dim"
+ field_type = "int"
+ field_value = [
+ {equals_to = 3}
+ ]
+ },
+ {
+ field_name = "projected_dim"
+ field_type = "int"
+ field_value = [
+ {equals_to = 3}
+ ]
+ },
+ {
+ field_name = "sparse_projected_dim"
+ field_type = "int"
+ field_value = [
+ {equals_to = 3}
+ ]
+ },
+ {
+ field_name = "normalized_dim"
+ field_type = "int"
+ field_value = [
+ {equals_to = 5}
+ ]
+ }
+ ]
+ row_rules = [
+ {
+ rule_type = MAX_ROW
+ rule_value = 3
+ },
+ {
+ rule_type = MIN_ROW
+ rule_value = 3
+ }
+ ]
+ }
+ }
+}
\ No newline at end of file
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLFunction.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLFunction.java
index fef526d0d7..86eaa9ccad 100644
---
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLFunction.java
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLFunction.java
@@ -212,6 +212,9 @@ public class ZetaSQLFunction {
public static final String VECTOR_NORM = "VECTOR_NORM";
public static final String INNER_PRODUCT = "INNER_PRODUCT";
+ public static final String VECTOR_REDUCE = "VECTOR_REDUCE";
+ public static final String VECTOR_NORMALIZE = "VECTOR_NORMALIZE";
+
private final SeaTunnelRowType inputRowType;
private final ZetaSQLType zetaSQLType;
@@ -619,6 +622,11 @@ public class ZetaSQLFunction {
return VectorFunction.vectorNorm(args);
case INNER_PRODUCT:
return VectorFunction.innerProduct(args);
+ case VECTOR_REDUCE:
+ return VectorFunction.vectorReduce(
+ args.get(0), (Integer) args.get(1), (String)
args.get(2));
+ case VECTOR_NORMALIZE:
+ return VectorFunction.vectorNormalize(args.get(0));
default:
for (ZetaUDF udf : udfList) {
if (udf.functionName().equalsIgnoreCase(functionName)) {
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLType.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLType.java
index 067fab2481..83c9550bed 100644
---
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLType.java
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLType.java
@@ -25,6 +25,7 @@ import org.apache.seatunnel.api.table.type.MapType;
import org.apache.seatunnel.api.table.type.SeaTunnelDataType;
import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
import org.apache.seatunnel.api.table.type.SqlType;
+import org.apache.seatunnel.api.table.type.VectorType;
import org.apache.seatunnel.common.exception.CommonErrorCodeDeprecated;
import org.apache.seatunnel.transform.exception.TransformException;
import org.apache.seatunnel.transform.sql.zeta.functions.ArrayFunction;
@@ -489,6 +490,10 @@ public class ZetaSQLType {
case ZetaSQLFunction.MOD:
// Result has the same type as second argument
return
getExpressionType(function.getParameters().getExpressions().get(1));
+ // Vector functions
+ case ZetaSQLFunction.VECTOR_REDUCE:
+ case ZetaSQLFunction.VECTOR_NORMALIZE:
+ return VectorType.VECTOR_FLOAT_TYPE;
default:
for (ZetaUDF udf : udfList) {
if
(udf.functionName().equalsIgnoreCase(function.getName())) {
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/functions/VectorFunction.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/functions/VectorFunction.java
index 7b37acdfbd..e10688702d 100644
---
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/functions/VectorFunction.java
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/functions/VectorFunction.java
@@ -25,9 +25,11 @@ import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
+import java.util.Random;
import java.util.stream.IntStream;
public class VectorFunction {
+ private static final Random random = new Random(42);
public static Object cosineDistance(List<Object> args) {
if (args.size() != 2) {
@@ -199,4 +201,158 @@ public class VectorFunction {
String.format("Unsupported vector type: %s",
obj.getClass().getName()));
}
}
+
+ /** Truncate vector to target dimension Usage: VECTOR_REDUCE(embedding,
256, 'TRUNCATE') */
+ public static Object vectorTruncate(Object vectorData, Integer
targetDimension) {
+ if (vectorData == null || targetDimension == null) {
+ return null;
+ }
+
+ Float[] sourceVector = convertToFloatArray(vectorData);
+ if (sourceVector.length <= targetDimension) {
+ return vectorData; // No need to truncate
+ }
+
+ Float[] result = new Float[targetDimension];
+ System.arraycopy(sourceVector, 0, result, 0, targetDimension);
+ return VectorUtils.toByteBuffer(result);
+ }
+
+ /**
+ * Random projection for dimension reduction Usage:
VECTOR_REDUCE(embedding, 128,
+ * 'RANDOM_PROJECTION')
+ */
+ public static Object vectorRandomProjection(Object vectorData, Integer
targetDimension) {
+ if (vectorData == null || targetDimension == null) {
+ return null;
+ }
+
+ Float[] sourceVector = convertToFloatArray(vectorData);
+ if (sourceVector.length <= targetDimension) {
+ return vectorData; // No need to reduce
+ }
+
+ float[][] projectionMatrix =
+ createGaussianProjectionMatrix(sourceVector.length,
targetDimension);
+ Float[] result = applyProjection(sourceVector, projectionMatrix,
targetDimension);
+ return VectorUtils.toByteBuffer(result);
+ }
+
+ /**
+ * Sparse random projection for dimension reduction Usage:
VECTOR_REDUCE(embedding, 64,
+ * 'SPARSE_RANDOM_PROJECTION')
+ */
+ public static Object vectorSparseProjection(Object vectorData, Integer
targetDimension) {
+ if (vectorData == null || targetDimension == null) {
+ return null;
+ }
+
+ Float[] sourceVector = convertToFloatArray(vectorData);
+ if (sourceVector.length <= targetDimension) {
+ return vectorData; // No need to reduce
+ }
+
+ float[][] projectionMatrix =
+ createSparseProjectionMatrix(sourceVector.length,
targetDimension);
+ Float[] result = applyProjection(sourceVector, projectionMatrix,
targetDimension);
+ return VectorUtils.toByteBuffer(result);
+ }
+
+ /**
+ * Generic vector dimension reduction function Usage:
VECTOR_REDUCE(vector_field,
+ * target_dimension, method) method: 'TRUNCATE', 'RANDOM_PROJECTION',
'SPARSE_RANDOM_PROJECTION'
+ */
+ public static Object vectorReduce(Object vectorData, Integer
targetDimension, String method) {
+ if (vectorData == null || targetDimension == null || method == null) {
+ return null;
+ }
+
+ switch (method.toUpperCase()) {
+ case "TRUNCATE":
+ return vectorTruncate(vectorData, targetDimension);
+ case "RANDOM_PROJECTION":
+ return vectorRandomProjection(vectorData, targetDimension);
+ case "SPARSE_RANDOM_PROJECTION":
+ return vectorSparseProjection(vectorData, targetDimension);
+ default:
+ throw new IllegalArgumentException("Unknown reduction method:
" + method);
+ }
+ }
+
+ /** Normalize vector to unit length Usage: VECTOR_NORMALIZE(vector_field)
*/
+ public static Object vectorNormalize(Object vectorData) {
+ if (vectorData == null) {
+ return null;
+ }
+
+ Float[] vector = convertToFloatArray(vectorData);
+ double magnitude = 0.0;
+ for (Float value : vector) {
+ if (value != null) {
+ magnitude += value * value;
+ }
+ }
+ magnitude = Math.sqrt(magnitude);
+
+ if (magnitude == 0.0) {
+ return vectorData; // Return original if zero vector
+ }
+
+ Float[] normalized = new Float[vector.length];
+ for (int i = 0; i < vector.length; i++) {
+ normalized[i] = vector[i] == null ? null : (float) (vector[i] /
magnitude);
+ }
+
+ return VectorUtils.toByteBuffer(normalized);
+ }
+
+ private static Float[] applyProjection(
+ Float[] sourceVector, float[][] projectionMatrix, int
targetDimension) {
+ Float[] result = new Float[targetDimension];
+ for (int i = 0; i < targetDimension; i++) {
+ float sum = 0.0f;
+ for (int j = 0; j < sourceVector.length; j++) {
+ if (projectionMatrix[i][j] != 0 && sourceVector[j] != null) {
+ sum += sourceVector[j] * projectionMatrix[i][j];
+ }
+ }
+ result[i] = sum;
+ }
+ return result;
+ }
+
+ private static float[][] createGaussianProjectionMatrix(
+ int sourceDimension, int targetDimension) {
+ float[][] matrix = new float[targetDimension][sourceDimension];
+ float scale = (float) Math.sqrt(1.0 / targetDimension);
+
+ for (int i = 0; i < targetDimension; i++) {
+ for (int j = 0; j < sourceDimension; j++) {
+ matrix[i][j] = (float) random.nextGaussian() * scale;
+ }
+ }
+ return matrix;
+ }
+
+ private static float[][] createSparseProjectionMatrix(
+ int sourceDimension, int targetDimension) {
+ float[][] matrix = new float[targetDimension][sourceDimension];
+ float scale = (float) Math.sqrt(3.0);
+ double p1 = 1.0 / 6.0;
+ double p2 = 2.0 / 6.0;
+
+ for (int i = 0; i < targetDimension; i++) {
+ for (int j = 0; j < sourceDimension; j++) {
+ double rand = random.nextDouble();
+ if (rand < p1) {
+ matrix[i][j] = scale;
+ } else if (rand < p2) {
+ matrix[i][j] = -scale;
+ } else {
+ matrix[i][j] = 0;
+ }
+ }
+ }
+ return matrix;
+ }
}
diff --git
a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/sql/SQLVectorFunctionTest.java
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/sql/SQLVectorFunctionTest.java
new file mode 100644
index 0000000000..e006331ad9
--- /dev/null
+++
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/sql/SQLVectorFunctionTest.java
@@ -0,0 +1,235 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.seatunnel.transform.sql;
+
+import org.apache.seatunnel.api.configuration.ReadonlyConfig;
+import org.apache.seatunnel.api.table.catalog.CatalogTable;
+import org.apache.seatunnel.api.table.catalog.PhysicalColumn;
+import org.apache.seatunnel.api.table.catalog.TableIdentifier;
+import org.apache.seatunnel.api.table.catalog.TableSchema;
+import org.apache.seatunnel.api.table.type.BasicType;
+import org.apache.seatunnel.api.table.type.SeaTunnelDataType;
+import org.apache.seatunnel.api.table.type.SeaTunnelRow;
+import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
+import org.apache.seatunnel.api.table.type.VectorType;
+import org.apache.seatunnel.common.utils.VectorUtils;
+
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+
+public class SQLVectorFunctionTest {
+
+ private static final String TEST_NAME = "vector_test";
+ private static final String[] FIELD_NAMES =
+ new String[] {"id", "vector_field", "vector_field2"};
+ private CatalogTable catalogTable;
+
+ @BeforeEach
+ void setUp() {
+ SeaTunnelRowType rowType =
+ new SeaTunnelRowType(
+ FIELD_NAMES,
+ new SeaTunnelDataType[] {
+ BasicType.INT_TYPE,
+ VectorType.VECTOR_FLOAT_TYPE,
+ VectorType.VECTOR_FLOAT_TYPE
+ });
+
+ TableSchema.Builder schemaBuilder = TableSchema.builder();
+ for (int i = 0; i < rowType.getTotalFields(); i++) {
+ PhysicalColumn column =
+ PhysicalColumn.of(
+ rowType.getFieldName(i), rowType.getFieldType(i),
0, true, null, null);
+ schemaBuilder.column(column);
+ }
+
+ catalogTable =
+ CatalogTable.of(
+ TableIdentifier.of(TEST_NAME, TEST_NAME, null,
TEST_NAME),
+ schemaBuilder.build(),
+ new HashMap<>(),
+ new ArrayList<>(),
+ "Vector function test table");
+ }
+
+ @Test
+ public void testVectorTruncate() {
+ ReadonlyConfig config =
+ ReadonlyConfig.fromMap(
+ Collections.singletonMap(
+ "query",
+ "SELECT id, VECTOR_REDUCE(vector_field,
3,'TRUNCATE') as truncated_vector FROM dual"));
+
+ SQLTransform sqlTransform = new SQLTransform(config, catalogTable);
+ TableSchema tableSchema = sqlTransform.transformTableSchema();
+
+ // Create test data
+ Float[] sourceVector = new Float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
+ ByteBuffer vectorBuffer = VectorUtils.toByteBuffer(sourceVector);
+
+ SeaTunnelRow inputRow = new SeaTunnelRow(new Object[] {1,
vectorBuffer, null});
+ List<SeaTunnelRow> result = sqlTransform.transformRow(inputRow);
+
+ Assertions.assertNotNull(result);
+ Assertions.assertEquals(1, result.size());
+
+ SeaTunnelRow outputRow = result.get(0);
+ Assertions.assertEquals(1, outputRow.getField(0));
+
+ ByteBuffer resultVector = (ByteBuffer) outputRow.getField(1);
+ Float[] resultArray = VectorUtils.toFloatArray(resultVector);
+ Assertions.assertEquals(3, resultArray.length);
+ Assertions.assertEquals(1.0f, resultArray[0], 0.001f);
+ Assertions.assertEquals(2.0f, resultArray[1], 0.001f);
+ Assertions.assertEquals(3.0f, resultArray[2], 0.001f);
+ }
+
+ @Test
+ public void testVectorNormalize() {
+ ReadonlyConfig config =
+ ReadonlyConfig.fromMap(
+ Collections.singletonMap(
+ "query",
+ "SELECT id, VECTOR_NORMALIZE(vector_field) as
normalized_vector FROM dual"));
+
+ SQLTransform sqlTransform = new SQLTransform(config, catalogTable);
+
+ // Create test data: [3, 4] normalized should be [0.6, 0.8]
+ Float[] sourceVector = new Float[] {3.0f, 4.0f};
+ ByteBuffer vectorBuffer = VectorUtils.toByteBuffer(sourceVector);
+
+ SeaTunnelRow inputRow = new SeaTunnelRow(new Object[] {1,
vectorBuffer, null});
+ List<SeaTunnelRow> result = sqlTransform.transformRow(inputRow);
+
+ Assertions.assertNotNull(result);
+ Assertions.assertEquals(1, result.size());
+
+ SeaTunnelRow outputRow = result.get(0);
+ Assertions.assertEquals(1, outputRow.getField(0));
+
+ ByteBuffer resultVector = (ByteBuffer) outputRow.getField(1);
+ Float[] resultArray = VectorUtils.toFloatArray(resultVector);
+ Assertions.assertEquals(2, resultArray.length);
+ Assertions.assertEquals(0.6f, resultArray[0], 0.001f);
+ Assertions.assertEquals(0.8f, resultArray[1], 0.001f);
+ }
+
+ @Test
+ public void testVectorReduce() {
+ ReadonlyConfig config =
+ ReadonlyConfig.fromMap(
+ Collections.singletonMap(
+ "query",
+ "SELECT id, VECTOR_REDUCE(vector_field, 3,
'TRUNCATE') as reduced_vector FROM dual"));
+
+ SQLTransform sqlTransform = new SQLTransform(config, catalogTable);
+
+ // Create test data
+ Float[] sourceVector = new Float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
+ ByteBuffer vectorBuffer = VectorUtils.toByteBuffer(sourceVector);
+
+ SeaTunnelRow inputRow = new SeaTunnelRow(new Object[] {1,
vectorBuffer, null});
+ List<SeaTunnelRow> result = sqlTransform.transformRow(inputRow);
+
+ Assertions.assertNotNull(result);
+ Assertions.assertEquals(1, result.size());
+
+ SeaTunnelRow outputRow = result.get(0);
+ Assertions.assertEquals(1, outputRow.getField(0));
+
+ ByteBuffer resultVector = (ByteBuffer) outputRow.getField(1);
+ Float[] resultArray = VectorUtils.toFloatArray(resultVector);
+ Assertions.assertEquals(3, resultArray.length);
+ Assertions.assertEquals(1.0f, resultArray[0], 0.001f);
+ Assertions.assertEquals(2.0f, resultArray[1], 0.001f);
+ Assertions.assertEquals(3.0f, resultArray[2], 0.001f);
+ }
+
+ @Test
+ public void testVectorRandomProjection() {
+ ReadonlyConfig config =
+ ReadonlyConfig.fromMap(
+ Collections.singletonMap(
+ "query",
+ "SELECT id, VECTOR_REDUCE(vector_field,
3,'RANDOM_PROJECTION') as projected_vector FROM dual"));
+
+ SQLTransform sqlTransform = new SQLTransform(config, catalogTable);
+
+ // Create test data
+ Float[] sourceVector = new Float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
+ ByteBuffer vectorBuffer = VectorUtils.toByteBuffer(sourceVector);
+
+ SeaTunnelRow inputRow = new SeaTunnelRow(new Object[] {1,
vectorBuffer, null});
+ List<SeaTunnelRow> result = sqlTransform.transformRow(inputRow);
+
+ Assertions.assertNotNull(result);
+ Assertions.assertEquals(1, result.size());
+
+ SeaTunnelRow outputRow = result.get(0);
+ Assertions.assertEquals(1, outputRow.getField(0));
+
+ ByteBuffer resultVector = (ByteBuffer) outputRow.getField(1);
+ Float[] resultArray = VectorUtils.toFloatArray(resultVector);
+ Assertions.assertEquals(3, resultArray.length);
+
+ // Just verify that we got a result with the expected dimension
+ for (Float value : resultArray) {
+ Assertions.assertNotNull(value);
+ }
+ }
+
+ @Test
+ public void testVectorSparseProjection() {
+ ReadonlyConfig config =
+ ReadonlyConfig.fromMap(
+ Collections.singletonMap(
+ "query",
+ "SELECT id, VECTOR_REDUCE(vector_field,
3,'SPARSE_RANDOM_PROJECTION') as sparse_projected_vector FROM dual"));
+
+ SQLTransform sqlTransform = new SQLTransform(config, catalogTable);
+
+ // Create test data
+ Float[] sourceVector = new Float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
+ ByteBuffer vectorBuffer = VectorUtils.toByteBuffer(sourceVector);
+
+ SeaTunnelRow inputRow = new SeaTunnelRow(new Object[] {1,
vectorBuffer, null});
+ List<SeaTunnelRow> result = sqlTransform.transformRow(inputRow);
+
+ Assertions.assertNotNull(result);
+ Assertions.assertEquals(1, result.size());
+
+ SeaTunnelRow outputRow = result.get(0);
+ Assertions.assertEquals(1, outputRow.getField(0));
+
+ ByteBuffer resultVector = (ByteBuffer) outputRow.getField(1);
+ Float[] resultArray = VectorUtils.toFloatArray(resultVector);
+ Assertions.assertEquals(3, resultArray.length);
+
+ // Just verify that we got a result with the expected dimension
+ for (Float value : resultArray) {
+ Assertions.assertNotNull(value);
+ }
+ }
+}