This is an automated email from the ASF dual-hosted git repository.

wanghailin pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/seatunnel.git


The following commit(s) were added to refs/heads/dev by this push:
     new 65c14aabea [Future][Transforms-V2] llm trans support field projection 
(#7621)
65c14aabea is described below

commit 65c14aabeaa39a12a670e33e40afb1b177d2d7ee
Author: zhangdonghao <[email protected]>
AuthorDate: Wed Sep 11 13:21:32 2024 +0800

    [Future][Transforms-V2] llm trans support field projection (#7621)
---
 docs/en/transform-v2/llm.md                        | 26 ++++++--
 docs/zh/transform-v2/llm.md                        | 43 ++++++++----
 .../apache/seatunnel/e2e/transform/TestLLMIT.java  |  7 ++
 .../resources/llm_openai_transform_columns.conf    | 76 ++++++++++++++++++++++
 .../transform/nlpmodel/llm/LLMTransform.java       |  2 +
 .../transform/nlpmodel/llm/LLMTransformConfig.java |  8 +++
 .../nlpmodel/llm/remote/AbstractModel.java         | 61 +++++++++++++++--
 .../nlpmodel/llm/remote/custom/CustomModel.java    |  3 +-
 .../nlpmodel/llm/remote/openai/OpenAIModel.java    |  3 +-
 .../transform/llm/LLMRequestJsonTest.java          | 41 ++++++++++++
 10 files changed, 247 insertions(+), 23 deletions(-)

diff --git a/docs/en/transform-v2/llm.md b/docs/en/transform-v2/llm.md
index 8caaad00a0..d1b8e6fc6e 100644
--- a/docs/en/transform-v2/llm.md
+++ b/docs/en/transform-v2/llm.md
@@ -11,17 +11,18 @@ more.
 ## Options
 
 | name                   | type   | required | default value |
-|------------------------|--------|----------|---------------|
+| ---------------------- | ------ | -------- | ------------- |
 | model_provider         | enum   | yes      |               |
 | output_data_type       | enum   | no       | String        |
 | prompt                 | string | yes      |               |
+| inference_columns   | list   | no       |               |
 | model                  | string | yes      |               |
 | api_key                | string | yes      |               |
 | api_path               | string | no       |               |
-| custom_config          | map    | no       |               | 
-| custom_response_parse  | string | no       |               | 
+| custom_config          | map    | no       |               |
+| custom_response_parse  | string | no       |               |
 | custom_request_headers | map    | no       |               |
-| custom_request_body    | map    | no       |               | 
+| custom_request_body    | map    | no       |               |
 
 ### model_provider
 
@@ -62,6 +63,23 @@ The result will be:
 | Eric          | 20  | American   |
 | Guangdong Liu | 20  | Chinese    |
 
+### inference_columns
+
+The `inference_columns` option allows you to specify which columns from the 
input data should be used as inputs for the LLM. By default, all columns will 
be used as inputs.
+
+For example:
+```hocon
+transform {
+  LLM {
+    model_provider = OPENAI
+    model = gpt-4o-mini
+    api_key = sk-xxx
+    inference_columns = ["name", "age"]
+    prompt = "Determine whether someone is Chinese or American by their name"
+  }
+}
+```
+
 ### model
 
 The model to use. Different model providers have different models. For 
example, the OpenAI model can be `gpt-4o-mini`.
diff --git a/docs/zh/transform-v2/llm.md b/docs/zh/transform-v2/llm.md
index 5efcf47125..c2d3c0f6ca 100644
--- a/docs/zh/transform-v2/llm.md
+++ b/docs/zh/transform-v2/llm.md
@@ -8,18 +8,19 @@
 
 ## 属性
 
-| 名称                     | 类型     | 是否必须 | 默认值    |
-|------------------------|--------|------|--------|
-| model_provider         | enum   | yes  |        |
-| output_data_type       | enum   | no   | String |
-| prompt                 | string | yes  |        |
-| model                  | string | yes  |        |
-| api_key                | string | yes  |        |
-| api_path               | string | no   |        |
-| custom_config          | map    | no   |        | 
-| custom_response_parse  | string | no   |        | 
-| custom_request_headers | map    | no   |        |
-| custom_request_body    | map    | no   |        | 
+| 名称                   | 类型   | 是否必须 | 默认值 |
+| ---------------------- | ------ | -------- | ------ |
+| model_provider         | enum   | yes      |        |
+| output_data_type       | enum   | no       | String |
+| prompt                 | string | yes      |        |
+| inference_columns   | list   | no       |        |
+| model                  | string | yes      |        |
+| api_key                | string | yes      |        |
+| api_path               | string | no       |        |
+| custom_config          | map    | no       |        |
+| custom_response_parse  | string | no       |        |
+| custom_request_headers | map    | no       |        |
+| custom_request_body    | map    | no       |        |
 
 ### model_provider
 
@@ -60,6 +61,23 @@ Determine whether someone is Chinese or American by their 
name
 | Eric          | 20  | American   |
 | Guangdong Liu | 20  | Chinese    |
 
+### inference_columns
+
+`inference_columns`选项允许您指定应该将输入数据中的哪些列用作LLM的输入。默认情况下,所有列都将用作输入。
+
+For example:
+```hocon
+transform {
+  LLM {
+    model_provider = OPENAI
+    model = gpt-4o-mini
+    api_key = sk-xxx
+    inference_columns = ["name", "age"]
+    prompt = "Determine whether someone is Chinese or American by their name"
+  }
+}
+```
+
 ### model
 
 要使用的模型。不同的模型提供者有不同的模型。例如,OpenAI 模型可以是 `gpt-4o-mini`。
@@ -253,4 +271,3 @@ sink {
 }
 ```
 
-
diff --git 
a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/java/org/apache/seatunnel/e2e/transform/TestLLMIT.java
 
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/java/org/apache/seatunnel/e2e/transform/TestLLMIT.java
index 712a6d7f90..244bca1e9c 100644
--- 
a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/java/org/apache/seatunnel/e2e/transform/TestLLMIT.java
+++ 
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/java/org/apache/seatunnel/e2e/transform/TestLLMIT.java
@@ -93,6 +93,13 @@ public class TestLLMIT extends TestSuiteBase implements 
TestResource {
             throws IOException, InterruptedException {
         Container.ExecResult execResult =
                 container.executeJob("/llm_openai_transform_boolean.conf");
+    }
+
+    @TestTemplate
+    public void testLLMWithOpenAIColumns(TestContainer container)
+            throws IOException, InterruptedException {
+        Container.ExecResult execResult =
+                container.executeJob("/llm_openai_transform_columns.conf");
         Assertions.assertEquals(0, execResult.getExitCode());
     }
 
diff --git 
a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/llm_openai_transform_columns.conf
 
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/llm_openai_transform_columns.conf
new file mode 100644
index 0000000000..e4286ba762
--- /dev/null
+++ 
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/llm_openai_transform_columns.conf
@@ -0,0 +1,76 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+######
+###### This config file is a demonstration of streaming processing in 
seatunnel config
+######
+
+env {
+  job.mode = "BATCH"
+}
+
+source {
+  FakeSource {
+    row.num = 5
+    schema = {
+      fields {
+        id = "int"
+        name = "string"
+      }
+    }
+    rows = [
+      {fields = [1, "Jia Fan"], kind = INSERT}
+      {fields = [2, "Hailin Wang"], kind = INSERT}
+      {fields = [3, "Tomas"], kind = INSERT}
+      {fields = [4, "Eric"], kind = INSERT}
+      {fields = [5, "Guangdong Liu"], kind = INSERT}
+    ]
+    result_table_name = "fake"
+  }
+}
+
+transform {
+  LLM {
+    source_table_name = "fake"
+    model_provider = OPENAI
+    model = gpt-4o-mini
+    api_key = sk-xxx
+    inference_columns = ["name"]
+    prompt = "Determine whether someone is Chinese or American by their name"
+    openai.api_path = "http://mockserver:1080/v1/chat/completions";
+    result_table_name = "llm_output"
+  }
+}
+
+sink {
+  Assert {
+    source_table_name = "llm_output"
+    rules =
+      {
+        field_rules = [
+          {
+            field_name = llm_output
+            field_type = string
+            field_value = [
+              {
+                rule_type = NOT_NULL
+              }
+            ]
+          }
+        ]
+      }
+  }
+}
diff --git 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransform.java
 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransform.java
index 92db061ccc..705253a2fe 100644
--- 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransform.java
+++ 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransform.java
@@ -79,6 +79,7 @@ public class LLMTransform extends SingleFieldOutputTransform {
                         new CustomModel(
                                 inputCatalogTable.getSeaTunnelRowType(),
                                 outputDataType.getSqlType(),
+                                
config.get(LLMTransformConfig.INFERENCE_COLUMNS),
                                 config.get(LLMTransformConfig.PROMPT),
                                 config.get(LLMTransformConfig.MODEL),
                                 
provider.usedLLMPath(config.get(LLMTransformConfig.API_PATH)),
@@ -97,6 +98,7 @@ public class LLMTransform extends SingleFieldOutputTransform {
                         new OpenAIModel(
                                 inputCatalogTable.getSeaTunnelRowType(),
                                 outputDataType.getSqlType(),
+                                
config.get(LLMTransformConfig.INFERENCE_COLUMNS),
                                 config.get(LLMTransformConfig.PROMPT),
                                 config.get(LLMTransformConfig.MODEL),
                                 config.get(LLMTransformConfig.API_KEY),
diff --git 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransformConfig.java
 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransformConfig.java
index 8800f061db..c45bfb8f39 100644
--- 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransformConfig.java
+++ 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransformConfig.java
@@ -21,6 +21,8 @@ import org.apache.seatunnel.api.configuration.Option;
 import org.apache.seatunnel.api.configuration.Options;
 import org.apache.seatunnel.transform.nlpmodel.ModelTransformConfig;
 
+import java.util.List;
+
 public class LLMTransformConfig extends ModelTransformConfig {
 
     public static final Option<String> PROMPT =
@@ -29,6 +31,12 @@ public class LLMTransformConfig extends ModelTransformConfig 
{
                     .noDefaultValue()
                     .withDescription("The prompt of LLM");
 
+    public static final Option<List<String>> INFERENCE_COLUMNS =
+            Options.key("inference_columns")
+                    .listType()
+                    .noDefaultValue()
+                    .withDescription("The row projection field of each 
inference");
+
     public static final Option<Integer> INFERENCE_BATCH_SIZE =
             Options.key("inference_batch_size")
                     .intType()
diff --git 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/AbstractModel.java
 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/AbstractModel.java
index 4ee271c408..5d0fcee637 100644
--- 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/AbstractModel.java
+++ 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/AbstractModel.java
@@ -21,25 +21,59 @@ import 
org.apache.seatunnel.shade.com.fasterxml.jackson.databind.ObjectMapper;
 import 
org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ArrayNode;
 import 
org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ObjectNode;
 
+import org.apache.seatunnel.api.table.type.SeaTunnelDataType;
 import org.apache.seatunnel.api.table.type.SeaTunnelRow;
 import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
 import org.apache.seatunnel.api.table.type.SqlType;
 import org.apache.seatunnel.format.json.RowToJsonConverters;
 
+import com.google.common.annotations.VisibleForTesting;
+
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.List;
 
 public abstract class AbstractModel implements Model {
 
     protected static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
-    private final RowToJsonConverters.RowToJsonConverter rowToJsonConverters;
+    private final RowToJsonConverters.RowToJsonConverter rowToJsonConverter;
+    private final SeaTunnelRowType rowType;
     private final String prompt;
     private final SqlType outputType;
+    private final List<String> projectionColumns;
 
-    public AbstractModel(SeaTunnelRowType rowType, SqlType outputType, String 
prompt) {
+    public AbstractModel(
+            SeaTunnelRowType rowType,
+            SqlType outputType,
+            List<String> projectionColumns,
+            String prompt) {
+        this.rowType = rowType;
         this.prompt = prompt;
         this.outputType = outputType;
-        this.rowToJsonConverters = new 
RowToJsonConverters().createConverter(rowType, null);
+        this.projectionColumns = projectionColumns;
+        this.rowToJsonConverter = getRowToJsonConverter();
+    }
+
+    public RowToJsonConverters.RowToJsonConverter getRowToJsonConverter() {
+        RowToJsonConverters converters = new RowToJsonConverters();
+        if (projectionColumns != null && !projectionColumns.isEmpty()) {
+            List<SeaTunnelDataType> fieldTypes = new ArrayList<>();
+            for (String fieldName : projectionColumns) {
+                int fieldIndex = rowType.indexOf(fieldName);
+                if (fieldIndex != -1) {
+                    fieldTypes.add(rowType.getFieldType(fieldIndex));
+                } else {
+                    throw new IllegalArgumentException(
+                            "Field name " + fieldName + " does not exist in 
the row type.");
+                }
+            }
+            SeaTunnelRowType projectionRowType =
+                    new SeaTunnelRowType(
+                            projectionColumns.toArray(new String[0]),
+                            fieldTypes.toArray(new SeaTunnelDataType[0]));
+            return converters.createConverter(projectionRowType, null);
+        }
+        return converters.createConverter(rowType, null);
     }
 
     private String getPromptWithLimit() {
@@ -58,12 +92,31 @@ public abstract class AbstractModel implements Model {
         ArrayNode rowsNode = OBJECT_MAPPER.createArrayNode();
         for (SeaTunnelRow row : rows) {
             ObjectNode rowNode = OBJECT_MAPPER.createObjectNode();
-            rowToJsonConverters.convert(OBJECT_MAPPER, rowNode, row);
+            rowToJsonConverter.convert(OBJECT_MAPPER, rowNode, 
createProjectionSeaTunnelRow(row));
             rowsNode.add(rowNode);
         }
         return chatWithModel(getPromptWithLimit(), 
OBJECT_MAPPER.writeValueAsString(rowsNode));
     }
 
+    @VisibleForTesting
+    public SeaTunnelRow createProjectionSeaTunnelRow(SeaTunnelRow row) {
+        if (row == null || projectionColumns == null || 
projectionColumns.isEmpty()) {
+            return row;
+        }
+        SeaTunnelRow projectionRow = new 
SeaTunnelRow(projectionColumns.size());
+        for (int i = 0; i < projectionColumns.size(); i++) {
+            String fieldName = projectionColumns.get(i);
+            int fieldIndex = rowType.indexOf(fieldName);
+            if (fieldIndex != -1) {
+                projectionRow.setField(i, row.getField(fieldIndex));
+            } else {
+                throw new IllegalArgumentException(
+                        "Field name " + fieldName + " does not exist in the 
row type.");
+            }
+        }
+        return projectionRow;
+    }
+
     protected abstract List<String> chatWithModel(String promptWithLimit, 
String rowsJson)
             throws IOException;
 
diff --git 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/custom/CustomModel.java
 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/custom/CustomModel.java
index af893e92dd..dfc2bfc868 100644
--- 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/custom/CustomModel.java
+++ 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/custom/CustomModel.java
@@ -56,13 +56,14 @@ public class CustomModel extends AbstractModel {
     public CustomModel(
             SeaTunnelRowType rowType,
             SqlType outputType,
+            List<String> projectionColumns,
             String prompt,
             String model,
             String apiPath,
             Map<String, String> header,
             Map<String, Object> body,
             String parse) {
-        super(rowType, outputType, prompt);
+        super(rowType, outputType, projectionColumns, prompt);
         this.apiPath = apiPath;
         this.model = model;
         this.header = header;
diff --git 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/openai/OpenAIModel.java
 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/openai/OpenAIModel.java
index 8dc12ec0cd..aeea00b49b 100644
--- 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/openai/OpenAIModel.java
+++ 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/openai/OpenAIModel.java
@@ -54,11 +54,12 @@ public class OpenAIModel extends AbstractModel {
     public OpenAIModel(
             SeaTunnelRowType rowType,
             SqlType outputType,
+            List<String> projectionColumns,
             String prompt,
             String model,
             String apiKey,
             String apiPath) {
-        super(rowType, outputType, prompt);
+        super(rowType, outputType, projectionColumns, prompt);
         this.apiKey = apiKey;
         this.apiPath = apiPath;
         this.model = model;
diff --git 
a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/llm/LLMRequestJsonTest.java
 
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/llm/LLMRequestJsonTest.java
index 2de785a1a8..97eb1b8f96 100644
--- 
a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/llm/LLMRequestJsonTest.java
+++ 
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/llm/LLMRequestJsonTest.java
@@ -22,14 +22,18 @@ import 
org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ObjectNode
 
 import org.apache.seatunnel.api.table.type.BasicType;
 import org.apache.seatunnel.api.table.type.SeaTunnelDataType;
+import org.apache.seatunnel.api.table.type.SeaTunnelRow;
 import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
 import org.apache.seatunnel.api.table.type.SqlType;
+import org.apache.seatunnel.format.json.RowToJsonConverters;
 import org.apache.seatunnel.transform.nlpmodel.llm.remote.custom.CustomModel;
 import org.apache.seatunnel.transform.nlpmodel.llm.remote.openai.OpenAIModel;
 
 import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 
+import com.google.common.collect.Lists;
+
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.HashMap;
@@ -50,6 +54,7 @@ public class LLMRequestJsonTest {
                 new OpenAIModel(
                         rowType,
                         SqlType.STRING,
+                        null,
                         "Determine whether someone is Chinese or American by 
their name",
                         "gpt-3.5-turbo",
                         "sk-xxx",
@@ -64,6 +69,41 @@ public class LLMRequestJsonTest {
         model.close();
     }
 
+    @Test
+    void testOpenAIProjectionRequestJson() throws IOException {
+        SeaTunnelRowType rowType =
+                new SeaTunnelRowType(
+                        new String[] {"id", "name", "city"},
+                        new SeaTunnelDataType[] {
+                            BasicType.INT_TYPE, BasicType.STRING_TYPE, 
BasicType.STRING_TYPE
+                        });
+        OpenAIModel model =
+                new OpenAIModel(
+                        rowType,
+                        SqlType.STRING,
+                        Lists.newArrayList("name", "city"),
+                        "Determine whether someone is Chinese or American by 
their name",
+                        "gpt-3.5-turbo",
+                        "sk-xxx",
+                        "https://api.openai.com/v1/chat/completions";);
+
+        SeaTunnelRow row = new SeaTunnelRow(rowType.getFieldTypes().length);
+        row.setField(0, 1);
+        row.setField(1, "John");
+        row.setField(2, "New York");
+        ObjectNode rowNode = OBJECT_MAPPER.createObjectNode();
+        RowToJsonConverters.RowToJsonConverter rowToJsonConverter = 
model.getRowToJsonConverter();
+        rowToJsonConverter.convert(OBJECT_MAPPER, rowNode, 
model.createProjectionSeaTunnelRow(row));
+        ObjectNode node =
+                model.createJsonNodeFromData(
+                        "Determine whether someone is Chinese or American by 
their name",
+                        OBJECT_MAPPER.writeValueAsString(rowNode));
+        Assertions.assertEquals(
+                
"{\"model\":\"gpt-3.5-turbo\",\"messages\":[{\"role\":\"system\",\"content\":\"Determine
 whether someone is Chinese or American by their 
name\"},{\"role\":\"user\",\"content\":\"{\\\"name\\\":\\\"John\\\",\\\"city\\\":\\\"New
 York\\\"}\"}]}",
+                OBJECT_MAPPER.writeValueAsString(node));
+        model.close();
+    }
+
     @Test
     void testCustomRequestJson() throws IOException {
         SeaTunnelRowType rowType =
@@ -95,6 +135,7 @@ public class LLMRequestJsonTest {
                 new CustomModel(
                         rowType,
                         SqlType.STRING,
+                        null,
                         "Determine whether someone is Chinese or American by 
their name",
                         "custom-model",
                         "https://api.custom.com/v1/chat/completions";,

Reply via email to