This is an automated email from the ASF dual-hosted git repository.
zhangzp pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new 5619c3b8 [FLINK-32889] Fix calcuation of weighted areaUnderROC and
areaUnderPRC in BinaryClassificationEvaluator
5619c3b8 is described below
commit 5619c3b8591b220e78a0a792c1f940e06149c8f0
Author: Fan Hong <[email protected]>
AuthorDate: Thu Aug 24 14:40:12 2023 +0800
[FLINK-32889] Fix calcuation of weighted areaUnderROC and areaUnderPRC in
BinaryClassificationEvaluator
This closes #252.
---
.../BinaryClassificationEvaluator.java | 298 +++++----------------
.../BinaryClassificationEvaluatorTest.java | 17 +-
.../evaluation/tests/tests_binaryclassification.py | 10 +-
3 files changed, 86 insertions(+), 239 deletions(-)
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassification/BinaryClassificationEvaluator.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassification/BinaryClassificationEvaluator.java
index d74e40b2..051d7513 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassification/BinaryClassificationEvaluator.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassification/BinaryClassificationEvaluator.java
@@ -20,8 +20,6 @@ package org.apache.flink.ml.evaluation.binaryclassification;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
-import org.apache.flink.api.common.functions.ReduceFunction;
-import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.state.ListState;
@@ -64,7 +62,6 @@ import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
-import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;
@@ -124,7 +121,7 @@ public class BinaryClassificationEvaluator
Iterable<Tuple4<Double, Boolean, Double,
Integer>> values,
Collector<Tuple3<Double, Boolean, Double>>
out) {
List<Tuple3<Double, Boolean, Double>>
bufferedData =
- new LinkedList<>();
+ new ArrayList<>();
for (Tuple4<Double, Boolean, Double, Integer>
t4 : values) {
bufferedData.add(Tuple3.of(t4.f0, t4.f1,
t4.f2));
}
@@ -142,48 +139,8 @@ public class BinaryClassificationEvaluator
TypeInformation.of(BinarySummary.class),
new PartitionSummaryOperator());
- /* Sorts global data. Output Tuple4 : <score, order, isPositive,
weight>. */
- DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders =
- BroadcastUtils.withBroadcastStream(
- Collections.singletonList(sortEvalData),
- Collections.singletonMap(partitionSummariesKey,
partitionSummaries),
- inputList -> {
- DataStream input = inputList.get(0);
- return input.flatMap(new
CalcSampleOrders(partitionSummariesKey));
- });
-
- DataStream<double[]> localAreaUnderROCVariable =
- dataWithOrders.transform(
- "AccumulateMultiScore",
- TypeInformation.of(double[].class),
- new AccumulateMultiScoreOperator());
-
- DataStream<double[]> middleAreaUnderROC =
- DataStreamUtils.reduce(
- localAreaUnderROCVariable,
- (ReduceFunction<double[]>)
- (t1, t2) -> {
- t2[0] += t1[0];
- t2[1] += t1[1];
- t2[2] += t1[2];
- return t2;
- });
-
- DataStream<Double> areaUnderROC =
- middleAreaUnderROC.map(
- (MapFunction<double[], Double>)
- value -> {
- if (value[1] > 0 && value[2] > 0) {
- return (value[0] - 1. * value[1] *
(value[1] + 1) / 2)
- / (value[1] * value[2]);
- } else {
- return Double.NaN;
- }
- });
-
Map<String, DataStream<?>> broadcastMap = new HashMap<>();
broadcastMap.put(partitionSummariesKey, partitionSummaries);
- broadcastMap.put(AREA_UNDER_ROC, areaUnderROC);
DataStream<BinaryMetrics> localMetrics =
BroadcastUtils.withBroadcastStream(
Collections.singletonList(sortEvalData),
@@ -218,89 +175,6 @@ public class BinaryClassificationEvaluator
return new Table[] {tEnv.fromDataStream(evalResult)};
}
- /** Updates variables for calculating AreaUnderROC. */
- private static class AccumulateMultiScoreOperator extends
AbstractStreamOperator<double[]>
- implements OneInputStreamOperator<Tuple4<Double, Long, Boolean,
Double>, double[]>,
- BoundedOneInput {
- private ListState<double[]> accValueState;
- private ListState<Double> scoreState;
-
- double[] accValue;
- double score;
-
- @Override
- public void endInput() {
- if (accValue != null) {
- output.collect(
- new StreamRecord<>(
- new double[] {
- accValue[0] / accValue[1] * accValue[2],
- accValue[2],
- accValue[3]
- }));
- }
- }
-
- @Override
- public void processElement(
- StreamRecord<Tuple4<Double, Long, Boolean, Double>>
streamRecord) {
- Tuple4<Double, Long, Boolean, Double> t = streamRecord.getValue();
- if (accValue == null) {
- accValue = new double[4];
- score = t.f0;
- } else if (score != t.f0) {
- output.collect(
- new StreamRecord<>(
- new double[] {
- accValue[0] / accValue[1] * accValue[2],
- accValue[2],
- accValue[3]
- }));
- Arrays.fill(accValue, 0.0);
- }
- accValue[0] += t.f1;
- accValue[1] += 1.0;
- if (t.f2) {
- accValue[2] += t.f3;
- } else {
- accValue[3] += t.f3;
- }
- }
-
- @Override
- @SuppressWarnings("unchecked")
- public void initializeState(StateInitializationContext context) throws
Exception {
- super.initializeState(context);
- accValueState =
- context.getOperatorStateStore()
- .getListState(
- new ListStateDescriptor<>(
- "accValueState",
TypeInformation.of(double[].class)));
- accValue =
- OperatorStateUtils.getUniqueElement(accValueState,
"accValueState")
- .orElse(null);
-
- scoreState =
- context.getOperatorStateStore()
- .getListState(
- new ListStateDescriptor<>(
- "scoreState",
TypeInformation.of(Double.class)));
- score = OperatorStateUtils.getUniqueElement(scoreState,
"scoreState").orElse(0.0);
- }
-
- @Override
- @SuppressWarnings("unchecked")
- public void snapshotState(StateSnapshotContext context) throws
Exception {
- super.snapshotState(context);
- accValueState.clear();
- scoreState.clear();
- if (accValue != null) {
- accValueState.add(accValue);
- scoreState.add(score);
- }
- }
- }
-
private static class PartitionSummaryOperator extends
AbstractStreamOperator<BinarySummary>
implements OneInputStreamOperator<Tuple3<Double, Boolean, Double>,
BinarySummary>,
BoundedOneInput {
@@ -320,7 +194,6 @@ public class BinaryClassificationEvaluator
}
@Override
- @SuppressWarnings("unchecked")
public void initializeState(StateInitializationContext context) throws
Exception {
super.initializeState(context);
summaryState =
@@ -340,7 +213,6 @@ public class BinaryClassificationEvaluator
}
@Override
- @SuppressWarnings("unchecked")
public void snapshotState(StateSnapshotContext context) throws
Exception {
super.snapshotState(context);
summaryState.clear();
@@ -385,24 +257,24 @@ public class BinaryClassificationEvaluator
List<BinarySummary> statistics =
getRuntimeContext().getBroadcastVariable(partitionSummariesKey);
- long[] countValues =
+ double[] accWeights =
reduceBinarySummary(statistics,
getRuntimeContext().getIndexOfThisSubtask());
- double areaUnderROC =
-
getRuntimeContext().<Double>getBroadcastVariable(AREA_UNDER_ROC).get(0);
- long totalTrue = countValues[2];
- long totalFalse = countValues[3];
- if (totalTrue == 0) {
- LOG.warn("There is no positive sample in data!");
+ double totalSumWeightsPos = accWeights[2];
+ double totalSumWeightsNeg = accWeights[3];
+ if (totalSumWeightsPos == 0) {
+ LOG.warn("There is no positive samples in data!");
}
- if (totalFalse == 0) {
- LOG.warn("There is no negative sample in data!");
+ if (totalSumWeightsNeg == 0) {
+ LOG.warn("There is no negative samples in data!");
}
- BinaryMetrics metrics = new BinaryMetrics(0L, areaUnderROC);
+ BinaryMetrics metrics = new BinaryMetrics(0);
+ // Stores values of TPR, FPR, Precision, and PR calculated from
samples with scores
+ // ranging from the maximum to the current one.
double[] tprFprPrecision = new double[4];
for (Tuple3<Double, Boolean, Double> t3 : iterable) {
- updateBinaryMetrics(t3, metrics, countValues, tprFprPrecision);
+ updateBinaryMetrics(t3, metrics, accWeights, tprFprPrecision);
}
collector.collect(metrics);
}
@@ -411,35 +283,36 @@ public class BinaryClassificationEvaluator
private static void updateBinaryMetrics(
Tuple3<Double, Boolean, Double> cur,
BinaryMetrics binaryMetrics,
- long[] countValues,
+ double[] accWeights,
double[] recordValues) {
- if (binaryMetrics.count == 0) {
- recordValues[0] = countValues[2] == 0 ? 1.0 : 1.0 * countValues[0]
/ countValues[2];
- recordValues[1] = countValues[3] == 0 ? 1.0 : 1.0 * countValues[1]
/ countValues[3];
+ if (binaryMetrics.sumWeights == 0) {
+ recordValues[0] = accWeights[2] == 0 ? 1.0 : accWeights[0] /
accWeights[2];
+ recordValues[1] = accWeights[3] == 0 ? 1.0 : accWeights[1] /
accWeights[3];
recordValues[2] =
- countValues[0] + countValues[1] == 0
+ accWeights[0] + accWeights[1] == 0
? 1.0
- : 1.0 * countValues[0] / (countValues[0] +
countValues[1]);
- recordValues[3] =
- 1.0 * (countValues[0] + countValues[1]) / (countValues[2]
+ countValues[3]);
+ : accWeights[0] / (accWeights[0] + accWeights[1]);
+ recordValues[3] = (accWeights[0] + accWeights[1]) / (accWeights[2]
+ accWeights[3]);
}
- binaryMetrics.count++;
- if (cur.f1) {
- countValues[0]++;
+ boolean isPos = cur.f1;
+ double weight = cur.f2;
+ binaryMetrics.sumWeights += weight;
+ if (isPos) {
+ accWeights[0] += weight;
} else {
- countValues[1]++;
+ accWeights[1] += weight;
}
- double tpr = countValues[2] == 0 ? 1.0 : 1.0 * countValues[0] /
countValues[2];
- double fpr = countValues[3] == 0 ? 1.0 : 1.0 * countValues[1] /
countValues[3];
+ double tpr = accWeights[2] == 0 ? 1.0 : accWeights[0] / accWeights[2];
+ double fpr = accWeights[3] == 0 ? 1.0 : accWeights[1] / accWeights[3];
double precision =
- countValues[0] + countValues[1] == 0
+ accWeights[0] + accWeights[1] == 0
? 1.0
- : 1.0 * countValues[0] / (countValues[0] +
countValues[1]);
- double positiveRate =
- 1.0 * (countValues[0] + countValues[1]) / (countValues[2] +
countValues[3]);
+ : accWeights[0] / (accWeights[0] + accWeights[1]);
+ double positiveRate = (accWeights[0] + accWeights[1]) / (accWeights[2]
+ accWeights[3]);
+ binaryMetrics.areaUnderROC += (fpr - recordValues[1]) * (tpr +
recordValues[0]) / 2;
binaryMetrics.areaUnderLorenz +=
((positiveRate - recordValues[3]) * (tpr + recordValues[0]) /
2);
binaryMetrics.areaUnderPR += ((tpr - recordValues[0]) * (precision +
recordValues[2]) / 2);
@@ -451,65 +324,31 @@ public class BinaryClassificationEvaluator
recordValues[3] = positiveRate;
}
- /**
- * For each sample, calculates its score order among all samples. The
sample with minimum score
- * has order 1, while the sample with maximum score has order samples.
- *
- * <p>Input is a dataset of tuple (score, is real positive, weight),
output is a dataset of
- * tuple (score, order, is real positive, weight).
- */
- private static class CalcSampleOrders
- extends RichFlatMapFunction<
- Tuple3<Double, Boolean, Double>, Tuple4<Double, Long,
Boolean, Double>> {
- private long startIndex;
- private long total = -1;
- private final String partitionSummariesKey;
-
- public CalcSampleOrders(String partitionSummariesKey) {
- this.partitionSummariesKey = partitionSummariesKey;
- }
-
- @Override
- public void flatMap(
- Tuple3<Double, Boolean, Double> value,
- Collector<Tuple4<Double, Long, Boolean, Double>> out)
- throws Exception {
- if (total == -1) {
- List<BinarySummary> statistics =
-
getRuntimeContext().getBroadcastVariable(partitionSummariesKey);
- long[] countValues =
- reduceBinarySummary(
- statistics,
getRuntimeContext().getIndexOfThisSubtask());
- startIndex = countValues[1] + countValues[0] + 1;
- total = countValues[2] + countValues[3];
- }
- out.collect(Tuple4.of(value.f0, total - startIndex + 1, value.f1,
value.f2));
- startIndex++;
- }
- }
-
/**
* @param values Reduce Summary of all workers.
* @param taskId current taskId.
- * @return [curTrue, curFalse, TotalTrue, TotalFalse]
+ * @return An array storing sum of weights of positives/negatives of tasks
before the current
+ * one, and sum of weights of positives/negatives of all tasks.
*/
- private static long[] reduceBinarySummary(List<BinarySummary> values, int
taskId) {
+ private static double[] reduceBinarySummary(List<BinarySummary> values,
int taskId) {
List<BinarySummary> list = new ArrayList<>(values);
list.sort(Comparator.comparingDouble(t -> -t.maxScore));
- long curTrue = 0;
- long curFalse = 0;
- long totalTrue = 0;
- long totalFalse = 0;
+ double prefixSumWeightsPos = 0;
+ double prefixSumWeightsNeg = 0;
+ double totalSumWeightsPos = 0;
+ double totalSumWeightsNeg = 0;
for (BinarySummary statistics : list) {
if (statistics.taskId == taskId) {
- curFalse = totalFalse;
- curTrue = totalTrue;
+ prefixSumWeightsNeg = totalSumWeightsNeg;
+ prefixSumWeightsPos = totalSumWeightsPos;
}
- totalTrue += statistics.curPositive;
- totalFalse += statistics.curNegative;
+ totalSumWeightsPos += statistics.sumWeightsPos;
+ totalSumWeightsNeg += statistics.sumWeightsNeg;
}
- return new long[] {curTrue, curFalse, totalTrue, totalFalse};
+ return new double[] {
+ prefixSumWeightsPos, prefixSumWeightsNeg, totalSumWeightsPos,
totalSumWeightsNeg
+ };
}
/**
@@ -520,13 +359,16 @@ public class BinaryClassificationEvaluator
*/
private static void updateBinarySummary(
BinarySummary statistics, Tuple3<Double, Boolean, Double>
evalElement) {
- if (evalElement.f1) {
- statistics.curPositive++;
+ boolean isPos = evalElement.f1;
+ double weight = evalElement.f2;
+ double score = evalElement.f0;
+ if (isPos) {
+ statistics.sumWeightsPos += weight;
} else {
- statistics.curNegative++;
+ statistics.sumWeightsNeg += weight;
}
- if (Double.compare(statistics.maxScore, evalElement.f0) < 0) {
- statistics.maxScore = evalElement.f0;
+ if (Double.compare(statistics.maxScore, score) < 0) {
+ statistics.maxScore = score;
}
}
@@ -673,25 +515,26 @@ public class BinaryClassificationEvaluator
public Integer taskId;
// maximum score in this partition
public double maxScore;
- // real positives in this partition
- public long curPositive;
- // real negatives in this partition
- public long curNegative;
+ // sum of weights of positives in this partition
+ public double sumWeightsPos;
+ // sum of weights of negatives in this partition
+ public double sumWeightsNeg;
public BinarySummary() {}
- public BinarySummary(Integer taskId, double maxScore, long
curPositive, long curNegative) {
+ public BinarySummary(
+ Integer taskId, double maxScore, double sumWeightsPos, double
sumWeightsNeg) {
this.taskId = taskId;
this.maxScore = maxScore;
- this.curPositive = curPositive;
- this.curNegative = curNegative;
+ this.sumWeightsPos = sumWeightsPos;
+ this.sumWeightsNeg = sumWeightsNeg;
}
}
/** The evaluation metrics for binary classification. */
public static class BinaryMetrics {
- /* The count of samples. */
- public long count;
+ /* The sum of weights of samples. */
+ public double sumWeights;
/* Area under ROC */
public double areaUnderROC;
@@ -707,22 +550,19 @@ public class BinaryClassificationEvaluator
public BinaryMetrics() {}
- public BinaryMetrics(long count, double areaUnderROC) {
- this.count = count;
- this.areaUnderROC = areaUnderROC;
+ public BinaryMetrics(long sumWeights) {
+ this.sumWeights = sumWeights;
}
public BinaryMetrics merge(BinaryMetrics binaryClassMetrics) {
if (null == binaryClassMetrics) {
return this;
}
- Preconditions.checkState(
- Double.compare(areaUnderROC,
binaryClassMetrics.areaUnderROC) == 0,
- "AreaUnderROC not equal!");
- count += binaryClassMetrics.count;
- ks = Math.max(ks, binaryClassMetrics.ks);
- areaUnderPR += binaryClassMetrics.areaUnderPR;
+ sumWeights += binaryClassMetrics.sumWeights;
+ areaUnderROC += binaryClassMetrics.areaUnderROC;
areaUnderLorenz += binaryClassMetrics.areaUnderLorenz;
+ areaUnderPR += binaryClassMetrics.areaUnderPR;
+ ks = Math.max(ks, binaryClassMetrics.ks);
return this;
}
}
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.java
index 0c146a3d..f6f38b29 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.java
@@ -118,7 +118,8 @@ public class BinaryClassificationEvaluatorTest extends
AbstractTestBase {
new double[] {
0.8571428571428571, 0.9377705627705628, 0.8571428571428571,
0.6488095238095237
};
- private static final double EXPECTED_DATA_W = 0.8911680911680911;
+ private static final double[] EXPECTED_DATA_W =
+ new double[] {0.8717948717948718, 0.9510202726261435};
private static final double EPS = 1.0e-5;
@Before
@@ -297,14 +298,20 @@ public class BinaryClassificationEvaluatorTest extends
AbstractTestBase {
public void testEvaluateWithWeight() {
BinaryClassificationEvaluator eval =
new BinaryClassificationEvaluator()
-
.setMetricsNames(BinaryClassificationEvaluatorParams.AREA_UNDER_ROC)
+ .setMetricsNames(
+
BinaryClassificationEvaluatorParams.AREA_UNDER_ROC,
+
BinaryClassificationEvaluatorParams.AREA_UNDER_PR)
.setWeightCol("weight");
Table evalResult = eval.transform(inputDataTableWithWeight)[0];
- List<Row> results =
IteratorUtils.toList(evalResult.execute().collect());
+ Row result = (Row)
IteratorUtils.toList(evalResult.execute().collect()).get(0);
assertArrayEquals(
- new String[]
{BinaryClassificationEvaluatorParams.AREA_UNDER_ROC},
+ new String[] {
+ BinaryClassificationEvaluatorParams.AREA_UNDER_ROC,
+ BinaryClassificationEvaluatorParams.AREA_UNDER_PR
+ },
evalResult.getResolvedSchema().getColumnNames().toArray());
- assertEquals(EXPECTED_DATA_W, results.get(0).getFieldAs(0), EPS);
+ assertArrayEquals(
+ EXPECTED_DATA_W, new double[] {result.getFieldAs(0),
result.getFieldAs(1)}, EPS);
}
@Test
diff --git
a/flink-ml-python/pyflink/ml/evaluation/tests/tests_binaryclassification.py
b/flink-ml-python/pyflink/ml/evaluation/tests/tests_binaryclassification.py
index 7f34eb4c..55a8f2b9 100644
--- a/flink-ml-python/pyflink/ml/evaluation/tests/tests_binaryclassification.py
+++ b/flink-ml-python/pyflink/ml/evaluation/tests/tests_binaryclassification.py
@@ -16,10 +16,10 @@
# limitations under the License.
################################################################################
import os
-
from pyflink.common import Types
-from pyflink.ml.linalg import Vectors, DenseVectorTypeInfo
+
from pyflink.ml.evaluation.binaryclassification import
BinaryClassificationEvaluator
+from pyflink.ml.linalg import Vectors, DenseVectorTypeInfo
from pyflink.ml.tests.test_utils import PyFlinkMLTestCase
@@ -111,7 +111,7 @@ class BinaryClassificationEvaluatorTest(PyFlinkMLTestCase):
self.expected_data_m = [0.8571428571428571, 0.9377705627705628,
0.8571428571428571, 0.6488095238095237]
- self.expected_data_w = 0.8911680911680911
+ self.expected_data_w = [0.8717948717948718, 0.9510202726261435]
self.eps = 1e-5
@@ -185,11 +185,11 @@ class
BinaryClassificationEvaluatorTest(PyFlinkMLTestCase):
def test_evaluate_with_weight(self):
evaluator = BinaryClassificationEvaluator() \
- .set_metrics_names("areaUnderROC") \
+ .set_metrics_names("areaUnderROC", "areaUnderPR") \
.set_weight_col("weight")
output = evaluator.transform(self.input_data_table_with_weight)[0]
self.assertEqual(
- ["areaUnderROC"],
+ ["areaUnderROC", "areaUnderPR"],
output.get_schema().get_field_names())
results = [result for result in output.execute().collect()]
result = results[0]