yunfengzhou-hub commented on code in PR #86: URL: https://github.com/apache/flink-ml/pull/86#discussion_r867784038
########## flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassificationevaluator/BinaryClassificationEvaluatorParams.java: ########## @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.evaluation.binaryclassificationevaluator; + +import org.apache.flink.ml.common.param.HasLabelCol; +import org.apache.flink.ml.common.param.HasRawPredictionCol; +import org.apache.flink.ml.common.param.HasWeightCol; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringArrayParam; + +/** + * Params of BinaryClassificationEvaluator. + * + * @param <T> The class type of this instance. + */ +public interface BinaryClassificationEvaluatorParams<T> + extends HasLabelCol<T>, HasRawPredictionCol<T>, HasWeightCol<T> { + String AREA_UNDER_ROC = "areaUnderROC"; + String AREA_UNDER_PR = "areaUnderPR"; + String AREA_UNDER_LORENZ = "areaUnderLorenz"; + String KS = "ks"; + + /** + * param for metric names in evaluation (supports 'areaUnderROC', 'areaUnderPR', 'KS' and + * 'areaUnderLorenz'). + * + * <ul> + * <li>areaUnderROC: the area under the receiver operating characteristic (ROC) curve. + * <li>areaUnderPR: the area under the precision-recall curve. + * <li>ks: Kolmogorov-Smirnov, measures the ability of the model to separate positive and + * negative samples. + * <li>areaUnderLorenz: the area under the lorenz curve. + * </ul> + */ + Param<String[]> METRICS_NAMES = + new StringArrayParam( + "metricsNames", + "Names of output metrics. The array element must be 'areaUnderROC', 'areaUnderPR', 'ks' and 'areaUnderLorenz'", Review Comment: nit: it might be better to remove "The array element ...", since other array-typed parameters, like `HasHandleInvalid` or `HasBatchStrategy` does not contain this sentence. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassificationevaluator/BinaryClassificationEvaluatorParams.java: ########## @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.evaluation.binaryclassificationevaluator; Review Comment: nit: Maybe `binaryclassification` makes a better package name, unless we have something like `binaryclassificationregressor` or `binaryclassificationscaler`. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassificationevaluator/BinaryClassificationEvaluatorParams.java: ########## @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.evaluation.binaryclassificationevaluator; + +import org.apache.flink.ml.common.param.HasLabelCol; +import org.apache.flink.ml.common.param.HasRawPredictionCol; +import org.apache.flink.ml.common.param.HasWeightCol; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringArrayParam; + +/** + * Params of BinaryClassificationEvaluator. + * + * @param <T> The class type of this instance. + */ +public interface BinaryClassificationEvaluatorParams<T> + extends HasLabelCol<T>, HasRawPredictionCol<T>, HasWeightCol<T> { + String AREA_UNDER_ROC = "areaUnderROC"; + String AREA_UNDER_PR = "areaUnderPR"; + String AREA_UNDER_LORENZ = "areaUnderLorenz"; + String KS = "ks"; + + /** + * param for metric names in evaluation (supports 'areaUnderROC', 'areaUnderPR', 'KS' and Review Comment: nit: Javadocs should start with Uppercase letter. Besides, "metric names in evaluation" might have ambiguity. "Param for supported metric names" or `Param for supported metric names in binary classification evaluation" looks good for me. ########## flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluatorTest.java: ########## @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.evaluation.binaryeval; Review Comment: If `BinaryClassificationEvaluator` is in `org.apache.flink.ml.evaluation.binaryclassificationevaluator`, `BinaryClassificationEvaluatorTest` should be in `org.apache.flink.ml.evaluation.binaryclassificationevaluator` or `org.apache.flink.ml.evaluation`. ########## flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluatorTest.java: ########## @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.evaluation.binaryeval; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.StageTestUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Arrays; +import java.util.List; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +/** Tests {@link BinaryClassificationEvaluator}. */ +public class BinaryClassificationEvaluatorTest { Review Comment: Spark BinaryClassification algorithm have test cases for the following situations. Could you please add corresponding test cases? - `rawPredictionCol`'s data type is double - `label`'s data type is any possible numeric type, like int, shorts or decimal. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassificationevaluator/BinaryClassificationEvaluator.java: ########## @@ -0,0 +1,730 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.evaluation.binaryclassificationevaluator; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.api.scala.typeutils.Types; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.AlgoOperator; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.StreamMap; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Random; + +import static org.apache.flink.runtime.blob.BlobWriter.LOG; + +/** + * An Estimator which calculates the evaluation metrics for binary classification. The input data + * has columns rawPrediction, label and an optional weight column. The rawPrediction can be of type + * double (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of + * raw predictions, scores, or label probabilities). The output may contain different metrics which + * will be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams. + */ +public class BinaryClassificationEvaluator + implements AlgoOperator<BinaryClassificationEvaluator>, + BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100; + + public BinaryClassificationEvaluator() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + @SuppressWarnings("unchecked") + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Tuple3<Double, Boolean, Double>> evalData = + tEnv.toDataStream(inputs[0]) + .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol())); + final String boundaryRangeKey = "boundaryRange"; + final String partitionSummariesKey = "partitionSummaries"; + + DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(evalData), + Collections.singletonMap(boundaryRangeKey, getBoundaryRange(evalData)), + inputList -> { + DataStream input = inputList.get(0); + return input.map(new AppendTaskId(boundaryRangeKey)); + }); + + /* Repartition the evaluated data by range. */ + evalDataWithTaskId = + evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3); + + /* Sorts local data by score.*/ + evalData = + DataStreamUtils.mapPartition( + evalDataWithTaskId, + new MapPartitionFunction< + Tuple4<Double, Boolean, Double, Integer>, + Tuple3<Double, Boolean, Double>>() { + @Override + public void mapPartition( + Iterable<Tuple4<Double, Boolean, Double, Integer>> values, + Collector<Tuple3<Double, Boolean, Double>> out) { + List<Tuple3<Double, Boolean, Double>> bufferedData = + new LinkedList<>(); + for (Tuple4<Double, Boolean, Double, Integer> t4 : values) { + bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2)); + } + bufferedData.sort(Comparator.comparingDouble(o -> -o.f0)); + for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) { + out.collect(dataPoint); + } + } + }); + + /* Calculates the summary of local data. */ + DataStream<BinarySummary> partitionSummaries = + evalData.transform( + "reduceInEachPartition", + TypeInformation.of(BinarySummary.class), + new PartitionSummaryOperator()); + + /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */ + DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(evalData), + Collections.singletonMap(partitionSummariesKey, partitionSummaries), + inputList -> { + DataStream input = inputList.get(0); + return input.flatMap(new CalcSampleOrders(partitionSummariesKey)); + }); + + dataWithOrders = + dataWithOrders.transform( + "appendMaxWaterMark", + dataWithOrders.getType(), + new AppendMaxWatermark(x -> x)); + + DataStream<double[]> localAreaUnderROCVariable = + dataWithOrders.transform( + "AccumulateMultiScore", + TypeInformation.of(double[].class), + new AccumulateMultiScoreOperator()); + + DataStream<double[]> middleAreaUnderROC = + DataStreamUtils.reduce( + localAreaUnderROCVariable, + (ReduceFunction<double[]>) + (t1, t2) -> { + t2[0] += t1[0]; + t2[1] += t1[1]; + t2[2] += t1[2]; + return t2; + }); + + DataStream<Double> areaUnderROC = + middleAreaUnderROC.map( + (MapFunction<double[], Double>) + value -> { + if (value[1] > 0 && value[2] > 0) { + return (value[0] - 1. * value[1] * (value[1] + 1) / 2) + / (value[1] * value[2]); + } else { + return Double.NaN; + } + }); + + Map<String, DataStream<?>> broadcastMap = new HashMap<>(); + broadcastMap.put(partitionSummariesKey, partitionSummaries); + broadcastMap.put(AREA_UNDER_ROC, areaUnderROC); + DataStream<BinaryMetrics> localMetrics = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(evalData), + broadcastMap, + inputList -> { + DataStream input = inputList.get(0); + return DataStreamUtils.mapPartition( + input, new CalcBinaryMetrics(partitionSummariesKey)); + }); + + DataStream<Map<String, Double>> metrics = + DataStreamUtils.mapPartition(localMetrics, new MergeMetrics()); + metrics.getTransformation().setParallelism(1); + + final String[] metricsNames = getMetricsNames(); + TypeInformation<?>[] metricTypes = new TypeInformation[metricsNames.length]; + Arrays.fill(metricTypes, Types.DOUBLE()); + RowTypeInfo outputTypeInfo = new RowTypeInfo(metricTypes, metricsNames); + + DataStream<Row> evalResult = + metrics.map( + (MapFunction<Map<String, Double>, Row>) + value -> { + Row ret = new Row(metricsNames.length); + for (int i = 0; i < metricsNames.length; ++i) { + ret.setField(i, value.get(metricsNames[i])); + } + return ret; + }, + outputTypeInfo); + return new Table[] {tEnv.fromDataStream(evalResult)}; + } + + /** Updates variables for calculating AreaUnderROC. */ + private static class AccumulateMultiScoreOperator extends AbstractStreamOperator<double[]> + implements OneInputStreamOperator<Tuple4<Double, Long, Boolean, Double>, double[]>, + BoundedOneInput { + private ListState<double[]> accValueState; + double[] accValue; + double score; + + @Override + public void endInput() { + if (accValue != null) { + output.collect( + new StreamRecord<>( + new double[] { + accValue[0] / accValue[1] * accValue[2], + accValue[2], + accValue[3] + })); + } + } + + @Override + public void processElement( + StreamRecord<Tuple4<Double, Long, Boolean, Double>> streamRecord) { + Tuple4<Double, Long, Boolean, Double> t = streamRecord.getValue(); + if (accValue == null) { + accValue = new double[4]; + score = t.f0; Review Comment: score is not checkpointed or restored. When there is failover, stream records of different scores might be accumulated in the same `accValue`. ########## flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluatorTest.java: ########## @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.evaluation.binaryeval; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.evaluation.binaryclassificationevaluator.BinaryClassificationEvaluator; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.StageTestUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Arrays; +import java.util.List; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +/** Tests {@link BinaryClassificationEvaluator}. */ +public class BinaryClassificationEvaluatorTest { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + private StreamTableEnvironment tEnv; + private Table inputDataTable; + private Table inputDataTableWithMultiScore; + private Table inputDataTableWithWeight; + + private static final List<Row> INPUT_DATA = + Arrays.asList( + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(1.0, Vectors.dense(0.2, 0.8)), + Row.of(1.0, Vectors.dense(0.3, 0.7)), + Row.of(0.0, Vectors.dense(0.25, 0.75)), + Row.of(0.0, Vectors.dense(0.4, 0.6)), + Row.of(1.0, Vectors.dense(0.35, 0.65)), + Row.of(1.0, Vectors.dense(0.45, 0.55)), + Row.of(0.0, Vectors.dense(0.6, 0.4)), + Row.of(0.0, Vectors.dense(0.7, 0.3)), + Row.of(1.0, Vectors.dense(0.65, 0.35)), + Row.of(0.0, Vectors.dense(0.8, 0.2)), + Row.of(1.0, Vectors.dense(0.9, 0.1))); + + private static final List<Row> INPUT_DATA_WITH_MULTI_SCORE = + Arrays.asList( + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(0.0, Vectors.dense(0.25, 0.75)), + Row.of(0.0, Vectors.dense(0.4, 0.6)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(0.0, Vectors.dense(0.6, 0.4)), + Row.of(0.0, Vectors.dense(0.7, 0.3)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(0.0, Vectors.dense(0.8, 0.2)), + Row.of(1.0, Vectors.dense(0.9, 0.1))); + + private static final List<Row> INPUT_DATA_WITH_WEIGHT = + Arrays.asList( + Row.of(1.0, Vectors.dense(0.1, 0.9), 0.8), + Row.of(1.0, Vectors.dense(0.1, 0.9), 0.7), + Row.of(1.0, Vectors.dense(0.1, 0.9), 0.5), + Row.of(0.0, Vectors.dense(0.25, 0.75), 1.2), + Row.of(0.0, Vectors.dense(0.4, 0.6), 1.3), + Row.of(1.0, Vectors.dense(0.1, 0.9), 1.5), + Row.of(1.0, Vectors.dense(0.1, 0.9), 1.4), + Row.of(0.0, Vectors.dense(0.6, 0.4), 0.3), + Row.of(0.0, Vectors.dense(0.7, 0.3), 0.5), + Row.of(1.0, Vectors.dense(0.1, 0.9), 1.9), + Row.of(0.0, Vectors.dense(0.8, 0.2), 1.2), + Row.of(1.0, Vectors.dense(0.9, 0.1), 1.0)); + + private static final double[] EXPECTED_DATA = + new double[] {0.7691481137909708, 0.3714285714285714, 0.6571428571428571}; + private static final double[] EXPECTED_DATA_M = + new double[] {0.8571428571428571, 0.9377705627705628, 0.8571428571428571}; + private static final double EXPECTED_DATA_W = 0.8911680911680911; + private static final double EPS = 1.0e-5; + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(3); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + inputDataTable = + tEnv.fromDataStream(env.fromCollection(INPUT_DATA)).as("label", "rawPrediction"); + inputDataTableWithMultiScore = + tEnv.fromDataStream(env.fromCollection(INPUT_DATA_WITH_MULTI_SCORE)) + .as("label", "rawPrediction"); + inputDataTableWithWeight = + tEnv.fromDataStream(env.fromCollection(INPUT_DATA_WITH_WEIGHT)) + .as("label", "rawPrediction", "weight"); + } + + @Test + public void testParam() { + BinaryClassificationEvaluator binaryEval = new BinaryClassificationEvaluator(); + assertEquals("label", binaryEval.getLabelCol()); + assertNull(binaryEval.getWeightCol()); + assertEquals("rawPrediction", binaryEval.getRawPredictionCol()); + assertArrayEquals( + new String[] {"areaUnderROC", "areaUnderPR"}, binaryEval.getMetricsNames()); + binaryEval + .setLabelCol("labelCol") + .setRawPredictionCol("raw") + .setMetricsNames("areaUnderROC") Review Comment: nit: It might be better to use the constants we have already defined. For example, `setMetricsNames(BinaryClassificationEvaluatorParams.AREA_UNDER_ROC)` ########## flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluatorTest.java: ########## @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.evaluation.binaryeval; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.evaluation.binaryclassificationevaluator.BinaryClassificationEvaluator; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.StageTestUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Arrays; +import java.util.List; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +/** Tests {@link BinaryClassificationEvaluator}. */ +public class BinaryClassificationEvaluatorTest { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + private StreamTableEnvironment tEnv; + private Table inputDataTable; + private Table inputDataTableWithMultiScore; + private Table inputDataTableWithWeight; + + private static final List<Row> INPUT_DATA = + Arrays.asList( + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(1.0, Vectors.dense(0.2, 0.8)), + Row.of(1.0, Vectors.dense(0.3, 0.7)), + Row.of(0.0, Vectors.dense(0.25, 0.75)), + Row.of(0.0, Vectors.dense(0.4, 0.6)), + Row.of(1.0, Vectors.dense(0.35, 0.65)), + Row.of(1.0, Vectors.dense(0.45, 0.55)), + Row.of(0.0, Vectors.dense(0.6, 0.4)), + Row.of(0.0, Vectors.dense(0.7, 0.3)), + Row.of(1.0, Vectors.dense(0.65, 0.35)), + Row.of(0.0, Vectors.dense(0.8, 0.2)), + Row.of(1.0, Vectors.dense(0.9, 0.1))); + + private static final List<Row> INPUT_DATA_WITH_MULTI_SCORE = + Arrays.asList( + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(0.0, Vectors.dense(0.25, 0.75)), + Row.of(0.0, Vectors.dense(0.4, 0.6)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(0.0, Vectors.dense(0.6, 0.4)), + Row.of(0.0, Vectors.dense(0.7, 0.3)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(0.0, Vectors.dense(0.8, 0.2)), + Row.of(1.0, Vectors.dense(0.9, 0.1))); + + private static final List<Row> INPUT_DATA_WITH_WEIGHT = + Arrays.asList( + Row.of(1.0, Vectors.dense(0.1, 0.9), 0.8), + Row.of(1.0, Vectors.dense(0.1, 0.9), 0.7), + Row.of(1.0, Vectors.dense(0.1, 0.9), 0.5), + Row.of(0.0, Vectors.dense(0.25, 0.75), 1.2), + Row.of(0.0, Vectors.dense(0.4, 0.6), 1.3), + Row.of(1.0, Vectors.dense(0.1, 0.9), 1.5), + Row.of(1.0, Vectors.dense(0.1, 0.9), 1.4), + Row.of(0.0, Vectors.dense(0.6, 0.4), 0.3), + Row.of(0.0, Vectors.dense(0.7, 0.3), 0.5), + Row.of(1.0, Vectors.dense(0.1, 0.9), 1.9), + Row.of(0.0, Vectors.dense(0.8, 0.2), 1.2), + Row.of(1.0, Vectors.dense(0.9, 0.1), 1.0)); + + private static final double[] EXPECTED_DATA = + new double[] {0.7691481137909708, 0.3714285714285714, 0.6571428571428571}; + private static final double[] EXPECTED_DATA_M = + new double[] {0.8571428571428571, 0.9377705627705628, 0.8571428571428571}; + private static final double EXPECTED_DATA_W = 0.8911680911680911; + private static final double EPS = 1.0e-5; + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(3); Review Comment: I tried changing this to `env.setParallelism(20);`, and these test cases failed. Could you please fix this problem and add a test similar to `LinearRegressionTest.testMoreSubtaskThanData`? ########## flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassificationevaluator/BinaryClassificationEvaluator.java: ########## @@ -0,0 +1,730 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.evaluation.binaryclassificationevaluator; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.api.scala.typeutils.Types; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.AlgoOperator; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.StreamMap; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Random; + +import static org.apache.flink.runtime.blob.BlobWriter.LOG; + +/** + * An Estimator which calculates the evaluation metrics for binary classification. The input data + * has columns rawPrediction, label and an optional weight column. The rawPrediction can be of type + * double (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of + * raw predictions, scores, or label probabilities). The output may contain different metrics which + * will be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams. + */ +public class BinaryClassificationEvaluator + implements AlgoOperator<BinaryClassificationEvaluator>, + BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100; + + public BinaryClassificationEvaluator() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + @SuppressWarnings("unchecked") + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Tuple3<Double, Boolean, Double>> evalData = + tEnv.toDataStream(inputs[0]) + .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol())); + final String boundaryRangeKey = "boundaryRange"; + final String partitionSummariesKey = "partitionSummaries"; + + DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(evalData), + Collections.singletonMap(boundaryRangeKey, getBoundaryRange(evalData)), + inputList -> { + DataStream input = inputList.get(0); + return input.map(new AppendTaskId(boundaryRangeKey)); + }); + + /* Repartition the evaluated data by range. */ + evalDataWithTaskId = + evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3); + + /* Sorts local data by score.*/ + evalData = + DataStreamUtils.mapPartition( + evalDataWithTaskId, + new MapPartitionFunction< + Tuple4<Double, Boolean, Double, Integer>, + Tuple3<Double, Boolean, Double>>() { + @Override + public void mapPartition( + Iterable<Tuple4<Double, Boolean, Double, Integer>> values, + Collector<Tuple3<Double, Boolean, Double>> out) { + List<Tuple3<Double, Boolean, Double>> bufferedData = + new LinkedList<>(); + for (Tuple4<Double, Boolean, Double, Integer> t4 : values) { + bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2)); + } + bufferedData.sort(Comparator.comparingDouble(o -> -o.f0)); + for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) { + out.collect(dataPoint); + } + } + }); + + /* Calculates the summary of local data. */ + DataStream<BinarySummary> partitionSummaries = + evalData.transform( + "reduceInEachPartition", + TypeInformation.of(BinarySummary.class), + new PartitionSummaryOperator()); + + /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */ + DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(evalData), + Collections.singletonMap(partitionSummariesKey, partitionSummaries), + inputList -> { + DataStream input = inputList.get(0); + return input.flatMap(new CalcSampleOrders(partitionSummariesKey)); + }); + + dataWithOrders = + dataWithOrders.transform( + "appendMaxWaterMark", + dataWithOrders.getType(), + new AppendMaxWatermark(x -> x)); + + DataStream<double[]> localAreaUnderROCVariable = + dataWithOrders.transform( + "AccumulateMultiScore", + TypeInformation.of(double[].class), + new AccumulateMultiScoreOperator()); + + DataStream<double[]> middleAreaUnderROC = + DataStreamUtils.reduce( + localAreaUnderROCVariable, + (ReduceFunction<double[]>) + (t1, t2) -> { + t2[0] += t1[0]; + t2[1] += t1[1]; + t2[2] += t1[2]; + return t2; + }); + + DataStream<Double> areaUnderROC = + middleAreaUnderROC.map( + (MapFunction<double[], Double>) + value -> { + if (value[1] > 0 && value[2] > 0) { + return (value[0] - 1. * value[1] * (value[1] + 1) / 2) + / (value[1] * value[2]); + } else { + return Double.NaN; + } + }); + + Map<String, DataStream<?>> broadcastMap = new HashMap<>(); + broadcastMap.put(partitionSummariesKey, partitionSummaries); + broadcastMap.put(AREA_UNDER_ROC, areaUnderROC); + DataStream<BinaryMetrics> localMetrics = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(evalData), + broadcastMap, + inputList -> { + DataStream input = inputList.get(0); + return DataStreamUtils.mapPartition( + input, new CalcBinaryMetrics(partitionSummariesKey)); + }); + + DataStream<Map<String, Double>> metrics = + DataStreamUtils.mapPartition(localMetrics, new MergeMetrics()); + metrics.getTransformation().setParallelism(1); + + final String[] metricsNames = getMetricsNames(); + TypeInformation<?>[] metricTypes = new TypeInformation[metricsNames.length]; + Arrays.fill(metricTypes, Types.DOUBLE()); + RowTypeInfo outputTypeInfo = new RowTypeInfo(metricTypes, metricsNames); + + DataStream<Row> evalResult = + metrics.map( + (MapFunction<Map<String, Double>, Row>) + value -> { + Row ret = new Row(metricsNames.length); + for (int i = 0; i < metricsNames.length; ++i) { + ret.setField(i, value.get(metricsNames[i])); + } + return ret; + }, + outputTypeInfo); + return new Table[] {tEnv.fromDataStream(evalResult)}; + } + + /** Updates variables for calculating AreaUnderROC. */ + private static class AccumulateMultiScoreOperator extends AbstractStreamOperator<double[]> + implements OneInputStreamOperator<Tuple4<Double, Long, Boolean, Double>, double[]>, + BoundedOneInput { + private ListState<double[]> accValueState; + double[] accValue; + double score; + + @Override + public void endInput() { + if (accValue != null) { + output.collect( + new StreamRecord<>( + new double[] { + accValue[0] / accValue[1] * accValue[2], + accValue[2], + accValue[3] + })); + } + } + + @Override + public void processElement( + StreamRecord<Tuple4<Double, Long, Boolean, Double>> streamRecord) { + Tuple4<Double, Long, Boolean, Double> t = streamRecord.getValue(); + if (accValue == null) { + accValue = new double[4]; + score = t.f0; + } else if (score != t.f0) { + output.collect( + new StreamRecord<>( + new double[] { + accValue[0] / accValue[1] * accValue[2], + accValue[2], + accValue[3] + })); + Arrays.fill(accValue, 0.0); + } + accValue[0] += t.f1; + accValue[1] += 1.0; + if (t.f2) { + accValue[2] += t.f3; + } else { + accValue[3] += t.f3; + } + } + + @Override + @SuppressWarnings("unchecked") + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + accValueState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "accValueState", TypeInformation.of(double[].class))); + accValue = + OperatorStateUtils.getUniqueElement(accValueState, "valueState").orElse(null); + } + + @Override + @SuppressWarnings("unchecked") + public void snapshotState(StateSnapshotContext context) throws Exception { + super.snapshotState(context); + accValueState.clear(); + if (accValue != null) { + accValueState.add(accValue); + } + } + } + + private static class PartitionSummaryOperator extends AbstractStreamOperator<BinarySummary> + implements OneInputStreamOperator<Tuple3<Double, Boolean, Double>, BinarySummary>, + BoundedOneInput { + private ListState<BinarySummary> summaryState; + private BinarySummary summary; + + @Override + public void endInput() { + if (summary != null) { + output.collect(new StreamRecord<>(summary)); + } + } + + @Override + public void processElement(StreamRecord<Tuple3<Double, Boolean, Double>> streamRecord) { + updateBinarySummary(summary, streamRecord.getValue()); + } + + @Override + @SuppressWarnings("unchecked") + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + summaryState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "summaryState", + TypeInformation.of(BinarySummary.class))); + summary = + OperatorStateUtils.getUniqueElement(summaryState, "summaryState") + .orElse( + new BinarySummary( + getRuntimeContext().getIndexOfThisSubtask(), + -Double.MAX_VALUE, + 0, + 0)); + } + + @Override + @SuppressWarnings("unchecked") + public void snapshotState(StateSnapshotContext context) throws Exception { + super.snapshotState(context); + summaryState.clear(); + if (summary != null) { + summaryState.add(summary); + } + } + } + + /** Merges the metrics calculated locally and output metrics data. */ + private static class MergeMetrics + implements MapPartitionFunction<BinaryMetrics, Map<String, Double>> { + @Override + public void mapPartition( + Iterable<BinaryMetrics> values, Collector<Map<String, Double>> out) { + Iterator<BinaryMetrics> iter = values.iterator(); + BinaryMetrics reduceMetrics = iter.next(); + while (iter.hasNext()) { + reduceMetrics = reduceMetrics.merge(iter.next()); + } + Map<String, Double> map = new HashMap<>(); + map.put(AREA_UNDER_ROC, reduceMetrics.areaUnderROC); + map.put(AREA_UNDER_PR, reduceMetrics.areaUnderPR); + map.put(AREA_UNDER_LORENZ, reduceMetrics.areaUnderLorenz); + map.put(KS, reduceMetrics.ks); + out.collect(map); + } + } + + private static class CalcBinaryMetrics + extends RichMapPartitionFunction<Tuple3<Double, Boolean, Double>, BinaryMetrics> { + private final String partitionSummariesKey; + + public CalcBinaryMetrics(String partitionSummariesKey) { + this.partitionSummariesKey = partitionSummariesKey; + } + + @Override + public void mapPartition( + Iterable<Tuple3<Double, Boolean, Double>> iterable, + Collector<BinaryMetrics> collector) { + + List<BinarySummary> statistics = + getRuntimeContext().getBroadcastVariable(partitionSummariesKey); + long[] countValues = + reduceBinarySummary(statistics, getRuntimeContext().getIndexOfThisSubtask()); + + double areaUnderROC = + getRuntimeContext().<Double>getBroadcastVariable(AREA_UNDER_ROC).get(0); + long totalTrue = countValues[2]; + long totalFalse = countValues[3]; + if (totalTrue == 0) { + LOG.warn("There is no positive sample in data!"); + } + if (totalFalse == 0) { + LOG.warn("There is no negative sample in data!"); + } + + BinaryMetrics metrics = new BinaryMetrics(0L, areaUnderROC); + double[] tprFprPrecision = new double[4]; + for (Tuple3<Double, Boolean, Double> t3 : iterable) { + updateBinaryMetrics(t3, metrics, countValues, tprFprPrecision); + } + collector.collect(metrics); + } + } + + private static void updateBinaryMetrics( + Tuple3<Double, Boolean, Double> cur, + BinaryMetrics binaryMetrics, + long[] countValues, + double[] recordValues) { + if (binaryMetrics.count == 0) { + recordValues[0] = countValues[2] == 0 ? 1.0 : 1.0 * countValues[0] / countValues[2]; + recordValues[1] = countValues[3] == 0 ? 1.0 : 1.0 * countValues[1] / countValues[3]; + recordValues[2] = + countValues[0] + countValues[1] == 0 + ? 1.0 + : 1.0 * countValues[0] / (countValues[0] + countValues[1]); + recordValues[3] = + 1.0 * (countValues[0] + countValues[1]) / (countValues[2] + countValues[3]); + } + + binaryMetrics.count++; + if (cur.f1) { + countValues[0]++; + } else { + countValues[1]++; + } + + double tpr = countValues[2] == 0 ? 1.0 : 1.0 * countValues[0] / countValues[2]; + double fpr = countValues[3] == 0 ? 1.0 : 1.0 * countValues[1] / countValues[3]; + double precision = + countValues[0] + countValues[1] == 0 + ? 1.0 + : 1.0 * countValues[0] / (countValues[0] + countValues[1]); + double positiveRate = + 1.0 * (countValues[0] + countValues[1]) / (countValues[2] + countValues[3]); + + binaryMetrics.areaUnderLorenz += + ((positiveRate - recordValues[3]) * (tpr + recordValues[0]) / 2); + binaryMetrics.areaUnderPR += ((tpr - recordValues[0]) * (precision + recordValues[2]) / 2); + binaryMetrics.ks = Math.max(Math.abs(fpr - tpr), binaryMetrics.ks); + + recordValues[0] = tpr; + recordValues[1] = fpr; + recordValues[2] = precision; + recordValues[3] = positiveRate; + } + + /** + * For each sample, calculates its score order among all samples. The sample with minimum score + * has order 1, while the sample with maximum score has order samples. + * + * <p>Input is a dataset of tuple (score, is real positive, wight), output is a dataset of tuple Review Comment: "weight". ########## flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluatorTest.java: ########## @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.evaluation.binaryeval; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.evaluation.binaryclassificationevaluator.BinaryClassificationEvaluator; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.StageTestUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Arrays; +import java.util.List; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +/** Tests {@link BinaryClassificationEvaluator}. */ +public class BinaryClassificationEvaluatorTest { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + private StreamTableEnvironment tEnv; + private Table inputDataTable; + private Table inputDataTableWithMultiScore; + private Table inputDataTableWithWeight; + + private static final List<Row> INPUT_DATA = + Arrays.asList( + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(1.0, Vectors.dense(0.2, 0.8)), + Row.of(1.0, Vectors.dense(0.3, 0.7)), + Row.of(0.0, Vectors.dense(0.25, 0.75)), + Row.of(0.0, Vectors.dense(0.4, 0.6)), + Row.of(1.0, Vectors.dense(0.35, 0.65)), + Row.of(1.0, Vectors.dense(0.45, 0.55)), + Row.of(0.0, Vectors.dense(0.6, 0.4)), + Row.of(0.0, Vectors.dense(0.7, 0.3)), + Row.of(1.0, Vectors.dense(0.65, 0.35)), + Row.of(0.0, Vectors.dense(0.8, 0.2)), + Row.of(1.0, Vectors.dense(0.9, 0.1))); + + private static final List<Row> INPUT_DATA_WITH_MULTI_SCORE = + Arrays.asList( + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(0.0, Vectors.dense(0.25, 0.75)), + Row.of(0.0, Vectors.dense(0.4, 0.6)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(0.0, Vectors.dense(0.6, 0.4)), + Row.of(0.0, Vectors.dense(0.7, 0.3)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(0.0, Vectors.dense(0.8, 0.2)), + Row.of(1.0, Vectors.dense(0.9, 0.1))); + + private static final List<Row> INPUT_DATA_WITH_WEIGHT = + Arrays.asList( + Row.of(1.0, Vectors.dense(0.1, 0.9), 0.8), + Row.of(1.0, Vectors.dense(0.1, 0.9), 0.7), + Row.of(1.0, Vectors.dense(0.1, 0.9), 0.5), + Row.of(0.0, Vectors.dense(0.25, 0.75), 1.2), + Row.of(0.0, Vectors.dense(0.4, 0.6), 1.3), + Row.of(1.0, Vectors.dense(0.1, 0.9), 1.5), + Row.of(1.0, Vectors.dense(0.1, 0.9), 1.4), + Row.of(0.0, Vectors.dense(0.6, 0.4), 0.3), + Row.of(0.0, Vectors.dense(0.7, 0.3), 0.5), + Row.of(1.0, Vectors.dense(0.1, 0.9), 1.9), + Row.of(0.0, Vectors.dense(0.8, 0.2), 1.2), + Row.of(1.0, Vectors.dense(0.9, 0.1), 1.0)); + + private static final double[] EXPECTED_DATA = + new double[] {0.7691481137909708, 0.3714285714285714, 0.6571428571428571}; + private static final double[] EXPECTED_DATA_M = + new double[] {0.8571428571428571, 0.9377705627705628, 0.8571428571428571}; + private static final double EXPECTED_DATA_W = 0.8911680911680911; + private static final double EPS = 1.0e-5; + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(3); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + inputDataTable = + tEnv.fromDataStream(env.fromCollection(INPUT_DATA)).as("label", "rawPrediction"); + inputDataTableWithMultiScore = + tEnv.fromDataStream(env.fromCollection(INPUT_DATA_WITH_MULTI_SCORE)) + .as("label", "rawPrediction"); + inputDataTableWithWeight = + tEnv.fromDataStream(env.fromCollection(INPUT_DATA_WITH_WEIGHT)) + .as("label", "rawPrediction", "weight"); + } + + @Test + public void testParam() { + BinaryClassificationEvaluator binaryEval = new BinaryClassificationEvaluator(); + assertEquals("label", binaryEval.getLabelCol()); + assertNull(binaryEval.getWeightCol()); + assertEquals("rawPrediction", binaryEval.getRawPredictionCol()); + assertArrayEquals( + new String[] {"areaUnderROC", "areaUnderPR"}, binaryEval.getMetricsNames()); + binaryEval + .setLabelCol("labelCol") + .setRawPredictionCol("raw") + .setMetricsNames("areaUnderROC") + .setWeightCol("weight"); + assertEquals("labelCol", binaryEval.getLabelCol()); + assertEquals("weight", binaryEval.getWeightCol()); + assertEquals("raw", binaryEval.getRawPredictionCol()); + assertArrayEquals(new String[] {"areaUnderROC"}, binaryEval.getMetricsNames()); + } + + @Test + public void testSaveLoadAndEvaluate() throws Exception { + BinaryClassificationEvaluator eval = + new BinaryClassificationEvaluator() + .setMetricsNames("areaUnderPR", "ks", "areaUnderROC"); Review Comment: `areaUnderLorenz` is not tested in these test cases yet. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java: ########## @@ -0,0 +1,783 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.evaluation.binaryeval; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.api.scala.typeutils.Types; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.AlgoOperator; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.StreamMap; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Random; + +import static org.apache.flink.runtime.blob.BlobWriter.LOG; + +/** + * Calculates the evaluation metrics for binary classification. The input data has columns + * rawPrediction, label and an optional weight column. The rawPrediction can be of type double + * (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw + * predictions, scores, or label probabilities). The output may contain different metrics which will + * be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams. + */ +public class BinaryClassificationEvaluator + implements AlgoOperator<BinaryClassificationEvaluator>, + BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100; + + public BinaryClassificationEvaluator() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + @SuppressWarnings("unchecked") + public Table[] transform(Table... inputs) { Review Comment: Maybe we can dive deeper into performance optimization chances later after this PR. Please feel free to close this conversation for now. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java: ########## @@ -0,0 +1,783 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.evaluation.binaryeval; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.api.scala.typeutils.Types; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.AlgoOperator; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.StreamMap; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Random; + +import static org.apache.flink.runtime.blob.BlobWriter.LOG; + +/** + * Calculates the evaluation metrics for binary classification. The input data has columns + * rawPrediction, label and an optional weight column. The rawPrediction can be of type double + * (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw + * predictions, scores, or label probabilities). The output may contain different metrics which will + * be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams. + */ +public class BinaryClassificationEvaluator + implements AlgoOperator<BinaryClassificationEvaluator>, + BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100; Review Comment: Got it. ########## flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluatorTest.java: ########## @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.evaluation.binaryeval; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.evaluation.binaryclassificationevaluator.BinaryClassificationEvaluator; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.StageTestUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Arrays; +import java.util.List; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +/** Tests {@link BinaryClassificationEvaluator}. */ +public class BinaryClassificationEvaluatorTest { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + private StreamTableEnvironment tEnv; + private Table inputDataTable; + private Table inputDataTableWithMultiScore; + private Table inputDataTableWithWeight; + + private static final List<Row> INPUT_DATA = + Arrays.asList( + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(1.0, Vectors.dense(0.2, 0.8)), + Row.of(1.0, Vectors.dense(0.3, 0.7)), + Row.of(0.0, Vectors.dense(0.25, 0.75)), + Row.of(0.0, Vectors.dense(0.4, 0.6)), + Row.of(1.0, Vectors.dense(0.35, 0.65)), + Row.of(1.0, Vectors.dense(0.45, 0.55)), + Row.of(0.0, Vectors.dense(0.6, 0.4)), + Row.of(0.0, Vectors.dense(0.7, 0.3)), + Row.of(1.0, Vectors.dense(0.65, 0.35)), + Row.of(0.0, Vectors.dense(0.8, 0.2)), + Row.of(1.0, Vectors.dense(0.9, 0.1))); + + private static final List<Row> INPUT_DATA_WITH_MULTI_SCORE = + Arrays.asList( + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(0.0, Vectors.dense(0.25, 0.75)), + Row.of(0.0, Vectors.dense(0.4, 0.6)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(0.0, Vectors.dense(0.6, 0.4)), + Row.of(0.0, Vectors.dense(0.7, 0.3)), + Row.of(1.0, Vectors.dense(0.1, 0.9)), + Row.of(0.0, Vectors.dense(0.8, 0.2)), + Row.of(1.0, Vectors.dense(0.9, 0.1))); + + private static final List<Row> INPUT_DATA_WITH_WEIGHT = + Arrays.asList( + Row.of(1.0, Vectors.dense(0.1, 0.9), 0.8), + Row.of(1.0, Vectors.dense(0.1, 0.9), 0.7), + Row.of(1.0, Vectors.dense(0.1, 0.9), 0.5), + Row.of(0.0, Vectors.dense(0.25, 0.75), 1.2), + Row.of(0.0, Vectors.dense(0.4, 0.6), 1.3), + Row.of(1.0, Vectors.dense(0.1, 0.9), 1.5), + Row.of(1.0, Vectors.dense(0.1, 0.9), 1.4), + Row.of(0.0, Vectors.dense(0.6, 0.4), 0.3), + Row.of(0.0, Vectors.dense(0.7, 0.3), 0.5), + Row.of(1.0, Vectors.dense(0.1, 0.9), 1.9), + Row.of(0.0, Vectors.dense(0.8, 0.2), 1.2), + Row.of(1.0, Vectors.dense(0.9, 0.1), 1.0)); + + private static final double[] EXPECTED_DATA = + new double[] {0.7691481137909708, 0.3714285714285714, 0.6571428571428571}; + private static final double[] EXPECTED_DATA_M = + new double[] {0.8571428571428571, 0.9377705627705628, 0.8571428571428571}; + private static final double EXPECTED_DATA_W = 0.8911680911680911; + private static final double EPS = 1.0e-5; + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(3); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + inputDataTable = + tEnv.fromDataStream(env.fromCollection(INPUT_DATA)).as("label", "rawPrediction"); + inputDataTableWithMultiScore = + tEnv.fromDataStream(env.fromCollection(INPUT_DATA_WITH_MULTI_SCORE)) + .as("label", "rawPrediction"); + inputDataTableWithWeight = + tEnv.fromDataStream(env.fromCollection(INPUT_DATA_WITH_WEIGHT)) + .as("label", "rawPrediction", "weight"); + } + + @Test + public void testParam() { + BinaryClassificationEvaluator binaryEval = new BinaryClassificationEvaluator(); + assertEquals("label", binaryEval.getLabelCol()); + assertNull(binaryEval.getWeightCol()); + assertEquals("rawPrediction", binaryEval.getRawPredictionCol()); + assertArrayEquals( + new String[] {"areaUnderROC", "areaUnderPR"}, binaryEval.getMetricsNames()); + binaryEval + .setLabelCol("labelCol") + .setRawPredictionCol("raw") + .setMetricsNames("areaUnderROC") + .setWeightCol("weight"); + assertEquals("labelCol", binaryEval.getLabelCol()); + assertEquals("weight", binaryEval.getWeightCol()); + assertEquals("raw", binaryEval.getRawPredictionCol()); + assertArrayEquals(new String[] {"areaUnderROC"}, binaryEval.getMetricsNames()); + } + + @Test + public void testSaveLoadAndEvaluate() throws Exception { + BinaryClassificationEvaluator eval = + new BinaryClassificationEvaluator() + .setMetricsNames("areaUnderPR", "ks", "areaUnderROC"); + BinaryClassificationEvaluator loadedEval = + StageTestUtils.saveAndReload(tEnv, eval, tempFolder.newFolder().getAbsolutePath()); + Table evalResult = loadedEval.transform(inputDataTable)[0]; + DataStream<Row> dataStream = tEnv.toDataStream(evalResult); + assertArrayEquals( + new String[] {"areaUnderPR", "ks", "areaUnderROC"}, + evalResult.getResolvedSchema().getColumnNames().toArray()); + List<Row> results = IteratorUtils.toList(dataStream.executeAndCollect()); Review Comment: nit: `List<Row> results = IteratorUtils.toList(evalResult.execute().collect());` is enough. `tEnv.toDataStream` is unnecessary. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
