This is an automated email from the ASF dual-hosted git repository.
liugddx pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/seatunnel.git
The following commit(s) were added to refs/heads/dev by this push:
new 855254e737 [Feature][Transform] Add LLM transform (#7303)
855254e737 is described below
commit 855254e737051edbaf4ca08b95ca010fe18fd214
Author: Jia Fan <[email protected]>
AuthorDate: Wed Aug 7 11:03:06 2024 +0800
[Feature][Transform] Add LLM transform (#7303)
* [Feature][Transform] Add LLM transform
* update
* update
* retrigger
---
docs/en/transform-v2/llm.md | 122 +++++++++++++++++++++
docs/zh/transform-v2/llm.md | 120 ++++++++++++++++++++
.../apache/seatunnel/e2e/transform/TestLLMIT.java | 90 +++++++++++++++
.../src/test/resources/llm_openai_transform.conf | 75 +++++++++++++
.../src/test/resources/mockserver-config.json | 40 +++++++
seatunnel-transforms-v2/pom.xml | 15 +++
.../transform/common/SeaTunnelRowAccessor.java | 4 +
.../seatunnel/transform/llm/LLMTransform.java | 119 ++++++++++++++++++++
.../transform/llm/LLMTransformConfig.java | 71 ++++++++++++
.../transform/llm/LLMTransformFactory.java | 59 ++++++++++
.../ModelProvider.java} | 28 +----
.../transform/llm/model/AbstractModel.java | 69 ++++++++++++
.../model/Model.java} | 27 +----
.../transform/llm/model/openai/OpenAIModel.java | 104 ++++++++++++++++++
.../transform/LLMTransformFactoryTest.java} | 30 ++---
.../transform/llm/LLMRequestJsonTest.java | 61 +++++++++++
tools/dependencies/known-dependencies.txt | 2 +
17 files changed, 969 insertions(+), 67 deletions(-)
diff --git a/docs/en/transform-v2/llm.md b/docs/en/transform-v2/llm.md
new file mode 100644
index 0000000000..d03b8226f0
--- /dev/null
+++ b/docs/en/transform-v2/llm.md
@@ -0,0 +1,122 @@
+# LLM
+
+> LLM transform plugin
+
+## Description
+
+Leverage the power of a large language model (LLM) to process data by sending
it to the LLM and receiving the
+generated results. Utilize the LLM's capabilities to label, clean, enrich
data, perform data inference, and
+more.
+
+## Options
+
+| name | type | required | default value
|
+|------------------|--------|----------|--------------------------------------------|
+| model_provider | enum | yes |
|
+| output_data_type | enum | no | String
|
+| prompt | string | yes |
|
+| model | string | yes |
|
+| api_key | string | yes |
|
+| openai.api_path | string | no |
https://api.openai.com/v1/chat/completions |
+
+### model_provider
+
+The model provider to use. The available options are:
+OPENAI
+
+### output_data_type
+
+The data type of the output data. The available options are:
+STRING,INT,BIGINT,DOUBLE,BOOLEAN.
+Default value is STRING.
+
+### prompt
+
+The prompt to send to the LLM. This parameter defines how LLM will process and
return data, eg:
+
+The data read from source is a table like this:
+
+| name | age |
+|---------------|-----|
+| Jia Fan | 20 |
+| Hailin Wang | 20 |
+| Eric | 20 |
+| Guangdong Liu | 20 |
+
+The prompt can be:
+
+```
+Determine whether someone is Chinese or American by their name
+```
+
+The result will be:
+
+| name | age | llm_output |
+|---------------|-----|------------|
+| Jia Fan | 20 | Chinese |
+| Hailin Wang | 20 | Chinese |
+| Eric | 20 | American |
+| Guangdong Liu | 20 | Chinese |
+
+### model
+
+The model to use. Different model providers have different models. For
example, the OpenAI model can be `gpt-4o-mini`.
+If you use OpenAI model, please refer
https://platform.openai.com/docs/models/model-endpoint-compatibility of
`/v1/chat/completions` endpoint.
+
+### api_key
+
+The API key to use for the model provider.
+If you use OpenAI model, please refer
https://platform.openai.com/docs/api-reference/api-keys of how to get the API
key.
+
+### openai.api_path
+
+The API path to use for the OpenAI model provider. In most cases, you do not
need to change this configuration. If you are using an API agent's service, you
may need to configure it to the agent's API address.
+
+### common options [string]
+
+Transform plugin common parameters, please refer to [Transform
Plugin](common-options.md) for details
+
+## Example
+
+Determine the user's country through a LLM.
+
+```hocon
+env {
+ parallelism = 1
+ job.mode = "BATCH"
+}
+
+source {
+ FakeSource {
+ row.num = 5
+ schema = {
+ fields {
+ id = "int"
+ name = "string"
+ }
+ }
+ rows = [
+ {fields = [1, "Jia Fan"], kind = INSERT}
+ {fields = [2, "Hailin Wang"], kind = INSERT}
+ {fields = [3, "Tomas"], kind = INSERT}
+ {fields = [4, "Eric"], kind = INSERT}
+ {fields = [5, "Guangdong Liu"], kind = INSERT}
+ ]
+ }
+}
+
+transform {
+ LLM {
+ model_provider = OPENAI
+ model = gpt-4o-mini
+ api_key = sk-xxx
+ prompt = "Determine whether someone is Chinese or American by their name"
+ }
+}
+
+sink {
+ console {
+ }
+}
+```
+
diff --git a/docs/zh/transform-v2/llm.md b/docs/zh/transform-v2/llm.md
new file mode 100644
index 0000000000..acd3245b8e
--- /dev/null
+++ b/docs/zh/transform-v2/llm.md
@@ -0,0 +1,120 @@
+# LLM
+
+> LLM 转换插件
+
+## 描述
+
+利用大型语言模型 (LLM) 的强大功能来处理数据,方法是将数据发送到 LLM 并接收生成的结果。利用 LLM 的功能来标记、清理、丰富数据、执行数据推理等。
+
+## 属性
+
+| 名称 | 类型 | 是否必须 | 默认值
|
+|------------------|--------|------|--------------------------------------------|
+| model_provider | enum | yes |
|
+| output_data_type | enum | no | String
|
+| prompt | string | yes |
|
+| model | string | yes |
|
+| api_key | string | yes |
|
+| openai.api_path | string | no |
https://api.openai.com/v1/chat/completions |
+
+### model_provider
+
+要使用的模型提供者。可用选项为:
+OPENAI
+
+### output_data_type
+
+输出数据的数据类型。可用选项为:
+STRING,INT,BIGINT,DOUBLE,BOOLEAN.
+默认值为 STRING。
+
+### prompt
+
+发送到 LLM 的提示。此参数定义 LLM 将如何处理和返回数据,例如:
+
+从源读取的数据是这样的表格:
+
+| name | age |
+|---------------|-----|
+| Jia Fan | 20 |
+| Hailin Wang | 20 |
+| Eric | 20 |
+| Guangdong Liu | 20 |
+
+我们可以使用以下提示:
+
+```
+Determine whether someone is Chinese or American by their name
+```
+
+这将返回:
+
+| name | age | llm_output |
+|---------------|-----|------------|
+| Jia Fan | 20 | Chinese |
+| Hailin Wang | 20 | Chinese |
+| Eric | 20 | American |
+| Guangdong Liu | 20 | Chinese |
+
+### model
+
+要使用的模型。不同的模型提供者有不同的模型。例如,OpenAI 模型可以是 `gpt-4o-mini`。
+如果使用 OpenAI 模型,请参考
https://platform.openai.com/docs/models/model-endpoint-compatibility
文档的`/v1/chat/completions` 端点。
+
+### api_key
+
+用于模型提供者的 API 密钥。
+如果使用 OpenAI 模型,请参考 https://platform.openai.com/docs/api-reference/api-keys
文档的如何获取 API 密钥。
+
+### openai.api_path
+
+用于 OpenAI 模型提供者的 API 路径。在大多数情况下,您不需要更改此配置。如果使用 API 代理的服务,您可能需要将其配置为代理的 API 地址。
+
+### common options [string]
+
+转换插件的常见参数, 请参考 [Transform Plugin](common-options.md) 了解详情
+
+## 示例
+
+通过 LLM 确定用户所在的国家。
+
+```hocon
+env {
+ parallelism = 1
+ job.mode = "BATCH"
+}
+
+source {
+ FakeSource {
+ row.num = 5
+ schema = {
+ fields {
+ id = "int"
+ name = "string"
+ }
+ }
+ rows = [
+ {fields = [1, "Jia Fan"], kind = INSERT}
+ {fields = [2, "Hailin Wang"], kind = INSERT}
+ {fields = [3, "Tomas"], kind = INSERT}
+ {fields = [4, "Eric"], kind = INSERT}
+ {fields = [5, "Guangdong Liu"], kind = INSERT}
+ ]
+ }
+}
+
+transform {
+ LLM {
+ model_provider = OPENAI
+ model = gpt-4o-mini
+ api_key = sk-xxx
+ prompt = "Determine whether someone is Chinese or American by their name"
+ }
+}
+
+sink {
+ console {
+ }
+}
+```
+
diff --git
a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/java/org/apache/seatunnel/e2e/transform/TestLLMIT.java
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/java/org/apache/seatunnel/e2e/transform/TestLLMIT.java
new file mode 100644
index 0000000000..6f17c5a94f
--- /dev/null
+++
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/java/org/apache/seatunnel/e2e/transform/TestLLMIT.java
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.seatunnel.e2e.transform;
+
+import org.apache.seatunnel.e2e.common.TestResource;
+import org.apache.seatunnel.e2e.common.container.TestContainer;
+
+import org.junit.jupiter.api.AfterAll;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.TestTemplate;
+import org.testcontainers.containers.Container;
+import org.testcontainers.containers.GenericContainer;
+import org.testcontainers.containers.output.Slf4jLogConsumer;
+import org.testcontainers.containers.wait.strategy.HttpWaitStrategy;
+import org.testcontainers.lifecycle.Startables;
+import org.testcontainers.utility.DockerImageName;
+import org.testcontainers.utility.DockerLoggerFactory;
+import org.testcontainers.utility.MountableFile;
+
+import java.io.File;
+import java.io.IOException;
+import java.net.URL;
+import java.util.Optional;
+import java.util.stream.Stream;
+
+public class TestLLMIT extends TestSuiteBase implements TestResource {
+ private static final String TMP_DIR = "/tmp";
+ private GenericContainer<?> mockserverContainer;
+ private static final String IMAGE = "mockserver/mockserver:5.14.0";
+
+ @BeforeAll
+ @Override
+ public void startUp() {
+ Optional<URL> resource =
+
Optional.ofNullable(TestLLMIT.class.getResource("/mockserver-config.json"));
+ this.mockserverContainer =
+ new GenericContainer<>(DockerImageName.parse(IMAGE))
+ .withNetwork(NETWORK)
+ .withNetworkAliases("mockserver")
+ .withExposedPorts(1080)
+ .withCopyFileToContainer(
+ MountableFile.forHostPath(
+ new File(
+ resource.orElseThrow(
+ () ->
+
new IllegalArgumentException(
+
"Can not get config file of mockServer"))
+ .getPath())
+ .getAbsolutePath()),
+ TMP_DIR + "/mockserver-config.json")
+ .withEnv(
+ "MOCKSERVER_INITIALIZATION_JSON_PATH",
+ TMP_DIR + "/mockserver-config.json")
+ .withEnv("MOCKSERVER_LOG_LEVEL", "WARN")
+ .withLogConsumer(new
Slf4jLogConsumer(DockerLoggerFactory.getLogger(IMAGE)))
+ .waitingFor(new
HttpWaitStrategy().forPath("/").forStatusCode(404));
+ Startables.deepStart(Stream.of(mockserverContainer)).join();
+ }
+
+ @AfterAll
+ @Override
+ public void tearDown() throws Exception {
+ if (mockserverContainer != null) {
+ mockserverContainer.stop();
+ }
+ }
+
+ @TestTemplate
+ public void testLLMWithOpenAI(TestContainer container)
+ throws IOException, InterruptedException {
+ Container.ExecResult execResult =
container.executeJob("/llm_openai_transform.conf");
+ Assertions.assertEquals(0, execResult.getExitCode());
+ }
+}
diff --git
a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/llm_openai_transform.conf
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/llm_openai_transform.conf
new file mode 100644
index 0000000000..5449593589
--- /dev/null
+++
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/llm_openai_transform.conf
@@ -0,0 +1,75 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+######
+###### This config file is a demonstration of streaming processing in
seatunnel config
+######
+
+env {
+ job.mode = "BATCH"
+}
+
+source {
+ FakeSource {
+ row.num = 5
+ schema = {
+ fields {
+ id = "int"
+ name = "string"
+ }
+ }
+ rows = [
+ {fields = [1, "Jia Fan"], kind = INSERT}
+ {fields = [2, "Hailin Wang"], kind = INSERT}
+ {fields = [3, "Tomas"], kind = INSERT}
+ {fields = [4, "Eric"], kind = INSERT}
+ {fields = [5, "Guangdong Liu"], kind = INSERT}
+ ]
+ result_table_name = "fake"
+ }
+}
+
+transform {
+ LLM {
+ source_table_name = "fake"
+ model_provider = OPENAI
+ model = gpt-4o-mini
+ api_key = sk-xxx
+ prompt = "Determine whether someone is Chinese or American by their name"
+ openai.api_path = "http://mockserver:1080/v1/chat/completions"
+ result_table_name = "llm_output"
+ }
+}
+
+sink {
+ Assert {
+ source_table_name = "llm_output"
+ rules =
+ {
+ field_rules = [
+ {
+ field_name = llm_output
+ field_type = string
+ field_value = [
+ {
+ rule_type = NOT_NULL
+ }
+ ]
+ }
+ ]
+ }
+ }
+}
\ No newline at end of file
diff --git
a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/mockserver-config.json
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/mockserver-config.json
new file mode 100644
index 0000000000..b4a2e53bea
--- /dev/null
+++
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/mockserver-config.json
@@ -0,0 +1,40 @@
+//
https://www.mock-server.com/mock_server/getting_started.html#request_matchers
+
+[
+ {
+ "httpRequest": {
+ "method": "POST",
+ "path": "/v1/chat/completions"
+ },
+ "httpResponse": {
+ "body": {
+ "id": "chatcmpl-9s4hoBNGV0d9Mudkhvgzg64DAWPnx",
+ "object": "chat.completion",
+ "created": 1722674828,
+ "model": "gpt-4o-mini",
+ "choices": [
+ {
+ "index": 0,
+ "message": {
+ "role": "assistant",
+ "content": "[\"Chinese\"]"
+ },
+ "logprobs": null,
+ "finish_reason": "stop"
+ }
+ ],
+ "usage": {
+ "prompt_tokens": 107,
+ "completion_tokens": 3,
+ "total_tokens": 110
+ },
+ "system_fingerprint": "fp_0f03d4f0ee",
+ "code": 0,
+ "msg": "ok"
+ },
+ "headers": {
+ "Content-Type": "application/json"
+ }
+ }
+ }
+]
diff --git a/seatunnel-transforms-v2/pom.xml b/seatunnel-transforms-v2/pom.xml
index ae8909f463..4cbef9a4b8 100644
--- a/seatunnel-transforms-v2/pom.xml
+++ b/seatunnel-transforms-v2/pom.xml
@@ -29,6 +29,11 @@
<artifactId>seatunnel-transforms-v2</artifactId>
<name>SeaTunnel : Transforms : V2</name>
+ <properties>
+ <httpclient.version>4.5.13</httpclient.version>
+ <httpcore.version>4.4.4</httpcore.version>
+ </properties>
+
<dependencyManagement>
<dependencies>
<dependency>
@@ -77,6 +82,16 @@
<version>${project.version}</version>
<classifier>optional</classifier>
</dependency>
+ <dependency>
+ <groupId>org.apache.httpcomponents</groupId>
+ <artifactId>httpclient</artifactId>
+ <version>${httpclient.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.httpcomponents</groupId>
+ <artifactId>httpcore</artifactId>
+ <version>${httpcore.version}</version>
+ </dependency>
</dependencies>
<build>
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/common/SeaTunnelRowAccessor.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/common/SeaTunnelRowAccessor.java
index 0224ef4b8f..5b97f34168 100644
---
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/common/SeaTunnelRowAccessor.java
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/common/SeaTunnelRowAccessor.java
@@ -41,4 +41,8 @@ public class SeaTunnelRowAccessor {
public Object getField(int pos) {
return row.getField(pos);
}
+
+ public Object[] getFields() {
+ return row.getFields();
+ }
}
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/llm/LLMTransform.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/llm/LLMTransform.java
new file mode 100644
index 0000000000..d19960044f
--- /dev/null
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/llm/LLMTransform.java
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.seatunnel.transform.llm;
+
+import org.apache.seatunnel.api.configuration.ReadonlyConfig;
+import org.apache.seatunnel.api.table.catalog.CatalogTable;
+import org.apache.seatunnel.api.table.catalog.Column;
+import org.apache.seatunnel.api.table.catalog.PhysicalColumn;
+import org.apache.seatunnel.api.table.catalog.SeaTunnelDataTypeConvertorUtil;
+import org.apache.seatunnel.api.table.type.SeaTunnelDataType;
+import org.apache.seatunnel.api.table.type.SeaTunnelRow;
+import org.apache.seatunnel.transform.common.SeaTunnelRowAccessor;
+import org.apache.seatunnel.transform.common.SingleFieldOutputTransform;
+import org.apache.seatunnel.transform.llm.model.Model;
+import org.apache.seatunnel.transform.llm.model.openai.OpenAIModel;
+
+import lombok.NonNull;
+import lombok.SneakyThrows;
+
+import java.util.Collections;
+import java.util.List;
+
+public class LLMTransform extends SingleFieldOutputTransform {
+ private final ReadonlyConfig config;
+ private final SeaTunnelDataType<?> outputDataType;
+ private Model model;
+
+ public LLMTransform(@NonNull ReadonlyConfig config, @NonNull CatalogTable
inputCatalogTable) {
+ super(inputCatalogTable);
+ this.config = config;
+ this.outputDataType =
+ SeaTunnelDataTypeConvertorUtil.deserializeSeaTunnelDataType(
+ "output",
config.get(LLMTransformConfig.OUTPUT_DATA_TYPE).toString());
+ }
+
+ private void tryOpen() {
+ if (model == null) {
+ open();
+ }
+ }
+
+ @Override
+ public String getPluginName() {
+ return "LLM";
+ }
+
+ @Override
+ public void open() {
+ ModelProvider provider = config.get(LLMTransformConfig.MODEL_PROVIDER);
+ if (provider.equals(ModelProvider.OPENAI)) {
+ model =
+ new OpenAIModel(
+ inputCatalogTable.getSeaTunnelRowType(),
+ outputDataType.getSqlType(),
+ config.get(LLMTransformConfig.PROMPT),
+ config.get(LLMTransformConfig.MODEL),
+ config.get(LLMTransformConfig.API_KEY),
+ config.get(LLMTransformConfig.OPENAI_API_PATH));
+ } else {
+ throw new IllegalArgumentException("Unsupported model provider: "
+ provider);
+ }
+ }
+
+ @Override
+ protected Object getOutputFieldValue(SeaTunnelRowAccessor inputRow) {
+ tryOpen();
+ SeaTunnelRow seaTunnelRow = new SeaTunnelRow(inputRow.getFields());
+ try {
+ List<String> values =
model.inference(Collections.singletonList(seaTunnelRow));
+ switch (outputDataType.getSqlType()) {
+ case STRING:
+ return String.valueOf(values.get(0));
+ case INT:
+ return Integer.parseInt(values.get(0));
+ case BIGINT:
+ return Long.parseLong(values.get(0));
+ case DOUBLE:
+ return Double.parseDouble(values.get(0));
+ case BOOLEAN:
+ return Boolean.parseBoolean(values.get(0));
+ default:
+ throw new IllegalArgumentException(
+ "Unsupported output data type: " + outputDataType);
+ }
+ } catch (Exception e) {
+ throw new RuntimeException(
+ String.format("Failed to inference model with row %s",
seaTunnelRow), e);
+ }
+ }
+
+ @Override
+ protected Column getOutputColumn() {
+ return PhysicalColumn.of(
+ "llm_output", outputDataType, (Long) null, true, null, "Output
column of LLM");
+ }
+
+ @SneakyThrows
+ @Override
+ public void close() {
+ if (model != null) {
+ model.close();
+ }
+ }
+}
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/llm/LLMTransformConfig.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/llm/LLMTransformConfig.java
new file mode 100644
index 0000000000..ca3da7e670
--- /dev/null
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/llm/LLMTransformConfig.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.seatunnel.transform.llm;
+
+import org.apache.seatunnel.api.configuration.Option;
+import org.apache.seatunnel.api.configuration.Options;
+import org.apache.seatunnel.api.table.type.SqlType;
+
+import java.io.Serializable;
+
+public class LLMTransformConfig implements Serializable {
+
+ public static final Option<ModelProvider> MODEL_PROVIDER =
+ Options.key("model_provider")
+ .enumType(ModelProvider.class)
+ .noDefaultValue()
+ .withDescription("The model provider of LLM");
+
+ public static final Option<SqlType> OUTPUT_DATA_TYPE =
+ Options.key("output_data_type")
+ .enumType(SqlType.class)
+ .defaultValue(SqlType.STRING)
+ .withDescription("The output data type of LLM");
+
+ public static final Option<String> PROMPT =
+ Options.key("prompt")
+ .stringType()
+ .noDefaultValue()
+ .withDescription("The prompt of LLM");
+
+ public static final Option<String> MODEL =
+ Options.key("model")
+ .stringType()
+ .noDefaultValue()
+ .withDescription(
+ "The model of LLM, eg: if the model provider is
OpenAI, the model should be gpt-3.5-turbo/gpt-4o-mini, etc.");
+
+ public static final Option<String> API_KEY =
+ Options.key("api_key")
+ .stringType()
+ .noDefaultValue()
+ .withDescription("The API key of LLM");
+
+ public static final Option<Integer> INFERENCE_BATCH_SIZE =
+ Options.key("inference_batch_size")
+ .intType()
+ .defaultValue(100)
+ .withDescription("The row batch size of each inference");
+
+ // OPENAI specific options
+ public static final Option<String> OPENAI_API_PATH =
+ Options.key("openai.api_path")
+ .stringType()
+ .defaultValue("https://api.openai.com/v1/chat/completions")
+ .withDescription("The API path of OpenAI");
+}
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/llm/LLMTransformFactory.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/llm/LLMTransformFactory.java
new file mode 100644
index 0000000000..6fe5d53fe5
--- /dev/null
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/llm/LLMTransformFactory.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.seatunnel.transform.llm;
+
+import org.apache.seatunnel.api.configuration.util.OptionRule;
+import org.apache.seatunnel.api.table.catalog.CatalogTable;
+import org.apache.seatunnel.api.table.connector.TableTransform;
+import org.apache.seatunnel.api.table.factory.Factory;
+import org.apache.seatunnel.api.table.factory.TableTransformFactory;
+import org.apache.seatunnel.api.table.factory.TableTransformFactoryContext;
+
+import com.google.auto.service.AutoService;
+
+@AutoService(Factory.class)
+public class LLMTransformFactory implements TableTransformFactory {
+ @Override
+ public String factoryIdentifier() {
+ return "LLM";
+ }
+
+ @Override
+ public OptionRule optionRule() {
+ return OptionRule.builder()
+ .required(
+ LLMTransformConfig.MODEL_PROVIDER,
+ LLMTransformConfig.MODEL,
+ LLMTransformConfig.PROMPT,
+ LLMTransformConfig.API_KEY)
+ .optional(
+ LLMTransformConfig.OUTPUT_DATA_TYPE,
+ LLMTransformConfig.INFERENCE_BATCH_SIZE)
+ .conditional(
+ LLMTransformConfig.MODEL_PROVIDER,
+ ModelProvider.OPENAI,
+ LLMTransformConfig.OPENAI_API_PATH)
+ .build();
+ }
+
+ @Override
+ public TableTransform createTransform(TableTransformFactoryContext
context) {
+ CatalogTable catalogTable = context.getCatalogTables().get(0);
+ return () -> new LLMTransform(context.getOptions(), catalogTable);
+ }
+}
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/common/SeaTunnelRowAccessor.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/llm/ModelProvider.java
similarity index 58%
copy from
seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/common/SeaTunnelRowAccessor.java
copy to
seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/llm/ModelProvider.java
index 0224ef4b8f..a55d706c09 100644
---
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/common/SeaTunnelRowAccessor.java
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/llm/ModelProvider.java
@@ -15,30 +15,8 @@
* limitations under the License.
*/
-package org.apache.seatunnel.transform.common;
+package org.apache.seatunnel.transform.llm;
-import org.apache.seatunnel.api.table.type.RowKind;
-import org.apache.seatunnel.api.table.type.SeaTunnelRow;
-
-import lombok.AllArgsConstructor;
-
-@AllArgsConstructor
-public class SeaTunnelRowAccessor {
- private final SeaTunnelRow row;
-
- public int getArity() {
- return row.getArity();
- }
-
- public String getTableId() {
- return row.getTableId();
- }
-
- public RowKind getRowKind() {
- return row.getRowKind();
- }
-
- public Object getField(int pos) {
- return row.getField(pos);
- }
+public enum ModelProvider {
+ OPENAI
}
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/llm/model/AbstractModel.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/llm/model/AbstractModel.java
new file mode 100644
index 0000000000..51d674c0ad
--- /dev/null
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/llm/model/AbstractModel.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.seatunnel.transform.llm.model;
+
+import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.ObjectMapper;
+import
org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ArrayNode;
+import
org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ObjectNode;
+
+import org.apache.seatunnel.api.table.type.SeaTunnelRow;
+import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
+import org.apache.seatunnel.api.table.type.SqlType;
+import org.apache.seatunnel.format.json.RowToJsonConverters;
+
+import java.io.IOException;
+import java.util.List;
+
+public abstract class AbstractModel implements Model {
+
+ protected static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
+ private final RowToJsonConverters.RowToJsonConverter rowToJsonConverters;
+ private final String prompt;
+ private final SqlType outputType;
+
+ public AbstractModel(SeaTunnelRowType rowType, SqlType outputType, String
prompt) {
+ this.prompt = prompt;
+ this.outputType = outputType;
+ this.rowToJsonConverters = new
RowToJsonConverters().createConverter(rowType, null);
+ }
+
+ private String getPromptWithLimit() {
+ return prompt
+ + "\n The following rules need to be followed: "
+ + "\n 1. The received data is an array, and the result is
returned in the form of an array."
+ + "\n 2. Only the result needs to be returned, and no other
information can be returned."
+ + "\n 3. The element type of the array is "
+ + outputType.toString()
+ + "."
+ + "\n Eg: [\"value1\", \"value2\"]";
+ }
+
+ @Override
+ public List<String> inference(List<SeaTunnelRow> rows) throws IOException {
+ ArrayNode rowsNode = OBJECT_MAPPER.createArrayNode();
+ for (SeaTunnelRow row : rows) {
+ ObjectNode rowNode = OBJECT_MAPPER.createObjectNode();
+ rowToJsonConverters.convert(OBJECT_MAPPER, rowNode, row);
+ rowsNode.add(rowNode);
+ }
+ return chatWithModel(getPromptWithLimit(),
OBJECT_MAPPER.writeValueAsString(rowsNode));
+ }
+
+ protected abstract List<String> chatWithModel(String promptWithLimit,
String rowsJson)
+ throws IOException;
+}
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/common/SeaTunnelRowAccessor.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/llm/model/Model.java
similarity index 62%
copy from
seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/common/SeaTunnelRowAccessor.java
copy to
seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/llm/model/Model.java
index 0224ef4b8f..77a8da6328 100644
---
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/common/SeaTunnelRowAccessor.java
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/llm/model/Model.java
@@ -15,30 +15,15 @@
* limitations under the License.
*/
-package org.apache.seatunnel.transform.common;
+package org.apache.seatunnel.transform.llm.model;
-import org.apache.seatunnel.api.table.type.RowKind;
import org.apache.seatunnel.api.table.type.SeaTunnelRow;
-import lombok.AllArgsConstructor;
+import java.io.Closeable;
+import java.io.IOException;
+import java.util.List;
-@AllArgsConstructor
-public class SeaTunnelRowAccessor {
- private final SeaTunnelRow row;
+public interface Model extends Closeable {
- public int getArity() {
- return row.getArity();
- }
-
- public String getTableId() {
- return row.getTableId();
- }
-
- public RowKind getRowKind() {
- return row.getRowKind();
- }
-
- public Object getField(int pos) {
- return row.getField(pos);
- }
+ List<String> inference(List<SeaTunnelRow> rows) throws IOException;
}
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/llm/model/openai/OpenAIModel.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/llm/model/openai/OpenAIModel.java
new file mode 100644
index 0000000000..9477b87320
--- /dev/null
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/llm/model/openai/OpenAIModel.java
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.seatunnel.transform.llm.model.openai;
+
+import
org.apache.seatunnel.shade.com.fasterxml.jackson.core.type.TypeReference;
+import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.JsonNode;
+import
org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ArrayNode;
+import
org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ObjectNode;
+
+import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
+import org.apache.seatunnel.api.table.type.SqlType;
+import org.apache.seatunnel.transform.llm.model.AbstractModel;
+
+import org.apache.http.client.config.RequestConfig;
+import org.apache.http.client.methods.CloseableHttpResponse;
+import org.apache.http.client.methods.HttpPost;
+import org.apache.http.entity.StringEntity;
+import org.apache.http.impl.client.CloseableHttpClient;
+import org.apache.http.impl.client.HttpClients;
+import org.apache.http.util.EntityUtils;
+
+import com.google.common.annotations.VisibleForTesting;
+import lombok.extern.slf4j.Slf4j;
+
+import java.io.IOException;
+import java.util.List;
+
+/**
+ * OpenAI model. Refer <a
href="https://platform.openai.com/docs/api-reference/chat">chat api </a>
+ */
+@Slf4j
+public class OpenAIModel extends AbstractModel {
+
+ private final CloseableHttpClient client;
+ private final String apiKey;
+ private final String model;
+ private final String apiPath;
+
+ public OpenAIModel(
+ SeaTunnelRowType rowType,
+ SqlType outputType,
+ String prompt,
+ String model,
+ String apiKey,
+ String apiPath) {
+ super(rowType, outputType, prompt);
+ this.apiKey = apiKey;
+ this.apiPath = apiPath;
+ this.model = model;
+ this.client = HttpClients.createDefault();
+ }
+
+ @Override
+ protected List<String> chatWithModel(String prompt, String data) throws
IOException {
+ HttpPost post = new HttpPost(apiPath);
+ post.setHeader("Authorization", "Bearer " + apiKey);
+ post.setHeader("Content-Type", "application/json");
+ ObjectNode objectNode = createJsonNodeFromData(prompt, data);
+ post.setEntity(new
StringEntity(OBJECT_MAPPER.writeValueAsString(objectNode), "UTF-8"));
+ post.setConfig(
+
RequestConfig.custom().setConnectTimeout(20000).setSocketTimeout(20000).build());
+ CloseableHttpResponse response = client.execute(post);
+ String responseStr = EntityUtils.toString(response.getEntity());
+ if (response.getStatusLine().getStatusCode() != 200) {
+ throw new IOException("Failed to chat with model, response: " +
responseStr);
+ }
+
+ JsonNode result = OBJECT_MAPPER.readTree(responseStr);
+ String resultData =
result.get("choices").get(0).get("message").get("content").asText();
+ return OBJECT_MAPPER.readValue(resultData, new
TypeReference<List<String>>() {});
+ }
+
+ @VisibleForTesting
+ public ObjectNode createJsonNodeFromData(String prompt, String data) {
+ ObjectNode objectNode = OBJECT_MAPPER.createObjectNode();
+ objectNode.put("model", model);
+ ArrayNode messages = objectNode.putArray("messages");
+ messages.addObject().put("role", "system").put("content", prompt);
+ messages.addObject().put("role", "user").put("content", data);
+ return objectNode;
+ }
+
+ @Override
+ public void close() throws IOException {
+ if (client != null) {
+ client.close();
+ }
+ }
+}
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/common/SeaTunnelRowAccessor.java
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/LLMTransformFactoryTest.java
similarity index 58%
copy from
seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/common/SeaTunnelRowAccessor.java
copy to
seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/LLMTransformFactoryTest.java
index 0224ef4b8f..39b2769480 100644
---
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/common/SeaTunnelRowAccessor.java
+++
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/LLMTransformFactoryTest.java
@@ -15,30 +15,18 @@
* limitations under the License.
*/
-package org.apache.seatunnel.transform.common;
+package org.apache.seatunnel.transform;
-import org.apache.seatunnel.api.table.type.RowKind;
-import org.apache.seatunnel.api.table.type.SeaTunnelRow;
+import org.apache.seatunnel.transform.llm.LLMTransformFactory;
-import lombok.AllArgsConstructor;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
-@AllArgsConstructor
-public class SeaTunnelRowAccessor {
- private final SeaTunnelRow row;
+public class LLMTransformFactoryTest {
- public int getArity() {
- return row.getArity();
- }
-
- public String getTableId() {
- return row.getTableId();
- }
-
- public RowKind getRowKind() {
- return row.getRowKind();
- }
-
- public Object getField(int pos) {
- return row.getField(pos);
+ @Test
+ public void testOptionRule() throws Exception {
+ LLMTransformFactory replaceTransformFactory = new
LLMTransformFactory();
+ Assertions.assertNotNull(replaceTransformFactory.optionRule());
}
}
diff --git
a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/llm/LLMRequestJsonTest.java
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/llm/LLMRequestJsonTest.java
new file mode 100644
index 0000000000..f32cc87055
--- /dev/null
+++
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/llm/LLMRequestJsonTest.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.seatunnel.transform.llm;
+
+import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.ObjectMapper;
+import
org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ObjectNode;
+
+import org.apache.seatunnel.api.table.type.BasicType;
+import org.apache.seatunnel.api.table.type.SeaTunnelDataType;
+import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
+import org.apache.seatunnel.api.table.type.SqlType;
+import org.apache.seatunnel.transform.llm.model.openai.OpenAIModel;
+
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+import java.io.IOException;
+
+public class LLMRequestJsonTest {
+
+ private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
+
+ @Test
+ void testOpenAIRequestJson() throws IOException {
+ SeaTunnelRowType rowType =
+ new SeaTunnelRowType(
+ new String[] {"id", "name"},
+ new SeaTunnelDataType[] {BasicType.INT_TYPE,
BasicType.STRING_TYPE});
+ OpenAIModel model =
+ new OpenAIModel(
+ rowType,
+ SqlType.STRING,
+ "Determine whether someone is Chinese or American by
their name",
+ "gpt-3.5-turbo",
+ "sk-xxx",
+ "https://api.openai.com/v1/chat/completions");
+ ObjectNode node =
+ model.createJsonNodeFromData(
+ "Determine whether someone is Chinese or American by
their name",
+ "{\"id\":1, \"name\":\"John\"}");
+ Assertions.assertEquals(
+
"{\"model\":\"gpt-3.5-turbo\",\"messages\":[{\"role\":\"system\",\"content\":\"Determine
whether someone is Chinese or American by their
name\"},{\"role\":\"user\",\"content\":\"{\\\"id\\\":1,
\\\"name\\\":\\\"John\\\"}\"}]}",
+ OBJECT_MAPPER.writeValueAsString(node));
+ model.close();
+ }
+}
diff --git a/tools/dependencies/known-dependencies.txt
b/tools/dependencies/known-dependencies.txt
index 161134511c..eda697369e 100755
--- a/tools/dependencies/known-dependencies.txt
+++ b/tools/dependencies/known-dependencies.txt
@@ -8,6 +8,8 @@ config-1.3.3.jar
disruptor-3.4.4.jar
guava-27.0-jre.jar
hazelcast-5.1.jar
+httpclient-4.5.13.jar
+httpcore-4.4.4.jar
jackson-annotations-2.13.3.jar
jackson-core-2.13.3.jar
jackson-databind-2.13.3.jar