This is an automated email from the ASF dual-hosted git repository.

wanghailin pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/seatunnel.git


The following commit(s) were added to refs/heads/dev by this push:
     new e0c99ace23 [Feature][Transforms-V2] Handling LLM non-standard format 
responses (#8551)
e0c99ace23 is described below

commit e0c99ace23c8d9f64f325c22819fe996e6edde7c
Author: 在下uptown <[email protected]>
AuthorDate: Mon Feb 17 20:29:07 2025 +0800

    [Feature][Transforms-V2] Handling LLM non-standard format responses (#8551)
---
 seatunnel-transforms-v2/pom.xml                    |  6 ++
 .../nlpmodel/llm/remote/custom/CustomModel.java    | 12 +++-
 .../transform/llm/LLMRequestJsonTest.java          | 77 +++++++++++++++++++++-
 3 files changed, 92 insertions(+), 3 deletions(-)

diff --git a/seatunnel-transforms-v2/pom.xml b/seatunnel-transforms-v2/pom.xml
index 4cbef9a4b8..f15c1aae40 100644
--- a/seatunnel-transforms-v2/pom.xml
+++ b/seatunnel-transforms-v2/pom.xml
@@ -92,6 +92,12 @@
             <artifactId>httpcore</artifactId>
             <version>${httpcore.version}</version>
         </dependency>
+        <dependency>
+            <groupId>com.squareup.okhttp3</groupId>
+            <artifactId>mockwebserver</artifactId>
+            <version>3.6.0</version>
+            <scope>test</scope>
+        </dependency>
     </dependencies>
 
     <build>
diff --git 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/custom/CustomModel.java
 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/custom/CustomModel.java
index e3670b1ebc..67cda24062 100644
--- 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/custom/CustomModel.java
+++ 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/custom/CustomModel.java
@@ -40,6 +40,7 @@ import org.apache.http.util.EntityUtils;
 import com.jayway.jsonpath.JsonPath;
 
 import java.io.IOException;
+import java.util.Collections;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
@@ -94,8 +95,15 @@ public class CustomModel extends AbstractModel {
         if (response.getStatusLine().getStatusCode() != 200) {
             throw new IOException("Failed to get vector from custom, response: 
" + responseStr);
         }
-        return OBJECT_MAPPER.convertValue(
-                parseResponse(responseStr), new TypeReference<List<String>>() 
{});
+        try {
+            return OBJECT_MAPPER.convertValue(
+                    parseResponse(responseStr), new 
TypeReference<List<String>>() {});
+        } catch (Exception e) {
+            String result =
+                    OBJECT_MAPPER.convertValue(
+                            parseResponse(responseStr), new 
TypeReference<String>() {});
+            return Collections.singletonList(result);
+        }
     }
 
     @VisibleForTesting
diff --git 
a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/llm/LLMRequestJsonTest.java
 
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/llm/LLMRequestJsonTest.java
index 6dd2ff4149..03bc017dbd 100644
--- 
a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/llm/LLMRequestJsonTest.java
+++ 
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/llm/LLMRequestJsonTest.java
@@ -35,9 +35,13 @@ import 
org.apache.seatunnel.transform.nlpmodel.llm.remote.openai.OpenAIModel;
 import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 
+import okhttp3.mockwebserver.MockResponse;
+import okhttp3.mockwebserver.MockWebServer;
+
 import java.io.IOException;
 import java.lang.reflect.Field;
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -172,7 +176,6 @@ public class LLMRequestJsonTest {
 
         Map<String, String> header = new HashMap<>();
         header.put("Content-Type", "application/json");
-        header.put("Authorization", "Bearer " + "apikey");
 
         List<Map<String, String>> messagesList = new ArrayList<>();
 
@@ -209,4 +212,76 @@ public class LLMRequestJsonTest {
                 "{\"messages\":[{\"role\":\"system\",\"content\":\"Determine 
whether someone is Chinese or American by their 
name\"},{\"role\":\"user\",\"content\":\"{\\\"id\\\":1, 
\\\"name\\\":\\\"John\\\"}\"}],\"model\":\"custom-model\"}",
                 OBJECT_MAPPER.writeValueAsString(node));
     }
+
+    @Test
+    void testCustomOllamaRequestJson() throws IOException {
+
+        MockWebServer mockWebServer = new MockWebServer();
+        mockWebServer.start(11434);
+        String jsonResponse =
+                "{\n"
+                        + "    \"model\": \"qwen:7b\",\n"
+                        + "    \"created_at\": 
\"2025-02-07T01:22:46.589856Z\",\n"
+                        + "    \"message\": {\n"
+                        + "        \"role\": \"assistant\",\n"
+                        + "        \"content\": \"Based on the information 
provided in the JSON object, \\\"John\\\" does not inherently indicate if the 
person is Chinese or American. The name \\\"John\\\" is commonly used across 
many cultures. To determine a person's nationality based solely on their name, 
more context would be needed.\"\n"
+                        + "    },\n"
+                        + "    \"done_reason\": \"stop\",\n"
+                        + "    \"done\": true,\n"
+                        + "    \"total_duration\": 14435322300,\n"
+                        + "    \"load_duration\": 28998200,\n"
+                        + "    \"prompt_eval_count\": 34,\n"
+                        + "    \"prompt_eval_duration\": 302000000,\n"
+                        + "    \"eval_count\": 56,\n"
+                        + "    \"eval_duration\": 14102000000\n"
+                        + "}";
+
+        mockWebServer.enqueue(
+                new MockResponse()
+                        .setBody(jsonResponse)
+                        .addHeader("Content-Type", "application/json"));
+
+        SeaTunnelRowType rowType =
+                new SeaTunnelRowType(
+                        new String[] {"id", "name"},
+                        new SeaTunnelDataType[] {BasicType.INT_TYPE, 
BasicType.STRING_TYPE});
+
+        Map<String, String> header = new HashMap<>();
+        header.put("Content-Type", "application/json");
+
+        List<Map<String, String>> messagesList = new ArrayList<>();
+
+        Map<String, String> systemMessage = new HashMap<>();
+        systemMessage.put("role", "system");
+        systemMessage.put("content", "${prompt}");
+        messagesList.add(systemMessage);
+
+        Map<String, String> userMessage = new HashMap<>();
+        userMessage.put("role", "user");
+        userMessage.put("content", "${input}");
+        messagesList.add(userMessage);
+
+        Map<String, Object> resultMap = new HashMap<>();
+        resultMap.put("model", "${model}");
+        resultMap.put("stream", false);
+        resultMap.put("messages", messagesList);
+
+        CustomModel model =
+                new CustomModel(
+                        rowType,
+                        SqlType.STRING,
+                        null,
+                        "Determine whether someone is Chinese or American by 
their name",
+                        "qwen:7b",
+                        "http://localhost:11434/api/chat";,
+                        header,
+                        resultMap,
+                        "$.message.content");
+
+        SeaTunnelRow row = new SeaTunnelRow(rowType.getFieldTypes().length);
+        row.setField(0, 1);
+        row.setField(1, "John");
+        List<String> successResult = 
model.inference(Collections.singletonList(row));
+        Assertions.assertFalse(successResult.isEmpty());
+    }
 }

Reply via email to