This is an automated email from the ASF dual-hosted git repository.
lindong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new 9b353cc [FLINK-26904] Update Stage::load to use StreamTableEnvironment
9b353cc is described below
commit 9b353cc875f7cdd6e1a80a63c9f7eb7eabd8fb41
Author: yunfengzhou-hub <[email protected]>
AuthorDate: Wed Mar 30 15:43:07 2022 +0800
[FLINK-26904] Update Stage::load to use StreamTableEnvironment
This closes #76.
---
.../main/java/org/apache/flink/ml/api/Stage.java | 6 ++--
.../java/org/apache/flink/ml/builder/Graph.java | 6 ++--
.../org/apache/flink/ml/builder/GraphModel.java | 6 ++--
.../java/org/apache/flink/ml/builder/Pipeline.java | 6 ++--
.../org/apache/flink/ml/builder/PipelineModel.java | 7 ++--
.../flink/ml/common/datastream/TableUtils.java | 13 +++++++-
.../org/apache/flink/ml/util/ReadWriteUtils.java | 37 ++++++++++++----------
.../org/apache/flink/ml/api/ExampleStages.java | 15 ++++-----
.../java/org/apache/flink/ml/api/GraphTest.java | 19 ++++++-----
.../java/org/apache/flink/ml/api/PipelineTest.java | 4 +--
.../java/org/apache/flink/ml/api/StageTest.java | 16 ++++++----
.../apache/flink/ml/util/ReadWriteUtilsTest.java | 2 +-
.../apache/flink/ml/classification/knn/Knn.java | 3 +-
.../flink/ml/classification/knn/KnnModel.java | 12 +++----
.../logisticregression/LogisticRegression.java | 3 +-
.../LogisticRegressionModel.java | 10 +++---
.../ml/classification/naivebayes/NaiveBayes.java | 3 +-
.../classification/naivebayes/NaiveBayesModel.java | 9 ++----
.../apache/flink/ml/clustering/kmeans/KMeans.java | 3 +-
.../flink/ml/clustering/kmeans/KMeansModel.java | 9 ++----
.../ml/clustering/kmeans/KMeansModelData.java | 7 ++--
.../flink/ml/clustering/kmeans/OnlineKMeans.java | 10 ++----
.../ml/clustering/kmeans/OnlineKMeansModel.java | 3 +-
.../ml/feature/minmaxscaler/MinMaxScaler.java | 4 +--
.../ml/feature/minmaxscaler/MinMaxScalerModel.java | 12 +++----
.../ml/feature/onehotencoder/OneHotEncoder.java | 4 +--
.../feature/onehotencoder/OneHotEncoderModel.java | 10 +++---
.../feature/stringindexer/IndexToStringModel.java | 12 +++----
.../ml/feature/stringindexer/StringIndexer.java | 4 +--
.../feature/stringindexer/StringIndexerModel.java | 12 +++----
.../apache/flink/ml/classification/KnnTest.java | 6 ++--
.../ml/classification/LogisticRegressionTest.java | 4 +--
.../flink/ml/classification/NaiveBayesTest.java | 4 +--
.../org/apache/flink/ml/clustering/KMeansTest.java | 5 +--
.../flink/ml/clustering/OnlineKMeansTest.java | 10 +++---
.../apache/flink/ml/feature/MinMaxScalerTest.java | 4 +--
.../apache/flink/ml/feature/OneHotEncoderTest.java | 4 +--
.../stringindexer/IndexToStringModelTest.java | 2 +-
.../feature/stringindexer/StringIndexerTest.java | 4 +--
.../org/apache/flink/ml/util/StageTestUtils.java | 10 ++++--
40 files changed, 153 insertions(+), 167 deletions(-)
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/api/Stage.java
b/flink-ml-core/src/main/java/org/apache/flink/ml/api/Stage.java
index e144934..439b956 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/api/Stage.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/api/Stage.java
@@ -32,9 +32,9 @@ import java.io.Serializable;
* <p>Each stage is with parameters, and requires a public empty constructor
for restoration.
*
* <p>NOTE: every Stage subclass should implement a static method with
signature {@code static T
- * load(StreamExecutionEnvironment env, String path)}, where {@code T} refers
to the concrete
- * subclass. This static method should instantiate a new stage instance based
on the data read from
- * the given path.
+ * load(StreamTableEnvironment tEnv, String path)}, where {@code T} refers to
the concrete subclass.
+ * This static method should instantiate a new stage instance based on the
data read from the given
+ * path.
*
* @param <T> The class type of the Stage implementation itself.
*/
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/builder/Graph.java
b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/Graph.java
index 8123e04..05f0c99 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/builder/Graph.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/Graph.java
@@ -26,8 +26,8 @@ import org.apache.flink.ml.api.Stage;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
-import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.util.Preconditions;
import javax.annotation.Nullable;
@@ -147,7 +147,7 @@ public final class Graph implements Estimator<Graph,
GraphModel> {
ReadWriteUtils.saveGraph(this, graphData, path);
}
- public static Graph load(StreamExecutionEnvironment env, String path)
throws IOException {
- return (Graph) ReadWriteUtils.loadGraph(env, path,
Graph.class.getName());
+ public static Graph load(StreamTableEnvironment tEnv, String path) throws
IOException {
+ return (Graph) ReadWriteUtils.loadGraph(tEnv, path,
Graph.class.getName());
}
}
diff --git
a/flink-ml-core/src/main/java/org/apache/flink/ml/builder/GraphModel.java
b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/GraphModel.java
index c39a774..2073cf2 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/builder/GraphModel.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/GraphModel.java
@@ -26,8 +26,8 @@ import org.apache.flink.ml.api.Stage;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
-import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.util.Preconditions;
import javax.annotation.Nullable;
@@ -142,7 +142,7 @@ public final class GraphModel implements Model<GraphModel> {
ReadWriteUtils.saveGraph(this, graphData, path);
}
- public static GraphModel load(StreamExecutionEnvironment env, String path)
throws IOException {
- return (GraphModel) ReadWriteUtils.loadGraph(env, path,
GraphModel.class.getName());
+ public static GraphModel load(StreamTableEnvironment tEnv, String path)
throws IOException {
+ return (GraphModel) ReadWriteUtils.loadGraph(tEnv, path,
GraphModel.class.getName());
}
}
diff --git
a/flink-ml-core/src/main/java/org/apache/flink/ml/builder/Pipeline.java
b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/Pipeline.java
index 3e021ab..6498fb8 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/builder/Pipeline.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/Pipeline.java
@@ -26,8 +26,8 @@ import org.apache.flink.ml.api.Stage;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
-import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.util.Preconditions;
import java.io.IOException;
@@ -116,8 +116,8 @@ public final class Pipeline implements Estimator<Pipeline,
PipelineModel> {
ReadWriteUtils.savePipeline(this, stages, path);
}
- public static Pipeline load(StreamExecutionEnvironment env, String path)
throws IOException {
- return new Pipeline(ReadWriteUtils.loadPipeline(env, path,
Pipeline.class.getName()));
+ public static Pipeline load(StreamTableEnvironment tEnv, String path)
throws IOException {
+ return new Pipeline(ReadWriteUtils.loadPipeline(tEnv, path,
Pipeline.class.getName()));
}
/**
diff --git
a/flink-ml-core/src/main/java/org/apache/flink/ml/builder/PipelineModel.java
b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/PipelineModel.java
index ee5f099..7c6bfdb 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/builder/PipelineModel.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/PipelineModel.java
@@ -26,8 +26,8 @@ import org.apache.flink.ml.api.Stage;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
-import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.util.Preconditions;
import java.io.IOException;
@@ -77,10 +77,9 @@ public final class PipelineModel implements
Model<PipelineModel> {
ReadWriteUtils.savePipeline(this, stages, path);
}
- public static PipelineModel load(StreamExecutionEnvironment env, String
path)
- throws IOException {
+ public static PipelineModel load(StreamTableEnvironment tEnv, String path)
throws IOException {
return new PipelineModel(
- ReadWriteUtils.loadPipeline(env, path,
PipelineModel.class.getName()));
+ ReadWriteUtils.loadPipeline(tEnv, path,
PipelineModel.class.getName()));
}
/**
diff --git
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/TableUtils.java
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/TableUtils.java
index 51aa18c..4dc175a 100644
---
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/TableUtils.java
+++
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/TableUtils.java
@@ -20,10 +20,15 @@ package org.apache.flink.ml.common.datastream;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.catalog.Column;
import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.types.Row;
-/** Utility class for table-related operations. */
+/** Utility class for operations related to Table API. */
public class TableUtils {
// Constructs a RowTypeInfo from the given schema.
public static RowTypeInfo getRowTypeInfo(ResolvedSchema schema) {
@@ -37,4 +42,10 @@ public class TableUtils {
}
return new RowTypeInfo(types, names);
}
+
+ public static StreamExecutionEnvironment
getExecutionEnvironment(StreamTableEnvironment tEnv) {
+ Table table = tEnv.fromValues();
+ DataStream<Row> dataStream = tEnv.toDataStream(table);
+ return dataStream.getExecutionEnvironment();
+ }
}
diff --git
a/flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java
b/flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java
index 674440b..54aefbe 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java
@@ -31,11 +31,14 @@ import org.apache.flink.ml.builder.Graph;
import org.apache.flink.ml.builder.GraphData;
import org.apache.flink.ml.builder.GraphModel;
import org.apache.flink.ml.builder.GraphNode;
+import org.apache.flink.ml.common.datastream.TableUtils;
import org.apache.flink.ml.param.Param;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import
org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner;
import
org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.util.InstantiationUtil;
import org.apache.flink.util.Preconditions;
@@ -211,21 +214,20 @@ public class ReadWriteUtils {
* <p>The method throws RuntimeException if the expectedClassName is not
empty AND it does not
* match the className of the previously saved Pipeline or PipelineModel.
*
- * @param env A StreamExecutionEnvironment instance.
+ * @param tEnv A StreamTableEnvironment instance.
* @param path The parent directory to load the pipeline metadata and its
stages.
* @param expectedClassName The expected class name of the pipeline.
* @return A list of stages.
*/
public static List<Stage<?>> loadPipeline(
- StreamExecutionEnvironment env, String path, String
expectedClassName)
- throws IOException {
+ StreamTableEnvironment tEnv, String path, String
expectedClassName) throws IOException {
Map<String, ?> metadata = loadMetadata(path, expectedClassName);
int numStages = (Integer) metadata.get("numStages");
List<Stage<?>> stages = new ArrayList<>(numStages);
for (int i = 0; i < numStages; i++) {
String stagePath = getPathForPipelineStage(i, numStages, path);
- stages.add(loadStage(env, stagePath));
+ stages.add(loadStage(tEnv, stagePath));
}
return stages;
}
@@ -270,14 +272,13 @@ public class ReadWriteUtils {
* <p>The method throws RuntimeException if the expectedClassName is not
empty AND it does not
* match the className of the previously saved Pipeline or PipelineModel.
*
- * @param env A StreamExecutionEnvironment instance.
+ * @param tEnv A StreamTableEnvironment instance.
* @param path The parent directory to load the pipeline metadata and its
stages.
* @param expectedClassName The expected class name of the pipeline.
* @return A Graph or GraphModel instance.
*/
public static Stage<?> loadGraph(
- StreamExecutionEnvironment env, String path, String
expectedClassName)
- throws IOException {
+ StreamTableEnvironment tEnv, String path, String
expectedClassName) throws IOException {
Map<String, ?> metadata = loadMetadata(path, expectedClassName);
GraphData graphData = GraphData.fromMap((Map<String, Object>)
metadata.get("graphData"));
@@ -289,7 +290,7 @@ public class ReadWriteUtils {
for (GraphNode node : graphData.nodes) {
String stagePath = getPathForPipelineStage(node.nodeId, maxNodeId
+ 1, path);
- node.stage = loadStage(env, stagePath);
+ node.stage = loadStage(tEnv, stagePath);
}
if (expectedClassName.equals(GraphModel.class.getName())) {
@@ -375,20 +376,19 @@ public class ReadWriteUtils {
*
* <p>Required: the stage class must have a static load() method.
*
- * @param env A StreamExecutionEnvironment instance.
+ * @param tEnv A StreamTableEnvironment instance.
* @param path The parent directory of the stage metadata file.
* @return An instance of Stage.
*/
- public static Stage<?> loadStage(StreamExecutionEnvironment env, String
path)
- throws IOException {
+ public static Stage<?> loadStage(StreamTableEnvironment tEnv, String path)
throws IOException {
Map<String, ?> metadata = loadMetadata(path, "");
String className = (String) metadata.get("className");
try {
Class<?> clazz = Class.forName(className);
- Method method = clazz.getMethod("load",
StreamExecutionEnvironment.class, String.class);
+ Method method = clazz.getMethod("load",
StreamTableEnvironment.class, String.class);
method.setAccessible(true);
- return (Stage<?>) method.invoke(null, env, path);
+ return (Stage<?>) method.invoke(null, tEnv, path);
} catch (NoSuchMethodException e) {
String methodName = String.format("%s::load(String)", className);
throw new RuntimeException(
@@ -423,16 +423,19 @@ public class ReadWriteUtils {
/**
* Loads the model data from the given path using the model decoder.
*
- * @param env A StreamExecutionEnvironment instance.
+ * @param tEnv A StreamTableEnvironment instance.
* @param path The parent directory of the model data file.
* @param modelDecoder The decoder used to decode the model data.
* @param <T> The class type of the model data.
* @return The loaded model data.
*/
- public static <T> DataStream<T> loadModelData(
- StreamExecutionEnvironment env, String path, SimpleStreamFormat<T>
modelDecoder) {
+ public static <T> Table loadModelData(
+ StreamTableEnvironment tEnv, String path, SimpleStreamFormat<T>
modelDecoder) {
+ StreamExecutionEnvironment env =
TableUtils.getExecutionEnvironment(tEnv);
Source<T, ?, ?> source =
FileSource.forRecordStreamFormat(modelDecoder, new
Path(getDataPath(path))).build();
- return env.fromSource(source, WatermarkStrategy.noWatermarks(),
"modelData");
+ DataStream<T> modelDataStream =
+ env.fromSource(source, WatermarkStrategy.noWatermarks(),
"modelData");
+ return tEnv.fromDataStream(modelDataStream);
}
}
diff --git
a/flink-ml-core/src/test/java/org/apache/flink/ml/api/ExampleStages.java
b/flink-ml-core/src/test/java/org/apache/flink/ml/api/ExampleStages.java
index 63de3a2..7eea301 100644
--- a/flink-ml-core/src/test/java/org/apache/flink/ml/api/ExampleStages.java
+++ b/flink-ml-core/src/test/java/org/apache/flink/ml/api/ExampleStages.java
@@ -28,7 +28,6 @@ import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.streaming.api.datastream.DataStream;
-import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.BoundedOneInput;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
@@ -103,14 +102,12 @@ public class ExampleStages {
ReadWriteUtils.saveMetadata(this, path);
}
- public static SumModel load(StreamExecutionEnvironment env, String
path)
- throws IOException {
- StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
- DataStream<Integer> modelData =
- ReadWriteUtils.loadModelData(env, path, new
TestUtils.IntegerStreamFormat());
+ public static SumModel load(StreamTableEnvironment tEnv, String path)
throws IOException {
+ Table modelDataTable =
+ ReadWriteUtils.loadModelData(tEnv, path, new
TestUtils.IntegerStreamFormat());
SumModel model = ReadWriteUtils.loadStageParam(path);
- return model.setModelData(tEnv.fromDataStream(modelData));
+ return model.setModelData(modelDataTable);
}
}
@@ -196,7 +193,7 @@ public class ExampleStages {
ReadWriteUtils.saveMetadata(this, path);
}
- public static SumEstimator load(StreamExecutionEnvironment env, String
path)
+ public static SumEstimator load(StreamTableEnvironment tEnv, String
path)
throws IOException {
return ReadWriteUtils.loadStageParam(path);
}
@@ -250,7 +247,7 @@ public class ExampleStages {
ReadWriteUtils.saveMetadata(this, path);
}
- public static UnionAlgoOperator load(StreamExecutionEnvironment env,
String path)
+ public static UnionAlgoOperator load(StreamTableEnvironment tEnv,
String path)
throws IOException {
return ReadWriteUtils.loadStageParam(path);
}
diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/api/GraphTest.java
b/flink-ml-core/src/test/java/org/apache/flink/ml/api/GraphTest.java
index 213ed3c..6cad1b1 100644
--- a/flink-ml-core/src/test/java/org/apache/flink/ml/api/GraphTest.java
+++ b/flink-ml-core/src/test/java/org/apache/flink/ml/api/GraphTest.java
@@ -27,6 +27,7 @@ import org.apache.flink.ml.builder.Graph;
import org.apache.flink.ml.builder.GraphBuilder;
import org.apache.flink.ml.builder.GraphModel;
import org.apache.flink.ml.builder.TableId;
+import org.apache.flink.ml.common.datastream.TableUtils;
import
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
@@ -60,7 +61,7 @@ public class GraphTest extends AbstractTestBase {
// Executes the given stage using the given inputs and verifies that it
produces the expected
// output. Then repeats this procedure after saving and loading the given
stage.
private static void executeSaveLoadAndCheckOutput(
- StreamExecutionEnvironment env,
+ StreamTableEnvironment tEnv,
Stage<?> stage,
List<List<Integer>> inputs,
List<Integer> expectedOutput,
@@ -68,6 +69,8 @@ public class GraphTest extends AbstractTestBase {
List<Integer> expectedModelDataOutput,
boolean modelDataExists)
throws Exception {
+ StreamExecutionEnvironment env =
TableUtils.getExecutionEnvironment(tEnv);
+
// Executes the given stage and verifies that it produces the expected
output.
TestUtils.executeAndCheckOutput(
env, stage, inputs, expectedOutput, modelDataInputs,
expectedModelDataOutput);
@@ -81,9 +84,9 @@ public class GraphTest extends AbstractTestBase {
Stage<?> loadedStage = null;
if (stage instanceof Estimator) {
- loadedStage = Graph.load(env, path);
+ loadedStage = Graph.load(tEnv, path);
} else {
- loadedStage = GraphModel.load(env, path);
+ loadedStage = GraphModel.load(tEnv, path);
}
// Executes the loaded stage and verifies that it produces the
expected output.
TestUtils.executeAndCheckOutput(
@@ -115,7 +118,7 @@ public class GraphTest extends AbstractTestBase {
inputValues.add(Arrays.asList(10, 11, 12));
List<Integer> expectedOutputValues = Arrays.asList(3, 4, 5, 11, 12,
13);
executeSaveLoadAndCheckOutput(
- env, model, inputValues, expectedOutputValues, null, null,
true);
+ tEnv, model, inputValues, expectedOutputValues, null, null,
true);
}
@Test
@@ -143,7 +146,7 @@ public class GraphTest extends AbstractTestBase {
inputValues.add(Arrays.asList(10, 11, 12));
List<Integer> expectedOutputValues = Arrays.asList(7, 8, 9, 43, 44,
45);
executeSaveLoadAndCheckOutput(
- env, model, inputValues, expectedOutputValues, null, null,
false);
+ tEnv, model, inputValues, expectedOutputValues, null, null,
false);
}
@Test
@@ -176,7 +179,7 @@ public class GraphTest extends AbstractTestBase {
Collections.singletonList(Collections.singletonList(2));
List<Integer> expectedModelDataOutputValues =
Collections.singletonList(3);
executeSaveLoadAndCheckOutput(
- env,
+ tEnv,
model,
inputValues,
expectedOutputValues,
@@ -210,7 +213,7 @@ public class GraphTest extends AbstractTestBase {
inputValues.add(Arrays.asList(10, 11, 12));
List<Integer> expectedOutputValues = Arrays.asList(7, 8, 9, 43, 44,
45);
executeSaveLoadAndCheckOutput(
- env, estimator, inputValues, expectedOutputValues, null, null,
false);
+ tEnv, estimator, inputValues, expectedOutputValues, null,
null, false);
}
@Test
@@ -242,7 +245,7 @@ public class GraphTest extends AbstractTestBase {
List<Integer> expectedOutputValues = Arrays.asList(7, 8, 9, 16, 17,
18);
List<Integer> expectedModelDataOutputValues =
Collections.singletonList(6);
executeSaveLoadAndCheckOutput(
- env,
+ tEnv,
estimator,
inputValues,
expectedOutputValues,
diff --git
a/flink-ml-core/src/test/java/org/apache/flink/ml/api/PipelineTest.java
b/flink-ml-core/src/test/java/org/apache/flink/ml/api/PipelineTest.java
index 105fcf9..80c0024 100644
--- a/flink-ml-core/src/test/java/org/apache/flink/ml/api/PipelineTest.java
+++ b/flink-ml-core/src/test/java/org/apache/flink/ml/api/PipelineTest.java
@@ -74,7 +74,7 @@ public class PipelineTest extends AbstractTestBase {
model.save(path);
env.execute();
- Model<?> loadedModel = PipelineModel.load(env, path);
+ Model<?> loadedModel = PipelineModel.load(tEnv, path);
// Executes the loaded PipelineModel and verifies that it produces the
expected output.
TestUtils.executeAndCheckOutput(env, loadedModel, inputs, output,
null, null);
}
@@ -98,7 +98,7 @@ public class PipelineTest extends AbstractTestBase {
estimator.save(path);
env.execute();
- Estimator<?, ?> loadedEstimator = Pipeline.load(env, path);
+ Estimator<?, ?> loadedEstimator = Pipeline.load(tEnv, path);
// Executes the loaded Pipeline and verifies that it produces the
expected output.
TestUtils.executeAndCheckOutput(env, loadedEstimator, inputs, output,
null, null);
}
diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/api/StageTest.java
b/flink-ml-core/src/test/java/org/apache/flink/ml/api/StageTest.java
index 0d5115e..18b9be8 100644
--- a/flink-ml-core/src/test/java/org/apache/flink/ml/api/StageTest.java
+++ b/flink-ml-core/src/test/java/org/apache/flink/ml/api/StageTest.java
@@ -36,6 +36,7 @@ import org.apache.flink.ml.param.WithParams;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.junit.Assert;
import org.junit.Test;
@@ -121,7 +122,7 @@ public class StageTest {
ReadWriteUtils.saveMetadata(this, path);
}
- public static MyStage load(StreamExecutionEnvironment env, String
path) throws IOException {
+ public static MyStage load(StreamTableEnvironment tEnv, String path)
throws IOException {
return ReadWriteUtils.loadStageParam(path);
}
}
@@ -167,7 +168,7 @@ public class StageTest {
// Saves and loads the given stage. And verifies that the loaded stage has
same parameter values
// as the original stage.
private static Stage<?> validateStageSaveLoad(
- StreamExecutionEnvironment env, Stage<?> stage, Map<String,
Object> paramOverrides)
+ StreamTableEnvironment tEnv, Stage<?> stage, Map<String, Object>
paramOverrides)
throws IOException {
for (Map.Entry<String, Object> entry : paramOverrides.entrySet()) {
Param<?> param = stage.getParam(entry.getKey());
@@ -183,7 +184,7 @@ public class StageTest {
// This is expected.
}
- Stage<?> loadedStage = ReadWriteUtils.loadStage(env, path);
+ Stage<?> loadedStage = ReadWriteUtils.loadStage(tEnv, path);
for (Map.Entry<String, Object> entry : paramOverrides.entrySet()) {
Param<?> param = loadedStage.getParam(entry.getKey());
Assert.assertEquals(entry.getValue(), loadedStage.get(param));
@@ -308,28 +309,31 @@ public class StageTest {
@Test
public void testStageSaveLoad() throws IOException {
StreamExecutionEnvironment env =
StreamExecutionEnvironment.getExecutionEnvironment();
+ StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
MyStage stage = new MyStage();
stage.set(stage.paramWithNullDefault, 1);
- Stage<?> loadedStage = validateStageSaveLoad(env, stage,
Collections.emptyMap());
+ Stage<?> loadedStage = validateStageSaveLoad(tEnv, stage,
Collections.emptyMap());
Assert.assertEquals(1, (int) loadedStage.get(MyParams.INT_PARAM));
}
@Test
public void testStageSaveLoadWithParamOverrides() throws IOException {
StreamExecutionEnvironment env =
StreamExecutionEnvironment.getExecutionEnvironment();
+ StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
MyStage stage = new MyStage();
stage.set(stage.paramWithNullDefault, 1);
Stage<?> loadedStage =
- validateStageSaveLoad(env, stage,
Collections.singletonMap("intParam", 10));
+ validateStageSaveLoad(tEnv, stage,
Collections.singletonMap("intParam", 10));
Assert.assertEquals(10, (int) loadedStage.get(MyParams.INT_PARAM));
}
@Test
public void testStageLoadWithoutLoadMethod() throws IOException {
StreamExecutionEnvironment env =
StreamExecutionEnvironment.getExecutionEnvironment();
+ StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
MyStageWithoutLoad stage = new MyStageWithoutLoad();
try {
- validateStageSaveLoad(env, stage, Collections.emptyMap());
+ validateStageSaveLoad(tEnv, stage, Collections.emptyMap());
Assert.fail("Expected RuntimeException");
} catch (RuntimeException e) {
Assert.assertTrue(e.getMessage().contains("not implemented"));
diff --git
a/flink-ml-core/src/test/java/org/apache/flink/ml/util/ReadWriteUtilsTest.java
b/flink-ml-core/src/test/java/org/apache/flink/ml/util/ReadWriteUtilsTest.java
index e050647..10596b3 100644
---
a/flink-ml-core/src/test/java/org/apache/flink/ml/util/ReadWriteUtilsTest.java
+++
b/flink-ml-core/src/test/java/org/apache/flink/ml/util/ReadWriteUtilsTest.java
@@ -83,7 +83,7 @@ public class ReadWriteUtilsTest extends AbstractTestBase {
model.save(path);
env.execute();
- ExampleStages.SumModel loadedModel = ExampleStages.SumModel.load(env,
path);
+ ExampleStages.SumModel loadedModel = ExampleStages.SumModel.load(tEnv,
path);
// Executes the loaded SumModel and verifies that it produces the
expected output.
TestUtils.executeAndCheckOutput(env, loadedModel, inputs, output,
null, null);
}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
index 54fcf80..93c0e5c 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
@@ -30,7 +30,6 @@ import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
-import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
@@ -81,7 +80,7 @@ public class Knn implements Estimator<Knn, KnnModel>,
KnnParams<Knn> {
ReadWriteUtils.saveMetadata(this, path);
}
- public static Knn load(StreamExecutionEnvironment env, String path) throws
IOException {
+ public static Knn load(StreamTableEnvironment tEnv, String path) throws
IOException {
return ReadWriteUtils.loadStageParam(path);
}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
index 97aa965..28dcaea 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
@@ -31,7 +31,6 @@ import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
-import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
@@ -113,16 +112,15 @@ public class KnnModel implements Model<KnnModel>,
KnnModelParams<KnnModel> {
/**
* Loads model data from path.
*
- * @param env Stream execution environment.
+ * @param tEnv A StreamTableEnvironment instance.
* @param path Model path.
* @return Knn model.
*/
- public static KnnModel load(StreamExecutionEnvironment env, String path)
throws IOException {
- StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+ public static KnnModel load(StreamTableEnvironment tEnv, String path)
throws IOException {
KnnModel model = ReadWriteUtils.loadStageParam(path);
- DataStream<KnnModelData> modelData =
- ReadWriteUtils.loadModelData(env, path, new
KnnModelData.ModelDataDecoder());
- return model.setModelData(tEnv.fromDataStream(modelData));
+ Table modelDataTable =
+ ReadWriteUtils.loadModelData(tEnv, path, new
KnnModelData.ModelDataDecoder());
+ return model.setModelData(modelDataTable);
}
/** This operator loads model data and predicts result. */
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
index 08d9c78..58cf0ce 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
@@ -45,7 +45,6 @@ import
org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.runtime.state.StateSnapshotContext;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
-import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.BoundedOneInput;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
@@ -93,7 +92,7 @@ public class LogisticRegression
ReadWriteUtils.saveMetadata(this, path);
}
- public static LogisticRegression load(StreamExecutionEnvironment env,
String path)
+ public static LogisticRegression load(StreamTableEnvironment tEnv, String
path)
throws IOException {
return ReadWriteUtils.loadStageParam(path);
}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java
index ab1fb96..3c30a8d 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java
@@ -33,7 +33,6 @@ import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
-import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
@@ -74,14 +73,13 @@ public class LogisticRegressionModel
new LogisticRegressionModelData.ModelDataEncoder());
}
- public static LogisticRegressionModel load(StreamExecutionEnvironment env,
String path)
+ public static LogisticRegressionModel load(StreamTableEnvironment tEnv,
String path)
throws IOException {
- StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
LogisticRegressionModel model = ReadWriteUtils.loadStageParam(path);
- DataStream<LogisticRegressionModelData> modelData =
+ Table modelDataTable =
ReadWriteUtils.loadModelData(
- env, path, new
LogisticRegressionModelData.ModelDataDecoder());
- return model.setModelData(tEnv.fromDataStream(modelData));
+ tEnv, path, new
LogisticRegressionModelData.ModelDataDecoder());
+ return model.setModelData(modelDataTable);
}
@Override
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java
index 7a3cc3d..a4803c7 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java
@@ -32,7 +32,6 @@ import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
-import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
@@ -115,7 +114,7 @@ public class NaiveBayes
ReadWriteUtils.saveMetadata(this, path);
}
- public static NaiveBayes load(StreamExecutionEnvironment env, String path)
throws IOException {
+ public static NaiveBayes load(StreamTableEnvironment tEnv, String path)
throws IOException {
return ReadWriteUtils.loadStageParam(path);
}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModel.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModel.java
index 1307a9f..1ac9c00 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModel.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModel.java
@@ -32,7 +32,6 @@ import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
-import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
@@ -106,13 +105,11 @@ public class NaiveBayesModel
new NaiveBayesModelData.ModelDataEncoder());
}
- public static NaiveBayesModel load(StreamExecutionEnvironment env, String
path)
+ public static NaiveBayesModel load(StreamTableEnvironment tEnv, String
path)
throws IOException {
- StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
NaiveBayesModel model = ReadWriteUtils.loadStageParam(path);
- DataStream<NaiveBayesModelData> modelData =
- ReadWriteUtils.loadModelData(env, path, new
ModelDataDecoder());
- return model.setModelData(tEnv.fromDataStream(modelData));
+ Table modelDataTable = ReadWriteUtils.loadModelData(tEnv, path, new
ModelDataDecoder());
+ return model.setModelData(modelDataTable);
}
@Override
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
index ff54c6a..31e6d6c 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
@@ -50,7 +50,6 @@ import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.streaming.api.datastream.DataStream;
-import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
@@ -125,7 +124,7 @@ public class KMeans implements Estimator<KMeans,
KMeansModel>, KMeansParams<KMea
ReadWriteUtils.saveMetadata(this, path);
}
- public static KMeans load(StreamExecutionEnvironment env, String path)
throws IOException {
+ public static KMeans load(StreamTableEnvironment tEnv, String path) throws
IOException {
return ReadWriteUtils.loadStageParam(path);
}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModel.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModel.java
index 1e0a1c8..d569d97 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModel.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModel.java
@@ -31,7 +31,6 @@ import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
-import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
@@ -154,11 +153,9 @@ public class KMeansModel implements Model<KMeansModel>,
KMeansModelParams<KMeans
}
// TODO: Add INFO level logging.
- public static KMeansModel load(StreamExecutionEnvironment env, String
path) throws IOException {
- StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
- DataStream<KMeansModelData> modelData =
- ReadWriteUtils.loadModelData(env, path, new
ModelDataDecoder());
+ public static KMeansModel load(StreamTableEnvironment tEnv, String path)
throws IOException {
+ Table modelDataTable = ReadWriteUtils.loadModelData(tEnv, path, new
ModelDataDecoder());
KMeansModel model = ReadWriteUtils.loadStageParam(path);
- return model.setModelData(tEnv.fromDataStream(modelData));
+ return model.setModelData(modelDataTable);
}
}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java
index 287854b..a7a62c7 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java
@@ -27,6 +27,7 @@ import
org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
import org.apache.flink.core.fs.FSDataInputStream;
import org.apache.flink.core.memory.DataInputViewStreamWrapper;
import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.ml.common.datastream.TableUtils;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
import org.apache.flink.streaming.api.datastream.DataStream;
@@ -73,15 +74,15 @@ public class KMeansModelData {
* Generates a Table containing a {@link KMeansModelData} instance with
randomly generated
* centroids.
*
- * @param env The environment where to create the table.
+ * @param tEnv The environment where to create the table.
* @param k The number of generated centroids.
* @param dim The size of generated centroids.
* @param weight The weight of the centroids.
* @param seed Random seed.
*/
public static Table generateRandomModelData(
- StreamExecutionEnvironment env, int k, int dim, double weight,
long seed) {
- StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+ StreamTableEnvironment tEnv, int k, int dim, double weight, long
seed) {
+ StreamExecutionEnvironment env =
TableUtils.getExecutionEnvironment(tEnv);
return tEnv.fromDataStream(
env.fromElements(1).map(new RandomCentroidsCreator(k, dim,
weight, seed)));
}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
index 51b48d5..7ede3e4 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
@@ -40,7 +40,6 @@ import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.streaming.api.datastream.DataStream;
-import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
@@ -129,13 +128,10 @@ public class OnlineKMeans
new KMeansModelData.ModelDataEncoder());
}
- public static OnlineKMeans load(StreamExecutionEnvironment env, String
path)
- throws IOException {
+ public static OnlineKMeans load(StreamTableEnvironment tEnv, String path)
throws IOException {
OnlineKMeans onlineKMeans = ReadWriteUtils.loadStageParam(path);
- DataStream<KMeansModelData> initModelDataStream =
- ReadWriteUtils.loadModelData(env, path, new
KMeansModelData.ModelDataDecoder());
- StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
- onlineKMeans.initModelDataTable =
tEnv.fromDataStream(initModelDataStream);
+ onlineKMeans.initModelDataTable =
+ ReadWriteUtils.loadModelData(tEnv, path, new
KMeansModelData.ModelDataDecoder());
return onlineKMeans;
}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansModel.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansModel.java
index 7b262ba..3643d19 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansModel.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansModel.java
@@ -32,7 +32,6 @@ import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.streaming.api.datastream.DataStream;
-import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
@@ -207,7 +206,7 @@ public class OnlineKMeansModel
}
// TODO: Add INFO level logging.
- public static OnlineKMeansModel load(StreamExecutionEnvironment env,
String path)
+ public static OnlineKMeansModel load(StreamTableEnvironment tEnv, String
path)
throws IOException {
return ReadWriteUtils.loadStageParam(path);
}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScaler.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScaler.java
index 19a9f6f..40e13d5 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScaler.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScaler.java
@@ -33,7 +33,6 @@ import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.runtime.state.StateSnapshotContext;
import org.apache.flink.streaming.api.datastream.DataStream;
-import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.BoundedOneInput;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
@@ -196,8 +195,7 @@ public class MinMaxScaler
ReadWriteUtils.saveMetadata(this, path);
}
- public static MinMaxScaler load(StreamExecutionEnvironment env, String
path)
- throws IOException {
+ public static MinMaxScaler load(StreamTableEnvironment tEnv, String path)
throws IOException {
return ReadWriteUtils.loadStageParam(path);
}
}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModel.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModel.java
index 762d74a..379ae8f 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModel.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModel.java
@@ -29,7 +29,6 @@ import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
-import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
@@ -117,18 +116,17 @@ public class MinMaxScalerModel
/**
* Loads model data from path.
*
- * @param env Stream execution environment.
+ * @param tEnv Stream table environment.
* @param path Model path.
* @return MinMaxScalerModel model.
*/
- public static MinMaxScalerModel load(StreamExecutionEnvironment env,
String path)
+ public static MinMaxScalerModel load(StreamTableEnvironment tEnv, String
path)
throws IOException {
- StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
MinMaxScalerModel model = ReadWriteUtils.loadStageParam(path);
- DataStream<MinMaxScalerModelData> modelData =
+ Table modelDataTable =
ReadWriteUtils.loadModelData(
- env, path, new
MinMaxScalerModelData.ModelDataDecoder());
- return model.setModelData(tEnv.fromDataStream(modelData));
+ tEnv, path, new
MinMaxScalerModelData.ModelDataDecoder());
+ return model.setModelData(modelDataTable);
}
/** This operator loads model data and predicts result. */
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoder.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoder.java
index 984f863..9cb23a1 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoder.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoder.java
@@ -28,7 +28,6 @@ import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
-import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
@@ -85,8 +84,7 @@ public class OneHotEncoder
ReadWriteUtils.saveMetadata(this, path);
}
- public static OneHotEncoder load(StreamExecutionEnvironment env, String
path)
- throws IOException {
+ public static OneHotEncoder load(StreamTableEnvironment tEnv, String path)
throws IOException {
return ReadWriteUtils.loadStageParam(path);
}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoderModel.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoderModel.java
index 447fe77..482cf57 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoderModel.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoderModel.java
@@ -31,7 +31,6 @@ import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
-import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
@@ -118,14 +117,13 @@ public class OneHotEncoderModel
new OneHotEncoderModelData.ModelDataEncoder());
}
- public static OneHotEncoderModel load(StreamExecutionEnvironment env,
String path)
+ public static OneHotEncoderModel load(StreamTableEnvironment tEnv, String
path)
throws IOException {
- StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
OneHotEncoderModel model = ReadWriteUtils.loadStageParam(path);
- DataStream<Tuple2<Integer, Integer>> modelData =
+ Table modelDataTable =
ReadWriteUtils.loadModelData(
- env, path, new
OneHotEncoderModelData.ModelDataStreamFormat());
- return model.setModelData(tEnv.fromDataStream(modelData));
+ tEnv, path, new
OneHotEncoderModelData.ModelDataStreamFormat());
+ return model.setModelData(modelDataTable);
}
@Override
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/IndexToStringModel.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/IndexToStringModel.java
index b66a337..4090193 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/IndexToStringModel.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/IndexToStringModel.java
@@ -29,7 +29,6 @@ import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
-import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
@@ -67,16 +66,13 @@ public class IndexToStringModel
new StringIndexerModelData.ModelDataEncoder());
}
- public static IndexToStringModel load(StreamExecutionEnvironment env,
String path)
+ public static IndexToStringModel load(StreamTableEnvironment tEnv, String
path)
throws IOException {
- StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
-
IndexToStringModel model = ReadWriteUtils.loadStageParam(path);
- DataStream<StringIndexerModelData> modelData =
+ Table modelDataTable =
ReadWriteUtils.loadModelData(
- env, path, new
StringIndexerModelData.ModelDataDecoder());
-
- return model.setModelData(tEnv.fromDataStream(modelData));
+ tEnv, path, new
StringIndexerModelData.ModelDataDecoder());
+ return model.setModelData(modelDataTable);
}
@Override
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexer.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexer.java
index 0a8f7e9..b920be2 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexer.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexer.java
@@ -29,7 +29,6 @@ import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
-import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
@@ -70,8 +69,7 @@ public class StringIndexer
ReadWriteUtils.saveMetadata(this, path);
}
- public static StringIndexer load(StreamExecutionEnvironment env, String
path)
- throws IOException {
+ public static StringIndexer load(StreamTableEnvironment tEnv, String path)
throws IOException {
return ReadWriteUtils.loadStageParam(path);
}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexerModel.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexerModel.java
index 37307ba..066086c 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexerModel.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexerModel.java
@@ -29,7 +29,6 @@ import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
-import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
@@ -67,16 +66,13 @@ public class StringIndexerModel
new StringIndexerModelData.ModelDataEncoder());
}
- public static StringIndexerModel load(StreamExecutionEnvironment env,
String path)
+ public static StringIndexerModel load(StreamTableEnvironment tEnv, String
path)
throws IOException {
- StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
-
StringIndexerModel model = ReadWriteUtils.loadStageParam(path);
- DataStream<StringIndexerModelData> modelData =
+ Table modelDataTable =
ReadWriteUtils.loadModelData(
- env, path, new
StringIndexerModelData.ModelDataDecoder());
-
- return model.setModelData(tEnv.fromDataStream(modelData));
+ tEnv, path, new
StringIndexerModelData.ModelDataDecoder());
+ return model.setModelData(modelDataTable);
}
@Override
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/KnnTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/KnnTest.java
index 133cd74..d134e37 100644
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/KnnTest.java
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/KnnTest.java
@@ -183,11 +183,11 @@ public class KnnTest {
public void testSaveLoadAndPredict() throws Exception {
Knn knn = new Knn();
Knn loadedKnn =
- StageTestUtils.saveAndReload(env, knn,
tempFolder.newFolder().getAbsolutePath());
+ StageTestUtils.saveAndReload(tEnv, knn,
tempFolder.newFolder().getAbsolutePath());
KnnModel knnModel = loadedKnn.fit(trainData);
knnModel =
StageTestUtils.saveAndReload(
- env, knnModel,
tempFolder.newFolder().getAbsolutePath());
+ tEnv, knnModel,
tempFolder.newFolder().getAbsolutePath());
assertEquals(
Arrays.asList("packedFeatures", "featureNormSquares",
"labels"),
knnModel.getModelData()[0].getResolvedSchema().getColumnNames());
@@ -201,7 +201,7 @@ public class KnnTest {
KnnModel knnModel = knn.fit(trainData);
KnnModel newModel =
StageTestUtils.saveAndReload(
- env, knnModel,
tempFolder.newFolder().getAbsolutePath());
+ tEnv, knnModel,
tempFolder.newFolder().getAbsolutePath());
Table output = newModel.transform(predictData)[0];
verifyPredictionResult(output, knn.getLabelCol(),
knn.getPredictionCol());
}
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java
index ea6d80a..e3c1ce0 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java
@@ -227,9 +227,9 @@ public class LogisticRegressionTest {
LogisticRegression logisticRegression = new
LogisticRegression().setWeightCol("weight");
logisticRegression =
StageTestUtils.saveAndReload(
- env, logisticRegression,
tempFolder.newFolder().getAbsolutePath());
+ tEnv, logisticRegression,
tempFolder.newFolder().getAbsolutePath());
LogisticRegressionModel model =
logisticRegression.fit(binomialDataTable);
- model = StageTestUtils.saveAndReload(env, model,
tempFolder.newFolder().getAbsolutePath());
+ model = StageTestUtils.saveAndReload(tEnv, model,
tempFolder.newFolder().getAbsolutePath());
assertEquals(
Collections.singletonList("coefficient"),
model.getModelData()[0].getResolvedSchema().getColumnNames());
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/NaiveBayesTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/NaiveBayesTest.java
index 581242f..801b443 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/NaiveBayesTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/NaiveBayesTest.java
@@ -262,11 +262,11 @@ public class NaiveBayesTest {
public void testSaveLoad() throws Exception {
estimator =
StageTestUtils.saveAndReload(
- env, estimator,
tempFolder.newFolder().getAbsolutePath());
+ tEnv, estimator,
tempFolder.newFolder().getAbsolutePath());
NaiveBayesModel model = estimator.fit(trainTable);
- model = StageTestUtils.saveAndReload(env, model,
tempFolder.newFolder().getAbsolutePath());
+ model = StageTestUtils.saveAndReload(tEnv, model,
tempFolder.newFolder().getAbsolutePath());
Table outputTable = model.transform(predictTable)[0];
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
index de0c2c2..5d41bf2 100644
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
@@ -207,10 +207,11 @@ public class KMeansTest extends AbstractTestBase {
public void testSaveLoadAndPredict() throws Exception {
KMeans kmeans = new KMeans().setMaxIter(2).setK(2);
KMeans loadedKmeans =
- StageTestUtils.saveAndReload(env, kmeans,
tempFolder.newFolder().getAbsolutePath());
+ StageTestUtils.saveAndReload(
+ tEnv, kmeans,
tempFolder.newFolder().getAbsolutePath());
KMeansModel model = loadedKmeans.fit(dataTable);
KMeansModel loadedModel =
- StageTestUtils.saveAndReload(env, model,
tempFolder.newFolder().getAbsolutePath());
+ StageTestUtils.saveAndReload(tEnv, model,
tempFolder.newFolder().getAbsolutePath());
Table output = loadedModel.transform(dataTable)[0];
assertEquals(
Arrays.asList("centroids", "weights"),
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java
index c212f1d..d13beb0 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java
@@ -275,7 +275,7 @@ public class OnlineKMeansTest extends TestLogger {
.setPredictionCol("prediction")
.setGlobalBatchSize(6)
.setInitialModelData(
- KMeansModelData.generateRandomModelData(env,
2, 2, 0.0, 0));
+ KMeansModelData.generateRandomModelData(tEnv,
2, 2, 0.0, 0));
OnlineKMeansModel onlineModel = onlineKMeans.fit(onlineTrainTable);
transformAndOutputData(onlineModel);
@@ -367,7 +367,7 @@ public class OnlineKMeansTest extends TestLogger {
.setPredictionCol("prediction")
.setGlobalBatchSize(2)
.setInitialModelData(
- KMeansModelData.generateRandomModelData(env,
2, 2, 0.0, 0));
+ KMeansModelData.generateRandomModelData(tEnv,
2, 2, 0.0, 0));
try {
onlineKMeans.fit(onlineTrainTable);
@@ -400,13 +400,13 @@ public class OnlineKMeansTest extends TestLogger {
String savePath = tempFolder.newFolder().getAbsolutePath();
onlineKMeans.save(savePath);
miniCluster.executeJobBlocking(env.getStreamGraph().getJobGraph());
- OnlineKMeans loadedOnlineKMeans = OnlineKMeans.load(env, savePath);
+ OnlineKMeans loadedOnlineKMeans = OnlineKMeans.load(tEnv, savePath);
OnlineKMeansModel onlineModel =
loadedOnlineKMeans.fit(onlineTrainTable);
String modelSavePath = tempFolder.newFolder().getAbsolutePath();
onlineModel.save(modelSavePath);
- OnlineKMeansModel loadedOnlineModel = OnlineKMeansModel.load(env,
modelSavePath);
+ OnlineKMeansModel loadedOnlineModel = OnlineKMeansModel.load(tEnv,
modelSavePath);
loadedOnlineModel.setModelData(onlineModel.getModelData());
transformAndOutputData(loadedOnlineModel);
@@ -430,7 +430,7 @@ public class OnlineKMeansTest extends TestLogger {
.setPredictionCol("prediction")
.setGlobalBatchSize(6)
.setInitialModelData(
- KMeansModelData.generateRandomModelData(env,
2, 2, 0.0, 0));
+ KMeansModelData.generateRandomModelData(tEnv,
2, 2, 0.0, 0));
OnlineKMeansModel onlineModel = onlineKMeans.fit(onlineTrainTable);
transformAndOutputData(onlineModel);
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinMaxScalerTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinMaxScalerTest.java
index 24ec885..a5f930b 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinMaxScalerTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinMaxScalerTest.java
@@ -180,10 +180,10 @@ public class MinMaxScalerTest {
MinMaxScaler minMaxScaler = new MinMaxScaler();
MinMaxScaler loadedMinMaxScaler =
StageTestUtils.saveAndReload(
- env, minMaxScaler,
tempFolder.newFolder().getAbsolutePath());
+ tEnv, minMaxScaler,
tempFolder.newFolder().getAbsolutePath());
MinMaxScalerModel model = loadedMinMaxScaler.fit(trainDataTable);
MinMaxScalerModel loadedModel =
- StageTestUtils.saveAndReload(env, model,
tempFolder.newFolder().getAbsolutePath());
+ StageTestUtils.saveAndReload(tEnv, model,
tempFolder.newFolder().getAbsolutePath());
assertEquals(
Arrays.asList("minVector", "maxVector"),
model.getModelData()[0].getResolvedSchema().getColumnNames());
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OneHotEncoderTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OneHotEncoderTest.java
index 51f9735..85034e0 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OneHotEncoderTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OneHotEncoderTest.java
@@ -258,9 +258,9 @@ public class OneHotEncoderTest {
public void testSaveLoad() throws Exception {
estimator =
StageTestUtils.saveAndReload(
- env, estimator,
tempFolder.newFolder().getAbsolutePath());
+ tEnv, estimator,
tempFolder.newFolder().getAbsolutePath());
OneHotEncoderModel model = estimator.fit(trainTable);
- model = StageTestUtils.saveAndReload(env, model,
tempFolder.newFolder().getAbsolutePath());
+ model = StageTestUtils.saveAndReload(tEnv, model,
tempFolder.newFolder().getAbsolutePath());
Table outputTable = model.transform(predictTable)[0];
Map<Double, Vector>[] actualOutput =
executeAndCollect(outputTable, model.getInputCols(),
model.getOutputCols());
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/stringindexer/IndexToStringModelTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/stringindexer/IndexToStringModelTest.java
index e02276a..63a3216 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/stringindexer/IndexToStringModelTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/stringindexer/IndexToStringModelTest.java
@@ -136,7 +136,7 @@ public class IndexToStringModelTest extends
AbstractTestBase {
.setInputCols("inputCol1", "inputCol2")
.setOutputCols("outputCol1", "outputCol2")
.setModelData(modelTable);
- model = StageTestUtils.saveAndReload(env, model,
tempFolder.newFolder().getAbsolutePath());
+ model = StageTestUtils.saveAndReload(tEnv, model,
tempFolder.newFolder().getAbsolutePath());
assertEquals(
Collections.singletonList("stringArrays"),
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/stringindexer/StringIndexerTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/stringindexer/StringIndexerTest.java
index 805da5b..a325113 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/stringindexer/StringIndexerTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/stringindexer/StringIndexerTest.java
@@ -251,10 +251,10 @@ public class StringIndexerTest extends AbstractTestBase {
.setHandleInvalid(StringIndexerParams.KEEP_INVALID);
stringIndexer =
StageTestUtils.saveAndReload(
- env, stringIndexer,
tempFolder.newFolder().getAbsolutePath());
+ tEnv, stringIndexer,
tempFolder.newFolder().getAbsolutePath());
StringIndexerModel model = stringIndexer.fit(trainTable);
- model = StageTestUtils.saveAndReload(env, model,
tempFolder.newFolder().getAbsolutePath());
+ model = StageTestUtils.saveAndReload(tEnv, model,
tempFolder.newFolder().getAbsolutePath());
assertEquals(
Collections.singletonList("stringArrays"),
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/util/StageTestUtils.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/util/StageTestUtils.java
index 27283e8..138a6fc 100644
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/util/StageTestUtils.java
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/util/StageTestUtils.java
@@ -19,7 +19,9 @@
package org.apache.flink.ml.util;
import org.apache.flink.ml.api.Stage;
+import org.apache.flink.ml.common.datastream.TableUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import java.lang.reflect.Method;
@@ -30,7 +32,9 @@ public class StageTestUtils {
* stage.
*/
public static <T extends Stage<T>> T saveAndReload(
- StreamExecutionEnvironment env, T stage, String path) throws
Exception {
+ StreamTableEnvironment tEnv, T stage, String path) throws
Exception {
+ StreamExecutionEnvironment env =
TableUtils.getExecutionEnvironment(tEnv);
+
stage.save(path);
try {
env.execute();
@@ -42,7 +46,7 @@ public class StageTestUtils {
}
Method method =
- stage.getClass().getMethod("load",
StreamExecutionEnvironment.class, String.class);
- return (T) method.invoke(null, env, path);
+ stage.getClass().getMethod("load",
StreamTableEnvironment.class, String.class);
+ return (T) method.invoke(null, tEnv, path);
}
}