This is an automated email from the ASF dual-hosted git repository.

jqin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new 5ff346e  [FLINK-22915][FLIP-173] Update Flink ML API to support 
AlgoOperator with multiple input tables and multiple output tables
5ff346e is described below

commit 5ff346ea1a508a00b89759492f09e7330e69baef
Author: Dong Lin <[email protected]>
AuthorDate: Wed Sep 22 13:47:39 2021 +0800

    [FLINK-22915][FLIP-173] Update Flink ML API to support AlgoOperator with 
multiple input tables and multiple output tables
---
 .../core/{PipelineStage.java => AlgoOperator.java} |  35 ++-
 .../org/apache/flink/ml/api/core/Estimator.java    |  24 +-
 .../java/org/apache/flink/ml/api/core/Model.java   |  34 ++-
 .../org/apache/flink/ml/api/core/Pipeline.java     | 257 +++++----------------
 .../apache/flink/ml/api/core/PipelineModel.java    |  83 +++++++
 .../java/org/apache/flink/ml/api/core/Stage.java   |  44 ++++
 .../org/apache/flink/ml/api/core/Transformer.java  |  22 +-
 .../org/apache/flink/ml/api/core/PipelineTest.java |  69 +++---
 8 files changed, 269 insertions(+), 299 deletions(-)

diff --git 
a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/PipelineStage.java 
b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/AlgoOperator.java
similarity index 50%
rename from 
flink-ml-api/src/main/java/org/apache/flink/ml/api/core/PipelineStage.java
rename to 
flink-ml-api/src/main/java/org/apache/flink/ml/api/core/AlgoOperator.java
index 0a3dd23..7f2d4b4 100644
--- a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/PipelineStage.java
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/AlgoOperator.java
@@ -18,29 +18,22 @@
 
 package org.apache.flink.ml.api.core;
 
-import org.apache.flink.ml.api.misc.param.WithParams;
-
-import java.io.Serializable;
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.table.api.Table;
 
 /**
- * Base class for a stage in a pipeline. The interface is only a concept, and 
does not have any
- * actual functionality. Its subclasses must be either Estimator or 
Transformer. No other classes
- * should inherit this interface directly.
- *
- * <p>Each pipeline stage is with parameters, and requires a public empty 
constructor for
- * restoration in Pipeline.
+ * An AlgoOperator takes a list of tables as inputs and produces a list of 
tables as results. It can
+ * be used to encode generic multi-input multi-output computation logic.
  *
- * @param <T> The class type of the PipelineStage implementation itself, used 
by {@link
- *     org.apache.flink.ml.api.misc.param.WithParams}
- * @see WithParams
+ * @param <T> The class type of the AlgoOperator implementation itself.
  */
