yunfengzhou-hub commented on code in PR #86: URL: https://github.com/apache/flink-ml/pull/86#discussion_r864421566
########## flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluatorParams.java: ########## @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.evaluation.binaryeval; + +import org.apache.flink.ml.common.param.HasLabelCol; +import org.apache.flink.ml.common.param.HasRawPredictionCol; +import org.apache.flink.ml.common.param.HasWeightCol; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringArrayParam; + +/** + * Params of BinaryClassificationEvaluator. + * + * @param <T> The class type of this instance. + */ +public interface BinaryClassificationEvaluatorParams<T> + extends HasLabelCol<T>, HasRawPredictionCol<T>, HasWeightCol<T> { + String AREA_UNDER_ROC = "areaUnderROC"; + String AREA_UNDER_PR = "areaUnderPR"; + String AREA_UNDER_LORENZ = "areaUnderLorenz"; + String KS = "KS"; + + /** + * param for metric names in evaluation (supports 'areaUnderROC', 'areaUnderPR', 'KS' and Review Comment: nit: Javadocs should start with Uppercase letters. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java: ########## @@ -0,0 +1,783 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.evaluation.binaryeval; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.api.scala.typeutils.Types; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.AlgoOperator; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.StreamMap; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Random; + +import static org.apache.flink.runtime.blob.BlobWriter.LOG; + +/** + * Calculates the evaluation metrics for binary classification. The input data has columns + * rawPrediction, label and an optional weight column. The rawPrediction can be of type double + * (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw + * predictions, scores, or label probabilities). The output may contain different metrics which will + * be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams. + */ +public class BinaryClassificationEvaluator + implements AlgoOperator<BinaryClassificationEvaluator>, + BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100; + + public BinaryClassificationEvaluator() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + @SuppressWarnings("unchecked") + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Tuple3<Double, Boolean, Double>> evalData = + tEnv.toDataStream(inputs[0]) + .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol())); + final String boundaryRangeKey = "boundaryRange"; + final String partitionSummariesKey = "partitionSummaries"; + + DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(evalData), + Collections.singletonMap(boundaryRangeKey, getBoundaryRange(evalData)), + inputList -> { + DataStream input = inputList.get(0); + return input.map(new AppendTaskId(boundaryRangeKey)); + }); + + /* Repartition the evaluated data by range. */ + evalDataWithTaskId = + evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3); + + /* Sorts local data by score.*/ + evalData = + DataStreamUtils.mapPartition( + evalDataWithTaskId, + new MapPartitionFunction< + Tuple4<Double, Boolean, Double, Integer>, + Tuple3<Double, Boolean, Double>>() { + @Override + public void mapPartition( + Iterable<Tuple4<Double, Boolean, Double, Integer>> values, + Collector<Tuple3<Double, Boolean, Double>> out) { + List<Tuple3<Double, Boolean, Double>> bufferedData = + new LinkedList<>(); + for (Tuple4<Double, Boolean, Double, Integer> t4 : values) { + bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2)); Review Comment: Trying to store every record in a stream in a `List` could easily cause OOM problems, which means code blocks like this would soon be removed in Flink ML 2.1's release plan to optimize performance in the next few weeks. Could you please help to find out way to avoid storing or sorting all records totally in memory? I would try to optimize `mapPartition` code so the infrastructure would not be a concern. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java: ########## @@ -0,0 +1,783 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.evaluation.binaryeval; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.api.scala.typeutils.Types; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.AlgoOperator; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.StreamMap; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Random; + +import static org.apache.flink.runtime.blob.BlobWriter.LOG; + +/** + * Calculates the evaluation metrics for binary classification. The input data has columns + * rawPrediction, label and an optional weight column. The rawPrediction can be of type double + * (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw + * predictions, scores, or label probabilities). The output may contain different metrics which will + * be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams. + */ +public class BinaryClassificationEvaluator + implements AlgoOperator<BinaryClassificationEvaluator>, + BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100; + + public BinaryClassificationEvaluator() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + @SuppressWarnings("unchecked") + public Table[] transform(Table... inputs) { Review Comment: I noticed that the implementation of this method composes a relatively complicated JobGraph. Is it possible to simplify the JobGraph's structure? For example, do we have to sort all records before proceeding to the next step? Is it a must for us to get areaUnderRoc before computing all other metrics? Can we try to avoid using `withBroadcast` and reduce operations multiple times? ########## flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java: ########## @@ -0,0 +1,783 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.evaluation.binaryeval; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.api.scala.typeutils.Types; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.AlgoOperator; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.StreamMap; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Random; + +import static org.apache.flink.runtime.blob.BlobWriter.LOG; + +/** + * Calculates the evaluation metrics for binary classification. The input data has columns + * rawPrediction, label and an optional weight column. The rawPrediction can be of type double + * (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw + * predictions, scores, or label probabilities). The output may contain different metrics which will + * be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams. + */ +public class BinaryClassificationEvaluator + implements AlgoOperator<BinaryClassificationEvaluator>, + BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100; + + public BinaryClassificationEvaluator() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + @SuppressWarnings("unchecked") + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Tuple3<Double, Boolean, Double>> evalData = + tEnv.toDataStream(inputs[0]) + .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol())); + final String boundaryRangeKey = "boundaryRange"; + final String partitionSummariesKey = "partitionSummaries"; + + DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(evalData), + Collections.singletonMap(boundaryRangeKey, getBoundaryRange(evalData)), + inputList -> { + DataStream input = inputList.get(0); + return input.map(new AppendTaskId(boundaryRangeKey)); + }); + + /* Repartition the evaluated data by range. */ + evalDataWithTaskId = + evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3); + + /* Sorts local data by score.*/ + evalData = + DataStreamUtils.mapPartition( + evalDataWithTaskId, + new MapPartitionFunction< + Tuple4<Double, Boolean, Double, Integer>, + Tuple3<Double, Boolean, Double>>() { + @Override + public void mapPartition( + Iterable<Tuple4<Double, Boolean, Double, Integer>> values, + Collector<Tuple3<Double, Boolean, Double>> out) { + List<Tuple3<Double, Boolean, Double>> bufferedData = + new LinkedList<>(); + for (Tuple4<Double, Boolean, Double, Integer> t4 : values) { + bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2)); + } + bufferedData.sort(Comparator.comparingDouble(o -> -o.f0)); + for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) { + out.collect(dataPoint); + } + } + }); + + /* Calculates the summary of local data. */ + DataStream<BinarySummary> partitionSummaries = + evalData.transform( + "reduceInEachPartition", + TypeInformation.of(BinarySummary.class), + new PartitionSummaryOperator()); + + /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */ + DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(evalData), + Collections.singletonMap(partitionSummariesKey, partitionSummaries), + inputList -> { + DataStream input = inputList.get(0); + return input.flatMap(new CalcSampleOrders(partitionSummariesKey)); + }); + + dataWithOrders = + dataWithOrders.transform( + "appendMaxWaterMark", + dataWithOrders.getType(), + new AppendMaxWatermark(x -> x)); + + DataStream<double[]> localAucVariable = + dataWithOrders.transform( + "AccumulateMultiScore", + TypeInformation.of(double[].class), + new AccumulateMultiScoreOperator()); + + DataStream<double[]> middleAreaUnderROC = + localAucVariable + .transform( + "calcLocalAucValues", + TypeInformation.of(double[].class), + new AucOperator()) + .transform( + "calcGlobalAucValues", + TypeInformation.of(double[].class), + new AucOperator()) + .setParallelism(1); + + DataStream<Double> areaUnderROC = + middleAreaUnderROC.map( + (MapFunction<double[], Double>) + value -> { + if (value[1] > 0 && value[2] > 0) { + return (value[0] - 1. * value[1] * (value[1] + 1) / 2) + / (value[1] * value[2]); + } else { + return Double.NaN; + } + }); + + Map<String, DataStream<?>> broadcastMap = new HashMap<>(); + broadcastMap.put(partitionSummariesKey, partitionSummaries); + broadcastMap.put(AREA_UNDER_ROC, areaUnderROC); + DataStream<BinaryMetrics> localMetrics = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(evalData), + broadcastMap, + inputList -> { + DataStream input = inputList.get(0); + return DataStreamUtils.mapPartition( + input, new CalcBinaryMetrics(partitionSummariesKey)); + }); + + DataStream<Map<String, Double>> metrics = + DataStreamUtils.mapPartition(localMetrics, new MergeMetrics()); + metrics.getTransformation().setParallelism(1); + + final String[] metricsNames = getMetricsNames(); + TypeInformation<?>[] metricTypes = new TypeInformation[metricsNames.length]; + for (int i = 0; i < metricsNames.length; ++i) { + metricTypes[i] = Types.DOUBLE(); + } + RowTypeInfo outputTypeInfo = new RowTypeInfo(metricTypes, metricsNames); + + DataStream<Row> evalResult = + metrics.map( + (MapFunction<Map<String, Double>, Row>) + value -> { + Row ret = new Row(metricsNames.length); + for (int i = 0; i < metricsNames.length; ++i) { + ret.setField(i, value.get(metricsNames[i])); + } + return ret; + }, + outputTypeInfo); + return new Table[] {tEnv.fromDataStream(evalResult)}; + } + + /** Updates variables for calculating Auc. */ + private static class AucOperator extends AbstractStreamOperator<double[]> + implements OneInputStreamOperator<double[], double[]>, BoundedOneInput { + private ListState<double[]> valueState; + private double[] value; + + @Override + public void endInput() { + if (value != null) { + output.collect(new StreamRecord<>(value)); + } + } + + @Override + public void processElement(StreamRecord<double[]> streamRecord) { + double[] tmpValues = streamRecord.getValue(); + value[0] += tmpValues[0]; + value[1] += tmpValues[1]; + value[2] += tmpValues[2]; + } + + @Override + @SuppressWarnings("unchecked") + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + valueState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "valueState", TypeInformation.of(double[].class))); + value = + OperatorStateUtils.getUniqueElement(valueState, "valueState") + .orElse(new double[3]); + } + + @Override + @SuppressWarnings("unchecked") + public void snapshotState(StateSnapshotContext context) throws Exception { + super.snapshotState(context); + valueState.clear(); + if (value != null) { + valueState.add(value); + } + } + } + + /** Updates variables for calculating Auc. */ + private static class AccumulateMultiScoreOperator extends AbstractStreamOperator<double[]> + implements OneInputStreamOperator<Tuple4<Double, Long, Boolean, Double>, double[]>, + BoundedOneInput { + private ListState<double[]> accValueState; + double[] accValue; + double score; + + @Override + public void endInput() { + if (accValue != null) { + output.collect( + new StreamRecord<>( + new double[] { + accValue[0] / accValue[1] * accValue[2], + accValue[2], + accValue[3] + })); + } + } + + @Override + public void processElement( + StreamRecord<Tuple4<Double, Long, Boolean, Double>> streamRecord) { + Tuple4<Double, Long, Boolean, Double> t = streamRecord.getValue(); + if (accValue == null) { + accValue = new double[4]; + score = t.f0; + } else if (score != t.f0) { + output.collect( + new StreamRecord<>( + new double[] { + accValue[0] / accValue[1] * accValue[2], + accValue[2], + accValue[3] + })); + Arrays.fill(accValue, 0.0); + } + accValue[0] += t.f1; + accValue[1] += 1.0; + if (t.f2) { + accValue[2] += t.f3; + } else { + accValue[3] += t.f3; + } + } + + @Override + @SuppressWarnings("unchecked") + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + accValueState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "accValueState", TypeInformation.of(double[].class))); + accValue = + OperatorStateUtils.getUniqueElement(accValueState, "valueState").orElse(null); + } + + @Override + @SuppressWarnings("unchecked") + public void snapshotState(StateSnapshotContext context) throws Exception { + super.snapshotState(context); + accValueState.clear(); + if (accValue != null) { + accValueState.add(accValue); + } + } + } + + private static class PartitionSummaryOperator extends AbstractStreamOperator<BinarySummary> + implements OneInputStreamOperator<Tuple3<Double, Boolean, Double>, BinarySummary>, + BoundedOneInput { + private ListState<BinarySummary> summaryState; + private BinarySummary summary; + + @Override + public void endInput() { + if (summary != null) { + output.collect(new StreamRecord<>(summary)); + } + } + + @Override + public void processElement(StreamRecord<Tuple3<Double, Boolean, Double>> streamRecord) { + updateBinarySummary(summary, streamRecord.getValue()); + } + + @Override + @SuppressWarnings("unchecked") + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + summaryState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "summaryState", + TypeInformation.of(BinarySummary.class))); + summary = + OperatorStateUtils.getUniqueElement(summaryState, "summaryState") + .orElse( + new BinarySummary( + getRuntimeContext().getIndexOfThisSubtask(), + -Double.MAX_VALUE, + 0, + 0)); + } + + @Override + @SuppressWarnings("unchecked") + public void snapshotState(StateSnapshotContext context) throws Exception { + super.snapshotState(context); + summaryState.clear(); + if (summary != null) { + summaryState.add(summary); + } + } + } + + /** Merges the metrics calculated locally and output metrics data. */ + private static class MergeMetrics + implements MapPartitionFunction<BinaryMetrics, Map<String, Double>> { + private static final long serialVersionUID = 463407033215369847L; + + @Override + public void mapPartition( + Iterable<BinaryMetrics> values, Collector<Map<String, Double>> out) { + Iterator<BinaryMetrics> iter = values.iterator(); + BinaryMetrics reduceMetrics = iter.next(); + while (iter.hasNext()) { + reduceMetrics = reduceMetrics.merge(iter.next()); + } + Map<String, Double> map = new HashMap<>(); Review Comment: Would it be better to use `Map<String, Double>` from the beginning, instead of introducing a new `BinaryMetrics` class? ########## flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluatorTest.java: ########## @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.evaluation.binaryeval; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.StageTestUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Arrays; +import java.util.List; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +/** Tests {@link BinaryClassificationEvaluator}. */ +public class BinaryClassificationEvaluatorTest { Review Comment: Can we add more test cases about the corner cases? For example, all labels are 1.0 or 0.0, or label values are independent from the rawPrediction results. For these corner cases we can generate test data like follows ```java final List<Row> INPUT_DATA = new ArrayList<>(); for (double i = 0.1; i < 1.0; i += 0.001) { INPUT_DATA.add(Row.of(i < 0.9? 1.0: 0.0, Vectors.dense(i, 1.0 - i))); } ``` ########## flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java: ########## @@ -0,0 +1,783 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.evaluation.binaryeval; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.api.scala.typeutils.Types; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.AlgoOperator; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.StreamMap; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Random; + +import static org.apache.flink.runtime.blob.BlobWriter.LOG; + +/** + * Calculates the evaluation metrics for binary classification. The input data has columns + * rawPrediction, label and an optional weight column. The rawPrediction can be of type double + * (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw + * predictions, scores, or label probabilities). The output may contain different metrics which will + * be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams. + */ +public class BinaryClassificationEvaluator + implements AlgoOperator<BinaryClassificationEvaluator>, + BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100; Review Comment: According to our offline discussion, this algorithm would not support `numBins` given that it would cause large deviations as pointed out by [tensorflow's issue](https://github.com/tensorflow/tensorflow/issues/14834). But so far as I can see, this `NUM_SAMPLE_FOR_RANGE_PARTITION` achieves similar function that could also cause the error. Could you please illustrate the difference between `numBins` and this variable? Besides, why should we choose a fixed `100` as the value of this variable, given that the scale of input data varies? Should we also add tests where the number of train data is larger than `the number of samples * the parallelism of the subtasks`? ########## flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java: ########## @@ -0,0 +1,783 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.evaluation.binaryeval; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.api.scala.typeutils.Types; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.AlgoOperator; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.StreamMap; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Random; + +import static org.apache.flink.runtime.blob.BlobWriter.LOG; + +/** + * Calculates the evaluation metrics for binary classification. The input data has columns Review Comment: The first sentence of the Javadoc for an Estimator/AlgoOperator class usually starts with a noun. It would be better to follow the practice of existing Javadocs. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java: ########## @@ -0,0 +1,783 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.evaluation.binaryeval; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.api.scala.typeutils.Types; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.AlgoOperator; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.StreamMap; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Random; + +import static org.apache.flink.runtime.blob.BlobWriter.LOG; + +/** + * Calculates the evaluation metrics for binary classification. The input data has columns + * rawPrediction, label and an optional weight column. The rawPrediction can be of type double + * (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw + * predictions, scores, or label probabilities). The output may contain different metrics which will + * be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams. + */ +public class BinaryClassificationEvaluator + implements AlgoOperator<BinaryClassificationEvaluator>, + BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100; + + public BinaryClassificationEvaluator() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + @SuppressWarnings("unchecked") + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Tuple3<Double, Boolean, Double>> evalData = + tEnv.toDataStream(inputs[0]) + .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol())); + final String boundaryRangeKey = "boundaryRange"; + final String partitionSummariesKey = "partitionSummaries"; + + DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(evalData), + Collections.singletonMap(boundaryRangeKey, getBoundaryRange(evalData)), + inputList -> { + DataStream input = inputList.get(0); + return input.map(new AppendTaskId(boundaryRangeKey)); + }); + + /* Repartition the evaluated data by range. */ + evalDataWithTaskId = + evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3); + + /* Sorts local data by score.*/ + evalData = + DataStreamUtils.mapPartition( + evalDataWithTaskId, + new MapPartitionFunction< + Tuple4<Double, Boolean, Double, Integer>, + Tuple3<Double, Boolean, Double>>() { + @Override + public void mapPartition( + Iterable<Tuple4<Double, Boolean, Double, Integer>> values, + Collector<Tuple3<Double, Boolean, Double>> out) { + List<Tuple3<Double, Boolean, Double>> bufferedData = + new LinkedList<>(); + for (Tuple4<Double, Boolean, Double, Integer> t4 : values) { + bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2)); + } + bufferedData.sort(Comparator.comparingDouble(o -> -o.f0)); + for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) { + out.collect(dataPoint); + } + } + }); + + /* Calculates the summary of local data. */ + DataStream<BinarySummary> partitionSummaries = + evalData.transform( + "reduceInEachPartition", + TypeInformation.of(BinarySummary.class), + new PartitionSummaryOperator()); + + /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */ + DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(evalData), + Collections.singletonMap(partitionSummariesKey, partitionSummaries), + inputList -> { + DataStream input = inputList.get(0); + return input.flatMap(new CalcSampleOrders(partitionSummariesKey)); + }); + + dataWithOrders = + dataWithOrders.transform( + "appendMaxWaterMark", + dataWithOrders.getType(), + new AppendMaxWatermark(x -> x)); + + DataStream<double[]> localAucVariable = + dataWithOrders.transform( + "AccumulateMultiScore", + TypeInformation.of(double[].class), + new AccumulateMultiScoreOperator()); + + DataStream<double[]> middleAreaUnderROC = + localAucVariable + .transform( + "calcLocalAucValues", + TypeInformation.of(double[].class), + new AucOperator()) + .transform( + "calcGlobalAucValues", + TypeInformation.of(double[].class), + new AucOperator()) + .setParallelism(1); + + DataStream<Double> areaUnderROC = + middleAreaUnderROC.map( + (MapFunction<double[], Double>) + value -> { + if (value[1] > 0 && value[2] > 0) { + return (value[0] - 1. * value[1] * (value[1] + 1) / 2) Review Comment: It might be hard for reviewers to remember the meaning of `value[0]` and `value[1]`. In order to improve readability of this PR, do you think it would be better to add more JavaDocs to explain the meaning of each element in double arrays and Tuple objects, or provide a meaningful variable name to each of such elements? ########## flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java: ########## @@ -0,0 +1,783 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.evaluation.binaryeval; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.api.scala.typeutils.Types; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.AlgoOperator; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.StreamMap; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Random; + +import static org.apache.flink.runtime.blob.BlobWriter.LOG; + +/** + * Calculates the evaluation metrics for binary classification. The input data has columns + * rawPrediction, label and an optional weight column. The rawPrediction can be of type double + * (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw + * predictions, scores, or label probabilities). The output may contain different metrics which will + * be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams. + */ +public class BinaryClassificationEvaluator + implements AlgoOperator<BinaryClassificationEvaluator>, + BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100; + + public BinaryClassificationEvaluator() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + @SuppressWarnings("unchecked") + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Tuple3<Double, Boolean, Double>> evalData = + tEnv.toDataStream(inputs[0]) + .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol())); + final String boundaryRangeKey = "boundaryRange"; + final String partitionSummariesKey = "partitionSummaries"; Review Comment: nit: It might be better to define these broadcast keys as static finals of the class, instead of defining them as local variables and passing them to operator's constructors. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java: ########## @@ -0,0 +1,783 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.evaluation.binaryeval; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.api.scala.typeutils.Types; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.AlgoOperator; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.StreamMap; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Random; + +import static org.apache.flink.runtime.blob.BlobWriter.LOG; + +/** + * Calculates the evaluation metrics for binary classification. The input data has columns + * rawPrediction, label and an optional weight column. The rawPrediction can be of type double + * (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw + * predictions, scores, or label probabilities). The output may contain different metrics which will + * be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams. + */ +public class BinaryClassificationEvaluator + implements AlgoOperator<BinaryClassificationEvaluator>, + BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100; + + public BinaryClassificationEvaluator() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + @SuppressWarnings("unchecked") + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Tuple3<Double, Boolean, Double>> evalData = + tEnv.toDataStream(inputs[0]) + .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol())); + final String boundaryRangeKey = "boundaryRange"; + final String partitionSummariesKey = "partitionSummaries"; + + DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(evalData), + Collections.singletonMap(boundaryRangeKey, getBoundaryRange(evalData)), + inputList -> { + DataStream input = inputList.get(0); + return input.map(new AppendTaskId(boundaryRangeKey)); + }); + + /* Repartition the evaluated data by range. */ + evalDataWithTaskId = + evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3); + + /* Sorts local data by score.*/ + evalData = + DataStreamUtils.mapPartition( + evalDataWithTaskId, + new MapPartitionFunction< + Tuple4<Double, Boolean, Double, Integer>, + Tuple3<Double, Boolean, Double>>() { + @Override + public void mapPartition( + Iterable<Tuple4<Double, Boolean, Double, Integer>> values, + Collector<Tuple3<Double, Boolean, Double>> out) { + List<Tuple3<Double, Boolean, Double>> bufferedData = + new LinkedList<>(); + for (Tuple4<Double, Boolean, Double, Integer> t4 : values) { + bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2)); + } + bufferedData.sort(Comparator.comparingDouble(o -> -o.f0)); + for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) { + out.collect(dataPoint); + } + } + }); + + /* Calculates the summary of local data. */ + DataStream<BinarySummary> partitionSummaries = + evalData.transform( + "reduceInEachPartition", + TypeInformation.of(BinarySummary.class), + new PartitionSummaryOperator()); + + /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */ + DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(evalData), + Collections.singletonMap(partitionSummariesKey, partitionSummaries), + inputList -> { + DataStream input = inputList.get(0); + return input.flatMap(new CalcSampleOrders(partitionSummariesKey)); + }); + + dataWithOrders = + dataWithOrders.transform( + "appendMaxWaterMark", + dataWithOrders.getType(), + new AppendMaxWatermark(x -> x)); + + DataStream<double[]> localAucVariable = + dataWithOrders.transform( + "AccumulateMultiScore", + TypeInformation.of(double[].class), + new AccumulateMultiScoreOperator()); + + DataStream<double[]> middleAreaUnderROC = + localAucVariable + .transform( + "calcLocalAucValues", + TypeInformation.of(double[].class), + new AucOperator()) + .transform( + "calcGlobalAucValues", + TypeInformation.of(double[].class), + new AucOperator()) + .setParallelism(1); + + DataStream<Double> areaUnderROC = + middleAreaUnderROC.map( + (MapFunction<double[], Double>) + value -> { + if (value[1] > 0 && value[2] > 0) { + return (value[0] - 1. * value[1] * (value[1] + 1) / 2) + / (value[1] * value[2]); + } else { + return Double.NaN; + } + }); + + Map<String, DataStream<?>> broadcastMap = new HashMap<>(); + broadcastMap.put(partitionSummariesKey, partitionSummaries); + broadcastMap.put(AREA_UNDER_ROC, areaUnderROC); + DataStream<BinaryMetrics> localMetrics = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(evalData), + broadcastMap, + inputList -> { + DataStream input = inputList.get(0); + return DataStreamUtils.mapPartition( + input, new CalcBinaryMetrics(partitionSummariesKey)); + }); + + DataStream<Map<String, Double>> metrics = + DataStreamUtils.mapPartition(localMetrics, new MergeMetrics()); + metrics.getTransformation().setParallelism(1); + + final String[] metricsNames = getMetricsNames(); + TypeInformation<?>[] metricTypes = new TypeInformation[metricsNames.length]; + for (int i = 0; i < metricsNames.length; ++i) { + metricTypes[i] = Types.DOUBLE(); Review Comment: nit: `ArrayUtils.addAll` might be better. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java: ########## @@ -0,0 +1,783 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.evaluation.binaryeval; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.api.scala.typeutils.Types; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.AlgoOperator; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.StreamMap; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Random; + +import static org.apache.flink.runtime.blob.BlobWriter.LOG; + +/** + * Calculates the evaluation metrics for binary classification. The input data has columns + * rawPrediction, label and an optional weight column. The rawPrediction can be of type double + * (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw + * predictions, scores, or label probabilities). The output may contain different metrics which will + * be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams. + */ +public class BinaryClassificationEvaluator + implements AlgoOperator<BinaryClassificationEvaluator>, + BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100; + + public BinaryClassificationEvaluator() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + @SuppressWarnings("unchecked") + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Tuple3<Double, Boolean, Double>> evalData = + tEnv.toDataStream(inputs[0]) + .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol())); + final String boundaryRangeKey = "boundaryRange"; + final String partitionSummariesKey = "partitionSummaries"; + + DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(evalData), + Collections.singletonMap(boundaryRangeKey, getBoundaryRange(evalData)), + inputList -> { + DataStream input = inputList.get(0); + return input.map(new AppendTaskId(boundaryRangeKey)); + }); + + /* Repartition the evaluated data by range. */ + evalDataWithTaskId = + evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3); + + /* Sorts local data by score.*/ + evalData = + DataStreamUtils.mapPartition( + evalDataWithTaskId, + new MapPartitionFunction< + Tuple4<Double, Boolean, Double, Integer>, + Tuple3<Double, Boolean, Double>>() { + @Override + public void mapPartition( + Iterable<Tuple4<Double, Boolean, Double, Integer>> values, + Collector<Tuple3<Double, Boolean, Double>> out) { + List<Tuple3<Double, Boolean, Double>> bufferedData = + new LinkedList<>(); + for (Tuple4<Double, Boolean, Double, Integer> t4 : values) { + bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2)); + } + bufferedData.sort(Comparator.comparingDouble(o -> -o.f0)); + for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) { + out.collect(dataPoint); + } + } + }); + + /* Calculates the summary of local data. */ + DataStream<BinarySummary> partitionSummaries = + evalData.transform( + "reduceInEachPartition", + TypeInformation.of(BinarySummary.class), + new PartitionSummaryOperator()); + + /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */ + DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(evalData), + Collections.singletonMap(partitionSummariesKey, partitionSummaries), + inputList -> { + DataStream input = inputList.get(0); + return input.flatMap(new CalcSampleOrders(partitionSummariesKey)); + }); + + dataWithOrders = + dataWithOrders.transform( + "appendMaxWaterMark", + dataWithOrders.getType(), + new AppendMaxWatermark(x -> x)); Review Comment: Watermark seems not used in the rest of this method. We can remove code like this. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java: ########## @@ -0,0 +1,783 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.evaluation.binaryeval; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.api.scala.typeutils.Types; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.AlgoOperator; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.StreamMap; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Random; + +import static org.apache.flink.runtime.blob.BlobWriter.LOG; + +/** + * Calculates the evaluation metrics for binary classification. The input data has columns + * rawPrediction, label and an optional weight column. The rawPrediction can be of type double + * (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw + * predictions, scores, or label probabilities). The output may contain different metrics which will + * be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams. + */ +public class BinaryClassificationEvaluator + implements AlgoOperator<BinaryClassificationEvaluator>, + BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100; + + public BinaryClassificationEvaluator() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + @SuppressWarnings("unchecked") + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Tuple3<Double, Boolean, Double>> evalData = + tEnv.toDataStream(inputs[0]) + .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol())); + final String boundaryRangeKey = "boundaryRange"; + final String partitionSummariesKey = "partitionSummaries"; + + DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(evalData), + Collections.singletonMap(boundaryRangeKey, getBoundaryRange(evalData)), + inputList -> { + DataStream input = inputList.get(0); + return input.map(new AppendTaskId(boundaryRangeKey)); + }); + + /* Repartition the evaluated data by range. */ + evalDataWithTaskId = + evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3); + + /* Sorts local data by score.*/ + evalData = + DataStreamUtils.mapPartition( + evalDataWithTaskId, + new MapPartitionFunction< + Tuple4<Double, Boolean, Double, Integer>, + Tuple3<Double, Boolean, Double>>() { + @Override + public void mapPartition( + Iterable<Tuple4<Double, Boolean, Double, Integer>> values, + Collector<Tuple3<Double, Boolean, Double>> out) { + List<Tuple3<Double, Boolean, Double>> bufferedData = + new LinkedList<>(); + for (Tuple4<Double, Boolean, Double, Integer> t4 : values) { + bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2)); + } + bufferedData.sort(Comparator.comparingDouble(o -> -o.f0)); + for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) { + out.collect(dataPoint); + } + } + }); + + /* Calculates the summary of local data. */ + DataStream<BinarySummary> partitionSummaries = + evalData.transform( + "reduceInEachPartition", + TypeInformation.of(BinarySummary.class), + new PartitionSummaryOperator()); + + /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */ + DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(evalData), + Collections.singletonMap(partitionSummariesKey, partitionSummaries), + inputList -> { + DataStream input = inputList.get(0); + return input.flatMap(new CalcSampleOrders(partitionSummariesKey)); + }); + + dataWithOrders = + dataWithOrders.transform( + "appendMaxWaterMark", + dataWithOrders.getType(), + new AppendMaxWatermark(x -> x)); + + DataStream<double[]> localAucVariable = + dataWithOrders.transform( + "AccumulateMultiScore", + TypeInformation.of(double[].class), + new AccumulateMultiScoreOperator()); + + DataStream<double[]> middleAreaUnderROC = + localAucVariable + .transform( + "calcLocalAucValues", + TypeInformation.of(double[].class), + new AucOperator()) + .transform( + "calcGlobalAucValues", + TypeInformation.of(double[].class), + new AucOperator()) + .setParallelism(1); Review Comment: In LinearRegression's PR, `DataStreamUtils.reduce()` is introduced to support reducing operations like this. We can refer to that PR to see how that infrastructure code can be shared. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java: ########## @@ -0,0 +1,783 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.evaluation.binaryeval; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.api.scala.typeutils.Types; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.AlgoOperator; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.StreamMap; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Random; + +import static org.apache.flink.runtime.blob.BlobWriter.LOG; + +/** + * Calculates the evaluation metrics for binary classification. The input data has columns + * rawPrediction, label and an optional weight column. The rawPrediction can be of type double + * (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw + * predictions, scores, or label probabilities). The output may contain different metrics which will + * be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams. + */ +public class BinaryClassificationEvaluator + implements AlgoOperator<BinaryClassificationEvaluator>, + BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100; + + public BinaryClassificationEvaluator() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + @SuppressWarnings("unchecked") + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Tuple3<Double, Boolean, Double>> evalData = + tEnv.toDataStream(inputs[0]) + .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol())); + final String boundaryRangeKey = "boundaryRange"; + final String partitionSummariesKey = "partitionSummaries"; + + DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(evalData), + Collections.singletonMap(boundaryRangeKey, getBoundaryRange(evalData)), + inputList -> { + DataStream input = inputList.get(0); + return input.map(new AppendTaskId(boundaryRangeKey)); + }); + + /* Repartition the evaluated data by range. */ + evalDataWithTaskId = + evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3); + + /* Sorts local data by score.*/ + evalData = + DataStreamUtils.mapPartition( + evalDataWithTaskId, + new MapPartitionFunction< + Tuple4<Double, Boolean, Double, Integer>, + Tuple3<Double, Boolean, Double>>() { + @Override + public void mapPartition( + Iterable<Tuple4<Double, Boolean, Double, Integer>> values, + Collector<Tuple3<Double, Boolean, Double>> out) { + List<Tuple3<Double, Boolean, Double>> bufferedData = + new LinkedList<>(); + for (Tuple4<Double, Boolean, Double, Integer> t4 : values) { + bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2)); + } + bufferedData.sort(Comparator.comparingDouble(o -> -o.f0)); + for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) { + out.collect(dataPoint); + } + } + }); + + /* Calculates the summary of local data. */ + DataStream<BinarySummary> partitionSummaries = + evalData.transform( + "reduceInEachPartition", + TypeInformation.of(BinarySummary.class), + new PartitionSummaryOperator()); + + /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */ + DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(evalData), + Collections.singletonMap(partitionSummariesKey, partitionSummaries), + inputList -> { + DataStream input = inputList.get(0); + return input.flatMap(new CalcSampleOrders(partitionSummariesKey)); + }); + + dataWithOrders = + dataWithOrders.transform( + "appendMaxWaterMark", + dataWithOrders.getType(), + new AppendMaxWatermark(x -> x)); + + DataStream<double[]> localAucVariable = + dataWithOrders.transform( + "AccumulateMultiScore", + TypeInformation.of(double[].class), + new AccumulateMultiScoreOperator()); + + DataStream<double[]> middleAreaUnderROC = + localAucVariable + .transform( + "calcLocalAucValues", + TypeInformation.of(double[].class), + new AucOperator()) + .transform( + "calcGlobalAucValues", + TypeInformation.of(double[].class), + new AucOperator()) + .setParallelism(1); + + DataStream<Double> areaUnderROC = + middleAreaUnderROC.map( + (MapFunction<double[], Double>) + value -> { + if (value[1] > 0 && value[2] > 0) { + return (value[0] - 1. * value[1] * (value[1] + 1) / 2) + / (value[1] * value[2]); + } else { + return Double.NaN; + } + }); + + Map<String, DataStream<?>> broadcastMap = new HashMap<>(); + broadcastMap.put(partitionSummariesKey, partitionSummaries); + broadcastMap.put(AREA_UNDER_ROC, areaUnderROC); + DataStream<BinaryMetrics> localMetrics = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(evalData), + broadcastMap, + inputList -> { + DataStream input = inputList.get(0); + return DataStreamUtils.mapPartition( + input, new CalcBinaryMetrics(partitionSummariesKey)); + }); + + DataStream<Map<String, Double>> metrics = + DataStreamUtils.mapPartition(localMetrics, new MergeMetrics()); + metrics.getTransformation().setParallelism(1); + + final String[] metricsNames = getMetricsNames(); + TypeInformation<?>[] metricTypes = new TypeInformation[metricsNames.length]; + for (int i = 0; i < metricsNames.length; ++i) { + metricTypes[i] = Types.DOUBLE(); + } + RowTypeInfo outputTypeInfo = new RowTypeInfo(metricTypes, metricsNames); + + DataStream<Row> evalResult = + metrics.map( + (MapFunction<Map<String, Double>, Row>) + value -> { + Row ret = new Row(metricsNames.length); + for (int i = 0; i < metricsNames.length; ++i) { + ret.setField(i, value.get(metricsNames[i])); + } + return ret; + }, + outputTypeInfo); + return new Table[] {tEnv.fromDataStream(evalResult)}; + } + + /** Updates variables for calculating Auc. */ + private static class AucOperator extends AbstractStreamOperator<double[]> + implements OneInputStreamOperator<double[], double[]>, BoundedOneInput { + private ListState<double[]> valueState; + private double[] value; + + @Override + public void endInput() { + if (value != null) { Review Comment: nit: `value` would always be created in `initializeState`, so it cannot be `null`. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java: ########## @@ -0,0 +1,783 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.evaluation.binaryeval; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.api.scala.typeutils.Types; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.AlgoOperator; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.StreamMap; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Random; + +import static org.apache.flink.runtime.blob.BlobWriter.LOG; + +/** + * Calculates the evaluation metrics for binary classification. The input data has columns + * rawPrediction, label and an optional weight column. The rawPrediction can be of type double + * (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw + * predictions, scores, or label probabilities). The output may contain different metrics which will + * be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams. + */ +public class BinaryClassificationEvaluator + implements AlgoOperator<BinaryClassificationEvaluator>, + BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100; + + public BinaryClassificationEvaluator() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + @SuppressWarnings("unchecked") + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Tuple3<Double, Boolean, Double>> evalData = + tEnv.toDataStream(inputs[0]) + .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol())); + final String boundaryRangeKey = "boundaryRange"; + final String partitionSummariesKey = "partitionSummaries"; + + DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(evalData), + Collections.singletonMap(boundaryRangeKey, getBoundaryRange(evalData)), + inputList -> { + DataStream input = inputList.get(0); + return input.map(new AppendTaskId(boundaryRangeKey)); + }); + + /* Repartition the evaluated data by range. */ + evalDataWithTaskId = + evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3); + + /* Sorts local data by score.*/ + evalData = + DataStreamUtils.mapPartition( + evalDataWithTaskId, + new MapPartitionFunction< + Tuple4<Double, Boolean, Double, Integer>, + Tuple3<Double, Boolean, Double>>() { + @Override + public void mapPartition( + Iterable<Tuple4<Double, Boolean, Double, Integer>> values, + Collector<Tuple3<Double, Boolean, Double>> out) { + List<Tuple3<Double, Boolean, Double>> bufferedData = + new LinkedList<>(); + for (Tuple4<Double, Boolean, Double, Integer> t4 : values) { + bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2)); + } + bufferedData.sort(Comparator.comparingDouble(o -> -o.f0)); + for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) { + out.collect(dataPoint); + } + } + }); + + /* Calculates the summary of local data. */ + DataStream<BinarySummary> partitionSummaries = + evalData.transform( + "reduceInEachPartition", + TypeInformation.of(BinarySummary.class), + new PartitionSummaryOperator()); + + /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */ + DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(evalData), + Collections.singletonMap(partitionSummariesKey, partitionSummaries), + inputList -> { + DataStream input = inputList.get(0); + return input.flatMap(new CalcSampleOrders(partitionSummariesKey)); + }); + + dataWithOrders = + dataWithOrders.transform( + "appendMaxWaterMark", + dataWithOrders.getType(), + new AppendMaxWatermark(x -> x)); + + DataStream<double[]> localAucVariable = + dataWithOrders.transform( + "AccumulateMultiScore", + TypeInformation.of(double[].class), + new AccumulateMultiScoreOperator()); + + DataStream<double[]> middleAreaUnderROC = + localAucVariable + .transform( + "calcLocalAucValues", + TypeInformation.of(double[].class), + new AucOperator()) + .transform( + "calcGlobalAucValues", + TypeInformation.of(double[].class), + new AucOperator()) + .setParallelism(1); + + DataStream<Double> areaUnderROC = + middleAreaUnderROC.map( + (MapFunction<double[], Double>) + value -> { + if (value[1] > 0 && value[2] > 0) { + return (value[0] - 1. * value[1] * (value[1] + 1) / 2) + / (value[1] * value[2]); + } else { + return Double.NaN; + } + }); + + Map<String, DataStream<?>> broadcastMap = new HashMap<>(); + broadcastMap.put(partitionSummariesKey, partitionSummaries); + broadcastMap.put(AREA_UNDER_ROC, areaUnderROC); + DataStream<BinaryMetrics> localMetrics = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(evalData), + broadcastMap, + inputList -> { + DataStream input = inputList.get(0); + return DataStreamUtils.mapPartition( + input, new CalcBinaryMetrics(partitionSummariesKey)); + }); + + DataStream<Map<String, Double>> metrics = + DataStreamUtils.mapPartition(localMetrics, new MergeMetrics()); + metrics.getTransformation().setParallelism(1); + + final String[] metricsNames = getMetricsNames(); + TypeInformation<?>[] metricTypes = new TypeInformation[metricsNames.length]; + for (int i = 0; i < metricsNames.length; ++i) { + metricTypes[i] = Types.DOUBLE(); + } + RowTypeInfo outputTypeInfo = new RowTypeInfo(metricTypes, metricsNames); + + DataStream<Row> evalResult = + metrics.map( + (MapFunction<Map<String, Double>, Row>) + value -> { + Row ret = new Row(metricsNames.length); + for (int i = 0; i < metricsNames.length; ++i) { + ret.setField(i, value.get(metricsNames[i])); + } + return ret; + }, + outputTypeInfo); + return new Table[] {tEnv.fromDataStream(evalResult)}; + } + + /** Updates variables for calculating Auc. */ Review Comment: It might be better to improve the JavaDoc for the private static classes in this class. For example, I'm not sure how it would "update" variables, what are the "variables" to be updated, and what "Auc" means. Same for other JavaDocs. Besides, we can reorder the static classes and methods according to other classes, or according to the order they appeared in `transform()`. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java: ########## @@ -0,0 +1,783 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.evaluation.binaryeval; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.api.scala.typeutils.Types; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.AlgoOperator; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.StreamMap; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Random; + +import static org.apache.flink.runtime.blob.BlobWriter.LOG; + +/** + * Calculates the evaluation metrics for binary classification. The input data has columns + * rawPrediction, label and an optional weight column. The rawPrediction can be of type double + * (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw + * predictions, scores, or label probabilities). The output may contain different metrics which will + * be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams. + */ +public class BinaryClassificationEvaluator + implements AlgoOperator<BinaryClassificationEvaluator>, + BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100; + + public BinaryClassificationEvaluator() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + @SuppressWarnings("unchecked") + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Tuple3<Double, Boolean, Double>> evalData = + tEnv.toDataStream(inputs[0]) + .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol())); + final String boundaryRangeKey = "boundaryRange"; + final String partitionSummariesKey = "partitionSummaries"; + + DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(evalData), + Collections.singletonMap(boundaryRangeKey, getBoundaryRange(evalData)), + inputList -> { + DataStream input = inputList.get(0); + return input.map(new AppendTaskId(boundaryRangeKey)); + }); + + /* Repartition the evaluated data by range. */ + evalDataWithTaskId = + evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3); + + /* Sorts local data by score.*/ + evalData = + DataStreamUtils.mapPartition( + evalDataWithTaskId, + new MapPartitionFunction< + Tuple4<Double, Boolean, Double, Integer>, + Tuple3<Double, Boolean, Double>>() { + @Override + public void mapPartition( + Iterable<Tuple4<Double, Boolean, Double, Integer>> values, + Collector<Tuple3<Double, Boolean, Double>> out) { + List<Tuple3<Double, Boolean, Double>> bufferedData = + new LinkedList<>(); + for (Tuple4<Double, Boolean, Double, Integer> t4 : values) { + bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2)); + } + bufferedData.sort(Comparator.comparingDouble(o -> -o.f0)); + for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) { + out.collect(dataPoint); + } + } + }); + + /* Calculates the summary of local data. */ + DataStream<BinarySummary> partitionSummaries = + evalData.transform( + "reduceInEachPartition", + TypeInformation.of(BinarySummary.class), + new PartitionSummaryOperator()); + + /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */ + DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(evalData), + Collections.singletonMap(partitionSummariesKey, partitionSummaries), + inputList -> { + DataStream input = inputList.get(0); + return input.flatMap(new CalcSampleOrders(partitionSummariesKey)); + }); + + dataWithOrders = + dataWithOrders.transform( + "appendMaxWaterMark", + dataWithOrders.getType(), + new AppendMaxWatermark(x -> x)); + + DataStream<double[]> localAucVariable = + dataWithOrders.transform( + "AccumulateMultiScore", + TypeInformation.of(double[].class), + new AccumulateMultiScoreOperator()); + + DataStream<double[]> middleAreaUnderROC = + localAucVariable + .transform( + "calcLocalAucValues", + TypeInformation.of(double[].class), + new AucOperator()) + .transform( + "calcGlobalAucValues", + TypeInformation.of(double[].class), + new AucOperator()) + .setParallelism(1); + + DataStream<Double> areaUnderROC = + middleAreaUnderROC.map( + (MapFunction<double[], Double>) + value -> { + if (value[1] > 0 && value[2] > 0) { + return (value[0] - 1. * value[1] * (value[1] + 1) / 2) + / (value[1] * value[2]); + } else { + return Double.NaN; + } + }); + + Map<String, DataStream<?>> broadcastMap = new HashMap<>(); + broadcastMap.put(partitionSummariesKey, partitionSummaries); + broadcastMap.put(AREA_UNDER_ROC, areaUnderROC); + DataStream<BinaryMetrics> localMetrics = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(evalData), + broadcastMap, + inputList -> { + DataStream input = inputList.get(0); + return DataStreamUtils.mapPartition( + input, new CalcBinaryMetrics(partitionSummariesKey)); + }); + + DataStream<Map<String, Double>> metrics = + DataStreamUtils.mapPartition(localMetrics, new MergeMetrics()); + metrics.getTransformation().setParallelism(1); + + final String[] metricsNames = getMetricsNames(); + TypeInformation<?>[] metricTypes = new TypeInformation[metricsNames.length]; + for (int i = 0; i < metricsNames.length; ++i) { + metricTypes[i] = Types.DOUBLE(); + } + RowTypeInfo outputTypeInfo = new RowTypeInfo(metricTypes, metricsNames); + + DataStream<Row> evalResult = + metrics.map( + (MapFunction<Map<String, Double>, Row>) + value -> { + Row ret = new Row(metricsNames.length); + for (int i = 0; i < metricsNames.length; ++i) { + ret.setField(i, value.get(metricsNames[i])); + } + return ret; + }, + outputTypeInfo); + return new Table[] {tEnv.fromDataStream(evalResult)}; + } + + /** Updates variables for calculating Auc. */ + private static class AucOperator extends AbstractStreamOperator<double[]> + implements OneInputStreamOperator<double[], double[]>, BoundedOneInput { + private ListState<double[]> valueState; + private double[] value; + + @Override + public void endInput() { + if (value != null) { + output.collect(new StreamRecord<>(value)); + } + } + + @Override + public void processElement(StreamRecord<double[]> streamRecord) { + double[] tmpValues = streamRecord.getValue(); + value[0] += tmpValues[0]; + value[1] += tmpValues[1]; + value[2] += tmpValues[2]; + } + + @Override + @SuppressWarnings("unchecked") + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + valueState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "valueState", TypeInformation.of(double[].class))); + value = + OperatorStateUtils.getUniqueElement(valueState, "valueState") + .orElse(new double[3]); + } + + @Override + @SuppressWarnings("unchecked") + public void snapshotState(StateSnapshotContext context) throws Exception { + super.snapshotState(context); + valueState.clear(); + if (value != null) { + valueState.add(value); + } + } + } + + /** Updates variables for calculating Auc. */ + private static class AccumulateMultiScoreOperator extends AbstractStreamOperator<double[]> + implements OneInputStreamOperator<Tuple4<Double, Long, Boolean, Double>, double[]>, + BoundedOneInput { + private ListState<double[]> accValueState; + double[] accValue; + double score; + + @Override + public void endInput() { + if (accValue != null) { + output.collect( + new StreamRecord<>( + new double[] { + accValue[0] / accValue[1] * accValue[2], + accValue[2], + accValue[3] + })); + } + } + + @Override + public void processElement( + StreamRecord<Tuple4<Double, Long, Boolean, Double>> streamRecord) { + Tuple4<Double, Long, Boolean, Double> t = streamRecord.getValue(); + if (accValue == null) { + accValue = new double[4]; + score = t.f0; + } else if (score != t.f0) { + output.collect( + new StreamRecord<>( + new double[] { + accValue[0] / accValue[1] * accValue[2], + accValue[2], + accValue[3] + })); + Arrays.fill(accValue, 0.0); + } + accValue[0] += t.f1; + accValue[1] += 1.0; + if (t.f2) { + accValue[2] += t.f3; + } else { + accValue[3] += t.f3; + } + } + + @Override + @SuppressWarnings("unchecked") + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + accValueState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "accValueState", TypeInformation.of(double[].class))); + accValue = + OperatorStateUtils.getUniqueElement(accValueState, "valueState").orElse(null); + } + + @Override + @SuppressWarnings("unchecked") + public void snapshotState(StateSnapshotContext context) throws Exception { + super.snapshotState(context); + accValueState.clear(); + if (accValue != null) { + accValueState.add(accValue); + } + } + } + + private static class PartitionSummaryOperator extends AbstractStreamOperator<BinarySummary> + implements OneInputStreamOperator<Tuple3<Double, Boolean, Double>, BinarySummary>, + BoundedOneInput { + private ListState<BinarySummary> summaryState; + private BinarySummary summary; + + @Override + public void endInput() { + if (summary != null) { + output.collect(new StreamRecord<>(summary)); + } + } + + @Override + public void processElement(StreamRecord<Tuple3<Double, Boolean, Double>> streamRecord) { + updateBinarySummary(summary, streamRecord.getValue()); + } + + @Override + @SuppressWarnings("unchecked") + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + summaryState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "summaryState", + TypeInformation.of(BinarySummary.class))); + summary = + OperatorStateUtils.getUniqueElement(summaryState, "summaryState") + .orElse( + new BinarySummary( + getRuntimeContext().getIndexOfThisSubtask(), + -Double.MAX_VALUE, + 0, + 0)); + } + + @Override + @SuppressWarnings("unchecked") + public void snapshotState(StateSnapshotContext context) throws Exception { + super.snapshotState(context); + summaryState.clear(); + if (summary != null) { + summaryState.add(summary); + } + } + } + + /** Merges the metrics calculated locally and output metrics data. */ + private static class MergeMetrics + implements MapPartitionFunction<BinaryMetrics, Map<String, Double>> { + private static final long serialVersionUID = 463407033215369847L; + + @Override + public void mapPartition( + Iterable<BinaryMetrics> values, Collector<Map<String, Double>> out) { + Iterator<BinaryMetrics> iter = values.iterator(); + BinaryMetrics reduceMetrics = iter.next(); + while (iter.hasNext()) { + reduceMetrics = reduceMetrics.merge(iter.next()); + } + Map<String, Double> map = new HashMap<>(); + map.put(AREA_UNDER_ROC, reduceMetrics.areaUnderROC); + map.put(AREA_UNDER_PR, reduceMetrics.areaUnderPR); + map.put(AREA_UNDER_LORENZ, reduceMetrics.areaUnderLorenz); + map.put(KS, reduceMetrics.ks); + out.collect(map); + } + } + + private static class CalcBinaryMetrics + extends RichMapPartitionFunction<Tuple3<Double, Boolean, Double>, BinaryMetrics> { + private static final long serialVersionUID = 5680342197308160013L; + private final String partitionSummariesKey; + + public CalcBinaryMetrics(String partitionSummariesKey) { + this.partitionSummariesKey = partitionSummariesKey; + } + + @Override + public void mapPartition( + Iterable<Tuple3<Double, Boolean, Double>> iterable, + Collector<BinaryMetrics> collector) { + + List<BinarySummary> statistics = + getRuntimeContext().getBroadcastVariable(partitionSummariesKey); + long[] countValues = + reduceBinarySummary(statistics, getRuntimeContext().getIndexOfThisSubtask()); + + double areaUnderROC = + getRuntimeContext().<Double>getBroadcastVariable(AREA_UNDER_ROC).get(0); + long totalTrue = countValues[2]; + long totalFalse = countValues[3]; + if (totalTrue == 0) { + LOG.warn("There is no positive sample in data!"); + } + if (totalFalse == 0) { + LOG.warn("There is no negative sample in data!"); + } + + BinaryMetrics metrics = new BinaryMetrics(0L, areaUnderROC); + double[] tprFprPrecision = new double[4]; + for (Tuple3<Double, Boolean, Double> t3 : iterable) { + updateBinaryMetrics(t3, metrics, countValues, tprFprPrecision); + } + collector.collect(metrics); + } + } + + private static void updateBinaryMetrics( + Tuple3<Double, Boolean, Double> cur, + BinaryMetrics binaryMetrics, + long[] countValues, + double[] recordValues) { + if (binaryMetrics.count == 0) { + recordValues[0] = countValues[2] == 0 ? 1.0 : 1.0 * countValues[0] / countValues[2]; + recordValues[1] = countValues[3] == 0 ? 1.0 : 1.0 * countValues[1] / countValues[3]; + recordValues[2] = + countValues[0] + countValues[1] == 0 + ? 1.0 + : 1.0 * countValues[0] / (countValues[0] + countValues[1]); + recordValues[3] = + 1.0 * (countValues[0] + countValues[1]) / (countValues[2] + countValues[3]); + } + + binaryMetrics.count++; + if (cur.f1) { + countValues[0]++; + } else { + countValues[1]++; + } + + double tpr = countValues[2] == 0 ? 1.0 : 1.0 * countValues[0] / countValues[2]; + double fpr = countValues[3] == 0 ? 1.0 : 1.0 * countValues[1] / countValues[3]; + double precision = + countValues[0] + countValues[1] == 0 + ? 1.0 + : 1.0 * countValues[0] / (countValues[0] + countValues[1]); + double positiveRate = + 1.0 * (countValues[0] + countValues[1]) / (countValues[2] + countValues[3]); + + binaryMetrics.areaUnderLorenz += + ((positiveRate - recordValues[3]) * (tpr + recordValues[0]) / 2); + binaryMetrics.areaUnderPR += ((tpr - recordValues[0]) * (precision + recordValues[2]) / 2); + binaryMetrics.ks = Math.max(Math.abs(fpr - tpr), binaryMetrics.ks); + + recordValues[0] = tpr; + recordValues[1] = fpr; + recordValues[2] = precision; + recordValues[3] = positiveRate; + } + + /** + * For each sample, calculates its score order among all samples. The sample with minimum score + * has order 1, while the sample with maximum score has order samples. + * + * <p>Input is a dataset of tuple (score, is real positive, wight), output is a dataset of tuple + * (score, order, is real positive, weight). + */ + private static class CalcSampleOrders + extends RichFlatMapFunction< + Tuple3<Double, Boolean, Double>, Tuple4<Double, Long, Boolean, Double>> { + private static final long serialVersionUID = 3047511137846831576L; Review Comment: Is `serialVersionUID` necessary for this operator? Existing operator classes in other algorithms does not need `serialVersionUID`, so I think it is OK to remove this. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
