This is an automated email from the ASF dual-hosted git repository.

zhangzp pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new 5f99ce8  [FLINK-29011] Add Transformer for Binarizer
5f99ce8 is described below

commit 5f99ce8687b00c0cd0392a67c677f56a4f121a91
Author: weibo <[email protected]>
AuthorDate: Tue Aug 30 15:19:40 2022 +0800

    [FLINK-29011] Add Transformer for Binarizer
    
    This closes #146.
---
 .../ml/examples/feature/BinarizerExample.java      |  85 +++++++++++
 .../flink/ml/feature/binarizer/Binarizer.java      | 160 +++++++++++++++++++++
 .../ml/feature/binarizer/BinarizerParams.java      |  50 +++++++
 .../org/apache/flink/ml/feature/BinarizerTest.java | 159 ++++++++++++++++++++
 .../examples/ml/feature/binarizer_example.py       |  69 +++++++++
 .../pyflink/ml/lib/feature/binarizer.py            |  70 +++++++++
 .../pyflink/ml/lib/feature/tests/test_binarizer.py |  95 ++++++++++++
 7 files changed, 688 insertions(+)

diff --git 
a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/BinarizerExample.java
 
b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/BinarizerExample.java
new file mode 100644
index 0000000..0b51363
--- /dev/null
+++ 
b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/BinarizerExample.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.examples.feature;
+
+import org.apache.flink.ml.feature.binarizer.Binarizer;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+import java.util.Arrays;
+
+/** Simple program that creates a Binarizer instance and uses it for feature 
engineering. */
+public class BinarizerExample {
+    public static void main(String[] args) {
+        StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+        // Generates input data.
+        DataStream<Row> inputStream =
+                env.fromElements(
+                        Row.of(
+                                1,
+                                Vectors.dense(1, 2),
+                                Vectors.sparse(
+                                        17, new int[] {0, 3, 9}, new double[] 
{1.0, 2.0, 7.0})),
+                        Row.of(
+                                2,
+                                Vectors.dense(2, 1),
+                                Vectors.sparse(
+                                        17, new int[] {0, 2, 14}, new double[] 
{5.0, 4.0, 1.0})),
+                        Row.of(
+                                3,
+                                Vectors.dense(5, 18),
+                                Vectors.sparse(
+                                        17, new int[] {0, 11, 12}, new 
double[] {2.0, 4.0, 4.0})));
+
+        Table inputTable = tEnv.fromDataStream(inputStream).as("f0", "f1", 
"f2");
+
+        // Creates a Binarizer object and initializes its parameters.
+        Binarizer binarizer =
+                new Binarizer()
+                        .setInputCols("f0", "f1", "f2")
+                        .setOutputCols("of0", "of1", "of2")
+                        .setThresholds(0.0, 0.0, 0.0);
+
+        // Transforms input data.
+        Table outputTable = binarizer.transform(inputTable)[0];
+
+        // Extracts and displays the results.
+        for (CloseableIterator<Row> it = outputTable.execute().collect(); 
it.hasNext(); ) {
+            Row row = it.next();
+
+            Object[] inputValues = new Object[binarizer.getInputCols().length];
+            Object[] outputValues = new 
Object[binarizer.getInputCols().length];
+            for (int i = 0; i < inputValues.length; i++) {
+                inputValues[i] = row.getField(binarizer.getInputCols()[i]);
+                outputValues[i] = row.getField(binarizer.getOutputCols()[i]);
+            }
+
+            System.out.printf(
+                    "Input Values: %s\tOutput Values: %s\n",
+                    Arrays.toString(inputValues), 
Arrays.toString(outputValues));
+        }
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/binarizer/Binarizer.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/binarizer/Binarizer.java
new file mode 100644
index 0000000..bdf1637
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/binarizer/Binarizer.java
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.binarizer;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Transformer;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfo;
+import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * A Transformer that binarizes the columns of continuous features by the 
given thresholds. The
+ * continuous features may be DenseVector, SparseVector, or Numerical Value.
+ */
+public class Binarizer implements Transformer<Binarizer>, 
BinarizerParams<Binarizer> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public Binarizer() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        String[] inputCols = getInputCols();
+        Preconditions.checkArgument(inputCols.length == 
getThresholds().length);
+        TypeInformation<?>[] outputTypes = new 
TypeInformation[inputCols.length];
+
+        for (int i = 0; i < inputCols.length; ++i) {
+            int idx = inputTypeInfo.getFieldIndex(inputCols[i]);
+            if (inputTypeInfo.getFieldTypes()[idx] instanceof 
SparseVectorTypeInfo) {
+                outputTypes[i] = SparseVectorTypeInfo.INSTANCE;
+            } else if (inputTypeInfo.getFieldTypes()[idx] instanceof 
DenseVectorTypeInfo) {
+                outputTypes[i] = DenseVectorTypeInfo.INSTANCE;
+            } else if (inputTypeInfo.getFieldTypes()[idx] instanceof 
VectorTypeInfo) {
+                outputTypes[i] = VectorTypeInfo.INSTANCE;
+            } else {
+                outputTypes[i] = Types.DOUBLE;
+            }
+        }
+
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), 
outputTypes),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
getOutputCols()));
+
+        DataStream<Row> output =
+                tEnv.toDataStream(inputs[0])
+                        .map(new BinarizeFunction(inputCols, getThresholds()), 
outputTypeInfo);
+        Table outputTable = tEnv.fromDataStream(output);
+
+        return new Table[] {outputTable};
+    }
+
+    private static class BinarizeFunction implements MapFunction<Row, Row> {
+        private final String[] inputCols;
+        private final Double[] thresholds;
+
+        public BinarizeFunction(String[] inputCols, Double[] thresholds) {
+            this.inputCols = inputCols;
+            this.thresholds = thresholds;
+        }
+
+        @Override
+        public Row map(Row input) {
+            if (null == input) {
+                return null;
+            }
+
+            Row result = new Row(inputCols.length);
+            for (int i = 0; i < inputCols.length; ++i) {
+                result.setField(i, binarizerFunc(input.getField(inputCols[i]), 
thresholds[i]));
+            }
+            return Row.join(input, result);
+        }
+
+        private Object binarizerFunc(Object obj, double threshold) {
+            if (obj instanceof DenseVector) {
+                DenseVector inputVec = (DenseVector) obj;
+                DenseVector vec = inputVec.clone();
+                for (int i = 0; i < vec.size(); ++i) {
+                    vec.values[i] = inputVec.get(i) > threshold ? 1.0 : 0.0;
+                }
+                return vec;
+            } else if (obj instanceof SparseVector) {
+                SparseVector inputVec = (SparseVector) obj;
+                int[] newIndices = new int[inputVec.indices.length];
+                int pos = 0;
+
+                for (int i = 0; i < inputVec.indices.length; ++i) {
+                    if (inputVec.values[i] > threshold) {
+                        newIndices[pos++] = inputVec.indices[i];
+                    }
+                }
+
+                double[] newValues = new double[pos];
+                Arrays.fill(newValues, 1.0);
+                return new SparseVector(inputVec.size(), 
Arrays.copyOf(newIndices, pos), newValues);
+            } else {
+                return Double.parseDouble(obj.toString()) > threshold ? 1.0 : 
0.0;
+            }
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static Binarizer load(StreamTableEnvironment env, String path) 
throws IOException {
+        return ReadWriteUtils.loadStageParam(path);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/binarizer/BinarizerParams.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/binarizer/BinarizerParams.java
new file mode 100644
index 0000000..abbf544
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/binarizer/BinarizerParams.java
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.binarizer;
+
+import org.apache.flink.ml.common.param.HasInputCols;
+import org.apache.flink.ml.common.param.HasOutputCols;
+import org.apache.flink.ml.param.DoubleArrayParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params of {@link Binarizer}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface BinarizerParams<T> extends HasInputCols<T>, HasOutputCols<T> {
+    Param<Double[]> THRESHOLDS =
+            new DoubleArrayParam(
+                    "thresholds",
+                    "The thresholds used to binarize continuous features. Each 
threshold would be used "
+                            + "against one input column. If the value of a 
continuous feature is greater than the "
+                            + "threshold, it will be binarized to 1.0. If the 
value is equal to or less than the "
+                            + "threshold, it will be binarized to 0.0.",
+                    null,
+                    ParamValidators.nonEmptyArray());
+
+    default Double[] getThresholds() {
+        return get(THRESHOLDS);
+    }
+
+    default T setThresholds(Double... value) {
+        return set(THRESHOLDS, value);
+    }
+}
diff --git 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/BinarizerTest.java 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/BinarizerTest.java
new file mode 100644
index 0000000..9285555
--- /dev/null
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/BinarizerTest.java
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.feature.binarizer.Binarizer;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.TestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+
+/** Tests {@link Binarizer}. */
+public class BinarizerTest extends AbstractTestBase {
+
+    private StreamTableEnvironment tEnv;
+    private Table inputDataTable;
+
+    private static final List<Row> INPUT_DATA =
+            Arrays.asList(
+                    Row.of(
+                            1,
+                            Vectors.dense(1, 2),
+                            Vectors.sparse(17, new int[] {0, 3, 9}, new 
double[] {1.0, 2.0, 7.0})),
+                    Row.of(
+                            2,
+                            Vectors.dense(2, 1),
+                            Vectors.sparse(17, new int[] {0, 2, 14}, new 
double[] {5.0, 4.0, 1.0})),
+                    Row.of(
+                            3,
+                            Vectors.dense(5, 18),
+                            Vectors.sparse(
+                                    17, new int[] {0, 11, 12}, new double[] 
{2.0, 4.0, 4.0})));
+
+    private static final Double[] EXPECTED_VALUE_OUTPUT = new Double[] {0.0, 
1.0, 1.0};
+
+    private static final List<Vector> EXPECTED_DENSE_OUTPUT =
+            Arrays.asList(
+                    Vectors.dense(0.0, 1.0), Vectors.dense(1.0, 0.0), 
Vectors.dense(1.0, 1.0));
+
+    private static final List<Vector> EXPECTED_SPARSE_OUTPUT =
+            Arrays.asList(
+                    Vectors.sparse(17, new int[] {9}, new double[] {1.0}),
+                    Vectors.sparse(17, new int[] {0, 2}, new double[] {1.0, 
1.0}),
+                    Vectors.sparse(17, new int[] {11, 12}, new double[] {1.0, 
1.0}));
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment(config);
+
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+
+        tEnv = StreamTableEnvironment.create(env);
+        DataStream<Row> dataStream = env.fromCollection(INPUT_DATA);
+        inputDataTable = tEnv.fromDataStream(dataStream).as("f0", "f1", "f2");
+    }
+
+    private void verifyOutputResult(Table output, String[] outputCols) throws 
Exception {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
output).getTableEnvironment();
+        DataStream<Row> stream = tEnv.toDataStream(output);
+
+        List<Row> results = IteratorUtils.toList(stream.executeAndCollect());
+        List<Double> doubleValues = new ArrayList<>(results.size());
+        List<Vector> sparseVectorValues = new ArrayList<>(results.size());
+        List<Vector> denseVectorValues = new ArrayList<>(results.size());
+        for (Row row : results) {
+            doubleValues.add(row.getFieldAs(outputCols[0]));
+            denseVectorValues.add(row.getFieldAs(outputCols[1]));
+            sparseVectorValues.add(row.getFieldAs(outputCols[2]));
+        }
+        doubleValues.sort(Double::compare);
+        assertArrayEquals(EXPECTED_VALUE_OUTPUT, doubleValues.toArray());
+        compareResultCollections(EXPECTED_DENSE_OUTPUT, denseVectorValues, 
TestUtils::compare);
+        compareResultCollections(EXPECTED_SPARSE_OUTPUT, sparseVectorValues, 
TestUtils::compare);
+    }
+
+    @Test
+    public void testParam() {
+        Binarizer binarizer =
+                new Binarizer()
+                        .setInputCols("f0", "f1", "f2")
+                        .setOutputCols("of0", "of1", "of2")
+                        .setThresholds(0.0, 0.0, 0.0);
+
+        assertArrayEquals(new String[] {"f0", "f1", "f2"}, 
binarizer.getInputCols());
+        assertArrayEquals(new String[] {"of0", "of1", "of2"}, 
binarizer.getOutputCols());
+        assertArrayEquals(new Double[] {0.0, 0.0, 0.0}, 
binarizer.getThresholds());
+    }
+
+    @Test
+    public void testOutputSchema() {
+        Binarizer binarizer =
+                new Binarizer()
+                        .setInputCols("f0", "f1", "f2")
+                        .setOutputCols("of0", "of1", "of2")
+                        .setThresholds(0.0, 0.0, 0.0);
+
+        Table output = binarizer.transform(inputDataTable)[0];
+
+        assertEquals(
+                Arrays.asList("f0", "f1", "f2", "of0", "of1", "of2"),
+                output.getResolvedSchema().getColumnNames());
+    }
+
+    @Test
+    public void testSaveLoadAndTransform() throws Exception {
+        Binarizer binarizer =
+                new Binarizer()
+                        .setInputCols("f0", "f1", "f2")
+                        .setOutputCols("of0", "of1", "of2")
+                        .setThresholds(1.0, 1.5, 2.5);
+
+        Binarizer loadedBinarizer =
+                TestUtils.saveAndReload(
+                        tEnv, binarizer, 
TEMPORARY_FOLDER.newFolder().getAbsolutePath());
+
+        Table output = loadedBinarizer.transform(inputDataTable)[0];
+        verifyOutputResult(output, loadedBinarizer.getOutputCols());
+    }
+}
diff --git a/flink-ml-python/pyflink/examples/ml/feature/binarizer_example.py 
b/flink-ml-python/pyflink/examples/ml/feature/binarizer_example.py
new file mode 100644
index 0000000..8b82e6e
--- /dev/null
+++ b/flink-ml-python/pyflink/examples/ml/feature/binarizer_example.py
@@ -0,0 +1,69 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+# Simple program that creates a Binarizer instance and uses it for feature
+# engineering.
+#
+# Before executing this program, please make sure you have followed Flink ML's
+# quick start guideline to set up Flink ML and Flink environment. The guideline
+# can be found at
+#
+# 
https://nightlies.apache.org/flink/flink-ml-docs-master/docs/try-flink-ml/quick-start/
+
+from pyflink.common import Types
+from pyflink.datastream import StreamExecutionEnvironment
+from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo
+from pyflink.ml.lib.feature.binarizer import Binarizer
+from pyflink.table import StreamTableEnvironment
+
+# create a new StreamExecutionEnvironment
+env = StreamExecutionEnvironment.get_execution_environment()
+
+# create a StreamTableEnvironment
+t_env = StreamTableEnvironment.create(env)
+
+# generate input data
+input_data_table = t_env.from_data_stream(
+    env.from_collection([
+        (1,
+         Vectors.dense(3, 4)),
+        (2,
+         Vectors.dense(6, 2))
+    ],
+        type_info=Types.ROW_NAMED(
+            ['f0', 'f1'],
+            [Types.INT(), DenseVectorTypeInfo()])))
+
+# create an binarizer object and initialize its parameters
+binarizer = Binarizer() \
+    .set_input_cols('f0', 'f1') \
+    .set_output_cols('of0', 'of1') \
+    .set_thresholds(1.5, 3.5)
+
+# use the binarizer for feature engineering
+output = binarizer.transform(input_data_table)[0]
+
+# extract and display the results
+field_names = output.get_schema().get_field_names()
+input_values = [None for _ in binarizer.get_input_cols()]
+output_values = [None for _ in binarizer.get_output_cols()]
+for result in t_env.to_data_stream(output).execute_and_collect():
+    for i in range(len(binarizer.get_input_cols())):
+        input_values[i] = 
result[field_names.index(binarizer.get_input_cols()[i])]
+        output_values[i] = 
result[field_names.index(binarizer.get_output_cols()[i])]
+    print('Input Values: ' + str(input_values) + '\tOutput Values: ' + 
str(output_values))
diff --git a/flink-ml-python/pyflink/ml/lib/feature/binarizer.py 
b/flink-ml-python/pyflink/ml/lib/feature/binarizer.py
new file mode 100644
index 0000000..c17eb6b
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/feature/binarizer.py
@@ -0,0 +1,70 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+from typing import Tuple
+
+from pyflink.ml.core.param import ParamValidators, Param, FloatArrayParam
+from pyflink.ml.core.wrapper import JavaWithParams
+from pyflink.ml.lib.feature.common import JavaFeatureTransformer
+from pyflink.ml.lib.param import HasInputCols, HasOutputCols
+
+
+class _BinarizerParams(
+    JavaWithParams,
+    HasInputCols,
+    HasOutputCols
+):
+    """
+    Params for :class:`Binarizer`.
+    """
+
+    THRESHOLDS: Param[Tuple[float, ...]] = FloatArrayParam(
+        "thresholds",
+        "The thresholds used to binarize continuous features. Each threshold 
would be used "
+        + "against one input column. If the value of a continuous feature is 
greater than the "
+        + "threshold, it will be binarized to 1.0. If the value is equal to or 
less than the "
+        + "threshold, it will be binarized to 0.0.",
+        None,
+        ParamValidators.non_empty_array())
+
+    def set_thresholds(self, *thresholds: float):
+        return self.set(self.THRESHOLDS, thresholds)
+
+    def get_thresholds(self) -> Tuple[float, ...]:
+        return self.get(self.THRESHOLDS)
+
+    @property
+    def thresholds(self) -> Tuple[float, ...]:
+        return self.get_thresholds()
+
+
+class Binarizer(JavaFeatureTransformer, _BinarizerParams):
+    """
+    A Transformer that binarizes the columns of continuous features by the 
given thresholds.
+    The continuous features may be DenseVector, SparseVector, or Numerical 
Value.
+    """
+
+    def __init__(self, java_model=None):
+        super(Binarizer, self).__init__(java_model)
+
+    @classmethod
+    def _java_transformer_package_name(cls) -> str:
+        return "binarizer"
+
+    @classmethod
+    def _java_transformer_class_name(cls) -> str:
+        return "Binarizer"
diff --git a/flink-ml-python/pyflink/ml/lib/feature/tests/test_binarizer.py 
b/flink-ml-python/pyflink/ml/lib/feature/tests/test_binarizer.py
new file mode 100644
index 0000000..e153f69
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/feature/tests/test_binarizer.py
@@ -0,0 +1,95 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+import os
+
+from pyflink.common import Types
+
+from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo, 
SparseVectorTypeInfo
+from pyflink.ml.lib.feature.binarizer import Binarizer
+from pyflink.ml.tests.test_utils import PyFlinkMLTestCase
+
+
+class BinarizerTest(PyFlinkMLTestCase):
+    def setUp(self):
+        super(BinarizerTest, self).setUp()
+        self.input_data_table = self.t_env.from_data_stream(
+            self.env.from_collection([
+                (1,
+                 Vectors.dense(1, 2),
+                 Vectors.sparse(17, [0, 3, 9], [1.0, 2.0, 7.0])),
+                (2,
+                 Vectors.dense(2, 1),
+                 Vectors.sparse(17, [0, 2, 14], [5.0, 4.0, 1.0])),
+                (3,
+                 Vectors.dense(5, 18),
+                 Vectors.sparse(17, [0, 11, 12], [2.0, 4.0, 4.0]))
+            ],
+                type_info=Types.ROW_NAMED(
+                    ['f0', 'f1', 'f2'],
+                    [Types.INT(), DenseVectorTypeInfo(), 
SparseVectorTypeInfo()])))
+
+        self.expected_output_data = [[0.0,
+                                      Vectors.dense(0.0, 1.0),
+                                      Vectors.sparse(17, [9], [1.0])],
+                                     [1.0,
+                                      Vectors.dense(1.0, 0.0),
+                                      Vectors.sparse(17, [0, 2], [1.0, 1.0])],
+                                     [1.0,
+                                      Vectors.dense(1.0, 1.0),
+                                      Vectors.sparse(17, [11, 12], [1.0, 
1.0])]]
+
+    def test_param(self):
+        binarizer = Binarizer()
+
+        binarizer.set_input_cols('f0', 'f1') \
+            .set_output_cols('of0', 'of1') \
+            .set_thresholds(1.5, 2.5)
+
+        self.assertEqual(('f0', 'f1'), binarizer.input_cols)
+        self.assertEqual(('of0', 'of1'), binarizer.output_cols)
+        self.assertEqual((1.5, 2.5), binarizer.get_thresholds())
+
+    def test_save_load_transform(self):
+        binarizer = Binarizer() \
+            .set_input_cols('f0', 'f1', 'f2') \
+            .set_output_cols('of0', 'of1', 'of2') \
+            .set_thresholds(1.0, 1.5, 2.5)
+
+        path = os.path.join(self.temp_dir, 
'test_save_load_transform_binarizer')
+        binarizer.save(path)
+        binarizer = Binarizer.load(self.t_env, path)
+
+        output_table = binarizer.transform(self.input_data_table)[0]
+        actual_outputs = [(result[0], result[3], result[4], result[5]) for 
result in
+                          
self.t_env.to_data_stream(output_table).execute_and_collect()]
+
+        self.assertEqual(3, len(actual_outputs))
+
+        actual_outputs.sort()
+
+        for i in range(len(actual_outputs)):
+            actual_output = actual_outputs[i]
+            self.assertAlmostEqual(self.expected_output_data[i][0], 
actual_output[1], delta=1.0e-7)
+            self.assertEqual(2, len(actual_output[2]))
+            for j in range(len(actual_output[2])):
+                self.assertAlmostEqual(self.expected_output_data[i][1].get(j),
+                                       actual_output[2].get(j), delta=1e-7)
+            self.assertEqual(17, len(actual_output[3]))
+            for j in range(len(actual_output[3])):
+                self.assertAlmostEqual(self.expected_output_data[i][2].get(j),
+                                       actual_output[3].get(j), delta=1e-7)

Reply via email to