This is an automated email from the ASF dual-hosted git repository.

corgy pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/seatunnel.git


The following commit(s) were added to refs/heads/dev by this push:
     new 9c18e8ca82 [Hotfix][Transform-V2] Fix some model return number of 
dimensions (#9644)
9c18e8ca82 is described below

commit 9c18e8ca82f8bbb6b65c666e72fa0369b796e35c
Author: loupipalien <[email protected]>
AuthorDate: Mon Aug 4 13:56:06 2025 +0800

    [Hotfix][Transform-V2] Fix some model return number of dimensions (#9644)
---
 .../embedding/remote/custom/CustomModel.java       |  15 ++-
 .../embedding/remote/doubao/DoubaoModel.java       |  13 +-
 .../embedding/remote/openai/OpenAIModel.java       |  13 +-
 .../embedding/EmbeddingModelDimensionTest.java     | 145 +++++++++++++++++++++
 4 files changed, 180 insertions(+), 6 deletions(-)

diff --git 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/custom/CustomModel.java
 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/custom/CustomModel.java
index 179315f956..8f9970e9b4 100644
--- 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/custom/CustomModel.java
+++ 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/custom/CustomModel.java
@@ -57,13 +57,24 @@ public class CustomModel extends AbstractModel {
             Map<String, Object> body,
             String parse,
             Integer vectorizedNumber) {
+        this(model, apiPath, header, body, parse, vectorizedNumber, 
HttpClients.createDefault());
+    }
+
+    public CustomModel(
+            String model,
+            String apiPath,
+            Map<String, String> header,
+            Map<String, Object> body,
+            String parse,
+            Integer vectorizedNumber,
+            CloseableHttpClient client) {
         super(vectorizedNumber);
         this.apiPath = apiPath;
         this.model = model;
         this.header = header;
         this.body = body;
         this.parse = parse;
-        this.client = HttpClients.createDefault();
+        this.client = client;
     }
 
     @Override
@@ -73,7 +84,7 @@ public class CustomModel extends AbstractModel {
 
     @Override
     public Integer dimension() throws IOException {
-        return vectorGeneration(new Object[] {DIMENSION_EXAMPLE}).size();
+        return vectorGeneration(new Object[] 
{DIMENSION_EXAMPLE}).get(0).size();
     }
 
     private List<List<Float>> vectorGeneration(Object[] fields) throws 
IOException {
diff --git 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java
 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java
index f2b1e348c7..2174e61996 100644
--- 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java
+++ 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java
@@ -46,11 +46,20 @@ public class DoubaoModel extends AbstractModel {
     private final String apiPath;
 
     public DoubaoModel(String apiKey, String model, String apiPath, Integer 
vectorizedNumber) {
+        this(apiKey, model, apiPath, vectorizedNumber, 
HttpClients.createDefault());
+    }
+
+    public DoubaoModel(
+            String apiKey,
+            String model,
+            String apiPath,
+            Integer vectorizedNumber,
+            CloseableHttpClient client) {
         super(vectorizedNumber);
         this.apiKey = apiKey;
         this.model = model;
         this.apiPath = apiPath;
-        this.client = HttpClients.createDefault();
+        this.client = client;
     }
 
     @Override
@@ -60,7 +69,7 @@ public class DoubaoModel extends AbstractModel {
 
     @Override
     public Integer dimension() throws IOException {
-        return vectorGeneration(new Object[] {DIMENSION_EXAMPLE}).size();
+        return vectorGeneration(new Object[] 
{DIMENSION_EXAMPLE}).get(0).size();
     }
 
     private List<List<Float>> vectorGeneration(Object[] fields) throws 
IOException {
diff --git 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/openai/OpenAIModel.java
 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/openai/OpenAIModel.java
index 2a45cc829f..932b77df92 100644
--- 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/openai/OpenAIModel.java
+++ 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/openai/OpenAIModel.java
@@ -45,11 +45,20 @@ public class OpenAIModel extends AbstractModel {
     private final String apiPath;
 
     public OpenAIModel(String apiKey, String model, String apiPath, Integer 
vectorizedNumber) {
+        this(apiKey, model, apiPath, vectorizedNumber, 
HttpClients.createDefault());
+    }
+
+    public OpenAIModel(
+            String apiKey,
+            String model,
+            String apiPath,
+            Integer vectorizedNumber,
+            CloseableHttpClient client) {
         super(vectorizedNumber);
         this.apiKey = apiKey;
         this.model = model;
         this.apiPath = apiPath;
-        this.client = HttpClients.createDefault();
+        this.client = client;
     }
 
     @Override
@@ -62,7 +71,7 @@ public class OpenAIModel extends AbstractModel {
 
     @Override
     public Integer dimension() throws IOException {
-        return vectorGeneration(new Object[] {DIMENSION_EXAMPLE}).size();
+        return vectorGeneration(new Object[] 
{DIMENSION_EXAMPLE}).get(0).size();
     }
 
     private List<List<Float>> vectorGeneration(Object[] fields) throws 
IOException {
diff --git 
a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/EmbeddingModelDimensionTest.java
 
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/EmbeddingModelDimensionTest.java
new file mode 100644
index 0000000000..efb65bb341
--- /dev/null
+++ 
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/EmbeddingModelDimensionTest.java
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.seatunnel.transform.embedding;
+
+import 
org.apache.seatunnel.transform.nlpmodel.embedding.remote.custom.CustomModel;
+import 
org.apache.seatunnel.transform.nlpmodel.embedding.remote.doubao.DoubaoModel;
+import 
org.apache.seatunnel.transform.nlpmodel.embedding.remote.openai.OpenAIModel;
+
+import org.apache.http.ProtocolVersion;
+import org.apache.http.client.methods.CloseableHttpResponse;
+import org.apache.http.impl.client.CloseableHttpClient;
+import org.apache.http.message.BasicStatusLine;
+import org.apache.http.util.EntityUtils;
+
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+import org.mockito.MockedStatic;
+import org.mockito.Mockito;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.concurrent.ThreadLocalRandom;
+
+public class EmbeddingModelDimensionTest {
+
+    @Test
+    void testCustomModelDimension() throws IOException {
+        CloseableHttpClient client = Mockito.mock(CloseableHttpClient.class);
+        CustomModel model =
+                new CustomModel(
+                        "modelName",
+                        "https://api.custom.com/v1/chat/completions";,
+                        new HashMap<>(),
+                        new HashMap<>(),
+                        "$.data[*].embedding",
+                        1,
+                        client);
+
+        int dimension = ThreadLocalRandom.current().nextInt(1024, 4097);
+        List<Float> vector = generateVector(dimension);
+        String responseStr =
+                "{\"created\":\"1753944315\",\"data\":[{\"embedding\":"
+                        + vector
+                        + 
",\"index\":0,\"object\":\"embedding\"}],\"id\":\"021753944315445384c5dcd581d413bdefc6446277658dfef1939\",\"model\":\"doubao-embedding-text-240715\",\"object\":\"list\",\"usage\":{\"completionTokens\":0,\"promptTokens\":3,\"totalTokens\":3}}";
+
+        try (MockedStatic<EntityUtils> entityUtils = 
Mockito.mockStatic(EntityUtils.class)) {
+            CloseableHttpResponse response = 
Mockito.mock(CloseableHttpResponse.class);
+            Mockito.when(client.execute(Mockito.any())).thenReturn(response);
+            Mockito.when(response.getStatusLine())
+                    .thenReturn(new BasicStatusLine(new 
ProtocolVersion("HTTP", 1, 1), 200, "OK"));
+            entityUtils
+                    .when(() -> EntityUtils.toString(response.getEntity()))
+                    .thenReturn(responseStr);
+
+            Assertions.assertEquals(dimension, model.dimension());
+        }
+    }
+
+    @Test
+    void testDoubleModelDimension() throws IOException {
+        CloseableHttpClient client = Mockito.mock(CloseableHttpClient.class);
+        DoubaoModel model =
+                new DoubaoModel(
+                        "apikey",
+                        "modelName",
+                        "https://api.doubao.io/v1/chat/completions";,
+                        1,
+                        client);
+
+        int dimension = ThreadLocalRandom.current().nextInt(1024, 2561);
+        List<Float> vector = generateVector(dimension);
+        String responseStr =
+                "{\"created\":\"1753944315\",\"data\":[{\"embedding\":"
+                        + vector
+                        + 
",\"index\":0,\"object\":\"embedding\"}],\"id\":\"021753944315445384c5dcd581d413bdefc6446277658dfef1939\",\"model\":\"doubao-embedding-text-240715\",\"object\":\"list\",\"usage\":{\"completionTokens\":0,\"promptTokens\":3,\"totalTokens\":3}}";
+
+        try (MockedStatic<EntityUtils> entityUtils = 
Mockito.mockStatic(EntityUtils.class)) {
+            CloseableHttpResponse response = 
Mockito.mock(CloseableHttpResponse.class);
+            Mockito.when(client.execute(Mockito.any())).thenReturn(response);
+            Mockito.when(response.getStatusLine())
+                    .thenReturn(new BasicStatusLine(new 
ProtocolVersion("HTTP", 1, 1), 200, "OK"));
+            entityUtils
+                    .when(() -> EntityUtils.toString(response.getEntity()))
+                    .thenReturn(responseStr);
+
+            Assertions.assertEquals(dimension, model.dimension());
+        }
+    }
+
+    @Test
+    void testOpenAIModelDimension() throws IOException {
+        CloseableHttpClient client = Mockito.mock(CloseableHttpClient.class);
+        OpenAIModel model =
+                new OpenAIModel(
+                        "apikey",
+                        "modelName",
+                        "https://api.openai.com/v1/chat/completions";,
+                        1,
+                        client);
+
+        int dimension = ThreadLocalRandom.current().nextInt(1024, 1537);
+        List<Float> vector = generateVector(dimension);
+        String responseStr =
+                
"{\"object\":\"list\",\"data\":[{\"object\":\"embedding\",\"embedding\":"
+                        + vector
+                        + 
",\"index\":0}],\"model\":\"text-embedding-ada-002\",\"usage\":{\"prompt_tokens\":8,\"total_tokens\":8}}";
+
+        try (MockedStatic<EntityUtils> entityUtils = 
Mockito.mockStatic(EntityUtils.class)) {
+            CloseableHttpResponse response = 
Mockito.mock(CloseableHttpResponse.class);
+            Mockito.when(response.getStatusLine())
+                    .thenReturn(new BasicStatusLine(new 
ProtocolVersion("HTTP", 1, 1), 200, "OK"));
+            Mockito.when(client.execute(Mockito.any())).thenReturn(response);
+            entityUtils
+                    .when(() -> EntityUtils.toString(response.getEntity()))
+                    .thenReturn(responseStr);
+
+            Assertions.assertEquals(dimension, model.dimension());
+        }
+    }
+
+    private List<Float> generateVector(int dimension) {
+        List<Float> vector = new ArrayList<>();
+        for (int i = 0; i < dimension; i++) {
+            vector.add(ThreadLocalRandom.current().nextFloat());
+        }
+        return vector;
+    }
+}

Reply via email to