This is an automated email from the ASF dual-hosted git repository.
gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new e9c10b8 [FLINK-23959][FLIP-175] Compose Estimator/Model/AlgoOperator
from DAG of Estimator/Model/AlgoOperator
e9c10b8 is described below
commit e9c10b8acd4300a62c2193aef2948f933bc4c67d
Author: Dong Lin <[email protected]>
AuthorDate: Mon Oct 25 17:08:56 2021 +0800
[FLINK-23959][FLIP-175] Compose Estimator/Model/AlgoOperator from DAG of
Estimator/Model/AlgoOperator
This closes #20.
---
.../java/org/apache/flink/ml/builder/Graph.java | 153 ++++++++
.../org/apache/flink/ml/builder/GraphBuilder.java | 434 +++++++++++++++++++++
.../org/apache/flink/ml/builder/GraphData.java | 105 +++++
.../flink/ml/builder/GraphExecutionHelper.java | 128 ++++++
.../org/apache/flink/ml/builder/GraphModel.java | 148 +++++++
.../org/apache/flink/ml/builder/GraphNode.java | 143 +++++++
.../java/org/apache/flink/ml/builder/Pipeline.java | 3 +-
.../org/apache/flink/ml/builder/PipelineModel.java | 3 +-
.../java/org/apache/flink/ml/builder/TableId.java | 73 ++++
.../org/apache/flink/ml/util/ReadWriteUtils.java | 79 ++++
.../org/apache/flink/ml/api/ExampleStages.java | 88 +++--
.../java/org/apache/flink/ml/api/GraphTest.java | 253 ++++++++++++
.../java/org/apache/flink/ml/api/PipelineTest.java | 72 ++--
.../java/org/apache/flink/ml/api/StageTest.java | 4 +-
.../java/org/apache/flink/ml/api/TestUtils.java | 147 +++++++
15 files changed, 1750 insertions(+), 83 deletions(-)
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/builder/Graph.java
b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/Graph.java
new file mode 100644
index 0000000..8123e04
--- /dev/null
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/Graph.java
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.builder;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.api.Stage;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.util.Preconditions;
+
+import javax.annotation.Nullable;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.apache.flink.ml.builder.GraphNode.StageType;
+
+/**
+ * A Graph acts as an Estimator. A Graph consists of a DAG of stages, each of
which could be an
+ * Estimator, Model, Transformer or AlgoOperator. When `Graph::fit` is called,
the stages are
+ * executed in a topologically-sorted order. If a stage is an Estimator, its
`Estimator::fit` method
+ * will be called on the input tables (from the input edges) to fit a Model.
Then the Model will be
+ * used to transform the input tables and produce output tables to the output
edges. If a stage is
+ * an AlgoOperator, its `AlgoOperator::transform` method will be called on the
input tables and
+ * produce output tables to the output edges. The GraphModel fitted from a
Graph consists of the
+ * fitted Models and AlgoOperators, corresponding to the Graph's stages.
+ */
+@PublicEvolving
+public final class Graph implements Estimator<Graph, GraphModel> {
+ private static final long serialVersionUID = 6354253958813529308L;
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+ private final List<GraphNode> nodes;
+ private final TableId[] estimatorInputIds;
+ private final TableId[] modelInputIds;
+ private final TableId[] outputIds;
+ private final @Nullable TableId[] inputModelDataIds;
+ private final @Nullable TableId[] outputModelDataIds;
+
+ public Graph(
+ List<GraphNode> nodes,
+ TableId[] estimatorInputIds,
+ TableId[] modelInputs,
+ TableId[] outputs,
+ TableId[] inputModelDataIds,
+ TableId[] outputModelDataIds) {
+ this.nodes = Preconditions.checkNotNull(nodes);
+ this.estimatorInputIds = Preconditions.checkNotNull(estimatorInputIds);
+ this.modelInputIds = Preconditions.checkNotNull(modelInputs);
+ this.outputIds = Preconditions.checkNotNull(outputs);
+ this.inputModelDataIds = inputModelDataIds;
+ this.outputModelDataIds = outputModelDataIds;
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public GraphModel fit(Table... inputTables) {
+ Preconditions.checkArgument(
+ estimatorInputIds.length == inputTables.length,
+ "number of provided tables %s does not match the expected
number of tables %s",
+ inputTables.length,
+ estimatorInputIds.length);
+ List<GraphNode> modelNodes = new ArrayList<>();
+ GraphExecutionHelper executionHelper = new GraphExecutionHelper(nodes);
+ // Maps estimatorInputIds to inputTables.
+ executionHelper.setTables(estimatorInputIds, inputTables);
+ // Iterates until we have executed all ready nodes.
+ GraphNode node;
+ while ((node = executionHelper.pollNextReadyNode()) != null) {
+ Stage<?> stage = node.stage;
+ // Invokes fit(...) if stageType == ESTIMATOR.
+ if (node.stageType == StageType.ESTIMATOR) {
+ stage =
+ ((Estimator<?, ?>) stage)
+
.fit(executionHelper.getTables(node.estimatorInputIds));
+ }
+ // Invokes setModelData(...).
+ if (node.inputModelDataIds != null) {
+ Table[] nodeInputModelData =
executionHelper.getTables(node.inputModelDataIds);
+ ((Model<?>) stage).setModelData(nodeInputModelData);
+ }
+ // Invokes transform(...).
+ Table[] nodeOutputs =
+ ((AlgoOperator<?>) stage)
+
.transform(executionHelper.getTables(node.algoOpInputIds));
+ executionHelper.setTables(node.outputIds, nodeOutputs);
+ // Invokes getModelData().
+ if (node.outputModelDataIds != null) {
+ Table[] nodeOutputModelData = ((Model<?>)
stage).getModelData();
+ executionHelper.setTables(node.outputModelDataIds,
nodeOutputModelData);
+ }
+
+ modelNodes.add(
+ new GraphNode(
+ node.nodeId,
+ stage,
+ StageType.ALGO_OPERATOR,
+ null,
+ node.algoOpInputIds,
+ node.outputIds,
+ node.inputModelDataIds,
+ node.outputModelDataIds));
+ }
+ return new GraphModel(
+ modelNodes, modelInputIds, outputIds, inputModelDataIds,
outputModelDataIds);
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ GraphData graphData =
+ new GraphData(
+ nodes,
+ estimatorInputIds,
+ modelInputIds,
+ outputIds,
+ inputModelDataIds,
+ outputModelDataIds);
+ ReadWriteUtils.saveGraph(this, graphData, path);
+ }
+
+ public static Graph load(StreamExecutionEnvironment env, String path)
throws IOException {
+ return (Graph) ReadWriteUtils.loadGraph(env, path,
Graph.class.getName());
+ }
+}
diff --git
a/flink-ml-core/src/main/java/org/apache/flink/ml/builder/GraphBuilder.java
b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/GraphBuilder.java
new file mode 100644
index 0000000..c3dd657
--- /dev/null
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/GraphBuilder.java
@@ -0,0 +1,434 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.builder;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.api.Stage;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.apache.flink.ml.builder.GraphNode.StageType;
+
+/**
+ * A GraphBuilder provides APIs to build Estimator/Model/AlgoOperator from a
DAG of stages, each of
+ * which could be an Estimator, Model, Transformer or AlgoOperator.
+ */
+@PublicEvolving
+public final class GraphBuilder {
+
+ private int maxOutputLength = 20;
+
+ private int nextTableId = 0;
+
+ private int nextNodeId = 0;
+
+ /** An ordered list of nodes in the graph. */
+ private final List<GraphNode> nodes = new ArrayList<>();
+ /** A map from stage instance to the corresponding node in the graph. */
+ private final Map<Stage<?>, GraphNode> existingNodes = new HashMap<>();
+
+ public GraphBuilder() {}
+
+ /**
+ * Specifies the loose upper bound of the number of output tables that can
be returned by the
+ * Model::getModelData() and AlgoOperator::transform() methods, for any
stage involved in this
+ * Graph.
+ *
+ * <p>The default upper bound is 20.
+ */
+ public GraphBuilder setMaxOutputTableNum(int maxOutputLength) {
+ this.maxOutputLength = maxOutputLength;
+ return this;
+ }
+
+ /**
+ * Creates a TableId associated with this GraphBuilder. It can be used to
specify the passing of
+ * tables between stages, as well as the input/output tables of the
Graph/GraphModel generated
+ * by this builder.
+ *
+ * @return A TableId.
+ */
+ public TableId createTableId() {
+ return new TableId(nextTableId++);
+ }
+
+ /**
+ * Adds an AlgoOperator in the graph.
+ *
+ * <p>When the graph runs as Estimator, the transform() of the given
AlgoOperator would be
+ * invoked with the given inputs. Then when the GraphModel fitted by this
graph runs, the
+ * transform() of the given AlgoOperator would be invoked with the given
inputs.
+ *
+ * <p>When the graph runs as AlgoOperator or Model, the transform() of the
given AlgoOperator
+ * would be invoked with the given inputs.
+ *
+ * <p>NOTE: the number of the returned TableIds does not represent the
actual number of Tables
+ * outputted by transform(). This number could be configured using {@link
+ * #setMaxOutputTableNum(int)}. Users should make sure that this number >=
the actual number of
+ * Tables outputted by transform().
+ *
+ * @param algoOp An AlgoOperator instance.
+ * @param inputs A list of TableIds which represents inputs to transform()
of the given
+ * AlgoOperator.
+ * @return A list of TableIds which represents the outputs of transform()
of the given
+ * AlgoOperator.
+ */
+ public TableId[] addAlgoOperator(AlgoOperator<?> algoOp, TableId...
inputs) {
+ return addStage(algoOp, StageType.ALGO_OPERATOR, null, inputs);
+ }
+
+ /**
+ * Adds an Estimator in the graph.
+ *
+ * <p>When the graph runs as Estimator, the fit() of the given Estimator
would be invoked with
+ * the given inputs. Then when the GraphModel fitted by this graph runs,
the transform() of the
+ * Model fitted by the given Estimator would be invoked with the given
inputs.
+ *
+ * <p>When the graph runs as AlgoOperator or Model, the fit() of the given
Estimator would be
+ * invoked with the given inputs, then the transform() of the Model fitted
by the given
+ * Estimator would be invoked with the given inputs.
+ *
+ * <p>NOTE: the number of the returned TableIds does not represent the
actual number of Tables
+ * outputted by transform(). This number could be configured using {@link
+ * #setMaxOutputTableNum(int)}. Users should make sure that this number >=
the actual number of
+ * Tables outputted by transform().
+ *
+ * @param estimator An Estimator instance.
+ * @param inputs A list of TableIds which represents inputs to fit() of
the given Estimator as
+ * well as inputs to transform() of the Model fitted by the given
Estimator.
+ * @return A list of TableIds which represents the outputs of transform()
of the Model fitted by
+ * the given Estimator.
+ */
+ public TableId[] addEstimator(Estimator<?, ?> estimator, TableId...
inputs) {
+ return addEstimator(estimator, inputs, inputs);
+ }
+
+ /**
+ * Adds an Estimator in the graph.
+ *
+ * <p>When the graph runs as Estimator, the fit() of the given Estimator
would be invoked with
+ * estimatorInputs. Then when the GraphModel fitted by this graph runs,
the transform() of the
+ * Model fitted by the given Estimator would be invoked with modelInputs.
+ *
+ * <p>When the graph runs as AlgoOperator or Model, the fit() of the given
Estimator would be
+ * invoked with estimatorInputs, then the transform() of the Model fitted
by the given Estimator
+ * would be invoked with modelInputs.
+ *
+ * <p>NOTE: the number of the returned TableIds does not represent the
actual number of Tables
+ * outputted by transform(). This number could be configured using {@link
+ * #setMaxOutputTableNum(int)}. Users should make sure that this number >=
the actual number of
+ * Tables outputted by transform().
+ *
+ * @param estimator An Estimator instance.
+ * @param estimatorInputs A list of TableIds which represents inputs to
fit() of the given
+ * Estimator.
+ * @param modelInputs A list of TableIds which represents inputs to
transform() of the Model
+ * fitted by the given Estimator.
+ * @return A list of TableIds which represents the outputs of transform()
of the Model fitted by
+ * the given Estimator.
+ */
+ public TableId[] addEstimator(
+ Estimator<?, ?> estimator, TableId[] estimatorInputs, TableId[]
modelInputs) {
+ return addStage(estimator, StageType.ESTIMATOR, estimatorInputs,
modelInputs);
+ }
+
+ /**
+ * When the graph runs as Estimator, it first generates a GraphModel that
contains the Model
+ * fitted by the given Estimator. Then when this GraphModel runs, the
setModelData() of the
+ * fitted Model would be invoked with the given inputs before its
transform() is invoked.
+ *
+ * <p>When the graph runs as AlgoOperator or Model, the setModelData() of
the Model fitted by
+ * the given Estimator would be invoked with the given inputs before its
transform() is invoked.
+ *
+ * @param estimator An Estimator instance.
+ * @param inputs A list of TableIds which represents inputs to
setModelData() of the Model
+ * fitted by the given Estimator.
+ */
+ public void setModelDataOnEstimator(Estimator<?, ?> estimator, TableId...
inputs) {
+ GraphNode node = existingNodes.get(estimator);
+ if (node == null) {
+ throw new RuntimeException("the Estimator has not been added to
the graph");
+ }
+ if (node.stageType != StageType.ESTIMATOR) {
+ throw new RuntimeException("the Estimator was previously added as
an AlgoOperator");
+ }
+ if (node.inputModelDataIds != null) {
+ throw new RuntimeException("the model data of this Estimator has
already been set");
+ }
+ node.inputModelDataIds = inputs;
+ }
+
+ /**
+ * When the graph runs as Estimator, the setModelData() of the given Model
would be invoked with
+ * the given inputs before its transform() is invoked. Then when the
GraphModel fitted by this
+ * graph runs, the setModelData() of the given Model would be invoked with
the given inputs.
+ *
+ * <p>When the graph runs as AlgoOperator or Model, the setModelData() of
the given Model would
+ * be invoked with the given inputs before its transform() is invoked.
+ *
+ * @param model A Model instance.
+ * @param inputs A list of TableIds which represents inputs to
setModelData() of the given
+ * Model.
+ */
+ public void setModelDataOnModel(Model<?> model, TableId... inputs) {
+ GraphNode node = existingNodes.get(model);
+ if (node == null) {
+ throw new RuntimeException("the Model has not been added to the
graph");
+ }
+ if (node.stageType != StageType.ALGO_OPERATOR) {
+ throw new RuntimeException("the Model was previously added as an
Estimator");
+ }
+ if (node.inputModelDataIds != null) {
+ throw new RuntimeException("the model data of this Model has
already been set");
+ }
+ node.inputModelDataIds = inputs;
+ }
+
+ /**
+ * When the graph runs as Estimator, it first generates a GraphModel that
contains the Model
+ * fitted by the given Estimator. Then when this GraphModel runs, the
getModelData() of the
+ * fitted Model would be invoked.
+ *
+ * <p>When the graph runs as AlgoOperator or Model, the getModelData() of
the Model fitted by
+ * the given Estimator would be invoked.
+ *
+ * <p>NOTE: the number of the returned TableIds does not represent the
actual number of Tables
+ * outputted by getModelData(). This number could be configured using
{@link
+ * #setMaxOutputTableNum(int)}. Users should make sure that this number >=
the actual number of
+ * Tables outputted by getModelData().
+ *
+ * @param estimator An Estimator instance.
+ * @return A list of TableIds which represents the outputs of
getModelData() of the Model fitted
+ * by the given Estimator.
+ */
+ public TableId[] getModelDataFromEstimator(Estimator<?, ?> estimator) {
+ GraphNode node = existingNodes.get(estimator);
+ if (node == null) {
+ throw new RuntimeException("the Estimator has not been added to
the graph");
+ }
+ if (node.stageType != StageType.ESTIMATOR) {
+ throw new RuntimeException("the Estimator was previously added as
an AlgoOperator");
+ }
+ if (node.outputModelDataIds != null) {
+ throw new RuntimeException("the model data of this Estimator has
already been fetched");
+ }
+ node.outputModelDataIds = createTableIds(maxOutputLength);
+ return node.outputModelDataIds;
+ }
+
+ /**
+ * When the graph runs as Estimator, the getModelData() of the given Model
would be invoked.
+ * Then when the GraphModel fitted by this graph runs, the getModelData()
of the given Model
+ * would be invoked.
+ *
+ * <p>When the graph runs as AlgoOperator or Model, the getModelData() of
the given Model would
+ * be invoked.
+ *
+ * <p>NOTE: the number of the returned TableIds does not represent the
actual number of Tables
+ * outputted by getModelData(). This number could be configured using
{@link
+ * #setMaxOutputTableNum(int)}. Users should make sure that this number >=
the actual number of
+ * Tables outputted by getModelData().
+ *
+ * @param model A Model instance.
+ * @return A list of TableIds which represents the outputs of
getModelData() of the given Model.
+ */
+ public TableId[] getModelDataFromModel(Model<?> model) {
+ GraphNode node = existingNodes.get(model);
+ if (node == null) {
+ throw new RuntimeException("the Model has not been added to the
graph");
+ }
+ if (node.stageType != StageType.ALGO_OPERATOR) {
+ throw new RuntimeException("the Model was previously added as an
Estimator");
+ }
+ if (node.outputModelDataIds != null) {
+ throw new RuntimeException("the model data of this Model has
already been fetched");
+ }
+ node.outputModelDataIds = createTableIds(maxOutputLength);
+ return node.outputModelDataIds;
+ }
+
+ /**
+ * Wraps nodes of the graph into an Estimator.
+ *
+ * <p>When the returned Estimator runs, and when the Model fitted by the
returned Estimator
+ * runs, the sequence of operations recorded by the {@code
addAlgoOperator(...)}, {@code
+ * addEstimator(...)}, {@code setModelData(...)} and {@code
getModelData(...)} would be executed
+ * as specified in the Java doc of the corresponding methods.
+ *
+ * @param inputs A list of TableIds which represents inputs to fit() of
the returned Estimator
+ * as well as inputs to transform() of the Model fitted by the
returned Estimator.
+ * @param outputs A list of TableIds which represents outputs of
transform() of the Model fitted
+ * by the returned Estimator.
+ * @return An Estimator which wraps the nodes of this graph.
+ */
+ public Estimator<?, ?> buildEstimator(TableId[] inputs, TableId[] outputs)
{
+ return buildEstimator(inputs, inputs, outputs, null, null);
+ }
+
+ /**
+ * Wraps nodes of the graph into an Estimator.
+ *
+ * <p>When the returned Estimator runs, and when the Model fitted by the
returned Estimator
+ * runs, the sequence of operations recorded by the {@code
addAlgoOperator(...)}, {@code
+ * addEstimator(...)}, {@code setModelData(...)} and {@code
getModelData(...)} would be executed
+ * as specified in the Java doc of the corresponding methods.
+ *
+ * @param inputs A list of TableIds which represents inputs to fit() of
the returned Estimator
+ * as well as inputs to transform() of the Model fitted by the
returned Estimator.
+ * @param outputs A list of TableIds which represents outputs of
transform() of the Model fitted
+ * by the returned Estimator.
+ * @param inputModelData A list of TableIds which represents inputs to
setModelData() of the
+ * Model fitted by the returned Estimator.
+ * @param outputModelData A list of TableIds which represents outputs of
getModelData() of the
+ * Model fitted by the returned Estimator.
+ * @return An Estimator which wraps the nodes of this graph.
+ */
+ public Estimator<?, ?> buildEstimator(
+ TableId[] inputs,
+ TableId[] outputs,
+ TableId[] inputModelData,
+ TableId[] outputModelData) {
+ return buildEstimator(inputs, inputs, outputs, inputModelData,
outputModelData);
+ }
+
+ /**
+ * Wraps nodes of the graph into an Estimator.
+ *
+ * <p>When the returned Estimator runs, and when the Model fitted by the
returned Estimator
+ * runs, the sequence of operations recorded by the {@code
addAlgoOperator(...)}, {@code
+ * addEstimator(...)}, {@code setModelData(...)} and {@code
getModelData(...)} would be executed
+ * as specified in the Java doc of the corresponding methods.
+ *
+ * @param estimatorInputs A list of TableIds which represents inputs to
fit() of the returned
+ * Estimator.
+ * @param modelInputs A list of TableIds which represents inputs to
transform() of the Model
+ * fitted by the returned Estimator.
+ * @param outputs A list of TableIds which represents outputs of
transform() of the Model fitted
+ * by the returned Estimator.
+ * @param inputModelData A list of TableIds which represents inputs to
setModelData() of the
+ * Model fitted by the returned Estimator.
+ * @param outputModelData A list of TableIds which represents outputs of
getModelData() of the
+ * Model fitted by the returned Estimator.
+ * @return An Estimator which wraps the nodes of this graph.
+ */
+ public Estimator<?, ?> buildEstimator(
+ TableId[] estimatorInputs,
+ TableId[] modelInputs,
+ TableId[] outputs,
+ TableId[] inputModelData,
+ TableId[] outputModelData) {
+ return new Graph(
+ nodes, estimatorInputs, modelInputs, outputs, inputModelData,
outputModelData);
+ }
+
+ /**
+ * Wraps nodes of the graph into an AlgoOperator.
+ *
+ * <p>When the returned AlgoOperator runs, the sequence of operations
recorded by the {@code
+ * addAlgoOperator(...)} and {@code addEstimator(...)} would be executed
as specified in the
+ * Java doc of the corresponding methods.
+ *
+ * @param inputs A list of TableIds which represents inputs to transform()
of the returned
+ * AlgoOperator.
+ * @param outputs A list of TableIds which represents outputs of
transform() of the returned
+ * AlgoOperator.
+ * @return An AlgoOperator which wraps the nodes of this graph.
+ */
+ public AlgoOperator<?> buildAlgoOperator(TableId[] inputs, TableId[]
outputs) {
+ return buildModel(inputs, outputs, null, null);
+ }
+
+ /**
+ * Wraps nodes of the graph into a Model.
+ *
+ * <p>When the returned Model runs, the sequence of operations recorded by
the {@code
+ * addAlgoOperator(...)} and {@code addEstimator(...)} would be executed
as specified in the
+ * Java doc of the corresponding methods.
+ *
+ * @param inputs A list of TableIds which represents inputs to transform()
of the returned
+ * Model.
+ * @param outputs A list of TableIds which represents outputs of
transform() of the returned
+ * Model.
+ * @return A Model which wraps the nodes of this graph.
+ */
+ public Model<?> buildModel(TableId[] inputs, TableId[] outputs) {
+ return buildModel(inputs, outputs, null, null);
+ }
+
+ /**
+ * Wraps nodes of the graph into a Model.
+ *
+ * <p>When the returned Model runs, the sequence of operations recorded by
the {@code
+ * addAlgoOperator(...)}, {@code addEstimator(...)}, {@code
setModelData(...)} and {@code
+ * getModelData(...)} would be executed as specified in the Java doc of
the corresponding
+ * methods.
+ *
+ * @param inputs A list of TableIds which represents inputs to transform()
of the returned
+ * Model.
+ * @param outputs A list of TableIds which represents outputs of
transform() of the returned
+ * Model.
+ * @param inputModelData A list of TableIds which represents inputs to
setModelData() of the
+ * returned Model.
+ * @param outputModelData A list of TableIds which represents outputs of
getModelData() of the
+ * returned Model.
+ * @return A Model which wraps the nodes of this graph.
+ */
+ public Model<?> buildModel(
+ TableId[] inputs,
+ TableId[] outputs,
+ TableId[] inputModelData,
+ TableId[] outputModelData) {
+ return new GraphModel(nodes, inputs, outputs, inputModelData,
outputModelData);
+ }
+
+ private TableId[] createTableIds(int count) {
+ TableId[] result = new TableId[count];
+ for (int i = 0; i < count; i++) {
+ result[i] = createTableId();
+ }
+ return result;
+ }
+
+ private TableId[] addStage(
+ Stage<?> stage, StageType stageType, TableId[] estimatorInputs,
TableId[] modelInputs) {
+ TableId[] outputs = createTableIds(maxOutputLength);
+ if (existingNodes.containsKey(stage)) {
+ throw new RuntimeException("The stage " + stage + " has already
been added.");
+ }
+ GraphNode node =
+ new GraphNode(
+ nextNodeId++,
+ stage,
+ stageType,
+ estimatorInputs,
+ modelInputs,
+ outputs,
+ null,
+ null);
+ nodes.add(node);
+ existingNodes.put(stage, node);
+ return outputs;
+ }
+}
diff --git
a/flink-ml-core/src/main/java/org/apache/flink/ml/builder/GraphData.java
b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/GraphData.java
new file mode 100644
index 0000000..b688a93
--- /dev/null
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/GraphData.java
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.builder;
+
+import org.apache.flink.util.Preconditions;
+
+import javax.annotation.Nullable;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/** This class contains fields that can be used to re-construct Graph and
GraphModel. */
+public class GraphData {
+ public final List<GraphNode> nodes;
+ public final @Nullable TableId[] estimatorInputIds;
+ public final TableId[] modelInputIds;
+ public final TableId[] outputIds;
+ public final @Nullable TableId[] inputModelDataIds;
+ public final @Nullable TableId[] outputModelDataIds;
+
+ public GraphData(
+ List<GraphNode> nodes,
+ TableId[] estimatorInputIds,
+ TableId[] modelInputIds,
+ TableId[] outputIds,
+ TableId[] inputModelDataIds,
+ TableId[] outputModelDataIds) {
+ this.nodes = Preconditions.checkNotNull(nodes);
+ this.estimatorInputIds = estimatorInputIds;
+ this.modelInputIds = Preconditions.checkNotNull(modelInputIds);
+ this.outputIds = Preconditions.checkNotNull(outputIds);
+ this.inputModelDataIds = inputModelDataIds;
+ this.outputModelDataIds = outputModelDataIds;
+ }
+
+ public Map<String, Object> toMap() {
+ Map<String, Object> result = new HashMap<>();
+
+ List<Map<String, Object>> nodeInfos = new ArrayList<>();
+ for (GraphNode node : nodes) {
+ nodeInfos.add(node.toMap());
+ }
+ result.put("nodes", nodeInfos);
+ if (estimatorInputIds != null) {
+ result.put("estimatorInputIds", TableId.toList(estimatorInputIds));
+ }
+ result.put("modelInputIds", TableId.toList(modelInputIds));
+ result.put("outputIds", TableId.toList(outputIds));
+ if (inputModelDataIds != null) {
+ result.put("inputModelDataIds", TableId.toList(inputModelDataIds));
+ }
+ if (outputModelDataIds != null) {
+ result.put("outputModelDataIds",
TableId.toList(outputModelDataIds));
+ }
+ return result;
+ }
+
+ public static GraphData fromMap(Map<String, Object> map) {
+ List<GraphNode> nodes = new ArrayList<>();
+ List<Map<String, Object>> nodeInfos = (List<Map<String, Object>>)
map.get("nodes");
+ for (Map<String, Object> nodeInfo : nodeInfos) {
+ nodes.add(GraphNode.fromMap(nodeInfo));
+ }
+
+ TableId[] estimatorInputIds = null;
+ if (map.containsKey("estimatorInputIds")) {
+ estimatorInputIds = TableId.fromList((List<Integer>)
map.get("estimatorInputIds"));
+ }
+ TableId[] modelInputIds = TableId.fromList((List<Integer>)
map.get("modelInputIds"));
+ TableId[] outputIds = TableId.fromList((List<Integer>)
map.get("outputIds"));
+ TableId[] inputModelDataIds = null;
+ if (map.containsKey("inputModelDataIds")) {
+ inputModelDataIds = TableId.fromList((List<Integer>)
map.get("inputModelDataIds"));
+ }
+ TableId[] outputModelDataIds = null;
+ if (map.containsKey("outputModelDataIds")) {
+ outputModelDataIds = TableId.fromList((List<Integer>)
map.get("outputModelDataIds"));
+ }
+ return new GraphData(
+ nodes,
+ estimatorInputIds,
+ modelInputIds,
+ outputIds,
+ inputModelDataIds,
+ outputModelDataIds);
+ }
+}
diff --git
a/flink-ml-core/src/main/java/org/apache/flink/ml/builder/GraphExecutionHelper.java
b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/GraphExecutionHelper.java
new file mode 100644
index 0000000..09973f4
--- /dev/null
+++
b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/GraphExecutionHelper.java
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.builder;
+
+import org.apache.flink.table.api.Table;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Deque;
+import java.util.HashMap;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * A container class that maintains the execution state of the graph (e.g.
which nodes are ready to
+ * run).
+ */
+class GraphExecutionHelper {
+ /** A map from tableId to the list of nodes which take this table as
input. */
+ private final Map<TableId, List<GraphNode>> consumerNodes = new
HashMap<>();
+ /**
+ * A map from tableId to the corresponding table. A TableId would be
mapped iff its
+ * corresponding Table has been constructed.
+ */
+ private final Map<TableId, Table> constructedTables = new HashMap<>();
+ /**
+ * A map that maintains the number of input tables that have not been
constructed for each node
+ * in the graph.
+ */
+ private final Map<GraphNode, Integer> numUnConstructedInputTables = new
HashMap<>();
+ /**
+ * An ordered list of nodes whose input tables have all been constructed
AND who has not been
+ * fetch via pollNextReadyNode.
+ */
+ private final Deque<GraphNode> unFetchedReadyNodes = new LinkedList<>();
+
+ public GraphExecutionHelper(List<GraphNode> nodes) {
+ // Initializes dependentNodes and numUnConstructedInputs.
+ for (GraphNode node : nodes) {
+ List<TableId> inputs = new ArrayList<>();
+ inputs.addAll(Arrays.asList(node.algoOpInputIds));
+ if (node.estimatorInputIds != null) {
+ inputs.addAll(Arrays.asList(node.estimatorInputIds));
+ }
+ if (node.inputModelDataIds != null) {
+ inputs.addAll(Arrays.asList(node.inputModelDataIds));
+ }
+ numUnConstructedInputTables.put(node, inputs.size());
+ for (TableId tableId : inputs) {
+ consumerNodes.putIfAbsent(tableId, new ArrayList<>());
+ consumerNodes.get(tableId).add(node);
+ }
+ }
+ }
+
+ public void setTables(TableId[] tableIds, Table[] tables) {
+ // The length of tableIds could be larger than the length of tables
because we over-allocate
+ // the number of tableIds (which is 20 by default) as placeholder of
the stage's output
+ // tables when adding a stage in GraphBuilder.
+ Preconditions.checkArgument(
+ tableIds.length >= tables.length,
+ "the length of tablesIds %s is less than the length of tables
%s",
+ tableIds.length,
+ tables.length);
+ for (int i = 0; i < tables.length; i++) {
+ setTable(tableIds[i], tables[i]);
+ }
+ }
+
+ private void setTable(TableId tableId, Table table) {
+ Preconditions.checkArgument(
+ !constructedTables.containsKey(tableId),
+ "the table with id=%s has already been constructed",
+ tableId.toString());
+ constructedTables.put(tableId, table);
+
+ for (GraphNode node : consumerNodes.getOrDefault(tableId, new
ArrayList<>())) {
+ int prevNum = numUnConstructedInputTables.get(node);
+ if (prevNum == 1) {
+ unFetchedReadyNodes.addLast(node);
+ numUnConstructedInputTables.remove(node);
+ } else {
+ numUnConstructedInputTables.put(node, prevNum - 1);
+ }
+ }
+ }
+
+ public Table[] getTables(TableId[] tableIds) {
+ Table[] tables = new Table[tableIds.length];
+ for (int i = 0; i < tableIds.length; i++) {
+ tables[i] = getTable(tableIds[i]);
+ }
+ return tables;
+ }
+
+ private Table getTable(TableId tableId) {
+ Preconditions.checkArgument(
+ constructedTables.containsKey(tableId),
+ "the table with id=%s has not been constructed yet",
+ tableId.toString());
+ return constructedTables.get(tableId);
+ }
+
+ public GraphNode pollNextReadyNode() {
+ if (unFetchedReadyNodes.isEmpty() &&
!numUnConstructedInputTables.isEmpty()) {
+ throw new RuntimeException("there exists node whose input can not
be constructed");
+ }
+ return unFetchedReadyNodes.pollFirst();
+ }
+}
diff --git
a/flink-ml-core/src/main/java/org/apache/flink/ml/builder/GraphModel.java
b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/GraphModel.java
new file mode 100644
index 0000000..894098f
--- /dev/null
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/GraphModel.java
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.builder;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.api.Stage;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.util.Preconditions;
+
+import javax.annotation.Nullable;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.apache.flink.ml.builder.GraphNode.StageType;
+
+/**
+ * A GraphModel acts as a Model. A GraphModel consists of a DAG of stages,
each of which could be an
+ * Estimator, Model, Transformer or AlgoOperators. When
`GraphModel::transform` is called, the
+ * stages are executed in a topologically-sorted order. When a stage is
executed, its
+ * `AlgoOperator::transform` method will be called on the input tables (from
the input edges) and
+ * produce output tables to the output edges.
+ */
+@PublicEvolving
+public final class GraphModel implements Model<GraphModel> {
+ private static final long serialVersionUID = 6354856913812529398L;
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+ private final List<GraphNode> nodes;
+ private final TableId[] inputIds;
+ private final TableId[] outputIds;
+ private final @Nullable TableId[] inputModelDataIds;
+ private final @Nullable TableId[] outputModelDataIds;
+ private final GraphExecutionHelper executionHelper;
+
+ public GraphModel(
+ List<GraphNode> nodes,
+ TableId[] inputIds,
+ TableId[] outputIds,
+ TableId[] inputModelDataIds,
+ TableId[] outputModelDataIds) {
+ this.nodes = Preconditions.checkNotNull(nodes);
+ this.inputIds = Preconditions.checkNotNull(inputIds);
+ this.outputIds = Preconditions.checkNotNull(outputIds);
+ this.inputModelDataIds = inputModelDataIds;
+ this.outputModelDataIds = outputModelDataIds;
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ executionHelper = new GraphExecutionHelper(nodes);
+ }
+
+ @Override
+ public Table[] transform(Table... inputTables) {
+ Preconditions.checkArgument(
+ inputIds.length == inputTables.length,
+ "number of provided tables %s does not match the expected
number of tables %s",
+ inputTables.length,
+ inputIds.length);
+ // Maps inputIds to inputTables.
+ executionHelper.setTables(inputIds, inputTables);
+ // Iterates until we have executed all ready nodes.
+ GraphNode node;
+ while ((node = executionHelper.pollNextReadyNode()) != null) {
+ Stage<?> stage = node.stage;
+ // Invokes fit(...) if stageType == ESTIMATOR.
+ if (node.stageType == StageType.ESTIMATOR) {
+ stage =
+ ((Estimator<?, ?>) stage)
+
.fit(executionHelper.getTables(node.estimatorInputIds));
+ }
+ // Invokes setModelData(...).
+ if (node.inputModelDataIds != null) {
+ Table[] nodeInputModelData =
executionHelper.getTables(node.inputModelDataIds);
+ ((Model<?>) stage).setModelData(nodeInputModelData);
+ }
+ // Invokes transform(...).
+ Table[] nodeOutputs =
+ ((AlgoOperator<?>) stage)
+
.transform(executionHelper.getTables(node.algoOpInputIds));
+ executionHelper.setTables(node.outputIds, nodeOutputs);
+ // Invokes getModelData().
+ if (node.outputModelDataIds != null) {
+ Table[] nodeOutputModelData = ((Model<?>)
stage).getModelData();
+ executionHelper.setTables(node.outputModelDataIds,
nodeOutputModelData);
+ }
+ }
+ return executionHelper.getTables(outputIds);
+ }
+
+ @Override
+ public GraphModel setModelData(Table... inputTables) {
+ Preconditions.checkArgument(inputModelDataIds != null, "setModelData()
is not supported");
+ Preconditions.checkArgument(
+ inputModelDataIds.length == inputTables.length,
+ "number of provided tables %s does not match the expected
number of tables %s",
+ inputTables.length,
+ inputIds.length);
+ // Maps inputModelDataIds to inputTables.
+ executionHelper.setTables(inputModelDataIds, inputTables);
+ return this;
+ }
+
+ @Override
+ public Table[] getModelData() {
+ Preconditions.checkArgument(outputModelDataIds != null);
+ return executionHelper.getTables(outputModelDataIds);
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ GraphData graphData =
+ new GraphData(
+ nodes, null, inputIds, outputIds, inputModelDataIds,
outputModelDataIds);
+ ReadWriteUtils.saveGraph(this, graphData, path);
+ }
+
+ public static GraphModel load(StreamExecutionEnvironment env, String path)
throws IOException {
+ return (GraphModel) ReadWriteUtils.loadGraph(env, path,
GraphModel.class.getName());
+ }
+}
diff --git
a/flink-ml-core/src/main/java/org/apache/flink/ml/builder/GraphNode.java
b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/GraphNode.java
new file mode 100644
index 0000000..e9d680e
--- /dev/null
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/GraphNode.java
@@ -0,0 +1,143 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.builder;
+
+import org.apache.flink.ml.api.Stage;
+import org.apache.flink.util.Preconditions;
+
+import javax.annotation.Nullable;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+
+/** The Graph node class. */
+public class GraphNode {
+ /** This class specifies whether a node should be used as Estimator or
AlgoOperator. */
+ public enum StageType {
+ ESTIMATOR,
+ ALGO_OPERATOR;
+ }
+
+ public final int nodeId;
+ public @Nullable Stage<?> stage;
+ public final StageType stageType;
+ public final @Nullable TableId[] estimatorInputIds;
+ public final TableId[] algoOpInputIds;
+ public final TableId[] outputIds;
+ public @Nullable TableId[] inputModelDataIds;
+ public @Nullable TableId[] outputModelDataIds;
+
+ public GraphNode(
+ int nodeId,
+ Stage<?> stage,
+ StageType stageType,
+ TableId[] estimatorInputIds,
+ TableId[] algoOpInputIds,
+ TableId[] outputIds,
+ TableId[] inputModelDataIds,
+ TableId[] outputModelDataIds) {
+ this.nodeId = Preconditions.checkNotNull(nodeId);
+ this.stage = stage;
+ this.stageType = Preconditions.checkNotNull(stageType);
+ this.estimatorInputIds = estimatorInputIds;
+ this.algoOpInputIds = Preconditions.checkNotNull(algoOpInputIds);
+ this.outputIds = Preconditions.checkNotNull(outputIds);
+ this.inputModelDataIds = inputModelDataIds;
+ this.outputModelDataIds = outputModelDataIds;
+ }
+
+ public Map<String, Object> toMap() {
+ Map<String, Object> result = new HashMap<>();
+ result.put("nodeId", nodeId);
+ result.put("stageType", stageType.name());
+ if (estimatorInputIds != null) {
+ result.put("estimatorInputIds", TableId.toList(estimatorInputIds));
+ }
+ result.put("algoOpInputIds", TableId.toList(algoOpInputIds));
+ result.put("outputIds", TableId.toList(outputIds));
+ if (inputModelDataIds != null) {
+ result.put("inputModelDataIds", TableId.toList(inputModelDataIds));
+ }
+ if (outputModelDataIds != null) {
+ result.put("outputModelDataIds",
TableId.toList(outputModelDataIds));
+ }
+ return result;
+ }
+
+ public static GraphNode fromMap(Map<String, Object> map) {
+ int nodeId = (Integer) map.get("nodeId");
+ StageType stageType = StageType.valueOf((String) map.get("stageType"));
+ TableId[] estimatorInputIds = null;
+ if (map.containsKey("estimatorInputIds")) {
+ estimatorInputIds = TableId.fromList((List<Integer>)
map.get("estimatorInputIds"));
+ }
+ TableId[] algoOpInputIds = TableId.fromList((List<Integer>)
map.get("algoOpInputIds"));
+ TableId[] outputIds = TableId.fromList((List<Integer>)
map.get("outputIds"));
+ TableId[] inputModelDataIds = null;
+ if (map.containsKey("inputModelDataIds")) {
+ inputModelDataIds = TableId.fromList((List<Integer>)
map.get("inputModelDataIds"));
+ }
+ TableId[] outputModelDataIds = null;
+ if (map.containsKey("outputModelDataIds")) {
+ outputModelDataIds = TableId.fromList((List<Integer>)
map.get("outputModelDataIds"));
+ }
+ return new GraphNode(
+ nodeId,
+ null,
+ stageType,
+ estimatorInputIds,
+ algoOpInputIds,
+ outputIds,
+ inputModelDataIds,
+ outputModelDataIds);
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (obj == null) {
+ return false;
+ }
+ if (!(obj instanceof GraphNode)) {
+ return false;
+ }
+ GraphNode other = (GraphNode) obj;
+ return nodeId == other.nodeId;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(nodeId);
+ }
+
+ @Override
+ public String toString() {
+ return String.format(
+ "GraphNode(nodeId=%d, stageType=%s, estimatorInputIds=%s,
algoOpInputIds=%s, outputIds=%s, inputModelDataIds=%s, outputModelDataIds=%s)",
+ nodeId,
+ stageType.name(),
+ Arrays.toString(estimatorInputIds),
+ Arrays.toString(algoOpInputIds),
+ Arrays.toString(outputIds),
+ Arrays.toString(inputModelDataIds),
+ Arrays.toString(outputModelDataIds));
+ }
+}
diff --git
a/flink-ml-core/src/main/java/org/apache/flink/ml/builder/Pipeline.java
b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/Pipeline.java
index 408198a..3e021ab 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/builder/Pipeline.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/Pipeline.java
@@ -28,6 +28,7 @@ import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
+import org.apache.flink.util.Preconditions;
import java.io.IOException;
import java.util.ArrayList;
@@ -47,7 +48,7 @@ public final class Pipeline implements Estimator<Pipeline,
PipelineModel> {
private final Map<Param<?>, Object> paramMap = new HashMap<>();
public Pipeline(List<Stage<?>> stages) {
- this.stages = stages;
+ this.stages = Preconditions.checkNotNull(stages);
ParamUtils.initializeMapWithDefaultValues(paramMap, this);
}
diff --git
a/flink-ml-core/src/main/java/org/apache/flink/ml/builder/PipelineModel.java
b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/PipelineModel.java
index 2668f82..ee5f099 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/builder/PipelineModel.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/PipelineModel.java
@@ -28,6 +28,7 @@ import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
+import org.apache.flink.util.Preconditions;
import java.io.IOException;
import java.util.Collections;
@@ -46,7 +47,7 @@ public final class PipelineModel implements
Model<PipelineModel> {
private final Map<Param<?>, Object> paramMap = new HashMap<>();
public PipelineModel(List<Stage<?>> stages) {
- this.stages = stages;
+ this.stages = Preconditions.checkNotNull(stages);
ParamUtils.initializeMapWithDefaultValues(paramMap, this);
}
diff --git
a/flink-ml-core/src/main/java/org/apache/flink/ml/builder/TableId.java
b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/TableId.java
new file mode 100644
index 0000000..51f566b
--- /dev/null
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/TableId.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.builder;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Objects;
+
+/**
+ * The TableId is necessary to pass the inputs/outputs of various API calls
across the nodes of
+ * Graph and GraphModel.
+ */
+public class TableId {
+ public final int id;
+
+ public TableId(int id) {
+ this.id = id;
+ }
+
+ public static List<Integer> toList(TableId[] tableIds) {
+ List<Integer> result = new ArrayList<>(tableIds.length);
+ for (int i = 0; i < tableIds.length; i++) {
+ result.add(tableIds[i].id);
+ }
+ return result;
+ }
+
+ public static TableId[] fromList(List<Integer> tableIds) {
+ TableId[] result = new TableId[tableIds.size()];
+ for (int i = 0; i < tableIds.size(); i++) {
+ result[i] = new TableId(tableIds.get(i));
+ }
+ return result;
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (obj == null) {
+ return false;
+ }
+ if (!(obj instanceof TableId)) {
+ return false;
+ }
+ TableId other = (TableId) obj;
+ return id == other.id;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(id);
+ }
+
+ @Override
+ public String toString() {
+ return "TableId(" + id + ")";
+ }
+}
diff --git
a/flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java
b/flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java
index 2c21e20..eedc74d 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java
@@ -25,12 +25,17 @@ import org.apache.flink.connector.file.sink.FileSink;
import org.apache.flink.connector.file.src.FileSource;
import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
import org.apache.flink.ml.api.Stage;
+import org.apache.flink.ml.builder.Graph;
+import org.apache.flink.ml.builder.GraphData;
+import org.apache.flink.ml.builder.GraphModel;
+import org.apache.flink.ml.builder.GraphNode;
import org.apache.flink.ml.param.Param;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import
org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner;
import
org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
import org.apache.flink.util.InstantiationUtil;
+import org.apache.flink.util.Preconditions;
import
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper;
@@ -45,6 +50,7 @@ import java.lang.reflect.Method;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
+import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -231,6 +237,79 @@ public class ReadWriteUtils {
return stages;
}
+ /**
+ * Saves a Graph or GraphModel with the given GraphData to the given path.
+ *
+ * @param graph A Graph or GraphModel instance.
+ * @param graphData A GraphData instance.
+ * @param path The parent directory to save the graph metadata and its
stages.
+ */
+ public static void saveGraph(Stage<?> graph, GraphData graphData, String
path)
+ throws IOException {
+ // Creates parent directories if not already created.
+ new File(path).mkdirs();
+
+ Map<String, Object> extraMetadata = new HashMap<>();
+ extraMetadata.put("graphData", graphData.toMap());
+ saveMetadata(graph, path, extraMetadata);
+ int maxNodeId =
+ graphData.nodes.stream()
+ .map(node -> node.nodeId)
+ .max(Comparator.naturalOrder())
+ .orElse(-1);
+
+ for (GraphNode node : graphData.nodes) {
+ String stagePath = getPathForPipelineStage(node.nodeId, maxNodeId
+ 1, path);
+ node.stage.save(stagePath);
+ }
+ }
+
+ /**
+ * Loads a Graph or GraphModel from the given path.
+ *
+ * <p>The method throws RuntimeException if the expectedClassName is not
empty AND it does not
+ * match the className of the previously saved Pipeline or PipelineModel.
+ *
+ * @param env A StreamExecutionEnvironment instance.
+ * @param path The parent directory to load the pipeline metadata and its
stages.
+ * @param expectedClassName The expected class name of the pipeline.
+ * @return A Graph or GraphModel instance.
+ */
+ public static Stage<?> loadGraph(
+ StreamExecutionEnvironment env, String path, String
expectedClassName)
+ throws IOException {
+ Map<String, ?> metadata = loadMetadata(path, expectedClassName);
+ GraphData graphData = GraphData.fromMap((Map<String, Object>)
metadata.get("graphData"));
+
+ int maxNodeId =
+ graphData.nodes.stream()
+ .map(node -> node.nodeId)
+ .max(Comparator.naturalOrder())
+ .orElse(-1);
+
+ for (GraphNode node : graphData.nodes) {
+ String stagePath = getPathForPipelineStage(node.nodeId, maxNodeId
+ 1, path);
+ node.stage = loadStage(env, stagePath);
+ }
+
+ if (expectedClassName.equals(GraphModel.class.getName())) {
+ return new GraphModel(
+ graphData.nodes,
+ graphData.modelInputIds,
+ graphData.outputIds,
+ graphData.inputModelDataIds,
+ graphData.outputModelDataIds);
+ }
+
Preconditions.checkState(expectedClassName.equals(Graph.class.getName()));
+ return new Graph(
+ graphData.nodes,
+ graphData.estimatorInputIds,
+ graphData.modelInputIds,
+ graphData.outputIds,
+ graphData.inputModelDataIds,
+ graphData.outputModelDataIds);
+ }
+
// A helper method that sets stage's parameter value. We can not call
stage.set(param, value)
// directly because stage::set(...) needs the actual type of the value.
public static <T> void setParam(Stage<?> stage, Param<T> param, Object
value) {
diff --git
a/flink-ml-core/src/test/java/org/apache/flink/ml/api/ExampleStages.java
b/flink-ml-core/src/test/java/org/apache/flink/ml/api/ExampleStages.java
index 204d477..63de3a2 100644
--- a/flink-ml-core/src/test/java/org/apache/flink/ml/api/ExampleStages.java
+++ b/flink-ml-core/src/test/java/org/apache/flink/ml/api/ExampleStages.java
@@ -38,16 +38,9 @@ import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
-import org.apache.commons.collections.IteratorUtils;
import org.junit.Assert;
-import java.io.DataInputStream;
-import java.io.DataOutputStream;
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.FileOutputStream;
import java.io.IOException;
-import java.nio.file.Paths;
import java.util.HashMap;
import java.util.Map;
@@ -60,6 +53,7 @@ public class ExampleStages {
public static class SumModel implements Model<SumModel> {
private final Map<Param<?>, Object> paramMap = new HashMap<>();
private DataStream<Integer> modelData;
+ private Table modelDataTable;
// This empty constructor is necessary in order for ModelA to be
loaded by
// ReadWriteUtils.createStageWithParam
@@ -94,46 +88,29 @@ public class ExampleStages {
(StreamTableEnvironment) ((TableImpl)
inputs[0]).getTableEnvironment();
modelData = tEnv.toDataStream(inputs[0], Integer.class);
+ modelDataTable = inputs[0];
return this;
}
@Override
+ public Table[] getModelData() {
+ return new Table[] {modelDataTable};
+ }
+
+ @Override
public void save(String path) throws IOException {
+ ReadWriteUtils.saveModelData(modelData, path, new
TestUtils.IntEncoder());
ReadWriteUtils.saveMetadata(this, path);
-
- File dataDir = new File(path, "data");
- if (!dataDir.mkdir()) {
- throw new IOException("Directory " + dataDir.toString() + "
already exists.");
- }
-
- File dataFile = new File(dataDir, "delta");
- if (!dataFile.createNewFile()) {
- throw new IOException("File " + dataFile.toString() + "
already exists.");
- }
-
- int delta;
- try {
- delta = (Integer)
IteratorUtils.toList(modelData.executeAndCollect()).get(0);
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
-
- try (DataOutputStream outputStream =
- new DataOutputStream(new FileOutputStream(dataFile))) {
- outputStream.writeInt(delta);
- }
}
public static SumModel load(StreamExecutionEnvironment env, String
path)
throws IOException {
- SumModel model = ReadWriteUtils.loadStageParam(path);
- File dataFile = Paths.get(path, "data", "delta").toFile();
+ StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+ DataStream<Integer> modelData =
+ ReadWriteUtils.loadModelData(env, path, new
TestUtils.IntegerStreamFormat());
- try (DataInputStream inputStream = new DataInputStream(new
FileInputStream(dataFile))) {
- StreamTableEnvironment tEnv =
StreamTableEnvironment.create(env);
- Table modelData =
tEnv.fromDataStream(env.fromElements(inputStream.readInt()));
- return model.setModelData(modelData);
- }
+ SumModel model = ReadWriteUtils.loadStageParam(path);
+ return model.setModelData(tEnv.fromDataStream(modelData));
}
}
@@ -239,4 +216,43 @@ public class ExampleStages {
sum += input.getValue();
}
}
+
+ /**
+ * A Transformer subclass that takes 2 inputs and returns the union of
these two inputs as the
+ * output.
+ */
+ public static class UnionAlgoOperator implements
Transformer<UnionAlgoOperator> {
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+ public UnionAlgoOperator() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+
+ @Override
+ public Table[] transform(Table... inputs) {
+ Assert.assertEquals(2, inputs.length);
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
inputs[0]).getTableEnvironment();
+
+ DataStream<Integer> inputA = tEnv.toDataStream(inputs[0],
Integer.class);
+ DataStream<Integer> inputB = tEnv.toDataStream(inputs[1],
Integer.class);
+
+ return new Table[] {tEnv.fromDataStream(inputA.union(inputB))};
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ }
+
+ public static UnionAlgoOperator load(StreamExecutionEnvironment env,
String path)
+ throws IOException {
+ return ReadWriteUtils.loadStageParam(path);
+ }
+ }
}
diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/api/GraphTest.java
b/flink-ml-core/src/test/java/org/apache/flink/ml/api/GraphTest.java
new file mode 100644
index 0000000..213ed3c
--- /dev/null
+++ b/flink-ml-core/src/test/java/org/apache/flink/ml/api/GraphTest.java
@@ -0,0 +1,253 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.api;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.api.ExampleStages.SumEstimator;
+import org.apache.flink.ml.api.ExampleStages.SumModel;
+import org.apache.flink.ml.api.ExampleStages.UnionAlgoOperator;
+import org.apache.flink.ml.builder.Graph;
+import org.apache.flink.ml.builder.GraphBuilder;
+import org.apache.flink.ml.builder.GraphModel;
+import org.apache.flink.ml.builder.TableId;
+import
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+
+import org.junit.Before;
+import org.junit.Test;
+
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+/** Tests the behavior of {@link Graph} and {@link GraphModel}. */
+public class GraphTest extends AbstractTestBase {
+ private StreamExecutionEnvironment env;
+ private StreamTableEnvironment tEnv;
+
+ @Before
+ public void before() {
+ Configuration config = new Configuration();
+
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH,
true);
+ env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+ env.setParallelism(4);
+ env.enableCheckpointing(100);
+ env.setRestartStrategy(RestartStrategies.noRestart());
+ tEnv = StreamTableEnvironment.create(env);
+ }
+
+ // Executes the given stage using the given inputs and verifies that it
produces the expected
+ // output. Then repeats this procedure after saving and loading the given
stage.
+ private static void executeSaveLoadAndCheckOutput(
+ StreamExecutionEnvironment env,
+ Stage<?> stage,
+ List<List<Integer>> inputs,
+ List<Integer> expectedOutput,
+ List<List<Integer>> modelDataInputs,
+ List<Integer> expectedModelDataOutput,
+ boolean modelDataExists)
+ throws Exception {
+ // Executes the given stage and verifies that it produces the expected
output.
+ TestUtils.executeAndCheckOutput(
+ env, stage, inputs, expectedOutput, modelDataInputs,
expectedModelDataOutput);
+ // Saves and loads the given stage.
+ String path = Files.createTempDirectory("").toString();
+ stage.save(path);
+
+ if (modelDataExists) {
+ env.execute();
+ }
+
+ Stage<?> loadedStage = null;
+ if (stage instanceof Estimator) {
+ loadedStage = Graph.load(env, path);
+ } else {
+ loadedStage = GraphModel.load(env, path);
+ }
+ // Executes the loaded stage and verifies that it produces the
expected output.
+ TestUtils.executeAndCheckOutput(
+ env, loadedStage, inputs, expectedOutput, modelDataInputs,
expectedModelDataOutput);
+ }
+
+ @Test
+ public void testGraphModelWithoutEstimator() throws Exception {
+ GraphBuilder builder = new GraphBuilder();
+ // Creates nodes.
+ SumModel stage1 = new SumModel().setModelData(tEnv.fromValues(2));
+ SumModel stage2 = new SumModel().setModelData(tEnv.fromValues(1));
+ AlgoOperator<?> stage3 = new UnionAlgoOperator();
+ // Creates inputs.
+ TableId input1 = builder.createTableId();
+ TableId input2 = builder.createTableId();
+ // Feeds inputs and gets outputs.
+ TableId output1 = builder.addAlgoOperator(stage1, input1)[0];
+ TableId output2 = builder.addAlgoOperator(stage2, input2)[0];
+ TableId output3 = builder.addAlgoOperator(stage3, output1, output2)[0];
+
+ // Builds a Model from the graph.
+ TableId[] inputs = new TableId[] {input1, input2};
+ TableId[] outputs = new TableId[] {output3};
+ Model<?> model = builder.buildModel(inputs, outputs);
+ // Executes the GraphModel and verifies that it produces the expected
output.
+ List<List<Integer>> inputValues = new ArrayList<>();
+ inputValues.add(Arrays.asList(1, 2, 3));
+ inputValues.add(Arrays.asList(10, 11, 12));
+ List<Integer> expectedOutputValues = Arrays.asList(3, 4, 5, 11, 12,
13);
+ executeSaveLoadAndCheckOutput(
+ env, model, inputValues, expectedOutputValues, null, null,
true);
+ }
+
+ @Test
+ public void testGraphModelWithEstimator() throws Exception {
+ GraphBuilder builder = new GraphBuilder();
+ // Creates nodes.
+ Estimator<?, ?> stage1 = new SumEstimator();
+ Estimator<?, ?> stage2 = new SumEstimator();
+ AlgoOperator<?> stage3 = new UnionAlgoOperator();
+ // Creates inputs.
+ TableId input1 = builder.createTableId();
+ TableId input2 = builder.createTableId();
+ // Feeds inputs and gets outputs.
+ TableId output1 = builder.addEstimator(stage1, input1)[0];
+ TableId output2 = builder.addEstimator(stage2, input2)[0];
+ TableId output3 = builder.addAlgoOperator(stage3, output1, output2)[0];
+
+ // Builds a Model from the graph.
+ TableId[] inputs = new TableId[] {input1, input2};
+ TableId[] outputs = new TableId[] {output3};
+ Model<?> model = builder.buildModel(inputs, outputs);
+ // Executes the GraphModel and verifies that it produces the expected
output.
+ List<List<Integer>> inputValues = new ArrayList<>();
+ inputValues.add(Arrays.asList(1, 2, 3));
+ inputValues.add(Arrays.asList(10, 11, 12));
+ List<Integer> expectedOutputValues = Arrays.asList(7, 8, 9, 43, 44,
45);
+ executeSaveLoadAndCheckOutput(
+ env, model, inputValues, expectedOutputValues, null, null,
false);
+ }
+
+ @Test
+ public void testGraphModelWithSetGetModelData() throws Exception {
+ GraphBuilder builder = new GraphBuilder();
+ // Creates nodes.
+ SumModel stage1 = new SumModel().setModelData(tEnv.fromValues(1));
+ SumModel stage2 = new SumModel();
+ SumModel stage3 = new SumModel().setModelData(tEnv.fromValues(3));
+ // Creates inputs and modelDataInputs.
+ TableId input = builder.createTableId();
+ TableId modelDataInput = builder.createTableId();
+ // Feeds inputs and gets outputs.
+ TableId output1 = builder.addAlgoOperator(stage1, input)[0];
+ TableId output2 = builder.addAlgoOperator(stage2, output1)[0];
+ builder.setModelDataOnModel(stage2, modelDataInput);
+ TableId output3 = builder.addAlgoOperator(stage3, output2)[0];
+ TableId modelDataOutput = builder.getModelDataFromModel(stage3)[0];
+
+ // Builds a Model from the graph.
+ TableId[] inputs = new TableId[] {input};
+ TableId[] outputs = new TableId[] {output3};
+ TableId[] modelDataInputs = new TableId[] {modelDataInput};
+ TableId[] modelDataOutputs = new TableId[] {modelDataOutput};
+ Model<?> model = builder.buildModel(inputs, outputs, modelDataInputs,
modelDataOutputs);
+ // Executes the GraphModel and verifies that it produces the expected
output.
+ List<List<Integer>> inputValues =
Collections.singletonList(Arrays.asList(1, 2, 3));
+ List<Integer> expectedOutputValues = Arrays.asList(7, 8, 9);
+ List<List<Integer>> inputModelDataValues =
+ Collections.singletonList(Collections.singletonList(2));
+ List<Integer> expectedModelDataOutputValues =
Collections.singletonList(3);
+ executeSaveLoadAndCheckOutput(
+ env,
+ model,
+ inputValues,
+ expectedOutputValues,
+ inputModelDataValues,
+ expectedModelDataOutputValues,
+ true);
+ }
+
+ @Test
+ public void testGraphWithEstimator() throws Exception {
+ GraphBuilder builder = new GraphBuilder();
+ // Creates nodes.
+ Estimator<?, ?> stage1 = new SumEstimator();
+ Estimator<?, ?> stage2 = new SumEstimator();
+ AlgoOperator<?> stage3 = new UnionAlgoOperator();
+ // Creates inputs.
+ TableId input1 = builder.createTableId();
+ TableId input2 = builder.createTableId();
+ // Feeds inputs and gets outputs.
+ TableId output1 = builder.addEstimator(stage1, input1)[0];
+ TableId output2 = builder.addEstimator(stage2, input2)[0];
+ TableId output3 = builder.addAlgoOperator(stage3, output1, output2)[0];
+
+ // Builds an Estimator from the graph.
+ TableId[] inputs = new TableId[] {input1, input2};
+ TableId[] outputs = new TableId[] {output3};
+ Estimator<?, ?> estimator = builder.buildEstimator(inputs, outputs);
+ // Executes the Graph and verifies that it produces the expected
output.
+ List<List<Integer>> inputValues = new ArrayList<>();
+ inputValues.add(Arrays.asList(1, 2, 3));
+ inputValues.add(Arrays.asList(10, 11, 12));
+ List<Integer> expectedOutputValues = Arrays.asList(7, 8, 9, 43, 44,
45);
+ executeSaveLoadAndCheckOutput(
+ env, estimator, inputValues, expectedOutputValues, null, null,
false);
+ }
+
+ @Test
+ public void testGraphWithSetGetModelData() throws Exception {
+ GraphBuilder builder = new GraphBuilder();
+ // Creates nodes.
+ Estimator<?, ?> stage1 = new SumEstimator();
+ SumModel stage2 = new SumModel();
+ AlgoOperator<?> stage3 = new UnionAlgoOperator();
+ // Creates inputs.
+ TableId input1 = builder.createTableId();
+ TableId input2 = builder.createTableId();
+ // Feeds inputs and gets outputs.
+ TableId output1 = builder.addEstimator(stage1, input1)[0];
+ TableId modelDataOutput = builder.getModelDataFromEstimator(stage1)[0];
+ TableId output2 = builder.addAlgoOperator(stage2, input2)[0];
+ builder.setModelDataOnModel(stage2, modelDataOutput);
+ TableId output3 = builder.addAlgoOperator(stage3, output1, output2)[0];
+
+ // Builds an Estimator from the graph.
+ TableId[] inputs = new TableId[] {input1, input2};
+ TableId[] outputs = new TableId[] {output3};
+ TableId[] modelDataOutputs = new TableId[] {modelDataOutput};
+ Estimator<?, ?> estimator = builder.buildEstimator(inputs, outputs,
null, modelDataOutputs);
+ // Executes the Graph and verifies that it produces the expected
output.
+ List<List<Integer>> inputValues = new ArrayList<>();
+ inputValues.add(Arrays.asList(1, 2, 3));
+ inputValues.add(Arrays.asList(10, 11, 12));
+ List<Integer> expectedOutputValues = Arrays.asList(7, 8, 9, 16, 17,
18);
+ List<Integer> expectedModelDataOutputValues =
Collections.singletonList(6);
+ executeSaveLoadAndCheckOutput(
+ env,
+ estimator,
+ inputValues,
+ expectedOutputValues,
+ null,
+ expectedModelDataOutputValues,
+ true);
+ }
+}
diff --git
a/flink-ml-core/src/test/java/org/apache/flink/ml/api/PipelineTest.java
b/flink-ml-core/src/test/java/org/apache/flink/ml/api/PipelineTest.java
index 8ce3fe3..105fcf9 100644
--- a/flink-ml-core/src/test/java/org/apache/flink/ml/api/PipelineTest.java
+++ b/flink-ml-core/src/test/java/org/apache/flink/ml/api/PipelineTest.java
@@ -18,57 +18,43 @@
package org.apache.flink.ml.api;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.ExampleStages.SumEstimator;
import org.apache.flink.ml.api.ExampleStages.SumModel;
import org.apache.flink.ml.builder.Pipeline;
import org.apache.flink.ml.builder.PipelineModel;
+import
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
-import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.test.util.AbstractTestBase;
-import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
import org.junit.Test;
import java.nio.file.Files;
-import java.nio.file.Path;
-import java.nio.file.Paths;
import java.util.Arrays;
-import java.util.Comparator;
+import java.util.Collections;
import java.util.List;
/** Tests the behavior of Pipeline and PipelineModel. */
public class PipelineTest extends AbstractTestBase {
-
- // Executes the given stage and verifies that it produces the expected
output.
- private static void executeAndCheckOutput(
- Stage<?> stage, List<Integer> input, List<Integer> expectedOutput)
throws Exception {
- StreamExecutionEnvironment env =
StreamExecutionEnvironment.getExecutionEnvironment();
- StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+ private StreamExecutionEnvironment env;
+ private StreamTableEnvironment tEnv;
+
+ @Before
+ public void before() {
+ Configuration config = new Configuration();
+
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH,
true);
+ env = StreamExecutionEnvironment.getExecutionEnvironment(config);
env.setParallelism(4);
-
- Table inputTable = tEnv.fromDataStream(env.fromCollection(input));
-
- Table outputTable;
-
- if (stage instanceof AlgoOperator) {
- outputTable = ((AlgoOperator<?>) stage).transform(inputTable)[0];
- } else {
- Estimator<?, ?> estimator = (Estimator<?, ?>) stage;
- Model<?> model = estimator.fit(inputTable);
- outputTable = model.transform(inputTable)[0];
- }
-
- List<Integer> output =
- IteratorUtils.toList(
- tEnv.toDataStream(outputTable,
Integer.class).executeAndCollect());
- compareResultCollections(expectedOutput, output,
Comparator.naturalOrder());
+ env.enableCheckpointing(100);
+ env.setRestartStrategy(RestartStrategies.noRestart());
+ tEnv = StreamTableEnvironment.create(env);
}
@Test
public void testPipelineModel() throws Exception {
- StreamExecutionEnvironment env =
StreamExecutionEnvironment.getExecutionEnvironment();
- StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
// Builds a PipelineModel that increments input value by 60. This
PipelineModel consists of
// three stages where each stage increments input value by 10, 20, and
30 respectively.
SumModel modelA = new SumModel().setModelData(tEnv.fromValues(10));
@@ -77,41 +63,43 @@ public class PipelineTest extends AbstractTestBase {
List<Stage<?>> stages = Arrays.asList(modelA, modelB, modelC);
Model<?> model = new PipelineModel(stages);
+ List<List<Integer>> inputs =
Collections.singletonList(Arrays.asList(1, 2, 3));
+ List<Integer> output = Arrays.asList(61, 62, 63);
// Executes the original PipelineModel and verifies that it produces
the expected output.
- executeAndCheckOutput(model, Arrays.asList(1, 2, 3), Arrays.asList(61,
62, 63));
+ TestUtils.executeAndCheckOutput(env, model, inputs, output, null,
null);
// Saves and loads the PipelineModel.
- Path tempDir = Files.createTempDirectory("PipelineTest");
- String path = Paths.get(tempDir.toString(),
"testPipelineModelSaveLoad").toString();
+ String path = Files.createTempDirectory("").toString();
model.save(path);
- Model<?> loadedModel = PipelineModel.load(env, path);
+ env.execute();
+ Model<?> loadedModel = PipelineModel.load(env, path);
// Executes the loaded PipelineModel and verifies that it produces the
expected output.
- executeAndCheckOutput(loadedModel, Arrays.asList(1, 2, 3),
Arrays.asList(61, 62, 63));
+ TestUtils.executeAndCheckOutput(env, loadedModel, inputs, output,
null, null);
}
@Test
public void testPipeline() throws Exception {
- StreamExecutionEnvironment env =
StreamExecutionEnvironment.getExecutionEnvironment();
- StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
// Builds a Pipeline that consists of a Model, an Estimator, and a
model.
SumModel modelA = new SumModel().setModelData(tEnv.fromValues(10));
SumModel modelB = new SumModel().setModelData(tEnv.fromValues(30));
List<Stage<?>> stages = Arrays.asList(modelA, new SumEstimator(),
modelB);
Estimator<?, ?> estimator = new Pipeline(stages);
+ List<List<Integer>> inputs =
Collections.singletonList(Arrays.asList(1, 2, 3));
+ List<Integer> output = Arrays.asList(77, 78, 79);
// Executes the original Pipeline and verifies that it produces the
expected output.
- executeAndCheckOutput(estimator, Arrays.asList(1, 2, 3),
Arrays.asList(77, 78, 79));
+ TestUtils.executeAndCheckOutput(env, estimator, inputs, output, null,
null);
// Saves and loads the Pipeline.
- Path tempDir = Files.createTempDirectory("PipelineTest");
- String path = Paths.get(tempDir.toString(), "testPipeline").toString();
+ String path = Files.createTempDirectory("").toString();
estimator.save(path);
- Estimator<?, ?> loadedEstimator = Pipeline.load(env, path);
+ env.execute();
+ Estimator<?, ?> loadedEstimator = Pipeline.load(env, path);
// Executes the loaded Pipeline and verifies that it produces the
expected output.
- executeAndCheckOutput(loadedEstimator, Arrays.asList(1, 2, 3),
Arrays.asList(77, 78, 79));
+ TestUtils.executeAndCheckOutput(env, loadedEstimator, inputs, output,
null, null);
}
}
diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/api/StageTest.java
b/flink-ml-core/src/test/java/org/apache/flink/ml/api/StageTest.java
index df0db64..0d5115e 100644
--- a/flink-ml-core/src/test/java/org/apache/flink/ml/api/StageTest.java
+++ b/flink-ml-core/src/test/java/org/apache/flink/ml/api/StageTest.java
@@ -42,7 +42,6 @@ import org.junit.Test;
import java.io.IOException;
import java.nio.file.Files;
-import java.nio.file.Paths;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
@@ -175,8 +174,7 @@ public class StageTest {
ReadWriteUtils.setParam(stage, param, entry.getValue());
}
- String tempDir = Files.createTempDirectory("").toString();
- String path = Paths.get(tempDir, "test").toString();
+ String path = Files.createTempDirectory("").toString();
stage.save(path);
try {
stage.save(path);
diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/api/TestUtils.java
b/flink-ml-core/src/test/java/org/apache/flink/ml/api/TestUtils.java
new file mode 100644
index 0000000..512d4a0
--- /dev/null
+++ b/flink-ml-core/src/test/java/org/apache/flink/ml/api/TestUtils.java
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.api;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.TestBaseUtils;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.EOFException;
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.Comparator;
+import java.util.List;
+
+/** Utility methods for tests. */
+public class TestUtils {
+
+ // Executes the given stage using the given inputs and verifies that it
produces the expected
+ // output.
+ public static void executeAndCheckOutput(
+ StreamExecutionEnvironment env,
+ Stage<?> stage,
+ List<List<Integer>> inputs,
+ List<Integer> expectedOutput,
+ List<List<Integer>> modelDataInputs,
+ List<Integer> expectedModelDataOutput)
+ throws Exception {
+ StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+ Table[] inputTables = new Table[inputs.size()];
+ for (int i = 0; i < inputTables.length; i++) {
+ inputTables[i] =
tEnv.fromDataStream(env.fromCollection(inputs.get(i)));
+ }
+ Table outputTable = null;
+ Table modelDataOutputTable = null;
+
+ if (stage instanceof AlgoOperator) {
+ if (modelDataInputs != null) {
+ Table[] inputModelDataTables = new
Table[modelDataInputs.size()];
+ for (int i = 0; i < inputModelDataTables.length; i++) {
+ inputModelDataTables[i] =
+
tEnv.fromDataStream(env.fromCollection(modelDataInputs.get(i)));
+ }
+ ((Model<?>) stage).setModelData(inputModelDataTables);
+ }
+ outputTable = ((AlgoOperator<?>) stage).transform(inputTables)[0];
+ if (expectedModelDataOutput != null) {
+ modelDataOutputTable = ((Model<?>) stage).getModelData()[0];
+ }
+ } else {
+ Estimator<?, ?> estimator = (Estimator<?, ?>) stage;
+ Model<?> model = estimator.fit(inputTables);
+
+ if (modelDataInputs != null) {
+ Table[] inputModelDataTables = new
Table[modelDataInputs.size()];
+ for (int i = 0; i < inputModelDataTables.length; i++) {
+ inputModelDataTables[i] =
+
tEnv.fromDataStream(env.fromCollection(modelDataInputs.get(i)));
+ }
+ model.setModelData(inputModelDataTables);
+ }
+ outputTable = model.transform(inputTables)[0];
+ if (expectedModelDataOutput != null) {
+ modelDataOutputTable = model.getModelData()[0];
+ }
+ }
+
+ List<Integer> output =
+ IteratorUtils.toList(
+ tEnv.toDataStream(outputTable,
Integer.class).executeAndCollect());
+ TestBaseUtils.compareResultCollections(expectedOutput, output,
Comparator.naturalOrder());
+
+ if (expectedModelDataOutput != null) {
+ List<Integer> modelDataOutput =
+ IteratorUtils.toList(
+ tEnv.toDataStream(modelDataOutputTable,
Integer.class)
+ .executeAndCollect());
+ TestBaseUtils.compareResultCollections(
+ expectedModelDataOutput, modelDataOutput,
Comparator.naturalOrder());
+ }
+ }
+
+ /** Encoder for Integer. */
+ public static class IntEncoder implements Encoder<Integer> {
+ @Override
+ public void encode(Integer element, OutputStream stream) throws
IOException {
+ DataOutputStream dataStream = new DataOutputStream(stream);
+ dataStream.writeInt(element);
+ dataStream.flush();
+ }
+ }
+
+ /** Decoder for Integer. */
+ public static class IntegerStreamFormat extends
SimpleStreamFormat<Integer> {
+ @Override
+ public Reader<Integer> createReader(Configuration config,
FSDataInputStream stream) {
+ return new Reader<Integer>() {
+ private final DataInputStream dataStream = new
DataInputStream(stream);
+
+ @Override
+ public Integer read() throws IOException {
+ try {
+ return dataStream.readInt();
+ } catch (EOFException e) {
+ return null;
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ dataStream.close();
+ }
+ };
+ }
+
+ @Override
+ public TypeInformation<Integer> getProducedType() {
+ return BasicTypeInfo.INT_TYPE_INFO;
+ }
+ }
+}