This is an automated email from the ASF dual-hosted git repository.
corgy pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/seatunnel.git
The following commit(s) were added to refs/heads/dev by this push:
new 9c18e8ca82 [Hotfix][Transform-V2] Fix some model return number of
dimensions (#9644)
9c18e8ca82 is described below
commit 9c18e8ca82f8bbb6b65c666e72fa0369b796e35c
Author: loupipalien <[email protected]>
AuthorDate: Mon Aug 4 13:56:06 2025 +0800
[Hotfix][Transform-V2] Fix some model return number of dimensions (#9644)
---
.../embedding/remote/custom/CustomModel.java | 15 ++-
.../embedding/remote/doubao/DoubaoModel.java | 13 +-
.../embedding/remote/openai/OpenAIModel.java | 13 +-
.../embedding/EmbeddingModelDimensionTest.java | 145 +++++++++++++++++++++
4 files changed, 180 insertions(+), 6 deletions(-)
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/custom/CustomModel.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/custom/CustomModel.java
index 179315f956..8f9970e9b4 100644
---
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/custom/CustomModel.java
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/custom/CustomModel.java
@@ -57,13 +57,24 @@ public class CustomModel extends AbstractModel {
Map<String, Object> body,
String parse,
Integer vectorizedNumber) {
+ this(model, apiPath, header, body, parse, vectorizedNumber,
HttpClients.createDefault());
+ }
+
+ public CustomModel(
+ String model,
+ String apiPath,
+ Map<String, String> header,
+ Map<String, Object> body,
+ String parse,
+ Integer vectorizedNumber,
+ CloseableHttpClient client) {
super(vectorizedNumber);
this.apiPath = apiPath;
this.model = model;
this.header = header;
this.body = body;
this.parse = parse;
- this.client = HttpClients.createDefault();
+ this.client = client;
}
@Override
@@ -73,7 +84,7 @@ public class CustomModel extends AbstractModel {
@Override
public Integer dimension() throws IOException {
- return vectorGeneration(new Object[] {DIMENSION_EXAMPLE}).size();
+ return vectorGeneration(new Object[]
{DIMENSION_EXAMPLE}).get(0).size();
}
private List<List<Float>> vectorGeneration(Object[] fields) throws
IOException {
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java
index f2b1e348c7..2174e61996 100644
---
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java
@@ -46,11 +46,20 @@ public class DoubaoModel extends AbstractModel {
private final String apiPath;
public DoubaoModel(String apiKey, String model, String apiPath, Integer
vectorizedNumber) {
+ this(apiKey, model, apiPath, vectorizedNumber,
HttpClients.createDefault());
+ }
+
+ public DoubaoModel(
+ String apiKey,
+ String model,
+ String apiPath,
+ Integer vectorizedNumber,
+ CloseableHttpClient client) {
super(vectorizedNumber);
this.apiKey = apiKey;
this.model = model;
this.apiPath = apiPath;
- this.client = HttpClients.createDefault();
+ this.client = client;
}
@Override
@@ -60,7 +69,7 @@ public class DoubaoModel extends AbstractModel {
@Override
public Integer dimension() throws IOException {
- return vectorGeneration(new Object[] {DIMENSION_EXAMPLE}).size();
+ return vectorGeneration(new Object[]
{DIMENSION_EXAMPLE}).get(0).size();
}
private List<List<Float>> vectorGeneration(Object[] fields) throws
IOException {
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/openai/OpenAIModel.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/openai/OpenAIModel.java
index 2a45cc829f..932b77df92 100644
---
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/openai/OpenAIModel.java
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/openai/OpenAIModel.java
@@ -45,11 +45,20 @@ public class OpenAIModel extends AbstractModel {
private final String apiPath;
public OpenAIModel(String apiKey, String model, String apiPath, Integer
vectorizedNumber) {
+ this(apiKey, model, apiPath, vectorizedNumber,
HttpClients.createDefault());
+ }
+
+ public OpenAIModel(
+ String apiKey,
+ String model,
+ String apiPath,
+ Integer vectorizedNumber,
+ CloseableHttpClient client) {
super(vectorizedNumber);
this.apiKey = apiKey;
this.model = model;
this.apiPath = apiPath;
- this.client = HttpClients.createDefault();
+ this.client = client;
}
@Override
@@ -62,7 +71,7 @@ public class OpenAIModel extends AbstractModel {
@Override
public Integer dimension() throws IOException {
- return vectorGeneration(new Object[] {DIMENSION_EXAMPLE}).size();
+ return vectorGeneration(new Object[]
{DIMENSION_EXAMPLE}).get(0).size();
}
private List<List<Float>> vectorGeneration(Object[] fields) throws
IOException {
diff --git
a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/EmbeddingModelDimensionTest.java
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/EmbeddingModelDimensionTest.java
new file mode 100644
index 0000000000..efb65bb341
--- /dev/null
+++
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/EmbeddingModelDimensionTest.java
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.seatunnel.transform.embedding;
+
+import
org.apache.seatunnel.transform.nlpmodel.embedding.remote.custom.CustomModel;
+import
org.apache.seatunnel.transform.nlpmodel.embedding.remote.doubao.DoubaoModel;
+import
org.apache.seatunnel.transform.nlpmodel.embedding.remote.openai.OpenAIModel;
+
+import org.apache.http.ProtocolVersion;
+import org.apache.http.client.methods.CloseableHttpResponse;
+import org.apache.http.impl.client.CloseableHttpClient;
+import org.apache.http.message.BasicStatusLine;
+import org.apache.http.util.EntityUtils;
+
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+import org.mockito.MockedStatic;
+import org.mockito.Mockito;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.concurrent.ThreadLocalRandom;
+
+public class EmbeddingModelDimensionTest {
+
+ @Test
+ void testCustomModelDimension() throws IOException {
+ CloseableHttpClient client = Mockito.mock(CloseableHttpClient.class);
+ CustomModel model =
+ new CustomModel(
+ "modelName",
+ "https://api.custom.com/v1/chat/completions",
+ new HashMap<>(),
+ new HashMap<>(),
+ "$.data[*].embedding",
+ 1,
+ client);
+
+ int dimension = ThreadLocalRandom.current().nextInt(1024, 4097);
+ List<Float> vector = generateVector(dimension);
+ String responseStr =
+ "{\"created\":\"1753944315\",\"data\":[{\"embedding\":"
+ + vector
+ +
",\"index\":0,\"object\":\"embedding\"}],\"id\":\"021753944315445384c5dcd581d413bdefc6446277658dfef1939\",\"model\":\"doubao-embedding-text-240715\",\"object\":\"list\",\"usage\":{\"completionTokens\":0,\"promptTokens\":3,\"totalTokens\":3}}";
+
+ try (MockedStatic<EntityUtils> entityUtils =
Mockito.mockStatic(EntityUtils.class)) {
+ CloseableHttpResponse response =
Mockito.mock(CloseableHttpResponse.class);
+ Mockito.when(client.execute(Mockito.any())).thenReturn(response);
+ Mockito.when(response.getStatusLine())
+ .thenReturn(new BasicStatusLine(new
ProtocolVersion("HTTP", 1, 1), 200, "OK"));
+ entityUtils
+ .when(() -> EntityUtils.toString(response.getEntity()))
+ .thenReturn(responseStr);
+
+ Assertions.assertEquals(dimension, model.dimension());
+ }
+ }
+
+ @Test
+ void testDoubleModelDimension() throws IOException {
+ CloseableHttpClient client = Mockito.mock(CloseableHttpClient.class);
+ DoubaoModel model =
+ new DoubaoModel(
+ "apikey",
+ "modelName",
+ "https://api.doubao.io/v1/chat/completions",
+ 1,
+ client);
+
+ int dimension = ThreadLocalRandom.current().nextInt(1024, 2561);
+ List<Float> vector = generateVector(dimension);
+ String responseStr =
+ "{\"created\":\"1753944315\",\"data\":[{\"embedding\":"
+ + vector
+ +
",\"index\":0,\"object\":\"embedding\"}],\"id\":\"021753944315445384c5dcd581d413bdefc6446277658dfef1939\",\"model\":\"doubao-embedding-text-240715\",\"object\":\"list\",\"usage\":{\"completionTokens\":0,\"promptTokens\":3,\"totalTokens\":3}}";
+
+ try (MockedStatic<EntityUtils> entityUtils =
Mockito.mockStatic(EntityUtils.class)) {
+ CloseableHttpResponse response =
Mockito.mock(CloseableHttpResponse.class);
+ Mockito.when(client.execute(Mockito.any())).thenReturn(response);
+ Mockito.when(response.getStatusLine())
+ .thenReturn(new BasicStatusLine(new
ProtocolVersion("HTTP", 1, 1), 200, "OK"));
+ entityUtils
+ .when(() -> EntityUtils.toString(response.getEntity()))
+ .thenReturn(responseStr);
+
+ Assertions.assertEquals(dimension, model.dimension());
+ }
+ }
+
+ @Test
+ void testOpenAIModelDimension() throws IOException {
+ CloseableHttpClient client = Mockito.mock(CloseableHttpClient.class);
+ OpenAIModel model =
+ new OpenAIModel(
+ "apikey",
+ "modelName",
+ "https://api.openai.com/v1/chat/completions",
+ 1,
+ client);
+
+ int dimension = ThreadLocalRandom.current().nextInt(1024, 1537);
+ List<Float> vector = generateVector(dimension);
+ String responseStr =
+
"{\"object\":\"list\",\"data\":[{\"object\":\"embedding\",\"embedding\":"
+ + vector
+ +
",\"index\":0}],\"model\":\"text-embedding-ada-002\",\"usage\":{\"prompt_tokens\":8,\"total_tokens\":8}}";
+
+ try (MockedStatic<EntityUtils> entityUtils =
Mockito.mockStatic(EntityUtils.class)) {
+ CloseableHttpResponse response =
Mockito.mock(CloseableHttpResponse.class);
+ Mockito.when(response.getStatusLine())
+ .thenReturn(new BasicStatusLine(new
ProtocolVersion("HTTP", 1, 1), 200, "OK"));
+ Mockito.when(client.execute(Mockito.any())).thenReturn(response);
+ entityUtils
+ .when(() -> EntityUtils.toString(response.getEntity()))
+ .thenReturn(responseStr);
+
+ Assertions.assertEquals(dimension, model.dimension());
+ }
+ }
+
+ private List<Float> generateVector(int dimension) {
+ List<Float> vector = new ArrayList<>();
+ for (int i = 0; i < dimension; i++) {
+ vector.add(ThreadLocalRandom.current().nextFloat());
+ }
+ return vector;
+ }
+}