-interface PipelineStage<T extends PipelineStage<T>> extends WithParams<T>, 
Serializable {
-
-    default String toJson() {
-        return getParams().toJson();
-    }
-
-    default void loadJson(String json) {
-        getParams().loadJson(json);
-    }
+@PublicEvolving
+public interface AlgoOperator<T extends AlgoOperator<T>> extends Stage<T> {
+    /**
+     * Applies the AlgoOperator on the given input tables and returns the 
result tables.
+     *
+     * @param inputs a list of tables
+     * @return a list of tables
+     */
+    Table[] transform(Table... inputs);
 }
diff --git 
a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Estimator.java 
b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Estimator.java
index 24c8349..bab9c7d 100644
--- a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Estimator.java
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Estimator.java
@@ -20,28 +20,20 @@ package org.apache.flink.ml.api.core;
 
 import org.apache.flink.annotation.PublicEvolving;
 import org.apache.flink.table.api.Table;
-import org.apache.flink.table.api.TableEnvironment;
 
 /**
- * Estimators are {@link PipelineStage}s responsible for training and 
generating machine learning
- * models.
+ * Estimators are responsible for training and generating Models.
  *
- * <p>The implementations are expected to take an input table as training 
samples and generate a
- * {@link Model} which fits these samples.
- *
- * @param <E> class type of the Estimator implementation itself, used by {@link
- *     org.apache.flink.ml.api.misc.param.WithParams}.
- * @param <M> class type of the {@link Model} this Estimator produces.
+ * @param <E> class type of the Estimator implementation itself.
+ * @param <M> class type of the Model this Estimator produces.
  */
 @PublicEvolving
-public interface Estimator<E extends Estimator<E, M>, M extends Model<M>> 
extends PipelineStage<E> {
-
+public interface Estimator<E extends Estimator<E, M>, M extends Model<M>> 
extends Stage<E> {
     /**
-     * Train and produce a {@link Model} which fits the records in the given 
{@link Table}.
+     * Trains on the given inputs and produces a Model.
      *
-     * @param tEnv the table environment to which the input table is bound.
-     * @param input the table with records to train the Model.
-     * @return a model trained to fit on the given Table.
+     * @param inputs a list of tables
+     * @return a Model
      */
-    M fit(TableEnvironment tEnv, Table input);
+    M fit(Table... inputs);
 }
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Model.java 
b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Model.java
index 6f15bc5..8caffe3 100644
--- a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Model.java
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Model.java
@@ -22,16 +22,30 @@ import org.apache.flink.annotation.PublicEvolving;
 import org.apache.flink.table.api.Table;
 
 /**
- * A model is an ordinary {@link Transformer} except how it is created. While 
ordinary transformers
- * are defined by specifying the parameters directly, a model is usually 
generated by an {@link
- * Estimator} when {@link 
Estimator#fit(org.apache.flink.table.api.TableEnvironment, Table)} is
- * invoked.
+ * A Model is typically generated by invoking {@link Estimator#fit(Table...)}. 
A Model is a
+ * Transformer with the extra APIs to set and get model data.
  *
- * <p>We separate Model from {@link Transformer} in order to support potential 
model specific logic
- * such as linking a Model to the {@link Estimator} from which the model was 
generated.
- *
- * @param <M> The class type of the Model implementation itself, used by {@link
- *     org.apache.flink.ml.api.misc.param.WithParams}
+ * @param <T> The class type of the Model implementation itself.
  */
 @PublicEvolving
-public interface Model<M extends Model<M>> extends Transformer<M> {}
+public interface Model<T extends Model<T>> extends Transformer<T> {
+    /**
+     * Sets model data using the given list of tables. Each table could be an 
unbounded stream of
+     * model data changes.
+     *
+     * @param inputs a list of tables
+     */
+    default void setModelData(Table... inputs) {
+        throw new UnsupportedOperationException("this operation is not 
supported");
+    }
+
+    /**
+     * Gets a list of tables representing the model data. Each table could be 
an unbounded stream of
+     * model data changes.
+     *
+     * @return a list of tables
+     */
+    default Table[] getModelData() {
+        throw new UnsupportedOperationException("this operation is not 
supported");
+    }
+}
diff --git 
a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Pipeline.java 
b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Pipeline.java
index c5a56a5..a5fed01 100644
--- a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Pipeline.java
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Pipeline.java
@@ -19,241 +19,104 @@
 package org.apache.flink.ml.api.core;
 
 import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.ml.api.misc.param.Params;
 import org.apache.flink.table.api.Table;
-import org.apache.flink.table.api.TableEnvironment;
-import org.apache.flink.util.InstantiationUtil;
-
-import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException;
-import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper;
 
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collections;
-import java.util.HashMap;
 import java.util.List;
-import java.util.Map;
 
 /**
- * A pipeline is a linear workflow which chains {@link Estimator}s and {@link 
Transformer}s to
- * execute an algorithm.
- *
- * <p>A pipeline itself can either act as an Estimator or a Transformer, 
depending on the stages it
- * includes. More specifically:
- *
- * <ul>
- *   <li>If a Pipeline has an {@link Estimator}, one needs to call {@link
- *       Pipeline#fit(TableEnvironment, Table)} before use the pipeline as a 
{@link Transformer} .
- *       In this case the Pipeline is an {@link Estimator} and can produce a 
Pipeline as a {@link
- *       Model}.
- *   <li>If a Pipeline has no {@link Estimator}, it is a {@link Transformer} 
and can be applied to a
- *       Table directly. In this case, {@link Pipeline#fit(TableEnvironment, 
Table)} will simply
- *       return the pipeline itself.
- * </ul>
- *
- * <p>In addition, a pipeline can also be used as a {@link PipelineStage} in 
another pipeline, just
- * like an ordinary {@link Estimator} or {@link Transformer} as describe above.
+ * A Pipeline acts as an Estimator. It consists of an ordered list of stages, 
each of which could be
+ * an Estimator, Model, Transformer or AlgoOperator.
  */
 @PublicEvolving
-public final class Pipeline
-        implements Estimator<Pipeline, Pipeline>, Transformer<Pipeline>, 
Model<Pipeline> {
-    private static final long serialVersionUID = 1L;
-    private final List<PipelineStage> stages = new ArrayList<>();
+public final class Pipeline implements Estimator<Pipeline, PipelineModel> {
+    private static final long serialVersionUID = 6384850154817512318L;
+    private final List<Stage<?>> stages;
     private final Params params = new Params();
 
-    private int lastEstimatorIndex = -1;
-
-    public Pipeline() {}
-
-    public Pipeline(String pipelineJson) {
-        this.loadJson(pipelineJson);
-    }
-
-    public Pipeline(List<PipelineStage> stages) {
-        for (PipelineStage s : stages) {
-            appendStage(s);
-        }
-    }
-
-    // is the stage a simple Estimator or pipeline with Estimator
-    private static boolean isStageNeedFit(PipelineStage stage) {
-        return (stage instanceof Pipeline && ((Pipeline) stage).needFit())
-                || (!(stage instanceof Pipeline) && stage instanceof 
Estimator);
-    }
-
-    /**
-     * Appends a PipelineStage to the tail of this pipeline. Pipeline is 
editable only via this
-     * method. The PipelineStage must be Estimator, Transformer, Model or 
Pipeline.
-     *
-     * @param stage the stage to be appended
-     */
-    public Pipeline appendStage(PipelineStage stage) {
-        if (isStageNeedFit(stage)) {
-            lastEstimatorIndex = stages.size();
-        } else if (!(stage instanceof Transformer)) {
-            throw new RuntimeException(
-                    "All PipelineStages should be Estimator or Transformer, 
got:"
-                            + stage.getClass().getSimpleName());
-        }
-        stages.add(stage);
-        return this;
+    public Pipeline(List<Stage<?>> stages) {
+        this.stages = stages;
     }
 
     /**
-     * Returns a list of all stages in this pipeline in order, the list is 
immutable.
+     * Trains the pipeline to fit on the given tables.
      *
-     * @return an immutable list of all stages in this pipeline in order.
-     */
-    public List<PipelineStage> getStages() {
-        return Collections.unmodifiableList(stages);
-    }
-
-    /**
-     * Check whether the pipeline acts as an {@link Estimator} or not. When 
the return value is
-     * true, that means this pipeline contains an {@link Estimator} and thus 
users must invoke
-     * {@link #fit(TableEnvironment, Table)} before they can use this pipeline 
as a {@link
-     * Transformer}. Otherwise, the pipeline can be used as a {@link 
Transformer} directly.
-     *
-     * @return {@code true} if this pipeline has an Estimator, {@code false} 
otherwise
-     */
-    public boolean needFit() {
-        return this.getIndexOfLastEstimator() >= 0;
-    }
-
-    public Params getParams() {
-        return params;
-    }
-
-    // find the last Estimator or Pipeline that needs fit in stages, -1 stand 
for no Estimator in
-    // Pipeline
-    private int getIndexOfLastEstimator() {
-        return lastEstimatorIndex;
-    }
-
-    /**
-     * Train the pipeline to fit on the records in the given {@link Table}.
-     *
-     * <p>This method go through all the {@link PipelineStage}s in order and 
does the following on
-     * each stage until the last {@link Estimator}(inclusive).
+     * <p>This method goes through all stages of this pipeline in order and 
does the following on
+     * each stage until the last Estimator (inclusive).
      *
      * <ul>
-     *   <li>If a stage is an {@link Estimator}, invoke {@link 
Estimator#fit(TableEnvironment,
-     *       Table)} with the input table to generate a {@link Model}, 
transform the the input table
-     *       with the generated {@link Model} to get a result table, then pass 
the result table to
-     *       the next stage as input.
-     *   <li>If a stage is a {@link Transformer}, invoke {@link
-     *       Transformer#transform(TableEnvironment, Table)} on the input 
table to get a result
-     *       table, and pass the result table to the next stage as input.
+     *   <li>If a stage is an Estimator, invoke {@link 
Estimator#fit(Table...)} with the input
+     *       tables to generate a Model. And if there is Estimator after this 
stage, transform the
+     *       input tables using the generated Model to get result tables, then 
pass the result
+     *       tables to the next stage as inputs.
+     *   <li>If a stage is an AlgoOperator AND there is Estimator after this 
stage, transform the
+     *       input tables using this stage to get result tables, then pass the 
result tables to the
+     *       next stage as inputs.
      * </ul>
      *
-     * <p>After all the {@link Estimator}s are trained to fit their input 
tables, a new pipeline
-     * will be created with the same stages in this pipeline, except that all 
the Estimators in the
-     * new pipeline are replaced with their corresponding Models generated in 
the above process.
+     * <p>After all the Estimators are trained to fit their input tables, a 
new PipelineModel will
+     * be created with the same stages in this pipeline, except that all the 
Estimators in the
+     * PipelineModel are replaced with the models generated in the above 
process.
      *
-     * <p>If there is no {@link Estimator} in the pipeline, the method returns 
a copy of this
-     * pipeline.
-     *
-     * @param tEnv the table environment to which the input table is bound.
-     * @param input the table with records to train the Pipeline.
-     * @return a pipeline with same stages as this Pipeline except all 
Estimators replaced with
-     *     their corresponding Models.
+     * @param inputs a list of tables
+     * @return a PipelineModel
      */
     @Override
-    public Pipeline fit(TableEnvironment tEnv, Table input) {
-        List<PipelineStage> transformStages = new ArrayList<>(stages.size());
-        int lastEstimatorIdx = getIndexOfLastEstimator();
+    public PipelineModel fit(Table... inputs) {
+        int lastEstimatorIdx = -1;
         for (int i = 0; i < stages.size(); i++) {
-            PipelineStage s = stages.get(i);
-            if (i <= lastEstimatorIdx) {
-                Transformer t;
-                boolean needFit = isStageNeedFit(s);
-                if (needFit) {
-                    t = ((Estimator) s).fit(tEnv, input);
-                } else {
-                    // stage is Transformer, guaranteed in appendStage() method
-                    t = (Transformer) s;
-                }
-                transformStages.add(t);
-                input = t.transform(tEnv, input);
-            } else {
-                transformStages.add(s);
+            if (stages.get(i) instanceof Estimator) {
+                lastEstimatorIdx = i;
             }
         }
-        return new Pipeline(transformStages);
-    }
 
-    /**
-     * Generate a result table by applying all the stages in this pipeline to 
the input table in
-     * order.
-     *
-     * @param tEnv the table environment to which the input table is bound.
-     * @param input the table to be transformed
-     * @return a result table with all the stages applied to the input tables 
in order.
-     */
-    @Override
-    public Table transform(TableEnvironment tEnv, Table input) {
-        if (needFit()) {
-            throw new RuntimeException("Pipeline contains Estimator, need to 
fit first.");
-        }
-        for (PipelineStage s : stages) {
-            input = ((Transformer) s).transform(tEnv, input);
-        }
-        return input;
-    }
+        List<Stage<?>> modelStages = new ArrayList<>(stages.size());
+        Table[] lastInputs = inputs;
 
-    @Override
-    public String toJson() {
-        ObjectMapper mapper = new ObjectMapper();
+        for (int i = 0; i < stages.size(); i++) {
+            Stage<?> stage = stages.get(i);
+            AlgoOperator<?> modelStage;
+            if (stage instanceof AlgoOperator) {
+                modelStage = (AlgoOperator<?>) stage;
+            } else {
+                modelStage = ((Estimator<?, ?>) stage).fit(lastInputs);
+            }
+            modelStages.add(modelStage);
 
-        List<Map<String, String>> stageJsons = new ArrayList<>();
-        for (PipelineStage s : getStages()) {
-            Map<String, String> stageMap = new HashMap<>();
-            stageMap.put("stageClassName", s.getClass().getTypeName());
-            stageMap.put("stageJson", s.toJson());
-            stageJsons.add(stageMap);
+            // Transforms inputs only if there exists Estimator stage after 
this stage.
+            if (i < lastEstimatorIdx) {
+                lastInputs = modelStage.transform(lastInputs);
+            }
         }
 
-        try {
-            return mapper.writeValueAsString(stageJsons);
-        } catch (JsonProcessingException e) {
-            throw new RuntimeException("Failed to serialize pipeline", e);
-        }
+        return new PipelineModel(modelStages);
     }
 
     @Override
-    @SuppressWarnings("unchecked")
-    public void loadJson(String json) {
-        ObjectMapper mapper = new ObjectMapper();
-        List<Map<String, String>> stageJsons;
-        try {
-            stageJsons = mapper.readValue(json, List.class);
-        } catch (IOException e) {
-            throw new RuntimeException("Failed to deserialize pipeline json:" 
+ json, e);
-        }
-        for (Map<String, String> stageMap : stageJsons) {
-            appendStage(restoreInnerStage(stageMap));
-        }
+    public void save(String path) throws IOException {
+        throw new UnsupportedOperationException();
     }
 
-    private PipelineStage<?> restoreInnerStage(Map<String, String> stageMap) {
-        String className = stageMap.get("stageClassName");
-        Class<?> clz;
-        try {
-            clz = Class.forName(className);
-        } catch (ClassNotFoundException e) {
-            throw new RuntimeException("PipelineStage class " + className + " 
not exists", e);
-        }
-        InstantiationUtil.checkForInstantiation(clz);
+    public static Pipeline load(String path) throws IOException {
+        throw new UnsupportedOperationException();
+    }
 
-        PipelineStage<?> s;
-        try {
-            s = (PipelineStage<?>) clz.newInstance();
-        } catch (Exception e) {
-            throw new RuntimeException("Class is instantiable but failed to 
new an instance", e);
-        }
+    @Override
+    public Params getParams() {
+        return params;
+    }
 
-        String stageJson = stageMap.get("stageJson");
-        s.loadJson(stageJson);
-        return s;
+    /**
+     * Returns a list of all stages in this Pipeline in order. The list is 
immutable.
+     *
+     * @return an immutable list of stages.
+     */
+    @VisibleForTesting
+    List<Stage<?>> getStages() {
+        return Collections.unmodifiableList(stages);
     }
 }
diff --git 
a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/PipelineModel.java 
b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/PipelineModel.java
new file mode 100644
index 0000000..704fa8e
--- /dev/null
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/PipelineModel.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.api.core;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.ml.api.misc.param.Params;
+import org.apache.flink.table.api.Table;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+
+/**
+ * A PipelineModel acts as a Model. It consists of an ordered list of stages, 
each of which could be
+ * a Model, Transformer or AlgoOperator.
+ */
+@PublicEvolving
+public final class PipelineModel implements Model<PipelineModel> {
+    private static final long serialVersionUID = 6184950154217411318L;
+    private final List<Stage<?>> stages;
+    private final Params params = new Params();
+
+    public PipelineModel(List<Stage<?>> stages) {
+        this.stages = stages;
+    }
+
+    /**
+     * Applies all stages in this PipelineModel on the input tables in order. 
The output of one
+     * stage is used as the input of the next stage (if any). The output of 
the last stage is
+     * returned as the result of this method.
+     *
+     * @param inputs a list of tables
+     * @return a list of tables
+     */
+    @Override
+    public Table[] transform(Table... inputs) {
+        for (Stage<?> stage : stages) {
+            inputs = ((AlgoOperator<?>) stage).transform(inputs);
+        }
+        return inputs;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        throw new UnsupportedOperationException();
+    }
+
+    public static PipelineModel load(String path) throws IOException {
+        throw new UnsupportedOperationException();
+    }
+
+    @Override
+    public Params getParams() {
+        return params;
+    }
+
+    /**
+     * Returns a list of all stages in this PipelineModel in order. The list 
is immutable.
+     *
+     * @return an immutable list of transformers.
+     */
+    @VisibleForTesting
+    List<Stage<?>> getStages() {
+        return Collections.unmodifiableList(stages);
+    }
+}
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Stage.java 
b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Stage.java
new file mode 100644
index 0000000..551c5e5
--- /dev/null
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Stage.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.api.core;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.ml.api.misc.param.WithParams;
+
+import java.io.IOException;
+import java.io.Serializable;
+
+/**
+ * Base class for a node in a Pipeline or Graph. The interface is only a 
concept, and does not have
+ * any actual functionality. Its subclasses could be Estimator, Model, 
Transformer or AlgoOperator.
+ * No other classes should inherit this interface directly.
+ *
+ * <p>Each stage is with parameters, and requires a public empty constructor 
for restoration.
+ *
+ * @param <T> The class type of the Stage implementation itself.
+ */
+@PublicEvolving
+public interface Stage<T extends Stage<T>> extends WithParams<T>, Serializable 
{
+    /** Saves this stage to the given path. */
+    void save(String path) throws IOException;
+
+    // NOTE: every Stage subclass should implement a static method with 
signature "static T
+    // load(String path)", where T refers to the concrete subclass. This 
static method should
+    // instantiate a new stage instance based on the data from the given path.
+}
diff --git 
a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Transformer.java 
b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Transformer.java
index f9a152c..ec86968 100644
--- a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Transformer.java
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Transformer.java
@@ -19,24 +19,14 @@
 package org.apache.flink.ml.api.core;
 
 import org.apache.flink.annotation.PublicEvolving;
-import org.apache.flink.table.api.Table;
-import org.apache.flink.table.api.TableEnvironment;
 
 /**
- * A transformer is a {@link PipelineStage} that transforms an input {@link 
Table} to a result
- * {@link Table}.
+ * A Transformer is an AlgoOperator with the semantic difference that it 
encodes the Transformation
+ * logic, such that a record in the output typically corresponds to one record 
in the input. In
+ * contrast, an AlgoOperator is a better fit to express aggregation logic 
where a record in the
+ * output could be computed from an arbitrary number of records in the input.
  *
- * @param <T> The class type of the Transformer implementation itself, used by 
{@link
- *     org.apache.flink.ml.api.misc.param.WithParams}
+ * @param <T> The class type of the Transformer implementation itself.
  */
 @PublicEvolving
-public interface Transformer<T extends Transformer<T>> extends 
PipelineStage<T> {
-    /**
-     * Applies the transformer on the input table, and returns the result 
table.
-     *
-     * @param tEnv the table environment to which the input table is bound.
-     * @param input the table to be transformed
-     * @return the transformed table
-     */
-    Table transform(TableEnvironment tEnv, Table input);
-}
+public interface Transformer<T extends Transformer<T>> extends AlgoOperator<T> 
{}
diff --git 
a/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/PipelineTest.java 
b/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/PipelineTest.java
index 87ec13a..6d46430 100644
--- a/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/PipelineTest.java
+++ b/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/PipelineTest.java
@@ -22,55 +22,37 @@ import org.apache.flink.ml.api.misc.param.ParamInfo;
 import org.apache.flink.ml.api.misc.param.ParamInfoFactory;
 import org.apache.flink.ml.api.misc.param.Params;
 import org.apache.flink.table.api.Table;
-import org.apache.flink.table.api.TableEnvironment;
 
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.ExpectedException;
 
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
 /** Tests the behavior of {@link Pipeline}. */
 public class PipelineTest {
     @Rule public ExpectedException thrown = ExpectedException.none();
 
     @Test
     public void testPipelineBehavior() {
-        Pipeline pipeline = new Pipeline();
-        pipeline.appendStage(new MockTransformer("a"));
-        pipeline.appendStage(new MockEstimator("b"));
-        pipeline.appendStage(new MockEstimator("c"));
-        pipeline.appendStage(new MockTransformer("d"));
-        assert describePipeline(pipeline).equals("a_b_c_d");
-
-        Pipeline pipelineModel = pipeline.fit(null, null);
-        assert describePipeline(pipelineModel).equals("a_mb_mc_d");
-
-        thrown.expect(RuntimeException.class);
-        thrown.expectMessage("Pipeline contains Estimator, need to fit 
first.");
-        pipeline.transform(null, null);
-    }
+        List<Stage<?>> stages = new ArrayList<>();
+        stages.add(new MockTransformer("a"));
+        stages.add(new MockEstimator("b"));
+        stages.add(new MockEstimator("c"));
+        stages.add(new MockTransformer("d"));
 
-    @Test
-    public void testPipelineRestore() {
-        Pipeline pipeline = new Pipeline();
-        pipeline.appendStage(new MockTransformer("a"));
-        pipeline.appendStage(new MockEstimator("b"));
-        pipeline.appendStage(new MockEstimator("c"));
-        pipeline.appendStage(new MockTransformer("d"));
-        String pipelineJson = pipeline.toJson();
-
-        Pipeline restoredPipeline = new Pipeline(pipelineJson);
-        assert describePipeline(restoredPipeline).equals("a_b_c_d");
-
-        Pipeline pipelineModel = pipeline.fit(null, null);
-        String modelJson = pipelineModel.toJson();
-
-        Pipeline restoredPipelineModel = new Pipeline(modelJson);
-        assert describePipeline(restoredPipelineModel).equals("a_mb_mc_d");
+        Pipeline pipeline = new Pipeline(stages);
+        assert describePipeline(pipeline.getStages()).equals("a_b_c_d");
+
+        PipelineModel pipelineModel = pipeline.fit(null, null);
+        assert describePipeline(pipelineModel.getStages()).equals("a_mb_mc_d");
     }
 
-    private static String describePipeline(Pipeline p) {
+    private static String describePipeline(List<Stage<?>> stages) {
         StringBuilder res = new StringBuilder();
-        for (PipelineStage s : p.getStages()) {
+        for (Stage<?> s : stages) {
             if (res.length() != 0) {
                 res.append("_");
             }
@@ -98,7 +80,7 @@ public class PipelineTest {
         }
 
         @Override
-        public MockModel fit(TableEnvironment tEnv, Table input) {
+        public MockModel fit(Table... inputs) {
             return new MockModel("m" + describe());
         }
 
@@ -111,6 +93,9 @@ public class PipelineTest {
         public String describe() {
             return get(DESCRIPTION);
         }
+
+        @Override
+        public void save(String path) throws IOException {}
     }
 
     /** Mock transformer for pipeline test. */
@@ -124,8 +109,8 @@ public class PipelineTest {
         }
 
         @Override
-        public Table transform(TableEnvironment tEnv, Table input) {
-            return input;
+        public Table[] transform(Table... inputs) {
+            return inputs;
         }
 
         @Override
@@ -137,6 +122,9 @@ public class PipelineTest {
         public String describe() {
             return get(DESCRIPTION);
         }
+
+        @Override
+        public void save(String path) throws IOException {}
     }
 
     /** Mock model for pipeline test. */
@@ -150,8 +138,8 @@ public class PipelineTest {
         }
 
         @Override
-        public Table transform(TableEnvironment tEnv, Table input) {
-            return input;
+        public Table[] transform(Table... inputs) {
+            return inputs;
         }
 
         @Override
@@ -163,5 +151,8 @@ public class PipelineTest {
         public String describe() {
             return get(DESCRIPTION);
         }
+
+        @Override
+        public void save(String path) throws IOException {}
     }
 }

Reply via email to