This is an automated email from the ASF dual-hosted git repository.
aloyszhang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/inlong.git
The following commit(s) were added to refs/heads/master by this push:
new 7b322a7aa1 [INLONG-9473][Sort] Support transform of embedding for LLM
applications (#9474)
7b322a7aa1 is described below
commit 7b322a7aa1437cb42e9e1d04a382547cca24296c
Author: AloysZhang <[email protected]>
AuthorDate: Thu Dec 14 19:08:02 2023 +0800
[INLONG-9473][Sort] Support transform of embedding for LLM applications
(#9474)
* [INLONG-9473][Sort] Support transform of embedding for LLM applications
* fix spotless issue
* fix log
---
.../inlong/sort/function/EmbeddingFunction.java | 105 +++++++++++++++
.../sort/function/embedding/EmbeddingInput.java | 50 +++++++
.../sort/function/embedding/LanguageModel.java | 71 ++++++++++
.../inlong/sort/parser/impl/FlinkSqlParser.java | 3 +
.../sort/parser/impl/NativeFlinkSqlParser.java | 4 +
.../sort/function/EmbeddingFunctionTest.java | 150 +++++++++++++++++++++
6 files changed, 383 insertions(+)
diff --git
a/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/function/EmbeddingFunction.java
b/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/function/EmbeddingFunction.java
new file mode 100644
index 0000000000..fc04c36b1a
--- /dev/null
+++
b/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/function/EmbeddingFunction.java
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.inlong.sort.function;
+
+import org.apache.inlong.sort.function.embedding.EmbeddingInput;
+import org.apache.inlong.sort.function.embedding.LanguageModel;
+
+import com.google.common.base.Strings;
+import
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper;
+import org.apache.flink.table.functions.ScalarFunction;
+import org.apache.http.HttpHeaders;
+import org.apache.http.HttpResponse;
+import org.apache.http.HttpStatus;
+import org.apache.http.client.HttpClient;
+import org.apache.http.client.config.RequestConfig;
+import org.apache.http.client.methods.HttpPost;
+import org.apache.http.entity.StringEntity;
+import org.apache.http.impl.client.HttpClientBuilder;
+import org.apache.http.util.EntityUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Embedding function for LLM applications.
+ * */
+public class EmbeddingFunction extends ScalarFunction {
+
+ public static final Logger logger =
LoggerFactory.getLogger(EmbeddingFunction.class);
+ public static final String DEFAULT_EMBEDDING_FUNCTION_NAME = "EMBEDDING";
+
+ private final ObjectMapper mapper = new ObjectMapper();
+ public static final int DEFAULT_CONNECT_TIMEOUT = 30000;
+ public static final int DEFAULT_SOCKET_TIMEOUT = 30000;
+ public static final String DEFAULT_MODEL =
LanguageModel.BBAI_ZH.getModel();
+ private transient HttpClient httpClient;
+
+ /**
+ * Embedding a LLM document(a String object for now) via http protocol
+ * @param url the service url for embedding service
+ * @param input the source data for embedding
+ * @param model the language model supported in the embedding service
+ * */
+ public String eval(String url, String input, String model) {
+ // url and input is not null
+ if (Strings.isNullOrEmpty(url) || Strings.isNullOrEmpty(input)) {
+ logger.error("Failed to embedding, both url and input can't be
empty or null, url: {}, input: {}",
+ url, input);
+ return null;
+ }
+
+ if (Strings.isNullOrEmpty(model)) {
+ model = DEFAULT_MODEL;
+ logger.info("model is null, use default model: {}", model);
+ }
+
+ if (!LanguageModel.isLanguageModelSupported(model)) {
+ logger.error("Failed to embedding, language model {} not
supported(only {} are supported right now)",
+ model, LanguageModel.getAllSupportedLanguageModels());
+ return null;
+ }
+
+ // initialize httpClient
+ if (httpClient == null) {
+ RequestConfig requestConfig = RequestConfig.custom()
+ .setConnectTimeout(DEFAULT_CONNECT_TIMEOUT)
+ .setSocketTimeout(DEFAULT_SOCKET_TIMEOUT)
+ .build();
+ httpClient =
HttpClientBuilder.create().setDefaultRequestConfig(requestConfig).build();
+ }
+
+ try {
+ HttpPost httpPost = new HttpPost(url);
+ httpPost.setHeader(HttpHeaders.CONTENT_TYPE, "application/json");
+ EmbeddingInput embeddingInput = new EmbeddingInput(input, model);
+ String encodedContents = mapper.writeValueAsString(embeddingInput);
+ httpPost.setEntity(new StringEntity(encodedContents));
+ HttpResponse response = httpClient.execute(httpPost);
+
+ String returnStr = EntityUtils.toString(response.getEntity());
+ int returnCode = response.getStatusLine().getStatusCode();
+ if (Strings.isNullOrEmpty(returnStr) || HttpStatus.SC_OK !=
returnCode) {
+ throw new Exception("Failed to embedding, result: " +
returnStr + ", code: " + returnCode);
+ }
+ return returnStr;
+ } catch (Exception e) {
+ logger.error("Failed to embedding, url: {}, input: {}", url,
input, e);
+ return null;
+ }
+ }
+}
\ No newline at end of file
diff --git
a/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/function/embedding/EmbeddingInput.java
b/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/function/embedding/EmbeddingInput.java
new file mode 100644
index 0000000000..f040db31d7
--- /dev/null
+++
b/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/function/embedding/EmbeddingInput.java
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.inlong.sort.function.embedding;
+
+import java.io.Serializable;
+
+/**
+ * Class representing the input of embedding function
+ */
+public class EmbeddingInput implements Serializable {
+
+ private String input;
+ private String model;
+
+ public void setInput(String input) {
+ this.input = input;
+ }
+
+ public String getInput() {
+ return input;
+ }
+
+ public void setModel(String model) {
+ this.model = model;
+ }
+
+ public String getModel() {
+ return model;
+ }
+
+ public EmbeddingInput(String input, String model) {
+ this.input = input;
+ this.model = model;
+ }
+}
diff --git
a/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/function/embedding/LanguageModel.java
b/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/function/embedding/LanguageModel.java
new file mode 100644
index 0000000000..aebe9cedd0
--- /dev/null
+++
b/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/function/embedding/LanguageModel.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.inlong.sort.function.embedding;
+
+import com.google.common.base.Strings;
+
+/**
+ * Supported language model for embedding.
+ */
+public enum LanguageModel {
+
+ /**
+ * Language model for BBAI zh, chinese
+ * */
+ BBAI_ZH("BAAI/bge-large-zh-v1.5"),
+ /**
+ * Language model for BBAI en, english
+ * */
+ BBAI_EN("BAAI/bge-large-en"),
+ /**
+ * Language model for intfloat multi-language
+ * */
+ INTFLOAT_MULTI("intfloat/multilingual-e5-large");
+ String model;
+
+ LanguageModel(String s) {
+ this.model = s;
+ }
+
+ public String getModel() {
+ return this.model;
+ }
+
+ public static boolean isLanguageModelSupported(String s) {
+ if (Strings.isNullOrEmpty(s)) {
+ return false;
+ }
+ for (LanguageModel lm : LanguageModel.values()) {
+ if (s.equalsIgnoreCase(lm.getModel())) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ public static String getAllSupportedLanguageModels() {
+ if (LanguageModel.values().length == 0) {
+ return null;
+ }
+ StringBuilder supportedLMBuilder = new StringBuilder();
+ for (LanguageModel lm : LanguageModel.values()) {
+ supportedLMBuilder.append(lm.getModel()).append(",");
+ }
+ return supportedLMBuilder.substring(0, supportedLMBuilder.length() -
1);
+ }
+}
diff --git
a/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/parser/impl/FlinkSqlParser.java
b/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/parser/impl/FlinkSqlParser.java
index 62369e1d54..7604549507 100644
---
a/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/parser/impl/FlinkSqlParser.java
+++
b/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/parser/impl/FlinkSqlParser.java
@@ -23,6 +23,7 @@ import org.apache.inlong.sort.formats.common.ArrayFormatInfo;
import org.apache.inlong.sort.formats.common.FormatInfo;
import org.apache.inlong.sort.formats.common.MapFormatInfo;
import org.apache.inlong.sort.formats.common.RowFormatInfo;
+import org.apache.inlong.sort.function.EmbeddingFunction;
import org.apache.inlong.sort.function.EncryptFunction;
import org.apache.inlong.sort.function.JsonGetterFunction;
import org.apache.inlong.sort.function.RegexpReplaceFirstFunction;
@@ -73,6 +74,7 @@ import java.util.stream.Collectors;
import java.util.stream.Stream;
import static org.apache.inlong.common.util.MaskDataUtils.maskSensitiveMessage;
+import static
org.apache.inlong.sort.function.EmbeddingFunction.DEFAULT_EMBEDDING_FUNCTION_NAME;
/**
* Flink sql parse handler
@@ -122,6 +124,7 @@ public class FlinkSqlParser implements Parser {
tableEnv.createTemporarySystemFunction("REGEXP_REPLACE",
RegexpReplaceFunction.class);
tableEnv.createTemporarySystemFunction("ENCRYPT",
EncryptFunction.class);
tableEnv.createTemporarySystemFunction("JSON_GETTER",
JsonGetterFunction.class);
+
tableEnv.createTemporarySystemFunction(DEFAULT_EMBEDDING_FUNCTION_NAME,
EmbeddingFunction.class);
}
/**
diff --git
a/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/parser/impl/NativeFlinkSqlParser.java
b/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/parser/impl/NativeFlinkSqlParser.java
index 11dd43347c..1e254dabba 100644
---
a/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/parser/impl/NativeFlinkSqlParser.java
+++
b/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/parser/impl/NativeFlinkSqlParser.java
@@ -17,6 +17,7 @@
package org.apache.inlong.sort.parser.impl;
+import org.apache.inlong.sort.function.EmbeddingFunction;
import org.apache.inlong.sort.function.EncryptFunction;
import org.apache.inlong.sort.function.JsonGetterFunction;
import org.apache.inlong.sort.function.RegexpReplaceFirstFunction;
@@ -34,6 +35,8 @@ import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
+import static
org.apache.inlong.sort.function.EmbeddingFunction.DEFAULT_EMBEDDING_FUNCTION_NAME;
+
/**
* parse flink sql script file
* script file include CREATE TABLE statement
@@ -70,6 +73,7 @@ public class NativeFlinkSqlParser implements Parser {
tableEnv.createTemporarySystemFunction("REGEXP_REPLACE",
RegexpReplaceFunction.class);
tableEnv.createTemporarySystemFunction("ENCRYPT",
EncryptFunction.class);
tableEnv.createTemporarySystemFunction("JSON_GETTER",
JsonGetterFunction.class);
+
tableEnv.createTemporarySystemFunction(DEFAULT_EMBEDDING_FUNCTION_NAME,
EmbeddingFunction.class);
}
/**
diff --git
a/inlong-sort/sort-core/src/test/java/org/apache/inlong/sort/function/EmbeddingFunctionTest.java
b/inlong-sort/sort-core/src/test/java/org/apache/inlong/sort/function/EmbeddingFunctionTest.java
new file mode 100644
index 0000000000..0c14ca9483
--- /dev/null
+++
b/inlong-sort/sort-core/src/test/java/org/apache/inlong/sort/function/EmbeddingFunctionTest.java
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.inlong.sort.function;
+
+import org.apache.inlong.sort.function.embedding.EmbeddingInput;
+import org.apache.inlong.sort.function.embedding.LanguageModel;
+
+import com.sun.net.httpserver.HttpServer;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.EnvironmentSettings;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.net.InetSocketAddress;
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Locale;
+
+import static
org.apache.inlong.sort.function.EmbeddingFunction.DEFAULT_EMBEDDING_FUNCTION_NAME;
+
+public class EmbeddingFunctionTest extends AbstractTestBase {
+
+ @Test
+ public void testMapper() throws Exception {
+ ObjectMapper mapper = new ObjectMapper();
+ EmbeddingInput embeddingInput = new EmbeddingInput("Input-Test",
"Model-Test");
+ String encodedContents = mapper.writeValueAsString(embeddingInput);
+ String expect = "{\"input\":\"Input-Test\",\"model\":\"Model-Test\"}";
+ Assert.assertEquals(encodedContents, expect);
+ }
+
+ @Test
+ public void testLanguageModel() {
+ String supportedLMs = LanguageModel.getAllSupportedLanguageModels();
+ Assert.assertNotNull(supportedLMs);
+ String[] supportLMArray = supportedLMs.split(",");
+ Assert.assertEquals(supportLMArray.length,
LanguageModel.values().length);
+
+
Assert.assertTrue(LanguageModel.isLanguageModelSupported("BAAI/bge-large-zh-v1.5"));
+
Assert.assertTrue(LanguageModel.isLanguageModelSupported("BAAI/bge-large-en"));
+
Assert.assertTrue(LanguageModel.isLanguageModelSupported("intfloat/multilingual-e5-large"));
+
Assert.assertFalse(LanguageModel.isLanguageModelSupported("fake/fake-language"));
+ }
+
+ /**
+ * Test for embedding function
+ *
+ * @throws Exception The exception may throw when test Embedding function
+ */
+ @Test
+ public void testEmbeddingFunction() throws Exception {
+ EnvironmentSettings settings = EnvironmentSettings
+ .newInstance()
+ .inStreamingMode()
+ .build();
+ StreamExecutionEnvironment env =
StreamExecutionEnvironment.getExecutionEnvironment();
+ env.setParallelism(1);
+ env.enableCheckpointing(10000);
+ StreamTableEnvironment tableEnv = StreamTableEnvironment.create(env,
settings);
+
+ // step 1. Register custom function of Embedding
+ tableEnv.createTemporaryFunction(DEFAULT_EMBEDDING_FUNCTION_NAME,
EmbeddingFunction.class);
+
+ List<String> udfNames =
Arrays.asList(tableEnv.listUserDefinedFunctions());
+
Assert.assertTrue(udfNames.contains(DEFAULT_EMBEDDING_FUNCTION_NAME.toLowerCase(Locale.ROOT)));
+
+ // step 2. Generate test data and convert to DataStream
+ int numOfMessages = 100;
+ List<String> sourceDateList = new ArrayList<>();
+ String msgPrefix = "Data for embedding-";
+ for (int i = 0; i < numOfMessages; i++) {
+ sourceDateList.add(msgPrefix + i);
+ }
+
+ List<Row> data = new ArrayList<>();
+ sourceDateList.forEach(s -> data.add(Row.of(s)));
+ TypeInformation<?>[] types = {BasicTypeInfo.STRING_TYPE_INFO};
+ String[] names = {"f1"};
+ RowTypeInfo typeInfo = new RowTypeInfo(types, names);
+ DataStream<Row> dataStream =
env.fromCollection(data).returns(typeInfo);
+
+ // step 3. start a web server to mock embedding service
+ String embeddingResult = "{\"result\": \"Result data for embedding\"}";
+ HttpServer httpServer = HttpServer.create(new InetSocketAddress(8899),
0); // or use InetSocketAddress(0) for
+
// ephemeral port
+ httpServer.createContext("/get_embedding", exchange -> {
+ byte[] response = embeddingResult.getBytes(StandardCharsets.UTF_8);
+ exchange.sendResponseHeaders(200, response.length);
+ exchange.getResponseBody().write(response);
+ exchange.close();
+ });
+ httpServer.start();
+
+ // step 4. Convert from DataStream to Table and execute the Embedding
function
+ Table tempView = tableEnv.fromDataStream(dataStream).as("f1");
+ tableEnv.createTemporaryView("temp_view", tempView);
+ Table outputTable = tableEnv.sqlQuery(
+ "SELECT " +
+ "f1," +
+ "EMBEDDING('http://localhost:8899/get_embedding', f1,
'BAAI/bge-large-en') " +
+ "from temp_view");
+
+ // step 5. Get function execution result and parse it
+ DataStream<Row> resultSet = tableEnv.toAppendStream(outputTable,
Row.class);
+ List<String> resultF0 = new ArrayList<>();
+ List<String> resultF1 = new ArrayList<>();
+ for (CloseableIterator<Row> it = resultSet.executeAndCollect();
it.hasNext();) {
+ Row row = it.next();
+ if (row != null) {
+ resultF0.add(row.getField(0).toString());
+ resultF1.add(row.getField(1).toString());
+ }
+ }
+ Assert.assertEquals(resultF0.size(), numOfMessages);
+ Assert.assertEquals(resultF1.size(), numOfMessages);
+ Assert.assertEquals(resultF0, sourceDateList);
+ for (String res : resultF1) {
+ Assert.assertEquals(res, embeddingResult);
+ }
+
+ httpServer.stop(0);
+ }
+}