This is an automated email from the ASF dual-hosted git repository.

gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new e9c10b8  [FLINK-23959][FLIP-175] Compose Estimator/Model/AlgoOperator 
from DAG of Estimator/Model/AlgoOperator
e9c10b8 is described below

commit e9c10b8acd4300a62c2193aef2948f933bc4c67d
Author: Dong Lin <[email protected]>
AuthorDate: Mon Oct 25 17:08:56 2021 +0800

    [FLINK-23959][FLIP-175] Compose Estimator/Model/AlgoOperator from DAG of 
Estimator/Model/AlgoOperator
    
    This closes #20.
---
 .../java/org/apache/flink/ml/builder/Graph.java    | 153 ++++++++
 .../org/apache/flink/ml/builder/GraphBuilder.java  | 434 +++++++++++++++++++++
 .../org/apache/flink/ml/builder/GraphData.java     | 105 +++++
 .../flink/ml/builder/GraphExecutionHelper.java     | 128 ++++++
 .../org/apache/flink/ml/builder/GraphModel.java    | 148 +++++++
 .../org/apache/flink/ml/builder/GraphNode.java     | 143 +++++++
 .../java/org/apache/flink/ml/builder/Pipeline.java |   3 +-
 .../org/apache/flink/ml/builder/PipelineModel.java |   3 +-
 .../java/org/apache/flink/ml/builder/TableId.java  |  73 ++++
 .../org/apache/flink/ml/util/ReadWriteUtils.java   |  79 ++++
 .../org/apache/flink/ml/api/ExampleStages.java     |  88 +++--
 .../java/org/apache/flink/ml/api/GraphTest.java    | 253 ++++++++++++
 .../java/org/apache/flink/ml/api/PipelineTest.java |  72 ++--
 .../java/org/apache/flink/ml/api/StageTest.java    |   4 +-
 .../java/org/apache/flink/ml/api/TestUtils.java    | 147 +++++++
 15 files changed, 1750 insertions(+), 83 deletions(-)

diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/builder/Graph.java 
b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/Graph.java
new file mode 100644
index 0000000..8123e04
--- /dev/null
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/Graph.java
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.builder;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.api.Stage;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.util.Preconditions;
+
+import javax.annotation.Nullable;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.apache.flink.ml.builder.GraphNode.StageType;
+
+/**
+ * A Graph acts as an Estimator. A Graph consists of a DAG of stages, each of 
which could be an
+ * Estimator, Model, Transformer or AlgoOperator. When `Graph::fit` is called, 
the stages are
+ * executed in a topologically-sorted order. If a stage is an Estimator, its 
`Estimator::fit` method
+ * will be called on the input tables (from the input edges) to fit a Model. 
Then the Model will be
+ * used to transform the input tables and produce output tables to the output 
edges. If a stage is
+ * an AlgoOperator, its `AlgoOperator::transform` method will be called on the 
input tables and
+ * produce output tables to the output edges. The GraphModel fitted from a 
Graph consists of the
+ * fitted Models and AlgoOperators, corresponding to the Graph's stages.
+ */
+@PublicEvolving
+public final class Graph implements Estimator<Graph, GraphModel> {
+    private static final long serialVersionUID = 6354253958813529308L;
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private final List<GraphNode> nodes;
+    private final TableId[] estimatorInputIds;
+    private final TableId[] modelInputIds;
+    private final TableId[] outputIds;
+    private final @Nullable TableId[] inputModelDataIds;
+    private final @Nullable TableId[] outputModelDataIds;
+
+    public Graph(
+            List<GraphNode> nodes,
+            TableId[] estimatorInputIds,
+            TableId[] modelInputs,
+            TableId[] outputs,
+            TableId[] inputModelDataIds,
+            TableId[] outputModelDataIds) {
+        this.nodes = Preconditions.checkNotNull(nodes);
+        this.estimatorInputIds = Preconditions.checkNotNull(estimatorInputIds);
+        this.modelInputIds = Preconditions.checkNotNull(modelInputs);
+        this.outputIds = Preconditions.checkNotNull(outputs);
+        this.inputModelDataIds = inputModelDataIds;
+        this.outputModelDataIds = outputModelDataIds;
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public GraphModel fit(Table... inputTables) {
+        Preconditions.checkArgument(
+                estimatorInputIds.length == inputTables.length,
+                "number of provided tables %s does not match the expected 
number of tables %s",
+                inputTables.length,
+                estimatorInputIds.length);
+        List<GraphNode> modelNodes = new ArrayList<>();
+        GraphExecutionHelper executionHelper = new GraphExecutionHelper(nodes);
+        // Maps estimatorInputIds to inputTables.
+        executionHelper.setTables(estimatorInputIds, inputTables);
+        // Iterates until we have executed all ready nodes.
+        GraphNode node;
+        while ((node = executionHelper.pollNextReadyNode()) != null) {
+            Stage<?> stage = node.stage;
+            // Invokes fit(...) if stageType == ESTIMATOR.
+            if (node.stageType == StageType.ESTIMATOR) {
+                stage =
+                        ((Estimator<?, ?>) stage)
+                                
.fit(executionHelper.getTables(node.estimatorInputIds));
+            }
+            // Invokes setModelData(...).
+            if (node.inputModelDataIds != null) {
+                Table[] nodeInputModelData = 
executionHelper.getTables(node.inputModelDataIds);
+                ((Model<?>) stage).setModelData(nodeInputModelData);
+            }
+            // Invokes transform(...).
+            Table[] nodeOutputs =
+                    ((AlgoOperator<?>) stage)
+                            
.transform(executionHelper.getTables(node.algoOpInputIds));
+            executionHelper.setTables(node.outputIds, nodeOutputs);
+            // Invokes getModelData().
+            if (node.outputModelDataIds != null) {
+                Table[] nodeOutputModelData = ((Model<?>) 
stage).getModelData();
+                executionHelper.setTables(node.outputModelDataIds, 
nodeOutputModelData);
+            }
+
+            modelNodes.add(
+                    new GraphNode(
+                            node.nodeId,
+                            stage,
+                            StageType.ALGO_OPERATOR,
+                            null,
+                            node.algoOpInputIds,
+                            node.outputIds,
+                            node.inputModelDataIds,
+                            node.outputModelDataIds));
+        }
+        return new GraphModel(
+                modelNodes, modelInputIds, outputIds, inputModelDataIds, 
outputModelDataIds);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        GraphData graphData =
+                new GraphData(
+                        nodes,
+                        estimatorInputIds,
+                        modelInputIds,
+                        outputIds,
+                        inputModelDataIds,
+                        outputModelDataIds);
+        ReadWriteUtils.saveGraph(this, graphData, path);
+    }
+
+    public static Graph load(StreamExecutionEnvironment env, String path) 
throws IOException {
+        return (Graph) ReadWriteUtils.loadGraph(env, path, 
Graph.class.getName());
+    }
+}
diff --git 
a/flink-ml-core/src/main/java/org/apache/flink/ml/builder/GraphBuilder.java 
b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/GraphBuilder.java
new file mode 100644
index 0000000..c3dd657
--- /dev/null
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/GraphBuilder.java
@@ -0,0 +1,434 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.builder;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.api.Stage;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.apache.flink.ml.builder.GraphNode.StageType;
+
+/**
+ * A GraphBuilder provides APIs to build Estimator/Model/AlgoOperator from a 
DAG of stages, each of
+ * which could be an Estimator, Model, Transformer or AlgoOperator.
+ */
+@PublicEvolving
+public final class GraphBuilder {
+
+    private int maxOutputLength = 20;
+
+    private int nextTableId = 0;
+
+    private int nextNodeId = 0;
+
+    /** An ordered list of nodes in the graph. */
+    private final List<GraphNode> nodes = new ArrayList<>();
+    /** A map from stage instance to the corresponding node in the graph. */
+    private final Map<Stage<?>, GraphNode> existingNodes = new HashMap<>();
+
+    public GraphBuilder() {}
+
+    /**
+     * Specifies the loose upper bound of the number of output tables that can 
be returned by the
+     * Model::getModelData() and AlgoOperator::transform() methods, for any 
stage involved in this
+     * Graph.
+     *
+     * <p>The default upper bound is 20.
+     */
+    public GraphBuilder setMaxOutputTableNum(int maxOutputLength) {
+        this.maxOutputLength = maxOutputLength;
+        return this;
+    }
+
+    /**
+     * Creates a TableId associated with this GraphBuilder. It can be used to 
specify the passing of
+     * tables between stages, as well as the input/output tables of the 
Graph/GraphModel generated
+     * by this builder.
+     *
+     * @return A TableId.
+     */
+    public TableId createTableId() {
+        return new TableId(nextTableId++);
+    }
+
+    /**
+     * Adds an AlgoOperator in the graph.
+     *
+     * <p>When the graph runs as Estimator, the transform() of the given 
AlgoOperator would be
+     * invoked with the given inputs. Then when the GraphModel fitted by this 
graph runs, the
+     * transform() of the given AlgoOperator would be invoked with the given 
inputs.
+     *
+     * <p>When the graph runs as AlgoOperator or Model, the transform() of the 
given AlgoOperator
+     * would be invoked with the given inputs.
+     *
+     * <p>NOTE: the number of the returned TableIds does not represent the 
actual number of Tables
+     * outputted by transform(). This number could be configured using {@link
+     * #setMaxOutputTableNum(int)}. Users should make sure that this number >= 
the actual number of
+     * Tables outputted by transform().
+     *
+     * @param algoOp An AlgoOperator instance.
+     * @param inputs A list of TableIds which represents inputs to transform() 
of the given
+     *     AlgoOperator.
+     * @return A list of TableIds which represents the outputs of transform() 
of the given
+     *     AlgoOperator.
+     */
+    public TableId[] addAlgoOperator(AlgoOperator<?> algoOp, TableId... 
inputs) {
+        return addStage(algoOp, StageType.ALGO_OPERATOR, null, inputs);
+    }
+
+    /**
+     * Adds an Estimator in the graph.
+     *
+     * <p>When the graph runs as Estimator, the fit() of the given Estimator 
would be invoked with
+     * the given inputs. Then when the GraphModel fitted by this graph runs, 
the transform() of the
+     * Model fitted by the given Estimator would be invoked with the given 
inputs.
+     *
+     * <p>When the graph runs as AlgoOperator or Model, the fit() of the given 
Estimator would be
+     * invoked with the given inputs, then the transform() of the Model fitted 
by the given
+     * Estimator would be invoked with the given inputs.
+     *
+     * <p>NOTE: the number of the returned TableIds does not represent the 
actual number of Tables
+     * outputted by transform(). This number could be configured using {@link
+     * #setMaxOutputTableNum(int)}. Users should make sure that this number >= 
the actual number of
+     * Tables outputted by transform().
+     *
+     * @param estimator An Estimator instance.
+     * @param inputs A list of TableIds which represents inputs to fit() of 
the given Estimator as
+     *     well as inputs to transform() of the Model fitted by the given 
Estimator.
+     * @return A list of TableIds which represents the outputs of transform() 
of the Model fitted by
+     *     the given Estimator.
+     */
+    public TableId[] addEstimator(Estimator<?, ?> estimator, TableId... 
inputs) {
+        return addEstimator(estimator, inputs, inputs);
+    }
+
+    /**
+     * Adds an Estimator in the graph.
+     *
+     * <p>When the graph runs as Estimator, the fit() of the given Estimator 
would be invoked with
+     * estimatorInputs. Then when the GraphModel fitted by this graph runs, 
the transform() of the
+     * Model fitted by the given Estimator would be invoked with modelInputs.
+     *
+     * <p>When the graph runs as AlgoOperator or Model, the fit() of the given 
Estimator would be
+     * invoked with estimatorInputs, then the transform() of the Model fitted 
by the given Estimator
+     * would be invoked with modelInputs.
+     *
+     * <p>NOTE: the number of the returned TableIds does not represent the 
actual number of Tables
+     * outputted by transform(). This number could be configured using {@link
+     * #setMaxOutputTableNum(int)}. Users should make sure that this number >= 
the actual number of
+     * Tables outputted by transform().
+     *
+     * @param estimator An Estimator instance.
+     * @param estimatorInputs A list of TableIds which represents inputs to 
fit() of the given
+     *     Estimator.
+     * @param modelInputs A list of TableIds which represents inputs to 
transform() of the Model
+     *     fitted by the given Estimator.
+     * @return A list of TableIds which represents the outputs of transform() 
of the Model fitted by
+     *     the given Estimator.
+     */
+    public TableId[] addEstimator(
+            Estimator<?, ?> estimator, TableId[] estimatorInputs, TableId[] 
modelInputs) {
+        return addStage(estimator, StageType.ESTIMATOR, estimatorInputs, 
modelInputs);
+    }
+
+    /**
+     * When the graph runs as Estimator, it first generates a GraphModel that 
contains the Model
+     * fitted by the given Estimator. Then when this GraphModel runs, the 
setModelData() of the
+     * fitted Model would be invoked with the given inputs before its 
transform() is invoked.
+     *
+     * <p>When the graph runs as AlgoOperator or Model, the setModelData() of 
the Model fitted by
+     * the given Estimator would be invoked with the given inputs before its 
transform() is invoked.
+     *
+     * @param estimator An Estimator instance.
+     * @param inputs A list of TableIds which represents inputs to 
setModelData() of the Model
+     *     fitted by the given Estimator.
+     */
+    public void setModelDataOnEstimator(Estimator<?, ?> estimator, TableId... 
inputs) {
+        GraphNode node = existingNodes.get(estimator);
+        if (node == null) {
+            throw new RuntimeException("the Estimator has not been added to 
the graph");
+        }
+        if (node.stageType != StageType.ESTIMATOR) {
+            throw new RuntimeException("the Estimator was previously added as 
an AlgoOperator");
+        }
+        if (node.inputModelDataIds != null) {
+            throw new RuntimeException("the model data of this Estimator has 
already been set");
+        }
+        node.inputModelDataIds = inputs;
+    }
+
+    /**
+     * When the graph runs as Estimator, the setModelData() of the given Model 
would be invoked with
+     * the given inputs before its transform() is invoked. Then when the 
GraphModel fitted by this
+     * graph runs, the setModelData() of the given Model would be invoked with 
the given inputs.
+     *
+     * <p>When the graph runs as AlgoOperator or Model, the setModelData() of 
the given Model would
+     * be invoked with the given inputs before its transform() is invoked.
+     *
+     * @param model A Model instance.
+     * @param inputs A list of TableIds which represents inputs to 
setModelData() of the given
+     *     Model.
+     */
+    public void setModelDataOnModel(Model<?> model, TableId... inputs) {
+        GraphNode node = existingNodes.get(model);
+        if (node == null) {
+            throw new RuntimeException("the Model has not been added to the 
graph");
+        }
+        if (node.stageType != StageType.ALGO_OPERATOR) {
+            throw new RuntimeException("the Model was previously added as an 
Estimator");
+        }
+        if (node.inputModelDataIds != null) {
+            throw new RuntimeException("the model data of this Model has 
already been set");
+        }
+        node.inputModelDataIds = inputs;
+    }
+
+    /**
+     * When the graph runs as Estimator, it first generates a GraphModel that 
contains the Model
+     * fitted by the given Estimator. Then when this GraphModel runs, the 
getModelData() of the
+     * fitted Model would be invoked.
+     *
+     * <p>When the graph runs as AlgoOperator or Model, the getModelData() of 
the Model fitted by
+     * the given Estimator would be invoked.
+     *
+     * <p>NOTE: the number of the returned TableIds does not represent the 
actual number of Tables
+     * outputted by getModelData(). This number could be configured using 
{@link
+     * #setMaxOutputTableNum(int)}. Users should make sure that this number >= 
the actual number of
+     * Tables outputted by getModelData().
+     *
+     * @param estimator An Estimator instance.
+     * @return A list of TableIds which represents the outputs of 
getModelData() of the Model fitted
+     *     by the given Estimator.
+     */
+    public TableId[] getModelDataFromEstimator(Estimator<?, ?> estimator) {
+        GraphNode node = existingNodes.get(estimator);
+        if (node == null) {
+            throw new RuntimeException("the Estimator has not been added to 
the graph");
+        }
+        if (node.stageType != StageType.ESTIMATOR) {
+            throw new RuntimeException("the Estimator was previously added as 
an AlgoOperator");
+        }
+        if (node.outputModelDataIds != null) {
+            throw new RuntimeException("the model data of this Estimator has 
already been fetched");
+        }
+        node.outputModelDataIds = createTableIds(maxOutputLength);
+        return node.outputModelDataIds;
+    }
+
+    /**
+     * When the graph runs as Estimator, the getModelData() of the given Model 
would be invoked.
+     * Then when the GraphModel fitted by this graph runs, the getModelData() 
of the given Model
+     * would be invoked.
+     *
+     * <p>When the graph runs as AlgoOperator or Model, the getModelData() of 
the given Model would
+     * be invoked.
+     *
+     * <p>NOTE: the number of the returned TableIds does not represent the 
actual number of Tables
+     * outputted by getModelData(). This number could be configured using 
{@link
+     * #setMaxOutputTableNum(int)}. Users should make sure that this number >= 
the actual number of
+     * Tables outputted by getModelData().
+     *
+     * @param model A Model instance.
+     * @return A list of TableIds which represents the outputs of 
getModelData() of the given Model.
+     */
+    public TableId[] getModelDataFromModel(Model<?> model) {
+        GraphNode node = existingNodes.get(model);
+        if (node == null) {
+            throw new RuntimeException("the Model has not been added to the 
graph");
+        }
+        if (node.stageType != StageType.ALGO_OPERATOR) {
+            throw new RuntimeException("the Model was previously added as an 
Estimator");
+        }
+        if (node.outputModelDataIds != null) {
+            throw new RuntimeException("the model data of this Model has 
already been fetched");
+        }
+        node.outputModelDataIds = createTableIds(maxOutputLength);
+        return node.outputModelDataIds;
+    }
+
+    /**
+     * Wraps nodes of the graph into an Estimator.
+     *
+     * <p>When the returned Estimator runs, and when the Model fitted by the 
returned Estimator
+     * runs, the sequence of operations recorded by the {@code 
addAlgoOperator(...)}, {@code
+     * addEstimator(...)}, {@code setModelData(...)} and {@code 
getModelData(...)} would be executed
+     * as specified in the Java doc of the corresponding methods.
+     *
+     * @param inputs A list of TableIds which represents inputs to fit() of 
the returned Estimator
+     *     as well as inputs to transform() of the Model fitted by the 
returned Estimator.
+     * @param outputs A list of TableIds which represents outputs of 
transform() of the Model fitted
+     *     by the returned Estimator.
+     * @return An Estimator which wraps the nodes of this graph.
+     */
+    public Estimator<?, ?> buildEstimator(TableId[] inputs, TableId[] outputs) 
{
+        return buildEstimator(inputs, inputs, outputs, null, null);
+    }
+
+    /**
+     * Wraps nodes of the graph into an Estimator.
+     *
+     * <p>When the returned Estimator runs, and when the Model fitted by the 
returned Estimator
+     * runs, the sequence of operations recorded by the {@code 
addAlgoOperator(...)}, {@code
+     * addEstimator(...)}, {@code setModelData(...)} and {@code 
getModelData(...)} would be executed
+     * as specified in the Java doc of the corresponding methods.
+     *
+     * @param inputs A list of TableIds which represents inputs to fit() of 
the returned Estimator
+     *     as well as inputs to transform() of the Model fitted by the 
returned Estimator.
+     * @param outputs A list of TableIds which represents outputs of 
transform() of the Model fitted
+     *     by the returned Estimator.
+     * @param inputModelData A list of TableIds which represents inputs to 
setModelData() of the
+     *     Model fitted by the returned Estimator.
+     * @param outputModelData A list of TableIds which represents outputs of 
getModelData() of the
+     *     Model fitted by the returned Estimator.
+     * @return An Estimator which wraps the nodes of this graph.
+     */
+    public Estimator<?, ?> buildEstimator(
+            TableId[] inputs,
+            TableId[] outputs,
+            TableId[] inputModelData,
+            TableId[] outputModelData) {
+        return buildEstimator(inputs, inputs, outputs, inputModelData, 
outputModelData);
+    }
+
+    /**
+     * Wraps nodes of the graph into an Estimator.
+     *
+     * <p>When the returned Estimator runs, and when the Model fitted by the 
returned Estimator
+     * runs, the sequence of operations recorded by the {@code 
addAlgoOperator(...)}, {@code
+     * addEstimator(...)}, {@code setModelData(...)} and {@code 
getModelData(...)} would be executed
+     * as specified in the Java doc of the corresponding methods.
+     *
+     * @param estimatorInputs A list of TableIds which represents inputs to 
fit() of the returned
+     *     Estimator.
+     * @param modelInputs A list of TableIds which represents inputs to 
transform() of the Model
+     *     fitted by the returned Estimator.
+     * @param outputs A list of TableIds which represents outputs of 
transform() of the Model fitted
+     *     by the returned Estimator.
+     * @param inputModelData A list of TableIds which represents inputs to 
setModelData() of the
+     *     Model fitted by the returned Estimator.
+     * @param outputModelData A list of TableIds which represents outputs of 
getModelData() of the
+     *     Model fitted by the returned Estimator.
+     * @return An Estimator which wraps the nodes of this graph.
+     */
+    public Estimator<?, ?> buildEstimator(
+            TableId[] estimatorInputs,
+            TableId[] modelInputs,
+            TableId[] outputs,
+            TableId[] inputModelData,
+            TableId[] outputModelData) {
+        return new Graph(
+                nodes, estimatorInputs, modelInputs, outputs, inputModelData, 
outputModelData);
+    }
+
+    /**
+     * Wraps nodes of the graph into an AlgoOperator.
+     *
+     * <p>When the returned AlgoOperator runs, the sequence of operations 
recorded by the {@code
+     * addAlgoOperator(...)} and {@code addEstimator(...)} would be executed 
as specified in the
+     * Java doc of the corresponding methods.
+     *
+     * @param inputs A list of TableIds which represents inputs to transform() 
of the returned
+     *     AlgoOperator.
+     * @param outputs A list of TableIds which represents outputs of 
transform() of the returned
+     *     AlgoOperator.
+     * @return An AlgoOperator which wraps the nodes of this graph.
+     */
+    public AlgoOperator<?> buildAlgoOperator(TableId[] inputs, TableId[] 
outputs) {
+        return buildModel(inputs, outputs, null, null);
+    }
+
+    /**
+     * Wraps nodes of the graph into a Model.
+     *
+     * <p>When the returned Model runs, the sequence of operations recorded by 
the {@code
+     * addAlgoOperator(...)} and {@code addEstimator(...)} would be executed 
as specified in the
+     * Java doc of the corresponding methods.
+     *
+     * @param inputs A list of TableIds which represents inputs to transform() 
of the returned
+     *     Model.
+     * @param outputs A list of TableIds which represents outputs of 
transform() of the returned
+     *     Model.
+     * @return A Model which wraps the nodes of this graph.
+     */
+    public Model<?> buildModel(TableId[] inputs, TableId[] outputs) {
+        return buildModel(inputs, outputs, null, null);
+    }
+
+    /**
+     * Wraps nodes of the graph into a Model.
+     *
+     * <p>When the returned Model runs, the sequence of operations recorded by 
the {@code
+     * addAlgoOperator(...)}, {@code addEstimator(...)}, {@code 
setModelData(...)} and {@code
+     * getModelData(...)} would be executed as specified in the Java doc of 
the corresponding
+     * methods.
+     *
+     * @param inputs A list of TableIds which represents inputs to transform() 
of the returned
+     *     Model.
+     * @param outputs A list of TableIds which represents outputs of 
transform() of the returned
+     *     Model.
+     * @param inputModelData A list of TableIds which represents inputs to 
setModelData() of the
+     *     returned Model.
+     * @param outputModelData A list of TableIds which represents outputs of 
getModelData() of the
+     *     returned Model.
+     * @return A Model which wraps the nodes of this graph.
+     */
+    public Model<?> buildModel(
+            TableId[] inputs,
+            TableId[] outputs,
+            TableId[] inputModelData,
+            TableId[] outputModelData) {
+        return new GraphModel(nodes, inputs, outputs, inputModelData, 
outputModelData);
+    }
+
+    private TableId[] createTableIds(int count) {
+        TableId[] result = new TableId[count];
+        for (int i = 0; i < count; i++) {
+            result[i] = createTableId();
+        }
+        return result;
+    }
+
+    private TableId[] addStage(
+            Stage<?> stage, StageType stageType, TableId[] estimatorInputs, 
TableId[] modelInputs) {
+        TableId[] outputs = createTableIds(maxOutputLength);
+        if (existingNodes.containsKey(stage)) {
+            throw new RuntimeException("The stage " + stage + " has already 
been added.");
+        }
+        GraphNode node =
+                new GraphNode(
+                        nextNodeId++,
+                        stage,
+                        stageType,
+                        estimatorInputs,
+                        modelInputs,
+                        outputs,
+                        null,
+                        null);
+        nodes.add(node);
+        existingNodes.put(stage, node);
+        return outputs;
+    }
+}
diff --git 
a/flink-ml-core/src/main/java/org/apache/flink/ml/builder/GraphData.java 
b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/GraphData.java
new file mode 100644
index 0000000..b688a93
--- /dev/null
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/GraphData.java
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.builder;
+
+import org.apache.flink.util.Preconditions;
+
+import javax.annotation.Nullable;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/** This class contains fields that can be used to re-construct Graph and 
GraphModel. */
+public class GraphData {
+    public final List<GraphNode> nodes;
+    public final @Nullable TableId[] estimatorInputIds;
+    public final TableId[] modelInputIds;
+    public final TableId[] outputIds;
+    public final @Nullable TableId[] inputModelDataIds;
+    public final @Nullable TableId[] outputModelDataIds;
+
+    public GraphData(
+            List<GraphNode> nodes,
+            TableId[] estimatorInputIds,
+            TableId[] modelInputIds,
+            TableId[] outputIds,
+            TableId[] inputModelDataIds,
+            TableId[] outputModelDataIds) {
+        this.nodes = Preconditions.checkNotNull(nodes);
+        this.estimatorInputIds = estimatorInputIds;
+        this.modelInputIds = Preconditions.checkNotNull(modelInputIds);
+        this.outputIds = Preconditions.checkNotNull(outputIds);
+        this.inputModelDataIds = inputModelDataIds;
+        this.outputModelDataIds = outputModelDataIds;
+    }
+
+    public Map<String, Object> toMap() {
+        Map<String, Object> result = new HashMap<>();
+
+        List<Map<String, Object>> nodeInfos = new ArrayList<>();
+        for (GraphNode node : nodes) {
+            nodeInfos.add(node.toMap());
+        }
+        result.put("nodes", nodeInfos);
+        if (estimatorInputIds != null) {
+            result.put("estimatorInputIds", TableId.toList(estimatorInputIds));
+        }
+        result.put("modelInputIds", TableId.toList(modelInputIds));
+        result.put("outputIds", TableId.toList(outputIds));
+        if (inputModelDataIds != null) {
+            result.put("inputModelDataIds", TableId.toList(inputModelDataIds));
+        }
+        if (outputModelDataIds != null) {
+            result.put("outputModelDataIds", 
TableId.toList(outputModelDataIds));
+        }
+        return result;
+    }
+
+    public static GraphData fromMap(Map<String, Object> map) {
+        List<GraphNode> nodes = new ArrayList<>();
+        List<Map<String, Object>> nodeInfos = (List<Map<String, Object>>) 
map.get("nodes");
+        for (Map<String, Object> nodeInfo : nodeInfos) {
+            nodes.add(GraphNode.fromMap(nodeInfo));
+        }
+
+        TableId[] estimatorInputIds = null;
+        if (map.containsKey("estimatorInputIds")) {
+            estimatorInputIds = TableId.fromList((List<Integer>) 
map.get("estimatorInputIds"));
+        }
+        TableId[] modelInputIds = TableId.fromList((List<Integer>) 
map.get("modelInputIds"));
+        TableId[] outputIds = TableId.fromList((List<Integer>) 
map.get("outputIds"));
+        TableId[] inputModelDataIds = null;
+        if (map.containsKey("inputModelDataIds")) {
+            inputModelDataIds = TableId.fromList((List<Integer>) 
map.get("inputModelDataIds"));
+        }
+        TableId[] outputModelDataIds = null;
+        if (map.containsKey("outputModelDataIds")) {
+            outputModelDataIds = TableId.fromList((List<Integer>) 
map.get("outputModelDataIds"));
+        }
+        return new GraphData(
+                nodes,
+                estimatorInputIds,
+                modelInputIds,
+                outputIds,
+                inputModelDataIds,
+                outputModelDataIds);
+    }
+}
diff --git 
a/flink-ml-core/src/main/java/org/apache/flink/ml/builder/GraphExecutionHelper.java
 
b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/GraphExecutionHelper.java
new file mode 100644
index 0000000..09973f4
--- /dev/null
+++ 
b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/GraphExecutionHelper.java
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.builder;
+
+import org.apache.flink.table.api.Table;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Deque;
+import java.util.HashMap;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * A container class that maintains the execution state of the graph (e.g. 
which nodes are ready to
+ * run).
+ */
+class GraphExecutionHelper {
+    /** A map from tableId to the list of nodes which take this table as 
input. */
+    private final Map<TableId, List<GraphNode>> consumerNodes = new 
HashMap<>();
+    /**
+     * A map from tableId to the corresponding table. A TableId would be 
mapped iff its
+     * corresponding Table has been constructed.
+     */
+    private final Map<TableId, Table> constructedTables = new HashMap<>();
+    /**
+     * A map that maintains the number of input tables that have not been 
constructed for each node
+     * in the graph.
+     */
+    private final Map<GraphNode, Integer> numUnConstructedInputTables = new 
HashMap<>();
+    /**
+     * An ordered list of nodes whose input tables have all been constructed 
AND who has not been
+     * fetch via pollNextReadyNode.
+     */
+    private final Deque<GraphNode> unFetchedReadyNodes = new LinkedList<>();
+
+    public GraphExecutionHelper(List<GraphNode> nodes) {
+        // Initializes dependentNodes and numUnConstructedInputs.
+        for (GraphNode node : nodes) {
+            List<TableId> inputs = new ArrayList<>();
+            inputs.addAll(Arrays.asList(node.algoOpInputIds));
+            if (node.estimatorInputIds != null) {
+                inputs.addAll(Arrays.asList(node.estimatorInputIds));
+            }
+            if (node.inputModelDataIds != null) {
+                inputs.addAll(Arrays.asList(node.inputModelDataIds));
+            }
+            numUnConstructedInputTables.put(node, inputs.size());
+            for (TableId tableId : inputs) {
+                consumerNodes.putIfAbsent(tableId, new ArrayList<>());
+                consumerNodes.get(tableId).add(node);
+            }
+        }
+    }
+
+    public void setTables(TableId[] tableIds, Table[] tables) {
+        // The length of tableIds could be larger than the length of tables 
because we over-allocate
+        // the number of tableIds (which is 20 by default) as placeholder of 
the stage's output
+        // tables when adding a stage in GraphBuilder.
+        Preconditions.checkArgument(
+                tableIds.length >= tables.length,
+                "the length of tablesIds %s is less than the length of tables 
%s",
+                tableIds.length,
+                tables.length);
+        for (int i = 0; i < tables.length; i++) {
+            setTable(tableIds[i], tables[i]);
+        }
+    }
+
+    private void setTable(TableId tableId, Table table) {
+        Preconditions.checkArgument(
+                !constructedTables.containsKey(tableId),
+                "the table with id=%s has already been constructed",
+                tableId.toString());
+        constructedTables.put(tableId, table);
+
+        for (GraphNode node : consumerNodes.getOrDefault(tableId, new 
ArrayList<>())) {
+            int prevNum = numUnConstructedInputTables.get(node);
+            if (prevNum == 1) {
+                unFetchedReadyNodes.addLast(node);
+                numUnConstructedInputTables.remove(node);
+            } else {
+                numUnConstructedInputTables.put(node, prevNum - 1);
+            }
+        }
+    }
+
+    public Table[] getTables(TableId[] tableIds) {
+        Table[] tables = new Table[tableIds.length];
+        for (int i = 0; i < tableIds.length; i++) {
+            tables[i] = getTable(tableIds[i]);
+        }
+        return tables;
+    }
+
+    private Table getTable(TableId tableId) {
+        Preconditions.checkArgument(
+                constructedTables.containsKey(tableId),
+                "the table with id=%s has not been constructed yet",
+                tableId.toString());
+        return constructedTables.get(tableId);
+    }
+
+    public GraphNode pollNextReadyNode() {
+        if (unFetchedReadyNodes.isEmpty() && 
!numUnConstructedInputTables.isEmpty()) {
+            throw new RuntimeException("there exists node whose input can not 
be constructed");
+        }
+        return unFetchedReadyNodes.pollFirst();
+    }
+}
diff --git 
a/flink-ml-core/src/main/java/org/apache/flink/ml/builder/GraphModel.java 
b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/GraphModel.java
new file mode 100644
index 0000000..894098f
--- /dev/null
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/GraphModel.java
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.builder;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.api.Stage;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.util.Preconditions;
+
+import javax.annotation.Nullable;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.apache.flink.ml.builder.GraphNode.StageType;
+
+/**
+ * A GraphModel acts as a Model. A GraphModel consists of a DAG of stages, 
each of which could be an
+ * Estimator, Model, Transformer or AlgoOperators. When 
`GraphModel::transform` is called, the
+ * stages are executed in a topologically-sorted order. When a stage is 
executed, its
+ * `AlgoOperator::transform` method will be called on the input tables (from 
the input edges) and
+ * produce output tables to the output edges.
+ */
+@PublicEvolving
+public final class GraphModel implements Model<GraphModel> {
+    private static final long serialVersionUID = 6354856913812529398L;
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private final List<GraphNode> nodes;
+    private final TableId[] inputIds;
+    private final TableId[] outputIds;
+    private final @Nullable TableId[] inputModelDataIds;
+    private final @Nullable TableId[] outputModelDataIds;
+    private final GraphExecutionHelper executionHelper;
+
+    public GraphModel(
+            List<GraphNode> nodes,
+            TableId[] inputIds,
+            TableId[] outputIds,
+            TableId[] inputModelDataIds,
+            TableId[] outputModelDataIds) {
+        this.nodes = Preconditions.checkNotNull(nodes);
+        this.inputIds = Preconditions.checkNotNull(inputIds);
+        this.outputIds = Preconditions.checkNotNull(outputIds);
+        this.inputModelDataIds = inputModelDataIds;
+        this.outputModelDataIds = outputModelDataIds;
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        executionHelper = new GraphExecutionHelper(nodes);
+    }
+
+    @Override
+    public Table[] transform(Table... inputTables) {
+        Preconditions.checkArgument(
+                inputIds.length == inputTables.length,
+                "number of provided tables %s does not match the expected 
number of tables %s",
+                inputTables.length,
+                inputIds.length);
+        // Maps inputIds to inputTables.
+        executionHelper.setTables(inputIds, inputTables);
+        // Iterates until we have executed all ready nodes.
+        GraphNode node;
+        while ((node = executionHelper.pollNextReadyNode()) != null) {
+            Stage<?> stage = node.stage;
+            // Invokes fit(...) if stageType == ESTIMATOR.
+            if (node.stageType == StageType.ESTIMATOR) {
+                stage =
+                        ((Estimator<?, ?>) stage)
+                                
.fit(executionHelper.getTables(node.estimatorInputIds));
+            }
+            // Invokes setModelData(...).
+            if (node.inputModelDataIds != null) {
+                Table[] nodeInputModelData = 
executionHelper.getTables(node.inputModelDataIds);
+                ((Model<?>) stage).setModelData(nodeInputModelData);
+            }
+            // Invokes transform(...).
+            Table[] nodeOutputs =
+                    ((AlgoOperator<?>) stage)
+                            
.transform(executionHelper.getTables(node.algoOpInputIds));
+            executionHelper.setTables(node.outputIds, nodeOutputs);
+            // Invokes getModelData().
+            if (node.outputModelDataIds != null) {
+                Table[] nodeOutputModelData = ((Model<?>) 
stage).getModelData();
+                executionHelper.setTables(node.outputModelDataIds, 
nodeOutputModelData);
+            }
+        }
+        return executionHelper.getTables(outputIds);
+    }
+
+    @Override
+    public GraphModel setModelData(Table... inputTables) {
+        Preconditions.checkArgument(inputModelDataIds != null, "setModelData() 
is not supported");
+        Preconditions.checkArgument(
+                inputModelDataIds.length == inputTables.length,
+                "number of provided tables %s does not match the expected 
number of tables %s",
+                inputTables.length,
+                inputIds.length);
+        // Maps inputModelDataIds to inputTables.
+        executionHelper.setTables(inputModelDataIds, inputTables);
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        Preconditions.checkArgument(outputModelDataIds != null);
+        return executionHelper.getTables(outputModelDataIds);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        GraphData graphData =
+                new GraphData(
+                        nodes, null, inputIds, outputIds, inputModelDataIds, 
outputModelDataIds);
+        ReadWriteUtils.saveGraph(this, graphData, path);
+    }
+
+    public static GraphModel load(StreamExecutionEnvironment env, String path) 
throws IOException {
+        return (GraphModel) ReadWriteUtils.loadGraph(env, path, 
GraphModel.class.getName());
+    }
+}
diff --git 
a/flink-ml-core/src/main/java/org/apache/flink/ml/builder/GraphNode.java 
b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/GraphNode.java
new file mode 100644
index 0000000..e9d680e
--- /dev/null
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/GraphNode.java
@@ -0,0 +1,143 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.builder;
+
+import org.apache.flink.ml.api.Stage;
+import org.apache.flink.util.Preconditions;
+
+import javax.annotation.Nullable;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+
+/** The Graph node class. */
+public class GraphNode {
+    /** This class specifies whether a node should be used as Estimator or 
AlgoOperator. */
+    public enum StageType {
+        ESTIMATOR,
+        ALGO_OPERATOR;
+    }
+
+    public final int nodeId;
+    public @Nullable Stage<?> stage;
+    public final StageType stageType;
+    public final @Nullable TableId[] estimatorInputIds;
+    public final TableId[] algoOpInputIds;
+    public final TableId[] outputIds;
+    public @Nullable TableId[] inputModelDataIds;
+    public @Nullable TableId[] outputModelDataIds;
+
+    public GraphNode(
+            int nodeId,
+            Stage<?> stage,
+            StageType stageType,
+            TableId[] estimatorInputIds,
+            TableId[] algoOpInputIds,
+            TableId[] outputIds,
+            TableId[] inputModelDataIds,
+            TableId[] outputModelDataIds) {
+        this.nodeId = Preconditions.checkNotNull(nodeId);
+        this.stage = stage;
+        this.stageType = Preconditions.checkNotNull(stageType);
+        this.estimatorInputIds = estimatorInputIds;
+        this.algoOpInputIds = Preconditions.checkNotNull(algoOpInputIds);
+        this.outputIds = Preconditions.checkNotNull(outputIds);
+        this.inputModelDataIds = inputModelDataIds;
+        this.outputModelDataIds = outputModelDataIds;
+    }
+
+    public Map<String, Object> toMap() {
+        Map<String, Object> result = new HashMap<>();
+        result.put("nodeId", nodeId);
+        result.put("stageType", stageType.name());
+        if (estimatorInputIds != null) {
+            result.put("estimatorInputIds", TableId.toList(estimatorInputIds));
+        }
+        result.put("algoOpInputIds", TableId.toList(algoOpInputIds));
+        result.put("outputIds", TableId.toList(outputIds));
+        if (inputModelDataIds != null) {
+            result.put("inputModelDataIds", TableId.toList(inputModelDataIds));
+        }
+        if (outputModelDataIds != null) {
+            result.put("outputModelDataIds", 
TableId.toList(outputModelDataIds));
+        }
+        return result;
+    }
+
+    public static GraphNode fromMap(Map<String, Object> map) {
+        int nodeId = (Integer) map.get("nodeId");
+        StageType stageType = StageType.valueOf((String) map.get("stageType"));
+        TableId[] estimatorInputIds = null;
+        if (map.containsKey("estimatorInputIds")) {
+            estimatorInputIds = TableId.fromList((List<Integer>) 
map.get("estimatorInputIds"));
+        }
+        TableId[] algoOpInputIds = TableId.fromList((List<Integer>) 
map.get("algoOpInputIds"));
+        TableId[] outputIds = TableId.fromList((List<Integer>) 
map.get("outputIds"));
+        TableId[] inputModelDataIds = null;
+        if (map.containsKey("inputModelDataIds")) {
+            inputModelDataIds = TableId.fromList((List<Integer>) 
map.get("inputModelDataIds"));
+        }
+        TableId[] outputModelDataIds = null;
+        if (map.containsKey("outputModelDataIds")) {
+            outputModelDataIds = TableId.fromList((List<Integer>) 
map.get("outputModelDataIds"));
+        }
+        return new GraphNode(
+                nodeId,
+                null,
+                stageType,
+                estimatorInputIds,
+                algoOpInputIds,
+                outputIds,
+                inputModelDataIds,
+                outputModelDataIds);
+    }
+
+    @Override
+    public boolean equals(Object obj) {
+        if (obj == null) {
+            return false;
+        }
+        if (!(obj instanceof GraphNode)) {
+            return false;
+        }
+        GraphNode other = (GraphNode) obj;
+        return nodeId == other.nodeId;
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hashCode(nodeId);
+    }
+
+    @Override
+    public String toString() {
+        return String.format(
+                "GraphNode(nodeId=%d, stageType=%s, estimatorInputIds=%s, 
algoOpInputIds=%s, outputIds=%s, inputModelDataIds=%s, outputModelDataIds=%s)",
+                nodeId,
+                stageType.name(),
+                Arrays.toString(estimatorInputIds),
+                Arrays.toString(algoOpInputIds),
+                Arrays.toString(outputIds),
+                Arrays.toString(inputModelDataIds),
+                Arrays.toString(outputModelDataIds));
+    }
+}
diff --git 
a/flink-ml-core/src/main/java/org/apache/flink/ml/builder/Pipeline.java 
b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/Pipeline.java
index 408198a..3e021ab 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/builder/Pipeline.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/Pipeline.java
@@ -28,6 +28,7 @@ import org.apache.flink.ml.util.ParamUtils;
 import org.apache.flink.ml.util.ReadWriteUtils;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.table.api.Table;
+import org.apache.flink.util.Preconditions;
 
 import java.io.IOException;
 import java.util.ArrayList;
@@ -47,7 +48,7 @@ public final class Pipeline implements Estimator<Pipeline, 
PipelineModel> {
     private final Map<Param<?>, Object> paramMap = new HashMap<>();
 
     public Pipeline(List<Stage<?>> stages) {
-        this.stages = stages;
+        this.stages = Preconditions.checkNotNull(stages);
         ParamUtils.initializeMapWithDefaultValues(paramMap, this);
     }
 
diff --git 
a/flink-ml-core/src/main/java/org/apache/flink/ml/builder/PipelineModel.java 
b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/PipelineModel.java
index 2668f82..ee5f099 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/builder/PipelineModel.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/PipelineModel.java
@@ -28,6 +28,7 @@ import org.apache.flink.ml.util.ParamUtils;
 import org.apache.flink.ml.util.ReadWriteUtils;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.table.api.Table;
+import org.apache.flink.util.Preconditions;
 
 import java.io.IOException;
 import java.util.Collections;
@@ -46,7 +47,7 @@ public final class PipelineModel implements 
Model<PipelineModel> {
     private final Map<Param<?>, Object> paramMap = new HashMap<>();
 
     public PipelineModel(List<Stage<?>> stages) {
-        this.stages = stages;
+        this.stages = Preconditions.checkNotNull(stages);
         ParamUtils.initializeMapWithDefaultValues(paramMap, this);
     }
 
diff --git 
a/flink-ml-core/src/main/java/org/apache/flink/ml/builder/TableId.java 
b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/TableId.java
new file mode 100644
index 0000000..51f566b
--- /dev/null
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/builder/TableId.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.builder;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Objects;
+
+/**
+ * The TableId is necessary to pass the inputs/outputs of various API calls 
across the nodes of
+ * Graph and GraphModel.
+ */
+public class TableId {
+    public final int id;
+
+    public TableId(int id) {
+        this.id = id;
+    }
+
+    public static List<Integer> toList(TableId[] tableIds) {
+        List<Integer> result = new ArrayList<>(tableIds.length);
+        for (int i = 0; i < tableIds.length; i++) {
+            result.add(tableIds[i].id);
+        }
+        return result;
+    }
+
+    public static TableId[] fromList(List<Integer> tableIds) {
+        TableId[] result = new TableId[tableIds.size()];
+        for (int i = 0; i < tableIds.size(); i++) {
+            result[i] = new TableId(tableIds.get(i));
+        }
+        return result;
+    }
+
+    @Override
+    public boolean equals(Object obj) {
+        if (obj == null) {
+            return false;
+        }
+        if (!(obj instanceof TableId)) {
+            return false;
+        }
+        TableId other = (TableId) obj;
+        return id == other.id;
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hashCode(id);
+    }
+
+    @Override
+    public String toString() {
+        return "TableId(" + id + ")";
+    }
+}
diff --git 
a/flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java 
b/flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java
index 2c21e20..eedc74d 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java
@@ -25,12 +25,17 @@ import org.apache.flink.connector.file.sink.FileSink;
 import org.apache.flink.connector.file.src.FileSource;
 import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
 import org.apache.flink.ml.api.Stage;
+import org.apache.flink.ml.builder.Graph;
+import org.apache.flink.ml.builder.GraphData;
+import org.apache.flink.ml.builder.GraphModel;
+import org.apache.flink.ml.builder.GraphNode;
 import org.apache.flink.ml.param.Param;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import 
org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner;
 import 
org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
 import org.apache.flink.util.InstantiationUtil;
+import org.apache.flink.util.Preconditions;
 
 import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper;
 
@@ -45,6 +50,7 @@ import java.lang.reflect.Method;
 import java.nio.file.Path;
 import java.nio.file.Paths;
 import java.util.ArrayList;
+import java.util.Comparator;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -231,6 +237,79 @@ public class ReadWriteUtils {
         return stages;
     }
 
+    /**
+     * Saves a Graph or GraphModel with the given GraphData to the given path.
+     *
+     * @param graph A Graph or GraphModel instance.
+     * @param graphData A GraphData instance.
+     * @param path The parent directory to save the graph metadata and its 
stages.
+     */
+    public static void saveGraph(Stage<?> graph, GraphData graphData, String 
path)
+            throws IOException {
+        // Creates parent directories if not already created.
+        new File(path).mkdirs();
+
+        Map<String, Object> extraMetadata = new HashMap<>();
+        extraMetadata.put("graphData", graphData.toMap());
+        saveMetadata(graph, path, extraMetadata);
+        int maxNodeId =
+                graphData.nodes.stream()
+                        .map(node -> node.nodeId)
+                        .max(Comparator.naturalOrder())
+                        .orElse(-1);
+
+        for (GraphNode node : graphData.nodes) {
+            String stagePath = getPathForPipelineStage(node.nodeId, maxNodeId 
+ 1, path);
+            node.stage.save(stagePath);
+        }
+    }
+
+    /**
+     * Loads a Graph or GraphModel from the given path.
+     *
+     * <p>The method throws RuntimeException if the expectedClassName is not 
empty AND it does not
+     * match the className of the previously saved Pipeline or PipelineModel.
+     *
+     * @param env A StreamExecutionEnvironment instance.
+     * @param path The parent directory to load the pipeline metadata and its 
stages.
+     * @param expectedClassName The expected class name of the pipeline.
+     * @return A Graph or GraphModel instance.
+     */
+    public static Stage<?> loadGraph(
+            StreamExecutionEnvironment env, String path, String 
expectedClassName)
+            throws IOException {
+        Map<String, ?> metadata = loadMetadata(path, expectedClassName);
+        GraphData graphData = GraphData.fromMap((Map<String, Object>) 
metadata.get("graphData"));
+
+        int maxNodeId =
+                graphData.nodes.stream()
+                        .map(node -> node.nodeId)
+                        .max(Comparator.naturalOrder())
+                        .orElse(-1);
+
+        for (GraphNode node : graphData.nodes) {
+            String stagePath = getPathForPipelineStage(node.nodeId, maxNodeId 
+ 1, path);
+            node.stage = loadStage(env, stagePath);
+        }
+
+        if (expectedClassName.equals(GraphModel.class.getName())) {
+            return new GraphModel(
+                    graphData.nodes,
+                    graphData.modelInputIds,
+                    graphData.outputIds,
+                    graphData.inputModelDataIds,
+                    graphData.outputModelDataIds);
+        }
+        
Preconditions.checkState(expectedClassName.equals(Graph.class.getName()));
+        return new Graph(
+                graphData.nodes,
+                graphData.estimatorInputIds,
+                graphData.modelInputIds,
+                graphData.outputIds,
+                graphData.inputModelDataIds,
+                graphData.outputModelDataIds);
+    }
+
     // A helper method that sets stage's parameter value. We can not call 
stage.set(param, value)
     // directly because stage::set(...) needs the actual type of the value.
     public static <T> void setParam(Stage<?> stage, Param<T> param, Object 
value) {
diff --git 
a/flink-ml-core/src/test/java/org/apache/flink/ml/api/ExampleStages.java 
b/flink-ml-core/src/test/java/org/apache/flink/ml/api/ExampleStages.java
index 204d477..63de3a2 100644
--- a/flink-ml-core/src/test/java/org/apache/flink/ml/api/ExampleStages.java
+++ b/flink-ml-core/src/test/java/org/apache/flink/ml/api/ExampleStages.java
@@ -38,16 +38,9 @@ import org.apache.flink.table.api.Table;
 import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
 import org.apache.flink.table.api.internal.TableImpl;
 
-import org.apache.commons.collections.IteratorUtils;
 import org.junit.Assert;
 
-import java.io.DataInputStream;
-import java.io.DataOutputStream;
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.FileOutputStream;
 import java.io.IOException;
-import java.nio.file.Paths;
 import java.util.HashMap;
 import java.util.Map;
 
@@ -60,6 +53,7 @@ public class ExampleStages {
     public static class SumModel implements Model<SumModel> {
         private final Map<Param<?>, Object> paramMap = new HashMap<>();
         private DataStream<Integer> modelData;
+        private Table modelDataTable;
 
         // This empty constructor is necessary in order for ModelA to be 
loaded by
         // ReadWriteUtils.createStageWithParam
@@ -94,46 +88,29 @@ public class ExampleStages {
                     (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
 
             modelData = tEnv.toDataStream(inputs[0], Integer.class);
+            modelDataTable = inputs[0];
             return this;
         }
 
         @Override
+        public Table[] getModelData() {
+            return new Table[] {modelDataTable};
+        }
+
+        @Override
         public void save(String path) throws IOException {
+            ReadWriteUtils.saveModelData(modelData, path, new 
TestUtils.IntEncoder());
             ReadWriteUtils.saveMetadata(this, path);
-
-            File dataDir = new File(path, "data");
-            if (!dataDir.mkdir()) {
-                throw new IOException("Directory " + dataDir.toString() + " 
already exists.");
-            }
-
-            File dataFile = new File(dataDir, "delta");
-            if (!dataFile.createNewFile()) {
-                throw new IOException("File " + dataFile.toString() + " 
already exists.");
-            }
-
-            int delta;
-            try {
-                delta = (Integer) 
IteratorUtils.toList(modelData.executeAndCollect()).get(0);
-            } catch (Exception e) {
-                throw new RuntimeException(e);
-            }
-
-            try (DataOutputStream outputStream =
-                    new DataOutputStream(new FileOutputStream(dataFile))) {
-                outputStream.writeInt(delta);
-            }
         }
 
         public static SumModel load(StreamExecutionEnvironment env, String 
path)
                 throws IOException {
-            SumModel model = ReadWriteUtils.loadStageParam(path);
-            File dataFile = Paths.get(path, "data", "delta").toFile();
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+            DataStream<Integer> modelData =
+                    ReadWriteUtils.loadModelData(env, path, new 
TestUtils.IntegerStreamFormat());
 
-            try (DataInputStream inputStream = new DataInputStream(new 
FileInputStream(dataFile))) {
-                StreamTableEnvironment tEnv = 
StreamTableEnvironment.create(env);
-                Table modelData = 
tEnv.fromDataStream(env.fromElements(inputStream.readInt()));
-                return model.setModelData(modelData);
-            }
+            SumModel model = ReadWriteUtils.loadStageParam(path);
+            return model.setModelData(tEnv.fromDataStream(modelData));
         }
     }
 
@@ -239,4 +216,43 @@ public class ExampleStages {
             sum += input.getValue();
         }
     }
+
+    /**
+     * A Transformer subclass that takes 2 inputs and returns the union of 
these two inputs as the
+     * output.
+     */
+    public static class UnionAlgoOperator implements 
Transformer<UnionAlgoOperator> {
+        private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+        public UnionAlgoOperator() {
+            ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        }
+
+        @Override
+        public Map<Param<?>, Object> getParamMap() {
+            return paramMap;
+        }
+
+        @Override
+        public Table[] transform(Table... inputs) {
+            Assert.assertEquals(2, inputs.length);
+            StreamTableEnvironment tEnv =
+                    (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+
+            DataStream<Integer> inputA = tEnv.toDataStream(inputs[0], 
Integer.class);
+            DataStream<Integer> inputB = tEnv.toDataStream(inputs[1], 
Integer.class);
+
+            return new Table[] {tEnv.fromDataStream(inputA.union(inputB))};
+        }
+
+        @Override
+        public void save(String path) throws IOException {
+            ReadWriteUtils.saveMetadata(this, path);
+        }
+
+        public static UnionAlgoOperator load(StreamExecutionEnvironment env, 
String path)
+                throws IOException {
+            return ReadWriteUtils.loadStageParam(path);
+        }
+    }
 }
diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/api/GraphTest.java 
b/flink-ml-core/src/test/java/org/apache/flink/ml/api/GraphTest.java
new file mode 100644
index 0000000..213ed3c
--- /dev/null
+++ b/flink-ml-core/src/test/java/org/apache/flink/ml/api/GraphTest.java
@@ -0,0 +1,253 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.api;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.api.ExampleStages.SumEstimator;
+import org.apache.flink.ml.api.ExampleStages.SumModel;
+import org.apache.flink.ml.api.ExampleStages.UnionAlgoOperator;
+import org.apache.flink.ml.builder.Graph;
+import org.apache.flink.ml.builder.GraphBuilder;
+import org.apache.flink.ml.builder.GraphModel;
+import org.apache.flink.ml.builder.TableId;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+
+import org.junit.Before;
+import org.junit.Test;
+
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+/** Tests the behavior of {@link Graph} and {@link GraphModel}. */
+public class GraphTest extends AbstractTestBase {
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+    }
+
+    // Executes the given stage using the given inputs and verifies that it 
produces the expected
+    // output. Then repeats this procedure after saving and loading the given 
stage.
+    private static void executeSaveLoadAndCheckOutput(
+            StreamExecutionEnvironment env,
+            Stage<?> stage,
+            List<List<Integer>> inputs,
+            List<Integer> expectedOutput,
+            List<List<Integer>> modelDataInputs,
+            List<Integer> expectedModelDataOutput,
+            boolean modelDataExists)
+            throws Exception {
+        // Executes the given stage and verifies that it produces the expected 
output.
+        TestUtils.executeAndCheckOutput(
+                env, stage, inputs, expectedOutput, modelDataInputs, 
expectedModelDataOutput);
+        // Saves and loads the given stage.
+        String path = Files.createTempDirectory("").toString();
+        stage.save(path);
+
+        if (modelDataExists) {
+            env.execute();
+        }
+
+        Stage<?> loadedStage = null;
+        if (stage instanceof Estimator) {
+            loadedStage = Graph.load(env, path);
+        } else {
+            loadedStage = GraphModel.load(env, path);
+        }
+        // Executes the loaded stage and verifies that it produces the 
expected output.
+        TestUtils.executeAndCheckOutput(
+                env, loadedStage, inputs, expectedOutput, modelDataInputs, 
expectedModelDataOutput);
+    }
+
+    @Test
+    public void testGraphModelWithoutEstimator() throws Exception {
+        GraphBuilder builder = new GraphBuilder();
+        // Creates nodes.
+        SumModel stage1 = new SumModel().setModelData(tEnv.fromValues(2));
+        SumModel stage2 = new SumModel().setModelData(tEnv.fromValues(1));
+        AlgoOperator<?> stage3 = new UnionAlgoOperator();
+        // Creates inputs.
+        TableId input1 = builder.createTableId();
+        TableId input2 = builder.createTableId();
+        // Feeds inputs and gets outputs.
+        TableId output1 = builder.addAlgoOperator(stage1, input1)[0];
+        TableId output2 = builder.addAlgoOperator(stage2, input2)[0];
+        TableId output3 = builder.addAlgoOperator(stage3, output1, output2)[0];
+
+        // Builds a Model from the graph.
+        TableId[] inputs = new TableId[] {input1, input2};
+        TableId[] outputs = new TableId[] {output3};
+        Model<?> model = builder.buildModel(inputs, outputs);
+        // Executes the GraphModel and verifies that it produces the expected 
output.
+        List<List<Integer>> inputValues = new ArrayList<>();
+        inputValues.add(Arrays.asList(1, 2, 3));
+        inputValues.add(Arrays.asList(10, 11, 12));
+        List<Integer> expectedOutputValues = Arrays.asList(3, 4, 5, 11, 12, 
13);
+        executeSaveLoadAndCheckOutput(
+                env, model, inputValues, expectedOutputValues, null, null, 
true);
+    }
+
+    @Test
+    public void testGraphModelWithEstimator() throws Exception {
+        GraphBuilder builder = new GraphBuilder();
+        // Creates nodes.
+        Estimator<?, ?> stage1 = new SumEstimator();
+        Estimator<?, ?> stage2 = new SumEstimator();
+        AlgoOperator<?> stage3 = new UnionAlgoOperator();
+        // Creates inputs.
+        TableId input1 = builder.createTableId();
+        TableId input2 = builder.createTableId();
+        // Feeds inputs and gets outputs.
+        TableId output1 = builder.addEstimator(stage1, input1)[0];
+        TableId output2 = builder.addEstimator(stage2, input2)[0];
+        TableId output3 = builder.addAlgoOperator(stage3, output1, output2)[0];
+
+        // Builds a Model from the graph.
+        TableId[] inputs = new TableId[] {input1, input2};
+        TableId[] outputs = new TableId[] {output3};
+        Model<?> model = builder.buildModel(inputs, outputs);
+        // Executes the GraphModel and verifies that it produces the expected 
output.
+        List<List<Integer>> inputValues = new ArrayList<>();
+        inputValues.add(Arrays.asList(1, 2, 3));
+        inputValues.add(Arrays.asList(10, 11, 12));
+        List<Integer> expectedOutputValues = Arrays.asList(7, 8, 9, 43, 44, 
45);
+        executeSaveLoadAndCheckOutput(
+                env, model, inputValues, expectedOutputValues, null, null, 
false);
+    }
+
+    @Test
+    public void testGraphModelWithSetGetModelData() throws Exception {
+        GraphBuilder builder = new GraphBuilder();
+        // Creates nodes.
+        SumModel stage1 = new SumModel().setModelData(tEnv.fromValues(1));
+        SumModel stage2 = new SumModel();
+        SumModel stage3 = new SumModel().setModelData(tEnv.fromValues(3));
+        // Creates inputs and modelDataInputs.
+        TableId input = builder.createTableId();
+        TableId modelDataInput = builder.createTableId();
+        // Feeds inputs and gets outputs.
+        TableId output1 = builder.addAlgoOperator(stage1, input)[0];
+        TableId output2 = builder.addAlgoOperator(stage2, output1)[0];
+        builder.setModelDataOnModel(stage2, modelDataInput);
+        TableId output3 = builder.addAlgoOperator(stage3, output2)[0];
+        TableId modelDataOutput = builder.getModelDataFromModel(stage3)[0];
+
+        // Builds a Model from the graph.
+        TableId[] inputs = new TableId[] {input};
+        TableId[] outputs = new TableId[] {output3};
+        TableId[] modelDataInputs = new TableId[] {modelDataInput};
+        TableId[] modelDataOutputs = new TableId[] {modelDataOutput};
+        Model<?> model = builder.buildModel(inputs, outputs, modelDataInputs, 
modelDataOutputs);
+        // Executes the GraphModel and verifies that it produces the expected 
output.
+        List<List<Integer>> inputValues = 
Collections.singletonList(Arrays.asList(1, 2, 3));
+        List<Integer> expectedOutputValues = Arrays.asList(7, 8, 9);
+        List<List<Integer>> inputModelDataValues =
+                Collections.singletonList(Collections.singletonList(2));
+        List<Integer> expectedModelDataOutputValues = 
Collections.singletonList(3);
+        executeSaveLoadAndCheckOutput(
+                env,
+                model,
+                inputValues,
+                expectedOutputValues,
+                inputModelDataValues,
+                expectedModelDataOutputValues,
+                true);
+    }
+
+    @Test
+    public void testGraphWithEstimator() throws Exception {
+        GraphBuilder builder = new GraphBuilder();
+        // Creates nodes.
+        Estimator<?, ?> stage1 = new SumEstimator();
+        Estimator<?, ?> stage2 = new SumEstimator();
+        AlgoOperator<?> stage3 = new UnionAlgoOperator();
+        // Creates inputs.
+        TableId input1 = builder.createTableId();
+        TableId input2 = builder.createTableId();
+        // Feeds inputs and gets outputs.
+        TableId output1 = builder.addEstimator(stage1, input1)[0];
+        TableId output2 = builder.addEstimator(stage2, input2)[0];
+        TableId output3 = builder.addAlgoOperator(stage3, output1, output2)[0];
+
+        // Builds an Estimator from the graph.
+        TableId[] inputs = new TableId[] {input1, input2};
+        TableId[] outputs = new TableId[] {output3};
+        Estimator<?, ?> estimator = builder.buildEstimator(inputs, outputs);
+        // Executes the Graph and verifies that it produces the expected 
output.
+        List<List<Integer>> inputValues = new ArrayList<>();
+        inputValues.add(Arrays.asList(1, 2, 3));
+        inputValues.add(Arrays.asList(10, 11, 12));
+        List<Integer> expectedOutputValues = Arrays.asList(7, 8, 9, 43, 44, 
45);
+        executeSaveLoadAndCheckOutput(
+                env, estimator, inputValues, expectedOutputValues, null, null, 
false);
+    }
+
+    @Test
+    public void testGraphWithSetGetModelData() throws Exception {
+        GraphBuilder builder = new GraphBuilder();
+        // Creates nodes.
+        Estimator<?, ?> stage1 = new SumEstimator();
+        SumModel stage2 = new SumModel();
+        AlgoOperator<?> stage3 = new UnionAlgoOperator();
+        // Creates inputs.
+        TableId input1 = builder.createTableId();
+        TableId input2 = builder.createTableId();
+        // Feeds inputs and gets outputs.
+        TableId output1 = builder.addEstimator(stage1, input1)[0];
+        TableId modelDataOutput = builder.getModelDataFromEstimator(stage1)[0];
+        TableId output2 = builder.addAlgoOperator(stage2, input2)[0];
+        builder.setModelDataOnModel(stage2, modelDataOutput);
+        TableId output3 = builder.addAlgoOperator(stage3, output1, output2)[0];
+
+        // Builds an Estimator from the graph.
+        TableId[] inputs = new TableId[] {input1, input2};
+        TableId[] outputs = new TableId[] {output3};
+        TableId[] modelDataOutputs = new TableId[] {modelDataOutput};
+        Estimator<?, ?> estimator = builder.buildEstimator(inputs, outputs, 
null, modelDataOutputs);
+        // Executes the Graph and verifies that it produces the expected 
output.
+        List<List<Integer>> inputValues = new ArrayList<>();
+        inputValues.add(Arrays.asList(1, 2, 3));
+        inputValues.add(Arrays.asList(10, 11, 12));
+        List<Integer> expectedOutputValues = Arrays.asList(7, 8, 9, 16, 17, 
18);
+        List<Integer> expectedModelDataOutputValues = 
Collections.singletonList(6);
+        executeSaveLoadAndCheckOutput(
+                env,
+                estimator,
+                inputValues,
+                expectedOutputValues,
+                null,
+                expectedModelDataOutputValues,
+                true);
+    }
+}
diff --git 
a/flink-ml-core/src/test/java/org/apache/flink/ml/api/PipelineTest.java 
b/flink-ml-core/src/test/java/org/apache/flink/ml/api/PipelineTest.java
index 8ce3fe3..105fcf9 100644
--- a/flink-ml-core/src/test/java/org/apache/flink/ml/api/PipelineTest.java
+++ b/flink-ml-core/src/test/java/org/apache/flink/ml/api/PipelineTest.java
@@ -18,57 +18,43 @@
 
 package org.apache.flink.ml.api;
 
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
 import org.apache.flink.ml.api.ExampleStages.SumEstimator;
 import org.apache.flink.ml.api.ExampleStages.SumModel;
 import org.apache.flink.ml.builder.Pipeline;
 import org.apache.flink.ml.builder.PipelineModel;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
-import org.apache.flink.table.api.Table;
 import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
 import org.apache.flink.test.util.AbstractTestBase;
 
-import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
 import org.junit.Test;
 
 import java.nio.file.Files;
-import java.nio.file.Path;
-import java.nio.file.Paths;
 import java.util.Arrays;
-import java.util.Comparator;
+import java.util.Collections;
 import java.util.List;
 
 /** Tests the behavior of Pipeline and PipelineModel. */
 public class PipelineTest extends AbstractTestBase {
-
-    // Executes the given stage and verifies that it produces the expected 
output.
-    private static void executeAndCheckOutput(
-            Stage<?> stage, List<Integer> input, List<Integer> expectedOutput) 
throws Exception {
-        StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
-        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
         env.setParallelism(4);
-
-        Table inputTable = tEnv.fromDataStream(env.fromCollection(input));
-
-        Table outputTable;
-
-        if (stage instanceof AlgoOperator) {
-            outputTable = ((AlgoOperator<?>) stage).transform(inputTable)[0];
-        } else {
-            Estimator<?, ?> estimator = (Estimator<?, ?>) stage;
-            Model<?> model = estimator.fit(inputTable);
-            outputTable = model.transform(inputTable)[0];
-        }
-
-        List<Integer> output =
-                IteratorUtils.toList(
-                        tEnv.toDataStream(outputTable, 
Integer.class).executeAndCollect());
-        compareResultCollections(expectedOutput, output, 
Comparator.naturalOrder());
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
     }
 
     @Test
     public void testPipelineModel() throws Exception {
-        StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
-        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
         // Builds a PipelineModel that increments input value by 60. This 
PipelineModel consists of
         // three stages where each stage increments input value by 10, 20, and 
30 respectively.
         SumModel modelA = new SumModel().setModelData(tEnv.fromValues(10));
@@ -77,41 +63,43 @@ public class PipelineTest extends AbstractTestBase {
 
         List<Stage<?>> stages = Arrays.asList(modelA, modelB, modelC);
         Model<?> model = new PipelineModel(stages);
+        List<List<Integer>> inputs = 
Collections.singletonList(Arrays.asList(1, 2, 3));
+        List<Integer> output = Arrays.asList(61, 62, 63);
 
         // Executes the original PipelineModel and verifies that it produces 
the expected output.
-        executeAndCheckOutput(model, Arrays.asList(1, 2, 3), Arrays.asList(61, 
62, 63));
+        TestUtils.executeAndCheckOutput(env, model, inputs, output, null, 
null);
 
         // Saves and loads the PipelineModel.
-        Path tempDir = Files.createTempDirectory("PipelineTest");
-        String path = Paths.get(tempDir.toString(), 
"testPipelineModelSaveLoad").toString();
+        String path = Files.createTempDirectory("").toString();
         model.save(path);
-        Model<?> loadedModel = PipelineModel.load(env, path);
+        env.execute();
 
+        Model<?> loadedModel = PipelineModel.load(env, path);
         // Executes the loaded PipelineModel and verifies that it produces the 
expected output.
-        executeAndCheckOutput(loadedModel, Arrays.asList(1, 2, 3), 
Arrays.asList(61, 62, 63));
+        TestUtils.executeAndCheckOutput(env, loadedModel, inputs, output, 
null, null);
     }
 
     @Test
     public void testPipeline() throws Exception {
-        StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
-        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
         // Builds a Pipeline that consists of a Model, an Estimator, and a 
model.
         SumModel modelA = new SumModel().setModelData(tEnv.fromValues(10));
         SumModel modelB = new SumModel().setModelData(tEnv.fromValues(30));
 
         List<Stage<?>> stages = Arrays.asList(modelA, new SumEstimator(), 
modelB);
         Estimator<?, ?> estimator = new Pipeline(stages);
+        List<List<Integer>> inputs = 
Collections.singletonList(Arrays.asList(1, 2, 3));
+        List<Integer> output = Arrays.asList(77, 78, 79);
 
         // Executes the original Pipeline and verifies that it produces the 
expected output.
-        executeAndCheckOutput(estimator, Arrays.asList(1, 2, 3), 
Arrays.asList(77, 78, 79));
+        TestUtils.executeAndCheckOutput(env, estimator, inputs, output, null, 
null);
 
         // Saves and loads the Pipeline.
-        Path tempDir = Files.createTempDirectory("PipelineTest");
-        String path = Paths.get(tempDir.toString(), "testPipeline").toString();
+        String path = Files.createTempDirectory("").toString();
         estimator.save(path);
-        Estimator<?, ?> loadedEstimator = Pipeline.load(env, path);
+        env.execute();
 
+        Estimator<?, ?> loadedEstimator = Pipeline.load(env, path);
         // Executes the loaded Pipeline and verifies that it produces the 
expected output.
-        executeAndCheckOutput(loadedEstimator, Arrays.asList(1, 2, 3), 
Arrays.asList(77, 78, 79));
+        TestUtils.executeAndCheckOutput(env, loadedEstimator, inputs, output, 
null, null);
     }
 }
diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/api/StageTest.java 
b/flink-ml-core/src/test/java/org/apache/flink/ml/api/StageTest.java
index df0db64..0d5115e 100644
--- a/flink-ml-core/src/test/java/org/apache/flink/ml/api/StageTest.java
+++ b/flink-ml-core/src/test/java/org/apache/flink/ml/api/StageTest.java
@@ -42,7 +42,6 @@ import org.junit.Test;
 
 import java.io.IOException;
 import java.nio.file.Files;
-import java.nio.file.Paths;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
@@ -175,8 +174,7 @@ public class StageTest {
             ReadWriteUtils.setParam(stage, param, entry.getValue());
         }
 
-        String tempDir = Files.createTempDirectory("").toString();
-        String path = Paths.get(tempDir, "test").toString();
+        String path = Files.createTempDirectory("").toString();
         stage.save(path);
         try {
             stage.save(path);
diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/api/TestUtils.java 
b/flink-ml-core/src/test/java/org/apache/flink/ml/api/TestUtils.java
new file mode 100644
index 0000000..512d4a0
--- /dev/null
+++ b/flink-ml-core/src/test/java/org/apache/flink/ml/api/TestUtils.java
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.api;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.TestBaseUtils;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.EOFException;
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.Comparator;
+import java.util.List;
+
+/** Utility methods for tests. */
+public class TestUtils {
+
+    // Executes the given stage using the given inputs and verifies that it 
produces the expected
+    // output.
+    public static void executeAndCheckOutput(
+            StreamExecutionEnvironment env,
+            Stage<?> stage,
+            List<List<Integer>> inputs,
+            List<Integer> expectedOutput,
+            List<List<Integer>> modelDataInputs,
+            List<Integer> expectedModelDataOutput)
+            throws Exception {
+        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+        Table[] inputTables = new Table[inputs.size()];
+        for (int i = 0; i < inputTables.length; i++) {
+            inputTables[i] = 
tEnv.fromDataStream(env.fromCollection(inputs.get(i)));
+        }
+        Table outputTable = null;
+        Table modelDataOutputTable = null;
+
+        if (stage instanceof AlgoOperator) {
+            if (modelDataInputs != null) {
+                Table[] inputModelDataTables = new 
Table[modelDataInputs.size()];
+                for (int i = 0; i < inputModelDataTables.length; i++) {
+                    inputModelDataTables[i] =
+                            
tEnv.fromDataStream(env.fromCollection(modelDataInputs.get(i)));
+                }
+                ((Model<?>) stage).setModelData(inputModelDataTables);
+            }
+            outputTable = ((AlgoOperator<?>) stage).transform(inputTables)[0];
+            if (expectedModelDataOutput != null) {
+                modelDataOutputTable = ((Model<?>) stage).getModelData()[0];
+            }
+        } else {
+            Estimator<?, ?> estimator = (Estimator<?, ?>) stage;
+            Model<?> model = estimator.fit(inputTables);
+
+            if (modelDataInputs != null) {
+                Table[] inputModelDataTables = new 
Table[modelDataInputs.size()];
+                for (int i = 0; i < inputModelDataTables.length; i++) {
+                    inputModelDataTables[i] =
+                            
tEnv.fromDataStream(env.fromCollection(modelDataInputs.get(i)));
+                }
+                model.setModelData(inputModelDataTables);
+            }
+            outputTable = model.transform(inputTables)[0];
+            if (expectedModelDataOutput != null) {
+                modelDataOutputTable = model.getModelData()[0];
+            }
+        }
+
+        List<Integer> output =
+                IteratorUtils.toList(
+                        tEnv.toDataStream(outputTable, 
Integer.class).executeAndCollect());
+        TestBaseUtils.compareResultCollections(expectedOutput, output, 
Comparator.naturalOrder());
+
+        if (expectedModelDataOutput != null) {
+            List<Integer> modelDataOutput =
+                    IteratorUtils.toList(
+                            tEnv.toDataStream(modelDataOutputTable, 
Integer.class)
+                                    .executeAndCollect());
+            TestBaseUtils.compareResultCollections(
+                    expectedModelDataOutput, modelDataOutput, 
Comparator.naturalOrder());
+        }
+    }
+
+    /** Encoder for Integer. */
+    public static class IntEncoder implements Encoder<Integer> {
+        @Override
+        public void encode(Integer element, OutputStream stream) throws 
IOException {
+            DataOutputStream dataStream = new DataOutputStream(stream);
+            dataStream.writeInt(element);
+            dataStream.flush();
+        }
+    }
+
+    /** Decoder for Integer. */
+    public static class IntegerStreamFormat extends 
SimpleStreamFormat<Integer> {
+        @Override
+        public Reader<Integer> createReader(Configuration config, 
FSDataInputStream stream) {
+            return new Reader<Integer>() {
+                private final DataInputStream dataStream = new 
DataInputStream(stream);
+
+                @Override
+                public Integer read() throws IOException {
+                    try {
+                        return dataStream.readInt();
+                    } catch (EOFException e) {
+                        return null;
+                    }
+                }
+
+                @Override
+                public void close() throws IOException {
+                    dataStream.close();
+                }
+            };
+        }
+
+        @Override
+        public TypeInformation<Integer> getProducedType() {
+            return BasicTypeInfo.INT_TYPE_INFO;
+        }
+    }
+}

Reply via email to