This is an automated email from the ASF dual-hosted git repository.

zhangzp pushed a commit to branch yuhe_release
in repository https://gitbox.apache.org/repos/asf/flink-ml.git

commit 9c6ab9e4f48fd7cdf1d41355cf9237ac9f6c6426
Author: weibo <[email protected]>
AuthorDate: Tue Oct 11 10:50:27 2022 +0800

    Add AlgoOperator for Filter, TypeTransform, Splitter on Flink
---
 docs/content/docs/operators/feature/splitter.md    | 144 ++++++
 .../flink/ml/examples/feature/SplitterExample.java |  63 +++
 .../org/apache/flink/ml/feature/filter/Filter.java |  82 +++
 .../flink/ml/feature/filter/FilterParams.java      |  67 +++
 .../apache/flink/ml/feature/splitter/Splitter.java | 113 +++++
 .../flink/ml/feature/splitter/SplitterParams.java  |  57 +++
 .../flink/ml/feature/transform/TypeTransform.java  | 186 +++++++
 .../ml/feature/transform/TypeTransformParams.java  | 191 +++++++
 .../org/apache/flink/ml/feature/FilterTest.java    |  85 ++++
 .../org/apache/flink/ml/feature/SplitterTest.java  | 181 +++++++
 .../apache/flink/ml/feature/TypeTransformTest.java | 547 +++++++++++++++++++++
 .../examples/ml/feature/splitter_example.py        |  60 +++
 flink-ml-python/pyflink/ml/lib/feature/splitter.py |  82 +++
 .../pyflink/ml/lib/feature/tests/test_splitter.py  | 104 ++++
 14 files changed, 1962 insertions(+)

diff --git a/docs/content/docs/operators/feature/splitter.md 
b/docs/content/docs/operators/feature/splitter.md
new file mode 100644
index 00000000..28f35429
--- /dev/null
+++ b/docs/content/docs/operators/feature/splitter.md
@@ -0,0 +1,144 @@
+---
+title: "Splitter"
+weight: 1
+type: docs
+aliases:
+- /operators/feature/splitter.html
+---
+
+<!--
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+  http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied.  See the License for the
+specific language governing permissions and limitations
+under the License.
+-->
+
+## Splitter
+
+An AlgoOperator which splits a dataset into two datasets according to a given 
fraction.
+
+### Parameters
+
+| Key      | Default | Type    | Required | Description                        
                          |
+|:---------|:--------|:--------|:---------|:-------------------------------------------------------------|
+| fraction | `0.5`   | Double  | no       | Proportion of data allocated to 
left output after splitting. |
+| seed     | `null`  | Integer | no       | The random seed.                   
                          |
+
+### Examples
+
+{{< tabs examples >}}
+
+{{< tab "Java">}}
+
+```java
+import org.apache.flink.ml.feature.splitter.Splitter;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+/** Simple program that creates a Splitter instance and uses it for data 
splitting. */
+public class SplitterExample {
+       public static void main(String[] args) {
+               StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+               StreamTableEnvironment tEnv = 
StreamTableEnvironment.create(env);
+
+               // Generates input data.
+               DataStream<Row> inputStream =
+                       env.fromElements(
+                               Row.of(1, 10, 0),
+                               Row.of(1, 10, 0),
+                               Row.of(1, 10, 0),
+                               Row.of(4, 10, 0),
+                               Row.of(5, 10, 0),
+                               Row.of(6, 10, 0),
+                               Row.of(7, 10, 0),
+                               Row.of(10, 10, 0),
+                               Row.of(13, 10, 3));
+               Table inputTable = tEnv.fromDataStream(inputStream).as("input");
+
+               // Creates a Splitter object and initializes its parameters.
+               Splitter splitter = new Splitter().setFraction(0.4).setSeed(1);
+
+               // Uses the Splitter to split inputData.
+               Table[] outputTable = splitter.transform(inputTable);
+
+               // Extracts and displays the results.
+               for (CloseableIterator<Row> it = 
outputTable[0].execute().collect(); it.hasNext(); ) {
+                       System.out.printf("Data 1 : %s\n", it.next());
+               }
+               for (CloseableIterator<Row> it = 
outputTable[1].execute().collect(); it.hasNext(); ) {
+                       System.out.printf("Data 2 : %s\n", it.next());
+               }
+       }
+}
+
+```
+
+{{< /tab>}}
+
+{{< tab "Python">}}
+
+```python
+# Simple program that creates a Splitter instance and uses it for feature
+# engineering.
+
+from pyflink.common import Types
+from pyflink.datastream import StreamExecutionEnvironment
+from pyflink.ml.lib.feature.splitter import Splitter
+from pyflink.table import StreamTableEnvironment
+
+# Creates a new StreamExecutionEnvironment.
+env = StreamExecutionEnvironment.get_execution_environment()
+
+# Creates a StreamTableEnvironment.
+t_env = StreamTableEnvironment.create(env)
+
+# Generates input.
+input_table = t_env.from_data_stream(
+    env.from_collection([
+        (1, 10, 0),
+        (1, 10, 0),
+        (1, 10, 0),
+        (4, 10, 0),
+        (5, 10, 0),
+        (6, 10, 0),
+        (7, 10, 0),
+        (10, 10, 0),
+        (13, 10, 0)
+    ],
+        type_info=Types.ROW_NAMED(
+            ['input', ],
+            [Types.INT(), ])))
+
+# Creates a Splitter object and initializes its parameters.
+splitter = Splitter().set_seed(1).set_fraction(0.4)
+
+# Uses the Splitter to split dataset.
+output = splitter.transform(input_table)
+
+# Extracts and displays the results.
+for result in t_env.to_data_stream(output[0]).execute_and_collect():
+    print('Data 1: ' + str(result))
+for result in t_env.to_data_stream(output[1]).execute_and_collect():
+    print('Data 2: ' + str(result))
+
+```
+
+{{< /tab>}}
+
+{{< /tabs>}}
diff --git 
a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/SplitterExample.java
 
