This is an automated email from the ASF dual-hosted git repository.
jqin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new 5ff346e [FLINK-22915][FLIP-173] Update Flink ML API to support
AlgoOperator with multiple input tables and multiple output tables
5ff346e is described below
commit 5ff346ea1a508a00b89759492f09e7330e69baef
Author: Dong Lin <[email protected]>
AuthorDate: Wed Sep 22 13:47:39 2021 +0800
[FLINK-22915][FLIP-173] Update Flink ML API to support AlgoOperator with
multiple input tables and multiple output tables
---
.../core/{PipelineStage.java => AlgoOperator.java} | 35 ++-
.../org/apache/flink/ml/api/core/Estimator.java | 24 +-
.../java/org/apache/flink/ml/api/core/Model.java | 34 ++-
.../org/apache/flink/ml/api/core/Pipeline.java | 257 +++++----------------
.../apache/flink/ml/api/core/PipelineModel.java | 83 +++++++
.../java/org/apache/flink/ml/api/core/Stage.java | 44 ++++
.../org/apache/flink/ml/api/core/Transformer.java | 22 +-
.../org/apache/flink/ml/api/core/PipelineTest.java | 69 +++---
8 files changed, 269 insertions(+), 299 deletions(-)
diff --git
a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/PipelineStage.java
b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/AlgoOperator.java
similarity index 50%
rename from
flink-ml-api/src/main/java/org/apache/flink/ml/api/core/PipelineStage.java
rename to
flink-ml-api/src/main/java/org/apache/flink/ml/api/core/AlgoOperator.java
index 0a3dd23..7f2d4b4 100644
--- a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/PipelineStage.java
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/AlgoOperator.java
@@ -18,29 +18,22 @@
package org.apache.flink.ml.api.core;
-import org.apache.flink.ml.api.misc.param.WithParams;
-
-import java.io.Serializable;
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.table.api.Table;
/**
- * Base class for a stage in a pipeline. The interface is only a concept, and
does not have any
- * actual functionality. Its subclasses must be either Estimator or
Transformer. No other classes
- * should inherit this interface directly.
- *
- * <p>Each pipeline stage is with parameters, and requires a public empty
constructor for
- * restoration in Pipeline.
+ * An AlgoOperator takes a list of tables as inputs and produces a list of
tables as results. It can
+ * be used to encode generic multi-input multi-output computation logic.
*
- * @param <T> The class type of the PipelineStage implementation itself, used
by {@link
- * org.apache.flink.ml.api.misc.param.WithParams}
- * @see WithParams
+ * @param <T> The class type of the AlgoOperator implementation itself.
*/
-interface PipelineStage<T extends PipelineStage<T>> extends WithParams<T>,
Serializable {
-
- default String toJson() {
- return getParams().toJson();
- }
-
- default void loadJson(String json) {
- getParams().loadJson(json);
- }
+@PublicEvolving
+public interface AlgoOperator<T extends AlgoOperator<T>> extends Stage<T> {
+ /**
+ * Applies the AlgoOperator on the given input tables and returns the
result tables.
+ *
+ * @param inputs a list of tables
+ * @return a list of tables
+ */
+ Table[] transform(Table... inputs);
}
diff --git
a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Estimator.java
b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Estimator.java
index 24c8349..bab9c7d 100644
--- a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Estimator.java
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Estimator.java
@@ -20,28 +20,20 @@ package org.apache.flink.ml.api.core;
import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.table.api.Table;
-import org.apache.flink.table.api.TableEnvironment;
/**
- * Estimators are {@link PipelineStage}s responsible for training and
generating machine learning
- * models.
+ * Estimators are responsible for training and generating Models.
*
- * <p>The implementations are expected to take an input table as training
samples and generate a
- * {@link Model} which fits these samples.
- *
- * @param <E> class type of the Estimator implementation itself, used by {@link
- * org.apache.flink.ml.api.misc.param.WithParams}.
- * @param <M> class type of the {@link Model} this Estimator produces.
+ * @param <E> class type of the Estimator implementation itself.
+ * @param <M> class type of the Model this Estimator produces.
*/
@PublicEvolving
-public interface Estimator<E extends Estimator<E, M>, M extends Model<M>>
extends PipelineStage<E> {
-
+public interface Estimator<E extends Estimator<E, M>, M extends Model<M>>
extends Stage<E> {
/**
- * Train and produce a {@link Model} which fits the records in the given
{@link Table}.
+ * Trains on the given inputs and produces a Model.
*
- * @param tEnv the table environment to which the input table is bound.
- * @param input the table with records to train the Model.
- * @return a model trained to fit on the given Table.
+ * @param inputs a list of tables
+ * @return a Model
*/
- M fit(TableEnvironment tEnv, Table input);
+ M fit(Table... inputs);
}
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Model.java
b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Model.java
index 6f15bc5..8caffe3 100644
--- a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Model.java
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Model.java
@@ -22,16 +22,30 @@ import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.table.api.Table;
/**
- * A model is an ordinary {@link Transformer} except how it is created. While
ordinary transformers
- * are defined by specifying the parameters directly, a model is usually
generated by an {@link
- * Estimator} when {@link
Estimator#fit(org.apache.flink.table.api.TableEnvironment, Table)} is
- * invoked.
+ * A Model is typically generated by invoking {@link Estimator#fit(Table...)}.
A Model is a
+ * Transformer with the extra APIs to set and get model data.
*
- * <p>We separate Model from {@link Transformer} in order to support potential
model specific logic
- * such as linking a Model to the {@link Estimator} from which the model was
generated.
- *
- * @param <M> The class type of the Model implementation itself, used by {@link
- * org.apache.flink.ml.api.misc.param.WithParams}
+ * @param <T> The class type of the Model implementation itself.
*/
@PublicEvolving
-public interface Model<M extends Model<M>> extends Transformer<M> {}
+public interface Model<T extends Model<T>> extends Transformer<T> {
+ /**
+ * Sets model data using the given list of tables. Each table could be an
unbounded stream of
+ * model data changes.
+ *
+ * @param inputs a list of tables
+ */
+ default void setModelData(Table... inputs) {
+ throw new UnsupportedOperationException("this operation is not
supported");
+ }
+
+ /**
+ * Gets a list of tables representing the model data. Each table could be
an unbounded stream of
+ * model data changes.
+ *
+ * @return a list of tables
+ */
+ default Table[] getModelData() {
+ throw new UnsupportedOperationException("this operation is not
supported");
+ }
+}
diff --git
a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Pipeline.java
b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Pipeline.java
index c5a56a5..a5fed01 100644
--- a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Pipeline.java
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Pipeline.java
@@ -19,241 +19,104 @@
package org.apache.flink.ml.api.core;
import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.Table;
-import org.apache.flink.table.api.TableEnvironment;
-import org.apache.flink.util.InstantiationUtil;
-
-import
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException;
-import
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
-import java.util.HashMap;
import java.util.List;
-import java.util.Map;
/**
- * A pipeline is a linear workflow which chains {@link Estimator}s and {@link
Transformer}s to
- * execute an algorithm.
- *
- * <p>A pipeline itself can either act as an Estimator or a Transformer,
depending on the stages it
- * includes. More specifically:
- *
- * <ul>
- * <li>If a Pipeline has an {@link Estimator}, one needs to call {@link
- * Pipeline#fit(TableEnvironment, Table)} before use the pipeline as a
{@link Transformer} .
- * In this case the Pipeline is an {@link Estimator} and can produce a
Pipeline as a {@link
- * Model}.
- * <li>If a Pipeline has no {@link Estimator}, it is a {@link Transformer}
and can be applied to a
- * Table directly. In this case, {@link Pipeline#fit(TableEnvironment,
Table)} will simply
- * return the pipeline itself.
- * </ul>
- *
- * <p>In addition, a pipeline can also be used as a {@link PipelineStage} in
another pipeline, just
- * like an ordinary {@link Estimator} or {@link Transformer} as describe above.
+ * A Pipeline acts as an Estimator. It consists of an ordered list of stages,
each of which could be
+ * an Estimator, Model, Transformer or AlgoOperator.
*/
@PublicEvolving
-public final class Pipeline
- implements Estimator<Pipeline, Pipeline>, Transformer<Pipeline>,
Model<Pipeline> {
- private static final long serialVersionUID = 1L;
- private final List<PipelineStage> stages = new ArrayList<>();
+public final class Pipeline implements Estimator<Pipeline, PipelineModel> {
+ private static final long serialVersionUID = 6384850154817512318L;
+ private final List<Stage<?>> stages;
private final Params params = new Params();
- private int lastEstimatorIndex = -1;
-
- public Pipeline() {}
-
- public Pipeline(String pipelineJson) {
- this.loadJson(pipelineJson);
- }
-
- public Pipeline(List<PipelineStage> stages) {
- for (PipelineStage s : stages) {
- appendStage(s);
- }
- }
-
- // is the stage a simple Estimator or pipeline with Estimator
- private static boolean isStageNeedFit(PipelineStage stage) {
- return (stage instanceof Pipeline && ((Pipeline) stage).needFit())
- || (!(stage instanceof Pipeline) && stage instanceof
Estimator);
- }
-
- /**
- * Appends a PipelineStage to the tail of this pipeline. Pipeline is
editable only via this
- * method. The PipelineStage must be Estimator, Transformer, Model or
Pipeline.
- *
- * @param stage the stage to be appended
- */
- public Pipeline appendStage(PipelineStage stage) {
- if (isStageNeedFit(stage)) {
- lastEstimatorIndex = stages.size();
- } else if (!(stage instanceof Transformer)) {
- throw new RuntimeException(
- "All PipelineStages should be Estimator or Transformer,
got:"
- + stage.getClass().getSimpleName());
- }
- stages.add(stage);
- return this;
+ public Pipeline(List<Stage<?>> stages) {
+ this.stages = stages;
}
/**
- * Returns a list of all stages in this pipeline in order, the list is
immutable.
+ * Trains the pipeline to fit on the given tables.
*
- * @return an immutable list of all stages in this pipeline in order.
- */
- public List<PipelineStage> getStages() {
- return Collections.unmodifiableList(stages);
- }
-
- /**
- * Check whether the pipeline acts as an {@link Estimator} or not. When
the return value is
- * true, that means this pipeline contains an {@link Estimator} and thus
users must invoke
- * {@link #fit(TableEnvironment, Table)} before they can use this pipeline
as a {@link
- * Transformer}. Otherwise, the pipeline can be used as a {@link
Transformer} directly.
- *
- * @return {@code true} if this pipeline has an Estimator, {@code false}
otherwise
- */
- public boolean needFit() {
- return this.getIndexOfLastEstimator() >= 0;
- }
-
- public Params getParams() {
- return params;
- }
-
- // find the last Estimator or Pipeline that needs fit in stages, -1 stand
for no Estimator in
- // Pipeline
- private int getIndexOfLastEstimator() {
- return lastEstimatorIndex;
- }
-
- /**
- * Train the pipeline to fit on the records in the given {@link Table}.
- *
- * <p>This method go through all the {@link PipelineStage}s in order and
does the following on
- * each stage until the last {@link Estimator}(inclusive).
+ * <p>This method goes through all stages of this pipeline in order and
does the following on
+ * each stage until the last Estimator (inclusive).
*
* <ul>
- * <li>If a stage is an {@link Estimator}, invoke {@link
Estimator#fit(TableEnvironment,
- * Table)} with the input table to generate a {@link Model},
transform the the input table
- * with the generated {@link Model} to get a result table, then pass
the result table to
- * the next stage as input.
- * <li>If a stage is a {@link Transformer}, invoke {@link
- * Transformer#transform(TableEnvironment, Table)} on the input
table to get a result
- * table, and pass the result table to the next stage as input.
+ * <li>If a stage is an Estimator, invoke {@link
Estimator#fit(Table...)} with the input
+ * tables to generate a Model. And if there is Estimator after this
stage, transform the
+ * input tables using the generated Model to get result tables, then
pass the result
+ * tables to the next stage as inputs.
+ * <li>If a stage is an AlgoOperator AND there is Estimator after this
stage, transform the
+ * input tables using this stage to get result tables, then pass the
result tables to the
+ * next stage as inputs.
* </ul>
*
- * <p>After all the {@link Estimator}s are trained to fit their input
tables, a new pipeline
- * will be created with the same stages in this pipeline, except that all
the Estimators in the
- * new pipeline are replaced with their corresponding Models generated in
the above process.
+ * <p>After all the Estimators are trained to fit their input tables, a
new PipelineModel will
+ * be created with the same stages in this pipeline, except that all the
Estimators in the
+ * PipelineModel are replaced with the models generated in the above
process.
*
- * <p>If there is no {@link Estimator} in the pipeline, the method returns
a copy of this
- * pipeline.
- *
- * @param tEnv the table environment to which the input table is bound.
- * @param input the table with records to train the Pipeline.
- * @return a pipeline with same stages as this Pipeline except all
Estimators replaced with
- * their corresponding Models.
+ * @param inputs a list of tables
+ * @return a PipelineModel
*/
@Override
- public Pipeline fit(TableEnvironment tEnv, Table input) {
- List<PipelineStage> transformStages = new ArrayList<>(stages.size());
- int lastEstimatorIdx = getIndexOfLastEstimator();
+ public PipelineModel fit(Table... inputs) {
+ int lastEstimatorIdx = -1;
for (int i = 0; i < stages.size(); i++) {
- PipelineStage s = stages.get(i);
- if (i <= lastEstimatorIdx) {
- Transformer t;
- boolean needFit = isStageNeedFit(s);
- if (needFit) {
- t = ((Estimator) s).fit(tEnv, input);
- } else {
- // stage is Transformer, guaranteed in appendStage() method
- t = (Transformer) s;
- }
- transformStages.add(t);
- input = t.transform(tEnv, input);
- } else {
- transformStages.add(s);
+ if (stages.get(i) instanceof Estimator) {
+ lastEstimatorIdx = i;
}
}
- return new Pipeline(transformStages);
- }
- /**
- * Generate a result table by applying all the stages in this pipeline to
the input table in
- * order.
- *
- * @param tEnv the table environment to which the input table is bound.
- * @param input the table to be transformed
- * @return a result table with all the stages applied to the input tables
in order.
- */
- @Override
- public Table transform(TableEnvironment tEnv, Table input) {
- if (needFit()) {
- throw new RuntimeException("Pipeline contains Estimator, need to
fit first.");
- }
- for (PipelineStage s : stages) {
- input = ((Transformer) s).transform(tEnv, input);
- }
- return input;
- }
+ List<Stage<?>> modelStages = new ArrayList<>(stages.size());
+ Table[] lastInputs = inputs;
- @Override
- public String toJson() {
- ObjectMapper mapper = new ObjectMapper();
+ for (int i = 0; i < stages.size(); i++) {
+ Stage<?> stage = stages.get(i);
+ AlgoOperator<?> modelStage;
+ if (stage instanceof AlgoOperator) {
+ modelStage = (AlgoOperator<?>) stage;
+ } else {
+ modelStage = ((Estimator<?, ?>) stage).fit(lastInputs);
+ }
+ modelStages.add(modelStage);
- List<Map<String, String>> stageJsons = new ArrayList<>();
- for (PipelineStage s : getStages()) {
- Map<String, String> stageMap = new HashMap<>();
- stageMap.put("stageClassName", s.getClass().getTypeName());
- stageMap.put("stageJson", s.toJson());
- stageJsons.add(stageMap);
+ // Transforms inputs only if there exists Estimator stage after
this stage.
+ if (i < lastEstimatorIdx) {
+ lastInputs = modelStage.transform(lastInputs);
+ }
}
- try {
- return mapper.writeValueAsString(stageJsons);
- } catch (JsonProcessingException e) {
- throw new RuntimeException("Failed to serialize pipeline", e);
- }
+ return new PipelineModel(modelStages);
}
@Override
- @SuppressWarnings("unchecked")
- public void loadJson(String json) {
- ObjectMapper mapper = new ObjectMapper();
- List<Map<String, String>> stageJsons;
- try {
- stageJsons = mapper.readValue(json, List.class);
- } catch (IOException e) {
- throw new RuntimeException("Failed to deserialize pipeline json:"
+ json, e);
- }
- for (Map<String, String> stageMap : stageJsons) {
- appendStage(restoreInnerStage(stageMap));
- }
+ public void save(String path) throws IOException {
+ throw new UnsupportedOperationException();
}
- private PipelineStage<?> restoreInnerStage(Map<String, String> stageMap) {
- String className = stageMap.get("stageClassName");
- Class<?> clz;
- try {
- clz = Class.forName(className);
- } catch (ClassNotFoundException e) {
- throw new RuntimeException("PipelineStage class " + className + "
not exists", e);
- }
- InstantiationUtil.checkForInstantiation(clz);
+ public static Pipeline load(String path) throws IOException {
+ throw new UnsupportedOperationException();
+ }
- PipelineStage<?> s;
- try {
- s = (PipelineStage<?>) clz.newInstance();
- } catch (Exception e) {
- throw new RuntimeException("Class is instantiable but failed to
new an instance", e);
- }
+ @Override
+ public Params getParams() {
+ return params;
+ }
- String stageJson = stageMap.get("stageJson");
- s.loadJson(stageJson);
- return s;
+ /**
+ * Returns a list of all stages in this Pipeline in order. The list is
immutable.
+ *
+ * @return an immutable list of stages.
+ */
+ @VisibleForTesting
+ List<Stage<?>> getStages() {
+ return Collections.unmodifiableList(stages);
}
}
diff --git
a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/PipelineModel.java
b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/PipelineModel.java
new file mode 100644
index 0000000..704fa8e
--- /dev/null
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/PipelineModel.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.api.core;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.ml.api.misc.param.Params;
+import org.apache.flink.table.api.Table;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+
+/**
+ * A PipelineModel acts as a Model. It consists of an ordered list of stages,
each of which could be
+ * a Model, Transformer or AlgoOperator.
+ */
+@PublicEvolving
+public final class PipelineModel implements Model<PipelineModel> {
+ private static final long serialVersionUID = 6184950154217411318L;
+ private final List<Stage<?>> stages;
+ private final Params params = new Params();
+
+ public PipelineModel(List<Stage<?>> stages) {
+ this.stages = stages;
+ }
+
+ /**
+ * Applies all stages in this PipelineModel on the input tables in order.
The output of one
+ * stage is used as the input of the next stage (if any). The output of
the last stage is
+ * returned as the result of this method.
+ *
+ * @param inputs a list of tables
+ * @return a list of tables
+ */
+ @Override
+ public Table[] transform(Table... inputs) {
+ for (Stage<?> stage : stages) {
+ inputs = ((AlgoOperator<?>) stage).transform(inputs);
+ }
+ return inputs;
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ throw new UnsupportedOperationException();
+ }
+
+ public static PipelineModel load(String path) throws IOException {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Params getParams() {
+ return params;
+ }
+
+ /**
+ * Returns a list of all stages in this PipelineModel in order. The list
is immutable.
+ *
+ * @return an immutable list of transformers.
+ */
+ @VisibleForTesting
+ List<Stage<?>> getStages() {
+ return Collections.unmodifiableList(stages);
+ }
+}
diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Stage.java
b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Stage.java
new file mode 100644
index 0000000..551c5e5
--- /dev/null
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Stage.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.api.core;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.ml.api.misc.param.WithParams;
+
+import java.io.IOException;
+import java.io.Serializable;
+
+/**
+ * Base class for a node in a Pipeline or Graph. The interface is only a
concept, and does not have
+ * any actual functionality. Its subclasses could be Estimator, Model,
Transformer or AlgoOperator.
+ * No other classes should inherit this interface directly.
+ *
+ * <p>Each stage is with parameters, and requires a public empty constructor
for restoration.
+ *
+ * @param <T> The class type of the Stage implementation itself.
+ */
+@PublicEvolving
+public interface Stage<T extends Stage<T>> extends WithParams<T>, Serializable
{
+ /** Saves this stage to the given path. */
+ void save(String path) throws IOException;
+
+ // NOTE: every Stage subclass should implement a static method with
signature "static T
+ // load(String path)", where T refers to the concrete subclass. This
static method should
+ // instantiate a new stage instance based on the data from the given path.
+}
diff --git
a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Transformer.java
b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Transformer.java
index f9a152c..ec86968 100644
--- a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Transformer.java
+++ b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Transformer.java
@@ -19,24 +19,14 @@
package org.apache.flink.ml.api.core;
import org.apache.flink.annotation.PublicEvolving;
-import org.apache.flink.table.api.Table;
-import org.apache.flink.table.api.TableEnvironment;
/**
- * A transformer is a {@link PipelineStage} that transforms an input {@link
Table} to a result
- * {@link Table}.
+ * A Transformer is an AlgoOperator with the semantic difference that it
encodes the Transformation
+ * logic, such that a record in the output typically corresponds to one record
in the input. In
+ * contrast, an AlgoOperator is a better fit to express aggregation logic
where a record in the
+ * output could be computed from an arbitrary number of records in the input.
*
- * @param <T> The class type of the Transformer implementation itself, used by
{@link
- * org.apache.flink.ml.api.misc.param.WithParams}
+ * @param <T> The class type of the Transformer implementation itself.
*/
@PublicEvolving
-public interface Transformer<T extends Transformer<T>> extends
PipelineStage<T> {
- /**
- * Applies the transformer on the input table, and returns the result
table.
- *
- * @param tEnv the table environment to which the input table is bound.
- * @param input the table to be transformed
- * @return the transformed table
- */
- Table transform(TableEnvironment tEnv, Table input);
-}
+public interface Transformer<T extends Transformer<T>> extends AlgoOperator<T>
{}
diff --git
a/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/PipelineTest.java
b/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/PipelineTest.java
index 87ec13a..6d46430 100644
--- a/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/PipelineTest.java
+++ b/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/PipelineTest.java
@@ -22,55 +22,37 @@ import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.ParamInfoFactory;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.Table;
-import org.apache.flink.table.api.TableEnvironment;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
/** Tests the behavior of {@link Pipeline}. */
public class PipelineTest {
@Rule public ExpectedException thrown = ExpectedException.none();
@Test
public void testPipelineBehavior() {
- Pipeline pipeline = new Pipeline();
- pipeline.appendStage(new MockTransformer("a"));
- pipeline.appendStage(new MockEstimator("b"));
- pipeline.appendStage(new MockEstimator("c"));
- pipeline.appendStage(new MockTransformer("d"));
- assert describePipeline(pipeline).equals("a_b_c_d");
-
- Pipeline pipelineModel = pipeline.fit(null, null);
- assert describePipeline(pipelineModel).equals("a_mb_mc_d");
-
- thrown.expect(RuntimeException.class);
- thrown.expectMessage("Pipeline contains Estimator, need to fit
first.");
- pipeline.transform(null, null);
- }
+ List<Stage<?>> stages = new ArrayList<>();
+ stages.add(new MockTransformer("a"));
+ stages.add(new MockEstimator("b"));
+ stages.add(new MockEstimator("c"));
+ stages.add(new MockTransformer("d"));
- @Test
- public void testPipelineRestore() {
- Pipeline pipeline = new Pipeline();
- pipeline.appendStage(new MockTransformer("a"));
- pipeline.appendStage(new MockEstimator("b"));
- pipeline.appendStage(new MockEstimator("c"));
- pipeline.appendStage(new MockTransformer("d"));
- String pipelineJson = pipeline.toJson();
-
- Pipeline restoredPipeline = new Pipeline(pipelineJson);
- assert describePipeline(restoredPipeline).equals("a_b_c_d");
-
- Pipeline pipelineModel = pipeline.fit(null, null);
- String modelJson = pipelineModel.toJson();
-
- Pipeline restoredPipelineModel = new Pipeline(modelJson);
- assert describePipeline(restoredPipelineModel).equals("a_mb_mc_d");
+ Pipeline pipeline = new Pipeline(stages);
+ assert describePipeline(pipeline.getStages()).equals("a_b_c_d");
+
+ PipelineModel pipelineModel = pipeline.fit(null, null);
+ assert describePipeline(pipelineModel.getStages()).equals("a_mb_mc_d");
}
- private static String describePipeline(Pipeline p) {
+ private static String describePipeline(List<Stage<?>> stages) {
StringBuilder res = new StringBuilder();
- for (PipelineStage s : p.getStages()) {
+ for (Stage<?> s : stages) {
if (res.length() != 0) {
res.append("_");
}
@@ -98,7 +80,7 @@ public class PipelineTest {
}
@Override
- public MockModel fit(TableEnvironment tEnv, Table input) {
+ public MockModel fit(Table... inputs) {
return new MockModel("m" + describe());
}
@@ -111,6 +93,9 @@ public class PipelineTest {
public String describe() {
return get(DESCRIPTION);
}
+
+ @Override
+ public void save(String path) throws IOException {}
}
/** Mock transformer for pipeline test. */
@@ -124,8 +109,8 @@ public class PipelineTest {
}
@Override
- public Table transform(TableEnvironment tEnv, Table input) {
- return input;
+ public Table[] transform(Table... inputs) {
+ return inputs;
}
@Override
@@ -137,6 +122,9 @@ public class PipelineTest {
public String describe() {
return get(DESCRIPTION);
}
+
+ @Override
+ public void save(String path) throws IOException {}
}
/** Mock model for pipeline test. */
@@ -150,8 +138,8 @@ public class PipelineTest {
}
@Override
- public Table transform(TableEnvironment tEnv, Table input) {
- return input;
+ public Table[] transform(Table... inputs) {
+ return inputs;
}
@Override
@@ -163,5 +151,8 @@ public class PipelineTest {
public String describe() {
return get(DESCRIPTION);
}
+
+ @Override
+ public void save(String path) throws IOException {}
}
}