This is an automated email from the ASF dual-hosted git repository.

corgy pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/seatunnel.git


The following commit(s) were added to refs/heads/dev by this push:
     new f78405947d [Fix] [Transform-V2] Fix embedding output columns vector 
dimension (#9646)
f78405947d is described below

commit f78405947dea0da786884963383921a865cd63a6
Author: loupipalien <[email protected]>
AuthorDate: Mon Aug 4 13:53:41 2025 +0800

    [Fix] [Transform-V2] Fix embedding output columns vector dimension (#9646)
---
 .../nlpmodel/embedding/EmbeddingTransform.java     |  8 ++-
 .../embedding/EmbeddingTransformTest.java          | 64 ++++++++++++++++++++++
 2 files changed, 70 insertions(+), 2 deletions(-)

diff --git 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransform.java
 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransform.java
index e748559b76..71310a1d54 100644
--- 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransform.java
+++ 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransform.java
@@ -17,6 +17,8 @@
 
 package org.apache.seatunnel.transform.nlpmodel.embedding;
 
+import 
org.apache.seatunnel.shade.com.google.common.annotations.VisibleForTesting;
+
 import org.apache.seatunnel.api.configuration.ReadonlyConfig;
 import org.apache.seatunnel.api.table.catalog.CatalogTable;
 import org.apache.seatunnel.api.table.catalog.Column;
@@ -52,7 +54,7 @@ public class EmbeddingTransform extends 
MultipleFieldOutputTransform {
     private final ReadonlyConfig config;
     private List<String> fieldNames;
     private List<Integer> fieldOriginalIndexes;
-    private Model model;
+    private transient Model model;
     private Integer dimension;
 
     public EmbeddingTransform(
@@ -212,7 +214,9 @@ public class EmbeddingTransform extends 
MultipleFieldOutputTransform {
     }
 
     @Override
-    protected Column[] getOutputColumns() {
+    @VisibleForTesting
+    public Column[] getOutputColumns() {
+        tryOpen();
         Column[] columns = new Column[fieldNames.size()];
         for (int i = 0; i < fieldNames.size(); i++) {
             columns[i] =
diff --git 
a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/EmbeddingTransformTest.java
 
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/EmbeddingTransformTest.java
new file mode 100644
index 0000000000..df19ad50e4
--- /dev/null
+++ 
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/EmbeddingTransformTest.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.seatunnel.transform.embedding;
+
+import 
org.apache.seatunnel.shade.com.fasterxml.jackson.core.JsonProcessingException;
+import 
org.apache.seatunnel.shade.com.fasterxml.jackson.core.type.TypeReference;
+import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.ObjectMapper;
+
+import org.apache.seatunnel.api.configuration.ReadonlyConfig;
+import org.apache.seatunnel.api.table.catalog.CatalogTable;
+import org.apache.seatunnel.api.table.catalog.CatalogTableUtil;
+import org.apache.seatunnel.api.table.catalog.Column;
+import org.apache.seatunnel.transform.nlpmodel.embedding.EmbeddingTransform;
+
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+import java.util.Map;
+
+public class EmbeddingTransformTest {
+
+    @Test
+    void testOutputColumns() throws JsonProcessingException {
+        ObjectMapper objectMapper = new ObjectMapper();
+
+        String sourceConfig =
+                
"{\"path\":\"/seatunnel/test_csv_data.csv\",\"bucket\":\"s3a://ltchen\",\"fs.s3a.endpoint\":\"tos-s3-cn-beijing.volces.com\",\"fs.s3a.aws.credentials.provider\":\"org.apache.hadoop.fs.s3a.SimpleAWSCredentialsProvider\",\"file_format_type\":\"csv\",\"access_key\":\"xxx\",\"secret_key\":\"xxx\",\"csv_use_header_line\":true,\"field_delimiter\":\",\",\"schema\":{\"fields\":{\"id\":\"int\",\"code\":\"int\",\"data\":\"string\",\"success\":\"boolean\"},\"primaryKey\":{\"name\":\
 [...]
+        Map<String, Object> sourceConfigMap =
+                objectMapper.readValue(sourceConfig, new 
TypeReference<Map<String, Object>>() {});
+        ReadonlyConfig readonlyConfig = 
ReadonlyConfig.fromMap(sourceConfigMap);
+        CatalogTable inputCatalogTable = 
CatalogTableUtil.buildWithConfig("S3File", readonlyConfig);
+
+        int dimension = 1024;
+        String embeddingConfig =
+                
"{\"model_provider\":\"AMAZON\",\"model\":\"amazon.titan-embed-text-v2:0\",\"aws_region\":
 \"us-east-1\", \"api_key\":\"xxx\",\"secret_key\":\"xxx\",\"api_path\": 
\"https://aws.amazon.com/bedrock/amazon-models\";, \"dimension\": "
+                        + dimension
+                        + 
",\"vectorization_fields\":{\"data_vector\":\"data\"},\"plugin_name\":\"Embedding\"}";
+        Map<String, Object> embeddingConfigMap =
+                objectMapper.readValue(
+                        embeddingConfig, new TypeReference<Map<String, 
Object>>() {});
+        ReadonlyConfig config = ReadonlyConfig.fromMap(embeddingConfigMap);
+        EmbeddingTransform embeddingTransform = new EmbeddingTransform(config, 
inputCatalogTable);
+
+        Column[] columns = embeddingTransform.getOutputColumns();
+        for (Column column : columns) {
+            Assertions.assertEquals(dimension, column.getScale());
+        }
+    }
+}

Reply via email to