This is an automated email from the ASF dual-hosted git repository.
wanghailin pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/seatunnel.git
The following commit(s) were added to refs/heads/dev by this push:
new ecd317368d [Feature][Transforms-V2] LLM transforms Support KimiAI
(#7630)
ecd317368d is described below
commit ecd317368d1357d35d0fbac7fe3be66da6674127
Author: zhangdonghao <[email protected]>
AuthorDate: Fri Sep 13 18:51:23 2024 +0800
[Feature][Transforms-V2] LLM transforms Support KimiAI (#7630)
---
docs/en/transform-v2/llm.md | 55 ++++++++++-
docs/zh/transform-v2/llm.md | 58 +++++++++++-
.../apache/seatunnel/e2e/transform/TestLLMIT.java | 8 ++
.../src/test/resources/llm_kimiai_transform.conf | 78 ++++++++++++++++
.../src/test/resources/mockserver-config.json | 32 +++++++
.../transform/nlpmodel/ModelProvider.java | 1 +
.../transform/nlpmodel/llm/LLMTransform.java | 11 +++
.../nlpmodel/llm/remote/kimiai/KimiAIModel.java | 103 +++++++++++++++++++++
.../transform/llm/LLMRequestJsonTest.java | 26 ++++++
9 files changed, 364 insertions(+), 8 deletions(-)
diff --git a/docs/en/transform-v2/llm.md b/docs/en/transform-v2/llm.md
index d1b8e6fc6e..6d036064de 100644
--- a/docs/en/transform-v2/llm.md
+++ b/docs/en/transform-v2/llm.md
@@ -27,7 +27,7 @@ more.
### model_provider
The model provider to use. The available options are:
-OPENAI、DOUBAO、CUSTOM
+OPENAI, DOUBAO, KIMIAI, CUSTOM
### output_data_type
@@ -155,7 +155,11 @@ The `custom_request_body` option supports placeholders:
Transform plugin common parameters, please refer to [Transform
Plugin](common-options.md) for details
-## Example
+## tips
+The API interface usually has a rate limit, which can be configured with
Seatunnel's speed limit to ensure smooth operation of the task.
+For details about Seatunnel speed limit Settings, please refer to
[speed-limit](../concept/speed-limit.md) for details.
+
+## Example OPENAI
Determine the user's country through a LLM.
@@ -163,6 +167,7 @@ Determine the user's country through a LLM.
env {
parallelism = 1
job.mode = "BATCH"
+ read_limit.rows_per_second = 10
}
source {
@@ -199,6 +204,51 @@ sink {
}
```
+## Example KIMIAI
+
+Determine whether a person is a historical emperor of China.
+
+```hocon
+env {
+ parallelism = 1
+ job.mode = "BATCH"
+ read_limit.rows_per_second = 10
+}
+
+source {
+ FakeSource {
+ row.num = 5
+ schema = {
+ fields {
+ id = "int"
+ name = "string"
+ }
+ }
+ rows = [
+ {fields = [1, "Zhuge Liang"], kind = INSERT}
+ {fields = [2, "Li Shimin"], kind = INSERT}
+ {fields = [3, "Sun Wukong"], kind = INSERT}
+ {fields = [4, "Zhu Yuanzhuang"], kind = INSERT}
+ {fields = [5, "George Washington"], kind = INSERT}
+ ]
+ }
+}
+
+transform {
+ LLM {
+ model_provider = KIMIAI
+ model = moonshot-v1-8k
+ api_key = sk-xxx
+ prompt = "Determine whether a person is a historical emperor of China"
+ output_data_type = boolean
+ }
+}
+
+sink {
+ console {
+ }
+}
+```
### Customize the LLM model
```hocon
@@ -277,4 +327,3 @@ sink {
}
}
```
-
diff --git a/docs/zh/transform-v2/llm.md b/docs/zh/transform-v2/llm.md
index c2d3c0f6ca..3ce53b78a6 100644
--- a/docs/zh/transform-v2/llm.md
+++ b/docs/zh/transform-v2/llm.md
@@ -25,7 +25,7 @@
### model_provider
要使用的模型提供者。可用选项为:
-OPENAI、DOUBAO、CUSTOM
+OPENAI、DOUBAO、KIMIAI、CUSTOM
### output_data_type
@@ -81,8 +81,7 @@ transform {
### model
要使用的模型。不同的模型提供者有不同的模型。例如,OpenAI 模型可以是 `gpt-4o-mini`。
-如果使用 OpenAI 模型,请参考
https://platform.openai.com/docs/models/model-endpoint-compatibility
-文档的`/v1/chat/completions` 端点。
+如果使用 OpenAI 模型,请参考
https://platform.openai.com/docs/models/model-endpoint-compatibility
文档的`/v1/chat/completions` 端点。
### api_key
@@ -148,7 +147,11 @@ transform {
转换插件的常见参数, 请参考 [Transform Plugin](common-options.md) 了解详情
-## 示例
+## tips
+大模型API接口通常会有速率限制,可以配合Seatunnel的限速配置,已确保任务顺利运行。
+Seatunnel限速配置,请参考[speed-limit](../concept/speed-limit.md)了解详情
+
+## 示例 OPENAI
通过 LLM 确定用户所在的国家。
@@ -156,6 +159,7 @@ transform {
env {
parallelism = 1
job.mode = "BATCH"
+ read_limit.rows_per_second = 10
}
source {
@@ -192,6 +196,51 @@ sink {
}
```
+## 示例 KIMIAI
+
+通过 LLM 判断人名是否中国历史上的帝王
+
+```hocon
+env {
+ parallelism = 1
+ job.mode = "BATCH"
+ read_limit.rows_per_second = 10
+}
+
+source {
+ FakeSource {
+ row.num = 5
+ schema = {
+ fields {
+ id = "int"
+ name = "string"
+ }
+ }
+ rows = [
+ {fields = [1, "诸葛亮"], kind = INSERT}
+ {fields = [2, "李世民"], kind = INSERT}
+ {fields = [3, "孙悟空"], kind = INSERT}
+ {fields = [4, "朱元璋"], kind = INSERT}
+ {fields = [5, "乔治·华盛顿"], kind = INSERT}
+ ]
+ }
+}
+
+transform {
+ LLM {
+ model_provider = KIMIAI
+ model = moonshot-v1-8k
+ api_key = sk-xxx
+ prompt = "判断是否是中国历史上的帝王"
+ output_data_type = boolean
+ }
+}
+
+sink {
+ console {
+ }
+}
+```
### Customize the LLM model
```hocon
@@ -270,4 +319,3 @@ sink {
}
}
```
-
diff --git
a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/java/org/apache/seatunnel/e2e/transform/TestLLMIT.java
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/java/org/apache/seatunnel/e2e/transform/TestLLMIT.java
index 244bca1e9c..b97d7182e1 100644
---
a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/java/org/apache/seatunnel/e2e/transform/TestLLMIT.java
+++
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/java/org/apache/seatunnel/e2e/transform/TestLLMIT.java
@@ -93,6 +93,7 @@ public class TestLLMIT extends TestSuiteBase implements
TestResource {
throws IOException, InterruptedException {
Container.ExecResult execResult =
container.executeJob("/llm_openai_transform_boolean.conf");
+ Assertions.assertEquals(0, execResult.getExitCode());
}
@TestTemplate
@@ -109,4 +110,11 @@ public class TestLLMIT extends TestSuiteBase implements
TestResource {
Container.ExecResult execResult =
container.executeJob("/llm_transform_custom.conf");
Assertions.assertEquals(0, execResult.getExitCode());
}
+
+ @TestTemplate
+ public void testLLMWithKimiAI(TestContainer container)
+ throws IOException, InterruptedException {
+ Container.ExecResult execResult =
container.executeJob("/llm_kimiai_transform.conf");
+ Assertions.assertEquals(0, execResult.getExitCode());
+ }
}
diff --git
a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/llm_kimiai_transform.conf
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/llm_kimiai_transform.conf
new file mode 100644
index 0000000000..6833257e2b
--- /dev/null
+++
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/llm_kimiai_transform.conf
@@ -0,0 +1,78 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+######
+###### This config file is a demonstration of streaming processing in
seatunnel config
+######
+
+env {
+ parallelism = 1
+ job.mode = "BATCH"
+ read_limit.rows_per_second = 1
+}
+
+source {
+ FakeSource {
+ row.num = 5
+ schema = {
+ fields {
+ id = "int"
+ name = "string"
+ }
+ }
+ rows = [
+ {fields = [1, "Zhuge Liang"], kind = INSERT}
+ {fields = [2, "Li Shimin"], kind = INSERT}
+ {fields = [3, "Sun Wukong"], kind = INSERT}
+ {fields = [4, "Zhu Yuanzhuang"], kind = INSERT}
+ {fields = [5, "George Washington"], kind = INSERT}
+ ]
+ result_table_name = "fake"
+ }
+}
+
+transform {
+ LLM {
+ source_table_name = "fake"
+ model_provider = KIMIAI
+ model = moonshot-v1-8k
+ api_key = sk-xxx
+ prompt = "Determine whether a person is a historical emperor of China"
+ api_path = "http://mockserver:1080/v3/chat/completions"
+ output_data_type = boolean
+ result_table_name = "llm_output"
+ }
+}
+
+sink {
+ Assert {
+ source_table_name = "llm_output"
+ rules =
+ {
+ field_rules = [
+ {
+ field_name = llm_output
+ field_type = boolean
+ field_value = [
+ {
+ rule_type = NOT_NULL
+ }
+ ]
+ }
+ ]
+ }
+ }
+}
diff --git
a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/mockserver-config.json
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/mockserver-config.json
index f7674d3a2a..44dd94396e 100644
---
a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/mockserver-config.json
+++
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/mockserver-config.json
@@ -72,5 +72,37 @@
"Content-Type": "application/json"
}
}
+ },
+ {
+ "httpRequest": {
+ "method": "POST",
+ "path": "/v3/chat/completions"
+ },
+ "httpResponse": {
+ "body": {
+ "id": "chatcmpl-66e0291f428f9d4703bf4edc",
+ "object": "chat.completion",
+ "created": 1725966623,
+ "model": "moonshot-v1-8k",
+ "choices": [
+ {
+ "index": 0,
+ "message": {
+ "role": "assistant",
+ "content": "[False]"
+ },
+ "finish_reason": "stop"
+ }
+ ],
+ "usage": {
+ "prompt_tokens": 113,
+ "completion_tokens": 10,
+ "total_tokens": 123
+ }
+ },
+ "headers": {
+ "Content-Type": "application/json"
+ }
+ }
}
]
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelProvider.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelProvider.java
index c14877816f..ce22bc5a6d 100644
---
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelProvider.java
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelProvider.java
@@ -25,6 +25,7 @@ public enum ModelProvider {
"https://ark.cn-beijing.volces.com/api/v3/chat/completions",
"https://ark.cn-beijing.volces.com/api/v3/embeddings"),
QIANFAN("",
"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings"),
+ KIMIAI("https://api.moonshot.cn/v1/chat/completions", ""),
CUSTOM("", ""),
LOCAL("", "");
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransform.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransform.java
index 705253a2fe..a29fd677ca 100644
---
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransform.java
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransform.java
@@ -30,6 +30,7 @@ import org.apache.seatunnel.transform.nlpmodel.ModelProvider;
import org.apache.seatunnel.transform.nlpmodel.ModelTransformConfig;
import org.apache.seatunnel.transform.nlpmodel.llm.remote.Model;
import org.apache.seatunnel.transform.nlpmodel.llm.remote.custom.CustomModel;
+import org.apache.seatunnel.transform.nlpmodel.llm.remote.kimiai.KimiAIModel;
import org.apache.seatunnel.transform.nlpmodel.llm.remote.openai.OpenAIModel;
import lombok.NonNull;
@@ -103,6 +104,16 @@ public class LLMTransform extends
SingleFieldOutputTransform {
config.get(LLMTransformConfig.MODEL),
config.get(LLMTransformConfig.API_KEY),
provider.usedLLMPath(config.get(LLMTransformConfig.API_PATH)));
+ case KIMIAI:
+ model =
+ new KimiAIModel(
+ inputCatalogTable.getSeaTunnelRowType(),
+ outputDataType.getSqlType(),
+
config.get(LLMTransformConfig.INFERENCE_COLUMNS),
+ config.get(LLMTransformConfig.PROMPT),
+ config.get(LLMTransformConfig.MODEL),
+ config.get(LLMTransformConfig.API_KEY),
+
provider.usedLLMPath(config.get(LLMTransformConfig.API_PATH)));
break;
case QIANFAN:
default:
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/kimiai/KimiAIModel.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/kimiai/KimiAIModel.java
new file mode 100644
index 0000000000..803f646078
--- /dev/null
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/kimiai/KimiAIModel.java
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.seatunnel.transform.nlpmodel.llm.remote.kimiai;
+
+import
org.apache.seatunnel.shade.com.fasterxml.jackson.core.type.TypeReference;
+import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.JsonNode;
+import
org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ArrayNode;
+import
org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ObjectNode;
+
+import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
+import org.apache.seatunnel.api.table.type.SqlType;
+import org.apache.seatunnel.transform.nlpmodel.llm.remote.AbstractModel;
+
+import org.apache.http.client.config.RequestConfig;
+import org.apache.http.client.methods.CloseableHttpResponse;
+import org.apache.http.client.methods.HttpPost;
+import org.apache.http.entity.StringEntity;
+import org.apache.http.impl.client.CloseableHttpClient;
+import org.apache.http.impl.client.HttpClients;
+import org.apache.http.util.EntityUtils;
+
+import com.google.common.annotations.VisibleForTesting;
+import lombok.extern.slf4j.Slf4j;
+
+import java.io.IOException;
+import java.util.List;
+
+@Slf4j
+public class KimiAIModel extends AbstractModel {
+
+ private final CloseableHttpClient client;
+ private final String apiKey;
+ private final String model;
+ private final String apiPath;
+
+ public KimiAIModel(
+ SeaTunnelRowType rowType,
+ SqlType outputType,
+ List<String> projectionColumns,
+ String prompt,
+ String model,
+ String apiKey,
+ String apiPath) {
+ super(rowType, outputType, projectionColumns, prompt);
+ this.apiKey = apiKey;
+ this.apiPath = apiPath;
+ this.model = model;
+ this.client = HttpClients.createDefault();
+ }
+
+ @Override
+ protected List<String> chatWithModel(String prompt, String data) throws
IOException {
+ HttpPost post = new HttpPost(apiPath);
+ post.setHeader("Authorization", "Bearer " + apiKey);
+ post.setHeader("Content-Type", "application/json");
+ ObjectNode objectNode = createJsonNodeFromData(prompt, data);
+ post.setEntity(new
StringEntity(OBJECT_MAPPER.writeValueAsString(objectNode), "UTF-8"));
+ post.setConfig(
+
RequestConfig.custom().setConnectTimeout(20000).setSocketTimeout(20000).build());
+ CloseableHttpResponse response = client.execute(post);
+ String responseStr = EntityUtils.toString(response.getEntity());
+ if (response.getStatusLine().getStatusCode() != 200) {
+ throw new IOException("Failed to chat with model, response: " +
responseStr);
+ }
+
+ JsonNode result = OBJECT_MAPPER.readTree(responseStr);
+ String resultData =
result.get("choices").get(0).get("message").get("content").asText();
+ return OBJECT_MAPPER.readValue(
+ convertData(resultData), new TypeReference<List<String>>() {});
+ }
+
+ @VisibleForTesting
+ public ObjectNode createJsonNodeFromData(String prompt, String data) {
+ ObjectNode objectNode = OBJECT_MAPPER.createObjectNode();
+ objectNode.put("model", model);
+ ArrayNode messages = objectNode.putArray("messages");
+ messages.addObject().put("role", "system").put("content", prompt);
+ messages.addObject().put("role", "user").put("content", data);
+ return objectNode;
+ }
+
+ @Override
+ public void close() throws IOException {
+ if (client != null) {
+ client.close();
+ }
+ }
+}
diff --git
a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/llm/LLMRequestJsonTest.java
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/llm/LLMRequestJsonTest.java
index 97eb1b8f96..91666c4139 100644
---
a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/llm/LLMRequestJsonTest.java
+++
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/llm/LLMRequestJsonTest.java
@@ -27,6 +27,7 @@ import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
import org.apache.seatunnel.api.table.type.SqlType;
import org.apache.seatunnel.format.json.RowToJsonConverters;
import org.apache.seatunnel.transform.nlpmodel.llm.remote.custom.CustomModel;
+import org.apache.seatunnel.transform.nlpmodel.llm.remote.kimiai.KimiAIModel;
import org.apache.seatunnel.transform.nlpmodel.llm.remote.openai.OpenAIModel;
import org.junit.jupiter.api.Assertions;
@@ -104,6 +105,31 @@ public class LLMRequestJsonTest {
model.close();
}
+ @Test
+ void testKimiAIRequestJson() throws IOException {
+ SeaTunnelRowType rowType =
+ new SeaTunnelRowType(
+ new String[] {"id", "name"},
+ new SeaTunnelDataType[] {BasicType.INT_TYPE,
BasicType.STRING_TYPE});
+ KimiAIModel model =
+ new KimiAIModel(
+ rowType,
+ SqlType.STRING,
+ null,
+ "Determine whether someone is Chinese or American by
their name",
+ "moonshot-v1-8k",
+ "sk-xxx",
+ "https://api.moonshot.cn/v1/chat/completions");
+ ObjectNode node =
+ model.createJsonNodeFromData(
+ "Determine whether someone is Chinese or American by
their name",
+ "{\"id\":1, \"name\":\"John\"}");
+ Assertions.assertEquals(
+
"{\"model\":\"moonshot-v1-8k\",\"messages\":[{\"role\":\"system\",\"content\":\"Determine
whether someone is Chinese or American by their
name\"},{\"role\":\"user\",\"content\":\"{\\\"id\\\":1,
\\\"name\\\":\\\"John\\\"}\"}]}",
+ OBJECT_MAPPER.writeValueAsString(node));
+ model.close();
+ }
+
@Test
void testCustomRequestJson() throws IOException {
SeaTunnelRowType rowType =