yunfengzhou-hub commented on code in PR #86: URL: https://github.com/apache/flink-ml/pull/86#discussion_r869795956
########## flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.java: ########## @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.evaluation; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.evaluation.binaryclassfication.BinaryClassificationEvaluator; +import org.apache.flink.ml.evaluation.binaryclassfication.BinaryClassificationEvaluatorParams; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.StageTestUtils; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Arrays; +import java.util.List; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +/** Tests {@link BinaryClassificationEvaluator}. */ +public class BinaryClassificationEvaluatorTest { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + private StreamTableEnvironment tEnv; + private Table inputDataTable; + private Table inputDataTableScore; + private Table inputDataTableWithMultiScore; + private Table inputDataTableWithWeight; + + private static final List<Row> INPUT_DATA = + Arrays.asList( + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(1.0, Vectors.dense(0.2, 0.8)), + Row.of(1.0, Vectors.dense(0.3, 0.7)), + Row.of(0.0, Vectors.dense(0.25, 0.75)), + Row.of(0.0, Vectors.dense(0.4, 0.6)), + Row.of(1.0, Vectors.dense(0.35, 0.65)), + Row.of(1.0, Vectors.dense(0.45, 0.55)), + Row.of(0.0, Vectors.dense(0.6, 0.4)), + Row.of(0.0, Vectors.dense(0.7, 0.3)), + Row.of(1.0, Vectors.dense(0.65, 0.35)), + Row.of(0.0, Vectors.dense(0.8, 0.2)), + Row.of(1.0, Vectors.dense(0.9, 0.1))); + + private static final List<Row> INPUT_DATA_DOUBLE_RAW = + Arrays.asList( + Row.of(1, 0.9), + Row.of(1, 0.8), + Row.of(1, 0.7), + Row.of(0, 0.75), + Row.of(0, 0.6), + Row.of(1, 0.65), + Row.of(1, 0.55), + Row.of(0, 0.4), + Row.of(0, 0.3), + Row.of(1, 0.35), + Row.of(0, 0.2), + Row.of(1, 0.1)); + + private static final List<Row> INPUT_DATA_WITH_MULTI_SCORE = + Arrays.asList( + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(0.0, Vectors.dense(0.25, 0.75)), + Row.of(0.0, Vectors.dense(0.4, 0.6)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(0.0, Vectors.dense(0.6, 0.4)), + Row.of(0.0, Vectors.dense(0.7, 0.3)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(0.0, Vectors.dense(0.8, 0.2)), + Row.of(1.0, Vectors.dense(0.9, 0.1))); + + private static final List<Row> INPUT_DATA_WITH_WEIGHT = + Arrays.asList( + Row.of(1.0, Vectors.dense(0.1, 0.9), 0.8), + Row.of(1.0, Vectors.dense(0.1, 0.9), 0.7), + Row.of(1.0, Vectors.dense(0.1, 0.9), 0.5), + Row.of(0.0, Vectors.dense(0.25, 0.75), 1.2), + Row.of(0.0, Vectors.dense(0.4, 0.6), 1.3), + Row.of(1.0, Vectors.dense(0.1, 0.9), 1.5), + Row.of(1.0, Vectors.dense(0.1, 0.9), 1.4), + Row.of(0.0, Vectors.dense(0.6, 0.4), 0.3), + Row.of(0.0, Vectors.dense(0.7, 0.3), 0.5), + Row.of(1.0, Vectors.dense(0.1, 0.9), 1.9), + Row.of(0.0, Vectors.dense(0.8, 0.2), 1.2), + Row.of(1.0, Vectors.dense(0.9, 0.1), 1.0)); + + private static final double[] EXPECTED_DATA = + new double[] {0.7691481137909708, 0.3714285714285714, 0.6571428571428571}; + private static final double[] EXPECTED_DATA_M = + new double[] { + 0.8571428571428571, 0.9377705627705628, 0.8571428571428571, 0.6488095238095237 + }; + private static final double EXPECTED_DATA_W = 0.8911680911680911; + private static final double EPS = 1.0e-5; + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(3); Review Comment: Thanks for fixing the problem here. Could we also add a test to verify that the algorithm can still work even if parallelism is larger than number of data? `LogisticRegressionTest.testMoreSubtaskThanData` is a good example. Or reviewers would have to manually change this to `env.setParallelism(20)` and re run the tests. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassfication/BinaryClassificationEvaluatorParams.java: ########## @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.evaluation.binaryclassfication; + +import org.apache.flink.ml.common.param.HasLabelCol; +import org.apache.flink.ml.common.param.HasRawPredictionCol; +import org.apache.flink.ml.common.param.HasWeightCol; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringArrayParam; + +/** + * Params of BinaryClassificationEvaluator. + * + * @param <T> The class type of this instance. + */ +public interface BinaryClassificationEvaluatorParams<T> + extends HasLabelCol<T>, HasRawPredictionCol<T>, HasWeightCol<T> { + String AREA_UNDER_ROC = "areaUnderROC"; + String AREA_UNDER_PR = "areaUnderPR"; + String AREA_UNDER_LORENZ = "areaUnderLorenz"; + String KS = "ks"; + + /** + * Param for supported metric names in binary classification evaluation (supports + * 'areaUnderROC', 'areaUnderPR', 'ks' and 'areaUnderLorenz'). + * + * <ul> + * <li>areaUnderROC: the area under the receiver operating characteristic (ROC) curve. + * <li>areaUnderPR: the area under the precision-recall curve. + * <li>ks: Kolmogorov-Smirnov, measures the ability of the model to separate positive and + * negative samples. + * <li>areaUnderLorenz: the area under the lorenz curve. + * </ul> + */ + Param<String[]> METRICS_NAMES = + new StringArrayParam( + "metricsNames", + "Names of output metrics, which may contains 'areaUnderROC', 'areaUnderPR', 'ks' or 'areaUnderLorenz'", Review Comment: nit: We don't need to present possible values in the description string. "Names of output metrics." is enough. Please check `HasHandleInvalid` as an example. ########## flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.java: ########## @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.evaluation; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.evaluation.binaryclassfication.BinaryClassificationEvaluator; +import org.apache.flink.ml.evaluation.binaryclassfication.BinaryClassificationEvaluatorParams; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.StageTestUtils; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Arrays; +import java.util.List; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +/** Tests {@link BinaryClassificationEvaluator}. */ +public class BinaryClassificationEvaluatorTest { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + private StreamTableEnvironment tEnv; + private Table inputDataTable; + private Table inputDataTableScore; + private Table inputDataTableWithMultiScore; + private Table inputDataTableWithWeight; + + private static final List<Row> INPUT_DATA = + Arrays.asList( + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(1.0, Vectors.dense(0.2, 0.8)), + Row.of(1.0, Vectors.dense(0.3, 0.7)), + Row.of(0.0, Vectors.dense(0.25, 0.75)), + Row.of(0.0, Vectors.dense(0.4, 0.6)), + Row.of(1.0, Vectors.dense(0.35, 0.65)), + Row.of(1.0, Vectors.dense(0.45, 0.55)), + Row.of(0.0, Vectors.dense(0.6, 0.4)), + Row.of(0.0, Vectors.dense(0.7, 0.3)), + Row.of(1.0, Vectors.dense(0.65, 0.35)), + Row.of(0.0, Vectors.dense(0.8, 0.2)), + Row.of(1.0, Vectors.dense(0.9, 0.1))); + + private static final List<Row> INPUT_DATA_DOUBLE_RAW = + Arrays.asList( + Row.of(1, 0.9), + Row.of(1, 0.8), + Row.of(1, 0.7), + Row.of(0, 0.75), + Row.of(0, 0.6), + Row.of(1, 0.65), + Row.of(1, 0.55), + Row.of(0, 0.4), + Row.of(0, 0.3), + Row.of(1, 0.35), + Row.of(0, 0.2), + Row.of(1, 0.1)); + + private static final List<Row> INPUT_DATA_WITH_MULTI_SCORE = + Arrays.asList( + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(0.0, Vectors.dense(0.25, 0.75)), + Row.of(0.0, Vectors.dense(0.4, 0.6)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(0.0, Vectors.dense(0.6, 0.4)), + Row.of(0.0, Vectors.dense(0.7, 0.3)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(0.0, Vectors.dense(0.8, 0.2)), + Row.of(1.0, Vectors.dense(0.9, 0.1))); + + private static final List<Row> INPUT_DATA_WITH_WEIGHT = + Arrays.asList( + Row.of(1.0, Vectors.dense(0.1, 0.9), 0.8), + Row.of(1.0, Vectors.dense(0.1, 0.9), 0.7), + Row.of(1.0, Vectors.dense(0.1, 0.9), 0.5), + Row.of(0.0, Vectors.dense(0.25, 0.75), 1.2), + Row.of(0.0, Vectors.dense(0.4, 0.6), 1.3), + Row.of(1.0, Vectors.dense(0.1, 0.9), 1.5), + Row.of(1.0, Vectors.dense(0.1, 0.9), 1.4), + Row.of(0.0, Vectors.dense(0.6, 0.4), 0.3), + Row.of(0.0, Vectors.dense(0.7, 0.3), 0.5), + Row.of(1.0, Vectors.dense(0.1, 0.9), 1.9), + Row.of(0.0, Vectors.dense(0.8, 0.2), 1.2), + Row.of(1.0, Vectors.dense(0.9, 0.1), 1.0)); + + private static final double[] EXPECTED_DATA = + new double[] {0.7691481137909708, 0.3714285714285714, 0.6571428571428571}; + private static final double[] EXPECTED_DATA_M = + new double[] { + 0.8571428571428571, 0.9377705627705628, 0.8571428571428571, 0.6488095238095237 + }; + private static final double EXPECTED_DATA_W = 0.8911680911680911; + private static final double EPS = 1.0e-5; + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(3); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + inputDataTable = + tEnv.fromDataStream(env.fromCollection(INPUT_DATA)).as("label", "rawPrediction"); + inputDataTableScore = + tEnv.fromDataStream(env.fromCollection(INPUT_DATA_DOUBLE_RAW)) + .as("label", "rawPrediction"); + + inputDataTableWithMultiScore = + tEnv.fromDataStream(env.fromCollection(INPUT_DATA_WITH_MULTI_SCORE)) + .as("label", "rawPrediction"); + inputDataTableWithWeight = + tEnv.fromDataStream(env.fromCollection(INPUT_DATA_WITH_WEIGHT)) + .as("label", "rawPrediction", "weight"); + } + + @Test + public void testParam() { + BinaryClassificationEvaluator binaryEval = new BinaryClassificationEvaluator(); + assertEquals("label", binaryEval.getLabelCol()); + assertNull(binaryEval.getWeightCol()); + assertEquals("rawPrediction", binaryEval.getRawPredictionCol()); + assertArrayEquals( + new String[] {"areaUnderROC", "areaUnderPR"}, binaryEval.getMetricsNames()); + binaryEval + .setLabelCol("labelCol") + .setRawPredictionCol("raw") + .setMetricsNames(BinaryClassificationEvaluatorParams.AREA_UNDER_ROC) Review Comment: Thanks for making this change. There are some other places that have directly used the string values like ”"areaUnderRoc" in BinaryClassificationEvaluatorTest. Shall we replace all of them into constants like `AREA_UNDER_ROC`? nit: Besides, statically importing the constants might make code looks better. ```java import static org.apache.flink.ml.evaluation.binaryclassfication.BinaryClassificationEvaluatorParams.AREA_UNDER_ROC; ... .setMetricsNames(AREA_UNDER_ROC) ``` ########## flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.java: ########## @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.evaluation; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.evaluation.binaryclassfication.BinaryClassificationEvaluator; +import org.apache.flink.ml.evaluation.binaryclassfication.BinaryClassificationEvaluatorParams; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.StageTestUtils; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Arrays; +import java.util.List; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +/** Tests {@link BinaryClassificationEvaluator}. */ +public class BinaryClassificationEvaluatorTest { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + private StreamTableEnvironment tEnv; + private Table inputDataTable; + private Table inputDataTableScore; + private Table inputDataTableWithMultiScore; + private Table inputDataTableWithWeight; + + private static final List<Row> INPUT_DATA = + Arrays.asList( + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(1.0, Vectors.dense(0.2, 0.8)), + Row.of(1.0, Vectors.dense(0.3, 0.7)), + Row.of(0.0, Vectors.dense(0.25, 0.75)), + Row.of(0.0, Vectors.dense(0.4, 0.6)), + Row.of(1.0, Vectors.dense(0.35, 0.65)), + Row.of(1.0, Vectors.dense(0.45, 0.55)), + Row.of(0.0, Vectors.dense(0.6, 0.4)), + Row.of(0.0, Vectors.dense(0.7, 0.3)), + Row.of(1.0, Vectors.dense(0.65, 0.35)), + Row.of(0.0, Vectors.dense(0.8, 0.2)), + Row.of(1.0, Vectors.dense(0.9, 0.1))); + + private static final List<Row> INPUT_DATA_DOUBLE_RAW = + Arrays.asList( + Row.of(1, 0.9), + Row.of(1, 0.8), + Row.of(1, 0.7), + Row.of(0, 0.75), + Row.of(0, 0.6), + Row.of(1, 0.65), + Row.of(1, 0.55), + Row.of(0, 0.4), + Row.of(0, 0.3), + Row.of(1, 0.35), + Row.of(0, 0.2), + Row.of(1, 0.1)); + + private static final List<Row> INPUT_DATA_WITH_MULTI_SCORE = + Arrays.asList( + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(0.0, Vectors.dense(0.25, 0.75)), + Row.of(0.0, Vectors.dense(0.4, 0.6)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(0.0, Vectors.dense(0.6, 0.4)), + Row.of(0.0, Vectors.dense(0.7, 0.3)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(0.0, Vectors.dense(0.8, 0.2)), + Row.of(1.0, Vectors.dense(0.9, 0.1))); + + private static final List<Row> INPUT_DATA_WITH_WEIGHT = + Arrays.asList( + Row.of(1.0, Vectors.dense(0.1, 0.9), 0.8), + Row.of(1.0, Vectors.dense(0.1, 0.9), 0.7), + Row.of(1.0, Vectors.dense(0.1, 0.9), 0.5), + Row.of(0.0, Vectors.dense(0.25, 0.75), 1.2), + Row.of(0.0, Vectors.dense(0.4, 0.6), 1.3), + Row.of(1.0, Vectors.dense(0.1, 0.9), 1.5), + Row.of(1.0, Vectors.dense(0.1, 0.9), 1.4), + Row.of(0.0, Vectors.dense(0.6, 0.4), 0.3), + Row.of(0.0, Vectors.dense(0.7, 0.3), 0.5), + Row.of(1.0, Vectors.dense(0.1, 0.9), 1.9), + Row.of(0.0, Vectors.dense(0.8, 0.2), 1.2), + Row.of(1.0, Vectors.dense(0.9, 0.1), 1.0)); + + private static final double[] EXPECTED_DATA = + new double[] {0.7691481137909708, 0.3714285714285714, 0.6571428571428571}; + private static final double[] EXPECTED_DATA_M = + new double[] { + 0.8571428571428571, 0.9377705627705628, 0.8571428571428571, 0.6488095238095237 + }; + private static final double EXPECTED_DATA_W = 0.8911680911680911; + private static final double EPS = 1.0e-5; + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(3); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + inputDataTable = + tEnv.fromDataStream(env.fromCollection(INPUT_DATA)).as("label", "rawPrediction"); + inputDataTableScore = + tEnv.fromDataStream(env.fromCollection(INPUT_DATA_DOUBLE_RAW)) + .as("label", "rawPrediction"); + + inputDataTableWithMultiScore = + tEnv.fromDataStream(env.fromCollection(INPUT_DATA_WITH_MULTI_SCORE)) + .as("label", "rawPrediction"); + inputDataTableWithWeight = + tEnv.fromDataStream(env.fromCollection(INPUT_DATA_WITH_WEIGHT)) + .as("label", "rawPrediction", "weight"); + } + + @Test + public void testParam() { + BinaryClassificationEvaluator binaryEval = new BinaryClassificationEvaluator(); + assertEquals("label", binaryEval.getLabelCol()); + assertNull(binaryEval.getWeightCol()); + assertEquals("rawPrediction", binaryEval.getRawPredictionCol()); + assertArrayEquals( + new String[] {"areaUnderROC", "areaUnderPR"}, binaryEval.getMetricsNames()); + binaryEval + .setLabelCol("labelCol") + .setRawPredictionCol("raw") + .setMetricsNames(BinaryClassificationEvaluatorParams.AREA_UNDER_ROC) + .setWeightCol("weight"); + assertEquals("labelCol", binaryEval.getLabelCol()); + assertEquals("weight", binaryEval.getWeightCol()); + assertEquals("raw", binaryEval.getRawPredictionCol()); + assertArrayEquals(new String[] {"areaUnderROC"}, binaryEval.getMetricsNames()); + } + + @Test + public void testSaveLoadAndEvaluate() throws Exception { + BinaryClassificationEvaluator eval = + new BinaryClassificationEvaluator() + .setMetricsNames("areaUnderPR", "ks", "areaUnderROC"); + BinaryClassificationEvaluator loadedEval = + StageTestUtils.saveAndReload(tEnv, eval, tempFolder.newFolder().getAbsolutePath()); + Table evalResult = loadedEval.transform(inputDataTable)[0]; + assertArrayEquals( + new String[] {"areaUnderPR", "ks", "areaUnderROC"}, + evalResult.getResolvedSchema().getColumnNames().toArray()); + List<Row> results = IteratorUtils.toList(evalResult.execute().collect()); + Row result = results.get(0); + for (int i = 0; i < EXPECTED_DATA.length; ++i) { + assertEquals(EXPECTED_DATA[i], result.getFieldAs(i), EPS); + } + } + + @Test + public void testEvaluate() throws Exception { + BinaryClassificationEvaluator eval = + new BinaryClassificationEvaluator() + .setMetricsNames("areaUnderPR", "ks", "areaUnderROC"); + Table evalResult = eval.transform(inputDataTable)[0]; + List<Row> results = IteratorUtils.toList(evalResult.execute().collect()); + assertArrayEquals( + new String[] {"areaUnderPR", "ks", "areaUnderROC"}, + evalResult.getResolvedSchema().getColumnNames().toArray()); + Row result = results.get(0); + for (int i = 0; i < EXPECTED_DATA.length; ++i) { + assertEquals(EXPECTED_DATA[i], result.getFieldAs(i), EPS); + } + } + + @Test + public void testEvaluateWithDoubleRaw() throws Exception { + BinaryClassificationEvaluator eval = + new BinaryClassificationEvaluator() + .setMetricsNames("areaUnderPR", "ks", "areaUnderROC"); + Table evalResult = eval.transform(inputDataTableScore)[0]; + List<Row> results = IteratorUtils.toList(evalResult.execute().collect()); + assertArrayEquals( + new String[] {"areaUnderPR", "ks", "areaUnderROC"}, + evalResult.getResolvedSchema().getColumnNames().toArray()); + Row result = results.get(0); + for (int i = 0; i < EXPECTED_DATA.length; ++i) { + assertEquals(EXPECTED_DATA[i], result.getFieldAs(i), EPS); + } + } + + @Test + public void testEvaluateWithMultiScore() throws Exception { Review Comment: nit: Shall we check and clean up warnings in `BinaryClassificationEvaluator` and `BinaryClassificationEvaluatorTest`? For example, Exception is never thrown in this method. ########## flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.java: ########## @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.evaluation; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.evaluation.binaryclassfication.BinaryClassificationEvaluator; +import org.apache.flink.ml.evaluation.binaryclassfication.BinaryClassificationEvaluatorParams; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.StageTestUtils; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Arrays; +import java.util.List; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +/** Tests {@link BinaryClassificationEvaluator}. */ +public class BinaryClassificationEvaluatorTest { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + private StreamTableEnvironment tEnv; + private Table inputDataTable; + private Table inputDataTableScore; + private Table inputDataTableWithMultiScore; + private Table inputDataTableWithWeight; + + private static final List<Row> INPUT_DATA = + Arrays.asList( + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(1.0, Vectors.dense(0.2, 0.8)), + Row.of(1.0, Vectors.dense(0.3, 0.7)), + Row.of(0.0, Vectors.dense(0.25, 0.75)), + Row.of(0.0, Vectors.dense(0.4, 0.6)), + Row.of(1.0, Vectors.dense(0.35, 0.65)), + Row.of(1.0, Vectors.dense(0.45, 0.55)), + Row.of(0.0, Vectors.dense(0.6, 0.4)), + Row.of(0.0, Vectors.dense(0.7, 0.3)), + Row.of(1.0, Vectors.dense(0.65, 0.35)), + Row.of(0.0, Vectors.dense(0.8, 0.2)), + Row.of(1.0, Vectors.dense(0.9, 0.1))); + + private static final List<Row> INPUT_DATA_DOUBLE_RAW = + Arrays.asList( + Row.of(1, 0.9), + Row.of(1, 0.8), + Row.of(1, 0.7), + Row.of(0, 0.75), + Row.of(0, 0.6), + Row.of(1, 0.65), + Row.of(1, 0.55), + Row.of(0, 0.4), + Row.of(0, 0.3), + Row.of(1, 0.35), + Row.of(0, 0.2), + Row.of(1, 0.1)); + + private static final List<Row> INPUT_DATA_WITH_MULTI_SCORE = + Arrays.asList( + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(0.0, Vectors.dense(0.25, 0.75)), + Row.of(0.0, Vectors.dense(0.4, 0.6)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(0.0, Vectors.dense(0.6, 0.4)), + Row.of(0.0, Vectors.dense(0.7, 0.3)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(0.0, Vectors.dense(0.8, 0.2)), + Row.of(1.0, Vectors.dense(0.9, 0.1))); + + private static final List<Row> INPUT_DATA_WITH_WEIGHT = + Arrays.asList( + Row.of(1.0, Vectors.dense(0.1, 0.9), 0.8), + Row.of(1.0, Vectors.dense(0.1, 0.9), 0.7), + Row.of(1.0, Vectors.dense(0.1, 0.9), 0.5), + Row.of(0.0, Vectors.dense(0.25, 0.75), 1.2), + Row.of(0.0, Vectors.dense(0.4, 0.6), 1.3), + Row.of(1.0, Vectors.dense(0.1, 0.9), 1.5), + Row.of(1.0, Vectors.dense(0.1, 0.9), 1.4), + Row.of(0.0, Vectors.dense(0.6, 0.4), 0.3), + Row.of(0.0, Vectors.dense(0.7, 0.3), 0.5), + Row.of(1.0, Vectors.dense(0.1, 0.9), 1.9), + Row.of(0.0, Vectors.dense(0.8, 0.2), 1.2), + Row.of(1.0, Vectors.dense(0.9, 0.1), 1.0)); + + private static final double[] EXPECTED_DATA = + new double[] {0.7691481137909708, 0.3714285714285714, 0.6571428571428571}; + private static final double[] EXPECTED_DATA_M = + new double[] { + 0.8571428571428571, 0.9377705627705628, 0.8571428571428571, 0.6488095238095237 + }; + private static final double EXPECTED_DATA_W = 0.8911680911680911; + private static final double EPS = 1.0e-5; + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(3); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + inputDataTable = + tEnv.fromDataStream(env.fromCollection(INPUT_DATA)).as("label", "rawPrediction"); + inputDataTableScore = + tEnv.fromDataStream(env.fromCollection(INPUT_DATA_DOUBLE_RAW)) + .as("label", "rawPrediction"); + + inputDataTableWithMultiScore = + tEnv.fromDataStream(env.fromCollection(INPUT_DATA_WITH_MULTI_SCORE)) + .as("label", "rawPrediction"); + inputDataTableWithWeight = + tEnv.fromDataStream(env.fromCollection(INPUT_DATA_WITH_WEIGHT)) + .as("label", "rawPrediction", "weight"); + } + + @Test + public void testParam() { + BinaryClassificationEvaluator binaryEval = new BinaryClassificationEvaluator(); + assertEquals("label", binaryEval.getLabelCol()); + assertNull(binaryEval.getWeightCol()); + assertEquals("rawPrediction", binaryEval.getRawPredictionCol()); + assertArrayEquals( + new String[] {"areaUnderROC", "areaUnderPR"}, binaryEval.getMetricsNames()); + binaryEval + .setLabelCol("labelCol") + .setRawPredictionCol("raw") + .setMetricsNames(BinaryClassificationEvaluatorParams.AREA_UNDER_ROC) + .setWeightCol("weight"); + assertEquals("labelCol", binaryEval.getLabelCol()); + assertEquals("weight", binaryEval.getWeightCol()); + assertEquals("raw", binaryEval.getRawPredictionCol()); + assertArrayEquals(new String[] {"areaUnderROC"}, binaryEval.getMetricsNames()); + } + + @Test + public void testSaveLoadAndEvaluate() throws Exception { + BinaryClassificationEvaluator eval = + new BinaryClassificationEvaluator() + .setMetricsNames("areaUnderPR", "ks", "areaUnderROC"); + BinaryClassificationEvaluator loadedEval = + StageTestUtils.saveAndReload(tEnv, eval, tempFolder.newFolder().getAbsolutePath()); + Table evalResult = loadedEval.transform(inputDataTable)[0]; + assertArrayEquals( + new String[] {"areaUnderPR", "ks", "areaUnderROC"}, + evalResult.getResolvedSchema().getColumnNames().toArray()); + List<Row> results = IteratorUtils.toList(evalResult.execute().collect()); + Row result = results.get(0); + for (int i = 0; i < EXPECTED_DATA.length; ++i) { + assertEquals(EXPECTED_DATA[i], result.getFieldAs(i), EPS); + } + } + + @Test + public void testEvaluate() throws Exception { + BinaryClassificationEvaluator eval = + new BinaryClassificationEvaluator() + .setMetricsNames("areaUnderPR", "ks", "areaUnderROC"); + Table evalResult = eval.transform(inputDataTable)[0]; + List<Row> results = IteratorUtils.toList(evalResult.execute().collect()); + assertArrayEquals( + new String[] {"areaUnderPR", "ks", "areaUnderROC"}, + evalResult.getResolvedSchema().getColumnNames().toArray()); + Row result = results.get(0); + for (int i = 0; i < EXPECTED_DATA.length; ++i) { + assertEquals(EXPECTED_DATA[i], result.getFieldAs(i), EPS); + } + } + + @Test + public void testEvaluateWithDoubleRaw() throws Exception { Review Comment: Thanks for adding this test. In Spark's BinaryClassification algorithm, it also tests when label's data type is any possible numeric type, like int, shorts or decimal. Could we also add these tests? ########## flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassfication/BinaryClassificationEvaluator.java: ########## @@ -0,0 +1,742 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.evaluation.binaryclassfication; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.api.scala.typeutils.Types; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.AlgoOperator; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.StreamMap; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Random; + +import static org.apache.flink.runtime.blob.BlobWriter.LOG; + +/** + * An Estimator which calculates the evaluation metrics for binary classification. The input data + * has columns rawPrediction, label and an optional weight column. The rawPrediction can be of type + * double (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of + * raw predictions, scores, or label probabilities). The output may contain different metrics which + * will be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams. Review Comment: nit: `{@link BinaryClassificationEvaluatorParams}`. `parameter MetricsNames` might be outdated. ########## flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.java: ########## @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.evaluation; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.evaluation.binaryclassfication.BinaryClassificationEvaluator; +import org.apache.flink.ml.evaluation.binaryclassfication.BinaryClassificationEvaluatorParams; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.StageTestUtils; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Arrays; +import java.util.List; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +/** Tests {@link BinaryClassificationEvaluator}. */ +public class BinaryClassificationEvaluatorTest { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + private StreamTableEnvironment tEnv; + private Table inputDataTable; + private Table inputDataTableScore; + private Table inputDataTableWithMultiScore; + private Table inputDataTableWithWeight; + + private static final List<Row> INPUT_DATA = + Arrays.asList( + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(1.0, Vectors.dense(0.2, 0.8)), + Row.of(1.0, Vectors.dense(0.3, 0.7)), + Row.of(0.0, Vectors.dense(0.25, 0.75)), + Row.of(0.0, Vectors.dense(0.4, 0.6)), + Row.of(1.0, Vectors.dense(0.35, 0.65)), + Row.of(1.0, Vectors.dense(0.45, 0.55)), + Row.of(0.0, Vectors.dense(0.6, 0.4)), + Row.of(0.0, Vectors.dense(0.7, 0.3)), + Row.of(1.0, Vectors.dense(0.65, 0.35)), + Row.of(0.0, Vectors.dense(0.8, 0.2)), + Row.of(1.0, Vectors.dense(0.9, 0.1))); + + private static final List<Row> INPUT_DATA_DOUBLE_RAW = + Arrays.asList( + Row.of(1, 0.9), + Row.of(1, 0.8), + Row.of(1, 0.7), + Row.of(0, 0.75), + Row.of(0, 0.6), + Row.of(1, 0.65), + Row.of(1, 0.55), + Row.of(0, 0.4), + Row.of(0, 0.3), + Row.of(1, 0.35), + Row.of(0, 0.2), + Row.of(1, 0.1)); + + private static final List<Row> INPUT_DATA_WITH_MULTI_SCORE = + Arrays.asList( + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(0.0, Vectors.dense(0.25, 0.75)), + Row.of(0.0, Vectors.dense(0.4, 0.6)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(0.0, Vectors.dense(0.6, 0.4)), + Row.of(0.0, Vectors.dense(0.7, 0.3)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(0.0, Vectors.dense(0.8, 0.2)), + Row.of(1.0, Vectors.dense(0.9, 0.1))); + + private static final List<Row> INPUT_DATA_WITH_WEIGHT = + Arrays.asList( + Row.of(1.0, Vectors.dense(0.1, 0.9), 0.8), + Row.of(1.0, Vectors.dense(0.1, 0.9), 0.7), + Row.of(1.0, Vectors.dense(0.1, 0.9), 0.5), + Row.of(0.0, Vectors.dense(0.25, 0.75), 1.2), + Row.of(0.0, Vectors.dense(0.4, 0.6), 1.3), + Row.of(1.0, Vectors.dense(0.1, 0.9), 1.5), + Row.of(1.0, Vectors.dense(0.1, 0.9), 1.4), + Row.of(0.0, Vectors.dense(0.6, 0.4), 0.3), + Row.of(0.0, Vectors.dense(0.7, 0.3), 0.5), + Row.of(1.0, Vectors.dense(0.1, 0.9), 1.9), + Row.of(0.0, Vectors.dense(0.8, 0.2), 1.2), + Row.of(1.0, Vectors.dense(0.9, 0.1), 1.0)); + + private static final double[] EXPECTED_DATA = + new double[] {0.7691481137909708, 0.3714285714285714, 0.6571428571428571}; + private static final double[] EXPECTED_DATA_M = + new double[] { + 0.8571428571428571, 0.9377705627705628, 0.8571428571428571, 0.6488095238095237 + }; + private static final double EXPECTED_DATA_W = 0.8911680911680911; + private static final double EPS = 1.0e-5; + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(3); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + inputDataTable = + tEnv.fromDataStream(env.fromCollection(INPUT_DATA)).as("label", "rawPrediction"); + inputDataTableScore = + tEnv.fromDataStream(env.fromCollection(INPUT_DATA_DOUBLE_RAW)) + .as("label", "rawPrediction"); + + inputDataTableWithMultiScore = + tEnv.fromDataStream(env.fromCollection(INPUT_DATA_WITH_MULTI_SCORE)) + .as("label", "rawPrediction"); + inputDataTableWithWeight = + tEnv.fromDataStream(env.fromCollection(INPUT_DATA_WITH_WEIGHT)) + .as("label", "rawPrediction", "weight"); + } + + @Test + public void testParam() { + BinaryClassificationEvaluator binaryEval = new BinaryClassificationEvaluator(); + assertEquals("label", binaryEval.getLabelCol()); + assertNull(binaryEval.getWeightCol()); + assertEquals("rawPrediction", binaryEval.getRawPredictionCol()); + assertArrayEquals( + new String[] {"areaUnderROC", "areaUnderPR"}, binaryEval.getMetricsNames()); + binaryEval + .setLabelCol("labelCol") + .setRawPredictionCol("raw") + .setMetricsNames(BinaryClassificationEvaluatorParams.AREA_UNDER_ROC) + .setWeightCol("weight"); + assertEquals("labelCol", binaryEval.getLabelCol()); + assertEquals("weight", binaryEval.getWeightCol()); + assertEquals("raw", binaryEval.getRawPredictionCol()); + assertArrayEquals(new String[] {"areaUnderROC"}, binaryEval.getMetricsNames()); + } + + @Test + public void testSaveLoadAndEvaluate() throws Exception { Review Comment: nit: Shall we adjust the order of these test cases to follow other classes? For example, First is `testParam`, followed by `testOutputSchema`, `testEvaluate` and then save/load and other special cases. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassfication/BinaryClassificationEvaluator.java: ########## @@ -0,0 +1,742 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.evaluation.binaryclassfication; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.api.scala.typeutils.Types; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.AlgoOperator; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.StreamMap; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Random; + +import static org.apache.flink.runtime.blob.BlobWriter.LOG; + +/** + * An Estimator which calculates the evaluation metrics for binary classification. The input data + * has columns rawPrediction, label and an optional weight column. The rawPrediction can be of type + * double (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of + * raw predictions, scores, or label probabilities). The output may contain different metrics which + * will be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams. + */ +public class BinaryClassificationEvaluator + implements AlgoOperator<BinaryClassificationEvaluator>, + BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 50; + + public BinaryClassificationEvaluator() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + @SuppressWarnings("unchecked") + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Tuple3<Double, Boolean, Double>> evalData = + tEnv.toDataStream(inputs[0]) + .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol())); + final String boundaryRangeKey = "boundaryRange"; + final String partitionSummariesKey = "partitionSummaries"; + + DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(evalData), + Collections.singletonMap(boundaryRangeKey, getBoundaryRange(evalData)), + inputList -> { + DataStream input = inputList.get(0); + return input.map(new AppendTaskId(boundaryRangeKey)); + }); + + /* Repartition the evaluated data by range. */ + evalDataWithTaskId = + evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3); + + /* Sorts local data by score.*/ + evalData = + DataStreamUtils.mapPartition( + evalDataWithTaskId, + new MapPartitionFunction< + Tuple4<Double, Boolean, Double, Integer>, + Tuple3<Double, Boolean, Double>>() { + @Override + public void mapPartition( + Iterable<Tuple4<Double, Boolean, Double, Integer>> values, + Collector<Tuple3<Double, Boolean, Double>> out) { + List<Tuple3<Double, Boolean, Double>> bufferedData = + new LinkedList<>(); + for (Tuple4<Double, Boolean, Double, Integer> t4 : values) { + bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2)); + } + bufferedData.sort(Comparator.comparingDouble(o -> -o.f0)); + for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) { + out.collect(dataPoint); + } + } + }); + + /* Calculates the summary of local data. */ + DataStream<BinarySummary> partitionSummaries = + evalData.transform( + "reduceInEachPartition", + TypeInformation.of(BinarySummary.class), + new PartitionSummaryOperator()); + + /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */ + DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(evalData), + Collections.singletonMap(partitionSummariesKey, partitionSummaries), + inputList -> { + DataStream input = inputList.get(0); + return input.flatMap(new CalcSampleOrders(partitionSummariesKey)); + }); + + dataWithOrders = + dataWithOrders.transform( + "appendMaxWaterMark", + dataWithOrders.getType(), + new AppendMaxWatermark(x -> x)); + + DataStream<double[]> localAreaUnderROCVariable = + dataWithOrders.transform( + "AccumulateMultiScore", + TypeInformation.of(double[].class), + new AccumulateMultiScoreOperator()); + + DataStream<double[]> middleAreaUnderROC = + DataStreamUtils.reduce( + localAreaUnderROCVariable, + (ReduceFunction<double[]>) + (t1, t2) -> { + t2[0] += t1[0]; + t2[1] += t1[1]; + t2[2] += t1[2]; + return t2; + }); + + DataStream<Double> areaUnderROC = + middleAreaUnderROC.map( + (MapFunction<double[], Double>) + value -> { + if (value[1] > 0 && value[2] > 0) { + return (value[0] - 1. * value[1] * (value[1] + 1) / 2) + / (value[1] * value[2]); + } else { + return Double.NaN; + } + }); + + Map<String, DataStream<?>> broadcastMap = new HashMap<>(); + broadcastMap.put(partitionSummariesKey, partitionSummaries); + broadcastMap.put(AREA_UNDER_ROC, areaUnderROC); + DataStream<BinaryMetrics> localMetrics = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(evalData), + broadcastMap, + inputList -> { + DataStream input = inputList.get(0); + return DataStreamUtils.mapPartition( + input, new CalcBinaryMetrics(partitionSummariesKey)); + }); + + DataStream<Map<String, Double>> metrics = + DataStreamUtils.mapPartition(localMetrics, new MergeMetrics()); + metrics.getTransformation().setParallelism(1); + + final String[] metricsNames = getMetricsNames(); + TypeInformation<?>[] metricTypes = new TypeInformation[metricsNames.length]; + Arrays.fill(metricTypes, Types.DOUBLE()); + RowTypeInfo outputTypeInfo = new RowTypeInfo(metricTypes, metricsNames); + + DataStream<Row> evalResult = + metrics.map( + (MapFunction<Map<String, Double>, Row>) + value -> { + Row ret = new Row(metricsNames.length); + for (int i = 0; i < metricsNames.length; ++i) { + ret.setField(i, value.get(metricsNames[i])); + } + return ret; + }, + outputTypeInfo); + return new Table[] {tEnv.fromDataStream(evalResult)}; + } + + /** Updates variables for calculating AreaUnderROC. */ + private static class AccumulateMultiScoreOperator extends AbstractStreamOperator<double[]> + implements OneInputStreamOperator<Tuple4<Double, Long, Boolean, Double>, double[]>, + BoundedOneInput { + private ListState<double[]> accValueState; + private ListState<Double> scoreState; + + double[] accValue; + double score; + + @Override + public void endInput() { + if (accValue != null) { + output.collect( + new StreamRecord<>( + new double[] { + accValue[0] / accValue[1] * accValue[2], + accValue[2], + accValue[3] + })); + } + } + + @Override + public void processElement( + StreamRecord<Tuple4<Double, Long, Boolean, Double>> streamRecord) { + Tuple4<Double, Long, Boolean, Double> t = streamRecord.getValue(); + if (accValue == null) { + accValue = new double[4]; + score = t.f0; + } else if (score != t.f0) { + output.collect( + new StreamRecord<>( + new double[] { + accValue[0] / accValue[1] * accValue[2], + accValue[2], + accValue[3] + })); + Arrays.fill(accValue, 0.0); + } + accValue[0] += t.f1; + accValue[1] += 1.0; + if (t.f2) { + accValue[2] += t.f3; + } else { + accValue[3] += t.f3; + } + } + + @Override + @SuppressWarnings("unchecked") + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + accValueState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "accValueState", TypeInformation.of(double[].class))); + accValue = + OperatorStateUtils.getUniqueElement(accValueState, "accValueState") + .orElse(null); + + scoreState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "scoreState", TypeInformation.of(Double.class))); + score = OperatorStateUtils.getUniqueElement(scoreState, "scoreState").orElse(0.0); Review Comment: score has not been saved in snapshot yet before restoring from operator state. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