b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/SplitterExample.java
new file mode 100644
index 00000000..5d9eb27e
--- /dev/null
+++ 
b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/SplitterExample.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.examples.feature;
+
+import org.apache.flink.ml.feature.splitter.Splitter;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+/** Simple program that creates a Splitter instance and uses it for data 
splitting. */
+public class SplitterExample {
+    public static void main(String[] args) {
+        StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+        // Generates input data.
+        DataStream<Row> inputStream =
+                env.fromElements(
+                        Row.of(1, 10, 0),
+                        Row.of(1, 10, 0),
+                        Row.of(1, 10, 0),
+                        Row.of(4, 10, 0),
+                        Row.of(5, 10, 0),
+                        Row.of(6, 10, 0),
+                        Row.of(7, 10, 0),
+                        Row.of(10, 10, 0),
+                        Row.of(13, 10, 3));
+        Table inputTable = tEnv.fromDataStream(inputStream).as("input");
+
+        // Creates a Splitter object and initializes its parameters.
+        Splitter splitter = new Splitter().setFraction(0.4).setSeed(1);
+
+        // Uses the Splitter to split inputData.
+        Table[] outputTable = splitter.transform(inputTable);
+
+        // Extracts and displays the results.
+        for (CloseableIterator<Row> it = outputTable[0].execute().collect(); 
it.hasNext(); ) {
+            System.out.printf("Data 1 : %s\n", it.next());
+        }
+        for (CloseableIterator<Row> it = outputTable[1].execute().collect(); 
it.hasNext(); ) {
+            System.out.printf("Data 2 : %s\n", it.next());
+        }
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/filter/Filter.java 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/filter/Filter.java
new file mode 100644
index 00000000..81ad0ba0
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/filter/Filter.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.filter;
+
+import org.apache.flink.ml.api.Transformer;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.table.api.ApiExpression;
+import org.apache.flink.table.api.Expressions;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+/** */
+public class Filter implements Transformer<Filter>, FilterParams<Filter> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public Filter() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        String[] inputCols = getInputCols();
+        String[] outputCols = getOutputCols();
+        ApiExpression[] expressions;
+        if (inputCols != null) {
+            expressions = new ApiExpression[inputCols.length];
+
+            for (int i = 0; i < inputCols.length; ++i) {
+                expressions[i] = Expressions.$(inputCols[i]);
+            }
+
+            Preconditions.checkArgument(outputCols.length == inputCols.length);
+            for (int i = 0; i < outputCols.length; ++i) {
+                expressions[i].as(outputCols[i]);
+            }
+        } else {
+            expressions = new ApiExpression[] {Expressions.$("*")};
+        }
+        Table outputTable =
+                
inputs[0].where(Expressions.callSql(getWhereClause())).select(expressions);
+        return new Table[] {outputTable};
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static Filter load(StreamTableEnvironment env, String path) throws 
IOException {
+        return ReadWriteUtils.loadStageParam(path);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/filter/FilterParams.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/filter/FilterParams.java
new file mode 100644
index 00000000..0d22782d
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/filter/FilterParams.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.filter;
+
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringArrayParam;
+import org.apache.flink.ml.param.StringParam;
+import org.apache.flink.ml.param.WithParams;
+
+/**
+ * Params of {@link Filter}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface FilterParams<T> extends WithParams<T> {
+    Param<String[]> INPUT_COLS =
+            new StringArrayParam(
+                    "inputCols", "Input column names.", null, 
ParamValidators.alwaysTrue());
+
+    Param<String[]> OUTPUT_COLS =
+            new StringArrayParam(
+                    "outputCols", "Output column names.", null, 
ParamValidators.alwaysTrue());
+
+    Param<String> WHERE_CLAUSE =
+            new StringParam("whereClause", "where clause.", null, 
ParamValidators.notNull());
+
+    default String[] getOutputCols() {
+        return get(OUTPUT_COLS);
+    }
+
+    default T setOutputCols(String... value) {
+        return set(OUTPUT_COLS, value);
+    }
+
+    default String getWhereClause() {
+        return get(WHERE_CLAUSE);
+    }
+
+    default T setWhereClause(String value) {
+        return set(WHERE_CLAUSE, value);
+    }
+
+    default String[] getInputCols() {
+        return get(INPUT_COLS);
+    }
+
+    default T setInputCols(String... value) {
+        return set(INPUT_COLS, value);
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/splitter/Splitter.java 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/splitter/Splitter.java
new file mode 100644
index 00000000..773bb2d4
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/splitter/Splitter.java
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.splitter;
+
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Random;
+
+/** An AlgoOperator which splits a Table into two Tables according to the 
given fraction. */
+public class Splitter implements AlgoOperator<Splitter>, 
SplitterParams<Splitter> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public Splitter() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        RowTypeInfo outputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+
+        OutputTag<Row> outputTag = new OutputTag<Row>("outputTag", 
outputTypeInfo) {};
+
+        int seed = getSeed() != null ? getSeed() : new Random().nextInt();
+        SingleOutputStreamOperator<Row> results =
+                tEnv.toDataStream(inputs[0])
+                        .transform(
+                                "SplitterOperator",
+                                outputTypeInfo,
+                                new SplitterOperator(outputTag, getFraction(), 
seed));
+
+        Table[] outputTables = new Table[2];
+        outputTables[0] = tEnv.fromDataStream(results);
+
+        DataStream<Row> dataStream = results.getSideOutput(outputTag);
+        outputTables[1] = tEnv.fromDataStream(dataStream);
+
+        return outputTables;
+    }
+
+    private static class SplitterOperator extends AbstractStreamOperator<Row>
+            implements OneInputStreamOperator<Row, Row> {
+        private final Random random;
+        OutputTag<Row> outputTag;
+        final double fraction;
+
+        public SplitterOperator(OutputTag<Row> outputTag, double fraction, int 
seed) {
+            this.outputTag = outputTag;
+            this.fraction = fraction;
+            random = new Random(seed);
+        }
+
+        @Override
+        public void processElement(StreamRecord<Row> streamRecord) throws 
Exception {
+            if (random.nextDouble() < fraction) {
+                output.collect(streamRecord);
+            } else {
+                output.collect(outputTag, streamRecord);
+            }
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static Splitter load(StreamTableEnvironment env, String path) 
throws IOException {
+        return ReadWriteUtils.loadStageParam(path);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/splitter/SplitterParams.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/splitter/SplitterParams.java
new file mode 100644
index 00000000..c316597e
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/splitter/SplitterParams.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.splitter;
+
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/**
+ * Params of {@link Splitter}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface SplitterParams<T> extends WithParams<T> {
+    Param<Double> FRACTION =
+            new DoubleParam(
+                    "fraction",
+                    "Proportion of data allocated to left output after 
splitting.",
+                    0.5,
+                    ParamValidators.inRange(0.0, 1.0));
+
+    default double getFraction() {
+        return get(FRACTION);
+    }
+
+    default T setFraction(Double value) {
+        return set(FRACTION, value);
+    }
+
+    Param<Integer> SEED = new IntParam("seed", "The random seed.", null);
+
+    default Integer getSeed() {
+        return get(SEED);
+    }
+
+    default T setSeed(int value) {
+        return set(SEED, value);
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/transform/TypeTransform.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/transform/TypeTransform.java
new file mode 100644
index 00000000..2b1b3bec
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/transform/TypeTransform.java
@@ -0,0 +1,186 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.transform;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Transformer;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.ApiExpression;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Expressions;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/** A Transformer that transform the types of special columns. */
+public class TypeTransform
+        implements Transformer<TypeTransform>, 
TypeTransformParams<TypeTransform> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final String PREFIX = "typed_";
+
+    public TypeTransform() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        ResolvedSchema schema = inputs[0].getResolvedSchema();
+        String[] allCols = schema.getColumnNames().toArray(new String[0]);
+        List<String> toDoubleCols = Arrays.asList(getToDoubleCols());
+        List<String> toFloatCols = Arrays.asList(getToFloatCols());
+        List<String> toIntCols = Arrays.asList(getToIntCols());
+        List<String> toLongCols = Arrays.asList(getToLongCols());
+        List<String> toStringCols = Arrays.asList(getToStringCols());
+        final boolean keepOrigin = getKeepOldCols();
+        final double defaultDoubleVal = getDefaultDoubleValue();
+        final float defaultFloatVal = getDefaultFloatValue();
+        final long defaultLongVal = getDefaultLongValue();
+        final int defaultIntVal = getDefaultIntValue();
+        final String defaultStringVal = getDefaultStringValue();
+
+        ApiExpression[] expressions =
+                new ApiExpression
+                        [allCols.length
+                                + (keepOrigin
+                                        ? (toDoubleCols.size()
+                                                + toLongCols.size()
+                                                + toStringCols.size()
+                                                + toFloatCols.size()
+                                                + toIntCols.size())
+                                        : 0)];
+        int iter = 0;
+        if (keepOrigin) {
+            for (String colName : allCols) {
+                expressions[iter++] = Expressions.$(colName);
+            }
+        }
+
+        for (String colName : allCols) {
+            if (toDoubleCols.contains(colName)) {
+                expressions[iter++] =
+                        Expressions.$(colName)
+                                .tryCast(DataTypes.DOUBLE())
+                                .as((keepOrigin ? PREFIX : "") + colName);
+            } else if (toFloatCols.contains(colName)) {
+                expressions[iter++] =
+                        Expressions.$(colName)
+                                .tryCast(DataTypes.FLOAT())
+                                .as((keepOrigin ? PREFIX : "") + colName);
+            } else if (toIntCols.contains(colName)) {
+                expressions[iter++] =
+                        Expressions.$(colName)
+                                .tryCast(DataTypes.INT())
+                                .as((keepOrigin ? PREFIX : "") + colName);
+            } else if (toLongCols.contains(colName)) {
+                expressions[iter++] =
+                        Expressions.$(colName)
+                                .tryCast(DataTypes.BIGINT())
+                                .as((keepOrigin ? PREFIX : "") + colName);
+            } else if (toStringCols.contains(colName)) {
+                expressions[iter++] =
+                        Expressions.$(colName)
+                                .tryCast(DataTypes.STRING())
+                                .as((keepOrigin ? PREFIX : "") + colName);
+            } else {
+                if (!keepOrigin) {
+                    expressions[iter++] = Expressions.$(colName);
+                }
+            }
+        }
+
+        Table middleTable = inputs[0].select(expressions);
+        DataStream<Row> outputStream = tEnv.toDataStream(middleTable);
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(middleTable.getResolvedSchema());
+
+        outputStream =
+                outputStream.map(
+                        (MapFunction<Row, Row>)
+                                row -> {
+                                    for (String colName : allCols) {
+                                        if (toDoubleCols.contains(colName)) {
+                                            String tmpName =
+                                                    keepOrigin ? PREFIX + 
colName : colName;
+                                            if (row.getField(tmpName) == null) 
{
+                                                row.setField(tmpName, 
defaultDoubleVal);
+                                            }
+                                        } else if 
(toFloatCols.contains(colName)) {
+                                            String tmpName =
+                                                    keepOrigin ? PREFIX + 
colName : colName;
+                                            if (row.getField(tmpName) == null) 
{
+                                                row.setField(tmpName, 
defaultFloatVal);
+                                            }
+                                        } else if 
(toIntCols.contains(colName)) {
+                                            String tmpName =
+                                                    keepOrigin ? PREFIX + 
colName : colName;
+                                            if (row.getField(tmpName) == null) 
{
+                                                row.setField(tmpName, 
defaultIntVal);
+                                            }
+                                        } else if 
(toLongCols.contains(colName)) {
+                                            String tmpName =
+                                                    keepOrigin ? PREFIX + 
colName : colName;
+                                            if (row.getField(tmpName) == null) 
{
+                                                row.setField(tmpName, 
defaultLongVal);
+                                            }
+                                        } else if 
(toStringCols.contains(colName)) {
+                                            String tmpName =
+                                                    keepOrigin ? PREFIX + 
colName : colName;
+                                            if (row.getField(tmpName) == null) 
{
+                                                row.setField(tmpName, 
defaultStringVal);
+                                            }
+                                        }
+                                    }
+
+                                    return row;
+                                },
+                        inputTypeInfo);
+        Table outputTable = tEnv.fromDataStream(outputStream);
+        return new Table[] {outputTable};
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static TypeTransform load(StreamTableEnvironment env, String path) 
throws IOException {
+        return ReadWriteUtils.loadStageParam(path);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/transform/TypeTransformParams.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/transform/TypeTransformParams.java
new file mode 100644
index 00000000..5749d177
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/transform/TypeTransformParams.java
@@ -0,0 +1,191 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.transform;
+
+import org.apache.flink.ml.param.BooleanParam;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.FloatParam;
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.LongParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringArrayParam;
+import org.apache.flink.ml.param.StringParam;
+import org.apache.flink.ml.param.WithParams;
+
+/**
+ * Params of {@link TypeTransform}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface TypeTransformParams<T> extends WithParams<T> {
+
+    Param<Double> DEFAULT_DOUBLE_VALUE =
+            new DoubleParam(
+                    "defaultDoubleValue",
+                    "The default double value.",
+                    0.0,
+                    ParamValidators.alwaysTrue());
+
+    Param<Float> DEFAULT_FLOAT_VALUE =
+            new FloatParam(
+                    "defaultFloatValue",
+                    "The default float value.",
+                    0.0F,
+                    ParamValidators.alwaysTrue());
+    Param<Integer> DEFAULT_INT_VALUE =
+            new IntParam(
+                    "defaultIntValue", "The default int value.", 0, 
ParamValidators.alwaysTrue());
+
+    Param<Long> DEFAULT_LONG_VALUE =
+            new LongParam(
+                    "defaultLongValue",
+                    "The default long value.",
+                    0L,
+                    ParamValidators.alwaysTrue());
+
+    Param<String> DEFAULT_STRING_VALUE =
+            new StringParam(
+                    "defaultStringValue",
+                    "The default string value.",
+                    "",
+                    ParamValidators.alwaysTrue());
+
+    Param<String[]> TO_DOUBLE_COLS =
+            new StringArrayParam(
+                    "toDoubleCols",
+                    "Input column names to double.",
+                    new String[] {},
+                    ParamValidators.alwaysTrue());
+    Param<String[]> TO_FLOAT_COLS =
+            new StringArrayParam(
+                    "toFloatCols",
+                    "Input column names to float.",
+                    new String[] {},
+                    ParamValidators.alwaysTrue());
+    Param<String[]> TO_INT_COLS =
+            new StringArrayParam(
+                    "toIntCols",
+                    "Input column names to int.",
+                    new String[] {},
+                    ParamValidators.alwaysTrue());
+    Param<String[]> TO_LONG_COLS =
+            new StringArrayParam(
+                    "toLongCols",
+                    "Input column names to long.",
+                    new String[] {},
+                    ParamValidators.alwaysTrue());
+    Param<String[]> TO_STRING_COLS =
+            new StringArrayParam(
+                    "toStringCols",
+                    "Input column names to string.",
+                    new String[] {},
+                    ParamValidators.alwaysTrue());
+
+    Param<Boolean> KEEP_OLD_COLS =
+            new BooleanParam("keepOldCols", "Whether to keep the old 
columns.", false);
+
+    default Double getDefaultDoubleValue() {
+        return get(DEFAULT_DOUBLE_VALUE);
+    }
+
+    default T setDefaultDoubleValue(Double value) {
+        return set(DEFAULT_DOUBLE_VALUE, value);
+    }
+
+    default Float getDefaultFloatValue() {
+        return get(DEFAULT_FLOAT_VALUE);
+    }
+
+    default T setDefaultFloatValue(Float value) {
+        return set(DEFAULT_FLOAT_VALUE, value);
+    }
+
+    default Integer getDefaultIntValue() {
+        return get(DEFAULT_INT_VALUE);
+    }
+
+    default T setDefaultIntValue(Integer value) {
+        return set(DEFAULT_INT_VALUE, value);
+    }
+
+    default Long getDefaultLongValue() {
+        return get(DEFAULT_LONG_VALUE);
+    }
+
+    default T setDefaultLongValue(Long value) {
+        return set(DEFAULT_LONG_VALUE, value);
+    }
+
+    default String getDefaultStringValue() {
+        return get(DEFAULT_STRING_VALUE);
+    }
+
+    default T setDefaultStringValue(String value) {
+        return set(DEFAULT_STRING_VALUE, value);
+    }
+
+    default String[] getToDoubleCols() {
+        return get(TO_DOUBLE_COLS);
+    }
+
+    default T setToDoubleCols(String... value) {
+        return set(TO_DOUBLE_COLS, value);
+    }
+
+    default String[] getToFloatCols() {
+        return get(TO_FLOAT_COLS);
+    }
+
+    default T setToFloatCols(String... value) {
+        return set(TO_FLOAT_COLS, value);
+    }
+
+    default String[] getToIntCols() {
+        return get(TO_INT_COLS);
+    }
+
+    default T setToIntCols(String... value) {
+        return set(TO_INT_COLS, value);
+    }
+
+    default String[] getToLongCols() {
+        return get(TO_LONG_COLS);
+    }
+
+    default T setToLongCols(String... value) {
+        return set(TO_LONG_COLS, value);
+    }
+
+    default String[] getToStringCols() {
+        return get(TO_STRING_COLS);
+    }
+
+    default T setToStringCols(String... value) {
+        return set(TO_STRING_COLS, value);
+    }
+
+    default boolean getKeepOldCols() {
+        return get(KEEP_OLD_COLS);
+    }
+
+    default T setKeepOldCols(boolean value) {
+        return set(KEEP_OLD_COLS, value);
+    }
+}
diff --git 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/FilterTest.java 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/FilterTest.java
new file mode 100644
index 00000000..a190a209
--- /dev/null
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/FilterTest.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.feature.filter.Filter;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.List;
+
+/** Tests the {@link Filter}. */
+public class FilterTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamTableEnvironment tEnv;
+    private Table inputTable;
+
+    private static final List<Row> inputData =
+            Arrays.asList(
+                    Row.of(1, "-0.5", "0.0", 1.0, 0.0),
+                    Row.of(
+                            2,
+                            "Double.NEGATIVE_INFINITY",
+                            "1.0",
+                            Double.POSITIVE_INFINITY,
+                            Double.NaN),
+                    Row.of(3, "1.0", "1.0", -0.5, null));
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        inputTable =
+                tEnv.fromDataStream(env.fromCollection(inputData)).as("id", 
"f1", "f2", "f3", "f4");
+    }
+
+    @Test
+    public void testTransform() throws Exception {
+        Filter filter =
+                new Filter()
+                        // .setInputCols("id", "f1", "f2", "f3", "f4")
+                        // .setOutputCols("id", "c1", "c2", "c4", "f4")
+                        // .setWhereClause("id=1 and f3<>-10");
+                        .setWhereClause("f1 rlike f2");
+        Table output = filter.transform(inputTable)[0];
+        List<Row> collectedResult =
+                
IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        for (Row row : collectedResult) {
+            System.out.println(row);
+        }
+    }
+}
diff --git 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/SplitterTest.java 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/SplitterTest.java
new file mode 100644
index 00000000..bc779524
--- /dev/null
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/SplitterTest.java
@@ -0,0 +1,181 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.feature.splitter.Splitter;
+import org.apache.flink.ml.util.TestUtils;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+
+/** Tests {@link Splitter}. */
+public class SplitterTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(2);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+
+        tEnv = StreamTableEnvironment.create(env);
+    }
+
+    private Table getTable(int size) {
+        List<Row> data = new ArrayList<>();
+        for (int i = 0; i < size; ++i) {
+            data.add(Row.of(i));
+        }
+        return tEnv.fromDataStream(env.fromCollection(data));
+    }
+
+    private void verifyOutputResult(Table output1, Table output2, int 
outputSize) throws Exception {
+        List<Integer> results1 =
+                IteratorUtils.toList(
+                        tEnv.toDataStream(output1)
+                                .map((MapFunction<Row, Integer>) row -> 
row.getFieldAs(0))
+                                .executeAndCollect());
+        List<Integer> results2 =
+                IteratorUtils.toList(
+                        tEnv.toDataStream(output2)
+                                .map((MapFunction<Row, Integer>) row -> 
row.getFieldAs(0))
+                                .executeAndCollect());
+
+        assertEquals(outputSize, results2.size());
+        assertEquals(outputSize, results1.size());
+
+        results1.sort(Integer::compare);
+        results2.sort(Integer::compare);
+        assertArrayEquals(results1.toArray(), results2.toArray());
+    }
+
+    @Test
+    public void testParam() {
+        Splitter splitter = new Splitter();
+        assertEquals(0.5, splitter.getFraction(), 1.0e-5);
+        assertNull(splitter.getSeed());
+
+        splitter.setSeed(2).setFraction(0.3);
+        assertEquals(0.3, splitter.getFraction(), 1.0e-5);
+        assertEquals(Integer.valueOf(2), splitter.getSeed());
+    }
+
+    @Test
+    public void testOutputSchema() {
+        Table tempTable =
+                tEnv.fromDataStream(env.fromElements(Row.of("", "")))
+                        .as("test_input", "dummy_input");
+
+        Splitter splitter = new Splitter();
+        Table output = splitter.transform(tempTable)[0];
+
+        assertEquals(
+                Arrays.asList("test_input", "dummy_input"),
+                output.getResolvedSchema().getColumnNames());
+    }
+
+    @Test
+    public void testFraction() throws Exception {
+        Table data = getTable(100);
+        Splitter splitter = new Splitter().setFraction(0.4).setSeed(2);
+        Table[] output = splitter.transform(data);
+        List<Integer> results1 =
+                IteratorUtils.toList(
+                        tEnv.toDataStream(output[0])
+                                .map((MapFunction<Row, Integer>) row -> 
row.getFieldAs(0))
+                                .executeAndCollect());
+        List<Integer> results2 =
+                IteratorUtils.toList(
+                        tEnv.toDataStream(output[1])
+                                .map((MapFunction<Row, Integer>) row -> 
row.getFieldAs(0))
+                                .executeAndCollect());
+        assertEquals(results1.size(), 42);
+        assertEquals(results2.size(), 58);
+    }
+
+    @Test
+    public void testSaveLoadAndTransform() throws Exception {
+        Table data = getTable(100);
+        Splitter splitter = new Splitter().setFraction(0.4).setSeed(2);
+
+        Splitter splitterLoad =
+                TestUtils.saveAndReload(
+                        tEnv, splitter, 
TEMPORARY_FOLDER.newFolder().getAbsolutePath());
+
+        Table[] output = splitterLoad.transform(data);
+
+        List<Integer> results1 =
+                IteratorUtils.toList(
+                        tEnv.toDataStream(output[0])
+                                .map((MapFunction<Row, Integer>) row -> 
row.getFieldAs(0))
+                                .executeAndCollect());
+        List<Integer> results2 =
+                IteratorUtils.toList(
+                        tEnv.toDataStream(output[1])
+                                .map((MapFunction<Row, Integer>) row -> 
row.getFieldAs(0))
+                                .executeAndCollect());
+        assertEquals(results1.size(), 42);
+        assertEquals(results2.size(), 58);
+    }
+
+    @Test
+    public void testSeed() throws Exception {
+        Table data = getTable(20);
+        Splitter splitter = new Splitter().setFraction(0.4).setSeed(1);
+
+        Table[] output1 = splitter.transform(data);
+        Table[] output2 = splitter.transform(data);
+        Table[] output3 = splitter.transform(data);
+
+        verifyOutputResult(output1[0], output2[0], 6);
+        verifyOutputResult(output1[1], output2[1], 14);
+
+        verifyOutputResult(output1[0], output3[0], 6);
+        verifyOutputResult(output1[1], output3[1], 14);
+
+        verifyOutputResult(output2[0], output3[0], 6);
+        verifyOutputResult(output2[1], output3[1], 14);
+    }
+}
diff --git 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/TypeTransformTest.java 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/TypeTransformTest.java
new file mode 100644
index 00000000..01039f9f
--- /dev/null
+++ 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/TypeTransformTest.java
@@ -0,0 +1,547 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.feature.transform.TypeTransform;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.math.BigDecimal;
+import java.util.Arrays;
+import java.util.List;
+
+/** Tests the {@link TypeTransform}. */
+public class TypeTransformTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table testTable;
+
+    private static final List<Row> testData =
+            Arrays.asList(
+                    Row.of(1, 1.0, "1.0"),
+                    Row.of(null, 2.3, "3.0"),
+                    Row.of(2, null, "3.0"),
+                    Row.of(2, 3.0, null),
+                    Row.of(1, 2.3, "ddsef"));
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        testTable = tEnv.fromDataStream(env.fromCollection(testData)).as("f0", 
"f1", "class");
+    }
+
+    Table getTable(Row r1, Row r2) {
+        List<Row> inputData = Arrays.asList(r1, r2);
+        return tEnv.fromDataStream(env.fromCollection(inputData)).as("f1", 
"f2");
+    }
+
+    @Test
+    public void test() throws Exception {
+        TypeTransform transform =
+                new TypeTransform()
+                        .setToLongCols("f0")
+                        .setToDoubleCols("f1", "class")
+                        .setKeepOldCols(true)
+                        .setDefaultLongValue(1000L)
+                        .setDefaultDoubleValue(1000.0);
+        Table output = transform.transform(testTable)[0];
+        System.out.println(output.getResolvedSchema());
+        List<Row> collectedResult =
+                
IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        for (Row row : collectedResult) {
+            System.out.println(row);
+        }
+    }
+
+    @Test
+    public void testIntToLong() throws Exception {
+        Table inputTable = getTable(Row.of(1, Integer.MAX_VALUE), Row.of(3, 
null));
+        TypeTransform transform =
+                new TypeTransform()
+                        .setToLongCols("f1", "f2")
+                        .setKeepOldCols(false)
+                        .setDefaultLongValue(1000L);
+        Table output = transform.transform(inputTable)[0];
+        System.out.println(output.getResolvedSchema());
+        List<Row> collectedResult =
+                
IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        for (Row row : collectedResult) {
+            System.out.println(row);
+        }
+    }
+
+    @Test
+    public void testIntToFloat() throws Exception {
+        Table inputTable = getTable(Row.of(1, Integer.MAX_VALUE), Row.of(3, 
null));
+        TypeTransform transform =
+                new TypeTransform()
+                        .setToFloatCols("f1", "f2")
+                        .setKeepOldCols(false)
+                        .setDefaultFloatValue(1000.0F);
+        Table output = transform.transform(inputTable)[0];
+        System.out.println(output.getResolvedSchema());
+        List<Row> collectedResult =
+                
IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        for (Row row : collectedResult) {
+            System.out.println(row);
+        }
+    }
+
+    @Test
+    public void testIntToDouble() throws Exception {
+        Table inputTable = getTable(Row.of(1, Integer.MAX_VALUE), Row.of(3, 
null));
+        TypeTransform transform =
+                new TypeTransform()
+                        .setToDoubleCols("f1", "f2")
+                        .setKeepOldCols(false)
+                        .setDefaultDoubleValue(1000.0);
+        Table output = transform.transform(inputTable)[0];
+        System.out.println(output.getResolvedSchema());
+        List<Row> collectedResult =
+                
IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        for (Row row : collectedResult) {
+            System.out.println(row);
+        }
+    }
+
+    @Test
+    public void testIntToString() throws Exception {
+        Table inputTable = getTable(Row.of(1, Integer.MAX_VALUE), Row.of(3, 
null));
+        TypeTransform transform =
+                new TypeTransform()
+                        .setToStringCols("f1", "f2")
+                        .setKeepOldCols(false)
+                        .setDefaultStringValue("1000.0F");
+        Table output = transform.transform(inputTable)[0];
+        System.out.println(output.getResolvedSchema());
+        List<Row> collectedResult =
+                
IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        for (Row row : collectedResult) {
+            System.out.println(row);
+        }
+    }
+
+    @Test
+    public void testDecimalToLong() throws Exception {
+        Table inputTable =
+                getTable(
+                        Row.of(new BigDecimal("1.0"), BigDecimal.ZERO),
+                        Row.of(BigDecimal.ONE, null));
+        TypeTransform transform =
+                new TypeTransform()
+                        .setToLongCols("f1", "f2")
+                        .setKeepOldCols(false)
+                        .setDefaultLongValue(1000L);
+        Table output = transform.transform(inputTable)[0];
+        System.out.println(output.getResolvedSchema());
+        List<Row> collectedResult =
+                
IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        for (Row row : collectedResult) {
+            System.out.println(row);
+        }
+    }
+
+    @Test
+    public void testDecimalToFloat() throws Exception {
+        Table inputTable =
+                getTable(
+                        Row.of(new BigDecimal("1.0"), BigDecimal.ZERO),
+                        Row.of(BigDecimal.ONE, null));
+        TypeTransform transform =
+                new TypeTransform()
+                        .setToFloatCols("f1", "f2")
+                        .setKeepOldCols(false)
+                        .setDefaultFloatValue(1000.0F);
+        Table output = transform.transform(inputTable)[0];
+        System.out.println(output.getResolvedSchema());
+        List<Row> collectedResult =
+                
IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        for (Row row : collectedResult) {
+            System.out.println(row);
+        }
+    }
+
+    @Test
+    public void testDecimalToDouble() throws Exception {
+        Table inputTable =
+                getTable(
+                        Row.of(new BigDecimal("1.0"), BigDecimal.ZERO),
+                        Row.of(BigDecimal.ONE, null));
+        TypeTransform transform =
+                new TypeTransform()
+                        .setToDoubleCols("f1", "f2")
+                        .setKeepOldCols(false)
+                        .setDefaultDoubleValue(1000.0);
+        Table output = transform.transform(inputTable)[0];
+        System.out.println(output.getResolvedSchema());
+        List<Row> collectedResult =
+                
IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        for (Row row : collectedResult) {
+            System.out.println(row);
+        }
+    }
+
+    @Test
+    public void testDecimalToString() throws Exception {
+        Table inputTable =
+                getTable(
+                        Row.of(new BigDecimal("1.0"), BigDecimal.ZERO),
+                        Row.of(BigDecimal.ONE, null));
+        TypeTransform transform =
+                new TypeTransform()
+                        .setToStringCols("f1", "f2")
+                        .setKeepOldCols(false)
+                        .setDefaultStringValue("1000.0F");
+        Table output = transform.transform(inputTable)[0];
+        System.out.println(output.getResolvedSchema());
+        List<Row> collectedResult =
+                
IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        for (Row row : collectedResult) {
+            System.out.println(row);
+        }
+    }
+
+    @Test
+    public void testDecimalToInt() throws Exception {
+        Table inputTable =
+                getTable(
+                        Row.of(new BigDecimal("1.0"), BigDecimal.ZERO),
+                        Row.of(BigDecimal.ONE, null));
+        TypeTransform transform =
+                new TypeTransform()
+                        .setToIntCols("f1", "f2")
+                        .setKeepOldCols(false)
+                        .setDefaultIntValue(1000);
+        Table output = transform.transform(inputTable)[0];
+        System.out.println(output.getResolvedSchema());
+        List<Row> collectedResult =
+                
IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        for (Row row : collectedResult) {
+            System.out.println(row);
+        }
+    }
+
+    @Test
+    public void testLongToFloat() throws Exception {
+        Table inputTable = getTable(Row.of(1L, Long.MIN_VALUE), Row.of(3L, 
null));
+        TypeTransform transform =
+                new TypeTransform()
+                        .setToFloatCols("f1", "f2")
+                        .setKeepOldCols(false)
+                        .setDefaultFloatValue(1000.0F);
+        Table output = transform.transform(inputTable)[0];
+        System.out.println(output.getResolvedSchema());
+        List<Row> collectedResult =
+                
IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        for (Row row : collectedResult) {
+            System.out.println(row);
+        }
+    }
+
+    @Test
+    public void testLongToDouble() throws Exception {
+        Table inputTable = getTable(Row.of(1L, Long.MIN_VALUE), Row.of(3L, 
null));
+        TypeTransform transform =
+                new TypeTransform()
+                        .setToDoubleCols("f1", "f2")
+                        .setKeepOldCols(false)
+                        .setDefaultDoubleValue(1000.0);
+        Table output = transform.transform(inputTable)[0];
+        System.out.println(output.getResolvedSchema());
+        List<Row> collectedResult =
+                
IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        for (Row row : collectedResult) {
+            System.out.println(row);
+        }
+    }
+
+    @Test
+    public void testLongToString() throws Exception {
+        Table inputTable = getTable(Row.of(1L, Long.MIN_VALUE), Row.of(3L, 
null));
+        TypeTransform transform =
+                new TypeTransform()
+                        .setToStringCols("f1", "f2")
+                        .setKeepOldCols(false)
+                        .setDefaultStringValue("1000.0F");
+        Table output = transform.transform(inputTable)[0];
+        System.out.println(output.getResolvedSchema());
+        List<Row> collectedResult =
+                
IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        for (Row row : collectedResult) {
+            System.out.println(row);
+        }
+    }
+
+    @Test
+    public void testLongToInt() throws Exception {
+        Table inputTable = getTable(Row.of(1L, Long.MIN_VALUE), Row.of(3L, 
null));
+        TypeTransform transform =
+                new TypeTransform()
+                        .setToIntCols("f1", "f2")
+                        .setKeepOldCols(false)
+                        .setDefaultIntValue(1000);
+        Table output = transform.transform(inputTable)[0];
+        System.out.println(output.getResolvedSchema());
+        List<Row> collectedResult =
+                
IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        for (Row row : collectedResult) {
+            System.out.println(row);
+        }
+    }
+
+    @Test
+    public void testDoubleToLong() throws Exception {
+        Table inputTable = getTable(Row.of(1.0, Double.MAX_VALUE), Row.of(2.0, 
null));
+        TypeTransform transform =
+                new TypeTransform()
+                        .setToLongCols("f1", "f2")
+                        .setKeepOldCols(false)
+                        .setDefaultLongValue(1000L);
+        Table output = transform.transform(inputTable)[0];
+        System.out.println(output.getResolvedSchema());
+        List<Row> collectedResult =
+                
IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        for (Row row : collectedResult) {
+            System.out.println(row);
+        }
+    }
+
+    @Test
+    public void testDoubleToFloat() throws Exception {
+        Table inputTable = getTable(Row.of(1.0, Double.MAX_VALUE), Row.of(2.0, 
null));
+        TypeTransform transform =
+                new TypeTransform()
+                        .setToFloatCols("f1", "f2")
+                        .setKeepOldCols(false)
+                        .setDefaultFloatValue(1000.0F);
+        Table output = transform.transform(inputTable)[0];
+        System.out.println(output.getResolvedSchema());
+        List<Row> collectedResult =
+                
IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        for (Row row : collectedResult) {
+            System.out.println(row);
+        }
+    }
+
+    @Test
+    public void testDoubleToString() throws Exception {
+        Table inputTable = getTable(Row.of(1.0, Double.MAX_VALUE), Row.of(2.0, 
null));
+        TypeTransform transform =
+                new TypeTransform()
+                        .setToStringCols("f1", "f2")
+                        .setKeepOldCols(false)
+                        .setDefaultStringValue("1000.0F");
+        Table output = transform.transform(inputTable)[0];
+        System.out.println(output.getResolvedSchema());
+        List<Row> collectedResult =
+                
IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        for (Row row : collectedResult) {
+            System.out.println(row);
+        }
+    }
+
+    @Test
+    public void testDoubleToInt() throws Exception {
+        Table inputTable = getTable(Row.of(1.0, Double.MAX_VALUE), Row.of(2.0, 
null));
+        TypeTransform transform =
+                new TypeTransform()
+                        .setToIntCols("f1", "f2")
+                        .setKeepOldCols(false)
+                        .setDefaultIntValue(1000);
+        Table output = transform.transform(inputTable)[0];
+        System.out.println(output.getResolvedSchema());
+        List<Row> collectedResult =
+                
IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        for (Row row : collectedResult) {
+            System.out.println(row);
+        }
+    }
+
+    @Test
+    public void testFloatToLong() throws Exception {
+        Table inputTable = getTable(Row.of(1.0f, Float.MAX_VALUE), 
Row.of(2.0f, null));
+        TypeTransform transform =
+                new TypeTransform()
+                        .setToLongCols("f1", "f2")
+                        .setKeepOldCols(false)
+                        .setDefaultLongValue(1000L);
+        Table output = transform.transform(inputTable)[0];
+        System.out.println(output.getResolvedSchema());
+        List<Row> collectedResult =
+                
IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        for (Row row : collectedResult) {
+            System.out.println(row);
+        }
+    }
+
+    @Test
+    public void testFloatToDouble() throws Exception {
+        Table inputTable = getTable(Row.of(1.0f, Float.MAX_VALUE), 
Row.of(2.0f, null));
+        TypeTransform transform =
+                new TypeTransform()
+                        .setToDoubleCols("f1", "f2")
+                        .setKeepOldCols(false)
+                        .setDefaultDoubleValue(1000.0);
+        Table output = transform.transform(inputTable)[0];
+        System.out.println(output.getResolvedSchema());
+        List<Row> collectedResult =
+                
IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        for (Row row : collectedResult) {
+            System.out.println(row);
+        }
+    }
+
+    @Test
+    public void testFloatToString() throws Exception {
+        Table inputTable = getTable(Row.of(1.0f, Float.MAX_VALUE), 
Row.of(2.0f, null));
+        TypeTransform transform =
+                new TypeTransform()
+                        .setToStringCols("f1", "f2")
+                        .setKeepOldCols(false)
+                        .setDefaultStringValue("1000.0F");
+        Table output = transform.transform(inputTable)[0];
+        System.out.println(output.getResolvedSchema());
+        List<Row> collectedResult =
+                
IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        for (Row row : collectedResult) {
+            System.out.println(row);
+        }
+    }
+
+    @Test
+    public void testFloatToInt() throws Exception {
+        Table inputTable = getTable(Row.of(1.0f, Float.MAX_VALUE), 
Row.of(2.0f, null));
+        TypeTransform transform =
+                new TypeTransform()
+                        .setToIntCols("f1", "f2")
+                        .setKeepOldCols(false)
+                        .setDefaultIntValue(1000);
+        Table output = transform.transform(inputTable)[0];
+        System.out.println(output.getResolvedSchema());
+        List<Row> collectedResult =
+                
IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        for (Row row : collectedResult) {
+            System.out.println(row);
+        }
+    }
+
+    @Test
+    public void testStringToLong() throws Exception {
+        Table inputTable = getTable(Row.of("1.0", "2.0"), Row.of("2.0", null));
+        TypeTransform transform =
+                new TypeTransform()
+                        .setToLongCols("f1", "f2")
+                        .setKeepOldCols(false)
+                        .setDefaultLongValue(1000L);
+        Table output = transform.transform(inputTable)[0];
+        System.out.println(output.getResolvedSchema());
+        List<Row> collectedResult =
+                
IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        for (Row row : collectedResult) {
+            System.out.println(row);
+        }
+    }
+
+    @Test
+    public void testStringToDouble() throws Exception {
+        Table inputTable = getTable(Row.of("1.0", "2.0"), Row.of("2.0", null));
+        TypeTransform transform =
+                new TypeTransform()
+                        .setToDoubleCols("f1", "f2")
+                        .setKeepOldCols(false)
+                        .setDefaultDoubleValue(1000.0);
+        Table output = transform.transform(inputTable)[0];
+        System.out.println(output.getResolvedSchema());
+        List<Row> collectedResult =
+                
IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        for (Row row : collectedResult) {
+            System.out.println(row);
+        }
+    }
+
+    @Test
+    public void testStringToFloat() throws Exception {
+        Table inputTable = getTable(Row.of("1.0", "2.0"), Row.of("2.0", null));
+        TypeTransform transform =
+                new TypeTransform()
+                        .setToFloatCols("f1", "f2")
+                        .setKeepOldCols(false)
+                        .setDefaultFloatValue(1000.0F);
+        Table output = transform.transform(inputTable)[0];
+        System.out.println(output.getResolvedSchema());
+        List<Row> collectedResult =
+                
IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        for (Row row : collectedResult) {
+            System.out.println(row);
+        }
+    }
+
+    @Test
+    public void testStringToInt() throws Exception {
+        Table inputTable = getTable(Row.of("1.0", "2.6"), Row.of("2.0", null));
+        TypeTransform transform =
+                new TypeTransform()
+                        .setToIntCols("f1", "f2")
+                        .setKeepOldCols(false)
+                        .setDefaultIntValue(1000);
+        Table output = transform.transform(inputTable)[0];
+        System.out.println(output.getResolvedSchema());
+        List<Row> collectedResult =
+                
IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        for (Row row : collectedResult) {
+            System.out.println(row);
+        }
+    }
+
+    @Test
+    public void testErrStringToInt() throws Exception {
+        Table inputTable = getTable(Row.of("1.0", "d212.6ded"), Row.of("2.0", 
null));
+        TypeTransform transform =
+                new TypeTransform()
+                        .setToIntCols("f1", "f2")
+                        .setKeepOldCols(false)
+                        .setDefaultIntValue(1000);
+        Table output = transform.transform(inputTable)[0];
+        System.out.println(output.getResolvedSchema());
+        List<Row> collectedResult =
+                
IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        for (Row row : collectedResult) {
+            System.out.println(row);
+        }
+    }
+}
diff --git a/flink-ml-python/pyflink/examples/ml/feature/splitter_example.py 
b/flink-ml-python/pyflink/examples/ml/feature/splitter_example.py
new file mode 100644
index 00000000..58bb417e
--- /dev/null
+++ b/flink-ml-python/pyflink/examples/ml/feature/splitter_example.py
@@ -0,0 +1,60 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+# Simple program that creates a Splitter instance and uses it for feature
+# engineering.
+
+from pyflink.common import Types
+from pyflink.datastream import StreamExecutionEnvironment
+from pyflink.ml.lib.feature.splitter import Splitter
+from pyflink.table import StreamTableEnvironment
+
+# Creates a new StreamExecutionEnvironment.
+env = StreamExecutionEnvironment.get_execution_environment()
+
+# Creates a StreamTableEnvironment.
+t_env = StreamTableEnvironment.create(env)
+
+# Generates input table.
+input_table = t_env.from_data_stream(
+    env.from_collection([
+        (1, 10, 0),
+        (1, 10, 0),
+        (1, 10, 0),
+        (4, 10, 0),
+        (5, 10, 0),
+        (6, 10, 0),
+        (7, 10, 0),
+        (10, 10, 0),
+        (13, 10, 0)
+    ],
+        type_info=Types.ROW_NAMED(
+            ['f0', 'f1', "f2"],
+            [Types.INT(), Types.INT(), Types.INT()])))
+
+# Creates a Splitter object and initializes its parameters.
+splitter = Splitter().set_seed(1).set_fraction(0.4)
+
+# Uses the Splitter to split the dataset.
+output = splitter.transform(input_table)
+
+# Extracts and displays the results.
+for result in t_env.to_data_stream(output[0]).execute_and_collect():
+    print('Data 1: ' + str(result))
+for result in t_env.to_data_stream(output[1]).execute_and_collect():
+    print('Data 2: ' + str(result))
diff --git a/flink-ml-python/pyflink/ml/lib/feature/splitter.py 
b/flink-ml-python/pyflink/ml/lib/feature/splitter.py
new file mode 100644
index 00000000..218bf874
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/feature/splitter.py
@@ -0,0 +1,82 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+import typing
+
+from pyflink.ml.core.param import Param, FloatParam, IntParam, ParamValidators
+from pyflink.ml.core.wrapper import JavaWithParams
+from pyflink.ml.lib.feature.common import JavaFeatureTransformer
+
+
+class _SplitterParams(
+    JavaWithParams
+):
+    """
+    Params for :class:`Splitter`.
+    """
+
+    FRACTION: Param[float] = FloatParam(
+        "fraction",
+        "Proportion of data allocated to left output after splitting.",
+        0.5,
+        ParamValidators.in_range(0.0, 1.0))
+
+    SEED: Param[int] = IntParam(
+        "seed",
+        "The random seed.",
+        None,
+        ParamValidators.always_true())
+
+    def __init__(self, java_params):
+        super(_SplitterParams, self).__init__(java_params)
+
+    def set_fraction(self, value: float):
+        return typing.cast(_SplitterParams, self.set(self.FRACTION, value))
+
+    def set_seed(self, value: int):
+        return typing.cast(_SplitterParams, self.set(self.SEED, value))
+
+    def get_fraction(self) -> float:
+        return self.get(self.FRACTION)
+
+    def get_seed(self) -> int:
+        return self.get(self.SEED)
+
+    @property
+    def fraction(self):
+        return self.get_fraction()
+
+    @property
+    def seed(self):
+        return self.get_seed()
+
+
+class Splitter(JavaFeatureTransformer, _SplitterParams):
+    """
+    An AlgoOperator which splits a dataset into two datasets according to a 
given fraction.
+    """
+
+    def __init__(self, java_model=None):
+        super(Splitter, self).__init__(java_model)
+
+    @classmethod
+    def _java_transformer_package_name(cls) -> str:
+        return "splitter"
+
+    @classmethod
+    def _java_transformer_class_name(cls) -> str:
+        return "Splitter"
diff --git a/flink-ml-python/pyflink/ml/lib/feature/tests/test_splitter.py 
b/flink-ml-python/pyflink/ml/lib/feature/tests/test_splitter.py
new file mode 100644
index 00000000..ec31561c
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/feature/tests/test_splitter.py
@@ -0,0 +1,104 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+import os
+
+from pyflink.common import Types
+
+from pyflink.ml.lib.feature.splitter import Splitter
+from pyflink.ml.tests.test_utils import PyFlinkMLTestCase
+
+
+class SplitterTest(PyFlinkMLTestCase):
+    def setUp(self):
+        super(SplitterTest, self).setUp()
+        self.input_table = self.t_env.from_data_stream(
+            self.env.from_collection([
+                (1,),
+                (2,),
+                (3,),
+                (4,),
+                (5,),
+                (6,),
+                (7,),
+                (8,),
+                (9,),
+                (0,)
+            ],
+                type_info=Types.ROW_NAMED(
+                    ['f0', ],
+                    [Types.INT(), ])))
+
+        self.expect_data1 = [(1), (3), (4), (7)]
+        self.expect_data2 = [(0), (2), (5), (6), (8), (9)]
+
+    def test_param(self):
+        splitter = Splitter()
+        self.assertEqual(0.5, splitter.fraction)
+
+        splitter \
+            .set_seed(1) \
+            .set_fraction(0.2)
+
+        self.assertEqual(0.2, splitter.fraction)
+        self.assertEqual(1, splitter.seed)
+
+    def test_output_schema(self):
+        splitter = Splitter()
+        input_data_table = self.t_env.from_data_stream(
+            self.env.from_collection([
+                ('', ''),
+            ],
+                type_info=Types.ROW_NAMED(
+                    ['test_input', 'dummy_input'],
+                    [Types.STRING(), Types.STRING()])))
+        output = splitter \
+            .transform(input_data_table)[0]
+
+        self.assertEqual(
+            ['test_input', 'dummy_input'],
+            output.get_schema().get_field_names())
+
+    def verify_split_result(self, expected1, expected2, output_table1, 
output_table2):
+        predicted_results1 = [result[0] for result in
+                              
self.t_env.to_data_stream(output_table1).execute_and_collect()]
+        predicted_results1.sort(key=lambda x: (x))
+        expected1.sort(key=lambda x: (x))
+        self.assertEqual(expected1, predicted_results1)
+
+        predicted_results2 = [result[0] for result in
+                              
self.t_env.to_data_stream(output_table2).execute_and_collect()]
+        predicted_results2.sort(key=lambda x: (x))
+        expected1.sort(key=lambda x: (x))
+        self.assertEqual(expected2, predicted_results2)
+
+    def test_fit_and_predict(self):
+        splitter = Splitter().set_seed(1).set_fraction(0.4)
+
+        output = splitter.transform(self.input_table)
+        self.verify_split_result(self.expect_data1, self.expect_data2, 
output[0], output[1])
+
+    def test_save_load_predict(self):
+        splitter = Splitter().set_seed(1).set_fraction(0.4)
+
+        estimator_path = os.path.join(self.temp_dir, 
'test_save_load_predict_splitter')
+        splitter.save(estimator_path)
+        splitter = Splitter.load(self.t_env, estimator_path)
+
+        output = splitter.transform(self.input_table)
+        self.verify_split_result(self.expect_data1, self.expect_data2, 
output[0], output[1])

Reply via email to