This is an automated email from the ASF dual-hosted git repository.

aloyszhang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/inlong.git


The following commit(s) were added to refs/heads/master by this push:
     new 7b322a7aa1 [INLONG-9473][Sort] Support transform of embedding for LLM 
applications (#9474)
7b322a7aa1 is described below

commit 7b322a7aa1437cb42e9e1d04a382547cca24296c
Author: AloysZhang <[email protected]>
AuthorDate: Thu Dec 14 19:08:02 2023 +0800

    [INLONG-9473][Sort] Support transform of embedding for LLM applications 
(#9474)
    
    * [INLONG-9473][Sort] Support transform of embedding for LLM applications
    
    * fix spotless issue
    
    * fix log
---
 .../inlong/sort/function/EmbeddingFunction.java    | 105 +++++++++++++++
 .../sort/function/embedding/EmbeddingInput.java    |  50 +++++++
 .../sort/function/embedding/LanguageModel.java     |  71 ++++++++++
 .../inlong/sort/parser/impl/FlinkSqlParser.java    |   3 +
 .../sort/parser/impl/NativeFlinkSqlParser.java     |   4 +
 .../sort/function/EmbeddingFunctionTest.java       | 150 +++++++++++++++++++++
 6 files changed, 383 insertions(+)

diff --git 
a/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/function/EmbeddingFunction.java
 
b/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/function/EmbeddingFunction.java
new file mode 100644
index 0000000000..fc04c36b1a
--- /dev/null
+++ 
b/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/function/EmbeddingFunction.java
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.inlong.sort.function;
+
+import org.apache.inlong.sort.function.embedding.EmbeddingInput;
+import org.apache.inlong.sort.function.embedding.LanguageModel;
+
+import com.google.common.base.Strings;
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper;
+import org.apache.flink.table.functions.ScalarFunction;
+import org.apache.http.HttpHeaders;
+import org.apache.http.HttpResponse;
+import org.apache.http.HttpStatus;
+import org.apache.http.client.HttpClient;
+import org.apache.http.client.config.RequestConfig;
+import org.apache.http.client.methods.HttpPost;
+import org.apache.http.entity.StringEntity;
+import org.apache.http.impl.client.HttpClientBuilder;
+import org.apache.http.util.EntityUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Embedding function for LLM applications.
+ * */
+public class EmbeddingFunction extends ScalarFunction {
+
+    public static final Logger logger = 
LoggerFactory.getLogger(EmbeddingFunction.class);
+    public static final String DEFAULT_EMBEDDING_FUNCTION_NAME = "EMBEDDING";
+
+    private final ObjectMapper mapper = new ObjectMapper();
+    public static final int DEFAULT_CONNECT_TIMEOUT = 30000;
+    public static final int DEFAULT_SOCKET_TIMEOUT = 30000;
+    public static final String DEFAULT_MODEL = 
LanguageModel.BBAI_ZH.getModel();
+    private transient HttpClient httpClient;
+
+    /**
+     * Embedding a LLM document(a String object for now) via http protocol
+     * @param url the service url for embedding service
+     * @param input the source data for embedding
+     * @param model the language model supported in the embedding service
+     * */
+    public String eval(String url, String input, String model) {
+        // url and input is not null
+        if (Strings.isNullOrEmpty(url) || Strings.isNullOrEmpty(input)) {
+            logger.error("Failed to embedding, both url and input can't be 
empty or null, url: {}, input: {}",
+                    url, input);
+            return null;
+        }
+
+        if (Strings.isNullOrEmpty(model)) {
+            model = DEFAULT_MODEL;
+            logger.info("model is null, use default model: {}", model);
+        }
+
+        if (!LanguageModel.isLanguageModelSupported(model)) {
+            logger.error("Failed to embedding, language model {} not 
supported(only {} are supported right now)",
+                    model, LanguageModel.getAllSupportedLanguageModels());
+            return null;
+        }
+
+        // initialize httpClient
+        if (httpClient == null) {
+            RequestConfig requestConfig = RequestConfig.custom()
+                    .setConnectTimeout(DEFAULT_CONNECT_TIMEOUT)
+                    .setSocketTimeout(DEFAULT_SOCKET_TIMEOUT)
+                    .build();
+            httpClient = 
HttpClientBuilder.create().setDefaultRequestConfig(requestConfig).build();
+        }
+
+        try {
+            HttpPost httpPost = new HttpPost(url);
+            httpPost.setHeader(HttpHeaders.CONTENT_TYPE, "application/json");
+            EmbeddingInput embeddingInput = new EmbeddingInput(input, model);
+            String encodedContents = mapper.writeValueAsString(embeddingInput);
+            httpPost.setEntity(new StringEntity(encodedContents));
+            HttpResponse response = httpClient.execute(httpPost);
+
+            String returnStr = EntityUtils.toString(response.getEntity());
+            int returnCode = response.getStatusLine().getStatusCode();
+            if (Strings.isNullOrEmpty(returnStr) || HttpStatus.SC_OK != 
returnCode) {
+                throw new Exception("Failed to embedding, result: " + 
returnStr + ", code: " + returnCode);
+            }
+            return returnStr;
+        } catch (Exception e) {
+            logger.error("Failed to embedding, url: {}, input: {}", url, 
input, e);
+            return null;
+        }
+    }
+}
\ No newline at end of file
diff --git 
a/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/function/embedding/EmbeddingInput.java
 
b/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/function/embedding/EmbeddingInput.java
new file mode 100644
index 0000000000..f040db31d7
--- /dev/null
+++ 
b/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/function/embedding/EmbeddingInput.java
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.inlong.sort.function.embedding;
+
+import java.io.Serializable;
+
+/**
+ * Class representing the input of embedding function
+ */
+public class EmbeddingInput implements Serializable {
+
+    private String input;
+    private String model;
+
+    public void setInput(String input) {
+        this.input = input;
+    }
+
+    public String getInput() {
+        return input;
+    }
+
+    public void setModel(String model) {
+        this.model = model;
+    }
+
+    public String getModel() {
+        return model;
+    }
+
+    public EmbeddingInput(String input, String model) {
+        this.input = input;
+        this.model = model;
+    }
+}
diff --git 
a/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/function/embedding/LanguageModel.java
 
b/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/function/embedding/LanguageModel.java
new file mode 100644
index 0000000000..aebe9cedd0
--- /dev/null
+++ 
b/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/function/embedding/LanguageModel.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.inlong.sort.function.embedding;
+
+import com.google.common.base.Strings;
+
+/**
+ * Supported language model for embedding.
+ */
+public enum LanguageModel {
+
+    /**
+     * Language model for BBAI zh, chinese
+     * */
+    BBAI_ZH("BAAI/bge-large-zh-v1.5"),
+    /**
+     * Language model for BBAI en, english
+     * */
+    BBAI_EN("BAAI/bge-large-en"),
+    /**
+     * Language model for intfloat multi-language
+     * */
+    INTFLOAT_MULTI("intfloat/multilingual-e5-large");
+    String model;
+
+    LanguageModel(String s) {
+        this.model = s;
+    }
+
+    public String getModel() {
+        return this.model;
+    }
+
+    public static boolean isLanguageModelSupported(String s) {
+        if (Strings.isNullOrEmpty(s)) {
+            return false;
+        }
+        for (LanguageModel lm : LanguageModel.values()) {
+            if (s.equalsIgnoreCase(lm.getModel())) {
+                return true;
+            }
+        }
+        return false;
+    }
+
+    public static String getAllSupportedLanguageModels() {
+        if (LanguageModel.values().length == 0) {
+            return null;
+        }
+        StringBuilder supportedLMBuilder = new StringBuilder();
+        for (LanguageModel lm : LanguageModel.values()) {
+            supportedLMBuilder.append(lm.getModel()).append(",");
+        }
+        return supportedLMBuilder.substring(0, supportedLMBuilder.length() - 
1);
+    }
+}
diff --git 
a/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/parser/impl/FlinkSqlParser.java
 
b/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/parser/impl/FlinkSqlParser.java
index 62369e1d54..7604549507 100644
--- 
a/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/parser/impl/FlinkSqlParser.java
+++ 
b/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/parser/impl/FlinkSqlParser.java
@@ -23,6 +23,7 @@ import org.apache.inlong.sort.formats.common.ArrayFormatInfo;
 import org.apache.inlong.sort.formats.common.FormatInfo;
 import org.apache.inlong.sort.formats.common.MapFormatInfo;
 import org.apache.inlong.sort.formats.common.RowFormatInfo;
+import org.apache.inlong.sort.function.EmbeddingFunction;
 import org.apache.inlong.sort.function.EncryptFunction;
 import org.apache.inlong.sort.function.JsonGetterFunction;
 import org.apache.inlong.sort.function.RegexpReplaceFirstFunction;
@@ -73,6 +74,7 @@ import java.util.stream.Collectors;
 import java.util.stream.Stream;
 
 import static org.apache.inlong.common.util.MaskDataUtils.maskSensitiveMessage;
+import static 
org.apache.inlong.sort.function.EmbeddingFunction.DEFAULT_EMBEDDING_FUNCTION_NAME;
 
 /**
  * Flink sql parse handler
@@ -122,6 +124,7 @@ public class FlinkSqlParser implements Parser {
         tableEnv.createTemporarySystemFunction("REGEXP_REPLACE", 
RegexpReplaceFunction.class);
         tableEnv.createTemporarySystemFunction("ENCRYPT", 
EncryptFunction.class);
         tableEnv.createTemporarySystemFunction("JSON_GETTER", 
JsonGetterFunction.class);
+        
tableEnv.createTemporarySystemFunction(DEFAULT_EMBEDDING_FUNCTION_NAME, 
EmbeddingFunction.class);
     }
 
     /**
diff --git 
a/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/parser/impl/NativeFlinkSqlParser.java
 
b/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/parser/impl/NativeFlinkSqlParser.java
index 11dd43347c..1e254dabba 100644
--- 
a/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/parser/impl/NativeFlinkSqlParser.java
+++ 
b/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/parser/impl/NativeFlinkSqlParser.java
@@ -17,6 +17,7 @@
 
 package org.apache.inlong.sort.parser.impl;
 
+import org.apache.inlong.sort.function.EmbeddingFunction;
 import org.apache.inlong.sort.function.EncryptFunction;
 import org.apache.inlong.sort.function.JsonGetterFunction;
 import org.apache.inlong.sort.function.RegexpReplaceFirstFunction;
@@ -34,6 +35,8 @@ import java.util.ArrayList;
 import java.util.List;
 import java.util.Locale;
 
+import static 
org.apache.inlong.sort.function.EmbeddingFunction.DEFAULT_EMBEDDING_FUNCTION_NAME;
+
 /**
  * parse flink sql script file
  * script file include CREATE TABLE statement
@@ -70,6 +73,7 @@ public class NativeFlinkSqlParser implements Parser {
         tableEnv.createTemporarySystemFunction("REGEXP_REPLACE", 
RegexpReplaceFunction.class);
         tableEnv.createTemporarySystemFunction("ENCRYPT", 
EncryptFunction.class);
         tableEnv.createTemporarySystemFunction("JSON_GETTER", 
JsonGetterFunction.class);
+        
tableEnv.createTemporarySystemFunction(DEFAULT_EMBEDDING_FUNCTION_NAME, 
EmbeddingFunction.class);
     }
 
     /**
diff --git 
a/inlong-sort/sort-core/src/test/java/org/apache/inlong/sort/function/EmbeddingFunctionTest.java
 
b/inlong-sort/sort-core/src/test/java/org/apache/inlong/sort/function/EmbeddingFunctionTest.java
new file mode 100644
index 0000000000..0c14ca9483
--- /dev/null
+++ 
b/inlong-sort/sort-core/src/test/java/org/apache/inlong/sort/function/EmbeddingFunctionTest.java
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.inlong.sort.function;
+
+import org.apache.inlong.sort.function.embedding.EmbeddingInput;
+import org.apache.inlong.sort.function.embedding.LanguageModel;
+
+import com.sun.net.httpserver.HttpServer;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.EnvironmentSettings;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.net.InetSocketAddress;
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Locale;
+
+import static 
org.apache.inlong.sort.function.EmbeddingFunction.DEFAULT_EMBEDDING_FUNCTION_NAME;
+
+public class EmbeddingFunctionTest extends AbstractTestBase {
+
+    @Test
+    public void testMapper() throws Exception {
+        ObjectMapper mapper = new ObjectMapper();
+        EmbeddingInput embeddingInput = new EmbeddingInput("Input-Test", 
"Model-Test");
+        String encodedContents = mapper.writeValueAsString(embeddingInput);
+        String expect = "{\"input\":\"Input-Test\",\"model\":\"Model-Test\"}";
+        Assert.assertEquals(encodedContents, expect);
+    }
+
+    @Test
+    public void testLanguageModel() {
+        String supportedLMs = LanguageModel.getAllSupportedLanguageModels();
+        Assert.assertNotNull(supportedLMs);
+        String[] supportLMArray = supportedLMs.split(",");
+        Assert.assertEquals(supportLMArray.length, 
LanguageModel.values().length);
+
+        
Assert.assertTrue(LanguageModel.isLanguageModelSupported("BAAI/bge-large-zh-v1.5"));
+        
Assert.assertTrue(LanguageModel.isLanguageModelSupported("BAAI/bge-large-en"));
+        
Assert.assertTrue(LanguageModel.isLanguageModelSupported("intfloat/multilingual-e5-large"));
+        
Assert.assertFalse(LanguageModel.isLanguageModelSupported("fake/fake-language"));
+    }
+
+    /**
+     * Test for embedding function
+     *
+     * @throws Exception The exception may throw when test Embedding function
+     */
+    @Test
+    public void testEmbeddingFunction() throws Exception {
+        EnvironmentSettings settings = EnvironmentSettings
+                .newInstance()
+                .inStreamingMode()
+                .build();
+        StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+        env.setParallelism(1);
+        env.enableCheckpointing(10000);
+        StreamTableEnvironment tableEnv = StreamTableEnvironment.create(env, 
settings);
+
+        // step 1. Register custom function of Embedding
+        tableEnv.createTemporaryFunction(DEFAULT_EMBEDDING_FUNCTION_NAME, 
EmbeddingFunction.class);
+
+        List<String> udfNames = 
Arrays.asList(tableEnv.listUserDefinedFunctions());
+        
Assert.assertTrue(udfNames.contains(DEFAULT_EMBEDDING_FUNCTION_NAME.toLowerCase(Locale.ROOT)));
+
+        // step 2. Generate test data and convert to DataStream
+        int numOfMessages = 100;
+        List<String> sourceDateList = new ArrayList<>();
+        String msgPrefix = "Data for embedding-";
+        for (int i = 0; i < numOfMessages; i++) {
+            sourceDateList.add(msgPrefix + i);
+        }
+
+        List<Row> data = new ArrayList<>();
+        sourceDateList.forEach(s -> data.add(Row.of(s)));
+        TypeInformation<?>[] types = {BasicTypeInfo.STRING_TYPE_INFO};
+        String[] names = {"f1"};
+        RowTypeInfo typeInfo = new RowTypeInfo(types, names);
+        DataStream<Row> dataStream = 
env.fromCollection(data).returns(typeInfo);
+
+        // step 3. start a web server to mock embedding service
+        String embeddingResult = "{\"result\": \"Result data for embedding\"}";
+        HttpServer httpServer = HttpServer.create(new InetSocketAddress(8899), 
0); // or use InetSocketAddress(0) for
+                                                                               
    // ephemeral port
+        httpServer.createContext("/get_embedding", exchange -> {
+            byte[] response = embeddingResult.getBytes(StandardCharsets.UTF_8);
+            exchange.sendResponseHeaders(200, response.length);
+            exchange.getResponseBody().write(response);
+            exchange.close();
+        });
+        httpServer.start();
+
+        // step 4. Convert from DataStream to Table and execute the Embedding 
function
+        Table tempView = tableEnv.fromDataStream(dataStream).as("f1");
+        tableEnv.createTemporaryView("temp_view", tempView);
+        Table outputTable = tableEnv.sqlQuery(
+                "SELECT " +
+                        "f1," +
+                        "EMBEDDING('http://localhost:8899/get_embedding', f1, 
'BAAI/bge-large-en') " +
+                        "from temp_view");
+
+        // step 5. Get function execution result and parse it
+        DataStream<Row> resultSet = tableEnv.toAppendStream(outputTable, 
Row.class);
+        List<String> resultF0 = new ArrayList<>();
+        List<String> resultF1 = new ArrayList<>();
+        for (CloseableIterator<Row> it = resultSet.executeAndCollect(); 
it.hasNext();) {
+            Row row = it.next();
+            if (row != null) {
+                resultF0.add(row.getField(0).toString());
+                resultF1.add(row.getField(1).toString());
+            }
+        }
+        Assert.assertEquals(resultF0.size(), numOfMessages);
+        Assert.assertEquals(resultF1.size(), numOfMessages);
+        Assert.assertEquals(resultF0, sourceDateList);
+        for (String res : resultF1) {
+            Assert.assertEquals(res, embeddingResult);
+        }
+
+        httpServer.stop(0);
+    }
+}

Reply via email to