This is an automated email from the ASF dual-hosted git repository.
wanghailin pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/seatunnel.git
The following commit(s) were added to refs/heads/dev by this push:
new 65c14aabea [Future][Transforms-V2] llm trans support field projection
(#7621)
65c14aabea is described below
commit 65c14aabeaa39a12a670e33e40afb1b177d2d7ee
Author: zhangdonghao <[email protected]>
AuthorDate: Wed Sep 11 13:21:32 2024 +0800
[Future][Transforms-V2] llm trans support field projection (#7621)
---
docs/en/transform-v2/llm.md | 26 ++++++--
docs/zh/transform-v2/llm.md | 43 ++++++++----
.../apache/seatunnel/e2e/transform/TestLLMIT.java | 7 ++
.../resources/llm_openai_transform_columns.conf | 76 ++++++++++++++++++++++
.../transform/nlpmodel/llm/LLMTransform.java | 2 +
.../transform/nlpmodel/llm/LLMTransformConfig.java | 8 +++
.../nlpmodel/llm/remote/AbstractModel.java | 61 +++++++++++++++--
.../nlpmodel/llm/remote/custom/CustomModel.java | 3 +-
.../nlpmodel/llm/remote/openai/OpenAIModel.java | 3 +-
.../transform/llm/LLMRequestJsonTest.java | 41 ++++++++++++
10 files changed, 247 insertions(+), 23 deletions(-)
diff --git a/docs/en/transform-v2/llm.md b/docs/en/transform-v2/llm.md
index 8caaad00a0..d1b8e6fc6e 100644
--- a/docs/en/transform-v2/llm.md
+++ b/docs/en/transform-v2/llm.md
@@ -11,17 +11,18 @@ more.
## Options
| name | type | required | default value |
-|------------------------|--------|----------|---------------|
+| ---------------------- | ------ | -------- | ------------- |
| model_provider | enum | yes | |
| output_data_type | enum | no | String |
| prompt | string | yes | |
+| inference_columns | list | no | |
| model | string | yes | |
| api_key | string | yes | |
| api_path | string | no | |
-| custom_config | map | no | |
-| custom_response_parse | string | no | |
+| custom_config | map | no | |
+| custom_response_parse | string | no | |
| custom_request_headers | map | no | |
-| custom_request_body | map | no | |
+| custom_request_body | map | no | |
### model_provider
@@ -62,6 +63,23 @@ The result will be:
| Eric | 20 | American |
| Guangdong Liu | 20 | Chinese |
+### inference_columns
+
+The `inference_columns` option allows you to specify which columns from the
input data should be used as inputs for the LLM. By default, all columns will
be used as inputs.
+
+For example:
+```hocon
+transform {
+ LLM {
+ model_provider = OPENAI
+ model = gpt-4o-mini
+ api_key = sk-xxx
+ inference_columns = ["name", "age"]
+ prompt = "Determine whether someone is Chinese or American by their name"
+ }
+}
+```
+
### model
The model to use. Different model providers have different models. For
example, the OpenAI model can be `gpt-4o-mini`.
diff --git a/docs/zh/transform-v2/llm.md b/docs/zh/transform-v2/llm.md
index 5efcf47125..c2d3c0f6ca 100644
--- a/docs/zh/transform-v2/llm.md
+++ b/docs/zh/transform-v2/llm.md
@@ -8,18 +8,19 @@
## 属性
-| 名称 | 类型 | 是否必须 | 默认值 |
-|------------------------|--------|------|--------|
-| model_provider | enum | yes | |
-| output_data_type | enum | no | String |
-| prompt | string | yes | |
-| model | string | yes | |
-| api_key | string | yes | |
-| api_path | string | no | |
-| custom_config | map | no | |
-| custom_response_parse | string | no | |
-| custom_request_headers | map | no | |
-| custom_request_body | map | no | |
+| 名称 | 类型 | 是否必须 | 默认值 |
+| ---------------------- | ------ | -------- | ------ |
+| model_provider | enum | yes | |
+| output_data_type | enum | no | String |
+| prompt | string | yes | |
+| inference_columns | list | no | |
+| model | string | yes | |
+| api_key | string | yes | |
+| api_path | string | no | |
+| custom_config | map | no | |
+| custom_response_parse | string | no | |
+| custom_request_headers | map | no | |
+| custom_request_body | map | no | |
### model_provider
@@ -60,6 +61,23 @@ Determine whether someone is Chinese or American by their
name
| Eric | 20 | American |
| Guangdong Liu | 20 | Chinese |
+### inference_columns
+
+`inference_columns`选项允许您指定应该将输入数据中的哪些列用作LLM的输入。默认情况下,所有列都将用作输入。
+
+For example:
+```hocon
+transform {
+ LLM {
+ model_provider = OPENAI
+ model = gpt-4o-mini
+ api_key = sk-xxx
+ inference_columns = ["name", "age"]
+ prompt = "Determine whether someone is Chinese or American by their name"
+ }
+}
+```
+
### model
要使用的模型。不同的模型提供者有不同的模型。例如,OpenAI 模型可以是 `gpt-4o-mini`。
@@ -253,4 +271,3 @@ sink {
}
```
-
diff --git
a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/java/org/apache/seatunnel/e2e/transform/TestLLMIT.java
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/java/org/apache/seatunnel/e2e/transform/TestLLMIT.java
index 712a6d7f90..244bca1e9c 100644
---
a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/java/org/apache/seatunnel/e2e/transform/TestLLMIT.java
+++
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/java/org/apache/seatunnel/e2e/transform/TestLLMIT.java
@@ -93,6 +93,13 @@ public class TestLLMIT extends TestSuiteBase implements
TestResource {
throws IOException, InterruptedException {
Container.ExecResult execResult =
container.executeJob("/llm_openai_transform_boolean.conf");
+ }
+
+ @TestTemplate
+ public void testLLMWithOpenAIColumns(TestContainer container)
+ throws IOException, InterruptedException {
+ Container.ExecResult execResult =
+ container.executeJob("/llm_openai_transform_columns.conf");
Assertions.assertEquals(0, execResult.getExitCode());
}
diff --git
a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/llm_openai_transform_columns.conf
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/llm_openai_transform_columns.conf
new file mode 100644
index 0000000000..e4286ba762
--- /dev/null
+++
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/llm_openai_transform_columns.conf
@@ -0,0 +1,76 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+######
+###### This config file is a demonstration of streaming processing in
seatunnel config
+######
+
+env {
+ job.mode = "BATCH"
+}
+
+source {
+ FakeSource {
+ row.num = 5
+ schema = {
+ fields {
+ id = "int"
+ name = "string"
+ }
+ }
+ rows = [
+ {fields = [1, "Jia Fan"], kind = INSERT}
+ {fields = [2, "Hailin Wang"], kind = INSERT}
+ {fields = [3, "Tomas"], kind = INSERT}
+ {fields = [4, "Eric"], kind = INSERT}
+ {fields = [5, "Guangdong Liu"], kind = INSERT}
+ ]
+ result_table_name = "fake"
+ }
+}
+
+transform {
+ LLM {
+ source_table_name = "fake"
+ model_provider = OPENAI
+ model = gpt-4o-mini
+ api_key = sk-xxx
+ inference_columns = ["name"]
+ prompt = "Determine whether someone is Chinese or American by their name"
+ openai.api_path = "http://mockserver:1080/v1/chat/completions"
+ result_table_name = "llm_output"
+ }
+}
+
+sink {
+ Assert {
+ source_table_name = "llm_output"
+ rules =
+ {
+ field_rules = [
+ {
+ field_name = llm_output
+ field_type = string
+ field_value = [
+ {
+ rule_type = NOT_NULL
+ }
+ ]
+ }
+ ]
+ }
+ }
+}
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransform.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransform.java
index 92db061ccc..705253a2fe 100644
---
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransform.java
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransform.java
@@ -79,6 +79,7 @@ public class LLMTransform extends SingleFieldOutputTransform {
new CustomModel(
inputCatalogTable.getSeaTunnelRowType(),
outputDataType.getSqlType(),
+
config.get(LLMTransformConfig.INFERENCE_COLUMNS),
config.get(LLMTransformConfig.PROMPT),
config.get(LLMTransformConfig.MODEL),
provider.usedLLMPath(config.get(LLMTransformConfig.API_PATH)),
@@ -97,6 +98,7 @@ public class LLMTransform extends SingleFieldOutputTransform {
new OpenAIModel(
inputCatalogTable.getSeaTunnelRowType(),
outputDataType.getSqlType(),
+
config.get(LLMTransformConfig.INFERENCE_COLUMNS),
config.get(LLMTransformConfig.PROMPT),
config.get(LLMTransformConfig.MODEL),
config.get(LLMTransformConfig.API_KEY),
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransformConfig.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransformConfig.java
index 8800f061db..c45bfb8f39 100644
---
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransformConfig.java
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransformConfig.java
@@ -21,6 +21,8 @@ import org.apache.seatunnel.api.configuration.Option;
import org.apache.seatunnel.api.configuration.Options;
import org.apache.seatunnel.transform.nlpmodel.ModelTransformConfig;
+import java.util.List;
+
public class LLMTransformConfig extends ModelTransformConfig {
public static final Option<String> PROMPT =
@@ -29,6 +31,12 @@ public class LLMTransformConfig extends ModelTransformConfig
{
.noDefaultValue()
.withDescription("The prompt of LLM");
+ public static final Option<List<String>> INFERENCE_COLUMNS =
+ Options.key("inference_columns")
+ .listType()
+ .noDefaultValue()
+ .withDescription("The row projection field of each
inference");
+
public static final Option<Integer> INFERENCE_BATCH_SIZE =
Options.key("inference_batch_size")
.intType()
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/AbstractModel.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/AbstractModel.java
index 4ee271c408..5d0fcee637 100644
---
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/AbstractModel.java
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/AbstractModel.java
@@ -21,25 +21,59 @@ import
org.apache.seatunnel.shade.com.fasterxml.jackson.databind.ObjectMapper;
import
org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ArrayNode;
import
org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ObjectNode;
+import org.apache.seatunnel.api.table.type.SeaTunnelDataType;
import org.apache.seatunnel.api.table.type.SeaTunnelRow;
import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
import org.apache.seatunnel.api.table.type.SqlType;
import org.apache.seatunnel.format.json.RowToJsonConverters;
+import com.google.common.annotations.VisibleForTesting;
+
import java.io.IOException;
+import java.util.ArrayList;
import java.util.List;
public abstract class AbstractModel implements Model {
protected static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
- private final RowToJsonConverters.RowToJsonConverter rowToJsonConverters;
+ private final RowToJsonConverters.RowToJsonConverter rowToJsonConverter;
+ private final SeaTunnelRowType rowType;
private final String prompt;
private final SqlType outputType;
+ private final List<String> projectionColumns;
- public AbstractModel(SeaTunnelRowType rowType, SqlType outputType, String
prompt) {
+ public AbstractModel(
+ SeaTunnelRowType rowType,
+ SqlType outputType,
+ List<String> projectionColumns,
+ String prompt) {
+ this.rowType = rowType;
this.prompt = prompt;
this.outputType = outputType;
- this.rowToJsonConverters = new
RowToJsonConverters().createConverter(rowType, null);
+ this.projectionColumns = projectionColumns;
+ this.rowToJsonConverter = getRowToJsonConverter();
+ }
+
+ public RowToJsonConverters.RowToJsonConverter getRowToJsonConverter() {
+ RowToJsonConverters converters = new RowToJsonConverters();
+ if (projectionColumns != null && !projectionColumns.isEmpty()) {
+ List<SeaTunnelDataType> fieldTypes = new ArrayList<>();
+ for (String fieldName : projectionColumns) {
+ int fieldIndex = rowType.indexOf(fieldName);
+ if (fieldIndex != -1) {
+ fieldTypes.add(rowType.getFieldType(fieldIndex));
+ } else {
+ throw new IllegalArgumentException(
+ "Field name " + fieldName + " does not exist in
the row type.");
+ }
+ }
+ SeaTunnelRowType projectionRowType =
+ new SeaTunnelRowType(
+ projectionColumns.toArray(new String[0]),
+ fieldTypes.toArray(new SeaTunnelDataType[0]));
+ return converters.createConverter(projectionRowType, null);
+ }
+ return converters.createConverter(rowType, null);
}
private String getPromptWithLimit() {
@@ -58,12 +92,31 @@ public abstract class AbstractModel implements Model {
ArrayNode rowsNode = OBJECT_MAPPER.createArrayNode();
for (SeaTunnelRow row : rows) {
ObjectNode rowNode = OBJECT_MAPPER.createObjectNode();
- rowToJsonConverters.convert(OBJECT_MAPPER, rowNode, row);
+ rowToJsonConverter.convert(OBJECT_MAPPER, rowNode,
createProjectionSeaTunnelRow(row));
rowsNode.add(rowNode);
}
return chatWithModel(getPromptWithLimit(),
OBJECT_MAPPER.writeValueAsString(rowsNode));
}
+ @VisibleForTesting
+ public SeaTunnelRow createProjectionSeaTunnelRow(SeaTunnelRow row) {
+ if (row == null || projectionColumns == null ||
projectionColumns.isEmpty()) {
+ return row;
+ }
+ SeaTunnelRow projectionRow = new
SeaTunnelRow(projectionColumns.size());
+ for (int i = 0; i < projectionColumns.size(); i++) {
+ String fieldName = projectionColumns.get(i);
+ int fieldIndex = rowType.indexOf(fieldName);
+ if (fieldIndex != -1) {
+ projectionRow.setField(i, row.getField(fieldIndex));
+ } else {
+ throw new IllegalArgumentException(
+ "Field name " + fieldName + " does not exist in the
row type.");
+ }
+ }
+ return projectionRow;
+ }
+
protected abstract List<String> chatWithModel(String promptWithLimit,
String rowsJson)
throws IOException;
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/custom/CustomModel.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/custom/CustomModel.java
index af893e92dd..dfc2bfc868 100644
---
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/custom/CustomModel.java
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/custom/CustomModel.java
@@ -56,13 +56,14 @@ public class CustomModel extends AbstractModel {
public CustomModel(
SeaTunnelRowType rowType,
SqlType outputType,
+ List<String> projectionColumns,
String prompt,
String model,
String apiPath,
Map<String, String> header,
Map<String, Object> body,
String parse) {
- super(rowType, outputType, prompt);
+ super(rowType, outputType, projectionColumns, prompt);
this.apiPath = apiPath;
this.model = model;
this.header = header;
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/openai/OpenAIModel.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/openai/OpenAIModel.java
index 8dc12ec0cd..aeea00b49b 100644
---
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/openai/OpenAIModel.java
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/openai/OpenAIModel.java
@@ -54,11 +54,12 @@ public class OpenAIModel extends AbstractModel {
public OpenAIModel(
SeaTunnelRowType rowType,
SqlType outputType,
+ List<String> projectionColumns,
String prompt,
String model,
String apiKey,
String apiPath) {
- super(rowType, outputType, prompt);
+ super(rowType, outputType, projectionColumns, prompt);
this.apiKey = apiKey;
this.apiPath = apiPath;
this.model = model;
diff --git
a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/llm/LLMRequestJsonTest.java
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/llm/LLMRequestJsonTest.java
index 2de785a1a8..97eb1b8f96 100644
---
a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/llm/LLMRequestJsonTest.java
+++
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/llm/LLMRequestJsonTest.java
@@ -22,14 +22,18 @@ import
org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ObjectNode
import org.apache.seatunnel.api.table.type.BasicType;
import org.apache.seatunnel.api.table.type.SeaTunnelDataType;
+import org.apache.seatunnel.api.table.type.SeaTunnelRow;
import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
import org.apache.seatunnel.api.table.type.SqlType;
+import org.apache.seatunnel.format.json.RowToJsonConverters;
import org.apache.seatunnel.transform.nlpmodel.llm.remote.custom.CustomModel;
import org.apache.seatunnel.transform.nlpmodel.llm.remote.openai.OpenAIModel;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
+import com.google.common.collect.Lists;
+
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
@@ -50,6 +54,7 @@ public class LLMRequestJsonTest {
new OpenAIModel(
rowType,
SqlType.STRING,
+ null,
"Determine whether someone is Chinese or American by
their name",
"gpt-3.5-turbo",
"sk-xxx",
@@ -64,6 +69,41 @@ public class LLMRequestJsonTest {
model.close();
}
+ @Test
+ void testOpenAIProjectionRequestJson() throws IOException {
+ SeaTunnelRowType rowType =
+ new SeaTunnelRowType(
+ new String[] {"id", "name", "city"},
+ new SeaTunnelDataType[] {
+ BasicType.INT_TYPE, BasicType.STRING_TYPE,
BasicType.STRING_TYPE
+ });
+ OpenAIModel model =
+ new OpenAIModel(
+ rowType,
+ SqlType.STRING,
+ Lists.newArrayList("name", "city"),
+ "Determine whether someone is Chinese or American by
their name",
+ "gpt-3.5-turbo",
+ "sk-xxx",
+ "https://api.openai.com/v1/chat/completions");
+
+ SeaTunnelRow row = new SeaTunnelRow(rowType.getFieldTypes().length);
+ row.setField(0, 1);
+ row.setField(1, "John");
+ row.setField(2, "New York");
+ ObjectNode rowNode = OBJECT_MAPPER.createObjectNode();
+ RowToJsonConverters.RowToJsonConverter rowToJsonConverter =
model.getRowToJsonConverter();
+ rowToJsonConverter.convert(OBJECT_MAPPER, rowNode,
model.createProjectionSeaTunnelRow(row));
+ ObjectNode node =
+ model.createJsonNodeFromData(
+ "Determine whether someone is Chinese or American by
their name",
+ OBJECT_MAPPER.writeValueAsString(rowNode));
+ Assertions.assertEquals(
+
"{\"model\":\"gpt-3.5-turbo\",\"messages\":[{\"role\":\"system\",\"content\":\"Determine
whether someone is Chinese or American by their
name\"},{\"role\":\"user\",\"content\":\"{\\\"name\\\":\\\"John\\\",\\\"city\\\":\\\"New
York\\\"}\"}]}",
+ OBJECT_MAPPER.writeValueAsString(node));
+ model.close();
+ }
+
@Test
void testCustomRequestJson() throws IOException {
SeaTunnelRowType rowType =
@@ -95,6 +135,7 @@ public class LLMRequestJsonTest {
new CustomModel(
rowType,
SqlType.STRING,
+ null,
"Determine whether someone is Chinese or American by
their name",
"custom-model",
"https://api.custom.com/v1/chat/completions",