This is an automated email from the ASF dual-hosted git repository.
corgy pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/seatunnel.git
The following commit(s) were added to refs/heads/dev by this push:
new f78405947d [Fix] [Transform-V2] Fix embedding output columns vector
dimension (#9646)
f78405947d is described below
commit f78405947dea0da786884963383921a865cd63a6
Author: loupipalien <[email protected]>
AuthorDate: Mon Aug 4 13:53:41 2025 +0800
[Fix] [Transform-V2] Fix embedding output columns vector dimension (#9646)
---
.../nlpmodel/embedding/EmbeddingTransform.java | 8 ++-
.../embedding/EmbeddingTransformTest.java | 64 ++++++++++++++++++++++
2 files changed, 70 insertions(+), 2 deletions(-)
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransform.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransform.java
index e748559b76..71310a1d54 100644
---
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransform.java
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransform.java
@@ -17,6 +17,8 @@
package org.apache.seatunnel.transform.nlpmodel.embedding;
+import
org.apache.seatunnel.shade.com.google.common.annotations.VisibleForTesting;
+
import org.apache.seatunnel.api.configuration.ReadonlyConfig;
import org.apache.seatunnel.api.table.catalog.CatalogTable;
import org.apache.seatunnel.api.table.catalog.Column;
@@ -52,7 +54,7 @@ public class EmbeddingTransform extends
MultipleFieldOutputTransform {
private final ReadonlyConfig config;
private List<String> fieldNames;
private List<Integer> fieldOriginalIndexes;
- private Model model;
+ private transient Model model;
private Integer dimension;
public EmbeddingTransform(
@@ -212,7 +214,9 @@ public class EmbeddingTransform extends
MultipleFieldOutputTransform {
}
@Override
- protected Column[] getOutputColumns() {
+ @VisibleForTesting
+ public Column[] getOutputColumns() {
+ tryOpen();
Column[] columns = new Column[fieldNames.size()];
for (int i = 0; i < fieldNames.size(); i++) {
columns[i] =
diff --git
a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/EmbeddingTransformTest.java
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/EmbeddingTransformTest.java
new file mode 100644
index 0000000000..df19ad50e4
--- /dev/null
+++
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/EmbeddingTransformTest.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.seatunnel.transform.embedding;
+
+import
org.apache.seatunnel.shade.com.fasterxml.jackson.core.JsonProcessingException;
+import
org.apache.seatunnel.shade.com.fasterxml.jackson.core.type.TypeReference;
+import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.ObjectMapper;
+
+import org.apache.seatunnel.api.configuration.ReadonlyConfig;
+import org.apache.seatunnel.api.table.catalog.CatalogTable;
+import org.apache.seatunnel.api.table.catalog.CatalogTableUtil;
+import org.apache.seatunnel.api.table.catalog.Column;
+import org.apache.seatunnel.transform.nlpmodel.embedding.EmbeddingTransform;
+
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+import java.util.Map;
+
+public class EmbeddingTransformTest {
+
+ @Test
+ void testOutputColumns() throws JsonProcessingException {
+ ObjectMapper objectMapper = new ObjectMapper();
+
+ String sourceConfig =
+
"{\"path\":\"/seatunnel/test_csv_data.csv\",\"bucket\":\"s3a://ltchen\",\"fs.s3a.endpoint\":\"tos-s3-cn-beijing.volces.com\",\"fs.s3a.aws.credentials.provider\":\"org.apache.hadoop.fs.s3a.SimpleAWSCredentialsProvider\",\"file_format_type\":\"csv\",\"access_key\":\"xxx\",\"secret_key\":\"xxx\",\"csv_use_header_line\":true,\"field_delimiter\":\",\",\"schema\":{\"fields\":{\"id\":\"int\",\"code\":\"int\",\"data\":\"string\",\"success\":\"boolean\"},\"primaryKey\":{\"name\":\
[...]
+ Map<String, Object> sourceConfigMap =
+ objectMapper.readValue(sourceConfig, new
TypeReference<Map<String, Object>>() {});
+ ReadonlyConfig readonlyConfig =
ReadonlyConfig.fromMap(sourceConfigMap);
+ CatalogTable inputCatalogTable =
CatalogTableUtil.buildWithConfig("S3File", readonlyConfig);
+
+ int dimension = 1024;
+ String embeddingConfig =
+
"{\"model_provider\":\"AMAZON\",\"model\":\"amazon.titan-embed-text-v2:0\",\"aws_region\":
\"us-east-1\", \"api_key\":\"xxx\",\"secret_key\":\"xxx\",\"api_path\":
\"https://aws.amazon.com/bedrock/amazon-models\", \"dimension\": "
+ + dimension
+ +
",\"vectorization_fields\":{\"data_vector\":\"data\"},\"plugin_name\":\"Embedding\"}";
+ Map<String, Object> embeddingConfigMap =
+ objectMapper.readValue(
+ embeddingConfig, new TypeReference<Map<String,
Object>>() {});
+ ReadonlyConfig config = ReadonlyConfig.fromMap(embeddingConfigMap);
+ EmbeddingTransform embeddingTransform = new EmbeddingTransform(config,
inputCatalogTable);
+
+ Column[] columns = embeddingTransform.getOutputColumns();
+ for (Column column : columns) {
+ Assertions.assertEquals(dimension, column.getScale());
+ }
+ }
+}