This is an automated email from the ASF dual-hosted git repository.
lindong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new 7a4e5d13 [FLINK-31306] Add Servable for PipelineModel
7a4e5d13 is described below
commit 7a4e5d13a7832b2b411beff6127c923662b67401
Author: JiangXin <[email protected]>
AuthorDate: Mon Mar 13 20:08:19 2023 +0800
[FLINK-31306] Add Servable for PipelineModel
This closes #218.
---
docs/content/docs/development/overview.md | 2 +-
.../org/apache/flink/ml/benchmark/Benchmark.java | 3 +-
.../apache/flink/ml/benchmark/BenchmarkUtils.java | 7 +-
.../clustering/KMeansModelDataGenerator.java | 3 +-
flink-ml-core/pom.xml | 8 +
.../org/apache/flink/ml/builder/PipelineModel.java | 30 ++++
.../org/apache/flink/ml/util/ReadWriteUtils.java | 169 ++-------------------
.../org/apache/flink/ml/api/ExampleStages.java | 5 +
.../java/org/apache/flink/ml/api/PipelineTest.java | 72 +++++++++
.../java/org/apache/flink/ml/api/ServableTest.java | 106 +++++++++++++
.../java/org/apache/flink/ml/api/StageTest.java | 2 +-
.../java/org/apache/flink/ml/util/TestUtils.java | 28 ++++
flink-ml-lib/pom.xml | 7 +
.../apache/flink/ml/classification/knn/Knn.java | 2 +-
.../ml/classification/linearsvc/LinearSVC.java | 2 +-
.../logisticregression/LogisticRegression.java | 2 +-
.../OnlineLogisticRegression.java | 2 +-
.../ml/classification/naivebayes/NaiveBayes.java | 2 +-
.../apache/flink/ml/clustering/kmeans/KMeans.java | 2 +-
.../flink/ml/clustering/kmeans/OnlineKMeans.java | 2 +-
.../feature/countvectorizer/CountVectorizer.java | 2 +-
.../java/org/apache/flink/ml/feature/idf/IDF.java | 2 +-
.../apache/flink/ml/feature/imputer/Imputer.java | 2 +-
.../feature/kbinsdiscretizer/KBinsDiscretizer.java | 2 +-
.../apache/flink/ml/feature/lsh/MinHashLSH.java | 3 +-
.../ml/feature/maxabsscaler/MaxAbsScaler.java | 2 +-
.../ml/feature/minmaxscaler/MinMaxScaler.java | 2 +-
.../ml/feature/onehotencoder/OneHotEncoder.java | 2 +-
.../ml/feature/robustscaler/RobustScaler.java | 2 +-
.../standardscaler/OnlineStandardScaler.java | 2 +-
.../ml/feature/standardscaler/StandardScaler.java | 2 +-
.../ml/feature/stringindexer/StringIndexer.java | 2 +-
.../UnivariateFeatureSelector.java | 2 +-
.../VarianceThresholdSelector.java | 2 +-
.../ml/feature/vectorindexer/VectorIndexer.java | 2 +-
.../linearregression/LinearRegression.java | 2 +-
.../apache/flink/ml/classification/KnnTest.java | 4 +-
.../flink/ml/classification/LinearSVCTest.java | 4 +-
.../ml/classification/LogisticRegressionTest.java | 4 +-
.../flink/ml/classification/NaiveBayesTest.java | 4 +-
.../org/apache/flink/ml/clustering/KMeansTest.java | 4 +-
.../java/org/apache/flink/ml/feature/IDFTest.java | 4 +-
.../flink/ml/feature/KBinsDiscretizerTest.java | 4 +-
.../apache/flink/ml/feature/MaxAbsScalerTest.java | 4 +-
.../apache/flink/ml/feature/MinHashLSHTest.java | 4 +-
.../apache/flink/ml/feature/MinMaxScalerTest.java | 4 +-
.../apache/flink/ml/feature/OneHotEncoderTest.java | 4 +-
.../flink/ml/feature/OnlineStandardScalerTest.java | 4 +-
.../flink/ml/feature/StandardScalerTest.java | 4 +-
.../apache/flink/ml/feature/VectorIndexerTest.java | 4 +-
.../feature/stringindexer/StringIndexerTest.java | 4 +-
.../flink/ml/regression/LinearRegressionTest.java | 4 +-
.../pyflink/ml/feature/tests/test_minhashlsh.py | 2 +-
flink-ml-python/pyflink/ml/tests/test_utils.py | 2 +-
.../apache/flink/ml/servable/api/DataFrame.java | 4 +-
.../ml/servable/builder/PipelineModelServable.java | 67 ++++++++
.../apache/flink/ml/servable/types/DataTypes.java | 49 ++++++
.../java/org/apache/flink/ml/util/FileUtils.java | 117 ++++++++++++++
.../java/org/apache/flink/ml/util/ParamUtils.java | 54 +++++++
.../flink/ml/util/ServableReadWriteUtils.java | 152 ++++++++++++++++++
.../org/apache/flink/ml/servable/TestUtils.java | 49 ++++++
.../ml/servable/builder/ExampleServables.java | 115 ++++++++++++++
.../builder/PipelineModelServableTest.java | 77 ++++++++++
tools/maven/suppressions.xml | 1 +
64 files changed, 1019 insertions(+), 223 deletions(-)
diff --git a/docs/content/docs/development/overview.md
b/docs/content/docs/development/overview.md
index 6075e239..5be08ebe 100644
--- a/docs/content/docs/development/overview.md
+++ b/docs/content/docs/development/overview.md
@@ -239,7 +239,7 @@ following ways.
the number of clusters, of a K-means algorithm, users can directly invoke
`setK()` method on that `KMeans` instance.
- Pass a parameter map containing new values to the stage through
- `ReadWriteUtils.updateExistingParams()` method.
+ `ParamUtils.updateExistingParams()` method.
If a `Model` is generated through an `Estimator`'s `fit()` method, the `Model`
would inherit the `Estimator` object's parameters. Thus there is no need to set
diff --git
a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/Benchmark.java
b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/Benchmark.java
index 6886bc28..3b69f128 100644
---
a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/Benchmark.java
+++
b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/Benchmark.java
@@ -19,6 +19,7 @@
package org.apache.flink.ml.benchmark;
import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.ml.util.FileUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
@@ -118,7 +119,7 @@ public class Benchmark {
.writerWithDefaultPrettyPrinter()
.writeValueAsString(benchmarks);
if (commandLine.hasOption(OUTPUT_FILE_OPTION.getLongOpt())) {
- ReadWriteUtils.saveToFile(saveFile, benchmarkResultsJson, true);
+ FileUtils.saveToFile(saveFile, benchmarkResultsJson, true);
System.out.printf("Benchmark results saved as json in %s.\n",
saveFile);
} else {
System.out.printf("Benchmark results summary:\n%s\n",
benchmarkResultsJson);
diff --git
a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/BenchmarkUtils.java
b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/BenchmarkUtils.java
index 22300fab..8532c32d 100644
---
a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/BenchmarkUtils.java
+++
b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/BenchmarkUtils.java
@@ -28,6 +28,7 @@ import org.apache.flink.ml.api.Stage;
import org.apache.flink.ml.benchmark.datagenerator.DataGenerator;
import org.apache.flink.ml.benchmark.datagenerator.InputDataGenerator;
import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.sink.RichSinkFunction;
@@ -77,12 +78,12 @@ public class BenchmarkUtils {
Map<String, Map<String, ?>> params,
boolean dryRun)
throws Exception {
- Stage stage =
ReadWriteUtils.instantiateWithParams(params.get("stage"));
+ Stage stage = ParamUtils.instantiateWithParams(params.get("stage"));
InputDataGenerator inputDataGenerator =
- ReadWriteUtils.instantiateWithParams(params.get("inputData"));
+ ParamUtils.instantiateWithParams(params.get("inputData"));
DataGenerator modelDataGenerator = null;
if (params.containsKey("modelData")) {
- modelDataGenerator =
ReadWriteUtils.instantiateWithParams(params.get("modelData"));
+ modelDataGenerator =
ParamUtils.instantiateWithParams(params.get("modelData"));
}
return runBenchmark(tEnv, name, stage, inputDataGenerator,
modelDataGenerator, dryRun);
diff --git
a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/clustering/KMeansModelDataGenerator.java
b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/clustering/KMeansModelDataGenerator.java
index 1b5cc365..4a272e54 100644
---
a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/clustering/KMeansModelDataGenerator.java
+++
b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/clustering/KMeansModelDataGenerator.java
@@ -27,7 +27,6 @@ import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
-import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
@@ -59,7 +58,7 @@ public class KMeansModelDataGenerator
@Override
public Table[] getData(StreamTableEnvironment tEnv) {
InputDataGenerator<?> vectorArrayGenerator = new
DenseVectorArrayGenerator();
- ReadWriteUtils.updateExistingParams(vectorArrayGenerator, paramMap);
+ ParamUtils.updateExistingParams(vectorArrayGenerator, paramMap);
vectorArrayGenerator.setNumValues(1);
vectorArrayGenerator.setColNames(new String[] {"centroids"});
diff --git a/flink-ml-core/pom.xml b/flink-ml-core/pom.xml
index 34b77eb5..71b7d02f 100644
--- a/flink-ml-core/pom.xml
+++ b/flink-ml-core/pom.xml
@@ -99,6 +99,14 @@ under the License.
<type>test-jar</type>
</dependency>
+ <dependency>
+ <groupId>org.apache.flink</groupId>
+ <artifactId>flink-ml-servable-core</artifactId>
+ <version>${project.version}</version>
+ <scope>test</scope>
+ <type>test-jar</type>
+ </dependency>
+
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-test-utils</artifactId>
diff --git
a/flink-ml-core/src/main/java/org/apache/flink/ml/builder/PipelineModel.java
b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/PipelineModel.java
index 7c6bfdbf..60207d77 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/builder/PipelineModel.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/PipelineModel.java
@@ -23,7 +23,10 @@ import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.ml.api.AlgoOperator;
import org.apache.flink.ml.api.Model;
import org.apache.flink.ml.api.Stage;
+import org.apache.flink.ml.api.Transformer;
import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.servable.api.TransformerServable;
+import org.apache.flink.ml.servable.builder.PipelineModelServable;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.table.api.Table;
@@ -82,6 +85,33 @@ public final class PipelineModel implements
Model<PipelineModel> {
ReadWriteUtils.loadPipeline(tEnv, path,
PipelineModel.class.getName()));
}
+ public static PipelineModelServable loadServable(String path) throws
IOException {
+ return PipelineModelServable.load(path);
+ }
+
+ /**
+ * Whether all stages in the pipeline have corresponding {@link
TransformerServable} so that the
+ * PipelineModel can be turned into a TransformerServable and used in an
online inference
+ * program.
+ *
+ * @return true if all stages have corresponding TransformerServable,
false if not.
+ */
+ public boolean supportServable() {
+ for (Stage<?> stage : stages) {
+ if (!(stage instanceof Transformer)) {
+ return false;
+ }
+ Transformer<?> transformer = (Transformer<?>) stage;
+ Class<?> clazz = transformer.getClass();
+ try {
+ clazz.getMethod("loadServable", String.class);
+ } catch (NoSuchMethodException e) {
+ return false;
+ }
+ }
+ return true;
+ }
+
/**
* Returns a list of all stages in this PipelineModel in order. The list
is immutable.
*
diff --git
a/flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java
b/flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java
index 98d8d335..b284bbb3 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java
@@ -24,7 +24,6 @@ import org.apache.flink.api.connector.source.Source;
import org.apache.flink.connector.file.sink.FileSink;
import org.apache.flink.connector.file.src.FileSource;
import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
-import org.apache.flink.core.fs.FileSystem;
import org.apache.flink.core.fs.Path;
import org.apache.flink.ml.api.Stage;
import org.apache.flink.ml.builder.Graph;
@@ -33,25 +32,18 @@ import org.apache.flink.ml.builder.GraphModel;
import org.apache.flink.ml.builder.GraphNode;
import org.apache.flink.ml.common.datastream.TableUtils;
import org.apache.flink.ml.param.Param;
-import org.apache.flink.ml.param.WithParams;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import
org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner;
import
org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
-import org.apache.flink.util.InstantiationUtil;
import org.apache.flink.util.Preconditions;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonParser;
import
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper;
-import java.io.BufferedReader;
-import java.io.BufferedWriter;
-import java.io.File;
import java.io.IOException;
-import java.io.InputStreamReader;
-import java.io.OutputStreamWriter;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
@@ -59,7 +51,6 @@ import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
-import java.util.Set;
/** Utility methods for reading and writing stages. */
public class ReadWriteUtils {
@@ -104,28 +95,7 @@ public class ReadWriteUtils {
// TODO: add version in the metadata.
String metadataStr = OBJECT_MAPPER.writeValueAsString(metadata);
- saveToFile(new Path(path, "metadata").toUri().toString(), metadataStr,
false);
- }
-
- /** Saves a given string to the specified file. */
- public static void saveToFile(String pathStr, String content, boolean
isOverwrite)
- throws IOException {
- Path path = new Path(pathStr);
-
- // Creates parent directories if not already created.
- FileSystem fs = mkdirs(path.getParent());
-
- FileSystem.WriteMode writeMode = FileSystem.WriteMode.OVERWRITE;
- if (!isOverwrite) {
- writeMode = FileSystem.WriteMode.NO_OVERWRITE;
- if (fs.exists(path)) {
- throw new IOException("File " + path + " already exists.");
- }
- }
- try (BufferedWriter writer =
- new BufferedWriter(new OutputStreamWriter(fs.create(path,
writeMode)))) {
- writer.write(content);
- }
+ FileUtils.saveToFile(new Path(path, "metadata").toUri().toString(),
metadataStr, false);
}
/**
@@ -141,62 +111,6 @@ public class ReadWriteUtils {
saveMetadata(stage, path, new HashMap<>());
}
- /** Returns a subdirectory of the given path for saving/loading model
data. */
- private static String getDataPath(String path) {
- return new Path(path, "data").toString();
- }
-
- /**
- * Loads the metadata from the metadata file under the given path.
- *
- * <p>The method throws RuntimeException if the expectedClassName is not
empty AND it does not
- * match the className of the previously saved stage.
- *
- * @param path The parent directory of the metadata file to read from.
- * @param expectedClassName The expected class name of the stage.
- * @return A map from metadata name to metadata value.
- */
- public static Map<String, ?> loadMetadata(String path, String
expectedClassName)
- throws IOException {
- Path metadataPath = new Path(path, "metadata");
- FileSystem fs = metadataPath.getFileSystem();
-
- StringBuilder buffer = new StringBuilder();
- try (BufferedReader br = new BufferedReader(new
InputStreamReader(fs.open(metadataPath)))) {
- String line;
- while ((line = br.readLine()) != null) {
- if (!line.startsWith("#")) {
- buffer.append(line);
- }
- }
- }
-
- @SuppressWarnings("unchecked")
- Map<String, ?> result = OBJECT_MAPPER.readValue(buffer.toString(),
Map.class);
-
- String className = (String) result.get("className");
- if (!expectedClassName.isEmpty() &&
!expectedClassName.equals(className)) {
- throw new RuntimeException(
- "Class name "
- + className
- + " does not match the expected class name "
- + expectedClassName
- + ".");
- }
-
- return result;
- }
-
- // Returns a string with value {parentPath}/stages/{stageIdx}, where the
stageIdx is prefixed
- // with zero or more `0` to have the same length as numStages. The
resulting string can be
- // used as the directory to save a stage of the Pipeline or PipelineModel.
- private static String getPathForPipelineStage(int stageIdx, int numStages,
String parentPath) {
- String format =
- String.format("stages%s%%0%dd", File.separator,
String.valueOf(numStages).length());
- String fileName = String.format(format, stageIdx);
- return new Path(parentPath, fileName).toString();
- }
-
/**
* Saves a Pipeline or PipelineModel with the given list of stages to the
given path.
*
@@ -207,7 +121,7 @@ public class ReadWriteUtils {
public static void savePipeline(Stage<?> pipeline, List<Stage<?>> stages,
String path)
throws IOException {
// Creates parent directories if not already created.
- mkdirs(new Path(path));
+ FileUtils.mkdirs(new Path(path));
Map<String, Object> extraMetadata = new HashMap<>();
extraMetadata.put("numStages", stages.size());
@@ -215,7 +129,7 @@ public class ReadWriteUtils {
int numStages = stages.size();
for (int i = 0; i < numStages; i++) {
- String stagePath = getPathForPipelineStage(i, numStages, path);
+ String stagePath = FileUtils.getPathForPipelineStage(i, numStages,
path);
stages.get(i).save(stagePath);
}
}
@@ -233,23 +147,17 @@ public class ReadWriteUtils {
*/
public static List<Stage<?>> loadPipeline(
StreamTableEnvironment tEnv, String path, String
expectedClassName) throws IOException {
- Map<String, ?> metadata = loadMetadata(path, expectedClassName);
+ Map<String, ?> metadata = FileUtils.loadMetadata(path,
expectedClassName);
int numStages = (Integer) metadata.get("numStages");
List<Stage<?>> stages = new ArrayList<>(numStages);
for (int i = 0; i < numStages; i++) {
- String stagePath = getPathForPipelineStage(i, numStages, path);
+ String stagePath = FileUtils.getPathForPipelineStage(i, numStages,
path);
stages.add(loadStage(tEnv, stagePath));
}
return stages;
}
- private static FileSystem mkdirs(Path path) throws IOException {
- FileSystem fs = path.getFileSystem();
- fs.mkdirs(path);
- return fs;
- }
-
/**
* Saves a Graph or GraphModel with the given GraphData to the given path.
*
@@ -260,7 +168,7 @@ public class ReadWriteUtils {
public static void saveGraph(Stage<?> graph, GraphData graphData, String
path)
throws IOException {
// Creates parent directories if not already created.
- mkdirs(new Path(path));
+ FileUtils.mkdirs(new Path(path));
Map<String, Object> extraMetadata = new HashMap<>();
extraMetadata.put("graphData", graphData.toMap());
@@ -272,7 +180,7 @@ public class ReadWriteUtils {
.orElse(-1);
for (GraphNode node : graphData.nodes) {
- String stagePath = getPathForPipelineStage(node.nodeId, maxNodeId
+ 1, path);
+ String stagePath = FileUtils.getPathForPipelineStage(node.nodeId,
maxNodeId + 1, path);
node.stage.save(stagePath);
}
}
@@ -291,7 +199,7 @@ public class ReadWriteUtils {
@SuppressWarnings("unchecked")
public static Stage<?> loadGraph(
StreamTableEnvironment tEnv, String path, String
expectedClassName) throws IOException {
- Map<String, ?> metadata = loadMetadata(path, expectedClassName);
+ Map<String, ?> metadata = FileUtils.loadMetadata(path,
expectedClassName);
GraphData graphData = GraphData.fromMap((Map<String, Object>)
metadata.get("graphData"));
int maxNodeId =
@@ -301,7 +209,7 @@ public class ReadWriteUtils {
.orElse(-1);
for (GraphNode node : graphData.nodes) {
- String stagePath = getPathForPipelineStage(node.nodeId, maxNodeId
+ 1, path);
+ String stagePath = FileUtils.getPathForPipelineStage(node.nodeId,
maxNodeId + 1, path);
node.stage = loadStage(tEnv, stagePath);
}
@@ -323,27 +231,6 @@ public class ReadWriteUtils {
graphData.outputModelDataIds);
}
- // A helper method that sets WithParams object's parameter value. We can
not call
- // WithParams.set(param, value)
- // directly because WithParams::set(...) needs the actual type of the
value.
- @SuppressWarnings("unchecked")
- public static <T> void setParam(WithParams<?> instance, Param<T> param,
Object value) {
- instance.set(param, (T) value);
- }
-
- // A helper method that updates WithParams instance's param map using
values from the
- // paramOverrides. This method only updates values for parameters already
defined in the
- // instance's param map.
- public static void updateExistingParams(
- WithParams<?> instance, Map<Param<?>, Object> paramOverrides) {
- Set<Param<?>> existingParams = instance.getParamMap().keySet();
- for (Map.Entry<Param<?>, Object> entry : paramOverrides.entrySet()) {
- if (existingParams.contains(entry.getKey())) {
- setParam(instance, entry.getKey(), entry.getValue());
- }
- }
- }
-
/**
* Loads the stage with the saved parameters from the given path. This
method reads the metadata
* file under the given path, instantiates the stage using its no-argument
constructor, and
@@ -360,41 +247,12 @@ public class ReadWriteUtils {
*/
public static <T extends Stage<T>> T loadStageParam(String path) throws
IOException {
try {
- return instantiateWithParams(loadMetadata(path, ""));
+ return
ParamUtils.instantiateWithParams(FileUtils.loadMetadata(path, ""));
} catch (ClassNotFoundException e) {
throw new RuntimeException("Failed to load stage.", e);
}
}
- /**
- * Instantiates a WithParams subclass from the provided json map.
- *
- * @param jsonMap a map containing className and paramMap.
- * @return the instantiated WithParams subclass instance.
- */
- @SuppressWarnings("unchecked")
- public static <T extends WithParams<T>> T
instantiateWithParams(Map<String, ?> jsonMap)
- throws ClassNotFoundException, IOException {
- String className = (String) jsonMap.get("className");
- Class<T> clazz = (Class<T>) Class.forName(className);
- T instance = InstantiationUtil.instantiate(clazz);
-
- Map<String, Param<?>> nameToParam = new HashMap<>();
- for (Param<?> param : ParamUtils.getPublicFinalParamFields(instance)) {
- nameToParam.put(param.name, param);
- }
-
- if (jsonMap.containsKey("paramMap")) {
- Map<String, Object> paramMap = (Map<String, Object>)
jsonMap.get("paramMap");
- for (Map.Entry<String, Object> entry : paramMap.entrySet()) {
- Param<?> param = nameToParam.get(entry.getKey());
- setParam(instance, param, param.jsonDecode(entry.getValue()));
- }
- }
-
- return instance;
- }
-
/**
* Loads the stage from the given path by invoking the static load()
method of the stage. The
* stage class name is read from the metadata file under the given path.
The load() method is
@@ -408,7 +266,7 @@ public class ReadWriteUtils {
* @return An instance of Stage.
*/
public static Stage<?> loadStage(StreamTableEnvironment tEnv, String path)
throws IOException {
- Map<String, ?> metadata = loadMetadata(path, "");
+ Map<String, ?> metadata = FileUtils.loadMetadata(path, "");
String className = (String) metadata.get("className");
try {
@@ -440,8 +298,7 @@ public class ReadWriteUtils {
public static <T> void saveModelData(
DataStream<T> model, String path, Encoder<T> modelEncoder) {
FileSink<T> sink =
- FileSink.forRowFormat(
- new
org.apache.flink.core.fs.Path(getDataPath(path)), modelEncoder)
+ FileSink.forRowFormat(FileUtils.getDataPath(path),
modelEncoder)
.withRollingPolicy(OnCheckpointRollingPolicy.build())
.withBucketAssigner(new BasePathBucketAssigner<>())
.build();
@@ -461,7 +318,7 @@ public class ReadWriteUtils {
StreamTableEnvironment tEnv, String path, SimpleStreamFormat<T>
modelDecoder) {
StreamExecutionEnvironment env =
TableUtils.getExecutionEnvironment(tEnv);
Source<T, ?, ?> source =
- FileSource.forRecordStreamFormat(modelDecoder, new
Path(getDataPath(path))).build();
+ FileSource.forRecordStreamFormat(modelDecoder,
FileUtils.getDataPath(path)).build();
DataStream<T> modelDataStream =
env.fromSource(source, WatermarkStrategy.noWatermarks(),
"modelData");
return tEnv.fromDataStream(modelDataStream);
diff --git
a/flink-ml-core/src/test/java/org/apache/flink/ml/api/ExampleStages.java
b/flink-ml-core/src/test/java/org/apache/flink/ml/api/ExampleStages.java
index d6c0ab21..477725f5 100644
--- a/flink-ml-core/src/test/java/org/apache/flink/ml/api/ExampleStages.java
+++ b/flink-ml-core/src/test/java/org/apache/flink/ml/api/ExampleStages.java
@@ -24,6 +24,7 @@ import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.servable.builder.ExampleServables.SumModelServable;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.ml.util.TestUtils;
@@ -110,6 +111,10 @@ public class ExampleStages {
SumModel model = ReadWriteUtils.loadStageParam(path);
return model.setModelData(modelDataTable);
}
+
+ public static SumModelServable loadServable(String path) throws
IOException {
+ return SumModelServable.load(path);
+ }
}
// Adds delta from the 2nd input to every element in the 1st input and
returns the added values.
diff --git
a/flink-ml-core/src/test/java/org/apache/flink/ml/api/PipelineTest.java
b/flink-ml-core/src/test/java/org/apache/flink/ml/api/PipelineTest.java
index e24cdbe6..0c564d97 100644
--- a/flink-ml-core/src/test/java/org/apache/flink/ml/api/PipelineTest.java
+++ b/flink-ml-core/src/test/java/org/apache/flink/ml/api/PipelineTest.java
@@ -20,26 +20,39 @@ package org.apache.flink.ml.api;
import org.apache.flink.ml.api.ExampleStages.SumEstimator;
import org.apache.flink.ml.api.ExampleStages.SumModel;
+import org.apache.flink.ml.api.ExampleStages.UnionAlgoOperator;
import org.apache.flink.ml.builder.Pipeline;
import org.apache.flink.ml.builder.PipelineModel;
+import org.apache.flink.ml.servable.api.DataFrame;
+import org.apache.flink.ml.servable.api.Row;
+import org.apache.flink.ml.servable.builder.PipelineModelServable;
+import org.apache.flink.ml.servable.types.DataTypes;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.test.util.AbstractTestBase;
import org.junit.Before;
+import org.junit.Rule;
import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
import java.nio.file.Files;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
+import static org.apache.flink.ml.servable.TestUtils.assertDataFrameEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
/** Tests the behavior of Pipeline and PipelineModel. */
public class PipelineTest extends AbstractTestBase {
private StreamExecutionEnvironment env;
private StreamTableEnvironment tEnv;
+ @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
@Before
public void before() {
env = TestUtils.getExecutionEnvironment();
@@ -95,4 +108,63 @@ public class PipelineTest extends AbstractTestBase {
// Executes the loaded Pipeline and verifies that it produces the
expected output.
TestUtils.executeAndCheckOutput(env, loadedEstimator, inputs, output,
null, null);
}
+
+ @Test
+ public void testSupportServable() {
+ SumEstimator estimatorA = new SumEstimator();
+ UnionAlgoOperator algoOperatorA = new UnionAlgoOperator();
+ SumModel modelA = new SumModel();
+ SumModel modelB = new SumModel();
+
+ List<Stage<?>> stages = Arrays.asList(modelA, modelB);
+ PipelineModel pipelineModel = new PipelineModel(stages);
+ assertTrue(pipelineModel.supportServable());
+
+ stages = Arrays.asList(estimatorA, modelA);
+ pipelineModel = new PipelineModel(stages);
+ assertFalse(pipelineModel.supportServable());
+
+ stages = Arrays.asList(algoOperatorA, modelA);
+ pipelineModel = new PipelineModel(stages);
+ assertFalse(pipelineModel.supportServable());
+ }
+
+ @Test
+ public void testPipelineModelServable() throws Exception {
+ SumModel modelA = new SumModel().setModelData(tEnv.fromValues(10));
+ SumModel modelB = new SumModel().setModelData(tEnv.fromValues(20));
+ SumModel modelC = new SumModel().setModelData(tEnv.fromValues(30));
+
+ List<Stage<?>> stages = Arrays.asList(modelA, modelB, modelC);
+ Model<?> model = new PipelineModel(stages);
+
+ PipelineModelServable servable =
+ TestUtils.saveAndLoadServable(
+ tEnv,
+ model,
+ tempFolder.newFolder().getAbsolutePath(),
+ PipelineModel::loadServable);
+
+ DataFrame input =
+ new DataFrame(
+ Collections.singletonList("input"),
+ Collections.singletonList(DataTypes.INT),
+ Arrays.asList(
+ new Row(Collections.singletonList(1)),
+ new Row(Collections.singletonList(2)),
+ new Row(Collections.singletonList(3))));
+
+ DataFrame output = servable.transform(input);
+
+ DataFrame expectedOutput =
+ new DataFrame(
+ Collections.singletonList("input"),
+ Collections.singletonList(DataTypes.INT),
+ Arrays.asList(
+ new Row(Collections.singletonList(61)),
+ new Row(Collections.singletonList(62)),
+ new Row(Collections.singletonList(63))));
+
+ assertDataFrameEquals(expectedOutput, output);
+ }
}
diff --git
a/flink-ml-core/src/test/java/org/apache/flink/ml/api/ServableTest.java
b/flink-ml-core/src/test/java/org/apache/flink/ml/api/ServableTest.java
new file mode 100644
index 00000000..8a92cad1
--- /dev/null
+++ b/flink-ml-core/src/test/java/org/apache/flink/ml/api/ServableTest.java
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.api;
+
+import org.apache.flink.ml.api.ExampleStages.SumModel;
+import org.apache.flink.ml.servable.api.DataFrame;
+import org.apache.flink.ml.servable.api.Row;
+import org.apache.flink.ml.servable.builder.ExampleServables.SumModelServable;
+import org.apache.flink.ml.servable.types.DataTypes;
+import org.apache.flink.ml.util.TestUtils;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.io.ByteArrayInputStream;
+import java.util.Arrays;
+import java.util.Collections;
+
+import static org.apache.flink.ml.servable.TestUtils.assertDataFrameEquals;
+
+/** Tests the behavior of integration between Transformer and Servable. */
+public class ServableTest extends AbstractTestBase {
+
+ private StreamTableEnvironment tEnv;
+
+ private static final DataFrame INPUT =
+ new DataFrame(
+ Collections.singletonList("input"),
+ Collections.singletonList(DataTypes.INT),
+ Arrays.asList(
+ new Row(Collections.singletonList(1)),
+ new Row(Collections.singletonList(2)),
+ new Row(Collections.singletonList(3))));
+
+ private static final DataFrame EXPECTED_OUTPUT =
+ new DataFrame(
+ Collections.singletonList("input"),
+ Collections.singletonList(DataTypes.INT),
+ Arrays.asList(
+ new Row(Collections.singletonList(11)),
+ new Row(Collections.singletonList(12)),
+ new Row(Collections.singletonList(13))));
+
+ @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+ @Before
+ public void before() {
+ StreamExecutionEnvironment env = TestUtils.getExecutionEnvironment();
+ tEnv = StreamTableEnvironment.create(env);
+ }
+
+ @Test
+ public void testSaveModelLoadServable() throws Exception {
+ String modelPath = tempFolder.newFolder().getAbsolutePath();
+
+ SumModel model = new SumModel().setModelData(tEnv.fromValues(10));
+
+ SumModelServable servable =
+ TestUtils.saveAndLoadServable(tEnv, model, modelPath,
SumModel::loadServable);
+
+ DataFrame output = servable.transform(INPUT);
+
+ assertDataFrameEquals(EXPECTED_OUTPUT, output);
+ }
+
+ @Test
+ public void testSetModelData() throws Exception {
+ SumModel model = new SumModel().setModelData(tEnv.fromValues(10));
+ Table modelDataTable = model.getModelData()[0];
+
+ byte[] serializedModelData =
+ tEnv.toDataStream(modelDataTable)
+ .map(x -> SumModelServable.serialize(x.getField(0)))
+ .executeAndCollect()
+ .next();
+
+ SumModelServable servable =
+ new SumModelServable().setModelData(new
ByteArrayInputStream(serializedModelData));
+
+ DataFrame output = servable.transform(INPUT);
+
+ assertDataFrameEquals(EXPECTED_OUTPUT, output);
+ }
+}
diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/api/StageTest.java
b/flink-ml-core/src/test/java/org/apache/flink/ml/api/StageTest.java
index 9d682c81..f812f8b2 100644
--- a/flink-ml-core/src/test/java/org/apache/flink/ml/api/StageTest.java
+++ b/flink-ml-core/src/test/java/org/apache/flink/ml/api/StageTest.java
@@ -229,7 +229,7 @@ public class StageTest {
throws IOException {
for (Map.Entry<String, Object> entry : paramOverrides.entrySet()) {
Param<?> param = stage.getParam(entry.getKey());
- ReadWriteUtils.setParam(stage, param, entry.getValue());
+ ParamUtils.setParam(stage, param, entry.getValue());
}
String path = Files.createTempDirectory("").toString();
diff --git
a/flink-ml-core/src/test/java/org/apache/flink/ml/util/TestUtils.java
b/flink-ml-core/src/test/java/org/apache/flink/ml/util/TestUtils.java
index 4623dafd..26199788 100644
--- a/flink-ml-core/src/test/java/org/apache/flink/ml/util/TestUtils.java
+++ b/flink-ml-core/src/test/java/org/apache/flink/ml/util/TestUtils.java
@@ -32,10 +32,12 @@ import org.apache.flink.ml.api.AlgoOperator;
import org.apache.flink.ml.api.Estimator;
import org.apache.flink.ml.api.Model;
import org.apache.flink.ml.api.Stage;
+import org.apache.flink.ml.api.Transformer;
import org.apache.flink.ml.common.datastream.TableUtils;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfo;
+import org.apache.flink.ml.servable.api.TransformerServable;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
@@ -43,6 +45,7 @@ import
org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.types.DataType;
import org.apache.flink.test.util.TestBaseUtils;
import org.apache.flink.types.Row;
+import org.apache.flink.util.function.FunctionWithException;
import org.apache.commons.collections.IteratorUtils;
import org.apache.commons.lang3.ArrayUtils;
@@ -209,6 +212,31 @@ public class TestUtils {
return (T) method.invoke(null, tEnv, path);
}
+ /**
+ * Saves a transformer to filesystem and reloads the matadata as a
servable with the given
+ * loadServable function.
+ */
+ public static <T extends TransformerServable<T>> T saveAndLoadServable(
+ StreamTableEnvironment tEnv,
+ Transformer<?> transformer,
+ String path,
+ FunctionWithException<String, T, IOException> loadServableFunc)
+ throws Exception {
+ StreamExecutionEnvironment env =
TableUtils.getExecutionEnvironment(tEnv);
+
+ transformer.save(path);
+ try {
+ env.execute();
+ } catch (RuntimeException e) {
+ if (!e.getMessage()
+ .equals("No operators defined in streaming topology.
Cannot execute.")) {
+ throw e;
+ }
+ }
+
+ return loadServableFunc.apply(path);
+ }
+
/**
* Converts data types in the table to sparse types and integer types.
*
diff --git a/flink-ml-lib/pom.xml b/flink-ml-lib/pom.xml
index b8620c15..777c6b98 100644
--- a/flink-ml-lib/pom.xml
+++ b/flink-ml-lib/pom.xml
@@ -117,6 +117,13 @@ under the License.
<scope>test</scope>
<type>test-jar</type>
</dependency>
+ <dependency>
+ <groupId>org.apache.flink</groupId>
+ <artifactId>flink-ml-servable-core</artifactId>
+ <version>${project.version}</version>
+ <scope>test</scope>
+ <type>test-jar</type>
+ </dependency>
</dependencies>
<build>
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
index 36e9d37d..8ad15e27 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
@@ -67,7 +67,7 @@ public class Knn implements Estimator<Knn, KnnModel>,
KnnParams<Knn> {
computeNormSquare(tEnv.toDataStream(inputs[0]));
DataStream<KnnModelData> modelData = genModelData(inputDataWithNorm);
KnnModel model = new
KnnModel().setModelData(tEnv.fromDataStream(modelData));
- ReadWriteUtils.updateExistingParams(model, getParamMap());
+ ParamUtils.updateExistingParams(model, getParamMap());
return model;
}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVC.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVC.java
index 8a8d1cbd..30c166b4 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVC.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVC.java
@@ -108,7 +108,7 @@ public class LinearSVC implements Estimator<LinearSVC,
LinearSVCModel>, LinearSV
DataStream<LinearSVCModelData> modelData =
rawModelData.map(LinearSVCModelData::new);
LinearSVCModel model = new
LinearSVCModel().setModelData(tEnv.fromDataStream(modelData));
- ReadWriteUtils.updateExistingParams(model, paramMap);
+ ParamUtils.updateExistingParams(model, paramMap);
return model;
}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
index 551b66a3..eeb7338a 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
@@ -118,7 +118,7 @@ public class LogisticRegression
rawModelData.map(vector -> new
LogisticRegressionModelData(vector, 0));
LogisticRegressionModel model =
new
LogisticRegressionModel().setModelData(tEnv.fromDataStream(modelData));
- ReadWriteUtils.updateExistingParams(model, paramMap);
+ ParamUtils.updateExistingParams(model, paramMap);
return model;
}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java
index 0ecbbd90..79566a74 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java
@@ -133,7 +133,7 @@ public class OnlineLogisticRegression
Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
OnlineLogisticRegressionModel model =
new
OnlineLogisticRegressionModel().setModelData(onlineModelDataTable);
- ReadWriteUtils.updateExistingParams(model, paramMap);
+ ParamUtils.updateExistingParams(model, paramMap);
return model;
}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java
index 212ee2ca..404c5d19 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java
@@ -135,7 +135,7 @@ public class NaiveBayes
NaiveBayesModel model =
new
NaiveBayesModel().setModelData(tEnv.fromDataStream(modelData, schema));
- ReadWriteUtils.updateExistingParams(model, paramMap);
+ ParamUtils.updateExistingParams(model, paramMap);
return model;
}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
index 6f2120eb..56330742 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
@@ -114,7 +114,7 @@ public class KMeans implements Estimator<KMeans,
KMeansModel>, KMeansParams<KMea
Table finalModelDataTable = tEnv.fromDataStream(finalModelData);
KMeansModel model = new
KMeansModel().setModelData(finalModelDataTable);
- ReadWriteUtils.updateExistingParams(model, paramMap);
+ ParamUtils.updateExistingParams(model, paramMap);
return model;
}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
index 2c8dbd11..d1783a79 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
@@ -110,7 +110,7 @@ public class OnlineKMeans
Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
OnlineKMeansModel model = new
OnlineKMeansModel().setModelData(onlineModelDataTable);
- ReadWriteUtils.updateExistingParams(model, paramMap);
+ ParamUtils.updateExistingParams(model, paramMap);
return model;
}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizer.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizer.java
index 657f6418..b2bb1b93 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizer.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizer.java
@@ -88,7 +88,7 @@ public class CountVectorizer
CountVectorizerModel model =
new
CountVectorizerModel().setModelData(tEnv.fromDataStream(modelData));
- ReadWriteUtils.updateExistingParams(model, getParamMap());
+ ParamUtils.updateExistingParams(model, getParamMap());
return model;
}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDF.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDF.java
index f348dbdd..7c2d051c 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDF.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDF.java
@@ -78,7 +78,7 @@ public class IDF implements Estimator<IDF, IDFModel>,
IDFParams<IDF> {
DataStreamUtils.aggregate(inputData, new
IDFAggregator(getMinDocFreq()));
IDFModel model = new
IDFModel().setModelData(tEnv.fromDataStream(modelData));
- ReadWriteUtils.updateExistingParams(model, getParamMap());
+ ParamUtils.updateExistingParams(model, getParamMap());
return model;
}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/imputer/Imputer.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/imputer/Imputer.java
index 585402fe..b83f14c5 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/imputer/Imputer.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/imputer/Imputer.java
@@ -107,7 +107,7 @@ public class Imputer implements Estimator<Imputer,
ImputerModel>, ImputerParams<
.build();
ImputerModel model =
new ImputerModel().setModelData(tEnv.fromDataStream(modelData,
schema));
- ReadWriteUtils.updateExistingParams(model, getParamMap());
+ ParamUtils.updateExistingParams(model, getParamMap());
return model;
}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizer.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizer.java
index 0a65406c..a5552f1d 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizer.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizer.java
@@ -157,7 +157,7 @@ public class KBinsDiscretizer
KBinsDiscretizerModel model =
new
KBinsDiscretizerModel().setModelData(tEnv.fromDataStream(modelData));
- ReadWriteUtils.updateExistingParams(model, getParamMap());
+ ParamUtils.updateExistingParams(model, getParamMap());
return model;
}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/MinHashLSH.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/MinHashLSH.java
index 8586b964..0c9bf1e3 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/MinHashLSH.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/MinHashLSH.java
@@ -18,6 +18,7 @@
package org.apache.flink.ml.feature.lsh;
+import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
@@ -48,7 +49,7 @@ public class MinHashLSH extends LSH<MinHashLSH,
MinHashLSHModel>
dim,
getSeed()));
MinHashLSHModel model = new
MinHashLSHModel().setModelData(tEnv.fromDataStream(modelData));
- ReadWriteUtils.updateExistingParams(model, getParamMap());
+ ParamUtils.updateExistingParams(model, getParamMap());
return model;
}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/maxabsscaler/MaxAbsScaler.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/maxabsscaler/MaxAbsScaler.java
index 16dabe66..26aa0caa 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/maxabsscaler/MaxAbsScaler.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/maxabsscaler/MaxAbsScaler.java
@@ -94,7 +94,7 @@ public class MaxAbsScaler
MaxAbsScalerModel model =
new
MaxAbsScalerModel().setModelData(tEnv.fromDataStream(modelData));
- ReadWriteUtils.updateExistingParams(model, getParamMap());
+ ParamUtils.updateExistingParams(model, getParamMap());
return model;
}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScaler.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScaler.java
index 28545e90..d21fd9a7 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScaler.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScaler.java
@@ -111,7 +111,7 @@ public class MinMaxScaler
MinMaxScalerModel model =
new
MinMaxScalerModel().setModelData(tEnv.fromDataStream(modelData));
- ReadWriteUtils.updateExistingParams(model, getParamMap());
+ ParamUtils.updateExistingParams(model, getParamMap());
return model;
}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoder.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoder.java
index 9b4fefdb..46d1cf35 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoder.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoder.java
@@ -96,7 +96,7 @@ public class OneHotEncoder
OneHotEncoderModel model =
new
OneHotEncoderModel().setModelData(tEnv.fromDataStream(modelData));
- ReadWriteUtils.updateExistingParams(model, paramMap);
+ ParamUtils.updateExistingParams(model, paramMap);
return model;
}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/robustscaler/RobustScaler.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/robustscaler/RobustScaler.java
index 4004915c..cb689081 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/robustscaler/RobustScaler.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/robustscaler/RobustScaler.java
@@ -84,7 +84,7 @@ public class RobustScaler
new QuantileAggregator(getRelativeError(), getLower(),
getUpper()));
RobustScalerModel model =
new
RobustScalerModel().setModelData(tEnv.fromDataStream(modelData));
- ReadWriteUtils.updateExistingParams(model, getParamMap());
+ ParamUtils.updateExistingParams(model, getParamMap());
return model;
}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/OnlineStandardScaler.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/OnlineStandardScaler.java
index 89ea7bba..5c7d4474 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/OnlineStandardScaler.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/OnlineStandardScaler.java
@@ -104,7 +104,7 @@ public class OnlineStandardScaler
OnlineStandardScalerModel model =
new
OnlineStandardScalerModel().setModelData(tEnv.fromDataStream(modelData));
- ReadWriteUtils.updateExistingParams(model, paramMap);
+ ParamUtils.updateExistingParams(model, paramMap);
return model;
}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/StandardScaler.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/StandardScaler.java
index 20a2ffb2..59f519f2 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/StandardScaler.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/StandardScaler.java
@@ -96,7 +96,7 @@ public class StandardScaler
StandardScalerModel model =
new
StandardScalerModel().setModelData(tEnv.fromDataStream(modelData));
- ReadWriteUtils.updateExistingParams(model, paramMap);
+ ParamUtils.updateExistingParams(model, paramMap);
return model;
}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexer.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexer.java
index abaffa03..51dd727c 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexer.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexer.java
@@ -132,7 +132,7 @@ public class StringIndexer
StringIndexerModel model =
new
StringIndexerModel().setModelData(tEnv.fromDataStream(modelData));
- ReadWriteUtils.updateExistingParams(model, paramMap);
+ ParamUtils.updateExistingParams(model, paramMap);
return model;
}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/univariatefeatureselector/UnivariateFeatureSelector.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/univariatefeatureselector/UnivariateFeatureSelector.java
index 78810d30..8c66bea9 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/univariatefeatureselector/UnivariateFeatureSelector.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/univariatefeatureselector/UnivariateFeatureSelector.java
@@ -149,7 +149,7 @@ public class UnivariateFeatureSelector
.setParallelism(1);
UnivariateFeatureSelectorModel model =
new
UnivariateFeatureSelectorModel().setModelData(tEnv.fromDataStream(modelData));
- ReadWriteUtils.updateExistingParams(model, getParamMap());
+ ParamUtils.updateExistingParams(model, getParamMap());
return model;
}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelector.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelector.java
index 4c965c75..621415f0 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelector.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelector.java
@@ -86,7 +86,7 @@ public class VarianceThresholdSelector
VarianceThresholdSelectorModel model =
new
VarianceThresholdSelectorModel().setModelData(tEnv.fromDataStream(modelData));
- ReadWriteUtils.updateExistingParams(model, getParamMap());
+ ParamUtils.updateExistingParams(model, getParamMap());
return model;
}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexer.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexer.java
index 203fb9c7..0bf4c7e5 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexer.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexer.java
@@ -134,7 +134,7 @@ public class VectorIndexer
VectorIndexerModel model =
new
VectorIndexerModel().setModelData(tEnv.fromDataStream(modelData, schema));
- ReadWriteUtils.updateExistingParams(model, paramMap);
+ ParamUtils.updateExistingParams(model, paramMap);
return model;
}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegression.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegression.java
index 3de48049..d977eb0e 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegression.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegression.java
@@ -106,7 +106,7 @@ public class LinearRegression
rawModelData.map(LinearRegressionModelData::new);
LinearRegressionModel model =
new
LinearRegressionModel().setModelData(tEnv.fromDataStream(modelData));
- ReadWriteUtils.updateExistingParams(model, paramMap);
+ ParamUtils.updateExistingParams(model, paramMap);
return model;
}
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/KnnTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/KnnTest.java
index dedafeea..28d9c4f7 100644
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/KnnTest.java
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/KnnTest.java
@@ -27,7 +27,7 @@ import org.apache.flink.ml.linalg.DenseMatrix;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vectors;
-import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
@@ -243,7 +243,7 @@ public class KnnTest extends AbstractTestBase {
KnnModel modelA = knn.fit(trainData);
Table modelData = modelA.getModelData()[0];
KnnModel modelB = new KnnModel().setModelData(modelData);
- ReadWriteUtils.updateExistingParams(modelB, modelA.getParamMap());
+ ParamUtils.updateExistingParams(modelB, modelA.getParamMap());
Table output = modelB.transform(predictData)[0];
verifyPredictionResult(output, knn.getLabelCol(),
knn.getPredictionCol());
}
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LinearSVCTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LinearSVCTest.java
index 5f3e3ac9..9b9f7a76 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LinearSVCTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LinearSVCTest.java
@@ -29,7 +29,7 @@ import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
-import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
@@ -245,7 +245,7 @@ public class LinearSVCTest extends AbstractTestBase {
LinearSVCModel model = linearSVC.fit(trainDataTable);
LinearSVCModel newModel = new LinearSVCModel();
- ReadWriteUtils.updateExistingParams(newModel, model.getParamMap());
+ ParamUtils.updateExistingParams(newModel, model.getParamMap());
newModel.setModelData(model.getModelData());
Table output = newModel.transform(trainDataTable)[0];
verifyPredictionResult(
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java
index b8a70ffd..20db30f0 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java
@@ -29,7 +29,7 @@ import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
-import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
@@ -272,7 +272,7 @@ public class LogisticRegressionTest extends
AbstractTestBase {
LogisticRegressionModel model =
logisticRegression.fit(binomialDataTable);
LogisticRegressionModel newModel = new LogisticRegressionModel();
- ReadWriteUtils.updateExistingParams(newModel, model.getParamMap());
+ ParamUtils.updateExistingParams(newModel, model.getParamMap());
newModel.setModelData(model.getModelData());
Table output = newModel.transform(binomialDataTable)[0];
verifyPredictionResult(
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/NaiveBayesTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/NaiveBayesTest.java
index e5473a8c..d727964e 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/NaiveBayesTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/NaiveBayesTest.java
@@ -24,7 +24,7 @@ import
org.apache.flink.ml.classification.naivebayes.NaiveBayesModelData;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
-import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
@@ -306,7 +306,7 @@ public class NaiveBayesTest extends AbstractTestBase {
Table modelData = modelA.getModelData()[0];
NaiveBayesModel modelB = new NaiveBayesModel().setModelData(modelData);
- ReadWriteUtils.updateExistingParams(modelB, modelA.getParamMap());
+ ParamUtils.updateExistingParams(modelB, modelA.getParamMap());
Table outputTable = modelB.transform(predictTable)[0];
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
index 3ea2c95b..ad54926f 100644
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
@@ -26,7 +26,7 @@ import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
-import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
@@ -255,7 +255,7 @@ public class KMeansTest extends AbstractTestBase {
KMeans kmeans = new KMeans().setMaxIter(2).setK(2);
KMeansModel modelA = kmeans.fit(dataTable);
KMeansModel modelB = new
KMeansModel().setModelData(modelA.getModelData());
- ReadWriteUtils.updateExistingParams(modelB, modelA.getParamMap());
+ ParamUtils.updateExistingParams(modelB, modelA.getParamMap());
Table output = modelB.transform(dataTable)[0];
List<Row> results = IteratorUtils.toList(output.execute().collect());
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/IDFTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/IDFTest.java
index 8d5feb7b..4d4d8e1b 100644
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/IDFTest.java
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/IDFTest.java
@@ -23,7 +23,7 @@ import org.apache.flink.ml.feature.idf.IDFModel;
import org.apache.flink.ml.feature.idf.IDFModelData;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
-import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
@@ -183,7 +183,7 @@ public class IDFTest extends AbstractTestBase {
IDFModel model = new IDF().fit(inputTable);
IDFModel newModel = new IDFModel();
- ReadWriteUtils.updateExistingParams(newModel, model.getParamMap());
+ ParamUtils.updateExistingParams(newModel, model.getParamMap());
newModel.setModelData(model.getModelData());
Table output = newModel.transform(inputTable)[0];
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/KBinsDiscretizerTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/KBinsDiscretizerTest.java
index 1553118c..47fecfbd 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/KBinsDiscretizerTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/KBinsDiscretizerTest.java
@@ -24,7 +24,7 @@ import
org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizerModelData;
import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizerParams;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
-import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
@@ -253,7 +253,7 @@ public class KBinsDiscretizerTest extends AbstractTestBase {
KBinsDiscretizerModel model = kBinsDiscretizer.fit(trainTable);
KBinsDiscretizerModel newModel = new KBinsDiscretizerModel();
- ReadWriteUtils.updateExistingParams(newModel, model.getParamMap());
+ ParamUtils.updateExistingParams(newModel, model.getParamMap());
newModel.setModelData(model.getModelData());
Table output = newModel.transform(testTable)[0];
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MaxAbsScalerTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MaxAbsScalerTest.java
index efc3b243..ece5789d 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MaxAbsScalerTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MaxAbsScalerTest.java
@@ -25,7 +25,7 @@ import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
-import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
@@ -240,7 +240,7 @@ public class MaxAbsScalerTest {
Table modelData = modelA.getModelData()[0];
MaxAbsScalerModel modelB = new
MaxAbsScalerModel().setModelData(modelData);
- ReadWriteUtils.updateExistingParams(modelB, modelA.getParamMap());
+ ParamUtils.updateExistingParams(modelB, modelA.getParamMap());
Table output = modelB.transform(predictDataTable)[0];
verifyPredictionResult(output, maxAbsScaler.getOutputCol(),
EXPECTED_DATA);
}
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinHashLSHTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinHashLSHTest.java
index fdb71f65..9de24b88 100644
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinHashLSHTest.java
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinHashLSHTest.java
@@ -25,7 +25,7 @@ import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
-import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
@@ -353,7 +353,7 @@ public class MinHashLSHTest extends AbstractTestBase {
MinHashLSHModel modelA = lsh.fit(inputTable);
Table modelDataData = modelA.getModelData()[0];
MinHashLSHModel modelB = new
MinHashLSHModel().setModelData(modelDataData);
- ReadWriteUtils.updateExistingParams(modelB, modelA.getParamMap());
+ ParamUtils.updateExistingParams(modelB, modelA.getParamMap());
Table output =
modelB.transform(inputTable)[0].select($(lsh.getOutputCol()));
verifyPredictionResult(output, outputRows);
}
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinMaxScalerTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinMaxScalerTest.java
index 3051a061..7c2ac0bb 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinMaxScalerTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinMaxScalerTest.java
@@ -24,7 +24,7 @@ import
org.apache.flink.ml.feature.minmaxscaler.MinMaxScalerModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vectors;
-import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
@@ -205,7 +205,7 @@ public class MinMaxScalerTest extends AbstractTestBase {
MinMaxScalerModel modelA = minMaxScaler.fit(trainDataTable);
Table modelData = modelA.getModelData()[0];
MinMaxScalerModel modelB = new
MinMaxScalerModel().setModelData(modelData);
- ReadWriteUtils.updateExistingParams(modelB, modelA.getParamMap());
+ ParamUtils.updateExistingParams(modelB, modelA.getParamMap());
Table output = modelB.transform(predictDataTable)[0];
verifyPredictionResult(output, minMaxScaler.getOutputCol(),
EXPECTED_DATA);
}
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OneHotEncoderTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OneHotEncoderTest.java
index fe6e9277..76682440 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OneHotEncoderTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OneHotEncoderTest.java
@@ -25,7 +25,7 @@ import
org.apache.flink.ml.feature.onehotencoder.OneHotEncoderModel;
import org.apache.flink.ml.feature.onehotencoder.OneHotEncoderModelData;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
-import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
@@ -289,7 +289,7 @@ public class OneHotEncoderTest extends AbstractTestBase {
Table modelData = modelA.getModelData()[0];
OneHotEncoderModel modelB = new
OneHotEncoderModel().setModelData(modelData);
- ReadWriteUtils.updateExistingParams(modelB, modelA.getParamMap());
+ ParamUtils.updateExistingParams(modelB, modelA.getParamMap());
Table outputTable = modelB.transform(predictTable)[0];
Map<Double, Vector>[] actualOutput =
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OnlineStandardScalerTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OnlineStandardScalerTest.java
index 400902dc..02798303 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OnlineStandardScalerTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OnlineStandardScalerTest.java
@@ -36,7 +36,7 @@ import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
-import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
@@ -356,7 +356,7 @@ public class OnlineStandardScalerTest extends
AbstractTestBase {
OnlineStandardScalerModel model =
standardScaler.fit(inputTableWithEventTime);
OnlineStandardScalerModel newModel = new OnlineStandardScalerModel();
- ReadWriteUtils.updateExistingParams(newModel, model.getParamMap());
+ ParamUtils.updateExistingParams(newModel, model.getParamMap());
newModel.setModelData(model.getModelData());
Table output = newModel.transform(inputTableWithEventTime)[0];
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/StandardScalerTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/StandardScalerTest.java
index 5d305409..b36cd885 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/StandardScalerTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/StandardScalerTest.java
@@ -25,7 +25,7 @@ import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
-import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
@@ -233,7 +233,7 @@ public class StandardScalerTest extends AbstractTestBase {
StandardScalerModel model = standardScaler.fit(denseTable);
StandardScalerModel newModel = new StandardScalerModel();
- ReadWriteUtils.updateExistingParams(newModel, model.getParamMap());
+ ParamUtils.updateExistingParams(newModel, model.getParamMap());
newModel.setModelData(model.getModelData());
Table output = newModel.transform(denseTable)[0];
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorIndexerTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorIndexerTest.java
index f725809f..0d4f1789 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorIndexerTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorIndexerTest.java
@@ -23,7 +23,7 @@ import
org.apache.flink.ml.feature.vectorindexer.VectorIndexer;
import org.apache.flink.ml.feature.vectorindexer.VectorIndexerModel;
import org.apache.flink.ml.feature.vectorindexer.VectorIndexerModelData;
import org.apache.flink.ml.linalg.Vectors;
-import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Expressions;
@@ -250,7 +250,7 @@ public class VectorIndexerTest extends AbstractTestBase {
VectorIndexerModel model = vectorIndexer.fit(trainInputTable);
VectorIndexerModel newModel = new VectorIndexerModel();
- ReadWriteUtils.updateExistingParams(newModel, model.getParamMap());
+ ParamUtils.updateExistingParams(newModel, model.getParamMap());
newModel.setModelData(model.getModelData());
Table output = newModel.transform(testInputTable)[0];
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/stringindexer/StringIndexerTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/stringindexer/StringIndexerTest.java
index 47e7c980..2b41ddb9 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/stringindexer/StringIndexerTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/stringindexer/StringIndexerTest.java
@@ -19,7 +19,7 @@
package org.apache.flink.ml.feature.stringindexer;
import org.apache.flink.ml.common.param.HasHandleInvalid;
-import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
@@ -339,7 +339,7 @@ public class StringIndexerTest extends AbstractTestBase {
.fit(trainTable);
StringIndexerModel newModel = new StringIndexerModel();
- ReadWriteUtils.updateExistingParams(newModel, model.getParamMap());
+ ParamUtils.updateExistingParams(newModel, model.getParamMap());
newModel.setModelData(model.getModelData());
Table output = newModel.transform(predictTable)[0];
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java
index ed2fa836..9776965f 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java
@@ -27,7 +27,7 @@ import
org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
import org.apache.flink.ml.regression.linearregression.LinearRegression;
import org.apache.flink.ml.regression.linearregression.LinearRegressionModel;
import
org.apache.flink.ml.regression.linearregression.LinearRegressionModelData;
-import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
@@ -218,7 +218,7 @@ public class LinearRegressionTest extends AbstractTestBase {
LinearRegressionModel model = linearRegression.fit(trainDataTable);
LinearRegressionModel newModel = new LinearRegressionModel();
- ReadWriteUtils.updateExistingParams(newModel, model.getParamMap());
+ ParamUtils.updateExistingParams(newModel, model.getParamMap());
newModel.setModelData(model.getModelData());
Table output = newModel.transform(trainDataTable)[0];
verifyPredictionResult(
diff --git a/flink-ml-python/pyflink/ml/feature/tests/test_minhashlsh.py
b/flink-ml-python/pyflink/ml/feature/tests/test_minhashlsh.py
index cdafca0a..5463a1f7 100644
--- a/flink-ml-python/pyflink/ml/feature/tests/test_minhashlsh.py
+++ b/flink-ml-python/pyflink/ml/feature/tests/test_minhashlsh.py
@@ -238,7 +238,7 @@ class MinHashLSHTest(PyFlinkMLTestCase):
@classmethod
def update_existing_params(cls, target: JavaWithParams, source:
JavaWithParams):
- get_gateway().jvm.org.apache.flink.ml.util.ReadWriteUtils \
+ get_gateway().jvm.org.apache.flink.ml.util.ParamUtils \
.updateExistingParams(target._java_obj,
source._java_obj.getParamMap())
@classmethod
diff --git a/flink-ml-python/pyflink/ml/tests/test_utils.py
b/flink-ml-python/pyflink/ml/tests/test_utils.py
index dae8c9ad..0909971d 100644
--- a/flink-ml-python/pyflink/ml/tests/test_utils.py
+++ b/flink-ml-python/pyflink/ml/tests/test_utils.py
@@ -34,7 +34,7 @@ from pyflink.ml.wrapper import JavaWithParams
def update_existing_params(target: JavaWithParams, source: JavaWithParams):
- get_gateway().jvm.org.apache.flink.ml.util.ReadWriteUtils \
+ get_gateway().jvm.org.apache.flink.ml.util.ParamUtils \
.updateExistingParams(target._java_obj, source._java_obj.getParamMap())
diff --git
a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/servable/api/DataFrame.java
b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/servable/api/DataFrame.java
index 32e20e41..780d0855 100644
---
a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/servable/api/DataFrame.java
+++
b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/servable/api/DataFrame.java
@@ -97,7 +97,7 @@ public class DataFrame {
* @throws IllegalArgumentException if the number of values is different
from the number of
* rows.
*/
- public DataFrame addColumn(String columnName, DataType dataType,
List<Object> values) {
+ public DataFrame addColumn(String columnName, DataType dataType, List<?>
values) {
if (values.size() != rows.size()) {
throw new RuntimeException(
String.format(
@@ -107,7 +107,7 @@ public class DataFrame {
columnNames.add(columnName);
dataTypes.add(dataType);
- Iterator<Object> iter = values.iterator();
+ Iterator<?> iter = values.iterator();
for (Row row : rows) {
Object value = iter.next();
row.add(value);
diff --git
a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/servable/builder/PipelineModelServable.java
b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/servable/builder/PipelineModelServable.java
new file mode 100644
index 00000000..0b90a681
--- /dev/null
+++
b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/servable/builder/PipelineModelServable.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.servable.builder;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.servable.api.DataFrame;
+import org.apache.flink.ml.servable.api.ModelServable;
+import org.apache.flink.ml.servable.api.TransformerServable;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ServableReadWriteUtils;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * A PipelineModelServable acts as a {@link ModelServable}. It consists of an
ordered list of
+ * servables, each of which could be a TransformerServable or ModelServable.
+ */
+@PublicEvolving
+public final class PipelineModelServable implements
ModelServable<PipelineModelServable> {
+
+ private final List<TransformerServable<?>> servables;
+
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+ public PipelineModelServable(List<TransformerServable<?>> servables) {
+ this.servables = Preconditions.checkNotNull(servables);
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public DataFrame transform(DataFrame input) {
+ for (TransformerServable<?> servable : servables) {
+ input = servable.transform(input);
+ }
+ return input;
+ }
+
+ public static PipelineModelServable load(String path) throws IOException {
+ return new
PipelineModelServable(ServableReadWriteUtils.loadPipeline(path));
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+}
diff --git
a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/servable/types/DataTypes.java
b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/servable/types/DataTypes.java
new file mode 100644
index 00000000..1fa33fd0
--- /dev/null
+++
b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/servable/types/DataTypes.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.servable.types;
+
+/** This class gives access to the most common types that are used to define
DataFrames. */
+public class DataTypes {
+
+ public static final ScalarType BOOLEAN = new ScalarType(BasicType.BOOLEAN);
+
+ public static final ScalarType BYTE = new ScalarType(BasicType.BYTE);
+
+ public static final ScalarType SHORT = new ScalarType(BasicType.SHORT);
+
+ public static final ScalarType INT = new ScalarType(BasicType.INT);
+
+ public static final ScalarType LONG = new ScalarType(BasicType.LONG);
+
+ public static final ScalarType FLOAT = new ScalarType(BasicType.FLOAT);
+
+ public static final ScalarType DOUBLE = new ScalarType(BasicType.DOUBLE);
+
+ public static final ScalarType STRING = new ScalarType(BasicType.STRING);
+
+ public static final ScalarType BYTE_STRING = new
ScalarType(BasicType.BYTE_STRING);
+
+ public static VectorType VECTOR(BasicType elementType) {
+ return new VectorType(elementType);
+ }
+
+ public static MatrixType MATRIX(BasicType elementType) {
+ return new MatrixType(elementType);
+ }
+}
diff --git
a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/util/FileUtils.java
b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/util/FileUtils.java
new file mode 100644
index 00000000..9eb975d4
--- /dev/null
+++
b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/util/FileUtils.java
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
+
+import java.io.BufferedReader;
+import java.io.BufferedWriter;
+import java.io.File;
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.io.OutputStreamWriter;
+import java.util.Map;
+
+/** Utility methods for file operations. */
+public class FileUtils {
+
+ /** Saves a given string to the specified file. */
+ public static void saveToFile(String pathStr, String content, boolean
isOverwrite)
+ throws IOException {
+ Path path = new Path(pathStr);
+
+ // Creates parent directories if not already created.
+ FileSystem fs = mkdirs(path.getParent());
+
+ FileSystem.WriteMode writeMode = FileSystem.WriteMode.OVERWRITE;
+ if (!isOverwrite) {
+ writeMode = FileSystem.WriteMode.NO_OVERWRITE;
+ if (fs.exists(path)) {
+ throw new IOException("File " + path + " already exists.");
+ }
+ }
+ try (BufferedWriter writer =
+ new BufferedWriter(new OutputStreamWriter(fs.create(path,
writeMode)))) {
+ writer.write(content);
+ }
+ }
+
+ public static FileSystem mkdirs(Path path) throws IOException {
+ FileSystem fs = path.getFileSystem();
+ fs.mkdirs(path);
+ return fs;
+ }
+
+ /**
+ * Loads the metadata from the metadata file under the given path.
+ *
+ * <p>The method throws RuntimeException if the expectedClassName is not
empty AND it does not
+ * match the className of the previously saved stage.
+ *
+ * @param path The parent directory of the metadata file to read from.
+ * @param expectedClassName The expected class name of the stage.
+ * @return A map from metadata name to metadata value.
+ */
+ public static Map<String, ?> loadMetadata(String path, String
expectedClassName)
+ throws IOException {
+ Path metadataPath = new Path(path, "metadata");
+ FileSystem fs = metadataPath.getFileSystem();
+
+ StringBuilder buffer = new StringBuilder();
+ try (BufferedReader br = new BufferedReader(new
InputStreamReader(fs.open(metadataPath)))) {
+ String line;
+ while ((line = br.readLine()) != null) {
+ if (!line.startsWith("#")) {
+ buffer.append(line);
+ }
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ Map<String, ?> result =
JsonUtils.OBJECT_MAPPER.readValue(buffer.toString(), Map.class);
+
+ String className = (String) result.get("className");
+ if (!expectedClassName.isEmpty() &&
!expectedClassName.equals(className)) {
+ throw new RuntimeException(
+ "Class name "
+ + className
+ + " does not match the expected class name "
+ + expectedClassName
+ + ".");
+ }
+
+ return result;
+ }
+
+ // Returns a string with value {parentPath}/stages/{stageIdx}, where the
stageIdx is prefixed
+ // with zero or more `0` to have the same length as numStages. The
resulting string can be
+ // used as the directory to save a stage of the Pipeline or PipelineModel.
+ public static String getPathForPipelineStage(int stageIdx, int numStages,
String parentPath) {
+ String format =
+ String.format("stages%s%%0%dd", File.separator,
String.valueOf(numStages).length());
+ String fileName = String.format(format, stageIdx);
+ return new Path(parentPath, fileName).toString();
+ }
+
+ /** Returns a subdirectory of the given path for saving/loading model
data. */
+ public static Path getDataPath(String path) {
+ return new Path(path, "data");
+ }
+}
diff --git
a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/util/ParamUtils.java
b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/util/ParamUtils.java
index cdbe63d5..a00263e0 100644
---
a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/util/ParamUtils.java
+++
b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/util/ParamUtils.java
@@ -20,12 +20,16 @@ package org.apache.flink.ml.util;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.param.WithParams;
+import org.apache.flink.util.InstantiationUtil;
+import java.io.IOException;
import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
+import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.Set;
/** Utility methods for reading and writing stages. */
public class ParamUtils {
@@ -86,4 +90,54 @@ public class ParamUtils {
}
return result;
}
+
+ // A helper method that sets WithParams object's parameter value. We can
not call
+ // WithParams.set(param, value)
+ // directly because WithParams::set(...) needs the actual type of the
value.
+ @SuppressWarnings("unchecked")
+ public static <T> void setParam(WithParams<?> instance, Param<T> param,
Object value) {
+ instance.set(param, (T) value);
+ }
+
+ // A helper method that updates WithParams instance's param map using
values from the
+ // paramOverrides. This method only updates values for parameters already
defined in the
+ // instance's param map.
+ public static void updateExistingParams(
+ WithParams<?> instance, Map<Param<?>, Object> paramOverrides) {
+ Set<Param<?>> existingParams = instance.getParamMap().keySet();
+ for (Map.Entry<Param<?>, Object> entry : paramOverrides.entrySet()) {
+ if (existingParams.contains(entry.getKey())) {
+ ParamUtils.setParam(instance, entry.getKey(),
entry.getValue());
+ }
+ }
+ }
+
+ /**
+ * Instantiates a WithParams subclass from the provided json map.
+ *
+ * @param jsonMap a map containing className and paramMap.
+ * @return the instantiated WithParams subclass instance.
+ */
+ @SuppressWarnings("unchecked")
+ public static <T extends WithParams<T>> T
instantiateWithParams(Map<String, ?> jsonMap)
+ throws ClassNotFoundException, IOException {
+ String className = (String) jsonMap.get("className");
+ Class<T> clazz = (Class<T>) Class.forName(className);
+ T instance = InstantiationUtil.instantiate(clazz);
+
+ Map<String, Param<?>> nameToParam = new HashMap<>();
+ for (Param<?> param : ParamUtils.getPublicFinalParamFields(instance)) {
+ nameToParam.put(param.name, param);
+ }
+
+ if (jsonMap.containsKey("paramMap")) {
+ Map<String, Object> paramMap = (Map<String, Object>)
jsonMap.get("paramMap");
+ for (Map.Entry<String, Object> entry : paramMap.entrySet()) {
+ Param<?> param = nameToParam.get(entry.getKey());
+ ParamUtils.setParam(instance, param,
param.jsonDecode(entry.getValue()));
+ }
+ }
+
+ return instance;
+ }
}
diff --git
a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/util/ServableReadWriteUtils.java
b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/util/ServableReadWriteUtils.java
new file mode 100644
index 00000000..c38e3d2f
--- /dev/null
+++
b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/util/ServableReadWriteUtils.java
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import org.apache.flink.core.fs.FileStatus;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.servable.api.TransformerServable;
+import org.apache.flink.ml.servable.builder.PipelineModelServable;
+import org.apache.flink.util.InstantiationUtil;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Method;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.apache.flink.ml.util.FileUtils.loadMetadata;
+
+/** Utility methods for loading Servables. */
+public class ServableReadWriteUtils {
+
+ /**
+ * Loads the servables of a {@link PipelineModelServable} from the given
path.
+ *
+ * <p>The method throws RuntimeException if the expectedClassName is not
empty AND it does not
+ * match the className of the previously saved PipelineModel.
+ *
+ * @param path The parent directory to load the PipelineModelServable
metadata and its
+ * servables.
+ * @return A list of servables.
+ */
+ public static List<TransformerServable<?>> loadPipeline(String path)
throws IOException {
+ Map<String, ?> metadata = loadMetadata(path, "");
+ int numStages = (Integer) metadata.get("numStages");
+ List<TransformerServable<?>> servables = new ArrayList<>(numStages);
+
+ for (int i = 0; i < numStages; i++) {
+ String stagePath = FileUtils.getPathForPipelineStage(i, numStages,
path);
+ servables.add(loadServable(stagePath));
+ }
+ return servables;
+ }
+
+ /**
+ * Loads the {@link TransformerServable} from the given path by invoking
the static
+ * loadServable() method of the stage. The stage class name is read from
the metadata file under
+ * the given path. The loadServable() method is expected to construct the
TransformerServable
+ * instance with the saved parameters, model data and other metadata if
exists.
+ *
+ * <p>Required: the stage class must have a static loadServable() method.
+ *
+ * @param path The parent directory of the stage metadata file.
+ * @return An instance of {@link TransformerServable}.
+ */
+ private static TransformerServable<?> loadServable(String path) throws
IOException {
+ Map<String, ?> metadata = FileUtils.loadMetadata(path, "");
+ String className = (String) metadata.get("className");
+
+ try {
+ Class<?> clazz = Class.forName(className);
+ Method method = clazz.getMethod("loadServable", String.class);
+ method.setAccessible(true);
+ return (TransformerServable<?>) method.invoke(null, path);
+ } catch (NoSuchMethodException e) {
+ String methodName = String.format("%s::loadServable(String)",
className);
+ throw new RuntimeException(
+ "Failed to load servable because the static method "
+ + methodName
+ + " is not implemented.",
+ e);
+ } catch (ClassNotFoundException | IllegalAccessException |
InvocationTargetException e) {
+ throw new RuntimeException("Failed to load servable.", e);
+ }
+ }
+
+ /**
+ * Loads the {@link TransformerServable} with the saved parameters from
the given path. This
+ * method reads the metadata file under the given path, instantiates the
servable using its
+ * no-argument constructor, and loads the servable with the paramMap from
the metadata file.
+ *
+ * <p>Note: This method does not attempt to read model data from the given
path. Caller needs to
+ * read and deserialize model data from the given path.
+ *
+ * <p>Required: the class with type T must have a no-argument constructor.
+ *
+ * @param path The parent directory of the metadata file.
+ * @param <T> The class type of the TransformerServable subclass.
+ * @return An instance of class type T.
+ */
+ public static <T extends TransformerServable<T>> T loadServableParam(
+ String path, Class<T> clazz) throws IOException {
+ T instance = InstantiationUtil.instantiate(clazz);
+
+ Map<String, Param<?>> nameToParam = new HashMap<>();
+ for (Param<?> param : ParamUtils.getPublicFinalParamFields(instance)) {
+ nameToParam.put(param.name, param);
+ }
+
+ Map<String, ?> jsonMap = loadMetadata(path, "");
+ if (jsonMap.containsKey("paramMap")) {
+ Map<String, Object> paramMap = (Map<String, Object>)
jsonMap.get("paramMap");
+ for (Map.Entry<String, Object> entry : paramMap.entrySet()) {
+ Param<?> param = nameToParam.get(entry.getKey());
+ ParamUtils.setParam(instance, param,
param.jsonDecode(entry.getValue()));
+ }
+ }
+
+ return instance;
+ }
+
+ /**
+ * Opens an FSDataInputStream to read the model data file in the
directory. Only one model data
+ * file is expected to be in the directory.
+ *
+ * @param path The parent directory of the model data file.
+ * @return A FSDataInputStream to read the model data.
+ */
+ public static InputStream loadModelData(String path) throws IOException {
+ Path modelDataPath = FileUtils.getDataPath(path);
+
+ FileSystem fileSystem = modelDataPath.getFileSystem();
+
+ FileStatus[] files = fileSystem.listStatus(modelDataPath);
+ Preconditions.checkState(
+ files.length == 1,
+ "Only one model data file is expected in the directory %s.",
+ path);
+ return fileSystem.open(files[0].getPath());
+ }
+}
diff --git
a/flink-ml-servable-core/src/test/java/org/apache/flink/ml/servable/TestUtils.java
b/flink-ml-servable-core/src/test/java/org/apache/flink/ml/servable/TestUtils.java
new file mode 100644
index 00000000..ac1aa989
--- /dev/null
+++
b/flink-ml-servable-core/src/test/java/org/apache/flink/ml/servable/TestUtils.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.servable;
+
+import org.apache.flink.ml.servable.api.DataFrame;
+import org.apache.flink.ml.servable.api.Row;
+
+import org.junit.Assert;
+
+import java.util.List;
+
+/** Utility methods for tests. */
+public class TestUtils {
+
+ /** Asserts that the two dataframes are equivalent. */
+ public static void assertDataFrameEquals(DataFrame first, DataFrame
second) {
+
+ List<String> firstColNames = first.getColumnNames();
+ Assert.assertEquals(first.getColumnNames(), second.getColumnNames());
+
+ List<Row> firstRows = first.collect();
+ List<Row> secondRows = second.collect();
+
+ for (int i = 0; i < firstColNames.size(); i++) {
+ String colName = firstColNames.get(i);
+ Assert.assertEquals(first.getDataType(colName),
second.getDataType(colName));
+
+ for (int j = 0; j < firstRows.size(); j++) {
+ Assert.assertEquals(firstRows.get(j).get(i),
secondRows.get(j).get(i));
+ }
+ }
+ }
+}
diff --git
a/flink-ml-servable-core/src/test/java/org/apache/flink/ml/servable/builder/ExampleServables.java
b/flink-ml-servable-core/src/test/java/org/apache/flink/ml/servable/builder/ExampleServables.java
new file mode 100644
index 00000000..813b80f2
--- /dev/null
+++
b/flink-ml-servable-core/src/test/java/org/apache/flink/ml/servable/builder/ExampleServables.java
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.servable.builder;
+
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.servable.api.DataFrame;
+import org.apache.flink.ml.servable.api.ModelServable;
+import org.apache.flink.ml.servable.api.Row;
+import org.apache.flink.ml.servable.api.TransformerServable;
+import org.apache.flink.ml.servable.types.DataTypes;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ServableReadWriteUtils;
+import org.apache.flink.util.Preconditions;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/** Defines Servable subclasses to be used in unit tests. */
+public class ExampleServables {
+
+ /**
+ * A {@link TransformerServable} subclass that increments every value in
the input dataframe by
+ * `delta` and outputs the resulting values.
+ */
+ public static class SumModelServable implements
ModelServable<SumModelServable> {
+
+ private static final String COL_NAME = "input";
+
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+ private int delta;
+
+ public SumModelServable() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public DataFrame transform(DataFrame input) {
+ List<Row> outputRows = new ArrayList<>();
+ for (Row row : input.collect()) {
+ Preconditions.checkState(row.size() == 1);
+ int originValue = (Integer) row.get(0);
+ outputRows.add(new Row(Collections.singletonList(originValue +
delta)));
+ }
+ return new DataFrame(
+ Collections.singletonList(COL_NAME),
+ Collections.singletonList(DataTypes.INT),
+ outputRows);
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+
+ public static SumModelServable load(String path) throws IOException {
+ SumModelServable servable =
+ ServableReadWriteUtils.loadServableParam(path,
SumModelServable.class);
+
+ try (InputStream inputStream =
ServableReadWriteUtils.loadModelData(path)) {
+ DataInputViewStreamWrapper dataInputViewStreamWrapper =
+ new DataInputViewStreamWrapper(inputStream);
+ servable.delta =
IntSerializer.INSTANCE.deserialize(dataInputViewStreamWrapper);
+ return servable;
+ }
+ }
+
+ public SumModelServable setModelData(InputStream... modelDataInputs)
throws IOException {
+ Preconditions.checkArgument(modelDataInputs.length == 1);
+
+ DataInputViewStreamWrapper inputViewStreamWrapper =
+ new DataInputViewStreamWrapper(modelDataInputs[0]);
+
+ delta = IntSerializer.INSTANCE.deserialize(inputViewStreamWrapper);
+
+ return this;
+ }
+
+ public static byte[] serialize(Object modelData) throws IOException {
+ ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
+
+ DataOutputViewStreamWrapper outputViewStreamWrapper =
+ new DataOutputViewStreamWrapper(outputStream);
+
+ IntSerializer.INSTANCE.serialize((Integer) modelData,
outputViewStreamWrapper);
+
+ return outputStream.toByteArray();
+ }
+ }
+}
diff --git
a/flink-ml-servable-core/src/test/java/org/apache/flink/ml/servable/builder/PipelineModelServableTest.java
b/flink-ml-servable-core/src/test/java/org/apache/flink/ml/servable/builder/PipelineModelServableTest.java
new file mode 100644
index 00000000..2d3e5ba9
--- /dev/null
+++
b/flink-ml-servable-core/src/test/java/org/apache/flink/ml/servable/builder/PipelineModelServableTest.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.servable.builder;
+
+import org.apache.flink.ml.servable.TestUtils;
+import org.apache.flink.ml.servable.api.DataFrame;
+import org.apache.flink.ml.servable.api.Row;
+import org.apache.flink.ml.servable.api.TransformerServable;
+import org.apache.flink.ml.servable.builder.ExampleServables.SumModelServable;
+import org.apache.flink.ml.servable.types.DataTypes;
+
+import org.junit.Test;
+
+import java.io.ByteArrayInputStream;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+/** Tests the {@link PipelineModelServable}. */
+public class PipelineModelServableTest {
+
+ @Test
+ public void testTransform() throws IOException {
+ SumModelServable servableA =
+ new SumModelServable()
+ .setModelData(new
ByteArrayInputStream(SumModelServable.serialize(10)));
+ SumModelServable servableB =
+ new SumModelServable()
+ .setModelData(new
ByteArrayInputStream(SumModelServable.serialize(20)));
+ SumModelServable servableC =
+ new SumModelServable()
+ .setModelData(new
ByteArrayInputStream(SumModelServable.serialize(30)));
+
+ List<TransformerServable<?>> servables = Arrays.asList(servableA,
servableB, servableC);
+
+ TransformerServable<?> pipelineModelServable = new
PipelineModelServable(servables);
+
+ DataFrame input =
+ new DataFrame(
+ Collections.singletonList("input"),
+ Collections.singletonList(DataTypes.INT),
+ Arrays.asList(
+ new Row(Collections.singletonList(1)),
+ new Row(Collections.singletonList(2)),
+ new Row(Collections.singletonList(3))));
+
+ DataFrame output = pipelineModelServable.transform(input);
+
+ DataFrame expectedOutput =
+ new DataFrame(
+ Collections.singletonList("input"),
+ Collections.singletonList(DataTypes.INT),
+ Arrays.asList(
+ new Row(Collections.singletonList(61)),
+ new Row(Collections.singletonList(62)),
+ new Row(Collections.singletonList(63))));
+
+ TestUtils.assertDataFrameEquals(expectedOutput, output);
+ }
+}
diff --git a/tools/maven/suppressions.xml b/tools/maven/suppressions.xml
index 129a0c7a..1158fb76 100644
--- a/tools/maven/suppressions.xml
+++ b/tools/maven/suppressions.xml
@@ -36,6 +36,7 @@ under the License.
<suppress files="WindowOperatorTest.java" checks="FileLength"/>
<suppress files="WindowOperatorContractTest.java"
checks="FileLength"/>
<suppress files="NFAITCase.java" checks="FileLength"/>
+ <suppress files="DataTypes.java" checks="MethodNameCheck"/>
<suppress
files="org[\\/]apache[\\/]flink[\\/]formats[\\/]avro[\\/]generated[\\/].*.java"
checks="[a-zA-Z0-9]*"/>
<suppress
files="org[\\/]apache[\\/]flink[\\/]formats[\\/]parquet[\\/]generated[\\/].*.java"
checks="[a-zA-Z0-9]*"/>