This is an automated email from the ASF dual-hosted git repository.
wanghailin pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/seatunnel.git
The following commit(s) were added to refs/heads/dev by this push:
new 51ffc5a97e [Feature][Transform-v2] Add support for Zhipu AI in
Embedding and LLM module (#8790)
51ffc5a97e is described below
commit 51ffc5a97e4f7a75d1df47ecba9039a52d59abae
Author: xiaochen <[email protected]>
AuthorDate: Tue Feb 25 10:47:29 2025 +0800
[Feature][Transform-v2] Add support for Zhipu AI in Embedding and LLM
module (#8790)
---
docs/en/transform-v2/embedding.md | 29 ++++++-----
docs/en/transform-v2/llm.md | 2 +-
docs/zh/transform-v2/embedding.md | 29 ++++++-----
docs/zh/transform-v2/llm.md | 2 +-
seatunnel-transforms-v2/pom.xml | 4 +-
.../transform/nlpmodel/ModelProvider.java | 3 ++
.../transform/nlpmodel/ModelTransformConfig.java | 3 ++
.../nlpmodel/embedding/EmbeddingTransform.java | 13 +++++
.../embedding/EmbeddingTransformFactory.java | 4 ++
.../nlpmodel/embedding/remote/AbstractModel.java | 14 ++---
.../embedding/remote/custom/CustomModel.java | 6 +--
.../embedding/remote/doubao/DoubaoModel.java | 10 ++--
.../embedding/remote/openai/OpenAIModel.java | 10 ++--
.../embedding/remote/qianfan/QianfanModel.java | 10 ++--
.../DoubaoModel.java => zhipu/ZhipuModel.java} | 60 ++++++++++++++--------
.../transform/nlpmodel/llm/LLMTransform.java | 1 +
.../embedding/EmbeddingRequestJsonTest.java | 25 ++++++++-
17 files changed, 147 insertions(+), 78 deletions(-)
diff --git a/docs/en/transform-v2/embedding.md
b/docs/en/transform-v2/embedding.md
index 350a23fc55..cbebd535b1 100644
--- a/docs/en/transform-v2/embedding.md
+++ b/docs/en/transform-v2/embedding.md
@@ -10,20 +10,21 @@ different API endpoints.
## Options
-| Name | Type | Required | Default Value |
Description
|
-|--------------------------------|--------|----------|---------------|-------------------------------------------------------------------------------------------------------------|
-| model_provider | enum | yes | - | The
model provider for embedding. Options may include `QIANFAN`, `OPENAI`, etc.
|
-| api_key | string | yes | - | The API
key required to authenticate with the embedding service.
|
-| secret_key | string | yes | - | The
secret key required for additional authentication with the embedding service.
|
-| single_vectorized_input_number | int | no | 1 | The
number of inputs vectorized in one request. Default is 1.
|
-| vectorization_fields | map | yes | - | A
mapping between input fields and their corresponding output vector fields.
|
-| model | string | yes | - | The
specific model to use for embedding (e.g: `text-embedding-3-small` for OPENAI).
|
-| api_path | string | no | - | The API
endpoint for the embedding service. Typically provided by the model provider.
|
-| oauth_path | string | no | - | The API
endpoint for the oauth service.
|
-| custom_config | map | no | | Custom
configurations for the model.
|
-| custom_response_parse | string | no | |
Specifies how to parse the response from the model using JsonPath. Example:
`$.choices[*].message.content`. |
-| custom_request_headers | map | no | | Custom
headers for the request to the model.
|
-| custom_request_body | map | no | | Custom
body for the request. Supports placeholders like `${model}`, `${input}`.
|
+| Name | Type | Required | Default Value |
Description
|
+|----------------------------------|--------|----------|---------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| model_provider | enum | yes | - | The
model provider for embedding. Options may include `QIANFAN`, `OPENAI`, etc.
|
+| api_key | string | yes | - | The
API key required to authenticate with the embedding service.
|
+| secret_key | string | yes | - | The
secret key required for additional authentication with the embedding service.
|
+| single_vectorized_input_number | int | no | 1 | The
number of inputs vectorized in one request. Default is 1.
|
+| vectorization_fields | map | yes | - | A
mapping between input fields and their corresponding output vector fields.
|
+| model | string | yes | - | The
specific model to use for embedding (e.g: `text-embedding-3-small` for OPENAI).
|
+| api_path | string | no | - | The
API endpoint for the embedding service. Typically provided by the model
provider.
|
+| dimension | int | no | - | TThe
vector dimension defaults to 2048. The Embedding-3 model supports custom vector
dimensions, and it is recommended to choose dimensions of 256, 512, 1024, or
2048. |
+| oauth_path | string | no | - | The
API endpoint for the oauth service.
|
+| custom_config | map | no | |
Custom configurations for the model.
|
+| custom_response_parse | string | no | |
Specifies how to parse the response from the model using JsonPath. Example:
`$.choices[*].message.content`.
|
+| custom_request_headers | map | no | |
Custom headers for the request to the model.
|
+| custom_request_body | map | no | |
Custom body for the request. Supports placeholders like `${model}`, `${input}`.
|
### model_provider
diff --git a/docs/en/transform-v2/llm.md b/docs/en/transform-v2/llm.md
index 680121cb4d..0bc137bded 100644
--- a/docs/en/transform-v2/llm.md
+++ b/docs/en/transform-v2/llm.md
@@ -28,7 +28,7 @@ more.
### model_provider
The model provider to use. The available options are:
-OPENAI, DOUBAO, DEEPSEEK, KIMIAI, MICROSOFT, CUSTOM
+OPENAI, DOUBAO, DEEPSEEK, KIMIAI, MICROSOFT, ZHIPU, CUSTOM
> tips: If you use Microsoft, please make sure api_path cannot be empty
diff --git a/docs/zh/transform-v2/embedding.md
b/docs/zh/transform-v2/embedding.md
index e05c9c2442..8ea2b68f70 100644
--- a/docs/zh/transform-v2/embedding.md
+++ b/docs/zh/transform-v2/embedding.md
@@ -8,20 +8,21 @@
## 配置选项
-| 名称 | 类型 | 是否必填 | 默认值 | 描述
|
-|--------------------------------|--------|------|-----|------------------------------------------------------------------|
-| model_provider | enum | 是 | - | embedding模型的提供商。可选项包括
`QIANFAN`、`OPENAI` 等。 |
-| api_key | string | 是 | - |
用于验证embedding服务的API密钥。 |
-| secret_key | string | 是 | - |
用于额外验证的密钥。一些提供商可能需要此密钥进行安全的API请求。 |
-| single_vectorized_input_number | int | 否 | 1 | 单次请求向量化的输入数量。默认值为1。
|
-| vectorization_fields | map | 是 | - | 输入字段和相应的输出向量字段之间的映射。
|
-| model | string | 是 | - |
要使用的具体embedding模型。例如,如果提供商为OPENAI,可以指定 `text-embedding-3-small`。 |
-| api_path | string | 否 | - |
embedding服务的API。通常由模型提供商提供。 |
-| oauth_path | string | 否 | - | oauth 服务的 API 。
|
-| custom_config | map | 否 | | 模型的自定义配置。
|
-| custom_response_parse | string | 否 | | 使用 JsonPath
解析模型响应的方式。示例:`$.choices[*].message.content`。 |
-| custom_request_headers | map | 否 | | 发送到模型的请求的自定义头信息。
|
-| custom_request_body | map | 否 | | 请求体的自定义配置。支持占位符如
`${model}`、`${input}`。 |
+| 名称 | 类型 | 是否必填 | 默认值 | 描述
|
+|----------------------------------|--------|------|--------|--------------------------------------------------------------------|
+| model_provider | enum | 是 | - |
embedding模型的提供商。可选项包括 `QIANFAN`、`OPENAI` 等。 |
+| api_key | string | 是 | - |
用于验证embedding服务的API密钥。 |
+| secret_key | string | 是 | - |
用于额外验证的密钥。一些提供商可能需要此密钥进行安全的API请求。 |
+| single_vectorized_input_number | int | 否 | 1 |
单次请求向量化的输入数量。默认值为1。 |
+| vectorization_fields | map | 是 | - |
输入字段和相应的输出向量字段之间的映射。 |
+| model | string | 是 | - |
要使用的具体embedding模型。例如,如果提供商为OPENAI,可以指定 `text-embedding-3-small`。 |
+| api_path | string | 否 | - |
embedding服务的API。通常由模型提供商提供。 |
+| dimension | int | 否 | 2048 | 向量维度默认为
2048,Embedding-3模型支持自定义向量维度,建议选择256、512、1024或2048维度。 |
+| oauth_path | string | 否 | - | oauth 服务的 API 。
|
+| custom_config | map | 否 | | 模型的自定义配置。
|
+| custom_response_parse | string | 否 | | 使用 JsonPath
解析模型响应的方式。示例:`$.choices[*].message.content`。 |
+| custom_request_headers | map | 否 | | 发送到模型的请求的自定义头信息。
|
+| custom_request_body | map | 否 | | 请求体的自定义配置。支持占位符如
`${model}`、`${input}`。 |
### embedding_model_provider
diff --git a/docs/zh/transform-v2/llm.md b/docs/zh/transform-v2/llm.md
index c1d05d59a3..c6cead3dfd 100644
--- a/docs/zh/transform-v2/llm.md
+++ b/docs/zh/transform-v2/llm.md
@@ -26,7 +26,7 @@
### model_provider
要使用的模型提供者。可用选项为:
-OPENAI,DOUBAO,DEEPSEEK,KIMIAI,MICROSOFT, CUSTOM
+OPENAI,DOUBAO,DEEPSEEK,KIMIAI,MICROSOFT, ZHIPU, CUSTOM
> tips: 如果使用 Microsoft, 请确保 api_path 配置不能为空
diff --git a/seatunnel-transforms-v2/pom.xml b/seatunnel-transforms-v2/pom.xml
index f15c1aae40..5f74ad156b 100644
--- a/seatunnel-transforms-v2/pom.xml
+++ b/seatunnel-transforms-v2/pom.xml
@@ -32,6 +32,8 @@
<properties>
<httpclient.version>4.5.13</httpclient.version>
<httpcore.version>4.4.4</httpcore.version>
+ <mockwebserver.version>3.6.0</mockwebserver.version>
+ <zhipu.version>release-V4-2.3.0</zhipu.version>
</properties>
<dependencyManagement>
@@ -95,7 +97,7 @@
<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>mockwebserver</artifactId>
- <version>3.6.0</version>
+ <version>${mockwebserver.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelProvider.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelProvider.java
index f18ffdfc8e..aaeaee90ad 100644
---
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelProvider.java
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelProvider.java
@@ -28,6 +28,9 @@ public enum ModelProvider {
KIMIAI("https://api.moonshot.cn/v1/chat/completions", ""),
DEEPSEEK("https://api.deepseek.com/chat/completions", ""),
MICROSOFT("", ""),
+ ZHIPU(
+ "https://open.bigmodel.cn/api/paas/v4/chat/completions",
+ "https://open.bigmodel.cn/api/paas/v4/embeddings"),
CUSTOM("", ""),
LOCAL("", "");
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelTransformConfig.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelTransformConfig.java
index b123459750..c3709c70b6 100644
---
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelTransformConfig.java
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelTransformConfig.java
@@ -79,6 +79,9 @@ public class ModelTransformConfig implements Serializable {
.withFallbackKeys("inference_batch_size")
.withDescription("The row batch size of each process");
+ public static final Option<Integer> DIMENSION =
+
Options.key("dimension").intType().defaultValue(2048).withDescription("dimension");
+
public static class CustomRequestConfig {
// Custom response parsing
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransform.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransform.java
index c699c6bfe8..6a8729a198 100644
---
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransform.java
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransform.java
@@ -33,6 +33,7 @@ import
org.apache.seatunnel.transform.nlpmodel.embedding.remote.custom.CustomMod
import
org.apache.seatunnel.transform.nlpmodel.embedding.remote.doubao.DoubaoModel;
import
org.apache.seatunnel.transform.nlpmodel.embedding.remote.openai.OpenAIModel;
import
org.apache.seatunnel.transform.nlpmodel.embedding.remote.qianfan.QianfanModel;
+import
org.apache.seatunnel.transform.nlpmodel.embedding.remote.zhipu.ZhipuModel;
import org.apache.seatunnel.transform.nlpmodel.llm.LLMTransformConfig;
import lombok.NonNull;
@@ -136,6 +137,18 @@ public class EmbeddingTransform extends
MultipleFieldOutputTransform {
EmbeddingTransformConfig
.SINGLE_VECTORIZED_INPUT_NUMBER));
break;
+ case ZHIPU:
+ model =
+ new ZhipuModel(
+ config.get(ModelTransformConfig.API_KEY),
+ config.get(ModelTransformConfig.MODEL),
+ provider.usedEmbeddingPath(
+
config.get(ModelTransformConfig.API_PATH)),
+ config.get(ModelTransformConfig.DIMENSION),
+ config.get(
+ EmbeddingTransformConfig
+
.SINGLE_VECTORIZED_INPUT_NUMBER));
+ break;
case LOCAL:
default:
throw new IllegalArgumentException("Unsupported model
provider: " + provider);
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransformFactory.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransformFactory.java
index 56e252e1bb..5f8e397e69 100644
---
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransformFactory.java
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransformFactory.java
@@ -62,6 +62,10 @@ public class EmbeddingTransformFactory implements
TableTransformFactory {
LLMTransformConfig.MODEL_PROVIDER,
ModelProvider.CUSTOM,
LLMTransformConfig.CustomRequestConfig.CUSTOM_CONFIG)
+ .conditional(
+ EmbeddingTransformConfig.MODEL_PROVIDER,
+ ModelProvider.ZHIPU,
+ EmbeddingTransformConfig.DIMENSION)
.optional(TransformCommonOptions.MULTI_TABLES)
.optional(TransformCommonOptions.TABLE_MATCH_REGEX)
.build();
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/AbstractModel.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/AbstractModel.java
index 0803dfd7ad..a53fa4684c 100644
---
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/AbstractModel.java
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/AbstractModel.java
@@ -42,23 +42,23 @@ public abstract class AbstractModel implements Model {
public List<ByteBuffer> vectorization(Object[] fields) throws IOException {
List<ByteBuffer> result = new ArrayList<>();
- List<List<Float>> vectors = batchProcess(fields,
singleVectorizedInputNumber);
- for (List<Float> vector : vectors) {
- result.add(BufferUtils.toByteBuffer(vector.toArray(new Float[0])));
+ List<List<Double>> vectors = batchProcess(fields,
singleVectorizedInputNumber);
+ for (List<Double> vector : vectors) {
+ result.add(BufferUtils.toByteBuffer(vector.toArray(new
Double[0])));
}
return result;
}
- protected abstract List<List<Float>> vector(Object[] fields) throws
IOException;
+ protected abstract List<List<Double>> vector(Object[] fields) throws
IOException;
- public List<List<Float>> batchProcess(Object[] array, int batchSize)
throws IOException {
- List<List<Float>> merged = new ArrayList<>();
+ public List<List<Double>> batchProcess(Object[] array, int batchSize)
throws IOException {
+ List<List<Double>> merged = new ArrayList<>();
if (array == null || array.length == 0) {
return merged;
}
for (int i = 0; i < array.length; i += batchSize) {
Object[] batch = ArrayUtils.subarray(array, i, i + batchSize);
- List<List<Float>> vector = vector(batch);
+ List<List<Double>> vector = vector(batch);
merged.addAll(vector);
}
if (array.length != merged.size()) {
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/custom/CustomModel.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/custom/CustomModel.java
index 179315f956..ea39f15462 100644
---
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/custom/CustomModel.java
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/custom/CustomModel.java
@@ -67,7 +67,7 @@ public class CustomModel extends AbstractModel {
}
@Override
- protected List<List<Float>> vector(Object[] fields) throws IOException {
+ protected List<List<Double>> vector(Object[] fields) throws IOException {
return vectorGeneration(fields);
}
@@ -76,7 +76,7 @@ public class CustomModel extends AbstractModel {
return vectorGeneration(new Object[] {DIMENSION_EXAMPLE}).size();
}
- private List<List<Float>> vectorGeneration(Object[] fields) throws
IOException {
+ private List<List<Double>> vectorGeneration(Object[] fields) throws
IOException {
HttpPost post = new HttpPost(apiPath);
// Construct a request with custom parameters
for (Map.Entry<String, String> entry : header.entrySet()) {
@@ -96,7 +96,7 @@ public class CustomModel extends AbstractModel {
}
return OBJECT_MAPPER.convertValue(
- parseResponse(responseStr), new
TypeReference<List<List<Float>>>() {});
+ parseResponse(responseStr), new
TypeReference<List<List<Double>>>() {});
}
@VisibleForTesting
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java
index f2b1e348c7..1591cd587e 100644
---
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java
@@ -54,7 +54,7 @@ public class DoubaoModel extends AbstractModel {
}
@Override
- protected List<List<Float>> vector(Object[] fields) throws IOException {
+ protected List<List<Double>> vector(Object[] fields) throws IOException {
return vectorGeneration(fields);
}
@@ -63,7 +63,7 @@ public class DoubaoModel extends AbstractModel {
return vectorGeneration(new Object[] {DIMENSION_EXAMPLE}).size();
}
- private List<List<Float>> vectorGeneration(Object[] fields) throws
IOException {
+ private List<List<Double>> vectorGeneration(Object[] fields) throws
IOException {
HttpPost post = new HttpPost(apiPath);
post.setHeader("Authorization", "Bearer " + apiKey);
post.setHeader("Content-Type", "application/json");
@@ -82,14 +82,14 @@ public class DoubaoModel extends AbstractModel {
}
JsonNode data = OBJECT_MAPPER.readTree(responseStr).get("data");
- List<List<Float>> embeddings = new ArrayList<>();
+ List<List<Double>> embeddings = new ArrayList<>();
if (data.isArray()) {
for (JsonNode node : data) {
JsonNode embeddingNode = node.get("embedding");
- List<Float> embedding =
+ List<Double> embedding =
OBJECT_MAPPER.readValue(
- embeddingNode.traverse(), new
TypeReference<List<Float>>() {});
+ embeddingNode.traverse(), new
TypeReference<List<Double>>() {});
embeddings.add(embedding);
}
}
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/openai/OpenAIModel.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/openai/OpenAIModel.java
index 2a45cc829f..467d6cb406 100644
---
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/openai/OpenAIModel.java
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/openai/OpenAIModel.java
@@ -53,7 +53,7 @@ public class OpenAIModel extends AbstractModel {
}
@Override
- protected List<List<Float>> vector(Object[] fields) throws IOException {
+ protected List<List<Double>> vector(Object[] fields) throws IOException {
if (fields.length > 1) {
throw new IllegalArgumentException("OpenAI model only supports
single input");
}
@@ -65,7 +65,7 @@ public class OpenAIModel extends AbstractModel {
return vectorGeneration(new Object[] {DIMENSION_EXAMPLE}).size();
}
- private List<List<Float>> vectorGeneration(Object[] fields) throws
IOException {
+ private List<List<Double>> vectorGeneration(Object[] fields) throws
IOException {
HttpPost post = new HttpPost(apiPath);
post.setHeader("Authorization", "Bearer " + apiKey);
post.setHeader("Content-Type", "application/json");
@@ -84,14 +84,14 @@ public class OpenAIModel extends AbstractModel {
}
JsonNode data = OBJECT_MAPPER.readTree(responseStr).get("data");
- List<List<Float>> embeddings = new ArrayList<>();
+ List<List<Double>> embeddings = new ArrayList<>();
if (data.isArray()) {
for (JsonNode node : data) {
JsonNode embeddingNode = node.get("embedding");
- List<Float> embedding =
+ List<Double> embedding =
OBJECT_MAPPER.readValue(
- embeddingNode.traverse(), new
TypeReference<List<Float>>() {});
+ embeddingNode.traverse(), new
TypeReference<List<Double>>() {});
embeddings.add(embedding);
}
}
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/qianfan/QianfanModel.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/qianfan/QianfanModel.java
index f85619eb3e..67c1a8147a 100644
---
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/qianfan/QianfanModel.java
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/qianfan/QianfanModel.java
@@ -100,7 +100,7 @@ public class QianfanModel extends AbstractModel {
}
@Override
- public List<List<Float>> vector(Object[] fields) throws IOException {
+ public List<List<Double>> vector(Object[] fields) throws IOException {
return vectorGeneration(fields);
}
@@ -109,7 +109,7 @@ public class QianfanModel extends AbstractModel {
return vectorGeneration(new Object[]
{DIMENSION_EXAMPLE}).get(0).size();
}
- private List<List<Float>> vectorGeneration(Object[] fields) throws
IOException {
+ private List<List<Double>> vectorGeneration(Object[] fields) throws
IOException {
String formattedApiPath =
String.format(
(apiPath.endsWith("/") ? apiPath : apiPath + "/") +
"%s?access_token=%s",
@@ -143,14 +143,14 @@ public class QianfanModel extends AbstractModel {
"Failed to get vector from qianfan, response: " +
result.get("error_msg"));
}
- List<List<Float>> embeddings = new ArrayList<>();
+ List<List<Double>> embeddings = new ArrayList<>();
JsonNode data = result.get("data");
if (data.isArray()) {
for (JsonNode node : data) {
- List<Float> embedding =
+ List<Double> embedding =
OBJECT_MAPPER.readValue(
node.get("embedding").traverse(),
- new TypeReference<List<Float>>() {});
+ new TypeReference<List<Double>>() {});
embeddings.add(embedding);
}
}
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/zhipu/ZhipuModel.java
similarity index 68%
copy from
seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java
copy to
seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/zhipu/ZhipuModel.java
index f2b1e348c7..df72261bb5 100644
---
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/zhipu/ZhipuModel.java
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.seatunnel.transform.nlpmodel.embedding.remote.doubao;
+package org.apache.seatunnel.transform.nlpmodel.embedding.remote.zhipu;
import
org.apache.seatunnel.shade.com.fasterxml.jackson.core.type.TypeReference;
import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.JsonNode;
@@ -25,6 +25,7 @@ import
org.apache.seatunnel.shade.com.google.common.annotations.VisibleForTestin
import org.apache.seatunnel.transform.nlpmodel.embedding.remote.AbstractModel;
+import org.apache.http.HttpHeaders;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpPost;
@@ -34,62 +35,77 @@ import org.apache.http.impl.client.HttpClients;
import org.apache.http.util.EntityUtils;
import java.io.IOException;
+import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
-public class DoubaoModel extends AbstractModel {
+/** Zhipu model. Refer <a
href="https://bigmodel.cn/dev/api/vector/embedding">embedding api </a> */
+public class ZhipuModel extends AbstractModel {
private final CloseableHttpClient client;
- private final String apiKey;
private final String model;
- private final String apiPath;
-
- public DoubaoModel(String apiKey, String model, String apiPath, Integer
vectorizedNumber) {
+ private final String apiKey;
+ private final String apiPath;;
+ private final Integer dimension;
+ private final Integer MAX_INPUT_SIZE = 64;
+
+ public ZhipuModel(
+ String apiKey,
+ String model,
+ String apiPath,
+ Integer dimension,
+ Integer vectorizedNumber)
+ throws IOException {
super(vectorizedNumber);
- this.apiKey = apiKey;
this.model = model;
+ this.apiKey = apiKey;
this.apiPath = apiPath;
+ this.dimension = dimension;
this.client = HttpClients.createDefault();
}
@Override
- protected List<List<Float>> vector(Object[] fields) throws IOException {
+ public List<List<Double>> vector(Object[] fields) throws IOException {
return vectorGeneration(fields);
}
@Override
public Integer dimension() throws IOException {
- return vectorGeneration(new Object[] {DIMENSION_EXAMPLE}).size();
+ return dimension;
}
- private List<List<Float>> vectorGeneration(Object[] fields) throws
IOException {
+ private List<List<Double>> vectorGeneration(Object[] fields) throws
IOException {
+
+ if (fields == null || fields.length > MAX_INPUT_SIZE) {
+ throw new IOException(
+ "Zhipu input text for vectorization, with a maximum limit
of 64 entries.");
+ }
HttpPost post = new HttpPost(apiPath);
- post.setHeader("Authorization", "Bearer " + apiKey);
- post.setHeader("Content-Type", "application/json");
+ post.setHeader(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey);
+ post.setHeader(HttpHeaders.CONTENT_TYPE, "application/json");
post.setConfig(
RequestConfig.custom().setConnectTimeout(20000).setSocketTimeout(20000).build());
post.setEntity(
new StringEntity(
-
OBJECT_MAPPER.writeValueAsString(createJsonNodeFromData(fields)), "UTF-8"));
+
OBJECT_MAPPER.writeValueAsString(createJsonNodeFromData(fields)),
+ StandardCharsets.UTF_8.name()));
CloseableHttpResponse response = client.execute(post);
String responseStr = EntityUtils.toString(response.getEntity());
-
if (response.getStatusLine().getStatusCode() != 200) {
- throw new IOException("Failed to get vector from doubao, response:
" + responseStr);
+ throw new IOException("Failed to get vector from zhipu, response:
" + responseStr);
}
-
JsonNode data = OBJECT_MAPPER.readTree(responseStr).get("data");
- List<List<Float>> embeddings = new ArrayList<>();
+ List<List<Double>> embeddings = new ArrayList<>();
if (data.isArray()) {
for (JsonNode node : data) {
JsonNode embeddingNode = node.get("embedding");
- List<Float> embedding =
+ List<Double> embedding =
OBJECT_MAPPER.readValue(
- embeddingNode.traverse(), new
TypeReference<List<Float>>() {});
+ embeddingNode.traverse(), new
TypeReference<List<Double>>() {});
embeddings.add(embedding);
}
}
@@ -99,7 +115,11 @@ public class DoubaoModel extends AbstractModel {
@VisibleForTesting
public ObjectNode createJsonNodeFromData(Object[] fields) {
ArrayNode arrayNode = OBJECT_MAPPER.valueToTree(Arrays.asList(fields));
- return OBJECT_MAPPER.createObjectNode().put("model",
model).set("input", arrayNode);
+ return OBJECT_MAPPER
+ .createObjectNode()
+ .put("model", model)
+ .put("dimensions", dimension)
+ .set("input", arrayNode);
}
@Override
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransform.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransform.java
index 048e1bffce..8160cdc647 100644
---
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransform.java
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransform.java
@@ -109,6 +109,7 @@ public class LLMTransform extends
SingleFieldOutputTransform {
case DEEPSEEK:
case OPENAI:
case DOUBAO:
+ case ZHIPU:
model =
new OpenAIModel(
inputCatalogTable.getSeaTunnelRowType(),
diff --git
a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/EmbeddingRequestJsonTest.java
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/EmbeddingRequestJsonTest.java
index dc43cdeb23..46c893d182 100644
---
a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/EmbeddingRequestJsonTest.java
+++
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/EmbeddingRequestJsonTest.java
@@ -26,6 +26,7 @@ import
org.apache.seatunnel.transform.nlpmodel.embedding.remote.custom.CustomMod
import
org.apache.seatunnel.transform.nlpmodel.embedding.remote.doubao.DoubaoModel;
import
org.apache.seatunnel.transform.nlpmodel.embedding.remote.openai.OpenAIModel;
import
org.apache.seatunnel.transform.nlpmodel.embedding.remote.qianfan.QianfanModel;
+import
org.apache.seatunnel.transform.nlpmodel.embedding.remote.zhipu.ZhipuModel;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
@@ -93,6 +94,26 @@ public class EmbeddingRequestJsonTest {
model.close();
}
+ @Test
+ void testZhipuRequestJson() throws IOException {
+ ZhipuModel model =
+ new ZhipuModel(
+ "apikey",
+ "modelName",
+ "https://open.bigmodel.cn/api/paas/v4/embeddings",
+ 64,
+ 1);
+ ObjectNode node =
+ model.createJsonNodeFromData(
+ new Object[] {
+ "Determine whether someone is Chinese or American
by their name"
+ });
+ Assertions.assertEquals(
+
"{\"model\":\"modelName\",\"dimensions\":64,\"input\":[\"Determine whether
someone is Chinese or American by their name\"]}",
+ OBJECT_MAPPER.writeValueAsString(node));
+ model.close();
+ }
+
@Test
void testCustomRequestJson() throws IOException {
Map<String, String> header = new HashMap<>();
@@ -131,11 +152,11 @@ public class EmbeddingRequestJsonTest {
new HashMap<>(),
"$.data[*].embedding",
1);
- List<List<Float>> lists =
+ List<List<Double>> lists =
OBJECT_MAPPER.convertValue(
customModel.parseResponse(
"{\"created\":1725001256,\"id\":\"02172500125677376580aba8475a41c550bbf05104842f0405ef5\",\"data\":[{\"embedding\":[-1.625,0.07958984375,-1.5703125,-3.03125,-1.4609375,3.46875,-0.73046875,-2.578125,-0.66796875,1.71875,0.361328125,2,5.125,2.25,4.6875,1.4921875,-0.77734375,-0.466796875,0.0439453125,-2.46875,3.59375,4.96875,2.34375,-5.34375,0.11083984375,-5.875,3.0625,4.09375,3.4375,0.2265625,9,-1.9296875,2.25,0.765625,3.671875,-2.484375,-1.171875,-1.6171875,
[...]
- new TypeReference<List<List<Float>>>() {});
+ new TypeReference<List<List<Double>>>() {});
Assertions.assertEquals(2, lists.size());
Assertions.assertEquals(2560, lists.get(0).size());
}