This is an automated email from the ASF dual-hosted git repository.
zhangzp pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new 5f99ce8 [FLINK-29011] Add Transformer for Binarizer
5f99ce8 is described below
commit 5f99ce8687b00c0cd0392a67c677f56a4f121a91
Author: weibo <[email protected]>
AuthorDate: Tue Aug 30 15:19:40 2022 +0800
[FLINK-29011] Add Transformer for Binarizer
This closes #146.
---
.../ml/examples/feature/BinarizerExample.java | 85 +++++++++++
.../flink/ml/feature/binarizer/Binarizer.java | 160 +++++++++++++++++++++
.../ml/feature/binarizer/BinarizerParams.java | 50 +++++++
.../org/apache/flink/ml/feature/BinarizerTest.java | 159 ++++++++++++++++++++
.../examples/ml/feature/binarizer_example.py | 69 +++++++++
.../pyflink/ml/lib/feature/binarizer.py | 70 +++++++++
.../pyflink/ml/lib/feature/tests/test_binarizer.py | 95 ++++++++++++
7 files changed, 688 insertions(+)
diff --git
a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/BinarizerExample.java
b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/BinarizerExample.java
new file mode 100644
index 0000000..0b51363
--- /dev/null
+++
b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/BinarizerExample.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.examples.feature;
+
+import org.apache.flink.ml.feature.binarizer.Binarizer;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+import java.util.Arrays;
+
+/** Simple program that creates a Binarizer instance and uses it for feature
engineering. */
+public class BinarizerExample {
+ public static void main(String[] args) {
+ StreamExecutionEnvironment env =
StreamExecutionEnvironment.getExecutionEnvironment();
+ StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+ // Generates input data.
+ DataStream<Row> inputStream =
+ env.fromElements(
+ Row.of(
+ 1,
+ Vectors.dense(1, 2),
+ Vectors.sparse(
+ 17, new int[] {0, 3, 9}, new double[]
{1.0, 2.0, 7.0})),
+ Row.of(
+ 2,
+ Vectors.dense(2, 1),
+ Vectors.sparse(
+ 17, new int[] {0, 2, 14}, new double[]
{5.0, 4.0, 1.0})),
+ Row.of(
+ 3,
+ Vectors.dense(5, 18),
+ Vectors.sparse(
+ 17, new int[] {0, 11, 12}, new
double[] {2.0, 4.0, 4.0})));
+
+ Table inputTable = tEnv.fromDataStream(inputStream).as("f0", "f1",
"f2");
+
+ // Creates a Binarizer object and initializes its parameters.
+ Binarizer binarizer =
+ new Binarizer()
+ .setInputCols("f0", "f1", "f2")
+ .setOutputCols("of0", "of1", "of2")
+ .setThresholds(0.0, 0.0, 0.0);
+
+ // Transforms input data.
+ Table outputTable = binarizer.transform(inputTable)[0];
+
+ // Extracts and displays the results.
+ for (CloseableIterator<Row> it = outputTable.execute().collect();
it.hasNext(); ) {
+ Row row = it.next();
+
+ Object[] inputValues = new Object[binarizer.getInputCols().length];
+ Object[] outputValues = new
Object[binarizer.getInputCols().length];
+ for (int i = 0; i < inputValues.length; i++) {
+ inputValues[i] = row.getField(binarizer.getInputCols()[i]);
+ outputValues[i] = row.getField(binarizer.getOutputCols()[i]);
+ }
+
+ System.out.printf(
+ "Input Values: %s\tOutput Values: %s\n",
+ Arrays.toString(inputValues),
Arrays.toString(outputValues));
+ }
+ }
+}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/binarizer/Binarizer.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/binarizer/Binarizer.java
new file mode 100644
index 0000000..bdf1637
--- /dev/null
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/binarizer/Binarizer.java
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.binarizer;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Transformer;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfo;
+import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * A Transformer that binarizes the columns of continuous features by the
given thresholds. The
+ * continuous features may be DenseVector, SparseVector, or Numerical Value.
+ */
+public class Binarizer implements Transformer<Binarizer>,
BinarizerParams<Binarizer> {
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+ public Binarizer() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public Table[] transform(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
inputs[0]).getTableEnvironment();
+
+ RowTypeInfo inputTypeInfo =
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+ String[] inputCols = getInputCols();
+ Preconditions.checkArgument(inputCols.length ==
getThresholds().length);
+ TypeInformation<?>[] outputTypes = new
TypeInformation[inputCols.length];
+
+ for (int i = 0; i < inputCols.length; ++i) {
+ int idx = inputTypeInfo.getFieldIndex(inputCols[i]);
+ if (inputTypeInfo.getFieldTypes()[idx] instanceof
SparseVectorTypeInfo) {
+ outputTypes[i] = SparseVectorTypeInfo.INSTANCE;
+ } else if (inputTypeInfo.getFieldTypes()[idx] instanceof
DenseVectorTypeInfo) {
+ outputTypes[i] = DenseVectorTypeInfo.INSTANCE;
+ } else if (inputTypeInfo.getFieldTypes()[idx] instanceof
VectorTypeInfo) {
+ outputTypes[i] = VectorTypeInfo.INSTANCE;
+ } else {
+ outputTypes[i] = Types.DOUBLE;
+ }
+ }
+
+ RowTypeInfo outputTypeInfo =
+ new RowTypeInfo(
+ ArrayUtils.addAll(inputTypeInfo.getFieldTypes(),
outputTypes),
+ ArrayUtils.addAll(inputTypeInfo.getFieldNames(),
getOutputCols()));
+
+ DataStream<Row> output =
+ tEnv.toDataStream(inputs[0])
+ .map(new BinarizeFunction(inputCols, getThresholds()),
outputTypeInfo);
+ Table outputTable = tEnv.fromDataStream(output);
+
+ return new Table[] {outputTable};
+ }
+
+ private static class BinarizeFunction implements MapFunction<Row, Row> {
+ private final String[] inputCols;
+ private final Double[] thresholds;
+
+ public BinarizeFunction(String[] inputCols, Double[] thresholds) {
+ this.inputCols = inputCols;
+ this.thresholds = thresholds;
+ }
+
+ @Override
+ public Row map(Row input) {
+ if (null == input) {
+ return null;
+ }
+
+ Row result = new Row(inputCols.length);
+ for (int i = 0; i < inputCols.length; ++i) {
+ result.setField(i, binarizerFunc(input.getField(inputCols[i]),
thresholds[i]));
+ }
+ return Row.join(input, result);
+ }
+
+ private Object binarizerFunc(Object obj, double threshold) {
+ if (obj instanceof DenseVector) {
+ DenseVector inputVec = (DenseVector) obj;
+ DenseVector vec = inputVec.clone();
+ for (int i = 0; i < vec.size(); ++i) {
+ vec.values[i] = inputVec.get(i) > threshold ? 1.0 : 0.0;
+ }
+ return vec;
+ } else if (obj instanceof SparseVector) {
+ SparseVector inputVec = (SparseVector) obj;
+ int[] newIndices = new int[inputVec.indices.length];
+ int pos = 0;
+
+ for (int i = 0; i < inputVec.indices.length; ++i) {
+ if (inputVec.values[i] > threshold) {
+ newIndices[pos++] = inputVec.indices[i];
+ }
+ }
+
+ double[] newValues = new double[pos];
+ Arrays.fill(newValues, 1.0);
+ return new SparseVector(inputVec.size(),
Arrays.copyOf(newIndices, pos), newValues);
+ } else {
+ return Double.parseDouble(obj.toString()) > threshold ? 1.0 :
0.0;
+ }
+ }
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ }
+
+ public static Binarizer load(StreamTableEnvironment env, String path)
throws IOException {
+ return ReadWriteUtils.loadStageParam(path);
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/binarizer/BinarizerParams.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/binarizer/BinarizerParams.java
new file mode 100644
index 0000000..abbf544
--- /dev/null
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/binarizer/BinarizerParams.java
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.binarizer;
+
+import org.apache.flink.ml.common.param.HasInputCols;
+import org.apache.flink.ml.common.param.HasOutputCols;
+import org.apache.flink.ml.param.DoubleArrayParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params of {@link Binarizer}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface BinarizerParams<T> extends HasInputCols<T>, HasOutputCols<T> {
+ Param<Double[]> THRESHOLDS =
+ new DoubleArrayParam(
+ "thresholds",
+ "The thresholds used to binarize continuous features. Each
threshold would be used "
+ + "against one input column. If the value of a
continuous feature is greater than the "
+ + "threshold, it will be binarized to 1.0. If the
value is equal to or less than the "
+ + "threshold, it will be binarized to 0.0.",
+ null,
+ ParamValidators.nonEmptyArray());
+
+ default Double[] getThresholds() {
+ return get(THRESHOLDS);
+ }
+
+ default T setThresholds(Double... value) {
+ return set(THRESHOLDS, value);
+ }
+}
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/BinarizerTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/BinarizerTest.java
new file mode 100644
index 0000000..9285555
--- /dev/null
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/BinarizerTest.java
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.feature.binarizer.Binarizer;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.TestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+
+/** Tests {@link Binarizer}. */
+public class BinarizerTest extends AbstractTestBase {
+
+ private StreamTableEnvironment tEnv;
+ private Table inputDataTable;
+
+ private static final List<Row> INPUT_DATA =
+ Arrays.asList(
+ Row.of(
+ 1,
+ Vectors.dense(1, 2),
+ Vectors.sparse(17, new int[] {0, 3, 9}, new
double[] {1.0, 2.0, 7.0})),
+ Row.of(
+ 2,
+ Vectors.dense(2, 1),
+ Vectors.sparse(17, new int[] {0, 2, 14}, new
double[] {5.0, 4.0, 1.0})),
+ Row.of(
+ 3,
+ Vectors.dense(5, 18),
+ Vectors.sparse(
+ 17, new int[] {0, 11, 12}, new double[]
{2.0, 4.0, 4.0})));
+
+ private static final Double[] EXPECTED_VALUE_OUTPUT = new Double[] {0.0,
1.0, 1.0};
+
+ private static final List<Vector> EXPECTED_DENSE_OUTPUT =
+ Arrays.asList(
+ Vectors.dense(0.0, 1.0), Vectors.dense(1.0, 0.0),
Vectors.dense(1.0, 1.0));
+
+ private static final List<Vector> EXPECTED_SPARSE_OUTPUT =
+ Arrays.asList(
+ Vectors.sparse(17, new int[] {9}, new double[] {1.0}),
+ Vectors.sparse(17, new int[] {0, 2}, new double[] {1.0,
1.0}),
+ Vectors.sparse(17, new int[] {11, 12}, new double[] {1.0,
1.0}));
+
+ @Before
+ public void before() {
+ Configuration config = new Configuration();
+
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH,
true);
+ StreamExecutionEnvironment env =
StreamExecutionEnvironment.getExecutionEnvironment(config);
+
+ env.setParallelism(4);
+ env.enableCheckpointing(100);
+ env.setRestartStrategy(RestartStrategies.noRestart());
+
+ tEnv = StreamTableEnvironment.create(env);
+ DataStream<Row> dataStream = env.fromCollection(INPUT_DATA);
+ inputDataTable = tEnv.fromDataStream(dataStream).as("f0", "f1", "f2");
+ }
+
+ private void verifyOutputResult(Table output, String[] outputCols) throws
Exception {
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
output).getTableEnvironment();
+ DataStream<Row> stream = tEnv.toDataStream(output);
+
+ List<Row> results = IteratorUtils.toList(stream.executeAndCollect());
+ List<Double> doubleValues = new ArrayList<>(results.size());
+ List<Vector> sparseVectorValues = new ArrayList<>(results.size());
+ List<Vector> denseVectorValues = new ArrayList<>(results.size());
+ for (Row row : results) {
+ doubleValues.add(row.getFieldAs(outputCols[0]));
+ denseVectorValues.add(row.getFieldAs(outputCols[1]));
+ sparseVectorValues.add(row.getFieldAs(outputCols[2]));
+ }
+ doubleValues.sort(Double::compare);
+ assertArrayEquals(EXPECTED_VALUE_OUTPUT, doubleValues.toArray());
+ compareResultCollections(EXPECTED_DENSE_OUTPUT, denseVectorValues,
TestUtils::compare);
+ compareResultCollections(EXPECTED_SPARSE_OUTPUT, sparseVectorValues,
TestUtils::compare);
+ }
+
+ @Test
+ public void testParam() {
+ Binarizer binarizer =
+ new Binarizer()
+ .setInputCols("f0", "f1", "f2")
+ .setOutputCols("of0", "of1", "of2")
+ .setThresholds(0.0, 0.0, 0.0);
+
+ assertArrayEquals(new String[] {"f0", "f1", "f2"},
binarizer.getInputCols());
+ assertArrayEquals(new String[] {"of0", "of1", "of2"},
binarizer.getOutputCols());
+ assertArrayEquals(new Double[] {0.0, 0.0, 0.0},
binarizer.getThresholds());
+ }
+
+ @Test
+ public void testOutputSchema() {
+ Binarizer binarizer =
+ new Binarizer()
+ .setInputCols("f0", "f1", "f2")
+ .setOutputCols("of0", "of1", "of2")
+ .setThresholds(0.0, 0.0, 0.0);
+
+ Table output = binarizer.transform(inputDataTable)[0];
+
+ assertEquals(
+ Arrays.asList("f0", "f1", "f2", "of0", "of1", "of2"),
+ output.getResolvedSchema().getColumnNames());
+ }
+
+ @Test
+ public void testSaveLoadAndTransform() throws Exception {
+ Binarizer binarizer =
+ new Binarizer()
+ .setInputCols("f0", "f1", "f2")
+ .setOutputCols("of0", "of1", "of2")
+ .setThresholds(1.0, 1.5, 2.5);
+
+ Binarizer loadedBinarizer =
+ TestUtils.saveAndReload(
+ tEnv, binarizer,
TEMPORARY_FOLDER.newFolder().getAbsolutePath());
+
+ Table output = loadedBinarizer.transform(inputDataTable)[0];
+ verifyOutputResult(output, loadedBinarizer.getOutputCols());
+ }
+}
diff --git a/flink-ml-python/pyflink/examples/ml/feature/binarizer_example.py
b/flink-ml-python/pyflink/examples/ml/feature/binarizer_example.py
new file mode 100644
index 0000000..8b82e6e
--- /dev/null
+++ b/flink-ml-python/pyflink/examples/ml/feature/binarizer_example.py
@@ -0,0 +1,69 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+# Simple program that creates a Binarizer instance and uses it for feature
+# engineering.
+#
+# Before executing this program, please make sure you have followed Flink ML's
+# quick start guideline to set up Flink ML and Flink environment. The guideline
+# can be found at
+#
+#
https://nightlies.apache.org/flink/flink-ml-docs-master/docs/try-flink-ml/quick-start/
+
+from pyflink.common import Types
+from pyflink.datastream import StreamExecutionEnvironment
+from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo
+from pyflink.ml.lib.feature.binarizer import Binarizer
+from pyflink.table import StreamTableEnvironment
+
+# create a new StreamExecutionEnvironment
+env = StreamExecutionEnvironment.get_execution_environment()
+
+# create a StreamTableEnvironment
+t_env = StreamTableEnvironment.create(env)
+
+# generate input data
+input_data_table = t_env.from_data_stream(
+ env.from_collection([
+ (1,
+ Vectors.dense(3, 4)),
+ (2,
+ Vectors.dense(6, 2))
+ ],
+ type_info=Types.ROW_NAMED(
+ ['f0', 'f1'],
+ [Types.INT(), DenseVectorTypeInfo()])))
+
+# create an binarizer object and initialize its parameters
+binarizer = Binarizer() \
+ .set_input_cols('f0', 'f1') \
+ .set_output_cols('of0', 'of1') \
+ .set_thresholds(1.5, 3.5)
+
+# use the binarizer for feature engineering
+output = binarizer.transform(input_data_table)[0]
+
+# extract and display the results
+field_names = output.get_schema().get_field_names()
+input_values = [None for _ in binarizer.get_input_cols()]
+output_values = [None for _ in binarizer.get_output_cols()]
+for result in t_env.to_data_stream(output).execute_and_collect():
+ for i in range(len(binarizer.get_input_cols())):
+ input_values[i] =
result[field_names.index(binarizer.get_input_cols()[i])]
+ output_values[i] =
result[field_names.index(binarizer.get_output_cols()[i])]
+ print('Input Values: ' + str(input_values) + '\tOutput Values: ' +
str(output_values))
diff --git a/flink-ml-python/pyflink/ml/lib/feature/binarizer.py
b/flink-ml-python/pyflink/ml/lib/feature/binarizer.py
new file mode 100644
index 0000000..c17eb6b
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/feature/binarizer.py
@@ -0,0 +1,70 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+from typing import Tuple
+
+from pyflink.ml.core.param import ParamValidators, Param, FloatArrayParam
+from pyflink.ml.core.wrapper import JavaWithParams
+from pyflink.ml.lib.feature.common import JavaFeatureTransformer
+from pyflink.ml.lib.param import HasInputCols, HasOutputCols
+
+
+class _BinarizerParams(
+ JavaWithParams,
+ HasInputCols,
+ HasOutputCols
+):
+ """
+ Params for :class:`Binarizer`.
+ """
+
+ THRESHOLDS: Param[Tuple[float, ...]] = FloatArrayParam(
+ "thresholds",
+ "The thresholds used to binarize continuous features. Each threshold
would be used "
+ + "against one input column. If the value of a continuous feature is
greater than the "
+ + "threshold, it will be binarized to 1.0. If the value is equal to or
less than the "
+ + "threshold, it will be binarized to 0.0.",
+ None,
+ ParamValidators.non_empty_array())
+
+ def set_thresholds(self, *thresholds: float):
+ return self.set(self.THRESHOLDS, thresholds)
+
+ def get_thresholds(self) -> Tuple[float, ...]:
+ return self.get(self.THRESHOLDS)
+
+ @property
+ def thresholds(self) -> Tuple[float, ...]:
+ return self.get_thresholds()
+
+
+class Binarizer(JavaFeatureTransformer, _BinarizerParams):
+ """
+ A Transformer that binarizes the columns of continuous features by the
given thresholds.
+ The continuous features may be DenseVector, SparseVector, or Numerical
Value.
+ """
+
+ def __init__(self, java_model=None):
+ super(Binarizer, self).__init__(java_model)
+
+ @classmethod
+ def _java_transformer_package_name(cls) -> str:
+ return "binarizer"
+
+ @classmethod
+ def _java_transformer_class_name(cls) -> str:
+ return "Binarizer"
diff --git a/flink-ml-python/pyflink/ml/lib/feature/tests/test_binarizer.py
b/flink-ml-python/pyflink/ml/lib/feature/tests/test_binarizer.py
new file mode 100644
index 0000000..e153f69
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/feature/tests/test_binarizer.py
@@ -0,0 +1,95 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+import os
+
+from pyflink.common import Types
+
+from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo,
SparseVectorTypeInfo
+from pyflink.ml.lib.feature.binarizer import Binarizer
+from pyflink.ml.tests.test_utils import PyFlinkMLTestCase
+
+
+class BinarizerTest(PyFlinkMLTestCase):
+ def setUp(self):
+ super(BinarizerTest, self).setUp()
+ self.input_data_table = self.t_env.from_data_stream(
+ self.env.from_collection([
+ (1,
+ Vectors.dense(1, 2),
+ Vectors.sparse(17, [0, 3, 9], [1.0, 2.0, 7.0])),
+ (2,
+ Vectors.dense(2, 1),
+ Vectors.sparse(17, [0, 2, 14], [5.0, 4.0, 1.0])),
+ (3,
+ Vectors.dense(5, 18),
+ Vectors.sparse(17, [0, 11, 12], [2.0, 4.0, 4.0]))
+ ],
+ type_info=Types.ROW_NAMED(
+ ['f0', 'f1', 'f2'],
+ [Types.INT(), DenseVectorTypeInfo(),
SparseVectorTypeInfo()])))
+
+ self.expected_output_data = [[0.0,
+ Vectors.dense(0.0, 1.0),
+ Vectors.sparse(17, [9], [1.0])],
+ [1.0,
+ Vectors.dense(1.0, 0.0),
+ Vectors.sparse(17, [0, 2], [1.0, 1.0])],
+ [1.0,
+ Vectors.dense(1.0, 1.0),
+ Vectors.sparse(17, [11, 12], [1.0,
1.0])]]
+
+ def test_param(self):
+ binarizer = Binarizer()
+
+ binarizer.set_input_cols('f0', 'f1') \
+ .set_output_cols('of0', 'of1') \
+ .set_thresholds(1.5, 2.5)
+
+ self.assertEqual(('f0', 'f1'), binarizer.input_cols)
+ self.assertEqual(('of0', 'of1'), binarizer.output_cols)
+ self.assertEqual((1.5, 2.5), binarizer.get_thresholds())
+
+ def test_save_load_transform(self):
+ binarizer = Binarizer() \
+ .set_input_cols('f0', 'f1', 'f2') \
+ .set_output_cols('of0', 'of1', 'of2') \
+ .set_thresholds(1.0, 1.5, 2.5)
+
+ path = os.path.join(self.temp_dir,
'test_save_load_transform_binarizer')
+ binarizer.save(path)
+ binarizer = Binarizer.load(self.t_env, path)
+
+ output_table = binarizer.transform(self.input_data_table)[0]
+ actual_outputs = [(result[0], result[3], result[4], result[5]) for
result in
+
self.t_env.to_data_stream(output_table).execute_and_collect()]
+
+ self.assertEqual(3, len(actual_outputs))
+
+ actual_outputs.sort()
+
+ for i in range(len(actual_outputs)):
+ actual_output = actual_outputs[i]
+ self.assertAlmostEqual(self.expected_output_data[i][0],
actual_output[1], delta=1.0e-7)
+ self.assertEqual(2, len(actual_output[2]))
+ for j in range(len(actual_output[2])):
+ self.assertAlmostEqual(self.expected_output_data[i][1].get(j),
+ actual_output[2].get(j), delta=1e-7)
+ self.assertEqual(17, len(actual_output[3]))
+ for j in range(len(actual_output[3])):
+ self.assertAlmostEqual(self.expected_output_data[i][2].get(j),
+ actual_output[3].get(j), delta=1e-7)