This is an automated email from the ASF dual-hosted git repository.
liugddx pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/seatunnel.git
The following commit(s) were added to refs/heads/dev by this push:
new 5b5ee84130 [Improve][Transform] Add LLM model provider microsoft
(#7778)
5b5ee84130 is described below
commit 5b5ee8413052730cf5f361b701465b6f4fbfb219
Author: corgy-w <[email protected]>
AuthorDate: Tue Oct 8 09:35:23 2024 +0800
[Improve][Transform] Add LLM model provider microsoft (#7778)
Co-authored-by: Jia Fan <[email protected]>
---
docs/en/transform-v2/llm.md | 7 +-
docs/zh/transform-v2/llm.md | 4 +-
.../apache/seatunnel/e2e/transform/TestLLMIT.java | 7 ++
.../test/resources/llm_microsoft_transform.conf | 75 +++++++++++++++
.../src/test/resources/mockserver-config.json | 32 +++++++
.../transform/nlpmodel/ModelProvider.java | 1 +
.../transform/nlpmodel/llm/LLMTransform.java | 12 +++
.../nlpmodel/llm/LLMTransformFactory.java | 8 +-
.../llm/remote/microsoft/MicrosoftModel.java | 103 +++++++++++++++++++++
.../transform/llm/LLMRequestJsonTest.java | 34 +++++++
10 files changed, 277 insertions(+), 6 deletions(-)
diff --git a/docs/en/transform-v2/llm.md b/docs/en/transform-v2/llm.md
index 8ee5a36a9a..81dc9b3c70 100644
--- a/docs/en/transform-v2/llm.md
+++ b/docs/en/transform-v2/llm.md
@@ -11,7 +11,7 @@ more.
## Options
| name | type | required | default value |
-|------------------------| ------ | -------- |---------------|
+|------------------------|--------|----------|---------------|
| model_provider | enum | yes | |
| output_data_type | enum | no | String |
| output_column_name | string | no | llm_output |
@@ -28,7 +28,9 @@ more.
### model_provider
The model provider to use. The available options are:
-OPENAI, DOUBAO, KIMIAI, CUSTOM
+OPENAI, DOUBAO, KIMIAI, MICROSOFT, CUSTOM
+
+> tips: If you use Microsoft, please make sure api_path cannot be empty
### output_data_type
@@ -254,6 +256,7 @@ sink {
}
}
```
+
### Customize the LLM model
```hocon
diff --git a/docs/zh/transform-v2/llm.md b/docs/zh/transform-v2/llm.md
index c6f7aeefea..5ab37f5870 100644
--- a/docs/zh/transform-v2/llm.md
+++ b/docs/zh/transform-v2/llm.md
@@ -26,7 +26,9 @@
### model_provider
要使用的模型提供者。可用选项为:
-OPENAI、DOUBAO、KIMIAI、CUSTOM
+OPENAI、DOUBAO、KIMIAI、MICROSOFT, CUSTOM
+
+> tips: 如果使用 Microsoft, 请确保 api_path 配置不能为空
### output_data_type
diff --git
a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/java/org/apache/seatunnel/e2e/transform/TestLLMIT.java
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/java/org/apache/seatunnel/e2e/transform/TestLLMIT.java
index d98a5e7e33..f739e7af96 100644
---
a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/java/org/apache/seatunnel/e2e/transform/TestLLMIT.java
+++
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/java/org/apache/seatunnel/e2e/transform/TestLLMIT.java
@@ -88,6 +88,13 @@ public class TestLLMIT extends TestSuiteBase implements
TestResource {
Assertions.assertEquals(0, execResult.getExitCode());
}
+ @TestTemplate
+ public void testLLMWithMicrosoft(TestContainer container)
+ throws IOException, InterruptedException {
+ Container.ExecResult execResult =
container.executeJob("/llm_microsoft_transform.conf");
+ Assertions.assertEquals(0, execResult.getExitCode());
+ }
+
@TestTemplate
public void testLLMWithOpenAIBoolean(TestContainer container)
throws IOException, InterruptedException {
diff --git
a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/llm_microsoft_transform.conf
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/llm_microsoft_transform.conf
new file mode 100644
index 0000000000..37205a3aca
--- /dev/null
+++
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/llm_microsoft_transform.conf
@@ -0,0 +1,75 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+######
+###### This config file is a demonstration of streaming processing in
seatunnel config
+######
+
+env {
+ job.mode = "BATCH"
+}
+
+source {
+ FakeSource {
+ row.num = 5
+ schema = {
+ fields {
+ id = "int"
+ name = "string"
+ }
+ }
+ rows = [
+ {fields = [1, "Jia Fan"], kind = INSERT}
+ {fields = [2, "Hailin Wang"], kind = INSERT}
+ {fields = [3, "Tomas"], kind = INSERT}
+ {fields = [4, "Eric"], kind = INSERT}
+ {fields = [5, "Guangdong Liu"], kind = INSERT}
+ ]
+ result_table_name = "fake"
+ }
+}
+
+transform {
+ LLM {
+ source_table_name = "fake"
+ model_provider = MICROSOFT
+ model = gpt-35-turbo
+ api_key = sk-xxx
+ prompt = "Determine whether someone is Chinese or American by their name"
+ api_path =
"http://mockserver:1080/openai/deployments/${model}/chat/completions?api-version=2024-02-01"
+ result_table_name = "llm_output"
+ }
+}
+
+sink {
+ Assert {
+ source_table_name = "llm_output"
+ rules =
+ {
+ field_rules = [
+ {
+ field_name = llm_output
+ field_type = string
+ field_value = [
+ {
+ rule_type = NOT_NULL
+ }
+ ]
+ }
+ ]
+ }
+ }
+}
\ No newline at end of file
diff --git
a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/mockserver-config.json
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/mockserver-config.json
index 44dd94396e..ffdb409c9c 100644
---
a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/mockserver-config.json
+++
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/mockserver-config.json
@@ -104,5 +104,37 @@
"Content-Type": "application/json"
}
}
+ },
+ {
+ "httpRequest": {
+ "method": "POST",
+ "path": "/openai/deployments/gpt-35-turbo/chat/.*"
+ },
+ "httpResponse": {
+ "body": {
+ "id": "chatcmpl-6v7mkQj980V1yBec6ETrKPRqFjNw9",
+ "object": "chat.completion",
+ "created": 1679072642,
+ "model": "gpt-35-turbo",
+ "usage": {
+ "prompt_tokens": 58,
+ "completion_tokens": 68,
+ "total_tokens": 126
+ },
+ "choices": [
+ {
+ "message": {
+ "role": "assistant",
+ "content": "[\"Chinese\"]"
+ },
+ "finish_reason": "stop",
+ "index": 0
+ }
+ ]
+ },
+ "headers": {
+ "Content-Type": "application/json"
+ }
+ }
}
]
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelProvider.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelProvider.java
index ce22bc5a6d..3172137706 100644
---
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelProvider.java
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelProvider.java
@@ -26,6 +26,7 @@ public enum ModelProvider {
"https://ark.cn-beijing.volces.com/api/v3/embeddings"),
QIANFAN("",
"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings"),
KIMIAI("https://api.moonshot.cn/v1/chat/completions", ""),
+ MICROSOFT("", ""),
CUSTOM("", ""),
LOCAL("", "");
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransform.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransform.java
index 08ae42e443..069945951b 100644
---
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransform.java
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransform.java
@@ -31,6 +31,7 @@ import
org.apache.seatunnel.transform.nlpmodel.ModelTransformConfig;
import org.apache.seatunnel.transform.nlpmodel.llm.remote.Model;
import org.apache.seatunnel.transform.nlpmodel.llm.remote.custom.CustomModel;
import org.apache.seatunnel.transform.nlpmodel.llm.remote.kimiai.KimiAIModel;
+import
org.apache.seatunnel.transform.nlpmodel.llm.remote.microsoft.MicrosoftModel;
import org.apache.seatunnel.transform.nlpmodel.llm.remote.openai.OpenAIModel;
import lombok.NonNull;
@@ -94,6 +95,17 @@ public class LLMTransform extends SingleFieldOutputTransform
{
LLMTransformConfig.CustomRequestConfig
.CUSTOM_RESPONSE_PARSE));
break;
+ case MICROSOFT:
+ model =
+ new MicrosoftModel(
+ inputCatalogTable.getSeaTunnelRowType(),
+ outputDataType.getSqlType(),
+
config.get(LLMTransformConfig.INFERENCE_COLUMNS),
+ config.get(LLMTransformConfig.PROMPT),
+ config.get(LLMTransformConfig.MODEL),
+ config.get(LLMTransformConfig.API_KEY),
+
provider.usedLLMPath(config.get(LLMTransformConfig.API_PATH)));
+ break;
case OPENAI:
case DOUBAO:
model =
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransformFactory.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransformFactory.java
index eda57e1275..834c0b4d17 100644
---
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransformFactory.java
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransformFactory.java
@@ -26,7 +26,6 @@ import org.apache.seatunnel.api.table.factory.Factory;
import org.apache.seatunnel.api.table.factory.TableTransformFactory;
import org.apache.seatunnel.api.table.factory.TableTransformFactoryContext;
import org.apache.seatunnel.transform.nlpmodel.ModelProvider;
-import org.apache.seatunnel.transform.nlpmodel.ModelTransformConfig;
import com.google.auto.service.AutoService;
@@ -50,14 +49,17 @@ public class LLMTransformFactory implements
TableTransformFactory {
LLMTransformConfig.PROCESS_BATCH_SIZE)
.conditional(
LLMTransformConfig.MODEL_PROVIDER,
- Lists.newArrayList(ModelProvider.OPENAI,
ModelProvider.DOUBAO),
+ Lists.newArrayList(
+ ModelProvider.OPENAI,
+ ModelProvider.DOUBAO,
+ ModelProvider.MICROSOFT),
LLMTransformConfig.API_KEY)
.conditional(
LLMTransformConfig.MODEL_PROVIDER,
ModelProvider.QIANFAN,
LLMTransformConfig.API_KEY,
LLMTransformConfig.SECRET_KEY,
- ModelTransformConfig.OAUTH_PATH)
+ LLMTransformConfig.OAUTH_PATH)
.conditional(
LLMTransformConfig.MODEL_PROVIDER,
ModelProvider.CUSTOM,
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/microsoft/MicrosoftModel.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/microsoft/MicrosoftModel.java
new file mode 100644
index 0000000000..b6362c41a3
--- /dev/null
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/microsoft/MicrosoftModel.java
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.seatunnel.transform.nlpmodel.llm.remote.microsoft;
+
+import
org.apache.seatunnel.shade.com.fasterxml.jackson.core.type.TypeReference;
+import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.JsonNode;
+import
org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ArrayNode;
+import
org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ObjectNode;
+
+import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
+import org.apache.seatunnel.api.table.type.SqlType;
+import org.apache.seatunnel.transform.nlpmodel.CustomConfigPlaceholder;
+import org.apache.seatunnel.transform.nlpmodel.llm.remote.AbstractModel;
+
+import org.apache.http.client.config.RequestConfig;
+import org.apache.http.client.methods.CloseableHttpResponse;
+import org.apache.http.client.methods.HttpPost;
+import org.apache.http.entity.StringEntity;
+import org.apache.http.impl.client.CloseableHttpClient;
+import org.apache.http.impl.client.HttpClients;
+import org.apache.http.util.EntityUtils;
+
+import com.google.common.annotations.VisibleForTesting;
+
+import java.io.IOException;
+import java.util.List;
+
+public class MicrosoftModel extends AbstractModel {
+
+ private final CloseableHttpClient client;
+ private final String apiKey;
+ private final String model;
+ private final String apiPath;
+
+ public MicrosoftModel(
+ SeaTunnelRowType rowType,
+ SqlType outputType,
+ List<String> projectionColumns,
+ String prompt,
+ String model,
+ String apiKey,
+ String apiPath) {
+ super(rowType, outputType, projectionColumns, prompt);
+ this.model = model;
+ this.apiKey = apiKey;
+ this.apiPath =
+ CustomConfigPlaceholder.replacePlaceholders(
+ apiPath,
CustomConfigPlaceholder.REPLACE_PLACEHOLDER_MODEL, model, null);
+ this.client = HttpClients.createDefault();
+ }
+
+ @Override
+ protected List<String> chatWithModel(String prompt, String data) throws
IOException {
+ HttpPost post = new HttpPost(apiPath);
+ post.setHeader("Authorization", "Bearer " + apiKey);
+ post.setHeader("Content-Type", "application/json");
+ ObjectNode objectNode = createJsonNodeFromData(prompt, data);
+ post.setEntity(new
StringEntity(OBJECT_MAPPER.writeValueAsString(objectNode), "UTF-8"));
+ post.setConfig(
+
RequestConfig.custom().setConnectTimeout(20000).setSocketTimeout(20000).build());
+ CloseableHttpResponse response = client.execute(post);
+ String responseStr = EntityUtils.toString(response.getEntity());
+ if (response.getStatusLine().getStatusCode() != 200) {
+ throw new IOException("Failed to chat with model, response: " +
responseStr);
+ }
+
+ JsonNode result = OBJECT_MAPPER.readTree(responseStr);
+ String resultData =
result.get("choices").get(0).get("message").get("content").asText();
+ return OBJECT_MAPPER.readValue(
+ convertData(resultData), new TypeReference<List<String>>() {});
+ }
+
+ @VisibleForTesting
+ public ObjectNode createJsonNodeFromData(String prompt, String data) {
+ ObjectNode objectNode = OBJECT_MAPPER.createObjectNode();
+ ArrayNode messages = objectNode.putArray("messages");
+ messages.addObject().put("role", "system").put("content", prompt);
+ messages.addObject().put("role", "user").put("content", data);
+ return objectNode;
+ }
+
+ @Override
+ public void close() throws IOException {
+ if (client != null) {
+ client.close();
+ }
+ }
+}
diff --git
a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/llm/LLMRequestJsonTest.java
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/llm/LLMRequestJsonTest.java
index 91666c4139..870af980fe 100644
---
a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/llm/LLMRequestJsonTest.java
+++
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/llm/LLMRequestJsonTest.java
@@ -28,6 +28,7 @@ import org.apache.seatunnel.api.table.type.SqlType;
import org.apache.seatunnel.format.json.RowToJsonConverters;
import org.apache.seatunnel.transform.nlpmodel.llm.remote.custom.CustomModel;
import org.apache.seatunnel.transform.nlpmodel.llm.remote.kimiai.KimiAIModel;
+import
org.apache.seatunnel.transform.nlpmodel.llm.remote.microsoft.MicrosoftModel;
import org.apache.seatunnel.transform.nlpmodel.llm.remote.openai.OpenAIModel;
import org.junit.jupiter.api.Assertions;
@@ -36,6 +37,7 @@ import org.junit.jupiter.api.Test;
import com.google.common.collect.Lists;
import java.io.IOException;
+import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
@@ -130,6 +132,38 @@ public class LLMRequestJsonTest {
model.close();
}
+ @Test
+ void testMicrosoftRequestJson() throws Exception {
+ SeaTunnelRowType rowType =
+ new SeaTunnelRowType(
+ new String[] {"id", "name"},
+ new SeaTunnelDataType[] {BasicType.INT_TYPE,
BasicType.STRING_TYPE});
+ MicrosoftModel model =
+ new MicrosoftModel(
+ rowType,
+ SqlType.STRING,
+ null,
+ "Determine whether someone is Chinese or American by
their name",
+ "gpt-35-turbo",
+ "sk-xxx",
+
"https://api.moonshot.cn/openai/deployments/${model}/chat/completions?api-version=2024-02-01");
+ Field apiPathField = model.getClass().getDeclaredField("apiPath");
+ apiPathField.setAccessible(true);
+ String apiPath = (String) apiPathField.get(model);
+ Assertions.assertEquals(
+
"https://api.moonshot.cn/openai/deployments/gpt-35-turbo/chat/completions?api-version=2024-02-01",
+ apiPath);
+
+ ObjectNode node =
+ model.createJsonNodeFromData(
+ "Determine whether someone is Chinese or American by
their name",
+ "{\"id\":1, \"name\":\"John\"}");
+ Assertions.assertEquals(
+ "{\"messages\":[{\"role\":\"system\",\"content\":\"Determine
whether someone is Chinese or American by their
name\"},{\"role\":\"user\",\"content\":\"{\\\"id\\\":1,
\\\"name\\\":\\\"John\\\"}\"}]}",
+ OBJECT_MAPPER.writeValueAsString(node));
+ model.close();
+ }
+
@Test
void testCustomRequestJson() throws IOException {
SeaTunnelRowType rowType =